aboutsummaryrefslogtreecommitdiff
path: root/src/dsp/x86/distance_weighted_blend_sse4.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/dsp/x86/distance_weighted_blend_sse4.cc')
-rw-r--r--src/dsp/x86/distance_weighted_blend_sse4.cc230
1 files changed, 230 insertions, 0 deletions
diff --git a/src/dsp/x86/distance_weighted_blend_sse4.cc b/src/dsp/x86/distance_weighted_blend_sse4.cc
new file mode 100644
index 0000000..deb57ef
--- /dev/null
+++ b/src/dsp/x86/distance_weighted_blend_sse4.cc
@@ -0,0 +1,230 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/distance_weighted_blend.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_TARGETING_SSE4_1
+
+#include <xmmintrin.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/dsp/x86/common_sse4.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+constexpr int kInterPostRoundBit = 4;
+
+inline __m128i ComputeWeightedAverage8(const __m128i& pred0,
+ const __m128i& pred1,
+ const __m128i& weights) {
+ // TODO(https://issuetracker.google.com/issues/150325685): Investigate range.
+ const __m128i preds_lo = _mm_unpacklo_epi16(pred0, pred1);
+ const __m128i mult_lo = _mm_madd_epi16(preds_lo, weights);
+ const __m128i result_lo =
+ RightShiftWithRounding_S32(mult_lo, kInterPostRoundBit + 4);
+
+ const __m128i preds_hi = _mm_unpackhi_epi16(pred0, pred1);
+ const __m128i mult_hi = _mm_madd_epi16(preds_hi, weights);
+ const __m128i result_hi =
+ RightShiftWithRounding_S32(mult_hi, kInterPostRoundBit + 4);
+
+ return _mm_packs_epi32(result_lo, result_hi);
+}
+
+template <int height>
+inline void DistanceWeightedBlend4xH_SSE4_1(
+ const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
+ const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
+ auto* dst = static_cast<uint8_t*>(dest);
+ const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
+
+ for (int y = 0; y < height; y += 4) {
+ // TODO(b/150326556): Use larger loads.
+ const __m128i src_00 = LoadLo8(pred_0);
+ const __m128i src_10 = LoadLo8(pred_1);
+ pred_0 += 4;
+ pred_1 += 4;
+ __m128i src_0 = LoadHi8(src_00, pred_0);
+ __m128i src_1 = LoadHi8(src_10, pred_1);
+ pred_0 += 4;
+ pred_1 += 4;
+ const __m128i res0 = ComputeWeightedAverage8(src_0, src_1, weights);
+
+ const __m128i src_01 = LoadLo8(pred_0);
+ const __m128i src_11 = LoadLo8(pred_1);
+ pred_0 += 4;
+ pred_1 += 4;
+ src_0 = LoadHi8(src_01, pred_0);
+ src_1 = LoadHi8(src_11, pred_1);
+ pred_0 += 4;
+ pred_1 += 4;
+ const __m128i res1 = ComputeWeightedAverage8(src_0, src_1, weights);
+
+ const __m128i result_pixels = _mm_packus_epi16(res0, res1);
+ Store4(dst, result_pixels);
+ dst += dest_stride;
+ const int result_1 = _mm_extract_epi32(result_pixels, 1);
+ memcpy(dst, &result_1, sizeof(result_1));
+ dst += dest_stride;
+ const int result_2 = _mm_extract_epi32(result_pixels, 2);
+ memcpy(dst, &result_2, sizeof(result_2));
+ dst += dest_stride;
+ const int result_3 = _mm_extract_epi32(result_pixels, 3);
+ memcpy(dst, &result_3, sizeof(result_3));
+ dst += dest_stride;
+ }
+}
+
+template <int height>
+inline void DistanceWeightedBlend8xH_SSE4_1(
+ const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
+ const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
+ auto* dst = static_cast<uint8_t*>(dest);
+ const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
+
+ for (int y = 0; y < height; y += 2) {
+ const __m128i src_00 = LoadAligned16(pred_0);
+ const __m128i src_10 = LoadAligned16(pred_1);
+ pred_0 += 8;
+ pred_1 += 8;
+ const __m128i res0 = ComputeWeightedAverage8(src_00, src_10, weights);
+
+ const __m128i src_01 = LoadAligned16(pred_0);
+ const __m128i src_11 = LoadAligned16(pred_1);
+ pred_0 += 8;
+ pred_1 += 8;
+ const __m128i res1 = ComputeWeightedAverage8(src_01, src_11, weights);
+
+ const __m128i result_pixels = _mm_packus_epi16(res0, res1);
+ StoreLo8(dst, result_pixels);
+ dst += dest_stride;
+ StoreHi8(dst, result_pixels);
+ dst += dest_stride;
+ }
+}
+
+inline void DistanceWeightedBlendLarge_SSE4_1(
+ const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
+ const uint8_t weight_1, const int width, const int height, void* const dest,
+ const ptrdiff_t dest_stride) {
+ auto* dst = static_cast<uint8_t*>(dest);
+ const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
+
+ int y = height;
+ do {
+ int x = 0;
+ do {
+ const __m128i src_0_lo = LoadAligned16(pred_0 + x);
+ const __m128i src_1_lo = LoadAligned16(pred_1 + x);
+ const __m128i res_lo =
+ ComputeWeightedAverage8(src_0_lo, src_1_lo, weights);
+
+ const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8);
+ const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8);
+ const __m128i res_hi =
+ ComputeWeightedAverage8(src_0_hi, src_1_hi, weights);
+
+ StoreUnaligned16(dst + x, _mm_packus_epi16(res_lo, res_hi));
+ x += 16;
+ } while (x < width);
+ dst += dest_stride;
+ pred_0 += width;
+ pred_1 += width;
+ } while (--y != 0);
+}
+
+void DistanceWeightedBlend_SSE4_1(const void* prediction_0,
+ const void* prediction_1,
+ const uint8_t weight_0,
+ const uint8_t weight_1, const int width,
+ const int height, void* const dest,
+ const ptrdiff_t dest_stride) {
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ if (width == 4) {
+ if (height == 4) {
+ DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
+ dest, dest_stride);
+ } else if (height == 8) {
+ DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
+ dest, dest_stride);
+ } else {
+ assert(height == 16);
+ DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
+ dest, dest_stride);
+ }
+ return;
+ }
+
+ if (width == 8) {
+ switch (height) {
+ case 4:
+ DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
+ dest, dest_stride);
+ return;
+ case 8:
+ DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
+ dest, dest_stride);
+ return;
+ case 16:
+ DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
+ dest, dest_stride);
+ return;
+ default:
+ assert(height == 32);
+ DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1,
+ dest, dest_stride);
+
+ return;
+ }
+ }
+
+ DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width,
+ height, dest, dest_stride);
+}
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+#if DSP_ENABLED_8BPP_SSE4_1(DistanceWeightedBlend)
+ dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1;
+#endif
+}
+
+} // namespace
+
+void DistanceWeightedBlendInit_SSE4_1() { Init8bpp(); }
+
+} // namespace dsp
+} // namespace libgav1
+
+#else // !LIBGAV1_TARGETING_SSE4_1
+
+namespace libgav1 {
+namespace dsp {
+
+void DistanceWeightedBlendInit_SSE4_1() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_TARGETING_SSE4_1