diff options
Diffstat (limited to 'src/dsp/arm/mask_blend_neon.cc')
-rw-r--r-- | src/dsp/arm/mask_blend_neon.cc | 444 |
1 files changed, 444 insertions, 0 deletions
diff --git a/src/dsp/arm/mask_blend_neon.cc b/src/dsp/arm/mask_blend_neon.cc new file mode 100644 index 0000000..084f42f --- /dev/null +++ b/src/dsp/arm/mask_blend_neon.cc @@ -0,0 +1,444 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/mask_blend.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +// TODO(b/150461164): Consider combining with GetInterIntraMask4x2(). +// Compound predictors use int16_t values and need to multiply long because the +// Convolve range * 64 is 20 bits. Unfortunately there is no multiply int16_t by +// int8_t and accumulate into int32_t instruction. +template <int subsampling_x, int subsampling_y> +inline int16x8_t GetMask4x2(const uint8_t* mask, ptrdiff_t mask_stride) { + if (subsampling_x == 1) { + const int16x4_t mask_val0 = vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask))); + const int16x4_t mask_val1 = vreinterpret_s16_u16( + vpaddl_u8(vld1_u8(mask + (mask_stride << subsampling_y)))); + int16x8_t final_val; + if (subsampling_y == 1) { + const int16x4_t next_mask_val0 = + vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask + mask_stride))); + const int16x4_t next_mask_val1 = + vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask + mask_stride * 3))); + final_val = vaddq_s16(vcombine_s16(mask_val0, mask_val1), + vcombine_s16(next_mask_val0, next_mask_val1)); + } else { + final_val = vreinterpretq_s16_u16( + vpaddlq_u8(vreinterpretq_u8_s16(vcombine_s16(mask_val0, mask_val1)))); + } + return vrshrq_n_s16(final_val, subsampling_y + 1); + } + assert(subsampling_y == 0 && subsampling_x == 0); + const uint8x8_t mask_val0 = Load4(mask); + const uint8x8_t mask_val = Load4<1>(mask + mask_stride, mask_val0); + return vreinterpretq_s16_u16(vmovl_u8(mask_val)); +} + +template <int subsampling_x, int subsampling_y> +inline int16x8_t GetMask8(const uint8_t* mask, ptrdiff_t mask_stride) { + if (subsampling_x == 1) { + int16x8_t mask_val = vreinterpretq_s16_u16(vpaddlq_u8(vld1q_u8(mask))); + if (subsampling_y == 1) { + const int16x8_t next_mask_val = + vreinterpretq_s16_u16(vpaddlq_u8(vld1q_u8(mask + mask_stride))); + mask_val = vaddq_s16(mask_val, next_mask_val); + } + return vrshrq_n_s16(mask_val, 1 + subsampling_y); + } + assert(subsampling_y == 0 && subsampling_x == 0); + const uint8x8_t mask_val = vld1_u8(mask); + return vreinterpretq_s16_u16(vmovl_u8(mask_val)); +} + +inline void WriteMaskBlendLine4x2(const int16_t* const pred_0, + const int16_t* const pred_1, + const int16x8_t pred_mask_0, + const int16x8_t pred_mask_1, uint8_t* dst, + const ptrdiff_t dst_stride) { + const int16x8_t pred_val_0 = vld1q_s16(pred_0); + const int16x8_t pred_val_1 = vld1q_s16(pred_1); + // int res = (mask_value * prediction_0[x] + + // (64 - mask_value) * prediction_1[x]) >> 6; + const int32x4_t weighted_pred_0_lo = + vmull_s16(vget_low_s16(pred_mask_0), vget_low_s16(pred_val_0)); + const int32x4_t weighted_pred_0_hi = + vmull_s16(vget_high_s16(pred_mask_0), vget_high_s16(pred_val_0)); + const int32x4_t weighted_combo_lo = vmlal_s16( + weighted_pred_0_lo, vget_low_s16(pred_mask_1), vget_low_s16(pred_val_1)); + const int32x4_t weighted_combo_hi = + vmlal_s16(weighted_pred_0_hi, vget_high_s16(pred_mask_1), + vget_high_s16(pred_val_1)); + // dst[x] = static_cast<Pixel>( + // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0, + // (1 << kBitdepth8) - 1)); + const uint8x8_t result = + vqrshrun_n_s16(vcombine_s16(vshrn_n_s32(weighted_combo_lo, 6), + vshrn_n_s32(weighted_combo_hi, 6)), + 4); + StoreLo4(dst, result); + StoreHi4(dst + dst_stride, result); +} + +template <int subsampling_x, int subsampling_y> +inline void MaskBlending4x4_NEON(const int16_t* pred_0, const int16_t* pred_1, + const uint8_t* mask, + const ptrdiff_t mask_stride, uint8_t* dst, + const ptrdiff_t dst_stride) { + const int16x8_t mask_inverter = vdupq_n_s16(64); + int16x8_t pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + // TODO(b/150461164): Arm tends to do better with load(val); val += stride + // It may be possible to turn this into a loop with a templated height. + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + + pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); +} + +template <int subsampling_x, int subsampling_y> +inline void MaskBlending4xH_NEON(const int16_t* pred_0, const int16_t* pred_1, + const uint8_t* const mask_ptr, + const ptrdiff_t mask_stride, const int height, + uint8_t* dst, const ptrdiff_t dst_stride) { + const uint8_t* mask = mask_ptr; + if (height == 4) { + MaskBlending4x4_NEON<subsampling_x, subsampling_y>( + pred_0, pred_1, mask, mask_stride, dst, dst_stride); + return; + } + const int16x8_t mask_inverter = vdupq_n_s16(64); + int y = 0; + do { + int16x8_t pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + + pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + + pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + + pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + y += 8; + } while (y < height); +} + +template <int subsampling_x, int subsampling_y> +inline void MaskBlend_NEON(const void* prediction_0, const void* prediction_1, + const ptrdiff_t /*prediction_stride_1*/, + const uint8_t* const mask_ptr, + const ptrdiff_t mask_stride, const int width, + const int height, void* dest, + const ptrdiff_t dst_stride) { + auto* dst = static_cast<uint8_t*>(dest); + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + if (width == 4) { + MaskBlending4xH_NEON<subsampling_x, subsampling_y>( + pred_0, pred_1, mask_ptr, mask_stride, height, dst, dst_stride); + return; + } + const uint8_t* mask = mask_ptr; + const int16x8_t mask_inverter = vdupq_n_s16(64); + int y = 0; + do { + int x = 0; + do { + const int16x8_t pred_mask_0 = GetMask8<subsampling_x, subsampling_y>( + mask + (x << subsampling_x), mask_stride); + // 64 - mask + const int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + const int16x8_t pred_val_0 = vld1q_s16(pred_0 + x); + const int16x8_t pred_val_1 = vld1q_s16(pred_1 + x); + uint8x8_t result; + // int res = (mask_value * prediction_0[x] + + // (64 - mask_value) * prediction_1[x]) >> 6; + const int32x4_t weighted_pred_0_lo = + vmull_s16(vget_low_s16(pred_mask_0), vget_low_s16(pred_val_0)); + const int32x4_t weighted_pred_0_hi = + vmull_s16(vget_high_s16(pred_mask_0), vget_high_s16(pred_val_0)); + const int32x4_t weighted_combo_lo = + vmlal_s16(weighted_pred_0_lo, vget_low_s16(pred_mask_1), + vget_low_s16(pred_val_1)); + const int32x4_t weighted_combo_hi = + vmlal_s16(weighted_pred_0_hi, vget_high_s16(pred_mask_1), + vget_high_s16(pred_val_1)); + + // dst[x] = static_cast<Pixel>( + // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0, + // (1 << kBitdepth8) - 1)); + result = vqrshrun_n_s16(vcombine_s16(vshrn_n_s32(weighted_combo_lo, 6), + vshrn_n_s32(weighted_combo_hi, 6)), + 4); + vst1_u8(dst + x, result); + + x += 8; + } while (x < width); + dst += dst_stride; + pred_0 += width; + pred_1 += width; + mask += mask_stride << subsampling_y; + } while (++y < height); +} + +// TODO(b/150461164): This is much faster for inter_intra (input is Pixel +// values) but regresses compound versions (input is int16_t). Try to +// consolidate these. +template <int subsampling_x, int subsampling_y> +inline uint8x8_t GetInterIntraMask4x2(const uint8_t* mask, + ptrdiff_t mask_stride) { + if (subsampling_x == 1) { + const uint8x8_t mask_val = + vpadd_u8(vld1_u8(mask), vld1_u8(mask + (mask_stride << subsampling_y))); + if (subsampling_y == 1) { + const uint8x8_t next_mask_val = vpadd_u8(vld1_u8(mask + mask_stride), + vld1_u8(mask + mask_stride * 3)); + + // Use a saturating add to work around the case where all |mask| values + // are 64. Together with the rounding shift this ensures the correct + // result. + const uint8x8_t sum = vqadd_u8(mask_val, next_mask_val); + return vrshr_n_u8(sum, /*subsampling_x=*/1 + subsampling_y); + } + + return vrshr_n_u8(mask_val, /*subsampling_x=*/1); + } + + assert(subsampling_y == 0 && subsampling_x == 0); + const uint8x8_t mask_val0 = Load4(mask); + // TODO(b/150461164): Investigate the source of |mask| and see if the stride + // can be removed. + // TODO(b/150461164): The unit tests start at 8x8. Does this get run? + return Load4<1>(mask + mask_stride, mask_val0); +} + +template <int subsampling_x, int subsampling_y> +inline uint8x8_t GetInterIntraMask8(const uint8_t* mask, + ptrdiff_t mask_stride) { + if (subsampling_x == 1) { + const uint8x16_t mask_val = vld1q_u8(mask); + const uint8x8_t mask_paired = + vpadd_u8(vget_low_u8(mask_val), vget_high_u8(mask_val)); + if (subsampling_y == 1) { + const uint8x16_t next_mask_val = vld1q_u8(mask + mask_stride); + const uint8x8_t next_mask_paired = + vpadd_u8(vget_low_u8(next_mask_val), vget_high_u8(next_mask_val)); + + // Use a saturating add to work around the case where all |mask| values + // are 64. Together with the rounding shift this ensures the correct + // result. + const uint8x8_t sum = vqadd_u8(mask_paired, next_mask_paired); + return vrshr_n_u8(sum, /*subsampling_x=*/1 + subsampling_y); + } + + return vrshr_n_u8(mask_paired, /*subsampling_x=*/1); + } + + assert(subsampling_y == 0 && subsampling_x == 0); + return vld1_u8(mask); +} + +inline void InterIntraWriteMaskBlendLine8bpp4x2(const uint8_t* const pred_0, + uint8_t* const pred_1, + const ptrdiff_t pred_stride_1, + const uint8x8_t pred_mask_0, + const uint8x8_t pred_mask_1) { + const uint8x8_t pred_val_0 = vld1_u8(pred_0); + uint8x8_t pred_val_1 = Load4(pred_1); + pred_val_1 = Load4<1>(pred_1 + pred_stride_1, pred_val_1); + + const uint16x8_t weighted_pred_0 = vmull_u8(pred_mask_0, pred_val_0); + const uint16x8_t weighted_combo = + vmlal_u8(weighted_pred_0, pred_mask_1, pred_val_1); + const uint8x8_t result = vrshrn_n_u16(weighted_combo, 6); + StoreLo4(pred_1, result); + StoreHi4(pred_1 + pred_stride_1, result); +} + +template <int subsampling_x, int subsampling_y> +inline void InterIntraMaskBlending8bpp4x4_NEON(const uint8_t* pred_0, + uint8_t* pred_1, + const ptrdiff_t pred_stride_1, + const uint8_t* mask, + const ptrdiff_t mask_stride) { + const uint8x8_t mask_inverter = vdup_n_u8(64); + uint8x8_t pred_mask_1 = + GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + uint8x8_t pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1); + InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1, + pred_mask_0, pred_mask_1); + pred_0 += 4 << 1; + pred_1 += pred_stride_1 << 1; + mask += mask_stride << (1 + subsampling_y); + + pred_mask_1 = + GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1); + InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1, + pred_mask_0, pred_mask_1); +} + +template <int subsampling_x, int subsampling_y> +inline void InterIntraMaskBlending8bpp4xH_NEON( + const uint8_t* pred_0, uint8_t* pred_1, const ptrdiff_t pred_stride_1, + const uint8_t* mask, const ptrdiff_t mask_stride, const int height) { + if (height == 4) { + InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>( + pred_0, pred_1, pred_stride_1, mask, mask_stride); + return; + } + int y = 0; + do { + InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>( + pred_0, pred_1, pred_stride_1, mask, mask_stride); + pred_0 += 4 << 2; + pred_1 += pred_stride_1 << 2; + mask += mask_stride << (2 + subsampling_y); + + InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>( + pred_0, pred_1, pred_stride_1, mask, mask_stride); + pred_0 += 4 << 2; + pred_1 += pred_stride_1 << 2; + mask += mask_stride << (2 + subsampling_y); + y += 8; + } while (y < height); +} + +template <int subsampling_x, int subsampling_y> +inline void InterIntraMaskBlend8bpp_NEON(const uint8_t* prediction_0, + uint8_t* prediction_1, + const ptrdiff_t prediction_stride_1, + const uint8_t* const mask_ptr, + const ptrdiff_t mask_stride, + const int width, const int height) { + if (width == 4) { + InterIntraMaskBlending8bpp4xH_NEON<subsampling_x, subsampling_y>( + prediction_0, prediction_1, prediction_stride_1, mask_ptr, mask_stride, + height); + return; + } + const uint8_t* mask = mask_ptr; + const uint8x8_t mask_inverter = vdup_n_u8(64); + int y = 0; + do { + int x = 0; + do { + // TODO(b/150461164): Consider a 16 wide specialization (at least for the + // unsampled version) to take advantage of vld1q_u8(). + const uint8x8_t pred_mask_1 = + GetInterIntraMask8<subsampling_x, subsampling_y>( + mask + (x << subsampling_x), mask_stride); + // 64 - mask + const uint8x8_t pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1); + const uint8x8_t pred_val_0 = vld1_u8(prediction_0); + prediction_0 += 8; + const uint8x8_t pred_val_1 = vld1_u8(prediction_1 + x); + const uint16x8_t weighted_pred_0 = vmull_u8(pred_mask_0, pred_val_0); + // weighted_pred0 + weighted_pred1 + const uint16x8_t weighted_combo = + vmlal_u8(weighted_pred_0, pred_mask_1, pred_val_1); + const uint8x8_t result = vrshrn_n_u16(weighted_combo, 6); + vst1_u8(prediction_1 + x, result); + + x += 8; + } while (x < width); + prediction_1 += prediction_stride_1; + mask += mask_stride << subsampling_y; + } while (++y < height); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->mask_blend[0][0] = MaskBlend_NEON<0, 0>; + dsp->mask_blend[1][0] = MaskBlend_NEON<1, 0>; + dsp->mask_blend[2][0] = MaskBlend_NEON<1, 1>; + // The is_inter_intra index of mask_blend[][] is replaced by + // inter_intra_mask_blend_8bpp[] in 8-bit. + dsp->inter_intra_mask_blend_8bpp[0] = InterIntraMaskBlend8bpp_NEON<0, 0>; + dsp->inter_intra_mask_blend_8bpp[1] = InterIntraMaskBlend8bpp_NEON<1, 0>; + dsp->inter_intra_mask_blend_8bpp[2] = InterIntraMaskBlend8bpp_NEON<1, 1>; +} + +} // namespace +} // namespace low_bitdepth + +void MaskBlendInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON + +namespace libgav1 { +namespace dsp { + +void MaskBlendInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON |