diff options
Diffstat (limited to 'src/dsp/arm/common_neon.h')
-rw-r--r-- | src/dsp/arm/common_neon.h | 70 |
1 files changed, 59 insertions, 11 deletions
diff --git a/src/dsp/arm/common_neon.h b/src/dsp/arm/common_neon.h index dcb7567..05e0d05 100644 --- a/src/dsp/arm/common_neon.h +++ b/src/dsp/arm/common_neon.h @@ -28,8 +28,7 @@ #if 0 #include <cstdio> - -#include "absl/strings/str_cat.h" +#include <string> constexpr bool kEnablePrintRegs = true; @@ -86,11 +85,11 @@ inline void PrintVectQ(const DebugRegisterQ r, const char* const name, inline void PrintReg(const int32x4x2_t val, const std::string& name) { DebugRegisterQ r; - vst1q_u32(r.u32, val.val[0]); - const std::string name0 = absl::StrCat(name, ".val[0]").c_str(); + vst1q_s32(r.i32, val.val[0]); + const std::string name0 = name + std::string(".val[0]"); PrintVectQ(r, name0.c_str(), 32); - vst1q_u32(r.u32, val.val[1]); - const std::string name1 = absl::StrCat(name, ".val[1]").c_str(); + vst1q_s32(r.i32, val.val[1]); + const std::string name1 = name + std::string(".val[1]"); PrintVectQ(r, name1.c_str(), 32); } @@ -169,14 +168,14 @@ inline void PrintReg(const int8x8_t val, const char* name) { // Print an individual (non-vector) value in decimal format. inline void PrintReg(const int x, const char* name) { if (kEnablePrintRegs) { - printf("%s: %d\n", name, x); + fprintf(stderr, "%s: %d\n", name, x); } } // Print an individual (non-vector) value in hexadecimal format. inline void PrintHex(const int x, const char* name) { if (kEnablePrintRegs) { - printf("%s: %x\n", name, x); + fprintf(stderr, "%s: %x\n", name, x); } } @@ -277,22 +276,32 @@ inline void Store2(uint16_t* const buf, const uint16x4_t val) { ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u16(val), lane)); } +// Simplify code when caller has |buf| cast as uint8_t*. +inline void Store4(void* const buf, const uint16x4_t val) { + vst1_u16(static_cast<uint16_t*>(buf), val); +} + +// Simplify code when caller has |buf| cast as uint8_t*. +inline void Store8(void* const buf, const uint16x8_t val) { + vst1q_u16(static_cast<uint16_t*>(buf), val); +} + //------------------------------------------------------------------------------ // Bit manipulation. // vshXX_n_XX() requires an immediate. template <int shift> -inline uint8x8_t LeftShift(const uint8x8_t vector) { +inline uint8x8_t LeftShiftVector(const uint8x8_t vector) { return vreinterpret_u8_u64(vshl_n_u64(vreinterpret_u64_u8(vector), shift)); } template <int shift> -inline uint8x8_t RightShift(const uint8x8_t vector) { +inline uint8x8_t RightShiftVector(const uint8x8_t vector) { return vreinterpret_u8_u64(vshr_n_u64(vreinterpret_u64_u8(vector), shift)); } template <int shift> -inline int8x8_t RightShift(const int8x8_t vector) { +inline int8x8_t RightShiftVector(const int8x8_t vector) { return vreinterpret_s8_u64(vshr_n_u64(vreinterpret_u64_s8(vector), shift)); } @@ -387,6 +396,15 @@ inline uint16_t SumVector(const uint8x8_t a) { #endif // defined(__aarch64__) } +inline uint32_t SumVector(const uint32x2_t a) { +#if defined(__aarch64__) + return vaddv_u32(a); +#else + const uint64x1_t b = vpaddl_u32(a); + return vget_lane_u32(vreinterpret_u32_u64(b), 0); +#endif // defined(__aarch64__) +} + inline uint32_t SumVector(const uint32x4_t a) { #if defined(__aarch64__) return vaddvq_u32(a); @@ -447,6 +465,36 @@ inline uint16x8x2_t VtrnqU64(uint32x4_t a0, uint32x4_t a1) { } // Input: +// 00 01 02 03 +// 10 11 12 13 +// 20 21 22 23 +// 30 31 32 33 +inline void Transpose4x4(uint16x4_t a[4]) { + // b: + // 00 10 02 12 + // 01 11 03 13 + const uint16x4x2_t b = vtrn_u16(a[0], a[1]); + // c: + // 20 30 22 32 + // 21 31 23 33 + const uint16x4x2_t c = vtrn_u16(a[2], a[3]); + // d: + // 00 10 20 30 + // 02 12 22 32 + const uint32x2x2_t d = + vtrn_u32(vreinterpret_u32_u16(b.val[0]), vreinterpret_u32_u16(c.val[0])); + // e: + // 01 11 21 31 + // 03 13 23 33 + const uint32x2x2_t e = + vtrn_u32(vreinterpret_u32_u16(b.val[1]), vreinterpret_u32_u16(c.val[1])); + a[0] = vreinterpret_u16_u32(d.val[0]); + a[1] = vreinterpret_u16_u32(e.val[0]); + a[2] = vreinterpret_u16_u32(d.val[1]); + a[3] = vreinterpret_u16_u32(e.val[1]); +} + +// Input: // a: 00 01 02 03 10 11 12 13 // b: 20 21 22 23 30 31 32 33 // Output: |