aboutsummaryrefslogtreecommitdiff
path: root/src/dsp/arm/common_neon.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/dsp/arm/common_neon.h')
-rw-r--r--src/dsp/arm/common_neon.h385
1 files changed, 384 insertions, 1 deletions
diff --git a/src/dsp/arm/common_neon.h b/src/dsp/arm/common_neon.h
index 05e0d05..9c46525 100644
--- a/src/dsp/arm/common_neon.h
+++ b/src/dsp/arm/common_neon.h
@@ -23,9 +23,13 @@
#include <arm_neon.h>
+#include <algorithm>
+#include <cstddef>
#include <cstdint>
#include <cstring>
+#include "src/utils/compiler_attributes.h"
+
#if 0
#include <cstdio>
#include <string>
@@ -183,6 +187,20 @@ inline void PrintHex(const int x, const char* name) {
#define PD(x) PrintReg(x, #x)
#define PX(x) PrintHex(x, #x)
+#if LIBGAV1_MSAN
+#include <sanitizer/msan_interface.h>
+
+inline void PrintShadow(const void* r, const char* const name,
+ const size_t size) {
+ if (kEnablePrintRegs) {
+ fprintf(stderr, "Shadow for %s:\n", name);
+ __msan_print_shadow(r, size);
+ }
+}
+#define PS(var, N) PrintShadow(var, #var, N)
+
+#endif // LIBGAV1_MSAN
+
#endif // 0
namespace libgav1 {
@@ -210,6 +228,14 @@ inline uint8x8_t Load2(const void* const buf, uint8x8_t val) {
vld1_lane_u16(&temp, vreinterpret_u16_u8(val), lane));
}
+template <int lane>
+inline uint16x4_t Load2(const void* const buf, uint16x4_t val) {
+ uint32_t temp;
+ memcpy(&temp, buf, 4);
+ return vreinterpret_u16_u32(
+ vld1_lane_u32(&temp, vreinterpret_u32_u16(val), lane));
+}
+
// Load 4 uint8_t values into the low half of a uint8x8_t register. Zeros the
// register before loading the values. Use caution when using this in loops
// because it will re-zero the register before loading on every iteration.
@@ -229,6 +255,96 @@ inline uint8x8_t Load4(const void* const buf, uint8x8_t val) {
vld1_lane_u32(&temp, vreinterpret_u32_u8(val), lane));
}
+// Convenience functions for 16-bit loads from a uint8_t* source.
+inline uint16x4_t Load4U16(const void* const buf) {
+ return vld1_u16(static_cast<const uint16_t*>(buf));
+}
+
+inline uint16x8_t Load8U16(const void* const buf) {
+ return vld1q_u16(static_cast<const uint16_t*>(buf));
+}
+
+//------------------------------------------------------------------------------
+// Load functions to avoid MemorySanitizer's use-of-uninitialized-value warning.
+
+inline uint8x8_t MaskOverreads(const uint8x8_t source,
+ const ptrdiff_t over_read_in_bytes) {
+ uint8x8_t dst = source;
+#if LIBGAV1_MSAN
+ if (over_read_in_bytes > 0) {
+ uint8x8_t mask = vdup_n_u8(0);
+ uint8x8_t valid_element_mask = vdup_n_u8(-1);
+ const int valid_bytes =
+ std::min(8, 8 - static_cast<int>(over_read_in_bytes));
+ for (int i = 0; i < valid_bytes; ++i) {
+ // Feed ff bytes into |mask| one at a time.
+ mask = vext_u8(valid_element_mask, mask, 7);
+ }
+ dst = vand_u8(dst, mask);
+ }
+#else
+ static_cast<void>(over_read_in_bytes);
+#endif
+ return dst;
+}
+
+inline uint8x16_t MaskOverreadsQ(const uint8x16_t source,
+ const ptrdiff_t over_read_in_bytes) {
+ uint8x16_t dst = source;
+#if LIBGAV1_MSAN
+ if (over_read_in_bytes > 0) {
+ uint8x16_t mask = vdupq_n_u8(0);
+ uint8x16_t valid_element_mask = vdupq_n_u8(-1);
+ const int valid_bytes =
+ std::min(16, 16 - static_cast<int>(over_read_in_bytes));
+ for (int i = 0; i < valid_bytes; ++i) {
+ // Feed ff bytes into |mask| one at a time.
+ mask = vextq_u8(valid_element_mask, mask, 15);
+ }
+ dst = vandq_u8(dst, mask);
+ }
+#else
+ static_cast<void>(over_read_in_bytes);
+#endif
+ return dst;
+}
+
+inline uint8x8_t Load1MsanU8(const uint8_t* const source,
+ const ptrdiff_t over_read_in_bytes) {
+ return MaskOverreads(vld1_u8(source), over_read_in_bytes);
+}
+
+inline uint8x16_t Load1QMsanU8(const uint8_t* const source,
+ const ptrdiff_t over_read_in_bytes) {
+ return MaskOverreadsQ(vld1q_u8(source), over_read_in_bytes);
+}
+
+inline uint16x8_t Load1QMsanU16(const uint16_t* const source,
+ const ptrdiff_t over_read_in_bytes) {
+ return vreinterpretq_u16_u8(MaskOverreadsQ(
+ vreinterpretq_u8_u16(vld1q_u16(source)), over_read_in_bytes));
+}
+
+inline uint16x8x2_t Load2QMsanU16(const uint16_t* const source,
+ const ptrdiff_t over_read_in_bytes) {
+ // Relative source index of elements (2 bytes each):
+ // dst.val[0]: 00 02 04 06 08 10 12 14
+ // dst.val[1]: 01 03 05 07 09 11 13 15
+ uint16x8x2_t dst = vld2q_u16(source);
+ dst.val[0] = vreinterpretq_u16_u8(MaskOverreadsQ(
+ vreinterpretq_u8_u16(dst.val[0]), over_read_in_bytes >> 1));
+ dst.val[1] = vreinterpretq_u16_u8(
+ MaskOverreadsQ(vreinterpretq_u8_u16(dst.val[1]),
+ (over_read_in_bytes >> 1) + (over_read_in_bytes % 4)));
+ return dst;
+}
+
+inline uint32x4_t Load1QMsanU32(const uint32_t* const source,
+ const ptrdiff_t over_read_in_bytes) {
+ return vreinterpretq_u32_u8(MaskOverreadsQ(
+ vreinterpretq_u8_u32(vld1q_u32(source)), over_read_in_bytes));
+}
+
//------------------------------------------------------------------------------
// Store functions.
@@ -272,7 +388,7 @@ inline void Store2(void* const buf, const uint16x8_t val) {
// Store 2 uint16_t values from |lane| * 2 and |lane| * 2 + 1 of a uint16x4_t
// register.
template <int lane>
-inline void Store2(uint16_t* const buf, const uint16x4_t val) {
+inline void Store2(void* const buf, const uint16x4_t val) {
ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u16(val), lane));
}
@@ -287,6 +403,104 @@ inline void Store8(void* const buf, const uint16x8_t val) {
}
//------------------------------------------------------------------------------
+// Pointer helpers.
+
+// This function adds |stride|, given as a number of bytes, to a pointer to a
+// larger type, using native pointer arithmetic.
+template <typename T>
+inline T* AddByteStride(T* ptr, const ptrdiff_t stride) {
+ return reinterpret_cast<T*>(
+ const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(ptr) + stride));
+}
+
+//------------------------------------------------------------------------------
+// Multiply.
+
+// Shim vmull_high_u16 for armv7.
+inline uint32x4_t VMullHighU16(const uint16x8_t a, const uint16x8_t b) {
+#if defined(__aarch64__)
+ return vmull_high_u16(a, b);
+#else
+ return vmull_u16(vget_high_u16(a), vget_high_u16(b));
+#endif
+}
+
+// Shim vmull_high_s16 for armv7.
+inline int32x4_t VMullHighS16(const int16x8_t a, const int16x8_t b) {
+#if defined(__aarch64__)
+ return vmull_high_s16(a, b);
+#else
+ return vmull_s16(vget_high_s16(a), vget_high_s16(b));
+#endif
+}
+
+// Shim vmlal_high_u16 for armv7.
+inline uint32x4_t VMlalHighU16(const uint32x4_t a, const uint16x8_t b,
+ const uint16x8_t c) {
+#if defined(__aarch64__)
+ return vmlal_high_u16(a, b, c);
+#else
+ return vmlal_u16(a, vget_high_u16(b), vget_high_u16(c));
+#endif
+}
+
+// Shim vmlal_high_s16 for armv7.
+inline int32x4_t VMlalHighS16(const int32x4_t a, const int16x8_t b,
+ const int16x8_t c) {
+#if defined(__aarch64__)
+ return vmlal_high_s16(a, b, c);
+#else
+ return vmlal_s16(a, vget_high_s16(b), vget_high_s16(c));
+#endif
+}
+
+// Shim vmul_laneq_u16 for armv7.
+template <int lane>
+inline uint16x4_t VMulLaneQU16(const uint16x4_t a, const uint16x8_t b) {
+#if defined(__aarch64__)
+ return vmul_laneq_u16(a, b, lane);
+#else
+ if (lane < 4) return vmul_lane_u16(a, vget_low_u16(b), lane & 0x3);
+ return vmul_lane_u16(a, vget_high_u16(b), (lane - 4) & 0x3);
+#endif
+}
+
+// Shim vmulq_laneq_u16 for armv7.
+template <int lane>
+inline uint16x8_t VMulQLaneQU16(const uint16x8_t a, const uint16x8_t b) {
+#if defined(__aarch64__)
+ return vmulq_laneq_u16(a, b, lane);
+#else
+ if (lane < 4) return vmulq_lane_u16(a, vget_low_u16(b), lane & 0x3);
+ return vmulq_lane_u16(a, vget_high_u16(b), (lane - 4) & 0x3);
+#endif
+}
+
+// Shim vmla_laneq_u16 for armv7.
+template <int lane>
+inline uint16x4_t VMlaLaneQU16(const uint16x4_t a, const uint16x4_t b,
+ const uint16x8_t c) {
+#if defined(__aarch64__)
+ return vmla_laneq_u16(a, b, c, lane);
+#else
+ if (lane < 4) return vmla_lane_u16(a, b, vget_low_u16(c), lane & 0x3);
+ return vmla_lane_u16(a, b, vget_high_u16(c), (lane - 4) & 0x3);
+#endif
+}
+
+// Shim vmlaq_laneq_u16 for armv7.
+template <int lane>
+inline uint16x8_t VMlaQLaneQU16(const uint16x8_t a, const uint16x8_t b,
+ const uint16x8_t c) {
+#if defined(__aarch64__)
+ return vmlaq_laneq_u16(a, b, c, lane);
+#else
+ if (lane < 4) return vmlaq_lane_u16(a, b, vget_low_u16(c), lane & 0x3);
+ return vmlaq_lane_u16(a, b, vget_high_u16(c), (lane - 4) & 0x3);
+#endif
+}
+
+//------------------------------------------------------------------------------
// Bit manipulation.
// vshXX_n_XX() requires an immediate.
@@ -315,6 +529,51 @@ inline uint8x8_t VQTbl1U8(const uint8x16_t a, const uint8x8_t index) {
#endif
}
+// Shim vqtbl2_u8 for armv7.
+inline uint8x8_t VQTbl2U8(const uint8x16x2_t a, const uint8x8_t index) {
+#if defined(__aarch64__)
+ return vqtbl2_u8(a, index);
+#else
+ const uint8x8x4_t b = {vget_low_u8(a.val[0]), vget_high_u8(a.val[0]),
+ vget_low_u8(a.val[1]), vget_high_u8(a.val[1])};
+ return vtbl4_u8(b, index);
+#endif
+}
+
+// Shim vqtbl2q_u8 for armv7.
+inline uint8x16_t VQTbl2QU8(const uint8x16x2_t a, const uint8x16_t index) {
+#if defined(__aarch64__)
+ return vqtbl2q_u8(a, index);
+#else
+ return vcombine_u8(VQTbl2U8(a, vget_low_u8(index)),
+ VQTbl2U8(a, vget_high_u8(index)));
+#endif
+}
+
+// Shim vqtbl3q_u8 for armv7.
+inline uint8x8_t VQTbl3U8(const uint8x16x3_t a, const uint8x8_t index) {
+#if defined(__aarch64__)
+ return vqtbl3_u8(a, index);
+#else
+ const uint8x8x4_t b = {vget_low_u8(a.val[0]), vget_high_u8(a.val[0]),
+ vget_low_u8(a.val[1]), vget_high_u8(a.val[1])};
+ const uint8x8x2_t c = {vget_low_u8(a.val[2]), vget_high_u8(a.val[2])};
+ const uint8x8_t index_ext = vsub_u8(index, vdup_n_u8(32));
+ const uint8x8_t partial_lookup = vtbl4_u8(b, index);
+ return vtbx2_u8(partial_lookup, c, index_ext);
+#endif
+}
+
+// Shim vqtbl3q_u8 for armv7.
+inline uint8x16_t VQTbl3QU8(const uint8x16x3_t a, const uint8x16_t index) {
+#if defined(__aarch64__)
+ return vqtbl3q_u8(a, index);
+#else
+ return vcombine_u8(VQTbl3U8(a, vget_low_u8(index)),
+ VQTbl3U8(a, vget_high_u8(index)));
+#endif
+}
+
// Shim vqtbl1_s8 for armv7.
inline int8x8_t VQTbl1S8(const int8x16_t a, const uint8x8_t index) {
#if defined(__aarch64__)
@@ -326,6 +585,25 @@ inline int8x8_t VQTbl1S8(const int8x16_t a, const uint8x8_t index) {
}
//------------------------------------------------------------------------------
+// Saturation helpers.
+
+inline int16x4_t Clip3S16(int16x4_t val, int16x4_t low, int16x4_t high) {
+ return vmin_s16(vmax_s16(val, low), high);
+}
+
+inline int16x8_t Clip3S16(const int16x8_t val, const int16x8_t low,
+ const int16x8_t high) {
+ return vminq_s16(vmaxq_s16(val, low), high);
+}
+
+inline uint16x8_t ConvertToUnsignedPixelU16(int16x8_t val, int bitdepth) {
+ const int16x8_t low = vdupq_n_s16(0);
+ const uint16x8_t high = vdupq_n_u16((1 << bitdepth) - 1);
+
+ return vminq_u16(vreinterpretq_u16_s16(vmaxq_s16(val, low)), high);
+}
+
+//------------------------------------------------------------------------------
// Interleave.
// vzipN is exclusive to A64.
@@ -439,6 +717,9 @@ inline uint8x8_t Transpose32(const uint8x8_t a) {
return vreinterpret_u8_u32(b);
}
+// Swap high and low halves.
+inline uint16x8_t Transpose64(const uint16x8_t a) { return vextq_u16(a, a, 4); }
+
// Implement vtrnq_s64().
// Input:
// a0: 00 01 02 03 04 05 06 07
@@ -512,6 +793,108 @@ inline void Transpose4x4(uint8x8_t* a, uint8x8_t* b) {
*b = e.val[1];
}
+// 4x8 Input:
+// a[0]: 00 01 02 03 04 05 06 07
+// a[1]: 10 11 12 13 14 15 16 17
+// a[2]: 20 21 22 23 24 25 26 27
+// a[3]: 30 31 32 33 34 35 36 37
+// 8x4 Output:
+// a[0]: 00 10 20 30 04 14 24 34
+// a[1]: 01 11 21 31 05 15 25 35
+// a[2]: 02 12 22 32 06 16 26 36
+// a[3]: 03 13 23 33 07 17 27 37
+inline void Transpose4x8(uint16x8_t a[4]) {
+ // b0.val[0]: 00 10 02 12 04 14 06 16
+ // b0.val[1]: 01 11 03 13 05 15 07 17
+ // b1.val[0]: 20 30 22 32 24 34 26 36
+ // b1.val[1]: 21 31 23 33 25 35 27 37
+ const uint16x8x2_t b0 = vtrnq_u16(a[0], a[1]);
+ const uint16x8x2_t b1 = vtrnq_u16(a[2], a[3]);
+
+ // c0.val[0]: 00 10 20 30 04 14 24 34
+ // c0.val[1]: 02 12 22 32 06 16 26 36
+ // c1.val[0]: 01 11 21 31 05 15 25 35
+ // c1.val[1]: 03 13 23 33 07 17 27 37
+ const uint32x4x2_t c0 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[0]),
+ vreinterpretq_u32_u16(b1.val[0]));
+ const uint32x4x2_t c1 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[1]),
+ vreinterpretq_u32_u16(b1.val[1]));
+
+ a[0] = vreinterpretq_u16_u32(c0.val[0]);
+ a[1] = vreinterpretq_u16_u32(c1.val[0]);
+ a[2] = vreinterpretq_u16_u32(c0.val[1]);
+ a[3] = vreinterpretq_u16_u32(c1.val[1]);
+}
+
+// Special transpose for loop filter.
+// 4x8 Input:
+// p_q: p3 p2 p1 p0 q0 q1 q2 q3
+// a[0]: 00 01 02 03 04 05 06 07
+// a[1]: 10 11 12 13 14 15 16 17
+// a[2]: 20 21 22 23 24 25 26 27
+// a[3]: 30 31 32 33 34 35 36 37
+// 8x4 Output:
+// a[0]: 03 13 23 33 04 14 24 34 p0q0
+// a[1]: 02 12 22 32 05 15 25 35 p1q1
+// a[2]: 01 11 21 31 06 16 26 36 p2q2
+// a[3]: 00 10 20 30 07 17 27 37 p3q3
+// Direct reapplication of the function will reset the high halves, but
+// reverse the low halves:
+// p_q: p0 p1 p2 p3 q0 q1 q2 q3
+// a[0]: 33 32 31 30 04 05 06 07
+// a[1]: 23 22 21 20 14 15 16 17
+// a[2]: 13 12 11 10 24 25 26 27
+// a[3]: 03 02 01 00 34 35 36 37
+// Simply reordering the inputs (3, 2, 1, 0) will reset the low halves, but
+// reverse the high halves.
+// The standard Transpose4x8 will produce the same reversals, but with the
+// order of the low halves also restored relative to the high halves. This is
+// preferable because it puts all values from the same source row back together,
+// but some post-processing is inevitable.
+inline void LoopFilterTranspose4x8(uint16x8_t a[4]) {
+ // b0.val[0]: 00 10 02 12 04 14 06 16
+ // b0.val[1]: 01 11 03 13 05 15 07 17
+ // b1.val[0]: 20 30 22 32 24 34 26 36
+ // b1.val[1]: 21 31 23 33 25 35 27 37
+ const uint16x8x2_t b0 = vtrnq_u16(a[0], a[1]);
+ const uint16x8x2_t b1 = vtrnq_u16(a[2], a[3]);
+
+ // Reverse odd vectors to bring the appropriate items to the front of zips.
+ // b0.val[0]: 00 10 02 12 04 14 06 16
+ // r0 : 03 13 01 11 07 17 05 15
+ // b1.val[0]: 20 30 22 32 24 34 26 36
+ // r1 : 23 33 21 31 27 37 25 35
+ const uint32x4_t r0 = vrev64q_u32(vreinterpretq_u32_u16(b0.val[1]));
+ const uint32x4_t r1 = vrev64q_u32(vreinterpretq_u32_u16(b1.val[1]));
+
+ // Zip to complete the halves.
+ // c0.val[0]: 00 10 20 30 02 12 22 32 p3p1
+ // c0.val[1]: 04 14 24 34 06 16 26 36 q0q2
+ // c1.val[0]: 03 13 23 33 01 11 21 31 p0p2
+ // c1.val[1]: 07 17 27 37 05 15 25 35 q3q1
+ const uint32x4x2_t c0 = vzipq_u32(vreinterpretq_u32_u16(b0.val[0]),
+ vreinterpretq_u32_u16(b1.val[0]));
+ const uint32x4x2_t c1 = vzipq_u32(r0, r1);
+
+ // d0.val[0]: 00 10 20 30 07 17 27 37 p3q3
+ // d0.val[1]: 02 12 22 32 05 15 25 35 p1q1
+ // d1.val[0]: 03 13 23 33 04 14 24 34 p0q0
+ // d1.val[1]: 01 11 21 31 06 16 26 36 p2q2
+ const uint16x8x2_t d0 = VtrnqU64(c0.val[0], c1.val[1]);
+ // The third row of c comes first here to swap p2 with q0.
+ const uint16x8x2_t d1 = VtrnqU64(c1.val[0], c0.val[1]);
+
+ // 8x4 Output:
+ // a[0]: 03 13 23 33 04 14 24 34 p0q0
+ // a[1]: 02 12 22 32 05 15 25 35 p1q1
+ // a[2]: 01 11 21 31 06 16 26 36 p2q2
+ // a[3]: 00 10 20 30 07 17 27 37 p3q3
+ a[0] = d1.val[0]; // p0q0
+ a[1] = d0.val[1]; // p1q1
+ a[2] = d1.val[1]; // p2q2
+ a[3] = d0.val[0]; // p3q3
+}
+
// Reversible if the x4 values are packed next to each other.
// x4 input / x8 output:
// a0: 00 01 02 03 40 41 42 43 44