diff options
Diffstat (limited to 'src/dsp/arm/distance_weighted_blend_neon.cc')
-rw-r--r-- | src/dsp/arm/distance_weighted_blend_neon.cc | 162 |
1 files changed, 159 insertions, 3 deletions
diff --git a/src/dsp/arm/distance_weighted_blend_neon.cc b/src/dsp/arm/distance_weighted_blend_neon.cc index 04952ab..a0cd0ac 100644 --- a/src/dsp/arm/distance_weighted_blend_neon.cc +++ b/src/dsp/arm/distance_weighted_blend_neon.cc @@ -30,10 +30,12 @@ namespace libgav1 { namespace dsp { -namespace { constexpr int kInterPostRoundBit = 4; +namespace low_bitdepth { +namespace { + inline int16x8_t ComputeWeightedAverage8(const int16x8_t pred0, const int16x8_t pred1, const int16x4_t weights[2]) { @@ -185,13 +187,167 @@ void Init8bpp() { } } // namespace +} // namespace low_bitdepth + +//------------------------------------------------------------------------------ +#if LIBGAV1_MAX_BITDEPTH >= 10 +namespace high_bitdepth { +namespace { + +inline uint16x4x2_t ComputeWeightedAverage8(const uint16x4x2_t pred0, + const uint16x4x2_t pred1, + const uint16x4_t weights[2]) { + const uint32x4_t wpred0_lo = vmull_u16(weights[0], pred0.val[0]); + const uint32x4_t wpred0_hi = vmull_u16(weights[0], pred0.val[1]); + const uint32x4_t blended_lo = vmlal_u16(wpred0_lo, weights[1], pred1.val[0]); + const uint32x4_t blended_hi = vmlal_u16(wpred0_hi, weights[1], pred1.val[1]); + const int32x4_t offset = vdupq_n_s32(kCompoundOffset * 16); + const int32x4_t res_lo = vsubq_s32(vreinterpretq_s32_u32(blended_lo), offset); + const int32x4_t res_hi = vsubq_s32(vreinterpretq_s32_u32(blended_hi), offset); + const uint16x4_t bd_max = vdup_n_u16((1 << kBitdepth10) - 1); + // Clip the result at (1 << bd) - 1. + uint16x4x2_t result; + result.val[0] = + vmin_u16(vqrshrun_n_s32(res_lo, kInterPostRoundBit + 4), bd_max); + result.val[1] = + vmin_u16(vqrshrun_n_s32(res_hi, kInterPostRoundBit + 4), bd_max); + return result; +} + +inline uint16x4x4_t ComputeWeightedAverage8(const uint16x4x4_t pred0, + const uint16x4x4_t pred1, + const uint16x4_t weights[2]) { + const int32x4_t offset = vdupq_n_s32(kCompoundOffset * 16); + const uint32x4_t wpred0 = vmull_u16(weights[0], pred0.val[0]); + const uint32x4_t wpred1 = vmull_u16(weights[0], pred0.val[1]); + const uint32x4_t blended0 = vmlal_u16(wpred0, weights[1], pred1.val[0]); + const uint32x4_t blended1 = vmlal_u16(wpred1, weights[1], pred1.val[1]); + const int32x4_t res0 = vsubq_s32(vreinterpretq_s32_u32(blended0), offset); + const int32x4_t res1 = vsubq_s32(vreinterpretq_s32_u32(blended1), offset); + const uint32x4_t wpred2 = vmull_u16(weights[0], pred0.val[2]); + const uint32x4_t wpred3 = vmull_u16(weights[0], pred0.val[3]); + const uint32x4_t blended2 = vmlal_u16(wpred2, weights[1], pred1.val[2]); + const uint32x4_t blended3 = vmlal_u16(wpred3, weights[1], pred1.val[3]); + const int32x4_t res2 = vsubq_s32(vreinterpretq_s32_u32(blended2), offset); + const int32x4_t res3 = vsubq_s32(vreinterpretq_s32_u32(blended3), offset); + const uint16x4_t bd_max = vdup_n_u16((1 << kBitdepth10) - 1); + // Clip the result at (1 << bd) - 1. + uint16x4x4_t result; + result.val[0] = + vmin_u16(vqrshrun_n_s32(res0, kInterPostRoundBit + 4), bd_max); + result.val[1] = + vmin_u16(vqrshrun_n_s32(res1, kInterPostRoundBit + 4), bd_max); + result.val[2] = + vmin_u16(vqrshrun_n_s32(res2, kInterPostRoundBit + 4), bd_max); + result.val[3] = + vmin_u16(vqrshrun_n_s32(res3, kInterPostRoundBit + 4), bd_max); + + return result; +} + +// We could use vld1_u16_x2, but for compatibility reasons, use this function +// instead. The compiler optimizes to the correct instruction. +inline uint16x4x2_t LoadU16x4_x2(uint16_t const* ptr) { + uint16x4x2_t x; + // gcc/clang (64 bit) optimizes the following to ldp. + x.val[0] = vld1_u16(ptr); + x.val[1] = vld1_u16(ptr + 4); + return x; +} + +// We could use vld1_u16_x4, but for compatibility reasons, use this function +// instead. The compiler optimizes to a pair of vld1_u16_x2, which showed better +// performance in the speed tests. +inline uint16x4x4_t LoadU16x4_x4(uint16_t const* ptr) { + uint16x4x4_t x; + x.val[0] = vld1_u16(ptr); + x.val[1] = vld1_u16(ptr + 4); + x.val[2] = vld1_u16(ptr + 8); + x.val[3] = vld1_u16(ptr + 12); + return x; +} + +void DistanceWeightedBlend_NEON(const void* prediction_0, + const void* prediction_1, + const uint8_t weight_0, const uint8_t weight_1, + const int width, const int height, + void* const dest, const ptrdiff_t dest_stride) { + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + auto* dst = static_cast<uint16_t*>(dest); + const ptrdiff_t dst_stride = dest_stride / sizeof(dst[0]); + const uint16x4_t weights[2] = {vdup_n_u16(weight_0), vdup_n_u16(weight_1)}; -void DistanceWeightedBlendInit_NEON() { Init8bpp(); } + if (width == 4) { + int y = height; + do { + const uint16x4x2_t src0 = LoadU16x4_x2(pred_0); + const uint16x4x2_t src1 = LoadU16x4_x2(pred_1); + const uint16x4x2_t res = ComputeWeightedAverage8(src0, src1, weights); + vst1_u16(dst, res.val[0]); + vst1_u16(dst + dst_stride, res.val[1]); + dst += dst_stride << 1; + pred_0 += 8; + pred_1 += 8; + y -= 2; + } while (y != 0); + } else if (width == 8) { + int y = height; + do { + const uint16x4x4_t src0 = LoadU16x4_x4(pred_0); + const uint16x4x4_t src1 = LoadU16x4_x4(pred_1); + const uint16x4x4_t res = ComputeWeightedAverage8(src0, src1, weights); + vst1_u16(dst, res.val[0]); + vst1_u16(dst + 4, res.val[1]); + vst1_u16(dst + dst_stride, res.val[2]); + vst1_u16(dst + dst_stride + 4, res.val[3]); + dst += dst_stride << 1; + pred_0 += 16; + pred_1 += 16; + y -= 2; + } while (y != 0); + } else { + int y = height; + do { + int x = 0; + do { + const uint16x4x4_t src0 = LoadU16x4_x4(pred_0 + x); + const uint16x4x4_t src1 = LoadU16x4_x4(pred_1 + x); + const uint16x4x4_t res = ComputeWeightedAverage8(src0, src1, weights); + vst1_u16(dst + x, res.val[0]); + vst1_u16(dst + x + 4, res.val[1]); + vst1_u16(dst + x + 8, res.val[2]); + vst1_u16(dst + x + 12, res.val[3]); + x += 16; + } while (x < width); + dst += dst_stride; + pred_0 += width; + pred_1 += width; + } while (--y != 0); + } +} + +void Init10bpp() { + Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + dsp->distance_weighted_blend = DistanceWeightedBlend_NEON; +} + +} // namespace +} // namespace high_bitdepth +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +void DistanceWeightedBlendInit_NEON() { + low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + high_bitdepth::Init10bpp(); +#endif +} } // namespace dsp } // namespace libgav1 -#else // !LIBGAV1_ENABLE_NEON +#else // !LIBGAV1_ENABLE_NEON namespace libgav1 { namespace dsp { |