aboutsummaryrefslogtreecommitdiff
path: root/src/dsp/arm/distance_weighted_blend_neon.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/dsp/arm/distance_weighted_blend_neon.cc')
-rw-r--r--src/dsp/arm/distance_weighted_blend_neon.cc162
1 files changed, 159 insertions, 3 deletions
diff --git a/src/dsp/arm/distance_weighted_blend_neon.cc b/src/dsp/arm/distance_weighted_blend_neon.cc
index 04952ab..a0cd0ac 100644
--- a/src/dsp/arm/distance_weighted_blend_neon.cc
+++ b/src/dsp/arm/distance_weighted_blend_neon.cc
@@ -30,10 +30,12 @@
namespace libgav1 {
namespace dsp {
-namespace {
constexpr int kInterPostRoundBit = 4;
+namespace low_bitdepth {
+namespace {
+
inline int16x8_t ComputeWeightedAverage8(const int16x8_t pred0,
const int16x8_t pred1,
const int16x4_t weights[2]) {
@@ -185,13 +187,167 @@ void Init8bpp() {
}
} // namespace
+} // namespace low_bitdepth
+
+//------------------------------------------------------------------------------
+#if LIBGAV1_MAX_BITDEPTH >= 10
+namespace high_bitdepth {
+namespace {
+
+inline uint16x4x2_t ComputeWeightedAverage8(const uint16x4x2_t pred0,
+ const uint16x4x2_t pred1,
+ const uint16x4_t weights[2]) {
+ const uint32x4_t wpred0_lo = vmull_u16(weights[0], pred0.val[0]);
+ const uint32x4_t wpred0_hi = vmull_u16(weights[0], pred0.val[1]);
+ const uint32x4_t blended_lo = vmlal_u16(wpred0_lo, weights[1], pred1.val[0]);
+ const uint32x4_t blended_hi = vmlal_u16(wpred0_hi, weights[1], pred1.val[1]);
+ const int32x4_t offset = vdupq_n_s32(kCompoundOffset * 16);
+ const int32x4_t res_lo = vsubq_s32(vreinterpretq_s32_u32(blended_lo), offset);
+ const int32x4_t res_hi = vsubq_s32(vreinterpretq_s32_u32(blended_hi), offset);
+ const uint16x4_t bd_max = vdup_n_u16((1 << kBitdepth10) - 1);
+ // Clip the result at (1 << bd) - 1.
+ uint16x4x2_t result;
+ result.val[0] =
+ vmin_u16(vqrshrun_n_s32(res_lo, kInterPostRoundBit + 4), bd_max);
+ result.val[1] =
+ vmin_u16(vqrshrun_n_s32(res_hi, kInterPostRoundBit + 4), bd_max);
+ return result;
+}
+
+inline uint16x4x4_t ComputeWeightedAverage8(const uint16x4x4_t pred0,
+ const uint16x4x4_t pred1,
+ const uint16x4_t weights[2]) {
+ const int32x4_t offset = vdupq_n_s32(kCompoundOffset * 16);
+ const uint32x4_t wpred0 = vmull_u16(weights[0], pred0.val[0]);
+ const uint32x4_t wpred1 = vmull_u16(weights[0], pred0.val[1]);
+ const uint32x4_t blended0 = vmlal_u16(wpred0, weights[1], pred1.val[0]);
+ const uint32x4_t blended1 = vmlal_u16(wpred1, weights[1], pred1.val[1]);
+ const int32x4_t res0 = vsubq_s32(vreinterpretq_s32_u32(blended0), offset);
+ const int32x4_t res1 = vsubq_s32(vreinterpretq_s32_u32(blended1), offset);
+ const uint32x4_t wpred2 = vmull_u16(weights[0], pred0.val[2]);
+ const uint32x4_t wpred3 = vmull_u16(weights[0], pred0.val[3]);
+ const uint32x4_t blended2 = vmlal_u16(wpred2, weights[1], pred1.val[2]);
+ const uint32x4_t blended3 = vmlal_u16(wpred3, weights[1], pred1.val[3]);
+ const int32x4_t res2 = vsubq_s32(vreinterpretq_s32_u32(blended2), offset);
+ const int32x4_t res3 = vsubq_s32(vreinterpretq_s32_u32(blended3), offset);
+ const uint16x4_t bd_max = vdup_n_u16((1 << kBitdepth10) - 1);
+ // Clip the result at (1 << bd) - 1.
+ uint16x4x4_t result;
+ result.val[0] =
+ vmin_u16(vqrshrun_n_s32(res0, kInterPostRoundBit + 4), bd_max);
+ result.val[1] =
+ vmin_u16(vqrshrun_n_s32(res1, kInterPostRoundBit + 4), bd_max);
+ result.val[2] =
+ vmin_u16(vqrshrun_n_s32(res2, kInterPostRoundBit + 4), bd_max);
+ result.val[3] =
+ vmin_u16(vqrshrun_n_s32(res3, kInterPostRoundBit + 4), bd_max);
+
+ return result;
+}
+
+// We could use vld1_u16_x2, but for compatibility reasons, use this function
+// instead. The compiler optimizes to the correct instruction.
+inline uint16x4x2_t LoadU16x4_x2(uint16_t const* ptr) {
+ uint16x4x2_t x;
+ // gcc/clang (64 bit) optimizes the following to ldp.
+ x.val[0] = vld1_u16(ptr);
+ x.val[1] = vld1_u16(ptr + 4);
+ return x;
+}
+
+// We could use vld1_u16_x4, but for compatibility reasons, use this function
+// instead. The compiler optimizes to a pair of vld1_u16_x2, which showed better
+// performance in the speed tests.
+inline uint16x4x4_t LoadU16x4_x4(uint16_t const* ptr) {
+ uint16x4x4_t x;
+ x.val[0] = vld1_u16(ptr);
+ x.val[1] = vld1_u16(ptr + 4);
+ x.val[2] = vld1_u16(ptr + 8);
+ x.val[3] = vld1_u16(ptr + 12);
+ return x;
+}
+
+void DistanceWeightedBlend_NEON(const void* prediction_0,
+ const void* prediction_1,
+ const uint8_t weight_0, const uint8_t weight_1,
+ const int width, const int height,
+ void* const dest, const ptrdiff_t dest_stride) {
+ const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
+ auto* dst = static_cast<uint16_t*>(dest);
+ const ptrdiff_t dst_stride = dest_stride / sizeof(dst[0]);
+ const uint16x4_t weights[2] = {vdup_n_u16(weight_0), vdup_n_u16(weight_1)};
-void DistanceWeightedBlendInit_NEON() { Init8bpp(); }
+ if (width == 4) {
+ int y = height;
+ do {
+ const uint16x4x2_t src0 = LoadU16x4_x2(pred_0);
+ const uint16x4x2_t src1 = LoadU16x4_x2(pred_1);
+ const uint16x4x2_t res = ComputeWeightedAverage8(src0, src1, weights);
+ vst1_u16(dst, res.val[0]);
+ vst1_u16(dst + dst_stride, res.val[1]);
+ dst += dst_stride << 1;
+ pred_0 += 8;
+ pred_1 += 8;
+ y -= 2;
+ } while (y != 0);
+ } else if (width == 8) {
+ int y = height;
+ do {
+ const uint16x4x4_t src0 = LoadU16x4_x4(pred_0);
+ const uint16x4x4_t src1 = LoadU16x4_x4(pred_1);
+ const uint16x4x4_t res = ComputeWeightedAverage8(src0, src1, weights);
+ vst1_u16(dst, res.val[0]);
+ vst1_u16(dst + 4, res.val[1]);
+ vst1_u16(dst + dst_stride, res.val[2]);
+ vst1_u16(dst + dst_stride + 4, res.val[3]);
+ dst += dst_stride << 1;
+ pred_0 += 16;
+ pred_1 += 16;
+ y -= 2;
+ } while (y != 0);
+ } else {
+ int y = height;
+ do {
+ int x = 0;
+ do {
+ const uint16x4x4_t src0 = LoadU16x4_x4(pred_0 + x);
+ const uint16x4x4_t src1 = LoadU16x4_x4(pred_1 + x);
+ const uint16x4x4_t res = ComputeWeightedAverage8(src0, src1, weights);
+ vst1_u16(dst + x, res.val[0]);
+ vst1_u16(dst + x + 4, res.val[1]);
+ vst1_u16(dst + x + 8, res.val[2]);
+ vst1_u16(dst + x + 12, res.val[3]);
+ x += 16;
+ } while (x < width);
+ dst += dst_stride;
+ pred_0 += width;
+ pred_1 += width;
+ } while (--y != 0);
+ }
+}
+
+void Init10bpp() {
+ Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+ assert(dsp != nullptr);
+ dsp->distance_weighted_blend = DistanceWeightedBlend_NEON;
+}
+
+} // namespace
+} // namespace high_bitdepth
+#endif // LIBGAV1_MAX_BITDEPTH >= 10
+
+void DistanceWeightedBlendInit_NEON() {
+ low_bitdepth::Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+ high_bitdepth::Init10bpp();
+#endif
+}
} // namespace dsp
} // namespace libgav1
-#else // !LIBGAV1_ENABLE_NEON
+#else // !LIBGAV1_ENABLE_NEON
namespace libgav1 {
namespace dsp {