diff options
Diffstat (limited to 'src/dsp/x86')
48 files changed, 14694 insertions, 3855 deletions
diff --git a/src/dsp/x86/average_blend_sse4.cc b/src/dsp/x86/average_blend_sse4.cc index 8e008d1..ec9f589 100644 --- a/src/dsp/x86/average_blend_sse4.cc +++ b/src/dsp/x86/average_blend_sse4.cc @@ -30,6 +30,7 @@ namespace libgav1 { namespace dsp { +namespace low_bitdepth { namespace { constexpr int kInterPostRoundBit = 4; @@ -138,13 +139,232 @@ void Init8bpp() { } } // namespace +} // namespace low_bitdepth -void AverageBlendInit_SSE4_1() { Init8bpp(); } +#if LIBGAV1_MAX_BITDEPTH >= 10 +namespace high_bitdepth { +namespace { + +constexpr int kInterPostRoundBitPlusOne = 5; + +template <const int width, const int offset> +inline void AverageBlendRow(const uint16_t* prediction_0, + const uint16_t* prediction_1, + const __m128i& compound_offset, + const __m128i& round_offset, const __m128i& max, + const __m128i& zero, uint16_t* dst, + const ptrdiff_t dest_stride) { + // pred_0/1 max range is 16b. + const __m128i pred_0 = LoadUnaligned16(prediction_0 + offset); + const __m128i pred_1 = LoadUnaligned16(prediction_1 + offset); + const __m128i pred_00 = _mm_cvtepu16_epi32(pred_0); + const __m128i pred_01 = _mm_unpackhi_epi16(pred_0, zero); + const __m128i pred_10 = _mm_cvtepu16_epi32(pred_1); + const __m128i pred_11 = _mm_unpackhi_epi16(pred_1, zero); + + const __m128i pred_add_0 = _mm_add_epi32(pred_00, pred_10); + const __m128i pred_add_1 = _mm_add_epi32(pred_01, pred_11); + const __m128i compound_offset_0 = _mm_sub_epi32(pred_add_0, compound_offset); + const __m128i compound_offset_1 = _mm_sub_epi32(pred_add_1, compound_offset); + // RightShiftWithRounding and Clip3. + const __m128i round_0 = _mm_add_epi32(compound_offset_0, round_offset); + const __m128i round_1 = _mm_add_epi32(compound_offset_1, round_offset); + const __m128i res_0 = _mm_srai_epi32(round_0, kInterPostRoundBitPlusOne); + const __m128i res_1 = _mm_srai_epi32(round_1, kInterPostRoundBitPlusOne); + const __m128i result = _mm_min_epi16(_mm_packus_epi32(res_0, res_1), max); + if (width != 4) { + // Store width=8/16/32/64/128. + StoreUnaligned16(dst + offset, result); + return; + } + assert(width == 4); + StoreLo8(dst, result); + StoreHi8(dst + dest_stride, result); +} + +void AverageBlend10bpp_SSE4_1(const void* prediction_0, + const void* prediction_1, const int width, + const int height, void* const dest, + const ptrdiff_t dst_stride) { + auto* dst = static_cast<uint16_t*>(dest); + const ptrdiff_t dest_stride = dst_stride / sizeof(dst[0]); + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + const __m128i compound_offset = + _mm_set1_epi32(kCompoundOffset + kCompoundOffset); + const __m128i round_offset = + _mm_set1_epi32((1 << kInterPostRoundBitPlusOne) >> 1); + const __m128i max = _mm_set1_epi16((1 << kBitdepth10) - 1); + const __m128i zero = _mm_setzero_si128(); + int y = height; + + if (width == 4) { + const ptrdiff_t dest_stride2 = dest_stride << 1; + const ptrdiff_t width2 = width << 1; + do { + // row0,1 + AverageBlendRow<4, 0>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + dst += dest_stride2; + pred_0 += width2; + pred_1 += width2; + y -= 2; + } while (y != 0); + return; + } + if (width == 8) { + const ptrdiff_t dest_stride2 = dest_stride << 1; + const ptrdiff_t width2 = width << 1; + do { + // row0. + AverageBlendRow<8, 0>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + // row1. + AverageBlendRow<8, 0>(pred_0 + width, pred_1 + width, compound_offset, + round_offset, max, zero, dst + dest_stride, + dest_stride); + dst += dest_stride2; + pred_0 += width2; + pred_1 += width2; + y -= 2; + } while (y != 0); + return; + } + if (width == 16) { + const ptrdiff_t dest_stride2 = dest_stride << 1; + const ptrdiff_t width2 = width << 1; + do { + // row0. + AverageBlendRow<8, 0>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 8>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + // row1. + AverageBlendRow<8, 0>(pred_0 + width, pred_1 + width, compound_offset, + round_offset, max, zero, dst + dest_stride, + dest_stride); + AverageBlendRow<8, 8>(pred_0 + width, pred_1 + width, compound_offset, + round_offset, max, zero, dst + dest_stride, + dest_stride); + dst += dest_stride2; + pred_0 += width2; + pred_1 += width2; + y -= 2; + } while (y != 0); + return; + } + if (width == 32) { + do { + // pred [0 - 15]. + AverageBlendRow<8, 0>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 8>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + // pred [16 - 31]. + AverageBlendRow<8, 16>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 24>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + dst += dest_stride; + pred_0 += width; + pred_1 += width; + } while (--y != 0); + return; + } + if (width == 64) { + do { + // pred [0 - 31]. + AverageBlendRow<8, 0>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 8>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 16>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 24>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + // pred [31 - 63]. + AverageBlendRow<8, 32>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 40>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 48>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 56>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + dst += dest_stride; + pred_0 += width; + pred_1 += width; + } while (--y != 0); + return; + } + assert(width == 128); + do { + // pred [0 - 31]. + AverageBlendRow<8, 0>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 8>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 16>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 24>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + // pred [31 - 63]. + AverageBlendRow<8, 32>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 40>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 48>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 56>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + + // pred [64 - 95]. + AverageBlendRow<8, 64>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 72>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 80>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 88>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + // pred [96 - 127]. + AverageBlendRow<8, 96>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 104>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 112>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + AverageBlendRow<8, 120>(pred_0, pred_1, compound_offset, round_offset, max, + zero, dst, dest_stride); + dst += dest_stride; + pred_0 += width; + pred_1 += width; + } while (--y != 0); +} + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); +#if DSP_ENABLED_10BPP_SSE4_1(AverageBlend) + dsp->average_blend = AverageBlend10bpp_SSE4_1; +#endif +} + +} // namespace +} // namespace high_bitdepth +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +void AverageBlendInit_SSE4_1() { + low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + high_bitdepth::Init10bpp(); +#endif // LIBGAV1_MAX_BITDEPTH >= 10 +} } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_SSE4_1 +#else // !LIBGAV1_TARGETING_SSE4_1 namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/average_blend_sse4.h b/src/dsp/x86/average_blend_sse4.h index 937e8e2..cd07112 100644 --- a/src/dsp/x86/average_blend_sse4.h +++ b/src/dsp/x86/average_blend_sse4.h @@ -32,9 +32,13 @@ void AverageBlendInit_SSE4_1(); // If sse4 is enabled and the baseline isn't set due to a higher level of // optimization being enabled, signal the sse4 implementation should be used. #if LIBGAV1_TARGETING_SSE4_1 + #ifndef LIBGAV1_Dsp8bpp_AverageBlend #define LIBGAV1_Dsp8bpp_AverageBlend LIBGAV1_CPU_SSE4_1 #endif +#ifndef LIBGAV1_Dsp10bpp_AverageBlend +#define LIBGAV1_Dsp10bpp_AverageBlend LIBGAV1_CPU_SSE4_1 +#endif #endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/cdef_avx2.cc b/src/dsp/x86/cdef_avx2.cc new file mode 100644 index 0000000..d41dc38 --- /dev/null +++ b/src/dsp/x86/cdef_avx2.cc @@ -0,0 +1,784 @@ +// Copyright 2021 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/cdef.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_AVX2 +#include <immintrin.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstdlib> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_avx2.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +#include "src/dsp/cdef.inc" + +// Used when calculating odd |cost[x]| values. +// Holds elements 1 3 5 7 7 7 7 7 +alignas(32) constexpr uint32_t kCdefDivisionTableOddPairsPadded[] = { + 420, 210, 140, 105, 420, 210, 140, 105, + 105, 105, 105, 105, 105, 105, 105, 105}; + +// ---------------------------------------------------------------------------- +// Refer to CdefDirection_C(). +// +// int32_t partial[8][15] = {}; +// for (int i = 0; i < 8; ++i) { +// for (int j = 0; j < 8; ++j) { +// const int x = 1; +// partial[0][i + j] += x; +// partial[1][i + j / 2] += x; +// partial[2][i] += x; +// partial[3][3 + i - j / 2] += x; +// partial[4][7 + i - j] += x; +// partial[5][3 - i / 2 + j] += x; +// partial[6][j] += x; +// partial[7][i / 2 + j] += x; +// } +// } +// +// Using the code above, generate the position count for partial[8][15]. +// +// partial[0]: 1 2 3 4 5 6 7 8 7 6 5 4 3 2 1 +// partial[1]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0 +// partial[2]: 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0 +// partial[3]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0 +// partial[4]: 1 2 3 4 5 6 7 8 7 6 5 4 3 2 1 +// partial[5]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0 +// partial[6]: 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0 +// partial[7]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0 +// +// The SIMD code shifts the input horizontally, then adds vertically to get the +// correct partial value for the given position. +// ---------------------------------------------------------------------------- + +// ---------------------------------------------------------------------------- +// partial[0][i + j] += x; +// +// 00 01 02 03 04 05 06 07 00 00 00 00 00 00 00 +// 00 10 11 12 13 14 15 16 17 00 00 00 00 00 00 +// 00 00 20 21 22 23 24 25 26 27 00 00 00 00 00 +// 00 00 00 30 31 32 33 34 35 36 37 00 00 00 00 +// 00 00 00 00 40 41 42 43 44 45 46 47 00 00 00 +// 00 00 00 00 00 50 51 52 53 54 55 56 57 00 00 +// 00 00 00 00 00 00 60 61 62 63 64 65 66 67 00 +// 00 00 00 00 00 00 00 70 71 72 73 74 75 76 77 +// +// partial[4] is the same except the source is reversed. +LIBGAV1_ALWAYS_INLINE void AddPartial_D0_D4(__m256i* v_src_16, + __m256i* partial_lo, + __m256i* partial_hi) { + // 00 01 02 03 04 05 06 07 + *partial_lo = v_src_16[0]; + // 00 00 00 00 00 00 00 00 + *partial_hi = _mm256_setzero_si256(); + + // 00 10 11 12 13 14 15 16 + *partial_lo = + _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_src_16[1], 2)); + // 17 00 00 00 00 00 00 00 + *partial_hi = + _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_src_16[1], 14)); + + // 00 00 20 21 22 23 24 25 + *partial_lo = + _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_src_16[2], 4)); + // 26 27 00 00 00 00 00 00 + *partial_hi = + _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_src_16[2], 12)); + + // 00 00 00 30 31 32 33 34 + *partial_lo = + _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_src_16[3], 6)); + // 35 36 37 00 00 00 00 00 + *partial_hi = + _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_src_16[3], 10)); + + // 00 00 00 00 40 41 42 43 + *partial_lo = + _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_src_16[4], 8)); + // 44 45 46 47 00 00 00 00 + *partial_hi = + _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_src_16[4], 8)); + + // 00 00 00 00 00 50 51 52 + *partial_lo = + _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_src_16[5], 10)); + // 53 54 55 56 57 00 00 00 + *partial_hi = + _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_src_16[5], 6)); + + // 00 00 00 00 00 00 60 61 + *partial_lo = + _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_src_16[6], 12)); + // 62 63 64 65 66 67 00 00 + *partial_hi = + _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_src_16[6], 4)); + + // 00 00 00 00 00 00 00 70 + *partial_lo = + _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_src_16[7], 14)); + // 71 72 73 74 75 76 77 00 + *partial_hi = + _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_src_16[7], 2)); +} + +// ---------------------------------------------------------------------------- +// partial[1][i + j / 2] += x; +// +// A0 = src[0] + src[1], A1 = src[2] + src[3], ... +// +// A0 A1 A2 A3 00 00 00 00 00 00 00 00 00 00 00 +// 00 B0 B1 B2 B3 00 00 00 00 00 00 00 00 00 00 +// 00 00 C0 C1 C2 C3 00 00 00 00 00 00 00 00 00 +// 00 00 00 D0 D1 D2 D3 00 00 00 00 00 00 00 00 +// 00 00 00 00 E0 E1 E2 E3 00 00 00 00 00 00 00 +// 00 00 00 00 00 F0 F1 F2 F3 00 00 00 00 00 00 +// 00 00 00 00 00 00 G0 G1 G2 G3 00 00 00 00 00 +// 00 00 00 00 00 00 00 H0 H1 H2 H3 00 00 00 00 +// +// partial[3] is the same except the source is reversed. +LIBGAV1_ALWAYS_INLINE void AddPartial_D1_D3(__m256i* v_src_16, + __m256i* partial_lo, + __m256i* partial_hi) { + __m256i v_d1_temp[8]; + const __m256i v_zero = _mm256_setzero_si256(); + + for (int i = 0; i < 8; ++i) { + v_d1_temp[i] = _mm256_hadd_epi16(v_src_16[i], v_zero); + } + + *partial_lo = *partial_hi = v_zero; + // A0 A1 A2 A3 00 00 00 00 + *partial_lo = _mm256_add_epi16(*partial_lo, v_d1_temp[0]); + + // 00 B0 B1 B2 B3 00 00 00 + *partial_lo = + _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_d1_temp[1], 2)); + + // 00 00 C0 C1 C2 C3 00 00 + *partial_lo = + _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_d1_temp[2], 4)); + // 00 00 00 D0 D1 D2 D3 00 + *partial_lo = + _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_d1_temp[3], 6)); + // 00 00 00 00 E0 E1 E2 E3 + *partial_lo = + _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_d1_temp[4], 8)); + + // 00 00 00 00 00 F0 F1 F2 + *partial_lo = + _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_d1_temp[5], 10)); + // F3 00 00 00 00 00 00 00 + *partial_hi = + _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_d1_temp[5], 6)); + + // 00 00 00 00 00 00 G0 G1 + *partial_lo = + _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_d1_temp[6], 12)); + // G2 G3 00 00 00 00 00 00 + *partial_hi = + _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_d1_temp[6], 4)); + + // 00 00 00 00 00 00 00 H0 + *partial_lo = + _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_d1_temp[7], 14)); + // H1 H2 H3 00 00 00 00 00 + *partial_hi = + _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_d1_temp[7], 2)); +} + +// ---------------------------------------------------------------------------- +// partial[7][i / 2 + j] += x; +// +// 00 01 02 03 04 05 06 07 00 00 00 00 00 00 00 +// 10 11 12 13 14 15 16 17 00 00 00 00 00 00 00 +// 00 20 21 22 23 24 25 26 27 00 00 00 00 00 00 +// 00 30 31 32 33 34 35 36 37 00 00 00 00 00 00 +// 00 00 40 41 42 43 44 45 46 47 00 00 00 00 00 +// 00 00 50 51 52 53 54 55 56 57 00 00 00 00 00 +// 00 00 00 60 61 62 63 64 65 66 67 00 00 00 00 +// 00 00 00 70 71 72 73 74 75 76 77 00 00 00 00 +// +// partial[5] is the same except the source is reversed. +LIBGAV1_ALWAYS_INLINE void AddPartial_D7_D5(__m256i* v_src, __m256i* partial_lo, + __m256i* partial_hi) { + __m256i v_pair_add[4]; + // Add vertical source pairs. + v_pair_add[0] = _mm256_add_epi16(v_src[0], v_src[1]); + v_pair_add[1] = _mm256_add_epi16(v_src[2], v_src[3]); + v_pair_add[2] = _mm256_add_epi16(v_src[4], v_src[5]); + v_pair_add[3] = _mm256_add_epi16(v_src[6], v_src[7]); + + // 00 01 02 03 04 05 06 07 + // 10 11 12 13 14 15 16 17 + *partial_lo = v_pair_add[0]; + // 00 00 00 00 00 00 00 00 + // 00 00 00 00 00 00 00 00 + *partial_hi = _mm256_setzero_si256(); + + // 00 20 21 22 23 24 25 26 + // 00 30 31 32 33 34 35 36 + *partial_lo = + _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_pair_add[1], 2)); + // 27 00 00 00 00 00 00 00 + // 37 00 00 00 00 00 00 00 + *partial_hi = + _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_pair_add[1], 14)); + + // 00 00 40 41 42 43 44 45 + // 00 00 50 51 52 53 54 55 + *partial_lo = + _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_pair_add[2], 4)); + // 46 47 00 00 00 00 00 00 + // 56 57 00 00 00 00 00 00 + *partial_hi = + _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_pair_add[2], 12)); + + // 00 00 00 60 61 62 63 64 + // 00 00 00 70 71 72 73 74 + *partial_lo = + _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_pair_add[3], 6)); + // 65 66 67 00 00 00 00 00 + // 75 76 77 00 00 00 00 00 + *partial_hi = + _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_pair_add[3], 10)); +} + +LIBGAV1_ALWAYS_INLINE void AddPartial(const uint8_t* src, ptrdiff_t stride, + __m256i* partial) { + // 8x8 input + // 00 01 02 03 04 05 06 07 + // 10 11 12 13 14 15 16 17 + // 20 21 22 23 24 25 26 27 + // 30 31 32 33 34 35 36 37 + // 40 41 42 43 44 45 46 47 + // 50 51 52 53 54 55 56 57 + // 60 61 62 63 64 65 66 67 + // 70 71 72 73 74 75 76 77 + __m256i v_src[8]; + for (auto& i : v_src) { + i = _mm256_castsi128_si256(LoadLo8(src)); + // Dup lower lane. + i = _mm256_permute2x128_si256(i, i, 0x0); + src += stride; + } + + const __m256i v_zero = _mm256_setzero_si256(); + // partial for direction 2 + // -------------------------------------------------------------------------- + // partial[2][i] += x; + // 00 10 20 30 40 50 60 70 xx xx xx xx xx xx xx xx + // 01 11 21 33 41 51 61 71 xx xx xx xx xx xx xx xx + // 02 12 22 33 42 52 62 72 xx xx xx xx xx xx xx xx + // 03 13 23 33 43 53 63 73 xx xx xx xx xx xx xx xx + // 04 14 24 34 44 54 64 74 xx xx xx xx xx xx xx xx + // 05 15 25 35 45 55 65 75 xx xx xx xx xx xx xx xx + // 06 16 26 36 46 56 66 76 xx xx xx xx xx xx xx xx + // 07 17 27 37 47 57 67 77 xx xx xx xx xx xx xx xx + const __m256i v_src_4_0 = _mm256_unpacklo_epi64(v_src[0], v_src[4]); + const __m256i v_src_5_1 = _mm256_unpacklo_epi64(v_src[1], v_src[5]); + const __m256i v_src_6_2 = _mm256_unpacklo_epi64(v_src[2], v_src[6]); + const __m256i v_src_7_3 = _mm256_unpacklo_epi64(v_src[3], v_src[7]); + const __m256i v_hsum_4_0 = _mm256_sad_epu8(v_src_4_0, v_zero); + const __m256i v_hsum_5_1 = _mm256_sad_epu8(v_src_5_1, v_zero); + const __m256i v_hsum_6_2 = _mm256_sad_epu8(v_src_6_2, v_zero); + const __m256i v_hsum_7_3 = _mm256_sad_epu8(v_src_7_3, v_zero); + const __m256i v_hsum_1_0 = _mm256_unpacklo_epi16(v_hsum_4_0, v_hsum_5_1); + const __m256i v_hsum_3_2 = _mm256_unpacklo_epi16(v_hsum_6_2, v_hsum_7_3); + const __m256i v_hsum_5_4 = _mm256_unpackhi_epi16(v_hsum_4_0, v_hsum_5_1); + const __m256i v_hsum_7_6 = _mm256_unpackhi_epi16(v_hsum_6_2, v_hsum_7_3); + partial[2] = + _mm256_unpacklo_epi64(_mm256_unpacklo_epi32(v_hsum_1_0, v_hsum_3_2), + _mm256_unpacklo_epi32(v_hsum_5_4, v_hsum_7_6)); + + const __m256i extend_reverse = SetrM128i( + _mm_set_epi32(static_cast<int>(0x80078006), static_cast<int>(0x80058004), + static_cast<int>(0x80038002), static_cast<int>(0x80018000)), + _mm_set_epi32(static_cast<int>(0x80008001), static_cast<int>(0x80028003), + static_cast<int>(0x80048005), + static_cast<int>(0x80068007))); + + for (auto& i : v_src) { + // Zero extend unsigned 8 to 16. The upper lane is reversed. + i = _mm256_shuffle_epi8(i, extend_reverse); + } + + // partial for direction 6 + // -------------------------------------------------------------------------- + // partial[6][j] += x; + // 00 01 02 03 04 05 06 07 xx xx xx xx xx xx xx xx + // 10 11 12 13 14 15 16 17 xx xx xx xx xx xx xx xx + // 20 21 22 23 24 25 26 27 xx xx xx xx xx xx xx xx + // 30 31 32 33 34 35 36 37 xx xx xx xx xx xx xx xx + // 40 41 42 43 44 45 46 47 xx xx xx xx xx xx xx xx + // 50 51 52 53 54 55 56 57 xx xx xx xx xx xx xx xx + // 60 61 62 63 64 65 66 67 xx xx xx xx xx xx xx xx + // 70 71 72 73 74 75 76 77 xx xx xx xx xx xx xx xx + partial[6] = v_src[0]; + for (int i = 1; i < 8; ++i) { + partial[6] = _mm256_add_epi16(partial[6], v_src[i]); + } + + AddPartial_D0_D4(v_src, &partial[0], &partial[4]); + AddPartial_D1_D3(v_src, &partial[1], &partial[3]); + AddPartial_D7_D5(v_src, &partial[7], &partial[5]); +} + +inline __m256i SumVectorPair_S32(__m256i a) { + a = _mm256_hadd_epi32(a, a); + a = _mm256_add_epi32(a, _mm256_srli_si256(a, 4)); + return a; +} + +// |cost[0]| and |cost[4]| square the input and sum with the corresponding +// element from the other end of the vector: +// |kCdefDivisionTable[]| element: +// cost[0] += (Square(partial[0][i]) + Square(partial[0][14 - i])) * +// kCdefDivisionTable[i + 1]; +// cost[0] += Square(partial[0][7]) * kCdefDivisionTable[8]; +inline void Cost0Or4_Pair(uint32_t* cost, const __m256i partial_0, + const __m256i partial_4, + const __m256i division_table) { + const __m256i division_table_0 = + _mm256_permute2x128_si256(division_table, division_table, 0x0); + const __m256i division_table_1 = + _mm256_permute2x128_si256(division_table, division_table, 0x11); + + // partial_lo + const __m256i a = partial_0; + // partial_hi + const __m256i b = partial_4; + + // Reverse and clear upper 2 bytes. + const __m256i reverser = _mm256_broadcastsi128_si256(_mm_set_epi32( + static_cast<int>(0x80800100), 0x03020504, 0x07060908, 0x0b0a0d0c)); + + // 14 13 12 11 10 09 08 ZZ + const __m256i b_reversed = _mm256_shuffle_epi8(b, reverser); + // 00 14 01 13 02 12 03 11 + const __m256i ab_lo = _mm256_unpacklo_epi16(a, b_reversed); + // 04 10 05 09 06 08 07 ZZ + const __m256i ab_hi = _mm256_unpackhi_epi16(a, b_reversed); + + // Square(partial[0][i]) + Square(partial[0][14 - i]) + const __m256i square_lo = _mm256_madd_epi16(ab_lo, ab_lo); + const __m256i square_hi = _mm256_madd_epi16(ab_hi, ab_hi); + + const __m256i c = _mm256_mullo_epi32(square_lo, division_table_0); + const __m256i d = _mm256_mullo_epi32(square_hi, division_table_1); + const __m256i e = SumVectorPair_S32(_mm256_add_epi32(c, d)); + // Copy upper 32bit sum to lower lane. + const __m128i sums = + _mm256_castsi256_si128(_mm256_permute4x64_epi64(e, 0x08)); + cost[0] = _mm_cvtsi128_si32(sums); + cost[4] = _mm_cvtsi128_si32(_mm_srli_si128(sums, 8)); +} + +template <int index_a, int index_b> +inline void CostOdd_Pair(uint32_t* cost, const __m256i partial_a, + const __m256i partial_b, + const __m256i division_table[2]) { + // partial_lo + const __m256i a = partial_a; + // partial_hi + const __m256i b = partial_b; + + // Reverse and clear upper 10 bytes. + const __m256i reverser = _mm256_broadcastsi128_si256( + _mm_set_epi32(static_cast<int>(0x80808080), static_cast<int>(0x80808080), + static_cast<int>(0x80800100), 0x03020504)); + + // 10 09 08 ZZ ZZ ZZ ZZ ZZ + const __m256i b_reversed = _mm256_shuffle_epi8(b, reverser); + // 00 10 01 09 02 08 03 ZZ + const __m256i ab_lo = _mm256_unpacklo_epi16(a, b_reversed); + // 04 ZZ 05 ZZ 06 ZZ 07 ZZ + const __m256i ab_hi = _mm256_unpackhi_epi16(a, b_reversed); + + // Square(partial[0][i]) + Square(partial[0][14 - i]) + const __m256i square_lo = _mm256_madd_epi16(ab_lo, ab_lo); + const __m256i square_hi = _mm256_madd_epi16(ab_hi, ab_hi); + + const __m256i c = _mm256_mullo_epi32(square_lo, division_table[0]); + const __m256i d = _mm256_mullo_epi32(square_hi, division_table[1]); + const __m256i e = SumVectorPair_S32(_mm256_add_epi32(c, d)); + // Copy upper 32bit sum to lower lane. + const __m128i sums = + _mm256_castsi256_si128(_mm256_permute4x64_epi64(e, 0x08)); + cost[index_a] = _mm_cvtsi128_si32(sums); + cost[index_b] = _mm_cvtsi128_si32(_mm_srli_si128(sums, 8)); +} + +inline void Cost2And6_Pair(uint32_t* cost, const __m256i partial_a, + const __m256i partial_b, + const __m256i division_table) { + // The upper lane is a "don't care", so only use the lower lane for + // calculating cost. + const __m256i a = _mm256_permute2x128_si256(partial_a, partial_b, 0x20); + + const __m256i square_a = _mm256_madd_epi16(a, a); + const __m256i b = _mm256_mullo_epi32(square_a, division_table); + const __m256i c = SumVectorPair_S32(b); + // Copy upper 32bit sum to lower lane. + const __m128i sums = + _mm256_castsi256_si128(_mm256_permute4x64_epi64(c, 0x08)); + cost[2] = _mm_cvtsi128_si32(sums); + cost[6] = _mm_cvtsi128_si32(_mm_srli_si128(sums, 8)); +} + +void CdefDirection_AVX2(const void* const source, ptrdiff_t stride, + uint8_t* const direction, int* const variance) { + assert(direction != nullptr); + assert(variance != nullptr); + const auto* src = static_cast<const uint8_t*>(source); + uint32_t cost[8]; + + // partial[0] = add partial 0,4 low + // partial[1] = add partial 1,3 low + // partial[2] = add partial 2 low + // partial[3] = add partial 1,3 high + // partial[4] = add partial 0,4 high + // partial[5] = add partial 7,5 high + // partial[6] = add partial 6 low + // partial[7] = add partial 7,5 low + __m256i partial[8]; + + AddPartial(src, stride, partial); + + const __m256i division_table = LoadUnaligned32(kCdefDivisionTable); + const __m256i division_table_7 = + _mm256_broadcastd_epi32(_mm_cvtsi32_si128(kCdefDivisionTable[7])); + + Cost2And6_Pair(cost, partial[2], partial[6], division_table_7); + + Cost0Or4_Pair(cost, partial[0], partial[4], division_table); + + const __m256i division_table_odd[2] = { + LoadUnaligned32(kCdefDivisionTableOddPairsPadded), + LoadUnaligned32(kCdefDivisionTableOddPairsPadded + 8)}; + + CostOdd_Pair<1, 3>(cost, partial[1], partial[3], division_table_odd); + CostOdd_Pair<7, 5>(cost, partial[7], partial[5], division_table_odd); + + uint32_t best_cost = 0; + *direction = 0; + for (int i = 0; i < 8; ++i) { + if (cost[i] > best_cost) { + best_cost = cost[i]; + *direction = i; + } + } + *variance = (best_cost - cost[(*direction + 4) & 7]) >> 10; +} + +// ------------------------------------------------------------------------- +// CdefFilter + +// Load 4 vectors based on the given |direction|. +inline void LoadDirection(const uint16_t* const src, const ptrdiff_t stride, + __m128i* output, const int direction) { + // Each |direction| describes a different set of source values. Expand this + // set by negating each set. For |direction| == 0 this gives a diagonal line + // from top right to bottom left. The first value is y, the second x. Negative + // y values move up. + // a b c d + // {-1, 1}, {1, -1}, {-2, 2}, {2, -2} + // c + // a + // 0 + // b + // d + const int y_0 = kCdefDirections[direction][0][0]; + const int x_0 = kCdefDirections[direction][0][1]; + const int y_1 = kCdefDirections[direction][1][0]; + const int x_1 = kCdefDirections[direction][1][1]; + output[0] = LoadUnaligned16(src - y_0 * stride - x_0); + output[1] = LoadUnaligned16(src + y_0 * stride + x_0); + output[2] = LoadUnaligned16(src - y_1 * stride - x_1); + output[3] = LoadUnaligned16(src + y_1 * stride + x_1); +} + +// Load 4 vectors based on the given |direction|. Use when |block_width| == 4 to +// do 2 rows at a time. +void LoadDirection4(const uint16_t* const src, const ptrdiff_t stride, + __m128i* output, const int direction) { + const int y_0 = kCdefDirections[direction][0][0]; + const int x_0 = kCdefDirections[direction][0][1]; + const int y_1 = kCdefDirections[direction][1][0]; + const int x_1 = kCdefDirections[direction][1][1]; + output[0] = LoadHi8(LoadLo8(src - y_0 * stride - x_0), + src - y_0 * stride + stride - x_0); + output[1] = LoadHi8(LoadLo8(src + y_0 * stride + x_0), + src + y_0 * stride + stride + x_0); + output[2] = LoadHi8(LoadLo8(src - y_1 * stride - x_1), + src - y_1 * stride + stride - x_1); + output[3] = LoadHi8(LoadLo8(src + y_1 * stride + x_1), + src + y_1 * stride + stride + x_1); +} + +inline __m256i Constrain(const __m256i& pixel, const __m256i& reference, + const __m128i& damping, const __m256i& threshold) { + const __m256i diff = _mm256_sub_epi16(pixel, reference); + const __m256i abs_diff = _mm256_abs_epi16(diff); + // sign(diff) * Clip3(threshold - (std::abs(diff) >> damping), + // 0, std::abs(diff)) + const __m256i shifted_diff = _mm256_srl_epi16(abs_diff, damping); + // For bitdepth == 8, the threshold range is [0, 15] and the damping range is + // [3, 6]. If pixel == kCdefLargeValue(0x4000), shifted_diff will always be + // larger than threshold. Subtract using saturation will return 0 when pixel + // == kCdefLargeValue. + static_assert(kCdefLargeValue == 0x4000, "Invalid kCdefLargeValue"); + const __m256i thresh_minus_shifted_diff = + _mm256_subs_epu16(threshold, shifted_diff); + const __m256i clamp_abs_diff = + _mm256_min_epi16(thresh_minus_shifted_diff, abs_diff); + // Restore the sign. + return _mm256_sign_epi16(clamp_abs_diff, diff); +} + +inline __m256i ApplyConstrainAndTap(const __m256i& pixel, const __m256i& val, + const __m256i& tap, const __m128i& damping, + const __m256i& threshold) { + const __m256i constrained = Constrain(val, pixel, damping, threshold); + return _mm256_mullo_epi16(constrained, tap); +} + +template <int width, bool enable_primary = true, bool enable_secondary = true> +void CdefFilter_AVX2(const uint16_t* src, const ptrdiff_t src_stride, + const int height, const int primary_strength, + const int secondary_strength, const int damping, + const int direction, void* dest, + const ptrdiff_t dst_stride) { + static_assert(width == 8 || width == 4, "Invalid CDEF width."); + static_assert(enable_primary || enable_secondary, ""); + constexpr bool clipping_required = enable_primary && enable_secondary; + auto* dst = static_cast<uint8_t*>(dest); + __m128i primary_damping_shift, secondary_damping_shift; + + // FloorLog2() requires input to be > 0. + // 8-bit damping range: Y: [3, 6], UV: [2, 5]. + if (enable_primary) { + // primary_strength: [0, 15] -> FloorLog2: [0, 3] so a clamp is necessary + // for UV filtering. + primary_damping_shift = + _mm_cvtsi32_si128(std::max(0, damping - FloorLog2(primary_strength))); + } + if (enable_secondary) { + // secondary_strength: [0, 4] -> FloorLog2: [0, 2] so no clamp to 0 is + // necessary. + assert(damping - FloorLog2(secondary_strength) >= 0); + secondary_damping_shift = + _mm_cvtsi32_si128(damping - FloorLog2(secondary_strength)); + } + const __m256i primary_tap_0 = _mm256_broadcastw_epi16( + _mm_cvtsi32_si128(kCdefPrimaryTaps[primary_strength & 1][0])); + const __m256i primary_tap_1 = _mm256_broadcastw_epi16( + _mm_cvtsi32_si128(kCdefPrimaryTaps[primary_strength & 1][1])); + const __m256i secondary_tap_0 = + _mm256_broadcastw_epi16(_mm_cvtsi32_si128(kCdefSecondaryTap0)); + const __m256i secondary_tap_1 = + _mm256_broadcastw_epi16(_mm_cvtsi32_si128(kCdefSecondaryTap1)); + const __m256i cdef_large_value_mask = _mm256_broadcastw_epi16( + _mm_cvtsi32_si128(static_cast<int16_t>(~kCdefLargeValue))); + const __m256i primary_threshold = + _mm256_broadcastw_epi16(_mm_cvtsi32_si128(primary_strength)); + const __m256i secondary_threshold = + _mm256_broadcastw_epi16(_mm_cvtsi32_si128(secondary_strength)); + + int y = height; + do { + __m128i pixel_128; + if (width == 8) { + pixel_128 = LoadUnaligned16(src); + } else { + pixel_128 = LoadHi8(LoadLo8(src), src + src_stride); + } + + __m256i pixel = SetrM128i(pixel_128, pixel_128); + + __m256i min = pixel; + __m256i max = pixel; + __m256i sum_pair; + + if (enable_primary) { + // Primary |direction|. + __m128i primary_val_128[4]; + if (width == 8) { + LoadDirection(src, src_stride, primary_val_128, direction); + } else { + LoadDirection4(src, src_stride, primary_val_128, direction); + } + + __m256i primary_val[2]; + primary_val[0] = SetrM128i(primary_val_128[0], primary_val_128[1]); + primary_val[1] = SetrM128i(primary_val_128[2], primary_val_128[3]); + + if (clipping_required) { + min = _mm256_min_epu16(min, primary_val[0]); + min = _mm256_min_epu16(min, primary_val[1]); + + // The source is 16 bits, however, we only really care about the lower + // 8 bits. The upper 8 bits contain the "large" flag. After the final + // primary max has been calculated, zero out the upper 8 bits. Use this + // to find the "16 bit" max. + const __m256i max_p01 = _mm256_max_epu8(primary_val[0], primary_val[1]); + max = _mm256_max_epu16( + max, _mm256_and_si256(max_p01, cdef_large_value_mask)); + } + + sum_pair = ApplyConstrainAndTap(pixel, primary_val[0], primary_tap_0, + primary_damping_shift, primary_threshold); + sum_pair = _mm256_add_epi16( + sum_pair, + ApplyConstrainAndTap(pixel, primary_val[1], primary_tap_1, + primary_damping_shift, primary_threshold)); + } else { + sum_pair = _mm256_setzero_si256(); + } + + if (enable_secondary) { + // Secondary |direction| values (+/- 2). Clamp |direction|. + __m128i secondary_val_128[8]; + if (width == 8) { + LoadDirection(src, src_stride, secondary_val_128, direction + 2); + LoadDirection(src, src_stride, secondary_val_128 + 4, direction - 2); + } else { + LoadDirection4(src, src_stride, secondary_val_128, direction + 2); + LoadDirection4(src, src_stride, secondary_val_128 + 4, direction - 2); + } + + __m256i secondary_val[4]; + secondary_val[0] = SetrM128i(secondary_val_128[0], secondary_val_128[1]); + secondary_val[1] = SetrM128i(secondary_val_128[2], secondary_val_128[3]); + secondary_val[2] = SetrM128i(secondary_val_128[4], secondary_val_128[5]); + secondary_val[3] = SetrM128i(secondary_val_128[6], secondary_val_128[7]); + + if (clipping_required) { + min = _mm256_min_epu16(min, secondary_val[0]); + min = _mm256_min_epu16(min, secondary_val[1]); + min = _mm256_min_epu16(min, secondary_val[2]); + min = _mm256_min_epu16(min, secondary_val[3]); + + const __m256i max_s01 = + _mm256_max_epu8(secondary_val[0], secondary_val[1]); + const __m256i max_s23 = + _mm256_max_epu8(secondary_val[2], secondary_val[3]); + const __m256i max_s = _mm256_max_epu8(max_s01, max_s23); + max = _mm256_max_epu8(max, + _mm256_and_si256(max_s, cdef_large_value_mask)); + } + + sum_pair = _mm256_add_epi16( + sum_pair, + ApplyConstrainAndTap(pixel, secondary_val[0], secondary_tap_0, + secondary_damping_shift, secondary_threshold)); + sum_pair = _mm256_add_epi16( + sum_pair, + ApplyConstrainAndTap(pixel, secondary_val[1], secondary_tap_1, + secondary_damping_shift, secondary_threshold)); + sum_pair = _mm256_add_epi16( + sum_pair, + ApplyConstrainAndTap(pixel, secondary_val[2], secondary_tap_0, + secondary_damping_shift, secondary_threshold)); + sum_pair = _mm256_add_epi16( + sum_pair, + ApplyConstrainAndTap(pixel, secondary_val[3], secondary_tap_1, + secondary_damping_shift, secondary_threshold)); + } + + __m128i sum = _mm_add_epi16(_mm256_castsi256_si128(sum_pair), + _mm256_extracti128_si256(sum_pair, 1)); + + // Clip3(pixel + ((8 + sum - (sum < 0)) >> 4), min, max)) + const __m128i sum_lt_0 = _mm_srai_epi16(sum, 15); + // 8 + sum + sum = _mm_add_epi16(sum, _mm_set1_epi16(8)); + // (... - (sum < 0)) >> 4 + sum = _mm_add_epi16(sum, sum_lt_0); + sum = _mm_srai_epi16(sum, 4); + // pixel + ... + sum = _mm_add_epi16(sum, _mm256_castsi256_si128(pixel)); + if (clipping_required) { + const __m128i min_128 = _mm_min_epu16(_mm256_castsi256_si128(min), + _mm256_extracti128_si256(min, 1)); + + const __m128i max_128 = _mm_max_epu16(_mm256_castsi256_si128(max), + _mm256_extracti128_si256(max, 1)); + // Clip3 + sum = _mm_min_epi16(sum, max_128); + sum = _mm_max_epi16(sum, min_128); + } + + const __m128i result = _mm_packus_epi16(sum, sum); + if (width == 8) { + src += src_stride; + StoreLo8(dst, result); + dst += dst_stride; + --y; + } else { + src += src_stride << 1; + Store4(dst, result); + dst += dst_stride; + Store4(dst, _mm_srli_si128(result, 4)); + dst += dst_stride; + y -= 2; + } + } while (y != 0); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(8); + assert(dsp != nullptr); + dsp->cdef_direction = CdefDirection_AVX2; + + dsp->cdef_filters[0][0] = CdefFilter_AVX2<4>; + dsp->cdef_filters[0][1] = + CdefFilter_AVX2<4, /*enable_primary=*/true, /*enable_secondary=*/false>; + dsp->cdef_filters[0][2] = CdefFilter_AVX2<4, /*enable_primary=*/false>; + dsp->cdef_filters[1][0] = CdefFilter_AVX2<8>; + dsp->cdef_filters[1][1] = + CdefFilter_AVX2<8, /*enable_primary=*/true, /*enable_secondary=*/false>; + dsp->cdef_filters[1][2] = CdefFilter_AVX2<8, /*enable_primary=*/false>; +} + +} // namespace +} // namespace low_bitdepth + +void CdefInit_AVX2() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 +#else // !LIBGAV1_TARGETING_AVX2 +namespace libgav1 { +namespace dsp { + +void CdefInit_AVX2() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_AVX2 diff --git a/src/dsp/x86/cdef_avx2.h b/src/dsp/x86/cdef_avx2.h new file mode 100644 index 0000000..41f2d3f --- /dev/null +++ b/src/dsp/x86/cdef_avx2.h @@ -0,0 +1,45 @@ +/* + * Copyright 2021 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_CDEF_AVX2_H_ +#define LIBGAV1_SRC_DSP_X86_CDEF_AVX2_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::cdef_direction and Dsp::cdef_filters. This function is not +// thread-safe. +void CdefInit_AVX2(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_TARGETING_AVX2 + +#ifndef LIBGAV1_Dsp8bpp_CdefDirection +#define LIBGAV1_Dsp8bpp_CdefDirection LIBGAV1_CPU_AVX2 +#endif + +#ifndef LIBGAV1_Dsp8bpp_CdefFilters +#define LIBGAV1_Dsp8bpp_CdefFilters LIBGAV1_CPU_AVX2 +#endif + +#endif // LIBGAV1_TARGETING_AVX2 + +#endif // LIBGAV1_SRC_DSP_X86_CDEF_AVX2_H_ diff --git a/src/dsp/x86/cdef_sse4.cc b/src/dsp/x86/cdef_sse4.cc index 3211a2d..6ede778 100644 --- a/src/dsp/x86/cdef_sse4.cc +++ b/src/dsp/x86/cdef_sse4.cc @@ -349,8 +349,8 @@ inline uint32_t SumVector_S32(__m128i a) { inline uint32_t Cost0Or4(const __m128i a, const __m128i b, const __m128i division_table[2]) { // Reverse and clear upper 2 bytes. - const __m128i reverser = - _mm_set_epi32(0x80800100, 0x03020504, 0x07060908, 0x0b0a0d0c); + const __m128i reverser = _mm_set_epi32(static_cast<int>(0x80800100), + 0x03020504, 0x07060908, 0x0b0a0d0c); // 14 13 12 11 10 09 08 ZZ const __m128i b_reversed = _mm_shuffle_epi8(b, reverser); // 00 14 01 13 02 12 03 11 @@ -371,7 +371,8 @@ inline uint32_t CostOdd(const __m128i a, const __m128i b, const __m128i division_table[2]) { // Reverse and clear upper 10 bytes. const __m128i reverser = - _mm_set_epi32(0x80808080, 0x80808080, 0x80800100, 0x03020504); + _mm_set_epi32(static_cast<int>(0x80808080), static_cast<int>(0x80808080), + static_cast<int>(0x80800100), 0x03020504); // 10 09 08 ZZ ZZ ZZ ZZ ZZ const __m128i b_reversed = _mm_shuffle_epi8(b, reverser); // 00 10 01 09 02 08 03 ZZ @@ -717,7 +718,7 @@ void CdefInit_SSE4_1() { low_bitdepth::Init8bpp(); } } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_SSE4_1 +#else // !LIBGAV1_TARGETING_SSE4_1 namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/common_avx2.h b/src/dsp/x86/common_avx2.h index 4ce7de2..373116a 100644 --- a/src/dsp/x86/common_avx2.h +++ b/src/dsp/x86/common_avx2.h @@ -27,109 +27,60 @@ #include <cassert> #include <cstddef> #include <cstdint> +#include <cstring> namespace libgav1 { namespace dsp { - -//------------------------------------------------------------------------------ -// Compatibility functions. - -inline __m256i SetrM128i(const __m128i lo, const __m128i hi) { - // For compatibility with older gcc toolchains (< 8) use - // _mm256_inserti128_si256 over _mm256_setr_m128i. Newer gcc implementations - // are implemented similarly to the following, clang uses a different method - // but no differences in assembly have been observed. - return _mm256_inserti128_si256(_mm256_castsi128_si256(lo), hi, 1); -} - -//------------------------------------------------------------------------------ -// Load functions. - -inline __m256i LoadAligned32(const void* a) { - assert((reinterpret_cast<uintptr_t>(a) & 0x1f) == 0); - return _mm256_load_si256(static_cast<const __m256i*>(a)); -} - -inline void LoadAligned64(const void* a, __m256i dst[2]) { - assert((reinterpret_cast<uintptr_t>(a) & 0x1f) == 0); - dst[0] = _mm256_load_si256(static_cast<const __m256i*>(a) + 0); - dst[1] = _mm256_load_si256(static_cast<const __m256i*>(a) + 1); -} - -inline __m256i LoadUnaligned32(const void* a) { - return _mm256_loadu_si256(static_cast<const __m256i*>(a)); -} - -//------------------------------------------------------------------------------ -// Load functions to avoid MemorySanitizer's use-of-uninitialized-value warning. - -inline __m256i MaskOverreads(const __m256i source, - const ptrdiff_t over_read_in_bytes) { - __m256i dst = source; -#if LIBGAV1_MSAN - if (over_read_in_bytes >= 32) return _mm256_setzero_si256(); - if (over_read_in_bytes > 0) { - __m128i m = _mm_set1_epi8(-1); - for (ptrdiff_t i = 0; i < over_read_in_bytes % 16; ++i) { - m = _mm_srli_si128(m, 1); - } - const __m256i mask = (over_read_in_bytes < 16) - ? SetrM128i(_mm_set1_epi8(-1), m) - : SetrM128i(m, _mm_setzero_si128()); - dst = _mm256_and_si256(dst, mask); - } -#else - static_cast<void>(over_read_in_bytes); -#endif - return dst; -} - -inline __m256i LoadAligned32Msan(const void* const source, - const ptrdiff_t over_read_in_bytes) { - return MaskOverreads(LoadAligned32(source), over_read_in_bytes); -} - -inline void LoadAligned64Msan(const void* const source, - const ptrdiff_t over_read_in_bytes, - __m256i dst[2]) { - dst[0] = MaskOverreads(LoadAligned32(source), over_read_in_bytes); - dst[1] = MaskOverreads(LoadAligned32(static_cast<const __m256i*>(source) + 1), - over_read_in_bytes); -} - -inline __m256i LoadUnaligned32Msan(const void* const source, - const ptrdiff_t over_read_in_bytes) { - return MaskOverreads(LoadUnaligned32(source), over_read_in_bytes); -} - -//------------------------------------------------------------------------------ -// Store functions. - -inline void StoreAligned32(void* a, const __m256i v) { - assert((reinterpret_cast<uintptr_t>(a) & 0x1f) == 0); - _mm256_store_si256(static_cast<__m256i*>(a), v); -} - -inline void StoreAligned64(void* a, const __m256i v[2]) { - assert((reinterpret_cast<uintptr_t>(a) & 0x1f) == 0); - _mm256_store_si256(static_cast<__m256i*>(a) + 0, v[0]); - _mm256_store_si256(static_cast<__m256i*>(a) + 1, v[1]); -} - -inline void StoreUnaligned32(void* a, const __m256i v) { - _mm256_storeu_si256(static_cast<__m256i*>(a), v); -} - -//------------------------------------------------------------------------------ -// Arithmetic utilities. - -inline __m256i RightShiftWithRounding_S16(const __m256i v_val_d, int bits) { - assert(bits <= 16); - const __m256i v_bias_d = - _mm256_set1_epi16(static_cast<int16_t>((1 << bits) >> 1)); - const __m256i v_tmp_d = _mm256_add_epi16(v_val_d, v_bias_d); - return _mm256_srai_epi16(v_tmp_d, bits); -} +namespace avx2 { + +#include "src/dsp/x86/common_avx2.inc" +#include "src/dsp/x86/common_sse4.inc" + +} // namespace avx2 + +// NOLINTBEGIN(misc-unused-using-decls) +// These function aliases shall not be visible to external code. They are +// restricted to x86/*_avx2.cc files only. This scheme exists to distinguish two +// possible implementations of common functions, which may differ based on +// whether the compiler is permitted to use avx2 instructions. + +// common_sse4.inc +using avx2::Load2; +using avx2::Load2x2; +using avx2::Load4; +using avx2::Load4x2; +using avx2::LoadAligned16; +using avx2::LoadAligned16Msan; +using avx2::LoadHi8; +using avx2::LoadHi8Msan; +using avx2::LoadLo8; +using avx2::LoadLo8Msan; +using avx2::LoadUnaligned16; +using avx2::LoadUnaligned16Msan; +using avx2::MaskHighNBytes; +using avx2::RightShiftWithRounding_S16; +using avx2::RightShiftWithRounding_S32; +using avx2::RightShiftWithRounding_U16; +using avx2::RightShiftWithRounding_U32; +using avx2::Store2; +using avx2::Store4; +using avx2::StoreAligned16; +using avx2::StoreHi8; +using avx2::StoreLo8; +using avx2::StoreUnaligned16; + +// common_avx2.inc +using avx2::LoadAligned32; +using avx2::LoadAligned32Msan; +using avx2::LoadAligned64; +using avx2::LoadAligned64Msan; +using avx2::LoadUnaligned32; +using avx2::LoadUnaligned32Msan; +using avx2::SetrM128i; +using avx2::StoreAligned32; +using avx2::StoreAligned64; +using avx2::StoreUnaligned32; +// NOLINTEND } // namespace dsp } // namespace libgav1 diff --git a/src/dsp/x86/common_avx2.inc b/src/dsp/x86/common_avx2.inc new file mode 100644 index 0000000..53b4e2e --- /dev/null +++ b/src/dsp/x86/common_avx2.inc @@ -0,0 +1,121 @@ +/* + * Copyright 2021 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//------------------------------------------------------------------------------ +// Compatibility functions. + +inline __m256i SetrM128i(const __m128i lo, const __m128i hi) { + // For compatibility with older gcc toolchains (< 8) use + // _mm256_inserti128_si256 over _mm256_setr_m128i. Newer gcc implementations + // are implemented similarly to the following, clang uses a different method + // but no differences in assembly have been observed. + return _mm256_inserti128_si256(_mm256_castsi128_si256(lo), hi, 1); +} + +//------------------------------------------------------------------------------ +// Load functions. + +inline __m256i LoadAligned32(const void* a) { + assert((reinterpret_cast<uintptr_t>(a) & 0x1f) == 0); + return _mm256_load_si256(static_cast<const __m256i*>(a)); +} + +inline void LoadAligned64(const void* a, __m256i dst[2]) { + assert((reinterpret_cast<uintptr_t>(a) & 0x1f) == 0); + dst[0] = _mm256_load_si256(static_cast<const __m256i*>(a) + 0); + dst[1] = _mm256_load_si256(static_cast<const __m256i*>(a) + 1); +} + +inline __m256i LoadUnaligned32(const void* a) { + return _mm256_loadu_si256(static_cast<const __m256i*>(a)); +} + +//------------------------------------------------------------------------------ +// Load functions to avoid MemorySanitizer's use-of-uninitialized-value warning. + +inline __m256i MaskOverreads(const __m256i source, + const ptrdiff_t over_read_in_bytes) { + __m256i dst = source; +#if LIBGAV1_MSAN + if (over_read_in_bytes >= 32) return _mm256_setzero_si256(); + if (over_read_in_bytes > 0) { + __m128i m = _mm_set1_epi8(-1); + for (ptrdiff_t i = 0; i < over_read_in_bytes % 16; ++i) { + m = _mm_srli_si128(m, 1); + } + const __m256i mask = (over_read_in_bytes < 16) + ? SetrM128i(_mm_set1_epi8(-1), m) + : SetrM128i(m, _mm_setzero_si128()); + dst = _mm256_and_si256(dst, mask); + } +#else + static_cast<void>(over_read_in_bytes); +#endif + return dst; +} + +inline __m256i LoadAligned32Msan(const void* const source, + const ptrdiff_t over_read_in_bytes) { + return MaskOverreads(LoadAligned32(source), over_read_in_bytes); +} + +inline void LoadAligned64Msan(const void* const source, + const ptrdiff_t over_read_in_bytes, + __m256i dst[2]) { + dst[0] = MaskOverreads(LoadAligned32(source), over_read_in_bytes); + dst[1] = MaskOverreads(LoadAligned32(static_cast<const __m256i*>(source) + 1), + over_read_in_bytes); +} + +inline __m256i LoadUnaligned32Msan(const void* const source, + const ptrdiff_t over_read_in_bytes) { + return MaskOverreads(LoadUnaligned32(source), over_read_in_bytes); +} + +//------------------------------------------------------------------------------ +// Store functions. + +inline void StoreAligned32(void* a, const __m256i v) { + assert((reinterpret_cast<uintptr_t>(a) & 0x1f) == 0); + _mm256_store_si256(static_cast<__m256i*>(a), v); +} + +inline void StoreAligned64(void* a, const __m256i v[2]) { + assert((reinterpret_cast<uintptr_t>(a) & 0x1f) == 0); + _mm256_store_si256(static_cast<__m256i*>(a) + 0, v[0]); + _mm256_store_si256(static_cast<__m256i*>(a) + 1, v[1]); +} + +inline void StoreUnaligned32(void* a, const __m256i v) { + _mm256_storeu_si256(static_cast<__m256i*>(a), v); +} + +//------------------------------------------------------------------------------ +// Arithmetic utilities. + +inline __m256i RightShiftWithRounding_S16(const __m256i v_val_d, int bits) { + assert(bits <= 16); + const __m256i v_bias_d = + _mm256_set1_epi16(static_cast<int16_t>((1 << bits) >> 1)); + const __m256i v_tmp_d = _mm256_add_epi16(v_val_d, v_bias_d); + return _mm256_srai_epi16(v_tmp_d, bits); +} + +inline __m256i RightShiftWithRounding_S32(const __m256i v_val_d, int bits) { + const __m256i v_bias_d = _mm256_set1_epi32((1 << bits) >> 1); + const __m256i v_tmp_d = _mm256_add_epi32(v_val_d, v_bias_d); + return _mm256_srai_epi32(v_tmp_d, bits); +} diff --git a/src/dsp/x86/common_sse4.h b/src/dsp/x86/common_sse4.h index c510f8c..41a3a68 100644 --- a/src/dsp/x86/common_sse4.h +++ b/src/dsp/x86/common_sse4.h @@ -28,7 +28,6 @@ #include <cassert> #include <cstddef> #include <cstdint> -#include <cstdlib> #include <cstring> #if 0 @@ -71,192 +70,58 @@ inline void PrintRegX(const int r, const char* const name) { #define PR(var, N) PrintReg(var, #var, N) #define PD(var) PrintReg(var, #var); #define PX(var) PrintRegX(var, #var); -#endif // 0 - -namespace libgav1 { -namespace dsp { - -//------------------------------------------------------------------------------ -// Load functions. - -inline __m128i Load2(const void* src) { - int16_t val; - memcpy(&val, src, sizeof(val)); - return _mm_cvtsi32_si128(val); -} - -inline __m128i Load2x2(const void* src1, const void* src2) { - uint16_t val1; - uint16_t val2; - memcpy(&val1, src1, sizeof(val1)); - memcpy(&val2, src2, sizeof(val2)); - return _mm_cvtsi32_si128(val1 | (val2 << 16)); -} - -// Load 2 uint8_t values into |lane| * 2 and |lane| * 2 + 1. -template <int lane> -inline __m128i Load2(const void* const buf, __m128i val) { - uint16_t temp; - memcpy(&temp, buf, 2); - return _mm_insert_epi16(val, temp, lane); -} - -inline __m128i Load4(const void* src) { - // With new compilers such as clang 8.0.0 we can use the new _mm_loadu_si32 - // intrinsic. Both _mm_loadu_si32(src) and the code here are compiled into a - // movss instruction. - // - // Until compiler support of _mm_loadu_si32 is widespread, use of - // _mm_loadu_si32 is banned. - int val; - memcpy(&val, src, sizeof(val)); - return _mm_cvtsi32_si128(val); -} - -inline __m128i Load4x2(const void* src1, const void* src2) { - // With new compilers such as clang 8.0.0 we can use the new _mm_loadu_si32 - // intrinsic. Both _mm_loadu_si32(src) and the code here are compiled into a - // movss instruction. - // - // Until compiler support of _mm_loadu_si32 is widespread, use of - // _mm_loadu_si32 is banned. - int val1, val2; - memcpy(&val1, src1, sizeof(val1)); - memcpy(&val2, src2, sizeof(val2)); - return _mm_insert_epi32(_mm_cvtsi32_si128(val1), val2, 1); -} -inline __m128i LoadLo8(const void* a) { - return _mm_loadl_epi64(static_cast<const __m128i*>(a)); -} - -inline __m128i LoadHi8(const __m128i v, const void* a) { - const __m128 x = - _mm_loadh_pi(_mm_castsi128_ps(v), static_cast<const __m64*>(a)); - return _mm_castps_si128(x); -} - -inline __m128i LoadUnaligned16(const void* a) { - return _mm_loadu_si128(static_cast<const __m128i*>(a)); -} - -inline __m128i LoadAligned16(const void* a) { - assert((reinterpret_cast<uintptr_t>(a) & 0xf) == 0); - return _mm_load_si128(static_cast<const __m128i*>(a)); -} - -//------------------------------------------------------------------------------ -// Load functions to avoid MemorySanitizer's use-of-uninitialized-value warning. - -inline __m128i MaskOverreads(const __m128i source, - const ptrdiff_t over_read_in_bytes) { - __m128i dst = source; #if LIBGAV1_MSAN - if (over_read_in_bytes > 0) { - __m128i mask = _mm_set1_epi8(-1); - for (ptrdiff_t i = 0; i < over_read_in_bytes; ++i) { - mask = _mm_srli_si128(mask, 1); - } - dst = _mm_and_si128(dst, mask); - } -#else - static_cast<void>(over_read_in_bytes); -#endif - return dst; -} +#include <sanitizer/msan_interface.h> -inline __m128i LoadLo8Msan(const void* const source, - const ptrdiff_t over_read_in_bytes) { - return MaskOverreads(LoadLo8(source), over_read_in_bytes + 8); +inline void PrintShadow(const void* r, const char* const name, + const size_t size) { + fprintf(stderr, "Shadow for %s:\n", name); + __msan_print_shadow(r, size); } +#define PS(var, N) PrintShadow(var, #var, N) -inline __m128i LoadHi8Msan(const __m128i v, const void* source, - const ptrdiff_t over_read_in_bytes) { - return MaskOverreads(LoadHi8(v, source), over_read_in_bytes); -} - -inline __m128i LoadAligned16Msan(const void* const source, - const ptrdiff_t over_read_in_bytes) { - return MaskOverreads(LoadAligned16(source), over_read_in_bytes); -} +#endif // LIBGAV1_MSAN -inline __m128i LoadUnaligned16Msan(const void* const source, - const ptrdiff_t over_read_in_bytes) { - return MaskOverreads(LoadUnaligned16(source), over_read_in_bytes); -} - -//------------------------------------------------------------------------------ -// Store functions. - -inline void Store2(void* dst, const __m128i x) { - const int val = _mm_cvtsi128_si32(x); - memcpy(dst, &val, 2); -} - -inline void Store4(void* dst, const __m128i x) { - const int val = _mm_cvtsi128_si32(x); - memcpy(dst, &val, sizeof(val)); -} - -inline void StoreLo8(void* a, const __m128i v) { - _mm_storel_epi64(static_cast<__m128i*>(a), v); -} - -inline void StoreHi8(void* a, const __m128i v) { - _mm_storeh_pi(static_cast<__m64*>(a), _mm_castsi128_ps(v)); -} - -inline void StoreAligned16(void* a, const __m128i v) { - assert((reinterpret_cast<uintptr_t>(a) & 0xf) == 0); - _mm_store_si128(static_cast<__m128i*>(a), v); -} - -inline void StoreUnaligned16(void* a, const __m128i v) { - _mm_storeu_si128(static_cast<__m128i*>(a), v); -} - -//------------------------------------------------------------------------------ -// Arithmetic utilities. - -inline __m128i RightShiftWithRounding_U16(const __m128i v_val_d, int bits) { - assert(bits <= 16); - // Shift out all but the last bit. - const __m128i v_tmp_d = _mm_srli_epi16(v_val_d, bits - 1); - // Avg with zero will shift by 1 and round. - return _mm_avg_epu16(v_tmp_d, _mm_setzero_si128()); -} - -inline __m128i RightShiftWithRounding_S16(const __m128i v_val_d, int bits) { - assert(bits <= 16); - const __m128i v_bias_d = - _mm_set1_epi16(static_cast<int16_t>((1 << bits) >> 1)); - const __m128i v_tmp_d = _mm_add_epi16(v_val_d, v_bias_d); - return _mm_srai_epi16(v_tmp_d, bits); -} - -inline __m128i RightShiftWithRounding_U32(const __m128i v_val_d, int bits) { - const __m128i v_bias_d = _mm_set1_epi32((1 << bits) >> 1); - const __m128i v_tmp_d = _mm_add_epi32(v_val_d, v_bias_d); - return _mm_srli_epi32(v_tmp_d, bits); -} - -inline __m128i RightShiftWithRounding_S32(const __m128i v_val_d, int bits) { - const __m128i v_bias_d = _mm_set1_epi32((1 << bits) >> 1); - const __m128i v_tmp_d = _mm_add_epi32(v_val_d, v_bias_d); - return _mm_srai_epi32(v_tmp_d, bits); -} - -//------------------------------------------------------------------------------ -// Masking utilities -inline __m128i MaskHighNBytes(int n) { - static constexpr uint8_t kMask[32] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - }; +#endif // 0 - return LoadUnaligned16(kMask + n); -} +namespace libgav1 { +namespace dsp { +namespace sse4 { + +#include "src/dsp/x86/common_sse4.inc" + +} // namespace sse4 + +// NOLINTBEGIN(misc-unused-using-decls) +// These function aliases shall not be visible to external code. They are +// restricted to x86/*_sse4.cc files only. This scheme exists to distinguish two +// possible implementations of common functions, which may differ based on +// whether the compiler is permitted to use avx2 instructions. +using sse4::Load2; +using sse4::Load2x2; +using sse4::Load4; +using sse4::Load4x2; +using sse4::LoadAligned16; +using sse4::LoadAligned16Msan; +using sse4::LoadHi8; +using sse4::LoadHi8Msan; +using sse4::LoadLo8; +using sse4::LoadLo8Msan; +using sse4::LoadUnaligned16; +using sse4::LoadUnaligned16Msan; +using sse4::MaskHighNBytes; +using sse4::RightShiftWithRounding_S16; +using sse4::RightShiftWithRounding_S32; +using sse4::RightShiftWithRounding_U16; +using sse4::RightShiftWithRounding_U32; +using sse4::Store2; +using sse4::Store4; +using sse4::StoreAligned16; +using sse4::StoreHi8; +using sse4::StoreLo8; +using sse4::StoreUnaligned16; +// NOLINTEND } // namespace dsp } // namespace libgav1 diff --git a/src/dsp/x86/common_sse4.inc b/src/dsp/x86/common_sse4.inc new file mode 100644 index 0000000..35c56b8 --- /dev/null +++ b/src/dsp/x86/common_sse4.inc @@ -0,0 +1,206 @@ +/* + * Copyright 2021 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//------------------------------------------------------------------------------ +// Load functions. + +inline __m128i Load2(const void* src) { + int16_t val; + memcpy(&val, src, sizeof(val)); + return _mm_cvtsi32_si128(val); +} + +inline __m128i Load2x2(const void* src1, const void* src2) { + uint16_t val1; + uint16_t val2; + memcpy(&val1, src1, sizeof(val1)); + memcpy(&val2, src2, sizeof(val2)); + return _mm_cvtsi32_si128(val1 | (val2 << 16)); +} + +// Load 2 uint8_t values into |lane| * 2 and |lane| * 2 + 1. +template <int lane> +inline __m128i Load2(const void* const buf, __m128i val) { + int16_t temp; + memcpy(&temp, buf, 2); + return _mm_insert_epi16(val, temp, lane); +} + +inline __m128i Load4(const void* src) { + // With new compilers such as clang 8.0.0 we can use the new _mm_loadu_si32 + // intrinsic. Both _mm_loadu_si32(src) and the code here are compiled into a + // movss instruction. + // + // Until compiler support of _mm_loadu_si32 is widespread, use of + // _mm_loadu_si32 is banned. + int val; + memcpy(&val, src, sizeof(val)); + return _mm_cvtsi32_si128(val); +} + +inline __m128i Load4x2(const void* src1, const void* src2) { + // With new compilers such as clang 8.0.0 we can use the new _mm_loadu_si32 + // intrinsic. Both _mm_loadu_si32(src) and the code here are compiled into a + // movss instruction. + // + // Until compiler support of _mm_loadu_si32 is widespread, use of + // _mm_loadu_si32 is banned. + int val1, val2; + memcpy(&val1, src1, sizeof(val1)); + memcpy(&val2, src2, sizeof(val2)); + return _mm_insert_epi32(_mm_cvtsi32_si128(val1), val2, 1); +} + +inline __m128i LoadLo8(const void* a) { + return _mm_loadl_epi64(static_cast<const __m128i*>(a)); +} + +inline __m128i LoadHi8(const __m128i v, const void* a) { + const __m128 x = + _mm_loadh_pi(_mm_castsi128_ps(v), static_cast<const __m64*>(a)); + return _mm_castps_si128(x); +} + +inline __m128i LoadUnaligned16(const void* a) { + return _mm_loadu_si128(static_cast<const __m128i*>(a)); +} + +inline __m128i LoadAligned16(const void* a) { + assert((reinterpret_cast<uintptr_t>(a) & 0xf) == 0); + return _mm_load_si128(static_cast<const __m128i*>(a)); +} + +//------------------------------------------------------------------------------ +// Load functions to avoid MemorySanitizer's use-of-uninitialized-value warning. + +inline __m128i MaskOverreads(const __m128i source, + const ptrdiff_t over_read_in_bytes) { + __m128i dst = source; +#if LIBGAV1_MSAN + if (over_read_in_bytes > 0) { + __m128i mask = _mm_set1_epi8(-1); + for (ptrdiff_t i = 0; i < over_read_in_bytes; ++i) { + mask = _mm_srli_si128(mask, 1); + } + dst = _mm_and_si128(dst, mask); + } +#else + static_cast<void>(over_read_in_bytes); +#endif + return dst; +} + +inline __m128i LoadLo8Msan(const void* const source, + const ptrdiff_t over_read_in_bytes) { + return MaskOverreads(LoadLo8(source), over_read_in_bytes + 8); +} + +inline __m128i LoadHi8Msan(const __m128i v, const void* source, + const ptrdiff_t over_read_in_bytes) { + return MaskOverreads(LoadHi8(v, source), over_read_in_bytes); +} + +inline __m128i LoadAligned16Msan(const void* const source, + const ptrdiff_t over_read_in_bytes) { + return MaskOverreads(LoadAligned16(source), over_read_in_bytes); +} + +inline __m128i LoadUnaligned16Msan(const void* const source, + const ptrdiff_t over_read_in_bytes) { + return MaskOverreads(LoadUnaligned16(source), over_read_in_bytes); +} + +//------------------------------------------------------------------------------ +// Store functions. + +inline void Store2(void* dst, const __m128i x) { + const int val = _mm_cvtsi128_si32(x); + memcpy(dst, &val, 2); +} + +inline void Store4(void* dst, const __m128i x) { + const int val = _mm_cvtsi128_si32(x); + memcpy(dst, &val, sizeof(val)); +} + +inline void StoreLo8(void* a, const __m128i v) { + _mm_storel_epi64(static_cast<__m128i*>(a), v); +} + +inline void StoreHi8(void* a, const __m128i v) { + _mm_storeh_pi(static_cast<__m64*>(a), _mm_castsi128_ps(v)); +} + +inline void StoreAligned16(void* a, const __m128i v) { + assert((reinterpret_cast<uintptr_t>(a) & 0xf) == 0); + _mm_store_si128(static_cast<__m128i*>(a), v); +} + +inline void StoreUnaligned16(void* a, const __m128i v) { + _mm_storeu_si128(static_cast<__m128i*>(a), v); +} + +//------------------------------------------------------------------------------ +// Arithmetic utilities. + +inline __m128i RightShiftWithRounding_U16(const __m128i v_val_d, int bits) { + assert(bits <= 16); + // Shift out all but the last bit. + const __m128i v_tmp_d = _mm_srli_epi16(v_val_d, bits - 1); + // Avg with zero will shift by 1 and round. + return _mm_avg_epu16(v_tmp_d, _mm_setzero_si128()); +} + +inline __m128i RightShiftWithRounding_S16(const __m128i v_val_d, int bits) { + assert(bits < 16); + const __m128i v_bias_d = + _mm_set1_epi16(static_cast<int16_t>((1 << bits) >> 1)); + const __m128i v_tmp_d = _mm_add_epi16(v_val_d, v_bias_d); + return _mm_srai_epi16(v_tmp_d, bits); +} + +inline __m128i RightShiftWithRounding_U32(const __m128i v_val_d, int bits) { + const __m128i v_bias_d = _mm_set1_epi32((1 << bits) >> 1); + const __m128i v_tmp_d = _mm_add_epi32(v_val_d, v_bias_d); + return _mm_srli_epi32(v_tmp_d, bits); +} + +inline __m128i RightShiftWithRounding_S32(const __m128i v_val_d, int bits) { + const __m128i v_bias_d = _mm_set1_epi32((1 << bits) >> 1); + const __m128i v_tmp_d = _mm_add_epi32(v_val_d, v_bias_d); + return _mm_srai_epi32(v_tmp_d, bits); +} + +// Use this when |bits| is not an immediate value. +inline __m128i VariableRightShiftWithRounding_S32(const __m128i v_val_d, + int bits) { + const __m128i v_bias_d = + _mm_set1_epi32(static_cast<int32_t>((1 << bits) >> 1)); + const __m128i v_tmp_d = _mm_add_epi32(v_val_d, v_bias_d); + return _mm_sra_epi32(v_tmp_d, _mm_cvtsi32_si128(bits)); +} + +//------------------------------------------------------------------------------ +// Masking utilities +inline __m128i MaskHighNBytes(int n) { + static constexpr uint8_t kMask[32] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + }; + + return LoadUnaligned16(kMask + n); +} diff --git a/src/dsp/x86/convolve_avx2.cc b/src/dsp/x86/convolve_avx2.cc index 3df2120..2ecb77c 100644 --- a/src/dsp/x86/convolve_avx2.cc +++ b/src/dsp/x86/convolve_avx2.cc @@ -26,7 +26,6 @@ #include "src/dsp/constants.h" #include "src/dsp/dsp.h" #include "src/dsp/x86/common_avx2.h" -#include "src/dsp/x86/common_sse4.h" #include "src/utils/common.h" #include "src/utils/constants.h" @@ -35,7 +34,7 @@ namespace dsp { namespace low_bitdepth { namespace { -constexpr int kHorizontalOffset = 3; +#include "src/dsp/x86/convolve_sse4.inc" // Multiply every entry in |src[]| by the corresponding entry in |taps[]| and // sum. The filters in |taps[]| are pre-shifted by 1. This prevents the final @@ -118,58 +117,15 @@ __m256i SimpleHorizontalTaps(const __m256i* const src, } template <int filter_index> -__m128i SumHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride, - const __m128i* const v_tap) { - // 00 01 02 03 04 05 06 07 10 11 12 13 14 15 16 17 - const __m128i v_src = LoadHi8(LoadLo8(&src[0]), &src[src_stride]); - - if (filter_index == 3) { - // 03 04 04 05 05 06 06 07 13 14 14 15 15 16 16 17 - const __m128i v_src_43 = _mm_shuffle_epi8( - v_src, _mm_set_epi32(0x0f0e0e0d, 0x0d0c0c0b, 0x07060605, 0x05040403)); - const __m128i v_sum_43 = _mm_maddubs_epi16(v_src_43, v_tap[0]); // k4k3 - return v_sum_43; - } - - // 02 03 03 04 04 05 05 06 12 13 13 14 14 15 15 16 - const __m128i v_src_32 = _mm_shuffle_epi8( - v_src, _mm_set_epi32(0x0e0d0d0c, 0x0c0b0b0a, 0x06050504, 0x04030302)); - // 04 05 05 06 06 07 07 xx 14 15 15 16 16 17 17 xx - const __m128i v_src_54 = _mm_shuffle_epi8( - v_src, _mm_set_epi32(0x800f0f0e, 0x0e0d0d0c, 0x80070706, 0x06050504)); - const __m128i v_madd_32 = _mm_maddubs_epi16(v_src_32, v_tap[0]); // k3k2 - const __m128i v_madd_54 = _mm_maddubs_epi16(v_src_54, v_tap[1]); // k5k4 - const __m128i v_sum_5432 = _mm_add_epi16(v_madd_54, v_madd_32); - return v_sum_5432; -} - -template <int filter_index> -__m128i SimpleHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride, - const __m128i* const v_tap) { - __m128i sum = SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); - - // Normally the Horizontal pass does the downshift in two passes: - // kInterRoundBitsHorizontal - 1 and then (kFilterBits - - // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them - // requires adding the rounding offset from the skipped shift. - constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2); - - sum = _mm_add_epi16(sum, _mm_set1_epi16(first_shift_rounding_bit)); - sum = RightShiftWithRounding_S16(sum, kFilterBits - 1); - return _mm_packus_epi16(sum, sum); -} - -template <int filter_index> -__m128i HorizontalTaps8To16_2x2(const uint8_t* src, const ptrdiff_t src_stride, - const __m128i* const v_tap) { - const __m128i sum = - SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); +__m256i HorizontalTaps8To16(const __m256i* const src, + const __m256i* const v_tap) { + const __m256i sum = SumHorizontalTaps<filter_index>(src, v_tap); return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1); } // Filter 2xh sizes. -template <int num_taps, int step, int filter_index, bool is_2d = false, +template <int num_taps, int filter_index, bool is_2d = false, bool is_compound = false> void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, void* const dest, const ptrdiff_t pred_stride, @@ -183,7 +139,8 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, assert(num_taps <= 4); if (num_taps <= 4) { if (!is_compound) { - int y = 0; + int y = height; + if (is_2d) y -= 1; do { if (is_2d) { const __m128i sum = @@ -202,8 +159,8 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, } src += src_stride << 1; - y += 2; - } while (y < height - 1); + y -= 2; + } while (y != 0); // The 2d filters have an odd |height| because the horizontal pass // generates context for the vertical pass. @@ -236,7 +193,7 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, } // Filter widths >= 4. -template <int num_taps, int step, int filter_index, bool is_2d = false, +template <int num_taps, int filter_index, bool is_2d = false, bool is_compound = false> void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, void* const dest, const ptrdiff_t pred_stride, @@ -251,7 +208,22 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, int x = 0; do { if (is_2d || is_compound) { - // placeholder + // Load into 2 128 bit lanes. + const __m256i src_long = + SetrM128i(LoadUnaligned16(&src[x]), LoadUnaligned16(&src[x + 8])); + const __m256i result = + HorizontalTaps8To16<filter_index>(&src_long, v_tap); + const __m256i src_long2 = SetrM128i(LoadUnaligned16(&src[x + 16]), + LoadUnaligned16(&src[x + 24])); + const __m256i result2 = + HorizontalTaps8To16<filter_index>(&src_long2, v_tap); + if (is_2d) { + StoreAligned32(&dest16[x], result); + StoreAligned32(&dest16[x + 16], result2); + } else { + StoreUnaligned32(&dest16[x], result); + StoreUnaligned32(&dest16[x + 16], result2); + } } else { // Load src used to calculate dest8[7:0] and dest8[23:16]. const __m256i src_long = LoadUnaligned32(&src[x]); @@ -264,7 +236,7 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, // Combine results and store. StoreUnaligned32(&dest8[x], _mm256_unpacklo_epi64(result, result2)); } - x += step * 4; + x += 32; } while (x < width); src += src_stride; dest8 += pred_stride; @@ -272,9 +244,26 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, } while (--y != 0); } else if (width == 16) { int y = height; + if (is_2d) y -= 1; do { if (is_2d || is_compound) { - // placeholder + // Load into 2 128 bit lanes. + const __m256i src_long = + SetrM128i(LoadUnaligned16(&src[0]), LoadUnaligned16(&src[8])); + const __m256i result = + HorizontalTaps8To16<filter_index>(&src_long, v_tap); + const __m256i src_long2 = + SetrM128i(LoadUnaligned16(&src[src_stride]), + LoadUnaligned16(&src[8 + src_stride])); + const __m256i result2 = + HorizontalTaps8To16<filter_index>(&src_long2, v_tap); + if (is_2d) { + StoreAligned32(&dest16[0], result); + StoreAligned32(&dest16[pred_stride], result2); + } else { + StoreUnaligned32(&dest16[0], result); + StoreUnaligned32(&dest16[pred_stride], result2); + } } else { // Load into 2 128 bit lanes. const __m256i src_long = SetrM128i(LoadUnaligned16(&src[0]), @@ -295,11 +284,37 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, dest16 += pred_stride * 2; y -= 2; } while (y != 0); + + // The 2d filters have an odd |height| during the horizontal pass, so + // filter the remaining row. + if (is_2d) { + const __m256i src_long = + SetrM128i(LoadUnaligned16(&src[0]), LoadUnaligned16(&src[8])); + const __m256i result = + HorizontalTaps8To16<filter_index>(&src_long, v_tap); + StoreAligned32(&dest16[0], result); + } + } else if (width == 8) { int y = height; + if (is_2d) y -= 1; do { + // Load into 2 128 bit lanes. + const __m128i this_row = LoadUnaligned16(&src[0]); + const __m128i next_row = LoadUnaligned16(&src[src_stride]); + const __m256i src_long = SetrM128i(this_row, next_row); if (is_2d || is_compound) { - // placeholder + const __m256i result = + HorizontalTaps8To16<filter_index>(&src_long, v_tap); + if (is_2d) { + StoreAligned16(&dest16[0], _mm256_castsi256_si128(result)); + StoreAligned16(&dest16[pred_stride], + _mm256_extracti128_si256(result, 1)); + } else { + StoreUnaligned16(&dest16[0], _mm256_castsi256_si128(result)); + StoreUnaligned16(&dest16[pred_stride], + _mm256_extracti128_si256(result, 1)); + } } else { const __m128i this_row = LoadUnaligned16(&src[0]); const __m128i next_row = LoadUnaligned16(&src[src_stride]); @@ -315,11 +330,29 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, dest16 += pred_stride * 2; y -= 2; } while (y != 0); + + // The 2d filters have an odd |height| during the horizontal pass, so + // filter the remaining row. + if (is_2d) { + const __m256i src_long = _mm256_castsi128_si256(LoadUnaligned16(&src[0])); + const __m256i result = + HorizontalTaps8To16<filter_index>(&src_long, v_tap); + StoreAligned16(&dest16[0], _mm256_castsi256_si128(result)); + } + } else { // width == 4 int y = height; + if (is_2d) y -= 1; do { + // Load into 2 128 bit lanes. + const __m128i this_row = LoadUnaligned16(&src[0]); + const __m128i next_row = LoadUnaligned16(&src[src_stride]); + const __m256i src_long = SetrM128i(this_row, next_row); if (is_2d || is_compound) { - // placeholder + const __m256i result = + HorizontalTaps8To16<filter_index>(&src_long, v_tap); + StoreLo8(&dest16[0], _mm256_castsi256_si128(result)); + StoreLo8(&dest16[pred_stride], _mm256_extracti128_si256(result, 1)); } else { const __m128i this_row = LoadUnaligned16(&src[0]); const __m128i next_row = LoadUnaligned16(&src[src_stride]); @@ -335,93 +368,176 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, dest16 += pred_stride * 2; y -= 2; } while (y != 0); + + // The 2d filters have an odd |height| during the horizontal pass, so + // filter the remaining row. + if (is_2d) { + const __m256i src_long = _mm256_castsi128_si256(LoadUnaligned16(&src[0])); + const __m256i result = + HorizontalTaps8To16<filter_index>(&src_long, v_tap); + StoreLo8(&dest16[0], _mm256_castsi256_si128(result)); + } } } template <int num_taps, bool is_2d_vertical = false> LIBGAV1_ALWAYS_INLINE void SetupTaps(const __m128i* const filter, - __m128i* v_tap) { + __m256i* v_tap) { if (num_taps == 8) { - v_tap[0] = _mm_shufflelo_epi16(*filter, 0x0); // k1k0 - v_tap[1] = _mm_shufflelo_epi16(*filter, 0x55); // k3k2 - v_tap[2] = _mm_shufflelo_epi16(*filter, 0xaa); // k5k4 - v_tap[3] = _mm_shufflelo_epi16(*filter, 0xff); // k7k6 if (is_2d_vertical) { - v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); - v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]); - v_tap[2] = _mm_cvtepi8_epi16(v_tap[2]); - v_tap[3] = _mm_cvtepi8_epi16(v_tap[3]); + v_tap[0] = _mm256_broadcastd_epi32(*filter); // k1k0 + v_tap[1] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 4)); // k3k2 + v_tap[2] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 8)); // k5k4 + v_tap[3] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 12)); // k7k6 } else { - v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); - v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]); - v_tap[2] = _mm_unpacklo_epi64(v_tap[2], v_tap[2]); - v_tap[3] = _mm_unpacklo_epi64(v_tap[3], v_tap[3]); + v_tap[0] = _mm256_broadcastw_epi16(*filter); // k1k0 + v_tap[1] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 2)); // k3k2 + v_tap[2] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 4)); // k5k4 + v_tap[3] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 6)); // k7k6 } } else if (num_taps == 6) { - const __m128i adjusted_filter = _mm_srli_si128(*filter, 1); - v_tap[0] = _mm_shufflelo_epi16(adjusted_filter, 0x0); // k2k1 - v_tap[1] = _mm_shufflelo_epi16(adjusted_filter, 0x55); // k4k3 - v_tap[2] = _mm_shufflelo_epi16(adjusted_filter, 0xaa); // k6k5 if (is_2d_vertical) { - v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); - v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]); - v_tap[2] = _mm_cvtepi8_epi16(v_tap[2]); + v_tap[0] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 2)); // k2k1 + v_tap[1] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 6)); // k4k3 + v_tap[2] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 10)); // k6k5 } else { - v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); - v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]); - v_tap[2] = _mm_unpacklo_epi64(v_tap[2], v_tap[2]); + v_tap[0] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 1)); // k2k1 + v_tap[1] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 3)); // k4k3 + v_tap[2] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 5)); // k6k5 } } else if (num_taps == 4) { - v_tap[0] = _mm_shufflelo_epi16(*filter, 0x55); // k3k2 - v_tap[1] = _mm_shufflelo_epi16(*filter, 0xaa); // k5k4 if (is_2d_vertical) { - v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); - v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]); + v_tap[0] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 4)); // k3k2 + v_tap[1] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 8)); // k5k4 } else { - v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); - v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]); + v_tap[0] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 2)); // k3k2 + v_tap[1] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 4)); // k5k4 } } else { // num_taps == 2 - const __m128i adjusted_filter = _mm_srli_si128(*filter, 1); - v_tap[0] = _mm_shufflelo_epi16(adjusted_filter, 0x55); // k4k3 if (is_2d_vertical) { - v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); + v_tap[0] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 6)); // k4k3 } else { - v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); + v_tap[0] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 3)); // k4k3 } } } -template <int num_taps, bool is_2d_vertical = false> -LIBGAV1_ALWAYS_INLINE void SetupTaps(const __m128i* const filter, - __m256i* v_tap) { - if (num_taps == 8) { - v_tap[0] = _mm256_broadcastw_epi16(*filter); // k1k0 - v_tap[1] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 2)); // k3k2 - v_tap[2] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 4)); // k5k4 - v_tap[3] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 6)); // k7k6 - if (is_2d_vertical) { - // placeholder - } - } else if (num_taps == 6) { - v_tap[0] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 1)); // k2k1 - v_tap[1] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 3)); // k4k3 - v_tap[2] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 5)); // k6k5 - if (is_2d_vertical) { - // placeholder - } - } else if (num_taps == 4) { - v_tap[0] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 2)); // k3k2 - v_tap[1] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 4)); // k5k4 - if (is_2d_vertical) { - // placeholder - } - } else { // num_taps == 2 - v_tap[0] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 3)); // k4k3 - if (is_2d_vertical) { - // placeholder +template <int num_taps, bool is_compound> +__m256i SimpleSum2DVerticalTaps(const __m256i* const src, + const __m256i* const taps) { + __m256i sum_lo = + _mm256_madd_epi16(_mm256_unpacklo_epi16(src[0], src[1]), taps[0]); + __m256i sum_hi = + _mm256_madd_epi16(_mm256_unpackhi_epi16(src[0], src[1]), taps[0]); + if (num_taps >= 4) { + __m256i madd_lo = + _mm256_madd_epi16(_mm256_unpacklo_epi16(src[2], src[3]), taps[1]); + __m256i madd_hi = + _mm256_madd_epi16(_mm256_unpackhi_epi16(src[2], src[3]), taps[1]); + sum_lo = _mm256_add_epi32(sum_lo, madd_lo); + sum_hi = _mm256_add_epi32(sum_hi, madd_hi); + if (num_taps >= 6) { + madd_lo = + _mm256_madd_epi16(_mm256_unpacklo_epi16(src[4], src[5]), taps[2]); + madd_hi = + _mm256_madd_epi16(_mm256_unpackhi_epi16(src[4], src[5]), taps[2]); + sum_lo = _mm256_add_epi32(sum_lo, madd_lo); + sum_hi = _mm256_add_epi32(sum_hi, madd_hi); + if (num_taps == 8) { + madd_lo = + _mm256_madd_epi16(_mm256_unpacklo_epi16(src[6], src[7]), taps[3]); + madd_hi = + _mm256_madd_epi16(_mm256_unpackhi_epi16(src[6], src[7]), taps[3]); + sum_lo = _mm256_add_epi32(sum_lo, madd_lo); + sum_hi = _mm256_add_epi32(sum_hi, madd_hi); + } } } + + if (is_compound) { + return _mm256_packs_epi32( + RightShiftWithRounding_S32(sum_lo, kInterRoundBitsCompoundVertical - 1), + RightShiftWithRounding_S32(sum_hi, + kInterRoundBitsCompoundVertical - 1)); + } + + return _mm256_packs_epi32( + RightShiftWithRounding_S32(sum_lo, kInterRoundBitsVertical - 1), + RightShiftWithRounding_S32(sum_hi, kInterRoundBitsVertical - 1)); +} + +template <int num_taps, bool is_compound = false> +void Filter2DVertical16xH(const uint16_t* src, void* const dst, + const ptrdiff_t dst_stride, const int width, + const int height, const __m256i* const taps) { + assert(width >= 8); + constexpr int next_row = num_taps - 1; + // The Horizontal pass uses |width| as |stride| for the intermediate buffer. + const ptrdiff_t src_stride = width; + + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + int x = 0; + do { + __m256i srcs[8]; + const uint16_t* src_x = src + x; + srcs[0] = LoadAligned32(src_x); + src_x += src_stride; + if (num_taps >= 4) { + srcs[1] = LoadAligned32(src_x); + src_x += src_stride; + srcs[2] = LoadAligned32(src_x); + src_x += src_stride; + if (num_taps >= 6) { + srcs[3] = LoadAligned32(src_x); + src_x += src_stride; + srcs[4] = LoadAligned32(src_x); + src_x += src_stride; + if (num_taps == 8) { + srcs[5] = LoadAligned32(src_x); + src_x += src_stride; + srcs[6] = LoadAligned32(src_x); + src_x += src_stride; + } + } + } + + auto* dst8_x = dst8 + x; + auto* dst16_x = dst16 + x; + int y = height; + do { + srcs[next_row] = LoadAligned32(src_x); + src_x += src_stride; + + const __m256i sum = + SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps); + if (is_compound) { + StoreUnaligned32(dst16_x, sum); + dst16_x += dst_stride; + } else { + const __m128i packed_sum = _mm_packus_epi16( + _mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1)); + StoreUnaligned16(dst8_x, packed_sum); + dst8_x += dst_stride; + } + + srcs[0] = srcs[1]; + if (num_taps >= 4) { + srcs[1] = srcs[2]; + srcs[2] = srcs[3]; + if (num_taps >= 6) { + srcs[3] = srcs[4]; + srcs[4] = srcs[5]; + if (num_taps == 8) { + srcs[5] = srcs[6]; + srcs[6] = srcs[7]; + } + } + } + } while (--y != 0); + x += 16; + } while (x < width); } template <bool is_2d = false, bool is_compound = false> @@ -436,16 +552,16 @@ LIBGAV1_ALWAYS_INLINE void DoHorizontalPass2xH( if (filter_index == 4) { // 4 tap. SetupTaps<4>(&v_horizontal_filter, v_tap); - FilterHorizontal<4, 8, 4, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<4, 4, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else if (filter_index == 5) { // 4 tap. SetupTaps<4>(&v_horizontal_filter, v_tap); - FilterHorizontal<4, 8, 5, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<4, 5, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else { // 2 tap. SetupTaps<2>(&v_horizontal_filter, v_tap); - FilterHorizontal<2, 8, 3, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<2, 3, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } } @@ -461,28 +577,792 @@ LIBGAV1_ALWAYS_INLINE void DoHorizontalPass( if (filter_index == 2) { // 8 tap. SetupTaps<8>(&v_horizontal_filter, v_tap); - FilterHorizontal<8, 8, 2, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<8, 2, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else if (filter_index == 1) { // 6 tap. SetupTaps<6>(&v_horizontal_filter, v_tap); - FilterHorizontal<6, 8, 1, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<6, 1, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else if (filter_index == 0) { // 6 tap. SetupTaps<6>(&v_horizontal_filter, v_tap); - FilterHorizontal<6, 8, 0, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<6, 0, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else if (filter_index == 4) { // 4 tap. SetupTaps<4>(&v_horizontal_filter, v_tap); - FilterHorizontal<4, 8, 4, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<4, 4, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else if (filter_index == 5) { // 4 tap. SetupTaps<4>(&v_horizontal_filter, v_tap); - FilterHorizontal<4, 8, 5, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<4, 5, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else { // 2 tap. SetupTaps<2>(&v_horizontal_filter, v_tap); - FilterHorizontal<2, 8, 3, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<2, 3, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); + } +} + +void Convolve2D_AVX2(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int vertical_filter_index, + const int horizontal_filter_id, + const int vertical_filter_id, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width); + const int vert_filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(vert_filter_index); + + // The output of the horizontal filter is guaranteed to fit in 16 bits. + alignas(32) uint16_t + intermediate_result[kMaxSuperBlockSizeInPixels * + (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)]; + const int intermediate_height = height + vertical_taps - 1; + + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride - kHorizontalOffset; + if (width > 2) { + DoHorizontalPass</*is_2d=*/true>(src, src_stride, intermediate_result, + width, width, intermediate_height, + horizontal_filter_id, horiz_filter_index); + } else { + // Use non avx2 version for smaller widths. + DoHorizontalPass2xH</*is_2d=*/true>( + src, src_stride, intermediate_result, width, width, intermediate_height, + horizontal_filter_id, horiz_filter_index); + } + + // Vertical filter. + auto* dest = static_cast<uint8_t*>(prediction); + const ptrdiff_t dest_stride = pred_stride; + assert(vertical_filter_id != 0); + + const __m128i v_filter = + LoadLo8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id]); + + // Use 256 bits for width > 8. + if (width > 8) { + __m256i taps_256[4]; + const __m128i v_filter_ext = _mm_cvtepi8_epi16(v_filter); + + if (vertical_taps == 8) { + SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256); + Filter2DVertical16xH<8>(intermediate_result, dest, dest_stride, width, + height, taps_256); + } else if (vertical_taps == 6) { + SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256); + Filter2DVertical16xH<6>(intermediate_result, dest, dest_stride, width, + height, taps_256); + } else if (vertical_taps == 4) { + SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256); + Filter2DVertical16xH<4>(intermediate_result, dest, dest_stride, width, + height, taps_256); + } else { // |vertical_taps| == 2 + SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256); + Filter2DVertical16xH<2>(intermediate_result, dest, dest_stride, width, + height, taps_256); + } + } else { // width <= 8 + __m128i taps[4]; + // Use 128 bit code. + if (vertical_taps == 8) { + SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 2) { + Filter2DVertical2xH<8>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<8>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<8>(intermediate_result, dest, dest_stride, width, + height, taps); + } + } else if (vertical_taps == 6) { + SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 2) { + Filter2DVertical2xH<6>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<6>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<6>(intermediate_result, dest, dest_stride, width, + height, taps); + } + } else if (vertical_taps == 4) { + SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 2) { + Filter2DVertical2xH<4>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<4>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<4>(intermediate_result, dest, dest_stride, width, + height, taps); + } + } else { // |vertical_taps| == 2 + SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 2) { + Filter2DVertical2xH<2>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<2>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<2>(intermediate_result, dest, dest_stride, width, + height, taps); + } + } + } +} + +// The 1D compound shift is always |kInterRoundBitsHorizontal|, even for 1D +// Vertical calculations. +__m256i Compound1DShift(const __m256i sum) { + return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1); +} + +template <int filter_index, bool unpack_high = false> +__m256i SumVerticalTaps(const __m256i* const srcs, const __m256i* const v_tap) { + __m256i v_src[4]; + + if (!unpack_high) { + if (filter_index < 2) { + // 6 taps. + v_src[0] = _mm256_unpacklo_epi8(srcs[0], srcs[1]); + v_src[1] = _mm256_unpacklo_epi8(srcs[2], srcs[3]); + v_src[2] = _mm256_unpacklo_epi8(srcs[4], srcs[5]); + } else if (filter_index == 2) { + // 8 taps. + v_src[0] = _mm256_unpacklo_epi8(srcs[0], srcs[1]); + v_src[1] = _mm256_unpacklo_epi8(srcs[2], srcs[3]); + v_src[2] = _mm256_unpacklo_epi8(srcs[4], srcs[5]); + v_src[3] = _mm256_unpacklo_epi8(srcs[6], srcs[7]); + } else if (filter_index == 3) { + // 2 taps. + v_src[0] = _mm256_unpacklo_epi8(srcs[0], srcs[1]); + } else if (filter_index > 3) { + // 4 taps. + v_src[0] = _mm256_unpacklo_epi8(srcs[0], srcs[1]); + v_src[1] = _mm256_unpacklo_epi8(srcs[2], srcs[3]); + } + } else { + if (filter_index < 2) { + // 6 taps. + v_src[0] = _mm256_unpackhi_epi8(srcs[0], srcs[1]); + v_src[1] = _mm256_unpackhi_epi8(srcs[2], srcs[3]); + v_src[2] = _mm256_unpackhi_epi8(srcs[4], srcs[5]); + } else if (filter_index == 2) { + // 8 taps. + v_src[0] = _mm256_unpackhi_epi8(srcs[0], srcs[1]); + v_src[1] = _mm256_unpackhi_epi8(srcs[2], srcs[3]); + v_src[2] = _mm256_unpackhi_epi8(srcs[4], srcs[5]); + v_src[3] = _mm256_unpackhi_epi8(srcs[6], srcs[7]); + } else if (filter_index == 3) { + // 2 taps. + v_src[0] = _mm256_unpackhi_epi8(srcs[0], srcs[1]); + } else if (filter_index > 3) { + // 4 taps. + v_src[0] = _mm256_unpackhi_epi8(srcs[0], srcs[1]); + v_src[1] = _mm256_unpackhi_epi8(srcs[2], srcs[3]); + } + } + return SumOnePassTaps<filter_index>(v_src, v_tap); +} + +template <int filter_index, bool is_compound = false> +void FilterVertical32xH(const uint8_t* src, const ptrdiff_t src_stride, + void* const dst, const ptrdiff_t dst_stride, + const int width, const int height, + const __m256i* const v_tap) { + const int num_taps = GetNumTapsInFilter(filter_index); + const int next_row = num_taps - 1; + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + assert(width >= 32); + int x = 0; + do { + const uint8_t* src_x = src + x; + __m256i srcs[8]; + srcs[0] = LoadUnaligned32(src_x); + src_x += src_stride; + if (num_taps >= 4) { + srcs[1] = LoadUnaligned32(src_x); + src_x += src_stride; + srcs[2] = LoadUnaligned32(src_x); + src_x += src_stride; + if (num_taps >= 6) { + srcs[3] = LoadUnaligned32(src_x); + src_x += src_stride; + srcs[4] = LoadUnaligned32(src_x); + src_x += src_stride; + if (num_taps == 8) { + srcs[5] = LoadUnaligned32(src_x); + src_x += src_stride; + srcs[6] = LoadUnaligned32(src_x); + src_x += src_stride; + } + } + } + + auto* dst8_x = dst8 + x; + auto* dst16_x = dst16 + x; + int y = height; + do { + srcs[next_row] = LoadUnaligned32(src_x); + src_x += src_stride; + + const __m256i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + const __m256i sums_hi = + SumVerticalTaps<filter_index, /*unpack_high=*/true>(srcs, v_tap); + if (is_compound) { + const __m256i results = + Compound1DShift(_mm256_permute2x128_si256(sums, sums_hi, 0x20)); + const __m256i results_hi = + Compound1DShift(_mm256_permute2x128_si256(sums, sums_hi, 0x31)); + StoreUnaligned32(dst16_x, results); + StoreUnaligned32(dst16_x + 16, results_hi); + dst16_x += dst_stride; + } else { + const __m256i results = + RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m256i results_hi = + RightShiftWithRounding_S16(sums_hi, kFilterBits - 1); + const __m256i packed_results = _mm256_packus_epi16(results, results_hi); + + StoreUnaligned32(dst8_x, packed_results); + dst8_x += dst_stride; + } + + srcs[0] = srcs[1]; + if (num_taps >= 4) { + srcs[1] = srcs[2]; + srcs[2] = srcs[3]; + if (num_taps >= 6) { + srcs[3] = srcs[4]; + srcs[4] = srcs[5]; + if (num_taps == 8) { + srcs[5] = srcs[6]; + srcs[6] = srcs[7]; + } + } + } + } while (--y != 0); + x += 32; + } while (x < width); +} + +template <int filter_index, bool is_compound = false> +void FilterVertical16xH(const uint8_t* src, const ptrdiff_t src_stride, + void* const dst, const ptrdiff_t dst_stride, + const int /*width*/, const int height, + const __m256i* const v_tap) { + const int num_taps = GetNumTapsInFilter(filter_index); + const int next_row = num_taps; + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + const uint8_t* src_x = src; + __m256i srcs[8 + 1]; + // The upper 128 bits hold the filter data for the next row. + srcs[0] = _mm256_castsi128_si256(LoadUnaligned16(src_x)); + src_x += src_stride; + if (num_taps >= 4) { + srcs[1] = _mm256_castsi128_si256(LoadUnaligned16(src_x)); + src_x += src_stride; + srcs[0] = + _mm256_inserti128_si256(srcs[0], _mm256_castsi256_si128(srcs[1]), 1); + srcs[2] = _mm256_castsi128_si256(LoadUnaligned16(src_x)); + src_x += src_stride; + srcs[1] = + _mm256_inserti128_si256(srcs[1], _mm256_castsi256_si128(srcs[2]), 1); + if (num_taps >= 6) { + srcs[3] = _mm256_castsi128_si256(LoadUnaligned16(src_x)); + src_x += src_stride; + srcs[2] = + _mm256_inserti128_si256(srcs[2], _mm256_castsi256_si128(srcs[3]), 1); + srcs[4] = _mm256_castsi128_si256(LoadUnaligned16(src_x)); + src_x += src_stride; + srcs[3] = + _mm256_inserti128_si256(srcs[3], _mm256_castsi256_si128(srcs[4]), 1); + if (num_taps == 8) { + srcs[5] = _mm256_castsi128_si256(LoadUnaligned16(src_x)); + src_x += src_stride; + srcs[4] = _mm256_inserti128_si256(srcs[4], + _mm256_castsi256_si128(srcs[5]), 1); + srcs[6] = _mm256_castsi128_si256(LoadUnaligned16(src_x)); + src_x += src_stride; + srcs[5] = _mm256_inserti128_si256(srcs[5], + _mm256_castsi256_si128(srcs[6]), 1); + } + } + } + + int y = height; + do { + srcs[next_row - 1] = _mm256_castsi128_si256(LoadUnaligned16(src_x)); + src_x += src_stride; + + srcs[next_row - 2] = _mm256_inserti128_si256( + srcs[next_row - 2], _mm256_castsi256_si128(srcs[next_row - 1]), 1); + + srcs[next_row] = _mm256_castsi128_si256(LoadUnaligned16(src_x)); + src_x += src_stride; + + srcs[next_row - 1] = _mm256_inserti128_si256( + srcs[next_row - 1], _mm256_castsi256_si128(srcs[next_row]), 1); + + const __m256i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + const __m256i sums_hi = + SumVerticalTaps<filter_index, /*unpack_high=*/true>(srcs, v_tap); + if (is_compound) { + const __m256i results = + Compound1DShift(_mm256_permute2x128_si256(sums, sums_hi, 0x20)); + const __m256i results_hi = + Compound1DShift(_mm256_permute2x128_si256(sums, sums_hi, 0x31)); + + StoreUnaligned32(dst16, results); + StoreUnaligned32(dst16 + dst_stride, results_hi); + dst16 += dst_stride << 1; + } else { + const __m256i results = RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m256i results_hi = + RightShiftWithRounding_S16(sums_hi, kFilterBits - 1); + const __m256i packed_results = _mm256_packus_epi16(results, results_hi); + const __m128i this_dst = _mm256_castsi256_si128(packed_results); + const auto next_dst = _mm256_extracti128_si256(packed_results, 1); + + StoreUnaligned16(dst8, this_dst); + StoreUnaligned16(dst8 + dst_stride, next_dst); + dst8 += dst_stride << 1; + } + + srcs[0] = srcs[2]; + if (num_taps >= 4) { + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + if (num_taps >= 6) { + srcs[3] = srcs[5]; + srcs[4] = srcs[6]; + if (num_taps == 8) { + srcs[5] = srcs[7]; + srcs[6] = srcs[8]; + } + } + } + y -= 2; + } while (y != 0); +} + +template <int filter_index, bool is_compound = false> +void FilterVertical8xH(const uint8_t* src, const ptrdiff_t src_stride, + void* const dst, const ptrdiff_t dst_stride, + const int /*width*/, const int height, + const __m256i* const v_tap) { + const int num_taps = GetNumTapsInFilter(filter_index); + const int next_row = num_taps; + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + const uint8_t* src_x = src; + __m256i srcs[8 + 1]; + // The upper 128 bits hold the filter data for the next row. + srcs[0] = _mm256_castsi128_si256(LoadLo8(src_x)); + src_x += src_stride; + if (num_taps >= 4) { + srcs[1] = _mm256_castsi128_si256(LoadLo8(src_x)); + src_x += src_stride; + srcs[0] = + _mm256_inserti128_si256(srcs[0], _mm256_castsi256_si128(srcs[1]), 1); + srcs[2] = _mm256_castsi128_si256(LoadLo8(src_x)); + src_x += src_stride; + srcs[1] = + _mm256_inserti128_si256(srcs[1], _mm256_castsi256_si128(srcs[2]), 1); + if (num_taps >= 6) { + srcs[3] = _mm256_castsi128_si256(LoadLo8(src_x)); + src_x += src_stride; + srcs[2] = + _mm256_inserti128_si256(srcs[2], _mm256_castsi256_si128(srcs[3]), 1); + srcs[4] = _mm256_castsi128_si256(LoadLo8(src_x)); + src_x += src_stride; + srcs[3] = + _mm256_inserti128_si256(srcs[3], _mm256_castsi256_si128(srcs[4]), 1); + if (num_taps == 8) { + srcs[5] = _mm256_castsi128_si256(LoadLo8(src_x)); + src_x += src_stride; + srcs[4] = _mm256_inserti128_si256(srcs[4], + _mm256_castsi256_si128(srcs[5]), 1); + srcs[6] = _mm256_castsi128_si256(LoadLo8(src_x)); + src_x += src_stride; + srcs[5] = _mm256_inserti128_si256(srcs[5], + _mm256_castsi256_si128(srcs[6]), 1); + } + } + } + + int y = height; + do { + srcs[next_row - 1] = _mm256_castsi128_si256(LoadLo8(src_x)); + src_x += src_stride; + + srcs[next_row - 2] = _mm256_inserti128_si256( + srcs[next_row - 2], _mm256_castsi256_si128(srcs[next_row - 1]), 1); + + srcs[next_row] = _mm256_castsi128_si256(LoadLo8(src_x)); + src_x += src_stride; + + srcs[next_row - 1] = _mm256_inserti128_si256( + srcs[next_row - 1], _mm256_castsi256_si128(srcs[next_row]), 1); + + const __m256i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + if (is_compound) { + const __m256i results = Compound1DShift(sums); + const __m128i this_dst = _mm256_castsi256_si128(results); + const auto next_dst = _mm256_extracti128_si256(results, 1); + + StoreUnaligned16(dst16, this_dst); + StoreUnaligned16(dst16 + dst_stride, next_dst); + dst16 += dst_stride << 1; + } else { + const __m256i results = RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m256i packed_results = _mm256_packus_epi16(results, results); + const __m128i this_dst = _mm256_castsi256_si128(packed_results); + const auto next_dst = _mm256_extracti128_si256(packed_results, 1); + + StoreLo8(dst8, this_dst); + StoreLo8(dst8 + dst_stride, next_dst); + dst8 += dst_stride << 1; + } + + srcs[0] = srcs[2]; + if (num_taps >= 4) { + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + if (num_taps >= 6) { + srcs[3] = srcs[5]; + srcs[4] = srcs[6]; + if (num_taps == 8) { + srcs[5] = srcs[7]; + srcs[6] = srcs[8]; + } + } + } + y -= 2; + } while (y != 0); +} + +template <int filter_index, bool is_compound = false> +void FilterVertical8xH(const uint8_t* src, const ptrdiff_t src_stride, + void* const dst, const ptrdiff_t dst_stride, + const int /*width*/, const int height, + const __m128i* const v_tap) { + const int num_taps = GetNumTapsInFilter(filter_index); + const int next_row = num_taps - 1; + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + const uint8_t* src_x = src; + __m128i srcs[8]; + srcs[0] = LoadLo8(src_x); + src_x += src_stride; + if (num_taps >= 4) { + srcs[1] = LoadLo8(src_x); + src_x += src_stride; + srcs[2] = LoadLo8(src_x); + src_x += src_stride; + if (num_taps >= 6) { + srcs[3] = LoadLo8(src_x); + src_x += src_stride; + srcs[4] = LoadLo8(src_x); + src_x += src_stride; + if (num_taps == 8) { + srcs[5] = LoadLo8(src_x); + src_x += src_stride; + srcs[6] = LoadLo8(src_x); + src_x += src_stride; + } + } + } + + int y = height; + do { + srcs[next_row] = LoadLo8(src_x); + src_x += src_stride; + + const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + if (is_compound) { + const __m128i results = Compound1DShift(sums); + StoreUnaligned16(dst16, results); + dst16 += dst_stride; + } else { + const __m128i results = RightShiftWithRounding_S16(sums, kFilterBits - 1); + StoreLo8(dst8, _mm_packus_epi16(results, results)); + dst8 += dst_stride; + } + + srcs[0] = srcs[1]; + if (num_taps >= 4) { + srcs[1] = srcs[2]; + srcs[2] = srcs[3]; + if (num_taps >= 6) { + srcs[3] = srcs[4]; + srcs[4] = srcs[5]; + if (num_taps == 8) { + srcs[5] = srcs[6]; + srcs[6] = srcs[7]; + } + } + } + } while (--y != 0); +} + +void ConvolveVertical_AVX2(const void* const reference, + const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, + const int vertical_filter_index, + const int /*horizontal_filter_id*/, + const int vertical_filter_id, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + const int filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(filter_index); + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride; + auto* dest = static_cast<uint8_t*>(prediction); + const ptrdiff_t dest_stride = pred_stride; + assert(vertical_filter_id != 0); + + const __m128i v_filter = + LoadLo8(kHalfSubPixelFilters[filter_index][vertical_filter_id]); + + // Use 256 bits for width > 4. + if (width > 4) { + __m256i taps_256[4]; + if (filter_index < 2) { // 6 tap. + SetupTaps<6>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<0>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else if (width == 16) { + FilterVertical16xH<0>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else { + FilterVertical32xH<0>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } + } else if (filter_index == 2) { // 8 tap. + SetupTaps<8>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<2>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else if (width == 16) { + FilterVertical16xH<2>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else { + FilterVertical32xH<2>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } + } else if (filter_index == 3) { // 2 tap. + SetupTaps<2>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<3>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else if (width == 16) { + FilterVertical16xH<3>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else { + FilterVertical32xH<3>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } + } else if (filter_index == 4) { // 4 tap. + SetupTaps<4>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<4>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else if (width == 16) { + FilterVertical16xH<4>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else { + FilterVertical32xH<4>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } + } else { + SetupTaps<4>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<5>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else if (width == 16) { + FilterVertical16xH<5>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } else { + FilterVertical32xH<5>(src, src_stride, dest, dest_stride, width, height, + taps_256); + } + } + } else { // width <= 8 + // Use 128 bit code. + __m128i taps[4]; + + if (filter_index < 2) { // 6 tap. + SetupTaps<6>(&v_filter, taps); + if (width == 2) { + FilterVertical2xH<6, 0>(src, src_stride, dest, dest_stride, height, + taps); + } else { + FilterVertical4xH<6, 0>(src, src_stride, dest, dest_stride, height, + taps); + } + } else if (filter_index == 2) { // 8 tap. + SetupTaps<8>(&v_filter, taps); + if (width == 2) { + FilterVertical2xH<8, 2>(src, src_stride, dest, dest_stride, height, + taps); + } else { + FilterVertical4xH<8, 2>(src, src_stride, dest, dest_stride, height, + taps); + } + } else if (filter_index == 3) { // 2 tap. + SetupTaps<2>(&v_filter, taps); + if (width == 2) { + FilterVertical2xH<2, 3>(src, src_stride, dest, dest_stride, height, + taps); + } else { + FilterVertical4xH<2, 3>(src, src_stride, dest, dest_stride, height, + taps); + } + } else if (filter_index == 4) { // 4 tap. + SetupTaps<4>(&v_filter, taps); + if (width == 2) { + FilterVertical2xH<4, 4>(src, src_stride, dest, dest_stride, height, + taps); + } else { + FilterVertical4xH<4, 4>(src, src_stride, dest, dest_stride, height, + taps); + } + } else { + SetupTaps<4>(&v_filter, taps); + if (width == 2) { + FilterVertical2xH<4, 5>(src, src_stride, dest, dest_stride, height, + taps); + } else { + FilterVertical4xH<4, 5>(src, src_stride, dest, dest_stride, height, + taps); + } + } + } +} + +void ConvolveCompoundVertical_AVX2( + const void* const reference, const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, const int vertical_filter_index, + const int /*horizontal_filter_id*/, const int vertical_filter_id, + const int width, const int height, void* prediction, + const ptrdiff_t /*pred_stride*/) { + const int filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(filter_index); + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride; + auto* dest = static_cast<uint8_t*>(prediction); + const ptrdiff_t dest_stride = width; + assert(vertical_filter_id != 0); + + const __m128i v_filter = + LoadLo8(kHalfSubPixelFilters[filter_index][vertical_filter_id]); + + // Use 256 bits for width > 4. + if (width > 4) { + __m256i taps_256[4]; + if (filter_index < 2) { // 6 tap. + SetupTaps<6>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<0, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else if (width == 16) { + FilterVertical16xH<0, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else { + FilterVertical32xH<0, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } + } else if (filter_index == 2) { // 8 tap. + SetupTaps<8>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<2, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else if (width == 16) { + FilterVertical16xH<2, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else { + FilterVertical32xH<2, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } + } else if (filter_index == 3) { // 2 tap. + SetupTaps<2>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<3, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else if (width == 16) { + FilterVertical16xH<3, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else { + FilterVertical32xH<3, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } + } else if (filter_index == 4) { // 4 tap. + SetupTaps<4>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<4, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else if (width == 16) { + FilterVertical16xH<4, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else { + FilterVertical32xH<4, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } + } else { + SetupTaps<4>(&v_filter, taps_256); + if (width == 8) { + FilterVertical8xH<5, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else if (width == 16) { + FilterVertical16xH<5, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } else { + FilterVertical32xH<5, /*is_compound=*/true>( + src, src_stride, dest, dest_stride, width, height, taps_256); + } + } + } else { // width <= 4 + // Use 128 bit code. + __m128i taps[4]; + + if (filter_index < 2) { // 6 tap. + SetupTaps<6>(&v_filter, taps); + FilterVertical4xH<6, 0, /*is_compound=*/true>(src, src_stride, dest, + dest_stride, height, taps); + } else if (filter_index == 2) { // 8 tap. + SetupTaps<8>(&v_filter, taps); + FilterVertical4xH<8, 2, /*is_compound=*/true>(src, src_stride, dest, + dest_stride, height, taps); + } else if (filter_index == 3) { // 2 tap. + SetupTaps<2>(&v_filter, taps); + FilterVertical4xH<2, 3, /*is_compound=*/true>(src, src_stride, dest, + dest_stride, height, taps); + } else if (filter_index == 4) { // 4 tap. + SetupTaps<4>(&v_filter, taps); + FilterVertical4xH<4, 4, /*is_compound=*/true>(src, src_stride, dest, + dest_stride, height, taps); + } else { + SetupTaps<4>(&v_filter, taps); + FilterVertical4xH<4, 5, /*is_compound=*/true>(src, src_stride, dest, + dest_stride, height, taps); + } } } @@ -509,10 +1389,140 @@ void ConvolveHorizontal_AVX2(const void* const reference, } } +void ConvolveCompoundHorizontal_AVX2( + const void* const reference, const ptrdiff_t reference_stride, + const int horizontal_filter_index, const int /*vertical_filter_index*/, + const int horizontal_filter_id, const int /*vertical_filter_id*/, + const int width, const int height, void* prediction, + const ptrdiff_t pred_stride) { + const int filter_index = GetFilterIndex(horizontal_filter_index, width); + // Set |src| to the outermost tap. + const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset; + auto* dest = static_cast<uint8_t*>(prediction); + // All compound functions output to the predictor buffer with |pred_stride| + // equal to |width|. + assert(pred_stride == width); + // Compound functions start at 4x4. + assert(width >= 4 && height >= 4); + +#ifdef NDEBUG + // Quiet compiler error. + (void)pred_stride; +#endif + + DoHorizontalPass</*is_2d=*/false, /*is_compound=*/true>( + src, reference_stride, dest, width, width, height, horizontal_filter_id, + filter_index); +} + +void ConvolveCompound2D_AVX2(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int vertical_filter_index, + const int horizontal_filter_id, + const int vertical_filter_id, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width); + const int vert_filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(vert_filter_index); + + // The output of the horizontal filter is guaranteed to fit in 16 bits. + alignas(32) uint16_t + intermediate_result[kMaxSuperBlockSizeInPixels * + (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)]; + const int intermediate_height = height + vertical_taps - 1; + + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride - kHorizontalOffset; + DoHorizontalPass</*is_2d=*/true, /*is_compound=*/true>( + src, src_stride, intermediate_result, width, width, intermediate_height, + horizontal_filter_id, horiz_filter_index); + + // Vertical filter. + auto* dest = static_cast<uint8_t*>(prediction); + const ptrdiff_t dest_stride = pred_stride; + assert(vertical_filter_id != 0); + + const __m128i v_filter = + LoadLo8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id]); + + // Use 256 bits for width > 8. + if (width > 8) { + __m256i taps_256[4]; + const __m128i v_filter_ext = _mm_cvtepi8_epi16(v_filter); + + if (vertical_taps == 8) { + SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256); + Filter2DVertical16xH<8, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps_256); + } else if (vertical_taps == 6) { + SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256); + Filter2DVertical16xH<6, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps_256); + } else if (vertical_taps == 4) { + SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256); + Filter2DVertical16xH<4, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps_256); + } else { // |vertical_taps| == 2 + SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256); + Filter2DVertical16xH<2, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps_256); + } + } else { // width <= 8 + __m128i taps[4]; + // Use 128 bit code. + if (vertical_taps == 8) { + SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 4) { + Filter2DVertical4xH<8, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<8, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } else if (vertical_taps == 6) { + SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 4) { + Filter2DVertical4xH<6, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<6, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } else if (vertical_taps == 4) { + SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 4) { + Filter2DVertical4xH<4, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<4, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } else { // |vertical_taps| == 2 + SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 4) { + Filter2DVertical4xH<2, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<2, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } + } +} + void Init8bpp() { Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); assert(dsp != nullptr); dsp->convolve[0][0][0][1] = ConvolveHorizontal_AVX2; + dsp->convolve[0][0][1][0] = ConvolveVertical_AVX2; + dsp->convolve[0][0][1][1] = Convolve2D_AVX2; + + dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_AVX2; + dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_AVX2; + dsp->convolve[0][1][1][1] = ConvolveCompound2D_AVX2; } } // namespace @@ -523,7 +1533,7 @@ void ConvolveInit_AVX2() { low_bitdepth::Init8bpp(); } } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_AVX2 +#else // !LIBGAV1_TARGETING_AVX2 namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/convolve_avx2.h b/src/dsp/x86/convolve_avx2.h index 6179d98..e509bc9 100644 --- a/src/dsp/x86/convolve_avx2.h +++ b/src/dsp/x86/convolve_avx2.h @@ -38,6 +38,22 @@ void ConvolveInit_AVX2(); #define LIBGAV1_Dsp8bpp_ConvolveHorizontal LIBGAV1_CPU_AVX2 #endif +#ifndef LIBGAV1_Dsp8bpp_ConvolveCompoundHorizontal +#define LIBGAV1_Dsp8bpp_ConvolveCompoundHorizontal LIBGAV1_CPU_AVX2 +#endif + +#ifndef LIBGAV1_Dsp8bpp_ConvolveVertical +#define LIBGAV1_Dsp8bpp_ConvolveVertical LIBGAV1_CPU_AVX2 +#endif + +#ifndef LIBGAV1_Dsp8bpp_Convolve2D +#define LIBGAV1_Dsp8bpp_Convolve2D LIBGAV1_CPU_AVX2 +#endif + +#ifndef LIBGAV1_Dsp8bpp_ConvolveCompoundVertical +#define LIBGAV1_Dsp8bpp_ConvolveCompoundVertical LIBGAV1_CPU_AVX2 +#endif + #endif // LIBGAV1_TARGETING_AVX2 #endif // LIBGAV1_SRC_DSP_X86_CONVOLVE_AVX2_H_ diff --git a/src/dsp/x86/convolve_sse4.cc b/src/dsp/x86/convolve_sse4.cc index 3a0fff5..9b72fe4 100644 --- a/src/dsp/x86/convolve_sse4.cc +++ b/src/dsp/x86/convolve_sse4.cc @@ -34,41 +34,7 @@ namespace dsp { namespace low_bitdepth { namespace { -#include "src/dsp/convolve.inc" - -// Multiply every entry in |src[]| by the corresponding entry in |taps[]| and -// sum. The filters in |taps[]| are pre-shifted by 1. This prevents the final -// sum from outranging int16_t. -template <int filter_index> -__m128i SumOnePassTaps(const __m128i* const src, const __m128i* const taps) { - __m128i sum; - if (filter_index < 2) { - // 6 taps. - const __m128i v_madd_21 = _mm_maddubs_epi16(src[0], taps[0]); // k2k1 - const __m128i v_madd_43 = _mm_maddubs_epi16(src[1], taps[1]); // k4k3 - const __m128i v_madd_65 = _mm_maddubs_epi16(src[2], taps[2]); // k6k5 - sum = _mm_add_epi16(v_madd_21, v_madd_43); - sum = _mm_add_epi16(sum, v_madd_65); - } else if (filter_index == 2) { - // 8 taps. - const __m128i v_madd_10 = _mm_maddubs_epi16(src[0], taps[0]); // k1k0 - const __m128i v_madd_32 = _mm_maddubs_epi16(src[1], taps[1]); // k3k2 - const __m128i v_madd_54 = _mm_maddubs_epi16(src[2], taps[2]); // k5k4 - const __m128i v_madd_76 = _mm_maddubs_epi16(src[3], taps[3]); // k7k6 - const __m128i v_sum_3210 = _mm_add_epi16(v_madd_10, v_madd_32); - const __m128i v_sum_7654 = _mm_add_epi16(v_madd_54, v_madd_76); - sum = _mm_add_epi16(v_sum_7654, v_sum_3210); - } else if (filter_index == 3) { - // 2 taps. - sum = _mm_maddubs_epi16(src[0], taps[0]); // k4k3 - } else { - // 4 taps. - const __m128i v_madd_32 = _mm_maddubs_epi16(src[0], taps[0]); // k3k2 - const __m128i v_madd_54 = _mm_maddubs_epi16(src[1], taps[1]); // k5k4 - sum = _mm_add_epi16(v_madd_32, v_madd_54); - } - return sum; -} +#include "src/dsp/x86/convolve_sse4.inc" template <int filter_index> __m128i SumHorizontalTaps(const uint8_t* const src, @@ -125,68 +91,7 @@ __m128i HorizontalTaps8To16(const uint8_t* const src, return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1); } -template <int filter_index> -__m128i SumHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride, - const __m128i* const v_tap) { - const __m128i input0 = LoadLo8(&src[2]); - const __m128i input1 = LoadLo8(&src[2 + src_stride]); - - if (filter_index == 3) { - // 03 04 04 05 05 06 06 07 .... - const __m128i input0_dup = - _mm_srli_si128(_mm_unpacklo_epi8(input0, input0), 3); - // 13 14 14 15 15 16 16 17 .... - const __m128i input1_dup = - _mm_srli_si128(_mm_unpacklo_epi8(input1, input1), 3); - const __m128i v_src_43 = _mm_unpacklo_epi64(input0_dup, input1_dup); - const __m128i v_sum_43 = _mm_maddubs_epi16(v_src_43, v_tap[0]); // k4k3 - return v_sum_43; - } - - // 02 03 03 04 04 05 05 06 06 07 .... - const __m128i input0_dup = - _mm_srli_si128(_mm_unpacklo_epi8(input0, input0), 1); - // 12 13 13 14 14 15 15 16 16 17 .... - const __m128i input1_dup = - _mm_srli_si128(_mm_unpacklo_epi8(input1, input1), 1); - // 04 05 05 06 06 07 07 08 ... - const __m128i input0_dup_54 = _mm_srli_si128(input0_dup, 4); - // 14 15 15 16 16 17 17 18 ... - const __m128i input1_dup_54 = _mm_srli_si128(input1_dup, 4); - const __m128i v_src_32 = _mm_unpacklo_epi64(input0_dup, input1_dup); - const __m128i v_src_54 = _mm_unpacklo_epi64(input0_dup_54, input1_dup_54); - const __m128i v_madd_32 = _mm_maddubs_epi16(v_src_32, v_tap[0]); // k3k2 - const __m128i v_madd_54 = _mm_maddubs_epi16(v_src_54, v_tap[1]); // k5k4 - const __m128i v_sum_5432 = _mm_add_epi16(v_madd_54, v_madd_32); - return v_sum_5432; -} - -template <int filter_index> -__m128i SimpleHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride, - const __m128i* const v_tap) { - __m128i sum = SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); - - // Normally the Horizontal pass does the downshift in two passes: - // kInterRoundBitsHorizontal - 1 and then (kFilterBits - - // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them - // requires adding the rounding offset from the skipped shift. - constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2); - - sum = _mm_add_epi16(sum, _mm_set1_epi16(first_shift_rounding_bit)); - sum = RightShiftWithRounding_S16(sum, kFilterBits - 1); - return _mm_packus_epi16(sum, sum); -} - -template <int filter_index> -__m128i HorizontalTaps8To16_2x2(const uint8_t* src, const ptrdiff_t src_stride, - const __m128i* const v_tap) { - const __m128i sum = - SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); - - return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1); -} - -template <int num_taps, int step, int filter_index, bool is_2d = false, +template <int num_taps, int filter_index, bool is_2d = false, bool is_compound = false> void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, void* const dest, const ptrdiff_t pred_stride, @@ -197,7 +102,7 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, // 4 tap filters are never used when width > 4. if (num_taps != 4 && width > 4) { - int y = 0; + int y = height; do { int x = 0; do { @@ -214,12 +119,12 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, SimpleHorizontalTaps<filter_index>(&src[x], v_tap); StoreLo8(&dest8[x], result); } - x += step; + x += 8; } while (x < width); src += src_stride; dest8 += pred_stride; dest16 += pred_stride; - } while (++y < height); + } while (--y != 0); return; } @@ -229,7 +134,7 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, assert(num_taps <= 4); if (num_taps <= 4) { if (width == 4) { - int y = 0; + int y = height; do { if (is_2d || is_compound) { const __m128i v_sum = HorizontalTaps8To16<filter_index>(src, v_tap); @@ -241,12 +146,13 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, src += src_stride; dest8 += pred_stride; dest16 += pred_stride; - } while (++y < height); + } while (--y != 0); return; } if (!is_compound) { - int y = 0; + int y = height; + if (is_2d) y -= 1; do { if (is_2d) { const __m128i sum = @@ -265,8 +171,8 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, } src += src_stride << 1; - y += 2; - } while (y < height - 1); + y -= 2; + } while (y != 0); // The 2d filters have an odd |height| because the horizontal pass // generates context for the vertical pass. @@ -298,303 +204,6 @@ void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, } } -template <int num_taps, bool is_2d_vertical = false> -LIBGAV1_ALWAYS_INLINE void SetupTaps(const __m128i* const filter, - __m128i* v_tap) { - if (num_taps == 8) { - v_tap[0] = _mm_shufflelo_epi16(*filter, 0x0); // k1k0 - v_tap[1] = _mm_shufflelo_epi16(*filter, 0x55); // k3k2 - v_tap[2] = _mm_shufflelo_epi16(*filter, 0xaa); // k5k4 - v_tap[3] = _mm_shufflelo_epi16(*filter, 0xff); // k7k6 - if (is_2d_vertical) { - v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); - v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]); - v_tap[2] = _mm_cvtepi8_epi16(v_tap[2]); - v_tap[3] = _mm_cvtepi8_epi16(v_tap[3]); - } else { - v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); - v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]); - v_tap[2] = _mm_unpacklo_epi64(v_tap[2], v_tap[2]); - v_tap[3] = _mm_unpacklo_epi64(v_tap[3], v_tap[3]); - } - } else if (num_taps == 6) { - const __m128i adjusted_filter = _mm_srli_si128(*filter, 1); - v_tap[0] = _mm_shufflelo_epi16(adjusted_filter, 0x0); // k2k1 - v_tap[1] = _mm_shufflelo_epi16(adjusted_filter, 0x55); // k4k3 - v_tap[2] = _mm_shufflelo_epi16(adjusted_filter, 0xaa); // k6k5 - if (is_2d_vertical) { - v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); - v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]); - v_tap[2] = _mm_cvtepi8_epi16(v_tap[2]); - } else { - v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); - v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]); - v_tap[2] = _mm_unpacklo_epi64(v_tap[2], v_tap[2]); - } - } else if (num_taps == 4) { - v_tap[0] = _mm_shufflelo_epi16(*filter, 0x55); // k3k2 - v_tap[1] = _mm_shufflelo_epi16(*filter, 0xaa); // k5k4 - if (is_2d_vertical) { - v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); - v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]); - } else { - v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); - v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]); - } - } else { // num_taps == 2 - const __m128i adjusted_filter = _mm_srli_si128(*filter, 1); - v_tap[0] = _mm_shufflelo_epi16(adjusted_filter, 0x55); // k4k3 - if (is_2d_vertical) { - v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); - } else { - v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); - } - } -} - -template <int num_taps, bool is_compound> -__m128i SimpleSum2DVerticalTaps(const __m128i* const src, - const __m128i* const taps) { - __m128i sum_lo = _mm_madd_epi16(_mm_unpacklo_epi16(src[0], src[1]), taps[0]); - __m128i sum_hi = _mm_madd_epi16(_mm_unpackhi_epi16(src[0], src[1]), taps[0]); - if (num_taps >= 4) { - __m128i madd_lo = - _mm_madd_epi16(_mm_unpacklo_epi16(src[2], src[3]), taps[1]); - __m128i madd_hi = - _mm_madd_epi16(_mm_unpackhi_epi16(src[2], src[3]), taps[1]); - sum_lo = _mm_add_epi32(sum_lo, madd_lo); - sum_hi = _mm_add_epi32(sum_hi, madd_hi); - if (num_taps >= 6) { - madd_lo = _mm_madd_epi16(_mm_unpacklo_epi16(src[4], src[5]), taps[2]); - madd_hi = _mm_madd_epi16(_mm_unpackhi_epi16(src[4], src[5]), taps[2]); - sum_lo = _mm_add_epi32(sum_lo, madd_lo); - sum_hi = _mm_add_epi32(sum_hi, madd_hi); - if (num_taps == 8) { - madd_lo = _mm_madd_epi16(_mm_unpacklo_epi16(src[6], src[7]), taps[3]); - madd_hi = _mm_madd_epi16(_mm_unpackhi_epi16(src[6], src[7]), taps[3]); - sum_lo = _mm_add_epi32(sum_lo, madd_lo); - sum_hi = _mm_add_epi32(sum_hi, madd_hi); - } - } - } - - if (is_compound) { - return _mm_packs_epi32( - RightShiftWithRounding_S32(sum_lo, kInterRoundBitsCompoundVertical - 1), - RightShiftWithRounding_S32(sum_hi, - kInterRoundBitsCompoundVertical - 1)); - } - - return _mm_packs_epi32( - RightShiftWithRounding_S32(sum_lo, kInterRoundBitsVertical - 1), - RightShiftWithRounding_S32(sum_hi, kInterRoundBitsVertical - 1)); -} - -template <int num_taps, bool is_compound = false> -void Filter2DVertical(const uint16_t* src, void* const dst, - const ptrdiff_t dst_stride, const int width, - const int height, const __m128i* const taps) { - assert(width >= 8); - constexpr int next_row = num_taps - 1; - // The Horizontal pass uses |width| as |stride| for the intermediate buffer. - const ptrdiff_t src_stride = width; - - auto* dst8 = static_cast<uint8_t*>(dst); - auto* dst16 = static_cast<uint16_t*>(dst); - - int x = 0; - do { - __m128i srcs[8]; - const uint16_t* src_x = src + x; - srcs[0] = LoadAligned16(src_x); - src_x += src_stride; - if (num_taps >= 4) { - srcs[1] = LoadAligned16(src_x); - src_x += src_stride; - srcs[2] = LoadAligned16(src_x); - src_x += src_stride; - if (num_taps >= 6) { - srcs[3] = LoadAligned16(src_x); - src_x += src_stride; - srcs[4] = LoadAligned16(src_x); - src_x += src_stride; - if (num_taps == 8) { - srcs[5] = LoadAligned16(src_x); - src_x += src_stride; - srcs[6] = LoadAligned16(src_x); - src_x += src_stride; - } - } - } - - int y = 0; - do { - srcs[next_row] = LoadAligned16(src_x); - src_x += src_stride; - - const __m128i sum = - SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps); - if (is_compound) { - StoreUnaligned16(dst16 + x + y * dst_stride, sum); - } else { - StoreLo8(dst8 + x + y * dst_stride, _mm_packus_epi16(sum, sum)); - } - - srcs[0] = srcs[1]; - if (num_taps >= 4) { - srcs[1] = srcs[2]; - srcs[2] = srcs[3]; - if (num_taps >= 6) { - srcs[3] = srcs[4]; - srcs[4] = srcs[5]; - if (num_taps == 8) { - srcs[5] = srcs[6]; - srcs[6] = srcs[7]; - } - } - } - } while (++y < height); - x += 8; - } while (x < width); -} - -// Take advantage of |src_stride| == |width| to process two rows at a time. -template <int num_taps, bool is_compound = false> -void Filter2DVertical4xH(const uint16_t* src, void* const dst, - const ptrdiff_t dst_stride, const int height, - const __m128i* const taps) { - auto* dst8 = static_cast<uint8_t*>(dst); - auto* dst16 = static_cast<uint16_t*>(dst); - - __m128i srcs[9]; - srcs[0] = LoadAligned16(src); - src += 8; - if (num_taps >= 4) { - srcs[2] = LoadAligned16(src); - src += 8; - srcs[1] = _mm_unpacklo_epi64(_mm_srli_si128(srcs[0], 8), srcs[2]); - if (num_taps >= 6) { - srcs[4] = LoadAligned16(src); - src += 8; - srcs[3] = _mm_unpacklo_epi64(_mm_srli_si128(srcs[2], 8), srcs[4]); - if (num_taps == 8) { - srcs[6] = LoadAligned16(src); - src += 8; - srcs[5] = _mm_unpacklo_epi64(_mm_srli_si128(srcs[4], 8), srcs[6]); - } - } - } - - int y = 0; - do { - srcs[num_taps] = LoadAligned16(src); - src += 8; - srcs[num_taps - 1] = _mm_unpacklo_epi64( - _mm_srli_si128(srcs[num_taps - 2], 8), srcs[num_taps]); - - const __m128i sum = - SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps); - if (is_compound) { - StoreUnaligned16(dst16, sum); - dst16 += 4 << 1; - } else { - const __m128i results = _mm_packus_epi16(sum, sum); - Store4(dst8, results); - dst8 += dst_stride; - Store4(dst8, _mm_srli_si128(results, 4)); - dst8 += dst_stride; - } - - srcs[0] = srcs[2]; - if (num_taps >= 4) { - srcs[1] = srcs[3]; - srcs[2] = srcs[4]; - if (num_taps >= 6) { - srcs[3] = srcs[5]; - srcs[4] = srcs[6]; - if (num_taps == 8) { - srcs[5] = srcs[7]; - srcs[6] = srcs[8]; - } - } - } - y += 2; - } while (y < height); -} - -// Take advantage of |src_stride| == |width| to process four rows at a time. -template <int num_taps> -void Filter2DVertical2xH(const uint16_t* src, void* const dst, - const ptrdiff_t dst_stride, const int height, - const __m128i* const taps) { - constexpr int next_row = (num_taps < 6) ? 4 : 8; - - auto* dst8 = static_cast<uint8_t*>(dst); - - __m128i srcs[9]; - srcs[0] = LoadAligned16(src); - src += 8; - if (num_taps >= 6) { - srcs[4] = LoadAligned16(src); - src += 8; - srcs[1] = _mm_alignr_epi8(srcs[4], srcs[0], 4); - if (num_taps == 8) { - srcs[2] = _mm_alignr_epi8(srcs[4], srcs[0], 8); - srcs[3] = _mm_alignr_epi8(srcs[4], srcs[0], 12); - } - } - - int y = 0; - do { - srcs[next_row] = LoadAligned16(src); - src += 8; - if (num_taps == 2) { - srcs[1] = _mm_alignr_epi8(srcs[4], srcs[0], 4); - } else if (num_taps == 4) { - srcs[1] = _mm_alignr_epi8(srcs[4], srcs[0], 4); - srcs[2] = _mm_alignr_epi8(srcs[4], srcs[0], 8); - srcs[3] = _mm_alignr_epi8(srcs[4], srcs[0], 12); - } else if (num_taps == 6) { - srcs[2] = _mm_alignr_epi8(srcs[4], srcs[0], 8); - srcs[3] = _mm_alignr_epi8(srcs[4], srcs[0], 12); - srcs[5] = _mm_alignr_epi8(srcs[8], srcs[4], 4); - } else if (num_taps == 8) { - srcs[5] = _mm_alignr_epi8(srcs[8], srcs[4], 4); - srcs[6] = _mm_alignr_epi8(srcs[8], srcs[4], 8); - srcs[7] = _mm_alignr_epi8(srcs[8], srcs[4], 12); - } - - const __m128i sum = - SimpleSum2DVerticalTaps<num_taps, /*is_compound=*/false>(srcs, taps); - const __m128i results = _mm_packus_epi16(sum, sum); - - Store2(dst8, results); - dst8 += dst_stride; - Store2(dst8, _mm_srli_si128(results, 2)); - // When |height| <= 4 the taps are restricted to 2 and 4 tap variants. - // Therefore we don't need to check this condition when |height| > 4. - if (num_taps <= 4 && height == 2) return; - dst8 += dst_stride; - Store2(dst8, _mm_srli_si128(results, 4)); - dst8 += dst_stride; - Store2(dst8, _mm_srli_si128(results, 6)); - dst8 += dst_stride; - - srcs[0] = srcs[4]; - if (num_taps == 6) { - srcs[1] = srcs[5]; - srcs[4] = srcs[8]; - } else if (num_taps == 8) { - srcs[1] = srcs[5]; - srcs[2] = srcs[6]; - srcs[3] = srcs[7]; - srcs[4] = srcs[8]; - } - - y += 4; - } while (y < height); -} - template <bool is_2d = false, bool is_compound = false> LIBGAV1_ALWAYS_INLINE void DoHorizontalPass( const uint8_t* const src, const ptrdiff_t src_stride, void* const dst, @@ -607,28 +216,28 @@ LIBGAV1_ALWAYS_INLINE void DoHorizontalPass( if (filter_index == 2) { // 8 tap. SetupTaps<8>(&v_horizontal_filter, v_tap); - FilterHorizontal<8, 8, 2, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<8, 2, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else if (filter_index == 1) { // 6 tap. SetupTaps<6>(&v_horizontal_filter, v_tap); - FilterHorizontal<6, 8, 1, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<6, 1, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else if (filter_index == 0) { // 6 tap. SetupTaps<6>(&v_horizontal_filter, v_tap); - FilterHorizontal<6, 8, 0, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<6, 0, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else if (filter_index == 4) { // 4 tap. SetupTaps<4>(&v_horizontal_filter, v_tap); - FilterHorizontal<4, 8, 4, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<4, 4, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else if (filter_index == 5) { // 4 tap. SetupTaps<4>(&v_horizontal_filter, v_tap); - FilterHorizontal<4, 8, 5, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<4, 5, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } else { // 2 tap. SetupTaps<2>(&v_horizontal_filter, v_tap); - FilterHorizontal<2, 8, 3, is_2d, is_compound>( - src, src_stride, dst, dst_stride, width, height, v_tap); + FilterHorizontal<2, 3, is_2d, is_compound>(src, src_stride, dst, dst_stride, + width, height, v_tap); } } @@ -718,39 +327,6 @@ void Convolve2D_SSE4_1(const void* const reference, } } -// The 1D compound shift is always |kInterRoundBitsHorizontal|, even for 1D -// Vertical calculations. -__m128i Compound1DShift(const __m128i sum) { - return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1); -} - -template <int filter_index> -__m128i SumVerticalTaps(const __m128i* const srcs, const __m128i* const v_tap) { - __m128i v_src[4]; - - if (filter_index < 2) { - // 6 taps. - v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]); - v_src[1] = _mm_unpacklo_epi8(srcs[2], srcs[3]); - v_src[2] = _mm_unpacklo_epi8(srcs[4], srcs[5]); - } else if (filter_index == 2) { - // 8 taps. - v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]); - v_src[1] = _mm_unpacklo_epi8(srcs[2], srcs[3]); - v_src[2] = _mm_unpacklo_epi8(srcs[4], srcs[5]); - v_src[3] = _mm_unpacklo_epi8(srcs[6], srcs[7]); - } else if (filter_index == 3) { - // 2 taps. - v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]); - } else if (filter_index > 3) { - // 4 taps. - v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]); - v_src[1] = _mm_unpacklo_epi8(srcs[2], srcs[3]); - } - const __m128i sum = SumOnePassTaps<filter_index>(v_src, v_tap); - return sum; -} - template <int filter_index, bool is_compound = false> void FilterVertical(const uint8_t* src, const ptrdiff_t src_stride, void* const dst, const ptrdiff_t dst_stride, @@ -787,7 +363,9 @@ void FilterVertical(const uint8_t* src, const ptrdiff_t src_stride, } } - int y = 0; + auto* dst8_x = dst8 + x; + auto* dst16_x = dst16 + x; + int y = height; do { srcs[next_row] = LoadLo8(src_x); src_x += src_stride; @@ -795,11 +373,13 @@ void FilterVertical(const uint8_t* src, const ptrdiff_t src_stride, const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); if (is_compound) { const __m128i results = Compound1DShift(sums); - StoreUnaligned16(dst16 + x + y * dst_stride, results); + StoreUnaligned16(dst16_x, results); + dst16_x += dst_stride; } else { const __m128i results = RightShiftWithRounding_S16(sums, kFilterBits - 1); - StoreLo8(dst8 + x + y * dst_stride, _mm_packus_epi16(results, results)); + StoreLo8(dst8_x, _mm_packus_epi16(results, results)); + dst8_x += dst_stride; } srcs[0] = srcs[1]; @@ -815,506 +395,11 @@ void FilterVertical(const uint8_t* src, const ptrdiff_t src_stride, } } } - } while (++y < height); + } while (--y != 0); x += 8; } while (x < width); } -template <int filter_index, bool is_compound = false> -void FilterVertical4xH(const uint8_t* src, const ptrdiff_t src_stride, - void* const dst, const ptrdiff_t dst_stride, - const int height, const __m128i* const v_tap) { - const int num_taps = GetNumTapsInFilter(filter_index); - auto* dst8 = static_cast<uint8_t*>(dst); - auto* dst16 = static_cast<uint16_t*>(dst); - - __m128i srcs[9]; - - if (num_taps == 2) { - srcs[2] = _mm_setzero_si128(); - // 00 01 02 03 - srcs[0] = Load4(src); - src += src_stride; - - int y = 0; - do { - // 10 11 12 13 - const __m128i a = Load4(src); - // 00 01 02 03 10 11 12 13 - srcs[0] = _mm_unpacklo_epi32(srcs[0], a); - src += src_stride; - // 20 21 22 23 - srcs[2] = Load4(src); - src += src_stride; - // 10 11 12 13 20 21 22 23 - srcs[1] = _mm_unpacklo_epi32(a, srcs[2]); - - const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); - if (is_compound) { - const __m128i results = Compound1DShift(sums); - StoreUnaligned16(dst16, results); - dst16 += 4 << 1; - } else { - const __m128i results_16 = - RightShiftWithRounding_S16(sums, kFilterBits - 1); - const __m128i results = _mm_packus_epi16(results_16, results_16); - Store4(dst8, results); - dst8 += dst_stride; - Store4(dst8, _mm_srli_si128(results, 4)); - dst8 += dst_stride; - } - - srcs[0] = srcs[2]; - y += 2; - } while (y < height); - } else if (num_taps == 4) { - srcs[4] = _mm_setzero_si128(); - // 00 01 02 03 - srcs[0] = Load4(src); - src += src_stride; - // 10 11 12 13 - const __m128i a = Load4(src); - // 00 01 02 03 10 11 12 13 - srcs[0] = _mm_unpacklo_epi32(srcs[0], a); - src += src_stride; - // 20 21 22 23 - srcs[2] = Load4(src); - src += src_stride; - // 10 11 12 13 20 21 22 23 - srcs[1] = _mm_unpacklo_epi32(a, srcs[2]); - - int y = 0; - do { - // 30 31 32 33 - const __m128i b = Load4(src); - // 20 21 22 23 30 31 32 33 - srcs[2] = _mm_unpacklo_epi32(srcs[2], b); - src += src_stride; - // 40 41 42 43 - srcs[4] = Load4(src); - src += src_stride; - // 30 31 32 33 40 41 42 43 - srcs[3] = _mm_unpacklo_epi32(b, srcs[4]); - - const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); - if (is_compound) { - const __m128i results = Compound1DShift(sums); - StoreUnaligned16(dst16, results); - dst16 += 4 << 1; - } else { - const __m128i results_16 = - RightShiftWithRounding_S16(sums, kFilterBits - 1); - const __m128i results = _mm_packus_epi16(results_16, results_16); - Store4(dst8, results); - dst8 += dst_stride; - Store4(dst8, _mm_srli_si128(results, 4)); - dst8 += dst_stride; - } - - srcs[0] = srcs[2]; - srcs[1] = srcs[3]; - srcs[2] = srcs[4]; - y += 2; - } while (y < height); - } else if (num_taps == 6) { - srcs[6] = _mm_setzero_si128(); - // 00 01 02 03 - srcs[0] = Load4(src); - src += src_stride; - // 10 11 12 13 - const __m128i a = Load4(src); - // 00 01 02 03 10 11 12 13 - srcs[0] = _mm_unpacklo_epi32(srcs[0], a); - src += src_stride; - // 20 21 22 23 - srcs[2] = Load4(src); - src += src_stride; - // 10 11 12 13 20 21 22 23 - srcs[1] = _mm_unpacklo_epi32(a, srcs[2]); - // 30 31 32 33 - const __m128i b = Load4(src); - // 20 21 22 23 30 31 32 33 - srcs[2] = _mm_unpacklo_epi32(srcs[2], b); - src += src_stride; - // 40 41 42 43 - srcs[4] = Load4(src); - src += src_stride; - // 30 31 32 33 40 41 42 43 - srcs[3] = _mm_unpacklo_epi32(b, srcs[4]); - - int y = 0; - do { - // 50 51 52 53 - const __m128i c = Load4(src); - // 40 41 42 43 50 51 52 53 - srcs[4] = _mm_unpacklo_epi32(srcs[4], c); - src += src_stride; - // 60 61 62 63 - srcs[6] = Load4(src); - src += src_stride; - // 50 51 52 53 60 61 62 63 - srcs[5] = _mm_unpacklo_epi32(c, srcs[6]); - - const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); - if (is_compound) { - const __m128i results = Compound1DShift(sums); - StoreUnaligned16(dst16, results); - dst16 += 4 << 1; - } else { - const __m128i results_16 = - RightShiftWithRounding_S16(sums, kFilterBits - 1); - const __m128i results = _mm_packus_epi16(results_16, results_16); - Store4(dst8, results); - dst8 += dst_stride; - Store4(dst8, _mm_srli_si128(results, 4)); - dst8 += dst_stride; - } - - srcs[0] = srcs[2]; - srcs[1] = srcs[3]; - srcs[2] = srcs[4]; - srcs[3] = srcs[5]; - srcs[4] = srcs[6]; - y += 2; - } while (y < height); - } else if (num_taps == 8) { - srcs[8] = _mm_setzero_si128(); - // 00 01 02 03 - srcs[0] = Load4(src); - src += src_stride; - // 10 11 12 13 - const __m128i a = Load4(src); - // 00 01 02 03 10 11 12 13 - srcs[0] = _mm_unpacklo_epi32(srcs[0], a); - src += src_stride; - // 20 21 22 23 - srcs[2] = Load4(src); - src += src_stride; - // 10 11 12 13 20 21 22 23 - srcs[1] = _mm_unpacklo_epi32(a, srcs[2]); - // 30 31 32 33 - const __m128i b = Load4(src); - // 20 21 22 23 30 31 32 33 - srcs[2] = _mm_unpacklo_epi32(srcs[2], b); - src += src_stride; - // 40 41 42 43 - srcs[4] = Load4(src); - src += src_stride; - // 30 31 32 33 40 41 42 43 - srcs[3] = _mm_unpacklo_epi32(b, srcs[4]); - // 50 51 52 53 - const __m128i c = Load4(src); - // 40 41 42 43 50 51 52 53 - srcs[4] = _mm_unpacklo_epi32(srcs[4], c); - src += src_stride; - // 60 61 62 63 - srcs[6] = Load4(src); - src += src_stride; - // 50 51 52 53 60 61 62 63 - srcs[5] = _mm_unpacklo_epi32(c, srcs[6]); - - int y = 0; - do { - // 70 71 72 73 - const __m128i d = Load4(src); - // 60 61 62 63 70 71 72 73 - srcs[6] = _mm_unpacklo_epi32(srcs[6], d); - src += src_stride; - // 80 81 82 83 - srcs[8] = Load4(src); - src += src_stride; - // 70 71 72 73 80 81 82 83 - srcs[7] = _mm_unpacklo_epi32(d, srcs[8]); - - const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); - if (is_compound) { - const __m128i results = Compound1DShift(sums); - StoreUnaligned16(dst16, results); - dst16 += 4 << 1; - } else { - const __m128i results_16 = - RightShiftWithRounding_S16(sums, kFilterBits - 1); - const __m128i results = _mm_packus_epi16(results_16, results_16); - Store4(dst8, results); - dst8 += dst_stride; - Store4(dst8, _mm_srli_si128(results, 4)); - dst8 += dst_stride; - } - - srcs[0] = srcs[2]; - srcs[1] = srcs[3]; - srcs[2] = srcs[4]; - srcs[3] = srcs[5]; - srcs[4] = srcs[6]; - srcs[5] = srcs[7]; - srcs[6] = srcs[8]; - y += 2; - } while (y < height); - } -} - -template <int filter_index, bool negative_outside_taps = false> -void FilterVertical2xH(const uint8_t* src, const ptrdiff_t src_stride, - void* const dst, const ptrdiff_t dst_stride, - const int height, const __m128i* const v_tap) { - const int num_taps = GetNumTapsInFilter(filter_index); - auto* dst8 = static_cast<uint8_t*>(dst); - - __m128i srcs[9]; - - if (num_taps == 2) { - srcs[2] = _mm_setzero_si128(); - // 00 01 - srcs[0] = Load2(src); - src += src_stride; - - int y = 0; - do { - // 00 01 10 11 - srcs[0] = Load2<1>(src, srcs[0]); - src += src_stride; - // 00 01 10 11 20 21 - srcs[0] = Load2<2>(src, srcs[0]); - src += src_stride; - // 00 01 10 11 20 21 30 31 - srcs[0] = Load2<3>(src, srcs[0]); - src += src_stride; - // 40 41 - srcs[2] = Load2<0>(src, srcs[2]); - src += src_stride; - // 00 01 10 11 20 21 30 31 40 41 - const __m128i srcs_0_2 = _mm_unpacklo_epi64(srcs[0], srcs[2]); - // 10 11 20 21 30 31 40 41 - srcs[1] = _mm_srli_si128(srcs_0_2, 2); - // This uses srcs[0]..srcs[1]. - const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); - const __m128i results_16 = - RightShiftWithRounding_S16(sums, kFilterBits - 1); - const __m128i results = _mm_packus_epi16(results_16, results_16); - - Store2(dst8, results); - dst8 += dst_stride; - Store2(dst8, _mm_srli_si128(results, 2)); - if (height == 2) return; - dst8 += dst_stride; - Store2(dst8, _mm_srli_si128(results, 4)); - dst8 += dst_stride; - Store2(dst8, _mm_srli_si128(results, 6)); - dst8 += dst_stride; - - srcs[0] = srcs[2]; - y += 4; - } while (y < height); - } else if (num_taps == 4) { - srcs[4] = _mm_setzero_si128(); - - // 00 01 - srcs[0] = Load2(src); - src += src_stride; - // 00 01 10 11 - srcs[0] = Load2<1>(src, srcs[0]); - src += src_stride; - // 00 01 10 11 20 21 - srcs[0] = Load2<2>(src, srcs[0]); - src += src_stride; - - int y = 0; - do { - // 00 01 10 11 20 21 30 31 - srcs[0] = Load2<3>(src, srcs[0]); - src += src_stride; - // 40 41 - srcs[4] = Load2<0>(src, srcs[4]); - src += src_stride; - // 40 41 50 51 - srcs[4] = Load2<1>(src, srcs[4]); - src += src_stride; - // 40 41 50 51 60 61 - srcs[4] = Load2<2>(src, srcs[4]); - src += src_stride; - // 00 01 10 11 20 21 30 31 40 41 50 51 60 61 - const __m128i srcs_0_4 = _mm_unpacklo_epi64(srcs[0], srcs[4]); - // 10 11 20 21 30 31 40 41 - srcs[1] = _mm_srli_si128(srcs_0_4, 2); - // 20 21 30 31 40 41 50 51 - srcs[2] = _mm_srli_si128(srcs_0_4, 4); - // 30 31 40 41 50 51 60 61 - srcs[3] = _mm_srli_si128(srcs_0_4, 6); - - // This uses srcs[0]..srcs[3]. - const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); - const __m128i results_16 = - RightShiftWithRounding_S16(sums, kFilterBits - 1); - const __m128i results = _mm_packus_epi16(results_16, results_16); - - Store2(dst8, results); - dst8 += dst_stride; - Store2(dst8, _mm_srli_si128(results, 2)); - if (height == 2) return; - dst8 += dst_stride; - Store2(dst8, _mm_srli_si128(results, 4)); - dst8 += dst_stride; - Store2(dst8, _mm_srli_si128(results, 6)); - dst8 += dst_stride; - - srcs[0] = srcs[4]; - y += 4; - } while (y < height); - } else if (num_taps == 6) { - // During the vertical pass the number of taps is restricted when - // |height| <= 4. - assert(height > 4); - srcs[8] = _mm_setzero_si128(); - - // 00 01 - srcs[0] = Load2(src); - src += src_stride; - // 00 01 10 11 - srcs[0] = Load2<1>(src, srcs[0]); - src += src_stride; - // 00 01 10 11 20 21 - srcs[0] = Load2<2>(src, srcs[0]); - src += src_stride; - // 00 01 10 11 20 21 30 31 - srcs[0] = Load2<3>(src, srcs[0]); - src += src_stride; - // 40 41 - srcs[4] = Load2(src); - src += src_stride; - // 00 01 10 11 20 21 30 31 40 41 50 51 60 61 - const __m128i srcs_0_4x = _mm_unpacklo_epi64(srcs[0], srcs[4]); - // 10 11 20 21 30 31 40 41 - srcs[1] = _mm_srli_si128(srcs_0_4x, 2); - - int y = 0; - do { - // 40 41 50 51 - srcs[4] = Load2<1>(src, srcs[4]); - src += src_stride; - // 40 41 50 51 60 61 - srcs[4] = Load2<2>(src, srcs[4]); - src += src_stride; - // 40 41 50 51 60 61 70 71 - srcs[4] = Load2<3>(src, srcs[4]); - src += src_stride; - // 80 81 - srcs[8] = Load2<0>(src, srcs[8]); - src += src_stride; - // 00 01 10 11 20 21 30 31 40 41 50 51 60 61 - const __m128i srcs_0_4 = _mm_unpacklo_epi64(srcs[0], srcs[4]); - // 20 21 30 31 40 41 50 51 - srcs[2] = _mm_srli_si128(srcs_0_4, 4); - // 30 31 40 41 50 51 60 61 - srcs[3] = _mm_srli_si128(srcs_0_4, 6); - const __m128i srcs_4_8 = _mm_unpacklo_epi64(srcs[4], srcs[8]); - // 50 51 60 61 70 71 80 81 - srcs[5] = _mm_srli_si128(srcs_4_8, 2); - - // This uses srcs[0]..srcs[5]. - const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); - const __m128i results_16 = - RightShiftWithRounding_S16(sums, kFilterBits - 1); - const __m128i results = _mm_packus_epi16(results_16, results_16); - - Store2(dst8, results); - dst8 += dst_stride; - Store2(dst8, _mm_srli_si128(results, 2)); - dst8 += dst_stride; - Store2(dst8, _mm_srli_si128(results, 4)); - dst8 += dst_stride; - Store2(dst8, _mm_srli_si128(results, 6)); - dst8 += dst_stride; - - srcs[0] = srcs[4]; - srcs[1] = srcs[5]; - srcs[4] = srcs[8]; - y += 4; - } while (y < height); - } else if (num_taps == 8) { - // During the vertical pass the number of taps is restricted when - // |height| <= 4. - assert(height > 4); - srcs[8] = _mm_setzero_si128(); - // 00 01 - srcs[0] = Load2(src); - src += src_stride; - // 00 01 10 11 - srcs[0] = Load2<1>(src, srcs[0]); - src += src_stride; - // 00 01 10 11 20 21 - srcs[0] = Load2<2>(src, srcs[0]); - src += src_stride; - // 00 01 10 11 20 21 30 31 - srcs[0] = Load2<3>(src, srcs[0]); - src += src_stride; - // 40 41 - srcs[4] = Load2(src); - src += src_stride; - // 40 41 50 51 - srcs[4] = Load2<1>(src, srcs[4]); - src += src_stride; - // 40 41 50 51 60 61 - srcs[4] = Load2<2>(src, srcs[4]); - src += src_stride; - - // 00 01 10 11 20 21 30 31 40 41 50 51 60 61 - const __m128i srcs_0_4 = _mm_unpacklo_epi64(srcs[0], srcs[4]); - // 10 11 20 21 30 31 40 41 - srcs[1] = _mm_srli_si128(srcs_0_4, 2); - // 20 21 30 31 40 41 50 51 - srcs[2] = _mm_srli_si128(srcs_0_4, 4); - // 30 31 40 41 50 51 60 61 - srcs[3] = _mm_srli_si128(srcs_0_4, 6); - - int y = 0; - do { - // 40 41 50 51 60 61 70 71 - srcs[4] = Load2<3>(src, srcs[4]); - src += src_stride; - // 80 81 - srcs[8] = Load2<0>(src, srcs[8]); - src += src_stride; - // 80 81 90 91 - srcs[8] = Load2<1>(src, srcs[8]); - src += src_stride; - // 80 81 90 91 a0 a1 - srcs[8] = Load2<2>(src, srcs[8]); - src += src_stride; - - // 40 41 50 51 60 61 70 71 80 81 90 91 a0 a1 - const __m128i srcs_4_8 = _mm_unpacklo_epi64(srcs[4], srcs[8]); - // 50 51 60 61 70 71 80 81 - srcs[5] = _mm_srli_si128(srcs_4_8, 2); - // 60 61 70 71 80 81 90 91 - srcs[6] = _mm_srli_si128(srcs_4_8, 4); - // 70 71 80 81 90 91 a0 a1 - srcs[7] = _mm_srli_si128(srcs_4_8, 6); - - // This uses srcs[0]..srcs[7]. - const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); - const __m128i results_16 = - RightShiftWithRounding_S16(sums, kFilterBits - 1); - const __m128i results = _mm_packus_epi16(results_16, results_16); - - Store2(dst8, results); - dst8 += dst_stride; - Store2(dst8, _mm_srli_si128(results, 2)); - dst8 += dst_stride; - Store2(dst8, _mm_srli_si128(results, 4)); - dst8 += dst_stride; - Store2(dst8, _mm_srli_si128(results, 6)); - dst8 += dst_stride; - - srcs[0] = srcs[4]; - srcs[1] = srcs[5]; - srcs[2] = srcs[6]; - srcs[3] = srcs[7]; - srcs[4] = srcs[8]; - y += 4; - } while (y < height); - } -} - void ConvolveVertical_SSE4_1(const void* const reference, const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/, @@ -1339,9 +424,9 @@ void ConvolveVertical_SSE4_1(const void* const reference, if (filter_index < 2) { // 6 tap. SetupTaps<6>(&v_filter, taps); if (width == 2) { - FilterVertical2xH<0>(src, src_stride, dest, dest_stride, height, taps); + FilterVertical2xH<6, 0>(src, src_stride, dest, dest_stride, height, taps); } else if (width == 4) { - FilterVertical4xH<0>(src, src_stride, dest, dest_stride, height, taps); + FilterVertical4xH<6, 0>(src, src_stride, dest, dest_stride, height, taps); } else { FilterVertical<0>(src, src_stride, dest, dest_stride, width, height, taps); @@ -1349,9 +434,9 @@ void ConvolveVertical_SSE4_1(const void* const reference, } else if (filter_index == 2) { // 8 tap. SetupTaps<8>(&v_filter, taps); if (width == 2) { - FilterVertical2xH<2>(src, src_stride, dest, dest_stride, height, taps); + FilterVertical2xH<8, 2>(src, src_stride, dest, dest_stride, height, taps); } else if (width == 4) { - FilterVertical4xH<2>(src, src_stride, dest, dest_stride, height, taps); + FilterVertical4xH<8, 2>(src, src_stride, dest, dest_stride, height, taps); } else { FilterVertical<2>(src, src_stride, dest, dest_stride, width, height, taps); @@ -1359,9 +444,9 @@ void ConvolveVertical_SSE4_1(const void* const reference, } else if (filter_index == 3) { // 2 tap. SetupTaps<2>(&v_filter, taps); if (width == 2) { - FilterVertical2xH<3>(src, src_stride, dest, dest_stride, height, taps); + FilterVertical2xH<2, 3>(src, src_stride, dest, dest_stride, height, taps); } else if (width == 4) { - FilterVertical4xH<3>(src, src_stride, dest, dest_stride, height, taps); + FilterVertical4xH<2, 3>(src, src_stride, dest, dest_stride, height, taps); } else { FilterVertical<3>(src, src_stride, dest, dest_stride, width, height, taps); @@ -1369,9 +454,9 @@ void ConvolveVertical_SSE4_1(const void* const reference, } else if (filter_index == 4) { // 4 tap. SetupTaps<4>(&v_filter, taps); if (width == 2) { - FilterVertical2xH<4>(src, src_stride, dest, dest_stride, height, taps); + FilterVertical2xH<4, 4>(src, src_stride, dest, dest_stride, height, taps); } else if (width == 4) { - FilterVertical4xH<4>(src, src_stride, dest, dest_stride, height, taps); + FilterVertical4xH<4, 4>(src, src_stride, dest, dest_stride, height, taps); } else { FilterVertical<4>(src, src_stride, dest, dest_stride, width, height, taps); @@ -1382,9 +467,9 @@ void ConvolveVertical_SSE4_1(const void* const reference, SetupTaps<4>(&v_filter, taps); if (width == 2) { - FilterVertical2xH<5>(src, src_stride, dest, dest_stride, height, taps); + FilterVertical2xH<4, 5>(src, src_stride, dest, dest_stride, height, taps); } else if (width == 4) { - FilterVertical4xH<5>(src, src_stride, dest, dest_stride, height, taps); + FilterVertical4xH<4, 5>(src, src_stride, dest, dest_stride, height, taps); } else { FilterVertical<5>(src, src_stride, dest, dest_stride, width, height, taps); @@ -1474,8 +559,8 @@ void ConvolveCompoundVertical_SSE4_1( if (filter_index < 2) { // 6 tap. SetupTaps<6>(&v_filter, taps); if (width == 4) { - FilterVertical4xH<0, /*is_compound=*/true>(src, src_stride, dest, 4, - height, taps); + FilterVertical4xH<6, 0, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps); } else { FilterVertical<0, /*is_compound=*/true>(src, src_stride, dest, width, width, height, taps); @@ -1484,8 +569,8 @@ void ConvolveCompoundVertical_SSE4_1( SetupTaps<8>(&v_filter, taps); if (width == 4) { - FilterVertical4xH<2, /*is_compound=*/true>(src, src_stride, dest, 4, - height, taps); + FilterVertical4xH<8, 2, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps); } else { FilterVertical<2, /*is_compound=*/true>(src, src_stride, dest, width, width, height, taps); @@ -1494,8 +579,8 @@ void ConvolveCompoundVertical_SSE4_1( SetupTaps<2>(&v_filter, taps); if (width == 4) { - FilterVertical4xH<3, /*is_compound=*/true>(src, src_stride, dest, 4, - height, taps); + FilterVertical4xH<2, 3, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps); } else { FilterVertical<3, /*is_compound=*/true>(src, src_stride, dest, width, width, height, taps); @@ -1504,8 +589,8 @@ void ConvolveCompoundVertical_SSE4_1( SetupTaps<4>(&v_filter, taps); if (width == 4) { - FilterVertical4xH<4, /*is_compound=*/true>(src, src_stride, dest, 4, - height, taps); + FilterVertical4xH<4, 4, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps); } else { FilterVertical<4, /*is_compound=*/true>(src, src_stride, dest, width, width, height, taps); @@ -1514,8 +599,8 @@ void ConvolveCompoundVertical_SSE4_1( SetupTaps<4>(&v_filter, taps); if (width == 4) { - FilterVertical4xH<5, /*is_compound=*/true>(src, src_stride, dest, 4, - height, taps); + FilterVertical4xH<4, 5, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps); } else { FilterVertical<5, /*is_compound=*/true>(src, src_stride, dest, width, width, height, taps); @@ -1752,7 +837,11 @@ inline void GetHalfSubPixelFilter(__m128i* output) { template <int num_taps, int grade_x> inline void PrepareSourceVectors(const uint8_t* src, const __m128i src_indices, __m128i* const source /*[num_taps >> 1]*/) { - const __m128i src_vals = LoadUnaligned16(src); + // |used_bytes| is only computed in msan builds. Mask away unused bytes for + // msan because it incorrectly models the outcome of the shuffles in some + // cases. This has not been reproduced out of context. + const int used_bytes = _mm_extract_epi8(src_indices, 15) + 1 + num_taps - 2; + const __m128i src_vals = LoadUnaligned16Msan(src, 16 - used_bytes); source[0] = _mm_shuffle_epi8(src_vals, src_indices); if (grade_x == 1) { if (num_taps > 2) { @@ -1768,7 +857,7 @@ inline void PrepareSourceVectors(const uint8_t* src, const __m128i src_indices, assert(grade_x > 1); assert(num_taps != 4); // grade_x > 1 also means width >= 8 && num_taps != 4 - const __m128i src_vals_ext = LoadLo8(src + 16); + const __m128i src_vals_ext = LoadLo8Msan(src + 16, 24 - used_bytes); if (num_taps > 2) { source[1] = _mm_shuffle_epi8(_mm_alignr_epi8(src_vals_ext, src_vals, 2), src_indices); @@ -1983,14 +1072,10 @@ __m128i Sum2DVerticalTaps4x2(const __m128i* const src, const __m128i* taps_lo, // |width_class| is 2, 4, or 8, according to the Store function that should be // used. template <int num_taps, int width_class, bool is_compound> -#if LIBGAV1_MSAN -__attribute__((no_sanitize_memory)) void ConvolveVerticalScale( -#else -inline void ConvolveVerticalScale( -#endif - const int16_t* src, const int width, const int subpixel_y, - const int filter_index, const int step_y, const int height, void* dest, - const ptrdiff_t dest_stride) { +inline void ConvolveVerticalScale(const int16_t* src, const int width, + const int subpixel_y, const int filter_index, + const int step_y, const int height, + void* dest, const ptrdiff_t dest_stride) { constexpr ptrdiff_t src_stride = kIntermediateStride; constexpr int kernel_offset = (8 - num_taps) / 2; const int16_t* src_y = src; @@ -2819,7 +1904,7 @@ void ConvolveInit_SSE4_1() { low_bitdepth::Init8bpp(); } } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_SSE4_1 +#else // !LIBGAV1_TARGETING_SSE4_1 namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/convolve_sse4.inc b/src/dsp/x86/convolve_sse4.inc new file mode 100644 index 0000000..550d6a4 --- /dev/null +++ b/src/dsp/x86/convolve_sse4.inc @@ -0,0 +1,934 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Common 128 bit functions used for sse4/avx2 convolve implementations. +// This will be included inside an anonymous namespace on files where these are +// necessary. + +#include "src/dsp/convolve.inc" + +// Multiply every entry in |src[]| by the corresponding entry in |taps[]| and +// sum. The filters in |taps[]| are pre-shifted by 1. This prevents the final +// sum from outranging int16_t. +template <int filter_index> +__m128i SumOnePassTaps(const __m128i* const src, const __m128i* const taps) { + __m128i sum; + if (filter_index < 2) { + // 6 taps. + const __m128i v_madd_21 = _mm_maddubs_epi16(src[0], taps[0]); // k2k1 + const __m128i v_madd_43 = _mm_maddubs_epi16(src[1], taps[1]); // k4k3 + const __m128i v_madd_65 = _mm_maddubs_epi16(src[2], taps[2]); // k6k5 + sum = _mm_add_epi16(v_madd_21, v_madd_43); + sum = _mm_add_epi16(sum, v_madd_65); + } else if (filter_index == 2) { + // 8 taps. + const __m128i v_madd_10 = _mm_maddubs_epi16(src[0], taps[0]); // k1k0 + const __m128i v_madd_32 = _mm_maddubs_epi16(src[1], taps[1]); // k3k2 + const __m128i v_madd_54 = _mm_maddubs_epi16(src[2], taps[2]); // k5k4 + const __m128i v_madd_76 = _mm_maddubs_epi16(src[3], taps[3]); // k7k6 + const __m128i v_sum_3210 = _mm_add_epi16(v_madd_10, v_madd_32); + const __m128i v_sum_7654 = _mm_add_epi16(v_madd_54, v_madd_76); + sum = _mm_add_epi16(v_sum_7654, v_sum_3210); + } else if (filter_index == 3) { + // 2 taps. + sum = _mm_maddubs_epi16(src[0], taps[0]); // k4k3 + } else { + // 4 taps. + const __m128i v_madd_32 = _mm_maddubs_epi16(src[0], taps[0]); // k3k2 + const __m128i v_madd_54 = _mm_maddubs_epi16(src[1], taps[1]); // k5k4 + sum = _mm_add_epi16(v_madd_32, v_madd_54); + } + return sum; +} + +template <int filter_index> +__m128i SumHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride, + const __m128i* const v_tap) { + // 00 01 02 03 04 05 06 07 10 11 12 13 14 15 16 17 + const __m128i v_src = LoadHi8(LoadLo8(&src[0]), &src[src_stride]); + + if (filter_index == 3) { + // 03 04 04 05 05 06 06 07 13 14 14 15 15 16 16 17 + const __m128i v_src_43 = _mm_shuffle_epi8( + v_src, _mm_set_epi32(0x0f0e0e0d, 0x0d0c0c0b, 0x07060605, 0x05040403)); + const __m128i v_sum_43 = _mm_maddubs_epi16(v_src_43, v_tap[0]); // k4k3 + return v_sum_43; + } + + // 02 03 03 04 04 05 05 06 12 13 13 14 14 15 15 16 + const __m128i v_src_32 = _mm_shuffle_epi8( + v_src, _mm_set_epi32(0x0e0d0d0c, 0x0c0b0b0a, 0x06050504, 0x04030302)); + // 04 05 05 06 06 07 07 xx 14 15 15 16 16 17 17 xx + const __m128i v_src_54 = _mm_shuffle_epi8( + v_src, _mm_set_epi32(static_cast<int>(0x800f0f0e), 0x0e0d0d0c, + static_cast<int>(0x80070706), 0x06050504)); + const __m128i v_madd_32 = _mm_maddubs_epi16(v_src_32, v_tap[0]); // k3k2 + const __m128i v_madd_54 = _mm_maddubs_epi16(v_src_54, v_tap[1]); // k5k4 + const __m128i v_sum_5432 = _mm_add_epi16(v_madd_54, v_madd_32); + return v_sum_5432; +} + +template <int filter_index> +__m128i SimpleHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride, + const __m128i* const v_tap) { + __m128i sum = SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); + + // Normally the Horizontal pass does the downshift in two passes: + // kInterRoundBitsHorizontal - 1 and then (kFilterBits - + // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them + // requires adding the rounding offset from the skipped shift. + constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2); + + sum = _mm_add_epi16(sum, _mm_set1_epi16(first_shift_rounding_bit)); + sum = RightShiftWithRounding_S16(sum, kFilterBits - 1); + return _mm_packus_epi16(sum, sum); +} + +template <int filter_index> +__m128i HorizontalTaps8To16_2x2(const uint8_t* src, const ptrdiff_t src_stride, + const __m128i* const v_tap) { + const __m128i sum = + SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); + + return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1); +} + +template <int num_taps, bool is_2d_vertical = false> +LIBGAV1_ALWAYS_INLINE void SetupTaps(const __m128i* const filter, + __m128i* v_tap) { + if (num_taps == 8) { + v_tap[0] = _mm_shufflelo_epi16(*filter, 0x0); // k1k0 + v_tap[1] = _mm_shufflelo_epi16(*filter, 0x55); // k3k2 + v_tap[2] = _mm_shufflelo_epi16(*filter, 0xaa); // k5k4 + v_tap[3] = _mm_shufflelo_epi16(*filter, 0xff); // k7k6 + if (is_2d_vertical) { + v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); + v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]); + v_tap[2] = _mm_cvtepi8_epi16(v_tap[2]); + v_tap[3] = _mm_cvtepi8_epi16(v_tap[3]); + } else { + v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); + v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]); + v_tap[2] = _mm_unpacklo_epi64(v_tap[2], v_tap[2]); + v_tap[3] = _mm_unpacklo_epi64(v_tap[3], v_tap[3]); + } + } else if (num_taps == 6) { + const __m128i adjusted_filter = _mm_srli_si128(*filter, 1); + v_tap[0] = _mm_shufflelo_epi16(adjusted_filter, 0x0); // k2k1 + v_tap[1] = _mm_shufflelo_epi16(adjusted_filter, 0x55); // k4k3 + v_tap[2] = _mm_shufflelo_epi16(adjusted_filter, 0xaa); // k6k5 + if (is_2d_vertical) { + v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); + v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]); + v_tap[2] = _mm_cvtepi8_epi16(v_tap[2]); + } else { + v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); + v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]); + v_tap[2] = _mm_unpacklo_epi64(v_tap[2], v_tap[2]); + } + } else if (num_taps == 4) { + v_tap[0] = _mm_shufflelo_epi16(*filter, 0x55); // k3k2 + v_tap[1] = _mm_shufflelo_epi16(*filter, 0xaa); // k5k4 + if (is_2d_vertical) { + v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); + v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]); + } else { + v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); + v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]); + } + } else { // num_taps == 2 + const __m128i adjusted_filter = _mm_srli_si128(*filter, 1); + v_tap[0] = _mm_shufflelo_epi16(adjusted_filter, 0x55); // k4k3 + if (is_2d_vertical) { + v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); + } else { + v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); + } + } +} + +template <int num_taps, bool is_compound> +__m128i SimpleSum2DVerticalTaps(const __m128i* const src, + const __m128i* const taps) { + __m128i sum_lo = _mm_madd_epi16(_mm_unpacklo_epi16(src[0], src[1]), taps[0]); + __m128i sum_hi = _mm_madd_epi16(_mm_unpackhi_epi16(src[0], src[1]), taps[0]); + if (num_taps >= 4) { + __m128i madd_lo = + _mm_madd_epi16(_mm_unpacklo_epi16(src[2], src[3]), taps[1]); + __m128i madd_hi = + _mm_madd_epi16(_mm_unpackhi_epi16(src[2], src[3]), taps[1]); + sum_lo = _mm_add_epi32(sum_lo, madd_lo); + sum_hi = _mm_add_epi32(sum_hi, madd_hi); + if (num_taps >= 6) { + madd_lo = _mm_madd_epi16(_mm_unpacklo_epi16(src[4], src[5]), taps[2]); + madd_hi = _mm_madd_epi16(_mm_unpackhi_epi16(src[4], src[5]), taps[2]); + sum_lo = _mm_add_epi32(sum_lo, madd_lo); + sum_hi = _mm_add_epi32(sum_hi, madd_hi); + if (num_taps == 8) { + madd_lo = _mm_madd_epi16(_mm_unpacklo_epi16(src[6], src[7]), taps[3]); + madd_hi = _mm_madd_epi16(_mm_unpackhi_epi16(src[6], src[7]), taps[3]); + sum_lo = _mm_add_epi32(sum_lo, madd_lo); + sum_hi = _mm_add_epi32(sum_hi, madd_hi); + } + } + } + + if (is_compound) { + return _mm_packs_epi32( + RightShiftWithRounding_S32(sum_lo, kInterRoundBitsCompoundVertical - 1), + RightShiftWithRounding_S32(sum_hi, + kInterRoundBitsCompoundVertical - 1)); + } + + return _mm_packs_epi32( + RightShiftWithRounding_S32(sum_lo, kInterRoundBitsVertical - 1), + RightShiftWithRounding_S32(sum_hi, kInterRoundBitsVertical - 1)); +} + +template <int num_taps, bool is_compound = false> +void Filter2DVertical(const uint16_t* src, void* const dst, + const ptrdiff_t dst_stride, const int width, + const int height, const __m128i* const taps) { + assert(width >= 8); + constexpr int next_row = num_taps - 1; + // The Horizontal pass uses |width| as |stride| for the intermediate buffer. + const ptrdiff_t src_stride = width; + + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + int x = 0; + do { + __m128i srcs[8]; + const uint16_t* src_x = src + x; + srcs[0] = LoadAligned16(src_x); + src_x += src_stride; + if (num_taps >= 4) { + srcs[1] = LoadAligned16(src_x); + src_x += src_stride; + srcs[2] = LoadAligned16(src_x); + src_x += src_stride; + if (num_taps >= 6) { + srcs[3] = LoadAligned16(src_x); + src_x += src_stride; + srcs[4] = LoadAligned16(src_x); + src_x += src_stride; + if (num_taps == 8) { + srcs[5] = LoadAligned16(src_x); + src_x += src_stride; + srcs[6] = LoadAligned16(src_x); + src_x += src_stride; + } + } + } + + auto* dst8_x = dst8 + x; + auto* dst16_x = dst16 + x; + int y = height; + do { + srcs[next_row] = LoadAligned16(src_x); + src_x += src_stride; + + const __m128i sum = + SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps); + if (is_compound) { + StoreUnaligned16(dst16_x, sum); + dst16_x += dst_stride; + } else { + StoreLo8(dst8_x, _mm_packus_epi16(sum, sum)); + dst8_x += dst_stride; + } + + srcs[0] = srcs[1]; + if (num_taps >= 4) { + srcs[1] = srcs[2]; + srcs[2] = srcs[3]; + if (num_taps >= 6) { + srcs[3] = srcs[4]; + srcs[4] = srcs[5]; + if (num_taps == 8) { + srcs[5] = srcs[6]; + srcs[6] = srcs[7]; + } + } + } + } while (--y != 0); + x += 8; + } while (x < width); +} + +// Take advantage of |src_stride| == |width| to process two rows at a time. +template <int num_taps, bool is_compound = false> +void Filter2DVertical4xH(const uint16_t* src, void* const dst, + const ptrdiff_t dst_stride, const int height, + const __m128i* const taps) { + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + __m128i srcs[9]; + srcs[0] = LoadAligned16(src); + src += 8; + if (num_taps >= 4) { + srcs[2] = LoadAligned16(src); + src += 8; + srcs[1] = _mm_unpacklo_epi64(_mm_srli_si128(srcs[0], 8), srcs[2]); + if (num_taps >= 6) { + srcs[4] = LoadAligned16(src); + src += 8; + srcs[3] = _mm_unpacklo_epi64(_mm_srli_si128(srcs[2], 8), srcs[4]); + if (num_taps == 8) { + srcs[6] = LoadAligned16(src); + src += 8; + srcs[5] = _mm_unpacklo_epi64(_mm_srli_si128(srcs[4], 8), srcs[6]); + } + } + } + + int y = height; + do { + srcs[num_taps] = LoadAligned16(src); + src += 8; + srcs[num_taps - 1] = _mm_unpacklo_epi64( + _mm_srli_si128(srcs[num_taps - 2], 8), srcs[num_taps]); + + const __m128i sum = + SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps); + if (is_compound) { + StoreUnaligned16(dst16, sum); + dst16 += 4 << 1; + } else { + const __m128i results = _mm_packus_epi16(sum, sum); + Store4(dst8, results); + dst8 += dst_stride; + Store4(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + if (num_taps >= 4) { + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + if (num_taps >= 6) { + srcs[3] = srcs[5]; + srcs[4] = srcs[6]; + if (num_taps == 8) { + srcs[5] = srcs[7]; + srcs[6] = srcs[8]; + } + } + } + y -= 2; + } while (y != 0); +} + +// Take advantage of |src_stride| == |width| to process four rows at a time. +template <int num_taps> +void Filter2DVertical2xH(const uint16_t* src, void* const dst, + const ptrdiff_t dst_stride, const int height, + const __m128i* const taps) { + constexpr int next_row = (num_taps < 6) ? 4 : 8; + + auto* dst8 = static_cast<uint8_t*>(dst); + + __m128i srcs[9]; + srcs[0] = LoadAligned16(src); + src += 8; + if (num_taps >= 6) { + srcs[4] = LoadAligned16(src); + src += 8; + srcs[1] = _mm_alignr_epi8(srcs[4], srcs[0], 4); + if (num_taps == 8) { + srcs[2] = _mm_alignr_epi8(srcs[4], srcs[0], 8); + srcs[3] = _mm_alignr_epi8(srcs[4], srcs[0], 12); + } + } + + int y = height; + do { + srcs[next_row] = LoadAligned16(src); + src += 8; + if (num_taps == 2) { + srcs[1] = _mm_alignr_epi8(srcs[4], srcs[0], 4); + } else if (num_taps == 4) { + srcs[1] = _mm_alignr_epi8(srcs[4], srcs[0], 4); + srcs[2] = _mm_alignr_epi8(srcs[4], srcs[0], 8); + srcs[3] = _mm_alignr_epi8(srcs[4], srcs[0], 12); + } else if (num_taps == 6) { + srcs[2] = _mm_alignr_epi8(srcs[4], srcs[0], 8); + srcs[3] = _mm_alignr_epi8(srcs[4], srcs[0], 12); + srcs[5] = _mm_alignr_epi8(srcs[8], srcs[4], 4); + } else if (num_taps == 8) { + srcs[5] = _mm_alignr_epi8(srcs[8], srcs[4], 4); + srcs[6] = _mm_alignr_epi8(srcs[8], srcs[4], 8); + srcs[7] = _mm_alignr_epi8(srcs[8], srcs[4], 12); + } + + const __m128i sum = + SimpleSum2DVerticalTaps<num_taps, /*is_compound=*/false>(srcs, taps); + const __m128i results = _mm_packus_epi16(sum, sum); + + Store2(dst8, results); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 2)); + // When |height| <= 4 the taps are restricted to 2 and 4 tap variants. + // Therefore we don't need to check this condition when |height| > 4. + if (num_taps <= 4 && height == 2) return; + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 6)); + dst8 += dst_stride; + + srcs[0] = srcs[4]; + if (num_taps == 6) { + srcs[1] = srcs[5]; + srcs[4] = srcs[8]; + } else if (num_taps == 8) { + srcs[1] = srcs[5]; + srcs[2] = srcs[6]; + srcs[3] = srcs[7]; + srcs[4] = srcs[8]; + } + + y -= 4; + } while (y != 0); +} + +// The 1D compound shift is always |kInterRoundBitsHorizontal|, even for 1D +// Vertical calculations. +__m128i Compound1DShift(const __m128i sum) { + return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1); +} + +template <int filter_index> +__m128i SumVerticalTaps(const __m128i* const srcs, const __m128i* const v_tap) { + __m128i v_src[4]; + + if (filter_index < 2) { + // 6 taps. + v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]); + v_src[1] = _mm_unpacklo_epi8(srcs[2], srcs[3]); + v_src[2] = _mm_unpacklo_epi8(srcs[4], srcs[5]); + } else if (filter_index == 2) { + // 8 taps. + v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]); + v_src[1] = _mm_unpacklo_epi8(srcs[2], srcs[3]); + v_src[2] = _mm_unpacklo_epi8(srcs[4], srcs[5]); + v_src[3] = _mm_unpacklo_epi8(srcs[6], srcs[7]); + } else if (filter_index == 3) { + // 2 taps. + v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]); + } else if (filter_index > 3) { + // 4 taps. + v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]); + v_src[1] = _mm_unpacklo_epi8(srcs[2], srcs[3]); + } + const __m128i sum = SumOnePassTaps<filter_index>(v_src, v_tap); + return sum; +} + +// TODO(slavarnway): Use num_taps instead of filter_index for templates. See the +// 2D version. +template <int num_taps, int filter_index, bool is_compound = false> +void FilterVertical4xH(const uint8_t* src, const ptrdiff_t src_stride, + void* const dst, const ptrdiff_t dst_stride, + const int height, const __m128i* const v_tap) { + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + __m128i srcs[9]; + + if (num_taps == 2) { + srcs[2] = _mm_setzero_si128(); + // 00 01 02 03 + srcs[0] = Load4(src); + src += src_stride; + + int y = height; + do { + // 10 11 12 13 + const __m128i a = Load4(src); + // 00 01 02 03 10 11 12 13 + srcs[0] = _mm_unpacklo_epi32(srcs[0], a); + src += src_stride; + // 20 21 22 23 + srcs[2] = Load4(src); + src += src_stride; + // 10 11 12 13 20 21 22 23 + srcs[1] = _mm_unpacklo_epi32(a, srcs[2]); + + const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + if (is_compound) { + const __m128i results = Compound1DShift(sums); + StoreUnaligned16(dst16, results); + dst16 += 4 << 1; + } else { + const __m128i results_16 = + RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m128i results = _mm_packus_epi16(results_16, results_16); + Store4(dst8, results); + dst8 += dst_stride; + Store4(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + y -= 2; + } while (y != 0); + } else if (num_taps == 4) { + srcs[4] = _mm_setzero_si128(); + // 00 01 02 03 + srcs[0] = Load4(src); + src += src_stride; + // 10 11 12 13 + const __m128i a = Load4(src); + // 00 01 02 03 10 11 12 13 + srcs[0] = _mm_unpacklo_epi32(srcs[0], a); + src += src_stride; + // 20 21 22 23 + srcs[2] = Load4(src); + src += src_stride; + // 10 11 12 13 20 21 22 23 + srcs[1] = _mm_unpacklo_epi32(a, srcs[2]); + + int y = height; + do { + // 30 31 32 33 + const __m128i b = Load4(src); + // 20 21 22 23 30 31 32 33 + srcs[2] = _mm_unpacklo_epi32(srcs[2], b); + src += src_stride; + // 40 41 42 43 + srcs[4] = Load4(src); + src += src_stride; + // 30 31 32 33 40 41 42 43 + srcs[3] = _mm_unpacklo_epi32(b, srcs[4]); + + const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + if (is_compound) { + const __m128i results = Compound1DShift(sums); + StoreUnaligned16(dst16, results); + dst16 += 4 << 1; + } else { + const __m128i results_16 = + RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m128i results = _mm_packus_epi16(results_16, results_16); + Store4(dst8, results); + dst8 += dst_stride; + Store4(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + y -= 2; + } while (y != 0); + } else if (num_taps == 6) { + srcs[6] = _mm_setzero_si128(); + // 00 01 02 03 + srcs[0] = Load4(src); + src += src_stride; + // 10 11 12 13 + const __m128i a = Load4(src); + // 00 01 02 03 10 11 12 13 + srcs[0] = _mm_unpacklo_epi32(srcs[0], a); + src += src_stride; + // 20 21 22 23 + srcs[2] = Load4(src); + src += src_stride; + // 10 11 12 13 20 21 22 23 + srcs[1] = _mm_unpacklo_epi32(a, srcs[2]); + // 30 31 32 33 + const __m128i b = Load4(src); + // 20 21 22 23 30 31 32 33 + srcs[2] = _mm_unpacklo_epi32(srcs[2], b); + src += src_stride; + // 40 41 42 43 + srcs[4] = Load4(src); + src += src_stride; + // 30 31 32 33 40 41 42 43 + srcs[3] = _mm_unpacklo_epi32(b, srcs[4]); + + int y = height; + do { + // 50 51 52 53 + const __m128i c = Load4(src); + // 40 41 42 43 50 51 52 53 + srcs[4] = _mm_unpacklo_epi32(srcs[4], c); + src += src_stride; + // 60 61 62 63 + srcs[6] = Load4(src); + src += src_stride; + // 50 51 52 53 60 61 62 63 + srcs[5] = _mm_unpacklo_epi32(c, srcs[6]); + + const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + if (is_compound) { + const __m128i results = Compound1DShift(sums); + StoreUnaligned16(dst16, results); + dst16 += 4 << 1; + } else { + const __m128i results_16 = + RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m128i results = _mm_packus_epi16(results_16, results_16); + Store4(dst8, results); + dst8 += dst_stride; + Store4(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + srcs[3] = srcs[5]; + srcs[4] = srcs[6]; + y -= 2; + } while (y != 0); + } else if (num_taps == 8) { + srcs[8] = _mm_setzero_si128(); + // 00 01 02 03 + srcs[0] = Load4(src); + src += src_stride; + // 10 11 12 13 + const __m128i a = Load4(src); + // 00 01 02 03 10 11 12 13 + srcs[0] = _mm_unpacklo_epi32(srcs[0], a); + src += src_stride; + // 20 21 22 23 + srcs[2] = Load4(src); + src += src_stride; + // 10 11 12 13 20 21 22 23 + srcs[1] = _mm_unpacklo_epi32(a, srcs[2]); + // 30 31 32 33 + const __m128i b = Load4(src); + // 20 21 22 23 30 31 32 33 + srcs[2] = _mm_unpacklo_epi32(srcs[2], b); + src += src_stride; + // 40 41 42 43 + srcs[4] = Load4(src); + src += src_stride; + // 30 31 32 33 40 41 42 43 + srcs[3] = _mm_unpacklo_epi32(b, srcs[4]); + // 50 51 52 53 + const __m128i c = Load4(src); + // 40 41 42 43 50 51 52 53 + srcs[4] = _mm_unpacklo_epi32(srcs[4], c); + src += src_stride; + // 60 61 62 63 + srcs[6] = Load4(src); + src += src_stride; + // 50 51 52 53 60 61 62 63 + srcs[5] = _mm_unpacklo_epi32(c, srcs[6]); + + int y = height; + do { + // 70 71 72 73 + const __m128i d = Load4(src); + // 60 61 62 63 70 71 72 73 + srcs[6] = _mm_unpacklo_epi32(srcs[6], d); + src += src_stride; + // 80 81 82 83 + srcs[8] = Load4(src); + src += src_stride; + // 70 71 72 73 80 81 82 83 + srcs[7] = _mm_unpacklo_epi32(d, srcs[8]); + + const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + if (is_compound) { + const __m128i results = Compound1DShift(sums); + StoreUnaligned16(dst16, results); + dst16 += 4 << 1; + } else { + const __m128i results_16 = + RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m128i results = _mm_packus_epi16(results_16, results_16); + Store4(dst8, results); + dst8 += dst_stride; + Store4(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + srcs[3] = srcs[5]; + srcs[4] = srcs[6]; + srcs[5] = srcs[7]; + srcs[6] = srcs[8]; + y -= 2; + } while (y != 0); + } +} + +template <int num_taps, int filter_index, bool negative_outside_taps = false> +void FilterVertical2xH(const uint8_t* src, const ptrdiff_t src_stride, + void* const dst, const ptrdiff_t dst_stride, + const int height, const __m128i* const v_tap) { + auto* dst8 = static_cast<uint8_t*>(dst); + + __m128i srcs[9]; + + if (num_taps == 2) { + srcs[2] = _mm_setzero_si128(); + // 00 01 + srcs[0] = Load2(src); + src += src_stride; + + int y = height; + do { + // 00 01 10 11 + srcs[0] = Load2<1>(src, srcs[0]); + src += src_stride; + // 00 01 10 11 20 21 + srcs[0] = Load2<2>(src, srcs[0]); + src += src_stride; + // 00 01 10 11 20 21 30 31 + srcs[0] = Load2<3>(src, srcs[0]); + src += src_stride; + // 40 41 + srcs[2] = Load2<0>(src, srcs[2]); + src += src_stride; + // 00 01 10 11 20 21 30 31 40 41 + const __m128i srcs_0_2 = _mm_unpacklo_epi64(srcs[0], srcs[2]); + // 10 11 20 21 30 31 40 41 + srcs[1] = _mm_srli_si128(srcs_0_2, 2); + // This uses srcs[0]..srcs[1]. + const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + const __m128i results_16 = + RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m128i results = _mm_packus_epi16(results_16, results_16); + + Store2(dst8, results); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 2)); + if (height == 2) return; + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 6)); + dst8 += dst_stride; + + srcs[0] = srcs[2]; + y -= 4; + } while (y != 0); + } else if (num_taps == 4) { + srcs[4] = _mm_setzero_si128(); + + // 00 01 + srcs[0] = Load2(src); + src += src_stride; + // 00 01 10 11 + srcs[0] = Load2<1>(src, srcs[0]); + src += src_stride; + // 00 01 10 11 20 21 + srcs[0] = Load2<2>(src, srcs[0]); + src += src_stride; + + int y = height; + do { + // 00 01 10 11 20 21 30 31 + srcs[0] = Load2<3>(src, srcs[0]); + src += src_stride; + // 40 41 + srcs[4] = Load2<0>(src, srcs[4]); + src += src_stride; + // 40 41 50 51 + srcs[4] = Load2<1>(src, srcs[4]); + src += src_stride; + // 40 41 50 51 60 61 + srcs[4] = Load2<2>(src, srcs[4]); + src += src_stride; + // 00 01 10 11 20 21 30 31 40 41 50 51 60 61 + const __m128i srcs_0_4 = _mm_unpacklo_epi64(srcs[0], srcs[4]); + // 10 11 20 21 30 31 40 41 + srcs[1] = _mm_srli_si128(srcs_0_4, 2); + // 20 21 30 31 40 41 50 51 + srcs[2] = _mm_srli_si128(srcs_0_4, 4); + // 30 31 40 41 50 51 60 61 + srcs[3] = _mm_srli_si128(srcs_0_4, 6); + + // This uses srcs[0]..srcs[3]. + const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + const __m128i results_16 = + RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m128i results = _mm_packus_epi16(results_16, results_16); + + Store2(dst8, results); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 2)); + if (height == 2) return; + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 6)); + dst8 += dst_stride; + + srcs[0] = srcs[4]; + y -= 4; + } while (y != 0); + } else if (num_taps == 6) { + // During the vertical pass the number of taps is restricted when + // |height| <= 4. + assert(height > 4); + srcs[8] = _mm_setzero_si128(); + + // 00 01 + srcs[0] = Load2(src); + src += src_stride; + // 00 01 10 11 + srcs[0] = Load2<1>(src, srcs[0]); + src += src_stride; + // 00 01 10 11 20 21 + srcs[0] = Load2<2>(src, srcs[0]); + src += src_stride; + // 00 01 10 11 20 21 30 31 + srcs[0] = Load2<3>(src, srcs[0]); + src += src_stride; + // 40 41 + srcs[4] = Load2(src); + src += src_stride; + // 00 01 10 11 20 21 30 31 40 41 50 51 60 61 + const __m128i srcs_0_4x = _mm_unpacklo_epi64(srcs[0], srcs[4]); + // 10 11 20 21 30 31 40 41 + srcs[1] = _mm_srli_si128(srcs_0_4x, 2); + + int y = height; + do { + // 40 41 50 51 + srcs[4] = Load2<1>(src, srcs[4]); + src += src_stride; + // 40 41 50 51 60 61 + srcs[4] = Load2<2>(src, srcs[4]); + src += src_stride; + // 40 41 50 51 60 61 70 71 + srcs[4] = Load2<3>(src, srcs[4]); + src += src_stride; + // 80 81 + srcs[8] = Load2<0>(src, srcs[8]); + src += src_stride; + // 00 01 10 11 20 21 30 31 40 41 50 51 60 61 + const __m128i srcs_0_4 = _mm_unpacklo_epi64(srcs[0], srcs[4]); + // 20 21 30 31 40 41 50 51 + srcs[2] = _mm_srli_si128(srcs_0_4, 4); + // 30 31 40 41 50 51 60 61 + srcs[3] = _mm_srli_si128(srcs_0_4, 6); + const __m128i srcs_4_8 = _mm_unpacklo_epi64(srcs[4], srcs[8]); + // 50 51 60 61 70 71 80 81 + srcs[5] = _mm_srli_si128(srcs_4_8, 2); + + // This uses srcs[0]..srcs[5]. + const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + const __m128i results_16 = + RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m128i results = _mm_packus_epi16(results_16, results_16); + + Store2(dst8, results); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 2)); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 6)); + dst8 += dst_stride; + + srcs[0] = srcs[4]; + srcs[1] = srcs[5]; + srcs[4] = srcs[8]; + y -= 4; + } while (y != 0); + } else if (num_taps == 8) { + // During the vertical pass the number of taps is restricted when + // |height| <= 4. + assert(height > 4); + srcs[8] = _mm_setzero_si128(); + // 00 01 + srcs[0] = Load2(src); + src += src_stride; + // 00 01 10 11 + srcs[0] = Load2<1>(src, srcs[0]); + src += src_stride; + // 00 01 10 11 20 21 + srcs[0] = Load2<2>(src, srcs[0]); + src += src_stride; + // 00 01 10 11 20 21 30 31 + srcs[0] = Load2<3>(src, srcs[0]); + src += src_stride; + // 40 41 + srcs[4] = Load2(src); + src += src_stride; + // 40 41 50 51 + srcs[4] = Load2<1>(src, srcs[4]); + src += src_stride; + // 40 41 50 51 60 61 + srcs[4] = Load2<2>(src, srcs[4]); + src += src_stride; + + // 00 01 10 11 20 21 30 31 40 41 50 51 60 61 + const __m128i srcs_0_4 = _mm_unpacklo_epi64(srcs[0], srcs[4]); + // 10 11 20 21 30 31 40 41 + srcs[1] = _mm_srli_si128(srcs_0_4, 2); + // 20 21 30 31 40 41 50 51 + srcs[2] = _mm_srli_si128(srcs_0_4, 4); + // 30 31 40 41 50 51 60 61 + srcs[3] = _mm_srli_si128(srcs_0_4, 6); + + int y = height; + do { + // 40 41 50 51 60 61 70 71 + srcs[4] = Load2<3>(src, srcs[4]); + src += src_stride; + // 80 81 + srcs[8] = Load2<0>(src, srcs[8]); + src += src_stride; + // 80 81 90 91 + srcs[8] = Load2<1>(src, srcs[8]); + src += src_stride; + // 80 81 90 91 a0 a1 + srcs[8] = Load2<2>(src, srcs[8]); + src += src_stride; + + // 40 41 50 51 60 61 70 71 80 81 90 91 a0 a1 + const __m128i srcs_4_8 = _mm_unpacklo_epi64(srcs[4], srcs[8]); + // 50 51 60 61 70 71 80 81 + srcs[5] = _mm_srli_si128(srcs_4_8, 2); + // 60 61 70 71 80 81 90 91 + srcs[6] = _mm_srli_si128(srcs_4_8, 4); + // 70 71 80 81 90 91 a0 a1 + srcs[7] = _mm_srli_si128(srcs_4_8, 6); + + // This uses srcs[0]..srcs[7]. + const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + const __m128i results_16 = + RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m128i results = _mm_packus_epi16(results_16, results_16); + + Store2(dst8, results); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 2)); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 6)); + dst8 += dst_stride; + + srcs[0] = srcs[4]; + srcs[1] = srcs[5]; + srcs[2] = srcs[6]; + srcs[3] = srcs[7]; + srcs[4] = srcs[8]; + y -= 4; + } while (y != 0); + } +} diff --git a/src/dsp/x86/distance_weighted_blend_sse4.cc b/src/dsp/x86/distance_weighted_blend_sse4.cc index deb57ef..3c29b19 100644 --- a/src/dsp/x86/distance_weighted_blend_sse4.cc +++ b/src/dsp/x86/distance_weighted_blend_sse4.cc @@ -30,6 +30,7 @@ namespace libgav1 { namespace dsp { +namespace low_bitdepth { namespace { constexpr int kInterPostRoundBit = 4; @@ -212,13 +213,231 @@ void Init8bpp() { } } // namespace +} // namespace low_bitdepth -void DistanceWeightedBlendInit_SSE4_1() { Init8bpp(); } +#if LIBGAV1_MAX_BITDEPTH >= 10 +namespace high_bitdepth { +namespace { + +constexpr int kMax10bppSample = (1 << 10) - 1; +constexpr int kInterPostRoundBit = 4; + +inline __m128i ComputeWeightedAverage8(const __m128i& pred0, + const __m128i& pred1, + const __m128i& weight0, + const __m128i& weight1) { + // This offset is a combination of round_factor and round_offset + // which are to be added and subtracted respectively. + // Here kInterPostRoundBit + 4 is considering bitdepth=10. + constexpr int offset = + (1 << ((kInterPostRoundBit + 4) - 1)) - (kCompoundOffset << 4); + const __m128i zero = _mm_setzero_si128(); + const __m128i bias = _mm_set1_epi32(offset); + const __m128i clip_high = _mm_set1_epi16(kMax10bppSample); + + __m128i prediction0 = _mm_cvtepu16_epi32(pred0); + __m128i mult0 = _mm_mullo_epi32(prediction0, weight0); + __m128i prediction1 = _mm_cvtepu16_epi32(pred1); + __m128i mult1 = _mm_mullo_epi32(prediction1, weight1); + __m128i sum = _mm_add_epi32(mult0, mult1); + sum = _mm_add_epi32(sum, bias); + const __m128i result0 = _mm_srai_epi32(sum, kInterPostRoundBit + 4); + + prediction0 = _mm_unpackhi_epi16(pred0, zero); + mult0 = _mm_mullo_epi32(prediction0, weight0); + prediction1 = _mm_unpackhi_epi16(pred1, zero); + mult1 = _mm_mullo_epi32(prediction1, weight1); + sum = _mm_add_epi32(mult0, mult1); + sum = _mm_add_epi32(sum, bias); + const __m128i result1 = _mm_srai_epi32(sum, kInterPostRoundBit + 4); + const __m128i pack = _mm_packus_epi32(result0, result1); + + return _mm_min_epi16(pack, clip_high); +} + +template <int height> +inline void DistanceWeightedBlend4xH_SSE4_1( + const uint16_t* pred_0, const uint16_t* pred_1, const uint8_t weight_0, + const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) { + auto* dst = static_cast<uint16_t*>(dest); + const __m128i weight0 = _mm_set1_epi32(weight_0); + const __m128i weight1 = _mm_set1_epi32(weight_1); + + int y = height; + do { + const __m128i src_00 = LoadLo8(pred_0); + const __m128i src_10 = LoadLo8(pred_1); + pred_0 += 4; + pred_1 += 4; + __m128i src_0 = LoadHi8(src_00, pred_0); + __m128i src_1 = LoadHi8(src_10, pred_1); + pred_0 += 4; + pred_1 += 4; + const __m128i res0 = + ComputeWeightedAverage8(src_0, src_1, weight0, weight1); + + const __m128i src_01 = LoadLo8(pred_0); + const __m128i src_11 = LoadLo8(pred_1); + pred_0 += 4; + pred_1 += 4; + src_0 = LoadHi8(src_01, pred_0); + src_1 = LoadHi8(src_11, pred_1); + pred_0 += 4; + pred_1 += 4; + const __m128i res1 = + ComputeWeightedAverage8(src_0, src_1, weight0, weight1); + + StoreLo8(dst, res0); + dst += dest_stride; + StoreHi8(dst, res0); + dst += dest_stride; + StoreLo8(dst, res1); + dst += dest_stride; + StoreHi8(dst, res1); + dst += dest_stride; + y -= 4; + } while (y != 0); +} + +template <int height> +inline void DistanceWeightedBlend8xH_SSE4_1( + const uint16_t* pred_0, const uint16_t* pred_1, const uint8_t weight_0, + const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) { + auto* dst = static_cast<uint16_t*>(dest); + const __m128i weight0 = _mm_set1_epi32(weight_0); + const __m128i weight1 = _mm_set1_epi32(weight_1); + + int y = height; + do { + const __m128i src_00 = LoadAligned16(pred_0); + const __m128i src_10 = LoadAligned16(pred_1); + pred_0 += 8; + pred_1 += 8; + const __m128i res0 = + ComputeWeightedAverage8(src_00, src_10, weight0, weight1); + + const __m128i src_01 = LoadAligned16(pred_0); + const __m128i src_11 = LoadAligned16(pred_1); + pred_0 += 8; + pred_1 += 8; + const __m128i res1 = + ComputeWeightedAverage8(src_01, src_11, weight0, weight1); + + StoreUnaligned16(dst, res0); + dst += dest_stride; + StoreUnaligned16(dst, res1); + dst += dest_stride; + y -= 2; + } while (y != 0); +} + +inline void DistanceWeightedBlendLarge_SSE4_1( + const uint16_t* pred_0, const uint16_t* pred_1, const uint8_t weight_0, + const uint8_t weight_1, const int width, const int height, void* const dest, + const ptrdiff_t dest_stride) { + auto* dst = static_cast<uint16_t*>(dest); + const __m128i weight0 = _mm_set1_epi32(weight_0); + const __m128i weight1 = _mm_set1_epi32(weight_1); + + int y = height; + do { + int x = 0; + do { + const __m128i src_0_lo = LoadAligned16(pred_0 + x); + const __m128i src_1_lo = LoadAligned16(pred_1 + x); + const __m128i res_lo = + ComputeWeightedAverage8(src_0_lo, src_1_lo, weight0, weight1); + + const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8); + const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8); + const __m128i res_hi = + ComputeWeightedAverage8(src_0_hi, src_1_hi, weight0, weight1); + + StoreUnaligned16(dst + x, res_lo); + x += 8; + StoreUnaligned16(dst + x, res_hi); + x += 8; + } while (x < width); + dst += dest_stride; + pred_0 += width; + pred_1 += width; + } while (--y != 0); +} + +void DistanceWeightedBlend_SSE4_1(const void* prediction_0, + const void* prediction_1, + const uint8_t weight_0, + const uint8_t weight_1, const int width, + const int height, void* const dest, + const ptrdiff_t dest_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + const ptrdiff_t dst_stride = dest_stride / sizeof(*pred_0); + if (width == 4) { + if (height == 4) { + DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1, + dest, dst_stride); + } else if (height == 8) { + DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1, + dest, dst_stride); + } else { + assert(height == 16); + DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1, + dest, dst_stride); + } + return; + } + + if (width == 8) { + switch (height) { + case 4: + DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1, + dest, dst_stride); + return; + case 8: + DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1, + dest, dst_stride); + return; + case 16: + DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1, + dest, dst_stride); + return; + default: + assert(height == 32); + DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1, + dest, dst_stride); + + return; + } + } + + DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width, + height, dest, dst_stride); +} + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); +#if DSP_ENABLED_10BPP_SSE4_1(DistanceWeightedBlend) + dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1; +#endif +} + +} // namespace +} // namespace high_bitdepth +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +void DistanceWeightedBlendInit_SSE4_1() { + low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + high_bitdepth::Init10bpp(); +#endif +} } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_SSE4_1 +#else // !LIBGAV1_TARGETING_SSE4_1 namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/distance_weighted_blend_sse4.h b/src/dsp/x86/distance_weighted_blend_sse4.h index 8646eca..dbb9f88 100644 --- a/src/dsp/x86/distance_weighted_blend_sse4.h +++ b/src/dsp/x86/distance_weighted_blend_sse4.h @@ -36,6 +36,10 @@ void DistanceWeightedBlendInit_SSE4_1(); #define LIBGAV1_Dsp8bpp_DistanceWeightedBlend LIBGAV1_CPU_SSE4_1 #endif +#ifndef LIBGAV1_Dsp10bpp_DistanceWeightedBlend +#define LIBGAV1_Dsp10bpp_DistanceWeightedBlend LIBGAV1_CPU_SSE4_1 +#endif + #endif // LIBGAV1_TARGETING_SSE4_1 #endif // LIBGAV1_SRC_DSP_X86_DISTANCE_WEIGHTED_BLEND_SSE4_H_ diff --git a/src/dsp/x86/film_grain_sse4.cc b/src/dsp/x86/film_grain_sse4.cc new file mode 100644 index 0000000..745c1ca --- /dev/null +++ b/src/dsp/x86/film_grain_sse4.cc @@ -0,0 +1,514 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/film_grain.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 +#include <smmintrin.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> + +#include "src/dsp/common.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/film_grain_common.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/utils/common.h" +#include "src/utils/compiler_attributes.h" +#include "src/utils/logging.h" + +namespace libgav1 { +namespace dsp { +namespace film_grain { +namespace { + +// Load 8 values from source, widening to int16_t intermediate value size. +// The function is overloaded for each type and bitdepth for simplicity. +inline __m128i LoadSource(const int8_t* src) { + return _mm_cvtepi8_epi16(LoadLo8(src)); +} + +// Load 8 values from source, widening to int16_t intermediate value size. +inline __m128i LoadSource(const uint8_t* src) { + return _mm_cvtepu8_epi16(LoadLo8(src)); +} + +inline __m128i LoadSourceMsan(const uint8_t* src, const int valid_range) { + return _mm_cvtepu8_epi16(LoadLo8Msan(src, 8 - valid_range)); +} + +// Store 8 values to dest, narrowing to uint8_t from int16_t intermediate value. +inline void StoreUnsigned(uint8_t* dest, const __m128i data) { + StoreLo8(dest, _mm_packus_epi16(data, data)); +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +// Load 8 values from source. +inline __m128i LoadSource(const int16_t* src) { return LoadUnaligned16(src); } + +// Load 8 values from source. +inline __m128i LoadSource(const uint16_t* src) { return LoadUnaligned16(src); } + +// Store 8 values to dest. +inline void StoreUnsigned(uint16_t* dest, const __m128i data) { + StoreUnaligned16(dest, data); +} +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +// For BlendNoiseWithImageChromaWithCfl, only |subsampling_x| is needed. +inline __m128i GetAverageLuma(const uint8_t* const luma, int subsampling_x) { + if (subsampling_x != 0) { + const __m128i src = LoadUnaligned16(luma); + + return RightShiftWithRounding_U16( + _mm_hadd_epi16(_mm_cvtepu8_epi16(src), + _mm_unpackhi_epi8(src, _mm_setzero_si128())), + 1); + } + return _mm_cvtepu8_epi16(LoadLo8(luma)); +} + +inline __m128i GetAverageLumaMsan(const uint8_t* const luma, int subsampling_x, + int valid_range) { + if (subsampling_x != 0) { + const __m128i src = LoadUnaligned16Msan(luma, 16 - valid_range); + + return RightShiftWithRounding_U16( + _mm_hadd_epi16(_mm_cvtepu8_epi16(src), + _mm_unpackhi_epi8(src, _mm_setzero_si128())), + 1); + } + return _mm_cvtepu8_epi16(LoadLo8Msan(luma, 8 - valid_range)); +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +// For BlendNoiseWithImageChromaWithCfl, only |subsampling_x| is needed. +inline __m128i GetAverageLuma(const uint16_t* const luma, int subsampling_x) { + if (subsampling_x != 0) { + return RightShiftWithRounding_U16( + _mm_hadd_epi16(LoadUnaligned16(luma), LoadUnaligned16(luma + 8)), 1); + } + return LoadUnaligned16(luma); +} + +inline __m128i GetAverageLumaMsan(const uint16_t* const luma, int subsampling_x, + int valid_range) { + if (subsampling_x != 0) { + return RightShiftWithRounding_U16( + _mm_hadd_epi16( + LoadUnaligned16Msan(luma, 16 - valid_range * sizeof(*luma)), + LoadUnaligned16Msan(luma + 8, 32 - valid_range * sizeof(*luma))), + 1); + } + return LoadUnaligned16Msan(luma, 16 - valid_range * sizeof(*luma)); +} +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +inline __m128i Clip3(const __m128i value, const __m128i low, + const __m128i high) { + const __m128i clipped_to_ceiling = _mm_min_epi16(high, value); + return _mm_max_epi16(low, clipped_to_ceiling); +} + +template <int bitdepth, typename Pixel> +inline __m128i GetScalingFactors( + const uint8_t scaling_lut[kScalingLookupTableSize], const Pixel* source) { + alignas(16) int16_t start_vals[8]; + if (bitdepth == 8) { + // TODO(petersonab): Speed this up by creating a uint16_t scaling_lut. + // Currently this code results in a series of movzbl. + for (int i = 0; i < 8; ++i) { + start_vals[i] = scaling_lut[source[i]]; + } + return LoadAligned16(start_vals); + } + alignas(16) int16_t end_vals[8]; + // TODO(petersonab): Precompute this into a larger table for direct lookups. + for (int i = 0; i < 8; ++i) { + const int index = source[i] >> 2; + start_vals[i] = scaling_lut[index]; + end_vals[i] = scaling_lut[index + 1]; + } + const __m128i start = LoadAligned16(start_vals); + const __m128i end = LoadAligned16(end_vals); + __m128i remainder = LoadSource(source); + remainder = _mm_srli_epi16(_mm_slli_epi16(remainder, 14), 1); + const __m128i delta = _mm_mulhrs_epi16(_mm_sub_epi16(end, start), remainder); + return _mm_add_epi16(start, delta); +} + +// |scaling_shift| is in range [8,11]. +template <int bitdepth> +inline __m128i ScaleNoise(const __m128i noise, const __m128i scaling, + const __m128i scaling_shift) { + const __m128i shifted_scale_factors = _mm_sll_epi16(scaling, scaling_shift); + return _mm_mulhrs_epi16(noise, shifted_scale_factors); +} + +template <int bitdepth, typename GrainType, typename Pixel> +void BlendNoiseWithImageLuma_SSE4_1( + const void* noise_image_ptr, int min_value, int max_luma, int scaling_shift, + int width, int height, int start_height, + const uint8_t scaling_lut_y[kScalingLookupTableSize], + const void* source_plane_y, ptrdiff_t source_stride_y, void* dest_plane_y, + ptrdiff_t dest_stride_y) { + const auto* noise_image = + static_cast<const Array2D<GrainType>*>(noise_image_ptr); + const auto* in_y_row = static_cast<const Pixel*>(source_plane_y); + source_stride_y /= sizeof(Pixel); + auto* out_y_row = static_cast<Pixel*>(dest_plane_y); + dest_stride_y /= sizeof(Pixel); + const __m128i floor = _mm_set1_epi16(min_value); + const __m128i ceiling = _mm_set1_epi16(max_luma); + const int safe_width = width & ~7; + const __m128i derived_scaling_shift = _mm_cvtsi32_si128(15 - scaling_shift); + int y = 0; + do { + int x = 0; + for (; x < safe_width; x += 8) { + // TODO(b/133525232): Make 16-pixel version of loop body. + const __m128i orig = LoadSource(&in_y_row[x]); + const __m128i scaling = + GetScalingFactors<bitdepth, Pixel>(scaling_lut_y, &in_y_row[x]); + __m128i noise = LoadSource(&(noise_image[kPlaneY][y + start_height][x])); + + noise = ScaleNoise<bitdepth>(noise, scaling, derived_scaling_shift); + const __m128i combined = _mm_add_epi16(orig, noise); + StoreUnsigned(&out_y_row[x], Clip3(combined, floor, ceiling)); + } + + if (x < width) { + Pixel luma_buffer[8]; + // Prevent arbitrary indices from entering GetScalingFactors. + memset(luma_buffer, 0, sizeof(luma_buffer)); + const int valid_range = width - x; + memcpy(luma_buffer, &in_y_row[x], valid_range * sizeof(in_y_row[0])); + luma_buffer[valid_range] = in_y_row[width - 1]; + const __m128i orig = LoadSource(&in_y_row[x]); + const __m128i scaling = + GetScalingFactors<bitdepth, Pixel>(scaling_lut_y, luma_buffer); + __m128i noise = LoadSource(&(noise_image[kPlaneY][y + start_height][x])); + + noise = ScaleNoise<bitdepth>(noise, scaling, derived_scaling_shift); + const __m128i combined = _mm_add_epi16(orig, noise); + StoreUnsigned(&out_y_row[x], Clip3(combined, floor, ceiling)); + } + in_y_row += source_stride_y; + out_y_row += dest_stride_y; + } while (++y < height); + out_y_row = static_cast<Pixel*>(dest_plane_y); +} + +template <int bitdepth, typename GrainType, typename Pixel> +inline __m128i BlendChromaValsWithCfl( + const Pixel* average_luma_buffer, + const uint8_t scaling_lut[kScalingLookupTableSize], + const Pixel* chroma_cursor, const GrainType* noise_image_cursor, + const __m128i scaling_shift) { + const __m128i scaling = + GetScalingFactors<bitdepth, Pixel>(scaling_lut, average_luma_buffer); + const __m128i orig = LoadSource(chroma_cursor); + __m128i noise = LoadSource(noise_image_cursor); + noise = ScaleNoise<bitdepth>(noise, scaling, scaling_shift); + return _mm_add_epi16(orig, noise); +} + +template <int bitdepth, typename GrainType, typename Pixel> +LIBGAV1_ALWAYS_INLINE void BlendChromaPlaneWithCfl_SSE4_1( + const Array2D<GrainType>& noise_image, int min_value, int max_chroma, + int width, int height, int start_height, int subsampling_x, + int subsampling_y, int scaling_shift, + const uint8_t scaling_lut[kScalingLookupTableSize], const Pixel* in_y_row, + ptrdiff_t source_stride_y, const Pixel* in_chroma_row, + ptrdiff_t source_stride_chroma, Pixel* out_chroma_row, + ptrdiff_t dest_stride) { + const __m128i floor = _mm_set1_epi16(min_value); + const __m128i ceiling = _mm_set1_epi16(max_chroma); + alignas(16) Pixel luma_buffer[16]; + + const int chroma_height = (height + subsampling_y) >> subsampling_y; + const int chroma_width = (width + subsampling_x) >> subsampling_x; + // |chroma_width| is rounded up. If |width| is odd, then the final pixel will + // need to be guarded from overread, even if |chroma_width| is divisible by 8. + const int safe_chroma_width = (chroma_width - (width & 1)) & ~7; + + // Writing to this buffer avoids the cost of doing 8 lane lookups in a row + // in GetScalingFactors. + Pixel average_luma_buffer[8]; + assert(start_height % 2 == 0); + start_height >>= subsampling_y; + const __m128i derived_scaling_shift = _mm_cvtsi32_si128(15 - scaling_shift); + int y = 0; + do { + int x = 0; + for (; x < safe_chroma_width; x += 8) { + const int luma_x = x << subsampling_x; + // TODO(petersonab): Consider specializing by subsampling_x. In the 444 + // case &in_y_row[x] can be passed to GetScalingFactors directly. + const __m128i average_luma = + GetAverageLuma(&in_y_row[luma_x], subsampling_x); + StoreUnsigned(average_luma_buffer, average_luma); + + const __m128i blended = + BlendChromaValsWithCfl<bitdepth, GrainType, Pixel>( + average_luma_buffer, scaling_lut, &in_chroma_row[x], + &(noise_image[y + start_height][x]), derived_scaling_shift); + StoreUnsigned(&out_chroma_row[x], Clip3(blended, floor, ceiling)); + } + + // This section only runs if width % (8 << sub_x) != 0. It should never run + // on 720p and above. + if (x < chroma_width) { + // Prevent huge indices from entering GetScalingFactors due to + // uninitialized values. This is not a problem in 8bpp because the table + // is made larger than 255 values. + if (bitdepth > 8) { + memset(luma_buffer, 0, sizeof(luma_buffer)); + } + const int luma_x = x << subsampling_x; + const int valid_range = width - luma_x; + assert(valid_range < 16); + memcpy(luma_buffer, &in_y_row[luma_x], valid_range * sizeof(in_y_row[0])); + luma_buffer[valid_range] = in_y_row[width - 1]; + const __m128i average_luma = + GetAverageLumaMsan(luma_buffer, subsampling_x, valid_range + 1); + StoreUnsigned(average_luma_buffer, average_luma); + + const __m128i blended = + BlendChromaValsWithCfl<bitdepth, GrainType, Pixel>( + average_luma_buffer, scaling_lut, &in_chroma_row[x], + &(noise_image[y + start_height][x]), derived_scaling_shift); + StoreUnsigned(&out_chroma_row[x], Clip3(blended, floor, ceiling)); + } + + in_y_row += source_stride_y << subsampling_y; + in_chroma_row += source_stride_chroma; + out_chroma_row += dest_stride; + } while (++y < chroma_height); +} + +// This function is for the case params_.chroma_scaling_from_luma == true. +// This further implies that scaling_lut_u == scaling_lut_v == scaling_lut_y. +template <int bitdepth, typename GrainType, typename Pixel> +void BlendNoiseWithImageChromaWithCfl_SSE4_1( + Plane plane, const FilmGrainParams& params, const void* noise_image_ptr, + int min_value, int max_chroma, int width, int height, int start_height, + int subsampling_x, int subsampling_y, + const uint8_t scaling_lut[kScalingLookupTableSize], + const void* source_plane_y, ptrdiff_t source_stride_y, + const void* source_plane_uv, ptrdiff_t source_stride_uv, + void* dest_plane_uv, ptrdiff_t dest_stride_uv) { + const auto* noise_image = + static_cast<const Array2D<GrainType>*>(noise_image_ptr); + const auto* in_y = static_cast<const Pixel*>(source_plane_y); + source_stride_y /= sizeof(Pixel); + + const auto* in_uv = static_cast<const Pixel*>(source_plane_uv); + source_stride_uv /= sizeof(Pixel); + auto* out_uv = static_cast<Pixel*>(dest_plane_uv); + dest_stride_uv /= sizeof(Pixel); + BlendChromaPlaneWithCfl_SSE4_1<bitdepth, GrainType, Pixel>( + noise_image[plane], min_value, max_chroma, width, height, start_height, + subsampling_x, subsampling_y, params.chroma_scaling, scaling_lut, in_y, + source_stride_y, in_uv, source_stride_uv, out_uv, dest_stride_uv); +} + +} // namespace + +namespace low_bitdepth { +namespace { + +// |offset| is 32x4 packed to add with the result of _mm_madd_epi16. +inline __m128i BlendChromaValsNoCfl8bpp( + const uint8_t scaling_lut[kScalingLookupTableSize], const __m128i& orig, + const int8_t* noise_image_cursor, const __m128i& average_luma, + const __m128i& scaling_shift, const __m128i& offset, + const __m128i& weights) { + uint8_t merged_buffer[8]; + const __m128i combined_lo = + _mm_madd_epi16(_mm_unpacklo_epi16(average_luma, orig), weights); + const __m128i combined_hi = + _mm_madd_epi16(_mm_unpackhi_epi16(average_luma, orig), weights); + const __m128i merged_base = _mm_packs_epi32(_mm_srai_epi32((combined_lo), 6), + _mm_srai_epi32((combined_hi), 6)); + + const __m128i merged = _mm_add_epi16(merged_base, offset); + + StoreLo8(merged_buffer, _mm_packus_epi16(merged, merged)); + const __m128i scaling = + GetScalingFactors<8, uint8_t>(scaling_lut, merged_buffer); + __m128i noise = LoadSource(noise_image_cursor); + noise = ScaleNoise<8>(noise, scaling, scaling_shift); + return _mm_add_epi16(orig, noise); +} + +LIBGAV1_ALWAYS_INLINE void BlendChromaPlane8bpp_SSE4_1( + const Array2D<int8_t>& noise_image, int min_value, int max_chroma, + int width, int height, int start_height, int subsampling_x, + int subsampling_y, int scaling_shift, int chroma_offset, + int chroma_multiplier, int luma_multiplier, + const uint8_t scaling_lut[kScalingLookupTableSize], const uint8_t* in_y_row, + ptrdiff_t source_stride_y, const uint8_t* in_chroma_row, + ptrdiff_t source_stride_chroma, uint8_t* out_chroma_row, + ptrdiff_t dest_stride) { + const __m128i floor = _mm_set1_epi16(min_value); + const __m128i ceiling = _mm_set1_epi16(max_chroma); + + const int chroma_height = (height + subsampling_y) >> subsampling_y; + const int chroma_width = (width + subsampling_x) >> subsampling_x; + // |chroma_width| is rounded up. If |width| is odd, then the final luma pixel + // will need to be guarded from overread, even if |chroma_width| is a + // multiple of 8. + const int safe_chroma_width = (chroma_width - (width & 1)) & ~7; + alignas(16) uint8_t luma_buffer[16]; + const __m128i offset = _mm_set1_epi16(chroma_offset); + const __m128i multipliers = _mm_set1_epi32(LeftShift(chroma_multiplier, 16) | + (luma_multiplier & 0xFFFF)); + const __m128i derived_scaling_shift = _mm_cvtsi32_si128(15 - scaling_shift); + + start_height >>= subsampling_y; + int y = 0; + do { + int x = 0; + for (; x < safe_chroma_width; x += 8) { + const int luma_x = x << subsampling_x; + const __m128i average_luma = + GetAverageLuma(&in_y_row[luma_x], subsampling_x); + const __m128i orig_chroma = LoadSource(&in_chroma_row[x]); + const __m128i blended = BlendChromaValsNoCfl8bpp( + scaling_lut, orig_chroma, &(noise_image[y + start_height][x]), + average_luma, derived_scaling_shift, offset, multipliers); + StoreUnsigned(&out_chroma_row[x], Clip3(blended, floor, ceiling)); + } + + if (x < chroma_width) { + // Begin right edge iteration. Same as the normal iterations, but the + // |average_luma| computation requires a duplicated luma value at the + // end. + const int luma_x = x << subsampling_x; + const int valid_range = width - luma_x; + assert(valid_range < 16); + // There is no need to pre-initialize this buffer, because merged values + // used as indices are saturated in the 8bpp case. Uninitialized values + // are written outside the frame. + memcpy(luma_buffer, &in_y_row[luma_x], valid_range * sizeof(in_y_row[0])); + luma_buffer[valid_range] = in_y_row[width - 1]; + const int valid_range_chroma = chroma_width - x; + uint8_t chroma_buffer[8]; + memcpy(chroma_buffer, &in_chroma_row[x], + valid_range_chroma * sizeof(in_chroma_row[0])); + + const __m128i average_luma = + GetAverageLumaMsan(luma_buffer, subsampling_x, valid_range + 1); + const __m128i orig_chroma = + LoadSourceMsan(chroma_buffer, valid_range_chroma); + const __m128i blended = BlendChromaValsNoCfl8bpp( + scaling_lut, orig_chroma, &(noise_image[y + start_height][x]), + average_luma, derived_scaling_shift, offset, multipliers); + StoreUnsigned(&out_chroma_row[x], Clip3(blended, floor, ceiling)); + // End of right edge iteration. + } + + in_y_row += source_stride_y << subsampling_y; + in_chroma_row += source_stride_chroma; + out_chroma_row += dest_stride; + } while (++y < chroma_height); +} + +// This function is for the case params_.chroma_scaling_from_luma == false. +void BlendNoiseWithImageChroma8bpp_SSE4_1( + Plane plane, const FilmGrainParams& params, const void* noise_image_ptr, + int min_value, int max_chroma, int width, int height, int start_height, + int subsampling_x, int subsampling_y, + const uint8_t scaling_lut[kScalingLookupTableSize], + const void* source_plane_y, ptrdiff_t source_stride_y, + const void* source_plane_uv, ptrdiff_t source_stride_uv, + void* dest_plane_uv, ptrdiff_t dest_stride_uv) { + assert(plane == kPlaneU || plane == kPlaneV); + const auto* noise_image = + static_cast<const Array2D<int8_t>*>(noise_image_ptr); + const auto* in_y = static_cast<const uint8_t*>(source_plane_y); + const auto* in_uv = static_cast<const uint8_t*>(source_plane_uv); + auto* out_uv = static_cast<uint8_t*>(dest_plane_uv); + + const int offset = (plane == kPlaneU) ? params.u_offset : params.v_offset; + const int luma_multiplier = + (plane == kPlaneU) ? params.u_luma_multiplier : params.v_luma_multiplier; + const int multiplier = + (plane == kPlaneU) ? params.u_multiplier : params.v_multiplier; + BlendChromaPlane8bpp_SSE4_1( + noise_image[plane], min_value, max_chroma, width, height, start_height, + subsampling_x, subsampling_y, params.chroma_scaling, offset, multiplier, + luma_multiplier, scaling_lut, in_y, source_stride_y, in_uv, + source_stride_uv, out_uv, dest_stride_uv); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + + dsp->film_grain.blend_noise_luma = + BlendNoiseWithImageLuma_SSE4_1<8, int8_t, uint8_t>; + dsp->film_grain.blend_noise_chroma[0] = BlendNoiseWithImageChroma8bpp_SSE4_1; + dsp->film_grain.blend_noise_chroma[1] = + BlendNoiseWithImageChromaWithCfl_SSE4_1<8, int8_t, uint8_t>; +} + +} // namespace +} // namespace low_bitdepth + +#if LIBGAV1_MAX_BITDEPTH >= 10 +namespace high_bitdepth { +namespace { + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + + dsp->film_grain.blend_noise_luma = + BlendNoiseWithImageLuma_SSE4_1<10, int16_t, uint16_t>; + dsp->film_grain.blend_noise_chroma[1] = + BlendNoiseWithImageChromaWithCfl_SSE4_1<10, int16_t, uint16_t>; +} + +} // namespace +} // namespace high_bitdepth +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +} // namespace film_grain + +void FilmGrainInit_SSE4_1() { + film_grain::low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + film_grain::high_bitdepth::Init10bpp(); +#endif // LIBGAV1_MAX_BITDEPTH >= 10 +} + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_SSE4_1 + +namespace libgav1 { +namespace dsp { + +void FilmGrainInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/film_grain_sse4.h b/src/dsp/x86/film_grain_sse4.h new file mode 100644 index 0000000..1cacbac --- /dev/null +++ b/src/dsp/x86/film_grain_sse4.h @@ -0,0 +1,40 @@ +/* + * Copyright 2020 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_FILM_GRAIN_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_FILM_GRAIN_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initialize members of Dsp::film_grain. This function is not thread-safe. +void FilmGrainInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_TARGETING_SSE4_1 +#define LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseLuma LIBGAV1_DSP_SSE4_1 +#define LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseLuma LIBGAV1_DSP_SSE4_1 +#define LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseChroma LIBGAV1_DSP_SSE4_1 +#define LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseChromaWithCfl LIBGAV1_DSP_SSE4_1 +#define LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseChromaWithCfl LIBGAV1_DSP_SSE4_1 +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_X86_FILM_GRAIN_SSE4_H_ diff --git a/src/dsp/x86/intra_edge_sse4.cc b/src/dsp/x86/intra_edge_sse4.cc index 4a8658d..d6af907 100644 --- a/src/dsp/x86/intra_edge_sse4.cc +++ b/src/dsp/x86/intra_edge_sse4.cc @@ -22,7 +22,7 @@ #include <cassert> #include <cstddef> #include <cstdint> -#include <cstring> // memcpy +#include <cstring> #include "src/dsp/constants.h" #include "src/dsp/dsp.h" @@ -259,7 +259,7 @@ void IntraEdgeInit_SSE4_1() { Init8bpp(); } } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_SSE4_1 +#else // !LIBGAV1_TARGETING_SSE4_1 namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/intrapred_cfl_sse4.cc b/src/dsp/x86/intrapred_cfl_sse4.cc index fac1556..f2dcfdb 100644 --- a/src/dsp/x86/intrapred_cfl_sse4.cc +++ b/src/dsp/x86/intrapred_cfl_sse4.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "src/dsp/intrapred.h" +#include "src/dsp/intrapred_cfl.h" #include "src/utils/cpu.h" #if LIBGAV1_TARGETING_SSE4_1 @@ -29,9 +29,48 @@ #include "src/dsp/x86/common_sse4.h" #include "src/utils/common.h" #include "src/utils/compiler_attributes.h" +#include "src/utils/constants.h" namespace libgav1 { namespace dsp { +namespace { + +// This duplicates the last two 16-bit values in |row|. +inline __m128i LastRowSamples(const __m128i row) { + return _mm_shuffle_epi32(row, 0xFF); +} + +// This duplicates the last 16-bit value in |row|. +inline __m128i LastRowResult(const __m128i row) { + const __m128i dup_row = _mm_shufflehi_epi16(row, 0xFF); + return _mm_shuffle_epi32(dup_row, 0xFF); +} + +// Takes in two sums of input row pairs, and completes the computation for two +// output rows. +inline __m128i StoreLumaResults4_420(const __m128i vertical_sum0, + const __m128i vertical_sum1, + int16_t* luma_ptr) { + __m128i result = _mm_hadd_epi16(vertical_sum0, vertical_sum1); + result = _mm_slli_epi16(result, 1); + StoreLo8(luma_ptr, result); + StoreHi8(luma_ptr + kCflLumaBufferStride, result); + return result; +} + +// Takes two halves of a vertically added pair of rows and completes the +// computation for one output row. +inline __m128i StoreLumaResults8_420(const __m128i vertical_sum0, + const __m128i vertical_sum1, + int16_t* luma_ptr) { + __m128i result = _mm_hadd_epi16(vertical_sum0, vertical_sum1); + result = _mm_slli_epi16(result, 1); + StoreUnaligned16(luma_ptr, result); + return result; +} + +} // namespace + namespace low_bitdepth { namespace { @@ -40,8 +79,8 @@ namespace { inline __m128i CflPredictUnclipped(const __m128i* input, __m128i alpha_q12, __m128i alpha_sign, __m128i dc_q0) { - __m128i ac_q3 = LoadUnaligned16(input); - __m128i ac_sign = _mm_sign_epi16(alpha_sign, ac_q3); + const __m128i ac_q3 = LoadUnaligned16(input); + const __m128i ac_sign = _mm_sign_epi16(alpha_sign, ac_q3); __m128i scaled_luma_q0 = _mm_mulhrs_epi16(_mm_abs_epi16(ac_q3), alpha_q12); scaled_luma_q0 = _mm_sign_epi16(scaled_luma_q0, ac_sign); return _mm_add_epi16(scaled_luma_q0, dc_q0); @@ -88,8 +127,7 @@ void CflIntraPredictor_SSE4_1( template <int block_height_log2, bool is_inside> void CflSubsampler444_4xH_SSE4_1( int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], - const int /*max_luma_width*/, const int max_luma_height, - const void* const source, ptrdiff_t stride) { + const int max_luma_height, const void* const source, ptrdiff_t stride) { static_assert(block_height_log2 <= 4, ""); const int block_height = 1 << block_height_log2; const int visible_height = max_luma_height; @@ -119,12 +157,15 @@ void CflSubsampler444_4xH_SSE4_1( } while (y < visible_height); if (!is_inside) { - int y = visible_height; + // Replicate the 2 high lanes. + samples = _mm_shuffle_epi32(samples, 0xee); do { + StoreLo8(luma_ptr, samples); + luma_ptr += kCflLumaBufferStride; StoreHi8(luma_ptr, samples); luma_ptr += kCflLumaBufferStride; sum = _mm_add_epi16(sum, samples); - ++y; + y += 2; } while (y < block_height); } @@ -152,15 +193,15 @@ void CflSubsampler444_4xH_SSE4_1( static_assert(block_height_log2 <= 4, ""); assert(max_luma_width >= 4); assert(max_luma_height >= 4); - const int block_height = 1 << block_height_log2; - const int block_width = 4; + static_cast<void>(max_luma_width); + constexpr int block_height = 1 << block_height_log2; - if (block_height <= max_luma_height && block_width <= max_luma_width) { - CflSubsampler444_4xH_SSE4_1<block_height_log2, true>( - luma, max_luma_width, max_luma_height, source, stride); + if (block_height <= max_luma_height) { + CflSubsampler444_4xH_SSE4_1<block_height_log2, true>(luma, max_luma_height, + source, stride); } else { - CflSubsampler444_4xH_SSE4_1<block_height_log2, false>( - luma, max_luma_width, max_luma_height, source, stride); + CflSubsampler444_4xH_SSE4_1<block_height_log2, false>(luma, max_luma_height, + source, stride); } } @@ -302,19 +343,9 @@ void CflSubsampler444_SSE4_1( __m128i inner_sum_lo, inner_sum_hi; int y = 0; do { -#if LIBGAV1_MSAN // We can load uninitialized values here. Even though they are - // then masked off by blendv, MSAN isn't smart enough to - // understand that. So we switch to a C implementation here. - uint16_t c_arr[16]; - for (int x = 0; x < 16; x++) { - const int x_index = std::min(x, visible_width_16 - 1); - c_arr[x] = src[x_index] << 3; - } - samples0 = LoadUnaligned16(c_arr); - samples1 = LoadUnaligned16(c_arr + 8); - static_cast<void>(blend_mask_16); -#else - __m128i samples01 = LoadUnaligned16(src); + // We can load uninitialized values here. Even though they are then masked + // off by blendv, MSAN doesn't model that behavior. + __m128i samples01 = LoadUnaligned16Msan(src, invisible_width_16); if (!inside) { const __m128i border16 = @@ -323,26 +354,15 @@ void CflSubsampler444_SSE4_1( } samples0 = _mm_slli_epi16(_mm_cvtepu8_epi16(samples01), 3); samples1 = _mm_slli_epi16(_mm_unpackhi_epi8(samples01, zero), 3); -#endif // LIBGAV1_MSAN StoreUnaligned16(luma_ptr, samples0); StoreUnaligned16(luma_ptr + 8, samples1); __m128i inner_sum = _mm_add_epi16(samples0, samples1); if (block_width == 32) { -#if LIBGAV1_MSAN // We can load uninitialized values here. Even though they are - // then masked off by blendv, MSAN isn't smart enough to - // understand that. So we switch to a C implementation here. - uint16_t c_arr[16]; - for (int x = 16; x < 32; x++) { - const int x_index = std::min(x, visible_width_32 - 1); - c_arr[x - 16] = src[x_index] << 3; - } - samples2 = LoadUnaligned16(c_arr); - samples3 = LoadUnaligned16(c_arr + 8); - static_cast<void>(blend_mask_32); -#else - __m128i samples23 = LoadUnaligned16(src + 16); + // We can load uninitialized values here. Even though they are then masked + // off by blendv, MSAN doesn't model that behavior. + __m128i samples23 = LoadUnaligned16Msan(src + 16, invisible_width_32); if (!inside) { const __m128i border32 = _mm_set1_epi8(static_cast<int8_t>(src[visible_width_32 - 1])); @@ -350,7 +370,6 @@ void CflSubsampler444_SSE4_1( } samples2 = _mm_slli_epi16(_mm_cvtepu8_epi16(samples23), 3); samples3 = _mm_slli_epi16(_mm_unpackhi_epi8(samples23, zero), 3); -#endif // LIBGAV1_MSAN StoreUnaligned16(luma_ptr + 16, samples2); StoreUnaligned16(luma_ptr + 24, samples3); @@ -418,29 +437,6 @@ void CflSubsampler444_SSE4_1( } } -// Takes in two sums of input row pairs, and completes the computation for two -// output rows. -inline __m128i StoreLumaResults4_420(const __m128i vertical_sum0, - const __m128i vertical_sum1, - int16_t* luma_ptr) { - __m128i result = _mm_hadd_epi16(vertical_sum0, vertical_sum1); - result = _mm_slli_epi16(result, 1); - StoreLo8(luma_ptr, result); - StoreHi8(luma_ptr + kCflLumaBufferStride, result); - return result; -} - -// Takes two halves of a vertically added pair of rows and completes the -// computation for one output row. -inline __m128i StoreLumaResults8_420(const __m128i vertical_sum0, - const __m128i vertical_sum1, - int16_t* luma_ptr) { - __m128i result = _mm_hadd_epi16(vertical_sum0, vertical_sum1); - result = _mm_slli_epi16(result, 1); - StoreUnaligned16(luma_ptr, result); - return result; -} - template <int block_height_log2> void CflSubsampler420_4xH_SSE4_1( int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], @@ -511,17 +507,6 @@ void CflSubsampler420_4xH_SSE4_1( } } -// This duplicates the last two 16-bit values in |row|. -inline __m128i LastRowSamples(const __m128i row) { - return _mm_shuffle_epi32(row, 0xFF); -} - -// This duplicates the last 16-bit value in |row|. -inline __m128i LastRowResult(const __m128i row) { - const __m128i dup_row = _mm_shufflehi_epi16(row, 0xFF); - return _mm_shuffle_epi32(dup_row, 0xFF); -} - template <int block_height_log2, int max_luma_width> inline void CflSubsampler420Impl_8xH_SSE4_1( int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], @@ -655,10 +640,11 @@ inline void CflSubsampler420Impl_WxH_SSE4_1( __m128i final_sum = zero; const int block_height = 1 << block_height_log2; const int luma_height = std::min(block_height, max_luma_height >> 1); + static_assert(max_luma_width <= 32, ""); int16_t* luma_ptr = luma[0]; __m128i final_row_result; - // Begin first y section, covering width up to 16. + // Begin first y section, covering width up to 32. int y = 0; do { const uint8_t* src_next = src + stride; @@ -694,29 +680,32 @@ inline void CflSubsampler420Impl_WxH_SSE4_1( final_row_result = StoreLumaResults8_420(luma_sum2, luma_sum3, luma_ptr + 8); sum = _mm_add_epi16(sum, final_row_result); + if (block_width_log2 == 5) { + const __m128i wide_fill = LastRowResult(final_row_result); + sum = _mm_add_epi16(sum, wide_fill); + sum = _mm_add_epi16(sum, wide_fill); + } final_sum = _mm_add_epi32(final_sum, _mm_cvtepu16_epi32(sum)); final_sum = _mm_add_epi32(final_sum, _mm_unpackhi_epi16(sum, zero)); src += stride << 1; luma_ptr += kCflLumaBufferStride; } while (++y < luma_height); - // Because max_luma_width is at most 32, any values beyond x=16 will - // necessarily be duplicated. - if (block_width_log2 == 5) { - const __m128i wide_fill = LastRowResult(final_row_result); - // Multiply duplicated value by number of occurrences, height * 4, since - // there are 16 in each row and the value appears in the vector 4 times. - final_sum = _mm_add_epi32( - final_sum, - _mm_slli_epi32(_mm_cvtepi16_epi32(wide_fill), block_height_log2 + 2)); - } - // Begin second y section. if (y < block_height) { const __m128i final_fill0 = LoadUnaligned16(luma_ptr - kCflLumaBufferStride); const __m128i final_fill1 = LoadUnaligned16(luma_ptr - kCflLumaBufferStride + 8); + __m128i wide_fill; + + if (block_width_log2 == 5) { + // There are 16 16-bit fill values per row, shifting by 2 accounts for + // the widening to 32-bit. + wide_fill = + _mm_slli_epi32(_mm_cvtepi16_epi32(LastRowResult(final_fill1)), 2); + } + const __m128i final_inner_sum = _mm_add_epi16(final_fill0, final_fill1); const __m128i final_inner_sum0 = _mm_cvtepu16_epi32(final_inner_sum); const __m128i final_inner_sum1 = _mm_unpackhi_epi16(final_inner_sum, zero); @@ -726,6 +715,9 @@ inline void CflSubsampler420Impl_WxH_SSE4_1( do { StoreUnaligned16(luma_ptr, final_fill0); StoreUnaligned16(luma_ptr + 8, final_fill1); + if (block_width_log2 == 5) { + final_sum = _mm_add_epi32(final_sum, wide_fill); + } luma_ptr += kCflLumaBufferStride; final_sum = _mm_add_epi32(final_sum, final_fill_to_sum); @@ -747,14 +739,10 @@ inline void CflSubsampler420Impl_WxH_SSE4_1( const __m128i samples1 = LoadUnaligned16(luma_ptr + 8); final_row_result = _mm_sub_epi16(samples1, averages); StoreUnaligned16(luma_ptr + 8, final_row_result); - } - if (block_width_log2 == 5) { - int16_t* wide_luma_ptr = luma[0] + 16; - const __m128i wide_fill = LastRowResult(final_row_result); - for (int i = 0; i < block_height; - ++i, wide_luma_ptr += kCflLumaBufferStride) { - StoreUnaligned16(wide_luma_ptr, wide_fill); - StoreUnaligned16(wide_luma_ptr + 8, wide_fill); + if (block_width_log2 == 5) { + const __m128i wide_fill = LastRowResult(final_row_result); + StoreUnaligned16(luma_ptr + 16, wide_fill); + StoreUnaligned16(luma_ptr + 24, wide_fill); } } } @@ -958,7 +946,882 @@ void Init8bpp() { } // namespace } // namespace low_bitdepth -void IntraPredCflInit_SSE4_1() { low_bitdepth::Init8bpp(); } +#if LIBGAV1_MAX_BITDEPTH >= 10 +namespace high_bitdepth { +namespace { + +//------------------------------------------------------------------------------ +// CflIntraPredictor_10bpp_SSE4_1 + +inline __m128i CflPredictUnclipped(const __m128i* input, __m128i alpha_q12, + __m128i alpha_sign, __m128i dc_q0) { + const __m128i ac_q3 = LoadUnaligned16(input); + const __m128i ac_sign = _mm_sign_epi16(alpha_sign, ac_q3); + __m128i scaled_luma_q0 = _mm_mulhrs_epi16(_mm_abs_epi16(ac_q3), alpha_q12); + scaled_luma_q0 = _mm_sign_epi16(scaled_luma_q0, ac_sign); + return _mm_add_epi16(scaled_luma_q0, dc_q0); +} + +inline __m128i ClipEpi16(__m128i x, __m128i min, __m128i max) { + return _mm_max_epi16(_mm_min_epi16(x, max), min); +} + +template <int width, int height> +void CflIntraPredictor_10bpp_SSE4_1( + void* const dest, ptrdiff_t stride, + const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int alpha) { + constexpr int kCflLumaBufferStrideLog2_16i = 5; + constexpr int kCflLumaBufferStrideLog2_128i = + kCflLumaBufferStrideLog2_16i - 3; + constexpr int kRowIncr = 1 << kCflLumaBufferStrideLog2_128i; + auto* dst = static_cast<uint16_t*>(dest); + const __m128i alpha_sign = _mm_set1_epi16(alpha); + const __m128i alpha_q12 = _mm_slli_epi16(_mm_abs_epi16(alpha_sign), 9); + auto* row = reinterpret_cast<const __m128i*>(luma); + const __m128i* row_end = row + (height << kCflLumaBufferStrideLog2_128i); + const __m128i dc_val = _mm_set1_epi16(dst[0]); + const __m128i min = _mm_setzero_si128(); + const __m128i max = _mm_set1_epi16((1 << kBitdepth10) - 1); + + stride >>= 1; + + do { + __m128i res = CflPredictUnclipped(row, alpha_q12, alpha_sign, dc_val); + res = ClipEpi16(res, min, max); + if (width == 4) { + StoreLo8(dst, res); + } else if (width == 8) { + StoreUnaligned16(dst, res); + } else if (width == 16) { + StoreUnaligned16(dst, res); + const __m128i res_1 = + CflPredictUnclipped(row + 1, alpha_q12, alpha_sign, dc_val); + StoreUnaligned16(dst + 8, ClipEpi16(res_1, min, max)); + } else { + StoreUnaligned16(dst, res); + const __m128i res_1 = + CflPredictUnclipped(row + 1, alpha_q12, alpha_sign, dc_val); + StoreUnaligned16(dst + 8, ClipEpi16(res_1, min, max)); + const __m128i res_2 = + CflPredictUnclipped(row + 2, alpha_q12, alpha_sign, dc_val); + StoreUnaligned16(dst + 16, ClipEpi16(res_2, min, max)); + const __m128i res_3 = + CflPredictUnclipped(row + 3, alpha_q12, alpha_sign, dc_val); + StoreUnaligned16(dst + 24, ClipEpi16(res_3, min, max)); + } + + dst += stride; + } while ((row += kRowIncr) < row_end); +} + +template <int block_height_log2, bool is_inside> +void CflSubsampler444_4xH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_height, const void* const source, ptrdiff_t stride) { + static_assert(block_height_log2 <= 4, ""); + const int block_height = 1 << block_height_log2; + const int visible_height = max_luma_height; + const auto* src = static_cast<const uint16_t*>(source); + const ptrdiff_t src_stride = stride / sizeof(src[0]); + int16_t* luma_ptr = luma[0]; + __m128i zero = _mm_setzero_si128(); + __m128i sum = zero; + __m128i samples; + int y = visible_height; + + do { + samples = LoadHi8(LoadLo8(src), src + src_stride); + src += src_stride << 1; + sum = _mm_add_epi16(sum, samples); + y -= 2; + } while (y != 0); + + if (!is_inside) { + y = visible_height; + samples = _mm_unpackhi_epi64(samples, samples); + do { + sum = _mm_add_epi16(sum, samples); + y += 2; + } while (y < block_height); + } + + sum = _mm_add_epi32(_mm_unpackhi_epi16(sum, zero), _mm_cvtepu16_epi32(sum)); + sum = _mm_add_epi32(sum, _mm_srli_si128(sum, 8)); + sum = _mm_add_epi32(sum, _mm_srli_si128(sum, 4)); + + // Here the left shift by 3 (to increase precision) is nullified in right + // shift ((log2 of width 4) + 1). + __m128i averages = RightShiftWithRounding_U32(sum, block_height_log2 - 1); + averages = _mm_shufflelo_epi16(averages, 0); + src = static_cast<const uint16_t*>(source); + luma_ptr = luma[0]; + y = visible_height; + do { + samples = LoadLo8(src); + samples = _mm_slli_epi16(samples, 3); + StoreLo8(luma_ptr, _mm_sub_epi16(samples, averages)); + src += src_stride; + luma_ptr += kCflLumaBufferStride; + } while (--y != 0); + + if (!is_inside) { + y = visible_height; + // Replicate last line + do { + StoreLo8(luma_ptr, _mm_sub_epi16(samples, averages)); + luma_ptr += kCflLumaBufferStride; + } while (++y < block_height); + } +} + +template <int block_height_log2> +void CflSubsampler444_4xH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_width, const int max_luma_height, + const void* const source, ptrdiff_t stride) { + static_cast<void>(max_luma_width); + static_cast<void>(max_luma_height); + static_assert(block_height_log2 <= 4, ""); + assert(max_luma_width >= 4); + assert(max_luma_height >= 4); + const int block_height = 1 << block_height_log2; + + if (block_height <= max_luma_height) { + CflSubsampler444_4xH_SSE4_1<block_height_log2, true>(luma, max_luma_height, + source, stride); + } else { + CflSubsampler444_4xH_SSE4_1<block_height_log2, false>(luma, max_luma_height, + source, stride); + } +} + +template <int block_height_log2, bool is_inside> +void CflSubsampler444_8xH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_height, const void* const source, ptrdiff_t stride) { + const int block_height = 1 << block_height_log2; + const int visible_height = max_luma_height; + const __m128i dup16 = _mm_set1_epi32(0x01000100); + const auto* src = static_cast<const uint16_t*>(source); + const ptrdiff_t src_stride = stride / sizeof(src[0]); + int16_t* luma_ptr = luma[0]; + const __m128i zero = _mm_setzero_si128(); + __m128i sum = zero; + __m128i samples; + int y = visible_height; + + do { + samples = LoadUnaligned16(src); + src += src_stride; + sum = _mm_add_epi16(sum, samples); + } while (--y != 0); + + if (!is_inside) { + y = visible_height; + do { + sum = _mm_add_epi16(sum, samples); + } while (++y < block_height); + } + + sum = _mm_add_epi32(_mm_unpackhi_epi16(sum, zero), _mm_cvtepu16_epi32(sum)); + sum = _mm_add_epi32(sum, _mm_srli_si128(sum, 8)); + sum = _mm_add_epi32(sum, _mm_srli_si128(sum, 4)); + + // Here the left shift by 3 (to increase precision) is nullified in right + // shift (log2 of width 8). + __m128i averages = RightShiftWithRounding_U32(sum, block_height_log2); + averages = _mm_shuffle_epi8(averages, dup16); + + src = static_cast<const uint16_t*>(source); + luma_ptr = luma[0]; + y = visible_height; + do { + samples = LoadUnaligned16(src); + samples = _mm_slli_epi16(samples, 3); + StoreUnaligned16(luma_ptr, _mm_sub_epi16(samples, averages)); + src += src_stride; + luma_ptr += kCflLumaBufferStride; + } while (--y != 0); + + if (!is_inside) { + y = visible_height; + // Replicate last line + do { + StoreUnaligned16(luma_ptr, _mm_sub_epi16(samples, averages)); + luma_ptr += kCflLumaBufferStride; + } while (++y < block_height); + } +} + +template <int block_height_log2> +void CflSubsampler444_8xH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_width, const int max_luma_height, + const void* const source, ptrdiff_t stride) { + static_cast<void>(max_luma_width); + static_cast<void>(max_luma_height); + static_assert(block_height_log2 <= 5, ""); + assert(max_luma_width >= 4); + assert(max_luma_height >= 4); + const int block_height = 1 << block_height_log2; + const int block_width = 8; + + const int horz_inside = block_width <= max_luma_width; + const int vert_inside = block_height <= max_luma_height; + if (horz_inside && vert_inside) { + CflSubsampler444_8xH_SSE4_1<block_height_log2, true>(luma, max_luma_height, + source, stride); + } else { + CflSubsampler444_8xH_SSE4_1<block_height_log2, false>(luma, max_luma_height, + source, stride); + } +} + +template <int block_width_log2, int block_height_log2, bool is_inside> +void CflSubsampler444_WxH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_width, const int max_luma_height, + const void* const source, ptrdiff_t stride) { + const int block_height = 1 << block_height_log2; + const int visible_height = max_luma_height; + const int block_width = 1 << block_width_log2; + const __m128i dup16 = _mm_set1_epi32(0x01000100); + const __m128i zero = _mm_setzero_si128(); + const auto* src = static_cast<const uint16_t*>(source); + const ptrdiff_t src_stride = stride / sizeof(src[0]); + int16_t* luma_ptr = luma[0]; + __m128i sum = zero; + __m128i inner_sum_lo, inner_sum_hi; + __m128i samples[4]; + int y = visible_height; + + do { + samples[0] = LoadUnaligned16(src); + samples[1] = (max_luma_width >= 16) ? LoadUnaligned16(src + 8) + : LastRowResult(samples[0]); + __m128i inner_sum = _mm_add_epi16(samples[0], samples[1]); + if (block_width == 32) { + samples[2] = (max_luma_width >= 24) ? LoadUnaligned16(src + 16) + : LastRowResult(samples[1]); + samples[3] = (max_luma_width == 32) ? LoadUnaligned16(src + 24) + : LastRowResult(samples[2]); + + inner_sum = _mm_add_epi16(samples[2], inner_sum); + inner_sum = _mm_add_epi16(samples[3], inner_sum); + } + inner_sum_lo = _mm_cvtepu16_epi32(inner_sum); + inner_sum_hi = _mm_unpackhi_epi16(inner_sum, zero); + sum = _mm_add_epi32(sum, inner_sum_lo); + sum = _mm_add_epi32(sum, inner_sum_hi); + src += src_stride; + } while (--y != 0); + + if (!is_inside) { + y = visible_height; + __m128i inner_sum = _mm_add_epi16(samples[0], samples[1]); + if (block_width == 32) { + inner_sum = _mm_add_epi16(samples[2], inner_sum); + inner_sum = _mm_add_epi16(samples[3], inner_sum); + } + inner_sum_lo = _mm_cvtepu16_epi32(inner_sum); + inner_sum_hi = _mm_unpackhi_epi16(inner_sum, zero); + do { + sum = _mm_add_epi32(sum, inner_sum_lo); + sum = _mm_add_epi32(sum, inner_sum_hi); + } while (++y < block_height); + } + + sum = _mm_add_epi32(sum, _mm_srli_si128(sum, 8)); + sum = _mm_add_epi32(sum, _mm_srli_si128(sum, 4)); + + // Here the left shift by 3 (to increase precision) is subtracted in right + // shift factor (block_width_log2 + block_height_log2 - 3). + __m128i averages = + RightShiftWithRounding_U32(sum, block_width_log2 + block_height_log2 - 3); + averages = _mm_shuffle_epi8(averages, dup16); + + src = static_cast<const uint16_t*>(source); + __m128i samples_ext = zero; + luma_ptr = luma[0]; + y = visible_height; + do { + int idx = 0; + for (int x = 0; x < block_width; x += 8) { + if (max_luma_width > x) { + samples[idx] = LoadUnaligned16(&src[x]); + samples[idx] = _mm_slli_epi16(samples[idx], 3); + samples_ext = samples[idx]; + } else { + samples[idx] = LastRowResult(samples_ext); + } + StoreUnaligned16(&luma_ptr[x], _mm_sub_epi16(samples[idx++], averages)); + } + src += src_stride; + luma_ptr += kCflLumaBufferStride; + } while (--y != 0); + + if (!is_inside) { + y = visible_height; + // Replicate last line + do { + int idx = 0; + for (int x = 0; x < block_width; x += 8) { + StoreUnaligned16(&luma_ptr[x], _mm_sub_epi16(samples[idx++], averages)); + } + luma_ptr += kCflLumaBufferStride; + } while (++y < block_height); + } +} + +template <int block_width_log2, int block_height_log2> +void CflSubsampler444_WxH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_width, const int max_luma_height, + const void* const source, ptrdiff_t stride) { + static_assert(block_width_log2 == 4 || block_width_log2 == 5, + "This function will only work for block_width 16 and 32."); + static_assert(block_height_log2 <= 5, ""); + assert(max_luma_width >= 4); + assert(max_luma_height >= 4); + + const int block_height = 1 << block_height_log2; + const int vert_inside = block_height <= max_luma_height; + if (vert_inside) { + CflSubsampler444_WxH_SSE4_1<block_width_log2, block_height_log2, true>( + luma, max_luma_width, max_luma_height, source, stride); + } else { + CflSubsampler444_WxH_SSE4_1<block_width_log2, block_height_log2, false>( + luma, max_luma_width, max_luma_height, source, stride); + } +} + +template <int block_height_log2> +void CflSubsampler420_4xH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int /*max_luma_width*/, const int max_luma_height, + const void* const source, ptrdiff_t stride) { + const int block_height = 1 << block_height_log2; + const auto* src = static_cast<const uint16_t*>(source); + const ptrdiff_t src_stride = stride / sizeof(src[0]); + int16_t* luma_ptr = luma[0]; + const __m128i zero = _mm_setzero_si128(); + __m128i final_sum = zero; + const int luma_height = std::min(block_height, max_luma_height >> 1); + int y = luma_height; + + do { + const __m128i samples_row0 = LoadUnaligned16(src); + src += src_stride; + const __m128i samples_row1 = LoadUnaligned16(src); + src += src_stride; + const __m128i luma_sum01 = _mm_add_epi16(samples_row0, samples_row1); + + const __m128i samples_row2 = LoadUnaligned16(src); + src += src_stride; + const __m128i samples_row3 = LoadUnaligned16(src); + src += src_stride; + const __m128i luma_sum23 = _mm_add_epi16(samples_row2, samples_row3); + __m128i sum = StoreLumaResults4_420(luma_sum01, luma_sum23, luma_ptr); + luma_ptr += kCflLumaBufferStride << 1; + + const __m128i samples_row4 = LoadUnaligned16(src); + src += src_stride; + const __m128i samples_row5 = LoadUnaligned16(src); + src += src_stride; + const __m128i luma_sum45 = _mm_add_epi16(samples_row4, samples_row5); + + const __m128i samples_row6 = LoadUnaligned16(src); + src += src_stride; + const __m128i samples_row7 = LoadUnaligned16(src); + src += src_stride; + const __m128i luma_sum67 = _mm_add_epi16(samples_row6, samples_row7); + sum = _mm_add_epi16( + sum, StoreLumaResults4_420(luma_sum45, luma_sum67, luma_ptr)); + luma_ptr += kCflLumaBufferStride << 1; + + final_sum = _mm_add_epi32(final_sum, _mm_cvtepu16_epi32(sum)); + final_sum = _mm_add_epi32(final_sum, _mm_unpackhi_epi16(sum, zero)); + y -= 4; + } while (y != 0); + + const __m128i final_fill = LoadLo8(luma_ptr - kCflLumaBufferStride); + const __m128i final_fill_to_sum = _mm_cvtepu16_epi32(final_fill); + for (y = luma_height; y < block_height; ++y) { + StoreLo8(luma_ptr, final_fill); + luma_ptr += kCflLumaBufferStride; + final_sum = _mm_add_epi32(final_sum, final_fill_to_sum); + } + final_sum = _mm_add_epi32(final_sum, _mm_srli_si128(final_sum, 8)); + final_sum = _mm_add_epi32(final_sum, _mm_srli_si128(final_sum, 4)); + + __m128i averages = RightShiftWithRounding_U32( + final_sum, block_height_log2 + 2 /*log2 of width 4*/); + + averages = _mm_shufflelo_epi16(averages, 0); + luma_ptr = luma[0]; + y = block_height; + do { + const __m128i samples = LoadLo8(luma_ptr); + StoreLo8(luma_ptr, _mm_sub_epi16(samples, averages)); + luma_ptr += kCflLumaBufferStride; + } while (--y != 0); +} + +template <int block_height_log2, int max_luma_width> +inline void CflSubsampler420Impl_8xH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_height, const void* const source, ptrdiff_t stride) { + const int block_height = 1 << block_height_log2; + const auto* src = static_cast<const uint16_t*>(source); + const ptrdiff_t src_stride = stride / sizeof(src[0]); + const __m128i zero = _mm_setzero_si128(); + __m128i final_sum = zero; + int16_t* luma_ptr = luma[0]; + const int luma_height = std::min(block_height, max_luma_height >> 1); + int y = luma_height; + + do { + const __m128i samples_row00 = LoadUnaligned16(src); + const __m128i samples_row01 = (max_luma_width == 16) + ? LoadUnaligned16(src + 8) + : LastRowSamples(samples_row00); + src += src_stride; + const __m128i samples_row10 = LoadUnaligned16(src); + const __m128i samples_row11 = (max_luma_width == 16) + ? LoadUnaligned16(src + 8) + : LastRowSamples(samples_row10); + src += src_stride; + const __m128i luma_sum00 = _mm_add_epi16(samples_row00, samples_row10); + const __m128i luma_sum01 = _mm_add_epi16(samples_row01, samples_row11); + __m128i sum = StoreLumaResults8_420(luma_sum00, luma_sum01, luma_ptr); + luma_ptr += kCflLumaBufferStride; + + const __m128i samples_row20 = LoadUnaligned16(src); + const __m128i samples_row21 = (max_luma_width == 16) + ? LoadUnaligned16(src + 8) + : LastRowSamples(samples_row20); + src += src_stride; + const __m128i samples_row30 = LoadUnaligned16(src); + const __m128i samples_row31 = (max_luma_width == 16) + ? LoadUnaligned16(src + 8) + : LastRowSamples(samples_row30); + src += src_stride; + const __m128i luma_sum10 = _mm_add_epi16(samples_row20, samples_row30); + const __m128i luma_sum11 = _mm_add_epi16(samples_row21, samples_row31); + sum = _mm_add_epi16( + sum, StoreLumaResults8_420(luma_sum10, luma_sum11, luma_ptr)); + luma_ptr += kCflLumaBufferStride; + + const __m128i samples_row40 = LoadUnaligned16(src); + const __m128i samples_row41 = (max_luma_width == 16) + ? LoadUnaligned16(src + 8) + : LastRowSamples(samples_row40); + src += src_stride; + const __m128i samples_row50 = LoadUnaligned16(src); + const __m128i samples_row51 = (max_luma_width == 16) + ? LoadUnaligned16(src + 8) + : LastRowSamples(samples_row50); + src += src_stride; + const __m128i luma_sum20 = _mm_add_epi16(samples_row40, samples_row50); + const __m128i luma_sum21 = _mm_add_epi16(samples_row41, samples_row51); + sum = _mm_add_epi16( + sum, StoreLumaResults8_420(luma_sum20, luma_sum21, luma_ptr)); + luma_ptr += kCflLumaBufferStride; + + const __m128i samples_row60 = LoadUnaligned16(src); + const __m128i samples_row61 = (max_luma_width == 16) + ? LoadUnaligned16(src + 8) + : LastRowSamples(samples_row60); + src += src_stride; + const __m128i samples_row70 = LoadUnaligned16(src); + const __m128i samples_row71 = (max_luma_width == 16) + ? LoadUnaligned16(src + 8) + : LastRowSamples(samples_row70); + src += src_stride; + const __m128i luma_sum30 = _mm_add_epi16(samples_row60, samples_row70); + const __m128i luma_sum31 = _mm_add_epi16(samples_row61, samples_row71); + sum = _mm_add_epi16( + sum, StoreLumaResults8_420(luma_sum30, luma_sum31, luma_ptr)); + luma_ptr += kCflLumaBufferStride; + + final_sum = _mm_add_epi32(final_sum, _mm_cvtepu16_epi32(sum)); + final_sum = _mm_add_epi32(final_sum, _mm_unpackhi_epi16(sum, zero)); + y -= 4; + } while (y != 0); + + // Duplicate the final row downward to the end after max_luma_height. + const __m128i final_fill = LoadUnaligned16(luma_ptr - kCflLumaBufferStride); + const __m128i final_fill_to_sum0 = _mm_cvtepi16_epi32(final_fill); + const __m128i final_fill_to_sum1 = + _mm_cvtepi16_epi32(_mm_srli_si128(final_fill, 8)); + const __m128i final_fill_to_sum = + _mm_add_epi32(final_fill_to_sum0, final_fill_to_sum1); + for (y = luma_height; y < block_height; ++y) { + StoreUnaligned16(luma_ptr, final_fill); + luma_ptr += kCflLumaBufferStride; + final_sum = _mm_add_epi32(final_sum, final_fill_to_sum); + } + final_sum = _mm_add_epi32(final_sum, _mm_srli_si128(final_sum, 8)); + final_sum = _mm_add_epi32(final_sum, _mm_srli_si128(final_sum, 4)); + + __m128i averages = RightShiftWithRounding_S32( + final_sum, block_height_log2 + 3 /*log2 of width 8*/); + + averages = _mm_shufflelo_epi16(averages, 0); + averages = _mm_shuffle_epi32(averages, 0); + luma_ptr = luma[0]; + y = block_height; + do { + const __m128i samples = LoadUnaligned16(luma_ptr); + StoreUnaligned16(luma_ptr, _mm_sub_epi16(samples, averages)); + luma_ptr += kCflLumaBufferStride; + } while (--y != 0); +} + +template <int block_height_log2> +void CflSubsampler420_8xH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_width, const int max_luma_height, + const void* const source, ptrdiff_t stride) { + if (max_luma_width == 8) { + CflSubsampler420Impl_8xH_SSE4_1<block_height_log2, 8>(luma, max_luma_height, + source, stride); + } else { + CflSubsampler420Impl_8xH_SSE4_1<block_height_log2, 16>( + luma, max_luma_height, source, stride); + } +} + +template <int block_width_log2, int block_height_log2, int max_luma_width> +inline void CflSubsampler420Impl_WxH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_height, const void* const source, ptrdiff_t stride) { + const auto* src = static_cast<const uint16_t*>(source); + const ptrdiff_t src_stride = stride / sizeof(src[0]); + const __m128i zero = _mm_setzero_si128(); + __m128i final_sum = zero; + const int block_height = 1 << block_height_log2; + const int luma_height = std::min(block_height, max_luma_height >> 1); + int16_t* luma_ptr = luma[0]; + __m128i final_row_result; + // Begin first y section, covering width up to 32. + int y = luma_height; + + do { + const uint16_t* src_next = src + src_stride; + const __m128i samples_row00 = LoadUnaligned16(src); + const __m128i samples_row01 = (max_luma_width >= 16) + ? LoadUnaligned16(src + 8) + : LastRowSamples(samples_row00); + const __m128i samples_row02 = (max_luma_width >= 24) + ? LoadUnaligned16(src + 16) + : LastRowSamples(samples_row01); + const __m128i samples_row03 = (max_luma_width == 32) + ? LoadUnaligned16(src + 24) + : LastRowSamples(samples_row02); + const __m128i samples_row10 = LoadUnaligned16(src_next); + const __m128i samples_row11 = (max_luma_width >= 16) + ? LoadUnaligned16(src_next + 8) + : LastRowSamples(samples_row10); + const __m128i samples_row12 = (max_luma_width >= 24) + ? LoadUnaligned16(src_next + 16) + : LastRowSamples(samples_row11); + const __m128i samples_row13 = (max_luma_width == 32) + ? LoadUnaligned16(src_next + 24) + : LastRowSamples(samples_row12); + const __m128i luma_sum0 = _mm_add_epi16(samples_row00, samples_row10); + const __m128i luma_sum1 = _mm_add_epi16(samples_row01, samples_row11); + const __m128i luma_sum2 = _mm_add_epi16(samples_row02, samples_row12); + const __m128i luma_sum3 = _mm_add_epi16(samples_row03, samples_row13); + __m128i sum = StoreLumaResults8_420(luma_sum0, luma_sum1, luma_ptr); + final_row_result = + StoreLumaResults8_420(luma_sum2, luma_sum3, luma_ptr + 8); + sum = _mm_add_epi16(sum, final_row_result); + final_sum = _mm_add_epi32(final_sum, _mm_cvtepu16_epi32(sum)); + final_sum = _mm_add_epi32(final_sum, _mm_unpackhi_epi16(sum, zero)); + + // Because max_luma_width is at most 32, any values beyond x=16 will + // necessarily be duplicated. + if (block_width_log2 == 5) { + const __m128i wide_fill = LastRowResult(final_row_result); + // There are 16 16-bit fill values per row, shifting by 2 accounts for + // the widening to 32-bit. + final_sum = _mm_add_epi32( + final_sum, _mm_slli_epi32(_mm_cvtepi16_epi32(wide_fill), 2)); + } + src += src_stride << 1; + luma_ptr += kCflLumaBufferStride; + } while (--y != 0); + + // Begin second y section. + y = luma_height; + if (y < block_height) { + const __m128i final_fill0 = + LoadUnaligned16(luma_ptr - kCflLumaBufferStride); + const __m128i final_fill1 = + LoadUnaligned16(luma_ptr - kCflLumaBufferStride + 8); + __m128i wide_fill; + if (block_width_log2 == 5) { + // There are 16 16-bit fill values per row, shifting by 2 accounts for + // the widening to 32-bit. + wide_fill = + _mm_slli_epi32(_mm_cvtepi16_epi32(LastRowResult(final_fill1)), 2); + } + const __m128i final_inner_sum = _mm_add_epi16(final_fill0, final_fill1); + const __m128i final_inner_sum0 = _mm_cvtepu16_epi32(final_inner_sum); + const __m128i final_inner_sum1 = _mm_unpackhi_epi16(final_inner_sum, zero); + const __m128i final_fill_to_sum = + _mm_add_epi32(final_inner_sum0, final_inner_sum1); + + do { + StoreUnaligned16(luma_ptr, final_fill0); + StoreUnaligned16(luma_ptr + 8, final_fill1); + if (block_width_log2 == 5) { + final_sum = _mm_add_epi32(final_sum, wide_fill); + } + luma_ptr += kCflLumaBufferStride; + final_sum = _mm_add_epi32(final_sum, final_fill_to_sum); + } while (++y < block_height); + } // End second y section. + + final_sum = _mm_add_epi32(final_sum, _mm_srli_si128(final_sum, 8)); + final_sum = _mm_add_epi32(final_sum, _mm_srli_si128(final_sum, 4)); + + __m128i averages = RightShiftWithRounding_S32( + final_sum, block_width_log2 + block_height_log2); + averages = _mm_shufflelo_epi16(averages, 0); + averages = _mm_shuffle_epi32(averages, 0); + + luma_ptr = luma[0]; + y = block_height; + do { + const __m128i samples0 = LoadUnaligned16(luma_ptr); + StoreUnaligned16(luma_ptr, _mm_sub_epi16(samples0, averages)); + const __m128i samples1 = LoadUnaligned16(luma_ptr + 8); + final_row_result = _mm_sub_epi16(samples1, averages); + StoreUnaligned16(luma_ptr + 8, final_row_result); + + if (block_width_log2 == 5) { + const __m128i wide_fill = LastRowResult(final_row_result); + StoreUnaligned16(luma_ptr + 16, wide_fill); + StoreUnaligned16(luma_ptr + 24, wide_fill); + } + luma_ptr += kCflLumaBufferStride; + } while (--y != 0); +} + +template <int block_width_log2, int block_height_log2> +void CflSubsampler420_WxH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_width, const int max_luma_height, + const void* const source, ptrdiff_t stride) { + switch (max_luma_width) { + case 8: + CflSubsampler420Impl_WxH_SSE4_1<block_width_log2, block_height_log2, 8>( + luma, max_luma_height, source, stride); + return; + case 16: + CflSubsampler420Impl_WxH_SSE4_1<block_width_log2, block_height_log2, 16>( + luma, max_luma_height, source, stride); + return; + case 24: + CflSubsampler420Impl_WxH_SSE4_1<block_width_log2, block_height_log2, 24>( + luma, max_luma_height, source, stride); + return; + default: + assert(max_luma_width == 32); + CflSubsampler420Impl_WxH_SSE4_1<block_width_log2, block_height_log2, 32>( + luma, max_luma_height, source, stride); + return; + } +} + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize4x4_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize4x4] = + CflIntraPredictor_10bpp_SSE4_1<4, 4>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize4x8_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize4x8] = + CflIntraPredictor_10bpp_SSE4_1<4, 8>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize4x16_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize4x16] = + CflIntraPredictor_10bpp_SSE4_1<4, 16>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize8x4_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize8x4] = + CflIntraPredictor_10bpp_SSE4_1<8, 4>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize8x8_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize8x8] = + CflIntraPredictor_10bpp_SSE4_1<8, 8>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize8x16_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize8x16] = + CflIntraPredictor_10bpp_SSE4_1<8, 16>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize8x32_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize8x32] = + CflIntraPredictor_10bpp_SSE4_1<8, 32>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize16x4_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize16x4] = + CflIntraPredictor_10bpp_SSE4_1<16, 4>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize16x8_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize16x8] = + CflIntraPredictor_10bpp_SSE4_1<16, 8>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize16x16_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize16x16] = + CflIntraPredictor_10bpp_SSE4_1<16, 16>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize16x32_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize16x32] = + CflIntraPredictor_10bpp_SSE4_1<16, 32>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize32x8_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize32x8] = + CflIntraPredictor_10bpp_SSE4_1<32, 8>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize32x16_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize32x16] = + CflIntraPredictor_10bpp_SSE4_1<32, 16>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize32x32_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize32x32] = + CflIntraPredictor_10bpp_SSE4_1<32, 32>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize4x4_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType420] = + CflSubsampler420_4xH_SSE4_1<2>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize4x8_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType420] = + CflSubsampler420_4xH_SSE4_1<3>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize4x16_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType420] = + CflSubsampler420_4xH_SSE4_1<4>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize8x4_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType420] = + CflSubsampler420_8xH_SSE4_1<2>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize8x8_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType420] = + CflSubsampler420_8xH_SSE4_1<3>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize8x16_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType420] = + CflSubsampler420_8xH_SSE4_1<4>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize8x32_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType420] = + CflSubsampler420_8xH_SSE4_1<5>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize16x4_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType420] = + CflSubsampler420_WxH_SSE4_1<4, 2>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize16x8_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType420] = + CflSubsampler420_WxH_SSE4_1<4, 3>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize16x16_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType420] = + CflSubsampler420_WxH_SSE4_1<4, 4>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize16x32_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType420] = + CflSubsampler420_WxH_SSE4_1<4, 5>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize32x8_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType420] = + CflSubsampler420_WxH_SSE4_1<5, 3>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize32x16_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType420] = + CflSubsampler420_WxH_SSE4_1<5, 4>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize32x32_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType420] = + CflSubsampler420_WxH_SSE4_1<5, 5>; +#endif + +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize4x4_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType444] = + CflSubsampler444_4xH_SSE4_1<2>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize4x8_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType444] = + CflSubsampler444_4xH_SSE4_1<3>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize4x16_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType444] = + CflSubsampler444_4xH_SSE4_1<4>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize8x4_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType444] = + CflSubsampler444_8xH_SSE4_1<2>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize8x8_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType444] = + CflSubsampler444_8xH_SSE4_1<3>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize8x16_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType444] = + CflSubsampler444_8xH_SSE4_1<4>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize8x32_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType444] = + CflSubsampler444_8xH_SSE4_1<5>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize16x4_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType444] = + CflSubsampler444_WxH_SSE4_1<4, 2>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize16x8_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType444] = + CflSubsampler444_WxH_SSE4_1<4, 3>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize16x16_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType444] = + CflSubsampler444_WxH_SSE4_1<4, 4>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize16x32_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType444] = + CflSubsampler444_WxH_SSE4_1<4, 5>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize32x8_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType444] = + CflSubsampler444_WxH_SSE4_1<5, 3>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize32x16_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType444] = + CflSubsampler444_WxH_SSE4_1<5, 4>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize32x32_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType444] = + CflSubsampler444_WxH_SSE4_1<5, 5>; +#endif +} + +} // namespace +} // namespace high_bitdepth +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +void IntraPredCflInit_SSE4_1() { + low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + high_bitdepth::Init10bpp(); +#endif // LIBGAV1_MAX_BITDEPTH >= 10 +} } // namespace dsp } // namespace libgav1 diff --git a/src/dsp/x86/intrapred_cfl_sse4.h b/src/dsp/x86/intrapred_cfl_sse4.h new file mode 100644 index 0000000..5d1a425 --- /dev/null +++ b/src/dsp/x86/intrapred_cfl_sse4.h @@ -0,0 +1,376 @@ +/* + * Copyright 2021 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_INTRAPRED_CFL_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_INTRAPRED_CFL_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::cfl_intra_predictors and Dsp::cfl_subsamplers, see the +// defines below for specifics. These functions are not thread-safe. +void IntraPredCflInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +// If sse4 is enabled and the baseline isn't set due to a higher level of +// optimization being enabled, signal the sse4 implementation should be used. +#if LIBGAV1_TARGETING_SSE4_1 +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +//------------------------------------------------------------------------------ +// 10bpp + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_CflSubsampler420 +#define LIBGAV1_Dsp10bpp_TransformSize4x4_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x8_CflSubsampler420 +#define LIBGAV1_Dsp10bpp_TransformSize4x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x16_CflSubsampler420 +#define LIBGAV1_Dsp10bpp_TransformSize4x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x4_CflSubsampler420 +#define LIBGAV1_Dsp10bpp_TransformSize8x4_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x8_CflSubsampler420 +#define LIBGAV1_Dsp10bpp_TransformSize8x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x16_CflSubsampler420 +#define LIBGAV1_Dsp10bpp_TransformSize8x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x32_CflSubsampler420 +#define LIBGAV1_Dsp10bpp_TransformSize8x32_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x4_CflSubsampler420 +#define LIBGAV1_Dsp10bpp_TransformSize16x4_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x8_CflSubsampler420 +#define LIBGAV1_Dsp10bpp_TransformSize16x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x16_CflSubsampler420 +#define LIBGAV1_Dsp10bpp_TransformSize16x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x32_CflSubsampler420 +#define LIBGAV1_Dsp10bpp_TransformSize16x32_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x8_CflSubsampler420 +#define LIBGAV1_Dsp10bpp_TransformSize32x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_CflSubsampler420 +#define LIBGAV1_Dsp10bpp_TransformSize32x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x32_CflSubsampler420 +#define LIBGAV1_Dsp10bpp_TransformSize32x32_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_CflSubsampler444 +#define LIBGAV1_Dsp10bpp_TransformSize4x4_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x8_CflSubsampler444 +#define LIBGAV1_Dsp10bpp_TransformSize4x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x16_CflSubsampler444 +#define LIBGAV1_Dsp10bpp_TransformSize4x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x4_CflSubsampler444 +#define LIBGAV1_Dsp10bpp_TransformSize8x4_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x8_CflSubsampler444 +#define LIBGAV1_Dsp10bpp_TransformSize8x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x16_CflSubsampler444 +#define LIBGAV1_Dsp10bpp_TransformSize8x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x32_CflSubsampler444 +#define LIBGAV1_Dsp10bpp_TransformSize8x32_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x4_CflSubsampler444 +#define LIBGAV1_Dsp10bpp_TransformSize16x4_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x8_CflSubsampler444 +#define LIBGAV1_Dsp10bpp_TransformSize16x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x16_CflSubsampler444 +#define LIBGAV1_Dsp10bpp_TransformSize16x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x32_CflSubsampler444 +#define LIBGAV1_Dsp10bpp_TransformSize16x32_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x8_CflSubsampler444 +#define LIBGAV1_Dsp10bpp_TransformSize32x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_CflSubsampler444 +#define LIBGAV1_Dsp10bpp_TransformSize32x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x32_CflSubsampler444 +#define LIBGAV1_Dsp10bpp_TransformSize32x32_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_CflIntraPredictor +#define LIBGAV1_Dsp10bpp_TransformSize4x4_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x8_CflIntraPredictor +#define LIBGAV1_Dsp10bpp_TransformSize4x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x16_CflIntraPredictor +#define LIBGAV1_Dsp10bpp_TransformSize4x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x4_CflIntraPredictor +#define LIBGAV1_Dsp10bpp_TransformSize8x4_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x8_CflIntraPredictor +#define LIBGAV1_Dsp10bpp_TransformSize8x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x16_CflIntraPredictor +#define LIBGAV1_Dsp10bpp_TransformSize8x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x32_CflIntraPredictor +#define LIBGAV1_Dsp10bpp_TransformSize8x32_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x4_CflIntraPredictor +#define LIBGAV1_Dsp10bpp_TransformSize16x4_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x8_CflIntraPredictor +#define LIBGAV1_Dsp10bpp_TransformSize16x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x16_CflIntraPredictor +#define LIBGAV1_Dsp10bpp_TransformSize16x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x32_CflIntraPredictor +#define LIBGAV1_Dsp10bpp_TransformSize16x32_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x8_CflIntraPredictor +#define LIBGAV1_Dsp10bpp_TransformSize32x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_CflIntraPredictor +#define LIBGAV1_Dsp10bpp_TransformSize32x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x32_CflIntraPredictor +#define LIBGAV1_Dsp10bpp_TransformSize32x32_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_X86_INTRAPRED_CFL_SSE4_H_ diff --git a/src/dsp/x86/intrapred_directional_sse4.cc b/src/dsp/x86/intrapred_directional_sse4.cc new file mode 100644 index 0000000..e642aee --- /dev/null +++ b/src/dsp/x86/intrapred_directional_sse4.cc @@ -0,0 +1,1478 @@ +// Copyright 2021 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intrapred_directional.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 + +#include <smmintrin.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/dsp/x86/transpose_sse4.h" +#include "src/utils/common.h" +#include "src/utils/memory.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +//------------------------------------------------------------------------------ +// 7.11.2.4. Directional intra prediction process + +// Special case: An |xstep| of 64 corresponds to an angle delta of 45, meaning +// upsampling is ruled out. In addition, the bits masked by 0x3F for +// |shift_val| are 0 for all multiples of 64, so the formula +// val = top[top_base_x]*shift + top[top_base_x+1]*(32-shift), reduces to +// val = top[top_base_x+1] << 5, meaning only the second set of pixels is +// involved in the output. Hence |top| is offset by 1. +inline void DirectionalZone1_Step64(uint8_t* dst, ptrdiff_t stride, + const uint8_t* const top, const int width, + const int height) { + ptrdiff_t offset = 1; + if (height == 4) { + memcpy(dst, top + offset, width); + dst += stride; + memcpy(dst, top + offset + 1, width); + dst += stride; + memcpy(dst, top + offset + 2, width); + dst += stride; + memcpy(dst, top + offset + 3, width); + return; + } + int y = 0; + do { + memcpy(dst, top + offset, width); + dst += stride; + memcpy(dst, top + offset + 1, width); + dst += stride; + memcpy(dst, top + offset + 2, width); + dst += stride; + memcpy(dst, top + offset + 3, width); + dst += stride; + memcpy(dst, top + offset + 4, width); + dst += stride; + memcpy(dst, top + offset + 5, width); + dst += stride; + memcpy(dst, top + offset + 6, width); + dst += stride; + memcpy(dst, top + offset + 7, width); + dst += stride; + + offset += 8; + y += 8; + } while (y < height); +} + +inline void DirectionalZone1_4xH(uint8_t* dst, ptrdiff_t stride, + const uint8_t* const top, const int height, + const int xstep, const bool upsampled) { + const int upsample_shift = static_cast<int>(upsampled); + const int scale_bits = 6 - upsample_shift; + const __m128i max_shift = _mm_set1_epi8(32); + // Downscaling for a weighted average whose weights sum to 32 (max_shift). + const int rounding_bits = 5; + const int max_base_x = (height + 3 /* width - 1 */) << upsample_shift; + const __m128i final_top_val = _mm_set1_epi16(top[max_base_x]); + const __m128i sampler = upsampled ? _mm_set_epi64x(0, 0x0706050403020100) + : _mm_set_epi64x(0, 0x0403030202010100); + // Each 16-bit value here corresponds to a position that may exceed + // |max_base_x|. When added to the top_base_x, it is used to mask values + // that pass the end of |top|. Starting from 1 to simulate "cmpge" which is + // not supported for packed integers. + const __m128i offsets = + _mm_set_epi32(0x00080007, 0x00060005, 0x00040003, 0x00020001); + + // All rows from |min_corner_only_y| down will simply use memcpy. |max_base_x| + // is always greater than |height|, so clipping to 1 is enough to make the + // logic work. + const int xstep_units = std::max(xstep >> scale_bits, 1); + const int min_corner_only_y = std::min(max_base_x / xstep_units, height); + + // Rows up to this y-value can be computed without checking for bounds. + int y = 0; + int top_x = xstep; + + for (; y < min_corner_only_y; ++y, dst += stride, top_x += xstep) { + const int top_base_x = top_x >> scale_bits; + + // Permit negative values of |top_x|. + const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; + const __m128i shift = _mm_set1_epi8(shift_val); + const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); + const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); + __m128i top_index_vect = _mm_set1_epi16(top_base_x); + top_index_vect = _mm_add_epi16(top_index_vect, offsets); + const __m128i max_base_x_vect = _mm_set1_epi16(max_base_x); + + // Load 8 values because we will select the sampled values based on + // |upsampled|. + const __m128i values = LoadLo8(top + top_base_x); + const __m128i sampled_values = _mm_shuffle_epi8(values, sampler); + const __m128i past_max = _mm_cmpgt_epi16(top_index_vect, max_base_x_vect); + __m128i prod = _mm_maddubs_epi16(sampled_values, shifts); + prod = RightShiftWithRounding_U16(prod, rounding_bits); + // Replace pixels from invalid range with top-right corner. + prod = _mm_blendv_epi8(prod, final_top_val, past_max); + Store4(dst, _mm_packus_epi16(prod, prod)); + } + + // Fill in corner-only rows. + for (; y < height; ++y) { + memset(dst, top[max_base_x], /* width */ 4); + dst += stride; + } +} + +// 7.11.2.4 (7) angle < 90 +inline void DirectionalZone1_Large(uint8_t* dest, ptrdiff_t stride, + const uint8_t* const top_row, + const int width, const int height, + const int xstep, const bool upsampled) { + const int upsample_shift = static_cast<int>(upsampled); + const __m128i sampler = + upsampled ? _mm_set_epi32(0x0F0E0D0C, 0x0B0A0908, 0x07060504, 0x03020100) + : _mm_set_epi32(0x08070706, 0x06050504, 0x04030302, 0x02010100); + const int scale_bits = 6 - upsample_shift; + const int max_base_x = ((width + height) - 1) << upsample_shift; + + const __m128i max_shift = _mm_set1_epi8(32); + // Downscaling for a weighted average whose weights sum to 32 (max_shift). + const int rounding_bits = 5; + const int base_step = 1 << upsample_shift; + const int base_step8 = base_step << 3; + + // All rows from |min_corner_only_y| down will simply use memcpy. |max_base_x| + // is always greater than |height|, so clipping to 1 is enough to make the + // logic work. + const int xstep_units = std::max(xstep >> scale_bits, 1); + const int min_corner_only_y = std::min(max_base_x / xstep_units, height); + + // Rows up to this y-value can be computed without checking for bounds. + const int max_no_corner_y = std::min( + LeftShift((max_base_x - (base_step * width)), scale_bits) / xstep, + height); + // No need to check for exceeding |max_base_x| in the first loop. + int y = 0; + int top_x = xstep; + for (; y < max_no_corner_y; ++y, dest += stride, top_x += xstep) { + int top_base_x = top_x >> scale_bits; + // Permit negative values of |top_x|. + const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; + const __m128i shift = _mm_set1_epi8(shift_val); + const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); + const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); + int x = 0; + do { + const __m128i top_vals = LoadUnaligned16(top_row + top_base_x); + __m128i vals = _mm_shuffle_epi8(top_vals, sampler); + vals = _mm_maddubs_epi16(vals, shifts); + vals = RightShiftWithRounding_U16(vals, rounding_bits); + StoreLo8(dest + x, _mm_packus_epi16(vals, vals)); + top_base_x += base_step8; + x += 8; + } while (x < width); + } + + // Each 16-bit value here corresponds to a position that may exceed + // |max_base_x|. When added to the top_base_x, it is used to mask values + // that pass the end of |top|. Starting from 1 to simulate "cmpge" which is + // not supported for packed integers. + const __m128i offsets = + _mm_set_epi32(0x00080007, 0x00060005, 0x00040003, 0x00020001); + + const __m128i max_base_x_vect = _mm_set1_epi16(max_base_x); + const __m128i final_top_val = _mm_set1_epi16(top_row[max_base_x]); + const __m128i base_step8_vect = _mm_set1_epi16(base_step8); + for (; y < min_corner_only_y; ++y, dest += stride, top_x += xstep) { + int top_base_x = top_x >> scale_bits; + + const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; + const __m128i shift = _mm_set1_epi8(shift_val); + const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); + const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); + __m128i top_index_vect = _mm_set1_epi16(top_base_x); + top_index_vect = _mm_add_epi16(top_index_vect, offsets); + + int x = 0; + const int min_corner_only_x = + std::min(width, ((max_base_x - top_base_x) >> upsample_shift) + 7) & ~7; + for (; x < min_corner_only_x; + x += 8, top_base_x += base_step8, + top_index_vect = _mm_add_epi16(top_index_vect, base_step8_vect)) { + const __m128i past_max = _mm_cmpgt_epi16(top_index_vect, max_base_x_vect); + // Assuming a buffer zone of 8 bytes at the end of top_row, this prevents + // reading out of bounds. If all indices are past max and we don't need to + // use the loaded bytes at all, |top_base_x| becomes 0. |top_base_x| will + // reset for the next |y|. + top_base_x &= ~_mm_cvtsi128_si32(past_max); + const __m128i top_vals = LoadUnaligned16(top_row + top_base_x); + __m128i vals = _mm_shuffle_epi8(top_vals, sampler); + vals = _mm_maddubs_epi16(vals, shifts); + vals = RightShiftWithRounding_U16(vals, rounding_bits); + vals = _mm_blendv_epi8(vals, final_top_val, past_max); + StoreLo8(dest + x, _mm_packus_epi16(vals, vals)); + } + // Corner-only section of the row. + memset(dest + x, top_row[max_base_x], width - x); + } + // Fill in corner-only rows. + for (; y < height; ++y) { + memset(dest, top_row[max_base_x], width); + dest += stride; + } +} + +// 7.11.2.4 (7) angle < 90 +inline void DirectionalZone1_SSE4_1(uint8_t* dest, ptrdiff_t stride, + const uint8_t* const top_row, + const int width, const int height, + const int xstep, const bool upsampled) { + const int upsample_shift = static_cast<int>(upsampled); + if (xstep == 64) { + DirectionalZone1_Step64(dest, stride, top_row, width, height); + return; + } + if (width == 4) { + DirectionalZone1_4xH(dest, stride, top_row, height, xstep, upsampled); + return; + } + if (width >= 32) { + DirectionalZone1_Large(dest, stride, top_row, width, height, xstep, + upsampled); + return; + } + const __m128i sampler = + upsampled ? _mm_set_epi32(0x0F0E0D0C, 0x0B0A0908, 0x07060504, 0x03020100) + : _mm_set_epi32(0x08070706, 0x06050504, 0x04030302, 0x02010100); + const int scale_bits = 6 - upsample_shift; + const int max_base_x = ((width + height) - 1) << upsample_shift; + + const __m128i max_shift = _mm_set1_epi8(32); + // Downscaling for a weighted average whose weights sum to 32 (max_shift). + const int rounding_bits = 5; + const int base_step = 1 << upsample_shift; + const int base_step8 = base_step << 3; + + // No need to check for exceeding |max_base_x| in the loops. + if (((xstep * height) >> scale_bits) + base_step * width < max_base_x) { + int top_x = xstep; + int y = 0; + do { + int top_base_x = top_x >> scale_bits; + // Permit negative values of |top_x|. + const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; + const __m128i shift = _mm_set1_epi8(shift_val); + const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); + const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); + int x = 0; + do { + const __m128i top_vals = LoadUnaligned16(top_row + top_base_x); + __m128i vals = _mm_shuffle_epi8(top_vals, sampler); + vals = _mm_maddubs_epi16(vals, shifts); + vals = RightShiftWithRounding_U16(vals, rounding_bits); + StoreLo8(dest + x, _mm_packus_epi16(vals, vals)); + top_base_x += base_step8; + x += 8; + } while (x < width); + dest += stride; + top_x += xstep; + } while (++y < height); + return; + } + + // Each 16-bit value here corresponds to a position that may exceed + // |max_base_x|. When added to the top_base_x, it is used to mask values + // that pass the end of |top|. Starting from 1 to simulate "cmpge" which is + // not supported for packed integers. + const __m128i offsets = + _mm_set_epi32(0x00080007, 0x00060005, 0x00040003, 0x00020001); + + const __m128i max_base_x_vect = _mm_set1_epi16(max_base_x); + const __m128i final_top_val = _mm_set1_epi16(top_row[max_base_x]); + const __m128i base_step8_vect = _mm_set1_epi16(base_step8); + int top_x = xstep; + int y = 0; + do { + int top_base_x = top_x >> scale_bits; + + if (top_base_x >= max_base_x) { + for (int i = y; i < height; ++i) { + memset(dest, top_row[max_base_x], width); + dest += stride; + } + return; + } + + const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; + const __m128i shift = _mm_set1_epi8(shift_val); + const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); + const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); + __m128i top_index_vect = _mm_set1_epi16(top_base_x); + top_index_vect = _mm_add_epi16(top_index_vect, offsets); + + int x = 0; + for (; x < width - 8; + x += 8, top_base_x += base_step8, + top_index_vect = _mm_add_epi16(top_index_vect, base_step8_vect)) { + const __m128i past_max = _mm_cmpgt_epi16(top_index_vect, max_base_x_vect); + // Assuming a buffer zone of 8 bytes at the end of top_row, this prevents + // reading out of bounds. If all indices are past max and we don't need to + // use the loaded bytes at all, |top_base_x| becomes 0. |top_base_x| will + // reset for the next |y|. + top_base_x &= ~_mm_cvtsi128_si32(past_max); + const __m128i top_vals = LoadUnaligned16(top_row + top_base_x); + __m128i vals = _mm_shuffle_epi8(top_vals, sampler); + vals = _mm_maddubs_epi16(vals, shifts); + vals = RightShiftWithRounding_U16(vals, rounding_bits); + vals = _mm_blendv_epi8(vals, final_top_val, past_max); + StoreLo8(dest + x, _mm_packus_epi16(vals, vals)); + } + const __m128i past_max = _mm_cmpgt_epi16(top_index_vect, max_base_x_vect); + __m128i vals; + if (upsampled) { + vals = LoadUnaligned16(top_row + top_base_x); + } else { + const __m128i top_vals = LoadLo8(top_row + top_base_x); + vals = _mm_shuffle_epi8(top_vals, sampler); + vals = _mm_insert_epi8(vals, top_row[top_base_x + 8], 15); + } + vals = _mm_maddubs_epi16(vals, shifts); + vals = RightShiftWithRounding_U16(vals, rounding_bits); + vals = _mm_blendv_epi8(vals, final_top_val, past_max); + StoreLo8(dest + x, _mm_packus_epi16(vals, vals)); + dest += stride; + top_x += xstep; + } while (++y < height); +} + +void DirectionalIntraPredictorZone1_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const int width, const int height, + const int xstep, + const bool upsampled_top) { + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + auto* dst = static_cast<uint8_t*>(dest); + DirectionalZone1_SSE4_1(dst, stride, top_ptr, width, height, xstep, + upsampled_top); +} + +template <bool upsampled> +inline void DirectionalZone3_4x4(uint8_t* dest, ptrdiff_t stride, + const uint8_t* const left_column, + const int base_left_y, const int ystep) { + // For use in the non-upsampled case. + const __m128i sampler = _mm_set_epi64x(0, 0x0403030202010100); + const int upsample_shift = static_cast<int>(upsampled); + const int scale_bits = 6 - upsample_shift; + const __m128i max_shift = _mm_set1_epi8(32); + // Downscaling for a weighted average whose weights sum to 32 (max_shift). + const int rounding_bits = 5; + + __m128i result_block[4]; + for (int x = 0, left_y = base_left_y; x < 4; x++, left_y += ystep) { + const int left_base_y = left_y >> scale_bits; + const int shift_val = ((left_y << upsample_shift) & 0x3F) >> 1; + const __m128i shift = _mm_set1_epi8(shift_val); + const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); + const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); + __m128i vals; + if (upsampled) { + vals = LoadLo8(left_column + left_base_y); + } else { + const __m128i top_vals = LoadLo8(left_column + left_base_y); + vals = _mm_shuffle_epi8(top_vals, sampler); + } + vals = _mm_maddubs_epi16(vals, shifts); + vals = RightShiftWithRounding_U16(vals, rounding_bits); + result_block[x] = _mm_packus_epi16(vals, vals); + } + const __m128i result = Transpose4x4_U8(result_block); + // This is result_row0. + Store4(dest, result); + dest += stride; + const int result_row1 = _mm_extract_epi32(result, 1); + memcpy(dest, &result_row1, sizeof(result_row1)); + dest += stride; + const int result_row2 = _mm_extract_epi32(result, 2); + memcpy(dest, &result_row2, sizeof(result_row2)); + dest += stride; + const int result_row3 = _mm_extract_epi32(result, 3); + memcpy(dest, &result_row3, sizeof(result_row3)); +} + +template <bool upsampled, int height> +inline void DirectionalZone3_8xH(uint8_t* dest, ptrdiff_t stride, + const uint8_t* const left_column, + const int base_left_y, const int ystep) { + // For use in the non-upsampled case. + const __m128i sampler = + _mm_set_epi64x(0x0807070606050504, 0x0403030202010100); + const int upsample_shift = static_cast<int>(upsampled); + const int scale_bits = 6 - upsample_shift; + const __m128i max_shift = _mm_set1_epi8(32); + // Downscaling for a weighted average whose weights sum to 32 (max_shift). + const int rounding_bits = 5; + + __m128i result_block[8]; + for (int x = 0, left_y = base_left_y; x < 8; x++, left_y += ystep) { + const int left_base_y = left_y >> scale_bits; + const int shift_val = (LeftShift(left_y, upsample_shift) & 0x3F) >> 1; + const __m128i shift = _mm_set1_epi8(shift_val); + const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); + const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); + __m128i vals; + if (upsampled) { + vals = LoadUnaligned16(left_column + left_base_y); + } else { + const __m128i top_vals = LoadUnaligned16(left_column + left_base_y); + vals = _mm_shuffle_epi8(top_vals, sampler); + } + vals = _mm_maddubs_epi16(vals, shifts); + result_block[x] = RightShiftWithRounding_U16(vals, rounding_bits); + } + Transpose8x8_U16(result_block, result_block); + for (int y = 0; y < height; ++y) { + StoreLo8(dest, _mm_packus_epi16(result_block[y], result_block[y])); + dest += stride; + } +} + +// 7.11.2.4 (9) angle > 180 +void DirectionalIntraPredictorZone3_SSE4_1(void* dest, ptrdiff_t stride, + const void* const left_column, + const int width, const int height, + const int ystep, + const bool upsampled) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + auto* dst = static_cast<uint8_t*>(dest); + const int upsample_shift = static_cast<int>(upsampled); + if (width == 4 || height == 4) { + const ptrdiff_t stride4 = stride << 2; + if (upsampled) { + int left_y = ystep; + int x = 0; + do { + uint8_t* dst_x = dst + x; + int y = 0; + do { + DirectionalZone3_4x4<true>( + dst_x, stride, left_ptr + (y << upsample_shift), left_y, ystep); + dst_x += stride4; + y += 4; + } while (y < height); + left_y += ystep << 2; + x += 4; + } while (x < width); + } else { + int left_y = ystep; + int x = 0; + do { + uint8_t* dst_x = dst + x; + int y = 0; + do { + DirectionalZone3_4x4<false>(dst_x, stride, left_ptr + y, left_y, + ystep); + dst_x += stride4; + y += 4; + } while (y < height); + left_y += ystep << 2; + x += 4; + } while (x < width); + } + return; + } + + const ptrdiff_t stride8 = stride << 3; + if (upsampled) { + int left_y = ystep; + int x = 0; + do { + uint8_t* dst_x = dst + x; + int y = 0; + do { + DirectionalZone3_8xH<true, 8>( + dst_x, stride, left_ptr + (y << upsample_shift), left_y, ystep); + dst_x += stride8; + y += 8; + } while (y < height); + left_y += ystep << 3; + x += 8; + } while (x < width); + } else { + int left_y = ystep; + int x = 0; + do { + uint8_t* dst_x = dst + x; + int y = 0; + do { + DirectionalZone3_8xH<false, 8>( + dst_x, stride, left_ptr + (y << upsample_shift), left_y, ystep); + dst_x += stride8; + y += 8; + } while (y < height); + left_y += ystep << 3; + x += 8; + } while (x < width); + } +} + +//------------------------------------------------------------------------------ +// Directional Zone 2 Functions +// 7.11.2.4 (8) + +// DirectionalBlend* selectively overwrites the values written by +// DirectionalZone2FromLeftCol*. |zone_bounds| has one 16-bit index for each +// row. +template <int y_selector> +inline void DirectionalBlend4_SSE4_1(uint8_t* dest, + const __m128i& dest_index_vect, + const __m128i& vals, + const __m128i& zone_bounds) { + const __m128i max_dest_x_vect = _mm_shufflelo_epi16(zone_bounds, y_selector); + const __m128i use_left = _mm_cmplt_epi16(dest_index_vect, max_dest_x_vect); + const __m128i original_vals = _mm_cvtepu8_epi16(Load4(dest)); + const __m128i blended_vals = _mm_blendv_epi8(vals, original_vals, use_left); + Store4(dest, _mm_packus_epi16(blended_vals, blended_vals)); +} + +inline void DirectionalBlend8_SSE4_1(uint8_t* dest, + const __m128i& dest_index_vect, + const __m128i& vals, + const __m128i& zone_bounds, + const __m128i& bounds_selector) { + const __m128i max_dest_x_vect = + _mm_shuffle_epi8(zone_bounds, bounds_selector); + const __m128i use_left = _mm_cmplt_epi16(dest_index_vect, max_dest_x_vect); + const __m128i original_vals = _mm_cvtepu8_epi16(LoadLo8(dest)); + const __m128i blended_vals = _mm_blendv_epi8(vals, original_vals, use_left); + StoreLo8(dest, _mm_packus_epi16(blended_vals, blended_vals)); +} + +constexpr int kDirectionalWeightBits = 5; +// |source| is packed with 4 or 8 pairs of 8-bit values from left or top. +// |shifts| is named to match the specification, with 4 or 8 pairs of (32 - +// shift) and shift. Shift is guaranteed to be between 0 and 32. +inline __m128i DirectionalZone2FromSource_SSE4_1(const uint8_t* const source, + const __m128i& shifts, + const __m128i& sampler) { + const __m128i src_vals = LoadUnaligned16(source); + __m128i vals = _mm_shuffle_epi8(src_vals, sampler); + vals = _mm_maddubs_epi16(vals, shifts); + return RightShiftWithRounding_U16(vals, kDirectionalWeightBits); +} + +// Because the source values "move backwards" as the row index increases, the +// indices derived from ystep are generally negative. This is accommodated by +// making sure the relative indices are within [-15, 0] when the function is +// called, and sliding them into the inclusive range [0, 15], relative to a +// lower base address. +constexpr int kPositiveIndexOffset = 15; + +template <bool upsampled> +inline void DirectionalZone2FromLeftCol_4x4_SSE4_1( + uint8_t* dst, ptrdiff_t stride, const uint8_t* const left_column_base, + __m128i left_y) { + const int upsample_shift = static_cast<int>(upsampled); + const int scale_bits = 6 - upsample_shift; + const __m128i max_shifts = _mm_set1_epi8(32); + const __m128i shift_mask = _mm_set1_epi32(0x003F003F); + const __m128i index_increment = _mm_cvtsi32_si128(0x01010101); + const __m128i positive_offset = _mm_set1_epi8(kPositiveIndexOffset); + // Left_column and sampler are both offset by 15 so the indices are always + // positive. + const uint8_t* left_column = left_column_base - kPositiveIndexOffset; + for (int y = 0; y < 4; dst += stride, ++y) { + __m128i offset_y = _mm_srai_epi16(left_y, scale_bits); + offset_y = _mm_packs_epi16(offset_y, offset_y); + + const __m128i adjacent = _mm_add_epi8(offset_y, index_increment); + __m128i sampler = _mm_unpacklo_epi8(offset_y, adjacent); + // Slide valid |offset_y| indices from range [-15, 0] to [0, 15] so they + // can work as shuffle indices. Some values may be out of bounds, but their + // pred results will be masked over by top prediction. + sampler = _mm_add_epi8(sampler, positive_offset); + + __m128i shifts = _mm_srli_epi16( + _mm_and_si128(_mm_slli_epi16(left_y, upsample_shift), shift_mask), 1); + shifts = _mm_packus_epi16(shifts, shifts); + const __m128i opposite_shifts = _mm_sub_epi8(max_shifts, shifts); + shifts = _mm_unpacklo_epi8(opposite_shifts, shifts); + const __m128i vals = DirectionalZone2FromSource_SSE4_1( + left_column + (y << upsample_shift), shifts, sampler); + Store4(dst, _mm_packus_epi16(vals, vals)); + } +} + +// The height at which a load of 16 bytes will not contain enough source pixels +// from |left_column| to supply an accurate row when computing 8 pixels at a +// time. The values are found by inspection. By coincidence, all angles that +// satisfy (ystep >> 6) == 2 map to the same value, so it is enough to look up +// by ystep >> 6. The largest index for this lookup is 1023 >> 6 == 15. +constexpr int kDirectionalZone2ShuffleInvalidHeight[16] = { + 1024, 1024, 16, 16, 16, 16, 0, 0, 18, 0, 0, 0, 0, 0, 0, 40}; + +template <bool upsampled> +inline void DirectionalZone2FromLeftCol_8x8_SSE4_1( + uint8_t* dst, ptrdiff_t stride, const uint8_t* const left_column, + __m128i left_y) { + const int upsample_shift = static_cast<int>(upsampled); + const int scale_bits = 6 - upsample_shift; + const __m128i max_shifts = _mm_set1_epi8(32); + const __m128i shift_mask = _mm_set1_epi32(0x003F003F); + const __m128i index_increment = _mm_set1_epi8(1); + const __m128i denegation = _mm_set1_epi8(kPositiveIndexOffset); + for (int y = 0; y < 8; dst += stride, ++y) { + __m128i offset_y = _mm_srai_epi16(left_y, scale_bits); + offset_y = _mm_packs_epi16(offset_y, offset_y); + const __m128i adjacent = _mm_add_epi8(offset_y, index_increment); + + // Offset the relative index because ystep is negative in Zone 2 and shuffle + // indices must be nonnegative. + __m128i sampler = _mm_unpacklo_epi8(offset_y, adjacent); + sampler = _mm_add_epi8(sampler, denegation); + + __m128i shifts = _mm_srli_epi16( + _mm_and_si128(_mm_slli_epi16(left_y, upsample_shift), shift_mask), 1); + shifts = _mm_packus_epi16(shifts, shifts); + const __m128i opposite_shifts = _mm_sub_epi8(max_shifts, shifts); + shifts = _mm_unpacklo_epi8(opposite_shifts, shifts); + + // The specification adds (y << 6) to left_y, which is subject to + // upsampling, but this puts sampler indices out of the 0-15 range. It is + // equivalent to offset the source address by (y << upsample_shift) instead. + const __m128i vals = DirectionalZone2FromSource_SSE4_1( + left_column - kPositiveIndexOffset + (y << upsample_shift), shifts, + sampler); + StoreLo8(dst, _mm_packus_epi16(vals, vals)); + } +} + +// |zone_bounds| is an epi16 of the relative x index at which base >= -(1 << +// upsampled_top), for each row. When there are 4 values, they can be duplicated +// with a non-register shuffle mask. +// |shifts| is one pair of weights that applies throughout a given row. +template <bool upsampled_top> +inline void DirectionalZone1Blend_4x4( + uint8_t* dest, const uint8_t* const top_row, ptrdiff_t stride, + __m128i sampler, const __m128i& zone_bounds, const __m128i& shifts, + const __m128i& dest_index_x, int top_x, const int xstep) { + const int upsample_shift = static_cast<int>(upsampled_top); + const int scale_bits_x = 6 - upsample_shift; + top_x -= xstep; + + int top_base_x = (top_x >> scale_bits_x); + const __m128i vals0 = DirectionalZone2FromSource_SSE4_1( + top_row + top_base_x, _mm_shufflelo_epi16(shifts, 0x00), sampler); + DirectionalBlend4_SSE4_1<0x00>(dest, dest_index_x, vals0, zone_bounds); + top_x -= xstep; + dest += stride; + + top_base_x = (top_x >> scale_bits_x); + const __m128i vals1 = DirectionalZone2FromSource_SSE4_1( + top_row + top_base_x, _mm_shufflelo_epi16(shifts, 0x55), sampler); + DirectionalBlend4_SSE4_1<0x55>(dest, dest_index_x, vals1, zone_bounds); + top_x -= xstep; + dest += stride; + + top_base_x = (top_x >> scale_bits_x); + const __m128i vals2 = DirectionalZone2FromSource_SSE4_1( + top_row + top_base_x, _mm_shufflelo_epi16(shifts, 0xAA), sampler); + DirectionalBlend4_SSE4_1<0xAA>(dest, dest_index_x, vals2, zone_bounds); + top_x -= xstep; + dest += stride; + + top_base_x = (top_x >> scale_bits_x); + const __m128i vals3 = DirectionalZone2FromSource_SSE4_1( + top_row + top_base_x, _mm_shufflelo_epi16(shifts, 0xFF), sampler); + DirectionalBlend4_SSE4_1<0xFF>(dest, dest_index_x, vals3, zone_bounds); +} + +template <bool upsampled_top, int height> +inline void DirectionalZone1Blend_8xH( + uint8_t* dest, const uint8_t* const top_row, ptrdiff_t stride, + __m128i sampler, const __m128i& zone_bounds, const __m128i& shifts, + const __m128i& dest_index_x, int top_x, const int xstep) { + const int upsample_shift = static_cast<int>(upsampled_top); + const int scale_bits_x = 6 - upsample_shift; + + __m128i y_selector = _mm_set1_epi32(0x01000100); + const __m128i index_increment = _mm_set1_epi32(0x02020202); + for (int y = 0; y < height; ++y, + y_selector = _mm_add_epi8(y_selector, index_increment), + dest += stride) { + top_x -= xstep; + const int top_base_x = top_x >> scale_bits_x; + const __m128i vals = DirectionalZone2FromSource_SSE4_1( + top_row + top_base_x, _mm_shuffle_epi8(shifts, y_selector), sampler); + DirectionalBlend8_SSE4_1(dest, dest_index_x, vals, zone_bounds, y_selector); + } +} + +// 7.11.2.4 (8) 90 < angle > 180 +// The strategy for this function is to know how many blocks can be processed +// with just pixels from |top_ptr|, then handle mixed blocks, then handle only +// blocks that take from |left_ptr|. Additionally, a fast index-shuffle +// approach is used for pred values from |left_column| in sections that permit +// it. +template <bool upsampled_left, bool upsampled_top> +inline void DirectionalZone2_SSE4_1(void* dest, ptrdiff_t stride, + const uint8_t* const top_row, + const uint8_t* const left_column, + const int width, const int height, + const int xstep, const int ystep) { + auto* dst = static_cast<uint8_t*>(dest); + const int upsample_left_shift = static_cast<int>(upsampled_left); + const int upsample_top_shift = static_cast<int>(upsampled_top); + const __m128i max_shift = _mm_set1_epi8(32); + const ptrdiff_t stride8 = stride << 3; + const __m128i dest_index_x = + _mm_set_epi32(0x00070006, 0x00050004, 0x00030002, 0x00010000); + const __m128i sampler_top = + upsampled_top + ? _mm_set_epi32(0x0F0E0D0C, 0x0B0A0908, 0x07060504, 0x03020100) + : _mm_set_epi32(0x08070706, 0x06050504, 0x04030302, 0x02010100); + const __m128i shift_mask = _mm_set1_epi32(0x003F003F); + // All columns from |min_top_only_x| to the right will only need |top_row| to + // compute. This assumes minimum |xstep| is 3. + const int min_top_only_x = std::min((height * xstep) >> 6, width); + + // For steep angles, the source pixels from left_column may not fit in a + // 16-byte load for shuffling. + // TODO(petersonab): Find a more precise formula for this subject to x. + const int max_shuffle_height = + std::min(height, kDirectionalZone2ShuffleInvalidHeight[ystep >> 6]); + + const int xstep8 = xstep << 3; + const __m128i xstep8_vect = _mm_set1_epi16(xstep8); + // Accumulate xstep across 8 rows. + const __m128i xstep_dup = _mm_set1_epi16(-xstep); + const __m128i increments = _mm_set_epi16(8, 7, 6, 5, 4, 3, 2, 1); + const __m128i xstep_for_shift = _mm_mullo_epi16(xstep_dup, increments); + // Offsets the original zone bound value to simplify x < (y+1)*xstep/64 -1 + const __m128i scaled_one = _mm_set1_epi16(-64); + __m128i xstep_bounds_base = + (xstep == 64) ? _mm_sub_epi16(scaled_one, xstep_for_shift) + : _mm_sub_epi16(_mm_set1_epi16(-1), xstep_for_shift); + + const int left_base_increment = ystep >> 6; + const int ystep_remainder = ystep & 0x3F; + const int ystep8 = ystep << 3; + const int left_base_increment8 = ystep8 >> 6; + const int ystep_remainder8 = ystep8 & 0x3F; + const __m128i increment_left8 = _mm_set1_epi16(-ystep_remainder8); + + // If the 64 scaling is regarded as a decimal point, the first value of the + // left_y vector omits the portion which is covered under the left_column + // offset. Following values need the full ystep as a relative offset. + const __m128i ystep_init = _mm_set1_epi16(-ystep_remainder); + const __m128i ystep_dup = _mm_set1_epi16(-ystep); + __m128i left_y = _mm_mullo_epi16(ystep_dup, dest_index_x); + left_y = _mm_add_epi16(ystep_init, left_y); + + const __m128i increment_top8 = _mm_set1_epi16(8 << 6); + int x = 0; + + // This loop treats each set of 4 columns in 3 stages with y-value boundaries. + // The first stage, before the first y-loop, covers blocks that are only + // computed from the top row. The second stage, comprising two y-loops, covers + // blocks that have a mixture of values computed from top or left. The final + // stage covers blocks that are only computed from the left. + for (int left_offset = -left_base_increment; x < min_top_only_x; + x += 8, + xstep_bounds_base = _mm_sub_epi16(xstep_bounds_base, increment_top8), + // Watch left_y because it can still get big. + left_y = _mm_add_epi16(left_y, increment_left8), + left_offset -= left_base_increment8) { + uint8_t* dst_x = dst + x; + + // Round down to the nearest multiple of 8. + const int max_top_only_y = std::min(((x + 1) << 6) / xstep, height) & ~7; + DirectionalZone1_4xH(dst_x, stride, top_row + (x << upsample_top_shift), + max_top_only_y, -xstep, upsampled_top); + DirectionalZone1_4xH(dst_x + 4, stride, + top_row + ((x + 4) << upsample_top_shift), + max_top_only_y, -xstep, upsampled_top); + + int y = max_top_only_y; + dst_x += stride * y; + const int xstep_y = xstep * y; + const __m128i xstep_y_vect = _mm_set1_epi16(xstep_y); + // All rows from |min_left_only_y| down for this set of columns, only need + // |left_column| to compute. + const int min_left_only_y = std::min(((x + 8) << 6) / xstep, height); + // At high angles such that min_left_only_y < 8, ystep is low and xstep is + // high. This means that max_shuffle_height is unbounded and xstep_bounds + // will overflow in 16 bits. This is prevented by stopping the first + // blending loop at min_left_only_y for such cases, which means we skip over + // the second blending loop as well. + const int left_shuffle_stop_y = + std::min(max_shuffle_height, min_left_only_y); + __m128i xstep_bounds = _mm_add_epi16(xstep_bounds_base, xstep_y_vect); + __m128i xstep_for_shift_y = _mm_sub_epi16(xstep_for_shift, xstep_y_vect); + int top_x = -xstep_y; + + for (; y < left_shuffle_stop_y; + y += 8, dst_x += stride8, + xstep_bounds = _mm_add_epi16(xstep_bounds, xstep8_vect), + xstep_for_shift_y = _mm_sub_epi16(xstep_for_shift_y, xstep8_vect), + top_x -= xstep8) { + DirectionalZone2FromLeftCol_8x8_SSE4_1<upsampled_left>( + dst_x, stride, + left_column + ((left_offset + y) << upsample_left_shift), left_y); + + __m128i shifts = _mm_srli_epi16( + _mm_and_si128(_mm_slli_epi16(xstep_for_shift_y, upsample_top_shift), + shift_mask), + 1); + shifts = _mm_packus_epi16(shifts, shifts); + __m128i opposite_shifts = _mm_sub_epi8(max_shift, shifts); + shifts = _mm_unpacklo_epi8(opposite_shifts, shifts); + __m128i xstep_bounds_off = _mm_srai_epi16(xstep_bounds, 6); + DirectionalZone1Blend_8xH<upsampled_top, 8>( + dst_x, top_row + (x << upsample_top_shift), stride, sampler_top, + xstep_bounds_off, shifts, dest_index_x, top_x, xstep); + } + // Pick up from the last y-value, using the 10% slower but secure method for + // left prediction. + const auto base_left_y = static_cast<int16_t>(_mm_extract_epi16(left_y, 0)); + for (; y < min_left_only_y; + y += 8, dst_x += stride8, + xstep_bounds = _mm_add_epi16(xstep_bounds, xstep8_vect), + xstep_for_shift_y = _mm_sub_epi16(xstep_for_shift_y, xstep8_vect), + top_x -= xstep8) { + const __m128i xstep_bounds_off = _mm_srai_epi16(xstep_bounds, 6); + + DirectionalZone3_8xH<upsampled_left, 8>( + dst_x, stride, + left_column + ((left_offset + y) << upsample_left_shift), base_left_y, + -ystep); + + __m128i shifts = _mm_srli_epi16( + _mm_and_si128(_mm_slli_epi16(xstep_for_shift_y, upsample_top_shift), + shift_mask), + 1); + shifts = _mm_packus_epi16(shifts, shifts); + __m128i opposite_shifts = _mm_sub_epi8(max_shift, shifts); + shifts = _mm_unpacklo_epi8(opposite_shifts, shifts); + DirectionalZone1Blend_8xH<upsampled_top, 8>( + dst_x, top_row + (x << upsample_top_shift), stride, sampler_top, + xstep_bounds_off, shifts, dest_index_x, top_x, xstep); + } + // Loop over y for left_only rows. + for (; y < height; y += 8, dst_x += stride8) { + DirectionalZone3_8xH<upsampled_left, 8>( + dst_x, stride, + left_column + ((left_offset + y) << upsample_left_shift), base_left_y, + -ystep); + } + } + for (; x < width; x += 4) { + DirectionalZone1_4xH(dst + x, stride, top_row + (x << upsample_top_shift), + height, -xstep, upsampled_top); + } +} + +template <bool upsampled_left, bool upsampled_top> +inline void DirectionalZone2_4_SSE4_1(void* dest, ptrdiff_t stride, + const uint8_t* const top_row, + const uint8_t* const left_column, + const int width, const int height, + const int xstep, const int ystep) { + auto* dst = static_cast<uint8_t*>(dest); + const int upsample_left_shift = static_cast<int>(upsampled_left); + const int upsample_top_shift = static_cast<int>(upsampled_top); + const __m128i max_shift = _mm_set1_epi8(32); + const ptrdiff_t stride4 = stride << 2; + const __m128i dest_index_x = _mm_set_epi32(0, 0, 0x00030002, 0x00010000); + const __m128i sampler_top = + upsampled_top + ? _mm_set_epi32(0x0F0E0D0C, 0x0B0A0908, 0x07060504, 0x03020100) + : _mm_set_epi32(0x08070706, 0x06050504, 0x04030302, 0x02010100); + // All columns from |min_top_only_x| to the right will only need |top_row| to + // compute. + assert(xstep >= 3); + const int min_top_only_x = std::min((height * xstep) >> 6, width); + + const int xstep4 = xstep << 2; + const __m128i xstep4_vect = _mm_set1_epi16(xstep4); + const __m128i xstep_dup = _mm_set1_epi16(-xstep); + const __m128i increments = _mm_set_epi32(0, 0, 0x00040003, 0x00020001); + __m128i xstep_for_shift = _mm_mullo_epi16(xstep_dup, increments); + const __m128i scaled_one = _mm_set1_epi16(-64); + // Offsets the original zone bound value to simplify x < (y+1)*xstep/64 -1 + __m128i xstep_bounds_base = + (xstep == 64) ? _mm_sub_epi16(scaled_one, xstep_for_shift) + : _mm_sub_epi16(_mm_set1_epi16(-1), xstep_for_shift); + + const int left_base_increment = ystep >> 6; + const int ystep_remainder = ystep & 0x3F; + const int ystep4 = ystep << 2; + const int left_base_increment4 = ystep4 >> 6; + // This is guaranteed to be less than 64, but accumulation may bring it past + // 64 for higher x values. + const int ystep_remainder4 = ystep4 & 0x3F; + const __m128i increment_left4 = _mm_set1_epi16(-ystep_remainder4); + const __m128i increment_top4 = _mm_set1_epi16(4 << 6); + + // If the 64 scaling is regarded as a decimal point, the first value of the + // left_y vector omits the portion which will go into the left_column offset. + // Following values need the full ystep as a relative offset. + const __m128i ystep_init = _mm_set1_epi16(-ystep_remainder); + const __m128i ystep_dup = _mm_set1_epi16(-ystep); + __m128i left_y = _mm_mullo_epi16(ystep_dup, dest_index_x); + left_y = _mm_add_epi16(ystep_init, left_y); + const __m128i shift_mask = _mm_set1_epi32(0x003F003F); + + int x = 0; + // Loop over x for columns with a mixture of sources. + for (int left_offset = -left_base_increment; x < min_top_only_x; x += 4, + xstep_bounds_base = _mm_sub_epi16(xstep_bounds_base, increment_top4), + left_y = _mm_add_epi16(left_y, increment_left4), + left_offset -= left_base_increment4) { + uint8_t* dst_x = dst + x; + + // Round down to the nearest multiple of 8. + const int max_top_only_y = std::min((x << 6) / xstep, height) & 0xFFFFFFF4; + DirectionalZone1_4xH(dst_x, stride, top_row + (x << upsample_top_shift), + max_top_only_y, -xstep, upsampled_top); + int y = max_top_only_y; + dst_x += stride * y; + const int xstep_y = xstep * y; + const __m128i xstep_y_vect = _mm_set1_epi16(xstep_y); + // All rows from |min_left_only_y| down for this set of columns, only need + // |left_column| to compute. Rounded up to the nearest multiple of 4. + const int min_left_only_y = std::min(((x + 4) << 6) / xstep, height); + + __m128i xstep_bounds = _mm_add_epi16(xstep_bounds_base, xstep_y_vect); + __m128i xstep_for_shift_y = _mm_sub_epi16(xstep_for_shift, xstep_y_vect); + int top_x = -xstep_y; + + // Loop over y for mixed rows. + for (; y < min_left_only_y; + y += 4, dst_x += stride4, + xstep_bounds = _mm_add_epi16(xstep_bounds, xstep4_vect), + xstep_for_shift_y = _mm_sub_epi16(xstep_for_shift_y, xstep4_vect), + top_x -= xstep4) { + DirectionalZone2FromLeftCol_4x4_SSE4_1<upsampled_left>( + dst_x, stride, + left_column + ((left_offset + y) * (1 << upsample_left_shift)), + left_y); + + __m128i shifts = _mm_srli_epi16( + _mm_and_si128(_mm_slli_epi16(xstep_for_shift_y, upsample_top_shift), + shift_mask), + 1); + shifts = _mm_packus_epi16(shifts, shifts); + const __m128i opposite_shifts = _mm_sub_epi8(max_shift, shifts); + shifts = _mm_unpacklo_epi8(opposite_shifts, shifts); + const __m128i xstep_bounds_off = _mm_srai_epi16(xstep_bounds, 6); + DirectionalZone1Blend_4x4<upsampled_top>( + dst_x, top_row + (x << upsample_top_shift), stride, sampler_top, + xstep_bounds_off, shifts, dest_index_x, top_x, xstep); + } + // Loop over y for left-only rows, if any. + for (; y < height; y += 4, dst_x += stride4) { + DirectionalZone2FromLeftCol_4x4_SSE4_1<upsampled_left>( + dst_x, stride, + left_column + ((left_offset + y) << upsample_left_shift), left_y); + } + } + // Loop over top-only columns, if any. + for (; x < width; x += 4) { + DirectionalZone1_4xH(dst + x, stride, top_row + (x << upsample_top_shift), + height, -xstep, upsampled_top); + } +} + +void DirectionalIntraPredictorZone2_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column, + const int width, const int height, + const int xstep, const int ystep, + const bool upsampled_top, + const bool upsampled_left) { + // Increasing the negative buffer for this function allows more rows to be + // processed at a time without branching in an inner loop to check the base. + uint8_t top_buffer[288]; + uint8_t left_buffer[288]; + memcpy(top_buffer + 128, static_cast<const uint8_t*>(top_row) - 16, 160); + memcpy(left_buffer + 128, static_cast<const uint8_t*>(left_column) - 16, 160); + const uint8_t* top_ptr = top_buffer + 144; + const uint8_t* left_ptr = left_buffer + 144; + if (width == 4 || height == 4) { + if (upsampled_left) { + if (upsampled_top) { + DirectionalZone2_4_SSE4_1<true, true>(dest, stride, top_ptr, left_ptr, + width, height, xstep, ystep); + } else { + DirectionalZone2_4_SSE4_1<true, false>(dest, stride, top_ptr, left_ptr, + width, height, xstep, ystep); + } + } else { + if (upsampled_top) { + DirectionalZone2_4_SSE4_1<false, true>(dest, stride, top_ptr, left_ptr, + width, height, xstep, ystep); + } else { + DirectionalZone2_4_SSE4_1<false, false>(dest, stride, top_ptr, left_ptr, + width, height, xstep, ystep); + } + } + return; + } + if (upsampled_left) { + if (upsampled_top) { + DirectionalZone2_SSE4_1<true, true>(dest, stride, top_ptr, left_ptr, + width, height, xstep, ystep); + } else { + DirectionalZone2_SSE4_1<true, false>(dest, stride, top_ptr, left_ptr, + width, height, xstep, ystep); + } + } else { + if (upsampled_top) { + DirectionalZone2_SSE4_1<false, true>(dest, stride, top_ptr, left_ptr, + width, height, xstep, ystep); + } else { + DirectionalZone2_SSE4_1<false, false>(dest, stride, top_ptr, left_ptr, + width, height, xstep, ystep); + } + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + static_cast<void>(dsp); +#if DSP_ENABLED_8BPP_SSE4_1(DirectionalIntraPredictorZone1) + dsp->directional_intra_predictor_zone1 = + DirectionalIntraPredictorZone1_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(DirectionalIntraPredictorZone2) + dsp->directional_intra_predictor_zone2 = + DirectionalIntraPredictorZone2_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(DirectionalIntraPredictorZone3) + dsp->directional_intra_predictor_zone3 = + DirectionalIntraPredictorZone3_SSE4_1; +#endif +} + +} // namespace +} // namespace low_bitdepth + +//------------------------------------------------------------------------------ +#if LIBGAV1_MAX_BITDEPTH >= 10 +namespace high_bitdepth { +namespace { + +//------------------------------------------------------------------------------ +// 7.11.2.4. Directional intra prediction process + +// Special case: An |xstep| of 64 corresponds to an angle delta of 45, meaning +// upsampling is ruled out. In addition, the bits masked by 0x3F for +// |shift_val| are 0 for all multiples of 64, so the formula +// val = top[top_base_x]*shift + top[top_base_x+1]*(32-shift), reduces to +// val = top[top_base_x+1] << 5, meaning only the second set of pixels is +// involved in the output. Hence |top| is offset by 1. +inline void DirectionalZone1_Step64(uint16_t* dst, ptrdiff_t stride, + const uint16_t* const top, const int width, + const int height) { + ptrdiff_t offset = 1; + if (height == 4) { + memcpy(dst, top + offset, width * sizeof(dst[0])); + dst += stride; + memcpy(dst, top + offset + 1, width * sizeof(dst[0])); + dst += stride; + memcpy(dst, top + offset + 2, width * sizeof(dst[0])); + dst += stride; + memcpy(dst, top + offset + 3, width * sizeof(dst[0])); + return; + } + int y = height; + do { + memcpy(dst, top + offset, width * sizeof(dst[0])); + dst += stride; + memcpy(dst, top + offset + 1, width * sizeof(dst[0])); + dst += stride; + memcpy(dst, top + offset + 2, width * sizeof(dst[0])); + dst += stride; + memcpy(dst, top + offset + 3, width * sizeof(dst[0])); + dst += stride; + memcpy(dst, top + offset + 4, width * sizeof(dst[0])); + dst += stride; + memcpy(dst, top + offset + 5, width * sizeof(dst[0])); + dst += stride; + memcpy(dst, top + offset + 6, width * sizeof(dst[0])); + dst += stride; + memcpy(dst, top + offset + 7, width * sizeof(dst[0])); + dst += stride; + + offset += 8; + y -= 8; + } while (y != 0); +} + +// Produce a weighted average whose weights sum to 32. +inline __m128i CombineTopVals4(const __m128i& top_vals, const __m128i& sampler, + const __m128i& shifts, + const __m128i& top_indices, + const __m128i& final_top_val, + const __m128i& border_index) { + const __m128i sampled_values = _mm_shuffle_epi8(top_vals, sampler); + __m128i prod = _mm_mullo_epi16(sampled_values, shifts); + prod = _mm_hadd_epi16(prod, prod); + const __m128i result = RightShiftWithRounding_U16(prod, 5 /*log2(32)*/); + + const __m128i past_max = _mm_cmpgt_epi16(top_indices, border_index); + // Replace pixels from invalid range with top-right corner. + return _mm_blendv_epi8(result, final_top_val, past_max); +} + +// When width is 4, only one load operation is needed per iteration. We also +// avoid extra loop precomputations that cause too much overhead. +inline void DirectionalZone1_4xH(uint16_t* dst, ptrdiff_t stride, + const uint16_t* const top, const int height, + const int xstep, const bool upsampled, + const __m128i& sampler) { + const int upsample_shift = static_cast<int>(upsampled); + const int index_scale_bits = 6 - upsample_shift; + const int max_base_x = (height + 3 /* width - 1 */) << upsample_shift; + const __m128i max_base_x_vect = _mm_set1_epi16(max_base_x); + const __m128i final_top_val = _mm_set1_epi16(top[max_base_x]); + + // Each 16-bit value here corresponds to a position that may exceed + // |max_base_x|. When added to the top_base_x, it is used to mask values + // that pass the end of |top|. Starting from 1 to simulate "cmpge" because + // only cmpgt is available. + const __m128i offsets = + _mm_set_epi32(0x00080007, 0x00060005, 0x00040003, 0x00020001); + + // All rows from |min_corner_only_y| down will simply use memcpy. + // |max_base_x| is always greater than |height|, so clipping the denominator + // to 1 is enough to make the logic work. + const int xstep_units = std::max(xstep >> index_scale_bits, 1); + const int min_corner_only_y = std::min(max_base_x / xstep_units, height); + + int y = 0; + int top_x = xstep; + const __m128i max_shift = _mm_set1_epi16(32); + + for (; y < min_corner_only_y; ++y, dst += stride, top_x += xstep) { + const int top_base_x = top_x >> index_scale_bits; + + // Permit negative values of |top_x|. + const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; + const __m128i shift = _mm_set1_epi16(shift_val); + const __m128i opposite_shift = _mm_sub_epi16(max_shift, shift); + const __m128i shifts = _mm_unpacklo_epi16(opposite_shift, shift); + __m128i top_index_vect = _mm_set1_epi16(top_base_x); + top_index_vect = _mm_add_epi16(top_index_vect, offsets); + + // Load 8 values because we will select the sampled values based on + // |upsampled|. + const __m128i values = LoadUnaligned16(top + top_base_x); + const __m128i pred = + CombineTopVals4(values, sampler, shifts, top_index_vect, final_top_val, + max_base_x_vect); + StoreLo8(dst, pred); + } + + // Fill in corner-only rows. + for (; y < height; ++y) { + Memset(dst, top[max_base_x], /* width */ 4); + dst += stride; + } +} + +// General purpose combine function. +// |check_border| means the final source value has to be duplicated into the +// result. This simplifies the loop structures that use precomputed boundaries +// to identify sections where it is safe to compute without checking for the +// right border. +template <bool check_border> +inline __m128i CombineTopVals( + const __m128i& top_vals_0, const __m128i& top_vals_1, + const __m128i& sampler, const __m128i& shifts, + const __m128i& top_indices = _mm_setzero_si128(), + const __m128i& final_top_val = _mm_setzero_si128(), + const __m128i& border_index = _mm_setzero_si128()) { + constexpr int scale_int_bits = 5; + const __m128i sampled_values_0 = _mm_shuffle_epi8(top_vals_0, sampler); + const __m128i sampled_values_1 = _mm_shuffle_epi8(top_vals_1, sampler); + const __m128i prod_0 = _mm_mullo_epi16(sampled_values_0, shifts); + const __m128i prod_1 = _mm_mullo_epi16(sampled_values_1, shifts); + const __m128i combined = _mm_hadd_epi16(prod_0, prod_1); + const __m128i result = RightShiftWithRounding_U16(combined, scale_int_bits); + if (check_border) { + const __m128i past_max = _mm_cmpgt_epi16(top_indices, border_index); + // Replace pixels from invalid range with top-right corner. + return _mm_blendv_epi8(result, final_top_val, past_max); + } + return result; +} + +// 7.11.2.4 (7) angle < 90 +inline void DirectionalZone1_Large(uint16_t* dest, ptrdiff_t stride, + const uint16_t* const top_row, + const int width, const int height, + const int xstep, const bool upsampled, + const __m128i& sampler) { + const int upsample_shift = static_cast<int>(upsampled); + const int index_scale_bits = 6 - upsample_shift; + const int max_base_x = ((width + height) - 1) << upsample_shift; + + const __m128i max_shift = _mm_set1_epi16(32); + const int base_step = 1 << upsample_shift; + const int base_step8 = base_step << 3; + + // All rows from |min_corner_only_y| down will simply use memcpy. + // |max_base_x| is always greater than |height|, so clipping to 1 is enough + // to make the logic work. + const int xstep_units = std::max(xstep >> index_scale_bits, 1); + const int min_corner_only_y = std::min(max_base_x / xstep_units, height); + + // Rows up to this y-value can be computed without checking for bounds. + const int max_no_corner_y = std::min( + LeftShift((max_base_x - (base_step * width)), index_scale_bits) / xstep, + height); + // No need to check for exceeding |max_base_x| in the first loop. + int y = 0; + int top_x = xstep; + for (; y < max_no_corner_y; ++y, dest += stride, top_x += xstep) { + int top_base_x = top_x >> index_scale_bits; + // Permit negative values of |top_x|. + const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; + const __m128i shift = _mm_set1_epi16(shift_val); + const __m128i opposite_shift = _mm_sub_epi16(max_shift, shift); + const __m128i shifts = _mm_unpacklo_epi16(opposite_shift, shift); + int x = 0; + do { + const __m128i top_vals_0 = LoadUnaligned16(top_row + top_base_x); + const __m128i top_vals_1 = + LoadUnaligned16(top_row + top_base_x + (4 << upsample_shift)); + + const __m128i pred = + CombineTopVals<false>(top_vals_0, top_vals_1, sampler, shifts); + + StoreUnaligned16(dest + x, pred); + top_base_x += base_step8; + x += 8; + } while (x < width); + } + + // Each 16-bit value here corresponds to a position that may exceed + // |max_base_x|. When added to |top_base_x|, it is used to mask values + // that pass the end of the |top| buffer. Starting from 1 to simulate "cmpge" + // which is not supported for packed integers. + const __m128i offsets = + _mm_set_epi32(0x00080007, 0x00060005, 0x00040003, 0x00020001); + + const __m128i max_base_x_vect = _mm_set1_epi16(max_base_x); + const __m128i final_top_val = _mm_set1_epi16(top_row[max_base_x]); + const __m128i base_step8_vect = _mm_set1_epi16(base_step8); + for (; y < min_corner_only_y; ++y, dest += stride, top_x += xstep) { + int top_base_x = top_x >> index_scale_bits; + + const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; + const __m128i shift = _mm_set1_epi16(shift_val); + const __m128i opposite_shift = _mm_sub_epi16(max_shift, shift); + const __m128i shifts = _mm_unpacklo_epi16(opposite_shift, shift); + __m128i top_index_vect = _mm_set1_epi16(top_base_x); + top_index_vect = _mm_add_epi16(top_index_vect, offsets); + + int x = 0; + const int min_corner_only_x = + std::min(width, ((max_base_x - top_base_x) >> upsample_shift) + 7) & ~7; + for (; x < min_corner_only_x; + x += 8, top_base_x += base_step8, + top_index_vect = _mm_add_epi16(top_index_vect, base_step8_vect)) { + const __m128i top_vals_0 = LoadUnaligned16(top_row + top_base_x); + const __m128i top_vals_1 = + LoadUnaligned16(top_row + top_base_x + (4 << upsample_shift)); + const __m128i pred = + CombineTopVals<true>(top_vals_0, top_vals_1, sampler, shifts, + top_index_vect, final_top_val, max_base_x_vect); + StoreUnaligned16(dest + x, pred); + } + // Corner-only section of the row. + Memset(dest + x, top_row[max_base_x], width - x); + } + // Fill in corner-only rows. + for (; y < height; ++y) { + Memset(dest, top_row[max_base_x], width); + dest += stride; + } +} + +// 7.11.2.4 (7) angle < 90 +inline void DirectionalIntraPredictorZone1_SSE4_1( + void* dest_ptr, ptrdiff_t stride, const void* const top_ptr, + const int width, const int height, const int xstep, const bool upsampled) { + const auto* const top_row = static_cast<const uint16_t*>(top_ptr); + auto* dest = static_cast<uint16_t*>(dest_ptr); + stride /= sizeof(uint16_t); + const int upsample_shift = static_cast<int>(upsampled); + if (xstep == 64) { + DirectionalZone1_Step64(dest, stride, top_row, width, height); + return; + } + // Each base pixel paired with its following pixel, for hadd purposes. + const __m128i adjacency_shuffler = _mm_set_epi16( + 0x0908, 0x0706, 0x0706, 0x0504, 0x0504, 0x0302, 0x0302, 0x0100); + // This is equivalent to not shuffling at all. + const __m128i identity_shuffler = _mm_set_epi16( + 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100); + // This represents a trade-off between code size and speed. When upsampled + // is true, no shuffle is necessary. But to avoid in-loop branching, we + // would need 2 copies of the main function body. + const __m128i sampler = upsampled ? identity_shuffler : adjacency_shuffler; + if (width == 4) { + DirectionalZone1_4xH(dest, stride, top_row, height, xstep, upsampled, + sampler); + return; + } + if (width >= 32) { + DirectionalZone1_Large(dest, stride, top_row, width, height, xstep, + upsampled, sampler); + return; + } + const int index_scale_bits = 6 - upsample_shift; + const int max_base_x = ((width + height) - 1) << upsample_shift; + + const __m128i max_shift = _mm_set1_epi16(32); + const int base_step = 1 << upsample_shift; + const int base_step8 = base_step << 3; + + // No need to check for exceeding |max_base_x| in the loops. + if (((xstep * height) >> index_scale_bits) + base_step * width < max_base_x) { + int top_x = xstep; + int y = height; + do { + int top_base_x = top_x >> index_scale_bits; + // Permit negative values of |top_x|. + const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; + const __m128i shift = _mm_set1_epi16(shift_val); + const __m128i opposite_shift = _mm_sub_epi16(max_shift, shift); + const __m128i shifts = _mm_unpacklo_epi16(opposite_shift, shift); + int x = 0; + do { + const __m128i top_vals_0 = LoadUnaligned16(top_row + top_base_x); + const __m128i top_vals_1 = + LoadUnaligned16(top_row + top_base_x + (4 << upsample_shift)); + const __m128i pred = + CombineTopVals<false>(top_vals_0, top_vals_1, sampler, shifts); + StoreUnaligned16(dest + x, pred); + top_base_x += base_step8; + x += 8; + } while (x < width); + dest += stride; + top_x += xstep; + } while (--y != 0); + return; + } + + // General case. Blocks with width less than 32 do not benefit from x-wise + // loop splitting, but do benefit from using memset on appropriate rows. + + // Each 16-bit value here corresponds to a position that may exceed + // |max_base_x|. When added to the top_base_x, it is used to mask values + // that pass the end of |top|. Starting from 1 to simulate "cmpge" which is + // not supported for packed integers. + const __m128i offsets = + _mm_set_epi32(0x00080007, 0x00060005, 0x00040003, 0x00020001); + + const __m128i max_base_x_vect = _mm_set1_epi16(max_base_x); + const __m128i final_top_val = _mm_set1_epi16(top_row[max_base_x]); + const __m128i base_step8_vect = _mm_set1_epi16(base_step8); + + // All rows from |min_corner_only_y| down will simply use memcpy. + // |max_base_x| is always greater than |height|, so clipping the denominator + // to 1 is enough to make the logic work. + const int xstep_units = std::max(xstep >> index_scale_bits, 1); + const int min_corner_only_y = std::min(max_base_x / xstep_units, height); + + int top_x = xstep; + int y = 0; + for (; y < min_corner_only_y; ++y, dest += stride, top_x += xstep) { + int top_base_x = top_x >> index_scale_bits; + + const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; + const __m128i shift = _mm_set1_epi16(shift_val); + const __m128i opposite_shift = _mm_sub_epi16(max_shift, shift); + const __m128i shifts = _mm_unpacklo_epi16(opposite_shift, shift); + __m128i top_index_vect = _mm_set1_epi16(top_base_x); + top_index_vect = _mm_add_epi16(top_index_vect, offsets); + + for (int x = 0; x < width; x += 8, top_base_x += base_step8, + top_index_vect = _mm_add_epi16(top_index_vect, base_step8_vect)) { + const __m128i top_vals_0 = LoadUnaligned16(top_row + top_base_x); + const __m128i top_vals_1 = + LoadUnaligned16(top_row + top_base_x + (4 << upsample_shift)); + const __m128i pred = + CombineTopVals<true>(top_vals_0, top_vals_1, sampler, shifts, + top_index_vect, final_top_val, max_base_x_vect); + StoreUnaligned16(dest + x, pred); + } + } + + // Fill in corner-only rows. + for (; y < height; ++y) { + Memset(dest, top_row[max_base_x], width); + dest += stride; + } +} + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(10); + assert(dsp != nullptr); + static_cast<void>(dsp); +#if DSP_ENABLED_10BPP_SSE4_1(DirectionalIntraPredictorZone1) + dsp->directional_intra_predictor_zone1 = + DirectionalIntraPredictorZone1_SSE4_1; +#endif +} + +} // namespace +} // namespace high_bitdepth + +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +void IntraPredDirectionalInit_SSE4_1() { + low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + high_bitdepth::Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_TARGETING_SSE4_1 +namespace libgav1 { +namespace dsp { + +void IntraPredDirectionalInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/intrapred_directional_sse4.h b/src/dsp/x86/intrapred_directional_sse4.h new file mode 100644 index 0000000..b352450 --- /dev/null +++ b/src/dsp/x86/intrapred_directional_sse4.h @@ -0,0 +1,54 @@ +/* + * Copyright 2021 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_INTRAPRED_DIRECTIONAL_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_INTRAPRED_DIRECTIONAL_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::directional_intra_predictor_zone*, see the defines below for +// specifics. These functions are not thread-safe. +void IntraPredDirectionalInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +// If sse4 is enabled and the baseline isn't set due to a higher level of +// optimization being enabled, signal the sse4 implementation should be used. +#if LIBGAV1_TARGETING_SSE4_1 +#ifndef LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone1 +#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone1 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone2 +#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone2 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone3 +#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone3 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_DirectionalIntraPredictorZone1 +#define LIBGAV1_Dsp10bpp_DirectionalIntraPredictorZone1 LIBGAV1_CPU_SSE4_1 +#endif + +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_X86_INTRAPRED_DIRECTIONAL_SSE4_H_ diff --git a/src/dsp/x86/intrapred_filter_sse4.cc b/src/dsp/x86/intrapred_filter_sse4.cc new file mode 100644 index 0000000..022af8d --- /dev/null +++ b/src/dsp/x86/intrapred_filter_sse4.cc @@ -0,0 +1,432 @@ +// Copyright 2021 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intrapred_filter.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 + +#include <xmmintrin.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/dsp/x86/transpose_sse4.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace { + +//------------------------------------------------------------------------------ +// FilterIntraPredictor_SSE4_1 +// Section 7.11.2.3. Recursive intra prediction process +// This filter applies recursively to 4x2 sub-blocks within the transform block, +// meaning that the predicted pixels in each sub-block are used as inputs to +// sub-blocks below and to the right, if present. +// +// Each output value in the sub-block is predicted by a different filter applied +// to the same array of top-left, top, and left values. If fn refers to the +// output of the nth filter, given this block: +// TL T0 T1 T2 T3 +// L0 f0 f1 f2 f3 +// L1 f4 f5 f6 f7 +// The filter input order is p0, p1, p2, p3, p4, p5, p6: +// p0 p1 p2 p3 p4 +// p5 f0 f1 f2 f3 +// p6 f4 f5 f6 f7 +// Filters usually apply to 8 values for convenience, so in this case we fix +// the 8th filter tap to 0 and disregard the value of the 8th input. + +// This shuffle mask selects 32-bit blocks in the order 0, 1, 0, 1, which +// duplicates the first 8 bytes of a 128-bit vector into the second 8 bytes. +constexpr int kDuplicateFirstHalf = 0x44; + +// Apply all filter taps to the given 7 packed 16-bit values, keeping the 8th +// at zero to preserve the sum. +// |pixels| contains p0-p7 in order as shown above. +// |taps_0_1| contains the filter kernels used to predict f0 and f1, and so on. +inline void Filter4x2_SSE4_1(uint8_t* dst, const ptrdiff_t stride, + const __m128i& pixels, const __m128i& taps_0_1, + const __m128i& taps_2_3, const __m128i& taps_4_5, + const __m128i& taps_6_7) { + const __m128i mul_0_01 = _mm_maddubs_epi16(pixels, taps_0_1); + const __m128i mul_0_23 = _mm_maddubs_epi16(pixels, taps_2_3); + // |output_half| contains 8 partial sums for f0-f7. + __m128i output_half = _mm_hadd_epi16(mul_0_01, mul_0_23); + __m128i output = _mm_hadd_epi16(output_half, output_half); + const __m128i output_row0 = + _mm_packus_epi16(RightShiftWithRounding_S16(output, 4), + /* unused half */ output); + Store4(dst, output_row0); + const __m128i mul_1_01 = _mm_maddubs_epi16(pixels, taps_4_5); + const __m128i mul_1_23 = _mm_maddubs_epi16(pixels, taps_6_7); + output_half = _mm_hadd_epi16(mul_1_01, mul_1_23); + output = _mm_hadd_epi16(output_half, output_half); + const __m128i output_row1 = + _mm_packus_epi16(RightShiftWithRounding_S16(output, 4), + /* arbitrary pack arg */ output); + Store4(dst + stride, output_row1); +} + +// 4xH transform sizes are given special treatment because LoadLo8 goes out +// of bounds and every block involves the left column. The top-left pixel, p0, +// is stored in the top buffer for the first 4x2, but comes from the left buffer +// for successive blocks. This implementation takes advantage of the fact +// that the p5 and p6 for each sub-block come solely from the |left_ptr| buffer, +// using shifts to arrange things to fit reusable shuffle vectors. +inline void Filter4xH(uint8_t* dest, ptrdiff_t stride, + const uint8_t* const top_ptr, + const uint8_t* const left_ptr, FilterIntraPredictor pred, + const int height) { + // Two filter kernels per vector. + const __m128i taps_0_1 = LoadAligned16(kFilterIntraTaps[pred][0]); + const __m128i taps_2_3 = LoadAligned16(kFilterIntraTaps[pred][2]); + const __m128i taps_4_5 = LoadAligned16(kFilterIntraTaps[pred][4]); + const __m128i taps_6_7 = LoadAligned16(kFilterIntraTaps[pred][6]); + __m128i top = Load4(top_ptr - 1); + __m128i pixels = _mm_insert_epi8(top, top_ptr[3], 4); + __m128i left = (height == 4 ? Load4(left_ptr) : LoadLo8(left_ptr)); + left = _mm_slli_si128(left, 5); + + // Relative pixels: top[-1], top[0], top[1], top[2], top[3], left[0], left[1], + // left[2], left[3], left[4], left[5], left[6], left[7] + // Let rn represent a pixel usable as pn for the 4x2 after this one. We get: + // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + // p0 p1 p2 p3 p4 p5 p6 r5 r6 ... + // r0 + pixels = _mm_or_si128(left, pixels); + + // Two sets of the same input pixels to apply two filters at once. + pixels = _mm_shuffle_epi32(pixels, kDuplicateFirstHalf); + Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + dest += stride; // Move to y = 1. + pixels = Load4(dest); + + // Relative pixels: top[0], top[1], top[2], top[3], empty, left[-2], left[-1], + // left[0], left[1], ... + // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + // p1 p2 p3 p4 xx xx p0 p5 p6 r5 r6 ... + // r0 + pixels = _mm_or_si128(left, pixels); + + // This mask rearranges bytes in the order: 6, 0, 1, 2, 3, 7, 8, 15. The last + // byte is an unused value, which shall be multiplied by 0 when we apply the + // filter. + constexpr int64_t kInsertTopLeftFirstMask = 0x0F08070302010006; + + // Insert left[-1] in front as TL and put left[0] and left[1] at the end. + const __m128i pixel_order1 = _mm_set1_epi64x(kInsertTopLeftFirstMask); + pixels = _mm_shuffle_epi8(pixels, pixel_order1); + dest += stride; // Move to y = 2. + Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + dest += stride; // Move to y = 3. + + // Compute the middle 8 rows before using common code for the final 4 rows, in + // order to fit the assumption that |left| has the next TL at position 8. + if (height == 16) { + // This shift allows us to use pixel_order2 twice after shifting by 2 later. + left = _mm_slli_si128(left, 1); + pixels = Load4(dest); + + // Relative pixels: top[0], top[1], top[2], top[3], empty, empty, left[-4], + // left[-3], left[-2], left[-1], left[0], left[1], left[2], left[3] + // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + // p1 p2 p3 p4 xx xx xx xx xx p0 p5 p6 r5 r6 ... + // r0 + pixels = _mm_or_si128(left, pixels); + + // This mask rearranges bytes in the order: 9, 0, 1, 2, 3, 7, 8, 15. The + // last byte is an unused value, as above. The top-left was shifted to + // position nine to keep two empty spaces after the top pixels. + constexpr int64_t kInsertTopLeftSecondMask = 0x0F0B0A0302010009; + + // Insert (relative) left[-1] in front as TL and put left[0] and left[1] at + // the end. + const __m128i pixel_order2 = _mm_set1_epi64x(kInsertTopLeftSecondMask); + pixels = _mm_shuffle_epi8(pixels, pixel_order2); + dest += stride; // Move to y = 4. + + // First 4x2 in the if body. + Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + + // Clear all but final pixel in the first 8 of left column. + __m128i keep_top_left = _mm_srli_si128(left, 13); + dest += stride; // Move to y = 5. + pixels = Load4(dest); + left = _mm_srli_si128(left, 2); + + // Relative pixels: top[0], top[1], top[2], top[3], left[-6], + // left[-5], left[-4], left[-3], left[-2], left[-1], left[0], left[1] + // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + // p1 p2 p3 p4 xx xx xx xx xx p0 p5 p6 r5 r6 ... + // r0 + pixels = _mm_or_si128(left, pixels); + left = LoadLo8(left_ptr + 8); + + pixels = _mm_shuffle_epi8(pixels, pixel_order2); + dest += stride; // Move to y = 6. + + // Second 4x2 in the if body. + Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + + // Position TL value so we can use pixel_order1. + keep_top_left = _mm_slli_si128(keep_top_left, 6); + dest += stride; // Move to y = 7. + pixels = Load4(dest); + left = _mm_slli_si128(left, 7); + left = _mm_or_si128(left, keep_top_left); + + // Relative pixels: top[0], top[1], top[2], top[3], empty, empty, + // left[-1], left[0], left[1], left[2], left[3], ... + // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + // p1 p2 p3 p4 xx xx p0 p5 p6 r5 r6 ... + // r0 + pixels = _mm_or_si128(left, pixels); + pixels = _mm_shuffle_epi8(pixels, pixel_order1); + dest += stride; // Move to y = 8. + + // Third 4x2 in the if body. + Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + dest += stride; // Move to y = 9. + + // Prepare final inputs. + pixels = Load4(dest); + left = _mm_srli_si128(left, 2); + + // Relative pixels: top[0], top[1], top[2], top[3], left[-3], left[-2] + // left[-1], left[0], left[1], left[2], left[3], ... + // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + // p1 p2 p3 p4 xx xx p0 p5 p6 r5 r6 ... + // r0 + pixels = _mm_or_si128(left, pixels); + pixels = _mm_shuffle_epi8(pixels, pixel_order1); + dest += stride; // Move to y = 10. + + // Fourth 4x2 in the if body. + Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + dest += stride; // Move to y = 11. + } + + // In both the 8 and 16 case at this point, we can assume that |left| has the + // next TL at position 8. + if (height > 4) { + // Erase prior left pixels by shifting TL to position 0. + left = _mm_srli_si128(left, 8); + left = _mm_slli_si128(left, 6); + pixels = Load4(dest); + + // Relative pixels: top[0], top[1], top[2], top[3], empty, empty, + // left[-1], left[0], left[1], left[2], left[3], ... + // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + // p1 p2 p3 p4 xx xx p0 p5 p6 r5 r6 ... + // r0 + pixels = _mm_or_si128(left, pixels); + pixels = _mm_shuffle_epi8(pixels, pixel_order1); + dest += stride; // Move to y = 12 or 4. + + // First of final two 4x2 blocks. + Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + dest += stride; // Move to y = 13 or 5. + pixels = Load4(dest); + left = _mm_srli_si128(left, 2); + + // Relative pixels: top[0], top[1], top[2], top[3], left[-3], left[-2] + // left[-1], left[0], left[1], left[2], left[3], ... + // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + // p1 p2 p3 p4 xx xx p0 p5 p6 r5 r6 ... + // r0 + pixels = _mm_or_si128(left, pixels); + pixels = _mm_shuffle_epi8(pixels, pixel_order1); + dest += stride; // Move to y = 14 or 6. + + // Last of final two 4x2 blocks. + Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + } +} + +void FilterIntraPredictor_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column, + FilterIntraPredictor pred, const int width, + const int height) { + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + auto* dst = static_cast<uint8_t*>(dest); + if (width == 4) { + Filter4xH(dst, stride, top_ptr, left_ptr, pred, height); + return; + } + + // There is one set of 7 taps for each of the 4x2 output pixels. + const __m128i taps_0_1 = LoadAligned16(kFilterIntraTaps[pred][0]); + const __m128i taps_2_3 = LoadAligned16(kFilterIntraTaps[pred][2]); + const __m128i taps_4_5 = LoadAligned16(kFilterIntraTaps[pred][4]); + const __m128i taps_6_7 = LoadAligned16(kFilterIntraTaps[pred][6]); + + // This mask rearranges bytes in the order: 0, 1, 2, 3, 4, 8, 9, 15. The 15 at + // the end is an unused value, which shall be multiplied by 0 when we apply + // the filter. + constexpr int64_t kCondenseLeftMask = 0x0F09080403020100; + + // Takes the "left section" and puts it right after p0-p4. + const __m128i pixel_order1 = _mm_set1_epi64x(kCondenseLeftMask); + + // This mask rearranges bytes in the order: 8, 0, 1, 2, 3, 9, 10, 15. The last + // byte is unused as above. + constexpr int64_t kInsertTopLeftMask = 0x0F0A090302010008; + + // Shuffles the "top left" from the left section, to the front. Used when + // grabbing data from left_column and not top_row. + const __m128i pixel_order2 = _mm_set1_epi64x(kInsertTopLeftMask); + + // This first pass takes care of the cases where the top left pixel comes from + // top_row. + __m128i pixels = LoadLo8(top_ptr - 1); + __m128i left = _mm_slli_si128(Load4(left_column), 8); + pixels = _mm_or_si128(pixels, left); + + // Two sets of the same pixels to multiply with two sets of taps. + pixels = _mm_shuffle_epi8(pixels, pixel_order1); + Filter4x2_SSE4_1(dst, stride, pixels, taps_0_1, taps_2_3, taps_4_5, taps_6_7); + left = _mm_srli_si128(left, 1); + + // Load + pixels = Load4(dst + stride); + + // Because of the above shift, this OR 'invades' the final of the first 8 + // bytes of |pixels|. This is acceptable because the 8th filter tap is always + // a padded 0. + pixels = _mm_or_si128(pixels, left); + pixels = _mm_shuffle_epi8(pixels, pixel_order2); + const ptrdiff_t stride2 = stride << 1; + const ptrdiff_t stride4 = stride << 2; + Filter4x2_SSE4_1(dst + stride2, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + dst += 4; + for (int x = 3; x < width - 4; x += 4) { + pixels = Load4(top_ptr + x); + pixels = _mm_insert_epi8(pixels, top_ptr[x + 4], 4); + pixels = _mm_insert_epi8(pixels, dst[-1], 5); + pixels = _mm_insert_epi8(pixels, dst[stride - 1], 6); + + // Duplicate bottom half into upper half. + pixels = _mm_shuffle_epi32(pixels, kDuplicateFirstHalf); + Filter4x2_SSE4_1(dst, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + pixels = Load4(dst + stride - 1); + pixels = _mm_insert_epi8(pixels, dst[stride + 3], 4); + pixels = _mm_insert_epi8(pixels, dst[stride2 - 1], 5); + pixels = _mm_insert_epi8(pixels, dst[stride + stride2 - 1], 6); + + // Duplicate bottom half into upper half. + pixels = _mm_shuffle_epi32(pixels, kDuplicateFirstHalf); + Filter4x2_SSE4_1(dst + stride2, stride, pixels, taps_0_1, taps_2_3, + taps_4_5, taps_6_7); + dst += 4; + } + + // Now we handle heights that reference previous blocks rather than top_row. + for (int y = 4; y < height; y += 4) { + // Leftmost 4x4 block for this height. + dst -= width; + dst += stride4; + + // Top Left is not available by offset in these leftmost blocks. + pixels = Load4(dst - stride); + left = _mm_slli_si128(Load4(left_ptr + y - 1), 8); + left = _mm_insert_epi8(left, left_ptr[y + 3], 12); + pixels = _mm_or_si128(pixels, left); + pixels = _mm_shuffle_epi8(pixels, pixel_order2); + Filter4x2_SSE4_1(dst, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + + // The bytes shifted into positions 6 and 7 will be ignored by the shuffle. + left = _mm_srli_si128(left, 2); + pixels = Load4(dst + stride); + pixels = _mm_or_si128(pixels, left); + pixels = _mm_shuffle_epi8(pixels, pixel_order2); + Filter4x2_SSE4_1(dst + stride2, stride, pixels, taps_0_1, taps_2_3, + taps_4_5, taps_6_7); + + dst += 4; + + // Remaining 4x4 blocks for this height. + for (int x = 4; x < width; x += 4) { + pixels = Load4(dst - stride - 1); + pixels = _mm_insert_epi8(pixels, dst[-stride + 3], 4); + pixels = _mm_insert_epi8(pixels, dst[-1], 5); + pixels = _mm_insert_epi8(pixels, dst[stride - 1], 6); + + // Duplicate bottom half into upper half. + pixels = _mm_shuffle_epi32(pixels, kDuplicateFirstHalf); + Filter4x2_SSE4_1(dst, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + pixels = Load4(dst + stride - 1); + pixels = _mm_insert_epi8(pixels, dst[stride + 3], 4); + pixels = _mm_insert_epi8(pixels, dst[stride2 - 1], 5); + pixels = _mm_insert_epi8(pixels, dst[stride2 + stride - 1], 6); + + // Duplicate bottom half into upper half. + pixels = _mm_shuffle_epi32(pixels, kDuplicateFirstHalf); + Filter4x2_SSE4_1(dst + stride2, stride, pixels, taps_0_1, taps_2_3, + taps_4_5, taps_6_7); + dst += 4; + } + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + static_cast<void>(dsp); +// These guards check if this version of the function was not superseded by +// a higher optimization level, such as AVX. The corresponding #define also +// prevents the C version from being added to the table. +#if DSP_ENABLED_8BPP_SSE4_1(FilterIntraPredictor) + dsp->filter_intra_predictor = FilterIntraPredictor_SSE4_1; +#endif +} + +} // namespace + +void IntraPredFilterInit_SSE4_1() { Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_TARGETING_SSE4_1 +namespace libgav1 { +namespace dsp { + +void IntraPredFilterInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/intrapred_filter_sse4.h b/src/dsp/x86/intrapred_filter_sse4.h new file mode 100644 index 0000000..ce28f93 --- /dev/null +++ b/src/dsp/x86/intrapred_filter_sse4.h @@ -0,0 +1,41 @@ +/* + * Copyright 2021 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_INTRAPRED_FILTER_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_INTRAPRED_FILTER_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::filter_intra_predictor, see the defines below for specifics. +// These functions are not thread-safe. +void IntraPredFilterInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +// If sse4 is enabled and the baseline isn't set due to a higher level of +// optimization being enabled, signal the sse4 implementation should be used. +#if LIBGAV1_TARGETING_SSE4_1 +#ifndef LIBGAV1_Dsp8bpp_FilterIntraPredictor +#define LIBGAV1_Dsp8bpp_FilterIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_X86_INTRAPRED_FILTER_SSE4_H_ diff --git a/src/dsp/x86/intrapred_smooth_sse4.cc b/src/dsp/x86/intrapred_smooth_sse4.cc index e944ea3..de9f551 100644 --- a/src/dsp/x86/intrapred_smooth_sse4.cc +++ b/src/dsp/x86/intrapred_smooth_sse4.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "src/dsp/intrapred.h" +#include "src/dsp/intrapred_smooth.h" #include "src/utils/cpu.h" #if LIBGAV1_TARGETING_SSE4_1 @@ -22,12 +22,12 @@ #include <cassert> #include <cstddef> #include <cstdint> -#include <cstring> // memcpy #include "src/dsp/constants.h" #include "src/dsp/dsp.h" #include "src/dsp/x86/common_sse4.h" #include "src/utils/common.h" +#include "src/utils/constants.h" namespace libgav1 { namespace dsp { @@ -67,29 +67,6 @@ inline void WriteSmoothHorizontalSum4(void* const dest, const __m128i& left, Store4(dest, _mm_shuffle_epi8(pred, cvtepi32_epi8)); } -template <int y_mask> -inline __m128i SmoothVerticalSum4(const __m128i& top, const __m128i& weights, - const __m128i& scaled_bottom_left) { - const __m128i weights_y = _mm_shuffle_epi32(weights, y_mask); - const __m128i weighted_top_y = _mm_mullo_epi16(top, weights_y); - const __m128i scaled_bottom_left_y = - _mm_shuffle_epi32(scaled_bottom_left, y_mask); - return _mm_add_epi32(scaled_bottom_left_y, weighted_top_y); -} - -template <int y_mask> -inline void WriteSmoothVerticalSum4(uint8_t* dest, const __m128i& top, - const __m128i& weights, - const __m128i& scaled_bottom_left, - const __m128i& round) { - __m128i pred_sum = - SmoothVerticalSum4<y_mask>(top, weights, scaled_bottom_left); - // Equivalent to RightShiftWithRounding(pred[x][y], 8). - pred_sum = _mm_srli_epi32(_mm_add_epi32(pred_sum, round), 8); - const __m128i cvtepi32_epi8 = _mm_set1_epi32(0x0C080400); - Store4(dest, _mm_shuffle_epi8(pred_sum, cvtepi32_epi8)); -} - // For SMOOTH_H, |pixels| is the repeated left value for the row. For SMOOTH_V, // |pixels| is a segment of the top row or the whole top row, and |weights| is // repeated. diff --git a/src/dsp/x86/intrapred_smooth_sse4.h b/src/dsp/x86/intrapred_smooth_sse4.h new file mode 100644 index 0000000..9353371 --- /dev/null +++ b/src/dsp/x86/intrapred_smooth_sse4.h @@ -0,0 +1,318 @@ +/* + * Copyright 2021 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_INTRAPRED_SMOOTH_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_INTRAPRED_SMOOTH_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::intra_predictors[][kIntraPredictorSmooth.*]. +// This function is not thread-safe. +void IntraPredSmoothInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +// If sse4 is enabled and the baseline isn't set due to a higher level of +// optimization being enabled, signal the sse4 implementation should be used. +#if LIBGAV1_TARGETING_SSE4_1 +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmooth LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmooth LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmooth LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmooth LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_X86_INTRAPRED_SMOOTH_SSE4_H_ diff --git a/src/dsp/x86/intrapred_sse4.cc b/src/dsp/x86/intrapred_sse4.cc index 9938dfe..063929d 100644 --- a/src/dsp/x86/intrapred_sse4.cc +++ b/src/dsp/x86/intrapred_sse4.cc @@ -23,13 +23,14 @@ #include <cassert> #include <cstddef> #include <cstdint> -#include <cstring> // memcpy +#include <cstring> #include "src/dsp/constants.h" #include "src/dsp/dsp.h" #include "src/dsp/x86/common_sse4.h" #include "src/dsp/x86/transpose_sse4.h" #include "src/utils/common.h" +#include "src/utils/constants.h" namespace libgav1 { namespace dsp { @@ -51,10 +52,6 @@ inline __m128i DivideByMultiplyShift_U32(const __m128i dividend) { return _mm_mulhi_epi16(interm, _mm_cvtsi32_si128(multiplier)); } -// This shuffle mask selects 32-bit blocks in the order 0, 1, 0, 1, which -// duplicates the first 8 bytes of a 128-bit vector into the second 8 bytes. -constexpr int kDuplicateFirstHalf = 0x44; - //------------------------------------------------------------------------------ // DcPredFuncs_SSE4_1 @@ -1408,1337 +1405,6 @@ void Paeth64x64_SSE4_1(void* const dest, ptrdiff_t stride, WritePaeth16x16(dst + 48, stride, top_left, top_3, left_3); } -//------------------------------------------------------------------------------ -// 7.11.2.4. Directional intra prediction process - -// Special case: An |xstep| of 64 corresponds to an angle delta of 45, meaning -// upsampling is ruled out. In addition, the bits masked by 0x3F for -// |shift_val| are 0 for all multiples of 64, so the formula -// val = top[top_base_x]*shift + top[top_base_x+1]*(32-shift), reduces to -// val = top[top_base_x+1] << 5, meaning only the second set of pixels is -// involved in the output. Hence |top| is offset by 1. -inline void DirectionalZone1_Step64(uint8_t* dst, ptrdiff_t stride, - const uint8_t* const top, const int width, - const int height) { - ptrdiff_t offset = 1; - if (height == 4) { - memcpy(dst, top + offset, width); - dst += stride; - memcpy(dst, top + offset + 1, width); - dst += stride; - memcpy(dst, top + offset + 2, width); - dst += stride; - memcpy(dst, top + offset + 3, width); - return; - } - int y = 0; - do { - memcpy(dst, top + offset, width); - dst += stride; - memcpy(dst, top + offset + 1, width); - dst += stride; - memcpy(dst, top + offset + 2, width); - dst += stride; - memcpy(dst, top + offset + 3, width); - dst += stride; - memcpy(dst, top + offset + 4, width); - dst += stride; - memcpy(dst, top + offset + 5, width); - dst += stride; - memcpy(dst, top + offset + 6, width); - dst += stride; - memcpy(dst, top + offset + 7, width); - dst += stride; - - offset += 8; - y += 8; - } while (y < height); -} - -inline void DirectionalZone1_4xH(uint8_t* dst, ptrdiff_t stride, - const uint8_t* const top, const int height, - const int xstep, const bool upsampled) { - const int upsample_shift = static_cast<int>(upsampled); - const int scale_bits = 6 - upsample_shift; - const int rounding_bits = 5; - const int max_base_x = (height + 3 /* width - 1 */) << upsample_shift; - const __m128i final_top_val = _mm_set1_epi16(top[max_base_x]); - const __m128i sampler = upsampled ? _mm_set_epi64x(0, 0x0706050403020100) - : _mm_set_epi64x(0, 0x0403030202010100); - // Each 16-bit value here corresponds to a position that may exceed - // |max_base_x|. When added to the top_base_x, it is used to mask values - // that pass the end of |top|. Starting from 1 to simulate "cmpge" which is - // not supported for packed integers. - const __m128i offsets = - _mm_set_epi32(0x00080007, 0x00060005, 0x00040003, 0x00020001); - - // All rows from |min_corner_only_y| down will simply use memcpy. |max_base_x| - // is always greater than |height|, so clipping to 1 is enough to make the - // logic work. - const int xstep_units = std::max(xstep >> scale_bits, 1); - const int min_corner_only_y = std::min(max_base_x / xstep_units, height); - - // Rows up to this y-value can be computed without checking for bounds. - int y = 0; - int top_x = xstep; - - for (; y < min_corner_only_y; ++y, dst += stride, top_x += xstep) { - const int top_base_x = top_x >> scale_bits; - - // Permit negative values of |top_x|. - const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; - const __m128i shift = _mm_set1_epi8(shift_val); - const __m128i max_shift = _mm_set1_epi8(32); - const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); - const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); - __m128i top_index_vect = _mm_set1_epi16(top_base_x); - top_index_vect = _mm_add_epi16(top_index_vect, offsets); - const __m128i max_base_x_vect = _mm_set1_epi16(max_base_x); - - // Load 8 values because we will select the sampled values based on - // |upsampled|. - const __m128i values = LoadLo8(top + top_base_x); - const __m128i sampled_values = _mm_shuffle_epi8(values, sampler); - const __m128i past_max = _mm_cmpgt_epi16(top_index_vect, max_base_x_vect); - __m128i prod = _mm_maddubs_epi16(sampled_values, shifts); - prod = RightShiftWithRounding_U16(prod, rounding_bits); - // Replace pixels from invalid range with top-right corner. - prod = _mm_blendv_epi8(prod, final_top_val, past_max); - Store4(dst, _mm_packus_epi16(prod, prod)); - } - - // Fill in corner-only rows. - for (; y < height; ++y) { - memset(dst, top[max_base_x], /* width */ 4); - dst += stride; - } -} - -// 7.11.2.4 (7) angle < 90 -inline void DirectionalZone1_Large(uint8_t* dest, ptrdiff_t stride, - const uint8_t* const top_row, - const int width, const int height, - const int xstep, const bool upsampled) { - const int upsample_shift = static_cast<int>(upsampled); - const __m128i sampler = - upsampled ? _mm_set_epi32(0x0F0E0D0C, 0x0B0A0908, 0x07060504, 0x03020100) - : _mm_set_epi32(0x08070706, 0x06050504, 0x04030302, 0x02010100); - const int scale_bits = 6 - upsample_shift; - const int max_base_x = ((width + height) - 1) << upsample_shift; - - const __m128i max_shift = _mm_set1_epi8(32); - const int rounding_bits = 5; - const int base_step = 1 << upsample_shift; - const int base_step8 = base_step << 3; - - // All rows from |min_corner_only_y| down will simply use memcpy. |max_base_x| - // is always greater than |height|, so clipping to 1 is enough to make the - // logic work. - const int xstep_units = std::max(xstep >> scale_bits, 1); - const int min_corner_only_y = std::min(max_base_x / xstep_units, height); - - // Rows up to this y-value can be computed without checking for bounds. - const int max_no_corner_y = std::min( - LeftShift((max_base_x - (base_step * width)), scale_bits) / xstep, - height); - // No need to check for exceeding |max_base_x| in the first loop. - int y = 0; - int top_x = xstep; - for (; y < max_no_corner_y; ++y, dest += stride, top_x += xstep) { - int top_base_x = top_x >> scale_bits; - // Permit negative values of |top_x|. - const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; - const __m128i shift = _mm_set1_epi8(shift_val); - const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); - const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); - int x = 0; - do { - const __m128i top_vals = LoadUnaligned16(top_row + top_base_x); - __m128i vals = _mm_shuffle_epi8(top_vals, sampler); - vals = _mm_maddubs_epi16(vals, shifts); - vals = RightShiftWithRounding_U16(vals, rounding_bits); - StoreLo8(dest + x, _mm_packus_epi16(vals, vals)); - top_base_x += base_step8; - x += 8; - } while (x < width); - } - - // Each 16-bit value here corresponds to a position that may exceed - // |max_base_x|. When added to the top_base_x, it is used to mask values - // that pass the end of |top|. Starting from 1 to simulate "cmpge" which is - // not supported for packed integers. - const __m128i offsets = - _mm_set_epi32(0x00080007, 0x00060005, 0x00040003, 0x00020001); - - const __m128i max_base_x_vect = _mm_set1_epi16(max_base_x); - const __m128i final_top_val = _mm_set1_epi16(top_row[max_base_x]); - const __m128i base_step8_vect = _mm_set1_epi16(base_step8); - for (; y < min_corner_only_y; ++y, dest += stride, top_x += xstep) { - int top_base_x = top_x >> scale_bits; - - const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; - const __m128i shift = _mm_set1_epi8(shift_val); - const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); - const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); - __m128i top_index_vect = _mm_set1_epi16(top_base_x); - top_index_vect = _mm_add_epi16(top_index_vect, offsets); - - int x = 0; - const int min_corner_only_x = - std::min(width, ((max_base_x - top_base_x) >> upsample_shift) + 7) & ~7; - for (; x < min_corner_only_x; - x += 8, top_base_x += base_step8, - top_index_vect = _mm_add_epi16(top_index_vect, base_step8_vect)) { - const __m128i past_max = _mm_cmpgt_epi16(top_index_vect, max_base_x_vect); - // Assuming a buffer zone of 8 bytes at the end of top_row, this prevents - // reading out of bounds. If all indices are past max and we don't need to - // use the loaded bytes at all, |top_base_x| becomes 0. |top_base_x| will - // reset for the next |y|. - top_base_x &= ~_mm_cvtsi128_si32(past_max); - const __m128i top_vals = LoadUnaligned16(top_row + top_base_x); - __m128i vals = _mm_shuffle_epi8(top_vals, sampler); - vals = _mm_maddubs_epi16(vals, shifts); - vals = RightShiftWithRounding_U16(vals, rounding_bits); - vals = _mm_blendv_epi8(vals, final_top_val, past_max); - StoreLo8(dest + x, _mm_packus_epi16(vals, vals)); - } - // Corner-only section of the row. - memset(dest + x, top_row[max_base_x], width - x); - } - // Fill in corner-only rows. - for (; y < height; ++y) { - memset(dest, top_row[max_base_x], width); - dest += stride; - } -} - -// 7.11.2.4 (7) angle < 90 -inline void DirectionalZone1_SSE4_1(uint8_t* dest, ptrdiff_t stride, - const uint8_t* const top_row, - const int width, const int height, - const int xstep, const bool upsampled) { - const int upsample_shift = static_cast<int>(upsampled); - if (xstep == 64) { - DirectionalZone1_Step64(dest, stride, top_row, width, height); - return; - } - if (width == 4) { - DirectionalZone1_4xH(dest, stride, top_row, height, xstep, upsampled); - return; - } - if (width >= 32) { - DirectionalZone1_Large(dest, stride, top_row, width, height, xstep, - upsampled); - return; - } - const __m128i sampler = - upsampled ? _mm_set_epi32(0x0F0E0D0C, 0x0B0A0908, 0x07060504, 0x03020100) - : _mm_set_epi32(0x08070706, 0x06050504, 0x04030302, 0x02010100); - const int scale_bits = 6 - upsample_shift; - const int max_base_x = ((width + height) - 1) << upsample_shift; - - const __m128i max_shift = _mm_set1_epi8(32); - const int rounding_bits = 5; - const int base_step = 1 << upsample_shift; - const int base_step8 = base_step << 3; - - // No need to check for exceeding |max_base_x| in the loops. - if (((xstep * height) >> scale_bits) + base_step * width < max_base_x) { - int top_x = xstep; - int y = 0; - do { - int top_base_x = top_x >> scale_bits; - // Permit negative values of |top_x|. - const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; - const __m128i shift = _mm_set1_epi8(shift_val); - const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); - const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); - int x = 0; - do { - const __m128i top_vals = LoadUnaligned16(top_row + top_base_x); - __m128i vals = _mm_shuffle_epi8(top_vals, sampler); - vals = _mm_maddubs_epi16(vals, shifts); - vals = RightShiftWithRounding_U16(vals, rounding_bits); - StoreLo8(dest + x, _mm_packus_epi16(vals, vals)); - top_base_x += base_step8; - x += 8; - } while (x < width); - dest += stride; - top_x += xstep; - } while (++y < height); - return; - } - - // Each 16-bit value here corresponds to a position that may exceed - // |max_base_x|. When added to the top_base_x, it is used to mask values - // that pass the end of |top|. Starting from 1 to simulate "cmpge" which is - // not supported for packed integers. - const __m128i offsets = - _mm_set_epi32(0x00080007, 0x00060005, 0x00040003, 0x00020001); - - const __m128i max_base_x_vect = _mm_set1_epi16(max_base_x); - const __m128i final_top_val = _mm_set1_epi16(top_row[max_base_x]); - const __m128i base_step8_vect = _mm_set1_epi16(base_step8); - int top_x = xstep; - int y = 0; - do { - int top_base_x = top_x >> scale_bits; - - if (top_base_x >= max_base_x) { - for (int i = y; i < height; ++i) { - memset(dest, top_row[max_base_x], width); - dest += stride; - } - return; - } - - const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; - const __m128i shift = _mm_set1_epi8(shift_val); - const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); - const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); - __m128i top_index_vect = _mm_set1_epi16(top_base_x); - top_index_vect = _mm_add_epi16(top_index_vect, offsets); - - int x = 0; - for (; x < width - 8; - x += 8, top_base_x += base_step8, - top_index_vect = _mm_add_epi16(top_index_vect, base_step8_vect)) { - const __m128i past_max = _mm_cmpgt_epi16(top_index_vect, max_base_x_vect); - // Assuming a buffer zone of 8 bytes at the end of top_row, this prevents - // reading out of bounds. If all indices are past max and we don't need to - // use the loaded bytes at all, |top_base_x| becomes 0. |top_base_x| will - // reset for the next |y|. - top_base_x &= ~_mm_cvtsi128_si32(past_max); - const __m128i top_vals = LoadUnaligned16(top_row + top_base_x); - __m128i vals = _mm_shuffle_epi8(top_vals, sampler); - vals = _mm_maddubs_epi16(vals, shifts); - vals = RightShiftWithRounding_U16(vals, rounding_bits); - vals = _mm_blendv_epi8(vals, final_top_val, past_max); - StoreLo8(dest + x, _mm_packus_epi16(vals, vals)); - } - const __m128i past_max = _mm_cmpgt_epi16(top_index_vect, max_base_x_vect); - __m128i vals; - if (upsampled) { - vals = LoadUnaligned16(top_row + top_base_x); - } else { - const __m128i top_vals = LoadLo8(top_row + top_base_x); - vals = _mm_shuffle_epi8(top_vals, sampler); - vals = _mm_insert_epi8(vals, top_row[top_base_x + 8], 15); - } - vals = _mm_maddubs_epi16(vals, shifts); - vals = RightShiftWithRounding_U16(vals, rounding_bits); - vals = _mm_blendv_epi8(vals, final_top_val, past_max); - StoreLo8(dest + x, _mm_packus_epi16(vals, vals)); - dest += stride; - top_x += xstep; - } while (++y < height); -} - -void DirectionalIntraPredictorZone1_SSE4_1(void* const dest, ptrdiff_t stride, - const void* const top_row, - const int width, const int height, - const int xstep, - const bool upsampled_top) { - const auto* const top_ptr = static_cast<const uint8_t*>(top_row); - auto* dst = static_cast<uint8_t*>(dest); - DirectionalZone1_SSE4_1(dst, stride, top_ptr, width, height, xstep, - upsampled_top); -} - -template <bool upsampled> -inline void DirectionalZone3_4x4(uint8_t* dest, ptrdiff_t stride, - const uint8_t* const left_column, - const int base_left_y, const int ystep) { - // For use in the non-upsampled case. - const __m128i sampler = _mm_set_epi64x(0, 0x0403030202010100); - const int upsample_shift = static_cast<int>(upsampled); - const int scale_bits = 6 - upsample_shift; - const __m128i max_shift = _mm_set1_epi8(32); - const int rounding_bits = 5; - - __m128i result_block[4]; - for (int x = 0, left_y = base_left_y; x < 4; x++, left_y += ystep) { - const int left_base_y = left_y >> scale_bits; - const int shift_val = ((left_y << upsample_shift) & 0x3F) >> 1; - const __m128i shift = _mm_set1_epi8(shift_val); - const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); - const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); - __m128i vals; - if (upsampled) { - vals = LoadLo8(left_column + left_base_y); - } else { - const __m128i top_vals = LoadLo8(left_column + left_base_y); - vals = _mm_shuffle_epi8(top_vals, sampler); - } - vals = _mm_maddubs_epi16(vals, shifts); - vals = RightShiftWithRounding_U16(vals, rounding_bits); - result_block[x] = _mm_packus_epi16(vals, vals); - } - const __m128i result = Transpose4x4_U8(result_block); - // This is result_row0. - Store4(dest, result); - dest += stride; - const int result_row1 = _mm_extract_epi32(result, 1); - memcpy(dest, &result_row1, sizeof(result_row1)); - dest += stride; - const int result_row2 = _mm_extract_epi32(result, 2); - memcpy(dest, &result_row2, sizeof(result_row2)); - dest += stride; - const int result_row3 = _mm_extract_epi32(result, 3); - memcpy(dest, &result_row3, sizeof(result_row3)); -} - -template <bool upsampled, int height> -inline void DirectionalZone3_8xH(uint8_t* dest, ptrdiff_t stride, - const uint8_t* const left_column, - const int base_left_y, const int ystep) { - // For use in the non-upsampled case. - const __m128i sampler = - _mm_set_epi64x(0x0807070606050504, 0x0403030202010100); - const int upsample_shift = static_cast<int>(upsampled); - const int scale_bits = 6 - upsample_shift; - const __m128i max_shift = _mm_set1_epi8(32); - const int rounding_bits = 5; - - __m128i result_block[8]; - for (int x = 0, left_y = base_left_y; x < 8; x++, left_y += ystep) { - const int left_base_y = left_y >> scale_bits; - const int shift_val = (LeftShift(left_y, upsample_shift) & 0x3F) >> 1; - const __m128i shift = _mm_set1_epi8(shift_val); - const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); - const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); - __m128i vals; - if (upsampled) { - vals = LoadUnaligned16(left_column + left_base_y); - } else { - const __m128i top_vals = LoadUnaligned16(left_column + left_base_y); - vals = _mm_shuffle_epi8(top_vals, sampler); - } - vals = _mm_maddubs_epi16(vals, shifts); - result_block[x] = RightShiftWithRounding_U16(vals, rounding_bits); - } - Transpose8x8_U16(result_block, result_block); - for (int y = 0; y < height; ++y) { - StoreLo8(dest, _mm_packus_epi16(result_block[y], result_block[y])); - dest += stride; - } -} - -// 7.11.2.4 (9) angle > 180 -void DirectionalIntraPredictorZone3_SSE4_1(void* dest, ptrdiff_t stride, - const void* const left_column, - const int width, const int height, - const int ystep, - const bool upsampled) { - const auto* const left_ptr = static_cast<const uint8_t*>(left_column); - auto* dst = static_cast<uint8_t*>(dest); - const int upsample_shift = static_cast<int>(upsampled); - if (width == 4 || height == 4) { - const ptrdiff_t stride4 = stride << 2; - if (upsampled) { - int left_y = ystep; - int x = 0; - do { - uint8_t* dst_x = dst + x; - int y = 0; - do { - DirectionalZone3_4x4<true>( - dst_x, stride, left_ptr + (y << upsample_shift), left_y, ystep); - dst_x += stride4; - y += 4; - } while (y < height); - left_y += ystep << 2; - x += 4; - } while (x < width); - } else { - int left_y = ystep; - int x = 0; - do { - uint8_t* dst_x = dst + x; - int y = 0; - do { - DirectionalZone3_4x4<false>(dst_x, stride, left_ptr + y, left_y, - ystep); - dst_x += stride4; - y += 4; - } while (y < height); - left_y += ystep << 2; - x += 4; - } while (x < width); - } - return; - } - - const ptrdiff_t stride8 = stride << 3; - if (upsampled) { - int left_y = ystep; - int x = 0; - do { - uint8_t* dst_x = dst + x; - int y = 0; - do { - DirectionalZone3_8xH<true, 8>( - dst_x, stride, left_ptr + (y << upsample_shift), left_y, ystep); - dst_x += stride8; - y += 8; - } while (y < height); - left_y += ystep << 3; - x += 8; - } while (x < width); - } else { - int left_y = ystep; - int x = 0; - do { - uint8_t* dst_x = dst + x; - int y = 0; - do { - DirectionalZone3_8xH<false, 8>( - dst_x, stride, left_ptr + (y << upsample_shift), left_y, ystep); - dst_x += stride8; - y += 8; - } while (y < height); - left_y += ystep << 3; - x += 8; - } while (x < width); - } -} - -//------------------------------------------------------------------------------ -// Directional Zone 2 Functions -// 7.11.2.4 (8) - -// DirectionalBlend* selectively overwrites the values written by -// DirectionalZone2FromLeftCol*. |zone_bounds| has one 16-bit index for each -// row. -template <int y_selector> -inline void DirectionalBlend4_SSE4_1(uint8_t* dest, - const __m128i& dest_index_vect, - const __m128i& vals, - const __m128i& zone_bounds) { - const __m128i max_dest_x_vect = _mm_shufflelo_epi16(zone_bounds, y_selector); - const __m128i use_left = _mm_cmplt_epi16(dest_index_vect, max_dest_x_vect); - const __m128i original_vals = _mm_cvtepu8_epi16(Load4(dest)); - const __m128i blended_vals = _mm_blendv_epi8(vals, original_vals, use_left); - Store4(dest, _mm_packus_epi16(blended_vals, blended_vals)); -} - -inline void DirectionalBlend8_SSE4_1(uint8_t* dest, - const __m128i& dest_index_vect, - const __m128i& vals, - const __m128i& zone_bounds, - const __m128i& bounds_selector) { - const __m128i max_dest_x_vect = - _mm_shuffle_epi8(zone_bounds, bounds_selector); - const __m128i use_left = _mm_cmplt_epi16(dest_index_vect, max_dest_x_vect); - const __m128i original_vals = _mm_cvtepu8_epi16(LoadLo8(dest)); - const __m128i blended_vals = _mm_blendv_epi8(vals, original_vals, use_left); - StoreLo8(dest, _mm_packus_epi16(blended_vals, blended_vals)); -} - -constexpr int kDirectionalWeightBits = 5; -// |source| is packed with 4 or 8 pairs of 8-bit values from left or top. -// |shifts| is named to match the specification, with 4 or 8 pairs of (32 - -// shift) and shift. Shift is guaranteed to be between 0 and 32. -inline __m128i DirectionalZone2FromSource_SSE4_1(const uint8_t* const source, - const __m128i& shifts, - const __m128i& sampler) { - const __m128i src_vals = LoadUnaligned16(source); - __m128i vals = _mm_shuffle_epi8(src_vals, sampler); - vals = _mm_maddubs_epi16(vals, shifts); - return RightShiftWithRounding_U16(vals, kDirectionalWeightBits); -} - -// Because the source values "move backwards" as the row index increases, the -// indices derived from ystep are generally negative. This is accommodated by -// making sure the relative indices are within [-15, 0] when the function is -// called, and sliding them into the inclusive range [0, 15], relative to a -// lower base address. -constexpr int kPositiveIndexOffset = 15; - -template <bool upsampled> -inline void DirectionalZone2FromLeftCol_4x4_SSE4_1( - uint8_t* dst, ptrdiff_t stride, const uint8_t* const left_column_base, - __m128i left_y) { - const int upsample_shift = static_cast<int>(upsampled); - const int scale_bits = 6 - upsample_shift; - const __m128i max_shifts = _mm_set1_epi8(32); - const __m128i shift_mask = _mm_set1_epi32(0x003F003F); - const __m128i index_increment = _mm_cvtsi32_si128(0x01010101); - const __m128i positive_offset = _mm_set1_epi8(kPositiveIndexOffset); - // Left_column and sampler are both offset by 15 so the indices are always - // positive. - const uint8_t* left_column = left_column_base - kPositiveIndexOffset; - for (int y = 0; y < 4; dst += stride, ++y) { - __m128i offset_y = _mm_srai_epi16(left_y, scale_bits); - offset_y = _mm_packs_epi16(offset_y, offset_y); - - const __m128i adjacent = _mm_add_epi8(offset_y, index_increment); - __m128i sampler = _mm_unpacklo_epi8(offset_y, adjacent); - // Slide valid |offset_y| indices from range [-15, 0] to [0, 15] so they - // can work as shuffle indices. Some values may be out of bounds, but their - // pred results will be masked over by top prediction. - sampler = _mm_add_epi8(sampler, positive_offset); - - __m128i shifts = _mm_srli_epi16( - _mm_and_si128(_mm_slli_epi16(left_y, upsample_shift), shift_mask), 1); - shifts = _mm_packus_epi16(shifts, shifts); - const __m128i opposite_shifts = _mm_sub_epi8(max_shifts, shifts); - shifts = _mm_unpacklo_epi8(opposite_shifts, shifts); - const __m128i vals = DirectionalZone2FromSource_SSE4_1( - left_column + (y << upsample_shift), shifts, sampler); - Store4(dst, _mm_packus_epi16(vals, vals)); - } -} - -// The height at which a load of 16 bytes will not contain enough source pixels -// from |left_column| to supply an accurate row when computing 8 pixels at a -// time. The values are found by inspection. By coincidence, all angles that -// satisfy (ystep >> 6) == 2 map to the same value, so it is enough to look up -// by ystep >> 6. The largest index for this lookup is 1023 >> 6 == 15. -constexpr int kDirectionalZone2ShuffleInvalidHeight[16] = { - 1024, 1024, 16, 16, 16, 16, 0, 0, 18, 0, 0, 0, 0, 0, 0, 40}; - -template <bool upsampled> -inline void DirectionalZone2FromLeftCol_8x8_SSE4_1( - uint8_t* dst, ptrdiff_t stride, const uint8_t* const left_column, - __m128i left_y) { - const int upsample_shift = static_cast<int>(upsampled); - const int scale_bits = 6 - upsample_shift; - const __m128i max_shifts = _mm_set1_epi8(32); - const __m128i shift_mask = _mm_set1_epi32(0x003F003F); - const __m128i index_increment = _mm_set1_epi8(1); - const __m128i denegation = _mm_set1_epi8(kPositiveIndexOffset); - for (int y = 0; y < 8; dst += stride, ++y) { - __m128i offset_y = _mm_srai_epi16(left_y, scale_bits); - offset_y = _mm_packs_epi16(offset_y, offset_y); - const __m128i adjacent = _mm_add_epi8(offset_y, index_increment); - - // Offset the relative index because ystep is negative in Zone 2 and shuffle - // indices must be nonnegative. - __m128i sampler = _mm_unpacklo_epi8(offset_y, adjacent); - sampler = _mm_add_epi8(sampler, denegation); - - __m128i shifts = _mm_srli_epi16( - _mm_and_si128(_mm_slli_epi16(left_y, upsample_shift), shift_mask), 1); - shifts = _mm_packus_epi16(shifts, shifts); - const __m128i opposite_shifts = _mm_sub_epi8(max_shifts, shifts); - shifts = _mm_unpacklo_epi8(opposite_shifts, shifts); - - // The specification adds (y << 6) to left_y, which is subject to - // upsampling, but this puts sampler indices out of the 0-15 range. It is - // equivalent to offset the source address by (y << upsample_shift) instead. - const __m128i vals = DirectionalZone2FromSource_SSE4_1( - left_column - kPositiveIndexOffset + (y << upsample_shift), shifts, - sampler); - StoreLo8(dst, _mm_packus_epi16(vals, vals)); - } -} - -// |zone_bounds| is an epi16 of the relative x index at which base >= -(1 << -// upsampled_top), for each row. When there are 4 values, they can be duplicated -// with a non-register shuffle mask. -// |shifts| is one pair of weights that applies throughout a given row. -template <bool upsampled_top> -inline void DirectionalZone1Blend_4x4( - uint8_t* dest, const uint8_t* const top_row, ptrdiff_t stride, - __m128i sampler, const __m128i& zone_bounds, const __m128i& shifts, - const __m128i& dest_index_x, int top_x, const int xstep) { - const int upsample_shift = static_cast<int>(upsampled_top); - const int scale_bits_x = 6 - upsample_shift; - top_x -= xstep; - - int top_base_x = (top_x >> scale_bits_x); - const __m128i vals0 = DirectionalZone2FromSource_SSE4_1( - top_row + top_base_x, _mm_shufflelo_epi16(shifts, 0x00), sampler); - DirectionalBlend4_SSE4_1<0x00>(dest, dest_index_x, vals0, zone_bounds); - top_x -= xstep; - dest += stride; - - top_base_x = (top_x >> scale_bits_x); - const __m128i vals1 = DirectionalZone2FromSource_SSE4_1( - top_row + top_base_x, _mm_shufflelo_epi16(shifts, 0x55), sampler); - DirectionalBlend4_SSE4_1<0x55>(dest, dest_index_x, vals1, zone_bounds); - top_x -= xstep; - dest += stride; - - top_base_x = (top_x >> scale_bits_x); - const __m128i vals2 = DirectionalZone2FromSource_SSE4_1( - top_row + top_base_x, _mm_shufflelo_epi16(shifts, 0xAA), sampler); - DirectionalBlend4_SSE4_1<0xAA>(dest, dest_index_x, vals2, zone_bounds); - top_x -= xstep; - dest += stride; - - top_base_x = (top_x >> scale_bits_x); - const __m128i vals3 = DirectionalZone2FromSource_SSE4_1( - top_row + top_base_x, _mm_shufflelo_epi16(shifts, 0xFF), sampler); - DirectionalBlend4_SSE4_1<0xFF>(dest, dest_index_x, vals3, zone_bounds); -} - -template <bool upsampled_top, int height> -inline void DirectionalZone1Blend_8xH( - uint8_t* dest, const uint8_t* const top_row, ptrdiff_t stride, - __m128i sampler, const __m128i& zone_bounds, const __m128i& shifts, - const __m128i& dest_index_x, int top_x, const int xstep) { - const int upsample_shift = static_cast<int>(upsampled_top); - const int scale_bits_x = 6 - upsample_shift; - - __m128i y_selector = _mm_set1_epi32(0x01000100); - const __m128i index_increment = _mm_set1_epi32(0x02020202); - for (int y = 0; y < height; ++y, - y_selector = _mm_add_epi8(y_selector, index_increment), - dest += stride) { - top_x -= xstep; - const int top_base_x = top_x >> scale_bits_x; - const __m128i vals = DirectionalZone2FromSource_SSE4_1( - top_row + top_base_x, _mm_shuffle_epi8(shifts, y_selector), sampler); - DirectionalBlend8_SSE4_1(dest, dest_index_x, vals, zone_bounds, y_selector); - } -} - -// 7.11.2.4 (8) 90 < angle > 180 -// The strategy for this function is to know how many blocks can be processed -// with just pixels from |top_ptr|, then handle mixed blocks, then handle only -// blocks that take from |left_ptr|. Additionally, a fast index-shuffle -// approach is used for pred values from |left_column| in sections that permit -// it. -template <bool upsampled_left, bool upsampled_top> -inline void DirectionalZone2_SSE4_1(void* dest, ptrdiff_t stride, - const uint8_t* const top_row, - const uint8_t* const left_column, - const int width, const int height, - const int xstep, const int ystep) { - auto* dst = static_cast<uint8_t*>(dest); - const int upsample_left_shift = static_cast<int>(upsampled_left); - const int upsample_top_shift = static_cast<int>(upsampled_top); - const __m128i max_shift = _mm_set1_epi8(32); - const ptrdiff_t stride8 = stride << 3; - const __m128i dest_index_x = - _mm_set_epi32(0x00070006, 0x00050004, 0x00030002, 0x00010000); - const __m128i sampler_top = - upsampled_top - ? _mm_set_epi32(0x0F0E0D0C, 0x0B0A0908, 0x07060504, 0x03020100) - : _mm_set_epi32(0x08070706, 0x06050504, 0x04030302, 0x02010100); - const __m128i shift_mask = _mm_set1_epi32(0x003F003F); - // All columns from |min_top_only_x| to the right will only need |top_row| to - // compute. This assumes minimum |xstep| is 3. - const int min_top_only_x = std::min((height * xstep) >> 6, width); - - // For steep angles, the source pixels from left_column may not fit in a - // 16-byte load for shuffling. - // TODO(petersonab): Find a more precise formula for this subject to x. - const int max_shuffle_height = - std::min(height, kDirectionalZone2ShuffleInvalidHeight[ystep >> 6]); - - const int xstep8 = xstep << 3; - const __m128i xstep8_vect = _mm_set1_epi16(xstep8); - // Accumulate xstep across 8 rows. - const __m128i xstep_dup = _mm_set1_epi16(-xstep); - const __m128i increments = _mm_set_epi16(8, 7, 6, 5, 4, 3, 2, 1); - const __m128i xstep_for_shift = _mm_mullo_epi16(xstep_dup, increments); - // Offsets the original zone bound value to simplify x < (y+1)*xstep/64 -1 - const __m128i scaled_one = _mm_set1_epi16(-64); - __m128i xstep_bounds_base = - (xstep == 64) ? _mm_sub_epi16(scaled_one, xstep_for_shift) - : _mm_sub_epi16(_mm_set1_epi16(-1), xstep_for_shift); - - const int left_base_increment = ystep >> 6; - const int ystep_remainder = ystep & 0x3F; - const int ystep8 = ystep << 3; - const int left_base_increment8 = ystep8 >> 6; - const int ystep_remainder8 = ystep8 & 0x3F; - const __m128i increment_left8 = _mm_set1_epi16(-ystep_remainder8); - - // If the 64 scaling is regarded as a decimal point, the first value of the - // left_y vector omits the portion which is covered under the left_column - // offset. Following values need the full ystep as a relative offset. - const __m128i ystep_init = _mm_set1_epi16(-ystep_remainder); - const __m128i ystep_dup = _mm_set1_epi16(-ystep); - __m128i left_y = _mm_mullo_epi16(ystep_dup, dest_index_x); - left_y = _mm_add_epi16(ystep_init, left_y); - - const __m128i increment_top8 = _mm_set1_epi16(8 << 6); - int x = 0; - - // This loop treats each set of 4 columns in 3 stages with y-value boundaries. - // The first stage, before the first y-loop, covers blocks that are only - // computed from the top row. The second stage, comprising two y-loops, covers - // blocks that have a mixture of values computed from top or left. The final - // stage covers blocks that are only computed from the left. - for (int left_offset = -left_base_increment; x < min_top_only_x; - x += 8, - xstep_bounds_base = _mm_sub_epi16(xstep_bounds_base, increment_top8), - // Watch left_y because it can still get big. - left_y = _mm_add_epi16(left_y, increment_left8), - left_offset -= left_base_increment8) { - uint8_t* dst_x = dst + x; - - // Round down to the nearest multiple of 8. - const int max_top_only_y = std::min(((x + 1) << 6) / xstep, height) & ~7; - DirectionalZone1_4xH(dst_x, stride, top_row + (x << upsample_top_shift), - max_top_only_y, -xstep, upsampled_top); - DirectionalZone1_4xH(dst_x + 4, stride, - top_row + ((x + 4) << upsample_top_shift), - max_top_only_y, -xstep, upsampled_top); - - int y = max_top_only_y; - dst_x += stride * y; - const int xstep_y = xstep * y; - const __m128i xstep_y_vect = _mm_set1_epi16(xstep_y); - // All rows from |min_left_only_y| down for this set of columns, only need - // |left_column| to compute. - const int min_left_only_y = std::min(((x + 8) << 6) / xstep, height); - // At high angles such that min_left_only_y < 8, ystep is low and xstep is - // high. This means that max_shuffle_height is unbounded and xstep_bounds - // will overflow in 16 bits. This is prevented by stopping the first - // blending loop at min_left_only_y for such cases, which means we skip over - // the second blending loop as well. - const int left_shuffle_stop_y = - std::min(max_shuffle_height, min_left_only_y); - __m128i xstep_bounds = _mm_add_epi16(xstep_bounds_base, xstep_y_vect); - __m128i xstep_for_shift_y = _mm_sub_epi16(xstep_for_shift, xstep_y_vect); - int top_x = -xstep_y; - - for (; y < left_shuffle_stop_y; - y += 8, dst_x += stride8, - xstep_bounds = _mm_add_epi16(xstep_bounds, xstep8_vect), - xstep_for_shift_y = _mm_sub_epi16(xstep_for_shift_y, xstep8_vect), - top_x -= xstep8) { - DirectionalZone2FromLeftCol_8x8_SSE4_1<upsampled_left>( - dst_x, stride, - left_column + ((left_offset + y) << upsample_left_shift), left_y); - - __m128i shifts = _mm_srli_epi16( - _mm_and_si128(_mm_slli_epi16(xstep_for_shift_y, upsample_top_shift), - shift_mask), - 1); - shifts = _mm_packus_epi16(shifts, shifts); - __m128i opposite_shifts = _mm_sub_epi8(max_shift, shifts); - shifts = _mm_unpacklo_epi8(opposite_shifts, shifts); - __m128i xstep_bounds_off = _mm_srai_epi16(xstep_bounds, 6); - DirectionalZone1Blend_8xH<upsampled_top, 8>( - dst_x, top_row + (x << upsample_top_shift), stride, sampler_top, - xstep_bounds_off, shifts, dest_index_x, top_x, xstep); - } - // Pick up from the last y-value, using the 10% slower but secure method for - // left prediction. - const auto base_left_y = static_cast<int16_t>(_mm_extract_epi16(left_y, 0)); - for (; y < min_left_only_y; - y += 8, dst_x += stride8, - xstep_bounds = _mm_add_epi16(xstep_bounds, xstep8_vect), - xstep_for_shift_y = _mm_sub_epi16(xstep_for_shift_y, xstep8_vect), - top_x -= xstep8) { - const __m128i xstep_bounds_off = _mm_srai_epi16(xstep_bounds, 6); - - DirectionalZone3_8xH<upsampled_left, 8>( - dst_x, stride, - left_column + ((left_offset + y) << upsample_left_shift), base_left_y, - -ystep); - - __m128i shifts = _mm_srli_epi16( - _mm_and_si128(_mm_slli_epi16(xstep_for_shift_y, upsample_top_shift), - shift_mask), - 1); - shifts = _mm_packus_epi16(shifts, shifts); - __m128i opposite_shifts = _mm_sub_epi8(max_shift, shifts); - shifts = _mm_unpacklo_epi8(opposite_shifts, shifts); - DirectionalZone1Blend_8xH<upsampled_top, 8>( - dst_x, top_row + (x << upsample_top_shift), stride, sampler_top, - xstep_bounds_off, shifts, dest_index_x, top_x, xstep); - } - // Loop over y for left_only rows. - for (; y < height; y += 8, dst_x += stride8) { - DirectionalZone3_8xH<upsampled_left, 8>( - dst_x, stride, - left_column + ((left_offset + y) << upsample_left_shift), base_left_y, - -ystep); - } - } - for (; x < width; x += 4) { - DirectionalZone1_4xH(dst + x, stride, top_row + (x << upsample_top_shift), - height, -xstep, upsampled_top); - } -} - -template <bool upsampled_left, bool upsampled_top> -inline void DirectionalZone2_4_SSE4_1(void* dest, ptrdiff_t stride, - const uint8_t* const top_row, - const uint8_t* const left_column, - const int width, const int height, - const int xstep, const int ystep) { - auto* dst = static_cast<uint8_t*>(dest); - const int upsample_left_shift = static_cast<int>(upsampled_left); - const int upsample_top_shift = static_cast<int>(upsampled_top); - const __m128i max_shift = _mm_set1_epi8(32); - const ptrdiff_t stride4 = stride << 2; - const __m128i dest_index_x = _mm_set_epi32(0, 0, 0x00030002, 0x00010000); - const __m128i sampler_top = - upsampled_top - ? _mm_set_epi32(0x0F0E0D0C, 0x0B0A0908, 0x07060504, 0x03020100) - : _mm_set_epi32(0x08070706, 0x06050504, 0x04030302, 0x02010100); - // All columns from |min_top_only_x| to the right will only need |top_row| to - // compute. - assert(xstep >= 3); - const int min_top_only_x = std::min((height * xstep) >> 6, width); - - const int xstep4 = xstep << 2; - const __m128i xstep4_vect = _mm_set1_epi16(xstep4); - const __m128i xstep_dup = _mm_set1_epi16(-xstep); - const __m128i increments = _mm_set_epi32(0, 0, 0x00040003, 0x00020001); - __m128i xstep_for_shift = _mm_mullo_epi16(xstep_dup, increments); - const __m128i scaled_one = _mm_set1_epi16(-64); - // Offsets the original zone bound value to simplify x < (y+1)*xstep/64 -1 - __m128i xstep_bounds_base = - (xstep == 64) ? _mm_sub_epi16(scaled_one, xstep_for_shift) - : _mm_sub_epi16(_mm_set1_epi16(-1), xstep_for_shift); - - const int left_base_increment = ystep >> 6; - const int ystep_remainder = ystep & 0x3F; - const int ystep4 = ystep << 2; - const int left_base_increment4 = ystep4 >> 6; - // This is guaranteed to be less than 64, but accumulation may bring it past - // 64 for higher x values. - const int ystep_remainder4 = ystep4 & 0x3F; - const __m128i increment_left4 = _mm_set1_epi16(-ystep_remainder4); - const __m128i increment_top4 = _mm_set1_epi16(4 << 6); - - // If the 64 scaling is regarded as a decimal point, the first value of the - // left_y vector omits the portion which will go into the left_column offset. - // Following values need the full ystep as a relative offset. - const __m128i ystep_init = _mm_set1_epi16(-ystep_remainder); - const __m128i ystep_dup = _mm_set1_epi16(-ystep); - __m128i left_y = _mm_mullo_epi16(ystep_dup, dest_index_x); - left_y = _mm_add_epi16(ystep_init, left_y); - const __m128i shift_mask = _mm_set1_epi32(0x003F003F); - - int x = 0; - // Loop over x for columns with a mixture of sources. - for (int left_offset = -left_base_increment; x < min_top_only_x; x += 4, - xstep_bounds_base = _mm_sub_epi16(xstep_bounds_base, increment_top4), - left_y = _mm_add_epi16(left_y, increment_left4), - left_offset -= left_base_increment4) { - uint8_t* dst_x = dst + x; - - // Round down to the nearest multiple of 8. - const int max_top_only_y = std::min((x << 6) / xstep, height) & 0xFFFFFFF4; - DirectionalZone1_4xH(dst_x, stride, top_row + (x << upsample_top_shift), - max_top_only_y, -xstep, upsampled_top); - int y = max_top_only_y; - dst_x += stride * y; - const int xstep_y = xstep * y; - const __m128i xstep_y_vect = _mm_set1_epi16(xstep_y); - // All rows from |min_left_only_y| down for this set of columns, only need - // |left_column| to compute. Rounded up to the nearest multiple of 4. - const int min_left_only_y = std::min(((x + 4) << 6) / xstep, height); - - __m128i xstep_bounds = _mm_add_epi16(xstep_bounds_base, xstep_y_vect); - __m128i xstep_for_shift_y = _mm_sub_epi16(xstep_for_shift, xstep_y_vect); - int top_x = -xstep_y; - - // Loop over y for mixed rows. - for (; y < min_left_only_y; - y += 4, dst_x += stride4, - xstep_bounds = _mm_add_epi16(xstep_bounds, xstep4_vect), - xstep_for_shift_y = _mm_sub_epi16(xstep_for_shift_y, xstep4_vect), - top_x -= xstep4) { - DirectionalZone2FromLeftCol_4x4_SSE4_1<upsampled_left>( - dst_x, stride, - left_column + ((left_offset + y) * (1 << upsample_left_shift)), - left_y); - - __m128i shifts = _mm_srli_epi16( - _mm_and_si128(_mm_slli_epi16(xstep_for_shift_y, upsample_top_shift), - shift_mask), - 1); - shifts = _mm_packus_epi16(shifts, shifts); - const __m128i opposite_shifts = _mm_sub_epi8(max_shift, shifts); - shifts = _mm_unpacklo_epi8(opposite_shifts, shifts); - const __m128i xstep_bounds_off = _mm_srai_epi16(xstep_bounds, 6); - DirectionalZone1Blend_4x4<upsampled_top>( - dst_x, top_row + (x << upsample_top_shift), stride, sampler_top, - xstep_bounds_off, shifts, dest_index_x, top_x, xstep); - } - // Loop over y for left-only rows, if any. - for (; y < height; y += 4, dst_x += stride4) { - DirectionalZone2FromLeftCol_4x4_SSE4_1<upsampled_left>( - dst_x, stride, - left_column + ((left_offset + y) << upsample_left_shift), left_y); - } - } - // Loop over top-only columns, if any. - for (; x < width; x += 4) { - DirectionalZone1_4xH(dst + x, stride, top_row + (x << upsample_top_shift), - height, -xstep, upsampled_top); - } -} - -void DirectionalIntraPredictorZone2_SSE4_1(void* const dest, ptrdiff_t stride, - const void* const top_row, - const void* const left_column, - const int width, const int height, - const int xstep, const int ystep, - const bool upsampled_top, - const bool upsampled_left) { - // Increasing the negative buffer for this function allows more rows to be - // processed at a time without branching in an inner loop to check the base. - uint8_t top_buffer[288]; - uint8_t left_buffer[288]; - memcpy(top_buffer + 128, static_cast<const uint8_t*>(top_row) - 16, 160); - memcpy(left_buffer + 128, static_cast<const uint8_t*>(left_column) - 16, 160); - const uint8_t* top_ptr = top_buffer + 144; - const uint8_t* left_ptr = left_buffer + 144; - if (width == 4 || height == 4) { - if (upsampled_left) { - if (upsampled_top) { - DirectionalZone2_4_SSE4_1<true, true>(dest, stride, top_ptr, left_ptr, - width, height, xstep, ystep); - } else { - DirectionalZone2_4_SSE4_1<true, false>(dest, stride, top_ptr, left_ptr, - width, height, xstep, ystep); - } - } else { - if (upsampled_top) { - DirectionalZone2_4_SSE4_1<false, true>(dest, stride, top_ptr, left_ptr, - width, height, xstep, ystep); - } else { - DirectionalZone2_4_SSE4_1<false, false>(dest, stride, top_ptr, left_ptr, - width, height, xstep, ystep); - } - } - return; - } - if (upsampled_left) { - if (upsampled_top) { - DirectionalZone2_SSE4_1<true, true>(dest, stride, top_ptr, left_ptr, - width, height, xstep, ystep); - } else { - DirectionalZone2_SSE4_1<true, false>(dest, stride, top_ptr, left_ptr, - width, height, xstep, ystep); - } - } else { - if (upsampled_top) { - DirectionalZone2_SSE4_1<false, true>(dest, stride, top_ptr, left_ptr, - width, height, xstep, ystep); - } else { - DirectionalZone2_SSE4_1<false, false>(dest, stride, top_ptr, left_ptr, - width, height, xstep, ystep); - } - } -} - -//------------------------------------------------------------------------------ -// FilterIntraPredictor_SSE4_1 - -// Apply all filter taps to the given 7 packed 16-bit values, keeping the 8th -// at zero to preserve the sum. -inline void Filter4x2_SSE4_1(uint8_t* dst, const ptrdiff_t stride, - const __m128i& pixels, const __m128i& taps_0_1, - const __m128i& taps_2_3, const __m128i& taps_4_5, - const __m128i& taps_6_7) { - const __m128i mul_0_01 = _mm_maddubs_epi16(pixels, taps_0_1); - const __m128i mul_0_23 = _mm_maddubs_epi16(pixels, taps_2_3); - // |output_half| contains 8 partial sums. - __m128i output_half = _mm_hadd_epi16(mul_0_01, mul_0_23); - __m128i output = _mm_hadd_epi16(output_half, output_half); - const __m128i output_row0 = - _mm_packus_epi16(RightShiftWithRounding_S16(output, 4), - /* arbitrary pack arg */ output); - Store4(dst, output_row0); - const __m128i mul_1_01 = _mm_maddubs_epi16(pixels, taps_4_5); - const __m128i mul_1_23 = _mm_maddubs_epi16(pixels, taps_6_7); - output_half = _mm_hadd_epi16(mul_1_01, mul_1_23); - output = _mm_hadd_epi16(output_half, output_half); - const __m128i output_row1 = - _mm_packus_epi16(RightShiftWithRounding_S16(output, 4), - /* arbitrary pack arg */ output); - Store4(dst + stride, output_row1); -} - -// 4xH transform sizes are given special treatment because LoadLo8 goes out -// of bounds and every block involves the left column. This implementation -// loads TL from the top row for the first block, so it is not -inline void Filter4xH(uint8_t* dest, ptrdiff_t stride, - const uint8_t* const top_ptr, - const uint8_t* const left_ptr, FilterIntraPredictor pred, - const int height) { - const __m128i taps_0_1 = LoadUnaligned16(kFilterIntraTaps[pred][0]); - const __m128i taps_2_3 = LoadUnaligned16(kFilterIntraTaps[pred][2]); - const __m128i taps_4_5 = LoadUnaligned16(kFilterIntraTaps[pred][4]); - const __m128i taps_6_7 = LoadUnaligned16(kFilterIntraTaps[pred][6]); - __m128i top = Load4(top_ptr - 1); - __m128i pixels = _mm_insert_epi8(top, top_ptr[3], 4); - __m128i left = (height == 4 ? Load4(left_ptr) : LoadLo8(left_ptr)); - left = _mm_slli_si128(left, 5); - - // Relative pixels: top[-1], top[0], top[1], top[2], top[3], left[0], left[1], - // left[2], left[3], left[4], left[5], left[6], left[7] - pixels = _mm_or_si128(left, pixels); - - // Duplicate first 8 bytes. - pixels = _mm_shuffle_epi32(pixels, kDuplicateFirstHalf); - Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, - taps_6_7); - dest += stride; // Move to y = 1. - pixels = Load4(dest); - - // Relative pixels: top[0], top[1], top[2], top[3], empty, left[-2], left[-1], - // left[0], left[1], ... - pixels = _mm_or_si128(left, pixels); - - // This mask rearranges bytes in the order: 6, 0, 1, 2, 3, 7, 8, 15. The last - // byte is an unused value, which shall be multiplied by 0 when we apply the - // filter. - constexpr int64_t kInsertTopLeftFirstMask = 0x0F08070302010006; - - // Insert left[-1] in front as TL and put left[0] and left[1] at the end. - const __m128i pixel_order1 = _mm_set1_epi64x(kInsertTopLeftFirstMask); - pixels = _mm_shuffle_epi8(pixels, pixel_order1); - dest += stride; // Move to y = 2. - Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, - taps_6_7); - dest += stride; // Move to y = 3. - - // Compute the middle 8 rows before using common code for the final 4 rows. - // Because the common code below this block assumes that - if (height == 16) { - // This shift allows us to use pixel_order2 twice after shifting by 2 later. - left = _mm_slli_si128(left, 1); - pixels = Load4(dest); - - // Relative pixels: top[0], top[1], top[2], top[3], empty, empty, left[-4], - // left[-3], left[-2], left[-1], left[0], left[1], left[2], left[3] - pixels = _mm_or_si128(left, pixels); - - // This mask rearranges bytes in the order: 9, 0, 1, 2, 3, 7, 8, 15. The - // last byte is an unused value, as above. The top-left was shifted to - // position nine to keep two empty spaces after the top pixels. - constexpr int64_t kInsertTopLeftSecondMask = 0x0F0B0A0302010009; - - // Insert (relative) left[-1] in front as TL and put left[0] and left[1] at - // the end. - const __m128i pixel_order2 = _mm_set1_epi64x(kInsertTopLeftSecondMask); - pixels = _mm_shuffle_epi8(pixels, pixel_order2); - dest += stride; // Move to y = 4. - - // First 4x2 in the if body. - Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, - taps_6_7); - - // Clear all but final pixel in the first 8 of left column. - __m128i keep_top_left = _mm_srli_si128(left, 13); - dest += stride; // Move to y = 5. - pixels = Load4(dest); - left = _mm_srli_si128(left, 2); - - // Relative pixels: top[0], top[1], top[2], top[3], left[-6], - // left[-5], left[-4], left[-3], left[-2], left[-1], left[0], left[1] - pixels = _mm_or_si128(left, pixels); - left = LoadLo8(left_ptr + 8); - - pixels = _mm_shuffle_epi8(pixels, pixel_order2); - dest += stride; // Move to y = 6. - - // Second 4x2 in the if body. - Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, - taps_6_7); - - // Position TL value so we can use pixel_order1. - keep_top_left = _mm_slli_si128(keep_top_left, 6); - dest += stride; // Move to y = 7. - pixels = Load4(dest); - left = _mm_slli_si128(left, 7); - left = _mm_or_si128(left, keep_top_left); - - // Relative pixels: top[0], top[1], top[2], top[3], empty, empty, - // left[-1], left[0], left[1], left[2], left[3], ... - pixels = _mm_or_si128(left, pixels); - pixels = _mm_shuffle_epi8(pixels, pixel_order1); - dest += stride; // Move to y = 8. - - // Third 4x2 in the if body. - Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, - taps_6_7); - dest += stride; // Move to y = 9. - - // Prepare final inputs. - pixels = Load4(dest); - left = _mm_srli_si128(left, 2); - - // Relative pixels: top[0], top[1], top[2], top[3], left[-3], left[-2] - // left[-1], left[0], left[1], left[2], left[3], ... - pixels = _mm_or_si128(left, pixels); - pixels = _mm_shuffle_epi8(pixels, pixel_order1); - dest += stride; // Move to y = 10. - - // Fourth 4x2 in the if body. - Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, - taps_6_7); - dest += stride; // Move to y = 11. - } - - // In both the 8 and 16 case, we assume that the left vector has the next TL - // at position 8. - if (height > 4) { - // Erase prior left pixels by shifting TL to position 0. - left = _mm_srli_si128(left, 8); - left = _mm_slli_si128(left, 6); - pixels = Load4(dest); - - // Relative pixels: top[0], top[1], top[2], top[3], empty, empty, - // left[-1], left[0], left[1], left[2], left[3], ... - pixels = _mm_or_si128(left, pixels); - pixels = _mm_shuffle_epi8(pixels, pixel_order1); - dest += stride; // Move to y = 12 or 4. - - // First of final two 4x2 blocks. - Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, - taps_6_7); - dest += stride; // Move to y = 13 or 5. - pixels = Load4(dest); - left = _mm_srli_si128(left, 2); - - // Relative pixels: top[0], top[1], top[2], top[3], left[-3], left[-2] - // left[-1], left[0], left[1], left[2], left[3], ... - pixels = _mm_or_si128(left, pixels); - pixels = _mm_shuffle_epi8(pixels, pixel_order1); - dest += stride; // Move to y = 14 or 6. - - // Last of final two 4x2 blocks. - Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, - taps_6_7); - } -} - -void FilterIntraPredictor_SSE4_1(void* const dest, ptrdiff_t stride, - const void* const top_row, - const void* const left_column, - FilterIntraPredictor pred, const int width, - const int height) { - const auto* const top_ptr = static_cast<const uint8_t*>(top_row); - const auto* const left_ptr = static_cast<const uint8_t*>(left_column); - auto* dst = static_cast<uint8_t*>(dest); - if (width == 4) { - Filter4xH(dst, stride, top_ptr, left_ptr, pred, height); - return; - } - - // There is one set of 7 taps for each of the 4x2 output pixels. - const __m128i taps_0_1 = LoadUnaligned16(kFilterIntraTaps[pred][0]); - const __m128i taps_2_3 = LoadUnaligned16(kFilterIntraTaps[pred][2]); - const __m128i taps_4_5 = LoadUnaligned16(kFilterIntraTaps[pred][4]); - const __m128i taps_6_7 = LoadUnaligned16(kFilterIntraTaps[pred][6]); - - // This mask rearranges bytes in the order: 0, 1, 2, 3, 4, 8, 9, 15. The 15 at - // the end is an unused value, which shall be multiplied by 0 when we apply - // the filter. - constexpr int64_t kCondenseLeftMask = 0x0F09080403020100; - - // Takes the "left section" and puts it right after p0-p4. - const __m128i pixel_order1 = _mm_set1_epi64x(kCondenseLeftMask); - - // This mask rearranges bytes in the order: 8, 0, 1, 2, 3, 9, 10, 15. The last - // byte is unused as above. - constexpr int64_t kInsertTopLeftMask = 0x0F0A090302010008; - - // Shuffles the "top left" from the left section, to the front. Used when - // grabbing data from left_column and not top_row. - const __m128i pixel_order2 = _mm_set1_epi64x(kInsertTopLeftMask); - - // This first pass takes care of the cases where the top left pixel comes from - // top_row. - __m128i pixels = LoadLo8(top_ptr - 1); - __m128i left = _mm_slli_si128(Load4(left_column), 8); - pixels = _mm_or_si128(pixels, left); - - // Two sets of the same pixels to multiply with two sets of taps. - pixels = _mm_shuffle_epi8(pixels, pixel_order1); - Filter4x2_SSE4_1(dst, stride, pixels, taps_0_1, taps_2_3, taps_4_5, taps_6_7); - left = _mm_srli_si128(left, 1); - - // Load - pixels = Load4(dst + stride); - - // Because of the above shift, this OR 'invades' the final of the first 8 - // bytes of |pixels|. This is acceptable because the 8th filter tap is always - // a padded 0. - pixels = _mm_or_si128(pixels, left); - pixels = _mm_shuffle_epi8(pixels, pixel_order2); - const ptrdiff_t stride2 = stride << 1; - const ptrdiff_t stride4 = stride << 2; - Filter4x2_SSE4_1(dst + stride2, stride, pixels, taps_0_1, taps_2_3, taps_4_5, - taps_6_7); - dst += 4; - for (int x = 3; x < width - 4; x += 4) { - pixels = Load4(top_ptr + x); - pixels = _mm_insert_epi8(pixels, top_ptr[x + 4], 4); - pixels = _mm_insert_epi8(pixels, dst[-1], 5); - pixels = _mm_insert_epi8(pixels, dst[stride - 1], 6); - - // Duplicate bottom half into upper half. - pixels = _mm_shuffle_epi32(pixels, kDuplicateFirstHalf); - Filter4x2_SSE4_1(dst, stride, pixels, taps_0_1, taps_2_3, taps_4_5, - taps_6_7); - pixels = Load4(dst + stride - 1); - pixels = _mm_insert_epi8(pixels, dst[stride + 3], 4); - pixels = _mm_insert_epi8(pixels, dst[stride2 - 1], 5); - pixels = _mm_insert_epi8(pixels, dst[stride + stride2 - 1], 6); - - // Duplicate bottom half into upper half. - pixels = _mm_shuffle_epi32(pixels, kDuplicateFirstHalf); - Filter4x2_SSE4_1(dst + stride2, stride, pixels, taps_0_1, taps_2_3, - taps_4_5, taps_6_7); - dst += 4; - } - - // Now we handle heights that reference previous blocks rather than top_row. - for (int y = 4; y < height; y += 4) { - // Leftmost 4x4 block for this height. - dst -= width; - dst += stride4; - - // Top Left is not available by offset in these leftmost blocks. - pixels = Load4(dst - stride); - left = _mm_slli_si128(Load4(left_ptr + y - 1), 8); - left = _mm_insert_epi8(left, left_ptr[y + 3], 12); - pixels = _mm_or_si128(pixels, left); - pixels = _mm_shuffle_epi8(pixels, pixel_order2); - Filter4x2_SSE4_1(dst, stride, pixels, taps_0_1, taps_2_3, taps_4_5, - taps_6_7); - - // The bytes shifted into positions 6 and 7 will be ignored by the shuffle. - left = _mm_srli_si128(left, 2); - pixels = Load4(dst + stride); - pixels = _mm_or_si128(pixels, left); - pixels = _mm_shuffle_epi8(pixels, pixel_order2); - Filter4x2_SSE4_1(dst + stride2, stride, pixels, taps_0_1, taps_2_3, - taps_4_5, taps_6_7); - - dst += 4; - - // Remaining 4x4 blocks for this height. - for (int x = 4; x < width; x += 4) { - pixels = Load4(dst - stride - 1); - pixels = _mm_insert_epi8(pixels, dst[-stride + 3], 4); - pixels = _mm_insert_epi8(pixels, dst[-1], 5); - pixels = _mm_insert_epi8(pixels, dst[stride - 1], 6); - - // Duplicate bottom half into upper half. - pixels = _mm_shuffle_epi32(pixels, kDuplicateFirstHalf); - Filter4x2_SSE4_1(dst, stride, pixels, taps_0_1, taps_2_3, taps_4_5, - taps_6_7); - pixels = Load4(dst + stride - 1); - pixels = _mm_insert_epi8(pixels, dst[stride + 3], 4); - pixels = _mm_insert_epi8(pixels, dst[stride2 - 1], 5); - pixels = _mm_insert_epi8(pixels, dst[stride2 + stride - 1], 6); - - // Duplicate bottom half into upper half. - pixels = _mm_shuffle_epi32(pixels, kDuplicateFirstHalf); - Filter4x2_SSE4_1(dst + stride2, stride, pixels, taps_0_1, taps_2_3, - taps_4_5, taps_6_7); - dst += 4; - } - } -} - void Init8bpp() { Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); assert(dsp != nullptr); @@ -2746,21 +1412,6 @@ void Init8bpp() { // These guards check if this version of the function was not superseded by // a higher optimization level, such as AVX. The corresponding #define also // prevents the C version from being added to the table. -#if DSP_ENABLED_8BPP_SSE4_1(FilterIntraPredictor) - dsp->filter_intra_predictor = FilterIntraPredictor_SSE4_1; -#endif -#if DSP_ENABLED_8BPP_SSE4_1(DirectionalIntraPredictorZone1) - dsp->directional_intra_predictor_zone1 = - DirectionalIntraPredictorZone1_SSE4_1; -#endif -#if DSP_ENABLED_8BPP_SSE4_1(DirectionalIntraPredictorZone2) - dsp->directional_intra_predictor_zone2 = - DirectionalIntraPredictorZone2_SSE4_1; -#endif -#if DSP_ENABLED_8BPP_SSE4_1(DirectionalIntraPredictorZone3) - dsp->directional_intra_predictor_zone3 = - DirectionalIntraPredictorZone3_SSE4_1; -#endif #if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_IntraPredictorDcTop) dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcTop] = DcDefs::_4x4::DcTop; @@ -3524,7 +2175,7 @@ void IntraPredInit_SSE4_1() { } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_SSE4_1 +#else // !LIBGAV1_TARGETING_SSE4_1 namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/intrapred_sse4.h b/src/dsp/x86/intrapred_sse4.h index 7f4fcd7..1f6f30a 100644 --- a/src/dsp/x86/intrapred_sse4.h +++ b/src/dsp/x86/intrapred_sse4.h @@ -23,13 +23,9 @@ namespace libgav1 { namespace dsp { -// Initializes Dsp::intra_predictors, Dsp::directional_intra_predictor_zone*, -// Dsp::cfl_intra_predictors, Dsp::cfl_subsamplers and -// Dsp::filter_intra_predictor, see the defines below for specifics. These -// functions are not thread-safe. +// Initializes Dsp::intra_predictors. See the defines below for specifics. +// These functions are not thread-safe. void IntraPredInit_SSE4_1(); -void IntraPredCflInit_SSE4_1(); -void IntraPredSmoothInit_SSE4_1(); } // namespace dsp } // namespace libgav1 @@ -37,22 +33,6 @@ void IntraPredSmoothInit_SSE4_1(); // If sse4 is enabled and the baseline isn't set due to a higher level of // optimization being enabled, signal the sse4 implementation should be used. #if LIBGAV1_TARGETING_SSE4_1 -#ifndef LIBGAV1_Dsp8bpp_FilterIntraPredictor -#define LIBGAV1_Dsp8bpp_FilterIntraPredictor LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone1 -#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone1 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone2 -#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone2 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone3 -#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone3 LIBGAV1_CPU_SSE4_1 -#endif - #ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcTop #define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1 #endif @@ -138,174 +118,6 @@ void IntraPredSmoothInit_SSE4_1(); LIBGAV1_CPU_SSE4_1 #endif -#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler420 -#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler420 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler420 -#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler420 -#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler420 -#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler420 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler420 -#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler420 -#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler420 -#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler420 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler420 -#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler420 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler420 -#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler420 -#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler420 -#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler420 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler420 -#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler420 -#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler420 -#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler420 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444 -#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler444 -#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler444 -#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler444 -#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler444 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler444 -#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler444 -#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler444 -#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler444 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler444 -#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler444 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler444 -#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler444 -#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler444 -#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler444 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler444 -#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler444 -#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444 -#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444 LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor -#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_CflIntraPredictor -#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_CflIntraPredictor -#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_CflIntraPredictor -#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflIntraPredictor LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_CflIntraPredictor -#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_CflIntraPredictor -#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_CflIntraPredictor -#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflIntraPredictor LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_CflIntraPredictor -#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflIntraPredictor LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_CflIntraPredictor -#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_CflIntraPredictor -#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_CflIntraPredictor -#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflIntraPredictor LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_CflIntraPredictor -#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_CflIntraPredictor -#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_CflIntraPredictor -#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflIntraPredictor LIBGAV1_CPU_SSE4_1 -#endif - #ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcLeft #define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcLeft LIBGAV1_CPU_SSE4_1 #endif @@ -658,287 +470,6 @@ void IntraPredSmoothInit_SSE4_1(); LIBGAV1_CPU_SSE4_1 #endif -#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmooth -#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmooth LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmooth -#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmooth LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmooth -#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmooth \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmooth -#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmooth LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmooth -#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmooth LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmooth -#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmooth \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmooth -#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmooth \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmooth -#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmooth \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmooth -#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmooth \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmooth -#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmooth \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmooth -#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmooth \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmooth -#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmooth \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmooth -#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmooth \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmooth -#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmooth \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmooth -#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmooth \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmooth -#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmooth \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmooth -#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmooth \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmooth -#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmooth \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmooth -#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmooth \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothVertical -#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothVertical \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothVertical -#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothVertical \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothVertical -#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothVertical \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothVertical -#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothVertical \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothVertical -#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothVertical \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothVertical -#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothVertical \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothVertical -#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothVertical \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothVertical -#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothVertical \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothVertical -#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothVertical \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothVertical -#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothVertical \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothVertical -#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothVertical \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothVertical -#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothVertical \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothVertical -#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothVertical \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothVertical -#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothVertical \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothVertical -#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothVertical \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothVertical -#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothVertical \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothVertical -#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothVertical \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothVertical -#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothVertical \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothVertical -#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothVertical \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothHorizontal -#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothHorizontal \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothHorizontal -#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothHorizontal \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothHorizontal -#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothHorizontal \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothHorizontal -#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothHorizontal \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothHorizontal -#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothHorizontal \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothHorizontal -#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothHorizontal \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothHorizontal -#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothHorizontal \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothHorizontal -#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothHorizontal \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothHorizontal -#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothHorizontal \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothHorizontal -#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothHorizontal \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothHorizontal -#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothHorizontal \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothHorizontal -#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothHorizontal \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothHorizontal -#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothHorizontal \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothHorizontal -#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothHorizontal \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothHorizontal -#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothHorizontal \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothHorizontal -#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothHorizontal \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothHorizontal -#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothHorizontal \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothHorizontal -#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothHorizontal \ - LIBGAV1_CPU_SSE4_1 -#endif - -#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothHorizontal -#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothHorizontal \ - LIBGAV1_CPU_SSE4_1 -#endif - //------------------------------------------------------------------------------ // 10bpp diff --git a/src/dsp/x86/inverse_transform_sse4.cc b/src/dsp/x86/inverse_transform_sse4.cc index 787d706..12c008f 100644 --- a/src/dsp/x86/inverse_transform_sse4.cc +++ b/src/dsp/x86/inverse_transform_sse4.cc @@ -94,8 +94,7 @@ LIBGAV1_ALWAYS_INLINE void ButterflyRotation_4(__m128i* a, __m128i* b, static_cast<uint16_t>(cos128) | (static_cast<uint32_t>(sin128) << 16)); const __m128i ba = _mm_unpacklo_epi16(*a, *b); const __m128i ab = _mm_unpacklo_epi16(*b, *a); - const __m128i sign = - _mm_set_epi32(0x80000001, 0x80000001, 0x80000001, 0x80000001); + const __m128i sign = _mm_set1_epi32(static_cast<int>(0x80000001)); // -sin cos, -sin cos, -sin cos, -sin cos const __m128i msin_pcos = _mm_sign_epi16(psin_pcos, sign); const __m128i x0 = _mm_madd_epi16(ba, msin_pcos); @@ -121,8 +120,7 @@ LIBGAV1_ALWAYS_INLINE void ButterflyRotation_8(__m128i* a, __m128i* b, const int16_t sin128 = Sin128(angle); const __m128i psin_pcos = _mm_set1_epi32( static_cast<uint16_t>(cos128) | (static_cast<uint32_t>(sin128) << 16)); - const __m128i sign = - _mm_set_epi32(0x80000001, 0x80000001, 0x80000001, 0x80000001); + const __m128i sign = _mm_set1_epi32(static_cast<int>(0x80000001)); // -sin cos, -sin cos, -sin cos, -sin cos const __m128i msin_pcos = _mm_sign_epi16(psin_pcos, sign); const __m128i ba = _mm_unpacklo_epi16(*a, *b); @@ -229,7 +227,8 @@ LIBGAV1_ALWAYS_INLINE bool DctDcOnly(void* dest, int adjusted_tx_height, const __m128i v_src_lo = _mm_shufflelo_epi16(_mm_cvtsi32_si128(dst[0]), 0); const __m128i v_src = (width == 4) ? v_src_lo : _mm_shuffle_epi32(v_src_lo, 0); - const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0); + const __m128i v_mask = + _mm_set1_epi16(should_round ? static_cast<int16_t>(0xffff) : 0); const __m128i v_kTransformRowMultiplier = _mm_set1_epi16(kTransformRowMultiplier << 3); const __m128i v_src_round = @@ -1039,7 +1038,8 @@ LIBGAV1_ALWAYS_INLINE bool Adst4DcOnly(void* dest, int adjusted_tx_height, auto* dst = static_cast<int16_t*>(dest); const __m128i v_src = _mm_shuffle_epi32(_mm_shufflelo_epi16(_mm_cvtsi32_si128(dst[0]), 0), 0); - const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0); + const __m128i v_mask = + _mm_set1_epi16(should_round ? static_cast<int16_t>(0xffff) : 0); const __m128i v_kTransformRowMultiplier = _mm_set1_epi16(kTransformRowMultiplier << 3); const __m128i v_src_round = @@ -1194,7 +1194,8 @@ LIBGAV1_ALWAYS_INLINE bool Adst8DcOnly(void* dest, int adjusted_tx_height, __m128i s[8]; const __m128i v_src = _mm_shufflelo_epi16(_mm_cvtsi32_si128(dst[0]), 0); - const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0); + const __m128i v_mask = + _mm_set1_epi16(should_round ? static_cast<int16_t>(0xffff) : 0); const __m128i v_kTransformRowMultiplier = _mm_set1_epi16(kTransformRowMultiplier << 3); const __m128i v_src_round = @@ -1519,7 +1520,8 @@ LIBGAV1_ALWAYS_INLINE bool Adst16DcOnly(void* dest, int adjusted_tx_height, __m128i x[16]; const __m128i v_src = _mm_shufflelo_epi16(_mm_cvtsi32_si128(dst[0]), 0); - const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0); + const __m128i v_mask = + _mm_set1_epi16(should_round ? static_cast<int16_t>(0xffff) : 0); const __m128i v_kTransformRowMultiplier = _mm_set1_epi16(kTransformRowMultiplier << 3); const __m128i v_src_round = @@ -1615,7 +1617,8 @@ LIBGAV1_ALWAYS_INLINE bool Identity4DcOnly(void* dest, int adjusted_tx_height, auto* dst = static_cast<int16_t*>(dest); const __m128i v_src0 = _mm_cvtsi32_si128(dst[0]); - const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0); + const __m128i v_mask = + _mm_set1_epi16(should_round ? static_cast<int16_t>(0xffff) : 0); const __m128i v_kTransformRowMultiplier = _mm_set1_epi16(kTransformRowMultiplier << 3); const __m128i v_src_round = @@ -1767,7 +1770,8 @@ LIBGAV1_ALWAYS_INLINE bool Identity8DcOnly(void* dest, int adjusted_tx_height, auto* dst = static_cast<int16_t*>(dest); const __m128i v_src0 = _mm_cvtsi32_si128(dst[0]); - const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0); + const __m128i v_mask = + _mm_set1_epi16(should_round ? static_cast<int16_t>(0xffff) : 0); const __m128i v_kTransformRowMultiplier = _mm_set1_epi16(kTransformRowMultiplier << 3); const __m128i v_src_round = @@ -1859,7 +1863,8 @@ LIBGAV1_ALWAYS_INLINE bool Identity16DcOnly(void* dest, int adjusted_tx_height, auto* dst = static_cast<int16_t*>(dest); const __m128i v_src0 = _mm_cvtsi32_si128(dst[0]); - const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0); + const __m128i v_mask = + _mm_set1_epi16(should_round ? static_cast<int16_t>(0xffff) : 0); const __m128i v_kTransformRowMultiplier = _mm_set1_epi16(kTransformRowMultiplier << 3); const __m128i v_src_round0 = @@ -2918,75 +2923,11 @@ void Wht4TransformLoopColumn_SSE4_1(TransformType tx_type, //------------------------------------------------------------------------------ -template <typename Residual, typename Pixel> -void InitAll(Dsp* const dsp) { - // Maximum transform size for Dct is 64. - dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kRow] = - Dct4TransformLoopRow_SSE4_1; - dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kColumn] = - Dct4TransformLoopColumn_SSE4_1; - dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kRow] = - Dct8TransformLoopRow_SSE4_1; - dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kColumn] = - Dct8TransformLoopColumn_SSE4_1; - dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kRow] = - Dct16TransformLoopRow_SSE4_1; - dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kColumn] = - Dct16TransformLoopColumn_SSE4_1; - dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kRow] = - Dct32TransformLoopRow_SSE4_1; - dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kColumn] = - Dct32TransformLoopColumn_SSE4_1; - dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kRow] = - Dct64TransformLoopRow_SSE4_1; - dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kColumn] = - Dct64TransformLoopColumn_SSE4_1; - - // Maximum transform size for Adst is 16. - dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kRow] = - Adst4TransformLoopRow_SSE4_1; - dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kColumn] = - Adst4TransformLoopColumn_SSE4_1; - dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kRow] = - Adst8TransformLoopRow_SSE4_1; - dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kColumn] = - Adst8TransformLoopColumn_SSE4_1; - dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kRow] = - Adst16TransformLoopRow_SSE4_1; - dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kColumn] = - Adst16TransformLoopColumn_SSE4_1; - - // Maximum transform size for Identity transform is 32. - dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kRow] = - Identity4TransformLoopRow_SSE4_1; - dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kColumn] = - Identity4TransformLoopColumn_SSE4_1; - dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kRow] = - Identity8TransformLoopRow_SSE4_1; - dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kColumn] = - Identity8TransformLoopColumn_SSE4_1; - dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kRow] = - Identity16TransformLoopRow_SSE4_1; - dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kColumn] = - Identity16TransformLoopColumn_SSE4_1; - dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kRow] = - Identity32TransformLoopRow_SSE4_1; - dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kColumn] = - Identity32TransformLoopColumn_SSE4_1; - - // Maximum transform size for Wht is 4. - dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kRow] = - Wht4TransformLoopRow_SSE4_1; - dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kColumn] = - Wht4TransformLoopColumn_SSE4_1; -} - void Init8bpp() { Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); assert(dsp != nullptr); -#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS - InitAll<int16_t, uint8_t>(dsp); -#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + + // Maximum transform size for Dct is 64. #if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize4_1DTransformDct) dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kRow] = Dct4TransformLoopRow_SSE4_1; @@ -3017,6 +2958,8 @@ void Init8bpp() { dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kColumn] = Dct64TransformLoopColumn_SSE4_1; #endif + + // Maximum transform size for Adst is 16. #if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize4_1DTransformAdst) dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kRow] = Adst4TransformLoopRow_SSE4_1; @@ -3035,6 +2978,8 @@ void Init8bpp() { dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kColumn] = Adst16TransformLoopColumn_SSE4_1; #endif + + // Maximum transform size for Identity transform is 32. #if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize4_1DTransformIdentity) dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kRow] = Identity4TransformLoopRow_SSE4_1; @@ -3059,13 +3004,14 @@ void Init8bpp() { dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kColumn] = Identity32TransformLoopColumn_SSE4_1; #endif + + // Maximum transform size for Wht is 4. #if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize4_1DTransformWht) dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kRow] = Wht4TransformLoopRow_SSE4_1; dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kColumn] = Wht4TransformLoopColumn_SSE4_1; #endif -#endif } } // namespace @@ -3075,7 +3021,7 @@ void InverseTransformInit_SSE4_1() { low_bitdepth::Init8bpp(); } } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_SSE4_1 +#else // !LIBGAV1_TARGETING_SSE4_1 namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/loop_filter_sse4.cc b/src/dsp/x86/loop_filter_sse4.cc index d67b450..b9da2d5 100644 --- a/src/dsp/x86/loop_filter_sse4.cc +++ b/src/dsp/x86/loop_filter_sse4.cc @@ -350,7 +350,7 @@ void Horizontal6(void* dest, ptrdiff_t stride, int outer_thresh, const __m128i v_mask = _mm_shuffle_epi32(_mm_and_si128(v_needs_mask, v_isflat3_mask), 0); - if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi8(v_mask, v_mask)) == 0) { + if (_mm_test_all_zeros(v_mask, v_mask) == 0) { __m128i oqp1_f6; __m128i oqp0_f6; @@ -454,7 +454,7 @@ void Vertical6(void* dest, ptrdiff_t stride, int outer_thresh, int inner_thresh, const __m128i v_mask = _mm_shuffle_epi32(_mm_and_si128(v_needs_mask, v_isflat3_mask), 0); - if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi8(v_mask, v_mask)) == 0) { + if (_mm_test_all_zeros(v_mask, v_mask) == 0) { __m128i oqp1_f6; __m128i oqp0_f6; @@ -595,7 +595,7 @@ void Horizontal8(void* dest, ptrdiff_t stride, int outer_thresh, const __m128i v_mask = _mm_shuffle_epi32(_mm_and_si128(v_needs_mask, v_isflat4_mask), 0); - if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi8(v_mask, v_mask)) == 0) { + if (_mm_test_all_zeros(v_mask, v_mask) == 0) { __m128i oqp2_f8; __m128i oqp1_f8; __m128i oqp0_f8; @@ -697,7 +697,7 @@ void Vertical8(void* dest, ptrdiff_t stride, int outer_thresh, int inner_thresh, const __m128i v_mask = _mm_shuffle_epi32(_mm_and_si128(v_needs_mask, v_isflat4_mask), 0); - if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi8(v_mask, v_mask)) == 0) { + if (_mm_test_all_zeros(v_mask, v_mask) == 0) { __m128i oqp2_f8; __m128i oqp1_f8; __m128i oqp0_f8; @@ -838,7 +838,7 @@ void Horizontal14(void* dest, ptrdiff_t stride, int outer_thresh, const __m128i v_mask = _mm_shuffle_epi32(_mm_and_si128(v_needs_mask, v_isflat4_mask), 0); - if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi8(v_mask, v_mask)) == 0) { + if (_mm_test_all_zeros(v_mask, v_mask) == 0) { const __m128i p6 = Load4(dst - 7 * stride); const __m128i p5 = Load4(dst - 6 * stride); const __m128i p4 = Load4(dst - 5 * stride); @@ -864,8 +864,7 @@ void Horizontal14(void* dest, ptrdiff_t stride, int outer_thresh, oqp1 = _mm_blendv_epi8(oqp1, oqp1_f8, v_mask); oqp0 = _mm_blendv_epi8(oqp0, oqp0_f8, v_mask); - if (_mm_test_all_zeros(v_flat4_mask, - _mm_cmpeq_epi8(v_flat4_mask, v_flat4_mask)) == 0) { + if (_mm_test_all_zeros(v_flat4_mask, v_flat4_mask) == 0) { __m128i oqp5_f14; __m128i oqp4_f14; __m128i oqp3_f14; @@ -1050,7 +1049,7 @@ void Vertical14(void* dest, ptrdiff_t stride, int outer_thresh, const __m128i v_mask = _mm_shuffle_epi32(_mm_and_si128(v_needs_mask, v_isflat4_mask), 0); - if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi8(v_mask, v_mask)) == 0) { + if (_mm_test_all_zeros(v_mask, v_mask) == 0) { const __m128i v_isflatouter4_mask = IsFlat4(qp6, qp5, qp4, qp0, v_flat_thresh); const __m128i v_flat4_mask = @@ -1066,8 +1065,7 @@ void Vertical14(void* dest, ptrdiff_t stride, int outer_thresh, oqp1 = _mm_blendv_epi8(oqp1, oqp1_f8, v_mask); oqp0 = _mm_blendv_epi8(oqp0, oqp0_f8, v_mask); - if (_mm_test_all_zeros(v_flat4_mask, - _mm_cmpeq_epi8(v_flat4_mask, v_flat4_mask)) == 0) { + if (_mm_test_all_zeros(v_flat4_mask, v_flat4_mask) == 0) { __m128i oqp5_f14; __m128i oqp4_f14; __m128i oqp3_f14; @@ -1458,7 +1456,7 @@ void LoopFilterFuncs_SSE4_1<bitdepth>::Horizontal6(void* dest, const __m128i v_mask_lo = _mm_and_si128(v_needs_mask, v_isflat3_mask); const __m128i v_mask = _mm_unpacklo_epi64(v_mask_lo, v_mask_lo); - if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi16(v_mask, v_mask)) == 0) { + if (_mm_test_all_zeros(v_mask, v_mask) == 0) { __m128i oqp1_f6; __m128i oqp0_f6; @@ -1572,7 +1570,7 @@ void LoopFilterFuncs_SSE4_1<bitdepth>::Vertical6(void* dest, ptrdiff_t stride8, const __m128i v_mask_lo = _mm_and_si128(v_needs_mask, v_isflat3_mask); const __m128i v_mask = _mm_unpacklo_epi64(v_mask_lo, v_mask_lo); - if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi16(v_mask, v_mask)) == 0) { + if (_mm_test_all_zeros(v_mask, v_mask) == 0) { __m128i oqp1_f6; __m128i oqp0_f6; @@ -1711,7 +1709,7 @@ void LoopFilterFuncs_SSE4_1<bitdepth>::Horizontal8(void* dest, const __m128i v_mask_lo = _mm_and_si128(v_needs_mask, v_isflat4_mask); const __m128i v_mask = _mm_unpacklo_epi64(v_mask_lo, v_mask_lo); - if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi16(v_mask, v_mask)) == 0) { + if (_mm_test_all_zeros(v_mask, v_mask) == 0) { __m128i oqp2_f8; __m128i oqp1_f8; __m128i oqp0_f8; @@ -1821,7 +1819,7 @@ void LoopFilterFuncs_SSE4_1<bitdepth>::Vertical8(void* dest, ptrdiff_t stride8, const __m128i v_mask_lo = _mm_and_si128(v_needs_mask, v_isflat4_mask); const __m128i v_mask = _mm_unpacklo_epi64(v_mask_lo, v_mask_lo); - if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi16(v_mask, v_mask)) == 0) { + if (_mm_test_all_zeros(v_mask, v_mask) == 0) { __m128i oqp2_f8; __m128i oqp1_f8; __m128i oqp0_f8; @@ -1957,7 +1955,7 @@ void LoopFilterFuncs_SSE4_1<bitdepth>::Horizontal14(void* dest, const __m128i v_mask_lo = _mm_and_si128(v_needs_mask, v_isflat4_mask); const __m128i v_mask = _mm_unpacklo_epi64(v_mask_lo, v_mask_lo); - if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi16(v_mask, v_mask)) == 0) { + if (_mm_test_all_zeros(v_mask, v_mask) == 0) { const __m128i p6 = LoadLo8(dst - 7 * stride); const __m128i p5 = LoadLo8(dst - 6 * stride); const __m128i p4 = LoadLo8(dst - 5 * stride); @@ -1984,8 +1982,7 @@ void LoopFilterFuncs_SSE4_1<bitdepth>::Horizontal14(void* dest, oqp1 = _mm_blendv_epi8(oqp1, oqp1_f8, v_mask); oqp0 = _mm_blendv_epi8(oqp0, oqp0_f8, v_mask); - if (_mm_test_all_zeros(v_flat4_mask, - _mm_cmpeq_epi16(v_flat4_mask, v_flat4_mask)) == 0) { + if (_mm_test_all_zeros(v_flat4_mask, v_flat4_mask) == 0) { __m128i oqp5_f14; __m128i oqp4_f14; __m128i oqp3_f14; @@ -2133,7 +2130,7 @@ void LoopFilterFuncs_SSE4_1<bitdepth>::Vertical14(void* dest, ptrdiff_t stride8, const __m128i v_mask_lo = _mm_and_si128(v_needs_mask, v_isflat4_mask); const __m128i v_mask = _mm_unpacklo_epi64(v_mask_lo, v_mask_lo); - if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi16(v_mask, v_mask)) == 0) { + if (_mm_test_all_zeros(v_mask, v_mask) == 0) { const __m128i v_isflatouter4_mask = IsFlat4(qp6, qp5, qp4, qp0, v_flat_thresh); const __m128i v_flat4_mask_lo = _mm_and_si128(v_mask, v_isflatouter4_mask); @@ -2150,8 +2147,7 @@ void LoopFilterFuncs_SSE4_1<bitdepth>::Vertical14(void* dest, ptrdiff_t stride8, oqp1 = _mm_blendv_epi8(oqp1, oqp1_f8, v_mask); oqp0 = _mm_blendv_epi8(oqp0, oqp0_f8, v_mask); - if (_mm_test_all_zeros(v_flat4_mask, - _mm_cmpeq_epi16(v_flat4_mask, v_flat4_mask)) == 0) { + if (_mm_test_all_zeros(v_flat4_mask, v_flat4_mask) == 0) { __m128i oqp5_f14; __m128i oqp4_f14; __m128i oqp3_f14; @@ -2245,7 +2241,7 @@ void LoopFilterInit_SSE4_1() { } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_SSE4_1 +#else // !LIBGAV1_TARGETING_SSE4_1 namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/loop_restoration_10bit_avx2.cc b/src/dsp/x86/loop_restoration_10bit_avx2.cc index 702bdea..b38f322 100644 --- a/src/dsp/x86/loop_restoration_10bit_avx2.cc +++ b/src/dsp/x86/loop_restoration_10bit_avx2.cc @@ -28,7 +28,6 @@ #include "src/dsp/constants.h" #include "src/dsp/dsp.h" #include "src/dsp/x86/common_avx2.h" -#include "src/dsp/x86/common_sse4.h" #include "src/utils/common.h" #include "src/utils/constants.h" @@ -472,12 +471,12 @@ inline void WienerVerticalTap1(const int16_t* wiener_buffer, } } -void WienerFilter_AVX2(const RestorationUnitInfo& restoration_info, - const void* const source, const void* const top_border, - const void* const bottom_border, const ptrdiff_t stride, - const int width, const int height, - RestorationBuffer* const restoration_buffer, - void* const dest) { +void WienerFilter_AVX2( + const RestorationUnitInfo& restoration_info, const void* const source, + const ptrdiff_t stride, const void* const top_border, + const ptrdiff_t top_border_stride, const void* const bottom_border, + const ptrdiff_t bottom_border_stride, const int width, const int height, + RestorationBuffer* const restoration_buffer, void* const dest) { const int16_t* const number_leading_zero_coefficients = restoration_info.wiener_info.number_leading_zero_coefficients; const int number_rows_to_skip = std::max( @@ -502,39 +501,42 @@ void WienerFilter_AVX2(const RestorationUnitInfo& restoration_info, LoadLo8(restoration_info.wiener_info.filter[WienerInfo::kHorizontal]); const __m256i coefficients_horizontal = _mm256_broadcastq_epi64(c); if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) { - WienerHorizontalTap7(top + (2 - height_extra) * stride - 3, stride, - wiener_stride, height_extra, &coefficients_horizontal, - &wiener_buffer_horizontal); - WienerHorizontalTap7(src - 3, stride, wiener_stride, height, + WienerHorizontalTap7(top + (2 - height_extra) * top_border_stride - 3, + top_border_stride, wiener_stride, height_extra, &coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap7(bottom - 3, stride, wiener_stride, height_extra, + WienerHorizontalTap7(src - 3, stride, wiener_stride, height, &coefficients_horizontal, &wiener_buffer_horizontal); - } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) { - WienerHorizontalTap5(top + (2 - height_extra) * stride - 2, stride, - wiener_stride, height_extra, &coefficients_horizontal, + WienerHorizontalTap7(bottom - 3, bottom_border_stride, wiener_stride, + height_extra, &coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap5(src - 2, stride, wiener_stride, height, + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) { + WienerHorizontalTap5(top + (2 - height_extra) * top_border_stride - 2, + top_border_stride, wiener_stride, height_extra, &coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap5(bottom - 2, stride, wiener_stride, height_extra, + WienerHorizontalTap5(src - 2, stride, wiener_stride, height, &coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap5(bottom - 2, bottom_border_stride, wiener_stride, + height_extra, &coefficients_horizontal, + &wiener_buffer_horizontal); } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) { // The maximum over-reads happen here. - WienerHorizontalTap3(top + (2 - height_extra) * stride - 1, stride, - wiener_stride, height_extra, &coefficients_horizontal, - &wiener_buffer_horizontal); - WienerHorizontalTap3(src - 1, stride, wiener_stride, height, + WienerHorizontalTap3(top + (2 - height_extra) * top_border_stride - 1, + top_border_stride, wiener_stride, height_extra, &coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap3(bottom - 1, stride, wiener_stride, height_extra, + WienerHorizontalTap3(src - 1, stride, wiener_stride, height, &coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap3(bottom - 1, bottom_border_stride, wiener_stride, + height_extra, &coefficients_horizontal, + &wiener_buffer_horizontal); } else { assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3); - WienerHorizontalTap1(top + (2 - height_extra) * stride, stride, - wiener_stride, height_extra, + WienerHorizontalTap1(top + (2 - height_extra) * top_border_stride, + top_border_stride, wiener_stride, height_extra, &wiener_buffer_horizontal); WienerHorizontalTap1(src, stride, wiener_stride, height, &wiener_buffer_horizontal); - WienerHorizontalTap1(bottom, stride, wiener_stride, height_extra, - &wiener_buffer_horizontal); + WienerHorizontalTap1(bottom, bottom_border_stride, wiener_stride, + height_extra, &wiener_buffer_horizontal); } // vertical filtering. @@ -566,12 +568,2575 @@ void WienerFilter_AVX2(const RestorationUnitInfo& restoration_info, } } +//------------------------------------------------------------------------------ +// SGR + +constexpr int kSumOffset = 24; + +// SIMD overreads the number of pixels in SIMD registers - (width % 8) - 2 * +// padding pixels, where padding is 3 for Pass 1 and 2 for Pass 2. The number of +// bytes in SIMD registers is 16 for SSE4.1 and 32 for AVX2. +constexpr int kOverreadInBytesPass1_128 = 4; +constexpr int kOverreadInBytesPass2_128 = 8; +constexpr int kOverreadInBytesPass1_256 = kOverreadInBytesPass1_128 + 16; +constexpr int kOverreadInBytesPass2_256 = kOverreadInBytesPass2_128 + 16; + +inline void LoadAligned16x2U16(const uint16_t* const src[2], const ptrdiff_t x, + __m128i dst[2]) { + dst[0] = LoadAligned16(src[0] + x); + dst[1] = LoadAligned16(src[1] + x); +} + +inline void LoadAligned32x2U16(const uint16_t* const src[2], const ptrdiff_t x, + __m256i dst[2]) { + dst[0] = LoadAligned32(src[0] + x); + dst[1] = LoadAligned32(src[1] + x); +} + +inline void LoadAligned32x2U16Msan(const uint16_t* const src[2], + const ptrdiff_t x, const ptrdiff_t border, + __m256i dst[2]) { + dst[0] = LoadAligned32Msan(src[0] + x, sizeof(**src) * (x + 16 - border)); + dst[1] = LoadAligned32Msan(src[1] + x, sizeof(**src) * (x + 16 - border)); +} + +inline void LoadAligned16x3U16(const uint16_t* const src[3], const ptrdiff_t x, + __m128i dst[3]) { + dst[0] = LoadAligned16(src[0] + x); + dst[1] = LoadAligned16(src[1] + x); + dst[2] = LoadAligned16(src[2] + x); +} + +inline void LoadAligned32x3U16(const uint16_t* const src[3], const ptrdiff_t x, + __m256i dst[3]) { + dst[0] = LoadAligned32(src[0] + x); + dst[1] = LoadAligned32(src[1] + x); + dst[2] = LoadAligned32(src[2] + x); +} + +inline void LoadAligned32x3U16Msan(const uint16_t* const src[3], + const ptrdiff_t x, const ptrdiff_t border, + __m256i dst[3]) { + dst[0] = LoadAligned32Msan(src[0] + x, sizeof(**src) * (x + 16 - border)); + dst[1] = LoadAligned32Msan(src[1] + x, sizeof(**src) * (x + 16 - border)); + dst[2] = LoadAligned32Msan(src[2] + x, sizeof(**src) * (x + 16 - border)); +} + +inline void LoadAligned32U32(const uint32_t* const src, __m128i dst[2]) { + dst[0] = LoadAligned16(src + 0); + dst[1] = LoadAligned16(src + 4); +} + +inline void LoadAligned32x2U32(const uint32_t* const src[2], const ptrdiff_t x, + __m128i dst[2][2]) { + LoadAligned32U32(src[0] + x, dst[0]); + LoadAligned32U32(src[1] + x, dst[1]); +} + +inline void LoadAligned64x2U32(const uint32_t* const src[2], const ptrdiff_t x, + __m256i dst[2][2]) { + LoadAligned64(src[0] + x, dst[0]); + LoadAligned64(src[1] + x, dst[1]); +} + +inline void LoadAligned64x2U32Msan(const uint32_t* const src[2], + const ptrdiff_t x, const ptrdiff_t border, + __m256i dst[2][2]) { + LoadAligned64Msan(src[0] + x, sizeof(**src) * (x + 16 - border), dst[0]); + LoadAligned64Msan(src[1] + x, sizeof(**src) * (x + 16 - border), dst[1]); +} + +inline void LoadAligned32x3U32(const uint32_t* const src[3], const ptrdiff_t x, + __m128i dst[3][2]) { + LoadAligned32U32(src[0] + x, dst[0]); + LoadAligned32U32(src[1] + x, dst[1]); + LoadAligned32U32(src[2] + x, dst[2]); +} + +inline void LoadAligned64x3U32(const uint32_t* const src[3], const ptrdiff_t x, + __m256i dst[3][2]) { + LoadAligned64(src[0] + x, dst[0]); + LoadAligned64(src[1] + x, dst[1]); + LoadAligned64(src[2] + x, dst[2]); +} + +inline void LoadAligned64x3U32Msan(const uint32_t* const src[3], + const ptrdiff_t x, const ptrdiff_t border, + __m256i dst[3][2]) { + LoadAligned64Msan(src[0] + x, sizeof(**src) * (x + 16 - border), dst[0]); + LoadAligned64Msan(src[1] + x, sizeof(**src) * (x + 16 - border), dst[1]); + LoadAligned64Msan(src[2] + x, sizeof(**src) * (x + 16 - border), dst[2]); +} + +inline void StoreAligned32U32(uint32_t* const dst, const __m128i src[2]) { + StoreAligned16(dst + 0, src[0]); + StoreAligned16(dst + 4, src[1]); +} + +// The AVX2 ymm register holds ma[0], ma[1], ..., ma[7], and ma[16], ma[17], +// ..., ma[23]. +// There is an 8 pixel gap between the first half and the second half. +constexpr int kMaStoreOffset = 8; + +inline void StoreAligned32_ma(uint16_t* src, const __m256i v) { + StoreAligned16(src + 0 * 8, _mm256_extracti128_si256(v, 0)); + StoreAligned16(src + 2 * 8, _mm256_extracti128_si256(v, 1)); +} + +inline void StoreAligned64_ma(uint16_t* src, const __m256i v[2]) { + // The next 4 lines are much faster than: + // StoreAligned32(src + 0, _mm256_permute2x128_si256(v[0], v[1], 0x20)); + // StoreAligned32(src + 16, _mm256_permute2x128_si256(v[0], v[1], 0x31)); + StoreAligned16(src + 0 * 8, _mm256_extracti128_si256(v[0], 0)); + StoreAligned16(src + 1 * 8, _mm256_extracti128_si256(v[1], 0)); + StoreAligned16(src + 2 * 8, _mm256_extracti128_si256(v[0], 1)); + StoreAligned16(src + 3 * 8, _mm256_extracti128_si256(v[1], 1)); +} + +// Don't use _mm_cvtepu8_epi16() or _mm_cvtepu16_epi32() in the following +// functions. Some compilers may generate super inefficient code and the whole +// decoder could be 15% slower. + +inline __m256i VaddlLo8(const __m256i src0, const __m256i src1) { + const __m256i s0 = _mm256_unpacklo_epi8(src0, _mm256_setzero_si256()); + const __m256i s1 = _mm256_unpacklo_epi8(src1, _mm256_setzero_si256()); + return _mm256_add_epi16(s0, s1); +} + +inline __m256i VaddlHi8(const __m256i src0, const __m256i src1) { + const __m256i s0 = _mm256_unpackhi_epi8(src0, _mm256_setzero_si256()); + const __m256i s1 = _mm256_unpackhi_epi8(src1, _mm256_setzero_si256()); + return _mm256_add_epi16(s0, s1); +} + +inline __m256i VaddwLo8(const __m256i src0, const __m256i src1) { + const __m256i s1 = _mm256_unpacklo_epi8(src1, _mm256_setzero_si256()); + return _mm256_add_epi16(src0, s1); +} + +inline __m256i VaddwHi8(const __m256i src0, const __m256i src1) { + const __m256i s1 = _mm256_unpackhi_epi8(src1, _mm256_setzero_si256()); + return _mm256_add_epi16(src0, s1); +} + +inline __m256i VmullNLo8(const __m256i src0, const int src1) { + const __m256i s0 = _mm256_unpacklo_epi16(src0, _mm256_setzero_si256()); + return _mm256_madd_epi16(s0, _mm256_set1_epi32(src1)); +} + +inline __m256i VmullNHi8(const __m256i src0, const int src1) { + const __m256i s0 = _mm256_unpackhi_epi16(src0, _mm256_setzero_si256()); + return _mm256_madd_epi16(s0, _mm256_set1_epi32(src1)); +} + +inline __m128i VmullLo16(const __m128i src0, const __m128i src1) { + const __m128i s0 = _mm_unpacklo_epi16(src0, _mm_setzero_si128()); + const __m128i s1 = _mm_unpacklo_epi16(src1, _mm_setzero_si128()); + return _mm_madd_epi16(s0, s1); +} + +inline __m256i VmullLo16(const __m256i src0, const __m256i src1) { + const __m256i s0 = _mm256_unpacklo_epi16(src0, _mm256_setzero_si256()); + const __m256i s1 = _mm256_unpacklo_epi16(src1, _mm256_setzero_si256()); + return _mm256_madd_epi16(s0, s1); +} + +inline __m128i VmullHi16(const __m128i src0, const __m128i src1) { + const __m128i s0 = _mm_unpackhi_epi16(src0, _mm_setzero_si128()); + const __m128i s1 = _mm_unpackhi_epi16(src1, _mm_setzero_si128()); + return _mm_madd_epi16(s0, s1); +} + +inline __m256i VmullHi16(const __m256i src0, const __m256i src1) { + const __m256i s0 = _mm256_unpackhi_epi16(src0, _mm256_setzero_si256()); + const __m256i s1 = _mm256_unpackhi_epi16(src1, _mm256_setzero_si256()); + return _mm256_madd_epi16(s0, s1); +} + +inline __m128i VrshrU16(const __m128i src0, const int src1) { + const __m128i sum = _mm_add_epi16(src0, _mm_set1_epi16(1 << (src1 - 1))); + return _mm_srli_epi16(sum, src1); +} + +inline __m256i VrshrU16(const __m256i src0, const int src1) { + const __m256i sum = + _mm256_add_epi16(src0, _mm256_set1_epi16(1 << (src1 - 1))); + return _mm256_srli_epi16(sum, src1); +} + +inline __m256i VrshrS32(const __m256i src0, const int src1) { + const __m256i sum = + _mm256_add_epi32(src0, _mm256_set1_epi32(1 << (src1 - 1))); + return _mm256_srai_epi32(sum, src1); +} + +inline __m128i VrshrU32(const __m128i src0, const int src1) { + const __m128i sum = _mm_add_epi32(src0, _mm_set1_epi32(1 << (src1 - 1))); + return _mm_srli_epi32(sum, src1); +} + +inline __m256i VrshrU32(const __m256i src0, const int src1) { + const __m256i sum = + _mm256_add_epi32(src0, _mm256_set1_epi32(1 << (src1 - 1))); + return _mm256_srli_epi32(sum, src1); +} + +inline void Square(const __m128i src, __m128i dst[2]) { + const __m128i s0 = _mm_unpacklo_epi16(src, _mm_setzero_si128()); + const __m128i s1 = _mm_unpackhi_epi16(src, _mm_setzero_si128()); + dst[0] = _mm_madd_epi16(s0, s0); + dst[1] = _mm_madd_epi16(s1, s1); +} + +inline void Square(const __m256i src, __m256i dst[2]) { + const __m256i s0 = _mm256_unpacklo_epi16(src, _mm256_setzero_si256()); + const __m256i s1 = _mm256_unpackhi_epi16(src, _mm256_setzero_si256()); + dst[0] = _mm256_madd_epi16(s0, s0); + dst[1] = _mm256_madd_epi16(s1, s1); +} + +inline void Prepare3_8(const __m256i src[2], __m256i dst[3]) { + dst[0] = _mm256_alignr_epi8(src[1], src[0], 0); + dst[1] = _mm256_alignr_epi8(src[1], src[0], 1); + dst[2] = _mm256_alignr_epi8(src[1], src[0], 2); +} + +inline void Prepare3_16(const __m128i src[2], __m128i dst[3]) { + dst[0] = src[0]; + dst[1] = _mm_alignr_epi8(src[1], src[0], 2); + dst[2] = _mm_alignr_epi8(src[1], src[0], 4); +} + +inline void Prepare3_32(const __m128i src[2], __m128i dst[3]) { + dst[0] = src[0]; + dst[1] = _mm_alignr_epi8(src[1], src[0], 4); + dst[2] = _mm_alignr_epi8(src[1], src[0], 8); +} + +inline void Prepare3_32(const __m256i src[2], __m256i dst[3]) { + dst[0] = src[0]; + dst[1] = _mm256_alignr_epi8(src[1], src[0], 4); + dst[2] = _mm256_alignr_epi8(src[1], src[0], 8); +} + +inline void Prepare5_16(const __m128i src[2], __m128i dst[5]) { + Prepare3_16(src, dst); + dst[3] = _mm_alignr_epi8(src[1], src[0], 6); + dst[4] = _mm_alignr_epi8(src[1], src[0], 8); +} + +inline void Prepare5_32(const __m128i src[2], __m128i dst[5]) { + Prepare3_32(src, dst); + dst[3] = _mm_alignr_epi8(src[1], src[0], 12); + dst[4] = src[1]; +} + +inline void Prepare5_32(const __m256i src[2], __m256i dst[5]) { + Prepare3_32(src, dst); + dst[3] = _mm256_alignr_epi8(src[1], src[0], 12); + dst[4] = src[1]; +} + +inline __m128i Sum3_16(const __m128i src0, const __m128i src1, + const __m128i src2) { + const __m128i sum = _mm_add_epi16(src0, src1); + return _mm_add_epi16(sum, src2); +} + +inline __m256i Sum3_16(const __m256i src0, const __m256i src1, + const __m256i src2) { + const __m256i sum = _mm256_add_epi16(src0, src1); + return _mm256_add_epi16(sum, src2); +} + +inline __m128i Sum3_16(const __m128i src[3]) { + return Sum3_16(src[0], src[1], src[2]); +} + +inline __m256i Sum3_16(const __m256i src[3]) { + return Sum3_16(src[0], src[1], src[2]); +} + +inline __m128i Sum3_32(const __m128i src0, const __m128i src1, + const __m128i src2) { + const __m128i sum = _mm_add_epi32(src0, src1); + return _mm_add_epi32(sum, src2); +} + +inline __m256i Sum3_32(const __m256i src0, const __m256i src1, + const __m256i src2) { + const __m256i sum = _mm256_add_epi32(src0, src1); + return _mm256_add_epi32(sum, src2); +} + +inline __m128i Sum3_32(const __m128i src[3]) { + return Sum3_32(src[0], src[1], src[2]); +} + +inline __m256i Sum3_32(const __m256i src[3]) { + return Sum3_32(src[0], src[1], src[2]); +} + +inline void Sum3_32(const __m128i src[3][2], __m128i dst[2]) { + dst[0] = Sum3_32(src[0][0], src[1][0], src[2][0]); + dst[1] = Sum3_32(src[0][1], src[1][1], src[2][1]); +} + +inline void Sum3_32(const __m256i src[3][2], __m256i dst[2]) { + dst[0] = Sum3_32(src[0][0], src[1][0], src[2][0]); + dst[1] = Sum3_32(src[0][1], src[1][1], src[2][1]); +} + +inline __m256i Sum3WLo16(const __m256i src[3]) { + const __m256i sum = VaddlLo8(src[0], src[1]); + return VaddwLo8(sum, src[2]); +} + +inline __m256i Sum3WHi16(const __m256i src[3]) { + const __m256i sum = VaddlHi8(src[0], src[1]); + return VaddwHi8(sum, src[2]); +} + +inline __m128i Sum5_16(const __m128i src[5]) { + const __m128i sum01 = _mm_add_epi16(src[0], src[1]); + const __m128i sum23 = _mm_add_epi16(src[2], src[3]); + const __m128i sum = _mm_add_epi16(sum01, sum23); + return _mm_add_epi16(sum, src[4]); +} + +inline __m256i Sum5_16(const __m256i src[5]) { + const __m256i sum01 = _mm256_add_epi16(src[0], src[1]); + const __m256i sum23 = _mm256_add_epi16(src[2], src[3]); + const __m256i sum = _mm256_add_epi16(sum01, sum23); + return _mm256_add_epi16(sum, src[4]); +} + +inline __m128i Sum5_32(const __m128i* const src0, const __m128i* const src1, + const __m128i* const src2, const __m128i* const src3, + const __m128i* const src4) { + const __m128i sum01 = _mm_add_epi32(*src0, *src1); + const __m128i sum23 = _mm_add_epi32(*src2, *src3); + const __m128i sum = _mm_add_epi32(sum01, sum23); + return _mm_add_epi32(sum, *src4); +} + +inline __m256i Sum5_32(const __m256i* const src0, const __m256i* const src1, + const __m256i* const src2, const __m256i* const src3, + const __m256i* const src4) { + const __m256i sum01 = _mm256_add_epi32(*src0, *src1); + const __m256i sum23 = _mm256_add_epi32(*src2, *src3); + const __m256i sum = _mm256_add_epi32(sum01, sum23); + return _mm256_add_epi32(sum, *src4); +} + +inline __m128i Sum5_32(const __m128i src[5]) { + return Sum5_32(&src[0], &src[1], &src[2], &src[3], &src[4]); +} + +inline __m256i Sum5_32(const __m256i src[5]) { + return Sum5_32(&src[0], &src[1], &src[2], &src[3], &src[4]); +} + +inline void Sum5_32(const __m128i src[5][2], __m128i dst[2]) { + dst[0] = Sum5_32(&src[0][0], &src[1][0], &src[2][0], &src[3][0], &src[4][0]); + dst[1] = Sum5_32(&src[0][1], &src[1][1], &src[2][1], &src[3][1], &src[4][1]); +} + +inline void Sum5_32(const __m256i src[5][2], __m256i dst[2]) { + dst[0] = Sum5_32(&src[0][0], &src[1][0], &src[2][0], &src[3][0], &src[4][0]); + dst[1] = Sum5_32(&src[0][1], &src[1][1], &src[2][1], &src[3][1], &src[4][1]); +} + +inline __m128i Sum3Horizontal16(const __m128i src[2]) { + __m128i s[3]; + Prepare3_16(src, s); + return Sum3_16(s); +} + +inline __m256i Sum3Horizontal16(const uint16_t* const src, + const ptrdiff_t over_read_in_bytes) { + __m256i s[3]; + s[0] = LoadUnaligned32Msan(src + 0, over_read_in_bytes + 0); + s[1] = LoadUnaligned32Msan(src + 1, over_read_in_bytes + 2); + s[2] = LoadUnaligned32Msan(src + 2, over_read_in_bytes + 4); + return Sum3_16(s); +} + +inline __m128i Sum5Horizontal16(const __m128i src[2]) { + __m128i s[5]; + Prepare5_16(src, s); + return Sum5_16(s); +} + +inline __m256i Sum5Horizontal16(const uint16_t* const src, + const ptrdiff_t over_read_in_bytes) { + __m256i s[5]; + s[0] = LoadUnaligned32Msan(src + 0, over_read_in_bytes + 0); + s[1] = LoadUnaligned32Msan(src + 1, over_read_in_bytes + 2); + s[2] = LoadUnaligned32Msan(src + 2, over_read_in_bytes + 4); + s[3] = LoadUnaligned32Msan(src + 3, over_read_in_bytes + 6); + s[4] = LoadUnaligned32Msan(src + 4, over_read_in_bytes + 8); + return Sum5_16(s); +} + +inline void SumHorizontal16(const uint16_t* const src, + const ptrdiff_t over_read_in_bytes, + __m256i* const row3, __m256i* const row5) { + __m256i s[5]; + s[0] = LoadUnaligned32Msan(src + 0, over_read_in_bytes + 0); + s[1] = LoadUnaligned32Msan(src + 1, over_read_in_bytes + 2); + s[2] = LoadUnaligned32Msan(src + 2, over_read_in_bytes + 4); + s[3] = LoadUnaligned32Msan(src + 3, over_read_in_bytes + 6); + s[4] = LoadUnaligned32Msan(src + 4, over_read_in_bytes + 8); + const __m256i sum04 = _mm256_add_epi16(s[0], s[4]); + *row3 = Sum3_16(s + 1); + *row5 = _mm256_add_epi16(sum04, *row3); +} + +inline void SumHorizontal16(const uint16_t* const src, + const ptrdiff_t over_read_in_bytes, + __m256i* const row3_0, __m256i* const row3_1, + __m256i* const row5_0, __m256i* const row5_1) { + SumHorizontal16(src + 0, over_read_in_bytes + 0, row3_0, row5_0); + SumHorizontal16(src + 16, over_read_in_bytes + 32, row3_1, row5_1); +} + +inline void SumHorizontal32(const __m128i src[5], __m128i* const row_sq3, + __m128i* const row_sq5) { + const __m128i sum04 = _mm_add_epi32(src[0], src[4]); + *row_sq3 = Sum3_32(src + 1); + *row_sq5 = _mm_add_epi32(sum04, *row_sq3); +} + +inline void SumHorizontal32(const __m256i src[5], __m256i* const row_sq3, + __m256i* const row_sq5) { + const __m256i sum04 = _mm256_add_epi32(src[0], src[4]); + *row_sq3 = Sum3_32(src + 1); + *row_sq5 = _mm256_add_epi32(sum04, *row_sq3); +} + +inline void SumHorizontal32(const __m128i src[3], __m128i* const row_sq3_0, + __m128i* const row_sq3_1, __m128i* const row_sq5_0, + __m128i* const row_sq5_1) { + __m128i s[5]; + Prepare5_32(src + 0, s); + SumHorizontal32(s, row_sq3_0, row_sq5_0); + Prepare5_32(src + 1, s); + SumHorizontal32(s, row_sq3_1, row_sq5_1); +} + +inline void SumHorizontal32(const __m256i src[3], __m256i* const row_sq3_0, + __m256i* const row_sq3_1, __m256i* const row_sq5_0, + __m256i* const row_sq5_1) { + __m256i s[5]; + Prepare5_32(src + 0, s); + SumHorizontal32(s, row_sq3_0, row_sq5_0); + Prepare5_32(src + 1, s); + SumHorizontal32(s, row_sq3_1, row_sq5_1); +} + +inline void Sum3Horizontal32(const __m128i src[3], __m128i dst[2]) { + __m128i s[3]; + Prepare3_32(src + 0, s); + dst[0] = Sum3_32(s); + Prepare3_32(src + 1, s); + dst[1] = Sum3_32(s); +} + +inline void Sum3Horizontal32(const __m256i src[3], __m256i dst[2]) { + __m256i s[3]; + Prepare3_32(src + 0, s); + dst[0] = Sum3_32(s); + Prepare3_32(src + 1, s); + dst[1] = Sum3_32(s); +} + +inline void Sum5Horizontal32(const __m128i src[3], __m128i dst[2]) { + __m128i s[5]; + Prepare5_32(src + 0, s); + dst[0] = Sum5_32(s); + Prepare5_32(src + 1, s); + dst[1] = Sum5_32(s); +} + +inline void Sum5Horizontal32(const __m256i src[3], __m256i dst[2]) { + __m256i s[5]; + Prepare5_32(src + 0, s); + dst[0] = Sum5_32(s); + Prepare5_32(src + 1, s); + dst[1] = Sum5_32(s); +} + +void SumHorizontal16(const __m128i src[2], __m128i* const row3, + __m128i* const row5) { + __m128i s[5]; + Prepare5_16(src, s); + const __m128i sum04 = _mm_add_epi16(s[0], s[4]); + *row3 = Sum3_16(s + 1); + *row5 = _mm_add_epi16(sum04, *row3); +} + +inline __m256i Sum343Lo(const __m256i ma3[3]) { + const __m256i sum = Sum3WLo16(ma3); + const __m256i sum3 = Sum3_16(sum, sum, sum); + return VaddwLo8(sum3, ma3[1]); +} + +inline __m256i Sum343Hi(const __m256i ma3[3]) { + const __m256i sum = Sum3WHi16(ma3); + const __m256i sum3 = Sum3_16(sum, sum, sum); + return VaddwHi8(sum3, ma3[1]); +} + +inline __m256i Sum343(const __m256i src[3]) { + const __m256i sum = Sum3_32(src); + const __m256i sum3 = Sum3_32(sum, sum, sum); + return _mm256_add_epi32(sum3, src[1]); +} + +inline void Sum343(const __m256i src[3], __m256i dst[2]) { + __m256i s[3]; + Prepare3_32(src + 0, s); + dst[0] = Sum343(s); + Prepare3_32(src + 1, s); + dst[1] = Sum343(s); +} + +inline __m256i Sum565Lo(const __m256i src[3]) { + const __m256i sum = Sum3WLo16(src); + const __m256i sum4 = _mm256_slli_epi16(sum, 2); + const __m256i sum5 = _mm256_add_epi16(sum4, sum); + return VaddwLo8(sum5, src[1]); +} + +inline __m256i Sum565Hi(const __m256i src[3]) { + const __m256i sum = Sum3WHi16(src); + const __m256i sum4 = _mm256_slli_epi16(sum, 2); + const __m256i sum5 = _mm256_add_epi16(sum4, sum); + return VaddwHi8(sum5, src[1]); +} + +inline __m256i Sum565(const __m256i src[3]) { + const __m256i sum = Sum3_32(src); + const __m256i sum4 = _mm256_slli_epi32(sum, 2); + const __m256i sum5 = _mm256_add_epi32(sum4, sum); + return _mm256_add_epi32(sum5, src[1]); +} + +inline void Sum565(const __m256i src[3], __m256i dst[2]) { + __m256i s[3]; + Prepare3_32(src + 0, s); + dst[0] = Sum565(s); + Prepare3_32(src + 1, s); + dst[1] = Sum565(s); +} + +inline void BoxSum(const uint16_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const ptrdiff_t sum_stride, + const ptrdiff_t sum_width, uint16_t* sum3, uint16_t* sum5, + uint32_t* square_sum3, uint32_t* square_sum5) { + const ptrdiff_t overread_in_bytes_128 = + kOverreadInBytesPass1_128 - sizeof(*src) * width; + const ptrdiff_t overread_in_bytes_256 = + kOverreadInBytesPass1_256 - sizeof(*src) * width; + int y = 2; + do { + __m128i s0[2], sq_128[4], s3, s5, sq3[2], sq5[2]; + __m256i sq[8]; + s0[0] = LoadUnaligned16Msan(src + 0, overread_in_bytes_128 + 0); + s0[1] = LoadUnaligned16Msan(src + 8, overread_in_bytes_128 + 16); + Square(s0[0], sq_128 + 0); + Square(s0[1], sq_128 + 2); + SumHorizontal16(s0, &s3, &s5); + StoreAligned16(sum3, s3); + StoreAligned16(sum5, s5); + SumHorizontal32(sq_128, &sq3[0], &sq3[1], &sq5[0], &sq5[1]); + StoreAligned32U32(square_sum3, sq3); + StoreAligned32U32(square_sum5, sq5); + src += 8; + sum3 += 8; + sum5 += 8; + square_sum3 += 8; + square_sum5 += 8; + sq[0] = SetrM128i(sq_128[2], sq_128[2]); + sq[1] = SetrM128i(sq_128[3], sq_128[3]); + ptrdiff_t x = sum_width; + do { + __m256i s[2], row3[2], row5[2], row_sq3[2], row_sq5[2]; + s[0] = LoadUnaligned32Msan( + src + 8, overread_in_bytes_256 + sizeof(*src) * (sum_width - x + 8)); + s[1] = LoadUnaligned32Msan( + src + 24, + overread_in_bytes_256 + sizeof(*src) * (sum_width - x + 24)); + Square(s[0], sq + 2); + Square(s[1], sq + 6); + sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21); + sq[1] = _mm256_permute2x128_si256(sq[1], sq[3], 0x21); + sq[4] = _mm256_permute2x128_si256(sq[2], sq[6], 0x21); + sq[5] = _mm256_permute2x128_si256(sq[3], sq[7], 0x21); + SumHorizontal16( + src, overread_in_bytes_256 + sizeof(*src) * (sum_width - x + 8), + &row3[0], &row3[1], &row5[0], &row5[1]); + StoreAligned64(sum3, row3); + StoreAligned64(sum5, row5); + SumHorizontal32(sq + 0, &row_sq3[0], &row_sq3[1], &row_sq5[0], + &row_sq5[1]); + StoreAligned64(square_sum3 + 0, row_sq3); + StoreAligned64(square_sum5 + 0, row_sq5); + SumHorizontal32(sq + 4, &row_sq3[0], &row_sq3[1], &row_sq5[0], + &row_sq5[1]); + StoreAligned64(square_sum3 + 16, row_sq3); + StoreAligned64(square_sum5 + 16, row_sq5); + sq[0] = sq[6]; + sq[1] = sq[7]; + src += 32; + sum3 += 32; + sum5 += 32; + square_sum3 += 32; + square_sum5 += 32; + x -= 32; + } while (x != 0); + src += src_stride - sum_width - 8; + sum3 += sum_stride - sum_width - 8; + sum5 += sum_stride - sum_width - 8; + square_sum3 += sum_stride - sum_width - 8; + square_sum5 += sum_stride - sum_width - 8; + } while (--y != 0); +} + +template <int size> +inline void BoxSum(const uint16_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const ptrdiff_t sum_stride, + const ptrdiff_t sum_width, uint16_t* sums, + uint32_t* square_sums) { + static_assert(size == 3 || size == 5, ""); + int overread_in_bytes_128, overread_in_bytes_256; + if (size == 3) { + overread_in_bytes_128 = kOverreadInBytesPass2_128; + overread_in_bytes_256 = kOverreadInBytesPass2_256; + } else { + overread_in_bytes_128 = kOverreadInBytesPass1_128; + overread_in_bytes_256 = kOverreadInBytesPass1_256; + } + overread_in_bytes_128 -= sizeof(*src) * width; + overread_in_bytes_256 -= sizeof(*src) * width; + int y = 2; + do { + __m128i s_128[2], ss, sq_128[4], sqs[2]; + __m256i sq[8]; + s_128[0] = LoadUnaligned16Msan(src + 0, overread_in_bytes_128); + s_128[1] = LoadUnaligned16Msan(src + 8, overread_in_bytes_128 + 16); + Square(s_128[0], sq_128 + 0); + Square(s_128[1], sq_128 + 2); + if (size == 3) { + ss = Sum3Horizontal16(s_128); + Sum3Horizontal32(sq_128, sqs); + } else { + ss = Sum5Horizontal16(s_128); + Sum5Horizontal32(sq_128, sqs); + } + StoreAligned16(sums, ss); + StoreAligned32U32(square_sums, sqs); + src += 8; + sums += 8; + square_sums += 8; + sq[0] = SetrM128i(sq_128[2], sq_128[2]); + sq[1] = SetrM128i(sq_128[3], sq_128[3]); + ptrdiff_t x = sum_width; + do { + __m256i s[2], row[2], row_sq[4]; + s[0] = LoadUnaligned32Msan( + src + 8, overread_in_bytes_256 + sizeof(*src) * (sum_width - x + 8)); + s[1] = LoadUnaligned32Msan( + src + 24, + overread_in_bytes_256 + sizeof(*src) * (sum_width - x + 24)); + Square(s[0], sq + 2); + Square(s[1], sq + 6); + sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21); + sq[1] = _mm256_permute2x128_si256(sq[1], sq[3], 0x21); + sq[4] = _mm256_permute2x128_si256(sq[2], sq[6], 0x21); + sq[5] = _mm256_permute2x128_si256(sq[3], sq[7], 0x21); + if (size == 3) { + row[0] = Sum3Horizontal16( + src, overread_in_bytes_256 + sizeof(*src) * (sum_width - x + 8)); + row[1] = + Sum3Horizontal16(src + 16, overread_in_bytes_256 + + sizeof(*src) * (sum_width - x + 24)); + Sum3Horizontal32(sq + 0, row_sq + 0); + Sum3Horizontal32(sq + 4, row_sq + 2); + } else { + row[0] = Sum5Horizontal16( + src, overread_in_bytes_256 + sizeof(*src) * (sum_width - x + 8)); + row[1] = + Sum5Horizontal16(src + 16, overread_in_bytes_256 + + sizeof(*src) * (sum_width - x + 24)); + Sum5Horizontal32(sq + 0, row_sq + 0); + Sum5Horizontal32(sq + 4, row_sq + 2); + } + StoreAligned64(sums, row); + StoreAligned64(square_sums + 0, row_sq + 0); + StoreAligned64(square_sums + 16, row_sq + 2); + sq[0] = sq[6]; + sq[1] = sq[7]; + src += 32; + sums += 32; + square_sums += 32; + x -= 32; + } while (x != 0); + src += src_stride - sum_width - 8; + sums += sum_stride - sum_width - 8; + square_sums += sum_stride - sum_width - 8; + } while (--y != 0); +} + +template <int n> +inline __m128i CalculateMa(const __m128i sum, const __m128i sum_sq, + const uint32_t scale) { + static_assert(n == 9 || n == 25, ""); + // a = |sum_sq| + // d = |sum| + // p = (a * n < d * d) ? 0 : a * n - d * d; + const __m128i dxd = _mm_madd_epi16(sum, sum); + // _mm_mullo_epi32() has high latency. Using shifts and additions instead. + // Some compilers could do this for us but we make this explicit. + // return _mm_mullo_epi32(sum_sq, _mm_set1_epi32(n)); + __m128i axn = _mm_add_epi32(sum_sq, _mm_slli_epi32(sum_sq, 3)); + if (n == 25) axn = _mm_add_epi32(axn, _mm_slli_epi32(sum_sq, 4)); + const __m128i sub = _mm_sub_epi32(axn, dxd); + const __m128i p = _mm_max_epi32(sub, _mm_setzero_si128()); + const __m128i pxs = _mm_mullo_epi32(p, _mm_set1_epi32(scale)); + return VrshrU32(pxs, kSgrProjScaleBits); +} + +template <int n> +inline __m128i CalculateMa(const __m128i sum, const __m128i sum_sq[2], + const uint32_t scale) { + static_assert(n == 9 || n == 25, ""); + const __m128i b = VrshrU16(sum, 2); + const __m128i sum_lo = _mm_unpacklo_epi16(b, _mm_setzero_si128()); + const __m128i sum_hi = _mm_unpackhi_epi16(b, _mm_setzero_si128()); + const __m128i z0 = CalculateMa<n>(sum_lo, VrshrU32(sum_sq[0], 4), scale); + const __m128i z1 = CalculateMa<n>(sum_hi, VrshrU32(sum_sq[1], 4), scale); + return _mm_packus_epi32(z0, z1); +} + +template <int n> +inline __m256i CalculateMa(const __m256i sum, const __m256i sum_sq, + const uint32_t scale) { + static_assert(n == 9 || n == 25, ""); + // a = |sum_sq| + // d = |sum| + // p = (a * n < d * d) ? 0 : a * n - d * d; + const __m256i dxd = _mm256_madd_epi16(sum, sum); + // _mm256_mullo_epi32() has high latency. Using shifts and additions instead. + // Some compilers could do this for us but we make this explicit. + // return _mm256_mullo_epi32(sum_sq, _mm256_set1_epi32(n)); + __m256i axn = _mm256_add_epi32(sum_sq, _mm256_slli_epi32(sum_sq, 3)); + if (n == 25) axn = _mm256_add_epi32(axn, _mm256_slli_epi32(sum_sq, 4)); + const __m256i sub = _mm256_sub_epi32(axn, dxd); + const __m256i p = _mm256_max_epi32(sub, _mm256_setzero_si256()); + const __m256i pxs = _mm256_mullo_epi32(p, _mm256_set1_epi32(scale)); + return VrshrU32(pxs, kSgrProjScaleBits); +} + +template <int n> +inline __m256i CalculateMa(const __m256i sum, const __m256i sum_sq[2], + const uint32_t scale) { + static_assert(n == 9 || n == 25, ""); + const __m256i b = VrshrU16(sum, 2); + const __m256i sum_lo = _mm256_unpacklo_epi16(b, _mm256_setzero_si256()); + const __m256i sum_hi = _mm256_unpackhi_epi16(b, _mm256_setzero_si256()); + const __m256i z0 = CalculateMa<n>(sum_lo, VrshrU32(sum_sq[0], 4), scale); + const __m256i z1 = CalculateMa<n>(sum_hi, VrshrU32(sum_sq[1], 4), scale); + return _mm256_packus_epi32(z0, z1); +} + +inline void CalculateB5(const __m128i sum, const __m128i ma, __m128i b[2]) { + // one_over_n == 164. + constexpr uint32_t one_over_n = + ((1 << kSgrProjReciprocalBits) + (25 >> 1)) / 25; + // one_over_n_quarter == 41. + constexpr uint32_t one_over_n_quarter = one_over_n >> 2; + static_assert(one_over_n == one_over_n_quarter << 2, ""); + // |ma| is in range [0, 255]. + const __m128i m = _mm_maddubs_epi16(ma, _mm_set1_epi16(one_over_n_quarter)); + const __m128i m0 = VmullLo16(m, sum); + const __m128i m1 = VmullHi16(m, sum); + b[0] = VrshrU32(m0, kSgrProjReciprocalBits - 2); + b[1] = VrshrU32(m1, kSgrProjReciprocalBits - 2); +} + +inline void CalculateB5(const __m256i sum, const __m256i ma, __m256i b[2]) { + // one_over_n == 164. + constexpr uint32_t one_over_n = + ((1 << kSgrProjReciprocalBits) + (25 >> 1)) / 25; + // one_over_n_quarter == 41. + constexpr uint32_t one_over_n_quarter = one_over_n >> 2; + static_assert(one_over_n == one_over_n_quarter << 2, ""); + // |ma| is in range [0, 255]. + const __m256i m = + _mm256_maddubs_epi16(ma, _mm256_set1_epi16(one_over_n_quarter)); + const __m256i m0 = VmullLo16(m, sum); + const __m256i m1 = VmullHi16(m, sum); + b[0] = VrshrU32(m0, kSgrProjReciprocalBits - 2); + b[1] = VrshrU32(m1, kSgrProjReciprocalBits - 2); +} + +inline void CalculateB3(const __m128i sum, const __m128i ma, __m128i b[2]) { + // one_over_n == 455. + constexpr uint32_t one_over_n = + ((1 << kSgrProjReciprocalBits) + (9 >> 1)) / 9; + const __m128i m0 = VmullLo16(ma, sum); + const __m128i m1 = VmullHi16(ma, sum); + const __m128i m2 = _mm_mullo_epi32(m0, _mm_set1_epi32(one_over_n)); + const __m128i m3 = _mm_mullo_epi32(m1, _mm_set1_epi32(one_over_n)); + b[0] = VrshrU32(m2, kSgrProjReciprocalBits); + b[1] = VrshrU32(m3, kSgrProjReciprocalBits); +} + +inline void CalculateB3(const __m256i sum, const __m256i ma, __m256i b[2]) { + // one_over_n == 455. + constexpr uint32_t one_over_n = + ((1 << kSgrProjReciprocalBits) + (9 >> 1)) / 9; + const __m256i m0 = VmullLo16(ma, sum); + const __m256i m1 = VmullHi16(ma, sum); + const __m256i m2 = _mm256_mullo_epi32(m0, _mm256_set1_epi32(one_over_n)); + const __m256i m3 = _mm256_mullo_epi32(m1, _mm256_set1_epi32(one_over_n)); + b[0] = VrshrU32(m2, kSgrProjReciprocalBits); + b[1] = VrshrU32(m3, kSgrProjReciprocalBits); +} + +inline void CalculateSumAndIndex5(const __m128i s5[5], const __m128i sq5[5][2], + const uint32_t scale, __m128i* const sum, + __m128i* const index) { + __m128i sum_sq[2]; + *sum = Sum5_16(s5); + Sum5_32(sq5, sum_sq); + *index = CalculateMa<25>(*sum, sum_sq, scale); +} + +inline void CalculateSumAndIndex5(const __m256i s5[5], const __m256i sq5[5][2], + const uint32_t scale, __m256i* const sum, + __m256i* const index) { + __m256i sum_sq[2]; + *sum = Sum5_16(s5); + Sum5_32(sq5, sum_sq); + *index = CalculateMa<25>(*sum, sum_sq, scale); +} + +inline void CalculateSumAndIndex3(const __m128i s3[3], const __m128i sq3[3][2], + const uint32_t scale, __m128i* const sum, + __m128i* const index) { + __m128i sum_sq[2]; + *sum = Sum3_16(s3); + Sum3_32(sq3, sum_sq); + *index = CalculateMa<9>(*sum, sum_sq, scale); +} + +inline void CalculateSumAndIndex3(const __m256i s3[3], const __m256i sq3[3][2], + const uint32_t scale, __m256i* const sum, + __m256i* const index) { + __m256i sum_sq[2]; + *sum = Sum3_16(s3); + Sum3_32(sq3, sum_sq); + *index = CalculateMa<9>(*sum, sum_sq, scale); +} + +template <int n> +inline void LookupIntermediate(const __m128i sum, const __m128i index, + __m128i* const ma, __m128i b[2]) { + static_assert(n == 9 || n == 25, ""); + const __m128i idx = _mm_packus_epi16(index, index); + // Actually it's not stored and loaded. The compiler will use a 64-bit + // general-purpose register to process. Faster than using _mm_extract_epi8(). + uint8_t temp[8]; + StoreLo8(temp, idx); + *ma = _mm_cvtsi32_si128(kSgrMaLookup[temp[0]]); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[1]], 1); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[2]], 2); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[3]], 3); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[4]], 4); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[5]], 5); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[6]], 6); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[7]], 7); + // b = ma * b * one_over_n + // |ma| = [0, 255] + // |sum| is a box sum with radius 1 or 2. + // For the first pass radius is 2. Maximum value is 5x5x255 = 6375. + // For the second pass radius is 1. Maximum value is 3x3x255 = 2295. + // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n + // When radius is 2 |n| is 25. |one_over_n| is 164. + // When radius is 1 |n| is 9. |one_over_n| is 455. + // |kSgrProjReciprocalBits| is 12. + // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits). + // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits). + const __m128i maq = _mm_unpacklo_epi8(*ma, _mm_setzero_si128()); + if (n == 9) { + CalculateB3(sum, maq, b); + } else { + CalculateB5(sum, maq, b); + } +} + +// Repeat the first 48 elements in kSgrMaLookup with a period of 16. +alignas(32) constexpr uint8_t kSgrMaLookupAvx2[96] = { + 255, 128, 85, 64, 51, 43, 37, 32, 28, 26, 23, 21, 20, 18, 17, 16, + 255, 128, 85, 64, 51, 43, 37, 32, 28, 26, 23, 21, 20, 18, 17, 16, + 15, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 9, 9, 8, 8, + 15, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 9, 9, 8, 8, + 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 5, 5, + 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 5, 5}; + +// Set the shuffle control mask of indices out of range [0, 15] to (1xxxxxxx)b +// to get value 0 as the shuffle result. The most significiant bit 1 comes +// either from the comparison instruction, or from the sign bit of the index. +inline __m128i ShuffleIndex(const __m128i table, const __m128i index) { + __m128i mask; + mask = _mm_cmpgt_epi8(index, _mm_set1_epi8(15)); + mask = _mm_or_si128(mask, index); + return _mm_shuffle_epi8(table, mask); +} + +inline __m256i ShuffleIndex(const __m256i table, const __m256i index) { + __m256i mask; + mask = _mm256_cmpgt_epi8(index, _mm256_set1_epi8(15)); + mask = _mm256_or_si256(mask, index); + return _mm256_shuffle_epi8(table, mask); +} + +inline __m128i AdjustValue(const __m128i value, const __m128i index, + const int threshold) { + const __m128i thresholds = _mm_set1_epi8(threshold - 128); + const __m128i offset = _mm_cmpgt_epi8(index, thresholds); + return _mm_add_epi8(value, offset); +} + +inline __m256i AdjustValue(const __m256i value, const __m256i index, + const int threshold) { + const __m256i thresholds = _mm256_set1_epi8(threshold - 128); + const __m256i offset = _mm256_cmpgt_epi8(index, thresholds); + return _mm256_add_epi8(value, offset); +} + +inline void CalculateIntermediate(const __m128i sum[2], const __m128i index[2], + __m128i* const ma, __m128i b0[2], + __m128i b1[2]) { + // Use table lookup to read elements whose indices are less than 48. + const __m128i c0 = LoadAligned16(kSgrMaLookup + 0 * 16); + const __m128i c1 = LoadAligned16(kSgrMaLookup + 1 * 16); + const __m128i c2 = LoadAligned16(kSgrMaLookup + 2 * 16); + const __m128i indices = _mm_packus_epi16(index[0], index[1]); + __m128i idx; + // Clip idx to 127 to apply signed comparison instructions. + idx = _mm_min_epu8(indices, _mm_set1_epi8(127)); + // All elements whose indices are less than 48 are set to 0. + // Get shuffle results for indices in range [0, 15]. + *ma = ShuffleIndex(c0, idx); + // Get shuffle results for indices in range [16, 31]. + // Subtract 16 to utilize the sign bit of the index. + idx = _mm_sub_epi8(idx, _mm_set1_epi8(16)); + const __m128i res1 = ShuffleIndex(c1, idx); + // Use OR instruction to combine shuffle results together. + *ma = _mm_or_si128(*ma, res1); + // Get shuffle results for indices in range [32, 47]. + // Subtract 16 to utilize the sign bit of the index. + idx = _mm_sub_epi8(idx, _mm_set1_epi8(16)); + const __m128i res2 = ShuffleIndex(c2, idx); + *ma = _mm_or_si128(*ma, res2); + + // For elements whose indices are larger than 47, since they seldom change + // values with the increase of the index, we use comparison and arithmetic + // operations to calculate their values. + // Add -128 to apply signed comparison instructions. + idx = _mm_add_epi8(indices, _mm_set1_epi8(-128)); + // Elements whose indices are larger than 47 (with value 0) are set to 5. + *ma = _mm_max_epu8(*ma, _mm_set1_epi8(5)); + *ma = AdjustValue(*ma, idx, 55); // 55 is the last index which value is 5. + *ma = AdjustValue(*ma, idx, 72); // 72 is the last index which value is 4. + *ma = AdjustValue(*ma, idx, 101); // 101 is the last index which value is 3. + *ma = AdjustValue(*ma, idx, 169); // 169 is the last index which value is 2. + *ma = AdjustValue(*ma, idx, 254); // 254 is the last index which value is 1. + + // b = ma * b * one_over_n + // |ma| = [0, 255] + // |sum| is a box sum with radius 1 or 2. + // For the first pass radius is 2. Maximum value is 5x5x255 = 6375. + // For the second pass radius is 1. Maximum value is 3x3x255 = 2295. + // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n + // When radius is 2 |n| is 25. |one_over_n| is 164. + // When radius is 1 |n| is 9. |one_over_n| is 455. + // |kSgrProjReciprocalBits| is 12. + // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits). + // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits). + const __m128i maq0 = _mm_unpacklo_epi8(*ma, _mm_setzero_si128()); + CalculateB3(sum[0], maq0, b0); + const __m128i maq1 = _mm_unpackhi_epi8(*ma, _mm_setzero_si128()); + CalculateB3(sum[1], maq1, b1); +} + +template <int n> +inline void CalculateIntermediate(const __m256i sum[2], const __m256i index[2], + __m256i ma[3], __m256i b0[2], __m256i b1[2]) { + static_assert(n == 9 || n == 25, ""); + // Use table lookup to read elements whose indices are less than 48. + const __m256i c0 = LoadAligned32(kSgrMaLookupAvx2 + 0 * 32); + const __m256i c1 = LoadAligned32(kSgrMaLookupAvx2 + 1 * 32); + const __m256i c2 = LoadAligned32(kSgrMaLookupAvx2 + 2 * 32); + const __m256i indices = _mm256_packus_epi16(index[0], index[1]); // 0 2 1 3 + __m256i idx, mas; + // Clip idx to 127 to apply signed comparison instructions. + idx = _mm256_min_epu8(indices, _mm256_set1_epi8(127)); + // All elements whose indices are less than 48 are set to 0. + // Get shuffle results for indices in range [0, 15]. + mas = ShuffleIndex(c0, idx); + // Get shuffle results for indices in range [16, 31]. + // Subtract 16 to utilize the sign bit of the index. + idx = _mm256_sub_epi8(idx, _mm256_set1_epi8(16)); + const __m256i res1 = ShuffleIndex(c1, idx); + // Use OR instruction to combine shuffle results together. + mas = _mm256_or_si256(mas, res1); + // Get shuffle results for indices in range [32, 47]. + // Subtract 16 to utilize the sign bit of the index. + idx = _mm256_sub_epi8(idx, _mm256_set1_epi8(16)); + const __m256i res2 = ShuffleIndex(c2, idx); + mas = _mm256_or_si256(mas, res2); + + // For elements whose indices are larger than 47, since they seldom change + // values with the increase of the index, we use comparison and arithmetic + // operations to calculate their values. + // Add -128 to apply signed comparison instructions. + idx = _mm256_add_epi8(indices, _mm256_set1_epi8(-128)); + // Elements whose indices are larger than 47 (with value 0) are set to 5. + mas = _mm256_max_epu8(mas, _mm256_set1_epi8(5)); + mas = AdjustValue(mas, idx, 55); // 55 is the last index which value is 5. + mas = AdjustValue(mas, idx, 72); // 72 is the last index which value is 4. + mas = AdjustValue(mas, idx, 101); // 101 is the last index which value is 3. + mas = AdjustValue(mas, idx, 169); // 169 is the last index which value is 2. + mas = AdjustValue(mas, idx, 254); // 254 is the last index which value is 1. + + ma[2] = _mm256_permute4x64_epi64(mas, 0x63); // 32-39 8-15 16-23 24-31 + ma[0] = _mm256_blend_epi32(ma[0], ma[2], 0xfc); // 0-7 8-15 16-23 24-31 + ma[1] = _mm256_permute2x128_si256(ma[0], ma[2], 0x21); + + // b = ma * b * one_over_n + // |ma| = [0, 255] + // |sum| is a box sum with radius 1 or 2. + // For the first pass radius is 2. Maximum value is 5x5x255 = 6375. + // For the second pass radius is 1. Maximum value is 3x3x255 = 2295. + // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n + // When radius is 2 |n| is 25. |one_over_n| is 164. + // When radius is 1 |n| is 9. |one_over_n| is 455. + // |kSgrProjReciprocalBits| is 12. + // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits). + // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits). + const __m256i maq0 = _mm256_unpackhi_epi8(ma[0], _mm256_setzero_si256()); + const __m256i maq1 = _mm256_unpacklo_epi8(ma[1], _mm256_setzero_si256()); + __m256i sums[2]; + sums[0] = _mm256_permute2x128_si256(sum[0], sum[1], 0x20); + sums[1] = _mm256_permute2x128_si256(sum[0], sum[1], 0x31); + if (n == 9) { + CalculateB3(sums[0], maq0, b0); + CalculateB3(sums[1], maq1, b1); + } else { + CalculateB5(sums[0], maq0, b0); + CalculateB5(sums[1], maq1, b1); + } +} + +inline void CalculateIntermediate5(const __m128i s5[5], const __m128i sq5[5][2], + const uint32_t scale, __m128i* const ma, + __m128i b[2]) { + __m128i sum, index; + CalculateSumAndIndex5(s5, sq5, scale, &sum, &index); + LookupIntermediate<25>(sum, index, ma, b); +} + +inline void CalculateIntermediate3(const __m128i s3[3], const __m128i sq3[3][2], + const uint32_t scale, __m128i* const ma, + __m128i b[2]) { + __m128i sum, index; + CalculateSumAndIndex3(s3, sq3, scale, &sum, &index); + LookupIntermediate<9>(sum, index, ma, b); +} + +inline void Store343_444(const __m256i b3[3], const ptrdiff_t x, + __m256i sum_b343[2], __m256i sum_b444[2], + uint32_t* const b343, uint32_t* const b444) { + __m256i b[3], sum_b111[2]; + Prepare3_32(b3 + 0, b); + sum_b111[0] = Sum3_32(b); + sum_b444[0] = _mm256_slli_epi32(sum_b111[0], 2); + sum_b343[0] = _mm256_sub_epi32(sum_b444[0], sum_b111[0]); + sum_b343[0] = _mm256_add_epi32(sum_b343[0], b[1]); + Prepare3_32(b3 + 1, b); + sum_b111[1] = Sum3_32(b); + sum_b444[1] = _mm256_slli_epi32(sum_b111[1], 2); + sum_b343[1] = _mm256_sub_epi32(sum_b444[1], sum_b111[1]); + sum_b343[1] = _mm256_add_epi32(sum_b343[1], b[1]); + StoreAligned64(b444 + x, sum_b444); + StoreAligned64(b343 + x, sum_b343); +} + +inline void Store343_444Lo(const __m256i ma3[3], const __m256i b3[2], + const ptrdiff_t x, __m256i* const sum_ma343, + __m256i* const sum_ma444, __m256i sum_b343[2], + __m256i sum_b444[2], uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + const __m256i sum_ma111 = Sum3WLo16(ma3); + *sum_ma444 = _mm256_slli_epi16(sum_ma111, 2); + StoreAligned32_ma(ma444 + x, *sum_ma444); + const __m256i sum333 = _mm256_sub_epi16(*sum_ma444, sum_ma111); + *sum_ma343 = VaddwLo8(sum333, ma3[1]); + StoreAligned32_ma(ma343 + x, *sum_ma343); + Store343_444(b3, x, sum_b343, sum_b444, b343, b444); +} + +inline void Store343_444Hi(const __m256i ma3[3], const __m256i b3[2], + const ptrdiff_t x, __m256i* const sum_ma343, + __m256i* const sum_ma444, __m256i sum_b343[2], + __m256i sum_b444[2], uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + const __m256i sum_ma111 = Sum3WHi16(ma3); + *sum_ma444 = _mm256_slli_epi16(sum_ma111, 2); + StoreAligned32_ma(ma444 + x, *sum_ma444); + const __m256i sum333 = _mm256_sub_epi16(*sum_ma444, sum_ma111); + *sum_ma343 = VaddwHi8(sum333, ma3[1]); + StoreAligned32_ma(ma343 + x, *sum_ma343); + Store343_444(b3, x + kMaStoreOffset, sum_b343, sum_b444, b343, b444); +} + +inline void Store343_444Lo(const __m256i ma3[3], const __m256i b3[2], + const ptrdiff_t x, __m256i* const sum_ma343, + __m256i sum_b343[2], uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + __m256i sum_ma444, sum_b444[2]; + Store343_444Lo(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, sum_b444, ma343, + ma444, b343, b444); +} + +inline void Store343_444Hi(const __m256i ma3[3], const __m256i b3[2], + const ptrdiff_t x, __m256i* const sum_ma343, + __m256i sum_b343[2], uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + __m256i sum_ma444, sum_b444[2]; + Store343_444Hi(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, sum_b444, ma343, + ma444, b343, b444); +} + +inline void Store343_444Lo(const __m256i ma3[3], const __m256i b3[2], + const ptrdiff_t x, uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + __m256i sum_ma343, sum_b343[2]; + Store343_444Lo(ma3, b3, x, &sum_ma343, sum_b343, ma343, ma444, b343, b444); +} + +inline void Store343_444Hi(const __m256i ma3[3], const __m256i b3[2], + const ptrdiff_t x, uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + __m256i sum_ma343, sum_b343[2]; + Store343_444Hi(ma3, b3, x, &sum_ma343, sum_b343, ma343, ma444, b343, b444); +} + +// Don't combine the following 2 functions, which would be slower. +inline void Store343_444(const __m256i ma3[3], const __m256i b3[6], + const ptrdiff_t x, __m256i* const sum_ma343_lo, + __m256i* const sum_ma343_hi, + __m256i* const sum_ma444_lo, + __m256i* const sum_ma444_hi, __m256i sum_b343_lo[2], + __m256i sum_b343_hi[2], __m256i sum_b444_lo[2], + __m256i sum_b444_hi[2], uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + __m256i sum_mat343[2], sum_mat444[2]; + const __m256i sum_ma111_lo = Sum3WLo16(ma3); + sum_mat444[0] = _mm256_slli_epi16(sum_ma111_lo, 2); + const __m256i sum333_lo = _mm256_sub_epi16(sum_mat444[0], sum_ma111_lo); + sum_mat343[0] = VaddwLo8(sum333_lo, ma3[1]); + Store343_444(b3, x, sum_b343_lo, sum_b444_lo, b343, b444); + const __m256i sum_ma111_hi = Sum3WHi16(ma3); + sum_mat444[1] = _mm256_slli_epi16(sum_ma111_hi, 2); + *sum_ma444_lo = _mm256_permute2x128_si256(sum_mat444[0], sum_mat444[1], 0x20); + *sum_ma444_hi = _mm256_permute2x128_si256(sum_mat444[0], sum_mat444[1], 0x31); + StoreAligned32(ma444 + x + 0, *sum_ma444_lo); + StoreAligned32(ma444 + x + 16, *sum_ma444_hi); + const __m256i sum333_hi = _mm256_sub_epi16(sum_mat444[1], sum_ma111_hi); + sum_mat343[1] = VaddwHi8(sum333_hi, ma3[1]); + *sum_ma343_lo = _mm256_permute2x128_si256(sum_mat343[0], sum_mat343[1], 0x20); + *sum_ma343_hi = _mm256_permute2x128_si256(sum_mat343[0], sum_mat343[1], 0x31); + StoreAligned32(ma343 + x + 0, *sum_ma343_lo); + StoreAligned32(ma343 + x + 16, *sum_ma343_hi); + Store343_444(b3 + 3, x + 16, sum_b343_hi, sum_b444_hi, b343, b444); +} + +inline void Store343_444(const __m256i ma3[3], const __m256i b3[6], + const ptrdiff_t x, __m256i* const sum_ma343_lo, + __m256i* const sum_ma343_hi, __m256i sum_b343_lo[2], + __m256i sum_b343_hi[2], uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + __m256i sum_ma444[2], sum_b444[2], sum_mat343[2]; + const __m256i sum_ma111_lo = Sum3WLo16(ma3); + sum_ma444[0] = _mm256_slli_epi16(sum_ma111_lo, 2); + const __m256i sum333_lo = _mm256_sub_epi16(sum_ma444[0], sum_ma111_lo); + sum_mat343[0] = VaddwLo8(sum333_lo, ma3[1]); + Store343_444(b3, x, sum_b343_lo, sum_b444, b343, b444); + const __m256i sum_ma111_hi = Sum3WHi16(ma3); + sum_ma444[1] = _mm256_slli_epi16(sum_ma111_hi, 2); + StoreAligned64_ma(ma444 + x, sum_ma444); + const __m256i sum333_hi = _mm256_sub_epi16(sum_ma444[1], sum_ma111_hi); + sum_mat343[1] = VaddwHi8(sum333_hi, ma3[1]); + *sum_ma343_lo = _mm256_permute2x128_si256(sum_mat343[0], sum_mat343[1], 0x20); + *sum_ma343_hi = _mm256_permute2x128_si256(sum_mat343[0], sum_mat343[1], 0x31); + StoreAligned32(ma343 + x + 0, *sum_ma343_lo); + StoreAligned32(ma343 + x + 16, *sum_ma343_hi); + Store343_444(b3 + 3, x + 16, sum_b343_hi, sum_b444, b343, b444); +} + +inline void PermuteB(const __m256i t[4], __m256i b[7]) { + // Input: + // 0 1 2 3 // b[0] + // 4 5 6 7 // b[1] + // 8 9 10 11 24 25 26 27 // t[0] + // 12 13 14 15 28 29 30 31 // t[1] + // 16 17 18 19 32 33 34 35 // t[2] + // 20 21 22 23 36 37 38 39 // t[3] + + // Output: + // 0 1 2 3 8 9 10 11 // b[0] + // 4 5 6 7 12 13 14 15 // b[1] + // 8 9 10 11 16 17 18 19 // b[2] + // 16 17 18 19 24 25 26 27 // b[3] + // 20 21 22 23 28 29 30 31 // b[4] + // 24 25 26 27 32 33 34 35 // b[5] + // 20 21 22 23 36 37 38 39 // b[6] + b[0] = _mm256_permute2x128_si256(b[0], t[0], 0x21); + b[1] = _mm256_permute2x128_si256(b[1], t[1], 0x21); + b[2] = _mm256_permute2x128_si256(t[0], t[2], 0x20); + b[3] = _mm256_permute2x128_si256(t[2], t[0], 0x30); + b[4] = _mm256_permute2x128_si256(t[3], t[1], 0x30); + b[5] = _mm256_permute2x128_si256(t[0], t[2], 0x31); + b[6] = t[3]; +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5Lo( + const __m128i s[2][2], const uint32_t scale, uint16_t* const sum5[5], + uint32_t* const square_sum5[5], __m128i sq[2][4], __m128i* const ma, + __m128i b[2]) { + __m128i s5[2][5], sq5[5][2]; + Square(s[0][1], sq[0] + 2); + Square(s[1][1], sq[1] + 2); + s5[0][3] = Sum5Horizontal16(s[0]); + StoreAligned16(sum5[3], s5[0][3]); + s5[0][4] = Sum5Horizontal16(s[1]); + StoreAligned16(sum5[4], s5[0][4]); + Sum5Horizontal32(sq[0], sq5[3]); + StoreAligned32U32(square_sum5[3], sq5[3]); + Sum5Horizontal32(sq[1], sq5[4]); + StoreAligned32U32(square_sum5[4], sq5[4]); + LoadAligned16x3U16(sum5, 0, s5[0]); + LoadAligned32x3U32(square_sum5, 0, sq5); + CalculateIntermediate5(s5[0], sq5, scale, ma, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5( + const uint16_t* const src0, const uint16_t* const src1, + const ptrdiff_t over_read_in_bytes, const ptrdiff_t sum_width, + const ptrdiff_t x, const uint32_t scale, uint16_t* const sum5[5], + uint32_t* const square_sum5[5], __m256i sq[2][8], __m256i ma[3], + __m256i b[3]) { + __m256i s[2], s5[2][5], sq5[5][2], sum[2], index[2], t[4]; + s[0] = LoadUnaligned32Msan(src0 + 8, over_read_in_bytes + 16); + s[1] = LoadUnaligned32Msan(src1 + 8, over_read_in_bytes + 16); + Square(s[0], sq[0] + 2); + Square(s[1], sq[1] + 2); + sq[0][0] = _mm256_permute2x128_si256(sq[0][0], sq[0][2], 0x21); + sq[0][1] = _mm256_permute2x128_si256(sq[0][1], sq[0][3], 0x21); + sq[1][0] = _mm256_permute2x128_si256(sq[1][0], sq[1][2], 0x21); + sq[1][1] = _mm256_permute2x128_si256(sq[1][1], sq[1][3], 0x21); + s5[0][3] = Sum5Horizontal16(src0 + 0, over_read_in_bytes + 0); + s5[1][3] = Sum5Horizontal16(src0 + 16, over_read_in_bytes + 32); + s5[0][4] = Sum5Horizontal16(src1 + 0, over_read_in_bytes + 0); + s5[1][4] = Sum5Horizontal16(src1 + 16, over_read_in_bytes + 32); + StoreAligned32(sum5[3] + x + 0, s5[0][3]); + StoreAligned32(sum5[3] + x + 16, s5[1][3]); + StoreAligned32(sum5[4] + x + 0, s5[0][4]); + StoreAligned32(sum5[4] + x + 16, s5[1][4]); + Sum5Horizontal32(sq[0], sq5[3]); + StoreAligned64(square_sum5[3] + x, sq5[3]); + Sum5Horizontal32(sq[1], sq5[4]); + StoreAligned64(square_sum5[4] + x, sq5[4]); + LoadAligned32x3U16(sum5, x, s5[0]); + LoadAligned64x3U32(square_sum5, x, sq5); + CalculateSumAndIndex5(s5[0], sq5, scale, &sum[0], &index[0]); + + s[0] = LoadUnaligned32Msan(src0 + 24, over_read_in_bytes + 48); + s[1] = LoadUnaligned32Msan(src1 + 24, over_read_in_bytes + 48); + Square(s[0], sq[0] + 6); + Square(s[1], sq[1] + 6); + sq[0][4] = _mm256_permute2x128_si256(sq[0][2], sq[0][6], 0x21); + sq[0][5] = _mm256_permute2x128_si256(sq[0][3], sq[0][7], 0x21); + sq[1][4] = _mm256_permute2x128_si256(sq[1][2], sq[1][6], 0x21); + sq[1][5] = _mm256_permute2x128_si256(sq[1][3], sq[1][7], 0x21); + Sum5Horizontal32(sq[0] + 4, sq5[3]); + StoreAligned64(square_sum5[3] + x + 16, sq5[3]); + Sum5Horizontal32(sq[1] + 4, sq5[4]); + StoreAligned64(square_sum5[4] + x + 16, sq5[4]); + LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]); + LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5); + CalculateSumAndIndex5(s5[1], sq5, scale, &sum[1], &index[1]); + CalculateIntermediate<25>(sum, index, ma, t, t + 2); + PermuteB(t, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRowLo( + const __m128i s[2], const uint32_t scale, const uint16_t* const sum5[5], + const uint32_t* const square_sum5[5], __m128i sq[4], __m128i* const ma, + __m128i b[2]) { + __m128i s5[5], sq5[5][2]; + Square(s[1], sq + 2); + s5[3] = s5[4] = Sum5Horizontal16(s); + Sum5Horizontal32(sq, sq5[3]); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + LoadAligned16x3U16(sum5, 0, s5); + LoadAligned32x3U32(square_sum5, 0, sq5); + CalculateIntermediate5(s5, sq5, scale, ma, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRow( + const uint16_t* const src, const ptrdiff_t over_read_in_bytes, + const ptrdiff_t sum_width, const ptrdiff_t x, const uint32_t scale, + const uint16_t* const sum5[5], const uint32_t* const square_sum5[5], + __m256i sq[3], __m256i ma[3], __m256i b[3]) { + const __m256i s0 = LoadUnaligned32Msan(src + 8, over_read_in_bytes + 16); + __m256i s5[2][5], sq5[5][2], sum[2], index[2], t[4]; + Square(s0, sq + 2); + sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21); + sq[1] = _mm256_permute2x128_si256(sq[1], sq[3], 0x21); + s5[0][3] = Sum5Horizontal16(src + 0, over_read_in_bytes + 0); + s5[1][3] = Sum5Horizontal16(src + 16, over_read_in_bytes + 32); + s5[0][4] = s5[0][3]; + s5[1][4] = s5[1][3]; + Sum5Horizontal32(sq, sq5[3]); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + LoadAligned32x3U16(sum5, x, s5[0]); + LoadAligned64x3U32(square_sum5, x, sq5); + CalculateSumAndIndex5(s5[0], sq5, scale, &sum[0], &index[0]); + + const __m256i s1 = LoadUnaligned32Msan(src + 24, over_read_in_bytes + 48); + Square(s1, sq + 6); + sq[4] = _mm256_permute2x128_si256(sq[2], sq[6], 0x21); + sq[5] = _mm256_permute2x128_si256(sq[3], sq[7], 0x21); + Sum5Horizontal32(sq + 4, sq5[3]); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]); + LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5); + CalculateSumAndIndex5(s5[1], sq5, scale, &sum[1], &index[1]); + CalculateIntermediate<25>(sum, index, ma, t, t + 2); + PermuteB(t, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3Lo( + const __m128i s[2], const uint32_t scale, uint16_t* const sum3[3], + uint32_t* const square_sum3[3], __m128i sq[4], __m128i* const ma, + __m128i b[2]) { + __m128i s3[3], sq3[3][2]; + Square(s[1], sq + 2); + s3[2] = Sum3Horizontal16(s); + StoreAligned16(sum3[2], s3[2]); + Sum3Horizontal32(sq, sq3[2]); + StoreAligned32U32(square_sum3[2], sq3[2]); + LoadAligned16x2U16(sum3, 0, s3); + LoadAligned32x2U32(square_sum3, 0, sq3); + CalculateIntermediate3(s3, sq3, scale, ma, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3( + const uint16_t* const src, const ptrdiff_t over_read_in_bytes, + const ptrdiff_t x, const ptrdiff_t sum_width, const uint32_t scale, + uint16_t* const sum3[3], uint32_t* const square_sum3[3], __m256i sq[8], + __m256i ma[3], __m256i b[7]) { + __m256i s[2], s3[4], sq3[3][2], sum[2], index[2], t[4]; + s[0] = LoadUnaligned32Msan(src + 8, over_read_in_bytes + 16); + s[1] = LoadUnaligned32Msan(src + 24, over_read_in_bytes + 48); + Square(s[0], sq + 2); + sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21); + sq[1] = _mm256_permute2x128_si256(sq[1], sq[3], 0x21); + s3[2] = Sum3Horizontal16(src, over_read_in_bytes); + s3[3] = Sum3Horizontal16(src + 16, over_read_in_bytes + 32); + StoreAligned64(sum3[2] + x, s3 + 2); + Sum3Horizontal32(sq + 0, sq3[2]); + StoreAligned64(square_sum3[2] + x, sq3[2]); + LoadAligned32x2U16(sum3, x, s3); + LoadAligned64x2U32(square_sum3, x, sq3); + CalculateSumAndIndex3(s3, sq3, scale, &sum[0], &index[0]); + + Square(s[1], sq + 6); + sq[4] = _mm256_permute2x128_si256(sq[2], sq[6], 0x21); + sq[5] = _mm256_permute2x128_si256(sq[3], sq[7], 0x21); + Sum3Horizontal32(sq + 4, sq3[2]); + StoreAligned64(square_sum3[2] + x + 16, sq3[2]); + LoadAligned32x2U16Msan(sum3, x + 16, sum_width, s3 + 1); + LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3); + CalculateSumAndIndex3(s3 + 1, sq3, scale, &sum[1], &index[1]); + CalculateIntermediate<9>(sum, index, ma, t, t + 2); + PermuteB(t, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLo( + const __m128i s[2][4], const uint16_t scales[2], uint16_t* const sum3[4], + uint16_t* const sum5[5], uint32_t* const square_sum3[4], + uint32_t* const square_sum5[5], __m128i sq[2][8], __m128i ma3[2][3], + __m128i b3[2][10], __m128i* const ma5, __m128i b5[2]) { + __m128i s3[4], s5[5], sq3[4][2], sq5[5][2], sum[2], index[2]; + Square(s[0][1], sq[0] + 2); + Square(s[1][1], sq[1] + 2); + SumHorizontal16(s[0], &s3[2], &s5[3]); + SumHorizontal16(s[1], &s3[3], &s5[4]); + StoreAligned16(sum3[2], s3[2]); + StoreAligned16(sum3[3], s3[3]); + StoreAligned16(sum5[3], s5[3]); + StoreAligned16(sum5[4], s5[4]); + SumHorizontal32(sq[0], &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + StoreAligned32U32(square_sum3[2], sq3[2]); + StoreAligned32U32(square_sum5[3], sq5[3]); + SumHorizontal32(sq[1], &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]); + StoreAligned32U32(square_sum3[3], sq3[3]); + StoreAligned32U32(square_sum5[4], sq5[4]); + LoadAligned16x2U16(sum3, 0, s3); + LoadAligned32x2U32(square_sum3, 0, sq3); + LoadAligned16x3U16(sum5, 0, s5); + LoadAligned32x3U32(square_sum5, 0, sq5); + CalculateSumAndIndex3(s3 + 0, sq3 + 0, scales[1], &sum[0], &index[0]); + CalculateSumAndIndex3(s3 + 1, sq3 + 1, scales[1], &sum[1], &index[1]); + CalculateIntermediate(sum, index, &ma3[0][0], b3[0], b3[1]); + ma3[1][0] = _mm_srli_si128(ma3[0][0], 8); + CalculateIntermediate5(s5, sq5, scales[0], ma5, b5); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess( + const uint16_t* const src0, const uint16_t* const src1, + const ptrdiff_t over_read_in_bytes, const ptrdiff_t x, + const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + const ptrdiff_t sum_width, __m256i sq[2][8], __m256i ma3[2][3], + __m256i b3[2][7], __m256i ma5[3], __m256i b5[5]) { + __m256i s[2], s3[2][4], s5[2][5], sq3[4][2], sq5[5][2], sum_3[2][2], + index_3[2][2], sum_5[2], index_5[2], t[4]; + s[0] = LoadUnaligned32Msan(src0 + 8, over_read_in_bytes + 16); + s[1] = LoadUnaligned32Msan(src1 + 8, over_read_in_bytes + 16); + Square(s[0], sq[0] + 2); + Square(s[1], sq[1] + 2); + sq[0][0] = _mm256_permute2x128_si256(sq[0][0], sq[0][2], 0x21); + sq[0][1] = _mm256_permute2x128_si256(sq[0][1], sq[0][3], 0x21); + sq[1][0] = _mm256_permute2x128_si256(sq[1][0], sq[1][2], 0x21); + sq[1][1] = _mm256_permute2x128_si256(sq[1][1], sq[1][3], 0x21); + SumHorizontal16(src0, over_read_in_bytes, &s3[0][2], &s3[1][2], &s5[0][3], + &s5[1][3]); + SumHorizontal16(src1, over_read_in_bytes, &s3[0][3], &s3[1][3], &s5[0][4], + &s5[1][4]); + StoreAligned32(sum3[2] + x + 0, s3[0][2]); + StoreAligned32(sum3[2] + x + 16, s3[1][2]); + StoreAligned32(sum3[3] + x + 0, s3[0][3]); + StoreAligned32(sum3[3] + x + 16, s3[1][3]); + StoreAligned32(sum5[3] + x + 0, s5[0][3]); + StoreAligned32(sum5[3] + x + 16, s5[1][3]); + StoreAligned32(sum5[4] + x + 0, s5[0][4]); + StoreAligned32(sum5[4] + x + 16, s5[1][4]); + SumHorizontal32(sq[0], &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + SumHorizontal32(sq[1], &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]); + StoreAligned64(square_sum3[2] + x, sq3[2]); + StoreAligned64(square_sum5[3] + x, sq5[3]); + StoreAligned64(square_sum3[3] + x, sq3[3]); + StoreAligned64(square_sum5[4] + x, sq5[4]); + LoadAligned32x2U16(sum3, x, s3[0]); + LoadAligned64x2U32(square_sum3, x, sq3); + CalculateSumAndIndex3(s3[0], sq3, scales[1], &sum_3[0][0], &index_3[0][0]); + CalculateSumAndIndex3(s3[0] + 1, sq3 + 1, scales[1], &sum_3[1][0], + &index_3[1][0]); + LoadAligned32x3U16(sum5, x, s5[0]); + LoadAligned64x3U32(square_sum5, x, sq5); + CalculateSumAndIndex5(s5[0], sq5, scales[0], &sum_5[0], &index_5[0]); + + s[0] = LoadUnaligned32Msan(src0 + 24, over_read_in_bytes + 48); + s[1] = LoadUnaligned32Msan(src1 + 24, over_read_in_bytes + 48); + Square(s[0], sq[0] + 6); + Square(s[1], sq[1] + 6); + sq[0][4] = _mm256_permute2x128_si256(sq[0][2], sq[0][6], 0x21); + sq[0][5] = _mm256_permute2x128_si256(sq[0][3], sq[0][7], 0x21); + sq[1][4] = _mm256_permute2x128_si256(sq[1][2], sq[1][6], 0x21); + sq[1][5] = _mm256_permute2x128_si256(sq[1][3], sq[1][7], 0x21); + SumHorizontal32(sq[0] + 4, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + SumHorizontal32(sq[1] + 4, &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]); + StoreAligned64(square_sum3[2] + x + 16, sq3[2]); + StoreAligned64(square_sum5[3] + x + 16, sq5[3]); + StoreAligned64(square_sum3[3] + x + 16, sq3[3]); + StoreAligned64(square_sum5[4] + x + 16, sq5[4]); + LoadAligned32x2U16Msan(sum3, x + 16, sum_width, s3[1]); + LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3); + CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum_3[0][1], &index_3[0][1]); + CalculateSumAndIndex3(s3[1] + 1, sq3 + 1, scales[1], &sum_3[1][1], + &index_3[1][1]); + CalculateIntermediate<9>(sum_3[0], index_3[0], ma3[0], t, t + 2); + PermuteB(t, b3[0]); + CalculateIntermediate<9>(sum_3[1], index_3[1], ma3[1], t, t + 2); + PermuteB(t, b3[1]); + LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]); + LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5); + CalculateSumAndIndex5(s5[1], sq5, scales[0], &sum_5[1], &index_5[1]); + CalculateIntermediate<25>(sum_5, index_5, ma5, t, t + 2); + PermuteB(t, b5); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRowLo( + const __m128i s[2], const uint16_t scales[2], const uint16_t* const sum3[4], + const uint16_t* const sum5[5], const uint32_t* const square_sum3[4], + const uint32_t* const square_sum5[5], __m128i sq[4], __m128i* const ma3, + __m128i* const ma5, __m128i b3[2], __m128i b5[2]) { + __m128i s3[3], s5[5], sq3[3][2], sq5[5][2]; + Square(s[1], sq + 2); + SumHorizontal16(s, &s3[2], &s5[3]); + SumHorizontal32(sq, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + LoadAligned16x3U16(sum5, 0, s5); + s5[4] = s5[3]; + LoadAligned32x3U32(square_sum5, 0, sq5); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + CalculateIntermediate5(s5, sq5, scales[0], ma5, b5); + LoadAligned16x2U16(sum3, 0, s3); + LoadAligned32x2U32(square_sum3, 0, sq3); + CalculateIntermediate3(s3, sq3, scales[1], ma3, b3); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow( + const uint16_t* const src, const ptrdiff_t over_read_in_bytes, + const ptrdiff_t sum_width, const ptrdiff_t x, const uint16_t scales[2], + const uint16_t* const sum3[4], const uint16_t* const sum5[5], + const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5], + __m256i sq[6], __m256i ma3[2], __m256i ma5[2], __m256i b3[5], + __m256i b5[5]) { + const __m256i s0 = LoadUnaligned32Msan(src + 8, over_read_in_bytes + 16); + __m256i s3[2][3], s5[2][5], sq3[4][2], sq5[5][2], sum_3[2], index_3[2], + sum_5[2], index_5[2], t[4]; + Square(s0, sq + 2); + sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21); + sq[1] = _mm256_permute2x128_si256(sq[1], sq[3], 0x21); + SumHorizontal16(src, over_read_in_bytes, &s3[0][2], &s3[1][2], &s5[0][3], + &s5[1][3]); + SumHorizontal32(sq, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + LoadAligned32x2U16(sum3, x, s3[0]); + LoadAligned64x2U32(square_sum3, x, sq3); + CalculateSumAndIndex3(s3[0], sq3, scales[1], &sum_3[0], &index_3[0]); + LoadAligned32x3U16(sum5, x, s5[0]); + s5[0][4] = s5[0][3]; + LoadAligned64x3U32(square_sum5, x, sq5); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + CalculateSumAndIndex5(s5[0], sq5, scales[0], &sum_5[0], &index_5[0]); + + const __m256i s1 = LoadUnaligned32Msan(src + 24, over_read_in_bytes + 48); + Square(s1, sq + 6); + sq[4] = _mm256_permute2x128_si256(sq[2], sq[6], 0x21); + sq[5] = _mm256_permute2x128_si256(sq[3], sq[7], 0x21); + SumHorizontal32(sq + 4, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + LoadAligned32x2U16Msan(sum3, x + 16, sum_width, s3[1]); + LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3); + CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum_3[1], &index_3[1]); + CalculateIntermediate<9>(sum_3, index_3, ma3, t, t + 2); + PermuteB(t, b3); + LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]); + s5[1][4] = s5[1][3]; + LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + CalculateSumAndIndex5(s5[1], sq5, scales[0], &sum_5[1], &index_5[1]); + CalculateIntermediate<25>(sum_5, index_5, ma5, t, t + 2); + PermuteB(t, b5); +} + +inline void BoxSumFilterPreProcess5(const uint16_t* const src0, + const uint16_t* const src1, const int width, + const uint32_t scale, + uint16_t* const sum5[5], + uint32_t* const square_sum5[5], + const ptrdiff_t sum_width, uint16_t* ma565, + uint32_t* b565) { + const ptrdiff_t overread_in_bytes = + kOverreadInBytesPass1_128 - sizeof(*src0) * width; + __m128i s[2][2], ma0, sq_128[2][4], b0[2]; + __m256i mas[3], sq[2][8], bs[10]; + s[0][0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes + 0); + s[0][1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes + 16); + s[1][0] = LoadUnaligned16Msan(src1 + 0, overread_in_bytes + 0); + s[1][1] = LoadUnaligned16Msan(src1 + 8, overread_in_bytes + 16); + Square(s[0][0], sq_128[0]); + Square(s[1][0], sq_128[1]); + BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq_128, &ma0, b0); + sq[0][0] = SetrM128i(sq_128[0][2], sq_128[0][2]); + sq[0][1] = SetrM128i(sq_128[0][3], sq_128[0][3]); + sq[1][0] = SetrM128i(sq_128[1][2], sq_128[1][2]); + sq[1][1] = SetrM128i(sq_128[1][3], sq_128[1][3]); + mas[0] = SetrM128i(ma0, ma0); + bs[0] = SetrM128i(b0[0], b0[0]); + bs[1] = SetrM128i(b0[1], b0[1]); + + int x = 0; + do { + __m256i ma5[3], ma[2], b[4]; + BoxFilterPreProcess5( + src0 + x + 8, src1 + x + 8, + kOverreadInBytesPass1_256 + sizeof(*src0) * (x + 8 - width), sum_width, + x + 8, scale, sum5, square_sum5, sq, mas, bs); + Prepare3_8(mas, ma5); + ma[0] = Sum565Lo(ma5); + ma[1] = Sum565Hi(ma5); + StoreAligned64_ma(ma565, ma); + Sum565(bs + 0, b + 0); + Sum565(bs + 3, b + 2); + StoreAligned64(b565, b + 0); + StoreAligned64(b565 + 16, b + 2); + sq[0][0] = sq[0][6]; + sq[0][1] = sq[0][7]; + sq[1][0] = sq[1][6]; + sq[1][1] = sq[1][7]; + mas[0] = mas[2]; + bs[0] = bs[5]; + bs[1] = bs[6]; + ma565 += 32; + b565 += 32; + x += 32; + } while (x < width); +} + +template <bool calculate444> +LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3( + const uint16_t* const src, const int width, const uint32_t scale, + uint16_t* const sum3[3], uint32_t* const square_sum3[3], + const ptrdiff_t sum_width, uint16_t* ma343, uint16_t* ma444, uint32_t* b343, + uint32_t* b444) { + const ptrdiff_t overread_in_bytes_128 = + kOverreadInBytesPass2_128 - sizeof(*src) * width; + __m128i s[2], ma0, sq_128[4], b0[2]; + __m256i mas[3], sq[8], bs[7]; + s[0] = LoadUnaligned16Msan(src + 0, overread_in_bytes_128 + 0); + s[1] = LoadUnaligned16Msan(src + 8, overread_in_bytes_128 + 16); + Square(s[0], sq_128); + BoxFilterPreProcess3Lo(s, scale, sum3, square_sum3, sq_128, &ma0, b0); + sq[0] = SetrM128i(sq_128[2], sq_128[2]); + sq[1] = SetrM128i(sq_128[3], sq_128[3]); + mas[0] = SetrM128i(ma0, ma0); + bs[0] = SetrM128i(b0[0], b0[0]); + bs[1] = SetrM128i(b0[1], b0[1]); + + int x = 0; + do { + __m256i ma3[3]; + BoxFilterPreProcess3( + src + x + 8, kOverreadInBytesPass2_256 + sizeof(*src) * (x + 8 - width), + x + 8, sum_width, scale, sum3, square_sum3, sq, mas, bs); + Prepare3_8(mas, ma3); + if (calculate444) { // NOLINT(readability-simplify-boolean-expr) + Store343_444Lo(ma3, bs + 0, 0, ma343, ma444, b343, b444); + Store343_444Hi(ma3, bs + 3, kMaStoreOffset, ma343, ma444, b343, b444); + ma444 += 32; + b444 += 32; + } else { + __m256i ma[2], b[4]; + ma[0] = Sum343Lo(ma3); + ma[1] = Sum343Hi(ma3); + StoreAligned64_ma(ma343, ma); + Sum343(bs + 0, b + 0); + Sum343(bs + 3, b + 2); + StoreAligned64(b343 + 0, b + 0); + StoreAligned64(b343 + 16, b + 2); + } + sq[0] = sq[6]; + sq[1] = sq[7]; + mas[0] = mas[2]; + bs[0] = bs[5]; + bs[1] = bs[6]; + ma343 += 32; + b343 += 32; + x += 32; + } while (x < width); +} + +inline void BoxSumFilterPreProcess( + const uint16_t* const src0, const uint16_t* const src1, const int width, + const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + const ptrdiff_t sum_width, uint16_t* const ma343[4], uint16_t* const ma444, + uint16_t* ma565, uint32_t* const b343[4], uint32_t* const b444, + uint32_t* b565) { + const ptrdiff_t overread_in_bytes = + kOverreadInBytesPass1_128 - sizeof(*src0) * width; + __m128i s[2][4], ma3_128[2][3], ma5_128[3], sq_128[2][8], b3_128[2][10], + b5_128[10]; + __m256i ma3[2][3], ma5[3], sq[2][8], b3[2][7], b5[7]; + s[0][0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes + 0); + s[0][1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes + 16); + s[1][0] = LoadUnaligned16Msan(src1 + 0, overread_in_bytes + 0); + s[1][1] = LoadUnaligned16Msan(src1 + 8, overread_in_bytes + 16); + Square(s[0][0], sq_128[0]); + Square(s[1][0], sq_128[1]); + BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq_128, + ma3_128, b3_128, &ma5_128[0], b5_128); + sq[0][0] = SetrM128i(sq_128[0][2], sq_128[0][2]); + sq[0][1] = SetrM128i(sq_128[0][3], sq_128[0][3]); + sq[1][0] = SetrM128i(sq_128[1][2], sq_128[1][2]); + sq[1][1] = SetrM128i(sq_128[1][3], sq_128[1][3]); + ma3[0][0] = SetrM128i(ma3_128[0][0], ma3_128[0][0]); + ma3[1][0] = SetrM128i(ma3_128[1][0], ma3_128[1][0]); + ma5[0] = SetrM128i(ma5_128[0], ma5_128[0]); + b3[0][0] = SetrM128i(b3_128[0][0], b3_128[0][0]); + b3[0][1] = SetrM128i(b3_128[0][1], b3_128[0][1]); + b3[1][0] = SetrM128i(b3_128[1][0], b3_128[1][0]); + b3[1][1] = SetrM128i(b3_128[1][1], b3_128[1][1]); + b5[0] = SetrM128i(b5_128[0], b5_128[0]); + b5[1] = SetrM128i(b5_128[1], b5_128[1]); + + int x = 0; + do { + __m256i ma[2], b[4], ma3x[3], ma5x[3]; + BoxFilterPreProcess( + src0 + x + 8, src1 + x + 8, + kOverreadInBytesPass1_256 + sizeof(*src0) * (x + 8 - width), x + 8, + scales, sum3, sum5, square_sum3, square_sum5, sum_width, sq, ma3, b3, + ma5, b5); + Prepare3_8(ma3[0], ma3x); + ma[0] = Sum343Lo(ma3x); + ma[1] = Sum343Hi(ma3x); + StoreAligned64_ma(ma343[0] + x, ma); + Sum343(b3[0], b); + Sum343(b3[0] + 3, b + 2); + StoreAligned64(b343[0] + x, b); + StoreAligned64(b343[0] + x + 16, b + 2); + Prepare3_8(ma3[1], ma3x); + Store343_444Lo(ma3x, b3[1], x, ma343[1], ma444, b343[1], b444); + Store343_444Hi(ma3x, b3[1] + 3, x + kMaStoreOffset, ma343[1], ma444, + b343[1], b444); + Prepare3_8(ma5, ma5x); + ma[0] = Sum565Lo(ma5x); + ma[1] = Sum565Hi(ma5x); + StoreAligned64_ma(ma565, ma); + Sum565(b5, b); + StoreAligned64(b565, b); + Sum565(b5 + 3, b); + StoreAligned64(b565 + 16, b); + sq[0][0] = sq[0][6]; + sq[0][1] = sq[0][7]; + sq[1][0] = sq[1][6]; + sq[1][1] = sq[1][7]; + ma3[0][0] = ma3[0][2]; + ma3[1][0] = ma3[1][2]; + ma5[0] = ma5[2]; + b3[0][0] = b3[0][5]; + b3[0][1] = b3[0][6]; + b3[1][0] = b3[1][5]; + b3[1][1] = b3[1][6]; + b5[0] = b5[5]; + b5[1] = b5[6]; + ma565 += 32; + b565 += 32; + x += 32; + } while (x < width); +} + +template <int shift> +inline __m256i FilterOutput(const __m256i ma_x_src, const __m256i b) { + // ma: 255 * 32 = 8160 (13 bits) + // b: 65088 * 32 = 2082816 (21 bits) + // v: b - ma * 255 (22 bits) + const __m256i v = _mm256_sub_epi32(b, ma_x_src); + // kSgrProjSgrBits = 8 + // kSgrProjRestoreBits = 4 + // shift = 4 or 5 + // v >> 8 or 9 (13 bits) + return VrshrS32(v, kSgrProjSgrBits + shift - kSgrProjRestoreBits); +} + +template <int shift> +inline __m256i CalculateFilteredOutput(const __m256i src, const __m256i ma, + const __m256i b[2]) { + const __m256i ma_x_src_lo = VmullLo16(ma, src); + const __m256i ma_x_src_hi = VmullHi16(ma, src); + const __m256i dst_lo = FilterOutput<shift>(ma_x_src_lo, b[0]); + const __m256i dst_hi = FilterOutput<shift>(ma_x_src_hi, b[1]); + return _mm256_packs_epi32(dst_lo, dst_hi); // 13 bits +} + +inline __m256i CalculateFilteredOutputPass1(const __m256i src, + const __m256i ma[2], + const __m256i b[2][2]) { + const __m256i ma_sum = _mm256_add_epi16(ma[0], ma[1]); + __m256i b_sum[2]; + b_sum[0] = _mm256_add_epi32(b[0][0], b[1][0]); + b_sum[1] = _mm256_add_epi32(b[0][1], b[1][1]); + return CalculateFilteredOutput<5>(src, ma_sum, b_sum); +} + +inline __m256i CalculateFilteredOutputPass2(const __m256i src, + const __m256i ma[3], + const __m256i b[3][2]) { + const __m256i ma_sum = Sum3_16(ma); + __m256i b_sum[2]; + Sum3_32(b, b_sum); + return CalculateFilteredOutput<5>(src, ma_sum, b_sum); +} + +inline __m256i SelfGuidedFinal(const __m256i src, const __m256i v[2]) { + const __m256i v_lo = + VrshrS32(v[0], kSgrProjRestoreBits + kSgrProjPrecisionBits); + const __m256i v_hi = + VrshrS32(v[1], kSgrProjRestoreBits + kSgrProjPrecisionBits); + const __m256i vv = _mm256_packs_epi32(v_lo, v_hi); + return _mm256_add_epi16(src, vv); +} + +inline __m256i SelfGuidedDoubleMultiplier(const __m256i src, + const __m256i filter[2], const int w0, + const int w2) { + __m256i v[2]; + const __m256i w0_w2 = + _mm256_set1_epi32((w2 << 16) | static_cast<uint16_t>(w0)); + const __m256i f_lo = _mm256_unpacklo_epi16(filter[0], filter[1]); + const __m256i f_hi = _mm256_unpackhi_epi16(filter[0], filter[1]); + v[0] = _mm256_madd_epi16(w0_w2, f_lo); + v[1] = _mm256_madd_epi16(w0_w2, f_hi); + return SelfGuidedFinal(src, v); +} + +inline __m256i SelfGuidedSingleMultiplier(const __m256i src, + const __m256i filter, const int w0) { + // weight: -96 to 96 (Sgrproj_Xqd_Min/Max) + __m256i v[2]; + v[0] = VmullNLo8(filter, w0); + v[1] = VmullNHi8(filter, w0); + return SelfGuidedFinal(src, v); +} + +inline void ClipAndStore(uint16_t* const dst, const __m256i val) { + const __m256i val0 = _mm256_max_epi16(val, _mm256_setzero_si256()); + const __m256i val1 = _mm256_min_epi16(val0, _mm256_set1_epi16(1023)); + StoreUnaligned32(dst, val1); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPass1( + const uint16_t* const src, const uint16_t* const src0, + const uint16_t* const src1, const ptrdiff_t stride, uint16_t* const sum5[5], + uint32_t* const square_sum5[5], const int width, const ptrdiff_t sum_width, + const uint32_t scale, const int16_t w0, uint16_t* const ma565[2], + uint32_t* const b565[2], uint16_t* const dst) { + const ptrdiff_t overread_in_bytes = + kOverreadInBytesPass1_128 - sizeof(*src0) * width; + __m128i s[2][2], ma0, sq_128[2][4], b0[2]; + __m256i mas[3], sq[2][8], bs[7]; + s[0][0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes + 0); + s[0][1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes + 16); + s[1][0] = LoadUnaligned16Msan(src1 + 0, overread_in_bytes + 0); + s[1][1] = LoadUnaligned16Msan(src1 + 8, overread_in_bytes + 16); + Square(s[0][0], sq_128[0]); + Square(s[1][0], sq_128[1]); + BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq_128, &ma0, b0); + sq[0][0] = SetrM128i(sq_128[0][2], sq_128[0][2]); + sq[0][1] = SetrM128i(sq_128[0][3], sq_128[0][3]); + sq[1][0] = SetrM128i(sq_128[1][2], sq_128[1][2]); + sq[1][1] = SetrM128i(sq_128[1][3], sq_128[1][3]); + mas[0] = SetrM128i(ma0, ma0); + bs[0] = SetrM128i(b0[0], b0[0]); + bs[1] = SetrM128i(b0[1], b0[1]); + + int x = 0; + do { + __m256i ma5[3], ma[4], b[4][2]; + BoxFilterPreProcess5( + src0 + x + 8, src1 + x + 8, + kOverreadInBytesPass1_256 + sizeof(*src0) * (x + 8 - width), sum_width, + x + 8, scale, sum5, square_sum5, sq, mas, bs); + Prepare3_8(mas, ma5); + ma[2] = Sum565Lo(ma5); + ma[3] = Sum565Hi(ma5); + ma[1] = _mm256_permute2x128_si256(ma[2], ma[3], 0x20); + ma[3] = _mm256_permute2x128_si256(ma[2], ma[3], 0x31); + StoreAligned32(ma565[1] + x + 0, ma[1]); + StoreAligned32(ma565[1] + x + 16, ma[3]); + Sum565(bs + 0, b[1]); + Sum565(bs + 3, b[3]); + StoreAligned64(b565[1] + x, b[1]); + StoreAligned64(b565[1] + x + 16, b[3]); + const __m256i sr0_lo = LoadUnaligned32(src + x + 0); + ma[0] = LoadAligned32(ma565[0] + x); + LoadAligned64(b565[0] + x, b[0]); + const __m256i p0 = CalculateFilteredOutputPass1(sr0_lo, ma, b); + const __m256i d0 = SelfGuidedSingleMultiplier(sr0_lo, p0, w0); + ClipAndStore(dst + x + 0, d0); + const __m256i sr0_hi = LoadUnaligned32(src + x + 16); + ma[2] = LoadAligned32(ma565[0] + x + 16); + LoadAligned64(b565[0] + x + 16, b[2]); + const __m256i p1 = CalculateFilteredOutputPass1(sr0_hi, ma + 2, b + 2); + const __m256i d1 = SelfGuidedSingleMultiplier(sr0_hi, p1, w0); + ClipAndStore(dst + x + 16, d1); + const __m256i sr1_lo = LoadUnaligned32(src + stride + x + 0); + const __m256i p10 = CalculateFilteredOutput<4>(sr1_lo, ma[1], b[1]); + const __m256i d10 = SelfGuidedSingleMultiplier(sr1_lo, p10, w0); + ClipAndStore(dst + stride + x + 0, d10); + const __m256i sr1_hi = LoadUnaligned32(src + stride + x + 16); + const __m256i p11 = CalculateFilteredOutput<4>(sr1_hi, ma[3], b[3]); + const __m256i d11 = SelfGuidedSingleMultiplier(sr1_hi, p11, w0); + ClipAndStore(dst + stride + x + 16, d11); + sq[0][0] = sq[0][6]; + sq[0][1] = sq[0][7]; + sq[1][0] = sq[1][6]; + sq[1][1] = sq[1][7]; + mas[0] = mas[2]; + bs[0] = bs[5]; + bs[1] = bs[6]; + x += 32; + } while (x < width); +} + +inline void BoxFilterPass1LastRow( + const uint16_t* const src, const uint16_t* const src0, const int width, + const ptrdiff_t sum_width, const uint32_t scale, const int16_t w0, + uint16_t* const sum5[5], uint32_t* const square_sum5[5], uint16_t* ma565, + uint32_t* b565, uint16_t* const dst) { + const ptrdiff_t overread_in_bytes = + kOverreadInBytesPass1_128 - sizeof(*src0) * width; + __m128i s[2], ma0[2], sq_128[8], b0[6]; + __m256i mas[3], sq[8], bs[7]; + s[0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes + 0); + s[1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes + 16); + Square(s[0], sq_128); + BoxFilterPreProcess5LastRowLo(s, scale, sum5, square_sum5, sq_128, &ma0[0], + b0); + sq[0] = SetrM128i(sq_128[2], sq_128[2]); + sq[1] = SetrM128i(sq_128[3], sq_128[3]); + mas[0] = SetrM128i(ma0[0], ma0[0]); + bs[0] = SetrM128i(b0[0], b0[0]); + bs[1] = SetrM128i(b0[1], b0[1]); + + int x = 0; + do { + __m256i ma5[3], ma[4], b[4][2]; + BoxFilterPreProcess5LastRow( + src0 + x + 8, + kOverreadInBytesPass1_256 + sizeof(*src0) * (x + 8 - width), sum_width, + x + 8, scale, sum5, square_sum5, sq, mas, bs); + Prepare3_8(mas, ma5); + ma[2] = Sum565Lo(ma5); + ma[3] = Sum565Hi(ma5); + Sum565(bs + 0, b[1]); + Sum565(bs + 3, b[3]); + const __m256i sr0_lo = LoadUnaligned32(src + x + 0); + ma[0] = LoadAligned32(ma565 + x); + ma[1] = _mm256_permute2x128_si256(ma[2], ma[3], 0x20); + LoadAligned64(b565 + x, b[0]); + const __m256i p0 = CalculateFilteredOutputPass1(sr0_lo, ma, b); + const __m256i d0 = SelfGuidedSingleMultiplier(sr0_lo, p0, w0); + ClipAndStore(dst + x + 0, d0); + const __m256i sr0_hi = LoadUnaligned32(src + x + 16); + ma[0] = LoadAligned32(ma565 + x + 16); + ma[1] = _mm256_permute2x128_si256(ma[2], ma[3], 0x31); + LoadAligned64(b565 + x + 16, b[2]); + const __m256i p1 = CalculateFilteredOutputPass1(sr0_hi, ma, b + 2); + const __m256i d1 = SelfGuidedSingleMultiplier(sr0_hi, p1, w0); + ClipAndStore(dst + x + 16, d1); + sq[0] = sq[6]; + sq[1] = sq[7]; + mas[0] = mas[2]; + bs[0] = bs[5]; + bs[1] = bs[6]; + x += 32; + } while (x < width); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPass2( + const uint16_t* const src, const uint16_t* const src0, const int width, + const ptrdiff_t sum_width, const uint32_t scale, const int16_t w0, + uint16_t* const sum3[3], uint32_t* const square_sum3[3], + uint16_t* const ma343[3], uint16_t* const ma444[2], uint32_t* const b343[3], + uint32_t* const b444[2], uint16_t* const dst) { + const ptrdiff_t overread_in_bytes_128 = + kOverreadInBytesPass2_128 - sizeof(*src0) * width; + __m128i s0[2], ma0, sq_128[4], b0[2]; + __m256i mas[3], sq[8], bs[7]; + s0[0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes_128 + 0); + s0[1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes_128 + 16); + Square(s0[0], sq_128); + BoxFilterPreProcess3Lo(s0, scale, sum3, square_sum3, sq_128, &ma0, b0); + sq[0] = SetrM128i(sq_128[2], sq_128[2]); + sq[1] = SetrM128i(sq_128[3], sq_128[3]); + mas[0] = SetrM128i(ma0, ma0); + bs[0] = SetrM128i(b0[0], b0[0]); + bs[1] = SetrM128i(b0[1], b0[1]); + + int x = 0; + do { + __m256i ma[4], b[4][2], ma3[3]; + BoxFilterPreProcess3( + src0 + x + 8, + kOverreadInBytesPass2_256 + sizeof(*src0) * (x + 8 - width), x + 8, + sum_width, scale, sum3, square_sum3, sq, mas, bs); + Prepare3_8(mas, ma3); + Store343_444(ma3, bs, x, &ma[2], &ma[3], b[2], b[3], ma343[2], ma444[1], + b343[2], b444[1]); + const __m256i sr_lo = LoadUnaligned32(src + x + 0); + const __m256i sr_hi = LoadUnaligned32(src + x + 16); + ma[0] = LoadAligned32(ma343[0] + x); + ma[1] = LoadAligned32(ma444[0] + x); + LoadAligned64(b343[0] + x, b[0]); + LoadAligned64(b444[0] + x, b[1]); + const __m256i p0 = CalculateFilteredOutputPass2(sr_lo, ma, b); + ma[1] = LoadAligned32(ma343[0] + x + 16); + ma[2] = LoadAligned32(ma444[0] + x + 16); + LoadAligned64(b343[0] + x + 16, b[1]); + LoadAligned64(b444[0] + x + 16, b[2]); + const __m256i p1 = CalculateFilteredOutputPass2(sr_hi, ma + 1, b + 1); + const __m256i d0 = SelfGuidedSingleMultiplier(sr_lo, p0, w0); + const __m256i d1 = SelfGuidedSingleMultiplier(sr_hi, p1, w0); + ClipAndStore(dst + x + 0, d0); + ClipAndStore(dst + x + 16, d1); + sq[0] = sq[6]; + sq[1] = sq[7]; + mas[0] = mas[2]; + bs[0] = bs[5]; + bs[1] = bs[6]; + x += 32; + } while (x < width); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilter( + const uint16_t* const src, const uint16_t* const src0, + const uint16_t* const src1, const ptrdiff_t stride, const int width, + const uint16_t scales[2], const int16_t w0, const int16_t w2, + uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + const ptrdiff_t sum_width, uint16_t* const ma343[4], + uint16_t* const ma444[3], uint16_t* const ma565[2], uint32_t* const b343[4], + uint32_t* const b444[3], uint32_t* const b565[2], uint16_t* const dst) { + const ptrdiff_t overread_in_bytes = + kOverreadInBytesPass1_128 - sizeof(*src0) * width; + __m128i s[2][4], ma3_128[2][3], ma5_0, sq_128[2][8], b3_128[2][10], b5_128[2]; + __m256i ma3[2][3], ma5[3], sq[2][8], b3[2][7], b5[7]; + s[0][0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes + 0); + s[0][1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes + 16); + s[1][0] = LoadUnaligned16Msan(src1 + 0, overread_in_bytes + 0); + s[1][1] = LoadUnaligned16Msan(src1 + 8, overread_in_bytes + 16); + Square(s[0][0], sq_128[0]); + Square(s[1][0], sq_128[1]); + BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq_128, + ma3_128, b3_128, &ma5_0, b5_128); + sq[0][0] = SetrM128i(sq_128[0][2], sq_128[0][2]); + sq[0][1] = SetrM128i(sq_128[0][3], sq_128[0][3]); + sq[1][0] = SetrM128i(sq_128[1][2], sq_128[1][2]); + sq[1][1] = SetrM128i(sq_128[1][3], sq_128[1][3]); + ma3[0][0] = SetrM128i(ma3_128[0][0], ma3_128[0][0]); + ma3[1][0] = SetrM128i(ma3_128[1][0], ma3_128[1][0]); + ma5[0] = SetrM128i(ma5_0, ma5_0); + b3[0][0] = SetrM128i(b3_128[0][0], b3_128[0][0]); + b3[0][1] = SetrM128i(b3_128[0][1], b3_128[0][1]); + b3[1][0] = SetrM128i(b3_128[1][0], b3_128[1][0]); + b3[1][1] = SetrM128i(b3_128[1][1], b3_128[1][1]); + b5[0] = SetrM128i(b5_128[0], b5_128[0]); + b5[1] = SetrM128i(b5_128[1], b5_128[1]); + + int x = 0; + do { + __m256i ma[3][4], mat[3][3], b[3][3][2], bt[3][3][2], p[2][2], ma3x[2][3], + ma5x[3]; + BoxFilterPreProcess( + src0 + x + 8, src1 + x + 8, + kOverreadInBytesPass1_256 + sizeof(*src0) * (x + 8 - width), x + 8, + scales, sum3, sum5, square_sum3, square_sum5, sum_width, sq, ma3, b3, + ma5, b5); + Prepare3_8(ma3[0], ma3x[0]); + Prepare3_8(ma3[1], ma3x[1]); + Prepare3_8(ma5, ma5x); + Store343_444(ma3x[0], b3[0], x, &ma[1][2], &mat[1][2], &ma[2][1], + &mat[2][1], b[1][2], bt[1][2], b[2][1], bt[2][1], ma343[2], + ma444[1], b343[2], b444[1]); + Store343_444(ma3x[1], b3[1], x, &ma[2][2], &mat[2][2], b[2][2], bt[2][2], + ma343[3], ma444[2], b343[3], b444[2]); + + ma[0][2] = Sum565Lo(ma5x); + ma[0][3] = Sum565Hi(ma5x); + ma[0][1] = _mm256_permute2x128_si256(ma[0][2], ma[0][3], 0x20); + ma[0][3] = _mm256_permute2x128_si256(ma[0][2], ma[0][3], 0x31); + StoreAligned32(ma565[1] + x + 0, ma[0][1]); + StoreAligned32(ma565[1] + x + 16, ma[0][3]); + Sum565(b5, b[0][1]); + StoreAligned64(b565[1] + x, b[0][1]); + const __m256i sr0_lo = LoadUnaligned32(src + x); + const __m256i sr1_lo = LoadUnaligned32(src + stride + x); + ma[0][0] = LoadAligned32(ma565[0] + x); + LoadAligned64(b565[0] + x, b[0][0]); + p[0][0] = CalculateFilteredOutputPass1(sr0_lo, ma[0], b[0]); + p[1][0] = CalculateFilteredOutput<4>(sr1_lo, ma[0][1], b[0][1]); + ma[1][0] = LoadAligned32(ma343[0] + x); + ma[1][1] = LoadAligned32(ma444[0] + x); + // Keeping the following 4 redundant lines is faster. The reason is that + // there are not enough registers available, and these values could be saved + // and loaded which is even slower. + ma[1][2] = LoadAligned32(ma343[2] + x); // Redundant line 1. + LoadAligned64(b343[0] + x, b[1][0]); + LoadAligned64(b444[0] + x, b[1][1]); + p[0][1] = CalculateFilteredOutputPass2(sr0_lo, ma[1], b[1]); + ma[2][0] = LoadAligned32(ma343[1] + x); + ma[2][1] = LoadAligned32(ma444[1] + x); // Redundant line 2. + LoadAligned64(b343[1] + x, b[2][0]); + p[1][1] = CalculateFilteredOutputPass2(sr1_lo, ma[2], b[2]); + const __m256i d00 = SelfGuidedDoubleMultiplier(sr0_lo, p[0], w0, w2); + ClipAndStore(dst + x, d00); + const __m256i d10x = SelfGuidedDoubleMultiplier(sr1_lo, p[1], w0, w2); + ClipAndStore(dst + stride + x, d10x); + + Sum565(b5 + 3, bt[0][1]); + StoreAligned64(b565[1] + x + 16, bt[0][1]); + const __m256i sr0_hi = LoadUnaligned32(src + x + 16); + const __m256i sr1_hi = LoadUnaligned32(src + stride + x + 16); + ma[0][2] = LoadAligned32(ma565[0] + x + 16); + LoadAligned64(b565[0] + x + 16, bt[0][0]); + p[0][0] = CalculateFilteredOutputPass1(sr0_hi, ma[0] + 2, bt[0]); + p[1][0] = CalculateFilteredOutput<4>(sr1_hi, ma[0][3], bt[0][1]); + mat[1][0] = LoadAligned32(ma343[0] + x + 16); + mat[1][1] = LoadAligned32(ma444[0] + x + 16); + mat[1][2] = LoadAligned32(ma343[2] + x + 16); // Redundant line 3. + LoadAligned64(b343[0] + x + 16, bt[1][0]); + LoadAligned64(b444[0] + x + 16, bt[1][1]); + p[0][1] = CalculateFilteredOutputPass2(sr0_hi, mat[1], bt[1]); + mat[2][0] = LoadAligned32(ma343[1] + x + 16); + mat[2][1] = LoadAligned32(ma444[1] + x + 16); // Redundant line 4. + LoadAligned64(b343[1] + x + 16, bt[2][0]); + p[1][1] = CalculateFilteredOutputPass2(sr1_hi, mat[2], bt[2]); + const __m256i d01 = SelfGuidedDoubleMultiplier(sr0_hi, p[0], w0, w2); + ClipAndStore(dst + x + 16, d01); + const __m256i d11 = SelfGuidedDoubleMultiplier(sr1_hi, p[1], w0, w2); + ClipAndStore(dst + stride + x + 16, d11); + + sq[0][0] = sq[0][6]; + sq[0][1] = sq[0][7]; + sq[1][0] = sq[1][6]; + sq[1][1] = sq[1][7]; + ma3[0][0] = ma3[0][2]; + ma3[1][0] = ma3[1][2]; + ma5[0] = ma5[2]; + b3[0][0] = b3[0][5]; + b3[0][1] = b3[0][6]; + b3[1][0] = b3[1][5]; + b3[1][1] = b3[1][6]; + b5[0] = b5[5]; + b5[1] = b5[6]; + x += 32; + } while (x < width); +} + +inline void BoxFilterLastRow( + const uint16_t* const src, const uint16_t* const src0, const int width, + const ptrdiff_t sum_width, const uint16_t scales[2], const int16_t w0, + const int16_t w2, uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + uint16_t* const ma343, uint16_t* const ma444, uint16_t* const ma565, + uint32_t* const b343, uint32_t* const b444, uint32_t* const b565, + uint16_t* const dst) { + const ptrdiff_t overread_in_bytes = + kOverreadInBytesPass1_128 - sizeof(*src0) * width; + __m128i s[2], ma3_0, ma5_0, sq_128[4], b3_128[2], b5_128[2]; + __m256i ma3[3], ma5[3], sq[8], b3[7], b5[7]; + s[0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes + 0); + s[1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes + 16); + Square(s[0], sq_128); + BoxFilterPreProcessLastRowLo(s, scales, sum3, sum5, square_sum3, square_sum5, + sq_128, &ma3_0, &ma5_0, b3_128, b5_128); + sq[0] = SetrM128i(sq_128[2], sq_128[2]); + sq[1] = SetrM128i(sq_128[3], sq_128[3]); + ma3[0] = SetrM128i(ma3_0, ma3_0); + ma5[0] = SetrM128i(ma5_0, ma5_0); + b3[0] = SetrM128i(b3_128[0], b3_128[0]); + b3[1] = SetrM128i(b3_128[1], b3_128[1]); + b5[0] = SetrM128i(b5_128[0], b5_128[0]); + b5[1] = SetrM128i(b5_128[1], b5_128[1]); + + int x = 0; + do { + __m256i ma[4], mat[4], b[3][2], bt[3][2], ma3x[3], ma5x[3], p[2]; + BoxFilterPreProcessLastRow( + src0 + x + 8, + kOverreadInBytesPass1_256 + sizeof(*src0) * (x + 8 - width), sum_width, + x + 8, scales, sum3, sum5, square_sum3, square_sum5, sq, ma3, ma5, b3, + b5); + Prepare3_8(ma3, ma3x); + Prepare3_8(ma5, ma5x); + ma[2] = Sum565Lo(ma5x); + Sum565(b5, b[1]); + mat[1] = Sum565Hi(ma5x); + Sum565(b5 + 3, bt[1]); + ma[3] = Sum343Lo(ma3x); + Sum343(b3, b[2]); + mat[2] = Sum343Hi(ma3x); + Sum343(b3 + 3, bt[2]); + + const __m256i sr_lo = LoadUnaligned32(src + x); + ma[0] = LoadAligned32(ma565 + x); + ma[1] = _mm256_permute2x128_si256(ma[2], mat[1], 0x20); + mat[1] = _mm256_permute2x128_si256(ma[2], mat[1], 0x31); + LoadAligned64(b565 + x, b[0]); + p[0] = CalculateFilteredOutputPass1(sr_lo, ma, b); + ma[0] = LoadAligned32(ma343 + x); + ma[1] = LoadAligned32(ma444 + x); + ma[2] = _mm256_permute2x128_si256(ma[3], mat[2], 0x20); + LoadAligned64(b343 + x, b[0]); + LoadAligned64(b444 + x, b[1]); + p[1] = CalculateFilteredOutputPass2(sr_lo, ma, b); + const __m256i d0 = SelfGuidedDoubleMultiplier(sr_lo, p, w0, w2); + + const __m256i sr_hi = LoadUnaligned32(src + x + 16); + mat[0] = LoadAligned32(ma565 + x + 16); + LoadAligned64(b565 + x + 16, bt[0]); + p[0] = CalculateFilteredOutputPass1(sr_hi, mat, bt); + mat[0] = LoadAligned32(ma343 + x + 16); + mat[1] = LoadAligned32(ma444 + x + 16); + mat[2] = _mm256_permute2x128_si256(ma[3], mat[2], 0x31); + LoadAligned64(b343 + x + 16, bt[0]); + LoadAligned64(b444 + x + 16, bt[1]); + p[1] = CalculateFilteredOutputPass2(sr_hi, mat, bt); + const __m256i d1 = SelfGuidedDoubleMultiplier(sr_hi, p, w0, w2); + ClipAndStore(dst + x + 0, d0); + ClipAndStore(dst + x + 16, d1); + + sq[0] = sq[6]; + sq[1] = sq[7]; + ma3[0] = ma3[2]; + ma5[0] = ma5[2]; + b3[0] = b3[5]; + b3[1] = b3[6]; + b5[0] = b5[5]; + b5[1] = b5[6]; + x += 32; + } while (x < width); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( + const RestorationUnitInfo& restoration_info, const uint16_t* src, + const ptrdiff_t stride, const uint16_t* const top_border, + const ptrdiff_t top_border_stride, const uint16_t* bottom_border, + const ptrdiff_t bottom_border_stride, const int width, const int height, + SgrBuffer* const sgr_buffer, uint16_t* dst) { + const auto temp_stride = Align<ptrdiff_t>(width, 32); + const auto sum_width = temp_stride + 8; + const auto sum_stride = temp_stride + 32; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index]; // < 2^12. + const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0]; + const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1]; + const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1; + uint16_t *sum3[4], *sum5[5], *ma343[4], *ma444[3], *ma565[2]; + uint32_t *square_sum3[4], *square_sum5[5], *b343[4], *b444[3], *b565[2]; + sum3[0] = sgr_buffer->sum3 + kSumOffset; + square_sum3[0] = sgr_buffer->square_sum3 + kSumOffset; + ma343[0] = sgr_buffer->ma343; + b343[0] = sgr_buffer->b343; + for (int i = 1; i <= 3; ++i) { + sum3[i] = sum3[i - 1] + sum_stride; + square_sum3[i] = square_sum3[i - 1] + sum_stride; + ma343[i] = ma343[i - 1] + temp_stride; + b343[i] = b343[i - 1] + temp_stride; + } + sum5[0] = sgr_buffer->sum5 + kSumOffset; + square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset; + for (int i = 1; i <= 4; ++i) { + sum5[i] = sum5[i - 1] + sum_stride; + square_sum5[i] = square_sum5[i - 1] + sum_stride; + } + ma444[0] = sgr_buffer->ma444; + b444[0] = sgr_buffer->b444; + for (int i = 1; i <= 2; ++i) { + ma444[i] = ma444[i - 1] + temp_stride; + b444[i] = b444[i - 1] + temp_stride; + } + ma565[0] = sgr_buffer->ma565; + ma565[1] = ma565[0] + temp_stride; + b565[0] = sgr_buffer->b565; + b565[1] = b565[0] + temp_stride; + assert(scales[0] != 0); + assert(scales[1] != 0); + BoxSum(top_border, top_border_stride, width, sum_stride, temp_stride, sum3[0], + sum5[1], square_sum3[0], square_sum5[1]); + sum5[0] = sum5[1]; + square_sum5[0] = square_sum5[1]; + const uint16_t* const s = (height > 1) ? src + stride : bottom_border; + BoxSumFilterPreProcess(src, s, width, scales, sum3, sum5, square_sum3, + square_sum5, sum_width, ma343, ma444[0], ma565[0], + b343, b444[0], b565[0]); + sum5[0] = sgr_buffer->sum5 + kSumOffset; + square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset; + + for (int y = (height >> 1) - 1; y > 0; --y) { + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + BoxFilter(src + 3, src + 2 * stride, src + 3 * stride, stride, width, + scales, w0, w2, sum3, sum5, square_sum3, square_sum5, sum_width, + ma343, ma444, ma565, b343, b444, b565, dst); + src += 2 * stride; + dst += 2 * stride; + Circulate4PointersBy2<uint16_t>(ma343); + Circulate4PointersBy2<uint32_t>(b343); + std::swap(ma444[0], ma444[2]); + std::swap(b444[0], b444[2]); + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + if ((height & 1) == 0 || height > 1) { + const uint16_t* sr[2]; + if ((height & 1) == 0) { + sr[0] = bottom_border; + sr[1] = bottom_border + bottom_border_stride; + } else { + sr[0] = src + 2 * stride; + sr[1] = bottom_border; + } + BoxFilter(src + 3, sr[0], sr[1], stride, width, scales, w0, w2, sum3, sum5, + square_sum3, square_sum5, sum_width, ma343, ma444, ma565, b343, + b444, b565, dst); + } + if ((height & 1) != 0) { + if (height > 1) { + src += 2 * stride; + dst += 2 * stride; + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + Circulate4PointersBy2<uint16_t>(ma343); + Circulate4PointersBy2<uint32_t>(b343); + std::swap(ma444[0], ma444[2]); + std::swap(b444[0], b444[2]); + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + BoxFilterLastRow(src + 3, bottom_border + bottom_border_stride, width, + sum_width, scales, w0, w2, sum3, sum5, square_sum3, + square_sum5, ma343[0], ma444[0], ma565[0], b343[0], + b444[0], b565[0], dst); + } +} + +inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, + const uint16_t* src, const ptrdiff_t stride, + const uint16_t* const top_border, + const ptrdiff_t top_border_stride, + const uint16_t* bottom_border, + const ptrdiff_t bottom_border_stride, + const int width, const int height, + SgrBuffer* const sgr_buffer, uint16_t* dst) { + const auto temp_stride = Align<ptrdiff_t>(width, 32); + const auto sum_width = temp_stride + 8; + const auto sum_stride = temp_stride + 32; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0]; // < 2^12. + const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0]; + uint16_t *sum5[5], *ma565[2]; + uint32_t *square_sum5[5], *b565[2]; + sum5[0] = sgr_buffer->sum5 + kSumOffset; + square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset; + for (int i = 1; i <= 4; ++i) { + sum5[i] = sum5[i - 1] + sum_stride; + square_sum5[i] = square_sum5[i - 1] + sum_stride; + } + ma565[0] = sgr_buffer->ma565; + ma565[1] = ma565[0] + temp_stride; + b565[0] = sgr_buffer->b565; + b565[1] = b565[0] + temp_stride; + assert(scale != 0); + BoxSum<5>(top_border, top_border_stride, width, sum_stride, temp_stride, + sum5[1], square_sum5[1]); + sum5[0] = sum5[1]; + square_sum5[0] = square_sum5[1]; + const uint16_t* const s = (height > 1) ? src + stride : bottom_border; + BoxSumFilterPreProcess5(src, s, width, scale, sum5, square_sum5, sum_width, + ma565[0], b565[0]); + sum5[0] = sgr_buffer->sum5 + kSumOffset; + square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset; + + for (int y = (height >> 1) - 1; y > 0; --y) { + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + BoxFilterPass1(src + 3, src + 2 * stride, src + 3 * stride, stride, sum5, + square_sum5, width, sum_width, scale, w0, ma565, b565, dst); + src += 2 * stride; + dst += 2 * stride; + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + if ((height & 1) == 0 || height > 1) { + const uint16_t* sr[2]; + if ((height & 1) == 0) { + sr[0] = bottom_border; + sr[1] = bottom_border + bottom_border_stride; + } else { + sr[0] = src + 2 * stride; + sr[1] = bottom_border; + } + BoxFilterPass1(src + 3, sr[0], sr[1], stride, sum5, square_sum5, width, + sum_width, scale, w0, ma565, b565, dst); + } + if ((height & 1) != 0) { + src += 3; + if (height > 1) { + src += 2 * stride; + dst += 2 * stride; + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + } + BoxFilterPass1LastRow(src, bottom_border + bottom_border_stride, width, + sum_width, scale, w0, sum5, square_sum5, ma565[0], + b565[0], dst); + } +} + +inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, + const uint16_t* src, const ptrdiff_t stride, + const uint16_t* const top_border, + const ptrdiff_t top_border_stride, + const uint16_t* bottom_border, + const ptrdiff_t bottom_border_stride, + const int width, const int height, + SgrBuffer* const sgr_buffer, uint16_t* dst) { + assert(restoration_info.sgr_proj_info.multiplier[0] == 0); + const auto temp_stride = Align<ptrdiff_t>(width, 32); + const auto sum_width = temp_stride + 8; + const auto sum_stride = temp_stride + 32; + const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1]; + const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint32_t scale = kSgrScaleParameter[sgr_proj_index][1]; // < 2^12. + uint16_t *sum3[3], *ma343[3], *ma444[2]; + uint32_t *square_sum3[3], *b343[3], *b444[2]; + sum3[0] = sgr_buffer->sum3 + kSumOffset; + square_sum3[0] = sgr_buffer->square_sum3 + kSumOffset; + ma343[0] = sgr_buffer->ma343; + b343[0] = sgr_buffer->b343; + for (int i = 1; i <= 2; ++i) { + sum3[i] = sum3[i - 1] + sum_stride; + square_sum3[i] = square_sum3[i - 1] + sum_stride; + ma343[i] = ma343[i - 1] + temp_stride; + b343[i] = b343[i - 1] + temp_stride; + } + ma444[0] = sgr_buffer->ma444; + ma444[1] = ma444[0] + temp_stride; + b444[0] = sgr_buffer->b444; + b444[1] = b444[0] + temp_stride; + assert(scale != 0); + BoxSum<3>(top_border, top_border_stride, width, sum_stride, temp_stride, + sum3[0], square_sum3[0]); + BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3, + sum_width, ma343[0], nullptr, b343[0], + nullptr); + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + const uint16_t* s; + if (height > 1) { + s = src + stride; + } else { + s = bottom_border; + bottom_border += bottom_border_stride; + } + BoxSumFilterPreProcess3<true>(s, width, scale, sum3, square_sum3, sum_width, + ma343[1], ma444[0], b343[1], b444[0]); + + for (int y = height - 2; y > 0; --y) { + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + BoxFilterPass2(src + 2, src + 2 * stride, width, sum_width, scale, w0, sum3, + square_sum3, ma343, ma444, b343, b444, dst); + src += stride; + dst += stride; + Circulate3PointersBy1<uint16_t>(ma343); + Circulate3PointersBy1<uint32_t>(b343); + std::swap(ma444[0], ma444[1]); + std::swap(b444[0], b444[1]); + } + + int y = std::min(height, 2); + src += 2; + do { + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + BoxFilterPass2(src, bottom_border, width, sum_width, scale, w0, sum3, + square_sum3, ma343, ma444, b343, b444, dst); + src += stride; + dst += stride; + bottom_border += bottom_border_stride; + Circulate3PointersBy1<uint16_t>(ma343); + Circulate3PointersBy1<uint32_t>(b343); + std::swap(ma444[0], ma444[1]); + std::swap(b444[0], b444[1]); + } while (--y != 0); +} + +// If |width| is non-multiple of 32, up to 31 more pixels are written to |dest| +// in the end of each row. It is safe to overwrite the output as it will not be +// part of the visible frame. +void SelfGuidedFilter_AVX2( + const RestorationUnitInfo& restoration_info, const void* const source, + const ptrdiff_t stride, const void* const top_border, + const ptrdiff_t top_border_stride, const void* const bottom_border, + const ptrdiff_t bottom_border_stride, const int width, const int height, + RestorationBuffer* const restoration_buffer, void* const dest) { + const int index = restoration_info.sgr_proj_info.index; + const int radius_pass_0 = kSgrProjParams[index][0]; // 2 or 0 + const int radius_pass_1 = kSgrProjParams[index][2]; // 1 or 0 + const auto* const src = static_cast<const uint16_t*>(source); + const auto* const top = static_cast<const uint16_t*>(top_border); + const auto* const bottom = static_cast<const uint16_t*>(bottom_border); + auto* const dst = static_cast<uint16_t*>(dest); + SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer; + if (radius_pass_1 == 0) { + // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the + // following assertion. + assert(radius_pass_0 != 0); + BoxFilterProcessPass1(restoration_info, src - 3, stride, top - 3, + top_border_stride, bottom - 3, bottom_border_stride, + width, height, sgr_buffer, dst); + } else if (radius_pass_0 == 0) { + BoxFilterProcessPass2(restoration_info, src - 2, stride, top - 2, + top_border_stride, bottom - 2, bottom_border_stride, + width, height, sgr_buffer, dst); + } else { + BoxFilterProcess(restoration_info, src - 3, stride, top - 3, + top_border_stride, bottom - 3, bottom_border_stride, width, + height, sgr_buffer, dst); + } +} + void Init10bpp() { Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); assert(dsp != nullptr); #if DSP_ENABLED_10BPP_AVX2(WienerFilter) dsp->loop_restorations[0] = WienerFilter_AVX2; #endif +#if DSP_ENABLED_10BPP_AVX2(SelfGuidedFilter) + dsp->loop_restorations[1] = SelfGuidedFilter_AVX2; +#endif } } // namespace @@ -581,7 +3146,7 @@ void LoopRestorationInit10bpp_AVX2() { Init10bpp(); } } // namespace dsp } // namespace libgav1 -#else // !(LIBGAV1_TARGETING_AVX2 && LIBGAV1_MAX_BITDEPTH >= 10) +#else // !(LIBGAV1_TARGETING_AVX2 && LIBGAV1_MAX_BITDEPTH >= 10) namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/loop_restoration_10bit_sse4.cc b/src/dsp/x86/loop_restoration_10bit_sse4.cc index 0598435..96380e3 100644 --- a/src/dsp/x86/loop_restoration_10bit_sse4.cc +++ b/src/dsp/x86/loop_restoration_10bit_sse4.cc @@ -428,13 +428,12 @@ inline void WienerVerticalTap1(const int16_t* wiener_buffer, } } -void WienerFilter_SSE4_1(const RestorationUnitInfo& restoration_info, - const void* const source, const void* const top_border, - const void* const bottom_border, - const ptrdiff_t stride, const int width, - const int height, - RestorationBuffer* const restoration_buffer, - void* const dest) { +void WienerFilter_SSE4_1( + const RestorationUnitInfo& restoration_info, const void* const source, + const ptrdiff_t stride, const void* const top_border, + const ptrdiff_t top_border_stride, const void* const bottom_border, + const ptrdiff_t bottom_border_stride, const int width, const int height, + RestorationBuffer* const restoration_buffer, void* const dest) { const int16_t* const number_leading_zero_coefficients = restoration_info.wiener_info.number_leading_zero_coefficients; const int number_rows_to_skip = std::max( @@ -458,39 +457,42 @@ void WienerFilter_SSE4_1(const RestorationUnitInfo& restoration_info, const __m128i coefficients_horizontal = LoadLo8(restoration_info.wiener_info.filter[WienerInfo::kHorizontal]); if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) { - WienerHorizontalTap7(top + (2 - height_extra) * stride - 3, stride, - wiener_stride, height_extra, coefficients_horizontal, - &wiener_buffer_horizontal); - WienerHorizontalTap7(src - 3, stride, wiener_stride, height, + WienerHorizontalTap7(top + (2 - height_extra) * top_border_stride - 3, + top_border_stride, wiener_stride, height_extra, coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap7(bottom - 3, stride, wiener_stride, height_extra, + WienerHorizontalTap7(src - 3, stride, wiener_stride, height, coefficients_horizontal, &wiener_buffer_horizontal); - } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) { - WienerHorizontalTap5(top + (2 - height_extra) * stride - 2, stride, - wiener_stride, height_extra, coefficients_horizontal, + WienerHorizontalTap7(bottom - 3, bottom_border_stride, wiener_stride, + height_extra, coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap5(src - 2, stride, wiener_stride, height, + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) { + WienerHorizontalTap5(top + (2 - height_extra) * top_border_stride - 2, + top_border_stride, wiener_stride, height_extra, coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap5(bottom - 2, stride, wiener_stride, height_extra, + WienerHorizontalTap5(src - 2, stride, wiener_stride, height, coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap5(bottom - 2, bottom_border_stride, wiener_stride, + height_extra, coefficients_horizontal, + &wiener_buffer_horizontal); } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) { // The maximum over-reads happen here. - WienerHorizontalTap3(top + (2 - height_extra) * stride - 1, stride, - wiener_stride, height_extra, coefficients_horizontal, - &wiener_buffer_horizontal); - WienerHorizontalTap3(src - 1, stride, wiener_stride, height, + WienerHorizontalTap3(top + (2 - height_extra) * top_border_stride - 1, + top_border_stride, wiener_stride, height_extra, coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap3(bottom - 1, stride, wiener_stride, height_extra, + WienerHorizontalTap3(src - 1, stride, wiener_stride, height, coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap3(bottom - 1, bottom_border_stride, wiener_stride, + height_extra, coefficients_horizontal, + &wiener_buffer_horizontal); } else { assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3); - WienerHorizontalTap1(top + (2 - height_extra) * stride, stride, - wiener_stride, height_extra, + WienerHorizontalTap1(top + (2 - height_extra) * top_border_stride, + top_border_stride, wiener_stride, height_extra, &wiener_buffer_horizontal); WienerHorizontalTap1(src, stride, wiener_stride, height, &wiener_buffer_horizontal); - WienerHorizontalTap1(bottom, stride, wiener_stride, height_extra, - &wiener_buffer_horizontal); + WienerHorizontalTap1(bottom, bottom_border_stride, wiener_stride, + height_extra, &wiener_buffer_horizontal); } // vertical filtering. @@ -522,6 +524,1978 @@ void WienerFilter_SSE4_1(const RestorationUnitInfo& restoration_info, } } +//------------------------------------------------------------------------------ +// SGR + +// SIMD overreads 8 - (width % 8) - 2 * padding pixels, where padding is 3 for +// Pass 1 and 2 for Pass 2. +constexpr int kOverreadInBytesPass1 = 4; +constexpr int kOverreadInBytesPass2 = 8; + +inline void LoadAligned16x2U16(const uint16_t* const src[2], const ptrdiff_t x, + __m128i dst[2]) { + dst[0] = LoadAligned16(src[0] + x); + dst[1] = LoadAligned16(src[1] + x); +} + +inline void LoadAligned16x2U16Msan(const uint16_t* const src[2], + const ptrdiff_t x, const ptrdiff_t border, + __m128i dst[2]) { + dst[0] = LoadAligned16Msan(src[0] + x, sizeof(**src) * (x + 8 - border)); + dst[1] = LoadAligned16Msan(src[1] + x, sizeof(**src) * (x + 8 - border)); +} + +inline void LoadAligned16x3U16(const uint16_t* const src[3], const ptrdiff_t x, + __m128i dst[3]) { + dst[0] = LoadAligned16(src[0] + x); + dst[1] = LoadAligned16(src[1] + x); + dst[2] = LoadAligned16(src[2] + x); +} + +inline void LoadAligned16x3U16Msan(const uint16_t* const src[3], + const ptrdiff_t x, const ptrdiff_t border, + __m128i dst[3]) { + dst[0] = LoadAligned16Msan(src[0] + x, sizeof(**src) * (x + 8 - border)); + dst[1] = LoadAligned16Msan(src[1] + x, sizeof(**src) * (x + 8 - border)); + dst[2] = LoadAligned16Msan(src[2] + x, sizeof(**src) * (x + 8 - border)); +} + +inline void LoadAligned32U32(const uint32_t* const src, __m128i dst[2]) { + dst[0] = LoadAligned16(src + 0); + dst[1] = LoadAligned16(src + 4); +} + +inline void LoadAligned32U32Msan(const uint32_t* const src, const ptrdiff_t x, + const ptrdiff_t border, __m128i dst[2]) { + dst[0] = LoadAligned16Msan(src + x + 0, sizeof(*src) * (x + 4 - border)); + dst[1] = LoadAligned16Msan(src + x + 4, sizeof(*src) * (x + 8 - border)); +} + +inline void LoadAligned32x2U32(const uint32_t* const src[2], const ptrdiff_t x, + __m128i dst[2][2]) { + LoadAligned32U32(src[0] + x, dst[0]); + LoadAligned32U32(src[1] + x, dst[1]); +} + +inline void LoadAligned32x2U32Msan(const uint32_t* const src[2], + const ptrdiff_t x, const ptrdiff_t border, + __m128i dst[2][2]) { + LoadAligned32U32Msan(src[0], x, border, dst[0]); + LoadAligned32U32Msan(src[1], x, border, dst[1]); +} + +inline void LoadAligned32x3U32(const uint32_t* const src[3], const ptrdiff_t x, + __m128i dst[3][2]) { + LoadAligned32U32(src[0] + x, dst[0]); + LoadAligned32U32(src[1] + x, dst[1]); + LoadAligned32U32(src[2] + x, dst[2]); +} + +inline void LoadAligned32x3U32Msan(const uint32_t* const src[3], + const ptrdiff_t x, const ptrdiff_t border, + __m128i dst[3][2]) { + LoadAligned32U32Msan(src[0], x, border, dst[0]); + LoadAligned32U32Msan(src[1], x, border, dst[1]); + LoadAligned32U32Msan(src[2], x, border, dst[2]); +} + +inline void StoreAligned32U16(uint16_t* const dst, const __m128i src[2]) { + StoreAligned16(dst + 0, src[0]); + StoreAligned16(dst + 8, src[1]); +} + +inline void StoreAligned32U32(uint32_t* const dst, const __m128i src[2]) { + StoreAligned16(dst + 0, src[0]); + StoreAligned16(dst + 4, src[1]); +} + +inline void StoreAligned64U32(uint32_t* const dst, const __m128i src[4]) { + StoreAligned32U32(dst + 0, src + 0); + StoreAligned32U32(dst + 8, src + 2); +} + +// Don't use _mm_cvtepu8_epi16() or _mm_cvtepu16_epi32() in the following +// functions. Some compilers may generate super inefficient code and the whole +// decoder could be 15% slower. + +inline __m128i VaddlLo8(const __m128i src0, const __m128i src1) { + const __m128i s0 = _mm_unpacklo_epi8(src0, _mm_setzero_si128()); + const __m128i s1 = _mm_unpacklo_epi8(src1, _mm_setzero_si128()); + return _mm_add_epi16(s0, s1); +} + +inline __m128i VaddlHi8(const __m128i src0, const __m128i src1) { + const __m128i s0 = _mm_unpackhi_epi8(src0, _mm_setzero_si128()); + const __m128i s1 = _mm_unpackhi_epi8(src1, _mm_setzero_si128()); + return _mm_add_epi16(s0, s1); +} + +inline __m128i VaddwLo8(const __m128i src0, const __m128i src1) { + const __m128i s1 = _mm_unpacklo_epi8(src1, _mm_setzero_si128()); + return _mm_add_epi16(src0, s1); +} + +inline __m128i VaddwHi8(const __m128i src0, const __m128i src1) { + const __m128i s1 = _mm_unpackhi_epi8(src1, _mm_setzero_si128()); + return _mm_add_epi16(src0, s1); +} + +inline __m128i VmullNLo8(const __m128i src0, const int src1) { + const __m128i s0 = _mm_unpacklo_epi16(src0, _mm_setzero_si128()); + return _mm_madd_epi16(s0, _mm_set1_epi32(src1)); +} + +inline __m128i VmullNHi8(const __m128i src0, const int src1) { + const __m128i s0 = _mm_unpackhi_epi16(src0, _mm_setzero_si128()); + return _mm_madd_epi16(s0, _mm_set1_epi32(src1)); +} + +inline __m128i VmullLo16(const __m128i src0, const __m128i src1) { + const __m128i s0 = _mm_unpacklo_epi16(src0, _mm_setzero_si128()); + const __m128i s1 = _mm_unpacklo_epi16(src1, _mm_setzero_si128()); + return _mm_madd_epi16(s0, s1); +} + +inline __m128i VmullHi16(const __m128i src0, const __m128i src1) { + const __m128i s0 = _mm_unpackhi_epi16(src0, _mm_setzero_si128()); + const __m128i s1 = _mm_unpackhi_epi16(src1, _mm_setzero_si128()); + return _mm_madd_epi16(s0, s1); +} + +inline __m128i VrshrU16(const __m128i src0, const int src1) { + const __m128i sum = _mm_add_epi16(src0, _mm_set1_epi16(1 << (src1 - 1))); + return _mm_srli_epi16(sum, src1); +} + +inline __m128i VrshrS32(const __m128i src0, const int src1) { + const __m128i sum = _mm_add_epi32(src0, _mm_set1_epi32(1 << (src1 - 1))); + return _mm_srai_epi32(sum, src1); +} + +inline __m128i VrshrU32(const __m128i src0, const int src1) { + const __m128i sum = _mm_add_epi32(src0, _mm_set1_epi32(1 << (src1 - 1))); + return _mm_srli_epi32(sum, src1); +} + +inline void Square(const __m128i src, __m128i dst[2]) { + const __m128i s0 = _mm_unpacklo_epi16(src, _mm_setzero_si128()); + const __m128i s1 = _mm_unpackhi_epi16(src, _mm_setzero_si128()); + dst[0] = _mm_madd_epi16(s0, s0); + dst[1] = _mm_madd_epi16(s1, s1); +} + +template <int offset> +inline void Prepare3_8(const __m128i src[2], __m128i dst[3]) { + dst[0] = _mm_alignr_epi8(src[1], src[0], offset + 0); + dst[1] = _mm_alignr_epi8(src[1], src[0], offset + 1); + dst[2] = _mm_alignr_epi8(src[1], src[0], offset + 2); +} + +inline void Prepare3_16(const __m128i src[2], __m128i dst[3]) { + dst[0] = src[0]; + dst[1] = _mm_alignr_epi8(src[1], src[0], 2); + dst[2] = _mm_alignr_epi8(src[1], src[0], 4); +} + +inline void Prepare3_32(const __m128i src[2], __m128i dst[3]) { + dst[0] = src[0]; + dst[1] = _mm_alignr_epi8(src[1], src[0], 4); + dst[2] = _mm_alignr_epi8(src[1], src[0], 8); +} + +inline void Prepare5_16(const __m128i src[2], __m128i dst[5]) { + Prepare3_16(src, dst); + dst[3] = _mm_alignr_epi8(src[1], src[0], 6); + dst[4] = _mm_alignr_epi8(src[1], src[0], 8); +} + +inline void Prepare5_32(const __m128i src[2], __m128i dst[5]) { + Prepare3_32(src, dst); + dst[3] = _mm_alignr_epi8(src[1], src[0], 12); + dst[4] = src[1]; +} + +inline __m128i Sum3_16(const __m128i src0, const __m128i src1, + const __m128i src2) { + const __m128i sum = _mm_add_epi16(src0, src1); + return _mm_add_epi16(sum, src2); +} + +inline __m128i Sum3_16(const __m128i src[3]) { + return Sum3_16(src[0], src[1], src[2]); +} + +inline __m128i Sum3_32(const __m128i src0, const __m128i src1, + const __m128i src2) { + const __m128i sum = _mm_add_epi32(src0, src1); + return _mm_add_epi32(sum, src2); +} + +inline __m128i Sum3_32(const __m128i src[3]) { + return Sum3_32(src[0], src[1], src[2]); +} + +inline void Sum3_32(const __m128i src[3][2], __m128i dst[2]) { + dst[0] = Sum3_32(src[0][0], src[1][0], src[2][0]); + dst[1] = Sum3_32(src[0][1], src[1][1], src[2][1]); +} + +inline __m128i Sum3WLo16(const __m128i src[3]) { + const __m128i sum = VaddlLo8(src[0], src[1]); + return VaddwLo8(sum, src[2]); +} + +inline __m128i Sum3WHi16(const __m128i src[3]) { + const __m128i sum = VaddlHi8(src[0], src[1]); + return VaddwHi8(sum, src[2]); +} + +inline __m128i Sum5_16(const __m128i src[5]) { + const __m128i sum01 = _mm_add_epi16(src[0], src[1]); + const __m128i sum23 = _mm_add_epi16(src[2], src[3]); + const __m128i sum = _mm_add_epi16(sum01, sum23); + return _mm_add_epi16(sum, src[4]); +} + +inline __m128i Sum5_32(const __m128i* const src0, const __m128i* const src1, + const __m128i* const src2, const __m128i* const src3, + const __m128i* const src4) { + const __m128i sum01 = _mm_add_epi32(*src0, *src1); + const __m128i sum23 = _mm_add_epi32(*src2, *src3); + const __m128i sum = _mm_add_epi32(sum01, sum23); + return _mm_add_epi32(sum, *src4); +} + +inline __m128i Sum5_32(const __m128i src[5]) { + return Sum5_32(&src[0], &src[1], &src[2], &src[3], &src[4]); +} + +inline void Sum5_32(const __m128i src[5][2], __m128i dst[2]) { + dst[0] = Sum5_32(&src[0][0], &src[1][0], &src[2][0], &src[3][0], &src[4][0]); + dst[1] = Sum5_32(&src[0][1], &src[1][1], &src[2][1], &src[3][1], &src[4][1]); +} + +inline __m128i Sum3Horizontal16(const __m128i src[2]) { + __m128i s[3]; + Prepare3_16(src, s); + return Sum3_16(s); +} + +inline void Sum3Horizontal32(const __m128i src[3], __m128i dst[2]) { + __m128i s[3]; + Prepare3_32(src + 0, s); + dst[0] = Sum3_32(s); + Prepare3_32(src + 1, s); + dst[1] = Sum3_32(s); +} + +inline __m128i Sum5Horizontal16(const __m128i src[2]) { + __m128i s[5]; + Prepare5_16(src, s); + return Sum5_16(s); +} + +inline void Sum5Horizontal32(const __m128i src[3], __m128i dst[2]) { + __m128i s[5]; + Prepare5_32(src + 0, s); + dst[0] = Sum5_32(s); + Prepare5_32(src + 1, s); + dst[1] = Sum5_32(s); +} + +void SumHorizontal16(const __m128i src[2], __m128i* const row3, + __m128i* const row5) { + __m128i s[5]; + Prepare5_16(src, s); + const __m128i sum04 = _mm_add_epi16(s[0], s[4]); + *row3 = Sum3_16(s + 1); + *row5 = _mm_add_epi16(sum04, *row3); +} + +inline void SumHorizontal16(const __m128i src[3], __m128i* const row3_0, + __m128i* const row3_1, __m128i* const row5_0, + __m128i* const row5_1) { + SumHorizontal16(src + 0, row3_0, row5_0); + SumHorizontal16(src + 1, row3_1, row5_1); +} + +void SumHorizontal32(const __m128i src[5], __m128i* const row_sq3, + __m128i* const row_sq5) { + const __m128i sum04 = _mm_add_epi32(src[0], src[4]); + *row_sq3 = Sum3_32(src + 1); + *row_sq5 = _mm_add_epi32(sum04, *row_sq3); +} + +inline void SumHorizontal32(const __m128i src[3], __m128i* const row_sq3_0, + __m128i* const row_sq3_1, __m128i* const row_sq5_0, + __m128i* const row_sq5_1) { + __m128i s[5]; + Prepare5_32(src + 0, s); + SumHorizontal32(s, row_sq3_0, row_sq5_0); + Prepare5_32(src + 1, s); + SumHorizontal32(s, row_sq3_1, row_sq5_1); +} + +inline __m128i Sum343Lo(const __m128i ma3[3]) { + const __m128i sum = Sum3WLo16(ma3); + const __m128i sum3 = Sum3_16(sum, sum, sum); + return VaddwLo8(sum3, ma3[1]); +} + +inline __m128i Sum343Hi(const __m128i ma3[3]) { + const __m128i sum = Sum3WHi16(ma3); + const __m128i sum3 = Sum3_16(sum, sum, sum); + return VaddwHi8(sum3, ma3[1]); +} + +inline __m128i Sum343(const __m128i src[3]) { + const __m128i sum = Sum3_32(src); + const __m128i sum3 = Sum3_32(sum, sum, sum); + return _mm_add_epi32(sum3, src[1]); +} + +inline void Sum343(const __m128i src[3], __m128i dst[2]) { + __m128i s[3]; + Prepare3_32(src + 0, s); + dst[0] = Sum343(s); + Prepare3_32(src + 1, s); + dst[1] = Sum343(s); +} + +inline __m128i Sum565Lo(const __m128i src[3]) { + const __m128i sum = Sum3WLo16(src); + const __m128i sum4 = _mm_slli_epi16(sum, 2); + const __m128i sum5 = _mm_add_epi16(sum4, sum); + return VaddwLo8(sum5, src[1]); +} + +inline __m128i Sum565Hi(const __m128i src[3]) { + const __m128i sum = Sum3WHi16(src); + const __m128i sum4 = _mm_slli_epi16(sum, 2); + const __m128i sum5 = _mm_add_epi16(sum4, sum); + return VaddwHi8(sum5, src[1]); +} + +inline __m128i Sum565(const __m128i src[3]) { + const __m128i sum = Sum3_32(src); + const __m128i sum4 = _mm_slli_epi32(sum, 2); + const __m128i sum5 = _mm_add_epi32(sum4, sum); + return _mm_add_epi32(sum5, src[1]); +} + +inline void Sum565(const __m128i src[3], __m128i dst[2]) { + __m128i s[3]; + Prepare3_32(src + 0, s); + dst[0] = Sum565(s); + Prepare3_32(src + 1, s); + dst[1] = Sum565(s); +} + +inline void BoxSum(const uint16_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const ptrdiff_t sum_stride, + const ptrdiff_t sum_width, uint16_t* sum3, uint16_t* sum5, + uint32_t* square_sum3, uint32_t* square_sum5) { + const ptrdiff_t overread_in_bytes = + kOverreadInBytesPass1 - sizeof(*src) * width; + int y = 2; + do { + __m128i s[3], sq[6]; + s[0] = LoadUnaligned16Msan(src, overread_in_bytes); + Square(s[0], sq); + ptrdiff_t x = sum_width; + do { + __m128i row3[2], row5[2], row_sq3[2], row_sq5[2]; + s[1] = LoadUnaligned16Msan( + src + 8, overread_in_bytes + sizeof(*src) * (sum_width - x + 8)); + x -= 16; + src += 16; + s[2] = LoadUnaligned16Msan( + src, overread_in_bytes + sizeof(*src) * (sum_width - x)); + Square(s[1], sq + 2); + Square(s[2], sq + 4); + SumHorizontal16(s, &row3[0], &row3[1], &row5[0], &row5[1]); + StoreAligned32U16(sum3, row3); + StoreAligned32U16(sum5, row5); + SumHorizontal32(sq + 0, &row_sq3[0], &row_sq3[1], &row_sq5[0], + &row_sq5[1]); + StoreAligned32U32(square_sum3 + 0, row_sq3); + StoreAligned32U32(square_sum5 + 0, row_sq5); + SumHorizontal32(sq + 2, &row_sq3[0], &row_sq3[1], &row_sq5[0], + &row_sq5[1]); + StoreAligned32U32(square_sum3 + 8, row_sq3); + StoreAligned32U32(square_sum5 + 8, row_sq5); + s[0] = s[2]; + sq[0] = sq[4]; + sq[1] = sq[5]; + sum3 += 16; + sum5 += 16; + square_sum3 += 16; + square_sum5 += 16; + } while (x != 0); + src += src_stride - sum_width; + sum3 += sum_stride - sum_width; + sum5 += sum_stride - sum_width; + square_sum3 += sum_stride - sum_width; + square_sum5 += sum_stride - sum_width; + } while (--y != 0); +} + +template <int size> +inline void BoxSum(const uint16_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const ptrdiff_t sum_stride, + const ptrdiff_t sum_width, uint16_t* sums, + uint32_t* square_sums) { + static_assert(size == 3 || size == 5, ""); + const ptrdiff_t overread_in_bytes = + ((size == 5) ? kOverreadInBytesPass1 : kOverreadInBytesPass2) - + sizeof(*src) * width; + int y = 2; + do { + __m128i s[3], sq[6]; + s[0] = LoadUnaligned16Msan(src, overread_in_bytes); + Square(s[0], sq); + ptrdiff_t x = sum_width; + do { + __m128i row[2], row_sq[4]; + s[1] = LoadUnaligned16Msan( + src + 8, overread_in_bytes + sizeof(*src) * (sum_width - x + 8)); + x -= 16; + src += 16; + s[2] = LoadUnaligned16Msan( + src, overread_in_bytes + sizeof(*src) * (sum_width - x)); + Square(s[1], sq + 2); + Square(s[2], sq + 4); + if (size == 3) { + row[0] = Sum3Horizontal16(s + 0); + row[1] = Sum3Horizontal16(s + 1); + Sum3Horizontal32(sq + 0, row_sq + 0); + Sum3Horizontal32(sq + 2, row_sq + 2); + } else { + row[0] = Sum5Horizontal16(s + 0); + row[1] = Sum5Horizontal16(s + 1); + Sum5Horizontal32(sq + 0, row_sq + 0); + Sum5Horizontal32(sq + 2, row_sq + 2); + } + StoreAligned32U16(sums, row); + StoreAligned64U32(square_sums, row_sq); + s[0] = s[2]; + sq[0] = sq[4]; + sq[1] = sq[5]; + sums += 16; + square_sums += 16; + } while (x != 0); + src += src_stride - sum_width; + sums += sum_stride - sum_width; + square_sums += sum_stride - sum_width; + } while (--y != 0); +} + +template <int n> +inline __m128i CalculateMa(const __m128i sum, const __m128i sum_sq, + const uint32_t scale) { + static_assert(n == 9 || n == 25, ""); + // a = |sum_sq| + // d = |sum| + // p = (a * n < d * d) ? 0 : a * n - d * d; + const __m128i dxd = _mm_madd_epi16(sum, sum); + // _mm_mullo_epi32() has high latency. Using shifts and additions instead. + // Some compilers could do this for us but we make this explicit. + // return _mm_mullo_epi32(sum_sq, _mm_set1_epi32(n)); + __m128i axn = _mm_add_epi32(sum_sq, _mm_slli_epi32(sum_sq, 3)); + if (n == 25) axn = _mm_add_epi32(axn, _mm_slli_epi32(sum_sq, 4)); + const __m128i sub = _mm_sub_epi32(axn, dxd); + const __m128i p = _mm_max_epi32(sub, _mm_setzero_si128()); + const __m128i pxs = _mm_mullo_epi32(p, _mm_set1_epi32(scale)); + return VrshrU32(pxs, kSgrProjScaleBits); +} + +template <int n> +inline __m128i CalculateMa(const __m128i sum, const __m128i sum_sq[2], + const uint32_t scale) { + static_assert(n == 9 || n == 25, ""); + const __m128i b = VrshrU16(sum, 2); + const __m128i sum_lo = _mm_unpacklo_epi16(b, _mm_setzero_si128()); + const __m128i sum_hi = _mm_unpackhi_epi16(b, _mm_setzero_si128()); + const __m128i z0 = CalculateMa<n>(sum_lo, VrshrU32(sum_sq[0], 4), scale); + const __m128i z1 = CalculateMa<n>(sum_hi, VrshrU32(sum_sq[1], 4), scale); + return _mm_packus_epi32(z0, z1); +} + +inline void CalculateB5(const __m128i sum, const __m128i ma, __m128i b[2]) { + // one_over_n == 164. + constexpr uint32_t one_over_n = + ((1 << kSgrProjReciprocalBits) + (25 >> 1)) / 25; + // one_over_n_quarter == 41. + constexpr uint32_t one_over_n_quarter = one_over_n >> 2; + static_assert(one_over_n == one_over_n_quarter << 2, ""); + // |ma| is in range [0, 255]. + const __m128i m = _mm_maddubs_epi16(ma, _mm_set1_epi16(one_over_n_quarter)); + const __m128i m0 = VmullLo16(m, sum); + const __m128i m1 = VmullHi16(m, sum); + b[0] = VrshrU32(m0, kSgrProjReciprocalBits - 2); + b[1] = VrshrU32(m1, kSgrProjReciprocalBits - 2); +} + +inline void CalculateB3(const __m128i sum, const __m128i ma, __m128i b[2]) { + // one_over_n == 455. + constexpr uint32_t one_over_n = + ((1 << kSgrProjReciprocalBits) + (9 >> 1)) / 9; + const __m128i m0 = VmullLo16(ma, sum); + const __m128i m1 = VmullHi16(ma, sum); + const __m128i m2 = _mm_mullo_epi32(m0, _mm_set1_epi32(one_over_n)); + const __m128i m3 = _mm_mullo_epi32(m1, _mm_set1_epi32(one_over_n)); + b[0] = VrshrU32(m2, kSgrProjReciprocalBits); + b[1] = VrshrU32(m3, kSgrProjReciprocalBits); +} + +inline void CalculateSumAndIndex5(const __m128i s5[5], const __m128i sq5[5][2], + const uint32_t scale, __m128i* const sum, + __m128i* const index) { + __m128i sum_sq[2]; + *sum = Sum5_16(s5); + Sum5_32(sq5, sum_sq); + *index = CalculateMa<25>(*sum, sum_sq, scale); +} + +inline void CalculateSumAndIndex3(const __m128i s3[3], const __m128i sq3[3][2], + const uint32_t scale, __m128i* const sum, + __m128i* const index) { + __m128i sum_sq[2]; + *sum = Sum3_16(s3); + Sum3_32(sq3, sum_sq); + *index = CalculateMa<9>(*sum, sum_sq, scale); +} + +template <int n, int offset> +inline void LookupIntermediate(const __m128i sum, const __m128i index, + __m128i* const ma, __m128i b[2]) { + static_assert(n == 9 || n == 25, ""); + static_assert(offset == 0 || offset == 8, ""); + const __m128i idx = _mm_packus_epi16(index, index); + // Actually it's not stored and loaded. The compiler will use a 64-bit + // general-purpose register to process. Faster than using _mm_extract_epi8(). + uint8_t temp[8]; + StoreLo8(temp, idx); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[0]], offset + 0); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[1]], offset + 1); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[2]], offset + 2); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[3]], offset + 3); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[4]], offset + 4); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[5]], offset + 5); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[6]], offset + 6); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[7]], offset + 7); + // b = ma * b * one_over_n + // |ma| = [0, 255] + // |sum| is a box sum with radius 1 or 2. + // For the first pass radius is 2. Maximum value is 5x5x255 = 6375. + // For the second pass radius is 1. Maximum value is 3x3x255 = 2295. + // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n + // When radius is 2 |n| is 25. |one_over_n| is 164. + // When radius is 1 |n| is 9. |one_over_n| is 455. + // |kSgrProjReciprocalBits| is 12. + // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits). + // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits). + __m128i maq; + if (offset == 0) { + maq = _mm_unpacklo_epi8(*ma, _mm_setzero_si128()); + } else { + maq = _mm_unpackhi_epi8(*ma, _mm_setzero_si128()); + } + if (n == 9) { + CalculateB3(sum, maq, b); + } else { + CalculateB5(sum, maq, b); + } +} + +// Set the shuffle control mask of indices out of range [0, 15] to (1xxxxxxx)b +// to get value 0 as the shuffle result. The most significiant bit 1 comes +// either from the comparison instruction, or from the sign bit of the index. +inline __m128i ShuffleIndex(const __m128i table, const __m128i index) { + __m128i mask; + mask = _mm_cmpgt_epi8(index, _mm_set1_epi8(15)); + mask = _mm_or_si128(mask, index); + return _mm_shuffle_epi8(table, mask); +} + +inline __m128i AdjustValue(const __m128i value, const __m128i index, + const int threshold) { + const __m128i thresholds = _mm_set1_epi8(threshold - 128); + const __m128i offset = _mm_cmpgt_epi8(index, thresholds); + return _mm_add_epi8(value, offset); +} + +inline void CalculateIntermediate(const __m128i sum[2], const __m128i index[2], + __m128i* const ma, __m128i b0[2], + __m128i b1[2]) { + // Use table lookup to read elements whose indices are less than 48. + const __m128i c0 = LoadAligned16(kSgrMaLookup + 0 * 16); + const __m128i c1 = LoadAligned16(kSgrMaLookup + 1 * 16); + const __m128i c2 = LoadAligned16(kSgrMaLookup + 2 * 16); + const __m128i indices = _mm_packus_epi16(index[0], index[1]); + __m128i idx; + // Clip idx to 127 to apply signed comparison instructions. + idx = _mm_min_epu8(indices, _mm_set1_epi8(127)); + // All elements whose indices are less than 48 are set to 0. + // Get shuffle results for indices in range [0, 15]. + *ma = ShuffleIndex(c0, idx); + // Get shuffle results for indices in range [16, 31]. + // Subtract 16 to utilize the sign bit of the index. + idx = _mm_sub_epi8(idx, _mm_set1_epi8(16)); + const __m128i res1 = ShuffleIndex(c1, idx); + // Use OR instruction to combine shuffle results together. + *ma = _mm_or_si128(*ma, res1); + // Get shuffle results for indices in range [32, 47]. + // Subtract 16 to utilize the sign bit of the index. + idx = _mm_sub_epi8(idx, _mm_set1_epi8(16)); + const __m128i res2 = ShuffleIndex(c2, idx); + *ma = _mm_or_si128(*ma, res2); + + // For elements whose indices are larger than 47, since they seldom change + // values with the increase of the index, we use comparison and arithmetic + // operations to calculate their values. + // Add -128 to apply signed comparison instructions. + idx = _mm_add_epi8(indices, _mm_set1_epi8(-128)); + // Elements whose indices are larger than 47 (with value 0) are set to 5. + *ma = _mm_max_epu8(*ma, _mm_set1_epi8(5)); + *ma = AdjustValue(*ma, idx, 55); // 55 is the last index which value is 5. + *ma = AdjustValue(*ma, idx, 72); // 72 is the last index which value is 4. + *ma = AdjustValue(*ma, idx, 101); // 101 is the last index which value is 3. + *ma = AdjustValue(*ma, idx, 169); // 169 is the last index which value is 2. + *ma = AdjustValue(*ma, idx, 254); // 254 is the last index which value is 1. + + // b = ma * b * one_over_n + // |ma| = [0, 255] + // |sum| is a box sum with radius 1 or 2. + // For the first pass radius is 2. Maximum value is 5x5x255 = 6375. + // For the second pass radius is 1. Maximum value is 3x3x255 = 2295. + // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n + // When radius is 2 |n| is 25. |one_over_n| is 164. + // When radius is 1 |n| is 9. |one_over_n| is 455. + // |kSgrProjReciprocalBits| is 12. + // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits). + // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits). + const __m128i maq0 = _mm_unpacklo_epi8(*ma, _mm_setzero_si128()); + CalculateB3(sum[0], maq0, b0); + const __m128i maq1 = _mm_unpackhi_epi8(*ma, _mm_setzero_si128()); + CalculateB3(sum[1], maq1, b1); +} + +inline void CalculateIntermediate(const __m128i sum[2], const __m128i index[2], + __m128i ma[2], __m128i b[4]) { + __m128i mas; + CalculateIntermediate(sum, index, &mas, b + 0, b + 2); + ma[0] = _mm_unpacklo_epi64(ma[0], mas); + ma[1] = _mm_srli_si128(mas, 8); +} + +// Note: It has been tried to call CalculateIntermediate() to replace the slow +// LookupIntermediate() when calculating 16 intermediate data points. However, +// the compiler generates even slower code. +template <int offset> +inline void CalculateIntermediate5(const __m128i s5[5], const __m128i sq5[5][2], + const uint32_t scale, __m128i* const ma, + __m128i b[2]) { + static_assert(offset == 0 || offset == 8, ""); + __m128i sum, index; + CalculateSumAndIndex5(s5, sq5, scale, &sum, &index); + LookupIntermediate<25, offset>(sum, index, ma, b); +} + +inline void CalculateIntermediate3(const __m128i s3[3], const __m128i sq3[3][2], + const uint32_t scale, __m128i* const ma, + __m128i b[2]) { + __m128i sum, index; + CalculateSumAndIndex3(s3, sq3, scale, &sum, &index); + LookupIntermediate<9, 0>(sum, index, ma, b); +} + +inline void Store343_444(const __m128i b3[3], const ptrdiff_t x, + __m128i sum_b343[2], __m128i sum_b444[2], + uint32_t* const b343, uint32_t* const b444) { + __m128i b[3], sum_b111[2]; + Prepare3_32(b3 + 0, b); + sum_b111[0] = Sum3_32(b); + sum_b444[0] = _mm_slli_epi32(sum_b111[0], 2); + sum_b343[0] = _mm_sub_epi32(sum_b444[0], sum_b111[0]); + sum_b343[0] = _mm_add_epi32(sum_b343[0], b[1]); + Prepare3_32(b3 + 1, b); + sum_b111[1] = Sum3_32(b); + sum_b444[1] = _mm_slli_epi32(sum_b111[1], 2); + sum_b343[1] = _mm_sub_epi32(sum_b444[1], sum_b111[1]); + sum_b343[1] = _mm_add_epi32(sum_b343[1], b[1]); + StoreAligned32U32(b444 + x, sum_b444); + StoreAligned32U32(b343 + x, sum_b343); +} + +inline void Store343_444Lo(const __m128i ma3[3], const __m128i b3[3], + const ptrdiff_t x, __m128i* const sum_ma343, + __m128i* const sum_ma444, __m128i sum_b343[2], + __m128i sum_b444[2], uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + const __m128i sum_ma111 = Sum3WLo16(ma3); + *sum_ma444 = _mm_slli_epi16(sum_ma111, 2); + StoreAligned16(ma444 + x, *sum_ma444); + const __m128i sum333 = _mm_sub_epi16(*sum_ma444, sum_ma111); + *sum_ma343 = VaddwLo8(sum333, ma3[1]); + StoreAligned16(ma343 + x, *sum_ma343); + Store343_444(b3, x, sum_b343, sum_b444, b343, b444); +} + +inline void Store343_444Hi(const __m128i ma3[3], const __m128i b3[3], + const ptrdiff_t x, __m128i* const sum_ma343, + __m128i* const sum_ma444, __m128i sum_b343[2], + __m128i sum_b444[2], uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + const __m128i sum_ma111 = Sum3WHi16(ma3); + *sum_ma444 = _mm_slli_epi16(sum_ma111, 2); + StoreAligned16(ma444 + x, *sum_ma444); + const __m128i sum333 = _mm_sub_epi16(*sum_ma444, sum_ma111); + *sum_ma343 = VaddwHi8(sum333, ma3[1]); + StoreAligned16(ma343 + x, *sum_ma343); + Store343_444(b3, x, sum_b343, sum_b444, b343, b444); +} + +inline void Store343_444Lo(const __m128i ma3[3], const __m128i b3[2], + const ptrdiff_t x, __m128i* const sum_ma343, + __m128i sum_b343[2], uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + __m128i sum_ma444, sum_b444[2]; + Store343_444Lo(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, sum_b444, ma343, + ma444, b343, b444); +} + +inline void Store343_444Hi(const __m128i ma3[3], const __m128i b3[2], + const ptrdiff_t x, __m128i* const sum_ma343, + __m128i sum_b343[2], uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + __m128i sum_ma444, sum_b444[2]; + Store343_444Hi(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, sum_b444, ma343, + ma444, b343, b444); +} + +inline void Store343_444Lo(const __m128i ma3[3], const __m128i b3[2], + const ptrdiff_t x, uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + __m128i sum_ma343, sum_b343[2]; + Store343_444Lo(ma3, b3, x, &sum_ma343, sum_b343, ma343, ma444, b343, b444); +} + +inline void Store343_444Hi(const __m128i ma3[3], const __m128i b3[2], + const ptrdiff_t x, uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + __m128i sum_ma343, sum_b343[2]; + Store343_444Hi(ma3, b3, x, &sum_ma343, sum_b343, ma343, ma444, b343, b444); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5Lo( + const __m128i s[2][4], const uint32_t scale, uint16_t* const sum5[5], + uint32_t* const square_sum5[5], __m128i sq[2][8], __m128i* const ma, + __m128i b[2]) { + __m128i s5[2][5], sq5[5][2]; + Square(s[0][1], sq[0] + 2); + Square(s[1][1], sq[1] + 2); + s5[0][3] = Sum5Horizontal16(s[0]); + StoreAligned16(sum5[3], s5[0][3]); + s5[0][4] = Sum5Horizontal16(s[1]); + StoreAligned16(sum5[4], s5[0][4]); + Sum5Horizontal32(sq[0], sq5[3]); + StoreAligned32U32(square_sum5[3], sq5[3]); + Sum5Horizontal32(sq[1], sq5[4]); + StoreAligned32U32(square_sum5[4], sq5[4]); + LoadAligned16x3U16(sum5, 0, s5[0]); + LoadAligned32x3U32(square_sum5, 0, sq5); + CalculateIntermediate5<0>(s5[0], sq5, scale, ma, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5( + const __m128i s[2][4], const ptrdiff_t sum_width, const ptrdiff_t x, + const uint32_t scale, uint16_t* const sum5[5], + uint32_t* const square_sum5[5], __m128i sq[2][8], __m128i ma[2], + __m128i b[6]) { + __m128i s5[2][5], sq5[5][2]; + Square(s[0][2], sq[0] + 4); + Square(s[1][2], sq[1] + 4); + s5[0][3] = Sum5Horizontal16(s[0] + 1); + s5[1][3] = Sum5Horizontal16(s[0] + 2); + StoreAligned16(sum5[3] + x + 0, s5[0][3]); + StoreAligned16(sum5[3] + x + 8, s5[1][3]); + s5[0][4] = Sum5Horizontal16(s[1] + 1); + s5[1][4] = Sum5Horizontal16(s[1] + 2); + StoreAligned16(sum5[4] + x + 0, s5[0][4]); + StoreAligned16(sum5[4] + x + 8, s5[1][4]); + Sum5Horizontal32(sq[0] + 2, sq5[3]); + StoreAligned32U32(square_sum5[3] + x, sq5[3]); + Sum5Horizontal32(sq[1] + 2, sq5[4]); + StoreAligned32U32(square_sum5[4] + x, sq5[4]); + LoadAligned16x3U16(sum5, x, s5[0]); + LoadAligned32x3U32(square_sum5, x, sq5); + CalculateIntermediate5<8>(s5[0], sq5, scale, &ma[0], b + 2); + + Square(s[0][3], sq[0] + 6); + Square(s[1][3], sq[1] + 6); + Sum5Horizontal32(sq[0] + 4, sq5[3]); + StoreAligned32U32(square_sum5[3] + x + 8, sq5[3]); + Sum5Horizontal32(sq[1] + 4, sq5[4]); + StoreAligned32U32(square_sum5[4] + x + 8, sq5[4]); + LoadAligned16x3U16Msan(sum5, x + 8, sum_width, s5[1]); + LoadAligned32x3U32Msan(square_sum5, x + 8, sum_width, sq5); + CalculateIntermediate5<0>(s5[1], sq5, scale, &ma[1], b + 4); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRowLo( + const __m128i s[2], const uint32_t scale, const uint16_t* const sum5[5], + const uint32_t* const square_sum5[5], __m128i sq[4], __m128i* const ma, + __m128i b[2]) { + __m128i s5[5], sq5[5][2]; + Square(s[1], sq + 2); + s5[3] = s5[4] = Sum5Horizontal16(s); + Sum5Horizontal32(sq, sq5[3]); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + LoadAligned16x3U16(sum5, 0, s5); + LoadAligned32x3U32(square_sum5, 0, sq5); + CalculateIntermediate5<0>(s5, sq5, scale, ma, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRow( + const __m128i s[4], const ptrdiff_t sum_width, const ptrdiff_t x, + const uint32_t scale, const uint16_t* const sum5[5], + const uint32_t* const square_sum5[5], __m128i sq[8], __m128i ma[2], + __m128i b[6]) { + __m128i s5[2][5], sq5[5][2]; + Square(s[2], sq + 4); + s5[0][3] = Sum5Horizontal16(s + 1); + s5[1][3] = Sum5Horizontal16(s + 2); + s5[0][4] = s5[0][3]; + s5[1][4] = s5[1][3]; + Sum5Horizontal32(sq + 2, sq5[3]); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + LoadAligned16x3U16(sum5, x, s5[0]); + LoadAligned32x3U32(square_sum5, x, sq5); + CalculateIntermediate5<8>(s5[0], sq5, scale, &ma[0], b + 2); + + Square(s[3], sq + 6); + Sum5Horizontal32(sq + 4, sq5[3]); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + LoadAligned16x3U16Msan(sum5, x + 8, sum_width, s5[1]); + LoadAligned32x3U32Msan(square_sum5, x + 8, sum_width, sq5); + CalculateIntermediate5<0>(s5[1], sq5, scale, &ma[1], b + 4); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3Lo( + const __m128i s[2], const uint32_t scale, uint16_t* const sum3[3], + uint32_t* const square_sum3[3], __m128i sq[4], __m128i* const ma, + __m128i b[2]) { + __m128i s3[3], sq3[3][2]; + Square(s[1], sq + 2); + s3[2] = Sum3Horizontal16(s); + StoreAligned16(sum3[2], s3[2]); + Sum3Horizontal32(sq, sq3[2]); + StoreAligned32U32(square_sum3[2], sq3[2]); + LoadAligned16x2U16(sum3, 0, s3); + LoadAligned32x2U32(square_sum3, 0, sq3); + CalculateIntermediate3(s3, sq3, scale, ma, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3( + const __m128i s[4], const ptrdiff_t x, const ptrdiff_t sum_width, + const uint32_t scale, uint16_t* const sum3[3], + uint32_t* const square_sum3[3], __m128i sq[8], __m128i ma[2], + __m128i b[6]) { + __m128i s3[4], sq3[3][2], sum[2], index[2]; + Square(s[2], sq + 4); + s3[2] = Sum3Horizontal16(s + 1); + s3[3] = Sum3Horizontal16(s + 2); + StoreAligned32U16(sum3[2] + x, s3 + 2); + Sum3Horizontal32(sq + 2, sq3[2]); + StoreAligned32U32(square_sum3[2] + x + 0, sq3[2]); + LoadAligned16x2U16(sum3, x, s3); + LoadAligned32x2U32(square_sum3, x, sq3); + CalculateSumAndIndex3(s3, sq3, scale, &sum[0], &index[0]); + + Square(s[3], sq + 6); + Sum3Horizontal32(sq + 4, sq3[2]); + StoreAligned32U32(square_sum3[2] + x + 8, sq3[2]); + LoadAligned16x2U16Msan(sum3, x + 8, sum_width, s3 + 1); + LoadAligned32x2U32Msan(square_sum3, x + 8, sum_width, sq3); + CalculateSumAndIndex3(s3 + 1, sq3, scale, &sum[1], &index[1]); + CalculateIntermediate(sum, index, ma, b + 2); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLo( + const __m128i s[2][4], const uint16_t scales[2], uint16_t* const sum3[4], + uint16_t* const sum5[5], uint32_t* const square_sum3[4], + uint32_t* const square_sum5[5], __m128i sq[2][8], __m128i ma3[2][2], + __m128i b3[2][6], __m128i* const ma5, __m128i b5[2]) { + __m128i s3[4], s5[5], sq3[4][2], sq5[5][2], sum[2], index[2]; + Square(s[0][1], sq[0] + 2); + Square(s[1][1], sq[1] + 2); + SumHorizontal16(s[0], &s3[2], &s5[3]); + SumHorizontal16(s[1], &s3[3], &s5[4]); + StoreAligned16(sum3[2], s3[2]); + StoreAligned16(sum3[3], s3[3]); + StoreAligned16(sum5[3], s5[3]); + StoreAligned16(sum5[4], s5[4]); + SumHorizontal32(sq[0], &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + StoreAligned32U32(square_sum3[2], sq3[2]); + StoreAligned32U32(square_sum5[3], sq5[3]); + SumHorizontal32(sq[1], &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]); + StoreAligned32U32(square_sum3[3], sq3[3]); + StoreAligned32U32(square_sum5[4], sq5[4]); + LoadAligned16x2U16(sum3, 0, s3); + LoadAligned32x2U32(square_sum3, 0, sq3); + LoadAligned16x3U16(sum5, 0, s5); + LoadAligned32x3U32(square_sum5, 0, sq5); + CalculateSumAndIndex3(s3 + 0, sq3 + 0, scales[1], &sum[0], &index[0]); + CalculateSumAndIndex3(s3 + 1, sq3 + 1, scales[1], &sum[1], &index[1]); + CalculateIntermediate(sum, index, &ma3[0][0], b3[0], b3[1]); + ma3[1][0] = _mm_srli_si128(ma3[0][0], 8); + CalculateIntermediate5<0>(s5, sq5, scales[0], ma5, b5); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess( + const __m128i s[2][4], const ptrdiff_t x, const uint16_t scales[2], + uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + const ptrdiff_t sum_width, __m128i sq[2][8], __m128i ma3[2][2], + __m128i b3[2][6], __m128i ma5[2], __m128i b5[6]) { + __m128i s3[2][4], s5[2][5], sq3[4][2], sq5[5][2], sum[2][2], index[2][2]; + SumHorizontal16(s[0] + 1, &s3[0][2], &s3[1][2], &s5[0][3], &s5[1][3]); + StoreAligned16(sum3[2] + x + 0, s3[0][2]); + StoreAligned16(sum3[2] + x + 8, s3[1][2]); + StoreAligned16(sum5[3] + x + 0, s5[0][3]); + StoreAligned16(sum5[3] + x + 8, s5[1][3]); + SumHorizontal16(s[1] + 1, &s3[0][3], &s3[1][3], &s5[0][4], &s5[1][4]); + StoreAligned16(sum3[3] + x + 0, s3[0][3]); + StoreAligned16(sum3[3] + x + 8, s3[1][3]); + StoreAligned16(sum5[4] + x + 0, s5[0][4]); + StoreAligned16(sum5[4] + x + 8, s5[1][4]); + Square(s[0][2], sq[0] + 4); + Square(s[1][2], sq[1] + 4); + SumHorizontal32(sq[0] + 2, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + StoreAligned32U32(square_sum3[2] + x, sq3[2]); + StoreAligned32U32(square_sum5[3] + x, sq5[3]); + SumHorizontal32(sq[1] + 2, &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]); + StoreAligned32U32(square_sum3[3] + x, sq3[3]); + StoreAligned32U32(square_sum5[4] + x, sq5[4]); + LoadAligned16x2U16(sum3, x, s3[0]); + LoadAligned32x2U32(square_sum3, x, sq3); + CalculateSumAndIndex3(s3[0], sq3, scales[1], &sum[0][0], &index[0][0]); + CalculateSumAndIndex3(s3[0] + 1, sq3 + 1, scales[1], &sum[1][0], + &index[1][0]); + LoadAligned16x3U16(sum5, x, s5[0]); + LoadAligned32x3U32(square_sum5, x, sq5); + CalculateIntermediate5<8>(s5[0], sq5, scales[0], &ma5[0], b5 + 2); + + Square(s[0][3], sq[0] + 6); + Square(s[1][3], sq[1] + 6); + SumHorizontal32(sq[0] + 4, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + StoreAligned32U32(square_sum3[2] + x + 8, sq3[2]); + StoreAligned32U32(square_sum5[3] + x + 8, sq5[3]); + SumHorizontal32(sq[1] + 4, &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]); + StoreAligned32U32(square_sum3[3] + x + 8, sq3[3]); + StoreAligned32U32(square_sum5[4] + x + 8, sq5[4]); + LoadAligned16x2U16Msan(sum3, x + 8, sum_width, s3[1]); + LoadAligned32x2U32Msan(square_sum3, x + 8, sum_width, sq3); + CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum[0][1], &index[0][1]); + CalculateSumAndIndex3(s3[1] + 1, sq3 + 1, scales[1], &sum[1][1], + &index[1][1]); + CalculateIntermediate(sum[0], index[0], ma3[0], b3[0] + 2); + CalculateIntermediate(sum[1], index[1], ma3[1], b3[1] + 2); + LoadAligned16x3U16Msan(sum5, x + 8, sum_width, s5[1]); + LoadAligned32x3U32Msan(square_sum5, x + 8, sum_width, sq5); + CalculateIntermediate5<0>(s5[1], sq5, scales[0], &ma5[1], b5 + 4); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRowLo( + const __m128i s[2], const uint16_t scales[2], const uint16_t* const sum3[4], + const uint16_t* const sum5[5], const uint32_t* const square_sum3[4], + const uint32_t* const square_sum5[5], __m128i sq[4], __m128i* const ma3, + __m128i* const ma5, __m128i b3[2], __m128i b5[2]) { + __m128i s3[3], s5[5], sq3[3][2], sq5[5][2]; + Square(s[1], sq + 2); + SumHorizontal16(s, &s3[2], &s5[3]); + SumHorizontal32(sq, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + LoadAligned16x3U16(sum5, 0, s5); + s5[4] = s5[3]; + LoadAligned32x3U32(square_sum5, 0, sq5); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + CalculateIntermediate5<0>(s5, sq5, scales[0], ma5, b5); + LoadAligned16x2U16(sum3, 0, s3); + LoadAligned32x2U32(square_sum3, 0, sq3); + CalculateIntermediate3(s3, sq3, scales[1], ma3, b3); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow( + const __m128i s[4], const ptrdiff_t sum_width, const ptrdiff_t x, + const uint16_t scales[2], const uint16_t* const sum3[4], + const uint16_t* const sum5[5], const uint32_t* const square_sum3[4], + const uint32_t* const square_sum5[5], __m128i sq[8], __m128i ma3[2], + __m128i ma5[2], __m128i b3[6], __m128i b5[6]) { + __m128i s3[2][3], s5[2][5], sq3[3][2], sq5[5][2], sum[2], index[2]; + Square(s[2], sq + 4); + SumHorizontal16(s + 1, &s3[0][2], &s3[1][2], &s5[0][3], &s5[1][3]); + SumHorizontal32(sq + 2, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + LoadAligned16x3U16(sum5, x, s5[0]); + s5[0][4] = s5[0][3]; + LoadAligned32x3U32(square_sum5, x, sq5); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + CalculateIntermediate5<8>(s5[0], sq5, scales[0], ma5, b5 + 2); + LoadAligned16x2U16(sum3, x, s3[0]); + LoadAligned32x2U32(square_sum3, x, sq3); + CalculateSumAndIndex3(s3[0], sq3, scales[1], &sum[0], &index[0]); + + Square(s[3], sq + 6); + SumHorizontal32(sq + 4, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + LoadAligned16x3U16Msan(sum5, x + 8, sum_width, s5[1]); + s5[1][4] = s5[1][3]; + LoadAligned32x3U32Msan(square_sum5, x + 8, sum_width, sq5); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + CalculateIntermediate5<0>(s5[1], sq5, scales[0], ma5 + 1, b5 + 4); + LoadAligned16x2U16Msan(sum3, x + 8, sum_width, s3[1]); + LoadAligned32x2U32Msan(square_sum3, x + 8, sum_width, sq3); + CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum[1], &index[1]); + CalculateIntermediate(sum, index, ma3, b3 + 2); +} + +inline void BoxSumFilterPreProcess5(const uint16_t* const src0, + const uint16_t* const src1, const int width, + const uint32_t scale, + uint16_t* const sum5[5], + uint32_t* const square_sum5[5], + const ptrdiff_t sum_width, uint16_t* ma565, + uint32_t* b565) { + const ptrdiff_t overread_in_bytes = + kOverreadInBytesPass1 - sizeof(*src0) * width; + __m128i s[2][4], mas[2], sq[2][8], bs[6]; + s[0][0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes + 0); + s[0][1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes + 16); + s[1][0] = LoadUnaligned16Msan(src1 + 0, overread_in_bytes + 0); + s[1][1] = LoadUnaligned16Msan(src1 + 8, overread_in_bytes + 16); + Square(s[0][0], sq[0]); + Square(s[1][0], sq[1]); + BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq, &mas[0], bs); + + int x = 0; + do { + __m128i ma5[3], ma[2], b[4]; + s[0][2] = LoadUnaligned16Msan(src0 + x + 16, + overread_in_bytes + sizeof(*src0) * (x + 16)); + s[0][3] = LoadUnaligned16Msan(src0 + x + 24, + overread_in_bytes + sizeof(*src0) * (x + 24)); + s[1][2] = LoadUnaligned16Msan(src1 + x + 16, + overread_in_bytes + sizeof(*src1) * (x + 16)); + s[1][3] = LoadUnaligned16Msan(src1 + x + 24, + overread_in_bytes + sizeof(*src1) * (x + 24)); + BoxFilterPreProcess5(s, sum_width, x + 8, scale, sum5, square_sum5, sq, mas, + bs); + Prepare3_8<0>(mas, ma5); + ma[0] = Sum565Lo(ma5); + ma[1] = Sum565Hi(ma5); + StoreAligned32U16(ma565, ma); + Sum565(bs + 0, b + 0); + Sum565(bs + 2, b + 2); + StoreAligned64U32(b565, b); + s[0][0] = s[0][2]; + s[0][1] = s[0][3]; + s[1][0] = s[1][2]; + s[1][1] = s[1][3]; + sq[0][2] = sq[0][6]; + sq[0][3] = sq[0][7]; + sq[1][2] = sq[1][6]; + sq[1][3] = sq[1][7]; + mas[0] = mas[1]; + bs[0] = bs[4]; + bs[1] = bs[5]; + ma565 += 16; + b565 += 16; + x += 16; + } while (x < width); +} + +template <bool calculate444> +LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3( + const uint16_t* const src, const int width, const uint32_t scale, + uint16_t* const sum3[3], uint32_t* const square_sum3[3], + const ptrdiff_t sum_width, uint16_t* ma343, uint16_t* ma444, uint32_t* b343, + uint32_t* b444) { + const ptrdiff_t overread_in_bytes = + kOverreadInBytesPass2 - sizeof(*src) * width; + __m128i s[4], mas[2], sq[8], bs[6]; + s[0] = LoadUnaligned16Msan(src + 0, overread_in_bytes + 0); + s[1] = LoadUnaligned16Msan(src + 8, overread_in_bytes + 16); + Square(s[0], sq); + BoxFilterPreProcess3Lo(s, scale, sum3, square_sum3, sq, &mas[0], bs); + + int x = 0; + do { + s[2] = LoadUnaligned16Msan(src + x + 16, + overread_in_bytes + sizeof(*src) * (x + 16)); + s[3] = LoadUnaligned16Msan(src + x + 24, + overread_in_bytes + sizeof(*src) * (x + 24)); + BoxFilterPreProcess3(s, x + 8, sum_width, scale, sum3, square_sum3, sq, mas, + bs); + __m128i ma3[3]; + Prepare3_8<0>(mas, ma3); + if (calculate444) { // NOLINT(readability-simplify-boolean-expr) + Store343_444Lo(ma3, bs + 0, 0, ma343, ma444, b343, b444); + Store343_444Hi(ma3, bs + 2, 8, ma343, ma444, b343, b444); + ma444 += 16; + b444 += 16; + } else { + __m128i ma[2], b[4]; + ma[0] = Sum343Lo(ma3); + ma[1] = Sum343Hi(ma3); + StoreAligned32U16(ma343, ma); + Sum343(bs + 0, b + 0); + Sum343(bs + 2, b + 2); + StoreAligned64U32(b343, b); + } + s[1] = s[3]; + sq[2] = sq[6]; + sq[3] = sq[7]; + mas[0] = mas[1]; + bs[0] = bs[4]; + bs[1] = bs[5]; + ma343 += 16; + b343 += 16; + x += 16; + } while (x < width); +} + +inline void BoxSumFilterPreProcess( + const uint16_t* const src0, const uint16_t* const src1, const int width, + const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + const ptrdiff_t sum_width, uint16_t* const ma343[4], uint16_t* const ma444, + uint16_t* ma565, uint32_t* const b343[4], uint32_t* const b444, + uint32_t* b565) { + const ptrdiff_t overread_in_bytes = + kOverreadInBytesPass1 - sizeof(*src0) * width; + __m128i s[2][4], ma3[2][2], ma5[2], sq[2][8], b3[2][6], b5[6]; + s[0][0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes + 0); + s[0][1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes + 16); + s[1][0] = LoadUnaligned16Msan(src1 + 0, overread_in_bytes + 0); + s[1][1] = LoadUnaligned16Msan(src1 + 8, overread_in_bytes + 16); + Square(s[0][0], sq[0]); + Square(s[1][0], sq[1]); + BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq, + ma3, b3, &ma5[0], b5); + + int x = 0; + do { + __m128i ma[2], b[4], ma3x[3], ma5x[3]; + s[0][2] = LoadUnaligned16Msan(src0 + x + 16, + overread_in_bytes + sizeof(*src0) * (x + 16)); + s[0][3] = LoadUnaligned16Msan(src0 + x + 24, + overread_in_bytes + sizeof(*src0) * (x + 24)); + s[1][2] = LoadUnaligned16Msan(src1 + x + 16, + overread_in_bytes + sizeof(*src1) * (x + 16)); + s[1][3] = LoadUnaligned16Msan(src1 + x + 24, + overread_in_bytes + sizeof(*src1) * (x + 24)); + BoxFilterPreProcess(s, x + 8, scales, sum3, sum5, square_sum3, square_sum5, + sum_width, sq, ma3, b3, ma5, b5); + + Prepare3_8<0>(ma3[0], ma3x); + ma[0] = Sum343Lo(ma3x); + ma[1] = Sum343Hi(ma3x); + StoreAligned32U16(ma343[0] + x, ma); + Sum343(b3[0] + 0, b + 0); + Sum343(b3[0] + 2, b + 2); + StoreAligned64U32(b343[0] + x, b); + Sum565(b5 + 0, b + 0); + Sum565(b5 + 2, b + 2); + StoreAligned64U32(b565, b); + Prepare3_8<0>(ma3[1], ma3x); + Store343_444Lo(ma3x, b3[1], x, ma343[1], ma444, b343[1], b444); + Store343_444Hi(ma3x, b3[1] + 2, x + 8, ma343[1], ma444, b343[1], b444); + Prepare3_8<0>(ma5, ma5x); + ma[0] = Sum565Lo(ma5x); + ma[1] = Sum565Hi(ma5x); + StoreAligned32U16(ma565, ma); + s[0][0] = s[0][2]; + s[0][1] = s[0][3]; + s[1][0] = s[1][2]; + s[1][1] = s[1][3]; + sq[0][2] = sq[0][6]; + sq[0][3] = sq[0][7]; + sq[1][2] = sq[1][6]; + sq[1][3] = sq[1][7]; + ma3[0][0] = ma3[0][1]; + ma3[1][0] = ma3[1][1]; + ma5[0] = ma5[1]; + b3[0][0] = b3[0][4]; + b3[0][1] = b3[0][5]; + b3[1][0] = b3[1][4]; + b3[1][1] = b3[1][5]; + b5[0] = b5[4]; + b5[1] = b5[5]; + ma565 += 16; + b565 += 16; + x += 16; + } while (x < width); +} + +template <int shift> +inline __m128i FilterOutput(const __m128i ma_x_src, const __m128i b) { + // ma: 255 * 32 = 8160 (13 bits) + // b: 65088 * 32 = 2082816 (21 bits) + // v: b - ma * 255 (22 bits) + const __m128i v = _mm_sub_epi32(b, ma_x_src); + // kSgrProjSgrBits = 8 + // kSgrProjRestoreBits = 4 + // shift = 4 or 5 + // v >> 8 or 9 (13 bits) + return VrshrS32(v, kSgrProjSgrBits + shift - kSgrProjRestoreBits); +} + +template <int shift> +inline __m128i CalculateFilteredOutput(const __m128i src, const __m128i ma, + const __m128i b[2]) { + const __m128i ma_x_src_lo = VmullLo16(ma, src); + const __m128i ma_x_src_hi = VmullHi16(ma, src); + const __m128i dst_lo = FilterOutput<shift>(ma_x_src_lo, b[0]); + const __m128i dst_hi = FilterOutput<shift>(ma_x_src_hi, b[1]); + return _mm_packs_epi32(dst_lo, dst_hi); // 13 bits +} + +inline __m128i CalculateFilteredOutputPass1(const __m128i src, + const __m128i ma[2], + const __m128i b[2][2]) { + const __m128i ma_sum = _mm_add_epi16(ma[0], ma[1]); + __m128i b_sum[2]; + b_sum[0] = _mm_add_epi32(b[0][0], b[1][0]); + b_sum[1] = _mm_add_epi32(b[0][1], b[1][1]); + return CalculateFilteredOutput<5>(src, ma_sum, b_sum); +} + +inline __m128i CalculateFilteredOutputPass2(const __m128i src, + const __m128i ma[3], + const __m128i b[3][2]) { + const __m128i ma_sum = Sum3_16(ma); + __m128i b_sum[2]; + Sum3_32(b, b_sum); + return CalculateFilteredOutput<5>(src, ma_sum, b_sum); +} + +inline __m128i SelfGuidedFinal(const __m128i src, const __m128i v[2]) { + const __m128i v_lo = + VrshrS32(v[0], kSgrProjRestoreBits + kSgrProjPrecisionBits); + const __m128i v_hi = + VrshrS32(v[1], kSgrProjRestoreBits + kSgrProjPrecisionBits); + const __m128i vv = _mm_packs_epi32(v_lo, v_hi); + return _mm_add_epi16(src, vv); +} + +inline __m128i SelfGuidedDoubleMultiplier(const __m128i src, + const __m128i filter[2], const int w0, + const int w2) { + __m128i v[2]; + const __m128i w0_w2 = _mm_set1_epi32((w2 << 16) | static_cast<uint16_t>(w0)); + const __m128i f_lo = _mm_unpacklo_epi16(filter[0], filter[1]); + const __m128i f_hi = _mm_unpackhi_epi16(filter[0], filter[1]); + v[0] = _mm_madd_epi16(w0_w2, f_lo); + v[1] = _mm_madd_epi16(w0_w2, f_hi); + return SelfGuidedFinal(src, v); +} + +inline __m128i SelfGuidedSingleMultiplier(const __m128i src, + const __m128i filter, const int w0) { + // weight: -96 to 96 (Sgrproj_Xqd_Min/Max) + __m128i v[2]; + v[0] = VmullNLo8(filter, w0); + v[1] = VmullNHi8(filter, w0); + return SelfGuidedFinal(src, v); +} + +inline void ClipAndStore(uint16_t* const dst, const __m128i val) { + const __m128i val0 = _mm_max_epi16(val, _mm_setzero_si128()); + const __m128i val1 = _mm_min_epi16(val0, _mm_set1_epi16(1023)); + StoreAligned16(dst, val1); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPass1( + const uint16_t* const src, const uint16_t* const src0, + const uint16_t* const src1, const ptrdiff_t stride, uint16_t* const sum5[5], + uint32_t* const square_sum5[5], const int width, const ptrdiff_t sum_width, + const uint32_t scale, const int16_t w0, uint16_t* const ma565[2], + uint32_t* const b565[2], uint16_t* const dst) { + const ptrdiff_t overread_in_bytes = + kOverreadInBytesPass1 - sizeof(*src0) * width; + __m128i s[2][4], mas[2], sq[2][8], bs[6]; + s[0][0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes + 0); + s[0][1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes + 16); + s[1][0] = LoadUnaligned16Msan(src1 + 0, overread_in_bytes + 0); + s[1][1] = LoadUnaligned16Msan(src1 + 8, overread_in_bytes + 16); + Square(s[0][0], sq[0]); + Square(s[1][0], sq[1]); + BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq, &mas[0], bs); + + int x = 0; + do { + __m128i ma[2], ma5[3], b[2][2], p[2]; + s[0][2] = LoadUnaligned16Msan(src0 + x + 16, + overread_in_bytes + sizeof(*src0) * (x + 16)); + s[0][3] = LoadUnaligned16Msan(src0 + x + 24, + overread_in_bytes + sizeof(*src0) * (x + 24)); + s[1][2] = LoadUnaligned16Msan(src1 + x + 16, + overread_in_bytes + sizeof(*src1) * (x + 16)); + s[1][3] = LoadUnaligned16Msan(src1 + x + 24, + overread_in_bytes + sizeof(*src1) * (x + 24)); + BoxFilterPreProcess5(s, sum_width, x + 8, scale, sum5, square_sum5, sq, mas, + bs); + Prepare3_8<0>(mas, ma5); + ma[1] = Sum565Lo(ma5); + StoreAligned16(ma565[1] + x, ma[1]); + Sum565(bs, b[1]); + StoreAligned32U32(b565[1] + x, b[1]); + const __m128i sr0_lo = LoadAligned16(src + x + 0); + const __m128i sr1_lo = LoadAligned16(src + stride + x + 0); + ma[0] = LoadAligned16(ma565[0] + x); + LoadAligned32U32(b565[0] + x, b[0]); + p[0] = CalculateFilteredOutputPass1(sr0_lo, ma, b); + p[1] = CalculateFilteredOutput<4>(sr1_lo, ma[1], b[1]); + const __m128i d00 = SelfGuidedSingleMultiplier(sr0_lo, p[0], w0); + const __m128i d10 = SelfGuidedSingleMultiplier(sr1_lo, p[1], w0); + + ma[1] = Sum565Hi(ma5); + StoreAligned16(ma565[1] + x + 8, ma[1]); + Sum565(bs + 2, b[1]); + StoreAligned32U32(b565[1] + x + 8, b[1]); + const __m128i sr0_hi = LoadAligned16(src + x + 8); + const __m128i sr1_hi = LoadAligned16(src + stride + x + 8); + ma[0] = LoadAligned16(ma565[0] + x + 8); + LoadAligned32U32(b565[0] + x + 8, b[0]); + p[0] = CalculateFilteredOutputPass1(sr0_hi, ma, b); + p[1] = CalculateFilteredOutput<4>(sr1_hi, ma[1], b[1]); + const __m128i d01 = SelfGuidedSingleMultiplier(sr0_hi, p[0], w0); + ClipAndStore(dst + x + 0, d00); + ClipAndStore(dst + x + 8, d01); + const __m128i d11 = SelfGuidedSingleMultiplier(sr1_hi, p[1], w0); + ClipAndStore(dst + stride + x + 0, d10); + ClipAndStore(dst + stride + x + 8, d11); + s[0][0] = s[0][2]; + s[0][1] = s[0][3]; + s[1][0] = s[1][2]; + s[1][1] = s[1][3]; + sq[0][2] = sq[0][6]; + sq[0][3] = sq[0][7]; + sq[1][2] = sq[1][6]; + sq[1][3] = sq[1][7]; + mas[0] = mas[1]; + bs[0] = bs[4]; + bs[1] = bs[5]; + x += 16; + } while (x < width); +} + +inline void BoxFilterPass1LastRow( + const uint16_t* const src, const uint16_t* const src0, const int width, + const ptrdiff_t sum_width, const uint32_t scale, const int16_t w0, + uint16_t* const sum5[5], uint32_t* const square_sum5[5], uint16_t* ma565, + uint32_t* b565, uint16_t* const dst) { + const ptrdiff_t overread_in_bytes = + kOverreadInBytesPass1 - sizeof(*src0) * width; + __m128i s[4], mas[2], sq[8], bs[6]; + s[0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes + 0); + s[1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes + 16); + Square(s[0], sq); + BoxFilterPreProcess5LastRowLo(s, scale, sum5, square_sum5, sq, &mas[0], bs); + + int x = 0; + do { + __m128i ma[2], ma5[3], b[2][2]; + s[2] = LoadUnaligned16Msan(src0 + x + 16, + overread_in_bytes + sizeof(*src0) * (x + 16)); + s[3] = LoadUnaligned16Msan(src0 + x + 24, + overread_in_bytes + sizeof(*src0) * (x + 24)); + BoxFilterPreProcess5LastRow(s, sum_width, x + 8, scale, sum5, square_sum5, + sq, mas, bs); + Prepare3_8<0>(mas, ma5); + ma[1] = Sum565Lo(ma5); + Sum565(bs, b[1]); + ma[0] = LoadAligned16(ma565); + LoadAligned32U32(b565, b[0]); + const __m128i sr_lo = LoadAligned16(src + x + 0); + __m128i p = CalculateFilteredOutputPass1(sr_lo, ma, b); + const __m128i d0 = SelfGuidedSingleMultiplier(sr_lo, p, w0); + + ma[1] = Sum565Hi(ma5); + Sum565(bs + 2, b[1]); + ma[0] = LoadAligned16(ma565 + 8); + LoadAligned32U32(b565 + 8, b[0]); + const __m128i sr_hi = LoadAligned16(src + x + 8); + p = CalculateFilteredOutputPass1(sr_hi, ma, b); + const __m128i d1 = SelfGuidedSingleMultiplier(sr_hi, p, w0); + ClipAndStore(dst + x + 0, d0); + ClipAndStore(dst + x + 8, d1); + s[1] = s[3]; + sq[2] = sq[6]; + sq[3] = sq[7]; + mas[0] = mas[1]; + bs[0] = bs[4]; + bs[1] = bs[5]; + ma565 += 16; + b565 += 16; + x += 16; + } while (x < width); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPass2( + const uint16_t* const src, const uint16_t* const src0, const int width, + const ptrdiff_t sum_width, const uint32_t scale, const int16_t w0, + uint16_t* const sum3[3], uint32_t* const square_sum3[3], + uint16_t* const ma343[3], uint16_t* const ma444[2], uint32_t* const b343[3], + uint32_t* const b444[2], uint16_t* const dst) { + const ptrdiff_t overread_in_bytes = + kOverreadInBytesPass2 - sizeof(*src0) * width; + __m128i s[4], mas[2], sq[8], bs[6]; + s[0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes + 0); + s[1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes + 16); + Square(s[0], sq); + BoxFilterPreProcess3Lo(s, scale, sum3, square_sum3, sq, &mas[0], bs); + + int x = 0; + do { + s[2] = LoadUnaligned16Msan(src0 + x + 16, + overread_in_bytes + sizeof(*src0) * (x + 16)); + s[3] = LoadUnaligned16Msan(src0 + x + 24, + overread_in_bytes + sizeof(*src0) * (x + 24)); + BoxFilterPreProcess3(s, x + 8, sum_width, scale, sum3, square_sum3, sq, mas, + bs); + __m128i ma[3], b[3][2], ma3[3]; + Prepare3_8<0>(mas, ma3); + Store343_444Lo(ma3, bs + 0, x, &ma[2], b[2], ma343[2], ma444[1], b343[2], + b444[1]); + const __m128i sr_lo = LoadAligned16(src + x + 0); + ma[0] = LoadAligned16(ma343[0] + x); + ma[1] = LoadAligned16(ma444[0] + x); + LoadAligned32U32(b343[0] + x, b[0]); + LoadAligned32U32(b444[0] + x, b[1]); + const __m128i p0 = CalculateFilteredOutputPass2(sr_lo, ma, b); + + Store343_444Hi(ma3, bs + 2, x + 8, &ma[2], b[2], ma343[2], ma444[1], + b343[2], b444[1]); + const __m128i sr_hi = LoadAligned16(src + x + 8); + ma[0] = LoadAligned16(ma343[0] + x + 8); + ma[1] = LoadAligned16(ma444[0] + x + 8); + LoadAligned32U32(b343[0] + x + 8, b[0]); + LoadAligned32U32(b444[0] + x + 8, b[1]); + const __m128i p1 = CalculateFilteredOutputPass2(sr_hi, ma, b); + const __m128i d0 = SelfGuidedSingleMultiplier(sr_lo, p0, w0); + const __m128i d1 = SelfGuidedSingleMultiplier(sr_hi, p1, w0); + ClipAndStore(dst + x + 0, d0); + ClipAndStore(dst + x + 8, d1); + s[1] = s[3]; + sq[2] = sq[6]; + sq[3] = sq[7]; + mas[0] = mas[1]; + bs[0] = bs[4]; + bs[1] = bs[5]; + x += 16; + } while (x < width); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilter( + const uint16_t* const src, const uint16_t* const src0, + const uint16_t* const src1, const ptrdiff_t stride, const int width, + const uint16_t scales[2], const int16_t w0, const int16_t w2, + uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + const ptrdiff_t sum_width, uint16_t* const ma343[4], + uint16_t* const ma444[3], uint16_t* const ma565[2], uint32_t* const b343[4], + uint32_t* const b444[3], uint32_t* const b565[2], uint16_t* const dst) { + const ptrdiff_t overread_in_bytes = + kOverreadInBytesPass1 - sizeof(*src0) * width; + __m128i s[2][4], ma3[2][2], ma5[2], sq[2][8], b3[2][6], b5[6]; + s[0][0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes + 0); + s[0][1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes + 16); + s[1][0] = LoadUnaligned16Msan(src1 + 0, overread_in_bytes + 0); + s[1][1] = LoadUnaligned16Msan(src1 + 8, overread_in_bytes + 16); + Square(s[0][0], sq[0]); + Square(s[1][0], sq[1]); + BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq, + ma3, b3, &ma5[0], b5); + + int x = 0; + do { + __m128i ma[3][3], b[3][3][2], p[2][2], ma3x[2][3], ma5x[3]; + s[0][2] = LoadUnaligned16Msan(src0 + x + 16, + overread_in_bytes + sizeof(*src0) * (x + 16)); + s[0][3] = LoadUnaligned16Msan(src0 + x + 24, + overread_in_bytes + sizeof(*src0) * (x + 24)); + s[1][2] = LoadUnaligned16Msan(src1 + x + 16, + overread_in_bytes + sizeof(*src1) * (x + 16)); + s[1][3] = LoadUnaligned16Msan(src1 + x + 24, + overread_in_bytes + sizeof(*src1) * (x + 24)); + BoxFilterPreProcess(s, x + 8, scales, sum3, sum5, square_sum3, square_sum5, + sum_width, sq, ma3, b3, ma5, b5); + Prepare3_8<0>(ma3[0], ma3x[0]); + Prepare3_8<0>(ma3[1], ma3x[1]); + Prepare3_8<0>(ma5, ma5x); + Store343_444Lo(ma3x[0], b3[0], x, &ma[1][2], &ma[2][1], b[1][2], b[2][1], + ma343[2], ma444[1], b343[2], b444[1]); + Store343_444Lo(ma3x[1], b3[1], x, &ma[2][2], b[2][2], ma343[3], ma444[2], + b343[3], b444[2]); + ma[0][1] = Sum565Lo(ma5x); + StoreAligned16(ma565[1] + x, ma[0][1]); + Sum565(b5, b[0][1]); + StoreAligned32U32(b565[1] + x, b[0][1]); + const __m128i sr0_lo = LoadAligned16(src + x); + const __m128i sr1_lo = LoadAligned16(src + stride + x); + ma[0][0] = LoadAligned16(ma565[0] + x); + LoadAligned32U32(b565[0] + x, b[0][0]); + p[0][0] = CalculateFilteredOutputPass1(sr0_lo, ma[0], b[0]); + p[1][0] = CalculateFilteredOutput<4>(sr1_lo, ma[0][1], b[0][1]); + ma[1][0] = LoadAligned16(ma343[0] + x); + ma[1][1] = LoadAligned16(ma444[0] + x); + LoadAligned32U32(b343[0] + x, b[1][0]); + LoadAligned32U32(b444[0] + x, b[1][1]); + p[0][1] = CalculateFilteredOutputPass2(sr0_lo, ma[1], b[1]); + const __m128i d00 = SelfGuidedDoubleMultiplier(sr0_lo, p[0], w0, w2); + ma[2][0] = LoadAligned16(ma343[1] + x); + LoadAligned32U32(b343[1] + x, b[2][0]); + p[1][1] = CalculateFilteredOutputPass2(sr1_lo, ma[2], b[2]); + const __m128i d10 = SelfGuidedDoubleMultiplier(sr1_lo, p[1], w0, w2); + + Store343_444Hi(ma3x[0], b3[0] + 2, x + 8, &ma[1][2], &ma[2][1], b[1][2], + b[2][1], ma343[2], ma444[1], b343[2], b444[1]); + Store343_444Hi(ma3x[1], b3[1] + 2, x + 8, &ma[2][2], b[2][2], ma343[3], + ma444[2], b343[3], b444[2]); + ma[0][1] = Sum565Hi(ma5x); + StoreAligned16(ma565[1] + x + 8, ma[0][1]); + Sum565(b5 + 2, b[0][1]); + StoreAligned32U32(b565[1] + x + 8, b[0][1]); + const __m128i sr0_hi = LoadAligned16(src + x + 8); + const __m128i sr1_hi = LoadAligned16(src + stride + x + 8); + ma[0][0] = LoadAligned16(ma565[0] + x + 8); + LoadAligned32U32(b565[0] + x + 8, b[0][0]); + p[0][0] = CalculateFilteredOutputPass1(sr0_hi, ma[0], b[0]); + p[1][0] = CalculateFilteredOutput<4>(sr1_hi, ma[0][1], b[0][1]); + ma[1][0] = LoadAligned16(ma343[0] + x + 8); + ma[1][1] = LoadAligned16(ma444[0] + x + 8); + LoadAligned32U32(b343[0] + x + 8, b[1][0]); + LoadAligned32U32(b444[0] + x + 8, b[1][1]); + p[0][1] = CalculateFilteredOutputPass2(sr0_hi, ma[1], b[1]); + const __m128i d01 = SelfGuidedDoubleMultiplier(sr0_hi, p[0], w0, w2); + ClipAndStore(dst + x + 0, d00); + ClipAndStore(dst + x + 8, d01); + ma[2][0] = LoadAligned16(ma343[1] + x + 8); + LoadAligned32U32(b343[1] + x + 8, b[2][0]); + p[1][1] = CalculateFilteredOutputPass2(sr1_hi, ma[2], b[2]); + const __m128i d11 = SelfGuidedDoubleMultiplier(sr1_hi, p[1], w0, w2); + ClipAndStore(dst + stride + x + 0, d10); + ClipAndStore(dst + stride + x + 8, d11); + s[0][0] = s[0][2]; + s[0][1] = s[0][3]; + s[1][0] = s[1][2]; + s[1][1] = s[1][3]; + sq[0][2] = sq[0][6]; + sq[0][3] = sq[0][7]; + sq[1][2] = sq[1][6]; + sq[1][3] = sq[1][7]; + ma3[0][0] = ma3[0][1]; + ma3[1][0] = ma3[1][1]; + ma5[0] = ma5[1]; + b3[0][0] = b3[0][4]; + b3[0][1] = b3[0][5]; + b3[1][0] = b3[1][4]; + b3[1][1] = b3[1][5]; + b5[0] = b5[4]; + b5[1] = b5[5]; + x += 16; + } while (x < width); +} + +inline void BoxFilterLastRow( + const uint16_t* const src, const uint16_t* const src0, const int width, + const ptrdiff_t sum_width, const uint16_t scales[2], const int16_t w0, + const int16_t w2, uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + uint16_t* const ma343, uint16_t* const ma444, uint16_t* const ma565, + uint32_t* const b343, uint32_t* const b444, uint32_t* const b565, + uint16_t* const dst) { + const ptrdiff_t overread_in_bytes = + kOverreadInBytesPass1 - sizeof(*src0) * width; + __m128i s[4], ma3[2], ma5[2], sq[8], b3[6], b5[6], ma[3], b[3][2]; + s[0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes + 0); + s[1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes + 16); + Square(s[0], sq); + BoxFilterPreProcessLastRowLo(s, scales, sum3, sum5, square_sum3, square_sum5, + sq, &ma3[0], &ma5[0], b3, b5); + + int x = 0; + do { + __m128i ma3x[3], ma5x[3], p[2]; + s[2] = LoadUnaligned16Msan(src0 + x + 16, + overread_in_bytes + sizeof(*src0) * (x + 16)); + s[3] = LoadUnaligned16Msan(src0 + x + 24, + overread_in_bytes + sizeof(*src0) * (x + 24)); + BoxFilterPreProcessLastRow(s, sum_width, x + 8, scales, sum3, sum5, + square_sum3, square_sum5, sq, ma3, ma5, b3, b5); + Prepare3_8<0>(ma3, ma3x); + Prepare3_8<0>(ma5, ma5x); + ma[1] = Sum565Lo(ma5x); + Sum565(b5, b[1]); + ma[2] = Sum343Lo(ma3x); + Sum343(b3, b[2]); + const __m128i sr_lo = LoadAligned16(src + x + 0); + ma[0] = LoadAligned16(ma565 + x); + LoadAligned32U32(b565 + x, b[0]); + p[0] = CalculateFilteredOutputPass1(sr_lo, ma, b); + ma[0] = LoadAligned16(ma343 + x); + ma[1] = LoadAligned16(ma444 + x); + LoadAligned32U32(b343 + x, b[0]); + LoadAligned32U32(b444 + x, b[1]); + p[1] = CalculateFilteredOutputPass2(sr_lo, ma, b); + const __m128i d0 = SelfGuidedDoubleMultiplier(sr_lo, p, w0, w2); + + ma[1] = Sum565Hi(ma5x); + Sum565(b5 + 2, b[1]); + ma[2] = Sum343Hi(ma3x); + Sum343(b3 + 2, b[2]); + const __m128i sr_hi = LoadAligned16(src + x + 8); + ma[0] = LoadAligned16(ma565 + x + 8); + LoadAligned32U32(b565 + x + 8, b[0]); + p[0] = CalculateFilteredOutputPass1(sr_hi, ma, b); + ma[0] = LoadAligned16(ma343 + x + 8); + ma[1] = LoadAligned16(ma444 + x + 8); + LoadAligned32U32(b343 + x + 8, b[0]); + LoadAligned32U32(b444 + x + 8, b[1]); + p[1] = CalculateFilteredOutputPass2(sr_hi, ma, b); + const __m128i d1 = SelfGuidedDoubleMultiplier(sr_hi, p, w0, w2); + ClipAndStore(dst + x + 0, d0); + ClipAndStore(dst + x + 8, d1); + s[1] = s[3]; + sq[2] = sq[6]; + sq[3] = sq[7]; + ma3[0] = ma3[1]; + ma5[0] = ma5[1]; + b3[0] = b3[4]; + b3[1] = b3[5]; + b5[0] = b5[4]; + b5[1] = b5[5]; + x += 16; + } while (x < width); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( + const RestorationUnitInfo& restoration_info, const uint16_t* src, + const ptrdiff_t stride, const uint16_t* const top_border, + const ptrdiff_t top_border_stride, const uint16_t* bottom_border, + const ptrdiff_t bottom_border_stride, const int width, const int height, + SgrBuffer* const sgr_buffer, uint16_t* dst) { + const auto temp_stride = Align<ptrdiff_t>(width, 16); + const auto sum_width = Align<ptrdiff_t>(width + 8, 16); + const auto sum_stride = temp_stride + 16; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index]; // < 2^12. + const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0]; + const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1]; + const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1; + uint16_t *sum3[4], *sum5[5], *ma343[4], *ma444[3], *ma565[2]; + uint32_t *square_sum3[4], *square_sum5[5], *b343[4], *b444[3], *b565[2]; + sum3[0] = sgr_buffer->sum3; + square_sum3[0] = sgr_buffer->square_sum3; + ma343[0] = sgr_buffer->ma343; + b343[0] = sgr_buffer->b343; + for (int i = 1; i <= 3; ++i) { + sum3[i] = sum3[i - 1] + sum_stride; + square_sum3[i] = square_sum3[i - 1] + sum_stride; + ma343[i] = ma343[i - 1] + temp_stride; + b343[i] = b343[i - 1] + temp_stride; + } + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + for (int i = 1; i <= 4; ++i) { + sum5[i] = sum5[i - 1] + sum_stride; + square_sum5[i] = square_sum5[i - 1] + sum_stride; + } + ma444[0] = sgr_buffer->ma444; + b444[0] = sgr_buffer->b444; + for (int i = 1; i <= 2; ++i) { + ma444[i] = ma444[i - 1] + temp_stride; + b444[i] = b444[i - 1] + temp_stride; + } + ma565[0] = sgr_buffer->ma565; + ma565[1] = ma565[0] + temp_stride; + b565[0] = sgr_buffer->b565; + b565[1] = b565[0] + temp_stride; + assert(scales[0] != 0); + assert(scales[1] != 0); + BoxSum(top_border, top_border_stride, width, sum_stride, sum_width, sum3[0], + sum5[1], square_sum3[0], square_sum5[1]); + sum5[0] = sum5[1]; + square_sum5[0] = square_sum5[1]; + const uint16_t* const s = (height > 1) ? src + stride : bottom_border; + BoxSumFilterPreProcess(src, s, width, scales, sum3, sum5, square_sum3, + square_sum5, sum_width, ma343, ma444[0], ma565[0], + b343, b444[0], b565[0]); + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + + for (int y = (height >> 1) - 1; y > 0; --y) { + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + BoxFilter(src + 3, src + 2 * stride, src + 3 * stride, stride, width, + scales, w0, w2, sum3, sum5, square_sum3, square_sum5, sum_width, + ma343, ma444, ma565, b343, b444, b565, dst); + src += 2 * stride; + dst += 2 * stride; + Circulate4PointersBy2<uint16_t>(ma343); + Circulate4PointersBy2<uint32_t>(b343); + std::swap(ma444[0], ma444[2]); + std::swap(b444[0], b444[2]); + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + if ((height & 1) == 0 || height > 1) { + const uint16_t* sr[2]; + if ((height & 1) == 0) { + sr[0] = bottom_border; + sr[1] = bottom_border + bottom_border_stride; + } else { + sr[0] = src + 2 * stride; + sr[1] = bottom_border; + } + BoxFilter(src + 3, sr[0], sr[1], stride, width, scales, w0, w2, sum3, sum5, + square_sum3, square_sum5, sum_width, ma343, ma444, ma565, b343, + b444, b565, dst); + } + if ((height & 1) != 0) { + if (height > 1) { + src += 2 * stride; + dst += 2 * stride; + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + Circulate4PointersBy2<uint16_t>(ma343); + Circulate4PointersBy2<uint32_t>(b343); + std::swap(ma444[0], ma444[2]); + std::swap(b444[0], b444[2]); + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + BoxFilterLastRow(src + 3, bottom_border + bottom_border_stride, width, + sum_width, scales, w0, w2, sum3, sum5, square_sum3, + square_sum5, ma343[0], ma444[0], ma565[0], b343[0], + b444[0], b565[0], dst); + } +} + +inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, + const uint16_t* src, const ptrdiff_t stride, + const uint16_t* const top_border, + const ptrdiff_t top_border_stride, + const uint16_t* bottom_border, + const ptrdiff_t bottom_border_stride, + const int width, const int height, + SgrBuffer* const sgr_buffer, uint16_t* dst) { + const auto temp_stride = Align<ptrdiff_t>(width, 16); + const auto sum_width = Align<ptrdiff_t>(width + 8, 16); + const auto sum_stride = temp_stride + 16; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0]; // < 2^12. + const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0]; + uint16_t *sum5[5], *ma565[2]; + uint32_t *square_sum5[5], *b565[2]; + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + for (int i = 1; i <= 4; ++i) { + sum5[i] = sum5[i - 1] + sum_stride; + square_sum5[i] = square_sum5[i - 1] + sum_stride; + } + ma565[0] = sgr_buffer->ma565; + ma565[1] = ma565[0] + temp_stride; + b565[0] = sgr_buffer->b565; + b565[1] = b565[0] + temp_stride; + assert(scale != 0); + BoxSum<5>(top_border, top_border_stride, width, sum_stride, sum_width, + sum5[1], square_sum5[1]); + sum5[0] = sum5[1]; + square_sum5[0] = square_sum5[1]; + const uint16_t* const s = (height > 1) ? src + stride : bottom_border; + BoxSumFilterPreProcess5(src, s, width, scale, sum5, square_sum5, sum_width, + ma565[0], b565[0]); + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + + for (int y = (height >> 1) - 1; y > 0; --y) { + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + BoxFilterPass1(src + 3, src + 2 * stride, src + 3 * stride, stride, sum5, + square_sum5, width, sum_width, scale, w0, ma565, b565, dst); + src += 2 * stride; + dst += 2 * stride; + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + if ((height & 1) == 0 || height > 1) { + const uint16_t* sr[2]; + if ((height & 1) == 0) { + sr[0] = bottom_border; + sr[1] = bottom_border + bottom_border_stride; + } else { + sr[0] = src + 2 * stride; + sr[1] = bottom_border; + } + BoxFilterPass1(src + 3, sr[0], sr[1], stride, sum5, square_sum5, width, + sum_width, scale, w0, ma565, b565, dst); + } + if ((height & 1) != 0) { + src += 3; + if (height > 1) { + src += 2 * stride; + dst += 2 * stride; + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + } + BoxFilterPass1LastRow(src, bottom_border + bottom_border_stride, width, + sum_width, scale, w0, sum5, square_sum5, ma565[0], + b565[0], dst); + } +} + +inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, + const uint16_t* src, const ptrdiff_t stride, + const uint16_t* const top_border, + const ptrdiff_t top_border_stride, + const uint16_t* bottom_border, + const ptrdiff_t bottom_border_stride, + const int width, const int height, + SgrBuffer* const sgr_buffer, uint16_t* dst) { + assert(restoration_info.sgr_proj_info.multiplier[0] == 0); + const auto temp_stride = Align<ptrdiff_t>(width, 16); + const auto sum_width = Align<ptrdiff_t>(width + 8, 16); + const auto sum_stride = temp_stride + 16; + const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1]; + const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint32_t scale = kSgrScaleParameter[sgr_proj_index][1]; // < 2^12. + uint16_t *sum3[3], *ma343[3], *ma444[2]; + uint32_t *square_sum3[3], *b343[3], *b444[2]; + sum3[0] = sgr_buffer->sum3; + square_sum3[0] = sgr_buffer->square_sum3; + ma343[0] = sgr_buffer->ma343; + b343[0] = sgr_buffer->b343; + for (int i = 1; i <= 2; ++i) { + sum3[i] = sum3[i - 1] + sum_stride; + square_sum3[i] = square_sum3[i - 1] + sum_stride; + ma343[i] = ma343[i - 1] + temp_stride; + b343[i] = b343[i - 1] + temp_stride; + } + ma444[0] = sgr_buffer->ma444; + ma444[1] = ma444[0] + temp_stride; + b444[0] = sgr_buffer->b444; + b444[1] = b444[0] + temp_stride; + assert(scale != 0); + BoxSum<3>(top_border, top_border_stride, width, sum_stride, sum_width, + sum3[0], square_sum3[0]); + BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3, + sum_width, ma343[0], nullptr, b343[0], + nullptr); + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + const uint16_t* s; + if (height > 1) { + s = src + stride; + } else { + s = bottom_border; + bottom_border += bottom_border_stride; + } + BoxSumFilterPreProcess3<true>(s, width, scale, sum3, square_sum3, sum_width, + ma343[1], ma444[0], b343[1], b444[0]); + + for (int y = height - 2; y > 0; --y) { + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + BoxFilterPass2(src + 2, src + 2 * stride, width, sum_width, scale, w0, sum3, + square_sum3, ma343, ma444, b343, b444, dst); + src += stride; + dst += stride; + Circulate3PointersBy1<uint16_t>(ma343); + Circulate3PointersBy1<uint32_t>(b343); + std::swap(ma444[0], ma444[1]); + std::swap(b444[0], b444[1]); + } + + int y = std::min(height, 2); + src += 2; + do { + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + BoxFilterPass2(src, bottom_border, width, sum_width, scale, w0, sum3, + square_sum3, ma343, ma444, b343, b444, dst); + src += stride; + dst += stride; + bottom_border += bottom_border_stride; + Circulate3PointersBy1<uint16_t>(ma343); + Circulate3PointersBy1<uint32_t>(b343); + std::swap(ma444[0], ma444[1]); + std::swap(b444[0], b444[1]); + } while (--y != 0); +} + +// If |width| is non-multiple of 16, up to 15 more pixels are written to |dest| +// in the end of each row. It is safe to overwrite the output as it will not be +// part of the visible frame. +void SelfGuidedFilter_SSE4_1( + const RestorationUnitInfo& restoration_info, const void* const source, + const ptrdiff_t stride, const void* const top_border, + const ptrdiff_t top_border_stride, const void* const bottom_border, + const ptrdiff_t bottom_border_stride, const int width, const int height, + RestorationBuffer* const restoration_buffer, void* const dest) { + const int index = restoration_info.sgr_proj_info.index; + const int radius_pass_0 = kSgrProjParams[index][0]; // 2 or 0 + const int radius_pass_1 = kSgrProjParams[index][2]; // 1 or 0 + const auto* const src = static_cast<const uint16_t*>(source); + const auto* const top = static_cast<const uint16_t*>(top_border); + const auto* const bottom = static_cast<const uint16_t*>(bottom_border); + auto* const dst = static_cast<uint16_t*>(dest); + SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer; + if (radius_pass_1 == 0) { + // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the + // following assertion. + assert(radius_pass_0 != 0); + BoxFilterProcessPass1(restoration_info, src - 3, stride, top - 3, + top_border_stride, bottom - 3, bottom_border_stride, + width, height, sgr_buffer, dst); + } else if (radius_pass_0 == 0) { + BoxFilterProcessPass2(restoration_info, src - 2, stride, top - 2, + top_border_stride, bottom - 2, bottom_border_stride, + width, height, sgr_buffer, dst); + } else { + BoxFilterProcess(restoration_info, src - 3, stride, top - 3, + top_border_stride, bottom - 3, bottom_border_stride, width, + height, sgr_buffer, dst); + } +} + void Init10bpp() { Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); assert(dsp != nullptr); @@ -531,6 +2505,11 @@ void Init10bpp() { #else static_cast<void>(WienerFilter_SSE4_1); #endif +#if DSP_ENABLED_10BPP_SSE4_1(SelfGuidedFilter) + dsp->loop_restorations[1] = SelfGuidedFilter_SSE4_1; +#else + static_cast<void>(SelfGuidedFilter_SSE4_1); +#endif } } // namespace @@ -540,7 +2519,7 @@ void LoopRestorationInit10bpp_SSE4_1() { Init10bpp(); } } // namespace dsp } // namespace libgav1 -#else // !(LIBGAV1_TARGETING_SSE4_1 && LIBGAV1_MAX_BITDEPTH >= 10) +#else // !(LIBGAV1_TARGETING_SSE4_1 && LIBGAV1_MAX_BITDEPTH >= 10) namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/loop_restoration_avx2.cc b/src/dsp/x86/loop_restoration_avx2.cc index 7ae7c90..351a324 100644 --- a/src/dsp/x86/loop_restoration_avx2.cc +++ b/src/dsp/x86/loop_restoration_avx2.cc @@ -28,7 +28,6 @@ #include "src/dsp/constants.h" #include "src/dsp/dsp.h" #include "src/dsp/x86/common_avx2.h" -#include "src/dsp/x86/common_sse4.h" #include "src/utils/common.h" #include "src/utils/constants.h" @@ -116,7 +115,8 @@ inline void WienerHorizontalTap7(const uint8_t* src, const ptrdiff_t src_stride, filter[0] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0100)); filter[1] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0302)); filter[2] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0102)); - filter[3] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x8000)); + filter[3] = _mm256_shuffle_epi8( + coefficients, _mm256_set1_epi16(static_cast<int16_t>(0x8000))); for (int y = height; y != 0; --y) { __m256i s = LoadUnaligned32(src); __m256i ss[4]; @@ -144,7 +144,8 @@ inline void WienerHorizontalTap5(const uint8_t* src, const ptrdiff_t src_stride, __m256i filter[3]; filter[0] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0201)); filter[1] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0203)); - filter[2] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x8001)); + filter[2] = _mm256_shuffle_epi8( + coefficients, _mm256_set1_epi16(static_cast<int16_t>(0x8001))); for (int y = height; y != 0; --y) { __m256i s = LoadUnaligned32(src); __m256i ss[4]; @@ -171,7 +172,8 @@ inline void WienerHorizontalTap3(const uint8_t* src, const ptrdiff_t src_stride, int16_t** const wiener_buffer) { __m256i filter[2]; filter[0] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0302)); - filter[1] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x8002)); + filter[1] = _mm256_shuffle_epi8( + coefficients, _mm256_set1_epi16(static_cast<int16_t>(0x8002))); for (int y = height; y != 0; --y) { __m256i s = LoadUnaligned32(src); __m256i ss[4]; @@ -480,12 +482,12 @@ inline void WienerVerticalTap1(const int16_t* wiener_buffer, } } -void WienerFilter_AVX2(const RestorationUnitInfo& restoration_info, - const void* const source, const void* const top_border, - const void* const bottom_border, const ptrdiff_t stride, - const int width, const int height, - RestorationBuffer* const restoration_buffer, - void* const dest) { +void WienerFilter_AVX2( + const RestorationUnitInfo& restoration_info, const void* const source, + const ptrdiff_t stride, const void* const top_border, + const ptrdiff_t top_border_stride, const void* const bottom_border, + const ptrdiff_t bottom_border_stride, const int width, const int height, + RestorationBuffer* const restoration_buffer, void* const dest) { const int16_t* const number_leading_zero_coefficients = restoration_info.wiener_info.number_leading_zero_coefficients; const int number_rows_to_skip = std::max( @@ -515,39 +517,42 @@ void WienerFilter_AVX2(const RestorationUnitInfo& restoration_info, c_horizontal = _mm_packs_epi16(c_horizontal, c_horizontal); const __m256i coefficients_horizontal = _mm256_broadcastd_epi32(c_horizontal); if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) { - WienerHorizontalTap7(top + (2 - height_extra) * stride - 3, stride, - wiener_stride, height_extra, coefficients_horizontal, - &wiener_buffer_horizontal); - WienerHorizontalTap7(src - 3, stride, wiener_stride, height, + WienerHorizontalTap7(top + (2 - height_extra) * top_border_stride - 3, + top_border_stride, wiener_stride, height_extra, coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap7(bottom - 3, stride, wiener_stride, height_extra, + WienerHorizontalTap7(src - 3, stride, wiener_stride, height, coefficients_horizontal, &wiener_buffer_horizontal); - } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) { - WienerHorizontalTap5(top + (2 - height_extra) * stride - 2, stride, - wiener_stride, height_extra, coefficients_horizontal, + WienerHorizontalTap7(bottom - 3, bottom_border_stride, wiener_stride, + height_extra, coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap5(src - 2, stride, wiener_stride, height, + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) { + WienerHorizontalTap5(top + (2 - height_extra) * top_border_stride - 2, + top_border_stride, wiener_stride, height_extra, coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap5(bottom - 2, stride, wiener_stride, height_extra, + WienerHorizontalTap5(src - 2, stride, wiener_stride, height, coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap5(bottom - 2, bottom_border_stride, wiener_stride, + height_extra, coefficients_horizontal, + &wiener_buffer_horizontal); } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) { // The maximum over-reads happen here. - WienerHorizontalTap3(top + (2 - height_extra) * stride - 1, stride, - wiener_stride, height_extra, coefficients_horizontal, - &wiener_buffer_horizontal); - WienerHorizontalTap3(src - 1, stride, wiener_stride, height, + WienerHorizontalTap3(top + (2 - height_extra) * top_border_stride - 1, + top_border_stride, wiener_stride, height_extra, coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap3(bottom - 1, stride, wiener_stride, height_extra, + WienerHorizontalTap3(src - 1, stride, wiener_stride, height, coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap3(bottom - 1, bottom_border_stride, wiener_stride, + height_extra, coefficients_horizontal, + &wiener_buffer_horizontal); } else { assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3); - WienerHorizontalTap1(top + (2 - height_extra) * stride, stride, - wiener_stride, height_extra, + WienerHorizontalTap1(top + (2 - height_extra) * top_border_stride, + top_border_stride, wiener_stride, height_extra, &wiener_buffer_horizontal); WienerHorizontalTap1(src, stride, wiener_stride, height, &wiener_buffer_horizontal); - WienerHorizontalTap1(bottom, stride, wiener_stride, height_extra, - &wiener_buffer_horizontal); + WienerHorizontalTap1(bottom, bottom_border_stride, wiener_stride, + height_extra, &wiener_buffer_horizontal); } // vertical filtering. @@ -765,17 +770,6 @@ inline __m256i VaddwHi16(const __m256i src0, const __m256i src1) { return _mm256_add_epi32(src0, s1); } -// Using VgetLane16() can save a sign extension instruction. -template <int n> -inline int VgetLane16(__m256i src) { - return _mm256_extract_epi16(src, n); -} - -template <int n> -inline int VgetLane8(__m256i src) { - return _mm256_extract_epi8(src, n); -} - inline __m256i VmullNLo8(const __m256i src0, const int src1) { const __m256i s0 = _mm256_unpacklo_epi16(src0, _mm256_setzero_si256()); return _mm256_madd_epi16(s0, _mm256_set1_epi32(src1)); @@ -1253,9 +1247,8 @@ inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride, do { const __m128i s0 = LoadUnaligned16Msan(src, kOverreadInBytesPass1_128 - width); - __m128i sq_128[2]; + __m128i sq_128[2], s3, s5, sq3[2], sq5[2]; __m256i sq[3]; - __m128i s3, s5, sq3[2], sq5[2]; sq_128[0] = SquareLo8(s0); sq_128[1] = SquareHi8(s0); SumHorizontalLo(s0, &s3, &s5); @@ -1432,11 +1425,43 @@ inline __m256i CalculateMa(const __m256i sum, const __m256i sum_sq[2], return _mm256_packus_epi32(z0, z1); } -template <int n> -inline __m128i CalculateB(const __m128i sum, const __m128i ma) { - static_assert(n == 9 || n == 25, ""); +inline __m128i CalculateB5(const __m128i sum, const __m128i ma) { + // one_over_n == 164. constexpr uint32_t one_over_n = - ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n; + ((1 << kSgrProjReciprocalBits) + (25 >> 1)) / 25; + // one_over_n_quarter == 41. + constexpr uint32_t one_over_n_quarter = one_over_n >> 2; + static_assert(one_over_n == one_over_n_quarter << 2, ""); + // |ma| is in range [0, 255]. + const __m128i m = _mm_maddubs_epi16(ma, _mm_set1_epi16(one_over_n_quarter)); + const __m128i m0 = VmullLo16(m, sum); + const __m128i m1 = VmullHi16(m, sum); + const __m128i b_lo = VrshrU32(m0, kSgrProjReciprocalBits - 2); + const __m128i b_hi = VrshrU32(m1, kSgrProjReciprocalBits - 2); + return _mm_packus_epi32(b_lo, b_hi); +} + +inline __m256i CalculateB5(const __m256i sum, const __m256i ma) { + // one_over_n == 164. + constexpr uint32_t one_over_n = + ((1 << kSgrProjReciprocalBits) + (25 >> 1)) / 25; + // one_over_n_quarter == 41. + constexpr uint32_t one_over_n_quarter = one_over_n >> 2; + static_assert(one_over_n == one_over_n_quarter << 2, ""); + // |ma| is in range [0, 255]. + const __m256i m = + _mm256_maddubs_epi16(ma, _mm256_set1_epi16(one_over_n_quarter)); + const __m256i m0 = VmullLo16(m, sum); + const __m256i m1 = VmullHi16(m, sum); + const __m256i b_lo = VrshrU32(m0, kSgrProjReciprocalBits - 2); + const __m256i b_hi = VrshrU32(m1, kSgrProjReciprocalBits - 2); + return _mm256_packus_epi32(b_lo, b_hi); +} + +inline __m128i CalculateB3(const __m128i sum, const __m128i ma) { + // one_over_n == 455. + constexpr uint32_t one_over_n = + ((1 << kSgrProjReciprocalBits) + (9 >> 1)) / 9; const __m128i m0 = VmullLo16(ma, sum); const __m128i m1 = VmullHi16(ma, sum); const __m128i m2 = _mm_mullo_epi32(m0, _mm_set1_epi32(one_over_n)); @@ -1446,11 +1471,10 @@ inline __m128i CalculateB(const __m128i sum, const __m128i ma) { return _mm_packus_epi32(b_lo, b_hi); } -template <int n> -inline __m256i CalculateB(const __m256i sum, const __m256i ma) { - static_assert(n == 9 || n == 25, ""); +inline __m256i CalculateB3(const __m256i sum, const __m256i ma) { + // one_over_n == 455. constexpr uint32_t one_over_n = - ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n; + ((1 << kSgrProjReciprocalBits) + (9 >> 1)) / 9; const __m256i m0 = VmullLo16(ma, sum); const __m256i m1 = VmullHi16(ma, sum); const __m256i m2 = _mm256_mullo_epi32(m0, _mm256_set1_epi32(one_over_n)); @@ -1525,7 +1549,7 @@ inline void LookupIntermediate(const __m128i sum, const __m128i index, // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits). // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits). const __m128i maq = _mm_unpacklo_epi8(*ma, _mm_setzero_si128()); - *b = CalculateB<n>(sum, maq); + *b = (n == 9) ? CalculateB3(sum, maq) : CalculateB5(sum, maq); } // Repeat the first 48 elements in kSgrMaLookup with a period of 16. @@ -1539,7 +1563,7 @@ alignas(32) constexpr uint8_t kSgrMaLookupAvx2[96] = { // Set the shuffle control mask of indices out of range [0, 15] to (1xxxxxxx)b // to get value 0 as the shuffle result. The most significiant bit 1 comes -// either from the comparision instruction, or from the sign bit of the index. +// either from the comparison instruction, or from the sign bit of the index. inline __m256i ShuffleIndex(const __m256i table, const __m256i index) { __m256i mask; mask = _mm256_cmpgt_epi8(index, _mm256_set1_epi8(15)); @@ -1558,15 +1582,15 @@ template <int n> inline void CalculateIntermediate(const __m256i sum[2], const __m256i index[2], __m256i ma[3], __m256i b[2]) { static_assert(n == 9 || n == 25, ""); - // Use table lookup to read elements which indices are less than 48. + // Use table lookup to read elements whose indices are less than 48. const __m256i c0 = LoadAligned32(kSgrMaLookupAvx2 + 0 * 32); const __m256i c1 = LoadAligned32(kSgrMaLookupAvx2 + 1 * 32); const __m256i c2 = LoadAligned32(kSgrMaLookupAvx2 + 2 * 32); const __m256i indices = _mm256_packus_epi16(index[0], index[1]); __m256i idx, mas; - // Clip idx to 127 to apply signed comparision instructions. + // Clip idx to 127 to apply signed comparison instructions. idx = _mm256_min_epu8(indices, _mm256_set1_epi8(127)); - // All elements which indices are less than 48 are set to 0. + // All elements whose indices are less than 48 are set to 0. // Get shuffle results for indices in range [0, 15]. mas = ShuffleIndex(c0, idx); // Get shuffle results for indices in range [16, 31]. @@ -1581,12 +1605,12 @@ inline void CalculateIntermediate(const __m256i sum[2], const __m256i index[2], const __m256i res2 = ShuffleIndex(c2, idx); mas = _mm256_or_si256(mas, res2); - // For elements which indices are larger than 47, since they seldom change + // For elements whose indices are larger than 47, since they seldom change // values with the increase of the index, we use comparison and arithmetic // operations to calculate their values. - // Add -128 to apply signed comparision instructions. + // Add -128 to apply signed comparison instructions. idx = _mm256_add_epi8(indices, _mm256_set1_epi8(-128)); - // Elements which indices are larger than 47 (with value 0) are set to 5. + // Elements whose indices are larger than 47 (with value 0) are set to 5. mas = _mm256_max_epu8(mas, _mm256_set1_epi8(5)); mas = AdjustValue(mas, idx, 55); // 55 is the last index which value is 5. mas = AdjustValue(mas, idx, 72); // 72 is the last index which value is 4. @@ -1611,8 +1635,13 @@ inline void CalculateIntermediate(const __m256i sum[2], const __m256i index[2], // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits). const __m256i maq0 = _mm256_unpackhi_epi8(ma[0], _mm256_setzero_si256()); const __m256i maq1 = _mm256_unpacklo_epi8(ma[1], _mm256_setzero_si256()); - b[0] = CalculateB<n>(sum[0], maq0); - b[1] = CalculateB<n>(sum[1], maq1); + if (n == 9) { + b[0] = CalculateB3(sum[0], maq0); + b[1] = CalculateB3(sum[1], maq1); + } else { + b[0] = CalculateB5(sum[0], maq0); + b[1] = CalculateB5(sum[1], maq1); + } } inline void CalculateIntermediate5(const __m128i s5[5], const __m128i sq5[5][2], @@ -1903,8 +1932,8 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess( __m256i b3[2][5], __m256i ma5[3], __m256i b5[5]) { const __m256i s0 = LoadUnaligned32Msan(src0 + 8, over_read_in_bytes + 8); const __m256i s1 = LoadUnaligned32Msan(src1 + 8, over_read_in_bytes + 8); - __m256i s3[2][4], s5[2][5], sq3[4][2], sq5[5][2], sq3t[4][2], sq5t[5][2], - sum_3[2][2], index_3[2][2], sum_5[2], index_5[2]; + __m256i s3[2][4], s5[2][5], sq3[4][2], sq5[5][2], sum_3[2][2], index_3[2][2], + sum_5[2], index_5[2]; sq[0][1] = SquareLo8(s0); sq[0][2] = SquareHi8(s0); sq[1][1] = SquareLo8(s1); @@ -1938,22 +1967,22 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess( LoadAligned64x3U32(square_sum5, x, sq5); CalculateSumAndIndex5(s5[0], sq5, scales[0], &sum_5[0], &index_5[0]); - SumHorizontal(sq[0] + 1, &sq3t[2][0], &sq3t[2][1], &sq5t[3][0], &sq5t[3][1]); - SumHorizontal(sq[1] + 1, &sq3t[3][0], &sq3t[3][1], &sq5t[4][0], &sq5t[4][1]); - StoreAligned64(square_sum3[2] + x + 16, sq3t[2]); - StoreAligned64(square_sum5[3] + x + 16, sq5t[3]); - StoreAligned64(square_sum3[3] + x + 16, sq3t[3]); - StoreAligned64(square_sum5[4] + x + 16, sq5t[4]); + SumHorizontal(sq[0] + 1, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + SumHorizontal(sq[1] + 1, &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]); + StoreAligned64(square_sum3[2] + x + 16, sq3[2]); + StoreAligned64(square_sum5[3] + x + 16, sq5[3]); + StoreAligned64(square_sum3[3] + x + 16, sq3[3]); + StoreAligned64(square_sum5[4] + x + 16, sq5[4]); LoadAligned32x2U16Msan(sum3, x + 16, sum_width, s3[1]); - LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3t); - CalculateSumAndIndex3(s3[1], sq3t, scales[1], &sum_3[0][1], &index_3[0][1]); - CalculateSumAndIndex3(s3[1] + 1, sq3t + 1, scales[1], &sum_3[1][1], + LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3); + CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum_3[0][1], &index_3[0][1]); + CalculateSumAndIndex3(s3[1] + 1, sq3 + 1, scales[1], &sum_3[1][1], &index_3[1][1]); CalculateIntermediate<9>(sum_3[0], index_3[0], ma3[0], b3[0] + 1); CalculateIntermediate<9>(sum_3[1], index_3[1], ma3[1], b3[1] + 1); LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]); - LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5t); - CalculateSumAndIndex5(s5[1], sq5t, scales[0], &sum_5[1], &index_5[1]); + LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5); + CalculateSumAndIndex5(s5[1], sq5, scales[0], &sum_5[1], &index_5[1]); CalculateIntermediate<25>(sum_5, index_5, ma5, b5 + 1); b3[0][0] = _mm256_permute2x128_si256(b3[0][0], b3[0][2], 0x21); b3[1][0] = _mm256_permute2x128_si256(b3[1][0], b3[1][2], 0x21); @@ -1988,8 +2017,8 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow( __m256i sq[6], __m256i ma3[2], __m256i ma5[2], __m256i b3[5], __m256i b5[5]) { const __m256i s0 = LoadUnaligned32Msan(src + 8, over_read_in_bytes + 8); - __m256i s3[2][3], s5[2][5], sq3[4][2], sq3t[4][2], sq5[5][2], sq5t[5][2], - sum_3[2], index_3[2], sum_5[2], index_5[2]; + __m256i s3[2][3], s5[2][5], sq3[4][2], sq5[5][2], sum_3[2], index_3[2], + sum_5[2], index_5[2]; sq[1] = SquareLo8(s0); sq[2] = SquareHi8(s0); sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21); @@ -2006,17 +2035,17 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow( sq5[4][1] = sq5[3][1]; CalculateSumAndIndex5(s5[0], sq5, scales[0], &sum_5[0], &index_5[0]); - SumHorizontal(sq + 1, &sq3t[2][0], &sq3t[2][1], &sq5t[3][0], &sq5t[3][1]); + SumHorizontal(sq + 1, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); LoadAligned32x2U16Msan(sum3, x + 16, sum_width, s3[1]); - LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3t); - CalculateSumAndIndex3(s3[1], sq3t, scales[1], &sum_3[1], &index_3[1]); + LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3); + CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum_3[1], &index_3[1]); CalculateIntermediate<9>(sum_3, index_3, ma3, b3 + 1); LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]); s5[1][4] = s5[1][3]; - LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5t); - sq5t[4][0] = sq5t[3][0]; - sq5t[4][1] = sq5t[3][1]; - CalculateSumAndIndex5(s5[1], sq5t, scales[0], &sum_5[1], &index_5[1]); + LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + CalculateSumAndIndex5(s5[1], sq5, scales[0], &sum_5[1], &index_5[1]); CalculateIntermediate<25>(sum_5, index_5, ma5, b5 + 1); b3[0] = _mm256_permute2x128_si256(b3[0], b3[2], 0x21); b5[0] = _mm256_permute2x128_si256(b5[0], b5[2], 0x21); @@ -2071,9 +2100,9 @@ LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3( uint16_t* const sum3[3], uint32_t* const square_sum3[3], const ptrdiff_t sum_width, uint16_t* ma343, uint16_t* ma444, uint32_t* b343, uint32_t* b444) { + const __m128i s = LoadUnaligned16Msan(src, kOverreadInBytesPass2_128 - width); __m128i ma0, sq_128[2], b0; __m256i mas[3], sq[3], bs[3]; - const __m128i s = LoadUnaligned16Msan(src, kOverreadInBytesPass2_128 - width); sq_128[0] = SquareLo8(s); BoxFilterPreProcess3Lo(s, scale, sum3, square_sum3, sq_128, &ma0, &b0); sq[0] = SetrM128i(sq_128[0], sq_128[1]); @@ -2115,9 +2144,9 @@ inline void BoxSumFilterPreProcess( const uint8_t* const src0, const uint8_t* const src1, const int width, const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5], uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], - const ptrdiff_t sum_width, uint16_t* const ma343[4], - uint16_t* const ma444[2], uint16_t* ma565, uint32_t* const b343[4], - uint32_t* const b444[2], uint32_t* b565) { + const ptrdiff_t sum_width, uint16_t* const ma343[4], uint16_t* const ma444, + uint16_t* ma565, uint32_t* const b343[4], uint32_t* const b444, + uint32_t* b565) { __m128i s[2], ma3_128[2], ma5_0, sq_128[2][2], b3_128[2], b5_0; __m256i ma3[2][3], ma5[3], sq[2][3], b3[2][5], b5[5]; s[0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width); @@ -2151,9 +2180,8 @@ inline void BoxSumFilterPreProcess( Sum565W(b5, b); StoreAligned64(b565, b); Prepare3_8(ma3[1], ma3x); - Store343_444Lo(ma3x, b3[1], x, ma343[1], ma444[0], b343[1], b444[0]); - Store343_444Hi(ma3x, b3[1] + 1, x + 16, ma343[1], ma444[0], b343[1], - b444[0]); + Store343_444Lo(ma3x, b3[1], x, ma343[1], ma444, b343[1], b444); + Store343_444Hi(ma3x, b3[1] + 1, x + 16, ma343[1], ma444, b343[1], b444); Prepare3_8(ma5, ma5x); ma[0] = Sum565Lo(ma5x); ma[1] = Sum565Hi(ma5x); @@ -2199,8 +2227,9 @@ inline __m256i CalculateFilteredOutput(const __m256i src, const __m256i ma, return _mm256_packs_epi32(dst_lo, dst_hi); // 13 bits } -inline __m256i CalculateFilteredOutputPass1(const __m256i src, __m256i ma[2], - __m256i b[2][2]) { +inline __m256i CalculateFilteredOutputPass1(const __m256i src, + const __m256i ma[2], + const __m256i b[2][2]) { const __m256i ma_sum = _mm256_add_epi16(ma[0], ma[1]); __m256i b_sum[2]; b_sum[0] = _mm256_add_epi32(b[0][0], b[1][0]); @@ -2208,8 +2237,9 @@ inline __m256i CalculateFilteredOutputPass1(const __m256i src, __m256i ma[2], return CalculateFilteredOutput<5>(src, ma_sum, b_sum); } -inline __m256i CalculateFilteredOutputPass2(const __m256i src, __m256i ma[3], - __m256i b[3][2]) { +inline __m256i CalculateFilteredOutputPass2(const __m256i src, + const __m256i ma[3], + const __m256i b[3][2]) { const __m256i ma_sum = Sum3_16(ma); __m256i b_sum[2]; Sum3_32(b, b_sum); @@ -2267,13 +2297,13 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterPass1( int x = 0; do { - __m256i ma[3], ma3[3], b[2][2][2]; + __m256i ma[3], ma5[3], b[2][2][2]; BoxFilterPreProcess5(src0 + x + 8, src1 + x + 8, x + 8 + kOverreadInBytesPass1_256 - width, sum_width, x + 8, scale, sum5, square_sum5, sq, mas, bs); - Prepare3_8(mas, ma3); - ma[1] = Sum565Lo(ma3); - ma[2] = Sum565Hi(ma3); + Prepare3_8(mas, ma5); + ma[1] = Sum565Lo(ma5); + ma[2] = Sum565Hi(ma5); StoreAligned64(ma565[1] + x, ma + 1); Sum565W(bs + 0, b[0][1]); Sum565W(bs + 1, b[1][1]); @@ -2511,9 +2541,9 @@ inline void BoxFilterLastRow( const ptrdiff_t sum_width, const uint16_t scales[2], const int16_t w0, const int16_t w2, uint16_t* const sum3[4], uint16_t* const sum5[5], uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], - uint16_t* const ma343[4], uint16_t* const ma444[3], - uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3], - uint32_t* const b565[2], uint8_t* const dst) { + uint16_t* const ma343, uint16_t* const ma444, uint16_t* const ma565, + uint32_t* const b343, uint32_t* const b444, uint32_t* const b565, + uint8_t* const dst) { const __m128i s0 = LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width); __m128i ma3_0, ma5_0, b3_0, b5_0, sq_128[2]; @@ -2542,13 +2572,13 @@ inline void BoxFilterLastRow( Sum343W(b3, b[2]); const __m256i sr = LoadUnaligned32(src + x); const __m256i sr_lo = _mm256_unpacklo_epi8(sr, _mm256_setzero_si256()); - ma[0] = LoadAligned32(ma565[0] + x); - LoadAligned64(b565[0] + x, b[0]); + ma[0] = LoadAligned32(ma565 + x); + LoadAligned64(b565 + x, b[0]); p[0] = CalculateFilteredOutputPass1(sr_lo, ma, b); - ma[0] = LoadAligned32(ma343[0] + x); - ma[1] = LoadAligned32(ma444[0] + x); - LoadAligned64(b343[0] + x, b[0]); - LoadAligned64(b444[0] + x, b[1]); + ma[0] = LoadAligned32(ma343 + x); + ma[1] = LoadAligned32(ma444 + x); + LoadAligned64(b343 + x, b[0]); + LoadAligned64(b444 + x, b[1]); p[1] = CalculateFilteredOutputPass2(sr_lo, ma, b); const __m256i d0 = SelfGuidedDoubleMultiplier(sr_lo, p, w0, w2); @@ -2557,13 +2587,13 @@ inline void BoxFilterLastRow( mat[2] = Sum343Hi(ma3x); Sum343W(b3 + 1, b[2]); const __m256i sr_hi = _mm256_unpackhi_epi8(sr, _mm256_setzero_si256()); - mat[0] = LoadAligned32(ma565[0] + x + 16); - LoadAligned64(b565[0] + x + 16, b[0]); + mat[0] = LoadAligned32(ma565 + x + 16); + LoadAligned64(b565 + x + 16, b[0]); p[0] = CalculateFilteredOutputPass1(sr_hi, mat, b); - mat[0] = LoadAligned32(ma343[0] + x + 16); - mat[1] = LoadAligned32(ma444[0] + x + 16); - LoadAligned64(b343[0] + x + 16, b[0]); - LoadAligned64(b444[0] + x + 16, b[1]); + mat[0] = LoadAligned32(ma343 + x + 16); + mat[1] = LoadAligned32(ma444 + x + 16); + LoadAligned64(b343 + x + 16, b[0]); + LoadAligned64(b444 + x + 16, b[1]); p[1] = CalculateFilteredOutputPass2(sr_hi, mat, b); const __m256i d1 = SelfGuidedDoubleMultiplier(sr_hi, p, w0, w2); StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1)); @@ -2578,8 +2608,9 @@ inline void BoxFilterLastRow( LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( const RestorationUnitInfo& restoration_info, const uint8_t* src, - const uint8_t* const top_border, const uint8_t* bottom_border, - const ptrdiff_t stride, const int width, const int height, + const ptrdiff_t stride, const uint8_t* const top_border, + const ptrdiff_t top_border_stride, const uint8_t* bottom_border, + const ptrdiff_t bottom_border_stride, const int width, const int height, SgrBuffer* const sgr_buffer, uint8_t* dst) { const auto temp_stride = Align<ptrdiff_t>(width, 32); const auto sum_width = temp_stride + 8; @@ -2619,14 +2650,14 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( b565[1] = b565[0] + temp_stride; assert(scales[0] != 0); assert(scales[1] != 0); - BoxSum(top_border, stride, width, sum_stride, temp_stride, sum3[0], sum5[1], - square_sum3[0], square_sum5[1]); + BoxSum(top_border, top_border_stride, width, sum_stride, temp_stride, sum3[0], + sum5[1], square_sum3[0], square_sum5[1]); sum5[0] = sum5[1]; square_sum5[0] = square_sum5[1]; const uint8_t* const s = (height > 1) ? src + stride : bottom_border; BoxSumFilterPreProcess(src, s, width, scales, sum3, sum5, square_sum3, - square_sum5, sum_width, ma343, ma444, ma565[0], b343, - b444, b565[0]); + square_sum5, sum_width, ma343, ma444[0], ma565[0], + b343, b444[0], b565[0]); sum5[0] = sgr_buffer->sum5 + kSumOffset; square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset; @@ -2656,7 +2687,7 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( const uint8_t* sr[2]; if ((height & 1) == 0) { sr[0] = bottom_border; - sr[1] = bottom_border + stride; + sr[1] = bottom_border + bottom_border_stride; } else { sr[0] = src + 2 * stride; sr[1] = bottom_border; @@ -2680,19 +2711,21 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( std::swap(ma565[0], ma565[1]); std::swap(b565[0], b565[1]); } - BoxFilterLastRow(src + 3, bottom_border + stride, width, sum_width, scales, - w0, w2, sum3, sum5, square_sum3, square_sum5, ma343, ma444, - ma565, b343, b444, b565, dst); + BoxFilterLastRow(src + 3, bottom_border + bottom_border_stride, width, + sum_width, scales, w0, w2, sum3, sum5, square_sum3, + square_sum5, ma343[0], ma444[0], ma565[0], b343[0], + b444[0], b565[0], dst); } } inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, - const uint8_t* src, + const uint8_t* src, const ptrdiff_t stride, const uint8_t* const top_border, + const ptrdiff_t top_border_stride, const uint8_t* bottom_border, - const ptrdiff_t stride, const int width, - const int height, SgrBuffer* const sgr_buffer, - uint8_t* dst) { + const ptrdiff_t bottom_border_stride, + const int width, const int height, + SgrBuffer* const sgr_buffer, uint8_t* dst) { const auto temp_stride = Align<ptrdiff_t>(width, 32); const auto sum_width = temp_stride + 8; const auto sum_stride = temp_stride + 32; @@ -2712,8 +2745,8 @@ inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, b565[0] = sgr_buffer->b565; b565[1] = b565[0] + temp_stride; assert(scale != 0); - BoxSum<5>(top_border, stride, width, sum_stride, temp_stride, sum5[1], - square_sum5[1]); + BoxSum<5>(top_border, top_border_stride, width, sum_stride, temp_stride, + sum5[1], square_sum5[1]); sum5[0] = sum5[1]; square_sum5[0] = square_sum5[1]; const uint8_t* const s = (height > 1) ? src + stride : bottom_border; @@ -2739,7 +2772,7 @@ inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, const uint8_t* sr[2]; if ((height & 1) == 0) { sr[0] = bottom_border; - sr[1] = bottom_border + stride; + sr[1] = bottom_border + bottom_border_stride; } else { sr[0] = src + 2 * stride; sr[1] = bottom_border; @@ -2757,18 +2790,20 @@ inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, Circulate5PointersBy2<uint16_t>(sum5); Circulate5PointersBy2<uint32_t>(square_sum5); } - BoxFilterPass1LastRow(src, bottom_border + stride, width, sum_width, scale, - w0, sum5, square_sum5, ma565[0], b565[0], dst); + BoxFilterPass1LastRow(src, bottom_border + bottom_border_stride, width, + sum_width, scale, w0, sum5, square_sum5, ma565[0], + b565[0], dst); } } inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, - const uint8_t* src, + const uint8_t* src, const ptrdiff_t stride, const uint8_t* const top_border, + const ptrdiff_t top_border_stride, const uint8_t* bottom_border, - const ptrdiff_t stride, const int width, - const int height, SgrBuffer* const sgr_buffer, - uint8_t* dst) { + const ptrdiff_t bottom_border_stride, + const int width, const int height, + SgrBuffer* const sgr_buffer, uint8_t* dst) { assert(restoration_info.sgr_proj_info.multiplier[0] == 0); const auto temp_stride = Align<ptrdiff_t>(width, 32); const auto sum_width = temp_stride + 8; @@ -2794,8 +2829,8 @@ inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, b444[0] = sgr_buffer->b444; b444[1] = b444[0] + temp_stride; assert(scale != 0); - BoxSum<3>(top_border, stride, width, sum_stride, temp_stride, sum3[0], - square_sum3[0]); + BoxSum<3>(top_border, top_border_stride, width, sum_stride, temp_stride, + sum3[0], square_sum3[0]); BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3, sum_width, ma343[0], nullptr, b343[0], nullptr); @@ -2806,7 +2841,7 @@ inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, s = src + stride; } else { s = bottom_border; - bottom_border += stride; + bottom_border += bottom_border_stride; } BoxSumFilterPreProcess3<true>(s, width, scale, sum3, square_sum3, sum_width, ma343[1], ma444[0], b343[1], b444[0]); @@ -2833,7 +2868,7 @@ inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, square_sum3, ma343, ma444, b343, b444, dst); src += stride; dst += stride; - bottom_border += stride; + bottom_border += bottom_border_stride; Circulate3PointersBy1<uint16_t>(ma343); Circulate3PointersBy1<uint32_t>(b343); std::swap(ma444[0], ma444[1]); @@ -2841,13 +2876,14 @@ inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, } while (--y != 0); } -// If |width| is non-multiple of 8, up to 7 more pixels are written to |dest| in -// the end of each row. It is safe to overwrite the output as it will not be +// If |width| is non-multiple of 32, up to 31 more pixels are written to |dest| +// in the end of each row. It is safe to overwrite the output as it will not be // part of the visible frame. void SelfGuidedFilter_AVX2( const RestorationUnitInfo& restoration_info, const void* const source, - const void* const top_border, const void* const bottom_border, - const ptrdiff_t stride, const int width, const int height, + const ptrdiff_t stride, const void* const top_border, + const ptrdiff_t top_border_stride, const void* const bottom_border, + const ptrdiff_t bottom_border_stride, const int width, const int height, RestorationBuffer* const restoration_buffer, void* const dest) { const int index = restoration_info.sgr_proj_info.index; const int radius_pass_0 = kSgrProjParams[index][0]; // 2 or 0 @@ -2861,14 +2897,17 @@ void SelfGuidedFilter_AVX2( // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the // following assertion. assert(radius_pass_0 != 0); - BoxFilterProcessPass1(restoration_info, src - 3, top - 3, bottom - 3, - stride, width, height, sgr_buffer, dst); + BoxFilterProcessPass1(restoration_info, src - 3, stride, top - 3, + top_border_stride, bottom - 3, bottom_border_stride, + width, height, sgr_buffer, dst); } else if (radius_pass_0 == 0) { - BoxFilterProcessPass2(restoration_info, src - 2, top - 2, bottom - 2, - stride, width, height, sgr_buffer, dst); + BoxFilterProcessPass2(restoration_info, src - 2, stride, top - 2, + top_border_stride, bottom - 2, bottom_border_stride, + width, height, sgr_buffer, dst); } else { - BoxFilterProcess(restoration_info, src - 3, top - 3, bottom - 3, stride, - width, height, sgr_buffer, dst); + BoxFilterProcess(restoration_info, src - 3, stride, top - 3, + top_border_stride, bottom - 3, bottom_border_stride, width, + height, sgr_buffer, dst); } } @@ -2891,7 +2930,7 @@ void LoopRestorationInit_AVX2() { low_bitdepth::Init8bpp(); } } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_AVX2 +#else // !LIBGAV1_TARGETING_AVX2 namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/loop_restoration_avx2.h b/src/dsp/x86/loop_restoration_avx2.h index d80227c..2c3534a 100644 --- a/src/dsp/x86/loop_restoration_avx2.h +++ b/src/dsp/x86/loop_restoration_avx2.h @@ -47,6 +47,10 @@ void LoopRestorationInit10bpp_AVX2(); #define LIBGAV1_Dsp8bpp_SelfGuidedFilter LIBGAV1_CPU_AVX2 #endif +#ifndef LIBGAV1_Dsp10bpp_SelfGuidedFilter +#define LIBGAV1_Dsp10bpp_SelfGuidedFilter LIBGAV1_CPU_AVX2 +#endif + #endif // LIBGAV1_TARGETING_AVX2 #endif // LIBGAV1_SRC_DSP_X86_LOOP_RESTORATION_AVX2_H_ diff --git a/src/dsp/x86/loop_restoration_sse4.cc b/src/dsp/x86/loop_restoration_sse4.cc index 24f5ad2..273bcc8 100644 --- a/src/dsp/x86/loop_restoration_sse4.cc +++ b/src/dsp/x86/loop_restoration_sse4.cc @@ -481,13 +481,12 @@ inline void WienerVerticalTap1(const int16_t* wiener_buffer, } } -void WienerFilter_SSE4_1(const RestorationUnitInfo& restoration_info, - const void* const source, const void* const top_border, - const void* const bottom_border, - const ptrdiff_t stride, const int width, - const int height, - RestorationBuffer* const restoration_buffer, - void* const dest) { +void WienerFilter_SSE4_1( + const RestorationUnitInfo& restoration_info, const void* const source, + const ptrdiff_t stride, const void* const top_border, + const ptrdiff_t top_border_stride, const void* const bottom_border, + const ptrdiff_t bottom_border_stride, const int width, const int height, + RestorationBuffer* const restoration_buffer, void* const dest) { const int16_t* const number_leading_zero_coefficients = restoration_info.wiener_info.number_leading_zero_coefficients; const int number_rows_to_skip = std::max( @@ -516,45 +515,48 @@ void WienerFilter_SSE4_1(const RestorationUnitInfo& restoration_info, const __m128i coefficients_horizontal = _mm_sub_epi16(c, _mm_setr_epi16(0, 0, 0, 128, 0, 0, 0, 0)); if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) { - WienerHorizontalTap7(top + (2 - height_extra) * stride - 3, stride, - wiener_stride, height_extra, filter_horizontal[0], - coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap7(src - 3, stride, wiener_stride, height, + WienerHorizontalTap7(top + (2 - height_extra) * top_border_stride - 3, + top_border_stride, wiener_stride, height_extra, filter_horizontal[0], coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap7(bottom - 3, stride, wiener_stride, height_extra, + WienerHorizontalTap7(src - 3, stride, wiener_stride, height, filter_horizontal[0], coefficients_horizontal, &wiener_buffer_horizontal); - } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) { - WienerHorizontalTap5(top + (2 - height_extra) * stride - 2, stride, - wiener_stride, height_extra, filter_horizontal[1], + WienerHorizontalTap7(bottom - 3, bottom_border_stride, wiener_stride, + height_extra, filter_horizontal[0], coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap5(src - 2, stride, wiener_stride, height, + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) { + WienerHorizontalTap5(top + (2 - height_extra) * top_border_stride - 2, + top_border_stride, wiener_stride, height_extra, filter_horizontal[1], coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap5(bottom - 2, stride, wiener_stride, height_extra, + WienerHorizontalTap5(src - 2, stride, wiener_stride, height, filter_horizontal[1], coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap5(bottom - 2, bottom_border_stride, wiener_stride, + height_extra, filter_horizontal[1], + coefficients_horizontal, &wiener_buffer_horizontal); } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) { // The maximum over-reads happen here. - WienerHorizontalTap3(top + (2 - height_extra) * stride - 1, stride, - wiener_stride, height_extra, filter_horizontal[2], - coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap3(src - 1, stride, wiener_stride, height, + WienerHorizontalTap3(top + (2 - height_extra) * top_border_stride - 1, + top_border_stride, wiener_stride, height_extra, filter_horizontal[2], coefficients_horizontal, &wiener_buffer_horizontal); - WienerHorizontalTap3(bottom - 1, stride, wiener_stride, height_extra, + WienerHorizontalTap3(src - 1, stride, wiener_stride, height, filter_horizontal[2], coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap3(bottom - 1, bottom_border_stride, wiener_stride, + height_extra, filter_horizontal[2], + coefficients_horizontal, &wiener_buffer_horizontal); } else { assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3); - WienerHorizontalTap1(top + (2 - height_extra) * stride, stride, - wiener_stride, height_extra, + WienerHorizontalTap1(top + (2 - height_extra) * top_border_stride, + top_border_stride, wiener_stride, height_extra, &wiener_buffer_horizontal); WienerHorizontalTap1(src, stride, wiener_stride, height, &wiener_buffer_horizontal); - WienerHorizontalTap1(bottom, stride, wiener_stride, height_extra, - &wiener_buffer_horizontal); + WienerHorizontalTap1(bottom, bottom_border_stride, wiener_stride, + height_extra, &wiener_buffer_horizontal); } // vertical filtering. @@ -1160,11 +1162,26 @@ inline __m128i CalculateMa(const __m128i sum, const __m128i sum_sq[2], return _mm_packus_epi32(z0, z1); } -template <int n> -inline __m128i CalculateB(const __m128i sum, const __m128i ma) { - static_assert(n == 9 || n == 25, ""); +inline __m128i CalculateB5(const __m128i sum, const __m128i ma) { + // one_over_n == 164. constexpr uint32_t one_over_n = - ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n; + ((1 << kSgrProjReciprocalBits) + (25 >> 1)) / 25; + // one_over_n_quarter == 41. + constexpr uint32_t one_over_n_quarter = one_over_n >> 2; + static_assert(one_over_n == one_over_n_quarter << 2, ""); + // |ma| is in range [0, 255]. + const __m128i m = _mm_maddubs_epi16(ma, _mm_set1_epi16(one_over_n_quarter)); + const __m128i m0 = VmullLo16(m, sum); + const __m128i m1 = VmullHi16(m, sum); + const __m128i b_lo = VrshrU32(m0, kSgrProjReciprocalBits - 2); + const __m128i b_hi = VrshrU32(m1, kSgrProjReciprocalBits - 2); + return _mm_packus_epi32(b_lo, b_hi); +} + +inline __m128i CalculateB3(const __m128i sum, const __m128i ma) { + // one_over_n == 455. + constexpr uint32_t one_over_n = + ((1 << kSgrProjReciprocalBits) + (9 >> 1)) / 9; const __m128i m0 = VmullLo16(ma, sum); const __m128i m1 = VmullHi16(ma, sum); const __m128i m2 = _mm_mullo_epi32(m0, _mm_set1_epi32(one_over_n)); @@ -1227,12 +1244,12 @@ inline void LookupIntermediate(const __m128i sum, const __m128i index, } else { maq = _mm_unpackhi_epi8(*ma, _mm_setzero_si128()); } - *b = CalculateB<n>(sum, maq); + *b = (n == 9) ? CalculateB3(sum, maq) : CalculateB5(sum, maq); } // Set the shuffle control mask of indices out of range [0, 15] to (1xxxxxxx)b // to get value 0 as the shuffle result. The most significiant bit 1 comes -// either from the comparision instruction, or from the sign bit of the index. +// either from the comparison instruction, or from the sign bit of the index. inline __m128i ShuffleIndex(const __m128i table, const __m128i index) { __m128i mask; mask = _mm_cmpgt_epi8(index, _mm_set1_epi8(15)); @@ -1250,15 +1267,15 @@ inline __m128i AdjustValue(const __m128i value, const __m128i index, inline void CalculateIntermediate(const __m128i sum[2], const __m128i index[2], __m128i* const ma, __m128i* const b0, __m128i* const b1) { - // Use table lookup to read elements which indices are less than 48. + // Use table lookup to read elements whose indices are less than 48. const __m128i c0 = LoadAligned16(kSgrMaLookup + 0 * 16); const __m128i c1 = LoadAligned16(kSgrMaLookup + 1 * 16); const __m128i c2 = LoadAligned16(kSgrMaLookup + 2 * 16); const __m128i indices = _mm_packus_epi16(index[0], index[1]); __m128i idx; - // Clip idx to 127 to apply signed comparision instructions. + // Clip idx to 127 to apply signed comparison instructions. idx = _mm_min_epu8(indices, _mm_set1_epi8(127)); - // All elements which indices are less than 48 are set to 0. + // All elements whose indices are less than 48 are set to 0. // Get shuffle results for indices in range [0, 15]. *ma = ShuffleIndex(c0, idx); // Get shuffle results for indices in range [16, 31]. @@ -1273,12 +1290,12 @@ inline void CalculateIntermediate(const __m128i sum[2], const __m128i index[2], const __m128i res2 = ShuffleIndex(c2, idx); *ma = _mm_or_si128(*ma, res2); - // For elements which indices are larger than 47, since they seldom change + // For elements whose indices are larger than 47, since they seldom change // values with the increase of the index, we use comparison and arithmetic // operations to calculate their values. - // Add -128 to apply signed comparision instructions. + // Add -128 to apply signed comparison instructions. idx = _mm_add_epi8(indices, _mm_set1_epi8(-128)); - // Elements which indices are larger than 47 (with value 0) are set to 5. + // Elements whose indices are larger than 47 (with value 0) are set to 5. *ma = _mm_max_epu8(*ma, _mm_set1_epi8(5)); *ma = AdjustValue(*ma, idx, 55); // 55 is the last index which value is 5. *ma = AdjustValue(*ma, idx, 72); // 72 is the last index which value is 4. @@ -1298,9 +1315,9 @@ inline void CalculateIntermediate(const __m128i sum[2], const __m128i index[2], // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits). // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits). const __m128i maq0 = _mm_unpacklo_epi8(*ma, _mm_setzero_si128()); - *b0 = CalculateB<9>(sum[0], maq0); + *b0 = CalculateB3(sum[0], maq0); const __m128i maq1 = _mm_unpackhi_epi8(*ma, _mm_setzero_si128()); - *b1 = CalculateB<9>(sum[1], maq1); + *b1 = CalculateB3(sum[1], maq1); } inline void CalculateIntermediate(const __m128i sum[2], const __m128i index[2], @@ -1776,9 +1793,9 @@ inline void BoxSumFilterPreProcess( const uint8_t* const src0, const uint8_t* const src1, const int width, const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5], uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], - const ptrdiff_t sum_width, uint16_t* const ma343[4], - uint16_t* const ma444[2], uint16_t* ma565, uint32_t* const b343[4], - uint32_t* const b444[2], uint32_t* b565) { + const ptrdiff_t sum_width, uint16_t* const ma343[4], uint16_t* const ma444, + uint16_t* ma565, uint32_t* const b343[4], uint32_t* const b444, + uint32_t* b565) { __m128i s[2][2], ma3[2][2], ma5[2], sq[2][4], b3[2][3], b5[3]; s[0][0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1 - width); s[1][0] = LoadUnaligned16Msan(src1, kOverreadInBytesPass1 - width); @@ -1808,9 +1825,8 @@ inline void BoxSumFilterPreProcess( Sum565W(b5 + 1, b + 2); StoreAligned64U32(b565, b); Prepare3_8<0>(ma3[1], ma3x); - Store343_444Lo(ma3x, b3[1], x, ma343[1], ma444[0], b343[1], b444[0]); - Store343_444Hi(ma3x, b3[1] + 1, x + 8, ma343[1], ma444[0], b343[1], - b444[0]); + Store343_444Lo(ma3x, b3[1], x, ma343[1], ma444, b343[1], b444); + Store343_444Hi(ma3x, b3[1] + 1, x + 8, ma343[1], ma444, b343[1], b444); Prepare3_8<0>(ma5, ma5x); ma[0] = Sum565Lo(ma5x); ma[1] = Sum565Hi(ma5x); @@ -1854,8 +1870,9 @@ inline __m128i CalculateFilteredOutput(const __m128i src, const __m128i ma, return _mm_packs_epi32(dst_lo, dst_hi); // 13 bits } -inline __m128i CalculateFilteredOutputPass1(const __m128i src, __m128i ma[2], - __m128i b[2][2]) { +inline __m128i CalculateFilteredOutputPass1(const __m128i src, + const __m128i ma[2], + const __m128i b[2][2]) { const __m128i ma_sum = _mm_add_epi16(ma[0], ma[1]); __m128i b_sum[2]; b_sum[0] = _mm_add_epi32(b[0][0], b[1][0]); @@ -1863,8 +1880,9 @@ inline __m128i CalculateFilteredOutputPass1(const __m128i src, __m128i ma[2], return CalculateFilteredOutput<5>(src, ma_sum, b_sum); } -inline __m128i CalculateFilteredOutputPass2(const __m128i src, __m128i ma[3], - __m128i b[3][2]) { +inline __m128i CalculateFilteredOutputPass2(const __m128i src, + const __m128i ma[3], + const __m128i b[3][2]) { const __m128i ma_sum = Sum3_16(ma); __m128i b_sum[2]; Sum3_32(b, b_sum); @@ -1916,15 +1934,15 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterPass1( int x = 0; do { - __m128i ma[2], ma3[3], b[2][2], sr[2], p[2]; + __m128i ma[2], ma5[3], b[2][2], sr[2], p[2]; s[0][1] = LoadUnaligned16Msan(src0 + x + 16, x + 16 + kOverreadInBytesPass1 - width); s[1][1] = LoadUnaligned16Msan(src1 + x + 16, x + 16 + kOverreadInBytesPass1 - width); BoxFilterPreProcess5(s, sum_width, x + 8, scale, sum5, square_sum5, sq, mas, bs); - Prepare3_8<0>(mas, ma3); - ma[1] = Sum565Lo(ma3); + Prepare3_8<0>(mas, ma5); + ma[1] = Sum565Lo(ma5); StoreAligned16(ma565[1] + x, ma[1]); Sum565W(bs, b[1]); StoreAligned32U32(b565[1] + x, b[1]); @@ -1939,7 +1957,7 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterPass1( const __m128i d00 = SelfGuidedSingleMultiplier(sr0_lo, p[0], w0); const __m128i d10 = SelfGuidedSingleMultiplier(sr1_lo, p[1], w0); - ma[1] = Sum565Hi(ma3); + ma[1] = Sum565Hi(ma5); StoreAligned16(ma565[1] + x + 8, ma[1]); Sum565W(bs + 1, b[1]); StoreAligned32U32(b565[1] + x + 8, b[1]); @@ -2158,9 +2176,9 @@ inline void BoxFilterLastRow( const ptrdiff_t sum_width, const uint16_t scales[2], const int16_t w0, const int16_t w2, uint16_t* const sum3[4], uint16_t* const sum5[5], uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], - uint16_t* const ma343[4], uint16_t* const ma444[3], - uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3], - uint32_t* const b565[2], uint8_t* const dst) { + uint16_t* const ma343, uint16_t* const ma444, uint16_t* const ma565, + uint32_t* const b343, uint32_t* const b444, uint32_t* const b565, + uint8_t* const dst) { __m128i s[2], ma3[2], ma5[2], sq[4], b3[3], b5[3], ma[3], b[3][2]; s[0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1 - width); sq[0] = SquareLo8(s[0]); @@ -2183,13 +2201,13 @@ inline void BoxFilterLastRow( Sum343W(b3, b[2]); const __m128i sr = LoadAligned16(src + x); const __m128i sr_lo = _mm_unpacklo_epi8(sr, _mm_setzero_si128()); - ma[0] = LoadAligned16(ma565[0] + x); - LoadAligned32U32(b565[0] + x, b[0]); + ma[0] = LoadAligned16(ma565 + x); + LoadAligned32U32(b565 + x, b[0]); p[0] = CalculateFilteredOutputPass1(sr_lo, ma, b); - ma[0] = LoadAligned16(ma343[0] + x); - ma[1] = LoadAligned16(ma444[0] + x); - LoadAligned32U32(b343[0] + x, b[0]); - LoadAligned32U32(b444[0] + x, b[1]); + ma[0] = LoadAligned16(ma343 + x); + ma[1] = LoadAligned16(ma444 + x); + LoadAligned32U32(b343 + x, b[0]); + LoadAligned32U32(b444 + x, b[1]); p[1] = CalculateFilteredOutputPass2(sr_lo, ma, b); const __m128i d0 = SelfGuidedDoubleMultiplier(sr_lo, p, w0, w2); @@ -2198,13 +2216,13 @@ inline void BoxFilterLastRow( ma[2] = Sum343Hi(ma3x); Sum343W(b3 + 1, b[2]); const __m128i sr_hi = _mm_unpackhi_epi8(sr, _mm_setzero_si128()); - ma[0] = LoadAligned16(ma565[0] + x + 8); - LoadAligned32U32(b565[0] + x + 8, b[0]); + ma[0] = LoadAligned16(ma565 + x + 8); + LoadAligned32U32(b565 + x + 8, b[0]); p[0] = CalculateFilteredOutputPass1(sr_hi, ma, b); - ma[0] = LoadAligned16(ma343[0] + x + 8); - ma[1] = LoadAligned16(ma444[0] + x + 8); - LoadAligned32U32(b343[0] + x + 8, b[0]); - LoadAligned32U32(b444[0] + x + 8, b[1]); + ma[0] = LoadAligned16(ma343 + x + 8); + ma[1] = LoadAligned16(ma444 + x + 8); + LoadAligned32U32(b343 + x + 8, b[0]); + LoadAligned32U32(b444 + x + 8, b[1]); p[1] = CalculateFilteredOutputPass2(sr_hi, ma, b); const __m128i d1 = SelfGuidedDoubleMultiplier(sr_hi, p, w0, w2); StoreAligned16(dst + x, _mm_packus_epi16(d0, d1)); @@ -2220,8 +2238,9 @@ inline void BoxFilterLastRow( LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( const RestorationUnitInfo& restoration_info, const uint8_t* src, - const uint8_t* const top_border, const uint8_t* bottom_border, - const ptrdiff_t stride, const int width, const int height, + const ptrdiff_t stride, const uint8_t* const top_border, + const ptrdiff_t top_border_stride, const uint8_t* bottom_border, + const ptrdiff_t bottom_border_stride, const int width, const int height, SgrBuffer* const sgr_buffer, uint8_t* dst) { const auto temp_stride = Align<ptrdiff_t>(width, 16); const auto sum_width = Align<ptrdiff_t>(width + 8, 16); @@ -2261,14 +2280,14 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( b565[1] = b565[0] + temp_stride; assert(scales[0] != 0); assert(scales[1] != 0); - BoxSum(top_border, stride, width, sum_stride, sum_width, sum3[0], sum5[1], - square_sum3[0], square_sum5[1]); + BoxSum(top_border, top_border_stride, width, sum_stride, sum_width, sum3[0], + sum5[1], square_sum3[0], square_sum5[1]); sum5[0] = sum5[1]; square_sum5[0] = square_sum5[1]; const uint8_t* const s = (height > 1) ? src + stride : bottom_border; BoxSumFilterPreProcess(src, s, width, scales, sum3, sum5, square_sum3, - square_sum5, sum_width, ma343, ma444, ma565[0], b343, - b444, b565[0]); + square_sum5, sum_width, ma343, ma444[0], ma565[0], + b343, b444[0], b565[0]); sum5[0] = sgr_buffer->sum5; square_sum5[0] = sgr_buffer->square_sum5; @@ -2298,7 +2317,7 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( const uint8_t* sr[2]; if ((height & 1) == 0) { sr[0] = bottom_border; - sr[1] = bottom_border + stride; + sr[1] = bottom_border + bottom_border_stride; } else { sr[0] = src + 2 * stride; sr[1] = bottom_border; @@ -2322,19 +2341,21 @@ LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( std::swap(ma565[0], ma565[1]); std::swap(b565[0], b565[1]); } - BoxFilterLastRow(src + 3, bottom_border + stride, width, sum_width, scales, - w0, w2, sum3, sum5, square_sum3, square_sum5, ma343, ma444, - ma565, b343, b444, b565, dst); + BoxFilterLastRow(src + 3, bottom_border + bottom_border_stride, width, + sum_width, scales, w0, w2, sum3, sum5, square_sum3, + square_sum5, ma343[0], ma444[0], ma565[0], b343[0], + b444[0], b565[0], dst); } } inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, - const uint8_t* src, + const uint8_t* src, const ptrdiff_t stride, const uint8_t* const top_border, + const ptrdiff_t top_border_stride, const uint8_t* bottom_border, - const ptrdiff_t stride, const int width, - const int height, SgrBuffer* const sgr_buffer, - uint8_t* dst) { + const ptrdiff_t bottom_border_stride, + const int width, const int height, + SgrBuffer* const sgr_buffer, uint8_t* dst) { const auto temp_stride = Align<ptrdiff_t>(width, 16); const auto sum_width = Align<ptrdiff_t>(width + 8, 16); const auto sum_stride = temp_stride + 16; @@ -2354,8 +2375,8 @@ inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, b565[0] = sgr_buffer->b565; b565[1] = b565[0] + temp_stride; assert(scale != 0); - BoxSum<5>(top_border, stride, width, sum_stride, sum_width, sum5[1], - square_sum5[1]); + BoxSum<5>(top_border, top_border_stride, width, sum_stride, sum_width, + sum5[1], square_sum5[1]); sum5[0] = sum5[1]; square_sum5[0] = square_sum5[1]; const uint8_t* const s = (height > 1) ? src + stride : bottom_border; @@ -2381,7 +2402,7 @@ inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, const uint8_t* sr[2]; if ((height & 1) == 0) { sr[0] = bottom_border; - sr[1] = bottom_border + stride; + sr[1] = bottom_border + bottom_border_stride; } else { sr[0] = src + 2 * stride; sr[1] = bottom_border; @@ -2399,18 +2420,20 @@ inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, Circulate5PointersBy2<uint16_t>(sum5); Circulate5PointersBy2<uint32_t>(square_sum5); } - BoxFilterPass1LastRow(src, bottom_border + stride, width, sum_width, scale, - w0, sum5, square_sum5, ma565[0], b565[0], dst); + BoxFilterPass1LastRow(src, bottom_border + bottom_border_stride, width, + sum_width, scale, w0, sum5, square_sum5, ma565[0], + b565[0], dst); } } inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, - const uint8_t* src, + const uint8_t* src, const ptrdiff_t stride, const uint8_t* const top_border, + const ptrdiff_t top_border_stride, const uint8_t* bottom_border, - const ptrdiff_t stride, const int width, - const int height, SgrBuffer* const sgr_buffer, - uint8_t* dst) { + const ptrdiff_t bottom_border_stride, + const int width, const int height, + SgrBuffer* const sgr_buffer, uint8_t* dst) { assert(restoration_info.sgr_proj_info.multiplier[0] == 0); const auto temp_stride = Align<ptrdiff_t>(width, 16); const auto sum_width = Align<ptrdiff_t>(width + 8, 16); @@ -2436,8 +2459,8 @@ inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, b444[0] = sgr_buffer->b444; b444[1] = b444[0] + temp_stride; assert(scale != 0); - BoxSum<3>(top_border, stride, width, sum_stride, sum_width, sum3[0], - square_sum3[0]); + BoxSum<3>(top_border, top_border_stride, width, sum_stride, sum_width, + sum3[0], square_sum3[0]); BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3, sum_width, ma343[0], nullptr, b343[0], nullptr); @@ -2448,7 +2471,7 @@ inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, s = src + stride; } else { s = bottom_border; - bottom_border += stride; + bottom_border += bottom_border_stride; } BoxSumFilterPreProcess3<true>(s, width, scale, sum3, square_sum3, sum_width, ma343[1], ma444[0], b343[1], b444[0]); @@ -2475,7 +2498,7 @@ inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, square_sum3, ma343, ma444, b343, b444, dst); src += stride; dst += stride; - bottom_border += stride; + bottom_border += bottom_border_stride; Circulate3PointersBy1<uint16_t>(ma343); Circulate3PointersBy1<uint32_t>(b343); std::swap(ma444[0], ma444[1]); @@ -2483,13 +2506,14 @@ inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, } while (--y != 0); } -// If |width| is non-multiple of 8, up to 7 more pixels are written to |dest| in -// the end of each row. It is safe to overwrite the output as it will not be +// If |width| is non-multiple of 16, up to 15 more pixels are written to |dest| +// in the end of each row. It is safe to overwrite the output as it will not be // part of the visible frame. void SelfGuidedFilter_SSE4_1( const RestorationUnitInfo& restoration_info, const void* const source, - const void* const top_border, const void* const bottom_border, - const ptrdiff_t stride, const int width, const int height, + const ptrdiff_t stride, const void* const top_border, + const ptrdiff_t top_border_stride, const void* const bottom_border, + const ptrdiff_t bottom_border_stride, const int width, const int height, RestorationBuffer* const restoration_buffer, void* const dest) { const int index = restoration_info.sgr_proj_info.index; const int radius_pass_0 = kSgrProjParams[index][0]; // 2 or 0 @@ -2503,14 +2527,17 @@ void SelfGuidedFilter_SSE4_1( // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the // following assertion. assert(radius_pass_0 != 0); - BoxFilterProcessPass1(restoration_info, src - 3, top - 3, bottom - 3, - stride, width, height, sgr_buffer, dst); + BoxFilterProcessPass1(restoration_info, src - 3, stride, top - 3, + top_border_stride, bottom - 3, bottom_border_stride, + width, height, sgr_buffer, dst); } else if (radius_pass_0 == 0) { - BoxFilterProcessPass2(restoration_info, src - 2, top - 2, bottom - 2, - stride, width, height, sgr_buffer, dst); + BoxFilterProcessPass2(restoration_info, src - 2, stride, top - 2, + top_border_stride, bottom - 2, bottom_border_stride, + width, height, sgr_buffer, dst); } else { - BoxFilterProcess(restoration_info, src - 3, top - 3, bottom - 3, stride, - width, height, sgr_buffer, dst); + BoxFilterProcess(restoration_info, src - 3, stride, top - 3, + top_border_stride, bottom - 3, bottom_border_stride, width, + height, sgr_buffer, dst); } } @@ -2538,7 +2565,7 @@ void LoopRestorationInit_SSE4_1() { low_bitdepth::Init8bpp(); } } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_SSE4_1 +#else // !LIBGAV1_TARGETING_SSE4_1 namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/loop_restoration_sse4.h b/src/dsp/x86/loop_restoration_sse4.h index 65b2b11..00df3af 100644 --- a/src/dsp/x86/loop_restoration_sse4.h +++ b/src/dsp/x86/loop_restoration_sse4.h @@ -47,6 +47,10 @@ void LoopRestorationInit10bpp_SSE4_1(); #define LIBGAV1_Dsp10bpp_WienerFilter LIBGAV1_CPU_SSE4_1 #endif +#ifndef LIBGAV1_Dsp10bpp_SelfGuidedFilter +#define LIBGAV1_Dsp10bpp_SelfGuidedFilter LIBGAV1_CPU_SSE4_1 +#endif + #endif // LIBGAV1_TARGETING_SSE4_1 #endif // LIBGAV1_SRC_DSP_X86_LOOP_RESTORATION_SSE4_H_ diff --git a/src/dsp/x86/mask_blend_sse4.cc b/src/dsp/x86/mask_blend_sse4.cc index d8036be..2e836af 100644 --- a/src/dsp/x86/mask_blend_sse4.cc +++ b/src/dsp/x86/mask_blend_sse4.cc @@ -430,12 +430,515 @@ void Init8bpp() { } // namespace } // namespace low_bitdepth -void MaskBlendInit_SSE4_1() { low_bitdepth::Init8bpp(); } +#if LIBGAV1_MAX_BITDEPTH >= 10 +namespace high_bitdepth { +namespace { + +constexpr int kMax10bppSample = (1 << 10) - 1; +constexpr int kMaskInverse = 64; +constexpr int kRoundBitsMaskBlend = 4; + +inline __m128i RightShiftWithRoundingZero_U16(const __m128i v_val_d, int bits, + const __m128i zero) { + // Shift out all but the last bit. + const __m128i v_tmp_d = _mm_srli_epi16(v_val_d, bits - 1); + // Avg with zero will shift by 1 and round. + return _mm_avg_epu16(v_tmp_d, zero); +} + +inline __m128i RightShiftWithRoundingConst_S32(const __m128i v_val_d, int bits, + const __m128i shift) { + const __m128i v_tmp_d = _mm_add_epi32(v_val_d, shift); + return _mm_srai_epi32(v_tmp_d, bits); +} + +template <int subsampling_x, int subsampling_y> +inline __m128i GetMask4x2(const uint8_t* mask, ptrdiff_t mask_stride, + const __m128i zero) { + if (subsampling_x == 1) { + if (subsampling_y == 0) { + const __m128i mask_val_0 = _mm_cvtepu8_epi16(LoadLo8(mask)); + const __m128i mask_val_1 = + _mm_cvtepu8_epi16(LoadLo8(mask + (mask_stride << subsampling_y))); + __m128i subsampled_mask = _mm_hadd_epi16(mask_val_0, mask_val_1); + return RightShiftWithRoundingZero_U16(subsampled_mask, 1, zero); + } + const __m128i one = _mm_set1_epi8(1); + const __m128i mask_val_0 = + LoadHi8(LoadLo8(mask), mask + (mask_stride << 1)); + const __m128i mask_val_1 = LoadHi8(LoadLo8(mask + mask_stride), + mask + (mask_stride << 1) + mask_stride); + const __m128i add = _mm_adds_epu8(mask_val_0, mask_val_1); + const __m128i subsampled_mask = _mm_maddubs_epi16(add, one); + return RightShiftWithRoundingZero_U16(subsampled_mask, 2, zero); + } + assert(subsampling_y == 0 && subsampling_x == 0); + const __m128i mask_val_0 = Load4(mask); + const __m128i mask_val_1 = Load4(mask + mask_stride); + return _mm_cvtepu8_epi16( + _mm_or_si128(mask_val_0, _mm_slli_si128(mask_val_1, 4))); +} + +template <int subsampling_x, int subsampling_y> +inline __m128i GetMask8(const uint8_t* mask, const ptrdiff_t stride, + const __m128i zero) { + if (subsampling_x == 1) { + if (subsampling_y == 0) { + const __m128i row_vals = LoadUnaligned16(mask); + const __m128i mask_val_0 = _mm_cvtepu8_epi16(row_vals); + const __m128i mask_val_1 = _mm_cvtepu8_epi16(_mm_srli_si128(row_vals, 8)); + __m128i subsampled_mask = _mm_hadd_epi16(mask_val_0, mask_val_1); + return RightShiftWithRoundingZero_U16(subsampled_mask, 1, zero); + } + const __m128i one = _mm_set1_epi8(1); + const __m128i mask_val_0 = LoadUnaligned16(mask); + const __m128i mask_val_1 = LoadUnaligned16(mask + stride); + const __m128i add_0 = _mm_adds_epu8(mask_val_0, mask_val_1); + const __m128i mask_0 = _mm_maddubs_epi16(add_0, one); + return RightShiftWithRoundingZero_U16(mask_0, 2, zero); + } + assert(subsampling_y == 0 && subsampling_x == 0); + const __m128i mask_val = LoadLo8(mask); + return _mm_cvtepu8_epi16(mask_val); +} + +inline void WriteMaskBlendLine10bpp4x2_SSE4_1( + const uint16_t* pred_0, const uint16_t* pred_1, + const ptrdiff_t pred_stride_1, const __m128i& pred_mask_0, + const __m128i& pred_mask_1, const __m128i& offset, const __m128i& max, + const __m128i& shift4, uint16_t* dst, const ptrdiff_t dst_stride) { + const __m128i pred_val_0 = LoadUnaligned16(pred_0); + const __m128i pred_val_1 = LoadHi8(LoadLo8(pred_1), pred_1 + pred_stride_1); + + // int res = (mask_value * pred_0[x] + (64 - mask_value) * pred_1[x]) >> 6; + const __m128i compound_pred_lo_0 = _mm_mullo_epi16(pred_val_0, pred_mask_0); + const __m128i compound_pred_hi_0 = _mm_mulhi_epu16(pred_val_0, pred_mask_0); + const __m128i compound_pred_lo_1 = _mm_mullo_epi16(pred_val_1, pred_mask_1); + const __m128i compound_pred_hi_1 = _mm_mulhi_epu16(pred_val_1, pred_mask_1); + const __m128i pack0_lo = + _mm_unpacklo_epi16(compound_pred_lo_0, compound_pred_hi_0); + const __m128i pack0_hi = + _mm_unpackhi_epi16(compound_pred_lo_0, compound_pred_hi_0); + const __m128i pack1_lo = + _mm_unpacklo_epi16(compound_pred_lo_1, compound_pred_hi_1); + const __m128i pack1_hi = + _mm_unpackhi_epi16(compound_pred_lo_1, compound_pred_hi_1); + const __m128i compound_pred_lo = _mm_add_epi32(pack0_lo, pack1_lo); + const __m128i compound_pred_hi = _mm_add_epi32(pack0_hi, pack1_hi); + // res -= (bitdepth == 8) ? 0 : kCompoundOffset; + const __m128i sub_0 = + _mm_sub_epi32(_mm_srli_epi32(compound_pred_lo, 6), offset); + const __m128i sub_1 = + _mm_sub_epi32(_mm_srli_epi32(compound_pred_hi, 6), offset); + + // dst[x] = static_cast<Pixel>( + // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0, + // (1 << kBitdepth8) - 1)); + const __m128i shift_0 = + RightShiftWithRoundingConst_S32(sub_0, kRoundBitsMaskBlend, shift4); + const __m128i shift_1 = + RightShiftWithRoundingConst_S32(sub_1, kRoundBitsMaskBlend, shift4); + const __m128i result = _mm_min_epi16(_mm_packus_epi32(shift_0, shift_1), max); + StoreLo8(dst, result); + StoreHi8(dst + dst_stride, result); +} + +template <int subsampling_x, int subsampling_y> +inline void MaskBlend10bpp4x4_SSE4_1(const uint16_t* pred_0, + const uint16_t* pred_1, + const ptrdiff_t pred_stride_1, + const uint8_t* mask, + const ptrdiff_t mask_stride, uint16_t* dst, + const ptrdiff_t dst_stride) { + const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse); + const __m128i zero = _mm_setzero_si128(); + const __m128i shift4 = _mm_set1_epi32((1 << kRoundBitsMaskBlend) >> 1); + const __m128i offset = _mm_set1_epi32(kCompoundOffset); + const __m128i max = _mm_set1_epi16(kMax10bppSample); + __m128i pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride, zero); + __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + WriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1, pred_mask_0, + pred_mask_1, offset, max, shift4, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += pred_stride_1 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + + pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride, zero); + pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + WriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1, pred_mask_0, + pred_mask_1, offset, max, shift4, dst, + dst_stride); +} + +template <int subsampling_x, int subsampling_y> +inline void MaskBlend10bpp4xH_SSE4_1(const uint16_t* pred_0, + const uint16_t* pred_1, + const ptrdiff_t pred_stride_1, + const uint8_t* const mask_ptr, + const ptrdiff_t mask_stride, + const int height, uint16_t* dst, + const ptrdiff_t dst_stride) { + const uint8_t* mask = mask_ptr; + if (height == 4) { + MaskBlend10bpp4x4_SSE4_1<subsampling_x, subsampling_y>( + pred_0, pred_1, pred_stride_1, mask, mask_stride, dst, dst_stride); + return; + } + const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse); + const __m128i zero = _mm_setzero_si128(); + const uint8_t pred0_stride2 = 4 << 1; + const ptrdiff_t pred1_stride2 = pred_stride_1 << 1; + const ptrdiff_t mask_stride2 = mask_stride << (1 + subsampling_y); + const ptrdiff_t dst_stride2 = dst_stride << 1; + const __m128i offset = _mm_set1_epi32(kCompoundOffset); + const __m128i max = _mm_set1_epi16(kMax10bppSample); + const __m128i shift4 = _mm_set1_epi32((1 << kRoundBitsMaskBlend) >> 1); + int y = height; + do { + __m128i pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride, zero); + __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + + WriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1, + pred_mask_0, pred_mask_1, offset, max, + shift4, dst, dst_stride); + pred_0 += pred0_stride2; + pred_1 += pred1_stride2; + mask += mask_stride2; + dst += dst_stride2; + + pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride, zero); + pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + WriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1, + pred_mask_0, pred_mask_1, offset, max, + shift4, dst, dst_stride); + pred_0 += pred0_stride2; + pred_1 += pred1_stride2; + mask += mask_stride2; + dst += dst_stride2; + + pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride, zero); + pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + WriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1, + pred_mask_0, pred_mask_1, offset, max, + shift4, dst, dst_stride); + pred_0 += pred0_stride2; + pred_1 += pred1_stride2; + mask += mask_stride2; + dst += dst_stride2; + + pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride, zero); + pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + WriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1, + pred_mask_0, pred_mask_1, offset, max, + shift4, dst, dst_stride); + pred_0 += pred0_stride2; + pred_1 += pred1_stride2; + mask += mask_stride2; + dst += dst_stride2; + y -= 8; + } while (y != 0); +} + +template <int subsampling_x, int subsampling_y> +inline void MaskBlend10bpp_SSE4_1(const void* prediction_0, + const void* prediction_1, + const ptrdiff_t prediction_stride_1, + const uint8_t* const mask_ptr, + const ptrdiff_t mask_stride, const int width, + const int height, void* dest, + const ptrdiff_t dest_stride) { + auto* dst = static_cast<uint16_t*>(dest); + const ptrdiff_t dst_stride = dest_stride / sizeof(dst[0]); + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + const ptrdiff_t pred_stride_0 = width; + const ptrdiff_t pred_stride_1 = prediction_stride_1; + if (width == 4) { + MaskBlend10bpp4xH_SSE4_1<subsampling_x, subsampling_y>( + pred_0, pred_1, pred_stride_1, mask_ptr, mask_stride, height, dst, + dst_stride); + return; + } + const uint8_t* mask = mask_ptr; + const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse); + const __m128i zero = _mm_setzero_si128(); + const ptrdiff_t mask_stride_ss = mask_stride << subsampling_y; + const __m128i offset = _mm_set1_epi32(kCompoundOffset); + const __m128i max = _mm_set1_epi16(kMax10bppSample); + const __m128i shift4 = _mm_set1_epi32((1 << kRoundBitsMaskBlend) >> 1); + int y = height; + do { + int x = 0; + do { + const __m128i pred_mask_0 = GetMask8<subsampling_x, subsampling_y>( + mask + (x << subsampling_x), mask_stride, zero); + const __m128i pred_val_0 = LoadUnaligned16(pred_0 + x); + const __m128i pred_val_1 = LoadUnaligned16(pred_1 + x); + // 64 - mask + const __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + + const __m128i compound_pred_lo_0 = + _mm_mullo_epi16(pred_val_0, pred_mask_0); + const __m128i compound_pred_hi_0 = + _mm_mulhi_epu16(pred_val_0, pred_mask_0); + const __m128i compound_pred_lo_1 = + _mm_mullo_epi16(pred_val_1, pred_mask_1); + const __m128i compound_pred_hi_1 = + _mm_mulhi_epu16(pred_val_1, pred_mask_1); + const __m128i pack0_lo = + _mm_unpacklo_epi16(compound_pred_lo_0, compound_pred_hi_0); + const __m128i pack0_hi = + _mm_unpackhi_epi16(compound_pred_lo_0, compound_pred_hi_0); + const __m128i pack1_lo = + _mm_unpacklo_epi16(compound_pred_lo_1, compound_pred_hi_1); + const __m128i pack1_hi = + _mm_unpackhi_epi16(compound_pred_lo_1, compound_pred_hi_1); + const __m128i compound_pred_lo = _mm_add_epi32(pack0_lo, pack1_lo); + const __m128i compound_pred_hi = _mm_add_epi32(pack0_hi, pack1_hi); + + const __m128i sub_0 = + _mm_sub_epi32(_mm_srli_epi32(compound_pred_lo, 6), offset); + const __m128i sub_1 = + _mm_sub_epi32(_mm_srli_epi32(compound_pred_hi, 6), offset); + const __m128i shift_0 = + RightShiftWithRoundingConst_S32(sub_0, kRoundBitsMaskBlend, shift4); + const __m128i shift_1 = + RightShiftWithRoundingConst_S32(sub_1, kRoundBitsMaskBlend, shift4); + const __m128i result = + _mm_min_epi16(_mm_packus_epi32(shift_0, shift_1), max); + StoreUnaligned16(dst + x, result); + x += 8; + } while (x < width); + dst += dst_stride; + pred_0 += pred_stride_0; + pred_1 += pred_stride_1; + mask += mask_stride_ss; + } while (--y != 0); +} + +inline void InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1( + const uint16_t* prediction_0, const uint16_t* prediction_1, + const ptrdiff_t pred_stride_1, const __m128i& pred_mask_0, + const __m128i& pred_mask_1, const __m128i& shift6, uint16_t* dst, + const ptrdiff_t dst_stride) { + const __m128i pred_val_0 = LoadUnaligned16(prediction_0); + const __m128i pred_val_1 = + LoadHi8(LoadLo8(prediction_1), prediction_1 + pred_stride_1); + + const __m128i mask_0 = _mm_unpacklo_epi16(pred_mask_1, pred_mask_0); + const __m128i mask_1 = _mm_unpackhi_epi16(pred_mask_1, pred_mask_0); + const __m128i pred_0 = _mm_unpacklo_epi16(pred_val_0, pred_val_1); + const __m128i pred_1 = _mm_unpackhi_epi16(pred_val_0, pred_val_1); + + const __m128i compound_pred_0 = _mm_madd_epi16(pred_0, mask_0); + const __m128i compound_pred_1 = _mm_madd_epi16(pred_1, mask_1); + const __m128i shift_0 = + RightShiftWithRoundingConst_S32(compound_pred_0, 6, shift6); + const __m128i shift_1 = + RightShiftWithRoundingConst_S32(compound_pred_1, 6, shift6); + const __m128i res = _mm_packus_epi32(shift_0, shift_1); + StoreLo8(dst, res); + StoreHi8(dst + dst_stride, res); +} + +template <int subsampling_x, int subsampling_y> +inline void InterIntraMaskBlend10bpp4x4_SSE4_1( + const uint16_t* pred_0, const uint16_t* pred_1, + const ptrdiff_t pred_stride_1, const uint8_t* mask, + const ptrdiff_t mask_stride, uint16_t* dst, const ptrdiff_t dst_stride) { + const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse); + const __m128i shift6 = _mm_set1_epi32((1 << 6) >> 1); + const __m128i zero = _mm_setzero_si128(); + __m128i pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride, zero); + __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1, + pred_mask_0, pred_mask_1, shift6, + dst, dst_stride); + pred_0 += 4 << 1; + pred_1 += pred_stride_1 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + + pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride, zero); + pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1, + pred_mask_0, pred_mask_1, shift6, + dst, dst_stride); +} + +template <int subsampling_x, int subsampling_y> +inline void InterIntraMaskBlend10bpp4xH_SSE4_1(const uint16_t* pred_0, + const uint16_t* pred_1, + const ptrdiff_t pred_stride_1, + const uint8_t* const mask_ptr, + const ptrdiff_t mask_stride, + const int height, uint16_t* dst, + const ptrdiff_t dst_stride) { + const uint8_t* mask = mask_ptr; + if (height == 4) { + InterIntraMaskBlend10bpp4x4_SSE4_1<subsampling_x, subsampling_y>( + pred_0, pred_1, pred_stride_1, mask, mask_stride, dst, dst_stride); + return; + } + const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse); + const __m128i zero = _mm_setzero_si128(); + const __m128i shift6 = _mm_set1_epi32((1 << 6) >> 1); + const uint8_t pred0_stride2 = 4 << 1; + const ptrdiff_t pred1_stride2 = pred_stride_1 << 1; + const ptrdiff_t mask_stride2 = mask_stride << (1 + subsampling_y); + const ptrdiff_t dst_stride2 = dst_stride << 1; + int y = height; + do { + __m128i pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride, zero); + __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1, + pred_mask_0, pred_mask_1, + shift6, dst, dst_stride); + pred_0 += pred0_stride2; + pred_1 += pred1_stride2; + mask += mask_stride2; + dst += dst_stride2; + + pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride, zero); + pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1, + pred_mask_0, pred_mask_1, + shift6, dst, dst_stride); + pred_0 += pred0_stride2; + pred_1 += pred1_stride2; + mask += mask_stride2; + dst += dst_stride2; + + pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride, zero); + pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1, + pred_mask_0, pred_mask_1, + shift6, dst, dst_stride); + pred_0 += pred0_stride2; + pred_1 += pred1_stride2; + mask += mask_stride2; + dst += dst_stride2; + + pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride, zero); + pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1, + pred_mask_0, pred_mask_1, + shift6, dst, dst_stride); + pred_0 += pred0_stride2; + pred_1 += pred1_stride2; + mask += mask_stride2; + dst += dst_stride2; + y -= 8; + } while (y != 0); +} + +template <int subsampling_x, int subsampling_y> +inline void InterIntraMaskBlend10bpp_SSE4_1( + const void* prediction_0, const void* prediction_1, + const ptrdiff_t prediction_stride_1, const uint8_t* const mask_ptr, + const ptrdiff_t mask_stride, const int width, const int height, void* dest, + const ptrdiff_t dest_stride) { + auto* dst = static_cast<uint16_t*>(dest); + const ptrdiff_t dst_stride = dest_stride / sizeof(dst[0]); + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + const ptrdiff_t pred_stride_0 = width; + const ptrdiff_t pred_stride_1 = prediction_stride_1; + if (width == 4) { + InterIntraMaskBlend10bpp4xH_SSE4_1<subsampling_x, subsampling_y>( + pred_0, pred_1, pred_stride_1, mask_ptr, mask_stride, height, dst, + dst_stride); + return; + } + const uint8_t* mask = mask_ptr; + const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse); + const __m128i shift6 = _mm_set1_epi32((1 << 6) >> 1); + const __m128i zero = _mm_setzero_si128(); + const ptrdiff_t mask_stride_ss = mask_stride << subsampling_y; + int y = height; + do { + int x = 0; + do { + const __m128i pred_mask_0 = GetMask8<subsampling_x, subsampling_y>( + mask + (x << subsampling_x), mask_stride, zero); + const __m128i pred_val_0 = LoadUnaligned16(pred_0 + x); + const __m128i pred_val_1 = LoadUnaligned16(pred_1 + x); + // 64 - mask + const __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + const __m128i mask_0 = _mm_unpacklo_epi16(pred_mask_1, pred_mask_0); + const __m128i mask_1 = _mm_unpackhi_epi16(pred_mask_1, pred_mask_0); + const __m128i pred_0 = _mm_unpacklo_epi16(pred_val_0, pred_val_1); + const __m128i pred_1 = _mm_unpackhi_epi16(pred_val_0, pred_val_1); + + const __m128i compound_pred_0 = _mm_madd_epi16(pred_0, mask_0); + const __m128i compound_pred_1 = _mm_madd_epi16(pred_1, mask_1); + const __m128i shift_0 = + RightShiftWithRoundingConst_S32(compound_pred_0, 6, shift6); + const __m128i shift_1 = + RightShiftWithRoundingConst_S32(compound_pred_1, 6, shift6); + StoreUnaligned16(dst + x, _mm_packus_epi32(shift_0, shift_1)); + x += 8; + } while (x < width); + dst += dst_stride; + pred_0 += pred_stride_0; + pred_1 += pred_stride_1; + mask += mask_stride_ss; + } while (--y != 0); +} + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + +#if DSP_ENABLED_10BPP_SSE4_1(MaskBlend444) + dsp->mask_blend[0][0] = MaskBlend10bpp_SSE4_1<0, 0>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(MaskBlend422) + dsp->mask_blend[1][0] = MaskBlend10bpp_SSE4_1<1, 0>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(MaskBlend420) + dsp->mask_blend[2][0] = MaskBlend10bpp_SSE4_1<1, 1>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(MaskBlendInterIntra444) + dsp->mask_blend[0][1] = InterIntraMaskBlend10bpp_SSE4_1<0, 0>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(MaskBlendInterIntra422) + dsp->mask_blend[1][1] = InterIntraMaskBlend10bpp_SSE4_1<1, 0>; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(MaskBlendInterIntra420) + dsp->mask_blend[2][1] = InterIntraMaskBlend10bpp_SSE4_1<1, 1>; +#endif +} + +} // namespace +} // namespace high_bitdepth +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +void MaskBlendInit_SSE4_1() { + low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + high_bitdepth::Init10bpp(); +#endif // LIBGAV1_MAX_BITDEPTH >= 10 +} } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_SSE4_1 +#else // !LIBGAV1_TARGETING_SSE4_1 namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/mask_blend_sse4.h b/src/dsp/x86/mask_blend_sse4.h index 52b0b5c..4a95f0c 100644 --- a/src/dsp/x86/mask_blend_sse4.h +++ b/src/dsp/x86/mask_blend_sse4.h @@ -55,6 +55,30 @@ void MaskBlendInit_SSE4_1(); #define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp420 LIBGAV1_CPU_SSE4_1 #endif +#ifndef LIBGAV1_Dsp10bpp_MaskBlend444 +#define LIBGAV1_Dsp10bpp_MaskBlend444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_MaskBlend422 +#define LIBGAV1_Dsp10bpp_MaskBlend422 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_MaskBlend420 +#define LIBGAV1_Dsp10bpp_MaskBlend420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_MaskBlendInterIntra444 +#define LIBGAV1_Dsp10bpp_MaskBlendInterIntra444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_MaskBlendInterIntra422 +#define LIBGAV1_Dsp10bpp_MaskBlendInterIntra422 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_MaskBlendInterIntra420 +#define LIBGAV1_Dsp10bpp_MaskBlendInterIntra420 LIBGAV1_CPU_SSE4_1 +#endif + #endif // LIBGAV1_TARGETING_SSE4_1 #endif // LIBGAV1_SRC_DSP_X86_MASK_BLEND_SSE4_H_ diff --git a/src/dsp/x86/motion_field_projection_sse4.cc b/src/dsp/x86/motion_field_projection_sse4.cc index c506941..e3f2cce 100644 --- a/src/dsp/x86/motion_field_projection_sse4.cc +++ b/src/dsp/x86/motion_field_projection_sse4.cc @@ -139,9 +139,9 @@ inline void Store(const __m128i position, const __m128i reference_offset, const ptrdiff_t offset = static_cast<int16_t>(_mm_extract_epi16(position, idx)); if ((idx & 3) == 0) { - dst_mv[offset].mv32 = _mm_cvtsi128_si32(mv); + dst_mv[offset].mv32 = static_cast<uint32_t>(_mm_cvtsi128_si32(mv)); } else { - dst_mv[offset].mv32 = _mm_extract_epi32(mv, idx & 3); + dst_mv[offset].mv32 = static_cast<uint32_t>(_mm_extract_epi32(mv, idx & 3)); } dst_reference_offset[offset] = _mm_extract_epi8(reference_offset, idx); } @@ -386,7 +386,7 @@ void MotionFieldProjectionInit_SSE4_1() { } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_SSE4_1 +#else // !LIBGAV1_TARGETING_SSE4_1 namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/motion_vector_search_sse4.cc b/src/dsp/x86/motion_vector_search_sse4.cc index e9cdd4c..7f5f035 100644 --- a/src/dsp/x86/motion_vector_search_sse4.cc +++ b/src/dsp/x86/motion_vector_search_sse4.cc @@ -251,7 +251,7 @@ void MotionVectorSearchInit_SSE4_1() { } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_SSE4_1 +#else // !LIBGAV1_TARGETING_SSE4_1 namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/obmc_sse4.cc b/src/dsp/x86/obmc_sse4.cc index 3a1d1fd..c34a7f7 100644 --- a/src/dsp/x86/obmc_sse4.cc +++ b/src/dsp/x86/obmc_sse4.cc @@ -31,6 +31,7 @@ namespace libgav1 { namespace dsp { +namespace low_bitdepth { namespace { #include "src/dsp/obmc.inc" @@ -311,13 +312,295 @@ void Init8bpp() { } } // namespace +} // namespace low_bitdepth -void ObmcInit_SSE4_1() { Init8bpp(); } +#if LIBGAV1_MAX_BITDEPTH >= 10 +namespace high_bitdepth { +namespace { + +#include "src/dsp/obmc.inc" + +constexpr int kRoundBitsObmcBlend = 6; + +inline void OverlapBlendFromLeft2xH_SSE4_1( + uint16_t* const prediction, const ptrdiff_t pred_stride, const int height, + const uint16_t* const obmc_prediction, const ptrdiff_t obmc_pred_stride) { + uint16_t* pred = prediction; + const uint16_t* obmc_pred = obmc_prediction; + const ptrdiff_t pred_stride2 = pred_stride << 1; + const ptrdiff_t obmc_pred_stride2 = obmc_pred_stride << 1; + const __m128i mask_inverter = _mm_cvtsi32_si128(0x40404040); + const __m128i mask_val = _mm_shufflelo_epi16(Load2(kObmcMask), 0x00); + // 64 - mask. + const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val); + const __m128i masks = + _mm_cvtepi8_epi16(_mm_unpacklo_epi8(mask_val, obmc_mask_val)); + int y = height; + do { + const __m128i pred_val = Load4x2(pred, pred + pred_stride); + const __m128i obmc_pred_val = + Load4x2(obmc_pred, obmc_pred + obmc_pred_stride); + const __m128i terms = _mm_unpacklo_epi16(pred_val, obmc_pred_val); + const __m128i result = RightShiftWithRounding_U32( + _mm_madd_epi16(terms, masks), kRoundBitsObmcBlend); + const __m128i packed_result = _mm_packus_epi32(result, result); + Store4(pred, packed_result); + Store4(pred + pred_stride, _mm_srli_si128(packed_result, 4)); + pred += pred_stride2; + obmc_pred += obmc_pred_stride2; + y -= 2; + } while (y != 0); +} + +inline void OverlapBlendFromLeft4xH_SSE4_1( + uint16_t* const prediction, const ptrdiff_t pred_stride, const int height, + const uint16_t* const obmc_prediction, const ptrdiff_t obmc_pred_stride) { + uint16_t* pred = prediction; + const uint16_t* obmc_pred = obmc_prediction; + const ptrdiff_t pred_stride2 = pred_stride << 1; + const ptrdiff_t obmc_pred_stride2 = obmc_pred_stride << 1; + const __m128i mask_inverter = _mm_cvtsi32_si128(0x40404040); + const __m128i mask_val = Load4(kObmcMask + 2); + // 64 - mask. + const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val); + const __m128i masks = + _mm_cvtepi8_epi16(_mm_unpacklo_epi8(mask_val, obmc_mask_val)); + int y = height; + do { + const __m128i pred_val = LoadHi8(LoadLo8(pred), pred + pred_stride); + const __m128i obmc_pred_val = + LoadHi8(LoadLo8(obmc_pred), obmc_pred + obmc_pred_stride); + const __m128i terms_lo = _mm_unpacklo_epi16(pred_val, obmc_pred_val); + const __m128i terms_hi = _mm_unpackhi_epi16(pred_val, obmc_pred_val); + const __m128i result_lo = RightShiftWithRounding_U32( + _mm_madd_epi16(terms_lo, masks), kRoundBitsObmcBlend); + const __m128i result_hi = RightShiftWithRounding_U32( + _mm_madd_epi16(terms_hi, masks), kRoundBitsObmcBlend); + const __m128i packed_result = _mm_packus_epi32(result_lo, result_hi); + StoreLo8(pred, packed_result); + StoreHi8(pred + pred_stride, packed_result); + pred += pred_stride2; + obmc_pred += obmc_pred_stride2; + y -= 2; + } while (y != 0); +} + +void OverlapBlendFromLeft10bpp_SSE4_1(void* const prediction, + const ptrdiff_t prediction_stride, + const int width, const int height, + const void* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + auto* pred = static_cast<uint16_t*>(prediction); + const auto* obmc_pred = static_cast<const uint16_t*>(obmc_prediction); + const ptrdiff_t pred_stride = prediction_stride / sizeof(pred[0]); + const ptrdiff_t obmc_pred_stride = + obmc_prediction_stride / sizeof(obmc_pred[0]); + + if (width == 2) { + OverlapBlendFromLeft2xH_SSE4_1(pred, pred_stride, height, obmc_pred, + obmc_pred_stride); + return; + } + if (width == 4) { + OverlapBlendFromLeft4xH_SSE4_1(pred, pred_stride, height, obmc_pred, + obmc_pred_stride); + return; + } + const __m128i mask_inverter = _mm_set1_epi8(64); + const uint8_t* mask = kObmcMask + width - 2; + int x = 0; + do { + pred = static_cast<uint16_t*>(prediction) + x; + obmc_pred = static_cast<const uint16_t*>(obmc_prediction) + x; + const __m128i mask_val = LoadLo8(mask + x); + // 64 - mask + const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val); + const __m128i masks = _mm_unpacklo_epi8(mask_val, obmc_mask_val); + const __m128i masks_lo = _mm_cvtepi8_epi16(masks); + const __m128i masks_hi = _mm_cvtepi8_epi16(_mm_srli_si128(masks, 8)); + int y = height; + do { + const __m128i pred_val = LoadUnaligned16(pred); + const __m128i obmc_pred_val = LoadUnaligned16(obmc_pred); + const __m128i terms_lo = _mm_unpacklo_epi16(pred_val, obmc_pred_val); + const __m128i terms_hi = _mm_unpackhi_epi16(pred_val, obmc_pred_val); + const __m128i result_lo = RightShiftWithRounding_U32( + _mm_madd_epi16(terms_lo, masks_lo), kRoundBitsObmcBlend); + const __m128i result_hi = RightShiftWithRounding_U32( + _mm_madd_epi16(terms_hi, masks_hi), kRoundBitsObmcBlend); + StoreUnaligned16(pred, _mm_packus_epi32(result_lo, result_hi)); + + pred += pred_stride; + obmc_pred += obmc_pred_stride; + } while (--y != 0); + x += 8; + } while (x < width); +} + +inline void OverlapBlendFromTop2xH_SSE4_1(uint16_t* const prediction, + const ptrdiff_t pred_stride, + const int height, + const uint16_t* const obmc_prediction, + const ptrdiff_t obmc_pred_stride) { + uint16_t* pred = prediction; + const uint16_t* obmc_pred = obmc_prediction; + const __m128i mask_inverter = _mm_set1_epi16(64); + const __m128i mask_shuffler = _mm_set_epi32(0x01010101, 0x01010101, 0, 0); + const __m128i mask_preinverter = _mm_set1_epi16(-256 | 1); + const uint8_t* mask = kObmcMask + height - 2; + const int compute_height = + height - (height >> 2); // compute_height based on 8-bit opt + const ptrdiff_t pred_stride2 = pred_stride << 1; + const ptrdiff_t obmc_pred_stride2 = obmc_pred_stride << 1; + int y = 0; + do { + // First mask in the first half, second mask in the second half. + const __m128i mask_val = _mm_shuffle_epi8(Load4(mask + y), mask_shuffler); + const __m128i masks = + _mm_sub_epi8(mask_inverter, _mm_sign_epi8(mask_val, mask_preinverter)); + const __m128i masks_lo = _mm_cvtepi8_epi16(masks); + const __m128i masks_hi = _mm_cvtepi8_epi16(_mm_srli_si128(masks, 8)); + + const __m128i pred_val = LoadHi8(LoadLo8(pred), pred + pred_stride); + const __m128i obmc_pred_val = + LoadHi8(LoadLo8(obmc_pred), obmc_pred + obmc_pred_stride); + const __m128i terms_lo = _mm_unpacklo_epi16(obmc_pred_val, pred_val); + const __m128i terms_hi = _mm_unpackhi_epi16(obmc_pred_val, pred_val); + const __m128i result_lo = RightShiftWithRounding_U32( + _mm_madd_epi16(terms_lo, masks_lo), kRoundBitsObmcBlend); + const __m128i result_hi = RightShiftWithRounding_U32( + _mm_madd_epi16(terms_hi, masks_hi), kRoundBitsObmcBlend); + const __m128i packed_result = _mm_packus_epi32(result_lo, result_hi); + + Store4(pred, packed_result); + Store4(pred + pred_stride, _mm_srli_si128(packed_result, 8)); + pred += pred_stride2; + obmc_pred += obmc_pred_stride2; + y += 2; + } while (y < compute_height); +} + +inline void OverlapBlendFromTop4xH_SSE4_1(uint16_t* const prediction, + const ptrdiff_t pred_stride, + const int height, + const uint16_t* const obmc_prediction, + const ptrdiff_t obmc_pred_stride) { + uint16_t* pred = prediction; + const uint16_t* obmc_pred = obmc_prediction; + const __m128i mask_inverter = _mm_set1_epi16(64); + const __m128i mask_shuffler = _mm_set_epi32(0x01010101, 0x01010101, 0, 0); + const __m128i mask_preinverter = _mm_set1_epi16(-256 | 1); + const uint8_t* mask = kObmcMask + height - 2; + const int compute_height = height - (height >> 2); + const ptrdiff_t pred_stride2 = pred_stride << 1; + const ptrdiff_t obmc_pred_stride2 = obmc_pred_stride << 1; + int y = 0; + do { + // First mask in the first half, second mask in the second half. + const __m128i mask_val = _mm_shuffle_epi8(Load4(mask + y), mask_shuffler); + const __m128i masks = + _mm_sub_epi8(mask_inverter, _mm_sign_epi8(mask_val, mask_preinverter)); + const __m128i masks_lo = _mm_cvtepi8_epi16(masks); + const __m128i masks_hi = _mm_cvtepi8_epi16(_mm_srli_si128(masks, 8)); + + const __m128i pred_val = LoadHi8(LoadLo8(pred), pred + pred_stride); + const __m128i obmc_pred_val = + LoadHi8(LoadLo8(obmc_pred), obmc_pred + obmc_pred_stride); + const __m128i terms_lo = _mm_unpacklo_epi16(obmc_pred_val, pred_val); + const __m128i terms_hi = _mm_unpackhi_epi16(obmc_pred_val, pred_val); + const __m128i result_lo = RightShiftWithRounding_U32( + _mm_madd_epi16(terms_lo, masks_lo), kRoundBitsObmcBlend); + const __m128i result_hi = RightShiftWithRounding_U32( + _mm_madd_epi16(terms_hi, masks_hi), kRoundBitsObmcBlend); + const __m128i packed_result = _mm_packus_epi32(result_lo, result_hi); + + StoreLo8(pred, packed_result); + StoreHi8(pred + pred_stride, packed_result); + pred += pred_stride2; + obmc_pred += obmc_pred_stride2; + y += 2; + } while (y < compute_height); +} + +void OverlapBlendFromTop10bpp_SSE4_1(void* const prediction, + const ptrdiff_t prediction_stride, + const int width, const int height, + const void* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + auto* pred = static_cast<uint16_t*>(prediction); + const auto* obmc_pred = static_cast<const uint16_t*>(obmc_prediction); + const ptrdiff_t pred_stride = prediction_stride / sizeof(pred[0]); + const ptrdiff_t obmc_pred_stride = + obmc_prediction_stride / sizeof(obmc_pred[0]); + + if (width == 2) { + OverlapBlendFromTop2xH_SSE4_1(pred, pred_stride, height, obmc_pred, + obmc_pred_stride); + return; + } + if (width == 4) { + OverlapBlendFromTop4xH_SSE4_1(pred, pred_stride, height, obmc_pred, + obmc_pred_stride); + return; + } + + const __m128i mask_inverter = _mm_set1_epi8(64); + const int compute_height = height - (height >> 2); + const uint8_t* mask = kObmcMask + height - 2; + pred = static_cast<uint16_t*>(prediction); + obmc_pred = static_cast<const uint16_t*>(obmc_prediction); + int y = 0; + do { + const __m128i mask_val = _mm_set1_epi8(mask[y]); + // 64 - mask + const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val); + const __m128i masks = _mm_unpacklo_epi8(mask_val, obmc_mask_val); + const __m128i masks_lo = _mm_cvtepi8_epi16(masks); + const __m128i masks_hi = _mm_cvtepi8_epi16(_mm_srli_si128(masks, 8)); + int x = 0; + do { + const __m128i pred_val = LoadUnaligned16(pred + x); + const __m128i obmc_pred_val = LoadUnaligned16(obmc_pred + x); + const __m128i terms_lo = _mm_unpacklo_epi16(pred_val, obmc_pred_val); + const __m128i terms_hi = _mm_unpackhi_epi16(pred_val, obmc_pred_val); + const __m128i result_lo = RightShiftWithRounding_U32( + _mm_madd_epi16(terms_lo, masks_lo), kRoundBitsObmcBlend); + const __m128i result_hi = RightShiftWithRounding_U32( + _mm_madd_epi16(terms_hi, masks_hi), kRoundBitsObmcBlend); + StoreUnaligned16(pred + x, _mm_packus_epi32(result_lo, result_hi)); + x += 8; + } while (x < width); + pred += pred_stride; + obmc_pred += obmc_pred_stride; + } while (++y < compute_height); +} + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); +#if DSP_ENABLED_10BPP_SSE4_1(ObmcVertical) + dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendFromTop10bpp_SSE4_1; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(ObmcHorizontal) + dsp->obmc_blend[kObmcDirectionHorizontal] = OverlapBlendFromLeft10bpp_SSE4_1; +#endif +} + +} // namespace +} // namespace high_bitdepth +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +void ObmcInit_SSE4_1() { + low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + high_bitdepth::Init10bpp(); +#endif // LIBGAV1_MAX_BITDEPTH >= 10 +} } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_SSE4_1 +#else // !LIBGAV1_TARGETING_SSE4_1 namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/obmc_sse4.h b/src/dsp/x86/obmc_sse4.h index bd8b416..448d2cf 100644 --- a/src/dsp/x86/obmc_sse4.h +++ b/src/dsp/x86/obmc_sse4.h @@ -38,6 +38,12 @@ void ObmcInit_SSE4_1(); #ifndef LIBGAV1_Dsp8bpp_ObmcHorizontal #define LIBGAV1_Dsp8bpp_ObmcHorizontal LIBGAV1_CPU_SSE4_1 #endif +#ifndef LIBGAV1_Dsp10bpp_ObmcVertical +#define LIBGAV1_Dsp10bpp_ObmcVertical LIBGAV1_CPU_SSE4_1 +#endif +#ifndef LIBGAV1_Dsp10bpp_ObmcHorizontal +#define LIBGAV1_Dsp10bpp_ObmcHorizontal LIBGAV1_CPU_SSE4_1 +#endif #endif // LIBGAV1_TARGETING_SSE4_1 #endif // LIBGAV1_SRC_DSP_X86_OBMC_SSE4_H_ diff --git a/src/dsp/x86/super_res_sse4.cc b/src/dsp/x86/super_res_sse4.cc index b2bdfd2..85d05bc 100644 --- a/src/dsp/x86/super_res_sse4.cc +++ b/src/dsp/x86/super_res_sse4.cc @@ -91,10 +91,10 @@ void SuperResCoefficients_SSE4_1(const int upscaled_width, } void SuperRes_SSE4_1(const void* const coefficients, void* const source, - const ptrdiff_t stride, const int height, + const ptrdiff_t source_stride, const int height, const int downscaled_width, const int upscaled_width, const int initial_subpixel_x, const int step, - void* const dest) { + void* const dest, const ptrdiff_t dest_stride) { auto* src = static_cast<uint8_t*>(source) - DivideBy2(kSuperResFilterTaps); auto* dst = static_cast<uint8_t*>(dest); int y = height; @@ -104,16 +104,30 @@ void SuperRes_SSE4_1(const void* const coefficients, void* const source, ExtendLine<uint8_t>(src + DivideBy2(kSuperResFilterTaps), downscaled_width, kSuperResHorizontalBorder, kSuperResHorizontalBorder); int subpixel_x = initial_subpixel_x; - // The below code calculates up to 15 extra upscaled - // pixels which will over-read up to 15 downscaled pixels in the end of each - // row. kSuperResHorizontalBorder accounts for this. + // The below code calculates up to 15 extra upscaled pixels which will + // over-read up to 15 downscaled pixels in the end of each row. + // kSuperResHorizontalPadding protects this behavior from segmentation + // faults and threading issues. int x = RightShiftWithCeiling(upscaled_width, 4); do { __m128i weighted_src[8]; for (int i = 0; i < 8; ++i, filter += 16) { - __m128i s = LoadLo8(&src[subpixel_x >> kSuperResScaleBits]); + // TODO(b/178652672): Remove Msan loads when hadd bug is resolved. + // It's fine to write uninitialized bytes outside the frame, but the + // inside-frame pixels are incorrectly labeled uninitialized if + // uninitialized values go through the hadd intrinsics. + // |src| is offset 4 pixels to the left, and there are 4 extended border + // pixels, so a difference of 0 from |downscaled_width| indicates 8 good + // bytes. A difference of 1 indicates 7 good bytes. + const int msan_bytes_lo = + (subpixel_x >> kSuperResScaleBits) - downscaled_width; + __m128i s = + LoadLo8Msan(&src[subpixel_x >> kSuperResScaleBits], msan_bytes_lo); subpixel_x += step; - s = LoadHi8(s, &src[subpixel_x >> kSuperResScaleBits]); + const int msan_bytes_hi = + (subpixel_x >> kSuperResScaleBits) - downscaled_width; + s = LoadHi8Msan(s, &src[subpixel_x >> kSuperResScaleBits], + msan_bytes_hi); subpixel_x += step; const __m128i f = LoadAligned16(filter); weighted_src[i] = _mm_maddubs_epi16(s, f); @@ -135,26 +149,165 @@ void SuperRes_SSE4_1(const void* const coefficients, void* const source, StoreAligned16(dst_ptr, _mm_packus_epi16(a[0], a[1])); dst_ptr += 16; } while (--x != 0); - src += stride; - dst += stride; + src += source_stride; + dst += dest_stride; } while (--y != 0); } void Init8bpp() { Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth8); +#if DSP_ENABLED_8BPP_SSE4_1(SuperResCoefficients) dsp->super_res_coefficients = SuperResCoefficients_SSE4_1; +#endif // DSP_ENABLED_8BPP_SSE4_1(SuperResCoefficients) +#if DSP_ENABLED_8BPP_SSE4_1(SuperRes) dsp->super_res = SuperRes_SSE4_1; +#endif // DSP_ENABLED_8BPP_SSE4_1(SuperRes) } } // namespace } // namespace low_bitdepth -void SuperResInit_SSE4_1() { low_bitdepth::Init8bpp(); } +//------------------------------------------------------------------------------ +#if LIBGAV1_MAX_BITDEPTH >= 10 +namespace high_bitdepth { +namespace { + +// Upscale_Filter as defined in AV1 Section 7.16 +alignas(16) const int16_t + kUpscaleFilter[kSuperResFilterShifts][kSuperResFilterTaps] = { + {0, 0, 0, 128, 0, 0, 0, 0}, {0, 0, -1, 128, 2, -1, 0, 0}, + {0, 1, -3, 127, 4, -2, 1, 0}, {0, 1, -4, 127, 6, -3, 1, 0}, + {0, 2, -6, 126, 8, -3, 1, 0}, {0, 2, -7, 125, 11, -4, 1, 0}, + {-1, 2, -8, 125, 13, -5, 2, 0}, {-1, 3, -9, 124, 15, -6, 2, 0}, + {-1, 3, -10, 123, 18, -6, 2, -1}, {-1, 3, -11, 122, 20, -7, 3, -1}, + {-1, 4, -12, 121, 22, -8, 3, -1}, {-1, 4, -13, 120, 25, -9, 3, -1}, + {-1, 4, -14, 118, 28, -9, 3, -1}, {-1, 4, -15, 117, 30, -10, 4, -1}, + {-1, 5, -16, 116, 32, -11, 4, -1}, {-1, 5, -16, 114, 35, -12, 4, -1}, + {-1, 5, -17, 112, 38, -12, 4, -1}, {-1, 5, -18, 111, 40, -13, 5, -1}, + {-1, 5, -18, 109, 43, -14, 5, -1}, {-1, 6, -19, 107, 45, -14, 5, -1}, + {-1, 6, -19, 105, 48, -15, 5, -1}, {-1, 6, -19, 103, 51, -16, 5, -1}, + {-1, 6, -20, 101, 53, -16, 6, -1}, {-1, 6, -20, 99, 56, -17, 6, -1}, + {-1, 6, -20, 97, 58, -17, 6, -1}, {-1, 6, -20, 95, 61, -18, 6, -1}, + {-2, 7, -20, 93, 64, -18, 6, -2}, {-2, 7, -20, 91, 66, -19, 6, -1}, + {-2, 7, -20, 88, 69, -19, 6, -1}, {-2, 7, -20, 86, 71, -19, 6, -1}, + {-2, 7, -20, 84, 74, -20, 7, -2}, {-2, 7, -20, 81, 76, -20, 7, -1}, + {-2, 7, -20, 79, 79, -20, 7, -2}, {-1, 7, -20, 76, 81, -20, 7, -2}, + {-2, 7, -20, 74, 84, -20, 7, -2}, {-1, 6, -19, 71, 86, -20, 7, -2}, + {-1, 6, -19, 69, 88, -20, 7, -2}, {-1, 6, -19, 66, 91, -20, 7, -2}, + {-2, 6, -18, 64, 93, -20, 7, -2}, {-1, 6, -18, 61, 95, -20, 6, -1}, + {-1, 6, -17, 58, 97, -20, 6, -1}, {-1, 6, -17, 56, 99, -20, 6, -1}, + {-1, 6, -16, 53, 101, -20, 6, -1}, {-1, 5, -16, 51, 103, -19, 6, -1}, + {-1, 5, -15, 48, 105, -19, 6, -1}, {-1, 5, -14, 45, 107, -19, 6, -1}, + {-1, 5, -14, 43, 109, -18, 5, -1}, {-1, 5, -13, 40, 111, -18, 5, -1}, + {-1, 4, -12, 38, 112, -17, 5, -1}, {-1, 4, -12, 35, 114, -16, 5, -1}, + {-1, 4, -11, 32, 116, -16, 5, -1}, {-1, 4, -10, 30, 117, -15, 4, -1}, + {-1, 3, -9, 28, 118, -14, 4, -1}, {-1, 3, -9, 25, 120, -13, 4, -1}, + {-1, 3, -8, 22, 121, -12, 4, -1}, {-1, 3, -7, 20, 122, -11, 3, -1}, + {-1, 2, -6, 18, 123, -10, 3, -1}, {0, 2, -6, 15, 124, -9, 3, -1}, + {0, 2, -5, 13, 125, -8, 2, -1}, {0, 1, -4, 11, 125, -7, 2, 0}, + {0, 1, -3, 8, 126, -6, 2, 0}, {0, 1, -3, 6, 127, -4, 1, 0}, + {0, 1, -2, 4, 127, -3, 1, 0}, {0, 0, -1, 2, 128, -1, 0, 0}, +}; + +void SuperResCoefficients_SSE4_1(const int upscaled_width, + const int initial_subpixel_x, const int step, + void* const coefficients) { + auto* dst = static_cast<uint16_t*>(coefficients); + int subpixel_x = initial_subpixel_x; + int x = RightShiftWithCeiling(upscaled_width, 3); + do { + for (int i = 0; i < 8; ++i, dst += 8) { + int remainder = subpixel_x & kSuperResScaleMask; + __m128i filter = + LoadAligned16(kUpscaleFilter[remainder >> kSuperResExtraBits]); + subpixel_x += step; + StoreAligned16(dst, filter); + } + } while (--x != 0); +} + +template <int bitdepth> +void SuperRes_SSE4_1(const void* const coefficients, void* const source, + const ptrdiff_t source_stride, const int height, + const int downscaled_width, const int upscaled_width, + const int initial_subpixel_x, const int step, + void* const dest, const ptrdiff_t dest_stride) { + auto* src = static_cast<uint16_t*>(source) - DivideBy2(kSuperResFilterTaps); + auto* dst = static_cast<uint16_t*>(dest); + int y = height; + do { + const auto* filter = static_cast<const uint16_t*>(coefficients); + uint16_t* dst_ptr = dst; + ExtendLine<uint16_t>(src + DivideBy2(kSuperResFilterTaps), downscaled_width, + kSuperResHorizontalBorder, kSuperResHorizontalPadding); + int subpixel_x = initial_subpixel_x; + // The below code calculates up to 7 extra upscaled + // pixels which will over-read up to 7 downscaled pixels in the end of each + // row. kSuperResHorizontalPadding accounts for this. + int x = RightShiftWithCeiling(upscaled_width, 3); + do { + __m128i weighted_src[8]; + for (int i = 0; i < 8; ++i, filter += 8) { + const __m128i s = + LoadUnaligned16(&src[subpixel_x >> kSuperResScaleBits]); + subpixel_x += step; + const __m128i f = LoadAligned16(filter); + weighted_src[i] = _mm_madd_epi16(s, f); + } + + __m128i a[4]; + a[0] = _mm_hadd_epi32(weighted_src[0], weighted_src[1]); + a[1] = _mm_hadd_epi32(weighted_src[2], weighted_src[3]); + a[2] = _mm_hadd_epi32(weighted_src[4], weighted_src[5]); + a[3] = _mm_hadd_epi32(weighted_src[6], weighted_src[7]); + + a[0] = _mm_hadd_epi32(a[0], a[1]); + a[1] = _mm_hadd_epi32(a[2], a[3]); + a[0] = RightShiftWithRounding_S32(a[0], kFilterBits); + a[1] = RightShiftWithRounding_S32(a[1], kFilterBits); + + // Clip the values at (1 << bd) - 1 + const __m128i clipped_16 = _mm_min_epi16( + _mm_packus_epi32(a[0], a[1]), _mm_set1_epi16((1 << bitdepth) - 1)); + StoreAligned16(dst_ptr, clipped_16); + dst_ptr += 8; + } while (--x != 0); + src += source_stride; + dst += dest_stride; + } while (--y != 0); +} + +void Init10bpp() { + Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + static_cast<void>(dsp); +#if DSP_ENABLED_10BPP_SSE4_1(SuperResCoefficients) + dsp->super_res_coefficients = SuperResCoefficients_SSE4_1; +#else + static_cast<void>(SuperResCoefficients_SSE4_1); +#endif +#if DSP_ENABLED_10BPP_SSE4_1(SuperRes) + dsp->super_res = SuperRes_SSE4_1<10>; +#else + static_cast<void>(SuperRes_SSE4_1); +#endif +} + +} // namespace +} // namespace high_bitdepth +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +void SuperResInit_SSE4_1() { + low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + high_bitdepth::Init10bpp(); +#endif +} } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_SSE4_1 +#else // !LIBGAV1_TARGETING_SSE4_1 namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/super_res_sse4.h b/src/dsp/x86/super_res_sse4.h index aef5147..07a7ef4 100644 --- a/src/dsp/x86/super_res_sse4.h +++ b/src/dsp/x86/super_res_sse4.h @@ -30,9 +30,21 @@ void SuperResInit_SSE4_1(); } // namespace libgav1 #if LIBGAV1_TARGETING_SSE4_1 +#ifndef LIBGAV1_Dsp8bpp_SuperResCoefficients +#define LIBGAV1_Dsp8bpp_SuperResCoefficients LIBGAV1_CPU_SSE4_1 +#endif + #ifndef LIBGAV1_Dsp8bpp_SuperRes #define LIBGAV1_Dsp8bpp_SuperRes LIBGAV1_CPU_SSE4_1 #endif + +#ifndef LIBGAV1_Dsp10bpp_SuperResCoefficients +#define LIBGAV1_Dsp10bpp_SuperResCoefficients LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_SuperRes +#define LIBGAV1_Dsp10bpp_SuperRes LIBGAV1_CPU_SSE4_1 +#endif #endif // LIBGAV1_TARGETING_SSE4_1 #endif // LIBGAV1_SRC_DSP_X86_SUPER_RES_SSE4_H_ diff --git a/src/dsp/x86/transpose_sse4.h b/src/dsp/x86/transpose_sse4.h index 208b301..9726495 100644 --- a/src/dsp/x86/transpose_sse4.h +++ b/src/dsp/x86/transpose_sse4.h @@ -30,9 +30,9 @@ LIBGAV1_ALWAYS_INLINE void Transpose2x16_U16(const __m128i* const in, __m128i* const out) { // Unpack 16 bit elements. Goes from: // in[0]: 00 01 10 11 20 21 30 31 - // in[0]: 40 41 50 51 60 61 70 71 - // in[0]: 80 81 90 91 a0 a1 b0 b1 - // in[0]: c0 c1 d0 d1 e0 e1 f0 f1 + // in[1]: 40 41 50 51 60 61 70 71 + // in[2]: 80 81 90 91 a0 a1 b0 b1 + // in[3]: c0 c1 d0 d1 e0 e1 f0 f1 // to: // a0: 00 40 01 41 10 50 11 51 // a1: 20 60 21 61 30 70 31 71 diff --git a/src/dsp/x86/warp_sse4.cc b/src/dsp/x86/warp_sse4.cc index 43279ab..9ddfeac 100644 --- a/src/dsp/x86/warp_sse4.cc +++ b/src/dsp/x86/warp_sse4.cc @@ -513,7 +513,7 @@ void WarpInit_SSE4_1() { low_bitdepth::Init8bpp(); } } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_SSE4_1 +#else // !LIBGAV1_TARGETING_SSE4_1 namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/weight_mask_sse4.cc b/src/dsp/x86/weight_mask_sse4.cc index dfd5662..08a1739 100644 --- a/src/dsp/x86/weight_mask_sse4.cc +++ b/src/dsp/x86/weight_mask_sse4.cc @@ -36,47 +36,65 @@ namespace { constexpr int kRoundingBits8bpp = 4; -template <bool mask_is_inverse> -inline void WeightMask8_SSE4(const int16_t* prediction_0, - const int16_t* prediction_1, uint8_t* mask) { - const __m128i pred_0 = LoadAligned16(prediction_0); - const __m128i pred_1 = LoadAligned16(prediction_1); - const __m128i difference = RightShiftWithRounding_U16( - _mm_abs_epi16(_mm_sub_epi16(pred_0, pred_1)), kRoundingBits8bpp); - const __m128i scaled_difference = _mm_srli_epi16(difference, 4); +template <bool mask_is_inverse, bool is_store_16> +inline void WeightMask16_SSE4(const int16_t* prediction_0, + const int16_t* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride) { + const __m128i pred_00 = LoadAligned16(prediction_0); + const __m128i pred_10 = LoadAligned16(prediction_1); + const __m128i difference_0 = RightShiftWithRounding_U16( + _mm_abs_epi16(_mm_sub_epi16(pred_00, pred_10)), kRoundingBits8bpp); + const __m128i scaled_difference_0 = _mm_srli_epi16(difference_0, 4); + + const __m128i pred_01 = LoadAligned16(prediction_0 + 8); + const __m128i pred_11 = LoadAligned16(prediction_1 + 8); + const __m128i difference_1 = RightShiftWithRounding_U16( + _mm_abs_epi16(_mm_sub_epi16(pred_01, pred_11)), kRoundingBits8bpp); + const __m128i scaled_difference_1 = _mm_srli_epi16(difference_1, 4); + const __m128i difference_offset = _mm_set1_epi8(38); const __m128i adjusted_difference = - _mm_adds_epu8(_mm_packus_epi16(scaled_difference, scaled_difference), + _mm_adds_epu8(_mm_packus_epi16(scaled_difference_0, scaled_difference_1), difference_offset); const __m128i mask_ceiling = _mm_set1_epi8(64); const __m128i mask_value = _mm_min_epi8(adjusted_difference, mask_ceiling); if (mask_is_inverse) { const __m128i inverted_mask_value = _mm_sub_epi8(mask_ceiling, mask_value); - StoreLo8(mask, inverted_mask_value); + if (is_store_16) { + StoreAligned16(mask, inverted_mask_value); + } else { + StoreLo8(mask, inverted_mask_value); + StoreHi8(mask + mask_stride, inverted_mask_value); + } } else { - StoreLo8(mask, mask_value); + if (is_store_16) { + StoreAligned16(mask, mask_value); + } else { + StoreLo8(mask, mask_value); + StoreHi8(mask + mask_stride, mask_value); + } } } -#define WEIGHT8_WITHOUT_STRIDE \ - WeightMask8_SSE4<mask_is_inverse>(pred_0, pred_1, mask) +#define WEIGHT8_PAIR_WITHOUT_STRIDE \ + WeightMask16_SSE4<mask_is_inverse, false>(pred_0, pred_1, mask, mask_stride) -#define WEIGHT8_AND_STRIDE \ - WEIGHT8_WITHOUT_STRIDE; \ - pred_0 += 8; \ - pred_1 += 8; \ - mask += mask_stride +#define WEIGHT8_PAIR_AND_STRIDE \ + WEIGHT8_PAIR_WITHOUT_STRIDE; \ + pred_0 += 8 << 1; \ + pred_1 += 8 << 1; \ + mask += mask_stride << 1 template <bool mask_is_inverse> void WeightMask8x8_SSE4(const void* prediction_0, const void* prediction_1, uint8_t* mask, ptrdiff_t mask_stride) { const auto* pred_0 = static_cast<const int16_t*>(prediction_0); const auto* pred_1 = static_cast<const int16_t*>(prediction_1); - int y = 0; - do { - WEIGHT8_AND_STRIDE; - } while (++y < 7); - WEIGHT8_WITHOUT_STRIDE; + + WEIGHT8_PAIR_AND_STRIDE; + WEIGHT8_PAIR_AND_STRIDE; + WEIGHT8_PAIR_AND_STRIDE; + WEIGHT8_PAIR_WITHOUT_STRIDE; } template <bool mask_is_inverse> @@ -84,13 +102,13 @@ void WeightMask8x16_SSE4(const void* prediction_0, const void* prediction_1, uint8_t* mask, ptrdiff_t mask_stride) { const auto* pred_0 = static_cast<const int16_t*>(prediction_0); const auto* pred_1 = static_cast<const int16_t*>(prediction_1); - int y3 = 0; + int y3 = 3; do { - WEIGHT8_AND_STRIDE; - WEIGHT8_AND_STRIDE; - WEIGHT8_AND_STRIDE; - } while (++y3 < 5); - WEIGHT8_WITHOUT_STRIDE; + WEIGHT8_PAIR_AND_STRIDE; + WEIGHT8_PAIR_AND_STRIDE; + } while (--y3 != 0); + WEIGHT8_PAIR_AND_STRIDE; + WEIGHT8_PAIR_WITHOUT_STRIDE; } template <bool mask_is_inverse> @@ -98,21 +116,17 @@ void WeightMask8x32_SSE4(const void* prediction_0, const void* prediction_1, uint8_t* mask, ptrdiff_t mask_stride) { const auto* pred_0 = static_cast<const int16_t*>(prediction_0); const auto* pred_1 = static_cast<const int16_t*>(prediction_1); - int y5 = 0; + int y5 = 5; do { - WEIGHT8_AND_STRIDE; - WEIGHT8_AND_STRIDE; - WEIGHT8_AND_STRIDE; - WEIGHT8_AND_STRIDE; - WEIGHT8_AND_STRIDE; - } while (++y5 < 6); - WEIGHT8_AND_STRIDE; - WEIGHT8_WITHOUT_STRIDE; + WEIGHT8_PAIR_AND_STRIDE; + WEIGHT8_PAIR_AND_STRIDE; + WEIGHT8_PAIR_AND_STRIDE; + } while (--y5 != 0); + WEIGHT8_PAIR_WITHOUT_STRIDE; } -#define WEIGHT16_WITHOUT_STRIDE \ - WeightMask8_SSE4<mask_is_inverse>(pred_0, pred_1, mask); \ - WeightMask8_SSE4<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8) +#define WEIGHT16_WITHOUT_STRIDE \ + WeightMask16_SSE4<mask_is_inverse, true>(pred_0, pred_1, mask, mask_stride) #define WEIGHT16_AND_STRIDE \ WEIGHT16_WITHOUT_STRIDE; \ @@ -125,10 +139,10 @@ void WeightMask16x8_SSE4(const void* prediction_0, const void* prediction_1, uint8_t* mask, ptrdiff_t mask_stride) { const auto* pred_0 = static_cast<const int16_t*>(prediction_0); const auto* pred_1 = static_cast<const int16_t*>(prediction_1); - int y = 0; + int y = 7; do { WEIGHT16_AND_STRIDE; - } while (++y < 7); + } while (--y != 0); WEIGHT16_WITHOUT_STRIDE; } @@ -137,12 +151,12 @@ void WeightMask16x16_SSE4(const void* prediction_0, const void* prediction_1, uint8_t* mask, ptrdiff_t mask_stride) { const auto* pred_0 = static_cast<const int16_t*>(prediction_0); const auto* pred_1 = static_cast<const int16_t*>(prediction_1); - int y3 = 0; + int y3 = 5; do { WEIGHT16_AND_STRIDE; WEIGHT16_AND_STRIDE; WEIGHT16_AND_STRIDE; - } while (++y3 < 5); + } while (--y3 != 0); WEIGHT16_WITHOUT_STRIDE; } @@ -151,14 +165,14 @@ void WeightMask16x32_SSE4(const void* prediction_0, const void* prediction_1, uint8_t* mask, ptrdiff_t mask_stride) { const auto* pred_0 = static_cast<const int16_t*>(prediction_0); const auto* pred_1 = static_cast<const int16_t*>(prediction_1); - int y5 = 0; + int y5 = 6; do { WEIGHT16_AND_STRIDE; WEIGHT16_AND_STRIDE; WEIGHT16_AND_STRIDE; WEIGHT16_AND_STRIDE; WEIGHT16_AND_STRIDE; - } while (++y5 < 6); + } while (--y5 != 0); WEIGHT16_AND_STRIDE; WEIGHT16_WITHOUT_STRIDE; } @@ -168,20 +182,19 @@ void WeightMask16x64_SSE4(const void* prediction_0, const void* prediction_1, uint8_t* mask, ptrdiff_t mask_stride) { const auto* pred_0 = static_cast<const int16_t*>(prediction_0); const auto* pred_1 = static_cast<const int16_t*>(prediction_1); - int y3 = 0; + int y3 = 21; do { WEIGHT16_AND_STRIDE; WEIGHT16_AND_STRIDE; WEIGHT16_AND_STRIDE; - } while (++y3 < 21); + } while (--y3 != 0); WEIGHT16_WITHOUT_STRIDE; } -#define WEIGHT32_WITHOUT_STRIDE \ - WeightMask8_SSE4<mask_is_inverse>(pred_0, pred_1, mask); \ - WeightMask8_SSE4<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8); \ - WeightMask8_SSE4<mask_is_inverse>(pred_0 + 16, pred_1 + 16, mask + 16); \ - WeightMask8_SSE4<mask_is_inverse>(pred_0 + 24, pred_1 + 24, mask + 24) +#define WEIGHT32_WITHOUT_STRIDE \ + WeightMask16_SSE4<mask_is_inverse, true>(pred_0, pred_1, mask, mask_stride); \ + WeightMask16_SSE4<mask_is_inverse, true>(pred_0 + 16, pred_1 + 16, \ + mask + 16, mask_stride) #define WEIGHT32_AND_STRIDE \ WEIGHT32_WITHOUT_STRIDE; \ @@ -209,12 +222,12 @@ void WeightMask32x16_SSE4(const void* prediction_0, const void* prediction_1, uint8_t* mask, ptrdiff_t mask_stride) { const auto* pred_0 = static_cast<const int16_t*>(prediction_0); const auto* pred_1 = static_cast<const int16_t*>(prediction_1); - int y3 = 0; + int y3 = 5; do { WEIGHT32_AND_STRIDE; WEIGHT32_AND_STRIDE; WEIGHT32_AND_STRIDE; - } while (++y3 < 5); + } while (--y3 != 0); WEIGHT32_WITHOUT_STRIDE; } @@ -223,14 +236,14 @@ void WeightMask32x32_SSE4(const void* prediction_0, const void* prediction_1, uint8_t* mask, ptrdiff_t mask_stride) { const auto* pred_0 = static_cast<const int16_t*>(prediction_0); const auto* pred_1 = static_cast<const int16_t*>(prediction_1); - int y5 = 0; + int y5 = 6; do { WEIGHT32_AND_STRIDE; WEIGHT32_AND_STRIDE; WEIGHT32_AND_STRIDE; WEIGHT32_AND_STRIDE; WEIGHT32_AND_STRIDE; - } while (++y5 < 6); + } while (--y5 != 0); WEIGHT32_AND_STRIDE; WEIGHT32_WITHOUT_STRIDE; } @@ -240,24 +253,23 @@ void WeightMask32x64_SSE4(const void* prediction_0, const void* prediction_1, uint8_t* mask, ptrdiff_t mask_stride) { const auto* pred_0 = static_cast<const int16_t*>(prediction_0); const auto* pred_1 = static_cast<const int16_t*>(prediction_1); - int y3 = 0; + int y3 = 21; do { WEIGHT32_AND_STRIDE; WEIGHT32_AND_STRIDE; WEIGHT32_AND_STRIDE; - } while (++y3 < 21); + } while (--y3 != 0); WEIGHT32_WITHOUT_STRIDE; } -#define WEIGHT64_WITHOUT_STRIDE \ - WeightMask8_SSE4<mask_is_inverse>(pred_0, pred_1, mask); \ - WeightMask8_SSE4<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8); \ - WeightMask8_SSE4<mask_is_inverse>(pred_0 + 16, pred_1 + 16, mask + 16); \ - WeightMask8_SSE4<mask_is_inverse>(pred_0 + 24, pred_1 + 24, mask + 24); \ - WeightMask8_SSE4<mask_is_inverse>(pred_0 + 32, pred_1 + 32, mask + 32); \ - WeightMask8_SSE4<mask_is_inverse>(pred_0 + 40, pred_1 + 40, mask + 40); \ - WeightMask8_SSE4<mask_is_inverse>(pred_0 + 48, pred_1 + 48, mask + 48); \ - WeightMask8_SSE4<mask_is_inverse>(pred_0 + 56, pred_1 + 56, mask + 56) +#define WEIGHT64_WITHOUT_STRIDE \ + WeightMask16_SSE4<mask_is_inverse, true>(pred_0, pred_1, mask, mask_stride); \ + WeightMask16_SSE4<mask_is_inverse, true>(pred_0 + 16, pred_1 + 16, \ + mask + 16, mask_stride); \ + WeightMask16_SSE4<mask_is_inverse, true>(pred_0 + 32, pred_1 + 32, \ + mask + 32, mask_stride); \ + WeightMask16_SSE4<mask_is_inverse, true>(pred_0 + 48, pred_1 + 48, \ + mask + 48, mask_stride) #define WEIGHT64_AND_STRIDE \ WEIGHT64_WITHOUT_STRIDE; \ @@ -447,12 +459,491 @@ void Init8bpp() { } // namespace } // namespace low_bitdepth -void WeightMaskInit_SSE4_1() { low_bitdepth::Init8bpp(); } +#if LIBGAV1_MAX_BITDEPTH >= 10 +namespace high_bitdepth { +namespace { + +constexpr int kRoundingBits10bpp = 6; +constexpr int kScaledDiffShift = 4; + +template <bool mask_is_inverse, bool is_store_16> +inline void WeightMask16_10bpp_SSE4(const uint16_t* prediction_0, + const uint16_t* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride) { + const __m128i diff_offset = _mm_set1_epi8(38); + const __m128i mask_ceiling = _mm_set1_epi8(64); + const __m128i zero = _mm_setzero_si128(); + + // Range of prediction: [3988, 61532]. + const __m128i pred_00 = LoadAligned16(prediction_0); + const __m128i pred_10 = LoadAligned16(prediction_1); + const __m128i pred_lo_00 = _mm_cvtepu16_epi32(pred_00); + const __m128i pred_lo_10 = _mm_cvtepu16_epi32(pred_10); + const __m128i diff_lo_0 = RightShiftWithRounding_U32( + _mm_abs_epi32(_mm_sub_epi32(pred_lo_00, pred_lo_10)), kRoundingBits10bpp); + + const __m128i pred_hi_00 = _mm_unpackhi_epi16(pred_00, zero); + const __m128i pred_hi_10 = _mm_unpackhi_epi16(pred_10, zero); + const __m128i diff_hi_0 = RightShiftWithRounding_U32( + _mm_abs_epi32(_mm_sub_epi32(pred_hi_00, pred_hi_10)), kRoundingBits10bpp); + + const __m128i diff_0 = _mm_packus_epi32(diff_lo_0, diff_hi_0); + const __m128i scaled_diff_0 = _mm_srli_epi16(diff_0, kScaledDiffShift); + + const __m128i pred_01 = LoadAligned16(prediction_0 + 8); + const __m128i pred_11 = LoadAligned16(prediction_1 + 8); + const __m128i pred_lo_01 = _mm_cvtepu16_epi32(pred_01); + const __m128i pred_lo_11 = _mm_cvtepu16_epi32(pred_11); + const __m128i diff_lo_1 = RightShiftWithRounding_U32( + _mm_abs_epi32(_mm_sub_epi32(pred_lo_01, pred_lo_11)), kRoundingBits10bpp); + + const __m128i pred_hi_01 = _mm_unpackhi_epi16(pred_01, zero); + const __m128i pred_hi_11 = _mm_unpackhi_epi16(pred_11, zero); + const __m128i diff_hi_1 = RightShiftWithRounding_U32( + _mm_abs_epi32(_mm_sub_epi32(pred_hi_01, pred_hi_11)), kRoundingBits10bpp); + + const __m128i diff_1 = _mm_packus_epi32(diff_lo_1, diff_hi_1); + const __m128i scaled_diff_1 = _mm_srli_epi16(diff_1, kScaledDiffShift); + + const __m128i adjusted_diff = _mm_adds_epu8( + _mm_packus_epi16(scaled_diff_0, scaled_diff_1), diff_offset); + const __m128i mask_value = _mm_min_epi8(adjusted_diff, mask_ceiling); + + if (mask_is_inverse) { + const __m128i inverted_mask_value = _mm_sub_epi8(mask_ceiling, mask_value); + if (is_store_16) { + StoreAligned16(mask, inverted_mask_value); + } else { + StoreLo8(mask, inverted_mask_value); + StoreHi8(mask + mask_stride, inverted_mask_value); + } + } else { + if (is_store_16) { + StoreAligned16(mask, mask_value); + } else { + StoreLo8(mask, mask_value); + StoreHi8(mask + mask_stride, mask_value); + } + } +} + +#define WEIGHT8_PAIR_WITHOUT_STRIDE_10BPP \ + WeightMask16_10bpp_SSE4<mask_is_inverse, false>(pred_0, pred_1, mask, \ + mask_stride) + +#define WEIGHT8_PAIR_AND_STRIDE_10BPP \ + WEIGHT8_PAIR_WITHOUT_STRIDE_10BPP; \ + pred_0 += 8 << 1; \ + pred_1 += 8 << 1; \ + mask += mask_stride << 1 + +template <bool mask_is_inverse> +void WeightMask8x8_10bpp_SSE4(const void* prediction_0, + const void* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + + WEIGHT8_PAIR_AND_STRIDE_10BPP; + WEIGHT8_PAIR_AND_STRIDE_10BPP; + WEIGHT8_PAIR_AND_STRIDE_10BPP; + WEIGHT8_PAIR_WITHOUT_STRIDE_10BPP; +} + +template <bool mask_is_inverse> +void WeightMask8x16_10bpp_SSE4(const void* prediction_0, + const void* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + int y3 = 3; + do { + WEIGHT8_PAIR_AND_STRIDE_10BPP; + WEIGHT8_PAIR_AND_STRIDE_10BPP; + } while (--y3 != 0); + WEIGHT8_PAIR_AND_STRIDE_10BPP; + WEIGHT8_PAIR_WITHOUT_STRIDE_10BPP; +} + +template <bool mask_is_inverse> +void WeightMask8x32_10bpp_SSE4(const void* prediction_0, + const void* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + int y5 = 5; + do { + WEIGHT8_PAIR_AND_STRIDE_10BPP; + WEIGHT8_PAIR_AND_STRIDE_10BPP; + WEIGHT8_PAIR_AND_STRIDE_10BPP; + } while (--y5 != 0); + WEIGHT8_PAIR_WITHOUT_STRIDE_10BPP; +} + +#define WEIGHT16_WITHOUT_STRIDE_10BPP \ + WeightMask16_10bpp_SSE4<mask_is_inverse, true>(pred_0, pred_1, mask, \ + mask_stride) + +#define WEIGHT16_AND_STRIDE_10BPP \ + WEIGHT16_WITHOUT_STRIDE_10BPP; \ + pred_0 += 16; \ + pred_1 += 16; \ + mask += mask_stride + +template <bool mask_is_inverse> +void WeightMask16x8_10bpp_SSE4(const void* prediction_0, + const void* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + int y = 7; + do { + WEIGHT16_AND_STRIDE_10BPP; + } while (--y != 0); + WEIGHT16_WITHOUT_STRIDE_10BPP; +} + +template <bool mask_is_inverse> +void WeightMask16x16_10bpp_SSE4(const void* prediction_0, + const void* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + int y3 = 5; + do { + WEIGHT16_AND_STRIDE_10BPP; + WEIGHT16_AND_STRIDE_10BPP; + WEIGHT16_AND_STRIDE_10BPP; + } while (--y3 != 0); + WEIGHT16_WITHOUT_STRIDE_10BPP; +} + +template <bool mask_is_inverse> +void WeightMask16x32_10bpp_SSE4(const void* prediction_0, + const void* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + int y5 = 6; + do { + WEIGHT16_AND_STRIDE_10BPP; + WEIGHT16_AND_STRIDE_10BPP; + WEIGHT16_AND_STRIDE_10BPP; + WEIGHT16_AND_STRIDE_10BPP; + WEIGHT16_AND_STRIDE_10BPP; + } while (--y5 != 0); + WEIGHT16_AND_STRIDE_10BPP; + WEIGHT16_WITHOUT_STRIDE_10BPP; +} + +template <bool mask_is_inverse> +void WeightMask16x64_10bpp_SSE4(const void* prediction_0, + const void* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + int y3 = 21; + do { + WEIGHT16_AND_STRIDE_10BPP; + WEIGHT16_AND_STRIDE_10BPP; + WEIGHT16_AND_STRIDE_10BPP; + } while (--y3 != 0); + WEIGHT16_WITHOUT_STRIDE_10BPP; +} + +#define WEIGHT32_WITHOUT_STRIDE_10BPP \ + WeightMask16_10bpp_SSE4<mask_is_inverse, true>(pred_0, pred_1, mask, \ + mask_stride); \ + WeightMask16_10bpp_SSE4<mask_is_inverse, true>(pred_0 + 16, pred_1 + 16, \ + mask + 16, mask_stride) + +#define WEIGHT32_AND_STRIDE_10BPP \ + WEIGHT32_WITHOUT_STRIDE_10BPP; \ + pred_0 += 32; \ + pred_1 += 32; \ + mask += mask_stride + +template <bool mask_is_inverse> +void WeightMask32x8_10bpp_SSE4(const void* prediction_0, + const void* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + WEIGHT32_AND_STRIDE_10BPP; + WEIGHT32_AND_STRIDE_10BPP; + WEIGHT32_AND_STRIDE_10BPP; + WEIGHT32_AND_STRIDE_10BPP; + WEIGHT32_AND_STRIDE_10BPP; + WEIGHT32_AND_STRIDE_10BPP; + WEIGHT32_AND_STRIDE_10BPP; + WEIGHT32_WITHOUT_STRIDE_10BPP; +} + +template <bool mask_is_inverse> +void WeightMask32x16_10bpp_SSE4(const void* prediction_0, + const void* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + int y3 = 5; + do { + WEIGHT32_AND_STRIDE_10BPP; + WEIGHT32_AND_STRIDE_10BPP; + WEIGHT32_AND_STRIDE_10BPP; + } while (--y3 != 0); + WEIGHT32_WITHOUT_STRIDE_10BPP; +} + +template <bool mask_is_inverse> +void WeightMask32x32_10bpp_SSE4(const void* prediction_0, + const void* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + int y5 = 6; + do { + WEIGHT32_AND_STRIDE_10BPP; + WEIGHT32_AND_STRIDE_10BPP; + WEIGHT32_AND_STRIDE_10BPP; + WEIGHT32_AND_STRIDE_10BPP; + WEIGHT32_AND_STRIDE_10BPP; + } while (--y5 != 0); + WEIGHT32_AND_STRIDE_10BPP; + WEIGHT32_WITHOUT_STRIDE_10BPP; +} + +template <bool mask_is_inverse> +void WeightMask32x64_10bpp_SSE4(const void* prediction_0, + const void* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + int y3 = 21; + do { + WEIGHT32_AND_STRIDE_10BPP; + WEIGHT32_AND_STRIDE_10BPP; + WEIGHT32_AND_STRIDE_10BPP; + } while (--y3 != 0); + WEIGHT32_WITHOUT_STRIDE_10BPP; +} + +#define WEIGHT64_WITHOUT_STRIDE_10BPP \ + WeightMask16_10bpp_SSE4<mask_is_inverse, true>(pred_0, pred_1, mask, \ + mask_stride); \ + WeightMask16_10bpp_SSE4<mask_is_inverse, true>(pred_0 + 16, pred_1 + 16, \ + mask + 16, mask_stride); \ + WeightMask16_10bpp_SSE4<mask_is_inverse, true>(pred_0 + 32, pred_1 + 32, \ + mask + 32, mask_stride); \ + WeightMask16_10bpp_SSE4<mask_is_inverse, true>(pred_0 + 48, pred_1 + 48, \ + mask + 48, mask_stride) + +#define WEIGHT64_AND_STRIDE_10BPP \ + WEIGHT64_WITHOUT_STRIDE_10BPP; \ + pred_0 += 64; \ + pred_1 += 64; \ + mask += mask_stride + +template <bool mask_is_inverse> +void WeightMask64x16_10bpp_SSE4(const void* prediction_0, + const void* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + int y3 = 5; + do { + WEIGHT64_AND_STRIDE_10BPP; + WEIGHT64_AND_STRIDE_10BPP; + WEIGHT64_AND_STRIDE_10BPP; + } while (--y3 != 0); + WEIGHT64_WITHOUT_STRIDE_10BPP; +} + +template <bool mask_is_inverse> +void WeightMask64x32_10bpp_SSE4(const void* prediction_0, + const void* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + int y5 = 6; + do { + WEIGHT64_AND_STRIDE_10BPP; + WEIGHT64_AND_STRIDE_10BPP; + WEIGHT64_AND_STRIDE_10BPP; + WEIGHT64_AND_STRIDE_10BPP; + WEIGHT64_AND_STRIDE_10BPP; + } while (--y5 != 0); + WEIGHT64_AND_STRIDE_10BPP; + WEIGHT64_WITHOUT_STRIDE_10BPP; +} + +template <bool mask_is_inverse> +void WeightMask64x64_10bpp_SSE4(const void* prediction_0, + const void* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + int y3 = 21; + do { + WEIGHT64_AND_STRIDE_10BPP; + WEIGHT64_AND_STRIDE_10BPP; + WEIGHT64_AND_STRIDE_10BPP; + } while (--y3 != 0); + WEIGHT64_WITHOUT_STRIDE_10BPP; +} + +template <bool mask_is_inverse> +void WeightMask64x128_10bpp_SSE4(const void* prediction_0, + const void* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + int y3 = 42; + do { + WEIGHT64_AND_STRIDE_10BPP; + WEIGHT64_AND_STRIDE_10BPP; + WEIGHT64_AND_STRIDE_10BPP; + } while (--y3 != 0); + WEIGHT64_AND_STRIDE_10BPP; + WEIGHT64_WITHOUT_STRIDE_10BPP; +} + +template <bool mask_is_inverse> +void WeightMask128x64_10bpp_SSE4(const void* prediction_0, + const void* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + int y3 = 21; + const ptrdiff_t adjusted_mask_stride = mask_stride - 64; + do { + WEIGHT64_WITHOUT_STRIDE_10BPP; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE_10BPP; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE_10BPP; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE_10BPP; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE_10BPP; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE_10BPP; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + } while (--y3 != 0); + WEIGHT64_WITHOUT_STRIDE_10BPP; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE_10BPP; +} + +template <bool mask_is_inverse> +void WeightMask128x128_10bpp_SSE4(const void* prediction_0, + const void* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + int y3 = 42; + const ptrdiff_t adjusted_mask_stride = mask_stride - 64; + do { + WEIGHT64_WITHOUT_STRIDE_10BPP; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE_10BPP; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE_10BPP; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE_10BPP; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE_10BPP; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE_10BPP; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + } while (--y3 != 0); + WEIGHT64_WITHOUT_STRIDE_10BPP; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE_10BPP; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE_10BPP; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE_10BPP; +} + +#define INIT_WEIGHT_MASK_10BPP(width, height, w_index, h_index) \ + dsp->weight_mask[w_index][h_index][0] = \ + WeightMask##width##x##height##_10bpp_SSE4<0>; \ + dsp->weight_mask[w_index][h_index][1] = \ + WeightMask##width##x##height##_10bpp_SSE4<1> +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + INIT_WEIGHT_MASK_10BPP(8, 8, 0, 0); + INIT_WEIGHT_MASK_10BPP(8, 16, 0, 1); + INIT_WEIGHT_MASK_10BPP(8, 32, 0, 2); + INIT_WEIGHT_MASK_10BPP(16, 8, 1, 0); + INIT_WEIGHT_MASK_10BPP(16, 16, 1, 1); + INIT_WEIGHT_MASK_10BPP(16, 32, 1, 2); + INIT_WEIGHT_MASK_10BPP(16, 64, 1, 3); + INIT_WEIGHT_MASK_10BPP(32, 8, 2, 0); + INIT_WEIGHT_MASK_10BPP(32, 16, 2, 1); + INIT_WEIGHT_MASK_10BPP(32, 32, 2, 2); + INIT_WEIGHT_MASK_10BPP(32, 64, 2, 3); + INIT_WEIGHT_MASK_10BPP(64, 16, 3, 1); + INIT_WEIGHT_MASK_10BPP(64, 32, 3, 2); + INIT_WEIGHT_MASK_10BPP(64, 64, 3, 3); + INIT_WEIGHT_MASK_10BPP(64, 128, 3, 4); + INIT_WEIGHT_MASK_10BPP(128, 64, 4, 3); + INIT_WEIGHT_MASK_10BPP(128, 128, 4, 4); +} + +} // namespace +} // namespace high_bitdepth +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +void WeightMaskInit_SSE4_1() { + low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + high_bitdepth::Init10bpp(); +#endif +} } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_SSE4_1 +#else // !LIBGAV1_TARGETING_SSE4_1 namespace libgav1 { namespace dsp { diff --git a/src/dsp/x86/weight_mask_sse4.h b/src/dsp/x86/weight_mask_sse4.h index 07636b7..e5d9d70 100644 --- a/src/dsp/x86/weight_mask_sse4.h +++ b/src/dsp/x86/weight_mask_sse4.h @@ -99,6 +99,73 @@ void WeightMaskInit_SSE4_1(); #define LIBGAV1_Dsp8bpp_WeightMask_128x128 LIBGAV1_CPU_SSE4_1 #endif +#ifndef LIBGAV1_Dsp10bpp_WeightMask_8x8 +#define LIBGAV1_Dsp10bpp_WeightMask_8x8 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_WeightMask_8x16 +#define LIBGAV1_Dsp10bpp_WeightMask_8x16 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_WeightMask_8x32 +#define LIBGAV1_Dsp10bpp_WeightMask_8x32 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_WeightMask_16x8 +#define LIBGAV1_Dsp10bpp_WeightMask_16x8 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_WeightMask_16x16 +#define LIBGAV1_Dsp10bpp_WeightMask_16x16 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_WeightMask_16x32 +#define LIBGAV1_Dsp10bpp_WeightMask_16x32 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_WeightMask_16x64 +#define LIBGAV1_Dsp10bpp_WeightMask_16x64 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_WeightMask_32x8 +#define LIBGAV1_Dsp10bpp_WeightMask_32x8 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_WeightMask_32x16 +#define LIBGAV1_Dsp10bpp_WeightMask_32x16 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_WeightMask_32x32 +#define LIBGAV1_Dsp10bpp_WeightMask_32x32 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_WeightMask_32x64 +#define LIBGAV1_Dsp10bpp_WeightMask_32x64 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_WeightMask_64x16 +#define LIBGAV1_Dsp10bpp_WeightMask_64x16 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_WeightMask_64x32 +#define LIBGAV1_Dsp10bpp_WeightMask_64x32 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_WeightMask_64x64 +#define LIBGAV1_Dsp10bpp_WeightMask_64x64 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_WeightMask_64x128 +#define LIBGAV1_Dsp10bpp_WeightMask_64x128 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_WeightMask_128x64 +#define LIBGAV1_Dsp10bpp_WeightMask_128x64 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_WeightMask_128x128 +#define LIBGAV1_Dsp10bpp_WeightMask_128x128 LIBGAV1_CPU_SSE4_1 +#endif #endif // LIBGAV1_TARGETING_SSE4_1 #endif // LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_SSE4_H_ |