diff options
Diffstat (limited to 'src/dsp/arm/mask_blend_neon.cc')
-rw-r--r-- | src/dsp/arm/mask_blend_neon.cc | 352 |
1 files changed, 321 insertions, 31 deletions
diff --git a/src/dsp/arm/mask_blend_neon.cc b/src/dsp/arm/mask_blend_neon.cc index ee50923..853f949 100644 --- a/src/dsp/arm/mask_blend_neon.cc +++ b/src/dsp/arm/mask_blend_neon.cc @@ -79,10 +79,11 @@ inline int16x8_t GetMask8(const uint8_t* mask, ptrdiff_t mask_stride) { return vreinterpretq_s16_u16(vmovl_u8(mask_val)); } -inline void WriteMaskBlendLine4x2(const int16_t* const pred_0, - const int16_t* const pred_1, +inline void WriteMaskBlendLine4x2(const int16_t* LIBGAV1_RESTRICT const pred_0, + const int16_t* LIBGAV1_RESTRICT const pred_1, const int16x8_t pred_mask_0, - const int16x8_t pred_mask_1, uint8_t* dst, + const int16x8_t pred_mask_1, + uint8_t* LIBGAV1_RESTRICT dst, const ptrdiff_t dst_stride) { const int16x8_t pred_val_0 = vld1q_s16(pred_0); const int16x8_t pred_val_1 = vld1q_s16(pred_1); @@ -109,9 +110,11 @@ inline void WriteMaskBlendLine4x2(const int16_t* const pred_0, } template <int subsampling_x, int subsampling_y> -inline void MaskBlending4x4_NEON(const int16_t* pred_0, const int16_t* pred_1, - const uint8_t* mask, - const ptrdiff_t mask_stride, uint8_t* dst, +inline void MaskBlending4x4_NEON(const int16_t* LIBGAV1_RESTRICT pred_0, + const int16_t* LIBGAV1_RESTRICT pred_1, + const uint8_t* LIBGAV1_RESTRICT mask, + const ptrdiff_t mask_stride, + uint8_t* LIBGAV1_RESTRICT dst, const ptrdiff_t dst_stride) { const int16x8_t mask_inverter = vdupq_n_s16(64); int16x8_t pred_mask_0 = @@ -133,10 +136,12 @@ inline void MaskBlending4x4_NEON(const int16_t* pred_0, const int16_t* pred_1, } template <int subsampling_x, int subsampling_y> -inline void MaskBlending4xH_NEON(const int16_t* pred_0, const int16_t* pred_1, - const uint8_t* const mask_ptr, +inline void MaskBlending4xH_NEON(const int16_t* LIBGAV1_RESTRICT pred_0, + const int16_t* LIBGAV1_RESTRICT pred_1, + const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride, const int height, - uint8_t* dst, const ptrdiff_t dst_stride) { + uint8_t* LIBGAV1_RESTRICT dst, + const ptrdiff_t dst_stride) { const uint8_t* mask = mask_ptr; if (height == 4) { MaskBlending4x4_NEON<subsampling_x, subsampling_y>( @@ -188,11 +193,12 @@ inline void MaskBlending4xH_NEON(const int16_t* pred_0, const int16_t* pred_1, } template <int subsampling_x, int subsampling_y> -inline void MaskBlend_NEON(const void* prediction_0, const void* prediction_1, +inline void MaskBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0, + const void* LIBGAV1_RESTRICT prediction_1, const ptrdiff_t /*prediction_stride_1*/, - const uint8_t* const mask_ptr, + const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride, const int width, - const int height, void* dest, + const int height, void* LIBGAV1_RESTRICT dest, const ptrdiff_t dst_stride) { auto* dst = static_cast<uint8_t*>(dest); const auto* pred_0 = static_cast<const int16_t*>(prediction_0); @@ -302,11 +308,10 @@ inline uint8x8_t GetInterIntraMask8(const uint8_t* mask, return vld1_u8(mask); } -inline void InterIntraWriteMaskBlendLine8bpp4x2(const uint8_t* const pred_0, - uint8_t* const pred_1, - const ptrdiff_t pred_stride_1, - const uint8x8_t pred_mask_0, - const uint8x8_t pred_mask_1) { +inline void InterIntraWriteMaskBlendLine8bpp4x2( + const uint8_t* LIBGAV1_RESTRICT const pred_0, + uint8_t* LIBGAV1_RESTRICT const pred_1, const ptrdiff_t pred_stride_1, + const uint8x8_t pred_mask_0, const uint8x8_t pred_mask_1) { const uint8x8_t pred_val_0 = vld1_u8(pred_0); uint8x8_t pred_val_1 = Load4(pred_1); pred_val_1 = Load4<1>(pred_1 + pred_stride_1, pred_val_1); @@ -320,11 +325,10 @@ inline void InterIntraWriteMaskBlendLine8bpp4x2(const uint8_t* const pred_0, } template <int subsampling_x, int subsampling_y> -inline void InterIntraMaskBlending8bpp4x4_NEON(const uint8_t* pred_0, - uint8_t* pred_1, - const ptrdiff_t pred_stride_1, - const uint8_t* mask, - const ptrdiff_t mask_stride) { +inline void InterIntraMaskBlending8bpp4x4_NEON( + const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1, + const ptrdiff_t pred_stride_1, const uint8_t* LIBGAV1_RESTRICT mask, + const ptrdiff_t mask_stride) { const uint8x8_t mask_inverter = vdup_n_u8(64); uint8x8_t pred_mask_1 = GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); @@ -344,8 +348,9 @@ inline void InterIntraMaskBlending8bpp4x4_NEON(const uint8_t* pred_0, template <int subsampling_x, int subsampling_y> inline void InterIntraMaskBlending8bpp4xH_NEON( - const uint8_t* pred_0, uint8_t* pred_1, const ptrdiff_t pred_stride_1, - const uint8_t* mask, const ptrdiff_t mask_stride, const int height) { + const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1, + const ptrdiff_t pred_stride_1, const uint8_t* LIBGAV1_RESTRICT mask, + const ptrdiff_t mask_stride, const int height) { if (height == 4) { InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>( pred_0, pred_1, pred_stride_1, mask, mask_stride); @@ -369,12 +374,11 @@ inline void InterIntraMaskBlending8bpp4xH_NEON( } template <int subsampling_x, int subsampling_y> -inline void InterIntraMaskBlend8bpp_NEON(const uint8_t* prediction_0, - uint8_t* prediction_1, - const ptrdiff_t prediction_stride_1, - const uint8_t* const mask_ptr, - const ptrdiff_t mask_stride, - const int width, const int height) { +inline void InterIntraMaskBlend8bpp_NEON( + const uint8_t* LIBGAV1_RESTRICT prediction_0, + uint8_t* LIBGAV1_RESTRICT prediction_1, const ptrdiff_t prediction_stride_1, + const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride, + const int width, const int height) { if (width == 4) { InterIntraMaskBlending8bpp4xH_NEON<subsampling_x, subsampling_y>( prediction_0, prediction_1, prediction_stride_1, mask_ptr, mask_stride, @@ -427,7 +431,293 @@ void Init8bpp() { } // namespace } // namespace low_bitdepth -void MaskBlendInit_NEON() { low_bitdepth::Init8bpp(); } +#if LIBGAV1_MAX_BITDEPTH >= 10 +namespace high_bitdepth { +namespace { + +template <int subsampling_x, int subsampling_y> +inline uint16x8_t GetMask4x2(const uint8_t* mask, ptrdiff_t mask_stride) { + if (subsampling_x == 1) { + const uint8x8_t mask_val0 = vld1_u8(mask); + const uint8x8_t mask_val1 = vld1_u8(mask + (mask_stride << subsampling_y)); + uint16x8_t final_val = vpaddlq_u8(vcombine_u8(mask_val0, mask_val1)); + if (subsampling_y == 1) { + const uint8x8_t next_mask_val0 = vld1_u8(mask + mask_stride); + const uint8x8_t next_mask_val1 = vld1_u8(mask + mask_stride * 3); + final_val = vaddq_u16( + final_val, vpaddlq_u8(vcombine_u8(next_mask_val0, next_mask_val1))); + } + return vrshrq_n_u16(final_val, subsampling_y + 1); + } + assert(subsampling_y == 0 && subsampling_x == 0); + const uint8x8_t mask_val0 = Load4(mask); + const uint8x8_t mask_val = Load4<1>(mask + mask_stride, mask_val0); + return vmovl_u8(mask_val); +} + +template <int subsampling_x, int subsampling_y> +inline uint16x8_t GetMask8(const uint8_t* mask, ptrdiff_t mask_stride) { + if (subsampling_x == 1) { + uint16x8_t mask_val = vpaddlq_u8(vld1q_u8(mask)); + if (subsampling_y == 1) { + const uint16x8_t next_mask_val = vpaddlq_u8(vld1q_u8(mask + mask_stride)); + mask_val = vaddq_u16(mask_val, next_mask_val); + } + return vrshrq_n_u16(mask_val, 1 + subsampling_y); + } + assert(subsampling_y == 0 && subsampling_x == 0); + const uint8x8_t mask_val = vld1_u8(mask); + return vmovl_u8(mask_val); +} + +template <bool is_inter_intra> +uint16x8_t SumWeightedPred(const uint16x8_t pred_mask_0, + const uint16x8_t pred_mask_1, + const uint16x8_t pred_val_0, + const uint16x8_t pred_val_1) { + if (is_inter_intra) { + // dst[x] = static_cast<Pixel>(RightShiftWithRounding( + // mask_value * pred_1[x] + (64 - mask_value) * pred_0[x], 6)); + uint16x8_t sum = vmulq_u16(pred_mask_1, pred_val_0); + sum = vmlaq_u16(sum, pred_mask_0, pred_val_1); + return vrshrq_n_u16(sum, 6); + } else { + // int res = (mask_value * prediction_0[x] + + // (64 - mask_value) * prediction_1[x]) >> 6; + const uint32x4_t weighted_pred_0_lo = + vmull_u16(vget_low_u16(pred_mask_0), vget_low_u16(pred_val_0)); + const uint32x4_t weighted_pred_0_hi = VMullHighU16(pred_mask_0, pred_val_0); + uint32x4x2_t sum; + sum.val[0] = vmlal_u16(weighted_pred_0_lo, vget_low_u16(pred_mask_1), + vget_low_u16(pred_val_1)); + sum.val[1] = VMlalHighU16(weighted_pred_0_hi, pred_mask_1, pred_val_1); + return vcombine_u16(vshrn_n_u32(sum.val[0], 6), vshrn_n_u32(sum.val[1], 6)); + } +} + +template <bool is_inter_intra, int width, int bitdepth = 10> +inline void StoreShiftedResult(uint8_t* dst, const uint16x8_t result, + const ptrdiff_t dst_stride = 0) { + if (is_inter_intra) { + if (width == 4) { + // Store 2 lines of width 4. + assert(dst_stride != 0); + vst1_u16(reinterpret_cast<uint16_t*>(dst), vget_low_u16(result)); + vst1_u16(reinterpret_cast<uint16_t*>(dst + dst_stride), + vget_high_u16(result)); + } else { + // Store 1 line of width 8. + vst1q_u16(reinterpret_cast<uint16_t*>(dst), result); + } + } else { + // res -= (bitdepth == 8) ? 0 : kCompoundOffset; + // dst[x] = static_cast<Pixel>( + // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0, + // (1 << kBitdepth8) - 1)); + constexpr int inter_post_round_bits = (bitdepth == 12) ? 2 : 4; + const uint16x8_t compound_result = + vminq_u16(vrshrq_n_u16(vqsubq_u16(result, vdupq_n_u16(kCompoundOffset)), + inter_post_round_bits), + vdupq_n_u16((1 << bitdepth) - 1)); + if (width == 4) { + // Store 2 lines of width 4. + assert(dst_stride != 0); + vst1_u16(reinterpret_cast<uint16_t*>(dst), vget_low_u16(compound_result)); + vst1_u16(reinterpret_cast<uint16_t*>(dst + dst_stride), + vget_high_u16(compound_result)); + } else { + // Store 1 line of width 8. + vst1q_u16(reinterpret_cast<uint16_t*>(dst), compound_result); + } + } +} + +template <int subsampling_x, int subsampling_y, bool is_inter_intra> +inline void MaskBlend4x2_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0, + const uint16_t* LIBGAV1_RESTRICT pred_1, + const ptrdiff_t pred_stride_1, + const uint8_t* LIBGAV1_RESTRICT mask, + const uint16x8_t mask_inverter, + const ptrdiff_t mask_stride, + uint8_t* LIBGAV1_RESTRICT dst, + const ptrdiff_t dst_stride) { + // This works because stride == width == 4. + const uint16x8_t pred_val_0 = vld1q_u16(pred_0); + const uint16x8_t pred_val_1 = + is_inter_intra + ? vcombine_u16(vld1_u16(pred_1), vld1_u16(pred_1 + pred_stride_1)) + : vld1q_u16(pred_1); + const uint16x8_t pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + const uint16x8_t pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0); + const uint16x8_t weighted_pred_sum = SumWeightedPred<is_inter_intra>( + pred_mask_0, pred_mask_1, pred_val_0, pred_val_1); + + StoreShiftedResult<is_inter_intra, 4>(dst, weighted_pred_sum, dst_stride); +} + +template <int subsampling_x, int subsampling_y, bool is_inter_intra> +inline void MaskBlending4x4_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0, + const uint16_t* LIBGAV1_RESTRICT pred_1, + const ptrdiff_t pred_stride_1, + const uint8_t* LIBGAV1_RESTRICT mask, + const ptrdiff_t mask_stride, + uint8_t* LIBGAV1_RESTRICT dst, + const ptrdiff_t dst_stride) { + // Double stride because the function works on 2 lines at a time. + const ptrdiff_t mask_stride_y = mask_stride << (subsampling_y + 1); + const ptrdiff_t dst_stride_y = dst_stride << 1; + const uint16x8_t mask_inverter = vdupq_n_u16(64); + + MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>( + pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst, + dst_stride); + + pred_0 += 4 << 1; + pred_1 += pred_stride_1 << 1; + mask += mask_stride_y; + dst += dst_stride_y; + + MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>( + pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst, + dst_stride); +} + +template <int subsampling_x, int subsampling_y, bool is_inter_intra> +inline void MaskBlending4xH_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0, + const uint16_t* LIBGAV1_RESTRICT pred_1, + const ptrdiff_t pred_stride_1, + const uint8_t* LIBGAV1_RESTRICT const mask_ptr, + const ptrdiff_t mask_stride, const int height, + uint8_t* LIBGAV1_RESTRICT dst, + const ptrdiff_t dst_stride) { + const uint8_t* mask = mask_ptr; + if (height == 4) { + MaskBlending4x4_NEON<subsampling_x, subsampling_y, is_inter_intra>( + pred_0, pred_1, pred_stride_1, mask, mask_stride, dst, dst_stride); + return; + } + // Double stride because the function works on 2 lines at a time. + const ptrdiff_t mask_stride_y = mask_stride << (subsampling_y + 1); + const ptrdiff_t dst_stride_y = dst_stride << 1; + const uint16x8_t mask_inverter = vdupq_n_u16(64); + int y = 0; + do { + MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>( + pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += pred_stride_1 << 1; + mask += mask_stride_y; + dst += dst_stride_y; + + MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>( + pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += pred_stride_1 << 1; + mask += mask_stride_y; + dst += dst_stride_y; + + MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>( + pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += pred_stride_1 << 1; + mask += mask_stride_y; + dst += dst_stride_y; + + MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>( + pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += pred_stride_1 << 1; + mask += mask_stride_y; + dst += dst_stride_y; + y += 8; + } while (y < height); +} + +template <int subsampling_x, int subsampling_y, bool is_inter_intra> +void MaskBlend8_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0, + const uint16_t* LIBGAV1_RESTRICT pred_1, + const uint8_t* LIBGAV1_RESTRICT mask, + const uint16x8_t mask_inverter, + const ptrdiff_t mask_stride, + uint8_t* LIBGAV1_RESTRICT dst) { + const uint16x8_t pred_val_0 = vld1q_u16(pred_0); + const uint16x8_t pred_val_1 = vld1q_u16(pred_1); + const uint16x8_t pred_mask_0 = + GetMask8<subsampling_x, subsampling_y>(mask, mask_stride); + const uint16x8_t pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0); + const uint16x8_t weighted_pred_sum = SumWeightedPred<is_inter_intra>( + pred_mask_0, pred_mask_1, pred_val_0, pred_val_1); + + StoreShiftedResult<is_inter_intra, 8>(dst, weighted_pred_sum); +} + +template <int subsampling_x, int subsampling_y, bool is_inter_intra> +inline void MaskBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0, + const void* LIBGAV1_RESTRICT prediction_1, + const ptrdiff_t prediction_stride_1, + const uint8_t* LIBGAV1_RESTRICT const mask_ptr, + const ptrdiff_t mask_stride, const int width, + const int height, void* LIBGAV1_RESTRICT dest, + const ptrdiff_t dst_stride) { + if (!is_inter_intra) { + assert(prediction_stride_1 == width); + } + auto* dst = static_cast<uint8_t*>(dest); + const auto* pred_0 = static_cast<const uint16_t*>(prediction_0); + const auto* pred_1 = static_cast<const uint16_t*>(prediction_1); + if (width == 4) { + MaskBlending4xH_NEON<subsampling_x, subsampling_y, is_inter_intra>( + pred_0, pred_1, prediction_stride_1, mask_ptr, mask_stride, height, dst, + dst_stride); + return; + } + const ptrdiff_t mask_stride_y = mask_stride << subsampling_y; + const uint8_t* mask = mask_ptr; + const uint16x8_t mask_inverter = vdupq_n_u16(64); + int y = 0; + do { + int x = 0; + do { + MaskBlend8_NEON<subsampling_x, subsampling_y, is_inter_intra>( + pred_0 + x, pred_1 + x, mask + (x << subsampling_x), mask_inverter, + mask_stride, + reinterpret_cast<uint8_t*>(reinterpret_cast<uint16_t*>(dst) + x)); + x += 8; + } while (x < width); + dst += dst_stride; + pred_0 += width; + pred_1 += prediction_stride_1; + mask += mask_stride_y; + } while (++y < height); +} + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + dsp->mask_blend[0][0] = MaskBlend_NEON<0, 0, false>; + dsp->mask_blend[1][0] = MaskBlend_NEON<1, 0, false>; + dsp->mask_blend[2][0] = MaskBlend_NEON<1, 1, false>; + + dsp->mask_blend[0][1] = MaskBlend_NEON<0, 0, true>; + dsp->mask_blend[1][1] = MaskBlend_NEON<1, 0, true>; + dsp->mask_blend[2][1] = MaskBlend_NEON<1, 1, true>; +} + +} // namespace +} // namespace high_bitdepth +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +void MaskBlendInit_NEON() { + low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + high_bitdepth::Init10bpp(); +#endif +} } // namespace dsp } // namespace libgav1 |