diff options
Diffstat (limited to 'src/dsp/arm/cdef_neon.cc')
-rw-r--r-- | src/dsp/arm/cdef_neon.cc | 268 |
1 files changed, 188 insertions, 80 deletions
diff --git a/src/dsp/arm/cdef_neon.cc b/src/dsp/arm/cdef_neon.cc index 60c72d6..da271f2 100644 --- a/src/dsp/arm/cdef_neon.cc +++ b/src/dsp/arm/cdef_neon.cc @@ -33,7 +33,6 @@ namespace libgav1 { namespace dsp { -namespace low_bitdepth { namespace { #include "src/dsp/cdef.inc" @@ -234,7 +233,8 @@ LIBGAV1_ALWAYS_INLINE void AddPartial_D5_D7(uint8x8_t* v_src, *partial_hi = vaddq_u16(*partial_hi, vextq_u16(v_pair_add[3], v_zero, 5)); } -LIBGAV1_ALWAYS_INLINE void AddPartial(const void* const source, +template <int bitdepth> +LIBGAV1_ALWAYS_INLINE void AddPartial(const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride, uint16x8_t* partial_lo, uint16x8_t* partial_hi) { const auto* src = static_cast<const uint8_t*>(source); @@ -249,11 +249,20 @@ LIBGAV1_ALWAYS_INLINE void AddPartial(const void* const source, // 60 61 62 63 64 65 66 67 // 70 71 72 73 74 75 76 77 uint8x8_t v_src[8]; - for (int i = 0; i < 8; ++i) { - v_src[i] = vld1_u8(src); - src += stride; + if (bitdepth == kBitdepth8) { + for (auto& v : v_src) { + v = vld1_u8(src); + src += stride; + } + } else { + // bitdepth - 8 + constexpr int src_shift = (bitdepth == kBitdepth10) ? 2 : 4; + for (auto& v : v_src) { + v = vshrn_n_u16(vld1q_u16(reinterpret_cast<const uint16_t*>(src)), + src_shift); + src += stride; + } } - // partial for direction 2 // -------------------------------------------------------------------------- // partial[2][i] += x; @@ -358,15 +367,19 @@ uint32_t CostOdd(const uint16x8_t a, const uint16x8_t b, const uint32x4_t mask, return SumVector(c); } -void CdefDirection_NEON(const void* const source, ptrdiff_t stride, - uint8_t* const direction, int* const variance) { +template <int bitdepth> +void CdefDirection_NEON(const void* LIBGAV1_RESTRICT const source, + ptrdiff_t stride, + uint8_t* LIBGAV1_RESTRICT const direction, + int* LIBGAV1_RESTRICT const variance) { assert(direction != nullptr); assert(variance != nullptr); const auto* src = static_cast<const uint8_t*>(source); + uint32_t cost[8]; uint16x8_t partial_lo[8], partial_hi[8]; - AddPartial(src, stride, partial_lo, partial_hi); + AddPartial<bitdepth>(src, stride, partial_lo, partial_hi); cost[2] = SquareAccumulate(partial_lo[2]); cost[6] = SquareAccumulate(partial_lo[6]); @@ -407,8 +420,9 @@ void CdefDirection_NEON(const void* const source, ptrdiff_t stride, // CdefFilter // Load 4 vectors based on the given |direction|. -void LoadDirection(const uint16_t* const src, const ptrdiff_t stride, - uint16x8_t* output, const int direction) { +void LoadDirection(const uint16_t* LIBGAV1_RESTRICT const src, + const ptrdiff_t stride, uint16x8_t* output, + const int direction) { // Each |direction| describes a different set of source values. Expand this // set by negating each set. For |direction| == 0 this gives a diagonal line // from top right to bottom left. The first value is y, the second x. Negative @@ -432,8 +446,9 @@ void LoadDirection(const uint16_t* const src, const ptrdiff_t stride, // Load 4 vectors based on the given |direction|. Use when |block_width| == 4 to // do 2 rows at a time. -void LoadDirection4(const uint16_t* const src, const ptrdiff_t stride, - uint16x8_t* output, const int direction) { +void LoadDirection4(const uint16_t* LIBGAV1_RESTRICT const src, + const ptrdiff_t stride, uint16x8_t* output, + const int direction) { const int y_0 = kCdefDirections[direction][0][0]; const int x_0 = kCdefDirections[direction][0][1]; const int y_1 = kCdefDirections[direction][1][0]; @@ -469,12 +484,90 @@ int16x8_t Constrain(const uint16x8_t pixel, const uint16x8_t reference, vsubq_u16(veorq_u16(clamp_abs_diff, sign), sign)); } -template <int width, bool enable_primary = true, bool enable_secondary = true> -void CdefFilter_NEON(const uint16_t* src, const ptrdiff_t src_stride, - const int height, const int primary_strength, - const int secondary_strength, const int damping, - const int direction, void* dest, - const ptrdiff_t dst_stride) { +template <typename Pixel> +uint16x8_t GetMaxPrimary(uint16x8_t* primary_val, uint16x8_t max, + uint16x8_t cdef_large_value_mask) { + if (sizeof(Pixel) == 1) { + // The source is 16 bits, however, we only really care about the lower + // 8 bits. The upper 8 bits contain the "large" flag. After the final + // primary max has been calculated, zero out the upper 8 bits. Use this + // to find the "16 bit" max. + const uint8x16_t max_p01 = vmaxq_u8(vreinterpretq_u8_u16(primary_val[0]), + vreinterpretq_u8_u16(primary_val[1])); + const uint8x16_t max_p23 = vmaxq_u8(vreinterpretq_u8_u16(primary_val[2]), + vreinterpretq_u8_u16(primary_val[3])); + const uint16x8_t max_p = vreinterpretq_u16_u8(vmaxq_u8(max_p01, max_p23)); + max = vmaxq_u16(max, vandq_u16(max_p, cdef_large_value_mask)); + } else { + // Convert kCdefLargeValue to 0 before calculating max. + max = vmaxq_u16(max, vandq_u16(primary_val[0], cdef_large_value_mask)); + max = vmaxq_u16(max, vandq_u16(primary_val[1], cdef_large_value_mask)); + max = vmaxq_u16(max, vandq_u16(primary_val[2], cdef_large_value_mask)); + max = vmaxq_u16(max, vandq_u16(primary_val[3], cdef_large_value_mask)); + } + return max; +} + +template <typename Pixel> +uint16x8_t GetMaxSecondary(uint16x8_t* secondary_val, uint16x8_t max, + uint16x8_t cdef_large_value_mask) { + if (sizeof(Pixel) == 1) { + const uint8x16_t max_s01 = vmaxq_u8(vreinterpretq_u8_u16(secondary_val[0]), + vreinterpretq_u8_u16(secondary_val[1])); + const uint8x16_t max_s23 = vmaxq_u8(vreinterpretq_u8_u16(secondary_val[2]), + vreinterpretq_u8_u16(secondary_val[3])); + const uint8x16_t max_s45 = vmaxq_u8(vreinterpretq_u8_u16(secondary_val[4]), + vreinterpretq_u8_u16(secondary_val[5])); + const uint8x16_t max_s67 = vmaxq_u8(vreinterpretq_u8_u16(secondary_val[6]), + vreinterpretq_u8_u16(secondary_val[7])); + const uint16x8_t max_s = vreinterpretq_u16_u8( + vmaxq_u8(vmaxq_u8(max_s01, max_s23), vmaxq_u8(max_s45, max_s67))); + max = vmaxq_u16(max, vandq_u16(max_s, cdef_large_value_mask)); + } else { + max = vmaxq_u16(max, vandq_u16(secondary_val[0], cdef_large_value_mask)); + max = vmaxq_u16(max, vandq_u16(secondary_val[1], cdef_large_value_mask)); + max = vmaxq_u16(max, vandq_u16(secondary_val[2], cdef_large_value_mask)); + max = vmaxq_u16(max, vandq_u16(secondary_val[3], cdef_large_value_mask)); + max = vmaxq_u16(max, vandq_u16(secondary_val[4], cdef_large_value_mask)); + max = vmaxq_u16(max, vandq_u16(secondary_val[5], cdef_large_value_mask)); + max = vmaxq_u16(max, vandq_u16(secondary_val[6], cdef_large_value_mask)); + max = vmaxq_u16(max, vandq_u16(secondary_val[7], cdef_large_value_mask)); + } + return max; +} + +template <typename Pixel, int width> +void StorePixels(void* dest, ptrdiff_t dst_stride, int16x8_t result) { + auto* const dst8 = static_cast<uint8_t*>(dest); + if (sizeof(Pixel) == 1) { + const uint8x8_t dst_pixel = vqmovun_s16(result); + if (width == 8) { + vst1_u8(dst8, dst_pixel); + } else { + StoreLo4(dst8, dst_pixel); + StoreHi4(dst8 + dst_stride, dst_pixel); + } + } else { + const uint16x8_t dst_pixel = vreinterpretq_u16_s16(result); + auto* const dst16 = reinterpret_cast<uint16_t*>(dst8); + if (width == 8) { + vst1q_u16(dst16, dst_pixel); + } else { + auto* const dst16_next_row = + reinterpret_cast<uint16_t*>(dst8 + dst_stride); + vst1_u16(dst16, vget_low_u16(dst_pixel)); + vst1_u16(dst16_next_row, vget_high_u16(dst_pixel)); + } + } +} + +template <int width, typename Pixel, bool enable_primary = true, + bool enable_secondary = true> +void CdefFilter_NEON(const uint16_t* LIBGAV1_RESTRICT src, + const ptrdiff_t src_stride, const int height, + const int primary_strength, const int secondary_strength, + const int damping, const int direction, + void* LIBGAV1_RESTRICT dest, const ptrdiff_t dst_stride) { static_assert(width == 8 || width == 4, ""); static_assert(enable_primary || enable_secondary, ""); constexpr bool clipping_required = enable_primary && enable_secondary; @@ -488,22 +581,34 @@ void CdefFilter_NEON(const uint16_t* src, const ptrdiff_t src_stride, // FloorLog2() requires input to be > 0. // 8-bit damping range: Y: [3, 6], UV: [2, 5]. + // 10-bit damping range: Y: [3, 6 + 2], UV: [2, 5 + 2]. if (enable_primary) { - // primary_strength: [0, 15] -> FloorLog2: [0, 3] so a clamp is necessary - // for UV filtering. + // 8-bit primary_strength: [0, 15] -> FloorLog2: [0, 3] so a clamp is + // necessary for UV filtering. + // 10-bit primary_strength: [0, 15 << 2]. primary_damping_shift = vdupq_n_s16(-std::max(0, damping - FloorLog2(primary_strength))); } + if (enable_secondary) { - // secondary_strength: [0, 4] -> FloorLog2: [0, 2] so no clamp to 0 is - // necessary. - assert(damping - FloorLog2(secondary_strength) >= 0); - secondary_damping_shift = - vdupq_n_s16(-(damping - FloorLog2(secondary_strength))); + if (sizeof(Pixel) == 1) { + // secondary_strength: [0, 4] -> FloorLog2: [0, 2] so no clamp to 0 is + // necessary. + assert(damping - FloorLog2(secondary_strength) >= 0); + secondary_damping_shift = + vdupq_n_s16(-(damping - FloorLog2(secondary_strength))); + } else { + // secondary_strength: [0, 4 << 2] + secondary_damping_shift = + vdupq_n_s16(-std::max(0, damping - FloorLog2(secondary_strength))); + } } - const int primary_tap_0 = kCdefPrimaryTaps[primary_strength & 1][0]; - const int primary_tap_1 = kCdefPrimaryTaps[primary_strength & 1][1]; + constexpr int coeff_shift = (sizeof(Pixel) == 1) ? 0 : kBitdepth10 - 8; + const int primary_tap_0 = + kCdefPrimaryTaps[(primary_strength >> coeff_shift) & 1][0]; + const int primary_tap_1 = + kCdefPrimaryTaps[(primary_strength >> coeff_shift) & 1][1]; int y = height; do { @@ -533,19 +638,7 @@ void CdefFilter_NEON(const uint16_t* src, const ptrdiff_t src_stride, min = vminq_u16(min, primary_val[2]); min = vminq_u16(min, primary_val[3]); - // The source is 16 bits, however, we only really care about the lower - // 8 bits. The upper 8 bits contain the "large" flag. After the final - // primary max has been calculated, zero out the upper 8 bits. Use this - // to find the "16 bit" max. - const uint8x16_t max_p01 = - vmaxq_u8(vreinterpretq_u8_u16(primary_val[0]), - vreinterpretq_u8_u16(primary_val[1])); - const uint8x16_t max_p23 = - vmaxq_u8(vreinterpretq_u8_u16(primary_val[2]), - vreinterpretq_u8_u16(primary_val[3])); - const uint16x8_t max_p = - vreinterpretq_u16_u8(vmaxq_u8(max_p01, max_p23)); - max = vmaxq_u16(max, vandq_u16(max_p, cdef_large_value_mask)); + max = GetMaxPrimary<Pixel>(primary_val, max, cdef_large_value_mask); } sum = Constrain(primary_val[0], pixel, primary_threshold, @@ -588,21 +681,7 @@ void CdefFilter_NEON(const uint16_t* src, const ptrdiff_t src_stride, min = vminq_u16(min, secondary_val[6]); min = vminq_u16(min, secondary_val[7]); - const uint8x16_t max_s01 = - vmaxq_u8(vreinterpretq_u8_u16(secondary_val[0]), - vreinterpretq_u8_u16(secondary_val[1])); - const uint8x16_t max_s23 = - vmaxq_u8(vreinterpretq_u8_u16(secondary_val[2]), - vreinterpretq_u8_u16(secondary_val[3])); - const uint8x16_t max_s45 = - vmaxq_u8(vreinterpretq_u8_u16(secondary_val[4]), - vreinterpretq_u8_u16(secondary_val[5])); - const uint8x16_t max_s67 = - vmaxq_u8(vreinterpretq_u8_u16(secondary_val[6]), - vreinterpretq_u8_u16(secondary_val[7])); - const uint16x8_t max_s = vreinterpretq_u16_u8( - vmaxq_u8(vmaxq_u8(max_s01, max_s23), vmaxq_u8(max_s45, max_s67))); - max = vmaxq_u16(max, vandq_u16(max_s, cdef_large_value_mask)); + max = GetMaxSecondary<Pixel>(secondary_val, max, cdef_large_value_mask); } sum = vmlaq_n_s16(sum, @@ -647,41 +726,70 @@ void CdefFilter_NEON(const uint16_t* src, const ptrdiff_t src_stride, result = vmaxq_s16(result, vreinterpretq_s16_u16(min)); } - const uint8x8_t dst_pixel = vqmovun_s16(result); - if (width == 8) { - src += src_stride; - vst1_u8(dst, dst_pixel); - dst += dst_stride; - --y; - } else { - src += src_stride << 1; - StoreLo4(dst, dst_pixel); - dst += dst_stride; - StoreHi4(dst, dst_pixel); - dst += dst_stride; - y -= 2; - } + StorePixels<Pixel, width>(dst, dst_stride, result); + + src += (width == 8) ? src_stride : src_stride << 1; + dst += (width == 8) ? dst_stride : dst_stride << 1; + y -= (width == 8) ? 1 : 2; } while (y != 0); } +} // namespace + +namespace low_bitdepth { +namespace { + void Init8bpp() { Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); assert(dsp != nullptr); - dsp->cdef_direction = CdefDirection_NEON; - dsp->cdef_filters[0][0] = CdefFilter_NEON<4>; - dsp->cdef_filters[0][1] = - CdefFilter_NEON<4, /*enable_primary=*/true, /*enable_secondary=*/false>; - dsp->cdef_filters[0][2] = CdefFilter_NEON<4, /*enable_primary=*/false>; - dsp->cdef_filters[1][0] = CdefFilter_NEON<8>; - dsp->cdef_filters[1][1] = - CdefFilter_NEON<8, /*enable_primary=*/true, /*enable_secondary=*/false>; - dsp->cdef_filters[1][2] = CdefFilter_NEON<8, /*enable_primary=*/false>; + dsp->cdef_direction = CdefDirection_NEON<kBitdepth8>; + dsp->cdef_filters[0][0] = CdefFilter_NEON<4, uint8_t>; + dsp->cdef_filters[0][1] = CdefFilter_NEON<4, uint8_t, /*enable_primary=*/true, + /*enable_secondary=*/false>; + dsp->cdef_filters[0][2] = + CdefFilter_NEON<4, uint8_t, /*enable_primary=*/false>; + dsp->cdef_filters[1][0] = CdefFilter_NEON<8, uint8_t>; + dsp->cdef_filters[1][1] = CdefFilter_NEON<8, uint8_t, /*enable_primary=*/true, + /*enable_secondary=*/false>; + dsp->cdef_filters[1][2] = + CdefFilter_NEON<8, uint8_t, /*enable_primary=*/false>; } } // namespace } // namespace low_bitdepth -void CdefInit_NEON() { low_bitdepth::Init8bpp(); } +#if LIBGAV1_MAX_BITDEPTH >= 10 +namespace high_bitdepth { +namespace { + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + dsp->cdef_direction = CdefDirection_NEON<kBitdepth10>; + dsp->cdef_filters[0][0] = CdefFilter_NEON<4, uint16_t>; + dsp->cdef_filters[0][1] = + CdefFilter_NEON<4, uint16_t, /*enable_primary=*/true, + /*enable_secondary=*/false>; + dsp->cdef_filters[0][2] = + CdefFilter_NEON<4, uint16_t, /*enable_primary=*/false>; + dsp->cdef_filters[1][0] = CdefFilter_NEON<8, uint16_t>; + dsp->cdef_filters[1][1] = + CdefFilter_NEON<8, uint16_t, /*enable_primary=*/true, + /*enable_secondary=*/false>; + dsp->cdef_filters[1][2] = + CdefFilter_NEON<8, uint16_t, /*enable_primary=*/false>; +} + +} // namespace +} // namespace high_bitdepth +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +void CdefInit_NEON() { + low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + high_bitdepth::Init10bpp(); +#endif +} } // namespace dsp } // namespace libgav1 |