diff options
Diffstat (limited to 'src/dsp/x86/loop_restoration_avx2.cc')
-rw-r--r-- | src/dsp/x86/loop_restoration_avx2.cc | 339 |
1 files changed, 189 insertions, 150 deletions
diff --git a/src/dsp/x86/loop_restoration_avx2.cc b/src/dsp/x86/loop_restoration_avx2.cc index 7ae7c90..351a324 100644 --- a/src/dsp/x86/loop_restoration_avx2.cc +++ b/src/dsp/x86/loop_restoration_avx2.cc @@ -28,7 +28,6 @@ #include "src/dsp/constants.h" #include "src/dsp/dsp.h" #include "src/dsp/x86/common_avx2.h" -#include "src/dsp/x86/common_sse4.h" #include "src/utils/common.h" #include "src/utils/constants.h" @@ -116,7 +115,8 @@ inline void WienerHorizontalTap7(const uint8_t* src, const ptrdiff_t src_stride, filter[0] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0100)); filter[1] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0302)); filter[2] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0102)); - filter[3] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x8000)); + filter[3] = _mm256_shuffle_epi8( + coefficients, _mm256_set1_epi16(static_cast<int16_t>(0x8000))); for (int y = height; y != 0; --y) { __m256i s = LoadUnaligned32(src); __m256i ss[4]; @@ -144,7 +144,8 @@ inline void WienerHorizontalTap5(const uint8_t* src, const ptrdiff_t src_stride, __m256i filter[3]; filter[0] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0201)); filter[1] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0203)); - filter[2] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x8001)); + filter[2] = _mm256_shuffle_epi8( + coefficients, _mm256_set1_epi16(static_cast<int16_t>(0x8001))); for (int y = height; y != 0; --y) { __m256i s = LoadUnaligned32(src); __m256i ss[4]; @@ -171,7 +172,8 @@ inline void WienerHorizontalTap3(const uint8_t* src, const ptrdiff_t src_stride, int16_t** const wiener_buffer) { __m256i filter[2]; filter[0] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0302)); - filter[1] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x8002)); + filter[1] = _mm256_shuffle_epi8( + coefficients, _mm256_set1_epi16(static_cast<int16_t>(0x8002))); for (int y = height; y != 0; --y) { __m256i s = LoadUnaligned32(src); __m256i ss[4]; @@ -480,12 +482,12 @@ inline void WienerVerticalTap1(const int16_t* wiener_buffer, } } -void WienerFilter_AVX2(const RestorationUnitInfo& restoration_info, - const void* const source, const void* const top_border, - const void* const bottom_border, const ptrdiff_t stride, - const int width, const int height, - RestorationBuffer* const restoration_buffer, - void* const dest) { +void WienerFilter_AVX2( + const RestorationUnitInfo& restoration_info, const void* const source, + const ptrdiff_t stride, const void* const top_border, + const ptrdiff_t top_border_stride, const void* const bottom_border, + const ptrdiff_t bottom_border_stride, const int width, const int height, + RestorationBuffer* const restoration_buffer, void* const dest) { const int16_t* const number_leading_zero_coefficients = restoration_info.wiener_info.number_leading_zero_coefficients; const int number_rows_to_skip = std::max( @@ -515,39 +517,42 @@ void WienerFilter_AVX2(const RestorationUnitInfo& restoration_info, c_horizontal = _mm_packs_epi16(c_horizontal, c_horizontal); const __m256i coefficients_horizontal = _mm256_broadcastd_epi32(c_horizontal); if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) { - WienerHorizontalTap7(top + (2 - height_extra) * stride - 3, stride, - wiener_stride, height_extra, coefficients_horizontal, - &wiener_buffer_horizontal); - WienerHorizontalTap7(src - 3, stride, wiener_stride, height, + WienerHorizontalTap7(top + (2 - height_extra) * top_border_stride - 3, + top_border_stride, wiener_stride, height_extra, coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap7(bottom - 3, stride, wiener_stride, height_extra, + WienerHorizontalTap7(src - 3, stride, wiener_stride, height, coefficients_horizontal, &wiener_buffer_horizontal); - } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) { - WienerHorizontalTap5(top + (2 - height_extra) * stride - 2, stride, - wiener_stride, height_extra, coefficients_horizontal, + WienerHorizontalTap7(bottom - 3, bottom_border_stride, wiener_stride, + height_extra, coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap5(src - 2, stride, wiener_stride, height, + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) { + WienerHorizontalTap5(top + (2 - height_extra) * top_border_stride - 2, + top_border_stride, wiener_stride, height_extra, coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap5(bottom - 2, stride, wiener_stride, height_extra, + WienerHorizontalTap5(src - 2, stride, wiener_stride, height, coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap5(bottom - 2, bottom_border_stride, wiener_stride, + height_extra, coefficients_horizontal, + &wiener_buffer_horizontal); } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) { // The maximum over-reads happen here. - WienerHorizontalTap3(top + (2 - height_extra) * stride - 1, stride, - wiener_stride, height_extra, coefficients_horizontal, - &wiener_buffer_horizontal); - WienerHorizontalTap3(src - 1, stride, wiener_stride, height, + WienerHorizontalTap3(top + (2 - height_extra) * top_border_stride - 1, + top_border_stride, wiener_stride, height_extra, coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap3(bottom - 1, stride, wiener_stride, height_extra, + WienerHorizontalTap3(src - 1, stride, wiener_stride, height, coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap3(bottom - 1, bottom_border_stride, wiener_stride, + height_extra, coefficients_horizontal, + &wiener_buffer_horizontal); } else { assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3); - WienerHorizontalTap1(top + (2 - height_extra) * stride, stride, - wiener_stride, height_extra, + WienerHorizontalTap1(top + (2 - height_extra) * top_border_stride, + top_border_stride, wiener_stride, height_extra, &wiener_buffer_horizontal); WienerHorizontalTap1(src, stride, wiener_stride, height, &wiener_buffer_horizontal); - WienerHorizontalTap1(bottom, stride, wiener_stride, height_extra, - &wiener_buffer_horizontal); + WienerHorizontalTap1(bottom, bottom_border_stride, wiener_stride, + height_extra, &wiener_buffer_horizontal); } // vertical filtering. @@ -765,17 +770,6 @@ inline __m256i VaddwHi16(const __m256i src0, const __m256i src1) { return _mm256_add_epi32(src0, s1); } -// Using VgetLane16() can save a sign extension instruction. -template <int n> -inline int VgetLane16(__m256i src) { - return _mm256_extract_epi16(src, n); -} - -template <int n> -inline int VgetLane8(__m256i src) { - return _mm256_extract_epi8(src, n); -} - inline __m256i VmullNLo8(const __m256i src0, const int src1) { const __m256i s0 = _mm256_unpacklo_epi16(src0, _mm256_setzero_si256()); return _mm256_madd_epi16(s0, _mm256_set1_epi32(src1)); @@ -1253,9 +1247,8 @@ inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride, do { const __m128i s0 = LoadUnaligned16Msan(src, kOverreadInBytesPass1_128 - width); - __m128i sq_128[2]; + __m128i sq_128[2], s3, s5, sq3[2], sq5[2]; __m256i sq[3]; - __m128i s3, s5, sq3[2], sq5[2]; sq_128[0] = SquareLo8(s0); sq_128[1] = SquareHi8(s0); SumHorizontalLo(s0, &s3, &s5); @@ -1432,11 +1425,43 @@ inline __m256i CalculateMa(const __m256i sum, const __m256i sum_sq[2], return _mm256_packus_epi32(z0, z1); } -template <int n> -inline __m128i CalculateB(const __m128i sum, const __m128i ma) { - static_assert(n == 9 || n == 25, ""); +inline __m128i CalculateB5(const __m128i sum, const __m128i ma) { + // one_over_n == 164. constexpr uint32_t one_over_n = - ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n; + ((1 << kSgrProjReciprocalBits) + (25 >> 1)) / 25; + // one_over_n_quarter == 41. + constexpr uint32_t one_over_n_quarter = one_over_n >> 2; + static_assert(one_over_n == one_over_n_quarter << 2, ""); + // |ma| is in range [0, 255]. + const __m128i m = _mm_maddubs_epi16(ma, _mm_set1_epi16(one_over_n_quarter)); + const __m128i m0 = VmullLo16(m, sum); + const __m128i m1 = VmullHi16(m, sum); + const __m128i b_lo = VrshrU32(m0, kSgrProjReciprocalBits - 2); + const __m128i b_hi = VrshrU32(m1, kSgrProjReciprocalBits - 2); + return _mm_packus_epi32(b_lo, b_hi); +} + +inline __m256i CalculateB5(const __m256i sum, const __m256i ma) { + // one_over_n == 164. + constexpr uint32_t one_over_n = + ((1 << kSgrProjReciprocalBits) + (25 >> 1)) / 25; + // one_over_n_quarter == 41. + constexpr uint32_t one_over_n_quarter = one_over_n >> 2; + static_assert(one_over_n == one_over_n_quarter << 2, ""); + // |ma| is in range [0, 255]. + const __m256i m = + _mm256_maddubs_epi16(ma, _mm256_set1_epi16(one_over_n_quarter)); + const __m256i m0 = VmullLo16(m, sum); + const __m256i m1 = VmullHi16(m, sum); + const __m256i b_lo = VrshrU32(m0, kSgrProjReciprocalBits - 2); + const __m256i b_hi = VrshrU32(m1, kSgrProjReciprocalBits - 2); + return _mm256_packus_epi32(b_lo, b_hi); +} + +inline __m128i CalculateB3(const __m128i sum, const __m128i ma) { + // one_over_n == 455. + constexpr uint32_t one_over_n = + ((1 << kSgrProjReciprocalBits) + (9 >> 1)) / 9; const __m128i m0 = VmullLo16(ma, sum); const __m128i m1 = VmullHi16(ma, sum); const __m128i m2 = _mm_mullo_epi32(m0, _mm_set1_epi32(one_over_n)); @@ -1446,11 +1471,10 @@ inline __m128i CalculateB(const __m128i sum, const __m128i ma) { return _mm_packus_epi32(b_lo, b_hi); } -template <int n> -inline __m256i CalculateB(const __m256i sum, const __m256i ma) { - static_assert(n == 9 || n == 25, ""); +inline __m256i CalculateB3(const __m256i sum, const __m256i ma) { + // one_over_n == 455. constexpr uint32_t one_over_n = - ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n; + ((1 << kSgrProjReciprocalBits) + (9 >> 1)) / 9; const __m256i m0 = VmullLo16(ma, sum); const __m256i m1 = VmullHi16(ma, sum); const __m256i m2 = _mm256_mullo_epi32(m0, _mm256_set1_epi32(one_over_n)); @@ -1525,7 +1549,7 @@ inline void LookupIntermediate(const __m128i sum, const __m128i index, // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits). // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits). const __m128i maq = _mm_unpacklo_epi8(*ma, _mm_setzero_si128()); - *b = CalculateB<n>(sum, maq); + *b = (n == 9) ? CalculateB3(sum, maq) : CalculateB5(sum, maq); } // Repeat the first 48 elements in kSgrMaLookup with a period of 16. @@ -1539,7 +1563,7 @@ alignas(32) constexpr uint8_t kSgrMaLookupAvx2[96] = { // Set the shuffle control mask of indices out of range [0, 15] to (1xxxxxxx)b // to get value 0 as the shuffle result. The most significiant bit 1 comes -// either from the comparision instruction, or from the sign bit of the index. +// either from the comparison instruction, or from the sign bit of the index. inline __m256i ShuffleIndex(const __m256i table, const __m256i index) { __m256i mask; mask = _mm256_cmpgt_epi8(index, _mm256_set1_epi8(15)); @@ -1558,15 +1582,15 @@ template <int n> inline void CalculateIntermediate(const __m256i sum[2], const __m256i index[2], __m256i ma[3], __m256i b[2]) { static_assert(n == 9 || n == 25, ""); - // Use table lookup to read elements which indices are less than 48. + // Use table lookup to read elements whose indices are less than 48. const __m256i c0 = LoadAligned32(kSgrMaLookupAvx2 + 0 * 32); const __m256i c1 = LoadAligned32(kSgrMaLookupAvx2 + 1 * 32); const __m256i c2 = LoadAligned32(kSgrMaLookupAvx2 + 2 * 32); const __m256i indices = _mm256_packus_epi16(index[0], index[1]); __m256i idx, mas; - // Clip idx to 127 to apply signed comparision instructions. + // Clip idx to 127 to apply signed comparison instructions. idx = _mm256_min_epu8(indices, _mm256_set1_epi8(127)); - // All elements which indices are less than 48 are set to 0. + // All elements whose indices are less than 48 are set to 0. // Get shuffle results for indices in range [0, 15]. mas = ShuffleIndex(c0, idx); // Get shuffle results for indices in range [16, 31]. @@ -1581,12 +1605,12 @@ inline void CalculateIntermediate(const __m256i sum[2], const __m256i index[2], const __m256i res2 = ShuffleIndex(c2, idx); mas = _mm256_or_si256(mas, res2); - // For elements which indices are larger than 47, since they seldom change + // For elements whose indices are larger than 47, since they seldom change // values with the increase of the index, we use comparison and arithmetic // operations to calculate their values. - // Add -128 to apply signed comparision instructions. + // Add -128 to apply signed comparison instructions. idx = _mm256_add_epi8(indices, _mm256_set1_epi8(-128)); - // Elements which indices are larger than 47 (with value 0) are set to 5. + // Elements whose indices are larger than 47 (with value 0) are set to 5. mas = _mm256_max_epu8(mas, _mm256_set1_epi8(5)); mas = AdjustValue(mas, idx, 55); // 55 is the last index which value is 5. mas = AdjustValue(mas, idx, 72); // 72 is the last index which value is 4. @@ -1611,8 +1635,13 @@ inline void CalculateIntermediate(const __m256i sum[2], const __m256i index[2], // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits). const __m256i maq0 = _mm256_unpackhi_epi8(ma[0], _mm256_setzero_si256()); const __m256i maq1 = _mm256_unpacklo_epi8(ma[1], _mm256_setzero_si256()); - b[0] = CalculateB<n>(sum[0], maq0); - b[1] = CalculateB<n>(sum[1], maq1); + if (n == 9) { + b[0] = CalculateB3(sum[0], maq0); + b[1] = CalculateB3(sum[1], maq1); + } else { + b[0] = CalculateB5(sum[0], maq0); + b[1] = CalculateB5(sum[1], maq1); + } } inline void CalculateIntermediate5(const __m128i s5[5], const __m128i sq5[5][2], @@ -1903,8 +1932,8 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess( __m256i b3[2][5], __m256i ma5[3], __m256i b5[5]) { const __m256i s0 = LoadUnaligned32Msan(src0 + 8, over_read_in_bytes + 8); const __m256i s1 = LoadUnaligned32Msan(src1 + 8, over_read_in_bytes + 8); - __m256i s3[2][4], s5[2][5], sq3[4][2], sq5[5][2], sq3t[4][2], sq5t[5][2], - sum_3[2][2], index_3[2][2], sum_5[2], index_5[2]; + __m256i s3[2][4], s5[2][5], sq3[4][2], sq5[5][2], sum_3[2][2], index_3[2][2], + sum_5[2], index_5[2]; sq[0][1] = SquareLo8(s0); sq[0][2] = SquareHi8(s0); sq[1][1] = SquareLo8(s1); @@ -1938,22 +1967,22 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess( LoadAligned64x3U32(square_sum5, x, sq5); CalculateSumAndIndex5(s5[0], sq5, scales[0], &sum_5[0], &index_5[0]); - SumHorizontal(sq[0] + 1, &sq3t[2][0], &sq3t[2][1], &sq5t[3][0], &sq5t[3][1]); - SumHorizontal(sq[1] + 1, &sq3t[3][0], &sq3t[3][1], &sq5t[4][0], &sq5t[4][1]); - StoreAligned64(square_sum3[2] + x + 16, sq3t[2]); - StoreAligned64(square_sum5[3] + x + 16, sq5t[3]); - StoreAligned64(square_sum3[3] + x + 16, sq3t[3]); - StoreAligned64(square_sum5[4] + x + 16, sq5t[4]); + SumHorizontal(sq[0] + 1, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + SumHorizontal(sq[1] + 1, &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]); + StoreAligned64(square_sum3[2] + x + 16, sq3[2]); + StoreAligned64(square_sum5[3] + x + 16, sq5[3]); + StoreAligned64(square_sum3[3] + x + 16, sq3[3]); + StoreAligned64(square_sum5[4] + x + 16, sq5[4]); LoadAligned32x2U16Msan(sum3, x + 16, sum_width, s3[1]); - LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3t); - CalculateSumAndIndex3(s3[1], sq3t, scales[1], &sum_3[0][1], &index_3[0][1]); - CalculateSumAndIndex3(s3[1] + 1, sq3t + 1, scales[1], &sum_3[1][1], + LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3); + CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum_3[0][1], &index_3[0][1]); + CalculateSumAndIndex3(s3[1] + 1, sq3 + 1, scales[1], &sum_3[1][1], &index_3[1][1]); CalculateIntermediate<9>(sum_3[0], index_3[0], ma3[0], b3[0] + 1); CalculateIntermediate<9>(sum_3[1], index_3[1], ma3[1], b3[1] + 1); LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]); - LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5t); - CalculateSumAndIndex5(s5[1], sq5t, scales[0], &sum_5[1], &index_5[1]); + LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5); + CalculateSumAndIndex5(s5[1], sq5, scales[0], &sum_5[1], &index_5[1]); CalculateIntermediate<25>(sum_5, index_5, ma5, b5 + 1); b3[0][0] = _mm256_permute2x128_si256(b3[0][0], b3[0][2], 0x21); b3[1][0] = _mm256_permute2x128_si256(b3[1][0], b3[1][2], 0x21); @@ -1988,8 +2017,8 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow( __m256i sq[6], __m256i ma3[2], __m256i ma5[2], __m256i b3[5], __m256i b5[5]) { const __m256i s0 = LoadUnaligned32Msan(src + 8, over_read_in_bytes + 8); - __m256i s3[2][3], s5[2][5], sq3[4][2], sq3t[4][2], sq5[5][2], sq5t[5][2], - sum_3[2], index_3[2], sum_5[2], index_5[2]; + __m256i s3[2][3], s5[2][5], sq3[4][2], sq5[5][2], sum_3[2], index_3[2], + sum_5[2], index_5[2]; sq[1] = SquareLo8(s0); sq[2] = SquareHi8(s0); sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21); @@ -2006,17 +2035,17 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow( sq5[4][1] = sq5[3][1]; CalculateSumAndIndex5(s5[0], sq5, scales[0], &sum_5[0], &index_5[0]); - SumHorizontal(sq + 1, &sq3t[2][0], &sq3t[2][1], &sq5t[3][0], &sq5t[3][1]); + SumHorizontal(sq + 1, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); LoadAligned32x2U16Msan(sum3, x + 16, sum_width, s3[1]); - LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3t); - CalculateSumAndIndex3(s3[1], sq3t, scales[1], &sum_3[1], &index_3[1]); + LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3); + CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum_3[1], &index_3[1]); CalculateIntermediate<9>(sum_3, index_3, ma3, b3 + 1); LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]); s5[1][4] = s5[1][3]; - LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5t); - sq5t[4][0] = sq5t[3][0]; - sq5t[4][1] = sq5t[3][1]; - CalculateSumAndIndex5(s5[1], sq5t, scales[0], &sum_5[1], &index_5[1]); + LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + CalculateSumAndIndex5(s5[1], sq5, scales[0], &sum_5[1], &index_5[1]); CalculateIntermediate<25>(sum_5, index_5, ma5, b5 + 1); b3[0] = _mm256_permute2x128_si256(b3[0], b3[2], 0x21); b5[0] = _mm256_permute2x128_si256(b5[0], b5[2], 0x21); @@ -2071,9 +2100,9 @@ LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3( uint16_t* const sum3[3], uint32_t* const square_sum3[3], const ptrdiff_t sum_width, uint16_t* ma343, uint16_t* ma444, uint32_t* b343, uint32_t* b444) { + const __m128i s = LoadUnaligned16Msan(src, kOverreadInBytesPass2_128 - width); __m128i ma0, sq_128[2], b0; __m256i mas[3], sq[3], bs[3]; - const __m128i s = LoadUnaligned16Msan(src, kOverreadInBytesPass2_128 - width); sq_128[0] = SquareLo8(s); BoxFilterPreProcess3Lo(s, scale, sum3, square_sum3, sq_128, &ma0, &b0); sq[0] = SetrM128i(sq_128[0], sq_128[1]); @@ -2115,9 +2144,9 @@ inline void BoxSumFilterPreProcess( const uint8_t* const src0, const uint8_t* const src1, const int width, const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5], uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], - const ptrdiff_t sum_width, uint16_t* const ma343[4], - uint16_t* const ma444[2], uint16_t* ma565, uint32_t* const b343[4], - uint32_t* const b444[2], uint32_t* b565) { + const ptrdiff_t sum_width, uint16_t* const ma343[4], uint16_t* const ma444, + uint16_t* ma565, uint32_t* const b343[4], uint32_t* const b444, + uint32_t* b565) { __m128i s[2], ma3_128[2], ma5_0, sq_128[2][2], b3_128[2], b5_0; __m256i ma3[2][3], ma5[3], sq[2][3], b3[2][5], b5[5]; s[0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width); @@ -2151,9 +2180,8 @@ inline void BoxSumFilterPreProcess( Sum565W(b5, b); StoreAligned64(b565, b); Prepare3_8(ma3[1], ma3x); - Store343_444Lo(ma3x, b3[1], x, ma343[1], ma444[0], b343[1], b444[0]); - Store343_444Hi(ma3x, b3[1] + 1, x + 16, ma343[1], ma444[0], b343[1], - b444[0]); + Store343_444Lo(ma3x, b3[1], x, ma343[1], ma444, b343[1], b444); + Store343_444Hi(ma3x, b3[1] + 1, x + 16, ma343[1], ma444, b343[1], b444); Prepare3_8(ma5, ma5x); ma[0] = Sum565Lo(ma5x); ma[1] = Sum565Hi(ma5x); @@ -2199,8 +2227,9 @@ inline __m256i CalculateFilteredOutput(const __m256i src, const __m256i ma, return _mm256_packs_epi32(dst_lo, dst_hi); // 13 bits } -inline __m256i CalculateFilteredOutputPass1(const __m256i src, __m256i ma[2], - __m256i b[2][2]) { +inline __m256i CalculateFilteredOutputPass1(const __m256i src, + const __m256i ma[2], + const __m256i b[2][2]) { const __m256i ma_sum = _mm256_add_epi16(ma[0], ma[1]); __m256i b_sum[2]; b_sum[0] = _mm256_add_epi32(b[0][0], b[1][0]); @@ -2208,8 +2237,9 @@ inline __m256i CalculateFilteredOutputPass1(const __m256i src, __m256i ma[2], return CalculateFilteredOutput<5>(src, ma_sum, b_sum); } -inline __m256i CalculateFilteredOutputPass2(const __m256i src, __m256i ma[3], - __m256i b[3][2]) { +inline __m256i CalculateFilteredOutputPass2(const __m256i src, + const __m256i ma[3], + const __m256i b[3][2]) { const __m256i ma_sum = Sum3_16(ma); __m256i b_sum[2]; Sum3_32(b, b_sum); @@ -2267,13 +2297,13 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterPass1( int x = 0; do { - __m256i ma[3], ma3[3], b[2][2][2]; + __m256i ma[3], ma5[3], b[2][2][2]; BoxFilterPreProcess5(src0 + x + 8, src1 + x + 8, x + 8 + kOverreadInBytesPass1_256 - width, sum_width, x + 8, scale, sum5, square_sum5, sq, mas, bs); - Prepare3_8(mas, ma3); - ma[1] = Sum565Lo(ma3); - ma[2] = Sum565Hi(ma3); + Prepare3_8(mas, ma5); + ma[1] = Sum565Lo(ma5); + ma[2] = Sum565Hi(ma5); StoreAligned64(ma565[1] + x, ma + 1); Sum565W(bs + 0, b[0][1]); Sum565W(bs + 1, b[1][1]); @@ -2511,9 +2541,9 @@ inline void BoxFilterLastRow( const ptrdiff_t sum_width, const uint16_t scales[2], const int16_t w0, const int16_t w2, uint16_t* const sum3[4], uint16_t* const sum5[5], uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], - uint16_t* const ma343[4], uint16_t* const ma444[3], - uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3], - uint32_t* const b565[2], uint8_t* const dst) { + uint16_t* const ma343, uint16_t* const ma444, uint16_t* const ma565, + uint32_t* const b343, uint32_t* const b444, uint32_t* const b565, + uint8_t* const dst) { const __m128i s0 = LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width); __m128i ma3_0, ma5_0, b3_0, b5_0, sq_128[2]; @@ -2542,13 +2572,13 @@ inline void BoxFilterLastRow( Sum343W(b3, b[2]); const __m256i sr = LoadUnaligned32(src + x); const __m256i sr_lo = _mm256_unpacklo_epi8(sr, _mm256_setzero_si256()); - ma[0] = LoadAligned32(ma565[0] + x); - LoadAligned64(b565[0] + x, b[0]); + ma[0] = LoadAligned32(ma565 + x); + LoadAligned64(b565 + x, b[0]); p[0] = CalculateFilteredOutputPass1(sr_lo, ma, b); - ma[0] = LoadAligned32(ma343[0] + x); - ma[1] = LoadAligned32(ma444[0] + x); - LoadAligned64(b343[0] + x, b[0]); - LoadAligned64(b444[0] + x, b[1]); + ma[0] = LoadAligned32(ma343 + x); + ma[1] = LoadAligned32(ma444 + x); + LoadAligned64(b343 + x, b[0]); + LoadAligned64(b444 + x, b[1]); p[1] = CalculateFilteredOutputPass2(sr_lo, ma, b); const __m256i d0 = SelfGuidedDoubleMultiplier(sr_lo, p, w0, w2); @@ -2557,13 +2587,13 @@ inline void BoxFilterLastRow( mat[2] = Sum343Hi(ma3x); Sum343W(b3 + 1, b[2]); const __m256i sr_hi = _mm256_unpackhi_epi8(sr, _mm256_setzero_si256()); - mat[0] = LoadAligned32(ma565[0] + x + 16); - LoadAligned64(b565[0] + x + 16, b[0]); + mat[0] = LoadAligned32(ma565 + x + 16); + LoadAligned64(b565 + x + 16, b[0]); p[0] = CalculateFilteredOutputPass1(sr_hi, mat, b); - mat[0] = LoadAligned32(ma343[0] + x + 16); - mat[1] = LoadAligned32(ma444[0] + x + 16); - LoadAligned64(b343[0] + x + 16, b[0]); - LoadAligned64(b444[0] + x + 16, b[1]); + mat[0] = LoadAligned32(ma343 + x + 16); + mat[1] = LoadAligned32(ma444 + x + 16); + LoadAligned64(b343 + x + 16, b[0]); + LoadAligned64(b444 + x + 16, b[1]); p[1] = CalculateFilteredOutputPass2(sr_hi, mat, b); const __m256i d1 = SelfGuidedDoubleMultiplier(sr_hi, p, w0, w2); StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1)); @@ -2578,8 +2608,9 @@ inline void BoxFilterLastRow( LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( const RestorationUnitInfo& restoration_info, const uint8_t* src, - const uint8_t* const top_border, const uint8_t* bottom_border, - const ptrdiff_t stride, const int width, const int height, + const ptrdiff_t stride, const uint8_t* const top_border, + const ptrdiff_t top_border_stride, const uint8_t* bottom_border, + const ptrdiff_t bottom_border_stride, const int width, const int height, SgrBuffer* const sgr_buffer, uint8_t* dst) { const auto temp_stride = Align<ptrdiff_t>(width, 32); const auto sum_width = temp_stride + 8; @@ -2619,14 +2650,14 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( b565[1] = b565[0] + temp_stride; assert(scales[0] != 0); assert(scales[1] != 0); - BoxSum(top_border, stride, width, sum_stride, temp_stride, sum3[0], sum5[1], - square_sum3[0], square_sum5[1]); + BoxSum(top_border, top_border_stride, width, sum_stride, temp_stride, sum3[0], + sum5[1], square_sum3[0], square_sum5[1]); sum5[0] = sum5[1]; square_sum5[0] = square_sum5[1]; const uint8_t* const s = (height > 1) ? src + stride : bottom_border; BoxSumFilterPreProcess(src, s, width, scales, sum3, sum5, square_sum3, - square_sum5, sum_width, ma343, ma444, ma565[0], b343, - b444, b565[0]); + square_sum5, sum_width, ma343, ma444[0], ma565[0], + b343, b444[0], b565[0]); sum5[0] = sgr_buffer->sum5 + kSumOffset; square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset; @@ -2656,7 +2687,7 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( const uint8_t* sr[2]; if ((height & 1) == 0) { sr[0] = bottom_border; - sr[1] = bottom_border + stride; + sr[1] = bottom_border + bottom_border_stride; } else { sr[0] = src + 2 * stride; sr[1] = bottom_border; @@ -2680,19 +2711,21 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( std::swap(ma565[0], ma565[1]); std::swap(b565[0], b565[1]); } - BoxFilterLastRow(src + 3, bottom_border + stride, width, sum_width, scales, - w0, w2, sum3, sum5, square_sum3, square_sum5, ma343, ma444, - ma565, b343, b444, b565, dst); + BoxFilterLastRow(src + 3, bottom_border + bottom_border_stride, width, + sum_width, scales, w0, w2, sum3, sum5, square_sum3, + square_sum5, ma343[0], ma444[0], ma565[0], b343[0], + b444[0], b565[0], dst); } } inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, - const uint8_t* src, + const uint8_t* src, const ptrdiff_t stride, const uint8_t* const top_border, + const ptrdiff_t top_border_stride, const uint8_t* bottom_border, - const ptrdiff_t stride, const int width, - const int height, SgrBuffer* const sgr_buffer, - uint8_t* dst) { + const ptrdiff_t bottom_border_stride, + const int width, const int height, + SgrBuffer* const sgr_buffer, uint8_t* dst) { const auto temp_stride = Align<ptrdiff_t>(width, 32); const auto sum_width = temp_stride + 8; const auto sum_stride = temp_stride + 32; @@ -2712,8 +2745,8 @@ inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, b565[0] = sgr_buffer->b565; b565[1] = b565[0] + temp_stride; assert(scale != 0); - BoxSum<5>(top_border, stride, width, sum_stride, temp_stride, sum5[1], - square_sum5[1]); + BoxSum<5>(top_border, top_border_stride, width, sum_stride, temp_stride, + sum5[1], square_sum5[1]); sum5[0] = sum5[1]; square_sum5[0] = square_sum5[1]; const uint8_t* const s = (height > 1) ? src + stride : bottom_border; @@ -2739,7 +2772,7 @@ inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, const uint8_t* sr[2]; if ((height & 1) == 0) { sr[0] = bottom_border; - sr[1] = bottom_border + stride; + sr[1] = bottom_border + bottom_border_stride; } else { sr[0] = src + 2 * stride; sr[1] = bottom_border; @@ -2757,18 +2790,20 @@ inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, Circulate5PointersBy2<uint16_t>(sum5); Circulate5PointersBy2<uint32_t>(square_sum5); } - BoxFilterPass1LastRow(src, bottom_border + stride, width, sum_width, scale, - w0, sum5, square_sum5, ma565[0], b565[0], dst); + BoxFilterPass1LastRow(src, bottom_border + bottom_border_stride, width, + sum_width, scale, w0, sum5, square_sum5, ma565[0], + b565[0], dst); } } inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, - const uint8_t* src, + const uint8_t* src, const ptrdiff_t stride, const uint8_t* const top_border, + const ptrdiff_t top_border_stride, const uint8_t* bottom_border, - const ptrdiff_t stride, const int width, - const int height, SgrBuffer* const sgr_buffer, - uint8_t* dst) { + const ptrdiff_t bottom_border_stride, + const int width, const int height, + SgrBuffer* const sgr_buffer, uint8_t* dst) { assert(restoration_info.sgr_proj_info.multiplier[0] == 0); const auto temp_stride = Align<ptrdiff_t>(width, 32); const auto sum_width = temp_stride + 8; @@ -2794,8 +2829,8 @@ inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, b444[0] = sgr_buffer->b444; b444[1] = b444[0] + temp_stride; assert(scale != 0); - BoxSum<3>(top_border, stride, width, sum_stride, temp_stride, sum3[0], - square_sum3[0]); + BoxSum<3>(top_border, top_border_stride, width, sum_stride, temp_stride, + sum3[0], square_sum3[0]); BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3, sum_width, ma343[0], nullptr, b343[0], nullptr); @@ -2806,7 +2841,7 @@ inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, s = src + stride; } else { s = bottom_border; - bottom_border += stride; + bottom_border += bottom_border_stride; } BoxSumFilterPreProcess3<true>(s, width, scale, sum3, square_sum3, sum_width, ma343[1], ma444[0], b343[1], b444[0]); @@ -2833,7 +2868,7 @@ inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, square_sum3, ma343, ma444, b343, b444, dst); src += stride; dst += stride; - bottom_border += stride; + bottom_border += bottom_border_stride; Circulate3PointersBy1<uint16_t>(ma343); Circulate3PointersBy1<uint32_t>(b343); std::swap(ma444[0], ma444[1]); @@ -2841,13 +2876,14 @@ inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, } while (--y != 0); } -// If |width| is non-multiple of 8, up to 7 more pixels are written to |dest| in -// the end of each row. It is safe to overwrite the output as it will not be +// If |width| is non-multiple of 32, up to 31 more pixels are written to |dest| +// in the end of each row. It is safe to overwrite the output as it will not be // part of the visible frame. void SelfGuidedFilter_AVX2( const RestorationUnitInfo& restoration_info, const void* const source, - const void* const top_border, const void* const bottom_border, - const ptrdiff_t stride, const int width, const int height, + const ptrdiff_t stride, const void* const top_border, + const ptrdiff_t top_border_stride, const void* const bottom_border, + const ptrdiff_t bottom_border_stride, const int width, const int height, RestorationBuffer* const restoration_buffer, void* const dest) { const int index = restoration_info.sgr_proj_info.index; const int radius_pass_0 = kSgrProjParams[index][0]; // 2 or 0 @@ -2861,14 +2897,17 @@ void SelfGuidedFilter_AVX2( // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the // following assertion. assert(radius_pass_0 != 0); - BoxFilterProcessPass1(restoration_info, src - 3, top - 3, bottom - 3, - stride, width, height, sgr_buffer, dst); + BoxFilterProcessPass1(restoration_info, src - 3, stride, top - 3, + top_border_stride, bottom - 3, bottom_border_stride, + width, height, sgr_buffer, dst); } else if (radius_pass_0 == 0) { - BoxFilterProcessPass2(restoration_info, src - 2, top - 2, bottom - 2, - stride, width, height, sgr_buffer, dst); + BoxFilterProcessPass2(restoration_info, src - 2, stride, top - 2, + top_border_stride, bottom - 2, bottom_border_stride, + width, height, sgr_buffer, dst); } else { - BoxFilterProcess(restoration_info, src - 3, top - 3, bottom - 3, stride, - width, height, sgr_buffer, dst); + BoxFilterProcess(restoration_info, src - 3, stride, top - 3, + top_border_stride, bottom - 3, bottom_border_stride, width, + height, sgr_buffer, dst); } } @@ -2891,7 +2930,7 @@ void LoopRestorationInit_AVX2() { low_bitdepth::Init8bpp(); } } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_AVX2 +#else // !LIBGAV1_TARGETING_AVX2 namespace libgav1 { namespace dsp { |