diff options
Diffstat (limited to 'src/dsp/arm/common_neon.h')
-rw-r--r-- | src/dsp/arm/common_neon.h | 385 |
1 files changed, 384 insertions, 1 deletions
diff --git a/src/dsp/arm/common_neon.h b/src/dsp/arm/common_neon.h index 05e0d05..9c46525 100644 --- a/src/dsp/arm/common_neon.h +++ b/src/dsp/arm/common_neon.h @@ -23,9 +23,13 @@ #include <arm_neon.h> +#include <algorithm> +#include <cstddef> #include <cstdint> #include <cstring> +#include "src/utils/compiler_attributes.h" + #if 0 #include <cstdio> #include <string> @@ -183,6 +187,20 @@ inline void PrintHex(const int x, const char* name) { #define PD(x) PrintReg(x, #x) #define PX(x) PrintHex(x, #x) +#if LIBGAV1_MSAN +#include <sanitizer/msan_interface.h> + +inline void PrintShadow(const void* r, const char* const name, + const size_t size) { + if (kEnablePrintRegs) { + fprintf(stderr, "Shadow for %s:\n", name); + __msan_print_shadow(r, size); + } +} +#define PS(var, N) PrintShadow(var, #var, N) + +#endif // LIBGAV1_MSAN + #endif // 0 namespace libgav1 { @@ -210,6 +228,14 @@ inline uint8x8_t Load2(const void* const buf, uint8x8_t val) { vld1_lane_u16(&temp, vreinterpret_u16_u8(val), lane)); } +template <int lane> +inline uint16x4_t Load2(const void* const buf, uint16x4_t val) { + uint32_t temp; + memcpy(&temp, buf, 4); + return vreinterpret_u16_u32( + vld1_lane_u32(&temp, vreinterpret_u32_u16(val), lane)); +} + // Load 4 uint8_t values into the low half of a uint8x8_t register. Zeros the // register before loading the values. Use caution when using this in loops // because it will re-zero the register before loading on every iteration. @@ -229,6 +255,96 @@ inline uint8x8_t Load4(const void* const buf, uint8x8_t val) { vld1_lane_u32(&temp, vreinterpret_u32_u8(val), lane)); } +// Convenience functions for 16-bit loads from a uint8_t* source. +inline uint16x4_t Load4U16(const void* const buf) { + return vld1_u16(static_cast<const uint16_t*>(buf)); +} + +inline uint16x8_t Load8U16(const void* const buf) { + return vld1q_u16(static_cast<const uint16_t*>(buf)); +} + +//------------------------------------------------------------------------------ +// Load functions to avoid MemorySanitizer's use-of-uninitialized-value warning. + +inline uint8x8_t MaskOverreads(const uint8x8_t source, + const ptrdiff_t over_read_in_bytes) { + uint8x8_t dst = source; +#if LIBGAV1_MSAN + if (over_read_in_bytes > 0) { + uint8x8_t mask = vdup_n_u8(0); + uint8x8_t valid_element_mask = vdup_n_u8(-1); + const int valid_bytes = + std::min(8, 8 - static_cast<int>(over_read_in_bytes)); + for (int i = 0; i < valid_bytes; ++i) { + // Feed ff bytes into |mask| one at a time. + mask = vext_u8(valid_element_mask, mask, 7); + } + dst = vand_u8(dst, mask); + } +#else + static_cast<void>(over_read_in_bytes); +#endif + return dst; +} + +inline uint8x16_t MaskOverreadsQ(const uint8x16_t source, + const ptrdiff_t over_read_in_bytes) { + uint8x16_t dst = source; +#if LIBGAV1_MSAN + if (over_read_in_bytes > 0) { + uint8x16_t mask = vdupq_n_u8(0); + uint8x16_t valid_element_mask = vdupq_n_u8(-1); + const int valid_bytes = + std::min(16, 16 - static_cast<int>(over_read_in_bytes)); + for (int i = 0; i < valid_bytes; ++i) { + // Feed ff bytes into |mask| one at a time. + mask = vextq_u8(valid_element_mask, mask, 15); + } + dst = vandq_u8(dst, mask); + } +#else + static_cast<void>(over_read_in_bytes); +#endif + return dst; +} + +inline uint8x8_t Load1MsanU8(const uint8_t* const source, + const ptrdiff_t over_read_in_bytes) { + return MaskOverreads(vld1_u8(source), over_read_in_bytes); +} + +inline uint8x16_t Load1QMsanU8(const uint8_t* const source, + const ptrdiff_t over_read_in_bytes) { + return MaskOverreadsQ(vld1q_u8(source), over_read_in_bytes); +} + +inline uint16x8_t Load1QMsanU16(const uint16_t* const source, + const ptrdiff_t over_read_in_bytes) { + return vreinterpretq_u16_u8(MaskOverreadsQ( + vreinterpretq_u8_u16(vld1q_u16(source)), over_read_in_bytes)); +} + +inline uint16x8x2_t Load2QMsanU16(const uint16_t* const source, + const ptrdiff_t over_read_in_bytes) { + // Relative source index of elements (2 bytes each): + // dst.val[0]: 00 02 04 06 08 10 12 14 + // dst.val[1]: 01 03 05 07 09 11 13 15 + uint16x8x2_t dst = vld2q_u16(source); + dst.val[0] = vreinterpretq_u16_u8(MaskOverreadsQ( + vreinterpretq_u8_u16(dst.val[0]), over_read_in_bytes >> 1)); + dst.val[1] = vreinterpretq_u16_u8( + MaskOverreadsQ(vreinterpretq_u8_u16(dst.val[1]), + (over_read_in_bytes >> 1) + (over_read_in_bytes % 4))); + return dst; +} + +inline uint32x4_t Load1QMsanU32(const uint32_t* const source, + const ptrdiff_t over_read_in_bytes) { + return vreinterpretq_u32_u8(MaskOverreadsQ( + vreinterpretq_u8_u32(vld1q_u32(source)), over_read_in_bytes)); +} + //------------------------------------------------------------------------------ // Store functions. @@ -272,7 +388,7 @@ inline void Store2(void* const buf, const uint16x8_t val) { // Store 2 uint16_t values from |lane| * 2 and |lane| * 2 + 1 of a uint16x4_t // register. template <int lane> -inline void Store2(uint16_t* const buf, const uint16x4_t val) { +inline void Store2(void* const buf, const uint16x4_t val) { ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u16(val), lane)); } @@ -287,6 +403,104 @@ inline void Store8(void* const buf, const uint16x8_t val) { } //------------------------------------------------------------------------------ +// Pointer helpers. + +// This function adds |stride|, given as a number of bytes, to a pointer to a +// larger type, using native pointer arithmetic. +template <typename T> +inline T* AddByteStride(T* ptr, const ptrdiff_t stride) { + return reinterpret_cast<T*>( + const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(ptr) + stride)); +} + +//------------------------------------------------------------------------------ +// Multiply. + +// Shim vmull_high_u16 for armv7. +inline uint32x4_t VMullHighU16(const uint16x8_t a, const uint16x8_t b) { +#if defined(__aarch64__) + return vmull_high_u16(a, b); +#else + return vmull_u16(vget_high_u16(a), vget_high_u16(b)); +#endif +} + +// Shim vmull_high_s16 for armv7. +inline int32x4_t VMullHighS16(const int16x8_t a, const int16x8_t b) { +#if defined(__aarch64__) + return vmull_high_s16(a, b); +#else + return vmull_s16(vget_high_s16(a), vget_high_s16(b)); +#endif +} + +// Shim vmlal_high_u16 for armv7. +inline uint32x4_t VMlalHighU16(const uint32x4_t a, const uint16x8_t b, + const uint16x8_t c) { +#if defined(__aarch64__) + return vmlal_high_u16(a, b, c); +#else + return vmlal_u16(a, vget_high_u16(b), vget_high_u16(c)); +#endif +} + +// Shim vmlal_high_s16 for armv7. +inline int32x4_t VMlalHighS16(const int32x4_t a, const int16x8_t b, + const int16x8_t c) { +#if defined(__aarch64__) + return vmlal_high_s16(a, b, c); +#else + return vmlal_s16(a, vget_high_s16(b), vget_high_s16(c)); +#endif +} + +// Shim vmul_laneq_u16 for armv7. +template <int lane> +inline uint16x4_t VMulLaneQU16(const uint16x4_t a, const uint16x8_t b) { +#if defined(__aarch64__) + return vmul_laneq_u16(a, b, lane); +#else + if (lane < 4) return vmul_lane_u16(a, vget_low_u16(b), lane & 0x3); + return vmul_lane_u16(a, vget_high_u16(b), (lane - 4) & 0x3); +#endif +} + +// Shim vmulq_laneq_u16 for armv7. +template <int lane> +inline uint16x8_t VMulQLaneQU16(const uint16x8_t a, const uint16x8_t b) { +#if defined(__aarch64__) + return vmulq_laneq_u16(a, b, lane); +#else + if (lane < 4) return vmulq_lane_u16(a, vget_low_u16(b), lane & 0x3); + return vmulq_lane_u16(a, vget_high_u16(b), (lane - 4) & 0x3); +#endif +} + +// Shim vmla_laneq_u16 for armv7. +template <int lane> +inline uint16x4_t VMlaLaneQU16(const uint16x4_t a, const uint16x4_t b, + const uint16x8_t c) { +#if defined(__aarch64__) + return vmla_laneq_u16(a, b, c, lane); +#else + if (lane < 4) return vmla_lane_u16(a, b, vget_low_u16(c), lane & 0x3); + return vmla_lane_u16(a, b, vget_high_u16(c), (lane - 4) & 0x3); +#endif +} + +// Shim vmlaq_laneq_u16 for armv7. +template <int lane> +inline uint16x8_t VMlaQLaneQU16(const uint16x8_t a, const uint16x8_t b, + const uint16x8_t c) { +#if defined(__aarch64__) + return vmlaq_laneq_u16(a, b, c, lane); +#else + if (lane < 4) return vmlaq_lane_u16(a, b, vget_low_u16(c), lane & 0x3); + return vmlaq_lane_u16(a, b, vget_high_u16(c), (lane - 4) & 0x3); +#endif +} + +//------------------------------------------------------------------------------ // Bit manipulation. // vshXX_n_XX() requires an immediate. @@ -315,6 +529,51 @@ inline uint8x8_t VQTbl1U8(const uint8x16_t a, const uint8x8_t index) { #endif } +// Shim vqtbl2_u8 for armv7. +inline uint8x8_t VQTbl2U8(const uint8x16x2_t a, const uint8x8_t index) { +#if defined(__aarch64__) + return vqtbl2_u8(a, index); +#else + const uint8x8x4_t b = {vget_low_u8(a.val[0]), vget_high_u8(a.val[0]), + vget_low_u8(a.val[1]), vget_high_u8(a.val[1])}; + return vtbl4_u8(b, index); +#endif +} + +// Shim vqtbl2q_u8 for armv7. +inline uint8x16_t VQTbl2QU8(const uint8x16x2_t a, const uint8x16_t index) { +#if defined(__aarch64__) + return vqtbl2q_u8(a, index); +#else + return vcombine_u8(VQTbl2U8(a, vget_low_u8(index)), + VQTbl2U8(a, vget_high_u8(index))); +#endif +} + +// Shim vqtbl3q_u8 for armv7. +inline uint8x8_t VQTbl3U8(const uint8x16x3_t a, const uint8x8_t index) { +#if defined(__aarch64__) + return vqtbl3_u8(a, index); +#else + const uint8x8x4_t b = {vget_low_u8(a.val[0]), vget_high_u8(a.val[0]), + vget_low_u8(a.val[1]), vget_high_u8(a.val[1])}; + const uint8x8x2_t c = {vget_low_u8(a.val[2]), vget_high_u8(a.val[2])}; + const uint8x8_t index_ext = vsub_u8(index, vdup_n_u8(32)); + const uint8x8_t partial_lookup = vtbl4_u8(b, index); + return vtbx2_u8(partial_lookup, c, index_ext); +#endif +} + +// Shim vqtbl3q_u8 for armv7. +inline uint8x16_t VQTbl3QU8(const uint8x16x3_t a, const uint8x16_t index) { +#if defined(__aarch64__) + return vqtbl3q_u8(a, index); +#else + return vcombine_u8(VQTbl3U8(a, vget_low_u8(index)), + VQTbl3U8(a, vget_high_u8(index))); +#endif +} + // Shim vqtbl1_s8 for armv7. inline int8x8_t VQTbl1S8(const int8x16_t a, const uint8x8_t index) { #if defined(__aarch64__) @@ -326,6 +585,25 @@ inline int8x8_t VQTbl1S8(const int8x16_t a, const uint8x8_t index) { } //------------------------------------------------------------------------------ +// Saturation helpers. + +inline int16x4_t Clip3S16(int16x4_t val, int16x4_t low, int16x4_t high) { + return vmin_s16(vmax_s16(val, low), high); +} + +inline int16x8_t Clip3S16(const int16x8_t val, const int16x8_t low, + const int16x8_t high) { + return vminq_s16(vmaxq_s16(val, low), high); +} + +inline uint16x8_t ConvertToUnsignedPixelU16(int16x8_t val, int bitdepth) { + const int16x8_t low = vdupq_n_s16(0); + const uint16x8_t high = vdupq_n_u16((1 << bitdepth) - 1); + + return vminq_u16(vreinterpretq_u16_s16(vmaxq_s16(val, low)), high); +} + +//------------------------------------------------------------------------------ // Interleave. // vzipN is exclusive to A64. @@ -439,6 +717,9 @@ inline uint8x8_t Transpose32(const uint8x8_t a) { return vreinterpret_u8_u32(b); } +// Swap high and low halves. +inline uint16x8_t Transpose64(const uint16x8_t a) { return vextq_u16(a, a, 4); } + // Implement vtrnq_s64(). // Input: // a0: 00 01 02 03 04 05 06 07 @@ -512,6 +793,108 @@ inline void Transpose4x4(uint8x8_t* a, uint8x8_t* b) { *b = e.val[1]; } +// 4x8 Input: +// a[0]: 00 01 02 03 04 05 06 07 +// a[1]: 10 11 12 13 14 15 16 17 +// a[2]: 20 21 22 23 24 25 26 27 +// a[3]: 30 31 32 33 34 35 36 37 +// 8x4 Output: +// a[0]: 00 10 20 30 04 14 24 34 +// a[1]: 01 11 21 31 05 15 25 35 +// a[2]: 02 12 22 32 06 16 26 36 +// a[3]: 03 13 23 33 07 17 27 37 +inline void Transpose4x8(uint16x8_t a[4]) { + // b0.val[0]: 00 10 02 12 04 14 06 16 + // b0.val[1]: 01 11 03 13 05 15 07 17 + // b1.val[0]: 20 30 22 32 24 34 26 36 + // b1.val[1]: 21 31 23 33 25 35 27 37 + const uint16x8x2_t b0 = vtrnq_u16(a[0], a[1]); + const uint16x8x2_t b1 = vtrnq_u16(a[2], a[3]); + + // c0.val[0]: 00 10 20 30 04 14 24 34 + // c0.val[1]: 02 12 22 32 06 16 26 36 + // c1.val[0]: 01 11 21 31 05 15 25 35 + // c1.val[1]: 03 13 23 33 07 17 27 37 + const uint32x4x2_t c0 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[0]), + vreinterpretq_u32_u16(b1.val[0])); + const uint32x4x2_t c1 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[1]), + vreinterpretq_u32_u16(b1.val[1])); + + a[0] = vreinterpretq_u16_u32(c0.val[0]); + a[1] = vreinterpretq_u16_u32(c1.val[0]); + a[2] = vreinterpretq_u16_u32(c0.val[1]); + a[3] = vreinterpretq_u16_u32(c1.val[1]); +} + +// Special transpose for loop filter. +// 4x8 Input: +// p_q: p3 p2 p1 p0 q0 q1 q2 q3 +// a[0]: 00 01 02 03 04 05 06 07 +// a[1]: 10 11 12 13 14 15 16 17 +// a[2]: 20 21 22 23 24 25 26 27 +// a[3]: 30 31 32 33 34 35 36 37 +// 8x4 Output: +// a[0]: 03 13 23 33 04 14 24 34 p0q0 +// a[1]: 02 12 22 32 05 15 25 35 p1q1 +// a[2]: 01 11 21 31 06 16 26 36 p2q2 +// a[3]: 00 10 20 30 07 17 27 37 p3q3 +// Direct reapplication of the function will reset the high halves, but +// reverse the low halves: +// p_q: p0 p1 p2 p3 q0 q1 q2 q3 +// a[0]: 33 32 31 30 04 05 06 07 +// a[1]: 23 22 21 20 14 15 16 17 +// a[2]: 13 12 11 10 24 25 26 27 +// a[3]: 03 02 01 00 34 35 36 37 +// Simply reordering the inputs (3, 2, 1, 0) will reset the low halves, but +// reverse the high halves. +// The standard Transpose4x8 will produce the same reversals, but with the +// order of the low halves also restored relative to the high halves. This is +// preferable because it puts all values from the same source row back together, +// but some post-processing is inevitable. +inline void LoopFilterTranspose4x8(uint16x8_t a[4]) { + // b0.val[0]: 00 10 02 12 04 14 06 16 + // b0.val[1]: 01 11 03 13 05 15 07 17 + // b1.val[0]: 20 30 22 32 24 34 26 36 + // b1.val[1]: 21 31 23 33 25 35 27 37 + const uint16x8x2_t b0 = vtrnq_u16(a[0], a[1]); + const uint16x8x2_t b1 = vtrnq_u16(a[2], a[3]); + + // Reverse odd vectors to bring the appropriate items to the front of zips. + // b0.val[0]: 00 10 02 12 04 14 06 16 + // r0 : 03 13 01 11 07 17 05 15 + // b1.val[0]: 20 30 22 32 24 34 26 36 + // r1 : 23 33 21 31 27 37 25 35 + const uint32x4_t r0 = vrev64q_u32(vreinterpretq_u32_u16(b0.val[1])); + const uint32x4_t r1 = vrev64q_u32(vreinterpretq_u32_u16(b1.val[1])); + + // Zip to complete the halves. + // c0.val[0]: 00 10 20 30 02 12 22 32 p3p1 + // c0.val[1]: 04 14 24 34 06 16 26 36 q0q2 + // c1.val[0]: 03 13 23 33 01 11 21 31 p0p2 + // c1.val[1]: 07 17 27 37 05 15 25 35 q3q1 + const uint32x4x2_t c0 = vzipq_u32(vreinterpretq_u32_u16(b0.val[0]), + vreinterpretq_u32_u16(b1.val[0])); + const uint32x4x2_t c1 = vzipq_u32(r0, r1); + + // d0.val[0]: 00 10 20 30 07 17 27 37 p3q3 + // d0.val[1]: 02 12 22 32 05 15 25 35 p1q1 + // d1.val[0]: 03 13 23 33 04 14 24 34 p0q0 + // d1.val[1]: 01 11 21 31 06 16 26 36 p2q2 + const uint16x8x2_t d0 = VtrnqU64(c0.val[0], c1.val[1]); + // The third row of c comes first here to swap p2 with q0. + const uint16x8x2_t d1 = VtrnqU64(c1.val[0], c0.val[1]); + + // 8x4 Output: + // a[0]: 03 13 23 33 04 14 24 34 p0q0 + // a[1]: 02 12 22 32 05 15 25 35 p1q1 + // a[2]: 01 11 21 31 06 16 26 36 p2q2 + // a[3]: 00 10 20 30 07 17 27 37 p3q3 + a[0] = d1.val[0]; // p0q0 + a[1] = d0.val[1]; // p1q1 + a[2] = d1.val[1]; // p2q2 + a[3] = d0.val[0]; // p3q3 +} + // Reversible if the x4 values are packed next to each other. // x4 input / x8 output: // a0: 00 01 02 03 40 41 42 43 44 |