aboutsummaryrefslogtreecommitdiff
path: root/src/dsp/arm/mask_blend_neon.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/dsp/arm/mask_blend_neon.cc')
-rw-r--r--src/dsp/arm/mask_blend_neon.cc352
1 files changed, 321 insertions, 31 deletions
diff --git a/src/dsp/arm/mask_blend_neon.cc b/src/dsp/arm/mask_blend_neon.cc
index ee50923..853f949 100644
--- a/src/dsp/arm/mask_blend_neon.cc
+++ b/src/dsp/arm/mask_blend_neon.cc
@@ -79,10 +79,11 @@ inline int16x8_t GetMask8(const uint8_t* mask, ptrdiff_t mask_stride) {
return vreinterpretq_s16_u16(vmovl_u8(mask_val));
}
-inline void WriteMaskBlendLine4x2(const int16_t* const pred_0,
- const int16_t* const pred_1,
+inline void WriteMaskBlendLine4x2(const int16_t* LIBGAV1_RESTRICT const pred_0,
+ const int16_t* LIBGAV1_RESTRICT const pred_1,
const int16x8_t pred_mask_0,
- const int16x8_t pred_mask_1, uint8_t* dst,
+ const int16x8_t pred_mask_1,
+ uint8_t* LIBGAV1_RESTRICT dst,
const ptrdiff_t dst_stride) {
const int16x8_t pred_val_0 = vld1q_s16(pred_0);
const int16x8_t pred_val_1 = vld1q_s16(pred_1);
@@ -109,9 +110,11 @@ inline void WriteMaskBlendLine4x2(const int16_t* const pred_0,
}
template <int subsampling_x, int subsampling_y>
-inline void MaskBlending4x4_NEON(const int16_t* pred_0, const int16_t* pred_1,
- const uint8_t* mask,
- const ptrdiff_t mask_stride, uint8_t* dst,
+inline void MaskBlending4x4_NEON(const int16_t* LIBGAV1_RESTRICT pred_0,
+ const int16_t* LIBGAV1_RESTRICT pred_1,
+ const uint8_t* LIBGAV1_RESTRICT mask,
+ const ptrdiff_t mask_stride,
+ uint8_t* LIBGAV1_RESTRICT dst,
const ptrdiff_t dst_stride) {
const int16x8_t mask_inverter = vdupq_n_s16(64);
int16x8_t pred_mask_0 =
@@ -133,10 +136,12 @@ inline void MaskBlending4x4_NEON(const int16_t* pred_0, const int16_t* pred_1,
}
template <int subsampling_x, int subsampling_y>
-inline void MaskBlending4xH_NEON(const int16_t* pred_0, const int16_t* pred_1,
- const uint8_t* const mask_ptr,
+inline void MaskBlending4xH_NEON(const int16_t* LIBGAV1_RESTRICT pred_0,
+ const int16_t* LIBGAV1_RESTRICT pred_1,
+ const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
const ptrdiff_t mask_stride, const int height,
- uint8_t* dst, const ptrdiff_t dst_stride) {
+ uint8_t* LIBGAV1_RESTRICT dst,
+ const ptrdiff_t dst_stride) {
const uint8_t* mask = mask_ptr;
if (height == 4) {
MaskBlending4x4_NEON<subsampling_x, subsampling_y>(
@@ -188,11 +193,12 @@ inline void MaskBlending4xH_NEON(const int16_t* pred_0, const int16_t* pred_1,
}
template <int subsampling_x, int subsampling_y>
-inline void MaskBlend_NEON(const void* prediction_0, const void* prediction_1,
+inline void MaskBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+ const void* LIBGAV1_RESTRICT prediction_1,
const ptrdiff_t /*prediction_stride_1*/,
- const uint8_t* const mask_ptr,
+ const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
const ptrdiff_t mask_stride, const int width,
- const int height, void* dest,
+ const int height, void* LIBGAV1_RESTRICT dest,
const ptrdiff_t dst_stride) {
auto* dst = static_cast<uint8_t*>(dest);
const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
@@ -302,11 +308,10 @@ inline uint8x8_t GetInterIntraMask8(const uint8_t* mask,
return vld1_u8(mask);
}
-inline void InterIntraWriteMaskBlendLine8bpp4x2(const uint8_t* const pred_0,
- uint8_t* const pred_1,
- const ptrdiff_t pred_stride_1,
- const uint8x8_t pred_mask_0,
- const uint8x8_t pred_mask_1) {
+inline void InterIntraWriteMaskBlendLine8bpp4x2(
+ const uint8_t* LIBGAV1_RESTRICT const pred_0,
+ uint8_t* LIBGAV1_RESTRICT const pred_1, const ptrdiff_t pred_stride_1,
+ const uint8x8_t pred_mask_0, const uint8x8_t pred_mask_1) {
const uint8x8_t pred_val_0 = vld1_u8(pred_0);
uint8x8_t pred_val_1 = Load4(pred_1);
pred_val_1 = Load4<1>(pred_1 + pred_stride_1, pred_val_1);
@@ -320,11 +325,10 @@ inline void InterIntraWriteMaskBlendLine8bpp4x2(const uint8_t* const pred_0,
}
template <int subsampling_x, int subsampling_y>
-inline void InterIntraMaskBlending8bpp4x4_NEON(const uint8_t* pred_0,
- uint8_t* pred_1,
- const ptrdiff_t pred_stride_1,
- const uint8_t* mask,
- const ptrdiff_t mask_stride) {
+inline void InterIntraMaskBlending8bpp4x4_NEON(
+ const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1,
+ const ptrdiff_t pred_stride_1, const uint8_t* LIBGAV1_RESTRICT mask,
+ const ptrdiff_t mask_stride) {
const uint8x8_t mask_inverter = vdup_n_u8(64);
uint8x8_t pred_mask_1 =
GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
@@ -344,8 +348,9 @@ inline void InterIntraMaskBlending8bpp4x4_NEON(const uint8_t* pred_0,
template <int subsampling_x, int subsampling_y>
inline void InterIntraMaskBlending8bpp4xH_NEON(
- const uint8_t* pred_0, uint8_t* pred_1, const ptrdiff_t pred_stride_1,
- const uint8_t* mask, const ptrdiff_t mask_stride, const int height) {
+ const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1,
+ const ptrdiff_t pred_stride_1, const uint8_t* LIBGAV1_RESTRICT mask,
+ const ptrdiff_t mask_stride, const int height) {
if (height == 4) {
InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>(
pred_0, pred_1, pred_stride_1, mask, mask_stride);
@@ -369,12 +374,11 @@ inline void InterIntraMaskBlending8bpp4xH_NEON(
}
template <int subsampling_x, int subsampling_y>
-inline void InterIntraMaskBlend8bpp_NEON(const uint8_t* prediction_0,
- uint8_t* prediction_1,
- const ptrdiff_t prediction_stride_1,
- const uint8_t* const mask_ptr,
- const ptrdiff_t mask_stride,
- const int width, const int height) {
+inline void InterIntraMaskBlend8bpp_NEON(
+ const uint8_t* LIBGAV1_RESTRICT prediction_0,
+ uint8_t* LIBGAV1_RESTRICT prediction_1, const ptrdiff_t prediction_stride_1,
+ const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
+ const int width, const int height) {
if (width == 4) {
InterIntraMaskBlending8bpp4xH_NEON<subsampling_x, subsampling_y>(
prediction_0, prediction_1, prediction_stride_1, mask_ptr, mask_stride,
@@ -427,7 +431,293 @@ void Init8bpp() {
} // namespace
} // namespace low_bitdepth
-void MaskBlendInit_NEON() { low_bitdepth::Init8bpp(); }
+#if LIBGAV1_MAX_BITDEPTH >= 10
+namespace high_bitdepth {
+namespace {
+
+template <int subsampling_x, int subsampling_y>
+inline uint16x8_t GetMask4x2(const uint8_t* mask, ptrdiff_t mask_stride) {
+ if (subsampling_x == 1) {
+ const uint8x8_t mask_val0 = vld1_u8(mask);
+ const uint8x8_t mask_val1 = vld1_u8(mask + (mask_stride << subsampling_y));
+ uint16x8_t final_val = vpaddlq_u8(vcombine_u8(mask_val0, mask_val1));
+ if (subsampling_y == 1) {
+ const uint8x8_t next_mask_val0 = vld1_u8(mask + mask_stride);
+ const uint8x8_t next_mask_val1 = vld1_u8(mask + mask_stride * 3);
+ final_val = vaddq_u16(
+ final_val, vpaddlq_u8(vcombine_u8(next_mask_val0, next_mask_val1)));
+ }
+ return vrshrq_n_u16(final_val, subsampling_y + 1);
+ }
+ assert(subsampling_y == 0 && subsampling_x == 0);
+ const uint8x8_t mask_val0 = Load4(mask);
+ const uint8x8_t mask_val = Load4<1>(mask + mask_stride, mask_val0);
+ return vmovl_u8(mask_val);
+}
+
+template <int subsampling_x, int subsampling_y>
+inline uint16x8_t GetMask8(const uint8_t* mask, ptrdiff_t mask_stride) {
+ if (subsampling_x == 1) {
+ uint16x8_t mask_val = vpaddlq_u8(vld1q_u8(mask));
+ if (subsampling_y == 1) {
+ const uint16x8_t next_mask_val = vpaddlq_u8(vld1q_u8(mask + mask_stride));
+ mask_val = vaddq_u16(mask_val, next_mask_val);
+ }
+ return vrshrq_n_u16(mask_val, 1 + subsampling_y);
+ }
+ assert(subsampling_y == 0 && subsampling_x == 0);
+ const uint8x8_t mask_val = vld1_u8(mask);
+ return vmovl_u8(mask_val);
+}
+
+template <bool is_inter_intra>
+uint16x8_t SumWeightedPred(const uint16x8_t pred_mask_0,
+ const uint16x8_t pred_mask_1,
+ const uint16x8_t pred_val_0,
+ const uint16x8_t pred_val_1) {
+ if (is_inter_intra) {
+ // dst[x] = static_cast<Pixel>(RightShiftWithRounding(
+ // mask_value * pred_1[x] + (64 - mask_value) * pred_0[x], 6));
+ uint16x8_t sum = vmulq_u16(pred_mask_1, pred_val_0);
+ sum = vmlaq_u16(sum, pred_mask_0, pred_val_1);
+ return vrshrq_n_u16(sum, 6);
+ } else {
+ // int res = (mask_value * prediction_0[x] +
+ // (64 - mask_value) * prediction_1[x]) >> 6;
+ const uint32x4_t weighted_pred_0_lo =
+ vmull_u16(vget_low_u16(pred_mask_0), vget_low_u16(pred_val_0));
+ const uint32x4_t weighted_pred_0_hi = VMullHighU16(pred_mask_0, pred_val_0);
+ uint32x4x2_t sum;
+ sum.val[0] = vmlal_u16(weighted_pred_0_lo, vget_low_u16(pred_mask_1),
+ vget_low_u16(pred_val_1));
+ sum.val[1] = VMlalHighU16(weighted_pred_0_hi, pred_mask_1, pred_val_1);
+ return vcombine_u16(vshrn_n_u32(sum.val[0], 6), vshrn_n_u32(sum.val[1], 6));
+ }
+}
+
+template <bool is_inter_intra, int width, int bitdepth = 10>
+inline void StoreShiftedResult(uint8_t* dst, const uint16x8_t result,
+ const ptrdiff_t dst_stride = 0) {
+ if (is_inter_intra) {
+ if (width == 4) {
+ // Store 2 lines of width 4.
+ assert(dst_stride != 0);
+ vst1_u16(reinterpret_cast<uint16_t*>(dst), vget_low_u16(result));
+ vst1_u16(reinterpret_cast<uint16_t*>(dst + dst_stride),
+ vget_high_u16(result));
+ } else {
+ // Store 1 line of width 8.
+ vst1q_u16(reinterpret_cast<uint16_t*>(dst), result);
+ }
+ } else {
+ // res -= (bitdepth == 8) ? 0 : kCompoundOffset;
+ // dst[x] = static_cast<Pixel>(
+ // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
+ // (1 << kBitdepth8) - 1));
+ constexpr int inter_post_round_bits = (bitdepth == 12) ? 2 : 4;
+ const uint16x8_t compound_result =
+ vminq_u16(vrshrq_n_u16(vqsubq_u16(result, vdupq_n_u16(kCompoundOffset)),
+ inter_post_round_bits),
+ vdupq_n_u16((1 << bitdepth) - 1));
+ if (width == 4) {
+ // Store 2 lines of width 4.
+ assert(dst_stride != 0);
+ vst1_u16(reinterpret_cast<uint16_t*>(dst), vget_low_u16(compound_result));
+ vst1_u16(reinterpret_cast<uint16_t*>(dst + dst_stride),
+ vget_high_u16(compound_result));
+ } else {
+ // Store 1 line of width 8.
+ vst1q_u16(reinterpret_cast<uint16_t*>(dst), compound_result);
+ }
+ }
+}
+
+template <int subsampling_x, int subsampling_y, bool is_inter_intra>
+inline void MaskBlend4x2_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0,
+ const uint16_t* LIBGAV1_RESTRICT pred_1,
+ const ptrdiff_t pred_stride_1,
+ const uint8_t* LIBGAV1_RESTRICT mask,
+ const uint16x8_t mask_inverter,
+ const ptrdiff_t mask_stride,
+ uint8_t* LIBGAV1_RESTRICT dst,
+ const ptrdiff_t dst_stride) {
+ // This works because stride == width == 4.
+ const uint16x8_t pred_val_0 = vld1q_u16(pred_0);
+ const uint16x8_t pred_val_1 =
+ is_inter_intra
+ ? vcombine_u16(vld1_u16(pred_1), vld1_u16(pred_1 + pred_stride_1))
+ : vld1q_u16(pred_1);
+ const uint16x8_t pred_mask_0 =
+ GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+ const uint16x8_t pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
+ const uint16x8_t weighted_pred_sum = SumWeightedPred<is_inter_intra>(
+ pred_mask_0, pred_mask_1, pred_val_0, pred_val_1);
+
+ StoreShiftedResult<is_inter_intra, 4>(dst, weighted_pred_sum, dst_stride);
+}
+
+template <int subsampling_x, int subsampling_y, bool is_inter_intra>
+inline void MaskBlending4x4_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0,
+ const uint16_t* LIBGAV1_RESTRICT pred_1,
+ const ptrdiff_t pred_stride_1,
+ const uint8_t* LIBGAV1_RESTRICT mask,
+ const ptrdiff_t mask_stride,
+ uint8_t* LIBGAV1_RESTRICT dst,
+ const ptrdiff_t dst_stride) {
+ // Double stride because the function works on 2 lines at a time.
+ const ptrdiff_t mask_stride_y = mask_stride << (subsampling_y + 1);
+ const ptrdiff_t dst_stride_y = dst_stride << 1;
+ const uint16x8_t mask_inverter = vdupq_n_u16(64);
+
+ MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
+ pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
+ dst_stride);
+
+ pred_0 += 4 << 1;
+ pred_1 += pred_stride_1 << 1;
+ mask += mask_stride_y;
+ dst += dst_stride_y;
+
+ MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
+ pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
+ dst_stride);
+}
+
+template <int subsampling_x, int subsampling_y, bool is_inter_intra>
+inline void MaskBlending4xH_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0,
+ const uint16_t* LIBGAV1_RESTRICT pred_1,
+ const ptrdiff_t pred_stride_1,
+ const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
+ const ptrdiff_t mask_stride, const int height,
+ uint8_t* LIBGAV1_RESTRICT dst,
+ const ptrdiff_t dst_stride) {
+ const uint8_t* mask = mask_ptr;
+ if (height == 4) {
+ MaskBlending4x4_NEON<subsampling_x, subsampling_y, is_inter_intra>(
+ pred_0, pred_1, pred_stride_1, mask, mask_stride, dst, dst_stride);
+ return;
+ }
+ // Double stride because the function works on 2 lines at a time.
+ const ptrdiff_t mask_stride_y = mask_stride << (subsampling_y + 1);
+ const ptrdiff_t dst_stride_y = dst_stride << 1;
+ const uint16x8_t mask_inverter = vdupq_n_u16(64);
+ int y = 0;
+ do {
+ MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
+ pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
+ dst_stride);
+ pred_0 += 4 << 1;
+ pred_1 += pred_stride_1 << 1;
+ mask += mask_stride_y;
+ dst += dst_stride_y;
+
+ MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
+ pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
+ dst_stride);
+ pred_0 += 4 << 1;
+ pred_1 += pred_stride_1 << 1;
+ mask += mask_stride_y;
+ dst += dst_stride_y;
+
+ MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
+ pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
+ dst_stride);
+ pred_0 += 4 << 1;
+ pred_1 += pred_stride_1 << 1;
+ mask += mask_stride_y;
+ dst += dst_stride_y;
+
+ MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
+ pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
+ dst_stride);
+ pred_0 += 4 << 1;
+ pred_1 += pred_stride_1 << 1;
+ mask += mask_stride_y;
+ dst += dst_stride_y;
+ y += 8;
+ } while (y < height);
+}
+
+template <int subsampling_x, int subsampling_y, bool is_inter_intra>
+void MaskBlend8_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0,
+ const uint16_t* LIBGAV1_RESTRICT pred_1,
+ const uint8_t* LIBGAV1_RESTRICT mask,
+ const uint16x8_t mask_inverter,
+ const ptrdiff_t mask_stride,
+ uint8_t* LIBGAV1_RESTRICT dst) {
+ const uint16x8_t pred_val_0 = vld1q_u16(pred_0);
+ const uint16x8_t pred_val_1 = vld1q_u16(pred_1);
+ const uint16x8_t pred_mask_0 =
+ GetMask8<subsampling_x, subsampling_y>(mask, mask_stride);
+ const uint16x8_t pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
+ const uint16x8_t weighted_pred_sum = SumWeightedPred<is_inter_intra>(
+ pred_mask_0, pred_mask_1, pred_val_0, pred_val_1);
+
+ StoreShiftedResult<is_inter_intra, 8>(dst, weighted_pred_sum);
+}
+
+template <int subsampling_x, int subsampling_y, bool is_inter_intra>
+inline void MaskBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+ const void* LIBGAV1_RESTRICT prediction_1,
+ const ptrdiff_t prediction_stride_1,
+ const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
+ const ptrdiff_t mask_stride, const int width,
+ const int height, void* LIBGAV1_RESTRICT dest,
+ const ptrdiff_t dst_stride) {
+ if (!is_inter_intra) {
+ assert(prediction_stride_1 == width);
+ }
+ auto* dst = static_cast<uint8_t*>(dest);
+ const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
+ if (width == 4) {
+ MaskBlending4xH_NEON<subsampling_x, subsampling_y, is_inter_intra>(
+ pred_0, pred_1, prediction_stride_1, mask_ptr, mask_stride, height, dst,
+ dst_stride);
+ return;
+ }
+ const ptrdiff_t mask_stride_y = mask_stride << subsampling_y;
+ const uint8_t* mask = mask_ptr;
+ const uint16x8_t mask_inverter = vdupq_n_u16(64);
+ int y = 0;
+ do {
+ int x = 0;
+ do {
+ MaskBlend8_NEON<subsampling_x, subsampling_y, is_inter_intra>(
+ pred_0 + x, pred_1 + x, mask + (x << subsampling_x), mask_inverter,
+ mask_stride,
+ reinterpret_cast<uint8_t*>(reinterpret_cast<uint16_t*>(dst) + x));
+ x += 8;
+ } while (x < width);
+ dst += dst_stride;
+ pred_0 += width;
+ pred_1 += prediction_stride_1;
+ mask += mask_stride_y;
+ } while (++y < height);
+}
+
+void Init10bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+ assert(dsp != nullptr);
+ dsp->mask_blend[0][0] = MaskBlend_NEON<0, 0, false>;
+ dsp->mask_blend[1][0] = MaskBlend_NEON<1, 0, false>;
+ dsp->mask_blend[2][0] = MaskBlend_NEON<1, 1, false>;
+
+ dsp->mask_blend[0][1] = MaskBlend_NEON<0, 0, true>;
+ dsp->mask_blend[1][1] = MaskBlend_NEON<1, 0, true>;
+ dsp->mask_blend[2][1] = MaskBlend_NEON<1, 1, true>;
+}
+
+} // namespace
+} // namespace high_bitdepth
+#endif // LIBGAV1_MAX_BITDEPTH >= 10
+
+void MaskBlendInit_NEON() {
+ low_bitdepth::Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+ high_bitdepth::Init10bpp();
+#endif
+}
} // namespace dsp
} // namespace libgav1