diff options
Diffstat (limited to 'src/dsp/x86/distance_weighted_blend_sse4.cc')
-rw-r--r-- | src/dsp/x86/distance_weighted_blend_sse4.cc | 152 |
1 files changed, 71 insertions, 81 deletions
diff --git a/src/dsp/x86/distance_weighted_blend_sse4.cc b/src/dsp/x86/distance_weighted_blend_sse4.cc index c813df4..8c32117 100644 --- a/src/dsp/x86/distance_weighted_blend_sse4.cc +++ b/src/dsp/x86/distance_weighted_blend_sse4.cc @@ -34,54 +34,50 @@ namespace low_bitdepth { namespace { constexpr int kInterPostRoundBit = 4; +constexpr int kInterPostRhsAdjust = 1 << (16 - kInterPostRoundBit - 1); inline __m128i ComputeWeightedAverage8(const __m128i& pred0, const __m128i& pred1, - const __m128i& weights) { - // TODO(https://issuetracker.google.com/issues/150325685): Investigate range. - const __m128i preds_lo = _mm_unpacklo_epi16(pred0, pred1); - const __m128i mult_lo = _mm_madd_epi16(preds_lo, weights); - const __m128i result_lo = - RightShiftWithRounding_S32(mult_lo, kInterPostRoundBit + 4); - - const __m128i preds_hi = _mm_unpackhi_epi16(pred0, pred1); - const __m128i mult_hi = _mm_madd_epi16(preds_hi, weights); - const __m128i result_hi = - RightShiftWithRounding_S32(mult_hi, kInterPostRoundBit + 4); - - return _mm_packs_epi32(result_lo, result_hi); + const __m128i& weight) { + // Given: p0,p1 in range [-5132,9212] and w0 = 16 - w1, w1 = 16 - w0 + // Output: (p0 * w0 + p1 * w1 + 128(=rounding bit)) >> + // 8(=kInterPostRoundBit + 4) + // The formula is manipulated to avoid lengthening to 32 bits. + // p0 * w0 + p1 * w1 = p0 * w0 + (16 - w0) * p1 + // = (p0 - p1) * w0 + 16 * p1 + // Maximum value of p0 - p1 is 9212 + 5132 = 0x3808. + const __m128i diff = _mm_slli_epi16(_mm_sub_epi16(pred0, pred1), 1); + // (((p0 - p1) * (w0 << 12) >> 16) + ((16 * p1) >> 4) + const __m128i weighted_diff = _mm_mulhi_epi16(diff, weight); + // ((p0 - p1) * w0 >> 4) + p1 + const __m128i upscaled_average = _mm_add_epi16(weighted_diff, pred1); + // (x << 11) >> 15 == x >> 4 + const __m128i right_shift_prep = _mm_set1_epi16(kInterPostRhsAdjust); + // (((p0 - p1) * w0 >> 4) + p1 + (128 >> 4)) >> 4 + return _mm_mulhrs_epi16(upscaled_average, right_shift_prep); } template <int height> inline void DistanceWeightedBlend4xH_SSE4_1( const int16_t* LIBGAV1_RESTRICT pred_0, - const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0, - const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest, - const ptrdiff_t dest_stride) { + const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight, + void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) { auto* dst = static_cast<uint8_t*>(dest); - const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16)); + // Upscale the weight for mulhi. + const __m128i weights = _mm_set1_epi16(weight << 11); for (int y = 0; y < height; y += 4) { - // TODO(b/150326556): Use larger loads. - const __m128i src_00 = LoadLo8(pred_0); - const __m128i src_10 = LoadLo8(pred_1); - pred_0 += 4; - pred_1 += 4; - __m128i src_0 = LoadHi8(src_00, pred_0); - __m128i src_1 = LoadHi8(src_10, pred_1); - pred_0 += 4; - pred_1 += 4; - const __m128i res0 = ComputeWeightedAverage8(src_0, src_1, weights); - - const __m128i src_01 = LoadLo8(pred_0); - const __m128i src_11 = LoadLo8(pred_1); - pred_0 += 4; - pred_1 += 4; - src_0 = LoadHi8(src_01, pred_0); - src_1 = LoadHi8(src_11, pred_1); - pred_0 += 4; - pred_1 += 4; - const __m128i res1 = ComputeWeightedAverage8(src_0, src_1, weights); + const __m128i src_00 = LoadAligned16(pred_0); + const __m128i src_10 = LoadAligned16(pred_1); + pred_0 += 8; + pred_1 += 8; + const __m128i res0 = ComputeWeightedAverage8(src_00, src_10, weights); + + const __m128i src_01 = LoadAligned16(pred_0); + const __m128i src_11 = LoadAligned16(pred_1); + pred_0 += 8; + pred_1 += 8; + const __m128i res1 = ComputeWeightedAverage8(src_01, src_11, weights); const __m128i result_pixels = _mm_packus_epi16(res0, res1); Store4(dst, result_pixels); @@ -101,11 +97,11 @@ inline void DistanceWeightedBlend4xH_SSE4_1( template <int height> inline void DistanceWeightedBlend8xH_SSE4_1( const int16_t* LIBGAV1_RESTRICT pred_0, - const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0, - const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest, - const ptrdiff_t dest_stride) { + const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight, + void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) { auto* dst = static_cast<uint8_t*>(dest); - const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16)); + // Upscale the weight for mulhi. + const __m128i weights = _mm_set1_epi16(weight << 11); for (int y = 0; y < height; y += 2) { const __m128i src_00 = LoadAligned16(pred_0); @@ -130,11 +126,12 @@ inline void DistanceWeightedBlend8xH_SSE4_1( inline void DistanceWeightedBlendLarge_SSE4_1( const int16_t* LIBGAV1_RESTRICT pred_0, - const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0, - const uint8_t weight_1, const int width, const int height, - void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) { + const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight, + const int width, const int height, void* LIBGAV1_RESTRICT const dest, + const ptrdiff_t dest_stride) { auto* dst = static_cast<uint8_t*>(dest); - const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16)); + // Upscale the weight for mulhi. + const __m128i weights = _mm_set1_epi16(weight << 11); int y = height; do { @@ -162,23 +159,24 @@ inline void DistanceWeightedBlendLarge_SSE4_1( void DistanceWeightedBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0, const void* LIBGAV1_RESTRICT prediction_1, const uint8_t weight_0, - const uint8_t weight_1, const int width, + const uint8_t /*weight_1*/, const int width, const int height, void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) { const auto* pred_0 = static_cast<const int16_t*>(prediction_0); const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + const uint8_t weight = weight_0; if (width == 4) { if (height == 4) { - DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1, - dest, dest_stride); + DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight, dest, + dest_stride); } else if (height == 8) { - DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1, - dest, dest_stride); + DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight, dest, + dest_stride); } else { assert(height == 16); - DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1, - dest, dest_stride); + DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight, dest, + dest_stride); } return; } @@ -186,28 +184,28 @@ void DistanceWeightedBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0, if (width == 8) { switch (height) { case 4: - DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1, - dest, dest_stride); + DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight, dest, + dest_stride); return; case 8: - DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1, - dest, dest_stride); + DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight, dest, + dest_stride); return; case 16: - DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1, - dest, dest_stride); + DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight, dest, + dest_stride); return; default: assert(height == 32); - DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1, - dest, dest_stride); + DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight, dest, + dest_stride); return; } } - DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width, - height, dest, dest_stride); + DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight, width, height, dest, + dest_stride); } void Init8bpp() { @@ -273,27 +271,19 @@ inline void DistanceWeightedBlend4xH_SSE4_1( int y = height; do { - const __m128i src_00 = LoadLo8(pred_0); - const __m128i src_10 = LoadLo8(pred_1); - pred_0 += 4; - pred_1 += 4; - __m128i src_0 = LoadHi8(src_00, pred_0); - __m128i src_1 = LoadHi8(src_10, pred_1); - pred_0 += 4; - pred_1 += 4; + const __m128i src_00 = LoadAligned16(pred_0); + const __m128i src_10 = LoadAligned16(pred_1); + pred_0 += 8; + pred_1 += 8; const __m128i res0 = - ComputeWeightedAverage8(src_0, src_1, weight0, weight1); - - const __m128i src_01 = LoadLo8(pred_0); - const __m128i src_11 = LoadLo8(pred_1); - pred_0 += 4; - pred_1 += 4; - src_0 = LoadHi8(src_01, pred_0); - src_1 = LoadHi8(src_11, pred_1); - pred_0 += 4; - pred_1 += 4; + ComputeWeightedAverage8(src_00, src_10, weight0, weight1); + + const __m128i src_01 = LoadAligned16(pred_0); + const __m128i src_11 = LoadAligned16(pred_1); + pred_0 += 8; + pred_1 += 8; const __m128i res1 = - ComputeWeightedAverage8(src_0, src_1, weight0, weight1); + ComputeWeightedAverage8(src_01, src_11, weight0, weight1); StoreLo8(dst, res0); dst += dest_stride; |