diff options
Diffstat (limited to 'src/dsp/x86/convolve_avx2.cc')
-rw-r--r-- | src/dsp/x86/convolve_avx2.cc | 322 |
1 files changed, 130 insertions, 192 deletions
diff --git a/src/dsp/x86/convolve_avx2.cc b/src/dsp/x86/convolve_avx2.cc index 4126ca9..6e94347 100644 --- a/src/dsp/x86/convolve_avx2.cc +++ b/src/dsp/x86/convolve_avx2.cc @@ -39,17 +39,17 @@ namespace { // Multiply every entry in |src[]| by the corresponding entry in |taps[]| and // sum. The filters in |taps[]| are pre-shifted by 1. This prevents the final // sum from outranging int16_t. -template <int filter_index> +template <int num_taps> __m256i SumOnePassTaps(const __m256i* const src, const __m256i* const taps) { __m256i sum; - if (filter_index < 2) { + if (num_taps == 6) { // 6 taps. const __m256i v_madd_21 = _mm256_maddubs_epi16(src[0], taps[0]); // k2k1 const __m256i v_madd_43 = _mm256_maddubs_epi16(src[1], taps[1]); // k4k3 const __m256i v_madd_65 = _mm256_maddubs_epi16(src[2], taps[2]); // k6k5 sum = _mm256_add_epi16(v_madd_21, v_madd_43); sum = _mm256_add_epi16(sum, v_madd_65); - } else if (filter_index == 2) { + } else if (num_taps == 8) { // 8 taps. const __m256i v_madd_10 = _mm256_maddubs_epi16(src[0], taps[0]); // k1k0 const __m256i v_madd_32 = _mm256_maddubs_epi16(src[1], taps[1]); // k3k2 @@ -58,7 +58,7 @@ __m256i SumOnePassTaps(const __m256i* const src, const __m256i* const taps) { const __m256i v_sum_3210 = _mm256_add_epi16(v_madd_10, v_madd_32); const __m256i v_sum_7654 = _mm256_add_epi16(v_madd_54, v_madd_76); sum = _mm256_add_epi16(v_sum_7654, v_sum_3210); - } else if (filter_index == 3) { + } else if (num_taps == 2) { // 2 taps. sum = _mm256_maddubs_epi16(src[0], taps[0]); // k4k3 } else { @@ -70,7 +70,7 @@ __m256i SumOnePassTaps(const __m256i* const src, const __m256i* const taps) { return sum; } -template <int filter_index> +template <int num_taps> __m256i SumHorizontalTaps(const __m256i* const src, const __m256i* const v_tap) { __m256i v_src[4]; @@ -78,32 +78,32 @@ __m256i SumHorizontalTaps(const __m256i* const src, const __m256i src_long_dup_lo = _mm256_unpacklo_epi8(src_long, src_long); const __m256i src_long_dup_hi = _mm256_unpackhi_epi8(src_long, src_long); - if (filter_index < 2) { + if (num_taps == 6) { // 6 taps. v_src[0] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 3); // _21 v_src[1] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 7); // _43 v_src[2] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 11); // _65 - } else if (filter_index == 2) { + } else if (num_taps == 8) { // 8 taps. v_src[0] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 1); // _10 v_src[1] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 5); // _32 v_src[2] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 9); // _54 v_src[3] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 13); // _76 - } else if (filter_index == 3) { + } else if (num_taps == 2) { // 2 taps. v_src[0] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 7); // _43 - } else if (filter_index > 3) { + } else { // 4 taps. v_src[0] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 5); // _32 v_src[1] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 9); // _54 } - return SumOnePassTaps<filter_index>(v_src, v_tap); + return SumOnePassTaps<num_taps>(v_src, v_tap); } -template <int filter_index> +template <int num_taps> __m256i SimpleHorizontalTaps(const __m256i* const src, const __m256i* const v_tap) { - __m256i sum = SumHorizontalTaps<filter_index>(src, v_tap); + __m256i sum = SumHorizontalTaps<num_taps>(src, v_tap); // Normally the Horizontal pass does the downshift in two passes: // kInterRoundBitsHorizontal - 1 and then (kFilterBits - @@ -116,17 +116,16 @@ __m256i SimpleHorizontalTaps(const __m256i* const src, return _mm256_packus_epi16(sum, sum); } -template <int filter_index> +template <int num_taps> __m256i HorizontalTaps8To16(const __m256i* const src, const __m256i* const v_tap) { - const __m256i sum = SumHorizontalTaps<filter_index>(src, v_tap); + const __m256i sum = SumHorizontalTaps<num_taps>(src, v_tap); return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1); } // Filter 2xh sizes. -template <int num_taps, int filter_index, bool is_2d = false, - bool is_compound = false> +template <int num_taps, bool is_2d = false, bool is_compound = false> void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src, const ptrdiff_t src_stride, void* LIBGAV1_RESTRICT const dest, @@ -145,14 +144,14 @@ void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src, do { if (is_2d) { const __m128i sum = - HorizontalTaps8To16_2x2<filter_index>(src, src_stride, v_tap); + HorizontalTaps8To16_2x2<num_taps>(src, src_stride, v_tap); Store4(&dest16[0], sum); dest16 += pred_stride; Store4(&dest16[0], _mm_srli_si128(sum, 8)); dest16 += pred_stride; } else { const __m128i sum = - SimpleHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); + SimpleHorizontalTaps2x2<num_taps>(src, src_stride, v_tap); Store2(dest8, sum); dest8 += pred_stride; Store2(dest8, _mm_srli_si128(sum, 4)); @@ -169,7 +168,7 @@ void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src, assert(height % 2 == 1); __m128i sum; const __m128i input = LoadLo8(&src[2]); - if (filter_index == 3) { + if (num_taps == 2) { // 03 04 04 05 05 06 06 07 .... const __m128i v_src_43 = _mm_srli_si128(_mm_unpacklo_epi8(input, input), 3); @@ -194,8 +193,7 @@ void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src, } // Filter widths >= 4. -template <int num_taps, int filter_index, bool is_2d = false, - bool is_compound = false> +template <int num_taps, bool is_2d = false, bool is_compound = false> void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src, const ptrdiff_t src_stride, void* LIBGAV1_RESTRICT const dest, @@ -214,11 +212,11 @@ void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src, const __m256i src_long = SetrM128i(LoadUnaligned16(&src[x]), LoadUnaligned16(&src[x + 8])); const __m256i result = - HorizontalTaps8To16<filter_index>(&src_long, v_tap); + HorizontalTaps8To16<num_taps>(&src_long, v_tap); const __m256i src_long2 = SetrM128i(LoadUnaligned16(&src[x + 16]), LoadUnaligned16(&src[x + 24])); const __m256i result2 = - HorizontalTaps8To16<filter_index>(&src_long2, v_tap); + HorizontalTaps8To16<num_taps>(&src_long2, v_tap); if (is_2d) { StoreAligned32(&dest16[x], result); StoreAligned32(&dest16[x + 16], result2); @@ -230,11 +228,11 @@ void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src, // Load src used to calculate dest8[7:0] and dest8[23:16]. const __m256i src_long = LoadUnaligned32(&src[x]); const __m256i result = - SimpleHorizontalTaps<filter_index>(&src_long, v_tap); + SimpleHorizontalTaps<num_taps>(&src_long, v_tap); // Load src used to calculate dest8[15:8] and dest8[31:24]. const __m256i src_long2 = LoadUnaligned32(&src[x + 8]); const __m256i result2 = - SimpleHorizontalTaps<filter_index>(&src_long2, v_tap); + SimpleHorizontalTaps<num_taps>(&src_long2, v_tap); // Combine results and store. StoreUnaligned32(&dest8[x], _mm256_unpacklo_epi64(result, result2)); } @@ -252,13 +250,12 @@ void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src, // Load into 2 128 bit lanes. const __m256i src_long = SetrM128i(LoadUnaligned16(&src[0]), LoadUnaligned16(&src[8])); - const __m256i result = - HorizontalTaps8To16<filter_index>(&src_long, v_tap); + const __m256i result = HorizontalTaps8To16<num_taps>(&src_long, v_tap); const __m256i src_long2 = SetrM128i(LoadUnaligned16(&src[src_stride]), LoadUnaligned16(&src[8 + src_stride])); const __m256i result2 = - HorizontalTaps8To16<filter_index>(&src_long2, v_tap); + HorizontalTaps8To16<num_taps>(&src_long2, v_tap); if (is_2d) { StoreAligned32(&dest16[0], result); StoreAligned32(&dest16[pred_stride], result2); @@ -270,12 +267,11 @@ void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src, // Load into 2 128 bit lanes. const __m256i src_long = SetrM128i(LoadUnaligned16(&src[0]), LoadUnaligned16(&src[src_stride])); - const __m256i result = - SimpleHorizontalTaps<filter_index>(&src_long, v_tap); + const __m256i result = SimpleHorizontalTaps<num_taps>(&src_long, v_tap); const __m256i src_long2 = SetrM128i( LoadUnaligned16(&src[8]), LoadUnaligned16(&src[8 + src_stride])); const __m256i result2 = - SimpleHorizontalTaps<filter_index>(&src_long2, v_tap); + SimpleHorizontalTaps<num_taps>(&src_long2, v_tap); const __m256i packed_result = _mm256_unpacklo_epi64(result, result2); StoreUnaligned16(&dest8[0], _mm256_castsi256_si128(packed_result)); StoreUnaligned16(&dest8[pred_stride], @@ -292,8 +288,7 @@ void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src, if (is_2d) { const __m256i src_long = SetrM128i(LoadUnaligned16(&src[0]), LoadUnaligned16(&src[8])); - const __m256i result = - HorizontalTaps8To16<filter_index>(&src_long, v_tap); + const __m256i result = HorizontalTaps8To16<num_taps>(&src_long, v_tap); StoreAligned32(&dest16[0], result); } @@ -306,8 +301,7 @@ void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src, const __m128i next_row = LoadUnaligned16(&src[src_stride]); const __m256i src_long = SetrM128i(this_row, next_row); if (is_2d || is_compound) { - const __m256i result = - HorizontalTaps8To16<filter_index>(&src_long, v_tap); + const __m256i result = HorizontalTaps8To16<num_taps>(&src_long, v_tap); if (is_2d) { StoreAligned16(&dest16[0], _mm256_castsi256_si128(result)); StoreAligned16(&dest16[pred_stride], @@ -322,8 +316,7 @@ void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src, const __m128i next_row = LoadUnaligned16(&src[src_stride]); // Load into 2 128 bit lanes. const __m256i src_long = SetrM128i(this_row, next_row); - const __m256i result = - SimpleHorizontalTaps<filter_index>(&src_long, v_tap); + const __m256i result = SimpleHorizontalTaps<num_taps>(&src_long, v_tap); StoreLo8(&dest8[0], _mm256_castsi256_si128(result)); StoreLo8(&dest8[pred_stride], _mm256_extracti128_si256(result, 1)); } @@ -337,8 +330,7 @@ void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src, // filter the remaining row. if (is_2d) { const __m256i src_long = _mm256_castsi128_si256(LoadUnaligned16(&src[0])); - const __m256i result = - HorizontalTaps8To16<filter_index>(&src_long, v_tap); + const __m256i result = HorizontalTaps8To16<num_taps>(&src_long, v_tap); StoreAligned16(&dest16[0], _mm256_castsi256_si128(result)); } @@ -351,8 +343,7 @@ void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src, const __m128i next_row = LoadUnaligned16(&src[src_stride]); const __m256i src_long = SetrM128i(this_row, next_row); if (is_2d || is_compound) { - const __m256i result = - HorizontalTaps8To16<filter_index>(&src_long, v_tap); + const __m256i result = HorizontalTaps8To16<num_taps>(&src_long, v_tap); StoreLo8(&dest16[0], _mm256_castsi256_si128(result)); StoreLo8(&dest16[pred_stride], _mm256_extracti128_si256(result, 1)); } else { @@ -360,8 +351,7 @@ void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src, const __m128i next_row = LoadUnaligned16(&src[src_stride]); // Load into 2 128 bit lanes. const __m256i src_long = SetrM128i(this_row, next_row); - const __m256i result = - SimpleHorizontalTaps<filter_index>(&src_long, v_tap); + const __m256i result = SimpleHorizontalTaps<num_taps>(&src_long, v_tap); Store4(&dest8[0], _mm256_castsi256_si128(result)); Store4(&dest8[pred_stride], _mm256_extracti128_si256(result, 1)); } @@ -375,8 +365,7 @@ void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src, // filter the remaining row. if (is_2d) { const __m256i src_long = _mm256_castsi128_si256(LoadUnaligned16(&src[0])); - const __m256i result = - HorizontalTaps8To16<filter_index>(&src_long, v_tap); + const __m256i result = HorizontalTaps8To16<num_taps>(&src_long, v_tap); StoreLo8(&dest16[0], _mm256_castsi256_si128(result)); } } @@ -554,18 +543,15 @@ LIBGAV1_ALWAYS_INLINE void DoHorizontalPass2xH( const __m128i v_horizontal_filter = LoadLo8(kHalfSubPixelFilters[filter_index][filter_id]); - if (filter_index == 4) { // 4 tap. - SetupTaps<4>(&v_horizontal_filter, v_tap); - FilterHorizontal<4, 4, is_2d, is_compound>(src, src_stride, dst, dst_stride, - width, height, v_tap); - } else if (filter_index == 5) { // 4 tap. + if ((filter_index & 0x4) != 0) { // 4 tap. + // ((filter_index == 4) | (filter_index == 5)) SetupTaps<4>(&v_horizontal_filter, v_tap); - FilterHorizontal<4, 5, is_2d, is_compound>(src, src_stride, dst, dst_stride, - width, height, v_tap); + FilterHorizontal<4, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else { // 2 tap. SetupTaps<2>(&v_horizontal_filter, v_tap); - FilterHorizontal<2, 3, is_2d, is_compound>(src, src_stride, dst, dst_stride, - width, height, v_tap); + FilterHorizontal<2, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } } @@ -582,28 +568,25 @@ LIBGAV1_ALWAYS_INLINE void DoHorizontalPass( if (filter_index == 2) { // 8 tap. SetupTaps<8>(&v_horizontal_filter, v_tap); - FilterHorizontal<8, 2, is_2d, is_compound>(src, src_stride, dst, dst_stride, - width, height, v_tap); + FilterHorizontal<8, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else if (filter_index == 1) { // 6 tap. SetupTaps<6>(&v_horizontal_filter, v_tap); - FilterHorizontal<6, 1, is_2d, is_compound>(src, src_stride, dst, dst_stride, - width, height, v_tap); + FilterHorizontal<6, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else if (filter_index == 0) { // 6 tap. SetupTaps<6>(&v_horizontal_filter, v_tap); - FilterHorizontal<6, 0, is_2d, is_compound>(src, src_stride, dst, dst_stride, - width, height, v_tap); - } else if (filter_index == 4) { // 4 tap. - SetupTaps<4>(&v_horizontal_filter, v_tap); - FilterHorizontal<4, 4, is_2d, is_compound>(src, src_stride, dst, dst_stride, - width, height, v_tap); - } else if (filter_index == 5) { // 4 tap. + FilterHorizontal<6, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); + } else if ((filter_index & 0x4) != 0) { // 4 tap. + // ((filter_index == 4) | (filter_index == 5)) SetupTaps<4>(&v_horizontal_filter, v_tap); - FilterHorizontal<4, 5, is_2d, is_compound>(src, src_stride, dst, dst_stride, - width, height, v_tap); + FilterHorizontal<4, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else { // 2 tap. SetupTaps<2>(&v_horizontal_filter, v_tap); - FilterHorizontal<2, 3, is_2d, is_compound>(src, src_stride, dst, dst_stride, - width, height, v_tap); + FilterHorizontal<2, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } } @@ -617,7 +600,8 @@ void Convolve2D_AVX2(const void* LIBGAV1_RESTRICT const reference, const ptrdiff_t pred_stride) { const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width); const int vert_filter_index = GetFilterIndex(vertical_filter_index, height); - const int vertical_taps = GetNumTapsInFilter(vert_filter_index); + const int vertical_taps = + GetNumTapsInFilter(vert_filter_index, vertical_filter_id); // The output of the horizontal filter is guaranteed to fit in 16 bits. alignas(32) uint16_t @@ -730,61 +714,60 @@ __m256i Compound1DShift(const __m256i sum) { return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1); } -template <int filter_index, bool unpack_high = false> +template <int num_taps, bool unpack_high = false> __m256i SumVerticalTaps(const __m256i* const srcs, const __m256i* const v_tap) { __m256i v_src[4]; if (!unpack_high) { - if (filter_index < 2) { + if (num_taps == 6) { // 6 taps. v_src[0] = _mm256_unpacklo_epi8(srcs[0], srcs[1]); v_src[1] = _mm256_unpacklo_epi8(srcs[2], srcs[3]); v_src[2] = _mm256_unpacklo_epi8(srcs[4], srcs[5]); - } else if (filter_index == 2) { + } else if (num_taps == 8) { // 8 taps. v_src[0] = _mm256_unpacklo_epi8(srcs[0], srcs[1]); v_src[1] = _mm256_unpacklo_epi8(srcs[2], srcs[3]); v_src[2] = _mm256_unpacklo_epi8(srcs[4], srcs[5]); v_src[3] = _mm256_unpacklo_epi8(srcs[6], srcs[7]); - } else if (filter_index == 3) { + } else if (num_taps == 2) { // 2 taps. v_src[0] = _mm256_unpacklo_epi8(srcs[0], srcs[1]); - } else if (filter_index > 3) { + } else { // 4 taps. v_src[0] = _mm256_unpacklo_epi8(srcs[0], srcs[1]); v_src[1] = _mm256_unpacklo_epi8(srcs[2], srcs[3]); } } else { - if (filter_index < 2) { + if (num_taps == 6) { // 6 taps. v_src[0] = _mm256_unpackhi_epi8(srcs[0], srcs[1]); v_src[1] = _mm256_unpackhi_epi8(srcs[2], srcs[3]); v_src[2] = _mm256_unpackhi_epi8(srcs[4], srcs[5]); - } else if (filter_index == 2) { + } else if (num_taps == 8) { // 8 taps. v_src[0] = _mm256_unpackhi_epi8(srcs[0], srcs[1]); v_src[1] = _mm256_unpackhi_epi8(srcs[2], srcs[3]); v_src[2] = _mm256_unpackhi_epi8(srcs[4], srcs[5]); v_src[3] = _mm256_unpackhi_epi8(srcs[6], srcs[7]); - } else if (filter_index == 3) { + } else if (num_taps == 2) { // 2 taps. v_src[0] = _mm256_unpackhi_epi8(srcs[0], srcs[1]); - } else if (filter_index > 3) { + } else { // 4 taps. v_src[0] = _mm256_unpackhi_epi8(srcs[0], srcs[1]); v_src[1] = _mm256_unpackhi_epi8(srcs[2], srcs[3]); } } - return SumOnePassTaps<filter_index>(v_src, v_tap); + return SumOnePassTaps<num_taps>(v_src, v_tap); } -template <int filter_index, bool is_compound = false> +template <int num_taps, bool is_compound = false> void FilterVertical32xH(const uint8_t* LIBGAV1_RESTRICT src, const ptrdiff_t src_stride, void* LIBGAV1_RESTRICT const dst, const ptrdiff_t dst_stride, const int width, const int height, const __m256i* const v_tap) { - const int num_taps = GetNumTapsInFilter(filter_index); const int next_row = num_taps - 1; auto* dst8 = static_cast<uint8_t*>(dst); auto* dst16 = static_cast<uint16_t*>(dst); @@ -821,9 +804,9 @@ void FilterVertical32xH(const uint8_t* LIBGAV1_RESTRICT src, srcs[next_row] = LoadUnaligned32(src_x); src_x += src_stride; - const __m256i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + const __m256i sums = SumVerticalTaps<num_taps>(srcs, v_tap); const __m256i sums_hi = - SumVerticalTaps<filter_index, /*unpack_high=*/true>(srcs, v_tap); + SumVerticalTaps<num_taps, /*unpack_high=*/true>(srcs, v_tap); if (is_compound) { const __m256i results = Compound1DShift(_mm256_permute2x128_si256(sums, sums_hi, 0x20)); @@ -861,13 +844,12 @@ void FilterVertical32xH(const uint8_t* LIBGAV1_RESTRICT src, } while (x < width); } -template <int filter_index, bool is_compound = false> +template <int num_taps, bool is_compound = false> void FilterVertical16xH(const uint8_t* LIBGAV1_RESTRICT src, const ptrdiff_t src_stride, void* LIBGAV1_RESTRICT const dst, const ptrdiff_t dst_stride, const int /*width*/, const int height, const __m256i* const v_tap) { - const int num_taps = GetNumTapsInFilter(filter_index); const int next_row = num_taps; auto* dst8 = static_cast<uint8_t*>(dst); auto* dst16 = static_cast<uint16_t*>(dst); @@ -922,9 +904,9 @@ void FilterVertical16xH(const uint8_t* LIBGAV1_RESTRICT src, srcs[next_row - 1] = _mm256_inserti128_si256( srcs[next_row - 1], _mm256_castsi256_si128(srcs[next_row]), 1); - const __m256i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + const __m256i sums = SumVerticalTaps<num_taps>(srcs, v_tap); const __m256i sums_hi = - SumVerticalTaps<filter_index, /*unpack_high=*/true>(srcs, v_tap); + SumVerticalTaps<num_taps, /*unpack_high=*/true>(srcs, v_tap); if (is_compound) { const __m256i results = Compound1DShift(_mm256_permute2x128_si256(sums, sums_hi, 0x20)); @@ -964,13 +946,12 @@ void FilterVertical16xH(const uint8_t* LIBGAV1_RESTRICT src, } while (y != 0); } -template <int filter_index, bool is_compound = false> +template <int num_taps, bool is_compound = false> void FilterVertical8xH(const uint8_t* LIBGAV1_RESTRICT src, const ptrdiff_t src_stride, void* LIBGAV1_RESTRICT const dst, const ptrdiff_t dst_stride, const int /*width*/, const int height, const __m256i* const v_tap) { - const int num_taps = GetNumTapsInFilter(filter_index); const int next_row = num_taps; auto* dst8 = static_cast<uint8_t*>(dst); auto* dst16 = static_cast<uint16_t*>(dst); @@ -1025,7 +1006,7 @@ void FilterVertical8xH(const uint8_t* LIBGAV1_RESTRICT src, srcs[next_row - 1] = _mm256_inserti128_si256( srcs[next_row - 1], _mm256_castsi256_si128(srcs[next_row]), 1); - const __m256i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + const __m256i sums = SumVerticalTaps<num_taps>(srcs, v_tap); if (is_compound) { const __m256i results = Compound1DShift(sums); const __m128i this_dst = _mm256_castsi256_si128(results); @@ -1062,13 +1043,12 @@ void FilterVertical8xH(const uint8_t* LIBGAV1_RESTRICT src, } while (y != 0); } -template <int filter_index, bool is_compound = false> +template <int num_taps, bool is_compound = false> void FilterVertical8xH(const uint8_t* LIBGAV1_RESTRICT src, const ptrdiff_t src_stride, void* LIBGAV1_RESTRICT const dst, const ptrdiff_t dst_stride, const int /*width*/, const int height, const __m128i* const v_tap) { - const int num_taps = GetNumTapsInFilter(filter_index); const int next_row = num_taps - 1; auto* dst8 = static_cast<uint8_t*>(dst); auto* dst16 = static_cast<uint16_t*>(dst); @@ -1101,7 +1081,7 @@ void FilterVertical8xH(const uint8_t* LIBGAV1_RESTRICT src, srcs[next_row] = LoadLo8(src_x); src_x += src_stride; - const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + const __m128i sums = SumVerticalTaps<num_taps>(srcs, v_tap); if (is_compound) { const __m128i results = Compound1DShift(sums); StoreUnaligned16(dst16, results); @@ -1137,7 +1117,8 @@ void ConvolveVertical_AVX2(const void* LIBGAV1_RESTRICT const reference, const int height, void* LIBGAV1_RESTRICT prediction, const ptrdiff_t pred_stride) { const int filter_index = GetFilterIndex(vertical_filter_index, height); - const int vertical_taps = GetNumTapsInFilter(filter_index); + const int vertical_taps = + GetNumTapsInFilter(filter_index, vertical_filter_id); const ptrdiff_t src_stride = reference_stride; const auto* src = static_cast<const uint8_t*>(reference) - (vertical_taps / 2 - 1) * src_stride; @@ -1151,43 +1132,43 @@ void ConvolveVertical_AVX2(const void* LIBGAV1_RESTRICT const reference, // Use 256 bits for width > 4. if (width > 4) { __m256i taps_256[4]; - if (filter_index < 2) { // 6 tap. + if (vertical_taps == 6) { // 6 tap. SetupTaps<6>(&v_filter, taps_256); if (width == 8) { - FilterVertical8xH<0>(src, src_stride, dest, dest_stride, width, height, + FilterVertical8xH<6>(src, src_stride, dest, dest_stride, width, height, taps_256); } else if (width == 16) { - FilterVertical16xH<0>(src, src_stride, dest, dest_stride, width, height, + FilterVertical16xH<6>(src, src_stride, dest, dest_stride, width, height, taps_256); } else { - FilterVertical32xH<0>(src, src_stride, dest, dest_stride, width, height, + FilterVertical32xH<6>(src, src_stride, dest, dest_stride, width, height, taps_256); } - } else if (filter_index == 2) { // 8 tap. + } else if (vertical_taps == 8) { // 8 tap. SetupTaps<8>(&v_filter, taps_256); if (width == 8) { - FilterVertical8xH<2>(src, src_stride, dest, dest_stride, width, height, + FilterVertical8xH<8>(src, src_stride, dest, dest_stride, width, height, taps_256); } else if (width == 16) { - FilterVertical16xH<2>(src, src_stride, dest, dest_stride, width, height, + FilterVertical16xH<8>(src, src_stride, dest, dest_stride, width, height, taps_256); } else { - FilterVertical32xH<2>(src, src_stride, dest, dest_stride, width, height, + FilterVertical32xH<8>(src, src_stride, dest, dest_stride, width, height, taps_256); } - } else if (filter_index == 3) { // 2 tap. + } else if (vertical_taps == 2) { // 2 tap. SetupTaps<2>(&v_filter, taps_256); if (width == 8) { - FilterVertical8xH<3>(src, src_stride, dest, dest_stride, width, height, + FilterVertical8xH<2>(src, src_stride, dest, dest_stride, width, height, taps_256); } else if (width == 16) { - FilterVertical16xH<3>(src, src_stride, dest, dest_stride, width, height, + FilterVertical16xH<2>(src, src_stride, dest, dest_stride, width, height, taps_256); } else { - FilterVertical32xH<3>(src, src_stride, dest, dest_stride, width, height, + FilterVertical32xH<2>(src, src_stride, dest, dest_stride, width, height, taps_256); } - } else if (filter_index == 4) { // 4 tap. + } else { // 4 tap. SetupTaps<4>(&v_filter, taps_256); if (width == 8) { FilterVertical8xH<4>(src, src_stride, dest, dest_stride, width, height, @@ -1199,67 +1180,38 @@ void ConvolveVertical_AVX2(const void* LIBGAV1_RESTRICT const reference, FilterVertical32xH<4>(src, src_stride, dest, dest_stride, width, height, taps_256); } - } else { - SetupTaps<4>(&v_filter, taps_256); - if (width == 8) { - FilterVertical8xH<5>(src, src_stride, dest, dest_stride, width, height, - taps_256); - } else if (width == 16) { - FilterVertical16xH<5>(src, src_stride, dest, dest_stride, width, height, - taps_256); - } else { - FilterVertical32xH<5>(src, src_stride, dest, dest_stride, width, height, - taps_256); - } } } else { // width <= 8 // Use 128 bit code. __m128i taps[4]; - if (filter_index < 2) { // 6 tap. + if (vertical_taps == 6) { // 6 tap. SetupTaps<6>(&v_filter, taps); if (width == 2) { - FilterVertical2xH<6, 0>(src, src_stride, dest, dest_stride, height, - taps); + FilterVertical2xH<6>(src, src_stride, dest, dest_stride, height, taps); } else { - FilterVertical4xH<6, 0>(src, src_stride, dest, dest_stride, height, - taps); + FilterVertical4xH<6>(src, src_stride, dest, dest_stride, height, taps); } - } else if (filter_index == 2) { // 8 tap. + } else if (vertical_taps == 8) { // 8 tap. SetupTaps<8>(&v_filter, taps); if (width == 2) { - FilterVertical2xH<8, 2>(src, src_stride, dest, dest_stride, height, - taps); + FilterVertical2xH<8>(src, src_stride, dest, dest_stride, height, taps); } else { - FilterVertical4xH<8, 2>(src, src_stride, dest, dest_stride, height, - taps); + FilterVertical4xH<8>(src, src_stride, dest, dest_stride, height, taps); } - } else if (filter_index == 3) { // 2 tap. + } else if (vertical_taps == 2) { // 2 tap. SetupTaps<2>(&v_filter, taps); if (width == 2) { - FilterVertical2xH<2, 3>(src, src_stride, dest, dest_stride, height, - taps); - } else { - FilterVertical4xH<2, 3>(src, src_stride, dest, dest_stride, height, - taps); - } - } else if (filter_index == 4) { // 4 tap. - SetupTaps<4>(&v_filter, taps); - if (width == 2) { - FilterVertical2xH<4, 4>(src, src_stride, dest, dest_stride, height, - taps); + FilterVertical2xH<2>(src, src_stride, dest, dest_stride, height, taps); } else { - FilterVertical4xH<4, 4>(src, src_stride, dest, dest_stride, height, - taps); + FilterVertical4xH<2>(src, src_stride, dest, dest_stride, height, taps); } - } else { + } else { // 4 tap. SetupTaps<4>(&v_filter, taps); if (width == 2) { - FilterVertical2xH<4, 5>(src, src_stride, dest, dest_stride, height, - taps); + FilterVertical2xH<4>(src, src_stride, dest, dest_stride, height, taps); } else { - FilterVertical4xH<4, 5>(src, src_stride, dest, dest_stride, height, - taps); + FilterVertical4xH<4>(src, src_stride, dest, dest_stride, height, taps); } } } @@ -1272,7 +1224,8 @@ void ConvolveCompoundVertical_AVX2( const int vertical_filter_id, const int width, const int height, void* LIBGAV1_RESTRICT prediction, const ptrdiff_t /*pred_stride*/) { const int filter_index = GetFilterIndex(vertical_filter_index, height); - const int vertical_taps = GetNumTapsInFilter(filter_index); + const int vertical_taps = + GetNumTapsInFilter(filter_index, vertical_filter_id); const ptrdiff_t src_stride = reference_stride; const auto* src = static_cast<const uint8_t*>(reference) - (vertical_taps / 2 - 1) * src_stride; @@ -1286,43 +1239,43 @@ void ConvolveCompoundVertical_AVX2( // Use 256 bits for width > 4. if (width > 4) { __m256i taps_256[4]; - if (filter_index < 2) { // 6 tap. + if (vertical_taps == 6) { // 6 tap. SetupTaps<6>(&v_filter, taps_256); if (width == 8) { - FilterVertical8xH<0, /*is_compound=*/true>( + FilterVertical8xH<6, /*is_compound=*/true>( src, src_stride, dest, dest_stride, width, height, taps_256); } else if (width == 16) { - FilterVertical16xH<0, /*is_compound=*/true>( + FilterVertical16xH<6, /*is_compound=*/true>( src, src_stride, dest, dest_stride, width, height, taps_256); } else { - FilterVertical32xH<0, /*is_compound=*/true>( + FilterVertical32xH<6, /*is_compound=*/true>( src, src_stride, dest, dest_stride, width, height, taps_256); } - } else if (filter_index == 2) { // 8 tap. + } else if (vertical_taps == 8) { // 8 tap. SetupTaps<8>(&v_filter, taps_256); if (width == 8) { - FilterVertical8xH<2, /*is_compound=*/true>( + FilterVertical8xH<8, /*is_compound=*/true>( src, src_stride, dest, dest_stride, width, height, taps_256); } else if (width == 16) { - FilterVertical16xH<2, /*is_compound=*/true>( + FilterVertical16xH<8, /*is_compound=*/true>( src, src_stride, dest, dest_stride, width, height, taps_256); } else { - FilterVertical32xH<2, /*is_compound=*/true>( + FilterVertical32xH<8, /*is_compound=*/true>( src, src_stride, dest, dest_stride, width, height, taps_256); } - } else if (filter_index == 3) { // 2 tap. + } else if (vertical_taps == 2) { // 2 tap. SetupTaps<2>(&v_filter, taps_256); if (width == 8) { - FilterVertical8xH<3, /*is_compound=*/true>( + FilterVertical8xH<2, /*is_compound=*/true>( src, src_stride, dest, dest_stride, width, height, taps_256); } else if (width == 16) { - FilterVertical16xH<3, /*is_compound=*/true>( + FilterVertical16xH<2, /*is_compound=*/true>( src, src_stride, dest, dest_stride, width, height, taps_256); } else { - FilterVertical32xH<3, /*is_compound=*/true>( + FilterVertical32xH<2, /*is_compound=*/true>( src, src_stride, dest, dest_stride, width, height, taps_256); } - } else if (filter_index == 4) { // 4 tap. + } else { // 4 tap. SetupTaps<4>(&v_filter, taps_256); if (width == 8) { FilterVertical8xH<4, /*is_compound=*/true>( @@ -1334,43 +1287,27 @@ void ConvolveCompoundVertical_AVX2( FilterVertical32xH<4, /*is_compound=*/true>( src, src_stride, dest, dest_stride, width, height, taps_256); } - } else { - SetupTaps<4>(&v_filter, taps_256); - if (width == 8) { - FilterVertical8xH<5, /*is_compound=*/true>( - src, src_stride, dest, dest_stride, width, height, taps_256); - } else if (width == 16) { - FilterVertical16xH<5, /*is_compound=*/true>( - src, src_stride, dest, dest_stride, width, height, taps_256); - } else { - FilterVertical32xH<5, /*is_compound=*/true>( - src, src_stride, dest, dest_stride, width, height, taps_256); - } } } else { // width <= 4 // Use 128 bit code. __m128i taps[4]; - if (filter_index < 2) { // 6 tap. + if (vertical_taps == 6) { // 6 tap. SetupTaps<6>(&v_filter, taps); - FilterVertical4xH<6, 0, /*is_compound=*/true>(src, src_stride, dest, - dest_stride, height, taps); - } else if (filter_index == 2) { // 8 tap. + FilterVertical4xH<6, /*is_compound=*/true>(src, src_stride, dest, + dest_stride, height, taps); + } else if (vertical_taps == 8) { // 8 tap. SetupTaps<8>(&v_filter, taps); - FilterVertical4xH<8, 2, /*is_compound=*/true>(src, src_stride, dest, - dest_stride, height, taps); - } else if (filter_index == 3) { // 2 tap. + FilterVertical4xH<8, /*is_compound=*/true>(src, src_stride, dest, + dest_stride, height, taps); + } else if (vertical_taps == 2) { // 2 tap. SetupTaps<2>(&v_filter, taps); - FilterVertical4xH<2, 3, /*is_compound=*/true>(src, src_stride, dest, - dest_stride, height, taps); - } else if (filter_index == 4) { // 4 tap. - SetupTaps<4>(&v_filter, taps); - FilterVertical4xH<4, 4, /*is_compound=*/true>(src, src_stride, dest, - dest_stride, height, taps); - } else { + FilterVertical4xH<2, /*is_compound=*/true>(src, src_stride, dest, + dest_stride, height, taps); + } else { // 4 tap. SetupTaps<4>(&v_filter, taps); - FilterVertical4xH<4, 5, /*is_compound=*/true>(src, src_stride, dest, - dest_stride, height, taps); + FilterVertical4xH<4, /*is_compound=*/true>(src, src_stride, dest, + dest_stride, height, taps); } } } @@ -1430,7 +1367,8 @@ void ConvolveCompound2D_AVX2( void* LIBGAV1_RESTRICT prediction, const ptrdiff_t pred_stride) { const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width); const int vert_filter_index = GetFilterIndex(vertical_filter_index, height); - const int vertical_taps = GetNumTapsInFilter(vert_filter_index); + const int vertical_taps = + GetNumTapsInFilter(vert_filter_index, vertical_filter_id); // The output of the horizontal filter is guaranteed to fit in 16 bits. alignas(32) uint16_t |