diff options
Diffstat (limited to 'src/dsp/arm/loop_restoration_neon.cc')
-rw-r--r-- | src/dsp/arm/loop_restoration_neon.cc | 1470 |
1 files changed, 1011 insertions, 459 deletions
diff --git a/src/dsp/arm/loop_restoration_neon.cc b/src/dsp/arm/loop_restoration_neon.cc index 337c9b4..e6ceb66 100644 --- a/src/dsp/arm/loop_restoration_neon.cc +++ b/src/dsp/arm/loop_restoration_neon.cc @@ -41,10 +41,25 @@ inline uint8x8_t VshrU128(const uint8x8x2_t src) { } template <int bytes> +inline uint8x8_t VshrU128(const uint8x8_t src[2]) { + return vext_u8(src[0], src[1], bytes); +} + +template <int bytes> +inline uint8x16_t VshrU128(const uint8x16_t src[2]) { + return vextq_u8(src[0], src[1], bytes); +} + +template <int bytes> inline uint16x8_t VshrU128(const uint16x8x2_t src) { return vextq_u16(src.val[0], src.val[1], bytes / 2); } +template <int bytes> +inline uint16x8_t VshrU128(const uint16x8_t src[2]) { + return vextq_u16(src[0], src[1], bytes / 2); +} + // Wiener // Must make a local copy of coefficients to help compiler know that they have @@ -177,18 +192,17 @@ inline void WienerHorizontalTap3(const uint8_t* src, const ptrdiff_t src_stride, int16_t** const wiener_buffer) { for (int y = height; y != 0; --y) { const uint8_t* src_ptr = src; - uint8x16_t s[4]; - s[0] = vld1q_u8(src_ptr); + uint8x16_t s[3]; ptrdiff_t x = width; do { - src_ptr += 16; - s[3] = vld1q_u8(src_ptr); - s[1] = vextq_u8(s[0], s[3], 1); - s[2] = vextq_u8(s[0], s[3], 2); + // Slightly faster than using vextq_u8(). + s[0] = vld1q_u8(src_ptr); + s[1] = vld1q_u8(src_ptr + 1); + s[2] = vld1q_u8(src_ptr + 2); int16x8x2_t sum; sum.val[0] = sum.val[1] = vdupq_n_s16(0); WienerHorizontalSum(s, filter, sum, *wiener_buffer); - s[0] = s[3]; + src_ptr += 16; *wiener_buffer += 16; x -= 16; } while (x != 0); @@ -476,12 +490,12 @@ inline void WienerVerticalTap1(const int16_t* wiener_buffer, // For width 16 and up, store the horizontal results, and then do the vertical // filter row by row. This is faster than doing it column by column when // considering cache issues. -void WienerFilter_NEON(const RestorationUnitInfo& restoration_info, - const void* const source, const void* const top_border, - const void* const bottom_border, const ptrdiff_t stride, - const int width, const int height, - RestorationBuffer* const restoration_buffer, - void* const dest) { +void WienerFilter_NEON( + const RestorationUnitInfo& restoration_info, const void* const source, + const ptrdiff_t stride, const void* const top_border, + const ptrdiff_t top_border_stride, const void* const bottom_border, + const ptrdiff_t bottom_border_stride, const int width, const int height, + RestorationBuffer* const restoration_buffer, void* const dest) { const int16_t* const number_leading_zero_coefficients = restoration_info.wiener_info.number_leading_zero_coefficients; const int number_rows_to_skip = std::max( @@ -509,39 +523,42 @@ void WienerFilter_NEON(const RestorationUnitInfo& restoration_info, const auto* const top = static_cast<const uint8_t*>(top_border); const auto* const bottom = static_cast<const uint8_t*>(bottom_border); if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) { - WienerHorizontalTap7(top + (2 - height_extra) * stride - 3, stride, - wiener_stride, height_extra, filter_horizontal, - &wiener_buffer_horizontal); - WienerHorizontalTap7(src - 3, stride, wiener_stride, height, + WienerHorizontalTap7(top + (2 - height_extra) * top_border_stride - 3, + top_border_stride, wiener_stride, height_extra, filter_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap7(bottom - 3, stride, wiener_stride, height_extra, + WienerHorizontalTap7(src - 3, stride, wiener_stride, height, filter_horizontal, &wiener_buffer_horizontal); - } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) { - WienerHorizontalTap5(top + (2 - height_extra) * stride - 2, stride, - wiener_stride, height_extra, filter_horizontal, + WienerHorizontalTap7(bottom - 3, bottom_border_stride, wiener_stride, + height_extra, filter_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap5(src - 2, stride, wiener_stride, height, + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) { + WienerHorizontalTap5(top + (2 - height_extra) * top_border_stride - 2, + top_border_stride, wiener_stride, height_extra, filter_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap5(bottom - 2, stride, wiener_stride, height_extra, + WienerHorizontalTap5(src - 2, stride, wiener_stride, height, filter_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap5(bottom - 2, bottom_border_stride, wiener_stride, + height_extra, filter_horizontal, + &wiener_buffer_horizontal); } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) { // The maximum over-reads happen here. - WienerHorizontalTap3(top + (2 - height_extra) * stride - 1, stride, - wiener_stride, height_extra, filter_horizontal, - &wiener_buffer_horizontal); - WienerHorizontalTap3(src - 1, stride, wiener_stride, height, + WienerHorizontalTap3(top + (2 - height_extra) * top_border_stride - 1, + top_border_stride, wiener_stride, height_extra, filter_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap3(bottom - 1, stride, wiener_stride, height_extra, + WienerHorizontalTap3(src - 1, stride, wiener_stride, height, filter_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap3(bottom - 1, bottom_border_stride, wiener_stride, + height_extra, filter_horizontal, + &wiener_buffer_horizontal); } else { assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3); - WienerHorizontalTap1(top + (2 - height_extra) * stride, stride, - wiener_stride, height_extra, + WienerHorizontalTap1(top + (2 - height_extra) * top_border_stride, + top_border_stride, wiener_stride, height_extra, &wiener_buffer_horizontal); WienerHorizontalTap1(src, stride, wiener_stride, height, &wiener_buffer_horizontal); - WienerHorizontalTap1(bottom, stride, wiener_stride, height_extra, - &wiener_buffer_horizontal); + WienerHorizontalTap1(bottom, bottom_border_stride, wiener_stride, + height_extra, &wiener_buffer_horizontal); } // vertical filtering. @@ -574,13 +591,20 @@ void WienerFilter_NEON(const RestorationUnitInfo& restoration_info, //------------------------------------------------------------------------------ // SGR -inline void Prepare3_8(const uint8x8x2_t src, uint8x8_t dst[3]) { +inline void Prepare3_8(const uint8x8_t src[2], uint8x8_t dst[3]) { dst[0] = VshrU128<0>(src); dst[1] = VshrU128<1>(src); dst[2] = VshrU128<2>(src); } -inline void Prepare3_16(const uint16x8x2_t src, uint16x4_t low[3], +template <int offset> +inline void Prepare3_8(const uint8x16_t src[2], uint8x16_t dst[3]) { + dst[0] = VshrU128<offset + 0>(src); + dst[1] = VshrU128<offset + 1>(src); + dst[2] = VshrU128<offset + 2>(src); +} + +inline void Prepare3_16(const uint16x8_t src[2], uint16x4_t low[3], uint16x4_t high[3]) { uint16x8_t s[3]; s[0] = VshrU128<0>(src); @@ -594,7 +618,7 @@ inline void Prepare3_16(const uint16x8x2_t src, uint16x4_t low[3], high[2] = vget_high_u16(s[2]); } -inline void Prepare5_8(const uint8x8x2_t src, uint8x8_t dst[5]) { +inline void Prepare5_8(const uint8x8_t src[2], uint8x8_t dst[5]) { dst[0] = VshrU128<0>(src); dst[1] = VshrU128<1>(src); dst[2] = VshrU128<2>(src); @@ -602,7 +626,16 @@ inline void Prepare5_8(const uint8x8x2_t src, uint8x8_t dst[5]) { dst[4] = VshrU128<4>(src); } -inline void Prepare5_16(const uint16x8x2_t src, uint16x4_t low[5], +template <int offset> +inline void Prepare5_8(const uint8x16_t src[2], uint8x16_t dst[5]) { + dst[0] = VshrU128<offset + 0>(src); + dst[1] = VshrU128<offset + 1>(src); + dst[2] = VshrU128<offset + 2>(src); + dst[3] = VshrU128<offset + 3>(src); + dst[4] = VshrU128<offset + 4>(src); +} + +inline void Prepare5_16(const uint16x8_t src[2], uint16x4_t low[5], uint16x4_t high[5]) { Prepare3_16(src, low, high); const uint16x8_t s3 = VshrU128<6>(src); @@ -641,6 +674,30 @@ inline uint16x8_t Sum3W_16(const uint8x8_t src[3]) { return vaddw_u8(sum, src[2]); } +inline uint16x8_t Sum3WLo16(const uint8x16_t src[3]) { + const uint16x8_t sum = vaddl_u8(vget_low_u8(src[0]), vget_low_u8(src[1])); + return vaddw_u8(sum, vget_low_u8(src[2])); +} + +inline uint16x8_t Sum3WHi16(const uint8x16_t src[3]) { + const uint16x8_t sum = vaddl_u8(vget_high_u8(src[0]), vget_high_u8(src[1])); + return vaddw_u8(sum, vget_high_u8(src[2])); +} + +inline uint16x8_t Sum5WLo16(const uint8x16_t src[5]) { + const uint16x8_t sum01 = vaddl_u8(vget_low_u8(src[0]), vget_low_u8(src[1])); + const uint16x8_t sum23 = vaddl_u8(vget_low_u8(src[2]), vget_low_u8(src[3])); + const uint16x8_t sum = vaddq_u16(sum01, sum23); + return vaddw_u8(sum, vget_low_u8(src[4])); +} + +inline uint16x8_t Sum5WHi16(const uint8x16_t src[5]) { + const uint16x8_t sum01 = vaddl_u8(vget_high_u8(src[0]), vget_high_u8(src[1])); + const uint16x8_t sum23 = vaddl_u8(vget_high_u8(src[2]), vget_high_u8(src[3])); + const uint16x8_t sum = vaddq_u16(sum01, sum23); + return vaddw_u8(sum, vget_high_u8(src[4])); +} + inline uint32x4_t Sum3W_32(const uint16x4_t src[3]) { const uint32x4_t sum = vaddl_u16(src[0], src[1]); return vaddw_u16(sum, src[2]); @@ -678,13 +735,28 @@ inline uint32x4_t Sum5W_32(const uint16x4_t src[5]) { return vaddw_u16(sum0123, src[4]); } -inline uint16x8_t Sum3Horizontal(const uint8x8x2_t src) { +inline uint16x8_t Sum3Horizontal(const uint8x8_t src[2]) { uint8x8_t s[3]; Prepare3_8(src, s); return Sum3W_16(s); } -inline uint32x4x2_t Sum3WHorizontal(const uint16x8x2_t src) { +inline uint16x8_t Sum3Horizontal(const uint8x16_t src) { + uint8x8_t s[2]; + s[0] = vget_low_u8(src); + s[1] = vget_high_u8(src); + return Sum3Horizontal(s); +} + +template <int offset> +inline void Sum3Horizontal(const uint8x16_t src[2], uint16x8_t dst[2]) { + uint8x16_t s[3]; + Prepare3_8<offset>(src, s); + dst[0] = Sum3WLo16(s); + dst[1] = Sum3WHi16(s); +} + +inline uint32x4x2_t Sum3WHorizontal(const uint16x8_t src[2]) { uint16x4_t low[3], high[3]; uint32x4x2_t sum; Prepare3_16(src, low, high); @@ -693,7 +765,7 @@ inline uint32x4x2_t Sum3WHorizontal(const uint16x8x2_t src) { return sum; } -inline uint16x8_t Sum5Horizontal(const uint8x8x2_t src) { +inline uint16x8_t Sum5Horizontal(const uint8x8_t src[2]) { uint8x8_t s[5]; Prepare5_8(src, s); const uint16x8_t sum01 = vaddl_u8(s[0], s[1]); @@ -702,7 +774,23 @@ inline uint16x8_t Sum5Horizontal(const uint8x8x2_t src) { return vaddw_u8(sum0123, s[4]); } -inline uint32x4x2_t Sum5WHorizontal(const uint16x8x2_t src) { +inline uint16x8_t Sum5Horizontal(const uint8x16_t src) { + uint8x8_t s[2]; + s[0] = vget_low_u8(src); + s[1] = vget_high_u8(src); + return Sum5Horizontal(s); +} + +template <int offset> +inline void Sum5Horizontal(const uint8x16_t src[2], uint16x8_t* const dst0, + uint16x8_t* const dst1) { + uint8x16_t s[5]; + Prepare5_8<offset>(src, s); + *dst0 = Sum5WLo16(s); + *dst1 = Sum5WHi16(s); +} + +inline uint32x4x2_t Sum5WHorizontal(const uint16x8_t src[2]) { uint16x4_t low[5], high[5]; Prepare5_16(src, low, high); uint32x4x2_t sum; @@ -711,35 +799,68 @@ inline uint32x4x2_t Sum5WHorizontal(const uint16x8x2_t src) { return sum; } -void SumHorizontal(const uint16x4_t src[5], uint32x4_t* const row_sq3, - uint32x4_t* const row_sq5) { - const uint32x4_t sum04 = vaddl_u16(src[0], src[4]); - const uint32x4_t sum12 = vaddl_u16(src[1], src[2]); - *row_sq3 = vaddw_u16(sum12, src[3]); - *row_sq5 = vaddq_u32(sum04, *row_sq3); +template <int offset> +void SumHorizontal(const uint8x16_t src[2], uint16x8_t* const row3_0, + uint16x8_t* const row3_1, uint16x8_t* const row5_0, + uint16x8_t* const row5_1) { + uint8x16_t s[5]; + Prepare5_8<offset>(src, s); + const uint16x8_t sum04_lo = vaddl_u8(vget_low_u8(s[0]), vget_low_u8(s[4])); + const uint16x8_t sum04_hi = vaddl_u8(vget_high_u8(s[0]), vget_high_u8(s[4])); + *row3_0 = Sum3WLo16(s + 1); + *row3_1 = Sum3WHi16(s + 1); + *row5_0 = vaddq_u16(sum04_lo, *row3_0); + *row5_1 = vaddq_u16(sum04_hi, *row3_1); } -void SumHorizontal(const uint8x8x2_t src, const uint16x8x2_t sq, - uint16x8_t* const row3, uint16x8_t* const row5, - uint32x4x2_t* const row_sq3, uint32x4x2_t* const row_sq5) { +void SumHorizontal(const uint8x8_t src[2], uint16x8_t* const row3, + uint16x8_t* const row5) { uint8x8_t s[5]; Prepare5_8(src, s); const uint16x8_t sum04 = vaddl_u8(s[0], s[4]); const uint16x8_t sum12 = vaddl_u8(s[1], s[2]); *row3 = vaddw_u8(sum12, s[3]); *row5 = vaddq_u16(sum04, *row3); +} + +void SumHorizontal(const uint16x4_t src[5], uint32x4_t* const row_sq3, + uint32x4_t* const row_sq5) { + const uint32x4_t sum04 = vaddl_u16(src[0], src[4]); + const uint32x4_t sum12 = vaddl_u16(src[1], src[2]); + *row_sq3 = vaddw_u16(sum12, src[3]); + *row_sq5 = vaddq_u32(sum04, *row_sq3); +} + +void SumHorizontal(const uint16x8_t sq[2], uint32x4x2_t* const row_sq3, + uint32x4x2_t* const row_sq5) { uint16x4_t low[5], high[5]; Prepare5_16(sq, low, high); SumHorizontal(low, &row_sq3->val[0], &row_sq5->val[0]); SumHorizontal(high, &row_sq3->val[1], &row_sq5->val[1]); } -inline uint16x8_t Sum343(const uint8x8x2_t src) { - uint8x8_t s[3]; - Prepare3_8(src, s); - const uint16x8_t sum = Sum3W_16(s); +void SumHorizontal(const uint8x8_t src[2], const uint16x8_t sq[2], + uint16x8_t* const row3, uint16x8_t* const row5, + uint32x4x2_t* const row_sq3, uint32x4x2_t* const row_sq5) { + SumHorizontal(src, row3, row5); + SumHorizontal(sq, row_sq3, row_sq5); +} + +void SumHorizontal(const uint8x16_t src, const uint16x8_t sq[2], + uint16x8_t* const row3, uint16x8_t* const row5, + uint32x4x2_t* const row_sq3, uint32x4x2_t* const row_sq5) { + uint8x8_t s[2]; + s[0] = vget_low_u8(src); + s[1] = vget_high_u8(src); + return SumHorizontal(s, sq, row3, row5, row_sq3, row_sq5); +} + +template <int offset> +inline uint16x8_t Sum343(const uint8x16_t ma3[2]) { + const uint16x8_t sum = (offset == 0) ? Sum3WLo16(ma3) : Sum3WHi16(ma3); const uint16x8_t sum3 = Sum3_16(sum, sum, sum); - return vaddw_u8(sum3, s[1]); + return vaddw_u8(sum3, + (offset == 0) ? vget_low_u8(ma3[1]) : vget_high_u8(ma3[1])); } inline uint32x4_t Sum343W(const uint16x4_t src[3]) { @@ -748,7 +869,7 @@ inline uint32x4_t Sum343W(const uint16x4_t src[3]) { return vaddw_u16(sum3, src[1]); } -inline uint32x4x2_t Sum343W(const uint16x8x2_t src) { +inline uint32x4x2_t Sum343W(const uint16x8_t src[2]) { uint16x4_t low[3], high[3]; uint32x4x2_t d; Prepare3_16(src, low, high); @@ -757,13 +878,13 @@ inline uint32x4x2_t Sum343W(const uint16x8x2_t src) { return d; } -inline uint16x8_t Sum565(const uint8x8x2_t src) { - uint8x8_t s[3]; - Prepare3_8(src, s); - const uint16x8_t sum = Sum3W_16(s); +template <int offset> +inline uint16x8_t Sum565(const uint8x16_t ma5[2]) { + const uint16x8_t sum = (offset == 0) ? Sum3WLo16(ma5) : Sum3WHi16(ma5); const uint16x8_t sum4 = vshlq_n_u16(sum, 2); const uint16x8_t sum5 = vaddq_u16(sum4, sum); - return vaddw_u8(sum5, s[1]); + return vaddw_u8(sum5, + (offset == 0) ? vget_low_u8(ma5[1]) : vget_high_u8(ma5[1])); } inline uint32x4_t Sum565W(const uint16x4_t src[3]) { @@ -773,7 +894,7 @@ inline uint32x4_t Sum565W(const uint16x4_t src[3]) { return vaddw_u16(sum5, src[1]); } -inline uint32x4x2_t Sum565W(const uint16x8x2_t src) { +inline uint32x4x2_t Sum565W(const uint16x8_t src[2]) { uint16x4_t low[3], high[3]; uint32x4x2_t d; Prepare3_16(src, low, high); @@ -783,21 +904,21 @@ inline uint32x4x2_t Sum565W(const uint16x8x2_t src) { } inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride, - const int height, const ptrdiff_t sum_stride, uint16_t* sum3, - uint16_t* sum5, uint32_t* square_sum3, - uint32_t* square_sum5) { - int y = height; + const ptrdiff_t sum_stride, uint16_t* sum3, uint16_t* sum5, + uint32_t* square_sum3, uint32_t* square_sum5) { + int y = 2; + // Don't change loop width to 16, which is even slower. do { - uint8x8x2_t s; - uint16x8x2_t sq; - s.val[0] = vld1_u8(src); - sq.val[0] = vmull_u8(s.val[0], s.val[0]); + uint8x8_t s[2]; + uint16x8_t sq[2]; + s[0] = vld1_u8(src); + sq[0] = vmull_u8(s[0], s[0]); ptrdiff_t x = 0; do { uint16x8_t row3, row5; uint32x4x2_t row_sq3, row_sq5; - s.val[1] = vld1_u8(src + x + 8); - sq.val[1] = vmull_u8(s.val[1], s.val[1]); + s[1] = vld1_u8(src + x + 8); + sq[1] = vmull_u8(s[1], s[1]); SumHorizontal(s, sq, &row3, &row5, &row_sq3, &row_sq5); vst1q_u16(sum3, row3); vst1q_u16(sum5, row5); @@ -805,8 +926,8 @@ inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride, vst1q_u32(square_sum3 + 4, row_sq3.val[1]); vst1q_u32(square_sum5 + 0, row_sq5.val[0]); vst1q_u32(square_sum5 + 4, row_sq5.val[1]); - s.val[0] = s.val[1]; - sq.val[0] = sq.val[1]; + s[0] = s[1]; + sq[0] = sq[1]; sum3 += 8; sum5 += 8; square_sum3 += 8; @@ -819,21 +940,22 @@ inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride, template <int size> inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride, - const int height, const ptrdiff_t sum_stride, uint16_t* sums, + const ptrdiff_t sum_stride, uint16_t* sums, uint32_t* square_sums) { static_assert(size == 3 || size == 5, ""); - int y = height; + int y = 2; + // Don't change loop width to 16, which is even slower. do { - uint8x8x2_t s; - uint16x8x2_t sq; - s.val[0] = vld1_u8(src); - sq.val[0] = vmull_u8(s.val[0], s.val[0]); + uint8x8_t s[2]; + uint16x8_t sq[2]; + s[0] = vld1_u8(src); + sq[0] = vmull_u8(s[0], s[0]); ptrdiff_t x = 0; do { uint16x8_t row; uint32x4x2_t row_sq; - s.val[1] = vld1_u8(src + x + 8); - sq.val[1] = vmull_u8(s.val[1], s.val[1]); + s[1] = vld1_u8(src + x + 8); + sq[1] = vmull_u8(s[1], s[1]); if (size == 3) { row = Sum3Horizontal(s); row_sq = Sum3WHorizontal(sq); @@ -844,8 +966,8 @@ inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride, vst1q_u16(sums, row); vst1q_u32(square_sums + 0, row_sq.val[0]); vst1q_u32(square_sums + 4, row_sq.val[1]); - s.val[0] = s.val[1]; - sq.val[0] = sq.val[1]; + s[0] = s[1]; + sq[0] = sq[1]; sums += 8; square_sums += 8; x += 8; @@ -871,10 +993,18 @@ inline uint16x4_t CalculateMa(const uint16x4_t sum, const uint32x4_t sum_sq, return vmovn_u32(shifted); } -template <int n> +inline uint8x8_t AdjustValue(const uint8x8_t value, const uint8x8_t index, + const int threshold) { + const uint8x8_t thresholds = vdup_n_u8(threshold); + const uint8x8_t offset = vcgt_u8(index, thresholds); + // Adding 255 is equivalent to subtracting 1 for 8-bit data. + return vadd_u8(value, offset); +} + +template <int n, int offset> inline void CalculateIntermediate(const uint16x8_t sum, const uint32x4x2_t sum_sq, - const uint32_t scale, uint8x8_t* const ma, + const uint32_t scale, uint8x16_t* const ma, uint16x8_t* const b) { constexpr uint32_t one_over_n = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n; @@ -882,19 +1012,39 @@ inline void CalculateIntermediate(const uint16x8_t sum, const uint16x4_t z1 = CalculateMa<n>(vget_high_u16(sum), sum_sq.val[1], scale); const uint16x8_t z01 = vcombine_u16(z0, z1); - // Using vqmovn_u16() needs an extra sign extension instruction. - const uint16x8_t z = vminq_u16(z01, vdupq_n_u16(255)); - // Using vgetq_lane_s16() can save the sign extension instruction. - const uint8_t lookup[8] = { - kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 0)], - kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 1)], - kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 2)], - kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 3)], - kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 4)], - kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 5)], - kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 6)], - kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 7)]}; - *ma = vld1_u8(lookup); + const uint8x8_t idx = vqmovn_u16(z01); + // Use table lookup to read elements whose indices are less than 48. + // Using one uint8x8x4_t vector and one uint8x8x2_t vector is faster than + // using two uint8x8x3_t vectors. + uint8x8x4_t table0; + uint8x8x2_t table1; + table0.val[0] = vld1_u8(kSgrMaLookup + 0 * 8); + table0.val[1] = vld1_u8(kSgrMaLookup + 1 * 8); + table0.val[2] = vld1_u8(kSgrMaLookup + 2 * 8); + table0.val[3] = vld1_u8(kSgrMaLookup + 3 * 8); + table1.val[0] = vld1_u8(kSgrMaLookup + 4 * 8); + table1.val[1] = vld1_u8(kSgrMaLookup + 5 * 8); + // All elements whose indices are out of range [0, 47] are set to 0. + uint8x8_t val = vtbl4_u8(table0, idx); // Range [0, 31]. + // Subtract 8 to shuffle the next index range. + const uint8x8_t index = vsub_u8(idx, vdup_n_u8(32)); + const uint8x8_t res = vtbl2_u8(table1, index); // Range [32, 47]. + // Use OR instruction to combine shuffle results together. + val = vorr_u8(val, res); + + // For elements whose indices are larger than 47, since they seldom change + // values with the increase of the index, we use comparison and arithmetic + // operations to calculate their values. + // Elements whose indices are larger than 47 (with value 0) are set to 5. + val = vmax_u8(val, vdup_n_u8(5)); + val = AdjustValue(val, idx, 55); // 55 is the last index which value is 5. + val = AdjustValue(val, idx, 72); // 72 is the last index which value is 4. + val = AdjustValue(val, idx, 101); // 101 is the last index which value is 3. + val = AdjustValue(val, idx, 169); // 169 is the last index which value is 2. + val = AdjustValue(val, idx, 254); // 254 is the last index which value is 1. + *ma = (offset == 0) ? vcombine_u8(val, vget_high_u8(*ma)) + : vcombine_u8(vget_low_u8(*ma), val); + // b = ma * b * one_over_n // |ma| = [0, 255] // |sum| is a box sum with radius 1 or 2. @@ -906,7 +1056,8 @@ inline void CalculateIntermediate(const uint16x8_t sum, // |kSgrProjReciprocalBits| is 12. // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits). // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits). - const uint16x8_t maq = vmovl_u8(*ma); + const uint16x8_t maq = + vmovl_u8((offset == 0) ? vget_low_u8(*ma) : vget_high_u8(*ma)); const uint32x4_t m0 = vmull_u16(vget_low_u16(maq), vget_low_u16(sum)); const uint32x4_t m1 = vmull_u16(vget_high_u16(maq), vget_high_u16(sum)); const uint32x4_t m2 = vmulq_n_u32(m0, one_over_n); @@ -916,37 +1067,39 @@ inline void CalculateIntermediate(const uint16x8_t sum, *b = vcombine_u16(b_lo, b_hi); } +template <int offset> inline void CalculateIntermediate5(const uint16x8_t s5[5], const uint32x4x2_t sq5[5], - const uint32_t scale, uint8x8_t* const ma, + const uint32_t scale, uint8x16_t* const ma, uint16x8_t* const b) { const uint16x8_t sum = Sum5_16(s5); const uint32x4x2_t sum_sq = Sum5_32(sq5); - CalculateIntermediate<25>(sum, sum_sq, scale, ma, b); + CalculateIntermediate<25, offset>(sum, sum_sq, scale, ma, b); } +template <int offset> inline void CalculateIntermediate3(const uint16x8_t s3[3], const uint32x4x2_t sq3[3], - const uint32_t scale, uint8x8_t* const ma, + const uint32_t scale, uint8x16_t* const ma, uint16x8_t* const b) { const uint16x8_t sum = Sum3_16(s3); const uint32x4x2_t sum_sq = Sum3_32(sq3); - CalculateIntermediate<9>(sum, sum_sq, scale, ma, b); + CalculateIntermediate<9, offset>(sum, sum_sq, scale, ma, b); } -inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3, +template <int offset> +inline void Store343_444(const uint8x16_t ma3[3], const uint16x8_t b3[2], const ptrdiff_t x, uint16x8_t* const sum_ma343, uint16x8_t* const sum_ma444, uint32x4x2_t* const sum_b343, uint32x4x2_t* const sum_b444, uint16_t* const ma343, uint16_t* const ma444, uint32_t* const b343, uint32_t* const b444) { - uint8x8_t s[3]; - Prepare3_8(ma3, s); - const uint16x8_t sum_ma111 = Sum3W_16(s); + const uint16x8_t sum_ma111 = (offset == 0) ? Sum3WLo16(ma3) : Sum3WHi16(ma3); *sum_ma444 = vshlq_n_u16(sum_ma111, 2); const uint16x8_t sum333 = vsubq_u16(*sum_ma444, sum_ma111); - *sum_ma343 = vaddw_u8(sum333, s[1]); + *sum_ma343 = vaddw_u8( + sum333, (offset == 0) ? vget_low_u8(ma3[1]) : vget_high_u8(ma3[1])); uint16x4_t low[3], high[3]; uint32x4x2_t sum_b111; Prepare3_16(b3, low, high); @@ -966,93 +1119,211 @@ inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3, vst1q_u32(b444 + x + 4, sum_b444->val[1]); } -inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3, +template <int offset> +inline void Store343_444(const uint8x16_t ma3[3], const uint16x8_t b3[2], const ptrdiff_t x, uint16x8_t* const sum_ma343, uint32x4x2_t* const sum_b343, uint16_t* const ma343, uint16_t* const ma444, uint32_t* const b343, uint32_t* const b444) { uint16x8_t sum_ma444; uint32x4x2_t sum_b444; - Store343_444(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, &sum_b444, ma343, - ma444, b343, b444); + Store343_444<offset>(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, &sum_b444, + ma343, ma444, b343, b444); } -inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3, +template <int offset> +inline void Store343_444(const uint8x16_t ma3[3], const uint16x8_t b3[2], const ptrdiff_t x, uint16_t* const ma343, uint16_t* const ma444, uint32_t* const b343, uint32_t* const b444) { uint16x8_t sum_ma343; uint32x4x2_t sum_b343; - Store343_444(ma3, b3, x, &sum_ma343, &sum_b343, ma343, ma444, b343, b444); + Store343_444<offset>(ma3, b3, x, &sum_ma343, &sum_b343, ma343, ma444, b343, + b444); } -LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5( - const uint8_t* const src0, const uint8_t* const src1, const ptrdiff_t x, - const uint32_t scale, uint16_t* const sum5[5], - uint32_t* const square_sum5[5], uint8x8x2_t s[2], uint16x8x2_t sq[2], - uint8x8_t* const ma, uint16x8_t* const b) { +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5Lo( + const uint8_t* const src0, const uint8_t* const src1, const uint32_t scale, + uint8x16_t s[2][2], uint16_t* const sum5[5], uint32_t* const square_sum5[5], + uint16x8_t sq[2][4], uint8x16_t* const ma, uint16x8_t* const b) { uint16x8_t s5[5]; uint32x4x2_t sq5[5]; - s[0].val[1] = vld1_u8(src0 + x + 8); - s[1].val[1] = vld1_u8(src1 + x + 8); - sq[0].val[1] = vmull_u8(s[0].val[1], s[0].val[1]); - sq[1].val[1] = vmull_u8(s[1].val[1], s[1].val[1]); - s5[3] = Sum5Horizontal(s[0]); - s5[4] = Sum5Horizontal(s[1]); + s[0][0] = vld1q_u8(src0); + s[1][0] = vld1q_u8(src1); + sq[0][0] = vmull_u8(vget_low_u8(s[0][0]), vget_low_u8(s[0][0])); + sq[1][0] = vmull_u8(vget_low_u8(s[1][0]), vget_low_u8(s[1][0])); + sq[0][1] = vmull_u8(vget_high_u8(s[0][0]), vget_high_u8(s[0][0])); + sq[1][1] = vmull_u8(vget_high_u8(s[1][0]), vget_high_u8(s[1][0])); + s5[3] = Sum5Horizontal(s[0][0]); + s5[4] = Sum5Horizontal(s[1][0]); sq5[3] = Sum5WHorizontal(sq[0]); sq5[4] = Sum5WHorizontal(sq[1]); - vst1q_u16(sum5[3] + x, s5[3]); - vst1q_u16(sum5[4] + x, s5[4]); + vst1q_u16(sum5[3], s5[3]); + vst1q_u16(sum5[4], s5[4]); + vst1q_u32(square_sum5[3] + 0, sq5[3].val[0]); + vst1q_u32(square_sum5[3] + 4, sq5[3].val[1]); + vst1q_u32(square_sum5[4] + 0, sq5[4].val[0]); + vst1q_u32(square_sum5[4] + 4, sq5[4].val[1]); + s5[0] = vld1q_u16(sum5[0]); + s5[1] = vld1q_u16(sum5[1]); + s5[2] = vld1q_u16(sum5[2]); + sq5[0].val[0] = vld1q_u32(square_sum5[0] + 0); + sq5[0].val[1] = vld1q_u32(square_sum5[0] + 4); + sq5[1].val[0] = vld1q_u32(square_sum5[1] + 0); + sq5[1].val[1] = vld1q_u32(square_sum5[1] + 4); + sq5[2].val[0] = vld1q_u32(square_sum5[2] + 0); + sq5[2].val[1] = vld1q_u32(square_sum5[2] + 4); + CalculateIntermediate5<0>(s5, sq5, scale, ma, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5( + const uint8_t* const src0, const uint8_t* const src1, const ptrdiff_t x, + const uint32_t scale, uint8x16_t s[2][2], uint16_t* const sum5[5], + uint32_t* const square_sum5[5], uint16x8_t sq[2][4], uint8x16_t ma[2], + uint16x8_t b[2]) { + uint16x8_t s5[2][5]; + uint32x4x2_t sq5[5]; + s[0][1] = vld1q_u8(src0 + x + 8); + s[1][1] = vld1q_u8(src1 + x + 8); + sq[0][2] = vmull_u8(vget_low_u8(s[0][1]), vget_low_u8(s[0][1])); + sq[1][2] = vmull_u8(vget_low_u8(s[1][1]), vget_low_u8(s[1][1])); + Sum5Horizontal<8>(s[0], &s5[0][3], &s5[1][3]); + Sum5Horizontal<8>(s[1], &s5[0][4], &s5[1][4]); + sq5[3] = Sum5WHorizontal(sq[0] + 1); + sq5[4] = Sum5WHorizontal(sq[1] + 1); + vst1q_u16(sum5[3] + x, s5[0][3]); + vst1q_u16(sum5[4] + x, s5[0][4]); vst1q_u32(square_sum5[3] + x + 0, sq5[3].val[0]); vst1q_u32(square_sum5[3] + x + 4, sq5[3].val[1]); vst1q_u32(square_sum5[4] + x + 0, sq5[4].val[0]); vst1q_u32(square_sum5[4] + x + 4, sq5[4].val[1]); - s5[0] = vld1q_u16(sum5[0] + x); - s5[1] = vld1q_u16(sum5[1] + x); - s5[2] = vld1q_u16(sum5[2] + x); + s5[0][0] = vld1q_u16(sum5[0] + x); + s5[0][1] = vld1q_u16(sum5[1] + x); + s5[0][2] = vld1q_u16(sum5[2] + x); sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0); sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4); sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0); sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4); sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0); sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4); - CalculateIntermediate5(s5, sq5, scale, ma, b); + CalculateIntermediate5<8>(s5[0], sq5, scale, &ma[0], &b[0]); + + sq[0][3] = vmull_u8(vget_high_u8(s[0][1]), vget_high_u8(s[0][1])); + sq[1][3] = vmull_u8(vget_high_u8(s[1][1]), vget_high_u8(s[1][1])); + sq5[3] = Sum5WHorizontal(sq[0] + 2); + sq5[4] = Sum5WHorizontal(sq[1] + 2); + vst1q_u16(sum5[3] + x + 8, s5[1][3]); + vst1q_u16(sum5[4] + x + 8, s5[1][4]); + vst1q_u32(square_sum5[3] + x + 8, sq5[3].val[0]); + vst1q_u32(square_sum5[3] + x + 12, sq5[3].val[1]); + vst1q_u32(square_sum5[4] + x + 8, sq5[4].val[0]); + vst1q_u32(square_sum5[4] + x + 12, sq5[4].val[1]); + s5[1][0] = vld1q_u16(sum5[0] + x + 8); + s5[1][1] = vld1q_u16(sum5[1] + x + 8); + s5[1][2] = vld1q_u16(sum5[2] + x + 8); + sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 8); + sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 12); + sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 8); + sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 12); + sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 8); + sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 12); + CalculateIntermediate5<0>(s5[1], sq5, scale, &ma[1], &b[1]); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRowLo( + const uint8_t* const src, const uint32_t scale, uint8x16_t* const s, + const uint16_t* const sum5[5], const uint32_t* const square_sum5[5], + uint16x8_t sq[2], uint8x16_t* const ma, uint16x8_t* const b) { + uint16x8_t s5[5]; + uint32x4x2_t sq5[5]; + *s = vld1q_u8(src); + sq[0] = vmull_u8(vget_low_u8(*s), vget_low_u8(*s)); + sq[1] = vmull_u8(vget_high_u8(*s), vget_high_u8(*s)); + s5[3] = s5[4] = Sum5Horizontal(*s); + sq5[3] = sq5[4] = Sum5WHorizontal(sq); + s5[0] = vld1q_u16(sum5[0]); + s5[1] = vld1q_u16(sum5[1]); + s5[2] = vld1q_u16(sum5[2]); + sq5[0].val[0] = vld1q_u32(square_sum5[0] + 0); + sq5[0].val[1] = vld1q_u32(square_sum5[0] + 4); + sq5[1].val[0] = vld1q_u32(square_sum5[1] + 0); + sq5[1].val[1] = vld1q_u32(square_sum5[1] + 4); + sq5[2].val[0] = vld1q_u32(square_sum5[2] + 0); + sq5[2].val[1] = vld1q_u32(square_sum5[2] + 4); + CalculateIntermediate5<0>(s5, sq5, scale, ma, b); } LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRow( const uint8_t* const src, const ptrdiff_t x, const uint32_t scale, - const uint16_t* const sum5[5], const uint32_t* const square_sum5[5], - uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma, - uint16x8_t* const b) { - uint16x8_t s5[5]; + uint8x16_t s[2], const uint16_t* const sum5[5], + const uint32_t* const square_sum5[5], uint16x8_t sq[3], uint8x16_t ma[2], + uint16x8_t b[2]) { + uint16x8_t s5[2][5]; uint32x4x2_t sq5[5]; - s->val[1] = vld1_u8(src + x + 8); - sq->val[1] = vmull_u8(s->val[1], s->val[1]); - s5[3] = s5[4] = Sum5Horizontal(*s); - sq5[3] = sq5[4] = Sum5WHorizontal(*sq); - s5[0] = vld1q_u16(sum5[0] + x); - s5[1] = vld1q_u16(sum5[1] + x); - s5[2] = vld1q_u16(sum5[2] + x); + s[1] = vld1q_u8(src + x + 8); + sq[1] = vmull_u8(vget_low_u8(s[1]), vget_low_u8(s[1])); + Sum5Horizontal<8>(s, &s5[0][3], &s5[1][3]); + sq5[3] = sq5[4] = Sum5WHorizontal(sq); + s5[0][0] = vld1q_u16(sum5[0] + x); + s5[0][1] = vld1q_u16(sum5[1] + x); + s5[0][2] = vld1q_u16(sum5[2] + x); + s5[0][4] = s5[0][3]; sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0); sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4); sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0); sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4); sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0); sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4); - CalculateIntermediate5(s5, sq5, scale, ma, b); + CalculateIntermediate5<8>(s5[0], sq5, scale, &ma[0], &b[0]); + + sq[2] = vmull_u8(vget_high_u8(s[1]), vget_high_u8(s[1])); + sq5[3] = sq5[4] = Sum5WHorizontal(sq + 1); + s5[1][0] = vld1q_u16(sum5[0] + x + 8); + s5[1][1] = vld1q_u16(sum5[1] + x + 8); + s5[1][2] = vld1q_u16(sum5[2] + x + 8); + s5[1][4] = s5[1][3]; + sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 8); + sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 12); + sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 8); + sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 12); + sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 8); + sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 12); + CalculateIntermediate5<0>(s5[1], sq5, scale, &ma[1], &b[1]); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3Lo( + const uint8_t* const src, const uint32_t scale, uint8x16_t* const s, + uint16_t* const sum3[3], uint32_t* const square_sum3[3], uint16x8_t sq[2], + uint8x16_t* const ma, uint16x8_t* const b) { + uint16x8_t s3[3]; + uint32x4x2_t sq3[3]; + *s = vld1q_u8(src); + sq[0] = vmull_u8(vget_low_u8(*s), vget_low_u8(*s)); + sq[1] = vmull_u8(vget_high_u8(*s), vget_high_u8(*s)); + s3[2] = Sum3Horizontal(*s); + sq3[2] = Sum3WHorizontal(sq); + vst1q_u16(sum3[2], s3[2]); + vst1q_u32(square_sum3[2] + 0, sq3[2].val[0]); + vst1q_u32(square_sum3[2] + 4, sq3[2].val[1]); + s3[0] = vld1q_u16(sum3[0]); + s3[1] = vld1q_u16(sum3[1]); + sq3[0].val[0] = vld1q_u32(square_sum3[0] + 0); + sq3[0].val[1] = vld1q_u32(square_sum3[0] + 4); + sq3[1].val[0] = vld1q_u32(square_sum3[1] + 0); + sq3[1].val[1] = vld1q_u32(square_sum3[1] + 4); + CalculateIntermediate3<0>(s3, sq3, scale, ma, b); } LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3( const uint8_t* const src, const ptrdiff_t x, const uint32_t scale, - uint16_t* const sum3[3], uint32_t* const square_sum3[3], - uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma, - uint16x8_t* const b) { - uint16x8_t s3[3]; + uint16_t* const sum3[3], uint32_t* const square_sum3[3], uint8x16_t s[2], + uint16x8_t sq[3], uint8x16_t ma[2], uint16x8_t b[2]) { + uint16x8_t s3[4]; uint32x4x2_t sq3[3]; - s->val[1] = vld1_u8(src + x + 8); - sq->val[1] = vmull_u8(s->val[1], s->val[1]); - s3[2] = Sum3Horizontal(*s); - sq3[2] = Sum3WHorizontal(*sq); + s[1] = vld1q_u8(src + x + 8); + sq[1] = vmull_u8(vget_low_u8(s[1]), vget_low_u8(s[1])); + Sum3Horizontal<8>(s, s3 + 2); + sq3[2] = Sum3WHorizontal(sq); vst1q_u16(sum3[2] + x, s3[2]); vst1q_u32(square_sum3[2] + x + 0, sq3[2].val[0]); vst1q_u32(square_sum3[2] + x + 4, sq3[2].val[1]); @@ -1062,71 +1333,204 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3( sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4); sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0); sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4); - CalculateIntermediate3(s3, sq3, scale, ma, b); + CalculateIntermediate3<8>(s3, sq3, scale, &ma[0], &b[0]); + + sq[2] = vmull_u8(vget_high_u8(s[1]), vget_high_u8(s[1])); + sq3[2] = Sum3WHorizontal(sq + 1); + vst1q_u16(sum3[2] + x + 8, s3[3]); + vst1q_u32(square_sum3[2] + x + 8, sq3[2].val[0]); + vst1q_u32(square_sum3[2] + x + 12, sq3[2].val[1]); + s3[1] = vld1q_u16(sum3[0] + x + 8); + s3[2] = vld1q_u16(sum3[1] + x + 8); + sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 8); + sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 12); + sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 8); + sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 12); + CalculateIntermediate3<0>(s3 + 1, sq3, scale, &ma[1], &b[1]); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLo( + const uint8_t* const src0, const uint8_t* const src1, + const uint16_t scales[2], uint8x16_t s[2][2], uint16_t* const sum3[4], + uint16_t* const sum5[5], uint32_t* const square_sum3[4], + uint32_t* const square_sum5[5], uint16x8_t sq[2][4], uint8x16_t ma3[2][2], + uint16x8_t b3[2][3], uint8x16_t* const ma5, uint16x8_t* const b5) { + uint16x8_t s3[4], s5[5]; + uint32x4x2_t sq3[4], sq5[5]; + s[0][0] = vld1q_u8(src0); + s[1][0] = vld1q_u8(src1); + sq[0][0] = vmull_u8(vget_low_u8(s[0][0]), vget_low_u8(s[0][0])); + sq[1][0] = vmull_u8(vget_low_u8(s[1][0]), vget_low_u8(s[1][0])); + sq[0][1] = vmull_u8(vget_high_u8(s[0][0]), vget_high_u8(s[0][0])); + sq[1][1] = vmull_u8(vget_high_u8(s[1][0]), vget_high_u8(s[1][0])); + SumHorizontal(s[0][0], sq[0], &s3[2], &s5[3], &sq3[2], &sq5[3]); + SumHorizontal(s[1][0], sq[1], &s3[3], &s5[4], &sq3[3], &sq5[4]); + vst1q_u16(sum3[2], s3[2]); + vst1q_u16(sum3[3], s3[3]); + vst1q_u32(square_sum3[2] + 0, sq3[2].val[0]); + vst1q_u32(square_sum3[2] + 4, sq3[2].val[1]); + vst1q_u32(square_sum3[3] + 0, sq3[3].val[0]); + vst1q_u32(square_sum3[3] + 4, sq3[3].val[1]); + vst1q_u16(sum5[3], s5[3]); + vst1q_u16(sum5[4], s5[4]); + vst1q_u32(square_sum5[3] + 0, sq5[3].val[0]); + vst1q_u32(square_sum5[3] + 4, sq5[3].val[1]); + vst1q_u32(square_sum5[4] + 0, sq5[4].val[0]); + vst1q_u32(square_sum5[4] + 4, sq5[4].val[1]); + s3[0] = vld1q_u16(sum3[0]); + s3[1] = vld1q_u16(sum3[1]); + sq3[0].val[0] = vld1q_u32(square_sum3[0] + 0); + sq3[0].val[1] = vld1q_u32(square_sum3[0] + 4); + sq3[1].val[0] = vld1q_u32(square_sum3[1] + 0); + sq3[1].val[1] = vld1q_u32(square_sum3[1] + 4); + s5[0] = vld1q_u16(sum5[0]); + s5[1] = vld1q_u16(sum5[1]); + s5[2] = vld1q_u16(sum5[2]); + sq5[0].val[0] = vld1q_u32(square_sum5[0] + 0); + sq5[0].val[1] = vld1q_u32(square_sum5[0] + 4); + sq5[1].val[0] = vld1q_u32(square_sum5[1] + 0); + sq5[1].val[1] = vld1q_u32(square_sum5[1] + 4); + sq5[2].val[0] = vld1q_u32(square_sum5[2] + 0); + sq5[2].val[1] = vld1q_u32(square_sum5[2] + 4); + CalculateIntermediate3<0>(s3, sq3, scales[1], ma3[0], b3[0]); + CalculateIntermediate3<0>(s3 + 1, sq3 + 1, scales[1], ma3[1], b3[1]); + CalculateIntermediate5<0>(s5, sq5, scales[0], ma5, b5); } LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess( const uint8_t* const src0, const uint8_t* const src1, const ptrdiff_t x, - const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5], - uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], - uint8x8x2_t s[2], uint16x8x2_t sq[2], uint8x8_t* const ma3_0, - uint8x8_t* const ma3_1, uint16x8_t* const b3_0, uint16x8_t* const b3_1, - uint8x8_t* const ma5, uint16x8_t* const b5) { - uint16x8_t s3[4], s5[5]; + const uint16_t scales[2], uint8x16_t s[2][2], uint16_t* const sum3[4], + uint16_t* const sum5[5], uint32_t* const square_sum3[4], + uint32_t* const square_sum5[5], uint16x8_t sq[2][4], uint8x16_t ma3[2][2], + uint16x8_t b3[2][3], uint8x16_t ma5[2], uint16x8_t b5[2]) { + uint16x8_t s3[2][4], s5[2][5]; uint32x4x2_t sq3[4], sq5[5]; - s[0].val[1] = vld1_u8(src0 + x + 8); - s[1].val[1] = vld1_u8(src1 + x + 8); - sq[0].val[1] = vmull_u8(s[0].val[1], s[0].val[1]); - sq[1].val[1] = vmull_u8(s[1].val[1], s[1].val[1]); - SumHorizontal(s[0], sq[0], &s3[2], &s5[3], &sq3[2], &sq5[3]); - SumHorizontal(s[1], sq[1], &s3[3], &s5[4], &sq3[3], &sq5[4]); - vst1q_u16(sum3[2] + x, s3[2]); - vst1q_u16(sum3[3] + x, s3[3]); + s[0][1] = vld1q_u8(src0 + x + 8); + s[1][1] = vld1q_u8(src1 + x + 8); + sq[0][2] = vmull_u8(vget_low_u8(s[0][1]), vget_low_u8(s[0][1])); + sq[1][2] = vmull_u8(vget_low_u8(s[1][1]), vget_low_u8(s[1][1])); + SumHorizontal<8>(s[0], &s3[0][2], &s3[1][2], &s5[0][3], &s5[1][3]); + SumHorizontal<8>(s[1], &s3[0][3], &s3[1][3], &s5[0][4], &s5[1][4]); + SumHorizontal(sq[0] + 1, &sq3[2], &sq5[3]); + SumHorizontal(sq[1] + 1, &sq3[3], &sq5[4]); + vst1q_u16(sum3[2] + x, s3[0][2]); + vst1q_u16(sum3[3] + x, s3[0][3]); vst1q_u32(square_sum3[2] + x + 0, sq3[2].val[0]); vst1q_u32(square_sum3[2] + x + 4, sq3[2].val[1]); vst1q_u32(square_sum3[3] + x + 0, sq3[3].val[0]); vst1q_u32(square_sum3[3] + x + 4, sq3[3].val[1]); - vst1q_u16(sum5[3] + x, s5[3]); - vst1q_u16(sum5[4] + x, s5[4]); + vst1q_u16(sum5[3] + x, s5[0][3]); + vst1q_u16(sum5[4] + x, s5[0][4]); vst1q_u32(square_sum5[3] + x + 0, sq5[3].val[0]); vst1q_u32(square_sum5[3] + x + 4, sq5[3].val[1]); vst1q_u32(square_sum5[4] + x + 0, sq5[4].val[0]); vst1q_u32(square_sum5[4] + x + 4, sq5[4].val[1]); - s3[0] = vld1q_u16(sum3[0] + x); - s3[1] = vld1q_u16(sum3[1] + x); + s3[0][0] = vld1q_u16(sum3[0] + x); + s3[0][1] = vld1q_u16(sum3[1] + x); sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0); sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4); sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0); sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4); - s5[0] = vld1q_u16(sum5[0] + x); - s5[1] = vld1q_u16(sum5[1] + x); - s5[2] = vld1q_u16(sum5[2] + x); + s5[0][0] = vld1q_u16(sum5[0] + x); + s5[0][1] = vld1q_u16(sum5[1] + x); + s5[0][2] = vld1q_u16(sum5[2] + x); sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0); sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4); sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0); sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4); sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0); sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4); - CalculateIntermediate3(s3, sq3, scales[1], ma3_0, b3_0); - CalculateIntermediate3(s3 + 1, sq3 + 1, scales[1], ma3_1, b3_1); - CalculateIntermediate5(s5, sq5, scales[0], ma5, b5); + CalculateIntermediate3<8>(s3[0], sq3, scales[1], &ma3[0][0], &b3[0][1]); + CalculateIntermediate3<8>(s3[0] + 1, sq3 + 1, scales[1], &ma3[1][0], + &b3[1][1]); + CalculateIntermediate5<8>(s5[0], sq5, scales[0], &ma5[0], &b5[0]); + + sq[0][3] = vmull_u8(vget_high_u8(s[0][1]), vget_high_u8(s[0][1])); + sq[1][3] = vmull_u8(vget_high_u8(s[1][1]), vget_high_u8(s[1][1])); + SumHorizontal(sq[0] + 2, &sq3[2], &sq5[3]); + SumHorizontal(sq[1] + 2, &sq3[3], &sq5[4]); + vst1q_u16(sum3[2] + x + 8, s3[1][2]); + vst1q_u16(sum3[3] + x + 8, s3[1][3]); + vst1q_u32(square_sum3[2] + x + 8, sq3[2].val[0]); + vst1q_u32(square_sum3[2] + x + 12, sq3[2].val[1]); + vst1q_u32(square_sum3[3] + x + 8, sq3[3].val[0]); + vst1q_u32(square_sum3[3] + x + 12, sq3[3].val[1]); + vst1q_u16(sum5[3] + x + 8, s5[1][3]); + vst1q_u16(sum5[4] + x + 8, s5[1][4]); + vst1q_u32(square_sum5[3] + x + 8, sq5[3].val[0]); + vst1q_u32(square_sum5[3] + x + 12, sq5[3].val[1]); + vst1q_u32(square_sum5[4] + x + 8, sq5[4].val[0]); + vst1q_u32(square_sum5[4] + x + 12, sq5[4].val[1]); + s3[1][0] = vld1q_u16(sum3[0] + x + 8); + s3[1][1] = vld1q_u16(sum3[1] + x + 8); + sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 8); + sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 12); + sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 8); + sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 12); + s5[1][0] = vld1q_u16(sum5[0] + x + 8); + s5[1][1] = vld1q_u16(sum5[1] + x + 8); + s5[1][2] = vld1q_u16(sum5[2] + x + 8); + sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 8); + sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 12); + sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 8); + sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 12); + sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 8); + sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 12); + CalculateIntermediate3<0>(s3[1], sq3, scales[1], &ma3[0][1], &b3[0][2]); + CalculateIntermediate3<0>(s3[1] + 1, sq3 + 1, scales[1], &ma3[1][1], + &b3[1][2]); + CalculateIntermediate5<0>(s5[1], sq5, scales[0], &ma5[1], &b5[1]); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRowLo( + const uint8_t* const src, const uint16_t scales[2], + const uint16_t* const sum3[4], const uint16_t* const sum5[5], + const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5], + uint8x16_t* const s, uint16x8_t sq[2], uint8x16_t* const ma3, + uint8x16_t* const ma5, uint16x8_t* const b3, uint16x8_t* const b5) { + uint16x8_t s3[3], s5[5]; + uint32x4x2_t sq3[3], sq5[5]; + *s = vld1q_u8(src); + sq[0] = vmull_u8(vget_low_u8(*s), vget_low_u8(*s)); + sq[1] = vmull_u8(vget_high_u8(*s), vget_high_u8(*s)); + SumHorizontal(*s, sq, &s3[2], &s5[3], &sq3[2], &sq5[3]); + s5[0] = vld1q_u16(sum5[0]); + s5[1] = vld1q_u16(sum5[1]); + s5[2] = vld1q_u16(sum5[2]); + s5[4] = s5[3]; + sq5[0].val[0] = vld1q_u32(square_sum5[0] + 0); + sq5[0].val[1] = vld1q_u32(square_sum5[0] + 4); + sq5[1].val[0] = vld1q_u32(square_sum5[1] + 0); + sq5[1].val[1] = vld1q_u32(square_sum5[1] + 4); + sq5[2].val[0] = vld1q_u32(square_sum5[2] + 0); + sq5[2].val[1] = vld1q_u32(square_sum5[2] + 4); + sq5[4] = sq5[3]; + CalculateIntermediate5<0>(s5, sq5, scales[0], ma5, b5); + s3[0] = vld1q_u16(sum3[0]); + s3[1] = vld1q_u16(sum3[1]); + sq3[0].val[0] = vld1q_u32(square_sum3[0] + 0); + sq3[0].val[1] = vld1q_u32(square_sum3[0] + 4); + sq3[1].val[0] = vld1q_u32(square_sum3[1] + 0); + sq3[1].val[1] = vld1q_u32(square_sum3[1] + 4); + CalculateIntermediate3<0>(s3, sq3, scales[1], ma3, b3); } LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow( const uint8_t* const src, const ptrdiff_t x, const uint16_t scales[2], const uint16_t* const sum3[4], const uint16_t* const sum5[5], const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5], - uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma3, - uint8x8_t* const ma5, uint16x8_t* const b3, uint16x8_t* const b5) { - uint16x8_t s3[3], s5[5]; + uint8x16_t s[2], uint16x8_t sq[3], uint8x16_t ma3[2], uint8x16_t ma5[2], + uint16x8_t b3[2], uint16x8_t b5[2]) { + uint16x8_t s3[2][3], s5[2][5]; uint32x4x2_t sq3[3], sq5[5]; - s->val[1] = vld1_u8(src + x + 8); - sq->val[1] = vmull_u8(s->val[1], s->val[1]); - SumHorizontal(*s, *sq, &s3[2], &s5[3], &sq3[2], &sq5[3]); - s5[0] = vld1q_u16(sum5[0] + x); - s5[1] = vld1q_u16(sum5[1] + x); - s5[2] = vld1q_u16(sum5[2] + x); - s5[4] = s5[3]; + s[1] = vld1q_u8(src + x + 8); + sq[1] = vmull_u8(vget_low_u8(s[1]), vget_low_u8(s[1])); + SumHorizontal<8>(s, &s3[0][2], &s3[1][2], &s5[0][3], &s5[1][3]); + SumHorizontal(sq, &sq3[2], &sq5[3]); + s5[0][0] = vld1q_u16(sum5[0] + x); + s5[0][1] = vld1q_u16(sum5[1] + x); + s5[0][2] = vld1q_u16(sum5[2] + x); + s5[0][4] = s5[0][3]; sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0); sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4); sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0); @@ -1134,14 +1538,36 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow( sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0); sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4); sq5[4] = sq5[3]; - CalculateIntermediate5(s5, sq5, scales[0], ma5, b5); - s3[0] = vld1q_u16(sum3[0] + x); - s3[1] = vld1q_u16(sum3[1] + x); + CalculateIntermediate5<8>(s5[0], sq5, scales[0], &ma5[0], &b5[0]); + s3[0][0] = vld1q_u16(sum3[0] + x); + s3[0][1] = vld1q_u16(sum3[1] + x); sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0); sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4); sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0); sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4); - CalculateIntermediate3(s3, sq3, scales[1], ma3, b3); + CalculateIntermediate3<8>(s3[0], sq3, scales[1], &ma3[0], &b3[0]); + + sq[2] = vmull_u8(vget_high_u8(s[1]), vget_high_u8(s[1])); + SumHorizontal(sq + 1, &sq3[2], &sq5[3]); + s5[1][0] = vld1q_u16(sum5[0] + x + 8); + s5[1][1] = vld1q_u16(sum5[1] + x + 8); + s5[1][2] = vld1q_u16(sum5[2] + x + 8); + s5[1][4] = s5[1][3]; + sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 8); + sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 12); + sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 8); + sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 12); + sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 8); + sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 12); + sq5[4] = sq5[3]; + CalculateIntermediate5<0>(s5[1], sq5, scales[0], &ma5[1], &b5[1]); + s3[1][0] = vld1q_u16(sum3[0] + x + 8); + s3[1][1] = vld1q_u16(sum3[1] + x + 8); + sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 8); + sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 12); + sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 8); + sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 12); + CalculateIntermediate3<0>(s3[1], sq3, scales[1], &ma3[1], &b3[1]); } inline void BoxSumFilterPreProcess5(const uint8_t* const src0, @@ -1150,33 +1576,39 @@ inline void BoxSumFilterPreProcess5(const uint8_t* const src0, uint16_t* const sum5[5], uint32_t* const square_sum5[5], uint16_t* ma565, uint32_t* b565) { - uint8x8x2_t s[2], mas; - uint16x8x2_t sq[2], bs; - s[0].val[0] = vld1_u8(src0); - s[1].val[0] = vld1_u8(src1); - sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]); - sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]); - BoxFilterPreProcess5(src0, src1, 0, scale, sum5, square_sum5, s, sq, - &mas.val[0], &bs.val[0]); + uint8x16_t s[2][2], mas[2]; + uint16x8_t sq[2][4], bs[3]; + BoxFilterPreProcess5Lo(src0, src1, scale, s, sum5, square_sum5, sq, &mas[0], + &bs[0]); int x = 0; do { - s[0].val[0] = s[0].val[1]; - s[1].val[0] = s[1].val[1]; - sq[0].val[0] = sq[0].val[1]; - sq[1].val[0] = sq[1].val[1]; - BoxFilterPreProcess5(src0, src1, x + 8, scale, sum5, square_sum5, s, sq, - &mas.val[1], &bs.val[1]); - const uint16x8_t ma = Sum565(mas); - const uint32x4x2_t b = Sum565W(bs); - vst1q_u16(ma565, ma); - vst1q_u32(b565 + 0, b.val[0]); - vst1q_u32(b565 + 4, b.val[1]); - mas.val[0] = mas.val[1]; - bs.val[0] = bs.val[1]; - ma565 += 8; - b565 += 8; - x += 8; + uint16x8_t ma[2]; + uint8x16_t masx[3]; + uint32x4x2_t b[2]; + BoxFilterPreProcess5(src0, src1, x + 8, scale, s, sum5, square_sum5, sq, + mas, bs + 1); + Prepare3_8<0>(mas, masx); + ma[0] = Sum565<0>(masx); + b[0] = Sum565W(bs); + vst1q_u16(ma565, ma[0]); + vst1q_u32(b565 + 0, b[0].val[0]); + vst1q_u32(b565 + 4, b[0].val[1]); + + ma[1] = Sum565<8>(masx); + b[1] = Sum565W(bs + 1); + vst1q_u16(ma565 + 8, ma[1]); + vst1q_u32(b565 + 8, b[1].val[0]); + vst1q_u32(b565 + 12, b[1].val[1]); + s[0][0] = s[0][1]; + s[1][0] = s[1][1]; + sq[0][1] = sq[0][3]; + sq[1][1] = sq[1][3]; + mas[0] = mas[1]; + bs[0] = bs[2]; + ma565 += 16; + b565 += 16; + x += 16; } while (x < width); } @@ -1185,35 +1617,44 @@ LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3( const uint8_t* const src, const int width, const uint32_t scale, uint16_t* const sum3[3], uint32_t* const square_sum3[3], uint16_t* ma343, uint16_t* ma444, uint32_t* b343, uint32_t* b444) { - uint8x8x2_t s, mas; - uint16x8x2_t sq, bs; - s.val[0] = vld1_u8(src); - sq.val[0] = vmull_u8(s.val[0], s.val[0]); - BoxFilterPreProcess3(src, 0, scale, sum3, square_sum3, &s, &sq, &mas.val[0], - &bs.val[0]); + uint8x16_t s[2], mas[2]; + uint16x8_t sq[4], bs[3]; + BoxFilterPreProcess3Lo(src, scale, &s[0], sum3, square_sum3, sq, &mas[0], + &bs[0]); int x = 0; do { - s.val[0] = s.val[1]; - sq.val[0] = sq.val[1]; - BoxFilterPreProcess3(src, x + 8, scale, sum3, square_sum3, &s, &sq, - &mas.val[1], &bs.val[1]); + uint8x16_t ma3x[3]; + BoxFilterPreProcess3(src, x + 8, scale, sum3, square_sum3, s, sq + 1, mas, + bs + 1); + Prepare3_8<0>(mas, ma3x); if (calculate444) { - Store343_444(mas, bs, 0, ma343, ma444, b343, b444); - ma444 += 8; - b444 += 8; + Store343_444<0>(ma3x, bs + 0, 0, ma343, ma444, b343, b444); + Store343_444<8>(ma3x, bs + 1, 0, ma343 + 8, ma444 + 8, b343 + 8, + b444 + 8); + ma444 += 16; + b444 += 16; } else { - const uint16x8_t ma = Sum343(mas); - const uint32x4x2_t b = Sum343W(bs); - vst1q_u16(ma343, ma); - vst1q_u32(b343 + 0, b.val[0]); - vst1q_u32(b343 + 4, b.val[1]); + uint16x8_t ma[2]; + uint32x4x2_t b[2]; + ma[0] = Sum343<0>(ma3x); + b[0] = Sum343W(bs); + vst1q_u16(ma343, ma[0]); + vst1q_u32(b343 + 0, b[0].val[0]); + vst1q_u32(b343 + 4, b[0].val[1]); + ma[1] = Sum343<8>(ma3x); + b[1] = Sum343W(bs + 1); + vst1q_u16(ma343 + 8, ma[1]); + vst1q_u32(b343 + 8, b[1].val[0]); + vst1q_u32(b343 + 12, b[1].val[1]); } - mas.val[0] = mas.val[1]; - bs.val[0] = bs.val[1]; - ma343 += 8; - b343 += 8; - x += 8; + s[0] = s[1]; + sq[1] = sq[3]; + mas[0] = mas[1]; + bs[0] = bs[2]; + ma343 += 16; + b343 += 16; + x += 16; } while (x < width); } @@ -1221,48 +1662,58 @@ inline void BoxSumFilterPreProcess( const uint8_t* const src0, const uint8_t* const src1, const int width, const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5], uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], - uint16_t* const ma343[4], uint16_t* const ma444[2], uint16_t* ma565, - uint32_t* const b343[4], uint32_t* const b444[2], uint32_t* b565) { - uint8x8x2_t s[2]; - uint8x8x2_t ma3[2], ma5; - uint16x8x2_t sq[2], b3[2], b5; - s[0].val[0] = vld1_u8(src0); - s[1].val[0] = vld1_u8(src1); - sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]); - sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]); - BoxFilterPreProcess(src0, src1, 0, scales, sum3, sum5, square_sum3, - square_sum5, s, sq, &ma3[0].val[0], &ma3[1].val[0], - &b3[0].val[0], &b3[1].val[0], &ma5.val[0], &b5.val[0]); + uint16_t* const ma343[4], uint16_t* const ma444, uint16_t* ma565, + uint32_t* const b343[4], uint32_t* const b444, uint32_t* b565) { + uint8x16_t s[2][2], ma3[2][2], ma5[2]; + uint16x8_t sq[2][4], b3[2][3], b5[3]; + BoxFilterPreProcessLo(src0, src1, scales, s, sum3, sum5, square_sum3, + square_sum5, sq, ma3, b3, &ma5[0], &b5[0]); int x = 0; do { - s[0].val[0] = s[0].val[1]; - s[1].val[0] = s[1].val[1]; - sq[0].val[0] = sq[0].val[1]; - sq[1].val[0] = sq[1].val[1]; - BoxFilterPreProcess(src0, src1, x + 8, scales, sum3, sum5, square_sum3, - square_sum5, s, sq, &ma3[0].val[1], &ma3[1].val[1], - &b3[0].val[1], &b3[1].val[1], &ma5.val[1], &b5.val[1]); - uint16x8_t ma = Sum343(ma3[0]); - uint32x4x2_t b = Sum343W(b3[0]); - vst1q_u16(ma343[0] + x, ma); - vst1q_u32(b343[0] + x, b.val[0]); - vst1q_u32(b343[0] + x + 4, b.val[1]); - Store343_444(ma3[1], b3[1], x, ma343[1], ma444[0], b343[1], b444[0]); - ma = Sum565(ma5); - b = Sum565W(b5); - vst1q_u16(ma565, ma); - vst1q_u32(b565 + 0, b.val[0]); - vst1q_u32(b565 + 4, b.val[1]); - ma3[0].val[0] = ma3[0].val[1]; - ma3[1].val[0] = ma3[1].val[1]; - b3[0].val[0] = b3[0].val[1]; - b3[1].val[0] = b3[1].val[1]; - ma5.val[0] = ma5.val[1]; - b5.val[0] = b5.val[1]; - ma565 += 8; - b565 += 8; - x += 8; + uint16x8_t ma[2]; + uint8x16_t ma3x[3], ma5x[3]; + uint32x4x2_t b[2]; + BoxFilterPreProcess(src0, src1, x + 8, scales, s, sum3, sum5, square_sum3, + square_sum5, sq, ma3, b3, ma5, b5 + 1); + Prepare3_8<0>(ma3[0], ma3x); + ma[0] = Sum343<0>(ma3x); + ma[1] = Sum343<8>(ma3x); + b[0] = Sum343W(b3[0] + 0); + b[1] = Sum343W(b3[0] + 1); + vst1q_u16(ma343[0] + x, ma[0]); + vst1q_u16(ma343[0] + x + 8, ma[1]); + vst1q_u32(b343[0] + x, b[0].val[0]); + vst1q_u32(b343[0] + x + 4, b[0].val[1]); + vst1q_u32(b343[0] + x + 8, b[1].val[0]); + vst1q_u32(b343[0] + x + 12, b[1].val[1]); + Prepare3_8<0>(ma3[1], ma3x); + Store343_444<0>(ma3x, b3[1], x, ma343[1], ma444, b343[1], b444); + Store343_444<8>(ma3x, b3[1] + 1, x + 8, ma343[1], ma444, b343[1], b444); + Prepare3_8<0>(ma5, ma5x); + ma[0] = Sum565<0>(ma5x); + ma[1] = Sum565<8>(ma5x); + b[0] = Sum565W(b5); + b[1] = Sum565W(b5 + 1); + vst1q_u16(ma565, ma[0]); + vst1q_u16(ma565 + 8, ma[1]); + vst1q_u32(b565 + 0, b[0].val[0]); + vst1q_u32(b565 + 4, b[0].val[1]); + vst1q_u32(b565 + 8, b[1].val[0]); + vst1q_u32(b565 + 12, b[1].val[1]); + s[0][0] = s[0][1]; + s[1][0] = s[1][1]; + sq[0][1] = sq[0][3]; + sq[1][1] = sq[1][3]; + ma3[0][0] = ma3[0][1]; + ma3[1][0] = ma3[1][1]; + b3[0][0] = b3[0][2]; + b3[1][0] = b3[1][2]; + ma5[0] = ma5[1]; + b5[0] = b5[2]; + ma565 += 16; + b565 += 16; + x += 16; } while (x < width); } @@ -1310,37 +1761,36 @@ inline int16x8_t CalculateFilteredOutputPass2(const uint8x8_t s, return CalculateFilteredOutput<5>(s, ma_sum, b_sum); } -inline void SelfGuidedFinal(const uint8x8_t src, const int32x4_t v[2], - uint8_t* const dst) { +inline uint8x8_t SelfGuidedFinal(const uint8x8_t src, const int32x4_t v[2]) { const int16x4_t v_lo = vrshrn_n_s32(v[0], kSgrProjRestoreBits + kSgrProjPrecisionBits); const int16x4_t v_hi = vrshrn_n_s32(v[1], kSgrProjRestoreBits + kSgrProjPrecisionBits); const int16x8_t vv = vcombine_s16(v_lo, v_hi); - const int16x8_t s = ZeroExtend(src); - const int16x8_t d = vaddq_s16(s, vv); - vst1_u8(dst, vqmovun_s16(d)); + const int16x8_t d = + vreinterpretq_s16_u16(vaddw_u8(vreinterpretq_u16_s16(vv), src)); + return vqmovun_s16(d); } -inline void SelfGuidedDoubleMultiplier(const uint8x8_t src, - const int16x8_t filter[2], const int w0, - const int w2, uint8_t* const dst) { +inline uint8x8_t SelfGuidedDoubleMultiplier(const uint8x8_t src, + const int16x8_t filter[2], + const int w0, const int w2) { int32x4_t v[2]; v[0] = vmull_n_s16(vget_low_s16(filter[0]), w0); v[1] = vmull_n_s16(vget_high_s16(filter[0]), w0); v[0] = vmlal_n_s16(v[0], vget_low_s16(filter[1]), w2); v[1] = vmlal_n_s16(v[1], vget_high_s16(filter[1]), w2); - SelfGuidedFinal(src, v, dst); + return SelfGuidedFinal(src, v); } -inline void SelfGuidedSingleMultiplier(const uint8x8_t src, - const int16x8_t filter, const int w0, - uint8_t* const dst) { +inline uint8x8_t SelfGuidedSingleMultiplier(const uint8x8_t src, + const int16x8_t filter, + const int w0) { // weight: -96 to 96 (Sgrproj_Xqd_Min/Max) int32x4_t v[2]; v[0] = vmull_n_s16(vget_low_s16(filter), w0); v[1] = vmull_n_s16(vget_high_s16(filter), w0); - SelfGuidedFinal(src, v, dst); + return SelfGuidedFinal(src, v); } LIBGAV1_ALWAYS_INLINE void BoxFilterPass1( @@ -1349,43 +1799,60 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterPass1( uint32_t* const square_sum5[5], const int width, const uint32_t scale, const int16_t w0, uint16_t* const ma565[2], uint32_t* const b565[2], uint8_t* const dst) { - uint8x8x2_t s[2], mas; - uint16x8x2_t sq[2], bs; - s[0].val[0] = vld1_u8(src0); - s[1].val[0] = vld1_u8(src1); - sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]); - sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]); - BoxFilterPreProcess5(src0, src1, 0, scale, sum5, square_sum5, s, sq, - &mas.val[0], &bs.val[0]); + uint8x16_t s[2][2], mas[2]; + uint16x8_t sq[2][4], bs[3]; + BoxFilterPreProcess5Lo(src0, src1, scale, s, sum5, square_sum5, sq, &mas[0], + &bs[0]); int x = 0; do { - s[0].val[0] = s[0].val[1]; - s[1].val[0] = s[1].val[1]; - sq[0].val[0] = sq[0].val[1]; - sq[1].val[0] = sq[1].val[1]; - BoxFilterPreProcess5(src0, src1, x + 8, scale, sum5, square_sum5, s, sq, - &mas.val[1], &bs.val[1]); uint16x8_t ma[2]; + uint8x16_t masx[3]; uint32x4x2_t b[2]; - ma[1] = Sum565(mas); + int16x8_t p0, p1; + BoxFilterPreProcess5(src0, src1, x + 8, scale, s, sum5, square_sum5, sq, + mas, bs + 1); + Prepare3_8<0>(mas, masx); + ma[1] = Sum565<0>(masx); b[1] = Sum565W(bs); vst1q_u16(ma565[1] + x, ma[1]); vst1q_u32(b565[1] + x + 0, b[1].val[0]); vst1q_u32(b565[1] + x + 4, b[1].val[1]); - const uint8x8_t sr0 = vld1_u8(src + x); - const uint8x8_t sr1 = vld1_u8(src + stride + x); - int16x8_t p0, p1; + const uint8x16_t sr0 = vld1q_u8(src + x); + const uint8x16_t sr1 = vld1q_u8(src + stride + x); + const uint8x8_t sr00 = vget_low_u8(sr0); + const uint8x8_t sr10 = vget_low_u8(sr1); ma[0] = vld1q_u16(ma565[0] + x); b[0].val[0] = vld1q_u32(b565[0] + x + 0); b[0].val[1] = vld1q_u32(b565[0] + x + 4); - p0 = CalculateFilteredOutputPass1(sr0, ma, b); - p1 = CalculateFilteredOutput<4>(sr1, ma[1], b[1]); - SelfGuidedSingleMultiplier(sr0, p0, w0, dst + x); - SelfGuidedSingleMultiplier(sr1, p1, w0, dst + stride + x); - mas.val[0] = mas.val[1]; - bs.val[0] = bs.val[1]; - x += 8; + p0 = CalculateFilteredOutputPass1(sr00, ma, b); + p1 = CalculateFilteredOutput<4>(sr10, ma[1], b[1]); + const uint8x8_t d00 = SelfGuidedSingleMultiplier(sr00, p0, w0); + const uint8x8_t d10 = SelfGuidedSingleMultiplier(sr10, p1, w0); + + ma[1] = Sum565<8>(masx); + b[1] = Sum565W(bs + 1); + vst1q_u16(ma565[1] + x + 8, ma[1]); + vst1q_u32(b565[1] + x + 8, b[1].val[0]); + vst1q_u32(b565[1] + x + 12, b[1].val[1]); + const uint8x8_t sr01 = vget_high_u8(sr0); + const uint8x8_t sr11 = vget_high_u8(sr1); + ma[0] = vld1q_u16(ma565[0] + x + 8); + b[0].val[0] = vld1q_u32(b565[0] + x + 8); + b[0].val[1] = vld1q_u32(b565[0] + x + 12); + p0 = CalculateFilteredOutputPass1(sr01, ma, b); + p1 = CalculateFilteredOutput<4>(sr11, ma[1], b[1]); + const uint8x8_t d01 = SelfGuidedSingleMultiplier(sr01, p0, w0); + const uint8x8_t d11 = SelfGuidedSingleMultiplier(sr11, p1, w0); + vst1q_u8(dst + x, vcombine_u8(d00, d01)); + vst1q_u8(dst + stride + x, vcombine_u8(d10, d11)); + s[0][0] = s[0][1]; + s[1][0] = s[1][1]; + sq[0][1] = sq[0][3]; + sq[1][1] = sq[1][3]; + mas[0] = mas[1]; + bs[0] = bs[2]; + x += 16; } while (x < width); } @@ -1396,34 +1863,45 @@ inline void BoxFilterPass1LastRow(const uint8_t* const src, uint32_t* const square_sum5[5], uint16_t* ma565, uint32_t* b565, uint8_t* const dst) { - uint8x8x2_t s, mas; - uint16x8x2_t sq, bs; - s.val[0] = vld1_u8(src0); - sq.val[0] = vmull_u8(s.val[0], s.val[0]); - BoxFilterPreProcess5LastRow(src0, 0, scale, sum5, square_sum5, &s, &sq, - &mas.val[0], &bs.val[0]); + uint8x16_t s[2], mas[2]; + uint16x8_t sq[4], bs[4]; + BoxFilterPreProcess5LastRowLo(src0, scale, s, sum5, square_sum5, sq, &mas[0], + &bs[0]); int x = 0; do { - s.val[0] = s.val[1]; - sq.val[0] = sq.val[1]; - BoxFilterPreProcess5LastRow(src0, x + 8, scale, sum5, square_sum5, &s, &sq, - &mas.val[1], &bs.val[1]); uint16x8_t ma[2]; + uint8x16_t masx[3]; uint32x4x2_t b[2]; - ma[1] = Sum565(mas); + BoxFilterPreProcess5LastRow(src0, x + 8, scale, s, sum5, square_sum5, + sq + 1, mas, bs + 1); + Prepare3_8<0>(mas, masx); + ma[1] = Sum565<0>(masx); b[1] = Sum565W(bs); - mas.val[0] = mas.val[1]; - bs.val[0] = bs.val[1]; ma[0] = vld1q_u16(ma565); b[0].val[0] = vld1q_u32(b565 + 0); b[0].val[1] = vld1q_u32(b565 + 4); - const uint8x8_t sr = vld1_u8(src + x); - const int16x8_t p = CalculateFilteredOutputPass1(sr, ma, b); - SelfGuidedSingleMultiplier(sr, p, w0, dst + x); - ma565 += 8; - b565 += 8; - x += 8; + const uint8x16_t sr = vld1q_u8(src + x); + const uint8x8_t sr0 = vget_low_u8(sr); + const int16x8_t p0 = CalculateFilteredOutputPass1(sr0, ma, b); + const uint8x8_t d0 = SelfGuidedSingleMultiplier(sr0, p0, w0); + + ma[1] = Sum565<8>(masx); + b[1] = Sum565W(bs + 1); + bs[0] = bs[2]; + const uint8x8_t sr1 = vget_high_u8(sr); + ma[0] = vld1q_u16(ma565 + 8); + b[0].val[0] = vld1q_u32(b565 + 8); + b[0].val[1] = vld1q_u32(b565 + 12); + const int16x8_t p1 = CalculateFilteredOutputPass1(sr1, ma, b); + const uint8x8_t d1 = SelfGuidedSingleMultiplier(sr1, p1, w0); + vst1q_u8(dst + x, vcombine_u8(d0, d1)); + s[0] = s[1]; + sq[1] = sq[3]; + mas[0] = mas[1]; + ma565 += 16; + b565 += 16; + x += 16; } while (x < width); } @@ -1433,35 +1911,49 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterPass2( uint32_t* const square_sum3[3], uint16_t* const ma343[3], uint16_t* const ma444[2], uint32_t* const b343[3], uint32_t* const b444[2], uint8_t* const dst) { - uint8x8x2_t s, mas; - uint16x8x2_t sq, bs; - s.val[0] = vld1_u8(src0); - sq.val[0] = vmull_u8(s.val[0], s.val[0]); - BoxFilterPreProcess3(src0, 0, scale, sum3, square_sum3, &s, &sq, &mas.val[0], - &bs.val[0]); + uint8x16_t s[2], mas[2]; + uint16x8_t sq[4], bs[3]; + BoxFilterPreProcess3Lo(src0, scale, &s[0], sum3, square_sum3, sq, &mas[0], + &bs[0]); int x = 0; do { - s.val[0] = s.val[1]; - sq.val[0] = sq.val[1]; - BoxFilterPreProcess3(src0, x + 8, scale, sum3, square_sum3, &s, &sq, - &mas.val[1], &bs.val[1]); uint16x8_t ma[3]; + uint8x16_t ma3x[3]; uint32x4x2_t b[3]; - Store343_444(mas, bs, x, &ma[2], &b[2], ma343[2], ma444[1], b343[2], - b444[1]); - const uint8x8_t sr = vld1_u8(src + x); + BoxFilterPreProcess3(src0, x + 8, scale, sum3, square_sum3, s, sq + 1, mas, + bs + 1); + Prepare3_8<0>(mas, ma3x); + Store343_444<0>(ma3x, bs, x, &ma[2], &b[2], ma343[2], ma444[1], b343[2], + b444[1]); + const uint8x16_t sr = vld1q_u8(src + x); + const uint8x8_t sr0 = vget_low_u8(sr); ma[0] = vld1q_u16(ma343[0] + x); ma[1] = vld1q_u16(ma444[0] + x); b[0].val[0] = vld1q_u32(b343[0] + x + 0); b[0].val[1] = vld1q_u32(b343[0] + x + 4); b[1].val[0] = vld1q_u32(b444[0] + x + 0); b[1].val[1] = vld1q_u32(b444[0] + x + 4); - const int16x8_t p = CalculateFilteredOutputPass2(sr, ma, b); - SelfGuidedSingleMultiplier(sr, p, w0, dst + x); - mas.val[0] = mas.val[1]; - bs.val[0] = bs.val[1]; - x += 8; + const int16x8_t p0 = CalculateFilteredOutputPass2(sr0, ma, b); + const uint8x8_t d0 = SelfGuidedSingleMultiplier(sr0, p0, w0); + + Store343_444<8>(ma3x, bs + 1, x + 8, &ma[2], &b[2], ma343[2], ma444[1], + b343[2], b444[1]); + const uint8x8_t sr1 = vget_high_u8(sr); + ma[0] = vld1q_u16(ma343[0] + x + 8); + ma[1] = vld1q_u16(ma444[0] + x + 8); + b[0].val[0] = vld1q_u32(b343[0] + x + 8); + b[0].val[1] = vld1q_u32(b343[0] + x + 12); + b[1].val[0] = vld1q_u32(b444[0] + x + 8); + b[1].val[1] = vld1q_u32(b444[0] + x + 12); + const int16x8_t p1 = CalculateFilteredOutputPass2(sr1, ma, b); + const uint8x8_t d1 = SelfGuidedSingleMultiplier(sr1, p1, w0); + vst1q_u8(dst + x, vcombine_u8(d0, d1)); + s[0] = s[1]; + sq[1] = sq[3]; + mas[0] = mas[1]; + bs[0] = bs[2]; + x += 16; } while (x < width); } @@ -1474,64 +1966,96 @@ LIBGAV1_ALWAYS_INLINE void BoxFilter( uint16_t* const ma343[4], uint16_t* const ma444[3], uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3], uint32_t* const b565[2], uint8_t* const dst) { - uint8x8x2_t s[2], ma3[2], ma5; - uint16x8x2_t sq[2], b3[2], b5; - s[0].val[0] = vld1_u8(src0); - s[1].val[0] = vld1_u8(src1); - sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]); - sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]); - BoxFilterPreProcess(src0, src1, 0, scales, sum3, sum5, square_sum3, - square_sum5, s, sq, &ma3[0].val[0], &ma3[1].val[0], - &b3[0].val[0], &b3[1].val[0], &ma5.val[0], &b5.val[0]); + uint8x16_t s[2][2], ma3[2][2], ma5[2]; + uint16x8_t sq[2][4], b3[2][3], b5[3]; + BoxFilterPreProcessLo(src0, src1, scales, s, sum3, sum5, square_sum3, + square_sum5, sq, ma3, b3, &ma5[0], &b5[0]); int x = 0; do { - s[0].val[0] = s[0].val[1]; - s[1].val[0] = s[1].val[1]; - sq[0].val[0] = sq[0].val[1]; - sq[1].val[0] = sq[1].val[1]; - BoxFilterPreProcess(src0, src1, x + 8, scales, sum3, sum5, square_sum3, - square_sum5, s, sq, &ma3[0].val[1], &ma3[1].val[1], - &b3[0].val[1], &b3[1].val[1], &ma5.val[1], &b5.val[1]); uint16x8_t ma[3][3]; + uint8x16_t ma3x[2][3], ma5x[3]; uint32x4x2_t b[3][3]; - Store343_444(ma3[0], b3[0], x, &ma[1][2], &ma[2][1], &b[1][2], &b[2][1], - ma343[2], ma444[1], b343[2], b444[1]); - Store343_444(ma3[1], b3[1], x, &ma[2][2], &b[2][2], ma343[3], ma444[2], - b343[3], b444[2]); - ma[0][1] = Sum565(ma5); + int16x8_t p[2][2]; + BoxFilterPreProcess(src0, src1, x + 8, scales, s, sum3, sum5, square_sum3, + square_sum5, sq, ma3, b3, ma5, b5 + 1); + Prepare3_8<0>(ma3[0], ma3x[0]); + Prepare3_8<0>(ma3[1], ma3x[1]); + Store343_444<0>(ma3x[0], b3[0], x, &ma[1][2], &ma[2][1], &b[1][2], &b[2][1], + ma343[2], ma444[1], b343[2], b444[1]); + Store343_444<0>(ma3x[1], b3[1], x, &ma[2][2], &b[2][2], ma343[3], ma444[2], + b343[3], b444[2]); + Prepare3_8<0>(ma5, ma5x); + ma[0][1] = Sum565<0>(ma5x); b[0][1] = Sum565W(b5); vst1q_u16(ma565[1] + x, ma[0][1]); vst1q_u32(b565[1] + x, b[0][1].val[0]); vst1q_u32(b565[1] + x + 4, b[0][1].val[1]); - ma3[0].val[0] = ma3[0].val[1]; - ma3[1].val[0] = ma3[1].val[1]; - b3[0].val[0] = b3[0].val[1]; - b3[1].val[0] = b3[1].val[1]; - ma5.val[0] = ma5.val[1]; - b5.val[0] = b5.val[1]; - int16x8_t p[2][2]; - const uint8x8_t sr0 = vld1_u8(src + x); - const uint8x8_t sr1 = vld1_u8(src + stride + x); + const uint8x16_t sr0 = vld1q_u8(src + x); + const uint8x16_t sr1 = vld1q_u8(src + stride + x); + const uint8x8_t sr00 = vget_low_u8(sr0); + const uint8x8_t sr10 = vget_low_u8(sr1); ma[0][0] = vld1q_u16(ma565[0] + x); b[0][0].val[0] = vld1q_u32(b565[0] + x); b[0][0].val[1] = vld1q_u32(b565[0] + x + 4); - p[0][0] = CalculateFilteredOutputPass1(sr0, ma[0], b[0]); - p[1][0] = CalculateFilteredOutput<4>(sr1, ma[0][1], b[0][1]); + p[0][0] = CalculateFilteredOutputPass1(sr00, ma[0], b[0]); + p[1][0] = CalculateFilteredOutput<4>(sr10, ma[0][1], b[0][1]); ma[1][0] = vld1q_u16(ma343[0] + x); ma[1][1] = vld1q_u16(ma444[0] + x); b[1][0].val[0] = vld1q_u32(b343[0] + x); b[1][0].val[1] = vld1q_u32(b343[0] + x + 4); b[1][1].val[0] = vld1q_u32(b444[0] + x); b[1][1].val[1] = vld1q_u32(b444[0] + x + 4); - p[0][1] = CalculateFilteredOutputPass2(sr0, ma[1], b[1]); + p[0][1] = CalculateFilteredOutputPass2(sr00, ma[1], b[1]); ma[2][0] = vld1q_u16(ma343[1] + x); b[2][0].val[0] = vld1q_u32(b343[1] + x); b[2][0].val[1] = vld1q_u32(b343[1] + x + 4); - p[1][1] = CalculateFilteredOutputPass2(sr1, ma[2], b[2]); - SelfGuidedDoubleMultiplier(sr0, p[0], w0, w2, dst + x); - SelfGuidedDoubleMultiplier(sr1, p[1], w0, w2, dst + stride + x); - x += 8; + p[1][1] = CalculateFilteredOutputPass2(sr10, ma[2], b[2]); + const uint8x8_t d00 = SelfGuidedDoubleMultiplier(sr00, p[0], w0, w2); + const uint8x8_t d10 = SelfGuidedDoubleMultiplier(sr10, p[1], w0, w2); + + Store343_444<8>(ma3x[0], b3[0] + 1, x + 8, &ma[1][2], &ma[2][1], &b[1][2], + &b[2][1], ma343[2], ma444[1], b343[2], b444[1]); + Store343_444<8>(ma3x[1], b3[1] + 1, x + 8, &ma[2][2], &b[2][2], ma343[3], + ma444[2], b343[3], b444[2]); + ma[0][1] = Sum565<8>(ma5x); + b[0][1] = Sum565W(b5 + 1); + vst1q_u16(ma565[1] + x + 8, ma[0][1]); + vst1q_u32(b565[1] + x + 8, b[0][1].val[0]); + vst1q_u32(b565[1] + x + 12, b[0][1].val[1]); + b3[0][0] = b3[0][2]; + b3[1][0] = b3[1][2]; + b5[0] = b5[2]; + const uint8x8_t sr01 = vget_high_u8(sr0); + const uint8x8_t sr11 = vget_high_u8(sr1); + ma[0][0] = vld1q_u16(ma565[0] + x + 8); + b[0][0].val[0] = vld1q_u32(b565[0] + x + 8); + b[0][0].val[1] = vld1q_u32(b565[0] + x + 12); + p[0][0] = CalculateFilteredOutputPass1(sr01, ma[0], b[0]); + p[1][0] = CalculateFilteredOutput<4>(sr11, ma[0][1], b[0][1]); + ma[1][0] = vld1q_u16(ma343[0] + x + 8); + ma[1][1] = vld1q_u16(ma444[0] + x + 8); + b[1][0].val[0] = vld1q_u32(b343[0] + x + 8); + b[1][0].val[1] = vld1q_u32(b343[0] + x + 12); + b[1][1].val[0] = vld1q_u32(b444[0] + x + 8); + b[1][1].val[1] = vld1q_u32(b444[0] + x + 12); + p[0][1] = CalculateFilteredOutputPass2(sr01, ma[1], b[1]); + ma[2][0] = vld1q_u16(ma343[1] + x + 8); + b[2][0].val[0] = vld1q_u32(b343[1] + x + 8); + b[2][0].val[1] = vld1q_u32(b343[1] + x + 12); + p[1][1] = CalculateFilteredOutputPass2(sr11, ma[2], b[2]); + const uint8x8_t d01 = SelfGuidedDoubleMultiplier(sr01, p[0], w0, w2); + const uint8x8_t d11 = SelfGuidedDoubleMultiplier(sr11, p[1], w0, w2); + vst1q_u8(dst + x, vcombine_u8(d00, d01)); + vst1q_u8(dst + stride + x, vcombine_u8(d10, d11)); + s[0][0] = s[0][1]; + s[1][0] = s[1][1]; + sq[0][1] = sq[0][3]; + sq[1][1] = sq[1][3]; + ma3[0][0] = ma3[0][1]; + ma3[1][0] = ma3[1][1]; + ma5[0] = ma5[1]; + x += 16; } while (x < width); } @@ -1540,58 +2064,79 @@ inline void BoxFilterLastRow( const uint16_t scales[2], const int16_t w0, const int16_t w2, uint16_t* const sum3[4], uint16_t* const sum5[5], uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], - uint16_t* const ma343[4], uint16_t* const ma444[3], - uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3], - uint32_t* const b565[2], uint8_t* const dst) { - uint8x8x2_t s, ma3, ma5; - uint16x8x2_t sq, b3, b5; - uint16x8_t ma[3]; + uint16_t* const ma343, uint16_t* const ma444, uint16_t* const ma565, + uint32_t* const b343, uint32_t* const b444, uint32_t* const b565, + uint8_t* const dst) { + uint8x16_t s[2], ma3[2], ma5[2]; + uint16x8_t sq[4], ma[3], b3[3], b5[3]; uint32x4x2_t b[3]; - s.val[0] = vld1_u8(src0); - sq.val[0] = vmull_u8(s.val[0], s.val[0]); - BoxFilterPreProcessLastRow(src0, 0, scales, sum3, sum5, square_sum3, - square_sum5, &s, &sq, &ma3.val[0], &ma5.val[0], - &b3.val[0], &b5.val[0]); + BoxFilterPreProcessLastRowLo(src0, scales, sum3, sum5, square_sum3, + square_sum5, &s[0], sq, &ma3[0], &ma5[0], &b3[0], + &b5[0]); int x = 0; do { - s.val[0] = s.val[1]; - sq.val[0] = sq.val[1]; + uint8x16_t ma3x[3], ma5x[3]; + int16x8_t p[2]; BoxFilterPreProcessLastRow(src0, x + 8, scales, sum3, sum5, square_sum3, - square_sum5, &s, &sq, &ma3.val[1], &ma5.val[1], - &b3.val[1], &b5.val[1]); - ma[1] = Sum565(ma5); + square_sum5, s, sq + 1, ma3, ma5, &b3[1], + &b5[1]); + Prepare3_8<0>(ma5, ma5x); + ma[1] = Sum565<0>(ma5x); b[1] = Sum565W(b5); - ma5.val[0] = ma5.val[1]; - b5.val[0] = b5.val[1]; - ma[2] = Sum343(ma3); + Prepare3_8<0>(ma3, ma3x); + ma[2] = Sum343<0>(ma3x); b[2] = Sum343W(b3); - ma3.val[0] = ma3.val[1]; - b3.val[0] = b3.val[1]; - const uint8x8_t sr = vld1_u8(src + x); - int16x8_t p[2]; - ma[0] = vld1q_u16(ma565[0] + x); - b[0].val[0] = vld1q_u32(b565[0] + x + 0); - b[0].val[1] = vld1q_u32(b565[0] + x + 4); - p[0] = CalculateFilteredOutputPass1(sr, ma, b); - ma[0] = vld1q_u16(ma343[0] + x); - ma[1] = vld1q_u16(ma444[0] + x); - b[0].val[0] = vld1q_u32(b343[0] + x + 0); - b[0].val[1] = vld1q_u32(b343[0] + x + 4); - b[1].val[0] = vld1q_u32(b444[0] + x + 0); - b[1].val[1] = vld1q_u32(b444[0] + x + 4); - p[1] = CalculateFilteredOutputPass2(sr, ma, b); - SelfGuidedDoubleMultiplier(sr, p, w0, w2, dst + x); - x += 8; + const uint8x16_t sr = vld1q_u8(src + x); + const uint8x8_t sr0 = vget_low_u8(sr); + ma[0] = vld1q_u16(ma565 + x); + b[0].val[0] = vld1q_u32(b565 + x + 0); + b[0].val[1] = vld1q_u32(b565 + x + 4); + p[0] = CalculateFilteredOutputPass1(sr0, ma, b); + ma[0] = vld1q_u16(ma343 + x); + ma[1] = vld1q_u16(ma444 + x); + b[0].val[0] = vld1q_u32(b343 + x + 0); + b[0].val[1] = vld1q_u32(b343 + x + 4); + b[1].val[0] = vld1q_u32(b444 + x + 0); + b[1].val[1] = vld1q_u32(b444 + x + 4); + p[1] = CalculateFilteredOutputPass2(sr0, ma, b); + const uint8x8_t d0 = SelfGuidedDoubleMultiplier(sr0, p, w0, w2); + + ma[1] = Sum565<8>(ma5x); + b[1] = Sum565W(b5 + 1); + b5[0] = b5[2]; + ma[2] = Sum343<8>(ma3x); + b[2] = Sum343W(b3 + 1); + b3[0] = b3[2]; + const uint8x8_t sr1 = vget_high_u8(sr); + ma[0] = vld1q_u16(ma565 + x + 8); + b[0].val[0] = vld1q_u32(b565 + x + 8); + b[0].val[1] = vld1q_u32(b565 + x + 12); + p[0] = CalculateFilteredOutputPass1(sr1, ma, b); + ma[0] = vld1q_u16(ma343 + x + 8); + ma[1] = vld1q_u16(ma444 + x + 8); + b[0].val[0] = vld1q_u32(b343 + x + 8); + b[0].val[1] = vld1q_u32(b343 + x + 12); + b[1].val[0] = vld1q_u32(b444 + x + 8); + b[1].val[1] = vld1q_u32(b444 + x + 12); + p[1] = CalculateFilteredOutputPass2(sr1, ma, b); + const uint8x8_t d1 = SelfGuidedDoubleMultiplier(sr1, p, w0, w2); + vst1q_u8(dst + x, vcombine_u8(d0, d1)); + s[0] = s[1]; + sq[1] = sq[3]; + ma3[0] = ma3[1]; + ma5[0] = ma5[1]; + x += 16; } while (x < width); } LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( const RestorationUnitInfo& restoration_info, const uint8_t* src, - const uint8_t* const top_border, const uint8_t* bottom_border, - const ptrdiff_t stride, const int width, const int height, + const ptrdiff_t stride, const uint8_t* const top_border, + const ptrdiff_t top_border_stride, const uint8_t* bottom_border, + const ptrdiff_t bottom_border_stride, const int width, const int height, SgrBuffer* const sgr_buffer, uint8_t* dst) { - const auto temp_stride = Align<ptrdiff_t>(width, 8); + const auto temp_stride = Align<ptrdiff_t>(width, 16); const ptrdiff_t sum_stride = temp_stride + 8; const int sgr_proj_index = restoration_info.sgr_proj_info.index; const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index]; // < 2^12. @@ -1628,13 +2173,13 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( b565[1] = b565[0] + temp_stride; assert(scales[0] != 0); assert(scales[1] != 0); - BoxSum(top_border, stride, 2, sum_stride, sum3[0], sum5[1], square_sum3[0], - square_sum5[1]); + BoxSum(top_border, top_border_stride, sum_stride, sum3[0], sum5[1], + square_sum3[0], square_sum5[1]); sum5[0] = sum5[1]; square_sum5[0] = square_sum5[1]; const uint8_t* const s = (height > 1) ? src + stride : bottom_border; BoxSumFilterPreProcess(src, s, width, scales, sum3, sum5, square_sum3, - square_sum5, ma343, ma444, ma565[0], b343, b444, + square_sum5, ma343, ma444[0], ma565[0], b343, b444[0], b565[0]); sum5[0] = sgr_buffer->sum5; square_sum5[0] = sgr_buffer->square_sum5; @@ -1665,7 +2210,7 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( const uint8_t* sr[2]; if ((height & 1) == 0) { sr[0] = bottom_border; - sr[1] = bottom_border + stride; + sr[1] = bottom_border + bottom_border_stride; } else { sr[0] = src + 2 * stride; sr[1] = bottom_border; @@ -1689,20 +2234,22 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( std::swap(ma565[0], ma565[1]); std::swap(b565[0], b565[1]); } - BoxFilterLastRow(src + 3, bottom_border + stride, width, scales, w0, w2, - sum3, sum5, square_sum3, square_sum5, ma343, ma444, ma565, - b343, b444, b565, dst); + BoxFilterLastRow(src + 3, bottom_border + bottom_border_stride, width, + scales, w0, w2, sum3, sum5, square_sum3, square_sum5, + ma343[0], ma444[0], ma565[0], b343[0], b444[0], b565[0], + dst); } } inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, - const uint8_t* src, + const uint8_t* src, const ptrdiff_t stride, const uint8_t* const top_border, + const ptrdiff_t top_border_stride, const uint8_t* bottom_border, - const ptrdiff_t stride, const int width, - const int height, SgrBuffer* const sgr_buffer, - uint8_t* dst) { - const auto temp_stride = Align<ptrdiff_t>(width, 8); + const ptrdiff_t bottom_border_stride, + const int width, const int height, + SgrBuffer* const sgr_buffer, uint8_t* dst) { + const auto temp_stride = Align<ptrdiff_t>(width, 16); const ptrdiff_t sum_stride = temp_stride + 8; const int sgr_proj_index = restoration_info.sgr_proj_info.index; const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0]; // < 2^12. @@ -1720,7 +2267,7 @@ inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, b565[0] = sgr_buffer->b565; b565[1] = b565[0] + temp_stride; assert(scale != 0); - BoxSum<5>(top_border, stride, 2, sum_stride, sum5[1], square_sum5[1]); + BoxSum<5>(top_border, top_border_stride, sum_stride, sum5[1], square_sum5[1]); sum5[0] = sum5[1]; square_sum5[0] = square_sum5[1]; const uint8_t* const s = (height > 1) ? src + stride : bottom_border; @@ -1746,7 +2293,7 @@ inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, const uint8_t* sr[2]; if ((height & 1) == 0) { sr[0] = bottom_border; - sr[1] = bottom_border + stride; + sr[1] = bottom_border + bottom_border_stride; } else { sr[0] = src + 2 * stride; sr[1] = bottom_border; @@ -1763,20 +2310,21 @@ inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, Circulate5PointersBy2<uint16_t>(sum5); Circulate5PointersBy2<uint32_t>(square_sum5); } - BoxFilterPass1LastRow(src + 3, bottom_border + stride, width, scale, w0, - sum5, square_sum5, ma565[0], b565[0], dst); + BoxFilterPass1LastRow(src + 3, bottom_border + bottom_border_stride, width, + scale, w0, sum5, square_sum5, ma565[0], b565[0], dst); } } inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, - const uint8_t* src, + const uint8_t* src, const ptrdiff_t stride, const uint8_t* const top_border, + const ptrdiff_t top_border_stride, const uint8_t* bottom_border, - const ptrdiff_t stride, const int width, - const int height, SgrBuffer* const sgr_buffer, - uint8_t* dst) { + const ptrdiff_t bottom_border_stride, + const int width, const int height, + SgrBuffer* const sgr_buffer, uint8_t* dst) { assert(restoration_info.sgr_proj_info.multiplier[0] == 0); - const auto temp_stride = Align<ptrdiff_t>(width, 8); + const auto temp_stride = Align<ptrdiff_t>(width, 16); const ptrdiff_t sum_stride = temp_stride + 8; const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1]; const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1; @@ -1799,7 +2347,7 @@ inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, b444[0] = sgr_buffer->b444; b444[1] = b444[0] + temp_stride; assert(scale != 0); - BoxSum<3>(top_border, stride, 2, sum_stride, sum3[0], square_sum3[0]); + BoxSum<3>(top_border, top_border_stride, sum_stride, sum3[0], square_sum3[0]); BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3, ma343[0], nullptr, b343[0], nullptr); Circulate3PointersBy1<uint16_t>(sum3); @@ -1809,7 +2357,7 @@ inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, s = src + stride; } else { s = bottom_border; - bottom_border += stride; + bottom_border += bottom_border_stride; } BoxSumFilterPreProcess3<true>(s, width, scale, sum3, square_sum3, ma343[1], ma444[0], b343[1], b444[0]); @@ -1836,7 +2384,7 @@ inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, ma343, ma444, b343, b444, dst); src += stride; dst += stride; - bottom_border += stride; + bottom_border += bottom_border_stride; Circulate3PointersBy1<uint16_t>(ma343); Circulate3PointersBy1<uint32_t>(b343); std::swap(ma444[0], ma444[1]); @@ -1849,8 +2397,9 @@ inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, // part of the visible frame. void SelfGuidedFilter_NEON( const RestorationUnitInfo& restoration_info, const void* const source, - const void* const top_border, const void* const bottom_border, - const ptrdiff_t stride, const int width, const int height, + const ptrdiff_t stride, const void* const top_border, + const ptrdiff_t top_border_stride, const void* const bottom_border, + const ptrdiff_t bottom_border_stride, const int width, const int height, RestorationBuffer* const restoration_buffer, void* const dest) { const int index = restoration_info.sgr_proj_info.index; const int radius_pass_0 = kSgrProjParams[index][0]; // 2 or 0 @@ -1864,14 +2413,17 @@ void SelfGuidedFilter_NEON( // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the // following assertion. assert(radius_pass_0 != 0); - BoxFilterProcessPass1(restoration_info, src - 3, top - 3, bottom - 3, - stride, width, height, sgr_buffer, dst); + BoxFilterProcessPass1(restoration_info, src - 3, stride, top - 3, + top_border_stride, bottom - 3, bottom_border_stride, + width, height, sgr_buffer, dst); } else if (radius_pass_0 == 0) { - BoxFilterProcessPass2(restoration_info, src - 2, top - 2, bottom - 2, - stride, width, height, sgr_buffer, dst); + BoxFilterProcessPass2(restoration_info, src - 2, stride, top - 2, + top_border_stride, bottom - 2, bottom_border_stride, + width, height, sgr_buffer, dst); } else { - BoxFilterProcess(restoration_info, src - 3, top - 3, bottom - 3, stride, - width, height, sgr_buffer, dst); + BoxFilterProcess(restoration_info, src - 3, stride, top - 3, + top_border_stride, bottom - 3, bottom_border_stride, width, + height, sgr_buffer, dst); } } @@ -1890,7 +2442,7 @@ void LoopRestorationInit_NEON() { low_bitdepth::Init8bpp(); } } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_ENABLE_NEON +#else // !LIBGAV1_ENABLE_NEON namespace libgav1 { namespace dsp { |