diff options
Diffstat (limited to 'src/dsp/x86/distance_weighted_blend_sse4.cc')
-rw-r--r-- | src/dsp/x86/distance_weighted_blend_sse4.cc | 223 |
1 files changed, 221 insertions, 2 deletions
diff --git a/src/dsp/x86/distance_weighted_blend_sse4.cc b/src/dsp/x86/distance_weighted_blend_sse4.cc index deb57ef..3c29b19 100644 --- a/src/dsp/x86/distance_weighted_blend_sse4.cc +++ b/src/dsp/x86/distance_weighted_blend_sse4.cc @@ -30,6 +30,7 @@ namespace libgav1 { namespace dsp { +namespace low_bitdepth { namespace { constexpr int kInterPostRoundBit = 4; @@ -212,13 +213,231 @@ void Init8bpp() { } } // namespace +} // namespace low_bitdepth -void DistanceWeightedBlendInit_SSE4_1() { Init8bpp(); } +#if LIBGAV1_MAX_BITDEPTH >= 10 +namespace high_bitdepth { +namespace { + +constexpr int kMax10bppSample = (1 << 10) - 1; +constexpr int kInterPostRoundBit = 4; + +inline __m128i ComputeWeightedAverage8(const __m128i& pred0, + const __m128i& pred1, + const __m128i& weight0, + const __m128i& weight1) { + // This offset is a combination of round_factor and round_offset + // which are to be added and subtracted respectively. + // Here kInterPostRoundBit + 4 is considering bitdepth=10. + constexpr int offset = + (1 << ((kInterPostRoundBit + 4) - 1)) - (kCompoundOffset << 4); + const __m128i zero = _mm_setzero_si128(); + const __m128i bias = _mm_set1_epi32(offset); + const __m128i clip_high = _mm_set1_epi16(kMax10bppSample); + + __m128i prediction0 = _mm_cvtepu16_epi32(pred0); + __m128i mult0 = _mm_mullo_epi32(prediction0, weight0); + __m128i prediction1 = _mm_cvtepu16_epi32(pred1); + __m128i mult1 = _mm_mullo_epi32(prediction1, weight1); + __m128i sum = _mm_add_epi32(mult0, mult1); + sum = _mm_add_epi32(sum, bias); + const __m128i result0 = _mm_srai_epi32(sum, kInterPostRoundBit + 4); + + prediction0 = _mm_unpackhi_epi16(pred0, zero); + mult0 = _mm_mullo_epi32(prediction0, weight0); + prediction1 = _mm_unpackhi_epi16(pred1, zero); + mult1 = _mm_mullo_epi32(prediction1, weight1); + sum = _mm_add_epi32(mult0, mult1); + sum = _mm_add_epi32(sum, bias); + const __m128i result1 = _mm_srai_epi32(sum, kInterPostRoundBit + 4); + const __m128i pack = _mm_packus_epi32(result0, result1); + + return _mm_min_epi16(pack, clip_high); +} + +template <int height> +inline void DistanceWeightedBlend4xH_SSE4_1( + const uint16_t* pred_0, const uint16_t* pred_1, const uint8_t weight_0, + const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) { + auto* dst = static_cast<uint16_t*>(dest); + const __m128i weight0 = _mm_set1_epi32(weight_0); + const __m128i weight1 = _mm_set1_epi32(weight_1); + + int y = height; + do { + const __m128i src_00 = LoadLo8(pred_0); + const __m128i src_10 = LoadLo8(pred_1); + pred_0 += 4; + pred_1 += 4; + __m128i src_0 = LoadHi8(src_00, pred_0); + __m128i src_1 = LoadHi8(src_10, pred_1); + pred_0 += 4; + pred_1 += 4; + const __m128i res0 = + ComputeWeightedAverage8(src_0, src_1, weight0, weight1); + + const __m128i src_01 = LoadLo8(pred_0); + const __m128i src_11 = LoadLo8(pred_1); + pred_0 += 4; + pred_1 += 4; + src_0 = LoadHi8(src_01, pred_0); + src_1 = LoadHi8(src_11, pred_1); + pred_0 += 4; + pred_1 += 4; + const __m128i res1 = + ComputeWeightedAverage8(src_0, src_1, weight0, weight1); + + StoreLo8(dst, res0); + dst += dest_stride; + StoreHi8(dst, res0); + dst += dest_stride; + StoreLo8(dst, res1); + dst += dest_stride; + StoreHi8(dst, res1); + dst += dest_stride; + y -= 4; + } while (y != 0); +} + +template <int height> +inline void DistanceWeightedBlend8xH_SSE4_1( + const uint16_t* pred_0, const uint16_t* pred_1, const uint8_t weight_0, + const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) { + auto* dst = static_cast<uint16_t*>(dest); + const __m128i weight0 = _mm_set1_epi32(weight_0); + const __m128i weight1 = _mm_set1_epi32(weight_1); + + int y = height; + do { + const __m128i src_00 = LoadAligned16(pred_0); + const __m128i src_10 = LoadAligned16(pred_1); + pred_0 += 8; + pred_1 += 8; + const __m128i res0 = + ComputeWeightedAverage8(src_00, src_10, weight0, weight1); + + const __m128i src_01 = LoadAligned16(pred_0); + const __m128i src_11 = LoadAligned16(pred_1); + pred_0 += 8; + pred_1 += 8; + const __m128i res1 = + ComputeWeightedAverage8(src_01, src_11, weight0, weight1); + + StoreUnaligned16(dst, res0); + dst += dest_stride; + StoreUnaligned16(dst, res1); + dst += dest_stride; + y -= 2; + } while (y != 0); +} + +inline void DistanceWeightedBlendLarge_SSE4_1( + const uint16_t* pred_0, const uint16_t* pred_1, const uint8_t weight_0, + const uint8_t weight_1, const int width, const int height, void* const dest, + const ptrdiff_t dest_stride) { + auto* dst = static_cast<uint16_t*>(dest); + const __m128i weight0 = _mm_set1_epi32(weight_0); + const __m128i weight1 = _mm_set1_epi32(weight_1); + + int y = height; + do { + int x = 0; + do { + const __m128i src_0_lo = LoadAligned16(pred_0 + x); + const __m128i src_1_lo = LoadAligned16(pred_1 + x); + const __m128i res_lo = + ComputeWeightedAverage8(src_0_lo, src_1_lo, weight0, weight1); + + const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8); + const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8); + const __m128i res_hi = + ComputeWeightedAverage8(src_0_hi, src_1_hi, weight0, weight1); + + StoreUnaligned16(dst + x, res_lo); + x += 8; + StoreUnaligned16(dst + x, res_hi); + x += 8; + } while (x < width); + dst += dest_stride; + pred_0 += width; + pred_1 += width; + } while (--y != 0); +} + +void DistanceWeightedBlend_SSE4_1(const void* prediction_0, + const void* prediction_1, + const uint8_t weight_0, + const uint8_t weight_1, const int width, + const int height, void* const dest, + const ptrdiff_t dest_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + const ptrdiff_t dst_stride = dest_stride / sizeof(*pred_0); + if (width == 4) { + if (height == 4) { + DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1, + dest, dst_stride); + } else if (height == 8) { + DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1, + dest, dst_stride); + } else { + assert(height == 16); + DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1, + dest, dst_stride); + } + return; + } + + if (width == 8) { + switch (height) { + case 4: + DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1, + dest, dst_stride); + return; + case 8: + DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1, + dest, dst_stride); + return; + case 16: + DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1, + dest, dst_stride); + return; + default: + assert(height == 32); + DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1, + dest, dst_stride); + + return; + } + } + + DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width, + height, dest, dst_stride); +} + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); +#if DSP_ENABLED_10BPP_SSE4_1(DistanceWeightedBlend) + dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1; +#endif +} + +} // namespace +} // namespace high_bitdepth +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +void DistanceWeightedBlendInit_SSE4_1() { + low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + high_bitdepth::Init10bpp(); +#endif +} } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_TARGETING_SSE4_1 +#else // !LIBGAV1_TARGETING_SSE4_1 namespace libgav1 { namespace dsp { |