aboutsummaryrefslogtreecommitdiff
path: root/src/dsp/x86/distance_weighted_blend_sse4.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/dsp/x86/distance_weighted_blend_sse4.cc')
-rw-r--r--src/dsp/x86/distance_weighted_blend_sse4.cc152
1 files changed, 71 insertions, 81 deletions
diff --git a/src/dsp/x86/distance_weighted_blend_sse4.cc b/src/dsp/x86/distance_weighted_blend_sse4.cc
index c813df4..8c32117 100644
--- a/src/dsp/x86/distance_weighted_blend_sse4.cc
+++ b/src/dsp/x86/distance_weighted_blend_sse4.cc
@@ -34,54 +34,50 @@ namespace low_bitdepth {
namespace {
constexpr int kInterPostRoundBit = 4;
+constexpr int kInterPostRhsAdjust = 1 << (16 - kInterPostRoundBit - 1);
inline __m128i ComputeWeightedAverage8(const __m128i& pred0,
const __m128i& pred1,
- const __m128i& weights) {
- // TODO(https://issuetracker.google.com/issues/150325685): Investigate range.
- const __m128i preds_lo = _mm_unpacklo_epi16(pred0, pred1);
- const __m128i mult_lo = _mm_madd_epi16(preds_lo, weights);
- const __m128i result_lo =
- RightShiftWithRounding_S32(mult_lo, kInterPostRoundBit + 4);
-
- const __m128i preds_hi = _mm_unpackhi_epi16(pred0, pred1);
- const __m128i mult_hi = _mm_madd_epi16(preds_hi, weights);
- const __m128i result_hi =
- RightShiftWithRounding_S32(mult_hi, kInterPostRoundBit + 4);
-
- return _mm_packs_epi32(result_lo, result_hi);
+ const __m128i& weight) {
+ // Given: p0,p1 in range [-5132,9212] and w0 = 16 - w1, w1 = 16 - w0
+ // Output: (p0 * w0 + p1 * w1 + 128(=rounding bit)) >>
+ // 8(=kInterPostRoundBit + 4)
+ // The formula is manipulated to avoid lengthening to 32 bits.
+ // p0 * w0 + p1 * w1 = p0 * w0 + (16 - w0) * p1
+ // = (p0 - p1) * w0 + 16 * p1
+ // Maximum value of p0 - p1 is 9212 + 5132 = 0x3808.
+ const __m128i diff = _mm_slli_epi16(_mm_sub_epi16(pred0, pred1), 1);
+ // (((p0 - p1) * (w0 << 12) >> 16) + ((16 * p1) >> 4)
+ const __m128i weighted_diff = _mm_mulhi_epi16(diff, weight);
+ // ((p0 - p1) * w0 >> 4) + p1
+ const __m128i upscaled_average = _mm_add_epi16(weighted_diff, pred1);
+ // (x << 11) >> 15 == x >> 4
+ const __m128i right_shift_prep = _mm_set1_epi16(kInterPostRhsAdjust);
+ // (((p0 - p1) * w0 >> 4) + p1 + (128 >> 4)) >> 4
+ return _mm_mulhrs_epi16(upscaled_average, right_shift_prep);
}
template <int height>
inline void DistanceWeightedBlend4xH_SSE4_1(
const int16_t* LIBGAV1_RESTRICT pred_0,
- const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
- const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest,
- const ptrdiff_t dest_stride) {
+ const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight,
+ void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
auto* dst = static_cast<uint8_t*>(dest);
- const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
+ // Upscale the weight for mulhi.
+ const __m128i weights = _mm_set1_epi16(weight << 11);
for (int y = 0; y < height; y += 4) {
- // TODO(b/150326556): Use larger loads.
- const __m128i src_00 = LoadLo8(pred_0);
- const __m128i src_10 = LoadLo8(pred_1);
- pred_0 += 4;
- pred_1 += 4;
- __m128i src_0 = LoadHi8(src_00, pred_0);
- __m128i src_1 = LoadHi8(src_10, pred_1);
- pred_0 += 4;
- pred_1 += 4;
- const __m128i res0 = ComputeWeightedAverage8(src_0, src_1, weights);
-
- const __m128i src_01 = LoadLo8(pred_0);
- const __m128i src_11 = LoadLo8(pred_1);
- pred_0 += 4;
- pred_1 += 4;
- src_0 = LoadHi8(src_01, pred_0);
- src_1 = LoadHi8(src_11, pred_1);
- pred_0 += 4;
- pred_1 += 4;
- const __m128i res1 = ComputeWeightedAverage8(src_0, src_1, weights);
+ const __m128i src_00 = LoadAligned16(pred_0);
+ const __m128i src_10 = LoadAligned16(pred_1);
+ pred_0 += 8;
+ pred_1 += 8;
+ const __m128i res0 = ComputeWeightedAverage8(src_00, src_10, weights);
+
+ const __m128i src_01 = LoadAligned16(pred_0);
+ const __m128i src_11 = LoadAligned16(pred_1);
+ pred_0 += 8;
+ pred_1 += 8;
+ const __m128i res1 = ComputeWeightedAverage8(src_01, src_11, weights);
const __m128i result_pixels = _mm_packus_epi16(res0, res1);
Store4(dst, result_pixels);
@@ -101,11 +97,11 @@ inline void DistanceWeightedBlend4xH_SSE4_1(
template <int height>
inline void DistanceWeightedBlend8xH_SSE4_1(
const int16_t* LIBGAV1_RESTRICT pred_0,
- const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
- const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest,
- const ptrdiff_t dest_stride) {
+ const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight,
+ void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
auto* dst = static_cast<uint8_t*>(dest);
- const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
+ // Upscale the weight for mulhi.
+ const __m128i weights = _mm_set1_epi16(weight << 11);
for (int y = 0; y < height; y += 2) {
const __m128i src_00 = LoadAligned16(pred_0);
@@ -130,11 +126,12 @@ inline void DistanceWeightedBlend8xH_SSE4_1(
inline void DistanceWeightedBlendLarge_SSE4_1(
const int16_t* LIBGAV1_RESTRICT pred_0,
- const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
- const uint8_t weight_1, const int width, const int height,
- void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
+ const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight,
+ const int width, const int height, void* LIBGAV1_RESTRICT const dest,
+ const ptrdiff_t dest_stride) {
auto* dst = static_cast<uint8_t*>(dest);
- const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
+ // Upscale the weight for mulhi.
+ const __m128i weights = _mm_set1_epi16(weight << 11);
int y = height;
do {
@@ -162,23 +159,24 @@ inline void DistanceWeightedBlendLarge_SSE4_1(
void DistanceWeightedBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
const void* LIBGAV1_RESTRICT prediction_1,
const uint8_t weight_0,
- const uint8_t weight_1, const int width,
+ const uint8_t /*weight_1*/, const int width,
const int height,
void* LIBGAV1_RESTRICT const dest,
const ptrdiff_t dest_stride) {
const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ const uint8_t weight = weight_0;
if (width == 4) {
if (height == 4) {
- DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
- dest, dest_stride);
+ DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight, dest,
+ dest_stride);
} else if (height == 8) {
- DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
- dest, dest_stride);
+ DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight, dest,
+ dest_stride);
} else {
assert(height == 16);
- DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
- dest, dest_stride);
+ DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight, dest,
+ dest_stride);
}
return;
}
@@ -186,28 +184,28 @@ void DistanceWeightedBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
if (width == 8) {
switch (height) {
case 4:
- DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
- dest, dest_stride);
+ DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight, dest,
+ dest_stride);
return;
case 8:
- DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
- dest, dest_stride);
+ DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight, dest,
+ dest_stride);
return;
case 16:
- DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
- dest, dest_stride);
+ DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight, dest,
+ dest_stride);
return;
default:
assert(height == 32);
- DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1,
- dest, dest_stride);
+ DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight, dest,
+ dest_stride);
return;
}
}
- DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width,
- height, dest, dest_stride);
+ DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight, width, height, dest,
+ dest_stride);
}
void Init8bpp() {
@@ -273,27 +271,19 @@ inline void DistanceWeightedBlend4xH_SSE4_1(
int y = height;
do {
- const __m128i src_00 = LoadLo8(pred_0);
- const __m128i src_10 = LoadLo8(pred_1);
- pred_0 += 4;
- pred_1 += 4;
- __m128i src_0 = LoadHi8(src_00, pred_0);
- __m128i src_1 = LoadHi8(src_10, pred_1);
- pred_0 += 4;
- pred_1 += 4;
+ const __m128i src_00 = LoadAligned16(pred_0);
+ const __m128i src_10 = LoadAligned16(pred_1);
+ pred_0 += 8;
+ pred_1 += 8;
const __m128i res0 =
- ComputeWeightedAverage8(src_0, src_1, weight0, weight1);
-
- const __m128i src_01 = LoadLo8(pred_0);
- const __m128i src_11 = LoadLo8(pred_1);
- pred_0 += 4;
- pred_1 += 4;
- src_0 = LoadHi8(src_01, pred_0);
- src_1 = LoadHi8(src_11, pred_1);
- pred_0 += 4;
- pred_1 += 4;
+ ComputeWeightedAverage8(src_00, src_10, weight0, weight1);
+
+ const __m128i src_01 = LoadAligned16(pred_0);
+ const __m128i src_11 = LoadAligned16(pred_1);
+ pred_0 += 8;
+ pred_1 += 8;
const __m128i res1 =
- ComputeWeightedAverage8(src_0, src_1, weight0, weight1);
+ ComputeWeightedAverage8(src_01, src_11, weight0, weight1);
StoreLo8(dst, res0);
dst += dest_stride;