diff options
Diffstat (limited to 'src/dsp/x86/convolve_avx2.cc')
-rw-r--r-- | src/dsp/x86/convolve_avx2.cc | 1286 |
1 files changed, 1148 insertions, 138 deletions
diff --git a/src/dsp/x86/convolve_avx2.cc b/src/dsp/x86/convolve_avx2.cc index 3df2120..2ecb77c 100644 --- a/src/dsp/x86/convolve_avx2.cc +++ b/src/dsp/x86/convolve_avx2.cc @@ -26,7 +26,6 @@ #include "src/dsp/constants.h" #include "src/dsp/dsp.h" #include "src/dsp/x86/common_avx2.h" -#include "src/dsp/x86/common_sse4.h" #include "src/utils/common.h" #include "src/utils/constants.h" @@ -35,7 +34,7 @@ namespace dsp { namespace low_bitdepth { namespace { -constexpr int kHorizontalOffset = 3; +#include "src/dsp/x86/convolve_sse4.inc" // Multiply every entry in |src[]| by the corresponding entry in |taps[]| and // sum. The filters in |taps[]| are pre-shifted by 1. This prevents the final @@ -118,58 +117,15 @@ __m256i SimpleHorizontalTaps(const __m256i* const src, } template <int filter_index> -__m128i SumHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride, - const __m128i* const v_tap) { - // 00 01 02 03 04 05 06 07 10 11 12 13 14 15 16 17 - const __m128i v_src = LoadHi8(LoadLo8(&src[0]), &src[src_stride]); - - if (filter_index == 3) { - // 03 04 04 05 05 06 06 07 13 14 14 15 15 16 16 17 - const __m128i v_src_43 = _mm_shuffle_epi8( - v_src, _mm_set_epi32(0x0f0e0e0d, 0x0d0c0c0b, 0x07060605, 0x05040403)); - const __m128i v_sum_43 = _mm_maddubs_epi16(v_src_43, v_tap[0]); // k4k3 - return v_sum_43; - } - - // 02 03 03 04 04 05 05 06 12 13 13 14 14 15 15 16 - const __m128i v_src_32 = _mm_shuffle_epi8( - v_src, _mm_set_epi32(0x0e0d0d0c, 0x0c0b0b0a, 0x06050504, 0x04030302)); - // 04 05 05 06 06 07 07 xx 14 15 15 16 16 17 17 xx - const __m128i v_src_54 = _mm_shuffle_epi8( - v_src, _mm_set_epi32(0x800f0f0e, 0x0e0d0d0c, 0x80070706, 0x06050504)); - const __m128i v_madd_32 = _mm_maddubs_epi16(v_src_32, v_tap[0]); // k3k2 - const __m128i v_madd_54 = _mm_maddubs_epi16(v_src_54, v_tap[1]); // k5k4 - const __m128i v_sum_5432 = _mm_add_epi16(v_madd_54, v_madd_32); - return v_sum_5432; -} - -template <int filter_index> -__m128i SimpleHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride, - const __m128i* const v_tap) { - __m128i sum = SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); - - // Normally the Horizontal pass does the downshift in two passes: - // kInterRoundBitsHorizontal - 1 and then (kFilterBits - - // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them - // requires adding the rounding offset from the skipped shift. - constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2); - - sum = _mm_add_epi16(sum, _mm_set1_epi16(first_shift_rounding_bit)); - sum = RightShiftWithRounding_S16(sum, kFilterBits - 1); - return _mm_packus_epi16(sum, sum); -} - -template <int filter_index> -__m128i HorizontalTaps8To16_2x2(const uint8_t* src, const ptrdiff_t src_stride, - const __m128i* const v_tap) { - const __m128i sum = - SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); +__m256i HorizontalTaps8To16(const __m256i* const src, + const __m256i* const v_tap) { + const __m256i sum = SumHorizontalTaps<filter_index>(src, v_tap); return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1); } // Filter 2xh sizes. -template <int num_taps, int step, int filter_index, bool is_2d = false, +template <int num_taps, int filter_index, bool is_2d = false, bool is_compound = false> void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, void* const dest, const ptrdiff_t pred_stride, @@ -183,7 +139,8 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, assert(num_taps <= 4); if (num_taps <= 4) { if (!is_compound) { - int y = 0; + int y = height; + if (is_2d) y -= 1; do { if (is_2d) { const __m128i sum = @@ -202,8 +159,8 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, } src += src_stride << 1; - y += 2; - } while (y < height - 1); + y -= 2; + } while (y != 0); // The 2d filters have an odd |height| because the horizontal pass // generates context for the vertical pass. @@ -236,7 +193,7 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, } // Filter widths >= 4. -template <int num_taps, int step, int filter_index, bool is_2d = false, +template <int num_taps, int filter_index, bool is_2d = false, bool is_compound = false> void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, void* const dest, const ptrdiff_t pred_stride, @@ -251,7 +208,22 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, int x = 0; do { if (is_2d || is_compound) { - // placeholder + // Load into 2 128 bit lanes. + const __m256i src_long = + SetrM128i(LoadUnaligned16(&src[x]), LoadUnaligned16(&src[x + 8])); + const __m256i result = + HorizontalTaps8To16<filter_index>(&src_long, v_tap); + const __m256i src_long2 = SetrM128i(LoadUnaligned16(&src[x + 16]), + LoadUnaligned16(&src[x + 24])); + const __m256i result2 = + HorizontalTaps8To16<filter_index>(&src_long2, v_tap); + if (is_2d) { + StoreAligned32(&dest16[x], result); + StoreAligned32(&dest16[x + 16], result2); + } else { + StoreUnaligned32(&dest16[x], result); + StoreUnaligned32(&dest16[x + 16], result2); + } } else { // Load src used to calculate dest8[7:0] and dest8[23:16]. const __m256i src_long = LoadUnaligned32(&src[x]); @@ -264,7 +236,7 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, // Combine results and store. StoreUnaligned32(&dest8[x], _mm256_unpacklo_epi64(result, result2)); } - x += step * 4; + x += 32; } while (x < width); src += src_stride; dest8 += pred_stride; @@ -272,9 +244,26 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, } while (--y != 0); } else if (width == 16) { int y = height; + if (is_2d) y -= 1; do { if (is_2d || is_compound) { - // placeholder + // Load into 2 128 bit lanes. + const __m256i src_long = + SetrM128i(LoadUnaligned16(&src[0]), LoadUnaligned16(&src[8])); + const __m256i result = + HorizontalTaps8To16<filter_index>(&src_long, v_tap); + const __m256i src_long2 = + SetrM128i(LoadUnaligned16(&src[src_stride]), + LoadUnaligned16(&src[8 + src_stride])); + const __m256i result2 = + HorizontalTaps8To16<filter_index>(&src_long2, v_tap); + if (is_2d) { + StoreAligned32(&dest16[0], result); + StoreAligned32(&dest16[pred_stride], result2); + } else { + StoreUnaligned32(&dest16[0], result); + StoreUnaligned32(&dest16[pred_stride], result2); + } } else { // Load into 2 128 bit lanes. const __m256i src_long = SetrM128i(LoadUnaligned16(&src[0]), @@ -295,11 +284,37 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, dest16 += pred_stride * 2; y -= 2; } while (y != 0); + + // The 2d filters have an odd |height| during the horizontal pass, so + // filter the remaining row. + if (is_2d) { + const __m256i src_long = + SetrM128i(LoadUnaligned16(&src[0]), LoadUnaligned16(&src[8])); + const __m256i result = + HorizontalTaps8To16<filter_index>(&src_long, v_tap); + StoreAligned32(&dest16[0], result); + } + } else if (width == 8) { int y = height; + if (is_2d) y -= 1; do { + // Load into 2 128 bit lanes. + const __m128i this_row = LoadUnaligned16(&src[0]); + const __m128i next_row = LoadUnaligned16(&src[src_stride]); + const __m256i src_long = SetrM128i(this_row, next_row); if (is_2d || is_compound) { - // placeholder + const __m256i result = + HorizontalTaps8To16<filter_index>(&src_long, v_tap); + if (is_2d) { + StoreAligned16(&dest16[0], _mm256_castsi256_si128(result)); + StoreAligned16(&dest16[pred_stride], + _mm256_extracti128_si256(result, 1)); + } else { + StoreUnaligned16(&dest16[0], _mm256_castsi256_si128(result)); + StoreUnaligned16(&dest16[pred_stride], + _mm256_extracti128_si256(result, 1)); + } } else { const __m128i this_row = LoadUnaligned16(&src[0]); const __m128i next_row = LoadUnaligned16(&src[src_stride]); @@ -315,11 +330,29 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, dest16 += pred_stride * 2; y -= 2; } while (y != 0); + + // The 2d filters have an odd |height| during the horizontal pass, so + // filter the remaining row. + if (is_2d) { + const __m256i src_long = _mm256_castsi128_si256(LoadUnaligned16(&src[0])); + const __m256i result = + HorizontalTaps8To16<filter_index>(&src_long, v_tap); + StoreAligned16(&dest16[0], _mm256_castsi256_si128(result)); + } + } else { // width == 4 int y = height; + if (is_2d) y -= 1; do { + // Load into 2 128 bit lanes. + const __m128i this_row = LoadUnaligned16(&src[0]); + const __m128i next_row = LoadUnaligned16(&src[src_stride]); + const __m256i src_long = SetrM128i(this_row, next_row); if (is_2d || is_compound) { - // placeholder + const __m256i result = + HorizontalTaps8To16<filter_index>(&src_long, v_tap); + StoreLo8(&dest16[0], _mm256_castsi256_si128(result)); + StoreLo8(&dest16[pred_stride], _mm256_extracti128_si256(result, 1)); } else { const __m128i this_row = LoadUnaligned16(&src[0]); const __m128i next_row = LoadUnaligned16(&src[src_stride]); @@ -335,93 +368,176 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, dest16 += pred_stride * 2; y -= 2; } while (y != 0); + + // The 2d filters have an odd |height| during the horizontal pass, so + // filter the remaining row. + if (is_2d) { + const __m256i src_long = _mm256_castsi128_si256(LoadUnaligned16(&src[0])); + const __m256i result = + HorizontalTaps8To16<filter_index>(&src_long, v_tap); + StoreLo8(&dest16[0], _mm256_castsi256_si128(result)); + } } } template <int num_taps, bool is_2d_vertical = false> LIBGAV1_ALWAYS_INLINE void SetupTaps(const __m128i* const filter, - __m128i* v_tap) { + __m256i* v_tap) { if (num_taps == 8) { - v_tap[0] = _mm_shufflelo_epi16(*filter, 0x0); // k1k0 - v_tap[1] = _mm_shufflelo_epi16(*filter, 0x55); // k3k2 - v_tap[2] = _mm_shufflelo_epi16(*filter, 0xaa); // k5k4 - v_tap[3] = _mm_shufflelo_epi16(*filter, 0xff); // k7k6 if (is_2d_vertical) { - v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); - v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]); - v_tap[2] = _mm_cvtepi8_epi16(v_tap[2]); - v_tap[3] = _mm_cvtepi8_epi16(v_tap[3]); + v_tap[0] = _mm256_broadcastd_epi32(*filter); // k1k0 + v_tap[1] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 4)); // k3k2 + v_tap[2] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 8)); // k5k4 + v_tap[3] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 12)); // k7k6 } else { - v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); - v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]); - v_tap[2] = _mm_unpacklo_epi64(v_tap[2], v_tap[2]); - v_tap[3] = _mm_unpacklo_epi64(v_tap[3], v_tap[3]); + v_tap[0] = _mm256_broadcastw_epi16(*filter); // k1k0 + v_tap[1] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 2)); // k3k2 + v_tap[2] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 4)); // k5k4 + v_tap[3] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 6)); // k7k6 } } else if (num_taps == 6) { - const __m128i adjusted_filter = _mm_srli_si128(*filter, 1); - v_tap[0] = _mm_shufflelo_epi16(adjusted_filter, 0x0); // k2k1 - v_tap[1] = _mm_shufflelo_epi16(adjusted_filter, 0x55); // k4k3 - v_tap[2] = _mm_shufflelo_epi16(adjusted_filter, 0xaa); // k6k5 if (is_2d_vertical) { - v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); - v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]); - v_tap[2] = _mm_cvtepi8_epi16(v_tap[2]); + v_tap[0] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 2)); // k2k1 + v_tap[1] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 6)); // k4k3 + v_tap[2] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 10)); // k6k5 } else { - v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); - v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]); - v_tap[2] = _mm_unpacklo_epi64(v_tap[2], v_tap[2]); + v_tap[0] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 1)); // k2k1 + v_tap[1] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 3)); // k4k3 + v_tap[2] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 5)); // k6k5 } } else if (num_taps == 4) { - v_tap[0] = _mm_shufflelo_epi16(*filter, 0x55); // k3k2 - v_tap[1] = _mm_shufflelo_epi16(*filter, 0xaa); // k5k4 if (is_2d_vertical) { - v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); - v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]); + v_tap[0] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 4)); // k3k2 + v_tap[1] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 8)); // k5k4 } else { - v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); - v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]); + v_tap[0] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 2)); // k3k2 + v_tap[1] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 4)); // k5k4 } } else { // num_taps == 2 - const __m128i adjusted_filter = _mm_srli_si128(*filter, 1); - v_tap[0] = _mm_shufflelo_epi16(adjusted_filter, 0x55); // k4k3 if (is_2d_vertical) { - v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); + v_tap[0] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 6)); // k4k3 } else { - v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); + v_tap[0] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 3)); // k4k3 } } } -template <int num_taps, bool is_2d_vertical = false> -LIBGAV1_ALWAYS_INLINE void SetupTaps(const __m128i* const filter, - __m256i* v_tap) { - if (num_taps == 8) { - v_tap[0] = _mm256_broadcastw_epi16(*filter); // k1k0 - v_tap[1] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 2)); // k3k2 - v_tap[2] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 4)); // k5k4 - v_tap[3] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 6)); // k7k6 - if (is_2d_vertical) { - // placeholder - } - } else if (num_taps == 6) { - v_tap[0] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 1)); // k2k1 - v_tap[1] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 3)); // k4k3 - v_tap[2] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 5)); // k6k5 - if (is_2d_vertical) { - // placeholder - } - } else if (num_taps == 4) { - v_tap[0] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 2)); // k3k2 - v_tap[1] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 4)); // k5k4 - if (is_2d_vertical) { - // placeholder - } - } else { // num_taps == 2 - v_tap[0] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 3)); // k4k3 - if (is_2d_vertical) { - // placeholder +template <int num_taps, bool is_compound> +__m256i SimpleSum2DVerticalTaps(const __m256i* const src, + const __m256i* const taps) { + __m256i sum_lo = + _mm256_madd_epi16(_mm256_unpacklo_epi16(src[0], src[1]), taps[0]); + __m256i sum_hi = + _mm256_madd_epi16(_mm256_unpackhi_epi16(src[0], src[1]), taps[0]); + if (num_taps >= 4) { + __m256i madd_lo = + _mm256_madd_epi16(_mm256_unpacklo_epi16(src[2], src[3]), taps[1]); + __m256i madd_hi = + _mm256_madd_epi16(_mm256_unpackhi_epi16(src[2], src[3]), taps[1]); + sum_lo = _mm256_add_epi32(sum_lo, madd_lo); + sum_hi = _mm256_add_epi32(sum_hi, madd_hi); + if (num_taps >= 6) { + madd_lo = + _mm256_madd_epi16(_mm256_unpacklo_epi16(src[4], src[5]), taps[2]); + madd_hi = + _mm256_madd_epi16(_mm256_unpackhi_epi16(src[4], src[5]), taps[2]); + sum_lo = _mm256_add_epi32(sum_lo, madd_lo); + sum_hi = _mm256_add_epi32(sum_hi, madd_hi); + if (num_taps == 8) { + madd_lo = + _mm256_madd_epi16(_mm256_unpacklo_epi16(src[6], src[7]), taps[3]); + madd_hi = + _mm256_madd_epi16(_mm256_unpackhi_epi16(src[6], src[7]), taps[3]); + sum_lo = _mm256_add_epi32(sum_lo, madd_lo); + sum_hi = _mm256_add_epi32(sum_hi, madd_hi); + } } } + + if (is_compound) { + return _mm256_packs_epi32( + RightShiftWithRounding_S32(sum_lo, kInterRoundBitsCompoundVertical - 1), + RightShiftWithRounding_S32(sum_hi, + kInterRoundBitsCompoundVertical - 1)); + } + + return _mm256_packs_epi32( + RightShiftWithRounding_S32(sum_lo, kInterRoundBitsVertical - 1), + RightShiftWithRounding_S32(sum_hi, kInterRoundBitsVertical - 1)); +} + +template <int num_taps, bool is_compound = false> +void Filter2DVertical16xH(const uint16_t* src, void* const dst, + const ptrdiff_t dst_stride, const int width, + const int height, const __m256i* const taps) { + assert(width >= 8); + constexpr int next_row = num_taps - 1; + // The Horizontal pass uses |width| as |stride| for the intermediate buffer. + const ptrdiff_t src_stride = width; + + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + int x = 0; + do { + __m256i srcs[8]; + const uint16_t* src_x = src + x; + srcs[0] = LoadAligned32(src_x); + src_x += src_stride; + if (num_taps >= 4) { + srcs[1] = LoadAligned32(src_x); + src_x += src_stride; + srcs[2] = LoadAligned32(src_x); + src_x += src_stride; + if (num_taps >= 6) { + srcs[3] = LoadAligned32(src_x); + src_x += src_stride; + srcs[4] = LoadAligned32(src_x); + src_x += src_stride; + if (num_taps == 8) { + srcs[5] = LoadAligned32(src_x); + src_x += src_stride; + srcs[6] = LoadAligned32(src_x); + src_x += src_stride; + } + } + } + + auto* dst8_x = dst8 + x; + auto* dst16_x = dst16 + x; + int y = height; + do { + srcs[next_row] = LoadAligned32(src_x); + src_x += src_stride; + + const __m256i sum = + SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps); + if (is_compound) { + StoreUnaligned32(dst16_x, sum); + dst16_x += dst_stride; + } else { + const __m128i packed_sum = _mm_packus_epi16( + _mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1)); + StoreUnaligned16(dst8_x, packed_sum); + dst8_x += dst_stride; + } + + srcs[0] = srcs[1]; + if (num_taps >= 4) { + srcs[1] = srcs[2]; + srcs[2] = srcs[3]; + if (num_taps >= 6) { + srcs[3] = srcs[4]; + srcs[4] = srcs[5]; + if (num_taps == 8) { + srcs[5] = srcs[6]; + srcs[6] = srcs[7]; + } + } + } + } while (--y != 0); + x += 16; + } while (x < width); } template <bool is_2d = false, bool is_compound = false> @@ -436,16 +552,16 @@ LIBGAV1_ALWAYS_INLINE void DoHorizontalPass2xH( if (filter_index == 4) { // 4 tap. SetupTaps<4>(&v_horizontal_filter, v_tap); - FilterHorizontal<4, 8, 4, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<4, 4, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else if (filter_index == 5) { // 4 tap. SetupTaps<4>(&v_horizontal_filter, v_tap); - FilterHorizontal<4, 8, 5, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<4, 5, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else { // 2 tap. SetupTaps<2>(&v_horizontal_filter, v_tap); - FilterHorizontal<2, 8, 3, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<2, 3, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } } @@ -461,28 +577,792 @@ LIBGAV1_ALWAYS_INLINE void DoHorizontalPass( if (filter_index == 2) { // 8 tap. SetupTaps<8>(&v_horizontal_filter, v_tap); - FilterHorizontal<8, 8, 2, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<8, 2, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else if (filter_index == 1) { // 6 tap. SetupTaps<6>(&v_horizontal_filter, v_tap); - FilterHorizontal<6, 8, 1, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<6, 1, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else if (filter_index == 0) { // 6 tap. SetupTaps<6>(&v_horizontal_filter, v_tap); - FilterHorizontal<6, 8, 0, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<6, 0, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else if (filter_index == 4) { // 4 tap. SetupTaps<4>(&v_horizontal_filter, v_tap); - FilterHorizontal<4, 8, 4, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<4, 4, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else if (filter_index == 5) { // 4 tap. SetupTaps<4>(&v_horizontal_filter, v_tap); - FilterHorizontal<4, 8, 5, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<4, 5, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else { // 2 tap. SetupTaps<2>(&v_horizontal_filter, v_tap); - FilterHorizontal<2, 8, 3, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<2, 3, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); + } +} + +void Convolve2D_AVX2(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int vertical_filter_index, + const int horizontal_filter_id, + const int vertical_filter_id, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width); + const int vert_filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(vert_filter_index); + + // The output of the horizontal filter is guaranteed to fit in 16 bits. + alignas(32) uint16_t + intermediate_result[kMaxSuperBlockSizeInPixels * + (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)]; + const int intermediate_height = height + vertical_taps - 1; + + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride - kHorizontalOffset; + if (width > 2) { + DoHorizontalPass</*is_2d=*/true>(src, src_stride, intermediate_result, + width, width, intermediate_height, + horizontal_filter_id, horiz_filter_index); + } else { + // Use non avx2 version for smaller widths. + DoHorizontalPass2xH</*is_2d=*/true>( + src, src_stride, intermediate_result, width, width, intermediate_height, + horizontal_filter_id, horiz_filter_index); + } + + // Vertical filter. + auto* dest = static_cast<uint8_t*>(prediction); + const ptrdiff_t dest_stride = pred_stride; + assert(vertical_filter_id != 0); + + const __m128i v_filter = + LoadLo8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id]); + + // Use 256 bits for width > 8. + if (width > 8) { + __m256i taps_256[4]; + const __m128i v_filter_ext = _mm_cvtepi8_epi16(v_filter); + + if (vertical_taps == 8) { + SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256); + Filter2DVertical16xH<8>(intermediate_result, dest, dest_stride, width, + height, taps_256); + } else if (vertical_taps == 6) { + SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256); + Filter2DVertical16xH<6>(intermediate_result, dest, dest_stride, width, + height, taps_256); + } else if (vertical_taps == 4) { + SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256); + Filter2DVertical16xH<4>(intermediate_result, dest, dest_stride, width, + height, taps_256); + } else { // |vertical_taps| == 2 + SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256); + Filter2DVertical16xH<2>(intermediate_result, dest, dest_stride, width, + height, taps_256); + } + } else { // width <= 8 + __m128i taps[4]; + // Use 128 bit code. + if (vertical_taps == 8) { + SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 2) { + Filter2DVertical2xH<8>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<8>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<8>(intermediate_result, dest, dest_stride, width, + height, taps); + } + } else if (vertical_taps == 6) { + SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 2) { + Filter2DVertical2xH<6>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<6>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<6>(intermediate_result, dest, dest_stride, width, + height, taps); + } + } else if (vertical_taps == 4) { + SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 2) { + Filter2DVertical2xH<4>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<4>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<4>(intermediate_result, dest, dest_stride, width, + height, taps); + } + } else { // |vertical_taps| == 2 + SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 2) { + Filter2DVertical2xH<2>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<2>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<2>(intermediate_result, dest, dest_stride, width, + height, taps); + } + } + } +} + +// The 1D compound shift is always |kInterRoundBitsHorizontal|, even for 1D +// Vertical calculations. +__m256i Compound1DShift(const __m256i sum) { + return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1); +} + +template <int filter_index, bool unpack_high = false> +__m256i SumVerticalTaps(const __m256i* const srcs, const __m256i* const v_tap) { + __m256i v_src[4]; + + if (!unpack_high) { + if (filter_index < 2) { + // 6 taps. + v_src[0] = _mm256_unpacklo_epi8(srcs[0], srcs[1]); + v_src[1] = _mm256_unpacklo_epi8(srcs[2], srcs[3]); + v_src[2] = _mm256_unpacklo_epi8(srcs[4], srcs[5]); + } else if (filter_index == 2) { + // 8 taps. + v_src[0] = _mm256_unpacklo_epi8(srcs[0], srcs[1]); + v_src[1] = _mm256_unpacklo_epi8(srcs[2], srcs[3]); + v_src[2] = _mm256_unpacklo_epi8(srcs[4], srcs[5]); + v_src[3] = _mm256_unpacklo_epi8(srcs[6], srcs[7]); + } else if (filter_index == 3) { + // 2 taps. + v_src[0] = _mm256_unpacklo_epi8(srcs[0], srcs[1]); + } else if (filter_index > 3) { + // 4 taps. + v_src[0] = _mm256_unpacklo_epi8(srcs[0], srcs[1]); + v_src[1] = _mm256_unpacklo_epi8(srcs[2], srcs[3]); + } + } else { + if (filter_index < 2) { + // 6 taps. + v_src[0] = _mm256_unpackhi_epi8(srcs[0], srcs[1]); + v_src[1] = _mm256_unpackhi_epi8(srcs[2], srcs[3]); + v_src[2] = _mm256_unpackhi_epi8(srcs[4], srcs[5]); + } else if (filter_index == 2) { + // 8 taps. + v_src[0] = _mm256_unpackhi_epi8(srcs[0], srcs[1]); + v_src[1] = _mm256_unpackhi_epi8(srcs[2], srcs[3]); + v_src[2] = _mm256_unpackhi_epi8(srcs[4], srcs[5]); + v_src[3] = _mm256_unpackhi_epi8(srcs[6], srcs[7]); + } else if (filter_index == 3) { + // 2 taps. + v_src[0] = _mm256_unpackhi_epi8(srcs[0], srcs[1]); + } else if (filter_index > 3) { + // 4 taps. + v_src[0] = _mm256_unpackhi_epi8(srcs[0], srcs[1]); + v_src[1] = _mm256_unpackhi_epi8(srcs[2], srcs[3]); + } + } + return SumOnePassTaps<filter_index>(v_src, v_tap); +} + +template <int filter_index, bool is_compound = false> +void FilterVertical32xH(const uint8_t* src, const ptrdiff_t src_stride, + void* const dst, const ptrdiff_t dst_stride, + const int width, const int height, + const __m256i* const v_tap) { + const int num_taps = GetNumTapsInFilter(filter_index); + const int next_row = num_taps - 1; + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + assert(width >= 32); + int x = 0; + do { + const uint8_t* src_x = src + x; + __m256i srcs[8]; + srcs[0] = LoadUnaligned32(src_x); + src_x += src_stride; + if (num_taps >= 4) { + srcs[1] = LoadUnaligned32(src_x); + src_x += src_stride; + srcs[2] = LoadUnaligned32(src_x); + src_x += src_stride; + if (num_taps >= 6) { + srcs[3] = LoadUnaligned32(src_x); + src_x += src_stride; + srcs[4] = LoadUnaligned32(src_x); + src_x += src_stride; + if (num_taps == 8) { + srcs[5] = LoadUnaligned32(src_x); + src_x += src_stride; + srcs[6] = LoadUnaligned32(src_x); + src_x += src_stride; + } + } + } + + auto* dst8_x = dst8 + x; + auto* dst16_x = dst16 + x; + int y = height; + do { + srcs[next_row] = LoadUnaligned32(src_x); + src_x += src_stride; + + const __m256i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + const __m256i sums_hi = + SumVerticalTaps<filter_index, /*unpack_high=*/true>(srcs, v_tap); + if (is_compound) { + const __m256i results = + Compound1DShift(_mm256_permute2x128_si256(sums, sums_hi, 0x20)); + const __m256i results_hi = + Compound1DShift(_mm256_permute2x128_si256(sums, sums_hi, 0x31)); + StoreUnaligned32(dst16_x, results); + StoreUnaligned32(dst16_x + 16, results_hi); + dst16_x += dst_stride; + } else { + const __m256i results = + RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m256i results_hi = + RightShiftWithRounding_S16(sums_hi, kFilterBits - 1); + const __m256i packed_results = _mm256_packus_epi16(results, results_hi); + + StoreUnaligned32(dst8_x, packed_results); + dst8_x += dst_stride; + } + + srcs[0] = srcs[1]; + if (num_taps >= 4) { + srcs[1] = srcs[2]; + srcs[2] = srcs[3]; + if (num_taps >= 6) { + srcs[3] = srcs[4]; + srcs[4] = srcs[5]; + if (num_taps == 8) { + srcs[5] = srcs[6]; + srcs[6] = srcs[7]; + } + } + } + } while (--y != 0); + x += 32; + } while (x < width); +} + +template <int filter_index, bool is_compound = false> +void FilterVertical16xH(const uint8_t* src, const ptrdiff_t src_stride, + void* const dst, const ptrdiff_t dst_stride, + const int /*width*/, const int height, + const __m256i* const v_tap) { + const int num_taps = GetNumTapsInFilter(filter_index); + const int next_row = num_taps; + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + const uint8_t* src_x = src; + __m256i srcs[8 + 1]; + // The upper 128 bits hold the filter data for the next row. + srcs[0] = _mm256_castsi128_si256(LoadUnaligned16(src_x)); + src_x += src_stride; + if (num_taps >= 4) { + srcs[1] = _mm256_castsi128_si256(LoadUnaligned16(src_x)); + src_x += src_stride; + srcs[0] = + _mm256_inserti128_si256(srcs[0], _mm256_castsi256_si128(srcs[1]), 1); + srcs[2] = _mm256_castsi128_si256(LoadUnaligned16(src_x)); + src_x += src_stride; + srcs[1] = + _mm256_inserti128_si256(srcs[1], _mm256_castsi256_si128(srcs[2]), 1); + if (num_taps >= 6) { + srcs[3] = _mm256_castsi128_si256(LoadUnaligned16(src_x)); + src_x += src_stride; + srcs[2] = + _mm256_inserti128_si256(srcs[2], _mm256_castsi256_si128(srcs[3]), 1); + srcs[4] = _mm256_castsi128_si256(LoadUnaligned16(src_x)); + src_x += src_stride; + srcs[3] = + _mm256_inserti128_si256(srcs[3], _mm256_castsi256_si128(srcs[4]), 1); + if (num_taps == 8) { + srcs[5] = _mm256_castsi128_si256(LoadUnaligned16(src_x)); + src_x += src_stride; + srcs[4] = _mm256_inserti128_si256(srcs[4], + _mm256_castsi256_si128(srcs[5]), 1); + srcs[6] = _mm256_castsi128_si256(LoadUnaligned16(src_x)); + src_x += src_stride; + srcs[5] = _mm256_inserti128_si256(srcs[5], + _mm256_castsi256_si128(srcs[6]), 1); + } + } + } + + int y = height; + do { + srcs[next_row - 1] = _mm256_castsi128_si256(LoadUnaligned16(src_x)); + src_x += src_stride; + + srcs[next_row - 2] = _mm256_inserti128_si256( + srcs[next_row - 2], _mm256_castsi256_si128(srcs[next_row - 1]), 1); + + srcs[next_row] = _mm256_castsi128_si256(LoadUnaligned16(src_x)); + src_x += src_stride; + + srcs[next_row - 1] = _mm256_inserti128_si256( + srcs[next_row - 1], _mm256_castsi256_si128(srcs[next_row]), 1); + + const __m256i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + const __m256i sums_hi = + SumVerticalTaps<filter_index, /*unpack_high=*/true>(srcs, v_tap); + if (is_compound) { + const __m256i results = + Compound1DShift(_mm256_permute2x128_si256(sums, sums_hi, 0x20)); + const __m256i results_hi = + Compound1DShift(_mm256_permute2x128_si256(sums, sums_hi, 0x31)); + + StoreUnaligned32(dst16, results); + StoreUnaligned32(dst16 + dst_stride, results_hi); + dst16 += dst_stride << 1; + } else { + const __m256i results = RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m256i results_hi = + RightShiftWithRounding_S16(sums_hi, kFilterBits - 1); + const __m256i packed_results = _mm256_packus_epi16(results, results_hi); + const __m128i this_dst = _mm256_castsi256_si128(packed_results); + const auto next_dst = _mm256_extracti128_si256(packed_results, 1); + + StoreUnaligned16(dst8, this_dst); + StoreUnaligned16(dst8 + dst_stride, next_dst); + dst8 += dst_stride << 1; + } + + srcs[0] = srcs[2]; + if (num_taps >= 4) { + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + if (num_taps >= 6) { + srcs[3] = srcs[5]; + srcs[4] = srcs[6]; + if (num_taps == 8) { + srcs[5] = srcs[7]; + srcs[6] = srcs[8]; + } + } + } + y -= 2; + } while (y != 0); +} + +template <int filter_index, bool is_compound = false> +void FilterVertical8xH(const uint8_t* src, const ptrdiff_t src_stride, + void* const dst, const ptrdiff_t dst_stride, + const int /*width*/, const int height, + const __m256i* const v_tap) { + const int num_taps = GetNumTapsInFilter(filter_index); + const int next_row = num_taps; + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + const uint8_t* src_x = src; + __m256i srcs[8 + 1]; + // The upper 128 bits hold the filter data for the next row. + srcs[0] = _mm256_castsi128_si256(LoadLo8(src_x)); + src_x += src_stride; + if (num_taps >= 4) { + srcs[1] = _mm256_castsi128_si256(LoadLo8(src_x)); + src_x += src_stride; + srcs[0] = + _mm256_inserti128_si256(srcs[0], _mm256_castsi256_si128(srcs[1]), 1); + srcs[2] = _mm256_castsi128_si256(LoadLo8(src_x)); + src_x += src_stride; + srcs[1] = + _mm256_inserti128_si256(srcs[1], _mm256_castsi256_si128(srcs[2]), 1); + if (num_taps >= 6) { + srcs[3] = _mm256_castsi128_si256(LoadLo8(src_x)); + src_x += src_stride; + srcs[2] = + _mm256_inserti128_si256(srcs[2], _mm256_castsi256_si128(srcs[3]), 1); + srcs[4] = _mm256_castsi128_si256(LoadLo8(src_x)); + src_x += src_stride; + srcs[3] = + _mm256_inserti128_si256(srcs[3], _mm256_castsi256_si128(srcs[4]), 1); + if (num_taps == 8) { + srcs[5] = _mm256_castsi128_si256(LoadLo8(src_x)); + src_x += src_stride; + srcs[4] = _mm256_inserti128_si256(srcs[4], + _mm256_castsi256_si128(srcs[5]), 1); + srcs[6] = _mm256_castsi128_si256(LoadLo8(src_x)); + src_x += src_stride; + srcs[5] = _mm256_inserti128_si256(srcs[5], + _mm256_castsi256_si128(srcs[6]), 1); + } + } + } + + int y = height; + do { + srcs[next_row - 1] = _mm256_castsi128_si256(LoadLo8(src_x)); + src_x += src_stride; + + srcs[next_row - 2] = _mm256_inserti128_si256( + srcs[next_row - 2], _mm256_castsi256_si128(srcs[next_row - 1]), 1); + + srcs[next_row] = _mm256_castsi128_si256(LoadLo8(src_x)); + src_x += src_stride; + + srcs[next_row - 1] = _mm256_inserti128_si256( + srcs[next_row - 1], _mm256_castsi256_si128(srcs[next_row]), 1); + + const __m256i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + if (is_compound) { + const __m256i results = Compound1DShift(sums); + const __m128i this_dst = _mm256_castsi256_si128(results); + const auto next_dst = _mm256_extracti128_si256(results, 1); + + StoreUnaligned16(dst16, this_dst); + StoreUnaligned16(dst16 + dst_stride, next_dst); + dst16 += dst_stride << 1; + } else { + const __m256i results = RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m256i packed_results = _mm256_packus_epi16(results, results); + const __m128i this_dst = _mm256_castsi256_si128(packed_results); + const auto next_dst = _mm256_extracti128_si256(packed_results, 1); + + StoreLo8(dst8, this_dst); + StoreLo8(dst8 + dst_stride, next_dst); + dst8 += dst_stride << 1; + } + + srcs[0] = srcs[2]; + if (num_taps >= 4) { + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + if (num_taps >= 6) { + srcs[3] = srcs[5]; + srcs[4] = srcs[6]; + if (num_taps == 8) { + srcs[5] = srcs[7]; + srcs[6] = srcs[8]; + } + } + } + y -= 2; + } while (y != 0); +} + +template <int filter_index, bool is_compound = false> +void FilterVertical8xH(const uint8_t* src, const ptrdiff_t src_stride, + void* const dst, const ptrdiff_t dst_stride, + const int /*width*/, const int height, + const __m128i* const v_tap) { + const int num_taps = GetNumTapsInFilter(filter_index); + const int next_row = num_taps - 1; + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + const uint8_t* src_x = src; + __m128i srcs[8]; + srcs[0] = LoadLo8(src_x); + src_x += src_stride; + if (num_taps >= 4) { + srcs[1] = LoadLo8(src_x); + src_x += src_stride; + srcs[2] = LoadLo8(src_x); + src_x += src_stride; + if (num_taps >= 6) { + srcs[3] = LoadLo8(src_x); + src_x += src_stride; + srcs[4] = LoadLo8(src_x); + src_x += src_stride; + if (num_taps == 8) { + srcs[5] = LoadLo8(src_x); + src_x += src_stride; + srcs[6] = LoadLo8(src_x); + src_x += src_stride; + } + } + } + + int y = height; + do { + srcs[next_row] = LoadLo8(src_x); + src_x += src_stride; + + const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + if (is_compound) { + const __m128i results = Compound1DShift(sums); + StoreUnaligned16(dst16, results); + dst16 += dst_stride; + } else { + const __m128i results = RightShiftWithRounding_S16(sums, kFilterBits - 1); + StoreLo8(dst8, _mm_packus_epi16(results, results)); + dst8 += dst_stride; + } + + srcs[0] = srcs[1]; + if (num_taps >= 4) { + srcs[1] = srcs[2]; + srcs[2] = srcs[3]; + if (num_taps >= 6) { + srcs[3] = srcs[4]; + srcs[4] = srcs[5]; + if (num_taps == 8) { + srcs[5] = srcs[6]; + srcs[6] = srcs[7]; + } + } + } + } while (--y != 0); +} + +void ConvolveVertical_AVX2(const void* const reference, + const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, + const int vertical_filter_index, + const int /*horizontal_filter_id*/, + const int vertical_filter_id, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + const int filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(filter_index); + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride; + auto* dest = static_cast<uint8_t*>(prediction); + const ptrdiff_t dest_stride = pred_stride; + assert(vertical_filter_id != 0); + + const __m128i v_filter = + LoadLo8(kHalfSubPixelFilters[filter_index][vertical_filter_id]); + + // Use 256 bits for width > 4. + if (width > 4) { + __m256i taps_256[4]; + if (filter_index < 2) { // 6 tap. + SetupTaps<6>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<0>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else if (width == 16) { + FilterVertical16xH<0>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else { + FilterVertical32xH<0>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } + } else if (filter_index == 2) { // 8 tap. + SetupTaps<8>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<2>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else if (width == 16) { + FilterVertical16xH<2>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else { + FilterVertical32xH<2>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } + } else if (filter_index == 3) { // 2 tap. + SetupTaps<2>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<3>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else if (width == 16) { + FilterVertical16xH<3>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else { + FilterVertical32xH<3>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } + } else if (filter_index == 4) { // 4 tap. + SetupTaps<4>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<4>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else if (width == 16) { + FilterVertical16xH<4>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else { + FilterVertical32xH<4>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } + } else { + SetupTaps<4>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<5>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else if (width == 16) { + FilterVertical16xH<5>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else { + FilterVertical32xH<5>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } + } + } else { // width <= 8 + // Use 128 bit code. + __m128i taps[4]; + + if (filter_index < 2) { // 6 tap. + SetupTaps<6>(&v_filter, taps); + if (width == 2) { + FilterVertical2xH<6, 0>(src, src_stride, dest, dest_stride, height, + taps); + } else { + FilterVertical4xH<6, 0>(src, src_stride, dest, dest_stride, height, + taps); + } + } else if (filter_index == 2) { // 8 tap. + SetupTaps<8>(&v_filter, taps); + if (width == 2) { + FilterVertical2xH<8, 2>(src, src_stride, dest, dest_stride, height, + taps); + } else { + FilterVertical4xH<8, 2>(src, src_stride, dest, dest_stride, height, + taps); + } + } else if (filter_index == 3) { // 2 tap. + SetupTaps<2>(&v_filter, taps); + if (width == 2) { + FilterVertical2xH<2, 3>(src, src_stride, dest, dest_stride, height, + taps); + } else { + FilterVertical4xH<2, 3>(src, src_stride, dest, dest_stride, height, + taps); + } + } else if (filter_index == 4) { // 4 tap. + SetupTaps<4>(&v_filter, taps); + if (width == 2) { + FilterVertical2xH<4, 4>(src, src_stride, dest, dest_stride, height, + taps); + } else { + FilterVertical4xH<4, 4>(src, src_stride, dest, dest_stride, height, + taps); + } + } else { + SetupTaps<4>(&v_filter, taps); + if (width == 2) { + FilterVertical2xH<4, 5>(src, src_stride, dest, dest_stride, height, + taps); + } else { + FilterVertical4xH<4, 5>(src, src_stride, dest, dest_stride, height, + taps); + } + } + } +} + +void ConvolveCompoundVertical_AVX2( + const void* const reference, const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, const int vertical_filter_index, + const int /*horizontal_filter_id*/, const int vertical_filter_id, + const int width, const int height, void* prediction, + const ptrdiff_t /*pred_stride*/) { + const int filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(filter_index); + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride; + auto* dest = static_cast<uint8_t*>(prediction); + const ptrdiff_t dest_stride = width; + assert(vertical_filter_id != 0); + + const __m128i v_filter = + LoadLo8(kHalfSubPixelFilters[filter_index][vertical_filter_id]); + + // Use 256 bits for width > 4. + if (width > 4) { + __m256i taps_256[4]; + if (filter_index < 2) { // 6 tap. + SetupTaps<6>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<0, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else if (width == 16) { + FilterVertical16xH<0, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else { + FilterVertical32xH<0, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } + } else if (filter_index == 2) { // 8 tap. + SetupTaps<8>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<2, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else if (width == 16) { + FilterVertical16xH<2, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else { + FilterVertical32xH<2, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } + } else if (filter_index == 3) { // 2 tap. + SetupTaps<2>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<3, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else if (width == 16) { + FilterVertical16xH<3, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else { + FilterVertical32xH<3, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } + } else if (filter_index == 4) { // 4 tap. + SetupTaps<4>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<4, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else if (width == 16) { + FilterVertical16xH<4, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else { + FilterVertical32xH<4, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } + } else { + SetupTaps<4>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<5, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else if (width == 16) { + FilterVertical16xH<5, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else { + FilterVertical32xH<5, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } + } + } else { // width <= 4 + // Use 128 bit code. + __m128i taps[4]; + + if (filter_index < 2) { // 6 tap. + SetupTaps<6>(&v_filter, taps); + FilterVertical4xH<6, 0, /*is_compound=*/true>(src, src_stride, dest, + dest_stride, height, taps); + } else if (filter_index == 2) { // 8 tap. + SetupTaps<8>(&v_filter, taps); + FilterVertical4xH<8, 2, /*is_compound=*/true>(src, src_stride, dest, + dest_stride, height, taps); + } else if (filter_index == 3) { // 2 tap. + SetupTaps<2>(&v_filter, taps); + FilterVertical4xH<2, 3, /*is_compound=*/true>(src, src_stride, dest, + dest_stride, height, taps); + } else if (filter_index == 4) { // 4 tap. + SetupTaps<4>(&v_filter, taps); + FilterVertical4xH<4, 4, /*is_compound=*/true>(src, src_stride, dest, + dest_stride, height, taps); + } else { + SetupTaps<4>(&v_filter, taps); + FilterVertical4xH<4, 5, /*is_compound=*/true>(src, src_stride, dest, + dest_stride, height, taps); + } } } @@ -509,10 +1389,140 @@ void ConvolveHorizontal_AVX2(const void* const reference, } } +void ConvolveCompoundHorizontal_AVX2( + const void* const reference, const ptrdiff_t reference_stride, + const int horizontal_filter_index, const int /*vertical_filter_index*/, + const int horizontal_filter_id, const int /*vertical_filter_id*/, + const int width, const int height, void* prediction, + const ptrdiff_t pred_stride) { + const int filter_index = GetFilterIndex(horizontal_filter_index, width); + // Set |src| to the outermost tap. + const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset; + auto* dest = static_cast<uint8_t*>(prediction); + // All compound functions output to the predictor buffer with |pred_stride| + // equal to |width|. + assert(pred_stride == width); + // Compound functions start at 4x4. + assert(width >= 4 && height >= 4); + +#ifdef NDEBUG + // Quiet compiler error. + (void)pred_stride; +#endif + + DoHorizontalPass</*is_2d=*/false, /*is_compound=*/true>( + src, reference_stride, dest, width, width, height, horizontal_filter_id, + filter_index); +} + +void ConvolveCompound2D_AVX2(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int vertical_filter_index, + const int horizontal_filter_id, + const int vertical_filter_id, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width); + const int vert_filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(vert_filter_index); + + // The output of the horizontal filter is guaranteed to fit in 16 bits. + alignas(32) uint16_t + intermediate_result[kMaxSuperBlockSizeInPixels * + (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)]; + const int intermediate_height = height + vertical_taps - 1; + + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride - kHorizontalOffset; + DoHorizontalPass</*is_2d=*/true, /*is_compound=*/true>( + src, src_stride, intermediate_result, width, width, intermediate_height, + horizontal_filter_id, horiz_filter_index); + + // Vertical filter. + auto* dest = static_cast<uint8_t*>(prediction); + const ptrdiff_t dest_stride = pred_stride; + assert(vertical_filter_id != 0); + + const __m128i v_filter = + LoadLo8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id]); + + // Use 256 bits for width > 8. + if (width > 8) { + __m256i taps_256[4]; + const __m128i v_filter_ext = _mm_cvtepi8_epi16(v_filter); + + if (vertical_taps == 8) { + SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256); + Filter2DVertical16xH<8, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps_256); + } else if (vertical_taps == 6) { + SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256); + Filter2DVertical16xH<6, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps_256); + } else if (vertical_taps == 4) { + SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256); + Filter2DVertical16xH<4, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps_256); + } else { // |vertical_taps| == 2 + SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256); + Filter2DVertical16xH<2, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps_256); + } + } else { // width <= 8 + __m128i taps[4]; + // Use 128 bit code. + if (vertical_taps == 8) { + SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 4) { + Filter2DVertical4xH<8, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<8, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } else if (vertical_taps == 6) { + SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 4) { + Filter2DVertical4xH<6, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<6, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } else if (vertical_taps == 4) { + SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 4) { + Filter2DVertical4xH<4, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<4, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } else { // |vertical_taps| == 2 + SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 4) { + Filter2DVertical4xH<2, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<2, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } + } +} + void Init8bpp() { Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); assert(dsp != nullptr); dsp->convolve[0][0][0][1] = ConvolveHorizontal_AVX2; + dsp->convolve[0][0][1][0] = ConvolveVertical_AVX2; + dsp->convolve[0][0][1][1] = Convolve2D_AVX2; + + dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_AVX2; + dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_AVX2; + dsp->convolve[0][1][1][1] = ConvolveCompound2D_AVX2; } } // namespace @@ -523,7 +1533,7 @@ void ConvolveInit_AVX2() { low_bitdepth::Init8bpp(); } } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_AVX2 +#else // !LIBGAV1_TARGETING_AVX2 namespace libgav1 { namespace dsp { |