diff options
Diffstat (limited to 'src/dsp/arm/distance_weighted_blend_neon.cc')
-rw-r--r-- | src/dsp/arm/distance_weighted_blend_neon.cc | 105 |
1 files changed, 41 insertions, 64 deletions
diff --git a/src/dsp/arm/distance_weighted_blend_neon.cc b/src/dsp/arm/distance_weighted_blend_neon.cc index 7d287c8..6087276 100644 --- a/src/dsp/arm/distance_weighted_blend_neon.cc +++ b/src/dsp/arm/distance_weighted_blend_neon.cc @@ -36,44 +36,48 @@ constexpr int kInterPostRoundBit = 4; namespace low_bitdepth { namespace { -inline int16x8_t ComputeWeightedAverage8(const int16x8_t pred0, +inline uint8x8_t ComputeWeightedAverage8(const int16x8_t pred0, const int16x8_t pred1, - const int16x4_t weights[2]) { - // TODO(https://issuetracker.google.com/issues/150325685): Investigate range. - const int32x4_t wpred0_lo = vmull_s16(weights[0], vget_low_s16(pred0)); - const int32x4_t wpred0_hi = vmull_s16(weights[0], vget_high_s16(pred0)); - const int32x4_t blended_lo = - vmlal_s16(wpred0_lo, weights[1], vget_low_s16(pred1)); - const int32x4_t blended_hi = - vmlal_s16(wpred0_hi, weights[1], vget_high_s16(pred1)); - - return vcombine_s16(vqrshrn_n_s32(blended_lo, kInterPostRoundBit + 4), - vqrshrn_n_s32(blended_hi, kInterPostRoundBit + 4)); + const int16x8_t weight) { + // Given: p0,p1 in range [-5132,9212] and w0 = 16 - w1, w1 = 16 - w0 + // Output: (p0 * w0 + p1 * w1 + 128(=rounding bit)) >> + // 8(=kInterPostRoundBit + 4) + // The formula is manipulated to avoid lengthening to 32 bits. + // p0 * w0 + p1 * w1 = p0 * w0 + (16 - w0) * p1 + // = (p0 - p1) * w0 + 16 * p1 + // Maximum value of p0 - p1 is 9212 + 5132 = 0x3808. + const int16x8_t diff = vsubq_s16(pred0, pred1); + // (((p0 - p1) * (w0 << 11) << 1) >> 16) + ((16 * p1) >> 4) + const int16x8_t weighted_diff = vqdmulhq_s16(diff, weight); + // ((p0 - p1) * w0 >> 4) + p1 + const int16x8_t upscaled_average = vaddq_s16(weighted_diff, pred1); + // (((p0 - p1) * w0 >> 4) + p1 + (128 >> 4)) >> 4 + return vqrshrun_n_s16(upscaled_average, kInterPostRoundBit); } -template <int width, int height> +template <int width> inline void DistanceWeightedBlendSmall_NEON( const int16_t* LIBGAV1_RESTRICT prediction_0, - const int16_t* LIBGAV1_RESTRICT prediction_1, const int16x4_t weights[2], - void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) { + const int16_t* LIBGAV1_RESTRICT prediction_1, const int height, + const int16x8_t weight, void* LIBGAV1_RESTRICT const dest, + const ptrdiff_t dest_stride) { auto* dst = static_cast<uint8_t*>(dest); constexpr int step = 16 / width; - for (int y = 0; y < height; y += step) { + int y = height; + do { const int16x8_t src_00 = vld1q_s16(prediction_0); const int16x8_t src_10 = vld1q_s16(prediction_1); prediction_0 += 8; prediction_1 += 8; - const int16x8_t res0 = ComputeWeightedAverage8(src_00, src_10, weights); + const uint8x8_t result0 = ComputeWeightedAverage8(src_00, src_10, weight); const int16x8_t src_01 = vld1q_s16(prediction_0); const int16x8_t src_11 = vld1q_s16(prediction_1); prediction_0 += 8; prediction_1 += 8; - const int16x8_t res1 = ComputeWeightedAverage8(src_01, src_11, weights); + const uint8x8_t result1 = ComputeWeightedAverage8(src_01, src_11, weight); - const uint8x8_t result0 = vqmovun_s16(res0); - const uint8x8_t result1 = vqmovun_s16(res1); if (width == 4) { StoreLo4(dst, result0); dst += dest_stride; @@ -90,12 +94,13 @@ inline void DistanceWeightedBlendSmall_NEON( vst1_u8(dst, result1); dst += dest_stride; } - } + y -= step; + } while (y != 0); } inline void DistanceWeightedBlendLarge_NEON( const int16_t* LIBGAV1_RESTRICT prediction_0, - const int16_t* LIBGAV1_RESTRICT prediction_1, const int16x4_t weights[2], + const int16_t* LIBGAV1_RESTRICT prediction_1, const int16x8_t weight, const int width, const int height, void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) { auto* dst = static_cast<uint8_t*>(dest); @@ -106,16 +111,15 @@ inline void DistanceWeightedBlendLarge_NEON( do { const int16x8_t src0_lo = vld1q_s16(prediction_0 + x); const int16x8_t src1_lo = vld1q_s16(prediction_1 + x); - const int16x8_t res_lo = - ComputeWeightedAverage8(src0_lo, src1_lo, weights); + const uint8x8_t res_lo = + ComputeWeightedAverage8(src0_lo, src1_lo, weight); const int16x8_t src0_hi = vld1q_s16(prediction_0 + x + 8); const int16x8_t src1_hi = vld1q_s16(prediction_1 + x + 8); - const int16x8_t res_hi = - ComputeWeightedAverage8(src0_hi, src1_hi, weights); + const uint8x8_t res_hi = + ComputeWeightedAverage8(src0_hi, src1_hi, weight); - const uint8x16_t result = - vcombine_u8(vqmovun_s16(res_lo), vqmovun_s16(res_hi)); + const uint8x16_t result = vcombine_u8(res_lo, res_hi); vst1q_u8(dst + x, result); x += 16; } while (x < width); @@ -128,52 +132,25 @@ inline void DistanceWeightedBlendLarge_NEON( inline void DistanceWeightedBlend_NEON( const void* LIBGAV1_RESTRICT prediction_0, const void* LIBGAV1_RESTRICT prediction_1, const uint8_t weight_0, - const uint8_t weight_1, const int width, const int height, + const uint8_t /*weight_1*/, const int width, const int height, void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) { const auto* pred_0 = static_cast<const int16_t*>(prediction_0); const auto* pred_1 = static_cast<const int16_t*>(prediction_1); - int16x4_t weights[2] = {vdup_n_s16(weight_0), vdup_n_s16(weight_1)}; - // TODO(johannkoenig): Investigate the branching. May be fine to call with a - // variable height. + // Upscale the weight for vqdmulh. + const int16x8_t weight = vdupq_n_s16(weight_0 << 11); if (width == 4) { - if (height == 4) { - DistanceWeightedBlendSmall_NEON<4, 4>(pred_0, pred_1, weights, dest, - dest_stride); - } else if (height == 8) { - DistanceWeightedBlendSmall_NEON<4, 8>(pred_0, pred_1, weights, dest, - dest_stride); - } else { - assert(height == 16); - DistanceWeightedBlendSmall_NEON<4, 16>(pred_0, pred_1, weights, dest, - dest_stride); - } + DistanceWeightedBlendSmall_NEON<4>(pred_0, pred_1, height, weight, dest, + dest_stride); return; } if (width == 8) { - switch (height) { - case 4: - DistanceWeightedBlendSmall_NEON<8, 4>(pred_0, pred_1, weights, dest, - dest_stride); - return; - case 8: - DistanceWeightedBlendSmall_NEON<8, 8>(pred_0, pred_1, weights, dest, - dest_stride); - return; - case 16: - DistanceWeightedBlendSmall_NEON<8, 16>(pred_0, pred_1, weights, dest, - dest_stride); - return; - default: - assert(height == 32); - DistanceWeightedBlendSmall_NEON<8, 32>(pred_0, pred_1, weights, dest, - dest_stride); - - return; - } + DistanceWeightedBlendSmall_NEON<8>(pred_0, pred_1, height, weight, dest, + dest_stride); + return; } - DistanceWeightedBlendLarge_NEON(pred_0, pred_1, weights, width, height, dest, + DistanceWeightedBlendLarge_NEON(pred_0, pred_1, weight, width, height, dest, dest_stride); } |