aboutsummaryrefslogtreecommitdiff
path: root/src/dsp/x86/distance_weighted_blend_sse4.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/dsp/x86/distance_weighted_blend_sse4.cc')
-rw-r--r--src/dsp/x86/distance_weighted_blend_sse4.cc223
1 files changed, 221 insertions, 2 deletions
diff --git a/src/dsp/x86/distance_weighted_blend_sse4.cc b/src/dsp/x86/distance_weighted_blend_sse4.cc
index deb57ef..3c29b19 100644
--- a/src/dsp/x86/distance_weighted_blend_sse4.cc
+++ b/src/dsp/x86/distance_weighted_blend_sse4.cc
@@ -30,6 +30,7 @@
namespace libgav1 {
namespace dsp {
+namespace low_bitdepth {
namespace {
constexpr int kInterPostRoundBit = 4;
@@ -212,13 +213,231 @@ void Init8bpp() {
}
} // namespace
+} // namespace low_bitdepth
-void DistanceWeightedBlendInit_SSE4_1() { Init8bpp(); }
+#if LIBGAV1_MAX_BITDEPTH >= 10
+namespace high_bitdepth {
+namespace {
+
+constexpr int kMax10bppSample = (1 << 10) - 1;
+constexpr int kInterPostRoundBit = 4;
+
+inline __m128i ComputeWeightedAverage8(const __m128i& pred0,
+ const __m128i& pred1,
+ const __m128i& weight0,
+ const __m128i& weight1) {
+ // This offset is a combination of round_factor and round_offset
+ // which are to be added and subtracted respectively.
+ // Here kInterPostRoundBit + 4 is considering bitdepth=10.
+ constexpr int offset =
+ (1 << ((kInterPostRoundBit + 4) - 1)) - (kCompoundOffset << 4);
+ const __m128i zero = _mm_setzero_si128();
+ const __m128i bias = _mm_set1_epi32(offset);
+ const __m128i clip_high = _mm_set1_epi16(kMax10bppSample);
+
+ __m128i prediction0 = _mm_cvtepu16_epi32(pred0);
+ __m128i mult0 = _mm_mullo_epi32(prediction0, weight0);
+ __m128i prediction1 = _mm_cvtepu16_epi32(pred1);
+ __m128i mult1 = _mm_mullo_epi32(prediction1, weight1);
+ __m128i sum = _mm_add_epi32(mult0, mult1);
+ sum = _mm_add_epi32(sum, bias);
+ const __m128i result0 = _mm_srai_epi32(sum, kInterPostRoundBit + 4);
+
+ prediction0 = _mm_unpackhi_epi16(pred0, zero);
+ mult0 = _mm_mullo_epi32(prediction0, weight0);
+ prediction1 = _mm_unpackhi_epi16(pred1, zero);
+ mult1 = _mm_mullo_epi32(prediction1, weight1);
+ sum = _mm_add_epi32(mult0, mult1);
+ sum = _mm_add_epi32(sum, bias);
+ const __m128i result1 = _mm_srai_epi32(sum, kInterPostRoundBit + 4);
+ const __m128i pack = _mm_packus_epi32(result0, result1);
+
+ return _mm_min_epi16(pack, clip_high);
+}
+
+template <int height>
+inline void DistanceWeightedBlend4xH_SSE4_1(
+ const uint16_t* pred_0, const uint16_t* pred_1, const uint8_t weight_0,
+ const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
+ auto* dst = static_cast<uint16_t*>(dest);
+ const __m128i weight0 = _mm_set1_epi32(weight_0);
+ const __m128i weight1 = _mm_set1_epi32(weight_1);
+
+ int y = height;
+ do {
+ const __m128i src_00 = LoadLo8(pred_0);
+ const __m128i src_10 = LoadLo8(pred_1);
+ pred_0 += 4;
+ pred_1 += 4;
+ __m128i src_0 = LoadHi8(src_00, pred_0);
+ __m128i src_1 = LoadHi8(src_10, pred_1);
+ pred_0 += 4;
+ pred_1 += 4;
+ const __m128i res0 =
+ ComputeWeightedAverage8(src_0, src_1, weight0, weight1);
+
+ const __m128i src_01 = LoadLo8(pred_0);
+ const __m128i src_11 = LoadLo8(pred_1);
+ pred_0 += 4;
+ pred_1 += 4;
+ src_0 = LoadHi8(src_01, pred_0);
+ src_1 = LoadHi8(src_11, pred_1);
+ pred_0 += 4;
+ pred_1 += 4;
+ const __m128i res1 =
+ ComputeWeightedAverage8(src_0, src_1, weight0, weight1);
+
+ StoreLo8(dst, res0);
+ dst += dest_stride;
+ StoreHi8(dst, res0);
+ dst += dest_stride;
+ StoreLo8(dst, res1);
+ dst += dest_stride;
+ StoreHi8(dst, res1);
+ dst += dest_stride;
+ y -= 4;
+ } while (y != 0);
+}
+
+template <int height>
+inline void DistanceWeightedBlend8xH_SSE4_1(
+ const uint16_t* pred_0, const uint16_t* pred_1, const uint8_t weight_0,
+ const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
+ auto* dst = static_cast<uint16_t*>(dest);
+ const __m128i weight0 = _mm_set1_epi32(weight_0);
+ const __m128i weight1 = _mm_set1_epi32(weight_1);
+
+ int y = height;
+ do {
+ const __m128i src_00 = LoadAligned16(pred_0);
+ const __m128i src_10 = LoadAligned16(pred_1);
+ pred_0 += 8;
+ pred_1 += 8;
+ const __m128i res0 =
+ ComputeWeightedAverage8(src_00, src_10, weight0, weight1);
+
+ const __m128i src_01 = LoadAligned16(pred_0);
+ const __m128i src_11 = LoadAligned16(pred_1);
+ pred_0 += 8;
+ pred_1 += 8;
+ const __m128i res1 =
+ ComputeWeightedAverage8(src_01, src_11, weight0, weight1);
+
+ StoreUnaligned16(dst, res0);
+ dst += dest_stride;
+ StoreUnaligned16(dst, res1);
+ dst += dest_stride;
+ y -= 2;
+ } while (y != 0);
+}
+
+inline void DistanceWeightedBlendLarge_SSE4_1(
+ const uint16_t* pred_0, const uint16_t* pred_1, const uint8_t weight_0,
+ const uint8_t weight_1, const int width, const int height, void* const dest,
+ const ptrdiff_t dest_stride) {
+ auto* dst = static_cast<uint16_t*>(dest);
+ const __m128i weight0 = _mm_set1_epi32(weight_0);
+ const __m128i weight1 = _mm_set1_epi32(weight_1);
+
+ int y = height;
+ do {
+ int x = 0;
+ do {
+ const __m128i src_0_lo = LoadAligned16(pred_0 + x);
+ const __m128i src_1_lo = LoadAligned16(pred_1 + x);
+ const __m128i res_lo =
+ ComputeWeightedAverage8(src_0_lo, src_1_lo, weight0, weight1);
+
+ const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8);
+ const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8);
+ const __m128i res_hi =
+ ComputeWeightedAverage8(src_0_hi, src_1_hi, weight0, weight1);
+
+ StoreUnaligned16(dst + x, res_lo);
+ x += 8;
+ StoreUnaligned16(dst + x, res_hi);
+ x += 8;
+ } while (x < width);
+ dst += dest_stride;
+ pred_0 += width;
+ pred_1 += width;
+ } while (--y != 0);
+}
+
+void DistanceWeightedBlend_SSE4_1(const void* prediction_0,
+ const void* prediction_1,
+ const uint8_t weight_0,
+ const uint8_t weight_1, const int width,
+ const int height, void* const dest,
+ const ptrdiff_t dest_stride) {
+ const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
+ const ptrdiff_t dst_stride = dest_stride / sizeof(*pred_0);
+ if (width == 4) {
+ if (height == 4) {
+ DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
+ dest, dst_stride);
+ } else if (height == 8) {
+ DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
+ dest, dst_stride);
+ } else {
+ assert(height == 16);
+ DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
+ dest, dst_stride);
+ }
+ return;
+ }
+
+ if (width == 8) {
+ switch (height) {
+ case 4:
+ DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
+ dest, dst_stride);
+ return;
+ case 8:
+ DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
+ dest, dst_stride);
+ return;
+ case 16:
+ DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
+ dest, dst_stride);
+ return;
+ default:
+ assert(height == 32);
+ DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1,
+ dest, dst_stride);
+
+ return;
+ }
+ }
+
+ DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width,
+ height, dest, dst_stride);
+}
+
+void Init10bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+ assert(dsp != nullptr);
+#if DSP_ENABLED_10BPP_SSE4_1(DistanceWeightedBlend)
+ dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1;
+#endif
+}
+
+} // namespace
+} // namespace high_bitdepth
+#endif // LIBGAV1_MAX_BITDEPTH >= 10
+
+void DistanceWeightedBlendInit_SSE4_1() {
+ low_bitdepth::Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+ high_bitdepth::Init10bpp();
+#endif
+}
} // namespace dsp
} // namespace libgav1
-#else // !LIBGAV1_TARGETING_SSE4_1
+#else // !LIBGAV1_TARGETING_SSE4_1
namespace libgav1 {
namespace dsp {