aboutsummaryrefslogtreecommitdiff
path: root/src/dsp/arm/common_neon.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/dsp/arm/common_neon.h')
-rw-r--r--src/dsp/arm/common_neon.h70
1 files changed, 59 insertions, 11 deletions
diff --git a/src/dsp/arm/common_neon.h b/src/dsp/arm/common_neon.h
index dcb7567..05e0d05 100644
--- a/src/dsp/arm/common_neon.h
+++ b/src/dsp/arm/common_neon.h
@@ -28,8 +28,7 @@
#if 0
#include <cstdio>
-
-#include "absl/strings/str_cat.h"
+#include <string>
constexpr bool kEnablePrintRegs = true;
@@ -86,11 +85,11 @@ inline void PrintVectQ(const DebugRegisterQ r, const char* const name,
inline void PrintReg(const int32x4x2_t val, const std::string& name) {
DebugRegisterQ r;
- vst1q_u32(r.u32, val.val[0]);
- const std::string name0 = absl::StrCat(name, ".val[0]").c_str();
+ vst1q_s32(r.i32, val.val[0]);
+ const std::string name0 = name + std::string(".val[0]");
PrintVectQ(r, name0.c_str(), 32);
- vst1q_u32(r.u32, val.val[1]);
- const std::string name1 = absl::StrCat(name, ".val[1]").c_str();
+ vst1q_s32(r.i32, val.val[1]);
+ const std::string name1 = name + std::string(".val[1]");
PrintVectQ(r, name1.c_str(), 32);
}
@@ -169,14 +168,14 @@ inline void PrintReg(const int8x8_t val, const char* name) {
// Print an individual (non-vector) value in decimal format.
inline void PrintReg(const int x, const char* name) {
if (kEnablePrintRegs) {
- printf("%s: %d\n", name, x);
+ fprintf(stderr, "%s: %d\n", name, x);
}
}
// Print an individual (non-vector) value in hexadecimal format.
inline void PrintHex(const int x, const char* name) {
if (kEnablePrintRegs) {
- printf("%s: %x\n", name, x);
+ fprintf(stderr, "%s: %x\n", name, x);
}
}
@@ -277,22 +276,32 @@ inline void Store2(uint16_t* const buf, const uint16x4_t val) {
ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u16(val), lane));
}
+// Simplify code when caller has |buf| cast as uint8_t*.
+inline void Store4(void* const buf, const uint16x4_t val) {
+ vst1_u16(static_cast<uint16_t*>(buf), val);
+}
+
+// Simplify code when caller has |buf| cast as uint8_t*.
+inline void Store8(void* const buf, const uint16x8_t val) {
+ vst1q_u16(static_cast<uint16_t*>(buf), val);
+}
+
//------------------------------------------------------------------------------
// Bit manipulation.
// vshXX_n_XX() requires an immediate.
template <int shift>
-inline uint8x8_t LeftShift(const uint8x8_t vector) {
+inline uint8x8_t LeftShiftVector(const uint8x8_t vector) {
return vreinterpret_u8_u64(vshl_n_u64(vreinterpret_u64_u8(vector), shift));
}
template <int shift>
-inline uint8x8_t RightShift(const uint8x8_t vector) {
+inline uint8x8_t RightShiftVector(const uint8x8_t vector) {
return vreinterpret_u8_u64(vshr_n_u64(vreinterpret_u64_u8(vector), shift));
}
template <int shift>
-inline int8x8_t RightShift(const int8x8_t vector) {
+inline int8x8_t RightShiftVector(const int8x8_t vector) {
return vreinterpret_s8_u64(vshr_n_u64(vreinterpret_u64_s8(vector), shift));
}
@@ -387,6 +396,15 @@ inline uint16_t SumVector(const uint8x8_t a) {
#endif // defined(__aarch64__)
}
+inline uint32_t SumVector(const uint32x2_t a) {
+#if defined(__aarch64__)
+ return vaddv_u32(a);
+#else
+ const uint64x1_t b = vpaddl_u32(a);
+ return vget_lane_u32(vreinterpret_u32_u64(b), 0);
+#endif // defined(__aarch64__)
+}
+
inline uint32_t SumVector(const uint32x4_t a) {
#if defined(__aarch64__)
return vaddvq_u32(a);
@@ -447,6 +465,36 @@ inline uint16x8x2_t VtrnqU64(uint32x4_t a0, uint32x4_t a1) {
}
// Input:
+// 00 01 02 03
+// 10 11 12 13
+// 20 21 22 23
+// 30 31 32 33
+inline void Transpose4x4(uint16x4_t a[4]) {
+ // b:
+ // 00 10 02 12
+ // 01 11 03 13
+ const uint16x4x2_t b = vtrn_u16(a[0], a[1]);
+ // c:
+ // 20 30 22 32
+ // 21 31 23 33
+ const uint16x4x2_t c = vtrn_u16(a[2], a[3]);
+ // d:
+ // 00 10 20 30
+ // 02 12 22 32
+ const uint32x2x2_t d =
+ vtrn_u32(vreinterpret_u32_u16(b.val[0]), vreinterpret_u32_u16(c.val[0]));
+ // e:
+ // 01 11 21 31
+ // 03 13 23 33
+ const uint32x2x2_t e =
+ vtrn_u32(vreinterpret_u32_u16(b.val[1]), vreinterpret_u32_u16(c.val[1]));
+ a[0] = vreinterpret_u16_u32(d.val[0]);
+ a[1] = vreinterpret_u16_u32(e.val[0]);
+ a[2] = vreinterpret_u16_u32(d.val[1]);
+ a[3] = vreinterpret_u16_u32(e.val[1]);
+}
+
+// Input:
// a: 00 01 02 03 10 11 12 13
// b: 20 21 22 23 30 31 32 33
// Output: