// Copyright 2019 The libgav1 Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "src/dsp/distance_weighted_blend.h" #include "src/utils/cpu.h" #if LIBGAV1_TARGETING_SSE4_1 #include #include #include #include #include "src/dsp/constants.h" #include "src/dsp/dsp.h" #include "src/dsp/x86/common_sse4.h" #include "src/utils/common.h" namespace libgav1 { namespace dsp { namespace low_bitdepth { namespace { constexpr int kInterPostRoundBit = 4; inline __m128i ComputeWeightedAverage8(const __m128i& pred0, const __m128i& pred1, const __m128i& weights) { // TODO(https://issuetracker.google.com/issues/150325685): Investigate range. const __m128i preds_lo = _mm_unpacklo_epi16(pred0, pred1); const __m128i mult_lo = _mm_madd_epi16(preds_lo, weights); const __m128i result_lo = RightShiftWithRounding_S32(mult_lo, kInterPostRoundBit + 4); const __m128i preds_hi = _mm_unpackhi_epi16(pred0, pred1); const __m128i mult_hi = _mm_madd_epi16(preds_hi, weights); const __m128i result_hi = RightShiftWithRounding_S32(mult_hi, kInterPostRoundBit + 4); return _mm_packs_epi32(result_lo, result_hi); } template inline void DistanceWeightedBlend4xH_SSE4_1( const int16_t* LIBGAV1_RESTRICT pred_0, const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0, const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) { auto* dst = static_cast(dest); const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16)); for (int y = 0; y < height; y += 4) { // TODO(b/150326556): Use larger loads. const __m128i src_00 = LoadLo8(pred_0); const __m128i src_10 = LoadLo8(pred_1); pred_0 += 4; pred_1 += 4; __m128i src_0 = LoadHi8(src_00, pred_0); __m128i src_1 = LoadHi8(src_10, pred_1); pred_0 += 4; pred_1 += 4; const __m128i res0 = ComputeWeightedAverage8(src_0, src_1, weights); const __m128i src_01 = LoadLo8(pred_0); const __m128i src_11 = LoadLo8(pred_1); pred_0 += 4; pred_1 += 4; src_0 = LoadHi8(src_01, pred_0); src_1 = LoadHi8(src_11, pred_1); pred_0 += 4; pred_1 += 4; const __m128i res1 = ComputeWeightedAverage8(src_0, src_1, weights); const __m128i result_pixels = _mm_packus_epi16(res0, res1); Store4(dst, result_pixels); dst += dest_stride; const int result_1 = _mm_extract_epi32(result_pixels, 1); memcpy(dst, &result_1, sizeof(result_1)); dst += dest_stride; const int result_2 = _mm_extract_epi32(result_pixels, 2); memcpy(dst, &result_2, sizeof(result_2)); dst += dest_stride; const int result_3 = _mm_extract_epi32(result_pixels, 3); memcpy(dst, &result_3, sizeof(result_3)); dst += dest_stride; } } template inline void DistanceWeightedBlend8xH_SSE4_1( const int16_t* LIBGAV1_RESTRICT pred_0, const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0, const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) { auto* dst = static_cast(dest); const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16)); for (int y = 0; y < height; y += 2) { const __m128i src_00 = LoadAligned16(pred_0); const __m128i src_10 = LoadAligned16(pred_1); pred_0 += 8; pred_1 += 8; const __m128i res0 = ComputeWeightedAverage8(src_00, src_10, weights); const __m128i src_01 = LoadAligned16(pred_0); const __m128i src_11 = LoadAligned16(pred_1); pred_0 += 8; pred_1 += 8; const __m128i res1 = ComputeWeightedAverage8(src_01, src_11, weights); const __m128i result_pixels = _mm_packus_epi16(res0, res1); StoreLo8(dst, result_pixels); dst += dest_stride; StoreHi8(dst, result_pixels); dst += dest_stride; } } inline void DistanceWeightedBlendLarge_SSE4_1( const int16_t* LIBGAV1_RESTRICT pred_0, const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0, const uint8_t weight_1, const int width, const int height, void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) { auto* dst = static_cast(dest); const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16)); int y = height; do { int x = 0; do { const __m128i src_0_lo = LoadAligned16(pred_0 + x); const __m128i src_1_lo = LoadAligned16(pred_1 + x); const __m128i res_lo = ComputeWeightedAverage8(src_0_lo, src_1_lo, weights); const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8); const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8); const __m128i res_hi = ComputeWeightedAverage8(src_0_hi, src_1_hi, weights); StoreUnaligned16(dst + x, _mm_packus_epi16(res_lo, res_hi)); x += 16; } while (x < width); dst += dest_stride; pred_0 += width; pred_1 += width; } while (--y != 0); } void DistanceWeightedBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0, const void* LIBGAV1_RESTRICT prediction_1, const uint8_t weight_0, const uint8_t weight_1, const int width, const int height, void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) { const auto* pred_0 = static_cast(prediction_0); const auto* pred_1 = static_cast(prediction_1); if (width == 4) { if (height == 4) { DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1, dest, dest_stride); } else if (height == 8) { DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1, dest, dest_stride); } else { assert(height == 16); DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1, dest, dest_stride); } return; } if (width == 8) { switch (height) { case 4: DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1, dest, dest_stride); return; case 8: DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1, dest, dest_stride); return; case 16: DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1, dest, dest_stride); return; default: assert(height == 32); DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1, dest, dest_stride); return; } } DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width, height, dest, dest_stride); } void Init8bpp() { Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); assert(dsp != nullptr); #if DSP_ENABLED_8BPP_SSE4_1(DistanceWeightedBlend) dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1; #endif } } // namespace } // namespace low_bitdepth #if LIBGAV1_MAX_BITDEPTH >= 10 namespace high_bitdepth { namespace { constexpr int kMax10bppSample = (1 << 10) - 1; constexpr int kInterPostRoundBit = 4; inline __m128i ComputeWeightedAverage8(const __m128i& pred0, const __m128i& pred1, const __m128i& weight0, const __m128i& weight1) { // This offset is a combination of round_factor and round_offset // which are to be added and subtracted respectively. // Here kInterPostRoundBit + 4 is considering bitdepth=10. constexpr int offset = (1 << ((kInterPostRoundBit + 4) - 1)) - (kCompoundOffset << 4); const __m128i zero = _mm_setzero_si128(); const __m128i bias = _mm_set1_epi32(offset); const __m128i clip_high = _mm_set1_epi16(kMax10bppSample); __m128i prediction0 = _mm_cvtepu16_epi32(pred0); __m128i mult0 = _mm_mullo_epi32(prediction0, weight0); __m128i prediction1 = _mm_cvtepu16_epi32(pred1); __m128i mult1 = _mm_mullo_epi32(prediction1, weight1); __m128i sum = _mm_add_epi32(mult0, mult1); sum = _mm_add_epi32(sum, bias); const __m128i result0 = _mm_srai_epi32(sum, kInterPostRoundBit + 4); prediction0 = _mm_unpackhi_epi16(pred0, zero); mult0 = _mm_mullo_epi32(prediction0, weight0); prediction1 = _mm_unpackhi_epi16(pred1, zero); mult1 = _mm_mullo_epi32(prediction1, weight1); sum = _mm_add_epi32(mult0, mult1); sum = _mm_add_epi32(sum, bias); const __m128i result1 = _mm_srai_epi32(sum, kInterPostRoundBit + 4); const __m128i pack = _mm_packus_epi32(result0, result1); return _mm_min_epi16(pack, clip_high); } template inline void DistanceWeightedBlend4xH_SSE4_1( const uint16_t* LIBGAV1_RESTRICT pred_0, const uint16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0, const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) { auto* dst = static_cast(dest); const __m128i weight0 = _mm_set1_epi32(weight_0); const __m128i weight1 = _mm_set1_epi32(weight_1); int y = height; do { const __m128i src_00 = LoadLo8(pred_0); const __m128i src_10 = LoadLo8(pred_1); pred_0 += 4; pred_1 += 4; __m128i src_0 = LoadHi8(src_00, pred_0); __m128i src_1 = LoadHi8(src_10, pred_1); pred_0 += 4; pred_1 += 4; const __m128i res0 = ComputeWeightedAverage8(src_0, src_1, weight0, weight1); const __m128i src_01 = LoadLo8(pred_0); const __m128i src_11 = LoadLo8(pred_1); pred_0 += 4; pred_1 += 4; src_0 = LoadHi8(src_01, pred_0); src_1 = LoadHi8(src_11, pred_1); pred_0 += 4; pred_1 += 4; const __m128i res1 = ComputeWeightedAverage8(src_0, src_1, weight0, weight1); StoreLo8(dst, res0); dst += dest_stride; StoreHi8(dst, res0); dst += dest_stride; StoreLo8(dst, res1); dst += dest_stride; StoreHi8(dst, res1); dst += dest_stride; y -= 4; } while (y != 0); } template inline void DistanceWeightedBlend8xH_SSE4_1( const uint16_t* LIBGAV1_RESTRICT pred_0, const uint16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0, const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) { auto* dst = static_cast(dest); const __m128i weight0 = _mm_set1_epi32(weight_0); const __m128i weight1 = _mm_set1_epi32(weight_1); int y = height; do { const __m128i src_00 = LoadAligned16(pred_0); const __m128i src_10 = LoadAligned16(pred_1); pred_0 += 8; pred_1 += 8; const __m128i res0 = ComputeWeightedAverage8(src_00, src_10, weight0, weight1); const __m128i src_01 = LoadAligned16(pred_0); const __m128i src_11 = LoadAligned16(pred_1); pred_0 += 8; pred_1 += 8; const __m128i res1 = ComputeWeightedAverage8(src_01, src_11, weight0, weight1); StoreUnaligned16(dst, res0); dst += dest_stride; StoreUnaligned16(dst, res1); dst += dest_stride; y -= 2; } while (y != 0); } inline void DistanceWeightedBlendLarge_SSE4_1( const uint16_t* LIBGAV1_RESTRICT pred_0, const uint16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0, const uint8_t weight_1, const int width, const int height, void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) { auto* dst = static_cast(dest); const __m128i weight0 = _mm_set1_epi32(weight_0); const __m128i weight1 = _mm_set1_epi32(weight_1); int y = height; do { int x = 0; do { const __m128i src_0_lo = LoadAligned16(pred_0 + x); const __m128i src_1_lo = LoadAligned16(pred_1 + x); const __m128i res_lo = ComputeWeightedAverage8(src_0_lo, src_1_lo, weight0, weight1); const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8); const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8); const __m128i res_hi = ComputeWeightedAverage8(src_0_hi, src_1_hi, weight0, weight1); StoreUnaligned16(dst + x, res_lo); x += 8; StoreUnaligned16(dst + x, res_hi); x += 8; } while (x < width); dst += dest_stride; pred_0 += width; pred_1 += width; } while (--y != 0); } void DistanceWeightedBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0, const void* LIBGAV1_RESTRICT prediction_1, const uint8_t weight_0, const uint8_t weight_1, const int width, const int height, void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) { const auto* pred_0 = static_cast(prediction_0); const auto* pred_1 = static_cast(prediction_1); const ptrdiff_t dst_stride = dest_stride / sizeof(*pred_0); if (width == 4) { if (height == 4) { DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1, dest, dst_stride); } else if (height == 8) { DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1, dest, dst_stride); } else { assert(height == 16); DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1, dest, dst_stride); } return; } if (width == 8) { switch (height) { case 4: DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1, dest, dst_stride); return; case 8: DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1, dest, dst_stride); return; case 16: DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1, dest, dst_stride); return; default: assert(height == 32); DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1, dest, dst_stride); return; } } DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width, height, dest, dst_stride); } void Init10bpp() { Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); assert(dsp != nullptr); #if DSP_ENABLED_10BPP_SSE4_1(DistanceWeightedBlend) dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1; #endif } } // namespace } // namespace high_bitdepth #endif // LIBGAV1_MAX_BITDEPTH >= 10 void DistanceWeightedBlendInit_SSE4_1() { low_bitdepth::Init8bpp(); #if LIBGAV1_MAX_BITDEPTH >= 10 high_bitdepth::Init10bpp(); #endif } } // namespace dsp } // namespace libgav1 #else // !LIBGAV1_TARGETING_SSE4_1 namespace libgav1 { namespace dsp { void DistanceWeightedBlendInit_SSE4_1() {} } // namespace dsp } // namespace libgav1 #endif // LIBGAV1_TARGETING_SSE4_1