diff options
author | qinxialei <xialeiqin@gmail.com> | 2020-10-29 11:26:59 +0800 |
---|---|---|
committer | qinxialei <xialeiqin@gmail.com> | 2020-10-29 11:26:59 +0800 |
commit | e8d277081293b6fb2a5d469616baaa7a06f52496 (patch) | |
tree | 1179bb07d3927d1837d4a90bd81b2034c4c696a9 /src/dsp/arm | |
download | libgav1-e8d277081293b6fb2a5d469616baaa7a06f52496.tar.gz libgav1-e8d277081293b6fb2a5d469616baaa7a06f52496.tar.bz2 libgav1-e8d277081293b6fb2a5d469616baaa7a06f52496.zip |
Import Upstream version 0.16.0
Diffstat (limited to 'src/dsp/arm')
39 files changed, 19650 insertions, 0 deletions
diff --git a/src/dsp/arm/average_blend_neon.cc b/src/dsp/arm/average_blend_neon.cc new file mode 100644 index 0000000..834e8b4 --- /dev/null +++ b/src/dsp/arm/average_blend_neon.cc @@ -0,0 +1,146 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/average_blend.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace { + +constexpr int kInterPostRoundBit = + kInterRoundBitsVertical - kInterRoundBitsCompoundVertical; + +inline uint8x8_t AverageBlend8Row(const int16_t* prediction_0, + const int16_t* prediction_1) { + const int16x8_t pred0 = vld1q_s16(prediction_0); + const int16x8_t pred1 = vld1q_s16(prediction_1); + const int16x8_t res = vaddq_s16(pred0, pred1); + return vqrshrun_n_s16(res, kInterPostRoundBit + 1); +} + +inline void AverageBlendLargeRow(const int16_t* prediction_0, + const int16_t* prediction_1, const int width, + uint8_t* dest) { + int x = width; + do { + const int16x8_t pred_00 = vld1q_s16(prediction_0); + const int16x8_t pred_01 = vld1q_s16(prediction_1); + prediction_0 += 8; + prediction_1 += 8; + const int16x8_t res0 = vaddq_s16(pred_00, pred_01); + const uint8x8_t res_out0 = vqrshrun_n_s16(res0, kInterPostRoundBit + 1); + const int16x8_t pred_10 = vld1q_s16(prediction_0); + const int16x8_t pred_11 = vld1q_s16(prediction_1); + prediction_0 += 8; + prediction_1 += 8; + const int16x8_t res1 = vaddq_s16(pred_10, pred_11); + const uint8x8_t res_out1 = vqrshrun_n_s16(res1, kInterPostRoundBit + 1); + vst1q_u8(dest, vcombine_u8(res_out0, res_out1)); + dest += 16; + x -= 16; + } while (x != 0); +} + +void AverageBlend_NEON(const void* prediction_0, const void* prediction_1, + const int width, const int height, void* const dest, + const ptrdiff_t dest_stride) { + auto* dst = static_cast<uint8_t*>(dest); + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y = height; + + if (width == 4) { + do { + const uint8x8_t result = AverageBlend8Row(pred_0, pred_1); + pred_0 += 8; + pred_1 += 8; + + StoreLo4(dst, result); + dst += dest_stride; + StoreHi4(dst, result); + dst += dest_stride; + y -= 2; + } while (y != 0); + return; + } + + if (width == 8) { + do { + vst1_u8(dst, AverageBlend8Row(pred_0, pred_1)); + dst += dest_stride; + pred_0 += 8; + pred_1 += 8; + + vst1_u8(dst, AverageBlend8Row(pred_0, pred_1)); + dst += dest_stride; + pred_0 += 8; + pred_1 += 8; + + y -= 2; + } while (y != 0); + return; + } + + do { + AverageBlendLargeRow(pred_0, pred_1, width, dst); + dst += dest_stride; + pred_0 += width; + pred_1 += width; + + AverageBlendLargeRow(pred_0, pred_1, width, dst); + dst += dest_stride; + pred_0 += width; + pred_1 += width; + + y -= 2; + } while (y != 0); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->average_blend = AverageBlend_NEON; +} + +} // namespace + +void AverageBlendInit_NEON() { Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON + +namespace libgav1 { +namespace dsp { + +void AverageBlendInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/average_blend_neon.h b/src/dsp/arm/average_blend_neon.h new file mode 100644 index 0000000..d13bcd6 --- /dev/null +++ b/src/dsp/arm/average_blend_neon.h @@ -0,0 +1,36 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_AVERAGE_BLEND_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_AVERAGE_BLEND_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::average_blend. This function is not thread-safe. +void AverageBlendInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_AverageBlend LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_AVERAGE_BLEND_NEON_H_ diff --git a/src/dsp/arm/cdef_neon.cc b/src/dsp/arm/cdef_neon.cc new file mode 100644 index 0000000..4d0e76f --- /dev/null +++ b/src/dsp/arm/cdef_neon.cc @@ -0,0 +1,697 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/cdef.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstdlib> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +#include "src/dsp/cdef.inc" + +// ---------------------------------------------------------------------------- +// Refer to CdefDirection_C(). +// +// int32_t partial[8][15] = {}; +// for (int i = 0; i < 8; ++i) { +// for (int j = 0; j < 8; ++j) { +// const int x = 1; +// partial[0][i + j] += x; +// partial[1][i + j / 2] += x; +// partial[2][i] += x; +// partial[3][3 + i - j / 2] += x; +// partial[4][7 + i - j] += x; +// partial[5][3 - i / 2 + j] += x; +// partial[6][j] += x; +// partial[7][i / 2 + j] += x; +// } +// } +// +// Using the code above, generate the position count for partial[8][15]. +// +// partial[0]: 1 2 3 4 5 6 7 8 7 6 5 4 3 2 1 +// partial[1]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0 +// partial[2]: 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0 +// partial[3]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0 +// partial[4]: 1 2 3 4 5 6 7 8 7 6 5 4 3 2 1 +// partial[5]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0 +// partial[6]: 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0 +// partial[7]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0 +// +// The SIMD code shifts the input horizontally, then adds vertically to get the +// correct partial value for the given position. +// ---------------------------------------------------------------------------- + +// ---------------------------------------------------------------------------- +// partial[0][i + j] += x; +// +// 00 01 02 03 04 05 06 07 00 00 00 00 00 00 00 +// 00 10 11 12 13 14 15 16 17 00 00 00 00 00 00 +// 00 00 20 21 22 23 24 25 26 27 00 00 00 00 00 +// 00 00 00 30 31 32 33 34 35 36 37 00 00 00 00 +// 00 00 00 00 40 41 42 43 44 45 46 47 00 00 00 +// 00 00 00 00 00 50 51 52 53 54 55 56 57 00 00 +// 00 00 00 00 00 00 60 61 62 63 64 65 66 67 00 +// 00 00 00 00 00 00 00 70 71 72 73 74 75 76 77 +// +// partial[4] is the same except the source is reversed. +LIBGAV1_ALWAYS_INLINE void AddPartial_D0_D4(uint8x8_t* v_src, + uint16x8_t* partial_lo, + uint16x8_t* partial_hi) { + const uint8x8_t v_zero = vdup_n_u8(0); + // 00 01 02 03 04 05 06 07 + // 00 10 11 12 13 14 15 16 + *partial_lo = vaddl_u8(v_src[0], vext_u8(v_zero, v_src[1], 7)); + + // 00 00 20 21 22 23 24 25 + *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[2], 6)); + // 17 00 00 00 00 00 00 00 + // 26 27 00 00 00 00 00 00 + *partial_hi = + vaddl_u8(vext_u8(v_src[1], v_zero, 7), vext_u8(v_src[2], v_zero, 6)); + + // 00 00 00 30 31 32 33 34 + *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[3], 5)); + // 35 36 37 00 00 00 00 00 + *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[3], v_zero, 5)); + + // 00 00 00 00 40 41 42 43 + *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[4], 4)); + // 44 45 46 47 00 00 00 00 + *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[4], v_zero, 4)); + + // 00 00 00 00 00 50 51 52 + *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[5], 3)); + // 53 54 55 56 57 00 00 00 + *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[5], v_zero, 3)); + + // 00 00 00 00 00 00 60 61 + *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[6], 2)); + // 62 63 64 65 66 67 00 00 + *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[6], v_zero, 2)); + + // 00 00 00 00 00 00 00 70 + *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[7], 1)); + // 71 72 73 74 75 76 77 00 + *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[7], v_zero, 1)); +} + +// ---------------------------------------------------------------------------- +// partial[1][i + j / 2] += x; +// +// A0 = src[0] + src[1], A1 = src[2] + src[3], ... +// +// A0 A1 A2 A3 00 00 00 00 00 00 00 00 00 00 00 +// 00 B0 B1 B2 B3 00 00 00 00 00 00 00 00 00 00 +// 00 00 C0 C1 C2 C3 00 00 00 00 00 00 00 00 00 +// 00 00 00 D0 D1 D2 D3 00 00 00 00 00 00 00 00 +// 00 00 00 00 E0 E1 E2 E3 00 00 00 00 00 00 00 +// 00 00 00 00 00 F0 F1 F2 F3 00 00 00 00 00 00 +// 00 00 00 00 00 00 G0 G1 G2 G3 00 00 00 00 00 +// 00 00 00 00 00 00 00 H0 H1 H2 H3 00 00 00 00 +// +// partial[3] is the same except the source is reversed. +LIBGAV1_ALWAYS_INLINE void AddPartial_D1_D3(uint8x8_t* v_src, + uint16x8_t* partial_lo, + uint16x8_t* partial_hi) { + uint8x16_t v_d1_temp[8]; + const uint8x8_t v_zero = vdup_n_u8(0); + const uint8x16_t v_zero_16 = vdupq_n_u8(0); + + for (int i = 0; i < 8; ++i) { + v_d1_temp[i] = vcombine_u8(v_src[i], v_zero); + } + + *partial_lo = *partial_hi = vdupq_n_u16(0); + // A0 A1 A2 A3 00 00 00 00 + *partial_lo = vpadalq_u8(*partial_lo, v_d1_temp[0]); + + // 00 B0 B1 B2 B3 00 00 00 + *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[1], 14)); + + // 00 00 C0 C1 C2 C3 00 00 + *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[2], 12)); + // 00 00 00 D0 D1 D2 D3 00 + *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[3], 10)); + // 00 00 00 00 E0 E1 E2 E3 + *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[4], 8)); + + // 00 00 00 00 00 F0 F1 F2 + *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[5], 6)); + // F3 00 00 00 00 00 00 00 + *partial_hi = vpadalq_u8(*partial_hi, vextq_u8(v_d1_temp[5], v_zero_16, 6)); + + // 00 00 00 00 00 00 G0 G1 + *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[6], 4)); + // G2 G3 00 00 00 00 00 00 + *partial_hi = vpadalq_u8(*partial_hi, vextq_u8(v_d1_temp[6], v_zero_16, 4)); + + // 00 00 00 00 00 00 00 H0 + *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[7], 2)); + // H1 H2 H3 00 00 00 00 00 + *partial_hi = vpadalq_u8(*partial_hi, vextq_u8(v_d1_temp[7], v_zero_16, 2)); +} + +// ---------------------------------------------------------------------------- +// partial[7][i / 2 + j] += x; +// +// 00 01 02 03 04 05 06 07 00 00 00 00 00 00 00 +// 10 11 12 13 14 15 16 17 00 00 00 00 00 00 00 +// 00 20 21 22 23 24 25 26 27 00 00 00 00 00 00 +// 00 30 31 32 33 34 35 36 37 00 00 00 00 00 00 +// 00 00 40 41 42 43 44 45 46 47 00 00 00 00 00 +// 00 00 50 51 52 53 54 55 56 57 00 00 00 00 00 +// 00 00 00 60 61 62 63 64 65 66 67 00 00 00 00 +// 00 00 00 70 71 72 73 74 75 76 77 00 00 00 00 +// +// partial[5] is the same except the source is reversed. +LIBGAV1_ALWAYS_INLINE void AddPartial_D5_D7(uint8x8_t* v_src, + uint16x8_t* partial_lo, + uint16x8_t* partial_hi) { + const uint16x8_t v_zero = vdupq_n_u16(0); + uint16x8_t v_pair_add[4]; + // Add vertical source pairs. + v_pair_add[0] = vaddl_u8(v_src[0], v_src[1]); + v_pair_add[1] = vaddl_u8(v_src[2], v_src[3]); + v_pair_add[2] = vaddl_u8(v_src[4], v_src[5]); + v_pair_add[3] = vaddl_u8(v_src[6], v_src[7]); + + // 00 01 02 03 04 05 06 07 + // 10 11 12 13 14 15 16 17 + *partial_lo = v_pair_add[0]; + // 00 00 00 00 00 00 00 00 + // 00 00 00 00 00 00 00 00 + *partial_hi = vdupq_n_u16(0); + + // 00 20 21 22 23 24 25 26 + // 00 30 31 32 33 34 35 36 + *partial_lo = vaddq_u16(*partial_lo, vextq_u16(v_zero, v_pair_add[1], 7)); + // 27 00 00 00 00 00 00 00 + // 37 00 00 00 00 00 00 00 + *partial_hi = vaddq_u16(*partial_hi, vextq_u16(v_pair_add[1], v_zero, 7)); + + // 00 00 40 41 42 43 44 45 + // 00 00 50 51 52 53 54 55 + *partial_lo = vaddq_u16(*partial_lo, vextq_u16(v_zero, v_pair_add[2], 6)); + // 46 47 00 00 00 00 00 00 + // 56 57 00 00 00 00 00 00 + *partial_hi = vaddq_u16(*partial_hi, vextq_u16(v_pair_add[2], v_zero, 6)); + + // 00 00 00 60 61 62 63 64 + // 00 00 00 70 71 72 73 74 + *partial_lo = vaddq_u16(*partial_lo, vextq_u16(v_zero, v_pair_add[3], 5)); + // 65 66 67 00 00 00 00 00 + // 75 76 77 00 00 00 00 00 + *partial_hi = vaddq_u16(*partial_hi, vextq_u16(v_pair_add[3], v_zero, 5)); +} + +LIBGAV1_ALWAYS_INLINE void AddPartial(const void* const source, + ptrdiff_t stride, uint16x8_t* partial_lo, + uint16x8_t* partial_hi) { + const auto* src = static_cast<const uint8_t*>(source); + + // 8x8 input + // 00 01 02 03 04 05 06 07 + // 10 11 12 13 14 15 16 17 + // 20 21 22 23 24 25 26 27 + // 30 31 32 33 34 35 36 37 + // 40 41 42 43 44 45 46 47 + // 50 51 52 53 54 55 56 57 + // 60 61 62 63 64 65 66 67 + // 70 71 72 73 74 75 76 77 + uint8x8_t v_src[8]; + for (int i = 0; i < 8; ++i) { + v_src[i] = vld1_u8(src); + src += stride; + } + + // partial for direction 2 + // -------------------------------------------------------------------------- + // partial[2][i] += x; + // 00 10 20 30 40 50 60 70 00 00 00 00 00 00 00 00 + // 01 11 21 33 41 51 61 71 00 00 00 00 00 00 00 00 + // 02 12 22 33 42 52 62 72 00 00 00 00 00 00 00 00 + // 03 13 23 33 43 53 63 73 00 00 00 00 00 00 00 00 + // 04 14 24 34 44 54 64 74 00 00 00 00 00 00 00 00 + // 05 15 25 35 45 55 65 75 00 00 00 00 00 00 00 00 + // 06 16 26 36 46 56 66 76 00 00 00 00 00 00 00 00 + // 07 17 27 37 47 57 67 77 00 00 00 00 00 00 00 00 + partial_lo[2] = vsetq_lane_u16(SumVector(v_src[0]), partial_lo[2], 0); + partial_lo[2] = vsetq_lane_u16(SumVector(v_src[1]), partial_lo[2], 1); + partial_lo[2] = vsetq_lane_u16(SumVector(v_src[2]), partial_lo[2], 2); + partial_lo[2] = vsetq_lane_u16(SumVector(v_src[3]), partial_lo[2], 3); + partial_lo[2] = vsetq_lane_u16(SumVector(v_src[4]), partial_lo[2], 4); + partial_lo[2] = vsetq_lane_u16(SumVector(v_src[5]), partial_lo[2], 5); + partial_lo[2] = vsetq_lane_u16(SumVector(v_src[6]), partial_lo[2], 6); + partial_lo[2] = vsetq_lane_u16(SumVector(v_src[7]), partial_lo[2], 7); + + // partial for direction 6 + // -------------------------------------------------------------------------- + // partial[6][j] += x; + // 00 01 02 03 04 05 06 07 00 00 00 00 00 00 00 00 + // 10 11 12 13 14 15 16 17 00 00 00 00 00 00 00 00 + // 20 21 22 23 24 25 26 27 00 00 00 00 00 00 00 00 + // 30 31 32 33 34 35 36 37 00 00 00 00 00 00 00 00 + // 40 41 42 43 44 45 46 47 00 00 00 00 00 00 00 00 + // 50 51 52 53 54 55 56 57 00 00 00 00 00 00 00 00 + // 60 61 62 63 64 65 66 67 00 00 00 00 00 00 00 00 + // 70 71 72 73 74 75 76 77 00 00 00 00 00 00 00 00 + const uint8x8_t v_zero = vdup_n_u8(0); + partial_lo[6] = vaddl_u8(v_zero, v_src[0]); + for (int i = 1; i < 8; ++i) { + partial_lo[6] = vaddw_u8(partial_lo[6], v_src[i]); + } + + // partial for direction 0 + AddPartial_D0_D4(v_src, &partial_lo[0], &partial_hi[0]); + + // partial for direction 1 + AddPartial_D1_D3(v_src, &partial_lo[1], &partial_hi[1]); + + // partial for direction 7 + AddPartial_D5_D7(v_src, &partial_lo[7], &partial_hi[7]); + + uint8x8_t v_src_reverse[8]; + for (int i = 0; i < 8; ++i) { + v_src_reverse[i] = vrev64_u8(v_src[i]); + } + + // partial for direction 4 + AddPartial_D0_D4(v_src_reverse, &partial_lo[4], &partial_hi[4]); + + // partial for direction 3 + AddPartial_D1_D3(v_src_reverse, &partial_lo[3], &partial_hi[3]); + + // partial for direction 5 + AddPartial_D5_D7(v_src_reverse, &partial_lo[5], &partial_hi[5]); +} + +uint32x4_t Square(uint16x4_t a) { return vmull_u16(a, a); } + +uint32x4_t SquareAccumulate(uint32x4_t a, uint16x4_t b) { + return vmlal_u16(a, b, b); +} + +// |cost[0]| and |cost[4]| square the input and sum with the corresponding +// element from the other end of the vector: +// |kCdefDivisionTable[]| element: +// cost[0] += (Square(partial[0][i]) + Square(partial[0][14 - i])) * +// kCdefDivisionTable[i + 1]; +// cost[0] += Square(partial[0][7]) * kCdefDivisionTable[8]; +// Because everything is being summed into a single value the distributive +// property allows us to mirror the division table and accumulate once. +uint32_t Cost0Or4(const uint16x8_t a, const uint16x8_t b, + const uint32x4_t division_table[4]) { + uint32x4_t c = vmulq_u32(Square(vget_low_u16(a)), division_table[0]); + c = vmlaq_u32(c, Square(vget_high_u16(a)), division_table[1]); + c = vmlaq_u32(c, Square(vget_low_u16(b)), division_table[2]); + c = vmlaq_u32(c, Square(vget_high_u16(b)), division_table[3]); + return SumVector(c); +} + +// |cost[2]| and |cost[6]| square the input and accumulate: +// cost[2] += Square(partial[2][i]) +uint32_t SquareAccumulate(const uint16x8_t a) { + uint32x4_t c = Square(vget_low_u16(a)); + c = SquareAccumulate(c, vget_high_u16(a)); + c = vmulq_n_u32(c, kCdefDivisionTable[7]); + return SumVector(c); +} + +uint32_t CostOdd(const uint16x8_t a, const uint16x8_t b, const uint32x4_t mask, + const uint32x4_t division_table[2]) { + // Remove elements 0-2. + uint32x4_t c = vandq_u32(mask, Square(vget_low_u16(a))); + c = vaddq_u32(c, Square(vget_high_u16(a))); + c = vmulq_n_u32(c, kCdefDivisionTable[7]); + + c = vmlaq_u32(c, Square(vget_low_u16(a)), division_table[0]); + c = vmlaq_u32(c, Square(vget_low_u16(b)), division_table[1]); + return SumVector(c); +} + +void CdefDirection_NEON(const void* const source, ptrdiff_t stride, + uint8_t* const direction, int* const variance) { + assert(direction != nullptr); + assert(variance != nullptr); + const auto* src = static_cast<const uint8_t*>(source); + uint32_t cost[8]; + uint16x8_t partial_lo[8], partial_hi[8]; + + AddPartial(src, stride, partial_lo, partial_hi); + + cost[2] = SquareAccumulate(partial_lo[2]); + cost[6] = SquareAccumulate(partial_lo[6]); + + const uint32x4_t division_table[4] = { + vld1q_u32(kCdefDivisionTable), vld1q_u32(kCdefDivisionTable + 4), + vld1q_u32(kCdefDivisionTable + 8), vld1q_u32(kCdefDivisionTable + 12)}; + + cost[0] = Cost0Or4(partial_lo[0], partial_hi[0], division_table); + cost[4] = Cost0Or4(partial_lo[4], partial_hi[4], division_table); + + const uint32x4_t division_table_odd[2] = { + vld1q_u32(kCdefDivisionTableOdd), vld1q_u32(kCdefDivisionTableOdd + 4)}; + + const uint32x4_t element_3_mask = {0, 0, 0, static_cast<uint32_t>(-1)}; + + cost[1] = + CostOdd(partial_lo[1], partial_hi[1], element_3_mask, division_table_odd); + cost[3] = + CostOdd(partial_lo[3], partial_hi[3], element_3_mask, division_table_odd); + cost[5] = + CostOdd(partial_lo[5], partial_hi[5], element_3_mask, division_table_odd); + cost[7] = + CostOdd(partial_lo[7], partial_hi[7], element_3_mask, division_table_odd); + + uint32_t best_cost = 0; + *direction = 0; + for (int i = 0; i < 8; ++i) { + if (cost[i] > best_cost) { + best_cost = cost[i]; + *direction = i; + } + } + *variance = (best_cost - cost[(*direction + 4) & 7]) >> 10; +} + +// ------------------------------------------------------------------------- +// CdefFilter + +// Load 4 vectors based on the given |direction|. +void LoadDirection(const uint16_t* const src, const ptrdiff_t stride, + uint16x8_t* output, const int direction) { + // Each |direction| describes a different set of source values. Expand this + // set by negating each set. For |direction| == 0 this gives a diagonal line + // from top right to bottom left. The first value is y, the second x. Negative + // y values move up. + // a b c d + // {-1, 1}, {1, -1}, {-2, 2}, {2, -2} + // c + // a + // 0 + // b + // d + const int y_0 = kCdefDirections[direction][0][0]; + const int x_0 = kCdefDirections[direction][0][1]; + const int y_1 = kCdefDirections[direction][1][0]; + const int x_1 = kCdefDirections[direction][1][1]; + output[0] = vld1q_u16(src + y_0 * stride + x_0); + output[1] = vld1q_u16(src - y_0 * stride - x_0); + output[2] = vld1q_u16(src + y_1 * stride + x_1); + output[3] = vld1q_u16(src - y_1 * stride - x_1); +} + +// Load 4 vectors based on the given |direction|. Use when |block_width| == 4 to +// do 2 rows at a time. +void LoadDirection4(const uint16_t* const src, const ptrdiff_t stride, + uint16x8_t* output, const int direction) { + const int y_0 = kCdefDirections[direction][0][0]; + const int x_0 = kCdefDirections[direction][0][1]; + const int y_1 = kCdefDirections[direction][1][0]; + const int x_1 = kCdefDirections[direction][1][1]; + output[0] = vcombine_u16(vld1_u16(src + y_0 * stride + x_0), + vld1_u16(src + y_0 * stride + stride + x_0)); + output[1] = vcombine_u16(vld1_u16(src - y_0 * stride - x_0), + vld1_u16(src - y_0 * stride + stride - x_0)); + output[2] = vcombine_u16(vld1_u16(src + y_1 * stride + x_1), + vld1_u16(src + y_1 * stride + stride + x_1)); + output[3] = vcombine_u16(vld1_u16(src - y_1 * stride - x_1), + vld1_u16(src - y_1 * stride + stride - x_1)); +} + +int16x8_t Constrain(const uint16x8_t pixel, const uint16x8_t reference, + const uint16x8_t threshold, const int16x8_t damping) { + // If reference > pixel, the difference will be negative, so covert to 0 or + // -1. + const uint16x8_t sign = vcgtq_u16(reference, pixel); + const uint16x8_t abs_diff = vabdq_u16(pixel, reference); + const uint16x8_t shifted_diff = vshlq_u16(abs_diff, damping); + // For bitdepth == 8, the threshold range is [0, 15] and the damping range is + // [3, 6]. If pixel == kCdefLargeValue(0x4000), shifted_diff will always be + // larger than threshold. Subtract using saturation will return 0 when pixel + // == kCdefLargeValue. + static_assert(kCdefLargeValue == 0x4000, "Invalid kCdefLargeValue"); + const uint16x8_t thresh_minus_shifted_diff = + vqsubq_u16(threshold, shifted_diff); + const uint16x8_t clamp_abs_diff = + vminq_u16(thresh_minus_shifted_diff, abs_diff); + // Restore the sign. + return vreinterpretq_s16_u16( + vsubq_u16(veorq_u16(clamp_abs_diff, sign), sign)); +} + +template <int width, bool enable_primary = true, bool enable_secondary = true> +void CdefFilter_NEON(const uint16_t* src, const ptrdiff_t src_stride, + const int height, const int primary_strength, + const int secondary_strength, const int damping, + const int direction, void* dest, + const ptrdiff_t dst_stride) { + static_assert(width == 8 || width == 4, ""); + static_assert(enable_primary || enable_secondary, ""); + constexpr bool clipping_required = enable_primary && enable_secondary; + auto* dst = static_cast<uint8_t*>(dest); + const uint16x8_t cdef_large_value_mask = + vdupq_n_u16(static_cast<uint16_t>(~kCdefLargeValue)); + const uint16x8_t primary_threshold = vdupq_n_u16(primary_strength); + const uint16x8_t secondary_threshold = vdupq_n_u16(secondary_strength); + + int16x8_t primary_damping_shift, secondary_damping_shift; + + // FloorLog2() requires input to be > 0. + // 8-bit damping range: Y: [3, 6], UV: [2, 5]. + if (enable_primary) { + // primary_strength: [0, 15] -> FloorLog2: [0, 3] so a clamp is necessary + // for UV filtering. + primary_damping_shift = + vdupq_n_s16(-std::max(0, damping - FloorLog2(primary_strength))); + } + if (enable_secondary) { + // secondary_strength: [0, 4] -> FloorLog2: [0, 2] so no clamp to 0 is + // necessary. + assert(damping - FloorLog2(secondary_strength) >= 0); + secondary_damping_shift = + vdupq_n_s16(-(damping - FloorLog2(secondary_strength))); + } + + const int primary_tap_0 = kCdefPrimaryTaps[primary_strength & 1][0]; + const int primary_tap_1 = kCdefPrimaryTaps[primary_strength & 1][1]; + + int y = height; + do { + uint16x8_t pixel; + if (width == 8) { + pixel = vld1q_u16(src); + } else { + pixel = vcombine_u16(vld1_u16(src), vld1_u16(src + src_stride)); + } + + uint16x8_t min = pixel; + uint16x8_t max = pixel; + int16x8_t sum; + + if (enable_primary) { + // Primary |direction|. + uint16x8_t primary_val[4]; + if (width == 8) { + LoadDirection(src, src_stride, primary_val, direction); + } else { + LoadDirection4(src, src_stride, primary_val, direction); + } + + if (clipping_required) { + min = vminq_u16(min, primary_val[0]); + min = vminq_u16(min, primary_val[1]); + min = vminq_u16(min, primary_val[2]); + min = vminq_u16(min, primary_val[3]); + + // The source is 16 bits, however, we only really care about the lower + // 8 bits. The upper 8 bits contain the "large" flag. After the final + // primary max has been calculated, zero out the upper 8 bits. Use this + // to find the "16 bit" max. + const uint8x16_t max_p01 = + vmaxq_u8(vreinterpretq_u8_u16(primary_val[0]), + vreinterpretq_u8_u16(primary_val[1])); + const uint8x16_t max_p23 = + vmaxq_u8(vreinterpretq_u8_u16(primary_val[2]), + vreinterpretq_u8_u16(primary_val[3])); + const uint16x8_t max_p = + vreinterpretq_u16_u8(vmaxq_u8(max_p01, max_p23)); + max = vmaxq_u16(max, vandq_u16(max_p, cdef_large_value_mask)); + } + + sum = Constrain(primary_val[0], pixel, primary_threshold, + primary_damping_shift); + sum = vmulq_n_s16(sum, primary_tap_0); + sum = vmlaq_n_s16(sum, + Constrain(primary_val[1], pixel, primary_threshold, + primary_damping_shift), + primary_tap_0); + sum = vmlaq_n_s16(sum, + Constrain(primary_val[2], pixel, primary_threshold, + primary_damping_shift), + primary_tap_1); + sum = vmlaq_n_s16(sum, + Constrain(primary_val[3], pixel, primary_threshold, + primary_damping_shift), + primary_tap_1); + } else { + sum = vdupq_n_s16(0); + } + + if (enable_secondary) { + // Secondary |direction| values (+/- 2). Clamp |direction|. + uint16x8_t secondary_val[8]; + if (width == 8) { + LoadDirection(src, src_stride, secondary_val, direction + 2); + LoadDirection(src, src_stride, secondary_val + 4, direction - 2); + } else { + LoadDirection4(src, src_stride, secondary_val, direction + 2); + LoadDirection4(src, src_stride, secondary_val + 4, direction - 2); + } + + if (clipping_required) { + min = vminq_u16(min, secondary_val[0]); + min = vminq_u16(min, secondary_val[1]); + min = vminq_u16(min, secondary_val[2]); + min = vminq_u16(min, secondary_val[3]); + min = vminq_u16(min, secondary_val[4]); + min = vminq_u16(min, secondary_val[5]); + min = vminq_u16(min, secondary_val[6]); + min = vminq_u16(min, secondary_val[7]); + + const uint8x16_t max_s01 = + vmaxq_u8(vreinterpretq_u8_u16(secondary_val[0]), + vreinterpretq_u8_u16(secondary_val[1])); + const uint8x16_t max_s23 = + vmaxq_u8(vreinterpretq_u8_u16(secondary_val[2]), + vreinterpretq_u8_u16(secondary_val[3])); + const uint8x16_t max_s45 = + vmaxq_u8(vreinterpretq_u8_u16(secondary_val[4]), + vreinterpretq_u8_u16(secondary_val[5])); + const uint8x16_t max_s67 = + vmaxq_u8(vreinterpretq_u8_u16(secondary_val[6]), + vreinterpretq_u8_u16(secondary_val[7])); + const uint16x8_t max_s = vreinterpretq_u16_u8( + vmaxq_u8(vmaxq_u8(max_s01, max_s23), vmaxq_u8(max_s45, max_s67))); + max = vmaxq_u16(max, vandq_u16(max_s, cdef_large_value_mask)); + } + + sum = vmlaq_n_s16(sum, + Constrain(secondary_val[0], pixel, secondary_threshold, + secondary_damping_shift), + kCdefSecondaryTap0); + sum = vmlaq_n_s16(sum, + Constrain(secondary_val[1], pixel, secondary_threshold, + secondary_damping_shift), + kCdefSecondaryTap0); + sum = vmlaq_n_s16(sum, + Constrain(secondary_val[2], pixel, secondary_threshold, + secondary_damping_shift), + kCdefSecondaryTap1); + sum = vmlaq_n_s16(sum, + Constrain(secondary_val[3], pixel, secondary_threshold, + secondary_damping_shift), + kCdefSecondaryTap1); + sum = vmlaq_n_s16(sum, + Constrain(secondary_val[4], pixel, secondary_threshold, + secondary_damping_shift), + kCdefSecondaryTap0); + sum = vmlaq_n_s16(sum, + Constrain(secondary_val[5], pixel, secondary_threshold, + secondary_damping_shift), + kCdefSecondaryTap0); + sum = vmlaq_n_s16(sum, + Constrain(secondary_val[6], pixel, secondary_threshold, + secondary_damping_shift), + kCdefSecondaryTap1); + sum = vmlaq_n_s16(sum, + Constrain(secondary_val[7], pixel, secondary_threshold, + secondary_damping_shift), + kCdefSecondaryTap1); + } + // Clip3(pixel + ((8 + sum - (sum < 0)) >> 4), min, max)) + const int16x8_t sum_lt_0 = vshrq_n_s16(sum, 15); + sum = vaddq_s16(sum, sum_lt_0); + int16x8_t result = vrsraq_n_s16(vreinterpretq_s16_u16(pixel), sum, 4); + if (clipping_required) { + result = vminq_s16(result, vreinterpretq_s16_u16(max)); + result = vmaxq_s16(result, vreinterpretq_s16_u16(min)); + } + + const uint8x8_t dst_pixel = vqmovun_s16(result); + if (width == 8) { + src += src_stride; + vst1_u8(dst, dst_pixel); + dst += dst_stride; + --y; + } else { + src += src_stride << 1; + StoreLo4(dst, dst_pixel); + dst += dst_stride; + StoreHi4(dst, dst_pixel); + dst += dst_stride; + y -= 2; + } + } while (y != 0); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->cdef_direction = CdefDirection_NEON; + dsp->cdef_filters[0][0] = CdefFilter_NEON<4>; + dsp->cdef_filters[0][1] = + CdefFilter_NEON<4, /*enable_primary=*/true, /*enable_secondary=*/false>; + dsp->cdef_filters[0][2] = CdefFilter_NEON<4, /*enable_primary=*/false>; + dsp->cdef_filters[1][0] = CdefFilter_NEON<8>; + dsp->cdef_filters[1][1] = + CdefFilter_NEON<8, /*enable_primary=*/true, /*enable_secondary=*/false>; + dsp->cdef_filters[1][2] = CdefFilter_NEON<8, /*enable_primary=*/false>; +} + +} // namespace +} // namespace low_bitdepth + +void CdefInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void CdefInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/cdef_neon.h b/src/dsp/arm/cdef_neon.h new file mode 100644 index 0000000..53d5f86 --- /dev/null +++ b/src/dsp/arm/cdef_neon.h @@ -0,0 +1,38 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_CDEF_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_CDEF_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::cdef_direction and Dsp::cdef_filters. This function is not +// thread-safe. +void CdefInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_CdefDirection LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_CdefFilters LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_CDEF_NEON_H_ diff --git a/src/dsp/arm/common_neon.h b/src/dsp/arm/common_neon.h new file mode 100644 index 0000000..dcb7567 --- /dev/null +++ b/src/dsp/arm/common_neon.h @@ -0,0 +1,777 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_COMMON_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_COMMON_NEON_H_ + +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cstdint> +#include <cstring> + +#if 0 +#include <cstdio> + +#include "absl/strings/str_cat.h" + +constexpr bool kEnablePrintRegs = true; + +union DebugRegister { + int8_t i8[8]; + int16_t i16[4]; + int32_t i32[2]; + uint8_t u8[8]; + uint16_t u16[4]; + uint32_t u32[2]; +}; + +union DebugRegisterQ { + int8_t i8[16]; + int16_t i16[8]; + int32_t i32[4]; + uint8_t u8[16]; + uint16_t u16[8]; + uint32_t u32[4]; +}; + +// Quite useful macro for debugging. Left here for convenience. +inline void PrintVect(const DebugRegister r, const char* const name, int size) { + int n; + if (kEnablePrintRegs) { + fprintf(stderr, "%s\t: ", name); + if (size == 8) { + for (n = 0; n < 8; ++n) fprintf(stderr, "%.2x ", r.u8[n]); + } else if (size == 16) { + for (n = 0; n < 4; ++n) fprintf(stderr, "%.4x ", r.u16[n]); + } else if (size == 32) { + for (n = 0; n < 2; ++n) fprintf(stderr, "%.8x ", r.u32[n]); + } + fprintf(stderr, "\n"); + } +} + +// Debugging macro for 128-bit types. +inline void PrintVectQ(const DebugRegisterQ r, const char* const name, + int size) { + int n; + if (kEnablePrintRegs) { + fprintf(stderr, "%s\t: ", name); + if (size == 8) { + for (n = 0; n < 16; ++n) fprintf(stderr, "%.2x ", r.u8[n]); + } else if (size == 16) { + for (n = 0; n < 8; ++n) fprintf(stderr, "%.4x ", r.u16[n]); + } else if (size == 32) { + for (n = 0; n < 4; ++n) fprintf(stderr, "%.8x ", r.u32[n]); + } + fprintf(stderr, "\n"); + } +} + +inline void PrintReg(const int32x4x2_t val, const std::string& name) { + DebugRegisterQ r; + vst1q_u32(r.u32, val.val[0]); + const std::string name0 = absl::StrCat(name, ".val[0]").c_str(); + PrintVectQ(r, name0.c_str(), 32); + vst1q_u32(r.u32, val.val[1]); + const std::string name1 = absl::StrCat(name, ".val[1]").c_str(); + PrintVectQ(r, name1.c_str(), 32); +} + +inline void PrintReg(const uint32x4_t val, const char* name) { + DebugRegisterQ r; + vst1q_u32(r.u32, val); + PrintVectQ(r, name, 32); +} + +inline void PrintReg(const uint32x2_t val, const char* name) { + DebugRegister r; + vst1_u32(r.u32, val); + PrintVect(r, name, 32); +} + +inline void PrintReg(const uint16x8_t val, const char* name) { + DebugRegisterQ r; + vst1q_u16(r.u16, val); + PrintVectQ(r, name, 16); +} + +inline void PrintReg(const uint16x4_t val, const char* name) { + DebugRegister r; + vst1_u16(r.u16, val); + PrintVect(r, name, 16); +} + +inline void PrintReg(const uint8x16_t val, const char* name) { + DebugRegisterQ r; + vst1q_u8(r.u8, val); + PrintVectQ(r, name, 8); +} + +inline void PrintReg(const uint8x8_t val, const char* name) { + DebugRegister r; + vst1_u8(r.u8, val); + PrintVect(r, name, 8); +} + +inline void PrintReg(const int32x4_t val, const char* name) { + DebugRegisterQ r; + vst1q_s32(r.i32, val); + PrintVectQ(r, name, 32); +} + +inline void PrintReg(const int32x2_t val, const char* name) { + DebugRegister r; + vst1_s32(r.i32, val); + PrintVect(r, name, 32); +} + +inline void PrintReg(const int16x8_t val, const char* name) { + DebugRegisterQ r; + vst1q_s16(r.i16, val); + PrintVectQ(r, name, 16); +} + +inline void PrintReg(const int16x4_t val, const char* name) { + DebugRegister r; + vst1_s16(r.i16, val); + PrintVect(r, name, 16); +} + +inline void PrintReg(const int8x16_t val, const char* name) { + DebugRegisterQ r; + vst1q_s8(r.i8, val); + PrintVectQ(r, name, 8); +} + +inline void PrintReg(const int8x8_t val, const char* name) { + DebugRegister r; + vst1_s8(r.i8, val); + PrintVect(r, name, 8); +} + +// Print an individual (non-vector) value in decimal format. +inline void PrintReg(const int x, const char* name) { + if (kEnablePrintRegs) { + printf("%s: %d\n", name, x); + } +} + +// Print an individual (non-vector) value in hexadecimal format. +inline void PrintHex(const int x, const char* name) { + if (kEnablePrintRegs) { + printf("%s: %x\n", name, x); + } +} + +#define PR(x) PrintReg(x, #x) +#define PD(x) PrintReg(x, #x) +#define PX(x) PrintHex(x, #x) + +#endif // 0 + +namespace libgav1 { +namespace dsp { + +//------------------------------------------------------------------------------ +// Load functions. + +// Load 2 uint8_t values into lanes 0 and 1. Zeros the register before loading +// the values. Use caution when using this in loops because it will re-zero the +// register before loading on every iteration. +inline uint8x8_t Load2(const void* const buf) { + const uint16x4_t zero = vdup_n_u16(0); + uint16_t temp; + memcpy(&temp, buf, 2); + return vreinterpret_u8_u16(vld1_lane_u16(&temp, zero, 0)); +} + +// Load 2 uint8_t values into |lane| * 2 and |lane| * 2 + 1. +template <int lane> +inline uint8x8_t Load2(const void* const buf, uint8x8_t val) { + uint16_t temp; + memcpy(&temp, buf, 2); + return vreinterpret_u8_u16( + vld1_lane_u16(&temp, vreinterpret_u16_u8(val), lane)); +} + +// Load 4 uint8_t values into the low half of a uint8x8_t register. Zeros the +// register before loading the values. Use caution when using this in loops +// because it will re-zero the register before loading on every iteration. +inline uint8x8_t Load4(const void* const buf) { + const uint32x2_t zero = vdup_n_u32(0); + uint32_t temp; + memcpy(&temp, buf, 4); + return vreinterpret_u8_u32(vld1_lane_u32(&temp, zero, 0)); +} + +// Load 4 uint8_t values into 4 lanes staring with |lane| * 4. +template <int lane> +inline uint8x8_t Load4(const void* const buf, uint8x8_t val) { + uint32_t temp; + memcpy(&temp, buf, 4); + return vreinterpret_u8_u32( + vld1_lane_u32(&temp, vreinterpret_u32_u8(val), lane)); +} + +//------------------------------------------------------------------------------ +// Store functions. + +// Propagate type information to the compiler. Without this the compiler may +// assume the required alignment of the type (4 bytes in the case of uint32_t) +// and add alignment hints to the memory access. +template <typename T> +inline void ValueToMem(void* const buf, T val) { + memcpy(buf, &val, sizeof(val)); +} + +// Store 4 int8_t values from the low half of an int8x8_t register. +inline void StoreLo4(void* const buf, const int8x8_t val) { + ValueToMem<int32_t>(buf, vget_lane_s32(vreinterpret_s32_s8(val), 0)); +} + +// Store 4 uint8_t values from the low half of a uint8x8_t register. +inline void StoreLo4(void* const buf, const uint8x8_t val) { + ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u8(val), 0)); +} + +// Store 4 uint8_t values from the high half of a uint8x8_t register. +inline void StoreHi4(void* const buf, const uint8x8_t val) { + ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u8(val), 1)); +} + +// Store 2 uint8_t values from |lane| * 2 and |lane| * 2 + 1 of a uint8x8_t +// register. +template <int lane> +inline void Store2(void* const buf, const uint8x8_t val) { + ValueToMem<uint16_t>(buf, vget_lane_u16(vreinterpret_u16_u8(val), lane)); +} + +// Store 2 uint16_t values from |lane| * 2 and |lane| * 2 + 1 of a uint16x8_t +// register. +template <int lane> +inline void Store2(void* const buf, const uint16x8_t val) { + ValueToMem<uint32_t>(buf, vgetq_lane_u32(vreinterpretq_u32_u16(val), lane)); +} + +// Store 2 uint16_t values from |lane| * 2 and |lane| * 2 + 1 of a uint16x4_t +// register. +template <int lane> +inline void Store2(uint16_t* const buf, const uint16x4_t val) { + ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u16(val), lane)); +} + +//------------------------------------------------------------------------------ +// Bit manipulation. + +// vshXX_n_XX() requires an immediate. +template <int shift> +inline uint8x8_t LeftShift(const uint8x8_t vector) { + return vreinterpret_u8_u64(vshl_n_u64(vreinterpret_u64_u8(vector), shift)); +} + +template <int shift> +inline uint8x8_t RightShift(const uint8x8_t vector) { + return vreinterpret_u8_u64(vshr_n_u64(vreinterpret_u64_u8(vector), shift)); +} + +template <int shift> +inline int8x8_t RightShift(const int8x8_t vector) { + return vreinterpret_s8_u64(vshr_n_u64(vreinterpret_u64_s8(vector), shift)); +} + +// Shim vqtbl1_u8 for armv7. +inline uint8x8_t VQTbl1U8(const uint8x16_t a, const uint8x8_t index) { +#if defined(__aarch64__) + return vqtbl1_u8(a, index); +#else + const uint8x8x2_t b = {vget_low_u8(a), vget_high_u8(a)}; + return vtbl2_u8(b, index); +#endif +} + +// Shim vqtbl1_s8 for armv7. +inline int8x8_t VQTbl1S8(const int8x16_t a, const uint8x8_t index) { +#if defined(__aarch64__) + return vqtbl1_s8(a, index); +#else + const int8x8x2_t b = {vget_low_s8(a), vget_high_s8(a)}; + return vtbl2_s8(b, vreinterpret_s8_u8(index)); +#endif +} + +//------------------------------------------------------------------------------ +// Interleave. + +// vzipN is exclusive to A64. +inline uint8x8_t InterleaveLow8(const uint8x8_t a, const uint8x8_t b) { +#if defined(__aarch64__) + return vzip1_u8(a, b); +#else + // Discard |.val[1]| + return vzip_u8(a, b).val[0]; +#endif +} + +inline uint8x8_t InterleaveLow32(const uint8x8_t a, const uint8x8_t b) { +#if defined(__aarch64__) + return vreinterpret_u8_u32( + vzip1_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b))); +#else + // Discard |.val[1]| + return vreinterpret_u8_u32( + vzip_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)).val[0]); +#endif +} + +inline int8x8_t InterleaveLow32(const int8x8_t a, const int8x8_t b) { +#if defined(__aarch64__) + return vreinterpret_s8_u32( + vzip1_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b))); +#else + // Discard |.val[1]| + return vreinterpret_s8_u32( + vzip_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)).val[0]); +#endif +} + +inline uint8x8_t InterleaveHigh32(const uint8x8_t a, const uint8x8_t b) { +#if defined(__aarch64__) + return vreinterpret_u8_u32( + vzip2_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b))); +#else + // Discard |.val[0]| + return vreinterpret_u8_u32( + vzip_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)).val[1]); +#endif +} + +inline int8x8_t InterleaveHigh32(const int8x8_t a, const int8x8_t b) { +#if defined(__aarch64__) + return vreinterpret_s8_u32( + vzip2_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b))); +#else + // Discard |.val[0]| + return vreinterpret_s8_u32( + vzip_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)).val[1]); +#endif +} + +//------------------------------------------------------------------------------ +// Sum. + +inline uint16_t SumVector(const uint8x8_t a) { +#if defined(__aarch64__) + return vaddlv_u8(a); +#else + const uint16x4_t c = vpaddl_u8(a); + const uint32x2_t d = vpaddl_u16(c); + const uint64x1_t e = vpaddl_u32(d); + return static_cast<uint16_t>(vget_lane_u64(e, 0)); +#endif // defined(__aarch64__) +} + +inline uint32_t SumVector(const uint32x4_t a) { +#if defined(__aarch64__) + return vaddvq_u32(a); +#else + const uint64x2_t b = vpaddlq_u32(a); + const uint64x1_t c = vadd_u64(vget_low_u64(b), vget_high_u64(b)); + return static_cast<uint32_t>(vget_lane_u64(c, 0)); +#endif +} + +//------------------------------------------------------------------------------ +// Transpose. + +// Transpose 32 bit elements such that: +// a: 00 01 +// b: 02 03 +// returns +// val[0]: 00 02 +// val[1]: 01 03 +inline uint8x8x2_t Interleave32(const uint8x8_t a, const uint8x8_t b) { + const uint32x2_t a_32 = vreinterpret_u32_u8(a); + const uint32x2_t b_32 = vreinterpret_u32_u8(b); + const uint32x2x2_t c = vtrn_u32(a_32, b_32); + const uint8x8x2_t d = {vreinterpret_u8_u32(c.val[0]), + vreinterpret_u8_u32(c.val[1])}; + return d; +} + +// Swap high and low 32 bit elements. +inline uint8x8_t Transpose32(const uint8x8_t a) { + const uint32x2_t b = vrev64_u32(vreinterpret_u32_u8(a)); + return vreinterpret_u8_u32(b); +} + +// Implement vtrnq_s64(). +// Input: +// a0: 00 01 02 03 04 05 06 07 +// a1: 16 17 18 19 20 21 22 23 +// Output: +// b0.val[0]: 00 01 02 03 16 17 18 19 +// b0.val[1]: 04 05 06 07 20 21 22 23 +inline int16x8x2_t VtrnqS64(int32x4_t a0, int32x4_t a1) { + int16x8x2_t b0; + b0.val[0] = vcombine_s16(vreinterpret_s16_s32(vget_low_s32(a0)), + vreinterpret_s16_s32(vget_low_s32(a1))); + b0.val[1] = vcombine_s16(vreinterpret_s16_s32(vget_high_s32(a0)), + vreinterpret_s16_s32(vget_high_s32(a1))); + return b0; +} + +inline uint16x8x2_t VtrnqU64(uint32x4_t a0, uint32x4_t a1) { + uint16x8x2_t b0; + b0.val[0] = vcombine_u16(vreinterpret_u16_u32(vget_low_u32(a0)), + vreinterpret_u16_u32(vget_low_u32(a1))); + b0.val[1] = vcombine_u16(vreinterpret_u16_u32(vget_high_u32(a0)), + vreinterpret_u16_u32(vget_high_u32(a1))); + return b0; +} + +// Input: +// a: 00 01 02 03 10 11 12 13 +// b: 20 21 22 23 30 31 32 33 +// Output: +// Note that columns [1] and [2] are transposed. +// a: 00 10 20 30 02 12 22 32 +// b: 01 11 21 31 03 13 23 33 +inline void Transpose4x4(uint8x8_t* a, uint8x8_t* b) { + const uint16x4x2_t c = + vtrn_u16(vreinterpret_u16_u8(*a), vreinterpret_u16_u8(*b)); + const uint32x2x2_t d = + vtrn_u32(vreinterpret_u32_u16(c.val[0]), vreinterpret_u32_u16(c.val[1])); + const uint8x8x2_t e = + vtrn_u8(vreinterpret_u8_u32(d.val[0]), vreinterpret_u8_u32(d.val[1])); + *a = e.val[0]; + *b = e.val[1]; +} + +// Reversible if the x4 values are packed next to each other. +// x4 input / x8 output: +// a0: 00 01 02 03 40 41 42 43 44 +// a1: 10 11 12 13 50 51 52 53 54 +// a2: 20 21 22 23 60 61 62 63 64 +// a3: 30 31 32 33 70 71 72 73 74 +// x8 input / x4 output: +// a0: 00 10 20 30 40 50 60 70 +// a1: 01 11 21 31 41 51 61 71 +// a2: 02 12 22 32 42 52 62 72 +// a3: 03 13 23 33 43 53 63 73 +inline void Transpose8x4(uint8x8_t* a0, uint8x8_t* a1, uint8x8_t* a2, + uint8x8_t* a3) { + const uint8x8x2_t b0 = vtrn_u8(*a0, *a1); + const uint8x8x2_t b1 = vtrn_u8(*a2, *a3); + + const uint16x4x2_t c0 = + vtrn_u16(vreinterpret_u16_u8(b0.val[0]), vreinterpret_u16_u8(b1.val[0])); + const uint16x4x2_t c1 = + vtrn_u16(vreinterpret_u16_u8(b0.val[1]), vreinterpret_u16_u8(b1.val[1])); + + *a0 = vreinterpret_u8_u16(c0.val[0]); + *a1 = vreinterpret_u8_u16(c1.val[0]); + *a2 = vreinterpret_u8_u16(c0.val[1]); + *a3 = vreinterpret_u8_u16(c1.val[1]); +} + +// Input: +// a[0]: 00 01 02 03 04 05 06 07 +// a[1]: 10 11 12 13 14 15 16 17 +// a[2]: 20 21 22 23 24 25 26 27 +// a[3]: 30 31 32 33 34 35 36 37 +// a[4]: 40 41 42 43 44 45 46 47 +// a[5]: 50 51 52 53 54 55 56 57 +// a[6]: 60 61 62 63 64 65 66 67 +// a[7]: 70 71 72 73 74 75 76 77 + +// Output: +// a[0]: 00 10 20 30 40 50 60 70 +// a[1]: 01 11 21 31 41 51 61 71 +// a[2]: 02 12 22 32 42 52 62 72 +// a[3]: 03 13 23 33 43 53 63 73 +// a[4]: 04 14 24 34 44 54 64 74 +// a[5]: 05 15 25 35 45 55 65 75 +// a[6]: 06 16 26 36 46 56 66 76 +// a[7]: 07 17 27 37 47 57 67 77 +inline void Transpose8x8(int8x8_t a[8]) { + // Swap 8 bit elements. Goes from: + // a[0]: 00 01 02 03 04 05 06 07 + // a[1]: 10 11 12 13 14 15 16 17 + // a[2]: 20 21 22 23 24 25 26 27 + // a[3]: 30 31 32 33 34 35 36 37 + // a[4]: 40 41 42 43 44 45 46 47 + // a[5]: 50 51 52 53 54 55 56 57 + // a[6]: 60 61 62 63 64 65 66 67 + // a[7]: 70 71 72 73 74 75 76 77 + // to: + // b0.val[0]: 00 10 02 12 04 14 06 16 40 50 42 52 44 54 46 56 + // b0.val[1]: 01 11 03 13 05 15 07 17 41 51 43 53 45 55 47 57 + // b1.val[0]: 20 30 22 32 24 34 26 36 60 70 62 72 64 74 66 76 + // b1.val[1]: 21 31 23 33 25 35 27 37 61 71 63 73 65 75 67 77 + const int8x16x2_t b0 = + vtrnq_s8(vcombine_s8(a[0], a[4]), vcombine_s8(a[1], a[5])); + const int8x16x2_t b1 = + vtrnq_s8(vcombine_s8(a[2], a[6]), vcombine_s8(a[3], a[7])); + + // Swap 16 bit elements resulting in: + // c0.val[0]: 00 10 20 30 04 14 24 34 40 50 60 70 44 54 64 74 + // c0.val[1]: 02 12 22 32 06 16 26 36 42 52 62 72 46 56 66 76 + // c1.val[0]: 01 11 21 31 05 15 25 35 41 51 61 71 45 55 65 75 + // c1.val[1]: 03 13 23 33 07 17 27 37 43 53 63 73 47 57 67 77 + const int16x8x2_t c0 = vtrnq_s16(vreinterpretq_s16_s8(b0.val[0]), + vreinterpretq_s16_s8(b1.val[0])); + const int16x8x2_t c1 = vtrnq_s16(vreinterpretq_s16_s8(b0.val[1]), + vreinterpretq_s16_s8(b1.val[1])); + + // Unzip 32 bit elements resulting in: + // d0.val[0]: 00 10 20 30 40 50 60 70 01 11 21 31 41 51 61 71 + // d0.val[1]: 04 14 24 34 44 54 64 74 05 15 25 35 45 55 65 75 + // d1.val[0]: 02 12 22 32 42 52 62 72 03 13 23 33 43 53 63 73 + // d1.val[1]: 06 16 26 36 46 56 66 76 07 17 27 37 47 57 67 77 + const int32x4x2_t d0 = vuzpq_s32(vreinterpretq_s32_s16(c0.val[0]), + vreinterpretq_s32_s16(c1.val[0])); + const int32x4x2_t d1 = vuzpq_s32(vreinterpretq_s32_s16(c0.val[1]), + vreinterpretq_s32_s16(c1.val[1])); + + a[0] = vreinterpret_s8_s32(vget_low_s32(d0.val[0])); + a[1] = vreinterpret_s8_s32(vget_high_s32(d0.val[0])); + a[2] = vreinterpret_s8_s32(vget_low_s32(d1.val[0])); + a[3] = vreinterpret_s8_s32(vget_high_s32(d1.val[0])); + a[4] = vreinterpret_s8_s32(vget_low_s32(d0.val[1])); + a[5] = vreinterpret_s8_s32(vget_high_s32(d0.val[1])); + a[6] = vreinterpret_s8_s32(vget_low_s32(d1.val[1])); + a[7] = vreinterpret_s8_s32(vget_high_s32(d1.val[1])); +} + +// Unsigned. +inline void Transpose8x8(uint8x8_t a[8]) { + const uint8x16x2_t b0 = + vtrnq_u8(vcombine_u8(a[0], a[4]), vcombine_u8(a[1], a[5])); + const uint8x16x2_t b1 = + vtrnq_u8(vcombine_u8(a[2], a[6]), vcombine_u8(a[3], a[7])); + + const uint16x8x2_t c0 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[0]), + vreinterpretq_u16_u8(b1.val[0])); + const uint16x8x2_t c1 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[1]), + vreinterpretq_u16_u8(b1.val[1])); + + const uint32x4x2_t d0 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[0]), + vreinterpretq_u32_u16(c1.val[0])); + const uint32x4x2_t d1 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[1]), + vreinterpretq_u32_u16(c1.val[1])); + + a[0] = vreinterpret_u8_u32(vget_low_u32(d0.val[0])); + a[1] = vreinterpret_u8_u32(vget_high_u32(d0.val[0])); + a[2] = vreinterpret_u8_u32(vget_low_u32(d1.val[0])); + a[3] = vreinterpret_u8_u32(vget_high_u32(d1.val[0])); + a[4] = vreinterpret_u8_u32(vget_low_u32(d0.val[1])); + a[5] = vreinterpret_u8_u32(vget_high_u32(d0.val[1])); + a[6] = vreinterpret_u8_u32(vget_low_u32(d1.val[1])); + a[7] = vreinterpret_u8_u32(vget_high_u32(d1.val[1])); +} + +inline void Transpose8x8(uint8x8_t in[8], uint8x16_t out[4]) { + const uint8x16x2_t a0 = + vtrnq_u8(vcombine_u8(in[0], in[4]), vcombine_u8(in[1], in[5])); + const uint8x16x2_t a1 = + vtrnq_u8(vcombine_u8(in[2], in[6]), vcombine_u8(in[3], in[7])); + + const uint16x8x2_t b0 = vtrnq_u16(vreinterpretq_u16_u8(a0.val[0]), + vreinterpretq_u16_u8(a1.val[0])); + const uint16x8x2_t b1 = vtrnq_u16(vreinterpretq_u16_u8(a0.val[1]), + vreinterpretq_u16_u8(a1.val[1])); + + const uint32x4x2_t c0 = vuzpq_u32(vreinterpretq_u32_u16(b0.val[0]), + vreinterpretq_u32_u16(b1.val[0])); + const uint32x4x2_t c1 = vuzpq_u32(vreinterpretq_u32_u16(b0.val[1]), + vreinterpretq_u32_u16(b1.val[1])); + + out[0] = vreinterpretq_u8_u32(c0.val[0]); + out[1] = vreinterpretq_u8_u32(c1.val[0]); + out[2] = vreinterpretq_u8_u32(c0.val[1]); + out[3] = vreinterpretq_u8_u32(c1.val[1]); +} + +// Input: +// a[0]: 00 01 02 03 04 05 06 07 +// a[1]: 10 11 12 13 14 15 16 17 +// a[2]: 20 21 22 23 24 25 26 27 +// a[3]: 30 31 32 33 34 35 36 37 +// a[4]: 40 41 42 43 44 45 46 47 +// a[5]: 50 51 52 53 54 55 56 57 +// a[6]: 60 61 62 63 64 65 66 67 +// a[7]: 70 71 72 73 74 75 76 77 + +// Output: +// a[0]: 00 10 20 30 40 50 60 70 +// a[1]: 01 11 21 31 41 51 61 71 +// a[2]: 02 12 22 32 42 52 62 72 +// a[3]: 03 13 23 33 43 53 63 73 +// a[4]: 04 14 24 34 44 54 64 74 +// a[5]: 05 15 25 35 45 55 65 75 +// a[6]: 06 16 26 36 46 56 66 76 +// a[7]: 07 17 27 37 47 57 67 77 +inline void Transpose8x8(int16x8_t a[8]) { + const int16x8x2_t b0 = vtrnq_s16(a[0], a[1]); + const int16x8x2_t b1 = vtrnq_s16(a[2], a[3]); + const int16x8x2_t b2 = vtrnq_s16(a[4], a[5]); + const int16x8x2_t b3 = vtrnq_s16(a[6], a[7]); + + const int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[0]), + vreinterpretq_s32_s16(b1.val[0])); + const int32x4x2_t c1 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[1]), + vreinterpretq_s32_s16(b1.val[1])); + const int32x4x2_t c2 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[0]), + vreinterpretq_s32_s16(b3.val[0])); + const int32x4x2_t c3 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[1]), + vreinterpretq_s32_s16(b3.val[1])); + + const int16x8x2_t d0 = VtrnqS64(c0.val[0], c2.val[0]); + const int16x8x2_t d1 = VtrnqS64(c1.val[0], c3.val[0]); + const int16x8x2_t d2 = VtrnqS64(c0.val[1], c2.val[1]); + const int16x8x2_t d3 = VtrnqS64(c1.val[1], c3.val[1]); + + a[0] = d0.val[0]; + a[1] = d1.val[0]; + a[2] = d2.val[0]; + a[3] = d3.val[0]; + a[4] = d0.val[1]; + a[5] = d1.val[1]; + a[6] = d2.val[1]; + a[7] = d3.val[1]; +} + +// Unsigned. +inline void Transpose8x8(uint16x8_t a[8]) { + const uint16x8x2_t b0 = vtrnq_u16(a[0], a[1]); + const uint16x8x2_t b1 = vtrnq_u16(a[2], a[3]); + const uint16x8x2_t b2 = vtrnq_u16(a[4], a[5]); + const uint16x8x2_t b3 = vtrnq_u16(a[6], a[7]); + + const uint32x4x2_t c0 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[0]), + vreinterpretq_u32_u16(b1.val[0])); + const uint32x4x2_t c1 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[1]), + vreinterpretq_u32_u16(b1.val[1])); + const uint32x4x2_t c2 = vtrnq_u32(vreinterpretq_u32_u16(b2.val[0]), + vreinterpretq_u32_u16(b3.val[0])); + const uint32x4x2_t c3 = vtrnq_u32(vreinterpretq_u32_u16(b2.val[1]), + vreinterpretq_u32_u16(b3.val[1])); + + const uint16x8x2_t d0 = VtrnqU64(c0.val[0], c2.val[0]); + const uint16x8x2_t d1 = VtrnqU64(c1.val[0], c3.val[0]); + const uint16x8x2_t d2 = VtrnqU64(c0.val[1], c2.val[1]); + const uint16x8x2_t d3 = VtrnqU64(c1.val[1], c3.val[1]); + + a[0] = d0.val[0]; + a[1] = d1.val[0]; + a[2] = d2.val[0]; + a[3] = d3.val[0]; + a[4] = d0.val[1]; + a[5] = d1.val[1]; + a[6] = d2.val[1]; + a[7] = d3.val[1]; +} + +// Input: +// a[0]: 00 01 02 03 04 05 06 07 80 81 82 83 84 85 86 87 +// a[1]: 10 11 12 13 14 15 16 17 90 91 92 93 94 95 96 97 +// a[2]: 20 21 22 23 24 25 26 27 a0 a1 a2 a3 a4 a5 a6 a7 +// a[3]: 30 31 32 33 34 35 36 37 b0 b1 b2 b3 b4 b5 b6 b7 +// a[4]: 40 41 42 43 44 45 46 47 c0 c1 c2 c3 c4 c5 c6 c7 +// a[5]: 50 51 52 53 54 55 56 57 d0 d1 d2 d3 d4 d5 d6 d7 +// a[6]: 60 61 62 63 64 65 66 67 e0 e1 e2 e3 e4 e5 e6 e7 +// a[7]: 70 71 72 73 74 75 76 77 f0 f1 f2 f3 f4 f5 f6 f7 + +// Output: +// a[0]: 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 +// a[1]: 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 +// a[2]: 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 +// a[3]: 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 +// a[4]: 04 14 24 34 44 54 64 74 84 94 a4 b4 c4 d4 e4 f4 +// a[5]: 05 15 25 35 45 55 65 75 85 95 a5 b5 c5 d5 e5 f5 +// a[6]: 06 16 26 36 46 56 66 76 86 96 a6 b6 c6 d6 e6 f6 +// a[7]: 07 17 27 37 47 57 67 77 87 97 a7 b7 c7 d7 e7 f7 +inline void Transpose8x16(uint8x16_t a[8]) { + // b0.val[0]: 00 10 02 12 04 14 06 16 80 90 82 92 84 94 86 96 + // b0.val[1]: 01 11 03 13 05 15 07 17 81 91 83 93 85 95 87 97 + // b1.val[0]: 20 30 22 32 24 34 26 36 a0 b0 a2 b2 a4 b4 a6 b6 + // b1.val[1]: 21 31 23 33 25 35 27 37 a1 b1 a3 b3 a5 b5 a7 b7 + // b2.val[0]: 40 50 42 52 44 54 46 56 c0 d0 c2 d2 c4 d4 c6 d6 + // b2.val[1]: 41 51 43 53 45 55 47 57 c1 d1 c3 d3 c5 d5 c7 d7 + // b3.val[0]: 60 70 62 72 64 74 66 76 e0 f0 e2 f2 e4 f4 e6 f6 + // b3.val[1]: 61 71 63 73 65 75 67 77 e1 f1 e3 f3 e5 f5 e7 f7 + const uint8x16x2_t b0 = vtrnq_u8(a[0], a[1]); + const uint8x16x2_t b1 = vtrnq_u8(a[2], a[3]); + const uint8x16x2_t b2 = vtrnq_u8(a[4], a[5]); + const uint8x16x2_t b3 = vtrnq_u8(a[6], a[7]); + + // c0.val[0]: 00 10 20 30 04 14 24 34 80 90 a0 b0 84 94 a4 b4 + // c0.val[1]: 02 12 22 32 06 16 26 36 82 92 a2 b2 86 96 a6 b6 + // c1.val[0]: 01 11 21 31 05 15 25 35 81 91 a1 b1 85 95 a5 b5 + // c1.val[1]: 03 13 23 33 07 17 27 37 83 93 a3 b3 87 97 a7 b7 + // c2.val[0]: 40 50 60 70 44 54 64 74 c0 d0 e0 f0 c4 d4 e4 f4 + // c2.val[1]: 42 52 62 72 46 56 66 76 c2 d2 e2 f2 c6 d6 e6 f6 + // c3.val[0]: 41 51 61 71 45 55 65 75 c1 d1 e1 f1 c5 d5 e5 f5 + // c3.val[1]: 43 53 63 73 47 57 67 77 c3 d3 e3 f3 c7 d7 e7 f7 + const uint16x8x2_t c0 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[0]), + vreinterpretq_u16_u8(b1.val[0])); + const uint16x8x2_t c1 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[1]), + vreinterpretq_u16_u8(b1.val[1])); + const uint16x8x2_t c2 = vtrnq_u16(vreinterpretq_u16_u8(b2.val[0]), + vreinterpretq_u16_u8(b3.val[0])); + const uint16x8x2_t c3 = vtrnq_u16(vreinterpretq_u16_u8(b2.val[1]), + vreinterpretq_u16_u8(b3.val[1])); + + // d0.val[0]: 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // d0.val[1]: 04 14 24 34 44 54 64 74 84 94 a4 b4 c4 d4 e4 f4 + // d1.val[0]: 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + // d1.val[1]: 05 15 25 35 45 55 65 75 85 95 a5 b5 c5 d5 e5 f5 + // d2.val[0]: 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 + // d2.val[1]: 06 16 26 36 46 56 66 76 86 96 a6 b6 c6 d6 e6 f6 + // d3.val[0]: 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 + // d3.val[1]: 07 17 27 37 47 57 67 77 87 97 a7 b7 c7 d7 e7 f7 + const uint32x4x2_t d0 = vtrnq_u32(vreinterpretq_u32_u16(c0.val[0]), + vreinterpretq_u32_u16(c2.val[0])); + const uint32x4x2_t d1 = vtrnq_u32(vreinterpretq_u32_u16(c1.val[0]), + vreinterpretq_u32_u16(c3.val[0])); + const uint32x4x2_t d2 = vtrnq_u32(vreinterpretq_u32_u16(c0.val[1]), + vreinterpretq_u32_u16(c2.val[1])); + const uint32x4x2_t d3 = vtrnq_u32(vreinterpretq_u32_u16(c1.val[1]), + vreinterpretq_u32_u16(c3.val[1])); + + a[0] = vreinterpretq_u8_u32(d0.val[0]); + a[1] = vreinterpretq_u8_u32(d1.val[0]); + a[2] = vreinterpretq_u8_u32(d2.val[0]); + a[3] = vreinterpretq_u8_u32(d3.val[0]); + a[4] = vreinterpretq_u8_u32(d0.val[1]); + a[5] = vreinterpretq_u8_u32(d1.val[1]); + a[6] = vreinterpretq_u8_u32(d2.val[1]); + a[7] = vreinterpretq_u8_u32(d3.val[1]); +} + +inline int16x8_t ZeroExtend(const uint8x8_t in) { + return vreinterpretq_s16_u16(vmovl_u8(in)); +} + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_ENABLE_NEON +#endif // LIBGAV1_SRC_DSP_ARM_COMMON_NEON_H_ diff --git a/src/dsp/arm/convolve_neon.cc b/src/dsp/arm/convolve_neon.cc new file mode 100644 index 0000000..fd9b912 --- /dev/null +++ b/src/dsp/arm/convolve_neon.cc @@ -0,0 +1,3105 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/convolve.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/compiler_attributes.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +// Include the constants and utility functions inside the anonymous namespace. +#include "src/dsp/convolve.inc" + +// Multiply every entry in |src[]| by the corresponding entry in |taps[]| and +// sum. The filters in |taps[]| are pre-shifted by 1. This prevents the final +// sum from outranging int16_t. +template <int filter_index, bool negative_outside_taps = false> +int16x8_t SumOnePassTaps(const uint8x8_t* const src, + const uint8x8_t* const taps) { + uint16x8_t sum; + if (filter_index == 0) { + // 6 taps. + - + + - + + sum = vmull_u8(src[0], taps[0]); + // Unsigned overflow will result in a valid int16_t value. + sum = vmlsl_u8(sum, src[1], taps[1]); + sum = vmlal_u8(sum, src[2], taps[2]); + sum = vmlal_u8(sum, src[3], taps[3]); + sum = vmlsl_u8(sum, src[4], taps[4]); + sum = vmlal_u8(sum, src[5], taps[5]); + } else if (filter_index == 1 && negative_outside_taps) { + // 6 taps. - + + + + - + // Set a base we can subtract from. + sum = vmull_u8(src[1], taps[1]); + sum = vmlsl_u8(sum, src[0], taps[0]); + sum = vmlal_u8(sum, src[2], taps[2]); + sum = vmlal_u8(sum, src[3], taps[3]); + sum = vmlal_u8(sum, src[4], taps[4]); + sum = vmlsl_u8(sum, src[5], taps[5]); + } else if (filter_index == 1) { + // 6 taps. All are positive. + sum = vmull_u8(src[0], taps[0]); + sum = vmlal_u8(sum, src[1], taps[1]); + sum = vmlal_u8(sum, src[2], taps[2]); + sum = vmlal_u8(sum, src[3], taps[3]); + sum = vmlal_u8(sum, src[4], taps[4]); + sum = vmlal_u8(sum, src[5], taps[5]); + } else if (filter_index == 2) { + // 8 taps. - + - + + - + - + sum = vmull_u8(src[1], taps[1]); + sum = vmlsl_u8(sum, src[0], taps[0]); + sum = vmlsl_u8(sum, src[2], taps[2]); + sum = vmlal_u8(sum, src[3], taps[3]); + sum = vmlal_u8(sum, src[4], taps[4]); + sum = vmlsl_u8(sum, src[5], taps[5]); + sum = vmlal_u8(sum, src[6], taps[6]); + sum = vmlsl_u8(sum, src[7], taps[7]); + } else if (filter_index == 3) { + // 2 taps. All are positive. + sum = vmull_u8(src[0], taps[0]); + sum = vmlal_u8(sum, src[1], taps[1]); + } else if (filter_index == 4) { + // 4 taps. - + + - + sum = vmull_u8(src[1], taps[1]); + sum = vmlsl_u8(sum, src[0], taps[0]); + sum = vmlal_u8(sum, src[2], taps[2]); + sum = vmlsl_u8(sum, src[3], taps[3]); + } else if (filter_index == 5) { + // 4 taps. All are positive. + sum = vmull_u8(src[0], taps[0]); + sum = vmlal_u8(sum, src[1], taps[1]); + sum = vmlal_u8(sum, src[2], taps[2]); + sum = vmlal_u8(sum, src[3], taps[3]); + } + return vreinterpretq_s16_u16(sum); +} + +template <int filter_index, bool negative_outside_taps> +int16x8_t SumHorizontalTaps(const uint8_t* const src, + const uint8x8_t* const v_tap) { + uint8x8_t v_src[8]; + const uint8x16_t src_long = vld1q_u8(src); + int16x8_t sum; + + if (filter_index < 2) { + v_src[0] = vget_low_u8(vextq_u8(src_long, src_long, 1)); + v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 2)); + v_src[2] = vget_low_u8(vextq_u8(src_long, src_long, 3)); + v_src[3] = vget_low_u8(vextq_u8(src_long, src_long, 4)); + v_src[4] = vget_low_u8(vextq_u8(src_long, src_long, 5)); + v_src[5] = vget_low_u8(vextq_u8(src_long, src_long, 6)); + sum = SumOnePassTaps<filter_index, negative_outside_taps>(v_src, v_tap + 1); + } else if (filter_index == 2) { + v_src[0] = vget_low_u8(src_long); + v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 1)); + v_src[2] = vget_low_u8(vextq_u8(src_long, src_long, 2)); + v_src[3] = vget_low_u8(vextq_u8(src_long, src_long, 3)); + v_src[4] = vget_low_u8(vextq_u8(src_long, src_long, 4)); + v_src[5] = vget_low_u8(vextq_u8(src_long, src_long, 5)); + v_src[6] = vget_low_u8(vextq_u8(src_long, src_long, 6)); + v_src[7] = vget_low_u8(vextq_u8(src_long, src_long, 7)); + sum = SumOnePassTaps<filter_index, negative_outside_taps>(v_src, v_tap); + } else if (filter_index == 3) { + v_src[0] = vget_low_u8(vextq_u8(src_long, src_long, 3)); + v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 4)); + sum = SumOnePassTaps<filter_index, negative_outside_taps>(v_src, v_tap + 3); + } else if (filter_index > 3) { + v_src[0] = vget_low_u8(vextq_u8(src_long, src_long, 2)); + v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 3)); + v_src[2] = vget_low_u8(vextq_u8(src_long, src_long, 4)); + v_src[3] = vget_low_u8(vextq_u8(src_long, src_long, 5)); + sum = SumOnePassTaps<filter_index, negative_outside_taps>(v_src, v_tap + 2); + } + return sum; +} + +template <int filter_index, bool negative_outside_taps> +uint8x8_t SimpleHorizontalTaps(const uint8_t* const src, + const uint8x8_t* const v_tap) { + int16x8_t sum = + SumHorizontalTaps<filter_index, negative_outside_taps>(src, v_tap); + + // Normally the Horizontal pass does the downshift in two passes: + // kInterRoundBitsHorizontal - 1 and then (kFilterBits - + // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them + // requires adding the rounding offset from the skipped shift. + constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2); + + sum = vaddq_s16(sum, vdupq_n_s16(first_shift_rounding_bit)); + return vqrshrun_n_s16(sum, kFilterBits - 1); +} + +template <int filter_index, bool negative_outside_taps> +uint16x8_t HorizontalTaps8To16(const uint8_t* const src, + const uint8x8_t* const v_tap) { + const int16x8_t sum = + SumHorizontalTaps<filter_index, negative_outside_taps>(src, v_tap); + + return vreinterpretq_u16_s16( + vrshrq_n_s16(sum, kInterRoundBitsHorizontal - 1)); +} + +template <int filter_index> +int16x8_t SumHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride, + const uint8x8_t* const v_tap) { + uint16x8_t sum; + const uint8x8_t input0 = vld1_u8(src); + src += src_stride; + const uint8x8_t input1 = vld1_u8(src); + uint8x8x2_t input = vzip_u8(input0, input1); + + if (filter_index == 3) { + // tap signs : + + + sum = vmull_u8(vext_u8(input.val[0], input.val[1], 6), v_tap[3]); + sum = vmlal_u8(sum, input.val[1], v_tap[4]); + } else if (filter_index == 4) { + // tap signs : - + + - + sum = vmull_u8(vext_u8(input.val[0], input.val[1], 6), v_tap[3]); + sum = vmlsl_u8(sum, RightShift<4 * 8>(input.val[0]), v_tap[2]); + sum = vmlal_u8(sum, input.val[1], v_tap[4]); + sum = vmlsl_u8(sum, RightShift<2 * 8>(input.val[1]), v_tap[5]); + } else { + // tap signs : + + + + + sum = vmull_u8(RightShift<4 * 8>(input.val[0]), v_tap[2]); + sum = vmlal_u8(sum, vext_u8(input.val[0], input.val[1], 6), v_tap[3]); + sum = vmlal_u8(sum, input.val[1], v_tap[4]); + sum = vmlal_u8(sum, RightShift<2 * 8>(input.val[1]), v_tap[5]); + } + + return vreinterpretq_s16_u16(sum); +} + +template <int filter_index> +uint8x8_t SimpleHorizontalTaps2x2(const uint8_t* src, + const ptrdiff_t src_stride, + const uint8x8_t* const v_tap) { + int16x8_t sum = SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); + + // Normally the Horizontal pass does the downshift in two passes: + // kInterRoundBitsHorizontal - 1 and then (kFilterBits - + // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them + // requires adding the rounding offset from the skipped shift. + constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2); + + sum = vaddq_s16(sum, vdupq_n_s16(first_shift_rounding_bit)); + return vqrshrun_n_s16(sum, kFilterBits - 1); +} + +template <int filter_index> +uint16x8_t HorizontalTaps8To16_2x2(const uint8_t* src, + const ptrdiff_t src_stride, + const uint8x8_t* const v_tap) { + const int16x8_t sum = + SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); + + return vreinterpretq_u16_s16( + vrshrq_n_s16(sum, kInterRoundBitsHorizontal - 1)); +} + +template <int num_taps, int step, int filter_index, + bool negative_outside_taps = true, bool is_2d = false, + bool is_compound = false> +void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, + void* const dest, const ptrdiff_t pred_stride, + const int width, const int height, + const uint8x8_t* const v_tap) { + auto* dest8 = static_cast<uint8_t*>(dest); + auto* dest16 = static_cast<uint16_t*>(dest); + + // 4 tap filters are never used when width > 4. + if (num_taps != 4 && width > 4) { + int y = 0; + do { + int x = 0; + do { + if (is_2d || is_compound) { + const uint16x8_t v_sum = + HorizontalTaps8To16<filter_index, negative_outside_taps>(&src[x], + v_tap); + vst1q_u16(&dest16[x], v_sum); + } else { + const uint8x8_t result = + SimpleHorizontalTaps<filter_index, negative_outside_taps>(&src[x], + v_tap); + vst1_u8(&dest8[x], result); + } + x += step; + } while (x < width); + src += src_stride; + dest8 += pred_stride; + dest16 += pred_stride; + } while (++y < height); + return; + } + + // Horizontal passes only needs to account for |num_taps| 2 and 4 when + // |width| <= 4. + assert(width <= 4); + assert(num_taps <= 4); + if (num_taps <= 4) { + if (width == 4) { + int y = 0; + do { + if (is_2d || is_compound) { + const uint16x8_t v_sum = + HorizontalTaps8To16<filter_index, negative_outside_taps>(src, + v_tap); + vst1_u16(dest16, vget_low_u16(v_sum)); + } else { + const uint8x8_t result = + SimpleHorizontalTaps<filter_index, negative_outside_taps>(src, + v_tap); + StoreLo4(&dest8[0], result); + } + src += src_stride; + dest8 += pred_stride; + dest16 += pred_stride; + } while (++y < height); + return; + } + + if (!is_compound) { + int y = 0; + do { + if (is_2d) { + const uint16x8_t sum = + HorizontalTaps8To16_2x2<filter_index>(src, src_stride, v_tap); + dest16[0] = vgetq_lane_u16(sum, 0); + dest16[1] = vgetq_lane_u16(sum, 2); + dest16 += pred_stride; + dest16[0] = vgetq_lane_u16(sum, 1); + dest16[1] = vgetq_lane_u16(sum, 3); + dest16 += pred_stride; + } else { + const uint8x8_t sum = + SimpleHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); + + dest8[0] = vget_lane_u8(sum, 0); + dest8[1] = vget_lane_u8(sum, 2); + dest8 += pred_stride; + + dest8[0] = vget_lane_u8(sum, 1); + dest8[1] = vget_lane_u8(sum, 3); + dest8 += pred_stride; + } + + src += src_stride << 1; + y += 2; + } while (y < height - 1); + + // The 2d filters have an odd |height| because the horizontal pass + // generates context for the vertical pass. + if (is_2d) { + assert(height % 2 == 1); + uint16x8_t sum; + const uint8x8_t input = vld1_u8(src); + if (filter_index == 3) { // |num_taps| == 2 + sum = vmull_u8(RightShift<3 * 8>(input), v_tap[3]); + sum = vmlal_u8(sum, RightShift<4 * 8>(input), v_tap[4]); + } else if (filter_index == 4) { + sum = vmull_u8(RightShift<3 * 8>(input), v_tap[3]); + sum = vmlsl_u8(sum, RightShift<2 * 8>(input), v_tap[2]); + sum = vmlal_u8(sum, RightShift<4 * 8>(input), v_tap[4]); + sum = vmlsl_u8(sum, RightShift<5 * 8>(input), v_tap[5]); + } else { + assert(filter_index == 5); + sum = vmull_u8(RightShift<2 * 8>(input), v_tap[2]); + sum = vmlal_u8(sum, RightShift<3 * 8>(input), v_tap[3]); + sum = vmlal_u8(sum, RightShift<4 * 8>(input), v_tap[4]); + sum = vmlal_u8(sum, RightShift<5 * 8>(input), v_tap[5]); + } + // |sum| contains an int16_t value. + sum = vreinterpretq_u16_s16(vrshrq_n_s16( + vreinterpretq_s16_u16(sum), kInterRoundBitsHorizontal - 1)); + Store2<0>(dest16, sum); + } + } + } +} + +// Process 16 bit inputs and output 32 bits. +template <int num_taps, bool is_compound> +inline int16x4_t Sum2DVerticalTaps4(const int16x4_t* const src, + const int16x8_t taps) { + const int16x4_t taps_lo = vget_low_s16(taps); + const int16x4_t taps_hi = vget_high_s16(taps); + int32x4_t sum; + if (num_taps == 8) { + sum = vmull_lane_s16(src[0], taps_lo, 0); + sum = vmlal_lane_s16(sum, src[1], taps_lo, 1); + sum = vmlal_lane_s16(sum, src[2], taps_lo, 2); + sum = vmlal_lane_s16(sum, src[3], taps_lo, 3); + sum = vmlal_lane_s16(sum, src[4], taps_hi, 0); + sum = vmlal_lane_s16(sum, src[5], taps_hi, 1); + sum = vmlal_lane_s16(sum, src[6], taps_hi, 2); + sum = vmlal_lane_s16(sum, src[7], taps_hi, 3); + } else if (num_taps == 6) { + sum = vmull_lane_s16(src[0], taps_lo, 1); + sum = vmlal_lane_s16(sum, src[1], taps_lo, 2); + sum = vmlal_lane_s16(sum, src[2], taps_lo, 3); + sum = vmlal_lane_s16(sum, src[3], taps_hi, 0); + sum = vmlal_lane_s16(sum, src[4], taps_hi, 1); + sum = vmlal_lane_s16(sum, src[5], taps_hi, 2); + } else if (num_taps == 4) { + sum = vmull_lane_s16(src[0], taps_lo, 2); + sum = vmlal_lane_s16(sum, src[1], taps_lo, 3); + sum = vmlal_lane_s16(sum, src[2], taps_hi, 0); + sum = vmlal_lane_s16(sum, src[3], taps_hi, 1); + } else if (num_taps == 2) { + sum = vmull_lane_s16(src[0], taps_lo, 3); + sum = vmlal_lane_s16(sum, src[1], taps_hi, 0); + } + + if (is_compound) { + return vqrshrn_n_s32(sum, kInterRoundBitsCompoundVertical - 1); + } + + return vqrshrn_n_s32(sum, kInterRoundBitsVertical - 1); +} + +template <int num_taps, bool is_compound> +int16x8_t SimpleSum2DVerticalTaps(const int16x8_t* const src, + const int16x8_t taps) { + const int16x4_t taps_lo = vget_low_s16(taps); + const int16x4_t taps_hi = vget_high_s16(taps); + int32x4_t sum_lo, sum_hi; + if (num_taps == 8) { + sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 0); + sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 0); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 1); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 1); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_lo, 2); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_lo, 2); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_lo, 3); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_lo, 3); + + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[4]), taps_hi, 0); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[4]), taps_hi, 0); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[5]), taps_hi, 1); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[5]), taps_hi, 1); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[6]), taps_hi, 2); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[6]), taps_hi, 2); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[7]), taps_hi, 3); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[7]), taps_hi, 3); + } else if (num_taps == 6) { + sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 1); + sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 1); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 2); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 2); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_lo, 3); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_lo, 3); + + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_hi, 0); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_hi, 0); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[4]), taps_hi, 1); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[4]), taps_hi, 1); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[5]), taps_hi, 2); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[5]), taps_hi, 2); + } else if (num_taps == 4) { + sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 2); + sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 2); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 3); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 3); + + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_hi, 0); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_hi, 0); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_hi, 1); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_hi, 1); + } else if (num_taps == 2) { + sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 3); + sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 3); + + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_hi, 0); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_hi, 0); + } + + if (is_compound) { + return vcombine_s16( + vqrshrn_n_s32(sum_lo, kInterRoundBitsCompoundVertical - 1), + vqrshrn_n_s32(sum_hi, kInterRoundBitsCompoundVertical - 1)); + } + + return vcombine_s16(vqrshrn_n_s32(sum_lo, kInterRoundBitsVertical - 1), + vqrshrn_n_s32(sum_hi, kInterRoundBitsVertical - 1)); +} + +template <int num_taps, bool is_compound = false> +void Filter2DVertical(const uint16_t* src, void* const dst, + const ptrdiff_t dst_stride, const int width, + const int height, const int16x8_t taps) { + assert(width >= 8); + constexpr int next_row = num_taps - 1; + // The Horizontal pass uses |width| as |stride| for the intermediate buffer. + const ptrdiff_t src_stride = width; + + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + int x = 0; + do { + int16x8_t srcs[8]; + const uint16_t* src_x = src + x; + srcs[0] = vreinterpretq_s16_u16(vld1q_u16(src_x)); + src_x += src_stride; + if (num_taps >= 4) { + srcs[1] = vreinterpretq_s16_u16(vld1q_u16(src_x)); + src_x += src_stride; + srcs[2] = vreinterpretq_s16_u16(vld1q_u16(src_x)); + src_x += src_stride; + if (num_taps >= 6) { + srcs[3] = vreinterpretq_s16_u16(vld1q_u16(src_x)); + src_x += src_stride; + srcs[4] = vreinterpretq_s16_u16(vld1q_u16(src_x)); + src_x += src_stride; + if (num_taps == 8) { + srcs[5] = vreinterpretq_s16_u16(vld1q_u16(src_x)); + src_x += src_stride; + srcs[6] = vreinterpretq_s16_u16(vld1q_u16(src_x)); + src_x += src_stride; + } + } + } + + int y = 0; + do { + srcs[next_row] = vreinterpretq_s16_u16(vld1q_u16(src_x)); + src_x += src_stride; + + const int16x8_t sum = + SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps); + if (is_compound) { + vst1q_u16(dst16 + x + y * dst_stride, vreinterpretq_u16_s16(sum)); + } else { + vst1_u8(dst8 + x + y * dst_stride, vqmovun_s16(sum)); + } + + srcs[0] = srcs[1]; + if (num_taps >= 4) { + srcs[1] = srcs[2]; + srcs[2] = srcs[3]; + if (num_taps >= 6) { + srcs[3] = srcs[4]; + srcs[4] = srcs[5]; + if (num_taps == 8) { + srcs[5] = srcs[6]; + srcs[6] = srcs[7]; + } + } + } + } while (++y < height); + x += 8; + } while (x < width); +} + +// Take advantage of |src_stride| == |width| to process two rows at a time. +template <int num_taps, bool is_compound = false> +void Filter2DVertical4xH(const uint16_t* src, void* const dst, + const ptrdiff_t dst_stride, const int height, + const int16x8_t taps) { + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + int16x8_t srcs[9]; + srcs[0] = vreinterpretq_s16_u16(vld1q_u16(src)); + src += 8; + if (num_taps >= 4) { + srcs[2] = vreinterpretq_s16_u16(vld1q_u16(src)); + src += 8; + srcs[1] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[2])); + if (num_taps >= 6) { + srcs[4] = vreinterpretq_s16_u16(vld1q_u16(src)); + src += 8; + srcs[3] = vcombine_s16(vget_high_s16(srcs[2]), vget_low_s16(srcs[4])); + if (num_taps == 8) { + srcs[6] = vreinterpretq_s16_u16(vld1q_u16(src)); + src += 8; + srcs[5] = vcombine_s16(vget_high_s16(srcs[4]), vget_low_s16(srcs[6])); + } + } + } + + int y = 0; + do { + srcs[num_taps] = vreinterpretq_s16_u16(vld1q_u16(src)); + src += 8; + srcs[num_taps - 1] = vcombine_s16(vget_high_s16(srcs[num_taps - 2]), + vget_low_s16(srcs[num_taps])); + + const int16x8_t sum = + SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps); + if (is_compound) { + const uint16x8_t results = vreinterpretq_u16_s16(sum); + vst1q_u16(dst16, results); + dst16 += 4 << 1; + } else { + const uint8x8_t results = vqmovun_s16(sum); + + StoreLo4(dst8, results); + dst8 += dst_stride; + StoreHi4(dst8, results); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + if (num_taps >= 4) { + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + if (num_taps >= 6) { + srcs[3] = srcs[5]; + srcs[4] = srcs[6]; + if (num_taps == 8) { + srcs[5] = srcs[7]; + srcs[6] = srcs[8]; + } + } + } + y += 2; + } while (y < height); +} + +// Take advantage of |src_stride| == |width| to process four rows at a time. +template <int num_taps> +void Filter2DVertical2xH(const uint16_t* src, void* const dst, + const ptrdiff_t dst_stride, const int height, + const int16x8_t taps) { + constexpr int next_row = (num_taps < 6) ? 4 : 8; + + auto* dst8 = static_cast<uint8_t*>(dst); + + int16x8_t srcs[9]; + srcs[0] = vreinterpretq_s16_u16(vld1q_u16(src)); + src += 8; + if (num_taps >= 6) { + srcs[4] = vreinterpretq_s16_u16(vld1q_u16(src)); + src += 8; + srcs[1] = vextq_s16(srcs[0], srcs[4], 2); + if (num_taps == 8) { + srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4])); + srcs[3] = vextq_s16(srcs[0], srcs[4], 6); + } + } + + int y = 0; + do { + srcs[next_row] = vreinterpretq_s16_u16(vld1q_u16(src)); + src += 8; + if (num_taps == 2) { + srcs[1] = vextq_s16(srcs[0], srcs[4], 2); + } else if (num_taps == 4) { + srcs[1] = vextq_s16(srcs[0], srcs[4], 2); + srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4])); + srcs[3] = vextq_s16(srcs[0], srcs[4], 6); + } else if (num_taps == 6) { + srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4])); + srcs[3] = vextq_s16(srcs[0], srcs[4], 6); + srcs[5] = vextq_s16(srcs[4], srcs[8], 2); + } else if (num_taps == 8) { + srcs[5] = vextq_s16(srcs[4], srcs[8], 2); + srcs[6] = vcombine_s16(vget_high_s16(srcs[4]), vget_low_s16(srcs[8])); + srcs[7] = vextq_s16(srcs[4], srcs[8], 6); + } + + const int16x8_t sum = + SimpleSum2DVerticalTaps<num_taps, /*is_compound=*/false>(srcs, taps); + const uint8x8_t results = vqmovun_s16(sum); + + Store2<0>(dst8, results); + dst8 += dst_stride; + Store2<1>(dst8, results); + // When |height| <= 4 the taps are restricted to 2 and 4 tap variants. + // Therefore we don't need to check this condition when |height| > 4. + if (num_taps <= 4 && height == 2) return; + dst8 += dst_stride; + Store2<2>(dst8, results); + dst8 += dst_stride; + Store2<3>(dst8, results); + dst8 += dst_stride; + + srcs[0] = srcs[4]; + if (num_taps == 6) { + srcs[1] = srcs[5]; + srcs[4] = srcs[8]; + } else if (num_taps == 8) { + srcs[1] = srcs[5]; + srcs[2] = srcs[6]; + srcs[3] = srcs[7]; + srcs[4] = srcs[8]; + } + + y += 4; + } while (y < height); +} + +template <bool is_2d = false, bool is_compound = false> +LIBGAV1_ALWAYS_INLINE void DoHorizontalPass( + const uint8_t* const src, const ptrdiff_t src_stride, void* const dst, + const ptrdiff_t dst_stride, const int width, const int height, + const int filter_id, const int filter_index) { + // Duplicate the absolute value for each tap. Negative taps are corrected + // by using the vmlsl_u8 instruction. Positive taps use vmlal_u8. + uint8x8_t v_tap[kSubPixelTaps]; + assert(filter_id != 0); + + for (int k = 0; k < kSubPixelTaps; ++k) { + v_tap[k] = vdup_n_u8(kAbsHalfSubPixelFilters[filter_index][filter_id][k]); + } + + if (filter_index == 2) { // 8 tap. + FilterHorizontal<8, 8, 2, true, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else if (filter_index == 1) { // 6 tap. + // Check if outside taps are positive. + if ((filter_id == 1) | (filter_id == 15)) { + FilterHorizontal<6, 8, 1, false, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else { + FilterHorizontal<6, 8, 1, true, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } + } else if (filter_index == 0) { // 6 tap. + FilterHorizontal<6, 8, 0, true, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else if (filter_index == 4) { // 4 tap. + FilterHorizontal<4, 8, 4, true, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else if (filter_index == 5) { // 4 tap. + FilterHorizontal<4, 8, 5, true, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else { // 2 tap. + FilterHorizontal<2, 8, 3, true, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } +} + +void Convolve2D_NEON(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int vertical_filter_index, + const int horizontal_filter_id, + const int vertical_filter_id, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width); + const int vert_filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(vert_filter_index); + + // The output of the horizontal filter is guaranteed to fit in 16 bits. + uint16_t + intermediate_result[kMaxSuperBlockSizeInPixels * + (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)]; + const int intermediate_height = height + vertical_taps - 1; + + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride - kHorizontalOffset; + + DoHorizontalPass</*is_2d=*/true>(src, src_stride, intermediate_result, width, + width, intermediate_height, + horizontal_filter_id, horiz_filter_index); + + // Vertical filter. + auto* dest = static_cast<uint8_t*>(prediction); + const ptrdiff_t dest_stride = pred_stride; + assert(vertical_filter_id != 0); + + const int16x8_t taps = vmovl_s8( + vld1_s8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id])); + + if (vertical_taps == 8) { + if (width == 2) { + Filter2DVertical2xH<8>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<8>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<8>(intermediate_result, dest, dest_stride, width, height, + taps); + } + } else if (vertical_taps == 6) { + if (width == 2) { + Filter2DVertical2xH<6>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<6>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<6>(intermediate_result, dest, dest_stride, width, height, + taps); + } + } else if (vertical_taps == 4) { + if (width == 2) { + Filter2DVertical2xH<4>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<4>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<4>(intermediate_result, dest, dest_stride, width, height, + taps); + } + } else { // |vertical_taps| == 2 + if (width == 2) { + Filter2DVertical2xH<2>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<2>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<2>(intermediate_result, dest, dest_stride, width, height, + taps); + } + } +} + +// There are many opportunities for overreading in scaled convolve, because the +// range of starting points for filter windows is anywhere from 0 to 16 for 8 +// destination pixels, and the window sizes range from 2 to 8. To accommodate +// this range concisely, we use |grade_x| to mean the most steps in src that can +// be traversed in a single |step_x| increment, i.e. 1 or 2. When grade_x is 2, +// we are guaranteed to exceed 8 whole steps in src for every 8 |step_x| +// increments. The first load covers the initial elements of src_x, while the +// final load covers the taps. +template <int grade_x> +inline uint8x8x3_t LoadSrcVals(const uint8_t* src_x) { + uint8x8x3_t ret; + const uint8x16_t src_val = vld1q_u8(src_x); + ret.val[0] = vget_low_u8(src_val); + ret.val[1] = vget_high_u8(src_val); + if (grade_x > 1) { + ret.val[2] = vld1_u8(src_x + 16); + } + return ret; +} + +// Pre-transpose the 2 tap filters in |kAbsHalfSubPixelFilters|[3] +inline uint8x16_t GetPositive2TapFilter(const int tap_index) { + assert(tap_index < 2); + alignas( + 16) static constexpr uint8_t kAbsHalfSubPixel2TapFilterColumns[2][16] = { + {64, 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4}, + {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}}; + + return vld1q_u8(kAbsHalfSubPixel2TapFilterColumns[tap_index]); +} + +template <int grade_x> +inline void ConvolveKernelHorizontal2Tap(const uint8_t* src, + const ptrdiff_t src_stride, + const int width, const int subpixel_x, + const int step_x, + const int intermediate_height, + int16_t* intermediate) { + // Account for the 0-taps that precede the 2 nonzero taps. + const int kernel_offset = 3; + const int ref_x = subpixel_x >> kScaleSubPixelBits; + const int step_x8 = step_x << 3; + const uint8x16_t filter_taps0 = GetPositive2TapFilter(0); + const uint8x16_t filter_taps1 = GetPositive2TapFilter(1); + const uint16x8_t index_steps = vmulq_n_u16( + vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x)); + const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask); + + int p = subpixel_x; + if (width <= 4) { + const uint8_t* src_x = + &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset]; + // Only add steps to the 10-bit truncated p to avoid overflow. + const uint16x8_t p_fraction = vdupq_n_u16(p & 1023); + const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction); + const uint8x8_t filter_indices = + vand_u8(vshrn_n_u16(subpel_index_offsets, 6), filter_index_mask); + // This is a special case. The 2-tap filter has no negative taps, so we + // can use unsigned values. + // For each x, a lane of tapsK has + // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends + // on x. + const uint8x8_t taps[2] = {VQTbl1U8(filter_taps0, filter_indices), + VQTbl1U8(filter_taps1, filter_indices)}; + int y = 0; + do { + // Load a pool of samples to select from using stepped indices. + const uint8x16_t src_vals = vld1q_u8(src_x); + const uint8x8_t src_indices = + vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)); + + // For each x, a lane of srcK contains src_x[k]. + const uint8x8_t src[2] = { + VQTbl1U8(src_vals, src_indices), + VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(1)))}; + + vst1q_s16(intermediate, + vrshrq_n_s16(SumOnePassTaps</*filter_index=*/3>(src, taps), + kInterRoundBitsHorizontal - 1)); + src_x += src_stride; + intermediate += kIntermediateStride; + } while (++y < intermediate_height); + return; + } + + // |width| >= 8 + int x = 0; + do { + const uint8_t* src_x = + &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset]; + int16_t* intermediate_x = intermediate + x; + // Only add steps to the 10-bit truncated p to avoid overflow. + const uint16x8_t p_fraction = vdupq_n_u16(p & 1023); + const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction); + const uint8x8_t filter_indices = + vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift), + filter_index_mask); + // This is a special case. The 2-tap filter has no negative taps, so we + // can use unsigned values. + // For each x, a lane of tapsK has + // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends + // on x. + const uint8x8_t taps[2] = {VQTbl1U8(filter_taps0, filter_indices), + VQTbl1U8(filter_taps1, filter_indices)}; + int y = 0; + do { + // Load a pool of samples to select from using stepped indices. + const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x); + const uint8x8_t src_indices = + vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)); + + // For each x, a lane of srcK contains src_x[k]. + const uint8x8_t src[2] = { + vtbl3_u8(src_vals, src_indices), + vtbl3_u8(src_vals, vadd_u8(src_indices, vdup_n_u8(1)))}; + + vst1q_s16(intermediate_x, + vrshrq_n_s16(SumOnePassTaps</*filter_index=*/3>(src, taps), + kInterRoundBitsHorizontal - 1)); + src_x += src_stride; + intermediate_x += kIntermediateStride; + } while (++y < intermediate_height); + x += 8; + p += step_x8; + } while (x < width); +} + +// Pre-transpose the 4 tap filters in |kAbsHalfSubPixelFilters|[5]. +inline uint8x16_t GetPositive4TapFilter(const int tap_index) { + assert(tap_index < 4); + alignas( + 16) static constexpr uint8_t kSubPixel4TapPositiveFilterColumns[4][16] = { + {0, 15, 13, 11, 10, 9, 8, 7, 6, 6, 5, 4, 3, 2, 2, 1}, + {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17}, + {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31}, + {0, 1, 2, 2, 3, 4, 5, 6, 6, 7, 8, 9, 10, 11, 13, 15}}; + + return vld1q_u8(kSubPixel4TapPositiveFilterColumns[tap_index]); +} + +// This filter is only possible when width <= 4. +void ConvolveKernelHorizontalPositive4Tap( + const uint8_t* src, const ptrdiff_t src_stride, const int subpixel_x, + const int step_x, const int intermediate_height, int16_t* intermediate) { + const int kernel_offset = 2; + const int ref_x = subpixel_x >> kScaleSubPixelBits; + const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask); + const uint8x16_t filter_taps0 = GetPositive4TapFilter(0); + const uint8x16_t filter_taps1 = GetPositive4TapFilter(1); + const uint8x16_t filter_taps2 = GetPositive4TapFilter(2); + const uint8x16_t filter_taps3 = GetPositive4TapFilter(3); + const uint16x8_t index_steps = vmulq_n_u16( + vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x)); + const int p = subpixel_x; + // First filter is special, just a 128 tap on the center. + const uint8_t* src_x = + &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset]; + // Only add steps to the 10-bit truncated p to avoid overflow. + const uint16x8_t p_fraction = vdupq_n_u16(p & 1023); + const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction); + const uint8x8_t filter_indices = vand_u8( + vshrn_n_u16(subpel_index_offsets, kFilterIndexShift), filter_index_mask); + // Note that filter_id depends on x. + // For each x, tapsK has kSubPixelFilters[filter_index][filter_id][k]. + const uint8x8_t taps[4] = {VQTbl1U8(filter_taps0, filter_indices), + VQTbl1U8(filter_taps1, filter_indices), + VQTbl1U8(filter_taps2, filter_indices), + VQTbl1U8(filter_taps3, filter_indices)}; + + const uint8x8_t src_indices = + vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)); + int y = 0; + do { + // Load a pool of samples to select from using stepped index vectors. + const uint8x16_t src_vals = vld1q_u8(src_x); + + // For each x, srcK contains src_x[k] where k=1. + // Whereas taps come from different arrays, src pixels are drawn from the + // same contiguous line. + const uint8x8_t src[4] = { + VQTbl1U8(src_vals, src_indices), + VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(1))), + VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(2))), + VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(3)))}; + + vst1q_s16(intermediate, + vrshrq_n_s16(SumOnePassTaps</*filter_index=*/5>(src, taps), + kInterRoundBitsHorizontal - 1)); + + src_x += src_stride; + intermediate += kIntermediateStride; + } while (++y < intermediate_height); +} + +// Pre-transpose the 4 tap filters in |kAbsHalfSubPixelFilters|[4]. +inline uint8x16_t GetSigned4TapFilter(const int tap_index) { + assert(tap_index < 4); + alignas(16) static constexpr uint8_t + kAbsHalfSubPixel4TapSignedFilterColumns[4][16] = { + {0, 2, 4, 5, 6, 6, 7, 6, 6, 5, 5, 5, 4, 3, 2, 1}, + {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4}, + {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63}, + {0, 1, 2, 3, 4, 5, 5, 5, 6, 6, 7, 6, 6, 5, 4, 2}}; + + return vld1q_u8(kAbsHalfSubPixel4TapSignedFilterColumns[tap_index]); +} + +// This filter is only possible when width <= 4. +inline void ConvolveKernelHorizontalSigned4Tap( + const uint8_t* src, const ptrdiff_t src_stride, const int subpixel_x, + const int step_x, const int intermediate_height, int16_t* intermediate) { + const int kernel_offset = 2; + const int ref_x = subpixel_x >> kScaleSubPixelBits; + const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask); + const uint8x16_t filter_taps0 = GetSigned4TapFilter(0); + const uint8x16_t filter_taps1 = GetSigned4TapFilter(1); + const uint8x16_t filter_taps2 = GetSigned4TapFilter(2); + const uint8x16_t filter_taps3 = GetSigned4TapFilter(3); + const uint16x4_t index_steps = vmul_n_u16(vcreate_u16(0x0003000200010000), + static_cast<uint16_t>(step_x)); + + const int p = subpixel_x; + const uint8_t* src_x = + &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset]; + // Only add steps to the 10-bit truncated p to avoid overflow. + const uint16x4_t p_fraction = vdup_n_u16(p & 1023); + const uint16x4_t subpel_index_offsets = vadd_u16(index_steps, p_fraction); + const uint8x8_t filter_index_offsets = vshrn_n_u16( + vcombine_u16(subpel_index_offsets, vdup_n_u16(0)), kFilterIndexShift); + const uint8x8_t filter_indices = + vand_u8(filter_index_offsets, filter_index_mask); + // Note that filter_id depends on x. + // For each x, tapsK has kSubPixelFilters[filter_index][filter_id][k]. + const uint8x8_t taps[4] = {VQTbl1U8(filter_taps0, filter_indices), + VQTbl1U8(filter_taps1, filter_indices), + VQTbl1U8(filter_taps2, filter_indices), + VQTbl1U8(filter_taps3, filter_indices)}; + + const uint8x8_t src_indices_base = + vshr_n_u8(filter_index_offsets, kScaleSubPixelBits - kFilterIndexShift); + + const uint8x8_t src_indices[4] = {src_indices_base, + vadd_u8(src_indices_base, vdup_n_u8(1)), + vadd_u8(src_indices_base, vdup_n_u8(2)), + vadd_u8(src_indices_base, vdup_n_u8(3))}; + + int y = 0; + do { + // Load a pool of samples to select from using stepped indices. + const uint8x16_t src_vals = vld1q_u8(src_x); + + // For each x, srcK contains src_x[k] where k=1. + // Whereas taps come from different arrays, src pixels are drawn from the + // same contiguous line. + const uint8x8_t src[4] = { + VQTbl1U8(src_vals, src_indices[0]), VQTbl1U8(src_vals, src_indices[1]), + VQTbl1U8(src_vals, src_indices[2]), VQTbl1U8(src_vals, src_indices[3])}; + + vst1q_s16(intermediate, + vrshrq_n_s16(SumOnePassTaps</*filter_index=*/4>(src, taps), + kInterRoundBitsHorizontal - 1)); + src_x += src_stride; + intermediate += kIntermediateStride; + } while (++y < intermediate_height); +} + +// Pre-transpose the 6 tap filters in |kAbsHalfSubPixelFilters|[0]. +inline uint8x16_t GetSigned6TapFilter(const int tap_index) { + assert(tap_index < 6); + alignas(16) static constexpr uint8_t + kAbsHalfSubPixel6TapSignedFilterColumns[6][16] = { + {0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0}, + {0, 3, 5, 6, 7, 7, 8, 7, 7, 6, 6, 6, 5, 4, 2, 1}, + {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4}, + {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63}, + {0, 1, 2, 4, 5, 6, 6, 6, 7, 7, 8, 7, 7, 6, 5, 3}, + {0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}}; + + return vld1q_u8(kAbsHalfSubPixel6TapSignedFilterColumns[tap_index]); +} + +// This filter is only possible when width >= 8. +template <int grade_x> +inline void ConvolveKernelHorizontalSigned6Tap( + const uint8_t* src, const ptrdiff_t src_stride, const int width, + const int subpixel_x, const int step_x, const int intermediate_height, + int16_t* intermediate) { + const int kernel_offset = 1; + const uint8x8_t one = vdup_n_u8(1); + const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask); + const int ref_x = subpixel_x >> kScaleSubPixelBits; + const int step_x8 = step_x << 3; + uint8x16_t filter_taps[6]; + for (int i = 0; i < 6; ++i) { + filter_taps[i] = GetSigned6TapFilter(i); + } + const uint16x8_t index_steps = vmulq_n_u16( + vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x)); + + int x = 0; + int p = subpixel_x; + do { + // Avoid overloading outside the reference boundaries. This means + // |trailing_width| can be up to 24. + const uint8_t* src_x = + &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset]; + int16_t* intermediate_x = intermediate + x; + // Only add steps to the 10-bit truncated p to avoid overflow. + const uint16x8_t p_fraction = vdupq_n_u16(p & 1023); + const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction); + const uint8x8_t src_indices = + vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)); + uint8x8_t src_lookup[6]; + src_lookup[0] = src_indices; + for (int i = 1; i < 6; ++i) { + src_lookup[i] = vadd_u8(src_lookup[i - 1], one); + } + + const uint8x8_t filter_indices = + vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift), + filter_index_mask); + // For each x, a lane of taps[k] has + // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends + // on x. + uint8x8_t taps[6]; + for (int i = 0; i < 6; ++i) { + taps[i] = VQTbl1U8(filter_taps[i], filter_indices); + } + int y = 0; + do { + // Load a pool of samples to select from using stepped indices. + const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x); + + const uint8x8_t src[6] = { + vtbl3_u8(src_vals, src_lookup[0]), vtbl3_u8(src_vals, src_lookup[1]), + vtbl3_u8(src_vals, src_lookup[2]), vtbl3_u8(src_vals, src_lookup[3]), + vtbl3_u8(src_vals, src_lookup[4]), vtbl3_u8(src_vals, src_lookup[5])}; + + vst1q_s16(intermediate_x, + vrshrq_n_s16(SumOnePassTaps</*filter_index=*/0>(src, taps), + kInterRoundBitsHorizontal - 1)); + src_x += src_stride; + intermediate_x += kIntermediateStride; + } while (++y < intermediate_height); + x += 8; + p += step_x8; + } while (x < width); +} + +// Pre-transpose the 6 tap filters in |kAbsHalfSubPixelFilters|[1]. This filter +// has mixed positive and negative outer taps which are handled in +// GetMixed6TapFilter(). +inline uint8x16_t GetPositive6TapFilter(const int tap_index) { + assert(tap_index < 6); + alignas(16) static constexpr uint8_t + kAbsHalfSubPixel6TapPositiveFilterColumns[4][16] = { + {0, 14, 13, 11, 10, 9, 8, 8, 7, 6, 5, 4, 3, 2, 2, 1}, + {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17}, + {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31}, + {0, 1, 2, 2, 3, 4, 5, 6, 7, 8, 8, 9, 10, 11, 13, 14}}; + + return vld1q_u8(kAbsHalfSubPixel6TapPositiveFilterColumns[tap_index]); +} + +inline int8x16_t GetMixed6TapFilter(const int tap_index) { + assert(tap_index < 2); + alignas( + 16) static constexpr int8_t kHalfSubPixel6TapMixedFilterColumns[2][16] = { + {0, 1, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 1}}; + + return vld1q_s8(kHalfSubPixel6TapMixedFilterColumns[tap_index]); +} + +// This filter is only possible when width >= 8. +template <int grade_x> +inline void ConvolveKernelHorizontalMixed6Tap( + const uint8_t* src, const ptrdiff_t src_stride, const int width, + const int subpixel_x, const int step_x, const int intermediate_height, + int16_t* intermediate) { + const int kernel_offset = 1; + const uint8x8_t one = vdup_n_u8(1); + const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask); + const int ref_x = subpixel_x >> kScaleSubPixelBits; + const int step_x8 = step_x << 3; + uint8x8_t taps[4]; + int16x8_t mixed_taps[2]; + uint8x16_t positive_filter_taps[4]; + for (int i = 0; i < 4; ++i) { + positive_filter_taps[i] = GetPositive6TapFilter(i); + } + int8x16_t mixed_filter_taps[2]; + mixed_filter_taps[0] = GetMixed6TapFilter(0); + mixed_filter_taps[1] = GetMixed6TapFilter(1); + const uint16x8_t index_steps = vmulq_n_u16( + vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x)); + + int x = 0; + int p = subpixel_x; + do { + const uint8_t* src_x = + &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset]; + int16_t* intermediate_x = intermediate + x; + // Only add steps to the 10-bit truncated p to avoid overflow. + const uint16x8_t p_fraction = vdupq_n_u16(p & 1023); + const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction); + const uint8x8_t src_indices = + vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)); + uint8x8_t src_lookup[6]; + src_lookup[0] = src_indices; + for (int i = 1; i < 6; ++i) { + src_lookup[i] = vadd_u8(src_lookup[i - 1], one); + } + + const uint8x8_t filter_indices = + vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift), + filter_index_mask); + // For each x, a lane of taps[k] has + // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends + // on x. + for (int i = 0; i < 4; ++i) { + taps[i] = VQTbl1U8(positive_filter_taps[i], filter_indices); + } + mixed_taps[0] = vmovl_s8(VQTbl1S8(mixed_filter_taps[0], filter_indices)); + mixed_taps[1] = vmovl_s8(VQTbl1S8(mixed_filter_taps[1], filter_indices)); + + int y = 0; + do { + // Load a pool of samples to select from using stepped indices. + const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x); + + int16x8_t sum_mixed = vmulq_s16( + mixed_taps[0], ZeroExtend(vtbl3_u8(src_vals, src_lookup[0]))); + sum_mixed = vmlaq_s16(sum_mixed, mixed_taps[1], + ZeroExtend(vtbl3_u8(src_vals, src_lookup[5]))); + uint16x8_t sum = vreinterpretq_u16_s16(sum_mixed); + sum = vmlal_u8(sum, taps[0], vtbl3_u8(src_vals, src_lookup[1])); + sum = vmlal_u8(sum, taps[1], vtbl3_u8(src_vals, src_lookup[2])); + sum = vmlal_u8(sum, taps[2], vtbl3_u8(src_vals, src_lookup[3])); + sum = vmlal_u8(sum, taps[3], vtbl3_u8(src_vals, src_lookup[4])); + + vst1q_s16(intermediate_x, vrshrq_n_s16(vreinterpretq_s16_u16(sum), + kInterRoundBitsHorizontal - 1)); + src_x += src_stride; + intermediate_x += kIntermediateStride; + } while (++y < intermediate_height); + x += 8; + p += step_x8; + } while (x < width); +} + +// Pre-transpose the 8 tap filters in |kAbsHalfSubPixelFilters|[2]. +inline uint8x16_t GetSigned8TapFilter(const int tap_index) { + assert(tap_index < 8); + alignas(16) static constexpr uint8_t + kAbsHalfSubPixel8TapSignedFilterColumns[8][16] = { + {0, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 0}, + {0, 1, 3, 4, 5, 5, 5, 5, 6, 5, 4, 4, 3, 3, 2, 1}, + {0, 3, 6, 9, 11, 11, 12, 12, 12, 11, 10, 9, 7, 5, 3, 1}, + {64, 63, 62, 60, 58, 54, 50, 45, 40, 35, 30, 24, 19, 13, 8, 4}, + {0, 4, 8, 13, 19, 24, 30, 35, 40, 45, 50, 54, 58, 60, 62, 63}, + {0, 1, 3, 5, 7, 9, 10, 11, 12, 12, 12, 11, 11, 9, 6, 3}, + {0, 1, 2, 3, 3, 4, 4, 5, 6, 5, 5, 5, 5, 4, 3, 1}, + {0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1}}; + + return vld1q_u8(kAbsHalfSubPixel8TapSignedFilterColumns[tap_index]); +} + +// This filter is only possible when width >= 8. +template <int grade_x> +inline void ConvolveKernelHorizontalSigned8Tap( + const uint8_t* src, const ptrdiff_t src_stride, const int width, + const int subpixel_x, const int step_x, const int intermediate_height, + int16_t* intermediate) { + const uint8x8_t one = vdup_n_u8(1); + const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask); + const int ref_x = subpixel_x >> kScaleSubPixelBits; + const int step_x8 = step_x << 3; + uint8x8_t taps[8]; + uint8x16_t filter_taps[8]; + for (int i = 0; i < 8; ++i) { + filter_taps[i] = GetSigned8TapFilter(i); + } + const uint16x8_t index_steps = vmulq_n_u16( + vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x)); + int x = 0; + int p = subpixel_x; + do { + const uint8_t* src_x = &src[(p >> kScaleSubPixelBits) - ref_x]; + int16_t* intermediate_x = intermediate + x; + // Only add steps to the 10-bit truncated p to avoid overflow. + const uint16x8_t p_fraction = vdupq_n_u16(p & 1023); + const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction); + const uint8x8_t src_indices = + vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)); + uint8x8_t src_lookup[8]; + src_lookup[0] = src_indices; + for (int i = 1; i < 8; ++i) { + src_lookup[i] = vadd_u8(src_lookup[i - 1], one); + } + + const uint8x8_t filter_indices = + vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift), + filter_index_mask); + // For each x, a lane of taps[k] has + // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends + // on x. + for (int i = 0; i < 8; ++i) { + taps[i] = VQTbl1U8(filter_taps[i], filter_indices); + } + + int y = 0; + do { + // Load a pool of samples to select from using stepped indices. + const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x); + + const uint8x8_t src[8] = { + vtbl3_u8(src_vals, src_lookup[0]), vtbl3_u8(src_vals, src_lookup[1]), + vtbl3_u8(src_vals, src_lookup[2]), vtbl3_u8(src_vals, src_lookup[3]), + vtbl3_u8(src_vals, src_lookup[4]), vtbl3_u8(src_vals, src_lookup[5]), + vtbl3_u8(src_vals, src_lookup[6]), vtbl3_u8(src_vals, src_lookup[7])}; + + vst1q_s16(intermediate_x, + vrshrq_n_s16(SumOnePassTaps</*filter_index=*/2>(src, taps), + kInterRoundBitsHorizontal - 1)); + src_x += src_stride; + intermediate_x += kIntermediateStride; + } while (++y < intermediate_height); + x += 8; + p += step_x8; + } while (x < width); +} + +// This function handles blocks of width 2 or 4. +template <int num_taps, int grade_y, int width, bool is_compound> +void ConvolveVerticalScale4xH(const int16_t* src, const int subpixel_y, + const int filter_index, const int step_y, + const int height, void* dest, + const ptrdiff_t dest_stride) { + constexpr ptrdiff_t src_stride = kIntermediateStride; + const int16_t* src_y = src; + // |dest| is 16-bit in compound mode, Pixel otherwise. + uint16_t* dest16_y = static_cast<uint16_t*>(dest); + uint8_t* dest_y = static_cast<uint8_t*>(dest); + int16x4_t s[num_taps + grade_y]; + + int p = subpixel_y & 1023; + int prev_p = p; + int y = 0; + do { // y < height + for (int i = 0; i < num_taps; ++i) { + s[i] = vld1_s16(src_y + i * src_stride); + } + int filter_id = (p >> 6) & kSubPixelMask; + int16x8_t filter = + vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id])); + int16x4_t sums = Sum2DVerticalTaps4<num_taps, is_compound>(s, filter); + if (is_compound) { + assert(width != 2); + const uint16x4_t result = vreinterpret_u16_s16(sums); + vst1_u16(dest16_y, result); + } else { + const uint8x8_t result = vqmovun_s16(vcombine_s16(sums, sums)); + if (width == 2) { + Store2<0>(dest_y, result); + } else { + StoreLo4(dest_y, result); + } + } + p += step_y; + const int p_diff = + (p >> kScaleSubPixelBits) - (prev_p >> kScaleSubPixelBits); + prev_p = p; + // Here we load extra source in case it is needed. If |p_diff| == 0, these + // values will be unused, but it's faster to load than to branch. + s[num_taps] = vld1_s16(src_y + num_taps * src_stride); + if (grade_y > 1) { + s[num_taps + 1] = vld1_s16(src_y + (num_taps + 1) * src_stride); + } + dest16_y += dest_stride; + dest_y += dest_stride; + + filter_id = (p >> 6) & kSubPixelMask; + filter = vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id])); + sums = Sum2DVerticalTaps4<num_taps, is_compound>(&s[p_diff], filter); + if (is_compound) { + assert(width != 2); + const uint16x4_t result = vreinterpret_u16_s16(sums); + vst1_u16(dest16_y, result); + } else { + const uint8x8_t result = vqmovun_s16(vcombine_s16(sums, sums)); + if (width == 2) { + Store2<0>(dest_y, result); + } else { + StoreLo4(dest_y, result); + } + } + p += step_y; + src_y = src + (p >> kScaleSubPixelBits) * src_stride; + prev_p = p; + dest16_y += dest_stride; + dest_y += dest_stride; + + y += 2; + } while (y < height); +} + +template <int num_taps, int grade_y, bool is_compound> +inline void ConvolveVerticalScale(const int16_t* src, const int width, + const int subpixel_y, const int filter_index, + const int step_y, const int height, + void* dest, const ptrdiff_t dest_stride) { + constexpr ptrdiff_t src_stride = kIntermediateStride; + // A possible improvement is to use arithmetic to decide how many times to + // apply filters to same source before checking whether to load new srcs. + // However, this will only improve performance with very small step sizes. + int16x8_t s[num_taps + grade_y]; + // |dest| is 16-bit in compound mode, Pixel otherwise. + uint16_t* dest16_y; + uint8_t* dest_y; + + int x = 0; + do { // x < width + const int16_t* src_x = src + x; + const int16_t* src_y = src_x; + dest16_y = static_cast<uint16_t*>(dest) + x; + dest_y = static_cast<uint8_t*>(dest) + x; + int p = subpixel_y & 1023; + int prev_p = p; + int y = 0; + do { // y < height + for (int i = 0; i < num_taps; ++i) { + s[i] = vld1q_s16(src_y + i * src_stride); + } + int filter_id = (p >> 6) & kSubPixelMask; + int16x8_t filter = + vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id])); + int16x8_t sum = SimpleSum2DVerticalTaps<num_taps, is_compound>(s, filter); + if (is_compound) { + vst1q_u16(dest16_y, vreinterpretq_u16_s16(sum)); + } else { + vst1_u8(dest_y, vqmovun_s16(sum)); + } + p += step_y; + const int p_diff = + (p >> kScaleSubPixelBits) - (prev_p >> kScaleSubPixelBits); + // |grade_y| > 1 always means p_diff > 0, so load vectors that may be + // needed. Otherwise, we only need to load one vector because |p_diff| + // can't exceed 1. + s[num_taps] = vld1q_s16(src_y + num_taps * src_stride); + if (grade_y > 1) { + s[num_taps + 1] = vld1q_s16(src_y + (num_taps + 1) * src_stride); + } + dest16_y += dest_stride; + dest_y += dest_stride; + + filter_id = (p >> 6) & kSubPixelMask; + filter = vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id])); + sum = SimpleSum2DVerticalTaps<num_taps, is_compound>(&s[p_diff], filter); + if (is_compound) { + vst1q_u16(dest16_y, vreinterpretq_u16_s16(sum)); + } else { + vst1_u8(dest_y, vqmovun_s16(sum)); + } + p += step_y; + src_y = src_x + (p >> kScaleSubPixelBits) * src_stride; + prev_p = p; + dest16_y += dest_stride; + dest_y += dest_stride; + + y += 2; + } while (y < height); + x += 8; + } while (x < width); +} + +template <bool is_compound> +void ConvolveScale2D_NEON(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int vertical_filter_index, const int subpixel_x, + const int subpixel_y, const int step_x, + const int step_y, const int width, const int height, + void* prediction, const ptrdiff_t pred_stride) { + const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width); + const int vert_filter_index = GetFilterIndex(vertical_filter_index, height); + assert(step_x <= 2048); + const int num_vert_taps = GetNumTapsInFilter(vert_filter_index); + const int intermediate_height = + (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >> + kScaleSubPixelBits) + + num_vert_taps; + assert(step_x <= 2048); + // The output of the horizontal filter, i.e. the intermediate_result, is + // guaranteed to fit in int16_t. + int16_t intermediate_result[kMaxSuperBlockSizeInPixels * + (2 * kMaxSuperBlockSizeInPixels + 8)]; + + // Horizontal filter. + // Filter types used for width <= 4 are different from those for width > 4. + // When width > 4, the valid filter index range is always [0, 3]. + // When width <= 4, the valid filter index range is always [3, 5]. + // Similarly for height. + int filter_index = GetFilterIndex(horizontal_filter_index, width); + int16_t* intermediate = intermediate_result; + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference); + const int vert_kernel_offset = (8 - num_vert_taps) / 2; + src += vert_kernel_offset * src_stride; + + // Derive the maximum value of |step_x| at which all source values fit in one + // 16-byte load. Final index is src_x + |num_taps| - 1 < 16 + // step_x*7 is the final base subpel index for the shuffle mask for filter + // inputs in each iteration on large blocks. When step_x is large, we need a + // larger structure and use a larger table lookup in order to gather all + // filter inputs. + // |num_taps| - 1 is the shuffle index of the final filter input. + const int num_horiz_taps = GetNumTapsInFilter(horiz_filter_index); + const int kernel_start_ceiling = 16 - num_horiz_taps; + // This truncated quotient |grade_x_threshold| selects |step_x| such that: + // (step_x * 7) >> kScaleSubPixelBits < single load limit + const int grade_x_threshold = + (kernel_start_ceiling << kScaleSubPixelBits) / 7; + switch (filter_index) { + case 0: + if (step_x > grade_x_threshold) { + ConvolveKernelHorizontalSigned6Tap<2>( + src, src_stride, width, subpixel_x, step_x, intermediate_height, + intermediate); + } else { + ConvolveKernelHorizontalSigned6Tap<1>( + src, src_stride, width, subpixel_x, step_x, intermediate_height, + intermediate); + } + break; + case 1: + if (step_x > grade_x_threshold) { + ConvolveKernelHorizontalMixed6Tap<2>(src, src_stride, width, subpixel_x, + step_x, intermediate_height, + intermediate); + + } else { + ConvolveKernelHorizontalMixed6Tap<1>(src, src_stride, width, subpixel_x, + step_x, intermediate_height, + intermediate); + } + break; + case 2: + if (step_x > grade_x_threshold) { + ConvolveKernelHorizontalSigned8Tap<2>( + src, src_stride, width, subpixel_x, step_x, intermediate_height, + intermediate); + } else { + ConvolveKernelHorizontalSigned8Tap<1>( + src, src_stride, width, subpixel_x, step_x, intermediate_height, + intermediate); + } + break; + case 3: + if (step_x > grade_x_threshold) { + ConvolveKernelHorizontal2Tap<2>(src, src_stride, width, subpixel_x, + step_x, intermediate_height, + intermediate); + } else { + ConvolveKernelHorizontal2Tap<1>(src, src_stride, width, subpixel_x, + step_x, intermediate_height, + intermediate); + } + break; + case 4: + assert(width <= 4); + ConvolveKernelHorizontalSigned4Tap(src, src_stride, subpixel_x, step_x, + intermediate_height, intermediate); + break; + default: + assert(filter_index == 5); + ConvolveKernelHorizontalPositive4Tap(src, src_stride, subpixel_x, step_x, + intermediate_height, intermediate); + } + // Vertical filter. + filter_index = GetFilterIndex(vertical_filter_index, height); + intermediate = intermediate_result; + + switch (filter_index) { + case 0: + case 1: + if (step_y <= 1024) { + if (!is_compound && width == 2) { + ConvolveVerticalScale4xH<6, 1, 2, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale4xH<6, 1, 4, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<6, 1, is_compound>( + intermediate, width, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } + } else { + if (!is_compound && width == 2) { + ConvolveVerticalScale4xH<6, 2, 2, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale4xH<6, 2, 4, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<6, 2, is_compound>( + intermediate, width, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } + } + break; + case 2: + if (step_y <= 1024) { + if (!is_compound && width == 2) { + ConvolveVerticalScale4xH<8, 1, 2, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale4xH<8, 1, 4, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<8, 1, is_compound>( + intermediate, width, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } + } else { + if (!is_compound && width == 2) { + ConvolveVerticalScale4xH<8, 2, 2, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale4xH<8, 2, 4, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<8, 2, is_compound>( + intermediate, width, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } + } + break; + case 3: + if (step_y <= 1024) { + if (!is_compound && width == 2) { + ConvolveVerticalScale4xH<2, 1, 2, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale4xH<2, 1, 4, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<2, 1, is_compound>( + intermediate, width, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } + } else { + if (!is_compound && width == 2) { + ConvolveVerticalScale4xH<2, 2, 2, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale4xH<2, 2, 4, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<2, 2, is_compound>( + intermediate, width, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } + } + break; + case 4: + default: + assert(filter_index == 4 || filter_index == 5); + assert(height <= 4); + if (step_y <= 1024) { + if (!is_compound && width == 2) { + ConvolveVerticalScale4xH<4, 1, 2, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale4xH<4, 1, 4, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<4, 1, is_compound>( + intermediate, width, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } + } else { + if (!is_compound && width == 2) { + ConvolveVerticalScale4xH<4, 2, 2, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale4xH<4, 2, 4, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<4, 2, is_compound>( + intermediate, width, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } + } + } +} + +void ConvolveHorizontal_NEON(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int /*vertical_filter_index*/, + const int horizontal_filter_id, + const int /*vertical_filter_id*/, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + const int filter_index = GetFilterIndex(horizontal_filter_index, width); + // Set |src| to the outermost tap. + const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset; + auto* dest = static_cast<uint8_t*>(prediction); + + DoHorizontalPass(src, reference_stride, dest, pred_stride, width, height, + horizontal_filter_id, filter_index); +} + +// The 1D compound shift is always |kInterRoundBitsHorizontal|, even for 1D +// Vertical calculations. +uint16x8_t Compound1DShift(const int16x8_t sum) { + return vreinterpretq_u16_s16( + vrshrq_n_s16(sum, kInterRoundBitsHorizontal - 1)); +} + +template <int filter_index, bool is_compound = false, + bool negative_outside_taps = false> +void FilterVertical(const uint8_t* src, const ptrdiff_t src_stride, + void* const dst, const ptrdiff_t dst_stride, + const int width, const int height, + const uint8x8_t* const taps) { + const int num_taps = GetNumTapsInFilter(filter_index); + const int next_row = num_taps - 1; + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + assert(width >= 8); + + int x = 0; + do { + const uint8_t* src_x = src + x; + uint8x8_t srcs[8]; + srcs[0] = vld1_u8(src_x); + src_x += src_stride; + if (num_taps >= 4) { + srcs[1] = vld1_u8(src_x); + src_x += src_stride; + srcs[2] = vld1_u8(src_x); + src_x += src_stride; + if (num_taps >= 6) { + srcs[3] = vld1_u8(src_x); + src_x += src_stride; + srcs[4] = vld1_u8(src_x); + src_x += src_stride; + if (num_taps == 8) { + srcs[5] = vld1_u8(src_x); + src_x += src_stride; + srcs[6] = vld1_u8(src_x); + src_x += src_stride; + } + } + } + + int y = 0; + do { + srcs[next_row] = vld1_u8(src_x); + src_x += src_stride; + + const int16x8_t sums = + SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps); + if (is_compound) { + const uint16x8_t results = Compound1DShift(sums); + vst1q_u16(dst16 + x + y * dst_stride, results); + } else { + const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1); + vst1_u8(dst8 + x + y * dst_stride, results); + } + + srcs[0] = srcs[1]; + if (num_taps >= 4) { + srcs[1] = srcs[2]; + srcs[2] = srcs[3]; + if (num_taps >= 6) { + srcs[3] = srcs[4]; + srcs[4] = srcs[5]; + if (num_taps == 8) { + srcs[5] = srcs[6]; + srcs[6] = srcs[7]; + } + } + } + } while (++y < height); + x += 8; + } while (x < width); +} + +template <int filter_index, bool is_compound = false, + bool negative_outside_taps = false> +void FilterVertical4xH(const uint8_t* src, const ptrdiff_t src_stride, + void* const dst, const ptrdiff_t dst_stride, + const int height, const uint8x8_t* const taps) { + const int num_taps = GetNumTapsInFilter(filter_index); + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + uint8x8_t srcs[9]; + + if (num_taps == 2) { + srcs[2] = vdup_n_u8(0); + + srcs[0] = Load4(src); + src += src_stride; + + int y = 0; + do { + srcs[0] = Load4<1>(src, srcs[0]); + src += src_stride; + srcs[2] = Load4<0>(src, srcs[2]); + src += src_stride; + srcs[1] = vext_u8(srcs[0], srcs[2], 4); + + const int16x8_t sums = + SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps); + if (is_compound) { + const uint16x8_t results = Compound1DShift(sums); + + vst1q_u16(dst16, results); + dst16 += 4 << 1; + } else { + const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1); + + StoreLo4(dst8, results); + dst8 += dst_stride; + StoreHi4(dst8, results); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + y += 2; + } while (y < height); + } else if (num_taps == 4) { + srcs[4] = vdup_n_u8(0); + + srcs[0] = Load4(src); + src += src_stride; + srcs[0] = Load4<1>(src, srcs[0]); + src += src_stride; + srcs[2] = Load4(src); + src += src_stride; + srcs[1] = vext_u8(srcs[0], srcs[2], 4); + + int y = 0; + do { + srcs[2] = Load4<1>(src, srcs[2]); + src += src_stride; + srcs[4] = Load4<0>(src, srcs[4]); + src += src_stride; + srcs[3] = vext_u8(srcs[2], srcs[4], 4); + + const int16x8_t sums = + SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps); + if (is_compound) { + const uint16x8_t results = Compound1DShift(sums); + + vst1q_u16(dst16, results); + dst16 += 4 << 1; + } else { + const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1); + + StoreLo4(dst8, results); + dst8 += dst_stride; + StoreHi4(dst8, results); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + y += 2; + } while (y < height); + } else if (num_taps == 6) { + srcs[6] = vdup_n_u8(0); + + srcs[0] = Load4(src); + src += src_stride; + srcs[0] = Load4<1>(src, srcs[0]); + src += src_stride; + srcs[2] = Load4(src); + src += src_stride; + srcs[1] = vext_u8(srcs[0], srcs[2], 4); + srcs[2] = Load4<1>(src, srcs[2]); + src += src_stride; + srcs[4] = Load4(src); + src += src_stride; + srcs[3] = vext_u8(srcs[2], srcs[4], 4); + + int y = 0; + do { + srcs[4] = Load4<1>(src, srcs[4]); + src += src_stride; + srcs[6] = Load4<0>(src, srcs[6]); + src += src_stride; + srcs[5] = vext_u8(srcs[4], srcs[6], 4); + + const int16x8_t sums = + SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps); + if (is_compound) { + const uint16x8_t results = Compound1DShift(sums); + + vst1q_u16(dst16, results); + dst16 += 4 << 1; + } else { + const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1); + + StoreLo4(dst8, results); + dst8 += dst_stride; + StoreHi4(dst8, results); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + srcs[3] = srcs[5]; + srcs[4] = srcs[6]; + y += 2; + } while (y < height); + } else if (num_taps == 8) { + srcs[8] = vdup_n_u8(0); + + srcs[0] = Load4(src); + src += src_stride; + srcs[0] = Load4<1>(src, srcs[0]); + src += src_stride; + srcs[2] = Load4(src); + src += src_stride; + srcs[1] = vext_u8(srcs[0], srcs[2], 4); + srcs[2] = Load4<1>(src, srcs[2]); + src += src_stride; + srcs[4] = Load4(src); + src += src_stride; + srcs[3] = vext_u8(srcs[2], srcs[4], 4); + srcs[4] = Load4<1>(src, srcs[4]); + src += src_stride; + srcs[6] = Load4(src); + src += src_stride; + srcs[5] = vext_u8(srcs[4], srcs[6], 4); + + int y = 0; + do { + srcs[6] = Load4<1>(src, srcs[6]); + src += src_stride; + srcs[8] = Load4<0>(src, srcs[8]); + src += src_stride; + srcs[7] = vext_u8(srcs[6], srcs[8], 4); + + const int16x8_t sums = + SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps); + if (is_compound) { + const uint16x8_t results = Compound1DShift(sums); + + vst1q_u16(dst16, results); + dst16 += 4 << 1; + } else { + const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1); + + StoreLo4(dst8, results); + dst8 += dst_stride; + StoreHi4(dst8, results); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + srcs[3] = srcs[5]; + srcs[4] = srcs[6]; + srcs[5] = srcs[7]; + srcs[6] = srcs[8]; + y += 2; + } while (y < height); + } +} + +template <int filter_index, bool negative_outside_taps = false> +void FilterVertical2xH(const uint8_t* src, const ptrdiff_t src_stride, + void* const dst, const ptrdiff_t dst_stride, + const int height, const uint8x8_t* const taps) { + const int num_taps = GetNumTapsInFilter(filter_index); + auto* dst8 = static_cast<uint8_t*>(dst); + + uint8x8_t srcs[9]; + + if (num_taps == 2) { + srcs[2] = vdup_n_u8(0); + + srcs[0] = Load2(src); + src += src_stride; + + int y = 0; + do { + srcs[0] = Load2<1>(src, srcs[0]); + src += src_stride; + srcs[0] = Load2<2>(src, srcs[0]); + src += src_stride; + srcs[0] = Load2<3>(src, srcs[0]); + src += src_stride; + srcs[2] = Load2<0>(src, srcs[2]); + src += src_stride; + srcs[1] = vext_u8(srcs[0], srcs[2], 2); + + // This uses srcs[0]..srcs[1]. + const int16x8_t sums = + SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps); + const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1); + + Store2<0>(dst8, results); + dst8 += dst_stride; + Store2<1>(dst8, results); + if (height == 2) return; + dst8 += dst_stride; + Store2<2>(dst8, results); + dst8 += dst_stride; + Store2<3>(dst8, results); + dst8 += dst_stride; + + srcs[0] = srcs[2]; + y += 4; + } while (y < height); + } else if (num_taps == 4) { + srcs[4] = vdup_n_u8(0); + + srcs[0] = Load2(src); + src += src_stride; + srcs[0] = Load2<1>(src, srcs[0]); + src += src_stride; + srcs[0] = Load2<2>(src, srcs[0]); + src += src_stride; + + int y = 0; + do { + srcs[0] = Load2<3>(src, srcs[0]); + src += src_stride; + srcs[4] = Load2<0>(src, srcs[4]); + src += src_stride; + srcs[1] = vext_u8(srcs[0], srcs[4], 2); + srcs[4] = Load2<1>(src, srcs[4]); + src += src_stride; + srcs[2] = vext_u8(srcs[0], srcs[4], 4); + srcs[4] = Load2<2>(src, srcs[4]); + src += src_stride; + srcs[3] = vext_u8(srcs[0], srcs[4], 6); + + // This uses srcs[0]..srcs[3]. + const int16x8_t sums = + SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps); + const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1); + + Store2<0>(dst8, results); + dst8 += dst_stride; + Store2<1>(dst8, results); + if (height == 2) return; + dst8 += dst_stride; + Store2<2>(dst8, results); + dst8 += dst_stride; + Store2<3>(dst8, results); + dst8 += dst_stride; + + srcs[0] = srcs[4]; + y += 4; + } while (y < height); + } else if (num_taps == 6) { + // During the vertical pass the number of taps is restricted when + // |height| <= 4. + assert(height > 4); + srcs[8] = vdup_n_u8(0); + + srcs[0] = Load2(src); + src += src_stride; + srcs[0] = Load2<1>(src, srcs[0]); + src += src_stride; + srcs[0] = Load2<2>(src, srcs[0]); + src += src_stride; + srcs[0] = Load2<3>(src, srcs[0]); + src += src_stride; + srcs[4] = Load2(src); + src += src_stride; + srcs[1] = vext_u8(srcs[0], srcs[4], 2); + + int y = 0; + do { + srcs[4] = Load2<1>(src, srcs[4]); + src += src_stride; + srcs[2] = vext_u8(srcs[0], srcs[4], 4); + srcs[4] = Load2<2>(src, srcs[4]); + src += src_stride; + srcs[3] = vext_u8(srcs[0], srcs[4], 6); + srcs[4] = Load2<3>(src, srcs[4]); + src += src_stride; + srcs[8] = Load2<0>(src, srcs[8]); + src += src_stride; + srcs[5] = vext_u8(srcs[4], srcs[8], 2); + + // This uses srcs[0]..srcs[5]. + const int16x8_t sums = + SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps); + const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1); + + Store2<0>(dst8, results); + dst8 += dst_stride; + Store2<1>(dst8, results); + dst8 += dst_stride; + Store2<2>(dst8, results); + dst8 += dst_stride; + Store2<3>(dst8, results); + dst8 += dst_stride; + + srcs[0] = srcs[4]; + srcs[1] = srcs[5]; + srcs[4] = srcs[8]; + y += 4; + } while (y < height); + } else if (num_taps == 8) { + // During the vertical pass the number of taps is restricted when + // |height| <= 4. + assert(height > 4); + srcs[8] = vdup_n_u8(0); + + srcs[0] = Load2(src); + src += src_stride; + srcs[0] = Load2<1>(src, srcs[0]); + src += src_stride; + srcs[0] = Load2<2>(src, srcs[0]); + src += src_stride; + srcs[0] = Load2<3>(src, srcs[0]); + src += src_stride; + srcs[4] = Load2(src); + src += src_stride; + srcs[1] = vext_u8(srcs[0], srcs[4], 2); + srcs[4] = Load2<1>(src, srcs[4]); + src += src_stride; + srcs[2] = vext_u8(srcs[0], srcs[4], 4); + srcs[4] = Load2<2>(src, srcs[4]); + src += src_stride; + srcs[3] = vext_u8(srcs[0], srcs[4], 6); + + int y = 0; + do { + srcs[4] = Load2<3>(src, srcs[4]); + src += src_stride; + srcs[8] = Load2<0>(src, srcs[8]); + src += src_stride; + srcs[5] = vext_u8(srcs[4], srcs[8], 2); + srcs[8] = Load2<1>(src, srcs[8]); + src += src_stride; + srcs[6] = vext_u8(srcs[4], srcs[8], 4); + srcs[8] = Load2<2>(src, srcs[8]); + src += src_stride; + srcs[7] = vext_u8(srcs[4], srcs[8], 6); + + // This uses srcs[0]..srcs[7]. + const int16x8_t sums = + SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps); + const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1); + + Store2<0>(dst8, results); + dst8 += dst_stride; + Store2<1>(dst8, results); + dst8 += dst_stride; + Store2<2>(dst8, results); + dst8 += dst_stride; + Store2<3>(dst8, results); + dst8 += dst_stride; + + srcs[0] = srcs[4]; + srcs[1] = srcs[5]; + srcs[2] = srcs[6]; + srcs[3] = srcs[7]; + srcs[4] = srcs[8]; + y += 4; + } while (y < height); + } +} + +// This function is a simplified version of Convolve2D_C. +// It is called when it is single prediction mode, where only vertical +// filtering is required. +// The output is the single prediction of the block, clipped to valid pixel +// range. +void ConvolveVertical_NEON(const void* const reference, + const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, + const int vertical_filter_index, + const int /*horizontal_filter_id*/, + const int vertical_filter_id, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + const int filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(filter_index); + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride; + auto* dest = static_cast<uint8_t*>(prediction); + const ptrdiff_t dest_stride = pred_stride; + assert(vertical_filter_id != 0); + + uint8x8_t taps[8]; + for (int k = 0; k < kSubPixelTaps; ++k) { + taps[k] = + vdup_n_u8(kAbsHalfSubPixelFilters[filter_index][vertical_filter_id][k]); + } + + if (filter_index == 0) { // 6 tap. + if (width == 2) { + FilterVertical2xH<0>(src, src_stride, dest, dest_stride, height, + taps + 1); + } else if (width == 4) { + FilterVertical4xH<0>(src, src_stride, dest, dest_stride, height, + taps + 1); + } else { + FilterVertical<0>(src, src_stride, dest, dest_stride, width, height, + taps + 1); + } + } else if ((filter_index == 1) & ((vertical_filter_id == 1) | + (vertical_filter_id == 15))) { // 5 tap. + if (width == 2) { + FilterVertical2xH<1>(src, src_stride, dest, dest_stride, height, + taps + 1); + } else if (width == 4) { + FilterVertical4xH<1>(src, src_stride, dest, dest_stride, height, + taps + 1); + } else { + FilterVertical<1>(src, src_stride, dest, dest_stride, width, height, + taps + 1); + } + } else if ((filter_index == 1) & + ((vertical_filter_id == 7) | (vertical_filter_id == 8) | + (vertical_filter_id == 9))) { // 6 tap with weird negative taps. + if (width == 2) { + FilterVertical2xH<1, + /*negative_outside_taps=*/true>( + src, src_stride, dest, dest_stride, height, taps + 1); + } else if (width == 4) { + FilterVertical4xH<1, /*is_compound=*/false, + /*negative_outside_taps=*/true>( + src, src_stride, dest, dest_stride, height, taps + 1); + } else { + FilterVertical<1, /*is_compound=*/false, /*negative_outside_taps=*/true>( + src, src_stride, dest, dest_stride, width, height, taps + 1); + } + } else if (filter_index == 2) { // 8 tap. + if (width == 2) { + FilterVertical2xH<2>(src, src_stride, dest, dest_stride, height, taps); + } else if (width == 4) { + FilterVertical4xH<2>(src, src_stride, dest, dest_stride, height, taps); + } else { + FilterVertical<2>(src, src_stride, dest, dest_stride, width, height, + taps); + } + } else if (filter_index == 3) { // 2 tap. + if (width == 2) { + FilterVertical2xH<3>(src, src_stride, dest, dest_stride, height, + taps + 3); + } else if (width == 4) { + FilterVertical4xH<3>(src, src_stride, dest, dest_stride, height, + taps + 3); + } else { + FilterVertical<3>(src, src_stride, dest, dest_stride, width, height, + taps + 3); + } + } else if (filter_index == 4) { // 4 tap. + // Outside taps are negative. + if (width == 2) { + FilterVertical2xH<4>(src, src_stride, dest, dest_stride, height, + taps + 2); + } else if (width == 4) { + FilterVertical4xH<4>(src, src_stride, dest, dest_stride, height, + taps + 2); + } else { + FilterVertical<4>(src, src_stride, dest, dest_stride, width, height, + taps + 2); + } + } else { + // 4 tap. When |filter_index| == 1 the |vertical_filter_id| values listed + // below map to 4 tap filters. + assert(filter_index == 5 || + (filter_index == 1 && + (vertical_filter_id == 2 || vertical_filter_id == 3 || + vertical_filter_id == 4 || vertical_filter_id == 5 || + vertical_filter_id == 6 || vertical_filter_id == 10 || + vertical_filter_id == 11 || vertical_filter_id == 12 || + vertical_filter_id == 13 || vertical_filter_id == 14))); + // According to GetNumTapsInFilter() this has 6 taps but here we are + // treating it as though it has 4. + if (filter_index == 1) src += src_stride; + if (width == 2) { + FilterVertical2xH<5>(src, src_stride, dest, dest_stride, height, + taps + 2); + } else if (width == 4) { + FilterVertical4xH<5>(src, src_stride, dest, dest_stride, height, + taps + 2); + } else { + FilterVertical<5>(src, src_stride, dest, dest_stride, width, height, + taps + 2); + } + } +} + +void ConvolveCompoundCopy_NEON( + const void* const reference, const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/, + const int /*horizontal_filter_id*/, const int /*vertical_filter_id*/, + const int width, const int height, void* prediction, + const ptrdiff_t /*pred_stride*/) { + const auto* src = static_cast<const uint8_t*>(reference); + const ptrdiff_t src_stride = reference_stride; + auto* dest = static_cast<uint16_t*>(prediction); + constexpr int final_shift = + kInterRoundBitsVertical - kInterRoundBitsCompoundVertical; + + if (width >= 16) { + int y = 0; + do { + int x = 0; + do { + const uint8x16_t v_src = vld1q_u8(&src[x]); + const uint16x8_t v_dest_lo = + vshll_n_u8(vget_low_u8(v_src), final_shift); + const uint16x8_t v_dest_hi = + vshll_n_u8(vget_high_u8(v_src), final_shift); + vst1q_u16(&dest[x], v_dest_lo); + x += 8; + vst1q_u16(&dest[x], v_dest_hi); + x += 8; + } while (x < width); + src += src_stride; + dest += width; + } while (++y < height); + } else if (width == 8) { + int y = 0; + do { + const uint8x8_t v_src = vld1_u8(&src[0]); + const uint16x8_t v_dest = vshll_n_u8(v_src, final_shift); + vst1q_u16(&dest[0], v_dest); + src += src_stride; + dest += width; + } while (++y < height); + } else { /* width == 4 */ + uint8x8_t v_src = vdup_n_u8(0); + + int y = 0; + do { + v_src = Load4<0>(&src[0], v_src); + src += src_stride; + v_src = Load4<1>(&src[0], v_src); + src += src_stride; + const uint16x8_t v_dest = vshll_n_u8(v_src, final_shift); + vst1q_u16(&dest[0], v_dest); + dest += 4 << 1; + y += 2; + } while (y < height); + } +} + +void ConvolveCompoundVertical_NEON( + const void* const reference, const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, const int vertical_filter_index, + const int /*horizontal_filter_id*/, const int vertical_filter_id, + const int width, const int height, void* prediction, + const ptrdiff_t /*pred_stride*/) { + const int filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(filter_index); + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride; + auto* dest = static_cast<uint16_t*>(prediction); + assert(vertical_filter_id != 0); + + uint8x8_t taps[8]; + for (int k = 0; k < kSubPixelTaps; ++k) { + taps[k] = + vdup_n_u8(kAbsHalfSubPixelFilters[filter_index][vertical_filter_id][k]); + } + + if (filter_index == 0) { // 6 tap. + if (width == 4) { + FilterVertical4xH<0, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps + 1); + } else { + FilterVertical<0, /*is_compound=*/true>(src, src_stride, dest, width, + width, height, taps + 1); + } + } else if ((filter_index == 1) & ((vertical_filter_id == 1) | + (vertical_filter_id == 15))) { // 5 tap. + if (width == 4) { + FilterVertical4xH<1, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps + 1); + } else { + FilterVertical<1, /*is_compound=*/true>(src, src_stride, dest, width, + width, height, taps + 1); + } + } else if ((filter_index == 1) & + ((vertical_filter_id == 7) | (vertical_filter_id == 8) | + (vertical_filter_id == 9))) { // 6 tap with weird negative taps. + if (width == 4) { + FilterVertical4xH<1, /*is_compound=*/true, + /*negative_outside_taps=*/true>(src, src_stride, dest, + 4, height, taps + 1); + } else { + FilterVertical<1, /*is_compound=*/true, /*negative_outside_taps=*/true>( + src, src_stride, dest, width, width, height, taps + 1); + } + } else if (filter_index == 2) { // 8 tap. + if (width == 4) { + FilterVertical4xH<2, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps); + } else { + FilterVertical<2, /*is_compound=*/true>(src, src_stride, dest, width, + width, height, taps); + } + } else if (filter_index == 3) { // 2 tap. + if (width == 4) { + FilterVertical4xH<3, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps + 3); + } else { + FilterVertical<3, /*is_compound=*/true>(src, src_stride, dest, width, + width, height, taps + 3); + } + } else if (filter_index == 4) { // 4 tap. + if (width == 4) { + FilterVertical4xH<4, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps + 2); + } else { + FilterVertical<4, /*is_compound=*/true>(src, src_stride, dest, width, + width, height, taps + 2); + } + } else { + // 4 tap. When |filter_index| == 1 the |filter_id| values listed below map + // to 4 tap filters. + assert(filter_index == 5 || + (filter_index == 1 && + (vertical_filter_id == 2 || vertical_filter_id == 3 || + vertical_filter_id == 4 || vertical_filter_id == 5 || + vertical_filter_id == 6 || vertical_filter_id == 10 || + vertical_filter_id == 11 || vertical_filter_id == 12 || + vertical_filter_id == 13 || vertical_filter_id == 14))); + // According to GetNumTapsInFilter() this has 6 taps but here we are + // treating it as though it has 4. + if (filter_index == 1) src += src_stride; + if (width == 4) { + FilterVertical4xH<5, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps + 2); + } else { + FilterVertical<5, /*is_compound=*/true>(src, src_stride, dest, width, + width, height, taps + 2); + } + } +} + +void ConvolveCompoundHorizontal_NEON( + const void* const reference, const ptrdiff_t reference_stride, + const int horizontal_filter_index, const int /*vertical_filter_index*/, + const int horizontal_filter_id, const int /*vertical_filter_id*/, + const int width, const int height, void* prediction, + const ptrdiff_t /*pred_stride*/) { + const int filter_index = GetFilterIndex(horizontal_filter_index, width); + const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset; + auto* dest = static_cast<uint16_t*>(prediction); + + DoHorizontalPass</*is_2d=*/false, /*is_compound=*/true>( + src, reference_stride, dest, width, width, height, horizontal_filter_id, + filter_index); +} + +void ConvolveCompound2D_NEON(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int vertical_filter_index, + const int horizontal_filter_id, + const int vertical_filter_id, const int width, + const int height, void* prediction, + const ptrdiff_t /*pred_stride*/) { + // The output of the horizontal filter, i.e. the intermediate_result, is + // guaranteed to fit in int16_t. + uint16_t + intermediate_result[kMaxSuperBlockSizeInPixels * + (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)]; + + // Horizontal filter. + // Filter types used for width <= 4 are different from those for width > 4. + // When width > 4, the valid filter index range is always [0, 3]. + // When width <= 4, the valid filter index range is always [4, 5]. + // Similarly for height. + const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width); + const int vert_filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(vert_filter_index); + const int intermediate_height = height + vertical_taps - 1; + const ptrdiff_t src_stride = reference_stride; + const auto* const src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride - + kHorizontalOffset; + + DoHorizontalPass</*is_2d=*/true, /*is_compound=*/true>( + src, src_stride, intermediate_result, width, width, intermediate_height, + horizontal_filter_id, horiz_filter_index); + + // Vertical filter. + auto* dest = static_cast<uint16_t*>(prediction); + assert(vertical_filter_id != 0); + + const ptrdiff_t dest_stride = width; + const int16x8_t taps = vmovl_s8( + vld1_s8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id])); + + if (vertical_taps == 8) { + if (width == 4) { + Filter2DVertical4xH<8, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<8, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } else if (vertical_taps == 6) { + if (width == 4) { + Filter2DVertical4xH<6, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<6, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } else if (vertical_taps == 4) { + if (width == 4) { + Filter2DVertical4xH<4, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<4, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } else { // |vertical_taps| == 2 + if (width == 4) { + Filter2DVertical4xH<2, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<2, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } +} + +inline void HalfAddHorizontal(const uint8_t* src, uint8_t* dst) { + const uint8x16_t left = vld1q_u8(src); + const uint8x16_t right = vld1q_u8(src + 1); + vst1q_u8(dst, vrhaddq_u8(left, right)); +} + +template <int width> +inline void IntraBlockCopyHorizontal(const uint8_t* src, + const ptrdiff_t src_stride, + const int height, uint8_t* dst, + const ptrdiff_t dst_stride) { + const ptrdiff_t src_remainder_stride = src_stride - (width - 16); + const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16); + + int y = 0; + do { + HalfAddHorizontal(src, dst); + if (width >= 32) { + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + if (width >= 64) { + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + if (width == 128) { + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + } + } + } + src += src_remainder_stride; + dst += dst_remainder_stride; + } while (++y < height); +} + +void ConvolveIntraBlockCopyHorizontal_NEON( + const void* const reference, const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/, + const int /*subpixel_x*/, const int /*subpixel_y*/, const int width, + const int height, void* const prediction, const ptrdiff_t pred_stride) { + const auto* src = static_cast<const uint8_t*>(reference); + auto* dest = static_cast<uint8_t*>(prediction); + + if (width == 128) { + IntraBlockCopyHorizontal<128>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 64) { + IntraBlockCopyHorizontal<64>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 32) { + IntraBlockCopyHorizontal<32>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 16) { + IntraBlockCopyHorizontal<16>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 8) { + int y = 0; + do { + const uint8x8_t left = vld1_u8(src); + const uint8x8_t right = vld1_u8(src + 1); + vst1_u8(dest, vrhadd_u8(left, right)); + + src += reference_stride; + dest += pred_stride; + } while (++y < height); + } else if (width == 4) { + uint8x8_t left = vdup_n_u8(0); + uint8x8_t right = vdup_n_u8(0); + int y = 0; + do { + left = Load4<0>(src, left); + right = Load4<0>(src + 1, right); + src += reference_stride; + left = Load4<1>(src, left); + right = Load4<1>(src + 1, right); + src += reference_stride; + + const uint8x8_t result = vrhadd_u8(left, right); + + StoreLo4(dest, result); + dest += pred_stride; + StoreHi4(dest, result); + dest += pred_stride; + y += 2; + } while (y < height); + } else { + assert(width == 2); + uint8x8_t left = vdup_n_u8(0); + uint8x8_t right = vdup_n_u8(0); + int y = 0; + do { + left = Load2<0>(src, left); + right = Load2<0>(src + 1, right); + src += reference_stride; + left = Load2<1>(src, left); + right = Load2<1>(src + 1, right); + src += reference_stride; + + const uint8x8_t result = vrhadd_u8(left, right); + + Store2<0>(dest, result); + dest += pred_stride; + Store2<1>(dest, result); + dest += pred_stride; + y += 2; + } while (y < height); + } +} + +template <int width> +inline void IntraBlockCopyVertical(const uint8_t* src, + const ptrdiff_t src_stride, const int height, + uint8_t* dst, const ptrdiff_t dst_stride) { + const ptrdiff_t src_remainder_stride = src_stride - (width - 16); + const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16); + uint8x16_t row[8], below[8]; + + row[0] = vld1q_u8(src); + if (width >= 32) { + src += 16; + row[1] = vld1q_u8(src); + if (width >= 64) { + src += 16; + row[2] = vld1q_u8(src); + src += 16; + row[3] = vld1q_u8(src); + if (width == 128) { + src += 16; + row[4] = vld1q_u8(src); + src += 16; + row[5] = vld1q_u8(src); + src += 16; + row[6] = vld1q_u8(src); + src += 16; + row[7] = vld1q_u8(src); + } + } + } + src += src_remainder_stride; + + int y = 0; + do { + below[0] = vld1q_u8(src); + if (width >= 32) { + src += 16; + below[1] = vld1q_u8(src); + if (width >= 64) { + src += 16; + below[2] = vld1q_u8(src); + src += 16; + below[3] = vld1q_u8(src); + if (width == 128) { + src += 16; + below[4] = vld1q_u8(src); + src += 16; + below[5] = vld1q_u8(src); + src += 16; + below[6] = vld1q_u8(src); + src += 16; + below[7] = vld1q_u8(src); + } + } + } + src += src_remainder_stride; + + vst1q_u8(dst, vrhaddq_u8(row[0], below[0])); + row[0] = below[0]; + if (width >= 32) { + dst += 16; + vst1q_u8(dst, vrhaddq_u8(row[1], below[1])); + row[1] = below[1]; + if (width >= 64) { + dst += 16; + vst1q_u8(dst, vrhaddq_u8(row[2], below[2])); + row[2] = below[2]; + dst += 16; + vst1q_u8(dst, vrhaddq_u8(row[3], below[3])); + row[3] = below[3]; + if (width >= 128) { + dst += 16; + vst1q_u8(dst, vrhaddq_u8(row[4], below[4])); + row[4] = below[4]; + dst += 16; + vst1q_u8(dst, vrhaddq_u8(row[5], below[5])); + row[5] = below[5]; + dst += 16; + vst1q_u8(dst, vrhaddq_u8(row[6], below[6])); + row[6] = below[6]; + dst += 16; + vst1q_u8(dst, vrhaddq_u8(row[7], below[7])); + row[7] = below[7]; + } + } + } + dst += dst_remainder_stride; + } while (++y < height); +} + +void ConvolveIntraBlockCopyVertical_NEON( + const void* const reference, const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/, + const int /*horizontal_filter_id*/, const int /*vertical_filter_id*/, + const int width, const int height, void* const prediction, + const ptrdiff_t pred_stride) { + const auto* src = static_cast<const uint8_t*>(reference); + auto* dest = static_cast<uint8_t*>(prediction); + + if (width == 128) { + IntraBlockCopyVertical<128>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 64) { + IntraBlockCopyVertical<64>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 32) { + IntraBlockCopyVertical<32>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 16) { + IntraBlockCopyVertical<16>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 8) { + uint8x8_t row, below; + row = vld1_u8(src); + src += reference_stride; + + int y = 0; + do { + below = vld1_u8(src); + src += reference_stride; + + vst1_u8(dest, vrhadd_u8(row, below)); + dest += pred_stride; + + row = below; + } while (++y < height); + } else if (width == 4) { + uint8x8_t row = Load4(src); + uint8x8_t below = vdup_n_u8(0); + src += reference_stride; + + int y = 0; + do { + below = Load4<0>(src, below); + src += reference_stride; + + StoreLo4(dest, vrhadd_u8(row, below)); + dest += pred_stride; + + row = below; + } while (++y < height); + } else { + assert(width == 2); + uint8x8_t row = Load2(src); + uint8x8_t below = vdup_n_u8(0); + src += reference_stride; + + int y = 0; + do { + below = Load2<0>(src, below); + src += reference_stride; + + Store2<0>(dest, vrhadd_u8(row, below)); + dest += pred_stride; + + row = below; + } while (++y < height); + } +} + +template <int width> +inline void IntraBlockCopy2D(const uint8_t* src, const ptrdiff_t src_stride, + const int height, uint8_t* dst, + const ptrdiff_t dst_stride) { + const ptrdiff_t src_remainder_stride = src_stride - (width - 8); + const ptrdiff_t dst_remainder_stride = dst_stride - (width - 8); + uint16x8_t row[16]; + row[0] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + if (width >= 16) { + src += 8; + row[1] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + if (width >= 32) { + src += 8; + row[2] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[3] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + if (width >= 64) { + src += 8; + row[4] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[5] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[6] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[7] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + if (width == 128) { + src += 8; + row[8] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[9] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[10] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[11] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[12] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[13] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[14] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[15] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + } + } + } + } + src += src_remainder_stride; + + int y = 0; + do { + const uint16x8_t below_0 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[0], below_0), 2)); + row[0] = below_0; + if (width >= 16) { + src += 8; + dst += 8; + + const uint16x8_t below_1 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[1], below_1), 2)); + row[1] = below_1; + if (width >= 32) { + src += 8; + dst += 8; + + const uint16x8_t below_2 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[2], below_2), 2)); + row[2] = below_2; + src += 8; + dst += 8; + + const uint16x8_t below_3 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[3], below_3), 2)); + row[3] = below_3; + if (width >= 64) { + src += 8; + dst += 8; + + const uint16x8_t below_4 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[4], below_4), 2)); + row[4] = below_4; + src += 8; + dst += 8; + + const uint16x8_t below_5 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[5], below_5), 2)); + row[5] = below_5; + src += 8; + dst += 8; + + const uint16x8_t below_6 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[6], below_6), 2)); + row[6] = below_6; + src += 8; + dst += 8; + + const uint16x8_t below_7 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[7], below_7), 2)); + row[7] = below_7; + if (width == 128) { + src += 8; + dst += 8; + + const uint16x8_t below_8 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[8], below_8), 2)); + row[8] = below_8; + src += 8; + dst += 8; + + const uint16x8_t below_9 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[9], below_9), 2)); + row[9] = below_9; + src += 8; + dst += 8; + + const uint16x8_t below_10 = + vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[10], below_10), 2)); + row[10] = below_10; + src += 8; + dst += 8; + + const uint16x8_t below_11 = + vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[11], below_11), 2)); + row[11] = below_11; + src += 8; + dst += 8; + + const uint16x8_t below_12 = + vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[12], below_12), 2)); + row[12] = below_12; + src += 8; + dst += 8; + + const uint16x8_t below_13 = + vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[13], below_13), 2)); + row[13] = below_13; + src += 8; + dst += 8; + + const uint16x8_t below_14 = + vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[14], below_14), 2)); + row[14] = below_14; + src += 8; + dst += 8; + + const uint16x8_t below_15 = + vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[15], below_15), 2)); + row[15] = below_15; + } + } + } + } + src += src_remainder_stride; + dst += dst_remainder_stride; + } while (++y < height); +} + +void ConvolveIntraBlockCopy2D_NEON( + const void* const reference, const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/, + const int /*horizontal_filter_id*/, const int /*vertical_filter_id*/, + const int width, const int height, void* const prediction, + const ptrdiff_t pred_stride) { + const auto* src = static_cast<const uint8_t*>(reference); + auto* dest = static_cast<uint8_t*>(prediction); + // Note: allow vertical access to height + 1. Because this function is only + // for u/v plane of intra block copy, such access is guaranteed to be within + // the prediction block. + + if (width == 128) { + IntraBlockCopy2D<128>(src, reference_stride, height, dest, pred_stride); + } else if (width == 64) { + IntraBlockCopy2D<64>(src, reference_stride, height, dest, pred_stride); + } else if (width == 32) { + IntraBlockCopy2D<32>(src, reference_stride, height, dest, pred_stride); + } else if (width == 16) { + IntraBlockCopy2D<16>(src, reference_stride, height, dest, pred_stride); + } else if (width == 8) { + IntraBlockCopy2D<8>(src, reference_stride, height, dest, pred_stride); + } else if (width == 4) { + uint8x8_t left = Load4(src); + uint8x8_t right = Load4(src + 1); + src += reference_stride; + + uint16x4_t row = vget_low_u16(vaddl_u8(left, right)); + + int y = 0; + do { + left = Load4<0>(src, left); + right = Load4<0>(src + 1, right); + src += reference_stride; + left = Load4<1>(src, left); + right = Load4<1>(src + 1, right); + src += reference_stride; + + const uint16x8_t below = vaddl_u8(left, right); + + const uint8x8_t result = vrshrn_n_u16( + vaddq_u16(vcombine_u16(row, vget_low_u16(below)), below), 2); + StoreLo4(dest, result); + dest += pred_stride; + StoreHi4(dest, result); + dest += pred_stride; + + row = vget_high_u16(below); + y += 2; + } while (y < height); + } else { + uint8x8_t left = Load2(src); + uint8x8_t right = Load2(src + 1); + src += reference_stride; + + uint16x4_t row = vget_low_u16(vaddl_u8(left, right)); + + int y = 0; + do { + left = Load2<0>(src, left); + right = Load2<0>(src + 1, right); + src += reference_stride; + left = Load2<2>(src, left); + right = Load2<2>(src + 1, right); + src += reference_stride; + + const uint16x8_t below = vaddl_u8(left, right); + + const uint8x8_t result = vrshrn_n_u16( + vaddq_u16(vcombine_u16(row, vget_low_u16(below)), below), 2); + Store2<0>(dest, result); + dest += pred_stride; + Store2<2>(dest, result); + dest += pred_stride; + + row = vget_high_u16(below); + y += 2; + } while (y < height); + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->convolve[0][0][0][1] = ConvolveHorizontal_NEON; + dsp->convolve[0][0][1][0] = ConvolveVertical_NEON; + dsp->convolve[0][0][1][1] = Convolve2D_NEON; + + dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_NEON; + dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_NEON; + dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_NEON; + dsp->convolve[0][1][1][1] = ConvolveCompound2D_NEON; + + dsp->convolve[1][0][0][1] = ConvolveIntraBlockCopyHorizontal_NEON; + dsp->convolve[1][0][1][0] = ConvolveIntraBlockCopyVertical_NEON; + dsp->convolve[1][0][1][1] = ConvolveIntraBlockCopy2D_NEON; + + dsp->convolve_scale[0] = ConvolveScale2D_NEON<false>; + dsp->convolve_scale[1] = ConvolveScale2D_NEON<true>; +} + +} // namespace +} // namespace low_bitdepth + +void ConvolveInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON + +namespace libgav1 { +namespace dsp { + +void ConvolveInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/convolve_neon.h b/src/dsp/arm/convolve_neon.h new file mode 100644 index 0000000..948ef4d --- /dev/null +++ b/src/dsp/arm/convolve_neon.h @@ -0,0 +1,50 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_CONVOLVE_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_CONVOLVE_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::convolve. This function is not thread-safe. +void ConvolveInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_ConvolveHorizontal LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_ConvolveVertical LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_Convolve2D LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_ConvolveCompoundCopy LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_ConvolveCompoundHorizontal LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_ConvolveCompoundVertical LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_ConvolveCompound2D LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopyHorizontal LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopyVertical LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopy2D LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_ConvolveScale2D LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_ConvolveCompoundScale2D LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_CONVOLVE_NEON_H_ diff --git a/src/dsp/arm/distance_weighted_blend_neon.cc b/src/dsp/arm/distance_weighted_blend_neon.cc new file mode 100644 index 0000000..04952ab --- /dev/null +++ b/src/dsp/arm/distance_weighted_blend_neon.cc @@ -0,0 +1,203 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/distance_weighted_blend.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace { + +constexpr int kInterPostRoundBit = 4; + +inline int16x8_t ComputeWeightedAverage8(const int16x8_t pred0, + const int16x8_t pred1, + const int16x4_t weights[2]) { + // TODO(https://issuetracker.google.com/issues/150325685): Investigate range. + const int32x4_t wpred0_lo = vmull_s16(weights[0], vget_low_s16(pred0)); + const int32x4_t wpred0_hi = vmull_s16(weights[0], vget_high_s16(pred0)); + const int32x4_t blended_lo = + vmlal_s16(wpred0_lo, weights[1], vget_low_s16(pred1)); + const int32x4_t blended_hi = + vmlal_s16(wpred0_hi, weights[1], vget_high_s16(pred1)); + + return vcombine_s16(vqrshrn_n_s32(blended_lo, kInterPostRoundBit + 4), + vqrshrn_n_s32(blended_hi, kInterPostRoundBit + 4)); +} + +template <int width, int height> +inline void DistanceWeightedBlendSmall_NEON(const int16_t* prediction_0, + const int16_t* prediction_1, + const int16x4_t weights[2], + void* const dest, + const ptrdiff_t dest_stride) { + auto* dst = static_cast<uint8_t*>(dest); + constexpr int step = 16 / width; + + for (int y = 0; y < height; y += step) { + const int16x8_t src_00 = vld1q_s16(prediction_0); + const int16x8_t src_10 = vld1q_s16(prediction_1); + prediction_0 += 8; + prediction_1 += 8; + const int16x8_t res0 = ComputeWeightedAverage8(src_00, src_10, weights); + + const int16x8_t src_01 = vld1q_s16(prediction_0); + const int16x8_t src_11 = vld1q_s16(prediction_1); + prediction_0 += 8; + prediction_1 += 8; + const int16x8_t res1 = ComputeWeightedAverage8(src_01, src_11, weights); + + const uint8x8_t result0 = vqmovun_s16(res0); + const uint8x8_t result1 = vqmovun_s16(res1); + if (width == 4) { + StoreLo4(dst, result0); + dst += dest_stride; + StoreHi4(dst, result0); + dst += dest_stride; + StoreLo4(dst, result1); + dst += dest_stride; + StoreHi4(dst, result1); + dst += dest_stride; + } else { + assert(width == 8); + vst1_u8(dst, result0); + dst += dest_stride; + vst1_u8(dst, result1); + dst += dest_stride; + } + } +} + +inline void DistanceWeightedBlendLarge_NEON(const int16_t* prediction_0, + const int16_t* prediction_1, + const int16x4_t weights[2], + const int width, const int height, + void* const dest, + const ptrdiff_t dest_stride) { + auto* dst = static_cast<uint8_t*>(dest); + + int y = height; + do { + int x = 0; + do { + const int16x8_t src0_lo = vld1q_s16(prediction_0 + x); + const int16x8_t src1_lo = vld1q_s16(prediction_1 + x); + const int16x8_t res_lo = + ComputeWeightedAverage8(src0_lo, src1_lo, weights); + + const int16x8_t src0_hi = vld1q_s16(prediction_0 + x + 8); + const int16x8_t src1_hi = vld1q_s16(prediction_1 + x + 8); + const int16x8_t res_hi = + ComputeWeightedAverage8(src0_hi, src1_hi, weights); + + const uint8x16_t result = + vcombine_u8(vqmovun_s16(res_lo), vqmovun_s16(res_hi)); + vst1q_u8(dst + x, result); + x += 16; + } while (x < width); + dst += dest_stride; + prediction_0 += width; + prediction_1 += width; + } while (--y != 0); +} + +inline void DistanceWeightedBlend_NEON(const void* prediction_0, + const void* prediction_1, + const uint8_t weight_0, + const uint8_t weight_1, const int width, + const int height, void* const dest, + const ptrdiff_t dest_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int16x4_t weights[2] = {vdup_n_s16(weight_0), vdup_n_s16(weight_1)}; + // TODO(johannkoenig): Investigate the branching. May be fine to call with a + // variable height. + if (width == 4) { + if (height == 4) { + DistanceWeightedBlendSmall_NEON<4, 4>(pred_0, pred_1, weights, dest, + dest_stride); + } else if (height == 8) { + DistanceWeightedBlendSmall_NEON<4, 8>(pred_0, pred_1, weights, dest, + dest_stride); + } else { + assert(height == 16); + DistanceWeightedBlendSmall_NEON<4, 16>(pred_0, pred_1, weights, dest, + dest_stride); + } + return; + } + + if (width == 8) { + switch (height) { + case 4: + DistanceWeightedBlendSmall_NEON<8, 4>(pred_0, pred_1, weights, dest, + dest_stride); + return; + case 8: + DistanceWeightedBlendSmall_NEON<8, 8>(pred_0, pred_1, weights, dest, + dest_stride); + return; + case 16: + DistanceWeightedBlendSmall_NEON<8, 16>(pred_0, pred_1, weights, dest, + dest_stride); + return; + default: + assert(height == 32); + DistanceWeightedBlendSmall_NEON<8, 32>(pred_0, pred_1, weights, dest, + dest_stride); + + return; + } + } + + DistanceWeightedBlendLarge_NEON(pred_0, pred_1, weights, width, height, dest, + dest_stride); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->distance_weighted_blend = DistanceWeightedBlend_NEON; +} + +} // namespace + +void DistanceWeightedBlendInit_NEON() { Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON + +namespace libgav1 { +namespace dsp { + +void DistanceWeightedBlendInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/distance_weighted_blend_neon.h b/src/dsp/arm/distance_weighted_blend_neon.h new file mode 100644 index 0000000..4d8824c --- /dev/null +++ b/src/dsp/arm/distance_weighted_blend_neon.h @@ -0,0 +1,39 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_DISTANCE_WEIGHTED_BLEND_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_DISTANCE_WEIGHTED_BLEND_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::distance_weighted_blend. This function is not thread-safe. +void DistanceWeightedBlendInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +// If NEON is enabled signal the NEON implementation should be used instead of +// normal C. +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_DistanceWeightedBlend LIBGAV1_CPU_NEON + +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_DISTANCE_WEIGHTED_BLEND_NEON_H_ diff --git a/src/dsp/arm/film_grain_neon.cc b/src/dsp/arm/film_grain_neon.cc new file mode 100644 index 0000000..2612466 --- /dev/null +++ b/src/dsp/arm/film_grain_neon.cc @@ -0,0 +1,1188 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/film_grain.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON +#include <arm_neon.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> +#include <new> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/arm/film_grain_neon.h" +#include "src/dsp/common.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/film_grain_common.h" +#include "src/utils/common.h" +#include "src/utils/compiler_attributes.h" +#include "src/utils/logging.h" + +namespace libgav1 { +namespace dsp { +namespace film_grain { +namespace { + +// These functions are overloaded for both possible sizes in order to simplify +// loading and storing to and from intermediate value types from within a +// template function. +inline int16x8_t GetSignedSource8(const int8_t* src) { + return vmovl_s8(vld1_s8(src)); +} + +inline int16x8_t GetSignedSource8(const uint8_t* src) { + return ZeroExtend(vld1_u8(src)); +} + +inline void StoreUnsigned8(uint8_t* dest, const uint16x8_t data) { + vst1_u8(dest, vmovn_u16(data)); +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +inline int16x8_t GetSignedSource8(const int16_t* src) { return vld1q_s16(src); } + +inline int16x8_t GetSignedSource8(const uint16_t* src) { + return vreinterpretq_s16_u16(vld1q_u16(src)); +} + +inline void StoreUnsigned8(uint16_t* dest, const uint16x8_t data) { + vst1q_u16(dest, data); +} +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +// Each element in |sum| represents one destination value's running +// autoregression formula. The fixed source values in |grain_lo| and |grain_hi| +// allow for a sliding window in successive calls to this function. +template <int position_offset> +inline int32x4x2_t AccumulateWeightedGrain(const int16x8_t grain_lo, + const int16x8_t grain_hi, + int16_t coeff, int32x4x2_t sum) { + const int16x8_t grain = vextq_s16(grain_lo, grain_hi, position_offset); + sum.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(grain), coeff); + sum.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(grain), coeff); + return sum; +} + +// Because the autoregressive filter requires the output of each pixel to +// compute pixels that come after in the row, we have to finish the calculations +// one at a time. +template <int bitdepth, int auto_regression_coeff_lag, int lane> +inline void WriteFinalAutoRegression(int8_t* grain_cursor, int32x4x2_t sum, + const int8_t* coeffs, int pos, int shift) { + int32_t result = vgetq_lane_s32(sum.val[lane >> 2], lane & 3); + + for (int delta_col = -auto_regression_coeff_lag; delta_col < 0; ++delta_col) { + result += grain_cursor[lane + delta_col] * coeffs[pos]; + ++pos; + } + grain_cursor[lane] = + Clip3(grain_cursor[lane] + RightShiftWithRounding(result, shift), + GetGrainMin<bitdepth>(), GetGrainMax<bitdepth>()); +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +template <int bitdepth, int auto_regression_coeff_lag, int lane> +inline void WriteFinalAutoRegression(int16_t* grain_cursor, int32x4x2_t sum, + const int8_t* coeffs, int pos, int shift) { + int32_t result = vgetq_lane_s32(sum.val[lane >> 2], lane & 3); + + for (int delta_col = -auto_regression_coeff_lag; delta_col < 0; ++delta_col) { + result += grain_cursor[lane + delta_col] * coeffs[pos]; + ++pos; + } + grain_cursor[lane] = + Clip3(grain_cursor[lane] + RightShiftWithRounding(result, shift), + GetGrainMin<bitdepth>(), GetGrainMax<bitdepth>()); +} +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +// Because the autoregressive filter requires the output of each pixel to +// compute pixels that come after in the row, we have to finish the calculations +// one at a time. +template <int bitdepth, int auto_regression_coeff_lag, int lane> +inline void WriteFinalAutoRegressionChroma(int8_t* u_grain_cursor, + int8_t* v_grain_cursor, + int32x4x2_t sum_u, int32x4x2_t sum_v, + const int8_t* coeffs_u, + const int8_t* coeffs_v, int pos, + int shift) { + WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>( + u_grain_cursor, sum_u, coeffs_u, pos, shift); + WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>( + v_grain_cursor, sum_v, coeffs_v, pos, shift); +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +template <int bitdepth, int auto_regression_coeff_lag, int lane> +inline void WriteFinalAutoRegressionChroma(int16_t* u_grain_cursor, + int16_t* v_grain_cursor, + int32x4x2_t sum_u, int32x4x2_t sum_v, + const int8_t* coeffs_u, + const int8_t* coeffs_v, int pos, + int shift) { + WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>( + u_grain_cursor, sum_u, coeffs_u, pos, shift); + WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>( + v_grain_cursor, sum_v, coeffs_v, pos, shift); +} +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +inline void SetZero(int32x4x2_t* v) { + v->val[0] = vdupq_n_s32(0); + v->val[1] = vdupq_n_s32(0); +} + +// Computes subsampled luma for use with chroma, by averaging in the x direction +// or y direction when applicable. +int16x8_t GetSubsampledLuma(const int8_t* const luma, int subsampling_x, + int subsampling_y, ptrdiff_t stride) { + if (subsampling_y != 0) { + assert(subsampling_x != 0); + const int8x16_t src0 = vld1q_s8(luma); + const int8x16_t src1 = vld1q_s8(luma + stride); + const int16x8_t ret0 = vcombine_s16(vpaddl_s8(vget_low_s8(src0)), + vpaddl_s8(vget_high_s8(src0))); + const int16x8_t ret1 = vcombine_s16(vpaddl_s8(vget_low_s8(src1)), + vpaddl_s8(vget_high_s8(src1))); + return vrshrq_n_s16(vaddq_s16(ret0, ret1), 2); + } + if (subsampling_x != 0) { + const int8x16_t src = vld1q_s8(luma); + return vrshrq_n_s16( + vcombine_s16(vpaddl_s8(vget_low_s8(src)), vpaddl_s8(vget_high_s8(src))), + 1); + } + return vmovl_s8(vld1_s8(luma)); +} + +// For BlendNoiseWithImageChromaWithCfl, only |subsampling_x| is needed. +inline uint16x8_t GetAverageLuma(const uint8_t* const luma, int subsampling_x) { + if (subsampling_x != 0) { + const uint8x16_t src = vld1q_u8(luma); + return vrshrq_n_u16(vpaddlq_u8(src), 1); + } + return vmovl_u8(vld1_u8(luma)); +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +// Computes subsampled luma for use with chroma, by averaging in the x direction +// or y direction when applicable. +int16x8_t GetSubsampledLuma(const int16_t* const luma, int subsampling_x, + int subsampling_y, ptrdiff_t stride) { + if (subsampling_y != 0) { + assert(subsampling_x != 0); + int16x8_t src0_lo = vld1q_s16(luma); + int16x8_t src0_hi = vld1q_s16(luma + 8); + const int16x8_t src1_lo = vld1q_s16(luma + stride); + const int16x8_t src1_hi = vld1q_s16(luma + stride + 8); + const int16x8_t src0 = + vcombine_s16(vpadd_s16(vget_low_s16(src0_lo), vget_high_s16(src0_lo)), + vpadd_s16(vget_low_s16(src0_hi), vget_high_s16(src0_hi))); + const int16x8_t src1 = + vcombine_s16(vpadd_s16(vget_low_s16(src1_lo), vget_high_s16(src1_lo)), + vpadd_s16(vget_low_s16(src1_hi), vget_high_s16(src1_hi))); + return vrshrq_n_s16(vaddq_s16(src0, src1), 2); + } + if (subsampling_x != 0) { + const int16x8_t src_lo = vld1q_s16(luma); + const int16x8_t src_hi = vld1q_s16(luma + 8); + const int16x8_t ret = + vcombine_s16(vpadd_s16(vget_low_s16(src_lo), vget_high_s16(src_lo)), + vpadd_s16(vget_low_s16(src_hi), vget_high_s16(src_hi))); + return vrshrq_n_s16(ret, 1); + } + return vld1q_s16(luma); +} + +// For BlendNoiseWithImageChromaWithCfl, only |subsampling_x| is needed. +inline uint16x8_t GetAverageLuma(const uint16_t* const luma, + int subsampling_x) { + if (subsampling_x != 0) { + const uint16x8x2_t src = vld2q_u16(luma); + return vrhaddq_u16(src.val[0], src.val[1]); + } + return vld1q_u16(luma); +} +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +template <int bitdepth, typename GrainType, int auto_regression_coeff_lag, + bool use_luma> +void ApplyAutoRegressiveFilterToChromaGrains_NEON(const FilmGrainParams& params, + const void* luma_grain_buffer, + int subsampling_x, + int subsampling_y, + void* u_grain_buffer, + void* v_grain_buffer) { + static_assert(auto_regression_coeff_lag <= 3, "Invalid autoregression lag."); + const auto* luma_grain = static_cast<const GrainType*>(luma_grain_buffer); + auto* u_grain = static_cast<GrainType*>(u_grain_buffer); + auto* v_grain = static_cast<GrainType*>(v_grain_buffer); + const int auto_regression_shift = params.auto_regression_shift; + const int chroma_width = + (subsampling_x == 0) ? kMaxChromaWidth : kMinChromaWidth; + const int chroma_height = + (subsampling_y == 0) ? kMaxChromaHeight : kMinChromaHeight; + // When |chroma_width| == 44, we write 8 at a time from x in [3, 34], + // leaving [35, 40] to write at the end. + const int chroma_width_remainder = + (chroma_width - 2 * kAutoRegressionBorder) & 7; + + int y = kAutoRegressionBorder; + luma_grain += kLumaWidth * y; + u_grain += chroma_width * y; + v_grain += chroma_width * y; + do { + // Each row is computed 8 values at a time in the following loop. At the + // end of the loop, 4 values remain to write. They are given a special + // reduced iteration at the end. + int x = kAutoRegressionBorder; + int luma_x = kAutoRegressionBorder; + do { + int pos = 0; + int32x4x2_t sum_u; + int32x4x2_t sum_v; + SetZero(&sum_u); + SetZero(&sum_v); + + if (auto_regression_coeff_lag > 0) { + for (int delta_row = -auto_regression_coeff_lag; delta_row < 0; + ++delta_row) { + // These loads may overflow to the next row, but they are never called + // on the final row of a grain block. Therefore, they will never + // exceed the block boundaries. + // Note: this could be slightly optimized to a single load in 8bpp, + // but requires making a special first iteration and accumulate + // function that takes an int8x16_t. + const int16x8_t u_grain_lo = + GetSignedSource8(u_grain + x + delta_row * chroma_width - + auto_regression_coeff_lag); + const int16x8_t u_grain_hi = + GetSignedSource8(u_grain + x + delta_row * chroma_width - + auto_regression_coeff_lag + 8); + const int16x8_t v_grain_lo = + GetSignedSource8(v_grain + x + delta_row * chroma_width - + auto_regression_coeff_lag); + const int16x8_t v_grain_hi = + GetSignedSource8(v_grain + x + delta_row * chroma_width - + auto_regression_coeff_lag + 8); +#define ACCUMULATE_WEIGHTED_GRAIN(offset) \ + sum_u = AccumulateWeightedGrain<offset>( \ + u_grain_lo, u_grain_hi, params.auto_regression_coeff_u[pos], sum_u); \ + sum_v = AccumulateWeightedGrain<offset>( \ + v_grain_lo, v_grain_hi, params.auto_regression_coeff_v[pos++], sum_v) + + ACCUMULATE_WEIGHTED_GRAIN(0); + ACCUMULATE_WEIGHTED_GRAIN(1); + ACCUMULATE_WEIGHTED_GRAIN(2); + // The horizontal |auto_regression_coeff_lag| loop is replaced with + // if-statements to give vextq_s16 an immediate param. + if (auto_regression_coeff_lag > 1) { + ACCUMULATE_WEIGHTED_GRAIN(3); + ACCUMULATE_WEIGHTED_GRAIN(4); + } + if (auto_regression_coeff_lag > 2) { + assert(auto_regression_coeff_lag == 3); + ACCUMULATE_WEIGHTED_GRAIN(5); + ACCUMULATE_WEIGHTED_GRAIN(6); + } + } + } + + if (use_luma) { + const int16x8_t luma = GetSubsampledLuma( + luma_grain + luma_x, subsampling_x, subsampling_y, kLumaWidth); + + // Luma samples get the final coefficient in the formula, but are best + // computed all at once before the final row. + const int coeff_u = + params.auto_regression_coeff_u[pos + auto_regression_coeff_lag]; + const int coeff_v = + params.auto_regression_coeff_v[pos + auto_regression_coeff_lag]; + + sum_u.val[0] = vmlal_n_s16(sum_u.val[0], vget_low_s16(luma), coeff_u); + sum_u.val[1] = vmlal_n_s16(sum_u.val[1], vget_high_s16(luma), coeff_u); + sum_v.val[0] = vmlal_n_s16(sum_v.val[0], vget_low_s16(luma), coeff_v); + sum_v.val[1] = vmlal_n_s16(sum_v.val[1], vget_high_s16(luma), coeff_v); + } + // At this point in the filter, the source addresses and destination + // addresses overlap. Because this is an auto-regressive filter, the + // higher lanes cannot be computed without the results of the lower lanes. + // Each call to WriteFinalAutoRegression incorporates preceding values + // on the final row, and writes a single sample. This allows the next + // pixel's value to be computed in the next call. +#define WRITE_AUTO_REGRESSION_RESULT(lane) \ + WriteFinalAutoRegressionChroma<bitdepth, auto_regression_coeff_lag, lane>( \ + u_grain + x, v_grain + x, sum_u, sum_v, params.auto_regression_coeff_u, \ + params.auto_regression_coeff_v, pos, auto_regression_shift) + + WRITE_AUTO_REGRESSION_RESULT(0); + WRITE_AUTO_REGRESSION_RESULT(1); + WRITE_AUTO_REGRESSION_RESULT(2); + WRITE_AUTO_REGRESSION_RESULT(3); + WRITE_AUTO_REGRESSION_RESULT(4); + WRITE_AUTO_REGRESSION_RESULT(5); + WRITE_AUTO_REGRESSION_RESULT(6); + WRITE_AUTO_REGRESSION_RESULT(7); + + x += 8; + luma_x += 8 << subsampling_x; + } while (x < chroma_width - kAutoRegressionBorder - chroma_width_remainder); + + // This is the "final iteration" of the above loop over width. We fill in + // the remainder of the width, which is less than 8. + int pos = 0; + int32x4x2_t sum_u; + int32x4x2_t sum_v; + SetZero(&sum_u); + SetZero(&sum_v); + + for (int delta_row = -auto_regression_coeff_lag; delta_row < 0; + ++delta_row) { + // These loads may overflow to the next row, but they are never called on + // the final row of a grain block. Therefore, they will never exceed the + // block boundaries. + const int16x8_t u_grain_lo = GetSignedSource8( + u_grain + x + delta_row * chroma_width - auto_regression_coeff_lag); + const int16x8_t u_grain_hi = + GetSignedSource8(u_grain + x + delta_row * chroma_width - + auto_regression_coeff_lag + 8); + const int16x8_t v_grain_lo = GetSignedSource8( + v_grain + x + delta_row * chroma_width - auto_regression_coeff_lag); + const int16x8_t v_grain_hi = + GetSignedSource8(v_grain + x + delta_row * chroma_width - + auto_regression_coeff_lag + 8); + + ACCUMULATE_WEIGHTED_GRAIN(0); + ACCUMULATE_WEIGHTED_GRAIN(1); + ACCUMULATE_WEIGHTED_GRAIN(2); + // The horizontal |auto_regression_coeff_lag| loop is replaced with + // if-statements to give vextq_s16 an immediate param. + if (auto_regression_coeff_lag > 1) { + ACCUMULATE_WEIGHTED_GRAIN(3); + ACCUMULATE_WEIGHTED_GRAIN(4); + } + if (auto_regression_coeff_lag > 2) { + assert(auto_regression_coeff_lag == 3); + ACCUMULATE_WEIGHTED_GRAIN(5); + ACCUMULATE_WEIGHTED_GRAIN(6); + } + } + + if (use_luma) { + const int16x8_t luma = GetSubsampledLuma( + luma_grain + luma_x, subsampling_x, subsampling_y, kLumaWidth); + + // Luma samples get the final coefficient in the formula, but are best + // computed all at once before the final row. + const int coeff_u = + params.auto_regression_coeff_u[pos + auto_regression_coeff_lag]; + const int coeff_v = + params.auto_regression_coeff_v[pos + auto_regression_coeff_lag]; + + sum_u.val[0] = vmlal_n_s16(sum_u.val[0], vget_low_s16(luma), coeff_u); + sum_u.val[1] = vmlal_n_s16(sum_u.val[1], vget_high_s16(luma), coeff_u); + sum_v.val[0] = vmlal_n_s16(sum_v.val[0], vget_low_s16(luma), coeff_v); + sum_v.val[1] = vmlal_n_s16(sum_v.val[1], vget_high_s16(luma), coeff_v); + } + + WRITE_AUTO_REGRESSION_RESULT(0); + WRITE_AUTO_REGRESSION_RESULT(1); + WRITE_AUTO_REGRESSION_RESULT(2); + WRITE_AUTO_REGRESSION_RESULT(3); + if (chroma_width_remainder == 6) { + WRITE_AUTO_REGRESSION_RESULT(4); + WRITE_AUTO_REGRESSION_RESULT(5); + } + + luma_grain += kLumaWidth << subsampling_y; + u_grain += chroma_width; + v_grain += chroma_width; + } while (++y < chroma_height); +#undef ACCUMULATE_WEIGHTED_GRAIN +#undef WRITE_AUTO_REGRESSION_RESULT +} + +// Applies an auto-regressive filter to the white noise in luma_grain. +template <int bitdepth, typename GrainType, int auto_regression_coeff_lag> +void ApplyAutoRegressiveFilterToLumaGrain_NEON(const FilmGrainParams& params, + void* luma_grain_buffer) { + static_assert(auto_regression_coeff_lag > 0, ""); + const int8_t* const auto_regression_coeff_y = params.auto_regression_coeff_y; + const uint8_t auto_regression_shift = params.auto_regression_shift; + + int y = kAutoRegressionBorder; + auto* luma_grain = + static_cast<GrainType*>(luma_grain_buffer) + kLumaWidth * y; + do { + // Each row is computed 8 values at a time in the following loop. At the + // end of the loop, 4 values remain to write. They are given a special + // reduced iteration at the end. + int x = kAutoRegressionBorder; + do { + int pos = 0; + int32x4x2_t sum; + SetZero(&sum); + for (int delta_row = -auto_regression_coeff_lag; delta_row < 0; + ++delta_row) { + // These loads may overflow to the next row, but they are never called + // on the final row of a grain block. Therefore, they will never exceed + // the block boundaries. + const int16x8_t src_grain_lo = + GetSignedSource8(luma_grain + x + delta_row * kLumaWidth - + auto_regression_coeff_lag); + const int16x8_t src_grain_hi = + GetSignedSource8(luma_grain + x + delta_row * kLumaWidth - + auto_regression_coeff_lag + 8); + + // A pictorial representation of the auto-regressive filter for + // various values of params.auto_regression_coeff_lag. The letter 'O' + // represents the current sample. (The filter always operates on the + // current sample with filter coefficient 1.) The letters 'X' + // represent the neighboring samples that the filter operates on, below + // their corresponding "offset" number. + // + // params.auto_regression_coeff_lag == 3: + // 0 1 2 3 4 5 6 + // X X X X X X X + // X X X X X X X + // X X X X X X X + // X X X O + // params.auto_regression_coeff_lag == 2: + // 0 1 2 3 4 + // X X X X X + // X X X X X + // X X O + // params.auto_regression_coeff_lag == 1: + // 0 1 2 + // X X X + // X O + // params.auto_regression_coeff_lag == 0: + // O + // The function relies on the caller to skip the call in the 0 lag + // case. + +#define ACCUMULATE_WEIGHTED_GRAIN(offset) \ + sum = AccumulateWeightedGrain<offset>(src_grain_lo, src_grain_hi, \ + auto_regression_coeff_y[pos++], sum) + ACCUMULATE_WEIGHTED_GRAIN(0); + ACCUMULATE_WEIGHTED_GRAIN(1); + ACCUMULATE_WEIGHTED_GRAIN(2); + // The horizontal |auto_regression_coeff_lag| loop is replaced with + // if-statements to give vextq_s16 an immediate param. + if (auto_regression_coeff_lag > 1) { + ACCUMULATE_WEIGHTED_GRAIN(3); + ACCUMULATE_WEIGHTED_GRAIN(4); + } + if (auto_regression_coeff_lag > 2) { + assert(auto_regression_coeff_lag == 3); + ACCUMULATE_WEIGHTED_GRAIN(5); + ACCUMULATE_WEIGHTED_GRAIN(6); + } + } + // At this point in the filter, the source addresses and destination + // addresses overlap. Because this is an auto-regressive filter, the + // higher lanes cannot be computed without the results of the lower lanes. + // Each call to WriteFinalAutoRegression incorporates preceding values + // on the final row, and writes a single sample. This allows the next + // pixel's value to be computed in the next call. +#define WRITE_AUTO_REGRESSION_RESULT(lane) \ + WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>( \ + luma_grain + x, sum, auto_regression_coeff_y, pos, \ + auto_regression_shift) + + WRITE_AUTO_REGRESSION_RESULT(0); + WRITE_AUTO_REGRESSION_RESULT(1); + WRITE_AUTO_REGRESSION_RESULT(2); + WRITE_AUTO_REGRESSION_RESULT(3); + WRITE_AUTO_REGRESSION_RESULT(4); + WRITE_AUTO_REGRESSION_RESULT(5); + WRITE_AUTO_REGRESSION_RESULT(6); + WRITE_AUTO_REGRESSION_RESULT(7); + x += 8; + // Leave the final four pixels for the special iteration below. + } while (x < kLumaWidth - kAutoRegressionBorder - 4); + + // Final 4 pixels in the row. + int pos = 0; + int32x4x2_t sum; + SetZero(&sum); + for (int delta_row = -auto_regression_coeff_lag; delta_row < 0; + ++delta_row) { + const int16x8_t src_grain_lo = GetSignedSource8( + luma_grain + x + delta_row * kLumaWidth - auto_regression_coeff_lag); + const int16x8_t src_grain_hi = + GetSignedSource8(luma_grain + x + delta_row * kLumaWidth - + auto_regression_coeff_lag + 8); + + ACCUMULATE_WEIGHTED_GRAIN(0); + ACCUMULATE_WEIGHTED_GRAIN(1); + ACCUMULATE_WEIGHTED_GRAIN(2); + // The horizontal |auto_regression_coeff_lag| loop is replaced with + // if-statements to give vextq_s16 an immediate param. + if (auto_regression_coeff_lag > 1) { + ACCUMULATE_WEIGHTED_GRAIN(3); + ACCUMULATE_WEIGHTED_GRAIN(4); + } + if (auto_regression_coeff_lag > 2) { + assert(auto_regression_coeff_lag == 3); + ACCUMULATE_WEIGHTED_GRAIN(5); + ACCUMULATE_WEIGHTED_GRAIN(6); + } + } + // delta_row == 0 + WRITE_AUTO_REGRESSION_RESULT(0); + WRITE_AUTO_REGRESSION_RESULT(1); + WRITE_AUTO_REGRESSION_RESULT(2); + WRITE_AUTO_REGRESSION_RESULT(3); + luma_grain += kLumaWidth; + } while (++y < kLumaHeight); + +#undef WRITE_AUTO_REGRESSION_RESULT +#undef ACCUMULATE_WEIGHTED_GRAIN +} + +void InitializeScalingLookupTable_NEON( + int num_points, const uint8_t point_value[], const uint8_t point_scaling[], + uint8_t scaling_lut[kScalingLookupTableSize]) { + if (num_points == 0) { + memset(scaling_lut, 0, sizeof(scaling_lut[0]) * kScalingLookupTableSize); + return; + } + static_assert(sizeof(scaling_lut[0]) == 1, ""); + memset(scaling_lut, point_scaling[0], point_value[0]); + const uint32x4_t steps = vmovl_u16(vcreate_u16(0x0003000200010000)); + const uint32x4_t offset = vdupq_n_u32(32768); + for (int i = 0; i < num_points - 1; ++i) { + const int delta_y = point_scaling[i + 1] - point_scaling[i]; + const int delta_x = point_value[i + 1] - point_value[i]; + const int delta = delta_y * ((65536 + (delta_x >> 1)) / delta_x); + const int delta4 = delta << 2; + const uint8x8_t base_point = vdup_n_u8(point_scaling[i]); + uint32x4_t upscaled_points0 = vmlaq_n_u32(offset, steps, delta); + const uint32x4_t line_increment4 = vdupq_n_u32(delta4); + // Get the second set of 4 points by adding 4 steps to the first set. + uint32x4_t upscaled_points1 = vaddq_u32(upscaled_points0, line_increment4); + // We obtain the next set of 8 points by adding 8 steps to each of the + // current 8 points. + const uint32x4_t line_increment8 = vshlq_n_u32(line_increment4, 1); + int x = 0; + do { + const uint16x4_t interp_points0 = vshrn_n_u32(upscaled_points0, 16); + const uint16x4_t interp_points1 = vshrn_n_u32(upscaled_points1, 16); + const uint8x8_t interp_points = + vmovn_u16(vcombine_u16(interp_points0, interp_points1)); + // The spec guarantees that the max value of |point_value[i]| + x is 255. + // Writing 8 bytes starting at the final table byte, leaves 7 bytes of + // required padding. + vst1_u8(&scaling_lut[point_value[i] + x], + vadd_u8(interp_points, base_point)); + upscaled_points0 = vaddq_u32(upscaled_points0, line_increment8); + upscaled_points1 = vaddq_u32(upscaled_points1, line_increment8); + x += 8; + } while (x < delta_x); + } + const uint8_t last_point_value = point_value[num_points - 1]; + memset(&scaling_lut[last_point_value], point_scaling[num_points - 1], + kScalingLookupTableSize - last_point_value); +} + +inline int16x8_t Clip3(const int16x8_t value, const int16x8_t low, + const int16x8_t high) { + const int16x8_t clipped_to_ceiling = vminq_s16(high, value); + return vmaxq_s16(low, clipped_to_ceiling); +} + +template <int bitdepth, typename Pixel> +inline int16x8_t GetScalingFactors( + const uint8_t scaling_lut[kScalingLookupTableSize], const Pixel* source) { + int16_t start_vals[8]; + if (bitdepth == 8) { + start_vals[0] = scaling_lut[source[0]]; + start_vals[1] = scaling_lut[source[1]]; + start_vals[2] = scaling_lut[source[2]]; + start_vals[3] = scaling_lut[source[3]]; + start_vals[4] = scaling_lut[source[4]]; + start_vals[5] = scaling_lut[source[5]]; + start_vals[6] = scaling_lut[source[6]]; + start_vals[7] = scaling_lut[source[7]]; + return vld1q_s16(start_vals); + } + int16_t end_vals[8]; + // TODO(petersonab): Precompute this into a larger table for direct lookups. + int index = source[0] >> 2; + start_vals[0] = scaling_lut[index]; + end_vals[0] = scaling_lut[index + 1]; + index = source[1] >> 2; + start_vals[1] = scaling_lut[index]; + end_vals[1] = scaling_lut[index + 1]; + index = source[2] >> 2; + start_vals[2] = scaling_lut[index]; + end_vals[2] = scaling_lut[index + 1]; + index = source[3] >> 2; + start_vals[3] = scaling_lut[index]; + end_vals[3] = scaling_lut[index + 1]; + index = source[4] >> 2; + start_vals[4] = scaling_lut[index]; + end_vals[4] = scaling_lut[index + 1]; + index = source[5] >> 2; + start_vals[5] = scaling_lut[index]; + end_vals[5] = scaling_lut[index + 1]; + index = source[6] >> 2; + start_vals[6] = scaling_lut[index]; + end_vals[6] = scaling_lut[index + 1]; + index = source[7] >> 2; + start_vals[7] = scaling_lut[index]; + end_vals[7] = scaling_lut[index + 1]; + const int16x8_t start = vld1q_s16(start_vals); + const int16x8_t end = vld1q_s16(end_vals); + int16x8_t remainder = GetSignedSource8(source); + remainder = vandq_s16(remainder, vdupq_n_s16(3)); + const int16x8_t delta = vmulq_s16(vsubq_s16(end, start), remainder); + return vaddq_s16(start, vrshrq_n_s16(delta, 2)); +} + +inline int16x8_t ScaleNoise(const int16x8_t noise, const int16x8_t scaling, + const int16x8_t scaling_shift_vect) { + const int16x8_t upscaled_noise = vmulq_s16(noise, scaling); + return vrshlq_s16(upscaled_noise, scaling_shift_vect); +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +inline int16x8_t ScaleNoise(const int16x8_t noise, const int16x8_t scaling, + const int32x4_t scaling_shift_vect) { + // TODO(petersonab): Try refactoring scaling lookup table to int16_t and + // upscaling by 7 bits to permit high half multiply. This would eliminate + // the intermediate 32x4 registers. Also write the averaged values directly + // into the table so it doesn't have to be done for every pixel in + // the frame. + const int32x4_t upscaled_noise_lo = + vmull_s16(vget_low_s16(noise), vget_low_s16(scaling)); + const int32x4_t upscaled_noise_hi = + vmull_s16(vget_high_s16(noise), vget_high_s16(scaling)); + const int16x4_t noise_lo = + vmovn_s32(vrshlq_s32(upscaled_noise_lo, scaling_shift_vect)); + const int16x4_t noise_hi = + vmovn_s32(vrshlq_s32(upscaled_noise_hi, scaling_shift_vect)); + return vcombine_s16(noise_lo, noise_hi); +} +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +template <int bitdepth, typename GrainType, typename Pixel> +void BlendNoiseWithImageLuma_NEON( + const void* noise_image_ptr, int min_value, int max_luma, int scaling_shift, + int width, int height, int start_height, + const uint8_t scaling_lut_y[kScalingLookupTableSize], + const void* source_plane_y, ptrdiff_t source_stride_y, void* dest_plane_y, + ptrdiff_t dest_stride_y) { + const auto* noise_image = + static_cast<const Array2D<GrainType>*>(noise_image_ptr); + const auto* in_y_row = static_cast<const Pixel*>(source_plane_y); + source_stride_y /= sizeof(Pixel); + auto* out_y_row = static_cast<Pixel*>(dest_plane_y); + dest_stride_y /= sizeof(Pixel); + const int16x8_t floor = vdupq_n_s16(min_value); + const int16x8_t ceiling = vdupq_n_s16(max_luma); + // In 8bpp, the maximum upscaled noise is 127*255 = 0x7E81, which is safe + // for 16 bit signed integers. In higher bitdepths, however, we have to + // expand to 32 to protect the sign bit. + const int16x8_t scaling_shift_vect16 = vdupq_n_s16(-scaling_shift); +#if LIBGAV1_MAX_BITDEPTH >= 10 + const int32x4_t scaling_shift_vect32 = vdupq_n_s32(-scaling_shift); +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + + int y = 0; + do { + int x = 0; + do { + // This operation on the unsigned input is safe in 8bpp because the vector + // is widened before it is reinterpreted. + const int16x8_t orig = GetSignedSource8(&in_y_row[x]); + const int16x8_t scaling = + GetScalingFactors<bitdepth, Pixel>(scaling_lut_y, &in_y_row[x]); + int16x8_t noise = + GetSignedSource8(&(noise_image[kPlaneY][y + start_height][x])); + + if (bitdepth == 8) { + noise = ScaleNoise(noise, scaling, scaling_shift_vect16); + } else { +#if LIBGAV1_MAX_BITDEPTH >= 10 + noise = ScaleNoise(noise, scaling, scaling_shift_vect32); +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + } + const int16x8_t combined = vaddq_s16(orig, noise); + // In 8bpp, when params_.clip_to_restricted_range == false, we can replace + // clipping with vqmovun_s16, but it's not likely to be worth copying the + // function for just that case, though the gain would be very small. + StoreUnsigned8(&out_y_row[x], + vreinterpretq_u16_s16(Clip3(combined, floor, ceiling))); + x += 8; + } while (x < width); + in_y_row += source_stride_y; + out_y_row += dest_stride_y; + } while (++y < height); +} + +template <int bitdepth, typename GrainType, typename Pixel> +inline int16x8_t BlendChromaValsWithCfl( + const Pixel* average_luma_buffer, + const uint8_t scaling_lut[kScalingLookupTableSize], + const Pixel* chroma_cursor, const GrainType* noise_image_cursor, + const int16x8_t scaling_shift_vect16, + const int32x4_t scaling_shift_vect32) { + const int16x8_t scaling = + GetScalingFactors<bitdepth, Pixel>(scaling_lut, average_luma_buffer); + const int16x8_t orig = GetSignedSource8(chroma_cursor); + int16x8_t noise = GetSignedSource8(noise_image_cursor); + if (bitdepth == 8) { + noise = ScaleNoise(noise, scaling, scaling_shift_vect16); + } else { + noise = ScaleNoise(noise, scaling, scaling_shift_vect32); + } + return vaddq_s16(orig, noise); +} + +template <int bitdepth, typename GrainType, typename Pixel> +LIBGAV1_ALWAYS_INLINE void BlendChromaPlaneWithCfl_NEON( + const Array2D<GrainType>& noise_image, int min_value, int max_chroma, + int width, int height, int start_height, int subsampling_x, + int subsampling_y, int scaling_shift, + const uint8_t scaling_lut[kScalingLookupTableSize], const Pixel* in_y_row, + ptrdiff_t source_stride_y, const Pixel* in_chroma_row, + ptrdiff_t source_stride_chroma, Pixel* out_chroma_row, + ptrdiff_t dest_stride) { + const int16x8_t floor = vdupq_n_s16(min_value); + const int16x8_t ceiling = vdupq_n_s16(max_chroma); + Pixel luma_buffer[16]; + memset(luma_buffer, 0, sizeof(luma_buffer)); + // In 8bpp, the maximum upscaled noise is 127*255 = 0x7E81, which is safe + // for 16 bit signed integers. In higher bitdepths, however, we have to + // expand to 32 to protect the sign bit. + const int16x8_t scaling_shift_vect16 = vdupq_n_s16(-scaling_shift); + const int32x4_t scaling_shift_vect32 = vdupq_n_s32(-scaling_shift); + + const int chroma_height = (height + subsampling_y) >> subsampling_y; + const int chroma_width = (width + subsampling_x) >> subsampling_x; + const int safe_chroma_width = chroma_width & ~7; + + // Writing to this buffer avoids the cost of doing 8 lane lookups in a row + // in GetScalingFactors. + Pixel average_luma_buffer[8]; + assert(start_height % 2 == 0); + start_height >>= subsampling_y; + int y = 0; + do { + int x = 0; + do { + const int luma_x = x << subsampling_x; + // TODO(petersonab): Consider specializing by subsampling_x. In the 444 + // case &in_y_row[x] can be passed to GetScalingFactors directly. + const uint16x8_t average_luma = + GetAverageLuma(&in_y_row[luma_x], subsampling_x); + StoreUnsigned8(average_luma_buffer, average_luma); + + const int16x8_t blended = + BlendChromaValsWithCfl<bitdepth, GrainType, Pixel>( + average_luma_buffer, scaling_lut, &in_chroma_row[x], + &(noise_image[y + start_height][x]), scaling_shift_vect16, + scaling_shift_vect32); + + // In 8bpp, when params_.clip_to_restricted_range == false, we can replace + // clipping with vqmovun_s16, but it's not likely to be worth copying the + // function for just that case. + StoreUnsigned8(&out_chroma_row[x], + vreinterpretq_u16_s16(Clip3(blended, floor, ceiling))); + x += 8; + } while (x < safe_chroma_width); + + if (x < chroma_width) { + const int luma_x = x << subsampling_x; + const int valid_range = width - luma_x; + memcpy(luma_buffer, &in_y_row[luma_x], valid_range * sizeof(in_y_row[0])); + luma_buffer[valid_range] = in_y_row[width - 1]; + const uint16x8_t average_luma = + GetAverageLuma(luma_buffer, subsampling_x); + StoreUnsigned8(average_luma_buffer, average_luma); + + const int16x8_t blended = + BlendChromaValsWithCfl<bitdepth, GrainType, Pixel>( + average_luma_buffer, scaling_lut, &in_chroma_row[x], + &(noise_image[y + start_height][x]), scaling_shift_vect16, + scaling_shift_vect32); + // In 8bpp, when params_.clip_to_restricted_range == false, we can replace + // clipping with vqmovun_s16, but it's not likely to be worth copying the + // function for just that case. + StoreUnsigned8(&out_chroma_row[x], + vreinterpretq_u16_s16(Clip3(blended, floor, ceiling))); + } + + in_y_row += source_stride_y << subsampling_y; + in_chroma_row += source_stride_chroma; + out_chroma_row += dest_stride; + } while (++y < chroma_height); +} + +// This function is for the case params_.chroma_scaling_from_luma == true. +// This further implies that scaling_lut_u == scaling_lut_v == scaling_lut_y. +template <int bitdepth, typename GrainType, typename Pixel> +void BlendNoiseWithImageChromaWithCfl_NEON( + Plane plane, const FilmGrainParams& params, const void* noise_image_ptr, + int min_value, int max_chroma, int width, int height, int start_height, + int subsampling_x, int subsampling_y, + const uint8_t scaling_lut[kScalingLookupTableSize], + const void* source_plane_y, ptrdiff_t source_stride_y, + const void* source_plane_uv, ptrdiff_t source_stride_uv, + void* dest_plane_uv, ptrdiff_t dest_stride_uv) { + const auto* noise_image = + static_cast<const Array2D<GrainType>*>(noise_image_ptr); + const auto* in_y = static_cast<const Pixel*>(source_plane_y); + source_stride_y /= sizeof(Pixel); + + const auto* in_uv = static_cast<const Pixel*>(source_plane_uv); + source_stride_uv /= sizeof(Pixel); + auto* out_uv = static_cast<Pixel*>(dest_plane_uv); + dest_stride_uv /= sizeof(Pixel); + // Looping over one plane at a time is faster in higher resolutions, despite + // re-computing luma. + BlendChromaPlaneWithCfl_NEON<bitdepth, GrainType, Pixel>( + noise_image[plane], min_value, max_chroma, width, height, start_height, + subsampling_x, subsampling_y, params.chroma_scaling, scaling_lut, in_y, + source_stride_y, in_uv, source_stride_uv, out_uv, dest_stride_uv); +} + +} // namespace + +namespace low_bitdepth { +namespace { + +inline int16x8_t BlendChromaValsNoCfl( + const uint8_t scaling_lut[kScalingLookupTableSize], + const uint8_t* chroma_cursor, const int8_t* noise_image_cursor, + const int16x8_t& average_luma, const int16x8_t& scaling_shift_vect, + const int16x8_t& offset, int luma_multiplier, int chroma_multiplier) { + uint8_t merged_buffer[8]; + const int16x8_t orig = GetSignedSource8(chroma_cursor); + const int16x8_t weighted_luma = vmulq_n_s16(average_luma, luma_multiplier); + const int16x8_t weighted_chroma = vmulq_n_s16(orig, chroma_multiplier); + // Maximum value of |combined_u| is 127*255 = 0x7E81. + const int16x8_t combined = vhaddq_s16(weighted_luma, weighted_chroma); + // Maximum value of u_offset is (255 << 5) = 0x1FE0. + // 0x7E81 + 0x1FE0 = 0x9E61, therefore another halving add is required. + const uint8x8_t merged = vqshrun_n_s16(vhaddq_s16(offset, combined), 4); + vst1_u8(merged_buffer, merged); + const int16x8_t scaling = + GetScalingFactors<8, uint8_t>(scaling_lut, merged_buffer); + int16x8_t noise = GetSignedSource8(noise_image_cursor); + noise = ScaleNoise(noise, scaling, scaling_shift_vect); + return vaddq_s16(orig, noise); +} + +LIBGAV1_ALWAYS_INLINE void BlendChromaPlane8bpp_NEON( + const Array2D<int8_t>& noise_image, int min_value, int max_chroma, + int width, int height, int start_height, int subsampling_x, + int subsampling_y, int scaling_shift, int chroma_offset, + int chroma_multiplier, int luma_multiplier, + const uint8_t scaling_lut[kScalingLookupTableSize], const uint8_t* in_y_row, + ptrdiff_t source_stride_y, const uint8_t* in_chroma_row, + ptrdiff_t source_stride_chroma, uint8_t* out_chroma_row, + ptrdiff_t dest_stride) { + const int16x8_t floor = vdupq_n_s16(min_value); + const int16x8_t ceiling = vdupq_n_s16(max_chroma); + // In 8bpp, the maximum upscaled noise is 127*255 = 0x7E81, which is safe + // for 16 bit signed integers. In higher bitdepths, however, we have to + // expand to 32 to protect the sign bit. + const int16x8_t scaling_shift_vect = vdupq_n_s16(-scaling_shift); + + const int chroma_height = (height + subsampling_y) >> subsampling_y; + const int chroma_width = (width + subsampling_x) >> subsampling_x; + const int safe_chroma_width = chroma_width & ~7; + uint8_t luma_buffer[16]; + const int16x8_t offset = vdupq_n_s16(chroma_offset << 5); + + start_height >>= subsampling_y; + int y = 0; + do { + int x = 0; + do { + const int luma_x = x << subsampling_x; + const int16x8_t average_luma = vreinterpretq_s16_u16( + GetAverageLuma(&in_y_row[luma_x], subsampling_x)); + const int16x8_t blended = BlendChromaValsNoCfl( + scaling_lut, &in_chroma_row[x], &(noise_image[y + start_height][x]), + average_luma, scaling_shift_vect, offset, luma_multiplier, + chroma_multiplier); + // In 8bpp, when params_.clip_to_restricted_range == false, we can + // replace clipping with vqmovun_s16, but the gain would be small. + StoreUnsigned8(&out_chroma_row[x], + vreinterpretq_u16_s16(Clip3(blended, floor, ceiling))); + + x += 8; + } while (x < safe_chroma_width); + + if (x < chroma_width) { + // Begin right edge iteration. Same as the normal iterations, but the + // |average_luma| computation requires a duplicated luma value at the + // end. + const int luma_x = x << subsampling_x; + const int valid_range = width - luma_x; + memcpy(luma_buffer, &in_y_row[luma_x], valid_range * sizeof(in_y_row[0])); + luma_buffer[valid_range] = in_y_row[width - 1]; + + const int16x8_t average_luma = + vreinterpretq_s16_u16(GetAverageLuma(luma_buffer, subsampling_x)); + const int16x8_t blended = BlendChromaValsNoCfl( + scaling_lut, &in_chroma_row[x], &(noise_image[y + start_height][x]), + average_luma, scaling_shift_vect, offset, luma_multiplier, + chroma_multiplier); + StoreUnsigned8(&out_chroma_row[x], + vreinterpretq_u16_s16(Clip3(blended, floor, ceiling))); + // End of right edge iteration. + } + + in_y_row += source_stride_y << subsampling_y; + in_chroma_row += source_stride_chroma; + out_chroma_row += dest_stride; + } while (++y < chroma_height); +} + +// This function is for the case params_.chroma_scaling_from_luma == false. +void BlendNoiseWithImageChroma8bpp_NEON( + Plane plane, const FilmGrainParams& params, const void* noise_image_ptr, + int min_value, int max_chroma, int width, int height, int start_height, + int subsampling_x, int subsampling_y, + const uint8_t scaling_lut[kScalingLookupTableSize], + const void* source_plane_y, ptrdiff_t source_stride_y, + const void* source_plane_uv, ptrdiff_t source_stride_uv, + void* dest_plane_uv, ptrdiff_t dest_stride_uv) { + assert(plane == kPlaneU || plane == kPlaneV); + const auto* noise_image = + static_cast<const Array2D<int8_t>*>(noise_image_ptr); + const auto* in_y = static_cast<const uint8_t*>(source_plane_y); + const auto* in_uv = static_cast<const uint8_t*>(source_plane_uv); + auto* out_uv = static_cast<uint8_t*>(dest_plane_uv); + + const int offset = (plane == kPlaneU) ? params.u_offset : params.v_offset; + const int luma_multiplier = + (plane == kPlaneU) ? params.u_luma_multiplier : params.v_luma_multiplier; + const int multiplier = + (plane == kPlaneU) ? params.u_multiplier : params.v_multiplier; + BlendChromaPlane8bpp_NEON(noise_image[plane], min_value, max_chroma, width, + height, start_height, subsampling_x, subsampling_y, + params.chroma_scaling, offset, multiplier, + luma_multiplier, scaling_lut, in_y, source_stride_y, + in_uv, source_stride_uv, out_uv, dest_stride_uv); +} + +inline void WriteOverlapLine8bpp_NEON(const int8_t* noise_stripe_row, + const int8_t* noise_stripe_row_prev, + int plane_width, + const int8x8_t grain_coeff, + const int8x8_t old_coeff, + int8_t* noise_image_row) { + int x = 0; + do { + // Note that these reads may exceed noise_stripe_row's width by up to 7 + // bytes. + const int8x8_t source_grain = vld1_s8(noise_stripe_row + x); + const int8x8_t source_old = vld1_s8(noise_stripe_row_prev + x); + const int16x8_t weighted_grain = vmull_s8(grain_coeff, source_grain); + const int16x8_t grain = vmlal_s8(weighted_grain, old_coeff, source_old); + // Note that this write may exceed noise_image_row's width by up to 7 bytes. + vst1_s8(noise_image_row + x, vqrshrn_n_s16(grain, 5)); + x += 8; + } while (x < plane_width); +} + +void ConstructNoiseImageOverlap8bpp_NEON(const void* noise_stripes_buffer, + int width, int height, + int subsampling_x, int subsampling_y, + void* noise_image_buffer) { + const auto* noise_stripes = + static_cast<const Array2DView<int8_t>*>(noise_stripes_buffer); + auto* noise_image = static_cast<Array2D<int8_t>*>(noise_image_buffer); + const int plane_width = (width + subsampling_x) >> subsampling_x; + const int plane_height = (height + subsampling_y) >> subsampling_y; + const int stripe_height = 32 >> subsampling_y; + const int stripe_mask = stripe_height - 1; + int y = stripe_height; + int luma_num = 1; + if (subsampling_y == 0) { + const int8x8_t first_row_grain_coeff = vdup_n_s8(17); + const int8x8_t first_row_old_coeff = vdup_n_s8(27); + const int8x8_t second_row_grain_coeff = first_row_old_coeff; + const int8x8_t second_row_old_coeff = first_row_grain_coeff; + for (; y < (plane_height & ~stripe_mask); ++luma_num, y += stripe_height) { + const int8_t* noise_stripe = (*noise_stripes)[luma_num]; + const int8_t* noise_stripe_prev = (*noise_stripes)[luma_num - 1]; + WriteOverlapLine8bpp_NEON( + noise_stripe, &noise_stripe_prev[32 * plane_width], plane_width, + first_row_grain_coeff, first_row_old_coeff, (*noise_image)[y]); + + WriteOverlapLine8bpp_NEON(&noise_stripe[plane_width], + &noise_stripe_prev[(32 + 1) * plane_width], + plane_width, second_row_grain_coeff, + second_row_old_coeff, (*noise_image)[y + 1]); + } + // Either one partial stripe remains (remaining_height > 0), + // OR image is less than one stripe high (remaining_height < 0), + // OR all stripes are completed (remaining_height == 0). + const int remaining_height = plane_height - y; + if (remaining_height <= 0) { + return; + } + const int8_t* noise_stripe = (*noise_stripes)[luma_num]; + const int8_t* noise_stripe_prev = (*noise_stripes)[luma_num - 1]; + WriteOverlapLine8bpp_NEON( + noise_stripe, &noise_stripe_prev[32 * plane_width], plane_width, + first_row_grain_coeff, first_row_old_coeff, (*noise_image)[y]); + + if (remaining_height > 1) { + WriteOverlapLine8bpp_NEON(&noise_stripe[plane_width], + &noise_stripe_prev[(32 + 1) * plane_width], + plane_width, second_row_grain_coeff, + second_row_old_coeff, (*noise_image)[y + 1]); + } + } else { // subsampling_y == 1 + const int8x8_t first_row_grain_coeff = vdup_n_s8(22); + const int8x8_t first_row_old_coeff = vdup_n_s8(23); + for (; y < plane_height; ++luma_num, y += stripe_height) { + const int8_t* noise_stripe = (*noise_stripes)[luma_num]; + const int8_t* noise_stripe_prev = (*noise_stripes)[luma_num - 1]; + WriteOverlapLine8bpp_NEON( + noise_stripe, &noise_stripe_prev[16 * plane_width], plane_width, + first_row_grain_coeff, first_row_old_coeff, (*noise_image)[y]); + } + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + + // LumaAutoRegressionFunc + dsp->film_grain.luma_auto_regression[0] = + ApplyAutoRegressiveFilterToLumaGrain_NEON<8, int8_t, 1>; + dsp->film_grain.luma_auto_regression[1] = + ApplyAutoRegressiveFilterToLumaGrain_NEON<8, int8_t, 2>; + dsp->film_grain.luma_auto_regression[2] = + ApplyAutoRegressiveFilterToLumaGrain_NEON<8, int8_t, 3>; + + // ChromaAutoRegressionFunc[use_luma][auto_regression_coeff_lag] + // Chroma autoregression should never be called when lag is 0 and use_luma + // is false. + dsp->film_grain.chroma_auto_regression[0][0] = nullptr; + dsp->film_grain.chroma_auto_regression[0][1] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 1, false>; + dsp->film_grain.chroma_auto_regression[0][2] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 2, false>; + dsp->film_grain.chroma_auto_regression[0][3] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 3, false>; + dsp->film_grain.chroma_auto_regression[1][0] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 0, true>; + dsp->film_grain.chroma_auto_regression[1][1] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 1, true>; + dsp->film_grain.chroma_auto_regression[1][2] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 2, true>; + dsp->film_grain.chroma_auto_regression[1][3] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 3, true>; + + dsp->film_grain.construct_noise_image_overlap = + ConstructNoiseImageOverlap8bpp_NEON; + + dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_NEON; + + dsp->film_grain.blend_noise_luma = + BlendNoiseWithImageLuma_NEON<8, int8_t, uint8_t>; + dsp->film_grain.blend_noise_chroma[0] = BlendNoiseWithImageChroma8bpp_NEON; + dsp->film_grain.blend_noise_chroma[1] = + BlendNoiseWithImageChromaWithCfl_NEON<8, int8_t, uint8_t>; +} + +} // namespace +} // namespace low_bitdepth + +#if LIBGAV1_MAX_BITDEPTH >= 10 +namespace high_bitdepth { +namespace { + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + + // LumaAutoRegressionFunc + dsp->film_grain.luma_auto_regression[0] = + ApplyAutoRegressiveFilterToLumaGrain_NEON<10, int16_t, 1>; + dsp->film_grain.luma_auto_regression[1] = + ApplyAutoRegressiveFilterToLumaGrain_NEON<10, int16_t, 2>; + dsp->film_grain.luma_auto_regression[2] = + ApplyAutoRegressiveFilterToLumaGrain_NEON<10, int16_t, 3>; + + // ChromaAutoRegressionFunc[use_luma][auto_regression_coeff_lag][subsampling] + // Chroma autoregression should never be called when lag is 0 and use_luma + // is false. + dsp->film_grain.chroma_auto_regression[0][0] = nullptr; + dsp->film_grain.chroma_auto_regression[0][1] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 1, false>; + dsp->film_grain.chroma_auto_regression[0][2] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 2, false>; + dsp->film_grain.chroma_auto_regression[0][3] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 3, false>; + dsp->film_grain.chroma_auto_regression[1][0] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 0, true>; + dsp->film_grain.chroma_auto_regression[1][1] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 1, true>; + dsp->film_grain.chroma_auto_regression[1][2] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 2, true>; + dsp->film_grain.chroma_auto_regression[1][3] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 3, true>; + + dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_NEON; + + dsp->film_grain.blend_noise_luma = + BlendNoiseWithImageLuma_NEON<10, int16_t, uint16_t>; + dsp->film_grain.blend_noise_chroma[1] = + BlendNoiseWithImageChromaWithCfl_NEON<10, int16_t, uint16_t>; +} + +} // namespace +} // namespace high_bitdepth +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +} // namespace film_grain + +void FilmGrainInit_NEON() { + film_grain::low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + film_grain::high_bitdepth::Init10bpp(); +#endif // LIBGAV1_MAX_BITDEPTH >= 10 +} + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON + +namespace libgav1 { +namespace dsp { + +void FilmGrainInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/film_grain_neon.h b/src/dsp/arm/film_grain_neon.h new file mode 100644 index 0000000..44b3d1d --- /dev/null +++ b/src/dsp/arm/film_grain_neon.h @@ -0,0 +1,47 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_FILM_GRAIN_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_FILM_GRAIN_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initialize members of Dsp::film_grain. This function is not thread-safe. +void FilmGrainInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_FilmGrainAutoregressionLuma LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp10bpp_FilmGrainAutoregressionLuma LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp8bpp_FilmGrainAutoregressionChroma LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp10bpp_FilmGrainAutoregressionChroma LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp8bpp_FilmGrainConstructNoiseImageOverlap LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp8bpp_FilmGrainInitializeScalingLutFunc LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp10bpp_FilmGrainInitializeScalingLutFunc LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseLuma LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseLuma LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseChroma LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseChromaWithCfl LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseChromaWithCfl LIBGAV1_DSP_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_FILM_GRAIN_NEON_H_ diff --git a/src/dsp/arm/intra_edge_neon.cc b/src/dsp/arm/intra_edge_neon.cc new file mode 100644 index 0000000..00b186a --- /dev/null +++ b/src/dsp/arm/intra_edge_neon.cc @@ -0,0 +1,301 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intra_edge.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <algorithm> +#include <cassert> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" // RightShiftWithRounding() + +namespace libgav1 { +namespace dsp { +namespace { + +// Simplified version of intra_edge.cc:kKernels[][]. Only |strength| 1 and 2 are +// required. +constexpr int kKernelsNEON[3][2] = {{4, 8}, {5, 6}}; + +void IntraEdgeFilter_NEON(void* buffer, const int size, const int strength) { + assert(strength == 1 || strength == 2 || strength == 3); + const int kernel_index = strength - 1; + auto* const dst_buffer = static_cast<uint8_t*>(buffer); + + // The first element is not written out (but it is input) so the number of + // elements written is |size| - 1. + if (size == 1) return; + + // |strength| 1 and 2 use a 3 tap filter. + if (strength < 3) { + // The last value requires extending the buffer (duplicating + // |dst_buffer[size - 1]). Calculate it here to avoid extra processing in + // neon. + const uint8_t last_val = RightShiftWithRounding( + kKernelsNEON[kernel_index][0] * dst_buffer[size - 2] + + kKernelsNEON[kernel_index][1] * dst_buffer[size - 1] + + kKernelsNEON[kernel_index][0] * dst_buffer[size - 1], + 4); + + const uint8x8_t krn1 = vdup_n_u8(kKernelsNEON[kernel_index][1]); + + // The first value we need gets overwritten by the output from the + // previous iteration. + uint8x16_t src_0 = vld1q_u8(dst_buffer); + int i = 1; + + // Process blocks until there are less than 16 values remaining. + for (; i < size - 15; i += 16) { + // Loading these at the end of the block with |src_0| will read past the + // end of |top_row_data[160]|, the source of |buffer|. + const uint8x16_t src_1 = vld1q_u8(dst_buffer + i); + const uint8x16_t src_2 = vld1q_u8(dst_buffer + i + 1); + uint16x8_t sum_lo = vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_2)); + sum_lo = vmulq_n_u16(sum_lo, kKernelsNEON[kernel_index][0]); + sum_lo = vmlal_u8(sum_lo, vget_low_u8(src_1), krn1); + uint16x8_t sum_hi = vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_2)); + sum_hi = vmulq_n_u16(sum_hi, kKernelsNEON[kernel_index][0]); + sum_hi = vmlal_u8(sum_hi, vget_high_u8(src_1), krn1); + + const uint8x16_t result = + vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4)); + + // Load the next row before overwriting. This loads an extra 15 values + // past |size| on the trailing iteration. + src_0 = vld1q_u8(dst_buffer + i + 15); + + vst1q_u8(dst_buffer + i, result); + } + + // The last output value |last_val| was already calculated so if + // |remainder| == 1 then we don't have to do anything. + const int remainder = (size - 1) & 0xf; + if (remainder > 1) { + uint8_t temp[16]; + const uint8x16_t src_1 = vld1q_u8(dst_buffer + i); + const uint8x16_t src_2 = vld1q_u8(dst_buffer + i + 1); + + uint16x8_t sum_lo = vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_2)); + sum_lo = vmulq_n_u16(sum_lo, kKernelsNEON[kernel_index][0]); + sum_lo = vmlal_u8(sum_lo, vget_low_u8(src_1), krn1); + uint16x8_t sum_hi = vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_2)); + sum_hi = vmulq_n_u16(sum_hi, kKernelsNEON[kernel_index][0]); + sum_hi = vmlal_u8(sum_hi, vget_high_u8(src_1), krn1); + + const uint8x16_t result = + vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4)); + + vst1q_u8(temp, result); + memcpy(dst_buffer + i, temp, remainder); + } + + dst_buffer[size - 1] = last_val; + return; + } + + assert(strength == 3); + // 5 tap filter. The first element requires duplicating |buffer[0]| and the + // last two elements require duplicating |buffer[size - 1]|. + uint8_t special_vals[3]; + special_vals[0] = RightShiftWithRounding( + (dst_buffer[0] << 1) + (dst_buffer[0] << 2) + (dst_buffer[1] << 2) + + (dst_buffer[2] << 2) + (dst_buffer[3] << 1), + 4); + // Clamp index for very small |size| values. + const int first_index_min = std::max(size - 4, 0); + const int second_index_min = std::max(size - 3, 0); + const int third_index_min = std::max(size - 2, 0); + special_vals[1] = RightShiftWithRounding( + (dst_buffer[first_index_min] << 1) + (dst_buffer[second_index_min] << 2) + + (dst_buffer[third_index_min] << 2) + (dst_buffer[size - 1] << 2) + + (dst_buffer[size - 1] << 1), + 4); + special_vals[2] = RightShiftWithRounding( + (dst_buffer[second_index_min] << 1) + (dst_buffer[third_index_min] << 2) + + // x << 2 + x << 2 == x << 3 + (dst_buffer[size - 1] << 3) + (dst_buffer[size - 1] << 1), + 4); + + // The first two values we need get overwritten by the output from the + // previous iteration. + uint8x16_t src_0 = vld1q_u8(dst_buffer - 1); + uint8x16_t src_1 = vld1q_u8(dst_buffer); + int i = 1; + + for (; i < size - 15; i += 16) { + // Loading these at the end of the block with |src_[01]| will read past + // the end of |top_row_data[160]|, the source of |buffer|. + const uint8x16_t src_2 = vld1q_u8(dst_buffer + i); + const uint8x16_t src_3 = vld1q_u8(dst_buffer + i + 1); + const uint8x16_t src_4 = vld1q_u8(dst_buffer + i + 2); + + uint16x8_t sum_lo = + vshlq_n_u16(vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_4)), 1); + const uint16x8_t sum_123_lo = vaddw_u8( + vaddl_u8(vget_low_u8(src_1), vget_low_u8(src_2)), vget_low_u8(src_3)); + sum_lo = vaddq_u16(sum_lo, vshlq_n_u16(sum_123_lo, 2)); + + uint16x8_t sum_hi = + vshlq_n_u16(vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_4)), 1); + const uint16x8_t sum_123_hi = + vaddw_u8(vaddl_u8(vget_high_u8(src_1), vget_high_u8(src_2)), + vget_high_u8(src_3)); + sum_hi = vaddq_u16(sum_hi, vshlq_n_u16(sum_123_hi, 2)); + + const uint8x16_t result = + vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4)); + + src_0 = vld1q_u8(dst_buffer + i + 14); + src_1 = vld1q_u8(dst_buffer + i + 15); + + vst1q_u8(dst_buffer + i, result); + } + + const int remainder = (size - 1) & 0xf; + // Like the 3 tap but if there are two remaining values we have already + // calculated them. + if (remainder > 2) { + uint8_t temp[16]; + const uint8x16_t src_2 = vld1q_u8(dst_buffer + i); + const uint8x16_t src_3 = vld1q_u8(dst_buffer + i + 1); + const uint8x16_t src_4 = vld1q_u8(dst_buffer + i + 2); + + uint16x8_t sum_lo = + vshlq_n_u16(vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_4)), 1); + const uint16x8_t sum_123_lo = vaddw_u8( + vaddl_u8(vget_low_u8(src_1), vget_low_u8(src_2)), vget_low_u8(src_3)); + sum_lo = vaddq_u16(sum_lo, vshlq_n_u16(sum_123_lo, 2)); + + uint16x8_t sum_hi = + vshlq_n_u16(vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_4)), 1); + const uint16x8_t sum_123_hi = + vaddw_u8(vaddl_u8(vget_high_u8(src_1), vget_high_u8(src_2)), + vget_high_u8(src_3)); + sum_hi = vaddq_u16(sum_hi, vshlq_n_u16(sum_123_hi, 2)); + + const uint8x16_t result = + vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4)); + + vst1q_u8(temp, result); + memcpy(dst_buffer + i, temp, remainder); + } + + dst_buffer[1] = special_vals[0]; + // Avoid overwriting |dst_buffer[0]|. + if (size > 2) dst_buffer[size - 2] = special_vals[1]; + dst_buffer[size - 1] = special_vals[2]; +} + +// (-|src0| + |src1| * 9 + |src2| * 9 - |src3|) >> 4 +uint8x8_t Upsample(const uint8x8_t src0, const uint8x8_t src1, + const uint8x8_t src2, const uint8x8_t src3) { + const uint16x8_t middle = vmulq_n_u16(vaddl_u8(src1, src2), 9); + const uint16x8_t ends = vaddl_u8(src0, src3); + const int16x8_t sum = + vsubq_s16(vreinterpretq_s16_u16(middle), vreinterpretq_s16_u16(ends)); + return vqrshrun_n_s16(sum, 4); +} + +void IntraEdgeUpsampler_NEON(void* buffer, const int size) { + assert(size % 4 == 0 && size <= 16); + auto* const pixel_buffer = static_cast<uint8_t*>(buffer); + // This is OK because we don't read this value for |size| 4 or 8 but if we + // write |pixel_buffer[size]| and then vld() it, that seems to introduce + // some latency. + pixel_buffer[-2] = pixel_buffer[-1]; + if (size == 4) { + // This uses one load and two vtbl() which is better than 4x Load{Lo,Hi}4(). + const uint8x8_t src = vld1_u8(pixel_buffer - 1); + // The outside values are negated so put those in the same vector. + const uint8x8_t src03 = vtbl1_u8(src, vcreate_u8(0x0404030202010000)); + // Reverse |src1| and |src2| so we can use |src2| for the interleave at the + // end. + const uint8x8_t src21 = vtbl1_u8(src, vcreate_u8(0x0302010004030201)); + + const uint16x8_t middle = vmull_u8(src21, vdup_n_u8(9)); + const int16x8_t half_sum = vsubq_s16( + vreinterpretq_s16_u16(middle), vreinterpretq_s16_u16(vmovl_u8(src03))); + const int16x4_t sum = + vadd_s16(vget_low_s16(half_sum), vget_high_s16(half_sum)); + const uint8x8_t result = vqrshrun_n_s16(vcombine_s16(sum, sum), 4); + + vst1_u8(pixel_buffer - 1, InterleaveLow8(result, src21)); + return; + } else if (size == 8) { + // Likewise, one load + multiple vtbls seems preferred to multiple loads. + const uint8x16_t src = vld1q_u8(pixel_buffer - 1); + const uint8x8_t src0 = VQTbl1U8(src, vcreate_u8(0x0605040302010000)); + const uint8x8_t src1 = vget_low_u8(src); + const uint8x8_t src2 = VQTbl1U8(src, vcreate_u8(0x0807060504030201)); + const uint8x8_t src3 = VQTbl1U8(src, vcreate_u8(0x0808070605040302)); + + const uint8x8x2_t output = {Upsample(src0, src1, src2, src3), src2}; + vst2_u8(pixel_buffer - 1, output); + return; + } + assert(size == 12 || size == 16); + // Extend the input borders to avoid branching later. + pixel_buffer[size] = pixel_buffer[size - 1]; + const uint8x16_t src0 = vld1q_u8(pixel_buffer - 2); + const uint8x16_t src1 = vld1q_u8(pixel_buffer - 1); + const uint8x16_t src2 = vld1q_u8(pixel_buffer); + const uint8x16_t src3 = vld1q_u8(pixel_buffer + 1); + + const uint8x8_t result_lo = Upsample(vget_low_u8(src0), vget_low_u8(src1), + vget_low_u8(src2), vget_low_u8(src3)); + + const uint8x8x2_t output_lo = {result_lo, vget_low_u8(src2)}; + vst2_u8(pixel_buffer - 1, output_lo); + + const uint8x8_t result_hi = Upsample(vget_high_u8(src0), vget_high_u8(src1), + vget_high_u8(src2), vget_high_u8(src3)); + + if (size == 12) { + vst1_u8(pixel_buffer + 15, InterleaveLow8(result_hi, vget_high_u8(src2))); + } else /* size == 16 */ { + const uint8x8x2_t output_hi = {result_hi, vget_high_u8(src2)}; + vst2_u8(pixel_buffer + 15, output_hi); + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->intra_edge_filter = IntraEdgeFilter_NEON; + dsp->intra_edge_upsampler = IntraEdgeUpsampler_NEON; +} + +} // namespace + +void IntraEdgeInit_NEON() { Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void IntraEdgeInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/intra_edge_neon.h b/src/dsp/arm/intra_edge_neon.h new file mode 100644 index 0000000..d3bb243 --- /dev/null +++ b/src/dsp/arm/intra_edge_neon.h @@ -0,0 +1,39 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_INTRA_EDGE_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_INTRA_EDGE_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::intra_edge_filter and Dsp::intra_edge_upsampler. This +// function is not thread-safe. +void IntraEdgeInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_IntraEdgeFilter LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_IntraEdgeUpsampler LIBGAV1_CPU_NEON + +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_INTRA_EDGE_NEON_H_ diff --git a/src/dsp/arm/intrapred_cfl_neon.cc b/src/dsp/arm/intrapred_cfl_neon.cc new file mode 100644 index 0000000..45fe33b --- /dev/null +++ b/src/dsp/arm/intrapred_cfl_neon.cc @@ -0,0 +1,479 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intrapred.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +uint8x16_t Set2ValuesQ(const uint8_t* a) { + uint16_t combined_values = a[0] | a[1] << 8; + return vreinterpretq_u8_u16(vdupq_n_u16(combined_values)); +} + +uint32_t SumVector(uint32x2_t a) { +#if defined(__aarch64__) + return vaddv_u32(a); +#else + const uint64x1_t b = vpaddl_u32(a); + return vget_lane_u32(vreinterpret_u32_u64(b), 0); +#endif // defined(__aarch64__) +} + +uint32_t SumVector(uint32x4_t a) { +#if defined(__aarch64__) + return vaddvq_u32(a); +#else + const uint64x2_t b = vpaddlq_u32(a); + const uint64x1_t c = vadd_u64(vget_low_u64(b), vget_high_u64(b)); + return vget_lane_u32(vreinterpret_u32_u64(c), 0); +#endif // defined(__aarch64__) +} + +// Divide by the number of elements. +uint32_t Average(const uint32_t sum, const int width, const int height) { + return RightShiftWithRounding(sum, FloorLog2(width) + FloorLog2(height)); +} + +// Subtract |val| from every element in |a|. +void BlockSubtract(const uint32_t val, + int16_t a[kCflLumaBufferStride][kCflLumaBufferStride], + const int width, const int height) { + assert(val <= INT16_MAX); + const int16x8_t val_v = vdupq_n_s16(static_cast<int16_t>(val)); + + for (int y = 0; y < height; ++y) { + if (width == 4) { + const int16x4_t b = vld1_s16(a[y]); + vst1_s16(a[y], vsub_s16(b, vget_low_s16(val_v))); + } else if (width == 8) { + const int16x8_t b = vld1q_s16(a[y]); + vst1q_s16(a[y], vsubq_s16(b, val_v)); + } else if (width == 16) { + const int16x8_t b = vld1q_s16(a[y]); + const int16x8_t c = vld1q_s16(a[y] + 8); + vst1q_s16(a[y], vsubq_s16(b, val_v)); + vst1q_s16(a[y] + 8, vsubq_s16(c, val_v)); + } else /* block_width == 32 */ { + const int16x8_t b = vld1q_s16(a[y]); + const int16x8_t c = vld1q_s16(a[y] + 8); + const int16x8_t d = vld1q_s16(a[y] + 16); + const int16x8_t e = vld1q_s16(a[y] + 24); + vst1q_s16(a[y], vsubq_s16(b, val_v)); + vst1q_s16(a[y] + 8, vsubq_s16(c, val_v)); + vst1q_s16(a[y] + 16, vsubq_s16(d, val_v)); + vst1q_s16(a[y] + 24, vsubq_s16(e, val_v)); + } + } +} + +template <int block_width, int block_height> +void CflSubsampler420_NEON( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_width, const int max_luma_height, + const void* const source, const ptrdiff_t stride) { + const auto* src = static_cast<const uint8_t*>(source); + uint32_t sum; + if (block_width == 4) { + assert(max_luma_width >= 8); + uint32x2_t running_sum = vdup_n_u32(0); + + for (int y = 0; y < block_height; ++y) { + const uint8x8_t row0 = vld1_u8(src); + const uint8x8_t row1 = vld1_u8(src + stride); + + uint16x4_t sum_row = vpadal_u8(vpaddl_u8(row0), row1); + sum_row = vshl_n_u16(sum_row, 1); + running_sum = vpadal_u16(running_sum, sum_row); + vst1_s16(luma[y], vreinterpret_s16_u16(sum_row)); + + if (y << 1 < max_luma_height - 2) { + // Once this threshold is reached the loop could be simplified. + src += stride << 1; + } + } + + sum = SumVector(running_sum); + } else if (block_width == 8) { + const uint8x16_t x_index = {0, 0, 2, 2, 4, 4, 6, 6, + 8, 8, 10, 10, 12, 12, 14, 14}; + const uint8x16_t x_max_index = vdupq_n_u8(max_luma_width - 2); + const uint8x16_t x_mask = vcltq_u8(x_index, x_max_index); + + uint32x4_t running_sum = vdupq_n_u32(0); + + for (int y = 0; y < block_height; ++y) { + const uint8x16_t x_max0 = Set2ValuesQ(src + max_luma_width - 2); + const uint8x16_t x_max1 = Set2ValuesQ(src + max_luma_width - 2 + stride); + + uint8x16_t row0 = vld1q_u8(src); + row0 = vbslq_u8(x_mask, row0, x_max0); + uint8x16_t row1 = vld1q_u8(src + stride); + row1 = vbslq_u8(x_mask, row1, x_max1); + + uint16x8_t sum_row = vpadalq_u8(vpaddlq_u8(row0), row1); + sum_row = vshlq_n_u16(sum_row, 1); + running_sum = vpadalq_u16(running_sum, sum_row); + vst1q_s16(luma[y], vreinterpretq_s16_u16(sum_row)); + + if (y << 1 < max_luma_height - 2) { + src += stride << 1; + } + } + + sum = SumVector(running_sum); + } else /* block_width >= 16 */ { + const uint8x16_t x_max_index = vdupq_n_u8(max_luma_width - 2); + uint32x4_t running_sum = vdupq_n_u32(0); + + for (int y = 0; y < block_height; ++y) { + uint8x16_t x_index = {0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30}; + const uint8x16_t x_max00 = vdupq_n_u8(src[max_luma_width - 2]); + const uint8x16_t x_max01 = vdupq_n_u8(src[max_luma_width - 2 + 1]); + const uint8x16_t x_max10 = vdupq_n_u8(src[stride + max_luma_width - 2]); + const uint8x16_t x_max11 = + vdupq_n_u8(src[stride + max_luma_width - 2 + 1]); + for (int x = 0; x < block_width; x += 16) { + const ptrdiff_t src_x_offset = x << 1; + const uint8x16_t x_mask = vcltq_u8(x_index, x_max_index); + const uint8x16x2_t row0 = vld2q_u8(src + src_x_offset); + const uint8x16x2_t row1 = vld2q_u8(src + src_x_offset + stride); + const uint8x16_t row_masked_00 = vbslq_u8(x_mask, row0.val[0], x_max00); + const uint8x16_t row_masked_01 = vbslq_u8(x_mask, row0.val[1], x_max01); + const uint8x16_t row_masked_10 = vbslq_u8(x_mask, row1.val[0], x_max10); + const uint8x16_t row_masked_11 = vbslq_u8(x_mask, row1.val[1], x_max11); + + uint16x8_t sum_row_lo = + vaddl_u8(vget_low_u8(row_masked_00), vget_low_u8(row_masked_01)); + sum_row_lo = vaddw_u8(sum_row_lo, vget_low_u8(row_masked_10)); + sum_row_lo = vaddw_u8(sum_row_lo, vget_low_u8(row_masked_11)); + sum_row_lo = vshlq_n_u16(sum_row_lo, 1); + running_sum = vpadalq_u16(running_sum, sum_row_lo); + vst1q_s16(luma[y] + x, vreinterpretq_s16_u16(sum_row_lo)); + + uint16x8_t sum_row_hi = + vaddl_u8(vget_high_u8(row_masked_00), vget_high_u8(row_masked_01)); + sum_row_hi = vaddw_u8(sum_row_hi, vget_high_u8(row_masked_10)); + sum_row_hi = vaddw_u8(sum_row_hi, vget_high_u8(row_masked_11)); + sum_row_hi = vshlq_n_u16(sum_row_hi, 1); + running_sum = vpadalq_u16(running_sum, sum_row_hi); + vst1q_s16(luma[y] + x + 8, vreinterpretq_s16_u16(sum_row_hi)); + + x_index = vaddq_u8(x_index, vdupq_n_u8(32)); + } + if (y << 1 < max_luma_height - 2) { + src += stride << 1; + } + } + sum = SumVector(running_sum); + } + + const uint32_t average = Average(sum, block_width, block_height); + BlockSubtract(average, luma, block_width, block_height); +} + +template <int block_width, int block_height> +void CflSubsampler444_NEON( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_width, const int max_luma_height, + const void* const source, const ptrdiff_t stride) { + const auto* src = static_cast<const uint8_t*>(source); + uint32_t sum; + if (block_width == 4) { + assert(max_luma_width >= 4); + uint32x4_t running_sum = vdupq_n_u32(0); + uint8x8_t row = vdup_n_u8(0); + + for (int y = 0; y < block_height; y += 2) { + row = Load4<0>(src, row); + row = Load4<1>(src + stride, row); + if (y < (max_luma_height - 1)) { + src += stride << 1; + } + + const uint16x8_t row_shifted = vshll_n_u8(row, 3); + running_sum = vpadalq_u16(running_sum, row_shifted); + vst1_s16(luma[y], vreinterpret_s16_u16(vget_low_u16(row_shifted))); + vst1_s16(luma[y + 1], vreinterpret_s16_u16(vget_high_u16(row_shifted))); + } + + sum = SumVector(running_sum); + } else if (block_width == 8) { + const uint8x8_t x_index = {0, 1, 2, 3, 4, 5, 6, 7}; + const uint8x8_t x_max_index = vdup_n_u8(max_luma_width - 1); + const uint8x8_t x_mask = vclt_u8(x_index, x_max_index); + + uint32x4_t running_sum = vdupq_n_u32(0); + + for (int y = 0; y < block_height; ++y) { + const uint8x8_t x_max = vdup_n_u8(src[max_luma_width - 1]); + const uint8x8_t row = vbsl_u8(x_mask, vld1_u8(src), x_max); + + const uint16x8_t row_shifted = vshll_n_u8(row, 3); + running_sum = vpadalq_u16(running_sum, row_shifted); + vst1q_s16(luma[y], vreinterpretq_s16_u16(row_shifted)); + + if (y < max_luma_height - 1) { + src += stride; + } + } + + sum = SumVector(running_sum); + } else /* block_width >= 16 */ { + const uint8x16_t x_max_index = vdupq_n_u8(max_luma_width - 1); + uint32x4_t running_sum = vdupq_n_u32(0); + + for (int y = 0; y < block_height; ++y) { + uint8x16_t x_index = {0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15}; + const uint8x16_t x_max = vdupq_n_u8(src[max_luma_width - 1]); + for (int x = 0; x < block_width; x += 16) { + const uint8x16_t x_mask = vcltq_u8(x_index, x_max_index); + const uint8x16_t row = vbslq_u8(x_mask, vld1q_u8(src + x), x_max); + + const uint16x8_t row_shifted_low = vshll_n_u8(vget_low_u8(row), 3); + const uint16x8_t row_shifted_high = vshll_n_u8(vget_high_u8(row), 3); + running_sum = vpadalq_u16(running_sum, row_shifted_low); + running_sum = vpadalq_u16(running_sum, row_shifted_high); + vst1q_s16(luma[y] + x, vreinterpretq_s16_u16(row_shifted_low)); + vst1q_s16(luma[y] + x + 8, vreinterpretq_s16_u16(row_shifted_high)); + + x_index = vaddq_u8(x_index, vdupq_n_u8(16)); + } + if (y < max_luma_height - 1) { + src += stride; + } + } + sum = SumVector(running_sum); + } + + const uint32_t average = Average(sum, block_width, block_height); + BlockSubtract(average, luma, block_width, block_height); +} + +// Saturate |dc + ((alpha * luma) >> 6))| to uint8_t. +inline uint8x8_t Combine8(const int16x8_t luma, const int alpha, + const int16x8_t dc) { + const int16x8_t la = vmulq_n_s16(luma, alpha); + // Subtract the sign bit to round towards zero. + const int16x8_t sub_sign = vsraq_n_s16(la, la, 15); + // Shift and accumulate. + const int16x8_t result = vrsraq_n_s16(dc, sub_sign, 6); + return vqmovun_s16(result); +} + +// The range of luma/alpha is not really important because it gets saturated to +// uint8_t. Saturated int16_t >> 6 outranges uint8_t. +template <int block_height> +inline void CflIntraPredictor4xN_NEON( + void* const dest, const ptrdiff_t stride, + const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int alpha) { + auto* dst = static_cast<uint8_t*>(dest); + const int16x8_t dc = vdupq_n_s16(dst[0]); + for (int y = 0; y < block_height; y += 2) { + const int16x4_t luma_row0 = vld1_s16(luma[y]); + const int16x4_t luma_row1 = vld1_s16(luma[y + 1]); + const uint8x8_t sum = + Combine8(vcombine_s16(luma_row0, luma_row1), alpha, dc); + StoreLo4(dst, sum); + dst += stride; + StoreHi4(dst, sum); + dst += stride; + } +} + +template <int block_height> +inline void CflIntraPredictor8xN_NEON( + void* const dest, const ptrdiff_t stride, + const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int alpha) { + auto* dst = static_cast<uint8_t*>(dest); + const int16x8_t dc = vdupq_n_s16(dst[0]); + for (int y = 0; y < block_height; ++y) { + const int16x8_t luma_row = vld1q_s16(luma[y]); + const uint8x8_t sum = Combine8(luma_row, alpha, dc); + vst1_u8(dst, sum); + dst += stride; + } +} + +template <int block_height> +inline void CflIntraPredictor16xN_NEON( + void* const dest, const ptrdiff_t stride, + const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int alpha) { + auto* dst = static_cast<uint8_t*>(dest); + const int16x8_t dc = vdupq_n_s16(dst[0]); + for (int y = 0; y < block_height; ++y) { + const int16x8_t luma_row_0 = vld1q_s16(luma[y]); + const int16x8_t luma_row_1 = vld1q_s16(luma[y] + 8); + const uint8x8_t sum_0 = Combine8(luma_row_0, alpha, dc); + const uint8x8_t sum_1 = Combine8(luma_row_1, alpha, dc); + vst1_u8(dst, sum_0); + vst1_u8(dst + 8, sum_1); + dst += stride; + } +} + +template <int block_height> +inline void CflIntraPredictor32xN_NEON( + void* const dest, const ptrdiff_t stride, + const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int alpha) { + auto* dst = static_cast<uint8_t*>(dest); + const int16x8_t dc = vdupq_n_s16(dst[0]); + for (int y = 0; y < block_height; ++y) { + const int16x8_t luma_row_0 = vld1q_s16(luma[y]); + const int16x8_t luma_row_1 = vld1q_s16(luma[y] + 8); + const int16x8_t luma_row_2 = vld1q_s16(luma[y] + 16); + const int16x8_t luma_row_3 = vld1q_s16(luma[y] + 24); + const uint8x8_t sum_0 = Combine8(luma_row_0, alpha, dc); + const uint8x8_t sum_1 = Combine8(luma_row_1, alpha, dc); + const uint8x8_t sum_2 = Combine8(luma_row_2, alpha, dc); + const uint8x8_t sum_3 = Combine8(luma_row_3, alpha, dc); + vst1_u8(dst, sum_0); + vst1_u8(dst + 8, sum_1); + vst1_u8(dst + 16, sum_2); + vst1_u8(dst + 24, sum_3); + dst += stride; + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + + dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType420] = + CflSubsampler420_NEON<4, 4>; + dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType420] = + CflSubsampler420_NEON<4, 8>; + dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType420] = + CflSubsampler420_NEON<4, 16>; + + dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType420] = + CflSubsampler420_NEON<8, 4>; + dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType420] = + CflSubsampler420_NEON<8, 8>; + dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType420] = + CflSubsampler420_NEON<8, 16>; + dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType420] = + CflSubsampler420_NEON<8, 32>; + + dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType420] = + CflSubsampler420_NEON<16, 4>; + dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType420] = + CflSubsampler420_NEON<16, 8>; + dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType420] = + CflSubsampler420_NEON<16, 16>; + dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType420] = + CflSubsampler420_NEON<16, 32>; + + dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType420] = + CflSubsampler420_NEON<32, 8>; + dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType420] = + CflSubsampler420_NEON<32, 16>; + dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType420] = + CflSubsampler420_NEON<32, 32>; + + dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType444] = + CflSubsampler444_NEON<4, 4>; + dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType444] = + CflSubsampler444_NEON<4, 8>; + dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType444] = + CflSubsampler444_NEON<4, 16>; + + dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType444] = + CflSubsampler444_NEON<8, 4>; + dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType444] = + CflSubsampler444_NEON<8, 8>; + dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType444] = + CflSubsampler444_NEON<8, 16>; + dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType444] = + CflSubsampler444_NEON<8, 32>; + + dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType444] = + CflSubsampler444_NEON<16, 4>; + dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType444] = + CflSubsampler444_NEON<16, 8>; + dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType444] = + CflSubsampler444_NEON<16, 16>; + dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType444] = + CflSubsampler444_NEON<16, 32>; + + dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType444] = + CflSubsampler444_NEON<32, 8>; + dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType444] = + CflSubsampler444_NEON<32, 16>; + dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType444] = + CflSubsampler444_NEON<32, 32>; + + dsp->cfl_intra_predictors[kTransformSize4x4] = CflIntraPredictor4xN_NEON<4>; + dsp->cfl_intra_predictors[kTransformSize4x8] = CflIntraPredictor4xN_NEON<8>; + dsp->cfl_intra_predictors[kTransformSize4x16] = CflIntraPredictor4xN_NEON<16>; + + dsp->cfl_intra_predictors[kTransformSize8x4] = CflIntraPredictor8xN_NEON<4>; + dsp->cfl_intra_predictors[kTransformSize8x8] = CflIntraPredictor8xN_NEON<8>; + dsp->cfl_intra_predictors[kTransformSize8x16] = CflIntraPredictor8xN_NEON<16>; + dsp->cfl_intra_predictors[kTransformSize8x32] = CflIntraPredictor8xN_NEON<32>; + + dsp->cfl_intra_predictors[kTransformSize16x4] = CflIntraPredictor16xN_NEON<4>; + dsp->cfl_intra_predictors[kTransformSize16x8] = CflIntraPredictor16xN_NEON<8>; + dsp->cfl_intra_predictors[kTransformSize16x16] = + CflIntraPredictor16xN_NEON<16>; + dsp->cfl_intra_predictors[kTransformSize16x32] = + CflIntraPredictor16xN_NEON<32>; + + dsp->cfl_intra_predictors[kTransformSize32x8] = CflIntraPredictor32xN_NEON<8>; + dsp->cfl_intra_predictors[kTransformSize32x16] = + CflIntraPredictor32xN_NEON<16>; + dsp->cfl_intra_predictors[kTransformSize32x32] = + CflIntraPredictor32xN_NEON<32>; + // Max Cfl predictor size is 32x32. +} + +} // namespace +} // namespace low_bitdepth + +void IntraPredCflInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void IntraPredCflInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/intrapred_directional_neon.cc b/src/dsp/arm/intrapred_directional_neon.cc new file mode 100644 index 0000000..805ba81 --- /dev/null +++ b/src/dsp/arm/intrapred_directional_neon.cc @@ -0,0 +1,926 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intrapred.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <algorithm> // std::min +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> // memset + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +// Blend two values based on a 32 bit weight. +inline uint8x8_t WeightedBlend(const uint8x8_t a, const uint8x8_t b, + const uint8x8_t a_weight, + const uint8x8_t b_weight) { + const uint16x8_t a_product = vmull_u8(a, a_weight); + const uint16x8_t b_product = vmull_u8(b, b_weight); + + return vrshrn_n_u16(vaddq_u16(a_product, b_product), 5); +} + +// For vertical operations the weights are one constant value. +inline uint8x8_t WeightedBlend(const uint8x8_t a, const uint8x8_t b, + const uint8_t weight) { + return WeightedBlend(a, b, vdup_n_u8(32 - weight), vdup_n_u8(weight)); +} + +// Fill |left| and |right| with the appropriate values for a given |base_step|. +inline void LoadStepwise(const uint8_t* const source, const uint8x8_t left_step, + const uint8x8_t right_step, uint8x8_t* left, + uint8x8_t* right) { + const uint8x16_t mixed = vld1q_u8(source); + *left = VQTbl1U8(mixed, left_step); + *right = VQTbl1U8(mixed, right_step); +} + +// Handle signed step arguments by ignoring the sign. Negative values are +// considered out of range and overwritten later. +inline void LoadStepwise(const uint8_t* const source, const int8x8_t left_step, + const int8x8_t right_step, uint8x8_t* left, + uint8x8_t* right) { + LoadStepwise(source, vreinterpret_u8_s8(left_step), + vreinterpret_u8_s8(right_step), left, right); +} + +// Process 4 or 8 |width| by any |height|. +template <int width> +inline void DirectionalZone1_WxH(uint8_t* dst, const ptrdiff_t stride, + const int height, const uint8_t* const top, + const int xstep, const bool upsampled) { + assert(width == 4 || width == 8); + + const int upsample_shift = static_cast<int>(upsampled); + const int scale_bits = 6 - upsample_shift; + + const int max_base_x = (width + height - 1) << upsample_shift; + const int8x8_t max_base = vdup_n_s8(max_base_x); + const uint8x8_t top_max_base = vdup_n_u8(top[max_base_x]); + + const int8x8_t all = vcreate_s8(0x0706050403020100); + const int8x8_t even = vcreate_s8(0x0e0c0a0806040200); + const int8x8_t base_step = upsampled ? even : all; + const int8x8_t right_step = vadd_s8(base_step, vdup_n_s8(1)); + + int top_x = xstep; + int y = 0; + do { + const int top_base_x = top_x >> scale_bits; + + if (top_base_x >= max_base_x) { + for (int i = y; i < height; ++i) { + memset(dst, top[max_base_x], 4 /* width */); + dst += stride; + } + return; + } + + const uint8_t shift = ((top_x << upsample_shift) & 0x3F) >> 1; + + // Zone2 uses negative values for xstep. Use signed values to compare + // |top_base_x| to |max_base_x|. + const int8x8_t base_v = vadd_s8(vdup_n_s8(top_base_x), base_step); + + const uint8x8_t max_base_mask = vclt_s8(base_v, max_base); + + // 4 wide subsamples the output. 8 wide subsamples the input. + if (width == 4) { + const uint8x8_t left_values = vld1_u8(top + top_base_x); + const uint8x8_t right_values = RightShift<8>(left_values); + const uint8x8_t value = WeightedBlend(left_values, right_values, shift); + + // If |upsampled| is true then extract every other value for output. + const uint8x8_t value_stepped = + vtbl1_u8(value, vreinterpret_u8_s8(base_step)); + const uint8x8_t masked_value = + vbsl_u8(max_base_mask, value_stepped, top_max_base); + + StoreLo4(dst, masked_value); + } else /* width == 8 */ { + uint8x8_t left_values, right_values; + // WeightedBlend() steps up to Q registers. Downsample the input to avoid + // doing extra calculations. + LoadStepwise(top + top_base_x, base_step, right_step, &left_values, + &right_values); + + const uint8x8_t value = WeightedBlend(left_values, right_values, shift); + const uint8x8_t masked_value = + vbsl_u8(max_base_mask, value, top_max_base); + + vst1_u8(dst, masked_value); + } + dst += stride; + top_x += xstep; + } while (++y < height); +} + +// Process a multiple of 8 |width| by any |height|. Processes horizontally +// before vertically in the hopes of being a little more cache friendly. +inline void DirectionalZone1_WxH(uint8_t* dst, const ptrdiff_t stride, + const int width, const int height, + const uint8_t* const top, const int xstep, + const bool upsampled) { + assert(width % 8 == 0); + const int upsample_shift = static_cast<int>(upsampled); + const int scale_bits = 6 - upsample_shift; + + const int max_base_x = (width + height - 1) << upsample_shift; + const int8x8_t max_base = vdup_n_s8(max_base_x); + const uint8x8_t top_max_base = vdup_n_u8(top[max_base_x]); + + const int8x8_t all = vcreate_s8(0x0706050403020100); + const int8x8_t even = vcreate_s8(0x0e0c0a0806040200); + const int8x8_t base_step = upsampled ? even : all; + const int8x8_t right_step = vadd_s8(base_step, vdup_n_s8(1)); + const int8x8_t block_step = vdup_n_s8(8 << upsample_shift); + + int top_x = xstep; + int y = 0; + do { + const int top_base_x = top_x >> scale_bits; + + if (top_base_x >= max_base_x) { + for (int i = y; i < height; ++i) { + memset(dst, top[max_base_x], 4 /* width */); + dst += stride; + } + return; + } + + const uint8_t shift = ((top_x << upsample_shift) & 0x3F) >> 1; + + // Zone2 uses negative values for xstep. Use signed values to compare + // |top_base_x| to |max_base_x|. + int8x8_t base_v = vadd_s8(vdup_n_s8(top_base_x), base_step); + + int x = 0; + do { + const uint8x8_t max_base_mask = vclt_s8(base_v, max_base); + + // Extract the input values based on |upsampled| here to avoid doing twice + // as many calculations. + uint8x8_t left_values, right_values; + LoadStepwise(top + top_base_x + x, base_step, right_step, &left_values, + &right_values); + + const uint8x8_t value = WeightedBlend(left_values, right_values, shift); + const uint8x8_t masked_value = + vbsl_u8(max_base_mask, value, top_max_base); + + vst1_u8(dst + x, masked_value); + + base_v = vadd_s8(base_v, block_step); + x += 8; + } while (x < width); + top_x += xstep; + dst += stride; + } while (++y < height); +} + +void DirectionalIntraPredictorZone1_NEON(void* const dest, + const ptrdiff_t stride, + const void* const top_row, + const int width, const int height, + const int xstep, + const bool upsampled_top) { + const uint8_t* const top = static_cast<const uint8_t*>(top_row); + uint8_t* dst = static_cast<uint8_t*>(dest); + + assert(xstep > 0); + + const int upsample_shift = static_cast<int>(upsampled_top); + + const uint8x8_t all = vcreate_u8(0x0706050403020100); + + if (xstep == 64) { + assert(!upsampled_top); + const uint8_t* top_ptr = top + 1; + int y = 0; + do { + memcpy(dst, top_ptr, width); + memcpy(dst + stride, top_ptr + 1, width); + memcpy(dst + 2 * stride, top_ptr + 2, width); + memcpy(dst + 3 * stride, top_ptr + 3, width); + dst += 4 * stride; + top_ptr += 4; + y += 4; + } while (y < height); + } else if (width == 4) { + DirectionalZone1_WxH<4>(dst, stride, height, top, xstep, upsampled_top); + } else if (xstep > 51) { + // 7.11.2.10. Intra edge upsample selection process + // if ( d <= 0 || d >= 40 ) useUpsample = 0 + // For |upsample_top| the delta is from vertical so |prediction_angle - 90|. + // In |kDirectionalIntraPredictorDerivative[]| angles less than 51 will meet + // this criteria. The |xstep| value for angle 51 happens to be 51 as well. + // Shallower angles have greater xstep values. + assert(!upsampled_top); + const int max_base_x = ((width + height) - 1); + const uint8x8_t max_base = vdup_n_u8(max_base_x); + const uint8x8_t top_max_base = vdup_n_u8(top[max_base_x]); + const uint8x8_t block_step = vdup_n_u8(8); + + int top_x = xstep; + int y = 0; + do { + const int top_base_x = top_x >> 6; + const uint8_t shift = ((top_x << upsample_shift) & 0x3F) >> 1; + uint8x8_t base_v = vadd_u8(vdup_n_u8(top_base_x), all); + int x = 0; + // Only calculate a block of 8 when at least one of the output values is + // within range. Otherwise it can read off the end of |top|. + const int must_calculate_width = + std::min(width, max_base_x - top_base_x + 7) & ~7; + for (; x < must_calculate_width; x += 8) { + const uint8x8_t max_base_mask = vclt_u8(base_v, max_base); + + // Since these |xstep| values can not be upsampled the load is + // simplified. + const uint8x8_t left_values = vld1_u8(top + top_base_x + x); + const uint8x8_t right_values = vld1_u8(top + top_base_x + x + 1); + const uint8x8_t value = WeightedBlend(left_values, right_values, shift); + const uint8x8_t masked_value = + vbsl_u8(max_base_mask, value, top_max_base); + + vst1_u8(dst + x, masked_value); + base_v = vadd_u8(base_v, block_step); + } + memset(dst + x, top[max_base_x], width - x); + dst += stride; + top_x += xstep; + } while (++y < height); + } else { + DirectionalZone1_WxH(dst, stride, width, height, top, xstep, upsampled_top); + } +} + +// Process 4 or 8 |width| by 4 or 8 |height|. +template <int width> +inline void DirectionalZone3_WxH(uint8_t* dest, const ptrdiff_t stride, + const int height, + const uint8_t* const left_column, + const int base_left_y, const int ystep, + const int upsample_shift) { + assert(width == 4 || width == 8); + assert(height == 4 || height == 8); + const int scale_bits = 6 - upsample_shift; + + // Zone3 never runs out of left_column values. + assert((width + height - 1) << upsample_shift > // max_base_y + ((ystep * width) >> scale_bits) + + (/* base_step */ 1 << upsample_shift) * + (height - 1)); // left_base_y + + // Limited improvement for 8x8. ~20% faster for 64x64. + const uint8x8_t all = vcreate_u8(0x0706050403020100); + const uint8x8_t even = vcreate_u8(0x0e0c0a0806040200); + const uint8x8_t base_step = upsample_shift ? even : all; + const uint8x8_t right_step = vadd_u8(base_step, vdup_n_u8(1)); + + uint8_t* dst = dest; + uint8x8_t left_v[8], right_v[8], value_v[8]; + const uint8_t* const left = left_column; + + const int index_0 = base_left_y; + LoadStepwise(left + (index_0 >> scale_bits), base_step, right_step, + &left_v[0], &right_v[0]); + value_v[0] = WeightedBlend(left_v[0], right_v[0], + ((index_0 << upsample_shift) & 0x3F) >> 1); + + const int index_1 = base_left_y + ystep; + LoadStepwise(left + (index_1 >> scale_bits), base_step, right_step, + &left_v[1], &right_v[1]); + value_v[1] = WeightedBlend(left_v[1], right_v[1], + ((index_1 << upsample_shift) & 0x3F) >> 1); + + const int index_2 = base_left_y + ystep * 2; + LoadStepwise(left + (index_2 >> scale_bits), base_step, right_step, + &left_v[2], &right_v[2]); + value_v[2] = WeightedBlend(left_v[2], right_v[2], + ((index_2 << upsample_shift) & 0x3F) >> 1); + + const int index_3 = base_left_y + ystep * 3; + LoadStepwise(left + (index_3 >> scale_bits), base_step, right_step, + &left_v[3], &right_v[3]); + value_v[3] = WeightedBlend(left_v[3], right_v[3], + ((index_3 << upsample_shift) & 0x3F) >> 1); + + const int index_4 = base_left_y + ystep * 4; + LoadStepwise(left + (index_4 >> scale_bits), base_step, right_step, + &left_v[4], &right_v[4]); + value_v[4] = WeightedBlend(left_v[4], right_v[4], + ((index_4 << upsample_shift) & 0x3F) >> 1); + + const int index_5 = base_left_y + ystep * 5; + LoadStepwise(left + (index_5 >> scale_bits), base_step, right_step, + &left_v[5], &right_v[5]); + value_v[5] = WeightedBlend(left_v[5], right_v[5], + ((index_5 << upsample_shift) & 0x3F) >> 1); + + const int index_6 = base_left_y + ystep * 6; + LoadStepwise(left + (index_6 >> scale_bits), base_step, right_step, + &left_v[6], &right_v[6]); + value_v[6] = WeightedBlend(left_v[6], right_v[6], + ((index_6 << upsample_shift) & 0x3F) >> 1); + + const int index_7 = base_left_y + ystep * 7; + LoadStepwise(left + (index_7 >> scale_bits), base_step, right_step, + &left_v[7], &right_v[7]); + value_v[7] = WeightedBlend(left_v[7], right_v[7], + ((index_7 << upsample_shift) & 0x3F) >> 1); + + // 8x8 transpose. + const uint8x16x2_t b0 = vtrnq_u8(vcombine_u8(value_v[0], value_v[4]), + vcombine_u8(value_v[1], value_v[5])); + const uint8x16x2_t b1 = vtrnq_u8(vcombine_u8(value_v[2], value_v[6]), + vcombine_u8(value_v[3], value_v[7])); + + const uint16x8x2_t c0 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[0]), + vreinterpretq_u16_u8(b1.val[0])); + const uint16x8x2_t c1 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[1]), + vreinterpretq_u16_u8(b1.val[1])); + + const uint32x4x2_t d0 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[0]), + vreinterpretq_u32_u16(c1.val[0])); + const uint32x4x2_t d1 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[1]), + vreinterpretq_u32_u16(c1.val[1])); + + if (width == 4) { + StoreLo4(dst, vreinterpret_u8_u32(vget_low_u32(d0.val[0]))); + dst += stride; + StoreLo4(dst, vreinterpret_u8_u32(vget_high_u32(d0.val[0]))); + dst += stride; + StoreLo4(dst, vreinterpret_u8_u32(vget_low_u32(d1.val[0]))); + dst += stride; + StoreLo4(dst, vreinterpret_u8_u32(vget_high_u32(d1.val[0]))); + if (height == 4) return; + dst += stride; + StoreLo4(dst, vreinterpret_u8_u32(vget_low_u32(d0.val[1]))); + dst += stride; + StoreLo4(dst, vreinterpret_u8_u32(vget_high_u32(d0.val[1]))); + dst += stride; + StoreLo4(dst, vreinterpret_u8_u32(vget_low_u32(d1.val[1]))); + dst += stride; + StoreLo4(dst, vreinterpret_u8_u32(vget_high_u32(d1.val[1]))); + } else { + vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d0.val[0]))); + dst += stride; + vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d0.val[0]))); + dst += stride; + vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d1.val[0]))); + dst += stride; + vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d1.val[0]))); + if (height == 4) return; + dst += stride; + vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d0.val[1]))); + dst += stride; + vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d0.val[1]))); + dst += stride; + vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d1.val[1]))); + dst += stride; + vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d1.val[1]))); + } +} + +// Because the source values "move backwards" as the row index increases, the +// indices derived from ystep are generally negative. This is accommodated by +// making sure the relative indices are within [-15, 0] when the function is +// called, and sliding them into the inclusive range [0, 15], relative to a +// lower base address. +constexpr int kPositiveIndexOffset = 15; + +// Process 4 or 8 |width| by any |height|. +template <int width> +inline void DirectionalZone2FromLeftCol_WxH(uint8_t* dst, + const ptrdiff_t stride, + const int height, + const uint8_t* const left_column, + const int16x8_t left_y, + const int upsample_shift) { + assert(width == 4 || width == 8); + + // The shift argument must be a constant. + int16x8_t offset_y, shift_upsampled = left_y; + if (upsample_shift) { + offset_y = vshrq_n_s16(left_y, 5); + shift_upsampled = vshlq_n_s16(shift_upsampled, 1); + } else { + offset_y = vshrq_n_s16(left_y, 6); + } + + // Select values to the left of the starting point. + // The 15th element (and 16th) will be all the way at the end, to the right. + // With a negative ystep everything else will be "left" of them. + // This supports cumulative steps up to 15. We could support up to 16 by doing + // separate loads for |left_values| and |right_values|. vtbl supports 2 Q + // registers as input which would allow for cumulative offsets of 32. + const int16x8_t sampler = + vaddq_s16(offset_y, vdupq_n_s16(kPositiveIndexOffset)); + const uint8x8_t left_values = vqmovun_s16(sampler); + const uint8x8_t right_values = vadd_u8(left_values, vdup_n_u8(1)); + + const int16x8_t shift_masked = vandq_s16(shift_upsampled, vdupq_n_s16(0x3f)); + const uint8x8_t shift_mul = vreinterpret_u8_s8(vshrn_n_s16(shift_masked, 1)); + const uint8x8_t inv_shift_mul = vsub_u8(vdup_n_u8(32), shift_mul); + + int y = 0; + do { + uint8x8_t src_left, src_right; + LoadStepwise(left_column - kPositiveIndexOffset + (y << upsample_shift), + left_values, right_values, &src_left, &src_right); + const uint8x8_t val = + WeightedBlend(src_left, src_right, inv_shift_mul, shift_mul); + + if (width == 4) { + StoreLo4(dst, val); + } else { + vst1_u8(dst, val); + } + dst += stride; + } while (++y < height); +} + +// Process 4 or 8 |width| by any |height|. +template <int width> +inline void DirectionalZone1Blend_WxH(uint8_t* dest, const ptrdiff_t stride, + const int height, + const uint8_t* const top_row, + int zone_bounds, int top_x, + const int xstep, + const int upsample_shift) { + assert(width == 4 || width == 8); + + const int scale_bits_x = 6 - upsample_shift; + + const uint8x8_t all = vcreate_u8(0x0706050403020100); + const uint8x8_t even = vcreate_u8(0x0e0c0a0806040200); + const uint8x8_t base_step = upsample_shift ? even : all; + const uint8x8_t right_step = vadd_u8(base_step, vdup_n_u8(1)); + + int y = 0; + do { + const uint8_t* const src = top_row + (top_x >> scale_bits_x); + uint8x8_t left, right; + LoadStepwise(src, base_step, right_step, &left, &right); + + const uint8_t shift = ((top_x << upsample_shift) & 0x3f) >> 1; + const uint8x8_t val = WeightedBlend(left, right, shift); + + uint8x8_t dst_blend = vld1_u8(dest); + // |zone_bounds| values can be negative. + uint8x8_t blend = + vcge_s8(vreinterpret_s8_u8(all), vdup_n_s8((zone_bounds >> 6))); + uint8x8_t output = vbsl_u8(blend, val, dst_blend); + + if (width == 4) { + StoreLo4(dest, output); + } else { + vst1_u8(dest, output); + } + dest += stride; + zone_bounds += xstep; + top_x -= xstep; + } while (++y < height); +} + +// The height at which a load of 16 bytes will not contain enough source pixels +// from |left_column| to supply an accurate row when computing 8 pixels at a +// time. The values are found by inspection. By coincidence, all angles that +// satisfy (ystep >> 6) == 2 map to the same value, so it is enough to look up +// by ystep >> 6. The largest index for this lookup is 1023 >> 6 == 15. +constexpr int kDirectionalZone2ShuffleInvalidHeight[16] = { + 1024, 1024, 16, 16, 16, 16, 0, 0, 18, 0, 0, 0, 0, 0, 0, 40}; + +// 7.11.2.4 (8) 90 < angle > 180 +// The strategy for these functions (4xH and 8+xH) is to know how many blocks +// can be processed with just pixels from |top_ptr|, then handle mixed blocks, +// then handle only blocks that take from |left_ptr|. Additionally, a fast +// index-shuffle approach is used for pred values from |left_column| in sections +// that permit it. +inline void DirectionalZone2_4xH(uint8_t* dst, const ptrdiff_t stride, + const uint8_t* const top_row, + const uint8_t* const left_column, + const int height, const int xstep, + const int ystep, const bool upsampled_top, + const bool upsampled_left) { + const int upsample_left_shift = static_cast<int>(upsampled_left); + const int upsample_top_shift = static_cast<int>(upsampled_top); + + // Helper vector. + const int16x8_t zero_to_seven = {0, 1, 2, 3, 4, 5, 6, 7}; + + // Loop incrementers for moving by block (4xN). Vertical still steps by 8. If + // it's only 4, it will be finished in the first iteration. + const ptrdiff_t stride8 = stride << 3; + const int xstep8 = xstep << 3; + + const int min_height = (height == 4) ? 4 : 8; + + // All columns from |min_top_only_x| to the right will only need |top_row| to + // compute and can therefore call the Zone1 functions. This assumes |xstep| is + // at least 3. + assert(xstep >= 3); + const int min_top_only_x = std::min((height * xstep) >> 6, /* width */ 4); + + // For steep angles, the source pixels from |left_column| may not fit in a + // 16-byte load for shuffling. + // TODO(petersonab): Find a more precise formula for this subject to x. + // TODO(johannkoenig): Revisit this for |width| == 4. + const int max_shuffle_height = + std::min(kDirectionalZone2ShuffleInvalidHeight[ystep >> 6], height); + + // Offsets the original zone bound value to simplify x < (y+1)*xstep/64 -1 + int xstep_bounds_base = (xstep == 64) ? 0 : xstep - 1; + + const int left_base_increment = ystep >> 6; + const int ystep_remainder = ystep & 0x3F; + + // If the 64 scaling is regarded as a decimal point, the first value of the + // left_y vector omits the portion which is covered under the left_column + // offset. The following values need the full ystep as a relative offset. + int16x8_t left_y = vmulq_n_s16(zero_to_seven, -ystep); + left_y = vaddq_s16(left_y, vdupq_n_s16(-ystep_remainder)); + + // This loop treats each set of 4 columns in 3 stages with y-value boundaries. + // The first stage, before the first y-loop, covers blocks that are only + // computed from the top row. The second stage, comprising two y-loops, covers + // blocks that have a mixture of values computed from top or left. The final + // stage covers blocks that are only computed from the left. + if (min_top_only_x > 0) { + // Round down to the nearest multiple of 8. + // TODO(johannkoenig): This never hits for Wx4 blocks but maybe it should. + const int max_top_only_y = std::min((1 << 6) / xstep, height) & ~7; + DirectionalZone1_WxH<4>(dst, stride, max_top_only_y, top_row, -xstep, + upsampled_top); + + if (max_top_only_y == height) return; + + int y = max_top_only_y; + dst += stride * y; + const int xstep_y = xstep * y; + + // All rows from |min_left_only_y| down for this set of columns only need + // |left_column| to compute. + const int min_left_only_y = std::min((4 << 6) / xstep, height); + // At high angles such that min_left_only_y < 8, ystep is low and xstep is + // high. This means that max_shuffle_height is unbounded and xstep_bounds + // will overflow in 16 bits. This is prevented by stopping the first + // blending loop at min_left_only_y for such cases, which means we skip over + // the second blending loop as well. + const int left_shuffle_stop_y = + std::min(max_shuffle_height, min_left_only_y); + int xstep_bounds = xstep_bounds_base + xstep_y; + int top_x = -xstep - xstep_y; + + // +8 increment is OK because if height is 4 this only goes once. + for (; y < left_shuffle_stop_y; + y += 8, dst += stride8, xstep_bounds += xstep8, top_x -= xstep8) { + DirectionalZone2FromLeftCol_WxH<4>( + dst, stride, min_height, + left_column + ((y - left_base_increment) << upsample_left_shift), + left_y, upsample_left_shift); + + DirectionalZone1Blend_WxH<4>(dst, stride, min_height, top_row, + xstep_bounds, top_x, xstep, + upsample_top_shift); + } + + // Pick up from the last y-value, using the slower but secure method for + // left prediction. + const int16_t base_left_y = vgetq_lane_s16(left_y, 0); + for (; y < min_left_only_y; + y += 8, dst += stride8, xstep_bounds += xstep8, top_x -= xstep8) { + DirectionalZone3_WxH<4>( + dst, stride, min_height, + left_column + ((y - left_base_increment) << upsample_left_shift), + base_left_y, -ystep, upsample_left_shift); + + DirectionalZone1Blend_WxH<4>(dst, stride, min_height, top_row, + xstep_bounds, top_x, xstep, + upsample_top_shift); + } + // Loop over y for left_only rows. + for (; y < height; y += 8, dst += stride8) { + DirectionalZone3_WxH<4>( + dst, stride, min_height, + left_column + ((y - left_base_increment) << upsample_left_shift), + base_left_y, -ystep, upsample_left_shift); + } + } else { + DirectionalZone1_WxH<4>(dst, stride, height, top_row, -xstep, + upsampled_top); + } +} + +// Process a multiple of 8 |width|. +inline void DirectionalZone2_8(uint8_t* const dst, const ptrdiff_t stride, + const uint8_t* const top_row, + const uint8_t* const left_column, + const int width, const int height, + const int xstep, const int ystep, + const bool upsampled_top, + const bool upsampled_left) { + const int upsample_left_shift = static_cast<int>(upsampled_left); + const int upsample_top_shift = static_cast<int>(upsampled_top); + + // Helper vector. + const int16x8_t zero_to_seven = {0, 1, 2, 3, 4, 5, 6, 7}; + + // Loop incrementers for moving by block (8x8). This function handles blocks + // with height 4 as well. They are calculated in one pass so these variables + // do not get used. + const ptrdiff_t stride8 = stride << 3; + const int xstep8 = xstep << 3; + const int ystep8 = ystep << 3; + + // Process Wx4 blocks. + const int min_height = (height == 4) ? 4 : 8; + + // All columns from |min_top_only_x| to the right will only need |top_row| to + // compute and can therefore call the Zone1 functions. This assumes |xstep| is + // at least 3. + assert(xstep >= 3); + const int min_top_only_x = std::min((height * xstep) >> 6, width); + + // For steep angles, the source pixels from |left_column| may not fit in a + // 16-byte load for shuffling. + // TODO(petersonab): Find a more precise formula for this subject to x. + const int max_shuffle_height = + std::min(kDirectionalZone2ShuffleInvalidHeight[ystep >> 6], height); + + // Offsets the original zone bound value to simplify x < (y+1)*xstep/64 -1 + int xstep_bounds_base = (xstep == 64) ? 0 : xstep - 1; + + const int left_base_increment = ystep >> 6; + const int ystep_remainder = ystep & 0x3F; + + const int left_base_increment8 = ystep8 >> 6; + const int ystep_remainder8 = ystep8 & 0x3F; + const int16x8_t increment_left8 = vdupq_n_s16(ystep_remainder8); + + // If the 64 scaling is regarded as a decimal point, the first value of the + // left_y vector omits the portion which is covered under the left_column + // offset. Following values need the full ystep as a relative offset. + int16x8_t left_y = vmulq_n_s16(zero_to_seven, -ystep); + left_y = vaddq_s16(left_y, vdupq_n_s16(-ystep_remainder)); + + // This loop treats each set of 4 columns in 3 stages with y-value boundaries. + // The first stage, before the first y-loop, covers blocks that are only + // computed from the top row. The second stage, comprising two y-loops, covers + // blocks that have a mixture of values computed from top or left. The final + // stage covers blocks that are only computed from the left. + int x = 0; + for (int left_offset = -left_base_increment; x < min_top_only_x; x += 8, + xstep_bounds_base -= (8 << 6), + left_y = vsubq_s16(left_y, increment_left8), + left_offset -= left_base_increment8) { + uint8_t* dst_x = dst + x; + + // Round down to the nearest multiple of 8. + const int max_top_only_y = std::min(((x + 1) << 6) / xstep, height) & ~7; + DirectionalZone1_WxH<8>(dst_x, stride, max_top_only_y, + top_row + (x << upsample_top_shift), -xstep, + upsampled_top); + + if (max_top_only_y == height) continue; + + int y = max_top_only_y; + dst_x += stride * y; + const int xstep_y = xstep * y; + + // All rows from |min_left_only_y| down for this set of columns only need + // |left_column| to compute. + const int min_left_only_y = std::min(((x + 8) << 6) / xstep, height); + // At high angles such that min_left_only_y < 8, ystep is low and xstep is + // high. This means that max_shuffle_height is unbounded and xstep_bounds + // will overflow in 16 bits. This is prevented by stopping the first + // blending loop at min_left_only_y for such cases, which means we skip over + // the second blending loop as well. + const int left_shuffle_stop_y = + std::min(max_shuffle_height, min_left_only_y); + int xstep_bounds = xstep_bounds_base + xstep_y; + int top_x = -xstep - xstep_y; + + for (; y < left_shuffle_stop_y; + y += 8, dst_x += stride8, xstep_bounds += xstep8, top_x -= xstep8) { + DirectionalZone2FromLeftCol_WxH<8>( + dst_x, stride, min_height, + left_column + ((left_offset + y) << upsample_left_shift), left_y, + upsample_left_shift); + + DirectionalZone1Blend_WxH<8>( + dst_x, stride, min_height, top_row + (x << upsample_top_shift), + xstep_bounds, top_x, xstep, upsample_top_shift); + } + + // Pick up from the last y-value, using the slower but secure method for + // left prediction. + const int16_t base_left_y = vgetq_lane_s16(left_y, 0); + for (; y < min_left_only_y; + y += 8, dst_x += stride8, xstep_bounds += xstep8, top_x -= xstep8) { + DirectionalZone3_WxH<8>( + dst_x, stride, min_height, + left_column + ((left_offset + y) << upsample_left_shift), base_left_y, + -ystep, upsample_left_shift); + + DirectionalZone1Blend_WxH<8>( + dst_x, stride, min_height, top_row + (x << upsample_top_shift), + xstep_bounds, top_x, xstep, upsample_top_shift); + } + // Loop over y for left_only rows. + for (; y < height; y += 8, dst_x += stride8) { + DirectionalZone3_WxH<8>( + dst_x, stride, min_height, + left_column + ((left_offset + y) << upsample_left_shift), base_left_y, + -ystep, upsample_left_shift); + } + } + // TODO(johannkoenig): May be able to remove this branch. + if (x < width) { + DirectionalZone1_WxH(dst + x, stride, width - x, height, + top_row + (x << upsample_top_shift), -xstep, + upsampled_top); + } +} + +void DirectionalIntraPredictorZone2_NEON( + void* const dest, const ptrdiff_t stride, const void* const top_row, + const void* const left_column, const int width, const int height, + const int xstep, const int ystep, const bool upsampled_top, + const bool upsampled_left) { + // Increasing the negative buffer for this function allows more rows to be + // processed at a time without branching in an inner loop to check the base. + uint8_t top_buffer[288]; + uint8_t left_buffer[288]; + memcpy(top_buffer + 128, static_cast<const uint8_t*>(top_row) - 16, 160); + memcpy(left_buffer + 128, static_cast<const uint8_t*>(left_column) - 16, 160); + const uint8_t* top_ptr = top_buffer + 144; + const uint8_t* left_ptr = left_buffer + 144; + auto* dst = static_cast<uint8_t*>(dest); + + if (width == 4) { + DirectionalZone2_4xH(dst, stride, top_ptr, left_ptr, height, xstep, ystep, + upsampled_top, upsampled_left); + } else { + DirectionalZone2_8(dst, stride, top_ptr, left_ptr, width, height, xstep, + ystep, upsampled_top, upsampled_left); + } +} + +void DirectionalIntraPredictorZone3_NEON(void* const dest, + const ptrdiff_t stride, + const void* const left_column, + const int width, const int height, + const int ystep, + const bool upsampled_left) { + const auto* const left = static_cast<const uint8_t*>(left_column); + + assert(ystep > 0); + + const int upsample_shift = static_cast<int>(upsampled_left); + const int scale_bits = 6 - upsample_shift; + const int base_step = 1 << upsample_shift; + + if (width == 4 || height == 4) { + // This block can handle all sizes but the specializations for other sizes + // are faster. + const uint8x8_t all = vcreate_u8(0x0706050403020100); + const uint8x8_t even = vcreate_u8(0x0e0c0a0806040200); + const uint8x8_t base_step_v = upsampled_left ? even : all; + const uint8x8_t right_step = vadd_u8(base_step_v, vdup_n_u8(1)); + + int y = 0; + do { + int x = 0; + do { + uint8_t* dst = static_cast<uint8_t*>(dest); + dst += y * stride + x; + uint8x8_t left_v[4], right_v[4], value_v[4]; + const int ystep_base = ystep * x; + const int offset = y * base_step; + + const int index_0 = ystep_base + ystep * 1; + LoadStepwise(left + offset + (index_0 >> scale_bits), base_step_v, + right_step, &left_v[0], &right_v[0]); + value_v[0] = WeightedBlend(left_v[0], right_v[0], + ((index_0 << upsample_shift) & 0x3F) >> 1); + + const int index_1 = ystep_base + ystep * 2; + LoadStepwise(left + offset + (index_1 >> scale_bits), base_step_v, + right_step, &left_v[1], &right_v[1]); + value_v[1] = WeightedBlend(left_v[1], right_v[1], + ((index_1 << upsample_shift) & 0x3F) >> 1); + + const int index_2 = ystep_base + ystep * 3; + LoadStepwise(left + offset + (index_2 >> scale_bits), base_step_v, + right_step, &left_v[2], &right_v[2]); + value_v[2] = WeightedBlend(left_v[2], right_v[2], + ((index_2 << upsample_shift) & 0x3F) >> 1); + + const int index_3 = ystep_base + ystep * 4; + LoadStepwise(left + offset + (index_3 >> scale_bits), base_step_v, + right_step, &left_v[3], &right_v[3]); + value_v[3] = WeightedBlend(left_v[3], right_v[3], + ((index_3 << upsample_shift) & 0x3F) >> 1); + + // 8x4 transpose. + const uint8x8x2_t b0 = vtrn_u8(value_v[0], value_v[1]); + const uint8x8x2_t b1 = vtrn_u8(value_v[2], value_v[3]); + + const uint16x4x2_t c0 = vtrn_u16(vreinterpret_u16_u8(b0.val[0]), + vreinterpret_u16_u8(b1.val[0])); + const uint16x4x2_t c1 = vtrn_u16(vreinterpret_u16_u8(b0.val[1]), + vreinterpret_u16_u8(b1.val[1])); + + StoreLo4(dst, vreinterpret_u8_u16(c0.val[0])); + dst += stride; + StoreLo4(dst, vreinterpret_u8_u16(c1.val[0])); + dst += stride; + StoreLo4(dst, vreinterpret_u8_u16(c0.val[1])); + dst += stride; + StoreLo4(dst, vreinterpret_u8_u16(c1.val[1])); + + if (height > 4) { + dst += stride; + StoreHi4(dst, vreinterpret_u8_u16(c0.val[0])); + dst += stride; + StoreHi4(dst, vreinterpret_u8_u16(c1.val[0])); + dst += stride; + StoreHi4(dst, vreinterpret_u8_u16(c0.val[1])); + dst += stride; + StoreHi4(dst, vreinterpret_u8_u16(c1.val[1])); + } + x += 4; + } while (x < width); + y += 8; + } while (y < height); + } else { // 8x8 at a time. + // Limited improvement for 8x8. ~20% faster for 64x64. + int y = 0; + do { + int x = 0; + do { + uint8_t* dst = static_cast<uint8_t*>(dest); + dst += y * stride + x; + const int ystep_base = ystep * (x + 1); + + DirectionalZone3_WxH<8>(dst, stride, 8, left + (y << upsample_shift), + ystep_base, ystep, upsample_shift); + x += 8; + } while (x < width); + y += 8; + } while (y < height); + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->directional_intra_predictor_zone1 = DirectionalIntraPredictorZone1_NEON; + dsp->directional_intra_predictor_zone2 = DirectionalIntraPredictorZone2_NEON; + dsp->directional_intra_predictor_zone3 = DirectionalIntraPredictorZone3_NEON; +} + +} // namespace +} // namespace low_bitdepth + +void IntraPredDirectionalInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void IntraPredDirectionalInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/intrapred_filter_intra_neon.cc b/src/dsp/arm/intrapred_filter_intra_neon.cc new file mode 100644 index 0000000..411708e --- /dev/null +++ b/src/dsp/arm/intrapred_filter_intra_neon.cc @@ -0,0 +1,176 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intrapred.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { + +namespace low_bitdepth { +namespace { + +// Transpose kFilterIntraTaps and convert the first row to unsigned values. +// +// With the previous orientation we were able to multiply all the input values +// by a single tap. This required that all the input values be in one vector +// which requires expensive set up operations (shifts, vext, vtbl). All the +// elements of the result needed to be summed (easy on A64 - vaddvq_s16) but +// then the shifting, rounding, and clamping was done in GP registers. +// +// Switching to unsigned values allows multiplying the 8 bit inputs directly. +// When one value was negative we needed to vmovl_u8 first so that the results +// maintained the proper sign. +// +// We take this into account when summing the values by subtracting the product +// of the first row. +alignas(8) constexpr uint8_t kTransposedTaps[kNumFilterIntraPredictors][7][8] = + {{{6, 5, 3, 3, 4, 3, 3, 3}, // Original values are negative. + {10, 2, 1, 1, 6, 2, 2, 1}, + {0, 10, 1, 1, 0, 6, 2, 2}, + {0, 0, 10, 2, 0, 0, 6, 2}, + {0, 0, 0, 10, 0, 0, 0, 6}, + {12, 9, 7, 5, 2, 2, 2, 3}, + {0, 0, 0, 0, 12, 9, 7, 5}}, + {{10, 6, 4, 2, 10, 6, 4, 2}, // Original values are negative. + {16, 0, 0, 0, 16, 0, 0, 0}, + {0, 16, 0, 0, 0, 16, 0, 0}, + {0, 0, 16, 0, 0, 0, 16, 0}, + {0, 0, 0, 16, 0, 0, 0, 16}, + {10, 6, 4, 2, 0, 0, 0, 0}, + {0, 0, 0, 0, 10, 6, 4, 2}}, + {{8, 8, 8, 8, 4, 4, 4, 4}, // Original values are negative. + {8, 0, 0, 0, 4, 0, 0, 0}, + {0, 8, 0, 0, 0, 4, 0, 0}, + {0, 0, 8, 0, 0, 0, 4, 0}, + {0, 0, 0, 8, 0, 0, 0, 4}, + {16, 16, 16, 16, 0, 0, 0, 0}, + {0, 0, 0, 0, 16, 16, 16, 16}}, + {{2, 1, 1, 0, 1, 1, 1, 1}, // Original values are negative. + {8, 3, 2, 1, 4, 3, 2, 2}, + {0, 8, 3, 2, 0, 4, 3, 2}, + {0, 0, 8, 3, 0, 0, 4, 3}, + {0, 0, 0, 8, 0, 0, 0, 4}, + {10, 6, 4, 2, 3, 4, 4, 3}, + {0, 0, 0, 0, 10, 6, 4, 3}}, + {{12, 10, 9, 8, 10, 9, 8, 7}, // Original values are negative. + {14, 0, 0, 0, 12, 1, 0, 0}, + {0, 14, 0, 0, 0, 12, 0, 0}, + {0, 0, 14, 0, 0, 0, 12, 1}, + {0, 0, 0, 14, 0, 0, 0, 12}, + {14, 12, 11, 10, 0, 0, 1, 1}, + {0, 0, 0, 0, 14, 12, 11, 9}}}; + +void FilterIntraPredictor_NEON(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column, + FilterIntraPredictor pred, int width, + int height) { + const uint8_t* const top = static_cast<const uint8_t*>(top_row); + const uint8_t* const left = static_cast<const uint8_t*>(left_column); + + assert(width <= 32 && height <= 32); + + uint8_t* dst = static_cast<uint8_t*>(dest); + + uint8x8_t transposed_taps[7]; + for (int i = 0; i < 7; ++i) { + transposed_taps[i] = vld1_u8(kTransposedTaps[pred][i]); + } + + uint8_t relative_top_left = top[-1]; + const uint8_t* relative_top = top; + uint8_t relative_left[2] = {left[0], left[1]}; + + int y = 0; + do { + uint8_t* row_dst = dst; + int x = 0; + do { + uint16x8_t sum = vdupq_n_u16(0); + const uint16x8_t subtrahend = + vmull_u8(transposed_taps[0], vdup_n_u8(relative_top_left)); + for (int i = 1; i < 5; ++i) { + sum = vmlal_u8(sum, transposed_taps[i], vdup_n_u8(relative_top[i - 1])); + } + for (int i = 5; i < 7; ++i) { + sum = + vmlal_u8(sum, transposed_taps[i], vdup_n_u8(relative_left[i - 5])); + } + + const int16x8_t sum_signed = + vreinterpretq_s16_u16(vsubq_u16(sum, subtrahend)); + const int16x8_t sum_shifted = vrshrq_n_s16(sum_signed, 4); + + uint8x8_t sum_saturated = vqmovun_s16(sum_shifted); + + StoreLo4(row_dst, sum_saturated); + StoreHi4(row_dst + stride, sum_saturated); + + // Progress across + relative_top_left = relative_top[3]; + relative_top += 4; + relative_left[0] = row_dst[3]; + relative_left[1] = row_dst[3 + stride]; + row_dst += 4; + x += 4; + } while (x < width); + + // Progress down. + relative_top_left = left[y + 1]; + relative_top = dst + stride; + relative_left[0] = left[y + 2]; + relative_left[1] = left[y + 3]; + + dst += 2 * stride; + y += 2; + } while (y < height); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->filter_intra_predictor = FilterIntraPredictor_NEON; +} + +} // namespace +} // namespace low_bitdepth + +void IntraPredFilterIntraInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void IntraPredFilterIntraInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/intrapred_neon.cc b/src/dsp/arm/intrapred_neon.cc new file mode 100644 index 0000000..c967d82 --- /dev/null +++ b/src/dsp/arm/intrapred_neon.cc @@ -0,0 +1,1144 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intrapred.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" + +namespace libgav1 { +namespace dsp { +namespace { + +//------------------------------------------------------------------------------ +// DcPredFuncs_NEON + +using DcSumFunc = uint32x2_t (*)(const void* ref_0, const int ref_0_size_log2, + const bool use_ref_1, const void* ref_1, + const int ref_1_size_log2); +using DcStoreFunc = void (*)(void* dest, ptrdiff_t stride, const uint32x2_t dc); + +// DC intra-predictors for square blocks. +template <int block_width_log2, int block_height_log2, DcSumFunc sumfn, + DcStoreFunc storefn> +struct DcPredFuncs_NEON { + DcPredFuncs_NEON() = delete; + + static void DcTop(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); + static void DcLeft(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); + static void Dc(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); +}; + +template <int block_width_log2, int block_height_log2, DcSumFunc sumfn, + DcStoreFunc storefn> +void DcPredFuncs_NEON<block_width_log2, block_height_log2, sumfn, + storefn>::DcTop(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* /*left_column*/) { + const uint32x2_t sum = sumfn(top_row, block_width_log2, false, nullptr, 0); + const uint32x2_t dc = vrshr_n_u32(sum, block_width_log2); + storefn(dest, stride, dc); +} + +template <int block_width_log2, int block_height_log2, DcSumFunc sumfn, + DcStoreFunc storefn> +void DcPredFuncs_NEON<block_width_log2, block_height_log2, sumfn, + storefn>::DcLeft(void* const dest, ptrdiff_t stride, + const void* /*top_row*/, + const void* const left_column) { + const uint32x2_t sum = + sumfn(left_column, block_height_log2, false, nullptr, 0); + const uint32x2_t dc = vrshr_n_u32(sum, block_height_log2); + storefn(dest, stride, dc); +} + +template <int block_width_log2, int block_height_log2, DcSumFunc sumfn, + DcStoreFunc storefn> +void DcPredFuncs_NEON<block_width_log2, block_height_log2, sumfn, storefn>::Dc( + void* const dest, ptrdiff_t stride, const void* const top_row, + const void* const left_column) { + const uint32x2_t sum = + sumfn(top_row, block_width_log2, true, left_column, block_height_log2); + if (block_width_log2 == block_height_log2) { + const uint32x2_t dc = vrshr_n_u32(sum, block_width_log2 + 1); + storefn(dest, stride, dc); + } else { + // TODO(johannkoenig): Compare this to mul/shift in vectors. + const int divisor = (1 << block_width_log2) + (1 << block_height_log2); + uint32_t dc = vget_lane_u32(sum, 0); + dc += divisor >> 1; + dc /= divisor; + storefn(dest, stride, vdup_n_u32(dc)); + } +} + +// Sum all the elements in the vector into the low 32 bits. +inline uint32x2_t Sum(const uint16x4_t val) { + const uint32x2_t sum = vpaddl_u16(val); + return vpadd_u32(sum, sum); +} + +// Sum all the elements in the vector into the low 32 bits. +inline uint32x2_t Sum(const uint16x8_t val) { + const uint32x4_t sum_0 = vpaddlq_u16(val); + const uint64x2_t sum_1 = vpaddlq_u32(sum_0); + return vadd_u32(vget_low_u32(vreinterpretq_u32_u64(sum_1)), + vget_high_u32(vreinterpretq_u32_u64(sum_1))); +} + +} // namespace + +//------------------------------------------------------------------------------ +namespace low_bitdepth { +namespace { + +// Add and expand the elements in the |val_[01]| to uint16_t but do not sum the +// entire vector. +inline uint16x8_t Add(const uint8x16_t val_0, const uint8x16_t val_1) { + const uint16x8_t sum_0 = vpaddlq_u8(val_0); + const uint16x8_t sum_1 = vpaddlq_u8(val_1); + return vaddq_u16(sum_0, sum_1); +} + +// Add and expand the elements in the |val_[0123]| to uint16_t but do not sum +// the entire vector. +inline uint16x8_t Add(const uint8x16_t val_0, const uint8x16_t val_1, + const uint8x16_t val_2, const uint8x16_t val_3) { + const uint16x8_t sum_0 = Add(val_0, val_1); + const uint16x8_t sum_1 = Add(val_2, val_3); + return vaddq_u16(sum_0, sum_1); +} + +// Load and combine 32 uint8_t values. +inline uint16x8_t LoadAndAdd32(const uint8_t* buf) { + const uint8x16_t val_0 = vld1q_u8(buf); + const uint8x16_t val_1 = vld1q_u8(buf + 16); + return Add(val_0, val_1); +} + +// Load and combine 64 uint8_t values. +inline uint16x8_t LoadAndAdd64(const uint8_t* buf) { + const uint8x16_t val_0 = vld1q_u8(buf); + const uint8x16_t val_1 = vld1q_u8(buf + 16); + const uint8x16_t val_2 = vld1q_u8(buf + 32); + const uint8x16_t val_3 = vld1q_u8(buf + 48); + return Add(val_0, val_1, val_2, val_3); +} + +// |ref_[01]| each point to 1 << |ref[01]_size_log2| packed uint8_t values. +// If |use_ref_1| is false then only sum |ref_0|. +// For |ref[01]_size_log2| == 4 this relies on |ref_[01]| being aligned to +// uint32_t. +inline uint32x2_t DcSum_NEON(const void* ref_0, const int ref_0_size_log2, + const bool use_ref_1, const void* ref_1, + const int ref_1_size_log2) { + const auto* const ref_0_u8 = static_cast<const uint8_t*>(ref_0); + const auto* const ref_1_u8 = static_cast<const uint8_t*>(ref_1); + if (ref_0_size_log2 == 2) { + uint8x8_t val = Load4(ref_0_u8); + if (use_ref_1) { + if (ref_1_size_log2 == 2) { // 4x4 + val = Load4<1>(ref_1_u8, val); + return Sum(vpaddl_u8(val)); + } else if (ref_1_size_log2 == 3) { // 4x8 + const uint8x8_t val_1 = vld1_u8(ref_1_u8); + const uint16x4_t sum_0 = vpaddl_u8(val); + const uint16x4_t sum_1 = vpaddl_u8(val_1); + return Sum(vadd_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 4) { // 4x16 + const uint8x16_t val_1 = vld1q_u8(ref_1_u8); + return Sum(vaddw_u8(vpaddlq_u8(val_1), val)); + } + } + // 4x1 + const uint16x4_t sum = vpaddl_u8(val); + return vpaddl_u16(sum); + } else if (ref_0_size_log2 == 3) { + const uint8x8_t val_0 = vld1_u8(ref_0_u8); + if (use_ref_1) { + if (ref_1_size_log2 == 2) { // 8x4 + const uint8x8_t val_1 = Load4(ref_1_u8); + const uint16x4_t sum_0 = vpaddl_u8(val_0); + const uint16x4_t sum_1 = vpaddl_u8(val_1); + return Sum(vadd_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 3) { // 8x8 + const uint8x8_t val_1 = vld1_u8(ref_1_u8); + const uint16x4_t sum_0 = vpaddl_u8(val_0); + const uint16x4_t sum_1 = vpaddl_u8(val_1); + return Sum(vadd_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 4) { // 8x16 + const uint8x16_t val_1 = vld1q_u8(ref_1_u8); + return Sum(vaddw_u8(vpaddlq_u8(val_1), val_0)); + } else if (ref_1_size_log2 == 5) { // 8x32 + return Sum(vaddw_u8(LoadAndAdd32(ref_1_u8), val_0)); + } + } + // 8x1 + return Sum(vpaddl_u8(val_0)); + } else if (ref_0_size_log2 == 4) { + const uint8x16_t val_0 = vld1q_u8(ref_0_u8); + if (use_ref_1) { + if (ref_1_size_log2 == 2) { // 16x4 + const uint8x8_t val_1 = Load4(ref_1_u8); + return Sum(vaddw_u8(vpaddlq_u8(val_0), val_1)); + } else if (ref_1_size_log2 == 3) { // 16x8 + const uint8x8_t val_1 = vld1_u8(ref_1_u8); + return Sum(vaddw_u8(vpaddlq_u8(val_0), val_1)); + } else if (ref_1_size_log2 == 4) { // 16x16 + const uint8x16_t val_1 = vld1q_u8(ref_1_u8); + return Sum(Add(val_0, val_1)); + } else if (ref_1_size_log2 == 5) { // 16x32 + const uint16x8_t sum_0 = vpaddlq_u8(val_0); + const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u8); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 6) { // 16x64 + const uint16x8_t sum_0 = vpaddlq_u8(val_0); + const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u8); + return Sum(vaddq_u16(sum_0, sum_1)); + } + } + // 16x1 + return Sum(vpaddlq_u8(val_0)); + } else if (ref_0_size_log2 == 5) { + const uint16x8_t sum_0 = LoadAndAdd32(ref_0_u8); + if (use_ref_1) { + if (ref_1_size_log2 == 3) { // 32x8 + const uint8x8_t val_1 = vld1_u8(ref_1_u8); + return Sum(vaddw_u8(sum_0, val_1)); + } else if (ref_1_size_log2 == 4) { // 32x16 + const uint8x16_t val_1 = vld1q_u8(ref_1_u8); + const uint16x8_t sum_1 = vpaddlq_u8(val_1); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 5) { // 32x32 + const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u8); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 6) { // 32x64 + const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u8); + return Sum(vaddq_u16(sum_0, sum_1)); + } + } + // 32x1 + return Sum(sum_0); + } + + assert(ref_0_size_log2 == 6); + const uint16x8_t sum_0 = LoadAndAdd64(ref_0_u8); + if (use_ref_1) { + if (ref_1_size_log2 == 4) { // 64x16 + const uint8x16_t val_1 = vld1q_u8(ref_1_u8); + const uint16x8_t sum_1 = vpaddlq_u8(val_1); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 5) { // 64x32 + const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u8); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 6) { // 64x64 + const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u8); + return Sum(vaddq_u16(sum_0, sum_1)); + } + } + // 64x1 + return Sum(sum_0); +} + +template <int width, int height> +inline void DcStore_NEON(void* const dest, ptrdiff_t stride, + const uint32x2_t dc) { + const uint8x16_t dc_dup = vdupq_lane_u8(vreinterpret_u8_u32(dc), 0); + auto* dst = static_cast<uint8_t*>(dest); + if (width == 4) { + int i = height - 1; + do { + StoreLo4(dst, vget_low_u8(dc_dup)); + dst += stride; + } while (--i != 0); + StoreLo4(dst, vget_low_u8(dc_dup)); + } else if (width == 8) { + int i = height - 1; + do { + vst1_u8(dst, vget_low_u8(dc_dup)); + dst += stride; + } while (--i != 0); + vst1_u8(dst, vget_low_u8(dc_dup)); + } else if (width == 16) { + int i = height - 1; + do { + vst1q_u8(dst, dc_dup); + dst += stride; + } while (--i != 0); + vst1q_u8(dst, dc_dup); + } else if (width == 32) { + int i = height - 1; + do { + vst1q_u8(dst, dc_dup); + vst1q_u8(dst + 16, dc_dup); + dst += stride; + } while (--i != 0); + vst1q_u8(dst, dc_dup); + vst1q_u8(dst + 16, dc_dup); + } else { + assert(width == 64); + int i = height - 1; + do { + vst1q_u8(dst, dc_dup); + vst1q_u8(dst + 16, dc_dup); + vst1q_u8(dst + 32, dc_dup); + vst1q_u8(dst + 48, dc_dup); + dst += stride; + } while (--i != 0); + vst1q_u8(dst, dc_dup); + vst1q_u8(dst + 16, dc_dup); + vst1q_u8(dst + 32, dc_dup); + vst1q_u8(dst + 48, dc_dup); + } +} + +template <int width, int height> +inline void Paeth4Or8xN_NEON(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + auto* dest_u8 = static_cast<uint8_t*>(dest); + const auto* const top_row_u8 = static_cast<const uint8_t*>(top_row); + const auto* const left_col_u8 = static_cast<const uint8_t*>(left_column); + + const uint8x8_t top_left = vdup_n_u8(top_row_u8[-1]); + const uint16x8_t top_left_x2 = vdupq_n_u16(top_row_u8[-1] + top_row_u8[-1]); + uint8x8_t top; + if (width == 4) { + top = Load4(top_row_u8); + } else { // width == 8 + top = vld1_u8(top_row_u8); + } + + for (int y = 0; y < height; ++y) { + const uint8x8_t left = vdup_n_u8(left_col_u8[y]); + + const uint8x8_t left_dist = vabd_u8(top, top_left); + const uint8x8_t top_dist = vabd_u8(left, top_left); + const uint16x8_t top_left_dist = + vabdq_u16(vaddl_u8(top, left), top_left_x2); + + const uint8x8_t left_le_top = vcle_u8(left_dist, top_dist); + const uint8x8_t left_le_top_left = + vmovn_u16(vcleq_u16(vmovl_u8(left_dist), top_left_dist)); + const uint8x8_t top_le_top_left = + vmovn_u16(vcleq_u16(vmovl_u8(top_dist), top_left_dist)); + + // if (left_dist <= top_dist && left_dist <= top_left_dist) + const uint8x8_t left_mask = vand_u8(left_le_top, left_le_top_left); + // dest[x] = left_column[y]; + // Fill all the unused spaces with 'top'. They will be overwritten when + // the positions for top_left are known. + uint8x8_t result = vbsl_u8(left_mask, left, top); + // else if (top_dist <= top_left_dist) + // dest[x] = top_row[x]; + // Add these values to the mask. They were already set. + const uint8x8_t left_or_top_mask = vorr_u8(left_mask, top_le_top_left); + // else + // dest[x] = top_left; + result = vbsl_u8(left_or_top_mask, result, top_left); + + if (width == 4) { + StoreLo4(dest_u8, result); + } else { // width == 8 + vst1_u8(dest_u8, result); + } + dest_u8 += stride; + } +} + +// Calculate X distance <= TopLeft distance and pack the resulting mask into +// uint8x8_t. +inline uint8x16_t XLeTopLeft(const uint8x16_t x_dist, + const uint16x8_t top_left_dist_low, + const uint16x8_t top_left_dist_high) { + // TODO(johannkoenig): cle() should work with vmovn(top_left_dist) instead of + // using movl(x_dist). + const uint8x8_t x_le_top_left_low = + vmovn_u16(vcleq_u16(vmovl_u8(vget_low_u8(x_dist)), top_left_dist_low)); + const uint8x8_t x_le_top_left_high = + vmovn_u16(vcleq_u16(vmovl_u8(vget_high_u8(x_dist)), top_left_dist_high)); + return vcombine_u8(x_le_top_left_low, x_le_top_left_high); +} + +// Select the closest values and collect them. +inline uint8x16_t SelectPaeth(const uint8x16_t top, const uint8x16_t left, + const uint8x16_t top_left, + const uint8x16_t left_le_top, + const uint8x16_t left_le_top_left, + const uint8x16_t top_le_top_left) { + // if (left_dist <= top_dist && left_dist <= top_left_dist) + const uint8x16_t left_mask = vandq_u8(left_le_top, left_le_top_left); + // dest[x] = left_column[y]; + // Fill all the unused spaces with 'top'. They will be overwritten when + // the positions for top_left are known. + uint8x16_t result = vbslq_u8(left_mask, left, top); + // else if (top_dist <= top_left_dist) + // dest[x] = top_row[x]; + // Add these values to the mask. They were already set. + const uint8x16_t left_or_top_mask = vorrq_u8(left_mask, top_le_top_left); + // else + // dest[x] = top_left; + return vbslq_u8(left_or_top_mask, result, top_left); +} + +// Generate numbered and high/low versions of top_left_dist. +#define TOP_LEFT_DIST(num) \ + const uint16x8_t top_left_##num##_dist_low = vabdq_u16( \ + vaddl_u8(vget_low_u8(top[num]), vget_low_u8(left)), top_left_x2); \ + const uint16x8_t top_left_##num##_dist_high = vabdq_u16( \ + vaddl_u8(vget_high_u8(top[num]), vget_low_u8(left)), top_left_x2) + +// Generate numbered versions of XLeTopLeft with x = left. +#define LEFT_LE_TOP_LEFT(num) \ + const uint8x16_t left_le_top_left_##num = \ + XLeTopLeft(left_##num##_dist, top_left_##num##_dist_low, \ + top_left_##num##_dist_high) + +// Generate numbered versions of XLeTopLeft with x = top. +#define TOP_LE_TOP_LEFT(num) \ + const uint8x16_t top_le_top_left_##num = XLeTopLeft( \ + top_dist, top_left_##num##_dist_low, top_left_##num##_dist_high) + +template <int width, int height> +inline void Paeth16PlusxN_NEON(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + auto* dest_u8 = static_cast<uint8_t*>(dest); + const auto* const top_row_u8 = static_cast<const uint8_t*>(top_row); + const auto* const left_col_u8 = static_cast<const uint8_t*>(left_column); + + const uint8x16_t top_left = vdupq_n_u8(top_row_u8[-1]); + const uint16x8_t top_left_x2 = vdupq_n_u16(top_row_u8[-1] + top_row_u8[-1]); + uint8x16_t top[4]; + top[0] = vld1q_u8(top_row_u8); + if (width > 16) { + top[1] = vld1q_u8(top_row_u8 + 16); + if (width == 64) { + top[2] = vld1q_u8(top_row_u8 + 32); + top[3] = vld1q_u8(top_row_u8 + 48); + } + } + + for (int y = 0; y < height; ++y) { + const uint8x16_t left = vdupq_n_u8(left_col_u8[y]); + + const uint8x16_t top_dist = vabdq_u8(left, top_left); + + const uint8x16_t left_0_dist = vabdq_u8(top[0], top_left); + TOP_LEFT_DIST(0); + const uint8x16_t left_0_le_top = vcleq_u8(left_0_dist, top_dist); + LEFT_LE_TOP_LEFT(0); + TOP_LE_TOP_LEFT(0); + + const uint8x16_t result_0 = + SelectPaeth(top[0], left, top_left, left_0_le_top, left_le_top_left_0, + top_le_top_left_0); + vst1q_u8(dest_u8, result_0); + + if (width > 16) { + const uint8x16_t left_1_dist = vabdq_u8(top[1], top_left); + TOP_LEFT_DIST(1); + const uint8x16_t left_1_le_top = vcleq_u8(left_1_dist, top_dist); + LEFT_LE_TOP_LEFT(1); + TOP_LE_TOP_LEFT(1); + + const uint8x16_t result_1 = + SelectPaeth(top[1], left, top_left, left_1_le_top, left_le_top_left_1, + top_le_top_left_1); + vst1q_u8(dest_u8 + 16, result_1); + + if (width == 64) { + const uint8x16_t left_2_dist = vabdq_u8(top[2], top_left); + TOP_LEFT_DIST(2); + const uint8x16_t left_2_le_top = vcleq_u8(left_2_dist, top_dist); + LEFT_LE_TOP_LEFT(2); + TOP_LE_TOP_LEFT(2); + + const uint8x16_t result_2 = + SelectPaeth(top[2], left, top_left, left_2_le_top, + left_le_top_left_2, top_le_top_left_2); + vst1q_u8(dest_u8 + 32, result_2); + + const uint8x16_t left_3_dist = vabdq_u8(top[3], top_left); + TOP_LEFT_DIST(3); + const uint8x16_t left_3_le_top = vcleq_u8(left_3_dist, top_dist); + LEFT_LE_TOP_LEFT(3); + TOP_LE_TOP_LEFT(3); + + const uint8x16_t result_3 = + SelectPaeth(top[3], left, top_left, left_3_le_top, + left_le_top_left_3, top_le_top_left_3); + vst1q_u8(dest_u8 + 48, result_3); + } + } + + dest_u8 += stride; + } +} + +struct DcDefs { + DcDefs() = delete; + + using _4x4 = DcPredFuncs_NEON<2, 2, DcSum_NEON, DcStore_NEON<4, 4>>; + using _4x8 = DcPredFuncs_NEON<2, 3, DcSum_NEON, DcStore_NEON<4, 8>>; + using _4x16 = DcPredFuncs_NEON<2, 4, DcSum_NEON, DcStore_NEON<4, 16>>; + using _8x4 = DcPredFuncs_NEON<3, 2, DcSum_NEON, DcStore_NEON<8, 4>>; + using _8x8 = DcPredFuncs_NEON<3, 3, DcSum_NEON, DcStore_NEON<8, 8>>; + using _8x16 = DcPredFuncs_NEON<3, 4, DcSum_NEON, DcStore_NEON<8, 16>>; + using _8x32 = DcPredFuncs_NEON<3, 5, DcSum_NEON, DcStore_NEON<8, 32>>; + using _16x4 = DcPredFuncs_NEON<4, 2, DcSum_NEON, DcStore_NEON<16, 4>>; + using _16x8 = DcPredFuncs_NEON<4, 3, DcSum_NEON, DcStore_NEON<16, 8>>; + using _16x16 = DcPredFuncs_NEON<4, 4, DcSum_NEON, DcStore_NEON<16, 16>>; + using _16x32 = DcPredFuncs_NEON<4, 5, DcSum_NEON, DcStore_NEON<16, 32>>; + using _16x64 = DcPredFuncs_NEON<4, 6, DcSum_NEON, DcStore_NEON<16, 64>>; + using _32x8 = DcPredFuncs_NEON<5, 3, DcSum_NEON, DcStore_NEON<32, 8>>; + using _32x16 = DcPredFuncs_NEON<5, 4, DcSum_NEON, DcStore_NEON<32, 16>>; + using _32x32 = DcPredFuncs_NEON<5, 5, DcSum_NEON, DcStore_NEON<32, 32>>; + using _32x64 = DcPredFuncs_NEON<5, 6, DcSum_NEON, DcStore_NEON<32, 64>>; + using _64x16 = DcPredFuncs_NEON<6, 4, DcSum_NEON, DcStore_NEON<64, 16>>; + using _64x32 = DcPredFuncs_NEON<6, 5, DcSum_NEON, DcStore_NEON<64, 32>>; + using _64x64 = DcPredFuncs_NEON<6, 6, DcSum_NEON, DcStore_NEON<64, 64>>; +}; + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + // 4x4 + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcTop] = + DcDefs::_4x4::DcTop; + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcLeft] = + DcDefs::_4x4::DcLeft; + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDc] = + DcDefs::_4x4::Dc; + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorPaeth] = + Paeth4Or8xN_NEON<4, 4>; + + // 4x8 + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcTop] = + DcDefs::_4x8::DcTop; + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcLeft] = + DcDefs::_4x8::DcLeft; + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDc] = + DcDefs::_4x8::Dc; + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorPaeth] = + Paeth4Or8xN_NEON<4, 8>; + + // 4x16 + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcTop] = + DcDefs::_4x16::DcTop; + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcLeft] = + DcDefs::_4x16::DcLeft; + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDc] = + DcDefs::_4x16::Dc; + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorPaeth] = + Paeth4Or8xN_NEON<4, 16>; + + // 8x4 + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcTop] = + DcDefs::_8x4::DcTop; + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcLeft] = + DcDefs::_8x4::DcLeft; + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDc] = + DcDefs::_8x4::Dc; + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorPaeth] = + Paeth4Or8xN_NEON<8, 4>; + + // 8x8 + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcTop] = + DcDefs::_8x8::DcTop; + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcLeft] = + DcDefs::_8x8::DcLeft; + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDc] = + DcDefs::_8x8::Dc; + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorPaeth] = + Paeth4Or8xN_NEON<8, 8>; + + // 8x16 + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcTop] = + DcDefs::_8x16::DcTop; + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcLeft] = + DcDefs::_8x16::DcLeft; + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDc] = + DcDefs::_8x16::Dc; + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorPaeth] = + Paeth4Or8xN_NEON<8, 16>; + + // 8x32 + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcTop] = + DcDefs::_8x32::DcTop; + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcLeft] = + DcDefs::_8x32::DcLeft; + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDc] = + DcDefs::_8x32::Dc; + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorPaeth] = + Paeth4Or8xN_NEON<8, 32>; + + // 16x4 + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcTop] = + DcDefs::_16x4::DcTop; + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcLeft] = + DcDefs::_16x4::DcLeft; + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDc] = + DcDefs::_16x4::Dc; + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<16, 4>; + + // 16x8 + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcTop] = + DcDefs::_16x8::DcTop; + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcLeft] = + DcDefs::_16x8::DcLeft; + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDc] = + DcDefs::_16x8::Dc; + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<16, 8>; + + // 16x16 + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcTop] = + DcDefs::_16x16::DcTop; + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcLeft] = + DcDefs::_16x16::DcLeft; + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDc] = + DcDefs::_16x16::Dc; + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<16, 16>; + + // 16x32 + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcTop] = + DcDefs::_16x32::DcTop; + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcLeft] = + DcDefs::_16x32::DcLeft; + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDc] = + DcDefs::_16x32::Dc; + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<16, 32>; + + // 16x64 + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcTop] = + DcDefs::_16x64::DcTop; + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcLeft] = + DcDefs::_16x64::DcLeft; + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDc] = + DcDefs::_16x64::Dc; + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<16, 64>; + + // 32x8 + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcTop] = + DcDefs::_32x8::DcTop; + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcLeft] = + DcDefs::_32x8::DcLeft; + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDc] = + DcDefs::_32x8::Dc; + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<32, 8>; + + // 32x16 + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcTop] = + DcDefs::_32x16::DcTop; + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcLeft] = + DcDefs::_32x16::DcLeft; + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDc] = + DcDefs::_32x16::Dc; + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<32, 16>; + + // 32x32 + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcTop] = + DcDefs::_32x32::DcTop; + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcLeft] = + DcDefs::_32x32::DcLeft; + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDc] = + DcDefs::_32x32::Dc; + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<32, 32>; + + // 32x64 + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcTop] = + DcDefs::_32x64::DcTop; + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcLeft] = + DcDefs::_32x64::DcLeft; + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDc] = + DcDefs::_32x64::Dc; + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<32, 64>; + + // 64x16 + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcTop] = + DcDefs::_64x16::DcTop; + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcLeft] = + DcDefs::_64x16::DcLeft; + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDc] = + DcDefs::_64x16::Dc; + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<64, 16>; + + // 64x32 + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcTop] = + DcDefs::_64x32::DcTop; + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcLeft] = + DcDefs::_64x32::DcLeft; + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDc] = + DcDefs::_64x32::Dc; + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<64, 32>; + + // 64x64 + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcTop] = + DcDefs::_64x64::DcTop; + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcLeft] = + DcDefs::_64x64::DcLeft; + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDc] = + DcDefs::_64x64::Dc; + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<64, 64>; +} + +} // namespace +} // namespace low_bitdepth + +//------------------------------------------------------------------------------ +#if LIBGAV1_MAX_BITDEPTH >= 10 +namespace high_bitdepth { +namespace { + +// Add the elements in the given vectors together but do not sum the entire +// vector. +inline uint16x8_t Add(const uint16x8_t val_0, const uint16x8_t val_1, + const uint16x8_t val_2, const uint16x8_t val_3) { + const uint16x8_t sum_0 = vaddq_u16(val_0, val_1); + const uint16x8_t sum_1 = vaddq_u16(val_2, val_3); + return vaddq_u16(sum_0, sum_1); +} + +// Load and combine 16 uint16_t values. +inline uint16x8_t LoadAndAdd16(const uint16_t* buf) { + const uint16x8_t val_0 = vld1q_u16(buf); + const uint16x8_t val_1 = vld1q_u16(buf + 8); + return vaddq_u16(val_0, val_1); +} + +// Load and combine 32 uint16_t values. +inline uint16x8_t LoadAndAdd32(const uint16_t* buf) { + const uint16x8_t val_0 = vld1q_u16(buf); + const uint16x8_t val_1 = vld1q_u16(buf + 8); + const uint16x8_t val_2 = vld1q_u16(buf + 16); + const uint16x8_t val_3 = vld1q_u16(buf + 24); + return Add(val_0, val_1, val_2, val_3); +} + +// Load and combine 64 uint16_t values. +inline uint16x8_t LoadAndAdd64(const uint16_t* buf) { + const uint16x8_t val_0 = vld1q_u16(buf); + const uint16x8_t val_1 = vld1q_u16(buf + 8); + const uint16x8_t val_2 = vld1q_u16(buf + 16); + const uint16x8_t val_3 = vld1q_u16(buf + 24); + const uint16x8_t val_4 = vld1q_u16(buf + 32); + const uint16x8_t val_5 = vld1q_u16(buf + 40); + const uint16x8_t val_6 = vld1q_u16(buf + 48); + const uint16x8_t val_7 = vld1q_u16(buf + 56); + const uint16x8_t sum_0 = Add(val_0, val_1, val_2, val_3); + const uint16x8_t sum_1 = Add(val_4, val_5, val_6, val_7); + return vaddq_u16(sum_0, sum_1); +} + +// |ref_[01]| each point to 1 << |ref[01]_size_log2| packed uint16_t values. +// If |use_ref_1| is false then only sum |ref_0|. +inline uint32x2_t DcSum_NEON(const void* ref_0, const int ref_0_size_log2, + const bool use_ref_1, const void* ref_1, + const int ref_1_size_log2) { + const auto* ref_0_u16 = static_cast<const uint16_t*>(ref_0); + const auto* ref_1_u16 = static_cast<const uint16_t*>(ref_1); + if (ref_0_size_log2 == 2) { + const uint16x4_t val_0 = vld1_u16(ref_0_u16); + if (use_ref_1) { + if (ref_1_size_log2 == 2) { // 4x4 + const uint16x4_t val_1 = vld1_u16(ref_1_u16); + return Sum(vadd_u16(val_0, val_1)); + } else if (ref_1_size_log2 == 3) { // 4x8 + const uint16x8_t val_1 = vld1q_u16(ref_1_u16); + const uint16x8_t sum_0 = vcombine_u16(vdup_n_u16(0), val_0); + return Sum(vaddq_u16(sum_0, val_1)); + } else if (ref_1_size_log2 == 4) { // 4x16 + const uint16x8_t sum_0 = vcombine_u16(vdup_n_u16(0), val_0); + const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } + } + // 4x1 + return Sum(val_0); + } else if (ref_0_size_log2 == 3) { + const uint16x8_t val_0 = vld1q_u16(ref_0_u16); + if (use_ref_1) { + if (ref_1_size_log2 == 2) { // 8x4 + const uint16x4_t val_1 = vld1_u16(ref_1_u16); + const uint16x8_t sum_1 = vcombine_u16(vdup_n_u16(0), val_1); + return Sum(vaddq_u16(val_0, sum_1)); + } else if (ref_1_size_log2 == 3) { // 8x8 + const uint16x8_t val_1 = vld1q_u16(ref_1_u16); + return Sum(vaddq_u16(val_0, val_1)); + } else if (ref_1_size_log2 == 4) { // 8x16 + const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16); + return Sum(vaddq_u16(val_0, sum_1)); + } else if (ref_1_size_log2 == 5) { // 8x32 + const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16); + return Sum(vaddq_u16(val_0, sum_1)); + } + } + // 8x1 + return Sum(val_0); + } else if (ref_0_size_log2 == 4) { + const uint16x8_t sum_0 = LoadAndAdd16(ref_0_u16); + if (use_ref_1) { + if (ref_1_size_log2 == 2) { // 16x4 + const uint16x4_t val_1 = vld1_u16(ref_1_u16); + const uint16x8_t sum_1 = vcombine_u16(vdup_n_u16(0), val_1); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 3) { // 16x8 + const uint16x8_t val_1 = vld1q_u16(ref_1_u16); + return Sum(vaddq_u16(sum_0, val_1)); + } else if (ref_1_size_log2 == 4) { // 16x16 + const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 5) { // 16x32 + const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 6) { // 16x64 + const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } + } + // 16x1 + return Sum(sum_0); + } else if (ref_0_size_log2 == 5) { + const uint16x8_t sum_0 = LoadAndAdd32(ref_0_u16); + if (use_ref_1) { + if (ref_1_size_log2 == 3) { // 32x8 + const uint16x8_t val_1 = vld1q_u16(ref_1_u16); + return Sum(vaddq_u16(sum_0, val_1)); + } else if (ref_1_size_log2 == 4) { // 32x16 + const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 5) { // 32x32 + const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 6) { // 32x64 + const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } + } + // 32x1 + return Sum(sum_0); + } + + assert(ref_0_size_log2 == 6); + const uint16x8_t sum_0 = LoadAndAdd64(ref_0_u16); + if (use_ref_1) { + if (ref_1_size_log2 == 4) { // 64x16 + const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 5) { // 64x32 + const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 6) { // 64x64 + const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } + } + // 64x1 + return Sum(sum_0); +} + +template <int width, int height> +inline void DcStore_NEON(void* const dest, ptrdiff_t stride, + const uint32x2_t dc) { + auto* dest_u16 = static_cast<uint16_t*>(dest); + ptrdiff_t stride_u16 = stride >> 1; + const uint16x8_t dc_dup = vdupq_lane_u16(vreinterpret_u16_u32(dc), 0); + if (width == 4) { + int i = height - 1; + do { + vst1_u16(dest_u16, vget_low_u16(dc_dup)); + dest_u16 += stride_u16; + } while (--i != 0); + vst1_u16(dest_u16, vget_low_u16(dc_dup)); + } else if (width == 8) { + int i = height - 1; + do { + vst1q_u16(dest_u16, dc_dup); + dest_u16 += stride_u16; + } while (--i != 0); + vst1q_u16(dest_u16, dc_dup); + } else if (width == 16) { + int i = height - 1; + do { + vst1q_u16(dest_u16, dc_dup); + vst1q_u16(dest_u16 + 8, dc_dup); + dest_u16 += stride_u16; + } while (--i != 0); + vst1q_u16(dest_u16, dc_dup); + vst1q_u16(dest_u16 + 8, dc_dup); + } else if (width == 32) { + int i = height - 1; + do { + vst1q_u16(dest_u16, dc_dup); + vst1q_u16(dest_u16 + 8, dc_dup); + vst1q_u16(dest_u16 + 16, dc_dup); + vst1q_u16(dest_u16 + 24, dc_dup); + dest_u16 += stride_u16; + } while (--i != 0); + vst1q_u16(dest_u16, dc_dup); + vst1q_u16(dest_u16 + 8, dc_dup); + vst1q_u16(dest_u16 + 16, dc_dup); + vst1q_u16(dest_u16 + 24, dc_dup); + } else { + assert(width == 64); + int i = height - 1; + do { + vst1q_u16(dest_u16, dc_dup); + vst1q_u16(dest_u16 + 8, dc_dup); + vst1q_u16(dest_u16 + 16, dc_dup); + vst1q_u16(dest_u16 + 24, dc_dup); + vst1q_u16(dest_u16 + 32, dc_dup); + vst1q_u16(dest_u16 + 40, dc_dup); + vst1q_u16(dest_u16 + 48, dc_dup); + vst1q_u16(dest_u16 + 56, dc_dup); + dest_u16 += stride_u16; + } while (--i != 0); + vst1q_u16(dest_u16, dc_dup); + vst1q_u16(dest_u16 + 8, dc_dup); + vst1q_u16(dest_u16 + 16, dc_dup); + vst1q_u16(dest_u16 + 24, dc_dup); + vst1q_u16(dest_u16 + 32, dc_dup); + vst1q_u16(dest_u16 + 40, dc_dup); + vst1q_u16(dest_u16 + 48, dc_dup); + vst1q_u16(dest_u16 + 56, dc_dup); + } +} + +struct DcDefs { + DcDefs() = delete; + + using _4x4 = DcPredFuncs_NEON<2, 2, DcSum_NEON, DcStore_NEON<4, 4>>; + using _4x8 = DcPredFuncs_NEON<2, 3, DcSum_NEON, DcStore_NEON<4, 8>>; + using _4x16 = DcPredFuncs_NEON<2, 4, DcSum_NEON, DcStore_NEON<4, 16>>; + using _8x4 = DcPredFuncs_NEON<3, 2, DcSum_NEON, DcStore_NEON<8, 4>>; + using _8x8 = DcPredFuncs_NEON<3, 3, DcSum_NEON, DcStore_NEON<8, 8>>; + using _8x16 = DcPredFuncs_NEON<3, 4, DcSum_NEON, DcStore_NEON<8, 16>>; + using _8x32 = DcPredFuncs_NEON<3, 5, DcSum_NEON, DcStore_NEON<8, 32>>; + using _16x4 = DcPredFuncs_NEON<4, 2, DcSum_NEON, DcStore_NEON<16, 4>>; + using _16x8 = DcPredFuncs_NEON<4, 3, DcSum_NEON, DcStore_NEON<16, 8>>; + using _16x16 = DcPredFuncs_NEON<4, 4, DcSum_NEON, DcStore_NEON<16, 16>>; + using _16x32 = DcPredFuncs_NEON<4, 5, DcSum_NEON, DcStore_NEON<16, 32>>; + using _16x64 = DcPredFuncs_NEON<4, 6, DcSum_NEON, DcStore_NEON<16, 64>>; + using _32x8 = DcPredFuncs_NEON<5, 3, DcSum_NEON, DcStore_NEON<32, 8>>; + using _32x16 = DcPredFuncs_NEON<5, 4, DcSum_NEON, DcStore_NEON<32, 16>>; + using _32x32 = DcPredFuncs_NEON<5, 5, DcSum_NEON, DcStore_NEON<32, 32>>; + using _32x64 = DcPredFuncs_NEON<5, 6, DcSum_NEON, DcStore_NEON<32, 64>>; + using _64x16 = DcPredFuncs_NEON<6, 4, DcSum_NEON, DcStore_NEON<64, 16>>; + using _64x32 = DcPredFuncs_NEON<6, 5, DcSum_NEON, DcStore_NEON<64, 32>>; + using _64x64 = DcPredFuncs_NEON<6, 6, DcSum_NEON, DcStore_NEON<64, 64>>; +}; + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcTop] = + DcDefs::_4x4::DcTop; + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcLeft] = + DcDefs::_4x4::DcLeft; + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDc] = + DcDefs::_4x4::Dc; + + // 4x8 + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcTop] = + DcDefs::_4x8::DcTop; + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcLeft] = + DcDefs::_4x8::DcLeft; + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDc] = + DcDefs::_4x8::Dc; + + // 4x16 + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcTop] = + DcDefs::_4x16::DcTop; + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcLeft] = + DcDefs::_4x16::DcLeft; + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDc] = + DcDefs::_4x16::Dc; + + // 8x4 + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcTop] = + DcDefs::_8x4::DcTop; + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcLeft] = + DcDefs::_8x4::DcLeft; + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDc] = + DcDefs::_8x4::Dc; + + // 8x8 + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcTop] = + DcDefs::_8x8::DcTop; + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcLeft] = + DcDefs::_8x8::DcLeft; + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDc] = + DcDefs::_8x8::Dc; + + // 8x16 + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcTop] = + DcDefs::_8x16::DcTop; + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcLeft] = + DcDefs::_8x16::DcLeft; + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDc] = + DcDefs::_8x16::Dc; + + // 8x32 + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcTop] = + DcDefs::_8x32::DcTop; + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcLeft] = + DcDefs::_8x32::DcLeft; + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDc] = + DcDefs::_8x32::Dc; + + // 16x4 + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcTop] = + DcDefs::_16x4::DcTop; + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcLeft] = + DcDefs::_16x4::DcLeft; + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDc] = + DcDefs::_16x4::Dc; + + // 16x8 + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcTop] = + DcDefs::_16x8::DcTop; + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcLeft] = + DcDefs::_16x8::DcLeft; + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDc] = + DcDefs::_16x8::Dc; + + // 16x16 + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcTop] = + DcDefs::_16x16::DcTop; + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcLeft] = + DcDefs::_16x16::DcLeft; + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDc] = + DcDefs::_16x16::Dc; + + // 16x32 + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcTop] = + DcDefs::_16x32::DcTop; + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcLeft] = + DcDefs::_16x32::DcLeft; + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDc] = + DcDefs::_16x32::Dc; + + // 16x64 + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcTop] = + DcDefs::_16x64::DcTop; + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcLeft] = + DcDefs::_16x64::DcLeft; + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDc] = + DcDefs::_16x64::Dc; + + // 32x8 + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcTop] = + DcDefs::_32x8::DcTop; + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcLeft] = + DcDefs::_32x8::DcLeft; + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDc] = + DcDefs::_32x8::Dc; + + // 32x16 + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcTop] = + DcDefs::_32x16::DcTop; + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcLeft] = + DcDefs::_32x16::DcLeft; + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDc] = + DcDefs::_32x16::Dc; + + // 32x32 + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcTop] = + DcDefs::_32x32::DcTop; + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcLeft] = + DcDefs::_32x32::DcLeft; + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDc] = + DcDefs::_32x32::Dc; + + // 32x64 + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcTop] = + DcDefs::_32x64::DcTop; + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcLeft] = + DcDefs::_32x64::DcLeft; + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDc] = + DcDefs::_32x64::Dc; + + // 64x16 + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcTop] = + DcDefs::_64x16::DcTop; + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcLeft] = + DcDefs::_64x16::DcLeft; + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDc] = + DcDefs::_64x16::Dc; + + // 64x32 + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcTop] = + DcDefs::_64x32::DcTop; + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcLeft] = + DcDefs::_64x32::DcLeft; + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDc] = + DcDefs::_64x32::Dc; + + // 64x64 + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcTop] = + DcDefs::_64x64::DcTop; + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcLeft] = + DcDefs::_64x64::DcLeft; + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDc] = + DcDefs::_64x64::Dc; +} + +} // namespace +} // namespace high_bitdepth +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +void IntraPredInit_NEON() { + low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + high_bitdepth::Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void IntraPredInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/intrapred_neon.h b/src/dsp/arm/intrapred_neon.h new file mode 100644 index 0000000..16f858c --- /dev/null +++ b/src/dsp/arm/intrapred_neon.h @@ -0,0 +1,418 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_INTRAPRED_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_INTRAPRED_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::intra_predictors, Dsp::directional_intra_predictor_zone*, +// Dsp::cfl_intra_predictors, Dsp::cfl_subsamplers and +// Dsp::filter_intra_predictor, see the defines below for specifics. These +// functions are not thread-safe. +void IntraPredCflInit_NEON(); +void IntraPredDirectionalInit_NEON(); +void IntraPredFilterIntraInit_NEON(); +void IntraPredInit_NEON(); +void IntraPredSmoothInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +// 8 bit +#define LIBGAV1_Dsp8bpp_FilterIntraPredictor LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone1 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone2 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone3 LIBGAV1_CPU_NEON + +// 4x4 +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444 LIBGAV1_CPU_NEON + +// 4x8 +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler444 LIBGAV1_CPU_NEON + +// 4x16 +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler444 LIBGAV1_CPU_NEON + +// 8x4 +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler444 LIBGAV1_CPU_NEON + +// 8x8 +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler444 LIBGAV1_CPU_NEON + +// 8x16 +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler444 LIBGAV1_CPU_NEON + +// 8x32 +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler444 LIBGAV1_CPU_NEON + +// 16x4 +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler444 LIBGAV1_CPU_NEON + +// 16x8 +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler444 LIBGAV1_CPU_NEON + +// 16x16 +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler444 LIBGAV1_CPU_NEON + +// 16x32 +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler444 LIBGAV1_CPU_NEON + +// 16x64 +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +// 32x8 +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler444 LIBGAV1_CPU_NEON + +// 32x16 +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler444 LIBGAV1_CPU_NEON + +// 32x32 +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444 LIBGAV1_CPU_NEON + +// 32x64 +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +// 64x16 +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +// 64x32 +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +// 64x64 +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +// 10 bit +// 4x4 +#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDc LIBGAV1_CPU_NEON + +// 4x8 +#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDc LIBGAV1_CPU_NEON + +// 4x16 +#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDc LIBGAV1_CPU_NEON + +// 8x4 +#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDc LIBGAV1_CPU_NEON + +// 8x8 +#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDc LIBGAV1_CPU_NEON + +// 8x16 +#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDc LIBGAV1_CPU_NEON + +// 8x32 +#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDc LIBGAV1_CPU_NEON + +// 16x4 +#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDc LIBGAV1_CPU_NEON + +// 16x8 +#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDc LIBGAV1_CPU_NEON + +// 16x16 +#define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDcLeft \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDc LIBGAV1_CPU_NEON + +// 16x32 +#define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDcLeft \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDc LIBGAV1_CPU_NEON + +// 16x64 +#define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDcLeft \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDc LIBGAV1_CPU_NEON + +// 32x8 +#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDc LIBGAV1_CPU_NEON + +// 32x16 +#define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDcLeft \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDc LIBGAV1_CPU_NEON + +// 32x32 +#define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDcLeft \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDc LIBGAV1_CPU_NEON + +// 32x64 +#define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorDcLeft \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorDc LIBGAV1_CPU_NEON + +// 64x16 +#define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDcLeft \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDc LIBGAV1_CPU_NEON + +// 64x32 +#define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDcLeft \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDc LIBGAV1_CPU_NEON + +// 64x64 +#define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDcLeft \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDc LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_INTRAPRED_NEON_H_ diff --git a/src/dsp/arm/intrapred_smooth_neon.cc b/src/dsp/arm/intrapred_smooth_neon.cc new file mode 100644 index 0000000..abc93e8 --- /dev/null +++ b/src/dsp/arm/intrapred_smooth_neon.cc @@ -0,0 +1,616 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intrapred.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" + +namespace libgav1 { +namespace dsp { + +namespace low_bitdepth { +namespace { + +// Note these constants are duplicated from intrapred.cc to allow the compiler +// to have visibility of the values. This helps reduce loads and in the +// creation of the inverse weights. +constexpr uint8_t kSmoothWeights[] = { + // block dimension = 4 + 255, 149, 85, 64, + // block dimension = 8 + 255, 197, 146, 105, 73, 50, 37, 32, + // block dimension = 16 + 255, 225, 196, 170, 145, 123, 102, 84, 68, 54, 43, 33, 26, 20, 17, 16, + // block dimension = 32 + 255, 240, 225, 210, 196, 182, 169, 157, 145, 133, 122, 111, 101, 92, 83, 74, + 66, 59, 52, 45, 39, 34, 29, 25, 21, 17, 14, 12, 10, 9, 8, 8, + // block dimension = 64 + 255, 248, 240, 233, 225, 218, 210, 203, 196, 189, 182, 176, 169, 163, 156, + 150, 144, 138, 133, 127, 121, 116, 111, 106, 101, 96, 91, 86, 82, 77, 73, + 69, 65, 61, 57, 54, 50, 47, 44, 41, 38, 35, 32, 29, 27, 25, 22, 20, 18, 16, + 15, 13, 12, 10, 9, 8, 7, 6, 6, 5, 5, 4, 4, 4}; + +// TODO(b/150459137): Keeping the intermediate values in uint16_t would allow +// processing more values at once. At the high end, it could do 4x4 or 8x2 at a +// time. +inline uint16x4_t CalculatePred(const uint16x4_t weighted_top, + const uint16x4_t weighted_left, + const uint16x4_t weighted_bl, + const uint16x4_t weighted_tr) { + const uint32x4_t pred_0 = vaddl_u16(weighted_top, weighted_left); + const uint32x4_t pred_1 = vaddl_u16(weighted_bl, weighted_tr); + const uint32x4_t pred_2 = vaddq_u32(pred_0, pred_1); + return vrshrn_n_u32(pred_2, kSmoothWeightScale + 1); +} + +template <int width, int height> +inline void Smooth4Or8xN_NEON(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const uint8_t* const top = static_cast<const uint8_t*>(top_row); + const uint8_t* const left = static_cast<const uint8_t*>(left_column); + const uint8_t top_right = top[width - 1]; + const uint8_t bottom_left = left[height - 1]; + const uint8_t* const weights_y = kSmoothWeights + height - 4; + uint8_t* dst = static_cast<uint8_t*>(dest); + + uint8x8_t top_v; + if (width == 4) { + top_v = Load4(top); + } else { // width == 8 + top_v = vld1_u8(top); + } + const uint8x8_t top_right_v = vdup_n_u8(top_right); + const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left); + // Over-reads for 4xN but still within the array. + const uint8x8_t weights_x_v = vld1_u8(kSmoothWeights + width - 4); + // 256 - weights = vneg_s8(weights) + const uint8x8_t scaled_weights_x = + vreinterpret_u8_s8(vneg_s8(vreinterpret_s8_u8(weights_x_v))); + + for (int y = 0; y < height; ++y) { + const uint8x8_t left_v = vdup_n_u8(left[y]); + const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]); + const uint8x8_t scaled_weights_y = vdup_n_u8(256 - weights_y[y]); + const uint16x8_t weighted_bl = vmull_u8(scaled_weights_y, bottom_left_v); + + const uint16x8_t weighted_top = vmull_u8(weights_y_v, top_v); + const uint16x8_t weighted_left = vmull_u8(weights_x_v, left_v); + const uint16x8_t weighted_tr = vmull_u8(scaled_weights_x, top_right_v); + const uint16x4_t dest_0 = + CalculatePred(vget_low_u16(weighted_top), vget_low_u16(weighted_left), + vget_low_u16(weighted_tr), vget_low_u16(weighted_bl)); + + if (width == 4) { + StoreLo4(dst, vmovn_u16(vcombine_u16(dest_0, dest_0))); + } else { // width == 8 + const uint16x4_t dest_1 = CalculatePred( + vget_high_u16(weighted_top), vget_high_u16(weighted_left), + vget_high_u16(weighted_tr), vget_high_u16(weighted_bl)); + vst1_u8(dst, vmovn_u16(vcombine_u16(dest_0, dest_1))); + } + dst += stride; + } +} + +inline uint8x16_t CalculateWeightsAndPred( + const uint8x16_t top, const uint8x8_t left, const uint8x8_t top_right, + const uint8x8_t weights_y, const uint8x16_t weights_x, + const uint8x16_t scaled_weights_x, const uint16x8_t weighted_bl) { + const uint16x8_t weighted_top_low = vmull_u8(weights_y, vget_low_u8(top)); + const uint16x8_t weighted_left_low = vmull_u8(vget_low_u8(weights_x), left); + const uint16x8_t weighted_tr_low = + vmull_u8(vget_low_u8(scaled_weights_x), top_right); + const uint16x4_t dest_0 = CalculatePred( + vget_low_u16(weighted_top_low), vget_low_u16(weighted_left_low), + vget_low_u16(weighted_tr_low), vget_low_u16(weighted_bl)); + const uint16x4_t dest_1 = CalculatePred( + vget_high_u16(weighted_top_low), vget_high_u16(weighted_left_low), + vget_high_u16(weighted_tr_low), vget_high_u16(weighted_bl)); + const uint8x8_t dest_0_u8 = vmovn_u16(vcombine_u16(dest_0, dest_1)); + + const uint16x8_t weighted_top_high = vmull_u8(weights_y, vget_high_u8(top)); + const uint16x8_t weighted_left_high = vmull_u8(vget_high_u8(weights_x), left); + const uint16x8_t weighted_tr_high = + vmull_u8(vget_high_u8(scaled_weights_x), top_right); + const uint16x4_t dest_2 = CalculatePred( + vget_low_u16(weighted_top_high), vget_low_u16(weighted_left_high), + vget_low_u16(weighted_tr_high), vget_low_u16(weighted_bl)); + const uint16x4_t dest_3 = CalculatePred( + vget_high_u16(weighted_top_high), vget_high_u16(weighted_left_high), + vget_high_u16(weighted_tr_high), vget_high_u16(weighted_bl)); + const uint8x8_t dest_1_u8 = vmovn_u16(vcombine_u16(dest_2, dest_3)); + + return vcombine_u8(dest_0_u8, dest_1_u8); +} + +template <int width, int height> +inline void Smooth16PlusxN_NEON(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const uint8_t* const top = static_cast<const uint8_t*>(top_row); + const uint8_t* const left = static_cast<const uint8_t*>(left_column); + const uint8_t top_right = top[width - 1]; + const uint8_t bottom_left = left[height - 1]; + const uint8_t* const weights_y = kSmoothWeights + height - 4; + uint8_t* dst = static_cast<uint8_t*>(dest); + + uint8x16_t top_v[4]; + top_v[0] = vld1q_u8(top); + if (width > 16) { + top_v[1] = vld1q_u8(top + 16); + if (width == 64) { + top_v[2] = vld1q_u8(top + 32); + top_v[3] = vld1q_u8(top + 48); + } + } + + const uint8x8_t top_right_v = vdup_n_u8(top_right); + const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left); + + // TODO(johannkoenig): Consider re-reading top_v and weights_x_v in the loop. + // This currently has a performance slope similar to Paeth so it does not + // appear to be register bound for arm64. + uint8x16_t weights_x_v[4]; + weights_x_v[0] = vld1q_u8(kSmoothWeights + width - 4); + if (width > 16) { + weights_x_v[1] = vld1q_u8(kSmoothWeights + width + 16 - 4); + if (width == 64) { + weights_x_v[2] = vld1q_u8(kSmoothWeights + width + 32 - 4); + weights_x_v[3] = vld1q_u8(kSmoothWeights + width + 48 - 4); + } + } + + uint8x16_t scaled_weights_x[4]; + scaled_weights_x[0] = + vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x_v[0]))); + if (width > 16) { + scaled_weights_x[1] = + vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x_v[1]))); + if (width == 64) { + scaled_weights_x[2] = + vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x_v[2]))); + scaled_weights_x[3] = + vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x_v[3]))); + } + } + + for (int y = 0; y < height; ++y) { + const uint8x8_t left_v = vdup_n_u8(left[y]); + const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]); + const uint8x8_t scaled_weights_y = vdup_n_u8(256 - weights_y[y]); + const uint16x8_t weighted_bl = vmull_u8(scaled_weights_y, bottom_left_v); + + vst1q_u8(dst, CalculateWeightsAndPred(top_v[0], left_v, top_right_v, + weights_y_v, weights_x_v[0], + scaled_weights_x[0], weighted_bl)); + + if (width > 16) { + vst1q_u8(dst + 16, CalculateWeightsAndPred( + top_v[1], left_v, top_right_v, weights_y_v, + weights_x_v[1], scaled_weights_x[1], weighted_bl)); + if (width == 64) { + vst1q_u8(dst + 32, + CalculateWeightsAndPred(top_v[2], left_v, top_right_v, + weights_y_v, weights_x_v[2], + scaled_weights_x[2], weighted_bl)); + vst1q_u8(dst + 48, + CalculateWeightsAndPred(top_v[3], left_v, top_right_v, + weights_y_v, weights_x_v[3], + scaled_weights_x[3], weighted_bl)); + } + } + + dst += stride; + } +} + +template <int width, int height> +inline void SmoothVertical4Or8xN_NEON(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const uint8_t* const top = static_cast<const uint8_t*>(top_row); + const uint8_t* const left = static_cast<const uint8_t*>(left_column); + const uint8_t bottom_left = left[height - 1]; + const uint8_t* const weights_y = kSmoothWeights + height - 4; + uint8_t* dst = static_cast<uint8_t*>(dest); + + uint8x8_t top_v; + if (width == 4) { + top_v = Load4(top); + } else { // width == 8 + top_v = vld1_u8(top); + } + + const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left); + + for (int y = 0; y < height; ++y) { + const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]); + const uint8x8_t scaled_weights_y = vdup_n_u8(256 - weights_y[y]); + + const uint16x8_t weighted_top = vmull_u8(weights_y_v, top_v); + const uint16x8_t weighted_bl = vmull_u8(scaled_weights_y, bottom_left_v); + const uint16x8_t pred = vaddq_u16(weighted_top, weighted_bl); + const uint8x8_t pred_scaled = vrshrn_n_u16(pred, kSmoothWeightScale); + + if (width == 4) { + StoreLo4(dst, pred_scaled); + } else { // width == 8 + vst1_u8(dst, pred_scaled); + } + dst += stride; + } +} + +inline uint8x16_t CalculateVerticalWeightsAndPred( + const uint8x16_t top, const uint8x8_t weights_y, + const uint16x8_t weighted_bl) { + const uint16x8_t weighted_top_low = vmull_u8(weights_y, vget_low_u8(top)); + const uint16x8_t weighted_top_high = vmull_u8(weights_y, vget_high_u8(top)); + const uint16x8_t pred_low = vaddq_u16(weighted_top_low, weighted_bl); + const uint16x8_t pred_high = vaddq_u16(weighted_top_high, weighted_bl); + const uint8x8_t pred_scaled_low = vrshrn_n_u16(pred_low, kSmoothWeightScale); + const uint8x8_t pred_scaled_high = + vrshrn_n_u16(pred_high, kSmoothWeightScale); + return vcombine_u8(pred_scaled_low, pred_scaled_high); +} + +template <int width, int height> +inline void SmoothVertical16PlusxN_NEON(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const uint8_t* const top = static_cast<const uint8_t*>(top_row); + const uint8_t* const left = static_cast<const uint8_t*>(left_column); + const uint8_t bottom_left = left[height - 1]; + const uint8_t* const weights_y = kSmoothWeights + height - 4; + uint8_t* dst = static_cast<uint8_t*>(dest); + + uint8x16_t top_v[4]; + top_v[0] = vld1q_u8(top); + if (width > 16) { + top_v[1] = vld1q_u8(top + 16); + if (width == 64) { + top_v[2] = vld1q_u8(top + 32); + top_v[3] = vld1q_u8(top + 48); + } + } + + const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left); + + for (int y = 0; y < height; ++y) { + const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]); + const uint8x8_t scaled_weights_y = vdup_n_u8(256 - weights_y[y]); + const uint16x8_t weighted_bl = vmull_u8(scaled_weights_y, bottom_left_v); + + const uint8x16_t pred_0 = + CalculateVerticalWeightsAndPred(top_v[0], weights_y_v, weighted_bl); + vst1q_u8(dst, pred_0); + + if (width > 16) { + const uint8x16_t pred_1 = + CalculateVerticalWeightsAndPred(top_v[1], weights_y_v, weighted_bl); + vst1q_u8(dst + 16, pred_1); + + if (width == 64) { + const uint8x16_t pred_2 = + CalculateVerticalWeightsAndPred(top_v[2], weights_y_v, weighted_bl); + vst1q_u8(dst + 32, pred_2); + + const uint8x16_t pred_3 = + CalculateVerticalWeightsAndPred(top_v[3], weights_y_v, weighted_bl); + vst1q_u8(dst + 48, pred_3); + } + } + + dst += stride; + } +} + +template <int width, int height> +inline void SmoothHorizontal4Or8xN_NEON(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const uint8_t* const top = static_cast<const uint8_t*>(top_row); + const uint8_t* const left = static_cast<const uint8_t*>(left_column); + const uint8_t top_right = top[width - 1]; + uint8_t* dst = static_cast<uint8_t*>(dest); + + const uint8x8_t top_right_v = vdup_n_u8(top_right); + // Over-reads for 4xN but still within the array. + const uint8x8_t weights_x = vld1_u8(kSmoothWeights + width - 4); + // 256 - weights = vneg_s8(weights) + const uint8x8_t scaled_weights_x = + vreinterpret_u8_s8(vneg_s8(vreinterpret_s8_u8(weights_x))); + + for (int y = 0; y < height; ++y) { + const uint8x8_t left_v = vdup_n_u8(left[y]); + + const uint16x8_t weighted_left = vmull_u8(weights_x, left_v); + const uint16x8_t weighted_tr = vmull_u8(scaled_weights_x, top_right_v); + const uint16x8_t pred = vaddq_u16(weighted_left, weighted_tr); + const uint8x8_t pred_scaled = vrshrn_n_u16(pred, kSmoothWeightScale); + + if (width == 4) { + StoreLo4(dst, pred_scaled); + } else { // width == 8 + vst1_u8(dst, pred_scaled); + } + dst += stride; + } +} + +inline uint8x16_t CalculateHorizontalWeightsAndPred( + const uint8x8_t left, const uint8x8_t top_right, const uint8x16_t weights_x, + const uint8x16_t scaled_weights_x) { + const uint16x8_t weighted_left_low = vmull_u8(vget_low_u8(weights_x), left); + const uint16x8_t weighted_tr_low = + vmull_u8(vget_low_u8(scaled_weights_x), top_right); + const uint16x8_t pred_low = vaddq_u16(weighted_left_low, weighted_tr_low); + const uint8x8_t pred_scaled_low = vrshrn_n_u16(pred_low, kSmoothWeightScale); + + const uint16x8_t weighted_left_high = vmull_u8(vget_high_u8(weights_x), left); + const uint16x8_t weighted_tr_high = + vmull_u8(vget_high_u8(scaled_weights_x), top_right); + const uint16x8_t pred_high = vaddq_u16(weighted_left_high, weighted_tr_high); + const uint8x8_t pred_scaled_high = + vrshrn_n_u16(pred_high, kSmoothWeightScale); + + return vcombine_u8(pred_scaled_low, pred_scaled_high); +} + +template <int width, int height> +inline void SmoothHorizontal16PlusxN_NEON(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const uint8_t* const top = static_cast<const uint8_t*>(top_row); + const uint8_t* const left = static_cast<const uint8_t*>(left_column); + const uint8_t top_right = top[width - 1]; + uint8_t* dst = static_cast<uint8_t*>(dest); + + const uint8x8_t top_right_v = vdup_n_u8(top_right); + + uint8x16_t weights_x[4]; + weights_x[0] = vld1q_u8(kSmoothWeights + width - 4); + if (width > 16) { + weights_x[1] = vld1q_u8(kSmoothWeights + width + 16 - 4); + if (width == 64) { + weights_x[2] = vld1q_u8(kSmoothWeights + width + 32 - 4); + weights_x[3] = vld1q_u8(kSmoothWeights + width + 48 - 4); + } + } + + uint8x16_t scaled_weights_x[4]; + scaled_weights_x[0] = + vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x[0]))); + if (width > 16) { + scaled_weights_x[1] = + vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x[1]))); + if (width == 64) { + scaled_weights_x[2] = + vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x[2]))); + scaled_weights_x[3] = + vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x[3]))); + } + } + + for (int y = 0; y < height; ++y) { + const uint8x8_t left_v = vdup_n_u8(left[y]); + + const uint8x16_t pred_0 = CalculateHorizontalWeightsAndPred( + left_v, top_right_v, weights_x[0], scaled_weights_x[0]); + vst1q_u8(dst, pred_0); + + if (width > 16) { + const uint8x16_t pred_1 = CalculateHorizontalWeightsAndPred( + left_v, top_right_v, weights_x[1], scaled_weights_x[1]); + vst1q_u8(dst + 16, pred_1); + + if (width == 64) { + const uint8x16_t pred_2 = CalculateHorizontalWeightsAndPred( + left_v, top_right_v, weights_x[2], scaled_weights_x[2]); + vst1q_u8(dst + 32, pred_2); + + const uint8x16_t pred_3 = CalculateHorizontalWeightsAndPred( + left_v, top_right_v, weights_x[3], scaled_weights_x[3]); + vst1q_u8(dst + 48, pred_3); + } + } + dst += stride; + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + // 4x4 + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmooth] = + Smooth4Or8xN_NEON<4, 4>; + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmoothVertical] = + SmoothVertical4Or8xN_NEON<4, 4>; + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal4Or8xN_NEON<4, 4>; + + // 4x8 + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmooth] = + Smooth4Or8xN_NEON<4, 8>; + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmoothVertical] = + SmoothVertical4Or8xN_NEON<4, 8>; + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal4Or8xN_NEON<4, 8>; + + // 4x16 + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmooth] = + Smooth4Or8xN_NEON<4, 16>; + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmoothVertical] = + SmoothVertical4Or8xN_NEON<4, 16>; + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal4Or8xN_NEON<4, 16>; + + // 8x4 + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmooth] = + Smooth4Or8xN_NEON<8, 4>; + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmoothVertical] = + SmoothVertical4Or8xN_NEON<8, 4>; + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal4Or8xN_NEON<8, 4>; + + // 8x8 + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmooth] = + Smooth4Or8xN_NEON<8, 8>; + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmoothVertical] = + SmoothVertical4Or8xN_NEON<8, 8>; + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal4Or8xN_NEON<8, 8>; + + // 8x16 + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmooth] = + Smooth4Or8xN_NEON<8, 16>; + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmoothVertical] = + SmoothVertical4Or8xN_NEON<8, 16>; + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal4Or8xN_NEON<8, 16>; + + // 8x32 + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmooth] = + Smooth4Or8xN_NEON<8, 32>; + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmoothVertical] = + SmoothVertical4Or8xN_NEON<8, 32>; + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal4Or8xN_NEON<8, 32>; + + // 16x4 + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<16, 4>; + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<16, 4>; + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<16, 4>; + + // 16x8 + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<16, 8>; + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<16, 8>; + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<16, 8>; + + // 16x16 + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<16, 16>; + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<16, 16>; + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<16, 16>; + + // 16x32 + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<16, 32>; + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<16, 32>; + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<16, 32>; + + // 16x64 + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<16, 64>; + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<16, 64>; + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<16, 64>; + + // 32x8 + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<32, 8>; + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<32, 8>; + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<32, 8>; + + // 32x16 + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<32, 16>; + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<32, 16>; + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<32, 16>; + + // 32x32 + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<32, 32>; + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<32, 32>; + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<32, 32>; + + // 32x64 + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<32, 64>; + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<32, 64>; + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<32, 64>; + + // 64x16 + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<64, 16>; + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<64, 16>; + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<64, 16>; + + // 64x32 + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<64, 32>; + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<64, 32>; + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<64, 32>; + + // 64x64 + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<64, 64>; + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<64, 64>; + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<64, 64>; +} + +} // namespace +} // namespace low_bitdepth + +void IntraPredSmoothInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void IntraPredSmoothInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/inverse_transform_neon.cc b/src/dsp/arm/inverse_transform_neon.cc new file mode 100644 index 0000000..072991a --- /dev/null +++ b/src/dsp/arm/inverse_transform_neon.cc @@ -0,0 +1,3128 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/inverse_transform.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <algorithm> +#include <cassert> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/array_2d.h" +#include "src/utils/common.h" +#include "src/utils/compiler_attributes.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +// Include the constants and utility functions inside the anonymous namespace. +#include "src/dsp/inverse_transform.inc" + +//------------------------------------------------------------------------------ + +// TODO(slavarnway): Move transpose functions to transpose_neon.h or +// common_neon.h. + +LIBGAV1_ALWAYS_INLINE void Transpose4x4(const int16x8_t in[4], + int16x8_t out[4]) { + // Swap 16 bit elements. Goes from: + // a0: 00 01 02 03 + // a1: 10 11 12 13 + // a2: 20 21 22 23 + // a3: 30 31 32 33 + // to: + // b0.val[0]: 00 10 02 12 + // b0.val[1]: 01 11 03 13 + // b1.val[0]: 20 30 22 32 + // b1.val[1]: 21 31 23 33 + const int16x4_t a0 = vget_low_s16(in[0]); + const int16x4_t a1 = vget_low_s16(in[1]); + const int16x4_t a2 = vget_low_s16(in[2]); + const int16x4_t a3 = vget_low_s16(in[3]); + + const int16x4x2_t b0 = vtrn_s16(a0, a1); + const int16x4x2_t b1 = vtrn_s16(a2, a3); + + // Swap 32 bit elements resulting in: + // c0.val[0]: 00 10 20 30 04 14 24 34 + // c0.val[1]: 02 12 22 32 06 16 26 36 + // c1.val[0]: 01 11 21 31 05 15 25 35 + // c1.val[1]: 03 13 23 33 07 17 27 37 + const int32x2x2_t c0 = vtrn_s32(vreinterpret_s32_s16(b0.val[0]), + vreinterpret_s32_s16(b1.val[0])); + const int32x2x2_t c1 = vtrn_s32(vreinterpret_s32_s16(b0.val[1]), + vreinterpret_s32_s16(b1.val[1])); + + const int16x4_t d0 = vreinterpret_s16_s32(c0.val[0]); + const int16x4_t d1 = vreinterpret_s16_s32(c1.val[0]); + const int16x4_t d2 = vreinterpret_s16_s32(c0.val[1]); + const int16x4_t d3 = vreinterpret_s16_s32(c1.val[1]); + + out[0] = vcombine_s16(d0, d0); + out[1] = vcombine_s16(d1, d1); + out[2] = vcombine_s16(d2, d2); + out[3] = vcombine_s16(d3, d3); +} + +// Note this is only used in the final stage of Dct32/64 and Adst16 as the in +// place version causes additional stack usage with clang. +LIBGAV1_ALWAYS_INLINE void Transpose8x8(const int16x8_t in[8], + int16x8_t out[8]) { + // Swap 16 bit elements. Goes from: + // a0: 00 01 02 03 04 05 06 07 + // a1: 10 11 12 13 14 15 16 17 + // a2: 20 21 22 23 24 25 26 27 + // a3: 30 31 32 33 34 35 36 37 + // a4: 40 41 42 43 44 45 46 47 + // a5: 50 51 52 53 54 55 56 57 + // a6: 60 61 62 63 64 65 66 67 + // a7: 70 71 72 73 74 75 76 77 + // to: + // b0.val[0]: 00 10 02 12 04 14 06 16 + // b0.val[1]: 01 11 03 13 05 15 07 17 + // b1.val[0]: 20 30 22 32 24 34 26 36 + // b1.val[1]: 21 31 23 33 25 35 27 37 + // b2.val[0]: 40 50 42 52 44 54 46 56 + // b2.val[1]: 41 51 43 53 45 55 47 57 + // b3.val[0]: 60 70 62 72 64 74 66 76 + // b3.val[1]: 61 71 63 73 65 75 67 77 + + const int16x8x2_t b0 = vtrnq_s16(in[0], in[1]); + const int16x8x2_t b1 = vtrnq_s16(in[2], in[3]); + const int16x8x2_t b2 = vtrnq_s16(in[4], in[5]); + const int16x8x2_t b3 = vtrnq_s16(in[6], in[7]); + + // Swap 32 bit elements resulting in: + // c0.val[0]: 00 10 20 30 04 14 24 34 + // c0.val[1]: 02 12 22 32 06 16 26 36 + // c1.val[0]: 01 11 21 31 05 15 25 35 + // c1.val[1]: 03 13 23 33 07 17 27 37 + // c2.val[0]: 40 50 60 70 44 54 64 74 + // c2.val[1]: 42 52 62 72 46 56 66 76 + // c3.val[0]: 41 51 61 71 45 55 65 75 + // c3.val[1]: 43 53 63 73 47 57 67 77 + + const int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[0]), + vreinterpretq_s32_s16(b1.val[0])); + const int32x4x2_t c1 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[1]), + vreinterpretq_s32_s16(b1.val[1])); + const int32x4x2_t c2 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[0]), + vreinterpretq_s32_s16(b3.val[0])); + const int32x4x2_t c3 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[1]), + vreinterpretq_s32_s16(b3.val[1])); + + // Swap 64 bit elements resulting in: + // d0.val[0]: 00 10 20 30 40 50 60 70 + // d0.val[1]: 04 14 24 34 44 54 64 74 + // d1.val[0]: 01 11 21 31 41 51 61 71 + // d1.val[1]: 05 15 25 35 45 55 65 75 + // d2.val[0]: 02 12 22 32 42 52 62 72 + // d2.val[1]: 06 16 26 36 46 56 66 76 + // d3.val[0]: 03 13 23 33 43 53 63 73 + // d3.val[1]: 07 17 27 37 47 57 67 77 + const int16x8x2_t d0 = VtrnqS64(c0.val[0], c2.val[0]); + const int16x8x2_t d1 = VtrnqS64(c1.val[0], c3.val[0]); + const int16x8x2_t d2 = VtrnqS64(c0.val[1], c2.val[1]); + const int16x8x2_t d3 = VtrnqS64(c1.val[1], c3.val[1]); + + out[0] = d0.val[0]; + out[1] = d1.val[0]; + out[2] = d2.val[0]; + out[3] = d3.val[0]; + out[4] = d0.val[1]; + out[5] = d1.val[1]; + out[6] = d2.val[1]; + out[7] = d3.val[1]; +} + +LIBGAV1_ALWAYS_INLINE void Transpose4x8To8x4(const uint16x8_t in[8], + uint16x8_t out[4]) { + // Swap 16 bit elements. Goes from: + // a0: 00 01 02 03 + // a1: 10 11 12 13 + // a2: 20 21 22 23 + // a3: 30 31 32 33 + // a4: 40 41 42 43 + // a5: 50 51 52 53 + // a6: 60 61 62 63 + // a7: 70 71 72 73 + // to: + // b0.val[0]: 00 10 02 12 + // b0.val[1]: 01 11 03 13 + // b1.val[0]: 20 30 22 32 + // b1.val[1]: 21 31 23 33 + // b2.val[0]: 40 50 42 52 + // b2.val[1]: 41 51 43 53 + // b3.val[0]: 60 70 62 72 + // b3.val[1]: 61 71 63 73 + + uint16x4x2_t b0 = vtrn_u16(vget_low_u16(in[0]), vget_low_u16(in[1])); + uint16x4x2_t b1 = vtrn_u16(vget_low_u16(in[2]), vget_low_u16(in[3])); + uint16x4x2_t b2 = vtrn_u16(vget_low_u16(in[4]), vget_low_u16(in[5])); + uint16x4x2_t b3 = vtrn_u16(vget_low_u16(in[6]), vget_low_u16(in[7])); + + // Swap 32 bit elements resulting in: + // c0.val[0]: 00 10 20 30 + // c0.val[1]: 02 12 22 32 + // c1.val[0]: 01 11 21 31 + // c1.val[1]: 03 13 23 33 + // c2.val[0]: 40 50 60 70 + // c2.val[1]: 42 52 62 72 + // c3.val[0]: 41 51 61 71 + // c3.val[1]: 43 53 63 73 + + uint32x2x2_t c0 = vtrn_u32(vreinterpret_u32_u16(b0.val[0]), + vreinterpret_u32_u16(b1.val[0])); + uint32x2x2_t c1 = vtrn_u32(vreinterpret_u32_u16(b0.val[1]), + vreinterpret_u32_u16(b1.val[1])); + uint32x2x2_t c2 = vtrn_u32(vreinterpret_u32_u16(b2.val[0]), + vreinterpret_u32_u16(b3.val[0])); + uint32x2x2_t c3 = vtrn_u32(vreinterpret_u32_u16(b2.val[1]), + vreinterpret_u32_u16(b3.val[1])); + + // Swap 64 bit elements resulting in: + // o0: 00 10 20 30 40 50 60 70 + // o1: 01 11 21 31 41 51 61 71 + // o2: 02 12 22 32 42 52 62 72 + // o3: 03 13 23 33 43 53 63 73 + + out[0] = vcombine_u16(vreinterpret_u16_u32(c0.val[0]), + vreinterpret_u16_u32(c2.val[0])); + out[1] = vcombine_u16(vreinterpret_u16_u32(c1.val[0]), + vreinterpret_u16_u32(c3.val[0])); + out[2] = vcombine_u16(vreinterpret_u16_u32(c0.val[1]), + vreinterpret_u16_u32(c2.val[1])); + out[3] = vcombine_u16(vreinterpret_u16_u32(c1.val[1]), + vreinterpret_u16_u32(c3.val[1])); +} + +LIBGAV1_ALWAYS_INLINE void Transpose4x8To8x4(const int16x8_t in[8], + int16x8_t out[4]) { + Transpose4x8To8x4(reinterpret_cast<const uint16x8_t*>(in), + reinterpret_cast<uint16x8_t*>(out)); +} + +LIBGAV1_ALWAYS_INLINE void Transpose8x4To4x8(const int16x8_t in[4], + int16x8_t out[8]) { + // Swap 16 bit elements. Goes from: + // a0: 00 01 02 03 04 05 06 07 + // a1: 10 11 12 13 14 15 16 17 + // a2: 20 21 22 23 24 25 26 27 + // a3: 30 31 32 33 34 35 36 37 + // to: + // b0.val[0]: 00 10 02 12 04 14 06 16 + // b0.val[1]: 01 11 03 13 05 15 07 17 + // b1.val[0]: 20 30 22 32 24 34 26 36 + // b1.val[1]: 21 31 23 33 25 35 27 37 + const int16x8x2_t b0 = vtrnq_s16(in[0], in[1]); + const int16x8x2_t b1 = vtrnq_s16(in[2], in[3]); + + // Swap 32 bit elements resulting in: + // c0.val[0]: 00 10 20 30 04 14 24 34 + // c0.val[1]: 02 12 22 32 06 16 26 36 + // c1.val[0]: 01 11 21 31 05 15 25 35 + // c1.val[1]: 03 13 23 33 07 17 27 37 + const int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[0]), + vreinterpretq_s32_s16(b1.val[0])); + const int32x4x2_t c1 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[1]), + vreinterpretq_s32_s16(b1.val[1])); + + // The upper 8 bytes are don't cares. + // out[0]: 00 10 20 30 04 14 24 34 + // out[1]: 01 11 21 31 05 15 25 35 + // out[2]: 02 12 22 32 06 16 26 36 + // out[3]: 03 13 23 33 07 17 27 37 + // out[4]: 04 14 24 34 04 14 24 34 + // out[5]: 05 15 25 35 05 15 25 35 + // out[6]: 06 16 26 36 06 16 26 36 + // out[7]: 07 17 27 37 07 17 27 37 + out[0] = vreinterpretq_s16_s32(c0.val[0]); + out[1] = vreinterpretq_s16_s32(c1.val[0]); + out[2] = vreinterpretq_s16_s32(c0.val[1]); + out[3] = vreinterpretq_s16_s32(c1.val[1]); + out[4] = vreinterpretq_s16_s32( + vcombine_s32(vget_high_s32(c0.val[0]), vget_high_s32(c0.val[0]))); + out[5] = vreinterpretq_s16_s32( + vcombine_s32(vget_high_s32(c1.val[0]), vget_high_s32(c1.val[0]))); + out[6] = vreinterpretq_s16_s32( + vcombine_s32(vget_high_s32(c0.val[1]), vget_high_s32(c0.val[1]))); + out[7] = vreinterpretq_s16_s32( + vcombine_s32(vget_high_s32(c1.val[1]), vget_high_s32(c1.val[1]))); +} + +//------------------------------------------------------------------------------ +template <int store_width, int store_count> +LIBGAV1_ALWAYS_INLINE void StoreDst(int16_t* dst, int32_t stride, int32_t idx, + const int16x8_t* const s) { + assert(store_count % 4 == 0); + assert(store_width == 8 || store_width == 16); + // NOTE: It is expected that the compiler will unroll these loops. + if (store_width == 16) { + for (int i = 0; i < store_count; i += 4) { + vst1q_s16(&dst[i * stride + idx], (s[i])); + vst1q_s16(&dst[(i + 1) * stride + idx], (s[i + 1])); + vst1q_s16(&dst[(i + 2) * stride + idx], (s[i + 2])); + vst1q_s16(&dst[(i + 3) * stride + idx], (s[i + 3])); + } + } else { + // store_width == 8 + for (int i = 0; i < store_count; i += 4) { + vst1_s16(&dst[i * stride + idx], vget_low_s16(s[i])); + vst1_s16(&dst[(i + 1) * stride + idx], vget_low_s16(s[i + 1])); + vst1_s16(&dst[(i + 2) * stride + idx], vget_low_s16(s[i + 2])); + vst1_s16(&dst[(i + 3) * stride + idx], vget_low_s16(s[i + 3])); + } + } +} + +template <int load_width, int load_count> +LIBGAV1_ALWAYS_INLINE void LoadSrc(const int16_t* src, int32_t stride, + int32_t idx, int16x8_t* x) { + assert(load_count % 4 == 0); + assert(load_width == 8 || load_width == 16); + // NOTE: It is expected that the compiler will unroll these loops. + if (load_width == 16) { + for (int i = 0; i < load_count; i += 4) { + x[i] = vld1q_s16(&src[i * stride + idx]); + x[i + 1] = vld1q_s16(&src[(i + 1) * stride + idx]); + x[i + 2] = vld1q_s16(&src[(i + 2) * stride + idx]); + x[i + 3] = vld1q_s16(&src[(i + 3) * stride + idx]); + } + } else { + // load_width == 8 + const int64x2_t zero = vdupq_n_s64(0); + for (int i = 0; i < load_count; i += 4) { + // The src buffer is aligned to 32 bytes. Each load will always be 8 + // byte aligned. + x[i] = vreinterpretq_s16_s64(vld1q_lane_s64( + reinterpret_cast<const int64_t*>(&src[i * stride + idx]), zero, 0)); + x[i + 1] = vreinterpretq_s16_s64(vld1q_lane_s64( + reinterpret_cast<const int64_t*>(&src[(i + 1) * stride + idx]), zero, + 0)); + x[i + 2] = vreinterpretq_s16_s64(vld1q_lane_s64( + reinterpret_cast<const int64_t*>(&src[(i + 2) * stride + idx]), zero, + 0)); + x[i + 3] = vreinterpretq_s16_s64(vld1q_lane_s64( + reinterpret_cast<const int64_t*>(&src[(i + 3) * stride + idx]), zero, + 0)); + } + } +} + +// Butterfly rotate 4 values. +LIBGAV1_ALWAYS_INLINE void ButterflyRotation_4(int16x8_t* a, int16x8_t* b, + const int angle, + const bool flip) { + const int16_t cos128 = Cos128(angle); + const int16_t sin128 = Sin128(angle); + const int32x4_t acc_x = vmull_n_s16(vget_low_s16(*a), cos128); + const int32x4_t acc_y = vmull_n_s16(vget_low_s16(*a), sin128); + const int32x4_t x0 = vmlsl_n_s16(acc_x, vget_low_s16(*b), sin128); + const int32x4_t y0 = vmlal_n_s16(acc_y, vget_low_s16(*b), cos128); + const int16x4_t x1 = vqrshrn_n_s32(x0, 12); + const int16x4_t y1 = vqrshrn_n_s32(y0, 12); + const int16x8_t x = vcombine_s16(x1, x1); + const int16x8_t y = vcombine_s16(y1, y1); + if (flip) { + *a = y; + *b = x; + } else { + *a = x; + *b = y; + } +} + +// Butterfly rotate 8 values. +LIBGAV1_ALWAYS_INLINE void ButterflyRotation_8(int16x8_t* a, int16x8_t* b, + const int angle, + const bool flip) { + const int16_t cos128 = Cos128(angle); + const int16_t sin128 = Sin128(angle); + const int32x4_t acc_x = vmull_n_s16(vget_low_s16(*a), cos128); + const int32x4_t acc_y = vmull_n_s16(vget_low_s16(*a), sin128); + const int32x4_t x0 = vmlsl_n_s16(acc_x, vget_low_s16(*b), sin128); + const int32x4_t y0 = vmlal_n_s16(acc_y, vget_low_s16(*b), cos128); + const int16x4_t x1 = vqrshrn_n_s32(x0, 12); + const int16x4_t y1 = vqrshrn_n_s32(y0, 12); + + const int32x4_t acc_x_hi = vmull_n_s16(vget_high_s16(*a), cos128); + const int32x4_t acc_y_hi = vmull_n_s16(vget_high_s16(*a), sin128); + const int32x4_t x0_hi = vmlsl_n_s16(acc_x_hi, vget_high_s16(*b), sin128); + const int32x4_t y0_hi = vmlal_n_s16(acc_y_hi, vget_high_s16(*b), cos128); + const int16x4_t x1_hi = vqrshrn_n_s32(x0_hi, 12); + const int16x4_t y1_hi = vqrshrn_n_s32(y0_hi, 12); + + const int16x8_t x = vcombine_s16(x1, x1_hi); + const int16x8_t y = vcombine_s16(y1, y1_hi); + if (flip) { + *a = y; + *b = x; + } else { + *a = x; + *b = y; + } +} + +LIBGAV1_ALWAYS_INLINE void ButterflyRotation_FirstIsZero(int16x8_t* a, + int16x8_t* b, + const int angle, + const bool flip) { + const int16_t cos128 = Cos128(angle); + const int16_t sin128 = Sin128(angle); + // For this function, the max value returned by Sin128() is 4091, which fits + // inside 12 bits. This leaves room for the sign bit and the 3 left shifted + // bits. + assert(sin128 <= 0xfff); + const int16x8_t x = vqrdmulhq_n_s16(*b, -sin128 << 3); + const int16x8_t y = vqrdmulhq_n_s16(*b, cos128 << 3); + if (flip) { + *a = y; + *b = x; + } else { + *a = x; + *b = y; + } +} + +LIBGAV1_ALWAYS_INLINE void ButterflyRotation_SecondIsZero(int16x8_t* a, + int16x8_t* b, + const int angle, + const bool flip) { + const int16_t cos128 = Cos128(angle); + const int16_t sin128 = Sin128(angle); + const int16x8_t x = vqrdmulhq_n_s16(*a, cos128 << 3); + const int16x8_t y = vqrdmulhq_n_s16(*a, sin128 << 3); + if (flip) { + *a = y; + *b = x; + } else { + *a = x; + *b = y; + } +} + +LIBGAV1_ALWAYS_INLINE void HadamardRotation(int16x8_t* a, int16x8_t* b, + bool flip) { + int16x8_t x, y; + if (flip) { + y = vqaddq_s16(*b, *a); + x = vqsubq_s16(*b, *a); + } else { + x = vqaddq_s16(*a, *b); + y = vqsubq_s16(*a, *b); + } + *a = x; + *b = y; +} + +using ButterflyRotationFunc = void (*)(int16x8_t* a, int16x8_t* b, int angle, + bool flip); + +//------------------------------------------------------------------------------ +// Discrete Cosine Transforms (DCT). + +template <int width> +LIBGAV1_ALWAYS_INLINE bool DctDcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + const int16x8_t v_src = vdupq_n_s16(dst[0]); + const uint16x8_t v_mask = vdupq_n_u16(should_round ? 0xffff : 0); + const int16x8_t v_src_round = + vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3); + const int16x8_t s0 = vbslq_s16(v_mask, v_src_round, v_src); + const int16_t cos128 = Cos128(32); + const int16x8_t xy = vqrdmulhq_n_s16(s0, cos128 << 3); + // vqrshlq_s16 will shift right if shift value is negative. + const int16x8_t xy_shifted = vqrshlq_s16(xy, vdupq_n_s16(-row_shift)); + + if (width == 4) { + vst1_s16(dst, vget_low_s16(xy_shifted)); + } else { + for (int i = 0; i < width; i += 8) { + vst1q_s16(dst, xy_shifted); + dst += 8; + } + } + return true; +} + +template <int height> +LIBGAV1_ALWAYS_INLINE bool DctDcOnlyColumn(void* dest, int adjusted_tx_height, + int width) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + const int16_t cos128 = Cos128(32); + + // Calculate dc values for first row. + if (width == 4) { + const int16x4_t v_src = vld1_s16(dst); + const int16x4_t xy = vqrdmulh_n_s16(v_src, cos128 << 3); + vst1_s16(dst, xy); + } else { + int i = 0; + do { + const int16x8_t v_src = vld1q_s16(&dst[i]); + const int16x8_t xy = vqrdmulhq_n_s16(v_src, cos128 << 3); + vst1q_s16(&dst[i], xy); + i += 8; + } while (i < width); + } + + // Copy first row to the rest of the block. + for (int y = 1; y < height; ++y) { + memcpy(&dst[y * width], dst, width * sizeof(dst[0])); + } + return true; +} + +template <ButterflyRotationFunc butterfly_rotation, + bool is_fast_butterfly = false> +LIBGAV1_ALWAYS_INLINE void Dct4Stages(int16x8_t* s) { + // stage 12. + if (is_fast_butterfly) { + ButterflyRotation_SecondIsZero(&s[0], &s[1], 32, true); + ButterflyRotation_SecondIsZero(&s[2], &s[3], 48, false); + } else { + butterfly_rotation(&s[0], &s[1], 32, true); + butterfly_rotation(&s[2], &s[3], 48, false); + } + + // stage 17. + HadamardRotation(&s[0], &s[3], false); + HadamardRotation(&s[1], &s[2], false); +} + +template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular> +LIBGAV1_ALWAYS_INLINE void Dct4_NEON(void* dest, int32_t step, bool transpose) { + auto* const dst = static_cast<int16_t*>(dest); + int16x8_t s[4], x[4]; + + if (stage_is_rectangular) { + if (transpose) { + int16x8_t input[8]; + LoadSrc<8, 8>(dst, step, 0, input); + Transpose4x8To8x4(input, x); + } else { + LoadSrc<16, 4>(dst, step, 0, x); + } + } else { + LoadSrc<8, 4>(dst, step, 0, x); + if (transpose) { + Transpose4x4(x, x); + } + } + + // stage 1. + // kBitReverseLookup 0, 2, 1, 3 + s[0] = x[0]; + s[1] = x[2]; + s[2] = x[1]; + s[3] = x[3]; + + Dct4Stages<butterfly_rotation>(s); + + if (stage_is_rectangular) { + if (transpose) { + int16x8_t output[8]; + Transpose8x4To4x8(s, output); + StoreDst<8, 8>(dst, step, 0, output); + } else { + StoreDst<16, 4>(dst, step, 0, s); + } + } else { + if (transpose) { + Transpose4x4(s, s); + } + StoreDst<8, 4>(dst, step, 0, s); + } +} + +template <ButterflyRotationFunc butterfly_rotation, + bool is_fast_butterfly = false> +LIBGAV1_ALWAYS_INLINE void Dct8Stages(int16x8_t* s) { + // stage 8. + if (is_fast_butterfly) { + ButterflyRotation_SecondIsZero(&s[4], &s[7], 56, false); + ButterflyRotation_FirstIsZero(&s[5], &s[6], 24, false); + } else { + butterfly_rotation(&s[4], &s[7], 56, false); + butterfly_rotation(&s[5], &s[6], 24, false); + } + + // stage 13. + HadamardRotation(&s[4], &s[5], false); + HadamardRotation(&s[6], &s[7], true); + + // stage 18. + butterfly_rotation(&s[6], &s[5], 32, true); + + // stage 22. + HadamardRotation(&s[0], &s[7], false); + HadamardRotation(&s[1], &s[6], false); + HadamardRotation(&s[2], &s[5], false); + HadamardRotation(&s[3], &s[4], false); +} + +// Process dct8 rows or columns, depending on the transpose flag. +template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular> +LIBGAV1_ALWAYS_INLINE void Dct8_NEON(void* dest, int32_t step, bool transpose) { + auto* const dst = static_cast<int16_t*>(dest); + int16x8_t s[8], x[8]; + + if (stage_is_rectangular) { + if (transpose) { + int16x8_t input[4]; + LoadSrc<16, 4>(dst, step, 0, input); + Transpose8x4To4x8(input, x); + } else { + LoadSrc<8, 8>(dst, step, 0, x); + } + } else if (transpose) { + LoadSrc<16, 8>(dst, step, 0, x); + dsp::Transpose8x8(x); + } else { + LoadSrc<16, 8>(dst, step, 0, x); + } + + // stage 1. + // kBitReverseLookup 0, 4, 2, 6, 1, 5, 3, 7, + s[0] = x[0]; + s[1] = x[4]; + s[2] = x[2]; + s[3] = x[6]; + s[4] = x[1]; + s[5] = x[5]; + s[6] = x[3]; + s[7] = x[7]; + + Dct4Stages<butterfly_rotation>(s); + Dct8Stages<butterfly_rotation>(s); + + if (stage_is_rectangular) { + if (transpose) { + int16x8_t output[4]; + Transpose4x8To8x4(s, output); + StoreDst<16, 4>(dst, step, 0, output); + } else { + StoreDst<8, 8>(dst, step, 0, s); + } + } else if (transpose) { + dsp::Transpose8x8(s); + StoreDst<16, 8>(dst, step, 0, s); + } else { + StoreDst<16, 8>(dst, step, 0, s); + } +} + +template <ButterflyRotationFunc butterfly_rotation, + bool is_fast_butterfly = false> +LIBGAV1_ALWAYS_INLINE void Dct16Stages(int16x8_t* s) { + // stage 5. + if (is_fast_butterfly) { + ButterflyRotation_SecondIsZero(&s[8], &s[15], 60, false); + ButterflyRotation_FirstIsZero(&s[9], &s[14], 28, false); + ButterflyRotation_SecondIsZero(&s[10], &s[13], 44, false); + ButterflyRotation_FirstIsZero(&s[11], &s[12], 12, false); + } else { + butterfly_rotation(&s[8], &s[15], 60, false); + butterfly_rotation(&s[9], &s[14], 28, false); + butterfly_rotation(&s[10], &s[13], 44, false); + butterfly_rotation(&s[11], &s[12], 12, false); + } + + // stage 9. + HadamardRotation(&s[8], &s[9], false); + HadamardRotation(&s[10], &s[11], true); + HadamardRotation(&s[12], &s[13], false); + HadamardRotation(&s[14], &s[15], true); + + // stage 14. + butterfly_rotation(&s[14], &s[9], 48, true); + butterfly_rotation(&s[13], &s[10], 112, true); + + // stage 19. + HadamardRotation(&s[8], &s[11], false); + HadamardRotation(&s[9], &s[10], false); + HadamardRotation(&s[12], &s[15], true); + HadamardRotation(&s[13], &s[14], true); + + // stage 23. + butterfly_rotation(&s[13], &s[10], 32, true); + butterfly_rotation(&s[12], &s[11], 32, true); + + // stage 26. + HadamardRotation(&s[0], &s[15], false); + HadamardRotation(&s[1], &s[14], false); + HadamardRotation(&s[2], &s[13], false); + HadamardRotation(&s[3], &s[12], false); + HadamardRotation(&s[4], &s[11], false); + HadamardRotation(&s[5], &s[10], false); + HadamardRotation(&s[6], &s[9], false); + HadamardRotation(&s[7], &s[8], false); +} + +// Process dct16 rows or columns, depending on the transpose flag. +template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular> +LIBGAV1_ALWAYS_INLINE void Dct16_NEON(void* dest, int32_t step, bool is_row, + int row_shift) { + auto* const dst = static_cast<int16_t*>(dest); + int16x8_t s[16], x[16]; + + if (stage_is_rectangular) { + if (is_row) { + int16x8_t input[4]; + LoadSrc<16, 4>(dst, step, 0, input); + Transpose8x4To4x8(input, x); + LoadSrc<16, 4>(dst, step, 8, input); + Transpose8x4To4x8(input, &x[8]); + } else { + LoadSrc<8, 16>(dst, step, 0, x); + } + } else if (is_row) { + for (int idx = 0; idx < 16; idx += 8) { + LoadSrc<16, 8>(dst, step, idx, &x[idx]); + dsp::Transpose8x8(&x[idx]); + } + } else { + LoadSrc<16, 16>(dst, step, 0, x); + } + + // stage 1 + // kBitReverseLookup 0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15, + s[0] = x[0]; + s[1] = x[8]; + s[2] = x[4]; + s[3] = x[12]; + s[4] = x[2]; + s[5] = x[10]; + s[6] = x[6]; + s[7] = x[14]; + s[8] = x[1]; + s[9] = x[9]; + s[10] = x[5]; + s[11] = x[13]; + s[12] = x[3]; + s[13] = x[11]; + s[14] = x[7]; + s[15] = x[15]; + + Dct4Stages<butterfly_rotation>(s); + Dct8Stages<butterfly_rotation>(s); + Dct16Stages<butterfly_rotation>(s); + + if (is_row) { + const int16x8_t v_row_shift = vdupq_n_s16(-row_shift); + for (int i = 0; i < 16; ++i) { + s[i] = vqrshlq_s16(s[i], v_row_shift); + } + } + + if (stage_is_rectangular) { + if (is_row) { + int16x8_t output[4]; + Transpose4x8To8x4(s, output); + StoreDst<16, 4>(dst, step, 0, output); + Transpose4x8To8x4(&s[8], output); + StoreDst<16, 4>(dst, step, 8, output); + } else { + StoreDst<8, 16>(dst, step, 0, s); + } + } else if (is_row) { + for (int idx = 0; idx < 16; idx += 8) { + dsp::Transpose8x8(&s[idx]); + StoreDst<16, 8>(dst, step, idx, &s[idx]); + } + } else { + StoreDst<16, 16>(dst, step, 0, s); + } +} + +template <ButterflyRotationFunc butterfly_rotation, + bool is_fast_butterfly = false> +LIBGAV1_ALWAYS_INLINE void Dct32Stages(int16x8_t* s) { + // stage 3 + if (is_fast_butterfly) { + ButterflyRotation_SecondIsZero(&s[16], &s[31], 62, false); + ButterflyRotation_FirstIsZero(&s[17], &s[30], 30, false); + ButterflyRotation_SecondIsZero(&s[18], &s[29], 46, false); + ButterflyRotation_FirstIsZero(&s[19], &s[28], 14, false); + ButterflyRotation_SecondIsZero(&s[20], &s[27], 54, false); + ButterflyRotation_FirstIsZero(&s[21], &s[26], 22, false); + ButterflyRotation_SecondIsZero(&s[22], &s[25], 38, false); + ButterflyRotation_FirstIsZero(&s[23], &s[24], 6, false); + } else { + butterfly_rotation(&s[16], &s[31], 62, false); + butterfly_rotation(&s[17], &s[30], 30, false); + butterfly_rotation(&s[18], &s[29], 46, false); + butterfly_rotation(&s[19], &s[28], 14, false); + butterfly_rotation(&s[20], &s[27], 54, false); + butterfly_rotation(&s[21], &s[26], 22, false); + butterfly_rotation(&s[22], &s[25], 38, false); + butterfly_rotation(&s[23], &s[24], 6, false); + } + // stage 6. + HadamardRotation(&s[16], &s[17], false); + HadamardRotation(&s[18], &s[19], true); + HadamardRotation(&s[20], &s[21], false); + HadamardRotation(&s[22], &s[23], true); + HadamardRotation(&s[24], &s[25], false); + HadamardRotation(&s[26], &s[27], true); + HadamardRotation(&s[28], &s[29], false); + HadamardRotation(&s[30], &s[31], true); + + // stage 10. + butterfly_rotation(&s[30], &s[17], 24 + 32, true); + butterfly_rotation(&s[29], &s[18], 24 + 64 + 32, true); + butterfly_rotation(&s[26], &s[21], 24, true); + butterfly_rotation(&s[25], &s[22], 24 + 64, true); + + // stage 15. + HadamardRotation(&s[16], &s[19], false); + HadamardRotation(&s[17], &s[18], false); + HadamardRotation(&s[20], &s[23], true); + HadamardRotation(&s[21], &s[22], true); + HadamardRotation(&s[24], &s[27], false); + HadamardRotation(&s[25], &s[26], false); + HadamardRotation(&s[28], &s[31], true); + HadamardRotation(&s[29], &s[30], true); + + // stage 20. + butterfly_rotation(&s[29], &s[18], 48, true); + butterfly_rotation(&s[28], &s[19], 48, true); + butterfly_rotation(&s[27], &s[20], 48 + 64, true); + butterfly_rotation(&s[26], &s[21], 48 + 64, true); + + // stage 24. + HadamardRotation(&s[16], &s[23], false); + HadamardRotation(&s[17], &s[22], false); + HadamardRotation(&s[18], &s[21], false); + HadamardRotation(&s[19], &s[20], false); + HadamardRotation(&s[24], &s[31], true); + HadamardRotation(&s[25], &s[30], true); + HadamardRotation(&s[26], &s[29], true); + HadamardRotation(&s[27], &s[28], true); + + // stage 27. + butterfly_rotation(&s[27], &s[20], 32, true); + butterfly_rotation(&s[26], &s[21], 32, true); + butterfly_rotation(&s[25], &s[22], 32, true); + butterfly_rotation(&s[24], &s[23], 32, true); + + // stage 29. + HadamardRotation(&s[0], &s[31], false); + HadamardRotation(&s[1], &s[30], false); + HadamardRotation(&s[2], &s[29], false); + HadamardRotation(&s[3], &s[28], false); + HadamardRotation(&s[4], &s[27], false); + HadamardRotation(&s[5], &s[26], false); + HadamardRotation(&s[6], &s[25], false); + HadamardRotation(&s[7], &s[24], false); + HadamardRotation(&s[8], &s[23], false); + HadamardRotation(&s[9], &s[22], false); + HadamardRotation(&s[10], &s[21], false); + HadamardRotation(&s[11], &s[20], false); + HadamardRotation(&s[12], &s[19], false); + HadamardRotation(&s[13], &s[18], false); + HadamardRotation(&s[14], &s[17], false); + HadamardRotation(&s[15], &s[16], false); +} + +// Process dct32 rows or columns, depending on the transpose flag. +LIBGAV1_ALWAYS_INLINE void Dct32_NEON(void* dest, const int32_t step, + const bool is_row, int row_shift) { + auto* const dst = static_cast<int16_t*>(dest); + int16x8_t s[32], x[32]; + + if (is_row) { + for (int idx = 0; idx < 32; idx += 8) { + LoadSrc<16, 8>(dst, step, idx, &x[idx]); + dsp::Transpose8x8(&x[idx]); + } + } else { + LoadSrc<16, 32>(dst, step, 0, x); + } + + // stage 1 + // kBitReverseLookup + // 0, 16, 8, 24, 4, 20, 12, 28, 2, 18, 10, 26, 6, 22, 14, 30, + s[0] = x[0]; + s[1] = x[16]; + s[2] = x[8]; + s[3] = x[24]; + s[4] = x[4]; + s[5] = x[20]; + s[6] = x[12]; + s[7] = x[28]; + s[8] = x[2]; + s[9] = x[18]; + s[10] = x[10]; + s[11] = x[26]; + s[12] = x[6]; + s[13] = x[22]; + s[14] = x[14]; + s[15] = x[30]; + + // 1, 17, 9, 25, 5, 21, 13, 29, 3, 19, 11, 27, 7, 23, 15, 31, + s[16] = x[1]; + s[17] = x[17]; + s[18] = x[9]; + s[19] = x[25]; + s[20] = x[5]; + s[21] = x[21]; + s[22] = x[13]; + s[23] = x[29]; + s[24] = x[3]; + s[25] = x[19]; + s[26] = x[11]; + s[27] = x[27]; + s[28] = x[7]; + s[29] = x[23]; + s[30] = x[15]; + s[31] = x[31]; + + Dct4Stages<ButterflyRotation_8>(s); + Dct8Stages<ButterflyRotation_8>(s); + Dct16Stages<ButterflyRotation_8>(s); + Dct32Stages<ButterflyRotation_8>(s); + + if (is_row) { + const int16x8_t v_row_shift = vdupq_n_s16(-row_shift); + for (int idx = 0; idx < 32; idx += 8) { + int16x8_t output[8]; + Transpose8x8(&s[idx], output); + for (int i = 0; i < 8; ++i) { + output[i] = vqrshlq_s16(output[i], v_row_shift); + } + StoreDst<16, 8>(dst, step, idx, output); + } + } else { + StoreDst<16, 32>(dst, step, 0, s); + } +} + +// Allow the compiler to call this function instead of force inlining. Tests +// show the performance is slightly faster. +void Dct64_NEON(void* dest, int32_t step, bool is_row, int row_shift) { + auto* const dst = static_cast<int16_t*>(dest); + int16x8_t s[64], x[32]; + + if (is_row) { + // The last 32 values of every row are always zero if the |tx_width| is + // 64. + for (int idx = 0; idx < 32; idx += 8) { + LoadSrc<16, 8>(dst, step, idx, &x[idx]); + dsp::Transpose8x8(&x[idx]); + } + } else { + // The last 32 values of every column are always zero if the |tx_height| is + // 64. + LoadSrc<16, 32>(dst, step, 0, x); + } + + // stage 1 + // kBitReverseLookup + // 0, 32, 16, 48, 8, 40, 24, 56, 4, 36, 20, 52, 12, 44, 28, 60, + s[0] = x[0]; + s[2] = x[16]; + s[4] = x[8]; + s[6] = x[24]; + s[8] = x[4]; + s[10] = x[20]; + s[12] = x[12]; + s[14] = x[28]; + + // 2, 34, 18, 50, 10, 42, 26, 58, 6, 38, 22, 54, 14, 46, 30, 62, + s[16] = x[2]; + s[18] = x[18]; + s[20] = x[10]; + s[22] = x[26]; + s[24] = x[6]; + s[26] = x[22]; + s[28] = x[14]; + s[30] = x[30]; + + // 1, 33, 17, 49, 9, 41, 25, 57, 5, 37, 21, 53, 13, 45, 29, 61, + s[32] = x[1]; + s[34] = x[17]; + s[36] = x[9]; + s[38] = x[25]; + s[40] = x[5]; + s[42] = x[21]; + s[44] = x[13]; + s[46] = x[29]; + + // 3, 35, 19, 51, 11, 43, 27, 59, 7, 39, 23, 55, 15, 47, 31, 63 + s[48] = x[3]; + s[50] = x[19]; + s[52] = x[11]; + s[54] = x[27]; + s[56] = x[7]; + s[58] = x[23]; + s[60] = x[15]; + s[62] = x[31]; + + Dct4Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s); + Dct8Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s); + Dct16Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s); + Dct32Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s); + + //-- start dct 64 stages + // stage 2. + ButterflyRotation_SecondIsZero(&s[32], &s[63], 63 - 0, false); + ButterflyRotation_FirstIsZero(&s[33], &s[62], 63 - 32, false); + ButterflyRotation_SecondIsZero(&s[34], &s[61], 63 - 16, false); + ButterflyRotation_FirstIsZero(&s[35], &s[60], 63 - 48, false); + ButterflyRotation_SecondIsZero(&s[36], &s[59], 63 - 8, false); + ButterflyRotation_FirstIsZero(&s[37], &s[58], 63 - 40, false); + ButterflyRotation_SecondIsZero(&s[38], &s[57], 63 - 24, false); + ButterflyRotation_FirstIsZero(&s[39], &s[56], 63 - 56, false); + ButterflyRotation_SecondIsZero(&s[40], &s[55], 63 - 4, false); + ButterflyRotation_FirstIsZero(&s[41], &s[54], 63 - 36, false); + ButterflyRotation_SecondIsZero(&s[42], &s[53], 63 - 20, false); + ButterflyRotation_FirstIsZero(&s[43], &s[52], 63 - 52, false); + ButterflyRotation_SecondIsZero(&s[44], &s[51], 63 - 12, false); + ButterflyRotation_FirstIsZero(&s[45], &s[50], 63 - 44, false); + ButterflyRotation_SecondIsZero(&s[46], &s[49], 63 - 28, false); + ButterflyRotation_FirstIsZero(&s[47], &s[48], 63 - 60, false); + + // stage 4. + HadamardRotation(&s[32], &s[33], false); + HadamardRotation(&s[34], &s[35], true); + HadamardRotation(&s[36], &s[37], false); + HadamardRotation(&s[38], &s[39], true); + HadamardRotation(&s[40], &s[41], false); + HadamardRotation(&s[42], &s[43], true); + HadamardRotation(&s[44], &s[45], false); + HadamardRotation(&s[46], &s[47], true); + HadamardRotation(&s[48], &s[49], false); + HadamardRotation(&s[50], &s[51], true); + HadamardRotation(&s[52], &s[53], false); + HadamardRotation(&s[54], &s[55], true); + HadamardRotation(&s[56], &s[57], false); + HadamardRotation(&s[58], &s[59], true); + HadamardRotation(&s[60], &s[61], false); + HadamardRotation(&s[62], &s[63], true); + + // stage 7. + ButterflyRotation_8(&s[62], &s[33], 60 - 0, true); + ButterflyRotation_8(&s[61], &s[34], 60 - 0 + 64, true); + ButterflyRotation_8(&s[58], &s[37], 60 - 32, true); + ButterflyRotation_8(&s[57], &s[38], 60 - 32 + 64, true); + ButterflyRotation_8(&s[54], &s[41], 60 - 16, true); + ButterflyRotation_8(&s[53], &s[42], 60 - 16 + 64, true); + ButterflyRotation_8(&s[50], &s[45], 60 - 48, true); + ButterflyRotation_8(&s[49], &s[46], 60 - 48 + 64, true); + + // stage 11. + HadamardRotation(&s[32], &s[35], false); + HadamardRotation(&s[33], &s[34], false); + HadamardRotation(&s[36], &s[39], true); + HadamardRotation(&s[37], &s[38], true); + HadamardRotation(&s[40], &s[43], false); + HadamardRotation(&s[41], &s[42], false); + HadamardRotation(&s[44], &s[47], true); + HadamardRotation(&s[45], &s[46], true); + HadamardRotation(&s[48], &s[51], false); + HadamardRotation(&s[49], &s[50], false); + HadamardRotation(&s[52], &s[55], true); + HadamardRotation(&s[53], &s[54], true); + HadamardRotation(&s[56], &s[59], false); + HadamardRotation(&s[57], &s[58], false); + HadamardRotation(&s[60], &s[63], true); + HadamardRotation(&s[61], &s[62], true); + + // stage 16. + ButterflyRotation_8(&s[61], &s[34], 56, true); + ButterflyRotation_8(&s[60], &s[35], 56, true); + ButterflyRotation_8(&s[59], &s[36], 56 + 64, true); + ButterflyRotation_8(&s[58], &s[37], 56 + 64, true); + ButterflyRotation_8(&s[53], &s[42], 56 - 32, true); + ButterflyRotation_8(&s[52], &s[43], 56 - 32, true); + ButterflyRotation_8(&s[51], &s[44], 56 - 32 + 64, true); + ButterflyRotation_8(&s[50], &s[45], 56 - 32 + 64, true); + + // stage 21. + HadamardRotation(&s[32], &s[39], false); + HadamardRotation(&s[33], &s[38], false); + HadamardRotation(&s[34], &s[37], false); + HadamardRotation(&s[35], &s[36], false); + HadamardRotation(&s[40], &s[47], true); + HadamardRotation(&s[41], &s[46], true); + HadamardRotation(&s[42], &s[45], true); + HadamardRotation(&s[43], &s[44], true); + HadamardRotation(&s[48], &s[55], false); + HadamardRotation(&s[49], &s[54], false); + HadamardRotation(&s[50], &s[53], false); + HadamardRotation(&s[51], &s[52], false); + HadamardRotation(&s[56], &s[63], true); + HadamardRotation(&s[57], &s[62], true); + HadamardRotation(&s[58], &s[61], true); + HadamardRotation(&s[59], &s[60], true); + + // stage 25. + ButterflyRotation_8(&s[59], &s[36], 48, true); + ButterflyRotation_8(&s[58], &s[37], 48, true); + ButterflyRotation_8(&s[57], &s[38], 48, true); + ButterflyRotation_8(&s[56], &s[39], 48, true); + ButterflyRotation_8(&s[55], &s[40], 112, true); + ButterflyRotation_8(&s[54], &s[41], 112, true); + ButterflyRotation_8(&s[53], &s[42], 112, true); + ButterflyRotation_8(&s[52], &s[43], 112, true); + + // stage 28. + HadamardRotation(&s[32], &s[47], false); + HadamardRotation(&s[33], &s[46], false); + HadamardRotation(&s[34], &s[45], false); + HadamardRotation(&s[35], &s[44], false); + HadamardRotation(&s[36], &s[43], false); + HadamardRotation(&s[37], &s[42], false); + HadamardRotation(&s[38], &s[41], false); + HadamardRotation(&s[39], &s[40], false); + HadamardRotation(&s[48], &s[63], true); + HadamardRotation(&s[49], &s[62], true); + HadamardRotation(&s[50], &s[61], true); + HadamardRotation(&s[51], &s[60], true); + HadamardRotation(&s[52], &s[59], true); + HadamardRotation(&s[53], &s[58], true); + HadamardRotation(&s[54], &s[57], true); + HadamardRotation(&s[55], &s[56], true); + + // stage 30. + ButterflyRotation_8(&s[55], &s[40], 32, true); + ButterflyRotation_8(&s[54], &s[41], 32, true); + ButterflyRotation_8(&s[53], &s[42], 32, true); + ButterflyRotation_8(&s[52], &s[43], 32, true); + ButterflyRotation_8(&s[51], &s[44], 32, true); + ButterflyRotation_8(&s[50], &s[45], 32, true); + ButterflyRotation_8(&s[49], &s[46], 32, true); + ButterflyRotation_8(&s[48], &s[47], 32, true); + + // stage 31. + for (int i = 0; i < 32; i += 4) { + HadamardRotation(&s[i], &s[63 - i], false); + HadamardRotation(&s[i + 1], &s[63 - i - 1], false); + HadamardRotation(&s[i + 2], &s[63 - i - 2], false); + HadamardRotation(&s[i + 3], &s[63 - i - 3], false); + } + //-- end dct 64 stages + + if (is_row) { + const int16x8_t v_row_shift = vdupq_n_s16(-row_shift); + for (int idx = 0; idx < 64; idx += 8) { + int16x8_t output[8]; + Transpose8x8(&s[idx], output); + for (int i = 0; i < 8; ++i) { + output[i] = vqrshlq_s16(output[i], v_row_shift); + } + StoreDst<16, 8>(dst, step, idx, output); + } + } else { + StoreDst<16, 64>(dst, step, 0, s); + } +} + +//------------------------------------------------------------------------------ +// Asymmetric Discrete Sine Transforms (ADST). +template <bool stage_is_rectangular> +LIBGAV1_ALWAYS_INLINE void Adst4_NEON(void* dest, int32_t step, + bool transpose) { + auto* const dst = static_cast<int16_t*>(dest); + int32x4_t s[8]; + int16x8_t x[4]; + + if (stage_is_rectangular) { + if (transpose) { + int16x8_t input[8]; + LoadSrc<8, 8>(dst, step, 0, input); + Transpose4x8To8x4(input, x); + } else { + LoadSrc<16, 4>(dst, step, 0, x); + } + } else { + LoadSrc<8, 4>(dst, step, 0, x); + if (transpose) { + Transpose4x4(x, x); + } + } + + // stage 1. + s[5] = vmull_n_s16(vget_low_s16(x[3]), kAdst4Multiplier[1]); + s[6] = vmull_n_s16(vget_low_s16(x[3]), kAdst4Multiplier[3]); + + // stage 2. + const int32x4_t a7 = vsubl_s16(vget_low_s16(x[0]), vget_low_s16(x[2])); + const int32x4_t b7 = vaddw_s16(a7, vget_low_s16(x[3])); + + // stage 3. + s[0] = vmull_n_s16(vget_low_s16(x[0]), kAdst4Multiplier[0]); + s[1] = vmull_n_s16(vget_low_s16(x[0]), kAdst4Multiplier[1]); + // s[0] = s[0] + s[3] + s[0] = vmlal_n_s16(s[0], vget_low_s16(x[2]), kAdst4Multiplier[3]); + // s[1] = s[1] - s[4] + s[1] = vmlsl_n_s16(s[1], vget_low_s16(x[2]), kAdst4Multiplier[0]); + + s[3] = vmull_n_s16(vget_low_s16(x[1]), kAdst4Multiplier[2]); + s[2] = vmulq_n_s32(b7, kAdst4Multiplier[2]); + + // stage 4. + s[0] = vaddq_s32(s[0], s[5]); + s[1] = vsubq_s32(s[1], s[6]); + + // stages 5 and 6. + const int32x4_t x0 = vaddq_s32(s[0], s[3]); + const int32x4_t x1 = vaddq_s32(s[1], s[3]); + const int32x4_t x3_a = vaddq_s32(s[0], s[1]); + const int32x4_t x3 = vsubq_s32(x3_a, s[3]); + const int16x4_t dst_0 = vqrshrn_n_s32(x0, 12); + const int16x4_t dst_1 = vqrshrn_n_s32(x1, 12); + const int16x4_t dst_2 = vqrshrn_n_s32(s[2], 12); + const int16x4_t dst_3 = vqrshrn_n_s32(x3, 12); + + x[0] = vcombine_s16(dst_0, dst_0); + x[1] = vcombine_s16(dst_1, dst_1); + x[2] = vcombine_s16(dst_2, dst_2); + x[3] = vcombine_s16(dst_3, dst_3); + + if (stage_is_rectangular) { + if (transpose) { + int16x8_t output[8]; + Transpose8x4To4x8(x, output); + StoreDst<8, 8>(dst, step, 0, output); + } else { + StoreDst<16, 4>(dst, step, 0, x); + } + } else { + if (transpose) { + Transpose4x4(x, x); + } + StoreDst<8, 4>(dst, step, 0, x); + } +} + +alignas(8) constexpr int16_t kAdst4DcOnlyMultiplier[4] = {1321, 2482, 3344, + 2482}; + +LIBGAV1_ALWAYS_INLINE bool Adst4DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + int32x4_t s[2]; + + const int16x4_t v_src0 = vdup_n_s16(dst[0]); + const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0); + const int16x4_t v_src_round = + vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3); + const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0); + const int16x4_t kAdst4DcOnlyMultipliers = vld1_s16(kAdst4DcOnlyMultiplier); + s[1] = vdupq_n_s32(0); + + // s0*k0 s0*k1 s0*k2 s0*k1 + s[0] = vmull_s16(kAdst4DcOnlyMultipliers, v_src); + // 0 0 0 s0*k0 + s[1] = vextq_s32(s[1], s[0], 1); + + const int32x4_t x3 = vaddq_s32(s[0], s[1]); + const int16x4_t dst_0 = vqrshrn_n_s32(x3, 12); + + // vqrshlq_s16 will shift right if shift value is negative. + vst1_s16(dst, vqrshl_s16(dst_0, vdup_n_s16(-row_shift))); + + return true; +} + +LIBGAV1_ALWAYS_INLINE bool Adst4DcOnlyColumn(void* dest, int adjusted_tx_height, + int width) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + int32x4_t s[4]; + + int i = 0; + do { + const int16x4_t v_src = vld1_s16(&dst[i]); + + s[0] = vmull_n_s16(v_src, kAdst4Multiplier[0]); + s[1] = vmull_n_s16(v_src, kAdst4Multiplier[1]); + s[2] = vmull_n_s16(v_src, kAdst4Multiplier[2]); + + const int32x4_t x0 = s[0]; + const int32x4_t x1 = s[1]; + const int32x4_t x2 = s[2]; + const int32x4_t x3 = vaddq_s32(s[0], s[1]); + const int16x4_t dst_0 = vqrshrn_n_s32(x0, 12); + const int16x4_t dst_1 = vqrshrn_n_s32(x1, 12); + const int16x4_t dst_2 = vqrshrn_n_s32(x2, 12); + const int16x4_t dst_3 = vqrshrn_n_s32(x3, 12); + + vst1_s16(&dst[i], dst_0); + vst1_s16(&dst[i + width * 1], dst_1); + vst1_s16(&dst[i + width * 2], dst_2); + vst1_s16(&dst[i + width * 3], dst_3); + + i += 4; + } while (i < width); + + return true; +} + +template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular> +LIBGAV1_ALWAYS_INLINE void Adst8_NEON(void* dest, int32_t step, + bool transpose) { + auto* const dst = static_cast<int16_t*>(dest); + int16x8_t s[8], x[8]; + + if (stage_is_rectangular) { + if (transpose) { + int16x8_t input[4]; + LoadSrc<16, 4>(dst, step, 0, input); + Transpose8x4To4x8(input, x); + } else { + LoadSrc<8, 8>(dst, step, 0, x); + } + } else { + if (transpose) { + LoadSrc<16, 8>(dst, step, 0, x); + dsp::Transpose8x8(x); + } else { + LoadSrc<16, 8>(dst, step, 0, x); + } + } + + // stage 1. + s[0] = x[7]; + s[1] = x[0]; + s[2] = x[5]; + s[3] = x[2]; + s[4] = x[3]; + s[5] = x[4]; + s[6] = x[1]; + s[7] = x[6]; + + // stage 2. + butterfly_rotation(&s[0], &s[1], 60 - 0, true); + butterfly_rotation(&s[2], &s[3], 60 - 16, true); + butterfly_rotation(&s[4], &s[5], 60 - 32, true); + butterfly_rotation(&s[6], &s[7], 60 - 48, true); + + // stage 3. + HadamardRotation(&s[0], &s[4], false); + HadamardRotation(&s[1], &s[5], false); + HadamardRotation(&s[2], &s[6], false); + HadamardRotation(&s[3], &s[7], false); + + // stage 4. + butterfly_rotation(&s[4], &s[5], 48 - 0, true); + butterfly_rotation(&s[7], &s[6], 48 - 32, true); + + // stage 5. + HadamardRotation(&s[0], &s[2], false); + HadamardRotation(&s[4], &s[6], false); + HadamardRotation(&s[1], &s[3], false); + HadamardRotation(&s[5], &s[7], false); + + // stage 6. + butterfly_rotation(&s[2], &s[3], 32, true); + butterfly_rotation(&s[6], &s[7], 32, true); + + // stage 7. + x[0] = s[0]; + x[1] = vqnegq_s16(s[4]); + x[2] = s[6]; + x[3] = vqnegq_s16(s[2]); + x[4] = s[3]; + x[5] = vqnegq_s16(s[7]); + x[6] = s[5]; + x[7] = vqnegq_s16(s[1]); + + if (stage_is_rectangular) { + if (transpose) { + int16x8_t output[4]; + Transpose4x8To8x4(x, output); + StoreDst<16, 4>(dst, step, 0, output); + } else { + StoreDst<8, 8>(dst, step, 0, x); + } + } else { + if (transpose) { + dsp::Transpose8x8(x); + StoreDst<16, 8>(dst, step, 0, x); + } else { + StoreDst<16, 8>(dst, step, 0, x); + } + } +} + +LIBGAV1_ALWAYS_INLINE bool Adst8DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + int16x8_t s[8]; + + const int16x8_t v_src = vdupq_n_s16(dst[0]); + const uint16x8_t v_mask = vdupq_n_u16(should_round ? 0xffff : 0); + const int16x8_t v_src_round = + vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3); + // stage 1. + s[1] = vbslq_s16(v_mask, v_src_round, v_src); + + // stage 2. + ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true); + + // stage 3. + s[4] = s[0]; + s[5] = s[1]; + + // stage 4. + ButterflyRotation_4(&s[4], &s[5], 48, true); + + // stage 5. + s[2] = s[0]; + s[3] = s[1]; + s[6] = s[4]; + s[7] = s[5]; + + // stage 6. + ButterflyRotation_4(&s[2], &s[3], 32, true); + ButterflyRotation_4(&s[6], &s[7], 32, true); + + // stage 7. + int16x8_t x[8]; + x[0] = s[0]; + x[1] = vqnegq_s16(s[4]); + x[2] = s[6]; + x[3] = vqnegq_s16(s[2]); + x[4] = s[3]; + x[5] = vqnegq_s16(s[7]); + x[6] = s[5]; + x[7] = vqnegq_s16(s[1]); + + for (int i = 0; i < 8; ++i) { + // vqrshlq_s16 will shift right if shift value is negative. + x[i] = vqrshlq_s16(x[i], vdupq_n_s16(-row_shift)); + vst1q_lane_s16(&dst[i], x[i], 0); + } + + return true; +} + +LIBGAV1_ALWAYS_INLINE bool Adst8DcOnlyColumn(void* dest, int adjusted_tx_height, + int width) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + int16x8_t s[8]; + + int i = 0; + do { + const int16x8_t v_src = vld1q_s16(dst); + // stage 1. + s[1] = v_src; + + // stage 2. + ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true); + + // stage 3. + s[4] = s[0]; + s[5] = s[1]; + + // stage 4. + ButterflyRotation_4(&s[4], &s[5], 48, true); + + // stage 5. + s[2] = s[0]; + s[3] = s[1]; + s[6] = s[4]; + s[7] = s[5]; + + // stage 6. + ButterflyRotation_4(&s[2], &s[3], 32, true); + ButterflyRotation_4(&s[6], &s[7], 32, true); + + // stage 7. + int16x8_t x[8]; + x[0] = s[0]; + x[1] = vqnegq_s16(s[4]); + x[2] = s[6]; + x[3] = vqnegq_s16(s[2]); + x[4] = s[3]; + x[5] = vqnegq_s16(s[7]); + x[6] = s[5]; + x[7] = vqnegq_s16(s[1]); + + for (int j = 0; j < 8; ++j) { + vst1_s16(&dst[j * width], vget_low_s16(x[j])); + } + i += 4; + dst += 4; + } while (i < width); + + return true; +} + +template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular> +LIBGAV1_ALWAYS_INLINE void Adst16_NEON(void* dest, int32_t step, bool is_row, + int row_shift) { + auto* const dst = static_cast<int16_t*>(dest); + int16x8_t s[16], x[16]; + + if (stage_is_rectangular) { + if (is_row) { + int16x8_t input[4]; + LoadSrc<16, 4>(dst, step, 0, input); + Transpose8x4To4x8(input, x); + LoadSrc<16, 4>(dst, step, 8, input); + Transpose8x4To4x8(input, &x[8]); + } else { + LoadSrc<8, 16>(dst, step, 0, x); + } + } else { + if (is_row) { + for (int idx = 0; idx < 16; idx += 8) { + LoadSrc<16, 8>(dst, step, idx, &x[idx]); + dsp::Transpose8x8(&x[idx]); + } + } else { + LoadSrc<16, 16>(dst, step, 0, x); + } + } + + // stage 1. + s[0] = x[15]; + s[1] = x[0]; + s[2] = x[13]; + s[3] = x[2]; + s[4] = x[11]; + s[5] = x[4]; + s[6] = x[9]; + s[7] = x[6]; + s[8] = x[7]; + s[9] = x[8]; + s[10] = x[5]; + s[11] = x[10]; + s[12] = x[3]; + s[13] = x[12]; + s[14] = x[1]; + s[15] = x[14]; + + // stage 2. + butterfly_rotation(&s[0], &s[1], 62 - 0, true); + butterfly_rotation(&s[2], &s[3], 62 - 8, true); + butterfly_rotation(&s[4], &s[5], 62 - 16, true); + butterfly_rotation(&s[6], &s[7], 62 - 24, true); + butterfly_rotation(&s[8], &s[9], 62 - 32, true); + butterfly_rotation(&s[10], &s[11], 62 - 40, true); + butterfly_rotation(&s[12], &s[13], 62 - 48, true); + butterfly_rotation(&s[14], &s[15], 62 - 56, true); + + // stage 3. + HadamardRotation(&s[0], &s[8], false); + HadamardRotation(&s[1], &s[9], false); + HadamardRotation(&s[2], &s[10], false); + HadamardRotation(&s[3], &s[11], false); + HadamardRotation(&s[4], &s[12], false); + HadamardRotation(&s[5], &s[13], false); + HadamardRotation(&s[6], &s[14], false); + HadamardRotation(&s[7], &s[15], false); + + // stage 4. + butterfly_rotation(&s[8], &s[9], 56 - 0, true); + butterfly_rotation(&s[13], &s[12], 8 + 0, true); + butterfly_rotation(&s[10], &s[11], 56 - 32, true); + butterfly_rotation(&s[15], &s[14], 8 + 32, true); + + // stage 5. + HadamardRotation(&s[0], &s[4], false); + HadamardRotation(&s[8], &s[12], false); + HadamardRotation(&s[1], &s[5], false); + HadamardRotation(&s[9], &s[13], false); + HadamardRotation(&s[2], &s[6], false); + HadamardRotation(&s[10], &s[14], false); + HadamardRotation(&s[3], &s[7], false); + HadamardRotation(&s[11], &s[15], false); + + // stage 6. + butterfly_rotation(&s[4], &s[5], 48 - 0, true); + butterfly_rotation(&s[12], &s[13], 48 - 0, true); + butterfly_rotation(&s[7], &s[6], 48 - 32, true); + butterfly_rotation(&s[15], &s[14], 48 - 32, true); + + // stage 7. + HadamardRotation(&s[0], &s[2], false); + HadamardRotation(&s[4], &s[6], false); + HadamardRotation(&s[8], &s[10], false); + HadamardRotation(&s[12], &s[14], false); + HadamardRotation(&s[1], &s[3], false); + HadamardRotation(&s[5], &s[7], false); + HadamardRotation(&s[9], &s[11], false); + HadamardRotation(&s[13], &s[15], false); + + // stage 8. + butterfly_rotation(&s[2], &s[3], 32, true); + butterfly_rotation(&s[6], &s[7], 32, true); + butterfly_rotation(&s[10], &s[11], 32, true); + butterfly_rotation(&s[14], &s[15], 32, true); + + // stage 9. + x[0] = s[0]; + x[1] = vqnegq_s16(s[8]); + x[2] = s[12]; + x[3] = vqnegq_s16(s[4]); + x[4] = s[6]; + x[5] = vqnegq_s16(s[14]); + x[6] = s[10]; + x[7] = vqnegq_s16(s[2]); + x[8] = s[3]; + x[9] = vqnegq_s16(s[11]); + x[10] = s[15]; + x[11] = vqnegq_s16(s[7]); + x[12] = s[5]; + x[13] = vqnegq_s16(s[13]); + x[14] = s[9]; + x[15] = vqnegq_s16(s[1]); + + if (stage_is_rectangular) { + if (is_row) { + const int16x8_t v_row_shift = vdupq_n_s16(-row_shift); + int16x8_t output[4]; + Transpose4x8To8x4(x, output); + for (int i = 0; i < 4; ++i) { + output[i] = vqrshlq_s16(output[i], v_row_shift); + } + StoreDst<16, 4>(dst, step, 0, output); + Transpose4x8To8x4(&x[8], output); + for (int i = 0; i < 4; ++i) { + output[i] = vqrshlq_s16(output[i], v_row_shift); + } + StoreDst<16, 4>(dst, step, 8, output); + } else { + StoreDst<8, 16>(dst, step, 0, x); + } + } else { + if (is_row) { + const int16x8_t v_row_shift = vdupq_n_s16(-row_shift); + for (int idx = 0; idx < 16; idx += 8) { + int16x8_t output[8]; + Transpose8x8(&x[idx], output); + for (int i = 0; i < 8; ++i) { + output[i] = vqrshlq_s16(output[i], v_row_shift); + } + StoreDst<16, 8>(dst, step, idx, output); + } + } else { + StoreDst<16, 16>(dst, step, 0, x); + } + } +} + +LIBGAV1_ALWAYS_INLINE void Adst16DcOnlyInternal(int16x8_t* s, int16x8_t* x) { + // stage 2. + ButterflyRotation_FirstIsZero(&s[0], &s[1], 62, true); + + // stage 3. + s[8] = s[0]; + s[9] = s[1]; + + // stage 4. + ButterflyRotation_4(&s[8], &s[9], 56, true); + + // stage 5. + s[4] = s[0]; + s[12] = s[8]; + s[5] = s[1]; + s[13] = s[9]; + + // stage 6. + ButterflyRotation_4(&s[4], &s[5], 48, true); + ButterflyRotation_4(&s[12], &s[13], 48, true); + + // stage 7. + s[2] = s[0]; + s[6] = s[4]; + s[10] = s[8]; + s[14] = s[12]; + s[3] = s[1]; + s[7] = s[5]; + s[11] = s[9]; + s[15] = s[13]; + + // stage 8. + ButterflyRotation_4(&s[2], &s[3], 32, true); + ButterflyRotation_4(&s[6], &s[7], 32, true); + ButterflyRotation_4(&s[10], &s[11], 32, true); + ButterflyRotation_4(&s[14], &s[15], 32, true); + + // stage 9. + x[0] = s[0]; + x[1] = vqnegq_s16(s[8]); + x[2] = s[12]; + x[3] = vqnegq_s16(s[4]); + x[4] = s[6]; + x[5] = vqnegq_s16(s[14]); + x[6] = s[10]; + x[7] = vqnegq_s16(s[2]); + x[8] = s[3]; + x[9] = vqnegq_s16(s[11]); + x[10] = s[15]; + x[11] = vqnegq_s16(s[7]); + x[12] = s[5]; + x[13] = vqnegq_s16(s[13]); + x[14] = s[9]; + x[15] = vqnegq_s16(s[1]); +} + +LIBGAV1_ALWAYS_INLINE bool Adst16DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + int16x8_t s[16]; + int16x8_t x[16]; + + const int16x8_t v_src = vdupq_n_s16(dst[0]); + const uint16x8_t v_mask = vdupq_n_u16(should_round ? 0xffff : 0); + const int16x8_t v_src_round = + vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3); + // stage 1. + s[1] = vbslq_s16(v_mask, v_src_round, v_src); + + Adst16DcOnlyInternal(s, x); + + for (int i = 0; i < 16; ++i) { + // vqrshlq_s16 will shift right if shift value is negative. + x[i] = vqrshlq_s16(x[i], vdupq_n_s16(-row_shift)); + vst1q_lane_s16(&dst[i], x[i], 0); + } + + return true; +} + +LIBGAV1_ALWAYS_INLINE bool Adst16DcOnlyColumn(void* dest, + int adjusted_tx_height, + int width) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + int i = 0; + do { + int16x8_t s[16]; + int16x8_t x[16]; + const int16x8_t v_src = vld1q_s16(dst); + // stage 1. + s[1] = v_src; + + Adst16DcOnlyInternal(s, x); + + for (int j = 0; j < 16; ++j) { + vst1_s16(&dst[j * width], vget_low_s16(x[j])); + } + i += 4; + dst += 4; + } while (i < width); + + return true; +} + +//------------------------------------------------------------------------------ +// Identity Transforms. + +template <bool is_row_shift> +LIBGAV1_ALWAYS_INLINE void Identity4_NEON(void* dest, int32_t step) { + auto* const dst = static_cast<int16_t*>(dest); + + if (is_row_shift) { + const int shift = 1; + const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11); + const int16x4_t v_multiplier = vdup_n_s16(kIdentity4Multiplier); + const int32x4_t v_shift = vdupq_n_s32(-(12 + shift)); + for (int i = 0; i < 4; i += 2) { + const int16x8_t v_src = vld1q_s16(&dst[i * step]); + const int32x4_t v_src_mult_lo = + vmlal_s16(v_dual_round, vget_low_s16(v_src), v_multiplier); + const int32x4_t v_src_mult_hi = + vmlal_s16(v_dual_round, vget_high_s16(v_src), v_multiplier); + const int32x4_t shift_lo = vqshlq_s32(v_src_mult_lo, v_shift); + const int32x4_t shift_hi = vqshlq_s32(v_src_mult_hi, v_shift); + vst1q_s16(&dst[i * step], + vcombine_s16(vqmovn_s32(shift_lo), vqmovn_s32(shift_hi))); + } + } else { + for (int i = 0; i < 4; i += 2) { + const int16x8_t v_src = vld1q_s16(&dst[i * step]); + const int16x8_t a = + vqrdmulhq_n_s16(v_src, kIdentity4MultiplierFraction << 3); + const int16x8_t b = vqaddq_s16(v_src, a); + vst1q_s16(&dst[i * step], b); + } + } +} + +LIBGAV1_ALWAYS_INLINE bool Identity4DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int tx_height) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + const int16x4_t v_src0 = vdup_n_s16(dst[0]); + const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0); + const int16x4_t v_src_round = + vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3); + const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0); + const int shift = tx_height < 16 ? 0 : 1; + const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11); + const int16x4_t v_multiplier = vdup_n_s16(kIdentity4Multiplier); + const int32x4_t v_shift = vdupq_n_s32(-(12 + shift)); + const int32x4_t v_src_mult_lo = vmlal_s16(v_dual_round, v_src, v_multiplier); + const int32x4_t dst_0 = vqshlq_s32(v_src_mult_lo, v_shift); + vst1_lane_s16(dst, vqmovn_s32(dst_0), 0); + return true; +} + +template <int identity_size> +LIBGAV1_ALWAYS_INLINE void IdentityColumnStoreToFrame( + Array2DView<uint8_t> frame, const int start_x, const int start_y, + const int tx_width, const int tx_height, const int16_t* source) { + const int stride = frame.columns(); + uint8_t* dst = frame[start_y] + start_x; + + if (identity_size < 32) { + if (tx_width == 4) { + uint8x8_t frame_data = vdup_n_u8(0); + int i = 0; + do { + const int16x4_t v_src = vld1_s16(&source[i * tx_width]); + + int16x4_t v_dst_i; + if (identity_size == 4) { + const int16x4_t v_src_fraction = + vqrdmulh_n_s16(v_src, kIdentity4MultiplierFraction << 3); + v_dst_i = vqadd_s16(v_src, v_src_fraction); + } else if (identity_size == 8) { + v_dst_i = vqadd_s16(v_src, v_src); + } else { // identity_size == 16 + const int16x4_t v_src_mult = + vqrdmulh_n_s16(v_src, kIdentity4MultiplierFraction << 4); + const int16x4_t v_srcx2 = vqadd_s16(v_src, v_src); + v_dst_i = vqadd_s16(v_srcx2, v_src_mult); + } + + frame_data = Load4<0>(dst, frame_data); + const int16x4_t a = vrshr_n_s16(v_dst_i, 4); + const uint16x8_t b = + vaddw_u8(vreinterpretq_u16_s16(vcombine_s16(a, a)), frame_data); + const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b)); + StoreLo4(dst, d); + dst += stride; + } while (++i < tx_height); + } else { + int i = 0; + do { + const int row = i * tx_width; + int j = 0; + do { + const int16x8_t v_src = vld1q_s16(&source[row + j]); + + int16x8_t v_dst_i; + if (identity_size == 4) { + const int16x8_t v_src_fraction = + vqrdmulhq_n_s16(v_src, kIdentity4MultiplierFraction << 3); + v_dst_i = vqaddq_s16(v_src, v_src_fraction); + } else if (identity_size == 8) { + v_dst_i = vqaddq_s16(v_src, v_src); + } else { // identity_size == 16 + const int16x8_t v_src_mult = + vqrdmulhq_n_s16(v_src, kIdentity4MultiplierFraction << 4); + const int16x8_t v_srcx2 = vqaddq_s16(v_src, v_src); + v_dst_i = vqaddq_s16(v_src_mult, v_srcx2); + } + + const uint8x8_t frame_data = vld1_u8(dst + j); + const int16x8_t a = vrshrq_n_s16(v_dst_i, 4); + const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data); + const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b)); + vst1_u8(dst + j, d); + j += 8; + } while (j < tx_width); + dst += stride; + } while (++i < tx_height); + } + } else { + int i = 0; + do { + const int row = i * tx_width; + int j = 0; + do { + const int16x8_t v_dst_i = vld1q_s16(&source[row + j]); + const uint8x8_t frame_data = vld1_u8(dst + j); + const int16x8_t a = vrshrq_n_s16(v_dst_i, 2); + const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data); + const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b)); + vst1_u8(dst + j, d); + j += 8; + } while (j < tx_width); + dst += stride; + } while (++i < tx_height); + } +} + +LIBGAV1_ALWAYS_INLINE void Identity4RowColumnStoreToFrame( + Array2DView<uint8_t> frame, const int start_x, const int start_y, + const int tx_width, const int tx_height, const int16_t* source) { + const int stride = frame.columns(); + uint8_t* dst = frame[start_y] + start_x; + + if (tx_width == 4) { + uint8x8_t frame_data = vdup_n_u8(0); + int i = 0; + do { + const int16x4_t v_src = vld1_s16(&source[i * tx_width]); + const int16x4_t v_src_mult = + vqrdmulh_n_s16(v_src, kIdentity4MultiplierFraction << 3); + const int16x4_t v_dst_row = vqadd_s16(v_src, v_src_mult); + const int16x4_t v_src_mult2 = + vqrdmulh_n_s16(v_dst_row, kIdentity4MultiplierFraction << 3); + const int16x4_t v_dst_col = vqadd_s16(v_dst_row, v_src_mult2); + frame_data = Load4<0>(dst, frame_data); + const int16x4_t a = vrshr_n_s16(v_dst_col, 4); + const uint16x8_t b = + vaddw_u8(vreinterpretq_u16_s16(vcombine_s16(a, a)), frame_data); + const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b)); + StoreLo4(dst, d); + dst += stride; + } while (++i < tx_height); + } else { + int i = 0; + do { + const int row = i * tx_width; + int j = 0; + do { + const int16x8_t v_src = vld1q_s16(&source[row + j]); + const int16x8_t v_src_round = + vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3); + const int16x8_t v_dst_row = vqaddq_s16(v_src_round, v_src_round); + const int16x8_t v_src_mult2 = + vqrdmulhq_n_s16(v_dst_row, kIdentity4MultiplierFraction << 3); + const int16x8_t v_dst_col = vqaddq_s16(v_dst_row, v_src_mult2); + const uint8x8_t frame_data = vld1_u8(dst + j); + const int16x8_t a = vrshrq_n_s16(v_dst_col, 4); + const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data); + const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b)); + vst1_u8(dst + j, d); + j += 8; + } while (j < tx_width); + dst += stride; + } while (++i < tx_height); + } +} + +LIBGAV1_ALWAYS_INLINE void Identity8Row32_NEON(void* dest, int32_t step) { + auto* const dst = static_cast<int16_t*>(dest); + + // When combining the identity8 multiplier with the row shift, the + // calculations for tx_height equal to 32 can be simplified from + // ((A * 2) + 2) >> 2) to ((A + 1) >> 1). + for (int i = 0; i < 4; ++i) { + const int16x8_t v_src = vld1q_s16(&dst[i * step]); + const int16x8_t a = vrshrq_n_s16(v_src, 1); + vst1q_s16(&dst[i * step], a); + } +} + +LIBGAV1_ALWAYS_INLINE void Identity8Row4_NEON(void* dest, int32_t step) { + auto* const dst = static_cast<int16_t*>(dest); + + for (int i = 0; i < 4; ++i) { + const int16x8_t v_src = vld1q_s16(&dst[i * step]); + // For bitdepth == 8, the identity row clamps to a signed 16bit value, so + // saturating add here is ok. + const int16x8_t v_srcx2 = vqaddq_s16(v_src, v_src); + vst1q_s16(&dst[i * step], v_srcx2); + } +} + +LIBGAV1_ALWAYS_INLINE bool Identity8DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + const int16x4_t v_src0 = vdup_n_s16(dst[0]); + const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0); + const int16x4_t v_src_round = + vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3); + const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0); + const int32x4_t v_srcx2 = vaddl_s16(v_src, v_src); + const int32x4_t dst_0 = vqrshlq_s32(v_srcx2, vdupq_n_s32(-row_shift)); + vst1_lane_s16(dst, vqmovn_s32(dst_0), 0); + return true; +} + +LIBGAV1_ALWAYS_INLINE void Identity16Row_NEON(void* dest, int32_t step, + int shift) { + auto* const dst = static_cast<int16_t*>(dest); + const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11); + const int32x4_t v_shift = vdupq_n_s32(-(12 + shift)); + + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 2; ++j) { + const int16x8_t v_src = vld1q_s16(&dst[i * step + j * 8]); + const int32x4_t v_src_mult_lo = + vmlal_n_s16(v_dual_round, vget_low_s16(v_src), kIdentity16Multiplier); + const int32x4_t v_src_mult_hi = vmlal_n_s16( + v_dual_round, vget_high_s16(v_src), kIdentity16Multiplier); + const int32x4_t shift_lo = vqshlq_s32(v_src_mult_lo, v_shift); + const int32x4_t shift_hi = vqshlq_s32(v_src_mult_hi, v_shift); + vst1q_s16(&dst[i * step + j * 8], + vcombine_s16(vqmovn_s32(shift_lo), vqmovn_s32(shift_hi))); + } + } +} + +LIBGAV1_ALWAYS_INLINE bool Identity16DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + const int16x4_t v_src0 = vdup_n_s16(dst[0]); + const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0); + const int16x4_t v_src_round = + vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3); + const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0); + const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11); + const int16x4_t v_multiplier = vdup_n_s16(kIdentity16Multiplier); + const int32x4_t v_shift = vdupq_n_s32(-(12 + shift)); + const int32x4_t v_src_mult_lo = + vmlal_s16(v_dual_round, (v_src), v_multiplier); + const int32x4_t dst_0 = vqshlq_s32(v_src_mult_lo, v_shift); + vst1_lane_s16(dst, vqmovn_s32(dst_0), 0); + return true; +} + +LIBGAV1_ALWAYS_INLINE void Identity32Row16_NEON(void* dest, + const int32_t step) { + auto* const dst = static_cast<int16_t*>(dest); + + // When combining the identity32 multiplier with the row shift, the + // calculation for tx_height equal to 16 can be simplified from + // ((A * 4) + 1) >> 1) to (A * 2). + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 32; j += 8) { + const int16x8_t v_src = vld1q_s16(&dst[i * step + j]); + // For bitdepth == 8, the identity row clamps to a signed 16bit value, so + // saturating add here is ok. + const int16x8_t v_dst_i = vqaddq_s16(v_src, v_src); + vst1q_s16(&dst[i * step + j], v_dst_i); + } + } +} + +LIBGAV1_ALWAYS_INLINE bool Identity32DcOnly(void* dest, + int adjusted_tx_height) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + const int16x4_t v_src0 = vdup_n_s16(dst[0]); + const int16x4_t v_src = vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3); + // When combining the identity32 multiplier with the row shift, the + // calculation for tx_height equal to 16 can be simplified from + // ((A * 4) + 1) >> 1) to (A * 2). + const int16x4_t v_dst_0 = vqadd_s16(v_src, v_src); + vst1_lane_s16(dst, v_dst_0, 0); + return true; +} + +//------------------------------------------------------------------------------ +// Walsh Hadamard Transform. + +// Transposes a 4x4 matrix and then permutes the rows of the transposed matrix +// for the WHT. The input matrix is in two "wide" int16x8_t variables. The +// output matrix is in four int16x4_t variables. +// +// Input: +// in[0]: 00 01 02 03 10 11 12 13 +// in[1]: 20 21 22 23 30 31 32 33 +// Output: +// out[0]: 00 10 20 30 +// out[1]: 03 13 23 33 +// out[2]: 01 11 21 31 +// out[3]: 02 12 22 32 +LIBGAV1_ALWAYS_INLINE void TransposeAndPermute4x4WideInput( + const int16x8_t in[2], int16x4_t out[4]) { + // Swap 32 bit elements. Goes from: + // in[0]: 00 01 02 03 10 11 12 13 + // in[1]: 20 21 22 23 30 31 32 33 + // to: + // b0.val[0]: 00 01 20 21 10 11 30 31 + // b0.val[1]: 02 03 22 23 12 13 32 33 + + const int32x4x2_t b0 = + vtrnq_s32(vreinterpretq_s32_s16(in[0]), vreinterpretq_s32_s16(in[1])); + + // Swap 16 bit elements. Goes from: + // vget_low_s32(b0.val[0]): 00 01 20 21 + // vget_high_s32(b0.val[0]): 10 11 30 31 + // vget_low_s32(b0.val[1]): 02 03 22 23 + // vget_high_s32(b0.val[1]): 12 13 32 33 + // to: + // c0.val[0]: 00 10 20 30 + // c0.val[1]: 01 11 21 32 + // c1.val[0]: 02 12 22 32 + // c1.val[1]: 03 13 23 33 + + const int16x4x2_t c0 = + vtrn_s16(vreinterpret_s16_s32(vget_low_s32(b0.val[0])), + vreinterpret_s16_s32(vget_high_s32(b0.val[0]))); + const int16x4x2_t c1 = + vtrn_s16(vreinterpret_s16_s32(vget_low_s32(b0.val[1])), + vreinterpret_s16_s32(vget_high_s32(b0.val[1]))); + + out[0] = c0.val[0]; + out[1] = c1.val[1]; + out[2] = c0.val[1]; + out[3] = c1.val[0]; +} + +// Process 4 wht4 rows and columns. +LIBGAV1_ALWAYS_INLINE void Wht4_NEON(uint8_t* dst, const int dst_stride, + const void* source, + const int adjusted_tx_height) { + const auto* const src = static_cast<const int16_t*>(source); + int16x4_t s[4]; + + if (adjusted_tx_height == 1) { + // Special case: only src[0] is nonzero. + // src[0] 0 0 0 + // 0 0 0 0 + // 0 0 0 0 + // 0 0 0 0 + // + // After the row and column transforms are applied, we have: + // f h h h + // g i i i + // g i i i + // g i i i + // where f, g, h, i are computed as follows. + int16_t f = (src[0] >> 2) - (src[0] >> 3); + const int16_t g = f >> 1; + f = f - (f >> 1); + const int16_t h = (src[0] >> 3) - (src[0] >> 4); + const int16_t i = (src[0] >> 4); + s[0] = vdup_n_s16(h); + s[0] = vset_lane_s16(f, s[0], 0); + s[1] = vdup_n_s16(i); + s[1] = vset_lane_s16(g, s[1], 0); + s[2] = s[3] = s[1]; + } else { + // Load the 4x4 source in transposed form. + int16x4x4_t columns = vld4_s16(src); + // Shift right and permute the columns for the WHT. + s[0] = vshr_n_s16(columns.val[0], 2); + s[2] = vshr_n_s16(columns.val[1], 2); + s[3] = vshr_n_s16(columns.val[2], 2); + s[1] = vshr_n_s16(columns.val[3], 2); + + // Row transforms. + s[0] = vadd_s16(s[0], s[2]); + s[3] = vsub_s16(s[3], s[1]); + int16x4_t e = vhsub_s16(s[0], s[3]); // e = (s[0] - s[3]) >> 1 + s[1] = vsub_s16(e, s[1]); + s[2] = vsub_s16(e, s[2]); + s[0] = vsub_s16(s[0], s[1]); + s[3] = vadd_s16(s[3], s[2]); + + int16x8_t x[2]; + x[0] = vcombine_s16(s[0], s[1]); + x[1] = vcombine_s16(s[2], s[3]); + TransposeAndPermute4x4WideInput(x, s); + + // Column transforms. + s[0] = vadd_s16(s[0], s[2]); + s[3] = vsub_s16(s[3], s[1]); + e = vhsub_s16(s[0], s[3]); // e = (s[0] - s[3]) >> 1 + s[1] = vsub_s16(e, s[1]); + s[2] = vsub_s16(e, s[2]); + s[0] = vsub_s16(s[0], s[1]); + s[3] = vadd_s16(s[3], s[2]); + } + + // Store to frame. + uint8x8_t frame_data = vdup_n_u8(0); + for (int row = 0; row < 4; row += 2) { + frame_data = Load4<0>(dst, frame_data); + frame_data = Load4<1>(dst + dst_stride, frame_data); + const int16x8_t residual = vcombine_s16(s[row], s[row + 1]); + const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(residual), frame_data); + frame_data = vqmovun_s16(vreinterpretq_s16_u16(b)); + StoreLo4(dst, frame_data); + dst += dst_stride; + StoreHi4(dst, frame_data); + dst += dst_stride; + } +} + +//------------------------------------------------------------------------------ +// row/column transform loops + +template <int tx_height> +LIBGAV1_ALWAYS_INLINE void FlipColumns(int16_t* source, int tx_width) { + if (tx_width >= 16) { + int i = 0; + do { + const int16x8_t a = vld1q_s16(&source[i]); + const int16x8_t b = vld1q_s16(&source[i + 8]); + const int16x8_t c = vrev64q_s16(a); + const int16x8_t d = vrev64q_s16(b); + vst1q_s16(&source[i], vcombine_s16(vget_high_s16(d), vget_low_s16(d))); + vst1q_s16(&source[i + 8], + vcombine_s16(vget_high_s16(c), vget_low_s16(c))); + i += 16; + } while (i < tx_width * tx_height); + } else if (tx_width == 8) { + for (int i = 0; i < 8 * tx_height; i += 8) { + const int16x8_t a = vld1q_s16(&source[i]); + const int16x8_t b = vrev64q_s16(a); + vst1q_s16(&source[i], vcombine_s16(vget_high_s16(b), vget_low_s16(b))); + } + } else { + // Process two rows per iteration. + for (int i = 0; i < 4 * tx_height; i += 8) { + const int16x8_t a = vld1q_s16(&source[i]); + vst1q_s16(&source[i], vrev64q_s16(a)); + } + } +} + +template <int tx_width> +LIBGAV1_ALWAYS_INLINE void ApplyRounding(int16_t* source, int num_rows) { + if (tx_width == 4) { + // Process two rows per iteration. + int i = 0; + do { + const int16x8_t a = vld1q_s16(&source[i]); + const int16x8_t b = vqrdmulhq_n_s16(a, kTransformRowMultiplier << 3); + vst1q_s16(&source[i], b); + i += 8; + } while (i < tx_width * num_rows); + } else { + int i = 0; + do { + // The last 32 values of every row are always zero if the |tx_width| is + // 64. + const int non_zero_width = (tx_width < 64) ? tx_width : 32; + int j = 0; + do { + const int16x8_t a = vld1q_s16(&source[i * tx_width + j]); + const int16x8_t b = vqrdmulhq_n_s16(a, kTransformRowMultiplier << 3); + vst1q_s16(&source[i * tx_width + j], b); + j += 8; + } while (j < non_zero_width); + } while (++i < num_rows); + } +} + +template <int tx_width> +LIBGAV1_ALWAYS_INLINE void RowShift(int16_t* source, int num_rows, + int row_shift) { + // vqrshlq_s16 will shift right if shift value is negative. + row_shift = -row_shift; + + if (tx_width == 4) { + // Process two rows per iteration. + int i = 0; + do { + const int16x8_t residual = vld1q_s16(&source[i]); + vst1q_s16(&source[i], vqrshlq_s16(residual, vdupq_n_s16(row_shift))); + i += 8; + } while (i < tx_width * num_rows); + } else { + int i = 0; + do { + for (int j = 0; j < tx_width; j += 8) { + const int16x8_t residual = vld1q_s16(&source[i * tx_width + j]); + const int16x8_t residual_shifted = + vqrshlq_s16(residual, vdupq_n_s16(row_shift)); + vst1q_s16(&source[i * tx_width + j], residual_shifted); + } + } while (++i < num_rows); + } +} + +template <int tx_height, bool enable_flip_rows = false> +LIBGAV1_ALWAYS_INLINE void StoreToFrameWithRound( + Array2DView<uint8_t> frame, const int start_x, const int start_y, + const int tx_width, const int16_t* source, TransformType tx_type) { + const bool flip_rows = + enable_flip_rows ? kTransformFlipRowsMask.Contains(tx_type) : false; + const int stride = frame.columns(); + uint8_t* dst = frame[start_y] + start_x; + + // Enable for 4x4, 4x8, 4x16 + if (tx_height < 32 && tx_width == 4) { + uint8x8_t frame_data = vdup_n_u8(0); + for (int i = 0; i < tx_height; ++i) { + const int row = flip_rows ? (tx_height - i - 1) * 4 : i * 4; + const int16x4_t residual = vld1_s16(&source[row]); + frame_data = Load4<0>(dst, frame_data); + const int16x4_t a = vrshr_n_s16(residual, 4); + const uint16x8_t b = + vaddw_u8(vreinterpretq_u16_s16(vcombine_s16(a, a)), frame_data); + const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b)); + StoreLo4(dst, d); + dst += stride; + } + // Enable for 8x4, 8x8, 8x16, 8x32 + } else if (tx_height < 64 && tx_width == 8) { + for (int i = 0; i < tx_height; ++i) { + const int row = flip_rows ? (tx_height - i - 1) * 8 : i * 8; + const int16x8_t residual = vld1q_s16(&source[row]); + const uint8x8_t frame_data = vld1_u8(dst); + const int16x8_t a = vrshrq_n_s16(residual, 4); + const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data); + const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b)); + vst1_u8(dst, d); + dst += stride; + } + // Remaining widths >= 16. + } else { + for (int i = 0; i < tx_height; ++i) { + const int y = start_y + i; + const int row = flip_rows ? (tx_height - i - 1) * tx_width : i * tx_width; + int j = 0; + do { + const int x = start_x + j; + const int16x8_t residual = vld1q_s16(&source[row + j]); + const int16x8_t residual_hi = vld1q_s16(&source[row + j + 8]); + const uint8x16_t frame_data = vld1q_u8(frame[y] + x); + const int16x8_t a = vrshrq_n_s16(residual, 4); + const int16x8_t a_hi = vrshrq_n_s16(residual_hi, 4); + const uint16x8_t b = + vaddw_u8(vreinterpretq_u16_s16(a), vget_low_u8(frame_data)); + const uint16x8_t b_hi = + vaddw_u8(vreinterpretq_u16_s16(a_hi), vget_high_u8(frame_data)); + vst1q_u8(frame[y] + x, + vcombine_u8(vqmovun_s16(vreinterpretq_s16_u16(b)), + vqmovun_s16(vreinterpretq_s16_u16(b_hi)))); + j += 16; + } while (j < tx_width); + } + } +} + +void Dct4TransformLoopRow_NEON(TransformType /*tx_type*/, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_height = kTransformHeight[tx_size]; + const bool should_round = (tx_height == 8); + const int row_shift = (tx_height == 16); + + if (DctDcOnly<4>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<4>(src, adjusted_tx_height); + } + + if (adjusted_tx_height == 4) { + // Process 4 1d dct4 rows in parallel. + Dct4_NEON<ButterflyRotation_4, false>(src, /*step=*/4, /*transpose=*/true); + } else { + // Process 8 1d dct4 rows in parallel per iteration. + int i = adjusted_tx_height; + auto* data = src; + do { + Dct4_NEON<ButterflyRotation_8, true>(data, /*step=*/4, + /*transpose=*/true); + data += 32; + i -= 8; + } while (i != 0); + } + if (tx_height == 16) { + RowShift<4>(src, adjusted_tx_height, 1); + } +} + +void Dct4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<4>(src, tx_width); + } + + if (!DctDcOnlyColumn<4>(src, adjusted_tx_height, tx_width)) { + if (tx_width == 4) { + // Process 4 1d dct4 columns in parallel. + Dct4_NEON<ButterflyRotation_4, false>(src, tx_width, /*transpose=*/false); + } else { + // Process 8 1d dct4 columns in parallel per iteration. + int i = tx_width; + auto* data = src; + do { + Dct4_NEON<ButterflyRotation_8, true>(data, tx_width, + /*transpose=*/false); + data += 8; + i -= 8; + } while (i != 0); + } + } + + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound<4>(frame, start_x, start_y, tx_width, src, tx_type); +} + +void Dct8TransformLoopRow_NEON(TransformType /*tx_type*/, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (DctDcOnly<8>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<8>(src, adjusted_tx_height); + } + + if (adjusted_tx_height == 4) { + // Process 4 1d dct8 rows in parallel. + Dct8_NEON<ButterflyRotation_4, true>(src, /*step=*/8, /*transpose=*/true); + } else { + // Process 8 1d dct8 rows in parallel per iteration. + assert(adjusted_tx_height % 8 == 0); + int i = adjusted_tx_height; + auto* data = src; + do { + Dct8_NEON<ButterflyRotation_8, false>(data, /*step=*/8, + /*transpose=*/true); + data += 64; + i -= 8; + } while (i != 0); + } + if (row_shift > 0) { + RowShift<8>(src, adjusted_tx_height, row_shift); + } +} + +void Dct8TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<8>(src, tx_width); + } + + if (!DctDcOnlyColumn<8>(src, adjusted_tx_height, tx_width)) { + if (tx_width == 4) { + // Process 4 1d dct8 columns in parallel. + Dct8_NEON<ButterflyRotation_4, true>(src, 4, /*transpose=*/false); + } else { + // Process 8 1d dct8 columns in parallel per iteration. + int i = tx_width; + auto* data = src; + do { + Dct8_NEON<ButterflyRotation_8, false>(data, tx_width, + /*transpose=*/false); + data += 8; + i -= 8; + } while (i != 0); + } + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound<8>(frame, start_x, start_y, tx_width, src, tx_type); +} + +void Dct16TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (DctDcOnly<16>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<16>(src, adjusted_tx_height); + } + + if (adjusted_tx_height == 4) { + // Process 4 1d dct16 rows in parallel. + Dct16_NEON<ButterflyRotation_4, true>(src, 16, /*is_row=*/true, row_shift); + } else { + assert(adjusted_tx_height % 8 == 0); + int i = adjusted_tx_height; + do { + // Process 8 1d dct16 rows in parallel per iteration. + Dct16_NEON<ButterflyRotation_8, false>(src, 16, /*is_row=*/true, + row_shift); + src += 128; + i -= 8; + } while (i != 0); + } +} + +void Dct16TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<16>(src, tx_width); + } + + if (!DctDcOnlyColumn<16>(src, adjusted_tx_height, tx_width)) { + if (tx_width == 4) { + // Process 4 1d dct16 columns in parallel. + Dct16_NEON<ButterflyRotation_4, true>(src, 4, /*is_row=*/false, + /*row_shift=*/0); + } else { + int i = tx_width; + auto* data = src; + do { + // Process 8 1d dct16 columns in parallel per iteration. + Dct16_NEON<ButterflyRotation_8, false>(data, tx_width, /*is_row=*/false, + /*row_shift=*/0); + data += 8; + i -= 8; + } while (i != 0); + } + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound<16>(frame, start_x, start_y, tx_width, src, tx_type); +} + +void Dct32TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (DctDcOnly<32>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<32>(src, adjusted_tx_height); + } + // Process 8 1d dct32 rows in parallel per iteration. + int i = 0; + do { + Dct32_NEON(&src[i * 32], 32, /*is_row=*/true, row_shift); + i += 8; + } while (i < adjusted_tx_height); +} + +void Dct32TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (!DctDcOnlyColumn<32>(src, adjusted_tx_height, tx_width)) { + // Process 8 1d dct32 columns in parallel per iteration. + int i = tx_width; + auto* data = src; + do { + Dct32_NEON(data, tx_width, /*is_row=*/false, /*row_shift=*/0); + data += 8; + i -= 8; + } while (i != 0); + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound<32>(frame, start_x, start_y, tx_width, src, tx_type); +} + +void Dct64TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (DctDcOnly<64>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<64>(src, adjusted_tx_height); + } + // Process 8 1d dct64 rows in parallel per iteration. + int i = 0; + do { + Dct64_NEON(&src[i * 64], 64, /*is_row=*/true, row_shift); + i += 8; + } while (i < adjusted_tx_height); +} + +void Dct64TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (!DctDcOnlyColumn<64>(src, adjusted_tx_height, tx_width)) { + // Process 8 1d dct64 columns in parallel per iteration. + int i = tx_width; + auto* data = src; + do { + Dct64_NEON(data, tx_width, /*is_row=*/false, /*row_shift=*/0); + data += 8; + i -= 8; + } while (i != 0); + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound<64>(frame, start_x, start_y, tx_width, src, tx_type); +} + +void Adst4TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_height = kTransformHeight[tx_size]; + const int row_shift = static_cast<int>(tx_height == 16); + const bool should_round = (tx_height == 8); + + if (Adst4DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<4>(src, adjusted_tx_height); + } + + // Process 4 1d adst4 rows in parallel per iteration. + int i = adjusted_tx_height; + auto* data = src; + do { + Adst4_NEON<false>(data, /*step=*/4, /*transpose=*/true); + data += 16; + i -= 4; + } while (i != 0); + + if (tx_height == 16) { + RowShift<4>(src, adjusted_tx_height, 1); + } +} + +void Adst4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<4>(src, tx_width); + } + + if (!Adst4DcOnlyColumn(src, adjusted_tx_height, tx_width)) { + // Process 4 1d adst4 columns in parallel per iteration. + int i = tx_width; + auto* data = src; + do { + Adst4_NEON<false>(data, tx_width, /*transpose=*/false); + data += 4; + i -= 4; + } while (i != 0); + } + + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound<4, /*enable_flip_rows=*/true>(frame, start_x, start_y, + tx_width, src, tx_type); +} + +void Adst8TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (Adst8DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<8>(src, adjusted_tx_height); + } + + if (adjusted_tx_height == 4) { + // Process 4 1d adst8 rows in parallel. + Adst8_NEON<ButterflyRotation_4, true>(src, /*step=*/8, /*transpose=*/true); + } else { + // Process 8 1d adst8 rows in parallel per iteration. + assert(adjusted_tx_height % 8 == 0); + int i = adjusted_tx_height; + auto* data = src; + do { + Adst8_NEON<ButterflyRotation_8, false>(data, /*step=*/8, + /*transpose=*/true); + data += 64; + i -= 8; + } while (i != 0); + } + if (row_shift > 0) { + RowShift<8>(src, adjusted_tx_height, row_shift); + } +} + +void Adst8TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<8>(src, tx_width); + } + + if (!Adst8DcOnlyColumn(src, adjusted_tx_height, tx_width)) { + if (tx_width == 4) { + // Process 4 1d adst8 columns in parallel. + Adst8_NEON<ButterflyRotation_4, true>(src, 4, /*transpose=*/false); + } else { + // Process 8 1d adst8 columns in parallel per iteration. + int i = tx_width; + auto* data = src; + do { + Adst8_NEON<ButterflyRotation_8, false>(data, tx_width, + /*transpose=*/false); + data += 8; + i -= 8; + } while (i != 0); + } + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound<8, /*enable_flip_rows=*/true>(frame, start_x, start_y, + tx_width, src, tx_type); +} + +void Adst16TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (Adst16DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<16>(src, adjusted_tx_height); + } + + if (adjusted_tx_height == 4) { + // Process 4 1d adst16 rows in parallel. + Adst16_NEON<ButterflyRotation_4, true>(src, 16, /*is_row=*/true, row_shift); + } else { + assert(adjusted_tx_height % 8 == 0); + int i = adjusted_tx_height; + do { + // Process 8 1d adst16 rows in parallel per iteration. + Adst16_NEON<ButterflyRotation_8, false>(src, 16, /*is_row=*/true, + row_shift); + src += 128; + i -= 8; + } while (i != 0); + } +} + +void Adst16TransformLoopColumn_NEON(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<16>(src, tx_width); + } + + if (!Adst16DcOnlyColumn(src, adjusted_tx_height, tx_width)) { + if (tx_width == 4) { + // Process 4 1d adst16 columns in parallel. + Adst16_NEON<ButterflyRotation_4, true>(src, 4, /*is_row=*/false, + /*row_shift=*/0); + } else { + int i = tx_width; + auto* data = src; + do { + // Process 8 1d adst16 columns in parallel per iteration. + Adst16_NEON<ButterflyRotation_8, false>( + data, tx_width, /*is_row=*/false, /*row_shift=*/0); + data += 8; + i -= 8; + } while (i != 0); + } + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound<16, /*enable_flip_rows=*/true>(frame, start_x, start_y, + tx_width, src, tx_type); +} + +void Identity4TransformLoopRow_NEON(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + // Special case: Process row calculations during column transform call. + // Improves performance. + if (tx_type == kTransformTypeIdentityIdentity && + tx_size == kTransformSize4x4) { + return; + } + + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_height = kTransformHeight[tx_size]; + const bool should_round = (tx_height == 8); + + if (Identity4DcOnly(src, adjusted_tx_height, should_round, tx_height)) { + return; + } + + if (should_round) { + ApplyRounding<4>(src, adjusted_tx_height); + } + if (tx_height < 16) { + int i = adjusted_tx_height; + do { + Identity4_NEON<false>(src, /*step=*/4); + src += 16; + i -= 4; + } while (i != 0); + } else { + int i = adjusted_tx_height; + do { + Identity4_NEON<true>(src, /*step=*/4); + src += 16; + i -= 4; + } while (i != 0); + } +} + +void Identity4TransformLoopColumn_NEON(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, + void* dst_frame) { + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + // Special case: Process row calculations during column transform call. + if (tx_type == kTransformTypeIdentityIdentity && + (tx_size == kTransformSize4x4 || tx_size == kTransformSize8x4)) { + Identity4RowColumnStoreToFrame(frame, start_x, start_y, tx_width, + adjusted_tx_height, src); + return; + } + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<4>(src, tx_width); + } + + IdentityColumnStoreToFrame<4>(frame, start_x, start_y, tx_width, + adjusted_tx_height, src); +} + +void Identity8TransformLoopRow_NEON(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + // Special case: Process row calculations during column transform call. + // Improves performance. + if (tx_type == kTransformTypeIdentityIdentity && + tx_size == kTransformSize8x4) { + return; + } + + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_height = kTransformHeight[tx_size]; + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (Identity8DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<8>(src, adjusted_tx_height); + } + + // When combining the identity8 multiplier with the row shift, the + // calculations for tx_height == 8 and tx_height == 16 can be simplified + // from ((A * 2) + 1) >> 1) to A. + if ((tx_height & 0x18) != 0) { + return; + } + if (tx_height == 32) { + int i = adjusted_tx_height; + do { + Identity8Row32_NEON(src, /*step=*/8); + src += 32; + i -= 4; + } while (i != 0); + return; + } + + assert(tx_size == kTransformSize8x4); + int i = adjusted_tx_height; + do { + Identity8Row4_NEON(src, /*step=*/8); + src += 32; + i -= 4; + } while (i != 0); +} + +void Identity8TransformLoopColumn_NEON(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, + void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<8>(src, tx_width); + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + IdentityColumnStoreToFrame<8>(frame, start_x, start_y, tx_width, + adjusted_tx_height, src); +} + +void Identity16TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (Identity16DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<16>(src, adjusted_tx_height); + } + int i = adjusted_tx_height; + do { + Identity16Row_NEON(src, /*step=*/16, kTransformRowShift[tx_size]); + src += 64; + i -= 4; + } while (i != 0); +} + +void Identity16TransformLoopColumn_NEON(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, + void* src_buffer, int start_x, + int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<16>(src, tx_width); + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + IdentityColumnStoreToFrame<16>(frame, start_x, start_y, tx_width, + adjusted_tx_height, src); +} + +void Identity32TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + const int tx_height = kTransformHeight[tx_size]; + + // When combining the identity32 multiplier with the row shift, the + // calculations for tx_height == 8 and tx_height == 32 can be simplified + // from ((A * 4) + 2) >> 2) to A. + if ((tx_height & 0x28) != 0) { + return; + } + + // Process kTransformSize32x16. The src is always rounded before the + // identity transform and shifted by 1 afterwards. + auto* src = static_cast<int16_t*>(src_buffer); + if (Identity32DcOnly(src, adjusted_tx_height)) { + return; + } + + assert(tx_size == kTransformSize32x16); + ApplyRounding<32>(src, adjusted_tx_height); + int i = adjusted_tx_height; + do { + Identity32Row16_NEON(src, /*step=*/32); + src += 128; + i -= 4; + } while (i != 0); +} + +void Identity32TransformLoopColumn_NEON(TransformType /*tx_type*/, + TransformSize tx_size, + int adjusted_tx_height, + void* src_buffer, int start_x, + int start_y, void* dst_frame) { + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + IdentityColumnStoreToFrame<32>(frame, start_x, start_y, tx_width, + adjusted_tx_height, src); +} + +void Wht4TransformLoopRow_NEON(TransformType tx_type, TransformSize tx_size, + int /*adjusted_tx_height*/, void* /*src_buffer*/, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + assert(tx_type == kTransformTypeDctDct); + assert(tx_size == kTransformSize4x4); + static_cast<void>(tx_type); + static_cast<void>(tx_size); + // Do both row and column transforms in the column-transform pass. +} + +void Wht4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + assert(tx_type == kTransformTypeDctDct); + assert(tx_size == kTransformSize4x4); + static_cast<void>(tx_type); + static_cast<void>(tx_size); + + // Process 4 1d wht4 rows and columns in parallel. + const auto* src = static_cast<int16_t*>(src_buffer); + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + uint8_t* dst = frame[start_y] + start_x; + const int dst_stride = frame.columns(); + Wht4_NEON(dst, dst_stride, src, adjusted_tx_height); +} + +//------------------------------------------------------------------------------ + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + // Maximum transform size for Dct is 64. + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kRow] = + Dct4TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kColumn] = + Dct4TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kRow] = + Dct8TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kColumn] = + Dct8TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kRow] = + Dct16TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kColumn] = + Dct16TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kRow] = + Dct32TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kColumn] = + Dct32TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kRow] = + Dct64TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kColumn] = + Dct64TransformLoopColumn_NEON; + + // Maximum transform size for Adst is 16. + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kRow] = + Adst4TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kColumn] = + Adst4TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kRow] = + Adst8TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kColumn] = + Adst8TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kRow] = + Adst16TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kColumn] = + Adst16TransformLoopColumn_NEON; + + // Maximum transform size for Identity transform is 32. + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kRow] = + Identity4TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kColumn] = + Identity4TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kRow] = + Identity8TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kColumn] = + Identity8TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kRow] = + Identity16TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kColumn] = + Identity16TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kRow] = + Identity32TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kColumn] = + Identity32TransformLoopColumn_NEON; + + // Maximum transform size for Wht is 4. + dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kRow] = + Wht4TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kColumn] = + Wht4TransformLoopColumn_NEON; +} + +} // namespace +} // namespace low_bitdepth + +void InverseTransformInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void InverseTransformInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/inverse_transform_neon.h b/src/dsp/arm/inverse_transform_neon.h new file mode 100644 index 0000000..af647e8 --- /dev/null +++ b/src/dsp/arm/inverse_transform_neon.h @@ -0,0 +1,52 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_INVERSE_TRANSFORM_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_INVERSE_TRANSFORM_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::inverse_transforms, see the defines below for specifics. +// This function is not thread-safe. +void InverseTransformInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformDct LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformDct LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformDct LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformDct LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize64_1DTransformDct LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformAdst LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformAdst LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformAdst LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformIdentity LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformIdentity LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformIdentity LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_INVERSE_TRANSFORM_NEON_H_ diff --git a/src/dsp/arm/loop_filter_neon.cc b/src/dsp/arm/loop_filter_neon.cc new file mode 100644 index 0000000..146c983 --- /dev/null +++ b/src/dsp/arm/loop_filter_neon.cc @@ -0,0 +1,1190 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/loop_filter.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +// (abs(p1 - p0) > thresh) || (abs(q1 - q0) > thresh) +inline uint8x8_t Hev(const uint8x8_t abd_p0p1_q0q1, const uint8_t thresh) { + const uint8x8_t a = vcgt_u8(abd_p0p1_q0q1, vdup_n_u8(thresh)); + return vorr_u8(a, RightShift<32>(a)); +} + +// abs(p0 - q0) * 2 + abs(p1 - q1) / 2 <= outer_thresh +inline uint8x8_t OuterThreshold(const uint8x8_t p0q0, const uint8x8_t p1q1, + const uint8_t outer_thresh) { + const uint8x8x2_t a = Interleave32(p0q0, p1q1); + const uint8x8_t b = vabd_u8(a.val[0], a.val[1]); + const uint8x8_t p0q0_double = vqadd_u8(b, b); + const uint8x8_t p1q1_half = RightShift<32>(vshr_n_u8(b, 1)); + const uint8x8_t c = vqadd_u8(p0q0_double, p1q1_half); + return vcle_u8(c, vdup_n_u8(outer_thresh)); +} + +// abs(p1 - p0) <= inner_thresh && abs(q1 - q0) <= inner_thresh && +// OuterThreshhold() +inline uint8x8_t NeedsFilter4(const uint8x8_t abd_p0p1_q0q1, + const uint8x8_t p0q0, const uint8x8_t p1q1, + const uint8_t inner_thresh, + const uint8_t outer_thresh) { + const uint8x8_t a = vcle_u8(abd_p0p1_q0q1, vdup_n_u8(inner_thresh)); + const uint8x8_t inner_mask = vand_u8(a, RightShift<32>(a)); + const uint8x8_t outer_mask = OuterThreshold(p0q0, p1q1, outer_thresh); + return vand_u8(inner_mask, outer_mask); +} + +inline void Filter4Masks(const uint8x8_t p0q0, const uint8x8_t p1q1, + const uint8_t hev_thresh, const uint8_t outer_thresh, + const uint8_t inner_thresh, uint8x8_t* const hev_mask, + uint8x8_t* const needs_filter4_mask) { + const uint8x8_t p0p1_q0q1 = vabd_u8(p0q0, p1q1); + // This includes cases where NeedsFilter4() is not true and so Filter2() will + // not be applied. + const uint8x8_t hev_tmp_mask = Hev(p0p1_q0q1, hev_thresh); + + *needs_filter4_mask = + NeedsFilter4(p0p1_q0q1, p0q0, p1q1, inner_thresh, outer_thresh); + + // Filter2() will only be applied if both NeedsFilter4() and Hev() are true. + *hev_mask = vand_u8(hev_tmp_mask, *needs_filter4_mask); +} + +// Calculate Filter4() or Filter2() based on |hev_mask|. +inline void Filter4(const uint8x8_t q0p1, const uint8x8_t p0q1, + const uint8x8_t hev_mask, uint8x8_t* const p1q1_result, + uint8x8_t* const p0q0_result) { + const int16x4_t zero = vdup_n_s16(0); + + // a = 3 * (q0 - p0) + Clip3(p1 - q1, min_signed_val, max_signed_val); + const int16x8_t q0mp0_p1mq1 = vreinterpretq_s16_u16(vsubl_u8(q0p1, p0q1)); + const int16x4_t q0mp0_3 = vmul_n_s16(vget_low_s16(q0mp0_p1mq1), 3); + + // If this is for Filter2() then include |p1mq1|. Otherwise zero it. + const int16x4_t p1mq1 = vget_high_s16(q0mp0_p1mq1); + const int8x8_t p1mq1_saturated = vqmovn_s16(vcombine_s16(p1mq1, zero)); + const int8x8_t hev_option = + vand_s8(vreinterpret_s8_u8(hev_mask), p1mq1_saturated); + + const int16x4_t a = + vget_low_s16(vaddw_s8(vcombine_s16(q0mp0_3, zero), hev_option)); + + // We can not shift with rounding because the clamp comes *before* the + // shifting. a1 = Clip3(a + 4, min_signed_val, max_signed_val) >> 3; a2 = + // Clip3(a + 3, min_signed_val, max_signed_val) >> 3; + const int16x4_t plus_four = vadd_s16(a, vdup_n_s16(4)); + const int16x4_t plus_three = vadd_s16(a, vdup_n_s16(3)); + const int8x8_t a2_a1 = + vshr_n_s8(vqmovn_s16(vcombine_s16(plus_three, plus_four)), 3); + + // a3 is in the high 4 values. + // a3 = (a1 + 1) >> 1; + const int8x8_t a3 = vrshr_n_s8(a2_a1, 1); + + const int16x8_t p0q1_l = vreinterpretq_s16_u16(vmovl_u8(p0q1)); + const int16x8_t q0p1_l = vreinterpretq_s16_u16(vmovl_u8(q0p1)); + + const int16x8_t p1q1_l = + vcombine_s16(vget_high_s16(q0p1_l), vget_high_s16(p0q1_l)); + + const int8x8_t a3_ma3 = InterleaveHigh32(a3, vneg_s8(a3)); + const int16x8_t p1q1_a3 = vaddw_s8(p1q1_l, a3_ma3); + + const int16x8_t p0q0_l = + vcombine_s16(vget_low_s16(p0q1_l), vget_low_s16(q0p1_l)); + // Need to shift the second term or we end up with a2_ma2. + const int8x8_t a2_ma1 = + InterleaveLow32(a2_a1, RightShift<32>(vneg_s8(a2_a1))); + const int16x8_t p0q0_a = vaddw_s8(p0q0_l, a2_ma1); + + *p1q1_result = vqmovun_s16(p1q1_a3); + *p0q0_result = vqmovun_s16(p0q0_a); +} + +void Horizontal4_NEON(void* const dest, const ptrdiff_t stride, + const int outer_thresh, const int inner_thresh, + const int hev_thresh) { + uint8_t* dst = static_cast<uint8_t*>(dest); + + const uint8x8_t p1_v = Load4(dst - 2 * stride); + const uint8x8_t p0_v = Load4(dst - stride); + const uint8x8_t p0q0 = Load4<1>(dst, p0_v); + const uint8x8_t p1q1 = Load4<1>(dst + stride, p1_v); + + uint8x8_t hev_mask; + uint8x8_t needs_filter4_mask; + Filter4Masks(p0q0, p1q1, hev_thresh, outer_thresh, inner_thresh, &hev_mask, + &needs_filter4_mask); + + // Copy the masks to the high bits for packed comparisons later. + hev_mask = InterleaveLow32(hev_mask, hev_mask); + needs_filter4_mask = InterleaveLow32(needs_filter4_mask, needs_filter4_mask); + +#if defined(__aarch64__) + // This provides a good speedup for the unit test. Not sure how applicable it + // is to valid streams though. + // Consider doing this on armv7 if there is a quick way to check if a vector + // is zero. + if (vaddv_u8(needs_filter4_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // defined(__aarch64__) + + uint8x8_t f_p1q1; + uint8x8_t f_p0q0; + const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1); + Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0); + + // Already integrated the Hev mask when calculating the filtered values. + const uint8x8_t p0q0_output = vbsl_u8(needs_filter4_mask, f_p0q0, p0q0); + + // p1/q1 are unmodified if only Hev() is true. This works because it was and'd + // with |needs_filter4_mask| previously. + const uint8x8_t p1q1_mask = veor_u8(hev_mask, needs_filter4_mask); + const uint8x8_t p1q1_output = vbsl_u8(p1q1_mask, f_p1q1, p1q1); + + StoreLo4(dst - 2 * stride, p1q1_output); + StoreLo4(dst - stride, p0q0_output); + StoreHi4(dst, p0q0_output); + StoreHi4(dst + stride, p1q1_output); +} + +void Vertical4_NEON(void* const dest, const ptrdiff_t stride, + const int outer_thresh, const int inner_thresh, + const int hev_thresh) { + uint8_t* dst = static_cast<uint8_t*>(dest); + + // Move |dst| to the left side of the filter window. + dst -= 2; + + // |p1q0| and |p0q1| are named for the values they will contain after the + // transpose. + const uint8x8_t row0 = Load4(dst); + uint8x8_t p1q0 = Load4<1>(dst + stride, row0); + const uint8x8_t row2 = Load4(dst + 2 * stride); + uint8x8_t p0q1 = Load4<1>(dst + 3 * stride, row2); + + Transpose4x4(&p1q0, &p0q1); + // Rearrange. + const uint8x8x2_t p1q1xq0p0 = Interleave32(p1q0, Transpose32(p0q1)); + const uint8x8x2_t p1q1xp0q0 = {p1q1xq0p0.val[0], + Transpose32(p1q1xq0p0.val[1])}; + + uint8x8_t hev_mask; + uint8x8_t needs_filter4_mask; + Filter4Masks(p1q1xp0q0.val[1], p1q1xp0q0.val[0], hev_thresh, outer_thresh, + inner_thresh, &hev_mask, &needs_filter4_mask); + + // Copy the masks to the high bits for packed comparisons later. + hev_mask = InterleaveLow32(hev_mask, hev_mask); + needs_filter4_mask = InterleaveLow32(needs_filter4_mask, needs_filter4_mask); + +#if defined(__aarch64__) + // This provides a good speedup for the unit test. Not sure how applicable it + // is to valid streams though. + // Consider doing this on armv7 if there is a quick way to check if a vector + // is zero. + if (vaddv_u8(needs_filter4_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // defined(__aarch64__) + + uint8x8_t f_p1q1; + uint8x8_t f_p0q0; + Filter4(Transpose32(p1q0), p0q1, hev_mask, &f_p1q1, &f_p0q0); + + // Already integrated the Hev mask when calculating the filtered values. + const uint8x8_t p0q0_output = + vbsl_u8(needs_filter4_mask, f_p0q0, p1q1xp0q0.val[1]); + + // p1/q1 are unmodified if only Hev() is true. This works because it was and'd + // with |needs_filter4_mask| previously. + const uint8x8_t p1q1_mask = veor_u8(hev_mask, needs_filter4_mask); + const uint8x8_t p1q1_output = vbsl_u8(p1q1_mask, f_p1q1, p1q1xp0q0.val[0]); + + // Put things back in order to reverse the transpose. + const uint8x8x2_t p1p0xq1q0 = Interleave32(p1q1_output, p0q0_output); + uint8x8_t output_0 = p1p0xq1q0.val[0], + output_1 = Transpose32(p1p0xq1q0.val[1]); + + Transpose4x4(&output_0, &output_1); + + StoreLo4(dst, output_0); + StoreLo4(dst + stride, output_1); + StoreHi4(dst + 2 * stride, output_0); + StoreHi4(dst + 3 * stride, output_1); +} + +// abs(p1 - p0) <= flat_thresh && abs(q1 - q0) <= flat_thresh && +// abs(p2 - p0) <= flat_thresh && abs(q2 - q0) <= flat_thresh +// |flat_thresh| == 1 for 8 bit decode. +inline uint8x8_t IsFlat3(const uint8x8_t abd_p0p1_q0q1, + const uint8x8_t abd_p0p2_q0q2) { + const uint8x8_t a = vmax_u8(abd_p0p1_q0q1, abd_p0p2_q0q2); + const uint8x8_t b = vcle_u8(a, vdup_n_u8(1)); + return vand_u8(b, RightShift<32>(b)); +} + +// abs(p2 - p1) <= inner_thresh && abs(p1 - p0) <= inner_thresh && +// abs(q1 - q0) <= inner_thresh && abs(q2 - q1) <= inner_thresh && +// OuterThreshhold() +inline uint8x8_t NeedsFilter6(const uint8x8_t abd_p0p1_q0q1, + const uint8x8_t abd_p1p2_q1q2, + const uint8x8_t p0q0, const uint8x8_t p1q1, + const uint8_t inner_thresh, + const uint8_t outer_thresh) { + const uint8x8_t a = vmax_u8(abd_p0p1_q0q1, abd_p1p2_q1q2); + const uint8x8_t b = vcle_u8(a, vdup_n_u8(inner_thresh)); + const uint8x8_t inner_mask = vand_u8(b, RightShift<32>(b)); + const uint8x8_t outer_mask = OuterThreshold(p0q0, p1q1, outer_thresh); + return vand_u8(inner_mask, outer_mask); +} + +inline void Filter6Masks(const uint8x8_t p2q2, const uint8x8_t p1q1, + const uint8x8_t p0q0, const uint8_t hev_thresh, + const uint8_t outer_thresh, const uint8_t inner_thresh, + uint8x8_t* const needs_filter6_mask, + uint8x8_t* const is_flat3_mask, + uint8x8_t* const hev_mask) { + const uint8x8_t p0p1_q0q1 = vabd_u8(p0q0, p1q1); + *hev_mask = Hev(p0p1_q0q1, hev_thresh); + *is_flat3_mask = IsFlat3(p0p1_q0q1, vabd_u8(p0q0, p2q2)); + *needs_filter6_mask = NeedsFilter6(p0p1_q0q1, vabd_u8(p1q1, p2q2), p0q0, p1q1, + inner_thresh, outer_thresh); +} + +inline void Filter6(const uint8x8_t p2q2, const uint8x8_t p1q1, + const uint8x8_t p0q0, uint8x8_t* const p1q1_output, + uint8x8_t* const p0q0_output) { + // Sum p1 and q1 output from opposite directions + // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0 + // ^^^^^^^^ + // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3) + // ^^^^^^^^ + const uint16x8_t p2q2_double = vaddl_u8(p2q2, p2q2); + uint16x8_t sum = vaddw_u8(p2q2_double, p2q2); + + // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0 + // ^^^^^^^^ + // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3) + // ^^^^^^^^ + sum = vaddq_u16(vaddl_u8(p1q1, p1q1), sum); + + // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0 + // ^^^^^^^^ + // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3) + // ^^^^^^^^ + sum = vaddq_u16(vaddl_u8(p0q0, p0q0), sum); + + // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0 + // ^^ + // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3) + // ^^ + const uint8x8_t q0p0 = Transpose32(p0q0); + sum = vaddw_u8(sum, q0p0); + + *p1q1_output = vrshrn_n_u16(sum, 3); + + // Convert to p0 and q0 output: + // p0 = p1 - (2 * p2) + q0 + q1 + // q0 = q1 - (2 * q2) + p0 + p1 + sum = vsubq_u16(sum, p2q2_double); + const uint8x8_t q1p1 = Transpose32(p1q1); + sum = vaddq_u16(vaddl_u8(q0p0, q1p1), sum); + + *p0q0_output = vrshrn_n_u16(sum, 3); +} + +void Horizontal6_NEON(void* const dest, const ptrdiff_t stride, + const int outer_thresh, const int inner_thresh, + const int hev_thresh) { + auto* dst = static_cast<uint8_t*>(dest); + + const uint8x8_t p2_v = Load4(dst - 3 * stride); + const uint8x8_t p1_v = Load4(dst - 2 * stride); + const uint8x8_t p0_v = Load4(dst - stride); + const uint8x8_t p0q0 = Load4<1>(dst, p0_v); + const uint8x8_t p1q1 = Load4<1>(dst + stride, p1_v); + const uint8x8_t p2q2 = Load4<1>(dst + 2 * stride, p2_v); + + uint8x8_t needs_filter6_mask, is_flat3_mask, hev_mask; + Filter6Masks(p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh, + &needs_filter6_mask, &is_flat3_mask, &hev_mask); + + needs_filter6_mask = InterleaveLow32(needs_filter6_mask, needs_filter6_mask); + is_flat3_mask = InterleaveLow32(is_flat3_mask, is_flat3_mask); + hev_mask = InterleaveLow32(hev_mask, hev_mask); + +#if defined(__aarch64__) + // This provides a good speedup for the unit test. Not sure how applicable it + // is to valid streams though. + // Consider doing this on armv7 if there is a quick way to check if a vector + // is zero. + if (vaddv_u8(needs_filter6_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // defined(__aarch64__) + + uint8x8_t f_p1q1; + uint8x8_t f_p0q0; + const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1); + Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0); + // Reset the outer values if only a Hev() mask was required. + f_p1q1 = vbsl_u8(hev_mask, p1q1, f_p1q1); + + uint8x8_t f6_p1q1, f6_p0q0; +#if defined(__aarch64__) + if (vaddv_u8(vand_u8(is_flat3_mask, needs_filter6_mask)) == 0) { + // Filter6() does not apply. + const uint8x8_t zero = vdup_n_u8(0); + f6_p1q1 = zero; + f6_p0q0 = zero; + } else { +#endif // defined(__aarch64__) + Filter6(p2q2, p1q1, p0q0, &f6_p1q1, &f6_p0q0); +#if defined(__aarch64__) + } +#endif // defined(__aarch64__) + + uint8x8_t p1q1_output = vbsl_u8(is_flat3_mask, f6_p1q1, f_p1q1); + p1q1_output = vbsl_u8(needs_filter6_mask, p1q1_output, p1q1); + StoreLo4(dst - 2 * stride, p1q1_output); + StoreHi4(dst + stride, p1q1_output); + + uint8x8_t p0q0_output = vbsl_u8(is_flat3_mask, f6_p0q0, f_p0q0); + p0q0_output = vbsl_u8(needs_filter6_mask, p0q0_output, p0q0); + StoreLo4(dst - stride, p0q0_output); + StoreHi4(dst, p0q0_output); +} + +void Vertical6_NEON(void* const dest, const ptrdiff_t stride, + const int outer_thresh, const int inner_thresh, + const int hev_thresh) { + auto* dst = static_cast<uint8_t*>(dest); + + // Move |dst| to the left side of the filter window. + dst -= 3; + + // |p2q1|, |p1q2|, |p0xx| and |q0xx| are named for the values they will + // contain after the transpose. + // These over-read by 2 bytes. We only need 6. + uint8x8_t p2q1 = vld1_u8(dst); + uint8x8_t p1q2 = vld1_u8(dst + stride); + uint8x8_t p0xx = vld1_u8(dst + 2 * stride); + uint8x8_t q0xx = vld1_u8(dst + 3 * stride); + + Transpose8x4(&p2q1, &p1q2, &p0xx, &q0xx); + + const uint8x8x2_t p2q2xq1p1 = Interleave32(p2q1, Transpose32(p1q2)); + const uint8x8_t p2q2 = p2q2xq1p1.val[0]; + const uint8x8_t p1q1 = Transpose32(p2q2xq1p1.val[1]); + const uint8x8_t p0q0 = InterleaveLow32(p0xx, q0xx); + + uint8x8_t needs_filter6_mask, is_flat3_mask, hev_mask; + Filter6Masks(p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh, + &needs_filter6_mask, &is_flat3_mask, &hev_mask); + + needs_filter6_mask = InterleaveLow32(needs_filter6_mask, needs_filter6_mask); + is_flat3_mask = InterleaveLow32(is_flat3_mask, is_flat3_mask); + hev_mask = InterleaveLow32(hev_mask, hev_mask); + +#if defined(__aarch64__) + // This provides a good speedup for the unit test. Not sure how applicable it + // is to valid streams though. + // Consider doing this on armv7 if there is a quick way to check if a vector + // is zero. + if (vaddv_u8(needs_filter6_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // defined(__aarch64__) + + uint8x8_t f_p1q1; + uint8x8_t f_p0q0; + const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1); + Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0); + // Reset the outer values if only a Hev() mask was required. + f_p1q1 = vbsl_u8(hev_mask, p1q1, f_p1q1); + + uint8x8_t f6_p1q1, f6_p0q0; +#if defined(__aarch64__) + if (vaddv_u8(vand_u8(is_flat3_mask, needs_filter6_mask)) == 0) { + // Filter6() does not apply. + const uint8x8_t zero = vdup_n_u8(0); + f6_p1q1 = zero; + f6_p0q0 = zero; + } else { +#endif // defined(__aarch64__) + Filter6(p2q2, p1q1, p0q0, &f6_p1q1, &f6_p0q0); +#if defined(__aarch64__) + } +#endif // defined(__aarch64__) + + uint8x8_t p1q1_output = vbsl_u8(is_flat3_mask, f6_p1q1, f_p1q1); + p1q1_output = vbsl_u8(needs_filter6_mask, p1q1_output, p1q1); + + uint8x8_t p0q0_output = vbsl_u8(is_flat3_mask, f6_p0q0, f_p0q0); + p0q0_output = vbsl_u8(needs_filter6_mask, p0q0_output, p0q0); + + // The six tap filter is only six taps on input. Output is limited to p1-q1. + dst += 1; + // Put things back in order to reverse the transpose. + const uint8x8x2_t p1p0xq1q0 = Interleave32(p1q1_output, p0q0_output); + uint8x8_t output_0 = p1p0xq1q0.val[0]; + uint8x8_t output_1 = Transpose32(p1p0xq1q0.val[1]); + + Transpose4x4(&output_0, &output_1); + + StoreLo4(dst, output_0); + StoreLo4(dst + stride, output_1); + StoreHi4(dst + 2 * stride, output_0); + StoreHi4(dst + 3 * stride, output_1); +} + +// IsFlat4 uses N=1, IsFlatOuter4 uses N=4. +// abs(p[N] - p0) <= flat_thresh && abs(q[N] - q0) <= flat_thresh && +// abs(p[N+1] - p0) <= flat_thresh && abs(q[N+1] - q0) <= flat_thresh && +// abs(p[N+2] - p0) <= flat_thresh && abs(q[N+1] - q0) <= flat_thresh +// |flat_thresh| == 1 for 8 bit decode. +inline uint8x8_t IsFlat4(const uint8x8_t abd_p0n0_q0n0, + const uint8x8_t abd_p0n1_q0n1, + const uint8x8_t abd_p0n2_q0n2) { + const uint8x8_t a = vmax_u8(abd_p0n0_q0n0, abd_p0n1_q0n1); + const uint8x8_t b = vmax_u8(a, abd_p0n2_q0n2); + const uint8x8_t c = vcle_u8(b, vdup_n_u8(1)); + return vand_u8(c, RightShift<32>(c)); +} + +// abs(p3 - p2) <= inner_thresh && abs(p2 - p1) <= inner_thresh && +// abs(p1 - p0) <= inner_thresh && abs(q1 - q0) <= inner_thresh && +// abs(q2 - q1) <= inner_thresh && abs(q3 - q2) <= inner_thresh +// OuterThreshhold() +inline uint8x8_t NeedsFilter8(const uint8x8_t abd_p0p1_q0q1, + const uint8x8_t abd_p1p2_q1q2, + const uint8x8_t abd_p2p3_q2q3, + const uint8x8_t p0q0, const uint8x8_t p1q1, + const uint8_t inner_thresh, + const uint8_t outer_thresh) { + const uint8x8_t a = vmax_u8(abd_p0p1_q0q1, abd_p1p2_q1q2); + const uint8x8_t b = vmax_u8(a, abd_p2p3_q2q3); + const uint8x8_t c = vcle_u8(b, vdup_n_u8(inner_thresh)); + const uint8x8_t inner_mask = vand_u8(c, RightShift<32>(c)); + const uint8x8_t outer_mask = OuterThreshold(p0q0, p1q1, outer_thresh); + return vand_u8(inner_mask, outer_mask); +} + +inline void Filter8Masks(const uint8x8_t p3q3, const uint8x8_t p2q2, + const uint8x8_t p1q1, const uint8x8_t p0q0, + const uint8_t hev_thresh, const uint8_t outer_thresh, + const uint8_t inner_thresh, + uint8x8_t* const needs_filter8_mask, + uint8x8_t* const is_flat4_mask, + uint8x8_t* const hev_mask) { + const uint8x8_t p0p1_q0q1 = vabd_u8(p0q0, p1q1); + *hev_mask = Hev(p0p1_q0q1, hev_thresh); + *is_flat4_mask = IsFlat4(p0p1_q0q1, vabd_u8(p0q0, p2q2), vabd_u8(p0q0, p3q3)); + *needs_filter8_mask = + NeedsFilter8(p0p1_q0q1, vabd_u8(p1q1, p2q2), vabd_u8(p2q2, p3q3), p0q0, + p1q1, inner_thresh, outer_thresh); +} + +inline void Filter8(const uint8x8_t p3q3, const uint8x8_t p2q2, + const uint8x8_t p1q1, const uint8x8_t p0q0, + uint8x8_t* const p2q2_output, uint8x8_t* const p1q1_output, + uint8x8_t* const p0q0_output) { + // Sum p2 and q2 output from opposite directions + // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0 + // ^^^^^^^^ + // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3) + // ^^^^^^^^ + uint16x8_t sum = vaddw_u8(vaddl_u8(p3q3, p3q3), p3q3); + + // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0 + // ^^^^^^^^ + // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3) + // ^^^^^^^^ + sum = vaddq_u16(vaddl_u8(p2q2, p2q2), sum); + + // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0 + // ^^^^^^^ + // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3) + // ^^^^^^^ + sum = vaddq_u16(vaddl_u8(p1q1, p0q0), sum); + + // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0 + // ^^ + // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3) + // ^^ + const uint8x8_t q0p0 = Transpose32(p0q0); + sum = vaddw_u8(sum, q0p0); + + *p2q2_output = vrshrn_n_u16(sum, 3); + + // Convert to p1 and q1 output: + // p1 = p2 - p3 - p2 + p1 + q1 + // q1 = q2 - q3 - q2 + q0 + p1 + sum = vsubq_u16(sum, vaddl_u8(p3q3, p2q2)); + const uint8x8_t q1p1 = Transpose32(p1q1); + sum = vaddq_u16(vaddl_u8(p1q1, q1p1), sum); + + *p1q1_output = vrshrn_n_u16(sum, 3); + + // Convert to p0 and q0 output: + // p0 = p1 - p3 - p1 + p0 + q2 + // q0 = q1 - q3 - q1 + q0 + p2 + sum = vsubq_u16(sum, vaddl_u8(p3q3, p1q1)); + const uint8x8_t q2p2 = Transpose32(p2q2); + sum = vaddq_u16(vaddl_u8(p0q0, q2p2), sum); + + *p0q0_output = vrshrn_n_u16(sum, 3); +} + +void Horizontal8_NEON(void* const dest, const ptrdiff_t stride, + const int outer_thresh, const int inner_thresh, + const int hev_thresh) { + auto* dst = static_cast<uint8_t*>(dest); + + const uint8x8_t p3_v = Load4(dst - 4 * stride); + const uint8x8_t p2_v = Load4(dst - 3 * stride); + const uint8x8_t p1_v = Load4(dst - 2 * stride); + const uint8x8_t p0_v = Load4(dst - stride); + const uint8x8_t p0q0 = Load4<1>(dst, p0_v); + const uint8x8_t p1q1 = Load4<1>(dst + stride, p1_v); + const uint8x8_t p2q2 = Load4<1>(dst + 2 * stride, p2_v); + const uint8x8_t p3q3 = Load4<1>(dst + 3 * stride, p3_v); + + uint8x8_t needs_filter8_mask, is_flat4_mask, hev_mask; + Filter8Masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh, + &needs_filter8_mask, &is_flat4_mask, &hev_mask); + + needs_filter8_mask = InterleaveLow32(needs_filter8_mask, needs_filter8_mask); + is_flat4_mask = vand_u8(is_flat4_mask, needs_filter8_mask); + is_flat4_mask = InterleaveLow32(is_flat4_mask, is_flat4_mask); + hev_mask = InterleaveLow32(hev_mask, hev_mask); + +#if defined(__aarch64__) + // This provides a good speedup for the unit test. Not sure how applicable it + // is to valid streams though. + // Consider doing this on armv7 if there is a quick way to check if a vector + // is zero. + if (vaddv_u8(needs_filter8_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // defined(__aarch64__) + + uint8x8_t f_p1q1; + uint8x8_t f_p0q0; + const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1); + Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0); + // Reset the outer values if only a Hev() mask was required. + f_p1q1 = vbsl_u8(hev_mask, p1q1, f_p1q1); + + uint8x8_t f8_p2q2, f8_p1q1, f8_p0q0; +#if defined(__aarch64__) + if (vaddv_u8(is_flat4_mask) == 0) { + // Filter8() does not apply. + const uint8x8_t zero = vdup_n_u8(0); + f8_p2q2 = zero; + f8_p1q1 = zero; + f8_p0q0 = zero; + } else { +#endif // defined(__aarch64__) + Filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0); + + const uint8x8_t p2p2_output = vbsl_u8(is_flat4_mask, f8_p2q2, p2q2); + StoreLo4(dst - 3 * stride, p2p2_output); + StoreHi4(dst + 2 * stride, p2p2_output); +#if defined(__aarch64__) + } +#endif // defined(__aarch64__) + + uint8x8_t p1q1_output = vbsl_u8(is_flat4_mask, f8_p1q1, f_p1q1); + p1q1_output = vbsl_u8(needs_filter8_mask, p1q1_output, p1q1); + StoreLo4(dst - 2 * stride, p1q1_output); + StoreHi4(dst + stride, p1q1_output); + + uint8x8_t p0q0_output = vbsl_u8(is_flat4_mask, f8_p0q0, f_p0q0); + p0q0_output = vbsl_u8(needs_filter8_mask, p0q0_output, p0q0); + StoreLo4(dst - stride, p0q0_output); + StoreHi4(dst, p0q0_output); +} + +void Vertical8_NEON(void* const dest, const ptrdiff_t stride, + const int outer_thresh, const int inner_thresh, + const int hev_thresh) { + auto* dst = static_cast<uint8_t*>(dest); + // Move |dst| to the left side of the filter window. + dst -= 4; + + // |p3q0|, |p2q1|, |p1q2| and |p0q3| are named for the values they will + // contain after the transpose. + uint8x8_t p3q0 = vld1_u8(dst); + uint8x8_t p2q1 = vld1_u8(dst + stride); + uint8x8_t p1q2 = vld1_u8(dst + 2 * stride); + uint8x8_t p0q3 = vld1_u8(dst + 3 * stride); + + Transpose8x4(&p3q0, &p2q1, &p1q2, &p0q3); + const uint8x8x2_t p3q3xq0p0 = Interleave32(p3q0, Transpose32(p0q3)); + const uint8x8_t p3q3 = p3q3xq0p0.val[0]; + const uint8x8_t p0q0 = Transpose32(p3q3xq0p0.val[1]); + const uint8x8x2_t p2q2xq1p1 = Interleave32(p2q1, Transpose32(p1q2)); + const uint8x8_t p2q2 = p2q2xq1p1.val[0]; + const uint8x8_t p1q1 = Transpose32(p2q2xq1p1.val[1]); + + uint8x8_t needs_filter8_mask, is_flat4_mask, hev_mask; + Filter8Masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh, + &needs_filter8_mask, &is_flat4_mask, &hev_mask); + + needs_filter8_mask = InterleaveLow32(needs_filter8_mask, needs_filter8_mask); + is_flat4_mask = vand_u8(is_flat4_mask, needs_filter8_mask); + is_flat4_mask = InterleaveLow32(is_flat4_mask, is_flat4_mask); + hev_mask = InterleaveLow32(hev_mask, hev_mask); + +#if defined(__aarch64__) + // This provides a good speedup for the unit test. Not sure how applicable it + // is to valid streams though. + // Consider doing this on armv7 if there is a quick way to check if a vector + // is zero. + if (vaddv_u8(needs_filter8_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // defined(__aarch64__) + + uint8x8_t f_p1q1; + uint8x8_t f_p0q0; + const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1); + Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0); + // Reset the outer values if only a Hev() mask was required. + f_p1q1 = vbsl_u8(hev_mask, p1q1, f_p1q1); + + uint8x8_t f8_p2q2, f8_p1q1, f8_p0q0; +#if defined(__aarch64__) + if (vaddv_u8(is_flat4_mask) == 0) { + // Filter8() does not apply. + const uint8x8_t zero = vdup_n_u8(0); + f8_p2q2 = zero; + f8_p1q1 = zero; + f8_p0q0 = zero; + } else { +#endif // defined(__aarch64__) + Filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0); + +#if defined(__aarch64__) + } +#endif // defined(__aarch64__) + + // Always prepare and store p2/q2 because we need to transpose it anyway. + const uint8x8_t p2q2_output = vbsl_u8(is_flat4_mask, f8_p2q2, p2q2); + + uint8x8_t p1q1_output = vbsl_u8(is_flat4_mask, f8_p1q1, f_p1q1); + p1q1_output = vbsl_u8(needs_filter8_mask, p1q1_output, p1q1); + + uint8x8_t p0q0_output = vbsl_u8(is_flat4_mask, f8_p0q0, f_p0q0); + p0q0_output = vbsl_u8(needs_filter8_mask, p0q0_output, p0q0); + + // Write out p3/q3 as well. There isn't a good way to write out 6 bytes. + // Variable names reflect the values before transposition. + const uint8x8x2_t p3q0xq3p0_output = + Interleave32(p3q3, Transpose32(p0q0_output)); + uint8x8_t p3q0_output = p3q0xq3p0_output.val[0]; + uint8x8_t p0q3_output = Transpose32(p3q0xq3p0_output.val[1]); + const uint8x8x2_t p2q1xq2p1_output = + Interleave32(p2q2_output, Transpose32(p1q1_output)); + uint8x8_t p2q1_output = p2q1xq2p1_output.val[0]; + uint8x8_t p1q2_output = Transpose32(p2q1xq2p1_output.val[1]); + + Transpose8x4(&p3q0_output, &p2q1_output, &p1q2_output, &p0q3_output); + + vst1_u8(dst, p3q0_output); + vst1_u8(dst + stride, p2q1_output); + vst1_u8(dst + 2 * stride, p1q2_output); + vst1_u8(dst + 3 * stride, p0q3_output); +} + +inline void Filter14(const uint8x8_t p6q6, const uint8x8_t p5q5, + const uint8x8_t p4q4, const uint8x8_t p3q3, + const uint8x8_t p2q2, const uint8x8_t p1q1, + const uint8x8_t p0q0, uint8x8_t* const p5q5_output, + uint8x8_t* const p4q4_output, uint8x8_t* const p3q3_output, + uint8x8_t* const p2q2_output, uint8x8_t* const p1q1_output, + uint8x8_t* const p0q0_output) { + // Sum p5 and q5 output from opposite directions + // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0 + // ^^^^^^^^ + // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6) + // ^^^^^^^^ + uint16x8_t sum = vsubw_u8(vshll_n_u8(p6q6, 3), p6q6); + + // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0 + // ^^^^^^^^ + // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6) + // ^^^^^^^^ + sum = vaddq_u16(vaddl_u8(p5q5, p5q5), sum); + + // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0 + // ^^^^^^^^ + // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6) + // ^^^^^^^^ + sum = vaddq_u16(vaddl_u8(p4q4, p4q4), sum); + + // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0 + // ^^^^^^^ + // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6) + // ^^^^^^^ + sum = vaddq_u16(vaddl_u8(p3q3, p2q2), sum); + + // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0 + // ^^^^^^^ + // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6) + // ^^^^^^^ + sum = vaddq_u16(vaddl_u8(p1q1, p0q0), sum); + + // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0 + // ^^ + // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6) + // ^^ + const uint8x8_t q0p0 = Transpose32(p0q0); + sum = vaddw_u8(sum, q0p0); + + *p5q5_output = vrshrn_n_u16(sum, 4); + + // Convert to p4 and q4 output: + // p4 = p5 - (2 * p6) + p3 + q1 + // q4 = q5 - (2 * q6) + q3 + p1 + sum = vsubq_u16(sum, vaddl_u8(p6q6, p6q6)); + const uint8x8_t q1p1 = Transpose32(p1q1); + sum = vaddq_u16(vaddl_u8(p3q3, q1p1), sum); + + *p4q4_output = vrshrn_n_u16(sum, 4); + + // Convert to p3 and q3 output: + // p3 = p4 - p6 - p5 + p2 + q2 + // q3 = q4 - q6 - q5 + q2 + p2 + sum = vsubq_u16(sum, vaddl_u8(p6q6, p5q5)); + const uint8x8_t q2p2 = Transpose32(p2q2); + sum = vaddq_u16(vaddl_u8(p2q2, q2p2), sum); + + *p3q3_output = vrshrn_n_u16(sum, 4); + + // Convert to p2 and q2 output: + // p2 = p3 - p6 - p4 + p1 + q3 + // q2 = q3 - q6 - q4 + q1 + p3 + sum = vsubq_u16(sum, vaddl_u8(p6q6, p4q4)); + const uint8x8_t q3p3 = Transpose32(p3q3); + sum = vaddq_u16(vaddl_u8(p1q1, q3p3), sum); + + *p2q2_output = vrshrn_n_u16(sum, 4); + + // Convert to p1 and q1 output: + // p1 = p2 - p6 - p3 + p0 + q4 + // q1 = q2 - q6 - q3 + q0 + p4 + sum = vsubq_u16(sum, vaddl_u8(p6q6, p3q3)); + const uint8x8_t q4p4 = Transpose32(p4q4); + sum = vaddq_u16(vaddl_u8(p0q0, q4p4), sum); + + *p1q1_output = vrshrn_n_u16(sum, 4); + + // Convert to p0 and q0 output: + // p0 = p1 - p6 - p2 + q0 + q5 + // q0 = q1 - q6 - q2 + p0 + p5 + sum = vsubq_u16(sum, vaddl_u8(p6q6, p2q2)); + const uint8x8_t q5p5 = Transpose32(p5q5); + sum = vaddq_u16(vaddl_u8(q0p0, q5p5), sum); + + *p0q0_output = vrshrn_n_u16(sum, 4); +} + +void Horizontal14_NEON(void* const dest, const ptrdiff_t stride, + const int outer_thresh, const int inner_thresh, + const int hev_thresh) { + auto* dst = static_cast<uint8_t*>(dest); + + const uint8x8_t p6_v = Load4(dst - 7 * stride); + const uint8x8_t p5_v = Load4(dst - 6 * stride); + const uint8x8_t p4_v = Load4(dst - 5 * stride); + const uint8x8_t p3_v = Load4(dst - 4 * stride); + const uint8x8_t p2_v = Load4(dst - 3 * stride); + const uint8x8_t p1_v = Load4(dst - 2 * stride); + const uint8x8_t p0_v = Load4(dst - stride); + const uint8x8_t p0q0 = Load4<1>(dst, p0_v); + const uint8x8_t p1q1 = Load4<1>(dst + stride, p1_v); + const uint8x8_t p2q2 = Load4<1>(dst + 2 * stride, p2_v); + const uint8x8_t p3q3 = Load4<1>(dst + 3 * stride, p3_v); + const uint8x8_t p4q4 = Load4<1>(dst + 4 * stride, p4_v); + const uint8x8_t p5q5 = Load4<1>(dst + 5 * stride, p5_v); + const uint8x8_t p6q6 = Load4<1>(dst + 6 * stride, p6_v); + + uint8x8_t needs_filter8_mask, is_flat4_mask, hev_mask; + Filter8Masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh, + &needs_filter8_mask, &is_flat4_mask, &hev_mask); + + needs_filter8_mask = InterleaveLow32(needs_filter8_mask, needs_filter8_mask); + is_flat4_mask = vand_u8(is_flat4_mask, needs_filter8_mask); + is_flat4_mask = InterleaveLow32(is_flat4_mask, is_flat4_mask); + hev_mask = InterleaveLow32(hev_mask, hev_mask); + +#if defined(__aarch64__) + // This provides a good speedup for the unit test. Not sure how applicable it + // is to valid streams though. + // Consider doing this on armv7 if there is a quick way to check if a vector + // is zero. + if (vaddv_u8(needs_filter8_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // defined(__aarch64__) + + // Decide between Filter8() and Filter14(). + uint8x8_t is_flat_outer4_mask = + IsFlat4(vabd_u8(p0q0, p4q4), vabd_u8(p0q0, p5q5), vabd_u8(p0q0, p6q6)); + is_flat_outer4_mask = vand_u8(is_flat4_mask, is_flat_outer4_mask); + is_flat_outer4_mask = + InterleaveLow32(is_flat_outer4_mask, is_flat_outer4_mask); + + uint8x8_t f_p1q1; + uint8x8_t f_p0q0; + const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1); + Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0); + // Reset the outer values if only a Hev() mask was required. + f_p1q1 = vbsl_u8(hev_mask, p1q1, f_p1q1); + + uint8x8_t f8_p1q1, f8_p0q0; + uint8x8_t f14_p2q2, f14_p1q1, f14_p0q0; +#if defined(__aarch64__) + if (vaddv_u8(is_flat4_mask) == 0) { + // Filter8() and Filter14() do not apply. + const uint8x8_t zero = vdup_n_u8(0); + f8_p1q1 = zero; + f8_p0q0 = zero; + f14_p1q1 = zero; + f14_p0q0 = zero; + } else { +#endif // defined(__aarch64__) + uint8x8_t f8_p2q2; + Filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0); + +#if defined(__aarch64__) + if (vaddv_u8(is_flat_outer4_mask) == 0) { + // Filter14() does not apply. + const uint8x8_t zero = vdup_n_u8(0); + f14_p2q2 = zero; + f14_p1q1 = zero; + f14_p0q0 = zero; + } else { +#endif // defined(__aarch64__) + uint8x8_t f14_p5q5, f14_p4q4, f14_p3q3; + Filter14(p6q6, p5q5, p4q4, p3q3, p2q2, p1q1, p0q0, &f14_p5q5, &f14_p4q4, + &f14_p3q3, &f14_p2q2, &f14_p1q1, &f14_p0q0); + + const uint8x8_t p5q5_output = + vbsl_u8(is_flat_outer4_mask, f14_p5q5, p5q5); + StoreLo4(dst - 6 * stride, p5q5_output); + StoreHi4(dst + 5 * stride, p5q5_output); + + const uint8x8_t p4q4_output = + vbsl_u8(is_flat_outer4_mask, f14_p4q4, p4q4); + StoreLo4(dst - 5 * stride, p4q4_output); + StoreHi4(dst + 4 * stride, p4q4_output); + + const uint8x8_t p3q3_output = + vbsl_u8(is_flat_outer4_mask, f14_p3q3, p3q3); + StoreLo4(dst - 4 * stride, p3q3_output); + StoreHi4(dst + 3 * stride, p3q3_output); +#if defined(__aarch64__) + } +#endif // defined(__aarch64__) + + uint8x8_t p2q2_output = vbsl_u8(is_flat_outer4_mask, f14_p2q2, f8_p2q2); + p2q2_output = vbsl_u8(is_flat4_mask, p2q2_output, p2q2); + StoreLo4(dst - 3 * stride, p2q2_output); + StoreHi4(dst + 2 * stride, p2q2_output); +#if defined(__aarch64__) + } +#endif // defined(__aarch64__) + + uint8x8_t p1q1_output = vbsl_u8(is_flat_outer4_mask, f14_p1q1, f8_p1q1); + p1q1_output = vbsl_u8(is_flat4_mask, p1q1_output, f_p1q1); + p1q1_output = vbsl_u8(needs_filter8_mask, p1q1_output, p1q1); + StoreLo4(dst - 2 * stride, p1q1_output); + StoreHi4(dst + stride, p1q1_output); + + uint8x8_t p0q0_output = vbsl_u8(is_flat_outer4_mask, f14_p0q0, f8_p0q0); + p0q0_output = vbsl_u8(is_flat4_mask, p0q0_output, f_p0q0); + p0q0_output = vbsl_u8(needs_filter8_mask, p0q0_output, p0q0); + StoreLo4(dst - stride, p0q0_output); + StoreHi4(dst, p0q0_output); +} + +void Vertical14_NEON(void* const dest, const ptrdiff_t stride, + const int outer_thresh, const int inner_thresh, + const int hev_thresh) { + auto* dst = static_cast<uint8_t*>(dest); + dst -= 8; + // input + // p7 p6 p5 p4 p3 p2 p1 p0 q0 q1 q2 q3 q4 q5 q6 q7 + const uint8x16_t x0 = vld1q_u8(dst); + dst += stride; + const uint8x16_t x1 = vld1q_u8(dst); + dst += stride; + const uint8x16_t x2 = vld1q_u8(dst); + dst += stride; + const uint8x16_t x3 = vld1q_u8(dst); + dst -= (stride * 3); + + // re-order input +#if defined(__aarch64__) + const uint8x8_t index_qp3toqp0 = vcreate_u8(0x0b0a090804050607); + const uint8x8_t index_qp7toqp4 = vcreate_u8(0x0f0e0d0c00010203); + const uint8x16_t index_qp7toqp0 = vcombine_u8(index_qp3toqp0, index_qp7toqp4); + + uint8x16_t input_0 = vqtbl1q_u8(x0, index_qp7toqp0); + uint8x16_t input_1 = vqtbl1q_u8(x1, index_qp7toqp0); + uint8x16_t input_2 = vqtbl1q_u8(x2, index_qp7toqp0); + uint8x16_t input_3 = vqtbl1q_u8(x3, index_qp7toqp0); +#else + const uint8x8_t index_qp3toqp0 = vcreate_u8(0x0b0a090804050607); + const uint8x8_t index_qp7toqp4 = vcreate_u8(0x0f0e0d0c00010203); + + const uint8x8_t x0_qp3qp0 = VQTbl1U8(x0, index_qp3toqp0); + const uint8x8_t x1_qp3qp0 = VQTbl1U8(x1, index_qp3toqp0); + const uint8x8_t x2_qp3qp0 = VQTbl1U8(x2, index_qp3toqp0); + const uint8x8_t x3_qp3qp0 = VQTbl1U8(x3, index_qp3toqp0); + + const uint8x8_t x0_qp7qp4 = VQTbl1U8(x0, index_qp7toqp4); + const uint8x8_t x1_qp7qp4 = VQTbl1U8(x1, index_qp7toqp4); + const uint8x8_t x2_qp7qp4 = VQTbl1U8(x2, index_qp7toqp4); + const uint8x8_t x3_qp7qp4 = VQTbl1U8(x3, index_qp7toqp4); + + const uint8x16_t input_0 = vcombine_u8(x0_qp3qp0, x0_qp7qp4); + const uint8x16_t input_1 = vcombine_u8(x1_qp3qp0, x1_qp7qp4); + const uint8x16_t input_2 = vcombine_u8(x2_qp3qp0, x2_qp7qp4); + const uint8x16_t input_3 = vcombine_u8(x3_qp3qp0, x3_qp7qp4); +#endif + // input after re-order + // p0 p1 p2 p3 q0 q1 q2 q3 p4 p5 p6 p7 q4 q5 q6 q7 + + const uint8x16x2_t in01 = vtrnq_u8(input_0, input_1); + const uint8x16x2_t in23 = vtrnq_u8(input_2, input_3); + const uint16x8x2_t in02 = vtrnq_u16(vreinterpretq_u16_u8(in01.val[0]), + vreinterpretq_u16_u8(in23.val[0])); + const uint16x8x2_t in13 = vtrnq_u16(vreinterpretq_u16_u8(in01.val[1]), + vreinterpretq_u16_u8(in23.val[1])); + + const uint8x8_t p0q0 = vget_low_u8(vreinterpretq_u8_u16(in02.val[0])); + const uint8x8_t p1q1 = vget_low_u8(vreinterpretq_u8_u16(in13.val[0])); + + const uint8x8_t p2q2 = vget_low_u8(vreinterpretq_u8_u16(in02.val[1])); + const uint8x8_t p3q3 = vget_low_u8(vreinterpretq_u8_u16(in13.val[1])); + + const uint8x8_t p4q4 = vget_high_u8(vreinterpretq_u8_u16(in02.val[0])); + const uint8x8_t p5q5 = vget_high_u8(vreinterpretq_u8_u16(in13.val[0])); + + const uint8x8_t p6q6 = vget_high_u8(vreinterpretq_u8_u16(in02.val[1])); + const uint8x8_t p7q7 = vget_high_u8(vreinterpretq_u8_u16(in13.val[1])); + + uint8x8_t needs_filter8_mask, is_flat4_mask, hev_mask; + Filter8Masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh, + &needs_filter8_mask, &is_flat4_mask, &hev_mask); + + needs_filter8_mask = InterleaveLow32(needs_filter8_mask, needs_filter8_mask); + is_flat4_mask = vand_u8(is_flat4_mask, needs_filter8_mask); + is_flat4_mask = InterleaveLow32(is_flat4_mask, is_flat4_mask); + hev_mask = InterleaveLow32(hev_mask, hev_mask); + +#if defined(__aarch64__) + // This provides a good speedup for the unit test. Not sure how applicable it + // is to valid streams though. + // Consider doing this on armv7 if there is a quick way to check if a vector + // is zero. + if (vaddv_u8(needs_filter8_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // defined(__aarch64__) + + // Decide between Filter8() and Filter14(). + uint8x8_t is_flat_outer4_mask = + IsFlat4(vabd_u8(p0q0, p4q4), vabd_u8(p0q0, p5q5), vabd_u8(p0q0, p6q6)); + is_flat_outer4_mask = vand_u8(is_flat4_mask, is_flat_outer4_mask); + is_flat_outer4_mask = + InterleaveLow32(is_flat_outer4_mask, is_flat_outer4_mask); + + uint8x8_t f_p0q0, f_p1q1; + const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1); + Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0); + // Reset the outer values if only a Hev() mask was required. + f_p1q1 = vbsl_u8(hev_mask, p1q1, f_p1q1); + + uint8x8_t p1q1_output, p0q0_output; + uint8x8_t p5q5_output, p4q4_output, p3q3_output, p2q2_output; + +#if defined(__aarch64__) + if (vaddv_u8(is_flat4_mask) == 0) { + // Filter8() and Filter14() do not apply. + p1q1_output = p1q1; + p0q0_output = p0q0; + + p5q5_output = p5q5; + p4q4_output = p4q4; + p3q3_output = p3q3; + p2q2_output = p2q2; + } else { +#endif // defined(__aarch64__) + uint8x8_t f8_p2q2, f8_p1q1, f8_p0q0; + Filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0); + +#if defined(__aarch64__) + if (vaddv_u8(is_flat_outer4_mask) == 0) { + // Filter14() does not apply. + p5q5_output = p5q5; + p4q4_output = p4q4; + p3q3_output = p3q3; + p2q2_output = f8_p2q2; + p1q1_output = f8_p1q1; + p0q0_output = f8_p0q0; + } else { +#endif // defined(__aarch64__) + uint8x8_t f14_p5q5, f14_p4q4, f14_p3q3, f14_p2q2, f14_p1q1, f14_p0q0; + Filter14(p6q6, p5q5, p4q4, p3q3, p2q2, p1q1, p0q0, &f14_p5q5, &f14_p4q4, + &f14_p3q3, &f14_p2q2, &f14_p1q1, &f14_p0q0); + + p5q5_output = vbsl_u8(is_flat_outer4_mask, f14_p5q5, p5q5); + p4q4_output = vbsl_u8(is_flat_outer4_mask, f14_p4q4, p4q4); + p3q3_output = vbsl_u8(is_flat_outer4_mask, f14_p3q3, p3q3); + p2q2_output = vbsl_u8(is_flat_outer4_mask, f14_p2q2, f8_p2q2); + p1q1_output = vbsl_u8(is_flat_outer4_mask, f14_p1q1, f8_p1q1); + p0q0_output = vbsl_u8(is_flat_outer4_mask, f14_p0q0, f8_p0q0); +#if defined(__aarch64__) + } +#endif // defined(__aarch64__) + p2q2_output = vbsl_u8(is_flat4_mask, p2q2_output, p2q2); +#if defined(__aarch64__) + } +#endif // defined(__aarch64__) + + p1q1_output = vbsl_u8(is_flat4_mask, p1q1_output, f_p1q1); + p1q1_output = vbsl_u8(needs_filter8_mask, p1q1_output, p1q1); + p0q0_output = vbsl_u8(is_flat4_mask, p0q0_output, f_p0q0); + p0q0_output = vbsl_u8(needs_filter8_mask, p0q0_output, p0q0); + + const uint8x16_t p0q0_p4q4 = vcombine_u8(p0q0_output, p4q4_output); + const uint8x16_t p2q2_p6q6 = vcombine_u8(p2q2_output, p6q6); + const uint8x16_t p1q1_p5q5 = vcombine_u8(p1q1_output, p5q5_output); + const uint8x16_t p3q3_p7q7 = vcombine_u8(p3q3_output, p7q7); + + const uint16x8x2_t out02 = vtrnq_u16(vreinterpretq_u16_u8(p0q0_p4q4), + vreinterpretq_u16_u8(p2q2_p6q6)); + const uint16x8x2_t out13 = vtrnq_u16(vreinterpretq_u16_u8(p1q1_p5q5), + vreinterpretq_u16_u8(p3q3_p7q7)); + const uint8x16x2_t out01 = vtrnq_u8(vreinterpretq_u8_u16(out02.val[0]), + vreinterpretq_u8_u16(out13.val[0])); + const uint8x16x2_t out23 = vtrnq_u8(vreinterpretq_u8_u16(out02.val[1]), + vreinterpretq_u8_u16(out13.val[1])); + +#if defined(__aarch64__) + const uint8x8_t index_p7top0 = vcreate_u8(0x0001020308090a0b); + const uint8x8_t index_q7toq0 = vcreate_u8(0x0f0e0d0c07060504); + const uint8x16_t index_p7toq7 = vcombine_u8(index_p7top0, index_q7toq0); + + const uint8x16_t output_0 = vqtbl1q_u8(out01.val[0], index_p7toq7); + const uint8x16_t output_1 = vqtbl1q_u8(out01.val[1], index_p7toq7); + const uint8x16_t output_2 = vqtbl1q_u8(out23.val[0], index_p7toq7); + const uint8x16_t output_3 = vqtbl1q_u8(out23.val[1], index_p7toq7); +#else + const uint8x8_t index_p7top0 = vcreate_u8(0x0001020308090a0b); + const uint8x8_t index_q7toq0 = vcreate_u8(0x0f0e0d0c07060504); + + const uint8x8_t x0_p7p0 = VQTbl1U8(out01.val[0], index_p7top0); + const uint8x8_t x1_p7p0 = VQTbl1U8(out01.val[1], index_p7top0); + const uint8x8_t x2_p7p0 = VQTbl1U8(out23.val[0], index_p7top0); + const uint8x8_t x3_p7p0 = VQTbl1U8(out23.val[1], index_p7top0); + + const uint8x8_t x0_q7q0 = VQTbl1U8(out01.val[0], index_q7toq0); + const uint8x8_t x1_q7q0 = VQTbl1U8(out01.val[1], index_q7toq0); + const uint8x8_t x2_q7q0 = VQTbl1U8(out23.val[0], index_q7toq0); + const uint8x8_t x3_q7q0 = VQTbl1U8(out23.val[1], index_q7toq0); + + const uint8x16_t output_0 = vcombine_u8(x0_p7p0, x0_q7q0); + const uint8x16_t output_1 = vcombine_u8(x1_p7p0, x1_q7q0); + const uint8x16_t output_2 = vcombine_u8(x2_p7p0, x2_q7q0); + const uint8x16_t output_3 = vcombine_u8(x3_p7p0, x3_q7q0); +#endif + + vst1q_u8(dst, output_0); + dst += stride; + vst1q_u8(dst, output_1); + dst += stride; + vst1q_u8(dst, output_2); + dst += stride; + vst1q_u8(dst, output_3); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeHorizontal] = + Horizontal4_NEON; + dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeVertical] = Vertical4_NEON; + + dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeHorizontal] = + Horizontal6_NEON; + dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeVertical] = Vertical6_NEON; + + dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeHorizontal] = + Horizontal8_NEON; + dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeVertical] = Vertical8_NEON; + + dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeHorizontal] = + Horizontal14_NEON; + dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeVertical] = + Vertical14_NEON; +} +} // namespace +} // namespace low_bitdepth + +void LoopFilterInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void LoopFilterInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/loop_filter_neon.h b/src/dsp/arm/loop_filter_neon.h new file mode 100644 index 0000000..5f79200 --- /dev/null +++ b/src/dsp/arm/loop_filter_neon.h @@ -0,0 +1,53 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_LOOP_FILTER_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_LOOP_FILTER_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::loop_filters, see the defines below for specifics. This +// function is not thread-safe. +void LoopFilterInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON + +#define LIBGAV1_Dsp8bpp_LoopFilterSize4_LoopFilterTypeHorizontal \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_LoopFilterSize4_LoopFilterTypeVertical LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_LoopFilterSize6_LoopFilterTypeHorizontal \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_LoopFilterSize6_LoopFilterTypeVertical LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_LoopFilterSize8_LoopFilterTypeHorizontal \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_LoopFilterSize8_LoopFilterTypeVertical LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeHorizontal \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeVertical LIBGAV1_CPU_NEON + +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_LOOP_FILTER_NEON_H_ diff --git a/src/dsp/arm/loop_restoration_neon.cc b/src/dsp/arm/loop_restoration_neon.cc new file mode 100644 index 0000000..337c9b4 --- /dev/null +++ b/src/dsp/arm/loop_restoration_neon.cc @@ -0,0 +1,1901 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/loop_restoration.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON +#include <arm_neon.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +template <int bytes> +inline uint8x8_t VshrU128(const uint8x8x2_t src) { + return vext_u8(src.val[0], src.val[1], bytes); +} + +template <int bytes> +inline uint16x8_t VshrU128(const uint16x8x2_t src) { + return vextq_u16(src.val[0], src.val[1], bytes / 2); +} + +// Wiener + +// Must make a local copy of coefficients to help compiler know that they have +// no overlap with other buffers. Using 'const' keyword is not enough. Actually +// compiler doesn't make a copy, since there is enough registers in this case. +inline void PopulateWienerCoefficients( + const RestorationUnitInfo& restoration_info, const int direction, + int16_t filter[4]) { + // In order to keep the horizontal pass intermediate values within 16 bits we + // offset |filter[3]| by 128. The 128 offset will be added back in the loop. + for (int i = 0; i < 4; ++i) { + filter[i] = restoration_info.wiener_info.filter[direction][i]; + } + if (direction == WienerInfo::kHorizontal) { + filter[3] -= 128; + } +} + +inline int16x8_t WienerHorizontal2(const uint8x8_t s0, const uint8x8_t s1, + const int16_t filter, const int16x8_t sum) { + const int16x8_t ss = vreinterpretq_s16_u16(vaddl_u8(s0, s1)); + return vmlaq_n_s16(sum, ss, filter); +} + +inline int16x8x2_t WienerHorizontal2(const uint8x16_t s0, const uint8x16_t s1, + const int16_t filter, + const int16x8x2_t sum) { + int16x8x2_t d; + d.val[0] = + WienerHorizontal2(vget_low_u8(s0), vget_low_u8(s1), filter, sum.val[0]); + d.val[1] = + WienerHorizontal2(vget_high_u8(s0), vget_high_u8(s1), filter, sum.val[1]); + return d; +} + +inline void WienerHorizontalSum(const uint8x8_t s[3], const int16_t filter[4], + int16x8_t sum, int16_t* const wiener_buffer) { + constexpr int offset = + 1 << (8 + kWienerFilterBits - kInterRoundBitsHorizontal - 1); + constexpr int limit = (offset << 2) - 1; + const int16x8_t s_0_2 = vreinterpretq_s16_u16(vaddl_u8(s[0], s[2])); + const int16x8_t s_1 = ZeroExtend(s[1]); + sum = vmlaq_n_s16(sum, s_0_2, filter[2]); + sum = vmlaq_n_s16(sum, s_1, filter[3]); + // Calculate scaled down offset correction, and add to sum here to prevent + // signed 16 bit outranging. + sum = vrsraq_n_s16(vshlq_n_s16(s_1, 7 - kInterRoundBitsHorizontal), sum, + kInterRoundBitsHorizontal); + sum = vmaxq_s16(sum, vdupq_n_s16(-offset)); + sum = vminq_s16(sum, vdupq_n_s16(limit - offset)); + vst1q_s16(wiener_buffer, sum); +} + +inline void WienerHorizontalSum(const uint8x16_t src[3], + const int16_t filter[4], int16x8x2_t sum, + int16_t* const wiener_buffer) { + uint8x8_t s[3]; + s[0] = vget_low_u8(src[0]); + s[1] = vget_low_u8(src[1]); + s[2] = vget_low_u8(src[2]); + WienerHorizontalSum(s, filter, sum.val[0], wiener_buffer); + s[0] = vget_high_u8(src[0]); + s[1] = vget_high_u8(src[1]); + s[2] = vget_high_u8(src[2]); + WienerHorizontalSum(s, filter, sum.val[1], wiener_buffer + 8); +} + +inline void WienerHorizontalTap7(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + const int16_t filter[4], + int16_t** const wiener_buffer) { + for (int y = height; y != 0; --y) { + const uint8_t* src_ptr = src; + uint8x16_t s[8]; + s[0] = vld1q_u8(src_ptr); + ptrdiff_t x = width; + do { + src_ptr += 16; + s[7] = vld1q_u8(src_ptr); + s[1] = vextq_u8(s[0], s[7], 1); + s[2] = vextq_u8(s[0], s[7], 2); + s[3] = vextq_u8(s[0], s[7], 3); + s[4] = vextq_u8(s[0], s[7], 4); + s[5] = vextq_u8(s[0], s[7], 5); + s[6] = vextq_u8(s[0], s[7], 6); + int16x8x2_t sum; + sum.val[0] = sum.val[1] = vdupq_n_s16(0); + sum = WienerHorizontal2(s[0], s[6], filter[0], sum); + sum = WienerHorizontal2(s[1], s[5], filter[1], sum); + WienerHorizontalSum(s + 2, filter, sum, *wiener_buffer); + s[0] = s[7]; + *wiener_buffer += 16; + x -= 16; + } while (x != 0); + src += src_stride; + } +} + +inline void WienerHorizontalTap5(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + const int16_t filter[4], + int16_t** const wiener_buffer) { + for (int y = height; y != 0; --y) { + const uint8_t* src_ptr = src; + uint8x16_t s[6]; + s[0] = vld1q_u8(src_ptr); + ptrdiff_t x = width; + do { + src_ptr += 16; + s[5] = vld1q_u8(src_ptr); + s[1] = vextq_u8(s[0], s[5], 1); + s[2] = vextq_u8(s[0], s[5], 2); + s[3] = vextq_u8(s[0], s[5], 3); + s[4] = vextq_u8(s[0], s[5], 4); + int16x8x2_t sum; + sum.val[0] = sum.val[1] = vdupq_n_s16(0); + sum = WienerHorizontal2(s[0], s[4], filter[1], sum); + WienerHorizontalSum(s + 1, filter, sum, *wiener_buffer); + s[0] = s[5]; + *wiener_buffer += 16; + x -= 16; + } while (x != 0); + src += src_stride; + } +} + +inline void WienerHorizontalTap3(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + const int16_t filter[4], + int16_t** const wiener_buffer) { + for (int y = height; y != 0; --y) { + const uint8_t* src_ptr = src; + uint8x16_t s[4]; + s[0] = vld1q_u8(src_ptr); + ptrdiff_t x = width; + do { + src_ptr += 16; + s[3] = vld1q_u8(src_ptr); + s[1] = vextq_u8(s[0], s[3], 1); + s[2] = vextq_u8(s[0], s[3], 2); + int16x8x2_t sum; + sum.val[0] = sum.val[1] = vdupq_n_s16(0); + WienerHorizontalSum(s, filter, sum, *wiener_buffer); + s[0] = s[3]; + *wiener_buffer += 16; + x -= 16; + } while (x != 0); + src += src_stride; + } +} + +inline void WienerHorizontalTap1(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + int16_t** const wiener_buffer) { + for (int y = height; y != 0; --y) { + const uint8_t* src_ptr = src; + ptrdiff_t x = width; + do { + const uint8x16_t s = vld1q_u8(src_ptr); + const uint8x8_t s0 = vget_low_u8(s); + const uint8x8_t s1 = vget_high_u8(s); + const int16x8_t d0 = vreinterpretq_s16_u16(vshll_n_u8(s0, 4)); + const int16x8_t d1 = vreinterpretq_s16_u16(vshll_n_u8(s1, 4)); + vst1q_s16(*wiener_buffer + 0, d0); + vst1q_s16(*wiener_buffer + 8, d1); + src_ptr += 16; + *wiener_buffer += 16; + x -= 16; + } while (x != 0); + src += src_stride; + } +} + +inline int32x4x2_t WienerVertical2(const int16x8_t a0, const int16x8_t a1, + const int16_t filter, + const int32x4x2_t sum) { + const int16x8_t a = vaddq_s16(a0, a1); + int32x4x2_t d; + d.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(a), filter); + d.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(a), filter); + return d; +} + +inline uint8x8_t WienerVertical(const int16x8_t a[3], const int16_t filter[4], + const int32x4x2_t sum) { + int32x4x2_t d = WienerVertical2(a[0], a[2], filter[2], sum); + d.val[0] = vmlal_n_s16(d.val[0], vget_low_s16(a[1]), filter[3]); + d.val[1] = vmlal_n_s16(d.val[1], vget_high_s16(a[1]), filter[3]); + const uint16x4_t sum_lo_16 = vqrshrun_n_s32(d.val[0], 11); + const uint16x4_t sum_hi_16 = vqrshrun_n_s32(d.val[1], 11); + return vqmovn_u16(vcombine_u16(sum_lo_16, sum_hi_16)); +} + +inline uint8x8_t WienerVerticalTap7Kernel(const int16_t* const wiener_buffer, + const ptrdiff_t wiener_stride, + const int16_t filter[4], + int16x8_t a[7]) { + int32x4x2_t sum; + a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride); + a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride); + a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride); + a[6] = vld1q_s16(wiener_buffer + 6 * wiener_stride); + sum.val[0] = sum.val[1] = vdupq_n_s32(0); + sum = WienerVertical2(a[0], a[6], filter[0], sum); + sum = WienerVertical2(a[1], a[5], filter[1], sum); + a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride); + a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride); + a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride); + return WienerVertical(a + 2, filter, sum); +} + +inline uint8x8x2_t WienerVerticalTap7Kernel2(const int16_t* const wiener_buffer, + const ptrdiff_t wiener_stride, + const int16_t filter[4]) { + int16x8_t a[8]; + int32x4x2_t sum; + uint8x8x2_t d; + d.val[0] = WienerVerticalTap7Kernel(wiener_buffer, wiener_stride, filter, a); + a[7] = vld1q_s16(wiener_buffer + 7 * wiener_stride); + sum.val[0] = sum.val[1] = vdupq_n_s32(0); + sum = WienerVertical2(a[1], a[7], filter[0], sum); + sum = WienerVertical2(a[2], a[6], filter[1], sum); + d.val[1] = WienerVertical(a + 3, filter, sum); + return d; +} + +inline void WienerVerticalTap7(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + const int16_t filter[4], uint8_t* dst, + const ptrdiff_t dst_stride) { + for (int y = height >> 1; y != 0; --y) { + uint8_t* dst_ptr = dst; + ptrdiff_t x = width; + do { + uint8x8x2_t d[2]; + d[0] = WienerVerticalTap7Kernel2(wiener_buffer + 0, width, filter); + d[1] = WienerVerticalTap7Kernel2(wiener_buffer + 8, width, filter); + vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0])); + vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1])); + wiener_buffer += 16; + dst_ptr += 16; + x -= 16; + } while (x != 0); + wiener_buffer += width; + dst += 2 * dst_stride; + } + + if ((height & 1) != 0) { + ptrdiff_t x = width; + do { + int16x8_t a[7]; + const uint8x8_t d0 = + WienerVerticalTap7Kernel(wiener_buffer + 0, width, filter, a); + const uint8x8_t d1 = + WienerVerticalTap7Kernel(wiener_buffer + 8, width, filter, a); + vst1q_u8(dst, vcombine_u8(d0, d1)); + wiener_buffer += 16; + dst += 16; + x -= 16; + } while (x != 0); + } +} + +inline uint8x8_t WienerVerticalTap5Kernel(const int16_t* const wiener_buffer, + const ptrdiff_t wiener_stride, + const int16_t filter[4], + int16x8_t a[5]) { + a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride); + a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride); + a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride); + a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride); + a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride); + int32x4x2_t sum; + sum.val[0] = sum.val[1] = vdupq_n_s32(0); + sum = WienerVertical2(a[0], a[4], filter[1], sum); + return WienerVertical(a + 1, filter, sum); +} + +inline uint8x8x2_t WienerVerticalTap5Kernel2(const int16_t* const wiener_buffer, + const ptrdiff_t wiener_stride, + const int16_t filter[4]) { + int16x8_t a[6]; + int32x4x2_t sum; + uint8x8x2_t d; + d.val[0] = WienerVerticalTap5Kernel(wiener_buffer, wiener_stride, filter, a); + a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride); + sum.val[0] = sum.val[1] = vdupq_n_s32(0); + sum = WienerVertical2(a[1], a[5], filter[1], sum); + d.val[1] = WienerVertical(a + 2, filter, sum); + return d; +} + +inline void WienerVerticalTap5(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + const int16_t filter[4], uint8_t* dst, + const ptrdiff_t dst_stride) { + for (int y = height >> 1; y != 0; --y) { + uint8_t* dst_ptr = dst; + ptrdiff_t x = width; + do { + uint8x8x2_t d[2]; + d[0] = WienerVerticalTap5Kernel2(wiener_buffer + 0, width, filter); + d[1] = WienerVerticalTap5Kernel2(wiener_buffer + 8, width, filter); + vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0])); + vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1])); + wiener_buffer += 16; + dst_ptr += 16; + x -= 16; + } while (x != 0); + wiener_buffer += width; + dst += 2 * dst_stride; + } + + if ((height & 1) != 0) { + ptrdiff_t x = width; + do { + int16x8_t a[5]; + const uint8x8_t d0 = + WienerVerticalTap5Kernel(wiener_buffer + 0, width, filter, a); + const uint8x8_t d1 = + WienerVerticalTap5Kernel(wiener_buffer + 8, width, filter, a); + vst1q_u8(dst, vcombine_u8(d0, d1)); + wiener_buffer += 16; + dst += 16; + x -= 16; + } while (x != 0); + } +} + +inline uint8x8_t WienerVerticalTap3Kernel(const int16_t* const wiener_buffer, + const ptrdiff_t wiener_stride, + const int16_t filter[4], + int16x8_t a[3]) { + a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride); + a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride); + a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride); + int32x4x2_t sum; + sum.val[0] = sum.val[1] = vdupq_n_s32(0); + return WienerVertical(a, filter, sum); +} + +inline uint8x8x2_t WienerVerticalTap3Kernel2(const int16_t* const wiener_buffer, + const ptrdiff_t wiener_stride, + const int16_t filter[4]) { + int16x8_t a[4]; + int32x4x2_t sum; + uint8x8x2_t d; + d.val[0] = WienerVerticalTap3Kernel(wiener_buffer, wiener_stride, filter, a); + a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride); + sum.val[0] = sum.val[1] = vdupq_n_s32(0); + d.val[1] = WienerVertical(a + 1, filter, sum); + return d; +} + +inline void WienerVerticalTap3(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + const int16_t filter[4], uint8_t* dst, + const ptrdiff_t dst_stride) { + for (int y = height >> 1; y != 0; --y) { + uint8_t* dst_ptr = dst; + ptrdiff_t x = width; + do { + uint8x8x2_t d[2]; + d[0] = WienerVerticalTap3Kernel2(wiener_buffer + 0, width, filter); + d[1] = WienerVerticalTap3Kernel2(wiener_buffer + 8, width, filter); + vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0])); + vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1])); + wiener_buffer += 16; + dst_ptr += 16; + x -= 16; + } while (x != 0); + wiener_buffer += width; + dst += 2 * dst_stride; + } + + if ((height & 1) != 0) { + ptrdiff_t x = width; + do { + int16x8_t a[3]; + const uint8x8_t d0 = + WienerVerticalTap3Kernel(wiener_buffer + 0, width, filter, a); + const uint8x8_t d1 = + WienerVerticalTap3Kernel(wiener_buffer + 8, width, filter, a); + vst1q_u8(dst, vcombine_u8(d0, d1)); + wiener_buffer += 16; + dst += 16; + x -= 16; + } while (x != 0); + } +} + +inline void WienerVerticalTap1Kernel(const int16_t* const wiener_buffer, + uint8_t* const dst) { + const int16x8_t a0 = vld1q_s16(wiener_buffer + 0); + const int16x8_t a1 = vld1q_s16(wiener_buffer + 8); + const uint8x8_t d0 = vqrshrun_n_s16(a0, 4); + const uint8x8_t d1 = vqrshrun_n_s16(a1, 4); + vst1q_u8(dst, vcombine_u8(d0, d1)); +} + +inline void WienerVerticalTap1(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + uint8_t* dst, const ptrdiff_t dst_stride) { + for (int y = height >> 1; y != 0; --y) { + uint8_t* dst_ptr = dst; + ptrdiff_t x = width; + do { + WienerVerticalTap1Kernel(wiener_buffer, dst_ptr); + WienerVerticalTap1Kernel(wiener_buffer + width, dst_ptr + dst_stride); + wiener_buffer += 16; + dst_ptr += 16; + x -= 16; + } while (x != 0); + wiener_buffer += width; + dst += 2 * dst_stride; + } + + if ((height & 1) != 0) { + ptrdiff_t x = width; + do { + WienerVerticalTap1Kernel(wiener_buffer, dst); + wiener_buffer += 16; + dst += 16; + x -= 16; + } while (x != 0); + } +} + +// For width 16 and up, store the horizontal results, and then do the vertical +// filter row by row. This is faster than doing it column by column when +// considering cache issues. +void WienerFilter_NEON(const RestorationUnitInfo& restoration_info, + const void* const source, const void* const top_border, + const void* const bottom_border, const ptrdiff_t stride, + const int width, const int height, + RestorationBuffer* const restoration_buffer, + void* const dest) { + const int16_t* const number_leading_zero_coefficients = + restoration_info.wiener_info.number_leading_zero_coefficients; + const int number_rows_to_skip = std::max( + static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]), + 1); + const ptrdiff_t wiener_stride = Align(width, 16); + int16_t* const wiener_buffer_vertical = restoration_buffer->wiener_buffer; + // The values are saturated to 13 bits before storing. + int16_t* wiener_buffer_horizontal = + wiener_buffer_vertical + number_rows_to_skip * wiener_stride; + int16_t filter_horizontal[(kWienerFilterTaps + 1) / 2]; + int16_t filter_vertical[(kWienerFilterTaps + 1) / 2]; + PopulateWienerCoefficients(restoration_info, WienerInfo::kHorizontal, + filter_horizontal); + PopulateWienerCoefficients(restoration_info, WienerInfo::kVertical, + filter_vertical); + + // horizontal filtering. + // Over-reads up to 15 - |kRestorationHorizontalBorder| values. + const int height_horizontal = + height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip; + const int height_extra = (height_horizontal - height) >> 1; + assert(height_extra <= 2); + const auto* const src = static_cast<const uint8_t*>(source); + const auto* const top = static_cast<const uint8_t*>(top_border); + const auto* const bottom = static_cast<const uint8_t*>(bottom_border); + if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) { + WienerHorizontalTap7(top + (2 - height_extra) * stride - 3, stride, + wiener_stride, height_extra, filter_horizontal, + &wiener_buffer_horizontal); + WienerHorizontalTap7(src - 3, stride, wiener_stride, height, + filter_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap7(bottom - 3, stride, wiener_stride, height_extra, + filter_horizontal, &wiener_buffer_horizontal); + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) { + WienerHorizontalTap5(top + (2 - height_extra) * stride - 2, stride, + wiener_stride, height_extra, filter_horizontal, + &wiener_buffer_horizontal); + WienerHorizontalTap5(src - 2, stride, wiener_stride, height, + filter_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap5(bottom - 2, stride, wiener_stride, height_extra, + filter_horizontal, &wiener_buffer_horizontal); + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) { + // The maximum over-reads happen here. + WienerHorizontalTap3(top + (2 - height_extra) * stride - 1, stride, + wiener_stride, height_extra, filter_horizontal, + &wiener_buffer_horizontal); + WienerHorizontalTap3(src - 1, stride, wiener_stride, height, + filter_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap3(bottom - 1, stride, wiener_stride, height_extra, + filter_horizontal, &wiener_buffer_horizontal); + } else { + assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3); + WienerHorizontalTap1(top + (2 - height_extra) * stride, stride, + wiener_stride, height_extra, + &wiener_buffer_horizontal); + WienerHorizontalTap1(src, stride, wiener_stride, height, + &wiener_buffer_horizontal); + WienerHorizontalTap1(bottom, stride, wiener_stride, height_extra, + &wiener_buffer_horizontal); + } + + // vertical filtering. + // Over-writes up to 15 values. + auto* dst = static_cast<uint8_t*>(dest); + if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) { + // Because the top row of |source| is a duplicate of the second row, and the + // bottom row of |source| is a duplicate of its above row, we can duplicate + // the top and bottom row of |wiener_buffer| accordingly. + memcpy(wiener_buffer_horizontal, wiener_buffer_horizontal - wiener_stride, + sizeof(*wiener_buffer_horizontal) * wiener_stride); + memcpy(restoration_buffer->wiener_buffer, + restoration_buffer->wiener_buffer + wiener_stride, + sizeof(*restoration_buffer->wiener_buffer) * wiener_stride); + WienerVerticalTap7(wiener_buffer_vertical, wiener_stride, height, + filter_vertical, dst, stride); + } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) { + WienerVerticalTap5(wiener_buffer_vertical + wiener_stride, wiener_stride, + height, filter_vertical, dst, stride); + } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) { + WienerVerticalTap3(wiener_buffer_vertical + 2 * wiener_stride, + wiener_stride, height, filter_vertical, dst, stride); + } else { + assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3); + WienerVerticalTap1(wiener_buffer_vertical + 3 * wiener_stride, + wiener_stride, height, dst, stride); + } +} + +//------------------------------------------------------------------------------ +// SGR + +inline void Prepare3_8(const uint8x8x2_t src, uint8x8_t dst[3]) { + dst[0] = VshrU128<0>(src); + dst[1] = VshrU128<1>(src); + dst[2] = VshrU128<2>(src); +} + +inline void Prepare3_16(const uint16x8x2_t src, uint16x4_t low[3], + uint16x4_t high[3]) { + uint16x8_t s[3]; + s[0] = VshrU128<0>(src); + s[1] = VshrU128<2>(src); + s[2] = VshrU128<4>(src); + low[0] = vget_low_u16(s[0]); + low[1] = vget_low_u16(s[1]); + low[2] = vget_low_u16(s[2]); + high[0] = vget_high_u16(s[0]); + high[1] = vget_high_u16(s[1]); + high[2] = vget_high_u16(s[2]); +} + +inline void Prepare5_8(const uint8x8x2_t src, uint8x8_t dst[5]) { + dst[0] = VshrU128<0>(src); + dst[1] = VshrU128<1>(src); + dst[2] = VshrU128<2>(src); + dst[3] = VshrU128<3>(src); + dst[4] = VshrU128<4>(src); +} + +inline void Prepare5_16(const uint16x8x2_t src, uint16x4_t low[5], + uint16x4_t high[5]) { + Prepare3_16(src, low, high); + const uint16x8_t s3 = VshrU128<6>(src); + const uint16x8_t s4 = VshrU128<8>(src); + low[3] = vget_low_u16(s3); + low[4] = vget_low_u16(s4); + high[3] = vget_high_u16(s3); + high[4] = vget_high_u16(s4); +} + +inline uint16x8_t Sum3_16(const uint16x8_t src0, const uint16x8_t src1, + const uint16x8_t src2) { + const uint16x8_t sum = vaddq_u16(src0, src1); + return vaddq_u16(sum, src2); +} + +inline uint16x8_t Sum3_16(const uint16x8_t src[3]) { + return Sum3_16(src[0], src[1], src[2]); +} + +inline uint32x4_t Sum3_32(const uint32x4_t src0, const uint32x4_t src1, + const uint32x4_t src2) { + const uint32x4_t sum = vaddq_u32(src0, src1); + return vaddq_u32(sum, src2); +} + +inline uint32x4x2_t Sum3_32(const uint32x4x2_t src[3]) { + uint32x4x2_t d; + d.val[0] = Sum3_32(src[0].val[0], src[1].val[0], src[2].val[0]); + d.val[1] = Sum3_32(src[0].val[1], src[1].val[1], src[2].val[1]); + return d; +} + +inline uint16x8_t Sum3W_16(const uint8x8_t src[3]) { + const uint16x8_t sum = vaddl_u8(src[0], src[1]); + return vaddw_u8(sum, src[2]); +} + +inline uint32x4_t Sum3W_32(const uint16x4_t src[3]) { + const uint32x4_t sum = vaddl_u16(src[0], src[1]); + return vaddw_u16(sum, src[2]); +} + +inline uint16x8_t Sum5_16(const uint16x8_t src[5]) { + const uint16x8_t sum01 = vaddq_u16(src[0], src[1]); + const uint16x8_t sum23 = vaddq_u16(src[2], src[3]); + const uint16x8_t sum = vaddq_u16(sum01, sum23); + return vaddq_u16(sum, src[4]); +} + +inline uint32x4_t Sum5_32(const uint32x4_t src0, const uint32x4_t src1, + const uint32x4_t src2, const uint32x4_t src3, + const uint32x4_t src4) { + const uint32x4_t sum01 = vaddq_u32(src0, src1); + const uint32x4_t sum23 = vaddq_u32(src2, src3); + const uint32x4_t sum = vaddq_u32(sum01, sum23); + return vaddq_u32(sum, src4); +} + +inline uint32x4x2_t Sum5_32(const uint32x4x2_t src[5]) { + uint32x4x2_t d; + d.val[0] = Sum5_32(src[0].val[0], src[1].val[0], src[2].val[0], src[3].val[0], + src[4].val[0]); + d.val[1] = Sum5_32(src[0].val[1], src[1].val[1], src[2].val[1], src[3].val[1], + src[4].val[1]); + return d; +} + +inline uint32x4_t Sum5W_32(const uint16x4_t src[5]) { + const uint32x4_t sum01 = vaddl_u16(src[0], src[1]); + const uint32x4_t sum23 = vaddl_u16(src[2], src[3]); + const uint32x4_t sum0123 = vaddq_u32(sum01, sum23); + return vaddw_u16(sum0123, src[4]); +} + +inline uint16x8_t Sum3Horizontal(const uint8x8x2_t src) { + uint8x8_t s[3]; + Prepare3_8(src, s); + return Sum3W_16(s); +} + +inline uint32x4x2_t Sum3WHorizontal(const uint16x8x2_t src) { + uint16x4_t low[3], high[3]; + uint32x4x2_t sum; + Prepare3_16(src, low, high); + sum.val[0] = Sum3W_32(low); + sum.val[1] = Sum3W_32(high); + return sum; +} + +inline uint16x8_t Sum5Horizontal(const uint8x8x2_t src) { + uint8x8_t s[5]; + Prepare5_8(src, s); + const uint16x8_t sum01 = vaddl_u8(s[0], s[1]); + const uint16x8_t sum23 = vaddl_u8(s[2], s[3]); + const uint16x8_t sum0123 = vaddq_u16(sum01, sum23); + return vaddw_u8(sum0123, s[4]); +} + +inline uint32x4x2_t Sum5WHorizontal(const uint16x8x2_t src) { + uint16x4_t low[5], high[5]; + Prepare5_16(src, low, high); + uint32x4x2_t sum; + sum.val[0] = Sum5W_32(low); + sum.val[1] = Sum5W_32(high); + return sum; +} + +void SumHorizontal(const uint16x4_t src[5], uint32x4_t* const row_sq3, + uint32x4_t* const row_sq5) { + const uint32x4_t sum04 = vaddl_u16(src[0], src[4]); + const uint32x4_t sum12 = vaddl_u16(src[1], src[2]); + *row_sq3 = vaddw_u16(sum12, src[3]); + *row_sq5 = vaddq_u32(sum04, *row_sq3); +} + +void SumHorizontal(const uint8x8x2_t src, const uint16x8x2_t sq, + uint16x8_t* const row3, uint16x8_t* const row5, + uint32x4x2_t* const row_sq3, uint32x4x2_t* const row_sq5) { + uint8x8_t s[5]; + Prepare5_8(src, s); + const uint16x8_t sum04 = vaddl_u8(s[0], s[4]); + const uint16x8_t sum12 = vaddl_u8(s[1], s[2]); + *row3 = vaddw_u8(sum12, s[3]); + *row5 = vaddq_u16(sum04, *row3); + uint16x4_t low[5], high[5]; + Prepare5_16(sq, low, high); + SumHorizontal(low, &row_sq3->val[0], &row_sq5->val[0]); + SumHorizontal(high, &row_sq3->val[1], &row_sq5->val[1]); +} + +inline uint16x8_t Sum343(const uint8x8x2_t src) { + uint8x8_t s[3]; + Prepare3_8(src, s); + const uint16x8_t sum = Sum3W_16(s); + const uint16x8_t sum3 = Sum3_16(sum, sum, sum); + return vaddw_u8(sum3, s[1]); +} + +inline uint32x4_t Sum343W(const uint16x4_t src[3]) { + const uint32x4_t sum = Sum3W_32(src); + const uint32x4_t sum3 = Sum3_32(sum, sum, sum); + return vaddw_u16(sum3, src[1]); +} + +inline uint32x4x2_t Sum343W(const uint16x8x2_t src) { + uint16x4_t low[3], high[3]; + uint32x4x2_t d; + Prepare3_16(src, low, high); + d.val[0] = Sum343W(low); + d.val[1] = Sum343W(high); + return d; +} + +inline uint16x8_t Sum565(const uint8x8x2_t src) { + uint8x8_t s[3]; + Prepare3_8(src, s); + const uint16x8_t sum = Sum3W_16(s); + const uint16x8_t sum4 = vshlq_n_u16(sum, 2); + const uint16x8_t sum5 = vaddq_u16(sum4, sum); + return vaddw_u8(sum5, s[1]); +} + +inline uint32x4_t Sum565W(const uint16x4_t src[3]) { + const uint32x4_t sum = Sum3W_32(src); + const uint32x4_t sum4 = vshlq_n_u32(sum, 2); + const uint32x4_t sum5 = vaddq_u32(sum4, sum); + return vaddw_u16(sum5, src[1]); +} + +inline uint32x4x2_t Sum565W(const uint16x8x2_t src) { + uint16x4_t low[3], high[3]; + uint32x4x2_t d; + Prepare3_16(src, low, high); + d.val[0] = Sum565W(low); + d.val[1] = Sum565W(high); + return d; +} + +inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride, + const int height, const ptrdiff_t sum_stride, uint16_t* sum3, + uint16_t* sum5, uint32_t* square_sum3, + uint32_t* square_sum5) { + int y = height; + do { + uint8x8x2_t s; + uint16x8x2_t sq; + s.val[0] = vld1_u8(src); + sq.val[0] = vmull_u8(s.val[0], s.val[0]); + ptrdiff_t x = 0; + do { + uint16x8_t row3, row5; + uint32x4x2_t row_sq3, row_sq5; + s.val[1] = vld1_u8(src + x + 8); + sq.val[1] = vmull_u8(s.val[1], s.val[1]); + SumHorizontal(s, sq, &row3, &row5, &row_sq3, &row_sq5); + vst1q_u16(sum3, row3); + vst1q_u16(sum5, row5); + vst1q_u32(square_sum3 + 0, row_sq3.val[0]); + vst1q_u32(square_sum3 + 4, row_sq3.val[1]); + vst1q_u32(square_sum5 + 0, row_sq5.val[0]); + vst1q_u32(square_sum5 + 4, row_sq5.val[1]); + s.val[0] = s.val[1]; + sq.val[0] = sq.val[1]; + sum3 += 8; + sum5 += 8; + square_sum3 += 8; + square_sum5 += 8; + x += 8; + } while (x < sum_stride); + src += src_stride; + } while (--y != 0); +} + +template <int size> +inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride, + const int height, const ptrdiff_t sum_stride, uint16_t* sums, + uint32_t* square_sums) { + static_assert(size == 3 || size == 5, ""); + int y = height; + do { + uint8x8x2_t s; + uint16x8x2_t sq; + s.val[0] = vld1_u8(src); + sq.val[0] = vmull_u8(s.val[0], s.val[0]); + ptrdiff_t x = 0; + do { + uint16x8_t row; + uint32x4x2_t row_sq; + s.val[1] = vld1_u8(src + x + 8); + sq.val[1] = vmull_u8(s.val[1], s.val[1]); + if (size == 3) { + row = Sum3Horizontal(s); + row_sq = Sum3WHorizontal(sq); + } else { + row = Sum5Horizontal(s); + row_sq = Sum5WHorizontal(sq); + } + vst1q_u16(sums, row); + vst1q_u32(square_sums + 0, row_sq.val[0]); + vst1q_u32(square_sums + 4, row_sq.val[1]); + s.val[0] = s.val[1]; + sq.val[0] = sq.val[1]; + sums += 8; + square_sums += 8; + x += 8; + } while (x < sum_stride); + src += src_stride; + } while (--y != 0); +} + +template <int n> +inline uint16x4_t CalculateMa(const uint16x4_t sum, const uint32x4_t sum_sq, + const uint32_t scale) { + // a = |sum_sq| + // d = |sum| + // p = (a * n < d * d) ? 0 : a * n - d * d; + const uint32x4_t dxd = vmull_u16(sum, sum); + const uint32x4_t axn = vmulq_n_u32(sum_sq, n); + // Ensure |p| does not underflow by using saturating subtraction. + const uint32x4_t p = vqsubq_u32(axn, dxd); + const uint32x4_t pxs = vmulq_n_u32(p, scale); + // vrshrn_n_u32() (narrowing shift) can only shift by 16 and kSgrProjScaleBits + // is 20. + const uint32x4_t shifted = vrshrq_n_u32(pxs, kSgrProjScaleBits); + return vmovn_u32(shifted); +} + +template <int n> +inline void CalculateIntermediate(const uint16x8_t sum, + const uint32x4x2_t sum_sq, + const uint32_t scale, uint8x8_t* const ma, + uint16x8_t* const b) { + constexpr uint32_t one_over_n = + ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n; + const uint16x4_t z0 = CalculateMa<n>(vget_low_u16(sum), sum_sq.val[0], scale); + const uint16x4_t z1 = + CalculateMa<n>(vget_high_u16(sum), sum_sq.val[1], scale); + const uint16x8_t z01 = vcombine_u16(z0, z1); + // Using vqmovn_u16() needs an extra sign extension instruction. + const uint16x8_t z = vminq_u16(z01, vdupq_n_u16(255)); + // Using vgetq_lane_s16() can save the sign extension instruction. + const uint8_t lookup[8] = { + kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 0)], + kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 1)], + kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 2)], + kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 3)], + kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 4)], + kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 5)], + kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 6)], + kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 7)]}; + *ma = vld1_u8(lookup); + // b = ma * b * one_over_n + // |ma| = [0, 255] + // |sum| is a box sum with radius 1 or 2. + // For the first pass radius is 2. Maximum value is 5x5x255 = 6375. + // For the second pass radius is 1. Maximum value is 3x3x255 = 2295. + // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n + // When radius is 2 |n| is 25. |one_over_n| is 164. + // When radius is 1 |n| is 9. |one_over_n| is 455. + // |kSgrProjReciprocalBits| is 12. + // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits). + // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits). + const uint16x8_t maq = vmovl_u8(*ma); + const uint32x4_t m0 = vmull_u16(vget_low_u16(maq), vget_low_u16(sum)); + const uint32x4_t m1 = vmull_u16(vget_high_u16(maq), vget_high_u16(sum)); + const uint32x4_t m2 = vmulq_n_u32(m0, one_over_n); + const uint32x4_t m3 = vmulq_n_u32(m1, one_over_n); + const uint16x4_t b_lo = vrshrn_n_u32(m2, kSgrProjReciprocalBits); + const uint16x4_t b_hi = vrshrn_n_u32(m3, kSgrProjReciprocalBits); + *b = vcombine_u16(b_lo, b_hi); +} + +inline void CalculateIntermediate5(const uint16x8_t s5[5], + const uint32x4x2_t sq5[5], + const uint32_t scale, uint8x8_t* const ma, + uint16x8_t* const b) { + const uint16x8_t sum = Sum5_16(s5); + const uint32x4x2_t sum_sq = Sum5_32(sq5); + CalculateIntermediate<25>(sum, sum_sq, scale, ma, b); +} + +inline void CalculateIntermediate3(const uint16x8_t s3[3], + const uint32x4x2_t sq3[3], + const uint32_t scale, uint8x8_t* const ma, + uint16x8_t* const b) { + const uint16x8_t sum = Sum3_16(s3); + const uint32x4x2_t sum_sq = Sum3_32(sq3); + CalculateIntermediate<9>(sum, sum_sq, scale, ma, b); +} + +inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3, + const ptrdiff_t x, uint16x8_t* const sum_ma343, + uint16x8_t* const sum_ma444, + uint32x4x2_t* const sum_b343, + uint32x4x2_t* const sum_b444, uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + uint8x8_t s[3]; + Prepare3_8(ma3, s); + const uint16x8_t sum_ma111 = Sum3W_16(s); + *sum_ma444 = vshlq_n_u16(sum_ma111, 2); + const uint16x8_t sum333 = vsubq_u16(*sum_ma444, sum_ma111); + *sum_ma343 = vaddw_u8(sum333, s[1]); + uint16x4_t low[3], high[3]; + uint32x4x2_t sum_b111; + Prepare3_16(b3, low, high); + sum_b111.val[0] = Sum3W_32(low); + sum_b111.val[1] = Sum3W_32(high); + sum_b444->val[0] = vshlq_n_u32(sum_b111.val[0], 2); + sum_b444->val[1] = vshlq_n_u32(sum_b111.val[1], 2); + sum_b343->val[0] = vsubq_u32(sum_b444->val[0], sum_b111.val[0]); + sum_b343->val[1] = vsubq_u32(sum_b444->val[1], sum_b111.val[1]); + sum_b343->val[0] = vaddw_u16(sum_b343->val[0], low[1]); + sum_b343->val[1] = vaddw_u16(sum_b343->val[1], high[1]); + vst1q_u16(ma343 + x, *sum_ma343); + vst1q_u16(ma444 + x, *sum_ma444); + vst1q_u32(b343 + x + 0, sum_b343->val[0]); + vst1q_u32(b343 + x + 4, sum_b343->val[1]); + vst1q_u32(b444 + x + 0, sum_b444->val[0]); + vst1q_u32(b444 + x + 4, sum_b444->val[1]); +} + +inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3, + const ptrdiff_t x, uint16x8_t* const sum_ma343, + uint32x4x2_t* const sum_b343, uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + uint16x8_t sum_ma444; + uint32x4x2_t sum_b444; + Store343_444(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, &sum_b444, ma343, + ma444, b343, b444); +} + +inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3, + const ptrdiff_t x, uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + uint16x8_t sum_ma343; + uint32x4x2_t sum_b343; + Store343_444(ma3, b3, x, &sum_ma343, &sum_b343, ma343, ma444, b343, b444); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5( + const uint8_t* const src0, const uint8_t* const src1, const ptrdiff_t x, + const uint32_t scale, uint16_t* const sum5[5], + uint32_t* const square_sum5[5], uint8x8x2_t s[2], uint16x8x2_t sq[2], + uint8x8_t* const ma, uint16x8_t* const b) { + uint16x8_t s5[5]; + uint32x4x2_t sq5[5]; + s[0].val[1] = vld1_u8(src0 + x + 8); + s[1].val[1] = vld1_u8(src1 + x + 8); + sq[0].val[1] = vmull_u8(s[0].val[1], s[0].val[1]); + sq[1].val[1] = vmull_u8(s[1].val[1], s[1].val[1]); + s5[3] = Sum5Horizontal(s[0]); + s5[4] = Sum5Horizontal(s[1]); + sq5[3] = Sum5WHorizontal(sq[0]); + sq5[4] = Sum5WHorizontal(sq[1]); + vst1q_u16(sum5[3] + x, s5[3]); + vst1q_u16(sum5[4] + x, s5[4]); + vst1q_u32(square_sum5[3] + x + 0, sq5[3].val[0]); + vst1q_u32(square_sum5[3] + x + 4, sq5[3].val[1]); + vst1q_u32(square_sum5[4] + x + 0, sq5[4].val[0]); + vst1q_u32(square_sum5[4] + x + 4, sq5[4].val[1]); + s5[0] = vld1q_u16(sum5[0] + x); + s5[1] = vld1q_u16(sum5[1] + x); + s5[2] = vld1q_u16(sum5[2] + x); + sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0); + sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4); + sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0); + sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4); + sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0); + sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4); + CalculateIntermediate5(s5, sq5, scale, ma, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRow( + const uint8_t* const src, const ptrdiff_t x, const uint32_t scale, + const uint16_t* const sum5[5], const uint32_t* const square_sum5[5], + uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma, + uint16x8_t* const b) { + uint16x8_t s5[5]; + uint32x4x2_t sq5[5]; + s->val[1] = vld1_u8(src + x + 8); + sq->val[1] = vmull_u8(s->val[1], s->val[1]); + s5[3] = s5[4] = Sum5Horizontal(*s); + sq5[3] = sq5[4] = Sum5WHorizontal(*sq); + s5[0] = vld1q_u16(sum5[0] + x); + s5[1] = vld1q_u16(sum5[1] + x); + s5[2] = vld1q_u16(sum5[2] + x); + sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0); + sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4); + sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0); + sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4); + sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0); + sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4); + CalculateIntermediate5(s5, sq5, scale, ma, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3( + const uint8_t* const src, const ptrdiff_t x, const uint32_t scale, + uint16_t* const sum3[3], uint32_t* const square_sum3[3], + uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma, + uint16x8_t* const b) { + uint16x8_t s3[3]; + uint32x4x2_t sq3[3]; + s->val[1] = vld1_u8(src + x + 8); + sq->val[1] = vmull_u8(s->val[1], s->val[1]); + s3[2] = Sum3Horizontal(*s); + sq3[2] = Sum3WHorizontal(*sq); + vst1q_u16(sum3[2] + x, s3[2]); + vst1q_u32(square_sum3[2] + x + 0, sq3[2].val[0]); + vst1q_u32(square_sum3[2] + x + 4, sq3[2].val[1]); + s3[0] = vld1q_u16(sum3[0] + x); + s3[1] = vld1q_u16(sum3[1] + x); + sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0); + sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4); + sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0); + sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4); + CalculateIntermediate3(s3, sq3, scale, ma, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess( + const uint8_t* const src0, const uint8_t* const src1, const ptrdiff_t x, + const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + uint8x8x2_t s[2], uint16x8x2_t sq[2], uint8x8_t* const ma3_0, + uint8x8_t* const ma3_1, uint16x8_t* const b3_0, uint16x8_t* const b3_1, + uint8x8_t* const ma5, uint16x8_t* const b5) { + uint16x8_t s3[4], s5[5]; + uint32x4x2_t sq3[4], sq5[5]; + s[0].val[1] = vld1_u8(src0 + x + 8); + s[1].val[1] = vld1_u8(src1 + x + 8); + sq[0].val[1] = vmull_u8(s[0].val[1], s[0].val[1]); + sq[1].val[1] = vmull_u8(s[1].val[1], s[1].val[1]); + SumHorizontal(s[0], sq[0], &s3[2], &s5[3], &sq3[2], &sq5[3]); + SumHorizontal(s[1], sq[1], &s3[3], &s5[4], &sq3[3], &sq5[4]); + vst1q_u16(sum3[2] + x, s3[2]); + vst1q_u16(sum3[3] + x, s3[3]); + vst1q_u32(square_sum3[2] + x + 0, sq3[2].val[0]); + vst1q_u32(square_sum3[2] + x + 4, sq3[2].val[1]); + vst1q_u32(square_sum3[3] + x + 0, sq3[3].val[0]); + vst1q_u32(square_sum3[3] + x + 4, sq3[3].val[1]); + vst1q_u16(sum5[3] + x, s5[3]); + vst1q_u16(sum5[4] + x, s5[4]); + vst1q_u32(square_sum5[3] + x + 0, sq5[3].val[0]); + vst1q_u32(square_sum5[3] + x + 4, sq5[3].val[1]); + vst1q_u32(square_sum5[4] + x + 0, sq5[4].val[0]); + vst1q_u32(square_sum5[4] + x + 4, sq5[4].val[1]); + s3[0] = vld1q_u16(sum3[0] + x); + s3[1] = vld1q_u16(sum3[1] + x); + sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0); + sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4); + sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0); + sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4); + s5[0] = vld1q_u16(sum5[0] + x); + s5[1] = vld1q_u16(sum5[1] + x); + s5[2] = vld1q_u16(sum5[2] + x); + sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0); + sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4); + sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0); + sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4); + sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0); + sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4); + CalculateIntermediate3(s3, sq3, scales[1], ma3_0, b3_0); + CalculateIntermediate3(s3 + 1, sq3 + 1, scales[1], ma3_1, b3_1); + CalculateIntermediate5(s5, sq5, scales[0], ma5, b5); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow( + const uint8_t* const src, const ptrdiff_t x, const uint16_t scales[2], + const uint16_t* const sum3[4], const uint16_t* const sum5[5], + const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5], + uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma3, + uint8x8_t* const ma5, uint16x8_t* const b3, uint16x8_t* const b5) { + uint16x8_t s3[3], s5[5]; + uint32x4x2_t sq3[3], sq5[5]; + s->val[1] = vld1_u8(src + x + 8); + sq->val[1] = vmull_u8(s->val[1], s->val[1]); + SumHorizontal(*s, *sq, &s3[2], &s5[3], &sq3[2], &sq5[3]); + s5[0] = vld1q_u16(sum5[0] + x); + s5[1] = vld1q_u16(sum5[1] + x); + s5[2] = vld1q_u16(sum5[2] + x); + s5[4] = s5[3]; + sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0); + sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4); + sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0); + sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4); + sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0); + sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4); + sq5[4] = sq5[3]; + CalculateIntermediate5(s5, sq5, scales[0], ma5, b5); + s3[0] = vld1q_u16(sum3[0] + x); + s3[1] = vld1q_u16(sum3[1] + x); + sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0); + sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4); + sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0); + sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4); + CalculateIntermediate3(s3, sq3, scales[1], ma3, b3); +} + +inline void BoxSumFilterPreProcess5(const uint8_t* const src0, + const uint8_t* const src1, const int width, + const uint32_t scale, + uint16_t* const sum5[5], + uint32_t* const square_sum5[5], + uint16_t* ma565, uint32_t* b565) { + uint8x8x2_t s[2], mas; + uint16x8x2_t sq[2], bs; + s[0].val[0] = vld1_u8(src0); + s[1].val[0] = vld1_u8(src1); + sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]); + sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]); + BoxFilterPreProcess5(src0, src1, 0, scale, sum5, square_sum5, s, sq, + &mas.val[0], &bs.val[0]); + + int x = 0; + do { + s[0].val[0] = s[0].val[1]; + s[1].val[0] = s[1].val[1]; + sq[0].val[0] = sq[0].val[1]; + sq[1].val[0] = sq[1].val[1]; + BoxFilterPreProcess5(src0, src1, x + 8, scale, sum5, square_sum5, s, sq, + &mas.val[1], &bs.val[1]); + const uint16x8_t ma = Sum565(mas); + const uint32x4x2_t b = Sum565W(bs); + vst1q_u16(ma565, ma); + vst1q_u32(b565 + 0, b.val[0]); + vst1q_u32(b565 + 4, b.val[1]); + mas.val[0] = mas.val[1]; + bs.val[0] = bs.val[1]; + ma565 += 8; + b565 += 8; + x += 8; + } while (x < width); +} + +template <bool calculate444> +LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3( + const uint8_t* const src, const int width, const uint32_t scale, + uint16_t* const sum3[3], uint32_t* const square_sum3[3], uint16_t* ma343, + uint16_t* ma444, uint32_t* b343, uint32_t* b444) { + uint8x8x2_t s, mas; + uint16x8x2_t sq, bs; + s.val[0] = vld1_u8(src); + sq.val[0] = vmull_u8(s.val[0], s.val[0]); + BoxFilterPreProcess3(src, 0, scale, sum3, square_sum3, &s, &sq, &mas.val[0], + &bs.val[0]); + + int x = 0; + do { + s.val[0] = s.val[1]; + sq.val[0] = sq.val[1]; + BoxFilterPreProcess3(src, x + 8, scale, sum3, square_sum3, &s, &sq, + &mas.val[1], &bs.val[1]); + if (calculate444) { + Store343_444(mas, bs, 0, ma343, ma444, b343, b444); + ma444 += 8; + b444 += 8; + } else { + const uint16x8_t ma = Sum343(mas); + const uint32x4x2_t b = Sum343W(bs); + vst1q_u16(ma343, ma); + vst1q_u32(b343 + 0, b.val[0]); + vst1q_u32(b343 + 4, b.val[1]); + } + mas.val[0] = mas.val[1]; + bs.val[0] = bs.val[1]; + ma343 += 8; + b343 += 8; + x += 8; + } while (x < width); +} + +inline void BoxSumFilterPreProcess( + const uint8_t* const src0, const uint8_t* const src1, const int width, + const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + uint16_t* const ma343[4], uint16_t* const ma444[2], uint16_t* ma565, + uint32_t* const b343[4], uint32_t* const b444[2], uint32_t* b565) { + uint8x8x2_t s[2]; + uint8x8x2_t ma3[2], ma5; + uint16x8x2_t sq[2], b3[2], b5; + s[0].val[0] = vld1_u8(src0); + s[1].val[0] = vld1_u8(src1); + sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]); + sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]); + BoxFilterPreProcess(src0, src1, 0, scales, sum3, sum5, square_sum3, + square_sum5, s, sq, &ma3[0].val[0], &ma3[1].val[0], + &b3[0].val[0], &b3[1].val[0], &ma5.val[0], &b5.val[0]); + + int x = 0; + do { + s[0].val[0] = s[0].val[1]; + s[1].val[0] = s[1].val[1]; + sq[0].val[0] = sq[0].val[1]; + sq[1].val[0] = sq[1].val[1]; + BoxFilterPreProcess(src0, src1, x + 8, scales, sum3, sum5, square_sum3, + square_sum5, s, sq, &ma3[0].val[1], &ma3[1].val[1], + &b3[0].val[1], &b3[1].val[1], &ma5.val[1], &b5.val[1]); + uint16x8_t ma = Sum343(ma3[0]); + uint32x4x2_t b = Sum343W(b3[0]); + vst1q_u16(ma343[0] + x, ma); + vst1q_u32(b343[0] + x, b.val[0]); + vst1q_u32(b343[0] + x + 4, b.val[1]); + Store343_444(ma3[1], b3[1], x, ma343[1], ma444[0], b343[1], b444[0]); + ma = Sum565(ma5); + b = Sum565W(b5); + vst1q_u16(ma565, ma); + vst1q_u32(b565 + 0, b.val[0]); + vst1q_u32(b565 + 4, b.val[1]); + ma3[0].val[0] = ma3[0].val[1]; + ma3[1].val[0] = ma3[1].val[1]; + b3[0].val[0] = b3[0].val[1]; + b3[1].val[0] = b3[1].val[1]; + ma5.val[0] = ma5.val[1]; + b5.val[0] = b5.val[1]; + ma565 += 8; + b565 += 8; + x += 8; + } while (x < width); +} + +template <int shift> +inline int16x4_t FilterOutput(const uint16x4_t src, const uint16x4_t ma, + const uint32x4_t b) { + // ma: 255 * 32 = 8160 (13 bits) + // b: 65088 * 32 = 2082816 (21 bits) + // v: b - ma * 255 (22 bits) + const int32x4_t v = vreinterpretq_s32_u32(vmlsl_u16(b, ma, src)); + // kSgrProjSgrBits = 8 + // kSgrProjRestoreBits = 4 + // shift = 4 or 5 + // v >> 8 or 9 (13 bits) + return vrshrn_n_s32(v, kSgrProjSgrBits + shift - kSgrProjRestoreBits); +} + +template <int shift> +inline int16x8_t CalculateFilteredOutput(const uint8x8_t src, + const uint16x8_t ma, + const uint32x4x2_t b) { + const uint16x8_t src_u16 = vmovl_u8(src); + const int16x4_t dst_lo = + FilterOutput<shift>(vget_low_u16(src_u16), vget_low_u16(ma), b.val[0]); + const int16x4_t dst_hi = + FilterOutput<shift>(vget_high_u16(src_u16), vget_high_u16(ma), b.val[1]); + return vcombine_s16(dst_lo, dst_hi); // 13 bits +} + +inline int16x8_t CalculateFilteredOutputPass1(const uint8x8_t s, + uint16x8_t ma[2], + uint32x4x2_t b[2]) { + const uint16x8_t ma_sum = vaddq_u16(ma[0], ma[1]); + uint32x4x2_t b_sum; + b_sum.val[0] = vaddq_u32(b[0].val[0], b[1].val[0]); + b_sum.val[1] = vaddq_u32(b[0].val[1], b[1].val[1]); + return CalculateFilteredOutput<5>(s, ma_sum, b_sum); +} + +inline int16x8_t CalculateFilteredOutputPass2(const uint8x8_t s, + uint16x8_t ma[3], + uint32x4x2_t b[3]) { + const uint16x8_t ma_sum = Sum3_16(ma); + const uint32x4x2_t b_sum = Sum3_32(b); + return CalculateFilteredOutput<5>(s, ma_sum, b_sum); +} + +inline void SelfGuidedFinal(const uint8x8_t src, const int32x4_t v[2], + uint8_t* const dst) { + const int16x4_t v_lo = + vrshrn_n_s32(v[0], kSgrProjRestoreBits + kSgrProjPrecisionBits); + const int16x4_t v_hi = + vrshrn_n_s32(v[1], kSgrProjRestoreBits + kSgrProjPrecisionBits); + const int16x8_t vv = vcombine_s16(v_lo, v_hi); + const int16x8_t s = ZeroExtend(src); + const int16x8_t d = vaddq_s16(s, vv); + vst1_u8(dst, vqmovun_s16(d)); +} + +inline void SelfGuidedDoubleMultiplier(const uint8x8_t src, + const int16x8_t filter[2], const int w0, + const int w2, uint8_t* const dst) { + int32x4_t v[2]; + v[0] = vmull_n_s16(vget_low_s16(filter[0]), w0); + v[1] = vmull_n_s16(vget_high_s16(filter[0]), w0); + v[0] = vmlal_n_s16(v[0], vget_low_s16(filter[1]), w2); + v[1] = vmlal_n_s16(v[1], vget_high_s16(filter[1]), w2); + SelfGuidedFinal(src, v, dst); +} + +inline void SelfGuidedSingleMultiplier(const uint8x8_t src, + const int16x8_t filter, const int w0, + uint8_t* const dst) { + // weight: -96 to 96 (Sgrproj_Xqd_Min/Max) + int32x4_t v[2]; + v[0] = vmull_n_s16(vget_low_s16(filter), w0); + v[1] = vmull_n_s16(vget_high_s16(filter), w0); + SelfGuidedFinal(src, v, dst); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPass1( + const uint8_t* const src, const uint8_t* const src0, + const uint8_t* const src1, const ptrdiff_t stride, uint16_t* const sum5[5], + uint32_t* const square_sum5[5], const int width, const uint32_t scale, + const int16_t w0, uint16_t* const ma565[2], uint32_t* const b565[2], + uint8_t* const dst) { + uint8x8x2_t s[2], mas; + uint16x8x2_t sq[2], bs; + s[0].val[0] = vld1_u8(src0); + s[1].val[0] = vld1_u8(src1); + sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]); + sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]); + BoxFilterPreProcess5(src0, src1, 0, scale, sum5, square_sum5, s, sq, + &mas.val[0], &bs.val[0]); + + int x = 0; + do { + s[0].val[0] = s[0].val[1]; + s[1].val[0] = s[1].val[1]; + sq[0].val[0] = sq[0].val[1]; + sq[1].val[0] = sq[1].val[1]; + BoxFilterPreProcess5(src0, src1, x + 8, scale, sum5, square_sum5, s, sq, + &mas.val[1], &bs.val[1]); + uint16x8_t ma[2]; + uint32x4x2_t b[2]; + ma[1] = Sum565(mas); + b[1] = Sum565W(bs); + vst1q_u16(ma565[1] + x, ma[1]); + vst1q_u32(b565[1] + x + 0, b[1].val[0]); + vst1q_u32(b565[1] + x + 4, b[1].val[1]); + const uint8x8_t sr0 = vld1_u8(src + x); + const uint8x8_t sr1 = vld1_u8(src + stride + x); + int16x8_t p0, p1; + ma[0] = vld1q_u16(ma565[0] + x); + b[0].val[0] = vld1q_u32(b565[0] + x + 0); + b[0].val[1] = vld1q_u32(b565[0] + x + 4); + p0 = CalculateFilteredOutputPass1(sr0, ma, b); + p1 = CalculateFilteredOutput<4>(sr1, ma[1], b[1]); + SelfGuidedSingleMultiplier(sr0, p0, w0, dst + x); + SelfGuidedSingleMultiplier(sr1, p1, w0, dst + stride + x); + mas.val[0] = mas.val[1]; + bs.val[0] = bs.val[1]; + x += 8; + } while (x < width); +} + +inline void BoxFilterPass1LastRow(const uint8_t* const src, + const uint8_t* const src0, const int width, + const uint32_t scale, const int16_t w0, + uint16_t* const sum5[5], + uint32_t* const square_sum5[5], + uint16_t* ma565, uint32_t* b565, + uint8_t* const dst) { + uint8x8x2_t s, mas; + uint16x8x2_t sq, bs; + s.val[0] = vld1_u8(src0); + sq.val[0] = vmull_u8(s.val[0], s.val[0]); + BoxFilterPreProcess5LastRow(src0, 0, scale, sum5, square_sum5, &s, &sq, + &mas.val[0], &bs.val[0]); + + int x = 0; + do { + s.val[0] = s.val[1]; + sq.val[0] = sq.val[1]; + BoxFilterPreProcess5LastRow(src0, x + 8, scale, sum5, square_sum5, &s, &sq, + &mas.val[1], &bs.val[1]); + uint16x8_t ma[2]; + uint32x4x2_t b[2]; + ma[1] = Sum565(mas); + b[1] = Sum565W(bs); + mas.val[0] = mas.val[1]; + bs.val[0] = bs.val[1]; + ma[0] = vld1q_u16(ma565); + b[0].val[0] = vld1q_u32(b565 + 0); + b[0].val[1] = vld1q_u32(b565 + 4); + const uint8x8_t sr = vld1_u8(src + x); + const int16x8_t p = CalculateFilteredOutputPass1(sr, ma, b); + SelfGuidedSingleMultiplier(sr, p, w0, dst + x); + ma565 += 8; + b565 += 8; + x += 8; + } while (x < width); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPass2( + const uint8_t* const src, const uint8_t* const src0, const int width, + const uint32_t scale, const int16_t w0, uint16_t* const sum3[3], + uint32_t* const square_sum3[3], uint16_t* const ma343[3], + uint16_t* const ma444[2], uint32_t* const b343[3], uint32_t* const b444[2], + uint8_t* const dst) { + uint8x8x2_t s, mas; + uint16x8x2_t sq, bs; + s.val[0] = vld1_u8(src0); + sq.val[0] = vmull_u8(s.val[0], s.val[0]); + BoxFilterPreProcess3(src0, 0, scale, sum3, square_sum3, &s, &sq, &mas.val[0], + &bs.val[0]); + + int x = 0; + do { + s.val[0] = s.val[1]; + sq.val[0] = sq.val[1]; + BoxFilterPreProcess3(src0, x + 8, scale, sum3, square_sum3, &s, &sq, + &mas.val[1], &bs.val[1]); + uint16x8_t ma[3]; + uint32x4x2_t b[3]; + Store343_444(mas, bs, x, &ma[2], &b[2], ma343[2], ma444[1], b343[2], + b444[1]); + const uint8x8_t sr = vld1_u8(src + x); + ma[0] = vld1q_u16(ma343[0] + x); + ma[1] = vld1q_u16(ma444[0] + x); + b[0].val[0] = vld1q_u32(b343[0] + x + 0); + b[0].val[1] = vld1q_u32(b343[0] + x + 4); + b[1].val[0] = vld1q_u32(b444[0] + x + 0); + b[1].val[1] = vld1q_u32(b444[0] + x + 4); + const int16x8_t p = CalculateFilteredOutputPass2(sr, ma, b); + SelfGuidedSingleMultiplier(sr, p, w0, dst + x); + mas.val[0] = mas.val[1]; + bs.val[0] = bs.val[1]; + x += 8; + } while (x < width); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilter( + const uint8_t* const src, const uint8_t* const src0, + const uint8_t* const src1, const ptrdiff_t stride, const int width, + const uint16_t scales[2], const int16_t w0, const int16_t w2, + uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + uint16_t* const ma343[4], uint16_t* const ma444[3], + uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3], + uint32_t* const b565[2], uint8_t* const dst) { + uint8x8x2_t s[2], ma3[2], ma5; + uint16x8x2_t sq[2], b3[2], b5; + s[0].val[0] = vld1_u8(src0); + s[1].val[0] = vld1_u8(src1); + sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]); + sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]); + BoxFilterPreProcess(src0, src1, 0, scales, sum3, sum5, square_sum3, + square_sum5, s, sq, &ma3[0].val[0], &ma3[1].val[0], + &b3[0].val[0], &b3[1].val[0], &ma5.val[0], &b5.val[0]); + + int x = 0; + do { + s[0].val[0] = s[0].val[1]; + s[1].val[0] = s[1].val[1]; + sq[0].val[0] = sq[0].val[1]; + sq[1].val[0] = sq[1].val[1]; + BoxFilterPreProcess(src0, src1, x + 8, scales, sum3, sum5, square_sum3, + square_sum5, s, sq, &ma3[0].val[1], &ma3[1].val[1], + &b3[0].val[1], &b3[1].val[1], &ma5.val[1], &b5.val[1]); + uint16x8_t ma[3][3]; + uint32x4x2_t b[3][3]; + Store343_444(ma3[0], b3[0], x, &ma[1][2], &ma[2][1], &b[1][2], &b[2][1], + ma343[2], ma444[1], b343[2], b444[1]); + Store343_444(ma3[1], b3[1], x, &ma[2][2], &b[2][2], ma343[3], ma444[2], + b343[3], b444[2]); + ma[0][1] = Sum565(ma5); + b[0][1] = Sum565W(b5); + vst1q_u16(ma565[1] + x, ma[0][1]); + vst1q_u32(b565[1] + x, b[0][1].val[0]); + vst1q_u32(b565[1] + x + 4, b[0][1].val[1]); + ma3[0].val[0] = ma3[0].val[1]; + ma3[1].val[0] = ma3[1].val[1]; + b3[0].val[0] = b3[0].val[1]; + b3[1].val[0] = b3[1].val[1]; + ma5.val[0] = ma5.val[1]; + b5.val[0] = b5.val[1]; + int16x8_t p[2][2]; + const uint8x8_t sr0 = vld1_u8(src + x); + const uint8x8_t sr1 = vld1_u8(src + stride + x); + ma[0][0] = vld1q_u16(ma565[0] + x); + b[0][0].val[0] = vld1q_u32(b565[0] + x); + b[0][0].val[1] = vld1q_u32(b565[0] + x + 4); + p[0][0] = CalculateFilteredOutputPass1(sr0, ma[0], b[0]); + p[1][0] = CalculateFilteredOutput<4>(sr1, ma[0][1], b[0][1]); + ma[1][0] = vld1q_u16(ma343[0] + x); + ma[1][1] = vld1q_u16(ma444[0] + x); + b[1][0].val[0] = vld1q_u32(b343[0] + x); + b[1][0].val[1] = vld1q_u32(b343[0] + x + 4); + b[1][1].val[0] = vld1q_u32(b444[0] + x); + b[1][1].val[1] = vld1q_u32(b444[0] + x + 4); + p[0][1] = CalculateFilteredOutputPass2(sr0, ma[1], b[1]); + ma[2][0] = vld1q_u16(ma343[1] + x); + b[2][0].val[0] = vld1q_u32(b343[1] + x); + b[2][0].val[1] = vld1q_u32(b343[1] + x + 4); + p[1][1] = CalculateFilteredOutputPass2(sr1, ma[2], b[2]); + SelfGuidedDoubleMultiplier(sr0, p[0], w0, w2, dst + x); + SelfGuidedDoubleMultiplier(sr1, p[1], w0, w2, dst + stride + x); + x += 8; + } while (x < width); +} + +inline void BoxFilterLastRow( + const uint8_t* const src, const uint8_t* const src0, const int width, + const uint16_t scales[2], const int16_t w0, const int16_t w2, + uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + uint16_t* const ma343[4], uint16_t* const ma444[3], + uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3], + uint32_t* const b565[2], uint8_t* const dst) { + uint8x8x2_t s, ma3, ma5; + uint16x8x2_t sq, b3, b5; + uint16x8_t ma[3]; + uint32x4x2_t b[3]; + s.val[0] = vld1_u8(src0); + sq.val[0] = vmull_u8(s.val[0], s.val[0]); + BoxFilterPreProcessLastRow(src0, 0, scales, sum3, sum5, square_sum3, + square_sum5, &s, &sq, &ma3.val[0], &ma5.val[0], + &b3.val[0], &b5.val[0]); + + int x = 0; + do { + s.val[0] = s.val[1]; + sq.val[0] = sq.val[1]; + BoxFilterPreProcessLastRow(src0, x + 8, scales, sum3, sum5, square_sum3, + square_sum5, &s, &sq, &ma3.val[1], &ma5.val[1], + &b3.val[1], &b5.val[1]); + ma[1] = Sum565(ma5); + b[1] = Sum565W(b5); + ma5.val[0] = ma5.val[1]; + b5.val[0] = b5.val[1]; + ma[2] = Sum343(ma3); + b[2] = Sum343W(b3); + ma3.val[0] = ma3.val[1]; + b3.val[0] = b3.val[1]; + const uint8x8_t sr = vld1_u8(src + x); + int16x8_t p[2]; + ma[0] = vld1q_u16(ma565[0] + x); + b[0].val[0] = vld1q_u32(b565[0] + x + 0); + b[0].val[1] = vld1q_u32(b565[0] + x + 4); + p[0] = CalculateFilteredOutputPass1(sr, ma, b); + ma[0] = vld1q_u16(ma343[0] + x); + ma[1] = vld1q_u16(ma444[0] + x); + b[0].val[0] = vld1q_u32(b343[0] + x + 0); + b[0].val[1] = vld1q_u32(b343[0] + x + 4); + b[1].val[0] = vld1q_u32(b444[0] + x + 0); + b[1].val[1] = vld1q_u32(b444[0] + x + 4); + p[1] = CalculateFilteredOutputPass2(sr, ma, b); + SelfGuidedDoubleMultiplier(sr, p, w0, w2, dst + x); + x += 8; + } while (x < width); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( + const RestorationUnitInfo& restoration_info, const uint8_t* src, + const uint8_t* const top_border, const uint8_t* bottom_border, + const ptrdiff_t stride, const int width, const int height, + SgrBuffer* const sgr_buffer, uint8_t* dst) { + const auto temp_stride = Align<ptrdiff_t>(width, 8); + const ptrdiff_t sum_stride = temp_stride + 8; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index]; // < 2^12. + const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0]; + const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1]; + const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1; + uint16_t *sum3[4], *sum5[5], *ma343[4], *ma444[3], *ma565[2]; + uint32_t *square_sum3[4], *square_sum5[5], *b343[4], *b444[3], *b565[2]; + sum3[0] = sgr_buffer->sum3; + square_sum3[0] = sgr_buffer->square_sum3; + ma343[0] = sgr_buffer->ma343; + b343[0] = sgr_buffer->b343; + for (int i = 1; i <= 3; ++i) { + sum3[i] = sum3[i - 1] + sum_stride; + square_sum3[i] = square_sum3[i - 1] + sum_stride; + ma343[i] = ma343[i - 1] + temp_stride; + b343[i] = b343[i - 1] + temp_stride; + } + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + for (int i = 1; i <= 4; ++i) { + sum5[i] = sum5[i - 1] + sum_stride; + square_sum5[i] = square_sum5[i - 1] + sum_stride; + } + ma444[0] = sgr_buffer->ma444; + b444[0] = sgr_buffer->b444; + for (int i = 1; i <= 2; ++i) { + ma444[i] = ma444[i - 1] + temp_stride; + b444[i] = b444[i - 1] + temp_stride; + } + ma565[0] = sgr_buffer->ma565; + ma565[1] = ma565[0] + temp_stride; + b565[0] = sgr_buffer->b565; + b565[1] = b565[0] + temp_stride; + assert(scales[0] != 0); + assert(scales[1] != 0); + BoxSum(top_border, stride, 2, sum_stride, sum3[0], sum5[1], square_sum3[0], + square_sum5[1]); + sum5[0] = sum5[1]; + square_sum5[0] = square_sum5[1]; + const uint8_t* const s = (height > 1) ? src + stride : bottom_border; + BoxSumFilterPreProcess(src, s, width, scales, sum3, sum5, square_sum3, + square_sum5, ma343, ma444, ma565[0], b343, b444, + b565[0]); + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + + for (int y = (height >> 1) - 1; y > 0; --y) { + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + BoxFilter(src + 3, src + 2 * stride, src + 3 * stride, stride, width, + scales, w0, w2, sum3, sum5, square_sum3, square_sum5, ma343, + ma444, ma565, b343, b444, b565, dst); + src += 2 * stride; + dst += 2 * stride; + Circulate4PointersBy2<uint16_t>(ma343); + Circulate4PointersBy2<uint32_t>(b343); + std::swap(ma444[0], ma444[2]); + std::swap(b444[0], b444[2]); + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + if ((height & 1) == 0 || height > 1) { + const uint8_t* sr[2]; + if ((height & 1) == 0) { + sr[0] = bottom_border; + sr[1] = bottom_border + stride; + } else { + sr[0] = src + 2 * stride; + sr[1] = bottom_border; + } + BoxFilter(src + 3, sr[0], sr[1], stride, width, scales, w0, w2, sum3, sum5, + square_sum3, square_sum5, ma343, ma444, ma565, b343, b444, b565, + dst); + } + if ((height & 1) != 0) { + if (height > 1) { + src += 2 * stride; + dst += 2 * stride; + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + Circulate4PointersBy2<uint16_t>(ma343); + Circulate4PointersBy2<uint32_t>(b343); + std::swap(ma444[0], ma444[2]); + std::swap(b444[0], b444[2]); + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + BoxFilterLastRow(src + 3, bottom_border + stride, width, scales, w0, w2, + sum3, sum5, square_sum3, square_sum5, ma343, ma444, ma565, + b343, b444, b565, dst); + } +} + +inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, + const uint8_t* src, + const uint8_t* const top_border, + const uint8_t* bottom_border, + const ptrdiff_t stride, const int width, + const int height, SgrBuffer* const sgr_buffer, + uint8_t* dst) { + const auto temp_stride = Align<ptrdiff_t>(width, 8); + const ptrdiff_t sum_stride = temp_stride + 8; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0]; // < 2^12. + const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0]; + uint16_t *sum5[5], *ma565[2]; + uint32_t *square_sum5[5], *b565[2]; + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + for (int i = 1; i <= 4; ++i) { + sum5[i] = sum5[i - 1] + sum_stride; + square_sum5[i] = square_sum5[i - 1] + sum_stride; + } + ma565[0] = sgr_buffer->ma565; + ma565[1] = ma565[0] + temp_stride; + b565[0] = sgr_buffer->b565; + b565[1] = b565[0] + temp_stride; + assert(scale != 0); + BoxSum<5>(top_border, stride, 2, sum_stride, sum5[1], square_sum5[1]); + sum5[0] = sum5[1]; + square_sum5[0] = square_sum5[1]; + const uint8_t* const s = (height > 1) ? src + stride : bottom_border; + BoxSumFilterPreProcess5(src, s, width, scale, sum5, square_sum5, ma565[0], + b565[0]); + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + + for (int y = (height >> 1) - 1; y > 0; --y) { + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + BoxFilterPass1(src + 3, src + 2 * stride, src + 3 * stride, stride, sum5, + square_sum5, width, scale, w0, ma565, b565, dst); + src += 2 * stride; + dst += 2 * stride; + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + if ((height & 1) == 0 || height > 1) { + const uint8_t* sr[2]; + if ((height & 1) == 0) { + sr[0] = bottom_border; + sr[1] = bottom_border + stride; + } else { + sr[0] = src + 2 * stride; + sr[1] = bottom_border; + } + BoxFilterPass1(src + 3, sr[0], sr[1], stride, sum5, square_sum5, width, + scale, w0, ma565, b565, dst); + } + if ((height & 1) != 0) { + if (height > 1) { + src += 2 * stride; + dst += 2 * stride; + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + } + BoxFilterPass1LastRow(src + 3, bottom_border + stride, width, scale, w0, + sum5, square_sum5, ma565[0], b565[0], dst); + } +} + +inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, + const uint8_t* src, + const uint8_t* const top_border, + const uint8_t* bottom_border, + const ptrdiff_t stride, const int width, + const int height, SgrBuffer* const sgr_buffer, + uint8_t* dst) { + assert(restoration_info.sgr_proj_info.multiplier[0] == 0); + const auto temp_stride = Align<ptrdiff_t>(width, 8); + const ptrdiff_t sum_stride = temp_stride + 8; + const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1]; + const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint32_t scale = kSgrScaleParameter[sgr_proj_index][1]; // < 2^12. + uint16_t *sum3[3], *ma343[3], *ma444[2]; + uint32_t *square_sum3[3], *b343[3], *b444[2]; + sum3[0] = sgr_buffer->sum3; + square_sum3[0] = sgr_buffer->square_sum3; + ma343[0] = sgr_buffer->ma343; + b343[0] = sgr_buffer->b343; + for (int i = 1; i <= 2; ++i) { + sum3[i] = sum3[i - 1] + sum_stride; + square_sum3[i] = square_sum3[i - 1] + sum_stride; + ma343[i] = ma343[i - 1] + temp_stride; + b343[i] = b343[i - 1] + temp_stride; + } + ma444[0] = sgr_buffer->ma444; + ma444[1] = ma444[0] + temp_stride; + b444[0] = sgr_buffer->b444; + b444[1] = b444[0] + temp_stride; + assert(scale != 0); + BoxSum<3>(top_border, stride, 2, sum_stride, sum3[0], square_sum3[0]); + BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3, ma343[0], + nullptr, b343[0], nullptr); + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + const uint8_t* s; + if (height > 1) { + s = src + stride; + } else { + s = bottom_border; + bottom_border += stride; + } + BoxSumFilterPreProcess3<true>(s, width, scale, sum3, square_sum3, ma343[1], + ma444[0], b343[1], b444[0]); + + for (int y = height - 2; y > 0; --y) { + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + BoxFilterPass2(src + 2, src + 2 * stride, width, scale, w0, sum3, + square_sum3, ma343, ma444, b343, b444, dst); + src += stride; + dst += stride; + Circulate3PointersBy1<uint16_t>(ma343); + Circulate3PointersBy1<uint32_t>(b343); + std::swap(ma444[0], ma444[1]); + std::swap(b444[0], b444[1]); + } + + src += 2; + int y = std::min(height, 2); + do { + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + BoxFilterPass2(src, bottom_border, width, scale, w0, sum3, square_sum3, + ma343, ma444, b343, b444, dst); + src += stride; + dst += stride; + bottom_border += stride; + Circulate3PointersBy1<uint16_t>(ma343); + Circulate3PointersBy1<uint32_t>(b343); + std::swap(ma444[0], ma444[1]); + std::swap(b444[0], b444[1]); + } while (--y != 0); +} + +// If |width| is non-multiple of 8, up to 7 more pixels are written to |dest| in +// the end of each row. It is safe to overwrite the output as it will not be +// part of the visible frame. +void SelfGuidedFilter_NEON( + const RestorationUnitInfo& restoration_info, const void* const source, + const void* const top_border, const void* const bottom_border, + const ptrdiff_t stride, const int width, const int height, + RestorationBuffer* const restoration_buffer, void* const dest) { + const int index = restoration_info.sgr_proj_info.index; + const int radius_pass_0 = kSgrProjParams[index][0]; // 2 or 0 + const int radius_pass_1 = kSgrProjParams[index][2]; // 1 or 0 + const auto* const src = static_cast<const uint8_t*>(source); + const auto* top = static_cast<const uint8_t*>(top_border); + const auto* bottom = static_cast<const uint8_t*>(bottom_border); + auto* const dst = static_cast<uint8_t*>(dest); + SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer; + if (radius_pass_1 == 0) { + // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the + // following assertion. + assert(radius_pass_0 != 0); + BoxFilterProcessPass1(restoration_info, src - 3, top - 3, bottom - 3, + stride, width, height, sgr_buffer, dst); + } else if (radius_pass_0 == 0) { + BoxFilterProcessPass2(restoration_info, src - 2, top - 2, bottom - 2, + stride, width, height, sgr_buffer, dst); + } else { + BoxFilterProcess(restoration_info, src - 3, top - 3, bottom - 3, stride, + width, height, sgr_buffer, dst); + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->loop_restorations[0] = WienerFilter_NEON; + dsp->loop_restorations[1] = SelfGuidedFilter_NEON; +} + +} // namespace +} // namespace low_bitdepth + +void LoopRestorationInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void LoopRestorationInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/loop_restoration_neon.h b/src/dsp/arm/loop_restoration_neon.h new file mode 100644 index 0000000..b551610 --- /dev/null +++ b/src/dsp/arm/loop_restoration_neon.h @@ -0,0 +1,40 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_LOOP_RESTORATION_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_LOOP_RESTORATION_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::loop_restorations, see the defines below for specifics. +// This function is not thread-safe. +void LoopRestorationInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON + +#define LIBGAV1_Dsp8bpp_WienerFilter LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_SelfGuidedFilter LIBGAV1_CPU_NEON + +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_LOOP_RESTORATION_NEON_H_ diff --git a/src/dsp/arm/mask_blend_neon.cc b/src/dsp/arm/mask_blend_neon.cc new file mode 100644 index 0000000..084f42f --- /dev/null +++ b/src/dsp/arm/mask_blend_neon.cc @@ -0,0 +1,444 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/mask_blend.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +// TODO(b/150461164): Consider combining with GetInterIntraMask4x2(). +// Compound predictors use int16_t values and need to multiply long because the +// Convolve range * 64 is 20 bits. Unfortunately there is no multiply int16_t by +// int8_t and accumulate into int32_t instruction. +template <int subsampling_x, int subsampling_y> +inline int16x8_t GetMask4x2(const uint8_t* mask, ptrdiff_t mask_stride) { + if (subsampling_x == 1) { + const int16x4_t mask_val0 = vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask))); + const int16x4_t mask_val1 = vreinterpret_s16_u16( + vpaddl_u8(vld1_u8(mask + (mask_stride << subsampling_y)))); + int16x8_t final_val; + if (subsampling_y == 1) { + const int16x4_t next_mask_val0 = + vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask + mask_stride))); + const int16x4_t next_mask_val1 = + vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask + mask_stride * 3))); + final_val = vaddq_s16(vcombine_s16(mask_val0, mask_val1), + vcombine_s16(next_mask_val0, next_mask_val1)); + } else { + final_val = vreinterpretq_s16_u16( + vpaddlq_u8(vreinterpretq_u8_s16(vcombine_s16(mask_val0, mask_val1)))); + } + return vrshrq_n_s16(final_val, subsampling_y + 1); + } + assert(subsampling_y == 0 && subsampling_x == 0); + const uint8x8_t mask_val0 = Load4(mask); + const uint8x8_t mask_val = Load4<1>(mask + mask_stride, mask_val0); + return vreinterpretq_s16_u16(vmovl_u8(mask_val)); +} + +template <int subsampling_x, int subsampling_y> +inline int16x8_t GetMask8(const uint8_t* mask, ptrdiff_t mask_stride) { + if (subsampling_x == 1) { + int16x8_t mask_val = vreinterpretq_s16_u16(vpaddlq_u8(vld1q_u8(mask))); + if (subsampling_y == 1) { + const int16x8_t next_mask_val = + vreinterpretq_s16_u16(vpaddlq_u8(vld1q_u8(mask + mask_stride))); + mask_val = vaddq_s16(mask_val, next_mask_val); + } + return vrshrq_n_s16(mask_val, 1 + subsampling_y); + } + assert(subsampling_y == 0 && subsampling_x == 0); + const uint8x8_t mask_val = vld1_u8(mask); + return vreinterpretq_s16_u16(vmovl_u8(mask_val)); +} + +inline void WriteMaskBlendLine4x2(const int16_t* const pred_0, + const int16_t* const pred_1, + const int16x8_t pred_mask_0, + const int16x8_t pred_mask_1, uint8_t* dst, + const ptrdiff_t dst_stride) { + const int16x8_t pred_val_0 = vld1q_s16(pred_0); + const int16x8_t pred_val_1 = vld1q_s16(pred_1); + // int res = (mask_value * prediction_0[x] + + // (64 - mask_value) * prediction_1[x]) >> 6; + const int32x4_t weighted_pred_0_lo = + vmull_s16(vget_low_s16(pred_mask_0), vget_low_s16(pred_val_0)); + const int32x4_t weighted_pred_0_hi = + vmull_s16(vget_high_s16(pred_mask_0), vget_high_s16(pred_val_0)); + const int32x4_t weighted_combo_lo = vmlal_s16( + weighted_pred_0_lo, vget_low_s16(pred_mask_1), vget_low_s16(pred_val_1)); + const int32x4_t weighted_combo_hi = + vmlal_s16(weighted_pred_0_hi, vget_high_s16(pred_mask_1), + vget_high_s16(pred_val_1)); + // dst[x] = static_cast<Pixel>( + // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0, + // (1 << kBitdepth8) - 1)); + const uint8x8_t result = + vqrshrun_n_s16(vcombine_s16(vshrn_n_s32(weighted_combo_lo, 6), + vshrn_n_s32(weighted_combo_hi, 6)), + 4); + StoreLo4(dst, result); + StoreHi4(dst + dst_stride, result); +} + +template <int subsampling_x, int subsampling_y> +inline void MaskBlending4x4_NEON(const int16_t* pred_0, const int16_t* pred_1, + const uint8_t* mask, + const ptrdiff_t mask_stride, uint8_t* dst, + const ptrdiff_t dst_stride) { + const int16x8_t mask_inverter = vdupq_n_s16(64); + int16x8_t pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + // TODO(b/150461164): Arm tends to do better with load(val); val += stride + // It may be possible to turn this into a loop with a templated height. + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + + pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); +} + +template <int subsampling_x, int subsampling_y> +inline void MaskBlending4xH_NEON(const int16_t* pred_0, const int16_t* pred_1, + const uint8_t* const mask_ptr, + const ptrdiff_t mask_stride, const int height, + uint8_t* dst, const ptrdiff_t dst_stride) { + const uint8_t* mask = mask_ptr; + if (height == 4) { + MaskBlending4x4_NEON<subsampling_x, subsampling_y>( + pred_0, pred_1, mask, mask_stride, dst, dst_stride); + return; + } + const int16x8_t mask_inverter = vdupq_n_s16(64); + int y = 0; + do { + int16x8_t pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + + pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + + pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + + pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + y += 8; + } while (y < height); +} + +template <int subsampling_x, int subsampling_y> +inline void MaskBlend_NEON(const void* prediction_0, const void* prediction_1, + const ptrdiff_t /*prediction_stride_1*/, + const uint8_t* const mask_ptr, + const ptrdiff_t mask_stride, const int width, + const int height, void* dest, + const ptrdiff_t dst_stride) { + auto* dst = static_cast<uint8_t*>(dest); + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + if (width == 4) { + MaskBlending4xH_NEON<subsampling_x, subsampling_y>( + pred_0, pred_1, mask_ptr, mask_stride, height, dst, dst_stride); + return; + } + const uint8_t* mask = mask_ptr; + const int16x8_t mask_inverter = vdupq_n_s16(64); + int y = 0; + do { + int x = 0; + do { + const int16x8_t pred_mask_0 = GetMask8<subsampling_x, subsampling_y>( + mask + (x << subsampling_x), mask_stride); + // 64 - mask + const int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + const int16x8_t pred_val_0 = vld1q_s16(pred_0 + x); + const int16x8_t pred_val_1 = vld1q_s16(pred_1 + x); + uint8x8_t result; + // int res = (mask_value * prediction_0[x] + + // (64 - mask_value) * prediction_1[x]) >> 6; + const int32x4_t weighted_pred_0_lo = + vmull_s16(vget_low_s16(pred_mask_0), vget_low_s16(pred_val_0)); + const int32x4_t weighted_pred_0_hi = + vmull_s16(vget_high_s16(pred_mask_0), vget_high_s16(pred_val_0)); + const int32x4_t weighted_combo_lo = + vmlal_s16(weighted_pred_0_lo, vget_low_s16(pred_mask_1), + vget_low_s16(pred_val_1)); + const int32x4_t weighted_combo_hi = + vmlal_s16(weighted_pred_0_hi, vget_high_s16(pred_mask_1), + vget_high_s16(pred_val_1)); + + // dst[x] = static_cast<Pixel>( + // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0, + // (1 << kBitdepth8) - 1)); + result = vqrshrun_n_s16(vcombine_s16(vshrn_n_s32(weighted_combo_lo, 6), + vshrn_n_s32(weighted_combo_hi, 6)), + 4); + vst1_u8(dst + x, result); + + x += 8; + } while (x < width); + dst += dst_stride; + pred_0 += width; + pred_1 += width; + mask += mask_stride << subsampling_y; + } while (++y < height); +} + +// TODO(b/150461164): This is much faster for inter_intra (input is Pixel +// values) but regresses compound versions (input is int16_t). Try to +// consolidate these. +template <int subsampling_x, int subsampling_y> +inline uint8x8_t GetInterIntraMask4x2(const uint8_t* mask, + ptrdiff_t mask_stride) { + if (subsampling_x == 1) { + const uint8x8_t mask_val = + vpadd_u8(vld1_u8(mask), vld1_u8(mask + (mask_stride << subsampling_y))); + if (subsampling_y == 1) { + const uint8x8_t next_mask_val = vpadd_u8(vld1_u8(mask + mask_stride), + vld1_u8(mask + mask_stride * 3)); + + // Use a saturating add to work around the case where all |mask| values + // are 64. Together with the rounding shift this ensures the correct + // result. + const uint8x8_t sum = vqadd_u8(mask_val, next_mask_val); + return vrshr_n_u8(sum, /*subsampling_x=*/1 + subsampling_y); + } + + return vrshr_n_u8(mask_val, /*subsampling_x=*/1); + } + + assert(subsampling_y == 0 && subsampling_x == 0); + const uint8x8_t mask_val0 = Load4(mask); + // TODO(b/150461164): Investigate the source of |mask| and see if the stride + // can be removed. + // TODO(b/150461164): The unit tests start at 8x8. Does this get run? + return Load4<1>(mask + mask_stride, mask_val0); +} + +template <int subsampling_x, int subsampling_y> +inline uint8x8_t GetInterIntraMask8(const uint8_t* mask, + ptrdiff_t mask_stride) { + if (subsampling_x == 1) { + const uint8x16_t mask_val = vld1q_u8(mask); + const uint8x8_t mask_paired = + vpadd_u8(vget_low_u8(mask_val), vget_high_u8(mask_val)); + if (subsampling_y == 1) { + const uint8x16_t next_mask_val = vld1q_u8(mask + mask_stride); + const uint8x8_t next_mask_paired = + vpadd_u8(vget_low_u8(next_mask_val), vget_high_u8(next_mask_val)); + + // Use a saturating add to work around the case where all |mask| values + // are 64. Together with the rounding shift this ensures the correct + // result. + const uint8x8_t sum = vqadd_u8(mask_paired, next_mask_paired); + return vrshr_n_u8(sum, /*subsampling_x=*/1 + subsampling_y); + } + + return vrshr_n_u8(mask_paired, /*subsampling_x=*/1); + } + + assert(subsampling_y == 0 && subsampling_x == 0); + return vld1_u8(mask); +} + +inline void InterIntraWriteMaskBlendLine8bpp4x2(const uint8_t* const pred_0, + uint8_t* const pred_1, + const ptrdiff_t pred_stride_1, + const uint8x8_t pred_mask_0, + const uint8x8_t pred_mask_1) { + const uint8x8_t pred_val_0 = vld1_u8(pred_0); + uint8x8_t pred_val_1 = Load4(pred_1); + pred_val_1 = Load4<1>(pred_1 + pred_stride_1, pred_val_1); + + const uint16x8_t weighted_pred_0 = vmull_u8(pred_mask_0, pred_val_0); + const uint16x8_t weighted_combo = + vmlal_u8(weighted_pred_0, pred_mask_1, pred_val_1); + const uint8x8_t result = vrshrn_n_u16(weighted_combo, 6); + StoreLo4(pred_1, result); + StoreHi4(pred_1 + pred_stride_1, result); +} + +template <int subsampling_x, int subsampling_y> +inline void InterIntraMaskBlending8bpp4x4_NEON(const uint8_t* pred_0, + uint8_t* pred_1, + const ptrdiff_t pred_stride_1, + const uint8_t* mask, + const ptrdiff_t mask_stride) { + const uint8x8_t mask_inverter = vdup_n_u8(64); + uint8x8_t pred_mask_1 = + GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + uint8x8_t pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1); + InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1, + pred_mask_0, pred_mask_1); + pred_0 += 4 << 1; + pred_1 += pred_stride_1 << 1; + mask += mask_stride << (1 + subsampling_y); + + pred_mask_1 = + GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1); + InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1, + pred_mask_0, pred_mask_1); +} + +template <int subsampling_x, int subsampling_y> +inline void InterIntraMaskBlending8bpp4xH_NEON( + const uint8_t* pred_0, uint8_t* pred_1, const ptrdiff_t pred_stride_1, + const uint8_t* mask, const ptrdiff_t mask_stride, const int height) { + if (height == 4) { + InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>( + pred_0, pred_1, pred_stride_1, mask, mask_stride); + return; + } + int y = 0; + do { + InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>( + pred_0, pred_1, pred_stride_1, mask, mask_stride); + pred_0 += 4 << 2; + pred_1 += pred_stride_1 << 2; + mask += mask_stride << (2 + subsampling_y); + + InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>( + pred_0, pred_1, pred_stride_1, mask, mask_stride); + pred_0 += 4 << 2; + pred_1 += pred_stride_1 << 2; + mask += mask_stride << (2 + subsampling_y); + y += 8; + } while (y < height); +} + +template <int subsampling_x, int subsampling_y> +inline void InterIntraMaskBlend8bpp_NEON(const uint8_t* prediction_0, + uint8_t* prediction_1, + const ptrdiff_t prediction_stride_1, + const uint8_t* const mask_ptr, + const ptrdiff_t mask_stride, + const int width, const int height) { + if (width == 4) { + InterIntraMaskBlending8bpp4xH_NEON<subsampling_x, subsampling_y>( + prediction_0, prediction_1, prediction_stride_1, mask_ptr, mask_stride, + height); + return; + } + const uint8_t* mask = mask_ptr; + const uint8x8_t mask_inverter = vdup_n_u8(64); + int y = 0; + do { + int x = 0; + do { + // TODO(b/150461164): Consider a 16 wide specialization (at least for the + // unsampled version) to take advantage of vld1q_u8(). + const uint8x8_t pred_mask_1 = + GetInterIntraMask8<subsampling_x, subsampling_y>( + mask + (x << subsampling_x), mask_stride); + // 64 - mask + const uint8x8_t pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1); + const uint8x8_t pred_val_0 = vld1_u8(prediction_0); + prediction_0 += 8; + const uint8x8_t pred_val_1 = vld1_u8(prediction_1 + x); + const uint16x8_t weighted_pred_0 = vmull_u8(pred_mask_0, pred_val_0); + // weighted_pred0 + weighted_pred1 + const uint16x8_t weighted_combo = + vmlal_u8(weighted_pred_0, pred_mask_1, pred_val_1); + const uint8x8_t result = vrshrn_n_u16(weighted_combo, 6); + vst1_u8(prediction_1 + x, result); + + x += 8; + } while (x < width); + prediction_1 += prediction_stride_1; + mask += mask_stride << subsampling_y; + } while (++y < height); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->mask_blend[0][0] = MaskBlend_NEON<0, 0>; + dsp->mask_blend[1][0] = MaskBlend_NEON<1, 0>; + dsp->mask_blend[2][0] = MaskBlend_NEON<1, 1>; + // The is_inter_intra index of mask_blend[][] is replaced by + // inter_intra_mask_blend_8bpp[] in 8-bit. + dsp->inter_intra_mask_blend_8bpp[0] = InterIntraMaskBlend8bpp_NEON<0, 0>; + dsp->inter_intra_mask_blend_8bpp[1] = InterIntraMaskBlend8bpp_NEON<1, 0>; + dsp->inter_intra_mask_blend_8bpp[2] = InterIntraMaskBlend8bpp_NEON<1, 1>; +} + +} // namespace +} // namespace low_bitdepth + +void MaskBlendInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON + +namespace libgav1 { +namespace dsp { + +void MaskBlendInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/mask_blend_neon.h b/src/dsp/arm/mask_blend_neon.h new file mode 100644 index 0000000..3829274 --- /dev/null +++ b/src/dsp/arm/mask_blend_neon.h @@ -0,0 +1,41 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_MASK_BLEND_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_MASK_BLEND_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::mask_blend. This function is not thread-safe. +void MaskBlendInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_MaskBlend444 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_MaskBlend422 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_MaskBlend420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp444 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp422 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp420 LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_MASK_BLEND_NEON_H_ diff --git a/src/dsp/arm/motion_field_projection_neon.cc b/src/dsp/arm/motion_field_projection_neon.cc new file mode 100644 index 0000000..8caba7d --- /dev/null +++ b/src/dsp/arm/motion_field_projection_neon.cc @@ -0,0 +1,393 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/motion_field_projection.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" +#include "src/utils/types.h" + +namespace libgav1 { +namespace dsp { +namespace { + +inline int16x8_t LoadDivision(const int8x8x2_t division_table, + const int8x8_t reference_offset) { + const int8x8_t kOne = vcreate_s8(0x0100010001000100); + const int8x16_t kOneQ = vcombine_s8(kOne, kOne); + const int8x8_t t = vadd_s8(reference_offset, reference_offset); + const int8x8x2_t tt = vzip_s8(t, t); + const int8x16_t t1 = vcombine_s8(tt.val[0], tt.val[1]); + const int8x16_t idx = vaddq_s8(t1, kOneQ); + const int8x8_t idx_low = vget_low_s8(idx); + const int8x8_t idx_high = vget_high_s8(idx); + const int16x4_t d0 = vreinterpret_s16_s8(vtbl2_s8(division_table, idx_low)); + const int16x4_t d1 = vreinterpret_s16_s8(vtbl2_s8(division_table, idx_high)); + return vcombine_s16(d0, d1); +} + +inline int16x4_t MvProjection(const int16x4_t mv, const int16x4_t denominator, + const int numerator) { + const int32x4_t m0 = vmull_s16(mv, denominator); + const int32x4_t m = vmulq_n_s32(m0, numerator); + // Add the sign (0 or -1) to round towards zero. + const int32x4_t add_sign = vsraq_n_s32(m, m, 31); + return vqrshrn_n_s32(add_sign, 14); +} + +inline int16x8_t MvProjectionClip(const int16x8_t mv, + const int16x8_t denominator, + const int numerator) { + const int16x4_t mv0 = vget_low_s16(mv); + const int16x4_t mv1 = vget_high_s16(mv); + const int16x4_t s0 = MvProjection(mv0, vget_low_s16(denominator), numerator); + const int16x4_t s1 = MvProjection(mv1, vget_high_s16(denominator), numerator); + const int16x8_t projection = vcombine_s16(s0, s1); + const int16x8_t projection_mv_clamp = vdupq_n_s16(kProjectionMvClamp); + const int16x8_t clamp = vminq_s16(projection, projection_mv_clamp); + return vmaxq_s16(clamp, vnegq_s16(projection_mv_clamp)); +} + +inline int8x8_t Project_NEON(const int16x8_t delta, const int16x8_t dst_sign) { + // Add 63 to negative delta so that it shifts towards zero. + const int16x8_t delta_sign = vshrq_n_s16(delta, 15); + const uint16x8_t delta_u = vreinterpretq_u16_s16(delta); + const uint16x8_t delta_sign_u = vreinterpretq_u16_s16(delta_sign); + const uint16x8_t delta_adjust_u = vsraq_n_u16(delta_u, delta_sign_u, 10); + const int16x8_t delta_adjust = vreinterpretq_s16_u16(delta_adjust_u); + const int16x8_t offset0 = vshrq_n_s16(delta_adjust, 6); + const int16x8_t offset1 = veorq_s16(offset0, dst_sign); + const int16x8_t offset2 = vsubq_s16(offset1, dst_sign); + return vqmovn_s16(offset2); +} + +inline void GetPosition( + const int8x8x2_t division_table, const MotionVector* const mv, + const int numerator, const int x8_start, const int x8_end, const int x8, + const int8x8_t r_offsets, const int8x8_t source_reference_type8, + const int8x8_t skip_r, const int8x8_t y8_floor8, const int8x8_t y8_ceiling8, + const int16x8_t d_sign, const int delta, int8x8_t* const r, + int8x8_t* const position_y8, int8x8_t* const position_x8, + int64_t* const skip_64, int32x4_t mvs[2]) { + const auto* const mv_int = reinterpret_cast<const int32_t*>(mv + x8); + *r = vtbl1_s8(r_offsets, source_reference_type8); + const int16x8_t denorm = LoadDivision(division_table, source_reference_type8); + int16x8_t projection_mv[2]; + mvs[0] = vld1q_s32(mv_int + 0); + mvs[1] = vld1q_s32(mv_int + 4); + // Deinterlace x and y components + const int16x8_t mv0 = vreinterpretq_s16_s32(mvs[0]); + const int16x8_t mv1 = vreinterpretq_s16_s32(mvs[1]); + const int16x8x2_t mv_yx = vuzpq_s16(mv0, mv1); + // numerator could be 0. + projection_mv[0] = MvProjectionClip(mv_yx.val[0], denorm, numerator); + projection_mv[1] = MvProjectionClip(mv_yx.val[1], denorm, numerator); + // Do not update the motion vector if the block position is not valid or + // if position_x8 is outside the current range of x8_start and x8_end. + // Note that position_y8 will always be within the range of y8_start and + // y8_end. + // After subtracting the base, valid projections are within 8-bit. + *position_y8 = Project_NEON(projection_mv[0], d_sign); + const int8x8_t position_x = Project_NEON(projection_mv[1], d_sign); + const int8x8_t k01234567 = vcreate_s8(uint64_t{0x0706050403020100}); + *position_x8 = vqadd_s8(position_x, k01234567); + const int8x16_t position_xy = vcombine_s8(*position_x8, *position_y8); + const int x8_floor = std::max( + x8_start - x8, delta - kProjectionMvMaxHorizontalOffset); // [-8, 8] + const int x8_ceiling = std::min( + x8_end - x8, delta + 8 + kProjectionMvMaxHorizontalOffset); // [0, 16] + const int8x8_t x8_floor8 = vdup_n_s8(x8_floor); + const int8x8_t x8_ceiling8 = vdup_n_s8(x8_ceiling); + const int8x16_t floor_xy = vcombine_s8(x8_floor8, y8_floor8); + const int8x16_t ceiling_xy = vcombine_s8(x8_ceiling8, y8_ceiling8); + const uint8x16_t underflow = vcltq_s8(position_xy, floor_xy); + const uint8x16_t overflow = vcgeq_s8(position_xy, ceiling_xy); + const int8x16_t out = vreinterpretq_s8_u8(vorrq_u8(underflow, overflow)); + const int8x8_t skip_low = vorr_s8(skip_r, vget_low_s8(out)); + const int8x8_t skip = vorr_s8(skip_low, vget_high_s8(out)); + *skip_64 = vget_lane_s64(vreinterpret_s64_s8(skip), 0); +} + +template <int idx> +inline void Store(const int16x8_t position, const int8x8_t reference_offset, + const int32x4_t mv, int8_t* dst_reference_offset, + MotionVector* dst_mv) { + const ptrdiff_t offset = vgetq_lane_s16(position, idx); + auto* const d_mv = reinterpret_cast<int32_t*>(&dst_mv[offset]); + vst1q_lane_s32(d_mv, mv, idx & 3); + vst1_lane_s8(&dst_reference_offset[offset], reference_offset, idx); +} + +template <int idx> +inline void CheckStore(const int8_t* skips, const int16x8_t position, + const int8x8_t reference_offset, const int32x4_t mv, + int8_t* dst_reference_offset, MotionVector* dst_mv) { + if (skips[idx] == 0) { + Store<idx>(position, reference_offset, mv, dst_reference_offset, dst_mv); + } +} + +// 7.9.2. +void MotionFieldProjectionKernel_NEON(const ReferenceInfo& reference_info, + const int reference_to_current_with_sign, + const int dst_sign, const int y8_start, + const int y8_end, const int x8_start, + const int x8_end, + TemporalMotionField* const motion_field) { + const ptrdiff_t stride = motion_field->mv.columns(); + // The column range has to be offset by kProjectionMvMaxHorizontalOffset since + // coordinates in that range could end up being position_x8 because of + // projection. + const int adjusted_x8_start = + std::max(x8_start - kProjectionMvMaxHorizontalOffset, 0); + const int adjusted_x8_end = std::min( + x8_end + kProjectionMvMaxHorizontalOffset, static_cast<int>(stride)); + const int adjusted_x8_end8 = adjusted_x8_end & ~7; + const int leftover = adjusted_x8_end - adjusted_x8_end8; + const int8_t* const reference_offsets = + reference_info.relative_distance_to.data(); + const bool* const skip_references = reference_info.skip_references.data(); + const int16_t* const projection_divisions = + reference_info.projection_divisions.data(); + const ReferenceFrameType* source_reference_types = + &reference_info.motion_field_reference_frame[y8_start][0]; + const MotionVector* mv = &reference_info.motion_field_mv[y8_start][0]; + int8_t* dst_reference_offset = motion_field->reference_offset[y8_start]; + MotionVector* dst_mv = motion_field->mv[y8_start]; + const int16x8_t d_sign = vdupq_n_s16(dst_sign); + + static_assert(sizeof(int8_t) == sizeof(bool), ""); + static_assert(sizeof(int8_t) == sizeof(ReferenceFrameType), ""); + static_assert(sizeof(int32_t) == sizeof(MotionVector), ""); + assert(dst_sign == 0 || dst_sign == -1); + assert(stride == motion_field->reference_offset.columns()); + assert((y8_start & 7) == 0); + assert((adjusted_x8_start & 7) == 0); + // The final position calculation is represented with int16_t. Valid + // position_y8 from its base is at most 7. After considering the horizontal + // offset which is at most |stride - 1|, we have the following assertion, + // which means this optimization works for frame width up to 32K (each + // position is a 8x8 block). + assert(8 * stride <= 32768); + const int8x8_t skip_reference = + vld1_s8(reinterpret_cast<const int8_t*>(skip_references)); + const int8x8_t r_offsets = vld1_s8(reference_offsets); + const int8x16_t table = vreinterpretq_s8_s16(vld1q_s16(projection_divisions)); + int8x8x2_t division_table; + division_table.val[0] = vget_low_s8(table); + division_table.val[1] = vget_high_s8(table); + + int y8 = y8_start; + do { + const int y8_floor = (y8 & ~7) - y8; // [-7, 0] + const int y8_ceiling = std::min(y8_end - y8, y8_floor + 8); // [1, 8] + const int8x8_t y8_floor8 = vdup_n_s8(y8_floor); + const int8x8_t y8_ceiling8 = vdup_n_s8(y8_ceiling); + int x8; + + for (x8 = adjusted_x8_start; x8 < adjusted_x8_end8; x8 += 8) { + const int8x8_t source_reference_type8 = + vld1_s8(reinterpret_cast<const int8_t*>(source_reference_types + x8)); + const int8x8_t skip_r = vtbl1_s8(skip_reference, source_reference_type8); + const int64_t early_skip = vget_lane_s64(vreinterpret_s64_s8(skip_r), 0); + // Early termination #1 if all are skips. Chance is typically ~30-40%. + if (early_skip == -1) continue; + int64_t skip_64; + int8x8_t r, position_x8, position_y8; + int32x4_t mvs[2]; + GetPosition(division_table, mv, reference_to_current_with_sign, x8_start, + x8_end, x8, r_offsets, source_reference_type8, skip_r, + y8_floor8, y8_ceiling8, d_sign, 0, &r, &position_y8, + &position_x8, &skip_64, mvs); + // Early termination #2 if all are skips. + // Chance is typically ~15-25% after Early termination #1. + if (skip_64 == -1) continue; + const int16x8_t p_y = vmovl_s8(position_y8); + const int16x8_t p_x = vmovl_s8(position_x8); + const int16x8_t pos = vmlaq_n_s16(p_x, p_y, stride); + const int16x8_t position = vaddq_s16(pos, vdupq_n_s16(x8)); + if (skip_64 == 0) { + // Store all. Chance is typically ~70-85% after Early termination #2. + Store<0>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<1>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<2>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<3>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<4>(position, r, mvs[1], dst_reference_offset, dst_mv); + Store<5>(position, r, mvs[1], dst_reference_offset, dst_mv); + Store<6>(position, r, mvs[1], dst_reference_offset, dst_mv); + Store<7>(position, r, mvs[1], dst_reference_offset, dst_mv); + } else { + // Check and store each. + // Chance is typically ~15-30% after Early termination #2. + // The compiler is smart enough to not create the local buffer skips[]. + int8_t skips[8]; + memcpy(skips, &skip_64, sizeof(skips)); + CheckStore<0>(skips, position, r, mvs[0], dst_reference_offset, dst_mv); + CheckStore<1>(skips, position, r, mvs[0], dst_reference_offset, dst_mv); + CheckStore<2>(skips, position, r, mvs[0], dst_reference_offset, dst_mv); + CheckStore<3>(skips, position, r, mvs[0], dst_reference_offset, dst_mv); + CheckStore<4>(skips, position, r, mvs[1], dst_reference_offset, dst_mv); + CheckStore<5>(skips, position, r, mvs[1], dst_reference_offset, dst_mv); + CheckStore<6>(skips, position, r, mvs[1], dst_reference_offset, dst_mv); + CheckStore<7>(skips, position, r, mvs[1], dst_reference_offset, dst_mv); + } + } + + // The following leftover processing cannot be moved out of the do...while + // loop. Doing so may change the result storing orders of the same position. + if (leftover > 0) { + // Use SIMD only when leftover is at least 4, and there are at least 8 + // elements in a row. + if (leftover >= 4 && adjusted_x8_start < adjusted_x8_end8) { + // Process the last 8 elements to avoid loading invalid memory. Some + // elements may have been processed in the above loop, which is OK. + const int delta = 8 - leftover; + x8 = adjusted_x8_end - 8; + const int8x8_t source_reference_type8 = vld1_s8( + reinterpret_cast<const int8_t*>(source_reference_types + x8)); + const int8x8_t skip_r = + vtbl1_s8(skip_reference, source_reference_type8); + const int64_t early_skip = + vget_lane_s64(vreinterpret_s64_s8(skip_r), 0); + // Early termination #1 if all are skips. + if (early_skip != -1) { + int64_t skip_64; + int8x8_t r, position_x8, position_y8; + int32x4_t mvs[2]; + GetPosition(division_table, mv, reference_to_current_with_sign, + x8_start, x8_end, x8, r_offsets, source_reference_type8, + skip_r, y8_floor8, y8_ceiling8, d_sign, delta, &r, + &position_y8, &position_x8, &skip_64, mvs); + // Early termination #2 if all are skips. + if (skip_64 != -1) { + const int16x8_t p_y = vmovl_s8(position_y8); + const int16x8_t p_x = vmovl_s8(position_x8); + const int16x8_t pos = vmlaq_n_s16(p_x, p_y, stride); + const int16x8_t position = vaddq_s16(pos, vdupq_n_s16(x8)); + // Store up to 7 elements since leftover is at most 7. + if (skip_64 == 0) { + // Store all. + Store<1>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<2>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<3>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<4>(position, r, mvs[1], dst_reference_offset, dst_mv); + Store<5>(position, r, mvs[1], dst_reference_offset, dst_mv); + Store<6>(position, r, mvs[1], dst_reference_offset, dst_mv); + Store<7>(position, r, mvs[1], dst_reference_offset, dst_mv); + } else { + // Check and store each. + // The compiler is smart enough to not create the local buffer + // skips[]. + int8_t skips[8]; + memcpy(skips, &skip_64, sizeof(skips)); + CheckStore<1>(skips, position, r, mvs[0], dst_reference_offset, + dst_mv); + CheckStore<2>(skips, position, r, mvs[0], dst_reference_offset, + dst_mv); + CheckStore<3>(skips, position, r, mvs[0], dst_reference_offset, + dst_mv); + CheckStore<4>(skips, position, r, mvs[1], dst_reference_offset, + dst_mv); + CheckStore<5>(skips, position, r, mvs[1], dst_reference_offset, + dst_mv); + CheckStore<6>(skips, position, r, mvs[1], dst_reference_offset, + dst_mv); + CheckStore<7>(skips, position, r, mvs[1], dst_reference_offset, + dst_mv); + } + } + } + } else { + for (; x8 < adjusted_x8_end; ++x8) { + const int source_reference_type = source_reference_types[x8]; + if (skip_references[source_reference_type]) continue; + MotionVector projection_mv; + // reference_to_current_with_sign could be 0. + GetMvProjection(mv[x8], reference_to_current_with_sign, + projection_divisions[source_reference_type], + &projection_mv); + // Do not update the motion vector if the block position is not valid + // or if position_x8 is outside the current range of x8_start and + // x8_end. Note that position_y8 will always be within the range of + // y8_start and y8_end. + const int position_y8 = Project(0, projection_mv.mv[0], dst_sign); + if (position_y8 < y8_floor || position_y8 >= y8_ceiling) continue; + const int x8_base = x8 & ~7; + const int x8_floor = + std::max(x8_start, x8_base - kProjectionMvMaxHorizontalOffset); + const int x8_ceiling = + std::min(x8_end, x8_base + 8 + kProjectionMvMaxHorizontalOffset); + const int position_x8 = Project(x8, projection_mv.mv[1], dst_sign); + if (position_x8 < x8_floor || position_x8 >= x8_ceiling) continue; + dst_mv[position_y8 * stride + position_x8] = mv[x8]; + dst_reference_offset[position_y8 * stride + position_x8] = + reference_offsets[source_reference_type]; + } + } + } + + source_reference_types += stride; + mv += stride; + dst_reference_offset += stride; + dst_mv += stride; + } while (++y8 < y8_end); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_NEON; +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_NEON; +} +#endif + +} // namespace + +void MotionFieldProjectionInit_NEON() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void MotionFieldProjectionInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/motion_field_projection_neon.h b/src/dsp/arm/motion_field_projection_neon.h new file mode 100644 index 0000000..41ab6a6 --- /dev/null +++ b/src/dsp/arm/motion_field_projection_neon.h @@ -0,0 +1,39 @@ +/* + * Copyright 2020 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_MOTION_FIELD_PROJECTION_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_MOTION_FIELD_PROJECTION_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::motion_field_projection_kernel. This function is not +// thread-safe. +void MotionFieldProjectionInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON + +#define LIBGAV1_Dsp8bpp_MotionFieldProjectionKernel LIBGAV1_CPU_NEON + +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_MOTION_FIELD_PROJECTION_NEON_H_ diff --git a/src/dsp/arm/motion_vector_search_neon.cc b/src/dsp/arm/motion_vector_search_neon.cc new file mode 100644 index 0000000..8a403a6 --- /dev/null +++ b/src/dsp/arm/motion_vector_search_neon.cc @@ -0,0 +1,267 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/motion_vector_search.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" +#include "src/utils/types.h" + +namespace libgav1 { +namespace dsp { +namespace { + +inline int16x4_t MvProjection(const int16x4_t mv, const int16x4_t denominator, + const int32x4_t numerator) { + const int32x4_t m0 = vmull_s16(mv, denominator); + const int32x4_t m = vmulq_s32(m0, numerator); + // Add the sign (0 or -1) to round towards zero. + const int32x4_t add_sign = vsraq_n_s32(m, m, 31); + return vqrshrn_n_s32(add_sign, 14); +} + +inline int16x4_t MvProjectionCompound(const int16x4_t mv, + const int temporal_reference_offsets, + const int reference_offsets[2]) { + const int16x4_t denominator = + vdup_n_s16(kProjectionMvDivisionLookup[temporal_reference_offsets]); + const int32x2_t offset = vld1_s32(reference_offsets); + const int32x2x2_t offsets = vzip_s32(offset, offset); + const int32x4_t numerator = vcombine_s32(offsets.val[0], offsets.val[1]); + return MvProjection(mv, denominator, numerator); +} + +inline int16x8_t ProjectionClip(const int16x4_t mv0, const int16x4_t mv1) { + const int16x8_t projection_mv_clamp = vdupq_n_s16(kProjectionMvClamp); + const int16x8_t mv = vcombine_s16(mv0, mv1); + const int16x8_t clamp = vminq_s16(mv, projection_mv_clamp); + return vmaxq_s16(clamp, vnegq_s16(projection_mv_clamp)); +} + +inline int16x8_t MvProjectionCompoundClip( + const MotionVector* const temporal_mvs, + const int8_t* const temporal_reference_offsets, + const int reference_offsets[2]) { + const auto* const tmvs = reinterpret_cast<const int32_t*>(temporal_mvs); + const int32x2_t temporal_mv = vld1_s32(tmvs); + const int16x4_t tmv0 = vreinterpret_s16_s32(vdup_lane_s32(temporal_mv, 0)); + const int16x4_t tmv1 = vreinterpret_s16_s32(vdup_lane_s32(temporal_mv, 1)); + const int16x4_t mv0 = MvProjectionCompound( + tmv0, temporal_reference_offsets[0], reference_offsets); + const int16x4_t mv1 = MvProjectionCompound( + tmv1, temporal_reference_offsets[1], reference_offsets); + return ProjectionClip(mv0, mv1); +} + +inline int16x8_t MvProjectionSingleClip( + const MotionVector* const temporal_mvs, + const int8_t* const temporal_reference_offsets, const int reference_offset, + int16x4_t* const lookup) { + const auto* const tmvs = reinterpret_cast<const int16_t*>(temporal_mvs); + const int16x8_t temporal_mv = vld1q_s16(tmvs); + *lookup = vld1_lane_s16( + &kProjectionMvDivisionLookup[temporal_reference_offsets[0]], *lookup, 0); + *lookup = vld1_lane_s16( + &kProjectionMvDivisionLookup[temporal_reference_offsets[1]], *lookup, 1); + *lookup = vld1_lane_s16( + &kProjectionMvDivisionLookup[temporal_reference_offsets[2]], *lookup, 2); + *lookup = vld1_lane_s16( + &kProjectionMvDivisionLookup[temporal_reference_offsets[3]], *lookup, 3); + const int16x4x2_t denominator = vzip_s16(*lookup, *lookup); + const int16x4_t tmv0 = vget_low_s16(temporal_mv); + const int16x4_t tmv1 = vget_high_s16(temporal_mv); + const int32x4_t numerator = vdupq_n_s32(reference_offset); + const int16x4_t mv0 = MvProjection(tmv0, denominator.val[0], numerator); + const int16x4_t mv1 = MvProjection(tmv1, denominator.val[1], numerator); + return ProjectionClip(mv0, mv1); +} + +inline void LowPrecision(const int16x8_t mv, void* const candidate_mvs) { + const int16x8_t kRoundDownMask = vdupq_n_s16(1); + const uint16x8_t mvu = vreinterpretq_u16_s16(mv); + const int16x8_t mv0 = vreinterpretq_s16_u16(vsraq_n_u16(mvu, mvu, 15)); + const int16x8_t mv1 = vbicq_s16(mv0, kRoundDownMask); + vst1q_s16(static_cast<int16_t*>(candidate_mvs), mv1); +} + +inline void ForceInteger(const int16x8_t mv, void* const candidate_mvs) { + const int16x8_t kRoundDownMask = vdupq_n_s16(7); + const uint16x8_t mvu = vreinterpretq_u16_s16(mv); + const int16x8_t mv0 = vreinterpretq_s16_u16(vsraq_n_u16(mvu, mvu, 15)); + const int16x8_t mv1 = vaddq_s16(mv0, vdupq_n_s16(3)); + const int16x8_t mv2 = vbicq_s16(mv1, kRoundDownMask); + vst1q_s16(static_cast<int16_t*>(candidate_mvs), mv2); +} + +void MvProjectionCompoundLowPrecision_NEON( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + const int reference_offsets[2], const int count, + CompoundMotionVector* candidate_mvs) { + // |reference_offsets| non-zero check usually equals true and is ignored. + // To facilitate the compilers, make a local copy of |reference_offsets|. + const int offsets[2] = {reference_offsets[0], reference_offsets[1]}; + // One more element could be calculated. + int loop_count = (count + 1) >> 1; + do { + const int16x8_t mv = MvProjectionCompoundClip( + temporal_mvs, temporal_reference_offsets, offsets); + LowPrecision(mv, candidate_mvs); + temporal_mvs += 2; + temporal_reference_offsets += 2; + candidate_mvs += 2; + } while (--loop_count); +} + +void MvProjectionCompoundForceInteger_NEON( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + const int reference_offsets[2], const int count, + CompoundMotionVector* candidate_mvs) { + // |reference_offsets| non-zero check usually equals true and is ignored. + // To facilitate the compilers, make a local copy of |reference_offsets|. + const int offsets[2] = {reference_offsets[0], reference_offsets[1]}; + // One more element could be calculated. + int loop_count = (count + 1) >> 1; + do { + const int16x8_t mv = MvProjectionCompoundClip( + temporal_mvs, temporal_reference_offsets, offsets); + ForceInteger(mv, candidate_mvs); + temporal_mvs += 2; + temporal_reference_offsets += 2; + candidate_mvs += 2; + } while (--loop_count); +} + +void MvProjectionCompoundHighPrecision_NEON( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + const int reference_offsets[2], const int count, + CompoundMotionVector* candidate_mvs) { + // |reference_offsets| non-zero check usually equals true and is ignored. + // To facilitate the compilers, make a local copy of |reference_offsets|. + const int offsets[2] = {reference_offsets[0], reference_offsets[1]}; + // One more element could be calculated. + int loop_count = (count + 1) >> 1; + do { + const int16x8_t mv = MvProjectionCompoundClip( + temporal_mvs, temporal_reference_offsets, offsets); + vst1q_s16(reinterpret_cast<int16_t*>(candidate_mvs), mv); + temporal_mvs += 2; + temporal_reference_offsets += 2; + candidate_mvs += 2; + } while (--loop_count); +} + +void MvProjectionSingleLowPrecision_NEON( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + const int reference_offset, const int count, MotionVector* candidate_mvs) { + // Up to three more elements could be calculated. + int loop_count = (count + 3) >> 2; + int16x4_t lookup = vdup_n_s16(0); + do { + const int16x8_t mv = MvProjectionSingleClip( + temporal_mvs, temporal_reference_offsets, reference_offset, &lookup); + LowPrecision(mv, candidate_mvs); + temporal_mvs += 4; + temporal_reference_offsets += 4; + candidate_mvs += 4; + } while (--loop_count); +} + +void MvProjectionSingleForceInteger_NEON( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + const int reference_offset, const int count, MotionVector* candidate_mvs) { + // Up to three more elements could be calculated. + int loop_count = (count + 3) >> 2; + int16x4_t lookup = vdup_n_s16(0); + do { + const int16x8_t mv = MvProjectionSingleClip( + temporal_mvs, temporal_reference_offsets, reference_offset, &lookup); + ForceInteger(mv, candidate_mvs); + temporal_mvs += 4; + temporal_reference_offsets += 4; + candidate_mvs += 4; + } while (--loop_count); +} + +void MvProjectionSingleHighPrecision_NEON( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + const int reference_offset, const int count, MotionVector* candidate_mvs) { + // Up to three more elements could be calculated. + int loop_count = (count + 3) >> 2; + int16x4_t lookup = vdup_n_s16(0); + do { + const int16x8_t mv = MvProjectionSingleClip( + temporal_mvs, temporal_reference_offsets, reference_offset, &lookup); + vst1q_s16(reinterpret_cast<int16_t*>(candidate_mvs), mv); + temporal_mvs += 4; + temporal_reference_offsets += 4; + candidate_mvs += 4; + } while (--loop_count); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_NEON; + dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_NEON; + dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_NEON; + dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_NEON; + dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_NEON; + dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_NEON; +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_NEON; + dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_NEON; + dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_NEON; + dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_NEON; + dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_NEON; + dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_NEON; +} +#endif + +} // namespace + +void MotionVectorSearchInit_NEON() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void MotionVectorSearchInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/motion_vector_search_neon.h b/src/dsp/arm/motion_vector_search_neon.h new file mode 100644 index 0000000..19b4519 --- /dev/null +++ b/src/dsp/arm/motion_vector_search_neon.h @@ -0,0 +1,39 @@ +/* + * Copyright 2020 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_MOTION_VECTOR_SEARCH_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_MOTION_VECTOR_SEARCH_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::mv_projection_compound and Dsp::mv_projection_single. This +// function is not thread-safe. +void MotionVectorSearchInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON + +#define LIBGAV1_Dsp8bpp_MotionVectorSearch LIBGAV1_CPU_NEON + +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_MOTION_VECTOR_SEARCH_NEON_H_ diff --git a/src/dsp/arm/obmc_neon.cc b/src/dsp/arm/obmc_neon.cc new file mode 100644 index 0000000..66ad663 --- /dev/null +++ b/src/dsp/arm/obmc_neon.cc @@ -0,0 +1,392 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/obmc.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace { + +#include "src/dsp/obmc.inc" + +inline void WriteObmcLine4(uint8_t* const pred, const uint8_t* const obmc_pred, + const uint8x8_t pred_mask, + const uint8x8_t obmc_pred_mask) { + const uint8x8_t pred_val = Load4(pred); + const uint8x8_t obmc_pred_val = Load4(obmc_pred); + const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val); + const uint8x8_t result = + vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6); + StoreLo4(pred, result); +} + +template <bool from_left> +inline void OverlapBlend2xH_NEON(uint8_t* const prediction, + const ptrdiff_t prediction_stride, + const int height, + const uint8_t* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + uint8_t* pred = prediction; + const uint8x8_t mask_inverter = vdup_n_u8(64); + const uint8_t* obmc_pred = obmc_prediction; + uint8x8_t pred_mask; + uint8x8_t obmc_pred_mask; + int compute_height; + const int mask_offset = height - 2; + if (from_left) { + pred_mask = Load2(kObmcMask); + obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + compute_height = height; + } else { + // Weights for the last line are all 64, which is a no-op. + compute_height = height - 1; + } + uint8x8_t pred_val = vdup_n_u8(0); + uint8x8_t obmc_pred_val = vdup_n_u8(0); + int y = 0; + do { + if (!from_left) { + pred_mask = vdup_n_u8(kObmcMask[mask_offset + y]); + obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + } + pred_val = Load2<0>(pred, pred_val); + const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val); + obmc_pred_val = Load2<0>(obmc_pred, obmc_pred_val); + const uint8x8_t result = + vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6); + Store2<0>(pred, result); + + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + } while (++y != compute_height); +} + +inline void OverlapBlendFromLeft4xH_NEON( + uint8_t* const prediction, const ptrdiff_t prediction_stride, + const int height, const uint8_t* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + uint8_t* pred = prediction; + const uint8_t* obmc_pred = obmc_prediction; + + const uint8x8_t mask_inverter = vdup_n_u8(64); + const uint8x8_t pred_mask = Load4(kObmcMask + 2); + // 64 - mask + const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + int y = 0; + do { + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + y += 2; + } while (y != height); +} + +inline void OverlapBlendFromLeft8xH_NEON( + uint8_t* const prediction, const ptrdiff_t prediction_stride, + const int height, const uint8_t* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + uint8_t* pred = prediction; + const uint8_t* obmc_pred = obmc_prediction; + const uint8x8_t mask_inverter = vdup_n_u8(64); + const uint8x8_t pred_mask = vld1_u8(kObmcMask + 6); + // 64 - mask + const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + int y = 0; + do { + const uint8x8_t pred_val = vld1_u8(pred); + const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val); + const uint8x8_t obmc_pred_val = vld1_u8(obmc_pred); + const uint8x8_t result = + vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6); + + vst1_u8(pred, result); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + } while (++y != height); +} + +void OverlapBlendFromLeft_NEON(void* const prediction, + const ptrdiff_t prediction_stride, + const int width, const int height, + const void* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + auto* pred = static_cast<uint8_t*>(prediction); + const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction); + + if (width == 2) { + OverlapBlend2xH_NEON<true>(pred, prediction_stride, height, obmc_pred, + obmc_prediction_stride); + return; + } + if (width == 4) { + OverlapBlendFromLeft4xH_NEON(pred, prediction_stride, height, obmc_pred, + obmc_prediction_stride); + return; + } + if (width == 8) { + OverlapBlendFromLeft8xH_NEON(pred, prediction_stride, height, obmc_pred, + obmc_prediction_stride); + return; + } + const uint8x16_t mask_inverter = vdupq_n_u8(64); + const uint8_t* mask = kObmcMask + width - 2; + int x = 0; + do { + pred = static_cast<uint8_t*>(prediction) + x; + obmc_pred = static_cast<const uint8_t*>(obmc_prediction) + x; + const uint8x16_t pred_mask = vld1q_u8(mask + x); + // 64 - mask + const uint8x16_t obmc_pred_mask = vsubq_u8(mask_inverter, pred_mask); + int y = 0; + do { + const uint8x16_t pred_val = vld1q_u8(pred); + const uint8x16_t obmc_pred_val = vld1q_u8(obmc_pred); + const uint16x8_t weighted_pred_lo = + vmull_u8(vget_low_u8(pred_mask), vget_low_u8(pred_val)); + const uint8x8_t result_lo = + vrshrn_n_u16(vmlal_u8(weighted_pred_lo, vget_low_u8(obmc_pred_mask), + vget_low_u8(obmc_pred_val)), + 6); + const uint16x8_t weighted_pred_hi = + vmull_u8(vget_high_u8(pred_mask), vget_high_u8(pred_val)); + const uint8x8_t result_hi = + vrshrn_n_u16(vmlal_u8(weighted_pred_hi, vget_high_u8(obmc_pred_mask), + vget_high_u8(obmc_pred_val)), + 6); + vst1q_u8(pred, vcombine_u8(result_lo, result_hi)); + + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + } while (++y < height); + x += 16; + } while (x < width); +} + +inline void OverlapBlendFromTop4x4_NEON(uint8_t* const prediction, + const ptrdiff_t prediction_stride, + const uint8_t* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride, + const int height) { + uint8_t* pred = prediction; + const uint8_t* obmc_pred = obmc_prediction; + uint8x8_t pred_mask = vdup_n_u8(kObmcMask[height - 2]); + const uint8x8_t mask_inverter = vdup_n_u8(64); + uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + if (height == 2) { + return; + } + + pred_mask = vdup_n_u8(kObmcMask[3]); + obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + pred_mask = vdup_n_u8(kObmcMask[4]); + obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); +} + +inline void OverlapBlendFromTop4xH_NEON( + uint8_t* const prediction, const ptrdiff_t prediction_stride, + const int height, const uint8_t* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + if (height < 8) { + OverlapBlendFromTop4x4_NEON(prediction, prediction_stride, obmc_prediction, + obmc_prediction_stride, height); + return; + } + uint8_t* pred = prediction; + const uint8_t* obmc_pred = obmc_prediction; + const uint8_t* mask = kObmcMask + height - 2; + const uint8x8_t mask_inverter = vdup_n_u8(64); + int y = 0; + // Compute 6 lines for height 8, or 12 lines for height 16. The remaining + // lines are unchanged as the corresponding mask value is 64. + do { + uint8x8_t pred_mask = vdup_n_u8(mask[y]); + uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + pred_mask = vdup_n_u8(mask[y + 1]); + obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + pred_mask = vdup_n_u8(mask[y + 2]); + obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + pred_mask = vdup_n_u8(mask[y + 3]); + obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + pred_mask = vdup_n_u8(mask[y + 4]); + obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + pred_mask = vdup_n_u8(mask[y + 5]); + obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + // Increment for the right mask index. + y += 6; + } while (y < height - 4); +} + +inline void OverlapBlendFromTop8xH_NEON( + uint8_t* const prediction, const ptrdiff_t prediction_stride, + const int height, const uint8_t* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + uint8_t* pred = prediction; + const uint8_t* obmc_pred = obmc_prediction; + const uint8x8_t mask_inverter = vdup_n_u8(64); + const uint8_t* mask = kObmcMask + height - 2; + const int compute_height = height - (height >> 2); + int y = 0; + do { + const uint8x8_t pred_mask = vdup_n_u8(mask[y]); + // 64 - mask + const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + const uint8x8_t pred_val = vld1_u8(pred); + const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val); + const uint8x8_t obmc_pred_val = vld1_u8(obmc_pred); + const uint8x8_t result = + vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6); + + vst1_u8(pred, result); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + } while (++y != compute_height); +} + +void OverlapBlendFromTop_NEON(void* const prediction, + const ptrdiff_t prediction_stride, + const int width, const int height, + const void* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + auto* pred = static_cast<uint8_t*>(prediction); + const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction); + + if (width == 2) { + OverlapBlend2xH_NEON<false>(pred, prediction_stride, height, obmc_pred, + obmc_prediction_stride); + return; + } + if (width == 4) { + OverlapBlendFromTop4xH_NEON(pred, prediction_stride, height, obmc_pred, + obmc_prediction_stride); + return; + } + + if (width == 8) { + OverlapBlendFromTop8xH_NEON(pred, prediction_stride, height, obmc_pred, + obmc_prediction_stride); + return; + } + + const uint8_t* mask = kObmcMask + height - 2; + const uint8x8_t mask_inverter = vdup_n_u8(64); + // Stop when mask value becomes 64. This is inferred for 4xH. + const int compute_height = height - (height >> 2); + int y = 0; + do { + const uint8x8_t pred_mask = vdup_n_u8(mask[y]); + // 64 - mask + const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + int x = 0; + do { + const uint8x16_t pred_val = vld1q_u8(pred + x); + const uint8x16_t obmc_pred_val = vld1q_u8(obmc_pred + x); + const uint16x8_t weighted_pred_lo = + vmull_u8(pred_mask, vget_low_u8(pred_val)); + const uint8x8_t result_lo = + vrshrn_n_u16(vmlal_u8(weighted_pred_lo, obmc_pred_mask, + vget_low_u8(obmc_pred_val)), + 6); + const uint16x8_t weighted_pred_hi = + vmull_u8(pred_mask, vget_high_u8(pred_val)); + const uint8x8_t result_hi = + vrshrn_n_u16(vmlal_u8(weighted_pred_hi, obmc_pred_mask, + vget_high_u8(obmc_pred_val)), + 6); + vst1q_u8(pred + x, vcombine_u8(result_lo, result_hi)); + + x += 16; + } while (x < width); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + } while (++y < compute_height); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendFromTop_NEON; + dsp->obmc_blend[kObmcDirectionHorizontal] = OverlapBlendFromLeft_NEON; +} + +} // namespace + +void ObmcInit_NEON() { Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON + +namespace libgav1 { +namespace dsp { + +void ObmcInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/obmc_neon.h b/src/dsp/arm/obmc_neon.h new file mode 100644 index 0000000..d5c9d9c --- /dev/null +++ b/src/dsp/arm/obmc_neon.h @@ -0,0 +1,38 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_OBMC_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_OBMC_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::obmc_blend. This function is not thread-safe. +void ObmcInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +// If NEON is enabled, signal the NEON implementation should be used. +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_ObmcVertical LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_ObmcHorizontal LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_OBMC_NEON_H_ diff --git a/src/dsp/arm/super_res_neon.cc b/src/dsp/arm/super_res_neon.cc new file mode 100644 index 0000000..1680450 --- /dev/null +++ b/src/dsp/arm/super_res_neon.cc @@ -0,0 +1,166 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/super_res.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { + +namespace low_bitdepth { +namespace { + +void SuperResCoefficients_NEON(const int upscaled_width, + const int initial_subpixel_x, const int step, + void* const coefficients) { + auto* dst = static_cast<uint8_t*>(coefficients); + int subpixel_x = initial_subpixel_x; + int x = RightShiftWithCeiling(upscaled_width, 3); + do { + uint8x8_t filter[8]; + uint8x16_t d[kSuperResFilterTaps / 2]; + for (int i = 0; i < 8; ++i, subpixel_x += step) { + filter[i] = + vld1_u8(kUpscaleFilterUnsigned[(subpixel_x & kSuperResScaleMask) >> + kSuperResExtraBits]); + } + Transpose8x8(filter, d); + vst1q_u8(dst, d[0]); + dst += 16; + vst1q_u8(dst, d[1]); + dst += 16; + vst1q_u8(dst, d[2]); + dst += 16; + vst1q_u8(dst, d[3]); + dst += 16; + } while (--x != 0); +} + +// Maximum sum of positive taps: 171 = 7 + 86 + 71 + 7 +// Maximum sum: 255*171 == 0xAA55 +// The sum is clipped to [0, 255], so adding all positive and then +// subtracting all negative with saturation is sufficient. +// 0 1 2 3 4 5 6 7 +// tap sign: - + - + + - + - +inline uint8x8_t SuperRes(const uint8x8_t src[kSuperResFilterTaps], + const uint8_t** coefficients) { + uint8x16_t f[kSuperResFilterTaps / 2]; + for (int i = 0; i < kSuperResFilterTaps / 2; ++i, *coefficients += 16) { + f[i] = vld1q_u8(*coefficients); + } + uint16x8_t res = vmull_u8(src[1], vget_high_u8(f[0])); + res = vmlal_u8(res, src[3], vget_high_u8(f[1])); + res = vmlal_u8(res, src[4], vget_low_u8(f[2])); + res = vmlal_u8(res, src[6], vget_low_u8(f[3])); + uint16x8_t temp = vmull_u8(src[0], vget_low_u8(f[0])); + temp = vmlal_u8(temp, src[2], vget_low_u8(f[1])); + temp = vmlal_u8(temp, src[5], vget_high_u8(f[2])); + temp = vmlal_u8(temp, src[7], vget_high_u8(f[3])); + res = vqsubq_u16(res, temp); + return vqrshrn_n_u16(res, kFilterBits); +} + +void SuperRes_NEON(const void* const coefficients, void* const source, + const ptrdiff_t stride, const int height, + const int downscaled_width, const int upscaled_width, + const int initial_subpixel_x, const int step, + void* const dest) { + auto* src = static_cast<uint8_t*>(source) - DivideBy2(kSuperResFilterTaps); + auto* dst = static_cast<uint8_t*>(dest); + int y = height; + do { + const auto* filter = static_cast<const uint8_t*>(coefficients); + uint8_t* dst_ptr = dst; + ExtendLine<uint8_t>(src + DivideBy2(kSuperResFilterTaps), downscaled_width, + kSuperResHorizontalBorder, kSuperResHorizontalBorder); + int subpixel_x = initial_subpixel_x; + uint8x8_t sr[8]; + uint8x16_t s[8]; + int x = RightShiftWithCeiling(upscaled_width, 4); + // The below code calculates up to 15 extra upscaled + // pixels which will over-read up to 15 downscaled pixels in the end of each + // row. kSuperResHorizontalBorder accounts for this. + do { + for (int i = 0; i < 8; ++i, subpixel_x += step) { + sr[i] = vld1_u8(&src[subpixel_x >> kSuperResScaleBits]); + } + for (int i = 0; i < 8; ++i, subpixel_x += step) { + const uint8x8_t s_hi = vld1_u8(&src[subpixel_x >> kSuperResScaleBits]); + s[i] = vcombine_u8(sr[i], s_hi); + } + Transpose8x16(s); + // Do not use loop for the following 8 instructions, since the compiler + // will generate redundant code. + sr[0] = vget_low_u8(s[0]); + sr[1] = vget_low_u8(s[1]); + sr[2] = vget_low_u8(s[2]); + sr[3] = vget_low_u8(s[3]); + sr[4] = vget_low_u8(s[4]); + sr[5] = vget_low_u8(s[5]); + sr[6] = vget_low_u8(s[6]); + sr[7] = vget_low_u8(s[7]); + const uint8x8_t d0 = SuperRes(sr, &filter); + // Do not use loop for the following 8 instructions, since the compiler + // will generate redundant code. + sr[0] = vget_high_u8(s[0]); + sr[1] = vget_high_u8(s[1]); + sr[2] = vget_high_u8(s[2]); + sr[3] = vget_high_u8(s[3]); + sr[4] = vget_high_u8(s[4]); + sr[5] = vget_high_u8(s[5]); + sr[6] = vget_high_u8(s[6]); + sr[7] = vget_high_u8(s[7]); + const uint8x8_t d1 = SuperRes(sr, &filter); + vst1q_u8(dst_ptr, vcombine_u8(d0, d1)); + dst_ptr += 16; + } while (--x != 0); + src += stride; + dst += stride; + } while (--y != 0); +} + +void Init8bpp() { + Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + dsp->super_res_coefficients = SuperResCoefficients_NEON; + dsp->super_res = SuperRes_NEON; +} + +} // namespace +} // namespace low_bitdepth + +void SuperResInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON + +namespace libgav1 { +namespace dsp { + +void SuperResInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/super_res_neon.h b/src/dsp/arm/super_res_neon.h new file mode 100644 index 0000000..f51785d --- /dev/null +++ b/src/dsp/arm/super_res_neon.h @@ -0,0 +1,37 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_SUPER_RES_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_SUPER_RES_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::super_res. This function is not thread-safe. +void SuperResInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_SuperRes LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_SuperResClip LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_SUPER_RES_NEON_H_ diff --git a/src/dsp/arm/warp_neon.cc b/src/dsp/arm/warp_neon.cc new file mode 100644 index 0000000..7a41998 --- /dev/null +++ b/src/dsp/arm/warp_neon.cc @@ -0,0 +1,453 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/warp.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <type_traits> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +// Number of extra bits of precision in warped filtering. +constexpr int kWarpedDiffPrecisionBits = 10; +constexpr int kFirstPassOffset = 1 << 14; +constexpr int kOffsetRemoval = + (kFirstPassOffset >> kInterRoundBitsHorizontal) * 128; + +// Applies the horizontal filter to one source row and stores the result in +// |intermediate_result_row|. |intermediate_result_row| is a row in the 15x8 +// |intermediate_result| two-dimensional array. +// +// src_row_centered contains 16 "centered" samples of a source row. (We center +// the samples by subtracting 128 from the samples.) +void HorizontalFilter(const int sx4, const int16_t alpha, + const int8x16_t src_row_centered, + int16_t intermediate_result_row[8]) { + int sx = sx4 - MultiplyBy4(alpha); + int8x8_t filter[8]; + for (int x = 0; x < 8; ++x) { + const int offset = RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) + + kWarpedPixelPrecisionShifts; + filter[x] = vld1_s8(kWarpedFilters8[offset]); + sx += alpha; + } + Transpose8x8(filter); + // Add kFirstPassOffset to ensure |sum| stays within uint16_t. + // Add 128 (offset) * 128 (filter sum) (also 1 << 14) to account for the + // centering of the source samples. These combined are 1 << 15 or -32768. + int16x8_t sum = + vdupq_n_s16(static_cast<int16_t>(kFirstPassOffset + 128 * 128)); + // Unrolled k = 0..7 loop. We need to manually unroll the loop because the + // third argument (an index value) to vextq_s8() must be a constant + // (immediate). src_row_window is a sliding window of length 8 into + // src_row_centered. + // k = 0. + int8x8_t src_row_window = vget_low_s8(src_row_centered); + sum = vmlal_s8(sum, filter[0], src_row_window); + // k = 1. + src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 1)); + sum = vmlal_s8(sum, filter[1], src_row_window); + // k = 2. + src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 2)); + sum = vmlal_s8(sum, filter[2], src_row_window); + // k = 3. + src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 3)); + sum = vmlal_s8(sum, filter[3], src_row_window); + // k = 4. + src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 4)); + sum = vmlal_s8(sum, filter[4], src_row_window); + // k = 5. + src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 5)); + sum = vmlal_s8(sum, filter[5], src_row_window); + // k = 6. + src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 6)); + sum = vmlal_s8(sum, filter[6], src_row_window); + // k = 7. + src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 7)); + sum = vmlal_s8(sum, filter[7], src_row_window); + // End of unrolled k = 0..7 loop. + // Due to the offset |sum| is guaranteed to be unsigned. + uint16x8_t sum_unsigned = vreinterpretq_u16_s16(sum); + sum_unsigned = vrshrq_n_u16(sum_unsigned, kInterRoundBitsHorizontal); + // After the shift |sum_unsigned| will fit into int16_t. + vst1q_s16(intermediate_result_row, vreinterpretq_s16_u16(sum_unsigned)); +} + +template <bool is_compound> +void Warp_NEON(const void* const source, const ptrdiff_t source_stride, + const int source_width, const int source_height, + const int* const warp_params, const int subsampling_x, + const int subsampling_y, const int block_start_x, + const int block_start_y, const int block_width, + const int block_height, const int16_t alpha, const int16_t beta, + const int16_t gamma, const int16_t delta, void* dest, + const ptrdiff_t dest_stride) { + constexpr int kRoundBitsVertical = + is_compound ? kInterRoundBitsCompoundVertical : kInterRoundBitsVertical; + union { + // Intermediate_result is the output of the horizontal filtering and + // rounding. The range is within 13 (= bitdepth + kFilterBits + 1 - + // kInterRoundBitsHorizontal) bits (unsigned). We use the signed int16_t + // type so that we can multiply it by kWarpedFilters (which has signed + // values) using vmlal_s16(). + int16_t intermediate_result[15][8]; // 15 rows, 8 columns. + // In the simple special cases where the samples in each row are all the + // same, store one sample per row in a column vector. + int16_t intermediate_result_column[15]; + }; + + const auto* const src = static_cast<const uint8_t*>(source); + using DestType = + typename std::conditional<is_compound, int16_t, uint8_t>::type; + auto* dst = static_cast<DestType*>(dest); + + assert(block_width >= 8); + assert(block_height >= 8); + + // Warp process applies for each 8x8 block. + int start_y = block_start_y; + do { + int start_x = block_start_x; + do { + const int src_x = (start_x + 4) << subsampling_x; + const int src_y = (start_y + 4) << subsampling_y; + const int dst_x = + src_x * warp_params[2] + src_y * warp_params[3] + warp_params[0]; + const int dst_y = + src_x * warp_params[4] + src_y * warp_params[5] + warp_params[1]; + const int x4 = dst_x >> subsampling_x; + const int y4 = dst_y >> subsampling_y; + const int ix4 = x4 >> kWarpedModelPrecisionBits; + const int iy4 = y4 >> kWarpedModelPrecisionBits; + // A prediction block may fall outside the frame's boundaries. If a + // prediction block is calculated using only samples outside the frame's + // boundary, the filtering can be simplified. We can divide the plane + // into several regions and handle them differently. + // + // | | + // 1 | 3 | 1 + // | | + // -------+-----------+------- + // |***********| + // 2 |*****4*****| 2 + // |***********| + // -------+-----------+------- + // | | + // 1 | 3 | 1 + // | | + // + // At the center, region 4 represents the frame and is the general case. + // + // In regions 1 and 2, the prediction block is outside the frame's + // boundary horizontally. Therefore the horizontal filtering can be + // simplified. Furthermore, in the region 1 (at the four corners), the + // prediction is outside the frame's boundary both horizontally and + // vertically, so we get a constant prediction block. + // + // In region 3, the prediction block is outside the frame's boundary + // vertically. Unfortunately because we apply the horizontal filters + // first, by the time we apply the vertical filters, they no longer see + // simple inputs. So the only simplification is that all the rows are + // the same, but we still need to apply all the horizontal and vertical + // filters. + + // Check for two simple special cases, where the horizontal filter can + // be significantly simplified. + // + // In general, for each row, the horizontal filter is calculated as + // follows: + // for (int x = -4; x < 4; ++x) { + // const int offset = ...; + // int sum = first_pass_offset; + // for (int k = 0; k < 8; ++k) { + // const int column = Clip3(ix4 + x + k - 3, 0, source_width - 1); + // sum += kWarpedFilters[offset][k] * src_row[column]; + // } + // ... + // } + // The column index before clipping, ix4 + x + k - 3, varies in the range + // ix4 - 7 <= ix4 + x + k - 3 <= ix4 + 7. If ix4 - 7 >= source_width - 1 + // or ix4 + 7 <= 0, then all the column indexes are clipped to the same + // border index (source_width - 1 or 0, respectively). Then for each x, + // the inner for loop of the horizontal filter is reduced to multiplying + // the border pixel by the sum of the filter coefficients. + if (ix4 - 7 >= source_width - 1 || ix4 + 7 <= 0) { + // Regions 1 and 2. + // Points to the left or right border of the first row of |src|. + const uint8_t* first_row_border = + (ix4 + 7 <= 0) ? src : src + source_width - 1; + // In general, for y in [-7, 8), the row number iy4 + y is clipped: + // const int row = Clip3(iy4 + y, 0, source_height - 1); + // In two special cases, iy4 + y is clipped to either 0 or + // source_height - 1 for all y. In the rest of the cases, iy4 + y is + // bounded and we can avoid clipping iy4 + y by relying on a reference + // frame's boundary extension on the top and bottom. + if (iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0) { + // Region 1. + // Every sample used to calculate the prediction block has the same + // value. So the whole prediction block has the same value. + const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1; + const uint8_t row_border_pixel = + first_row_border[row * source_stride]; + + DestType* dst_row = dst + start_x - block_start_x; + for (int y = 0; y < 8; ++y) { + if (is_compound) { + const int16x8_t sum = + vdupq_n_s16(row_border_pixel << (kInterRoundBitsVertical - + kRoundBitsVertical)); + vst1q_s16(reinterpret_cast<int16_t*>(dst_row), sum); + } else { + memset(dst_row, row_border_pixel, 8); + } + dst_row += dest_stride; + } + // End of region 1. Continue the |start_x| do-while loop. + start_x += 8; + continue; + } + + // Region 2. + // Horizontal filter. + // The input values in this region are generated by extending the border + // which makes them identical in the horizontal direction. This + // computation could be inlined in the vertical pass but most + // implementations will need a transpose of some sort. + // It is not necessary to use the offset values here because the + // horizontal pass is a simple shift and the vertical pass will always + // require using 32 bits. + for (int y = -7; y < 8; ++y) { + // We may over-read up to 13 pixels above the top source row, or up + // to 13 pixels below the bottom source row. This is proved in + // warp.cc. + const int row = iy4 + y; + int sum = first_row_border[row * source_stride]; + sum <<= (kFilterBits - kInterRoundBitsHorizontal); + intermediate_result_column[y + 7] = sum; + } + // Vertical filter. + DestType* dst_row = dst + start_x - block_start_x; + int sy4 = + (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta); + for (int y = 0; y < 8; ++y) { + int sy = sy4 - MultiplyBy4(gamma); +#if defined(__aarch64__) + const int16x8_t intermediate = + vld1q_s16(&intermediate_result_column[y]); + int16_t tmp[8]; + for (int x = 0; x < 8; ++x) { + const int offset = + RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) + + kWarpedPixelPrecisionShifts; + const int16x8_t filter = vld1q_s16(kWarpedFilters[offset]); + const int32x4_t product_low = + vmull_s16(vget_low_s16(filter), vget_low_s16(intermediate)); + const int32x4_t product_high = + vmull_s16(vget_high_s16(filter), vget_high_s16(intermediate)); + // vaddvq_s32 is only available on __aarch64__. + const int32_t sum = + vaddvq_s32(product_low) + vaddvq_s32(product_high); + const int16_t sum_descale = + RightShiftWithRounding(sum, kRoundBitsVertical); + if (is_compound) { + dst_row[x] = sum_descale; + } else { + tmp[x] = sum_descale; + } + sy += gamma; + } + if (!is_compound) { + const int16x8_t sum = vld1q_s16(tmp); + vst1_u8(reinterpret_cast<uint8_t*>(dst_row), vqmovun_s16(sum)); + } +#else // !defined(__aarch64__) + int16x8_t filter[8]; + for (int x = 0; x < 8; ++x) { + const int offset = + RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) + + kWarpedPixelPrecisionShifts; + filter[x] = vld1q_s16(kWarpedFilters[offset]); + sy += gamma; + } + Transpose8x8(filter); + int32x4_t sum_low = vdupq_n_s32(0); + int32x4_t sum_high = sum_low; + for (int k = 0; k < 8; ++k) { + const int16_t intermediate = intermediate_result_column[y + k]; + sum_low = + vmlal_n_s16(sum_low, vget_low_s16(filter[k]), intermediate); + sum_high = + vmlal_n_s16(sum_high, vget_high_s16(filter[k]), intermediate); + } + const int16x8_t sum = + vcombine_s16(vrshrn_n_s32(sum_low, kRoundBitsVertical), + vrshrn_n_s32(sum_high, kRoundBitsVertical)); + if (is_compound) { + vst1q_s16(reinterpret_cast<int16_t*>(dst_row), sum); + } else { + vst1_u8(reinterpret_cast<uint8_t*>(dst_row), vqmovun_s16(sum)); + } +#endif // defined(__aarch64__) + dst_row += dest_stride; + sy4 += delta; + } + // End of region 2. Continue the |start_x| do-while loop. + start_x += 8; + continue; + } + + // Regions 3 and 4. + // At this point, we know ix4 - 7 < source_width - 1 and ix4 + 7 > 0. + + // In general, for y in [-7, 8), the row number iy4 + y is clipped: + // const int row = Clip3(iy4 + y, 0, source_height - 1); + // In two special cases, iy4 + y is clipped to either 0 or + // source_height - 1 for all y. In the rest of the cases, iy4 + y is + // bounded and we can avoid clipping iy4 + y by relying on a reference + // frame's boundary extension on the top and bottom. + if (iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0) { + // Region 3. + // Horizontal filter. + const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1; + const uint8_t* const src_row = src + row * source_stride; + // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also + // read but is ignored. + // + // NOTE: This may read up to 13 bytes before src_row[0] or up to 14 + // bytes after src_row[source_width - 1]. We assume the source frame + // has left and right borders of at least 13 bytes that extend the + // frame boundary pixels. We also assume there is at least one extra + // padding byte after the right border of the last source row. + const uint8x16_t src_row_v = vld1q_u8(&src_row[ix4 - 7]); + // Convert src_row_v to int8 (subtract 128). + const int8x16_t src_row_centered = + vreinterpretq_s8_u8(vsubq_u8(src_row_v, vdupq_n_u8(128))); + int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7; + for (int y = -7; y < 8; ++y) { + HorizontalFilter(sx4, alpha, src_row_centered, + intermediate_result[y + 7]); + sx4 += beta; + } + } else { + // Region 4. + // Horizontal filter. + int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7; + for (int y = -7; y < 8; ++y) { + // We may over-read up to 13 pixels above the top source row, or up + // to 13 pixels below the bottom source row. This is proved in + // warp.cc. + const int row = iy4 + y; + const uint8_t* const src_row = src + row * source_stride; + // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also + // read but is ignored. + // + // NOTE: This may read up to 13 bytes before src_row[0] or up to 14 + // bytes after src_row[source_width - 1]. We assume the source frame + // has left and right borders of at least 13 bytes that extend the + // frame boundary pixels. We also assume there is at least one extra + // padding byte after the right border of the last source row. + const uint8x16_t src_row_v = vld1q_u8(&src_row[ix4 - 7]); + // Convert src_row_v to int8 (subtract 128). + const int8x16_t src_row_centered = + vreinterpretq_s8_u8(vsubq_u8(src_row_v, vdupq_n_u8(128))); + HorizontalFilter(sx4, alpha, src_row_centered, + intermediate_result[y + 7]); + sx4 += beta; + } + } + + // Regions 3 and 4. + // Vertical filter. + DestType* dst_row = dst + start_x - block_start_x; + int sy4 = + (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta); + for (int y = 0; y < 8; ++y) { + int sy = sy4 - MultiplyBy4(gamma); + int16x8_t filter[8]; + for (int x = 0; x < 8; ++x) { + const int offset = + RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) + + kWarpedPixelPrecisionShifts; + filter[x] = vld1q_s16(kWarpedFilters[offset]); + sy += gamma; + } + Transpose8x8(filter); + int32x4_t sum_low = vdupq_n_s32(-kOffsetRemoval); + int32x4_t sum_high = sum_low; + for (int k = 0; k < 8; ++k) { + const int16x8_t intermediate = vld1q_s16(intermediate_result[y + k]); + sum_low = vmlal_s16(sum_low, vget_low_s16(filter[k]), + vget_low_s16(intermediate)); + sum_high = vmlal_s16(sum_high, vget_high_s16(filter[k]), + vget_high_s16(intermediate)); + } + const int16x8_t sum = + vcombine_s16(vrshrn_n_s32(sum_low, kRoundBitsVertical), + vrshrn_n_s32(sum_high, kRoundBitsVertical)); + if (is_compound) { + vst1q_s16(reinterpret_cast<int16_t*>(dst_row), sum); + } else { + vst1_u8(reinterpret_cast<uint8_t*>(dst_row), vqmovun_s16(sum)); + } + dst_row += dest_stride; + sy4 += delta; + } + start_x += 8; + } while (start_x < block_start_x + block_width); + dst += 8 * dest_stride; + start_y += 8; + } while (start_y < block_start_y + block_height); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->warp = Warp_NEON</*is_compound=*/false>; + dsp->warp_compound = Warp_NEON</*is_compound=*/true>; +} + +} // namespace +} // namespace low_bitdepth + +void WarpInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void WarpInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/warp_neon.h b/src/dsp/arm/warp_neon.h new file mode 100644 index 0000000..dbcaa23 --- /dev/null +++ b/src/dsp/arm/warp_neon.h @@ -0,0 +1,37 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_WARP_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_WARP_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::warp. This function is not thread-safe. +void WarpInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_Warp LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WarpCompound LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_WARP_NEON_H_ diff --git a/src/dsp/arm/weight_mask_neon.cc b/src/dsp/arm/weight_mask_neon.cc new file mode 100644 index 0000000..49d3be0 --- /dev/null +++ b/src/dsp/arm/weight_mask_neon.cc @@ -0,0 +1,463 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/arm/weight_mask_neon.h" + +#include "src/dsp/weight_mask.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +constexpr int kRoundingBits8bpp = 4; + +template <bool mask_is_inverse> +inline void WeightMask8_NEON(const int16_t* prediction_0, + const int16_t* prediction_1, uint8_t* mask) { + const int16x8_t pred_0 = vld1q_s16(prediction_0); + const int16x8_t pred_1 = vld1q_s16(prediction_1); + const uint8x8_t difference_offset = vdup_n_u8(38); + const uint8x8_t mask_ceiling = vdup_n_u8(64); + const uint16x8_t difference = vrshrq_n_u16( + vreinterpretq_u16_s16(vabdq_s16(pred_0, pred_1)), kRoundingBits8bpp); + const uint8x8_t adjusted_difference = + vqadd_u8(vqshrn_n_u16(difference, 4), difference_offset); + const uint8x8_t mask_value = vmin_u8(adjusted_difference, mask_ceiling); + if (mask_is_inverse) { + const uint8x8_t inverted_mask_value = vsub_u8(mask_ceiling, mask_value); + vst1_u8(mask, inverted_mask_value); + } else { + vst1_u8(mask, mask_value); + } +} + +#define WEIGHT8_WITHOUT_STRIDE \ + WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask) + +#define WEIGHT8_AND_STRIDE \ + WEIGHT8_WITHOUT_STRIDE; \ + pred_0 += 8; \ + pred_1 += 8; \ + mask += mask_stride + +template <bool mask_is_inverse> +void WeightMask8x8_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y = 0; + do { + WEIGHT8_AND_STRIDE; + } while (++y < 7); + WEIGHT8_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask8x16_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT8_AND_STRIDE; + WEIGHT8_AND_STRIDE; + WEIGHT8_AND_STRIDE; + } while (++y3 < 5); + WEIGHT8_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask8x32_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y5 = 0; + do { + WEIGHT8_AND_STRIDE; + WEIGHT8_AND_STRIDE; + WEIGHT8_AND_STRIDE; + WEIGHT8_AND_STRIDE; + WEIGHT8_AND_STRIDE; + } while (++y5 < 6); + WEIGHT8_AND_STRIDE; + WEIGHT8_WITHOUT_STRIDE; +} + +#define WEIGHT16_WITHOUT_STRIDE \ + WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8) + +#define WEIGHT16_AND_STRIDE \ + WEIGHT16_WITHOUT_STRIDE; \ + pred_0 += 16; \ + pred_1 += 16; \ + mask += mask_stride + +template <bool mask_is_inverse> +void WeightMask16x8_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y = 0; + do { + WEIGHT16_AND_STRIDE; + } while (++y < 7); + WEIGHT16_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask16x16_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + } while (++y3 < 5); + WEIGHT16_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask16x32_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y5 = 0; + do { + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + } while (++y5 < 6); + WEIGHT16_AND_STRIDE; + WEIGHT16_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask16x64_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + } while (++y3 < 21); + WEIGHT16_WITHOUT_STRIDE; +} + +#define WEIGHT32_WITHOUT_STRIDE \ + WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 16, pred_1 + 16, mask + 16); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 24, pred_1 + 24, mask + 24) + +#define WEIGHT32_AND_STRIDE \ + WEIGHT32_WITHOUT_STRIDE; \ + pred_0 += 32; \ + pred_1 += 32; \ + mask += mask_stride + +template <bool mask_is_inverse> +void WeightMask32x8_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask32x16_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + } while (++y3 < 5); + WEIGHT32_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask32x32_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y5 = 0; + do { + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + } while (++y5 < 6); + WEIGHT32_AND_STRIDE; + WEIGHT32_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask32x64_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + } while (++y3 < 21); + WEIGHT32_WITHOUT_STRIDE; +} + +#define WEIGHT64_WITHOUT_STRIDE \ + WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 16, pred_1 + 16, mask + 16); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 24, pred_1 + 24, mask + 24); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 32, pred_1 + 32, mask + 32); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 40, pred_1 + 40, mask + 40); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 48, pred_1 + 48, mask + 48); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 56, pred_1 + 56, mask + 56) + +#define WEIGHT64_AND_STRIDE \ + WEIGHT64_WITHOUT_STRIDE; \ + pred_0 += 64; \ + pred_1 += 64; \ + mask += mask_stride + +template <bool mask_is_inverse> +void WeightMask64x16_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + } while (++y3 < 5); + WEIGHT64_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask64x32_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y5 = 0; + do { + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + } while (++y5 < 6); + WEIGHT64_AND_STRIDE; + WEIGHT64_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask64x64_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + } while (++y3 < 21); + WEIGHT64_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask64x128_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + } while (++y3 < 42); + WEIGHT64_AND_STRIDE; + WEIGHT64_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask128x64_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + const ptrdiff_t adjusted_mask_stride = mask_stride - 64; + do { + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + } while (++y3 < 21); + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask128x128_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + const ptrdiff_t adjusted_mask_stride = mask_stride - 64; + do { + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + } while (++y3 < 42); + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; +} + +#define INIT_WEIGHT_MASK_8BPP(width, height, w_index, h_index) \ + dsp->weight_mask[w_index][h_index][0] = \ + WeightMask##width##x##height##_NEON<0>; \ + dsp->weight_mask[w_index][h_index][1] = WeightMask##width##x##height##_NEON<1> +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + INIT_WEIGHT_MASK_8BPP(8, 8, 0, 0); + INIT_WEIGHT_MASK_8BPP(8, 16, 0, 1); + INIT_WEIGHT_MASK_8BPP(8, 32, 0, 2); + INIT_WEIGHT_MASK_8BPP(16, 8, 1, 0); + INIT_WEIGHT_MASK_8BPP(16, 16, 1, 1); + INIT_WEIGHT_MASK_8BPP(16, 32, 1, 2); + INIT_WEIGHT_MASK_8BPP(16, 64, 1, 3); + INIT_WEIGHT_MASK_8BPP(32, 8, 2, 0); + INIT_WEIGHT_MASK_8BPP(32, 16, 2, 1); + INIT_WEIGHT_MASK_8BPP(32, 32, 2, 2); + INIT_WEIGHT_MASK_8BPP(32, 64, 2, 3); + INIT_WEIGHT_MASK_8BPP(64, 16, 3, 1); + INIT_WEIGHT_MASK_8BPP(64, 32, 3, 2); + INIT_WEIGHT_MASK_8BPP(64, 64, 3, 3); + INIT_WEIGHT_MASK_8BPP(64, 128, 3, 4); + INIT_WEIGHT_MASK_8BPP(128, 64, 4, 3); + INIT_WEIGHT_MASK_8BPP(128, 128, 4, 4); +} + +} // namespace +} // namespace low_bitdepth + +void WeightMaskInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON + +namespace libgav1 { +namespace dsp { + +void WeightMaskInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/weight_mask_neon.h b/src/dsp/arm/weight_mask_neon.h new file mode 100644 index 0000000..b4749ec --- /dev/null +++ b/src/dsp/arm/weight_mask_neon.h @@ -0,0 +1,52 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::weight_mask. This function is not thread-safe. +void WeightMaskInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_8x8 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_8x16 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_8x32 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_16x8 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_16x16 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_16x32 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_16x64 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_32x8 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_32x16 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_32x32 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_32x64 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_64x16 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_64x32 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_64x64 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_64x128 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_128x64 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_128x128 LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_NEON_H_ |