aboutsummaryrefslogtreecommitdiff
path: root/src/dsp/arm/mask_blend_neon.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/dsp/arm/mask_blend_neon.cc')
-rw-r--r--src/dsp/arm/mask_blend_neon.cc375
1 files changed, 215 insertions, 160 deletions
diff --git a/src/dsp/arm/mask_blend_neon.cc b/src/dsp/arm/mask_blend_neon.cc
index 853f949..ecc67f8 100644
--- a/src/dsp/arm/mask_blend_neon.cc
+++ b/src/dsp/arm/mask_blend_neon.cc
@@ -33,50 +33,40 @@ namespace dsp {
namespace low_bitdepth {
namespace {
-// TODO(b/150461164): Consider combining with GetInterIntraMask4x2().
-// Compound predictors use int16_t values and need to multiply long because the
-// Convolve range * 64 is 20 bits. Unfortunately there is no multiply int16_t by
-// int8_t and accumulate into int32_t instruction.
-template <int subsampling_x, int subsampling_y>
-inline int16x8_t GetMask4x2(const uint8_t* mask, ptrdiff_t mask_stride) {
- if (subsampling_x == 1) {
- const int16x4_t mask_val0 = vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask)));
- const int16x4_t mask_val1 = vreinterpret_s16_u16(
- vpaddl_u8(vld1_u8(mask + (mask_stride << subsampling_y))));
- int16x8_t final_val;
- if (subsampling_y == 1) {
- const int16x4_t next_mask_val0 =
- vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask + mask_stride)));
- const int16x4_t next_mask_val1 =
- vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask + mask_stride * 3)));
- final_val = vaddq_s16(vcombine_s16(mask_val0, mask_val1),
- vcombine_s16(next_mask_val0, next_mask_val1));
- } else {
- final_val = vreinterpretq_s16_u16(
- vpaddlq_u8(vreinterpretq_u8_s16(vcombine_s16(mask_val0, mask_val1))));
- }
- return vrshrq_n_s16(final_val, subsampling_y + 1);
+template <int subsampling_y>
+inline uint8x8_t GetMask4x2(const uint8_t* mask) {
+ if (subsampling_y == 1) {
+ const uint8x16x2_t mask_val = vld2q_u8(mask);
+ const uint8x16_t combined_horz = vaddq_u8(mask_val.val[0], mask_val.val[1]);
+ const uint32x2_t row_01 = vreinterpret_u32_u8(vget_low_u8(combined_horz));
+ const uint32x2_t row_23 = vreinterpret_u32_u8(vget_high_u8(combined_horz));
+
+ const uint32x2x2_t row_02_13 = vtrn_u32(row_01, row_23);
+ // Use a halving add to work around the case where all |mask| values are 64.
+ return vrshr_n_u8(vhadd_u8(vreinterpret_u8_u32(row_02_13.val[0]),
+ vreinterpret_u8_u32(row_02_13.val[1])),
+ 1);
}
- assert(subsampling_y == 0 && subsampling_x == 0);
- const uint8x8_t mask_val0 = Load4(mask);
- const uint8x8_t mask_val = Load4<1>(mask + mask_stride, mask_val0);
- return vreinterpretq_s16_u16(vmovl_u8(mask_val));
+ // subsampling_x == 1
+ const uint8x8x2_t mask_val = vld2_u8(mask);
+ return vrhadd_u8(mask_val.val[0], mask_val.val[1]);
}
template <int subsampling_x, int subsampling_y>
-inline int16x8_t GetMask8(const uint8_t* mask, ptrdiff_t mask_stride) {
+inline uint8x8_t GetMask8(const uint8_t* mask) {
+ if (subsampling_x == 1 && subsampling_y == 1) {
+ const uint8x16x2_t mask_val = vld2q_u8(mask);
+ const uint8x16_t combined_horz = vaddq_u8(mask_val.val[0], mask_val.val[1]);
+ // Use a halving add to work around the case where all |mask| values are 64.
+ return vrshr_n_u8(
+ vhadd_u8(vget_low_u8(combined_horz), vget_high_u8(combined_horz)), 1);
+ }
if (subsampling_x == 1) {
- int16x8_t mask_val = vreinterpretq_s16_u16(vpaddlq_u8(vld1q_u8(mask)));
- if (subsampling_y == 1) {
- const int16x8_t next_mask_val =
- vreinterpretq_s16_u16(vpaddlq_u8(vld1q_u8(mask + mask_stride)));
- mask_val = vaddq_s16(mask_val, next_mask_val);
- }
- return vrshrq_n_s16(mask_val, 1 + subsampling_y);
+ const uint8x8x2_t mask_val = vld2_u8(mask);
+ return vrhadd_u8(mask_val.val[0], mask_val.val[1]);
}
assert(subsampling_y == 0 && subsampling_x == 0);
- const uint8x8_t mask_val = vld1_u8(mask);
- return vreinterpretq_s16_u16(vmovl_u8(mask_val));
+ return vld1_u8(mask);
}
inline void WriteMaskBlendLine4x2(const int16_t* LIBGAV1_RESTRICT const pred_0,
@@ -109,89 +99,162 @@ inline void WriteMaskBlendLine4x2(const int16_t* LIBGAV1_RESTRICT const pred_0,
StoreHi4(dst + dst_stride, result);
}
-template <int subsampling_x, int subsampling_y>
+template <int subsampling_y>
inline void MaskBlending4x4_NEON(const int16_t* LIBGAV1_RESTRICT pred_0,
const int16_t* LIBGAV1_RESTRICT pred_1,
const uint8_t* LIBGAV1_RESTRICT mask,
- const ptrdiff_t mask_stride,
uint8_t* LIBGAV1_RESTRICT dst,
const ptrdiff_t dst_stride) {
+ constexpr int subsampling_x = 1;
+ constexpr ptrdiff_t mask_stride = 4 << subsampling_x;
const int16x8_t mask_inverter = vdupq_n_s16(64);
- int16x8_t pred_mask_0 =
- GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+ // Compound predictors use int16_t values and need to multiply long because
+ // the Convolve range * 64 is 20 bits. Unfortunately there is no multiply
+ // int16_t by int8_t and accumulate into int32_t instruction.
+ int16x8_t pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask));
int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
dst_stride);
- // TODO(b/150461164): Arm tends to do better with load(val); val += stride
- // It may be possible to turn this into a loop with a templated height.
- pred_0 += 4 << 1;
- pred_1 += 4 << 1;
- mask += mask_stride << (1 + subsampling_y);
- dst += dst_stride << 1;
+ pred_0 += 4 << subsampling_x;
+ pred_1 += 4 << subsampling_x;
+ mask += mask_stride << (subsampling_x + subsampling_y);
+ dst += dst_stride << subsampling_x;
- pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+ pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask));
pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
dst_stride);
}
-template <int subsampling_x, int subsampling_y>
+template <int subsampling_y>
inline void MaskBlending4xH_NEON(const int16_t* LIBGAV1_RESTRICT pred_0,
const int16_t* LIBGAV1_RESTRICT pred_1,
const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
- const ptrdiff_t mask_stride, const int height,
+ const int height,
uint8_t* LIBGAV1_RESTRICT dst,
const ptrdiff_t dst_stride) {
const uint8_t* mask = mask_ptr;
if (height == 4) {
- MaskBlending4x4_NEON<subsampling_x, subsampling_y>(
- pred_0, pred_1, mask, mask_stride, dst, dst_stride);
+ MaskBlending4x4_NEON<subsampling_y>(pred_0, pred_1, mask, dst, dst_stride);
return;
}
+ constexpr int subsampling_x = 1;
+ constexpr ptrdiff_t mask_stride = 4 << subsampling_x;
const int16x8_t mask_inverter = vdupq_n_s16(64);
int y = 0;
do {
int16x8_t pred_mask_0 =
- GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+ vreinterpretq_s16_u16(vmovl_u8(GetMask4x2<subsampling_y>(mask)));
int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
dst_stride);
- pred_0 += 4 << 1;
- pred_1 += 4 << 1;
- mask += mask_stride << (1 + subsampling_y);
- dst += dst_stride << 1;
+ pred_0 += 4 << subsampling_x;
+ pred_1 += 4 << subsampling_x;
+ mask += mask_stride << (subsampling_x + subsampling_y);
+ dst += dst_stride << subsampling_x;
- pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+ pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask));
pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
dst_stride);
- pred_0 += 4 << 1;
- pred_1 += 4 << 1;
- mask += mask_stride << (1 + subsampling_y);
- dst += dst_stride << 1;
+ pred_0 += 4 << subsampling_x;
+ pred_1 += 4 << subsampling_x;
+ mask += mask_stride << (subsampling_x + subsampling_y);
+ dst += dst_stride << subsampling_x;
- pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+ pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask));
pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
dst_stride);
- pred_0 += 4 << 1;
- pred_1 += 4 << 1;
- mask += mask_stride << (1 + subsampling_y);
- dst += dst_stride << 1;
+ pred_0 += 4 << subsampling_x;
+ pred_1 += 4 << subsampling_x;
+ mask += mask_stride << (subsampling_x + subsampling_y);
+ dst += dst_stride << subsampling_x;
- pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+ pred_mask_0 = ZeroExtend(GetMask4x2<subsampling_y>(mask));
pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
dst_stride);
- pred_0 += 4 << 1;
- pred_1 += 4 << 1;
- mask += mask_stride << (1 + subsampling_y);
- dst += dst_stride << 1;
+ pred_0 += 4 << subsampling_x;
+ pred_1 += 4 << subsampling_x;
+ mask += mask_stride << (subsampling_x + subsampling_y);
+ dst += dst_stride << subsampling_x;
y += 8;
} while (y < height);
}
+inline uint8x8_t CombinePred8(const int16_t* LIBGAV1_RESTRICT pred_0,
+ const int16_t* LIBGAV1_RESTRICT pred_1,
+ const int16x8_t pred_mask_0,
+ const int16x8_t pred_mask_1) {
+ // First 8 values.
+ const int16x8_t pred_val_0 = vld1q_s16(pred_0);
+ const int16x8_t pred_val_1 = vld1q_s16(pred_1);
+ // int res = (mask_value * prediction_0[x] +
+ // (64 - mask_value) * prediction_1[x]) >> 6;
+ const int32x4_t weighted_pred_lo =
+ vmull_s16(vget_low_s16(pred_mask_0), vget_low_s16(pred_val_0));
+ const int32x4_t weighted_pred_hi =
+ vmull_s16(vget_high_s16(pred_mask_0), vget_high_s16(pred_val_0));
+ const int32x4_t weighted_combo_lo = vmlal_s16(
+ weighted_pred_lo, vget_low_s16(pred_mask_1), vget_low_s16(pred_val_1));
+ const int32x4_t weighted_combo_hi = vmlal_s16(
+ weighted_pred_hi, vget_high_s16(pred_mask_1), vget_high_s16(pred_val_1));
+
+ // dst[x] = static_cast<Pixel>(
+ // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
+ // (1 << kBitdepth8) - 1));
+ return vqrshrun_n_s16(vcombine_s16(vshrn_n_s32(weighted_combo_lo, 6),
+ vshrn_n_s32(weighted_combo_hi, 6)),
+ 4);
+}
+
+template <int subsampling_x, int subsampling_y>
+inline void MaskBlending8xH_NEON(const int16_t* LIBGAV1_RESTRICT pred_0,
+ const int16_t* LIBGAV1_RESTRICT pred_1,
+ const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
+ const int height,
+ uint8_t* LIBGAV1_RESTRICT dst,
+ const ptrdiff_t dst_stride) {
+ const uint8_t* mask = mask_ptr;
+ const int16x8_t mask_inverter = vdupq_n_s16(64);
+ int y = height;
+ do {
+ const int16x8_t pred_mask_0 =
+ ZeroExtend(GetMask8<subsampling_x, subsampling_y>(mask));
+ // 64 - mask
+ const int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
+ const uint8x8_t result =
+ CombinePred8(pred_0, pred_1, pred_mask_0, pred_mask_1);
+ vst1_u8(dst, result);
+ dst += dst_stride;
+ mask += 8 << (subsampling_x + subsampling_y);
+ pred_0 += 8;
+ pred_1 += 8;
+ } while (--y != 0);
+}
+
+template <int subsampling_x, int subsampling_y>
+inline uint8x16_t GetMask16(const uint8_t* mask, const ptrdiff_t mask_stride) {
+ if (subsampling_x == 1 && subsampling_y == 1) {
+ const uint8x16x2_t mask_val0 = vld2q_u8(mask);
+ const uint8x16x2_t mask_val1 = vld2q_u8(mask + mask_stride);
+ const uint8x16_t combined_horz0 =
+ vaddq_u8(mask_val0.val[0], mask_val0.val[1]);
+ const uint8x16_t combined_horz1 =
+ vaddq_u8(mask_val1.val[0], mask_val1.val[1]);
+ // Use a halving add to work around the case where all |mask| values are 64.
+ return vrshrq_n_u8(vhaddq_u8(combined_horz0, combined_horz1), 1);
+ }
+ if (subsampling_x == 1) {
+ const uint8x16x2_t mask_val = vld2q_u8(mask);
+ return vrhaddq_u8(mask_val.val[0], mask_val.val[1]);
+ }
+ assert(subsampling_y == 0 && subsampling_x == 0);
+ return vld1q_u8(mask);
+}
+
template <int subsampling_x, int subsampling_y>
inline void MaskBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0,
const void* LIBGAV1_RESTRICT prediction_1,
@@ -204,8 +267,13 @@ inline void MaskBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0,
const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
if (width == 4) {
- MaskBlending4xH_NEON<subsampling_x, subsampling_y>(
- pred_0, pred_1, mask_ptr, mask_stride, height, dst, dst_stride);
+ MaskBlending4xH_NEON<subsampling_y>(pred_0, pred_1, mask_ptr, height, dst,
+ dst_stride);
+ return;
+ }
+ if (width == 8) {
+ MaskBlending8xH_NEON<subsampling_x, subsampling_y>(pred_0, pred_1, mask_ptr,
+ height, dst, dst_stride);
return;
}
const uint8_t* mask = mask_ptr;
@@ -214,35 +282,24 @@ inline void MaskBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0,
do {
int x = 0;
do {
- const int16x8_t pred_mask_0 = GetMask8<subsampling_x, subsampling_y>(
+ const uint8x16_t pred_mask_0 = GetMask16<subsampling_x, subsampling_y>(
mask + (x << subsampling_x), mask_stride);
+ const int16x8_t pred_mask_0_lo = ZeroExtend(vget_low_u8(pred_mask_0));
+ const int16x8_t pred_mask_0_hi = ZeroExtend(vget_high_u8(pred_mask_0));
// 64 - mask
- const int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
- const int16x8_t pred_val_0 = vld1q_s16(pred_0 + x);
- const int16x8_t pred_val_1 = vld1q_s16(pred_1 + x);
+ const int16x8_t pred_mask_1_lo = vsubq_s16(mask_inverter, pred_mask_0_lo);
+ const int16x8_t pred_mask_1_hi = vsubq_s16(mask_inverter, pred_mask_0_hi);
+
uint8x8_t result;
- // int res = (mask_value * prediction_0[x] +
- // (64 - mask_value) * prediction_1[x]) >> 6;
- const int32x4_t weighted_pred_0_lo =
- vmull_s16(vget_low_s16(pred_mask_0), vget_low_s16(pred_val_0));
- const int32x4_t weighted_pred_0_hi =
- vmull_s16(vget_high_s16(pred_mask_0), vget_high_s16(pred_val_0));
- const int32x4_t weighted_combo_lo =
- vmlal_s16(weighted_pred_0_lo, vget_low_s16(pred_mask_1),
- vget_low_s16(pred_val_1));
- const int32x4_t weighted_combo_hi =
- vmlal_s16(weighted_pred_0_hi, vget_high_s16(pred_mask_1),
- vget_high_s16(pred_val_1));
-
- // dst[x] = static_cast<Pixel>(
- // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
- // (1 << kBitdepth8) - 1));
- result = vqrshrun_n_s16(vcombine_s16(vshrn_n_s32(weighted_combo_lo, 6),
- vshrn_n_s32(weighted_combo_hi, 6)),
- 4);
+ result =
+ CombinePred8(pred_0 + x, pred_1 + x, pred_mask_0_lo, pred_mask_1_lo);
vst1_u8(dst + x, result);
- x += 8;
+ result = CombinePred8(pred_0 + x + 8, pred_1 + x + 8, pred_mask_0_hi,
+ pred_mask_1_hi);
+ vst1_u8(dst + x + 8, result);
+
+ x += 16;
} while (x < width);
dst += dst_stride;
pred_0 += width;
@@ -251,63 +308,19 @@ inline void MaskBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0,
} while (++y < height);
}
-// TODO(b/150461164): This is much faster for inter_intra (input is Pixel
-// values) but regresses compound versions (input is int16_t). Try to
-// consolidate these.
template <int subsampling_x, int subsampling_y>
inline uint8x8_t GetInterIntraMask4x2(const uint8_t* mask,
ptrdiff_t mask_stride) {
if (subsampling_x == 1) {
- const uint8x8_t mask_val =
- vpadd_u8(vld1_u8(mask), vld1_u8(mask + (mask_stride << subsampling_y)));
- if (subsampling_y == 1) {
- const uint8x8_t next_mask_val = vpadd_u8(vld1_u8(mask + mask_stride),
- vld1_u8(mask + mask_stride * 3));
-
- // Use a saturating add to work around the case where all |mask| values
- // are 64. Together with the rounding shift this ensures the correct
- // result.
- const uint8x8_t sum = vqadd_u8(mask_val, next_mask_val);
- return vrshr_n_u8(sum, /*subsampling_x=*/1 + subsampling_y);
- }
-
- return vrshr_n_u8(mask_val, /*subsampling_x=*/1);
+ return GetMask4x2<subsampling_y>(mask);
}
-
+ // When using intra or difference weighted masks, the function doesn't use
+ // subsampling, so |mask_stride| may be 4 or 8.
assert(subsampling_y == 0 && subsampling_x == 0);
const uint8x8_t mask_val0 = Load4(mask);
- // TODO(b/150461164): Investigate the source of |mask| and see if the stride
- // can be removed.
- // TODO(b/150461164): The unit tests start at 8x8. Does this get run?
return Load4<1>(mask + mask_stride, mask_val0);
}
-template <int subsampling_x, int subsampling_y>
-inline uint8x8_t GetInterIntraMask8(const uint8_t* mask,
- ptrdiff_t mask_stride) {
- if (subsampling_x == 1) {
- const uint8x16_t mask_val = vld1q_u8(mask);
- const uint8x8_t mask_paired =
- vpadd_u8(vget_low_u8(mask_val), vget_high_u8(mask_val));
- if (subsampling_y == 1) {
- const uint8x16_t next_mask_val = vld1q_u8(mask + mask_stride);
- const uint8x8_t next_mask_paired =
- vpadd_u8(vget_low_u8(next_mask_val), vget_high_u8(next_mask_val));
-
- // Use a saturating add to work around the case where all |mask| values
- // are 64. Together with the rounding shift this ensures the correct
- // result.
- const uint8x8_t sum = vqadd_u8(mask_paired, next_mask_paired);
- return vrshr_n_u8(sum, /*subsampling_x=*/1 + subsampling_y);
- }
-
- return vrshr_n_u8(mask_paired, /*subsampling_x=*/1);
- }
-
- assert(subsampling_y == 0 && subsampling_x == 0);
- return vld1_u8(mask);
-}
-
inline void InterIntraWriteMaskBlendLine8bpp4x2(
const uint8_t* LIBGAV1_RESTRICT const pred_0,
uint8_t* LIBGAV1_RESTRICT const pred_1, const ptrdiff_t pred_stride_1,
@@ -374,6 +387,32 @@ inline void InterIntraMaskBlending8bpp4xH_NEON(
}
template <int subsampling_x, int subsampling_y>
+inline void InterIntraMaskBlending8bpp8xH_NEON(
+ const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1,
+ const ptrdiff_t pred_stride_1, const uint8_t* LIBGAV1_RESTRICT mask,
+ const ptrdiff_t mask_stride, const int height) {
+ const uint8x8_t mask_inverter = vdup_n_u8(64);
+ int y = height;
+ do {
+ const uint8x8_t pred_mask_1 = GetMask8<subsampling_x, subsampling_y>(mask);
+ // 64 - mask
+ const uint8x8_t pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1);
+ const uint8x8_t pred_val_0 = vld1_u8(pred_0);
+ const uint8x8_t pred_val_1 = vld1_u8(pred_1);
+ const uint16x8_t weighted_pred_0 = vmull_u8(pred_mask_0, pred_val_0);
+ // weighted_pred0 + weighted_pred1
+ const uint16x8_t weighted_combo =
+ vmlal_u8(weighted_pred_0, pred_mask_1, pred_val_1);
+ const uint8x8_t result = vrshrn_n_u16(weighted_combo, 6);
+ vst1_u8(pred_1, result);
+
+ pred_0 += 8;
+ pred_1 += pred_stride_1;
+ mask += mask_stride << subsampling_y;
+ } while (--y != 0);
+}
+
+template <int subsampling_x, int subsampling_y>
inline void InterIntraMaskBlend8bpp_NEON(
const uint8_t* LIBGAV1_RESTRICT prediction_0,
uint8_t* LIBGAV1_RESTRICT prediction_1, const ptrdiff_t prediction_stride_1,
@@ -385,30 +424,46 @@ inline void InterIntraMaskBlend8bpp_NEON(
height);
return;
}
+ if (width == 8) {
+ InterIntraMaskBlending8bpp8xH_NEON<subsampling_x, subsampling_y>(
+ prediction_0, prediction_1, prediction_stride_1, mask_ptr, mask_stride,
+ height);
+ return;
+ }
const uint8_t* mask = mask_ptr;
- const uint8x8_t mask_inverter = vdup_n_u8(64);
+ const uint8x16_t mask_inverter = vdupq_n_u8(64);
int y = 0;
do {
int x = 0;
do {
- // TODO(b/150461164): Consider a 16 wide specialization (at least for the
- // unsampled version) to take advantage of vld1q_u8().
- const uint8x8_t pred_mask_1 =
- GetInterIntraMask8<subsampling_x, subsampling_y>(
- mask + (x << subsampling_x), mask_stride);
+ const uint8x16_t pred_mask_1 = GetMask16<subsampling_x, subsampling_y>(
+ mask + (x << subsampling_x), mask_stride);
// 64 - mask
- const uint8x8_t pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1);
- const uint8x8_t pred_val_0 = vld1_u8(prediction_0);
+ const uint8x16_t pred_mask_0 = vsubq_u8(mask_inverter, pred_mask_1);
+ const uint8x8_t pred_val_0_lo = vld1_u8(prediction_0);
+ prediction_0 += 8;
+ const uint8x8_t pred_val_0_hi = vld1_u8(prediction_0);
prediction_0 += 8;
- const uint8x8_t pred_val_1 = vld1_u8(prediction_1 + x);
- const uint16x8_t weighted_pred_0 = vmull_u8(pred_mask_0, pred_val_0);
+ // Ensure armv7 build combines the load.
+ const uint8x16_t pred_val_1 = vld1q_u8(prediction_1 + x);
+ const uint8x8_t pred_val_1_lo = vget_low_u8(pred_val_1);
+ const uint8x8_t pred_val_1_hi = vget_high_u8(pred_val_1);
+ const uint16x8_t weighted_pred_0_lo =
+ vmull_u8(vget_low_u8(pred_mask_0), pred_val_0_lo);
// weighted_pred0 + weighted_pred1
- const uint16x8_t weighted_combo =
- vmlal_u8(weighted_pred_0, pred_mask_1, pred_val_1);
- const uint8x8_t result = vrshrn_n_u16(weighted_combo, 6);
- vst1_u8(prediction_1 + x, result);
+ const uint16x8_t weighted_combo_lo =
+ vmlal_u8(weighted_pred_0_lo, vget_low_u8(pred_mask_1), pred_val_1_lo);
+ const uint8x8_t result_lo = vrshrn_n_u16(weighted_combo_lo, 6);
+ vst1_u8(prediction_1 + x, result_lo);
+ const uint16x8_t weighted_pred_0_hi =
+ vmull_u8(vget_high_u8(pred_mask_0), pred_val_0_hi);
+ // weighted_pred0 + weighted_pred1
+ const uint16x8_t weighted_combo_hi = vmlal_u8(
+ weighted_pred_0_hi, vget_high_u8(pred_mask_1), pred_val_1_hi);
+ const uint8x8_t result_hi = vrshrn_n_u16(weighted_combo_hi, 6);
+ vst1_u8(prediction_1 + x + 8, result_hi);
- x += 8;
+ x += 16;
} while (x < width);
prediction_1 += prediction_stride_1;
mask += mask_stride << subsampling_y;