diff options
Diffstat (limited to 'src/dsp/arm/mask_blend_neon.cc')
-rw-r--r-- | src/dsp/arm/mask_blend_neon.cc | 375 |
1 files changed, 215 insertions, 160 deletions
diff --git a/src/dsp/arm/mask_blend_neon.cc b/src/dsp/arm/mask_blend_neon.cc index 853f949..ecc67f8 100644 --- a/src/dsp/arm/mask_blend_neon.cc +++ b/src/dsp/arm/mask_blend_neon.cc @@ -33,50 +33,40 @@ namespace dsp { namespace low_bitdepth { namespace { -// TODO(b/150461164): Consider combining with GetInterIntraMask4x2(). -// Compound predictors use int16_t values and need to multiply long because the -// Convolve range * 64 is 20 bits. Unfortunately there is no multiply int16_t by -// int8_t and accumulate into int32_t instruction. -template <int subsampling_x, int subsampling_y> -inline int16x8_t GetMask4x2(const uint8_t* mask, ptrdiff_t mask_stride) { - if (subsampling_x == 1) { - const int16x4_t mask_val0 = vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask))); - const int16x4_t mask_val1 = vreinterpret_s16_u16( - vpaddl_u8(vld1_u8(mask + (mask_stride << subsampling_y)))); - int16x8_t final_val; - if (subsampling_y == 1) { - const int16x4_t next_mask_val0 = - vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask + mask_stride))); - const int16x4_t next_mask_val1 = - vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask + mask_stride * 3))); - final_val = vaddq_s16(vcombine_s16(mask_val0, mask_val1), - vcombine_s16(next_mask_val0, next_mask_val1)); - } else { - final_val = vreinterpretq_s16_u16( - vpaddlq_u8(vreinterpretq_u8_s16(vcombine_s16(mask_val0, mask_val1)))); - } - return vrshrq_n_s16(final_val, subsampling_y + 1); +template <int subsampling_y> +inline uint8x8_t GetMask4x2(const uint8_t* mask) { + if (subsampling_y == 1) { + const uint8x16x2_t mask_val = vld2q_u8(mask); + const uint8x16_t combined_horz = vaddq_u8(mask_val.val[0], mask_val.val[1]); + const uint32x2_t row_01 = vreinterpret_u32_u8(vget_low_u8(combined_horz)); + const uint32x2_t row_23 = vreinterpret_u32_u8(vget_high_u8(combined_horz)); + + const uint32x2x2_t row_02_13 = vtrn_u32(row_01, row_23); + // Use a halving add to work around the case where all |mask| values are 64. + return vrshr_n_u8(vhadd_u8(vreinterpret_u8_u32(row_02_13.val[0]), + vreinterpret_u8_u32(row_02_13.val[1])), + 1); } - assert(subsampling_y == 0 && subsampling_x == 0); - const uint8x8_t mask_val0 = Load4(mask); - const uint8x8_t mask_val = Load4<1>(mask + mask_stride, mask_val0); - return vreinterpretq_s16_u16(vmovl_u8(mask_val)); + // subsampling_x == 1 + const uint8x8x2_t mask_val = vld2_u8(mask); + return vrhadd_u8(mask_val.val[0], mask_val.val[1]); } template <int subsampling_x, int subsampling_y> -inline int16x8_t GetMask8(const uint8_t* mask, ptrdiff_t mask_stride) { +inline uint8x8_t GetMask8(const uint8_t* mask) { + if (subsampling_x == 1 && subsampling_y == 1) { + const uint8x16x2_t mask_val = vld2q_u8(mask); + const uint8x16_t combined_horz = vaddq_u8(mask_val.val[0], mask_val.val[1]); + // Use a halving add to work around the case where all |mask| values are 64. + return vrshr_n_u8( + vhadd_u8(vget_low_u8(combined_horz), vget_high_u8(combined_horz)), 1); + } if (subsampling_x == 1) { - int16x8_t mask_val = vreinterpretq_s16_u16(vpaddlq_u8(vld1q_u8(mask))); - if (subsampling_y == 1) { - const int16x8_t next_mask_val = - vreinterpretq_s16_u16(vpaddlq_u8(vld1q_u8(mask + mask_stride))); - mask_val = vaddq_s16(mask_val, next_mask_val); - } - return vrshrq_n_s16(mask_val, 1 + subsampling_y); + const uint8x8x2_t mask_val = vld2_u8(mask); + return vrhadd_u8(mask_val.val[0], mask_val.val[1]); } assert(subsampling_y == 0 && subsampling_x == 0); - const uint8x8_t mask_val = vld1_u8(mask); - return vreinterpretq_s16_u16(vmovl_u8(mask_val)); + return vld1_u8(mask); } inline void WriteMaskBlendLine4x2(const int16_t* LIBGAV1_RESTRICT const pred_0, @@ -109,89 +99,162 @@ inline void WriteMaskBlendLine4x2(const int16_t* LIBGAV1_RESTRICT const pred_0, StoreHi4(dst + dst_stride, result); } -template <int subsampling_x, int subsampling_y> +template <int subsampling_y> inline void MaskBlending4x4_NEON(const int16_t* LIBGAV1_RESTRICT pred_0, const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t* LIBGAV1_RESTRICT mask, - const ptrdiff_t mask_stride, uint8_t* LIBGAV1_RESTRICT dst, const ptrdiff_t dst_stride) { + constexpr int subsampling_x = 1; + constexpr ptrdiff_t mask_stride = 4 << subsampling_x; const int16x8_t mask_inverter = vdupq_n_s16(64); - int16x8_t pred_mask_0 = - GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + // Compound predictors use int16_t values and need to multiply long because + // the Convolve range * 64 is 20 bits. Unfortunately there is no multiply + // int16_t by int8_t and accumulate into int32_t instruction. + int16x8_t pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask)); int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, dst_stride); - // TODO(b/150461164): Arm tends to do better with load(val); val += stride - // It may be possible to turn this into a loop with a templated height. - pred_0 += 4 << 1; - pred_1 += 4 << 1; - mask += mask_stride << (1 + subsampling_y); - dst += dst_stride << 1; + pred_0 += 4 << subsampling_x; + pred_1 += 4 << subsampling_x; + mask += mask_stride << (subsampling_x + subsampling_y); + dst += dst_stride << subsampling_x; - pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask)); pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, dst_stride); } -template <int subsampling_x, int subsampling_y> +template <int subsampling_y> inline void MaskBlending4xH_NEON(const int16_t* LIBGAV1_RESTRICT pred_0, const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t* LIBGAV1_RESTRICT const mask_ptr, - const ptrdiff_t mask_stride, const int height, + const int height, uint8_t* LIBGAV1_RESTRICT dst, const ptrdiff_t dst_stride) { const uint8_t* mask = mask_ptr; if (height == 4) { - MaskBlending4x4_NEON<subsampling_x, subsampling_y>( - pred_0, pred_1, mask, mask_stride, dst, dst_stride); + MaskBlending4x4_NEON<subsampling_y>(pred_0, pred_1, mask, dst, dst_stride); return; } + constexpr int subsampling_x = 1; + constexpr ptrdiff_t mask_stride = 4 << subsampling_x; const int16x8_t mask_inverter = vdupq_n_s16(64); int y = 0; do { int16x8_t pred_mask_0 = - GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + vreinterpretq_s16_u16(vmovl_u8(GetMask4x2<subsampling_y>(mask))); int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, dst_stride); - pred_0 += 4 << 1; - pred_1 += 4 << 1; - mask += mask_stride << (1 + subsampling_y); - dst += dst_stride << 1; + pred_0 += 4 << subsampling_x; + pred_1 += 4 << subsampling_x; + mask += mask_stride << (subsampling_x + subsampling_y); + dst += dst_stride << subsampling_x; - pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask)); pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, dst_stride); - pred_0 += 4 << 1; - pred_1 += 4 << 1; - mask += mask_stride << (1 + subsampling_y); - dst += dst_stride << 1; + pred_0 += 4 << subsampling_x; + pred_1 += 4 << subsampling_x; + mask += mask_stride << (subsampling_x + subsampling_y); + dst += dst_stride << subsampling_x; - pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask)); pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, dst_stride); - pred_0 += 4 << 1; - pred_1 += 4 << 1; - mask += mask_stride << (1 + subsampling_y); - dst += dst_stride << 1; + pred_0 += 4 << subsampling_x; + pred_1 += 4 << subsampling_x; + mask += mask_stride << (subsampling_x + subsampling_y); + dst += dst_stride << subsampling_x; - pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask)); pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, dst_stride); - pred_0 += 4 << 1; - pred_1 += 4 << 1; - mask += mask_stride << (1 + subsampling_y); - dst += dst_stride << 1; + pred_0 += 4 << subsampling_x; + pred_1 += 4 << subsampling_x; + mask += mask_stride << (subsampling_x + subsampling_y); + dst += dst_stride << subsampling_x; y += 8; } while (y < height); } +inline uint8x8_t CombinePred8(const int16_t* LIBGAV1_RESTRICT pred_0, + const int16_t* LIBGAV1_RESTRICT pred_1, + const int16x8_t pred_mask_0, + const int16x8_t pred_mask_1) { + // First 8 values. + const int16x8_t pred_val_0 = vld1q_s16(pred_0); + const int16x8_t pred_val_1 = vld1q_s16(pred_1); + // int res = (mask_value * prediction_0[x] + + // (64 - mask_value) * prediction_1[x]) >> 6; + const int32x4_t weighted_pred_lo = + vmull_s16(vget_low_s16(pred_mask_0), vget_low_s16(pred_val_0)); + const int32x4_t weighted_pred_hi = + vmull_s16(vget_high_s16(pred_mask_0), vget_high_s16(pred_val_0)); + const int32x4_t weighted_combo_lo = vmlal_s16( + weighted_pred_lo, vget_low_s16(pred_mask_1), vget_low_s16(pred_val_1)); + const int32x4_t weighted_combo_hi = vmlal_s16( + weighted_pred_hi, vget_high_s16(pred_mask_1), vget_high_s16(pred_val_1)); + + // dst[x] = static_cast<Pixel>( + // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0, + // (1 << kBitdepth8) - 1)); + return vqrshrun_n_s16(vcombine_s16(vshrn_n_s32(weighted_combo_lo, 6), + vshrn_n_s32(weighted_combo_hi, 6)), + 4); +} + +template <int subsampling_x, int subsampling_y> +inline void MaskBlending8xH_NEON(const int16_t* LIBGAV1_RESTRICT pred_0, + const int16_t* LIBGAV1_RESTRICT pred_1, + const uint8_t* LIBGAV1_RESTRICT const mask_ptr, + const int height, + uint8_t* LIBGAV1_RESTRICT dst, + const ptrdiff_t dst_stride) { + const uint8_t* mask = mask_ptr; + const int16x8_t mask_inverter = vdupq_n_s16(64); + int y = height; + do { + const int16x8_t pred_mask_0 = + ZeroExtend(GetMask8<subsampling_x, subsampling_y>(mask)); + // 64 - mask + const int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + const uint8x8_t result = + CombinePred8(pred_0, pred_1, pred_mask_0, pred_mask_1); + vst1_u8(dst, result); + dst += dst_stride; + mask += 8 << (subsampling_x + subsampling_y); + pred_0 += 8; + pred_1 += 8; + } while (--y != 0); +} + +template <int subsampling_x, int subsampling_y> +inline uint8x16_t GetMask16(const uint8_t* mask, const ptrdiff_t mask_stride) { + if (subsampling_x == 1 && subsampling_y == 1) { + const uint8x16x2_t mask_val0 = vld2q_u8(mask); + const uint8x16x2_t mask_val1 = vld2q_u8(mask + mask_stride); + const uint8x16_t combined_horz0 = + vaddq_u8(mask_val0.val[0], mask_val0.val[1]); + const uint8x16_t combined_horz1 = + vaddq_u8(mask_val1.val[0], mask_val1.val[1]); + // Use a halving add to work around the case where all |mask| values are 64. + return vrshrq_n_u8(vhaddq_u8(combined_horz0, combined_horz1), 1); + } + if (subsampling_x == 1) { + const uint8x16x2_t mask_val = vld2q_u8(mask); + return vrhaddq_u8(mask_val.val[0], mask_val.val[1]); + } + assert(subsampling_y == 0 && subsampling_x == 0); + return vld1q_u8(mask); +} + template <int subsampling_x, int subsampling_y> inline void MaskBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0, const void* LIBGAV1_RESTRICT prediction_1, @@ -204,8 +267,13 @@ inline void MaskBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0, const auto* pred_0 = static_cast<const int16_t*>(prediction_0); const auto* pred_1 = static_cast<const int16_t*>(prediction_1); if (width == 4) { - MaskBlending4xH_NEON<subsampling_x, subsampling_y>( - pred_0, pred_1, mask_ptr, mask_stride, height, dst, dst_stride); + MaskBlending4xH_NEON<subsampling_y>(pred_0, pred_1, mask_ptr, height, dst, + dst_stride); + return; + } + if (width == 8) { + MaskBlending8xH_NEON<subsampling_x, subsampling_y>(pred_0, pred_1, mask_ptr, + height, dst, dst_stride); return; } const uint8_t* mask = mask_ptr; @@ -214,35 +282,24 @@ inline void MaskBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0, do { int x = 0; do { - const int16x8_t pred_mask_0 = GetMask8<subsampling_x, subsampling_y>( + const uint8x16_t pred_mask_0 = GetMask16<subsampling_x, subsampling_y>( mask + (x << subsampling_x), mask_stride); + const int16x8_t pred_mask_0_lo = ZeroExtend(vget_low_u8(pred_mask_0)); + const int16x8_t pred_mask_0_hi = ZeroExtend(vget_high_u8(pred_mask_0)); // 64 - mask - const int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); - const int16x8_t pred_val_0 = vld1q_s16(pred_0 + x); - const int16x8_t pred_val_1 = vld1q_s16(pred_1 + x); + const int16x8_t pred_mask_1_lo = vsubq_s16(mask_inverter, pred_mask_0_lo); + const int16x8_t pred_mask_1_hi = vsubq_s16(mask_inverter, pred_mask_0_hi); + uint8x8_t result; - // int res = (mask_value * prediction_0[x] + - // (64 - mask_value) * prediction_1[x]) >> 6; - const int32x4_t weighted_pred_0_lo = - vmull_s16(vget_low_s16(pred_mask_0), vget_low_s16(pred_val_0)); - const int32x4_t weighted_pred_0_hi = - vmull_s16(vget_high_s16(pred_mask_0), vget_high_s16(pred_val_0)); - const int32x4_t weighted_combo_lo = - vmlal_s16(weighted_pred_0_lo, vget_low_s16(pred_mask_1), - vget_low_s16(pred_val_1)); - const int32x4_t weighted_combo_hi = - vmlal_s16(weighted_pred_0_hi, vget_high_s16(pred_mask_1), - vget_high_s16(pred_val_1)); - - // dst[x] = static_cast<Pixel>( - // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0, - // (1 << kBitdepth8) - 1)); - result = vqrshrun_n_s16(vcombine_s16(vshrn_n_s32(weighted_combo_lo, 6), - vshrn_n_s32(weighted_combo_hi, 6)), - 4); + result = + CombinePred8(pred_0 + x, pred_1 + x, pred_mask_0_lo, pred_mask_1_lo); vst1_u8(dst + x, result); - x += 8; + result = CombinePred8(pred_0 + x + 8, pred_1 + x + 8, pred_mask_0_hi, + pred_mask_1_hi); + vst1_u8(dst + x + 8, result); + + x += 16; } while (x < width); dst += dst_stride; pred_0 += width; @@ -251,63 +308,19 @@ inline void MaskBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0, } while (++y < height); } -// TODO(b/150461164): This is much faster for inter_intra (input is Pixel -// values) but regresses compound versions (input is int16_t). Try to -// consolidate these. template <int subsampling_x, int subsampling_y> inline uint8x8_t GetInterIntraMask4x2(const uint8_t* mask, ptrdiff_t mask_stride) { if (subsampling_x == 1) { - const uint8x8_t mask_val = - vpadd_u8(vld1_u8(mask), vld1_u8(mask + (mask_stride << subsampling_y))); - if (subsampling_y == 1) { - const uint8x8_t next_mask_val = vpadd_u8(vld1_u8(mask + mask_stride), - vld1_u8(mask + mask_stride * 3)); - - // Use a saturating add to work around the case where all |mask| values - // are 64. Together with the rounding shift this ensures the correct - // result. - const uint8x8_t sum = vqadd_u8(mask_val, next_mask_val); - return vrshr_n_u8(sum, /*subsampling_x=*/1 + subsampling_y); - } - - return vrshr_n_u8(mask_val, /*subsampling_x=*/1); + return GetMask4x2<subsampling_y>(mask); } - + // When using intra or difference weighted masks, the function doesn't use + // subsampling, so |mask_stride| may be 4 or 8. assert(subsampling_y == 0 && subsampling_x == 0); const uint8x8_t mask_val0 = Load4(mask); - // TODO(b/150461164): Investigate the source of |mask| and see if the stride - // can be removed. - // TODO(b/150461164): The unit tests start at 8x8. Does this get run? return Load4<1>(mask + mask_stride, mask_val0); } -template <int subsampling_x, int subsampling_y> -inline uint8x8_t GetInterIntraMask8(const uint8_t* mask, - ptrdiff_t mask_stride) { - if (subsampling_x == 1) { - const uint8x16_t mask_val = vld1q_u8(mask); - const uint8x8_t mask_paired = - vpadd_u8(vget_low_u8(mask_val), vget_high_u8(mask_val)); - if (subsampling_y == 1) { - const uint8x16_t next_mask_val = vld1q_u8(mask + mask_stride); - const uint8x8_t next_mask_paired = - vpadd_u8(vget_low_u8(next_mask_val), vget_high_u8(next_mask_val)); - - // Use a saturating add to work around the case where all |mask| values - // are 64. Together with the rounding shift this ensures the correct - // result. - const uint8x8_t sum = vqadd_u8(mask_paired, next_mask_paired); - return vrshr_n_u8(sum, /*subsampling_x=*/1 + subsampling_y); - } - - return vrshr_n_u8(mask_paired, /*subsampling_x=*/1); - } - - assert(subsampling_y == 0 && subsampling_x == 0); - return vld1_u8(mask); -} - inline void InterIntraWriteMaskBlendLine8bpp4x2( const uint8_t* LIBGAV1_RESTRICT const pred_0, uint8_t* LIBGAV1_RESTRICT const pred_1, const ptrdiff_t pred_stride_1, @@ -374,6 +387,32 @@ inline void InterIntraMaskBlending8bpp4xH_NEON( } template <int subsampling_x, int subsampling_y> +inline void InterIntraMaskBlending8bpp8xH_NEON( + const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1, + const ptrdiff_t pred_stride_1, const uint8_t* LIBGAV1_RESTRICT mask, + const ptrdiff_t mask_stride, const int height) { + const uint8x8_t mask_inverter = vdup_n_u8(64); + int y = height; + do { + const uint8x8_t pred_mask_1 = GetMask8<subsampling_x, subsampling_y>(mask); + // 64 - mask + const uint8x8_t pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1); + const uint8x8_t pred_val_0 = vld1_u8(pred_0); + const uint8x8_t pred_val_1 = vld1_u8(pred_1); + const uint16x8_t weighted_pred_0 = vmull_u8(pred_mask_0, pred_val_0); + // weighted_pred0 + weighted_pred1 + const uint16x8_t weighted_combo = + vmlal_u8(weighted_pred_0, pred_mask_1, pred_val_1); + const uint8x8_t result = vrshrn_n_u16(weighted_combo, 6); + vst1_u8(pred_1, result); + + pred_0 += 8; + pred_1 += pred_stride_1; + mask += mask_stride << subsampling_y; + } while (--y != 0); +} + +template <int subsampling_x, int subsampling_y> inline void InterIntraMaskBlend8bpp_NEON( const uint8_t* LIBGAV1_RESTRICT prediction_0, uint8_t* LIBGAV1_RESTRICT prediction_1, const ptrdiff_t prediction_stride_1, @@ -385,30 +424,46 @@ inline void InterIntraMaskBlend8bpp_NEON( height); return; } + if (width == 8) { + InterIntraMaskBlending8bpp8xH_NEON<subsampling_x, subsampling_y>( + prediction_0, prediction_1, prediction_stride_1, mask_ptr, mask_stride, + height); + return; + } const uint8_t* mask = mask_ptr; - const uint8x8_t mask_inverter = vdup_n_u8(64); + const uint8x16_t mask_inverter = vdupq_n_u8(64); int y = 0; do { int x = 0; do { - // TODO(b/150461164): Consider a 16 wide specialization (at least for the - // unsampled version) to take advantage of vld1q_u8(). - const uint8x8_t pred_mask_1 = - GetInterIntraMask8<subsampling_x, subsampling_y>( - mask + (x << subsampling_x), mask_stride); + const uint8x16_t pred_mask_1 = GetMask16<subsampling_x, subsampling_y>( + mask + (x << subsampling_x), mask_stride); // 64 - mask - const uint8x8_t pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1); - const uint8x8_t pred_val_0 = vld1_u8(prediction_0); + const uint8x16_t pred_mask_0 = vsubq_u8(mask_inverter, pred_mask_1); + const uint8x8_t pred_val_0_lo = vld1_u8(prediction_0); + prediction_0 += 8; + const uint8x8_t pred_val_0_hi = vld1_u8(prediction_0); prediction_0 += 8; - const uint8x8_t pred_val_1 = vld1_u8(prediction_1 + x); - const uint16x8_t weighted_pred_0 = vmull_u8(pred_mask_0, pred_val_0); + // Ensure armv7 build combines the load. + const uint8x16_t pred_val_1 = vld1q_u8(prediction_1 + x); + const uint8x8_t pred_val_1_lo = vget_low_u8(pred_val_1); + const uint8x8_t pred_val_1_hi = vget_high_u8(pred_val_1); + const uint16x8_t weighted_pred_0_lo = + vmull_u8(vget_low_u8(pred_mask_0), pred_val_0_lo); // weighted_pred0 + weighted_pred1 - const uint16x8_t weighted_combo = - vmlal_u8(weighted_pred_0, pred_mask_1, pred_val_1); - const uint8x8_t result = vrshrn_n_u16(weighted_combo, 6); - vst1_u8(prediction_1 + x, result); + const uint16x8_t weighted_combo_lo = + vmlal_u8(weighted_pred_0_lo, vget_low_u8(pred_mask_1), pred_val_1_lo); + const uint8x8_t result_lo = vrshrn_n_u16(weighted_combo_lo, 6); + vst1_u8(prediction_1 + x, result_lo); + const uint16x8_t weighted_pred_0_hi = + vmull_u8(vget_high_u8(pred_mask_0), pred_val_0_hi); + // weighted_pred0 + weighted_pred1 + const uint16x8_t weighted_combo_hi = vmlal_u8( + weighted_pred_0_hi, vget_high_u8(pred_mask_1), pred_val_1_hi); + const uint8x8_t result_hi = vrshrn_n_u16(weighted_combo_hi, 6); + vst1_u8(prediction_1 + x + 8, result_hi); - x += 8; + x += 16; } while (x < width); prediction_1 += prediction_stride_1; mask += mask_stride << subsampling_y; |