diff options
Diffstat (limited to 'src/dsp')
127 files changed, 61385 insertions, 0 deletions
diff --git a/src/dsp/arm/average_blend_neon.cc b/src/dsp/arm/average_blend_neon.cc new file mode 100644 index 0000000..834e8b4 --- /dev/null +++ b/src/dsp/arm/average_blend_neon.cc @@ -0,0 +1,146 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/average_blend.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace { + +constexpr int kInterPostRoundBit = + kInterRoundBitsVertical - kInterRoundBitsCompoundVertical; + +inline uint8x8_t AverageBlend8Row(const int16_t* prediction_0, + const int16_t* prediction_1) { + const int16x8_t pred0 = vld1q_s16(prediction_0); + const int16x8_t pred1 = vld1q_s16(prediction_1); + const int16x8_t res = vaddq_s16(pred0, pred1); + return vqrshrun_n_s16(res, kInterPostRoundBit + 1); +} + +inline void AverageBlendLargeRow(const int16_t* prediction_0, + const int16_t* prediction_1, const int width, + uint8_t* dest) { + int x = width; + do { + const int16x8_t pred_00 = vld1q_s16(prediction_0); + const int16x8_t pred_01 = vld1q_s16(prediction_1); + prediction_0 += 8; + prediction_1 += 8; + const int16x8_t res0 = vaddq_s16(pred_00, pred_01); + const uint8x8_t res_out0 = vqrshrun_n_s16(res0, kInterPostRoundBit + 1); + const int16x8_t pred_10 = vld1q_s16(prediction_0); + const int16x8_t pred_11 = vld1q_s16(prediction_1); + prediction_0 += 8; + prediction_1 += 8; + const int16x8_t res1 = vaddq_s16(pred_10, pred_11); + const uint8x8_t res_out1 = vqrshrun_n_s16(res1, kInterPostRoundBit + 1); + vst1q_u8(dest, vcombine_u8(res_out0, res_out1)); + dest += 16; + x -= 16; + } while (x != 0); +} + +void AverageBlend_NEON(const void* prediction_0, const void* prediction_1, + const int width, const int height, void* const dest, + const ptrdiff_t dest_stride) { + auto* dst = static_cast<uint8_t*>(dest); + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y = height; + + if (width == 4) { + do { + const uint8x8_t result = AverageBlend8Row(pred_0, pred_1); + pred_0 += 8; + pred_1 += 8; + + StoreLo4(dst, result); + dst += dest_stride; + StoreHi4(dst, result); + dst += dest_stride; + y -= 2; + } while (y != 0); + return; + } + + if (width == 8) { + do { + vst1_u8(dst, AverageBlend8Row(pred_0, pred_1)); + dst += dest_stride; + pred_0 += 8; + pred_1 += 8; + + vst1_u8(dst, AverageBlend8Row(pred_0, pred_1)); + dst += dest_stride; + pred_0 += 8; + pred_1 += 8; + + y -= 2; + } while (y != 0); + return; + } + + do { + AverageBlendLargeRow(pred_0, pred_1, width, dst); + dst += dest_stride; + pred_0 += width; + pred_1 += width; + + AverageBlendLargeRow(pred_0, pred_1, width, dst); + dst += dest_stride; + pred_0 += width; + pred_1 += width; + + y -= 2; + } while (y != 0); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->average_blend = AverageBlend_NEON; +} + +} // namespace + +void AverageBlendInit_NEON() { Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON + +namespace libgav1 { +namespace dsp { + +void AverageBlendInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/average_blend_neon.h b/src/dsp/arm/average_blend_neon.h new file mode 100644 index 0000000..d13bcd6 --- /dev/null +++ b/src/dsp/arm/average_blend_neon.h @@ -0,0 +1,36 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_AVERAGE_BLEND_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_AVERAGE_BLEND_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::average_blend. This function is not thread-safe. +void AverageBlendInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_AverageBlend LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_AVERAGE_BLEND_NEON_H_ diff --git a/src/dsp/arm/cdef_neon.cc b/src/dsp/arm/cdef_neon.cc new file mode 100644 index 0000000..4d0e76f --- /dev/null +++ b/src/dsp/arm/cdef_neon.cc @@ -0,0 +1,697 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/cdef.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstdlib> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +#include "src/dsp/cdef.inc" + +// ---------------------------------------------------------------------------- +// Refer to CdefDirection_C(). +// +// int32_t partial[8][15] = {}; +// for (int i = 0; i < 8; ++i) { +// for (int j = 0; j < 8; ++j) { +// const int x = 1; +// partial[0][i + j] += x; +// partial[1][i + j / 2] += x; +// partial[2][i] += x; +// partial[3][3 + i - j / 2] += x; +// partial[4][7 + i - j] += x; +// partial[5][3 - i / 2 + j] += x; +// partial[6][j] += x; +// partial[7][i / 2 + j] += x; +// } +// } +// +// Using the code above, generate the position count for partial[8][15]. +// +// partial[0]: 1 2 3 4 5 6 7 8 7 6 5 4 3 2 1 +// partial[1]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0 +// partial[2]: 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0 +// partial[3]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0 +// partial[4]: 1 2 3 4 5 6 7 8 7 6 5 4 3 2 1 +// partial[5]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0 +// partial[6]: 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0 +// partial[7]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0 +// +// The SIMD code shifts the input horizontally, then adds vertically to get the +// correct partial value for the given position. +// ---------------------------------------------------------------------------- + +// ---------------------------------------------------------------------------- +// partial[0][i + j] += x; +// +// 00 01 02 03 04 05 06 07 00 00 00 00 00 00 00 +// 00 10 11 12 13 14 15 16 17 00 00 00 00 00 00 +// 00 00 20 21 22 23 24 25 26 27 00 00 00 00 00 +// 00 00 00 30 31 32 33 34 35 36 37 00 00 00 00 +// 00 00 00 00 40 41 42 43 44 45 46 47 00 00 00 +// 00 00 00 00 00 50 51 52 53 54 55 56 57 00 00 +// 00 00 00 00 00 00 60 61 62 63 64 65 66 67 00 +// 00 00 00 00 00 00 00 70 71 72 73 74 75 76 77 +// +// partial[4] is the same except the source is reversed. +LIBGAV1_ALWAYS_INLINE void AddPartial_D0_D4(uint8x8_t* v_src, + uint16x8_t* partial_lo, + uint16x8_t* partial_hi) { + const uint8x8_t v_zero = vdup_n_u8(0); + // 00 01 02 03 04 05 06 07 + // 00 10 11 12 13 14 15 16 + *partial_lo = vaddl_u8(v_src[0], vext_u8(v_zero, v_src[1], 7)); + + // 00 00 20 21 22 23 24 25 + *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[2], 6)); + // 17 00 00 00 00 00 00 00 + // 26 27 00 00 00 00 00 00 + *partial_hi = + vaddl_u8(vext_u8(v_src[1], v_zero, 7), vext_u8(v_src[2], v_zero, 6)); + + // 00 00 00 30 31 32 33 34 + *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[3], 5)); + // 35 36 37 00 00 00 00 00 + *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[3], v_zero, 5)); + + // 00 00 00 00 40 41 42 43 + *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[4], 4)); + // 44 45 46 47 00 00 00 00 + *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[4], v_zero, 4)); + + // 00 00 00 00 00 50 51 52 + *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[5], 3)); + // 53 54 55 56 57 00 00 00 + *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[5], v_zero, 3)); + + // 00 00 00 00 00 00 60 61 + *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[6], 2)); + // 62 63 64 65 66 67 00 00 + *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[6], v_zero, 2)); + + // 00 00 00 00 00 00 00 70 + *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[7], 1)); + // 71 72 73 74 75 76 77 00 + *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[7], v_zero, 1)); +} + +// ---------------------------------------------------------------------------- +// partial[1][i + j / 2] += x; +// +// A0 = src[0] + src[1], A1 = src[2] + src[3], ... +// +// A0 A1 A2 A3 00 00 00 00 00 00 00 00 00 00 00 +// 00 B0 B1 B2 B3 00 00 00 00 00 00 00 00 00 00 +// 00 00 C0 C1 C2 C3 00 00 00 00 00 00 00 00 00 +// 00 00 00 D0 D1 D2 D3 00 00 00 00 00 00 00 00 +// 00 00 00 00 E0 E1 E2 E3 00 00 00 00 00 00 00 +// 00 00 00 00 00 F0 F1 F2 F3 00 00 00 00 00 00 +// 00 00 00 00 00 00 G0 G1 G2 G3 00 00 00 00 00 +// 00 00 00 00 00 00 00 H0 H1 H2 H3 00 00 00 00 +// +// partial[3] is the same except the source is reversed. +LIBGAV1_ALWAYS_INLINE void AddPartial_D1_D3(uint8x8_t* v_src, + uint16x8_t* partial_lo, + uint16x8_t* partial_hi) { + uint8x16_t v_d1_temp[8]; + const uint8x8_t v_zero = vdup_n_u8(0); + const uint8x16_t v_zero_16 = vdupq_n_u8(0); + + for (int i = 0; i < 8; ++i) { + v_d1_temp[i] = vcombine_u8(v_src[i], v_zero); + } + + *partial_lo = *partial_hi = vdupq_n_u16(0); + // A0 A1 A2 A3 00 00 00 00 + *partial_lo = vpadalq_u8(*partial_lo, v_d1_temp[0]); + + // 00 B0 B1 B2 B3 00 00 00 + *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[1], 14)); + + // 00 00 C0 C1 C2 C3 00 00 + *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[2], 12)); + // 00 00 00 D0 D1 D2 D3 00 + *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[3], 10)); + // 00 00 00 00 E0 E1 E2 E3 + *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[4], 8)); + + // 00 00 00 00 00 F0 F1 F2 + *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[5], 6)); + // F3 00 00 00 00 00 00 00 + *partial_hi = vpadalq_u8(*partial_hi, vextq_u8(v_d1_temp[5], v_zero_16, 6)); + + // 00 00 00 00 00 00 G0 G1 + *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[6], 4)); + // G2 G3 00 00 00 00 00 00 + *partial_hi = vpadalq_u8(*partial_hi, vextq_u8(v_d1_temp[6], v_zero_16, 4)); + + // 00 00 00 00 00 00 00 H0 + *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[7], 2)); + // H1 H2 H3 00 00 00 00 00 + *partial_hi = vpadalq_u8(*partial_hi, vextq_u8(v_d1_temp[7], v_zero_16, 2)); +} + +// ---------------------------------------------------------------------------- +// partial[7][i / 2 + j] += x; +// +// 00 01 02 03 04 05 06 07 00 00 00 00 00 00 00 +// 10 11 12 13 14 15 16 17 00 00 00 00 00 00 00 +// 00 20 21 22 23 24 25 26 27 00 00 00 00 00 00 +// 00 30 31 32 33 34 35 36 37 00 00 00 00 00 00 +// 00 00 40 41 42 43 44 45 46 47 00 00 00 00 00 +// 00 00 50 51 52 53 54 55 56 57 00 00 00 00 00 +// 00 00 00 60 61 62 63 64 65 66 67 00 00 00 00 +// 00 00 00 70 71 72 73 74 75 76 77 00 00 00 00 +// +// partial[5] is the same except the source is reversed. +LIBGAV1_ALWAYS_INLINE void AddPartial_D5_D7(uint8x8_t* v_src, + uint16x8_t* partial_lo, + uint16x8_t* partial_hi) { + const uint16x8_t v_zero = vdupq_n_u16(0); + uint16x8_t v_pair_add[4]; + // Add vertical source pairs. + v_pair_add[0] = vaddl_u8(v_src[0], v_src[1]); + v_pair_add[1] = vaddl_u8(v_src[2], v_src[3]); + v_pair_add[2] = vaddl_u8(v_src[4], v_src[5]); + v_pair_add[3] = vaddl_u8(v_src[6], v_src[7]); + + // 00 01 02 03 04 05 06 07 + // 10 11 12 13 14 15 16 17 + *partial_lo = v_pair_add[0]; + // 00 00 00 00 00 00 00 00 + // 00 00 00 00 00 00 00 00 + *partial_hi = vdupq_n_u16(0); + + // 00 20 21 22 23 24 25 26 + // 00 30 31 32 33 34 35 36 + *partial_lo = vaddq_u16(*partial_lo, vextq_u16(v_zero, v_pair_add[1], 7)); + // 27 00 00 00 00 00 00 00 + // 37 00 00 00 00 00 00 00 + *partial_hi = vaddq_u16(*partial_hi, vextq_u16(v_pair_add[1], v_zero, 7)); + + // 00 00 40 41 42 43 44 45 + // 00 00 50 51 52 53 54 55 + *partial_lo = vaddq_u16(*partial_lo, vextq_u16(v_zero, v_pair_add[2], 6)); + // 46 47 00 00 00 00 00 00 + // 56 57 00 00 00 00 00 00 + *partial_hi = vaddq_u16(*partial_hi, vextq_u16(v_pair_add[2], v_zero, 6)); + + // 00 00 00 60 61 62 63 64 + // 00 00 00 70 71 72 73 74 + *partial_lo = vaddq_u16(*partial_lo, vextq_u16(v_zero, v_pair_add[3], 5)); + // 65 66 67 00 00 00 00 00 + // 75 76 77 00 00 00 00 00 + *partial_hi = vaddq_u16(*partial_hi, vextq_u16(v_pair_add[3], v_zero, 5)); +} + +LIBGAV1_ALWAYS_INLINE void AddPartial(const void* const source, + ptrdiff_t stride, uint16x8_t* partial_lo, + uint16x8_t* partial_hi) { + const auto* src = static_cast<const uint8_t*>(source); + + // 8x8 input + // 00 01 02 03 04 05 06 07 + // 10 11 12 13 14 15 16 17 + // 20 21 22 23 24 25 26 27 + // 30 31 32 33 34 35 36 37 + // 40 41 42 43 44 45 46 47 + // 50 51 52 53 54 55 56 57 + // 60 61 62 63 64 65 66 67 + // 70 71 72 73 74 75 76 77 + uint8x8_t v_src[8]; + for (int i = 0; i < 8; ++i) { + v_src[i] = vld1_u8(src); + src += stride; + } + + // partial for direction 2 + // -------------------------------------------------------------------------- + // partial[2][i] += x; + // 00 10 20 30 40 50 60 70 00 00 00 00 00 00 00 00 + // 01 11 21 33 41 51 61 71 00 00 00 00 00 00 00 00 + // 02 12 22 33 42 52 62 72 00 00 00 00 00 00 00 00 + // 03 13 23 33 43 53 63 73 00 00 00 00 00 00 00 00 + // 04 14 24 34 44 54 64 74 00 00 00 00 00 00 00 00 + // 05 15 25 35 45 55 65 75 00 00 00 00 00 00 00 00 + // 06 16 26 36 46 56 66 76 00 00 00 00 00 00 00 00 + // 07 17 27 37 47 57 67 77 00 00 00 00 00 00 00 00 + partial_lo[2] = vsetq_lane_u16(SumVector(v_src[0]), partial_lo[2], 0); + partial_lo[2] = vsetq_lane_u16(SumVector(v_src[1]), partial_lo[2], 1); + partial_lo[2] = vsetq_lane_u16(SumVector(v_src[2]), partial_lo[2], 2); + partial_lo[2] = vsetq_lane_u16(SumVector(v_src[3]), partial_lo[2], 3); + partial_lo[2] = vsetq_lane_u16(SumVector(v_src[4]), partial_lo[2], 4); + partial_lo[2] = vsetq_lane_u16(SumVector(v_src[5]), partial_lo[2], 5); + partial_lo[2] = vsetq_lane_u16(SumVector(v_src[6]), partial_lo[2], 6); + partial_lo[2] = vsetq_lane_u16(SumVector(v_src[7]), partial_lo[2], 7); + + // partial for direction 6 + // -------------------------------------------------------------------------- + // partial[6][j] += x; + // 00 01 02 03 04 05 06 07 00 00 00 00 00 00 00 00 + // 10 11 12 13 14 15 16 17 00 00 00 00 00 00 00 00 + // 20 21 22 23 24 25 26 27 00 00 00 00 00 00 00 00 + // 30 31 32 33 34 35 36 37 00 00 00 00 00 00 00 00 + // 40 41 42 43 44 45 46 47 00 00 00 00 00 00 00 00 + // 50 51 52 53 54 55 56 57 00 00 00 00 00 00 00 00 + // 60 61 62 63 64 65 66 67 00 00 00 00 00 00 00 00 + // 70 71 72 73 74 75 76 77 00 00 00 00 00 00 00 00 + const uint8x8_t v_zero = vdup_n_u8(0); + partial_lo[6] = vaddl_u8(v_zero, v_src[0]); + for (int i = 1; i < 8; ++i) { + partial_lo[6] = vaddw_u8(partial_lo[6], v_src[i]); + } + + // partial for direction 0 + AddPartial_D0_D4(v_src, &partial_lo[0], &partial_hi[0]); + + // partial for direction 1 + AddPartial_D1_D3(v_src, &partial_lo[1], &partial_hi[1]); + + // partial for direction 7 + AddPartial_D5_D7(v_src, &partial_lo[7], &partial_hi[7]); + + uint8x8_t v_src_reverse[8]; + for (int i = 0; i < 8; ++i) { + v_src_reverse[i] = vrev64_u8(v_src[i]); + } + + // partial for direction 4 + AddPartial_D0_D4(v_src_reverse, &partial_lo[4], &partial_hi[4]); + + // partial for direction 3 + AddPartial_D1_D3(v_src_reverse, &partial_lo[3], &partial_hi[3]); + + // partial for direction 5 + AddPartial_D5_D7(v_src_reverse, &partial_lo[5], &partial_hi[5]); +} + +uint32x4_t Square(uint16x4_t a) { return vmull_u16(a, a); } + +uint32x4_t SquareAccumulate(uint32x4_t a, uint16x4_t b) { + return vmlal_u16(a, b, b); +} + +// |cost[0]| and |cost[4]| square the input and sum with the corresponding +// element from the other end of the vector: +// |kCdefDivisionTable[]| element: +// cost[0] += (Square(partial[0][i]) + Square(partial[0][14 - i])) * +// kCdefDivisionTable[i + 1]; +// cost[0] += Square(partial[0][7]) * kCdefDivisionTable[8]; +// Because everything is being summed into a single value the distributive +// property allows us to mirror the division table and accumulate once. +uint32_t Cost0Or4(const uint16x8_t a, const uint16x8_t b, + const uint32x4_t division_table[4]) { + uint32x4_t c = vmulq_u32(Square(vget_low_u16(a)), division_table[0]); + c = vmlaq_u32(c, Square(vget_high_u16(a)), division_table[1]); + c = vmlaq_u32(c, Square(vget_low_u16(b)), division_table[2]); + c = vmlaq_u32(c, Square(vget_high_u16(b)), division_table[3]); + return SumVector(c); +} + +// |cost[2]| and |cost[6]| square the input and accumulate: +// cost[2] += Square(partial[2][i]) +uint32_t SquareAccumulate(const uint16x8_t a) { + uint32x4_t c = Square(vget_low_u16(a)); + c = SquareAccumulate(c, vget_high_u16(a)); + c = vmulq_n_u32(c, kCdefDivisionTable[7]); + return SumVector(c); +} + +uint32_t CostOdd(const uint16x8_t a, const uint16x8_t b, const uint32x4_t mask, + const uint32x4_t division_table[2]) { + // Remove elements 0-2. + uint32x4_t c = vandq_u32(mask, Square(vget_low_u16(a))); + c = vaddq_u32(c, Square(vget_high_u16(a))); + c = vmulq_n_u32(c, kCdefDivisionTable[7]); + + c = vmlaq_u32(c, Square(vget_low_u16(a)), division_table[0]); + c = vmlaq_u32(c, Square(vget_low_u16(b)), division_table[1]); + return SumVector(c); +} + +void CdefDirection_NEON(const void* const source, ptrdiff_t stride, + uint8_t* const direction, int* const variance) { + assert(direction != nullptr); + assert(variance != nullptr); + const auto* src = static_cast<const uint8_t*>(source); + uint32_t cost[8]; + uint16x8_t partial_lo[8], partial_hi[8]; + + AddPartial(src, stride, partial_lo, partial_hi); + + cost[2] = SquareAccumulate(partial_lo[2]); + cost[6] = SquareAccumulate(partial_lo[6]); + + const uint32x4_t division_table[4] = { + vld1q_u32(kCdefDivisionTable), vld1q_u32(kCdefDivisionTable + 4), + vld1q_u32(kCdefDivisionTable + 8), vld1q_u32(kCdefDivisionTable + 12)}; + + cost[0] = Cost0Or4(partial_lo[0], partial_hi[0], division_table); + cost[4] = Cost0Or4(partial_lo[4], partial_hi[4], division_table); + + const uint32x4_t division_table_odd[2] = { + vld1q_u32(kCdefDivisionTableOdd), vld1q_u32(kCdefDivisionTableOdd + 4)}; + + const uint32x4_t element_3_mask = {0, 0, 0, static_cast<uint32_t>(-1)}; + + cost[1] = + CostOdd(partial_lo[1], partial_hi[1], element_3_mask, division_table_odd); + cost[3] = + CostOdd(partial_lo[3], partial_hi[3], element_3_mask, division_table_odd); + cost[5] = + CostOdd(partial_lo[5], partial_hi[5], element_3_mask, division_table_odd); + cost[7] = + CostOdd(partial_lo[7], partial_hi[7], element_3_mask, division_table_odd); + + uint32_t best_cost = 0; + *direction = 0; + for (int i = 0; i < 8; ++i) { + if (cost[i] > best_cost) { + best_cost = cost[i]; + *direction = i; + } + } + *variance = (best_cost - cost[(*direction + 4) & 7]) >> 10; +} + +// ------------------------------------------------------------------------- +// CdefFilter + +// Load 4 vectors based on the given |direction|. +void LoadDirection(const uint16_t* const src, const ptrdiff_t stride, + uint16x8_t* output, const int direction) { + // Each |direction| describes a different set of source values. Expand this + // set by negating each set. For |direction| == 0 this gives a diagonal line + // from top right to bottom left. The first value is y, the second x. Negative + // y values move up. + // a b c d + // {-1, 1}, {1, -1}, {-2, 2}, {2, -2} + // c + // a + // 0 + // b + // d + const int y_0 = kCdefDirections[direction][0][0]; + const int x_0 = kCdefDirections[direction][0][1]; + const int y_1 = kCdefDirections[direction][1][0]; + const int x_1 = kCdefDirections[direction][1][1]; + output[0] = vld1q_u16(src + y_0 * stride + x_0); + output[1] = vld1q_u16(src - y_0 * stride - x_0); + output[2] = vld1q_u16(src + y_1 * stride + x_1); + output[3] = vld1q_u16(src - y_1 * stride - x_1); +} + +// Load 4 vectors based on the given |direction|. Use when |block_width| == 4 to +// do 2 rows at a time. +void LoadDirection4(const uint16_t* const src, const ptrdiff_t stride, + uint16x8_t* output, const int direction) { + const int y_0 = kCdefDirections[direction][0][0]; + const int x_0 = kCdefDirections[direction][0][1]; + const int y_1 = kCdefDirections[direction][1][0]; + const int x_1 = kCdefDirections[direction][1][1]; + output[0] = vcombine_u16(vld1_u16(src + y_0 * stride + x_0), + vld1_u16(src + y_0 * stride + stride + x_0)); + output[1] = vcombine_u16(vld1_u16(src - y_0 * stride - x_0), + vld1_u16(src - y_0 * stride + stride - x_0)); + output[2] = vcombine_u16(vld1_u16(src + y_1 * stride + x_1), + vld1_u16(src + y_1 * stride + stride + x_1)); + output[3] = vcombine_u16(vld1_u16(src - y_1 * stride - x_1), + vld1_u16(src - y_1 * stride + stride - x_1)); +} + +int16x8_t Constrain(const uint16x8_t pixel, const uint16x8_t reference, + const uint16x8_t threshold, const int16x8_t damping) { + // If reference > pixel, the difference will be negative, so covert to 0 or + // -1. + const uint16x8_t sign = vcgtq_u16(reference, pixel); + const uint16x8_t abs_diff = vabdq_u16(pixel, reference); + const uint16x8_t shifted_diff = vshlq_u16(abs_diff, damping); + // For bitdepth == 8, the threshold range is [0, 15] and the damping range is + // [3, 6]. If pixel == kCdefLargeValue(0x4000), shifted_diff will always be + // larger than threshold. Subtract using saturation will return 0 when pixel + // == kCdefLargeValue. + static_assert(kCdefLargeValue == 0x4000, "Invalid kCdefLargeValue"); + const uint16x8_t thresh_minus_shifted_diff = + vqsubq_u16(threshold, shifted_diff); + const uint16x8_t clamp_abs_diff = + vminq_u16(thresh_minus_shifted_diff, abs_diff); + // Restore the sign. + return vreinterpretq_s16_u16( + vsubq_u16(veorq_u16(clamp_abs_diff, sign), sign)); +} + +template <int width, bool enable_primary = true, bool enable_secondary = true> +void CdefFilter_NEON(const uint16_t* src, const ptrdiff_t src_stride, + const int height, const int primary_strength, + const int secondary_strength, const int damping, + const int direction, void* dest, + const ptrdiff_t dst_stride) { + static_assert(width == 8 || width == 4, ""); + static_assert(enable_primary || enable_secondary, ""); + constexpr bool clipping_required = enable_primary && enable_secondary; + auto* dst = static_cast<uint8_t*>(dest); + const uint16x8_t cdef_large_value_mask = + vdupq_n_u16(static_cast<uint16_t>(~kCdefLargeValue)); + const uint16x8_t primary_threshold = vdupq_n_u16(primary_strength); + const uint16x8_t secondary_threshold = vdupq_n_u16(secondary_strength); + + int16x8_t primary_damping_shift, secondary_damping_shift; + + // FloorLog2() requires input to be > 0. + // 8-bit damping range: Y: [3, 6], UV: [2, 5]. + if (enable_primary) { + // primary_strength: [0, 15] -> FloorLog2: [0, 3] so a clamp is necessary + // for UV filtering. + primary_damping_shift = + vdupq_n_s16(-std::max(0, damping - FloorLog2(primary_strength))); + } + if (enable_secondary) { + // secondary_strength: [0, 4] -> FloorLog2: [0, 2] so no clamp to 0 is + // necessary. + assert(damping - FloorLog2(secondary_strength) >= 0); + secondary_damping_shift = + vdupq_n_s16(-(damping - FloorLog2(secondary_strength))); + } + + const int primary_tap_0 = kCdefPrimaryTaps[primary_strength & 1][0]; + const int primary_tap_1 = kCdefPrimaryTaps[primary_strength & 1][1]; + + int y = height; + do { + uint16x8_t pixel; + if (width == 8) { + pixel = vld1q_u16(src); + } else { + pixel = vcombine_u16(vld1_u16(src), vld1_u16(src + src_stride)); + } + + uint16x8_t min = pixel; + uint16x8_t max = pixel; + int16x8_t sum; + + if (enable_primary) { + // Primary |direction|. + uint16x8_t primary_val[4]; + if (width == 8) { + LoadDirection(src, src_stride, primary_val, direction); + } else { + LoadDirection4(src, src_stride, primary_val, direction); + } + + if (clipping_required) { + min = vminq_u16(min, primary_val[0]); + min = vminq_u16(min, primary_val[1]); + min = vminq_u16(min, primary_val[2]); + min = vminq_u16(min, primary_val[3]); + + // The source is 16 bits, however, we only really care about the lower + // 8 bits. The upper 8 bits contain the "large" flag. After the final + // primary max has been calculated, zero out the upper 8 bits. Use this + // to find the "16 bit" max. + const uint8x16_t max_p01 = + vmaxq_u8(vreinterpretq_u8_u16(primary_val[0]), + vreinterpretq_u8_u16(primary_val[1])); + const uint8x16_t max_p23 = + vmaxq_u8(vreinterpretq_u8_u16(primary_val[2]), + vreinterpretq_u8_u16(primary_val[3])); + const uint16x8_t max_p = + vreinterpretq_u16_u8(vmaxq_u8(max_p01, max_p23)); + max = vmaxq_u16(max, vandq_u16(max_p, cdef_large_value_mask)); + } + + sum = Constrain(primary_val[0], pixel, primary_threshold, + primary_damping_shift); + sum = vmulq_n_s16(sum, primary_tap_0); + sum = vmlaq_n_s16(sum, + Constrain(primary_val[1], pixel, primary_threshold, + primary_damping_shift), + primary_tap_0); + sum = vmlaq_n_s16(sum, + Constrain(primary_val[2], pixel, primary_threshold, + primary_damping_shift), + primary_tap_1); + sum = vmlaq_n_s16(sum, + Constrain(primary_val[3], pixel, primary_threshold, + primary_damping_shift), + primary_tap_1); + } else { + sum = vdupq_n_s16(0); + } + + if (enable_secondary) { + // Secondary |direction| values (+/- 2). Clamp |direction|. + uint16x8_t secondary_val[8]; + if (width == 8) { + LoadDirection(src, src_stride, secondary_val, direction + 2); + LoadDirection(src, src_stride, secondary_val + 4, direction - 2); + } else { + LoadDirection4(src, src_stride, secondary_val, direction + 2); + LoadDirection4(src, src_stride, secondary_val + 4, direction - 2); + } + + if (clipping_required) { + min = vminq_u16(min, secondary_val[0]); + min = vminq_u16(min, secondary_val[1]); + min = vminq_u16(min, secondary_val[2]); + min = vminq_u16(min, secondary_val[3]); + min = vminq_u16(min, secondary_val[4]); + min = vminq_u16(min, secondary_val[5]); + min = vminq_u16(min, secondary_val[6]); + min = vminq_u16(min, secondary_val[7]); + + const uint8x16_t max_s01 = + vmaxq_u8(vreinterpretq_u8_u16(secondary_val[0]), + vreinterpretq_u8_u16(secondary_val[1])); + const uint8x16_t max_s23 = + vmaxq_u8(vreinterpretq_u8_u16(secondary_val[2]), + vreinterpretq_u8_u16(secondary_val[3])); + const uint8x16_t max_s45 = + vmaxq_u8(vreinterpretq_u8_u16(secondary_val[4]), + vreinterpretq_u8_u16(secondary_val[5])); + const uint8x16_t max_s67 = + vmaxq_u8(vreinterpretq_u8_u16(secondary_val[6]), + vreinterpretq_u8_u16(secondary_val[7])); + const uint16x8_t max_s = vreinterpretq_u16_u8( + vmaxq_u8(vmaxq_u8(max_s01, max_s23), vmaxq_u8(max_s45, max_s67))); + max = vmaxq_u16(max, vandq_u16(max_s, cdef_large_value_mask)); + } + + sum = vmlaq_n_s16(sum, + Constrain(secondary_val[0], pixel, secondary_threshold, + secondary_damping_shift), + kCdefSecondaryTap0); + sum = vmlaq_n_s16(sum, + Constrain(secondary_val[1], pixel, secondary_threshold, + secondary_damping_shift), + kCdefSecondaryTap0); + sum = vmlaq_n_s16(sum, + Constrain(secondary_val[2], pixel, secondary_threshold, + secondary_damping_shift), + kCdefSecondaryTap1); + sum = vmlaq_n_s16(sum, + Constrain(secondary_val[3], pixel, secondary_threshold, + secondary_damping_shift), + kCdefSecondaryTap1); + sum = vmlaq_n_s16(sum, + Constrain(secondary_val[4], pixel, secondary_threshold, + secondary_damping_shift), + kCdefSecondaryTap0); + sum = vmlaq_n_s16(sum, + Constrain(secondary_val[5], pixel, secondary_threshold, + secondary_damping_shift), + kCdefSecondaryTap0); + sum = vmlaq_n_s16(sum, + Constrain(secondary_val[6], pixel, secondary_threshold, + secondary_damping_shift), + kCdefSecondaryTap1); + sum = vmlaq_n_s16(sum, + Constrain(secondary_val[7], pixel, secondary_threshold, + secondary_damping_shift), + kCdefSecondaryTap1); + } + // Clip3(pixel + ((8 + sum - (sum < 0)) >> 4), min, max)) + const int16x8_t sum_lt_0 = vshrq_n_s16(sum, 15); + sum = vaddq_s16(sum, sum_lt_0); + int16x8_t result = vrsraq_n_s16(vreinterpretq_s16_u16(pixel), sum, 4); + if (clipping_required) { + result = vminq_s16(result, vreinterpretq_s16_u16(max)); + result = vmaxq_s16(result, vreinterpretq_s16_u16(min)); + } + + const uint8x8_t dst_pixel = vqmovun_s16(result); + if (width == 8) { + src += src_stride; + vst1_u8(dst, dst_pixel); + dst += dst_stride; + --y; + } else { + src += src_stride << 1; + StoreLo4(dst, dst_pixel); + dst += dst_stride; + StoreHi4(dst, dst_pixel); + dst += dst_stride; + y -= 2; + } + } while (y != 0); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->cdef_direction = CdefDirection_NEON; + dsp->cdef_filters[0][0] = CdefFilter_NEON<4>; + dsp->cdef_filters[0][1] = + CdefFilter_NEON<4, /*enable_primary=*/true, /*enable_secondary=*/false>; + dsp->cdef_filters[0][2] = CdefFilter_NEON<4, /*enable_primary=*/false>; + dsp->cdef_filters[1][0] = CdefFilter_NEON<8>; + dsp->cdef_filters[1][1] = + CdefFilter_NEON<8, /*enable_primary=*/true, /*enable_secondary=*/false>; + dsp->cdef_filters[1][2] = CdefFilter_NEON<8, /*enable_primary=*/false>; +} + +} // namespace +} // namespace low_bitdepth + +void CdefInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void CdefInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/cdef_neon.h b/src/dsp/arm/cdef_neon.h new file mode 100644 index 0000000..53d5f86 --- /dev/null +++ b/src/dsp/arm/cdef_neon.h @@ -0,0 +1,38 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_CDEF_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_CDEF_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::cdef_direction and Dsp::cdef_filters. This function is not +// thread-safe. +void CdefInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_CdefDirection LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_CdefFilters LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_CDEF_NEON_H_ diff --git a/src/dsp/arm/common_neon.h b/src/dsp/arm/common_neon.h new file mode 100644 index 0000000..dcb7567 --- /dev/null +++ b/src/dsp/arm/common_neon.h @@ -0,0 +1,777 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_COMMON_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_COMMON_NEON_H_ + +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cstdint> +#include <cstring> + +#if 0 +#include <cstdio> + +#include "absl/strings/str_cat.h" + +constexpr bool kEnablePrintRegs = true; + +union DebugRegister { + int8_t i8[8]; + int16_t i16[4]; + int32_t i32[2]; + uint8_t u8[8]; + uint16_t u16[4]; + uint32_t u32[2]; +}; + +union DebugRegisterQ { + int8_t i8[16]; + int16_t i16[8]; + int32_t i32[4]; + uint8_t u8[16]; + uint16_t u16[8]; + uint32_t u32[4]; +}; + +// Quite useful macro for debugging. Left here for convenience. +inline void PrintVect(const DebugRegister r, const char* const name, int size) { + int n; + if (kEnablePrintRegs) { + fprintf(stderr, "%s\t: ", name); + if (size == 8) { + for (n = 0; n < 8; ++n) fprintf(stderr, "%.2x ", r.u8[n]); + } else if (size == 16) { + for (n = 0; n < 4; ++n) fprintf(stderr, "%.4x ", r.u16[n]); + } else if (size == 32) { + for (n = 0; n < 2; ++n) fprintf(stderr, "%.8x ", r.u32[n]); + } + fprintf(stderr, "\n"); + } +} + +// Debugging macro for 128-bit types. +inline void PrintVectQ(const DebugRegisterQ r, const char* const name, + int size) { + int n; + if (kEnablePrintRegs) { + fprintf(stderr, "%s\t: ", name); + if (size == 8) { + for (n = 0; n < 16; ++n) fprintf(stderr, "%.2x ", r.u8[n]); + } else if (size == 16) { + for (n = 0; n < 8; ++n) fprintf(stderr, "%.4x ", r.u16[n]); + } else if (size == 32) { + for (n = 0; n < 4; ++n) fprintf(stderr, "%.8x ", r.u32[n]); + } + fprintf(stderr, "\n"); + } +} + +inline void PrintReg(const int32x4x2_t val, const std::string& name) { + DebugRegisterQ r; + vst1q_u32(r.u32, val.val[0]); + const std::string name0 = absl::StrCat(name, ".val[0]").c_str(); + PrintVectQ(r, name0.c_str(), 32); + vst1q_u32(r.u32, val.val[1]); + const std::string name1 = absl::StrCat(name, ".val[1]").c_str(); + PrintVectQ(r, name1.c_str(), 32); +} + +inline void PrintReg(const uint32x4_t val, const char* name) { + DebugRegisterQ r; + vst1q_u32(r.u32, val); + PrintVectQ(r, name, 32); +} + +inline void PrintReg(const uint32x2_t val, const char* name) { + DebugRegister r; + vst1_u32(r.u32, val); + PrintVect(r, name, 32); +} + +inline void PrintReg(const uint16x8_t val, const char* name) { + DebugRegisterQ r; + vst1q_u16(r.u16, val); + PrintVectQ(r, name, 16); +} + +inline void PrintReg(const uint16x4_t val, const char* name) { + DebugRegister r; + vst1_u16(r.u16, val); + PrintVect(r, name, 16); +} + +inline void PrintReg(const uint8x16_t val, const char* name) { + DebugRegisterQ r; + vst1q_u8(r.u8, val); + PrintVectQ(r, name, 8); +} + +inline void PrintReg(const uint8x8_t val, const char* name) { + DebugRegister r; + vst1_u8(r.u8, val); + PrintVect(r, name, 8); +} + +inline void PrintReg(const int32x4_t val, const char* name) { + DebugRegisterQ r; + vst1q_s32(r.i32, val); + PrintVectQ(r, name, 32); +} + +inline void PrintReg(const int32x2_t val, const char* name) { + DebugRegister r; + vst1_s32(r.i32, val); + PrintVect(r, name, 32); +} + +inline void PrintReg(const int16x8_t val, const char* name) { + DebugRegisterQ r; + vst1q_s16(r.i16, val); + PrintVectQ(r, name, 16); +} + +inline void PrintReg(const int16x4_t val, const char* name) { + DebugRegister r; + vst1_s16(r.i16, val); + PrintVect(r, name, 16); +} + +inline void PrintReg(const int8x16_t val, const char* name) { + DebugRegisterQ r; + vst1q_s8(r.i8, val); + PrintVectQ(r, name, 8); +} + +inline void PrintReg(const int8x8_t val, const char* name) { + DebugRegister r; + vst1_s8(r.i8, val); + PrintVect(r, name, 8); +} + +// Print an individual (non-vector) value in decimal format. +inline void PrintReg(const int x, const char* name) { + if (kEnablePrintRegs) { + printf("%s: %d\n", name, x); + } +} + +// Print an individual (non-vector) value in hexadecimal format. +inline void PrintHex(const int x, const char* name) { + if (kEnablePrintRegs) { + printf("%s: %x\n", name, x); + } +} + +#define PR(x) PrintReg(x, #x) +#define PD(x) PrintReg(x, #x) +#define PX(x) PrintHex(x, #x) + +#endif // 0 + +namespace libgav1 { +namespace dsp { + +//------------------------------------------------------------------------------ +// Load functions. + +// Load 2 uint8_t values into lanes 0 and 1. Zeros the register before loading +// the values. Use caution when using this in loops because it will re-zero the +// register before loading on every iteration. +inline uint8x8_t Load2(const void* const buf) { + const uint16x4_t zero = vdup_n_u16(0); + uint16_t temp; + memcpy(&temp, buf, 2); + return vreinterpret_u8_u16(vld1_lane_u16(&temp, zero, 0)); +} + +// Load 2 uint8_t values into |lane| * 2 and |lane| * 2 + 1. +template <int lane> +inline uint8x8_t Load2(const void* const buf, uint8x8_t val) { + uint16_t temp; + memcpy(&temp, buf, 2); + return vreinterpret_u8_u16( + vld1_lane_u16(&temp, vreinterpret_u16_u8(val), lane)); +} + +// Load 4 uint8_t values into the low half of a uint8x8_t register. Zeros the +// register before loading the values. Use caution when using this in loops +// because it will re-zero the register before loading on every iteration. +inline uint8x8_t Load4(const void* const buf) { + const uint32x2_t zero = vdup_n_u32(0); + uint32_t temp; + memcpy(&temp, buf, 4); + return vreinterpret_u8_u32(vld1_lane_u32(&temp, zero, 0)); +} + +// Load 4 uint8_t values into 4 lanes staring with |lane| * 4. +template <int lane> +inline uint8x8_t Load4(const void* const buf, uint8x8_t val) { + uint32_t temp; + memcpy(&temp, buf, 4); + return vreinterpret_u8_u32( + vld1_lane_u32(&temp, vreinterpret_u32_u8(val), lane)); +} + +//------------------------------------------------------------------------------ +// Store functions. + +// Propagate type information to the compiler. Without this the compiler may +// assume the required alignment of the type (4 bytes in the case of uint32_t) +// and add alignment hints to the memory access. +template <typename T> +inline void ValueToMem(void* const buf, T val) { + memcpy(buf, &val, sizeof(val)); +} + +// Store 4 int8_t values from the low half of an int8x8_t register. +inline void StoreLo4(void* const buf, const int8x8_t val) { + ValueToMem<int32_t>(buf, vget_lane_s32(vreinterpret_s32_s8(val), 0)); +} + +// Store 4 uint8_t values from the low half of a uint8x8_t register. +inline void StoreLo4(void* const buf, const uint8x8_t val) { + ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u8(val), 0)); +} + +// Store 4 uint8_t values from the high half of a uint8x8_t register. +inline void StoreHi4(void* const buf, const uint8x8_t val) { + ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u8(val), 1)); +} + +// Store 2 uint8_t values from |lane| * 2 and |lane| * 2 + 1 of a uint8x8_t +// register. +template <int lane> +inline void Store2(void* const buf, const uint8x8_t val) { + ValueToMem<uint16_t>(buf, vget_lane_u16(vreinterpret_u16_u8(val), lane)); +} + +// Store 2 uint16_t values from |lane| * 2 and |lane| * 2 + 1 of a uint16x8_t +// register. +template <int lane> +inline void Store2(void* const buf, const uint16x8_t val) { + ValueToMem<uint32_t>(buf, vgetq_lane_u32(vreinterpretq_u32_u16(val), lane)); +} + +// Store 2 uint16_t values from |lane| * 2 and |lane| * 2 + 1 of a uint16x4_t +// register. +template <int lane> +inline void Store2(uint16_t* const buf, const uint16x4_t val) { + ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u16(val), lane)); +} + +//------------------------------------------------------------------------------ +// Bit manipulation. + +// vshXX_n_XX() requires an immediate. +template <int shift> +inline uint8x8_t LeftShift(const uint8x8_t vector) { + return vreinterpret_u8_u64(vshl_n_u64(vreinterpret_u64_u8(vector), shift)); +} + +template <int shift> +inline uint8x8_t RightShift(const uint8x8_t vector) { + return vreinterpret_u8_u64(vshr_n_u64(vreinterpret_u64_u8(vector), shift)); +} + +template <int shift> +inline int8x8_t RightShift(const int8x8_t vector) { + return vreinterpret_s8_u64(vshr_n_u64(vreinterpret_u64_s8(vector), shift)); +} + +// Shim vqtbl1_u8 for armv7. +inline uint8x8_t VQTbl1U8(const uint8x16_t a, const uint8x8_t index) { +#if defined(__aarch64__) + return vqtbl1_u8(a, index); +#else + const uint8x8x2_t b = {vget_low_u8(a), vget_high_u8(a)}; + return vtbl2_u8(b, index); +#endif +} + +// Shim vqtbl1_s8 for armv7. +inline int8x8_t VQTbl1S8(const int8x16_t a, const uint8x8_t index) { +#if defined(__aarch64__) + return vqtbl1_s8(a, index); +#else + const int8x8x2_t b = {vget_low_s8(a), vget_high_s8(a)}; + return vtbl2_s8(b, vreinterpret_s8_u8(index)); +#endif +} + +//------------------------------------------------------------------------------ +// Interleave. + +// vzipN is exclusive to A64. +inline uint8x8_t InterleaveLow8(const uint8x8_t a, const uint8x8_t b) { +#if defined(__aarch64__) + return vzip1_u8(a, b); +#else + // Discard |.val[1]| + return vzip_u8(a, b).val[0]; +#endif +} + +inline uint8x8_t InterleaveLow32(const uint8x8_t a, const uint8x8_t b) { +#if defined(__aarch64__) + return vreinterpret_u8_u32( + vzip1_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b))); +#else + // Discard |.val[1]| + return vreinterpret_u8_u32( + vzip_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)).val[0]); +#endif +} + +inline int8x8_t InterleaveLow32(const int8x8_t a, const int8x8_t b) { +#if defined(__aarch64__) + return vreinterpret_s8_u32( + vzip1_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b))); +#else + // Discard |.val[1]| + return vreinterpret_s8_u32( + vzip_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)).val[0]); +#endif +} + +inline uint8x8_t InterleaveHigh32(const uint8x8_t a, const uint8x8_t b) { +#if defined(__aarch64__) + return vreinterpret_u8_u32( + vzip2_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b))); +#else + // Discard |.val[0]| + return vreinterpret_u8_u32( + vzip_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)).val[1]); +#endif +} + +inline int8x8_t InterleaveHigh32(const int8x8_t a, const int8x8_t b) { +#if defined(__aarch64__) + return vreinterpret_s8_u32( + vzip2_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b))); +#else + // Discard |.val[0]| + return vreinterpret_s8_u32( + vzip_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)).val[1]); +#endif +} + +//------------------------------------------------------------------------------ +// Sum. + +inline uint16_t SumVector(const uint8x8_t a) { +#if defined(__aarch64__) + return vaddlv_u8(a); +#else + const uint16x4_t c = vpaddl_u8(a); + const uint32x2_t d = vpaddl_u16(c); + const uint64x1_t e = vpaddl_u32(d); + return static_cast<uint16_t>(vget_lane_u64(e, 0)); +#endif // defined(__aarch64__) +} + +inline uint32_t SumVector(const uint32x4_t a) { +#if defined(__aarch64__) + return vaddvq_u32(a); +#else + const uint64x2_t b = vpaddlq_u32(a); + const uint64x1_t c = vadd_u64(vget_low_u64(b), vget_high_u64(b)); + return static_cast<uint32_t>(vget_lane_u64(c, 0)); +#endif +} + +//------------------------------------------------------------------------------ +// Transpose. + +// Transpose 32 bit elements such that: +// a: 00 01 +// b: 02 03 +// returns +// val[0]: 00 02 +// val[1]: 01 03 +inline uint8x8x2_t Interleave32(const uint8x8_t a, const uint8x8_t b) { + const uint32x2_t a_32 = vreinterpret_u32_u8(a); + const uint32x2_t b_32 = vreinterpret_u32_u8(b); + const uint32x2x2_t c = vtrn_u32(a_32, b_32); + const uint8x8x2_t d = {vreinterpret_u8_u32(c.val[0]), + vreinterpret_u8_u32(c.val[1])}; + return d; +} + +// Swap high and low 32 bit elements. +inline uint8x8_t Transpose32(const uint8x8_t a) { + const uint32x2_t b = vrev64_u32(vreinterpret_u32_u8(a)); + return vreinterpret_u8_u32(b); +} + +// Implement vtrnq_s64(). +// Input: +// a0: 00 01 02 03 04 05 06 07 +// a1: 16 17 18 19 20 21 22 23 +// Output: +// b0.val[0]: 00 01 02 03 16 17 18 19 +// b0.val[1]: 04 05 06 07 20 21 22 23 +inline int16x8x2_t VtrnqS64(int32x4_t a0, int32x4_t a1) { + int16x8x2_t b0; + b0.val[0] = vcombine_s16(vreinterpret_s16_s32(vget_low_s32(a0)), + vreinterpret_s16_s32(vget_low_s32(a1))); + b0.val[1] = vcombine_s16(vreinterpret_s16_s32(vget_high_s32(a0)), + vreinterpret_s16_s32(vget_high_s32(a1))); + return b0; +} + +inline uint16x8x2_t VtrnqU64(uint32x4_t a0, uint32x4_t a1) { + uint16x8x2_t b0; + b0.val[0] = vcombine_u16(vreinterpret_u16_u32(vget_low_u32(a0)), + vreinterpret_u16_u32(vget_low_u32(a1))); + b0.val[1] = vcombine_u16(vreinterpret_u16_u32(vget_high_u32(a0)), + vreinterpret_u16_u32(vget_high_u32(a1))); + return b0; +} + +// Input: +// a: 00 01 02 03 10 11 12 13 +// b: 20 21 22 23 30 31 32 33 +// Output: +// Note that columns [1] and [2] are transposed. +// a: 00 10 20 30 02 12 22 32 +// b: 01 11 21 31 03 13 23 33 +inline void Transpose4x4(uint8x8_t* a, uint8x8_t* b) { + const uint16x4x2_t c = + vtrn_u16(vreinterpret_u16_u8(*a), vreinterpret_u16_u8(*b)); + const uint32x2x2_t d = + vtrn_u32(vreinterpret_u32_u16(c.val[0]), vreinterpret_u32_u16(c.val[1])); + const uint8x8x2_t e = + vtrn_u8(vreinterpret_u8_u32(d.val[0]), vreinterpret_u8_u32(d.val[1])); + *a = e.val[0]; + *b = e.val[1]; +} + +// Reversible if the x4 values are packed next to each other. +// x4 input / x8 output: +// a0: 00 01 02 03 40 41 42 43 44 +// a1: 10 11 12 13 50 51 52 53 54 +// a2: 20 21 22 23 60 61 62 63 64 +// a3: 30 31 32 33 70 71 72 73 74 +// x8 input / x4 output: +// a0: 00 10 20 30 40 50 60 70 +// a1: 01 11 21 31 41 51 61 71 +// a2: 02 12 22 32 42 52 62 72 +// a3: 03 13 23 33 43 53 63 73 +inline void Transpose8x4(uint8x8_t* a0, uint8x8_t* a1, uint8x8_t* a2, + uint8x8_t* a3) { + const uint8x8x2_t b0 = vtrn_u8(*a0, *a1); + const uint8x8x2_t b1 = vtrn_u8(*a2, *a3); + + const uint16x4x2_t c0 = + vtrn_u16(vreinterpret_u16_u8(b0.val[0]), vreinterpret_u16_u8(b1.val[0])); + const uint16x4x2_t c1 = + vtrn_u16(vreinterpret_u16_u8(b0.val[1]), vreinterpret_u16_u8(b1.val[1])); + + *a0 = vreinterpret_u8_u16(c0.val[0]); + *a1 = vreinterpret_u8_u16(c1.val[0]); + *a2 = vreinterpret_u8_u16(c0.val[1]); + *a3 = vreinterpret_u8_u16(c1.val[1]); +} + +// Input: +// a[0]: 00 01 02 03 04 05 06 07 +// a[1]: 10 11 12 13 14 15 16 17 +// a[2]: 20 21 22 23 24 25 26 27 +// a[3]: 30 31 32 33 34 35 36 37 +// a[4]: 40 41 42 43 44 45 46 47 +// a[5]: 50 51 52 53 54 55 56 57 +// a[6]: 60 61 62 63 64 65 66 67 +// a[7]: 70 71 72 73 74 75 76 77 + +// Output: +// a[0]: 00 10 20 30 40 50 60 70 +// a[1]: 01 11 21 31 41 51 61 71 +// a[2]: 02 12 22 32 42 52 62 72 +// a[3]: 03 13 23 33 43 53 63 73 +// a[4]: 04 14 24 34 44 54 64 74 +// a[5]: 05 15 25 35 45 55 65 75 +// a[6]: 06 16 26 36 46 56 66 76 +// a[7]: 07 17 27 37 47 57 67 77 +inline void Transpose8x8(int8x8_t a[8]) { + // Swap 8 bit elements. Goes from: + // a[0]: 00 01 02 03 04 05 06 07 + // a[1]: 10 11 12 13 14 15 16 17 + // a[2]: 20 21 22 23 24 25 26 27 + // a[3]: 30 31 32 33 34 35 36 37 + // a[4]: 40 41 42 43 44 45 46 47 + // a[5]: 50 51 52 53 54 55 56 57 + // a[6]: 60 61 62 63 64 65 66 67 + // a[7]: 70 71 72 73 74 75 76 77 + // to: + // b0.val[0]: 00 10 02 12 04 14 06 16 40 50 42 52 44 54 46 56 + // b0.val[1]: 01 11 03 13 05 15 07 17 41 51 43 53 45 55 47 57 + // b1.val[0]: 20 30 22 32 24 34 26 36 60 70 62 72 64 74 66 76 + // b1.val[1]: 21 31 23 33 25 35 27 37 61 71 63 73 65 75 67 77 + const int8x16x2_t b0 = + vtrnq_s8(vcombine_s8(a[0], a[4]), vcombine_s8(a[1], a[5])); + const int8x16x2_t b1 = + vtrnq_s8(vcombine_s8(a[2], a[6]), vcombine_s8(a[3], a[7])); + + // Swap 16 bit elements resulting in: + // c0.val[0]: 00 10 20 30 04 14 24 34 40 50 60 70 44 54 64 74 + // c0.val[1]: 02 12 22 32 06 16 26 36 42 52 62 72 46 56 66 76 + // c1.val[0]: 01 11 21 31 05 15 25 35 41 51 61 71 45 55 65 75 + // c1.val[1]: 03 13 23 33 07 17 27 37 43 53 63 73 47 57 67 77 + const int16x8x2_t c0 = vtrnq_s16(vreinterpretq_s16_s8(b0.val[0]), + vreinterpretq_s16_s8(b1.val[0])); + const int16x8x2_t c1 = vtrnq_s16(vreinterpretq_s16_s8(b0.val[1]), + vreinterpretq_s16_s8(b1.val[1])); + + // Unzip 32 bit elements resulting in: + // d0.val[0]: 00 10 20 30 40 50 60 70 01 11 21 31 41 51 61 71 + // d0.val[1]: 04 14 24 34 44 54 64 74 05 15 25 35 45 55 65 75 + // d1.val[0]: 02 12 22 32 42 52 62 72 03 13 23 33 43 53 63 73 + // d1.val[1]: 06 16 26 36 46 56 66 76 07 17 27 37 47 57 67 77 + const int32x4x2_t d0 = vuzpq_s32(vreinterpretq_s32_s16(c0.val[0]), + vreinterpretq_s32_s16(c1.val[0])); + const int32x4x2_t d1 = vuzpq_s32(vreinterpretq_s32_s16(c0.val[1]), + vreinterpretq_s32_s16(c1.val[1])); + + a[0] = vreinterpret_s8_s32(vget_low_s32(d0.val[0])); + a[1] = vreinterpret_s8_s32(vget_high_s32(d0.val[0])); + a[2] = vreinterpret_s8_s32(vget_low_s32(d1.val[0])); + a[3] = vreinterpret_s8_s32(vget_high_s32(d1.val[0])); + a[4] = vreinterpret_s8_s32(vget_low_s32(d0.val[1])); + a[5] = vreinterpret_s8_s32(vget_high_s32(d0.val[1])); + a[6] = vreinterpret_s8_s32(vget_low_s32(d1.val[1])); + a[7] = vreinterpret_s8_s32(vget_high_s32(d1.val[1])); +} + +// Unsigned. +inline void Transpose8x8(uint8x8_t a[8]) { + const uint8x16x2_t b0 = + vtrnq_u8(vcombine_u8(a[0], a[4]), vcombine_u8(a[1], a[5])); + const uint8x16x2_t b1 = + vtrnq_u8(vcombine_u8(a[2], a[6]), vcombine_u8(a[3], a[7])); + + const uint16x8x2_t c0 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[0]), + vreinterpretq_u16_u8(b1.val[0])); + const uint16x8x2_t c1 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[1]), + vreinterpretq_u16_u8(b1.val[1])); + + const uint32x4x2_t d0 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[0]), + vreinterpretq_u32_u16(c1.val[0])); + const uint32x4x2_t d1 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[1]), + vreinterpretq_u32_u16(c1.val[1])); + + a[0] = vreinterpret_u8_u32(vget_low_u32(d0.val[0])); + a[1] = vreinterpret_u8_u32(vget_high_u32(d0.val[0])); + a[2] = vreinterpret_u8_u32(vget_low_u32(d1.val[0])); + a[3] = vreinterpret_u8_u32(vget_high_u32(d1.val[0])); + a[4] = vreinterpret_u8_u32(vget_low_u32(d0.val[1])); + a[5] = vreinterpret_u8_u32(vget_high_u32(d0.val[1])); + a[6] = vreinterpret_u8_u32(vget_low_u32(d1.val[1])); + a[7] = vreinterpret_u8_u32(vget_high_u32(d1.val[1])); +} + +inline void Transpose8x8(uint8x8_t in[8], uint8x16_t out[4]) { + const uint8x16x2_t a0 = + vtrnq_u8(vcombine_u8(in[0], in[4]), vcombine_u8(in[1], in[5])); + const uint8x16x2_t a1 = + vtrnq_u8(vcombine_u8(in[2], in[6]), vcombine_u8(in[3], in[7])); + + const uint16x8x2_t b0 = vtrnq_u16(vreinterpretq_u16_u8(a0.val[0]), + vreinterpretq_u16_u8(a1.val[0])); + const uint16x8x2_t b1 = vtrnq_u16(vreinterpretq_u16_u8(a0.val[1]), + vreinterpretq_u16_u8(a1.val[1])); + + const uint32x4x2_t c0 = vuzpq_u32(vreinterpretq_u32_u16(b0.val[0]), + vreinterpretq_u32_u16(b1.val[0])); + const uint32x4x2_t c1 = vuzpq_u32(vreinterpretq_u32_u16(b0.val[1]), + vreinterpretq_u32_u16(b1.val[1])); + + out[0] = vreinterpretq_u8_u32(c0.val[0]); + out[1] = vreinterpretq_u8_u32(c1.val[0]); + out[2] = vreinterpretq_u8_u32(c0.val[1]); + out[3] = vreinterpretq_u8_u32(c1.val[1]); +} + +// Input: +// a[0]: 00 01 02 03 04 05 06 07 +// a[1]: 10 11 12 13 14 15 16 17 +// a[2]: 20 21 22 23 24 25 26 27 +// a[3]: 30 31 32 33 34 35 36 37 +// a[4]: 40 41 42 43 44 45 46 47 +// a[5]: 50 51 52 53 54 55 56 57 +// a[6]: 60 61 62 63 64 65 66 67 +// a[7]: 70 71 72 73 74 75 76 77 + +// Output: +// a[0]: 00 10 20 30 40 50 60 70 +// a[1]: 01 11 21 31 41 51 61 71 +// a[2]: 02 12 22 32 42 52 62 72 +// a[3]: 03 13 23 33 43 53 63 73 +// a[4]: 04 14 24 34 44 54 64 74 +// a[5]: 05 15 25 35 45 55 65 75 +// a[6]: 06 16 26 36 46 56 66 76 +// a[7]: 07 17 27 37 47 57 67 77 +inline void Transpose8x8(int16x8_t a[8]) { + const int16x8x2_t b0 = vtrnq_s16(a[0], a[1]); + const int16x8x2_t b1 = vtrnq_s16(a[2], a[3]); + const int16x8x2_t b2 = vtrnq_s16(a[4], a[5]); + const int16x8x2_t b3 = vtrnq_s16(a[6], a[7]); + + const int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[0]), + vreinterpretq_s32_s16(b1.val[0])); + const int32x4x2_t c1 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[1]), + vreinterpretq_s32_s16(b1.val[1])); + const int32x4x2_t c2 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[0]), + vreinterpretq_s32_s16(b3.val[0])); + const int32x4x2_t c3 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[1]), + vreinterpretq_s32_s16(b3.val[1])); + + const int16x8x2_t d0 = VtrnqS64(c0.val[0], c2.val[0]); + const int16x8x2_t d1 = VtrnqS64(c1.val[0], c3.val[0]); + const int16x8x2_t d2 = VtrnqS64(c0.val[1], c2.val[1]); + const int16x8x2_t d3 = VtrnqS64(c1.val[1], c3.val[1]); + + a[0] = d0.val[0]; + a[1] = d1.val[0]; + a[2] = d2.val[0]; + a[3] = d3.val[0]; + a[4] = d0.val[1]; + a[5] = d1.val[1]; + a[6] = d2.val[1]; + a[7] = d3.val[1]; +} + +// Unsigned. +inline void Transpose8x8(uint16x8_t a[8]) { + const uint16x8x2_t b0 = vtrnq_u16(a[0], a[1]); + const uint16x8x2_t b1 = vtrnq_u16(a[2], a[3]); + const uint16x8x2_t b2 = vtrnq_u16(a[4], a[5]); + const uint16x8x2_t b3 = vtrnq_u16(a[6], a[7]); + + const uint32x4x2_t c0 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[0]), + vreinterpretq_u32_u16(b1.val[0])); + const uint32x4x2_t c1 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[1]), + vreinterpretq_u32_u16(b1.val[1])); + const uint32x4x2_t c2 = vtrnq_u32(vreinterpretq_u32_u16(b2.val[0]), + vreinterpretq_u32_u16(b3.val[0])); + const uint32x4x2_t c3 = vtrnq_u32(vreinterpretq_u32_u16(b2.val[1]), + vreinterpretq_u32_u16(b3.val[1])); + + const uint16x8x2_t d0 = VtrnqU64(c0.val[0], c2.val[0]); + const uint16x8x2_t d1 = VtrnqU64(c1.val[0], c3.val[0]); + const uint16x8x2_t d2 = VtrnqU64(c0.val[1], c2.val[1]); + const uint16x8x2_t d3 = VtrnqU64(c1.val[1], c3.val[1]); + + a[0] = d0.val[0]; + a[1] = d1.val[0]; + a[2] = d2.val[0]; + a[3] = d3.val[0]; + a[4] = d0.val[1]; + a[5] = d1.val[1]; + a[6] = d2.val[1]; + a[7] = d3.val[1]; +} + +// Input: +// a[0]: 00 01 02 03 04 05 06 07 80 81 82 83 84 85 86 87 +// a[1]: 10 11 12 13 14 15 16 17 90 91 92 93 94 95 96 97 +// a[2]: 20 21 22 23 24 25 26 27 a0 a1 a2 a3 a4 a5 a6 a7 +// a[3]: 30 31 32 33 34 35 36 37 b0 b1 b2 b3 b4 b5 b6 b7 +// a[4]: 40 41 42 43 44 45 46 47 c0 c1 c2 c3 c4 c5 c6 c7 +// a[5]: 50 51 52 53 54 55 56 57 d0 d1 d2 d3 d4 d5 d6 d7 +// a[6]: 60 61 62 63 64 65 66 67 e0 e1 e2 e3 e4 e5 e6 e7 +// a[7]: 70 71 72 73 74 75 76 77 f0 f1 f2 f3 f4 f5 f6 f7 + +// Output: +// a[0]: 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 +// a[1]: 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 +// a[2]: 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 +// a[3]: 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 +// a[4]: 04 14 24 34 44 54 64 74 84 94 a4 b4 c4 d4 e4 f4 +// a[5]: 05 15 25 35 45 55 65 75 85 95 a5 b5 c5 d5 e5 f5 +// a[6]: 06 16 26 36 46 56 66 76 86 96 a6 b6 c6 d6 e6 f6 +// a[7]: 07 17 27 37 47 57 67 77 87 97 a7 b7 c7 d7 e7 f7 +inline void Transpose8x16(uint8x16_t a[8]) { + // b0.val[0]: 00 10 02 12 04 14 06 16 80 90 82 92 84 94 86 96 + // b0.val[1]: 01 11 03 13 05 15 07 17 81 91 83 93 85 95 87 97 + // b1.val[0]: 20 30 22 32 24 34 26 36 a0 b0 a2 b2 a4 b4 a6 b6 + // b1.val[1]: 21 31 23 33 25 35 27 37 a1 b1 a3 b3 a5 b5 a7 b7 + // b2.val[0]: 40 50 42 52 44 54 46 56 c0 d0 c2 d2 c4 d4 c6 d6 + // b2.val[1]: 41 51 43 53 45 55 47 57 c1 d1 c3 d3 c5 d5 c7 d7 + // b3.val[0]: 60 70 62 72 64 74 66 76 e0 f0 e2 f2 e4 f4 e6 f6 + // b3.val[1]: 61 71 63 73 65 75 67 77 e1 f1 e3 f3 e5 f5 e7 f7 + const uint8x16x2_t b0 = vtrnq_u8(a[0], a[1]); + const uint8x16x2_t b1 = vtrnq_u8(a[2], a[3]); + const uint8x16x2_t b2 = vtrnq_u8(a[4], a[5]); + const uint8x16x2_t b3 = vtrnq_u8(a[6], a[7]); + + // c0.val[0]: 00 10 20 30 04 14 24 34 80 90 a0 b0 84 94 a4 b4 + // c0.val[1]: 02 12 22 32 06 16 26 36 82 92 a2 b2 86 96 a6 b6 + // c1.val[0]: 01 11 21 31 05 15 25 35 81 91 a1 b1 85 95 a5 b5 + // c1.val[1]: 03 13 23 33 07 17 27 37 83 93 a3 b3 87 97 a7 b7 + // c2.val[0]: 40 50 60 70 44 54 64 74 c0 d0 e0 f0 c4 d4 e4 f4 + // c2.val[1]: 42 52 62 72 46 56 66 76 c2 d2 e2 f2 c6 d6 e6 f6 + // c3.val[0]: 41 51 61 71 45 55 65 75 c1 d1 e1 f1 c5 d5 e5 f5 + // c3.val[1]: 43 53 63 73 47 57 67 77 c3 d3 e3 f3 c7 d7 e7 f7 + const uint16x8x2_t c0 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[0]), + vreinterpretq_u16_u8(b1.val[0])); + const uint16x8x2_t c1 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[1]), + vreinterpretq_u16_u8(b1.val[1])); + const uint16x8x2_t c2 = vtrnq_u16(vreinterpretq_u16_u8(b2.val[0]), + vreinterpretq_u16_u8(b3.val[0])); + const uint16x8x2_t c3 = vtrnq_u16(vreinterpretq_u16_u8(b2.val[1]), + vreinterpretq_u16_u8(b3.val[1])); + + // d0.val[0]: 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // d0.val[1]: 04 14 24 34 44 54 64 74 84 94 a4 b4 c4 d4 e4 f4 + // d1.val[0]: 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + // d1.val[1]: 05 15 25 35 45 55 65 75 85 95 a5 b5 c5 d5 e5 f5 + // d2.val[0]: 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 + // d2.val[1]: 06 16 26 36 46 56 66 76 86 96 a6 b6 c6 d6 e6 f6 + // d3.val[0]: 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 + // d3.val[1]: 07 17 27 37 47 57 67 77 87 97 a7 b7 c7 d7 e7 f7 + const uint32x4x2_t d0 = vtrnq_u32(vreinterpretq_u32_u16(c0.val[0]), + vreinterpretq_u32_u16(c2.val[0])); + const uint32x4x2_t d1 = vtrnq_u32(vreinterpretq_u32_u16(c1.val[0]), + vreinterpretq_u32_u16(c3.val[0])); + const uint32x4x2_t d2 = vtrnq_u32(vreinterpretq_u32_u16(c0.val[1]), + vreinterpretq_u32_u16(c2.val[1])); + const uint32x4x2_t d3 = vtrnq_u32(vreinterpretq_u32_u16(c1.val[1]), + vreinterpretq_u32_u16(c3.val[1])); + + a[0] = vreinterpretq_u8_u32(d0.val[0]); + a[1] = vreinterpretq_u8_u32(d1.val[0]); + a[2] = vreinterpretq_u8_u32(d2.val[0]); + a[3] = vreinterpretq_u8_u32(d3.val[0]); + a[4] = vreinterpretq_u8_u32(d0.val[1]); + a[5] = vreinterpretq_u8_u32(d1.val[1]); + a[6] = vreinterpretq_u8_u32(d2.val[1]); + a[7] = vreinterpretq_u8_u32(d3.val[1]); +} + +inline int16x8_t ZeroExtend(const uint8x8_t in) { + return vreinterpretq_s16_u16(vmovl_u8(in)); +} + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_ENABLE_NEON +#endif // LIBGAV1_SRC_DSP_ARM_COMMON_NEON_H_ diff --git a/src/dsp/arm/convolve_neon.cc b/src/dsp/arm/convolve_neon.cc new file mode 100644 index 0000000..fd9b912 --- /dev/null +++ b/src/dsp/arm/convolve_neon.cc @@ -0,0 +1,3105 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/convolve.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/compiler_attributes.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +// Include the constants and utility functions inside the anonymous namespace. +#include "src/dsp/convolve.inc" + +// Multiply every entry in |src[]| by the corresponding entry in |taps[]| and +// sum. The filters in |taps[]| are pre-shifted by 1. This prevents the final +// sum from outranging int16_t. +template <int filter_index, bool negative_outside_taps = false> +int16x8_t SumOnePassTaps(const uint8x8_t* const src, + const uint8x8_t* const taps) { + uint16x8_t sum; + if (filter_index == 0) { + // 6 taps. + - + + - + + sum = vmull_u8(src[0], taps[0]); + // Unsigned overflow will result in a valid int16_t value. + sum = vmlsl_u8(sum, src[1], taps[1]); + sum = vmlal_u8(sum, src[2], taps[2]); + sum = vmlal_u8(sum, src[3], taps[3]); + sum = vmlsl_u8(sum, src[4], taps[4]); + sum = vmlal_u8(sum, src[5], taps[5]); + } else if (filter_index == 1 && negative_outside_taps) { + // 6 taps. - + + + + - + // Set a base we can subtract from. + sum = vmull_u8(src[1], taps[1]); + sum = vmlsl_u8(sum, src[0], taps[0]); + sum = vmlal_u8(sum, src[2], taps[2]); + sum = vmlal_u8(sum, src[3], taps[3]); + sum = vmlal_u8(sum, src[4], taps[4]); + sum = vmlsl_u8(sum, src[5], taps[5]); + } else if (filter_index == 1) { + // 6 taps. All are positive. + sum = vmull_u8(src[0], taps[0]); + sum = vmlal_u8(sum, src[1], taps[1]); + sum = vmlal_u8(sum, src[2], taps[2]); + sum = vmlal_u8(sum, src[3], taps[3]); + sum = vmlal_u8(sum, src[4], taps[4]); + sum = vmlal_u8(sum, src[5], taps[5]); + } else if (filter_index == 2) { + // 8 taps. - + - + + - + - + sum = vmull_u8(src[1], taps[1]); + sum = vmlsl_u8(sum, src[0], taps[0]); + sum = vmlsl_u8(sum, src[2], taps[2]); + sum = vmlal_u8(sum, src[3], taps[3]); + sum = vmlal_u8(sum, src[4], taps[4]); + sum = vmlsl_u8(sum, src[5], taps[5]); + sum = vmlal_u8(sum, src[6], taps[6]); + sum = vmlsl_u8(sum, src[7], taps[7]); + } else if (filter_index == 3) { + // 2 taps. All are positive. + sum = vmull_u8(src[0], taps[0]); + sum = vmlal_u8(sum, src[1], taps[1]); + } else if (filter_index == 4) { + // 4 taps. - + + - + sum = vmull_u8(src[1], taps[1]); + sum = vmlsl_u8(sum, src[0], taps[0]); + sum = vmlal_u8(sum, src[2], taps[2]); + sum = vmlsl_u8(sum, src[3], taps[3]); + } else if (filter_index == 5) { + // 4 taps. All are positive. + sum = vmull_u8(src[0], taps[0]); + sum = vmlal_u8(sum, src[1], taps[1]); + sum = vmlal_u8(sum, src[2], taps[2]); + sum = vmlal_u8(sum, src[3], taps[3]); + } + return vreinterpretq_s16_u16(sum); +} + +template <int filter_index, bool negative_outside_taps> +int16x8_t SumHorizontalTaps(const uint8_t* const src, + const uint8x8_t* const v_tap) { + uint8x8_t v_src[8]; + const uint8x16_t src_long = vld1q_u8(src); + int16x8_t sum; + + if (filter_index < 2) { + v_src[0] = vget_low_u8(vextq_u8(src_long, src_long, 1)); + v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 2)); + v_src[2] = vget_low_u8(vextq_u8(src_long, src_long, 3)); + v_src[3] = vget_low_u8(vextq_u8(src_long, src_long, 4)); + v_src[4] = vget_low_u8(vextq_u8(src_long, src_long, 5)); + v_src[5] = vget_low_u8(vextq_u8(src_long, src_long, 6)); + sum = SumOnePassTaps<filter_index, negative_outside_taps>(v_src, v_tap + 1); + } else if (filter_index == 2) { + v_src[0] = vget_low_u8(src_long); + v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 1)); + v_src[2] = vget_low_u8(vextq_u8(src_long, src_long, 2)); + v_src[3] = vget_low_u8(vextq_u8(src_long, src_long, 3)); + v_src[4] = vget_low_u8(vextq_u8(src_long, src_long, 4)); + v_src[5] = vget_low_u8(vextq_u8(src_long, src_long, 5)); + v_src[6] = vget_low_u8(vextq_u8(src_long, src_long, 6)); + v_src[7] = vget_low_u8(vextq_u8(src_long, src_long, 7)); + sum = SumOnePassTaps<filter_index, negative_outside_taps>(v_src, v_tap); + } else if (filter_index == 3) { + v_src[0] = vget_low_u8(vextq_u8(src_long, src_long, 3)); + v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 4)); + sum = SumOnePassTaps<filter_index, negative_outside_taps>(v_src, v_tap + 3); + } else if (filter_index > 3) { + v_src[0] = vget_low_u8(vextq_u8(src_long, src_long, 2)); + v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 3)); + v_src[2] = vget_low_u8(vextq_u8(src_long, src_long, 4)); + v_src[3] = vget_low_u8(vextq_u8(src_long, src_long, 5)); + sum = SumOnePassTaps<filter_index, negative_outside_taps>(v_src, v_tap + 2); + } + return sum; +} + +template <int filter_index, bool negative_outside_taps> +uint8x8_t SimpleHorizontalTaps(const uint8_t* const src, + const uint8x8_t* const v_tap) { + int16x8_t sum = + SumHorizontalTaps<filter_index, negative_outside_taps>(src, v_tap); + + // Normally the Horizontal pass does the downshift in two passes: + // kInterRoundBitsHorizontal - 1 and then (kFilterBits - + // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them + // requires adding the rounding offset from the skipped shift. + constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2); + + sum = vaddq_s16(sum, vdupq_n_s16(first_shift_rounding_bit)); + return vqrshrun_n_s16(sum, kFilterBits - 1); +} + +template <int filter_index, bool negative_outside_taps> +uint16x8_t HorizontalTaps8To16(const uint8_t* const src, + const uint8x8_t* const v_tap) { + const int16x8_t sum = + SumHorizontalTaps<filter_index, negative_outside_taps>(src, v_tap); + + return vreinterpretq_u16_s16( + vrshrq_n_s16(sum, kInterRoundBitsHorizontal - 1)); +} + +template <int filter_index> +int16x8_t SumHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride, + const uint8x8_t* const v_tap) { + uint16x8_t sum; + const uint8x8_t input0 = vld1_u8(src); + src += src_stride; + const uint8x8_t input1 = vld1_u8(src); + uint8x8x2_t input = vzip_u8(input0, input1); + + if (filter_index == 3) { + // tap signs : + + + sum = vmull_u8(vext_u8(input.val[0], input.val[1], 6), v_tap[3]); + sum = vmlal_u8(sum, input.val[1], v_tap[4]); + } else if (filter_index == 4) { + // tap signs : - + + - + sum = vmull_u8(vext_u8(input.val[0], input.val[1], 6), v_tap[3]); + sum = vmlsl_u8(sum, RightShift<4 * 8>(input.val[0]), v_tap[2]); + sum = vmlal_u8(sum, input.val[1], v_tap[4]); + sum = vmlsl_u8(sum, RightShift<2 * 8>(input.val[1]), v_tap[5]); + } else { + // tap signs : + + + + + sum = vmull_u8(RightShift<4 * 8>(input.val[0]), v_tap[2]); + sum = vmlal_u8(sum, vext_u8(input.val[0], input.val[1], 6), v_tap[3]); + sum = vmlal_u8(sum, input.val[1], v_tap[4]); + sum = vmlal_u8(sum, RightShift<2 * 8>(input.val[1]), v_tap[5]); + } + + return vreinterpretq_s16_u16(sum); +} + +template <int filter_index> +uint8x8_t SimpleHorizontalTaps2x2(const uint8_t* src, + const ptrdiff_t src_stride, + const uint8x8_t* const v_tap) { + int16x8_t sum = SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); + + // Normally the Horizontal pass does the downshift in two passes: + // kInterRoundBitsHorizontal - 1 and then (kFilterBits - + // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them + // requires adding the rounding offset from the skipped shift. + constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2); + + sum = vaddq_s16(sum, vdupq_n_s16(first_shift_rounding_bit)); + return vqrshrun_n_s16(sum, kFilterBits - 1); +} + +template <int filter_index> +uint16x8_t HorizontalTaps8To16_2x2(const uint8_t* src, + const ptrdiff_t src_stride, + const uint8x8_t* const v_tap) { + const int16x8_t sum = + SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); + + return vreinterpretq_u16_s16( + vrshrq_n_s16(sum, kInterRoundBitsHorizontal - 1)); +} + +template <int num_taps, int step, int filter_index, + bool negative_outside_taps = true, bool is_2d = false, + bool is_compound = false> +void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, + void* const dest, const ptrdiff_t pred_stride, + const int width, const int height, + const uint8x8_t* const v_tap) { + auto* dest8 = static_cast<uint8_t*>(dest); + auto* dest16 = static_cast<uint16_t*>(dest); + + // 4 tap filters are never used when width > 4. + if (num_taps != 4 && width > 4) { + int y = 0; + do { + int x = 0; + do { + if (is_2d || is_compound) { + const uint16x8_t v_sum = + HorizontalTaps8To16<filter_index, negative_outside_taps>(&src[x], + v_tap); + vst1q_u16(&dest16[x], v_sum); + } else { + const uint8x8_t result = + SimpleHorizontalTaps<filter_index, negative_outside_taps>(&src[x], + v_tap); + vst1_u8(&dest8[x], result); + } + x += step; + } while (x < width); + src += src_stride; + dest8 += pred_stride; + dest16 += pred_stride; + } while (++y < height); + return; + } + + // Horizontal passes only needs to account for |num_taps| 2 and 4 when + // |width| <= 4. + assert(width <= 4); + assert(num_taps <= 4); + if (num_taps <= 4) { + if (width == 4) { + int y = 0; + do { + if (is_2d || is_compound) { + const uint16x8_t v_sum = + HorizontalTaps8To16<filter_index, negative_outside_taps>(src, + v_tap); + vst1_u16(dest16, vget_low_u16(v_sum)); + } else { + const uint8x8_t result = + SimpleHorizontalTaps<filter_index, negative_outside_taps>(src, + v_tap); + StoreLo4(&dest8[0], result); + } + src += src_stride; + dest8 += pred_stride; + dest16 += pred_stride; + } while (++y < height); + return; + } + + if (!is_compound) { + int y = 0; + do { + if (is_2d) { + const uint16x8_t sum = + HorizontalTaps8To16_2x2<filter_index>(src, src_stride, v_tap); + dest16[0] = vgetq_lane_u16(sum, 0); + dest16[1] = vgetq_lane_u16(sum, 2); + dest16 += pred_stride; + dest16[0] = vgetq_lane_u16(sum, 1); + dest16[1] = vgetq_lane_u16(sum, 3); + dest16 += pred_stride; + } else { + const uint8x8_t sum = + SimpleHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); + + dest8[0] = vget_lane_u8(sum, 0); + dest8[1] = vget_lane_u8(sum, 2); + dest8 += pred_stride; + + dest8[0] = vget_lane_u8(sum, 1); + dest8[1] = vget_lane_u8(sum, 3); + dest8 += pred_stride; + } + + src += src_stride << 1; + y += 2; + } while (y < height - 1); + + // The 2d filters have an odd |height| because the horizontal pass + // generates context for the vertical pass. + if (is_2d) { + assert(height % 2 == 1); + uint16x8_t sum; + const uint8x8_t input = vld1_u8(src); + if (filter_index == 3) { // |num_taps| == 2 + sum = vmull_u8(RightShift<3 * 8>(input), v_tap[3]); + sum = vmlal_u8(sum, RightShift<4 * 8>(input), v_tap[4]); + } else if (filter_index == 4) { + sum = vmull_u8(RightShift<3 * 8>(input), v_tap[3]); + sum = vmlsl_u8(sum, RightShift<2 * 8>(input), v_tap[2]); + sum = vmlal_u8(sum, RightShift<4 * 8>(input), v_tap[4]); + sum = vmlsl_u8(sum, RightShift<5 * 8>(input), v_tap[5]); + } else { + assert(filter_index == 5); + sum = vmull_u8(RightShift<2 * 8>(input), v_tap[2]); + sum = vmlal_u8(sum, RightShift<3 * 8>(input), v_tap[3]); + sum = vmlal_u8(sum, RightShift<4 * 8>(input), v_tap[4]); + sum = vmlal_u8(sum, RightShift<5 * 8>(input), v_tap[5]); + } + // |sum| contains an int16_t value. + sum = vreinterpretq_u16_s16(vrshrq_n_s16( + vreinterpretq_s16_u16(sum), kInterRoundBitsHorizontal - 1)); + Store2<0>(dest16, sum); + } + } + } +} + +// Process 16 bit inputs and output 32 bits. +template <int num_taps, bool is_compound> +inline int16x4_t Sum2DVerticalTaps4(const int16x4_t* const src, + const int16x8_t taps) { + const int16x4_t taps_lo = vget_low_s16(taps); + const int16x4_t taps_hi = vget_high_s16(taps); + int32x4_t sum; + if (num_taps == 8) { + sum = vmull_lane_s16(src[0], taps_lo, 0); + sum = vmlal_lane_s16(sum, src[1], taps_lo, 1); + sum = vmlal_lane_s16(sum, src[2], taps_lo, 2); + sum = vmlal_lane_s16(sum, src[3], taps_lo, 3); + sum = vmlal_lane_s16(sum, src[4], taps_hi, 0); + sum = vmlal_lane_s16(sum, src[5], taps_hi, 1); + sum = vmlal_lane_s16(sum, src[6], taps_hi, 2); + sum = vmlal_lane_s16(sum, src[7], taps_hi, 3); + } else if (num_taps == 6) { + sum = vmull_lane_s16(src[0], taps_lo, 1); + sum = vmlal_lane_s16(sum, src[1], taps_lo, 2); + sum = vmlal_lane_s16(sum, src[2], taps_lo, 3); + sum = vmlal_lane_s16(sum, src[3], taps_hi, 0); + sum = vmlal_lane_s16(sum, src[4], taps_hi, 1); + sum = vmlal_lane_s16(sum, src[5], taps_hi, 2); + } else if (num_taps == 4) { + sum = vmull_lane_s16(src[0], taps_lo, 2); + sum = vmlal_lane_s16(sum, src[1], taps_lo, 3); + sum = vmlal_lane_s16(sum, src[2], taps_hi, 0); + sum = vmlal_lane_s16(sum, src[3], taps_hi, 1); + } else if (num_taps == 2) { + sum = vmull_lane_s16(src[0], taps_lo, 3); + sum = vmlal_lane_s16(sum, src[1], taps_hi, 0); + } + + if (is_compound) { + return vqrshrn_n_s32(sum, kInterRoundBitsCompoundVertical - 1); + } + + return vqrshrn_n_s32(sum, kInterRoundBitsVertical - 1); +} + +template <int num_taps, bool is_compound> +int16x8_t SimpleSum2DVerticalTaps(const int16x8_t* const src, + const int16x8_t taps) { + const int16x4_t taps_lo = vget_low_s16(taps); + const int16x4_t taps_hi = vget_high_s16(taps); + int32x4_t sum_lo, sum_hi; + if (num_taps == 8) { + sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 0); + sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 0); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 1); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 1); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_lo, 2); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_lo, 2); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_lo, 3); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_lo, 3); + + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[4]), taps_hi, 0); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[4]), taps_hi, 0); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[5]), taps_hi, 1); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[5]), taps_hi, 1); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[6]), taps_hi, 2); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[6]), taps_hi, 2); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[7]), taps_hi, 3); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[7]), taps_hi, 3); + } else if (num_taps == 6) { + sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 1); + sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 1); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 2); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 2); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_lo, 3); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_lo, 3); + + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_hi, 0); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_hi, 0); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[4]), taps_hi, 1); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[4]), taps_hi, 1); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[5]), taps_hi, 2); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[5]), taps_hi, 2); + } else if (num_taps == 4) { + sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 2); + sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 2); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 3); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 3); + + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_hi, 0); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_hi, 0); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_hi, 1); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_hi, 1); + } else if (num_taps == 2) { + sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 3); + sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 3); + + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_hi, 0); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_hi, 0); + } + + if (is_compound) { + return vcombine_s16( + vqrshrn_n_s32(sum_lo, kInterRoundBitsCompoundVertical - 1), + vqrshrn_n_s32(sum_hi, kInterRoundBitsCompoundVertical - 1)); + } + + return vcombine_s16(vqrshrn_n_s32(sum_lo, kInterRoundBitsVertical - 1), + vqrshrn_n_s32(sum_hi, kInterRoundBitsVertical - 1)); +} + +template <int num_taps, bool is_compound = false> +void Filter2DVertical(const uint16_t* src, void* const dst, + const ptrdiff_t dst_stride, const int width, + const int height, const int16x8_t taps) { + assert(width >= 8); + constexpr int next_row = num_taps - 1; + // The Horizontal pass uses |width| as |stride| for the intermediate buffer. + const ptrdiff_t src_stride = width; + + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + int x = 0; + do { + int16x8_t srcs[8]; + const uint16_t* src_x = src + x; + srcs[0] = vreinterpretq_s16_u16(vld1q_u16(src_x)); + src_x += src_stride; + if (num_taps >= 4) { + srcs[1] = vreinterpretq_s16_u16(vld1q_u16(src_x)); + src_x += src_stride; + srcs[2] = vreinterpretq_s16_u16(vld1q_u16(src_x)); + src_x += src_stride; + if (num_taps >= 6) { + srcs[3] = vreinterpretq_s16_u16(vld1q_u16(src_x)); + src_x += src_stride; + srcs[4] = vreinterpretq_s16_u16(vld1q_u16(src_x)); + src_x += src_stride; + if (num_taps == 8) { + srcs[5] = vreinterpretq_s16_u16(vld1q_u16(src_x)); + src_x += src_stride; + srcs[6] = vreinterpretq_s16_u16(vld1q_u16(src_x)); + src_x += src_stride; + } + } + } + + int y = 0; + do { + srcs[next_row] = vreinterpretq_s16_u16(vld1q_u16(src_x)); + src_x += src_stride; + + const int16x8_t sum = + SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps); + if (is_compound) { + vst1q_u16(dst16 + x + y * dst_stride, vreinterpretq_u16_s16(sum)); + } else { + vst1_u8(dst8 + x + y * dst_stride, vqmovun_s16(sum)); + } + + srcs[0] = srcs[1]; + if (num_taps >= 4) { + srcs[1] = srcs[2]; + srcs[2] = srcs[3]; + if (num_taps >= 6) { + srcs[3] = srcs[4]; + srcs[4] = srcs[5]; + if (num_taps == 8) { + srcs[5] = srcs[6]; + srcs[6] = srcs[7]; + } + } + } + } while (++y < height); + x += 8; + } while (x < width); +} + +// Take advantage of |src_stride| == |width| to process two rows at a time. +template <int num_taps, bool is_compound = false> +void Filter2DVertical4xH(const uint16_t* src, void* const dst, + const ptrdiff_t dst_stride, const int height, + const int16x8_t taps) { + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + int16x8_t srcs[9]; + srcs[0] = vreinterpretq_s16_u16(vld1q_u16(src)); + src += 8; + if (num_taps >= 4) { + srcs[2] = vreinterpretq_s16_u16(vld1q_u16(src)); + src += 8; + srcs[1] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[2])); + if (num_taps >= 6) { + srcs[4] = vreinterpretq_s16_u16(vld1q_u16(src)); + src += 8; + srcs[3] = vcombine_s16(vget_high_s16(srcs[2]), vget_low_s16(srcs[4])); + if (num_taps == 8) { + srcs[6] = vreinterpretq_s16_u16(vld1q_u16(src)); + src += 8; + srcs[5] = vcombine_s16(vget_high_s16(srcs[4]), vget_low_s16(srcs[6])); + } + } + } + + int y = 0; + do { + srcs[num_taps] = vreinterpretq_s16_u16(vld1q_u16(src)); + src += 8; + srcs[num_taps - 1] = vcombine_s16(vget_high_s16(srcs[num_taps - 2]), + vget_low_s16(srcs[num_taps])); + + const int16x8_t sum = + SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps); + if (is_compound) { + const uint16x8_t results = vreinterpretq_u16_s16(sum); + vst1q_u16(dst16, results); + dst16 += 4 << 1; + } else { + const uint8x8_t results = vqmovun_s16(sum); + + StoreLo4(dst8, results); + dst8 += dst_stride; + StoreHi4(dst8, results); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + if (num_taps >= 4) { + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + if (num_taps >= 6) { + srcs[3] = srcs[5]; + srcs[4] = srcs[6]; + if (num_taps == 8) { + srcs[5] = srcs[7]; + srcs[6] = srcs[8]; + } + } + } + y += 2; + } while (y < height); +} + +// Take advantage of |src_stride| == |width| to process four rows at a time. +template <int num_taps> +void Filter2DVertical2xH(const uint16_t* src, void* const dst, + const ptrdiff_t dst_stride, const int height, + const int16x8_t taps) { + constexpr int next_row = (num_taps < 6) ? 4 : 8; + + auto* dst8 = static_cast<uint8_t*>(dst); + + int16x8_t srcs[9]; + srcs[0] = vreinterpretq_s16_u16(vld1q_u16(src)); + src += 8; + if (num_taps >= 6) { + srcs[4] = vreinterpretq_s16_u16(vld1q_u16(src)); + src += 8; + srcs[1] = vextq_s16(srcs[0], srcs[4], 2); + if (num_taps == 8) { + srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4])); + srcs[3] = vextq_s16(srcs[0], srcs[4], 6); + } + } + + int y = 0; + do { + srcs[next_row] = vreinterpretq_s16_u16(vld1q_u16(src)); + src += 8; + if (num_taps == 2) { + srcs[1] = vextq_s16(srcs[0], srcs[4], 2); + } else if (num_taps == 4) { + srcs[1] = vextq_s16(srcs[0], srcs[4], 2); + srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4])); + srcs[3] = vextq_s16(srcs[0], srcs[4], 6); + } else if (num_taps == 6) { + srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4])); + srcs[3] = vextq_s16(srcs[0], srcs[4], 6); + srcs[5] = vextq_s16(srcs[4], srcs[8], 2); + } else if (num_taps == 8) { + srcs[5] = vextq_s16(srcs[4], srcs[8], 2); + srcs[6] = vcombine_s16(vget_high_s16(srcs[4]), vget_low_s16(srcs[8])); + srcs[7] = vextq_s16(srcs[4], srcs[8], 6); + } + + const int16x8_t sum = + SimpleSum2DVerticalTaps<num_taps, /*is_compound=*/false>(srcs, taps); + const uint8x8_t results = vqmovun_s16(sum); + + Store2<0>(dst8, results); + dst8 += dst_stride; + Store2<1>(dst8, results); + // When |height| <= 4 the taps are restricted to 2 and 4 tap variants. + // Therefore we don't need to check this condition when |height| > 4. + if (num_taps <= 4 && height == 2) return; + dst8 += dst_stride; + Store2<2>(dst8, results); + dst8 += dst_stride; + Store2<3>(dst8, results); + dst8 += dst_stride; + + srcs[0] = srcs[4]; + if (num_taps == 6) { + srcs[1] = srcs[5]; + srcs[4] = srcs[8]; + } else if (num_taps == 8) { + srcs[1] = srcs[5]; + srcs[2] = srcs[6]; + srcs[3] = srcs[7]; + srcs[4] = srcs[8]; + } + + y += 4; + } while (y < height); +} + +template <bool is_2d = false, bool is_compound = false> +LIBGAV1_ALWAYS_INLINE void DoHorizontalPass( + const uint8_t* const src, const ptrdiff_t src_stride, void* const dst, + const ptrdiff_t dst_stride, const int width, const int height, + const int filter_id, const int filter_index) { + // Duplicate the absolute value for each tap. Negative taps are corrected + // by using the vmlsl_u8 instruction. Positive taps use vmlal_u8. + uint8x8_t v_tap[kSubPixelTaps]; + assert(filter_id != 0); + + for (int k = 0; k < kSubPixelTaps; ++k) { + v_tap[k] = vdup_n_u8(kAbsHalfSubPixelFilters[filter_index][filter_id][k]); + } + + if (filter_index == 2) { // 8 tap. + FilterHorizontal<8, 8, 2, true, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else if (filter_index == 1) { // 6 tap. + // Check if outside taps are positive. + if ((filter_id == 1) | (filter_id == 15)) { + FilterHorizontal<6, 8, 1, false, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else { + FilterHorizontal<6, 8, 1, true, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } + } else if (filter_index == 0) { // 6 tap. + FilterHorizontal<6, 8, 0, true, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else if (filter_index == 4) { // 4 tap. + FilterHorizontal<4, 8, 4, true, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else if (filter_index == 5) { // 4 tap. + FilterHorizontal<4, 8, 5, true, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else { // 2 tap. + FilterHorizontal<2, 8, 3, true, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } +} + +void Convolve2D_NEON(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int vertical_filter_index, + const int horizontal_filter_id, + const int vertical_filter_id, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width); + const int vert_filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(vert_filter_index); + + // The output of the horizontal filter is guaranteed to fit in 16 bits. + uint16_t + intermediate_result[kMaxSuperBlockSizeInPixels * + (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)]; + const int intermediate_height = height + vertical_taps - 1; + + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride - kHorizontalOffset; + + DoHorizontalPass</*is_2d=*/true>(src, src_stride, intermediate_result, width, + width, intermediate_height, + horizontal_filter_id, horiz_filter_index); + + // Vertical filter. + auto* dest = static_cast<uint8_t*>(prediction); + const ptrdiff_t dest_stride = pred_stride; + assert(vertical_filter_id != 0); + + const int16x8_t taps = vmovl_s8( + vld1_s8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id])); + + if (vertical_taps == 8) { + if (width == 2) { + Filter2DVertical2xH<8>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<8>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<8>(intermediate_result, dest, dest_stride, width, height, + taps); + } + } else if (vertical_taps == 6) { + if (width == 2) { + Filter2DVertical2xH<6>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<6>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<6>(intermediate_result, dest, dest_stride, width, height, + taps); + } + } else if (vertical_taps == 4) { + if (width == 2) { + Filter2DVertical2xH<4>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<4>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<4>(intermediate_result, dest, dest_stride, width, height, + taps); + } + } else { // |vertical_taps| == 2 + if (width == 2) { + Filter2DVertical2xH<2>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<2>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<2>(intermediate_result, dest, dest_stride, width, height, + taps); + } + } +} + +// There are many opportunities for overreading in scaled convolve, because the +// range of starting points for filter windows is anywhere from 0 to 16 for 8 +// destination pixels, and the window sizes range from 2 to 8. To accommodate +// this range concisely, we use |grade_x| to mean the most steps in src that can +// be traversed in a single |step_x| increment, i.e. 1 or 2. When grade_x is 2, +// we are guaranteed to exceed 8 whole steps in src for every 8 |step_x| +// increments. The first load covers the initial elements of src_x, while the +// final load covers the taps. +template <int grade_x> +inline uint8x8x3_t LoadSrcVals(const uint8_t* src_x) { + uint8x8x3_t ret; + const uint8x16_t src_val = vld1q_u8(src_x); + ret.val[0] = vget_low_u8(src_val); + ret.val[1] = vget_high_u8(src_val); + if (grade_x > 1) { + ret.val[2] = vld1_u8(src_x + 16); + } + return ret; +} + +// Pre-transpose the 2 tap filters in |kAbsHalfSubPixelFilters|[3] +inline uint8x16_t GetPositive2TapFilter(const int tap_index) { + assert(tap_index < 2); + alignas( + 16) static constexpr uint8_t kAbsHalfSubPixel2TapFilterColumns[2][16] = { + {64, 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4}, + {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}}; + + return vld1q_u8(kAbsHalfSubPixel2TapFilterColumns[tap_index]); +} + +template <int grade_x> +inline void ConvolveKernelHorizontal2Tap(const uint8_t* src, + const ptrdiff_t src_stride, + const int width, const int subpixel_x, + const int step_x, + const int intermediate_height, + int16_t* intermediate) { + // Account for the 0-taps that precede the 2 nonzero taps. + const int kernel_offset = 3; + const int ref_x = subpixel_x >> kScaleSubPixelBits; + const int step_x8 = step_x << 3; + const uint8x16_t filter_taps0 = GetPositive2TapFilter(0); + const uint8x16_t filter_taps1 = GetPositive2TapFilter(1); + const uint16x8_t index_steps = vmulq_n_u16( + vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x)); + const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask); + + int p = subpixel_x; + if (width <= 4) { + const uint8_t* src_x = + &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset]; + // Only add steps to the 10-bit truncated p to avoid overflow. + const uint16x8_t p_fraction = vdupq_n_u16(p & 1023); + const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction); + const uint8x8_t filter_indices = + vand_u8(vshrn_n_u16(subpel_index_offsets, 6), filter_index_mask); + // This is a special case. The 2-tap filter has no negative taps, so we + // can use unsigned values. + // For each x, a lane of tapsK has + // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends + // on x. + const uint8x8_t taps[2] = {VQTbl1U8(filter_taps0, filter_indices), + VQTbl1U8(filter_taps1, filter_indices)}; + int y = 0; + do { + // Load a pool of samples to select from using stepped indices. + const uint8x16_t src_vals = vld1q_u8(src_x); + const uint8x8_t src_indices = + vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)); + + // For each x, a lane of srcK contains src_x[k]. + const uint8x8_t src[2] = { + VQTbl1U8(src_vals, src_indices), + VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(1)))}; + + vst1q_s16(intermediate, + vrshrq_n_s16(SumOnePassTaps</*filter_index=*/3>(src, taps), + kInterRoundBitsHorizontal - 1)); + src_x += src_stride; + intermediate += kIntermediateStride; + } while (++y < intermediate_height); + return; + } + + // |width| >= 8 + int x = 0; + do { + const uint8_t* src_x = + &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset]; + int16_t* intermediate_x = intermediate + x; + // Only add steps to the 10-bit truncated p to avoid overflow. + const uint16x8_t p_fraction = vdupq_n_u16(p & 1023); + const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction); + const uint8x8_t filter_indices = + vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift), + filter_index_mask); + // This is a special case. The 2-tap filter has no negative taps, so we + // can use unsigned values. + // For each x, a lane of tapsK has + // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends + // on x. + const uint8x8_t taps[2] = {VQTbl1U8(filter_taps0, filter_indices), + VQTbl1U8(filter_taps1, filter_indices)}; + int y = 0; + do { + // Load a pool of samples to select from using stepped indices. + const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x); + const uint8x8_t src_indices = + vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)); + + // For each x, a lane of srcK contains src_x[k]. + const uint8x8_t src[2] = { + vtbl3_u8(src_vals, src_indices), + vtbl3_u8(src_vals, vadd_u8(src_indices, vdup_n_u8(1)))}; + + vst1q_s16(intermediate_x, + vrshrq_n_s16(SumOnePassTaps</*filter_index=*/3>(src, taps), + kInterRoundBitsHorizontal - 1)); + src_x += src_stride; + intermediate_x += kIntermediateStride; + } while (++y < intermediate_height); + x += 8; + p += step_x8; + } while (x < width); +} + +// Pre-transpose the 4 tap filters in |kAbsHalfSubPixelFilters|[5]. +inline uint8x16_t GetPositive4TapFilter(const int tap_index) { + assert(tap_index < 4); + alignas( + 16) static constexpr uint8_t kSubPixel4TapPositiveFilterColumns[4][16] = { + {0, 15, 13, 11, 10, 9, 8, 7, 6, 6, 5, 4, 3, 2, 2, 1}, + {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17}, + {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31}, + {0, 1, 2, 2, 3, 4, 5, 6, 6, 7, 8, 9, 10, 11, 13, 15}}; + + return vld1q_u8(kSubPixel4TapPositiveFilterColumns[tap_index]); +} + +// This filter is only possible when width <= 4. +void ConvolveKernelHorizontalPositive4Tap( + const uint8_t* src, const ptrdiff_t src_stride, const int subpixel_x, + const int step_x, const int intermediate_height, int16_t* intermediate) { + const int kernel_offset = 2; + const int ref_x = subpixel_x >> kScaleSubPixelBits; + const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask); + const uint8x16_t filter_taps0 = GetPositive4TapFilter(0); + const uint8x16_t filter_taps1 = GetPositive4TapFilter(1); + const uint8x16_t filter_taps2 = GetPositive4TapFilter(2); + const uint8x16_t filter_taps3 = GetPositive4TapFilter(3); + const uint16x8_t index_steps = vmulq_n_u16( + vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x)); + const int p = subpixel_x; + // First filter is special, just a 128 tap on the center. + const uint8_t* src_x = + &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset]; + // Only add steps to the 10-bit truncated p to avoid overflow. + const uint16x8_t p_fraction = vdupq_n_u16(p & 1023); + const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction); + const uint8x8_t filter_indices = vand_u8( + vshrn_n_u16(subpel_index_offsets, kFilterIndexShift), filter_index_mask); + // Note that filter_id depends on x. + // For each x, tapsK has kSubPixelFilters[filter_index][filter_id][k]. + const uint8x8_t taps[4] = {VQTbl1U8(filter_taps0, filter_indices), + VQTbl1U8(filter_taps1, filter_indices), + VQTbl1U8(filter_taps2, filter_indices), + VQTbl1U8(filter_taps3, filter_indices)}; + + const uint8x8_t src_indices = + vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)); + int y = 0; + do { + // Load a pool of samples to select from using stepped index vectors. + const uint8x16_t src_vals = vld1q_u8(src_x); + + // For each x, srcK contains src_x[k] where k=1. + // Whereas taps come from different arrays, src pixels are drawn from the + // same contiguous line. + const uint8x8_t src[4] = { + VQTbl1U8(src_vals, src_indices), + VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(1))), + VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(2))), + VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(3)))}; + + vst1q_s16(intermediate, + vrshrq_n_s16(SumOnePassTaps</*filter_index=*/5>(src, taps), + kInterRoundBitsHorizontal - 1)); + + src_x += src_stride; + intermediate += kIntermediateStride; + } while (++y < intermediate_height); +} + +// Pre-transpose the 4 tap filters in |kAbsHalfSubPixelFilters|[4]. +inline uint8x16_t GetSigned4TapFilter(const int tap_index) { + assert(tap_index < 4); + alignas(16) static constexpr uint8_t + kAbsHalfSubPixel4TapSignedFilterColumns[4][16] = { + {0, 2, 4, 5, 6, 6, 7, 6, 6, 5, 5, 5, 4, 3, 2, 1}, + {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4}, + {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63}, + {0, 1, 2, 3, 4, 5, 5, 5, 6, 6, 7, 6, 6, 5, 4, 2}}; + + return vld1q_u8(kAbsHalfSubPixel4TapSignedFilterColumns[tap_index]); +} + +// This filter is only possible when width <= 4. +inline void ConvolveKernelHorizontalSigned4Tap( + const uint8_t* src, const ptrdiff_t src_stride, const int subpixel_x, + const int step_x, const int intermediate_height, int16_t* intermediate) { + const int kernel_offset = 2; + const int ref_x = subpixel_x >> kScaleSubPixelBits; + const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask); + const uint8x16_t filter_taps0 = GetSigned4TapFilter(0); + const uint8x16_t filter_taps1 = GetSigned4TapFilter(1); + const uint8x16_t filter_taps2 = GetSigned4TapFilter(2); + const uint8x16_t filter_taps3 = GetSigned4TapFilter(3); + const uint16x4_t index_steps = vmul_n_u16(vcreate_u16(0x0003000200010000), + static_cast<uint16_t>(step_x)); + + const int p = subpixel_x; + const uint8_t* src_x = + &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset]; + // Only add steps to the 10-bit truncated p to avoid overflow. + const uint16x4_t p_fraction = vdup_n_u16(p & 1023); + const uint16x4_t subpel_index_offsets = vadd_u16(index_steps, p_fraction); + const uint8x8_t filter_index_offsets = vshrn_n_u16( + vcombine_u16(subpel_index_offsets, vdup_n_u16(0)), kFilterIndexShift); + const uint8x8_t filter_indices = + vand_u8(filter_index_offsets, filter_index_mask); + // Note that filter_id depends on x. + // For each x, tapsK has kSubPixelFilters[filter_index][filter_id][k]. + const uint8x8_t taps[4] = {VQTbl1U8(filter_taps0, filter_indices), + VQTbl1U8(filter_taps1, filter_indices), + VQTbl1U8(filter_taps2, filter_indices), + VQTbl1U8(filter_taps3, filter_indices)}; + + const uint8x8_t src_indices_base = + vshr_n_u8(filter_index_offsets, kScaleSubPixelBits - kFilterIndexShift); + + const uint8x8_t src_indices[4] = {src_indices_base, + vadd_u8(src_indices_base, vdup_n_u8(1)), + vadd_u8(src_indices_base, vdup_n_u8(2)), + vadd_u8(src_indices_base, vdup_n_u8(3))}; + + int y = 0; + do { + // Load a pool of samples to select from using stepped indices. + const uint8x16_t src_vals = vld1q_u8(src_x); + + // For each x, srcK contains src_x[k] where k=1. + // Whereas taps come from different arrays, src pixels are drawn from the + // same contiguous line. + const uint8x8_t src[4] = { + VQTbl1U8(src_vals, src_indices[0]), VQTbl1U8(src_vals, src_indices[1]), + VQTbl1U8(src_vals, src_indices[2]), VQTbl1U8(src_vals, src_indices[3])}; + + vst1q_s16(intermediate, + vrshrq_n_s16(SumOnePassTaps</*filter_index=*/4>(src, taps), + kInterRoundBitsHorizontal - 1)); + src_x += src_stride; + intermediate += kIntermediateStride; + } while (++y < intermediate_height); +} + +// Pre-transpose the 6 tap filters in |kAbsHalfSubPixelFilters|[0]. +inline uint8x16_t GetSigned6TapFilter(const int tap_index) { + assert(tap_index < 6); + alignas(16) static constexpr uint8_t + kAbsHalfSubPixel6TapSignedFilterColumns[6][16] = { + {0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0}, + {0, 3, 5, 6, 7, 7, 8, 7, 7, 6, 6, 6, 5, 4, 2, 1}, + {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4}, + {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63}, + {0, 1, 2, 4, 5, 6, 6, 6, 7, 7, 8, 7, 7, 6, 5, 3}, + {0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}}; + + return vld1q_u8(kAbsHalfSubPixel6TapSignedFilterColumns[tap_index]); +} + +// This filter is only possible when width >= 8. +template <int grade_x> +inline void ConvolveKernelHorizontalSigned6Tap( + const uint8_t* src, const ptrdiff_t src_stride, const int width, + const int subpixel_x, const int step_x, const int intermediate_height, + int16_t* intermediate) { + const int kernel_offset = 1; + const uint8x8_t one = vdup_n_u8(1); + const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask); + const int ref_x = subpixel_x >> kScaleSubPixelBits; + const int step_x8 = step_x << 3; + uint8x16_t filter_taps[6]; + for (int i = 0; i < 6; ++i) { + filter_taps[i] = GetSigned6TapFilter(i); + } + const uint16x8_t index_steps = vmulq_n_u16( + vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x)); + + int x = 0; + int p = subpixel_x; + do { + // Avoid overloading outside the reference boundaries. This means + // |trailing_width| can be up to 24. + const uint8_t* src_x = + &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset]; + int16_t* intermediate_x = intermediate + x; + // Only add steps to the 10-bit truncated p to avoid overflow. + const uint16x8_t p_fraction = vdupq_n_u16(p & 1023); + const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction); + const uint8x8_t src_indices = + vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)); + uint8x8_t src_lookup[6]; + src_lookup[0] = src_indices; + for (int i = 1; i < 6; ++i) { + src_lookup[i] = vadd_u8(src_lookup[i - 1], one); + } + + const uint8x8_t filter_indices = + vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift), + filter_index_mask); + // For each x, a lane of taps[k] has + // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends + // on x. + uint8x8_t taps[6]; + for (int i = 0; i < 6; ++i) { + taps[i] = VQTbl1U8(filter_taps[i], filter_indices); + } + int y = 0; + do { + // Load a pool of samples to select from using stepped indices. + const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x); + + const uint8x8_t src[6] = { + vtbl3_u8(src_vals, src_lookup[0]), vtbl3_u8(src_vals, src_lookup[1]), + vtbl3_u8(src_vals, src_lookup[2]), vtbl3_u8(src_vals, src_lookup[3]), + vtbl3_u8(src_vals, src_lookup[4]), vtbl3_u8(src_vals, src_lookup[5])}; + + vst1q_s16(intermediate_x, + vrshrq_n_s16(SumOnePassTaps</*filter_index=*/0>(src, taps), + kInterRoundBitsHorizontal - 1)); + src_x += src_stride; + intermediate_x += kIntermediateStride; + } while (++y < intermediate_height); + x += 8; + p += step_x8; + } while (x < width); +} + +// Pre-transpose the 6 tap filters in |kAbsHalfSubPixelFilters|[1]. This filter +// has mixed positive and negative outer taps which are handled in +// GetMixed6TapFilter(). +inline uint8x16_t GetPositive6TapFilter(const int tap_index) { + assert(tap_index < 6); + alignas(16) static constexpr uint8_t + kAbsHalfSubPixel6TapPositiveFilterColumns[4][16] = { + {0, 14, 13, 11, 10, 9, 8, 8, 7, 6, 5, 4, 3, 2, 2, 1}, + {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17}, + {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31}, + {0, 1, 2, 2, 3, 4, 5, 6, 7, 8, 8, 9, 10, 11, 13, 14}}; + + return vld1q_u8(kAbsHalfSubPixel6TapPositiveFilterColumns[tap_index]); +} + +inline int8x16_t GetMixed6TapFilter(const int tap_index) { + assert(tap_index < 2); + alignas( + 16) static constexpr int8_t kHalfSubPixel6TapMixedFilterColumns[2][16] = { + {0, 1, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 1}}; + + return vld1q_s8(kHalfSubPixel6TapMixedFilterColumns[tap_index]); +} + +// This filter is only possible when width >= 8. +template <int grade_x> +inline void ConvolveKernelHorizontalMixed6Tap( + const uint8_t* src, const ptrdiff_t src_stride, const int width, + const int subpixel_x, const int step_x, const int intermediate_height, + int16_t* intermediate) { + const int kernel_offset = 1; + const uint8x8_t one = vdup_n_u8(1); + const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask); + const int ref_x = subpixel_x >> kScaleSubPixelBits; + const int step_x8 = step_x << 3; + uint8x8_t taps[4]; + int16x8_t mixed_taps[2]; + uint8x16_t positive_filter_taps[4]; + for (int i = 0; i < 4; ++i) { + positive_filter_taps[i] = GetPositive6TapFilter(i); + } + int8x16_t mixed_filter_taps[2]; + mixed_filter_taps[0] = GetMixed6TapFilter(0); + mixed_filter_taps[1] = GetMixed6TapFilter(1); + const uint16x8_t index_steps = vmulq_n_u16( + vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x)); + + int x = 0; + int p = subpixel_x; + do { + const uint8_t* src_x = + &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset]; + int16_t* intermediate_x = intermediate + x; + // Only add steps to the 10-bit truncated p to avoid overflow. + const uint16x8_t p_fraction = vdupq_n_u16(p & 1023); + const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction); + const uint8x8_t src_indices = + vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)); + uint8x8_t src_lookup[6]; + src_lookup[0] = src_indices; + for (int i = 1; i < 6; ++i) { + src_lookup[i] = vadd_u8(src_lookup[i - 1], one); + } + + const uint8x8_t filter_indices = + vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift), + filter_index_mask); + // For each x, a lane of taps[k] has + // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends + // on x. + for (int i = 0; i < 4; ++i) { + taps[i] = VQTbl1U8(positive_filter_taps[i], filter_indices); + } + mixed_taps[0] = vmovl_s8(VQTbl1S8(mixed_filter_taps[0], filter_indices)); + mixed_taps[1] = vmovl_s8(VQTbl1S8(mixed_filter_taps[1], filter_indices)); + + int y = 0; + do { + // Load a pool of samples to select from using stepped indices. + const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x); + + int16x8_t sum_mixed = vmulq_s16( + mixed_taps[0], ZeroExtend(vtbl3_u8(src_vals, src_lookup[0]))); + sum_mixed = vmlaq_s16(sum_mixed, mixed_taps[1], + ZeroExtend(vtbl3_u8(src_vals, src_lookup[5]))); + uint16x8_t sum = vreinterpretq_u16_s16(sum_mixed); + sum = vmlal_u8(sum, taps[0], vtbl3_u8(src_vals, src_lookup[1])); + sum = vmlal_u8(sum, taps[1], vtbl3_u8(src_vals, src_lookup[2])); + sum = vmlal_u8(sum, taps[2], vtbl3_u8(src_vals, src_lookup[3])); + sum = vmlal_u8(sum, taps[3], vtbl3_u8(src_vals, src_lookup[4])); + + vst1q_s16(intermediate_x, vrshrq_n_s16(vreinterpretq_s16_u16(sum), + kInterRoundBitsHorizontal - 1)); + src_x += src_stride; + intermediate_x += kIntermediateStride; + } while (++y < intermediate_height); + x += 8; + p += step_x8; + } while (x < width); +} + +// Pre-transpose the 8 tap filters in |kAbsHalfSubPixelFilters|[2]. +inline uint8x16_t GetSigned8TapFilter(const int tap_index) { + assert(tap_index < 8); + alignas(16) static constexpr uint8_t + kAbsHalfSubPixel8TapSignedFilterColumns[8][16] = { + {0, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 0}, + {0, 1, 3, 4, 5, 5, 5, 5, 6, 5, 4, 4, 3, 3, 2, 1}, + {0, 3, 6, 9, 11, 11, 12, 12, 12, 11, 10, 9, 7, 5, 3, 1}, + {64, 63, 62, 60, 58, 54, 50, 45, 40, 35, 30, 24, 19, 13, 8, 4}, + {0, 4, 8, 13, 19, 24, 30, 35, 40, 45, 50, 54, 58, 60, 62, 63}, + {0, 1, 3, 5, 7, 9, 10, 11, 12, 12, 12, 11, 11, 9, 6, 3}, + {0, 1, 2, 3, 3, 4, 4, 5, 6, 5, 5, 5, 5, 4, 3, 1}, + {0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1}}; + + return vld1q_u8(kAbsHalfSubPixel8TapSignedFilterColumns[tap_index]); +} + +// This filter is only possible when width >= 8. +template <int grade_x> +inline void ConvolveKernelHorizontalSigned8Tap( + const uint8_t* src, const ptrdiff_t src_stride, const int width, + const int subpixel_x, const int step_x, const int intermediate_height, + int16_t* intermediate) { + const uint8x8_t one = vdup_n_u8(1); + const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask); + const int ref_x = subpixel_x >> kScaleSubPixelBits; + const int step_x8 = step_x << 3; + uint8x8_t taps[8]; + uint8x16_t filter_taps[8]; + for (int i = 0; i < 8; ++i) { + filter_taps[i] = GetSigned8TapFilter(i); + } + const uint16x8_t index_steps = vmulq_n_u16( + vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x)); + int x = 0; + int p = subpixel_x; + do { + const uint8_t* src_x = &src[(p >> kScaleSubPixelBits) - ref_x]; + int16_t* intermediate_x = intermediate + x; + // Only add steps to the 10-bit truncated p to avoid overflow. + const uint16x8_t p_fraction = vdupq_n_u16(p & 1023); + const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction); + const uint8x8_t src_indices = + vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)); + uint8x8_t src_lookup[8]; + src_lookup[0] = src_indices; + for (int i = 1; i < 8; ++i) { + src_lookup[i] = vadd_u8(src_lookup[i - 1], one); + } + + const uint8x8_t filter_indices = + vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift), + filter_index_mask); + // For each x, a lane of taps[k] has + // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends + // on x. + for (int i = 0; i < 8; ++i) { + taps[i] = VQTbl1U8(filter_taps[i], filter_indices); + } + + int y = 0; + do { + // Load a pool of samples to select from using stepped indices. + const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x); + + const uint8x8_t src[8] = { + vtbl3_u8(src_vals, src_lookup[0]), vtbl3_u8(src_vals, src_lookup[1]), + vtbl3_u8(src_vals, src_lookup[2]), vtbl3_u8(src_vals, src_lookup[3]), + vtbl3_u8(src_vals, src_lookup[4]), vtbl3_u8(src_vals, src_lookup[5]), + vtbl3_u8(src_vals, src_lookup[6]), vtbl3_u8(src_vals, src_lookup[7])}; + + vst1q_s16(intermediate_x, + vrshrq_n_s16(SumOnePassTaps</*filter_index=*/2>(src, taps), + kInterRoundBitsHorizontal - 1)); + src_x += src_stride; + intermediate_x += kIntermediateStride; + } while (++y < intermediate_height); + x += 8; + p += step_x8; + } while (x < width); +} + +// This function handles blocks of width 2 or 4. +template <int num_taps, int grade_y, int width, bool is_compound> +void ConvolveVerticalScale4xH(const int16_t* src, const int subpixel_y, + const int filter_index, const int step_y, + const int height, void* dest, + const ptrdiff_t dest_stride) { + constexpr ptrdiff_t src_stride = kIntermediateStride; + const int16_t* src_y = src; + // |dest| is 16-bit in compound mode, Pixel otherwise. + uint16_t* dest16_y = static_cast<uint16_t*>(dest); + uint8_t* dest_y = static_cast<uint8_t*>(dest); + int16x4_t s[num_taps + grade_y]; + + int p = subpixel_y & 1023; + int prev_p = p; + int y = 0; + do { // y < height + for (int i = 0; i < num_taps; ++i) { + s[i] = vld1_s16(src_y + i * src_stride); + } + int filter_id = (p >> 6) & kSubPixelMask; + int16x8_t filter = + vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id])); + int16x4_t sums = Sum2DVerticalTaps4<num_taps, is_compound>(s, filter); + if (is_compound) { + assert(width != 2); + const uint16x4_t result = vreinterpret_u16_s16(sums); + vst1_u16(dest16_y, result); + } else { + const uint8x8_t result = vqmovun_s16(vcombine_s16(sums, sums)); + if (width == 2) { + Store2<0>(dest_y, result); + } else { + StoreLo4(dest_y, result); + } + } + p += step_y; + const int p_diff = + (p >> kScaleSubPixelBits) - (prev_p >> kScaleSubPixelBits); + prev_p = p; + // Here we load extra source in case it is needed. If |p_diff| == 0, these + // values will be unused, but it's faster to load than to branch. + s[num_taps] = vld1_s16(src_y + num_taps * src_stride); + if (grade_y > 1) { + s[num_taps + 1] = vld1_s16(src_y + (num_taps + 1) * src_stride); + } + dest16_y += dest_stride; + dest_y += dest_stride; + + filter_id = (p >> 6) & kSubPixelMask; + filter = vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id])); + sums = Sum2DVerticalTaps4<num_taps, is_compound>(&s[p_diff], filter); + if (is_compound) { + assert(width != 2); + const uint16x4_t result = vreinterpret_u16_s16(sums); + vst1_u16(dest16_y, result); + } else { + const uint8x8_t result = vqmovun_s16(vcombine_s16(sums, sums)); + if (width == 2) { + Store2<0>(dest_y, result); + } else { + StoreLo4(dest_y, result); + } + } + p += step_y; + src_y = src + (p >> kScaleSubPixelBits) * src_stride; + prev_p = p; + dest16_y += dest_stride; + dest_y += dest_stride; + + y += 2; + } while (y < height); +} + +template <int num_taps, int grade_y, bool is_compound> +inline void ConvolveVerticalScale(const int16_t* src, const int width, + const int subpixel_y, const int filter_index, + const int step_y, const int height, + void* dest, const ptrdiff_t dest_stride) { + constexpr ptrdiff_t src_stride = kIntermediateStride; + // A possible improvement is to use arithmetic to decide how many times to + // apply filters to same source before checking whether to load new srcs. + // However, this will only improve performance with very small step sizes. + int16x8_t s[num_taps + grade_y]; + // |dest| is 16-bit in compound mode, Pixel otherwise. + uint16_t* dest16_y; + uint8_t* dest_y; + + int x = 0; + do { // x < width + const int16_t* src_x = src + x; + const int16_t* src_y = src_x; + dest16_y = static_cast<uint16_t*>(dest) + x; + dest_y = static_cast<uint8_t*>(dest) + x; + int p = subpixel_y & 1023; + int prev_p = p; + int y = 0; + do { // y < height + for (int i = 0; i < num_taps; ++i) { + s[i] = vld1q_s16(src_y + i * src_stride); + } + int filter_id = (p >> 6) & kSubPixelMask; + int16x8_t filter = + vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id])); + int16x8_t sum = SimpleSum2DVerticalTaps<num_taps, is_compound>(s, filter); + if (is_compound) { + vst1q_u16(dest16_y, vreinterpretq_u16_s16(sum)); + } else { + vst1_u8(dest_y, vqmovun_s16(sum)); + } + p += step_y; + const int p_diff = + (p >> kScaleSubPixelBits) - (prev_p >> kScaleSubPixelBits); + // |grade_y| > 1 always means p_diff > 0, so load vectors that may be + // needed. Otherwise, we only need to load one vector because |p_diff| + // can't exceed 1. + s[num_taps] = vld1q_s16(src_y + num_taps * src_stride); + if (grade_y > 1) { + s[num_taps + 1] = vld1q_s16(src_y + (num_taps + 1) * src_stride); + } + dest16_y += dest_stride; + dest_y += dest_stride; + + filter_id = (p >> 6) & kSubPixelMask; + filter = vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id])); + sum = SimpleSum2DVerticalTaps<num_taps, is_compound>(&s[p_diff], filter); + if (is_compound) { + vst1q_u16(dest16_y, vreinterpretq_u16_s16(sum)); + } else { + vst1_u8(dest_y, vqmovun_s16(sum)); + } + p += step_y; + src_y = src_x + (p >> kScaleSubPixelBits) * src_stride; + prev_p = p; + dest16_y += dest_stride; + dest_y += dest_stride; + + y += 2; + } while (y < height); + x += 8; + } while (x < width); +} + +template <bool is_compound> +void ConvolveScale2D_NEON(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int vertical_filter_index, const int subpixel_x, + const int subpixel_y, const int step_x, + const int step_y, const int width, const int height, + void* prediction, const ptrdiff_t pred_stride) { + const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width); + const int vert_filter_index = GetFilterIndex(vertical_filter_index, height); + assert(step_x <= 2048); + const int num_vert_taps = GetNumTapsInFilter(vert_filter_index); + const int intermediate_height = + (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >> + kScaleSubPixelBits) + + num_vert_taps; + assert(step_x <= 2048); + // The output of the horizontal filter, i.e. the intermediate_result, is + // guaranteed to fit in int16_t. + int16_t intermediate_result[kMaxSuperBlockSizeInPixels * + (2 * kMaxSuperBlockSizeInPixels + 8)]; + + // Horizontal filter. + // Filter types used for width <= 4 are different from those for width > 4. + // When width > 4, the valid filter index range is always [0, 3]. + // When width <= 4, the valid filter index range is always [3, 5]. + // Similarly for height. + int filter_index = GetFilterIndex(horizontal_filter_index, width); + int16_t* intermediate = intermediate_result; + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference); + const int vert_kernel_offset = (8 - num_vert_taps) / 2; + src += vert_kernel_offset * src_stride; + + // Derive the maximum value of |step_x| at which all source values fit in one + // 16-byte load. Final index is src_x + |num_taps| - 1 < 16 + // step_x*7 is the final base subpel index for the shuffle mask for filter + // inputs in each iteration on large blocks. When step_x is large, we need a + // larger structure and use a larger table lookup in order to gather all + // filter inputs. + // |num_taps| - 1 is the shuffle index of the final filter input. + const int num_horiz_taps = GetNumTapsInFilter(horiz_filter_index); + const int kernel_start_ceiling = 16 - num_horiz_taps; + // This truncated quotient |grade_x_threshold| selects |step_x| such that: + // (step_x * 7) >> kScaleSubPixelBits < single load limit + const int grade_x_threshold = + (kernel_start_ceiling << kScaleSubPixelBits) / 7; + switch (filter_index) { + case 0: + if (step_x > grade_x_threshold) { + ConvolveKernelHorizontalSigned6Tap<2>( + src, src_stride, width, subpixel_x, step_x, intermediate_height, + intermediate); + } else { + ConvolveKernelHorizontalSigned6Tap<1>( + src, src_stride, width, subpixel_x, step_x, intermediate_height, + intermediate); + } + break; + case 1: + if (step_x > grade_x_threshold) { + ConvolveKernelHorizontalMixed6Tap<2>(src, src_stride, width, subpixel_x, + step_x, intermediate_height, + intermediate); + + } else { + ConvolveKernelHorizontalMixed6Tap<1>(src, src_stride, width, subpixel_x, + step_x, intermediate_height, + intermediate); + } + break; + case 2: + if (step_x > grade_x_threshold) { + ConvolveKernelHorizontalSigned8Tap<2>( + src, src_stride, width, subpixel_x, step_x, intermediate_height, + intermediate); + } else { + ConvolveKernelHorizontalSigned8Tap<1>( + src, src_stride, width, subpixel_x, step_x, intermediate_height, + intermediate); + } + break; + case 3: + if (step_x > grade_x_threshold) { + ConvolveKernelHorizontal2Tap<2>(src, src_stride, width, subpixel_x, + step_x, intermediate_height, + intermediate); + } else { + ConvolveKernelHorizontal2Tap<1>(src, src_stride, width, subpixel_x, + step_x, intermediate_height, + intermediate); + } + break; + case 4: + assert(width <= 4); + ConvolveKernelHorizontalSigned4Tap(src, src_stride, subpixel_x, step_x, + intermediate_height, intermediate); + break; + default: + assert(filter_index == 5); + ConvolveKernelHorizontalPositive4Tap(src, src_stride, subpixel_x, step_x, + intermediate_height, intermediate); + } + // Vertical filter. + filter_index = GetFilterIndex(vertical_filter_index, height); + intermediate = intermediate_result; + + switch (filter_index) { + case 0: + case 1: + if (step_y <= 1024) { + if (!is_compound && width == 2) { + ConvolveVerticalScale4xH<6, 1, 2, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale4xH<6, 1, 4, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<6, 1, is_compound>( + intermediate, width, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } + } else { + if (!is_compound && width == 2) { + ConvolveVerticalScale4xH<6, 2, 2, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale4xH<6, 2, 4, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<6, 2, is_compound>( + intermediate, width, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } + } + break; + case 2: + if (step_y <= 1024) { + if (!is_compound && width == 2) { + ConvolveVerticalScale4xH<8, 1, 2, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale4xH<8, 1, 4, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<8, 1, is_compound>( + intermediate, width, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } + } else { + if (!is_compound && width == 2) { + ConvolveVerticalScale4xH<8, 2, 2, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale4xH<8, 2, 4, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<8, 2, is_compound>( + intermediate, width, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } + } + break; + case 3: + if (step_y <= 1024) { + if (!is_compound && width == 2) { + ConvolveVerticalScale4xH<2, 1, 2, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale4xH<2, 1, 4, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<2, 1, is_compound>( + intermediate, width, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } + } else { + if (!is_compound && width == 2) { + ConvolveVerticalScale4xH<2, 2, 2, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale4xH<2, 2, 4, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<2, 2, is_compound>( + intermediate, width, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } + } + break; + case 4: + default: + assert(filter_index == 4 || filter_index == 5); + assert(height <= 4); + if (step_y <= 1024) { + if (!is_compound && width == 2) { + ConvolveVerticalScale4xH<4, 1, 2, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale4xH<4, 1, 4, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<4, 1, is_compound>( + intermediate, width, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } + } else { + if (!is_compound && width == 2) { + ConvolveVerticalScale4xH<4, 2, 2, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale4xH<4, 2, 4, is_compound>( + intermediate, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<4, 2, is_compound>( + intermediate, width, subpixel_y, filter_index, step_y, height, + prediction, pred_stride); + } + } + } +} + +void ConvolveHorizontal_NEON(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int /*vertical_filter_index*/, + const int horizontal_filter_id, + const int /*vertical_filter_id*/, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + const int filter_index = GetFilterIndex(horizontal_filter_index, width); + // Set |src| to the outermost tap. + const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset; + auto* dest = static_cast<uint8_t*>(prediction); + + DoHorizontalPass(src, reference_stride, dest, pred_stride, width, height, + horizontal_filter_id, filter_index); +} + +// The 1D compound shift is always |kInterRoundBitsHorizontal|, even for 1D +// Vertical calculations. +uint16x8_t Compound1DShift(const int16x8_t sum) { + return vreinterpretq_u16_s16( + vrshrq_n_s16(sum, kInterRoundBitsHorizontal - 1)); +} + +template <int filter_index, bool is_compound = false, + bool negative_outside_taps = false> +void FilterVertical(const uint8_t* src, const ptrdiff_t src_stride, + void* const dst, const ptrdiff_t dst_stride, + const int width, const int height, + const uint8x8_t* const taps) { + const int num_taps = GetNumTapsInFilter(filter_index); + const int next_row = num_taps - 1; + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + assert(width >= 8); + + int x = 0; + do { + const uint8_t* src_x = src + x; + uint8x8_t srcs[8]; + srcs[0] = vld1_u8(src_x); + src_x += src_stride; + if (num_taps >= 4) { + srcs[1] = vld1_u8(src_x); + src_x += src_stride; + srcs[2] = vld1_u8(src_x); + src_x += src_stride; + if (num_taps >= 6) { + srcs[3] = vld1_u8(src_x); + src_x += src_stride; + srcs[4] = vld1_u8(src_x); + src_x += src_stride; + if (num_taps == 8) { + srcs[5] = vld1_u8(src_x); + src_x += src_stride; + srcs[6] = vld1_u8(src_x); + src_x += src_stride; + } + } + } + + int y = 0; + do { + srcs[next_row] = vld1_u8(src_x); + src_x += src_stride; + + const int16x8_t sums = + SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps); + if (is_compound) { + const uint16x8_t results = Compound1DShift(sums); + vst1q_u16(dst16 + x + y * dst_stride, results); + } else { + const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1); + vst1_u8(dst8 + x + y * dst_stride, results); + } + + srcs[0] = srcs[1]; + if (num_taps >= 4) { + srcs[1] = srcs[2]; + srcs[2] = srcs[3]; + if (num_taps >= 6) { + srcs[3] = srcs[4]; + srcs[4] = srcs[5]; + if (num_taps == 8) { + srcs[5] = srcs[6]; + srcs[6] = srcs[7]; + } + } + } + } while (++y < height); + x += 8; + } while (x < width); +} + +template <int filter_index, bool is_compound = false, + bool negative_outside_taps = false> +void FilterVertical4xH(const uint8_t* src, const ptrdiff_t src_stride, + void* const dst, const ptrdiff_t dst_stride, + const int height, const uint8x8_t* const taps) { + const int num_taps = GetNumTapsInFilter(filter_index); + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + uint8x8_t srcs[9]; + + if (num_taps == 2) { + srcs[2] = vdup_n_u8(0); + + srcs[0] = Load4(src); + src += src_stride; + + int y = 0; + do { + srcs[0] = Load4<1>(src, srcs[0]); + src += src_stride; + srcs[2] = Load4<0>(src, srcs[2]); + src += src_stride; + srcs[1] = vext_u8(srcs[0], srcs[2], 4); + + const int16x8_t sums = + SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps); + if (is_compound) { + const uint16x8_t results = Compound1DShift(sums); + + vst1q_u16(dst16, results); + dst16 += 4 << 1; + } else { + const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1); + + StoreLo4(dst8, results); + dst8 += dst_stride; + StoreHi4(dst8, results); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + y += 2; + } while (y < height); + } else if (num_taps == 4) { + srcs[4] = vdup_n_u8(0); + + srcs[0] = Load4(src); + src += src_stride; + srcs[0] = Load4<1>(src, srcs[0]); + src += src_stride; + srcs[2] = Load4(src); + src += src_stride; + srcs[1] = vext_u8(srcs[0], srcs[2], 4); + + int y = 0; + do { + srcs[2] = Load4<1>(src, srcs[2]); + src += src_stride; + srcs[4] = Load4<0>(src, srcs[4]); + src += src_stride; + srcs[3] = vext_u8(srcs[2], srcs[4], 4); + + const int16x8_t sums = + SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps); + if (is_compound) { + const uint16x8_t results = Compound1DShift(sums); + + vst1q_u16(dst16, results); + dst16 += 4 << 1; + } else { + const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1); + + StoreLo4(dst8, results); + dst8 += dst_stride; + StoreHi4(dst8, results); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + y += 2; + } while (y < height); + } else if (num_taps == 6) { + srcs[6] = vdup_n_u8(0); + + srcs[0] = Load4(src); + src += src_stride; + srcs[0] = Load4<1>(src, srcs[0]); + src += src_stride; + srcs[2] = Load4(src); + src += src_stride; + srcs[1] = vext_u8(srcs[0], srcs[2], 4); + srcs[2] = Load4<1>(src, srcs[2]); + src += src_stride; + srcs[4] = Load4(src); + src += src_stride; + srcs[3] = vext_u8(srcs[2], srcs[4], 4); + + int y = 0; + do { + srcs[4] = Load4<1>(src, srcs[4]); + src += src_stride; + srcs[6] = Load4<0>(src, srcs[6]); + src += src_stride; + srcs[5] = vext_u8(srcs[4], srcs[6], 4); + + const int16x8_t sums = + SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps); + if (is_compound) { + const uint16x8_t results = Compound1DShift(sums); + + vst1q_u16(dst16, results); + dst16 += 4 << 1; + } else { + const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1); + + StoreLo4(dst8, results); + dst8 += dst_stride; + StoreHi4(dst8, results); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + srcs[3] = srcs[5]; + srcs[4] = srcs[6]; + y += 2; + } while (y < height); + } else if (num_taps == 8) { + srcs[8] = vdup_n_u8(0); + + srcs[0] = Load4(src); + src += src_stride; + srcs[0] = Load4<1>(src, srcs[0]); + src += src_stride; + srcs[2] = Load4(src); + src += src_stride; + srcs[1] = vext_u8(srcs[0], srcs[2], 4); + srcs[2] = Load4<1>(src, srcs[2]); + src += src_stride; + srcs[4] = Load4(src); + src += src_stride; + srcs[3] = vext_u8(srcs[2], srcs[4], 4); + srcs[4] = Load4<1>(src, srcs[4]); + src += src_stride; + srcs[6] = Load4(src); + src += src_stride; + srcs[5] = vext_u8(srcs[4], srcs[6], 4); + + int y = 0; + do { + srcs[6] = Load4<1>(src, srcs[6]); + src += src_stride; + srcs[8] = Load4<0>(src, srcs[8]); + src += src_stride; + srcs[7] = vext_u8(srcs[6], srcs[8], 4); + + const int16x8_t sums = + SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps); + if (is_compound) { + const uint16x8_t results = Compound1DShift(sums); + + vst1q_u16(dst16, results); + dst16 += 4 << 1; + } else { + const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1); + + StoreLo4(dst8, results); + dst8 += dst_stride; + StoreHi4(dst8, results); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + srcs[3] = srcs[5]; + srcs[4] = srcs[6]; + srcs[5] = srcs[7]; + srcs[6] = srcs[8]; + y += 2; + } while (y < height); + } +} + +template <int filter_index, bool negative_outside_taps = false> +void FilterVertical2xH(const uint8_t* src, const ptrdiff_t src_stride, + void* const dst, const ptrdiff_t dst_stride, + const int height, const uint8x8_t* const taps) { + const int num_taps = GetNumTapsInFilter(filter_index); + auto* dst8 = static_cast<uint8_t*>(dst); + + uint8x8_t srcs[9]; + + if (num_taps == 2) { + srcs[2] = vdup_n_u8(0); + + srcs[0] = Load2(src); + src += src_stride; + + int y = 0; + do { + srcs[0] = Load2<1>(src, srcs[0]); + src += src_stride; + srcs[0] = Load2<2>(src, srcs[0]); + src += src_stride; + srcs[0] = Load2<3>(src, srcs[0]); + src += src_stride; + srcs[2] = Load2<0>(src, srcs[2]); + src += src_stride; + srcs[1] = vext_u8(srcs[0], srcs[2], 2); + + // This uses srcs[0]..srcs[1]. + const int16x8_t sums = + SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps); + const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1); + + Store2<0>(dst8, results); + dst8 += dst_stride; + Store2<1>(dst8, results); + if (height == 2) return; + dst8 += dst_stride; + Store2<2>(dst8, results); + dst8 += dst_stride; + Store2<3>(dst8, results); + dst8 += dst_stride; + + srcs[0] = srcs[2]; + y += 4; + } while (y < height); + } else if (num_taps == 4) { + srcs[4] = vdup_n_u8(0); + + srcs[0] = Load2(src); + src += src_stride; + srcs[0] = Load2<1>(src, srcs[0]); + src += src_stride; + srcs[0] = Load2<2>(src, srcs[0]); + src += src_stride; + + int y = 0; + do { + srcs[0] = Load2<3>(src, srcs[0]); + src += src_stride; + srcs[4] = Load2<0>(src, srcs[4]); + src += src_stride; + srcs[1] = vext_u8(srcs[0], srcs[4], 2); + srcs[4] = Load2<1>(src, srcs[4]); + src += src_stride; + srcs[2] = vext_u8(srcs[0], srcs[4], 4); + srcs[4] = Load2<2>(src, srcs[4]); + src += src_stride; + srcs[3] = vext_u8(srcs[0], srcs[4], 6); + + // This uses srcs[0]..srcs[3]. + const int16x8_t sums = + SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps); + const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1); + + Store2<0>(dst8, results); + dst8 += dst_stride; + Store2<1>(dst8, results); + if (height == 2) return; + dst8 += dst_stride; + Store2<2>(dst8, results); + dst8 += dst_stride; + Store2<3>(dst8, results); + dst8 += dst_stride; + + srcs[0] = srcs[4]; + y += 4; + } while (y < height); + } else if (num_taps == 6) { + // During the vertical pass the number of taps is restricted when + // |height| <= 4. + assert(height > 4); + srcs[8] = vdup_n_u8(0); + + srcs[0] = Load2(src); + src += src_stride; + srcs[0] = Load2<1>(src, srcs[0]); + src += src_stride; + srcs[0] = Load2<2>(src, srcs[0]); + src += src_stride; + srcs[0] = Load2<3>(src, srcs[0]); + src += src_stride; + srcs[4] = Load2(src); + src += src_stride; + srcs[1] = vext_u8(srcs[0], srcs[4], 2); + + int y = 0; + do { + srcs[4] = Load2<1>(src, srcs[4]); + src += src_stride; + srcs[2] = vext_u8(srcs[0], srcs[4], 4); + srcs[4] = Load2<2>(src, srcs[4]); + src += src_stride; + srcs[3] = vext_u8(srcs[0], srcs[4], 6); + srcs[4] = Load2<3>(src, srcs[4]); + src += src_stride; + srcs[8] = Load2<0>(src, srcs[8]); + src += src_stride; + srcs[5] = vext_u8(srcs[4], srcs[8], 2); + + // This uses srcs[0]..srcs[5]. + const int16x8_t sums = + SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps); + const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1); + + Store2<0>(dst8, results); + dst8 += dst_stride; + Store2<1>(dst8, results); + dst8 += dst_stride; + Store2<2>(dst8, results); + dst8 += dst_stride; + Store2<3>(dst8, results); + dst8 += dst_stride; + + srcs[0] = srcs[4]; + srcs[1] = srcs[5]; + srcs[4] = srcs[8]; + y += 4; + } while (y < height); + } else if (num_taps == 8) { + // During the vertical pass the number of taps is restricted when + // |height| <= 4. + assert(height > 4); + srcs[8] = vdup_n_u8(0); + + srcs[0] = Load2(src); + src += src_stride; + srcs[0] = Load2<1>(src, srcs[0]); + src += src_stride; + srcs[0] = Load2<2>(src, srcs[0]); + src += src_stride; + srcs[0] = Load2<3>(src, srcs[0]); + src += src_stride; + srcs[4] = Load2(src); + src += src_stride; + srcs[1] = vext_u8(srcs[0], srcs[4], 2); + srcs[4] = Load2<1>(src, srcs[4]); + src += src_stride; + srcs[2] = vext_u8(srcs[0], srcs[4], 4); + srcs[4] = Load2<2>(src, srcs[4]); + src += src_stride; + srcs[3] = vext_u8(srcs[0], srcs[4], 6); + + int y = 0; + do { + srcs[4] = Load2<3>(src, srcs[4]); + src += src_stride; + srcs[8] = Load2<0>(src, srcs[8]); + src += src_stride; + srcs[5] = vext_u8(srcs[4], srcs[8], 2); + srcs[8] = Load2<1>(src, srcs[8]); + src += src_stride; + srcs[6] = vext_u8(srcs[4], srcs[8], 4); + srcs[8] = Load2<2>(src, srcs[8]); + src += src_stride; + srcs[7] = vext_u8(srcs[4], srcs[8], 6); + + // This uses srcs[0]..srcs[7]. + const int16x8_t sums = + SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps); + const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1); + + Store2<0>(dst8, results); + dst8 += dst_stride; + Store2<1>(dst8, results); + dst8 += dst_stride; + Store2<2>(dst8, results); + dst8 += dst_stride; + Store2<3>(dst8, results); + dst8 += dst_stride; + + srcs[0] = srcs[4]; + srcs[1] = srcs[5]; + srcs[2] = srcs[6]; + srcs[3] = srcs[7]; + srcs[4] = srcs[8]; + y += 4; + } while (y < height); + } +} + +// This function is a simplified version of Convolve2D_C. +// It is called when it is single prediction mode, where only vertical +// filtering is required. +// The output is the single prediction of the block, clipped to valid pixel +// range. +void ConvolveVertical_NEON(const void* const reference, + const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, + const int vertical_filter_index, + const int /*horizontal_filter_id*/, + const int vertical_filter_id, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + const int filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(filter_index); + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride; + auto* dest = static_cast<uint8_t*>(prediction); + const ptrdiff_t dest_stride = pred_stride; + assert(vertical_filter_id != 0); + + uint8x8_t taps[8]; + for (int k = 0; k < kSubPixelTaps; ++k) { + taps[k] = + vdup_n_u8(kAbsHalfSubPixelFilters[filter_index][vertical_filter_id][k]); + } + + if (filter_index == 0) { // 6 tap. + if (width == 2) { + FilterVertical2xH<0>(src, src_stride, dest, dest_stride, height, + taps + 1); + } else if (width == 4) { + FilterVertical4xH<0>(src, src_stride, dest, dest_stride, height, + taps + 1); + } else { + FilterVertical<0>(src, src_stride, dest, dest_stride, width, height, + taps + 1); + } + } else if ((filter_index == 1) & ((vertical_filter_id == 1) | + (vertical_filter_id == 15))) { // 5 tap. + if (width == 2) { + FilterVertical2xH<1>(src, src_stride, dest, dest_stride, height, + taps + 1); + } else if (width == 4) { + FilterVertical4xH<1>(src, src_stride, dest, dest_stride, height, + taps + 1); + } else { + FilterVertical<1>(src, src_stride, dest, dest_stride, width, height, + taps + 1); + } + } else if ((filter_index == 1) & + ((vertical_filter_id == 7) | (vertical_filter_id == 8) | + (vertical_filter_id == 9))) { // 6 tap with weird negative taps. + if (width == 2) { + FilterVertical2xH<1, + /*negative_outside_taps=*/true>( + src, src_stride, dest, dest_stride, height, taps + 1); + } else if (width == 4) { + FilterVertical4xH<1, /*is_compound=*/false, + /*negative_outside_taps=*/true>( + src, src_stride, dest, dest_stride, height, taps + 1); + } else { + FilterVertical<1, /*is_compound=*/false, /*negative_outside_taps=*/true>( + src, src_stride, dest, dest_stride, width, height, taps + 1); + } + } else if (filter_index == 2) { // 8 tap. + if (width == 2) { + FilterVertical2xH<2>(src, src_stride, dest, dest_stride, height, taps); + } else if (width == 4) { + FilterVertical4xH<2>(src, src_stride, dest, dest_stride, height, taps); + } else { + FilterVertical<2>(src, src_stride, dest, dest_stride, width, height, + taps); + } + } else if (filter_index == 3) { // 2 tap. + if (width == 2) { + FilterVertical2xH<3>(src, src_stride, dest, dest_stride, height, + taps + 3); + } else if (width == 4) { + FilterVertical4xH<3>(src, src_stride, dest, dest_stride, height, + taps + 3); + } else { + FilterVertical<3>(src, src_stride, dest, dest_stride, width, height, + taps + 3); + } + } else if (filter_index == 4) { // 4 tap. + // Outside taps are negative. + if (width == 2) { + FilterVertical2xH<4>(src, src_stride, dest, dest_stride, height, + taps + 2); + } else if (width == 4) { + FilterVertical4xH<4>(src, src_stride, dest, dest_stride, height, + taps + 2); + } else { + FilterVertical<4>(src, src_stride, dest, dest_stride, width, height, + taps + 2); + } + } else { + // 4 tap. When |filter_index| == 1 the |vertical_filter_id| values listed + // below map to 4 tap filters. + assert(filter_index == 5 || + (filter_index == 1 && + (vertical_filter_id == 2 || vertical_filter_id == 3 || + vertical_filter_id == 4 || vertical_filter_id == 5 || + vertical_filter_id == 6 || vertical_filter_id == 10 || + vertical_filter_id == 11 || vertical_filter_id == 12 || + vertical_filter_id == 13 || vertical_filter_id == 14))); + // According to GetNumTapsInFilter() this has 6 taps but here we are + // treating it as though it has 4. + if (filter_index == 1) src += src_stride; + if (width == 2) { + FilterVertical2xH<5>(src, src_stride, dest, dest_stride, height, + taps + 2); + } else if (width == 4) { + FilterVertical4xH<5>(src, src_stride, dest, dest_stride, height, + taps + 2); + } else { + FilterVertical<5>(src, src_stride, dest, dest_stride, width, height, + taps + 2); + } + } +} + +void ConvolveCompoundCopy_NEON( + const void* const reference, const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/, + const int /*horizontal_filter_id*/, const int /*vertical_filter_id*/, + const int width, const int height, void* prediction, + const ptrdiff_t /*pred_stride*/) { + const auto* src = static_cast<const uint8_t*>(reference); + const ptrdiff_t src_stride = reference_stride; + auto* dest = static_cast<uint16_t*>(prediction); + constexpr int final_shift = + kInterRoundBitsVertical - kInterRoundBitsCompoundVertical; + + if (width >= 16) { + int y = 0; + do { + int x = 0; + do { + const uint8x16_t v_src = vld1q_u8(&src[x]); + const uint16x8_t v_dest_lo = + vshll_n_u8(vget_low_u8(v_src), final_shift); + const uint16x8_t v_dest_hi = + vshll_n_u8(vget_high_u8(v_src), final_shift); + vst1q_u16(&dest[x], v_dest_lo); + x += 8; + vst1q_u16(&dest[x], v_dest_hi); + x += 8; + } while (x < width); + src += src_stride; + dest += width; + } while (++y < height); + } else if (width == 8) { + int y = 0; + do { + const uint8x8_t v_src = vld1_u8(&src[0]); + const uint16x8_t v_dest = vshll_n_u8(v_src, final_shift); + vst1q_u16(&dest[0], v_dest); + src += src_stride; + dest += width; + } while (++y < height); + } else { /* width == 4 */ + uint8x8_t v_src = vdup_n_u8(0); + + int y = 0; + do { + v_src = Load4<0>(&src[0], v_src); + src += src_stride; + v_src = Load4<1>(&src[0], v_src); + src += src_stride; + const uint16x8_t v_dest = vshll_n_u8(v_src, final_shift); + vst1q_u16(&dest[0], v_dest); + dest += 4 << 1; + y += 2; + } while (y < height); + } +} + +void ConvolveCompoundVertical_NEON( + const void* const reference, const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, const int vertical_filter_index, + const int /*horizontal_filter_id*/, const int vertical_filter_id, + const int width, const int height, void* prediction, + const ptrdiff_t /*pred_stride*/) { + const int filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(filter_index); + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride; + auto* dest = static_cast<uint16_t*>(prediction); + assert(vertical_filter_id != 0); + + uint8x8_t taps[8]; + for (int k = 0; k < kSubPixelTaps; ++k) { + taps[k] = + vdup_n_u8(kAbsHalfSubPixelFilters[filter_index][vertical_filter_id][k]); + } + + if (filter_index == 0) { // 6 tap. + if (width == 4) { + FilterVertical4xH<0, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps + 1); + } else { + FilterVertical<0, /*is_compound=*/true>(src, src_stride, dest, width, + width, height, taps + 1); + } + } else if ((filter_index == 1) & ((vertical_filter_id == 1) | + (vertical_filter_id == 15))) { // 5 tap. + if (width == 4) { + FilterVertical4xH<1, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps + 1); + } else { + FilterVertical<1, /*is_compound=*/true>(src, src_stride, dest, width, + width, height, taps + 1); + } + } else if ((filter_index == 1) & + ((vertical_filter_id == 7) | (vertical_filter_id == 8) | + (vertical_filter_id == 9))) { // 6 tap with weird negative taps. + if (width == 4) { + FilterVertical4xH<1, /*is_compound=*/true, + /*negative_outside_taps=*/true>(src, src_stride, dest, + 4, height, taps + 1); + } else { + FilterVertical<1, /*is_compound=*/true, /*negative_outside_taps=*/true>( + src, src_stride, dest, width, width, height, taps + 1); + } + } else if (filter_index == 2) { // 8 tap. + if (width == 4) { + FilterVertical4xH<2, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps); + } else { + FilterVertical<2, /*is_compound=*/true>(src, src_stride, dest, width, + width, height, taps); + } + } else if (filter_index == 3) { // 2 tap. + if (width == 4) { + FilterVertical4xH<3, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps + 3); + } else { + FilterVertical<3, /*is_compound=*/true>(src, src_stride, dest, width, + width, height, taps + 3); + } + } else if (filter_index == 4) { // 4 tap. + if (width == 4) { + FilterVertical4xH<4, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps + 2); + } else { + FilterVertical<4, /*is_compound=*/true>(src, src_stride, dest, width, + width, height, taps + 2); + } + } else { + // 4 tap. When |filter_index| == 1 the |filter_id| values listed below map + // to 4 tap filters. + assert(filter_index == 5 || + (filter_index == 1 && + (vertical_filter_id == 2 || vertical_filter_id == 3 || + vertical_filter_id == 4 || vertical_filter_id == 5 || + vertical_filter_id == 6 || vertical_filter_id == 10 || + vertical_filter_id == 11 || vertical_filter_id == 12 || + vertical_filter_id == 13 || vertical_filter_id == 14))); + // According to GetNumTapsInFilter() this has 6 taps but here we are + // treating it as though it has 4. + if (filter_index == 1) src += src_stride; + if (width == 4) { + FilterVertical4xH<5, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps + 2); + } else { + FilterVertical<5, /*is_compound=*/true>(src, src_stride, dest, width, + width, height, taps + 2); + } + } +} + +void ConvolveCompoundHorizontal_NEON( + const void* const reference, const ptrdiff_t reference_stride, + const int horizontal_filter_index, const int /*vertical_filter_index*/, + const int horizontal_filter_id, const int /*vertical_filter_id*/, + const int width, const int height, void* prediction, + const ptrdiff_t /*pred_stride*/) { + const int filter_index = GetFilterIndex(horizontal_filter_index, width); + const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset; + auto* dest = static_cast<uint16_t*>(prediction); + + DoHorizontalPass</*is_2d=*/false, /*is_compound=*/true>( + src, reference_stride, dest, width, width, height, horizontal_filter_id, + filter_index); +} + +void ConvolveCompound2D_NEON(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int vertical_filter_index, + const int horizontal_filter_id, + const int vertical_filter_id, const int width, + const int height, void* prediction, + const ptrdiff_t /*pred_stride*/) { + // The output of the horizontal filter, i.e. the intermediate_result, is + // guaranteed to fit in int16_t. + uint16_t + intermediate_result[kMaxSuperBlockSizeInPixels * + (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)]; + + // Horizontal filter. + // Filter types used for width <= 4 are different from those for width > 4. + // When width > 4, the valid filter index range is always [0, 3]. + // When width <= 4, the valid filter index range is always [4, 5]. + // Similarly for height. + const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width); + const int vert_filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(vert_filter_index); + const int intermediate_height = height + vertical_taps - 1; + const ptrdiff_t src_stride = reference_stride; + const auto* const src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride - + kHorizontalOffset; + + DoHorizontalPass</*is_2d=*/true, /*is_compound=*/true>( + src, src_stride, intermediate_result, width, width, intermediate_height, + horizontal_filter_id, horiz_filter_index); + + // Vertical filter. + auto* dest = static_cast<uint16_t*>(prediction); + assert(vertical_filter_id != 0); + + const ptrdiff_t dest_stride = width; + const int16x8_t taps = vmovl_s8( + vld1_s8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id])); + + if (vertical_taps == 8) { + if (width == 4) { + Filter2DVertical4xH<8, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<8, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } else if (vertical_taps == 6) { + if (width == 4) { + Filter2DVertical4xH<6, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<6, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } else if (vertical_taps == 4) { + if (width == 4) { + Filter2DVertical4xH<4, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<4, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } else { // |vertical_taps| == 2 + if (width == 4) { + Filter2DVertical4xH<2, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<2, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } +} + +inline void HalfAddHorizontal(const uint8_t* src, uint8_t* dst) { + const uint8x16_t left = vld1q_u8(src); + const uint8x16_t right = vld1q_u8(src + 1); + vst1q_u8(dst, vrhaddq_u8(left, right)); +} + +template <int width> +inline void IntraBlockCopyHorizontal(const uint8_t* src, + const ptrdiff_t src_stride, + const int height, uint8_t* dst, + const ptrdiff_t dst_stride) { + const ptrdiff_t src_remainder_stride = src_stride - (width - 16); + const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16); + + int y = 0; + do { + HalfAddHorizontal(src, dst); + if (width >= 32) { + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + if (width >= 64) { + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + if (width == 128) { + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + } + } + } + src += src_remainder_stride; + dst += dst_remainder_stride; + } while (++y < height); +} + +void ConvolveIntraBlockCopyHorizontal_NEON( + const void* const reference, const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/, + const int /*subpixel_x*/, const int /*subpixel_y*/, const int width, + const int height, void* const prediction, const ptrdiff_t pred_stride) { + const auto* src = static_cast<const uint8_t*>(reference); + auto* dest = static_cast<uint8_t*>(prediction); + + if (width == 128) { + IntraBlockCopyHorizontal<128>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 64) { + IntraBlockCopyHorizontal<64>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 32) { + IntraBlockCopyHorizontal<32>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 16) { + IntraBlockCopyHorizontal<16>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 8) { + int y = 0; + do { + const uint8x8_t left = vld1_u8(src); + const uint8x8_t right = vld1_u8(src + 1); + vst1_u8(dest, vrhadd_u8(left, right)); + + src += reference_stride; + dest += pred_stride; + } while (++y < height); + } else if (width == 4) { + uint8x8_t left = vdup_n_u8(0); + uint8x8_t right = vdup_n_u8(0); + int y = 0; + do { + left = Load4<0>(src, left); + right = Load4<0>(src + 1, right); + src += reference_stride; + left = Load4<1>(src, left); + right = Load4<1>(src + 1, right); + src += reference_stride; + + const uint8x8_t result = vrhadd_u8(left, right); + + StoreLo4(dest, result); + dest += pred_stride; + StoreHi4(dest, result); + dest += pred_stride; + y += 2; + } while (y < height); + } else { + assert(width == 2); + uint8x8_t left = vdup_n_u8(0); + uint8x8_t right = vdup_n_u8(0); + int y = 0; + do { + left = Load2<0>(src, left); + right = Load2<0>(src + 1, right); + src += reference_stride; + left = Load2<1>(src, left); + right = Load2<1>(src + 1, right); + src += reference_stride; + + const uint8x8_t result = vrhadd_u8(left, right); + + Store2<0>(dest, result); + dest += pred_stride; + Store2<1>(dest, result); + dest += pred_stride; + y += 2; + } while (y < height); + } +} + +template <int width> +inline void IntraBlockCopyVertical(const uint8_t* src, + const ptrdiff_t src_stride, const int height, + uint8_t* dst, const ptrdiff_t dst_stride) { + const ptrdiff_t src_remainder_stride = src_stride - (width - 16); + const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16); + uint8x16_t row[8], below[8]; + + row[0] = vld1q_u8(src); + if (width >= 32) { + src += 16; + row[1] = vld1q_u8(src); + if (width >= 64) { + src += 16; + row[2] = vld1q_u8(src); + src += 16; + row[3] = vld1q_u8(src); + if (width == 128) { + src += 16; + row[4] = vld1q_u8(src); + src += 16; + row[5] = vld1q_u8(src); + src += 16; + row[6] = vld1q_u8(src); + src += 16; + row[7] = vld1q_u8(src); + } + } + } + src += src_remainder_stride; + + int y = 0; + do { + below[0] = vld1q_u8(src); + if (width >= 32) { + src += 16; + below[1] = vld1q_u8(src); + if (width >= 64) { + src += 16; + below[2] = vld1q_u8(src); + src += 16; + below[3] = vld1q_u8(src); + if (width == 128) { + src += 16; + below[4] = vld1q_u8(src); + src += 16; + below[5] = vld1q_u8(src); + src += 16; + below[6] = vld1q_u8(src); + src += 16; + below[7] = vld1q_u8(src); + } + } + } + src += src_remainder_stride; + + vst1q_u8(dst, vrhaddq_u8(row[0], below[0])); + row[0] = below[0]; + if (width >= 32) { + dst += 16; + vst1q_u8(dst, vrhaddq_u8(row[1], below[1])); + row[1] = below[1]; + if (width >= 64) { + dst += 16; + vst1q_u8(dst, vrhaddq_u8(row[2], below[2])); + row[2] = below[2]; + dst += 16; + vst1q_u8(dst, vrhaddq_u8(row[3], below[3])); + row[3] = below[3]; + if (width >= 128) { + dst += 16; + vst1q_u8(dst, vrhaddq_u8(row[4], below[4])); + row[4] = below[4]; + dst += 16; + vst1q_u8(dst, vrhaddq_u8(row[5], below[5])); + row[5] = below[5]; + dst += 16; + vst1q_u8(dst, vrhaddq_u8(row[6], below[6])); + row[6] = below[6]; + dst += 16; + vst1q_u8(dst, vrhaddq_u8(row[7], below[7])); + row[7] = below[7]; + } + } + } + dst += dst_remainder_stride; + } while (++y < height); +} + +void ConvolveIntraBlockCopyVertical_NEON( + const void* const reference, const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/, + const int /*horizontal_filter_id*/, const int /*vertical_filter_id*/, + const int width, const int height, void* const prediction, + const ptrdiff_t pred_stride) { + const auto* src = static_cast<const uint8_t*>(reference); + auto* dest = static_cast<uint8_t*>(prediction); + + if (width == 128) { + IntraBlockCopyVertical<128>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 64) { + IntraBlockCopyVertical<64>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 32) { + IntraBlockCopyVertical<32>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 16) { + IntraBlockCopyVertical<16>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 8) { + uint8x8_t row, below; + row = vld1_u8(src); + src += reference_stride; + + int y = 0; + do { + below = vld1_u8(src); + src += reference_stride; + + vst1_u8(dest, vrhadd_u8(row, below)); + dest += pred_stride; + + row = below; + } while (++y < height); + } else if (width == 4) { + uint8x8_t row = Load4(src); + uint8x8_t below = vdup_n_u8(0); + src += reference_stride; + + int y = 0; + do { + below = Load4<0>(src, below); + src += reference_stride; + + StoreLo4(dest, vrhadd_u8(row, below)); + dest += pred_stride; + + row = below; + } while (++y < height); + } else { + assert(width == 2); + uint8x8_t row = Load2(src); + uint8x8_t below = vdup_n_u8(0); + src += reference_stride; + + int y = 0; + do { + below = Load2<0>(src, below); + src += reference_stride; + + Store2<0>(dest, vrhadd_u8(row, below)); + dest += pred_stride; + + row = below; + } while (++y < height); + } +} + +template <int width> +inline void IntraBlockCopy2D(const uint8_t* src, const ptrdiff_t src_stride, + const int height, uint8_t* dst, + const ptrdiff_t dst_stride) { + const ptrdiff_t src_remainder_stride = src_stride - (width - 8); + const ptrdiff_t dst_remainder_stride = dst_stride - (width - 8); + uint16x8_t row[16]; + row[0] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + if (width >= 16) { + src += 8; + row[1] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + if (width >= 32) { + src += 8; + row[2] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[3] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + if (width >= 64) { + src += 8; + row[4] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[5] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[6] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[7] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + if (width == 128) { + src += 8; + row[8] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[9] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[10] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[11] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[12] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[13] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[14] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + src += 8; + row[15] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + } + } + } + } + src += src_remainder_stride; + + int y = 0; + do { + const uint16x8_t below_0 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[0], below_0), 2)); + row[0] = below_0; + if (width >= 16) { + src += 8; + dst += 8; + + const uint16x8_t below_1 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[1], below_1), 2)); + row[1] = below_1; + if (width >= 32) { + src += 8; + dst += 8; + + const uint16x8_t below_2 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[2], below_2), 2)); + row[2] = below_2; + src += 8; + dst += 8; + + const uint16x8_t below_3 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[3], below_3), 2)); + row[3] = below_3; + if (width >= 64) { + src += 8; + dst += 8; + + const uint16x8_t below_4 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[4], below_4), 2)); + row[4] = below_4; + src += 8; + dst += 8; + + const uint16x8_t below_5 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[5], below_5), 2)); + row[5] = below_5; + src += 8; + dst += 8; + + const uint16x8_t below_6 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[6], below_6), 2)); + row[6] = below_6; + src += 8; + dst += 8; + + const uint16x8_t below_7 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[7], below_7), 2)); + row[7] = below_7; + if (width == 128) { + src += 8; + dst += 8; + + const uint16x8_t below_8 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[8], below_8), 2)); + row[8] = below_8; + src += 8; + dst += 8; + + const uint16x8_t below_9 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[9], below_9), 2)); + row[9] = below_9; + src += 8; + dst += 8; + + const uint16x8_t below_10 = + vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[10], below_10), 2)); + row[10] = below_10; + src += 8; + dst += 8; + + const uint16x8_t below_11 = + vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[11], below_11), 2)); + row[11] = below_11; + src += 8; + dst += 8; + + const uint16x8_t below_12 = + vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[12], below_12), 2)); + row[12] = below_12; + src += 8; + dst += 8; + + const uint16x8_t below_13 = + vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[13], below_13), 2)); + row[13] = below_13; + src += 8; + dst += 8; + + const uint16x8_t below_14 = + vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[14], below_14), 2)); + row[14] = below_14; + src += 8; + dst += 8; + + const uint16x8_t below_15 = + vaddl_u8(vld1_u8(src), vld1_u8(src + 1)); + vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[15], below_15), 2)); + row[15] = below_15; + } + } + } + } + src += src_remainder_stride; + dst += dst_remainder_stride; + } while (++y < height); +} + +void ConvolveIntraBlockCopy2D_NEON( + const void* const reference, const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/, + const int /*horizontal_filter_id*/, const int /*vertical_filter_id*/, + const int width, const int height, void* const prediction, + const ptrdiff_t pred_stride) { + const auto* src = static_cast<const uint8_t*>(reference); + auto* dest = static_cast<uint8_t*>(prediction); + // Note: allow vertical access to height + 1. Because this function is only + // for u/v plane of intra block copy, such access is guaranteed to be within + // the prediction block. + + if (width == 128) { + IntraBlockCopy2D<128>(src, reference_stride, height, dest, pred_stride); + } else if (width == 64) { + IntraBlockCopy2D<64>(src, reference_stride, height, dest, pred_stride); + } else if (width == 32) { + IntraBlockCopy2D<32>(src, reference_stride, height, dest, pred_stride); + } else if (width == 16) { + IntraBlockCopy2D<16>(src, reference_stride, height, dest, pred_stride); + } else if (width == 8) { + IntraBlockCopy2D<8>(src, reference_stride, height, dest, pred_stride); + } else if (width == 4) { + uint8x8_t left = Load4(src); + uint8x8_t right = Load4(src + 1); + src += reference_stride; + + uint16x4_t row = vget_low_u16(vaddl_u8(left, right)); + + int y = 0; + do { + left = Load4<0>(src, left); + right = Load4<0>(src + 1, right); + src += reference_stride; + left = Load4<1>(src, left); + right = Load4<1>(src + 1, right); + src += reference_stride; + + const uint16x8_t below = vaddl_u8(left, right); + + const uint8x8_t result = vrshrn_n_u16( + vaddq_u16(vcombine_u16(row, vget_low_u16(below)), below), 2); + StoreLo4(dest, result); + dest += pred_stride; + StoreHi4(dest, result); + dest += pred_stride; + + row = vget_high_u16(below); + y += 2; + } while (y < height); + } else { + uint8x8_t left = Load2(src); + uint8x8_t right = Load2(src + 1); + src += reference_stride; + + uint16x4_t row = vget_low_u16(vaddl_u8(left, right)); + + int y = 0; + do { + left = Load2<0>(src, left); + right = Load2<0>(src + 1, right); + src += reference_stride; + left = Load2<2>(src, left); + right = Load2<2>(src + 1, right); + src += reference_stride; + + const uint16x8_t below = vaddl_u8(left, right); + + const uint8x8_t result = vrshrn_n_u16( + vaddq_u16(vcombine_u16(row, vget_low_u16(below)), below), 2); + Store2<0>(dest, result); + dest += pred_stride; + Store2<2>(dest, result); + dest += pred_stride; + + row = vget_high_u16(below); + y += 2; + } while (y < height); + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->convolve[0][0][0][1] = ConvolveHorizontal_NEON; + dsp->convolve[0][0][1][0] = ConvolveVertical_NEON; + dsp->convolve[0][0][1][1] = Convolve2D_NEON; + + dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_NEON; + dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_NEON; + dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_NEON; + dsp->convolve[0][1][1][1] = ConvolveCompound2D_NEON; + + dsp->convolve[1][0][0][1] = ConvolveIntraBlockCopyHorizontal_NEON; + dsp->convolve[1][0][1][0] = ConvolveIntraBlockCopyVertical_NEON; + dsp->convolve[1][0][1][1] = ConvolveIntraBlockCopy2D_NEON; + + dsp->convolve_scale[0] = ConvolveScale2D_NEON<false>; + dsp->convolve_scale[1] = ConvolveScale2D_NEON<true>; +} + +} // namespace +} // namespace low_bitdepth + +void ConvolveInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON + +namespace libgav1 { +namespace dsp { + +void ConvolveInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/convolve_neon.h b/src/dsp/arm/convolve_neon.h new file mode 100644 index 0000000..948ef4d --- /dev/null +++ b/src/dsp/arm/convolve_neon.h @@ -0,0 +1,50 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_CONVOLVE_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_CONVOLVE_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::convolve. This function is not thread-safe. +void ConvolveInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_ConvolveHorizontal LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_ConvolveVertical LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_Convolve2D LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_ConvolveCompoundCopy LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_ConvolveCompoundHorizontal LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_ConvolveCompoundVertical LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_ConvolveCompound2D LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopyHorizontal LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopyVertical LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopy2D LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_ConvolveScale2D LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_ConvolveCompoundScale2D LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_CONVOLVE_NEON_H_ diff --git a/src/dsp/arm/distance_weighted_blend_neon.cc b/src/dsp/arm/distance_weighted_blend_neon.cc new file mode 100644 index 0000000..04952ab --- /dev/null +++ b/src/dsp/arm/distance_weighted_blend_neon.cc @@ -0,0 +1,203 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/distance_weighted_blend.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace { + +constexpr int kInterPostRoundBit = 4; + +inline int16x8_t ComputeWeightedAverage8(const int16x8_t pred0, + const int16x8_t pred1, + const int16x4_t weights[2]) { + // TODO(https://issuetracker.google.com/issues/150325685): Investigate range. + const int32x4_t wpred0_lo = vmull_s16(weights[0], vget_low_s16(pred0)); + const int32x4_t wpred0_hi = vmull_s16(weights[0], vget_high_s16(pred0)); + const int32x4_t blended_lo = + vmlal_s16(wpred0_lo, weights[1], vget_low_s16(pred1)); + const int32x4_t blended_hi = + vmlal_s16(wpred0_hi, weights[1], vget_high_s16(pred1)); + + return vcombine_s16(vqrshrn_n_s32(blended_lo, kInterPostRoundBit + 4), + vqrshrn_n_s32(blended_hi, kInterPostRoundBit + 4)); +} + +template <int width, int height> +inline void DistanceWeightedBlendSmall_NEON(const int16_t* prediction_0, + const int16_t* prediction_1, + const int16x4_t weights[2], + void* const dest, + const ptrdiff_t dest_stride) { + auto* dst = static_cast<uint8_t*>(dest); + constexpr int step = 16 / width; + + for (int y = 0; y < height; y += step) { + const int16x8_t src_00 = vld1q_s16(prediction_0); + const int16x8_t src_10 = vld1q_s16(prediction_1); + prediction_0 += 8; + prediction_1 += 8; + const int16x8_t res0 = ComputeWeightedAverage8(src_00, src_10, weights); + + const int16x8_t src_01 = vld1q_s16(prediction_0); + const int16x8_t src_11 = vld1q_s16(prediction_1); + prediction_0 += 8; + prediction_1 += 8; + const int16x8_t res1 = ComputeWeightedAverage8(src_01, src_11, weights); + + const uint8x8_t result0 = vqmovun_s16(res0); + const uint8x8_t result1 = vqmovun_s16(res1); + if (width == 4) { + StoreLo4(dst, result0); + dst += dest_stride; + StoreHi4(dst, result0); + dst += dest_stride; + StoreLo4(dst, result1); + dst += dest_stride; + StoreHi4(dst, result1); + dst += dest_stride; + } else { + assert(width == 8); + vst1_u8(dst, result0); + dst += dest_stride; + vst1_u8(dst, result1); + dst += dest_stride; + } + } +} + +inline void DistanceWeightedBlendLarge_NEON(const int16_t* prediction_0, + const int16_t* prediction_1, + const int16x4_t weights[2], + const int width, const int height, + void* const dest, + const ptrdiff_t dest_stride) { + auto* dst = static_cast<uint8_t*>(dest); + + int y = height; + do { + int x = 0; + do { + const int16x8_t src0_lo = vld1q_s16(prediction_0 + x); + const int16x8_t src1_lo = vld1q_s16(prediction_1 + x); + const int16x8_t res_lo = + ComputeWeightedAverage8(src0_lo, src1_lo, weights); + + const int16x8_t src0_hi = vld1q_s16(prediction_0 + x + 8); + const int16x8_t src1_hi = vld1q_s16(prediction_1 + x + 8); + const int16x8_t res_hi = + ComputeWeightedAverage8(src0_hi, src1_hi, weights); + + const uint8x16_t result = + vcombine_u8(vqmovun_s16(res_lo), vqmovun_s16(res_hi)); + vst1q_u8(dst + x, result); + x += 16; + } while (x < width); + dst += dest_stride; + prediction_0 += width; + prediction_1 += width; + } while (--y != 0); +} + +inline void DistanceWeightedBlend_NEON(const void* prediction_0, + const void* prediction_1, + const uint8_t weight_0, + const uint8_t weight_1, const int width, + const int height, void* const dest, + const ptrdiff_t dest_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int16x4_t weights[2] = {vdup_n_s16(weight_0), vdup_n_s16(weight_1)}; + // TODO(johannkoenig): Investigate the branching. May be fine to call with a + // variable height. + if (width == 4) { + if (height == 4) { + DistanceWeightedBlendSmall_NEON<4, 4>(pred_0, pred_1, weights, dest, + dest_stride); + } else if (height == 8) { + DistanceWeightedBlendSmall_NEON<4, 8>(pred_0, pred_1, weights, dest, + dest_stride); + } else { + assert(height == 16); + DistanceWeightedBlendSmall_NEON<4, 16>(pred_0, pred_1, weights, dest, + dest_stride); + } + return; + } + + if (width == 8) { + switch (height) { + case 4: + DistanceWeightedBlendSmall_NEON<8, 4>(pred_0, pred_1, weights, dest, + dest_stride); + return; + case 8: + DistanceWeightedBlendSmall_NEON<8, 8>(pred_0, pred_1, weights, dest, + dest_stride); + return; + case 16: + DistanceWeightedBlendSmall_NEON<8, 16>(pred_0, pred_1, weights, dest, + dest_stride); + return; + default: + assert(height == 32); + DistanceWeightedBlendSmall_NEON<8, 32>(pred_0, pred_1, weights, dest, + dest_stride); + + return; + } + } + + DistanceWeightedBlendLarge_NEON(pred_0, pred_1, weights, width, height, dest, + dest_stride); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->distance_weighted_blend = DistanceWeightedBlend_NEON; +} + +} // namespace + +void DistanceWeightedBlendInit_NEON() { Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON + +namespace libgav1 { +namespace dsp { + +void DistanceWeightedBlendInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/distance_weighted_blend_neon.h b/src/dsp/arm/distance_weighted_blend_neon.h new file mode 100644 index 0000000..4d8824c --- /dev/null +++ b/src/dsp/arm/distance_weighted_blend_neon.h @@ -0,0 +1,39 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_DISTANCE_WEIGHTED_BLEND_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_DISTANCE_WEIGHTED_BLEND_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::distance_weighted_blend. This function is not thread-safe. +void DistanceWeightedBlendInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +// If NEON is enabled signal the NEON implementation should be used instead of +// normal C. +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_DistanceWeightedBlend LIBGAV1_CPU_NEON + +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_DISTANCE_WEIGHTED_BLEND_NEON_H_ diff --git a/src/dsp/arm/film_grain_neon.cc b/src/dsp/arm/film_grain_neon.cc new file mode 100644 index 0000000..2612466 --- /dev/null +++ b/src/dsp/arm/film_grain_neon.cc @@ -0,0 +1,1188 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/film_grain.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON +#include <arm_neon.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> +#include <new> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/arm/film_grain_neon.h" +#include "src/dsp/common.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/film_grain_common.h" +#include "src/utils/common.h" +#include "src/utils/compiler_attributes.h" +#include "src/utils/logging.h" + +namespace libgav1 { +namespace dsp { +namespace film_grain { +namespace { + +// These functions are overloaded for both possible sizes in order to simplify +// loading and storing to and from intermediate value types from within a +// template function. +inline int16x8_t GetSignedSource8(const int8_t* src) { + return vmovl_s8(vld1_s8(src)); +} + +inline int16x8_t GetSignedSource8(const uint8_t* src) { + return ZeroExtend(vld1_u8(src)); +} + +inline void StoreUnsigned8(uint8_t* dest, const uint16x8_t data) { + vst1_u8(dest, vmovn_u16(data)); +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +inline int16x8_t GetSignedSource8(const int16_t* src) { return vld1q_s16(src); } + +inline int16x8_t GetSignedSource8(const uint16_t* src) { + return vreinterpretq_s16_u16(vld1q_u16(src)); +} + +inline void StoreUnsigned8(uint16_t* dest, const uint16x8_t data) { + vst1q_u16(dest, data); +} +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +// Each element in |sum| represents one destination value's running +// autoregression formula. The fixed source values in |grain_lo| and |grain_hi| +// allow for a sliding window in successive calls to this function. +template <int position_offset> +inline int32x4x2_t AccumulateWeightedGrain(const int16x8_t grain_lo, + const int16x8_t grain_hi, + int16_t coeff, int32x4x2_t sum) { + const int16x8_t grain = vextq_s16(grain_lo, grain_hi, position_offset); + sum.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(grain), coeff); + sum.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(grain), coeff); + return sum; +} + +// Because the autoregressive filter requires the output of each pixel to +// compute pixels that come after in the row, we have to finish the calculations +// one at a time. +template <int bitdepth, int auto_regression_coeff_lag, int lane> +inline void WriteFinalAutoRegression(int8_t* grain_cursor, int32x4x2_t sum, + const int8_t* coeffs, int pos, int shift) { + int32_t result = vgetq_lane_s32(sum.val[lane >> 2], lane & 3); + + for (int delta_col = -auto_regression_coeff_lag; delta_col < 0; ++delta_col) { + result += grain_cursor[lane + delta_col] * coeffs[pos]; + ++pos; + } + grain_cursor[lane] = + Clip3(grain_cursor[lane] + RightShiftWithRounding(result, shift), + GetGrainMin<bitdepth>(), GetGrainMax<bitdepth>()); +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +template <int bitdepth, int auto_regression_coeff_lag, int lane> +inline void WriteFinalAutoRegression(int16_t* grain_cursor, int32x4x2_t sum, + const int8_t* coeffs, int pos, int shift) { + int32_t result = vgetq_lane_s32(sum.val[lane >> 2], lane & 3); + + for (int delta_col = -auto_regression_coeff_lag; delta_col < 0; ++delta_col) { + result += grain_cursor[lane + delta_col] * coeffs[pos]; + ++pos; + } + grain_cursor[lane] = + Clip3(grain_cursor[lane] + RightShiftWithRounding(result, shift), + GetGrainMin<bitdepth>(), GetGrainMax<bitdepth>()); +} +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +// Because the autoregressive filter requires the output of each pixel to +// compute pixels that come after in the row, we have to finish the calculations +// one at a time. +template <int bitdepth, int auto_regression_coeff_lag, int lane> +inline void WriteFinalAutoRegressionChroma(int8_t* u_grain_cursor, + int8_t* v_grain_cursor, + int32x4x2_t sum_u, int32x4x2_t sum_v, + const int8_t* coeffs_u, + const int8_t* coeffs_v, int pos, + int shift) { + WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>( + u_grain_cursor, sum_u, coeffs_u, pos, shift); + WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>( + v_grain_cursor, sum_v, coeffs_v, pos, shift); +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +template <int bitdepth, int auto_regression_coeff_lag, int lane> +inline void WriteFinalAutoRegressionChroma(int16_t* u_grain_cursor, + int16_t* v_grain_cursor, + int32x4x2_t sum_u, int32x4x2_t sum_v, + const int8_t* coeffs_u, + const int8_t* coeffs_v, int pos, + int shift) { + WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>( + u_grain_cursor, sum_u, coeffs_u, pos, shift); + WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>( + v_grain_cursor, sum_v, coeffs_v, pos, shift); +} +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +inline void SetZero(int32x4x2_t* v) { + v->val[0] = vdupq_n_s32(0); + v->val[1] = vdupq_n_s32(0); +} + +// Computes subsampled luma for use with chroma, by averaging in the x direction +// or y direction when applicable. +int16x8_t GetSubsampledLuma(const int8_t* const luma, int subsampling_x, + int subsampling_y, ptrdiff_t stride) { + if (subsampling_y != 0) { + assert(subsampling_x != 0); + const int8x16_t src0 = vld1q_s8(luma); + const int8x16_t src1 = vld1q_s8(luma + stride); + const int16x8_t ret0 = vcombine_s16(vpaddl_s8(vget_low_s8(src0)), + vpaddl_s8(vget_high_s8(src0))); + const int16x8_t ret1 = vcombine_s16(vpaddl_s8(vget_low_s8(src1)), + vpaddl_s8(vget_high_s8(src1))); + return vrshrq_n_s16(vaddq_s16(ret0, ret1), 2); + } + if (subsampling_x != 0) { + const int8x16_t src = vld1q_s8(luma); + return vrshrq_n_s16( + vcombine_s16(vpaddl_s8(vget_low_s8(src)), vpaddl_s8(vget_high_s8(src))), + 1); + } + return vmovl_s8(vld1_s8(luma)); +} + +// For BlendNoiseWithImageChromaWithCfl, only |subsampling_x| is needed. +inline uint16x8_t GetAverageLuma(const uint8_t* const luma, int subsampling_x) { + if (subsampling_x != 0) { + const uint8x16_t src = vld1q_u8(luma); + return vrshrq_n_u16(vpaddlq_u8(src), 1); + } + return vmovl_u8(vld1_u8(luma)); +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +// Computes subsampled luma for use with chroma, by averaging in the x direction +// or y direction when applicable. +int16x8_t GetSubsampledLuma(const int16_t* const luma, int subsampling_x, + int subsampling_y, ptrdiff_t stride) { + if (subsampling_y != 0) { + assert(subsampling_x != 0); + int16x8_t src0_lo = vld1q_s16(luma); + int16x8_t src0_hi = vld1q_s16(luma + 8); + const int16x8_t src1_lo = vld1q_s16(luma + stride); + const int16x8_t src1_hi = vld1q_s16(luma + stride + 8); + const int16x8_t src0 = + vcombine_s16(vpadd_s16(vget_low_s16(src0_lo), vget_high_s16(src0_lo)), + vpadd_s16(vget_low_s16(src0_hi), vget_high_s16(src0_hi))); + const int16x8_t src1 = + vcombine_s16(vpadd_s16(vget_low_s16(src1_lo), vget_high_s16(src1_lo)), + vpadd_s16(vget_low_s16(src1_hi), vget_high_s16(src1_hi))); + return vrshrq_n_s16(vaddq_s16(src0, src1), 2); + } + if (subsampling_x != 0) { + const int16x8_t src_lo = vld1q_s16(luma); + const int16x8_t src_hi = vld1q_s16(luma + 8); + const int16x8_t ret = + vcombine_s16(vpadd_s16(vget_low_s16(src_lo), vget_high_s16(src_lo)), + vpadd_s16(vget_low_s16(src_hi), vget_high_s16(src_hi))); + return vrshrq_n_s16(ret, 1); + } + return vld1q_s16(luma); +} + +// For BlendNoiseWithImageChromaWithCfl, only |subsampling_x| is needed. +inline uint16x8_t GetAverageLuma(const uint16_t* const luma, + int subsampling_x) { + if (subsampling_x != 0) { + const uint16x8x2_t src = vld2q_u16(luma); + return vrhaddq_u16(src.val[0], src.val[1]); + } + return vld1q_u16(luma); +} +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +template <int bitdepth, typename GrainType, int auto_regression_coeff_lag, + bool use_luma> +void ApplyAutoRegressiveFilterToChromaGrains_NEON(const FilmGrainParams& params, + const void* luma_grain_buffer, + int subsampling_x, + int subsampling_y, + void* u_grain_buffer, + void* v_grain_buffer) { + static_assert(auto_regression_coeff_lag <= 3, "Invalid autoregression lag."); + const auto* luma_grain = static_cast<const GrainType*>(luma_grain_buffer); + auto* u_grain = static_cast<GrainType*>(u_grain_buffer); + auto* v_grain = static_cast<GrainType*>(v_grain_buffer); + const int auto_regression_shift = params.auto_regression_shift; + const int chroma_width = + (subsampling_x == 0) ? kMaxChromaWidth : kMinChromaWidth; + const int chroma_height = + (subsampling_y == 0) ? kMaxChromaHeight : kMinChromaHeight; + // When |chroma_width| == 44, we write 8 at a time from x in [3, 34], + // leaving [35, 40] to write at the end. + const int chroma_width_remainder = + (chroma_width - 2 * kAutoRegressionBorder) & 7; + + int y = kAutoRegressionBorder; + luma_grain += kLumaWidth * y; + u_grain += chroma_width * y; + v_grain += chroma_width * y; + do { + // Each row is computed 8 values at a time in the following loop. At the + // end of the loop, 4 values remain to write. They are given a special + // reduced iteration at the end. + int x = kAutoRegressionBorder; + int luma_x = kAutoRegressionBorder; + do { + int pos = 0; + int32x4x2_t sum_u; + int32x4x2_t sum_v; + SetZero(&sum_u); + SetZero(&sum_v); + + if (auto_regression_coeff_lag > 0) { + for (int delta_row = -auto_regression_coeff_lag; delta_row < 0; + ++delta_row) { + // These loads may overflow to the next row, but they are never called + // on the final row of a grain block. Therefore, they will never + // exceed the block boundaries. + // Note: this could be slightly optimized to a single load in 8bpp, + // but requires making a special first iteration and accumulate + // function that takes an int8x16_t. + const int16x8_t u_grain_lo = + GetSignedSource8(u_grain + x + delta_row * chroma_width - + auto_regression_coeff_lag); + const int16x8_t u_grain_hi = + GetSignedSource8(u_grain + x + delta_row * chroma_width - + auto_regression_coeff_lag + 8); + const int16x8_t v_grain_lo = + GetSignedSource8(v_grain + x + delta_row * chroma_width - + auto_regression_coeff_lag); + const int16x8_t v_grain_hi = + GetSignedSource8(v_grain + x + delta_row * chroma_width - + auto_regression_coeff_lag + 8); +#define ACCUMULATE_WEIGHTED_GRAIN(offset) \ + sum_u = AccumulateWeightedGrain<offset>( \ + u_grain_lo, u_grain_hi, params.auto_regression_coeff_u[pos], sum_u); \ + sum_v = AccumulateWeightedGrain<offset>( \ + v_grain_lo, v_grain_hi, params.auto_regression_coeff_v[pos++], sum_v) + + ACCUMULATE_WEIGHTED_GRAIN(0); + ACCUMULATE_WEIGHTED_GRAIN(1); + ACCUMULATE_WEIGHTED_GRAIN(2); + // The horizontal |auto_regression_coeff_lag| loop is replaced with + // if-statements to give vextq_s16 an immediate param. + if (auto_regression_coeff_lag > 1) { + ACCUMULATE_WEIGHTED_GRAIN(3); + ACCUMULATE_WEIGHTED_GRAIN(4); + } + if (auto_regression_coeff_lag > 2) { + assert(auto_regression_coeff_lag == 3); + ACCUMULATE_WEIGHTED_GRAIN(5); + ACCUMULATE_WEIGHTED_GRAIN(6); + } + } + } + + if (use_luma) { + const int16x8_t luma = GetSubsampledLuma( + luma_grain + luma_x, subsampling_x, subsampling_y, kLumaWidth); + + // Luma samples get the final coefficient in the formula, but are best + // computed all at once before the final row. + const int coeff_u = + params.auto_regression_coeff_u[pos + auto_regression_coeff_lag]; + const int coeff_v = + params.auto_regression_coeff_v[pos + auto_regression_coeff_lag]; + + sum_u.val[0] = vmlal_n_s16(sum_u.val[0], vget_low_s16(luma), coeff_u); + sum_u.val[1] = vmlal_n_s16(sum_u.val[1], vget_high_s16(luma), coeff_u); + sum_v.val[0] = vmlal_n_s16(sum_v.val[0], vget_low_s16(luma), coeff_v); + sum_v.val[1] = vmlal_n_s16(sum_v.val[1], vget_high_s16(luma), coeff_v); + } + // At this point in the filter, the source addresses and destination + // addresses overlap. Because this is an auto-regressive filter, the + // higher lanes cannot be computed without the results of the lower lanes. + // Each call to WriteFinalAutoRegression incorporates preceding values + // on the final row, and writes a single sample. This allows the next + // pixel's value to be computed in the next call. +#define WRITE_AUTO_REGRESSION_RESULT(lane) \ + WriteFinalAutoRegressionChroma<bitdepth, auto_regression_coeff_lag, lane>( \ + u_grain + x, v_grain + x, sum_u, sum_v, params.auto_regression_coeff_u, \ + params.auto_regression_coeff_v, pos, auto_regression_shift) + + WRITE_AUTO_REGRESSION_RESULT(0); + WRITE_AUTO_REGRESSION_RESULT(1); + WRITE_AUTO_REGRESSION_RESULT(2); + WRITE_AUTO_REGRESSION_RESULT(3); + WRITE_AUTO_REGRESSION_RESULT(4); + WRITE_AUTO_REGRESSION_RESULT(5); + WRITE_AUTO_REGRESSION_RESULT(6); + WRITE_AUTO_REGRESSION_RESULT(7); + + x += 8; + luma_x += 8 << subsampling_x; + } while (x < chroma_width - kAutoRegressionBorder - chroma_width_remainder); + + // This is the "final iteration" of the above loop over width. We fill in + // the remainder of the width, which is less than 8. + int pos = 0; + int32x4x2_t sum_u; + int32x4x2_t sum_v; + SetZero(&sum_u); + SetZero(&sum_v); + + for (int delta_row = -auto_regression_coeff_lag; delta_row < 0; + ++delta_row) { + // These loads may overflow to the next row, but they are never called on + // the final row of a grain block. Therefore, they will never exceed the + // block boundaries. + const int16x8_t u_grain_lo = GetSignedSource8( + u_grain + x + delta_row * chroma_width - auto_regression_coeff_lag); + const int16x8_t u_grain_hi = + GetSignedSource8(u_grain + x + delta_row * chroma_width - + auto_regression_coeff_lag + 8); + const int16x8_t v_grain_lo = GetSignedSource8( + v_grain + x + delta_row * chroma_width - auto_regression_coeff_lag); + const int16x8_t v_grain_hi = + GetSignedSource8(v_grain + x + delta_row * chroma_width - + auto_regression_coeff_lag + 8); + + ACCUMULATE_WEIGHTED_GRAIN(0); + ACCUMULATE_WEIGHTED_GRAIN(1); + ACCUMULATE_WEIGHTED_GRAIN(2); + // The horizontal |auto_regression_coeff_lag| loop is replaced with + // if-statements to give vextq_s16 an immediate param. + if (auto_regression_coeff_lag > 1) { + ACCUMULATE_WEIGHTED_GRAIN(3); + ACCUMULATE_WEIGHTED_GRAIN(4); + } + if (auto_regression_coeff_lag > 2) { + assert(auto_regression_coeff_lag == 3); + ACCUMULATE_WEIGHTED_GRAIN(5); + ACCUMULATE_WEIGHTED_GRAIN(6); + } + } + + if (use_luma) { + const int16x8_t luma = GetSubsampledLuma( + luma_grain + luma_x, subsampling_x, subsampling_y, kLumaWidth); + + // Luma samples get the final coefficient in the formula, but are best + // computed all at once before the final row. + const int coeff_u = + params.auto_regression_coeff_u[pos + auto_regression_coeff_lag]; + const int coeff_v = + params.auto_regression_coeff_v[pos + auto_regression_coeff_lag]; + + sum_u.val[0] = vmlal_n_s16(sum_u.val[0], vget_low_s16(luma), coeff_u); + sum_u.val[1] = vmlal_n_s16(sum_u.val[1], vget_high_s16(luma), coeff_u); + sum_v.val[0] = vmlal_n_s16(sum_v.val[0], vget_low_s16(luma), coeff_v); + sum_v.val[1] = vmlal_n_s16(sum_v.val[1], vget_high_s16(luma), coeff_v); + } + + WRITE_AUTO_REGRESSION_RESULT(0); + WRITE_AUTO_REGRESSION_RESULT(1); + WRITE_AUTO_REGRESSION_RESULT(2); + WRITE_AUTO_REGRESSION_RESULT(3); + if (chroma_width_remainder == 6) { + WRITE_AUTO_REGRESSION_RESULT(4); + WRITE_AUTO_REGRESSION_RESULT(5); + } + + luma_grain += kLumaWidth << subsampling_y; + u_grain += chroma_width; + v_grain += chroma_width; + } while (++y < chroma_height); +#undef ACCUMULATE_WEIGHTED_GRAIN +#undef WRITE_AUTO_REGRESSION_RESULT +} + +// Applies an auto-regressive filter to the white noise in luma_grain. +template <int bitdepth, typename GrainType, int auto_regression_coeff_lag> +void ApplyAutoRegressiveFilterToLumaGrain_NEON(const FilmGrainParams& params, + void* luma_grain_buffer) { + static_assert(auto_regression_coeff_lag > 0, ""); + const int8_t* const auto_regression_coeff_y = params.auto_regression_coeff_y; + const uint8_t auto_regression_shift = params.auto_regression_shift; + + int y = kAutoRegressionBorder; + auto* luma_grain = + static_cast<GrainType*>(luma_grain_buffer) + kLumaWidth * y; + do { + // Each row is computed 8 values at a time in the following loop. At the + // end of the loop, 4 values remain to write. They are given a special + // reduced iteration at the end. + int x = kAutoRegressionBorder; + do { + int pos = 0; + int32x4x2_t sum; + SetZero(&sum); + for (int delta_row = -auto_regression_coeff_lag; delta_row < 0; + ++delta_row) { + // These loads may overflow to the next row, but they are never called + // on the final row of a grain block. Therefore, they will never exceed + // the block boundaries. + const int16x8_t src_grain_lo = + GetSignedSource8(luma_grain + x + delta_row * kLumaWidth - + auto_regression_coeff_lag); + const int16x8_t src_grain_hi = + GetSignedSource8(luma_grain + x + delta_row * kLumaWidth - + auto_regression_coeff_lag + 8); + + // A pictorial representation of the auto-regressive filter for + // various values of params.auto_regression_coeff_lag. The letter 'O' + // represents the current sample. (The filter always operates on the + // current sample with filter coefficient 1.) The letters 'X' + // represent the neighboring samples that the filter operates on, below + // their corresponding "offset" number. + // + // params.auto_regression_coeff_lag == 3: + // 0 1 2 3 4 5 6 + // X X X X X X X + // X X X X X X X + // X X X X X X X + // X X X O + // params.auto_regression_coeff_lag == 2: + // 0 1 2 3 4 + // X X X X X + // X X X X X + // X X O + // params.auto_regression_coeff_lag == 1: + // 0 1 2 + // X X X + // X O + // params.auto_regression_coeff_lag == 0: + // O + // The function relies on the caller to skip the call in the 0 lag + // case. + +#define ACCUMULATE_WEIGHTED_GRAIN(offset) \ + sum = AccumulateWeightedGrain<offset>(src_grain_lo, src_grain_hi, \ + auto_regression_coeff_y[pos++], sum) + ACCUMULATE_WEIGHTED_GRAIN(0); + ACCUMULATE_WEIGHTED_GRAIN(1); + ACCUMULATE_WEIGHTED_GRAIN(2); + // The horizontal |auto_regression_coeff_lag| loop is replaced with + // if-statements to give vextq_s16 an immediate param. + if (auto_regression_coeff_lag > 1) { + ACCUMULATE_WEIGHTED_GRAIN(3); + ACCUMULATE_WEIGHTED_GRAIN(4); + } + if (auto_regression_coeff_lag > 2) { + assert(auto_regression_coeff_lag == 3); + ACCUMULATE_WEIGHTED_GRAIN(5); + ACCUMULATE_WEIGHTED_GRAIN(6); + } + } + // At this point in the filter, the source addresses and destination + // addresses overlap. Because this is an auto-regressive filter, the + // higher lanes cannot be computed without the results of the lower lanes. + // Each call to WriteFinalAutoRegression incorporates preceding values + // on the final row, and writes a single sample. This allows the next + // pixel's value to be computed in the next call. +#define WRITE_AUTO_REGRESSION_RESULT(lane) \ + WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>( \ + luma_grain + x, sum, auto_regression_coeff_y, pos, \ + auto_regression_shift) + + WRITE_AUTO_REGRESSION_RESULT(0); + WRITE_AUTO_REGRESSION_RESULT(1); + WRITE_AUTO_REGRESSION_RESULT(2); + WRITE_AUTO_REGRESSION_RESULT(3); + WRITE_AUTO_REGRESSION_RESULT(4); + WRITE_AUTO_REGRESSION_RESULT(5); + WRITE_AUTO_REGRESSION_RESULT(6); + WRITE_AUTO_REGRESSION_RESULT(7); + x += 8; + // Leave the final four pixels for the special iteration below. + } while (x < kLumaWidth - kAutoRegressionBorder - 4); + + // Final 4 pixels in the row. + int pos = 0; + int32x4x2_t sum; + SetZero(&sum); + for (int delta_row = -auto_regression_coeff_lag; delta_row < 0; + ++delta_row) { + const int16x8_t src_grain_lo = GetSignedSource8( + luma_grain + x + delta_row * kLumaWidth - auto_regression_coeff_lag); + const int16x8_t src_grain_hi = + GetSignedSource8(luma_grain + x + delta_row * kLumaWidth - + auto_regression_coeff_lag + 8); + + ACCUMULATE_WEIGHTED_GRAIN(0); + ACCUMULATE_WEIGHTED_GRAIN(1); + ACCUMULATE_WEIGHTED_GRAIN(2); + // The horizontal |auto_regression_coeff_lag| loop is replaced with + // if-statements to give vextq_s16 an immediate param. + if (auto_regression_coeff_lag > 1) { + ACCUMULATE_WEIGHTED_GRAIN(3); + ACCUMULATE_WEIGHTED_GRAIN(4); + } + if (auto_regression_coeff_lag > 2) { + assert(auto_regression_coeff_lag == 3); + ACCUMULATE_WEIGHTED_GRAIN(5); + ACCUMULATE_WEIGHTED_GRAIN(6); + } + } + // delta_row == 0 + WRITE_AUTO_REGRESSION_RESULT(0); + WRITE_AUTO_REGRESSION_RESULT(1); + WRITE_AUTO_REGRESSION_RESULT(2); + WRITE_AUTO_REGRESSION_RESULT(3); + luma_grain += kLumaWidth; + } while (++y < kLumaHeight); + +#undef WRITE_AUTO_REGRESSION_RESULT +#undef ACCUMULATE_WEIGHTED_GRAIN +} + +void InitializeScalingLookupTable_NEON( + int num_points, const uint8_t point_value[], const uint8_t point_scaling[], + uint8_t scaling_lut[kScalingLookupTableSize]) { + if (num_points == 0) { + memset(scaling_lut, 0, sizeof(scaling_lut[0]) * kScalingLookupTableSize); + return; + } + static_assert(sizeof(scaling_lut[0]) == 1, ""); + memset(scaling_lut, point_scaling[0], point_value[0]); + const uint32x4_t steps = vmovl_u16(vcreate_u16(0x0003000200010000)); + const uint32x4_t offset = vdupq_n_u32(32768); + for (int i = 0; i < num_points - 1; ++i) { + const int delta_y = point_scaling[i + 1] - point_scaling[i]; + const int delta_x = point_value[i + 1] - point_value[i]; + const int delta = delta_y * ((65536 + (delta_x >> 1)) / delta_x); + const int delta4 = delta << 2; + const uint8x8_t base_point = vdup_n_u8(point_scaling[i]); + uint32x4_t upscaled_points0 = vmlaq_n_u32(offset, steps, delta); + const uint32x4_t line_increment4 = vdupq_n_u32(delta4); + // Get the second set of 4 points by adding 4 steps to the first set. + uint32x4_t upscaled_points1 = vaddq_u32(upscaled_points0, line_increment4); + // We obtain the next set of 8 points by adding 8 steps to each of the + // current 8 points. + const uint32x4_t line_increment8 = vshlq_n_u32(line_increment4, 1); + int x = 0; + do { + const uint16x4_t interp_points0 = vshrn_n_u32(upscaled_points0, 16); + const uint16x4_t interp_points1 = vshrn_n_u32(upscaled_points1, 16); + const uint8x8_t interp_points = + vmovn_u16(vcombine_u16(interp_points0, interp_points1)); + // The spec guarantees that the max value of |point_value[i]| + x is 255. + // Writing 8 bytes starting at the final table byte, leaves 7 bytes of + // required padding. + vst1_u8(&scaling_lut[point_value[i] + x], + vadd_u8(interp_points, base_point)); + upscaled_points0 = vaddq_u32(upscaled_points0, line_increment8); + upscaled_points1 = vaddq_u32(upscaled_points1, line_increment8); + x += 8; + } while (x < delta_x); + } + const uint8_t last_point_value = point_value[num_points - 1]; + memset(&scaling_lut[last_point_value], point_scaling[num_points - 1], + kScalingLookupTableSize - last_point_value); +} + +inline int16x8_t Clip3(const int16x8_t value, const int16x8_t low, + const int16x8_t high) { + const int16x8_t clipped_to_ceiling = vminq_s16(high, value); + return vmaxq_s16(low, clipped_to_ceiling); +} + +template <int bitdepth, typename Pixel> +inline int16x8_t GetScalingFactors( + const uint8_t scaling_lut[kScalingLookupTableSize], const Pixel* source) { + int16_t start_vals[8]; + if (bitdepth == 8) { + start_vals[0] = scaling_lut[source[0]]; + start_vals[1] = scaling_lut[source[1]]; + start_vals[2] = scaling_lut[source[2]]; + start_vals[3] = scaling_lut[source[3]]; + start_vals[4] = scaling_lut[source[4]]; + start_vals[5] = scaling_lut[source[5]]; + start_vals[6] = scaling_lut[source[6]]; + start_vals[7] = scaling_lut[source[7]]; + return vld1q_s16(start_vals); + } + int16_t end_vals[8]; + // TODO(petersonab): Precompute this into a larger table for direct lookups. + int index = source[0] >> 2; + start_vals[0] = scaling_lut[index]; + end_vals[0] = scaling_lut[index + 1]; + index = source[1] >> 2; + start_vals[1] = scaling_lut[index]; + end_vals[1] = scaling_lut[index + 1]; + index = source[2] >> 2; + start_vals[2] = scaling_lut[index]; + end_vals[2] = scaling_lut[index + 1]; + index = source[3] >> 2; + start_vals[3] = scaling_lut[index]; + end_vals[3] = scaling_lut[index + 1]; + index = source[4] >> 2; + start_vals[4] = scaling_lut[index]; + end_vals[4] = scaling_lut[index + 1]; + index = source[5] >> 2; + start_vals[5] = scaling_lut[index]; + end_vals[5] = scaling_lut[index + 1]; + index = source[6] >> 2; + start_vals[6] = scaling_lut[index]; + end_vals[6] = scaling_lut[index + 1]; + index = source[7] >> 2; + start_vals[7] = scaling_lut[index]; + end_vals[7] = scaling_lut[index + 1]; + const int16x8_t start = vld1q_s16(start_vals); + const int16x8_t end = vld1q_s16(end_vals); + int16x8_t remainder = GetSignedSource8(source); + remainder = vandq_s16(remainder, vdupq_n_s16(3)); + const int16x8_t delta = vmulq_s16(vsubq_s16(end, start), remainder); + return vaddq_s16(start, vrshrq_n_s16(delta, 2)); +} + +inline int16x8_t ScaleNoise(const int16x8_t noise, const int16x8_t scaling, + const int16x8_t scaling_shift_vect) { + const int16x8_t upscaled_noise = vmulq_s16(noise, scaling); + return vrshlq_s16(upscaled_noise, scaling_shift_vect); +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +inline int16x8_t ScaleNoise(const int16x8_t noise, const int16x8_t scaling, + const int32x4_t scaling_shift_vect) { + // TODO(petersonab): Try refactoring scaling lookup table to int16_t and + // upscaling by 7 bits to permit high half multiply. This would eliminate + // the intermediate 32x4 registers. Also write the averaged values directly + // into the table so it doesn't have to be done for every pixel in + // the frame. + const int32x4_t upscaled_noise_lo = + vmull_s16(vget_low_s16(noise), vget_low_s16(scaling)); + const int32x4_t upscaled_noise_hi = + vmull_s16(vget_high_s16(noise), vget_high_s16(scaling)); + const int16x4_t noise_lo = + vmovn_s32(vrshlq_s32(upscaled_noise_lo, scaling_shift_vect)); + const int16x4_t noise_hi = + vmovn_s32(vrshlq_s32(upscaled_noise_hi, scaling_shift_vect)); + return vcombine_s16(noise_lo, noise_hi); +} +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +template <int bitdepth, typename GrainType, typename Pixel> +void BlendNoiseWithImageLuma_NEON( + const void* noise_image_ptr, int min_value, int max_luma, int scaling_shift, + int width, int height, int start_height, + const uint8_t scaling_lut_y[kScalingLookupTableSize], + const void* source_plane_y, ptrdiff_t source_stride_y, void* dest_plane_y, + ptrdiff_t dest_stride_y) { + const auto* noise_image = + static_cast<const Array2D<GrainType>*>(noise_image_ptr); + const auto* in_y_row = static_cast<const Pixel*>(source_plane_y); + source_stride_y /= sizeof(Pixel); + auto* out_y_row = static_cast<Pixel*>(dest_plane_y); + dest_stride_y /= sizeof(Pixel); + const int16x8_t floor = vdupq_n_s16(min_value); + const int16x8_t ceiling = vdupq_n_s16(max_luma); + // In 8bpp, the maximum upscaled noise is 127*255 = 0x7E81, which is safe + // for 16 bit signed integers. In higher bitdepths, however, we have to + // expand to 32 to protect the sign bit. + const int16x8_t scaling_shift_vect16 = vdupq_n_s16(-scaling_shift); +#if LIBGAV1_MAX_BITDEPTH >= 10 + const int32x4_t scaling_shift_vect32 = vdupq_n_s32(-scaling_shift); +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + + int y = 0; + do { + int x = 0; + do { + // This operation on the unsigned input is safe in 8bpp because the vector + // is widened before it is reinterpreted. + const int16x8_t orig = GetSignedSource8(&in_y_row[x]); + const int16x8_t scaling = + GetScalingFactors<bitdepth, Pixel>(scaling_lut_y, &in_y_row[x]); + int16x8_t noise = + GetSignedSource8(&(noise_image[kPlaneY][y + start_height][x])); + + if (bitdepth == 8) { + noise = ScaleNoise(noise, scaling, scaling_shift_vect16); + } else { +#if LIBGAV1_MAX_BITDEPTH >= 10 + noise = ScaleNoise(noise, scaling, scaling_shift_vect32); +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + } + const int16x8_t combined = vaddq_s16(orig, noise); + // In 8bpp, when params_.clip_to_restricted_range == false, we can replace + // clipping with vqmovun_s16, but it's not likely to be worth copying the + // function for just that case, though the gain would be very small. + StoreUnsigned8(&out_y_row[x], + vreinterpretq_u16_s16(Clip3(combined, floor, ceiling))); + x += 8; + } while (x < width); + in_y_row += source_stride_y; + out_y_row += dest_stride_y; + } while (++y < height); +} + +template <int bitdepth, typename GrainType, typename Pixel> +inline int16x8_t BlendChromaValsWithCfl( + const Pixel* average_luma_buffer, + const uint8_t scaling_lut[kScalingLookupTableSize], + const Pixel* chroma_cursor, const GrainType* noise_image_cursor, + const int16x8_t scaling_shift_vect16, + const int32x4_t scaling_shift_vect32) { + const int16x8_t scaling = + GetScalingFactors<bitdepth, Pixel>(scaling_lut, average_luma_buffer); + const int16x8_t orig = GetSignedSource8(chroma_cursor); + int16x8_t noise = GetSignedSource8(noise_image_cursor); + if (bitdepth == 8) { + noise = ScaleNoise(noise, scaling, scaling_shift_vect16); + } else { + noise = ScaleNoise(noise, scaling, scaling_shift_vect32); + } + return vaddq_s16(orig, noise); +} + +template <int bitdepth, typename GrainType, typename Pixel> +LIBGAV1_ALWAYS_INLINE void BlendChromaPlaneWithCfl_NEON( + const Array2D<GrainType>& noise_image, int min_value, int max_chroma, + int width, int height, int start_height, int subsampling_x, + int subsampling_y, int scaling_shift, + const uint8_t scaling_lut[kScalingLookupTableSize], const Pixel* in_y_row, + ptrdiff_t source_stride_y, const Pixel* in_chroma_row, + ptrdiff_t source_stride_chroma, Pixel* out_chroma_row, + ptrdiff_t dest_stride) { + const int16x8_t floor = vdupq_n_s16(min_value); + const int16x8_t ceiling = vdupq_n_s16(max_chroma); + Pixel luma_buffer[16]; + memset(luma_buffer, 0, sizeof(luma_buffer)); + // In 8bpp, the maximum upscaled noise is 127*255 = 0x7E81, which is safe + // for 16 bit signed integers. In higher bitdepths, however, we have to + // expand to 32 to protect the sign bit. + const int16x8_t scaling_shift_vect16 = vdupq_n_s16(-scaling_shift); + const int32x4_t scaling_shift_vect32 = vdupq_n_s32(-scaling_shift); + + const int chroma_height = (height + subsampling_y) >> subsampling_y; + const int chroma_width = (width + subsampling_x) >> subsampling_x; + const int safe_chroma_width = chroma_width & ~7; + + // Writing to this buffer avoids the cost of doing 8 lane lookups in a row + // in GetScalingFactors. + Pixel average_luma_buffer[8]; + assert(start_height % 2 == 0); + start_height >>= subsampling_y; + int y = 0; + do { + int x = 0; + do { + const int luma_x = x << subsampling_x; + // TODO(petersonab): Consider specializing by subsampling_x. In the 444 + // case &in_y_row[x] can be passed to GetScalingFactors directly. + const uint16x8_t average_luma = + GetAverageLuma(&in_y_row[luma_x], subsampling_x); + StoreUnsigned8(average_luma_buffer, average_luma); + + const int16x8_t blended = + BlendChromaValsWithCfl<bitdepth, GrainType, Pixel>( + average_luma_buffer, scaling_lut, &in_chroma_row[x], + &(noise_image[y + start_height][x]), scaling_shift_vect16, + scaling_shift_vect32); + + // In 8bpp, when params_.clip_to_restricted_range == false, we can replace + // clipping with vqmovun_s16, but it's not likely to be worth copying the + // function for just that case. + StoreUnsigned8(&out_chroma_row[x], + vreinterpretq_u16_s16(Clip3(blended, floor, ceiling))); + x += 8; + } while (x < safe_chroma_width); + + if (x < chroma_width) { + const int luma_x = x << subsampling_x; + const int valid_range = width - luma_x; + memcpy(luma_buffer, &in_y_row[luma_x], valid_range * sizeof(in_y_row[0])); + luma_buffer[valid_range] = in_y_row[width - 1]; + const uint16x8_t average_luma = + GetAverageLuma(luma_buffer, subsampling_x); + StoreUnsigned8(average_luma_buffer, average_luma); + + const int16x8_t blended = + BlendChromaValsWithCfl<bitdepth, GrainType, Pixel>( + average_luma_buffer, scaling_lut, &in_chroma_row[x], + &(noise_image[y + start_height][x]), scaling_shift_vect16, + scaling_shift_vect32); + // In 8bpp, when params_.clip_to_restricted_range == false, we can replace + // clipping with vqmovun_s16, but it's not likely to be worth copying the + // function for just that case. + StoreUnsigned8(&out_chroma_row[x], + vreinterpretq_u16_s16(Clip3(blended, floor, ceiling))); + } + + in_y_row += source_stride_y << subsampling_y; + in_chroma_row += source_stride_chroma; + out_chroma_row += dest_stride; + } while (++y < chroma_height); +} + +// This function is for the case params_.chroma_scaling_from_luma == true. +// This further implies that scaling_lut_u == scaling_lut_v == scaling_lut_y. +template <int bitdepth, typename GrainType, typename Pixel> +void BlendNoiseWithImageChromaWithCfl_NEON( + Plane plane, const FilmGrainParams& params, const void* noise_image_ptr, + int min_value, int max_chroma, int width, int height, int start_height, + int subsampling_x, int subsampling_y, + const uint8_t scaling_lut[kScalingLookupTableSize], + const void* source_plane_y, ptrdiff_t source_stride_y, + const void* source_plane_uv, ptrdiff_t source_stride_uv, + void* dest_plane_uv, ptrdiff_t dest_stride_uv) { + const auto* noise_image = + static_cast<const Array2D<GrainType>*>(noise_image_ptr); + const auto* in_y = static_cast<const Pixel*>(source_plane_y); + source_stride_y /= sizeof(Pixel); + + const auto* in_uv = static_cast<const Pixel*>(source_plane_uv); + source_stride_uv /= sizeof(Pixel); + auto* out_uv = static_cast<Pixel*>(dest_plane_uv); + dest_stride_uv /= sizeof(Pixel); + // Looping over one plane at a time is faster in higher resolutions, despite + // re-computing luma. + BlendChromaPlaneWithCfl_NEON<bitdepth, GrainType, Pixel>( + noise_image[plane], min_value, max_chroma, width, height, start_height, + subsampling_x, subsampling_y, params.chroma_scaling, scaling_lut, in_y, + source_stride_y, in_uv, source_stride_uv, out_uv, dest_stride_uv); +} + +} // namespace + +namespace low_bitdepth { +namespace { + +inline int16x8_t BlendChromaValsNoCfl( + const uint8_t scaling_lut[kScalingLookupTableSize], + const uint8_t* chroma_cursor, const int8_t* noise_image_cursor, + const int16x8_t& average_luma, const int16x8_t& scaling_shift_vect, + const int16x8_t& offset, int luma_multiplier, int chroma_multiplier) { + uint8_t merged_buffer[8]; + const int16x8_t orig = GetSignedSource8(chroma_cursor); + const int16x8_t weighted_luma = vmulq_n_s16(average_luma, luma_multiplier); + const int16x8_t weighted_chroma = vmulq_n_s16(orig, chroma_multiplier); + // Maximum value of |combined_u| is 127*255 = 0x7E81. + const int16x8_t combined = vhaddq_s16(weighted_luma, weighted_chroma); + // Maximum value of u_offset is (255 << 5) = 0x1FE0. + // 0x7E81 + 0x1FE0 = 0x9E61, therefore another halving add is required. + const uint8x8_t merged = vqshrun_n_s16(vhaddq_s16(offset, combined), 4); + vst1_u8(merged_buffer, merged); + const int16x8_t scaling = + GetScalingFactors<8, uint8_t>(scaling_lut, merged_buffer); + int16x8_t noise = GetSignedSource8(noise_image_cursor); + noise = ScaleNoise(noise, scaling, scaling_shift_vect); + return vaddq_s16(orig, noise); +} + +LIBGAV1_ALWAYS_INLINE void BlendChromaPlane8bpp_NEON( + const Array2D<int8_t>& noise_image, int min_value, int max_chroma, + int width, int height, int start_height, int subsampling_x, + int subsampling_y, int scaling_shift, int chroma_offset, + int chroma_multiplier, int luma_multiplier, + const uint8_t scaling_lut[kScalingLookupTableSize], const uint8_t* in_y_row, + ptrdiff_t source_stride_y, const uint8_t* in_chroma_row, + ptrdiff_t source_stride_chroma, uint8_t* out_chroma_row, + ptrdiff_t dest_stride) { + const int16x8_t floor = vdupq_n_s16(min_value); + const int16x8_t ceiling = vdupq_n_s16(max_chroma); + // In 8bpp, the maximum upscaled noise is 127*255 = 0x7E81, which is safe + // for 16 bit signed integers. In higher bitdepths, however, we have to + // expand to 32 to protect the sign bit. + const int16x8_t scaling_shift_vect = vdupq_n_s16(-scaling_shift); + + const int chroma_height = (height + subsampling_y) >> subsampling_y; + const int chroma_width = (width + subsampling_x) >> subsampling_x; + const int safe_chroma_width = chroma_width & ~7; + uint8_t luma_buffer[16]; + const int16x8_t offset = vdupq_n_s16(chroma_offset << 5); + + start_height >>= subsampling_y; + int y = 0; + do { + int x = 0; + do { + const int luma_x = x << subsampling_x; + const int16x8_t average_luma = vreinterpretq_s16_u16( + GetAverageLuma(&in_y_row[luma_x], subsampling_x)); + const int16x8_t blended = BlendChromaValsNoCfl( + scaling_lut, &in_chroma_row[x], &(noise_image[y + start_height][x]), + average_luma, scaling_shift_vect, offset, luma_multiplier, + chroma_multiplier); + // In 8bpp, when params_.clip_to_restricted_range == false, we can + // replace clipping with vqmovun_s16, but the gain would be small. + StoreUnsigned8(&out_chroma_row[x], + vreinterpretq_u16_s16(Clip3(blended, floor, ceiling))); + + x += 8; + } while (x < safe_chroma_width); + + if (x < chroma_width) { + // Begin right edge iteration. Same as the normal iterations, but the + // |average_luma| computation requires a duplicated luma value at the + // end. + const int luma_x = x << subsampling_x; + const int valid_range = width - luma_x; + memcpy(luma_buffer, &in_y_row[luma_x], valid_range * sizeof(in_y_row[0])); + luma_buffer[valid_range] = in_y_row[width - 1]; + + const int16x8_t average_luma = + vreinterpretq_s16_u16(GetAverageLuma(luma_buffer, subsampling_x)); + const int16x8_t blended = BlendChromaValsNoCfl( + scaling_lut, &in_chroma_row[x], &(noise_image[y + start_height][x]), + average_luma, scaling_shift_vect, offset, luma_multiplier, + chroma_multiplier); + StoreUnsigned8(&out_chroma_row[x], + vreinterpretq_u16_s16(Clip3(blended, floor, ceiling))); + // End of right edge iteration. + } + + in_y_row += source_stride_y << subsampling_y; + in_chroma_row += source_stride_chroma; + out_chroma_row += dest_stride; + } while (++y < chroma_height); +} + +// This function is for the case params_.chroma_scaling_from_luma == false. +void BlendNoiseWithImageChroma8bpp_NEON( + Plane plane, const FilmGrainParams& params, const void* noise_image_ptr, + int min_value, int max_chroma, int width, int height, int start_height, + int subsampling_x, int subsampling_y, + const uint8_t scaling_lut[kScalingLookupTableSize], + const void* source_plane_y, ptrdiff_t source_stride_y, + const void* source_plane_uv, ptrdiff_t source_stride_uv, + void* dest_plane_uv, ptrdiff_t dest_stride_uv) { + assert(plane == kPlaneU || plane == kPlaneV); + const auto* noise_image = + static_cast<const Array2D<int8_t>*>(noise_image_ptr); + const auto* in_y = static_cast<const uint8_t*>(source_plane_y); + const auto* in_uv = static_cast<const uint8_t*>(source_plane_uv); + auto* out_uv = static_cast<uint8_t*>(dest_plane_uv); + + const int offset = (plane == kPlaneU) ? params.u_offset : params.v_offset; + const int luma_multiplier = + (plane == kPlaneU) ? params.u_luma_multiplier : params.v_luma_multiplier; + const int multiplier = + (plane == kPlaneU) ? params.u_multiplier : params.v_multiplier; + BlendChromaPlane8bpp_NEON(noise_image[plane], min_value, max_chroma, width, + height, start_height, subsampling_x, subsampling_y, + params.chroma_scaling, offset, multiplier, + luma_multiplier, scaling_lut, in_y, source_stride_y, + in_uv, source_stride_uv, out_uv, dest_stride_uv); +} + +inline void WriteOverlapLine8bpp_NEON(const int8_t* noise_stripe_row, + const int8_t* noise_stripe_row_prev, + int plane_width, + const int8x8_t grain_coeff, + const int8x8_t old_coeff, + int8_t* noise_image_row) { + int x = 0; + do { + // Note that these reads may exceed noise_stripe_row's width by up to 7 + // bytes. + const int8x8_t source_grain = vld1_s8(noise_stripe_row + x); + const int8x8_t source_old = vld1_s8(noise_stripe_row_prev + x); + const int16x8_t weighted_grain = vmull_s8(grain_coeff, source_grain); + const int16x8_t grain = vmlal_s8(weighted_grain, old_coeff, source_old); + // Note that this write may exceed noise_image_row's width by up to 7 bytes. + vst1_s8(noise_image_row + x, vqrshrn_n_s16(grain, 5)); + x += 8; + } while (x < plane_width); +} + +void ConstructNoiseImageOverlap8bpp_NEON(const void* noise_stripes_buffer, + int width, int height, + int subsampling_x, int subsampling_y, + void* noise_image_buffer) { + const auto* noise_stripes = + static_cast<const Array2DView<int8_t>*>(noise_stripes_buffer); + auto* noise_image = static_cast<Array2D<int8_t>*>(noise_image_buffer); + const int plane_width = (width + subsampling_x) >> subsampling_x; + const int plane_height = (height + subsampling_y) >> subsampling_y; + const int stripe_height = 32 >> subsampling_y; + const int stripe_mask = stripe_height - 1; + int y = stripe_height; + int luma_num = 1; + if (subsampling_y == 0) { + const int8x8_t first_row_grain_coeff = vdup_n_s8(17); + const int8x8_t first_row_old_coeff = vdup_n_s8(27); + const int8x8_t second_row_grain_coeff = first_row_old_coeff; + const int8x8_t second_row_old_coeff = first_row_grain_coeff; + for (; y < (plane_height & ~stripe_mask); ++luma_num, y += stripe_height) { + const int8_t* noise_stripe = (*noise_stripes)[luma_num]; + const int8_t* noise_stripe_prev = (*noise_stripes)[luma_num - 1]; + WriteOverlapLine8bpp_NEON( + noise_stripe, &noise_stripe_prev[32 * plane_width], plane_width, + first_row_grain_coeff, first_row_old_coeff, (*noise_image)[y]); + + WriteOverlapLine8bpp_NEON(&noise_stripe[plane_width], + &noise_stripe_prev[(32 + 1) * plane_width], + plane_width, second_row_grain_coeff, + second_row_old_coeff, (*noise_image)[y + 1]); + } + // Either one partial stripe remains (remaining_height > 0), + // OR image is less than one stripe high (remaining_height < 0), + // OR all stripes are completed (remaining_height == 0). + const int remaining_height = plane_height - y; + if (remaining_height <= 0) { + return; + } + const int8_t* noise_stripe = (*noise_stripes)[luma_num]; + const int8_t* noise_stripe_prev = (*noise_stripes)[luma_num - 1]; + WriteOverlapLine8bpp_NEON( + noise_stripe, &noise_stripe_prev[32 * plane_width], plane_width, + first_row_grain_coeff, first_row_old_coeff, (*noise_image)[y]); + + if (remaining_height > 1) { + WriteOverlapLine8bpp_NEON(&noise_stripe[plane_width], + &noise_stripe_prev[(32 + 1) * plane_width], + plane_width, second_row_grain_coeff, + second_row_old_coeff, (*noise_image)[y + 1]); + } + } else { // subsampling_y == 1 + const int8x8_t first_row_grain_coeff = vdup_n_s8(22); + const int8x8_t first_row_old_coeff = vdup_n_s8(23); + for (; y < plane_height; ++luma_num, y += stripe_height) { + const int8_t* noise_stripe = (*noise_stripes)[luma_num]; + const int8_t* noise_stripe_prev = (*noise_stripes)[luma_num - 1]; + WriteOverlapLine8bpp_NEON( + noise_stripe, &noise_stripe_prev[16 * plane_width], plane_width, + first_row_grain_coeff, first_row_old_coeff, (*noise_image)[y]); + } + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + + // LumaAutoRegressionFunc + dsp->film_grain.luma_auto_regression[0] = + ApplyAutoRegressiveFilterToLumaGrain_NEON<8, int8_t, 1>; + dsp->film_grain.luma_auto_regression[1] = + ApplyAutoRegressiveFilterToLumaGrain_NEON<8, int8_t, 2>; + dsp->film_grain.luma_auto_regression[2] = + ApplyAutoRegressiveFilterToLumaGrain_NEON<8, int8_t, 3>; + + // ChromaAutoRegressionFunc[use_luma][auto_regression_coeff_lag] + // Chroma autoregression should never be called when lag is 0 and use_luma + // is false. + dsp->film_grain.chroma_auto_regression[0][0] = nullptr; + dsp->film_grain.chroma_auto_regression[0][1] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 1, false>; + dsp->film_grain.chroma_auto_regression[0][2] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 2, false>; + dsp->film_grain.chroma_auto_regression[0][3] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 3, false>; + dsp->film_grain.chroma_auto_regression[1][0] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 0, true>; + dsp->film_grain.chroma_auto_regression[1][1] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 1, true>; + dsp->film_grain.chroma_auto_regression[1][2] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 2, true>; + dsp->film_grain.chroma_auto_regression[1][3] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 3, true>; + + dsp->film_grain.construct_noise_image_overlap = + ConstructNoiseImageOverlap8bpp_NEON; + + dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_NEON; + + dsp->film_grain.blend_noise_luma = + BlendNoiseWithImageLuma_NEON<8, int8_t, uint8_t>; + dsp->film_grain.blend_noise_chroma[0] = BlendNoiseWithImageChroma8bpp_NEON; + dsp->film_grain.blend_noise_chroma[1] = + BlendNoiseWithImageChromaWithCfl_NEON<8, int8_t, uint8_t>; +} + +} // namespace +} // namespace low_bitdepth + +#if LIBGAV1_MAX_BITDEPTH >= 10 +namespace high_bitdepth { +namespace { + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + + // LumaAutoRegressionFunc + dsp->film_grain.luma_auto_regression[0] = + ApplyAutoRegressiveFilterToLumaGrain_NEON<10, int16_t, 1>; + dsp->film_grain.luma_auto_regression[1] = + ApplyAutoRegressiveFilterToLumaGrain_NEON<10, int16_t, 2>; + dsp->film_grain.luma_auto_regression[2] = + ApplyAutoRegressiveFilterToLumaGrain_NEON<10, int16_t, 3>; + + // ChromaAutoRegressionFunc[use_luma][auto_regression_coeff_lag][subsampling] + // Chroma autoregression should never be called when lag is 0 and use_luma + // is false. + dsp->film_grain.chroma_auto_regression[0][0] = nullptr; + dsp->film_grain.chroma_auto_regression[0][1] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 1, false>; + dsp->film_grain.chroma_auto_regression[0][2] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 2, false>; + dsp->film_grain.chroma_auto_regression[0][3] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 3, false>; + dsp->film_grain.chroma_auto_regression[1][0] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 0, true>; + dsp->film_grain.chroma_auto_regression[1][1] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 1, true>; + dsp->film_grain.chroma_auto_regression[1][2] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 2, true>; + dsp->film_grain.chroma_auto_regression[1][3] = + ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 3, true>; + + dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_NEON; + + dsp->film_grain.blend_noise_luma = + BlendNoiseWithImageLuma_NEON<10, int16_t, uint16_t>; + dsp->film_grain.blend_noise_chroma[1] = + BlendNoiseWithImageChromaWithCfl_NEON<10, int16_t, uint16_t>; +} + +} // namespace +} // namespace high_bitdepth +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +} // namespace film_grain + +void FilmGrainInit_NEON() { + film_grain::low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + film_grain::high_bitdepth::Init10bpp(); +#endif // LIBGAV1_MAX_BITDEPTH >= 10 +} + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON + +namespace libgav1 { +namespace dsp { + +void FilmGrainInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/film_grain_neon.h b/src/dsp/arm/film_grain_neon.h new file mode 100644 index 0000000..44b3d1d --- /dev/null +++ b/src/dsp/arm/film_grain_neon.h @@ -0,0 +1,47 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_FILM_GRAIN_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_FILM_GRAIN_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initialize members of Dsp::film_grain. This function is not thread-safe. +void FilmGrainInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_FilmGrainAutoregressionLuma LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp10bpp_FilmGrainAutoregressionLuma LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp8bpp_FilmGrainAutoregressionChroma LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp10bpp_FilmGrainAutoregressionChroma LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp8bpp_FilmGrainConstructNoiseImageOverlap LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp8bpp_FilmGrainInitializeScalingLutFunc LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp10bpp_FilmGrainInitializeScalingLutFunc LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseLuma LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseLuma LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseChroma LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseChromaWithCfl LIBGAV1_DSP_NEON +#define LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseChromaWithCfl LIBGAV1_DSP_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_FILM_GRAIN_NEON_H_ diff --git a/src/dsp/arm/intra_edge_neon.cc b/src/dsp/arm/intra_edge_neon.cc new file mode 100644 index 0000000..00b186a --- /dev/null +++ b/src/dsp/arm/intra_edge_neon.cc @@ -0,0 +1,301 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intra_edge.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <algorithm> +#include <cassert> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" // RightShiftWithRounding() + +namespace libgav1 { +namespace dsp { +namespace { + +// Simplified version of intra_edge.cc:kKernels[][]. Only |strength| 1 and 2 are +// required. +constexpr int kKernelsNEON[3][2] = {{4, 8}, {5, 6}}; + +void IntraEdgeFilter_NEON(void* buffer, const int size, const int strength) { + assert(strength == 1 || strength == 2 || strength == 3); + const int kernel_index = strength - 1; + auto* const dst_buffer = static_cast<uint8_t*>(buffer); + + // The first element is not written out (but it is input) so the number of + // elements written is |size| - 1. + if (size == 1) return; + + // |strength| 1 and 2 use a 3 tap filter. + if (strength < 3) { + // The last value requires extending the buffer (duplicating + // |dst_buffer[size - 1]). Calculate it here to avoid extra processing in + // neon. + const uint8_t last_val = RightShiftWithRounding( + kKernelsNEON[kernel_index][0] * dst_buffer[size - 2] + + kKernelsNEON[kernel_index][1] * dst_buffer[size - 1] + + kKernelsNEON[kernel_index][0] * dst_buffer[size - 1], + 4); + + const uint8x8_t krn1 = vdup_n_u8(kKernelsNEON[kernel_index][1]); + + // The first value we need gets overwritten by the output from the + // previous iteration. + uint8x16_t src_0 = vld1q_u8(dst_buffer); + int i = 1; + + // Process blocks until there are less than 16 values remaining. + for (; i < size - 15; i += 16) { + // Loading these at the end of the block with |src_0| will read past the + // end of |top_row_data[160]|, the source of |buffer|. + const uint8x16_t src_1 = vld1q_u8(dst_buffer + i); + const uint8x16_t src_2 = vld1q_u8(dst_buffer + i + 1); + uint16x8_t sum_lo = vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_2)); + sum_lo = vmulq_n_u16(sum_lo, kKernelsNEON[kernel_index][0]); + sum_lo = vmlal_u8(sum_lo, vget_low_u8(src_1), krn1); + uint16x8_t sum_hi = vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_2)); + sum_hi = vmulq_n_u16(sum_hi, kKernelsNEON[kernel_index][0]); + sum_hi = vmlal_u8(sum_hi, vget_high_u8(src_1), krn1); + + const uint8x16_t result = + vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4)); + + // Load the next row before overwriting. This loads an extra 15 values + // past |size| on the trailing iteration. + src_0 = vld1q_u8(dst_buffer + i + 15); + + vst1q_u8(dst_buffer + i, result); + } + + // The last output value |last_val| was already calculated so if + // |remainder| == 1 then we don't have to do anything. + const int remainder = (size - 1) & 0xf; + if (remainder > 1) { + uint8_t temp[16]; + const uint8x16_t src_1 = vld1q_u8(dst_buffer + i); + const uint8x16_t src_2 = vld1q_u8(dst_buffer + i + 1); + + uint16x8_t sum_lo = vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_2)); + sum_lo = vmulq_n_u16(sum_lo, kKernelsNEON[kernel_index][0]); + sum_lo = vmlal_u8(sum_lo, vget_low_u8(src_1), krn1); + uint16x8_t sum_hi = vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_2)); + sum_hi = vmulq_n_u16(sum_hi, kKernelsNEON[kernel_index][0]); + sum_hi = vmlal_u8(sum_hi, vget_high_u8(src_1), krn1); + + const uint8x16_t result = + vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4)); + + vst1q_u8(temp, result); + memcpy(dst_buffer + i, temp, remainder); + } + + dst_buffer[size - 1] = last_val; + return; + } + + assert(strength == 3); + // 5 tap filter. The first element requires duplicating |buffer[0]| and the + // last two elements require duplicating |buffer[size - 1]|. + uint8_t special_vals[3]; + special_vals[0] = RightShiftWithRounding( + (dst_buffer[0] << 1) + (dst_buffer[0] << 2) + (dst_buffer[1] << 2) + + (dst_buffer[2] << 2) + (dst_buffer[3] << 1), + 4); + // Clamp index for very small |size| values. + const int first_index_min = std::max(size - 4, 0); + const int second_index_min = std::max(size - 3, 0); + const int third_index_min = std::max(size - 2, 0); + special_vals[1] = RightShiftWithRounding( + (dst_buffer[first_index_min] << 1) + (dst_buffer[second_index_min] << 2) + + (dst_buffer[third_index_min] << 2) + (dst_buffer[size - 1] << 2) + + (dst_buffer[size - 1] << 1), + 4); + special_vals[2] = RightShiftWithRounding( + (dst_buffer[second_index_min] << 1) + (dst_buffer[third_index_min] << 2) + + // x << 2 + x << 2 == x << 3 + (dst_buffer[size - 1] << 3) + (dst_buffer[size - 1] << 1), + 4); + + // The first two values we need get overwritten by the output from the + // previous iteration. + uint8x16_t src_0 = vld1q_u8(dst_buffer - 1); + uint8x16_t src_1 = vld1q_u8(dst_buffer); + int i = 1; + + for (; i < size - 15; i += 16) { + // Loading these at the end of the block with |src_[01]| will read past + // the end of |top_row_data[160]|, the source of |buffer|. + const uint8x16_t src_2 = vld1q_u8(dst_buffer + i); + const uint8x16_t src_3 = vld1q_u8(dst_buffer + i + 1); + const uint8x16_t src_4 = vld1q_u8(dst_buffer + i + 2); + + uint16x8_t sum_lo = + vshlq_n_u16(vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_4)), 1); + const uint16x8_t sum_123_lo = vaddw_u8( + vaddl_u8(vget_low_u8(src_1), vget_low_u8(src_2)), vget_low_u8(src_3)); + sum_lo = vaddq_u16(sum_lo, vshlq_n_u16(sum_123_lo, 2)); + + uint16x8_t sum_hi = + vshlq_n_u16(vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_4)), 1); + const uint16x8_t sum_123_hi = + vaddw_u8(vaddl_u8(vget_high_u8(src_1), vget_high_u8(src_2)), + vget_high_u8(src_3)); + sum_hi = vaddq_u16(sum_hi, vshlq_n_u16(sum_123_hi, 2)); + + const uint8x16_t result = + vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4)); + + src_0 = vld1q_u8(dst_buffer + i + 14); + src_1 = vld1q_u8(dst_buffer + i + 15); + + vst1q_u8(dst_buffer + i, result); + } + + const int remainder = (size - 1) & 0xf; + // Like the 3 tap but if there are two remaining values we have already + // calculated them. + if (remainder > 2) { + uint8_t temp[16]; + const uint8x16_t src_2 = vld1q_u8(dst_buffer + i); + const uint8x16_t src_3 = vld1q_u8(dst_buffer + i + 1); + const uint8x16_t src_4 = vld1q_u8(dst_buffer + i + 2); + + uint16x8_t sum_lo = + vshlq_n_u16(vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_4)), 1); + const uint16x8_t sum_123_lo = vaddw_u8( + vaddl_u8(vget_low_u8(src_1), vget_low_u8(src_2)), vget_low_u8(src_3)); + sum_lo = vaddq_u16(sum_lo, vshlq_n_u16(sum_123_lo, 2)); + + uint16x8_t sum_hi = + vshlq_n_u16(vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_4)), 1); + const uint16x8_t sum_123_hi = + vaddw_u8(vaddl_u8(vget_high_u8(src_1), vget_high_u8(src_2)), + vget_high_u8(src_3)); + sum_hi = vaddq_u16(sum_hi, vshlq_n_u16(sum_123_hi, 2)); + + const uint8x16_t result = + vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4)); + + vst1q_u8(temp, result); + memcpy(dst_buffer + i, temp, remainder); + } + + dst_buffer[1] = special_vals[0]; + // Avoid overwriting |dst_buffer[0]|. + if (size > 2) dst_buffer[size - 2] = special_vals[1]; + dst_buffer[size - 1] = special_vals[2]; +} + +// (-|src0| + |src1| * 9 + |src2| * 9 - |src3|) >> 4 +uint8x8_t Upsample(const uint8x8_t src0, const uint8x8_t src1, + const uint8x8_t src2, const uint8x8_t src3) { + const uint16x8_t middle = vmulq_n_u16(vaddl_u8(src1, src2), 9); + const uint16x8_t ends = vaddl_u8(src0, src3); + const int16x8_t sum = + vsubq_s16(vreinterpretq_s16_u16(middle), vreinterpretq_s16_u16(ends)); + return vqrshrun_n_s16(sum, 4); +} + +void IntraEdgeUpsampler_NEON(void* buffer, const int size) { + assert(size % 4 == 0 && size <= 16); + auto* const pixel_buffer = static_cast<uint8_t*>(buffer); + // This is OK because we don't read this value for |size| 4 or 8 but if we + // write |pixel_buffer[size]| and then vld() it, that seems to introduce + // some latency. + pixel_buffer[-2] = pixel_buffer[-1]; + if (size == 4) { + // This uses one load and two vtbl() which is better than 4x Load{Lo,Hi}4(). + const uint8x8_t src = vld1_u8(pixel_buffer - 1); + // The outside values are negated so put those in the same vector. + const uint8x8_t src03 = vtbl1_u8(src, vcreate_u8(0x0404030202010000)); + // Reverse |src1| and |src2| so we can use |src2| for the interleave at the + // end. + const uint8x8_t src21 = vtbl1_u8(src, vcreate_u8(0x0302010004030201)); + + const uint16x8_t middle = vmull_u8(src21, vdup_n_u8(9)); + const int16x8_t half_sum = vsubq_s16( + vreinterpretq_s16_u16(middle), vreinterpretq_s16_u16(vmovl_u8(src03))); + const int16x4_t sum = + vadd_s16(vget_low_s16(half_sum), vget_high_s16(half_sum)); + const uint8x8_t result = vqrshrun_n_s16(vcombine_s16(sum, sum), 4); + + vst1_u8(pixel_buffer - 1, InterleaveLow8(result, src21)); + return; + } else if (size == 8) { + // Likewise, one load + multiple vtbls seems preferred to multiple loads. + const uint8x16_t src = vld1q_u8(pixel_buffer - 1); + const uint8x8_t src0 = VQTbl1U8(src, vcreate_u8(0x0605040302010000)); + const uint8x8_t src1 = vget_low_u8(src); + const uint8x8_t src2 = VQTbl1U8(src, vcreate_u8(0x0807060504030201)); + const uint8x8_t src3 = VQTbl1U8(src, vcreate_u8(0x0808070605040302)); + + const uint8x8x2_t output = {Upsample(src0, src1, src2, src3), src2}; + vst2_u8(pixel_buffer - 1, output); + return; + } + assert(size == 12 || size == 16); + // Extend the input borders to avoid branching later. + pixel_buffer[size] = pixel_buffer[size - 1]; + const uint8x16_t src0 = vld1q_u8(pixel_buffer - 2); + const uint8x16_t src1 = vld1q_u8(pixel_buffer - 1); + const uint8x16_t src2 = vld1q_u8(pixel_buffer); + const uint8x16_t src3 = vld1q_u8(pixel_buffer + 1); + + const uint8x8_t result_lo = Upsample(vget_low_u8(src0), vget_low_u8(src1), + vget_low_u8(src2), vget_low_u8(src3)); + + const uint8x8x2_t output_lo = {result_lo, vget_low_u8(src2)}; + vst2_u8(pixel_buffer - 1, output_lo); + + const uint8x8_t result_hi = Upsample(vget_high_u8(src0), vget_high_u8(src1), + vget_high_u8(src2), vget_high_u8(src3)); + + if (size == 12) { + vst1_u8(pixel_buffer + 15, InterleaveLow8(result_hi, vget_high_u8(src2))); + } else /* size == 16 */ { + const uint8x8x2_t output_hi = {result_hi, vget_high_u8(src2)}; + vst2_u8(pixel_buffer + 15, output_hi); + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->intra_edge_filter = IntraEdgeFilter_NEON; + dsp->intra_edge_upsampler = IntraEdgeUpsampler_NEON; +} + +} // namespace + +void IntraEdgeInit_NEON() { Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void IntraEdgeInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/intra_edge_neon.h b/src/dsp/arm/intra_edge_neon.h new file mode 100644 index 0000000..d3bb243 --- /dev/null +++ b/src/dsp/arm/intra_edge_neon.h @@ -0,0 +1,39 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_INTRA_EDGE_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_INTRA_EDGE_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::intra_edge_filter and Dsp::intra_edge_upsampler. This +// function is not thread-safe. +void IntraEdgeInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_IntraEdgeFilter LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_IntraEdgeUpsampler LIBGAV1_CPU_NEON + +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_INTRA_EDGE_NEON_H_ diff --git a/src/dsp/arm/intrapred_cfl_neon.cc b/src/dsp/arm/intrapred_cfl_neon.cc new file mode 100644 index 0000000..45fe33b --- /dev/null +++ b/src/dsp/arm/intrapred_cfl_neon.cc @@ -0,0 +1,479 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intrapred.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +uint8x16_t Set2ValuesQ(const uint8_t* a) { + uint16_t combined_values = a[0] | a[1] << 8; + return vreinterpretq_u8_u16(vdupq_n_u16(combined_values)); +} + +uint32_t SumVector(uint32x2_t a) { +#if defined(__aarch64__) + return vaddv_u32(a); +#else + const uint64x1_t b = vpaddl_u32(a); + return vget_lane_u32(vreinterpret_u32_u64(b), 0); +#endif // defined(__aarch64__) +} + +uint32_t SumVector(uint32x4_t a) { +#if defined(__aarch64__) + return vaddvq_u32(a); +#else + const uint64x2_t b = vpaddlq_u32(a); + const uint64x1_t c = vadd_u64(vget_low_u64(b), vget_high_u64(b)); + return vget_lane_u32(vreinterpret_u32_u64(c), 0); +#endif // defined(__aarch64__) +} + +// Divide by the number of elements. +uint32_t Average(const uint32_t sum, const int width, const int height) { + return RightShiftWithRounding(sum, FloorLog2(width) + FloorLog2(height)); +} + +// Subtract |val| from every element in |a|. +void BlockSubtract(const uint32_t val, + int16_t a[kCflLumaBufferStride][kCflLumaBufferStride], + const int width, const int height) { + assert(val <= INT16_MAX); + const int16x8_t val_v = vdupq_n_s16(static_cast<int16_t>(val)); + + for (int y = 0; y < height; ++y) { + if (width == 4) { + const int16x4_t b = vld1_s16(a[y]); + vst1_s16(a[y], vsub_s16(b, vget_low_s16(val_v))); + } else if (width == 8) { + const int16x8_t b = vld1q_s16(a[y]); + vst1q_s16(a[y], vsubq_s16(b, val_v)); + } else if (width == 16) { + const int16x8_t b = vld1q_s16(a[y]); + const int16x8_t c = vld1q_s16(a[y] + 8); + vst1q_s16(a[y], vsubq_s16(b, val_v)); + vst1q_s16(a[y] + 8, vsubq_s16(c, val_v)); + } else /* block_width == 32 */ { + const int16x8_t b = vld1q_s16(a[y]); + const int16x8_t c = vld1q_s16(a[y] + 8); + const int16x8_t d = vld1q_s16(a[y] + 16); + const int16x8_t e = vld1q_s16(a[y] + 24); + vst1q_s16(a[y], vsubq_s16(b, val_v)); + vst1q_s16(a[y] + 8, vsubq_s16(c, val_v)); + vst1q_s16(a[y] + 16, vsubq_s16(d, val_v)); + vst1q_s16(a[y] + 24, vsubq_s16(e, val_v)); + } + } +} + +template <int block_width, int block_height> +void CflSubsampler420_NEON( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_width, const int max_luma_height, + const void* const source, const ptrdiff_t stride) { + const auto* src = static_cast<const uint8_t*>(source); + uint32_t sum; + if (block_width == 4) { + assert(max_luma_width >= 8); + uint32x2_t running_sum = vdup_n_u32(0); + + for (int y = 0; y < block_height; ++y) { + const uint8x8_t row0 = vld1_u8(src); + const uint8x8_t row1 = vld1_u8(src + stride); + + uint16x4_t sum_row = vpadal_u8(vpaddl_u8(row0), row1); + sum_row = vshl_n_u16(sum_row, 1); + running_sum = vpadal_u16(running_sum, sum_row); + vst1_s16(luma[y], vreinterpret_s16_u16(sum_row)); + + if (y << 1 < max_luma_height - 2) { + // Once this threshold is reached the loop could be simplified. + src += stride << 1; + } + } + + sum = SumVector(running_sum); + } else if (block_width == 8) { + const uint8x16_t x_index = {0, 0, 2, 2, 4, 4, 6, 6, + 8, 8, 10, 10, 12, 12, 14, 14}; + const uint8x16_t x_max_index = vdupq_n_u8(max_luma_width - 2); + const uint8x16_t x_mask = vcltq_u8(x_index, x_max_index); + + uint32x4_t running_sum = vdupq_n_u32(0); + + for (int y = 0; y < block_height; ++y) { + const uint8x16_t x_max0 = Set2ValuesQ(src + max_luma_width - 2); + const uint8x16_t x_max1 = Set2ValuesQ(src + max_luma_width - 2 + stride); + + uint8x16_t row0 = vld1q_u8(src); + row0 = vbslq_u8(x_mask, row0, x_max0); + uint8x16_t row1 = vld1q_u8(src + stride); + row1 = vbslq_u8(x_mask, row1, x_max1); + + uint16x8_t sum_row = vpadalq_u8(vpaddlq_u8(row0), row1); + sum_row = vshlq_n_u16(sum_row, 1); + running_sum = vpadalq_u16(running_sum, sum_row); + vst1q_s16(luma[y], vreinterpretq_s16_u16(sum_row)); + + if (y << 1 < max_luma_height - 2) { + src += stride << 1; + } + } + + sum = SumVector(running_sum); + } else /* block_width >= 16 */ { + const uint8x16_t x_max_index = vdupq_n_u8(max_luma_width - 2); + uint32x4_t running_sum = vdupq_n_u32(0); + + for (int y = 0; y < block_height; ++y) { + uint8x16_t x_index = {0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30}; + const uint8x16_t x_max00 = vdupq_n_u8(src[max_luma_width - 2]); + const uint8x16_t x_max01 = vdupq_n_u8(src[max_luma_width - 2 + 1]); + const uint8x16_t x_max10 = vdupq_n_u8(src[stride + max_luma_width - 2]); + const uint8x16_t x_max11 = + vdupq_n_u8(src[stride + max_luma_width - 2 + 1]); + for (int x = 0; x < block_width; x += 16) { + const ptrdiff_t src_x_offset = x << 1; + const uint8x16_t x_mask = vcltq_u8(x_index, x_max_index); + const uint8x16x2_t row0 = vld2q_u8(src + src_x_offset); + const uint8x16x2_t row1 = vld2q_u8(src + src_x_offset + stride); + const uint8x16_t row_masked_00 = vbslq_u8(x_mask, row0.val[0], x_max00); + const uint8x16_t row_masked_01 = vbslq_u8(x_mask, row0.val[1], x_max01); + const uint8x16_t row_masked_10 = vbslq_u8(x_mask, row1.val[0], x_max10); + const uint8x16_t row_masked_11 = vbslq_u8(x_mask, row1.val[1], x_max11); + + uint16x8_t sum_row_lo = + vaddl_u8(vget_low_u8(row_masked_00), vget_low_u8(row_masked_01)); + sum_row_lo = vaddw_u8(sum_row_lo, vget_low_u8(row_masked_10)); + sum_row_lo = vaddw_u8(sum_row_lo, vget_low_u8(row_masked_11)); + sum_row_lo = vshlq_n_u16(sum_row_lo, 1); + running_sum = vpadalq_u16(running_sum, sum_row_lo); + vst1q_s16(luma[y] + x, vreinterpretq_s16_u16(sum_row_lo)); + + uint16x8_t sum_row_hi = + vaddl_u8(vget_high_u8(row_masked_00), vget_high_u8(row_masked_01)); + sum_row_hi = vaddw_u8(sum_row_hi, vget_high_u8(row_masked_10)); + sum_row_hi = vaddw_u8(sum_row_hi, vget_high_u8(row_masked_11)); + sum_row_hi = vshlq_n_u16(sum_row_hi, 1); + running_sum = vpadalq_u16(running_sum, sum_row_hi); + vst1q_s16(luma[y] + x + 8, vreinterpretq_s16_u16(sum_row_hi)); + + x_index = vaddq_u8(x_index, vdupq_n_u8(32)); + } + if (y << 1 < max_luma_height - 2) { + src += stride << 1; + } + } + sum = SumVector(running_sum); + } + + const uint32_t average = Average(sum, block_width, block_height); + BlockSubtract(average, luma, block_width, block_height); +} + +template <int block_width, int block_height> +void CflSubsampler444_NEON( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_width, const int max_luma_height, + const void* const source, const ptrdiff_t stride) { + const auto* src = static_cast<const uint8_t*>(source); + uint32_t sum; + if (block_width == 4) { + assert(max_luma_width >= 4); + uint32x4_t running_sum = vdupq_n_u32(0); + uint8x8_t row = vdup_n_u8(0); + + for (int y = 0; y < block_height; y += 2) { + row = Load4<0>(src, row); + row = Load4<1>(src + stride, row); + if (y < (max_luma_height - 1)) { + src += stride << 1; + } + + const uint16x8_t row_shifted = vshll_n_u8(row, 3); + running_sum = vpadalq_u16(running_sum, row_shifted); + vst1_s16(luma[y], vreinterpret_s16_u16(vget_low_u16(row_shifted))); + vst1_s16(luma[y + 1], vreinterpret_s16_u16(vget_high_u16(row_shifted))); + } + + sum = SumVector(running_sum); + } else if (block_width == 8) { + const uint8x8_t x_index = {0, 1, 2, 3, 4, 5, 6, 7}; + const uint8x8_t x_max_index = vdup_n_u8(max_luma_width - 1); + const uint8x8_t x_mask = vclt_u8(x_index, x_max_index); + + uint32x4_t running_sum = vdupq_n_u32(0); + + for (int y = 0; y < block_height; ++y) { + const uint8x8_t x_max = vdup_n_u8(src[max_luma_width - 1]); + const uint8x8_t row = vbsl_u8(x_mask, vld1_u8(src), x_max); + + const uint16x8_t row_shifted = vshll_n_u8(row, 3); + running_sum = vpadalq_u16(running_sum, row_shifted); + vst1q_s16(luma[y], vreinterpretq_s16_u16(row_shifted)); + + if (y < max_luma_height - 1) { + src += stride; + } + } + + sum = SumVector(running_sum); + } else /* block_width >= 16 */ { + const uint8x16_t x_max_index = vdupq_n_u8(max_luma_width - 1); + uint32x4_t running_sum = vdupq_n_u32(0); + + for (int y = 0; y < block_height; ++y) { + uint8x16_t x_index = {0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15}; + const uint8x16_t x_max = vdupq_n_u8(src[max_luma_width - 1]); + for (int x = 0; x < block_width; x += 16) { + const uint8x16_t x_mask = vcltq_u8(x_index, x_max_index); + const uint8x16_t row = vbslq_u8(x_mask, vld1q_u8(src + x), x_max); + + const uint16x8_t row_shifted_low = vshll_n_u8(vget_low_u8(row), 3); + const uint16x8_t row_shifted_high = vshll_n_u8(vget_high_u8(row), 3); + running_sum = vpadalq_u16(running_sum, row_shifted_low); + running_sum = vpadalq_u16(running_sum, row_shifted_high); + vst1q_s16(luma[y] + x, vreinterpretq_s16_u16(row_shifted_low)); + vst1q_s16(luma[y] + x + 8, vreinterpretq_s16_u16(row_shifted_high)); + + x_index = vaddq_u8(x_index, vdupq_n_u8(16)); + } + if (y < max_luma_height - 1) { + src += stride; + } + } + sum = SumVector(running_sum); + } + + const uint32_t average = Average(sum, block_width, block_height); + BlockSubtract(average, luma, block_width, block_height); +} + +// Saturate |dc + ((alpha * luma) >> 6))| to uint8_t. +inline uint8x8_t Combine8(const int16x8_t luma, const int alpha, + const int16x8_t dc) { + const int16x8_t la = vmulq_n_s16(luma, alpha); + // Subtract the sign bit to round towards zero. + const int16x8_t sub_sign = vsraq_n_s16(la, la, 15); + // Shift and accumulate. + const int16x8_t result = vrsraq_n_s16(dc, sub_sign, 6); + return vqmovun_s16(result); +} + +// The range of luma/alpha is not really important because it gets saturated to +// uint8_t. Saturated int16_t >> 6 outranges uint8_t. +template <int block_height> +inline void CflIntraPredictor4xN_NEON( + void* const dest, const ptrdiff_t stride, + const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int alpha) { + auto* dst = static_cast<uint8_t*>(dest); + const int16x8_t dc = vdupq_n_s16(dst[0]); + for (int y = 0; y < block_height; y += 2) { + const int16x4_t luma_row0 = vld1_s16(luma[y]); + const int16x4_t luma_row1 = vld1_s16(luma[y + 1]); + const uint8x8_t sum = + Combine8(vcombine_s16(luma_row0, luma_row1), alpha, dc); + StoreLo4(dst, sum); + dst += stride; + StoreHi4(dst, sum); + dst += stride; + } +} + +template <int block_height> +inline void CflIntraPredictor8xN_NEON( + void* const dest, const ptrdiff_t stride, + const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int alpha) { + auto* dst = static_cast<uint8_t*>(dest); + const int16x8_t dc = vdupq_n_s16(dst[0]); + for (int y = 0; y < block_height; ++y) { + const int16x8_t luma_row = vld1q_s16(luma[y]); + const uint8x8_t sum = Combine8(luma_row, alpha, dc); + vst1_u8(dst, sum); + dst += stride; + } +} + +template <int block_height> +inline void CflIntraPredictor16xN_NEON( + void* const dest, const ptrdiff_t stride, + const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int alpha) { + auto* dst = static_cast<uint8_t*>(dest); + const int16x8_t dc = vdupq_n_s16(dst[0]); + for (int y = 0; y < block_height; ++y) { + const int16x8_t luma_row_0 = vld1q_s16(luma[y]); + const int16x8_t luma_row_1 = vld1q_s16(luma[y] + 8); + const uint8x8_t sum_0 = Combine8(luma_row_0, alpha, dc); + const uint8x8_t sum_1 = Combine8(luma_row_1, alpha, dc); + vst1_u8(dst, sum_0); + vst1_u8(dst + 8, sum_1); + dst += stride; + } +} + +template <int block_height> +inline void CflIntraPredictor32xN_NEON( + void* const dest, const ptrdiff_t stride, + const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int alpha) { + auto* dst = static_cast<uint8_t*>(dest); + const int16x8_t dc = vdupq_n_s16(dst[0]); + for (int y = 0; y < block_height; ++y) { + const int16x8_t luma_row_0 = vld1q_s16(luma[y]); + const int16x8_t luma_row_1 = vld1q_s16(luma[y] + 8); + const int16x8_t luma_row_2 = vld1q_s16(luma[y] + 16); + const int16x8_t luma_row_3 = vld1q_s16(luma[y] + 24); + const uint8x8_t sum_0 = Combine8(luma_row_0, alpha, dc); + const uint8x8_t sum_1 = Combine8(luma_row_1, alpha, dc); + const uint8x8_t sum_2 = Combine8(luma_row_2, alpha, dc); + const uint8x8_t sum_3 = Combine8(luma_row_3, alpha, dc); + vst1_u8(dst, sum_0); + vst1_u8(dst + 8, sum_1); + vst1_u8(dst + 16, sum_2); + vst1_u8(dst + 24, sum_3); + dst += stride; + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + + dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType420] = + CflSubsampler420_NEON<4, 4>; + dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType420] = + CflSubsampler420_NEON<4, 8>; + dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType420] = + CflSubsampler420_NEON<4, 16>; + + dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType420] = + CflSubsampler420_NEON<8, 4>; + dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType420] = + CflSubsampler420_NEON<8, 8>; + dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType420] = + CflSubsampler420_NEON<8, 16>; + dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType420] = + CflSubsampler420_NEON<8, 32>; + + dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType420] = + CflSubsampler420_NEON<16, 4>; + dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType420] = + CflSubsampler420_NEON<16, 8>; + dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType420] = + CflSubsampler420_NEON<16, 16>; + dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType420] = + CflSubsampler420_NEON<16, 32>; + + dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType420] = + CflSubsampler420_NEON<32, 8>; + dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType420] = + CflSubsampler420_NEON<32, 16>; + dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType420] = + CflSubsampler420_NEON<32, 32>; + + dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType444] = + CflSubsampler444_NEON<4, 4>; + dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType444] = + CflSubsampler444_NEON<4, 8>; + dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType444] = + CflSubsampler444_NEON<4, 16>; + + dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType444] = + CflSubsampler444_NEON<8, 4>; + dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType444] = + CflSubsampler444_NEON<8, 8>; + dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType444] = + CflSubsampler444_NEON<8, 16>; + dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType444] = + CflSubsampler444_NEON<8, 32>; + + dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType444] = + CflSubsampler444_NEON<16, 4>; + dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType444] = + CflSubsampler444_NEON<16, 8>; + dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType444] = + CflSubsampler444_NEON<16, 16>; + dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType444] = + CflSubsampler444_NEON<16, 32>; + + dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType444] = + CflSubsampler444_NEON<32, 8>; + dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType444] = + CflSubsampler444_NEON<32, 16>; + dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType444] = + CflSubsampler444_NEON<32, 32>; + + dsp->cfl_intra_predictors[kTransformSize4x4] = CflIntraPredictor4xN_NEON<4>; + dsp->cfl_intra_predictors[kTransformSize4x8] = CflIntraPredictor4xN_NEON<8>; + dsp->cfl_intra_predictors[kTransformSize4x16] = CflIntraPredictor4xN_NEON<16>; + + dsp->cfl_intra_predictors[kTransformSize8x4] = CflIntraPredictor8xN_NEON<4>; + dsp->cfl_intra_predictors[kTransformSize8x8] = CflIntraPredictor8xN_NEON<8>; + dsp->cfl_intra_predictors[kTransformSize8x16] = CflIntraPredictor8xN_NEON<16>; + dsp->cfl_intra_predictors[kTransformSize8x32] = CflIntraPredictor8xN_NEON<32>; + + dsp->cfl_intra_predictors[kTransformSize16x4] = CflIntraPredictor16xN_NEON<4>; + dsp->cfl_intra_predictors[kTransformSize16x8] = CflIntraPredictor16xN_NEON<8>; + dsp->cfl_intra_predictors[kTransformSize16x16] = + CflIntraPredictor16xN_NEON<16>; + dsp->cfl_intra_predictors[kTransformSize16x32] = + CflIntraPredictor16xN_NEON<32>; + + dsp->cfl_intra_predictors[kTransformSize32x8] = CflIntraPredictor32xN_NEON<8>; + dsp->cfl_intra_predictors[kTransformSize32x16] = + CflIntraPredictor32xN_NEON<16>; + dsp->cfl_intra_predictors[kTransformSize32x32] = + CflIntraPredictor32xN_NEON<32>; + // Max Cfl predictor size is 32x32. +} + +} // namespace +} // namespace low_bitdepth + +void IntraPredCflInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void IntraPredCflInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/intrapred_directional_neon.cc b/src/dsp/arm/intrapred_directional_neon.cc new file mode 100644 index 0000000..805ba81 --- /dev/null +++ b/src/dsp/arm/intrapred_directional_neon.cc @@ -0,0 +1,926 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intrapred.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <algorithm> // std::min +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> // memset + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +// Blend two values based on a 32 bit weight. +inline uint8x8_t WeightedBlend(const uint8x8_t a, const uint8x8_t b, + const uint8x8_t a_weight, + const uint8x8_t b_weight) { + const uint16x8_t a_product = vmull_u8(a, a_weight); + const uint16x8_t b_product = vmull_u8(b, b_weight); + + return vrshrn_n_u16(vaddq_u16(a_product, b_product), 5); +} + +// For vertical operations the weights are one constant value. +inline uint8x8_t WeightedBlend(const uint8x8_t a, const uint8x8_t b, + const uint8_t weight) { + return WeightedBlend(a, b, vdup_n_u8(32 - weight), vdup_n_u8(weight)); +} + +// Fill |left| and |right| with the appropriate values for a given |base_step|. +inline void LoadStepwise(const uint8_t* const source, const uint8x8_t left_step, + const uint8x8_t right_step, uint8x8_t* left, + uint8x8_t* right) { + const uint8x16_t mixed = vld1q_u8(source); + *left = VQTbl1U8(mixed, left_step); + *right = VQTbl1U8(mixed, right_step); +} + +// Handle signed step arguments by ignoring the sign. Negative values are +// considered out of range and overwritten later. +inline void LoadStepwise(const uint8_t* const source, const int8x8_t left_step, + const int8x8_t right_step, uint8x8_t* left, + uint8x8_t* right) { + LoadStepwise(source, vreinterpret_u8_s8(left_step), + vreinterpret_u8_s8(right_step), left, right); +} + +// Process 4 or 8 |width| by any |height|. +template <int width> +inline void DirectionalZone1_WxH(uint8_t* dst, const ptrdiff_t stride, + const int height, const uint8_t* const top, + const int xstep, const bool upsampled) { + assert(width == 4 || width == 8); + + const int upsample_shift = static_cast<int>(upsampled); + const int scale_bits = 6 - upsample_shift; + + const int max_base_x = (width + height - 1) << upsample_shift; + const int8x8_t max_base = vdup_n_s8(max_base_x); + const uint8x8_t top_max_base = vdup_n_u8(top[max_base_x]); + + const int8x8_t all = vcreate_s8(0x0706050403020100); + const int8x8_t even = vcreate_s8(0x0e0c0a0806040200); + const int8x8_t base_step = upsampled ? even : all; + const int8x8_t right_step = vadd_s8(base_step, vdup_n_s8(1)); + + int top_x = xstep; + int y = 0; + do { + const int top_base_x = top_x >> scale_bits; + + if (top_base_x >= max_base_x) { + for (int i = y; i < height; ++i) { + memset(dst, top[max_base_x], 4 /* width */); + dst += stride; + } + return; + } + + const uint8_t shift = ((top_x << upsample_shift) & 0x3F) >> 1; + + // Zone2 uses negative values for xstep. Use signed values to compare + // |top_base_x| to |max_base_x|. + const int8x8_t base_v = vadd_s8(vdup_n_s8(top_base_x), base_step); + + const uint8x8_t max_base_mask = vclt_s8(base_v, max_base); + + // 4 wide subsamples the output. 8 wide subsamples the input. + if (width == 4) { + const uint8x8_t left_values = vld1_u8(top + top_base_x); + const uint8x8_t right_values = RightShift<8>(left_values); + const uint8x8_t value = WeightedBlend(left_values, right_values, shift); + + // If |upsampled| is true then extract every other value for output. + const uint8x8_t value_stepped = + vtbl1_u8(value, vreinterpret_u8_s8(base_step)); + const uint8x8_t masked_value = + vbsl_u8(max_base_mask, value_stepped, top_max_base); + + StoreLo4(dst, masked_value); + } else /* width == 8 */ { + uint8x8_t left_values, right_values; + // WeightedBlend() steps up to Q registers. Downsample the input to avoid + // doing extra calculations. + LoadStepwise(top + top_base_x, base_step, right_step, &left_values, + &right_values); + + const uint8x8_t value = WeightedBlend(left_values, right_values, shift); + const uint8x8_t masked_value = + vbsl_u8(max_base_mask, value, top_max_base); + + vst1_u8(dst, masked_value); + } + dst += stride; + top_x += xstep; + } while (++y < height); +} + +// Process a multiple of 8 |width| by any |height|. Processes horizontally +// before vertically in the hopes of being a little more cache friendly. +inline void DirectionalZone1_WxH(uint8_t* dst, const ptrdiff_t stride, + const int width, const int height, + const uint8_t* const top, const int xstep, + const bool upsampled) { + assert(width % 8 == 0); + const int upsample_shift = static_cast<int>(upsampled); + const int scale_bits = 6 - upsample_shift; + + const int max_base_x = (width + height - 1) << upsample_shift; + const int8x8_t max_base = vdup_n_s8(max_base_x); + const uint8x8_t top_max_base = vdup_n_u8(top[max_base_x]); + + const int8x8_t all = vcreate_s8(0x0706050403020100); + const int8x8_t even = vcreate_s8(0x0e0c0a0806040200); + const int8x8_t base_step = upsampled ? even : all; + const int8x8_t right_step = vadd_s8(base_step, vdup_n_s8(1)); + const int8x8_t block_step = vdup_n_s8(8 << upsample_shift); + + int top_x = xstep; + int y = 0; + do { + const int top_base_x = top_x >> scale_bits; + + if (top_base_x >= max_base_x) { + for (int i = y; i < height; ++i) { + memset(dst, top[max_base_x], 4 /* width */); + dst += stride; + } + return; + } + + const uint8_t shift = ((top_x << upsample_shift) & 0x3F) >> 1; + + // Zone2 uses negative values for xstep. Use signed values to compare + // |top_base_x| to |max_base_x|. + int8x8_t base_v = vadd_s8(vdup_n_s8(top_base_x), base_step); + + int x = 0; + do { + const uint8x8_t max_base_mask = vclt_s8(base_v, max_base); + + // Extract the input values based on |upsampled| here to avoid doing twice + // as many calculations. + uint8x8_t left_values, right_values; + LoadStepwise(top + top_base_x + x, base_step, right_step, &left_values, + &right_values); + + const uint8x8_t value = WeightedBlend(left_values, right_values, shift); + const uint8x8_t masked_value = + vbsl_u8(max_base_mask, value, top_max_base); + + vst1_u8(dst + x, masked_value); + + base_v = vadd_s8(base_v, block_step); + x += 8; + } while (x < width); + top_x += xstep; + dst += stride; + } while (++y < height); +} + +void DirectionalIntraPredictorZone1_NEON(void* const dest, + const ptrdiff_t stride, + const void* const top_row, + const int width, const int height, + const int xstep, + const bool upsampled_top) { + const uint8_t* const top = static_cast<const uint8_t*>(top_row); + uint8_t* dst = static_cast<uint8_t*>(dest); + + assert(xstep > 0); + + const int upsample_shift = static_cast<int>(upsampled_top); + + const uint8x8_t all = vcreate_u8(0x0706050403020100); + + if (xstep == 64) { + assert(!upsampled_top); + const uint8_t* top_ptr = top + 1; + int y = 0; + do { + memcpy(dst, top_ptr, width); + memcpy(dst + stride, top_ptr + 1, width); + memcpy(dst + 2 * stride, top_ptr + 2, width); + memcpy(dst + 3 * stride, top_ptr + 3, width); + dst += 4 * stride; + top_ptr += 4; + y += 4; + } while (y < height); + } else if (width == 4) { + DirectionalZone1_WxH<4>(dst, stride, height, top, xstep, upsampled_top); + } else if (xstep > 51) { + // 7.11.2.10. Intra edge upsample selection process + // if ( d <= 0 || d >= 40 ) useUpsample = 0 + // For |upsample_top| the delta is from vertical so |prediction_angle - 90|. + // In |kDirectionalIntraPredictorDerivative[]| angles less than 51 will meet + // this criteria. The |xstep| value for angle 51 happens to be 51 as well. + // Shallower angles have greater xstep values. + assert(!upsampled_top); + const int max_base_x = ((width + height) - 1); + const uint8x8_t max_base = vdup_n_u8(max_base_x); + const uint8x8_t top_max_base = vdup_n_u8(top[max_base_x]); + const uint8x8_t block_step = vdup_n_u8(8); + + int top_x = xstep; + int y = 0; + do { + const int top_base_x = top_x >> 6; + const uint8_t shift = ((top_x << upsample_shift) & 0x3F) >> 1; + uint8x8_t base_v = vadd_u8(vdup_n_u8(top_base_x), all); + int x = 0; + // Only calculate a block of 8 when at least one of the output values is + // within range. Otherwise it can read off the end of |top|. + const int must_calculate_width = + std::min(width, max_base_x - top_base_x + 7) & ~7; + for (; x < must_calculate_width; x += 8) { + const uint8x8_t max_base_mask = vclt_u8(base_v, max_base); + + // Since these |xstep| values can not be upsampled the load is + // simplified. + const uint8x8_t left_values = vld1_u8(top + top_base_x + x); + const uint8x8_t right_values = vld1_u8(top + top_base_x + x + 1); + const uint8x8_t value = WeightedBlend(left_values, right_values, shift); + const uint8x8_t masked_value = + vbsl_u8(max_base_mask, value, top_max_base); + + vst1_u8(dst + x, masked_value); + base_v = vadd_u8(base_v, block_step); + } + memset(dst + x, top[max_base_x], width - x); + dst += stride; + top_x += xstep; + } while (++y < height); + } else { + DirectionalZone1_WxH(dst, stride, width, height, top, xstep, upsampled_top); + } +} + +// Process 4 or 8 |width| by 4 or 8 |height|. +template <int width> +inline void DirectionalZone3_WxH(uint8_t* dest, const ptrdiff_t stride, + const int height, + const uint8_t* const left_column, + const int base_left_y, const int ystep, + const int upsample_shift) { + assert(width == 4 || width == 8); + assert(height == 4 || height == 8); + const int scale_bits = 6 - upsample_shift; + + // Zone3 never runs out of left_column values. + assert((width + height - 1) << upsample_shift > // max_base_y + ((ystep * width) >> scale_bits) + + (/* base_step */ 1 << upsample_shift) * + (height - 1)); // left_base_y + + // Limited improvement for 8x8. ~20% faster for 64x64. + const uint8x8_t all = vcreate_u8(0x0706050403020100); + const uint8x8_t even = vcreate_u8(0x0e0c0a0806040200); + const uint8x8_t base_step = upsample_shift ? even : all; + const uint8x8_t right_step = vadd_u8(base_step, vdup_n_u8(1)); + + uint8_t* dst = dest; + uint8x8_t left_v[8], right_v[8], value_v[8]; + const uint8_t* const left = left_column; + + const int index_0 = base_left_y; + LoadStepwise(left + (index_0 >> scale_bits), base_step, right_step, + &left_v[0], &right_v[0]); + value_v[0] = WeightedBlend(left_v[0], right_v[0], + ((index_0 << upsample_shift) & 0x3F) >> 1); + + const int index_1 = base_left_y + ystep; + LoadStepwise(left + (index_1 >> scale_bits), base_step, right_step, + &left_v[1], &right_v[1]); + value_v[1] = WeightedBlend(left_v[1], right_v[1], + ((index_1 << upsample_shift) & 0x3F) >> 1); + + const int index_2 = base_left_y + ystep * 2; + LoadStepwise(left + (index_2 >> scale_bits), base_step, right_step, + &left_v[2], &right_v[2]); + value_v[2] = WeightedBlend(left_v[2], right_v[2], + ((index_2 << upsample_shift) & 0x3F) >> 1); + + const int index_3 = base_left_y + ystep * 3; + LoadStepwise(left + (index_3 >> scale_bits), base_step, right_step, + &left_v[3], &right_v[3]); + value_v[3] = WeightedBlend(left_v[3], right_v[3], + ((index_3 << upsample_shift) & 0x3F) >> 1); + + const int index_4 = base_left_y + ystep * 4; + LoadStepwise(left + (index_4 >> scale_bits), base_step, right_step, + &left_v[4], &right_v[4]); + value_v[4] = WeightedBlend(left_v[4], right_v[4], + ((index_4 << upsample_shift) & 0x3F) >> 1); + + const int index_5 = base_left_y + ystep * 5; + LoadStepwise(left + (index_5 >> scale_bits), base_step, right_step, + &left_v[5], &right_v[5]); + value_v[5] = WeightedBlend(left_v[5], right_v[5], + ((index_5 << upsample_shift) & 0x3F) >> 1); + + const int index_6 = base_left_y + ystep * 6; + LoadStepwise(left + (index_6 >> scale_bits), base_step, right_step, + &left_v[6], &right_v[6]); + value_v[6] = WeightedBlend(left_v[6], right_v[6], + ((index_6 << upsample_shift) & 0x3F) >> 1); + + const int index_7 = base_left_y + ystep * 7; + LoadStepwise(left + (index_7 >> scale_bits), base_step, right_step, + &left_v[7], &right_v[7]); + value_v[7] = WeightedBlend(left_v[7], right_v[7], + ((index_7 << upsample_shift) & 0x3F) >> 1); + + // 8x8 transpose. + const uint8x16x2_t b0 = vtrnq_u8(vcombine_u8(value_v[0], value_v[4]), + vcombine_u8(value_v[1], value_v[5])); + const uint8x16x2_t b1 = vtrnq_u8(vcombine_u8(value_v[2], value_v[6]), + vcombine_u8(value_v[3], value_v[7])); + + const uint16x8x2_t c0 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[0]), + vreinterpretq_u16_u8(b1.val[0])); + const uint16x8x2_t c1 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[1]), + vreinterpretq_u16_u8(b1.val[1])); + + const uint32x4x2_t d0 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[0]), + vreinterpretq_u32_u16(c1.val[0])); + const uint32x4x2_t d1 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[1]), + vreinterpretq_u32_u16(c1.val[1])); + + if (width == 4) { + StoreLo4(dst, vreinterpret_u8_u32(vget_low_u32(d0.val[0]))); + dst += stride; + StoreLo4(dst, vreinterpret_u8_u32(vget_high_u32(d0.val[0]))); + dst += stride; + StoreLo4(dst, vreinterpret_u8_u32(vget_low_u32(d1.val[0]))); + dst += stride; + StoreLo4(dst, vreinterpret_u8_u32(vget_high_u32(d1.val[0]))); + if (height == 4) return; + dst += stride; + StoreLo4(dst, vreinterpret_u8_u32(vget_low_u32(d0.val[1]))); + dst += stride; + StoreLo4(dst, vreinterpret_u8_u32(vget_high_u32(d0.val[1]))); + dst += stride; + StoreLo4(dst, vreinterpret_u8_u32(vget_low_u32(d1.val[1]))); + dst += stride; + StoreLo4(dst, vreinterpret_u8_u32(vget_high_u32(d1.val[1]))); + } else { + vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d0.val[0]))); + dst += stride; + vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d0.val[0]))); + dst += stride; + vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d1.val[0]))); + dst += stride; + vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d1.val[0]))); + if (height == 4) return; + dst += stride; + vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d0.val[1]))); + dst += stride; + vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d0.val[1]))); + dst += stride; + vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d1.val[1]))); + dst += stride; + vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d1.val[1]))); + } +} + +// Because the source values "move backwards" as the row index increases, the +// indices derived from ystep are generally negative. This is accommodated by +// making sure the relative indices are within [-15, 0] when the function is +// called, and sliding them into the inclusive range [0, 15], relative to a +// lower base address. +constexpr int kPositiveIndexOffset = 15; + +// Process 4 or 8 |width| by any |height|. +template <int width> +inline void DirectionalZone2FromLeftCol_WxH(uint8_t* dst, + const ptrdiff_t stride, + const int height, + const uint8_t* const left_column, + const int16x8_t left_y, + const int upsample_shift) { + assert(width == 4 || width == 8); + + // The shift argument must be a constant. + int16x8_t offset_y, shift_upsampled = left_y; + if (upsample_shift) { + offset_y = vshrq_n_s16(left_y, 5); + shift_upsampled = vshlq_n_s16(shift_upsampled, 1); + } else { + offset_y = vshrq_n_s16(left_y, 6); + } + + // Select values to the left of the starting point. + // The 15th element (and 16th) will be all the way at the end, to the right. + // With a negative ystep everything else will be "left" of them. + // This supports cumulative steps up to 15. We could support up to 16 by doing + // separate loads for |left_values| and |right_values|. vtbl supports 2 Q + // registers as input which would allow for cumulative offsets of 32. + const int16x8_t sampler = + vaddq_s16(offset_y, vdupq_n_s16(kPositiveIndexOffset)); + const uint8x8_t left_values = vqmovun_s16(sampler); + const uint8x8_t right_values = vadd_u8(left_values, vdup_n_u8(1)); + + const int16x8_t shift_masked = vandq_s16(shift_upsampled, vdupq_n_s16(0x3f)); + const uint8x8_t shift_mul = vreinterpret_u8_s8(vshrn_n_s16(shift_masked, 1)); + const uint8x8_t inv_shift_mul = vsub_u8(vdup_n_u8(32), shift_mul); + + int y = 0; + do { + uint8x8_t src_left, src_right; + LoadStepwise(left_column - kPositiveIndexOffset + (y << upsample_shift), + left_values, right_values, &src_left, &src_right); + const uint8x8_t val = + WeightedBlend(src_left, src_right, inv_shift_mul, shift_mul); + + if (width == 4) { + StoreLo4(dst, val); + } else { + vst1_u8(dst, val); + } + dst += stride; + } while (++y < height); +} + +// Process 4 or 8 |width| by any |height|. +template <int width> +inline void DirectionalZone1Blend_WxH(uint8_t* dest, const ptrdiff_t stride, + const int height, + const uint8_t* const top_row, + int zone_bounds, int top_x, + const int xstep, + const int upsample_shift) { + assert(width == 4 || width == 8); + + const int scale_bits_x = 6 - upsample_shift; + + const uint8x8_t all = vcreate_u8(0x0706050403020100); + const uint8x8_t even = vcreate_u8(0x0e0c0a0806040200); + const uint8x8_t base_step = upsample_shift ? even : all; + const uint8x8_t right_step = vadd_u8(base_step, vdup_n_u8(1)); + + int y = 0; + do { + const uint8_t* const src = top_row + (top_x >> scale_bits_x); + uint8x8_t left, right; + LoadStepwise(src, base_step, right_step, &left, &right); + + const uint8_t shift = ((top_x << upsample_shift) & 0x3f) >> 1; + const uint8x8_t val = WeightedBlend(left, right, shift); + + uint8x8_t dst_blend = vld1_u8(dest); + // |zone_bounds| values can be negative. + uint8x8_t blend = + vcge_s8(vreinterpret_s8_u8(all), vdup_n_s8((zone_bounds >> 6))); + uint8x8_t output = vbsl_u8(blend, val, dst_blend); + + if (width == 4) { + StoreLo4(dest, output); + } else { + vst1_u8(dest, output); + } + dest += stride; + zone_bounds += xstep; + top_x -= xstep; + } while (++y < height); +} + +// The height at which a load of 16 bytes will not contain enough source pixels +// from |left_column| to supply an accurate row when computing 8 pixels at a +// time. The values are found by inspection. By coincidence, all angles that +// satisfy (ystep >> 6) == 2 map to the same value, so it is enough to look up +// by ystep >> 6. The largest index for this lookup is 1023 >> 6 == 15. +constexpr int kDirectionalZone2ShuffleInvalidHeight[16] = { + 1024, 1024, 16, 16, 16, 16, 0, 0, 18, 0, 0, 0, 0, 0, 0, 40}; + +// 7.11.2.4 (8) 90 < angle > 180 +// The strategy for these functions (4xH and 8+xH) is to know how many blocks +// can be processed with just pixels from |top_ptr|, then handle mixed blocks, +// then handle only blocks that take from |left_ptr|. Additionally, a fast +// index-shuffle approach is used for pred values from |left_column| in sections +// that permit it. +inline void DirectionalZone2_4xH(uint8_t* dst, const ptrdiff_t stride, + const uint8_t* const top_row, + const uint8_t* const left_column, + const int height, const int xstep, + const int ystep, const bool upsampled_top, + const bool upsampled_left) { + const int upsample_left_shift = static_cast<int>(upsampled_left); + const int upsample_top_shift = static_cast<int>(upsampled_top); + + // Helper vector. + const int16x8_t zero_to_seven = {0, 1, 2, 3, 4, 5, 6, 7}; + + // Loop incrementers for moving by block (4xN). Vertical still steps by 8. If + // it's only 4, it will be finished in the first iteration. + const ptrdiff_t stride8 = stride << 3; + const int xstep8 = xstep << 3; + + const int min_height = (height == 4) ? 4 : 8; + + // All columns from |min_top_only_x| to the right will only need |top_row| to + // compute and can therefore call the Zone1 functions. This assumes |xstep| is + // at least 3. + assert(xstep >= 3); + const int min_top_only_x = std::min((height * xstep) >> 6, /* width */ 4); + + // For steep angles, the source pixels from |left_column| may not fit in a + // 16-byte load for shuffling. + // TODO(petersonab): Find a more precise formula for this subject to x. + // TODO(johannkoenig): Revisit this for |width| == 4. + const int max_shuffle_height = + std::min(kDirectionalZone2ShuffleInvalidHeight[ystep >> 6], height); + + // Offsets the original zone bound value to simplify x < (y+1)*xstep/64 -1 + int xstep_bounds_base = (xstep == 64) ? 0 : xstep - 1; + + const int left_base_increment = ystep >> 6; + const int ystep_remainder = ystep & 0x3F; + + // If the 64 scaling is regarded as a decimal point, the first value of the + // left_y vector omits the portion which is covered under the left_column + // offset. The following values need the full ystep as a relative offset. + int16x8_t left_y = vmulq_n_s16(zero_to_seven, -ystep); + left_y = vaddq_s16(left_y, vdupq_n_s16(-ystep_remainder)); + + // This loop treats each set of 4 columns in 3 stages with y-value boundaries. + // The first stage, before the first y-loop, covers blocks that are only + // computed from the top row. The second stage, comprising two y-loops, covers + // blocks that have a mixture of values computed from top or left. The final + // stage covers blocks that are only computed from the left. + if (min_top_only_x > 0) { + // Round down to the nearest multiple of 8. + // TODO(johannkoenig): This never hits for Wx4 blocks but maybe it should. + const int max_top_only_y = std::min((1 << 6) / xstep, height) & ~7; + DirectionalZone1_WxH<4>(dst, stride, max_top_only_y, top_row, -xstep, + upsampled_top); + + if (max_top_only_y == height) return; + + int y = max_top_only_y; + dst += stride * y; + const int xstep_y = xstep * y; + + // All rows from |min_left_only_y| down for this set of columns only need + // |left_column| to compute. + const int min_left_only_y = std::min((4 << 6) / xstep, height); + // At high angles such that min_left_only_y < 8, ystep is low and xstep is + // high. This means that max_shuffle_height is unbounded and xstep_bounds + // will overflow in 16 bits. This is prevented by stopping the first + // blending loop at min_left_only_y for such cases, which means we skip over + // the second blending loop as well. + const int left_shuffle_stop_y = + std::min(max_shuffle_height, min_left_only_y); + int xstep_bounds = xstep_bounds_base + xstep_y; + int top_x = -xstep - xstep_y; + + // +8 increment is OK because if height is 4 this only goes once. + for (; y < left_shuffle_stop_y; + y += 8, dst += stride8, xstep_bounds += xstep8, top_x -= xstep8) { + DirectionalZone2FromLeftCol_WxH<4>( + dst, stride, min_height, + left_column + ((y - left_base_increment) << upsample_left_shift), + left_y, upsample_left_shift); + + DirectionalZone1Blend_WxH<4>(dst, stride, min_height, top_row, + xstep_bounds, top_x, xstep, + upsample_top_shift); + } + + // Pick up from the last y-value, using the slower but secure method for + // left prediction. + const int16_t base_left_y = vgetq_lane_s16(left_y, 0); + for (; y < min_left_only_y; + y += 8, dst += stride8, xstep_bounds += xstep8, top_x -= xstep8) { + DirectionalZone3_WxH<4>( + dst, stride, min_height, + left_column + ((y - left_base_increment) << upsample_left_shift), + base_left_y, -ystep, upsample_left_shift); + + DirectionalZone1Blend_WxH<4>(dst, stride, min_height, top_row, + xstep_bounds, top_x, xstep, + upsample_top_shift); + } + // Loop over y for left_only rows. + for (; y < height; y += 8, dst += stride8) { + DirectionalZone3_WxH<4>( + dst, stride, min_height, + left_column + ((y - left_base_increment) << upsample_left_shift), + base_left_y, -ystep, upsample_left_shift); + } + } else { + DirectionalZone1_WxH<4>(dst, stride, height, top_row, -xstep, + upsampled_top); + } +} + +// Process a multiple of 8 |width|. +inline void DirectionalZone2_8(uint8_t* const dst, const ptrdiff_t stride, + const uint8_t* const top_row, + const uint8_t* const left_column, + const int width, const int height, + const int xstep, const int ystep, + const bool upsampled_top, + const bool upsampled_left) { + const int upsample_left_shift = static_cast<int>(upsampled_left); + const int upsample_top_shift = static_cast<int>(upsampled_top); + + // Helper vector. + const int16x8_t zero_to_seven = {0, 1, 2, 3, 4, 5, 6, 7}; + + // Loop incrementers for moving by block (8x8). This function handles blocks + // with height 4 as well. They are calculated in one pass so these variables + // do not get used. + const ptrdiff_t stride8 = stride << 3; + const int xstep8 = xstep << 3; + const int ystep8 = ystep << 3; + + // Process Wx4 blocks. + const int min_height = (height == 4) ? 4 : 8; + + // All columns from |min_top_only_x| to the right will only need |top_row| to + // compute and can therefore call the Zone1 functions. This assumes |xstep| is + // at least 3. + assert(xstep >= 3); + const int min_top_only_x = std::min((height * xstep) >> 6, width); + + // For steep angles, the source pixels from |left_column| may not fit in a + // 16-byte load for shuffling. + // TODO(petersonab): Find a more precise formula for this subject to x. + const int max_shuffle_height = + std::min(kDirectionalZone2ShuffleInvalidHeight[ystep >> 6], height); + + // Offsets the original zone bound value to simplify x < (y+1)*xstep/64 -1 + int xstep_bounds_base = (xstep == 64) ? 0 : xstep - 1; + + const int left_base_increment = ystep >> 6; + const int ystep_remainder = ystep & 0x3F; + + const int left_base_increment8 = ystep8 >> 6; + const int ystep_remainder8 = ystep8 & 0x3F; + const int16x8_t increment_left8 = vdupq_n_s16(ystep_remainder8); + + // If the 64 scaling is regarded as a decimal point, the first value of the + // left_y vector omits the portion which is covered under the left_column + // offset. Following values need the full ystep as a relative offset. + int16x8_t left_y = vmulq_n_s16(zero_to_seven, -ystep); + left_y = vaddq_s16(left_y, vdupq_n_s16(-ystep_remainder)); + + // This loop treats each set of 4 columns in 3 stages with y-value boundaries. + // The first stage, before the first y-loop, covers blocks that are only + // computed from the top row. The second stage, comprising two y-loops, covers + // blocks that have a mixture of values computed from top or left. The final + // stage covers blocks that are only computed from the left. + int x = 0; + for (int left_offset = -left_base_increment; x < min_top_only_x; x += 8, + xstep_bounds_base -= (8 << 6), + left_y = vsubq_s16(left_y, increment_left8), + left_offset -= left_base_increment8) { + uint8_t* dst_x = dst + x; + + // Round down to the nearest multiple of 8. + const int max_top_only_y = std::min(((x + 1) << 6) / xstep, height) & ~7; + DirectionalZone1_WxH<8>(dst_x, stride, max_top_only_y, + top_row + (x << upsample_top_shift), -xstep, + upsampled_top); + + if (max_top_only_y == height) continue; + + int y = max_top_only_y; + dst_x += stride * y; + const int xstep_y = xstep * y; + + // All rows from |min_left_only_y| down for this set of columns only need + // |left_column| to compute. + const int min_left_only_y = std::min(((x + 8) << 6) / xstep, height); + // At high angles such that min_left_only_y < 8, ystep is low and xstep is + // high. This means that max_shuffle_height is unbounded and xstep_bounds + // will overflow in 16 bits. This is prevented by stopping the first + // blending loop at min_left_only_y for such cases, which means we skip over + // the second blending loop as well. + const int left_shuffle_stop_y = + std::min(max_shuffle_height, min_left_only_y); + int xstep_bounds = xstep_bounds_base + xstep_y; + int top_x = -xstep - xstep_y; + + for (; y < left_shuffle_stop_y; + y += 8, dst_x += stride8, xstep_bounds += xstep8, top_x -= xstep8) { + DirectionalZone2FromLeftCol_WxH<8>( + dst_x, stride, min_height, + left_column + ((left_offset + y) << upsample_left_shift), left_y, + upsample_left_shift); + + DirectionalZone1Blend_WxH<8>( + dst_x, stride, min_height, top_row + (x << upsample_top_shift), + xstep_bounds, top_x, xstep, upsample_top_shift); + } + + // Pick up from the last y-value, using the slower but secure method for + // left prediction. + const int16_t base_left_y = vgetq_lane_s16(left_y, 0); + for (; y < min_left_only_y; + y += 8, dst_x += stride8, xstep_bounds += xstep8, top_x -= xstep8) { + DirectionalZone3_WxH<8>( + dst_x, stride, min_height, + left_column + ((left_offset + y) << upsample_left_shift), base_left_y, + -ystep, upsample_left_shift); + + DirectionalZone1Blend_WxH<8>( + dst_x, stride, min_height, top_row + (x << upsample_top_shift), + xstep_bounds, top_x, xstep, upsample_top_shift); + } + // Loop over y for left_only rows. + for (; y < height; y += 8, dst_x += stride8) { + DirectionalZone3_WxH<8>( + dst_x, stride, min_height, + left_column + ((left_offset + y) << upsample_left_shift), base_left_y, + -ystep, upsample_left_shift); + } + } + // TODO(johannkoenig): May be able to remove this branch. + if (x < width) { + DirectionalZone1_WxH(dst + x, stride, width - x, height, + top_row + (x << upsample_top_shift), -xstep, + upsampled_top); + } +} + +void DirectionalIntraPredictorZone2_NEON( + void* const dest, const ptrdiff_t stride, const void* const top_row, + const void* const left_column, const int width, const int height, + const int xstep, const int ystep, const bool upsampled_top, + const bool upsampled_left) { + // Increasing the negative buffer for this function allows more rows to be + // processed at a time without branching in an inner loop to check the base. + uint8_t top_buffer[288]; + uint8_t left_buffer[288]; + memcpy(top_buffer + 128, static_cast<const uint8_t*>(top_row) - 16, 160); + memcpy(left_buffer + 128, static_cast<const uint8_t*>(left_column) - 16, 160); + const uint8_t* top_ptr = top_buffer + 144; + const uint8_t* left_ptr = left_buffer + 144; + auto* dst = static_cast<uint8_t*>(dest); + + if (width == 4) { + DirectionalZone2_4xH(dst, stride, top_ptr, left_ptr, height, xstep, ystep, + upsampled_top, upsampled_left); + } else { + DirectionalZone2_8(dst, stride, top_ptr, left_ptr, width, height, xstep, + ystep, upsampled_top, upsampled_left); + } +} + +void DirectionalIntraPredictorZone3_NEON(void* const dest, + const ptrdiff_t stride, + const void* const left_column, + const int width, const int height, + const int ystep, + const bool upsampled_left) { + const auto* const left = static_cast<const uint8_t*>(left_column); + + assert(ystep > 0); + + const int upsample_shift = static_cast<int>(upsampled_left); + const int scale_bits = 6 - upsample_shift; + const int base_step = 1 << upsample_shift; + + if (width == 4 || height == 4) { + // This block can handle all sizes but the specializations for other sizes + // are faster. + const uint8x8_t all = vcreate_u8(0x0706050403020100); + const uint8x8_t even = vcreate_u8(0x0e0c0a0806040200); + const uint8x8_t base_step_v = upsampled_left ? even : all; + const uint8x8_t right_step = vadd_u8(base_step_v, vdup_n_u8(1)); + + int y = 0; + do { + int x = 0; + do { + uint8_t* dst = static_cast<uint8_t*>(dest); + dst += y * stride + x; + uint8x8_t left_v[4], right_v[4], value_v[4]; + const int ystep_base = ystep * x; + const int offset = y * base_step; + + const int index_0 = ystep_base + ystep * 1; + LoadStepwise(left + offset + (index_0 >> scale_bits), base_step_v, + right_step, &left_v[0], &right_v[0]); + value_v[0] = WeightedBlend(left_v[0], right_v[0], + ((index_0 << upsample_shift) & 0x3F) >> 1); + + const int index_1 = ystep_base + ystep * 2; + LoadStepwise(left + offset + (index_1 >> scale_bits), base_step_v, + right_step, &left_v[1], &right_v[1]); + value_v[1] = WeightedBlend(left_v[1], right_v[1], + ((index_1 << upsample_shift) & 0x3F) >> 1); + + const int index_2 = ystep_base + ystep * 3; + LoadStepwise(left + offset + (index_2 >> scale_bits), base_step_v, + right_step, &left_v[2], &right_v[2]); + value_v[2] = WeightedBlend(left_v[2], right_v[2], + ((index_2 << upsample_shift) & 0x3F) >> 1); + + const int index_3 = ystep_base + ystep * 4; + LoadStepwise(left + offset + (index_3 >> scale_bits), base_step_v, + right_step, &left_v[3], &right_v[3]); + value_v[3] = WeightedBlend(left_v[3], right_v[3], + ((index_3 << upsample_shift) & 0x3F) >> 1); + + // 8x4 transpose. + const uint8x8x2_t b0 = vtrn_u8(value_v[0], value_v[1]); + const uint8x8x2_t b1 = vtrn_u8(value_v[2], value_v[3]); + + const uint16x4x2_t c0 = vtrn_u16(vreinterpret_u16_u8(b0.val[0]), + vreinterpret_u16_u8(b1.val[0])); + const uint16x4x2_t c1 = vtrn_u16(vreinterpret_u16_u8(b0.val[1]), + vreinterpret_u16_u8(b1.val[1])); + + StoreLo4(dst, vreinterpret_u8_u16(c0.val[0])); + dst += stride; + StoreLo4(dst, vreinterpret_u8_u16(c1.val[0])); + dst += stride; + StoreLo4(dst, vreinterpret_u8_u16(c0.val[1])); + dst += stride; + StoreLo4(dst, vreinterpret_u8_u16(c1.val[1])); + + if (height > 4) { + dst += stride; + StoreHi4(dst, vreinterpret_u8_u16(c0.val[0])); + dst += stride; + StoreHi4(dst, vreinterpret_u8_u16(c1.val[0])); + dst += stride; + StoreHi4(dst, vreinterpret_u8_u16(c0.val[1])); + dst += stride; + StoreHi4(dst, vreinterpret_u8_u16(c1.val[1])); + } + x += 4; + } while (x < width); + y += 8; + } while (y < height); + } else { // 8x8 at a time. + // Limited improvement for 8x8. ~20% faster for 64x64. + int y = 0; + do { + int x = 0; + do { + uint8_t* dst = static_cast<uint8_t*>(dest); + dst += y * stride + x; + const int ystep_base = ystep * (x + 1); + + DirectionalZone3_WxH<8>(dst, stride, 8, left + (y << upsample_shift), + ystep_base, ystep, upsample_shift); + x += 8; + } while (x < width); + y += 8; + } while (y < height); + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->directional_intra_predictor_zone1 = DirectionalIntraPredictorZone1_NEON; + dsp->directional_intra_predictor_zone2 = DirectionalIntraPredictorZone2_NEON; + dsp->directional_intra_predictor_zone3 = DirectionalIntraPredictorZone3_NEON; +} + +} // namespace +} // namespace low_bitdepth + +void IntraPredDirectionalInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void IntraPredDirectionalInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/intrapred_filter_intra_neon.cc b/src/dsp/arm/intrapred_filter_intra_neon.cc new file mode 100644 index 0000000..411708e --- /dev/null +++ b/src/dsp/arm/intrapred_filter_intra_neon.cc @@ -0,0 +1,176 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intrapred.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { + +namespace low_bitdepth { +namespace { + +// Transpose kFilterIntraTaps and convert the first row to unsigned values. +// +// With the previous orientation we were able to multiply all the input values +// by a single tap. This required that all the input values be in one vector +// which requires expensive set up operations (shifts, vext, vtbl). All the +// elements of the result needed to be summed (easy on A64 - vaddvq_s16) but +// then the shifting, rounding, and clamping was done in GP registers. +// +// Switching to unsigned values allows multiplying the 8 bit inputs directly. +// When one value was negative we needed to vmovl_u8 first so that the results +// maintained the proper sign. +// +// We take this into account when summing the values by subtracting the product +// of the first row. +alignas(8) constexpr uint8_t kTransposedTaps[kNumFilterIntraPredictors][7][8] = + {{{6, 5, 3, 3, 4, 3, 3, 3}, // Original values are negative. + {10, 2, 1, 1, 6, 2, 2, 1}, + {0, 10, 1, 1, 0, 6, 2, 2}, + {0, 0, 10, 2, 0, 0, 6, 2}, + {0, 0, 0, 10, 0, 0, 0, 6}, + {12, 9, 7, 5, 2, 2, 2, 3}, + {0, 0, 0, 0, 12, 9, 7, 5}}, + {{10, 6, 4, 2, 10, 6, 4, 2}, // Original values are negative. + {16, 0, 0, 0, 16, 0, 0, 0}, + {0, 16, 0, 0, 0, 16, 0, 0}, + {0, 0, 16, 0, 0, 0, 16, 0}, + {0, 0, 0, 16, 0, 0, 0, 16}, + {10, 6, 4, 2, 0, 0, 0, 0}, + {0, 0, 0, 0, 10, 6, 4, 2}}, + {{8, 8, 8, 8, 4, 4, 4, 4}, // Original values are negative. + {8, 0, 0, 0, 4, 0, 0, 0}, + {0, 8, 0, 0, 0, 4, 0, 0}, + {0, 0, 8, 0, 0, 0, 4, 0}, + {0, 0, 0, 8, 0, 0, 0, 4}, + {16, 16, 16, 16, 0, 0, 0, 0}, + {0, 0, 0, 0, 16, 16, 16, 16}}, + {{2, 1, 1, 0, 1, 1, 1, 1}, // Original values are negative. + {8, 3, 2, 1, 4, 3, 2, 2}, + {0, 8, 3, 2, 0, 4, 3, 2}, + {0, 0, 8, 3, 0, 0, 4, 3}, + {0, 0, 0, 8, 0, 0, 0, 4}, + {10, 6, 4, 2, 3, 4, 4, 3}, + {0, 0, 0, 0, 10, 6, 4, 3}}, + {{12, 10, 9, 8, 10, 9, 8, 7}, // Original values are negative. + {14, 0, 0, 0, 12, 1, 0, 0}, + {0, 14, 0, 0, 0, 12, 0, 0}, + {0, 0, 14, 0, 0, 0, 12, 1}, + {0, 0, 0, 14, 0, 0, 0, 12}, + {14, 12, 11, 10, 0, 0, 1, 1}, + {0, 0, 0, 0, 14, 12, 11, 9}}}; + +void FilterIntraPredictor_NEON(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column, + FilterIntraPredictor pred, int width, + int height) { + const uint8_t* const top = static_cast<const uint8_t*>(top_row); + const uint8_t* const left = static_cast<const uint8_t*>(left_column); + + assert(width <= 32 && height <= 32); + + uint8_t* dst = static_cast<uint8_t*>(dest); + + uint8x8_t transposed_taps[7]; + for (int i = 0; i < 7; ++i) { + transposed_taps[i] = vld1_u8(kTransposedTaps[pred][i]); + } + + uint8_t relative_top_left = top[-1]; + const uint8_t* relative_top = top; + uint8_t relative_left[2] = {left[0], left[1]}; + + int y = 0; + do { + uint8_t* row_dst = dst; + int x = 0; + do { + uint16x8_t sum = vdupq_n_u16(0); + const uint16x8_t subtrahend = + vmull_u8(transposed_taps[0], vdup_n_u8(relative_top_left)); + for (int i = 1; i < 5; ++i) { + sum = vmlal_u8(sum, transposed_taps[i], vdup_n_u8(relative_top[i - 1])); + } + for (int i = 5; i < 7; ++i) { + sum = + vmlal_u8(sum, transposed_taps[i], vdup_n_u8(relative_left[i - 5])); + } + + const int16x8_t sum_signed = + vreinterpretq_s16_u16(vsubq_u16(sum, subtrahend)); + const int16x8_t sum_shifted = vrshrq_n_s16(sum_signed, 4); + + uint8x8_t sum_saturated = vqmovun_s16(sum_shifted); + + StoreLo4(row_dst, sum_saturated); + StoreHi4(row_dst + stride, sum_saturated); + + // Progress across + relative_top_left = relative_top[3]; + relative_top += 4; + relative_left[0] = row_dst[3]; + relative_left[1] = row_dst[3 + stride]; + row_dst += 4; + x += 4; + } while (x < width); + + // Progress down. + relative_top_left = left[y + 1]; + relative_top = dst + stride; + relative_left[0] = left[y + 2]; + relative_left[1] = left[y + 3]; + + dst += 2 * stride; + y += 2; + } while (y < height); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->filter_intra_predictor = FilterIntraPredictor_NEON; +} + +} // namespace +} // namespace low_bitdepth + +void IntraPredFilterIntraInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void IntraPredFilterIntraInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/intrapred_neon.cc b/src/dsp/arm/intrapred_neon.cc new file mode 100644 index 0000000..c967d82 --- /dev/null +++ b/src/dsp/arm/intrapred_neon.cc @@ -0,0 +1,1144 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intrapred.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" + +namespace libgav1 { +namespace dsp { +namespace { + +//------------------------------------------------------------------------------ +// DcPredFuncs_NEON + +using DcSumFunc = uint32x2_t (*)(const void* ref_0, const int ref_0_size_log2, + const bool use_ref_1, const void* ref_1, + const int ref_1_size_log2); +using DcStoreFunc = void (*)(void* dest, ptrdiff_t stride, const uint32x2_t dc); + +// DC intra-predictors for square blocks. +template <int block_width_log2, int block_height_log2, DcSumFunc sumfn, + DcStoreFunc storefn> +struct DcPredFuncs_NEON { + DcPredFuncs_NEON() = delete; + + static void DcTop(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); + static void DcLeft(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); + static void Dc(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); +}; + +template <int block_width_log2, int block_height_log2, DcSumFunc sumfn, + DcStoreFunc storefn> +void DcPredFuncs_NEON<block_width_log2, block_height_log2, sumfn, + storefn>::DcTop(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* /*left_column*/) { + const uint32x2_t sum = sumfn(top_row, block_width_log2, false, nullptr, 0); + const uint32x2_t dc = vrshr_n_u32(sum, block_width_log2); + storefn(dest, stride, dc); +} + +template <int block_width_log2, int block_height_log2, DcSumFunc sumfn, + DcStoreFunc storefn> +void DcPredFuncs_NEON<block_width_log2, block_height_log2, sumfn, + storefn>::DcLeft(void* const dest, ptrdiff_t stride, + const void* /*top_row*/, + const void* const left_column) { + const uint32x2_t sum = + sumfn(left_column, block_height_log2, false, nullptr, 0); + const uint32x2_t dc = vrshr_n_u32(sum, block_height_log2); + storefn(dest, stride, dc); +} + +template <int block_width_log2, int block_height_log2, DcSumFunc sumfn, + DcStoreFunc storefn> +void DcPredFuncs_NEON<block_width_log2, block_height_log2, sumfn, storefn>::Dc( + void* const dest, ptrdiff_t stride, const void* const top_row, + const void* const left_column) { + const uint32x2_t sum = + sumfn(top_row, block_width_log2, true, left_column, block_height_log2); + if (block_width_log2 == block_height_log2) { + const uint32x2_t dc = vrshr_n_u32(sum, block_width_log2 + 1); + storefn(dest, stride, dc); + } else { + // TODO(johannkoenig): Compare this to mul/shift in vectors. + const int divisor = (1 << block_width_log2) + (1 << block_height_log2); + uint32_t dc = vget_lane_u32(sum, 0); + dc += divisor >> 1; + dc /= divisor; + storefn(dest, stride, vdup_n_u32(dc)); + } +} + +// Sum all the elements in the vector into the low 32 bits. +inline uint32x2_t Sum(const uint16x4_t val) { + const uint32x2_t sum = vpaddl_u16(val); + return vpadd_u32(sum, sum); +} + +// Sum all the elements in the vector into the low 32 bits. +inline uint32x2_t Sum(const uint16x8_t val) { + const uint32x4_t sum_0 = vpaddlq_u16(val); + const uint64x2_t sum_1 = vpaddlq_u32(sum_0); + return vadd_u32(vget_low_u32(vreinterpretq_u32_u64(sum_1)), + vget_high_u32(vreinterpretq_u32_u64(sum_1))); +} + +} // namespace + +//------------------------------------------------------------------------------ +namespace low_bitdepth { +namespace { + +// Add and expand the elements in the |val_[01]| to uint16_t but do not sum the +// entire vector. +inline uint16x8_t Add(const uint8x16_t val_0, const uint8x16_t val_1) { + const uint16x8_t sum_0 = vpaddlq_u8(val_0); + const uint16x8_t sum_1 = vpaddlq_u8(val_1); + return vaddq_u16(sum_0, sum_1); +} + +// Add and expand the elements in the |val_[0123]| to uint16_t but do not sum +// the entire vector. +inline uint16x8_t Add(const uint8x16_t val_0, const uint8x16_t val_1, + const uint8x16_t val_2, const uint8x16_t val_3) { + const uint16x8_t sum_0 = Add(val_0, val_1); + const uint16x8_t sum_1 = Add(val_2, val_3); + return vaddq_u16(sum_0, sum_1); +} + +// Load and combine 32 uint8_t values. +inline uint16x8_t LoadAndAdd32(const uint8_t* buf) { + const uint8x16_t val_0 = vld1q_u8(buf); + const uint8x16_t val_1 = vld1q_u8(buf + 16); + return Add(val_0, val_1); +} + +// Load and combine 64 uint8_t values. +inline uint16x8_t LoadAndAdd64(const uint8_t* buf) { + const uint8x16_t val_0 = vld1q_u8(buf); + const uint8x16_t val_1 = vld1q_u8(buf + 16); + const uint8x16_t val_2 = vld1q_u8(buf + 32); + const uint8x16_t val_3 = vld1q_u8(buf + 48); + return Add(val_0, val_1, val_2, val_3); +} + +// |ref_[01]| each point to 1 << |ref[01]_size_log2| packed uint8_t values. +// If |use_ref_1| is false then only sum |ref_0|. +// For |ref[01]_size_log2| == 4 this relies on |ref_[01]| being aligned to +// uint32_t. +inline uint32x2_t DcSum_NEON(const void* ref_0, const int ref_0_size_log2, + const bool use_ref_1, const void* ref_1, + const int ref_1_size_log2) { + const auto* const ref_0_u8 = static_cast<const uint8_t*>(ref_0); + const auto* const ref_1_u8 = static_cast<const uint8_t*>(ref_1); + if (ref_0_size_log2 == 2) { + uint8x8_t val = Load4(ref_0_u8); + if (use_ref_1) { + if (ref_1_size_log2 == 2) { // 4x4 + val = Load4<1>(ref_1_u8, val); + return Sum(vpaddl_u8(val)); + } else if (ref_1_size_log2 == 3) { // 4x8 + const uint8x8_t val_1 = vld1_u8(ref_1_u8); + const uint16x4_t sum_0 = vpaddl_u8(val); + const uint16x4_t sum_1 = vpaddl_u8(val_1); + return Sum(vadd_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 4) { // 4x16 + const uint8x16_t val_1 = vld1q_u8(ref_1_u8); + return Sum(vaddw_u8(vpaddlq_u8(val_1), val)); + } + } + // 4x1 + const uint16x4_t sum = vpaddl_u8(val); + return vpaddl_u16(sum); + } else if (ref_0_size_log2 == 3) { + const uint8x8_t val_0 = vld1_u8(ref_0_u8); + if (use_ref_1) { + if (ref_1_size_log2 == 2) { // 8x4 + const uint8x8_t val_1 = Load4(ref_1_u8); + const uint16x4_t sum_0 = vpaddl_u8(val_0); + const uint16x4_t sum_1 = vpaddl_u8(val_1); + return Sum(vadd_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 3) { // 8x8 + const uint8x8_t val_1 = vld1_u8(ref_1_u8); + const uint16x4_t sum_0 = vpaddl_u8(val_0); + const uint16x4_t sum_1 = vpaddl_u8(val_1); + return Sum(vadd_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 4) { // 8x16 + const uint8x16_t val_1 = vld1q_u8(ref_1_u8); + return Sum(vaddw_u8(vpaddlq_u8(val_1), val_0)); + } else if (ref_1_size_log2 == 5) { // 8x32 + return Sum(vaddw_u8(LoadAndAdd32(ref_1_u8), val_0)); + } + } + // 8x1 + return Sum(vpaddl_u8(val_0)); + } else if (ref_0_size_log2 == 4) { + const uint8x16_t val_0 = vld1q_u8(ref_0_u8); + if (use_ref_1) { + if (ref_1_size_log2 == 2) { // 16x4 + const uint8x8_t val_1 = Load4(ref_1_u8); + return Sum(vaddw_u8(vpaddlq_u8(val_0), val_1)); + } else if (ref_1_size_log2 == 3) { // 16x8 + const uint8x8_t val_1 = vld1_u8(ref_1_u8); + return Sum(vaddw_u8(vpaddlq_u8(val_0), val_1)); + } else if (ref_1_size_log2 == 4) { // 16x16 + const uint8x16_t val_1 = vld1q_u8(ref_1_u8); + return Sum(Add(val_0, val_1)); + } else if (ref_1_size_log2 == 5) { // 16x32 + const uint16x8_t sum_0 = vpaddlq_u8(val_0); + const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u8); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 6) { // 16x64 + const uint16x8_t sum_0 = vpaddlq_u8(val_0); + const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u8); + return Sum(vaddq_u16(sum_0, sum_1)); + } + } + // 16x1 + return Sum(vpaddlq_u8(val_0)); + } else if (ref_0_size_log2 == 5) { + const uint16x8_t sum_0 = LoadAndAdd32(ref_0_u8); + if (use_ref_1) { + if (ref_1_size_log2 == 3) { // 32x8 + const uint8x8_t val_1 = vld1_u8(ref_1_u8); + return Sum(vaddw_u8(sum_0, val_1)); + } else if (ref_1_size_log2 == 4) { // 32x16 + const uint8x16_t val_1 = vld1q_u8(ref_1_u8); + const uint16x8_t sum_1 = vpaddlq_u8(val_1); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 5) { // 32x32 + const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u8); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 6) { // 32x64 + const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u8); + return Sum(vaddq_u16(sum_0, sum_1)); + } + } + // 32x1 + return Sum(sum_0); + } + + assert(ref_0_size_log2 == 6); + const uint16x8_t sum_0 = LoadAndAdd64(ref_0_u8); + if (use_ref_1) { + if (ref_1_size_log2 == 4) { // 64x16 + const uint8x16_t val_1 = vld1q_u8(ref_1_u8); + const uint16x8_t sum_1 = vpaddlq_u8(val_1); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 5) { // 64x32 + const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u8); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 6) { // 64x64 + const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u8); + return Sum(vaddq_u16(sum_0, sum_1)); + } + } + // 64x1 + return Sum(sum_0); +} + +template <int width, int height> +inline void DcStore_NEON(void* const dest, ptrdiff_t stride, + const uint32x2_t dc) { + const uint8x16_t dc_dup = vdupq_lane_u8(vreinterpret_u8_u32(dc), 0); + auto* dst = static_cast<uint8_t*>(dest); + if (width == 4) { + int i = height - 1; + do { + StoreLo4(dst, vget_low_u8(dc_dup)); + dst += stride; + } while (--i != 0); + StoreLo4(dst, vget_low_u8(dc_dup)); + } else if (width == 8) { + int i = height - 1; + do { + vst1_u8(dst, vget_low_u8(dc_dup)); + dst += stride; + } while (--i != 0); + vst1_u8(dst, vget_low_u8(dc_dup)); + } else if (width == 16) { + int i = height - 1; + do { + vst1q_u8(dst, dc_dup); + dst += stride; + } while (--i != 0); + vst1q_u8(dst, dc_dup); + } else if (width == 32) { + int i = height - 1; + do { + vst1q_u8(dst, dc_dup); + vst1q_u8(dst + 16, dc_dup); + dst += stride; + } while (--i != 0); + vst1q_u8(dst, dc_dup); + vst1q_u8(dst + 16, dc_dup); + } else { + assert(width == 64); + int i = height - 1; + do { + vst1q_u8(dst, dc_dup); + vst1q_u8(dst + 16, dc_dup); + vst1q_u8(dst + 32, dc_dup); + vst1q_u8(dst + 48, dc_dup); + dst += stride; + } while (--i != 0); + vst1q_u8(dst, dc_dup); + vst1q_u8(dst + 16, dc_dup); + vst1q_u8(dst + 32, dc_dup); + vst1q_u8(dst + 48, dc_dup); + } +} + +template <int width, int height> +inline void Paeth4Or8xN_NEON(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + auto* dest_u8 = static_cast<uint8_t*>(dest); + const auto* const top_row_u8 = static_cast<const uint8_t*>(top_row); + const auto* const left_col_u8 = static_cast<const uint8_t*>(left_column); + + const uint8x8_t top_left = vdup_n_u8(top_row_u8[-1]); + const uint16x8_t top_left_x2 = vdupq_n_u16(top_row_u8[-1] + top_row_u8[-1]); + uint8x8_t top; + if (width == 4) { + top = Load4(top_row_u8); + } else { // width == 8 + top = vld1_u8(top_row_u8); + } + + for (int y = 0; y < height; ++y) { + const uint8x8_t left = vdup_n_u8(left_col_u8[y]); + + const uint8x8_t left_dist = vabd_u8(top, top_left); + const uint8x8_t top_dist = vabd_u8(left, top_left); + const uint16x8_t top_left_dist = + vabdq_u16(vaddl_u8(top, left), top_left_x2); + + const uint8x8_t left_le_top = vcle_u8(left_dist, top_dist); + const uint8x8_t left_le_top_left = + vmovn_u16(vcleq_u16(vmovl_u8(left_dist), top_left_dist)); + const uint8x8_t top_le_top_left = + vmovn_u16(vcleq_u16(vmovl_u8(top_dist), top_left_dist)); + + // if (left_dist <= top_dist && left_dist <= top_left_dist) + const uint8x8_t left_mask = vand_u8(left_le_top, left_le_top_left); + // dest[x] = left_column[y]; + // Fill all the unused spaces with 'top'. They will be overwritten when + // the positions for top_left are known. + uint8x8_t result = vbsl_u8(left_mask, left, top); + // else if (top_dist <= top_left_dist) + // dest[x] = top_row[x]; + // Add these values to the mask. They were already set. + const uint8x8_t left_or_top_mask = vorr_u8(left_mask, top_le_top_left); + // else + // dest[x] = top_left; + result = vbsl_u8(left_or_top_mask, result, top_left); + + if (width == 4) { + StoreLo4(dest_u8, result); + } else { // width == 8 + vst1_u8(dest_u8, result); + } + dest_u8 += stride; + } +} + +// Calculate X distance <= TopLeft distance and pack the resulting mask into +// uint8x8_t. +inline uint8x16_t XLeTopLeft(const uint8x16_t x_dist, + const uint16x8_t top_left_dist_low, + const uint16x8_t top_left_dist_high) { + // TODO(johannkoenig): cle() should work with vmovn(top_left_dist) instead of + // using movl(x_dist). + const uint8x8_t x_le_top_left_low = + vmovn_u16(vcleq_u16(vmovl_u8(vget_low_u8(x_dist)), top_left_dist_low)); + const uint8x8_t x_le_top_left_high = + vmovn_u16(vcleq_u16(vmovl_u8(vget_high_u8(x_dist)), top_left_dist_high)); + return vcombine_u8(x_le_top_left_low, x_le_top_left_high); +} + +// Select the closest values and collect them. +inline uint8x16_t SelectPaeth(const uint8x16_t top, const uint8x16_t left, + const uint8x16_t top_left, + const uint8x16_t left_le_top, + const uint8x16_t left_le_top_left, + const uint8x16_t top_le_top_left) { + // if (left_dist <= top_dist && left_dist <= top_left_dist) + const uint8x16_t left_mask = vandq_u8(left_le_top, left_le_top_left); + // dest[x] = left_column[y]; + // Fill all the unused spaces with 'top'. They will be overwritten when + // the positions for top_left are known. + uint8x16_t result = vbslq_u8(left_mask, left, top); + // else if (top_dist <= top_left_dist) + // dest[x] = top_row[x]; + // Add these values to the mask. They were already set. + const uint8x16_t left_or_top_mask = vorrq_u8(left_mask, top_le_top_left); + // else + // dest[x] = top_left; + return vbslq_u8(left_or_top_mask, result, top_left); +} + +// Generate numbered and high/low versions of top_left_dist. +#define TOP_LEFT_DIST(num) \ + const uint16x8_t top_left_##num##_dist_low = vabdq_u16( \ + vaddl_u8(vget_low_u8(top[num]), vget_low_u8(left)), top_left_x2); \ + const uint16x8_t top_left_##num##_dist_high = vabdq_u16( \ + vaddl_u8(vget_high_u8(top[num]), vget_low_u8(left)), top_left_x2) + +// Generate numbered versions of XLeTopLeft with x = left. +#define LEFT_LE_TOP_LEFT(num) \ + const uint8x16_t left_le_top_left_##num = \ + XLeTopLeft(left_##num##_dist, top_left_##num##_dist_low, \ + top_left_##num##_dist_high) + +// Generate numbered versions of XLeTopLeft with x = top. +#define TOP_LE_TOP_LEFT(num) \ + const uint8x16_t top_le_top_left_##num = XLeTopLeft( \ + top_dist, top_left_##num##_dist_low, top_left_##num##_dist_high) + +template <int width, int height> +inline void Paeth16PlusxN_NEON(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + auto* dest_u8 = static_cast<uint8_t*>(dest); + const auto* const top_row_u8 = static_cast<const uint8_t*>(top_row); + const auto* const left_col_u8 = static_cast<const uint8_t*>(left_column); + + const uint8x16_t top_left = vdupq_n_u8(top_row_u8[-1]); + const uint16x8_t top_left_x2 = vdupq_n_u16(top_row_u8[-1] + top_row_u8[-1]); + uint8x16_t top[4]; + top[0] = vld1q_u8(top_row_u8); + if (width > 16) { + top[1] = vld1q_u8(top_row_u8 + 16); + if (width == 64) { + top[2] = vld1q_u8(top_row_u8 + 32); + top[3] = vld1q_u8(top_row_u8 + 48); + } + } + + for (int y = 0; y < height; ++y) { + const uint8x16_t left = vdupq_n_u8(left_col_u8[y]); + + const uint8x16_t top_dist = vabdq_u8(left, top_left); + + const uint8x16_t left_0_dist = vabdq_u8(top[0], top_left); + TOP_LEFT_DIST(0); + const uint8x16_t left_0_le_top = vcleq_u8(left_0_dist, top_dist); + LEFT_LE_TOP_LEFT(0); + TOP_LE_TOP_LEFT(0); + + const uint8x16_t result_0 = + SelectPaeth(top[0], left, top_left, left_0_le_top, left_le_top_left_0, + top_le_top_left_0); + vst1q_u8(dest_u8, result_0); + + if (width > 16) { + const uint8x16_t left_1_dist = vabdq_u8(top[1], top_left); + TOP_LEFT_DIST(1); + const uint8x16_t left_1_le_top = vcleq_u8(left_1_dist, top_dist); + LEFT_LE_TOP_LEFT(1); + TOP_LE_TOP_LEFT(1); + + const uint8x16_t result_1 = + SelectPaeth(top[1], left, top_left, left_1_le_top, left_le_top_left_1, + top_le_top_left_1); + vst1q_u8(dest_u8 + 16, result_1); + + if (width == 64) { + const uint8x16_t left_2_dist = vabdq_u8(top[2], top_left); + TOP_LEFT_DIST(2); + const uint8x16_t left_2_le_top = vcleq_u8(left_2_dist, top_dist); + LEFT_LE_TOP_LEFT(2); + TOP_LE_TOP_LEFT(2); + + const uint8x16_t result_2 = + SelectPaeth(top[2], left, top_left, left_2_le_top, + left_le_top_left_2, top_le_top_left_2); + vst1q_u8(dest_u8 + 32, result_2); + + const uint8x16_t left_3_dist = vabdq_u8(top[3], top_left); + TOP_LEFT_DIST(3); + const uint8x16_t left_3_le_top = vcleq_u8(left_3_dist, top_dist); + LEFT_LE_TOP_LEFT(3); + TOP_LE_TOP_LEFT(3); + + const uint8x16_t result_3 = + SelectPaeth(top[3], left, top_left, left_3_le_top, + left_le_top_left_3, top_le_top_left_3); + vst1q_u8(dest_u8 + 48, result_3); + } + } + + dest_u8 += stride; + } +} + +struct DcDefs { + DcDefs() = delete; + + using _4x4 = DcPredFuncs_NEON<2, 2, DcSum_NEON, DcStore_NEON<4, 4>>; + using _4x8 = DcPredFuncs_NEON<2, 3, DcSum_NEON, DcStore_NEON<4, 8>>; + using _4x16 = DcPredFuncs_NEON<2, 4, DcSum_NEON, DcStore_NEON<4, 16>>; + using _8x4 = DcPredFuncs_NEON<3, 2, DcSum_NEON, DcStore_NEON<8, 4>>; + using _8x8 = DcPredFuncs_NEON<3, 3, DcSum_NEON, DcStore_NEON<8, 8>>; + using _8x16 = DcPredFuncs_NEON<3, 4, DcSum_NEON, DcStore_NEON<8, 16>>; + using _8x32 = DcPredFuncs_NEON<3, 5, DcSum_NEON, DcStore_NEON<8, 32>>; + using _16x4 = DcPredFuncs_NEON<4, 2, DcSum_NEON, DcStore_NEON<16, 4>>; + using _16x8 = DcPredFuncs_NEON<4, 3, DcSum_NEON, DcStore_NEON<16, 8>>; + using _16x16 = DcPredFuncs_NEON<4, 4, DcSum_NEON, DcStore_NEON<16, 16>>; + using _16x32 = DcPredFuncs_NEON<4, 5, DcSum_NEON, DcStore_NEON<16, 32>>; + using _16x64 = DcPredFuncs_NEON<4, 6, DcSum_NEON, DcStore_NEON<16, 64>>; + using _32x8 = DcPredFuncs_NEON<5, 3, DcSum_NEON, DcStore_NEON<32, 8>>; + using _32x16 = DcPredFuncs_NEON<5, 4, DcSum_NEON, DcStore_NEON<32, 16>>; + using _32x32 = DcPredFuncs_NEON<5, 5, DcSum_NEON, DcStore_NEON<32, 32>>; + using _32x64 = DcPredFuncs_NEON<5, 6, DcSum_NEON, DcStore_NEON<32, 64>>; + using _64x16 = DcPredFuncs_NEON<6, 4, DcSum_NEON, DcStore_NEON<64, 16>>; + using _64x32 = DcPredFuncs_NEON<6, 5, DcSum_NEON, DcStore_NEON<64, 32>>; + using _64x64 = DcPredFuncs_NEON<6, 6, DcSum_NEON, DcStore_NEON<64, 64>>; +}; + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + // 4x4 + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcTop] = + DcDefs::_4x4::DcTop; + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcLeft] = + DcDefs::_4x4::DcLeft; + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDc] = + DcDefs::_4x4::Dc; + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorPaeth] = + Paeth4Or8xN_NEON<4, 4>; + + // 4x8 + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcTop] = + DcDefs::_4x8::DcTop; + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcLeft] = + DcDefs::_4x8::DcLeft; + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDc] = + DcDefs::_4x8::Dc; + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorPaeth] = + Paeth4Or8xN_NEON<4, 8>; + + // 4x16 + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcTop] = + DcDefs::_4x16::DcTop; + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcLeft] = + DcDefs::_4x16::DcLeft; + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDc] = + DcDefs::_4x16::Dc; + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorPaeth] = + Paeth4Or8xN_NEON<4, 16>; + + // 8x4 + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcTop] = + DcDefs::_8x4::DcTop; + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcLeft] = + DcDefs::_8x4::DcLeft; + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDc] = + DcDefs::_8x4::Dc; + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorPaeth] = + Paeth4Or8xN_NEON<8, 4>; + + // 8x8 + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcTop] = + DcDefs::_8x8::DcTop; + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcLeft] = + DcDefs::_8x8::DcLeft; + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDc] = + DcDefs::_8x8::Dc; + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorPaeth] = + Paeth4Or8xN_NEON<8, 8>; + + // 8x16 + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcTop] = + DcDefs::_8x16::DcTop; + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcLeft] = + DcDefs::_8x16::DcLeft; + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDc] = + DcDefs::_8x16::Dc; + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorPaeth] = + Paeth4Or8xN_NEON<8, 16>; + + // 8x32 + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcTop] = + DcDefs::_8x32::DcTop; + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcLeft] = + DcDefs::_8x32::DcLeft; + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDc] = + DcDefs::_8x32::Dc; + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorPaeth] = + Paeth4Or8xN_NEON<8, 32>; + + // 16x4 + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcTop] = + DcDefs::_16x4::DcTop; + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcLeft] = + DcDefs::_16x4::DcLeft; + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDc] = + DcDefs::_16x4::Dc; + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<16, 4>; + + // 16x8 + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcTop] = + DcDefs::_16x8::DcTop; + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcLeft] = + DcDefs::_16x8::DcLeft; + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDc] = + DcDefs::_16x8::Dc; + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<16, 8>; + + // 16x16 + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcTop] = + DcDefs::_16x16::DcTop; + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcLeft] = + DcDefs::_16x16::DcLeft; + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDc] = + DcDefs::_16x16::Dc; + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<16, 16>; + + // 16x32 + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcTop] = + DcDefs::_16x32::DcTop; + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcLeft] = + DcDefs::_16x32::DcLeft; + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDc] = + DcDefs::_16x32::Dc; + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<16, 32>; + + // 16x64 + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcTop] = + DcDefs::_16x64::DcTop; + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcLeft] = + DcDefs::_16x64::DcLeft; + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDc] = + DcDefs::_16x64::Dc; + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<16, 64>; + + // 32x8 + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcTop] = + DcDefs::_32x8::DcTop; + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcLeft] = + DcDefs::_32x8::DcLeft; + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDc] = + DcDefs::_32x8::Dc; + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<32, 8>; + + // 32x16 + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcTop] = + DcDefs::_32x16::DcTop; + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcLeft] = + DcDefs::_32x16::DcLeft; + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDc] = + DcDefs::_32x16::Dc; + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<32, 16>; + + // 32x32 + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcTop] = + DcDefs::_32x32::DcTop; + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcLeft] = + DcDefs::_32x32::DcLeft; + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDc] = + DcDefs::_32x32::Dc; + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<32, 32>; + + // 32x64 + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcTop] = + DcDefs::_32x64::DcTop; + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcLeft] = + DcDefs::_32x64::DcLeft; + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDc] = + DcDefs::_32x64::Dc; + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<32, 64>; + + // 64x16 + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcTop] = + DcDefs::_64x16::DcTop; + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcLeft] = + DcDefs::_64x16::DcLeft; + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDc] = + DcDefs::_64x16::Dc; + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<64, 16>; + + // 64x32 + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcTop] = + DcDefs::_64x32::DcTop; + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcLeft] = + DcDefs::_64x32::DcLeft; + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDc] = + DcDefs::_64x32::Dc; + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<64, 32>; + + // 64x64 + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcTop] = + DcDefs::_64x64::DcTop; + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcLeft] = + DcDefs::_64x64::DcLeft; + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDc] = + DcDefs::_64x64::Dc; + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorPaeth] = + Paeth16PlusxN_NEON<64, 64>; +} + +} // namespace +} // namespace low_bitdepth + +//------------------------------------------------------------------------------ +#if LIBGAV1_MAX_BITDEPTH >= 10 +namespace high_bitdepth { +namespace { + +// Add the elements in the given vectors together but do not sum the entire +// vector. +inline uint16x8_t Add(const uint16x8_t val_0, const uint16x8_t val_1, + const uint16x8_t val_2, const uint16x8_t val_3) { + const uint16x8_t sum_0 = vaddq_u16(val_0, val_1); + const uint16x8_t sum_1 = vaddq_u16(val_2, val_3); + return vaddq_u16(sum_0, sum_1); +} + +// Load and combine 16 uint16_t values. +inline uint16x8_t LoadAndAdd16(const uint16_t* buf) { + const uint16x8_t val_0 = vld1q_u16(buf); + const uint16x8_t val_1 = vld1q_u16(buf + 8); + return vaddq_u16(val_0, val_1); +} + +// Load and combine 32 uint16_t values. +inline uint16x8_t LoadAndAdd32(const uint16_t* buf) { + const uint16x8_t val_0 = vld1q_u16(buf); + const uint16x8_t val_1 = vld1q_u16(buf + 8); + const uint16x8_t val_2 = vld1q_u16(buf + 16); + const uint16x8_t val_3 = vld1q_u16(buf + 24); + return Add(val_0, val_1, val_2, val_3); +} + +// Load and combine 64 uint16_t values. +inline uint16x8_t LoadAndAdd64(const uint16_t* buf) { + const uint16x8_t val_0 = vld1q_u16(buf); + const uint16x8_t val_1 = vld1q_u16(buf + 8); + const uint16x8_t val_2 = vld1q_u16(buf + 16); + const uint16x8_t val_3 = vld1q_u16(buf + 24); + const uint16x8_t val_4 = vld1q_u16(buf + 32); + const uint16x8_t val_5 = vld1q_u16(buf + 40); + const uint16x8_t val_6 = vld1q_u16(buf + 48); + const uint16x8_t val_7 = vld1q_u16(buf + 56); + const uint16x8_t sum_0 = Add(val_0, val_1, val_2, val_3); + const uint16x8_t sum_1 = Add(val_4, val_5, val_6, val_7); + return vaddq_u16(sum_0, sum_1); +} + +// |ref_[01]| each point to 1 << |ref[01]_size_log2| packed uint16_t values. +// If |use_ref_1| is false then only sum |ref_0|. +inline uint32x2_t DcSum_NEON(const void* ref_0, const int ref_0_size_log2, + const bool use_ref_1, const void* ref_1, + const int ref_1_size_log2) { + const auto* ref_0_u16 = static_cast<const uint16_t*>(ref_0); + const auto* ref_1_u16 = static_cast<const uint16_t*>(ref_1); + if (ref_0_size_log2 == 2) { + const uint16x4_t val_0 = vld1_u16(ref_0_u16); + if (use_ref_1) { + if (ref_1_size_log2 == 2) { // 4x4 + const uint16x4_t val_1 = vld1_u16(ref_1_u16); + return Sum(vadd_u16(val_0, val_1)); + } else if (ref_1_size_log2 == 3) { // 4x8 + const uint16x8_t val_1 = vld1q_u16(ref_1_u16); + const uint16x8_t sum_0 = vcombine_u16(vdup_n_u16(0), val_0); + return Sum(vaddq_u16(sum_0, val_1)); + } else if (ref_1_size_log2 == 4) { // 4x16 + const uint16x8_t sum_0 = vcombine_u16(vdup_n_u16(0), val_0); + const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } + } + // 4x1 + return Sum(val_0); + } else if (ref_0_size_log2 == 3) { + const uint16x8_t val_0 = vld1q_u16(ref_0_u16); + if (use_ref_1) { + if (ref_1_size_log2 == 2) { // 8x4 + const uint16x4_t val_1 = vld1_u16(ref_1_u16); + const uint16x8_t sum_1 = vcombine_u16(vdup_n_u16(0), val_1); + return Sum(vaddq_u16(val_0, sum_1)); + } else if (ref_1_size_log2 == 3) { // 8x8 + const uint16x8_t val_1 = vld1q_u16(ref_1_u16); + return Sum(vaddq_u16(val_0, val_1)); + } else if (ref_1_size_log2 == 4) { // 8x16 + const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16); + return Sum(vaddq_u16(val_0, sum_1)); + } else if (ref_1_size_log2 == 5) { // 8x32 + const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16); + return Sum(vaddq_u16(val_0, sum_1)); + } + } + // 8x1 + return Sum(val_0); + } else if (ref_0_size_log2 == 4) { + const uint16x8_t sum_0 = LoadAndAdd16(ref_0_u16); + if (use_ref_1) { + if (ref_1_size_log2 == 2) { // 16x4 + const uint16x4_t val_1 = vld1_u16(ref_1_u16); + const uint16x8_t sum_1 = vcombine_u16(vdup_n_u16(0), val_1); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 3) { // 16x8 + const uint16x8_t val_1 = vld1q_u16(ref_1_u16); + return Sum(vaddq_u16(sum_0, val_1)); + } else if (ref_1_size_log2 == 4) { // 16x16 + const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 5) { // 16x32 + const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 6) { // 16x64 + const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } + } + // 16x1 + return Sum(sum_0); + } else if (ref_0_size_log2 == 5) { + const uint16x8_t sum_0 = LoadAndAdd32(ref_0_u16); + if (use_ref_1) { + if (ref_1_size_log2 == 3) { // 32x8 + const uint16x8_t val_1 = vld1q_u16(ref_1_u16); + return Sum(vaddq_u16(sum_0, val_1)); + } else if (ref_1_size_log2 == 4) { // 32x16 + const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 5) { // 32x32 + const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 6) { // 32x64 + const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } + } + // 32x1 + return Sum(sum_0); + } + + assert(ref_0_size_log2 == 6); + const uint16x8_t sum_0 = LoadAndAdd64(ref_0_u16); + if (use_ref_1) { + if (ref_1_size_log2 == 4) { // 64x16 + const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 5) { // 64x32 + const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } else if (ref_1_size_log2 == 6) { // 64x64 + const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u16); + return Sum(vaddq_u16(sum_0, sum_1)); + } + } + // 64x1 + return Sum(sum_0); +} + +template <int width, int height> +inline void DcStore_NEON(void* const dest, ptrdiff_t stride, + const uint32x2_t dc) { + auto* dest_u16 = static_cast<uint16_t*>(dest); + ptrdiff_t stride_u16 = stride >> 1; + const uint16x8_t dc_dup = vdupq_lane_u16(vreinterpret_u16_u32(dc), 0); + if (width == 4) { + int i = height - 1; + do { + vst1_u16(dest_u16, vget_low_u16(dc_dup)); + dest_u16 += stride_u16; + } while (--i != 0); + vst1_u16(dest_u16, vget_low_u16(dc_dup)); + } else if (width == 8) { + int i = height - 1; + do { + vst1q_u16(dest_u16, dc_dup); + dest_u16 += stride_u16; + } while (--i != 0); + vst1q_u16(dest_u16, dc_dup); + } else if (width == 16) { + int i = height - 1; + do { + vst1q_u16(dest_u16, dc_dup); + vst1q_u16(dest_u16 + 8, dc_dup); + dest_u16 += stride_u16; + } while (--i != 0); + vst1q_u16(dest_u16, dc_dup); + vst1q_u16(dest_u16 + 8, dc_dup); + } else if (width == 32) { + int i = height - 1; + do { + vst1q_u16(dest_u16, dc_dup); + vst1q_u16(dest_u16 + 8, dc_dup); + vst1q_u16(dest_u16 + 16, dc_dup); + vst1q_u16(dest_u16 + 24, dc_dup); + dest_u16 += stride_u16; + } while (--i != 0); + vst1q_u16(dest_u16, dc_dup); + vst1q_u16(dest_u16 + 8, dc_dup); + vst1q_u16(dest_u16 + 16, dc_dup); + vst1q_u16(dest_u16 + 24, dc_dup); + } else { + assert(width == 64); + int i = height - 1; + do { + vst1q_u16(dest_u16, dc_dup); + vst1q_u16(dest_u16 + 8, dc_dup); + vst1q_u16(dest_u16 + 16, dc_dup); + vst1q_u16(dest_u16 + 24, dc_dup); + vst1q_u16(dest_u16 + 32, dc_dup); + vst1q_u16(dest_u16 + 40, dc_dup); + vst1q_u16(dest_u16 + 48, dc_dup); + vst1q_u16(dest_u16 + 56, dc_dup); + dest_u16 += stride_u16; + } while (--i != 0); + vst1q_u16(dest_u16, dc_dup); + vst1q_u16(dest_u16 + 8, dc_dup); + vst1q_u16(dest_u16 + 16, dc_dup); + vst1q_u16(dest_u16 + 24, dc_dup); + vst1q_u16(dest_u16 + 32, dc_dup); + vst1q_u16(dest_u16 + 40, dc_dup); + vst1q_u16(dest_u16 + 48, dc_dup); + vst1q_u16(dest_u16 + 56, dc_dup); + } +} + +struct DcDefs { + DcDefs() = delete; + + using _4x4 = DcPredFuncs_NEON<2, 2, DcSum_NEON, DcStore_NEON<4, 4>>; + using _4x8 = DcPredFuncs_NEON<2, 3, DcSum_NEON, DcStore_NEON<4, 8>>; + using _4x16 = DcPredFuncs_NEON<2, 4, DcSum_NEON, DcStore_NEON<4, 16>>; + using _8x4 = DcPredFuncs_NEON<3, 2, DcSum_NEON, DcStore_NEON<8, 4>>; + using _8x8 = DcPredFuncs_NEON<3, 3, DcSum_NEON, DcStore_NEON<8, 8>>; + using _8x16 = DcPredFuncs_NEON<3, 4, DcSum_NEON, DcStore_NEON<8, 16>>; + using _8x32 = DcPredFuncs_NEON<3, 5, DcSum_NEON, DcStore_NEON<8, 32>>; + using _16x4 = DcPredFuncs_NEON<4, 2, DcSum_NEON, DcStore_NEON<16, 4>>; + using _16x8 = DcPredFuncs_NEON<4, 3, DcSum_NEON, DcStore_NEON<16, 8>>; + using _16x16 = DcPredFuncs_NEON<4, 4, DcSum_NEON, DcStore_NEON<16, 16>>; + using _16x32 = DcPredFuncs_NEON<4, 5, DcSum_NEON, DcStore_NEON<16, 32>>; + using _16x64 = DcPredFuncs_NEON<4, 6, DcSum_NEON, DcStore_NEON<16, 64>>; + using _32x8 = DcPredFuncs_NEON<5, 3, DcSum_NEON, DcStore_NEON<32, 8>>; + using _32x16 = DcPredFuncs_NEON<5, 4, DcSum_NEON, DcStore_NEON<32, 16>>; + using _32x32 = DcPredFuncs_NEON<5, 5, DcSum_NEON, DcStore_NEON<32, 32>>; + using _32x64 = DcPredFuncs_NEON<5, 6, DcSum_NEON, DcStore_NEON<32, 64>>; + using _64x16 = DcPredFuncs_NEON<6, 4, DcSum_NEON, DcStore_NEON<64, 16>>; + using _64x32 = DcPredFuncs_NEON<6, 5, DcSum_NEON, DcStore_NEON<64, 32>>; + using _64x64 = DcPredFuncs_NEON<6, 6, DcSum_NEON, DcStore_NEON<64, 64>>; +}; + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcTop] = + DcDefs::_4x4::DcTop; + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcLeft] = + DcDefs::_4x4::DcLeft; + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDc] = + DcDefs::_4x4::Dc; + + // 4x8 + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcTop] = + DcDefs::_4x8::DcTop; + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcLeft] = + DcDefs::_4x8::DcLeft; + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDc] = + DcDefs::_4x8::Dc; + + // 4x16 + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcTop] = + DcDefs::_4x16::DcTop; + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcLeft] = + DcDefs::_4x16::DcLeft; + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDc] = + DcDefs::_4x16::Dc; + + // 8x4 + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcTop] = + DcDefs::_8x4::DcTop; + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcLeft] = + DcDefs::_8x4::DcLeft; + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDc] = + DcDefs::_8x4::Dc; + + // 8x8 + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcTop] = + DcDefs::_8x8::DcTop; + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcLeft] = + DcDefs::_8x8::DcLeft; + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDc] = + DcDefs::_8x8::Dc; + + // 8x16 + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcTop] = + DcDefs::_8x16::DcTop; + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcLeft] = + DcDefs::_8x16::DcLeft; + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDc] = + DcDefs::_8x16::Dc; + + // 8x32 + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcTop] = + DcDefs::_8x32::DcTop; + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcLeft] = + DcDefs::_8x32::DcLeft; + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDc] = + DcDefs::_8x32::Dc; + + // 16x4 + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcTop] = + DcDefs::_16x4::DcTop; + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcLeft] = + DcDefs::_16x4::DcLeft; + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDc] = + DcDefs::_16x4::Dc; + + // 16x8 + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcTop] = + DcDefs::_16x8::DcTop; + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcLeft] = + DcDefs::_16x8::DcLeft; + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDc] = + DcDefs::_16x8::Dc; + + // 16x16 + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcTop] = + DcDefs::_16x16::DcTop; + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcLeft] = + DcDefs::_16x16::DcLeft; + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDc] = + DcDefs::_16x16::Dc; + + // 16x32 + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcTop] = + DcDefs::_16x32::DcTop; + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcLeft] = + DcDefs::_16x32::DcLeft; + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDc] = + DcDefs::_16x32::Dc; + + // 16x64 + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcTop] = + DcDefs::_16x64::DcTop; + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcLeft] = + DcDefs::_16x64::DcLeft; + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDc] = + DcDefs::_16x64::Dc; + + // 32x8 + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcTop] = + DcDefs::_32x8::DcTop; + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcLeft] = + DcDefs::_32x8::DcLeft; + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDc] = + DcDefs::_32x8::Dc; + + // 32x16 + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcTop] = + DcDefs::_32x16::DcTop; + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcLeft] = + DcDefs::_32x16::DcLeft; + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDc] = + DcDefs::_32x16::Dc; + + // 32x32 + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcTop] = + DcDefs::_32x32::DcTop; + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcLeft] = + DcDefs::_32x32::DcLeft; + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDc] = + DcDefs::_32x32::Dc; + + // 32x64 + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcTop] = + DcDefs::_32x64::DcTop; + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcLeft] = + DcDefs::_32x64::DcLeft; + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDc] = + DcDefs::_32x64::Dc; + + // 64x16 + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcTop] = + DcDefs::_64x16::DcTop; + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcLeft] = + DcDefs::_64x16::DcLeft; + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDc] = + DcDefs::_64x16::Dc; + + // 64x32 + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcTop] = + DcDefs::_64x32::DcTop; + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcLeft] = + DcDefs::_64x32::DcLeft; + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDc] = + DcDefs::_64x32::Dc; + + // 64x64 + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcTop] = + DcDefs::_64x64::DcTop; + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcLeft] = + DcDefs::_64x64::DcLeft; + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDc] = + DcDefs::_64x64::Dc; +} + +} // namespace +} // namespace high_bitdepth +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +void IntraPredInit_NEON() { + low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + high_bitdepth::Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void IntraPredInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/intrapred_neon.h b/src/dsp/arm/intrapred_neon.h new file mode 100644 index 0000000..16f858c --- /dev/null +++ b/src/dsp/arm/intrapred_neon.h @@ -0,0 +1,418 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_INTRAPRED_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_INTRAPRED_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::intra_predictors, Dsp::directional_intra_predictor_zone*, +// Dsp::cfl_intra_predictors, Dsp::cfl_subsamplers and +// Dsp::filter_intra_predictor, see the defines below for specifics. These +// functions are not thread-safe. +void IntraPredCflInit_NEON(); +void IntraPredDirectionalInit_NEON(); +void IntraPredFilterIntraInit_NEON(); +void IntraPredInit_NEON(); +void IntraPredSmoothInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +// 8 bit +#define LIBGAV1_Dsp8bpp_FilterIntraPredictor LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone1 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone2 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone3 LIBGAV1_CPU_NEON + +// 4x4 +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444 LIBGAV1_CPU_NEON + +// 4x8 +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler444 LIBGAV1_CPU_NEON + +// 4x16 +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler444 LIBGAV1_CPU_NEON + +// 8x4 +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler444 LIBGAV1_CPU_NEON + +// 8x8 +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler444 LIBGAV1_CPU_NEON + +// 8x16 +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler444 LIBGAV1_CPU_NEON + +// 8x32 +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler444 LIBGAV1_CPU_NEON + +// 16x4 +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler444 LIBGAV1_CPU_NEON + +// 16x8 +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler444 LIBGAV1_CPU_NEON + +// 16x16 +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler444 LIBGAV1_CPU_NEON + +// 16x32 +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler444 LIBGAV1_CPU_NEON + +// 16x64 +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +// 32x8 +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler444 LIBGAV1_CPU_NEON + +// 32x16 +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler444 LIBGAV1_CPU_NEON + +// 32x32 +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflIntraPredictor LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444 LIBGAV1_CPU_NEON + +// 32x64 +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +// 64x16 +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +// 64x32 +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +// 64x64 +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDc LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorPaeth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmooth LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_NEON + +// 10 bit +// 4x4 +#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDc LIBGAV1_CPU_NEON + +// 4x8 +#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDc LIBGAV1_CPU_NEON + +// 4x16 +#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDc LIBGAV1_CPU_NEON + +// 8x4 +#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDc LIBGAV1_CPU_NEON + +// 8x8 +#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDc LIBGAV1_CPU_NEON + +// 8x16 +#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDc LIBGAV1_CPU_NEON + +// 8x32 +#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDc LIBGAV1_CPU_NEON + +// 16x4 +#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDc LIBGAV1_CPU_NEON + +// 16x8 +#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDc LIBGAV1_CPU_NEON + +// 16x16 +#define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDcLeft \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDc LIBGAV1_CPU_NEON + +// 16x32 +#define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDcLeft \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDc LIBGAV1_CPU_NEON + +// 16x64 +#define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDcLeft \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDc LIBGAV1_CPU_NEON + +// 32x8 +#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDc LIBGAV1_CPU_NEON + +// 32x16 +#define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDcLeft \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDc LIBGAV1_CPU_NEON + +// 32x32 +#define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDcLeft \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDc LIBGAV1_CPU_NEON + +// 32x64 +#define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorDcLeft \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorDc LIBGAV1_CPU_NEON + +// 64x16 +#define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDcLeft \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDc LIBGAV1_CPU_NEON + +// 64x32 +#define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDcLeft \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDc LIBGAV1_CPU_NEON + +// 64x64 +#define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDcTop LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDcLeft \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDc LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_INTRAPRED_NEON_H_ diff --git a/src/dsp/arm/intrapred_smooth_neon.cc b/src/dsp/arm/intrapred_smooth_neon.cc new file mode 100644 index 0000000..abc93e8 --- /dev/null +++ b/src/dsp/arm/intrapred_smooth_neon.cc @@ -0,0 +1,616 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intrapred.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" + +namespace libgav1 { +namespace dsp { + +namespace low_bitdepth { +namespace { + +// Note these constants are duplicated from intrapred.cc to allow the compiler +// to have visibility of the values. This helps reduce loads and in the +// creation of the inverse weights. +constexpr uint8_t kSmoothWeights[] = { + // block dimension = 4 + 255, 149, 85, 64, + // block dimension = 8 + 255, 197, 146, 105, 73, 50, 37, 32, + // block dimension = 16 + 255, 225, 196, 170, 145, 123, 102, 84, 68, 54, 43, 33, 26, 20, 17, 16, + // block dimension = 32 + 255, 240, 225, 210, 196, 182, 169, 157, 145, 133, 122, 111, 101, 92, 83, 74, + 66, 59, 52, 45, 39, 34, 29, 25, 21, 17, 14, 12, 10, 9, 8, 8, + // block dimension = 64 + 255, 248, 240, 233, 225, 218, 210, 203, 196, 189, 182, 176, 169, 163, 156, + 150, 144, 138, 133, 127, 121, 116, 111, 106, 101, 96, 91, 86, 82, 77, 73, + 69, 65, 61, 57, 54, 50, 47, 44, 41, 38, 35, 32, 29, 27, 25, 22, 20, 18, 16, + 15, 13, 12, 10, 9, 8, 7, 6, 6, 5, 5, 4, 4, 4}; + +// TODO(b/150459137): Keeping the intermediate values in uint16_t would allow +// processing more values at once. At the high end, it could do 4x4 or 8x2 at a +// time. +inline uint16x4_t CalculatePred(const uint16x4_t weighted_top, + const uint16x4_t weighted_left, + const uint16x4_t weighted_bl, + const uint16x4_t weighted_tr) { + const uint32x4_t pred_0 = vaddl_u16(weighted_top, weighted_left); + const uint32x4_t pred_1 = vaddl_u16(weighted_bl, weighted_tr); + const uint32x4_t pred_2 = vaddq_u32(pred_0, pred_1); + return vrshrn_n_u32(pred_2, kSmoothWeightScale + 1); +} + +template <int width, int height> +inline void Smooth4Or8xN_NEON(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const uint8_t* const top = static_cast<const uint8_t*>(top_row); + const uint8_t* const left = static_cast<const uint8_t*>(left_column); + const uint8_t top_right = top[width - 1]; + const uint8_t bottom_left = left[height - 1]; + const uint8_t* const weights_y = kSmoothWeights + height - 4; + uint8_t* dst = static_cast<uint8_t*>(dest); + + uint8x8_t top_v; + if (width == 4) { + top_v = Load4(top); + } else { // width == 8 + top_v = vld1_u8(top); + } + const uint8x8_t top_right_v = vdup_n_u8(top_right); + const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left); + // Over-reads for 4xN but still within the array. + const uint8x8_t weights_x_v = vld1_u8(kSmoothWeights + width - 4); + // 256 - weights = vneg_s8(weights) + const uint8x8_t scaled_weights_x = + vreinterpret_u8_s8(vneg_s8(vreinterpret_s8_u8(weights_x_v))); + + for (int y = 0; y < height; ++y) { + const uint8x8_t left_v = vdup_n_u8(left[y]); + const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]); + const uint8x8_t scaled_weights_y = vdup_n_u8(256 - weights_y[y]); + const uint16x8_t weighted_bl = vmull_u8(scaled_weights_y, bottom_left_v); + + const uint16x8_t weighted_top = vmull_u8(weights_y_v, top_v); + const uint16x8_t weighted_left = vmull_u8(weights_x_v, left_v); + const uint16x8_t weighted_tr = vmull_u8(scaled_weights_x, top_right_v); + const uint16x4_t dest_0 = + CalculatePred(vget_low_u16(weighted_top), vget_low_u16(weighted_left), + vget_low_u16(weighted_tr), vget_low_u16(weighted_bl)); + + if (width == 4) { + StoreLo4(dst, vmovn_u16(vcombine_u16(dest_0, dest_0))); + } else { // width == 8 + const uint16x4_t dest_1 = CalculatePred( + vget_high_u16(weighted_top), vget_high_u16(weighted_left), + vget_high_u16(weighted_tr), vget_high_u16(weighted_bl)); + vst1_u8(dst, vmovn_u16(vcombine_u16(dest_0, dest_1))); + } + dst += stride; + } +} + +inline uint8x16_t CalculateWeightsAndPred( + const uint8x16_t top, const uint8x8_t left, const uint8x8_t top_right, + const uint8x8_t weights_y, const uint8x16_t weights_x, + const uint8x16_t scaled_weights_x, const uint16x8_t weighted_bl) { + const uint16x8_t weighted_top_low = vmull_u8(weights_y, vget_low_u8(top)); + const uint16x8_t weighted_left_low = vmull_u8(vget_low_u8(weights_x), left); + const uint16x8_t weighted_tr_low = + vmull_u8(vget_low_u8(scaled_weights_x), top_right); + const uint16x4_t dest_0 = CalculatePred( + vget_low_u16(weighted_top_low), vget_low_u16(weighted_left_low), + vget_low_u16(weighted_tr_low), vget_low_u16(weighted_bl)); + const uint16x4_t dest_1 = CalculatePred( + vget_high_u16(weighted_top_low), vget_high_u16(weighted_left_low), + vget_high_u16(weighted_tr_low), vget_high_u16(weighted_bl)); + const uint8x8_t dest_0_u8 = vmovn_u16(vcombine_u16(dest_0, dest_1)); + + const uint16x8_t weighted_top_high = vmull_u8(weights_y, vget_high_u8(top)); + const uint16x8_t weighted_left_high = vmull_u8(vget_high_u8(weights_x), left); + const uint16x8_t weighted_tr_high = + vmull_u8(vget_high_u8(scaled_weights_x), top_right); + const uint16x4_t dest_2 = CalculatePred( + vget_low_u16(weighted_top_high), vget_low_u16(weighted_left_high), + vget_low_u16(weighted_tr_high), vget_low_u16(weighted_bl)); + const uint16x4_t dest_3 = CalculatePred( + vget_high_u16(weighted_top_high), vget_high_u16(weighted_left_high), + vget_high_u16(weighted_tr_high), vget_high_u16(weighted_bl)); + const uint8x8_t dest_1_u8 = vmovn_u16(vcombine_u16(dest_2, dest_3)); + + return vcombine_u8(dest_0_u8, dest_1_u8); +} + +template <int width, int height> +inline void Smooth16PlusxN_NEON(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const uint8_t* const top = static_cast<const uint8_t*>(top_row); + const uint8_t* const left = static_cast<const uint8_t*>(left_column); + const uint8_t top_right = top[width - 1]; + const uint8_t bottom_left = left[height - 1]; + const uint8_t* const weights_y = kSmoothWeights + height - 4; + uint8_t* dst = static_cast<uint8_t*>(dest); + + uint8x16_t top_v[4]; + top_v[0] = vld1q_u8(top); + if (width > 16) { + top_v[1] = vld1q_u8(top + 16); + if (width == 64) { + top_v[2] = vld1q_u8(top + 32); + top_v[3] = vld1q_u8(top + 48); + } + } + + const uint8x8_t top_right_v = vdup_n_u8(top_right); + const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left); + + // TODO(johannkoenig): Consider re-reading top_v and weights_x_v in the loop. + // This currently has a performance slope similar to Paeth so it does not + // appear to be register bound for arm64. + uint8x16_t weights_x_v[4]; + weights_x_v[0] = vld1q_u8(kSmoothWeights + width - 4); + if (width > 16) { + weights_x_v[1] = vld1q_u8(kSmoothWeights + width + 16 - 4); + if (width == 64) { + weights_x_v[2] = vld1q_u8(kSmoothWeights + width + 32 - 4); + weights_x_v[3] = vld1q_u8(kSmoothWeights + width + 48 - 4); + } + } + + uint8x16_t scaled_weights_x[4]; + scaled_weights_x[0] = + vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x_v[0]))); + if (width > 16) { + scaled_weights_x[1] = + vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x_v[1]))); + if (width == 64) { + scaled_weights_x[2] = + vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x_v[2]))); + scaled_weights_x[3] = + vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x_v[3]))); + } + } + + for (int y = 0; y < height; ++y) { + const uint8x8_t left_v = vdup_n_u8(left[y]); + const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]); + const uint8x8_t scaled_weights_y = vdup_n_u8(256 - weights_y[y]); + const uint16x8_t weighted_bl = vmull_u8(scaled_weights_y, bottom_left_v); + + vst1q_u8(dst, CalculateWeightsAndPred(top_v[0], left_v, top_right_v, + weights_y_v, weights_x_v[0], + scaled_weights_x[0], weighted_bl)); + + if (width > 16) { + vst1q_u8(dst + 16, CalculateWeightsAndPred( + top_v[1], left_v, top_right_v, weights_y_v, + weights_x_v[1], scaled_weights_x[1], weighted_bl)); + if (width == 64) { + vst1q_u8(dst + 32, + CalculateWeightsAndPred(top_v[2], left_v, top_right_v, + weights_y_v, weights_x_v[2], + scaled_weights_x[2], weighted_bl)); + vst1q_u8(dst + 48, + CalculateWeightsAndPred(top_v[3], left_v, top_right_v, + weights_y_v, weights_x_v[3], + scaled_weights_x[3], weighted_bl)); + } + } + + dst += stride; + } +} + +template <int width, int height> +inline void SmoothVertical4Or8xN_NEON(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const uint8_t* const top = static_cast<const uint8_t*>(top_row); + const uint8_t* const left = static_cast<const uint8_t*>(left_column); + const uint8_t bottom_left = left[height - 1]; + const uint8_t* const weights_y = kSmoothWeights + height - 4; + uint8_t* dst = static_cast<uint8_t*>(dest); + + uint8x8_t top_v; + if (width == 4) { + top_v = Load4(top); + } else { // width == 8 + top_v = vld1_u8(top); + } + + const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left); + + for (int y = 0; y < height; ++y) { + const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]); + const uint8x8_t scaled_weights_y = vdup_n_u8(256 - weights_y[y]); + + const uint16x8_t weighted_top = vmull_u8(weights_y_v, top_v); + const uint16x8_t weighted_bl = vmull_u8(scaled_weights_y, bottom_left_v); + const uint16x8_t pred = vaddq_u16(weighted_top, weighted_bl); + const uint8x8_t pred_scaled = vrshrn_n_u16(pred, kSmoothWeightScale); + + if (width == 4) { + StoreLo4(dst, pred_scaled); + } else { // width == 8 + vst1_u8(dst, pred_scaled); + } + dst += stride; + } +} + +inline uint8x16_t CalculateVerticalWeightsAndPred( + const uint8x16_t top, const uint8x8_t weights_y, + const uint16x8_t weighted_bl) { + const uint16x8_t weighted_top_low = vmull_u8(weights_y, vget_low_u8(top)); + const uint16x8_t weighted_top_high = vmull_u8(weights_y, vget_high_u8(top)); + const uint16x8_t pred_low = vaddq_u16(weighted_top_low, weighted_bl); + const uint16x8_t pred_high = vaddq_u16(weighted_top_high, weighted_bl); + const uint8x8_t pred_scaled_low = vrshrn_n_u16(pred_low, kSmoothWeightScale); + const uint8x8_t pred_scaled_high = + vrshrn_n_u16(pred_high, kSmoothWeightScale); + return vcombine_u8(pred_scaled_low, pred_scaled_high); +} + +template <int width, int height> +inline void SmoothVertical16PlusxN_NEON(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const uint8_t* const top = static_cast<const uint8_t*>(top_row); + const uint8_t* const left = static_cast<const uint8_t*>(left_column); + const uint8_t bottom_left = left[height - 1]; + const uint8_t* const weights_y = kSmoothWeights + height - 4; + uint8_t* dst = static_cast<uint8_t*>(dest); + + uint8x16_t top_v[4]; + top_v[0] = vld1q_u8(top); + if (width > 16) { + top_v[1] = vld1q_u8(top + 16); + if (width == 64) { + top_v[2] = vld1q_u8(top + 32); + top_v[3] = vld1q_u8(top + 48); + } + } + + const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left); + + for (int y = 0; y < height; ++y) { + const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]); + const uint8x8_t scaled_weights_y = vdup_n_u8(256 - weights_y[y]); + const uint16x8_t weighted_bl = vmull_u8(scaled_weights_y, bottom_left_v); + + const uint8x16_t pred_0 = + CalculateVerticalWeightsAndPred(top_v[0], weights_y_v, weighted_bl); + vst1q_u8(dst, pred_0); + + if (width > 16) { + const uint8x16_t pred_1 = + CalculateVerticalWeightsAndPred(top_v[1], weights_y_v, weighted_bl); + vst1q_u8(dst + 16, pred_1); + + if (width == 64) { + const uint8x16_t pred_2 = + CalculateVerticalWeightsAndPred(top_v[2], weights_y_v, weighted_bl); + vst1q_u8(dst + 32, pred_2); + + const uint8x16_t pred_3 = + CalculateVerticalWeightsAndPred(top_v[3], weights_y_v, weighted_bl); + vst1q_u8(dst + 48, pred_3); + } + } + + dst += stride; + } +} + +template <int width, int height> +inline void SmoothHorizontal4Or8xN_NEON(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const uint8_t* const top = static_cast<const uint8_t*>(top_row); + const uint8_t* const left = static_cast<const uint8_t*>(left_column); + const uint8_t top_right = top[width - 1]; + uint8_t* dst = static_cast<uint8_t*>(dest); + + const uint8x8_t top_right_v = vdup_n_u8(top_right); + // Over-reads for 4xN but still within the array. + const uint8x8_t weights_x = vld1_u8(kSmoothWeights + width - 4); + // 256 - weights = vneg_s8(weights) + const uint8x8_t scaled_weights_x = + vreinterpret_u8_s8(vneg_s8(vreinterpret_s8_u8(weights_x))); + + for (int y = 0; y < height; ++y) { + const uint8x8_t left_v = vdup_n_u8(left[y]); + + const uint16x8_t weighted_left = vmull_u8(weights_x, left_v); + const uint16x8_t weighted_tr = vmull_u8(scaled_weights_x, top_right_v); + const uint16x8_t pred = vaddq_u16(weighted_left, weighted_tr); + const uint8x8_t pred_scaled = vrshrn_n_u16(pred, kSmoothWeightScale); + + if (width == 4) { + StoreLo4(dst, pred_scaled); + } else { // width == 8 + vst1_u8(dst, pred_scaled); + } + dst += stride; + } +} + +inline uint8x16_t CalculateHorizontalWeightsAndPred( + const uint8x8_t left, const uint8x8_t top_right, const uint8x16_t weights_x, + const uint8x16_t scaled_weights_x) { + const uint16x8_t weighted_left_low = vmull_u8(vget_low_u8(weights_x), left); + const uint16x8_t weighted_tr_low = + vmull_u8(vget_low_u8(scaled_weights_x), top_right); + const uint16x8_t pred_low = vaddq_u16(weighted_left_low, weighted_tr_low); + const uint8x8_t pred_scaled_low = vrshrn_n_u16(pred_low, kSmoothWeightScale); + + const uint16x8_t weighted_left_high = vmull_u8(vget_high_u8(weights_x), left); + const uint16x8_t weighted_tr_high = + vmull_u8(vget_high_u8(scaled_weights_x), top_right); + const uint16x8_t pred_high = vaddq_u16(weighted_left_high, weighted_tr_high); + const uint8x8_t pred_scaled_high = + vrshrn_n_u16(pred_high, kSmoothWeightScale); + + return vcombine_u8(pred_scaled_low, pred_scaled_high); +} + +template <int width, int height> +inline void SmoothHorizontal16PlusxN_NEON(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const uint8_t* const top = static_cast<const uint8_t*>(top_row); + const uint8_t* const left = static_cast<const uint8_t*>(left_column); + const uint8_t top_right = top[width - 1]; + uint8_t* dst = static_cast<uint8_t*>(dest); + + const uint8x8_t top_right_v = vdup_n_u8(top_right); + + uint8x16_t weights_x[4]; + weights_x[0] = vld1q_u8(kSmoothWeights + width - 4); + if (width > 16) { + weights_x[1] = vld1q_u8(kSmoothWeights + width + 16 - 4); + if (width == 64) { + weights_x[2] = vld1q_u8(kSmoothWeights + width + 32 - 4); + weights_x[3] = vld1q_u8(kSmoothWeights + width + 48 - 4); + } + } + + uint8x16_t scaled_weights_x[4]; + scaled_weights_x[0] = + vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x[0]))); + if (width > 16) { + scaled_weights_x[1] = + vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x[1]))); + if (width == 64) { + scaled_weights_x[2] = + vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x[2]))); + scaled_weights_x[3] = + vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x[3]))); + } + } + + for (int y = 0; y < height; ++y) { + const uint8x8_t left_v = vdup_n_u8(left[y]); + + const uint8x16_t pred_0 = CalculateHorizontalWeightsAndPred( + left_v, top_right_v, weights_x[0], scaled_weights_x[0]); + vst1q_u8(dst, pred_0); + + if (width > 16) { + const uint8x16_t pred_1 = CalculateHorizontalWeightsAndPred( + left_v, top_right_v, weights_x[1], scaled_weights_x[1]); + vst1q_u8(dst + 16, pred_1); + + if (width == 64) { + const uint8x16_t pred_2 = CalculateHorizontalWeightsAndPred( + left_v, top_right_v, weights_x[2], scaled_weights_x[2]); + vst1q_u8(dst + 32, pred_2); + + const uint8x16_t pred_3 = CalculateHorizontalWeightsAndPred( + left_v, top_right_v, weights_x[3], scaled_weights_x[3]); + vst1q_u8(dst + 48, pred_3); + } + } + dst += stride; + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + // 4x4 + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmooth] = + Smooth4Or8xN_NEON<4, 4>; + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmoothVertical] = + SmoothVertical4Or8xN_NEON<4, 4>; + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal4Or8xN_NEON<4, 4>; + + // 4x8 + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmooth] = + Smooth4Or8xN_NEON<4, 8>; + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmoothVertical] = + SmoothVertical4Or8xN_NEON<4, 8>; + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal4Or8xN_NEON<4, 8>; + + // 4x16 + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmooth] = + Smooth4Or8xN_NEON<4, 16>; + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmoothVertical] = + SmoothVertical4Or8xN_NEON<4, 16>; + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal4Or8xN_NEON<4, 16>; + + // 8x4 + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmooth] = + Smooth4Or8xN_NEON<8, 4>; + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmoothVertical] = + SmoothVertical4Or8xN_NEON<8, 4>; + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal4Or8xN_NEON<8, 4>; + + // 8x8 + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmooth] = + Smooth4Or8xN_NEON<8, 8>; + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmoothVertical] = + SmoothVertical4Or8xN_NEON<8, 8>; + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal4Or8xN_NEON<8, 8>; + + // 8x16 + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmooth] = + Smooth4Or8xN_NEON<8, 16>; + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmoothVertical] = + SmoothVertical4Or8xN_NEON<8, 16>; + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal4Or8xN_NEON<8, 16>; + + // 8x32 + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmooth] = + Smooth4Or8xN_NEON<8, 32>; + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmoothVertical] = + SmoothVertical4Or8xN_NEON<8, 32>; + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal4Or8xN_NEON<8, 32>; + + // 16x4 + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<16, 4>; + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<16, 4>; + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<16, 4>; + + // 16x8 + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<16, 8>; + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<16, 8>; + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<16, 8>; + + // 16x16 + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<16, 16>; + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<16, 16>; + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<16, 16>; + + // 16x32 + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<16, 32>; + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<16, 32>; + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<16, 32>; + + // 16x64 + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<16, 64>; + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<16, 64>; + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<16, 64>; + + // 32x8 + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<32, 8>; + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<32, 8>; + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<32, 8>; + + // 32x16 + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<32, 16>; + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<32, 16>; + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<32, 16>; + + // 32x32 + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<32, 32>; + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<32, 32>; + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<32, 32>; + + // 32x64 + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<32, 64>; + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<32, 64>; + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<32, 64>; + + // 64x16 + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<64, 16>; + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<64, 16>; + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<64, 16>; + + // 64x32 + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<64, 32>; + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<64, 32>; + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<64, 32>; + + // 64x64 + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmooth] = + Smooth16PlusxN_NEON<64, 64>; + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmoothVertical] = + SmoothVertical16PlusxN_NEON<64, 64>; + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16PlusxN_NEON<64, 64>; +} + +} // namespace +} // namespace low_bitdepth + +void IntraPredSmoothInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void IntraPredSmoothInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/inverse_transform_neon.cc b/src/dsp/arm/inverse_transform_neon.cc new file mode 100644 index 0000000..072991a --- /dev/null +++ b/src/dsp/arm/inverse_transform_neon.cc @@ -0,0 +1,3128 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/inverse_transform.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <algorithm> +#include <cassert> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/array_2d.h" +#include "src/utils/common.h" +#include "src/utils/compiler_attributes.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +// Include the constants and utility functions inside the anonymous namespace. +#include "src/dsp/inverse_transform.inc" + +//------------------------------------------------------------------------------ + +// TODO(slavarnway): Move transpose functions to transpose_neon.h or +// common_neon.h. + +LIBGAV1_ALWAYS_INLINE void Transpose4x4(const int16x8_t in[4], + int16x8_t out[4]) { + // Swap 16 bit elements. Goes from: + // a0: 00 01 02 03 + // a1: 10 11 12 13 + // a2: 20 21 22 23 + // a3: 30 31 32 33 + // to: + // b0.val[0]: 00 10 02 12 + // b0.val[1]: 01 11 03 13 + // b1.val[0]: 20 30 22 32 + // b1.val[1]: 21 31 23 33 + const int16x4_t a0 = vget_low_s16(in[0]); + const int16x4_t a1 = vget_low_s16(in[1]); + const int16x4_t a2 = vget_low_s16(in[2]); + const int16x4_t a3 = vget_low_s16(in[3]); + + const int16x4x2_t b0 = vtrn_s16(a0, a1); + const int16x4x2_t b1 = vtrn_s16(a2, a3); + + // Swap 32 bit elements resulting in: + // c0.val[0]: 00 10 20 30 04 14 24 34 + // c0.val[1]: 02 12 22 32 06 16 26 36 + // c1.val[0]: 01 11 21 31 05 15 25 35 + // c1.val[1]: 03 13 23 33 07 17 27 37 + const int32x2x2_t c0 = vtrn_s32(vreinterpret_s32_s16(b0.val[0]), + vreinterpret_s32_s16(b1.val[0])); + const int32x2x2_t c1 = vtrn_s32(vreinterpret_s32_s16(b0.val[1]), + vreinterpret_s32_s16(b1.val[1])); + + const int16x4_t d0 = vreinterpret_s16_s32(c0.val[0]); + const int16x4_t d1 = vreinterpret_s16_s32(c1.val[0]); + const int16x4_t d2 = vreinterpret_s16_s32(c0.val[1]); + const int16x4_t d3 = vreinterpret_s16_s32(c1.val[1]); + + out[0] = vcombine_s16(d0, d0); + out[1] = vcombine_s16(d1, d1); + out[2] = vcombine_s16(d2, d2); + out[3] = vcombine_s16(d3, d3); +} + +// Note this is only used in the final stage of Dct32/64 and Adst16 as the in +// place version causes additional stack usage with clang. +LIBGAV1_ALWAYS_INLINE void Transpose8x8(const int16x8_t in[8], + int16x8_t out[8]) { + // Swap 16 bit elements. Goes from: + // a0: 00 01 02 03 04 05 06 07 + // a1: 10 11 12 13 14 15 16 17 + // a2: 20 21 22 23 24 25 26 27 + // a3: 30 31 32 33 34 35 36 37 + // a4: 40 41 42 43 44 45 46 47 + // a5: 50 51 52 53 54 55 56 57 + // a6: 60 61 62 63 64 65 66 67 + // a7: 70 71 72 73 74 75 76 77 + // to: + // b0.val[0]: 00 10 02 12 04 14 06 16 + // b0.val[1]: 01 11 03 13 05 15 07 17 + // b1.val[0]: 20 30 22 32 24 34 26 36 + // b1.val[1]: 21 31 23 33 25 35 27 37 + // b2.val[0]: 40 50 42 52 44 54 46 56 + // b2.val[1]: 41 51 43 53 45 55 47 57 + // b3.val[0]: 60 70 62 72 64 74 66 76 + // b3.val[1]: 61 71 63 73 65 75 67 77 + + const int16x8x2_t b0 = vtrnq_s16(in[0], in[1]); + const int16x8x2_t b1 = vtrnq_s16(in[2], in[3]); + const int16x8x2_t b2 = vtrnq_s16(in[4], in[5]); + const int16x8x2_t b3 = vtrnq_s16(in[6], in[7]); + + // Swap 32 bit elements resulting in: + // c0.val[0]: 00 10 20 30 04 14 24 34 + // c0.val[1]: 02 12 22 32 06 16 26 36 + // c1.val[0]: 01 11 21 31 05 15 25 35 + // c1.val[1]: 03 13 23 33 07 17 27 37 + // c2.val[0]: 40 50 60 70 44 54 64 74 + // c2.val[1]: 42 52 62 72 46 56 66 76 + // c3.val[0]: 41 51 61 71 45 55 65 75 + // c3.val[1]: 43 53 63 73 47 57 67 77 + + const int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[0]), + vreinterpretq_s32_s16(b1.val[0])); + const int32x4x2_t c1 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[1]), + vreinterpretq_s32_s16(b1.val[1])); + const int32x4x2_t c2 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[0]), + vreinterpretq_s32_s16(b3.val[0])); + const int32x4x2_t c3 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[1]), + vreinterpretq_s32_s16(b3.val[1])); + + // Swap 64 bit elements resulting in: + // d0.val[0]: 00 10 20 30 40 50 60 70 + // d0.val[1]: 04 14 24 34 44 54 64 74 + // d1.val[0]: 01 11 21 31 41 51 61 71 + // d1.val[1]: 05 15 25 35 45 55 65 75 + // d2.val[0]: 02 12 22 32 42 52 62 72 + // d2.val[1]: 06 16 26 36 46 56 66 76 + // d3.val[0]: 03 13 23 33 43 53 63 73 + // d3.val[1]: 07 17 27 37 47 57 67 77 + const int16x8x2_t d0 = VtrnqS64(c0.val[0], c2.val[0]); + const int16x8x2_t d1 = VtrnqS64(c1.val[0], c3.val[0]); + const int16x8x2_t d2 = VtrnqS64(c0.val[1], c2.val[1]); + const int16x8x2_t d3 = VtrnqS64(c1.val[1], c3.val[1]); + + out[0] = d0.val[0]; + out[1] = d1.val[0]; + out[2] = d2.val[0]; + out[3] = d3.val[0]; + out[4] = d0.val[1]; + out[5] = d1.val[1]; + out[6] = d2.val[1]; + out[7] = d3.val[1]; +} + +LIBGAV1_ALWAYS_INLINE void Transpose4x8To8x4(const uint16x8_t in[8], + uint16x8_t out[4]) { + // Swap 16 bit elements. Goes from: + // a0: 00 01 02 03 + // a1: 10 11 12 13 + // a2: 20 21 22 23 + // a3: 30 31 32 33 + // a4: 40 41 42 43 + // a5: 50 51 52 53 + // a6: 60 61 62 63 + // a7: 70 71 72 73 + // to: + // b0.val[0]: 00 10 02 12 + // b0.val[1]: 01 11 03 13 + // b1.val[0]: 20 30 22 32 + // b1.val[1]: 21 31 23 33 + // b2.val[0]: 40 50 42 52 + // b2.val[1]: 41 51 43 53 + // b3.val[0]: 60 70 62 72 + // b3.val[1]: 61 71 63 73 + + uint16x4x2_t b0 = vtrn_u16(vget_low_u16(in[0]), vget_low_u16(in[1])); + uint16x4x2_t b1 = vtrn_u16(vget_low_u16(in[2]), vget_low_u16(in[3])); + uint16x4x2_t b2 = vtrn_u16(vget_low_u16(in[4]), vget_low_u16(in[5])); + uint16x4x2_t b3 = vtrn_u16(vget_low_u16(in[6]), vget_low_u16(in[7])); + + // Swap 32 bit elements resulting in: + // c0.val[0]: 00 10 20 30 + // c0.val[1]: 02 12 22 32 + // c1.val[0]: 01 11 21 31 + // c1.val[1]: 03 13 23 33 + // c2.val[0]: 40 50 60 70 + // c2.val[1]: 42 52 62 72 + // c3.val[0]: 41 51 61 71 + // c3.val[1]: 43 53 63 73 + + uint32x2x2_t c0 = vtrn_u32(vreinterpret_u32_u16(b0.val[0]), + vreinterpret_u32_u16(b1.val[0])); + uint32x2x2_t c1 = vtrn_u32(vreinterpret_u32_u16(b0.val[1]), + vreinterpret_u32_u16(b1.val[1])); + uint32x2x2_t c2 = vtrn_u32(vreinterpret_u32_u16(b2.val[0]), + vreinterpret_u32_u16(b3.val[0])); + uint32x2x2_t c3 = vtrn_u32(vreinterpret_u32_u16(b2.val[1]), + vreinterpret_u32_u16(b3.val[1])); + + // Swap 64 bit elements resulting in: + // o0: 00 10 20 30 40 50 60 70 + // o1: 01 11 21 31 41 51 61 71 + // o2: 02 12 22 32 42 52 62 72 + // o3: 03 13 23 33 43 53 63 73 + + out[0] = vcombine_u16(vreinterpret_u16_u32(c0.val[0]), + vreinterpret_u16_u32(c2.val[0])); + out[1] = vcombine_u16(vreinterpret_u16_u32(c1.val[0]), + vreinterpret_u16_u32(c3.val[0])); + out[2] = vcombine_u16(vreinterpret_u16_u32(c0.val[1]), + vreinterpret_u16_u32(c2.val[1])); + out[3] = vcombine_u16(vreinterpret_u16_u32(c1.val[1]), + vreinterpret_u16_u32(c3.val[1])); +} + +LIBGAV1_ALWAYS_INLINE void Transpose4x8To8x4(const int16x8_t in[8], + int16x8_t out[4]) { + Transpose4x8To8x4(reinterpret_cast<const uint16x8_t*>(in), + reinterpret_cast<uint16x8_t*>(out)); +} + +LIBGAV1_ALWAYS_INLINE void Transpose8x4To4x8(const int16x8_t in[4], + int16x8_t out[8]) { + // Swap 16 bit elements. Goes from: + // a0: 00 01 02 03 04 05 06 07 + // a1: 10 11 12 13 14 15 16 17 + // a2: 20 21 22 23 24 25 26 27 + // a3: 30 31 32 33 34 35 36 37 + // to: + // b0.val[0]: 00 10 02 12 04 14 06 16 + // b0.val[1]: 01 11 03 13 05 15 07 17 + // b1.val[0]: 20 30 22 32 24 34 26 36 + // b1.val[1]: 21 31 23 33 25 35 27 37 + const int16x8x2_t b0 = vtrnq_s16(in[0], in[1]); + const int16x8x2_t b1 = vtrnq_s16(in[2], in[3]); + + // Swap 32 bit elements resulting in: + // c0.val[0]: 00 10 20 30 04 14 24 34 + // c0.val[1]: 02 12 22 32 06 16 26 36 + // c1.val[0]: 01 11 21 31 05 15 25 35 + // c1.val[1]: 03 13 23 33 07 17 27 37 + const int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[0]), + vreinterpretq_s32_s16(b1.val[0])); + const int32x4x2_t c1 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[1]), + vreinterpretq_s32_s16(b1.val[1])); + + // The upper 8 bytes are don't cares. + // out[0]: 00 10 20 30 04 14 24 34 + // out[1]: 01 11 21 31 05 15 25 35 + // out[2]: 02 12 22 32 06 16 26 36 + // out[3]: 03 13 23 33 07 17 27 37 + // out[4]: 04 14 24 34 04 14 24 34 + // out[5]: 05 15 25 35 05 15 25 35 + // out[6]: 06 16 26 36 06 16 26 36 + // out[7]: 07 17 27 37 07 17 27 37 + out[0] = vreinterpretq_s16_s32(c0.val[0]); + out[1] = vreinterpretq_s16_s32(c1.val[0]); + out[2] = vreinterpretq_s16_s32(c0.val[1]); + out[3] = vreinterpretq_s16_s32(c1.val[1]); + out[4] = vreinterpretq_s16_s32( + vcombine_s32(vget_high_s32(c0.val[0]), vget_high_s32(c0.val[0]))); + out[5] = vreinterpretq_s16_s32( + vcombine_s32(vget_high_s32(c1.val[0]), vget_high_s32(c1.val[0]))); + out[6] = vreinterpretq_s16_s32( + vcombine_s32(vget_high_s32(c0.val[1]), vget_high_s32(c0.val[1]))); + out[7] = vreinterpretq_s16_s32( + vcombine_s32(vget_high_s32(c1.val[1]), vget_high_s32(c1.val[1]))); +} + +//------------------------------------------------------------------------------ +template <int store_width, int store_count> +LIBGAV1_ALWAYS_INLINE void StoreDst(int16_t* dst, int32_t stride, int32_t idx, + const int16x8_t* const s) { + assert(store_count % 4 == 0); + assert(store_width == 8 || store_width == 16); + // NOTE: It is expected that the compiler will unroll these loops. + if (store_width == 16) { + for (int i = 0; i < store_count; i += 4) { + vst1q_s16(&dst[i * stride + idx], (s[i])); + vst1q_s16(&dst[(i + 1) * stride + idx], (s[i + 1])); + vst1q_s16(&dst[(i + 2) * stride + idx], (s[i + 2])); + vst1q_s16(&dst[(i + 3) * stride + idx], (s[i + 3])); + } + } else { + // store_width == 8 + for (int i = 0; i < store_count; i += 4) { + vst1_s16(&dst[i * stride + idx], vget_low_s16(s[i])); + vst1_s16(&dst[(i + 1) * stride + idx], vget_low_s16(s[i + 1])); + vst1_s16(&dst[(i + 2) * stride + idx], vget_low_s16(s[i + 2])); + vst1_s16(&dst[(i + 3) * stride + idx], vget_low_s16(s[i + 3])); + } + } +} + +template <int load_width, int load_count> +LIBGAV1_ALWAYS_INLINE void LoadSrc(const int16_t* src, int32_t stride, + int32_t idx, int16x8_t* x) { + assert(load_count % 4 == 0); + assert(load_width == 8 || load_width == 16); + // NOTE: It is expected that the compiler will unroll these loops. + if (load_width == 16) { + for (int i = 0; i < load_count; i += 4) { + x[i] = vld1q_s16(&src[i * stride + idx]); + x[i + 1] = vld1q_s16(&src[(i + 1) * stride + idx]); + x[i + 2] = vld1q_s16(&src[(i + 2) * stride + idx]); + x[i + 3] = vld1q_s16(&src[(i + 3) * stride + idx]); + } + } else { + // load_width == 8 + const int64x2_t zero = vdupq_n_s64(0); + for (int i = 0; i < load_count; i += 4) { + // The src buffer is aligned to 32 bytes. Each load will always be 8 + // byte aligned. + x[i] = vreinterpretq_s16_s64(vld1q_lane_s64( + reinterpret_cast<const int64_t*>(&src[i * stride + idx]), zero, 0)); + x[i + 1] = vreinterpretq_s16_s64(vld1q_lane_s64( + reinterpret_cast<const int64_t*>(&src[(i + 1) * stride + idx]), zero, + 0)); + x[i + 2] = vreinterpretq_s16_s64(vld1q_lane_s64( + reinterpret_cast<const int64_t*>(&src[(i + 2) * stride + idx]), zero, + 0)); + x[i + 3] = vreinterpretq_s16_s64(vld1q_lane_s64( + reinterpret_cast<const int64_t*>(&src[(i + 3) * stride + idx]), zero, + 0)); + } + } +} + +// Butterfly rotate 4 values. +LIBGAV1_ALWAYS_INLINE void ButterflyRotation_4(int16x8_t* a, int16x8_t* b, + const int angle, + const bool flip) { + const int16_t cos128 = Cos128(angle); + const int16_t sin128 = Sin128(angle); + const int32x4_t acc_x = vmull_n_s16(vget_low_s16(*a), cos128); + const int32x4_t acc_y = vmull_n_s16(vget_low_s16(*a), sin128); + const int32x4_t x0 = vmlsl_n_s16(acc_x, vget_low_s16(*b), sin128); + const int32x4_t y0 = vmlal_n_s16(acc_y, vget_low_s16(*b), cos128); + const int16x4_t x1 = vqrshrn_n_s32(x0, 12); + const int16x4_t y1 = vqrshrn_n_s32(y0, 12); + const int16x8_t x = vcombine_s16(x1, x1); + const int16x8_t y = vcombine_s16(y1, y1); + if (flip) { + *a = y; + *b = x; + } else { + *a = x; + *b = y; + } +} + +// Butterfly rotate 8 values. +LIBGAV1_ALWAYS_INLINE void ButterflyRotation_8(int16x8_t* a, int16x8_t* b, + const int angle, + const bool flip) { + const int16_t cos128 = Cos128(angle); + const int16_t sin128 = Sin128(angle); + const int32x4_t acc_x = vmull_n_s16(vget_low_s16(*a), cos128); + const int32x4_t acc_y = vmull_n_s16(vget_low_s16(*a), sin128); + const int32x4_t x0 = vmlsl_n_s16(acc_x, vget_low_s16(*b), sin128); + const int32x4_t y0 = vmlal_n_s16(acc_y, vget_low_s16(*b), cos128); + const int16x4_t x1 = vqrshrn_n_s32(x0, 12); + const int16x4_t y1 = vqrshrn_n_s32(y0, 12); + + const int32x4_t acc_x_hi = vmull_n_s16(vget_high_s16(*a), cos128); + const int32x4_t acc_y_hi = vmull_n_s16(vget_high_s16(*a), sin128); + const int32x4_t x0_hi = vmlsl_n_s16(acc_x_hi, vget_high_s16(*b), sin128); + const int32x4_t y0_hi = vmlal_n_s16(acc_y_hi, vget_high_s16(*b), cos128); + const int16x4_t x1_hi = vqrshrn_n_s32(x0_hi, 12); + const int16x4_t y1_hi = vqrshrn_n_s32(y0_hi, 12); + + const int16x8_t x = vcombine_s16(x1, x1_hi); + const int16x8_t y = vcombine_s16(y1, y1_hi); + if (flip) { + *a = y; + *b = x; + } else { + *a = x; + *b = y; + } +} + +LIBGAV1_ALWAYS_INLINE void ButterflyRotation_FirstIsZero(int16x8_t* a, + int16x8_t* b, + const int angle, + const bool flip) { + const int16_t cos128 = Cos128(angle); + const int16_t sin128 = Sin128(angle); + // For this function, the max value returned by Sin128() is 4091, which fits + // inside 12 bits. This leaves room for the sign bit and the 3 left shifted + // bits. + assert(sin128 <= 0xfff); + const int16x8_t x = vqrdmulhq_n_s16(*b, -sin128 << 3); + const int16x8_t y = vqrdmulhq_n_s16(*b, cos128 << 3); + if (flip) { + *a = y; + *b = x; + } else { + *a = x; + *b = y; + } +} + +LIBGAV1_ALWAYS_INLINE void ButterflyRotation_SecondIsZero(int16x8_t* a, + int16x8_t* b, + const int angle, + const bool flip) { + const int16_t cos128 = Cos128(angle); + const int16_t sin128 = Sin128(angle); + const int16x8_t x = vqrdmulhq_n_s16(*a, cos128 << 3); + const int16x8_t y = vqrdmulhq_n_s16(*a, sin128 << 3); + if (flip) { + *a = y; + *b = x; + } else { + *a = x; + *b = y; + } +} + +LIBGAV1_ALWAYS_INLINE void HadamardRotation(int16x8_t* a, int16x8_t* b, + bool flip) { + int16x8_t x, y; + if (flip) { + y = vqaddq_s16(*b, *a); + x = vqsubq_s16(*b, *a); + } else { + x = vqaddq_s16(*a, *b); + y = vqsubq_s16(*a, *b); + } + *a = x; + *b = y; +} + +using ButterflyRotationFunc = void (*)(int16x8_t* a, int16x8_t* b, int angle, + bool flip); + +//------------------------------------------------------------------------------ +// Discrete Cosine Transforms (DCT). + +template <int width> +LIBGAV1_ALWAYS_INLINE bool DctDcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + const int16x8_t v_src = vdupq_n_s16(dst[0]); + const uint16x8_t v_mask = vdupq_n_u16(should_round ? 0xffff : 0); + const int16x8_t v_src_round = + vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3); + const int16x8_t s0 = vbslq_s16(v_mask, v_src_round, v_src); + const int16_t cos128 = Cos128(32); + const int16x8_t xy = vqrdmulhq_n_s16(s0, cos128 << 3); + // vqrshlq_s16 will shift right if shift value is negative. + const int16x8_t xy_shifted = vqrshlq_s16(xy, vdupq_n_s16(-row_shift)); + + if (width == 4) { + vst1_s16(dst, vget_low_s16(xy_shifted)); + } else { + for (int i = 0; i < width; i += 8) { + vst1q_s16(dst, xy_shifted); + dst += 8; + } + } + return true; +} + +template <int height> +LIBGAV1_ALWAYS_INLINE bool DctDcOnlyColumn(void* dest, int adjusted_tx_height, + int width) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + const int16_t cos128 = Cos128(32); + + // Calculate dc values for first row. + if (width == 4) { + const int16x4_t v_src = vld1_s16(dst); + const int16x4_t xy = vqrdmulh_n_s16(v_src, cos128 << 3); + vst1_s16(dst, xy); + } else { + int i = 0; + do { + const int16x8_t v_src = vld1q_s16(&dst[i]); + const int16x8_t xy = vqrdmulhq_n_s16(v_src, cos128 << 3); + vst1q_s16(&dst[i], xy); + i += 8; + } while (i < width); + } + + // Copy first row to the rest of the block. + for (int y = 1; y < height; ++y) { + memcpy(&dst[y * width], dst, width * sizeof(dst[0])); + } + return true; +} + +template <ButterflyRotationFunc butterfly_rotation, + bool is_fast_butterfly = false> +LIBGAV1_ALWAYS_INLINE void Dct4Stages(int16x8_t* s) { + // stage 12. + if (is_fast_butterfly) { + ButterflyRotation_SecondIsZero(&s[0], &s[1], 32, true); + ButterflyRotation_SecondIsZero(&s[2], &s[3], 48, false); + } else { + butterfly_rotation(&s[0], &s[1], 32, true); + butterfly_rotation(&s[2], &s[3], 48, false); + } + + // stage 17. + HadamardRotation(&s[0], &s[3], false); + HadamardRotation(&s[1], &s[2], false); +} + +template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular> +LIBGAV1_ALWAYS_INLINE void Dct4_NEON(void* dest, int32_t step, bool transpose) { + auto* const dst = static_cast<int16_t*>(dest); + int16x8_t s[4], x[4]; + + if (stage_is_rectangular) { + if (transpose) { + int16x8_t input[8]; + LoadSrc<8, 8>(dst, step, 0, input); + Transpose4x8To8x4(input, x); + } else { + LoadSrc<16, 4>(dst, step, 0, x); + } + } else { + LoadSrc<8, 4>(dst, step, 0, x); + if (transpose) { + Transpose4x4(x, x); + } + } + + // stage 1. + // kBitReverseLookup 0, 2, 1, 3 + s[0] = x[0]; + s[1] = x[2]; + s[2] = x[1]; + s[3] = x[3]; + + Dct4Stages<butterfly_rotation>(s); + + if (stage_is_rectangular) { + if (transpose) { + int16x8_t output[8]; + Transpose8x4To4x8(s, output); + StoreDst<8, 8>(dst, step, 0, output); + } else { + StoreDst<16, 4>(dst, step, 0, s); + } + } else { + if (transpose) { + Transpose4x4(s, s); + } + StoreDst<8, 4>(dst, step, 0, s); + } +} + +template <ButterflyRotationFunc butterfly_rotation, + bool is_fast_butterfly = false> +LIBGAV1_ALWAYS_INLINE void Dct8Stages(int16x8_t* s) { + // stage 8. + if (is_fast_butterfly) { + ButterflyRotation_SecondIsZero(&s[4], &s[7], 56, false); + ButterflyRotation_FirstIsZero(&s[5], &s[6], 24, false); + } else { + butterfly_rotation(&s[4], &s[7], 56, false); + butterfly_rotation(&s[5], &s[6], 24, false); + } + + // stage 13. + HadamardRotation(&s[4], &s[5], false); + HadamardRotation(&s[6], &s[7], true); + + // stage 18. + butterfly_rotation(&s[6], &s[5], 32, true); + + // stage 22. + HadamardRotation(&s[0], &s[7], false); + HadamardRotation(&s[1], &s[6], false); + HadamardRotation(&s[2], &s[5], false); + HadamardRotation(&s[3], &s[4], false); +} + +// Process dct8 rows or columns, depending on the transpose flag. +template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular> +LIBGAV1_ALWAYS_INLINE void Dct8_NEON(void* dest, int32_t step, bool transpose) { + auto* const dst = static_cast<int16_t*>(dest); + int16x8_t s[8], x[8]; + + if (stage_is_rectangular) { + if (transpose) { + int16x8_t input[4]; + LoadSrc<16, 4>(dst, step, 0, input); + Transpose8x4To4x8(input, x); + } else { + LoadSrc<8, 8>(dst, step, 0, x); + } + } else if (transpose) { + LoadSrc<16, 8>(dst, step, 0, x); + dsp::Transpose8x8(x); + } else { + LoadSrc<16, 8>(dst, step, 0, x); + } + + // stage 1. + // kBitReverseLookup 0, 4, 2, 6, 1, 5, 3, 7, + s[0] = x[0]; + s[1] = x[4]; + s[2] = x[2]; + s[3] = x[6]; + s[4] = x[1]; + s[5] = x[5]; + s[6] = x[3]; + s[7] = x[7]; + + Dct4Stages<butterfly_rotation>(s); + Dct8Stages<butterfly_rotation>(s); + + if (stage_is_rectangular) { + if (transpose) { + int16x8_t output[4]; + Transpose4x8To8x4(s, output); + StoreDst<16, 4>(dst, step, 0, output); + } else { + StoreDst<8, 8>(dst, step, 0, s); + } + } else if (transpose) { + dsp::Transpose8x8(s); + StoreDst<16, 8>(dst, step, 0, s); + } else { + StoreDst<16, 8>(dst, step, 0, s); + } +} + +template <ButterflyRotationFunc butterfly_rotation, + bool is_fast_butterfly = false> +LIBGAV1_ALWAYS_INLINE void Dct16Stages(int16x8_t* s) { + // stage 5. + if (is_fast_butterfly) { + ButterflyRotation_SecondIsZero(&s[8], &s[15], 60, false); + ButterflyRotation_FirstIsZero(&s[9], &s[14], 28, false); + ButterflyRotation_SecondIsZero(&s[10], &s[13], 44, false); + ButterflyRotation_FirstIsZero(&s[11], &s[12], 12, false); + } else { + butterfly_rotation(&s[8], &s[15], 60, false); + butterfly_rotation(&s[9], &s[14], 28, false); + butterfly_rotation(&s[10], &s[13], 44, false); + butterfly_rotation(&s[11], &s[12], 12, false); + } + + // stage 9. + HadamardRotation(&s[8], &s[9], false); + HadamardRotation(&s[10], &s[11], true); + HadamardRotation(&s[12], &s[13], false); + HadamardRotation(&s[14], &s[15], true); + + // stage 14. + butterfly_rotation(&s[14], &s[9], 48, true); + butterfly_rotation(&s[13], &s[10], 112, true); + + // stage 19. + HadamardRotation(&s[8], &s[11], false); + HadamardRotation(&s[9], &s[10], false); + HadamardRotation(&s[12], &s[15], true); + HadamardRotation(&s[13], &s[14], true); + + // stage 23. + butterfly_rotation(&s[13], &s[10], 32, true); + butterfly_rotation(&s[12], &s[11], 32, true); + + // stage 26. + HadamardRotation(&s[0], &s[15], false); + HadamardRotation(&s[1], &s[14], false); + HadamardRotation(&s[2], &s[13], false); + HadamardRotation(&s[3], &s[12], false); + HadamardRotation(&s[4], &s[11], false); + HadamardRotation(&s[5], &s[10], false); + HadamardRotation(&s[6], &s[9], false); + HadamardRotation(&s[7], &s[8], false); +} + +// Process dct16 rows or columns, depending on the transpose flag. +template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular> +LIBGAV1_ALWAYS_INLINE void Dct16_NEON(void* dest, int32_t step, bool is_row, + int row_shift) { + auto* const dst = static_cast<int16_t*>(dest); + int16x8_t s[16], x[16]; + + if (stage_is_rectangular) { + if (is_row) { + int16x8_t input[4]; + LoadSrc<16, 4>(dst, step, 0, input); + Transpose8x4To4x8(input, x); + LoadSrc<16, 4>(dst, step, 8, input); + Transpose8x4To4x8(input, &x[8]); + } else { + LoadSrc<8, 16>(dst, step, 0, x); + } + } else if (is_row) { + for (int idx = 0; idx < 16; idx += 8) { + LoadSrc<16, 8>(dst, step, idx, &x[idx]); + dsp::Transpose8x8(&x[idx]); + } + } else { + LoadSrc<16, 16>(dst, step, 0, x); + } + + // stage 1 + // kBitReverseLookup 0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15, + s[0] = x[0]; + s[1] = x[8]; + s[2] = x[4]; + s[3] = x[12]; + s[4] = x[2]; + s[5] = x[10]; + s[6] = x[6]; + s[7] = x[14]; + s[8] = x[1]; + s[9] = x[9]; + s[10] = x[5]; + s[11] = x[13]; + s[12] = x[3]; + s[13] = x[11]; + s[14] = x[7]; + s[15] = x[15]; + + Dct4Stages<butterfly_rotation>(s); + Dct8Stages<butterfly_rotation>(s); + Dct16Stages<butterfly_rotation>(s); + + if (is_row) { + const int16x8_t v_row_shift = vdupq_n_s16(-row_shift); + for (int i = 0; i < 16; ++i) { + s[i] = vqrshlq_s16(s[i], v_row_shift); + } + } + + if (stage_is_rectangular) { + if (is_row) { + int16x8_t output[4]; + Transpose4x8To8x4(s, output); + StoreDst<16, 4>(dst, step, 0, output); + Transpose4x8To8x4(&s[8], output); + StoreDst<16, 4>(dst, step, 8, output); + } else { + StoreDst<8, 16>(dst, step, 0, s); + } + } else if (is_row) { + for (int idx = 0; idx < 16; idx += 8) { + dsp::Transpose8x8(&s[idx]); + StoreDst<16, 8>(dst, step, idx, &s[idx]); + } + } else { + StoreDst<16, 16>(dst, step, 0, s); + } +} + +template <ButterflyRotationFunc butterfly_rotation, + bool is_fast_butterfly = false> +LIBGAV1_ALWAYS_INLINE void Dct32Stages(int16x8_t* s) { + // stage 3 + if (is_fast_butterfly) { + ButterflyRotation_SecondIsZero(&s[16], &s[31], 62, false); + ButterflyRotation_FirstIsZero(&s[17], &s[30], 30, false); + ButterflyRotation_SecondIsZero(&s[18], &s[29], 46, false); + ButterflyRotation_FirstIsZero(&s[19], &s[28], 14, false); + ButterflyRotation_SecondIsZero(&s[20], &s[27], 54, false); + ButterflyRotation_FirstIsZero(&s[21], &s[26], 22, false); + ButterflyRotation_SecondIsZero(&s[22], &s[25], 38, false); + ButterflyRotation_FirstIsZero(&s[23], &s[24], 6, false); + } else { + butterfly_rotation(&s[16], &s[31], 62, false); + butterfly_rotation(&s[17], &s[30], 30, false); + butterfly_rotation(&s[18], &s[29], 46, false); + butterfly_rotation(&s[19], &s[28], 14, false); + butterfly_rotation(&s[20], &s[27], 54, false); + butterfly_rotation(&s[21], &s[26], 22, false); + butterfly_rotation(&s[22], &s[25], 38, false); + butterfly_rotation(&s[23], &s[24], 6, false); + } + // stage 6. + HadamardRotation(&s[16], &s[17], false); + HadamardRotation(&s[18], &s[19], true); + HadamardRotation(&s[20], &s[21], false); + HadamardRotation(&s[22], &s[23], true); + HadamardRotation(&s[24], &s[25], false); + HadamardRotation(&s[26], &s[27], true); + HadamardRotation(&s[28], &s[29], false); + HadamardRotation(&s[30], &s[31], true); + + // stage 10. + butterfly_rotation(&s[30], &s[17], 24 + 32, true); + butterfly_rotation(&s[29], &s[18], 24 + 64 + 32, true); + butterfly_rotation(&s[26], &s[21], 24, true); + butterfly_rotation(&s[25], &s[22], 24 + 64, true); + + // stage 15. + HadamardRotation(&s[16], &s[19], false); + HadamardRotation(&s[17], &s[18], false); + HadamardRotation(&s[20], &s[23], true); + HadamardRotation(&s[21], &s[22], true); + HadamardRotation(&s[24], &s[27], false); + HadamardRotation(&s[25], &s[26], false); + HadamardRotation(&s[28], &s[31], true); + HadamardRotation(&s[29], &s[30], true); + + // stage 20. + butterfly_rotation(&s[29], &s[18], 48, true); + butterfly_rotation(&s[28], &s[19], 48, true); + butterfly_rotation(&s[27], &s[20], 48 + 64, true); + butterfly_rotation(&s[26], &s[21], 48 + 64, true); + + // stage 24. + HadamardRotation(&s[16], &s[23], false); + HadamardRotation(&s[17], &s[22], false); + HadamardRotation(&s[18], &s[21], false); + HadamardRotation(&s[19], &s[20], false); + HadamardRotation(&s[24], &s[31], true); + HadamardRotation(&s[25], &s[30], true); + HadamardRotation(&s[26], &s[29], true); + HadamardRotation(&s[27], &s[28], true); + + // stage 27. + butterfly_rotation(&s[27], &s[20], 32, true); + butterfly_rotation(&s[26], &s[21], 32, true); + butterfly_rotation(&s[25], &s[22], 32, true); + butterfly_rotation(&s[24], &s[23], 32, true); + + // stage 29. + HadamardRotation(&s[0], &s[31], false); + HadamardRotation(&s[1], &s[30], false); + HadamardRotation(&s[2], &s[29], false); + HadamardRotation(&s[3], &s[28], false); + HadamardRotation(&s[4], &s[27], false); + HadamardRotation(&s[5], &s[26], false); + HadamardRotation(&s[6], &s[25], false); + HadamardRotation(&s[7], &s[24], false); + HadamardRotation(&s[8], &s[23], false); + HadamardRotation(&s[9], &s[22], false); + HadamardRotation(&s[10], &s[21], false); + HadamardRotation(&s[11], &s[20], false); + HadamardRotation(&s[12], &s[19], false); + HadamardRotation(&s[13], &s[18], false); + HadamardRotation(&s[14], &s[17], false); + HadamardRotation(&s[15], &s[16], false); +} + +// Process dct32 rows or columns, depending on the transpose flag. +LIBGAV1_ALWAYS_INLINE void Dct32_NEON(void* dest, const int32_t step, + const bool is_row, int row_shift) { + auto* const dst = static_cast<int16_t*>(dest); + int16x8_t s[32], x[32]; + + if (is_row) { + for (int idx = 0; idx < 32; idx += 8) { + LoadSrc<16, 8>(dst, step, idx, &x[idx]); + dsp::Transpose8x8(&x[idx]); + } + } else { + LoadSrc<16, 32>(dst, step, 0, x); + } + + // stage 1 + // kBitReverseLookup + // 0, 16, 8, 24, 4, 20, 12, 28, 2, 18, 10, 26, 6, 22, 14, 30, + s[0] = x[0]; + s[1] = x[16]; + s[2] = x[8]; + s[3] = x[24]; + s[4] = x[4]; + s[5] = x[20]; + s[6] = x[12]; + s[7] = x[28]; + s[8] = x[2]; + s[9] = x[18]; + s[10] = x[10]; + s[11] = x[26]; + s[12] = x[6]; + s[13] = x[22]; + s[14] = x[14]; + s[15] = x[30]; + + // 1, 17, 9, 25, 5, 21, 13, 29, 3, 19, 11, 27, 7, 23, 15, 31, + s[16] = x[1]; + s[17] = x[17]; + s[18] = x[9]; + s[19] = x[25]; + s[20] = x[5]; + s[21] = x[21]; + s[22] = x[13]; + s[23] = x[29]; + s[24] = x[3]; + s[25] = x[19]; + s[26] = x[11]; + s[27] = x[27]; + s[28] = x[7]; + s[29] = x[23]; + s[30] = x[15]; + s[31] = x[31]; + + Dct4Stages<ButterflyRotation_8>(s); + Dct8Stages<ButterflyRotation_8>(s); + Dct16Stages<ButterflyRotation_8>(s); + Dct32Stages<ButterflyRotation_8>(s); + + if (is_row) { + const int16x8_t v_row_shift = vdupq_n_s16(-row_shift); + for (int idx = 0; idx < 32; idx += 8) { + int16x8_t output[8]; + Transpose8x8(&s[idx], output); + for (int i = 0; i < 8; ++i) { + output[i] = vqrshlq_s16(output[i], v_row_shift); + } + StoreDst<16, 8>(dst, step, idx, output); + } + } else { + StoreDst<16, 32>(dst, step, 0, s); + } +} + +// Allow the compiler to call this function instead of force inlining. Tests +// show the performance is slightly faster. +void Dct64_NEON(void* dest, int32_t step, bool is_row, int row_shift) { + auto* const dst = static_cast<int16_t*>(dest); + int16x8_t s[64], x[32]; + + if (is_row) { + // The last 32 values of every row are always zero if the |tx_width| is + // 64. + for (int idx = 0; idx < 32; idx += 8) { + LoadSrc<16, 8>(dst, step, idx, &x[idx]); + dsp::Transpose8x8(&x[idx]); + } + } else { + // The last 32 values of every column are always zero if the |tx_height| is + // 64. + LoadSrc<16, 32>(dst, step, 0, x); + } + + // stage 1 + // kBitReverseLookup + // 0, 32, 16, 48, 8, 40, 24, 56, 4, 36, 20, 52, 12, 44, 28, 60, + s[0] = x[0]; + s[2] = x[16]; + s[4] = x[8]; + s[6] = x[24]; + s[8] = x[4]; + s[10] = x[20]; + s[12] = x[12]; + s[14] = x[28]; + + // 2, 34, 18, 50, 10, 42, 26, 58, 6, 38, 22, 54, 14, 46, 30, 62, + s[16] = x[2]; + s[18] = x[18]; + s[20] = x[10]; + s[22] = x[26]; + s[24] = x[6]; + s[26] = x[22]; + s[28] = x[14]; + s[30] = x[30]; + + // 1, 33, 17, 49, 9, 41, 25, 57, 5, 37, 21, 53, 13, 45, 29, 61, + s[32] = x[1]; + s[34] = x[17]; + s[36] = x[9]; + s[38] = x[25]; + s[40] = x[5]; + s[42] = x[21]; + s[44] = x[13]; + s[46] = x[29]; + + // 3, 35, 19, 51, 11, 43, 27, 59, 7, 39, 23, 55, 15, 47, 31, 63 + s[48] = x[3]; + s[50] = x[19]; + s[52] = x[11]; + s[54] = x[27]; + s[56] = x[7]; + s[58] = x[23]; + s[60] = x[15]; + s[62] = x[31]; + + Dct4Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s); + Dct8Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s); + Dct16Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s); + Dct32Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s); + + //-- start dct 64 stages + // stage 2. + ButterflyRotation_SecondIsZero(&s[32], &s[63], 63 - 0, false); + ButterflyRotation_FirstIsZero(&s[33], &s[62], 63 - 32, false); + ButterflyRotation_SecondIsZero(&s[34], &s[61], 63 - 16, false); + ButterflyRotation_FirstIsZero(&s[35], &s[60], 63 - 48, false); + ButterflyRotation_SecondIsZero(&s[36], &s[59], 63 - 8, false); + ButterflyRotation_FirstIsZero(&s[37], &s[58], 63 - 40, false); + ButterflyRotation_SecondIsZero(&s[38], &s[57], 63 - 24, false); + ButterflyRotation_FirstIsZero(&s[39], &s[56], 63 - 56, false); + ButterflyRotation_SecondIsZero(&s[40], &s[55], 63 - 4, false); + ButterflyRotation_FirstIsZero(&s[41], &s[54], 63 - 36, false); + ButterflyRotation_SecondIsZero(&s[42], &s[53], 63 - 20, false); + ButterflyRotation_FirstIsZero(&s[43], &s[52], 63 - 52, false); + ButterflyRotation_SecondIsZero(&s[44], &s[51], 63 - 12, false); + ButterflyRotation_FirstIsZero(&s[45], &s[50], 63 - 44, false); + ButterflyRotation_SecondIsZero(&s[46], &s[49], 63 - 28, false); + ButterflyRotation_FirstIsZero(&s[47], &s[48], 63 - 60, false); + + // stage 4. + HadamardRotation(&s[32], &s[33], false); + HadamardRotation(&s[34], &s[35], true); + HadamardRotation(&s[36], &s[37], false); + HadamardRotation(&s[38], &s[39], true); + HadamardRotation(&s[40], &s[41], false); + HadamardRotation(&s[42], &s[43], true); + HadamardRotation(&s[44], &s[45], false); + HadamardRotation(&s[46], &s[47], true); + HadamardRotation(&s[48], &s[49], false); + HadamardRotation(&s[50], &s[51], true); + HadamardRotation(&s[52], &s[53], false); + HadamardRotation(&s[54], &s[55], true); + HadamardRotation(&s[56], &s[57], false); + HadamardRotation(&s[58], &s[59], true); + HadamardRotation(&s[60], &s[61], false); + HadamardRotation(&s[62], &s[63], true); + + // stage 7. + ButterflyRotation_8(&s[62], &s[33], 60 - 0, true); + ButterflyRotation_8(&s[61], &s[34], 60 - 0 + 64, true); + ButterflyRotation_8(&s[58], &s[37], 60 - 32, true); + ButterflyRotation_8(&s[57], &s[38], 60 - 32 + 64, true); + ButterflyRotation_8(&s[54], &s[41], 60 - 16, true); + ButterflyRotation_8(&s[53], &s[42], 60 - 16 + 64, true); + ButterflyRotation_8(&s[50], &s[45], 60 - 48, true); + ButterflyRotation_8(&s[49], &s[46], 60 - 48 + 64, true); + + // stage 11. + HadamardRotation(&s[32], &s[35], false); + HadamardRotation(&s[33], &s[34], false); + HadamardRotation(&s[36], &s[39], true); + HadamardRotation(&s[37], &s[38], true); + HadamardRotation(&s[40], &s[43], false); + HadamardRotation(&s[41], &s[42], false); + HadamardRotation(&s[44], &s[47], true); + HadamardRotation(&s[45], &s[46], true); + HadamardRotation(&s[48], &s[51], false); + HadamardRotation(&s[49], &s[50], false); + HadamardRotation(&s[52], &s[55], true); + HadamardRotation(&s[53], &s[54], true); + HadamardRotation(&s[56], &s[59], false); + HadamardRotation(&s[57], &s[58], false); + HadamardRotation(&s[60], &s[63], true); + HadamardRotation(&s[61], &s[62], true); + + // stage 16. + ButterflyRotation_8(&s[61], &s[34], 56, true); + ButterflyRotation_8(&s[60], &s[35], 56, true); + ButterflyRotation_8(&s[59], &s[36], 56 + 64, true); + ButterflyRotation_8(&s[58], &s[37], 56 + 64, true); + ButterflyRotation_8(&s[53], &s[42], 56 - 32, true); + ButterflyRotation_8(&s[52], &s[43], 56 - 32, true); + ButterflyRotation_8(&s[51], &s[44], 56 - 32 + 64, true); + ButterflyRotation_8(&s[50], &s[45], 56 - 32 + 64, true); + + // stage 21. + HadamardRotation(&s[32], &s[39], false); + HadamardRotation(&s[33], &s[38], false); + HadamardRotation(&s[34], &s[37], false); + HadamardRotation(&s[35], &s[36], false); + HadamardRotation(&s[40], &s[47], true); + HadamardRotation(&s[41], &s[46], true); + HadamardRotation(&s[42], &s[45], true); + HadamardRotation(&s[43], &s[44], true); + HadamardRotation(&s[48], &s[55], false); + HadamardRotation(&s[49], &s[54], false); + HadamardRotation(&s[50], &s[53], false); + HadamardRotation(&s[51], &s[52], false); + HadamardRotation(&s[56], &s[63], true); + HadamardRotation(&s[57], &s[62], true); + HadamardRotation(&s[58], &s[61], true); + HadamardRotation(&s[59], &s[60], true); + + // stage 25. + ButterflyRotation_8(&s[59], &s[36], 48, true); + ButterflyRotation_8(&s[58], &s[37], 48, true); + ButterflyRotation_8(&s[57], &s[38], 48, true); + ButterflyRotation_8(&s[56], &s[39], 48, true); + ButterflyRotation_8(&s[55], &s[40], 112, true); + ButterflyRotation_8(&s[54], &s[41], 112, true); + ButterflyRotation_8(&s[53], &s[42], 112, true); + ButterflyRotation_8(&s[52], &s[43], 112, true); + + // stage 28. + HadamardRotation(&s[32], &s[47], false); + HadamardRotation(&s[33], &s[46], false); + HadamardRotation(&s[34], &s[45], false); + HadamardRotation(&s[35], &s[44], false); + HadamardRotation(&s[36], &s[43], false); + HadamardRotation(&s[37], &s[42], false); + HadamardRotation(&s[38], &s[41], false); + HadamardRotation(&s[39], &s[40], false); + HadamardRotation(&s[48], &s[63], true); + HadamardRotation(&s[49], &s[62], true); + HadamardRotation(&s[50], &s[61], true); + HadamardRotation(&s[51], &s[60], true); + HadamardRotation(&s[52], &s[59], true); + HadamardRotation(&s[53], &s[58], true); + HadamardRotation(&s[54], &s[57], true); + HadamardRotation(&s[55], &s[56], true); + + // stage 30. + ButterflyRotation_8(&s[55], &s[40], 32, true); + ButterflyRotation_8(&s[54], &s[41], 32, true); + ButterflyRotation_8(&s[53], &s[42], 32, true); + ButterflyRotation_8(&s[52], &s[43], 32, true); + ButterflyRotation_8(&s[51], &s[44], 32, true); + ButterflyRotation_8(&s[50], &s[45], 32, true); + ButterflyRotation_8(&s[49], &s[46], 32, true); + ButterflyRotation_8(&s[48], &s[47], 32, true); + + // stage 31. + for (int i = 0; i < 32; i += 4) { + HadamardRotation(&s[i], &s[63 - i], false); + HadamardRotation(&s[i + 1], &s[63 - i - 1], false); + HadamardRotation(&s[i + 2], &s[63 - i - 2], false); + HadamardRotation(&s[i + 3], &s[63 - i - 3], false); + } + //-- end dct 64 stages + + if (is_row) { + const int16x8_t v_row_shift = vdupq_n_s16(-row_shift); + for (int idx = 0; idx < 64; idx += 8) { + int16x8_t output[8]; + Transpose8x8(&s[idx], output); + for (int i = 0; i < 8; ++i) { + output[i] = vqrshlq_s16(output[i], v_row_shift); + } + StoreDst<16, 8>(dst, step, idx, output); + } + } else { + StoreDst<16, 64>(dst, step, 0, s); + } +} + +//------------------------------------------------------------------------------ +// Asymmetric Discrete Sine Transforms (ADST). +template <bool stage_is_rectangular> +LIBGAV1_ALWAYS_INLINE void Adst4_NEON(void* dest, int32_t step, + bool transpose) { + auto* const dst = static_cast<int16_t*>(dest); + int32x4_t s[8]; + int16x8_t x[4]; + + if (stage_is_rectangular) { + if (transpose) { + int16x8_t input[8]; + LoadSrc<8, 8>(dst, step, 0, input); + Transpose4x8To8x4(input, x); + } else { + LoadSrc<16, 4>(dst, step, 0, x); + } + } else { + LoadSrc<8, 4>(dst, step, 0, x); + if (transpose) { + Transpose4x4(x, x); + } + } + + // stage 1. + s[5] = vmull_n_s16(vget_low_s16(x[3]), kAdst4Multiplier[1]); + s[6] = vmull_n_s16(vget_low_s16(x[3]), kAdst4Multiplier[3]); + + // stage 2. + const int32x4_t a7 = vsubl_s16(vget_low_s16(x[0]), vget_low_s16(x[2])); + const int32x4_t b7 = vaddw_s16(a7, vget_low_s16(x[3])); + + // stage 3. + s[0] = vmull_n_s16(vget_low_s16(x[0]), kAdst4Multiplier[0]); + s[1] = vmull_n_s16(vget_low_s16(x[0]), kAdst4Multiplier[1]); + // s[0] = s[0] + s[3] + s[0] = vmlal_n_s16(s[0], vget_low_s16(x[2]), kAdst4Multiplier[3]); + // s[1] = s[1] - s[4] + s[1] = vmlsl_n_s16(s[1], vget_low_s16(x[2]), kAdst4Multiplier[0]); + + s[3] = vmull_n_s16(vget_low_s16(x[1]), kAdst4Multiplier[2]); + s[2] = vmulq_n_s32(b7, kAdst4Multiplier[2]); + + // stage 4. + s[0] = vaddq_s32(s[0], s[5]); + s[1] = vsubq_s32(s[1], s[6]); + + // stages 5 and 6. + const int32x4_t x0 = vaddq_s32(s[0], s[3]); + const int32x4_t x1 = vaddq_s32(s[1], s[3]); + const int32x4_t x3_a = vaddq_s32(s[0], s[1]); + const int32x4_t x3 = vsubq_s32(x3_a, s[3]); + const int16x4_t dst_0 = vqrshrn_n_s32(x0, 12); + const int16x4_t dst_1 = vqrshrn_n_s32(x1, 12); + const int16x4_t dst_2 = vqrshrn_n_s32(s[2], 12); + const int16x4_t dst_3 = vqrshrn_n_s32(x3, 12); + + x[0] = vcombine_s16(dst_0, dst_0); + x[1] = vcombine_s16(dst_1, dst_1); + x[2] = vcombine_s16(dst_2, dst_2); + x[3] = vcombine_s16(dst_3, dst_3); + + if (stage_is_rectangular) { + if (transpose) { + int16x8_t output[8]; + Transpose8x4To4x8(x, output); + StoreDst<8, 8>(dst, step, 0, output); + } else { + StoreDst<16, 4>(dst, step, 0, x); + } + } else { + if (transpose) { + Transpose4x4(x, x); + } + StoreDst<8, 4>(dst, step, 0, x); + } +} + +alignas(8) constexpr int16_t kAdst4DcOnlyMultiplier[4] = {1321, 2482, 3344, + 2482}; + +LIBGAV1_ALWAYS_INLINE bool Adst4DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + int32x4_t s[2]; + + const int16x4_t v_src0 = vdup_n_s16(dst[0]); + const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0); + const int16x4_t v_src_round = + vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3); + const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0); + const int16x4_t kAdst4DcOnlyMultipliers = vld1_s16(kAdst4DcOnlyMultiplier); + s[1] = vdupq_n_s32(0); + + // s0*k0 s0*k1 s0*k2 s0*k1 + s[0] = vmull_s16(kAdst4DcOnlyMultipliers, v_src); + // 0 0 0 s0*k0 + s[1] = vextq_s32(s[1], s[0], 1); + + const int32x4_t x3 = vaddq_s32(s[0], s[1]); + const int16x4_t dst_0 = vqrshrn_n_s32(x3, 12); + + // vqrshlq_s16 will shift right if shift value is negative. + vst1_s16(dst, vqrshl_s16(dst_0, vdup_n_s16(-row_shift))); + + return true; +} + +LIBGAV1_ALWAYS_INLINE bool Adst4DcOnlyColumn(void* dest, int adjusted_tx_height, + int width) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + int32x4_t s[4]; + + int i = 0; + do { + const int16x4_t v_src = vld1_s16(&dst[i]); + + s[0] = vmull_n_s16(v_src, kAdst4Multiplier[0]); + s[1] = vmull_n_s16(v_src, kAdst4Multiplier[1]); + s[2] = vmull_n_s16(v_src, kAdst4Multiplier[2]); + + const int32x4_t x0 = s[0]; + const int32x4_t x1 = s[1]; + const int32x4_t x2 = s[2]; + const int32x4_t x3 = vaddq_s32(s[0], s[1]); + const int16x4_t dst_0 = vqrshrn_n_s32(x0, 12); + const int16x4_t dst_1 = vqrshrn_n_s32(x1, 12); + const int16x4_t dst_2 = vqrshrn_n_s32(x2, 12); + const int16x4_t dst_3 = vqrshrn_n_s32(x3, 12); + + vst1_s16(&dst[i], dst_0); + vst1_s16(&dst[i + width * 1], dst_1); + vst1_s16(&dst[i + width * 2], dst_2); + vst1_s16(&dst[i + width * 3], dst_3); + + i += 4; + } while (i < width); + + return true; +} + +template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular> +LIBGAV1_ALWAYS_INLINE void Adst8_NEON(void* dest, int32_t step, + bool transpose) { + auto* const dst = static_cast<int16_t*>(dest); + int16x8_t s[8], x[8]; + + if (stage_is_rectangular) { + if (transpose) { + int16x8_t input[4]; + LoadSrc<16, 4>(dst, step, 0, input); + Transpose8x4To4x8(input, x); + } else { + LoadSrc<8, 8>(dst, step, 0, x); + } + } else { + if (transpose) { + LoadSrc<16, 8>(dst, step, 0, x); + dsp::Transpose8x8(x); + } else { + LoadSrc<16, 8>(dst, step, 0, x); + } + } + + // stage 1. + s[0] = x[7]; + s[1] = x[0]; + s[2] = x[5]; + s[3] = x[2]; + s[4] = x[3]; + s[5] = x[4]; + s[6] = x[1]; + s[7] = x[6]; + + // stage 2. + butterfly_rotation(&s[0], &s[1], 60 - 0, true); + butterfly_rotation(&s[2], &s[3], 60 - 16, true); + butterfly_rotation(&s[4], &s[5], 60 - 32, true); + butterfly_rotation(&s[6], &s[7], 60 - 48, true); + + // stage 3. + HadamardRotation(&s[0], &s[4], false); + HadamardRotation(&s[1], &s[5], false); + HadamardRotation(&s[2], &s[6], false); + HadamardRotation(&s[3], &s[7], false); + + // stage 4. + butterfly_rotation(&s[4], &s[5], 48 - 0, true); + butterfly_rotation(&s[7], &s[6], 48 - 32, true); + + // stage 5. + HadamardRotation(&s[0], &s[2], false); + HadamardRotation(&s[4], &s[6], false); + HadamardRotation(&s[1], &s[3], false); + HadamardRotation(&s[5], &s[7], false); + + // stage 6. + butterfly_rotation(&s[2], &s[3], 32, true); + butterfly_rotation(&s[6], &s[7], 32, true); + + // stage 7. + x[0] = s[0]; + x[1] = vqnegq_s16(s[4]); + x[2] = s[6]; + x[3] = vqnegq_s16(s[2]); + x[4] = s[3]; + x[5] = vqnegq_s16(s[7]); + x[6] = s[5]; + x[7] = vqnegq_s16(s[1]); + + if (stage_is_rectangular) { + if (transpose) { + int16x8_t output[4]; + Transpose4x8To8x4(x, output); + StoreDst<16, 4>(dst, step, 0, output); + } else { + StoreDst<8, 8>(dst, step, 0, x); + } + } else { + if (transpose) { + dsp::Transpose8x8(x); + StoreDst<16, 8>(dst, step, 0, x); + } else { + StoreDst<16, 8>(dst, step, 0, x); + } + } +} + +LIBGAV1_ALWAYS_INLINE bool Adst8DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + int16x8_t s[8]; + + const int16x8_t v_src = vdupq_n_s16(dst[0]); + const uint16x8_t v_mask = vdupq_n_u16(should_round ? 0xffff : 0); + const int16x8_t v_src_round = + vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3); + // stage 1. + s[1] = vbslq_s16(v_mask, v_src_round, v_src); + + // stage 2. + ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true); + + // stage 3. + s[4] = s[0]; + s[5] = s[1]; + + // stage 4. + ButterflyRotation_4(&s[4], &s[5], 48, true); + + // stage 5. + s[2] = s[0]; + s[3] = s[1]; + s[6] = s[4]; + s[7] = s[5]; + + // stage 6. + ButterflyRotation_4(&s[2], &s[3], 32, true); + ButterflyRotation_4(&s[6], &s[7], 32, true); + + // stage 7. + int16x8_t x[8]; + x[0] = s[0]; + x[1] = vqnegq_s16(s[4]); + x[2] = s[6]; + x[3] = vqnegq_s16(s[2]); + x[4] = s[3]; + x[5] = vqnegq_s16(s[7]); + x[6] = s[5]; + x[7] = vqnegq_s16(s[1]); + + for (int i = 0; i < 8; ++i) { + // vqrshlq_s16 will shift right if shift value is negative. + x[i] = vqrshlq_s16(x[i], vdupq_n_s16(-row_shift)); + vst1q_lane_s16(&dst[i], x[i], 0); + } + + return true; +} + +LIBGAV1_ALWAYS_INLINE bool Adst8DcOnlyColumn(void* dest, int adjusted_tx_height, + int width) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + int16x8_t s[8]; + + int i = 0; + do { + const int16x8_t v_src = vld1q_s16(dst); + // stage 1. + s[1] = v_src; + + // stage 2. + ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true); + + // stage 3. + s[4] = s[0]; + s[5] = s[1]; + + // stage 4. + ButterflyRotation_4(&s[4], &s[5], 48, true); + + // stage 5. + s[2] = s[0]; + s[3] = s[1]; + s[6] = s[4]; + s[7] = s[5]; + + // stage 6. + ButterflyRotation_4(&s[2], &s[3], 32, true); + ButterflyRotation_4(&s[6], &s[7], 32, true); + + // stage 7. + int16x8_t x[8]; + x[0] = s[0]; + x[1] = vqnegq_s16(s[4]); + x[2] = s[6]; + x[3] = vqnegq_s16(s[2]); + x[4] = s[3]; + x[5] = vqnegq_s16(s[7]); + x[6] = s[5]; + x[7] = vqnegq_s16(s[1]); + + for (int j = 0; j < 8; ++j) { + vst1_s16(&dst[j * width], vget_low_s16(x[j])); + } + i += 4; + dst += 4; + } while (i < width); + + return true; +} + +template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular> +LIBGAV1_ALWAYS_INLINE void Adst16_NEON(void* dest, int32_t step, bool is_row, + int row_shift) { + auto* const dst = static_cast<int16_t*>(dest); + int16x8_t s[16], x[16]; + + if (stage_is_rectangular) { + if (is_row) { + int16x8_t input[4]; + LoadSrc<16, 4>(dst, step, 0, input); + Transpose8x4To4x8(input, x); + LoadSrc<16, 4>(dst, step, 8, input); + Transpose8x4To4x8(input, &x[8]); + } else { + LoadSrc<8, 16>(dst, step, 0, x); + } + } else { + if (is_row) { + for (int idx = 0; idx < 16; idx += 8) { + LoadSrc<16, 8>(dst, step, idx, &x[idx]); + dsp::Transpose8x8(&x[idx]); + } + } else { + LoadSrc<16, 16>(dst, step, 0, x); + } + } + + // stage 1. + s[0] = x[15]; + s[1] = x[0]; + s[2] = x[13]; + s[3] = x[2]; + s[4] = x[11]; + s[5] = x[4]; + s[6] = x[9]; + s[7] = x[6]; + s[8] = x[7]; + s[9] = x[8]; + s[10] = x[5]; + s[11] = x[10]; + s[12] = x[3]; + s[13] = x[12]; + s[14] = x[1]; + s[15] = x[14]; + + // stage 2. + butterfly_rotation(&s[0], &s[1], 62 - 0, true); + butterfly_rotation(&s[2], &s[3], 62 - 8, true); + butterfly_rotation(&s[4], &s[5], 62 - 16, true); + butterfly_rotation(&s[6], &s[7], 62 - 24, true); + butterfly_rotation(&s[8], &s[9], 62 - 32, true); + butterfly_rotation(&s[10], &s[11], 62 - 40, true); + butterfly_rotation(&s[12], &s[13], 62 - 48, true); + butterfly_rotation(&s[14], &s[15], 62 - 56, true); + + // stage 3. + HadamardRotation(&s[0], &s[8], false); + HadamardRotation(&s[1], &s[9], false); + HadamardRotation(&s[2], &s[10], false); + HadamardRotation(&s[3], &s[11], false); + HadamardRotation(&s[4], &s[12], false); + HadamardRotation(&s[5], &s[13], false); + HadamardRotation(&s[6], &s[14], false); + HadamardRotation(&s[7], &s[15], false); + + // stage 4. + butterfly_rotation(&s[8], &s[9], 56 - 0, true); + butterfly_rotation(&s[13], &s[12], 8 + 0, true); + butterfly_rotation(&s[10], &s[11], 56 - 32, true); + butterfly_rotation(&s[15], &s[14], 8 + 32, true); + + // stage 5. + HadamardRotation(&s[0], &s[4], false); + HadamardRotation(&s[8], &s[12], false); + HadamardRotation(&s[1], &s[5], false); + HadamardRotation(&s[9], &s[13], false); + HadamardRotation(&s[2], &s[6], false); + HadamardRotation(&s[10], &s[14], false); + HadamardRotation(&s[3], &s[7], false); + HadamardRotation(&s[11], &s[15], false); + + // stage 6. + butterfly_rotation(&s[4], &s[5], 48 - 0, true); + butterfly_rotation(&s[12], &s[13], 48 - 0, true); + butterfly_rotation(&s[7], &s[6], 48 - 32, true); + butterfly_rotation(&s[15], &s[14], 48 - 32, true); + + // stage 7. + HadamardRotation(&s[0], &s[2], false); + HadamardRotation(&s[4], &s[6], false); + HadamardRotation(&s[8], &s[10], false); + HadamardRotation(&s[12], &s[14], false); + HadamardRotation(&s[1], &s[3], false); + HadamardRotation(&s[5], &s[7], false); + HadamardRotation(&s[9], &s[11], false); + HadamardRotation(&s[13], &s[15], false); + + // stage 8. + butterfly_rotation(&s[2], &s[3], 32, true); + butterfly_rotation(&s[6], &s[7], 32, true); + butterfly_rotation(&s[10], &s[11], 32, true); + butterfly_rotation(&s[14], &s[15], 32, true); + + // stage 9. + x[0] = s[0]; + x[1] = vqnegq_s16(s[8]); + x[2] = s[12]; + x[3] = vqnegq_s16(s[4]); + x[4] = s[6]; + x[5] = vqnegq_s16(s[14]); + x[6] = s[10]; + x[7] = vqnegq_s16(s[2]); + x[8] = s[3]; + x[9] = vqnegq_s16(s[11]); + x[10] = s[15]; + x[11] = vqnegq_s16(s[7]); + x[12] = s[5]; + x[13] = vqnegq_s16(s[13]); + x[14] = s[9]; + x[15] = vqnegq_s16(s[1]); + + if (stage_is_rectangular) { + if (is_row) { + const int16x8_t v_row_shift = vdupq_n_s16(-row_shift); + int16x8_t output[4]; + Transpose4x8To8x4(x, output); + for (int i = 0; i < 4; ++i) { + output[i] = vqrshlq_s16(output[i], v_row_shift); + } + StoreDst<16, 4>(dst, step, 0, output); + Transpose4x8To8x4(&x[8], output); + for (int i = 0; i < 4; ++i) { + output[i] = vqrshlq_s16(output[i], v_row_shift); + } + StoreDst<16, 4>(dst, step, 8, output); + } else { + StoreDst<8, 16>(dst, step, 0, x); + } + } else { + if (is_row) { + const int16x8_t v_row_shift = vdupq_n_s16(-row_shift); + for (int idx = 0; idx < 16; idx += 8) { + int16x8_t output[8]; + Transpose8x8(&x[idx], output); + for (int i = 0; i < 8; ++i) { + output[i] = vqrshlq_s16(output[i], v_row_shift); + } + StoreDst<16, 8>(dst, step, idx, output); + } + } else { + StoreDst<16, 16>(dst, step, 0, x); + } + } +} + +LIBGAV1_ALWAYS_INLINE void Adst16DcOnlyInternal(int16x8_t* s, int16x8_t* x) { + // stage 2. + ButterflyRotation_FirstIsZero(&s[0], &s[1], 62, true); + + // stage 3. + s[8] = s[0]; + s[9] = s[1]; + + // stage 4. + ButterflyRotation_4(&s[8], &s[9], 56, true); + + // stage 5. + s[4] = s[0]; + s[12] = s[8]; + s[5] = s[1]; + s[13] = s[9]; + + // stage 6. + ButterflyRotation_4(&s[4], &s[5], 48, true); + ButterflyRotation_4(&s[12], &s[13], 48, true); + + // stage 7. + s[2] = s[0]; + s[6] = s[4]; + s[10] = s[8]; + s[14] = s[12]; + s[3] = s[1]; + s[7] = s[5]; + s[11] = s[9]; + s[15] = s[13]; + + // stage 8. + ButterflyRotation_4(&s[2], &s[3], 32, true); + ButterflyRotation_4(&s[6], &s[7], 32, true); + ButterflyRotation_4(&s[10], &s[11], 32, true); + ButterflyRotation_4(&s[14], &s[15], 32, true); + + // stage 9. + x[0] = s[0]; + x[1] = vqnegq_s16(s[8]); + x[2] = s[12]; + x[3] = vqnegq_s16(s[4]); + x[4] = s[6]; + x[5] = vqnegq_s16(s[14]); + x[6] = s[10]; + x[7] = vqnegq_s16(s[2]); + x[8] = s[3]; + x[9] = vqnegq_s16(s[11]); + x[10] = s[15]; + x[11] = vqnegq_s16(s[7]); + x[12] = s[5]; + x[13] = vqnegq_s16(s[13]); + x[14] = s[9]; + x[15] = vqnegq_s16(s[1]); +} + +LIBGAV1_ALWAYS_INLINE bool Adst16DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + int16x8_t s[16]; + int16x8_t x[16]; + + const int16x8_t v_src = vdupq_n_s16(dst[0]); + const uint16x8_t v_mask = vdupq_n_u16(should_round ? 0xffff : 0); + const int16x8_t v_src_round = + vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3); + // stage 1. + s[1] = vbslq_s16(v_mask, v_src_round, v_src); + + Adst16DcOnlyInternal(s, x); + + for (int i = 0; i < 16; ++i) { + // vqrshlq_s16 will shift right if shift value is negative. + x[i] = vqrshlq_s16(x[i], vdupq_n_s16(-row_shift)); + vst1q_lane_s16(&dst[i], x[i], 0); + } + + return true; +} + +LIBGAV1_ALWAYS_INLINE bool Adst16DcOnlyColumn(void* dest, + int adjusted_tx_height, + int width) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + int i = 0; + do { + int16x8_t s[16]; + int16x8_t x[16]; + const int16x8_t v_src = vld1q_s16(dst); + // stage 1. + s[1] = v_src; + + Adst16DcOnlyInternal(s, x); + + for (int j = 0; j < 16; ++j) { + vst1_s16(&dst[j * width], vget_low_s16(x[j])); + } + i += 4; + dst += 4; + } while (i < width); + + return true; +} + +//------------------------------------------------------------------------------ +// Identity Transforms. + +template <bool is_row_shift> +LIBGAV1_ALWAYS_INLINE void Identity4_NEON(void* dest, int32_t step) { + auto* const dst = static_cast<int16_t*>(dest); + + if (is_row_shift) { + const int shift = 1; + const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11); + const int16x4_t v_multiplier = vdup_n_s16(kIdentity4Multiplier); + const int32x4_t v_shift = vdupq_n_s32(-(12 + shift)); + for (int i = 0; i < 4; i += 2) { + const int16x8_t v_src = vld1q_s16(&dst[i * step]); + const int32x4_t v_src_mult_lo = + vmlal_s16(v_dual_round, vget_low_s16(v_src), v_multiplier); + const int32x4_t v_src_mult_hi = + vmlal_s16(v_dual_round, vget_high_s16(v_src), v_multiplier); + const int32x4_t shift_lo = vqshlq_s32(v_src_mult_lo, v_shift); + const int32x4_t shift_hi = vqshlq_s32(v_src_mult_hi, v_shift); + vst1q_s16(&dst[i * step], + vcombine_s16(vqmovn_s32(shift_lo), vqmovn_s32(shift_hi))); + } + } else { + for (int i = 0; i < 4; i += 2) { + const int16x8_t v_src = vld1q_s16(&dst[i * step]); + const int16x8_t a = + vqrdmulhq_n_s16(v_src, kIdentity4MultiplierFraction << 3); + const int16x8_t b = vqaddq_s16(v_src, a); + vst1q_s16(&dst[i * step], b); + } + } +} + +LIBGAV1_ALWAYS_INLINE bool Identity4DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int tx_height) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + const int16x4_t v_src0 = vdup_n_s16(dst[0]); + const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0); + const int16x4_t v_src_round = + vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3); + const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0); + const int shift = tx_height < 16 ? 0 : 1; + const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11); + const int16x4_t v_multiplier = vdup_n_s16(kIdentity4Multiplier); + const int32x4_t v_shift = vdupq_n_s32(-(12 + shift)); + const int32x4_t v_src_mult_lo = vmlal_s16(v_dual_round, v_src, v_multiplier); + const int32x4_t dst_0 = vqshlq_s32(v_src_mult_lo, v_shift); + vst1_lane_s16(dst, vqmovn_s32(dst_0), 0); + return true; +} + +template <int identity_size> +LIBGAV1_ALWAYS_INLINE void IdentityColumnStoreToFrame( + Array2DView<uint8_t> frame, const int start_x, const int start_y, + const int tx_width, const int tx_height, const int16_t* source) { + const int stride = frame.columns(); + uint8_t* dst = frame[start_y] + start_x; + + if (identity_size < 32) { + if (tx_width == 4) { + uint8x8_t frame_data = vdup_n_u8(0); + int i = 0; + do { + const int16x4_t v_src = vld1_s16(&source[i * tx_width]); + + int16x4_t v_dst_i; + if (identity_size == 4) { + const int16x4_t v_src_fraction = + vqrdmulh_n_s16(v_src, kIdentity4MultiplierFraction << 3); + v_dst_i = vqadd_s16(v_src, v_src_fraction); + } else if (identity_size == 8) { + v_dst_i = vqadd_s16(v_src, v_src); + } else { // identity_size == 16 + const int16x4_t v_src_mult = + vqrdmulh_n_s16(v_src, kIdentity4MultiplierFraction << 4); + const int16x4_t v_srcx2 = vqadd_s16(v_src, v_src); + v_dst_i = vqadd_s16(v_srcx2, v_src_mult); + } + + frame_data = Load4<0>(dst, frame_data); + const int16x4_t a = vrshr_n_s16(v_dst_i, 4); + const uint16x8_t b = + vaddw_u8(vreinterpretq_u16_s16(vcombine_s16(a, a)), frame_data); + const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b)); + StoreLo4(dst, d); + dst += stride; + } while (++i < tx_height); + } else { + int i = 0; + do { + const int row = i * tx_width; + int j = 0; + do { + const int16x8_t v_src = vld1q_s16(&source[row + j]); + + int16x8_t v_dst_i; + if (identity_size == 4) { + const int16x8_t v_src_fraction = + vqrdmulhq_n_s16(v_src, kIdentity4MultiplierFraction << 3); + v_dst_i = vqaddq_s16(v_src, v_src_fraction); + } else if (identity_size == 8) { + v_dst_i = vqaddq_s16(v_src, v_src); + } else { // identity_size == 16 + const int16x8_t v_src_mult = + vqrdmulhq_n_s16(v_src, kIdentity4MultiplierFraction << 4); + const int16x8_t v_srcx2 = vqaddq_s16(v_src, v_src); + v_dst_i = vqaddq_s16(v_src_mult, v_srcx2); + } + + const uint8x8_t frame_data = vld1_u8(dst + j); + const int16x8_t a = vrshrq_n_s16(v_dst_i, 4); + const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data); + const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b)); + vst1_u8(dst + j, d); + j += 8; + } while (j < tx_width); + dst += stride; + } while (++i < tx_height); + } + } else { + int i = 0; + do { + const int row = i * tx_width; + int j = 0; + do { + const int16x8_t v_dst_i = vld1q_s16(&source[row + j]); + const uint8x8_t frame_data = vld1_u8(dst + j); + const int16x8_t a = vrshrq_n_s16(v_dst_i, 2); + const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data); + const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b)); + vst1_u8(dst + j, d); + j += 8; + } while (j < tx_width); + dst += stride; + } while (++i < tx_height); + } +} + +LIBGAV1_ALWAYS_INLINE void Identity4RowColumnStoreToFrame( + Array2DView<uint8_t> frame, const int start_x, const int start_y, + const int tx_width, const int tx_height, const int16_t* source) { + const int stride = frame.columns(); + uint8_t* dst = frame[start_y] + start_x; + + if (tx_width == 4) { + uint8x8_t frame_data = vdup_n_u8(0); + int i = 0; + do { + const int16x4_t v_src = vld1_s16(&source[i * tx_width]); + const int16x4_t v_src_mult = + vqrdmulh_n_s16(v_src, kIdentity4MultiplierFraction << 3); + const int16x4_t v_dst_row = vqadd_s16(v_src, v_src_mult); + const int16x4_t v_src_mult2 = + vqrdmulh_n_s16(v_dst_row, kIdentity4MultiplierFraction << 3); + const int16x4_t v_dst_col = vqadd_s16(v_dst_row, v_src_mult2); + frame_data = Load4<0>(dst, frame_data); + const int16x4_t a = vrshr_n_s16(v_dst_col, 4); + const uint16x8_t b = + vaddw_u8(vreinterpretq_u16_s16(vcombine_s16(a, a)), frame_data); + const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b)); + StoreLo4(dst, d); + dst += stride; + } while (++i < tx_height); + } else { + int i = 0; + do { + const int row = i * tx_width; + int j = 0; + do { + const int16x8_t v_src = vld1q_s16(&source[row + j]); + const int16x8_t v_src_round = + vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3); + const int16x8_t v_dst_row = vqaddq_s16(v_src_round, v_src_round); + const int16x8_t v_src_mult2 = + vqrdmulhq_n_s16(v_dst_row, kIdentity4MultiplierFraction << 3); + const int16x8_t v_dst_col = vqaddq_s16(v_dst_row, v_src_mult2); + const uint8x8_t frame_data = vld1_u8(dst + j); + const int16x8_t a = vrshrq_n_s16(v_dst_col, 4); + const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data); + const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b)); + vst1_u8(dst + j, d); + j += 8; + } while (j < tx_width); + dst += stride; + } while (++i < tx_height); + } +} + +LIBGAV1_ALWAYS_INLINE void Identity8Row32_NEON(void* dest, int32_t step) { + auto* const dst = static_cast<int16_t*>(dest); + + // When combining the identity8 multiplier with the row shift, the + // calculations for tx_height equal to 32 can be simplified from + // ((A * 2) + 2) >> 2) to ((A + 1) >> 1). + for (int i = 0; i < 4; ++i) { + const int16x8_t v_src = vld1q_s16(&dst[i * step]); + const int16x8_t a = vrshrq_n_s16(v_src, 1); + vst1q_s16(&dst[i * step], a); + } +} + +LIBGAV1_ALWAYS_INLINE void Identity8Row4_NEON(void* dest, int32_t step) { + auto* const dst = static_cast<int16_t*>(dest); + + for (int i = 0; i < 4; ++i) { + const int16x8_t v_src = vld1q_s16(&dst[i * step]); + // For bitdepth == 8, the identity row clamps to a signed 16bit value, so + // saturating add here is ok. + const int16x8_t v_srcx2 = vqaddq_s16(v_src, v_src); + vst1q_s16(&dst[i * step], v_srcx2); + } +} + +LIBGAV1_ALWAYS_INLINE bool Identity8DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + const int16x4_t v_src0 = vdup_n_s16(dst[0]); + const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0); + const int16x4_t v_src_round = + vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3); + const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0); + const int32x4_t v_srcx2 = vaddl_s16(v_src, v_src); + const int32x4_t dst_0 = vqrshlq_s32(v_srcx2, vdupq_n_s32(-row_shift)); + vst1_lane_s16(dst, vqmovn_s32(dst_0), 0); + return true; +} + +LIBGAV1_ALWAYS_INLINE void Identity16Row_NEON(void* dest, int32_t step, + int shift) { + auto* const dst = static_cast<int16_t*>(dest); + const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11); + const int32x4_t v_shift = vdupq_n_s32(-(12 + shift)); + + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 2; ++j) { + const int16x8_t v_src = vld1q_s16(&dst[i * step + j * 8]); + const int32x4_t v_src_mult_lo = + vmlal_n_s16(v_dual_round, vget_low_s16(v_src), kIdentity16Multiplier); + const int32x4_t v_src_mult_hi = vmlal_n_s16( + v_dual_round, vget_high_s16(v_src), kIdentity16Multiplier); + const int32x4_t shift_lo = vqshlq_s32(v_src_mult_lo, v_shift); + const int32x4_t shift_hi = vqshlq_s32(v_src_mult_hi, v_shift); + vst1q_s16(&dst[i * step + j * 8], + vcombine_s16(vqmovn_s32(shift_lo), vqmovn_s32(shift_hi))); + } + } +} + +LIBGAV1_ALWAYS_INLINE bool Identity16DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + const int16x4_t v_src0 = vdup_n_s16(dst[0]); + const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0); + const int16x4_t v_src_round = + vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3); + const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0); + const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11); + const int16x4_t v_multiplier = vdup_n_s16(kIdentity16Multiplier); + const int32x4_t v_shift = vdupq_n_s32(-(12 + shift)); + const int32x4_t v_src_mult_lo = + vmlal_s16(v_dual_round, (v_src), v_multiplier); + const int32x4_t dst_0 = vqshlq_s32(v_src_mult_lo, v_shift); + vst1_lane_s16(dst, vqmovn_s32(dst_0), 0); + return true; +} + +LIBGAV1_ALWAYS_INLINE void Identity32Row16_NEON(void* dest, + const int32_t step) { + auto* const dst = static_cast<int16_t*>(dest); + + // When combining the identity32 multiplier with the row shift, the + // calculation for tx_height equal to 16 can be simplified from + // ((A * 4) + 1) >> 1) to (A * 2). + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 32; j += 8) { + const int16x8_t v_src = vld1q_s16(&dst[i * step + j]); + // For bitdepth == 8, the identity row clamps to a signed 16bit value, so + // saturating add here is ok. + const int16x8_t v_dst_i = vqaddq_s16(v_src, v_src); + vst1q_s16(&dst[i * step + j], v_dst_i); + } + } +} + +LIBGAV1_ALWAYS_INLINE bool Identity32DcOnly(void* dest, + int adjusted_tx_height) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + const int16x4_t v_src0 = vdup_n_s16(dst[0]); + const int16x4_t v_src = vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3); + // When combining the identity32 multiplier with the row shift, the + // calculation for tx_height equal to 16 can be simplified from + // ((A * 4) + 1) >> 1) to (A * 2). + const int16x4_t v_dst_0 = vqadd_s16(v_src, v_src); + vst1_lane_s16(dst, v_dst_0, 0); + return true; +} + +//------------------------------------------------------------------------------ +// Walsh Hadamard Transform. + +// Transposes a 4x4 matrix and then permutes the rows of the transposed matrix +// for the WHT. The input matrix is in two "wide" int16x8_t variables. The +// output matrix is in four int16x4_t variables. +// +// Input: +// in[0]: 00 01 02 03 10 11 12 13 +// in[1]: 20 21 22 23 30 31 32 33 +// Output: +// out[0]: 00 10 20 30 +// out[1]: 03 13 23 33 +// out[2]: 01 11 21 31 +// out[3]: 02 12 22 32 +LIBGAV1_ALWAYS_INLINE void TransposeAndPermute4x4WideInput( + const int16x8_t in[2], int16x4_t out[4]) { + // Swap 32 bit elements. Goes from: + // in[0]: 00 01 02 03 10 11 12 13 + // in[1]: 20 21 22 23 30 31 32 33 + // to: + // b0.val[0]: 00 01 20 21 10 11 30 31 + // b0.val[1]: 02 03 22 23 12 13 32 33 + + const int32x4x2_t b0 = + vtrnq_s32(vreinterpretq_s32_s16(in[0]), vreinterpretq_s32_s16(in[1])); + + // Swap 16 bit elements. Goes from: + // vget_low_s32(b0.val[0]): 00 01 20 21 + // vget_high_s32(b0.val[0]): 10 11 30 31 + // vget_low_s32(b0.val[1]): 02 03 22 23 + // vget_high_s32(b0.val[1]): 12 13 32 33 + // to: + // c0.val[0]: 00 10 20 30 + // c0.val[1]: 01 11 21 32 + // c1.val[0]: 02 12 22 32 + // c1.val[1]: 03 13 23 33 + + const int16x4x2_t c0 = + vtrn_s16(vreinterpret_s16_s32(vget_low_s32(b0.val[0])), + vreinterpret_s16_s32(vget_high_s32(b0.val[0]))); + const int16x4x2_t c1 = + vtrn_s16(vreinterpret_s16_s32(vget_low_s32(b0.val[1])), + vreinterpret_s16_s32(vget_high_s32(b0.val[1]))); + + out[0] = c0.val[0]; + out[1] = c1.val[1]; + out[2] = c0.val[1]; + out[3] = c1.val[0]; +} + +// Process 4 wht4 rows and columns. +LIBGAV1_ALWAYS_INLINE void Wht4_NEON(uint8_t* dst, const int dst_stride, + const void* source, + const int adjusted_tx_height) { + const auto* const src = static_cast<const int16_t*>(source); + int16x4_t s[4]; + + if (adjusted_tx_height == 1) { + // Special case: only src[0] is nonzero. + // src[0] 0 0 0 + // 0 0 0 0 + // 0 0 0 0 + // 0 0 0 0 + // + // After the row and column transforms are applied, we have: + // f h h h + // g i i i + // g i i i + // g i i i + // where f, g, h, i are computed as follows. + int16_t f = (src[0] >> 2) - (src[0] >> 3); + const int16_t g = f >> 1; + f = f - (f >> 1); + const int16_t h = (src[0] >> 3) - (src[0] >> 4); + const int16_t i = (src[0] >> 4); + s[0] = vdup_n_s16(h); + s[0] = vset_lane_s16(f, s[0], 0); + s[1] = vdup_n_s16(i); + s[1] = vset_lane_s16(g, s[1], 0); + s[2] = s[3] = s[1]; + } else { + // Load the 4x4 source in transposed form. + int16x4x4_t columns = vld4_s16(src); + // Shift right and permute the columns for the WHT. + s[0] = vshr_n_s16(columns.val[0], 2); + s[2] = vshr_n_s16(columns.val[1], 2); + s[3] = vshr_n_s16(columns.val[2], 2); + s[1] = vshr_n_s16(columns.val[3], 2); + + // Row transforms. + s[0] = vadd_s16(s[0], s[2]); + s[3] = vsub_s16(s[3], s[1]); + int16x4_t e = vhsub_s16(s[0], s[3]); // e = (s[0] - s[3]) >> 1 + s[1] = vsub_s16(e, s[1]); + s[2] = vsub_s16(e, s[2]); + s[0] = vsub_s16(s[0], s[1]); + s[3] = vadd_s16(s[3], s[2]); + + int16x8_t x[2]; + x[0] = vcombine_s16(s[0], s[1]); + x[1] = vcombine_s16(s[2], s[3]); + TransposeAndPermute4x4WideInput(x, s); + + // Column transforms. + s[0] = vadd_s16(s[0], s[2]); + s[3] = vsub_s16(s[3], s[1]); + e = vhsub_s16(s[0], s[3]); // e = (s[0] - s[3]) >> 1 + s[1] = vsub_s16(e, s[1]); + s[2] = vsub_s16(e, s[2]); + s[0] = vsub_s16(s[0], s[1]); + s[3] = vadd_s16(s[3], s[2]); + } + + // Store to frame. + uint8x8_t frame_data = vdup_n_u8(0); + for (int row = 0; row < 4; row += 2) { + frame_data = Load4<0>(dst, frame_data); + frame_data = Load4<1>(dst + dst_stride, frame_data); + const int16x8_t residual = vcombine_s16(s[row], s[row + 1]); + const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(residual), frame_data); + frame_data = vqmovun_s16(vreinterpretq_s16_u16(b)); + StoreLo4(dst, frame_data); + dst += dst_stride; + StoreHi4(dst, frame_data); + dst += dst_stride; + } +} + +//------------------------------------------------------------------------------ +// row/column transform loops + +template <int tx_height> +LIBGAV1_ALWAYS_INLINE void FlipColumns(int16_t* source, int tx_width) { + if (tx_width >= 16) { + int i = 0; + do { + const int16x8_t a = vld1q_s16(&source[i]); + const int16x8_t b = vld1q_s16(&source[i + 8]); + const int16x8_t c = vrev64q_s16(a); + const int16x8_t d = vrev64q_s16(b); + vst1q_s16(&source[i], vcombine_s16(vget_high_s16(d), vget_low_s16(d))); + vst1q_s16(&source[i + 8], + vcombine_s16(vget_high_s16(c), vget_low_s16(c))); + i += 16; + } while (i < tx_width * tx_height); + } else if (tx_width == 8) { + for (int i = 0; i < 8 * tx_height; i += 8) { + const int16x8_t a = vld1q_s16(&source[i]); + const int16x8_t b = vrev64q_s16(a); + vst1q_s16(&source[i], vcombine_s16(vget_high_s16(b), vget_low_s16(b))); + } + } else { + // Process two rows per iteration. + for (int i = 0; i < 4 * tx_height; i += 8) { + const int16x8_t a = vld1q_s16(&source[i]); + vst1q_s16(&source[i], vrev64q_s16(a)); + } + } +} + +template <int tx_width> +LIBGAV1_ALWAYS_INLINE void ApplyRounding(int16_t* source, int num_rows) { + if (tx_width == 4) { + // Process two rows per iteration. + int i = 0; + do { + const int16x8_t a = vld1q_s16(&source[i]); + const int16x8_t b = vqrdmulhq_n_s16(a, kTransformRowMultiplier << 3); + vst1q_s16(&source[i], b); + i += 8; + } while (i < tx_width * num_rows); + } else { + int i = 0; + do { + // The last 32 values of every row are always zero if the |tx_width| is + // 64. + const int non_zero_width = (tx_width < 64) ? tx_width : 32; + int j = 0; + do { + const int16x8_t a = vld1q_s16(&source[i * tx_width + j]); + const int16x8_t b = vqrdmulhq_n_s16(a, kTransformRowMultiplier << 3); + vst1q_s16(&source[i * tx_width + j], b); + j += 8; + } while (j < non_zero_width); + } while (++i < num_rows); + } +} + +template <int tx_width> +LIBGAV1_ALWAYS_INLINE void RowShift(int16_t* source, int num_rows, + int row_shift) { + // vqrshlq_s16 will shift right if shift value is negative. + row_shift = -row_shift; + + if (tx_width == 4) { + // Process two rows per iteration. + int i = 0; + do { + const int16x8_t residual = vld1q_s16(&source[i]); + vst1q_s16(&source[i], vqrshlq_s16(residual, vdupq_n_s16(row_shift))); + i += 8; + } while (i < tx_width * num_rows); + } else { + int i = 0; + do { + for (int j = 0; j < tx_width; j += 8) { + const int16x8_t residual = vld1q_s16(&source[i * tx_width + j]); + const int16x8_t residual_shifted = + vqrshlq_s16(residual, vdupq_n_s16(row_shift)); + vst1q_s16(&source[i * tx_width + j], residual_shifted); + } + } while (++i < num_rows); + } +} + +template <int tx_height, bool enable_flip_rows = false> +LIBGAV1_ALWAYS_INLINE void StoreToFrameWithRound( + Array2DView<uint8_t> frame, const int start_x, const int start_y, + const int tx_width, const int16_t* source, TransformType tx_type) { + const bool flip_rows = + enable_flip_rows ? kTransformFlipRowsMask.Contains(tx_type) : false; + const int stride = frame.columns(); + uint8_t* dst = frame[start_y] + start_x; + + // Enable for 4x4, 4x8, 4x16 + if (tx_height < 32 && tx_width == 4) { + uint8x8_t frame_data = vdup_n_u8(0); + for (int i = 0; i < tx_height; ++i) { + const int row = flip_rows ? (tx_height - i - 1) * 4 : i * 4; + const int16x4_t residual = vld1_s16(&source[row]); + frame_data = Load4<0>(dst, frame_data); + const int16x4_t a = vrshr_n_s16(residual, 4); + const uint16x8_t b = + vaddw_u8(vreinterpretq_u16_s16(vcombine_s16(a, a)), frame_data); + const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b)); + StoreLo4(dst, d); + dst += stride; + } + // Enable for 8x4, 8x8, 8x16, 8x32 + } else if (tx_height < 64 && tx_width == 8) { + for (int i = 0; i < tx_height; ++i) { + const int row = flip_rows ? (tx_height - i - 1) * 8 : i * 8; + const int16x8_t residual = vld1q_s16(&source[row]); + const uint8x8_t frame_data = vld1_u8(dst); + const int16x8_t a = vrshrq_n_s16(residual, 4); + const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data); + const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b)); + vst1_u8(dst, d); + dst += stride; + } + // Remaining widths >= 16. + } else { + for (int i = 0; i < tx_height; ++i) { + const int y = start_y + i; + const int row = flip_rows ? (tx_height - i - 1) * tx_width : i * tx_width; + int j = 0; + do { + const int x = start_x + j; + const int16x8_t residual = vld1q_s16(&source[row + j]); + const int16x8_t residual_hi = vld1q_s16(&source[row + j + 8]); + const uint8x16_t frame_data = vld1q_u8(frame[y] + x); + const int16x8_t a = vrshrq_n_s16(residual, 4); + const int16x8_t a_hi = vrshrq_n_s16(residual_hi, 4); + const uint16x8_t b = + vaddw_u8(vreinterpretq_u16_s16(a), vget_low_u8(frame_data)); + const uint16x8_t b_hi = + vaddw_u8(vreinterpretq_u16_s16(a_hi), vget_high_u8(frame_data)); + vst1q_u8(frame[y] + x, + vcombine_u8(vqmovun_s16(vreinterpretq_s16_u16(b)), + vqmovun_s16(vreinterpretq_s16_u16(b_hi)))); + j += 16; + } while (j < tx_width); + } + } +} + +void Dct4TransformLoopRow_NEON(TransformType /*tx_type*/, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_height = kTransformHeight[tx_size]; + const bool should_round = (tx_height == 8); + const int row_shift = (tx_height == 16); + + if (DctDcOnly<4>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<4>(src, adjusted_tx_height); + } + + if (adjusted_tx_height == 4) { + // Process 4 1d dct4 rows in parallel. + Dct4_NEON<ButterflyRotation_4, false>(src, /*step=*/4, /*transpose=*/true); + } else { + // Process 8 1d dct4 rows in parallel per iteration. + int i = adjusted_tx_height; + auto* data = src; + do { + Dct4_NEON<ButterflyRotation_8, true>(data, /*step=*/4, + /*transpose=*/true); + data += 32; + i -= 8; + } while (i != 0); + } + if (tx_height == 16) { + RowShift<4>(src, adjusted_tx_height, 1); + } +} + +void Dct4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<4>(src, tx_width); + } + + if (!DctDcOnlyColumn<4>(src, adjusted_tx_height, tx_width)) { + if (tx_width == 4) { + // Process 4 1d dct4 columns in parallel. + Dct4_NEON<ButterflyRotation_4, false>(src, tx_width, /*transpose=*/false); + } else { + // Process 8 1d dct4 columns in parallel per iteration. + int i = tx_width; + auto* data = src; + do { + Dct4_NEON<ButterflyRotation_8, true>(data, tx_width, + /*transpose=*/false); + data += 8; + i -= 8; + } while (i != 0); + } + } + + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound<4>(frame, start_x, start_y, tx_width, src, tx_type); +} + +void Dct8TransformLoopRow_NEON(TransformType /*tx_type*/, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (DctDcOnly<8>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<8>(src, adjusted_tx_height); + } + + if (adjusted_tx_height == 4) { + // Process 4 1d dct8 rows in parallel. + Dct8_NEON<ButterflyRotation_4, true>(src, /*step=*/8, /*transpose=*/true); + } else { + // Process 8 1d dct8 rows in parallel per iteration. + assert(adjusted_tx_height % 8 == 0); + int i = adjusted_tx_height; + auto* data = src; + do { + Dct8_NEON<ButterflyRotation_8, false>(data, /*step=*/8, + /*transpose=*/true); + data += 64; + i -= 8; + } while (i != 0); + } + if (row_shift > 0) { + RowShift<8>(src, adjusted_tx_height, row_shift); + } +} + +void Dct8TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<8>(src, tx_width); + } + + if (!DctDcOnlyColumn<8>(src, adjusted_tx_height, tx_width)) { + if (tx_width == 4) { + // Process 4 1d dct8 columns in parallel. + Dct8_NEON<ButterflyRotation_4, true>(src, 4, /*transpose=*/false); + } else { + // Process 8 1d dct8 columns in parallel per iteration. + int i = tx_width; + auto* data = src; + do { + Dct8_NEON<ButterflyRotation_8, false>(data, tx_width, + /*transpose=*/false); + data += 8; + i -= 8; + } while (i != 0); + } + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound<8>(frame, start_x, start_y, tx_width, src, tx_type); +} + +void Dct16TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (DctDcOnly<16>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<16>(src, adjusted_tx_height); + } + + if (adjusted_tx_height == 4) { + // Process 4 1d dct16 rows in parallel. + Dct16_NEON<ButterflyRotation_4, true>(src, 16, /*is_row=*/true, row_shift); + } else { + assert(adjusted_tx_height % 8 == 0); + int i = adjusted_tx_height; + do { + // Process 8 1d dct16 rows in parallel per iteration. + Dct16_NEON<ButterflyRotation_8, false>(src, 16, /*is_row=*/true, + row_shift); + src += 128; + i -= 8; + } while (i != 0); + } +} + +void Dct16TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<16>(src, tx_width); + } + + if (!DctDcOnlyColumn<16>(src, adjusted_tx_height, tx_width)) { + if (tx_width == 4) { + // Process 4 1d dct16 columns in parallel. + Dct16_NEON<ButterflyRotation_4, true>(src, 4, /*is_row=*/false, + /*row_shift=*/0); + } else { + int i = tx_width; + auto* data = src; + do { + // Process 8 1d dct16 columns in parallel per iteration. + Dct16_NEON<ButterflyRotation_8, false>(data, tx_width, /*is_row=*/false, + /*row_shift=*/0); + data += 8; + i -= 8; + } while (i != 0); + } + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound<16>(frame, start_x, start_y, tx_width, src, tx_type); +} + +void Dct32TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (DctDcOnly<32>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<32>(src, adjusted_tx_height); + } + // Process 8 1d dct32 rows in parallel per iteration. + int i = 0; + do { + Dct32_NEON(&src[i * 32], 32, /*is_row=*/true, row_shift); + i += 8; + } while (i < adjusted_tx_height); +} + +void Dct32TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (!DctDcOnlyColumn<32>(src, adjusted_tx_height, tx_width)) { + // Process 8 1d dct32 columns in parallel per iteration. + int i = tx_width; + auto* data = src; + do { + Dct32_NEON(data, tx_width, /*is_row=*/false, /*row_shift=*/0); + data += 8; + i -= 8; + } while (i != 0); + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound<32>(frame, start_x, start_y, tx_width, src, tx_type); +} + +void Dct64TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (DctDcOnly<64>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<64>(src, adjusted_tx_height); + } + // Process 8 1d dct64 rows in parallel per iteration. + int i = 0; + do { + Dct64_NEON(&src[i * 64], 64, /*is_row=*/true, row_shift); + i += 8; + } while (i < adjusted_tx_height); +} + +void Dct64TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (!DctDcOnlyColumn<64>(src, adjusted_tx_height, tx_width)) { + // Process 8 1d dct64 columns in parallel per iteration. + int i = tx_width; + auto* data = src; + do { + Dct64_NEON(data, tx_width, /*is_row=*/false, /*row_shift=*/0); + data += 8; + i -= 8; + } while (i != 0); + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound<64>(frame, start_x, start_y, tx_width, src, tx_type); +} + +void Adst4TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_height = kTransformHeight[tx_size]; + const int row_shift = static_cast<int>(tx_height == 16); + const bool should_round = (tx_height == 8); + + if (Adst4DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<4>(src, adjusted_tx_height); + } + + // Process 4 1d adst4 rows in parallel per iteration. + int i = adjusted_tx_height; + auto* data = src; + do { + Adst4_NEON<false>(data, /*step=*/4, /*transpose=*/true); + data += 16; + i -= 4; + } while (i != 0); + + if (tx_height == 16) { + RowShift<4>(src, adjusted_tx_height, 1); + } +} + +void Adst4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<4>(src, tx_width); + } + + if (!Adst4DcOnlyColumn(src, adjusted_tx_height, tx_width)) { + // Process 4 1d adst4 columns in parallel per iteration. + int i = tx_width; + auto* data = src; + do { + Adst4_NEON<false>(data, tx_width, /*transpose=*/false); + data += 4; + i -= 4; + } while (i != 0); + } + + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound<4, /*enable_flip_rows=*/true>(frame, start_x, start_y, + tx_width, src, tx_type); +} + +void Adst8TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (Adst8DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<8>(src, adjusted_tx_height); + } + + if (adjusted_tx_height == 4) { + // Process 4 1d adst8 rows in parallel. + Adst8_NEON<ButterflyRotation_4, true>(src, /*step=*/8, /*transpose=*/true); + } else { + // Process 8 1d adst8 rows in parallel per iteration. + assert(adjusted_tx_height % 8 == 0); + int i = adjusted_tx_height; + auto* data = src; + do { + Adst8_NEON<ButterflyRotation_8, false>(data, /*step=*/8, + /*transpose=*/true); + data += 64; + i -= 8; + } while (i != 0); + } + if (row_shift > 0) { + RowShift<8>(src, adjusted_tx_height, row_shift); + } +} + +void Adst8TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<8>(src, tx_width); + } + + if (!Adst8DcOnlyColumn(src, adjusted_tx_height, tx_width)) { + if (tx_width == 4) { + // Process 4 1d adst8 columns in parallel. + Adst8_NEON<ButterflyRotation_4, true>(src, 4, /*transpose=*/false); + } else { + // Process 8 1d adst8 columns in parallel per iteration. + int i = tx_width; + auto* data = src; + do { + Adst8_NEON<ButterflyRotation_8, false>(data, tx_width, + /*transpose=*/false); + data += 8; + i -= 8; + } while (i != 0); + } + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound<8, /*enable_flip_rows=*/true>(frame, start_x, start_y, + tx_width, src, tx_type); +} + +void Adst16TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (Adst16DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<16>(src, adjusted_tx_height); + } + + if (adjusted_tx_height == 4) { + // Process 4 1d adst16 rows in parallel. + Adst16_NEON<ButterflyRotation_4, true>(src, 16, /*is_row=*/true, row_shift); + } else { + assert(adjusted_tx_height % 8 == 0); + int i = adjusted_tx_height; + do { + // Process 8 1d adst16 rows in parallel per iteration. + Adst16_NEON<ButterflyRotation_8, false>(src, 16, /*is_row=*/true, + row_shift); + src += 128; + i -= 8; + } while (i != 0); + } +} + +void Adst16TransformLoopColumn_NEON(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<16>(src, tx_width); + } + + if (!Adst16DcOnlyColumn(src, adjusted_tx_height, tx_width)) { + if (tx_width == 4) { + // Process 4 1d adst16 columns in parallel. + Adst16_NEON<ButterflyRotation_4, true>(src, 4, /*is_row=*/false, + /*row_shift=*/0); + } else { + int i = tx_width; + auto* data = src; + do { + // Process 8 1d adst16 columns in parallel per iteration. + Adst16_NEON<ButterflyRotation_8, false>( + data, tx_width, /*is_row=*/false, /*row_shift=*/0); + data += 8; + i -= 8; + } while (i != 0); + } + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound<16, /*enable_flip_rows=*/true>(frame, start_x, start_y, + tx_width, src, tx_type); +} + +void Identity4TransformLoopRow_NEON(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + // Special case: Process row calculations during column transform call. + // Improves performance. + if (tx_type == kTransformTypeIdentityIdentity && + tx_size == kTransformSize4x4) { + return; + } + + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_height = kTransformHeight[tx_size]; + const bool should_round = (tx_height == 8); + + if (Identity4DcOnly(src, adjusted_tx_height, should_round, tx_height)) { + return; + } + + if (should_round) { + ApplyRounding<4>(src, adjusted_tx_height); + } + if (tx_height < 16) { + int i = adjusted_tx_height; + do { + Identity4_NEON<false>(src, /*step=*/4); + src += 16; + i -= 4; + } while (i != 0); + } else { + int i = adjusted_tx_height; + do { + Identity4_NEON<true>(src, /*step=*/4); + src += 16; + i -= 4; + } while (i != 0); + } +} + +void Identity4TransformLoopColumn_NEON(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, + void* dst_frame) { + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + // Special case: Process row calculations during column transform call. + if (tx_type == kTransformTypeIdentityIdentity && + (tx_size == kTransformSize4x4 || tx_size == kTransformSize8x4)) { + Identity4RowColumnStoreToFrame(frame, start_x, start_y, tx_width, + adjusted_tx_height, src); + return; + } + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<4>(src, tx_width); + } + + IdentityColumnStoreToFrame<4>(frame, start_x, start_y, tx_width, + adjusted_tx_height, src); +} + +void Identity8TransformLoopRow_NEON(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + // Special case: Process row calculations during column transform call. + // Improves performance. + if (tx_type == kTransformTypeIdentityIdentity && + tx_size == kTransformSize8x4) { + return; + } + + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_height = kTransformHeight[tx_size]; + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (Identity8DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<8>(src, adjusted_tx_height); + } + + // When combining the identity8 multiplier with the row shift, the + // calculations for tx_height == 8 and tx_height == 16 can be simplified + // from ((A * 2) + 1) >> 1) to A. + if ((tx_height & 0x18) != 0) { + return; + } + if (tx_height == 32) { + int i = adjusted_tx_height; + do { + Identity8Row32_NEON(src, /*step=*/8); + src += 32; + i -= 4; + } while (i != 0); + return; + } + + assert(tx_size == kTransformSize8x4); + int i = adjusted_tx_height; + do { + Identity8Row4_NEON(src, /*step=*/8); + src += 32; + i -= 4; + } while (i != 0); +} + +void Identity8TransformLoopColumn_NEON(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, + void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<8>(src, tx_width); + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + IdentityColumnStoreToFrame<8>(frame, start_x, start_y, tx_width, + adjusted_tx_height, src); +} + +void Identity16TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (Identity16DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<16>(src, adjusted_tx_height); + } + int i = adjusted_tx_height; + do { + Identity16Row_NEON(src, /*step=*/16, kTransformRowShift[tx_size]); + src += 64; + i -= 4; + } while (i != 0); +} + +void Identity16TransformLoopColumn_NEON(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, + void* src_buffer, int start_x, + int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<16>(src, tx_width); + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + IdentityColumnStoreToFrame<16>(frame, start_x, start_y, tx_width, + adjusted_tx_height, src); +} + +void Identity32TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + const int tx_height = kTransformHeight[tx_size]; + + // When combining the identity32 multiplier with the row shift, the + // calculations for tx_height == 8 and tx_height == 32 can be simplified + // from ((A * 4) + 2) >> 2) to A. + if ((tx_height & 0x28) != 0) { + return; + } + + // Process kTransformSize32x16. The src is always rounded before the + // identity transform and shifted by 1 afterwards. + auto* src = static_cast<int16_t*>(src_buffer); + if (Identity32DcOnly(src, adjusted_tx_height)) { + return; + } + + assert(tx_size == kTransformSize32x16); + ApplyRounding<32>(src, adjusted_tx_height); + int i = adjusted_tx_height; + do { + Identity32Row16_NEON(src, /*step=*/32); + src += 128; + i -= 4; + } while (i != 0); +} + +void Identity32TransformLoopColumn_NEON(TransformType /*tx_type*/, + TransformSize tx_size, + int adjusted_tx_height, + void* src_buffer, int start_x, + int start_y, void* dst_frame) { + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + IdentityColumnStoreToFrame<32>(frame, start_x, start_y, tx_width, + adjusted_tx_height, src); +} + +void Wht4TransformLoopRow_NEON(TransformType tx_type, TransformSize tx_size, + int /*adjusted_tx_height*/, void* /*src_buffer*/, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + assert(tx_type == kTransformTypeDctDct); + assert(tx_size == kTransformSize4x4); + static_cast<void>(tx_type); + static_cast<void>(tx_size); + // Do both row and column transforms in the column-transform pass. +} + +void Wht4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + assert(tx_type == kTransformTypeDctDct); + assert(tx_size == kTransformSize4x4); + static_cast<void>(tx_type); + static_cast<void>(tx_size); + + // Process 4 1d wht4 rows and columns in parallel. + const auto* src = static_cast<int16_t*>(src_buffer); + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + uint8_t* dst = frame[start_y] + start_x; + const int dst_stride = frame.columns(); + Wht4_NEON(dst, dst_stride, src, adjusted_tx_height); +} + +//------------------------------------------------------------------------------ + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + // Maximum transform size for Dct is 64. + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kRow] = + Dct4TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kColumn] = + Dct4TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kRow] = + Dct8TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kColumn] = + Dct8TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kRow] = + Dct16TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kColumn] = + Dct16TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kRow] = + Dct32TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kColumn] = + Dct32TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kRow] = + Dct64TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kColumn] = + Dct64TransformLoopColumn_NEON; + + // Maximum transform size for Adst is 16. + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kRow] = + Adst4TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kColumn] = + Adst4TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kRow] = + Adst8TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kColumn] = + Adst8TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kRow] = + Adst16TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kColumn] = + Adst16TransformLoopColumn_NEON; + + // Maximum transform size for Identity transform is 32. + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kRow] = + Identity4TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kColumn] = + Identity4TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kRow] = + Identity8TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kColumn] = + Identity8TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kRow] = + Identity16TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kColumn] = + Identity16TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kRow] = + Identity32TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kColumn] = + Identity32TransformLoopColumn_NEON; + + // Maximum transform size for Wht is 4. + dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kRow] = + Wht4TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kColumn] = + Wht4TransformLoopColumn_NEON; +} + +} // namespace +} // namespace low_bitdepth + +void InverseTransformInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void InverseTransformInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/inverse_transform_neon.h b/src/dsp/arm/inverse_transform_neon.h new file mode 100644 index 0000000..af647e8 --- /dev/null +++ b/src/dsp/arm/inverse_transform_neon.h @@ -0,0 +1,52 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_INVERSE_TRANSFORM_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_INVERSE_TRANSFORM_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::inverse_transforms, see the defines below for specifics. +// This function is not thread-safe. +void InverseTransformInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformDct LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformDct LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformDct LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformDct LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize64_1DTransformDct LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformAdst LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformAdst LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformAdst LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformIdentity LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformIdentity LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformIdentity LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_INVERSE_TRANSFORM_NEON_H_ diff --git a/src/dsp/arm/loop_filter_neon.cc b/src/dsp/arm/loop_filter_neon.cc new file mode 100644 index 0000000..146c983 --- /dev/null +++ b/src/dsp/arm/loop_filter_neon.cc @@ -0,0 +1,1190 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/loop_filter.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +// (abs(p1 - p0) > thresh) || (abs(q1 - q0) > thresh) +inline uint8x8_t Hev(const uint8x8_t abd_p0p1_q0q1, const uint8_t thresh) { + const uint8x8_t a = vcgt_u8(abd_p0p1_q0q1, vdup_n_u8(thresh)); + return vorr_u8(a, RightShift<32>(a)); +} + +// abs(p0 - q0) * 2 + abs(p1 - q1) / 2 <= outer_thresh +inline uint8x8_t OuterThreshold(const uint8x8_t p0q0, const uint8x8_t p1q1, + const uint8_t outer_thresh) { + const uint8x8x2_t a = Interleave32(p0q0, p1q1); + const uint8x8_t b = vabd_u8(a.val[0], a.val[1]); + const uint8x8_t p0q0_double = vqadd_u8(b, b); + const uint8x8_t p1q1_half = RightShift<32>(vshr_n_u8(b, 1)); + const uint8x8_t c = vqadd_u8(p0q0_double, p1q1_half); + return vcle_u8(c, vdup_n_u8(outer_thresh)); +} + +// abs(p1 - p0) <= inner_thresh && abs(q1 - q0) <= inner_thresh && +// OuterThreshhold() +inline uint8x8_t NeedsFilter4(const uint8x8_t abd_p0p1_q0q1, + const uint8x8_t p0q0, const uint8x8_t p1q1, + const uint8_t inner_thresh, + const uint8_t outer_thresh) { + const uint8x8_t a = vcle_u8(abd_p0p1_q0q1, vdup_n_u8(inner_thresh)); + const uint8x8_t inner_mask = vand_u8(a, RightShift<32>(a)); + const uint8x8_t outer_mask = OuterThreshold(p0q0, p1q1, outer_thresh); + return vand_u8(inner_mask, outer_mask); +} + +inline void Filter4Masks(const uint8x8_t p0q0, const uint8x8_t p1q1, + const uint8_t hev_thresh, const uint8_t outer_thresh, + const uint8_t inner_thresh, uint8x8_t* const hev_mask, + uint8x8_t* const needs_filter4_mask) { + const uint8x8_t p0p1_q0q1 = vabd_u8(p0q0, p1q1); + // This includes cases where NeedsFilter4() is not true and so Filter2() will + // not be applied. + const uint8x8_t hev_tmp_mask = Hev(p0p1_q0q1, hev_thresh); + + *needs_filter4_mask = + NeedsFilter4(p0p1_q0q1, p0q0, p1q1, inner_thresh, outer_thresh); + + // Filter2() will only be applied if both NeedsFilter4() and Hev() are true. + *hev_mask = vand_u8(hev_tmp_mask, *needs_filter4_mask); +} + +// Calculate Filter4() or Filter2() based on |hev_mask|. +inline void Filter4(const uint8x8_t q0p1, const uint8x8_t p0q1, + const uint8x8_t hev_mask, uint8x8_t* const p1q1_result, + uint8x8_t* const p0q0_result) { + const int16x4_t zero = vdup_n_s16(0); + + // a = 3 * (q0 - p0) + Clip3(p1 - q1, min_signed_val, max_signed_val); + const int16x8_t q0mp0_p1mq1 = vreinterpretq_s16_u16(vsubl_u8(q0p1, p0q1)); + const int16x4_t q0mp0_3 = vmul_n_s16(vget_low_s16(q0mp0_p1mq1), 3); + + // If this is for Filter2() then include |p1mq1|. Otherwise zero it. + const int16x4_t p1mq1 = vget_high_s16(q0mp0_p1mq1); + const int8x8_t p1mq1_saturated = vqmovn_s16(vcombine_s16(p1mq1, zero)); + const int8x8_t hev_option = + vand_s8(vreinterpret_s8_u8(hev_mask), p1mq1_saturated); + + const int16x4_t a = + vget_low_s16(vaddw_s8(vcombine_s16(q0mp0_3, zero), hev_option)); + + // We can not shift with rounding because the clamp comes *before* the + // shifting. a1 = Clip3(a + 4, min_signed_val, max_signed_val) >> 3; a2 = + // Clip3(a + 3, min_signed_val, max_signed_val) >> 3; + const int16x4_t plus_four = vadd_s16(a, vdup_n_s16(4)); + const int16x4_t plus_three = vadd_s16(a, vdup_n_s16(3)); + const int8x8_t a2_a1 = + vshr_n_s8(vqmovn_s16(vcombine_s16(plus_three, plus_four)), 3); + + // a3 is in the high 4 values. + // a3 = (a1 + 1) >> 1; + const int8x8_t a3 = vrshr_n_s8(a2_a1, 1); + + const int16x8_t p0q1_l = vreinterpretq_s16_u16(vmovl_u8(p0q1)); + const int16x8_t q0p1_l = vreinterpretq_s16_u16(vmovl_u8(q0p1)); + + const int16x8_t p1q1_l = + vcombine_s16(vget_high_s16(q0p1_l), vget_high_s16(p0q1_l)); + + const int8x8_t a3_ma3 = InterleaveHigh32(a3, vneg_s8(a3)); + const int16x8_t p1q1_a3 = vaddw_s8(p1q1_l, a3_ma3); + + const int16x8_t p0q0_l = + vcombine_s16(vget_low_s16(p0q1_l), vget_low_s16(q0p1_l)); + // Need to shift the second term or we end up with a2_ma2. + const int8x8_t a2_ma1 = + InterleaveLow32(a2_a1, RightShift<32>(vneg_s8(a2_a1))); + const int16x8_t p0q0_a = vaddw_s8(p0q0_l, a2_ma1); + + *p1q1_result = vqmovun_s16(p1q1_a3); + *p0q0_result = vqmovun_s16(p0q0_a); +} + +void Horizontal4_NEON(void* const dest, const ptrdiff_t stride, + const int outer_thresh, const int inner_thresh, + const int hev_thresh) { + uint8_t* dst = static_cast<uint8_t*>(dest); + + const uint8x8_t p1_v = Load4(dst - 2 * stride); + const uint8x8_t p0_v = Load4(dst - stride); + const uint8x8_t p0q0 = Load4<1>(dst, p0_v); + const uint8x8_t p1q1 = Load4<1>(dst + stride, p1_v); + + uint8x8_t hev_mask; + uint8x8_t needs_filter4_mask; + Filter4Masks(p0q0, p1q1, hev_thresh, outer_thresh, inner_thresh, &hev_mask, + &needs_filter4_mask); + + // Copy the masks to the high bits for packed comparisons later. + hev_mask = InterleaveLow32(hev_mask, hev_mask); + needs_filter4_mask = InterleaveLow32(needs_filter4_mask, needs_filter4_mask); + +#if defined(__aarch64__) + // This provides a good speedup for the unit test. Not sure how applicable it + // is to valid streams though. + // Consider doing this on armv7 if there is a quick way to check if a vector + // is zero. + if (vaddv_u8(needs_filter4_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // defined(__aarch64__) + + uint8x8_t f_p1q1; + uint8x8_t f_p0q0; + const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1); + Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0); + + // Already integrated the Hev mask when calculating the filtered values. + const uint8x8_t p0q0_output = vbsl_u8(needs_filter4_mask, f_p0q0, p0q0); + + // p1/q1 are unmodified if only Hev() is true. This works because it was and'd + // with |needs_filter4_mask| previously. + const uint8x8_t p1q1_mask = veor_u8(hev_mask, needs_filter4_mask); + const uint8x8_t p1q1_output = vbsl_u8(p1q1_mask, f_p1q1, p1q1); + + StoreLo4(dst - 2 * stride, p1q1_output); + StoreLo4(dst - stride, p0q0_output); + StoreHi4(dst, p0q0_output); + StoreHi4(dst + stride, p1q1_output); +} + +void Vertical4_NEON(void* const dest, const ptrdiff_t stride, + const int outer_thresh, const int inner_thresh, + const int hev_thresh) { + uint8_t* dst = static_cast<uint8_t*>(dest); + + // Move |dst| to the left side of the filter window. + dst -= 2; + + // |p1q0| and |p0q1| are named for the values they will contain after the + // transpose. + const uint8x8_t row0 = Load4(dst); + uint8x8_t p1q0 = Load4<1>(dst + stride, row0); + const uint8x8_t row2 = Load4(dst + 2 * stride); + uint8x8_t p0q1 = Load4<1>(dst + 3 * stride, row2); + + Transpose4x4(&p1q0, &p0q1); + // Rearrange. + const uint8x8x2_t p1q1xq0p0 = Interleave32(p1q0, Transpose32(p0q1)); + const uint8x8x2_t p1q1xp0q0 = {p1q1xq0p0.val[0], + Transpose32(p1q1xq0p0.val[1])}; + + uint8x8_t hev_mask; + uint8x8_t needs_filter4_mask; + Filter4Masks(p1q1xp0q0.val[1], p1q1xp0q0.val[0], hev_thresh, outer_thresh, + inner_thresh, &hev_mask, &needs_filter4_mask); + + // Copy the masks to the high bits for packed comparisons later. + hev_mask = InterleaveLow32(hev_mask, hev_mask); + needs_filter4_mask = InterleaveLow32(needs_filter4_mask, needs_filter4_mask); + +#if defined(__aarch64__) + // This provides a good speedup for the unit test. Not sure how applicable it + // is to valid streams though. + // Consider doing this on armv7 if there is a quick way to check if a vector + // is zero. + if (vaddv_u8(needs_filter4_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // defined(__aarch64__) + + uint8x8_t f_p1q1; + uint8x8_t f_p0q0; + Filter4(Transpose32(p1q0), p0q1, hev_mask, &f_p1q1, &f_p0q0); + + // Already integrated the Hev mask when calculating the filtered values. + const uint8x8_t p0q0_output = + vbsl_u8(needs_filter4_mask, f_p0q0, p1q1xp0q0.val[1]); + + // p1/q1 are unmodified if only Hev() is true. This works because it was and'd + // with |needs_filter4_mask| previously. + const uint8x8_t p1q1_mask = veor_u8(hev_mask, needs_filter4_mask); + const uint8x8_t p1q1_output = vbsl_u8(p1q1_mask, f_p1q1, p1q1xp0q0.val[0]); + + // Put things back in order to reverse the transpose. + const uint8x8x2_t p1p0xq1q0 = Interleave32(p1q1_output, p0q0_output); + uint8x8_t output_0 = p1p0xq1q0.val[0], + output_1 = Transpose32(p1p0xq1q0.val[1]); + + Transpose4x4(&output_0, &output_1); + + StoreLo4(dst, output_0); + StoreLo4(dst + stride, output_1); + StoreHi4(dst + 2 * stride, output_0); + StoreHi4(dst + 3 * stride, output_1); +} + +// abs(p1 - p0) <= flat_thresh && abs(q1 - q0) <= flat_thresh && +// abs(p2 - p0) <= flat_thresh && abs(q2 - q0) <= flat_thresh +// |flat_thresh| == 1 for 8 bit decode. +inline uint8x8_t IsFlat3(const uint8x8_t abd_p0p1_q0q1, + const uint8x8_t abd_p0p2_q0q2) { + const uint8x8_t a = vmax_u8(abd_p0p1_q0q1, abd_p0p2_q0q2); + const uint8x8_t b = vcle_u8(a, vdup_n_u8(1)); + return vand_u8(b, RightShift<32>(b)); +} + +// abs(p2 - p1) <= inner_thresh && abs(p1 - p0) <= inner_thresh && +// abs(q1 - q0) <= inner_thresh && abs(q2 - q1) <= inner_thresh && +// OuterThreshhold() +inline uint8x8_t NeedsFilter6(const uint8x8_t abd_p0p1_q0q1, + const uint8x8_t abd_p1p2_q1q2, + const uint8x8_t p0q0, const uint8x8_t p1q1, + const uint8_t inner_thresh, + const uint8_t outer_thresh) { + const uint8x8_t a = vmax_u8(abd_p0p1_q0q1, abd_p1p2_q1q2); + const uint8x8_t b = vcle_u8(a, vdup_n_u8(inner_thresh)); + const uint8x8_t inner_mask = vand_u8(b, RightShift<32>(b)); + const uint8x8_t outer_mask = OuterThreshold(p0q0, p1q1, outer_thresh); + return vand_u8(inner_mask, outer_mask); +} + +inline void Filter6Masks(const uint8x8_t p2q2, const uint8x8_t p1q1, + const uint8x8_t p0q0, const uint8_t hev_thresh, + const uint8_t outer_thresh, const uint8_t inner_thresh, + uint8x8_t* const needs_filter6_mask, + uint8x8_t* const is_flat3_mask, + uint8x8_t* const hev_mask) { + const uint8x8_t p0p1_q0q1 = vabd_u8(p0q0, p1q1); + *hev_mask = Hev(p0p1_q0q1, hev_thresh); + *is_flat3_mask = IsFlat3(p0p1_q0q1, vabd_u8(p0q0, p2q2)); + *needs_filter6_mask = NeedsFilter6(p0p1_q0q1, vabd_u8(p1q1, p2q2), p0q0, p1q1, + inner_thresh, outer_thresh); +} + +inline void Filter6(const uint8x8_t p2q2, const uint8x8_t p1q1, + const uint8x8_t p0q0, uint8x8_t* const p1q1_output, + uint8x8_t* const p0q0_output) { + // Sum p1 and q1 output from opposite directions + // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0 + // ^^^^^^^^ + // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3) + // ^^^^^^^^ + const uint16x8_t p2q2_double = vaddl_u8(p2q2, p2q2); + uint16x8_t sum = vaddw_u8(p2q2_double, p2q2); + + // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0 + // ^^^^^^^^ + // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3) + // ^^^^^^^^ + sum = vaddq_u16(vaddl_u8(p1q1, p1q1), sum); + + // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0 + // ^^^^^^^^ + // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3) + // ^^^^^^^^ + sum = vaddq_u16(vaddl_u8(p0q0, p0q0), sum); + + // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0 + // ^^ + // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3) + // ^^ + const uint8x8_t q0p0 = Transpose32(p0q0); + sum = vaddw_u8(sum, q0p0); + + *p1q1_output = vrshrn_n_u16(sum, 3); + + // Convert to p0 and q0 output: + // p0 = p1 - (2 * p2) + q0 + q1 + // q0 = q1 - (2 * q2) + p0 + p1 + sum = vsubq_u16(sum, p2q2_double); + const uint8x8_t q1p1 = Transpose32(p1q1); + sum = vaddq_u16(vaddl_u8(q0p0, q1p1), sum); + + *p0q0_output = vrshrn_n_u16(sum, 3); +} + +void Horizontal6_NEON(void* const dest, const ptrdiff_t stride, + const int outer_thresh, const int inner_thresh, + const int hev_thresh) { + auto* dst = static_cast<uint8_t*>(dest); + + const uint8x8_t p2_v = Load4(dst - 3 * stride); + const uint8x8_t p1_v = Load4(dst - 2 * stride); + const uint8x8_t p0_v = Load4(dst - stride); + const uint8x8_t p0q0 = Load4<1>(dst, p0_v); + const uint8x8_t p1q1 = Load4<1>(dst + stride, p1_v); + const uint8x8_t p2q2 = Load4<1>(dst + 2 * stride, p2_v); + + uint8x8_t needs_filter6_mask, is_flat3_mask, hev_mask; + Filter6Masks(p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh, + &needs_filter6_mask, &is_flat3_mask, &hev_mask); + + needs_filter6_mask = InterleaveLow32(needs_filter6_mask, needs_filter6_mask); + is_flat3_mask = InterleaveLow32(is_flat3_mask, is_flat3_mask); + hev_mask = InterleaveLow32(hev_mask, hev_mask); + +#if defined(__aarch64__) + // This provides a good speedup for the unit test. Not sure how applicable it + // is to valid streams though. + // Consider doing this on armv7 if there is a quick way to check if a vector + // is zero. + if (vaddv_u8(needs_filter6_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // defined(__aarch64__) + + uint8x8_t f_p1q1; + uint8x8_t f_p0q0; + const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1); + Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0); + // Reset the outer values if only a Hev() mask was required. + f_p1q1 = vbsl_u8(hev_mask, p1q1, f_p1q1); + + uint8x8_t f6_p1q1, f6_p0q0; +#if defined(__aarch64__) + if (vaddv_u8(vand_u8(is_flat3_mask, needs_filter6_mask)) == 0) { + // Filter6() does not apply. + const uint8x8_t zero = vdup_n_u8(0); + f6_p1q1 = zero; + f6_p0q0 = zero; + } else { +#endif // defined(__aarch64__) + Filter6(p2q2, p1q1, p0q0, &f6_p1q1, &f6_p0q0); +#if defined(__aarch64__) + } +#endif // defined(__aarch64__) + + uint8x8_t p1q1_output = vbsl_u8(is_flat3_mask, f6_p1q1, f_p1q1); + p1q1_output = vbsl_u8(needs_filter6_mask, p1q1_output, p1q1); + StoreLo4(dst - 2 * stride, p1q1_output); + StoreHi4(dst + stride, p1q1_output); + + uint8x8_t p0q0_output = vbsl_u8(is_flat3_mask, f6_p0q0, f_p0q0); + p0q0_output = vbsl_u8(needs_filter6_mask, p0q0_output, p0q0); + StoreLo4(dst - stride, p0q0_output); + StoreHi4(dst, p0q0_output); +} + +void Vertical6_NEON(void* const dest, const ptrdiff_t stride, + const int outer_thresh, const int inner_thresh, + const int hev_thresh) { + auto* dst = static_cast<uint8_t*>(dest); + + // Move |dst| to the left side of the filter window. + dst -= 3; + + // |p2q1|, |p1q2|, |p0xx| and |q0xx| are named for the values they will + // contain after the transpose. + // These over-read by 2 bytes. We only need 6. + uint8x8_t p2q1 = vld1_u8(dst); + uint8x8_t p1q2 = vld1_u8(dst + stride); + uint8x8_t p0xx = vld1_u8(dst + 2 * stride); + uint8x8_t q0xx = vld1_u8(dst + 3 * stride); + + Transpose8x4(&p2q1, &p1q2, &p0xx, &q0xx); + + const uint8x8x2_t p2q2xq1p1 = Interleave32(p2q1, Transpose32(p1q2)); + const uint8x8_t p2q2 = p2q2xq1p1.val[0]; + const uint8x8_t p1q1 = Transpose32(p2q2xq1p1.val[1]); + const uint8x8_t p0q0 = InterleaveLow32(p0xx, q0xx); + + uint8x8_t needs_filter6_mask, is_flat3_mask, hev_mask; + Filter6Masks(p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh, + &needs_filter6_mask, &is_flat3_mask, &hev_mask); + + needs_filter6_mask = InterleaveLow32(needs_filter6_mask, needs_filter6_mask); + is_flat3_mask = InterleaveLow32(is_flat3_mask, is_flat3_mask); + hev_mask = InterleaveLow32(hev_mask, hev_mask); + +#if defined(__aarch64__) + // This provides a good speedup for the unit test. Not sure how applicable it + // is to valid streams though. + // Consider doing this on armv7 if there is a quick way to check if a vector + // is zero. + if (vaddv_u8(needs_filter6_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // defined(__aarch64__) + + uint8x8_t f_p1q1; + uint8x8_t f_p0q0; + const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1); + Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0); + // Reset the outer values if only a Hev() mask was required. + f_p1q1 = vbsl_u8(hev_mask, p1q1, f_p1q1); + + uint8x8_t f6_p1q1, f6_p0q0; +#if defined(__aarch64__) + if (vaddv_u8(vand_u8(is_flat3_mask, needs_filter6_mask)) == 0) { + // Filter6() does not apply. + const uint8x8_t zero = vdup_n_u8(0); + f6_p1q1 = zero; + f6_p0q0 = zero; + } else { +#endif // defined(__aarch64__) + Filter6(p2q2, p1q1, p0q0, &f6_p1q1, &f6_p0q0); +#if defined(__aarch64__) + } +#endif // defined(__aarch64__) + + uint8x8_t p1q1_output = vbsl_u8(is_flat3_mask, f6_p1q1, f_p1q1); + p1q1_output = vbsl_u8(needs_filter6_mask, p1q1_output, p1q1); + + uint8x8_t p0q0_output = vbsl_u8(is_flat3_mask, f6_p0q0, f_p0q0); + p0q0_output = vbsl_u8(needs_filter6_mask, p0q0_output, p0q0); + + // The six tap filter is only six taps on input. Output is limited to p1-q1. + dst += 1; + // Put things back in order to reverse the transpose. + const uint8x8x2_t p1p0xq1q0 = Interleave32(p1q1_output, p0q0_output); + uint8x8_t output_0 = p1p0xq1q0.val[0]; + uint8x8_t output_1 = Transpose32(p1p0xq1q0.val[1]); + + Transpose4x4(&output_0, &output_1); + + StoreLo4(dst, output_0); + StoreLo4(dst + stride, output_1); + StoreHi4(dst + 2 * stride, output_0); + StoreHi4(dst + 3 * stride, output_1); +} + +// IsFlat4 uses N=1, IsFlatOuter4 uses N=4. +// abs(p[N] - p0) <= flat_thresh && abs(q[N] - q0) <= flat_thresh && +// abs(p[N+1] - p0) <= flat_thresh && abs(q[N+1] - q0) <= flat_thresh && +// abs(p[N+2] - p0) <= flat_thresh && abs(q[N+1] - q0) <= flat_thresh +// |flat_thresh| == 1 for 8 bit decode. +inline uint8x8_t IsFlat4(const uint8x8_t abd_p0n0_q0n0, + const uint8x8_t abd_p0n1_q0n1, + const uint8x8_t abd_p0n2_q0n2) { + const uint8x8_t a = vmax_u8(abd_p0n0_q0n0, abd_p0n1_q0n1); + const uint8x8_t b = vmax_u8(a, abd_p0n2_q0n2); + const uint8x8_t c = vcle_u8(b, vdup_n_u8(1)); + return vand_u8(c, RightShift<32>(c)); +} + +// abs(p3 - p2) <= inner_thresh && abs(p2 - p1) <= inner_thresh && +// abs(p1 - p0) <= inner_thresh && abs(q1 - q0) <= inner_thresh && +// abs(q2 - q1) <= inner_thresh && abs(q3 - q2) <= inner_thresh +// OuterThreshhold() +inline uint8x8_t NeedsFilter8(const uint8x8_t abd_p0p1_q0q1, + const uint8x8_t abd_p1p2_q1q2, + const uint8x8_t abd_p2p3_q2q3, + const uint8x8_t p0q0, const uint8x8_t p1q1, + const uint8_t inner_thresh, + const uint8_t outer_thresh) { + const uint8x8_t a = vmax_u8(abd_p0p1_q0q1, abd_p1p2_q1q2); + const uint8x8_t b = vmax_u8(a, abd_p2p3_q2q3); + const uint8x8_t c = vcle_u8(b, vdup_n_u8(inner_thresh)); + const uint8x8_t inner_mask = vand_u8(c, RightShift<32>(c)); + const uint8x8_t outer_mask = OuterThreshold(p0q0, p1q1, outer_thresh); + return vand_u8(inner_mask, outer_mask); +} + +inline void Filter8Masks(const uint8x8_t p3q3, const uint8x8_t p2q2, + const uint8x8_t p1q1, const uint8x8_t p0q0, + const uint8_t hev_thresh, const uint8_t outer_thresh, + const uint8_t inner_thresh, + uint8x8_t* const needs_filter8_mask, + uint8x8_t* const is_flat4_mask, + uint8x8_t* const hev_mask) { + const uint8x8_t p0p1_q0q1 = vabd_u8(p0q0, p1q1); + *hev_mask = Hev(p0p1_q0q1, hev_thresh); + *is_flat4_mask = IsFlat4(p0p1_q0q1, vabd_u8(p0q0, p2q2), vabd_u8(p0q0, p3q3)); + *needs_filter8_mask = + NeedsFilter8(p0p1_q0q1, vabd_u8(p1q1, p2q2), vabd_u8(p2q2, p3q3), p0q0, + p1q1, inner_thresh, outer_thresh); +} + +inline void Filter8(const uint8x8_t p3q3, const uint8x8_t p2q2, + const uint8x8_t p1q1, const uint8x8_t p0q0, + uint8x8_t* const p2q2_output, uint8x8_t* const p1q1_output, + uint8x8_t* const p0q0_output) { + // Sum p2 and q2 output from opposite directions + // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0 + // ^^^^^^^^ + // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3) + // ^^^^^^^^ + uint16x8_t sum = vaddw_u8(vaddl_u8(p3q3, p3q3), p3q3); + + // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0 + // ^^^^^^^^ + // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3) + // ^^^^^^^^ + sum = vaddq_u16(vaddl_u8(p2q2, p2q2), sum); + + // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0 + // ^^^^^^^ + // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3) + // ^^^^^^^ + sum = vaddq_u16(vaddl_u8(p1q1, p0q0), sum); + + // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0 + // ^^ + // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3) + // ^^ + const uint8x8_t q0p0 = Transpose32(p0q0); + sum = vaddw_u8(sum, q0p0); + + *p2q2_output = vrshrn_n_u16(sum, 3); + + // Convert to p1 and q1 output: + // p1 = p2 - p3 - p2 + p1 + q1 + // q1 = q2 - q3 - q2 + q0 + p1 + sum = vsubq_u16(sum, vaddl_u8(p3q3, p2q2)); + const uint8x8_t q1p1 = Transpose32(p1q1); + sum = vaddq_u16(vaddl_u8(p1q1, q1p1), sum); + + *p1q1_output = vrshrn_n_u16(sum, 3); + + // Convert to p0 and q0 output: + // p0 = p1 - p3 - p1 + p0 + q2 + // q0 = q1 - q3 - q1 + q0 + p2 + sum = vsubq_u16(sum, vaddl_u8(p3q3, p1q1)); + const uint8x8_t q2p2 = Transpose32(p2q2); + sum = vaddq_u16(vaddl_u8(p0q0, q2p2), sum); + + *p0q0_output = vrshrn_n_u16(sum, 3); +} + +void Horizontal8_NEON(void* const dest, const ptrdiff_t stride, + const int outer_thresh, const int inner_thresh, + const int hev_thresh) { + auto* dst = static_cast<uint8_t*>(dest); + + const uint8x8_t p3_v = Load4(dst - 4 * stride); + const uint8x8_t p2_v = Load4(dst - 3 * stride); + const uint8x8_t p1_v = Load4(dst - 2 * stride); + const uint8x8_t p0_v = Load4(dst - stride); + const uint8x8_t p0q0 = Load4<1>(dst, p0_v); + const uint8x8_t p1q1 = Load4<1>(dst + stride, p1_v); + const uint8x8_t p2q2 = Load4<1>(dst + 2 * stride, p2_v); + const uint8x8_t p3q3 = Load4<1>(dst + 3 * stride, p3_v); + + uint8x8_t needs_filter8_mask, is_flat4_mask, hev_mask; + Filter8Masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh, + &needs_filter8_mask, &is_flat4_mask, &hev_mask); + + needs_filter8_mask = InterleaveLow32(needs_filter8_mask, needs_filter8_mask); + is_flat4_mask = vand_u8(is_flat4_mask, needs_filter8_mask); + is_flat4_mask = InterleaveLow32(is_flat4_mask, is_flat4_mask); + hev_mask = InterleaveLow32(hev_mask, hev_mask); + +#if defined(__aarch64__) + // This provides a good speedup for the unit test. Not sure how applicable it + // is to valid streams though. + // Consider doing this on armv7 if there is a quick way to check if a vector + // is zero. + if (vaddv_u8(needs_filter8_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // defined(__aarch64__) + + uint8x8_t f_p1q1; + uint8x8_t f_p0q0; + const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1); + Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0); + // Reset the outer values if only a Hev() mask was required. + f_p1q1 = vbsl_u8(hev_mask, p1q1, f_p1q1); + + uint8x8_t f8_p2q2, f8_p1q1, f8_p0q0; +#if defined(__aarch64__) + if (vaddv_u8(is_flat4_mask) == 0) { + // Filter8() does not apply. + const uint8x8_t zero = vdup_n_u8(0); + f8_p2q2 = zero; + f8_p1q1 = zero; + f8_p0q0 = zero; + } else { +#endif // defined(__aarch64__) + Filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0); + + const uint8x8_t p2p2_output = vbsl_u8(is_flat4_mask, f8_p2q2, p2q2); + StoreLo4(dst - 3 * stride, p2p2_output); + StoreHi4(dst + 2 * stride, p2p2_output); +#if defined(__aarch64__) + } +#endif // defined(__aarch64__) + + uint8x8_t p1q1_output = vbsl_u8(is_flat4_mask, f8_p1q1, f_p1q1); + p1q1_output = vbsl_u8(needs_filter8_mask, p1q1_output, p1q1); + StoreLo4(dst - 2 * stride, p1q1_output); + StoreHi4(dst + stride, p1q1_output); + + uint8x8_t p0q0_output = vbsl_u8(is_flat4_mask, f8_p0q0, f_p0q0); + p0q0_output = vbsl_u8(needs_filter8_mask, p0q0_output, p0q0); + StoreLo4(dst - stride, p0q0_output); + StoreHi4(dst, p0q0_output); +} + +void Vertical8_NEON(void* const dest, const ptrdiff_t stride, + const int outer_thresh, const int inner_thresh, + const int hev_thresh) { + auto* dst = static_cast<uint8_t*>(dest); + // Move |dst| to the left side of the filter window. + dst -= 4; + + // |p3q0|, |p2q1|, |p1q2| and |p0q3| are named for the values they will + // contain after the transpose. + uint8x8_t p3q0 = vld1_u8(dst); + uint8x8_t p2q1 = vld1_u8(dst + stride); + uint8x8_t p1q2 = vld1_u8(dst + 2 * stride); + uint8x8_t p0q3 = vld1_u8(dst + 3 * stride); + + Transpose8x4(&p3q0, &p2q1, &p1q2, &p0q3); + const uint8x8x2_t p3q3xq0p0 = Interleave32(p3q0, Transpose32(p0q3)); + const uint8x8_t p3q3 = p3q3xq0p0.val[0]; + const uint8x8_t p0q0 = Transpose32(p3q3xq0p0.val[1]); + const uint8x8x2_t p2q2xq1p1 = Interleave32(p2q1, Transpose32(p1q2)); + const uint8x8_t p2q2 = p2q2xq1p1.val[0]; + const uint8x8_t p1q1 = Transpose32(p2q2xq1p1.val[1]); + + uint8x8_t needs_filter8_mask, is_flat4_mask, hev_mask; + Filter8Masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh, + &needs_filter8_mask, &is_flat4_mask, &hev_mask); + + needs_filter8_mask = InterleaveLow32(needs_filter8_mask, needs_filter8_mask); + is_flat4_mask = vand_u8(is_flat4_mask, needs_filter8_mask); + is_flat4_mask = InterleaveLow32(is_flat4_mask, is_flat4_mask); + hev_mask = InterleaveLow32(hev_mask, hev_mask); + +#if defined(__aarch64__) + // This provides a good speedup for the unit test. Not sure how applicable it + // is to valid streams though. + // Consider doing this on armv7 if there is a quick way to check if a vector + // is zero. + if (vaddv_u8(needs_filter8_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // defined(__aarch64__) + + uint8x8_t f_p1q1; + uint8x8_t f_p0q0; + const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1); + Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0); + // Reset the outer values if only a Hev() mask was required. + f_p1q1 = vbsl_u8(hev_mask, p1q1, f_p1q1); + + uint8x8_t f8_p2q2, f8_p1q1, f8_p0q0; +#if defined(__aarch64__) + if (vaddv_u8(is_flat4_mask) == 0) { + // Filter8() does not apply. + const uint8x8_t zero = vdup_n_u8(0); + f8_p2q2 = zero; + f8_p1q1 = zero; + f8_p0q0 = zero; + } else { +#endif // defined(__aarch64__) + Filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0); + +#if defined(__aarch64__) + } +#endif // defined(__aarch64__) + + // Always prepare and store p2/q2 because we need to transpose it anyway. + const uint8x8_t p2q2_output = vbsl_u8(is_flat4_mask, f8_p2q2, p2q2); + + uint8x8_t p1q1_output = vbsl_u8(is_flat4_mask, f8_p1q1, f_p1q1); + p1q1_output = vbsl_u8(needs_filter8_mask, p1q1_output, p1q1); + + uint8x8_t p0q0_output = vbsl_u8(is_flat4_mask, f8_p0q0, f_p0q0); + p0q0_output = vbsl_u8(needs_filter8_mask, p0q0_output, p0q0); + + // Write out p3/q3 as well. There isn't a good way to write out 6 bytes. + // Variable names reflect the values before transposition. + const uint8x8x2_t p3q0xq3p0_output = + Interleave32(p3q3, Transpose32(p0q0_output)); + uint8x8_t p3q0_output = p3q0xq3p0_output.val[0]; + uint8x8_t p0q3_output = Transpose32(p3q0xq3p0_output.val[1]); + const uint8x8x2_t p2q1xq2p1_output = + Interleave32(p2q2_output, Transpose32(p1q1_output)); + uint8x8_t p2q1_output = p2q1xq2p1_output.val[0]; + uint8x8_t p1q2_output = Transpose32(p2q1xq2p1_output.val[1]); + + Transpose8x4(&p3q0_output, &p2q1_output, &p1q2_output, &p0q3_output); + + vst1_u8(dst, p3q0_output); + vst1_u8(dst + stride, p2q1_output); + vst1_u8(dst + 2 * stride, p1q2_output); + vst1_u8(dst + 3 * stride, p0q3_output); +} + +inline void Filter14(const uint8x8_t p6q6, const uint8x8_t p5q5, + const uint8x8_t p4q4, const uint8x8_t p3q3, + const uint8x8_t p2q2, const uint8x8_t p1q1, + const uint8x8_t p0q0, uint8x8_t* const p5q5_output, + uint8x8_t* const p4q4_output, uint8x8_t* const p3q3_output, + uint8x8_t* const p2q2_output, uint8x8_t* const p1q1_output, + uint8x8_t* const p0q0_output) { + // Sum p5 and q5 output from opposite directions + // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0 + // ^^^^^^^^ + // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6) + // ^^^^^^^^ + uint16x8_t sum = vsubw_u8(vshll_n_u8(p6q6, 3), p6q6); + + // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0 + // ^^^^^^^^ + // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6) + // ^^^^^^^^ + sum = vaddq_u16(vaddl_u8(p5q5, p5q5), sum); + + // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0 + // ^^^^^^^^ + // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6) + // ^^^^^^^^ + sum = vaddq_u16(vaddl_u8(p4q4, p4q4), sum); + + // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0 + // ^^^^^^^ + // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6) + // ^^^^^^^ + sum = vaddq_u16(vaddl_u8(p3q3, p2q2), sum); + + // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0 + // ^^^^^^^ + // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6) + // ^^^^^^^ + sum = vaddq_u16(vaddl_u8(p1q1, p0q0), sum); + + // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0 + // ^^ + // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6) + // ^^ + const uint8x8_t q0p0 = Transpose32(p0q0); + sum = vaddw_u8(sum, q0p0); + + *p5q5_output = vrshrn_n_u16(sum, 4); + + // Convert to p4 and q4 output: + // p4 = p5 - (2 * p6) + p3 + q1 + // q4 = q5 - (2 * q6) + q3 + p1 + sum = vsubq_u16(sum, vaddl_u8(p6q6, p6q6)); + const uint8x8_t q1p1 = Transpose32(p1q1); + sum = vaddq_u16(vaddl_u8(p3q3, q1p1), sum); + + *p4q4_output = vrshrn_n_u16(sum, 4); + + // Convert to p3 and q3 output: + // p3 = p4 - p6 - p5 + p2 + q2 + // q3 = q4 - q6 - q5 + q2 + p2 + sum = vsubq_u16(sum, vaddl_u8(p6q6, p5q5)); + const uint8x8_t q2p2 = Transpose32(p2q2); + sum = vaddq_u16(vaddl_u8(p2q2, q2p2), sum); + + *p3q3_output = vrshrn_n_u16(sum, 4); + + // Convert to p2 and q2 output: + // p2 = p3 - p6 - p4 + p1 + q3 + // q2 = q3 - q6 - q4 + q1 + p3 + sum = vsubq_u16(sum, vaddl_u8(p6q6, p4q4)); + const uint8x8_t q3p3 = Transpose32(p3q3); + sum = vaddq_u16(vaddl_u8(p1q1, q3p3), sum); + + *p2q2_output = vrshrn_n_u16(sum, 4); + + // Convert to p1 and q1 output: + // p1 = p2 - p6 - p3 + p0 + q4 + // q1 = q2 - q6 - q3 + q0 + p4 + sum = vsubq_u16(sum, vaddl_u8(p6q6, p3q3)); + const uint8x8_t q4p4 = Transpose32(p4q4); + sum = vaddq_u16(vaddl_u8(p0q0, q4p4), sum); + + *p1q1_output = vrshrn_n_u16(sum, 4); + + // Convert to p0 and q0 output: + // p0 = p1 - p6 - p2 + q0 + q5 + // q0 = q1 - q6 - q2 + p0 + p5 + sum = vsubq_u16(sum, vaddl_u8(p6q6, p2q2)); + const uint8x8_t q5p5 = Transpose32(p5q5); + sum = vaddq_u16(vaddl_u8(q0p0, q5p5), sum); + + *p0q0_output = vrshrn_n_u16(sum, 4); +} + +void Horizontal14_NEON(void* const dest, const ptrdiff_t stride, + const int outer_thresh, const int inner_thresh, + const int hev_thresh) { + auto* dst = static_cast<uint8_t*>(dest); + + const uint8x8_t p6_v = Load4(dst - 7 * stride); + const uint8x8_t p5_v = Load4(dst - 6 * stride); + const uint8x8_t p4_v = Load4(dst - 5 * stride); + const uint8x8_t p3_v = Load4(dst - 4 * stride); + const uint8x8_t p2_v = Load4(dst - 3 * stride); + const uint8x8_t p1_v = Load4(dst - 2 * stride); + const uint8x8_t p0_v = Load4(dst - stride); + const uint8x8_t p0q0 = Load4<1>(dst, p0_v); + const uint8x8_t p1q1 = Load4<1>(dst + stride, p1_v); + const uint8x8_t p2q2 = Load4<1>(dst + 2 * stride, p2_v); + const uint8x8_t p3q3 = Load4<1>(dst + 3 * stride, p3_v); + const uint8x8_t p4q4 = Load4<1>(dst + 4 * stride, p4_v); + const uint8x8_t p5q5 = Load4<1>(dst + 5 * stride, p5_v); + const uint8x8_t p6q6 = Load4<1>(dst + 6 * stride, p6_v); + + uint8x8_t needs_filter8_mask, is_flat4_mask, hev_mask; + Filter8Masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh, + &needs_filter8_mask, &is_flat4_mask, &hev_mask); + + needs_filter8_mask = InterleaveLow32(needs_filter8_mask, needs_filter8_mask); + is_flat4_mask = vand_u8(is_flat4_mask, needs_filter8_mask); + is_flat4_mask = InterleaveLow32(is_flat4_mask, is_flat4_mask); + hev_mask = InterleaveLow32(hev_mask, hev_mask); + +#if defined(__aarch64__) + // This provides a good speedup for the unit test. Not sure how applicable it + // is to valid streams though. + // Consider doing this on armv7 if there is a quick way to check if a vector + // is zero. + if (vaddv_u8(needs_filter8_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // defined(__aarch64__) + + // Decide between Filter8() and Filter14(). + uint8x8_t is_flat_outer4_mask = + IsFlat4(vabd_u8(p0q0, p4q4), vabd_u8(p0q0, p5q5), vabd_u8(p0q0, p6q6)); + is_flat_outer4_mask = vand_u8(is_flat4_mask, is_flat_outer4_mask); + is_flat_outer4_mask = + InterleaveLow32(is_flat_outer4_mask, is_flat_outer4_mask); + + uint8x8_t f_p1q1; + uint8x8_t f_p0q0; + const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1); + Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0); + // Reset the outer values if only a Hev() mask was required. + f_p1q1 = vbsl_u8(hev_mask, p1q1, f_p1q1); + + uint8x8_t f8_p1q1, f8_p0q0; + uint8x8_t f14_p2q2, f14_p1q1, f14_p0q0; +#if defined(__aarch64__) + if (vaddv_u8(is_flat4_mask) == 0) { + // Filter8() and Filter14() do not apply. + const uint8x8_t zero = vdup_n_u8(0); + f8_p1q1 = zero; + f8_p0q0 = zero; + f14_p1q1 = zero; + f14_p0q0 = zero; + } else { +#endif // defined(__aarch64__) + uint8x8_t f8_p2q2; + Filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0); + +#if defined(__aarch64__) + if (vaddv_u8(is_flat_outer4_mask) == 0) { + // Filter14() does not apply. + const uint8x8_t zero = vdup_n_u8(0); + f14_p2q2 = zero; + f14_p1q1 = zero; + f14_p0q0 = zero; + } else { +#endif // defined(__aarch64__) + uint8x8_t f14_p5q5, f14_p4q4, f14_p3q3; + Filter14(p6q6, p5q5, p4q4, p3q3, p2q2, p1q1, p0q0, &f14_p5q5, &f14_p4q4, + &f14_p3q3, &f14_p2q2, &f14_p1q1, &f14_p0q0); + + const uint8x8_t p5q5_output = + vbsl_u8(is_flat_outer4_mask, f14_p5q5, p5q5); + StoreLo4(dst - 6 * stride, p5q5_output); + StoreHi4(dst + 5 * stride, p5q5_output); + + const uint8x8_t p4q4_output = + vbsl_u8(is_flat_outer4_mask, f14_p4q4, p4q4); + StoreLo4(dst - 5 * stride, p4q4_output); + StoreHi4(dst + 4 * stride, p4q4_output); + + const uint8x8_t p3q3_output = + vbsl_u8(is_flat_outer4_mask, f14_p3q3, p3q3); + StoreLo4(dst - 4 * stride, p3q3_output); + StoreHi4(dst + 3 * stride, p3q3_output); +#if defined(__aarch64__) + } +#endif // defined(__aarch64__) + + uint8x8_t p2q2_output = vbsl_u8(is_flat_outer4_mask, f14_p2q2, f8_p2q2); + p2q2_output = vbsl_u8(is_flat4_mask, p2q2_output, p2q2); + StoreLo4(dst - 3 * stride, p2q2_output); + StoreHi4(dst + 2 * stride, p2q2_output); +#if defined(__aarch64__) + } +#endif // defined(__aarch64__) + + uint8x8_t p1q1_output = vbsl_u8(is_flat_outer4_mask, f14_p1q1, f8_p1q1); + p1q1_output = vbsl_u8(is_flat4_mask, p1q1_output, f_p1q1); + p1q1_output = vbsl_u8(needs_filter8_mask, p1q1_output, p1q1); + StoreLo4(dst - 2 * stride, p1q1_output); + StoreHi4(dst + stride, p1q1_output); + + uint8x8_t p0q0_output = vbsl_u8(is_flat_outer4_mask, f14_p0q0, f8_p0q0); + p0q0_output = vbsl_u8(is_flat4_mask, p0q0_output, f_p0q0); + p0q0_output = vbsl_u8(needs_filter8_mask, p0q0_output, p0q0); + StoreLo4(dst - stride, p0q0_output); + StoreHi4(dst, p0q0_output); +} + +void Vertical14_NEON(void* const dest, const ptrdiff_t stride, + const int outer_thresh, const int inner_thresh, + const int hev_thresh) { + auto* dst = static_cast<uint8_t*>(dest); + dst -= 8; + // input + // p7 p6 p5 p4 p3 p2 p1 p0 q0 q1 q2 q3 q4 q5 q6 q7 + const uint8x16_t x0 = vld1q_u8(dst); + dst += stride; + const uint8x16_t x1 = vld1q_u8(dst); + dst += stride; + const uint8x16_t x2 = vld1q_u8(dst); + dst += stride; + const uint8x16_t x3 = vld1q_u8(dst); + dst -= (stride * 3); + + // re-order input +#if defined(__aarch64__) + const uint8x8_t index_qp3toqp0 = vcreate_u8(0x0b0a090804050607); + const uint8x8_t index_qp7toqp4 = vcreate_u8(0x0f0e0d0c00010203); + const uint8x16_t index_qp7toqp0 = vcombine_u8(index_qp3toqp0, index_qp7toqp4); + + uint8x16_t input_0 = vqtbl1q_u8(x0, index_qp7toqp0); + uint8x16_t input_1 = vqtbl1q_u8(x1, index_qp7toqp0); + uint8x16_t input_2 = vqtbl1q_u8(x2, index_qp7toqp0); + uint8x16_t input_3 = vqtbl1q_u8(x3, index_qp7toqp0); +#else + const uint8x8_t index_qp3toqp0 = vcreate_u8(0x0b0a090804050607); + const uint8x8_t index_qp7toqp4 = vcreate_u8(0x0f0e0d0c00010203); + + const uint8x8_t x0_qp3qp0 = VQTbl1U8(x0, index_qp3toqp0); + const uint8x8_t x1_qp3qp0 = VQTbl1U8(x1, index_qp3toqp0); + const uint8x8_t x2_qp3qp0 = VQTbl1U8(x2, index_qp3toqp0); + const uint8x8_t x3_qp3qp0 = VQTbl1U8(x3, index_qp3toqp0); + + const uint8x8_t x0_qp7qp4 = VQTbl1U8(x0, index_qp7toqp4); + const uint8x8_t x1_qp7qp4 = VQTbl1U8(x1, index_qp7toqp4); + const uint8x8_t x2_qp7qp4 = VQTbl1U8(x2, index_qp7toqp4); + const uint8x8_t x3_qp7qp4 = VQTbl1U8(x3, index_qp7toqp4); + + const uint8x16_t input_0 = vcombine_u8(x0_qp3qp0, x0_qp7qp4); + const uint8x16_t input_1 = vcombine_u8(x1_qp3qp0, x1_qp7qp4); + const uint8x16_t input_2 = vcombine_u8(x2_qp3qp0, x2_qp7qp4); + const uint8x16_t input_3 = vcombine_u8(x3_qp3qp0, x3_qp7qp4); +#endif + // input after re-order + // p0 p1 p2 p3 q0 q1 q2 q3 p4 p5 p6 p7 q4 q5 q6 q7 + + const uint8x16x2_t in01 = vtrnq_u8(input_0, input_1); + const uint8x16x2_t in23 = vtrnq_u8(input_2, input_3); + const uint16x8x2_t in02 = vtrnq_u16(vreinterpretq_u16_u8(in01.val[0]), + vreinterpretq_u16_u8(in23.val[0])); + const uint16x8x2_t in13 = vtrnq_u16(vreinterpretq_u16_u8(in01.val[1]), + vreinterpretq_u16_u8(in23.val[1])); + + const uint8x8_t p0q0 = vget_low_u8(vreinterpretq_u8_u16(in02.val[0])); + const uint8x8_t p1q1 = vget_low_u8(vreinterpretq_u8_u16(in13.val[0])); + + const uint8x8_t p2q2 = vget_low_u8(vreinterpretq_u8_u16(in02.val[1])); + const uint8x8_t p3q3 = vget_low_u8(vreinterpretq_u8_u16(in13.val[1])); + + const uint8x8_t p4q4 = vget_high_u8(vreinterpretq_u8_u16(in02.val[0])); + const uint8x8_t p5q5 = vget_high_u8(vreinterpretq_u8_u16(in13.val[0])); + + const uint8x8_t p6q6 = vget_high_u8(vreinterpretq_u8_u16(in02.val[1])); + const uint8x8_t p7q7 = vget_high_u8(vreinterpretq_u8_u16(in13.val[1])); + + uint8x8_t needs_filter8_mask, is_flat4_mask, hev_mask; + Filter8Masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh, + &needs_filter8_mask, &is_flat4_mask, &hev_mask); + + needs_filter8_mask = InterleaveLow32(needs_filter8_mask, needs_filter8_mask); + is_flat4_mask = vand_u8(is_flat4_mask, needs_filter8_mask); + is_flat4_mask = InterleaveLow32(is_flat4_mask, is_flat4_mask); + hev_mask = InterleaveLow32(hev_mask, hev_mask); + +#if defined(__aarch64__) + // This provides a good speedup for the unit test. Not sure how applicable it + // is to valid streams though. + // Consider doing this on armv7 if there is a quick way to check if a vector + // is zero. + if (vaddv_u8(needs_filter8_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // defined(__aarch64__) + + // Decide between Filter8() and Filter14(). + uint8x8_t is_flat_outer4_mask = + IsFlat4(vabd_u8(p0q0, p4q4), vabd_u8(p0q0, p5q5), vabd_u8(p0q0, p6q6)); + is_flat_outer4_mask = vand_u8(is_flat4_mask, is_flat_outer4_mask); + is_flat_outer4_mask = + InterleaveLow32(is_flat_outer4_mask, is_flat_outer4_mask); + + uint8x8_t f_p0q0, f_p1q1; + const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1); + Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0); + // Reset the outer values if only a Hev() mask was required. + f_p1q1 = vbsl_u8(hev_mask, p1q1, f_p1q1); + + uint8x8_t p1q1_output, p0q0_output; + uint8x8_t p5q5_output, p4q4_output, p3q3_output, p2q2_output; + +#if defined(__aarch64__) + if (vaddv_u8(is_flat4_mask) == 0) { + // Filter8() and Filter14() do not apply. + p1q1_output = p1q1; + p0q0_output = p0q0; + + p5q5_output = p5q5; + p4q4_output = p4q4; + p3q3_output = p3q3; + p2q2_output = p2q2; + } else { +#endif // defined(__aarch64__) + uint8x8_t f8_p2q2, f8_p1q1, f8_p0q0; + Filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0); + +#if defined(__aarch64__) + if (vaddv_u8(is_flat_outer4_mask) == 0) { + // Filter14() does not apply. + p5q5_output = p5q5; + p4q4_output = p4q4; + p3q3_output = p3q3; + p2q2_output = f8_p2q2; + p1q1_output = f8_p1q1; + p0q0_output = f8_p0q0; + } else { +#endif // defined(__aarch64__) + uint8x8_t f14_p5q5, f14_p4q4, f14_p3q3, f14_p2q2, f14_p1q1, f14_p0q0; + Filter14(p6q6, p5q5, p4q4, p3q3, p2q2, p1q1, p0q0, &f14_p5q5, &f14_p4q4, + &f14_p3q3, &f14_p2q2, &f14_p1q1, &f14_p0q0); + + p5q5_output = vbsl_u8(is_flat_outer4_mask, f14_p5q5, p5q5); + p4q4_output = vbsl_u8(is_flat_outer4_mask, f14_p4q4, p4q4); + p3q3_output = vbsl_u8(is_flat_outer4_mask, f14_p3q3, p3q3); + p2q2_output = vbsl_u8(is_flat_outer4_mask, f14_p2q2, f8_p2q2); + p1q1_output = vbsl_u8(is_flat_outer4_mask, f14_p1q1, f8_p1q1); + p0q0_output = vbsl_u8(is_flat_outer4_mask, f14_p0q0, f8_p0q0); +#if defined(__aarch64__) + } +#endif // defined(__aarch64__) + p2q2_output = vbsl_u8(is_flat4_mask, p2q2_output, p2q2); +#if defined(__aarch64__) + } +#endif // defined(__aarch64__) + + p1q1_output = vbsl_u8(is_flat4_mask, p1q1_output, f_p1q1); + p1q1_output = vbsl_u8(needs_filter8_mask, p1q1_output, p1q1); + p0q0_output = vbsl_u8(is_flat4_mask, p0q0_output, f_p0q0); + p0q0_output = vbsl_u8(needs_filter8_mask, p0q0_output, p0q0); + + const uint8x16_t p0q0_p4q4 = vcombine_u8(p0q0_output, p4q4_output); + const uint8x16_t p2q2_p6q6 = vcombine_u8(p2q2_output, p6q6); + const uint8x16_t p1q1_p5q5 = vcombine_u8(p1q1_output, p5q5_output); + const uint8x16_t p3q3_p7q7 = vcombine_u8(p3q3_output, p7q7); + + const uint16x8x2_t out02 = vtrnq_u16(vreinterpretq_u16_u8(p0q0_p4q4), + vreinterpretq_u16_u8(p2q2_p6q6)); + const uint16x8x2_t out13 = vtrnq_u16(vreinterpretq_u16_u8(p1q1_p5q5), + vreinterpretq_u16_u8(p3q3_p7q7)); + const uint8x16x2_t out01 = vtrnq_u8(vreinterpretq_u8_u16(out02.val[0]), + vreinterpretq_u8_u16(out13.val[0])); + const uint8x16x2_t out23 = vtrnq_u8(vreinterpretq_u8_u16(out02.val[1]), + vreinterpretq_u8_u16(out13.val[1])); + +#if defined(__aarch64__) + const uint8x8_t index_p7top0 = vcreate_u8(0x0001020308090a0b); + const uint8x8_t index_q7toq0 = vcreate_u8(0x0f0e0d0c07060504); + const uint8x16_t index_p7toq7 = vcombine_u8(index_p7top0, index_q7toq0); + + const uint8x16_t output_0 = vqtbl1q_u8(out01.val[0], index_p7toq7); + const uint8x16_t output_1 = vqtbl1q_u8(out01.val[1], index_p7toq7); + const uint8x16_t output_2 = vqtbl1q_u8(out23.val[0], index_p7toq7); + const uint8x16_t output_3 = vqtbl1q_u8(out23.val[1], index_p7toq7); +#else + const uint8x8_t index_p7top0 = vcreate_u8(0x0001020308090a0b); + const uint8x8_t index_q7toq0 = vcreate_u8(0x0f0e0d0c07060504); + + const uint8x8_t x0_p7p0 = VQTbl1U8(out01.val[0], index_p7top0); + const uint8x8_t x1_p7p0 = VQTbl1U8(out01.val[1], index_p7top0); + const uint8x8_t x2_p7p0 = VQTbl1U8(out23.val[0], index_p7top0); + const uint8x8_t x3_p7p0 = VQTbl1U8(out23.val[1], index_p7top0); + + const uint8x8_t x0_q7q0 = VQTbl1U8(out01.val[0], index_q7toq0); + const uint8x8_t x1_q7q0 = VQTbl1U8(out01.val[1], index_q7toq0); + const uint8x8_t x2_q7q0 = VQTbl1U8(out23.val[0], index_q7toq0); + const uint8x8_t x3_q7q0 = VQTbl1U8(out23.val[1], index_q7toq0); + + const uint8x16_t output_0 = vcombine_u8(x0_p7p0, x0_q7q0); + const uint8x16_t output_1 = vcombine_u8(x1_p7p0, x1_q7q0); + const uint8x16_t output_2 = vcombine_u8(x2_p7p0, x2_q7q0); + const uint8x16_t output_3 = vcombine_u8(x3_p7p0, x3_q7q0); +#endif + + vst1q_u8(dst, output_0); + dst += stride; + vst1q_u8(dst, output_1); + dst += stride; + vst1q_u8(dst, output_2); + dst += stride; + vst1q_u8(dst, output_3); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeHorizontal] = + Horizontal4_NEON; + dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeVertical] = Vertical4_NEON; + + dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeHorizontal] = + Horizontal6_NEON; + dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeVertical] = Vertical6_NEON; + + dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeHorizontal] = + Horizontal8_NEON; + dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeVertical] = Vertical8_NEON; + + dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeHorizontal] = + Horizontal14_NEON; + dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeVertical] = + Vertical14_NEON; +} +} // namespace +} // namespace low_bitdepth + +void LoopFilterInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void LoopFilterInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/loop_filter_neon.h b/src/dsp/arm/loop_filter_neon.h new file mode 100644 index 0000000..5f79200 --- /dev/null +++ b/src/dsp/arm/loop_filter_neon.h @@ -0,0 +1,53 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_LOOP_FILTER_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_LOOP_FILTER_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::loop_filters, see the defines below for specifics. This +// function is not thread-safe. +void LoopFilterInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON + +#define LIBGAV1_Dsp8bpp_LoopFilterSize4_LoopFilterTypeHorizontal \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_LoopFilterSize4_LoopFilterTypeVertical LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_LoopFilterSize6_LoopFilterTypeHorizontal \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_LoopFilterSize6_LoopFilterTypeVertical LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_LoopFilterSize8_LoopFilterTypeHorizontal \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_LoopFilterSize8_LoopFilterTypeVertical LIBGAV1_CPU_NEON + +#define LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeHorizontal \ + LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeVertical LIBGAV1_CPU_NEON + +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_LOOP_FILTER_NEON_H_ diff --git a/src/dsp/arm/loop_restoration_neon.cc b/src/dsp/arm/loop_restoration_neon.cc new file mode 100644 index 0000000..337c9b4 --- /dev/null +++ b/src/dsp/arm/loop_restoration_neon.cc @@ -0,0 +1,1901 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/loop_restoration.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON +#include <arm_neon.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +template <int bytes> +inline uint8x8_t VshrU128(const uint8x8x2_t src) { + return vext_u8(src.val[0], src.val[1], bytes); +} + +template <int bytes> +inline uint16x8_t VshrU128(const uint16x8x2_t src) { + return vextq_u16(src.val[0], src.val[1], bytes / 2); +} + +// Wiener + +// Must make a local copy of coefficients to help compiler know that they have +// no overlap with other buffers. Using 'const' keyword is not enough. Actually +// compiler doesn't make a copy, since there is enough registers in this case. +inline void PopulateWienerCoefficients( + const RestorationUnitInfo& restoration_info, const int direction, + int16_t filter[4]) { + // In order to keep the horizontal pass intermediate values within 16 bits we + // offset |filter[3]| by 128. The 128 offset will be added back in the loop. + for (int i = 0; i < 4; ++i) { + filter[i] = restoration_info.wiener_info.filter[direction][i]; + } + if (direction == WienerInfo::kHorizontal) { + filter[3] -= 128; + } +} + +inline int16x8_t WienerHorizontal2(const uint8x8_t s0, const uint8x8_t s1, + const int16_t filter, const int16x8_t sum) { + const int16x8_t ss = vreinterpretq_s16_u16(vaddl_u8(s0, s1)); + return vmlaq_n_s16(sum, ss, filter); +} + +inline int16x8x2_t WienerHorizontal2(const uint8x16_t s0, const uint8x16_t s1, + const int16_t filter, + const int16x8x2_t sum) { + int16x8x2_t d; + d.val[0] = + WienerHorizontal2(vget_low_u8(s0), vget_low_u8(s1), filter, sum.val[0]); + d.val[1] = + WienerHorizontal2(vget_high_u8(s0), vget_high_u8(s1), filter, sum.val[1]); + return d; +} + +inline void WienerHorizontalSum(const uint8x8_t s[3], const int16_t filter[4], + int16x8_t sum, int16_t* const wiener_buffer) { + constexpr int offset = + 1 << (8 + kWienerFilterBits - kInterRoundBitsHorizontal - 1); + constexpr int limit = (offset << 2) - 1; + const int16x8_t s_0_2 = vreinterpretq_s16_u16(vaddl_u8(s[0], s[2])); + const int16x8_t s_1 = ZeroExtend(s[1]); + sum = vmlaq_n_s16(sum, s_0_2, filter[2]); + sum = vmlaq_n_s16(sum, s_1, filter[3]); + // Calculate scaled down offset correction, and add to sum here to prevent + // signed 16 bit outranging. + sum = vrsraq_n_s16(vshlq_n_s16(s_1, 7 - kInterRoundBitsHorizontal), sum, + kInterRoundBitsHorizontal); + sum = vmaxq_s16(sum, vdupq_n_s16(-offset)); + sum = vminq_s16(sum, vdupq_n_s16(limit - offset)); + vst1q_s16(wiener_buffer, sum); +} + +inline void WienerHorizontalSum(const uint8x16_t src[3], + const int16_t filter[4], int16x8x2_t sum, + int16_t* const wiener_buffer) { + uint8x8_t s[3]; + s[0] = vget_low_u8(src[0]); + s[1] = vget_low_u8(src[1]); + s[2] = vget_low_u8(src[2]); + WienerHorizontalSum(s, filter, sum.val[0], wiener_buffer); + s[0] = vget_high_u8(src[0]); + s[1] = vget_high_u8(src[1]); + s[2] = vget_high_u8(src[2]); + WienerHorizontalSum(s, filter, sum.val[1], wiener_buffer + 8); +} + +inline void WienerHorizontalTap7(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + const int16_t filter[4], + int16_t** const wiener_buffer) { + for (int y = height; y != 0; --y) { + const uint8_t* src_ptr = src; + uint8x16_t s[8]; + s[0] = vld1q_u8(src_ptr); + ptrdiff_t x = width; + do { + src_ptr += 16; + s[7] = vld1q_u8(src_ptr); + s[1] = vextq_u8(s[0], s[7], 1); + s[2] = vextq_u8(s[0], s[7], 2); + s[3] = vextq_u8(s[0], s[7], 3); + s[4] = vextq_u8(s[0], s[7], 4); + s[5] = vextq_u8(s[0], s[7], 5); + s[6] = vextq_u8(s[0], s[7], 6); + int16x8x2_t sum; + sum.val[0] = sum.val[1] = vdupq_n_s16(0); + sum = WienerHorizontal2(s[0], s[6], filter[0], sum); + sum = WienerHorizontal2(s[1], s[5], filter[1], sum); + WienerHorizontalSum(s + 2, filter, sum, *wiener_buffer); + s[0] = s[7]; + *wiener_buffer += 16; + x -= 16; + } while (x != 0); + src += src_stride; + } +} + +inline void WienerHorizontalTap5(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + const int16_t filter[4], + int16_t** const wiener_buffer) { + for (int y = height; y != 0; --y) { + const uint8_t* src_ptr = src; + uint8x16_t s[6]; + s[0] = vld1q_u8(src_ptr); + ptrdiff_t x = width; + do { + src_ptr += 16; + s[5] = vld1q_u8(src_ptr); + s[1] = vextq_u8(s[0], s[5], 1); + s[2] = vextq_u8(s[0], s[5], 2); + s[3] = vextq_u8(s[0], s[5], 3); + s[4] = vextq_u8(s[0], s[5], 4); + int16x8x2_t sum; + sum.val[0] = sum.val[1] = vdupq_n_s16(0); + sum = WienerHorizontal2(s[0], s[4], filter[1], sum); + WienerHorizontalSum(s + 1, filter, sum, *wiener_buffer); + s[0] = s[5]; + *wiener_buffer += 16; + x -= 16; + } while (x != 0); + src += src_stride; + } +} + +inline void WienerHorizontalTap3(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + const int16_t filter[4], + int16_t** const wiener_buffer) { + for (int y = height; y != 0; --y) { + const uint8_t* src_ptr = src; + uint8x16_t s[4]; + s[0] = vld1q_u8(src_ptr); + ptrdiff_t x = width; + do { + src_ptr += 16; + s[3] = vld1q_u8(src_ptr); + s[1] = vextq_u8(s[0], s[3], 1); + s[2] = vextq_u8(s[0], s[3], 2); + int16x8x2_t sum; + sum.val[0] = sum.val[1] = vdupq_n_s16(0); + WienerHorizontalSum(s, filter, sum, *wiener_buffer); + s[0] = s[3]; + *wiener_buffer += 16; + x -= 16; + } while (x != 0); + src += src_stride; + } +} + +inline void WienerHorizontalTap1(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + int16_t** const wiener_buffer) { + for (int y = height; y != 0; --y) { + const uint8_t* src_ptr = src; + ptrdiff_t x = width; + do { + const uint8x16_t s = vld1q_u8(src_ptr); + const uint8x8_t s0 = vget_low_u8(s); + const uint8x8_t s1 = vget_high_u8(s); + const int16x8_t d0 = vreinterpretq_s16_u16(vshll_n_u8(s0, 4)); + const int16x8_t d1 = vreinterpretq_s16_u16(vshll_n_u8(s1, 4)); + vst1q_s16(*wiener_buffer + 0, d0); + vst1q_s16(*wiener_buffer + 8, d1); + src_ptr += 16; + *wiener_buffer += 16; + x -= 16; + } while (x != 0); + src += src_stride; + } +} + +inline int32x4x2_t WienerVertical2(const int16x8_t a0, const int16x8_t a1, + const int16_t filter, + const int32x4x2_t sum) { + const int16x8_t a = vaddq_s16(a0, a1); + int32x4x2_t d; + d.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(a), filter); + d.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(a), filter); + return d; +} + +inline uint8x8_t WienerVertical(const int16x8_t a[3], const int16_t filter[4], + const int32x4x2_t sum) { + int32x4x2_t d = WienerVertical2(a[0], a[2], filter[2], sum); + d.val[0] = vmlal_n_s16(d.val[0], vget_low_s16(a[1]), filter[3]); + d.val[1] = vmlal_n_s16(d.val[1], vget_high_s16(a[1]), filter[3]); + const uint16x4_t sum_lo_16 = vqrshrun_n_s32(d.val[0], 11); + const uint16x4_t sum_hi_16 = vqrshrun_n_s32(d.val[1], 11); + return vqmovn_u16(vcombine_u16(sum_lo_16, sum_hi_16)); +} + +inline uint8x8_t WienerVerticalTap7Kernel(const int16_t* const wiener_buffer, + const ptrdiff_t wiener_stride, + const int16_t filter[4], + int16x8_t a[7]) { + int32x4x2_t sum; + a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride); + a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride); + a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride); + a[6] = vld1q_s16(wiener_buffer + 6 * wiener_stride); + sum.val[0] = sum.val[1] = vdupq_n_s32(0); + sum = WienerVertical2(a[0], a[6], filter[0], sum); + sum = WienerVertical2(a[1], a[5], filter[1], sum); + a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride); + a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride); + a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride); + return WienerVertical(a + 2, filter, sum); +} + +inline uint8x8x2_t WienerVerticalTap7Kernel2(const int16_t* const wiener_buffer, + const ptrdiff_t wiener_stride, + const int16_t filter[4]) { + int16x8_t a[8]; + int32x4x2_t sum; + uint8x8x2_t d; + d.val[0] = WienerVerticalTap7Kernel(wiener_buffer, wiener_stride, filter, a); + a[7] = vld1q_s16(wiener_buffer + 7 * wiener_stride); + sum.val[0] = sum.val[1] = vdupq_n_s32(0); + sum = WienerVertical2(a[1], a[7], filter[0], sum); + sum = WienerVertical2(a[2], a[6], filter[1], sum); + d.val[1] = WienerVertical(a + 3, filter, sum); + return d; +} + +inline void WienerVerticalTap7(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + const int16_t filter[4], uint8_t* dst, + const ptrdiff_t dst_stride) { + for (int y = height >> 1; y != 0; --y) { + uint8_t* dst_ptr = dst; + ptrdiff_t x = width; + do { + uint8x8x2_t d[2]; + d[0] = WienerVerticalTap7Kernel2(wiener_buffer + 0, width, filter); + d[1] = WienerVerticalTap7Kernel2(wiener_buffer + 8, width, filter); + vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0])); + vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1])); + wiener_buffer += 16; + dst_ptr += 16; + x -= 16; + } while (x != 0); + wiener_buffer += width; + dst += 2 * dst_stride; + } + + if ((height & 1) != 0) { + ptrdiff_t x = width; + do { + int16x8_t a[7]; + const uint8x8_t d0 = + WienerVerticalTap7Kernel(wiener_buffer + 0, width, filter, a); + const uint8x8_t d1 = + WienerVerticalTap7Kernel(wiener_buffer + 8, width, filter, a); + vst1q_u8(dst, vcombine_u8(d0, d1)); + wiener_buffer += 16; + dst += 16; + x -= 16; + } while (x != 0); + } +} + +inline uint8x8_t WienerVerticalTap5Kernel(const int16_t* const wiener_buffer, + const ptrdiff_t wiener_stride, + const int16_t filter[4], + int16x8_t a[5]) { + a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride); + a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride); + a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride); + a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride); + a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride); + int32x4x2_t sum; + sum.val[0] = sum.val[1] = vdupq_n_s32(0); + sum = WienerVertical2(a[0], a[4], filter[1], sum); + return WienerVertical(a + 1, filter, sum); +} + +inline uint8x8x2_t WienerVerticalTap5Kernel2(const int16_t* const wiener_buffer, + const ptrdiff_t wiener_stride, + const int16_t filter[4]) { + int16x8_t a[6]; + int32x4x2_t sum; + uint8x8x2_t d; + d.val[0] = WienerVerticalTap5Kernel(wiener_buffer, wiener_stride, filter, a); + a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride); + sum.val[0] = sum.val[1] = vdupq_n_s32(0); + sum = WienerVertical2(a[1], a[5], filter[1], sum); + d.val[1] = WienerVertical(a + 2, filter, sum); + return d; +} + +inline void WienerVerticalTap5(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + const int16_t filter[4], uint8_t* dst, + const ptrdiff_t dst_stride) { + for (int y = height >> 1; y != 0; --y) { + uint8_t* dst_ptr = dst; + ptrdiff_t x = width; + do { + uint8x8x2_t d[2]; + d[0] = WienerVerticalTap5Kernel2(wiener_buffer + 0, width, filter); + d[1] = WienerVerticalTap5Kernel2(wiener_buffer + 8, width, filter); + vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0])); + vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1])); + wiener_buffer += 16; + dst_ptr += 16; + x -= 16; + } while (x != 0); + wiener_buffer += width; + dst += 2 * dst_stride; + } + + if ((height & 1) != 0) { + ptrdiff_t x = width; + do { + int16x8_t a[5]; + const uint8x8_t d0 = + WienerVerticalTap5Kernel(wiener_buffer + 0, width, filter, a); + const uint8x8_t d1 = + WienerVerticalTap5Kernel(wiener_buffer + 8, width, filter, a); + vst1q_u8(dst, vcombine_u8(d0, d1)); + wiener_buffer += 16; + dst += 16; + x -= 16; + } while (x != 0); + } +} + +inline uint8x8_t WienerVerticalTap3Kernel(const int16_t* const wiener_buffer, + const ptrdiff_t wiener_stride, + const int16_t filter[4], + int16x8_t a[3]) { + a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride); + a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride); + a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride); + int32x4x2_t sum; + sum.val[0] = sum.val[1] = vdupq_n_s32(0); + return WienerVertical(a, filter, sum); +} + +inline uint8x8x2_t WienerVerticalTap3Kernel2(const int16_t* const wiener_buffer, + const ptrdiff_t wiener_stride, + const int16_t filter[4]) { + int16x8_t a[4]; + int32x4x2_t sum; + uint8x8x2_t d; + d.val[0] = WienerVerticalTap3Kernel(wiener_buffer, wiener_stride, filter, a); + a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride); + sum.val[0] = sum.val[1] = vdupq_n_s32(0); + d.val[1] = WienerVertical(a + 1, filter, sum); + return d; +} + +inline void WienerVerticalTap3(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + const int16_t filter[4], uint8_t* dst, + const ptrdiff_t dst_stride) { + for (int y = height >> 1; y != 0; --y) { + uint8_t* dst_ptr = dst; + ptrdiff_t x = width; + do { + uint8x8x2_t d[2]; + d[0] = WienerVerticalTap3Kernel2(wiener_buffer + 0, width, filter); + d[1] = WienerVerticalTap3Kernel2(wiener_buffer + 8, width, filter); + vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0])); + vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1])); + wiener_buffer += 16; + dst_ptr += 16; + x -= 16; + } while (x != 0); + wiener_buffer += width; + dst += 2 * dst_stride; + } + + if ((height & 1) != 0) { + ptrdiff_t x = width; + do { + int16x8_t a[3]; + const uint8x8_t d0 = + WienerVerticalTap3Kernel(wiener_buffer + 0, width, filter, a); + const uint8x8_t d1 = + WienerVerticalTap3Kernel(wiener_buffer + 8, width, filter, a); + vst1q_u8(dst, vcombine_u8(d0, d1)); + wiener_buffer += 16; + dst += 16; + x -= 16; + } while (x != 0); + } +} + +inline void WienerVerticalTap1Kernel(const int16_t* const wiener_buffer, + uint8_t* const dst) { + const int16x8_t a0 = vld1q_s16(wiener_buffer + 0); + const int16x8_t a1 = vld1q_s16(wiener_buffer + 8); + const uint8x8_t d0 = vqrshrun_n_s16(a0, 4); + const uint8x8_t d1 = vqrshrun_n_s16(a1, 4); + vst1q_u8(dst, vcombine_u8(d0, d1)); +} + +inline void WienerVerticalTap1(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + uint8_t* dst, const ptrdiff_t dst_stride) { + for (int y = height >> 1; y != 0; --y) { + uint8_t* dst_ptr = dst; + ptrdiff_t x = width; + do { + WienerVerticalTap1Kernel(wiener_buffer, dst_ptr); + WienerVerticalTap1Kernel(wiener_buffer + width, dst_ptr + dst_stride); + wiener_buffer += 16; + dst_ptr += 16; + x -= 16; + } while (x != 0); + wiener_buffer += width; + dst += 2 * dst_stride; + } + + if ((height & 1) != 0) { + ptrdiff_t x = width; + do { + WienerVerticalTap1Kernel(wiener_buffer, dst); + wiener_buffer += 16; + dst += 16; + x -= 16; + } while (x != 0); + } +} + +// For width 16 and up, store the horizontal results, and then do the vertical +// filter row by row. This is faster than doing it column by column when +// considering cache issues. +void WienerFilter_NEON(const RestorationUnitInfo& restoration_info, + const void* const source, const void* const top_border, + const void* const bottom_border, const ptrdiff_t stride, + const int width, const int height, + RestorationBuffer* const restoration_buffer, + void* const dest) { + const int16_t* const number_leading_zero_coefficients = + restoration_info.wiener_info.number_leading_zero_coefficients; + const int number_rows_to_skip = std::max( + static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]), + 1); + const ptrdiff_t wiener_stride = Align(width, 16); + int16_t* const wiener_buffer_vertical = restoration_buffer->wiener_buffer; + // The values are saturated to 13 bits before storing. + int16_t* wiener_buffer_horizontal = + wiener_buffer_vertical + number_rows_to_skip * wiener_stride; + int16_t filter_horizontal[(kWienerFilterTaps + 1) / 2]; + int16_t filter_vertical[(kWienerFilterTaps + 1) / 2]; + PopulateWienerCoefficients(restoration_info, WienerInfo::kHorizontal, + filter_horizontal); + PopulateWienerCoefficients(restoration_info, WienerInfo::kVertical, + filter_vertical); + + // horizontal filtering. + // Over-reads up to 15 - |kRestorationHorizontalBorder| values. + const int height_horizontal = + height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip; + const int height_extra = (height_horizontal - height) >> 1; + assert(height_extra <= 2); + const auto* const src = static_cast<const uint8_t*>(source); + const auto* const top = static_cast<const uint8_t*>(top_border); + const auto* const bottom = static_cast<const uint8_t*>(bottom_border); + if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) { + WienerHorizontalTap7(top + (2 - height_extra) * stride - 3, stride, + wiener_stride, height_extra, filter_horizontal, + &wiener_buffer_horizontal); + WienerHorizontalTap7(src - 3, stride, wiener_stride, height, + filter_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap7(bottom - 3, stride, wiener_stride, height_extra, + filter_horizontal, &wiener_buffer_horizontal); + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) { + WienerHorizontalTap5(top + (2 - height_extra) * stride - 2, stride, + wiener_stride, height_extra, filter_horizontal, + &wiener_buffer_horizontal); + WienerHorizontalTap5(src - 2, stride, wiener_stride, height, + filter_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap5(bottom - 2, stride, wiener_stride, height_extra, + filter_horizontal, &wiener_buffer_horizontal); + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) { + // The maximum over-reads happen here. + WienerHorizontalTap3(top + (2 - height_extra) * stride - 1, stride, + wiener_stride, height_extra, filter_horizontal, + &wiener_buffer_horizontal); + WienerHorizontalTap3(src - 1, stride, wiener_stride, height, + filter_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap3(bottom - 1, stride, wiener_stride, height_extra, + filter_horizontal, &wiener_buffer_horizontal); + } else { + assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3); + WienerHorizontalTap1(top + (2 - height_extra) * stride, stride, + wiener_stride, height_extra, + &wiener_buffer_horizontal); + WienerHorizontalTap1(src, stride, wiener_stride, height, + &wiener_buffer_horizontal); + WienerHorizontalTap1(bottom, stride, wiener_stride, height_extra, + &wiener_buffer_horizontal); + } + + // vertical filtering. + // Over-writes up to 15 values. + auto* dst = static_cast<uint8_t*>(dest); + if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) { + // Because the top row of |source| is a duplicate of the second row, and the + // bottom row of |source| is a duplicate of its above row, we can duplicate + // the top and bottom row of |wiener_buffer| accordingly. + memcpy(wiener_buffer_horizontal, wiener_buffer_horizontal - wiener_stride, + sizeof(*wiener_buffer_horizontal) * wiener_stride); + memcpy(restoration_buffer->wiener_buffer, + restoration_buffer->wiener_buffer + wiener_stride, + sizeof(*restoration_buffer->wiener_buffer) * wiener_stride); + WienerVerticalTap7(wiener_buffer_vertical, wiener_stride, height, + filter_vertical, dst, stride); + } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) { + WienerVerticalTap5(wiener_buffer_vertical + wiener_stride, wiener_stride, + height, filter_vertical, dst, stride); + } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) { + WienerVerticalTap3(wiener_buffer_vertical + 2 * wiener_stride, + wiener_stride, height, filter_vertical, dst, stride); + } else { + assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3); + WienerVerticalTap1(wiener_buffer_vertical + 3 * wiener_stride, + wiener_stride, height, dst, stride); + } +} + +//------------------------------------------------------------------------------ +// SGR + +inline void Prepare3_8(const uint8x8x2_t src, uint8x8_t dst[3]) { + dst[0] = VshrU128<0>(src); + dst[1] = VshrU128<1>(src); + dst[2] = VshrU128<2>(src); +} + +inline void Prepare3_16(const uint16x8x2_t src, uint16x4_t low[3], + uint16x4_t high[3]) { + uint16x8_t s[3]; + s[0] = VshrU128<0>(src); + s[1] = VshrU128<2>(src); + s[2] = VshrU128<4>(src); + low[0] = vget_low_u16(s[0]); + low[1] = vget_low_u16(s[1]); + low[2] = vget_low_u16(s[2]); + high[0] = vget_high_u16(s[0]); + high[1] = vget_high_u16(s[1]); + high[2] = vget_high_u16(s[2]); +} + +inline void Prepare5_8(const uint8x8x2_t src, uint8x8_t dst[5]) { + dst[0] = VshrU128<0>(src); + dst[1] = VshrU128<1>(src); + dst[2] = VshrU128<2>(src); + dst[3] = VshrU128<3>(src); + dst[4] = VshrU128<4>(src); +} + +inline void Prepare5_16(const uint16x8x2_t src, uint16x4_t low[5], + uint16x4_t high[5]) { + Prepare3_16(src, low, high); + const uint16x8_t s3 = VshrU128<6>(src); + const uint16x8_t s4 = VshrU128<8>(src); + low[3] = vget_low_u16(s3); + low[4] = vget_low_u16(s4); + high[3] = vget_high_u16(s3); + high[4] = vget_high_u16(s4); +} + +inline uint16x8_t Sum3_16(const uint16x8_t src0, const uint16x8_t src1, + const uint16x8_t src2) { + const uint16x8_t sum = vaddq_u16(src0, src1); + return vaddq_u16(sum, src2); +} + +inline uint16x8_t Sum3_16(const uint16x8_t src[3]) { + return Sum3_16(src[0], src[1], src[2]); +} + +inline uint32x4_t Sum3_32(const uint32x4_t src0, const uint32x4_t src1, + const uint32x4_t src2) { + const uint32x4_t sum = vaddq_u32(src0, src1); + return vaddq_u32(sum, src2); +} + +inline uint32x4x2_t Sum3_32(const uint32x4x2_t src[3]) { + uint32x4x2_t d; + d.val[0] = Sum3_32(src[0].val[0], src[1].val[0], src[2].val[0]); + d.val[1] = Sum3_32(src[0].val[1], src[1].val[1], src[2].val[1]); + return d; +} + +inline uint16x8_t Sum3W_16(const uint8x8_t src[3]) { + const uint16x8_t sum = vaddl_u8(src[0], src[1]); + return vaddw_u8(sum, src[2]); +} + +inline uint32x4_t Sum3W_32(const uint16x4_t src[3]) { + const uint32x4_t sum = vaddl_u16(src[0], src[1]); + return vaddw_u16(sum, src[2]); +} + +inline uint16x8_t Sum5_16(const uint16x8_t src[5]) { + const uint16x8_t sum01 = vaddq_u16(src[0], src[1]); + const uint16x8_t sum23 = vaddq_u16(src[2], src[3]); + const uint16x8_t sum = vaddq_u16(sum01, sum23); + return vaddq_u16(sum, src[4]); +} + +inline uint32x4_t Sum5_32(const uint32x4_t src0, const uint32x4_t src1, + const uint32x4_t src2, const uint32x4_t src3, + const uint32x4_t src4) { + const uint32x4_t sum01 = vaddq_u32(src0, src1); + const uint32x4_t sum23 = vaddq_u32(src2, src3); + const uint32x4_t sum = vaddq_u32(sum01, sum23); + return vaddq_u32(sum, src4); +} + +inline uint32x4x2_t Sum5_32(const uint32x4x2_t src[5]) { + uint32x4x2_t d; + d.val[0] = Sum5_32(src[0].val[0], src[1].val[0], src[2].val[0], src[3].val[0], + src[4].val[0]); + d.val[1] = Sum5_32(src[0].val[1], src[1].val[1], src[2].val[1], src[3].val[1], + src[4].val[1]); + return d; +} + +inline uint32x4_t Sum5W_32(const uint16x4_t src[5]) { + const uint32x4_t sum01 = vaddl_u16(src[0], src[1]); + const uint32x4_t sum23 = vaddl_u16(src[2], src[3]); + const uint32x4_t sum0123 = vaddq_u32(sum01, sum23); + return vaddw_u16(sum0123, src[4]); +} + +inline uint16x8_t Sum3Horizontal(const uint8x8x2_t src) { + uint8x8_t s[3]; + Prepare3_8(src, s); + return Sum3W_16(s); +} + +inline uint32x4x2_t Sum3WHorizontal(const uint16x8x2_t src) { + uint16x4_t low[3], high[3]; + uint32x4x2_t sum; + Prepare3_16(src, low, high); + sum.val[0] = Sum3W_32(low); + sum.val[1] = Sum3W_32(high); + return sum; +} + +inline uint16x8_t Sum5Horizontal(const uint8x8x2_t src) { + uint8x8_t s[5]; + Prepare5_8(src, s); + const uint16x8_t sum01 = vaddl_u8(s[0], s[1]); + const uint16x8_t sum23 = vaddl_u8(s[2], s[3]); + const uint16x8_t sum0123 = vaddq_u16(sum01, sum23); + return vaddw_u8(sum0123, s[4]); +} + +inline uint32x4x2_t Sum5WHorizontal(const uint16x8x2_t src) { + uint16x4_t low[5], high[5]; + Prepare5_16(src, low, high); + uint32x4x2_t sum; + sum.val[0] = Sum5W_32(low); + sum.val[1] = Sum5W_32(high); + return sum; +} + +void SumHorizontal(const uint16x4_t src[5], uint32x4_t* const row_sq3, + uint32x4_t* const row_sq5) { + const uint32x4_t sum04 = vaddl_u16(src[0], src[4]); + const uint32x4_t sum12 = vaddl_u16(src[1], src[2]); + *row_sq3 = vaddw_u16(sum12, src[3]); + *row_sq5 = vaddq_u32(sum04, *row_sq3); +} + +void SumHorizontal(const uint8x8x2_t src, const uint16x8x2_t sq, + uint16x8_t* const row3, uint16x8_t* const row5, + uint32x4x2_t* const row_sq3, uint32x4x2_t* const row_sq5) { + uint8x8_t s[5]; + Prepare5_8(src, s); + const uint16x8_t sum04 = vaddl_u8(s[0], s[4]); + const uint16x8_t sum12 = vaddl_u8(s[1], s[2]); + *row3 = vaddw_u8(sum12, s[3]); + *row5 = vaddq_u16(sum04, *row3); + uint16x4_t low[5], high[5]; + Prepare5_16(sq, low, high); + SumHorizontal(low, &row_sq3->val[0], &row_sq5->val[0]); + SumHorizontal(high, &row_sq3->val[1], &row_sq5->val[1]); +} + +inline uint16x8_t Sum343(const uint8x8x2_t src) { + uint8x8_t s[3]; + Prepare3_8(src, s); + const uint16x8_t sum = Sum3W_16(s); + const uint16x8_t sum3 = Sum3_16(sum, sum, sum); + return vaddw_u8(sum3, s[1]); +} + +inline uint32x4_t Sum343W(const uint16x4_t src[3]) { + const uint32x4_t sum = Sum3W_32(src); + const uint32x4_t sum3 = Sum3_32(sum, sum, sum); + return vaddw_u16(sum3, src[1]); +} + +inline uint32x4x2_t Sum343W(const uint16x8x2_t src) { + uint16x4_t low[3], high[3]; + uint32x4x2_t d; + Prepare3_16(src, low, high); + d.val[0] = Sum343W(low); + d.val[1] = Sum343W(high); + return d; +} + +inline uint16x8_t Sum565(const uint8x8x2_t src) { + uint8x8_t s[3]; + Prepare3_8(src, s); + const uint16x8_t sum = Sum3W_16(s); + const uint16x8_t sum4 = vshlq_n_u16(sum, 2); + const uint16x8_t sum5 = vaddq_u16(sum4, sum); + return vaddw_u8(sum5, s[1]); +} + +inline uint32x4_t Sum565W(const uint16x4_t src[3]) { + const uint32x4_t sum = Sum3W_32(src); + const uint32x4_t sum4 = vshlq_n_u32(sum, 2); + const uint32x4_t sum5 = vaddq_u32(sum4, sum); + return vaddw_u16(sum5, src[1]); +} + +inline uint32x4x2_t Sum565W(const uint16x8x2_t src) { + uint16x4_t low[3], high[3]; + uint32x4x2_t d; + Prepare3_16(src, low, high); + d.val[0] = Sum565W(low); + d.val[1] = Sum565W(high); + return d; +} + +inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride, + const int height, const ptrdiff_t sum_stride, uint16_t* sum3, + uint16_t* sum5, uint32_t* square_sum3, + uint32_t* square_sum5) { + int y = height; + do { + uint8x8x2_t s; + uint16x8x2_t sq; + s.val[0] = vld1_u8(src); + sq.val[0] = vmull_u8(s.val[0], s.val[0]); + ptrdiff_t x = 0; + do { + uint16x8_t row3, row5; + uint32x4x2_t row_sq3, row_sq5; + s.val[1] = vld1_u8(src + x + 8); + sq.val[1] = vmull_u8(s.val[1], s.val[1]); + SumHorizontal(s, sq, &row3, &row5, &row_sq3, &row_sq5); + vst1q_u16(sum3, row3); + vst1q_u16(sum5, row5); + vst1q_u32(square_sum3 + 0, row_sq3.val[0]); + vst1q_u32(square_sum3 + 4, row_sq3.val[1]); + vst1q_u32(square_sum5 + 0, row_sq5.val[0]); + vst1q_u32(square_sum5 + 4, row_sq5.val[1]); + s.val[0] = s.val[1]; + sq.val[0] = sq.val[1]; + sum3 += 8; + sum5 += 8; + square_sum3 += 8; + square_sum5 += 8; + x += 8; + } while (x < sum_stride); + src += src_stride; + } while (--y != 0); +} + +template <int size> +inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride, + const int height, const ptrdiff_t sum_stride, uint16_t* sums, + uint32_t* square_sums) { + static_assert(size == 3 || size == 5, ""); + int y = height; + do { + uint8x8x2_t s; + uint16x8x2_t sq; + s.val[0] = vld1_u8(src); + sq.val[0] = vmull_u8(s.val[0], s.val[0]); + ptrdiff_t x = 0; + do { + uint16x8_t row; + uint32x4x2_t row_sq; + s.val[1] = vld1_u8(src + x + 8); + sq.val[1] = vmull_u8(s.val[1], s.val[1]); + if (size == 3) { + row = Sum3Horizontal(s); + row_sq = Sum3WHorizontal(sq); + } else { + row = Sum5Horizontal(s); + row_sq = Sum5WHorizontal(sq); + } + vst1q_u16(sums, row); + vst1q_u32(square_sums + 0, row_sq.val[0]); + vst1q_u32(square_sums + 4, row_sq.val[1]); + s.val[0] = s.val[1]; + sq.val[0] = sq.val[1]; + sums += 8; + square_sums += 8; + x += 8; + } while (x < sum_stride); + src += src_stride; + } while (--y != 0); +} + +template <int n> +inline uint16x4_t CalculateMa(const uint16x4_t sum, const uint32x4_t sum_sq, + const uint32_t scale) { + // a = |sum_sq| + // d = |sum| + // p = (a * n < d * d) ? 0 : a * n - d * d; + const uint32x4_t dxd = vmull_u16(sum, sum); + const uint32x4_t axn = vmulq_n_u32(sum_sq, n); + // Ensure |p| does not underflow by using saturating subtraction. + const uint32x4_t p = vqsubq_u32(axn, dxd); + const uint32x4_t pxs = vmulq_n_u32(p, scale); + // vrshrn_n_u32() (narrowing shift) can only shift by 16 and kSgrProjScaleBits + // is 20. + const uint32x4_t shifted = vrshrq_n_u32(pxs, kSgrProjScaleBits); + return vmovn_u32(shifted); +} + +template <int n> +inline void CalculateIntermediate(const uint16x8_t sum, + const uint32x4x2_t sum_sq, + const uint32_t scale, uint8x8_t* const ma, + uint16x8_t* const b) { + constexpr uint32_t one_over_n = + ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n; + const uint16x4_t z0 = CalculateMa<n>(vget_low_u16(sum), sum_sq.val[0], scale); + const uint16x4_t z1 = + CalculateMa<n>(vget_high_u16(sum), sum_sq.val[1], scale); + const uint16x8_t z01 = vcombine_u16(z0, z1); + // Using vqmovn_u16() needs an extra sign extension instruction. + const uint16x8_t z = vminq_u16(z01, vdupq_n_u16(255)); + // Using vgetq_lane_s16() can save the sign extension instruction. + const uint8_t lookup[8] = { + kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 0)], + kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 1)], + kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 2)], + kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 3)], + kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 4)], + kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 5)], + kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 6)], + kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 7)]}; + *ma = vld1_u8(lookup); + // b = ma * b * one_over_n + // |ma| = [0, 255] + // |sum| is a box sum with radius 1 or 2. + // For the first pass radius is 2. Maximum value is 5x5x255 = 6375. + // For the second pass radius is 1. Maximum value is 3x3x255 = 2295. + // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n + // When radius is 2 |n| is 25. |one_over_n| is 164. + // When radius is 1 |n| is 9. |one_over_n| is 455. + // |kSgrProjReciprocalBits| is 12. + // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits). + // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits). + const uint16x8_t maq = vmovl_u8(*ma); + const uint32x4_t m0 = vmull_u16(vget_low_u16(maq), vget_low_u16(sum)); + const uint32x4_t m1 = vmull_u16(vget_high_u16(maq), vget_high_u16(sum)); + const uint32x4_t m2 = vmulq_n_u32(m0, one_over_n); + const uint32x4_t m3 = vmulq_n_u32(m1, one_over_n); + const uint16x4_t b_lo = vrshrn_n_u32(m2, kSgrProjReciprocalBits); + const uint16x4_t b_hi = vrshrn_n_u32(m3, kSgrProjReciprocalBits); + *b = vcombine_u16(b_lo, b_hi); +} + +inline void CalculateIntermediate5(const uint16x8_t s5[5], + const uint32x4x2_t sq5[5], + const uint32_t scale, uint8x8_t* const ma, + uint16x8_t* const b) { + const uint16x8_t sum = Sum5_16(s5); + const uint32x4x2_t sum_sq = Sum5_32(sq5); + CalculateIntermediate<25>(sum, sum_sq, scale, ma, b); +} + +inline void CalculateIntermediate3(const uint16x8_t s3[3], + const uint32x4x2_t sq3[3], + const uint32_t scale, uint8x8_t* const ma, + uint16x8_t* const b) { + const uint16x8_t sum = Sum3_16(s3); + const uint32x4x2_t sum_sq = Sum3_32(sq3); + CalculateIntermediate<9>(sum, sum_sq, scale, ma, b); +} + +inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3, + const ptrdiff_t x, uint16x8_t* const sum_ma343, + uint16x8_t* const sum_ma444, + uint32x4x2_t* const sum_b343, + uint32x4x2_t* const sum_b444, uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + uint8x8_t s[3]; + Prepare3_8(ma3, s); + const uint16x8_t sum_ma111 = Sum3W_16(s); + *sum_ma444 = vshlq_n_u16(sum_ma111, 2); + const uint16x8_t sum333 = vsubq_u16(*sum_ma444, sum_ma111); + *sum_ma343 = vaddw_u8(sum333, s[1]); + uint16x4_t low[3], high[3]; + uint32x4x2_t sum_b111; + Prepare3_16(b3, low, high); + sum_b111.val[0] = Sum3W_32(low); + sum_b111.val[1] = Sum3W_32(high); + sum_b444->val[0] = vshlq_n_u32(sum_b111.val[0], 2); + sum_b444->val[1] = vshlq_n_u32(sum_b111.val[1], 2); + sum_b343->val[0] = vsubq_u32(sum_b444->val[0], sum_b111.val[0]); + sum_b343->val[1] = vsubq_u32(sum_b444->val[1], sum_b111.val[1]); + sum_b343->val[0] = vaddw_u16(sum_b343->val[0], low[1]); + sum_b343->val[1] = vaddw_u16(sum_b343->val[1], high[1]); + vst1q_u16(ma343 + x, *sum_ma343); + vst1q_u16(ma444 + x, *sum_ma444); + vst1q_u32(b343 + x + 0, sum_b343->val[0]); + vst1q_u32(b343 + x + 4, sum_b343->val[1]); + vst1q_u32(b444 + x + 0, sum_b444->val[0]); + vst1q_u32(b444 + x + 4, sum_b444->val[1]); +} + +inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3, + const ptrdiff_t x, uint16x8_t* const sum_ma343, + uint32x4x2_t* const sum_b343, uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + uint16x8_t sum_ma444; + uint32x4x2_t sum_b444; + Store343_444(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, &sum_b444, ma343, + ma444, b343, b444); +} + +inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3, + const ptrdiff_t x, uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + uint16x8_t sum_ma343; + uint32x4x2_t sum_b343; + Store343_444(ma3, b3, x, &sum_ma343, &sum_b343, ma343, ma444, b343, b444); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5( + const uint8_t* const src0, const uint8_t* const src1, const ptrdiff_t x, + const uint32_t scale, uint16_t* const sum5[5], + uint32_t* const square_sum5[5], uint8x8x2_t s[2], uint16x8x2_t sq[2], + uint8x8_t* const ma, uint16x8_t* const b) { + uint16x8_t s5[5]; + uint32x4x2_t sq5[5]; + s[0].val[1] = vld1_u8(src0 + x + 8); + s[1].val[1] = vld1_u8(src1 + x + 8); + sq[0].val[1] = vmull_u8(s[0].val[1], s[0].val[1]); + sq[1].val[1] = vmull_u8(s[1].val[1], s[1].val[1]); + s5[3] = Sum5Horizontal(s[0]); + s5[4] = Sum5Horizontal(s[1]); + sq5[3] = Sum5WHorizontal(sq[0]); + sq5[4] = Sum5WHorizontal(sq[1]); + vst1q_u16(sum5[3] + x, s5[3]); + vst1q_u16(sum5[4] + x, s5[4]); + vst1q_u32(square_sum5[3] + x + 0, sq5[3].val[0]); + vst1q_u32(square_sum5[3] + x + 4, sq5[3].val[1]); + vst1q_u32(square_sum5[4] + x + 0, sq5[4].val[0]); + vst1q_u32(square_sum5[4] + x + 4, sq5[4].val[1]); + s5[0] = vld1q_u16(sum5[0] + x); + s5[1] = vld1q_u16(sum5[1] + x); + s5[2] = vld1q_u16(sum5[2] + x); + sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0); + sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4); + sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0); + sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4); + sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0); + sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4); + CalculateIntermediate5(s5, sq5, scale, ma, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRow( + const uint8_t* const src, const ptrdiff_t x, const uint32_t scale, + const uint16_t* const sum5[5], const uint32_t* const square_sum5[5], + uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma, + uint16x8_t* const b) { + uint16x8_t s5[5]; + uint32x4x2_t sq5[5]; + s->val[1] = vld1_u8(src + x + 8); + sq->val[1] = vmull_u8(s->val[1], s->val[1]); + s5[3] = s5[4] = Sum5Horizontal(*s); + sq5[3] = sq5[4] = Sum5WHorizontal(*sq); + s5[0] = vld1q_u16(sum5[0] + x); + s5[1] = vld1q_u16(sum5[1] + x); + s5[2] = vld1q_u16(sum5[2] + x); + sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0); + sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4); + sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0); + sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4); + sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0); + sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4); + CalculateIntermediate5(s5, sq5, scale, ma, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3( + const uint8_t* const src, const ptrdiff_t x, const uint32_t scale, + uint16_t* const sum3[3], uint32_t* const square_sum3[3], + uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma, + uint16x8_t* const b) { + uint16x8_t s3[3]; + uint32x4x2_t sq3[3]; + s->val[1] = vld1_u8(src + x + 8); + sq->val[1] = vmull_u8(s->val[1], s->val[1]); + s3[2] = Sum3Horizontal(*s); + sq3[2] = Sum3WHorizontal(*sq); + vst1q_u16(sum3[2] + x, s3[2]); + vst1q_u32(square_sum3[2] + x + 0, sq3[2].val[0]); + vst1q_u32(square_sum3[2] + x + 4, sq3[2].val[1]); + s3[0] = vld1q_u16(sum3[0] + x); + s3[1] = vld1q_u16(sum3[1] + x); + sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0); + sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4); + sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0); + sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4); + CalculateIntermediate3(s3, sq3, scale, ma, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess( + const uint8_t* const src0, const uint8_t* const src1, const ptrdiff_t x, + const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + uint8x8x2_t s[2], uint16x8x2_t sq[2], uint8x8_t* const ma3_0, + uint8x8_t* const ma3_1, uint16x8_t* const b3_0, uint16x8_t* const b3_1, + uint8x8_t* const ma5, uint16x8_t* const b5) { + uint16x8_t s3[4], s5[5]; + uint32x4x2_t sq3[4], sq5[5]; + s[0].val[1] = vld1_u8(src0 + x + 8); + s[1].val[1] = vld1_u8(src1 + x + 8); + sq[0].val[1] = vmull_u8(s[0].val[1], s[0].val[1]); + sq[1].val[1] = vmull_u8(s[1].val[1], s[1].val[1]); + SumHorizontal(s[0], sq[0], &s3[2], &s5[3], &sq3[2], &sq5[3]); + SumHorizontal(s[1], sq[1], &s3[3], &s5[4], &sq3[3], &sq5[4]); + vst1q_u16(sum3[2] + x, s3[2]); + vst1q_u16(sum3[3] + x, s3[3]); + vst1q_u32(square_sum3[2] + x + 0, sq3[2].val[0]); + vst1q_u32(square_sum3[2] + x + 4, sq3[2].val[1]); + vst1q_u32(square_sum3[3] + x + 0, sq3[3].val[0]); + vst1q_u32(square_sum3[3] + x + 4, sq3[3].val[1]); + vst1q_u16(sum5[3] + x, s5[3]); + vst1q_u16(sum5[4] + x, s5[4]); + vst1q_u32(square_sum5[3] + x + 0, sq5[3].val[0]); + vst1q_u32(square_sum5[3] + x + 4, sq5[3].val[1]); + vst1q_u32(square_sum5[4] + x + 0, sq5[4].val[0]); + vst1q_u32(square_sum5[4] + x + 4, sq5[4].val[1]); + s3[0] = vld1q_u16(sum3[0] + x); + s3[1] = vld1q_u16(sum3[1] + x); + sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0); + sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4); + sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0); + sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4); + s5[0] = vld1q_u16(sum5[0] + x); + s5[1] = vld1q_u16(sum5[1] + x); + s5[2] = vld1q_u16(sum5[2] + x); + sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0); + sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4); + sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0); + sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4); + sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0); + sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4); + CalculateIntermediate3(s3, sq3, scales[1], ma3_0, b3_0); + CalculateIntermediate3(s3 + 1, sq3 + 1, scales[1], ma3_1, b3_1); + CalculateIntermediate5(s5, sq5, scales[0], ma5, b5); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow( + const uint8_t* const src, const ptrdiff_t x, const uint16_t scales[2], + const uint16_t* const sum3[4], const uint16_t* const sum5[5], + const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5], + uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma3, + uint8x8_t* const ma5, uint16x8_t* const b3, uint16x8_t* const b5) { + uint16x8_t s3[3], s5[5]; + uint32x4x2_t sq3[3], sq5[5]; + s->val[1] = vld1_u8(src + x + 8); + sq->val[1] = vmull_u8(s->val[1], s->val[1]); + SumHorizontal(*s, *sq, &s3[2], &s5[3], &sq3[2], &sq5[3]); + s5[0] = vld1q_u16(sum5[0] + x); + s5[1] = vld1q_u16(sum5[1] + x); + s5[2] = vld1q_u16(sum5[2] + x); + s5[4] = s5[3]; + sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0); + sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4); + sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0); + sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4); + sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0); + sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4); + sq5[4] = sq5[3]; + CalculateIntermediate5(s5, sq5, scales[0], ma5, b5); + s3[0] = vld1q_u16(sum3[0] + x); + s3[1] = vld1q_u16(sum3[1] + x); + sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0); + sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4); + sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0); + sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4); + CalculateIntermediate3(s3, sq3, scales[1], ma3, b3); +} + +inline void BoxSumFilterPreProcess5(const uint8_t* const src0, + const uint8_t* const src1, const int width, + const uint32_t scale, + uint16_t* const sum5[5], + uint32_t* const square_sum5[5], + uint16_t* ma565, uint32_t* b565) { + uint8x8x2_t s[2], mas; + uint16x8x2_t sq[2], bs; + s[0].val[0] = vld1_u8(src0); + s[1].val[0] = vld1_u8(src1); + sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]); + sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]); + BoxFilterPreProcess5(src0, src1, 0, scale, sum5, square_sum5, s, sq, + &mas.val[0], &bs.val[0]); + + int x = 0; + do { + s[0].val[0] = s[0].val[1]; + s[1].val[0] = s[1].val[1]; + sq[0].val[0] = sq[0].val[1]; + sq[1].val[0] = sq[1].val[1]; + BoxFilterPreProcess5(src0, src1, x + 8, scale, sum5, square_sum5, s, sq, + &mas.val[1], &bs.val[1]); + const uint16x8_t ma = Sum565(mas); + const uint32x4x2_t b = Sum565W(bs); + vst1q_u16(ma565, ma); + vst1q_u32(b565 + 0, b.val[0]); + vst1q_u32(b565 + 4, b.val[1]); + mas.val[0] = mas.val[1]; + bs.val[0] = bs.val[1]; + ma565 += 8; + b565 += 8; + x += 8; + } while (x < width); +} + +template <bool calculate444> +LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3( + const uint8_t* const src, const int width, const uint32_t scale, + uint16_t* const sum3[3], uint32_t* const square_sum3[3], uint16_t* ma343, + uint16_t* ma444, uint32_t* b343, uint32_t* b444) { + uint8x8x2_t s, mas; + uint16x8x2_t sq, bs; + s.val[0] = vld1_u8(src); + sq.val[0] = vmull_u8(s.val[0], s.val[0]); + BoxFilterPreProcess3(src, 0, scale, sum3, square_sum3, &s, &sq, &mas.val[0], + &bs.val[0]); + + int x = 0; + do { + s.val[0] = s.val[1]; + sq.val[0] = sq.val[1]; + BoxFilterPreProcess3(src, x + 8, scale, sum3, square_sum3, &s, &sq, + &mas.val[1], &bs.val[1]); + if (calculate444) { + Store343_444(mas, bs, 0, ma343, ma444, b343, b444); + ma444 += 8; + b444 += 8; + } else { + const uint16x8_t ma = Sum343(mas); + const uint32x4x2_t b = Sum343W(bs); + vst1q_u16(ma343, ma); + vst1q_u32(b343 + 0, b.val[0]); + vst1q_u32(b343 + 4, b.val[1]); + } + mas.val[0] = mas.val[1]; + bs.val[0] = bs.val[1]; + ma343 += 8; + b343 += 8; + x += 8; + } while (x < width); +} + +inline void BoxSumFilterPreProcess( + const uint8_t* const src0, const uint8_t* const src1, const int width, + const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + uint16_t* const ma343[4], uint16_t* const ma444[2], uint16_t* ma565, + uint32_t* const b343[4], uint32_t* const b444[2], uint32_t* b565) { + uint8x8x2_t s[2]; + uint8x8x2_t ma3[2], ma5; + uint16x8x2_t sq[2], b3[2], b5; + s[0].val[0] = vld1_u8(src0); + s[1].val[0] = vld1_u8(src1); + sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]); + sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]); + BoxFilterPreProcess(src0, src1, 0, scales, sum3, sum5, square_sum3, + square_sum5, s, sq, &ma3[0].val[0], &ma3[1].val[0], + &b3[0].val[0], &b3[1].val[0], &ma5.val[0], &b5.val[0]); + + int x = 0; + do { + s[0].val[0] = s[0].val[1]; + s[1].val[0] = s[1].val[1]; + sq[0].val[0] = sq[0].val[1]; + sq[1].val[0] = sq[1].val[1]; + BoxFilterPreProcess(src0, src1, x + 8, scales, sum3, sum5, square_sum3, + square_sum5, s, sq, &ma3[0].val[1], &ma3[1].val[1], + &b3[0].val[1], &b3[1].val[1], &ma5.val[1], &b5.val[1]); + uint16x8_t ma = Sum343(ma3[0]); + uint32x4x2_t b = Sum343W(b3[0]); + vst1q_u16(ma343[0] + x, ma); + vst1q_u32(b343[0] + x, b.val[0]); + vst1q_u32(b343[0] + x + 4, b.val[1]); + Store343_444(ma3[1], b3[1], x, ma343[1], ma444[0], b343[1], b444[0]); + ma = Sum565(ma5); + b = Sum565W(b5); + vst1q_u16(ma565, ma); + vst1q_u32(b565 + 0, b.val[0]); + vst1q_u32(b565 + 4, b.val[1]); + ma3[0].val[0] = ma3[0].val[1]; + ma3[1].val[0] = ma3[1].val[1]; + b3[0].val[0] = b3[0].val[1]; + b3[1].val[0] = b3[1].val[1]; + ma5.val[0] = ma5.val[1]; + b5.val[0] = b5.val[1]; + ma565 += 8; + b565 += 8; + x += 8; + } while (x < width); +} + +template <int shift> +inline int16x4_t FilterOutput(const uint16x4_t src, const uint16x4_t ma, + const uint32x4_t b) { + // ma: 255 * 32 = 8160 (13 bits) + // b: 65088 * 32 = 2082816 (21 bits) + // v: b - ma * 255 (22 bits) + const int32x4_t v = vreinterpretq_s32_u32(vmlsl_u16(b, ma, src)); + // kSgrProjSgrBits = 8 + // kSgrProjRestoreBits = 4 + // shift = 4 or 5 + // v >> 8 or 9 (13 bits) + return vrshrn_n_s32(v, kSgrProjSgrBits + shift - kSgrProjRestoreBits); +} + +template <int shift> +inline int16x8_t CalculateFilteredOutput(const uint8x8_t src, + const uint16x8_t ma, + const uint32x4x2_t b) { + const uint16x8_t src_u16 = vmovl_u8(src); + const int16x4_t dst_lo = + FilterOutput<shift>(vget_low_u16(src_u16), vget_low_u16(ma), b.val[0]); + const int16x4_t dst_hi = + FilterOutput<shift>(vget_high_u16(src_u16), vget_high_u16(ma), b.val[1]); + return vcombine_s16(dst_lo, dst_hi); // 13 bits +} + +inline int16x8_t CalculateFilteredOutputPass1(const uint8x8_t s, + uint16x8_t ma[2], + uint32x4x2_t b[2]) { + const uint16x8_t ma_sum = vaddq_u16(ma[0], ma[1]); + uint32x4x2_t b_sum; + b_sum.val[0] = vaddq_u32(b[0].val[0], b[1].val[0]); + b_sum.val[1] = vaddq_u32(b[0].val[1], b[1].val[1]); + return CalculateFilteredOutput<5>(s, ma_sum, b_sum); +} + +inline int16x8_t CalculateFilteredOutputPass2(const uint8x8_t s, + uint16x8_t ma[3], + uint32x4x2_t b[3]) { + const uint16x8_t ma_sum = Sum3_16(ma); + const uint32x4x2_t b_sum = Sum3_32(b); + return CalculateFilteredOutput<5>(s, ma_sum, b_sum); +} + +inline void SelfGuidedFinal(const uint8x8_t src, const int32x4_t v[2], + uint8_t* const dst) { + const int16x4_t v_lo = + vrshrn_n_s32(v[0], kSgrProjRestoreBits + kSgrProjPrecisionBits); + const int16x4_t v_hi = + vrshrn_n_s32(v[1], kSgrProjRestoreBits + kSgrProjPrecisionBits); + const int16x8_t vv = vcombine_s16(v_lo, v_hi); + const int16x8_t s = ZeroExtend(src); + const int16x8_t d = vaddq_s16(s, vv); + vst1_u8(dst, vqmovun_s16(d)); +} + +inline void SelfGuidedDoubleMultiplier(const uint8x8_t src, + const int16x8_t filter[2], const int w0, + const int w2, uint8_t* const dst) { + int32x4_t v[2]; + v[0] = vmull_n_s16(vget_low_s16(filter[0]), w0); + v[1] = vmull_n_s16(vget_high_s16(filter[0]), w0); + v[0] = vmlal_n_s16(v[0], vget_low_s16(filter[1]), w2); + v[1] = vmlal_n_s16(v[1], vget_high_s16(filter[1]), w2); + SelfGuidedFinal(src, v, dst); +} + +inline void SelfGuidedSingleMultiplier(const uint8x8_t src, + const int16x8_t filter, const int w0, + uint8_t* const dst) { + // weight: -96 to 96 (Sgrproj_Xqd_Min/Max) + int32x4_t v[2]; + v[0] = vmull_n_s16(vget_low_s16(filter), w0); + v[1] = vmull_n_s16(vget_high_s16(filter), w0); + SelfGuidedFinal(src, v, dst); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPass1( + const uint8_t* const src, const uint8_t* const src0, + const uint8_t* const src1, const ptrdiff_t stride, uint16_t* const sum5[5], + uint32_t* const square_sum5[5], const int width, const uint32_t scale, + const int16_t w0, uint16_t* const ma565[2], uint32_t* const b565[2], + uint8_t* const dst) { + uint8x8x2_t s[2], mas; + uint16x8x2_t sq[2], bs; + s[0].val[0] = vld1_u8(src0); + s[1].val[0] = vld1_u8(src1); + sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]); + sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]); + BoxFilterPreProcess5(src0, src1, 0, scale, sum5, square_sum5, s, sq, + &mas.val[0], &bs.val[0]); + + int x = 0; + do { + s[0].val[0] = s[0].val[1]; + s[1].val[0] = s[1].val[1]; + sq[0].val[0] = sq[0].val[1]; + sq[1].val[0] = sq[1].val[1]; + BoxFilterPreProcess5(src0, src1, x + 8, scale, sum5, square_sum5, s, sq, + &mas.val[1], &bs.val[1]); + uint16x8_t ma[2]; + uint32x4x2_t b[2]; + ma[1] = Sum565(mas); + b[1] = Sum565W(bs); + vst1q_u16(ma565[1] + x, ma[1]); + vst1q_u32(b565[1] + x + 0, b[1].val[0]); + vst1q_u32(b565[1] + x + 4, b[1].val[1]); + const uint8x8_t sr0 = vld1_u8(src + x); + const uint8x8_t sr1 = vld1_u8(src + stride + x); + int16x8_t p0, p1; + ma[0] = vld1q_u16(ma565[0] + x); + b[0].val[0] = vld1q_u32(b565[0] + x + 0); + b[0].val[1] = vld1q_u32(b565[0] + x + 4); + p0 = CalculateFilteredOutputPass1(sr0, ma, b); + p1 = CalculateFilteredOutput<4>(sr1, ma[1], b[1]); + SelfGuidedSingleMultiplier(sr0, p0, w0, dst + x); + SelfGuidedSingleMultiplier(sr1, p1, w0, dst + stride + x); + mas.val[0] = mas.val[1]; + bs.val[0] = bs.val[1]; + x += 8; + } while (x < width); +} + +inline void BoxFilterPass1LastRow(const uint8_t* const src, + const uint8_t* const src0, const int width, + const uint32_t scale, const int16_t w0, + uint16_t* const sum5[5], + uint32_t* const square_sum5[5], + uint16_t* ma565, uint32_t* b565, + uint8_t* const dst) { + uint8x8x2_t s, mas; + uint16x8x2_t sq, bs; + s.val[0] = vld1_u8(src0); + sq.val[0] = vmull_u8(s.val[0], s.val[0]); + BoxFilterPreProcess5LastRow(src0, 0, scale, sum5, square_sum5, &s, &sq, + &mas.val[0], &bs.val[0]); + + int x = 0; + do { + s.val[0] = s.val[1]; + sq.val[0] = sq.val[1]; + BoxFilterPreProcess5LastRow(src0, x + 8, scale, sum5, square_sum5, &s, &sq, + &mas.val[1], &bs.val[1]); + uint16x8_t ma[2]; + uint32x4x2_t b[2]; + ma[1] = Sum565(mas); + b[1] = Sum565W(bs); + mas.val[0] = mas.val[1]; + bs.val[0] = bs.val[1]; + ma[0] = vld1q_u16(ma565); + b[0].val[0] = vld1q_u32(b565 + 0); + b[0].val[1] = vld1q_u32(b565 + 4); + const uint8x8_t sr = vld1_u8(src + x); + const int16x8_t p = CalculateFilteredOutputPass1(sr, ma, b); + SelfGuidedSingleMultiplier(sr, p, w0, dst + x); + ma565 += 8; + b565 += 8; + x += 8; + } while (x < width); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPass2( + const uint8_t* const src, const uint8_t* const src0, const int width, + const uint32_t scale, const int16_t w0, uint16_t* const sum3[3], + uint32_t* const square_sum3[3], uint16_t* const ma343[3], + uint16_t* const ma444[2], uint32_t* const b343[3], uint32_t* const b444[2], + uint8_t* const dst) { + uint8x8x2_t s, mas; + uint16x8x2_t sq, bs; + s.val[0] = vld1_u8(src0); + sq.val[0] = vmull_u8(s.val[0], s.val[0]); + BoxFilterPreProcess3(src0, 0, scale, sum3, square_sum3, &s, &sq, &mas.val[0], + &bs.val[0]); + + int x = 0; + do { + s.val[0] = s.val[1]; + sq.val[0] = sq.val[1]; + BoxFilterPreProcess3(src0, x + 8, scale, sum3, square_sum3, &s, &sq, + &mas.val[1], &bs.val[1]); + uint16x8_t ma[3]; + uint32x4x2_t b[3]; + Store343_444(mas, bs, x, &ma[2], &b[2], ma343[2], ma444[1], b343[2], + b444[1]); + const uint8x8_t sr = vld1_u8(src + x); + ma[0] = vld1q_u16(ma343[0] + x); + ma[1] = vld1q_u16(ma444[0] + x); + b[0].val[0] = vld1q_u32(b343[0] + x + 0); + b[0].val[1] = vld1q_u32(b343[0] + x + 4); + b[1].val[0] = vld1q_u32(b444[0] + x + 0); + b[1].val[1] = vld1q_u32(b444[0] + x + 4); + const int16x8_t p = CalculateFilteredOutputPass2(sr, ma, b); + SelfGuidedSingleMultiplier(sr, p, w0, dst + x); + mas.val[0] = mas.val[1]; + bs.val[0] = bs.val[1]; + x += 8; + } while (x < width); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilter( + const uint8_t* const src, const uint8_t* const src0, + const uint8_t* const src1, const ptrdiff_t stride, const int width, + const uint16_t scales[2], const int16_t w0, const int16_t w2, + uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + uint16_t* const ma343[4], uint16_t* const ma444[3], + uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3], + uint32_t* const b565[2], uint8_t* const dst) { + uint8x8x2_t s[2], ma3[2], ma5; + uint16x8x2_t sq[2], b3[2], b5; + s[0].val[0] = vld1_u8(src0); + s[1].val[0] = vld1_u8(src1); + sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]); + sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]); + BoxFilterPreProcess(src0, src1, 0, scales, sum3, sum5, square_sum3, + square_sum5, s, sq, &ma3[0].val[0], &ma3[1].val[0], + &b3[0].val[0], &b3[1].val[0], &ma5.val[0], &b5.val[0]); + + int x = 0; + do { + s[0].val[0] = s[0].val[1]; + s[1].val[0] = s[1].val[1]; + sq[0].val[0] = sq[0].val[1]; + sq[1].val[0] = sq[1].val[1]; + BoxFilterPreProcess(src0, src1, x + 8, scales, sum3, sum5, square_sum3, + square_sum5, s, sq, &ma3[0].val[1], &ma3[1].val[1], + &b3[0].val[1], &b3[1].val[1], &ma5.val[1], &b5.val[1]); + uint16x8_t ma[3][3]; + uint32x4x2_t b[3][3]; + Store343_444(ma3[0], b3[0], x, &ma[1][2], &ma[2][1], &b[1][2], &b[2][1], + ma343[2], ma444[1], b343[2], b444[1]); + Store343_444(ma3[1], b3[1], x, &ma[2][2], &b[2][2], ma343[3], ma444[2], + b343[3], b444[2]); + ma[0][1] = Sum565(ma5); + b[0][1] = Sum565W(b5); + vst1q_u16(ma565[1] + x, ma[0][1]); + vst1q_u32(b565[1] + x, b[0][1].val[0]); + vst1q_u32(b565[1] + x + 4, b[0][1].val[1]); + ma3[0].val[0] = ma3[0].val[1]; + ma3[1].val[0] = ma3[1].val[1]; + b3[0].val[0] = b3[0].val[1]; + b3[1].val[0] = b3[1].val[1]; + ma5.val[0] = ma5.val[1]; + b5.val[0] = b5.val[1]; + int16x8_t p[2][2]; + const uint8x8_t sr0 = vld1_u8(src + x); + const uint8x8_t sr1 = vld1_u8(src + stride + x); + ma[0][0] = vld1q_u16(ma565[0] + x); + b[0][0].val[0] = vld1q_u32(b565[0] + x); + b[0][0].val[1] = vld1q_u32(b565[0] + x + 4); + p[0][0] = CalculateFilteredOutputPass1(sr0, ma[0], b[0]); + p[1][0] = CalculateFilteredOutput<4>(sr1, ma[0][1], b[0][1]); + ma[1][0] = vld1q_u16(ma343[0] + x); + ma[1][1] = vld1q_u16(ma444[0] + x); + b[1][0].val[0] = vld1q_u32(b343[0] + x); + b[1][0].val[1] = vld1q_u32(b343[0] + x + 4); + b[1][1].val[0] = vld1q_u32(b444[0] + x); + b[1][1].val[1] = vld1q_u32(b444[0] + x + 4); + p[0][1] = CalculateFilteredOutputPass2(sr0, ma[1], b[1]); + ma[2][0] = vld1q_u16(ma343[1] + x); + b[2][0].val[0] = vld1q_u32(b343[1] + x); + b[2][0].val[1] = vld1q_u32(b343[1] + x + 4); + p[1][1] = CalculateFilteredOutputPass2(sr1, ma[2], b[2]); + SelfGuidedDoubleMultiplier(sr0, p[0], w0, w2, dst + x); + SelfGuidedDoubleMultiplier(sr1, p[1], w0, w2, dst + stride + x); + x += 8; + } while (x < width); +} + +inline void BoxFilterLastRow( + const uint8_t* const src, const uint8_t* const src0, const int width, + const uint16_t scales[2], const int16_t w0, const int16_t w2, + uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + uint16_t* const ma343[4], uint16_t* const ma444[3], + uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3], + uint32_t* const b565[2], uint8_t* const dst) { + uint8x8x2_t s, ma3, ma5; + uint16x8x2_t sq, b3, b5; + uint16x8_t ma[3]; + uint32x4x2_t b[3]; + s.val[0] = vld1_u8(src0); + sq.val[0] = vmull_u8(s.val[0], s.val[0]); + BoxFilterPreProcessLastRow(src0, 0, scales, sum3, sum5, square_sum3, + square_sum5, &s, &sq, &ma3.val[0], &ma5.val[0], + &b3.val[0], &b5.val[0]); + + int x = 0; + do { + s.val[0] = s.val[1]; + sq.val[0] = sq.val[1]; + BoxFilterPreProcessLastRow(src0, x + 8, scales, sum3, sum5, square_sum3, + square_sum5, &s, &sq, &ma3.val[1], &ma5.val[1], + &b3.val[1], &b5.val[1]); + ma[1] = Sum565(ma5); + b[1] = Sum565W(b5); + ma5.val[0] = ma5.val[1]; + b5.val[0] = b5.val[1]; + ma[2] = Sum343(ma3); + b[2] = Sum343W(b3); + ma3.val[0] = ma3.val[1]; + b3.val[0] = b3.val[1]; + const uint8x8_t sr = vld1_u8(src + x); + int16x8_t p[2]; + ma[0] = vld1q_u16(ma565[0] + x); + b[0].val[0] = vld1q_u32(b565[0] + x + 0); + b[0].val[1] = vld1q_u32(b565[0] + x + 4); + p[0] = CalculateFilteredOutputPass1(sr, ma, b); + ma[0] = vld1q_u16(ma343[0] + x); + ma[1] = vld1q_u16(ma444[0] + x); + b[0].val[0] = vld1q_u32(b343[0] + x + 0); + b[0].val[1] = vld1q_u32(b343[0] + x + 4); + b[1].val[0] = vld1q_u32(b444[0] + x + 0); + b[1].val[1] = vld1q_u32(b444[0] + x + 4); + p[1] = CalculateFilteredOutputPass2(sr, ma, b); + SelfGuidedDoubleMultiplier(sr, p, w0, w2, dst + x); + x += 8; + } while (x < width); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( + const RestorationUnitInfo& restoration_info, const uint8_t* src, + const uint8_t* const top_border, const uint8_t* bottom_border, + const ptrdiff_t stride, const int width, const int height, + SgrBuffer* const sgr_buffer, uint8_t* dst) { + const auto temp_stride = Align<ptrdiff_t>(width, 8); + const ptrdiff_t sum_stride = temp_stride + 8; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index]; // < 2^12. + const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0]; + const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1]; + const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1; + uint16_t *sum3[4], *sum5[5], *ma343[4], *ma444[3], *ma565[2]; + uint32_t *square_sum3[4], *square_sum5[5], *b343[4], *b444[3], *b565[2]; + sum3[0] = sgr_buffer->sum3; + square_sum3[0] = sgr_buffer->square_sum3; + ma343[0] = sgr_buffer->ma343; + b343[0] = sgr_buffer->b343; + for (int i = 1; i <= 3; ++i) { + sum3[i] = sum3[i - 1] + sum_stride; + square_sum3[i] = square_sum3[i - 1] + sum_stride; + ma343[i] = ma343[i - 1] + temp_stride; + b343[i] = b343[i - 1] + temp_stride; + } + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + for (int i = 1; i <= 4; ++i) { + sum5[i] = sum5[i - 1] + sum_stride; + square_sum5[i] = square_sum5[i - 1] + sum_stride; + } + ma444[0] = sgr_buffer->ma444; + b444[0] = sgr_buffer->b444; + for (int i = 1; i <= 2; ++i) { + ma444[i] = ma444[i - 1] + temp_stride; + b444[i] = b444[i - 1] + temp_stride; + } + ma565[0] = sgr_buffer->ma565; + ma565[1] = ma565[0] + temp_stride; + b565[0] = sgr_buffer->b565; + b565[1] = b565[0] + temp_stride; + assert(scales[0] != 0); + assert(scales[1] != 0); + BoxSum(top_border, stride, 2, sum_stride, sum3[0], sum5[1], square_sum3[0], + square_sum5[1]); + sum5[0] = sum5[1]; + square_sum5[0] = square_sum5[1]; + const uint8_t* const s = (height > 1) ? src + stride : bottom_border; + BoxSumFilterPreProcess(src, s, width, scales, sum3, sum5, square_sum3, + square_sum5, ma343, ma444, ma565[0], b343, b444, + b565[0]); + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + + for (int y = (height >> 1) - 1; y > 0; --y) { + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + BoxFilter(src + 3, src + 2 * stride, src + 3 * stride, stride, width, + scales, w0, w2, sum3, sum5, square_sum3, square_sum5, ma343, + ma444, ma565, b343, b444, b565, dst); + src += 2 * stride; + dst += 2 * stride; + Circulate4PointersBy2<uint16_t>(ma343); + Circulate4PointersBy2<uint32_t>(b343); + std::swap(ma444[0], ma444[2]); + std::swap(b444[0], b444[2]); + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + if ((height & 1) == 0 || height > 1) { + const uint8_t* sr[2]; + if ((height & 1) == 0) { + sr[0] = bottom_border; + sr[1] = bottom_border + stride; + } else { + sr[0] = src + 2 * stride; + sr[1] = bottom_border; + } + BoxFilter(src + 3, sr[0], sr[1], stride, width, scales, w0, w2, sum3, sum5, + square_sum3, square_sum5, ma343, ma444, ma565, b343, b444, b565, + dst); + } + if ((height & 1) != 0) { + if (height > 1) { + src += 2 * stride; + dst += 2 * stride; + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + Circulate4PointersBy2<uint16_t>(ma343); + Circulate4PointersBy2<uint32_t>(b343); + std::swap(ma444[0], ma444[2]); + std::swap(b444[0], b444[2]); + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + BoxFilterLastRow(src + 3, bottom_border + stride, width, scales, w0, w2, + sum3, sum5, square_sum3, square_sum5, ma343, ma444, ma565, + b343, b444, b565, dst); + } +} + +inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, + const uint8_t* src, + const uint8_t* const top_border, + const uint8_t* bottom_border, + const ptrdiff_t stride, const int width, + const int height, SgrBuffer* const sgr_buffer, + uint8_t* dst) { + const auto temp_stride = Align<ptrdiff_t>(width, 8); + const ptrdiff_t sum_stride = temp_stride + 8; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0]; // < 2^12. + const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0]; + uint16_t *sum5[5], *ma565[2]; + uint32_t *square_sum5[5], *b565[2]; + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + for (int i = 1; i <= 4; ++i) { + sum5[i] = sum5[i - 1] + sum_stride; + square_sum5[i] = square_sum5[i - 1] + sum_stride; + } + ma565[0] = sgr_buffer->ma565; + ma565[1] = ma565[0] + temp_stride; + b565[0] = sgr_buffer->b565; + b565[1] = b565[0] + temp_stride; + assert(scale != 0); + BoxSum<5>(top_border, stride, 2, sum_stride, sum5[1], square_sum5[1]); + sum5[0] = sum5[1]; + square_sum5[0] = square_sum5[1]; + const uint8_t* const s = (height > 1) ? src + stride : bottom_border; + BoxSumFilterPreProcess5(src, s, width, scale, sum5, square_sum5, ma565[0], + b565[0]); + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + + for (int y = (height >> 1) - 1; y > 0; --y) { + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + BoxFilterPass1(src + 3, src + 2 * stride, src + 3 * stride, stride, sum5, + square_sum5, width, scale, w0, ma565, b565, dst); + src += 2 * stride; + dst += 2 * stride; + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + if ((height & 1) == 0 || height > 1) { + const uint8_t* sr[2]; + if ((height & 1) == 0) { + sr[0] = bottom_border; + sr[1] = bottom_border + stride; + } else { + sr[0] = src + 2 * stride; + sr[1] = bottom_border; + } + BoxFilterPass1(src + 3, sr[0], sr[1], stride, sum5, square_sum5, width, + scale, w0, ma565, b565, dst); + } + if ((height & 1) != 0) { + if (height > 1) { + src += 2 * stride; + dst += 2 * stride; + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + } + BoxFilterPass1LastRow(src + 3, bottom_border + stride, width, scale, w0, + sum5, square_sum5, ma565[0], b565[0], dst); + } +} + +inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, + const uint8_t* src, + const uint8_t* const top_border, + const uint8_t* bottom_border, + const ptrdiff_t stride, const int width, + const int height, SgrBuffer* const sgr_buffer, + uint8_t* dst) { + assert(restoration_info.sgr_proj_info.multiplier[0] == 0); + const auto temp_stride = Align<ptrdiff_t>(width, 8); + const ptrdiff_t sum_stride = temp_stride + 8; + const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1]; + const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint32_t scale = kSgrScaleParameter[sgr_proj_index][1]; // < 2^12. + uint16_t *sum3[3], *ma343[3], *ma444[2]; + uint32_t *square_sum3[3], *b343[3], *b444[2]; + sum3[0] = sgr_buffer->sum3; + square_sum3[0] = sgr_buffer->square_sum3; + ma343[0] = sgr_buffer->ma343; + b343[0] = sgr_buffer->b343; + for (int i = 1; i <= 2; ++i) { + sum3[i] = sum3[i - 1] + sum_stride; + square_sum3[i] = square_sum3[i - 1] + sum_stride; + ma343[i] = ma343[i - 1] + temp_stride; + b343[i] = b343[i - 1] + temp_stride; + } + ma444[0] = sgr_buffer->ma444; + ma444[1] = ma444[0] + temp_stride; + b444[0] = sgr_buffer->b444; + b444[1] = b444[0] + temp_stride; + assert(scale != 0); + BoxSum<3>(top_border, stride, 2, sum_stride, sum3[0], square_sum3[0]); + BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3, ma343[0], + nullptr, b343[0], nullptr); + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + const uint8_t* s; + if (height > 1) { + s = src + stride; + } else { + s = bottom_border; + bottom_border += stride; + } + BoxSumFilterPreProcess3<true>(s, width, scale, sum3, square_sum3, ma343[1], + ma444[0], b343[1], b444[0]); + + for (int y = height - 2; y > 0; --y) { + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + BoxFilterPass2(src + 2, src + 2 * stride, width, scale, w0, sum3, + square_sum3, ma343, ma444, b343, b444, dst); + src += stride; + dst += stride; + Circulate3PointersBy1<uint16_t>(ma343); + Circulate3PointersBy1<uint32_t>(b343); + std::swap(ma444[0], ma444[1]); + std::swap(b444[0], b444[1]); + } + + src += 2; + int y = std::min(height, 2); + do { + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + BoxFilterPass2(src, bottom_border, width, scale, w0, sum3, square_sum3, + ma343, ma444, b343, b444, dst); + src += stride; + dst += stride; + bottom_border += stride; + Circulate3PointersBy1<uint16_t>(ma343); + Circulate3PointersBy1<uint32_t>(b343); + std::swap(ma444[0], ma444[1]); + std::swap(b444[0], b444[1]); + } while (--y != 0); +} + +// If |width| is non-multiple of 8, up to 7 more pixels are written to |dest| in +// the end of each row. It is safe to overwrite the output as it will not be +// part of the visible frame. +void SelfGuidedFilter_NEON( + const RestorationUnitInfo& restoration_info, const void* const source, + const void* const top_border, const void* const bottom_border, + const ptrdiff_t stride, const int width, const int height, + RestorationBuffer* const restoration_buffer, void* const dest) { + const int index = restoration_info.sgr_proj_info.index; + const int radius_pass_0 = kSgrProjParams[index][0]; // 2 or 0 + const int radius_pass_1 = kSgrProjParams[index][2]; // 1 or 0 + const auto* const src = static_cast<const uint8_t*>(source); + const auto* top = static_cast<const uint8_t*>(top_border); + const auto* bottom = static_cast<const uint8_t*>(bottom_border); + auto* const dst = static_cast<uint8_t*>(dest); + SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer; + if (radius_pass_1 == 0) { + // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the + // following assertion. + assert(radius_pass_0 != 0); + BoxFilterProcessPass1(restoration_info, src - 3, top - 3, bottom - 3, + stride, width, height, sgr_buffer, dst); + } else if (radius_pass_0 == 0) { + BoxFilterProcessPass2(restoration_info, src - 2, top - 2, bottom - 2, + stride, width, height, sgr_buffer, dst); + } else { + BoxFilterProcess(restoration_info, src - 3, top - 3, bottom - 3, stride, + width, height, sgr_buffer, dst); + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->loop_restorations[0] = WienerFilter_NEON; + dsp->loop_restorations[1] = SelfGuidedFilter_NEON; +} + +} // namespace +} // namespace low_bitdepth + +void LoopRestorationInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void LoopRestorationInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/loop_restoration_neon.h b/src/dsp/arm/loop_restoration_neon.h new file mode 100644 index 0000000..b551610 --- /dev/null +++ b/src/dsp/arm/loop_restoration_neon.h @@ -0,0 +1,40 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_LOOP_RESTORATION_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_LOOP_RESTORATION_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::loop_restorations, see the defines below for specifics. +// This function is not thread-safe. +void LoopRestorationInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON + +#define LIBGAV1_Dsp8bpp_WienerFilter LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_SelfGuidedFilter LIBGAV1_CPU_NEON + +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_LOOP_RESTORATION_NEON_H_ diff --git a/src/dsp/arm/mask_blend_neon.cc b/src/dsp/arm/mask_blend_neon.cc new file mode 100644 index 0000000..084f42f --- /dev/null +++ b/src/dsp/arm/mask_blend_neon.cc @@ -0,0 +1,444 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/mask_blend.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +// TODO(b/150461164): Consider combining with GetInterIntraMask4x2(). +// Compound predictors use int16_t values and need to multiply long because the +// Convolve range * 64 is 20 bits. Unfortunately there is no multiply int16_t by +// int8_t and accumulate into int32_t instruction. +template <int subsampling_x, int subsampling_y> +inline int16x8_t GetMask4x2(const uint8_t* mask, ptrdiff_t mask_stride) { + if (subsampling_x == 1) { + const int16x4_t mask_val0 = vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask))); + const int16x4_t mask_val1 = vreinterpret_s16_u16( + vpaddl_u8(vld1_u8(mask + (mask_stride << subsampling_y)))); + int16x8_t final_val; + if (subsampling_y == 1) { + const int16x4_t next_mask_val0 = + vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask + mask_stride))); + const int16x4_t next_mask_val1 = + vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask + mask_stride * 3))); + final_val = vaddq_s16(vcombine_s16(mask_val0, mask_val1), + vcombine_s16(next_mask_val0, next_mask_val1)); + } else { + final_val = vreinterpretq_s16_u16( + vpaddlq_u8(vreinterpretq_u8_s16(vcombine_s16(mask_val0, mask_val1)))); + } + return vrshrq_n_s16(final_val, subsampling_y + 1); + } + assert(subsampling_y == 0 && subsampling_x == 0); + const uint8x8_t mask_val0 = Load4(mask); + const uint8x8_t mask_val = Load4<1>(mask + mask_stride, mask_val0); + return vreinterpretq_s16_u16(vmovl_u8(mask_val)); +} + +template <int subsampling_x, int subsampling_y> +inline int16x8_t GetMask8(const uint8_t* mask, ptrdiff_t mask_stride) { + if (subsampling_x == 1) { + int16x8_t mask_val = vreinterpretq_s16_u16(vpaddlq_u8(vld1q_u8(mask))); + if (subsampling_y == 1) { + const int16x8_t next_mask_val = + vreinterpretq_s16_u16(vpaddlq_u8(vld1q_u8(mask + mask_stride))); + mask_val = vaddq_s16(mask_val, next_mask_val); + } + return vrshrq_n_s16(mask_val, 1 + subsampling_y); + } + assert(subsampling_y == 0 && subsampling_x == 0); + const uint8x8_t mask_val = vld1_u8(mask); + return vreinterpretq_s16_u16(vmovl_u8(mask_val)); +} + +inline void WriteMaskBlendLine4x2(const int16_t* const pred_0, + const int16_t* const pred_1, + const int16x8_t pred_mask_0, + const int16x8_t pred_mask_1, uint8_t* dst, + const ptrdiff_t dst_stride) { + const int16x8_t pred_val_0 = vld1q_s16(pred_0); + const int16x8_t pred_val_1 = vld1q_s16(pred_1); + // int res = (mask_value * prediction_0[x] + + // (64 - mask_value) * prediction_1[x]) >> 6; + const int32x4_t weighted_pred_0_lo = + vmull_s16(vget_low_s16(pred_mask_0), vget_low_s16(pred_val_0)); + const int32x4_t weighted_pred_0_hi = + vmull_s16(vget_high_s16(pred_mask_0), vget_high_s16(pred_val_0)); + const int32x4_t weighted_combo_lo = vmlal_s16( + weighted_pred_0_lo, vget_low_s16(pred_mask_1), vget_low_s16(pred_val_1)); + const int32x4_t weighted_combo_hi = + vmlal_s16(weighted_pred_0_hi, vget_high_s16(pred_mask_1), + vget_high_s16(pred_val_1)); + // dst[x] = static_cast<Pixel>( + // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0, + // (1 << kBitdepth8) - 1)); + const uint8x8_t result = + vqrshrun_n_s16(vcombine_s16(vshrn_n_s32(weighted_combo_lo, 6), + vshrn_n_s32(weighted_combo_hi, 6)), + 4); + StoreLo4(dst, result); + StoreHi4(dst + dst_stride, result); +} + +template <int subsampling_x, int subsampling_y> +inline void MaskBlending4x4_NEON(const int16_t* pred_0, const int16_t* pred_1, + const uint8_t* mask, + const ptrdiff_t mask_stride, uint8_t* dst, + const ptrdiff_t dst_stride) { + const int16x8_t mask_inverter = vdupq_n_s16(64); + int16x8_t pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + // TODO(b/150461164): Arm tends to do better with load(val); val += stride + // It may be possible to turn this into a loop with a templated height. + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + + pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); +} + +template <int subsampling_x, int subsampling_y> +inline void MaskBlending4xH_NEON(const int16_t* pred_0, const int16_t* pred_1, + const uint8_t* const mask_ptr, + const ptrdiff_t mask_stride, const int height, + uint8_t* dst, const ptrdiff_t dst_stride) { + const uint8_t* mask = mask_ptr; + if (height == 4) { + MaskBlending4x4_NEON<subsampling_x, subsampling_y>( + pred_0, pred_1, mask, mask_stride, dst, dst_stride); + return; + } + const int16x8_t mask_inverter = vdupq_n_s16(64); + int y = 0; + do { + int16x8_t pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + + pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + + pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + + pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + y += 8; + } while (y < height); +} + +template <int subsampling_x, int subsampling_y> +inline void MaskBlend_NEON(const void* prediction_0, const void* prediction_1, + const ptrdiff_t /*prediction_stride_1*/, + const uint8_t* const mask_ptr, + const ptrdiff_t mask_stride, const int width, + const int height, void* dest, + const ptrdiff_t dst_stride) { + auto* dst = static_cast<uint8_t*>(dest); + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + if (width == 4) { + MaskBlending4xH_NEON<subsampling_x, subsampling_y>( + pred_0, pred_1, mask_ptr, mask_stride, height, dst, dst_stride); + return; + } + const uint8_t* mask = mask_ptr; + const int16x8_t mask_inverter = vdupq_n_s16(64); + int y = 0; + do { + int x = 0; + do { + const int16x8_t pred_mask_0 = GetMask8<subsampling_x, subsampling_y>( + mask + (x << subsampling_x), mask_stride); + // 64 - mask + const int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0); + const int16x8_t pred_val_0 = vld1q_s16(pred_0 + x); + const int16x8_t pred_val_1 = vld1q_s16(pred_1 + x); + uint8x8_t result; + // int res = (mask_value * prediction_0[x] + + // (64 - mask_value) * prediction_1[x]) >> 6; + const int32x4_t weighted_pred_0_lo = + vmull_s16(vget_low_s16(pred_mask_0), vget_low_s16(pred_val_0)); + const int32x4_t weighted_pred_0_hi = + vmull_s16(vget_high_s16(pred_mask_0), vget_high_s16(pred_val_0)); + const int32x4_t weighted_combo_lo = + vmlal_s16(weighted_pred_0_lo, vget_low_s16(pred_mask_1), + vget_low_s16(pred_val_1)); + const int32x4_t weighted_combo_hi = + vmlal_s16(weighted_pred_0_hi, vget_high_s16(pred_mask_1), + vget_high_s16(pred_val_1)); + + // dst[x] = static_cast<Pixel>( + // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0, + // (1 << kBitdepth8) - 1)); + result = vqrshrun_n_s16(vcombine_s16(vshrn_n_s32(weighted_combo_lo, 6), + vshrn_n_s32(weighted_combo_hi, 6)), + 4); + vst1_u8(dst + x, result); + + x += 8; + } while (x < width); + dst += dst_stride; + pred_0 += width; + pred_1 += width; + mask += mask_stride << subsampling_y; + } while (++y < height); +} + +// TODO(b/150461164): This is much faster for inter_intra (input is Pixel +// values) but regresses compound versions (input is int16_t). Try to +// consolidate these. +template <int subsampling_x, int subsampling_y> +inline uint8x8_t GetInterIntraMask4x2(const uint8_t* mask, + ptrdiff_t mask_stride) { + if (subsampling_x == 1) { + const uint8x8_t mask_val = + vpadd_u8(vld1_u8(mask), vld1_u8(mask + (mask_stride << subsampling_y))); + if (subsampling_y == 1) { + const uint8x8_t next_mask_val = vpadd_u8(vld1_u8(mask + mask_stride), + vld1_u8(mask + mask_stride * 3)); + + // Use a saturating add to work around the case where all |mask| values + // are 64. Together with the rounding shift this ensures the correct + // result. + const uint8x8_t sum = vqadd_u8(mask_val, next_mask_val); + return vrshr_n_u8(sum, /*subsampling_x=*/1 + subsampling_y); + } + + return vrshr_n_u8(mask_val, /*subsampling_x=*/1); + } + + assert(subsampling_y == 0 && subsampling_x == 0); + const uint8x8_t mask_val0 = Load4(mask); + // TODO(b/150461164): Investigate the source of |mask| and see if the stride + // can be removed. + // TODO(b/150461164): The unit tests start at 8x8. Does this get run? + return Load4<1>(mask + mask_stride, mask_val0); +} + +template <int subsampling_x, int subsampling_y> +inline uint8x8_t GetInterIntraMask8(const uint8_t* mask, + ptrdiff_t mask_stride) { + if (subsampling_x == 1) { + const uint8x16_t mask_val = vld1q_u8(mask); + const uint8x8_t mask_paired = + vpadd_u8(vget_low_u8(mask_val), vget_high_u8(mask_val)); + if (subsampling_y == 1) { + const uint8x16_t next_mask_val = vld1q_u8(mask + mask_stride); + const uint8x8_t next_mask_paired = + vpadd_u8(vget_low_u8(next_mask_val), vget_high_u8(next_mask_val)); + + // Use a saturating add to work around the case where all |mask| values + // are 64. Together with the rounding shift this ensures the correct + // result. + const uint8x8_t sum = vqadd_u8(mask_paired, next_mask_paired); + return vrshr_n_u8(sum, /*subsampling_x=*/1 + subsampling_y); + } + + return vrshr_n_u8(mask_paired, /*subsampling_x=*/1); + } + + assert(subsampling_y == 0 && subsampling_x == 0); + return vld1_u8(mask); +} + +inline void InterIntraWriteMaskBlendLine8bpp4x2(const uint8_t* const pred_0, + uint8_t* const pred_1, + const ptrdiff_t pred_stride_1, + const uint8x8_t pred_mask_0, + const uint8x8_t pred_mask_1) { + const uint8x8_t pred_val_0 = vld1_u8(pred_0); + uint8x8_t pred_val_1 = Load4(pred_1); + pred_val_1 = Load4<1>(pred_1 + pred_stride_1, pred_val_1); + + const uint16x8_t weighted_pred_0 = vmull_u8(pred_mask_0, pred_val_0); + const uint16x8_t weighted_combo = + vmlal_u8(weighted_pred_0, pred_mask_1, pred_val_1); + const uint8x8_t result = vrshrn_n_u16(weighted_combo, 6); + StoreLo4(pred_1, result); + StoreHi4(pred_1 + pred_stride_1, result); +} + +template <int subsampling_x, int subsampling_y> +inline void InterIntraMaskBlending8bpp4x4_NEON(const uint8_t* pred_0, + uint8_t* pred_1, + const ptrdiff_t pred_stride_1, + const uint8_t* mask, + const ptrdiff_t mask_stride) { + const uint8x8_t mask_inverter = vdup_n_u8(64); + uint8x8_t pred_mask_1 = + GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + uint8x8_t pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1); + InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1, + pred_mask_0, pred_mask_1); + pred_0 += 4 << 1; + pred_1 += pred_stride_1 << 1; + mask += mask_stride << (1 + subsampling_y); + + pred_mask_1 = + GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1); + InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1, + pred_mask_0, pred_mask_1); +} + +template <int subsampling_x, int subsampling_y> +inline void InterIntraMaskBlending8bpp4xH_NEON( + const uint8_t* pred_0, uint8_t* pred_1, const ptrdiff_t pred_stride_1, + const uint8_t* mask, const ptrdiff_t mask_stride, const int height) { + if (height == 4) { + InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>( + pred_0, pred_1, pred_stride_1, mask, mask_stride); + return; + } + int y = 0; + do { + InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>( + pred_0, pred_1, pred_stride_1, mask, mask_stride); + pred_0 += 4 << 2; + pred_1 += pred_stride_1 << 2; + mask += mask_stride << (2 + subsampling_y); + + InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>( + pred_0, pred_1, pred_stride_1, mask, mask_stride); + pred_0 += 4 << 2; + pred_1 += pred_stride_1 << 2; + mask += mask_stride << (2 + subsampling_y); + y += 8; + } while (y < height); +} + +template <int subsampling_x, int subsampling_y> +inline void InterIntraMaskBlend8bpp_NEON(const uint8_t* prediction_0, + uint8_t* prediction_1, + const ptrdiff_t prediction_stride_1, + const uint8_t* const mask_ptr, + const ptrdiff_t mask_stride, + const int width, const int height) { + if (width == 4) { + InterIntraMaskBlending8bpp4xH_NEON<subsampling_x, subsampling_y>( + prediction_0, prediction_1, prediction_stride_1, mask_ptr, mask_stride, + height); + return; + } + const uint8_t* mask = mask_ptr; + const uint8x8_t mask_inverter = vdup_n_u8(64); + int y = 0; + do { + int x = 0; + do { + // TODO(b/150461164): Consider a 16 wide specialization (at least for the + // unsampled version) to take advantage of vld1q_u8(). + const uint8x8_t pred_mask_1 = + GetInterIntraMask8<subsampling_x, subsampling_y>( + mask + (x << subsampling_x), mask_stride); + // 64 - mask + const uint8x8_t pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1); + const uint8x8_t pred_val_0 = vld1_u8(prediction_0); + prediction_0 += 8; + const uint8x8_t pred_val_1 = vld1_u8(prediction_1 + x); + const uint16x8_t weighted_pred_0 = vmull_u8(pred_mask_0, pred_val_0); + // weighted_pred0 + weighted_pred1 + const uint16x8_t weighted_combo = + vmlal_u8(weighted_pred_0, pred_mask_1, pred_val_1); + const uint8x8_t result = vrshrn_n_u16(weighted_combo, 6); + vst1_u8(prediction_1 + x, result); + + x += 8; + } while (x < width); + prediction_1 += prediction_stride_1; + mask += mask_stride << subsampling_y; + } while (++y < height); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->mask_blend[0][0] = MaskBlend_NEON<0, 0>; + dsp->mask_blend[1][0] = MaskBlend_NEON<1, 0>; + dsp->mask_blend[2][0] = MaskBlend_NEON<1, 1>; + // The is_inter_intra index of mask_blend[][] is replaced by + // inter_intra_mask_blend_8bpp[] in 8-bit. + dsp->inter_intra_mask_blend_8bpp[0] = InterIntraMaskBlend8bpp_NEON<0, 0>; + dsp->inter_intra_mask_blend_8bpp[1] = InterIntraMaskBlend8bpp_NEON<1, 0>; + dsp->inter_intra_mask_blend_8bpp[2] = InterIntraMaskBlend8bpp_NEON<1, 1>; +} + +} // namespace +} // namespace low_bitdepth + +void MaskBlendInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON + +namespace libgav1 { +namespace dsp { + +void MaskBlendInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/mask_blend_neon.h b/src/dsp/arm/mask_blend_neon.h new file mode 100644 index 0000000..3829274 --- /dev/null +++ b/src/dsp/arm/mask_blend_neon.h @@ -0,0 +1,41 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_MASK_BLEND_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_MASK_BLEND_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::mask_blend. This function is not thread-safe. +void MaskBlendInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_MaskBlend444 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_MaskBlend422 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_MaskBlend420 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp444 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp422 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp420 LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_MASK_BLEND_NEON_H_ diff --git a/src/dsp/arm/motion_field_projection_neon.cc b/src/dsp/arm/motion_field_projection_neon.cc new file mode 100644 index 0000000..8caba7d --- /dev/null +++ b/src/dsp/arm/motion_field_projection_neon.cc @@ -0,0 +1,393 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/motion_field_projection.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" +#include "src/utils/types.h" + +namespace libgav1 { +namespace dsp { +namespace { + +inline int16x8_t LoadDivision(const int8x8x2_t division_table, + const int8x8_t reference_offset) { + const int8x8_t kOne = vcreate_s8(0x0100010001000100); + const int8x16_t kOneQ = vcombine_s8(kOne, kOne); + const int8x8_t t = vadd_s8(reference_offset, reference_offset); + const int8x8x2_t tt = vzip_s8(t, t); + const int8x16_t t1 = vcombine_s8(tt.val[0], tt.val[1]); + const int8x16_t idx = vaddq_s8(t1, kOneQ); + const int8x8_t idx_low = vget_low_s8(idx); + const int8x8_t idx_high = vget_high_s8(idx); + const int16x4_t d0 = vreinterpret_s16_s8(vtbl2_s8(division_table, idx_low)); + const int16x4_t d1 = vreinterpret_s16_s8(vtbl2_s8(division_table, idx_high)); + return vcombine_s16(d0, d1); +} + +inline int16x4_t MvProjection(const int16x4_t mv, const int16x4_t denominator, + const int numerator) { + const int32x4_t m0 = vmull_s16(mv, denominator); + const int32x4_t m = vmulq_n_s32(m0, numerator); + // Add the sign (0 or -1) to round towards zero. + const int32x4_t add_sign = vsraq_n_s32(m, m, 31); + return vqrshrn_n_s32(add_sign, 14); +} + +inline int16x8_t MvProjectionClip(const int16x8_t mv, + const int16x8_t denominator, + const int numerator) { + const int16x4_t mv0 = vget_low_s16(mv); + const int16x4_t mv1 = vget_high_s16(mv); + const int16x4_t s0 = MvProjection(mv0, vget_low_s16(denominator), numerator); + const int16x4_t s1 = MvProjection(mv1, vget_high_s16(denominator), numerator); + const int16x8_t projection = vcombine_s16(s0, s1); + const int16x8_t projection_mv_clamp = vdupq_n_s16(kProjectionMvClamp); + const int16x8_t clamp = vminq_s16(projection, projection_mv_clamp); + return vmaxq_s16(clamp, vnegq_s16(projection_mv_clamp)); +} + +inline int8x8_t Project_NEON(const int16x8_t delta, const int16x8_t dst_sign) { + // Add 63 to negative delta so that it shifts towards zero. + const int16x8_t delta_sign = vshrq_n_s16(delta, 15); + const uint16x8_t delta_u = vreinterpretq_u16_s16(delta); + const uint16x8_t delta_sign_u = vreinterpretq_u16_s16(delta_sign); + const uint16x8_t delta_adjust_u = vsraq_n_u16(delta_u, delta_sign_u, 10); + const int16x8_t delta_adjust = vreinterpretq_s16_u16(delta_adjust_u); + const int16x8_t offset0 = vshrq_n_s16(delta_adjust, 6); + const int16x8_t offset1 = veorq_s16(offset0, dst_sign); + const int16x8_t offset2 = vsubq_s16(offset1, dst_sign); + return vqmovn_s16(offset2); +} + +inline void GetPosition( + const int8x8x2_t division_table, const MotionVector* const mv, + const int numerator, const int x8_start, const int x8_end, const int x8, + const int8x8_t r_offsets, const int8x8_t source_reference_type8, + const int8x8_t skip_r, const int8x8_t y8_floor8, const int8x8_t y8_ceiling8, + const int16x8_t d_sign, const int delta, int8x8_t* const r, + int8x8_t* const position_y8, int8x8_t* const position_x8, + int64_t* const skip_64, int32x4_t mvs[2]) { + const auto* const mv_int = reinterpret_cast<const int32_t*>(mv + x8); + *r = vtbl1_s8(r_offsets, source_reference_type8); + const int16x8_t denorm = LoadDivision(division_table, source_reference_type8); + int16x8_t projection_mv[2]; + mvs[0] = vld1q_s32(mv_int + 0); + mvs[1] = vld1q_s32(mv_int + 4); + // Deinterlace x and y components + const int16x8_t mv0 = vreinterpretq_s16_s32(mvs[0]); + const int16x8_t mv1 = vreinterpretq_s16_s32(mvs[1]); + const int16x8x2_t mv_yx = vuzpq_s16(mv0, mv1); + // numerator could be 0. + projection_mv[0] = MvProjectionClip(mv_yx.val[0], denorm, numerator); + projection_mv[1] = MvProjectionClip(mv_yx.val[1], denorm, numerator); + // Do not update the motion vector if the block position is not valid or + // if position_x8 is outside the current range of x8_start and x8_end. + // Note that position_y8 will always be within the range of y8_start and + // y8_end. + // After subtracting the base, valid projections are within 8-bit. + *position_y8 = Project_NEON(projection_mv[0], d_sign); + const int8x8_t position_x = Project_NEON(projection_mv[1], d_sign); + const int8x8_t k01234567 = vcreate_s8(uint64_t{0x0706050403020100}); + *position_x8 = vqadd_s8(position_x, k01234567); + const int8x16_t position_xy = vcombine_s8(*position_x8, *position_y8); + const int x8_floor = std::max( + x8_start - x8, delta - kProjectionMvMaxHorizontalOffset); // [-8, 8] + const int x8_ceiling = std::min( + x8_end - x8, delta + 8 + kProjectionMvMaxHorizontalOffset); // [0, 16] + const int8x8_t x8_floor8 = vdup_n_s8(x8_floor); + const int8x8_t x8_ceiling8 = vdup_n_s8(x8_ceiling); + const int8x16_t floor_xy = vcombine_s8(x8_floor8, y8_floor8); + const int8x16_t ceiling_xy = vcombine_s8(x8_ceiling8, y8_ceiling8); + const uint8x16_t underflow = vcltq_s8(position_xy, floor_xy); + const uint8x16_t overflow = vcgeq_s8(position_xy, ceiling_xy); + const int8x16_t out = vreinterpretq_s8_u8(vorrq_u8(underflow, overflow)); + const int8x8_t skip_low = vorr_s8(skip_r, vget_low_s8(out)); + const int8x8_t skip = vorr_s8(skip_low, vget_high_s8(out)); + *skip_64 = vget_lane_s64(vreinterpret_s64_s8(skip), 0); +} + +template <int idx> +inline void Store(const int16x8_t position, const int8x8_t reference_offset, + const int32x4_t mv, int8_t* dst_reference_offset, + MotionVector* dst_mv) { + const ptrdiff_t offset = vgetq_lane_s16(position, idx); + auto* const d_mv = reinterpret_cast<int32_t*>(&dst_mv[offset]); + vst1q_lane_s32(d_mv, mv, idx & 3); + vst1_lane_s8(&dst_reference_offset[offset], reference_offset, idx); +} + +template <int idx> +inline void CheckStore(const int8_t* skips, const int16x8_t position, + const int8x8_t reference_offset, const int32x4_t mv, + int8_t* dst_reference_offset, MotionVector* dst_mv) { + if (skips[idx] == 0) { + Store<idx>(position, reference_offset, mv, dst_reference_offset, dst_mv); + } +} + +// 7.9.2. +void MotionFieldProjectionKernel_NEON(const ReferenceInfo& reference_info, + const int reference_to_current_with_sign, + const int dst_sign, const int y8_start, + const int y8_end, const int x8_start, + const int x8_end, + TemporalMotionField* const motion_field) { + const ptrdiff_t stride = motion_field->mv.columns(); + // The column range has to be offset by kProjectionMvMaxHorizontalOffset since + // coordinates in that range could end up being position_x8 because of + // projection. + const int adjusted_x8_start = + std::max(x8_start - kProjectionMvMaxHorizontalOffset, 0); + const int adjusted_x8_end = std::min( + x8_end + kProjectionMvMaxHorizontalOffset, static_cast<int>(stride)); + const int adjusted_x8_end8 = adjusted_x8_end & ~7; + const int leftover = adjusted_x8_end - adjusted_x8_end8; + const int8_t* const reference_offsets = + reference_info.relative_distance_to.data(); + const bool* const skip_references = reference_info.skip_references.data(); + const int16_t* const projection_divisions = + reference_info.projection_divisions.data(); + const ReferenceFrameType* source_reference_types = + &reference_info.motion_field_reference_frame[y8_start][0]; + const MotionVector* mv = &reference_info.motion_field_mv[y8_start][0]; + int8_t* dst_reference_offset = motion_field->reference_offset[y8_start]; + MotionVector* dst_mv = motion_field->mv[y8_start]; + const int16x8_t d_sign = vdupq_n_s16(dst_sign); + + static_assert(sizeof(int8_t) == sizeof(bool), ""); + static_assert(sizeof(int8_t) == sizeof(ReferenceFrameType), ""); + static_assert(sizeof(int32_t) == sizeof(MotionVector), ""); + assert(dst_sign == 0 || dst_sign == -1); + assert(stride == motion_field->reference_offset.columns()); + assert((y8_start & 7) == 0); + assert((adjusted_x8_start & 7) == 0); + // The final position calculation is represented with int16_t. Valid + // position_y8 from its base is at most 7. After considering the horizontal + // offset which is at most |stride - 1|, we have the following assertion, + // which means this optimization works for frame width up to 32K (each + // position is a 8x8 block). + assert(8 * stride <= 32768); + const int8x8_t skip_reference = + vld1_s8(reinterpret_cast<const int8_t*>(skip_references)); + const int8x8_t r_offsets = vld1_s8(reference_offsets); + const int8x16_t table = vreinterpretq_s8_s16(vld1q_s16(projection_divisions)); + int8x8x2_t division_table; + division_table.val[0] = vget_low_s8(table); + division_table.val[1] = vget_high_s8(table); + + int y8 = y8_start; + do { + const int y8_floor = (y8 & ~7) - y8; // [-7, 0] + const int y8_ceiling = std::min(y8_end - y8, y8_floor + 8); // [1, 8] + const int8x8_t y8_floor8 = vdup_n_s8(y8_floor); + const int8x8_t y8_ceiling8 = vdup_n_s8(y8_ceiling); + int x8; + + for (x8 = adjusted_x8_start; x8 < adjusted_x8_end8; x8 += 8) { + const int8x8_t source_reference_type8 = + vld1_s8(reinterpret_cast<const int8_t*>(source_reference_types + x8)); + const int8x8_t skip_r = vtbl1_s8(skip_reference, source_reference_type8); + const int64_t early_skip = vget_lane_s64(vreinterpret_s64_s8(skip_r), 0); + // Early termination #1 if all are skips. Chance is typically ~30-40%. + if (early_skip == -1) continue; + int64_t skip_64; + int8x8_t r, position_x8, position_y8; + int32x4_t mvs[2]; + GetPosition(division_table, mv, reference_to_current_with_sign, x8_start, + x8_end, x8, r_offsets, source_reference_type8, skip_r, + y8_floor8, y8_ceiling8, d_sign, 0, &r, &position_y8, + &position_x8, &skip_64, mvs); + // Early termination #2 if all are skips. + // Chance is typically ~15-25% after Early termination #1. + if (skip_64 == -1) continue; + const int16x8_t p_y = vmovl_s8(position_y8); + const int16x8_t p_x = vmovl_s8(position_x8); + const int16x8_t pos = vmlaq_n_s16(p_x, p_y, stride); + const int16x8_t position = vaddq_s16(pos, vdupq_n_s16(x8)); + if (skip_64 == 0) { + // Store all. Chance is typically ~70-85% after Early termination #2. + Store<0>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<1>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<2>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<3>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<4>(position, r, mvs[1], dst_reference_offset, dst_mv); + Store<5>(position, r, mvs[1], dst_reference_offset, dst_mv); + Store<6>(position, r, mvs[1], dst_reference_offset, dst_mv); + Store<7>(position, r, mvs[1], dst_reference_offset, dst_mv); + } else { + // Check and store each. + // Chance is typically ~15-30% after Early termination #2. + // The compiler is smart enough to not create the local buffer skips[]. + int8_t skips[8]; + memcpy(skips, &skip_64, sizeof(skips)); + CheckStore<0>(skips, position, r, mvs[0], dst_reference_offset, dst_mv); + CheckStore<1>(skips, position, r, mvs[0], dst_reference_offset, dst_mv); + CheckStore<2>(skips, position, r, mvs[0], dst_reference_offset, dst_mv); + CheckStore<3>(skips, position, r, mvs[0], dst_reference_offset, dst_mv); + CheckStore<4>(skips, position, r, mvs[1], dst_reference_offset, dst_mv); + CheckStore<5>(skips, position, r, mvs[1], dst_reference_offset, dst_mv); + CheckStore<6>(skips, position, r, mvs[1], dst_reference_offset, dst_mv); + CheckStore<7>(skips, position, r, mvs[1], dst_reference_offset, dst_mv); + } + } + + // The following leftover processing cannot be moved out of the do...while + // loop. Doing so may change the result storing orders of the same position. + if (leftover > 0) { + // Use SIMD only when leftover is at least 4, and there are at least 8 + // elements in a row. + if (leftover >= 4 && adjusted_x8_start < adjusted_x8_end8) { + // Process the last 8 elements to avoid loading invalid memory. Some + // elements may have been processed in the above loop, which is OK. + const int delta = 8 - leftover; + x8 = adjusted_x8_end - 8; + const int8x8_t source_reference_type8 = vld1_s8( + reinterpret_cast<const int8_t*>(source_reference_types + x8)); + const int8x8_t skip_r = + vtbl1_s8(skip_reference, source_reference_type8); + const int64_t early_skip = + vget_lane_s64(vreinterpret_s64_s8(skip_r), 0); + // Early termination #1 if all are skips. + if (early_skip != -1) { + int64_t skip_64; + int8x8_t r, position_x8, position_y8; + int32x4_t mvs[2]; + GetPosition(division_table, mv, reference_to_current_with_sign, + x8_start, x8_end, x8, r_offsets, source_reference_type8, + skip_r, y8_floor8, y8_ceiling8, d_sign, delta, &r, + &position_y8, &position_x8, &skip_64, mvs); + // Early termination #2 if all are skips. + if (skip_64 != -1) { + const int16x8_t p_y = vmovl_s8(position_y8); + const int16x8_t p_x = vmovl_s8(position_x8); + const int16x8_t pos = vmlaq_n_s16(p_x, p_y, stride); + const int16x8_t position = vaddq_s16(pos, vdupq_n_s16(x8)); + // Store up to 7 elements since leftover is at most 7. + if (skip_64 == 0) { + // Store all. + Store<1>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<2>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<3>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<4>(position, r, mvs[1], dst_reference_offset, dst_mv); + Store<5>(position, r, mvs[1], dst_reference_offset, dst_mv); + Store<6>(position, r, mvs[1], dst_reference_offset, dst_mv); + Store<7>(position, r, mvs[1], dst_reference_offset, dst_mv); + } else { + // Check and store each. + // The compiler is smart enough to not create the local buffer + // skips[]. + int8_t skips[8]; + memcpy(skips, &skip_64, sizeof(skips)); + CheckStore<1>(skips, position, r, mvs[0], dst_reference_offset, + dst_mv); + CheckStore<2>(skips, position, r, mvs[0], dst_reference_offset, + dst_mv); + CheckStore<3>(skips, position, r, mvs[0], dst_reference_offset, + dst_mv); + CheckStore<4>(skips, position, r, mvs[1], dst_reference_offset, + dst_mv); + CheckStore<5>(skips, position, r, mvs[1], dst_reference_offset, + dst_mv); + CheckStore<6>(skips, position, r, mvs[1], dst_reference_offset, + dst_mv); + CheckStore<7>(skips, position, r, mvs[1], dst_reference_offset, + dst_mv); + } + } + } + } else { + for (; x8 < adjusted_x8_end; ++x8) { + const int source_reference_type = source_reference_types[x8]; + if (skip_references[source_reference_type]) continue; + MotionVector projection_mv; + // reference_to_current_with_sign could be 0. + GetMvProjection(mv[x8], reference_to_current_with_sign, + projection_divisions[source_reference_type], + &projection_mv); + // Do not update the motion vector if the block position is not valid + // or if position_x8 is outside the current range of x8_start and + // x8_end. Note that position_y8 will always be within the range of + // y8_start and y8_end. + const int position_y8 = Project(0, projection_mv.mv[0], dst_sign); + if (position_y8 < y8_floor || position_y8 >= y8_ceiling) continue; + const int x8_base = x8 & ~7; + const int x8_floor = + std::max(x8_start, x8_base - kProjectionMvMaxHorizontalOffset); + const int x8_ceiling = + std::min(x8_end, x8_base + 8 + kProjectionMvMaxHorizontalOffset); + const int position_x8 = Project(x8, projection_mv.mv[1], dst_sign); + if (position_x8 < x8_floor || position_x8 >= x8_ceiling) continue; + dst_mv[position_y8 * stride + position_x8] = mv[x8]; + dst_reference_offset[position_y8 * stride + position_x8] = + reference_offsets[source_reference_type]; + } + } + } + + source_reference_types += stride; + mv += stride; + dst_reference_offset += stride; + dst_mv += stride; + } while (++y8 < y8_end); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_NEON; +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_NEON; +} +#endif + +} // namespace + +void MotionFieldProjectionInit_NEON() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void MotionFieldProjectionInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/motion_field_projection_neon.h b/src/dsp/arm/motion_field_projection_neon.h new file mode 100644 index 0000000..41ab6a6 --- /dev/null +++ b/src/dsp/arm/motion_field_projection_neon.h @@ -0,0 +1,39 @@ +/* + * Copyright 2020 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_MOTION_FIELD_PROJECTION_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_MOTION_FIELD_PROJECTION_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::motion_field_projection_kernel. This function is not +// thread-safe. +void MotionFieldProjectionInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON + +#define LIBGAV1_Dsp8bpp_MotionFieldProjectionKernel LIBGAV1_CPU_NEON + +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_MOTION_FIELD_PROJECTION_NEON_H_ diff --git a/src/dsp/arm/motion_vector_search_neon.cc b/src/dsp/arm/motion_vector_search_neon.cc new file mode 100644 index 0000000..8a403a6 --- /dev/null +++ b/src/dsp/arm/motion_vector_search_neon.cc @@ -0,0 +1,267 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/motion_vector_search.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" +#include "src/utils/types.h" + +namespace libgav1 { +namespace dsp { +namespace { + +inline int16x4_t MvProjection(const int16x4_t mv, const int16x4_t denominator, + const int32x4_t numerator) { + const int32x4_t m0 = vmull_s16(mv, denominator); + const int32x4_t m = vmulq_s32(m0, numerator); + // Add the sign (0 or -1) to round towards zero. + const int32x4_t add_sign = vsraq_n_s32(m, m, 31); + return vqrshrn_n_s32(add_sign, 14); +} + +inline int16x4_t MvProjectionCompound(const int16x4_t mv, + const int temporal_reference_offsets, + const int reference_offsets[2]) { + const int16x4_t denominator = + vdup_n_s16(kProjectionMvDivisionLookup[temporal_reference_offsets]); + const int32x2_t offset = vld1_s32(reference_offsets); + const int32x2x2_t offsets = vzip_s32(offset, offset); + const int32x4_t numerator = vcombine_s32(offsets.val[0], offsets.val[1]); + return MvProjection(mv, denominator, numerator); +} + +inline int16x8_t ProjectionClip(const int16x4_t mv0, const int16x4_t mv1) { + const int16x8_t projection_mv_clamp = vdupq_n_s16(kProjectionMvClamp); + const int16x8_t mv = vcombine_s16(mv0, mv1); + const int16x8_t clamp = vminq_s16(mv, projection_mv_clamp); + return vmaxq_s16(clamp, vnegq_s16(projection_mv_clamp)); +} + +inline int16x8_t MvProjectionCompoundClip( + const MotionVector* const temporal_mvs, + const int8_t* const temporal_reference_offsets, + const int reference_offsets[2]) { + const auto* const tmvs = reinterpret_cast<const int32_t*>(temporal_mvs); + const int32x2_t temporal_mv = vld1_s32(tmvs); + const int16x4_t tmv0 = vreinterpret_s16_s32(vdup_lane_s32(temporal_mv, 0)); + const int16x4_t tmv1 = vreinterpret_s16_s32(vdup_lane_s32(temporal_mv, 1)); + const int16x4_t mv0 = MvProjectionCompound( + tmv0, temporal_reference_offsets[0], reference_offsets); + const int16x4_t mv1 = MvProjectionCompound( + tmv1, temporal_reference_offsets[1], reference_offsets); + return ProjectionClip(mv0, mv1); +} + +inline int16x8_t MvProjectionSingleClip( + const MotionVector* const temporal_mvs, + const int8_t* const temporal_reference_offsets, const int reference_offset, + int16x4_t* const lookup) { + const auto* const tmvs = reinterpret_cast<const int16_t*>(temporal_mvs); + const int16x8_t temporal_mv = vld1q_s16(tmvs); + *lookup = vld1_lane_s16( + &kProjectionMvDivisionLookup[temporal_reference_offsets[0]], *lookup, 0); + *lookup = vld1_lane_s16( + &kProjectionMvDivisionLookup[temporal_reference_offsets[1]], *lookup, 1); + *lookup = vld1_lane_s16( + &kProjectionMvDivisionLookup[temporal_reference_offsets[2]], *lookup, 2); + *lookup = vld1_lane_s16( + &kProjectionMvDivisionLookup[temporal_reference_offsets[3]], *lookup, 3); + const int16x4x2_t denominator = vzip_s16(*lookup, *lookup); + const int16x4_t tmv0 = vget_low_s16(temporal_mv); + const int16x4_t tmv1 = vget_high_s16(temporal_mv); + const int32x4_t numerator = vdupq_n_s32(reference_offset); + const int16x4_t mv0 = MvProjection(tmv0, denominator.val[0], numerator); + const int16x4_t mv1 = MvProjection(tmv1, denominator.val[1], numerator); + return ProjectionClip(mv0, mv1); +} + +inline void LowPrecision(const int16x8_t mv, void* const candidate_mvs) { + const int16x8_t kRoundDownMask = vdupq_n_s16(1); + const uint16x8_t mvu = vreinterpretq_u16_s16(mv); + const int16x8_t mv0 = vreinterpretq_s16_u16(vsraq_n_u16(mvu, mvu, 15)); + const int16x8_t mv1 = vbicq_s16(mv0, kRoundDownMask); + vst1q_s16(static_cast<int16_t*>(candidate_mvs), mv1); +} + +inline void ForceInteger(const int16x8_t mv, void* const candidate_mvs) { + const int16x8_t kRoundDownMask = vdupq_n_s16(7); + const uint16x8_t mvu = vreinterpretq_u16_s16(mv); + const int16x8_t mv0 = vreinterpretq_s16_u16(vsraq_n_u16(mvu, mvu, 15)); + const int16x8_t mv1 = vaddq_s16(mv0, vdupq_n_s16(3)); + const int16x8_t mv2 = vbicq_s16(mv1, kRoundDownMask); + vst1q_s16(static_cast<int16_t*>(candidate_mvs), mv2); +} + +void MvProjectionCompoundLowPrecision_NEON( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + const int reference_offsets[2], const int count, + CompoundMotionVector* candidate_mvs) { + // |reference_offsets| non-zero check usually equals true and is ignored. + // To facilitate the compilers, make a local copy of |reference_offsets|. + const int offsets[2] = {reference_offsets[0], reference_offsets[1]}; + // One more element could be calculated. + int loop_count = (count + 1) >> 1; + do { + const int16x8_t mv = MvProjectionCompoundClip( + temporal_mvs, temporal_reference_offsets, offsets); + LowPrecision(mv, candidate_mvs); + temporal_mvs += 2; + temporal_reference_offsets += 2; + candidate_mvs += 2; + } while (--loop_count); +} + +void MvProjectionCompoundForceInteger_NEON( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + const int reference_offsets[2], const int count, + CompoundMotionVector* candidate_mvs) { + // |reference_offsets| non-zero check usually equals true and is ignored. + // To facilitate the compilers, make a local copy of |reference_offsets|. + const int offsets[2] = {reference_offsets[0], reference_offsets[1]}; + // One more element could be calculated. + int loop_count = (count + 1) >> 1; + do { + const int16x8_t mv = MvProjectionCompoundClip( + temporal_mvs, temporal_reference_offsets, offsets); + ForceInteger(mv, candidate_mvs); + temporal_mvs += 2; + temporal_reference_offsets += 2; + candidate_mvs += 2; + } while (--loop_count); +} + +void MvProjectionCompoundHighPrecision_NEON( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + const int reference_offsets[2], const int count, + CompoundMotionVector* candidate_mvs) { + // |reference_offsets| non-zero check usually equals true and is ignored. + // To facilitate the compilers, make a local copy of |reference_offsets|. + const int offsets[2] = {reference_offsets[0], reference_offsets[1]}; + // One more element could be calculated. + int loop_count = (count + 1) >> 1; + do { + const int16x8_t mv = MvProjectionCompoundClip( + temporal_mvs, temporal_reference_offsets, offsets); + vst1q_s16(reinterpret_cast<int16_t*>(candidate_mvs), mv); + temporal_mvs += 2; + temporal_reference_offsets += 2; + candidate_mvs += 2; + } while (--loop_count); +} + +void MvProjectionSingleLowPrecision_NEON( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + const int reference_offset, const int count, MotionVector* candidate_mvs) { + // Up to three more elements could be calculated. + int loop_count = (count + 3) >> 2; + int16x4_t lookup = vdup_n_s16(0); + do { + const int16x8_t mv = MvProjectionSingleClip( + temporal_mvs, temporal_reference_offsets, reference_offset, &lookup); + LowPrecision(mv, candidate_mvs); + temporal_mvs += 4; + temporal_reference_offsets += 4; + candidate_mvs += 4; + } while (--loop_count); +} + +void MvProjectionSingleForceInteger_NEON( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + const int reference_offset, const int count, MotionVector* candidate_mvs) { + // Up to three more elements could be calculated. + int loop_count = (count + 3) >> 2; + int16x4_t lookup = vdup_n_s16(0); + do { + const int16x8_t mv = MvProjectionSingleClip( + temporal_mvs, temporal_reference_offsets, reference_offset, &lookup); + ForceInteger(mv, candidate_mvs); + temporal_mvs += 4; + temporal_reference_offsets += 4; + candidate_mvs += 4; + } while (--loop_count); +} + +void MvProjectionSingleHighPrecision_NEON( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + const int reference_offset, const int count, MotionVector* candidate_mvs) { + // Up to three more elements could be calculated. + int loop_count = (count + 3) >> 2; + int16x4_t lookup = vdup_n_s16(0); + do { + const int16x8_t mv = MvProjectionSingleClip( + temporal_mvs, temporal_reference_offsets, reference_offset, &lookup); + vst1q_s16(reinterpret_cast<int16_t*>(candidate_mvs), mv); + temporal_mvs += 4; + temporal_reference_offsets += 4; + candidate_mvs += 4; + } while (--loop_count); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_NEON; + dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_NEON; + dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_NEON; + dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_NEON; + dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_NEON; + dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_NEON; +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_NEON; + dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_NEON; + dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_NEON; + dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_NEON; + dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_NEON; + dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_NEON; +} +#endif + +} // namespace + +void MotionVectorSearchInit_NEON() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void MotionVectorSearchInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/motion_vector_search_neon.h b/src/dsp/arm/motion_vector_search_neon.h new file mode 100644 index 0000000..19b4519 --- /dev/null +++ b/src/dsp/arm/motion_vector_search_neon.h @@ -0,0 +1,39 @@ +/* + * Copyright 2020 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_MOTION_VECTOR_SEARCH_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_MOTION_VECTOR_SEARCH_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::mv_projection_compound and Dsp::mv_projection_single. This +// function is not thread-safe. +void MotionVectorSearchInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON + +#define LIBGAV1_Dsp8bpp_MotionVectorSearch LIBGAV1_CPU_NEON + +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_MOTION_VECTOR_SEARCH_NEON_H_ diff --git a/src/dsp/arm/obmc_neon.cc b/src/dsp/arm/obmc_neon.cc new file mode 100644 index 0000000..66ad663 --- /dev/null +++ b/src/dsp/arm/obmc_neon.cc @@ -0,0 +1,392 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/obmc.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace { + +#include "src/dsp/obmc.inc" + +inline void WriteObmcLine4(uint8_t* const pred, const uint8_t* const obmc_pred, + const uint8x8_t pred_mask, + const uint8x8_t obmc_pred_mask) { + const uint8x8_t pred_val = Load4(pred); + const uint8x8_t obmc_pred_val = Load4(obmc_pred); + const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val); + const uint8x8_t result = + vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6); + StoreLo4(pred, result); +} + +template <bool from_left> +inline void OverlapBlend2xH_NEON(uint8_t* const prediction, + const ptrdiff_t prediction_stride, + const int height, + const uint8_t* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + uint8_t* pred = prediction; + const uint8x8_t mask_inverter = vdup_n_u8(64); + const uint8_t* obmc_pred = obmc_prediction; + uint8x8_t pred_mask; + uint8x8_t obmc_pred_mask; + int compute_height; + const int mask_offset = height - 2; + if (from_left) { + pred_mask = Load2(kObmcMask); + obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + compute_height = height; + } else { + // Weights for the last line are all 64, which is a no-op. + compute_height = height - 1; + } + uint8x8_t pred_val = vdup_n_u8(0); + uint8x8_t obmc_pred_val = vdup_n_u8(0); + int y = 0; + do { + if (!from_left) { + pred_mask = vdup_n_u8(kObmcMask[mask_offset + y]); + obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + } + pred_val = Load2<0>(pred, pred_val); + const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val); + obmc_pred_val = Load2<0>(obmc_pred, obmc_pred_val); + const uint8x8_t result = + vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6); + Store2<0>(pred, result); + + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + } while (++y != compute_height); +} + +inline void OverlapBlendFromLeft4xH_NEON( + uint8_t* const prediction, const ptrdiff_t prediction_stride, + const int height, const uint8_t* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + uint8_t* pred = prediction; + const uint8_t* obmc_pred = obmc_prediction; + + const uint8x8_t mask_inverter = vdup_n_u8(64); + const uint8x8_t pred_mask = Load4(kObmcMask + 2); + // 64 - mask + const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + int y = 0; + do { + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + y += 2; + } while (y != height); +} + +inline void OverlapBlendFromLeft8xH_NEON( + uint8_t* const prediction, const ptrdiff_t prediction_stride, + const int height, const uint8_t* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + uint8_t* pred = prediction; + const uint8_t* obmc_pred = obmc_prediction; + const uint8x8_t mask_inverter = vdup_n_u8(64); + const uint8x8_t pred_mask = vld1_u8(kObmcMask + 6); + // 64 - mask + const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + int y = 0; + do { + const uint8x8_t pred_val = vld1_u8(pred); + const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val); + const uint8x8_t obmc_pred_val = vld1_u8(obmc_pred); + const uint8x8_t result = + vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6); + + vst1_u8(pred, result); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + } while (++y != height); +} + +void OverlapBlendFromLeft_NEON(void* const prediction, + const ptrdiff_t prediction_stride, + const int width, const int height, + const void* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + auto* pred = static_cast<uint8_t*>(prediction); + const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction); + + if (width == 2) { + OverlapBlend2xH_NEON<true>(pred, prediction_stride, height, obmc_pred, + obmc_prediction_stride); + return; + } + if (width == 4) { + OverlapBlendFromLeft4xH_NEON(pred, prediction_stride, height, obmc_pred, + obmc_prediction_stride); + return; + } + if (width == 8) { + OverlapBlendFromLeft8xH_NEON(pred, prediction_stride, height, obmc_pred, + obmc_prediction_stride); + return; + } + const uint8x16_t mask_inverter = vdupq_n_u8(64); + const uint8_t* mask = kObmcMask + width - 2; + int x = 0; + do { + pred = static_cast<uint8_t*>(prediction) + x; + obmc_pred = static_cast<const uint8_t*>(obmc_prediction) + x; + const uint8x16_t pred_mask = vld1q_u8(mask + x); + // 64 - mask + const uint8x16_t obmc_pred_mask = vsubq_u8(mask_inverter, pred_mask); + int y = 0; + do { + const uint8x16_t pred_val = vld1q_u8(pred); + const uint8x16_t obmc_pred_val = vld1q_u8(obmc_pred); + const uint16x8_t weighted_pred_lo = + vmull_u8(vget_low_u8(pred_mask), vget_low_u8(pred_val)); + const uint8x8_t result_lo = + vrshrn_n_u16(vmlal_u8(weighted_pred_lo, vget_low_u8(obmc_pred_mask), + vget_low_u8(obmc_pred_val)), + 6); + const uint16x8_t weighted_pred_hi = + vmull_u8(vget_high_u8(pred_mask), vget_high_u8(pred_val)); + const uint8x8_t result_hi = + vrshrn_n_u16(vmlal_u8(weighted_pred_hi, vget_high_u8(obmc_pred_mask), + vget_high_u8(obmc_pred_val)), + 6); + vst1q_u8(pred, vcombine_u8(result_lo, result_hi)); + + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + } while (++y < height); + x += 16; + } while (x < width); +} + +inline void OverlapBlendFromTop4x4_NEON(uint8_t* const prediction, + const ptrdiff_t prediction_stride, + const uint8_t* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride, + const int height) { + uint8_t* pred = prediction; + const uint8_t* obmc_pred = obmc_prediction; + uint8x8_t pred_mask = vdup_n_u8(kObmcMask[height - 2]); + const uint8x8_t mask_inverter = vdup_n_u8(64); + uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + if (height == 2) { + return; + } + + pred_mask = vdup_n_u8(kObmcMask[3]); + obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + pred_mask = vdup_n_u8(kObmcMask[4]); + obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); +} + +inline void OverlapBlendFromTop4xH_NEON( + uint8_t* const prediction, const ptrdiff_t prediction_stride, + const int height, const uint8_t* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + if (height < 8) { + OverlapBlendFromTop4x4_NEON(prediction, prediction_stride, obmc_prediction, + obmc_prediction_stride, height); + return; + } + uint8_t* pred = prediction; + const uint8_t* obmc_pred = obmc_prediction; + const uint8_t* mask = kObmcMask + height - 2; + const uint8x8_t mask_inverter = vdup_n_u8(64); + int y = 0; + // Compute 6 lines for height 8, or 12 lines for height 16. The remaining + // lines are unchanged as the corresponding mask value is 64. + do { + uint8x8_t pred_mask = vdup_n_u8(mask[y]); + uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + pred_mask = vdup_n_u8(mask[y + 1]); + obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + pred_mask = vdup_n_u8(mask[y + 2]); + obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + pred_mask = vdup_n_u8(mask[y + 3]); + obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + pred_mask = vdup_n_u8(mask[y + 4]); + obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + pred_mask = vdup_n_u8(mask[y + 5]); + obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + // Increment for the right mask index. + y += 6; + } while (y < height - 4); +} + +inline void OverlapBlendFromTop8xH_NEON( + uint8_t* const prediction, const ptrdiff_t prediction_stride, + const int height, const uint8_t* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + uint8_t* pred = prediction; + const uint8_t* obmc_pred = obmc_prediction; + const uint8x8_t mask_inverter = vdup_n_u8(64); + const uint8_t* mask = kObmcMask + height - 2; + const int compute_height = height - (height >> 2); + int y = 0; + do { + const uint8x8_t pred_mask = vdup_n_u8(mask[y]); + // 64 - mask + const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + const uint8x8_t pred_val = vld1_u8(pred); + const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val); + const uint8x8_t obmc_pred_val = vld1_u8(obmc_pred); + const uint8x8_t result = + vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6); + + vst1_u8(pred, result); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + } while (++y != compute_height); +} + +void OverlapBlendFromTop_NEON(void* const prediction, + const ptrdiff_t prediction_stride, + const int width, const int height, + const void* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + auto* pred = static_cast<uint8_t*>(prediction); + const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction); + + if (width == 2) { + OverlapBlend2xH_NEON<false>(pred, prediction_stride, height, obmc_pred, + obmc_prediction_stride); + return; + } + if (width == 4) { + OverlapBlendFromTop4xH_NEON(pred, prediction_stride, height, obmc_pred, + obmc_prediction_stride); + return; + } + + if (width == 8) { + OverlapBlendFromTop8xH_NEON(pred, prediction_stride, height, obmc_pred, + obmc_prediction_stride); + return; + } + + const uint8_t* mask = kObmcMask + height - 2; + const uint8x8_t mask_inverter = vdup_n_u8(64); + // Stop when mask value becomes 64. This is inferred for 4xH. + const int compute_height = height - (height >> 2); + int y = 0; + do { + const uint8x8_t pred_mask = vdup_n_u8(mask[y]); + // 64 - mask + const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask); + int x = 0; + do { + const uint8x16_t pred_val = vld1q_u8(pred + x); + const uint8x16_t obmc_pred_val = vld1q_u8(obmc_pred + x); + const uint16x8_t weighted_pred_lo = + vmull_u8(pred_mask, vget_low_u8(pred_val)); + const uint8x8_t result_lo = + vrshrn_n_u16(vmlal_u8(weighted_pred_lo, obmc_pred_mask, + vget_low_u8(obmc_pred_val)), + 6); + const uint16x8_t weighted_pred_hi = + vmull_u8(pred_mask, vget_high_u8(pred_val)); + const uint8x8_t result_hi = + vrshrn_n_u16(vmlal_u8(weighted_pred_hi, obmc_pred_mask, + vget_high_u8(obmc_pred_val)), + 6); + vst1q_u8(pred + x, vcombine_u8(result_lo, result_hi)); + + x += 16; + } while (x < width); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + } while (++y < compute_height); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendFromTop_NEON; + dsp->obmc_blend[kObmcDirectionHorizontal] = OverlapBlendFromLeft_NEON; +} + +} // namespace + +void ObmcInit_NEON() { Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON + +namespace libgav1 { +namespace dsp { + +void ObmcInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/obmc_neon.h b/src/dsp/arm/obmc_neon.h new file mode 100644 index 0000000..d5c9d9c --- /dev/null +++ b/src/dsp/arm/obmc_neon.h @@ -0,0 +1,38 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_OBMC_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_OBMC_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::obmc_blend. This function is not thread-safe. +void ObmcInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +// If NEON is enabled, signal the NEON implementation should be used. +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_ObmcVertical LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_ObmcHorizontal LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_OBMC_NEON_H_ diff --git a/src/dsp/arm/super_res_neon.cc b/src/dsp/arm/super_res_neon.cc new file mode 100644 index 0000000..1680450 --- /dev/null +++ b/src/dsp/arm/super_res_neon.cc @@ -0,0 +1,166 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/super_res.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { + +namespace low_bitdepth { +namespace { + +void SuperResCoefficients_NEON(const int upscaled_width, + const int initial_subpixel_x, const int step, + void* const coefficients) { + auto* dst = static_cast<uint8_t*>(coefficients); + int subpixel_x = initial_subpixel_x; + int x = RightShiftWithCeiling(upscaled_width, 3); + do { + uint8x8_t filter[8]; + uint8x16_t d[kSuperResFilterTaps / 2]; + for (int i = 0; i < 8; ++i, subpixel_x += step) { + filter[i] = + vld1_u8(kUpscaleFilterUnsigned[(subpixel_x & kSuperResScaleMask) >> + kSuperResExtraBits]); + } + Transpose8x8(filter, d); + vst1q_u8(dst, d[0]); + dst += 16; + vst1q_u8(dst, d[1]); + dst += 16; + vst1q_u8(dst, d[2]); + dst += 16; + vst1q_u8(dst, d[3]); + dst += 16; + } while (--x != 0); +} + +// Maximum sum of positive taps: 171 = 7 + 86 + 71 + 7 +// Maximum sum: 255*171 == 0xAA55 +// The sum is clipped to [0, 255], so adding all positive and then +// subtracting all negative with saturation is sufficient. +// 0 1 2 3 4 5 6 7 +// tap sign: - + - + + - + - +inline uint8x8_t SuperRes(const uint8x8_t src[kSuperResFilterTaps], + const uint8_t** coefficients) { + uint8x16_t f[kSuperResFilterTaps / 2]; + for (int i = 0; i < kSuperResFilterTaps / 2; ++i, *coefficients += 16) { + f[i] = vld1q_u8(*coefficients); + } + uint16x8_t res = vmull_u8(src[1], vget_high_u8(f[0])); + res = vmlal_u8(res, src[3], vget_high_u8(f[1])); + res = vmlal_u8(res, src[4], vget_low_u8(f[2])); + res = vmlal_u8(res, src[6], vget_low_u8(f[3])); + uint16x8_t temp = vmull_u8(src[0], vget_low_u8(f[0])); + temp = vmlal_u8(temp, src[2], vget_low_u8(f[1])); + temp = vmlal_u8(temp, src[5], vget_high_u8(f[2])); + temp = vmlal_u8(temp, src[7], vget_high_u8(f[3])); + res = vqsubq_u16(res, temp); + return vqrshrn_n_u16(res, kFilterBits); +} + +void SuperRes_NEON(const void* const coefficients, void* const source, + const ptrdiff_t stride, const int height, + const int downscaled_width, const int upscaled_width, + const int initial_subpixel_x, const int step, + void* const dest) { + auto* src = static_cast<uint8_t*>(source) - DivideBy2(kSuperResFilterTaps); + auto* dst = static_cast<uint8_t*>(dest); + int y = height; + do { + const auto* filter = static_cast<const uint8_t*>(coefficients); + uint8_t* dst_ptr = dst; + ExtendLine<uint8_t>(src + DivideBy2(kSuperResFilterTaps), downscaled_width, + kSuperResHorizontalBorder, kSuperResHorizontalBorder); + int subpixel_x = initial_subpixel_x; + uint8x8_t sr[8]; + uint8x16_t s[8]; + int x = RightShiftWithCeiling(upscaled_width, 4); + // The below code calculates up to 15 extra upscaled + // pixels which will over-read up to 15 downscaled pixels in the end of each + // row. kSuperResHorizontalBorder accounts for this. + do { + for (int i = 0; i < 8; ++i, subpixel_x += step) { + sr[i] = vld1_u8(&src[subpixel_x >> kSuperResScaleBits]); + } + for (int i = 0; i < 8; ++i, subpixel_x += step) { + const uint8x8_t s_hi = vld1_u8(&src[subpixel_x >> kSuperResScaleBits]); + s[i] = vcombine_u8(sr[i], s_hi); + } + Transpose8x16(s); + // Do not use loop for the following 8 instructions, since the compiler + // will generate redundant code. + sr[0] = vget_low_u8(s[0]); + sr[1] = vget_low_u8(s[1]); + sr[2] = vget_low_u8(s[2]); + sr[3] = vget_low_u8(s[3]); + sr[4] = vget_low_u8(s[4]); + sr[5] = vget_low_u8(s[5]); + sr[6] = vget_low_u8(s[6]); + sr[7] = vget_low_u8(s[7]); + const uint8x8_t d0 = SuperRes(sr, &filter); + // Do not use loop for the following 8 instructions, since the compiler + // will generate redundant code. + sr[0] = vget_high_u8(s[0]); + sr[1] = vget_high_u8(s[1]); + sr[2] = vget_high_u8(s[2]); + sr[3] = vget_high_u8(s[3]); + sr[4] = vget_high_u8(s[4]); + sr[5] = vget_high_u8(s[5]); + sr[6] = vget_high_u8(s[6]); + sr[7] = vget_high_u8(s[7]); + const uint8x8_t d1 = SuperRes(sr, &filter); + vst1q_u8(dst_ptr, vcombine_u8(d0, d1)); + dst_ptr += 16; + } while (--x != 0); + src += stride; + dst += stride; + } while (--y != 0); +} + +void Init8bpp() { + Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + dsp->super_res_coefficients = SuperResCoefficients_NEON; + dsp->super_res = SuperRes_NEON; +} + +} // namespace +} // namespace low_bitdepth + +void SuperResInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON + +namespace libgav1 { +namespace dsp { + +void SuperResInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/super_res_neon.h b/src/dsp/arm/super_res_neon.h new file mode 100644 index 0000000..f51785d --- /dev/null +++ b/src/dsp/arm/super_res_neon.h @@ -0,0 +1,37 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_SUPER_RES_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_SUPER_RES_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::super_res. This function is not thread-safe. +void SuperResInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_SuperRes LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_SuperResClip LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_SUPER_RES_NEON_H_ diff --git a/src/dsp/arm/warp_neon.cc b/src/dsp/arm/warp_neon.cc new file mode 100644 index 0000000..7a41998 --- /dev/null +++ b/src/dsp/arm/warp_neon.cc @@ -0,0 +1,453 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/warp.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <type_traits> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +// Number of extra bits of precision in warped filtering. +constexpr int kWarpedDiffPrecisionBits = 10; +constexpr int kFirstPassOffset = 1 << 14; +constexpr int kOffsetRemoval = + (kFirstPassOffset >> kInterRoundBitsHorizontal) * 128; + +// Applies the horizontal filter to one source row and stores the result in +// |intermediate_result_row|. |intermediate_result_row| is a row in the 15x8 +// |intermediate_result| two-dimensional array. +// +// src_row_centered contains 16 "centered" samples of a source row. (We center +// the samples by subtracting 128 from the samples.) +void HorizontalFilter(const int sx4, const int16_t alpha, + const int8x16_t src_row_centered, + int16_t intermediate_result_row[8]) { + int sx = sx4 - MultiplyBy4(alpha); + int8x8_t filter[8]; + for (int x = 0; x < 8; ++x) { + const int offset = RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) + + kWarpedPixelPrecisionShifts; + filter[x] = vld1_s8(kWarpedFilters8[offset]); + sx += alpha; + } + Transpose8x8(filter); + // Add kFirstPassOffset to ensure |sum| stays within uint16_t. + // Add 128 (offset) * 128 (filter sum) (also 1 << 14) to account for the + // centering of the source samples. These combined are 1 << 15 or -32768. + int16x8_t sum = + vdupq_n_s16(static_cast<int16_t>(kFirstPassOffset + 128 * 128)); + // Unrolled k = 0..7 loop. We need to manually unroll the loop because the + // third argument (an index value) to vextq_s8() must be a constant + // (immediate). src_row_window is a sliding window of length 8 into + // src_row_centered. + // k = 0. + int8x8_t src_row_window = vget_low_s8(src_row_centered); + sum = vmlal_s8(sum, filter[0], src_row_window); + // k = 1. + src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 1)); + sum = vmlal_s8(sum, filter[1], src_row_window); + // k = 2. + src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 2)); + sum = vmlal_s8(sum, filter[2], src_row_window); + // k = 3. + src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 3)); + sum = vmlal_s8(sum, filter[3], src_row_window); + // k = 4. + src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 4)); + sum = vmlal_s8(sum, filter[4], src_row_window); + // k = 5. + src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 5)); + sum = vmlal_s8(sum, filter[5], src_row_window); + // k = 6. + src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 6)); + sum = vmlal_s8(sum, filter[6], src_row_window); + // k = 7. + src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 7)); + sum = vmlal_s8(sum, filter[7], src_row_window); + // End of unrolled k = 0..7 loop. + // Due to the offset |sum| is guaranteed to be unsigned. + uint16x8_t sum_unsigned = vreinterpretq_u16_s16(sum); + sum_unsigned = vrshrq_n_u16(sum_unsigned, kInterRoundBitsHorizontal); + // After the shift |sum_unsigned| will fit into int16_t. + vst1q_s16(intermediate_result_row, vreinterpretq_s16_u16(sum_unsigned)); +} + +template <bool is_compound> +void Warp_NEON(const void* const source, const ptrdiff_t source_stride, + const int source_width, const int source_height, + const int* const warp_params, const int subsampling_x, + const int subsampling_y, const int block_start_x, + const int block_start_y, const int block_width, + const int block_height, const int16_t alpha, const int16_t beta, + const int16_t gamma, const int16_t delta, void* dest, + const ptrdiff_t dest_stride) { + constexpr int kRoundBitsVertical = + is_compound ? kInterRoundBitsCompoundVertical : kInterRoundBitsVertical; + union { + // Intermediate_result is the output of the horizontal filtering and + // rounding. The range is within 13 (= bitdepth + kFilterBits + 1 - + // kInterRoundBitsHorizontal) bits (unsigned). We use the signed int16_t + // type so that we can multiply it by kWarpedFilters (which has signed + // values) using vmlal_s16(). + int16_t intermediate_result[15][8]; // 15 rows, 8 columns. + // In the simple special cases where the samples in each row are all the + // same, store one sample per row in a column vector. + int16_t intermediate_result_column[15]; + }; + + const auto* const src = static_cast<const uint8_t*>(source); + using DestType = + typename std::conditional<is_compound, int16_t, uint8_t>::type; + auto* dst = static_cast<DestType*>(dest); + + assert(block_width >= 8); + assert(block_height >= 8); + + // Warp process applies for each 8x8 block. + int start_y = block_start_y; + do { + int start_x = block_start_x; + do { + const int src_x = (start_x + 4) << subsampling_x; + const int src_y = (start_y + 4) << subsampling_y; + const int dst_x = + src_x * warp_params[2] + src_y * warp_params[3] + warp_params[0]; + const int dst_y = + src_x * warp_params[4] + src_y * warp_params[5] + warp_params[1]; + const int x4 = dst_x >> subsampling_x; + const int y4 = dst_y >> subsampling_y; + const int ix4 = x4 >> kWarpedModelPrecisionBits; + const int iy4 = y4 >> kWarpedModelPrecisionBits; + // A prediction block may fall outside the frame's boundaries. If a + // prediction block is calculated using only samples outside the frame's + // boundary, the filtering can be simplified. We can divide the plane + // into several regions and handle them differently. + // + // | | + // 1 | 3 | 1 + // | | + // -------+-----------+------- + // |***********| + // 2 |*****4*****| 2 + // |***********| + // -------+-----------+------- + // | | + // 1 | 3 | 1 + // | | + // + // At the center, region 4 represents the frame and is the general case. + // + // In regions 1 and 2, the prediction block is outside the frame's + // boundary horizontally. Therefore the horizontal filtering can be + // simplified. Furthermore, in the region 1 (at the four corners), the + // prediction is outside the frame's boundary both horizontally and + // vertically, so we get a constant prediction block. + // + // In region 3, the prediction block is outside the frame's boundary + // vertically. Unfortunately because we apply the horizontal filters + // first, by the time we apply the vertical filters, they no longer see + // simple inputs. So the only simplification is that all the rows are + // the same, but we still need to apply all the horizontal and vertical + // filters. + + // Check for two simple special cases, where the horizontal filter can + // be significantly simplified. + // + // In general, for each row, the horizontal filter is calculated as + // follows: + // for (int x = -4; x < 4; ++x) { + // const int offset = ...; + // int sum = first_pass_offset; + // for (int k = 0; k < 8; ++k) { + // const int column = Clip3(ix4 + x + k - 3, 0, source_width - 1); + // sum += kWarpedFilters[offset][k] * src_row[column]; + // } + // ... + // } + // The column index before clipping, ix4 + x + k - 3, varies in the range + // ix4 - 7 <= ix4 + x + k - 3 <= ix4 + 7. If ix4 - 7 >= source_width - 1 + // or ix4 + 7 <= 0, then all the column indexes are clipped to the same + // border index (source_width - 1 or 0, respectively). Then for each x, + // the inner for loop of the horizontal filter is reduced to multiplying + // the border pixel by the sum of the filter coefficients. + if (ix4 - 7 >= source_width - 1 || ix4 + 7 <= 0) { + // Regions 1 and 2. + // Points to the left or right border of the first row of |src|. + const uint8_t* first_row_border = + (ix4 + 7 <= 0) ? src : src + source_width - 1; + // In general, for y in [-7, 8), the row number iy4 + y is clipped: + // const int row = Clip3(iy4 + y, 0, source_height - 1); + // In two special cases, iy4 + y is clipped to either 0 or + // source_height - 1 for all y. In the rest of the cases, iy4 + y is + // bounded and we can avoid clipping iy4 + y by relying on a reference + // frame's boundary extension on the top and bottom. + if (iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0) { + // Region 1. + // Every sample used to calculate the prediction block has the same + // value. So the whole prediction block has the same value. + const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1; + const uint8_t row_border_pixel = + first_row_border[row * source_stride]; + + DestType* dst_row = dst + start_x - block_start_x; + for (int y = 0; y < 8; ++y) { + if (is_compound) { + const int16x8_t sum = + vdupq_n_s16(row_border_pixel << (kInterRoundBitsVertical - + kRoundBitsVertical)); + vst1q_s16(reinterpret_cast<int16_t*>(dst_row), sum); + } else { + memset(dst_row, row_border_pixel, 8); + } + dst_row += dest_stride; + } + // End of region 1. Continue the |start_x| do-while loop. + start_x += 8; + continue; + } + + // Region 2. + // Horizontal filter. + // The input values in this region are generated by extending the border + // which makes them identical in the horizontal direction. This + // computation could be inlined in the vertical pass but most + // implementations will need a transpose of some sort. + // It is not necessary to use the offset values here because the + // horizontal pass is a simple shift and the vertical pass will always + // require using 32 bits. + for (int y = -7; y < 8; ++y) { + // We may over-read up to 13 pixels above the top source row, or up + // to 13 pixels below the bottom source row. This is proved in + // warp.cc. + const int row = iy4 + y; + int sum = first_row_border[row * source_stride]; + sum <<= (kFilterBits - kInterRoundBitsHorizontal); + intermediate_result_column[y + 7] = sum; + } + // Vertical filter. + DestType* dst_row = dst + start_x - block_start_x; + int sy4 = + (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta); + for (int y = 0; y < 8; ++y) { + int sy = sy4 - MultiplyBy4(gamma); +#if defined(__aarch64__) + const int16x8_t intermediate = + vld1q_s16(&intermediate_result_column[y]); + int16_t tmp[8]; + for (int x = 0; x < 8; ++x) { + const int offset = + RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) + + kWarpedPixelPrecisionShifts; + const int16x8_t filter = vld1q_s16(kWarpedFilters[offset]); + const int32x4_t product_low = + vmull_s16(vget_low_s16(filter), vget_low_s16(intermediate)); + const int32x4_t product_high = + vmull_s16(vget_high_s16(filter), vget_high_s16(intermediate)); + // vaddvq_s32 is only available on __aarch64__. + const int32_t sum = + vaddvq_s32(product_low) + vaddvq_s32(product_high); + const int16_t sum_descale = + RightShiftWithRounding(sum, kRoundBitsVertical); + if (is_compound) { + dst_row[x] = sum_descale; + } else { + tmp[x] = sum_descale; + } + sy += gamma; + } + if (!is_compound) { + const int16x8_t sum = vld1q_s16(tmp); + vst1_u8(reinterpret_cast<uint8_t*>(dst_row), vqmovun_s16(sum)); + } +#else // !defined(__aarch64__) + int16x8_t filter[8]; + for (int x = 0; x < 8; ++x) { + const int offset = + RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) + + kWarpedPixelPrecisionShifts; + filter[x] = vld1q_s16(kWarpedFilters[offset]); + sy += gamma; + } + Transpose8x8(filter); + int32x4_t sum_low = vdupq_n_s32(0); + int32x4_t sum_high = sum_low; + for (int k = 0; k < 8; ++k) { + const int16_t intermediate = intermediate_result_column[y + k]; + sum_low = + vmlal_n_s16(sum_low, vget_low_s16(filter[k]), intermediate); + sum_high = + vmlal_n_s16(sum_high, vget_high_s16(filter[k]), intermediate); + } + const int16x8_t sum = + vcombine_s16(vrshrn_n_s32(sum_low, kRoundBitsVertical), + vrshrn_n_s32(sum_high, kRoundBitsVertical)); + if (is_compound) { + vst1q_s16(reinterpret_cast<int16_t*>(dst_row), sum); + } else { + vst1_u8(reinterpret_cast<uint8_t*>(dst_row), vqmovun_s16(sum)); + } +#endif // defined(__aarch64__) + dst_row += dest_stride; + sy4 += delta; + } + // End of region 2. Continue the |start_x| do-while loop. + start_x += 8; + continue; + } + + // Regions 3 and 4. + // At this point, we know ix4 - 7 < source_width - 1 and ix4 + 7 > 0. + + // In general, for y in [-7, 8), the row number iy4 + y is clipped: + // const int row = Clip3(iy4 + y, 0, source_height - 1); + // In two special cases, iy4 + y is clipped to either 0 or + // source_height - 1 for all y. In the rest of the cases, iy4 + y is + // bounded and we can avoid clipping iy4 + y by relying on a reference + // frame's boundary extension on the top and bottom. + if (iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0) { + // Region 3. + // Horizontal filter. + const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1; + const uint8_t* const src_row = src + row * source_stride; + // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also + // read but is ignored. + // + // NOTE: This may read up to 13 bytes before src_row[0] or up to 14 + // bytes after src_row[source_width - 1]. We assume the source frame + // has left and right borders of at least 13 bytes that extend the + // frame boundary pixels. We also assume there is at least one extra + // padding byte after the right border of the last source row. + const uint8x16_t src_row_v = vld1q_u8(&src_row[ix4 - 7]); + // Convert src_row_v to int8 (subtract 128). + const int8x16_t src_row_centered = + vreinterpretq_s8_u8(vsubq_u8(src_row_v, vdupq_n_u8(128))); + int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7; + for (int y = -7; y < 8; ++y) { + HorizontalFilter(sx4, alpha, src_row_centered, + intermediate_result[y + 7]); + sx4 += beta; + } + } else { + // Region 4. + // Horizontal filter. + int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7; + for (int y = -7; y < 8; ++y) { + // We may over-read up to 13 pixels above the top source row, or up + // to 13 pixels below the bottom source row. This is proved in + // warp.cc. + const int row = iy4 + y; + const uint8_t* const src_row = src + row * source_stride; + // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also + // read but is ignored. + // + // NOTE: This may read up to 13 bytes before src_row[0] or up to 14 + // bytes after src_row[source_width - 1]. We assume the source frame + // has left and right borders of at least 13 bytes that extend the + // frame boundary pixels. We also assume there is at least one extra + // padding byte after the right border of the last source row. + const uint8x16_t src_row_v = vld1q_u8(&src_row[ix4 - 7]); + // Convert src_row_v to int8 (subtract 128). + const int8x16_t src_row_centered = + vreinterpretq_s8_u8(vsubq_u8(src_row_v, vdupq_n_u8(128))); + HorizontalFilter(sx4, alpha, src_row_centered, + intermediate_result[y + 7]); + sx4 += beta; + } + } + + // Regions 3 and 4. + // Vertical filter. + DestType* dst_row = dst + start_x - block_start_x; + int sy4 = + (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta); + for (int y = 0; y < 8; ++y) { + int sy = sy4 - MultiplyBy4(gamma); + int16x8_t filter[8]; + for (int x = 0; x < 8; ++x) { + const int offset = + RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) + + kWarpedPixelPrecisionShifts; + filter[x] = vld1q_s16(kWarpedFilters[offset]); + sy += gamma; + } + Transpose8x8(filter); + int32x4_t sum_low = vdupq_n_s32(-kOffsetRemoval); + int32x4_t sum_high = sum_low; + for (int k = 0; k < 8; ++k) { + const int16x8_t intermediate = vld1q_s16(intermediate_result[y + k]); + sum_low = vmlal_s16(sum_low, vget_low_s16(filter[k]), + vget_low_s16(intermediate)); + sum_high = vmlal_s16(sum_high, vget_high_s16(filter[k]), + vget_high_s16(intermediate)); + } + const int16x8_t sum = + vcombine_s16(vrshrn_n_s32(sum_low, kRoundBitsVertical), + vrshrn_n_s32(sum_high, kRoundBitsVertical)); + if (is_compound) { + vst1q_s16(reinterpret_cast<int16_t*>(dst_row), sum); + } else { + vst1_u8(reinterpret_cast<uint8_t*>(dst_row), vqmovun_s16(sum)); + } + dst_row += dest_stride; + sy4 += delta; + } + start_x += 8; + } while (start_x < block_start_x + block_width); + dst += 8 * dest_stride; + start_y += 8; + } while (start_y < block_start_y + block_height); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->warp = Warp_NEON</*is_compound=*/false>; + dsp->warp_compound = Warp_NEON</*is_compound=*/true>; +} + +} // namespace +} // namespace low_bitdepth + +void WarpInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 +#else // !LIBGAV1_ENABLE_NEON +namespace libgav1 { +namespace dsp { + +void WarpInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/warp_neon.h b/src/dsp/arm/warp_neon.h new file mode 100644 index 0000000..dbcaa23 --- /dev/null +++ b/src/dsp/arm/warp_neon.h @@ -0,0 +1,37 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_WARP_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_WARP_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::warp. This function is not thread-safe. +void WarpInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_Warp LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WarpCompound LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_WARP_NEON_H_ diff --git a/src/dsp/arm/weight_mask_neon.cc b/src/dsp/arm/weight_mask_neon.cc new file mode 100644 index 0000000..49d3be0 --- /dev/null +++ b/src/dsp/arm/weight_mask_neon.cc @@ -0,0 +1,463 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/arm/weight_mask_neon.h" + +#include "src/dsp/weight_mask.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON + +#include <arm_neon.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +constexpr int kRoundingBits8bpp = 4; + +template <bool mask_is_inverse> +inline void WeightMask8_NEON(const int16_t* prediction_0, + const int16_t* prediction_1, uint8_t* mask) { + const int16x8_t pred_0 = vld1q_s16(prediction_0); + const int16x8_t pred_1 = vld1q_s16(prediction_1); + const uint8x8_t difference_offset = vdup_n_u8(38); + const uint8x8_t mask_ceiling = vdup_n_u8(64); + const uint16x8_t difference = vrshrq_n_u16( + vreinterpretq_u16_s16(vabdq_s16(pred_0, pred_1)), kRoundingBits8bpp); + const uint8x8_t adjusted_difference = + vqadd_u8(vqshrn_n_u16(difference, 4), difference_offset); + const uint8x8_t mask_value = vmin_u8(adjusted_difference, mask_ceiling); + if (mask_is_inverse) { + const uint8x8_t inverted_mask_value = vsub_u8(mask_ceiling, mask_value); + vst1_u8(mask, inverted_mask_value); + } else { + vst1_u8(mask, mask_value); + } +} + +#define WEIGHT8_WITHOUT_STRIDE \ + WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask) + +#define WEIGHT8_AND_STRIDE \ + WEIGHT8_WITHOUT_STRIDE; \ + pred_0 += 8; \ + pred_1 += 8; \ + mask += mask_stride + +template <bool mask_is_inverse> +void WeightMask8x8_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y = 0; + do { + WEIGHT8_AND_STRIDE; + } while (++y < 7); + WEIGHT8_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask8x16_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT8_AND_STRIDE; + WEIGHT8_AND_STRIDE; + WEIGHT8_AND_STRIDE; + } while (++y3 < 5); + WEIGHT8_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask8x32_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y5 = 0; + do { + WEIGHT8_AND_STRIDE; + WEIGHT8_AND_STRIDE; + WEIGHT8_AND_STRIDE; + WEIGHT8_AND_STRIDE; + WEIGHT8_AND_STRIDE; + } while (++y5 < 6); + WEIGHT8_AND_STRIDE; + WEIGHT8_WITHOUT_STRIDE; +} + +#define WEIGHT16_WITHOUT_STRIDE \ + WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8) + +#define WEIGHT16_AND_STRIDE \ + WEIGHT16_WITHOUT_STRIDE; \ + pred_0 += 16; \ + pred_1 += 16; \ + mask += mask_stride + +template <bool mask_is_inverse> +void WeightMask16x8_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y = 0; + do { + WEIGHT16_AND_STRIDE; + } while (++y < 7); + WEIGHT16_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask16x16_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + } while (++y3 < 5); + WEIGHT16_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask16x32_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y5 = 0; + do { + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + } while (++y5 < 6); + WEIGHT16_AND_STRIDE; + WEIGHT16_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask16x64_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + } while (++y3 < 21); + WEIGHT16_WITHOUT_STRIDE; +} + +#define WEIGHT32_WITHOUT_STRIDE \ + WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 16, pred_1 + 16, mask + 16); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 24, pred_1 + 24, mask + 24) + +#define WEIGHT32_AND_STRIDE \ + WEIGHT32_WITHOUT_STRIDE; \ + pred_0 += 32; \ + pred_1 += 32; \ + mask += mask_stride + +template <bool mask_is_inverse> +void WeightMask32x8_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask32x16_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + } while (++y3 < 5); + WEIGHT32_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask32x32_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y5 = 0; + do { + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + } while (++y5 < 6); + WEIGHT32_AND_STRIDE; + WEIGHT32_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask32x64_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + } while (++y3 < 21); + WEIGHT32_WITHOUT_STRIDE; +} + +#define WEIGHT64_WITHOUT_STRIDE \ + WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 16, pred_1 + 16, mask + 16); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 24, pred_1 + 24, mask + 24); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 32, pred_1 + 32, mask + 32); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 40, pred_1 + 40, mask + 40); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 48, pred_1 + 48, mask + 48); \ + WeightMask8_NEON<mask_is_inverse>(pred_0 + 56, pred_1 + 56, mask + 56) + +#define WEIGHT64_AND_STRIDE \ + WEIGHT64_WITHOUT_STRIDE; \ + pred_0 += 64; \ + pred_1 += 64; \ + mask += mask_stride + +template <bool mask_is_inverse> +void WeightMask64x16_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + } while (++y3 < 5); + WEIGHT64_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask64x32_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y5 = 0; + do { + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + } while (++y5 < 6); + WEIGHT64_AND_STRIDE; + WEIGHT64_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask64x64_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + } while (++y3 < 21); + WEIGHT64_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask64x128_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + } while (++y3 < 42); + WEIGHT64_AND_STRIDE; + WEIGHT64_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask128x64_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + const ptrdiff_t adjusted_mask_stride = mask_stride - 64; + do { + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + } while (++y3 < 21); + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask128x128_NEON(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + const ptrdiff_t adjusted_mask_stride = mask_stride - 64; + do { + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + } while (++y3 < 42); + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; +} + +#define INIT_WEIGHT_MASK_8BPP(width, height, w_index, h_index) \ + dsp->weight_mask[w_index][h_index][0] = \ + WeightMask##width##x##height##_NEON<0>; \ + dsp->weight_mask[w_index][h_index][1] = WeightMask##width##x##height##_NEON<1> +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + INIT_WEIGHT_MASK_8BPP(8, 8, 0, 0); + INIT_WEIGHT_MASK_8BPP(8, 16, 0, 1); + INIT_WEIGHT_MASK_8BPP(8, 32, 0, 2); + INIT_WEIGHT_MASK_8BPP(16, 8, 1, 0); + INIT_WEIGHT_MASK_8BPP(16, 16, 1, 1); + INIT_WEIGHT_MASK_8BPP(16, 32, 1, 2); + INIT_WEIGHT_MASK_8BPP(16, 64, 1, 3); + INIT_WEIGHT_MASK_8BPP(32, 8, 2, 0); + INIT_WEIGHT_MASK_8BPP(32, 16, 2, 1); + INIT_WEIGHT_MASK_8BPP(32, 32, 2, 2); + INIT_WEIGHT_MASK_8BPP(32, 64, 2, 3); + INIT_WEIGHT_MASK_8BPP(64, 16, 3, 1); + INIT_WEIGHT_MASK_8BPP(64, 32, 3, 2); + INIT_WEIGHT_MASK_8BPP(64, 64, 3, 3); + INIT_WEIGHT_MASK_8BPP(64, 128, 3, 4); + INIT_WEIGHT_MASK_8BPP(128, 64, 4, 3); + INIT_WEIGHT_MASK_8BPP(128, 128, 4, 4); +} + +} // namespace +} // namespace low_bitdepth + +void WeightMaskInit_NEON() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_ENABLE_NEON + +namespace libgav1 { +namespace dsp { + +void WeightMaskInit_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON diff --git a/src/dsp/arm/weight_mask_neon.h b/src/dsp/arm/weight_mask_neon.h new file mode 100644 index 0000000..b4749ec --- /dev/null +++ b/src/dsp/arm/weight_mask_neon.h @@ -0,0 +1,52 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_NEON_H_ +#define LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_NEON_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::weight_mask. This function is not thread-safe. +void WeightMaskInit_NEON(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_ENABLE_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_8x8 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_8x16 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_8x32 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_16x8 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_16x16 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_16x32 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_16x64 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_32x8 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_32x16 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_32x32 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_32x64 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_64x16 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_64x32 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_64x64 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_64x128 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_128x64 LIBGAV1_CPU_NEON +#define LIBGAV1_Dsp8bpp_WeightMask_128x128 LIBGAV1_CPU_NEON +#endif // LIBGAV1_ENABLE_NEON + +#endif // LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_NEON_H_ diff --git a/src/dsp/average_blend.cc b/src/dsp/average_blend.cc new file mode 100644 index 0000000..a59abb0 --- /dev/null +++ b/src/dsp/average_blend.cc @@ -0,0 +1,101 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/average_blend.h" + +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <type_traits> + +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace { + +template <int bitdepth, typename Pixel> +void AverageBlend_C(const void* prediction_0, const void* prediction_1, + const int width, const int height, void* const dest, + const ptrdiff_t dest_stride) { + // 7.11.3.2 Rounding variables derivation process + // 2 * FILTER_BITS(7) - (InterRound0(3|5) + InterRound1(7)) + constexpr int inter_post_round_bits = (bitdepth == 12) ? 2 : 4; + using PredType = + typename std::conditional<bitdepth == 8, int16_t, uint16_t>::type; + const auto* pred_0 = static_cast<const PredType*>(prediction_0); + const auto* pred_1 = static_cast<const PredType*>(prediction_1); + auto* dst = static_cast<Pixel*>(dest); + const ptrdiff_t dst_stride = dest_stride / sizeof(Pixel); + + int y = 0; + do { + int x = 0; + do { + // See warp.cc and convolve.cc for detailed prediction ranges. + int res = pred_0[x] + pred_1[x]; + res -= (bitdepth == 8) ? 0 : kCompoundOffset + kCompoundOffset; + dst[x] = static_cast<Pixel>( + Clip3(RightShiftWithRounding(res, inter_post_round_bits + 1), 0, + (1 << bitdepth) - 1)); + } while (++x < width); + + dst += dst_stride; + pred_0 += width; + pred_1 += width; + } while (++y < height); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(8); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->average_blend = AverageBlend_C<8, uint8_t>; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp8bpp_AverageBlend + dsp->average_blend = AverageBlend_C<8, uint8_t>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(10); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +#ifndef LIBGAV1_Dsp10bpp_AverageBlend + dsp->average_blend = AverageBlend_C<10, uint16_t>; +#endif +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp10bpp_AverageBlend + dsp->average_blend = AverageBlend_C<10, uint16_t>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} +#endif + +} // namespace + +void AverageBlendInit_C() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 diff --git a/src/dsp/average_blend.h b/src/dsp/average_blend.h new file mode 100644 index 0000000..02ecd09 --- /dev/null +++ b/src/dsp/average_blend.h @@ -0,0 +1,47 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_AVERAGE_BLEND_H_ +#define LIBGAV1_SRC_DSP_AVERAGE_BLEND_H_ + +// Pull in LIBGAV1_DspXXX defines representing the implementation status +// of each function. The resulting value of each can be used by each module to +// determine whether an implementation is needed at compile time. +// IWYU pragma: begin_exports + +// ARM: +#include "src/dsp/arm/average_blend_neon.h" + +// x86: +// Note includes should be sorted in logical order avx2/avx/sse4, etc. +// The order of includes is important as each tests for a superior version +// before setting the base. +// clang-format off +#include "src/dsp/x86/average_blend_sse4.h" +// clang-format on + +// IWYU pragma: end_exports + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::average_blend. This function is not thread-safe. +void AverageBlendInit_C(); + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_AVERAGE_BLEND_H_ diff --git a/src/dsp/cdef.cc b/src/dsp/cdef.cc new file mode 100644 index 0000000..0b50517 --- /dev/null +++ b/src/dsp/cdef.cc @@ -0,0 +1,306 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/cdef.h" + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace { + +#include "src/dsp/cdef.inc" + +// Silence unused function warnings when CdefDirection_C is obviated. +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \ + !defined(LIBGAV1_Dsp8bpp_CdefDirection) || \ + (LIBGAV1_MAX_BITDEPTH >= 10 && !defined(LIBGAV1_Dsp10bpp_CdefDirection)) +constexpr int16_t kDivisionTable[] = {840, 420, 280, 210, 168, 140, 120, 105}; + +int32_t Square(int32_t x) { return x * x; } + +template <int bitdepth, typename Pixel> +void CdefDirection_C(const void* const source, ptrdiff_t stride, + uint8_t* const direction, int* const variance) { + assert(direction != nullptr); + assert(variance != nullptr); + const auto* src = static_cast<const Pixel*>(source); + stride /= sizeof(Pixel); + int32_t cost[8] = {}; + // |partial| does not have to be int32_t for 8bpp. int16_t will suffice. We + // use int32_t to keep it simple since |cost| will have to be int32_t. + int32_t partial[8][15] = {}; + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + const int x = (src[j] >> (bitdepth - 8)) - 128; + partial[0][i + j] += x; + partial[1][i + j / 2] += x; + partial[2][i] += x; + partial[3][3 + i - j / 2] += x; + partial[4][7 + i - j] += x; + partial[5][3 - i / 2 + j] += x; + partial[6][j] += x; + partial[7][i / 2 + j] += x; + } + src += stride; + } + for (int i = 0; i < 8; ++i) { + cost[2] += Square(partial[2][i]); + cost[6] += Square(partial[6][i]); + } + cost[2] *= kDivisionTable[7]; + cost[6] *= kDivisionTable[7]; + for (int i = 0; i < 7; ++i) { + cost[0] += (Square(partial[0][i]) + Square(partial[0][14 - i])) * + kDivisionTable[i]; + cost[4] += (Square(partial[4][i]) + Square(partial[4][14 - i])) * + kDivisionTable[i]; + } + cost[0] += Square(partial[0][7]) * kDivisionTable[7]; + cost[4] += Square(partial[4][7]) * kDivisionTable[7]; + for (int i = 1; i < 8; i += 2) { + for (int j = 0; j < 5; ++j) { + cost[i] += Square(partial[i][3 + j]); + } + cost[i] *= kDivisionTable[7]; + for (int j = 0; j < 3; ++j) { + cost[i] += (Square(partial[i][j]) + Square(partial[i][10 - j])) * + kDivisionTable[2 * j + 1]; + } + } + int32_t best_cost = 0; + *direction = 0; + for (int i = 0; i < 8; ++i) { + if (cost[i] > best_cost) { + best_cost = cost[i]; + *direction = i; + } + } + *variance = (best_cost - cost[(*direction + 4) & 7]) >> 10; +} +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || + // !defined(LIBGAV1_Dsp8bpp_CdefDirection) || + // (LIBGAV1_MAX_BITDEPTH >= 10 && + // !defined(LIBGAV1_Dsp10bpp_CdefDirection)) + +// Silence unused function warnings when CdefFilter_C is obviated. +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \ + !defined(LIBGAV1_Dsp8bpp_CdefFilters) || \ + (LIBGAV1_MAX_BITDEPTH >= 10 && !defined(LIBGAV1_Dsp10bpp_CdefFilters)) + +int Constrain(int diff, int threshold, int damping) { + assert(threshold != 0); + damping = std::max(0, damping - FloorLog2(threshold)); + const int sign = (diff < 0) ? -1 : 1; + return sign * + Clip3(threshold - (std::abs(diff) >> damping), 0, std::abs(diff)); +} + +// Filters the source block. It doesn't check whether the candidate pixel is +// inside the frame. However it requires the source input to be padded with a +// constant large value (kCdefLargeValue) if at the boundary. +template <int block_width, int bitdepth, typename Pixel, + bool enable_primary = true, bool enable_secondary = true> +void CdefFilter_C(const uint16_t* src, const ptrdiff_t src_stride, + const int block_height, const int primary_strength, + const int secondary_strength, const int damping, + const int direction, void* const dest, + const ptrdiff_t dest_stride) { + static_assert(block_width == 4 || block_width == 8, "Invalid CDEF width."); + static_assert(enable_primary || enable_secondary, ""); + assert(block_height == 4 || block_height == 8); + assert(direction >= 0 && direction <= 7); + constexpr int coeff_shift = bitdepth - 8; + // Section 5.9.19. CDEF params syntax. + assert(primary_strength >= 0 && primary_strength <= 15 << coeff_shift); + assert(secondary_strength >= 0 && secondary_strength <= 4 << coeff_shift && + secondary_strength != 3 << coeff_shift); + assert(primary_strength != 0 || secondary_strength != 0); + // damping is decreased by 1 for chroma. + assert((damping >= 3 && damping <= 6 + coeff_shift) || + (damping >= 2 && damping <= 5 + coeff_shift)); + // When only primary_strength or secondary_strength are non-zero the number + // of pixels inspected (4 for primary_strength, 8 for secondary_strength) and + // the taps used don't exceed the amount the sum is + // descaled by (16) so we can skip tracking and clipping to the minimum and + // maximum value observed. + constexpr bool clipping_required = enable_primary && enable_secondary; + static constexpr int kCdefSecondaryTaps[2] = {kCdefSecondaryTap0, + kCdefSecondaryTap1}; + auto* dst = static_cast<Pixel*>(dest); + const ptrdiff_t dst_stride = dest_stride / sizeof(Pixel); + int y = block_height; + do { + int x = 0; + do { + int16_t sum = 0; + const uint16_t pixel_value = src[x]; + uint16_t max_value = pixel_value; + uint16_t min_value = pixel_value; + for (int k = 0; k < 2; ++k) { + static constexpr int signs[] = {-1, 1}; + for (const int& sign : signs) { + if (enable_primary) { + const int dy = sign * kCdefDirections[direction][k][0]; + const int dx = sign * kCdefDirections[direction][k][1]; + const uint16_t value = src[dy * src_stride + dx + x]; + // Note: the summation can ignore the condition check in SIMD + // implementation, because Constrain() will return 0 when + // value == kCdefLargeValue. + if (value != kCdefLargeValue) { + sum += Constrain(value - pixel_value, primary_strength, damping) * + kCdefPrimaryTaps[(primary_strength >> coeff_shift) & 1][k]; + if (clipping_required) { + max_value = std::max(value, max_value); + min_value = std::min(value, min_value); + } + } + } + + if (enable_secondary) { + static constexpr int offsets[] = {-2, 2}; + for (const int& offset : offsets) { + const int dy = sign * kCdefDirections[direction + offset][k][0]; + const int dx = sign * kCdefDirections[direction + offset][k][1]; + const uint16_t value = src[dy * src_stride + dx + x]; + // Note: the summation can ignore the condition check in SIMD + // implementation. + if (value != kCdefLargeValue) { + sum += Constrain(value - pixel_value, secondary_strength, + damping) * + kCdefSecondaryTaps[k]; + if (clipping_required) { + max_value = std::max(value, max_value); + min_value = std::min(value, min_value); + } + } + } + } + } + } + + const int offset = (8 + sum - (sum < 0)) >> 4; + if (clipping_required) { + dst[x] = static_cast<Pixel>( + Clip3(pixel_value + offset, min_value, max_value)); + } else { + dst[x] = static_cast<Pixel>(pixel_value + offset); + } + } while (++x < block_width); + + src += src_stride; + dst += dst_stride; + } while (--y != 0); +} +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || + // !defined(LIBGAV1_Dsp8bpp_CdefFilters) || + // (LIBGAV1_MAX_BITDEPTH >= 10 && + // !defined(LIBGAV1_Dsp10bpp_CdefFilters)) + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(8); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->cdef_direction = CdefDirection_C<8, uint8_t>; + dsp->cdef_filters[0][0] = CdefFilter_C<4, 8, uint8_t>; + dsp->cdef_filters[0][1] = CdefFilter_C<4, 8, uint8_t, /*enable_primary=*/true, + /*enable_secondary=*/false>; + dsp->cdef_filters[0][2] = + CdefFilter_C<4, 8, uint8_t, /*enable_primary=*/false>; + dsp->cdef_filters[1][0] = CdefFilter_C<8, 8, uint8_t>; + dsp->cdef_filters[1][1] = CdefFilter_C<8, 8, uint8_t, /*enable_primary=*/true, + /*enable_secondary=*/false>; + dsp->cdef_filters[1][2] = + CdefFilter_C<8, 8, uint8_t, /*enable_primary=*/false>; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp8bpp_CdefDirection + dsp->cdef_direction = CdefDirection_C<8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_CdefFilters + dsp->cdef_filters[0][0] = CdefFilter_C<4, 8, uint8_t>; + dsp->cdef_filters[0][1] = CdefFilter_C<4, 8, uint8_t, /*enable_primary=*/true, + /*enable_secondary=*/false>; + dsp->cdef_filters[0][2] = + CdefFilter_C<4, 8, uint8_t, /*enable_primary=*/false>; + dsp->cdef_filters[1][0] = CdefFilter_C<8, 8, uint8_t>; + dsp->cdef_filters[1][1] = CdefFilter_C<8, 8, uint8_t, /*enable_primary=*/true, + /*enable_secondary=*/false>; + dsp->cdef_filters[1][2] = + CdefFilter_C<8, 8, uint8_t, /*enable_primary=*/false>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(10); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->cdef_direction = CdefDirection_C<10, uint16_t>; + dsp->cdef_filters[0][0] = CdefFilter_C<4, 10, uint16_t>; + dsp->cdef_filters[0][1] = + CdefFilter_C<4, 10, uint16_t, /*enable_primary=*/true, + /*enable_secondary=*/false>; + dsp->cdef_filters[0][2] = + CdefFilter_C<4, 10, uint16_t, /*enable_primary=*/false>; + dsp->cdef_filters[1][0] = CdefFilter_C<8, 10, uint16_t>; + dsp->cdef_filters[1][1] = + CdefFilter_C<8, 10, uint16_t, /*enable_primary=*/true, + /*enable_secondary=*/false>; + dsp->cdef_filters[1][2] = + CdefFilter_C<8, 10, uint16_t, /*enable_primary=*/false>; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp10bpp_CdefDirection + dsp->cdef_direction = CdefDirection_C<10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_CdefFilters + dsp->cdef_filters[0][0] = CdefFilter_C<4, 10, uint16_t>; + dsp->cdef_filters[0][1] = + CdefFilter_C<4, 10, uint16_t, /*enable_primary=*/true, + /*enable_secondary=*/false>; + dsp->cdef_filters[0][2] = + CdefFilter_C<4, 10, uint16_t, /*enable_primary=*/false>; + dsp->cdef_filters[1][0] = CdefFilter_C<8, 10, uint16_t>; + dsp->cdef_filters[1][1] = + CdefFilter_C<8, 10, uint16_t, /*enable_primary=*/true, + /*enable_secondary=*/false>; + dsp->cdef_filters[1][2] = + CdefFilter_C<8, 10, uint16_t, /*enable_primary=*/false>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} +#endif + +} // namespace + +void CdefInit_C() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 diff --git a/src/dsp/cdef.h b/src/dsp/cdef.h new file mode 100644 index 0000000..2d70d2c --- /dev/null +++ b/src/dsp/cdef.h @@ -0,0 +1,47 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_CDEF_H_ +#define LIBGAV1_SRC_DSP_CDEF_H_ + +// Pull in LIBGAV1_DspXXX defines representing the implementation status +// of each function. The resulting value of each can be used by each module to +// determine whether an implementation is needed at compile time. +// IWYU pragma: begin_exports + +// ARM: +#include "src/dsp/arm/cdef_neon.h" + +// x86: +// Note includes should be sorted in logical order avx2/avx/sse4, etc. +// The order of includes is important as each tests for a superior version +// before setting the base. +// clang-format off +#include "src/dsp/x86/cdef_sse4.h" +// clang-format on +// IWYU pragma: end_exports + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::cdef_direction and Dsp::cdef_filters. This function is not +// thread-safe. +void CdefInit_C(); + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_CDEF_H_ diff --git a/src/dsp/cdef.inc b/src/dsp/cdef.inc new file mode 100644 index 0000000..c1a3136 --- /dev/null +++ b/src/dsp/cdef.inc @@ -0,0 +1,29 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Constants used for cdef implementations. +// This will be included inside an anonymous namespace on files where these are +// necessary. + +const int8_t (*const kCdefDirections)[2][2] = kCdefDirectionsPadded + 2; + +// Mirror values and pad to 16 elements. +alignas(16) constexpr uint32_t kCdefDivisionTable[] = { + 840, 420, 280, 210, 168, 140, 120, 105, + 120, 140, 168, 210, 280, 420, 840, 0}; + +// Used when calculating odd |cost[x]| values to mask off unwanted elements. +// Holds elements 1 3 5 X 5 3 1 X +alignas(16) constexpr uint32_t kCdefDivisionTableOdd[] = {420, 210, 140, 0, + 140, 210, 420, 0}; diff --git a/src/dsp/common.h b/src/dsp/common.h new file mode 100644 index 0000000..d614a81 --- /dev/null +++ b/src/dsp/common.h @@ -0,0 +1,82 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_COMMON_H_ +#define LIBGAV1_SRC_DSP_COMMON_H_ + +#include <cstdint> + +#include "src/dsp/constants.h" +#include "src/utils/constants.h" +#include "src/utils/memory.h" + +namespace libgav1 { + +enum { kSgrStride = kRestorationUnitWidth + 32 }; // anonymous enum + +// Self guided projection filter. +struct SgrProjInfo { + int index; + int multiplier[2]; +}; + +struct WienerInfo { + static const int kVertical = 0; + static const int kHorizontal = 1; + int16_t number_leading_zero_coefficients[2]; + alignas(kMaxAlignment) int16_t filter[2][(kWienerFilterTaps + 1) / 2]; +}; + +struct RestorationUnitInfo : public MaxAlignedAllocable { + LoopRestorationType type; + SgrProjInfo sgr_proj_info; + WienerInfo wiener_info; +}; + +struct SgrBuffer { + alignas(kMaxAlignment) uint16_t sum3[4 * kSgrStride]; + alignas(kMaxAlignment) uint16_t sum5[5 * kSgrStride]; + alignas(kMaxAlignment) uint32_t square_sum3[4 * kSgrStride]; + alignas(kMaxAlignment) uint32_t square_sum5[5 * kSgrStride]; + alignas(kMaxAlignment) uint16_t ma343[4 * kRestorationUnitWidth]; + alignas(kMaxAlignment) uint16_t ma444[3 * kRestorationUnitWidth]; + alignas(kMaxAlignment) uint16_t ma565[2 * kRestorationUnitWidth]; + alignas(kMaxAlignment) uint32_t b343[4 * kRestorationUnitWidth]; + alignas(kMaxAlignment) uint32_t b444[3 * kRestorationUnitWidth]; + alignas(kMaxAlignment) uint32_t b565[2 * kRestorationUnitWidth]; + // The following 2 buffers are only used by the C functions. Since SgrBuffer + // is smaller than |wiener_buffer| in RestorationBuffer which is an union, + // it's OK to always keep the following 2 buffers. + alignas(kMaxAlignment) uint8_t ma[kSgrStride]; // [0, 255] + // b is less than 2^16 for 8-bit. However, making it a template slows down the + // C function by 5%. So b is fixed to 32-bit. + alignas(kMaxAlignment) uint32_t b[kSgrStride]; +}; + +union RestorationBuffer { + // For self-guided filter. + SgrBuffer sgr_buffer; + // For wiener filter. + // The array |intermediate| in Section 7.17.4, the intermediate results + // between the horizontal and vertical filters. + alignas(kMaxAlignment) int16_t + wiener_buffer[(kRestorationUnitHeight + kWienerFilterTaps - 1) * + kRestorationUnitWidth]; +}; + +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_COMMON_H_ diff --git a/src/dsp/constants.cc b/src/dsp/constants.cc new file mode 100644 index 0000000..0099ca3 --- /dev/null +++ b/src/dsp/constants.cc @@ -0,0 +1,103 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/constants.h" + +#include <cstdint> + +namespace libgav1 { + +// Each set of 7 taps is padded with a 0 to easily align and pack into the high +// and low 8 bytes. This way, we can load 16 at a time to fit mulhi and mullo. +const int8_t kFilterIntraTaps[kNumFilterIntraPredictors][8][8] = { + {{-6, 10, 0, 0, 0, 12, 0, 0}, + {-5, 2, 10, 0, 0, 9, 0, 0}, + {-3, 1, 1, 10, 0, 7, 0, 0}, + {-3, 1, 1, 2, 10, 5, 0, 0}, + {-4, 6, 0, 0, 0, 2, 12, 0}, + {-3, 2, 6, 0, 0, 2, 9, 0}, + {-3, 2, 2, 6, 0, 2, 7, 0}, + {-3, 1, 2, 2, 6, 3, 5, 0}}, + {{-10, 16, 0, 0, 0, 10, 0, 0}, + {-6, 0, 16, 0, 0, 6, 0, 0}, + {-4, 0, 0, 16, 0, 4, 0, 0}, + {-2, 0, 0, 0, 16, 2, 0, 0}, + {-10, 16, 0, 0, 0, 0, 10, 0}, + {-6, 0, 16, 0, 0, 0, 6, 0}, + {-4, 0, 0, 16, 0, 0, 4, 0}, + {-2, 0, 0, 0, 16, 0, 2, 0}}, + {{-8, 8, 0, 0, 0, 16, 0, 0}, + {-8, 0, 8, 0, 0, 16, 0, 0}, + {-8, 0, 0, 8, 0, 16, 0, 0}, + {-8, 0, 0, 0, 8, 16, 0, 0}, + {-4, 4, 0, 0, 0, 0, 16, 0}, + {-4, 0, 4, 0, 0, 0, 16, 0}, + {-4, 0, 0, 4, 0, 0, 16, 0}, + {-4, 0, 0, 0, 4, 0, 16, 0}}, + {{-2, 8, 0, 0, 0, 10, 0, 0}, + {-1, 3, 8, 0, 0, 6, 0, 0}, + {-1, 2, 3, 8, 0, 4, 0, 0}, + {0, 1, 2, 3, 8, 2, 0, 0}, + {-1, 4, 0, 0, 0, 3, 10, 0}, + {-1, 3, 4, 0, 0, 4, 6, 0}, + {-1, 2, 3, 4, 0, 4, 4, 0}, + {-1, 2, 2, 3, 4, 3, 3, 0}}, + {{-12, 14, 0, 0, 0, 14, 0, 0}, + {-10, 0, 14, 0, 0, 12, 0, 0}, + {-9, 0, 0, 14, 0, 11, 0, 0}, + {-8, 0, 0, 0, 14, 10, 0, 0}, + {-10, 12, 0, 0, 0, 0, 14, 0}, + {-9, 1, 12, 0, 0, 0, 12, 0}, + {-8, 0, 0, 12, 0, 1, 11, 0}, + {-7, 0, 0, 1, 12, 1, 9, 0}}}; + +// A lookup table replacing the calculation of the variable s in Section 7.17.3 +// (Box filter process). The first index is sgr_proj_index (the lr_sgr_set +// syntax element in the Spec, saved in the sgr_proj_info.index field of a +// RestorationUnitInfo struct). The second index is pass (0 or 1). +// +// const uint8_t scale = kSgrProjParams[sgr_proj_index][pass * 2 + 1]; +// const uint32_t n2_with_scale = n * n * scale; +// const uint32_t s = +// ((1 << kSgrProjScaleBits) + (n2_with_scale >> 1)) / n2_with_scale; +// 0 is an invalid value, corresponding to radius = 0, where the filter is +// skipped. +const uint16_t kSgrScaleParameter[16][2] = { + {140, 3236}, {112, 2158}, {93, 1618}, {80, 1438}, {70, 1295}, {58, 1177}, + {47, 1079}, {37, 996}, {30, 925}, {25, 863}, {0, 2589}, {0, 1618}, + {0, 1177}, {0, 925}, {56, 0}, {22, 0}, +}; + +const uint8_t kCdefPrimaryTaps[2][2] = {{4, 2}, {3, 3}}; + +// This is Cdef_Directions (section 7.15.3) with 2 padding entries at the +// beginning and end of the table. The cdef direction range is [0, 7] and the +// first index is offset +/-2. This removes the need to constrain the first +// index to the same range using e.g., & 7. +const int8_t kCdefDirectionsPadded[12][2][2] = { + {{1, 0}, {2, 0}}, // Padding: Cdef_Directions[6] + {{1, 0}, {2, -1}}, // Padding: Cdef_Directions[7] + {{-1, 1}, {-2, 2}}, // Begin Cdef_Directions + {{0, 1}, {-1, 2}}, // + {{0, 1}, {0, 2}}, // + {{0, 1}, {1, 2}}, // + {{1, 1}, {2, 2}}, // + {{1, 0}, {2, 1}}, // + {{1, 0}, {2, 0}}, // + {{1, 0}, {2, -1}}, // End Cdef_Directions + {{-1, 1}, {-2, 2}}, // Padding: Cdef_Directions[0] + {{0, 1}, {-1, 2}}, // Padding: Cdef_Directions[1] +}; + +} // namespace libgav1 diff --git a/src/dsp/constants.h b/src/dsp/constants.h new file mode 100644 index 0000000..7c1b62c --- /dev/null +++ b/src/dsp/constants.h @@ -0,0 +1,71 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_CONSTANTS_H_ +#define LIBGAV1_SRC_DSP_CONSTANTS_H_ + +// This file contains DSP related constants that have a direct relationship with +// a DSP component. + +#include <cstdint> + +#include "src/utils/constants.h" + +namespace libgav1 { + +enum { + // Documentation variables. + kBitdepth8 = 8, + kBitdepth10 = 10, + kBitdepth12 = 12, + // Weights are quadratic from '1' to '1 / block_size', scaled by + // 2^kSmoothWeightScale. + kSmoothWeightScale = 8, + kCflLumaBufferStride = 32, + // InterRound0, Section 7.11.3.2. + kInterRoundBitsHorizontal = 3, // 8 & 10-bit. + kInterRoundBitsHorizontal12bpp = 5, + kInterRoundBitsCompoundVertical = 7, // 8, 10 & 12-bit compound prediction. + kInterRoundBitsVertical = 11, // 8 & 10-bit, single prediction. + kInterRoundBitsVertical12bpp = 9, + // Offset applied to 10bpp and 12bpp predictors to allow storing them in + // uint16_t. Removed before blending. + kCompoundOffset = (1 << 14) + (1 << 13), + kCdefSecondaryTap0 = 2, + kCdefSecondaryTap1 = 1, +}; // anonymous enum + +extern const int8_t kFilterIntraTaps[kNumFilterIntraPredictors][8][8]; + +// Values in this enum can be derived as the sum of subsampling_x and +// subsampling_y (since subsampling_x == 0 && subsampling_y == 1 case is never +// allowed by the bitstream). +enum SubsamplingType : uint8_t { + kSubsamplingType444, // subsampling_x = 0, subsampling_y = 0. + kSubsamplingType422, // subsampling_x = 1, subsampling_y = 0. + kSubsamplingType420, // subsampling_x = 1, subsampling_y = 1. + kNumSubsamplingTypes +}; + +extern const uint16_t kSgrScaleParameter[16][2]; + +extern const uint8_t kCdefPrimaryTaps[2][2]; + +extern const int8_t kCdefDirectionsPadded[12][2][2]; + +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_CONSTANTS_H_ diff --git a/src/dsp/convolve.cc b/src/dsp/convolve.cc new file mode 100644 index 0000000..8c6f68f --- /dev/null +++ b/src/dsp/convolve.cc @@ -0,0 +1,876 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/convolve.h" + +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <cstring> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace { + +constexpr int kHorizontalOffset = 3; +constexpr int kVerticalOffset = 3; + +// Compound prediction output ranges from ConvolveTest.ShowRange. +// Bitdepth: 8 Input range: [ 0, 255] +// intermediate range: [ -7140, 23460] +// first pass output range: [ -1785, 5865] +// intermediate range: [ -328440, 589560] +// second pass output range: [ 0, 255] +// compound second pass output range: [ -5132, 9212] +// +// Bitdepth: 10 Input range: [ 0, 1023] +// intermediate range: [ -28644, 94116] +// first pass output range: [ -7161, 23529] +// intermediate range: [-1317624, 2365176] +// second pass output range: [ 0, 1023] +// compound second pass output range: [ 3988, 61532] +// +// Bitdepth: 12 Input range: [ 0, 4095] +// intermediate range: [ -114660, 376740] +// first pass output range: [ -7166, 23546] +// intermediate range: [-1318560, 2366880] +// second pass output range: [ 0, 4095] +// compound second pass output range: [ 3974, 61559] + +template <int bitdepth, typename Pixel> +void ConvolveScale2D_C(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int vertical_filter_index, const int subpixel_x, + const int subpixel_y, const int step_x, const int step_y, + const int width, const int height, void* prediction, + const ptrdiff_t pred_stride) { + constexpr int kRoundBitsHorizontal = (bitdepth == 12) + ? kInterRoundBitsHorizontal12bpp + : kInterRoundBitsHorizontal; + constexpr int kRoundBitsVertical = + (bitdepth == 12) ? kInterRoundBitsVertical12bpp : kInterRoundBitsVertical; + const int intermediate_height = + (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >> + kScaleSubPixelBits) + + kSubPixelTaps; + // The output of the horizontal filter, i.e. the intermediate_result, is + // guaranteed to fit in int16_t. + int16_t intermediate_result[kMaxSuperBlockSizeInPixels * + (2 * kMaxSuperBlockSizeInPixels + 8)]; + const int intermediate_stride = kMaxSuperBlockSizeInPixels; + const int max_pixel_value = (1 << bitdepth) - 1; + + // Horizontal filter. + // Filter types used for width <= 4 are different from those for width > 4. + // When width > 4, the valid filter index range is always [0, 3]. + // When width <= 4, the valid filter index range is always [4, 5]. + // Similarly for height. + int filter_index = GetFilterIndex(horizontal_filter_index, width); + int16_t* intermediate = intermediate_result; + const auto* src = static_cast<const Pixel*>(reference); + const ptrdiff_t src_stride = reference_stride / sizeof(Pixel); + auto* dest = static_cast<Pixel*>(prediction); + const ptrdiff_t dest_stride = pred_stride / sizeof(Pixel); + const int ref_x = subpixel_x >> kScaleSubPixelBits; + // Note: assume the input src is already aligned to the correct start + // position. + int y = 0; + do { + int p = subpixel_x; + int x = 0; + do { + int sum = 0; + const Pixel* src_x = &src[(p >> kScaleSubPixelBits) - ref_x]; + const int filter_id = (p >> 6) & kSubPixelMask; + for (int k = 0; k < kSubPixelTaps; ++k) { + sum += kHalfSubPixelFilters[filter_index][filter_id][k] * src_x[k]; + } + intermediate[x] = RightShiftWithRounding(sum, kRoundBitsHorizontal - 1); + p += step_x; + } while (++x < width); + + src += src_stride; + intermediate += intermediate_stride; + } while (++y < intermediate_height); + + // Vertical filter. + filter_index = GetFilterIndex(vertical_filter_index, height); + intermediate = intermediate_result; + int p = subpixel_y & 1023; + y = 0; + do { + const int filter_id = (p >> 6) & kSubPixelMask; + int x = 0; + do { + int sum = 0; + for (int k = 0; k < kSubPixelTaps; ++k) { + sum += + kHalfSubPixelFilters[filter_index][filter_id][k] * + intermediate[((p >> kScaleSubPixelBits) + k) * intermediate_stride + + x]; + } + dest[x] = Clip3(RightShiftWithRounding(sum, kRoundBitsVertical - 1), 0, + max_pixel_value); + } while (++x < width); + + dest += dest_stride; + p += step_y; + } while (++y < height); +} + +template <int bitdepth, typename Pixel> +void ConvolveCompoundScale2D_C(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int vertical_filter_index, + const int subpixel_x, const int subpixel_y, + const int step_x, const int step_y, + const int width, const int height, + void* prediction, const ptrdiff_t pred_stride) { + // All compound functions output to the predictor buffer with |pred_stride| + // equal to |width|. + assert(pred_stride == width); + // Compound functions start at 4x4. + assert(width >= 4 && height >= 4); + constexpr int kRoundBitsHorizontal = (bitdepth == 12) + ? kInterRoundBitsHorizontal12bpp + : kInterRoundBitsHorizontal; + constexpr int kRoundBitsVertical = kInterRoundBitsCompoundVertical; + const int intermediate_height = + (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >> + kScaleSubPixelBits) + + kSubPixelTaps; + // The output of the horizontal filter, i.e. the intermediate_result, is + // guaranteed to fit in int16_t. + int16_t intermediate_result[kMaxSuperBlockSizeInPixels * + (2 * kMaxSuperBlockSizeInPixels + 8)]; + const int intermediate_stride = kMaxSuperBlockSizeInPixels; + + // Horizontal filter. + // Filter types used for width <= 4 are different from those for width > 4. + // When width > 4, the valid filter index range is always [0, 3]. + // When width <= 4, the valid filter index range is always [4, 5]. + // Similarly for height. + int filter_index = GetFilterIndex(horizontal_filter_index, width); + int16_t* intermediate = intermediate_result; + const auto* src = static_cast<const Pixel*>(reference); + const ptrdiff_t src_stride = reference_stride / sizeof(Pixel); + auto* dest = static_cast<uint16_t*>(prediction); + const int ref_x = subpixel_x >> kScaleSubPixelBits; + // Note: assume the input src is already aligned to the correct start + // position. + int y = 0; + do { + int p = subpixel_x; + int x = 0; + do { + int sum = 0; + const Pixel* src_x = &src[(p >> kScaleSubPixelBits) - ref_x]; + const int filter_id = (p >> 6) & kSubPixelMask; + for (int k = 0; k < kSubPixelTaps; ++k) { + sum += kHalfSubPixelFilters[filter_index][filter_id][k] * src_x[k]; + } + intermediate[x] = RightShiftWithRounding(sum, kRoundBitsHorizontal - 1); + p += step_x; + } while (++x < width); + + src += src_stride; + intermediate += intermediate_stride; + } while (++y < intermediate_height); + + // Vertical filter. + filter_index = GetFilterIndex(vertical_filter_index, height); + intermediate = intermediate_result; + int p = subpixel_y & 1023; + y = 0; + do { + const int filter_id = (p >> 6) & kSubPixelMask; + int x = 0; + do { + int sum = 0; + for (int k = 0; k < kSubPixelTaps; ++k) { + sum += + kHalfSubPixelFilters[filter_index][filter_id][k] * + intermediate[((p >> kScaleSubPixelBits) + k) * intermediate_stride + + x]; + } + sum = RightShiftWithRounding(sum, kRoundBitsVertical - 1); + sum += (bitdepth == 8) ? 0 : kCompoundOffset; + dest[x] = sum; + } while (++x < width); + + dest += pred_stride; + p += step_y; + } while (++y < height); +} + +template <int bitdepth, typename Pixel> +void ConvolveCompound2D_C(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int vertical_filter_index, + const int horizontal_filter_id, + const int vertical_filter_id, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + // All compound functions output to the predictor buffer with |pred_stride| + // equal to |width|. + assert(pred_stride == width); + // Compound functions start at 4x4. + assert(width >= 4 && height >= 4); + constexpr int kRoundBitsHorizontal = (bitdepth == 12) + ? kInterRoundBitsHorizontal12bpp + : kInterRoundBitsHorizontal; + constexpr int kRoundBitsVertical = kInterRoundBitsCompoundVertical; + const int intermediate_height = height + kSubPixelTaps - 1; + // The output of the horizontal filter, i.e. the intermediate_result, is + // guaranteed to fit in int16_t. + int16_t intermediate_result[kMaxSuperBlockSizeInPixels * + (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)]; + const int intermediate_stride = kMaxSuperBlockSizeInPixels; + + // Horizontal filter. + // Filter types used for width <= 4 are different from those for width > 4. + // When width > 4, the valid filter index range is always [0, 3]. + // When width <= 4, the valid filter index range is always [4, 5]. + // Similarly for height. + int filter_index = GetFilterIndex(horizontal_filter_index, width); + int16_t* intermediate = intermediate_result; + const ptrdiff_t src_stride = reference_stride / sizeof(Pixel); + const auto* src = static_cast<const Pixel*>(reference) - + kVerticalOffset * src_stride - kHorizontalOffset; + auto* dest = static_cast<uint16_t*>(prediction); + + // If |horizontal_filter_id| == 0 then ConvolveVertical() should be called. + assert(horizontal_filter_id != 0); + int y = 0; + do { + int x = 0; + do { + int sum = 0; + for (int k = 0; k < kSubPixelTaps; ++k) { + sum += kHalfSubPixelFilters[filter_index][horizontal_filter_id][k] * + src[x + k]; + } + intermediate[x] = RightShiftWithRounding(sum, kRoundBitsHorizontal - 1); + } while (++x < width); + + src += src_stride; + intermediate += intermediate_stride; + } while (++y < intermediate_height); + + // Vertical filter. + filter_index = GetFilterIndex(vertical_filter_index, height); + intermediate = intermediate_result; + // If |vertical_filter_id| == 0 then ConvolveHorizontal() should be called. + assert(vertical_filter_id != 0); + y = 0; + do { + int x = 0; + do { + int sum = 0; + for (int k = 0; k < kSubPixelTaps; ++k) { + sum += kHalfSubPixelFilters[filter_index][vertical_filter_id][k] * + intermediate[k * intermediate_stride + x]; + } + sum = RightShiftWithRounding(sum, kRoundBitsVertical - 1); + sum += (bitdepth == 8) ? 0 : kCompoundOffset; + dest[x] = sum; + } while (++x < width); + + dest += pred_stride; + intermediate += intermediate_stride; + } while (++y < height); +} + +// This function is a simplified version of ConvolveCompound2D_C. +// It is called when it is single prediction mode, where both horizontal and +// vertical filtering are required. +// The output is the single prediction of the block, clipped to valid pixel +// range. +template <int bitdepth, typename Pixel> +void Convolve2D_C(const void* const reference, const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int vertical_filter_index, + const int horizontal_filter_id, const int vertical_filter_id, + const int width, const int height, void* prediction, + const ptrdiff_t pred_stride) { + constexpr int kRoundBitsHorizontal = (bitdepth == 12) + ? kInterRoundBitsHorizontal12bpp + : kInterRoundBitsHorizontal; + constexpr int kRoundBitsVertical = + (bitdepth == 12) ? kInterRoundBitsVertical12bpp : kInterRoundBitsVertical; + const int intermediate_height = height + kSubPixelTaps - 1; + // The output of the horizontal filter, i.e. the intermediate_result, is + // guaranteed to fit in int16_t. + int16_t intermediate_result[kMaxSuperBlockSizeInPixels * + (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)]; + const int intermediate_stride = kMaxSuperBlockSizeInPixels; + const int max_pixel_value = (1 << bitdepth) - 1; + + // Horizontal filter. + // Filter types used for width <= 4 are different from those for width > 4. + // When width > 4, the valid filter index range is always [0, 3]. + // When width <= 4, the valid filter index range is always [4, 5]. + // Similarly for height. + int filter_index = GetFilterIndex(horizontal_filter_index, width); + int16_t* intermediate = intermediate_result; + const ptrdiff_t src_stride = reference_stride / sizeof(Pixel); + const auto* src = static_cast<const Pixel*>(reference) - + kVerticalOffset * src_stride - kHorizontalOffset; + auto* dest = static_cast<Pixel*>(prediction); + const ptrdiff_t dest_stride = pred_stride / sizeof(Pixel); + // If |horizontal_filter_id| == 0 then ConvolveVertical() should be called. + assert(horizontal_filter_id != 0); + int y = 0; + do { + int x = 0; + do { + int sum = 0; + for (int k = 0; k < kSubPixelTaps; ++k) { + sum += kHalfSubPixelFilters[filter_index][horizontal_filter_id][k] * + src[x + k]; + } + intermediate[x] = RightShiftWithRounding(sum, kRoundBitsHorizontal - 1); + } while (++x < width); + + src += src_stride; + intermediate += intermediate_stride; + } while (++y < intermediate_height); + + // Vertical filter. + filter_index = GetFilterIndex(vertical_filter_index, height); + intermediate = intermediate_result; + // If |vertical_filter_id| == 0 then ConvolveHorizontal() should be called. + assert(vertical_filter_id != 0); + y = 0; + do { + int x = 0; + do { + int sum = 0; + for (int k = 0; k < kSubPixelTaps; ++k) { + sum += kHalfSubPixelFilters[filter_index][vertical_filter_id][k] * + intermediate[k * intermediate_stride + x]; + } + dest[x] = Clip3(RightShiftWithRounding(sum, kRoundBitsVertical - 1), 0, + max_pixel_value); + } while (++x < width); + + dest += dest_stride; + intermediate += intermediate_stride; + } while (++y < height); +} + +// This function is a simplified version of Convolve2D_C. +// It is called when it is single prediction mode, where only horizontal +// filtering is required. +// The output is the single prediction of the block, clipped to valid pixel +// range. +template <int bitdepth, typename Pixel> +void ConvolveHorizontal_C(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int /*vertical_filter_index*/, + const int horizontal_filter_id, + const int /*vertical_filter_id*/, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + constexpr int kRoundBitsHorizontal = (bitdepth == 12) + ? kInterRoundBitsHorizontal12bpp + : kInterRoundBitsHorizontal; + const int filter_index = GetFilterIndex(horizontal_filter_index, width); + const int bits = kFilterBits - kRoundBitsHorizontal; + const auto* src = static_cast<const Pixel*>(reference) - kHorizontalOffset; + const ptrdiff_t src_stride = reference_stride / sizeof(Pixel); + auto* dest = static_cast<Pixel*>(prediction); + const ptrdiff_t dest_stride = pred_stride / sizeof(Pixel); + const int max_pixel_value = (1 << bitdepth) - 1; + int y = 0; + do { + int x = 0; + do { + int sum = 0; + for (int k = 0; k < kSubPixelTaps; ++k) { + sum += kHalfSubPixelFilters[filter_index][horizontal_filter_id][k] * + src[x + k]; + } + sum = RightShiftWithRounding(sum, kRoundBitsHorizontal - 1); + dest[x] = Clip3(RightShiftWithRounding(sum, bits), 0, max_pixel_value); + } while (++x < width); + + src += src_stride; + dest += dest_stride; + } while (++y < height); +} + +// This function is a simplified version of Convolve2D_C. +// It is called when it is single prediction mode, where only vertical +// filtering is required. +// The output is the single prediction of the block, clipped to valid pixel +// range. +template <int bitdepth, typename Pixel> +void ConvolveVertical_C(const void* const reference, + const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, + const int vertical_filter_index, + const int /*horizontal_filter_id*/, + const int vertical_filter_id, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + const int filter_index = GetFilterIndex(vertical_filter_index, height); + const ptrdiff_t src_stride = reference_stride / sizeof(Pixel); + const auto* src = + static_cast<const Pixel*>(reference) - kVerticalOffset * src_stride; + auto* dest = static_cast<Pixel*>(prediction); + const ptrdiff_t dest_stride = pred_stride / sizeof(Pixel); + // Copy filters must call ConvolveCopy(). + assert(vertical_filter_id != 0); + + const int max_pixel_value = (1 << bitdepth) - 1; + int y = 0; + do { + int x = 0; + do { + int sum = 0; + for (int k = 0; k < kSubPixelTaps; ++k) { + sum += kHalfSubPixelFilters[filter_index][vertical_filter_id][k] * + src[k * src_stride + x]; + } + dest[x] = Clip3(RightShiftWithRounding(sum, kFilterBits - 1), 0, + max_pixel_value); + } while (++x < width); + + src += src_stride; + dest += dest_stride; + } while (++y < height); +} + +template <int bitdepth, typename Pixel> +void ConvolveCopy_C(const void* const reference, + const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, + const int /*vertical_filter_index*/, + const int /*horizontal_filter_id*/, + const int /*vertical_filter_id*/, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + const auto* src = static_cast<const uint8_t*>(reference); + auto* dest = static_cast<uint8_t*>(prediction); + int y = 0; + do { + memcpy(dest, src, width * sizeof(Pixel)); + src += reference_stride; + dest += pred_stride; + } while (++y < height); +} + +template <int bitdepth, typename Pixel> +void ConvolveCompoundCopy_C(const void* const reference, + const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, + const int /*vertical_filter_index*/, + const int /*horizontal_filter_id*/, + const int /*vertical_filter_id*/, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + // All compound functions output to the predictor buffer with |pred_stride| + // equal to |width|. + assert(pred_stride == width); + // Compound functions start at 4x4. + assert(width >= 4 && height >= 4); + constexpr int kRoundBitsVertical = + ((bitdepth == 12) ? kInterRoundBitsVertical12bpp + : kInterRoundBitsVertical) - + kInterRoundBitsCompoundVertical; + const auto* src = static_cast<const Pixel*>(reference); + const ptrdiff_t src_stride = reference_stride / sizeof(Pixel); + auto* dest = static_cast<uint16_t*>(prediction); + int y = 0; + do { + int x = 0; + do { + int sum = (bitdepth == 8) ? 0 : ((1 << bitdepth) + (1 << (bitdepth - 1))); + sum += src[x]; + dest[x] = sum << kRoundBitsVertical; + } while (++x < width); + src += src_stride; + dest += pred_stride; + } while (++y < height); +} + +// This function is a simplified version of ConvolveCompound2D_C. +// It is called when it is compound prediction mode, where only horizontal +// filtering is required. +// The output is not clipped to valid pixel range. Its output will be +// blended with another predictor to generate the final prediction of the block. +template <int bitdepth, typename Pixel> +void ConvolveCompoundHorizontal_C( + const void* const reference, const ptrdiff_t reference_stride, + const int horizontal_filter_index, const int /*vertical_filter_index*/, + const int horizontal_filter_id, const int /*vertical_filter_id*/, + const int width, const int height, void* prediction, + const ptrdiff_t pred_stride) { + // All compound functions output to the predictor buffer with |pred_stride| + // equal to |width|. + assert(pred_stride == width); + // Compound functions start at 4x4. + assert(width >= 4 && height >= 4); + constexpr int kRoundBitsHorizontal = (bitdepth == 12) + ? kInterRoundBitsHorizontal12bpp + : kInterRoundBitsHorizontal; + const int filter_index = GetFilterIndex(horizontal_filter_index, width); + const auto* src = static_cast<const Pixel*>(reference) - kHorizontalOffset; + const ptrdiff_t src_stride = reference_stride / sizeof(Pixel); + auto* dest = static_cast<uint16_t*>(prediction); + // Copy filters must call ConvolveCopy(). + assert(horizontal_filter_id != 0); + int y = 0; + do { + int x = 0; + do { + int sum = 0; + for (int k = 0; k < kSubPixelTaps; ++k) { + sum += kHalfSubPixelFilters[filter_index][horizontal_filter_id][k] * + src[x + k]; + } + sum = RightShiftWithRounding(sum, kRoundBitsHorizontal - 1); + sum += (bitdepth == 8) ? 0 : kCompoundOffset; + dest[x] = sum; + } while (++x < width); + + src += src_stride; + dest += pred_stride; + } while (++y < height); +} + +// This function is a simplified version of ConvolveCompound2D_C. +// It is called when it is compound prediction mode, where only vertical +// filtering is required. +// The output is not clipped to valid pixel range. Its output will be +// blended with another predictor to generate the final prediction of the block. +template <int bitdepth, typename Pixel> +void ConvolveCompoundVertical_C(const void* const reference, + const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, + const int vertical_filter_index, + const int /*horizontal_filter_id*/, + const int vertical_filter_id, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + // All compound functions output to the predictor buffer with |pred_stride| + // equal to |width|. + assert(pred_stride == width); + // Compound functions start at 4x4. + assert(width >= 4 && height >= 4); + constexpr int kRoundBitsHorizontal = (bitdepth == 12) + ? kInterRoundBitsHorizontal12bpp + : kInterRoundBitsHorizontal; + const int filter_index = GetFilterIndex(vertical_filter_index, height); + const ptrdiff_t src_stride = reference_stride / sizeof(Pixel); + const auto* src = + static_cast<const Pixel*>(reference) - kVerticalOffset * src_stride; + auto* dest = static_cast<uint16_t*>(prediction); + // Copy filters must call ConvolveCopy(). + assert(vertical_filter_id != 0); + int y = 0; + do { + int x = 0; + do { + int sum = 0; + for (int k = 0; k < kSubPixelTaps; ++k) { + sum += kHalfSubPixelFilters[filter_index][vertical_filter_id][k] * + src[k * src_stride + x]; + } + sum = RightShiftWithRounding(sum, kRoundBitsHorizontal - 1); + sum += (bitdepth == 8) ? 0 : kCompoundOffset; + dest[x] = sum; + } while (++x < width); + src += src_stride; + dest += pred_stride; + } while (++y < height); +} + +// This function is used when intra block copy is present. +// It is called when it is single prediction mode for U/V plane, where the +// reference block is from current frame and both horizontal and vertical +// filtering are required. +// The output is the single prediction of the block, clipped to valid pixel +// range. +template <int bitdepth, typename Pixel> +void ConvolveIntraBlockCopy2D_C(const void* const reference, + const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, + const int /*vertical_filter_index*/, + const int /*horizontal_filter_id*/, + const int /*vertical_filter_id*/, + const int width, const int height, + void* prediction, const ptrdiff_t pred_stride) { + const auto* src = static_cast<const Pixel*>(reference); + const ptrdiff_t src_stride = reference_stride / sizeof(Pixel); + auto* dest = static_cast<Pixel*>(prediction); + const ptrdiff_t dest_stride = pred_stride / sizeof(Pixel); + const int intermediate_height = height + 1; + uint16_t intermediate_result[kMaxSuperBlockSizeInPixels * + (kMaxSuperBlockSizeInPixels + 1)]; + uint16_t* intermediate = intermediate_result; + // Note: allow vertical access to height + 1. Because this function is only + // for u/v plane of intra block copy, such access is guaranteed to be within + // the prediction block. + int y = 0; + do { + int x = 0; + do { + intermediate[x] = src[x] + src[x + 1]; + } while (++x < width); + + src += src_stride; + intermediate += width; + } while (++y < intermediate_height); + + intermediate = intermediate_result; + y = 0; + do { + int x = 0; + do { + dest[x] = + RightShiftWithRounding(intermediate[x] + intermediate[x + width], 2); + } while (++x < width); + + intermediate += width; + dest += dest_stride; + } while (++y < height); +} + +// This function is used when intra block copy is present. +// It is called when it is single prediction mode for U/V plane, where the +// reference block is from the current frame and only horizontal or vertical +// filtering is required. +// The output is the single prediction of the block, clipped to valid pixel +// range. +// The filtering of intra block copy is simply the average of current and +// the next pixel. +template <int bitdepth, typename Pixel, bool is_horizontal> +void ConvolveIntraBlockCopy1D_C(const void* const reference, + const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, + const int /*vertical_filter_index*/, + const int /*horizontal_filter_id*/, + const int /*vertical_filter_id*/, + const int width, const int height, + void* prediction, const ptrdiff_t pred_stride) { + const auto* src = static_cast<const Pixel*>(reference); + const ptrdiff_t src_stride = reference_stride / sizeof(Pixel); + auto* dest = static_cast<Pixel*>(prediction); + const ptrdiff_t dest_stride = pred_stride / sizeof(Pixel); + const ptrdiff_t offset = is_horizontal ? 1 : src_stride; + int y = 0; + do { + int x = 0; + do { + dest[x] = RightShiftWithRounding(src[x] + src[x + offset], 1); + } while (++x < width); + + src += src_stride; + dest += dest_stride; + } while (++y < height); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(8); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->convolve[0][0][0][0] = ConvolveCopy_C<8, uint8_t>; + dsp->convolve[0][0][0][1] = ConvolveHorizontal_C<8, uint8_t>; + dsp->convolve[0][0][1][0] = ConvolveVertical_C<8, uint8_t>; + dsp->convolve[0][0][1][1] = Convolve2D_C<8, uint8_t>; + + dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_C<8, uint8_t>; + dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_C<8, uint8_t>; + dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_C<8, uint8_t>; + dsp->convolve[0][1][1][1] = ConvolveCompound2D_C<8, uint8_t>; + + dsp->convolve[1][0][0][0] = ConvolveCopy_C<8, uint8_t>; + dsp->convolve[1][0][0][1] = + ConvolveIntraBlockCopy1D_C<8, uint8_t, /*is_horizontal=*/true>; + dsp->convolve[1][0][1][0] = + ConvolveIntraBlockCopy1D_C<8, uint8_t, /*is_horizontal=*/false>; + dsp->convolve[1][0][1][1] = ConvolveIntraBlockCopy2D_C<8, uint8_t>; + + dsp->convolve[1][1][0][0] = nullptr; + dsp->convolve[1][1][0][1] = nullptr; + dsp->convolve[1][1][1][0] = nullptr; + dsp->convolve[1][1][1][1] = nullptr; + + dsp->convolve_scale[0] = ConvolveScale2D_C<8, uint8_t>; + dsp->convolve_scale[1] = ConvolveCompoundScale2D_C<8, uint8_t>; +#else // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +#ifndef LIBGAV1_Dsp8bpp_ConvolveCopy + dsp->convolve[0][0][0][0] = ConvolveCopy_C<8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_ConvolveHorizontal + dsp->convolve[0][0][0][1] = ConvolveHorizontal_C<8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_ConvolveVertical + dsp->convolve[0][0][1][0] = ConvolveVertical_C<8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_Convolve2D + dsp->convolve[0][0][1][1] = Convolve2D_C<8, uint8_t>; +#endif + +#ifndef LIBGAV1_Dsp8bpp_ConvolveCompoundCopy + dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_C<8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_ConvolveCompoundHorizontal + dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_C<8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_ConvolveCompoundVertical + dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_C<8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_ConvolveCompound2D + dsp->convolve[0][1][1][1] = ConvolveCompound2D_C<8, uint8_t>; +#endif + +#ifndef LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopy + dsp->convolve[1][0][0][0] = ConvolveCopy_C<8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopyHorizontal + dsp->convolve[1][0][0][1] = + ConvolveIntraBlockCopy1D_C<8, uint8_t, /*is_horizontal=*/true>; +#endif +#ifndef LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopyVertical + dsp->convolve[1][0][1][0] = + ConvolveIntraBlockCopy1D_C<8, uint8_t, /*is_horizontal=*/false>; +#endif +#ifndef LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopy2D + dsp->convolve[1][0][1][1] = ConvolveIntraBlockCopy2D_C<8, uint8_t>; +#endif + + dsp->convolve[1][1][0][0] = nullptr; + dsp->convolve[1][1][0][1] = nullptr; + dsp->convolve[1][1][1][0] = nullptr; + dsp->convolve[1][1][1][1] = nullptr; + +#ifndef LIBGAV1_Dsp8bpp_ConvolveScale2D + dsp->convolve_scale[0] = ConvolveScale2D_C<8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_ConvolveCompoundScale2D + dsp->convolve_scale[1] = ConvolveCompoundScale2D_C<8, uint8_t>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(10); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->convolve[0][0][0][0] = ConvolveCopy_C<10, uint16_t>; + dsp->convolve[0][0][0][1] = ConvolveHorizontal_C<10, uint16_t>; + dsp->convolve[0][0][1][0] = ConvolveVertical_C<10, uint16_t>; + dsp->convolve[0][0][1][1] = Convolve2D_C<10, uint16_t>; + + dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_C<10, uint16_t>; + dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_C<10, uint16_t>; + dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_C<10, uint16_t>; + dsp->convolve[0][1][1][1] = ConvolveCompound2D_C<10, uint16_t>; + + dsp->convolve[1][0][0][0] = ConvolveCopy_C<10, uint16_t>; + dsp->convolve[1][0][0][1] = + ConvolveIntraBlockCopy1D_C<10, uint16_t, /*is_horizontal=*/true>; + dsp->convolve[1][0][1][0] = + ConvolveIntraBlockCopy1D_C<10, uint16_t, /*is_horizontal=*/false>; + dsp->convolve[1][0][1][1] = ConvolveIntraBlockCopy2D_C<10, uint16_t>; + + dsp->convolve[1][1][0][0] = nullptr; + dsp->convolve[1][1][0][1] = nullptr; + dsp->convolve[1][1][1][0] = nullptr; + dsp->convolve[1][1][1][1] = nullptr; + + dsp->convolve_scale[0] = ConvolveScale2D_C<10, uint16_t>; + dsp->convolve_scale[1] = ConvolveCompoundScale2D_C<10, uint16_t>; +#else // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +#ifndef LIBGAV1_Dsp10bpp_ConvolveCopy + dsp->convolve[0][0][0][0] = ConvolveCopy_C<10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_ConvolveHorizontal + dsp->convolve[0][0][0][1] = ConvolveHorizontal_C<10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_ConvolveVertical + dsp->convolve[0][0][1][0] = ConvolveVertical_C<10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_Convolve2D + dsp->convolve[0][0][1][1] = Convolve2D_C<10, uint16_t>; +#endif + +#ifndef LIBGAV1_Dsp10bpp_ConvolveCompoundCopy + dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_C<10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_ConvolveCompoundHorizontal + dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_C<10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_ConvolveCompoundVertical + dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_C<10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_ConvolveCompound2D + dsp->convolve[0][1][1][1] = ConvolveCompound2D_C<10, uint16_t>; +#endif + +#ifndef LIBGAV1_Dsp10bpp_ConvolveIntraBlockCopy + dsp->convolve[1][0][0][0] = ConvolveCopy_C<10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_ConvolveIntraBlockHorizontal + dsp->convolve[1][0][0][1] = + ConvolveIntraBlockCopy1D_C<10, uint16_t, /*is_horizontal=*/true>; +#endif +#ifndef LIBGAV1_Dsp10bpp_ConvolveIntraBlockVertical + dsp->convolve[1][0][1][0] = + ConvolveIntraBlockCopy1D_C<10, uint16_t, /*is_horizontal=*/false>; +#endif +#ifndef LIBGAV1_Dsp10bpp_ConvolveIntraBlock2D + dsp->convolve[1][0][1][1] = ConvolveIntraBlockCopy2D_C<10, uint16_t>; +#endif + + dsp->convolve[1][1][0][0] = nullptr; + dsp->convolve[1][1][0][1] = nullptr; + dsp->convolve[1][1][1][0] = nullptr; + dsp->convolve[1][1][1][1] = nullptr; + +#ifndef LIBGAV1_Dsp10bpp_ConvolveScale2D + dsp->convolve_scale[0] = ConvolveScale2D_C<10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_ConvolveCompoundScale2D + dsp->convolve_scale[1] = ConvolveCompoundScale2D_C<10, uint16_t>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} +#endif + +} // namespace + +void ConvolveInit_C() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 diff --git a/src/dsp/convolve.h b/src/dsp/convolve.h new file mode 100644 index 0000000..5bc0bad --- /dev/null +++ b/src/dsp/convolve.h @@ -0,0 +1,49 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_CONVOLVE_H_ +#define LIBGAV1_SRC_DSP_CONVOLVE_H_ + +// Pull in LIBGAV1_DspXXX defines representing the implementation status +// of each function. The resulting value of each can be used by each module to +// determine whether an implementation is needed at compile time. +// IWYU pragma: begin_exports + +// ARM: +#include "src/dsp/arm/convolve_neon.h" + +// x86: +// Note includes should be sorted in logical order avx2/avx/sse4, etc. +// The order of includes is important as each tests for a superior version +// before setting the base. +// clang-format off +#include "src/dsp/x86/convolve_avx2.h" +#include "src/dsp/x86/convolve_sse4.h" +// clang-format on + +// IWYU pragma: end_exports + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::convolve and Dsp::convolve_scale. This function is not +// thread-safe. +void ConvolveInit_C(); + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_CONVOLVE_H_ diff --git a/src/dsp/convolve.inc b/src/dsp/convolve.inc new file mode 100644 index 0000000..140648b --- /dev/null +++ b/src/dsp/convolve.inc @@ -0,0 +1,50 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Constants and utility functions used for convolve implementations. +// This will be included inside an anonymous namespace on files where these are +// necessary. + +int GetNumTapsInFilter(const int filter_index) { + if (filter_index < 2) { + // Despite the names these only use 6 taps. + // kInterpolationFilterEightTap + // kInterpolationFilterEightTapSmooth + return 6; + } + + if (filter_index == 2) { + // kInterpolationFilterEightTapSharp + return 8; + } + + if (filter_index == 3) { + // kInterpolationFilterBilinear + return 2; + } + + assert(filter_index > 3); + // For small sizes (width/height <= 4) the large filters are replaced with 4 + // tap options. + // If the original filters were |kInterpolationFilterEightTap| or + // |kInterpolationFilterEightTapSharp| then it becomes + // |kInterpolationFilterSwitchable|. + // If it was |kInterpolationFilterEightTapSmooth| then it becomes an unnamed 4 + // tap filter. + return 4; +} + +constexpr int kIntermediateStride = kMaxSuperBlockSizeInPixels; +constexpr int kHorizontalOffset = 3; +constexpr int kFilterIndexShift = 6; diff --git a/src/dsp/distance_weighted_blend.cc b/src/dsp/distance_weighted_blend.cc new file mode 100644 index 0000000..a035fbe --- /dev/null +++ b/src/dsp/distance_weighted_blend.cc @@ -0,0 +1,101 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/distance_weighted_blend.h" + +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <type_traits> + +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace { + +template <int bitdepth, typename Pixel> +void DistanceWeightedBlend_C(const void* prediction_0, const void* prediction_1, + const uint8_t weight_0, const uint8_t weight_1, + const int width, const int height, + void* const dest, const ptrdiff_t dest_stride) { + // 7.11.3.2 Rounding variables derivation process + // 2 * FILTER_BITS(7) - (InterRound0(3|5) + InterRound1(7)) + constexpr int inter_post_round_bits = (bitdepth == 12) ? 2 : 4; + using PredType = + typename std::conditional<bitdepth == 8, int16_t, uint16_t>::type; + const auto* pred_0 = static_cast<const PredType*>(prediction_0); + const auto* pred_1 = static_cast<const PredType*>(prediction_1); + auto* dst = static_cast<Pixel*>(dest); + const ptrdiff_t dst_stride = dest_stride / sizeof(Pixel); + + int y = 0; + do { + int x = 0; + do { + // See warp.cc and convolve.cc for detailed prediction ranges. + // weight_0 + weight_1 = 16. + int res = pred_0[x] * weight_0 + pred_1[x] * weight_1; + res -= (bitdepth == 8) ? 0 : kCompoundOffset * 16; + dst[x] = static_cast<Pixel>( + Clip3(RightShiftWithRounding(res, inter_post_round_bits + 4), 0, + (1 << bitdepth) - 1)); + } while (++x < width); + + dst += dst_stride; + pred_0 += width; + pred_1 += width; + } while (++y < height); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(8); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->distance_weighted_blend = DistanceWeightedBlend_C<8, uint8_t>; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp8bpp_DistanceWeightedBlend + dsp->distance_weighted_blend = DistanceWeightedBlend_C<8, uint8_t>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(10); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->distance_weighted_blend = DistanceWeightedBlend_C<10, uint16_t>; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp10bpp_DistanceWeightedBlend + dsp->distance_weighted_blend = DistanceWeightedBlend_C<10, uint16_t>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} +#endif + +} // namespace + +void DistanceWeightedBlendInit_C() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 diff --git a/src/dsp/distance_weighted_blend.h b/src/dsp/distance_weighted_blend.h new file mode 100644 index 0000000..1a782b6 --- /dev/null +++ b/src/dsp/distance_weighted_blend.h @@ -0,0 +1,47 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_DISTANCE_WEIGHTED_BLEND_H_ +#define LIBGAV1_SRC_DSP_DISTANCE_WEIGHTED_BLEND_H_ + +// Pull in LIBGAV1_DspXXX defines representing the implementation status +// of each function. The resulting value of each can be used by each module to +// determine whether an implementation is needed at compile time. +// IWYU pragma: begin_exports + +// ARM: +#include "src/dsp/arm/distance_weighted_blend_neon.h" + +// x86: +// Note includes should be sorted in logical order avx2/avx/sse4, etc. +// The order of includes is important as each tests for a superior version +// before setting the base. +// clang-format off +#include "src/dsp/x86/distance_weighted_blend_sse4.h" +// clang-format on + +// IWYU pragma: end_exports + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::distance_weighted_blend. This function is not thread-safe. +void DistanceWeightedBlendInit_C(); + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_DISTANCE_WEIGHTED_BLEND_H_ diff --git a/src/dsp/dsp.cc b/src/dsp/dsp.cc new file mode 100644 index 0000000..5b54c4e --- /dev/null +++ b/src/dsp/dsp.cc @@ -0,0 +1,150 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/dsp.h" + +#include <mutex> // NOLINT (unapproved c++11 header) + +#include "src/dsp/arm/weight_mask_neon.h" +#include "src/dsp/average_blend.h" +#include "src/dsp/cdef.h" +#include "src/dsp/convolve.h" +#include "src/dsp/distance_weighted_blend.h" +#include "src/dsp/film_grain.h" +#include "src/dsp/intra_edge.h" +#include "src/dsp/intrapred.h" +#include "src/dsp/inverse_transform.h" +#include "src/dsp/loop_filter.h" +#include "src/dsp/loop_restoration.h" +#include "src/dsp/mask_blend.h" +#include "src/dsp/motion_field_projection.h" +#include "src/dsp/motion_vector_search.h" +#include "src/dsp/obmc.h" +#include "src/dsp/super_res.h" +#include "src/dsp/warp.h" +#include "src/dsp/weight_mask.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp_internal { + +dsp::Dsp* GetWritableDspTable(int bitdepth) { + switch (bitdepth) { + case 8: { + static dsp::Dsp dsp_8bpp; + return &dsp_8bpp; + } +#if LIBGAV1_MAX_BITDEPTH >= 10 + case 10: { + static dsp::Dsp dsp_10bpp; + return &dsp_10bpp; + } +#endif + } + return nullptr; +} + +} // namespace dsp_internal + +namespace dsp { + +void DspInit() { + static std::once_flag once; + std::call_once(once, []() { + AverageBlendInit_C(); + CdefInit_C(); + ConvolveInit_C(); + DistanceWeightedBlendInit_C(); + FilmGrainInit_C(); + IntraEdgeInit_C(); + IntraPredInit_C(); + InverseTransformInit_C(); + LoopFilterInit_C(); + LoopRestorationInit_C(); + MaskBlendInit_C(); + MotionFieldProjectionInit_C(); + MotionVectorSearchInit_C(); + ObmcInit_C(); + SuperResInit_C(); + WarpInit_C(); + WeightMaskInit_C(); +#if LIBGAV1_ENABLE_SSE4_1 || LIBGAV1_ENABLE_AVX2 + const uint32_t cpu_features = GetCpuInfo(); +#if LIBGAV1_ENABLE_SSE4_1 + if ((cpu_features & kSSE4_1) != 0) { + AverageBlendInit_SSE4_1(); + CdefInit_SSE4_1(); + ConvolveInit_SSE4_1(); + DistanceWeightedBlendInit_SSE4_1(); + IntraEdgeInit_SSE4_1(); + IntraPredInit_SSE4_1(); + IntraPredCflInit_SSE4_1(); + IntraPredSmoothInit_SSE4_1(); + InverseTransformInit_SSE4_1(); + LoopFilterInit_SSE4_1(); + LoopRestorationInit_SSE4_1(); + MaskBlendInit_SSE4_1(); + MotionFieldProjectionInit_SSE4_1(); + MotionVectorSearchInit_SSE4_1(); + ObmcInit_SSE4_1(); + SuperResInit_SSE4_1(); + WarpInit_SSE4_1(); + WeightMaskInit_SSE4_1(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + LoopRestorationInit10bpp_SSE4_1(); +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + } +#endif // LIBGAV1_ENABLE_SSE4_1 +#if LIBGAV1_ENABLE_AVX2 + if ((cpu_features & kAVX2) != 0) { + ConvolveInit_AVX2(); + LoopRestorationInit_AVX2(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + LoopRestorationInit10bpp_AVX2(); +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + } +#endif // LIBGAV1_ENABLE_AVX2 +#endif // LIBGAV1_ENABLE_SSE4_1 || LIBGAV1_ENABLE_AVX2 +#if LIBGAV1_ENABLE_NEON + AverageBlendInit_NEON(); + CdefInit_NEON(); + ConvolveInit_NEON(); + DistanceWeightedBlendInit_NEON(); + FilmGrainInit_NEON(); + IntraEdgeInit_NEON(); + IntraPredCflInit_NEON(); + IntraPredDirectionalInit_NEON(); + IntraPredFilterIntraInit_NEON(); + IntraPredInit_NEON(); + IntraPredSmoothInit_NEON(); + InverseTransformInit_NEON(); + LoopFilterInit_NEON(); + LoopRestorationInit_NEON(); + MaskBlendInit_NEON(); + MotionFieldProjectionInit_NEON(); + MotionVectorSearchInit_NEON(); + ObmcInit_NEON(); + SuperResInit_NEON(); + WarpInit_NEON(); + WeightMaskInit_NEON(); +#endif // LIBGAV1_ENABLE_NEON + }); +} + +const Dsp* GetDspTable(int bitdepth) { + return dsp_internal::GetWritableDspTable(bitdepth); +} + +} // namespace dsp +} // namespace libgav1 diff --git a/src/dsp/dsp.h b/src/dsp/dsp.h new file mode 100644 index 0000000..fcbac3a --- /dev/null +++ b/src/dsp/dsp.h @@ -0,0 +1,910 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_DSP_H_ +#define LIBGAV1_SRC_DSP_DSP_H_ + +#include <cstddef> // ptrdiff_t +#include <cstdint> +#include <cstdlib> + +#include "src/dsp/common.h" +#include "src/dsp/constants.h" +#include "src/dsp/film_grain_common.h" +#include "src/utils/cpu.h" +#include "src/utils/reference_info.h" +#include "src/utils/types.h" + +namespace libgav1 { +namespace dsp { + +#if !defined(LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS) +#define LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS 0 +#endif + +enum IntraPredictor : uint8_t { + kIntraPredictorDcFill, + kIntraPredictorDcTop, + kIntraPredictorDcLeft, + kIntraPredictorDc, + kIntraPredictorVertical, + kIntraPredictorHorizontal, + kIntraPredictorPaeth, + kIntraPredictorSmooth, + kIntraPredictorSmoothVertical, + kIntraPredictorSmoothHorizontal, + kNumIntraPredictors +}; + +// List of valid 1D transforms. +enum Transform1D : uint8_t { + k1DTransformDct, // Discrete Cosine Transform. + k1DTransformAdst, // Asymmetric Discrete Sine Transform. + k1DTransformIdentity, + k1DTransformWht, // Walsh Hadamard Transform. + kNum1DTransforms +}; + +// List of valid 1D transform sizes. Not all transforms may be available for all +// the sizes. +enum TransformSize1D : uint8_t { + k1DTransformSize4, + k1DTransformSize8, + k1DTransformSize16, + k1DTransformSize32, + k1DTransformSize64, + kNum1DTransformSizes +}; + +// The maximum width of the loop filter, fewer pixels may be filtered depending +// on strength thresholds. +enum LoopFilterSize : uint8_t { + kLoopFilterSize4, + kLoopFilterSize6, + kLoopFilterSize8, + kLoopFilterSize14, + kNumLoopFilterSizes +}; + +enum : uint8_t { + kRow = 0, + kColumn = 1, +}; + +//------------------------------------------------------------------------------ +// ToString() +// +// These functions are meant to be used only in debug logging and within tests. +// They are defined inline to avoid including the strings in the release +// library when logging is disabled; unreferenced functions will not be added to +// any object file in that case. + +inline const char* ToString(const IntraPredictor predictor) { + switch (predictor) { + case kIntraPredictorDcFill: + return "kIntraPredictorDcFill"; + case kIntraPredictorDcTop: + return "kIntraPredictorDcTop"; + case kIntraPredictorDcLeft: + return "kIntraPredictorDcLeft"; + case kIntraPredictorDc: + return "kIntraPredictorDc"; + case kIntraPredictorVertical: + return "kIntraPredictorVertical"; + case kIntraPredictorHorizontal: + return "kIntraPredictorHorizontal"; + case kIntraPredictorPaeth: + return "kIntraPredictorPaeth"; + case kIntraPredictorSmooth: + return "kIntraPredictorSmooth"; + case kIntraPredictorSmoothVertical: + return "kIntraPredictorSmoothVertical"; + case kIntraPredictorSmoothHorizontal: + return "kIntraPredictorSmoothHorizontal"; + case kNumIntraPredictors: + return "kNumIntraPredictors"; + } + abort(); +} + +inline const char* ToString(const Transform1D transform) { + switch (transform) { + case k1DTransformDct: + return "k1DTransformDct"; + case k1DTransformAdst: + return "k1DTransformAdst"; + case k1DTransformIdentity: + return "k1DTransformIdentity"; + case k1DTransformWht: + return "k1DTransformWht"; + case kNum1DTransforms: + return "kNum1DTransforms"; + } + abort(); +} + +inline const char* ToString(const TransformSize1D transform_size) { + switch (transform_size) { + case k1DTransformSize4: + return "k1DTransformSize4"; + case k1DTransformSize8: + return "k1DTransformSize8"; + case k1DTransformSize16: + return "k1DTransformSize16"; + case k1DTransformSize32: + return "k1DTransformSize32"; + case k1DTransformSize64: + return "k1DTransformSize64"; + case kNum1DTransformSizes: + return "kNum1DTransformSizes"; + } + abort(); +} + +inline const char* ToString(const LoopFilterSize filter_size) { + switch (filter_size) { + case kLoopFilterSize4: + return "kLoopFilterSize4"; + case kLoopFilterSize6: + return "kLoopFilterSize6"; + case kLoopFilterSize8: + return "kLoopFilterSize8"; + case kLoopFilterSize14: + return "kLoopFilterSize14"; + case kNumLoopFilterSizes: + return "kNumLoopFilterSizes"; + } + abort(); +} + +inline const char* ToString(const LoopFilterType filter_type) { + switch (filter_type) { + case kLoopFilterTypeVertical: + return "kLoopFilterTypeVertical"; + case kLoopFilterTypeHorizontal: + return "kLoopFilterTypeHorizontal"; + case kNumLoopFilterTypes: + return "kNumLoopFilterTypes"; + } + abort(); +} + +//------------------------------------------------------------------------------ +// Intra predictors. Section 7.11.2. +// These require access to one or both of the top row and left column. Some may +// access the top-left (top[-1]), top-right (top[width+N]), bottom-left +// (left[height+N]) or upper-left (left[-1]). + +// Intra predictor function signature. Sections 7.11.2.2, 7.11.2.4 (#10,#11), +// 7.11.2.5, 7.11.2.6. +// |dst| is an unaligned pointer to the output block. Pixel size is determined +// by bitdepth with |stride| given in bytes. |top| is an unaligned pointer to +// the row above |dst|. |left| is an aligned vector of the column to the left +// of |dst|. top-left and bottom-left may be accessed. +using IntraPredictorFunc = void (*)(void* dst, ptrdiff_t stride, + const void* top, const void* left); +using IntraPredictorFuncs = + IntraPredictorFunc[kNumTransformSizes][kNumIntraPredictors]; + +// Directional intra predictor function signature, zone 1 (0 < angle < 90). +// Section 7.11.2.4 (#7). +// |dst| is an unaligned pointer to the output block. Pixel size is determined +// by bitdepth with |stride| given in bytes. |top| is an unaligned pointer to +// the row above |dst|. |width| and |height| give the dimensions of the block. +// |xstep| is the scaled starting index to |top| from +// kDirectionalIntraPredictorDerivative. |upsampled_top| indicates whether +// |top| has been upsampled as described in '7.11.2.11. Intra edge upsample +// process'. This can occur in cases with |width| + |height| <= 16. top-right +// is accessed. +using DirectionalIntraPredictorZone1Func = void (*)(void* dst, ptrdiff_t stride, + const void* top, int width, + int height, int xstep, + bool upsampled_top); + +// Directional intra predictor function signature, zone 2 (90 < angle < 180). +// Section 7.11.2.4 (#8). +// |dst| is an unaligned pointer to the output block. Pixel size is determined +// by bitdepth with |stride| given in bytes. |top| is an unaligned pointer to +// the row above |dst|. |left| is an aligned vector of the column to the left of +// |dst|. |width| and |height| give the dimensions of the block. |xstep| and +// |ystep| are the scaled starting index to |top| and |left|, respectively, +// from kDirectionalIntraPredictorDerivative. |upsampled_top| and +// |upsampled_left| indicate whether |top| and |left| have been upsampled as +// described in '7.11.2.11. Intra edge upsample process'. This can occur in +// cases with |width| + |height| <= 16. top-left and upper-left are accessed, +// up to [-2] in each if |upsampled_top/left| are set. +using DirectionalIntraPredictorZone2Func = void (*)( + void* dst, ptrdiff_t stride, const void* top, const void* left, int width, + int height, int xstep, int ystep, bool upsampled_top, bool upsampled_left); + +// Directional intra predictor function signature, zone 3 (180 < angle < 270). +// Section 7.11.2.4 (#9). +// |dst| is an unaligned pointer to the output block. Pixel size is determined +// by bitdepth with |stride| given in bytes. |left| is an aligned vector of the +// column to the left of |dst|. |width| and |height| give the dimensions of the +// block. |ystep| is the scaled starting index to |left| from +// kDirectionalIntraPredictorDerivative. |upsampled_left| indicates whether +// |left| has been upsampled as described in '7.11.2.11. Intra edge upsample +// process'. This can occur in cases with |width| + |height| <= 16. bottom-left +// is accessed. +using DirectionalIntraPredictorZone3Func = void (*)(void* dst, ptrdiff_t stride, + const void* left, int width, + int height, int ystep, + bool upsampled_left); + +// Filter intra predictor function signature. Section 7.11.2.3. +// |dst| is an unaligned pointer to the output block. Pixel size is determined +// by bitdepth with |stride| given in bytes. |top| is an unaligned pointer to +// the row above |dst|. |left| is an aligned vector of the column to the left +// of |dst|. |width| and |height| are the size of the block in pixels. +using FilterIntraPredictorFunc = void (*)(void* dst, ptrdiff_t stride, + const void* top, const void* left, + FilterIntraPredictor pred, int width, + int height); + +//------------------------------------------------------------------------------ +// Chroma from Luma (Cfl) prediction. Section 7.11.5. + +// Chroma from Luma (Cfl) intra prediction function signature. |dst| is an +// unaligned pointer to the output block. Pixel size is determined by bitdepth +// with |stride| given in bytes. |luma| contains subsampled luma pixels with 3 +// fractional bits of precision. |alpha| is the signed Cfl alpha value for the +// appropriate plane. +using CflIntraPredictorFunc = void (*)( + void* dst, ptrdiff_t stride, + const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], int alpha); +using CflIntraPredictorFuncs = CflIntraPredictorFunc[kNumTransformSizes]; + +// Chroma from Luma (Cfl) subsampler function signature. |luma| is an unaligned +// pointer to the output block. |src| is an unaligned pointer to the input +// block. Pixel size is determined by bitdepth with |stride| given in bytes. +using CflSubsamplerFunc = + void (*)(int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + int max_luma_width, int max_luma_height, const void* source, + ptrdiff_t stride); +using CflSubsamplerFuncs = + CflSubsamplerFunc[kNumTransformSizes][kNumSubsamplingTypes]; + +//------------------------------------------------------------------------------ +// Intra Edge Filtering and Upsampling. Step 4 in section 7.11.2.4. + +// Intra edge filter function signature. |buffer| is a pointer to the top_row or +// left_column that needs to be filtered. Typically the -1'th index of |top_row| +// and |left_column| need to be filtered as well, so the caller can merely pass +// the |buffer| as top_row[-1] or left_column[-1]. Pixel size is determined by +// bitdepth. |size| is the number of pixels to be filtered. |strength| is the +// filter strength. Section 7.11.2.12 in the spec. +using IntraEdgeFilterFunc = void (*)(void* buffer, int size, int strength); + +// Intra edge upsampler function signature. |buffer| is a pointer to the top_row +// or left_column that needs to be upsampled. Pixel size is determined by +// bitdepth. |size| is the number of pixels to be upsampled; valid values are: +// 4, 8, 12, 16. This function needs access to negative indices -1 and -2 of +// the |buffer|. Section 7.11.2.11 in the spec. +using IntraEdgeUpsamplerFunc = void (*)(void* buffer, int size); + +//------------------------------------------------------------------------------ +// Inverse transform add function signature. +// +// Steps 2 and 3 of section 7.12.3 (contains the implementation of section +// 7.13.3). +// Apply the inverse transforms and add the residual to the destination frame +// for the transform type and block size |tx_size| starting at position +// |start_x| and |start_y|. |dst_frame| is a pointer to an Array2D. +// |adjusted_tx_height| is the number of rows to process based on the non-zero +// coefficient count in the block. It will be 1 (non-zero coefficient count == +// 1), 4 or a multiple of 8 up to 32 or the original transform height, +// whichever is less. +using InverseTransformAddFunc = void (*)(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, + void* src_buffer, int start_x, + int start_y, void* dst_frame); +// The final dimension holds row and column transforms indexed with kRow and +// kColumn. +using InverseTransformAddFuncs = + InverseTransformAddFunc[kNum1DTransforms][kNum1DTransformSizes][2]; + +//------------------------------------------------------------------------------ +// Post processing. + +// Loop filter function signature. Section 7.14. +// |dst| is an unaligned pointer to the output block. Pixel size is determined +// by bitdepth with |stride| given in bytes. +using LoopFilterFunc = void (*)(void* dst, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh); +using LoopFilterFuncs = + LoopFilterFunc[kNumLoopFilterSizes][kNumLoopFilterTypes]; + +// Cdef direction function signature. Section 7.15.2. +// |src| is a pointer to the source block. Pixel size is determined by bitdepth +// with |stride| given in bytes. |direction| and |variance| are output +// parameters and must not be nullptr. +using CdefDirectionFunc = void (*)(const void* src, ptrdiff_t stride, + uint8_t* direction, int* variance); + +// Cdef filtering function signature. Section 7.15.3. +// |source| is a pointer to the input block padded with kCdefLargeValue if at a +// frame border. |source_stride| is given in units of uint16_t. +// |block_width|, |block_height| are the width/height of the input block. +// |primary_strength|, |secondary_strength|, and |damping| are Cdef filtering +// parameters. +// |direction| is the filtering direction. +// |dest| is the output buffer. |dest_stride| is given in bytes. +using CdefFilteringFunc = void (*)(const uint16_t* source, + ptrdiff_t source_stride, int block_height, + int primary_strength, int secondary_strength, + int damping, int direction, void* dest, + ptrdiff_t dest_stride); + +// The first index is block width: [0]: 4, [1]: 8. The second is based on +// non-zero strengths: [0]: |primary_strength| and |secondary_strength|, [1]: +// |primary_strength| only, [2]: |secondary_strength| only. +using CdefFilteringFuncs = CdefFilteringFunc[2][3]; + +// Upscaling coefficients function signature. Section 7.16. +// This is an auxiliary function for SIMD optimizations and has no corresponding +// C function. Different SIMD versions may have different outputs. So it must +// pair with the corresponding version of SuperResFunc. +// |upscaled_width| is the width of the output frame. +// |step| is the number of subpixels to move the kernel for the next destination +// pixel. +// |initial_subpixel_x| is a base offset from which |step| increments. +// |coefficients| is the upscale filter used by each pixel in a row. +using SuperResCoefficientsFunc = void (*)(int upscaled_width, + int initial_subpixel_x, int step, + void* coefficients); + +// Upscaling process function signature. Section 7.16. +// |coefficients| is the upscale filter used by each pixel in a row. It is not +// used by the C function. +// |source| is the input frame buffer. It will be line extended. +// |dest| is the output buffer. +// |stride| is given in pixels, and shared by |source| and |dest|. +// |height| is the height of the block to be processed. +// |downscaled_width| is the width of the input frame. +// |upscaled_width| is the width of the output frame. +// |step| is the number of subpixels to move the kernel for the next destination +// pixel. +// |initial_subpixel_x| is a base offset from which |step| increments. +using SuperResFunc = void (*)(const void* coefficients, void* source, + ptrdiff_t stride, int height, + int downscaled_width, int upscaled_width, + int initial_subpixel_x, int step, void* dest); + +// Loop restoration function signature. Sections 7.16, 7.17. +// |restoration_info| contains loop restoration information, such as filter +// type, strength. +// |source| is the input frame buffer, which is deblocked and cdef filtered. +// |top_border| and |bottom_border| are the top and bottom borders. +// |dest| is the output. +// |stride| is given in pixels, and shared by |source|, |top_border|, +// |bottom_border| and |dest|. +// |restoration_buffer| contains buffers required for self guided filter and +// wiener filter. They must be initialized before calling. +using LoopRestorationFunc = void (*)( + const RestorationUnitInfo& restoration_info, const void* source, + const void* top_border, const void* bottom_border, ptrdiff_t stride, + int width, int height, RestorationBuffer* restoration_buffer, void* dest); + +// Index 0 is Wiener Filter. +// Index 1 is Self Guided Restoration Filter. +// This can be accessed as LoopRestorationType - 2. +using LoopRestorationFuncs = LoopRestorationFunc[2]; + +// Convolve function signature. Section 7.11.3.4. +// This function applies a horizontal filter followed by a vertical filter. +// |reference| is the input block (reference frame buffer). |reference_stride| +// is the corresponding frame stride. +// |vertical_filter_index|/|horizontal_filter_index| is the index to +// retrieve the type of filter to be applied for vertical/horizontal direction +// from the filter lookup table 'kSubPixelFilters'. +// |horizontal_filter_id| and |vertical_filter_id| are the filter ids. +// |width| and |height| are width and height of the block to be filtered. +// |ref_last_x| and |ref_last_y| are the last pixel of the reference frame in +// x/y direction. +// |prediction| is the output block (output frame buffer). +// Rounding precision is derived from the function being called. For horizontal +// filtering kInterRoundBitsHorizontal & kInterRoundBitsHorizontal12bpp will be +// used. For compound vertical filtering kInterRoundBitsCompoundVertical will be +// used. Otherwise kInterRoundBitsVertical & kInterRoundBitsVertical12bpp will +// be used. +using ConvolveFunc = void (*)(const void* reference, ptrdiff_t reference_stride, + int horizontal_filter_index, + int vertical_filter_index, + int horizontal_filter_id, int vertical_filter_id, + int width, int height, void* prediction, + ptrdiff_t pred_stride); + +// Convolve functions signature. Each points to one convolve function with +// a specific setting: +// ConvolveFunc[is_intra_block_copy][is_compound][has_vertical_filter] +// [has_horizontal_filter]. +// If is_compound is false, the prediction is clipped to Pixel. +// If is_compound is true, the range of prediction is: +// 8bpp: [-5132, 9212] (int16_t) +// 10bpp: [ 3988, 61532] (uint16_t) +// 12bpp: [ 3974, 61559] (uint16_t) +// See src/dsp/convolve.cc +using ConvolveFuncs = ConvolveFunc[2][2][2][2]; + +// Convolve + scale function signature. Section 7.11.3.4. +// This function applies a horizontal filter followed by a vertical filter. +// |reference| is the input block (reference frame buffer). |reference_stride| +// is the corresponding frame stride. +// |vertical_filter_index|/|horizontal_filter_index| is the index to +// retrieve the type of filter to be applied for vertical/horizontal direction +// from the filter lookup table 'kSubPixelFilters'. +// |subpixel_x| and |subpixel_y| are starting positions in units of 1/1024. +// |step_x| and |step_y| are step sizes in units of 1/1024 of a pixel. +// |width| and |height| are width and height of the block to be filtered. +// |ref_last_x| and |ref_last_y| are the last pixel of the reference frame in +// x/y direction. +// |prediction| is the output block (output frame buffer). +// Rounding precision is derived from the function being called. For horizontal +// filtering kInterRoundBitsHorizontal & kInterRoundBitsHorizontal12bpp will be +// used. For compound vertical filtering kInterRoundBitsCompoundVertical will be +// used. Otherwise kInterRoundBitsVertical & kInterRoundBitsVertical12bpp will +// be used. +using ConvolveScaleFunc = void (*)(const void* reference, + ptrdiff_t reference_stride, + int horizontal_filter_index, + int vertical_filter_index, int subpixel_x, + int subpixel_y, int step_x, int step_y, + int width, int height, void* prediction, + ptrdiff_t pred_stride); + +// Convolve functions signature for scaling version. +// 0: single predictor. 1: compound predictor. +using ConvolveScaleFuncs = ConvolveScaleFunc[2]; + +// Weight mask function signature. Section 7.11.3.12. +// |prediction_0| is the first input block. +// |prediction_1| is the second input block. Both blocks are int16_t* when +// bitdepth == 8 and uint16_t* otherwise. +// |width| and |height| are the prediction width and height. +// The stride for the input buffers is equal to |width|. +// The valid range of block size is [8x8, 128x128] for the luma plane. +// |mask| is the output buffer. |mask_stride| is the output buffer stride. +using WeightMaskFunc = void (*)(const void* prediction_0, + const void* prediction_1, uint8_t* mask, + ptrdiff_t mask_stride); + +// Weight mask functions signature. The dimensions (in order) are: +// * Width index (4 => 0, 8 => 1, 16 => 2 and so on). +// * Height index (4 => 0, 8 => 1, 16 => 2 and so on). +// * mask_is_inverse. +using WeightMaskFuncs = WeightMaskFunc[6][6][2]; + +// Average blending function signature. +// Two predictors are averaged to generate the output. +// Input predictor values are int16_t. Output type is uint8_t, with actual +// range of Pixel value. +// Average blending is in the bottom of Section 7.11.3.1 (COMPOUND_AVERAGE). +// |prediction_0| is the first input block. +// |prediction_1| is the second input block. Both blocks are int16_t* when +// bitdepth == 8 and uint16_t* otherwise. +// |width| and |height| are the same for the first and second input blocks. +// The stride for the input buffers is equal to |width|. +// The valid range of block size is [8x8, 128x128] for the luma plane. +// |dest| is the output buffer. |dest_stride| is the output buffer stride. +using AverageBlendFunc = void (*)(const void* prediction_0, + const void* prediction_1, int width, + int height, void* dest, + ptrdiff_t dest_stride); + +// Distance weighted blending function signature. +// Weights are generated in Section 7.11.3.15. +// Weighted blending is in the bottom of Section 7.11.3.1 (COMPOUND_DISTANCE). +// This function takes two blocks (inter frame prediction) and produces a +// weighted output. +// |prediction_0| is the first input block. +// |prediction_1| is the second input block. Both blocks are int16_t* when +// bitdepth == 8 and uint16_t* otherwise. +// |weight_0| is the weight for the first block. It is derived from the relative +// distance of the first reference frame and the current frame. +// |weight_1| is the weight for the second block. It is derived from the +// relative distance of the second reference frame and the current frame. +// |width| and |height| are the same for the first and second input blocks. +// The stride for the input buffers is equal to |width|. +// The valid range of block size is [8x8, 128x128] for the luma plane. +// |dest| is the output buffer. |dest_stride| is the output buffer stride. +using DistanceWeightedBlendFunc = void (*)(const void* prediction_0, + const void* prediction_1, + uint8_t weight_0, uint8_t weight_1, + int width, int height, void* dest, + ptrdiff_t dest_stride); + +// Mask blending function signature. Section 7.11.3.14. +// This function takes two blocks and produces a blended output stored into the +// output block |dest|. The blending is a weighted average process, controlled +// by values of the mask. +// |prediction_0| is the first input block. When prediction mode is inter_intra +// (or wedge_inter_intra), this refers to the inter frame prediction. It is +// int16_t* when bitdepth == 8 and uint16_t* otherwise. +// The stride for |prediction_0| is equal to |width|. +// |prediction_1| is the second input block. When prediction mode is inter_intra +// (or wedge_inter_intra), this refers to the intra frame prediction and uses +// Pixel values. It is only used for intra frame prediction when bitdepth >= 10. +// It is int16_t* when bitdepth == 8 and uint16_t* otherwise. +// |prediction_stride_1| is the stride, given in units of [u]int16_t. When +// |is_inter_intra| is false (compound prediction) then |prediction_stride_1| is +// equal to |width|. +// |mask| is an integer array, whose value indicates the weight of the blending. +// |mask_stride| is corresponding stride. +// |width|, |height| are the same for both input blocks. +// If it's inter_intra (or wedge_inter_intra), the valid range of block size is +// [8x8, 32x32]. Otherwise (including difference weighted prediction and +// compound average prediction), the valid range is [8x8, 128x128]. +// If there's subsampling, the corresponding width and height are halved for +// chroma planes. +// |subsampling_x|, |subsampling_y| are the subsampling factors. +// |is_inter_intra| stands for the prediction mode. If it is true, one of the +// prediction blocks is from intra prediction of current frame. Otherwise, two +// prediction blocks are both inter frame predictions. +// |is_wedge_inter_intra| indicates if the mask is for the wedge prediction. +// |dest| is the output block. +// |dest_stride| is the corresponding stride for dest. +using MaskBlendFunc = void (*)(const void* prediction_0, + const void* prediction_1, + ptrdiff_t prediction_stride_1, + const uint8_t* mask, ptrdiff_t mask_stride, + int width, int height, void* dest, + ptrdiff_t dest_stride); + +// Mask blending functions signature. Each points to one function with +// a specific setting: +// MaskBlendFunc[subsampling_x + subsampling_y][is_inter_intra]. +using MaskBlendFuncs = MaskBlendFunc[3][2]; + +// This function is similar to the MaskBlendFunc. It is only used when +// |is_inter_intra| is true and |bitdepth| == 8. +// |prediction_[01]| are Pixel values (uint8_t). +// |prediction_1| is also the output buffer. +using InterIntraMaskBlendFunc8bpp = void (*)(const uint8_t* prediction_0, + uint8_t* prediction_1, + ptrdiff_t prediction_stride_1, + const uint8_t* mask, + ptrdiff_t mask_stride, int width, + int height); + +// InterIntra8bpp mask blending functions signature. When is_wedge_inter_intra +// is false, the function at index 0 must be used. Otherwise, the function at +// index subsampling_x + subsampling_y must be used. +using InterIntraMaskBlendFuncs8bpp = InterIntraMaskBlendFunc8bpp[3]; + +// Obmc (overlapped block motion compensation) blending function signature. +// Section 7.11.3.10. +// This function takes two blocks and produces a blended output stored into the +// first input block. The blending is a weighted average process, controlled by +// values of the mask. +// Obmc is not a compound mode. It is different from other compound blending, +// in terms of precision. The current block is computed using convolution with +// clipping to the range of pixel values. Its above and left blocks are also +// clipped. Therefore obmc blending process doesn't need to clip the output. +// |prediction| is the first input block, which will be overwritten. +// |prediction_stride| is the stride, given in bytes. +// |width|, |height| are the same for both input blocks. +// |obmc_prediction| is the second input block. +// |obmc_prediction_stride| is its stride, given in bytes. +using ObmcBlendFunc = void (*)(void* prediction, ptrdiff_t prediction_stride, + int width, int height, + const void* obmc_prediction, + ptrdiff_t obmc_prediction_stride); +using ObmcBlendFuncs = ObmcBlendFunc[kNumObmcDirections]; + +// Warp function signature. Section 7.11.3.5. +// This function applies warp filtering for each 8x8 block inside the current +// coding block. The filtering process is similar to 2d convolve filtering. +// The horizontal filter is applied followed by the vertical filter. +// The function has to calculate corresponding pixel positions before and +// after warping. +// |source| is the input reference frame buffer. +// |source_stride|, |source_width|, |source_height| are corresponding frame +// stride, width, and height. |source_stride| is given in bytes. +// |warp_params| is the matrix of warp motion: warp_params[i] = mN. +// [x' (m2 m3 m0 [x +// z . y' = m4 m5 m1 * y +// 1] m6 m7 1) 1] +// |subsampling_x/y| is the current frame's plane subsampling factor. +// |block_start_x| and |block_start_y| are the starting position the current +// coding block. +// |block_width| and |block_height| are width and height of the current coding +// block. |block_width| and |block_height| are at least 8. +// |alpha|, |beta|, |gamma|, |delta| are valid warp parameters. See the +// comments in the definition of struct GlobalMotion for the range of their +// values. +// |dest| is the output buffer of type Pixel. The output values are clipped to +// Pixel values. +// |dest_stride| is the stride, in units of bytes. +// Rounding precision is derived from the function being called. For horizontal +// filtering kInterRoundBitsHorizontal & kInterRoundBitsHorizontal12bpp will be +// used. For vertical filtering kInterRoundBitsVertical & +// kInterRoundBitsVertical12bpp will be used. +// +// NOTE: WarpFunc assumes the source frame has left, right, top, and bottom +// borders that extend the frame boundary pixels. +// * The left and right borders must be at least 13 pixels wide. In addition, +// Warp_NEON() may read up to 14 bytes after a row in the |source| buffer. +// Therefore, there must be at least one extra padding byte after the right +// border of the last row in the source buffer. +// * The top and bottom borders must be at least 13 pixels high. +using WarpFunc = void (*)(const void* source, ptrdiff_t source_stride, + int source_width, int source_height, + const int* warp_params, int subsampling_x, + int subsampling_y, int block_start_x, + int block_start_y, int block_width, int block_height, + int16_t alpha, int16_t beta, int16_t gamma, + int16_t delta, void* dest, ptrdiff_t dest_stride); + +// Warp for compound predictions. Section 7.11.3.5. +// Similar to WarpFunc, but |dest| is a uint16_t predictor buffer, +// |dest_stride| is given in units of uint16_t and |inter_round_bits_vertical| +// is always 7 (kCompoundInterRoundBitsVertical). +// Rounding precision is derived from the function being called. For horizontal +// filtering kInterRoundBitsHorizontal & kInterRoundBitsHorizontal12bpp will be +// used. For vertical filtering kInterRoundBitsCompondVertical will be used. +using WarpCompoundFunc = WarpFunc; + +constexpr int kNumAutoRegressionLags = 4; +// Applies an auto-regressive filter to the white noise in |luma_grain_buffer|. +// Section 7.18.3.3, second code block +// |params| are parameters read from frame header, mainly providing +// auto_regression_coeff_y for the filter and auto_regression_shift to right +// shift the filter sum by. Note: This method assumes +// params.auto_regression_coeff_lag is not 0. Do not call this method if +// params.auto_regression_coeff_lag is 0. +using LumaAutoRegressionFunc = void (*)(const FilmGrainParams& params, + void* luma_grain_buffer); +// Function index is auto_regression_coeff_lag - 1. +using LumaAutoRegressionFuncs = + LumaAutoRegressionFunc[kNumAutoRegressionLags - 1]; + +// Applies an auto-regressive filter to the white noise in u_grain and v_grain. +// Section 7.18.3.3, third code block +// The |luma_grain_buffer| provides samples that are added to the autoregressive +// sum when num_y_points > 0. +// |u_grain_buffer| and |v_grain_buffer| point to the buffers of chroma noise +// that were generated from the stored Gaussian sequence, and are overwritten +// with the results of the autoregressive filter. |params| are parameters read +// from frame header, mainly providing auto_regression_coeff_u and +// auto_regression_coeff_v for each chroma plane's filter, and +// auto_regression_shift to right shift the filter sums by. +using ChromaAutoRegressionFunc = void (*)(const FilmGrainParams& params, + const void* luma_grain_buffer, + int subsampling_x, int subsampling_y, + void* u_grain_buffer, + void* v_grain_buffer); +using ChromaAutoRegressionFuncs = + ChromaAutoRegressionFunc[/*use_luma*/ 2][kNumAutoRegressionLags]; + +// Build an image-wide "stripe" of grain noise for every 32 rows in the image. +// Section 7.18.3.5, first code block. +// Each 32x32 luma block is copied at a random offset specified via +// |grain_seed| from the grain template produced by autoregression, and the same +// is done for chroma grains, subject to subsampling. +// |width| and |height| are the dimensions of the overall image. +// |noise_stripes_buffer| points to an Array2DView with one row for each stripe. +// Because this function treats all planes identically and independently, it is +// simplified to take one grain buffer at a time. This means duplicating some +// random number generations, but that work can be reduced in other ways. +using ConstructNoiseStripesFunc = void (*)(const void* grain_buffer, + int grain_seed, int width, + int height, int subsampling_x, + int subsampling_y, + void* noise_stripes_buffer); +using ConstructNoiseStripesFuncs = + ConstructNoiseStripesFunc[/*overlap_flag*/ 2]; + +// Compute the one or two overlap rows for each stripe copied to the noise +// image. +// Section 7.18.3.5, second code block. |width| and |height| are the +// dimensions of the overall image. |noise_stripes_buffer| points to an +// Array2DView with one row for each stripe. |noise_image_buffer| points to an +// Array2D containing the allocated plane for this frame. Because this function +// treats all planes identically and independently, it is simplified to take one +// grain buffer at a time. +using ConstructNoiseImageOverlapFunc = + void (*)(const void* noise_stripes_buffer, int width, int height, + int subsampling_x, int subsampling_y, void* noise_image_buffer); + +// Populate a scaling lookup table with interpolated values of a piecewise +// linear function where values in |point_value| are mapped to the values in +// |point_scaling|. +// |num_points| can be between 0 and 15. When 0, the lookup table is set to +// zero. +// |point_value| and |point_scaling| have |num_points| valid elements. +using InitializeScalingLutFunc = void (*)( + int num_points, const uint8_t point_value[], const uint8_t point_scaling[], + uint8_t scaling_lut[kScalingLookupTableSize]); + +// Blend noise with image. Section 7.18.3.5, third code block. +// |width| is the width of each row, while |height| is how many rows to compute. +// |start_height| is an offset for the noise image, to support multithreading. +// |min_value|, |max_luma|, and |max_chroma| are computed by the caller of these +// functions, according to the code in the spec. +// |source_plane_y| and |source_plane_uv| are the plane buffers of the decoded +// frame. They are blended with the film grain noise and written to +// |dest_plane_y| and |dest_plane_uv| as final output for display. +// source_plane_* and dest_plane_* may point to the same buffer, in which case +// the film grain noise is added in place. +// |scaling_lut_y| and |scaling_lut| represent a piecewise linear mapping from +// the frame's raw pixel value, to a scaling factor for the noise sample. +// |scaling_shift| is applied as a right shift after scaling, so that scaling +// down is possible. It is found in FilmGrainParams, but supplied directly to +// BlendNoiseWithImageLumaFunc because it's the only member used. +using BlendNoiseWithImageLumaFunc = + void (*)(const void* noise_image_ptr, int min_value, int max_value, + int scaling_shift, int width, int height, int start_height, + const uint8_t scaling_lut_y[kScalingLookupTableSize], + const void* source_plane_y, ptrdiff_t source_stride_y, + void* dest_plane_y, ptrdiff_t dest_stride_y); + +using BlendNoiseWithImageChromaFunc = void (*)( + Plane plane, const FilmGrainParams& params, const void* noise_image_ptr, + int min_value, int max_value, int width, int height, int start_height, + int subsampling_x, int subsampling_y, + const uint8_t scaling_lut[kScalingLookupTableSize], + const void* source_plane_y, ptrdiff_t source_stride_y, + const void* source_plane_uv, ptrdiff_t source_stride_uv, + void* dest_plane_uv, ptrdiff_t dest_stride_uv); + +using BlendNoiseWithImageChromaFuncs = + BlendNoiseWithImageChromaFunc[/*chroma_scaling_from_luma*/ 2]; + +//------------------------------------------------------------------------------ + +struct FilmGrainFuncs { + LumaAutoRegressionFuncs luma_auto_regression; + ChromaAutoRegressionFuncs chroma_auto_regression; + ConstructNoiseStripesFuncs construct_noise_stripes; + ConstructNoiseImageOverlapFunc construct_noise_image_overlap; + InitializeScalingLutFunc initialize_scaling_lut; + BlendNoiseWithImageLumaFunc blend_noise_luma; + BlendNoiseWithImageChromaFuncs blend_noise_chroma; +}; + +// Motion field projection function signature. Section 7.9. +// |reference_info| provides reference information for motion field projection. +// |reference_to_current_with_sign| is the precalculated reference frame id +// distance from current frame. +// |dst_sign| is -1 for LAST_FRAME and LAST2_FRAME, or 0 (1 in spec) for others. +// |y8_start| and |y8_end| are the start and end 8x8 rows of the current tile. +// |x8_start| and |x8_end| are the start and end 8x8 columns of the current +// tile. +// |motion_field| is the output which saves the projected motion field +// information. +using MotionFieldProjectionKernelFunc = void (*)( + const ReferenceInfo& reference_info, int reference_to_current_with_sign, + int dst_sign, int y8_start, int y8_end, int x8_start, int x8_end, + TemporalMotionField* motion_field); + +// Compound temporal motion vector projection function signature. +// Section 7.9.3 and 7.10.2.10. +// |temporal_mvs| is the set of temporal reference motion vectors. +// |temporal_reference_offsets| specifies the number of frames covered by the +// original motion vector. +// |reference_offsets| specifies the number of frames to be covered by the +// projected motion vector. +// |count| is the number of the temporal motion vectors. +// |candidate_mvs| is the set of projected motion vectors. +using MvProjectionCompoundFunc = void (*)( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + const int reference_offsets[2], int count, + CompoundMotionVector* candidate_mvs); + +// Single temporal motion vector projection function signature. +// Section 7.9.3 and 7.10.2.10. +// |temporal_mvs| is the set of temporal reference motion vectors. +// |temporal_reference_offsets| specifies the number of frames covered by the +// original motion vector. +// |reference_offset| specifies the number of frames to be covered by the +// projected motion vector. +// |count| is the number of the temporal motion vectors. +// |candidate_mvs| is the set of projected motion vectors. +using MvProjectionSingleFunc = void (*)( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + int reference_offset, int count, MotionVector* candidate_mvs); + +struct Dsp { + AverageBlendFunc average_blend; + CdefDirectionFunc cdef_direction; + CdefFilteringFuncs cdef_filters; + CflIntraPredictorFuncs cfl_intra_predictors; + CflSubsamplerFuncs cfl_subsamplers; + ConvolveFuncs convolve; + ConvolveScaleFuncs convolve_scale; + DirectionalIntraPredictorZone1Func directional_intra_predictor_zone1; + DirectionalIntraPredictorZone2Func directional_intra_predictor_zone2; + DirectionalIntraPredictorZone3Func directional_intra_predictor_zone3; + DistanceWeightedBlendFunc distance_weighted_blend; + FilmGrainFuncs film_grain; + FilterIntraPredictorFunc filter_intra_predictor; + InterIntraMaskBlendFuncs8bpp inter_intra_mask_blend_8bpp; + IntraEdgeFilterFunc intra_edge_filter; + IntraEdgeUpsamplerFunc intra_edge_upsampler; + IntraPredictorFuncs intra_predictors; + InverseTransformAddFuncs inverse_transforms; + LoopFilterFuncs loop_filters; + LoopRestorationFuncs loop_restorations; + MaskBlendFuncs mask_blend; + MotionFieldProjectionKernelFunc motion_field_projection_kernel; + MvProjectionCompoundFunc mv_projection_compound[3]; + MvProjectionSingleFunc mv_projection_single[3]; + ObmcBlendFuncs obmc_blend; + SuperResCoefficientsFunc super_res_coefficients; + SuperResFunc super_res; + WarpCompoundFunc warp_compound; + WarpFunc warp; + WeightMaskFuncs weight_mask; +}; + +// Initializes function pointers based on build config and runtime +// environment. Must be called once before first use. This function is +// thread-safe. +void DspInit(); + +// Returns the appropriate Dsp table for |bitdepth| or nullptr if one doesn't +// exist. +const Dsp* GetDspTable(int bitdepth); + +} // namespace dsp + +namespace dsp_internal { + +// Visual Studio builds don't have a way to detect SSE4_1. Only exclude the C +// functions if /arch:AVX2 is used across all sources. +#if !LIBGAV1_TARGETING_AVX2 && \ + (defined(_MSC_VER) || (defined(_M_IX86) || defined(_M_X64))) +#undef LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +#define LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS 1 +#endif + +// Returns true if a more highly optimized version of |func| is not defined for +// the associated bitdepth or if it is forcibly enabled with +// LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS. The define checked for |func| corresponds +// to the LIBGAV1_Dsp<bitdepth>bpp_|func| define in the header file associated +// with the module. +// |func| is one of: +// - FunctionName, e.g., SelfGuidedFilter. +// - [sub-table-index1][...-indexN] e.g., +// TransformSize4x4_IntraPredictorDc. The indices correspond to enum values +// used as lookups with leading 'k' removed. +// +// NEON support is the only extension available for ARM and it is always +// required. Because of this restriction DSP_ENABLED_8BPP_NEON(func) is always +// true and can be omitted. +#define DSP_ENABLED_8BPP_AVX2(func) \ + (LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \ + LIBGAV1_Dsp8bpp_##func == LIBGAV1_CPU_AVX2) +#define DSP_ENABLED_10BPP_AVX2(func) \ + (LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \ + LIBGAV1_Dsp10bpp_##func == LIBGAV1_CPU_AVX2) +#define DSP_ENABLED_8BPP_SSE4_1(func) \ + (LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \ + LIBGAV1_Dsp8bpp_##func == LIBGAV1_CPU_SSE4_1) +#define DSP_ENABLED_10BPP_SSE4_1(func) \ + (LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \ + LIBGAV1_Dsp10bpp_##func == LIBGAV1_CPU_SSE4_1) + +// Returns the appropriate Dsp table for |bitdepth| or nullptr if one doesn't +// exist. This version is meant for use by test or dsp/*Init() functions only. +dsp::Dsp* GetWritableDspTable(int bitdepth); + +} // namespace dsp_internal +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_DSP_H_ diff --git a/src/dsp/film_grain.cc b/src/dsp/film_grain.cc new file mode 100644 index 0000000..41d1dd0 --- /dev/null +++ b/src/dsp/film_grain.cc @@ -0,0 +1,870 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/film_grain.h" + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> +#include <new> + +#include "src/dsp/common.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/film_grain_common.h" +#include "src/utils/array_2d.h" +#include "src/utils/common.h" +#include "src/utils/compiler_attributes.h" +#include "src/utils/logging.h" + +namespace libgav1 { +namespace dsp { +namespace film_grain { +namespace { + +// Making this a template function prevents it from adding to code size when it +// is not placed in the DSP table. Most functions in the dsp directory change +// behavior by bitdepth, but because this one doesn't, it receives a dummy +// parameter with one enforced value, ensuring only one copy is made. +template <int singleton> +void InitializeScalingLookupTable_C( + int num_points, const uint8_t point_value[], const uint8_t point_scaling[], + uint8_t scaling_lut[kScalingLookupTableSize]) { + static_assert(singleton == 0, + "Improper instantiation of InitializeScalingLookupTable_C. " + "There should be only one copy of this function."); + if (num_points == 0) { + memset(scaling_lut, 0, sizeof(scaling_lut[0]) * kScalingLookupTableSize); + return; + } + static_assert(sizeof(scaling_lut[0]) == 1, ""); + memset(scaling_lut, point_scaling[0], point_value[0]); + for (int i = 0; i < num_points - 1; ++i) { + const int delta_y = point_scaling[i + 1] - point_scaling[i]; + const int delta_x = point_value[i + 1] - point_value[i]; + const int delta = delta_y * ((65536 + (delta_x >> 1)) / delta_x); + for (int x = 0; x < delta_x; ++x) { + const int v = point_scaling[i] + ((x * delta + 32768) >> 16); + assert(v >= 0 && v <= UINT8_MAX); + scaling_lut[point_value[i] + x] = v; + } + } + const uint8_t last_point_value = point_value[num_points - 1]; + memset(&scaling_lut[last_point_value], point_scaling[num_points - 1], + kScalingLookupTableSize - last_point_value); +} + +// Section 7.18.3.5. +// Performs a piecewise linear interpolation into the scaling table. +template <int bitdepth> +int ScaleLut(const uint8_t scaling_lut[kScalingLookupTableSize], int index) { + const int shift = bitdepth - 8; + const int quotient = index >> shift; + const int remainder = index - (quotient << shift); + if (bitdepth == 8) { + assert(quotient < kScalingLookupTableSize); + return scaling_lut[quotient]; + } + assert(quotient + 1 < kScalingLookupTableSize); + const int start = scaling_lut[quotient]; + const int end = scaling_lut[quotient + 1]; + return start + RightShiftWithRounding((end - start) * remainder, shift); +} + +// Applies an auto-regressive filter to the white noise in luma_grain. +template <int bitdepth, typename GrainType> +void ApplyAutoRegressiveFilterToLumaGrain_C(const FilmGrainParams& params, + void* luma_grain_buffer) { + auto* luma_grain = static_cast<GrainType*>(luma_grain_buffer); + const int grain_min = GetGrainMin<bitdepth>(); + const int grain_max = GetGrainMax<bitdepth>(); + const int auto_regression_coeff_lag = params.auto_regression_coeff_lag; + assert(auto_regression_coeff_lag > 0 && auto_regression_coeff_lag <= 3); + // A pictorial representation of the auto-regressive filter for various values + // of auto_regression_coeff_lag. The letter 'O' represents the current sample. + // (The filter always operates on the current sample with filter + // coefficient 1.) The letters 'X' represent the neighboring samples that the + // filter operates on. + // + // auto_regression_coeff_lag == 3: + // X X X X X X X + // X X X X X X X + // X X X X X X X + // X X X O + // auto_regression_coeff_lag == 2: + // X X X X X + // X X X X X + // X X O + // auto_regression_coeff_lag == 1: + // X X X + // X O + // auto_regression_coeff_lag == 0: + // O + // + // Note that if auto_regression_coeff_lag is 0, the filter is the identity + // filter and therefore can be skipped. This implementation assumes it is not + // called in that case. + const int shift = params.auto_regression_shift; + for (int y = kAutoRegressionBorder; y < kLumaHeight; ++y) { + for (int x = kAutoRegressionBorder; x < kLumaWidth - kAutoRegressionBorder; + ++x) { + int sum = 0; + int pos = 0; + int delta_row = -auto_regression_coeff_lag; + // The last iteration (delta_row == 0) is shorter and is handled + // separately. + do { + int delta_column = -auto_regression_coeff_lag; + do { + const int coeff = params.auto_regression_coeff_y[pos]; + sum += luma_grain[(y + delta_row) * kLumaWidth + (x + delta_column)] * + coeff; + ++pos; + } while (++delta_column <= auto_regression_coeff_lag); + } while (++delta_row < 0); + // Last iteration: delta_row == 0. + { + int delta_column = -auto_regression_coeff_lag; + do { + const int coeff = params.auto_regression_coeff_y[pos]; + sum += luma_grain[y * kLumaWidth + (x + delta_column)] * coeff; + ++pos; + } while (++delta_column < 0); + } + luma_grain[y * kLumaWidth + x] = Clip3( + luma_grain[y * kLumaWidth + x] + RightShiftWithRounding(sum, shift), + grain_min, grain_max); + } + } +} + +template <int bitdepth, typename GrainType, int auto_regression_coeff_lag, + bool use_luma> +void ApplyAutoRegressiveFilterToChromaGrains_C(const FilmGrainParams& params, + const void* luma_grain_buffer, + int subsampling_x, + int subsampling_y, + void* u_grain_buffer, + void* v_grain_buffer) { + static_assert( + auto_regression_coeff_lag >= 0 && auto_regression_coeff_lag <= 3, + "Unsupported autoregression lag for chroma."); + const auto* luma_grain = static_cast<const GrainType*>(luma_grain_buffer); + const int grain_min = GetGrainMin<bitdepth>(); + const int grain_max = GetGrainMax<bitdepth>(); + auto* u_grain = static_cast<GrainType*>(u_grain_buffer); + auto* v_grain = static_cast<GrainType*>(v_grain_buffer); + const int shift = params.auto_regression_shift; + const int chroma_height = + (subsampling_y == 0) ? kMaxChromaHeight : kMinChromaHeight; + const int chroma_width = + (subsampling_x == 0) ? kMaxChromaWidth : kMinChromaWidth; + for (int y = kAutoRegressionBorder; y < chroma_height; ++y) { + const int luma_y = + ((y - kAutoRegressionBorder) << subsampling_y) + kAutoRegressionBorder; + for (int x = kAutoRegressionBorder; + x < chroma_width - kAutoRegressionBorder; ++x) { + int sum_u = 0; + int sum_v = 0; + int pos = 0; + int delta_row = -auto_regression_coeff_lag; + do { + int delta_column = -auto_regression_coeff_lag; + do { + if (delta_row == 0 && delta_column == 0) { + break; + } + const int coeff_u = params.auto_regression_coeff_u[pos]; + const int coeff_v = params.auto_regression_coeff_v[pos]; + sum_u += + u_grain[(y + delta_row) * chroma_width + (x + delta_column)] * + coeff_u; + sum_v += + v_grain[(y + delta_row) * chroma_width + (x + delta_column)] * + coeff_v; + ++pos; + } while (++delta_column <= auto_regression_coeff_lag); + } while (++delta_row <= 0); + if (use_luma) { + int luma = 0; + const int luma_x = ((x - kAutoRegressionBorder) << subsampling_x) + + kAutoRegressionBorder; + int i = 0; + do { + int j = 0; + do { + luma += luma_grain[(luma_y + i) * kLumaWidth + (luma_x + j)]; + } while (++j <= subsampling_x); + } while (++i <= subsampling_y); + luma = SubsampledValue(luma, subsampling_x + subsampling_y); + const int coeff_u = params.auto_regression_coeff_u[pos]; + const int coeff_v = params.auto_regression_coeff_v[pos]; + sum_u += luma * coeff_u; + sum_v += luma * coeff_v; + } + u_grain[y * chroma_width + x] = Clip3( + u_grain[y * chroma_width + x] + RightShiftWithRounding(sum_u, shift), + grain_min, grain_max); + v_grain[y * chroma_width + x] = Clip3( + v_grain[y * chroma_width + x] + RightShiftWithRounding(sum_v, shift), + grain_min, grain_max); + } + } +} + +// This implementation is for the condition overlap_flag == false. +template <int bitdepth, typename GrainType> +void ConstructNoiseStripes_C(const void* grain_buffer, int grain_seed, + int width, int height, int subsampling_x, + int subsampling_y, void* noise_stripes_buffer) { + auto* noise_stripes = + static_cast<Array2DView<GrainType>*>(noise_stripes_buffer); + const auto* grain = static_cast<const GrainType*>(grain_buffer); + const int half_width = DivideBy2(width + 1); + const int half_height = DivideBy2(height + 1); + assert(half_width > 0); + assert(half_height > 0); + static_assert(kLumaWidth == kMaxChromaWidth, + "kLumaWidth width should be equal to kMaxChromaWidth"); + const int grain_width = + (subsampling_x == 0) ? kMaxChromaWidth : kMinChromaWidth; + const int plane_width = (width + subsampling_x) >> subsampling_x; + constexpr int kNoiseStripeHeight = 34; + int luma_num = 0; + int y = 0; + do { + GrainType* const noise_stripe = (*noise_stripes)[luma_num]; + uint16_t seed = grain_seed; + seed ^= ((luma_num * 37 + 178) & 255) << 8; + seed ^= ((luma_num * 173 + 105) & 255); + int x = 0; + do { + const int rand = GetFilmGrainRandomNumber(8, &seed); + const int offset_x = rand >> 4; + const int offset_y = rand & 15; + const int plane_offset_x = + (subsampling_x != 0) ? 6 + offset_x : 9 + offset_x * 2; + const int plane_offset_y = + (subsampling_y != 0) ? 6 + offset_y : 9 + offset_y * 2; + int i = 0; + do { + // Section 7.18.3.5 says: + // noiseStripe[ lumaNum ][ 0 ] is 34 samples high and w samples + // wide (a few additional samples across are actually written to + // the array, but these are never read) ... + // + // Note: The warning in the parentheses also applies to + // noiseStripe[ lumaNum ][ 1 ] and noiseStripe[ lumaNum ][ 2 ]. + // + // Writes beyond the width of each row could happen below. To + // prevent those writes, we clip the number of pixels to copy against + // the remaining width. + // TODO(petersonab): Allocate aligned stripes with extra width to cover + // the size of the final stripe block, then remove this call to min. + const int copy_size = + std::min(kNoiseStripeHeight >> subsampling_x, + plane_width - (x << (1 - subsampling_x))); + memcpy(&noise_stripe[i * plane_width + (x << (1 - subsampling_x))], + &grain[(plane_offset_y + i) * grain_width + plane_offset_x], + copy_size * sizeof(noise_stripe[0])); + } while (++i < (kNoiseStripeHeight >> subsampling_y)); + x += 16; + } while (x < half_width); + + ++luma_num; + y += 16; + } while (y < half_height); +} + +// This implementation is for the condition overlap_flag == true. +template <int bitdepth, typename GrainType> +void ConstructNoiseStripesWithOverlap_C(const void* grain_buffer, + int grain_seed, int width, int height, + int subsampling_x, int subsampling_y, + void* noise_stripes_buffer) { + auto* noise_stripes = + static_cast<Array2DView<GrainType>*>(noise_stripes_buffer); + const auto* grain = static_cast<const GrainType*>(grain_buffer); + const int half_width = DivideBy2(width + 1); + const int half_height = DivideBy2(height + 1); + assert(half_width > 0); + assert(half_height > 0); + static_assert(kLumaWidth == kMaxChromaWidth, + "kLumaWidth width should be equal to kMaxChromaWidth"); + const int grain_width = + (subsampling_x == 0) ? kMaxChromaWidth : kMinChromaWidth; + const int plane_width = (width + subsampling_x) >> subsampling_x; + constexpr int kNoiseStripeHeight = 34; + int luma_num = 0; + int y = 0; + do { + GrainType* const noise_stripe = (*noise_stripes)[luma_num]; + uint16_t seed = grain_seed; + seed ^= ((luma_num * 37 + 178) & 255) << 8; + seed ^= ((luma_num * 173 + 105) & 255); + // Begin special iteration for x == 0. + const int rand = GetFilmGrainRandomNumber(8, &seed); + const int offset_x = rand >> 4; + const int offset_y = rand & 15; + const int plane_offset_x = + (subsampling_x != 0) ? 6 + offset_x : 9 + offset_x * 2; + const int plane_offset_y = + (subsampling_y != 0) ? 6 + offset_y : 9 + offset_y * 2; + // The overlap computation only occurs when x > 0, so it is omitted here. + int i = 0; + do { + // TODO(petersonab): Allocate aligned stripes with extra width to cover + // the size of the final stripe block, then remove this call to min. + const int copy_size = + std::min(kNoiseStripeHeight >> subsampling_x, plane_width); + memcpy(&noise_stripe[i * plane_width], + &grain[(plane_offset_y + i) * grain_width + plane_offset_x], + copy_size * sizeof(noise_stripe[0])); + } while (++i < (kNoiseStripeHeight >> subsampling_y)); + // End special iteration for x == 0. + for (int x = 16; x < half_width; x += 16) { + const int rand = GetFilmGrainRandomNumber(8, &seed); + const int offset_x = rand >> 4; + const int offset_y = rand & 15; + const int plane_offset_x = + (subsampling_x != 0) ? 6 + offset_x : 9 + offset_x * 2; + const int plane_offset_y = + (subsampling_y != 0) ? 6 + offset_y : 9 + offset_y * 2; + int i = 0; + do { + int j = 0; + int grain_sample = + grain[(plane_offset_y + i) * grain_width + plane_offset_x]; + // The first pixel(s) of each segment of the noise_stripe are subject to + // the "overlap" computation. + if (subsampling_x == 0) { + // Corresponds to the line in the spec: + // if (j < 2 && x > 0) + // j = 0 + int old = noise_stripe[i * plane_width + x * 2]; + grain_sample = old * 27 + grain_sample * 17; + grain_sample = + Clip3(RightShiftWithRounding(grain_sample, 5), + GetGrainMin<bitdepth>(), GetGrainMax<bitdepth>()); + noise_stripe[i * plane_width + x * 2] = grain_sample; + + // This check prevents overwriting for the iteration j = 1. The + // continue applies to the i-loop. + if (x * 2 + 1 >= plane_width) continue; + // j = 1 + grain_sample = + grain[(plane_offset_y + i) * grain_width + plane_offset_x + 1]; + old = noise_stripe[i * plane_width + x * 2 + 1]; + grain_sample = old * 17 + grain_sample * 27; + grain_sample = + Clip3(RightShiftWithRounding(grain_sample, 5), + GetGrainMin<bitdepth>(), GetGrainMax<bitdepth>()); + noise_stripe[i * plane_width + x * 2 + 1] = grain_sample; + j = 2; + } else { + // Corresponds to the line in the spec: + // if (j == 0 && x > 0) + const int old = noise_stripe[i * plane_width + x]; + grain_sample = old * 23 + grain_sample * 22; + grain_sample = + Clip3(RightShiftWithRounding(grain_sample, 5), + GetGrainMin<bitdepth>(), GetGrainMax<bitdepth>()); + noise_stripe[i * plane_width + x] = grain_sample; + j = 1; + } + // The following covers the rest of the loop over j as described in the + // spec. + // + // Section 7.18.3.5 says: + // noiseStripe[ lumaNum ][ 0 ] is 34 samples high and w samples + // wide (a few additional samples across are actually written to + // the array, but these are never read) ... + // + // Note: The warning in the parentheses also applies to + // noiseStripe[ lumaNum ][ 1 ] and noiseStripe[ lumaNum ][ 2 ]. + // + // Writes beyond the width of each row could happen below. To + // prevent those writes, we clip the number of pixels to copy against + // the remaining width. + // TODO(petersonab): Allocate aligned stripes with extra width to cover + // the size of the final stripe block, then remove this call to min. + const int copy_size = + std::min(kNoiseStripeHeight >> subsampling_x, + plane_width - (x << (1 - subsampling_x))) - + j; + memcpy(&noise_stripe[i * plane_width + (x << (1 - subsampling_x)) + j], + &grain[(plane_offset_y + i) * grain_width + plane_offset_x + j], + copy_size * sizeof(noise_stripe[0])); + } while (++i < (kNoiseStripeHeight >> subsampling_y)); + } + + ++luma_num; + y += 16; + } while (y < half_height); +} + +template <int bitdepth, typename GrainType> +inline void WriteOverlapLine_C(const GrainType* noise_stripe_row, + const GrainType* noise_stripe_row_prev, + int plane_width, int grain_coeff, int old_coeff, + GrainType* noise_image_row) { + int x = 0; + do { + int grain = noise_stripe_row[x]; + const int old = noise_stripe_row_prev[x]; + grain = old * old_coeff + grain * grain_coeff; + grain = Clip3(RightShiftWithRounding(grain, 5), GetGrainMin<bitdepth>(), + GetGrainMax<bitdepth>()); + noise_image_row[x] = grain; + } while (++x < plane_width); +} + +template <int bitdepth, typename GrainType> +void ConstructNoiseImageOverlap_C(const void* noise_stripes_buffer, int width, + int height, int subsampling_x, + int subsampling_y, void* noise_image_buffer) { + const auto* noise_stripes = + static_cast<const Array2DView<GrainType>*>(noise_stripes_buffer); + auto* noise_image = static_cast<Array2D<GrainType>*>(noise_image_buffer); + const int plane_width = (width + subsampling_x) >> subsampling_x; + const int plane_height = (height + subsampling_y) >> subsampling_y; + const int stripe_height = 32 >> subsampling_y; + const int stripe_mask = stripe_height - 1; + int y = stripe_height; + int luma_num = 1; + if (subsampling_y == 0) { + // Begin complete stripes section. This is when we are guaranteed to have + // two overlap rows in each stripe. + for (; y < (plane_height & ~stripe_mask); ++luma_num, y += stripe_height) { + const GrainType* noise_stripe = (*noise_stripes)[luma_num]; + const GrainType* noise_stripe_prev = (*noise_stripes)[luma_num - 1]; + // First overlap row. + WriteOverlapLine_C<bitdepth>(noise_stripe, + &noise_stripe_prev[32 * plane_width], + plane_width, 17, 27, (*noise_image)[y]); + // Second overlap row. + WriteOverlapLine_C<bitdepth>(&noise_stripe[plane_width], + &noise_stripe_prev[(32 + 1) * plane_width], + plane_width, 27, 17, (*noise_image)[y + 1]); + } + // End complete stripes section. + + const int remaining_height = plane_height - y; + // Either one partial stripe remains (remaining_height > 0), + // OR image is less than one stripe high (remaining_height < 0), + // OR all stripes are completed (remaining_height == 0). + if (remaining_height <= 0) { + return; + } + const GrainType* noise_stripe = (*noise_stripes)[luma_num]; + const GrainType* noise_stripe_prev = (*noise_stripes)[luma_num - 1]; + WriteOverlapLine_C<bitdepth>(noise_stripe, + &noise_stripe_prev[32 * plane_width], + plane_width, 17, 27, (*noise_image)[y]); + + // Check if second overlap row is in the image. + if (remaining_height > 1) { + WriteOverlapLine_C<bitdepth>(&noise_stripe[plane_width], + &noise_stripe_prev[(32 + 1) * plane_width], + plane_width, 27, 17, (*noise_image)[y + 1]); + } + } else { // |subsampling_y| == 1 + // No special checks needed for partial stripes, because if one exists, the + // first and only overlap row is guaranteed to exist. + for (; y < plane_height; ++luma_num, y += stripe_height) { + const GrainType* noise_stripe = (*noise_stripes)[luma_num]; + const GrainType* noise_stripe_prev = (*noise_stripes)[luma_num - 1]; + WriteOverlapLine_C<bitdepth>(noise_stripe, + &noise_stripe_prev[16 * plane_width], + plane_width, 22, 23, (*noise_image)[y]); + } + } +} + +template <int bitdepth, typename GrainType, typename Pixel> +void BlendNoiseWithImageLuma_C( + const void* noise_image_ptr, int min_value, int max_luma, int scaling_shift, + int width, int height, int start_height, + const uint8_t scaling_lut_y[kScalingLookupTableSize], + const void* source_plane_y, ptrdiff_t source_stride_y, void* dest_plane_y, + ptrdiff_t dest_stride_y) { + const auto* noise_image = + static_cast<const Array2D<GrainType>*>(noise_image_ptr); + const auto* in_y = static_cast<const Pixel*>(source_plane_y); + source_stride_y /= sizeof(Pixel); + auto* out_y = static_cast<Pixel*>(dest_plane_y); + dest_stride_y /= sizeof(Pixel); + + int y = 0; + do { + int x = 0; + do { + const int orig = in_y[y * source_stride_y + x]; + int noise = noise_image[kPlaneY][y + start_height][x]; + noise = RightShiftWithRounding( + ScaleLut<bitdepth>(scaling_lut_y, orig) * noise, scaling_shift); + out_y[y * dest_stride_y + x] = Clip3(orig + noise, min_value, max_luma); + } while (++x < width); + } while (++y < height); +} + +// This function is for the case params_.chroma_scaling_from_luma == false. +template <int bitdepth, typename GrainType, typename Pixel> +void BlendNoiseWithImageChroma_C( + Plane plane, const FilmGrainParams& params, const void* noise_image_ptr, + int min_value, int max_chroma, int width, int height, int start_height, + int subsampling_x, int subsampling_y, + const uint8_t scaling_lut_uv[kScalingLookupTableSize], + const void* source_plane_y, ptrdiff_t source_stride_y, + const void* source_plane_uv, ptrdiff_t source_stride_uv, + void* dest_plane_uv, ptrdiff_t dest_stride_uv) { + const auto* noise_image = + static_cast<const Array2D<GrainType>*>(noise_image_ptr); + + const int chroma_width = (width + subsampling_x) >> subsampling_x; + const int chroma_height = (height + subsampling_y) >> subsampling_y; + + const auto* in_y = static_cast<const Pixel*>(source_plane_y); + source_stride_y /= sizeof(Pixel); + const auto* in_uv = static_cast<const Pixel*>(source_plane_uv); + source_stride_uv /= sizeof(Pixel); + auto* out_uv = static_cast<Pixel*>(dest_plane_uv); + dest_stride_uv /= sizeof(Pixel); + + const int offset = (plane == kPlaneU) ? params.u_offset : params.v_offset; + const int luma_multiplier = + (plane == kPlaneU) ? params.u_luma_multiplier : params.v_luma_multiplier; + const int multiplier = + (plane == kPlaneU) ? params.u_multiplier : params.v_multiplier; + + const int scaling_shift = params.chroma_scaling; + start_height >>= subsampling_y; + int y = 0; + do { + int x = 0; + do { + const int luma_x = x << subsampling_x; + const int luma_y = y << subsampling_y; + const int luma_next_x = std::min(luma_x + 1, width - 1); + int average_luma; + if (subsampling_x != 0) { + average_luma = RightShiftWithRounding( + in_y[luma_y * source_stride_y + luma_x] + + in_y[luma_y * source_stride_y + luma_next_x], + 1); + } else { + average_luma = in_y[luma_y * source_stride_y + luma_x]; + } + const int orig = in_uv[y * source_stride_uv + x]; + const int combined = average_luma * luma_multiplier + orig * multiplier; + const int merged = + Clip3((combined >> 6) + LeftShift(offset, bitdepth - 8), 0, + (1 << bitdepth) - 1); + int noise = noise_image[plane][y + start_height][x]; + noise = RightShiftWithRounding( + ScaleLut<bitdepth>(scaling_lut_uv, merged) * noise, scaling_shift); + out_uv[y * dest_stride_uv + x] = + Clip3(orig + noise, min_value, max_chroma); + } while (++x < chroma_width); + } while (++y < chroma_height); +} + +// This function is for the case params_.chroma_scaling_from_luma == true. +// This further implies that scaling_lut_u == scaling_lut_v == scaling_lut_y. +template <int bitdepth, typename GrainType, typename Pixel> +void BlendNoiseWithImageChromaWithCfl_C( + Plane plane, const FilmGrainParams& params, const void* noise_image_ptr, + int min_value, int max_chroma, int width, int height, int start_height, + int subsampling_x, int subsampling_y, + const uint8_t scaling_lut[kScalingLookupTableSize], + const void* source_plane_y, ptrdiff_t source_stride_y, + const void* source_plane_uv, ptrdiff_t source_stride_uv, + void* dest_plane_uv, ptrdiff_t dest_stride_uv) { + const auto* noise_image = + static_cast<const Array2D<GrainType>*>(noise_image_ptr); + const auto* in_y = static_cast<const Pixel*>(source_plane_y); + source_stride_y /= sizeof(Pixel); + const auto* in_uv = static_cast<const Pixel*>(source_plane_uv); + source_stride_uv /= sizeof(Pixel); + auto* out_uv = static_cast<Pixel*>(dest_plane_uv); + dest_stride_uv /= sizeof(Pixel); + + const int chroma_width = (width + subsampling_x) >> subsampling_x; + const int chroma_height = (height + subsampling_y) >> subsampling_y; + const int scaling_shift = params.chroma_scaling; + start_height >>= subsampling_y; + int y = 0; + do { + int x = 0; + do { + const int luma_x = x << subsampling_x; + const int luma_y = y << subsampling_y; + const int luma_next_x = std::min(luma_x + 1, width - 1); + int average_luma; + if (subsampling_x != 0) { + average_luma = RightShiftWithRounding( + in_y[luma_y * source_stride_y + luma_x] + + in_y[luma_y * source_stride_y + luma_next_x], + 1); + } else { + average_luma = in_y[luma_y * source_stride_y + luma_x]; + } + const int orig_uv = in_uv[y * source_stride_uv + x]; + int noise_uv = noise_image[plane][y + start_height][x]; + noise_uv = RightShiftWithRounding( + ScaleLut<bitdepth>(scaling_lut, average_luma) * noise_uv, + scaling_shift); + out_uv[y * dest_stride_uv + x] = + Clip3(orig_uv + noise_uv, min_value, max_chroma); + } while (++x < chroma_width); + } while (++y < chroma_height); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(8); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + // LumaAutoRegressionFunc + dsp->film_grain.luma_auto_regression[0] = + ApplyAutoRegressiveFilterToLumaGrain_C<8, int8_t>; + dsp->film_grain.luma_auto_regression[1] = + ApplyAutoRegressiveFilterToLumaGrain_C<8, int8_t>; + dsp->film_grain.luma_auto_regression[2] = + ApplyAutoRegressiveFilterToLumaGrain_C<8, int8_t>; + + // ChromaAutoRegressionFunc + // Chroma autoregression should never be called when lag is 0 and use_luma is + // false. + dsp->film_grain.chroma_auto_regression[0][0] = nullptr; + dsp->film_grain.chroma_auto_regression[0][1] = + ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 1, false>; + dsp->film_grain.chroma_auto_regression[0][2] = + ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 2, false>; + dsp->film_grain.chroma_auto_regression[0][3] = + ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 3, false>; + dsp->film_grain.chroma_auto_regression[1][0] = + ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 0, true>; + dsp->film_grain.chroma_auto_regression[1][1] = + ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 1, true>; + dsp->film_grain.chroma_auto_regression[1][2] = + ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 2, true>; + dsp->film_grain.chroma_auto_regression[1][3] = + ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 3, true>; + + // ConstructNoiseStripesFunc + dsp->film_grain.construct_noise_stripes[0] = + ConstructNoiseStripes_C<8, int8_t>; + dsp->film_grain.construct_noise_stripes[1] = + ConstructNoiseStripesWithOverlap_C<8, int8_t>; + + // ConstructNoiseImageOverlapFunc + dsp->film_grain.construct_noise_image_overlap = + ConstructNoiseImageOverlap_C<8, int8_t>; + + // InitializeScalingLutFunc + dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_C<0>; + + // BlendNoiseWithImageLumaFunc + dsp->film_grain.blend_noise_luma = + BlendNoiseWithImageLuma_C<8, int8_t, uint8_t>; + + // BlendNoiseWithImageChromaFunc + dsp->film_grain.blend_noise_chroma[0] = + BlendNoiseWithImageChroma_C<8, int8_t, uint8_t>; + dsp->film_grain.blend_noise_chroma[1] = + BlendNoiseWithImageChromaWithCfl_C<8, int8_t, uint8_t>; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp8bpp_FilmGrainAutoregressionLuma + dsp->film_grain.luma_auto_regression[0] = + ApplyAutoRegressiveFilterToLumaGrain_C<8, int8_t>; + dsp->film_grain.luma_auto_regression[1] = + ApplyAutoRegressiveFilterToLumaGrain_C<8, int8_t>; + dsp->film_grain.luma_auto_regression[2] = + ApplyAutoRegressiveFilterToLumaGrain_C<8, int8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_FilmGrainAutoregressionChroma + // Chroma autoregression should never be called when lag is 0 and use_luma is + // false. + dsp->film_grain.chroma_auto_regression[0][0] = nullptr; + dsp->film_grain.chroma_auto_regression[0][1] = + ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 1, false>; + dsp->film_grain.chroma_auto_regression[0][2] = + ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 2, false>; + dsp->film_grain.chroma_auto_regression[0][3] = + ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 3, false>; + dsp->film_grain.chroma_auto_regression[1][0] = + ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 0, true>; + dsp->film_grain.chroma_auto_regression[1][1] = + ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 1, true>; + dsp->film_grain.chroma_auto_regression[1][2] = + ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 2, true>; + dsp->film_grain.chroma_auto_regression[1][3] = + ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 3, true>; +#endif +#ifndef LIBGAV1_Dsp8bpp_FilmGrainConstructNoiseStripes + dsp->film_grain.construct_noise_stripes[0] = + ConstructNoiseStripes_C<8, int8_t>; + dsp->film_grain.construct_noise_stripes[1] = + ConstructNoiseStripesWithOverlap_C<8, int8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_FilmGrainConstructNoiseImageOverlap + dsp->film_grain.construct_noise_image_overlap = + ConstructNoiseImageOverlap_C<8, int8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_FilmGrainInitializeScalingLutFunc + dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_C<0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseLuma + dsp->film_grain.blend_noise_luma = + BlendNoiseWithImageLuma_C<8, int8_t, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseChroma + dsp->film_grain.blend_noise_chroma[0] = + BlendNoiseWithImageChroma_C<8, int8_t, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseChromaWithCfl + dsp->film_grain.blend_noise_chroma[1] = + BlendNoiseWithImageChromaWithCfl_C<8, int8_t, uint8_t>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(10); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + + // LumaAutoRegressionFunc + dsp->film_grain.luma_auto_regression[0] = + ApplyAutoRegressiveFilterToLumaGrain_C<10, int16_t>; + dsp->film_grain.luma_auto_regression[1] = + ApplyAutoRegressiveFilterToLumaGrain_C<10, int16_t>; + dsp->film_grain.luma_auto_regression[2] = + ApplyAutoRegressiveFilterToLumaGrain_C<10, int16_t>; + + // ChromaAutoRegressionFunc + // Chroma autoregression should never be called when lag is 0 and use_luma is + // false. + dsp->film_grain.chroma_auto_regression[0][0] = nullptr; + dsp->film_grain.chroma_auto_regression[0][1] = + ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 1, false>; + dsp->film_grain.chroma_auto_regression[0][2] = + ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 2, false>; + dsp->film_grain.chroma_auto_regression[0][3] = + ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 3, false>; + dsp->film_grain.chroma_auto_regression[1][0] = + ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 0, true>; + dsp->film_grain.chroma_auto_regression[1][1] = + ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 1, true>; + dsp->film_grain.chroma_auto_regression[1][2] = + ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 2, true>; + dsp->film_grain.chroma_auto_regression[1][3] = + ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 3, true>; + + // ConstructNoiseStripesFunc + dsp->film_grain.construct_noise_stripes[0] = + ConstructNoiseStripes_C<10, int16_t>; + dsp->film_grain.construct_noise_stripes[1] = + ConstructNoiseStripesWithOverlap_C<10, int16_t>; + + // ConstructNoiseImageOverlapFunc + dsp->film_grain.construct_noise_image_overlap = + ConstructNoiseImageOverlap_C<10, int16_t>; + + // InitializeScalingLutFunc + dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_C<0>; + + // BlendNoiseWithImageLumaFunc + dsp->film_grain.blend_noise_luma = + BlendNoiseWithImageLuma_C<10, int16_t, uint16_t>; + + // BlendNoiseWithImageChromaFunc + dsp->film_grain.blend_noise_chroma[0] = + BlendNoiseWithImageChroma_C<10, int16_t, uint16_t>; + dsp->film_grain.blend_noise_chroma[1] = + BlendNoiseWithImageChromaWithCfl_C<10, int16_t, uint16_t>; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp10bpp_FilmGrainAutoregressionLuma + dsp->film_grain.luma_auto_regression[0] = + ApplyAutoRegressiveFilterToLumaGrain_C<10, int16_t>; + dsp->film_grain.luma_auto_regression[1] = + ApplyAutoRegressiveFilterToLumaGrain_C<10, int16_t>; + dsp->film_grain.luma_auto_regression[2] = + ApplyAutoRegressiveFilterToLumaGrain_C<10, int16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_FilmGrainAutoregressionChroma + // Chroma autoregression should never be called when lag is 0 and use_luma is + // false. + dsp->film_grain.chroma_auto_regression[0][0] = nullptr; + dsp->film_grain.chroma_auto_regression[0][1] = + ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 1, false>; + dsp->film_grain.chroma_auto_regression[0][2] = + ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 2, false>; + dsp->film_grain.chroma_auto_regression[0][3] = + ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 3, false>; + dsp->film_grain.chroma_auto_regression[1][0] = + ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 0, true>; + dsp->film_grain.chroma_auto_regression[1][1] = + ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 1, true>; + dsp->film_grain.chroma_auto_regression[1][2] = + ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 2, true>; + dsp->film_grain.chroma_auto_regression[1][3] = + ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 3, true>; +#endif +#ifndef LIBGAV1_Dsp10bpp_FilmGrainConstructNoiseStripes + dsp->film_grain.construct_noise_stripes[0] = + ConstructNoiseStripes_C<10, int16_t>; + dsp->film_grain.construct_noise_stripes[1] = + ConstructNoiseStripesWithOverlap_C<10, int16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_FilmGrainConstructNoiseImageOverlap + dsp->film_grain.construct_noise_image_overlap = + ConstructNoiseImageOverlap_C<10, int16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_FilmGrainInitializeScalingLutFunc + dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_C<0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseLuma + dsp->film_grain.blend_noise_luma = + BlendNoiseWithImageLuma_C<10, int16_t, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseChroma + dsp->film_grain.blend_noise_chroma[0] = + BlendNoiseWithImageChroma_C<10, int16_t, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseChromaWithCfl + dsp->film_grain.blend_noise_chroma[1] = + BlendNoiseWithImageChromaWithCfl_C<10, int16_t, uint16_t>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +} // namespace +} // namespace film_grain + +void FilmGrainInit_C() { + film_grain::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + film_grain::Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 diff --git a/src/dsp/film_grain.h b/src/dsp/film_grain.h new file mode 100644 index 0000000..fe93270 --- /dev/null +++ b/src/dsp/film_grain.h @@ -0,0 +1,39 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_FILM_GRAIN_H_ +#define LIBGAV1_SRC_DSP_FILM_GRAIN_H_ + +// Pull in LIBGAV1_DspXXX defines representing the implementation status +// of each function. The resulting value of each can be used by each module to +// determine whether an implementation is needed at compile time. +// IWYU pragma: begin_exports + +// ARM: +#include "src/dsp/arm/film_grain_neon.h" + +// IWYU pragma: end_exports + +namespace libgav1 { +namespace dsp { + +// Initialize Dsp::film_grain_synthesis. This function is not thread-safe. +void FilmGrainInit_C(); + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_FILM_GRAIN_H_ diff --git a/src/dsp/film_grain_common.h b/src/dsp/film_grain_common.h new file mode 100644 index 0000000..64e3e8e --- /dev/null +++ b/src/dsp/film_grain_common.h @@ -0,0 +1,78 @@ +/* + * Copyright 2020 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_FILM_GRAIN_COMMON_H_ +#define LIBGAV1_SRC_DSP_FILM_GRAIN_COMMON_H_ + +#include <cstddef> +#include <cstdint> +#include <memory> +#include <type_traits> + +#include "src/dsp/common.h" +#include "src/utils/array_2d.h" +#include "src/utils/constants.h" +#include "src/utils/cpu.h" + +namespace libgav1 { + +template <int bitdepth> +int GetGrainMax() { + return (1 << (bitdepth - 1)) - 1; +} + +template <int bitdepth> +int GetGrainMin() { + return -(1 << (bitdepth - 1)); +} + +inline int GetFilmGrainRandomNumber(int bits, uint16_t* seed) { + uint16_t s = *seed; + uint16_t bit = (s ^ (s >> 1) ^ (s >> 3) ^ (s >> 12)) & 1; + s = (s >> 1) | (bit << 15); + *seed = s; + return s >> (16 - bits); +} + +enum { + kAutoRegressionBorder = 3, + // The width of the luma noise array. + kLumaWidth = 82, + // The height of the luma noise array. + kLumaHeight = 73, + // The two possible widths of the chroma noise array. + kMinChromaWidth = 44, + kMaxChromaWidth = 82, + // The two possible heights of the chroma noise array. + kMinChromaHeight = 38, + kMaxChromaHeight = 73, + // The scaling lookup table maps bytes to bytes, so only uses 256 elements, + // plus one for overflow in 10bit lookups. + kScalingLookupTableSize = 257, + // Padding is added to the scaling lookup table to permit overwrites by + // InitializeScalingLookupTable_NEON. + kScalingLookupTablePadding = 6, + // Padding is added to each row of the noise image to permit overreads by + // BlendNoiseWithImageLuma_NEON and overwrites by WriteOverlapLine8bpp_NEON. + kNoiseImagePadding = 7, + // Padding is added to the end of the |noise_stripes_| buffer to permit + // overreads by WriteOverlapLine8bpp_NEON. + kNoiseStripePadding = 7, +}; // anonymous enum + +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_FILM_GRAIN_COMMON_H_ diff --git a/src/dsp/intra_edge.cc b/src/dsp/intra_edge.cc new file mode 100644 index 0000000..fe66db2 --- /dev/null +++ b/src/dsp/intra_edge.cc @@ -0,0 +1,115 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intra_edge.h" + +#include <cassert> +#include <cstdint> +#include <cstring> + +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace { + +constexpr int kKernelTaps = 5; +constexpr int kKernels[3][kKernelTaps] = { + {0, 4, 8, 4, 0}, {0, 5, 6, 5, 0}, {2, 4, 4, 4, 2}}; +constexpr int kMaxUpsampleSize = 16; + +template <typename Pixel> +void IntraEdgeFilter_C(void* buffer, int size, int strength) { + assert(strength > 0); + Pixel edge[129]; + memcpy(edge, buffer, sizeof(edge[0]) * size); + auto* const dst_buffer = static_cast<Pixel*>(buffer); + const int kernel_index = strength - 1; + for (int i = 1; i < size; ++i) { + int sum = 0; + for (int j = 0; j < kKernelTaps; ++j) { + const int k = Clip3(i + j - 2, 0, size - 1); + sum += kKernels[kernel_index][j] * edge[k]; + } + dst_buffer[i] = RightShiftWithRounding(sum, 4); + } +} + +template <int bitdepth, typename Pixel> +void IntraEdgeUpsampler_C(void* buffer, int size) { + assert(size % 4 == 0 && size <= kMaxUpsampleSize); + auto* const pixel_buffer = static_cast<Pixel*>(buffer); + Pixel temp[kMaxUpsampleSize + 3]; + temp[0] = temp[1] = pixel_buffer[-1]; + memcpy(temp + 2, pixel_buffer, sizeof(temp[0]) * size); + temp[size + 2] = pixel_buffer[size - 1]; + + pixel_buffer[-2] = temp[0]; + for (int i = 0; i < size; ++i) { + const int sum = + -temp[i] + (9 * temp[i + 1]) + (9 * temp[i + 2]) - temp[i + 3]; + pixel_buffer[2 * i - 1] = + Clip3(RightShiftWithRounding(sum, 4), 0, (1 << bitdepth) - 1); + pixel_buffer[2 * i] = temp[i + 2]; + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(8); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->intra_edge_filter = IntraEdgeFilter_C<uint8_t>; + dsp->intra_edge_upsampler = IntraEdgeUpsampler_C<8, uint8_t>; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp8bpp_IntraEdgeFilter + dsp->intra_edge_filter = IntraEdgeFilter_C<uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_IntraEdgeUpsampler + dsp->intra_edge_upsampler = IntraEdgeUpsampler_C<8, uint8_t>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(10); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->intra_edge_filter = IntraEdgeFilter_C<uint16_t>; + dsp->intra_edge_upsampler = IntraEdgeUpsampler_C<10, uint16_t>; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp10bpp_IntraEdgeFilter + dsp->intra_edge_filter = IntraEdgeFilter_C<uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_IntraEdgeUpsampler + dsp->intra_edge_upsampler = IntraEdgeUpsampler_C<10, uint16_t>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} +#endif + +} // namespace + +void IntraEdgeInit_C() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 diff --git a/src/dsp/intra_edge.h b/src/dsp/intra_edge.h new file mode 100644 index 0000000..172ecbb --- /dev/null +++ b/src/dsp/intra_edge.h @@ -0,0 +1,48 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_INTRA_EDGE_H_ +#define LIBGAV1_SRC_DSP_INTRA_EDGE_H_ + +// Pull in LIBGAV1_DspXXX defines representing the implementation status +// of each function. The resulting value of each can be used by each module to +// determine whether an implementation is needed at compile time. +// IWYU pragma: begin_exports + +// ARM: +#include "src/dsp/arm/intra_edge_neon.h" + +// x86: +// Note includes should be sorted in logical order avx2/avx/sse4, etc. +// The order of includes is important as each tests for a superior version +// before setting the base. +// clang-format off +#include "src/dsp/x86/intra_edge_sse4.h" +// clang-format on + +// IWYU pragma: end_exports + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::intra_edge_filter and Dsp::intra_edge_upsampler. This +// function is not thread-safe. +void IntraEdgeInit_C(); + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_INTRA_EDGE_H_ diff --git a/src/dsp/intrapred.cc b/src/dsp/intrapred.cc new file mode 100644 index 0000000..4bcb580 --- /dev/null +++ b/src/dsp/intrapred.cc @@ -0,0 +1,2911 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intrapred.h" + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <cstring> // memset + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/memory.h" + +namespace libgav1 { +namespace dsp { +namespace { + +constexpr TransformSize kTransformSizesLargerThan32x32[] = { + kTransformSize16x64, kTransformSize32x64, kTransformSize64x16, + kTransformSize64x32, kTransformSize64x64}; + +template <int block_width, int block_height, typename Pixel> +struct IntraPredFuncs_C { + IntraPredFuncs_C() = delete; + + static void DcTop(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); + static void DcLeft(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); + static void Dc(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); + static void Vertical(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); + static void Horizontal(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); + static void Paeth(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); + static void Smooth(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); + static void SmoothVertical(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); + static void SmoothHorizontal(void* dest, ptrdiff_t stride, + const void* top_row, const void* left_column); +}; + +// Intra-predictors that require bitdepth. +template <int block_width, int block_height, int bitdepth, typename Pixel> +struct IntraPredBppFuncs_C { + IntraPredBppFuncs_C() = delete; + + static void DcFill(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); +}; + +//------------------------------------------------------------------------------ +// IntraPredFuncs_C::DcPred + +template <int block_width, int block_height, typename Pixel> +void IntraPredFuncs_C<block_width, block_height, Pixel>::DcTop( + void* const dest, ptrdiff_t stride, const void* const top_row, + const void* /*left_column*/) { + int sum = block_width >> 1; // rounder + const auto* const top = static_cast<const Pixel*>(top_row); + for (int x = 0; x < block_width; ++x) sum += top[x]; + const int dc = sum >> FloorLog2(block_width); + + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + for (int y = 0; y < block_height; ++y) { + Memset(dst, dc, block_width); + dst += stride; + } +} + +template <int block_width, int block_height, typename Pixel> +void IntraPredFuncs_C<block_width, block_height, Pixel>::DcLeft( + void* const dest, ptrdiff_t stride, const void* /*top_row*/, + const void* const left_column) { + int sum = block_height >> 1; // rounder + const auto* const left = static_cast<const Pixel*>(left_column); + for (int y = 0; y < block_height; ++y) sum += left[y]; + const int dc = sum >> FloorLog2(block_height); + + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + for (int y = 0; y < block_height; ++y) { + Memset(dst, dc, block_width); + dst += stride; + } +} + +// Note for square blocks the divide in the Dc() function reduces to a shift. +// For rectangular block sizes the following multipliers can be used with the +// corresponding shifts. +// 8-bit +// 1:2 (e.g,, 4x8): scale = 0x5556 +// 1:4 (e.g., 4x16): scale = 0x3334 +// final_descale = 16 +// 10/12-bit +// 1:2: scale = 0xaaab +// 1:4: scale = 0x6667 +// final_descale = 17 +// Note these may be halved to the values used in 8-bit in all cases except +// when bitdepth == 12 and block_width + block_height is divisible by 5 (as +// opposed to 3). +// +// The calculation becomes: +// (dc_sum >> intermediate_descale) * scale) >> final_descale +// where intermediate_descale is: +// sum = block_width + block_height +// intermediate_descale = +// (sum <= 20) ? 2 : (sum <= 40) ? 3 : (sum <= 80) ? 4 : 5 +// +// The constants (multiplier and shifts) for a given block size are obtained +// as follows: +// - Let sum = block width + block height +// - Shift 'sum' right until we reach an odd number +// - Let the number of shifts for that block size be called 'intermediate_scale' +// and let the odd number be 'd' (d has only 2 possible values: d = 3 for a +// 1:2 rectangular block and d = 5 for a 1:4 rectangular block). +// - Find multipliers by dividing by 'd' using "Algorithm 1" in: +// http://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=1467632 +// by ensuring that m + n = 16 (in that algorithm). This ensures that our 2nd +// shift will be 16, regardless of the block size. +// TODO(jzern): the base implementation could be updated to use this method. + +template <int block_width, int block_height, typename Pixel> +void IntraPredFuncs_C<block_width, block_height, Pixel>::Dc( + void* const dest, ptrdiff_t stride, const void* const top_row, + const void* const left_column) { + const int divisor = block_width + block_height; + int sum = divisor >> 1; // rounder + + const auto* const top = static_cast<const Pixel*>(top_row); + const auto* const left = static_cast<const Pixel*>(left_column); + for (int x = 0; x < block_width; ++x) sum += top[x]; + for (int y = 0; y < block_height; ++y) sum += left[y]; + + const int dc = sum / divisor; + + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + for (int y = 0; y < block_height; ++y) { + Memset(dst, dc, block_width); + dst += stride; + } +} + +//------------------------------------------------------------------------------ +// IntraPredFuncs_C directional predictors + +// IntraPredFuncs_C::Vertical -- apply top row vertically +template <int block_width, int block_height, typename Pixel> +void IntraPredFuncs_C<block_width, block_height, Pixel>::Vertical( + void* const dest, ptrdiff_t stride, const void* const top_row, + const void* /*left_column*/) { + auto* dst = static_cast<uint8_t*>(dest); + for (int y = 0; y < block_height; ++y) { + memcpy(dst, top_row, block_width * sizeof(Pixel)); + dst += stride; + } +} + +// IntraPredFuncs_C::Horizontal -- apply left column horizontally +template <int block_width, int block_height, typename Pixel> +void IntraPredFuncs_C<block_width, block_height, Pixel>::Horizontal( + void* const dest, ptrdiff_t stride, const void* /*top_row*/, + const void* const left_column) { + const auto* const left = static_cast<const Pixel*>(left_column); + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + for (int y = 0; y < block_height; ++y) { + Memset(dst, left[y], block_width); + dst += stride; + } +} + +template <typename Pixel> +inline Pixel Average(Pixel a, Pixel b) { + return static_cast<Pixel>((a + b + 1) >> 1); +} + +template <typename Pixel> +inline Pixel Average(Pixel a, Pixel b, Pixel c) { + return static_cast<Pixel>((a + 2 * b + c + 2) >> 2); +} + +// IntraPredFuncs_C::Paeth +template <int block_width, int block_height, typename Pixel> +void IntraPredFuncs_C<block_width, block_height, Pixel>::Paeth( + void* const dest, ptrdiff_t stride, const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const Pixel*>(top_row); + const auto* const left = static_cast<const Pixel*>(left_column); + const Pixel top_left = top[-1]; + const int top_left_x2 = top_left + top_left; + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + + for (int y = 0; y < block_height; ++y) { + const int left_pixel = left[y]; + for (int x = 0; x < block_width; ++x) { + // The Paeth filter selects the value closest to: + // top[x] + left[y] - top_left + // To calculate the absolute distance for the left value this would be: + // abs((top[x] + left[y] - top_left) - left[y]) + // or, because left[y] cancels out: + // abs(top[x] - top_left) + const int left_dist = std::abs(top[x] - top_left); + const int top_dist = std::abs(left_pixel - top_left); + const int top_left_dist = std::abs(top[x] + left_pixel - top_left_x2); + + // Select the closest value to the initial estimate of 'T + L - TL'. + if (left_dist <= top_dist && left_dist <= top_left_dist) { + dst[x] = left_pixel; + } else if (top_dist <= top_left_dist) { + dst[x] = top[x]; + } else { + dst[x] = top_left; + } + } + dst += stride; + } +} + +constexpr uint8_t kSmoothWeights[] = { + // block dimension = 4 + 255, 149, 85, 64, + // block dimension = 8 + 255, 197, 146, 105, 73, 50, 37, 32, + // block dimension = 16 + 255, 225, 196, 170, 145, 123, 102, 84, 68, 54, 43, 33, 26, 20, 17, 16, + // block dimension = 32 + 255, 240, 225, 210, 196, 182, 169, 157, 145, 133, 122, 111, 101, 92, 83, 74, + 66, 59, 52, 45, 39, 34, 29, 25, 21, 17, 14, 12, 10, 9, 8, 8, + // block dimension = 64 + 255, 248, 240, 233, 225, 218, 210, 203, 196, 189, 182, 176, 169, 163, 156, + 150, 144, 138, 133, 127, 121, 116, 111, 106, 101, 96, 91, 86, 82, 77, 73, + 69, 65, 61, 57, 54, 50, 47, 44, 41, 38, 35, 32, 29, 27, 25, 22, 20, 18, 16, + 15, 13, 12, 10, 9, 8, 7, 6, 6, 5, 5, 4, 4, 4}; + +// IntraPredFuncs_C::Smooth +template <int block_width, int block_height, typename Pixel> +void IntraPredFuncs_C<block_width, block_height, Pixel>::Smooth( + void* const dest, ptrdiff_t stride, const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const Pixel*>(top_row); + const auto* const left = static_cast<const Pixel*>(left_column); + const Pixel top_right = top[block_width - 1]; + const Pixel bottom_left = left[block_height - 1]; + static_assert( + block_width >= 4 && block_height >= 4, + "Weights for smooth predictor undefined for block width/height < 4"); + const uint8_t* const weights_x = kSmoothWeights + block_width - 4; + const uint8_t* const weights_y = kSmoothWeights + block_height - 4; + const uint16_t scale_value = (1 << kSmoothWeightScale); + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + + for (int y = 0; y < block_height; ++y) { + for (int x = 0; x < block_width; ++x) { + assert(scale_value >= weights_y[y] && scale_value >= weights_x[x]); + uint32_t pred = weights_y[y] * top[x]; + pred += weights_x[x] * left[y]; + pred += static_cast<uint8_t>(scale_value - weights_y[y]) * bottom_left; + pred += static_cast<uint8_t>(scale_value - weights_x[x]) * top_right; + // The maximum value of pred with the rounder is 2^9 * (2^bitdepth - 1) + // + 256. With the descale there's no need for saturation. + dst[x] = static_cast<Pixel>( + RightShiftWithRounding(pred, kSmoothWeightScale + 1)); + } + dst += stride; + } +} + +// IntraPredFuncs_C::SmoothVertical +template <int block_width, int block_height, typename Pixel> +void IntraPredFuncs_C<block_width, block_height, Pixel>::SmoothVertical( + void* const dest, ptrdiff_t stride, const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const Pixel*>(top_row); + const auto* const left = static_cast<const Pixel*>(left_column); + const Pixel bottom_left = left[block_height - 1]; + static_assert(block_height >= 4, + "Weights for smooth predictor undefined for block height < 4"); + const uint8_t* const weights_y = kSmoothWeights + block_height - 4; + const uint16_t scale_value = (1 << kSmoothWeightScale); + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + + for (int y = 0; y < block_height; ++y) { + for (int x = 0; x < block_width; ++x) { + assert(scale_value >= weights_y[y]); + uint32_t pred = weights_y[y] * top[x]; + pred += static_cast<uint8_t>(scale_value - weights_y[y]) * bottom_left; + dst[x] = + static_cast<Pixel>(RightShiftWithRounding(pred, kSmoothWeightScale)); + } + dst += stride; + } +} + +// IntraPredFuncs_C::SmoothHorizontal +template <int block_width, int block_height, typename Pixel> +void IntraPredFuncs_C<block_width, block_height, Pixel>::SmoothHorizontal( + void* const dest, ptrdiff_t stride, const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const Pixel*>(top_row); + const auto* const left = static_cast<const Pixel*>(left_column); + const Pixel top_right = top[block_width - 1]; + static_assert(block_width >= 4, + "Weights for smooth predictor undefined for block width < 4"); + const uint8_t* const weights_x = kSmoothWeights + block_width - 4; + const uint16_t scale_value = (1 << kSmoothWeightScale); + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + + for (int y = 0; y < block_height; ++y) { + for (int x = 0; x < block_width; ++x) { + assert(scale_value >= weights_x[x]); + uint32_t pred = weights_x[x] * left[y]; + pred += static_cast<uint8_t>(scale_value - weights_x[x]) * top_right; + dst[x] = + static_cast<Pixel>(RightShiftWithRounding(pred, kSmoothWeightScale)); + } + dst += stride; + } +} + +//------------------------------------------------------------------------------ +// IntraPredBppFuncs_C +template <int fill, typename Pixel> +inline void DcFill_C(void* const dest, ptrdiff_t stride, const int block_width, + const int block_height) { + static_assert(sizeof(Pixel) == 1 || sizeof(Pixel) == 2, + "Only 1 & 2 byte pixels are supported"); + + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + for (int y = 0; y < block_height; ++y) { + Memset(dst, fill, block_width); + dst += stride; + } +} + +template <int block_width, int block_height, int bitdepth, typename Pixel> +void IntraPredBppFuncs_C<block_width, block_height, bitdepth, Pixel>::DcFill( + void* const dest, ptrdiff_t stride, const void* /*top_row*/, + const void* /*left_column*/) { + DcFill_C<0x80 << (bitdepth - 8), Pixel>(dest, stride, block_width, + block_height); +} + +//------------------------------------------------------------------------------ +// FilterIntraPredictor_C + +template <int bitdepth, typename Pixel> +void FilterIntraPredictor_C(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column, + const FilterIntraPredictor pred, const int width, + const int height) { + const int kMaxPixel = (1 << bitdepth) - 1; + const auto* const top = static_cast<const Pixel*>(top_row); + const auto* const left = static_cast<const Pixel*>(left_column); + + assert(width <= 32 && height <= 32); + + Pixel buffer[3][33]; // cache 2 rows + top & left boundaries + memcpy(buffer[0], &top[-1], (width + 1) * sizeof(top[0])); + + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + int row0 = 0, row2 = 2; + int ystep = 1; + int y = 0; + do { + buffer[1][0] = left[y]; + buffer[row2][0] = left[y + 1]; + int x = 1; + do { + const Pixel p0 = buffer[row0][x - 1]; // top-left + const Pixel p1 = buffer[row0][x + 0]; // top 0 + const Pixel p2 = buffer[row0][x + 1]; // top 1 + const Pixel p3 = buffer[row0][x + 2]; // top 2 + const Pixel p4 = buffer[row0][x + 3]; // top 3 + const Pixel p5 = buffer[1][x - 1]; // left 0 + const Pixel p6 = buffer[row2][x - 1]; // left 1 + for (int i = 0; i < 8; ++i) { + const int xoffset = i & 0x03; + const int yoffset = (i >> 2) * ystep; + const int value = kFilterIntraTaps[pred][i][0] * p0 + + kFilterIntraTaps[pred][i][1] * p1 + + kFilterIntraTaps[pred][i][2] * p2 + + kFilterIntraTaps[pred][i][3] * p3 + + kFilterIntraTaps[pred][i][4] * p4 + + kFilterIntraTaps[pred][i][5] * p5 + + kFilterIntraTaps[pred][i][6] * p6; + buffer[1 + yoffset][x + xoffset] = static_cast<Pixel>( + Clip3(RightShiftWithRounding(value, 4), 0, kMaxPixel)); + } + x += 4; + } while (x < width); + memcpy(dst, &buffer[1][1], width * sizeof(dst[0])); + dst += stride; + memcpy(dst, &buffer[row2][1], width * sizeof(dst[0])); + dst += stride; + + // The final row becomes the top for the next pass. + row0 ^= 2; + row2 ^= 2; + ystep = -ystep; + y += 2; + } while (y < height); +} + +//------------------------------------------------------------------------------ +// CflIntraPredictor_C + +// |luma| can be within +/-(((1 << bitdepth) - 1) << 3), inclusive. +// |alpha| can be -16 to 16 (inclusive). +template <int block_width, int block_height, int bitdepth, typename Pixel> +void CflIntraPredictor_C( + void* const dest, ptrdiff_t stride, + const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int alpha) { + auto* dst = static_cast<Pixel*>(dest); + const int dc = dst[0]; + stride /= sizeof(Pixel); + const int max_value = (1 << bitdepth) - 1; + for (int y = 0; y < block_height; ++y) { + for (int x = 0; x < block_width; ++x) { + assert(luma[y][x] >= -(((1 << bitdepth) - 1) << 3)); + assert(luma[y][x] <= ((1 << bitdepth) - 1) << 3); + dst[x] = Clip3(dc + RightShiftWithRoundingSigned(alpha * luma[y][x], 6), + 0, max_value); + } + dst += stride; + } +} + +//------------------------------------------------------------------------------ +// CflSubsampler_C + +template <int block_width, int block_height, int bitdepth, typename Pixel, + int subsampling_x, int subsampling_y> +void CflSubsampler_C(int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_width, const int max_luma_height, + const void* const source, ptrdiff_t stride) { + assert(max_luma_width >= 4); + assert(max_luma_height >= 4); + const auto* src = static_cast<const Pixel*>(source); + stride /= sizeof(Pixel); + int sum = 0; + for (int y = 0; y < block_height; ++y) { + for (int x = 0; x < block_width; ++x) { + const ptrdiff_t luma_x = + std::min(x << subsampling_x, max_luma_width - (1 << subsampling_x)); + const ptrdiff_t luma_x_next = luma_x + stride; + luma[y][x] = + (src[luma_x] + ((subsampling_x != 0) ? src[luma_x + 1] : 0) + + ((subsampling_y != 0) ? (src[luma_x_next] + src[luma_x_next + 1]) + : 0)) + << (3 - subsampling_x - subsampling_y); + sum += luma[y][x]; + } + if ((y << subsampling_y) < (max_luma_height - (1 << subsampling_y))) { + src += stride << subsampling_y; + } + } + const int average = RightShiftWithRounding( + sum, FloorLog2(block_width) + FloorLog2(block_height)); + for (int y = 0; y < block_height; ++y) { + for (int x = 0; x < block_width; ++x) { + luma[y][x] -= average; + } + } +} + +//------------------------------------------------------------------------------ +// 7.11.2.4. Directional intra prediction process + +template <typename Pixel> +void DirectionalIntraPredictorZone1_C(void* const dest, ptrdiff_t stride, + const void* const top_row, + const int width, const int height, + const int xstep, + const bool upsampled_top) { + const auto* const top = static_cast<const Pixel*>(top_row); + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + + assert(xstep > 0); + + // If xstep == 64 then |shift| always evaluates to 0 which sets |val| to + // |top[top_base_x]|. This corresponds to a 45 degree prediction. + if (xstep == 64) { + // 7.11.2.10. Intra edge upsample selection process + // if ( d <= 0 || d >= 40 ) useUpsample = 0 + // For |upsampled_top| the delta is |predictor_angle - 90|. Since the + // |predictor_angle| is 45 the delta is also 45. + assert(!upsampled_top); + const Pixel* top_ptr = top + 1; + for (int y = 0; y < height; ++y, dst += stride, ++top_ptr) { + memcpy(dst, top_ptr, sizeof(*top_ptr) * width); + } + return; + } + + const int upsample_shift = static_cast<int>(upsampled_top); + const int max_base_x = ((width + height) - 1) << upsample_shift; + const int scale_bits = 6 - upsample_shift; + const int base_step = 1 << upsample_shift; + int top_x = xstep; + int y = 0; + do { + int top_base_x = top_x >> scale_bits; + + if (top_base_x >= max_base_x) { + for (int i = y; i < height; ++i) { + Memset(dst, top[max_base_x], width); + dst += stride; + } + return; + } + + const int shift = ((top_x << upsample_shift) & 0x3F) >> 1; + int x = 0; + do { + if (top_base_x >= max_base_x) { + Memset(dst + x, top[max_base_x], width - x); + break; + } + + const int val = + top[top_base_x] * (32 - shift) + top[top_base_x + 1] * shift; + dst[x] = RightShiftWithRounding(val, 5); + top_base_x += base_step; + } while (++x < width); + + dst += stride; + top_x += xstep; + } while (++y < height); +} + +template <typename Pixel> +void DirectionalIntraPredictorZone2_C(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column, + const int width, const int height, + const int xstep, const int ystep, + const bool upsampled_top, + const bool upsampled_left) { + const auto* const top = static_cast<const Pixel*>(top_row); + const auto* const left = static_cast<const Pixel*>(left_column); + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + + assert(xstep > 0); + assert(ystep > 0); + + const int upsample_top_shift = static_cast<int>(upsampled_top); + const int upsample_left_shift = static_cast<int>(upsampled_left); + const int scale_bits_x = 6 - upsample_top_shift; + const int scale_bits_y = 6 - upsample_left_shift; + const int min_base_x = -(1 << upsample_top_shift); + const int base_step_x = 1 << upsample_top_shift; + int y = 0; + int top_x = -xstep; + do { + int top_base_x = top_x >> scale_bits_x; + int left_y = (y << 6) - ystep; + int x = 0; + do { + int val; + if (top_base_x >= min_base_x) { + const int shift = ((top_x * (1 << upsample_top_shift)) & 0x3F) >> 1; + val = top[top_base_x] * (32 - shift) + top[top_base_x + 1] * shift; + } else { + // Note this assumes an arithmetic shift to handle negative values. + const int left_base_y = left_y >> scale_bits_y; + const int shift = ((left_y * (1 << upsample_left_shift)) & 0x3F) >> 1; + assert(left_base_y >= -(1 << upsample_left_shift)); + val = left[left_base_y] * (32 - shift) + left[left_base_y + 1] * shift; + } + dst[x] = RightShiftWithRounding(val, 5); + top_base_x += base_step_x; + left_y -= ystep; + } while (++x < width); + + top_x -= xstep; + dst += stride; + } while (++y < height); +} + +template <typename Pixel> +void DirectionalIntraPredictorZone3_C(void* const dest, ptrdiff_t stride, + const void* const left_column, + const int width, const int height, + const int ystep, + const bool upsampled_left) { + const auto* const left = static_cast<const Pixel*>(left_column); + stride /= sizeof(Pixel); + + assert(ystep > 0); + + const int upsample_shift = static_cast<int>(upsampled_left); + const int scale_bits = 6 - upsample_shift; + const int base_step = 1 << upsample_shift; + // Zone3 never runs out of left_column values. + assert((width + height - 1) << upsample_shift > // max_base_y + ((ystep * width) >> scale_bits) + + base_step * (height - 1)); // left_base_y + + int left_y = ystep; + int x = 0; + do { + auto* dst = static_cast<Pixel*>(dest); + + int left_base_y = left_y >> scale_bits; + int y = 0; + do { + const int shift = ((left_y << upsample_shift) & 0x3F) >> 1; + const int val = + left[left_base_y] * (32 - shift) + left[left_base_y + 1] * shift; + dst[x] = RightShiftWithRounding(val, 5); + dst += stride; + left_base_y += base_step; + } while (++y < height); + + left_y += ystep; + } while (++x < width); +} + +//------------------------------------------------------------------------------ + +template <typename Pixel> +struct IntraPredDefs { + IntraPredDefs() = delete; + + using _4x4 = IntraPredFuncs_C<4, 4, Pixel>; + using _4x8 = IntraPredFuncs_C<4, 8, Pixel>; + using _4x16 = IntraPredFuncs_C<4, 16, Pixel>; + using _8x4 = IntraPredFuncs_C<8, 4, Pixel>; + using _8x8 = IntraPredFuncs_C<8, 8, Pixel>; + using _8x16 = IntraPredFuncs_C<8, 16, Pixel>; + using _8x32 = IntraPredFuncs_C<8, 32, Pixel>; + using _16x4 = IntraPredFuncs_C<16, 4, Pixel>; + using _16x8 = IntraPredFuncs_C<16, 8, Pixel>; + using _16x16 = IntraPredFuncs_C<16, 16, Pixel>; + using _16x32 = IntraPredFuncs_C<16, 32, Pixel>; + using _16x64 = IntraPredFuncs_C<16, 64, Pixel>; + using _32x8 = IntraPredFuncs_C<32, 8, Pixel>; + using _32x16 = IntraPredFuncs_C<32, 16, Pixel>; + using _32x32 = IntraPredFuncs_C<32, 32, Pixel>; + using _32x64 = IntraPredFuncs_C<32, 64, Pixel>; + using _64x16 = IntraPredFuncs_C<64, 16, Pixel>; + using _64x32 = IntraPredFuncs_C<64, 32, Pixel>; + using _64x64 = IntraPredFuncs_C<64, 64, Pixel>; +}; + +template <int bitdepth, typename Pixel> +struct IntraPredBppDefs { + IntraPredBppDefs() = delete; + + using _4x4 = IntraPredBppFuncs_C<4, 4, bitdepth, Pixel>; + using _4x8 = IntraPredBppFuncs_C<4, 8, bitdepth, Pixel>; + using _4x16 = IntraPredBppFuncs_C<4, 16, bitdepth, Pixel>; + using _8x4 = IntraPredBppFuncs_C<8, 4, bitdepth, Pixel>; + using _8x8 = IntraPredBppFuncs_C<8, 8, bitdepth, Pixel>; + using _8x16 = IntraPredBppFuncs_C<8, 16, bitdepth, Pixel>; + using _8x32 = IntraPredBppFuncs_C<8, 32, bitdepth, Pixel>; + using _16x4 = IntraPredBppFuncs_C<16, 4, bitdepth, Pixel>; + using _16x8 = IntraPredBppFuncs_C<16, 8, bitdepth, Pixel>; + using _16x16 = IntraPredBppFuncs_C<16, 16, bitdepth, Pixel>; + using _16x32 = IntraPredBppFuncs_C<16, 32, bitdepth, Pixel>; + using _16x64 = IntraPredBppFuncs_C<16, 64, bitdepth, Pixel>; + using _32x8 = IntraPredBppFuncs_C<32, 8, bitdepth, Pixel>; + using _32x16 = IntraPredBppFuncs_C<32, 16, bitdepth, Pixel>; + using _32x32 = IntraPredBppFuncs_C<32, 32, bitdepth, Pixel>; + using _32x64 = IntraPredBppFuncs_C<32, 64, bitdepth, Pixel>; + using _64x16 = IntraPredBppFuncs_C<64, 16, bitdepth, Pixel>; + using _64x32 = IntraPredBppFuncs_C<64, 32, bitdepth, Pixel>; + using _64x64 = IntraPredBppFuncs_C<64, 64, bitdepth, Pixel>; +}; + +using Defs = IntraPredDefs<uint8_t>; +using Defs8bpp = IntraPredBppDefs<8, uint8_t>; + +// Initializes dsp entries for kTransformSize|W|x|H| from |DEFS|/|DEFSBPP| of +// the same size. +#define INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, W, H) \ + dsp->intra_predictors[kTransformSize##W##x##H][kIntraPredictorDcFill] = \ + DEFSBPP::_##W##x##H::DcFill; \ + dsp->intra_predictors[kTransformSize##W##x##H][kIntraPredictorDcTop] = \ + DEFS::_##W##x##H::DcTop; \ + dsp->intra_predictors[kTransformSize##W##x##H][kIntraPredictorDcLeft] = \ + DEFS::_##W##x##H::DcLeft; \ + dsp->intra_predictors[kTransformSize##W##x##H][kIntraPredictorDc] = \ + DEFS::_##W##x##H::Dc; \ + dsp->intra_predictors[kTransformSize##W##x##H][kIntraPredictorVertical] = \ + DEFS::_##W##x##H::Vertical; \ + dsp->intra_predictors[kTransformSize##W##x##H][kIntraPredictorHorizontal] = \ + DEFS::_##W##x##H::Horizontal; \ + dsp->intra_predictors[kTransformSize##W##x##H][kIntraPredictorPaeth] = \ + DEFS::_##W##x##H::Paeth; \ + dsp->intra_predictors[kTransformSize##W##x##H][kIntraPredictorSmooth] = \ + DEFS::_##W##x##H::Smooth; \ + dsp->intra_predictors[kTransformSize##W##x##H] \ + [kIntraPredictorSmoothVertical] = \ + DEFS::_##W##x##H::SmoothVertical; \ + dsp->intra_predictors[kTransformSize##W##x##H] \ + [kIntraPredictorSmoothHorizontal] = \ + DEFS::_##W##x##H::SmoothHorizontal + +#define INIT_INTRAPREDICTORS(DEFS, DEFSBPP) \ + INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, 4, 4); \ + INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, 4, 8); \ + INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, 4, 16); \ + INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, 8, 4); \ + INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, 8, 8); \ + INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, 8, 16); \ + INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, 8, 32); \ + INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, 16, 4); \ + INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, 16, 8); \ + INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, 16, 16); \ + INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, 16, 32); \ + INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, 16, 64); \ + INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, 32, 8); \ + INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, 32, 16); \ + INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, 32, 32); \ + INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, 32, 64); \ + INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, 64, 16); \ + INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, 64, 32); \ + INIT_INTRAPREDICTORS_WxH(DEFS, DEFSBPP, 64, 64) + +#define INIT_CFL_INTRAPREDICTOR_WxH(W, H, BITDEPTH, PIXEL) \ + dsp->cfl_intra_predictors[kTransformSize##W##x##H] = \ + CflIntraPredictor_C<W, H, BITDEPTH, PIXEL>; \ + dsp->cfl_subsamplers[kTransformSize##W##x##H][kSubsamplingType444] = \ + CflSubsampler_C<W, H, BITDEPTH, PIXEL, 0, 0>; \ + dsp->cfl_subsamplers[kTransformSize##W##x##H][kSubsamplingType422] = \ + CflSubsampler_C<W, H, BITDEPTH, PIXEL, 1, 0>; \ + dsp->cfl_subsamplers[kTransformSize##W##x##H][kSubsamplingType420] = \ + CflSubsampler_C<W, H, BITDEPTH, PIXEL, 1, 1> + +#define INIT_CFL_INTRAPREDICTORS(BITDEPTH, PIXEL) \ + INIT_CFL_INTRAPREDICTOR_WxH(4, 4, BITDEPTH, PIXEL); \ + INIT_CFL_INTRAPREDICTOR_WxH(4, 8, BITDEPTH, PIXEL); \ + INIT_CFL_INTRAPREDICTOR_WxH(4, 16, BITDEPTH, PIXEL); \ + INIT_CFL_INTRAPREDICTOR_WxH(8, 4, BITDEPTH, PIXEL); \ + INIT_CFL_INTRAPREDICTOR_WxH(8, 8, BITDEPTH, PIXEL); \ + INIT_CFL_INTRAPREDICTOR_WxH(8, 16, BITDEPTH, PIXEL); \ + INIT_CFL_INTRAPREDICTOR_WxH(8, 32, BITDEPTH, PIXEL); \ + INIT_CFL_INTRAPREDICTOR_WxH(16, 4, BITDEPTH, PIXEL); \ + INIT_CFL_INTRAPREDICTOR_WxH(16, 8, BITDEPTH, PIXEL); \ + INIT_CFL_INTRAPREDICTOR_WxH(16, 16, BITDEPTH, PIXEL); \ + INIT_CFL_INTRAPREDICTOR_WxH(16, 32, BITDEPTH, PIXEL); \ + INIT_CFL_INTRAPREDICTOR_WxH(32, 8, BITDEPTH, PIXEL); \ + INIT_CFL_INTRAPREDICTOR_WxH(32, 16, BITDEPTH, PIXEL); \ + INIT_CFL_INTRAPREDICTOR_WxH(32, 32, BITDEPTH, PIXEL) + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(8); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + INIT_INTRAPREDICTORS(Defs, Defs8bpp); + dsp->directional_intra_predictor_zone1 = + DirectionalIntraPredictorZone1_C<uint8_t>; + dsp->directional_intra_predictor_zone2 = + DirectionalIntraPredictorZone2_C<uint8_t>; + dsp->directional_intra_predictor_zone3 = + DirectionalIntraPredictorZone3_C<uint8_t>; + dsp->filter_intra_predictor = FilterIntraPredictor_C<8, uint8_t>; + INIT_CFL_INTRAPREDICTORS(8, uint8_t); +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcFill] = + Defs8bpp::_4x4::DcFill; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcTop] = + Defs::_4x4::DcTop; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcLeft] = + Defs::_4x4::DcLeft; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDc + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDc] = Defs::_4x4::Dc; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorVertical + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorVertical] = + Defs::_4x4::Vertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorHorizontal] = + Defs::_4x4::Horizontal; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorPaeth] = + Defs::_4x4::Paeth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmooth] = + Defs::_4x4::Smooth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmoothVertical] = + Defs::_4x4::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmoothHorizontal] = + Defs::_4x4::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcFill] = + Defs8bpp::_4x8::DcFill; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcTop] = + Defs::_4x8::DcTop; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcLeft] = + Defs::_4x8::DcLeft; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDc + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDc] = Defs::_4x8::Dc; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorVertical + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorVertical] = + Defs::_4x8::Vertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorHorizontal] = + Defs::_4x8::Horizontal; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorPaeth] = + Defs::_4x8::Paeth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmooth] = + Defs::_4x8::Smooth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmoothVertical] = + Defs::_4x8::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmoothHorizontal] = + Defs::_4x8::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcFill] = + Defs8bpp::_4x16::DcFill; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcTop] = + Defs::_4x16::DcTop; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcLeft] = + Defs::_4x16::DcLeft; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDc + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDc] = + Defs::_4x16::Dc; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorVertical + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorVertical] = + Defs::_4x16::Vertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorHorizontal] = + Defs::_4x16::Horizontal; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorPaeth] = + Defs::_4x16::Paeth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmooth] = + Defs::_4x16::Smooth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmoothVertical] = + Defs::_4x16::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmoothHorizontal] = + Defs::_4x16::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcFill] = + Defs8bpp::_8x4::DcFill; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcTop] = + Defs::_8x4::DcTop; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcLeft] = + Defs::_8x4::DcLeft; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDc + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDc] = Defs::_8x4::Dc; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorVertical + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorVertical] = + Defs::_8x4::Vertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorHorizontal] = + Defs::_8x4::Horizontal; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorPaeth] = + Defs::_8x4::Paeth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmooth] = + Defs::_8x4::Smooth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmoothVertical] = + Defs::_8x4::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmoothHorizontal] = + Defs::_8x4::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcFill] = + Defs8bpp::_8x8::DcFill; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcTop] = + Defs::_8x8::DcTop; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcLeft] = + Defs::_8x8::DcLeft; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDc + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDc] = Defs::_8x8::Dc; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorVertical + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorVertical] = + Defs::_8x8::Vertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorHorizontal] = + Defs::_8x8::Horizontal; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorPaeth] = + Defs::_8x8::Paeth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmooth] = + Defs::_8x8::Smooth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmoothVertical] = + Defs::_8x8::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmoothHorizontal] = + Defs::_8x8::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcFill] = + Defs8bpp::_8x16::DcFill; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcTop] = + Defs::_8x16::DcTop; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcLeft] = + Defs::_8x16::DcLeft; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDc + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDc] = + Defs::_8x16::Dc; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorVertical + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorVertical] = + Defs::_8x16::Vertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorHorizontal] = + Defs::_8x16::Horizontal; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorPaeth] = + Defs::_8x16::Paeth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmooth] = + Defs::_8x16::Smooth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmoothVertical] = + Defs::_8x16::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmoothHorizontal] = + Defs::_8x16::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcFill] = + Defs8bpp::_8x32::DcFill; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcTop] = + Defs::_8x32::DcTop; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcLeft] = + Defs::_8x32::DcLeft; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDc + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDc] = + Defs::_8x32::Dc; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorVertical + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorVertical] = + Defs::_8x32::Vertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorHorizontal] = + Defs::_8x32::Horizontal; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorPaeth] = + Defs::_8x32::Paeth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmooth] = + Defs::_8x32::Smooth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmoothVertical] = + Defs::_8x32::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmoothHorizontal] = + Defs::_8x32::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcFill] = + Defs8bpp::_16x4::DcFill; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcTop] = + Defs::_16x4::DcTop; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcLeft] = + Defs::_16x4::DcLeft; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDc + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDc] = + Defs::_16x4::Dc; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorVertical + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorVertical] = + Defs::_16x4::Vertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorHorizontal] = + Defs::_16x4::Horizontal; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorPaeth] = + Defs::_16x4::Paeth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmooth] = + Defs::_16x4::Smooth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmoothVertical] = + Defs::_16x4::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmoothHorizontal] = + Defs::_16x4::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcFill] = + Defs8bpp::_16x8::DcFill; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcTop] = + Defs::_16x8::DcTop; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcLeft] = + Defs::_16x8::DcLeft; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDc + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDc] = + Defs::_16x8::Dc; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorVertical + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorVertical] = + Defs::_16x8::Vertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorHorizontal] = + Defs::_16x8::Horizontal; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorPaeth] = + Defs::_16x8::Paeth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmooth] = + Defs::_16x8::Smooth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmoothVertical] = + Defs::_16x8::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmoothHorizontal] = + Defs::_16x8::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcFill] = + Defs8bpp::_16x16::DcFill; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcTop] = + Defs::_16x16::DcTop; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcLeft] = + Defs::_16x16::DcLeft; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDc + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDc] = + Defs::_16x16::Dc; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorVertical + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorVertical] = + Defs::_16x16::Vertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorHorizontal] = + Defs::_16x16::Horizontal; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorPaeth] = + Defs::_16x16::Paeth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmooth] = + Defs::_16x16::Smooth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmoothVertical] = + Defs::_16x16::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmoothHorizontal] = + Defs::_16x16::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcFill] = + Defs8bpp::_16x32::DcFill; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcTop] = + Defs::_16x32::DcTop; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcLeft] = + Defs::_16x32::DcLeft; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDc + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDc] = + Defs::_16x32::Dc; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorVertical + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorVertical] = + Defs::_16x32::Vertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorHorizontal] = + Defs::_16x32::Horizontal; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorPaeth] = + Defs::_16x32::Paeth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmooth] = + Defs::_16x32::Smooth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmoothVertical] = + Defs::_16x32::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmoothHorizontal] = + Defs::_16x32::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcFill] = + Defs8bpp::_16x64::DcFill; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcTop] = + Defs::_16x64::DcTop; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcLeft] = + Defs::_16x64::DcLeft; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDc + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDc] = + Defs::_16x64::Dc; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorVertical + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorVertical] = + Defs::_16x64::Vertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorHorizontal] = + Defs::_16x64::Horizontal; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorPaeth] = + Defs::_16x64::Paeth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmooth] = + Defs::_16x64::Smooth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmoothVertical] = + Defs::_16x64::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmoothHorizontal] = + Defs::_16x64::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcFill] = + Defs8bpp::_32x8::DcFill; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcTop] = + Defs::_32x8::DcTop; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcLeft] = + Defs::_32x8::DcLeft; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDc + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDc] = + Defs::_32x8::Dc; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorVertical + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorVertical] = + Defs::_32x8::Vertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorHorizontal] = + Defs::_32x8::Horizontal; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorPaeth] = + Defs::_32x8::Paeth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmooth] = + Defs::_32x8::Smooth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmoothVertical] = + Defs::_32x8::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmoothHorizontal] = + Defs::_32x8::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcFill] = + Defs8bpp::_32x16::DcFill; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcTop] = + Defs::_32x16::DcTop; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcLeft] = + Defs::_32x16::DcLeft; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDc + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDc] = + Defs::_32x16::Dc; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorVertical + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorVertical] = + Defs::_32x16::Vertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorHorizontal] = + Defs::_32x16::Horizontal; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorPaeth] = + Defs::_32x16::Paeth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmooth] = + Defs::_32x16::Smooth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmoothVertical] = + Defs::_32x16::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmoothHorizontal] = + Defs::_32x16::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcFill] = + Defs8bpp::_32x32::DcFill; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcTop] = + Defs::_32x32::DcTop; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcLeft] = + Defs::_32x32::DcLeft; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDc + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDc] = + Defs::_32x32::Dc; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorVertical + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorVertical] = + Defs::_32x32::Vertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorHorizontal] = + Defs::_32x32::Horizontal; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorPaeth] = + Defs::_32x32::Paeth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmooth] = + Defs::_32x32::Smooth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmoothVertical] = + Defs::_32x32::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmoothHorizontal] = + Defs::_32x32::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcFill] = + Defs8bpp::_32x64::DcFill; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcTop] = + Defs::_32x64::DcTop; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcLeft] = + Defs::_32x64::DcLeft; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDc + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDc] = + Defs::_32x64::Dc; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorVertical + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorVertical] = + Defs::_32x64::Vertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorHorizontal] = + Defs::_32x64::Horizontal; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorPaeth] = + Defs::_32x64::Paeth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmooth] = + Defs::_32x64::Smooth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmoothVertical] = + Defs::_32x64::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmoothHorizontal] = + Defs::_32x64::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcFill] = + Defs8bpp::_64x16::DcFill; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcTop] = + Defs::_64x16::DcTop; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcLeft] = + Defs::_64x16::DcLeft; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDc + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDc] = + Defs::_64x16::Dc; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorVertical + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorVertical] = + Defs::_64x16::Vertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorHorizontal] = + Defs::_64x16::Horizontal; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorPaeth] = + Defs::_64x16::Paeth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmooth] = + Defs::_64x16::Smooth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmoothVertical] = + Defs::_64x16::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmoothHorizontal] = + Defs::_64x16::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcFill] = + Defs8bpp::_64x32::DcFill; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcTop] = + Defs::_64x32::DcTop; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcLeft] = + Defs::_64x32::DcLeft; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDc + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDc] = + Defs::_64x32::Dc; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorVertical + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorVertical] = + Defs::_64x32::Vertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorHorizontal] = + Defs::_64x32::Horizontal; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorPaeth] = + Defs::_64x32::Paeth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmooth] = + Defs::_64x32::Smooth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmoothVertical] = + Defs::_64x32::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmoothHorizontal] = + Defs::_64x32::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcFill] = + Defs8bpp::_64x64::DcFill; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcTop] = + Defs::_64x64::DcTop; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcLeft] = + Defs::_64x64::DcLeft; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDc + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDc] = + Defs::_64x64::Dc; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorVertical + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorVertical] = + Defs::_64x64::Vertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorHorizontal] = + Defs::_64x64::Horizontal; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorPaeth] = + Defs::_64x64::Paeth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmooth] = + Defs::_64x64::Smooth; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmoothVertical] = + Defs::_64x64::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmoothHorizontal] = + Defs::_64x64::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone1 + dsp->directional_intra_predictor_zone1 = + DirectionalIntraPredictorZone1_C<uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone2 + dsp->directional_intra_predictor_zone2 = + DirectionalIntraPredictorZone2_C<uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone3 + dsp->directional_intra_predictor_zone3 = + DirectionalIntraPredictorZone3_C<uint8_t>; +#endif + +#ifndef LIBGAV1_Dsp8bpp_FilterIntraPredictor + dsp->filter_intra_predictor = FilterIntraPredictor_C<8, uint8_t>; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize4x4] = + CflIntraPredictor_C<4, 4, 8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType444] = + CflSubsampler_C<4, 4, 8, uint8_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType422] = + CflSubsampler_C<4, 4, 8, uint8_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType420] = + CflSubsampler_C<4, 4, 8, uint8_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize4x8] = + CflIntraPredictor_C<4, 8, 8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType444] = + CflSubsampler_C<4, 8, 8, uint8_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType422] = + CflSubsampler_C<4, 8, 8, uint8_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType420] = + CflSubsampler_C<4, 8, 8, uint8_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize4x16] = + CflIntraPredictor_C<4, 16, 8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType444] = + CflSubsampler_C<4, 16, 8, uint8_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType422] = + CflSubsampler_C<4, 16, 8, uint8_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType420] = + CflSubsampler_C<4, 16, 8, uint8_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize8x4] = + CflIntraPredictor_C<8, 4, 8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType444] = + CflSubsampler_C<8, 4, 8, uint8_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType422] = + CflSubsampler_C<8, 4, 8, uint8_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType420] = + CflSubsampler_C<8, 4, 8, uint8_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize8x8] = + CflIntraPredictor_C<8, 8, 8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType444] = + CflSubsampler_C<8, 8, 8, uint8_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType422] = + CflSubsampler_C<8, 8, 8, uint8_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType420] = + CflSubsampler_C<8, 8, 8, uint8_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize8x16] = + CflIntraPredictor_C<8, 16, 8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType444] = + CflSubsampler_C<8, 16, 8, uint8_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType422] = + CflSubsampler_C<8, 16, 8, uint8_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType420] = + CflSubsampler_C<8, 16, 8, uint8_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize8x32] = + CflIntraPredictor_C<8, 32, 8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType444] = + CflSubsampler_C<8, 32, 8, uint8_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType422] = + CflSubsampler_C<8, 32, 8, uint8_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType420] = + CflSubsampler_C<8, 32, 8, uint8_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize16x4] = + CflIntraPredictor_C<16, 4, 8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType444] = + CflSubsampler_C<16, 4, 8, uint8_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType422] = + CflSubsampler_C<16, 4, 8, uint8_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType420] = + CflSubsampler_C<16, 4, 8, uint8_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize16x8] = + CflIntraPredictor_C<16, 8, 8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType444] = + CflSubsampler_C<16, 8, 8, uint8_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType422] = + CflSubsampler_C<16, 8, 8, uint8_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType420] = + CflSubsampler_C<16, 8, 8, uint8_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize16x16] = + CflIntraPredictor_C<16, 16, 8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType444] = + CflSubsampler_C<16, 16, 8, uint8_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType422] = + CflSubsampler_C<16, 16, 8, uint8_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType420] = + CflSubsampler_C<16, 16, 8, uint8_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize16x32] = + CflIntraPredictor_C<16, 32, 8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType444] = + CflSubsampler_C<16, 32, 8, uint8_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType422] = + CflSubsampler_C<16, 32, 8, uint8_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType420] = + CflSubsampler_C<16, 32, 8, uint8_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize32x8] = + CflIntraPredictor_C<32, 8, 8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType444] = + CflSubsampler_C<32, 8, 8, uint8_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType422] = + CflSubsampler_C<32, 8, 8, uint8_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType420] = + CflSubsampler_C<32, 8, 8, uint8_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize32x16] = + CflIntraPredictor_C<32, 16, 8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType444] = + CflSubsampler_C<32, 16, 8, uint8_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType422] = + CflSubsampler_C<32, 16, 8, uint8_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType420] = + CflSubsampler_C<32, 16, 8, uint8_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize32x32] = + CflIntraPredictor_C<32, 32, 8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType444] = + CflSubsampler_C<32, 32, 8, uint8_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType422] = + CflSubsampler_C<32, 32, 8, uint8_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType420] = + CflSubsampler_C<32, 32, 8, uint8_t, 1, 1>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + // Cfl predictors are available only for transform sizes with max(width, + // height) <= 32. Set all others to nullptr. + for (const auto i : kTransformSizesLargerThan32x32) { + dsp->cfl_intra_predictors[i] = nullptr; + for (int j = 0; j < kNumSubsamplingTypes; ++j) { + dsp->cfl_subsamplers[i][j] = nullptr; + } + } +} // NOLINT(readability/fn_size) + +#if LIBGAV1_MAX_BITDEPTH >= 10 +using DefsHbd = IntraPredDefs<uint16_t>; +using Defs10bpp = IntraPredBppDefs<10, uint16_t>; + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(10); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + INIT_INTRAPREDICTORS(DefsHbd, Defs10bpp); + dsp->directional_intra_predictor_zone1 = + DirectionalIntraPredictorZone1_C<uint16_t>; + dsp->directional_intra_predictor_zone2 = + DirectionalIntraPredictorZone2_C<uint16_t>; + dsp->directional_intra_predictor_zone3 = + DirectionalIntraPredictorZone3_C<uint16_t>; + dsp->filter_intra_predictor = FilterIntraPredictor_C<10, uint16_t>; + INIT_CFL_INTRAPREDICTORS(10, uint16_t); +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcFill] = + Defs10bpp::_4x4::DcFill; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcTop] = + DefsHbd::_4x4::DcTop; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcLeft] = + DefsHbd::_4x4::DcLeft; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDc + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDc] = + DefsHbd::_4x4::Dc; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorVertical + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorVertical] = + DefsHbd::_4x4::Vertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorHorizontal] = + DefsHbd::_4x4::Horizontal; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorPaeth] = + DefsHbd::_4x4::Paeth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmooth] = + DefsHbd::_4x4::Smooth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmoothVertical] = + DefsHbd::_4x4::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmoothHorizontal] = + DefsHbd::_4x4::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcFill] = + Defs10bpp::_4x8::DcFill; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcTop] = + DefsHbd::_4x8::DcTop; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcLeft] = + DefsHbd::_4x8::DcLeft; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDc + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDc] = + DefsHbd::_4x8::Dc; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorVertical + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorVertical] = + DefsHbd::_4x8::Vertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorHorizontal] = + DefsHbd::_4x8::Horizontal; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorPaeth] = + DefsHbd::_4x8::Paeth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmooth] = + DefsHbd::_4x8::Smooth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmoothVertical] = + DefsHbd::_4x8::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmoothHorizontal] = + DefsHbd::_4x8::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcFill] = + Defs10bpp::_4x16::DcFill; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcTop] = + DefsHbd::_4x16::DcTop; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcLeft] = + DefsHbd::_4x16::DcLeft; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDc + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDc] = + DefsHbd::_4x16::Dc; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorVertical + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorVertical] = + DefsHbd::_4x16::Vertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorHorizontal] = + DefsHbd::_4x16::Horizontal; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorPaeth] = + DefsHbd::_4x16::Paeth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmooth] = + DefsHbd::_4x16::Smooth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmoothVertical] = + DefsHbd::_4x16::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmoothHorizontal] = + DefsHbd::_4x16::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcFill] = + Defs10bpp::_8x4::DcFill; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcTop] = + DefsHbd::_8x4::DcTop; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcLeft] = + DefsHbd::_8x4::DcLeft; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDc + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDc] = + DefsHbd::_8x4::Dc; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorVertical + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorVertical] = + DefsHbd::_8x4::Vertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorHorizontal] = + DefsHbd::_8x4::Horizontal; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorPaeth] = + DefsHbd::_8x4::Paeth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmooth] = + DefsHbd::_8x4::Smooth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmoothVertical] = + DefsHbd::_8x4::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmoothHorizontal] = + DefsHbd::_8x4::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcFill] = + Defs10bpp::_8x8::DcFill; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcTop] = + DefsHbd::_8x8::DcTop; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcLeft] = + DefsHbd::_8x8::DcLeft; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDc + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDc] = + DefsHbd::_8x8::Dc; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorVertical + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorVertical] = + DefsHbd::_8x8::Vertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorHorizontal] = + DefsHbd::_8x8::Horizontal; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorPaeth] = + DefsHbd::_8x8::Paeth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmooth] = + DefsHbd::_8x8::Smooth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmoothVertical] = + DefsHbd::_8x8::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmoothHorizontal] = + DefsHbd::_8x8::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcFill] = + Defs10bpp::_8x16::DcFill; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcTop] = + DefsHbd::_8x16::DcTop; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcLeft] = + DefsHbd::_8x16::DcLeft; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDc + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDc] = + DefsHbd::_8x16::Dc; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorVertical + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorVertical] = + DefsHbd::_8x16::Vertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorHorizontal] = + DefsHbd::_8x16::Horizontal; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorPaeth] = + DefsHbd::_8x16::Paeth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmooth] = + DefsHbd::_8x16::Smooth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmoothVertical] = + DefsHbd::_8x16::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmoothHorizontal] = + DefsHbd::_8x16::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcFill] = + Defs10bpp::_8x32::DcFill; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcTop] = + DefsHbd::_8x32::DcTop; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcLeft] = + DefsHbd::_8x32::DcLeft; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDc + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDc] = + DefsHbd::_8x32::Dc; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorVertical + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorVertical] = + DefsHbd::_8x32::Vertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorHorizontal] = + DefsHbd::_8x32::Horizontal; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorPaeth] = + DefsHbd::_8x32::Paeth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmooth] = + DefsHbd::_8x32::Smooth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmoothVertical] = + DefsHbd::_8x32::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmoothHorizontal] = + DefsHbd::_8x32::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcFill] = + Defs10bpp::_16x4::DcFill; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcTop] = + DefsHbd::_16x4::DcTop; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcLeft] = + DefsHbd::_16x4::DcLeft; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDc + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDc] = + DefsHbd::_16x4::Dc; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorVertical + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorVertical] = + DefsHbd::_16x4::Vertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorHorizontal] = + DefsHbd::_16x4::Horizontal; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorPaeth] = + DefsHbd::_16x4::Paeth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmooth] = + DefsHbd::_16x4::Smooth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmoothVertical] = + DefsHbd::_16x4::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmoothHorizontal] = + DefsHbd::_16x4::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcFill] = + Defs10bpp::_16x8::DcFill; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcTop] = + DefsHbd::_16x8::DcTop; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcLeft] = + DefsHbd::_16x8::DcLeft; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDc + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDc] = + DefsHbd::_16x8::Dc; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorVertical + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorVertical] = + DefsHbd::_16x8::Vertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorHorizontal] = + DefsHbd::_16x8::Horizontal; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorPaeth] = + DefsHbd::_16x8::Paeth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmooth] = + DefsHbd::_16x8::Smooth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmoothVertical] = + DefsHbd::_16x8::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmoothHorizontal] = + DefsHbd::_16x8::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcFill] = + Defs10bpp::_16x16::DcFill; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcTop] = + DefsHbd::_16x16::DcTop; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcLeft] = + DefsHbd::_16x16::DcLeft; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDc + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDc] = + DefsHbd::_16x16::Dc; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorVertical + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorVertical] = + DefsHbd::_16x16::Vertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorHorizontal] = + DefsHbd::_16x16::Horizontal; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorPaeth] = + DefsHbd::_16x16::Paeth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmooth] = + DefsHbd::_16x16::Smooth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmoothVertical] = + DefsHbd::_16x16::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmoothHorizontal] = + DefsHbd::_16x16::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcFill] = + Defs10bpp::_16x32::DcFill; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcTop] = + DefsHbd::_16x32::DcTop; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcLeft] = + DefsHbd::_16x32::DcLeft; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDc + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDc] = + DefsHbd::_16x32::Dc; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorVertical + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorVertical] = + DefsHbd::_16x32::Vertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorHorizontal] = + DefsHbd::_16x32::Horizontal; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorPaeth] = + DefsHbd::_16x32::Paeth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmooth] = + DefsHbd::_16x32::Smooth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmoothVertical] = + DefsHbd::_16x32::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmoothHorizontal] = + DefsHbd::_16x32::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcFill] = + Defs10bpp::_16x64::DcFill; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcTop] = + DefsHbd::_16x64::DcTop; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcLeft] = + DefsHbd::_16x64::DcLeft; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDc + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDc] = + DefsHbd::_16x64::Dc; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorVertical + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorVertical] = + DefsHbd::_16x64::Vertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorHorizontal] = + DefsHbd::_16x64::Horizontal; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorPaeth] = + DefsHbd::_16x64::Paeth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmooth] = + DefsHbd::_16x64::Smooth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmoothVertical] = + DefsHbd::_16x64::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmoothHorizontal] = + DefsHbd::_16x64::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcFill] = + Defs10bpp::_32x8::DcFill; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcTop] = + DefsHbd::_32x8::DcTop; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcLeft] = + DefsHbd::_32x8::DcLeft; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDc + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDc] = + DefsHbd::_32x8::Dc; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorVertical + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorVertical] = + DefsHbd::_32x8::Vertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorHorizontal] = + DefsHbd::_32x8::Horizontal; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorPaeth] = + DefsHbd::_32x8::Paeth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmooth] = + DefsHbd::_32x8::Smooth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmoothVertical] = + DefsHbd::_32x8::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmoothHorizontal] = + DefsHbd::_32x8::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcFill] = + Defs10bpp::_32x16::DcFill; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcTop] = + DefsHbd::_32x16::DcTop; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcLeft] = + DefsHbd::_32x16::DcLeft; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDc + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDc] = + DefsHbd::_32x16::Dc; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorVertical + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorVertical] = + DefsHbd::_32x16::Vertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorHorizontal] = + DefsHbd::_32x16::Horizontal; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorPaeth] = + DefsHbd::_32x16::Paeth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmooth] = + DefsHbd::_32x16::Smooth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmoothVertical] = + DefsHbd::_32x16::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmoothHorizontal] = + DefsHbd::_32x16::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcFill] = + Defs10bpp::_32x32::DcFill; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcTop] = + DefsHbd::_32x32::DcTop; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcLeft] = + DefsHbd::_32x32::DcLeft; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDc + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDc] = + DefsHbd::_32x32::Dc; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorVertical + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorVertical] = + DefsHbd::_32x32::Vertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorHorizontal] = + DefsHbd::_32x32::Horizontal; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorPaeth] = + DefsHbd::_32x32::Paeth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmooth] = + DefsHbd::_32x32::Smooth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmoothVertical] = + DefsHbd::_32x32::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmoothHorizontal] = + DefsHbd::_32x32::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcFill] = + Defs10bpp::_32x64::DcFill; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcTop] = + DefsHbd::_32x64::DcTop; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcLeft] = + DefsHbd::_32x64::DcLeft; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorDc + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDc] = + DefsHbd::_32x64::Dc; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorVertical + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorVertical] = + DefsHbd::_32x64::Vertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorHorizontal] = + DefsHbd::_32x64::Horizontal; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorPaeth] = + DefsHbd::_32x64::Paeth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmooth] = + DefsHbd::_32x64::Smooth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmoothVertical] = + DefsHbd::_32x64::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmoothHorizontal] = + DefsHbd::_32x64::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcFill] = + Defs10bpp::_64x16::DcFill; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcTop] = + DefsHbd::_64x16::DcTop; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcLeft] = + DefsHbd::_64x16::DcLeft; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDc + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDc] = + DefsHbd::_64x16::Dc; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorVertical + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorVertical] = + DefsHbd::_64x16::Vertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorHorizontal] = + DefsHbd::_64x16::Horizontal; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorPaeth] = + DefsHbd::_64x16::Paeth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmooth] = + DefsHbd::_64x16::Smooth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmoothVertical] = + DefsHbd::_64x16::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmoothHorizontal] = + DefsHbd::_64x16::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcFill] = + Defs10bpp::_64x32::DcFill; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcTop] = + DefsHbd::_64x32::DcTop; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcLeft] = + DefsHbd::_64x32::DcLeft; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDc + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDc] = + DefsHbd::_64x32::Dc; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorVertical + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorVertical] = + DefsHbd::_64x32::Vertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorHorizontal] = + DefsHbd::_64x32::Horizontal; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorPaeth] = + DefsHbd::_64x32::Paeth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmooth] = + DefsHbd::_64x32::Smooth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmoothVertical] = + DefsHbd::_64x32::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmoothHorizontal] = + DefsHbd::_64x32::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDcFill + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcFill] = + Defs10bpp::_64x64::DcFill; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDcTop + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcTop] = + DefsHbd::_64x64::DcTop; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDcLeft + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcLeft] = + DefsHbd::_64x64::DcLeft; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDc + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDc] = + DefsHbd::_64x64::Dc; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorVertical + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorVertical] = + DefsHbd::_64x64::Vertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorHorizontal + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorHorizontal] = + DefsHbd::_64x64::Horizontal; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorPaeth + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorPaeth] = + DefsHbd::_64x64::Paeth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorSmooth + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmooth] = + DefsHbd::_64x64::Smooth; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorSmoothVertical + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmoothVertical] = + DefsHbd::_64x64::SmoothVertical; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorSmoothHorizontal + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmoothHorizontal] = + DefsHbd::_64x64::SmoothHorizontal; +#endif + +#ifndef LIBGAV1_Dsp10bpp_DirectionalIntraPredictorZone1 + dsp->directional_intra_predictor_zone1 = + DirectionalIntraPredictorZone1_C<uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_DirectionalIntraPredictorZone2 + dsp->directional_intra_predictor_zone2 = + DirectionalIntraPredictorZone2_C<uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_DirectionalIntraPredictorZone3 + dsp->directional_intra_predictor_zone3 = + DirectionalIntraPredictorZone3_C<uint16_t>; +#endif + +#ifndef LIBGAV1_Dsp10bpp_FilterIntraPredictor + dsp->filter_intra_predictor = FilterIntraPredictor_C<10, uint16_t>; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize4x4] = + CflIntraPredictor_C<4, 4, 10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType444] = + CflSubsampler_C<4, 4, 10, uint16_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType422] = + CflSubsampler_C<4, 4, 10, uint16_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType420] = + CflSubsampler_C<4, 4, 10, uint16_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x8_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize4x8] = + CflIntraPredictor_C<4, 8, 10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x8_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType444] = + CflSubsampler_C<4, 8, 10, uint16_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x8_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType422] = + CflSubsampler_C<4, 8, 10, uint16_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x8_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType420] = + CflSubsampler_C<4, 8, 10, uint16_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x16_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize4x16] = + CflIntraPredictor_C<4, 16, 10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x16_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType444] = + CflSubsampler_C<4, 16, 10, uint16_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x16_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType422] = + CflSubsampler_C<4, 16, 10, uint16_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x16_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType420] = + CflSubsampler_C<4, 16, 10, uint16_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x4_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize8x4] = + CflIntraPredictor_C<8, 4, 10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x4_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType444] = + CflSubsampler_C<8, 4, 10, uint16_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x4_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType422] = + CflSubsampler_C<8, 4, 10, uint16_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x4_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType420] = + CflSubsampler_C<8, 4, 10, uint16_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x8_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize8x8] = + CflIntraPredictor_C<8, 8, 10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x8_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType444] = + CflSubsampler_C<8, 8, 10, uint16_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x8_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType422] = + CflSubsampler_C<8, 8, 10, uint16_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x8_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType420] = + CflSubsampler_C<8, 8, 10, uint16_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x16_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize8x16] = + CflIntraPredictor_C<8, 16, 10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x16_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType444] = + CflSubsampler_C<8, 16, 10, uint16_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x16_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType422] = + CflSubsampler_C<8, 16, 10, uint16_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x16_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType420] = + CflSubsampler_C<8, 16, 10, uint16_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x32_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize8x32] = + CflIntraPredictor_C<8, 32, 10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x32_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType444] = + CflSubsampler_C<8, 32, 10, uint16_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x32_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType422] = + CflSubsampler_C<8, 32, 10, uint16_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x32_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType420] = + CflSubsampler_C<8, 32, 10, uint16_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x4_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize16x4] = + CflIntraPredictor_C<16, 4, 10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x4_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType444] = + CflSubsampler_C<16, 4, 10, uint16_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x4_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType422] = + CflSubsampler_C<16, 4, 10, uint16_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x4_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType420] = + CflSubsampler_C<16, 4, 10, uint16_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x8_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize16x8] = + CflIntraPredictor_C<16, 8, 10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x8_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType444] = + CflSubsampler_C<16, 8, 10, uint16_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x8_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType422] = + CflSubsampler_C<16, 8, 10, uint16_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x8_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType420] = + CflSubsampler_C<16, 8, 10, uint16_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x16_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize16x16] = + CflIntraPredictor_C<16, 16, 10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x16_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType444] = + CflSubsampler_C<16, 16, 10, uint16_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x16_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType422] = + CflSubsampler_C<16, 16, 10, uint16_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x16_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType420] = + CflSubsampler_C<16, 16, 10, uint16_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x32_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize16x32] = + CflIntraPredictor_C<16, 32, 10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x32_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType444] = + CflSubsampler_C<16, 32, 10, uint16_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x32_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType422] = + CflSubsampler_C<16, 32, 10, uint16_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x32_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType420] = + CflSubsampler_C<16, 32, 10, uint16_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x8_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize32x8] = + CflIntraPredictor_C<32, 8, 10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x8_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType444] = + CflSubsampler_C<32, 8, 10, uint16_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x8_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType422] = + CflSubsampler_C<32, 8, 10, uint16_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x8_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType420] = + CflSubsampler_C<32, 8, 10, uint16_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize32x16] = + CflIntraPredictor_C<32, 16, 10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType444] = + CflSubsampler_C<32, 16, 10, uint16_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType422] = + CflSubsampler_C<32, 16, 10, uint16_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType420] = + CflSubsampler_C<32, 16, 10, uint16_t, 1, 1>; +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x32_CflIntraPredictor + dsp->cfl_intra_predictors[kTransformSize32x32] = + CflIntraPredictor_C<32, 32, 10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x32_CflSubsampler444 + dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType444] = + CflSubsampler_C<32, 32, 10, uint16_t, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x32_CflSubsampler422 + dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType422] = + CflSubsampler_C<32, 32, 10, uint16_t, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x32_CflSubsampler420 + dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType420] = + CflSubsampler_C<32, 32, 10, uint16_t, 1, 1>; +#endif + +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + // Cfl predictors are available only for transform sizes with max(width, + // height) <= 32. Set all others to nullptr. + for (const auto i : kTransformSizesLargerThan32x32) { + dsp->cfl_intra_predictors[i] = nullptr; + for (int j = 0; j < kNumSubsamplingTypes; ++j) { + dsp->cfl_subsamplers[i][j] = nullptr; + } + } +} // NOLINT(readability/fn_size) +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +#undef INIT_CFL_INTRAPREDICTOR_WxH +#undef INIT_CFL_INTRAPREDICTORS +#undef INIT_INTRAPREDICTORS_WxH +#undef INIT_INTRAPREDICTORS + +} // namespace + +void IntraPredInit_C() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 diff --git a/src/dsp/intrapred.h b/src/dsp/intrapred.h new file mode 100644 index 0000000..c5286ef --- /dev/null +++ b/src/dsp/intrapred.h @@ -0,0 +1,49 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_INTRAPRED_H_ +#define LIBGAV1_SRC_DSP_INTRAPRED_H_ + +// Pull in LIBGAV1_DspXXX defines representing the implementation status +// of each function. The resulting value of each can be used by each module to +// determine whether an implementation is needed at compile time. +// IWYU pragma: begin_exports + +// ARM: +#include "src/dsp/arm/intrapred_neon.h" + +// x86: +// Note includes should be sorted in logical order avx2/avx/sse4, etc. +// The order of includes is important as each tests for a superior version +// before setting the base. +// clang-format off +#include "src/dsp/x86/intrapred_sse4.h" +// clang-format on + +// IWYU pragma: end_exports + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::intra_predictors, Dsp::directional_intra_predictor_zone*, +// Dsp::cfl_intra_predictors, Dsp::cfl_subsamplers and +// Dsp::filter_intra_predictor. This function is not thread-safe. +void IntraPredInit_C(); + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_INTRAPRED_H_ diff --git a/src/dsp/inverse_transform.cc b/src/dsp/inverse_transform.cc new file mode 100644 index 0000000..a03fad2 --- /dev/null +++ b/src/dsp/inverse_transform.cc @@ -0,0 +1,1636 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/inverse_transform.h" + +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <cstring> + +#include "src/dsp/dsp.h" +#include "src/utils/array_2d.h" +#include "src/utils/common.h" +#include "src/utils/compiler_attributes.h" +#include "src/utils/logging.h" + +namespace libgav1 { +namespace dsp { +namespace { + +// Include the constants and utility functions inside the anonymous namespace. +#include "src/dsp/inverse_transform.inc" + +constexpr uint8_t kTransformColumnShift = 4; + +#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) +#undef LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK +#endif + +int32_t RangeCheckValue(int32_t value, int8_t range) { +#if defined(LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK) && \ + LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK + assert(range <= 32); + const int32_t min = -(1 << (range - 1)); + const int32_t max = (1 << (range - 1)) - 1; + if (min > value || value > max) { + LIBGAV1_DLOG(ERROR, "coeff out of bit range, value: %d bit range %d\n", + value, range); + assert(min <= value && value <= max); + } +#endif // LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK + static_cast<void>(range); + return value; +} + +template <typename Residual> +LIBGAV1_ALWAYS_INLINE void ButterflyRotation_C(Residual* const dst, int a, + int b, int angle, bool flip, + int8_t range) { + // Note that we multiply in 32 bits and then add/subtract the products in 64 + // bits. The 32-bit multiplications do not overflow. Please see the comment + // and assert() in Cos128(). + const int64_t x = static_cast<int64_t>(dst[a] * Cos128(angle)) - + static_cast<int64_t>(dst[b] * Sin128(angle)); + const int64_t y = static_cast<int64_t>(dst[a] * Sin128(angle)) + + static_cast<int64_t>(dst[b] * Cos128(angle)); + // Section 7.13.2.1: It is a requirement of bitstream conformance that the + // values saved into the array T by this function are representable by a + // signed integer using |range| bits of precision. + dst[a] = RangeCheckValue(RightShiftWithRounding(flip ? y : x, 12), range); + dst[b] = RangeCheckValue(RightShiftWithRounding(flip ? x : y, 12), range); +} + +template <typename Residual> +void ButterflyRotationFirstIsZero_C(Residual* const dst, int a, int b, + int angle, bool flip, int8_t range) { + // Note that we multiply in 32 bits and then add/subtract the products in 64 + // bits. The 32-bit multiplications do not overflow. Please see the comment + // and assert() in Cos128(). + const auto x = static_cast<int64_t>(dst[b] * -Sin128(angle)); + const auto y = static_cast<int64_t>(dst[b] * Cos128(angle)); + // Section 7.13.2.1: It is a requirement of bitstream conformance that the + // values saved into the array T by this function are representable by a + // signed integer using |range| bits of precision. + dst[a] = RangeCheckValue(RightShiftWithRounding(flip ? y : x, 12), range); + dst[b] = RangeCheckValue(RightShiftWithRounding(flip ? x : y, 12), range); +} + +template <typename Residual> +void ButterflyRotationSecondIsZero_C(Residual* const dst, int a, int b, + int angle, bool flip, int8_t range) { + // Note that we multiply in 32 bits and then add/subtract the products in 64 + // bits. The 32-bit multiplications do not overflow. Please see the comment + // and assert() in Cos128(). + const auto x = static_cast<int64_t>(dst[a] * Cos128(angle)); + const auto y = static_cast<int64_t>(dst[a] * Sin128(angle)); + + // Section 7.13.2.1: It is a requirement of bitstream conformance that the + // values saved into the array T by this function are representable by a + // signed integer using |range| bits of precision. + dst[a] = RangeCheckValue(RightShiftWithRounding(flip ? y : x, 12), range); + dst[b] = RangeCheckValue(RightShiftWithRounding(flip ? x : y, 12), range); +} + +template <typename Residual> +void HadamardRotation_C(Residual* const dst, int a, int b, bool flip, + int8_t range) { + if (flip) std::swap(a, b); + --range; + // For Adst and Dct, the maximum possible value for range is 20. So min and + // max should always fit into int32_t. + const int32_t min = -(1 << range); + const int32_t max = (1 << range) - 1; + const int32_t x = dst[a] + dst[b]; + const int32_t y = dst[a] - dst[b]; + dst[a] = Clip3(x, min, max); + dst[b] = Clip3(y, min, max); +} + +template <int bitdepth, typename Residual> +void ClampIntermediate(Residual* const dst, int size) { + // If Residual is int16_t (which implies bitdepth is 8), we don't need to + // clip residual[i][j] to 16 bits. + if (sizeof(Residual) > 2) { + const Residual intermediate_clamp_max = + (1 << (std::max(bitdepth + 6, 16) - 1)) - 1; + const Residual intermediate_clamp_min = -intermediate_clamp_max - 1; + for (int j = 0; j < size; ++j) { + dst[j] = Clip3(dst[j], intermediate_clamp_min, intermediate_clamp_max); + } + } +} + +//------------------------------------------------------------------------------ +// Discrete Cosine Transforms (DCT). + +// Value for index (i, j) is computed as bitreverse(j) and interpreting that as +// an integer with bit-length i + 2. +// For e.g. index (2, 3) will be computed as follows: +// * bitreverse(3) = bitreverse(..000011) = 110000... +// * interpreting that as an integer with bit-length 2+2 = 4 will be 1100 = 12 +constexpr uint8_t kBitReverseLookup[kNum1DTransformSizes][64] = { + {0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, + 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, + 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3}, + {0, 4, 2, 6, 1, 5, 3, 7, 0, 4, 2, 6, 1, 5, 3, 7, 0, 4, 2, 6, 1, 5, + 3, 7, 0, 4, 2, 6, 1, 5, 3, 7, 0, 4, 2, 6, 1, 5, 3, 7, 0, 4, 2, 6, + 1, 5, 3, 7, 0, 4, 2, 6, 1, 5, 3, 7, 0, 4, 2, 6, 1, 5, 3, 7}, + {0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15, + 0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15, + 0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15, + 0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15}, + {0, 16, 8, 24, 4, 20, 12, 28, 2, 18, 10, 26, 6, 22, 14, 30, + 1, 17, 9, 25, 5, 21, 13, 29, 3, 19, 11, 27, 7, 23, 15, 31, + 0, 16, 8, 24, 4, 20, 12, 28, 2, 18, 10, 26, 6, 22, 14, 30, + 1, 17, 9, 25, 5, 21, 13, 29, 3, 19, 11, 27, 7, 23, 15, 31}, + {0, 32, 16, 48, 8, 40, 24, 56, 4, 36, 20, 52, 12, 44, 28, 60, + 2, 34, 18, 50, 10, 42, 26, 58, 6, 38, 22, 54, 14, 46, 30, 62, + 1, 33, 17, 49, 9, 41, 25, 57, 5, 37, 21, 53, 13, 45, 29, 61, + 3, 35, 19, 51, 11, 43, 27, 59, 7, 39, 23, 55, 15, 47, 31, 63}}; + +template <typename Residual, int size_log2> +void Dct_C(void* dest, int8_t range) { + static_assert(size_log2 >= 2 && size_log2 <= 6, ""); + auto* const dst = static_cast<Residual*>(dest); + // stage 1. + const int size = 1 << size_log2; + Residual temp[size]; + memcpy(temp, dst, sizeof(temp)); + for (int i = 0; i < size; ++i) { + dst[i] = temp[kBitReverseLookup[size_log2 - 2][i]]; + } + // stages 2-32 are dependent on the value of size_log2. + // stage 2. + if (size_log2 == 6) { + for (int i = 0; i < 16; ++i) { + ButterflyRotation_C(dst, i + 32, 63 - i, + 63 - MultiplyBy4(kBitReverseLookup[2][i]), false, + range); + } + } + // stage 3 + if (size_log2 >= 5) { + for (int i = 0; i < 8; ++i) { + ButterflyRotation_C(dst, i + 16, 31 - i, + 6 + MultiplyBy8(kBitReverseLookup[1][7 - i]), false, + range); + } + } + // stage 4. + if (size_log2 == 6) { + for (int i = 0; i < 16; ++i) { + HadamardRotation_C(dst, MultiplyBy2(i) + 32, MultiplyBy2(i) + 33, + static_cast<bool>(i & 1), range); + } + } + // stage 5. + if (size_log2 >= 4) { + for (int i = 0; i < 4; ++i) { + ButterflyRotation_C(dst, i + 8, 15 - i, + 12 + MultiplyBy16(kBitReverseLookup[0][3 - i]), false, + range); + } + } + // stage 6. + if (size_log2 >= 5) { + for (int i = 0; i < 8; ++i) { + HadamardRotation_C(dst, MultiplyBy2(i) + 16, MultiplyBy2(i) + 17, + static_cast<bool>(i & 1), range); + } + } + // stage 7. + if (size_log2 == 6) { + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 2; ++j) { + ButterflyRotation_C( + dst, 62 - MultiplyBy4(i) - j, MultiplyBy4(i) + j + 33, + 60 - MultiplyBy16(kBitReverseLookup[0][i]) + MultiplyBy64(j), true, + range); + } + } + } + // stage 8. + if (size_log2 >= 3) { + for (int i = 0; i < 2; ++i) { + ButterflyRotation_C(dst, i + 4, 7 - i, 56 - 32 * i, false, range); + } + } + // stage 9. + if (size_log2 >= 4) { + for (int i = 0; i < 4; ++i) { + HadamardRotation_C(dst, MultiplyBy2(i) + 8, MultiplyBy2(i) + 9, + static_cast<bool>(i & 1), range); + } + } + // stage 10. + if (size_log2 >= 5) { + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + ButterflyRotation_C( + dst, 30 - MultiplyBy4(i) - j, MultiplyBy4(i) + j + 17, + 24 + MultiplyBy64(j) + MultiplyBy32(1 - i), true, range); + } + } + } + // stage 11. + if (size_log2 == 6) { + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 2; ++j) { + HadamardRotation_C(dst, MultiplyBy4(i) + j + 32, + MultiplyBy4(i) - j + 35, static_cast<bool>(i & 1), + range); + } + } + } + // stage 12. + for (int i = 0; i < 2; ++i) { + ButterflyRotation_C(dst, MultiplyBy2(i), MultiplyBy2(i) + 1, 32 + 16 * i, + i == 0, range); + } + // stage 13. + if (size_log2 >= 3) { + for (int i = 0; i < 2; ++i) { + HadamardRotation_C(dst, MultiplyBy2(i) + 4, MultiplyBy2(i) + 5, + /*flip=*/i != 0, range); + } + } + // stage 14. + if (size_log2 >= 4) { + for (int i = 0; i < 2; ++i) { + ButterflyRotation_C(dst, 14 - i, i + 9, 48 + 64 * i, true, range); + } + } + // stage 15. + if (size_log2 >= 5) { + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 2; ++j) { + HadamardRotation_C(dst, MultiplyBy4(i) + j + 16, + MultiplyBy4(i) - j + 19, static_cast<bool>(i & 1), + range); + } + } + } + // stage 16. + if (size_log2 == 6) { + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 4; ++j) { + ButterflyRotation_C( + dst, 61 - MultiplyBy8(i) - j, MultiplyBy8(i) + j + 34, + 56 - MultiplyBy32(i) + MultiplyBy64(DivideBy2(j)), true, range); + } + } + } + // stage 17. + for (int i = 0; i < 2; ++i) { + HadamardRotation_C(dst, i, 3 - i, false, range); + } + // stage 18. + if (size_log2 >= 3) { + ButterflyRotation_C(dst, 6, 5, 32, true, range); + } + // stage 19. + if (size_log2 >= 4) { + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + HadamardRotation_C(dst, MultiplyBy4(i) + j + 8, MultiplyBy4(i) - j + 11, + /*flip=*/i != 0, range); + } + } + } + // stage 20. + if (size_log2 >= 5) { + for (int i = 0; i < 4; ++i) { + ButterflyRotation_C(dst, 29 - i, i + 18, 48 + 64 * DivideBy2(i), true, + range); + } + } + // stage 21. + if (size_log2 == 6) { + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + HadamardRotation_C(dst, MultiplyBy8(i) + j + 32, + MultiplyBy8(i) - j + 39, static_cast<bool>(i & 1), + range); + } + } + } + // stage 22. + if (size_log2 >= 3) { + for (int i = 0; i < 4; ++i) { + HadamardRotation_C(dst, i, 7 - i, false, range); + } + } + // stage 23. + if (size_log2 >= 4) { + for (int i = 0; i < 2; ++i) { + ButterflyRotation_C(dst, 13 - i, i + 10, 32, true, range); + } + } + // stage 24. + if (size_log2 >= 5) { + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 4; ++j) { + HadamardRotation_C(dst, MultiplyBy8(i) + j + 16, + MultiplyBy8(i) - j + 23, i == 1, range); + } + } + } + // stage 25. + if (size_log2 == 6) { + for (int i = 0; i < 8; ++i) { + ButterflyRotation_C(dst, 59 - i, i + 36, (i < 4) ? 48 : 112, true, range); + } + } + // stage 26. + if (size_log2 >= 4) { + for (int i = 0; i < 8; ++i) { + HadamardRotation_C(dst, i, 15 - i, false, range); + } + } + // stage 27. + if (size_log2 >= 5) { + for (int i = 0; i < 4; ++i) { + ButterflyRotation_C(dst, 27 - i, i + 20, 32, true, range); + } + } + // stage 28. + if (size_log2 == 6) { + for (int i = 0; i < 8; ++i) { + HadamardRotation_C(dst, i + 32, 47 - i, false, range); + HadamardRotation_C(dst, i + 48, 63 - i, true, range); + } + } + // stage 29. + if (size_log2 >= 5) { + for (int i = 0; i < 16; ++i) { + HadamardRotation_C(dst, i, 31 - i, false, range); + } + } + // stage 30. + if (size_log2 == 6) { + for (int i = 0; i < 8; ++i) { + ButterflyRotation_C(dst, 55 - i, i + 40, 32, true, range); + } + } + // stage 31. + if (size_log2 == 6) { + for (int i = 0; i < 32; ++i) { + HadamardRotation_C(dst, i, 63 - i, false, range); + } + } +} + +template <int bitdepth, typename Residual, int size_log2> +void DctDcOnly_C(void* dest, int8_t range, bool should_round, int row_shift, + bool is_row) { + auto* const dst = static_cast<Residual*>(dest); + + if (is_row && should_round) { + dst[0] = RightShiftWithRounding(dst[0] * kTransformRowMultiplier, 12); + } + + ButterflyRotationSecondIsZero_C(dst, 0, 1, 32, true, range); + + if (is_row && row_shift > 0) { + dst[0] = RightShiftWithRounding(dst[0], row_shift); + } + + ClampIntermediate<bitdepth, Residual>(dst, 1); + + const int size = 1 << size_log2; + for (int i = 1; i < size; ++i) { + dst[i] = dst[0]; + } +} + +//------------------------------------------------------------------------------ +// Asymmetric Discrete Sine Transforms (ADST). + +/* + * Row transform max range in bits for bitdepths 8/10/12: 28/30/32. + * Column transform max range in bits for bitdepths 8/10/12: 28/28/30. + */ +template <typename Residual> +void Adst4_C(void* dest, int8_t range) { + auto* const dst = static_cast<Residual*>(dest); + if ((dst[0] | dst[1] | dst[2] | dst[3]) == 0) { + return; + } + + // stage 1. + // Section 7.13.2.6: It is a requirement of bitstream conformance that all + // values stored in the s and x arrays by this process are representable by + // a signed integer using range + 12 bits of precision. + int32_t s[7]; + s[0] = RangeCheckValue(kAdst4Multiplier[0] * dst[0], range + 12); + s[1] = RangeCheckValue(kAdst4Multiplier[1] * dst[0], range + 12); + s[2] = RangeCheckValue(kAdst4Multiplier[2] * dst[1], range + 12); + s[3] = RangeCheckValue(kAdst4Multiplier[3] * dst[2], range + 12); + s[4] = RangeCheckValue(kAdst4Multiplier[0] * dst[2], range + 12); + s[5] = RangeCheckValue(kAdst4Multiplier[1] * dst[3], range + 12); + s[6] = RangeCheckValue(kAdst4Multiplier[3] * dst[3], range + 12); + // stage 2. + // Section 7.13.2.6: It is a requirement of bitstream conformance that + // values stored in the variable a7 by this process are representable by a + // signed integer using range + 1 bits of precision. + const int32_t a7 = RangeCheckValue(dst[0] - dst[2], range + 1); + // Section 7.13.2.6: It is a requirement of bitstream conformance that + // values stored in the variable b7 by this process are representable by a + // signed integer using |range| bits of precision. + const int32_t b7 = RangeCheckValue(a7 + dst[3], range); + // stage 3. + s[0] = RangeCheckValue(s[0] + s[3], range + 12); + s[1] = RangeCheckValue(s[1] - s[4], range + 12); + s[3] = s[2]; + s[2] = RangeCheckValue(kAdst4Multiplier[2] * b7, range + 12); + // stage 4. + s[0] = RangeCheckValue(s[0] + s[5], range + 12); + s[1] = RangeCheckValue(s[1] - s[6], range + 12); + // stages 5 and 6. + const int32_t x0 = RangeCheckValue(s[0] + s[3], range + 12); + const int32_t x1 = RangeCheckValue(s[1] + s[3], range + 12); + int32_t x3 = RangeCheckValue(s[0] + s[1], range + 12); + x3 = RangeCheckValue(x3 - s[3], range + 12); + int32_t dst_0 = RightShiftWithRounding(x0, 12); + int32_t dst_1 = RightShiftWithRounding(x1, 12); + int32_t dst_2 = RightShiftWithRounding(s[2], 12); + int32_t dst_3 = RightShiftWithRounding(x3, 12); + if (sizeof(Residual) == 2) { + // If the first argument to RightShiftWithRounding(..., 12) is only + // slightly smaller than 2^27 - 1 (e.g., 0x7fffe4e), adding 2^11 to it + // in RightShiftWithRounding(..., 12) will cause the function to return + // 0x8000, which cannot be represented as an int16_t. Change it to 0x7fff. + dst_0 -= (dst_0 == 0x8000); + dst_1 -= (dst_1 == 0x8000); + dst_3 -= (dst_3 == 0x8000); + } + dst[0] = dst_0; + dst[1] = dst_1; + dst[2] = dst_2; + dst[3] = dst_3; +} + +template <int bitdepth, typename Residual> +void Adst4DcOnly_C(void* dest, int8_t range, bool should_round, int row_shift, + bool is_row) { + auto* const dst = static_cast<Residual*>(dest); + + if (is_row && should_round) { + dst[0] = RightShiftWithRounding(dst[0] * kTransformRowMultiplier, 12); + } + + // stage 1. + // Section 7.13.2.6: It is a requirement of bitstream conformance that all + // values stored in the s and x arrays by this process are representable by + // a signed integer using range + 12 bits of precision. + int32_t s[3]; + s[0] = RangeCheckValue(kAdst4Multiplier[0] * dst[0], range + 12); + s[1] = RangeCheckValue(kAdst4Multiplier[1] * dst[0], range + 12); + s[2] = RangeCheckValue(kAdst4Multiplier[2] * dst[0], range + 12); + // stage 3. + // stage 4. + // stages 5 and 6. + int32_t dst_0 = RightShiftWithRounding(s[0], 12); + int32_t dst_1 = RightShiftWithRounding(s[1], 12); + int32_t dst_2 = RightShiftWithRounding(s[2], 12); + int32_t dst_3 = + RightShiftWithRounding(RangeCheckValue(s[0] + s[1], range + 12), 12); + if (sizeof(Residual) == 2) { + // If the first argument to RightShiftWithRounding(..., 12) is only + // slightly smaller than 2^27 - 1 (e.g., 0x7fffe4e), adding 2^11 to it + // in RightShiftWithRounding(..., 12) will cause the function to return + // 0x8000, which cannot be represented as an int16_t. Change it to 0x7fff. + dst_0 -= (dst_0 == 0x8000); + dst_1 -= (dst_1 == 0x8000); + dst_3 -= (dst_3 == 0x8000); + } + dst[0] = dst_0; + dst[1] = dst_1; + dst[2] = dst_2; + dst[3] = dst_3; + + const int size = 4; + if (is_row && row_shift > 0) { + for (int j = 0; j < size; ++j) { + dst[j] = RightShiftWithRounding(dst[j], row_shift); + } + } + + ClampIntermediate<bitdepth, Residual>(dst, 4); +} + +template <typename Residual> +void AdstInputPermutation(int32_t* const dst, const Residual* const src, + int n) { + assert(n == 8 || n == 16); + for (int i = 0; i < n; ++i) { + dst[i] = src[((i & 1) == 0) ? n - i - 1 : i - 1]; + } +} + +constexpr int8_t kAdstOutputPermutationLookup[16] = { + 0, 8, 12, 4, 6, 14, 10, 2, 3, 11, 15, 7, 5, 13, 9, 1}; + +template <typename Residual> +void AdstOutputPermutation(Residual* const dst, const int32_t* const src, + int n) { + assert(n == 8 || n == 16); + const auto shift = static_cast<int8_t>(n == 8); + for (int i = 0; i < n; ++i) { + const int8_t index = kAdstOutputPermutationLookup[i] >> shift; + int32_t dst_i = ((i & 1) == 0) ? src[index] : -src[index]; + if (sizeof(Residual) == 2) { + // If i is odd and src[index] is -32768, dst_i will be 32768, which + // cannot be represented as an int16_t. + dst_i -= (dst_i == 0x8000); + } + dst[i] = dst_i; + } +} + +template <typename Residual> +void Adst8_C(void* dest, int8_t range) { + auto* const dst = static_cast<Residual*>(dest); + // stage 1. + int32_t temp[8]; + AdstInputPermutation(temp, dst, 8); + // stage 2. + for (int i = 0; i < 4; ++i) { + ButterflyRotation_C(temp, MultiplyBy2(i), MultiplyBy2(i) + 1, 60 - 16 * i, + true, range); + } + // stage 3. + for (int i = 0; i < 4; ++i) { + HadamardRotation_C(temp, i, i + 4, false, range); + } + // stage 4. + for (int i = 0; i < 2; ++i) { + ButterflyRotation_C(temp, i * 3 + 4, i + 5, 48 - 32 * i, true, range); + } + // stage 5. + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + HadamardRotation_C(temp, i + MultiplyBy4(j), i + MultiplyBy4(j) + 2, + false, range); + } + } + // stage 6. + for (int i = 0; i < 2; ++i) { + ButterflyRotation_C(temp, MultiplyBy4(i) + 2, MultiplyBy4(i) + 3, 32, true, + range); + } + // stage 7. + AdstOutputPermutation(dst, temp, 8); +} + +template <int bitdepth, typename Residual> +void Adst8DcOnly_C(void* dest, int8_t range, bool should_round, int row_shift, + bool is_row) { + auto* const dst = static_cast<Residual*>(dest); + + // stage 1. + int32_t temp[8]; + // After the permutation, the dc value is in temp[1]. The remaining are zero. + AdstInputPermutation(temp, dst, 8); + + if (is_row && should_round) { + temp[1] = RightShiftWithRounding(temp[1] * kTransformRowMultiplier, 12); + } + + // stage 2. + ButterflyRotationFirstIsZero_C(temp, 0, 1, 60, true, range); + + // stage 3. + temp[4] = temp[0]; + temp[5] = temp[1]; + + // stage 4. + ButterflyRotation_C(temp, 4, 5, 48, true, range); + + // stage 5. + temp[2] = temp[0]; + temp[3] = temp[1]; + temp[6] = temp[4]; + temp[7] = temp[5]; + + // stage 6. + ButterflyRotation_C(temp, 2, 3, 32, true, range); + ButterflyRotation_C(temp, 6, 7, 32, true, range); + + // stage 7. + AdstOutputPermutation(dst, temp, 8); + + const int size = 8; + if (is_row && row_shift > 0) { + for (int j = 0; j < size; ++j) { + dst[j] = RightShiftWithRounding(dst[j], row_shift); + } + } + + ClampIntermediate<bitdepth, Residual>(dst, 8); +} + +template <typename Residual> +void Adst16_C(void* dest, int8_t range) { + auto* const dst = static_cast<Residual*>(dest); + // stage 1. + int32_t temp[16]; + AdstInputPermutation(temp, dst, 16); + // stage 2. + for (int i = 0; i < 8; ++i) { + ButterflyRotation_C(temp, MultiplyBy2(i), MultiplyBy2(i) + 1, 62 - 8 * i, + true, range); + } + // stage 3. + for (int i = 0; i < 8; ++i) { + HadamardRotation_C(temp, i, i + 8, false, range); + } + // stage 4. + for (int i = 0; i < 2; ++i) { + ButterflyRotation_C(temp, MultiplyBy2(i) + 8, MultiplyBy2(i) + 9, + 56 - 32 * i, true, range); + ButterflyRotation_C(temp, MultiplyBy2(i) + 13, MultiplyBy2(i) + 12, + 8 + 32 * i, true, range); + } + // stage 5. + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 2; ++j) { + HadamardRotation_C(temp, i + MultiplyBy8(j), i + MultiplyBy8(j) + 4, + false, range); + } + } + // stage 6. + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + ButterflyRotation_C(temp, i * 3 + MultiplyBy8(j) + 4, + i + MultiplyBy8(j) + 5, 48 - 32 * i, true, range); + } + } + // stage 7. + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 4; ++j) { + HadamardRotation_C(temp, i + MultiplyBy4(j), i + MultiplyBy4(j) + 2, + false, range); + } + } + // stage 8. + for (int i = 0; i < 4; ++i) { + ButterflyRotation_C(temp, MultiplyBy4(i) + 2, MultiplyBy4(i) + 3, 32, true, + range); + } + // stage 9. + AdstOutputPermutation(dst, temp, 16); +} + +template <int bitdepth, typename Residual> +void Adst16DcOnly_C(void* dest, int8_t range, bool should_round, int row_shift, + bool is_row) { + auto* const dst = static_cast<Residual*>(dest); + + // stage 1. + int32_t temp[16]; + // After the permutation, the dc value is in temp[1]. The remaining are zero. + AdstInputPermutation(temp, dst, 16); + + if (is_row && should_round) { + temp[1] = RightShiftWithRounding(temp[1] * kTransformRowMultiplier, 12); + } + + // stage 2. + ButterflyRotationFirstIsZero_C(temp, 0, 1, 62, true, range); + + // stage 3. + temp[8] = temp[0]; + temp[9] = temp[1]; + + // stage 4. + ButterflyRotation_C(temp, 8, 9, 56, true, range); + + // stage 5. + temp[4] = temp[0]; + temp[5] = temp[1]; + temp[12] = temp[8]; + temp[13] = temp[9]; + + // stage 6. + ButterflyRotation_C(temp, 4, 5, 48, true, range); + ButterflyRotation_C(temp, 12, 13, 48, true, range); + + // stage 7. + temp[2] = temp[0]; + temp[3] = temp[1]; + temp[10] = temp[8]; + temp[11] = temp[9]; + + temp[6] = temp[4]; + temp[7] = temp[5]; + temp[14] = temp[12]; + temp[15] = temp[13]; + + // stage 8. + for (int i = 0; i < 4; ++i) { + ButterflyRotation_C(temp, MultiplyBy4(i) + 2, MultiplyBy4(i) + 3, 32, true, + range); + } + + // stage 9. + AdstOutputPermutation(dst, temp, 16); + + const int size = 16; + if (is_row && row_shift > 0) { + for (int j = 0; j < size; ++j) { + dst[j] = RightShiftWithRounding(dst[j], row_shift); + } + } + + ClampIntermediate<bitdepth, Residual>(dst, 16); +} + +//------------------------------------------------------------------------------ +// Identity Transforms. +// +// In the spec, the inverse identity transform is followed by a Round2() call: +// The row transforms with i = 0..(h-1) are applied as follows: +// ... +// * Otherwise, invoke the inverse identity transform process specified in +// section 7.13.2.15 with the input variable n equal to log2W. +// * Set Residual[ i ][ j ] equal to Round2( T[ j ], rowShift ) +// for j = 0..(w-1). +// ... +// The column transforms with j = 0..(w-1) are applied as follows: +// ... +// * Otherwise, invoke the inverse identity transform process specified in +// section 7.13.2.15 with the input variable n equal to log2H. +// * Residual[ i ][ j ] is set equal to Round2( T[ i ], colShift ) +// for i = 0..(h-1). +// +// Therefore, we define the identity transform functions to perform both the +// inverse identity transform and the Round2() call. This has two advantages: +// 1. The outputs of the inverse identity transform do not need to be stored +// in the Residual array. They can be stored in int32_t local variables, +// which have a larger range if Residual is an int16_t array. +// 2. The inverse identity transform and the Round2() call can be jointly +// optimized. +// +// The identity transform functions have the following prototype: +// void Identity_C(void* dest, int8_t shift); +// +// The |shift| parameter is the amount of shift for the Round2() call. For row +// transforms, |shift| is 0, 1, or 2. For column transforms, |shift| is always +// 4. Therefore, an identity transform function can detect whether it is being +// invoked as a row transform or a column transform by checking whether |shift| +// is equal to 4. +// +// Input Range +// +// The inputs of row transforms, stored in the 2D array Dequant, are +// representable by a signed integer using 8 + BitDepth bits of precision: +// f. Dequant[ i ][ j ] is set equal to +// Clip3( - ( 1 << ( 7 + BitDepth ) ), ( 1 << ( 7 + BitDepth ) ) - 1, dq2 ). +// +// The inputs of column transforms are representable by a signed integer using +// Max( BitDepth + 6, 16 ) bits of precision: +// Set the variable colClampRange equal to Max( BitDepth + 6, 16 ). +// ... +// Between the row and column transforms, Residual[ i ][ j ] is set equal to +// Clip3( - ( 1 << ( colClampRange - 1 ) ), +// ( 1 << (colClampRange - 1 ) ) - 1, +// Residual[ i ][ j ] ) +// for i = 0..(h-1), for j = 0..(w-1). +// +// Output Range +// +// The outputs of row transforms are representable by a signed integer using +// 8 + BitDepth + 1 = 9 + BitDepth bits of precision, because the net effect +// of the multiplicative factor of inverse identity transforms minus the +// smallest row shift is an increase of at most one bit. +// +// Transform | Multiplicative factor | Smallest row | Net increase +// width | (in bits) | shift | in bits +// --------------------------------------------------------------- +// 4 | sqrt(2) (0.5 bits) | 0 | +0.5 +// 8 | 2 (1 bit) | 0 | +1 +// 16 | 2*sqrt(2) (1.5 bits) | 1 | +0.5 +// 32 | 4 (2 bits) | 1 | +1 +// +// If BitDepth is 8 and Residual is an int16_t array, to avoid truncation we +// clip the outputs (which have 17 bits of precision) to the range of int16_t +// before storing them in the Residual array. This clipping happens to be the +// same as the required clipping after the row transform (see the spec quoted +// above), so we remain compliant with the spec. (In this case, +// TransformLoop_C() skips clipping the outputs of row transforms to avoid +// duplication of effort.) +// +// The outputs of column transforms are representable by a signed integer using +// Max( BitDepth + 6, 16 ) + 2 - 4 = Max( BitDepth + 4, 14 ) bits of precision, +// because the multiplicative factor of inverse identity transforms is at most +// 4 (2 bits) and |shift| is always 4. + +template <typename Residual> +void Identity4Row_C(void* dest, int8_t shift) { + assert(shift == 0 || shift == 1); + auto* const dst = static_cast<Residual*>(dest); + // If |shift| is 0, |rounding| should be 1 << 11. If |shift| is 1, |rounding| + // should be (1 + (1 << 1)) << 11. The following expression works for both + // values of |shift|. + const int32_t rounding = (1 + (shift << 1)) << 11; + for (int i = 0; i < 4; ++i) { + // The intermediate value here will have to fit into an int32_t for it to be + // bitstream conformant. The multiplication is promoted to int32_t by + // defining kIdentity4Multiplier as int32_t. + int32_t dst_i = (dst[i] * kIdentity4Multiplier + rounding) >> (12 + shift); + if (sizeof(Residual) == 2) { + dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX); + } + dst[i] = static_cast<Residual>(dst_i); + } +} + +template <typename Residual> +void Identity4Column_C(void* dest, int8_t /*shift*/) { + auto* const dst = static_cast<Residual*>(dest); + const int32_t rounding = (1 + (1 << kTransformColumnShift)) << 11; + for (int i = 0; i < 4; ++i) { + // The intermediate value here will have to fit into an int32_t for it to be + // bitstream conformant. The multiplication is promoted to int32_t by + // defining kIdentity4Multiplier as int32_t. + dst[i] = static_cast<Residual>((dst[i] * kIdentity4Multiplier + rounding) >> + (12 + kTransformColumnShift)); + } +} + +template <int bitdepth, typename Residual> +void Identity4DcOnly_C(void* dest, int8_t /*range*/, bool should_round, + int row_shift, bool is_row) { + auto* const dst = static_cast<Residual*>(dest); + + if (is_row) { + if (should_round) { + dst[0] = RightShiftWithRounding(dst[0] * kTransformRowMultiplier, 12); + } + + const int32_t rounding = (1 + (row_shift << 1)) << 11; + int32_t dst_i = + (dst[0] * kIdentity4Multiplier + rounding) >> (12 + row_shift); + if (sizeof(Residual) == 2) { + dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX); + } + dst[0] = static_cast<Residual>(dst_i); + + ClampIntermediate<bitdepth, Residual>(dst, 1); + return; + } + + const int32_t rounding = (1 + (1 << kTransformColumnShift)) << 11; + dst[0] = static_cast<Residual>((dst[0] * kIdentity4Multiplier + rounding) >> + (12 + kTransformColumnShift)); +} + +template <typename Residual> +void Identity8Row_C(void* dest, int8_t shift) { + assert(shift == 0 || shift == 1 || shift == 2); + auto* const dst = static_cast<Residual*>(dest); + for (int i = 0; i < 8; ++i) { + int32_t dst_i = RightShiftWithRounding(MultiplyBy2(dst[i]), shift); + if (sizeof(Residual) == 2) { + dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX); + } + dst[i] = static_cast<Residual>(dst_i); + } +} + +template <typename Residual> +void Identity8Column_C(void* dest, int8_t /*shift*/) { + auto* const dst = static_cast<Residual*>(dest); + for (int i = 0; i < 8; ++i) { + dst[i] = static_cast<Residual>( + RightShiftWithRounding(dst[i], kTransformColumnShift - 1)); + } +} + +template <int bitdepth, typename Residual> +void Identity8DcOnly_C(void* dest, int8_t /*range*/, bool should_round, + int row_shift, bool is_row) { + auto* const dst = static_cast<Residual*>(dest); + + if (is_row) { + if (should_round) { + dst[0] = RightShiftWithRounding(dst[0] * kTransformRowMultiplier, 12); + } + + int32_t dst_i = RightShiftWithRounding(MultiplyBy2(dst[0]), row_shift); + if (sizeof(Residual) == 2) { + dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX); + } + dst[0] = static_cast<Residual>(dst_i); + + // If Residual is int16_t (which implies bitdepth is 8), we don't need to + // clip residual[i][j] to 16 bits. + if (sizeof(Residual) > 2) { + const Residual intermediate_clamp_max = + (1 << (std::max(bitdepth + 6, 16) - 1)) - 1; + const Residual intermediate_clamp_min = -intermediate_clamp_max - 1; + dst[0] = Clip3(dst[0], intermediate_clamp_min, intermediate_clamp_max); + } + return; + } + + dst[0] = static_cast<Residual>( + RightShiftWithRounding(dst[0], kTransformColumnShift - 1)); +} + +template <typename Residual> +void Identity16Row_C(void* dest, int8_t shift) { + assert(shift == 1 || shift == 2); + auto* const dst = static_cast<Residual*>(dest); + const int32_t rounding = (1 + (1 << shift)) << 11; + for (int i = 0; i < 16; ++i) { + // The intermediate value here will have to fit into an int32_t for it to be + // bitstream conformant. The multiplication is promoted to int32_t by + // defining kIdentity16Multiplier as int32_t. + int32_t dst_i = (dst[i] * kIdentity16Multiplier + rounding) >> (12 + shift); + if (sizeof(Residual) == 2) { + dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX); + } + dst[i] = static_cast<Residual>(dst_i); + } +} + +template <typename Residual> +void Identity16Column_C(void* dest, int8_t /*shift*/) { + auto* const dst = static_cast<Residual*>(dest); + const int32_t rounding = (1 + (1 << kTransformColumnShift)) << 11; + for (int i = 0; i < 16; ++i) { + // The intermediate value here will have to fit into an int32_t for it to be + // bitstream conformant. The multiplication is promoted to int32_t by + // defining kIdentity16Multiplier as int32_t. + dst[i] = + static_cast<Residual>((dst[i] * kIdentity16Multiplier + rounding) >> + (12 + kTransformColumnShift)); + } +} + +template <int bitdepth, typename Residual> +void Identity16DcOnly_C(void* dest, int8_t /*range*/, bool should_round, + int row_shift, bool is_row) { + auto* const dst = static_cast<Residual*>(dest); + + if (is_row) { + if (should_round) { + dst[0] = RightShiftWithRounding(dst[0] * kTransformRowMultiplier, 12); + } + + const int32_t rounding = (1 + (1 << row_shift)) << 11; + int32_t dst_i = + (dst[0] * kIdentity16Multiplier + rounding) >> (12 + row_shift); + if (sizeof(Residual) == 2) { + dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX); + } + dst[0] = static_cast<Residual>(dst_i); + + ClampIntermediate<bitdepth, Residual>(dst, 1); + return; + } + + const int32_t rounding = (1 + (1 << kTransformColumnShift)) << 11; + dst[0] = static_cast<Residual>((dst[0] * kIdentity16Multiplier + rounding) >> + (12 + kTransformColumnShift)); +} + +template <typename Residual> +void Identity32Row_C(void* dest, int8_t shift) { + assert(shift == 1 || shift == 2); + auto* const dst = static_cast<Residual*>(dest); + for (int i = 0; i < 32; ++i) { + int32_t dst_i = RightShiftWithRounding(MultiplyBy4(dst[i]), shift); + if (sizeof(Residual) == 2) { + dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX); + } + dst[i] = static_cast<Residual>(dst_i); + } +} + +template <typename Residual> +void Identity32Column_C(void* dest, int8_t /*shift*/) { + auto* const dst = static_cast<Residual*>(dest); + for (int i = 0; i < 32; ++i) { + dst[i] = static_cast<Residual>( + RightShiftWithRounding(dst[i], kTransformColumnShift - 2)); + } +} + +template <int bitdepth, typename Residual> +void Identity32DcOnly_C(void* dest, int8_t /*range*/, bool should_round, + int row_shift, bool is_row) { + auto* const dst = static_cast<Residual*>(dest); + + if (is_row) { + if (should_round) { + dst[0] = RightShiftWithRounding(dst[0] * kTransformRowMultiplier, 12); + } + + int32_t dst_i = RightShiftWithRounding(MultiplyBy4(dst[0]), row_shift); + if (sizeof(Residual) == 2) { + dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX); + } + dst[0] = static_cast<Residual>(dst_i); + + ClampIntermediate<bitdepth, Residual>(dst, 1); + return; + } + + dst[0] = static_cast<Residual>( + RightShiftWithRounding(dst[0], kTransformColumnShift - 2)); +} + +//------------------------------------------------------------------------------ +// Walsh Hadamard Transform. + +template <typename Residual> +void Wht4_C(void* dest, int8_t shift) { + auto* const dst = static_cast<Residual*>(dest); + Residual temp[4]; + temp[0] = dst[0] >> shift; + temp[2] = dst[1] >> shift; + temp[3] = dst[2] >> shift; + temp[1] = dst[3] >> shift; + temp[0] += temp[2]; + temp[3] -= temp[1]; + // This signed right shift must be an arithmetic shift. + Residual e = (temp[0] - temp[3]) >> 1; + dst[1] = e - temp[1]; + dst[2] = e - temp[2]; + dst[0] = temp[0] - dst[1]; + dst[3] = temp[3] + dst[2]; +} + +template <int bitdepth, typename Residual> +void Wht4DcOnly_C(void* dest, int8_t range, bool /*should_round*/, + int /*row_shift*/, bool /*is_row*/) { + auto* const dst = static_cast<Residual*>(dest); + const int shift = range; + + Residual temp = dst[0] >> shift; + // This signed right shift must be an arithmetic shift. + Residual e = temp >> 1; + dst[0] = temp - e; + dst[1] = e; + dst[2] = e; + dst[3] = e; + + ClampIntermediate<bitdepth, Residual>(dst, 4); +} + +//------------------------------------------------------------------------------ +// row/column transform loop + +using InverseTransform1DFunc = void (*)(void* dst, int8_t range); +using InverseTransformDcOnlyFunc = void (*)(void* dest, int8_t range, + bool should_round, int row_shift, + bool is_row); + +template <int bitdepth, typename Residual, typename Pixel, + Transform1D transform1d_type, + InverseTransformDcOnlyFunc dconly_transform1d, + InverseTransform1DFunc transform1d_func, bool is_row> +void TransformLoop_C(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, int start_x, + int start_y, void* dst_frame) { + constexpr bool lossless = transform1d_type == k1DTransformWht; + constexpr bool is_identity = transform1d_type == k1DTransformIdentity; + // The transform size of the WHT is always 4x4. Setting tx_width and + // tx_height to the constant 4 for the WHT speeds the code up. + assert(!lossless || tx_size == kTransformSize4x4); + const int tx_width = lossless ? 4 : kTransformWidth[tx_size]; + const int tx_height = lossless ? 4 : kTransformHeight[tx_size]; + const int tx_width_log2 = kTransformWidthLog2[tx_size]; + const int tx_height_log2 = kTransformHeightLog2[tx_size]; + auto* frame = static_cast<Array2DView<Pixel>*>(dst_frame); + + // Initially this points to the dequantized values. After the transforms are + // applied, this buffer contains the residual. + Array2DView<Residual> residual(tx_height, tx_width, + static_cast<Residual*>(src_buffer)); + + if (is_row) { + // Row transform. + const uint8_t row_shift = lossless ? 0 : kTransformRowShift[tx_size]; + // This is the |range| parameter of the InverseTransform1DFunc. For lossy + // transforms, this will be equal to the clamping range. + const int8_t row_clamp_range = lossless ? 2 : (bitdepth + 8); + // If the width:height ratio of the transform size is 2:1 or 1:2, multiply + // the input to the row transform by 1 / sqrt(2), which is approximated by + // the fraction 2896 / 2^12. + const bool should_round = std::abs(tx_width_log2 - tx_height_log2) == 1; + + if (adjusted_tx_height == 1) { + dconly_transform1d(residual[0], row_clamp_range, should_round, row_shift, + true); + return; + } + + // Row transforms need to be done only up to 32 because the rest of the rows + // are always all zero if |tx_height| is 64. Otherwise, only process the + // rows that have a non zero coefficients. + for (int i = 0; i < adjusted_tx_height; ++i) { + // If lossless, the transform size is 4x4, so should_round is false. + if (!lossless && should_round) { + // The last 32 values of every row are always zero if the |tx_width| is + // 64. + for (int j = 0; j < std::min(tx_width, 32); ++j) { + residual[i][j] = RightShiftWithRounding( + residual[i][j] * kTransformRowMultiplier, 12); + } + } + // For identity transform, |transform1d_func| also performs the + // Round2(T[j], rowShift) call in the spec. + transform1d_func(residual[i], is_identity ? row_shift : row_clamp_range); + if (!lossless && !is_identity && row_shift > 0) { + for (int j = 0; j < tx_width; ++j) { + residual[i][j] = RightShiftWithRounding(residual[i][j], row_shift); + } + } + + ClampIntermediate<bitdepth, Residual>(residual[i], tx_width); + } + return; + } + + assert(!is_row); + constexpr uint8_t column_shift = lossless ? 0 : kTransformColumnShift; + // This is the |range| parameter of the InverseTransform1DFunc. For lossy + // transforms, this will be equal to the clamping range. + const int8_t column_clamp_range = lossless ? 0 : std::max(bitdepth + 6, 16); + const bool flip_rows = transform1d_type == k1DTransformAdst && + kTransformFlipRowsMask.Contains(tx_type); + const bool flip_columns = + !lossless && kTransformFlipColumnsMask.Contains(tx_type); + const int min_value = 0; + const int max_value = (1 << bitdepth) - 1; + // Note: 64 is the maximum size of a 1D transform buffer (the largest + // transform size is kTransformSize64x64). + Residual tx_buffer[64]; + for (int j = 0; j < tx_width; ++j) { + const int flipped_j = flip_columns ? tx_width - j - 1 : j; + for (int i = 0; i < tx_height; ++i) { + tx_buffer[i] = residual[i][flipped_j]; + } + if (adjusted_tx_height == 1) { + dconly_transform1d(tx_buffer, column_clamp_range, false, 0, false); + } else { + // For identity transform, |transform1d_func| also performs the + // Round2(T[i], colShift) call in the spec. + transform1d_func(tx_buffer, + is_identity ? column_shift : column_clamp_range); + } + const int x = start_x + j; + for (int i = 0; i < tx_height; ++i) { + const int y = start_y + i; + const int index = flip_rows ? tx_height - i - 1 : i; + Residual residual_value = tx_buffer[index]; + if (!lossless && !is_identity) { + residual_value = RightShiftWithRounding(residual_value, column_shift); + } + (*frame)[y][x] = + Clip3((*frame)[y][x] + residual_value, min_value, max_value); + } + } +} + +//------------------------------------------------------------------------------ + +template <int bitdepth, typename Residual, typename Pixel> +void InitAll(Dsp* const dsp) { + // Maximum transform size for Dct is 64. + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kRow] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct, + DctDcOnly_C<bitdepth, Residual, 2>, Dct_C<Residual, 2>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kColumn] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct, + DctDcOnly_C<bitdepth, Residual, 2>, Dct_C<Residual, 2>, + /*is_row=*/false>; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kRow] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct, + DctDcOnly_C<bitdepth, Residual, 3>, Dct_C<Residual, 3>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kColumn] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct, + DctDcOnly_C<bitdepth, Residual, 3>, Dct_C<Residual, 3>, + /*is_row=*/false>; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kRow] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct, + DctDcOnly_C<bitdepth, Residual, 4>, Dct_C<Residual, 4>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kColumn] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct, + DctDcOnly_C<bitdepth, Residual, 4>, Dct_C<Residual, 4>, + /*is_row=*/false>; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kRow] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct, + DctDcOnly_C<bitdepth, Residual, 5>, Dct_C<Residual, 5>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kColumn] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct, + DctDcOnly_C<bitdepth, Residual, 5>, Dct_C<Residual, 5>, + /*is_row=*/false>; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kRow] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct, + DctDcOnly_C<bitdepth, Residual, 6>, Dct_C<Residual, 6>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kColumn] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct, + DctDcOnly_C<bitdepth, Residual, 6>, Dct_C<Residual, 6>, + /*is_row=*/false>; + + // Maximum transform size for Adst is 16. + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kRow] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformAdst, + Adst4DcOnly_C<bitdepth, Residual>, Adst4_C<Residual>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kColumn] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformAdst, + Adst4DcOnly_C<bitdepth, Residual>, Adst4_C<Residual>, + /*is_row=*/false>; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kRow] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformAdst, + Adst8DcOnly_C<bitdepth, Residual>, Adst8_C<Residual>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kColumn] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformAdst, + Adst8DcOnly_C<bitdepth, Residual>, Adst8_C<Residual>, + /*is_row=*/false>; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kRow] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformAdst, + Adst16DcOnly_C<bitdepth, Residual>, Adst16_C<Residual>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kColumn] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformAdst, + Adst16DcOnly_C<bitdepth, Residual>, Adst16_C<Residual>, + /*is_row=*/false>; + + // Maximum transform size for Identity transform is 32. + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kRow] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity, + Identity4DcOnly_C<bitdepth, Residual>, + Identity4Row_C<Residual>, /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kColumn] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity, + Identity4DcOnly_C<bitdepth, Residual>, + Identity4Column_C<Residual>, /*is_row=*/false>; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kRow] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity, + Identity8DcOnly_C<bitdepth, Residual>, + Identity8Row_C<Residual>, /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kColumn] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity, + Identity8DcOnly_C<bitdepth, Residual>, + Identity8Column_C<Residual>, /*is_row=*/false>; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kRow] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity, + Identity16DcOnly_C<bitdepth, Residual>, + Identity16Row_C<Residual>, /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kColumn] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity, + Identity16DcOnly_C<bitdepth, Residual>, + Identity16Column_C<Residual>, /*is_row=*/false>; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kRow] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity, + Identity32DcOnly_C<bitdepth, Residual>, + Identity32Row_C<Residual>, /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kColumn] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity, + Identity32DcOnly_C<bitdepth, Residual>, + Identity32Column_C<Residual>, /*is_row=*/false>; + + // Maximum transform size for Wht is 4. + dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kRow] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformWht, + Wht4DcOnly_C<bitdepth, Residual>, Wht4_C<Residual>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kColumn] = + TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformWht, + Wht4DcOnly_C<bitdepth, Residual>, Wht4_C<Residual>, + /*is_row=*/false>; +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(8); + assert(dsp != nullptr); + for (auto& inverse_transform_by_size : dsp->inverse_transforms) { + for (auto& inverse_transform : inverse_transform_by_size) { + inverse_transform[kRow] = nullptr; + inverse_transform[kColumn] = nullptr; + } + } +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + InitAll<8, int16_t, uint8_t>(dsp); +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformDct + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kRow] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, + DctDcOnly_C<8, int16_t, 2>, Dct_C<int16_t, 2>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kColumn] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, + DctDcOnly_C<8, int16_t, 2>, Dct_C<int16_t, 2>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformDct + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kRow] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, + DctDcOnly_C<8, int16_t, 3>, Dct_C<int16_t, 3>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kColumn] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, + DctDcOnly_C<8, int16_t, 3>, Dct_C<int16_t, 3>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformDct + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kRow] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, + DctDcOnly_C<8, int16_t, 4>, Dct_C<int16_t, 4>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kColumn] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, + DctDcOnly_C<8, int16_t, 4>, Dct_C<int16_t, 4>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformDct + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kRow] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, + DctDcOnly_C<8, int16_t, 5>, Dct_C<int16_t, 5>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kColumn] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, + DctDcOnly_C<8, int16_t, 5>, Dct_C<int16_t, 5>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize64_1DTransformDct + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kRow] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, + DctDcOnly_C<8, int16_t, 6>, Dct_C<int16_t, 6>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kColumn] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, + DctDcOnly_C<8, int16_t, 6>, Dct_C<int16_t, 6>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformAdst + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kRow] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst, + Adst4DcOnly_C<8, int16_t>, Adst4_C<int16_t>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kColumn] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst, + Adst4DcOnly_C<8, int16_t>, Adst4_C<int16_t>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformAdst + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kRow] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst, + Adst8DcOnly_C<8, int16_t>, Adst8_C<int16_t>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kColumn] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst, + Adst8DcOnly_C<8, int16_t>, Adst8_C<int16_t>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformAdst + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kRow] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst, + Adst16DcOnly_C<8, int16_t>, Adst16_C<int16_t>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kColumn] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst, + Adst16DcOnly_C<8, int16_t>, Adst16_C<int16_t>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformIdentity + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kRow] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity, + Identity4DcOnly_C<8, int16_t>, Identity4Row_C<int16_t>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kColumn] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity, + Identity4DcOnly_C<8, int16_t>, Identity4Column_C<int16_t>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformIdentity + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kRow] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity, + Identity8DcOnly_C<8, int16_t>, Identity8Row_C<int16_t>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kColumn] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity, + Identity8DcOnly_C<8, int16_t>, Identity8Column_C<int16_t>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformIdentity + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kRow] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity, + Identity16DcOnly_C<8, int16_t>, Identity16Row_C<int16_t>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kColumn] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity, + Identity16DcOnly_C<8, int16_t>, + Identity16Column_C<int16_t>, /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kRow] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity, + Identity32DcOnly_C<8, int16_t>, Identity32Row_C<int16_t>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kColumn] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity, + Identity32DcOnly_C<8, int16_t>, + Identity32Column_C<int16_t>, /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht + dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kRow] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformWht, + Wht4DcOnly_C<8, int16_t>, Wht4_C<int16_t>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kColumn] = + TransformLoop_C<8, int16_t, uint8_t, k1DTransformWht, + Wht4DcOnly_C<8, int16_t>, Wht4_C<int16_t>, + /*is_row=*/false>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(10); + assert(dsp != nullptr); + for (auto& inverse_transform_by_size : dsp->inverse_transforms) { + for (auto& inverse_transform : inverse_transform_by_size) { + inverse_transform[kRow] = nullptr; + inverse_transform[kColumn] = nullptr; + } + } +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + InitAll<10, int32_t, uint16_t>(dsp); +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +#ifndef LIBGAV1_Dsp10bpp_1DTransformSize4_1DTransformDct + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kRow] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct, + DctDcOnly_C<10, int32_t, 2>, Dct_C<int32_t, 2>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kColumn] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct, + DctDcOnly_C<10, int32_t, 2>, Dct_C<int32_t, 2>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp10bpp_1DTransformSize8_1DTransformDct + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kRow] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct, + DctDcOnly_C<10, int32_t, 3>, Dct_C<int32_t, 3>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kColumn] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct, + DctDcOnly_C<10, int32_t, 3>, Dct_C<int32_t, 3>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp10bpp_1DTransformSize16_1DTransformDct + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kRow] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct, + DctDcOnly_C<10, int32_t, 4>, Dct_C<int32_t, 4>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kColumn] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct, + DctDcOnly_C<10, int32_t, 4>, Dct_C<int32_t, 4>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp10bpp_1DTransformSize32_1DTransformDct + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kRow] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct, + DctDcOnly_C<10, int32_t, 5>, Dct_C<int32_t, 5>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kColumn] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct, + DctDcOnly_C<10, int32_t, 5>, Dct_C<int32_t, 5>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp10bpp_1DTransformSize64_1DTransformDct + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kRow] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct, + DctDcOnly_C<10, int32_t, 6>, Dct_C<int32_t, 6>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kColumn] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct, + DctDcOnly_C<10, int32_t, 6>, Dct_C<int32_t, 6>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp10bpp_1DTransformSize4_1DTransformAdst + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kRow] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformAdst, + Adst4DcOnly_C<10, int32_t>, Adst4_C<int32_t>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kColumn] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformAdst, + Adst4DcOnly_C<10, int32_t>, Adst4_C<int32_t>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp10bpp_1DTransformSize8_1DTransformAdst + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kRow] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformAdst, + Adst8DcOnly_C<10, int32_t>, Adst8_C<int32_t>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kColumn] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformAdst, + Adst8DcOnly_C<10, int32_t>, Adst8_C<int32_t>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp10bpp_1DTransformSize16_1DTransformAdst + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kRow] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformAdst, + Adst16DcOnly_C<10, int32_t>, Adst16_C<int32_t>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kColumn] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformAdst, + Adst16DcOnly_C<10, int32_t>, Adst16_C<int32_t>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp10bpp_1DTransformSize4_1DTransformIdentity + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kRow] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity, + Identity4DcOnly_C<10, int32_t>, Identity4Row_C<int32_t>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kColumn] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity, + Identity4DcOnly_C<10, int32_t>, + Identity4Column_C<int32_t>, /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp10bpp_1DTransformSize8_1DTransformIdentity + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kRow] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity, + Identity8DcOnly_C<10, int32_t>, Identity8Row_C<int32_t>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kColumn] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity, + Identity8DcOnly_C<10, int32_t>, + Identity8Column_C<int32_t>, /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp10bpp_1DTransformSize16_1DTransformIdentity + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kRow] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity, + Identity16DcOnly_C<10, int32_t>, Identity16Row_C<int32_t>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kColumn] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity, + Identity16DcOnly_C<10, int32_t>, + Identity16Column_C<int32_t>, /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp10bpp_1DTransformSize32_1DTransformIdentity + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kRow] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity, + Identity32DcOnly_C<10, int32_t>, Identity32Row_C<int32_t>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kColumn] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity, + Identity32DcOnly_C<10, int32_t>, + Identity32Column_C<int32_t>, /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp10bpp_1DTransformSize4_1DTransformWht + dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kRow] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformWht, + Wht4DcOnly_C<10, int32_t>, Wht4_C<int32_t>, + /*is_row=*/true>; + dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kColumn] = + TransformLoop_C<10, int32_t, uint16_t, k1DTransformWht, + Wht4DcOnly_C<10, int32_t>, Wht4_C<int32_t>, + /*is_row=*/false>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +} // namespace + +void InverseTransformInit_C() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif + + // Local functions that may be unused depending on the optimizations + // available. + static_cast<void>(RangeCheckValue); + static_cast<void>(kBitReverseLookup); +} + +} // namespace dsp +} // namespace libgav1 diff --git a/src/dsp/inverse_transform.h b/src/dsp/inverse_transform.h new file mode 100644 index 0000000..0916665 --- /dev/null +++ b/src/dsp/inverse_transform.h @@ -0,0 +1,47 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_INVERSE_TRANSFORM_H_ +#define LIBGAV1_SRC_DSP_INVERSE_TRANSFORM_H_ + +// Pull in LIBGAV1_DspXXX defines representing the implementation status +// of each function. The resulting value of each can be used by each module to +// determine whether an implementation is needed at compile time. +// IWYU pragma: begin_exports + +// ARM: +#include "src/dsp/arm/inverse_transform_neon.h" + +// x86: +// Note includes should be sorted in logical order avx2/avx/sse4, etc. +// The order of includes is important as each tests for a superior version +// before setting the base. +// clang-format off +#include "src/dsp/x86/inverse_transform_sse4.h" +// clang-format on + +// IWYU pragma: end_exports + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::inverse_transforms. This function is not thread-safe. +void InverseTransformInit_C(); + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_INVERSE_TRANSFORM_H_ diff --git a/src/dsp/inverse_transform.inc b/src/dsp/inverse_transform.inc new file mode 100644 index 0000000..55e68b6 --- /dev/null +++ b/src/dsp/inverse_transform.inc @@ -0,0 +1,64 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Constants and utility functions used for inverse transform implementations. +// This will be included inside an anonymous namespace on files where these are +// necessary. + +// The value at index i is derived as: round(cos(pi * i / 128) * (1 << 12)). +constexpr int16_t kCos128[65] = { + 4096, 4095, 4091, 4085, 4076, 4065, 4052, 4036, 4017, 3996, 3973, + 3948, 3920, 3889, 3857, 3822, 3784, 3745, 3703, 3659, 3612, 3564, + 3513, 3461, 3406, 3349, 3290, 3229, 3166, 3102, 3035, 2967, 2896, + 2824, 2751, 2675, 2598, 2520, 2440, 2359, 2276, 2191, 2106, 2019, + 1931, 1842, 1751, 1660, 1567, 1474, 1380, 1285, 1189, 1092, 995, + 897, 799, 700, 601, 501, 401, 301, 201, 101, 0}; + +inline int16_t Cos128(int angle) { + angle &= 0xff; + + // If |angle| is 128, this function returns -4096 (= -2^12), which will + // cause the 32-bit multiplications in ButterflyRotation() to overflow if + // dst[a] or dst[b] is -2^19 (a possible corner case when |range| is 20): + // + // (-2^12) * (-2^19) = 2^31, which cannot be represented as an int32_t. + // + // Note: |range| is 20 when bitdepth is 12 and a row transform is performed. + // + // Assert that this angle is never used by DCT or ADST. + assert(angle != 128); + if (angle <= 64) return kCos128[angle]; + if (angle <= 128) return -kCos128[128 - angle]; + if (angle <= 192) return -kCos128[angle - 128]; + return kCos128[256 - angle]; +} + +inline int16_t Sin128(int angle) { return Cos128(angle - 64); } + +// The value for index i is derived as: +// round(sqrt(2) * sin(i * pi / 9) * 2 / 3 * (1 << 12)). +constexpr int16_t kAdst4Multiplier[4] = {1321, 2482, 3344, 3803}; + +constexpr uint8_t kTransformRowShift[kNumTransformSizes] = { + 0, 0, 1, 0, 1, 1, 2, 1, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2}; + +constexpr bool kShouldRound[kNumTransformSizes] = { + false, true, false, true, false, true, false, false, true, false, + true, false, false, true, false, true, false, true, false}; + +constexpr int16_t kIdentity4Multiplier /* round(2^12 * sqrt(2)) */ = 0x16A1; +constexpr int16_t kIdentity4MultiplierFraction /* round(2^12 * (sqrt(2) - 1))*/ + = 0x6A1; +constexpr int16_t kIdentity16Multiplier /* 2 * round(2^12 * sqrt(2)) */ = 11586; +constexpr int16_t kTransformRowMultiplier /* round(2^12 / sqrt(2)) */ = 2896; diff --git a/src/dsp/libgav1_dsp.cmake b/src/dsp/libgav1_dsp.cmake new file mode 100644 index 0000000..960d5a7 --- /dev/null +++ b/src/dsp/libgav1_dsp.cmake @@ -0,0 +1,176 @@ +# Copyright 2019 The libgav1 Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(LIBGAV1_SRC_DSP_LIBGAV1_DSP_CMAKE_) + return() +endif() # LIBGAV1_SRC_DSP_LIBGAV1_DSP_CMAKE_ +set(LIBGAV1_SRC_DSP_LIBGAV1_DSP_CMAKE_ 1) + +include("${libgav1_root}/cmake/libgav1_targets.cmake") + +list(APPEND libgav1_dsp_sources + "${libgav1_source}/dsp/average_blend.cc" + "${libgav1_source}/dsp/average_blend.h" + "${libgav1_source}/dsp/cdef.cc" + "${libgav1_source}/dsp/cdef.h" + "${libgav1_source}/dsp/cdef.inc" + "${libgav1_source}/dsp/common.h" + "${libgav1_source}/dsp/constants.cc" + "${libgav1_source}/dsp/constants.h" + "${libgav1_source}/dsp/convolve.cc" + "${libgav1_source}/dsp/convolve.h" + "${libgav1_source}/dsp/convolve.inc" + "${libgav1_source}/dsp/distance_weighted_blend.cc" + "${libgav1_source}/dsp/distance_weighted_blend.h" + "${libgav1_source}/dsp/dsp.cc" + "${libgav1_source}/dsp/dsp.h" + "${libgav1_source}/dsp/film_grain.cc" + "${libgav1_source}/dsp/film_grain.h" + "${libgav1_source}/dsp/film_grain_common.h" + "${libgav1_source}/dsp/intra_edge.cc" + "${libgav1_source}/dsp/intra_edge.h" + "${libgav1_source}/dsp/intrapred.cc" + "${libgav1_source}/dsp/intrapred.h" + "${libgav1_source}/dsp/inverse_transform.cc" + "${libgav1_source}/dsp/inverse_transform.h" + "${libgav1_source}/dsp/inverse_transform.inc" + "${libgav1_source}/dsp/loop_filter.cc" + "${libgav1_source}/dsp/loop_filter.h" + "${libgav1_source}/dsp/loop_restoration.cc" + "${libgav1_source}/dsp/loop_restoration.h" + "${libgav1_source}/dsp/mask_blend.cc" + "${libgav1_source}/dsp/mask_blend.h" + "${libgav1_source}/dsp/motion_field_projection.cc" + "${libgav1_source}/dsp/motion_field_projection.h" + "${libgav1_source}/dsp/motion_vector_search.cc" + "${libgav1_source}/dsp/motion_vector_search.h" + "${libgav1_source}/dsp/obmc.cc" + "${libgav1_source}/dsp/obmc.h" + "${libgav1_source}/dsp/obmc.inc" + "${libgav1_source}/dsp/super_res.cc" + "${libgav1_source}/dsp/super_res.h" + "${libgav1_source}/dsp/warp.cc" + "${libgav1_source}/dsp/warp.h" + "${libgav1_source}/dsp/weight_mask.cc" + "${libgav1_source}/dsp/weight_mask.h") + +list(APPEND libgav1_dsp_sources_avx2 + ${libgav1_dsp_sources_avx2} + "${libgav1_source}/dsp/x86/convolve_avx2.cc" + "${libgav1_source}/dsp/x86/convolve_avx2.h" + "${libgav1_source}/dsp/x86/loop_restoration_10bit_avx2.cc" + "${libgav1_source}/dsp/x86/loop_restoration_avx2.cc" + "${libgav1_source}/dsp/x86/loop_restoration_avx2.h") + +list(APPEND libgav1_dsp_sources_neon + ${libgav1_dsp_sources_neon} + "${libgav1_source}/dsp/arm/average_blend_neon.cc" + "${libgav1_source}/dsp/arm/average_blend_neon.h" + "${libgav1_source}/dsp/arm/cdef_neon.cc" + "${libgav1_source}/dsp/arm/cdef_neon.h" + "${libgav1_source}/dsp/arm/common_neon.h" + "${libgav1_source}/dsp/arm/convolve_neon.cc" + "${libgav1_source}/dsp/arm/convolve_neon.h" + "${libgav1_source}/dsp/arm/distance_weighted_blend_neon.cc" + "${libgav1_source}/dsp/arm/distance_weighted_blend_neon.h" + "${libgav1_source}/dsp/arm/film_grain_neon.cc" + "${libgav1_source}/dsp/arm/film_grain_neon.h" + "${libgav1_source}/dsp/arm/intra_edge_neon.cc" + "${libgav1_source}/dsp/arm/intra_edge_neon.h" + "${libgav1_source}/dsp/arm/intrapred_cfl_neon.cc" + "${libgav1_source}/dsp/arm/intrapred_directional_neon.cc" + "${libgav1_source}/dsp/arm/intrapred_filter_intra_neon.cc" + "${libgav1_source}/dsp/arm/intrapred_neon.cc" + "${libgav1_source}/dsp/arm/intrapred_neon.h" + "${libgav1_source}/dsp/arm/intrapred_smooth_neon.cc" + "${libgav1_source}/dsp/arm/inverse_transform_neon.cc" + "${libgav1_source}/dsp/arm/inverse_transform_neon.h" + "${libgav1_source}/dsp/arm/loop_filter_neon.cc" + "${libgav1_source}/dsp/arm/loop_filter_neon.h" + "${libgav1_source}/dsp/arm/loop_restoration_neon.cc" + "${libgav1_source}/dsp/arm/loop_restoration_neon.h" + "${libgav1_source}/dsp/arm/mask_blend_neon.cc" + "${libgav1_source}/dsp/arm/mask_blend_neon.h" + "${libgav1_source}/dsp/arm/motion_field_projection_neon.cc" + "${libgav1_source}/dsp/arm/motion_field_projection_neon.h" + "${libgav1_source}/dsp/arm/motion_vector_search_neon.cc" + "${libgav1_source}/dsp/arm/motion_vector_search_neon.h" + "${libgav1_source}/dsp/arm/obmc_neon.cc" + "${libgav1_source}/dsp/arm/obmc_neon.h" + "${libgav1_source}/dsp/arm/super_res_neon.cc" + "${libgav1_source}/dsp/arm/super_res_neon.h" + "${libgav1_source}/dsp/arm/warp_neon.cc" + "${libgav1_source}/dsp/arm/warp_neon.h" + "${libgav1_source}/dsp/arm/weight_mask_neon.cc" + "${libgav1_source}/dsp/arm/weight_mask_neon.h") + +list(APPEND libgav1_dsp_sources_sse4 + ${libgav1_dsp_sources_sse4} + "${libgav1_source}/dsp/x86/average_blend_sse4.cc" + "${libgav1_source}/dsp/x86/average_blend_sse4.h" + "${libgav1_source}/dsp/x86/common_sse4.h" + "${libgav1_source}/dsp/x86/cdef_sse4.cc" + "${libgav1_source}/dsp/x86/cdef_sse4.h" + "${libgav1_source}/dsp/x86/convolve_sse4.cc" + "${libgav1_source}/dsp/x86/convolve_sse4.h" + "${libgav1_source}/dsp/x86/distance_weighted_blend_sse4.cc" + "${libgav1_source}/dsp/x86/distance_weighted_blend_sse4.h" + "${libgav1_source}/dsp/x86/intra_edge_sse4.cc" + "${libgav1_source}/dsp/x86/intra_edge_sse4.h" + "${libgav1_source}/dsp/x86/intrapred_sse4.cc" + "${libgav1_source}/dsp/x86/intrapred_sse4.h" + "${libgav1_source}/dsp/x86/intrapred_cfl_sse4.cc" + "${libgav1_source}/dsp/x86/intrapred_smooth_sse4.cc" + "${libgav1_source}/dsp/x86/inverse_transform_sse4.cc" + "${libgav1_source}/dsp/x86/inverse_transform_sse4.h" + "${libgav1_source}/dsp/x86/loop_filter_sse4.cc" + "${libgav1_source}/dsp/x86/loop_filter_sse4.h" + "${libgav1_source}/dsp/x86/loop_restoration_10bit_sse4.cc" + "${libgav1_source}/dsp/x86/loop_restoration_sse4.cc" + "${libgav1_source}/dsp/x86/loop_restoration_sse4.h" + "${libgav1_source}/dsp/x86/mask_blend_sse4.cc" + "${libgav1_source}/dsp/x86/mask_blend_sse4.h" + "${libgav1_source}/dsp/x86/motion_field_projection_sse4.cc" + "${libgav1_source}/dsp/x86/motion_field_projection_sse4.h" + "${libgav1_source}/dsp/x86/motion_vector_search_sse4.cc" + "${libgav1_source}/dsp/x86/motion_vector_search_sse4.h" + "${libgav1_source}/dsp/x86/obmc_sse4.cc" + "${libgav1_source}/dsp/x86/obmc_sse4.h" + "${libgav1_source}/dsp/x86/super_res_sse4.cc" + "${libgav1_source}/dsp/x86/super_res_sse4.h" + "${libgav1_source}/dsp/x86/transpose_sse4.h" + "${libgav1_source}/dsp/x86/warp_sse4.cc" + "${libgav1_source}/dsp/x86/warp_sse4.h" + "${libgav1_source}/dsp/x86/weight_mask_sse4.cc" + "${libgav1_source}/dsp/x86/weight_mask_sse4.h") + +macro(libgav1_add_dsp_targets) + unset(dsp_sources) + list(APPEND dsp_sources ${libgav1_dsp_sources} + ${libgav1_dsp_sources_neon} + ${libgav1_dsp_sources_avx2} + ${libgav1_dsp_sources_sse4}) + + libgav1_add_library(NAME + libgav1_dsp + TYPE + OBJECT + SOURCES + ${dsp_sources} + DEFINES + ${libgav1_defines} + $<$<CONFIG:Debug>:LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS> + INCLUDES + ${libgav1_include_paths}) +endmacro() diff --git a/src/dsp/loop_filter.cc b/src/dsp/loop_filter.cc new file mode 100644 index 0000000..6cad97d --- /dev/null +++ b/src/dsp/loop_filter.cc @@ -0,0 +1,616 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/loop_filter.h" + +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstdlib> + +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace { + +// 7.14.6.1. +template <int bitdepth, typename Pixel> +struct LoopFilterFuncs_C { + LoopFilterFuncs_C() = delete; + + static constexpr int kMaxPixel = (1 << bitdepth) - 1; + static constexpr int kMinSignedPixel = -(1 << (bitdepth - 1)); + static constexpr int kMaxSignedPixel = (1 << (bitdepth - 1)) - 1; + static constexpr int kFlatThresh = 1 << (bitdepth - 8); + + static void Vertical4(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh); + static void Horizontal4(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh); + static void Vertical6(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh); + static void Horizontal6(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh); + static void Vertical8(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh); + static void Horizontal8(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh); + static void Vertical14(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh); + static void Horizontal14(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh); +}; + +inline void AdjustThresholds(const int bitdepth, int* const outer_thresh, + int* const inner_thresh, int* const hev_thresh) { + *outer_thresh <<= bitdepth - 8; + *inner_thresh <<= bitdepth - 8; + *hev_thresh <<= bitdepth - 8; +} + +//------------------------------------------------------------------------------ +// 4-tap filters + +// 7.14.6.2. +template <typename Pixel> +inline bool NeedsFilter4(const Pixel* p, ptrdiff_t step, int outer_thresh, + int inner_thresh) { + const int p1 = p[-2 * step], p0 = p[-step]; + const int q0 = p[0], q1 = p[step]; + return std::abs(p1 - p0) <= inner_thresh && + std::abs(q1 - q0) <= inner_thresh && + std::abs(p0 - q0) * 2 + std::abs(p1 - q1) / 2 <= outer_thresh; +} + +// 7.14.6.2. +template <typename Pixel> +inline bool Hev(const Pixel* p, ptrdiff_t step, int thresh) { + const int p1 = p[-2 * step], p0 = p[-step], q0 = p[0], q1 = p[step]; + return (std::abs(p1 - p0) > thresh) || (std::abs(q1 - q0) > thresh); +} + +// 7.14.6.3. +// 4 pixels in, 2 pixels out. +template <int bitdepth, typename Pixel> +inline void Filter2_C(Pixel* p, ptrdiff_t step) { + const int p1 = p[-2 * step], p0 = p[-step], q0 = p[0], q1 = p[step]; + const int min_signed_val = + LoopFilterFuncs_C<bitdepth, Pixel>::kMinSignedPixel; + const int max_signed_val = + LoopFilterFuncs_C<bitdepth, Pixel>::kMaxSignedPixel; + // 8bpp: [-893,892], 10bpp: [-3581,3580], 12bpp [-14333,14332] + const int a = 3 * (q0 - p0) + Clip3(p1 - q1, min_signed_val, max_signed_val); + // 8bpp: [-16,15], 10bpp: [-64,63], 12bpp: [-256,255] + const int a1 = Clip3(a + 4, min_signed_val, max_signed_val) >> 3; + const int a2 = Clip3(a + 3, min_signed_val, max_signed_val) >> 3; + const int max_unsigned_val = LoopFilterFuncs_C<bitdepth, Pixel>::kMaxPixel; + p[-step] = Clip3(p0 + a2, 0, max_unsigned_val); + p[0] = Clip3(q0 - a1, 0, max_unsigned_val); +} + +// 7.14.6.3. +// 4 pixels in, 4 pixels out. +template <int bitdepth, typename Pixel> +inline void Filter4_C(Pixel* p, ptrdiff_t step) { + const int p1 = p[-2 * step], p0 = p[-step], q0 = p[0], q1 = p[step]; + const int a = 3 * (q0 - p0); + const int min_signed_val = + LoopFilterFuncs_C<bitdepth, Pixel>::kMinSignedPixel; + const int max_signed_val = + LoopFilterFuncs_C<bitdepth, Pixel>::kMaxSignedPixel; + const int a1 = Clip3(a + 4, min_signed_val, max_signed_val) >> 3; + const int a2 = Clip3(a + 3, min_signed_val, max_signed_val) >> 3; + const int a3 = (a1 + 1) >> 1; + const int max_unsigned_val = LoopFilterFuncs_C<bitdepth, Pixel>::kMaxPixel; + p[-2 * step] = Clip3(p1 + a3, 0, max_unsigned_val); + p[-1 * step] = Clip3(p0 + a2, 0, max_unsigned_val); + p[0 * step] = Clip3(q0 - a1, 0, max_unsigned_val); + p[1 * step] = Clip3(q1 - a3, 0, max_unsigned_val); +} + +template <int bitdepth, typename Pixel> +void LoopFilterFuncs_C<bitdepth, Pixel>::Vertical4(void* dest, ptrdiff_t stride, + int outer_thresh, + int inner_thresh, + int hev_thresh) { + AdjustThresholds(bitdepth, &outer_thresh, &inner_thresh, &hev_thresh); + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + for (int i = 0; i < 4; ++i) { + if (NeedsFilter4(dst, 1, outer_thresh, inner_thresh)) { + if (Hev(dst, 1, hev_thresh)) { + Filter2_C<bitdepth>(dst, 1); + } else { + Filter4_C<bitdepth>(dst, 1); + } + } + dst += stride; + } +} + +template <int bitdepth, typename Pixel> +void LoopFilterFuncs_C<bitdepth, Pixel>::Horizontal4(void* dest, + ptrdiff_t stride, + int outer_thresh, + int inner_thresh, + int hev_thresh) { + AdjustThresholds(bitdepth, &outer_thresh, &inner_thresh, &hev_thresh); + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + for (int i = 0; i < 4; ++i) { + if (NeedsFilter4(dst, stride, outer_thresh, inner_thresh)) { + if (Hev(dst, stride, hev_thresh)) { + Filter2_C<bitdepth>(dst, stride); + } else { + Filter4_C<bitdepth>(dst, stride); + } + } + ++dst; + } +} + +//------------------------------------------------------------------------------ +// 5-tap (chroma) filters + +// 7.14.6.2. +template <typename Pixel> +inline bool NeedsFilter6(const Pixel* p, ptrdiff_t step, int outer_thresh, + int inner_thresh) { + const int p2 = p[-3 * step], p1 = p[-2 * step], p0 = p[-step]; + const int q0 = p[0], q1 = p[step], q2 = p[2 * step]; + return std::abs(p2 - p1) <= inner_thresh && + std::abs(p1 - p0) <= inner_thresh && + std::abs(q1 - q0) <= inner_thresh && + std::abs(q2 - q1) <= inner_thresh && + std::abs(p0 - q0) * 2 + std::abs(p1 - q1) / 2 <= outer_thresh; +} + +// 7.14.6.2. +template <typename Pixel> +inline bool IsFlat3(const Pixel* p, ptrdiff_t step, int flat_thresh) { + const int p2 = p[-3 * step], p1 = p[-2 * step], p0 = p[-step]; + const int q0 = p[0], q1 = p[step], q2 = p[2 * step]; + return std::abs(p1 - p0) <= flat_thresh && std::abs(q1 - q0) <= flat_thresh && + std::abs(p2 - p0) <= flat_thresh && std::abs(q2 - q0) <= flat_thresh; +} + +template <typename Pixel> +inline Pixel ApplyFilter6(int filter_value) { + return static_cast<Pixel>(RightShiftWithRounding(filter_value, 3)); +} + +// 7.14.6.4. +// 6 pixels in, 4 pixels out. +template <typename Pixel> +inline void Filter6_C(Pixel* p, ptrdiff_t step) { + const int p2 = p[-3 * step], p1 = p[-2 * step], p0 = p[-step]; + const int q0 = p[0], q1 = p[step], q2 = p[2 * step]; + const int a1 = 2 * p1; + const int a0 = 2 * p0; + const int b0 = 2 * q0; + const int b1 = 2 * q1; + // The max is 8 * max_pixel + 4 for the rounder. + // 8bpp: 2044 (11 bits), 10bpp: 8188 (13 bits), 12bpp: 32764 (15 bits) + p[-2 * step] = ApplyFilter6<Pixel>(3 * p2 + a1 + a0 + q0); + p[-1 * step] = ApplyFilter6<Pixel>(p2 + a1 + a0 + b0 + q1); + p[0 * step] = ApplyFilter6<Pixel>(p1 + a0 + b0 + b1 + q2); + p[1 * step] = ApplyFilter6<Pixel>(p0 + b0 + b1 + 3 * q2); +} + +template <int bitdepth, typename Pixel> +void LoopFilterFuncs_C<bitdepth, Pixel>::Vertical6(void* dest, ptrdiff_t stride, + int outer_thresh, + int inner_thresh, + int hev_thresh) { + const int flat_thresh = LoopFilterFuncs_C<bitdepth, Pixel>::kFlatThresh; + AdjustThresholds(bitdepth, &outer_thresh, &inner_thresh, &hev_thresh); + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + for (int i = 0; i < 4; ++i) { + if (NeedsFilter6(dst, 1, outer_thresh, inner_thresh)) { + if (IsFlat3(dst, 1, flat_thresh)) { + Filter6_C(dst, 1); + } else if (Hev(dst, 1, hev_thresh)) { + Filter2_C<bitdepth>(dst, 1); + } else { + Filter4_C<bitdepth>(dst, 1); + } + } + dst += stride; + } +} + +template <int bitdepth, typename Pixel> +void LoopFilterFuncs_C<bitdepth, Pixel>::Horizontal6(void* dest, + ptrdiff_t stride, + int outer_thresh, + int inner_thresh, + int hev_thresh) { + const int flat_thresh = LoopFilterFuncs_C<bitdepth, Pixel>::kFlatThresh; + AdjustThresholds(bitdepth, &outer_thresh, &inner_thresh, &hev_thresh); + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + for (int i = 0; i < 4; ++i) { + if (NeedsFilter6(dst, stride, outer_thresh, inner_thresh)) { + if (IsFlat3(dst, stride, flat_thresh)) { + Filter6_C(dst, stride); + } else if (Hev(dst, stride, hev_thresh)) { + Filter2_C<bitdepth>(dst, stride); + } else { + Filter4_C<bitdepth>(dst, stride); + } + } + ++dst; + } +} + +//------------------------------------------------------------------------------ +// 7-tap filters + +// 7.14.6.2. +template <typename Pixel> +inline bool NeedsFilter8(const Pixel* p, ptrdiff_t step, int outer_thresh, + int inner_thresh) { + const int p3 = p[-4 * step], p2 = p[-3 * step], p1 = p[-2 * step], + p0 = p[-step]; + const int q0 = p[0], q1 = p[step], q2 = p[2 * step], q3 = p[3 * step]; + return std::abs(p3 - p2) <= inner_thresh && + std::abs(p2 - p1) <= inner_thresh && + std::abs(p1 - p0) <= inner_thresh && + std::abs(q1 - q0) <= inner_thresh && + std::abs(q2 - q1) <= inner_thresh && + std::abs(q3 - q2) <= inner_thresh && + std::abs(p0 - q0) * 2 + std::abs(p1 - q1) / 2 <= outer_thresh; +} + +// 7.14.6.2. +template <typename Pixel> +inline bool IsFlat4(const Pixel* p, ptrdiff_t step, int flat_thresh) { + const int p3 = p[-4 * step], p2 = p[-3 * step], p1 = p[-2 * step], + p0 = p[-step]; + const int q0 = p[0], q1 = p[step], q2 = p[2 * step], q3 = p[3 * step]; + return std::abs(p1 - p0) <= flat_thresh && std::abs(q1 - q0) <= flat_thresh && + std::abs(p2 - p0) <= flat_thresh && std::abs(q2 - q0) <= flat_thresh && + std::abs(p3 - p0) <= flat_thresh && std::abs(q3 - q0) <= flat_thresh; +} + +template <typename Pixel> +inline Pixel ApplyFilter8(int filter_value) { + return static_cast<Pixel>(RightShiftWithRounding(filter_value, 3)); +} + +// 7.14.6.4. +// 8 pixels in, 6 pixels out. +template <typename Pixel> +inline void Filter8_C(Pixel* p, ptrdiff_t step) { + const int p3 = p[-4 * step], p2 = p[-3 * step], p1 = p[-2 * step], + p0 = p[-step]; + const int q0 = p[0], q1 = p[step], q2 = p[2 * step], q3 = p[3 * step]; + // The max is 8 * max_pixel + 4 for the rounder. + // 8bpp: 2044 (11 bits), 10bpp: 8188 (13 bits), 12bpp: 32764 (15 bits) + p[-3 * step] = ApplyFilter8<Pixel>(3 * p3 + 2 * p2 + p1 + p0 + q0); + p[-2 * step] = ApplyFilter8<Pixel>(2 * p3 + p2 + 2 * p1 + p0 + q0 + q1); + p[-1 * step] = ApplyFilter8<Pixel>(p3 + p2 + p1 + 2 * p0 + q0 + q1 + q2); + p[0 * step] = ApplyFilter8<Pixel>(p2 + p1 + p0 + 2 * q0 + q1 + q2 + q3); + p[1 * step] = ApplyFilter8<Pixel>(p1 + p0 + q0 + 2 * q1 + q2 + 2 * q3); + p[2 * step] = ApplyFilter8<Pixel>(p0 + q0 + q1 + 2 * q2 + 3 * q3); +} + +template <int bitdepth, typename Pixel> +void LoopFilterFuncs_C<bitdepth, Pixel>::Vertical8(void* dest, ptrdiff_t stride, + int outer_thresh, + int inner_thresh, + int hev_thresh) { + const int flat_thresh = LoopFilterFuncs_C<bitdepth, Pixel>::kFlatThresh; + AdjustThresholds(bitdepth, &outer_thresh, &inner_thresh, &hev_thresh); + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + for (int i = 0; i < 4; ++i) { + if (NeedsFilter8(dst, 1, outer_thresh, inner_thresh)) { + if (IsFlat4(dst, 1, flat_thresh)) { + Filter8_C(dst, 1); + } else if (Hev(dst, 1, hev_thresh)) { + Filter2_C<bitdepth>(dst, 1); + } else { + Filter4_C<bitdepth>(dst, 1); + } + } + dst += stride; + } +} + +template <int bitdepth, typename Pixel> +void LoopFilterFuncs_C<bitdepth, Pixel>::Horizontal8(void* dest, + ptrdiff_t stride, + int outer_thresh, + int inner_thresh, + int hev_thresh) { + const int flat_thresh = LoopFilterFuncs_C<bitdepth, Pixel>::kFlatThresh; + AdjustThresholds(bitdepth, &outer_thresh, &inner_thresh, &hev_thresh); + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + for (int i = 0; i < 4; ++i) { + if (NeedsFilter8(dst, stride, outer_thresh, inner_thresh)) { + if (IsFlat4(dst, stride, flat_thresh)) { + Filter8_C(dst, stride); + } else if (Hev(dst, stride, hev_thresh)) { + Filter2_C<bitdepth>(dst, stride); + } else { + Filter4_C<bitdepth>(dst, stride); + } + } + ++dst; + } +} + +//------------------------------------------------------------------------------ +// 13-tap filters + +// 7.14.6.2. +template <typename Pixel> +inline bool IsFlatOuter4(const Pixel* p, ptrdiff_t step, int flat_thresh) { + const int p6 = p[-7 * step], p5 = p[-6 * step], p4 = p[-5 * step], + p0 = p[-step]; + const int q0 = p[0], q4 = p[4 * step], q5 = p[5 * step], q6 = p[6 * step]; + return std::abs(p4 - p0) <= flat_thresh && std::abs(q4 - q0) <= flat_thresh && + std::abs(p5 - p0) <= flat_thresh && std::abs(q5 - q0) <= flat_thresh && + std::abs(p6 - p0) <= flat_thresh && std::abs(q6 - q0) <= flat_thresh; +} + +template <typename Pixel> +inline Pixel ApplyFilter14(int filter_value) { + return static_cast<Pixel>(RightShiftWithRounding(filter_value, 4)); +} + +// 7.14.6.4. +// 14 pixels in, 12 pixels out. +template <typename Pixel> +inline void Filter14_C(Pixel* p, ptrdiff_t step) { + const int p6 = p[-7 * step], p5 = p[-6 * step], p4 = p[-5 * step], + p3 = p[-4 * step], p2 = p[-3 * step], p1 = p[-2 * step], + p0 = p[-step]; + const int q0 = p[0], q1 = p[step], q2 = p[2 * step], q3 = p[3 * step], + q4 = p[4 * step], q5 = p[5 * step], q6 = p[6 * step]; + // The max is 16 * max_pixel + 8 for the rounder. + // 8bpp: 4088 (12 bits), 10bpp: 16376 (14 bits), 12bpp: 65528 (16 bits) + p[-6 * step] = + ApplyFilter14<Pixel>(p6 * 7 + p5 * 2 + p4 * 2 + p3 + p2 + p1 + p0 + q0); + p[-5 * step] = ApplyFilter14<Pixel>(p6 * 5 + p5 * 2 + p4 * 2 + p3 * 2 + p2 + + p1 + p0 + q0 + q1); + p[-4 * step] = ApplyFilter14<Pixel>(p6 * 4 + p5 + p4 * 2 + p3 * 2 + p2 * 2 + + p1 + p0 + q0 + q1 + q2); + p[-3 * step] = ApplyFilter14<Pixel>(p6 * 3 + p5 + p4 + p3 * 2 + p2 * 2 + + p1 * 2 + p0 + q0 + q1 + q2 + q3); + p[-2 * step] = ApplyFilter14<Pixel>(p6 * 2 + p5 + p4 + p3 + p2 * 2 + p1 * 2 + + p0 * 2 + q0 + q1 + q2 + q3 + q4); + p[-1 * step] = ApplyFilter14<Pixel>(p6 + p5 + p4 + p3 + p2 + p1 * 2 + p0 * 2 + + q0 * 2 + q1 + q2 + q3 + q4 + q5); + p[0 * step] = ApplyFilter14<Pixel>(p5 + p4 + p3 + p2 + p1 + p0 * 2 + q0 * 2 + + q1 * 2 + q2 + q3 + q4 + q5 + q6); + p[1 * step] = ApplyFilter14<Pixel>(p4 + p3 + p2 + p1 + p0 + q0 * 2 + q1 * 2 + + q2 * 2 + q3 + q4 + q5 + q6 * 2); + p[2 * step] = ApplyFilter14<Pixel>(p3 + p2 + p1 + p0 + q0 + q1 * 2 + q2 * 2 + + q3 * 2 + q4 + q5 + q6 * 3); + p[3 * step] = ApplyFilter14<Pixel>(p2 + p1 + p0 + q0 + q1 + q2 * 2 + q3 * 2 + + q4 * 2 + q5 + q6 * 4); + p[4 * step] = ApplyFilter14<Pixel>(p1 + p0 + q0 + q1 + q2 + q3 * 2 + q4 * 2 + + q5 * 2 + q6 * 5); + p[5 * step] = + ApplyFilter14<Pixel>(p0 + q0 + q1 + q2 + q3 + q4 * 2 + q5 * 2 + q6 * 7); +} + +template <int bitdepth, typename Pixel> +void LoopFilterFuncs_C<bitdepth, Pixel>::Vertical14(void* dest, + ptrdiff_t stride, + int outer_thresh, + int inner_thresh, + int hev_thresh) { + const int flat_thresh = LoopFilterFuncs_C<bitdepth, Pixel>::kFlatThresh; + AdjustThresholds(bitdepth, &outer_thresh, &inner_thresh, &hev_thresh); + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + for (int i = 0; i < 4; ++i) { + if (NeedsFilter8(dst, 1, outer_thresh, inner_thresh)) { + if (IsFlat4(dst, 1, flat_thresh)) { + if (IsFlatOuter4(dst, 1, flat_thresh)) { + Filter14_C(dst, 1); + } else { + Filter8_C(dst, 1); + } + } else if (Hev(dst, 1, hev_thresh)) { + Filter2_C<bitdepth>(dst, 1); + } else { + Filter4_C<bitdepth>(dst, 1); + } + } + dst += stride; + } +} + +template <int bitdepth, typename Pixel> +void LoopFilterFuncs_C<bitdepth, Pixel>::Horizontal14(void* dest, + ptrdiff_t stride, + int outer_thresh, + int inner_thresh, + int hev_thresh) { + const int flat_thresh = LoopFilterFuncs_C<bitdepth, Pixel>::kFlatThresh; + AdjustThresholds(bitdepth, &outer_thresh, &inner_thresh, &hev_thresh); + auto* dst = static_cast<Pixel*>(dest); + stride /= sizeof(Pixel); + for (int i = 0; i < 4; ++i) { + if (NeedsFilter8(dst, stride, outer_thresh, inner_thresh)) { + if (IsFlat4(dst, stride, flat_thresh)) { + if (IsFlatOuter4(dst, stride, flat_thresh)) { + Filter14_C(dst, stride); + } else { + Filter8_C(dst, stride); + } + } else if (Hev(dst, stride, hev_thresh)) { + Filter2_C<bitdepth>(dst, stride); + } else { + Filter4_C<bitdepth>(dst, stride); + } + } + ++dst; + } +} + +using Defs8bpp = LoopFilterFuncs_C<8, uint8_t>; + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(8); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeHorizontal] = + Defs8bpp::Horizontal4; + dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeVertical] = + Defs8bpp::Vertical4; + + dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeHorizontal] = + Defs8bpp::Horizontal6; + dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeVertical] = + Defs8bpp::Vertical6; + + dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeHorizontal] = + Defs8bpp::Horizontal8; + dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeVertical] = + Defs8bpp::Vertical8; + + dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeHorizontal] = + Defs8bpp::Horizontal14; + dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeVertical] = + Defs8bpp::Vertical14; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp8bpp_LoopFilterSize4_LoopFilterTypeHorizontal + dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeHorizontal] = + Defs8bpp::Horizontal4; +#endif +#ifndef LIBGAV1_Dsp8bpp_LoopFilterSize4_LoopFilterTypeVertical + dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeVertical] = + Defs8bpp::Vertical4; +#endif + +#ifndef LIBGAV1_Dsp8bpp_LoopFilterSize6_LoopFilterTypeHorizontal + dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeHorizontal] = + Defs8bpp::Horizontal6; +#endif +#ifndef LIBGAV1_Dsp8bpp_LoopFilterSize6_LoopFilterTypeVertical + dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeVertical] = + Defs8bpp::Vertical6; +#endif + +#ifndef LIBGAV1_Dsp8bpp_LoopFilterSize8_LoopFilterTypeHorizontal + dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeHorizontal] = + Defs8bpp::Horizontal8; +#endif +#ifndef LIBGAV1_Dsp8bpp_LoopFilterSize8_LoopFilterTypeVertical + dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeVertical] = + Defs8bpp::Vertical8; +#endif + +#ifndef LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeHorizontal + dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeHorizontal] = + Defs8bpp::Horizontal14; +#endif +#ifndef LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeVertical + dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeVertical] = + Defs8bpp::Vertical14; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +using Defs10bpp = LoopFilterFuncs_C<10, uint16_t>; + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(10); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeHorizontal] = + Defs10bpp::Horizontal4; + dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeVertical] = + Defs10bpp::Vertical4; + + dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeHorizontal] = + Defs10bpp::Horizontal6; + dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeVertical] = + Defs10bpp::Vertical6; + + dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeHorizontal] = + Defs10bpp::Horizontal8; + dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeVertical] = + Defs10bpp::Vertical8; + + dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeHorizontal] = + Defs10bpp::Horizontal14; + dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeVertical] = + Defs10bpp::Vertical14; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp10bpp_LoopFilterSize4_LoopFilterTypeHorizontal + dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeHorizontal] = + Defs10bpp::Horizontal4; +#endif +#ifndef LIBGAV1_Dsp10bpp_LoopFilterSize4_LoopFilterTypeVertical + dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeVertical] = + Defs10bpp::Vertical4; +#endif + +#ifndef LIBGAV1_Dsp10bpp_LoopFilterSize6_LoopFilterTypeHorizontal + dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeHorizontal] = + Defs10bpp::Horizontal6; +#endif +#ifndef LIBGAV1_Dsp10bpp_LoopFilterSize6_LoopFilterTypeVertical + dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeVertical] = + Defs10bpp::Vertical6; +#endif + +#ifndef LIBGAV1_Dsp10bpp_LoopFilterSize8_LoopFilterTypeHorizontal + dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeHorizontal] = + Defs10bpp::Horizontal8; +#endif +#ifndef LIBGAV1_Dsp10bpp_LoopFilterSize8_LoopFilterTypeVertical + dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeVertical] = + Defs10bpp::Vertical8; +#endif + +#ifndef LIBGAV1_Dsp10bpp_LoopFilterSize14_LoopFilterTypeHorizontal + dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeHorizontal] = + Defs10bpp::Horizontal14; +#endif +#ifndef LIBGAV1_Dsp10bpp_LoopFilterSize14_LoopFilterTypeVertical + dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeVertical] = + Defs10bpp::Vertical14; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +} // namespace + +void LoopFilterInit_C() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif + // Local functions that may be unused depending on the optimizations + // available. + static_cast<void>(AdjustThresholds); +} + +} // namespace dsp +} // namespace libgav1 diff --git a/src/dsp/loop_filter.h b/src/dsp/loop_filter.h new file mode 100644 index 0000000..1ddad71 --- /dev/null +++ b/src/dsp/loop_filter.h @@ -0,0 +1,47 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_LOOP_FILTER_H_ +#define LIBGAV1_SRC_DSP_LOOP_FILTER_H_ + +// Pull in LIBGAV1_DspXXX defines representing the implementation status +// of each function. The resulting value of each can be used by each module to +// determine whether an implementation is needed at compile time. +// IWYU pragma: begin_exports + +// ARM: +#include "src/dsp/arm/loop_filter_neon.h" + +// x86: +// Note includes should be sorted in logical order avx2/avx/sse4, etc. +// The order of includes is important as each tests for a superior version +// before setting the base. +// clang-format off +#include "src/dsp/x86/loop_filter_sse4.h" +// clang-format on + +// IWYU pragma: end_exports + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::loop_filters. This function is not thread-safe. +void LoopFilterInit_C(); + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_LOOP_FILTER_H_ diff --git a/src/dsp/loop_restoration.cc b/src/dsp/loop_restoration.cc new file mode 100644 index 0000000..0909df0 --- /dev/null +++ b/src/dsp/loop_restoration.cc @@ -0,0 +1,936 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/loop_restoration.h" + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> + +#include "src/dsp/common.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { + +// Section 7.17.3. +// a2: range [1, 256]. +// if (z >= 255) +// a2 = 256; +// else if (z == 0) +// a2 = 1; +// else +// a2 = ((z << kSgrProjSgrBits) + (z >> 1)) / (z + 1); +// ma = 256 - a2; +alignas(16) const uint8_t kSgrMaLookup[256] = { + 255, 128, 85, 64, 51, 43, 37, 32, 28, 26, 23, 21, 20, 18, 17, 16, 15, 14, + 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 9, 9, 8, 8, 8, 8, 7, 7, + 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 0}; + +namespace { + +template <int bitdepth, typename Pixel> +inline void WienerHorizontal(const Pixel* source, const ptrdiff_t source_stride, + const int width, const int height, + const int16_t* const filter, + const int number_zero_coefficients, + int16_t** wiener_buffer) { + constexpr int kCenterTap = kWienerFilterTaps / 2; + constexpr int kRoundBitsHorizontal = (bitdepth == 12) + ? kInterRoundBitsHorizontal12bpp + : kInterRoundBitsHorizontal; + constexpr int offset = + 1 << (bitdepth + kWienerFilterBits - kRoundBitsHorizontal - 1); + constexpr int limit = (offset << 2) - 1; + for (int y = 0; y < height; ++y) { + int x = 0; + do { + // sum fits into 16 bits only when bitdepth = 8. + int sum = 0; + for (int k = number_zero_coefficients; k < kCenterTap; ++k) { + sum += + filter[k] * (source[x + k] + source[x + kWienerFilterTaps - 1 - k]); + } + sum += filter[kCenterTap] * source[x + kCenterTap]; + const int rounded_sum = RightShiftWithRounding(sum, kRoundBitsHorizontal); + (*wiener_buffer)[x] = Clip3(rounded_sum, -offset, limit - offset); + } while (++x != width); + source += source_stride; + *wiener_buffer += width; + } +} + +template <int bitdepth, typename Pixel> +inline void WienerVertical(const int16_t* wiener_buffer, const int width, + const int height, const int16_t* const filter, + const int number_zero_coefficients, void* const dest, + const ptrdiff_t dest_stride) { + constexpr int kCenterTap = kWienerFilterTaps / 2; + constexpr int kRoundBitsVertical = + (bitdepth == 12) ? kInterRoundBitsVertical12bpp : kInterRoundBitsVertical; + auto* dst = static_cast<Pixel*>(dest); + int y = height; + do { + int x = 0; + do { + // sum needs 32 bits. + int sum = 0; + for (int k = number_zero_coefficients; k < kCenterTap; ++k) { + sum += filter[k] * + (wiener_buffer[k * width + x] + + wiener_buffer[(kWienerFilterTaps - 1 - k) * width + x]); + } + sum += filter[kCenterTap] * wiener_buffer[kCenterTap * width + x]; + const int rounded_sum = RightShiftWithRounding(sum, kRoundBitsVertical); + dst[x] = static_cast<Pixel>(Clip3(rounded_sum, 0, (1 << bitdepth) - 1)); + } while (++x != width); + wiener_buffer += width; + dst += dest_stride; + } while (--y != 0); +} + +// Note: bit range for wiener filter. +// Wiener filter process first applies horizontal filtering to input pixels, +// followed by rounding with predefined bits (dependent on bitdepth). +// Then vertical filtering is applied, followed by rounding (dependent on +// bitdepth). +// The process is the same as convolution: +// <input> --> <horizontal filter> --> <rounding 0> --> <vertical filter> +// --> <rounding 1> +// By design: +// (a). horizontal/vertical filtering adds 7 bits to input. +// (b). The output of first rounding fits into 16 bits. +// (c). The output of second rounding fits into 16 bits. +// If input bitdepth > 8, the accumulator of the horizontal filter is larger +// than 16 bit and smaller than 32 bits. +// The accumulator of the vertical filter is larger than 16 bits and smaller +// than 32 bits. +// Note: range of wiener filter coefficients. +// Wiener filter coefficients are symmetric, and their sum is 1 (128). +// The range of each coefficient: +// filter[0] = filter[6], 4 bits, min = -5, max = 10. +// filter[1] = filter[5], 5 bits, min = -23, max = 8. +// filter[2] = filter[4], 6 bits, min = -17, max = 46. +// filter[3] = 128 - 2 * (filter[0] + filter[1] + filter[2]). +// The difference from libaom is that in libaom: +// filter[3] = 0 - 2 * (filter[0] + filter[1] + filter[2]). +// Thus in libaom's computation, an offset of 128 is needed for filter[3]. +template <int bitdepth, typename Pixel> +void WienerFilter_C(const RestorationUnitInfo& restoration_info, + const void* const source, const void* const top_border, + const void* const bottom_border, const ptrdiff_t stride, + const int width, const int height, + RestorationBuffer* const restoration_buffer, + void* const dest) { + constexpr int kCenterTap = kWienerFilterTaps / 2; + const int16_t* const number_leading_zero_coefficients = + restoration_info.wiener_info.number_leading_zero_coefficients; + const int number_rows_to_skip = std::max( + static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]), + 1); + int16_t* const wiener_buffer_org = restoration_buffer->wiener_buffer; + + // horizontal filtering. + const int height_horizontal = + height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip; + const int height_extra = (height_horizontal - height) >> 1; + assert(height_extra <= 2); + const int16_t* const filter_horizontal = + restoration_info.wiener_info.filter[WienerInfo::kHorizontal]; + const auto* src = static_cast<const Pixel*>(source) - kCenterTap; + const auto* top = static_cast<const Pixel*>(top_border) - kCenterTap; + const auto* bottom = static_cast<const Pixel*>(bottom_border) - kCenterTap; + auto* wiener_buffer = wiener_buffer_org + number_rows_to_skip * width; + + if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) { + WienerHorizontal<bitdepth, Pixel>(top + (2 - height_extra) * stride, stride, + width, height_extra, filter_horizontal, 0, + &wiener_buffer); + WienerHorizontal<bitdepth, Pixel>(src, stride, width, height, + filter_horizontal, 0, &wiener_buffer); + WienerHorizontal<bitdepth, Pixel>(bottom, stride, width, height_extra, + filter_horizontal, 0, &wiener_buffer); + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) { + WienerHorizontal<bitdepth, Pixel>(top + (2 - height_extra) * stride, stride, + width, height_extra, filter_horizontal, 1, + &wiener_buffer); + WienerHorizontal<bitdepth, Pixel>(src, stride, width, height, + filter_horizontal, 1, &wiener_buffer); + WienerHorizontal<bitdepth, Pixel>(bottom, stride, width, height_extra, + filter_horizontal, 1, &wiener_buffer); + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) { + WienerHorizontal<bitdepth, Pixel>(top + (2 - height_extra) * stride, stride, + width, height_extra, filter_horizontal, 2, + &wiener_buffer); + WienerHorizontal<bitdepth, Pixel>(src, stride, width, height, + filter_horizontal, 2, &wiener_buffer); + WienerHorizontal<bitdepth, Pixel>(bottom, stride, width, height_extra, + filter_horizontal, 2, &wiener_buffer); + } else { + assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3); + WienerHorizontal<bitdepth, Pixel>(top + (2 - height_extra) * stride, stride, + width, height_extra, filter_horizontal, 3, + &wiener_buffer); + WienerHorizontal<bitdepth, Pixel>(src, stride, width, height, + filter_horizontal, 3, &wiener_buffer); + WienerHorizontal<bitdepth, Pixel>(bottom, stride, width, height_extra, + filter_horizontal, 3, &wiener_buffer); + } + + // vertical filtering. + const int16_t* const filter_vertical = + restoration_info.wiener_info.filter[WienerInfo::kVertical]; + if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) { + // Because the top row of |source| is a duplicate of the second row, and the + // bottom row of |source| is a duplicate of its above row, we can duplicate + // the top and bottom row of |wiener_buffer| accordingly. + memcpy(wiener_buffer, wiener_buffer - width, + sizeof(*wiener_buffer) * width); + memcpy(wiener_buffer_org, wiener_buffer_org + width, + sizeof(*wiener_buffer) * width); + WienerVertical<bitdepth, Pixel>(wiener_buffer_org, width, height, + filter_vertical, 0, dest, stride); + } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) { + WienerVertical<bitdepth, Pixel>(wiener_buffer_org, width, height, + filter_vertical, 1, dest, stride); + } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) { + WienerVertical<bitdepth, Pixel>(wiener_buffer_org, width, height, + filter_vertical, 2, dest, stride); + } else { + assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3); + WienerVertical<bitdepth, Pixel>(wiener_buffer_org, width, height, + filter_vertical, 3, dest, stride); + } +} + +//------------------------------------------------------------------------------ +// SGR + +// When |height| is 1, |src_stride| could be set to arbitrary value. +template <typename Pixel, int size> +LIBGAV1_ALWAYS_INLINE void BoxSum(const Pixel* src, const ptrdiff_t src_stride, + const int height, const int width, + uint16_t* const* sums, + uint32_t* const* square_sums) { + int y = height; + do { + uint32_t sum = 0; + uint32_t square_sum = 0; + for (int dx = 0; dx < size; ++dx) { + const Pixel source = src[dx]; + sum += source; + square_sum += source * source; + } + (*sums)[0] = sum; + (*square_sums)[0] = square_sum; + int x = 1; + do { + const Pixel source0 = src[x - 1]; + const Pixel source1 = src[x - 1 + size]; + sum -= source0; + sum += source1; + square_sum -= source0 * source0; + square_sum += source1 * source1; + (*sums)[x] = sum; + (*square_sums)[x] = square_sum; + } while (++x != width); + src += src_stride; + ++sums; + ++square_sums; + } while (--y != 0); +} + +// When |height| is 1, |src_stride| could be set to arbitrary value. +template <typename Pixel> +LIBGAV1_ALWAYS_INLINE void BoxSum(const Pixel* src, const ptrdiff_t src_stride, + const int height, const int width, + uint16_t* const* sum3, uint16_t* const* sum5, + uint32_t* const* square_sum3, + uint32_t* const* square_sum5) { + int y = height; + do { + uint32_t sum = 0; + uint32_t square_sum = 0; + for (int dx = 0; dx < 4; ++dx) { + const Pixel source = src[dx]; + sum += source; + square_sum += source * source; + } + int x = 0; + do { + const Pixel source0 = src[x]; + const Pixel source1 = src[x + 4]; + sum -= source0; + square_sum -= source0 * source0; + (*sum3)[x] = sum; + (*square_sum3)[x] = square_sum; + sum += source1; + square_sum += source1 * source1; + (*sum5)[x] = sum + source0; + (*square_sum5)[x] = square_sum + source0 * source0; + } while (++x != width); + src += src_stride; + ++sum3; + ++sum5; + ++square_sum3; + ++square_sum5; + } while (--y != 0); +} + +template <int bitdepth, int n> +inline void CalculateIntermediate(const uint32_t s, uint32_t a, + const uint32_t b, uint8_t* const ma_ptr, + uint32_t* const b_ptr) { + // a: before shift, max is 25 * (2^(bitdepth) - 1) * (2^(bitdepth) - 1). + // since max bitdepth = 12, max < 2^31. + // after shift, a < 2^16 * n < 2^22 regardless of bitdepth + a = RightShiftWithRounding(a, (bitdepth - 8) << 1); + // b: max is 25 * (2^(bitdepth) - 1). If bitdepth = 12, max < 2^19. + // d < 2^8 * n < 2^14 regardless of bitdepth + const uint32_t d = RightShiftWithRounding(b, bitdepth - 8); + // p: Each term in calculating p = a * n - b * b is < 2^16 * n^2 < 2^28, + // and p itself satisfies p < 2^14 * n^2 < 2^26. + // This bound on p is due to: + // https://en.wikipedia.org/wiki/Popoviciu's_inequality_on_variances + // Note: Sometimes, in high bitdepth, we can end up with a*n < b*b. + // This is an artifact of rounding, and can only happen if all pixels + // are (almost) identical, so in this case we saturate to p=0. + const uint32_t p = (a * n < d * d) ? 0 : a * n - d * d; + // p * s < (2^14 * n^2) * round(2^20 / (n^2 * scale)) < 2^34 / scale < + // 2^32 as long as scale >= 4. So p * s fits into a uint32_t, and z < 2^12 + // (this holds even after accounting for the rounding in s) + const uint32_t z = RightShiftWithRounding(p * s, kSgrProjScaleBits); + // ma: range [0, 255]. + const uint32_t ma = kSgrMaLookup[std::min(z, 255u)]; + const uint32_t one_over_n = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n; + // ma < 2^8, b < 2^(bitdepth) * n, + // one_over_n = round(2^12 / n) + // => the product here is < 2^(20 + bitdepth) <= 2^32, + // and b is set to a value < 2^(8 + bitdepth). + // This holds even with the rounding in one_over_n and in the overall result, + // as long as ma is strictly less than 2^8. + const uint32_t b2 = ma * b * one_over_n; + *ma_ptr = ma; + *b_ptr = RightShiftWithRounding(b2, kSgrProjReciprocalBits); +} + +template <typename T> +inline uint32_t Sum343(const T* const src) { + return 3 * (src[0] + src[2]) + 4 * src[1]; +} + +template <typename T> +inline uint32_t Sum444(const T* const src) { + return 4 * (src[0] + src[1] + src[2]); +} + +template <typename T> +inline uint32_t Sum565(const T* const src) { + return 5 * (src[0] + src[2]) + 6 * src[1]; +} + +template <int bitdepth> +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5( + const uint16_t* const sum5[5], const uint32_t* const square_sum5[5], + const int width, const uint32_t s, SgrBuffer* const sgr_buffer, + uint16_t* const ma565, uint32_t* const b565) { + int x = 0; + do { + uint32_t a = 0; + uint32_t b = 0; + for (int dy = 0; dy < 5; ++dy) { + a += square_sum5[dy][x]; + b += sum5[dy][x]; + } + CalculateIntermediate<bitdepth, 25>(s, a, b, sgr_buffer->ma + x, + sgr_buffer->b + x); + } while (++x != width + 2); + x = 0; + do { + ma565[x] = Sum565(sgr_buffer->ma + x); + b565[x] = Sum565(sgr_buffer->b + x); + } while (++x != width); +} + +template <int bitdepth> +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3( + const uint16_t* const sum3[3], const uint32_t* const square_sum3[3], + const int width, const uint32_t s, const bool calculate444, + SgrBuffer* const sgr_buffer, uint16_t* const ma343, uint32_t* const b343, + uint16_t* const ma444, uint32_t* const b444) { + int x = 0; + do { + uint32_t a = 0; + uint32_t b = 0; + for (int dy = 0; dy < 3; ++dy) { + a += square_sum3[dy][x]; + b += sum3[dy][x]; + } + CalculateIntermediate<bitdepth, 9>(s, a, b, sgr_buffer->ma + x, + sgr_buffer->b + x); + } while (++x != width + 2); + x = 0; + do { + ma343[x] = Sum343(sgr_buffer->ma + x); + b343[x] = Sum343(sgr_buffer->b + x); + } while (++x != width); + if (calculate444) { + x = 0; + do { + ma444[x] = Sum444(sgr_buffer->ma + x); + b444[x] = Sum444(sgr_buffer->b + x); + } while (++x != width); + } +} + +template <typename Pixel> +inline int CalculateFilteredOutput(const Pixel src, const uint32_t ma, + const uint32_t b, const int shift) { + const int32_t v = b - ma * src; + return RightShiftWithRounding(v, + kSgrProjSgrBits + shift - kSgrProjRestoreBits); +} + +template <typename Pixel> +inline void BoxFilterPass1Kernel(const Pixel src0, const Pixel src1, + const uint16_t* const ma565[2], + const uint32_t* const b565[2], + const ptrdiff_t x, int p[2]) { + p[0] = CalculateFilteredOutput<Pixel>(src0, ma565[0][x] + ma565[1][x], + b565[0][x] + b565[1][x], 5); + p[1] = CalculateFilteredOutput<Pixel>(src1, ma565[1][x], b565[1][x], 4); +} + +template <typename Pixel> +inline int BoxFilterPass2Kernel(const Pixel src, const uint16_t* const ma343[3], + const uint16_t* const ma444, + const uint32_t* const b343[3], + const uint32_t* const b444, const ptrdiff_t x) { + const uint32_t ma = ma343[0][x] + ma444[x] + ma343[2][x]; + const uint32_t b = b343[0][x] + b444[x] + b343[2][x]; + return CalculateFilteredOutput<Pixel>(src, ma, b, 5); +} + +template <int bitdepth, typename Pixel> +inline Pixel SelfGuidedFinal(const int src, const int v) { + // if radius_pass_0 == 0 and radius_pass_1 == 0, the range of v is: + // bits(u) + bits(w0/w1/w2) + 2 = bitdepth + 13. + // Then, range of s is bitdepth + 2. This is a rough estimation, taking the + // maximum value of each element. + const int s = src + RightShiftWithRounding( + v, kSgrProjRestoreBits + kSgrProjPrecisionBits); + return static_cast<Pixel>(Clip3(s, 0, (1 << bitdepth) - 1)); +} + +template <int bitdepth, typename Pixel> +inline Pixel SelfGuidedDoubleMultiplier(const int src, const int filter0, + const int filter1, const int16_t w0, + const int16_t w2) { + const int v = w0 * filter0 + w2 * filter1; + return SelfGuidedFinal<bitdepth, Pixel>(src, v); +} + +template <int bitdepth, typename Pixel> +inline Pixel SelfGuidedSingleMultiplier(const int src, const int filter, + const int16_t w0) { + const int v = w0 * filter; + return SelfGuidedFinal<bitdepth, Pixel>(src, v); +} + +template <int bitdepth, typename Pixel> +inline void BoxFilterPass1(const Pixel* const src, const ptrdiff_t stride, + uint16_t* const sum5[5], + uint32_t* const square_sum5[5], const int width, + const uint32_t scale, const int16_t w0, + SgrBuffer* const sgr_buffer, + uint16_t* const ma565[2], uint32_t* const b565[2], + Pixel* dst) { + BoxFilterPreProcess5<bitdepth>(sum5, square_sum5, width, scale, sgr_buffer, + ma565[1], b565[1]); + int x = 0; + do { + int p[2]; + BoxFilterPass1Kernel<Pixel>(src[x], src[stride + x], ma565, b565, x, p); + dst[x] = SelfGuidedSingleMultiplier<bitdepth, Pixel>(src[x], p[0], w0); + dst[stride + x] = + SelfGuidedSingleMultiplier<bitdepth, Pixel>(src[stride + x], p[1], w0); + } while (++x != width); +} + +template <int bitdepth, typename Pixel> +inline void BoxFilterPass2(const Pixel* const src, const Pixel* const src0, + const int width, const uint16_t scale, + const int16_t w0, uint16_t* const sum3[4], + uint32_t* const square_sum3[4], + SgrBuffer* const sgr_buffer, + uint16_t* const ma343[4], uint16_t* const ma444[3], + uint32_t* const b343[4], uint32_t* const b444[3], + Pixel* dst) { + BoxSum<Pixel, 3>(src0, 0, 1, width + 2, sum3 + 2, square_sum3 + 2); + BoxFilterPreProcess3<bitdepth>(sum3, square_sum3, width, scale, true, + sgr_buffer, ma343[2], b343[2], ma444[1], + b444[1]); + int x = 0; + do { + const int p = + BoxFilterPass2Kernel<Pixel>(src[x], ma343, ma444[0], b343, b444[0], x); + dst[x] = SelfGuidedSingleMultiplier<bitdepth, Pixel>(src[x], p, w0); + } while (++x != width); +} + +template <int bitdepth, typename Pixel> +inline void BoxFilter(const Pixel* const src, const ptrdiff_t stride, + uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], + uint32_t* const square_sum5[5], const int width, + const uint16_t scales[2], const int16_t w0, + const int16_t w2, SgrBuffer* const sgr_buffer, + uint16_t* const ma343[4], uint16_t* const ma444[3], + uint16_t* const ma565[2], uint32_t* const b343[4], + uint32_t* const b444[3], uint32_t* const b565[2], + Pixel* dst) { + BoxFilterPreProcess5<bitdepth>(sum5, square_sum5, width, scales[0], + sgr_buffer, ma565[1], b565[1]); + BoxFilterPreProcess3<bitdepth>(sum3, square_sum3, width, scales[1], true, + sgr_buffer, ma343[2], b343[2], ma444[1], + b444[1]); + BoxFilterPreProcess3<bitdepth>(sum3 + 1, square_sum3 + 1, width, scales[1], + true, sgr_buffer, ma343[3], b343[3], ma444[2], + b444[2]); + int x = 0; + do { + int p[2][2]; + BoxFilterPass1Kernel<Pixel>(src[x], src[stride + x], ma565, b565, x, p[0]); + p[1][0] = + BoxFilterPass2Kernel<Pixel>(src[x], ma343, ma444[0], b343, b444[0], x); + p[1][1] = BoxFilterPass2Kernel<Pixel>(src[stride + x], ma343 + 1, ma444[1], + b343 + 1, b444[1], x); + dst[x] = SelfGuidedDoubleMultiplier<bitdepth, Pixel>(src[x], p[0][0], + p[1][0], w0, w2); + dst[stride + x] = SelfGuidedDoubleMultiplier<bitdepth, Pixel>( + src[stride + x], p[0][1], p[1][1], w0, w2); + } while (++x != width); +} + +template <int bitdepth, typename Pixel> +inline void BoxFilterProcess(const RestorationUnitInfo& restoration_info, + const Pixel* src, const Pixel* const top_border, + const Pixel* bottom_border, const ptrdiff_t stride, + const int width, const int height, + SgrBuffer* const sgr_buffer, Pixel* dst) { + const auto temp_stride = Align<ptrdiff_t>(width, 8); + const ptrdiff_t sum_stride = temp_stride + 8; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index]; // < 2^12. + const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0]; + const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1]; + const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1; + uint16_t *sum3[4], *sum5[5], *ma343[4], *ma444[3], *ma565[2]; + uint32_t *square_sum3[4], *square_sum5[5], *b343[4], *b444[3], *b565[2]; + sum3[0] = sgr_buffer->sum3; + square_sum3[0] = sgr_buffer->square_sum3; + ma343[0] = sgr_buffer->ma343; + b343[0] = sgr_buffer->b343; + for (int i = 1; i <= 3; ++i) { + sum3[i] = sum3[i - 1] + sum_stride; + square_sum3[i] = square_sum3[i - 1] + sum_stride; + ma343[i] = ma343[i - 1] + temp_stride; + b343[i] = b343[i - 1] + temp_stride; + } + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + for (int i = 1; i <= 4; ++i) { + sum5[i] = sum5[i - 1] + sum_stride; + square_sum5[i] = square_sum5[i - 1] + sum_stride; + } + ma444[0] = sgr_buffer->ma444; + b444[0] = sgr_buffer->b444; + for (int i = 1; i <= 2; ++i) { + ma444[i] = ma444[i - 1] + temp_stride; + b444[i] = b444[i - 1] + temp_stride; + } + ma565[0] = sgr_buffer->ma565; + ma565[1] = ma565[0] + temp_stride; + b565[0] = sgr_buffer->b565; + b565[1] = b565[0] + temp_stride; + assert(scales[0] != 0); + assert(scales[1] != 0); + BoxSum<Pixel>(top_border, stride, 2, width + 2, sum3, sum5 + 1, square_sum3, + square_sum5 + 1); + sum5[0] = sum5[1]; + square_sum5[0] = square_sum5[1]; + BoxSum<Pixel>(src, stride, 1, width + 2, sum3 + 2, sum5 + 3, square_sum3 + 2, + square_sum5 + 3); + const Pixel* const s = (height > 1) ? src + stride : bottom_border; + BoxSum<Pixel>(s, 0, 1, width + 2, sum3 + 3, sum5 + 4, square_sum3 + 3, + square_sum5 + 4); + BoxFilterPreProcess5<bitdepth>(sum5, square_sum5, width, scales[0], + sgr_buffer, ma565[0], b565[0]); + BoxFilterPreProcess3<bitdepth>(sum3, square_sum3, width, scales[1], false, + sgr_buffer, ma343[0], b343[0], nullptr, + nullptr); + BoxFilterPreProcess3<bitdepth>(sum3 + 1, square_sum3 + 1, width, scales[1], + true, sgr_buffer, ma343[1], b343[1], ma444[0], + b444[0]); + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + + for (int y = (height >> 1) - 1; y > 0; --y) { + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + BoxSum<Pixel>(src + 2 * stride, stride, 2, width + 2, sum3 + 2, sum5 + 3, + square_sum3 + 2, square_sum5 + 3); + BoxFilter<bitdepth, Pixel>(src + 3, stride, sum3, sum5, square_sum3, + square_sum5, width, scales, w0, w2, sgr_buffer, + ma343, ma444, ma565, b343, b444, b565, dst); + src += 2 * stride; + dst += 2 * stride; + Circulate4PointersBy2<uint16_t>(ma343); + Circulate4PointersBy2<uint32_t>(b343); + std::swap(ma444[0], ma444[2]); + std::swap(b444[0], b444[2]); + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + if ((height & 1) == 0 || height > 1) { + const Pixel* sr; + ptrdiff_t s_stride; + if ((height & 1) == 0) { + sr = bottom_border; + s_stride = stride; + } else { + sr = src + 2 * stride; + s_stride = bottom_border - (src + 2 * stride); + } + BoxSum<Pixel>(sr, s_stride, 2, width + 2, sum3 + 2, sum5 + 3, + square_sum3 + 2, square_sum5 + 3); + BoxFilter<bitdepth, Pixel>(src + 3, stride, sum3, sum5, square_sum3, + square_sum5, width, scales, w0, w2, sgr_buffer, + ma343, ma444, ma565, b343, b444, b565, dst); + } + if ((height & 1) != 0) { + src += 3; + if (height > 1) { + src += 2 * stride; + dst += 2 * stride; + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + Circulate4PointersBy2<uint16_t>(ma343); + Circulate4PointersBy2<uint32_t>(b343); + std::swap(ma444[0], ma444[2]); + std::swap(b444[0], b444[2]); + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + BoxSum<Pixel>(bottom_border + stride, stride, 1, width + 2, sum3 + 2, + sum5 + 3, square_sum3 + 2, square_sum5 + 3); + sum5[4] = sum5[3]; + square_sum5[4] = square_sum5[3]; + BoxFilterPreProcess5<bitdepth>(sum5, square_sum5, width, scales[0], + sgr_buffer, ma565[1], b565[1]); + BoxFilterPreProcess3<bitdepth>(sum3, square_sum3, width, scales[1], false, + sgr_buffer, ma343[2], b343[2], nullptr, + nullptr); + int x = 0; + do { + const int p0 = CalculateFilteredOutput<Pixel>( + src[x], ma565[0][x] + ma565[1][x], b565[0][x] + b565[1][x], 5); + const int p1 = BoxFilterPass2Kernel<Pixel>(src[x], ma343, ma444[0], b343, + b444[0], x); + dst[x] = + SelfGuidedDoubleMultiplier<bitdepth, Pixel>(src[x], p0, p1, w0, w2); + } while (++x != width); + } +} + +template <int bitdepth, typename Pixel> +inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, + const Pixel* src, + const Pixel* const top_border, + const Pixel* bottom_border, + const ptrdiff_t stride, const int width, + const int height, SgrBuffer* const sgr_buffer, + Pixel* dst) { + const auto temp_stride = Align<ptrdiff_t>(width, 8); + const ptrdiff_t sum_stride = temp_stride + 8; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0]; // < 2^12. + const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0]; + uint16_t *sum5[5], *ma565[2]; + uint32_t *square_sum5[5], *b565[2]; + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + for (int i = 1; i <= 4; ++i) { + sum5[i] = sum5[i - 1] + sum_stride; + square_sum5[i] = square_sum5[i - 1] + sum_stride; + } + ma565[0] = sgr_buffer->ma565; + ma565[1] = ma565[0] + temp_stride; + b565[0] = sgr_buffer->b565; + b565[1] = b565[0] + temp_stride; + assert(scale != 0); + BoxSum<Pixel, 5>(top_border, stride, 2, width + 2, sum5 + 1, square_sum5 + 1); + sum5[0] = sum5[1]; + square_sum5[0] = square_sum5[1]; + BoxSum<Pixel, 5>(src, stride, 1, width + 2, sum5 + 3, square_sum5 + 3); + const Pixel* const s = (height > 1) ? src + stride : bottom_border; + BoxSum<Pixel, 5>(s, 0, 1, width + 2, sum5 + 4, square_sum5 + 4); + BoxFilterPreProcess5<bitdepth>(sum5, square_sum5, width, scale, sgr_buffer, + ma565[0], b565[0]); + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + + for (int y = (height >> 1) - 1; y > 0; --y) { + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + BoxSum<Pixel, 5>(src + 2 * stride, stride, 2, width + 2, sum5 + 3, + square_sum5 + 3); + BoxFilterPass1<bitdepth, Pixel>(src + 3, stride, sum5, square_sum5, width, + scale, w0, sgr_buffer, ma565, b565, dst); + src += 2 * stride; + dst += 2 * stride; + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + if ((height & 1) == 0 || height > 1) { + const Pixel* sr; + ptrdiff_t s_stride; + if ((height & 1) == 0) { + sr = bottom_border; + s_stride = stride; + } else { + sr = src + 2 * stride; + s_stride = bottom_border - (src + 2 * stride); + } + BoxSum<Pixel, 5>(sr, s_stride, 2, width + 2, sum5 + 3, square_sum5 + 3); + BoxFilterPass1<bitdepth, Pixel>(src + 3, stride, sum5, square_sum5, width, + scale, w0, sgr_buffer, ma565, b565, dst); + } + if ((height & 1) != 0) { + src += 3; + if (height > 1) { + src += 2 * stride; + dst += 2 * stride; + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + } + BoxSum<Pixel, 5>(bottom_border + stride, stride, 1, width + 2, sum5 + 3, + square_sum5 + 3); + sum5[4] = sum5[3]; + square_sum5[4] = square_sum5[3]; + BoxFilterPreProcess5<bitdepth>(sum5, square_sum5, width, scale, sgr_buffer, + ma565[1], b565[1]); + int x = 0; + do { + const int p = CalculateFilteredOutput<Pixel>( + src[x], ma565[0][x] + ma565[1][x], b565[0][x] + b565[1][x], 5); + dst[x] = SelfGuidedSingleMultiplier<bitdepth, Pixel>(src[x], p, w0); + } while (++x != width); + } +} + +template <int bitdepth, typename Pixel> +inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, + const Pixel* src, + const Pixel* const top_border, + const Pixel* bottom_border, + const ptrdiff_t stride, const int width, + const int height, SgrBuffer* const sgr_buffer, + Pixel* dst) { + assert(restoration_info.sgr_proj_info.multiplier[0] == 0); + const auto temp_stride = Align<ptrdiff_t>(width, 8); + const ptrdiff_t sum_stride = temp_stride + 8; + const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1]; + const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint32_t scale = kSgrScaleParameter[sgr_proj_index][1]; // < 2^12. + uint16_t *sum3[3], *ma343[3], *ma444[2]; + uint32_t *square_sum3[3], *b343[3], *b444[2]; + sum3[0] = sgr_buffer->sum3; + square_sum3[0] = sgr_buffer->square_sum3; + ma343[0] = sgr_buffer->ma343; + b343[0] = sgr_buffer->b343; + for (int i = 1; i <= 2; ++i) { + sum3[i] = sum3[i - 1] + sum_stride; + square_sum3[i] = square_sum3[i - 1] + sum_stride; + ma343[i] = ma343[i - 1] + temp_stride; + b343[i] = b343[i - 1] + temp_stride; + } + ma444[0] = sgr_buffer->ma444; + ma444[1] = ma444[0] + temp_stride; + b444[0] = sgr_buffer->b444; + b444[1] = b444[0] + temp_stride; + assert(scale != 0); + BoxSum<Pixel, 3>(top_border, stride, 2, width + 2, sum3, square_sum3); + BoxSum<Pixel, 3>(src, stride, 1, width + 2, sum3 + 2, square_sum3 + 2); + BoxFilterPreProcess3<bitdepth>(sum3, square_sum3, width, scale, false, + sgr_buffer, ma343[0], b343[0], nullptr, + nullptr); + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + const Pixel* s; + if (height > 1) { + s = src + stride; + } else { + s = bottom_border; + bottom_border += stride; + } + BoxSum<Pixel, 3>(s, 0, 1, width + 2, sum3 + 2, square_sum3 + 2); + BoxFilterPreProcess3<bitdepth>(sum3, square_sum3, width, scale, true, + sgr_buffer, ma343[1], b343[1], ma444[0], + b444[0]); + + for (int y = height - 2; y > 0; --y) { + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + BoxFilterPass2<bitdepth, Pixel>(src + 2, src + 2 * stride, width, scale, w0, + sum3, square_sum3, sgr_buffer, ma343, ma444, + b343, b444, dst); + src += stride; + dst += stride; + Circulate3PointersBy1<uint16_t>(ma343); + Circulate3PointersBy1<uint32_t>(b343); + std::swap(ma444[0], ma444[1]); + std::swap(b444[0], b444[1]); + } + + src += 2; + int y = std::min(height, 2); + do { + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + BoxFilterPass2<bitdepth, Pixel>(src, bottom_border, width, scale, w0, sum3, + square_sum3, sgr_buffer, ma343, ma444, b343, + b444, dst); + src += stride; + dst += stride; + bottom_border += stride; + Circulate3PointersBy1<uint16_t>(ma343); + Circulate3PointersBy1<uint32_t>(b343); + std::swap(ma444[0], ma444[1]); + std::swap(b444[0], b444[1]); + } while (--y != 0); +} + +template <int bitdepth, typename Pixel> +void SelfGuidedFilter_C(const RestorationUnitInfo& restoration_info, + const void* const source, const void* const top_border, + const void* const bottom_border, const ptrdiff_t stride, + const int width, const int height, + RestorationBuffer* const restoration_buffer, + void* const dest) { + const int index = restoration_info.sgr_proj_info.index; + const int radius_pass_0 = kSgrProjParams[index][0]; // 2 or 0 + const int radius_pass_1 = kSgrProjParams[index][2]; // 1 or 0 + const auto* src = static_cast<const Pixel*>(source); + const auto* top = static_cast<const Pixel*>(top_border); + const auto* bottom = static_cast<const Pixel*>(bottom_border); + auto* dst = static_cast<Pixel*>(dest); + SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer; + if (radius_pass_1 == 0) { + // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the + // following assertion. + assert(radius_pass_0 != 0); + BoxFilterProcessPass1<bitdepth, Pixel>(restoration_info, src - 3, top - 3, + bottom - 3, stride, width, height, + sgr_buffer, dst); + } else if (radius_pass_0 == 0) { + BoxFilterProcessPass2<bitdepth, Pixel>(restoration_info, src - 2, top - 2, + bottom - 2, stride, width, height, + sgr_buffer, dst); + } else { + BoxFilterProcess<bitdepth, Pixel>(restoration_info, src - 3, top - 3, + bottom - 3, stride, width, height, + sgr_buffer, dst); + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(8); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->loop_restorations[0] = WienerFilter_C<8, uint8_t>; + dsp->loop_restorations[1] = SelfGuidedFilter_C<8, uint8_t>; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp8bpp_WienerFilter + dsp->loop_restorations[0] = WienerFilter_C<8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_SelfGuidedFilter + dsp->loop_restorations[1] = SelfGuidedFilter_C<8, uint8_t>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(10); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->loop_restorations[0] = WienerFilter_C<10, uint16_t>; + dsp->loop_restorations[1] = SelfGuidedFilter_C<10, uint16_t>; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp10bpp_WienerFilter + dsp->loop_restorations[0] = WienerFilter_C<10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_SelfGuidedFilter + dsp->loop_restorations[1] = SelfGuidedFilter_C<10, uint16_t>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} + +#endif // LIBGAV1_MAX_BITDEPTH >= 10 +} // namespace + +void LoopRestorationInit_C() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 diff --git a/src/dsp/loop_restoration.h b/src/dsp/loop_restoration.h new file mode 100644 index 0000000..de80926 --- /dev/null +++ b/src/dsp/loop_restoration.h @@ -0,0 +1,85 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_LOOP_RESTORATION_H_ +#define LIBGAV1_SRC_DSP_LOOP_RESTORATION_H_ + +// Pull in LIBGAV1_DspXXX defines representing the implementation status +// of each function. The resulting value of each can be used by each module to +// determine whether an implementation is needed at compile time. +// IWYU pragma: begin_exports + +// ARM: +#include "src/dsp/arm/loop_restoration_neon.h" + +// x86: +// Note includes should be sorted in logical order avx2/avx/sse4, etc. +// The order of includes is important as each tests for a superior version +// before setting the base. +// clang-format off +#include "src/dsp/x86/loop_restoration_avx2.h" +#include "src/dsp/x86/loop_restoration_sse4.h" +// clang-format on + +// IWYU pragma: end_exports + +namespace libgav1 { +namespace dsp { + +enum { + // Precision of a division table (mtable) + kSgrProjScaleBits = 20, + kSgrProjReciprocalBits = 12, + // Core self-guided restoration precision bits. + kSgrProjSgrBits = 8, + // Precision bits of generated values higher than source before projection. + kSgrProjRestoreBits = 4 +}; // anonymous enum + +extern const uint8_t kSgrMaLookup[256]; + +// Initializes Dsp::loop_restorations. This function is not thread-safe. +void LoopRestorationInit_C(); + +template <typename T> +void Circulate3PointersBy1(T* p[3]) { + T* const p0 = p[0]; + p[0] = p[1]; + p[1] = p[2]; + p[2] = p0; +} + +template <typename T> +void Circulate4PointersBy2(T* p[4]) { + std::swap(p[0], p[2]); + std::swap(p[1], p[3]); +} + +template <typename T> +void Circulate5PointersBy2(T* p[5]) { + T* const p0 = p[0]; + T* const p1 = p[1]; + p[0] = p[2]; + p[1] = p[3]; + p[2] = p[4]; + p[3] = p0; + p[4] = p1; +} + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_LOOP_RESTORATION_H_ diff --git a/src/dsp/mask_blend.cc b/src/dsp/mask_blend.cc new file mode 100644 index 0000000..101c410 --- /dev/null +++ b/src/dsp/mask_blend.cc @@ -0,0 +1,207 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/mask_blend.h" + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace { + +template <int subsampling_x, int subsampling_y> +uint8_t GetMaskValue(const uint8_t* mask, const uint8_t* mask_next_row, int x) { + if ((subsampling_x | subsampling_y) == 0) { + return mask[x]; + } + if (subsampling_x == 1 && subsampling_y == 0) { + return static_cast<uint8_t>(RightShiftWithRounding( + mask[MultiplyBy2(x)] + mask[MultiplyBy2(x) + 1], 1)); + } + assert(subsampling_x == 1 && subsampling_y == 1); + return static_cast<uint8_t>(RightShiftWithRounding( + mask[MultiplyBy2(x)] + mask[MultiplyBy2(x) + 1] + + mask_next_row[MultiplyBy2(x)] + mask_next_row[MultiplyBy2(x) + 1], + 2)); +} + +template <int bitdepth, typename Pixel, bool is_inter_intra, int subsampling_x, + int subsampling_y> +void MaskBlend_C(const void* prediction_0, const void* prediction_1, + const ptrdiff_t prediction_stride_1, const uint8_t* mask, + const ptrdiff_t mask_stride, const int width, const int height, + void* dest, const ptrdiff_t dest_stride) { + static_assert(!(bitdepth == 8 && is_inter_intra), ""); + assert(mask != nullptr); + using PredType = + typename std::conditional<bitdepth == 8, int16_t, uint16_t>::type; + const auto* pred_0 = static_cast<const PredType*>(prediction_0); + const auto* pred_1 = static_cast<const PredType*>(prediction_1); + auto* dst = static_cast<Pixel*>(dest); + const ptrdiff_t dst_stride = dest_stride / sizeof(Pixel); + constexpr int step_y = subsampling_y ? 2 : 1; + const uint8_t* mask_next_row = mask + mask_stride; + // 7.11.3.2 Rounding variables derivation process + // 2 * FILTER_BITS(7) - (InterRound0(3|5) + InterRound1(7)) + constexpr int inter_post_round_bits = (bitdepth == 12) ? 2 : 4; + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + const uint8_t mask_value = + GetMaskValue<subsampling_x, subsampling_y>(mask, mask_next_row, x); + if (is_inter_intra) { + dst[x] = static_cast<Pixel>(RightShiftWithRounding( + mask_value * pred_1[x] + (64 - mask_value) * pred_0[x], 6)); + } else { + assert(prediction_stride_1 == width); + int res = (mask_value * pred_0[x] + (64 - mask_value) * pred_1[x]) >> 6; + res -= (bitdepth == 8) ? 0 : kCompoundOffset; + dst[x] = static_cast<Pixel>( + Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0, + (1 << bitdepth) - 1)); + } + } + dst += dst_stride; + mask += mask_stride * step_y; + mask_next_row += mask_stride * step_y; + pred_0 += width; + pred_1 += prediction_stride_1; + } +} + +template <int subsampling_x, int subsampling_y> +void InterIntraMaskBlend8bpp_C(const uint8_t* prediction_0, + uint8_t* prediction_1, + const ptrdiff_t prediction_stride_1, + const uint8_t* mask, const ptrdiff_t mask_stride, + const int width, const int height) { + assert(mask != nullptr); + constexpr int step_y = subsampling_y ? 2 : 1; + const uint8_t* mask_next_row = mask + mask_stride; + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + const uint8_t mask_value = + GetMaskValue<subsampling_x, subsampling_y>(mask, mask_next_row, x); + prediction_1[x] = static_cast<uint8_t>(RightShiftWithRounding( + mask_value * prediction_1[x] + (64 - mask_value) * prediction_0[x], + 6)); + } + mask += mask_stride * step_y; + mask_next_row += mask_stride * step_y; + prediction_0 += width; + prediction_1 += prediction_stride_1; + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(8); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->mask_blend[0][0] = MaskBlend_C<8, uint8_t, false, 0, 0>; + dsp->mask_blend[1][0] = MaskBlend_C<8, uint8_t, false, 1, 0>; + dsp->mask_blend[2][0] = MaskBlend_C<8, uint8_t, false, 1, 1>; + // The is_inter_intra index of mask_blend[][] is replaced by + // inter_intra_mask_blend_8bpp[] in 8-bit. + dsp->mask_blend[0][1] = nullptr; + dsp->mask_blend[1][1] = nullptr; + dsp->mask_blend[2][1] = nullptr; + dsp->inter_intra_mask_blend_8bpp[0] = InterIntraMaskBlend8bpp_C<0, 0>; + dsp->inter_intra_mask_blend_8bpp[1] = InterIntraMaskBlend8bpp_C<1, 0>; + dsp->inter_intra_mask_blend_8bpp[2] = InterIntraMaskBlend8bpp_C<1, 1>; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp8bpp_MaskBlend444 + dsp->mask_blend[0][0] = MaskBlend_C<8, uint8_t, false, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_MaskBlend422 + dsp->mask_blend[1][0] = MaskBlend_C<8, uint8_t, false, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_MaskBlend420 + dsp->mask_blend[2][0] = MaskBlend_C<8, uint8_t, false, 1, 1>; +#endif + // The is_inter_intra index of mask_blend[][] is replaced by + // inter_intra_mask_blend_8bpp[] in 8-bit. + dsp->mask_blend[0][1] = nullptr; + dsp->mask_blend[1][1] = nullptr; + dsp->mask_blend[2][1] = nullptr; +#ifndef LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp444 + dsp->inter_intra_mask_blend_8bpp[0] = InterIntraMaskBlend8bpp_C<0, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp422 + dsp->inter_intra_mask_blend_8bpp[1] = InterIntraMaskBlend8bpp_C<1, 0>; +#endif +#ifndef LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp420 + dsp->inter_intra_mask_blend_8bpp[2] = InterIntraMaskBlend8bpp_C<1, 1>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(10); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->mask_blend[0][0] = MaskBlend_C<10, uint16_t, false, 0, 0>; + dsp->mask_blend[1][0] = MaskBlend_C<10, uint16_t, false, 1, 0>; + dsp->mask_blend[2][0] = MaskBlend_C<10, uint16_t, false, 1, 1>; + dsp->mask_blend[0][1] = MaskBlend_C<10, uint16_t, true, 0, 0>; + dsp->mask_blend[1][1] = MaskBlend_C<10, uint16_t, true, 1, 0>; + dsp->mask_blend[2][1] = MaskBlend_C<10, uint16_t, true, 1, 1>; + // These are only used with 8-bit. + dsp->inter_intra_mask_blend_8bpp[0] = nullptr; + dsp->inter_intra_mask_blend_8bpp[1] = nullptr; + dsp->inter_intra_mask_blend_8bpp[2] = nullptr; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp10bpp_MaskBlend444 + dsp->mask_blend[0][0] = MaskBlend_C<10, uint16_t, false, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_MaskBlend422 + dsp->mask_blend[1][0] = MaskBlend_C<10, uint16_t, false, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_MaskBlend420 + dsp->mask_blend[2][0] = MaskBlend_C<10, uint16_t, false, 1, 1>; +#endif +#ifndef LIBGAV1_Dsp10bpp_MaskBlendInterIntra444 + dsp->mask_blend[0][1] = MaskBlend_C<10, uint16_t, true, 0, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_MaskBlendInterIntra422 + dsp->mask_blend[1][1] = MaskBlend_C<10, uint16_t, true, 1, 0>; +#endif +#ifndef LIBGAV1_Dsp10bpp_MaskBlendInterIntra420 + dsp->mask_blend[2][1] = MaskBlend_C<10, uint16_t, true, 1, 1>; +#endif + // These are only used with 8-bit. + dsp->inter_intra_mask_blend_8bpp[0] = nullptr; + dsp->inter_intra_mask_blend_8bpp[1] = nullptr; + dsp->inter_intra_mask_blend_8bpp[2] = nullptr; +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} +#endif + +} // namespace + +void MaskBlendInit_C() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 diff --git a/src/dsp/mask_blend.h b/src/dsp/mask_blend.h new file mode 100644 index 0000000..41f5e5b --- /dev/null +++ b/src/dsp/mask_blend.h @@ -0,0 +1,49 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_MASK_BLEND_H_ +#define LIBGAV1_SRC_DSP_MASK_BLEND_H_ + +// Pull in LIBGAV1_DspXXX defines representing the implementation status +// of each function. The resulting value of each can be used by each module to +// determine whether an implementation is needed at compile time. +// IWYU pragma: begin_exports + +// ARM: +#include "src/dsp/arm/mask_blend_neon.h" + +// x86: +// Note includes should be sorted in logical order avx2/avx/sse4, etc. +// The order of includes is important as each tests for a superior version +// before setting the base. +// clang-format off +// SSE4_1 +#include "src/dsp/x86/mask_blend_sse4.h" +// clang-format on + +// IWYU pragma: end_exports + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::mask_blend and Dsp::inter_intra_mask_blend_8bpp. This +// function is not thread-safe. +void MaskBlendInit_C(); + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_MASK_BLEND_H_ diff --git a/src/dsp/motion_field_projection.cc b/src/dsp/motion_field_projection.cc new file mode 100644 index 0000000..b51ec8f --- /dev/null +++ b/src/dsp/motion_field_projection.cc @@ -0,0 +1,138 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/motion_field_projection.h" + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" +#include "src/utils/reference_info.h" +#include "src/utils/types.h" + +namespace libgav1 { +namespace dsp { +namespace { + +// Silence unused function warnings when MotionFieldProjectionKernel_C is +// not used. +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \ + !defined(LIBGAV1_Dsp8bpp_MotionFieldProjectionKernel) || \ + (LIBGAV1_MAX_BITDEPTH >= 10 && \ + !defined(LIBGAV1_Dsp10bpp_MotionFieldProjectionKernel)) + +// 7.9.2. +void MotionFieldProjectionKernel_C(const ReferenceInfo& reference_info, + int reference_to_current_with_sign, + int dst_sign, int y8_start, int y8_end, + int x8_start, int x8_end, + TemporalMotionField* motion_field) { + const ptrdiff_t stride = motion_field->mv.columns(); + // The column range has to be offset by kProjectionMvMaxHorizontalOffset since + // coordinates in that range could end up being position_x8 because of + // projection. + const int adjusted_x8_start = + std::max(x8_start - kProjectionMvMaxHorizontalOffset, 0); + const int adjusted_x8_end = std::min( + x8_end + kProjectionMvMaxHorizontalOffset, static_cast<int>(stride)); + const int8_t* const reference_offsets = + reference_info.relative_distance_to.data(); + const bool* const skip_references = reference_info.skip_references.data(); + const int16_t* const projection_divisions = + reference_info.projection_divisions.data(); + const ReferenceFrameType* source_reference_types = + &reference_info.motion_field_reference_frame[y8_start][0]; + const MotionVector* mv = &reference_info.motion_field_mv[y8_start][0]; + int8_t* dst_reference_offset = motion_field->reference_offset[y8_start]; + MotionVector* dst_mv = motion_field->mv[y8_start]; + assert(stride == motion_field->reference_offset.columns()); + assert((y8_start & 7) == 0); + + int y8 = y8_start; + do { + const int y8_floor = (y8 & ~7) - y8; + const int y8_ceiling = std::min(y8_end - y8, y8_floor + 8); + int x8 = adjusted_x8_start; + do { + const int source_reference_type = source_reference_types[x8]; + if (skip_references[source_reference_type]) continue; + MotionVector projection_mv; + // reference_to_current_with_sign could be 0. + GetMvProjection(mv[x8], reference_to_current_with_sign, + projection_divisions[source_reference_type], + &projection_mv); + // Do not update the motion vector if the block position is not valid or + // if position_x8 is outside the current range of x8_start and x8_end. + // Note that position_y8 will always be within the range of y8_start and + // y8_end. + const int position_y8 = Project(0, projection_mv.mv[0], dst_sign); + if (position_y8 < y8_floor || position_y8 >= y8_ceiling) continue; + const int x8_base = x8 & ~7; + const int x8_floor = + std::max(x8_start, x8_base - kProjectionMvMaxHorizontalOffset); + const int x8_ceiling = + std::min(x8_end, x8_base + 8 + kProjectionMvMaxHorizontalOffset); + const int position_x8 = Project(x8, projection_mv.mv[1], dst_sign); + if (position_x8 < x8_floor || position_x8 >= x8_ceiling) continue; + dst_mv[position_y8 * stride + position_x8] = mv[x8]; + dst_reference_offset[position_y8 * stride + position_x8] = + reference_offsets[source_reference_type]; + } while (++x8 < adjusted_x8_end); + source_reference_types += stride; + mv += stride; + dst_reference_offset += stride; + dst_mv += stride; + } while (++y8 < y8_end); +} + +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || + // !defined(LIBGAV1_Dsp8bpp_MotionFieldProjectionKernel) || + // (LIBGAV1_MAX_BITDEPTH >= 10 && + // !defined(LIBGAV1_Dsp10bpp_MotionFieldProjectionKernel)) + +void Init8bpp() { +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \ + !defined(LIBGAV1_Dsp8bpp_MotionFieldProjectionKernel) + Dsp* const dsp = dsp_internal::GetWritableDspTable(8); + assert(dsp != nullptr); + dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_C; +#endif +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \ + !defined(LIBGAV1_Dsp10bpp_MotionFieldProjectionKernel) + Dsp* const dsp = dsp_internal::GetWritableDspTable(10); + assert(dsp != nullptr); + dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_C; +#endif +} +#endif + +} // namespace + +void MotionFieldProjectionInit_C() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 diff --git a/src/dsp/motion_field_projection.h b/src/dsp/motion_field_projection.h new file mode 100644 index 0000000..36de459 --- /dev/null +++ b/src/dsp/motion_field_projection.h @@ -0,0 +1,48 @@ +/* + * Copyright 2020 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_MOTION_FIELD_PROJECTION_H_ +#define LIBGAV1_SRC_DSP_MOTION_FIELD_PROJECTION_H_ + +// Pull in LIBGAV1_DspXXX defines representing the implementation status +// of each function. The resulting value of each can be used by each module to +// determine whether an implementation is needed at compile time. +// IWYU pragma: begin_exports + +// ARM: +#include "src/dsp/arm/motion_field_projection_neon.h" +// x86: +// Note includes should be sorted in logical order avx2/avx/sse4, etc. +// The order of includes is important as each tests for a superior version +// before setting the base. +// clang-format off +// SSE4_1 +#include "src/dsp/x86/motion_field_projection_sse4.h" +// clang-format on + +// IWYU pragma: end_exports + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::motion_field_projection_kernel. This function is not +// thread-safe. +void MotionFieldProjectionInit_C(); + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_MOTION_FIELD_PROJECTION_H_ diff --git a/src/dsp/motion_vector_search.cc b/src/dsp/motion_vector_search.cc new file mode 100644 index 0000000..9402302 --- /dev/null +++ b/src/dsp/motion_vector_search.cc @@ -0,0 +1,211 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/motion_vector_search.h" + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" +#include "src/utils/types.h" + +namespace libgav1 { +namespace dsp { +namespace { + +// Silence unused function warnings when the C functions are not used. +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \ + !defined(LIBGAV1_Dsp8bpp_MotionVectorSearch) || \ + (LIBGAV1_MAX_BITDEPTH >= 10 && \ + !defined(LIBGAV1_Dsp10bpp_MotionVectorSearch)) + +void MvProjectionCompoundLowPrecision_C( + const MotionVector* const temporal_mvs, + const int8_t* const temporal_reference_offsets, + const int reference_offsets[2], const int count, + CompoundMotionVector* const candidate_mvs) { + // To facilitate the compilers, make a local copy of |reference_offsets|. + const int offsets[2] = {reference_offsets[0], reference_offsets[1]}; + int index = 0; + do { + candidate_mvs[index].mv64 = 0; + for (int i = 0; i < 2; ++i) { + // |offsets| non-zero check usually equals true and could be ignored. + if (offsets[i] != 0) { + GetMvProjection( + temporal_mvs[index], offsets[i], + kProjectionMvDivisionLookup[temporal_reference_offsets[index]], + &candidate_mvs[index].mv[i]); + for (auto& mv : candidate_mvs[index].mv[i].mv) { + // The next line is equivalent to: + // if ((mv & 1) != 0) mv += (mv > 0) ? -1 : 1; + mv = (mv - (mv >> 15)) & ~1; + } + } + } + } while (++index < count); +} + +void MvProjectionCompoundForceInteger_C( + const MotionVector* const temporal_mvs, + const int8_t* const temporal_reference_offsets, + const int reference_offsets[2], const int count, + CompoundMotionVector* const candidate_mvs) { + // To facilitate the compilers, make a local copy of |reference_offsets|. + const int offsets[2] = {reference_offsets[0], reference_offsets[1]}; + int index = 0; + do { + candidate_mvs[index].mv64 = 0; + for (int i = 0; i < 2; ++i) { + // |offsets| non-zero check usually equals true and could be ignored. + if (offsets[i] != 0) { + GetMvProjection( + temporal_mvs[index], offsets[i], + kProjectionMvDivisionLookup[temporal_reference_offsets[index]], + &candidate_mvs[index].mv[i]); + for (auto& mv : candidate_mvs[index].mv[i].mv) { + // The next line is equivalent to: + // const int value = (std::abs(static_cast<int>(mv)) + 3) & ~7; + // const int sign = mv >> 15; + // mv = ApplySign(value, sign); + mv = (mv + 3 - (mv >> 15)) & ~7; + } + } + } + } while (++index < count); +} + +void MvProjectionCompoundHighPrecision_C( + const MotionVector* const temporal_mvs, + const int8_t* const temporal_reference_offsets, + const int reference_offsets[2], const int count, + CompoundMotionVector* const candidate_mvs) { + // To facilitate the compilers, make a local copy of |reference_offsets|. + const int offsets[2] = {reference_offsets[0], reference_offsets[1]}; + int index = 0; + do { + candidate_mvs[index].mv64 = 0; + for (int i = 0; i < 2; ++i) { + // |offsets| non-zero check usually equals true and could be ignored. + if (offsets[i] != 0) { + GetMvProjection( + temporal_mvs[index], offsets[i], + kProjectionMvDivisionLookup[temporal_reference_offsets[index]], + &candidate_mvs[index].mv[i]); + } + } + } while (++index < count); +} + +void MvProjectionSingleLowPrecision_C( + const MotionVector* const temporal_mvs, + const int8_t* const temporal_reference_offsets, const int reference_offset, + const int count, MotionVector* const candidate_mvs) { + int index = 0; + do { + GetMvProjection( + temporal_mvs[index], reference_offset, + kProjectionMvDivisionLookup[temporal_reference_offsets[index]], + &candidate_mvs[index]); + for (auto& mv : candidate_mvs[index].mv) { + // The next line is equivalent to: + // if ((mv & 1) != 0) mv += (mv > 0) ? -1 : 1; + mv = (mv - (mv >> 15)) & ~1; + } + } while (++index < count); +} + +void MvProjectionSingleForceInteger_C( + const MotionVector* const temporal_mvs, + const int8_t* const temporal_reference_offsets, const int reference_offset, + const int count, MotionVector* const candidate_mvs) { + int index = 0; + do { + GetMvProjection( + temporal_mvs[index], reference_offset, + kProjectionMvDivisionLookup[temporal_reference_offsets[index]], + &candidate_mvs[index]); + for (auto& mv : candidate_mvs[index].mv) { + // The next line is equivalent to: + // const int value = (std::abs(static_cast<int>(mv)) + 3) & ~7; + // const int sign = mv >> 15; + // mv = ApplySign(value, sign); + mv = (mv + 3 - (mv >> 15)) & ~7; + } + } while (++index < count); +} + +void MvProjectionSingleHighPrecision_C( + const MotionVector* const temporal_mvs, + const int8_t* const temporal_reference_offsets, const int reference_offset, + const int count, MotionVector* const candidate_mvs) { + int index = 0; + do { + GetMvProjection( + temporal_mvs[index], reference_offset, + kProjectionMvDivisionLookup[temporal_reference_offsets[index]], + &candidate_mvs[index]); + } while (++index < count); +} + +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || + // !defined(LIBGAV1_Dsp8bpp_MotionVectorSearch) || + // (LIBGAV1_MAX_BITDEPTH >= 10 && + // !defined(LIBGAV1_Dsp10bpp_MotionVectorSearch)) + +void Init8bpp() { +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \ + !defined(LIBGAV1_Dsp8bpp_MotionVectorSearch) + Dsp* const dsp = dsp_internal::GetWritableDspTable(8); + assert(dsp != nullptr); + dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_C; + dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_C; + dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_C; + dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_C; + dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_C; + dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_C; +#endif +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \ + !defined(LIBGAV1_Dsp10bpp_MotionVectorSearch) + Dsp* const dsp = dsp_internal::GetWritableDspTable(10); + assert(dsp != nullptr); + dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_C; + dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_C; + dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_C; + dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_C; + dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_C; + dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_C; +#endif +} +#endif + +} // namespace + +void MotionVectorSearchInit_C() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 diff --git a/src/dsp/motion_vector_search.h b/src/dsp/motion_vector_search.h new file mode 100644 index 0000000..ae16726 --- /dev/null +++ b/src/dsp/motion_vector_search.h @@ -0,0 +1,49 @@ +/* + * Copyright 2020 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_MOTION_VECTOR_SEARCH_H_ +#define LIBGAV1_SRC_DSP_MOTION_VECTOR_SEARCH_H_ + +// Pull in LIBGAV1_DspXXX defines representing the implementation status +// of each function. The resulting value of each can be used by each module to +// determine whether an implementation is needed at compile time. +// IWYU pragma: begin_exports + +// ARM: +#include "src/dsp/arm/motion_vector_search_neon.h" + +// x86: +// Note includes should be sorted in logical order avx2/avx/sse4, etc. +// The order of includes is important as each tests for a superior version +// before setting the base. +// clang-format off +// SSE4_1 +#include "src/dsp/x86/motion_vector_search_sse4.h" +// clang-format on + +// IWYU pragma: end_exports + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::mv_projection_compound and Dsp::mv_projection_single. This +// function is not thread-safe. +void MotionVectorSearchInit_C(); + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_MOTION_VECTOR_SEARCH_H_ diff --git a/src/dsp/obmc.cc b/src/dsp/obmc.cc new file mode 100644 index 0000000..46d1b5b --- /dev/null +++ b/src/dsp/obmc.cc @@ -0,0 +1,125 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/obmc.h" + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace { + +#include "src/dsp/obmc.inc" + +// 7.11.3.10 (from top samples). +template <typename Pixel> +void OverlapBlendVertical_C(void* const prediction, + const ptrdiff_t prediction_stride, const int width, + const int height, const void* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + auto* pred = static_cast<Pixel*>(prediction); + const ptrdiff_t pred_stride = prediction_stride / sizeof(Pixel); + const auto* obmc_pred = static_cast<const Pixel*>(obmc_prediction); + const ptrdiff_t obmc_pred_stride = obmc_prediction_stride / sizeof(Pixel); + const uint8_t* const mask = kObmcMask + height - 2; + + for (int y = 0; y < height; ++y) { + const uint8_t mask_value = mask[y]; + for (int x = 0; x < width; ++x) { + pred[x] = static_cast<Pixel>(RightShiftWithRounding( + mask_value * pred[x] + (64 - mask_value) * obmc_pred[x], 6)); + } + pred += pred_stride; + obmc_pred += obmc_pred_stride; + } +} + +// 7.11.3.10 (from left samples). +template <typename Pixel> +void OverlapBlendHorizontal_C(void* const prediction, + const ptrdiff_t prediction_stride, + const int width, const int height, + const void* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + auto* pred = static_cast<Pixel*>(prediction); + const ptrdiff_t pred_stride = prediction_stride / sizeof(Pixel); + const auto* obmc_pred = static_cast<const Pixel*>(obmc_prediction); + const ptrdiff_t obmc_pred_stride = obmc_prediction_stride / sizeof(Pixel); + const uint8_t* const mask = kObmcMask + width - 2; + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + const uint8_t mask_value = mask[x]; + pred[x] = static_cast<Pixel>(RightShiftWithRounding( + mask_value * pred[x] + (64 - mask_value) * obmc_pred[x], 6)); + } + pred += pred_stride; + obmc_pred += obmc_pred_stride; + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(8); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendVertical_C<uint8_t>; + dsp->obmc_blend[kObmcDirectionHorizontal] = OverlapBlendHorizontal_C<uint8_t>; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp8bpp_ObmcVertical + dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendVertical_C<uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_ObmcHorizontal + dsp->obmc_blend[kObmcDirectionHorizontal] = OverlapBlendHorizontal_C<uint8_t>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(10); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendVertical_C<uint16_t>; + dsp->obmc_blend[kObmcDirectionHorizontal] = + OverlapBlendHorizontal_C<uint16_t>; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp10bpp_ObmcVertical + dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendVertical_C<uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_ObmcHorizontal + dsp->obmc_blend[kObmcDirectionHorizontal] = + OverlapBlendHorizontal_C<uint16_t>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} +#endif + +} // namespace + +void ObmcInit_C() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 diff --git a/src/dsp/obmc.h b/src/dsp/obmc.h new file mode 100644 index 0000000..3b826c7 --- /dev/null +++ b/src/dsp/obmc.h @@ -0,0 +1,47 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_OBMC_H_ +#define LIBGAV1_SRC_DSP_OBMC_H_ + +// Pull in LIBGAV1_DspXXX defines representing the implementation status +// of each function. The resulting value of each can be used by each module to +// determine whether an implementation is needed at compile time. +// IWYU pragma: begin_exports + +// ARM: +#include "src/dsp/arm/obmc_neon.h" + +// x86: +// Note includes should be sorted in logical order avx2/avx/sse4, etc. +// The order of includes is important as each tests for a superior version +// before setting the base. +// clang-format off +#include "src/dsp/x86/obmc_sse4.h" +// clang-format on + +// IWYU pragma: end_exports + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::obmc_blend. This function is not thread-safe. +void ObmcInit_C(); + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_OBMC_H_ diff --git a/src/dsp/obmc.inc b/src/dsp/obmc.inc new file mode 100644 index 0000000..001c6ee --- /dev/null +++ b/src/dsp/obmc.inc @@ -0,0 +1,32 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Constants and utility functions used for overlap blend implementations. +// This will be included inside an anonymous namespace on files where these are +// necessary. + +// This is a flat array of masks for each block dimension from 2 to 32. The +// starting index for each length is length-2. +constexpr uint8_t kObmcMask[62] = { + // Obmc Mask 2 + 45, 64, + // Obmc Mask 4 + 39, 50, 59, 64, + // Obmc Mask 8 + 36, 42, 48, 53, 57, 61, 64, 64, + // Obmc Mask 16 + 34, 37, 40, 43, 46, 49, 52, 54, 56, 58, 60, 61, 64, 64, 64, 64, + // Obmc Mask 32 + 33, 35, 36, 38, 40, 41, 43, 44, 45, 47, 48, 50, 51, 52, 53, 55, 56, 57, 58, + 59, 60, 60, 61, 62, 64, 64, 64, 64, 64, 64, 64, 64}; diff --git a/src/dsp/super_res.cc b/src/dsp/super_res.cc new file mode 100644 index 0000000..d041bd1 --- /dev/null +++ b/src/dsp/super_res.cc @@ -0,0 +1,109 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/super_res.h" + +#include <cassert> + +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace { + +template <int bitdepth, typename Pixel> +void SuperRes_C(const void* /*coefficients*/, void* const source, + const ptrdiff_t stride, const int height, + const int downscaled_width, const int upscaled_width, + const int initial_subpixel_x, const int step, + void* const dest) { + assert(step <= 1 << kSuperResScaleBits); + auto* src = static_cast<Pixel*>(source) - DivideBy2(kSuperResFilterTaps); + auto* dst = static_cast<Pixel*>(dest); + int y = height; + do { + ExtendLine<Pixel>(src + DivideBy2(kSuperResFilterTaps), downscaled_width, + kSuperResHorizontalBorder, kSuperResHorizontalBorder); + // If (original) upscaled_width is <= 9, the downscaled_width may be + // upscaled_width - 1 (i.e. 8, 9), and become the same (i.e. 4) when + // subsampled via RightShiftWithRounding. This leads to an edge case where + // |step| == 1 << 14. + int subpixel_x = initial_subpixel_x; + int x = 0; + do { + int sum = 0; + const Pixel* const src_x = &src[subpixel_x >> kSuperResScaleBits]; + const int src_x_subpixel = + (subpixel_x & kSuperResScaleMask) >> kSuperResExtraBits; + // The sign of each tap is: - + - + + - + - + sum -= src_x[0] * kUpscaleFilterUnsigned[src_x_subpixel][0]; + sum += src_x[1] * kUpscaleFilterUnsigned[src_x_subpixel][1]; + sum -= src_x[2] * kUpscaleFilterUnsigned[src_x_subpixel][2]; + sum += src_x[3] * kUpscaleFilterUnsigned[src_x_subpixel][3]; + sum += src_x[4] * kUpscaleFilterUnsigned[src_x_subpixel][4]; + sum -= src_x[5] * kUpscaleFilterUnsigned[src_x_subpixel][5]; + sum += src_x[6] * kUpscaleFilterUnsigned[src_x_subpixel][6]; + sum -= src_x[7] * kUpscaleFilterUnsigned[src_x_subpixel][7]; + dst[x] = Clip3(RightShiftWithRounding(sum, kFilterBits), 0, + (1 << bitdepth) - 1); + subpixel_x += step; + } while (++x < upscaled_width); + src += stride; + dst += stride; + } while (--y != 0); +} + +void Init8bpp() { + Dsp* dsp = dsp_internal::GetWritableDspTable(8); + assert(dsp != nullptr); + dsp->super_res_coefficients = nullptr; +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->super_res = SuperRes_C<8, uint8_t>; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp8bpp_SuperRes + dsp->super_res = SuperRes_C<8, uint8_t>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { + Dsp* dsp = dsp_internal::GetWritableDspTable(10); + assert(dsp != nullptr); + dsp->super_res_coefficients = nullptr; +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->super_res = SuperRes_C<10, uint16_t>; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp10bpp_SuperRes + dsp->super_res = SuperRes_C<10, uint16_t>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} +#endif + +} // namespace + +void SuperResInit_C() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 diff --git a/src/dsp/super_res.h b/src/dsp/super_res.h new file mode 100644 index 0000000..2ca9d2b --- /dev/null +++ b/src/dsp/super_res.h @@ -0,0 +1,47 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_SUPER_RES_H_ +#define LIBGAV1_SRC_DSP_SUPER_RES_H_ + +// Pull in LIBGAV1_DspXXX defines representing the implementation status +// of each function. The resulting value of each can be used by each module to +// determine whether an implementation is needed at compile time. +// IWYU pragma: begin_exports + +// ARM: +#include "src/dsp/arm/super_res_neon.h" + +// x86: +// Note includes should be sorted in logical order avx2/avx/sse4, etc. +// The order of includes is important as each tests for a superior version +// before setting the base. +// clang-format off +#include "src/dsp/x86/super_res_sse4.h" +// clang-format on + +// IWYU pragma: end_exports + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::super_res. This function is not thread-safe. +void SuperResInit_C(); + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_SUPER_RES_H_ diff --git a/src/dsp/warp.cc b/src/dsp/warp.cc new file mode 100644 index 0000000..fbde65a --- /dev/null +++ b/src/dsp/warp.cc @@ -0,0 +1,475 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/warp.h" + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <type_traits> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" +#include "src/utils/memory.h" + +namespace libgav1 { +namespace dsp { +namespace { + +// Number of extra bits of precision in warped filtering. +constexpr int kWarpedDiffPrecisionBits = 10; + +// Warp prediction output ranges from WarpTest.ShowRange. +// Bitdepth: 8 Input range: [ 0, 255] +// 8bpp intermediate offset: 16384. +// intermediate range: [ 4399, 61009] +// first pass output range: [ 550, 7626] +// 8bpp intermediate offset removal: 262144. +// intermediate range: [ -620566, 1072406] +// second pass output range: [ 0, 255] +// compound second pass output range: [ -4848, 8378] +// +// Bitdepth: 10 Input range: [ 0, 1023] +// intermediate range: [ -48081, 179025] +// first pass output range: [ -6010, 22378] +// intermediate range: [-2103516, 4198620] +// second pass output range: [ 0, 1023] +// compound second pass output range: [ 8142, 57378] +// +// Bitdepth: 12 Input range: [ 0, 4095] +// intermediate range: [ -192465, 716625] +// first pass output range: [ -6015, 22395] +// intermediate range: [-2105190, 4201830] +// second pass output range: [ 0, 4095] +// compound second pass output range: [ 8129, 57403] + +template <bool is_compound, int bitdepth, typename Pixel> +void Warp_C(const void* const source, ptrdiff_t source_stride, + const int source_width, const int source_height, + const int* const warp_params, const int subsampling_x, + const int subsampling_y, const int block_start_x, + const int block_start_y, const int block_width, + const int block_height, const int16_t alpha, const int16_t beta, + const int16_t gamma, const int16_t delta, void* dest, + ptrdiff_t dest_stride) { + assert(block_width >= 8 && block_height >= 8); + if (is_compound) { + assert(dest_stride == block_width); + } + constexpr int kRoundBitsHorizontal = (bitdepth == 12) + ? kInterRoundBitsHorizontal12bpp + : kInterRoundBitsHorizontal; + constexpr int kRoundBitsVertical = + is_compound ? kInterRoundBitsCompoundVertical + : (bitdepth == 12) ? kInterRoundBitsVertical12bpp + : kInterRoundBitsVertical; + + // Only used for 8bpp. Allows for keeping the first pass intermediates within + // uint16_t. With 10/12bpp the intermediate value will always require int32_t. + constexpr int first_pass_offset = (bitdepth == 8) ? 1 << 14 : 0; + constexpr int offset_removal = + (first_pass_offset >> kRoundBitsHorizontal) * 128; + + constexpr int kMaxPixel = (1 << bitdepth) - 1; + union { + // |intermediate_result| is the output of the horizontal filtering and + // rounding. The range is within int16_t. + int16_t intermediate_result[15][8]; // 15 rows, 8 columns. + // In the simple special cases where the samples in each row are all the + // same, store one sample per row in a column vector. + int16_t intermediate_result_column[15]; + }; + const auto* const src = static_cast<const Pixel*>(source); + source_stride /= sizeof(Pixel); + using DestType = + typename std::conditional<is_compound, uint16_t, Pixel>::type; + auto* dst = static_cast<DestType*>(dest); + if (!is_compound) dest_stride /= sizeof(dst[0]); + + assert(block_width >= 8); + assert(block_height >= 8); + + // Warp process applies for each 8x8 block (or smaller). + for (int start_y = block_start_y; start_y < block_start_y + block_height; + start_y += 8) { + for (int start_x = block_start_x; start_x < block_start_x + block_width; + start_x += 8) { + const int src_x = (start_x + 4) << subsampling_x; + const int src_y = (start_y + 4) << subsampling_y; + const int dst_x = + src_x * warp_params[2] + src_y * warp_params[3] + warp_params[0]; + const int dst_y = + src_x * warp_params[4] + src_y * warp_params[5] + warp_params[1]; + const int x4 = dst_x >> subsampling_x; + const int y4 = dst_y >> subsampling_y; + const int ix4 = x4 >> kWarpedModelPrecisionBits; + const int iy4 = y4 >> kWarpedModelPrecisionBits; + + // A prediction block may fall outside the frame's boundaries. If a + // prediction block is calculated using only samples outside the frame's + // boundary, the filtering can be simplified. We can divide the plane + // into several regions and handle them differently. + // + // | | + // 1 | 3 | 1 + // | | + // -------+-----------+------- + // |***********| + // 2 |*****4*****| 2 + // |***********| + // -------+-----------+------- + // | | + // 1 | 3 | 1 + // | | + // + // At the center, region 4 represents the frame and is the general case. + // + // In regions 1 and 2, the prediction block is outside the frame's + // boundary horizontally. Therefore the horizontal filtering can be + // simplified. Furthermore, in the region 1 (at the four corners), the + // prediction is outside the frame's boundary both horizontally and + // vertically, so we get a constant prediction block. + // + // In region 3, the prediction block is outside the frame's boundary + // vertically. Unfortunately because we apply the horizontal filters + // first, by the time we apply the vertical filters, they no longer see + // simple inputs. So the only simplification is that all the rows are + // the same, but we still need to apply all the horizontal and vertical + // filters. + + // Check for two simple special cases, where the horizontal filter can + // be significantly simplified. + // + // In general, for each row, the horizontal filter is calculated as + // follows: + // for (int x = -4; x < 4; ++x) { + // const int offset = ...; + // int sum = first_pass_offset; + // for (int k = 0; k < 8; ++k) { + // const int column = Clip3(ix4 + x + k - 3, 0, source_width - 1); + // sum += kWarpedFilters[offset][k] * src_row[column]; + // } + // ... + // } + // The column index before clipping, ix4 + x + k - 3, varies in the range + // ix4 - 7 <= ix4 + x + k - 3 <= ix4 + 7. If ix4 - 7 >= source_width - 1 + // or ix4 + 7 <= 0, then all the column indexes are clipped to the same + // border index (source_width - 1 or 0, respectively). Then for each x, + // the inner for loop of the horizontal filter is reduced to multiplying + // the border pixel by the sum of the filter coefficients. + if (ix4 - 7 >= source_width - 1 || ix4 + 7 <= 0) { + // Regions 1 and 2. + // Points to the left or right border of the first row of |src|. + const Pixel* first_row_border = + (ix4 + 7 <= 0) ? src : src + source_width - 1; + // In general, for y in [-7, 8), the row number iy4 + y is clipped: + // const int row = Clip3(iy4 + y, 0, source_height - 1); + // In two special cases, iy4 + y is clipped to either 0 or + // source_height - 1 for all y. In the rest of the cases, iy4 + y is + // bounded and we can avoid clipping iy4 + y by relying on a reference + // frame's boundary extension on the top and bottom. + if (iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0) { + // Region 1. + // Every sample used to calculate the prediction block has the same + // value. So the whole prediction block has the same value. + const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1; + const Pixel row_border_pixel = first_row_border[row * source_stride]; + DestType* dst_row = dst + start_x - block_start_x; + if (is_compound) { + int sum = row_border_pixel + << ((14 - kRoundBitsHorizontal) - kRoundBitsVertical); + sum += (bitdepth == 8) ? 0 : kCompoundOffset; + Memset(dst_row, sum, 8); + } else { + Memset(dst_row, row_border_pixel, 8); + } + const DestType* const first_dst_row = dst_row; + dst_row += dest_stride; + for (int y = 1; y < 8; ++y) { + memcpy(dst_row, first_dst_row, 8 * sizeof(*dst_row)); + dst_row += dest_stride; + } + // End of region 1. Continue the |start_x| for loop. + continue; + } + + // Region 2. + // Horizontal filter. + // The input values in this region are generated by extending the border + // which makes them identical in the horizontal direction. This + // computation could be inlined in the vertical pass but most + // implementations will need a transpose of some sort. + // It is not necessary to use the offset values here because the + // horizontal pass is a simple shift and the vertical pass will always + // require using 32 bits. + for (int y = -7; y < 8; ++y) { + // We may over-read up to 13 pixels above the top source row, or up + // to 13 pixels below the bottom source row. This is proved below. + const int row = iy4 + y; + int sum = first_row_border[row * source_stride]; + sum <<= kFilterBits - kRoundBitsHorizontal; + intermediate_result_column[y + 7] = sum; + } + // Vertical filter. + DestType* dst_row = dst + start_x - block_start_x; + int sy4 = + (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta); + for (int y = 0; y < 8; ++y) { + int sy = sy4 - MultiplyBy4(gamma); + for (int x = 0; x < 8; ++x) { + const int offset = + RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) + + kWarpedPixelPrecisionShifts; + assert(offset >= 0); + assert(offset < 3 * kWarpedPixelPrecisionShifts + 1); + int sum = 0; + for (int k = 0; k < 8; ++k) { + sum += + kWarpedFilters[offset][k] * intermediate_result_column[y + k]; + } + sum = RightShiftWithRounding(sum, kRoundBitsVertical); + if (is_compound) { + sum += (bitdepth == 8) ? 0 : kCompoundOffset; + dst_row[x] = static_cast<DestType>(sum); + } else { + dst_row[x] = static_cast<DestType>(Clip3(sum, 0, kMaxPixel)); + } + sy += gamma; + } + dst_row += dest_stride; + sy4 += delta; + } + // End of region 2. Continue the |start_x| for loop. + continue; + } + + // Regions 3 and 4. + // At this point, we know ix4 - 7 < source_width - 1 and ix4 + 7 > 0. + // It follows that -6 <= ix4 <= source_width + 5. This inequality is + // used below. + + // In general, for y in [-7, 8), the row number iy4 + y is clipped: + // const int row = Clip3(iy4 + y, 0, source_height - 1); + // In two special cases, iy4 + y is clipped to either 0 or + // source_height - 1 for all y. In the rest of the cases, iy4 + y is + // bounded and we can avoid clipping iy4 + y by relying on a reference + // frame's boundary extension on the top and bottom. + if (iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0) { + // Region 3. + // Horizontal filter. + const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1; + const Pixel* const src_row = src + row * source_stride; + int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7; + for (int y = -7; y < 8; ++y) { + int sx = sx4 - MultiplyBy4(alpha); + for (int x = -4; x < 4; ++x) { + const int offset = + RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) + + kWarpedPixelPrecisionShifts; + // Since alpha and beta have been validated by SetupShear(), one + // can prove that 0 <= offset <= 3 * 2^6. + assert(offset >= 0); + assert(offset < 3 * kWarpedPixelPrecisionShifts + 1); + // For SIMD optimization: + // |first_pass_offset| guarantees the sum fits in uint16_t for 8bpp. + // For 10/12 bit, the range of sum requires 32 bits. + int sum = first_pass_offset; + for (int k = 0; k < 8; ++k) { + // We assume the source frame has left and right borders of at + // least 13 pixels that extend the frame boundary pixels. + // + // Since -4 <= x <= 3 and 0 <= k <= 7, using the inequality on + // ix4 above, we have + // -13 <= ix4 + x + k - 3 <= source_width + 12, + // or + // -13 <= column <= (source_width - 1) + 13. + // Therefore we may over-read up to 13 pixels before the source + // row, or up to 13 pixels after the source row. + const int column = ix4 + x + k - 3; + sum += kWarpedFilters[offset][k] * src_row[column]; + } + intermediate_result[y + 7][x + 4] = + RightShiftWithRounding(sum, kRoundBitsHorizontal); + sx += alpha; + } + sx4 += beta; + } + } else { + // Region 4. + // Horizontal filter. + // At this point, we know iy4 - 7 < source_height - 1 and iy4 + 7 > 0. + // It follows that -6 <= iy4 <= source_height + 5. This inequality is + // used below. + int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7; + for (int y = -7; y < 8; ++y) { + // We assume the source frame has top and bottom borders of at least + // 13 pixels that extend the frame boundary pixels. + // + // Since -7 <= y <= 7, using the inequality on iy4 above, we have + // -13 <= iy4 + y <= source_height + 12, + // or + // -13 <= row <= (source_height - 1) + 13. + // Therefore we may over-read up to 13 pixels above the top source + // row, or up to 13 pixels below the bottom source row. + const int row = iy4 + y; + const Pixel* const src_row = src + row * source_stride; + int sx = sx4 - MultiplyBy4(alpha); + for (int x = -4; x < 4; ++x) { + const int offset = + RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) + + kWarpedPixelPrecisionShifts; + // Since alpha and beta have been validated by SetupShear(), one + // can prove that 0 <= offset <= 3 * 2^6. + assert(offset >= 0); + assert(offset < 3 * kWarpedPixelPrecisionShifts + 1); + // For SIMD optimization: + // |first_pass_offset| guarantees the sum fits in uint16_t for 8bpp. + // For 10/12 bit, the range of sum requires 32 bits. + int sum = first_pass_offset; + for (int k = 0; k < 8; ++k) { + // We assume the source frame has left and right borders of at + // least 13 pixels that extend the frame boundary pixels. + // + // Since -4 <= x <= 3 and 0 <= k <= 7, using the inequality on + // ix4 above, we have + // -13 <= ix4 + x + k - 3 <= source_width + 12, + // or + // -13 <= column <= (source_width - 1) + 13. + // Therefore we may over-read up to 13 pixels before the source + // row, or up to 13 pixels after the source row. + const int column = ix4 + x + k - 3; + sum += kWarpedFilters[offset][k] * src_row[column]; + } + intermediate_result[y + 7][x + 4] = + RightShiftWithRounding(sum, kRoundBitsHorizontal) - + offset_removal; + sx += alpha; + } + sx4 += beta; + } + } + + // Regions 3 and 4. + // Vertical filter. + DestType* dst_row = dst + start_x - block_start_x; + int sy4 = + (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta); + // The spec says we should use the following loop condition: + // y < std::min(4, block_start_y + block_height - start_y - 4); + // We can prove that block_start_y + block_height - start_y >= 8, which + // implies std::min(4, block_start_y + block_height - start_y - 4) = 4. + // So the loop condition is simply y < 4. + // + // Proof: + // start_y < block_start_y + block_height + // => block_start_y + block_height - start_y > 0 + // => block_height - (start_y - block_start_y) > 0 + // + // Since block_height >= 8 and is a power of 2, it follows that + // block_height is a multiple of 8. start_y - block_start_y is also a + // multiple of 8. Therefore their difference is a multiple of 8. Since + // their difference is > 0, their difference must be >= 8. + // + // We then add an offset of 4 to y so that the loop starts with y = 0 + // and continues if y < 8. + for (int y = 0; y < 8; ++y) { + int sy = sy4 - MultiplyBy4(gamma); + // The spec says we should use the following loop condition: + // x < std::min(4, block_start_x + block_width - start_x - 4); + // Similar to the above, we can prove that the loop condition can be + // simplified to x < 4. + // + // We then add an offset of 4 to x so that the loop starts with x = 0 + // and continues if x < 8. + for (int x = 0; x < 8; ++x) { + const int offset = + RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) + + kWarpedPixelPrecisionShifts; + // Since gamma and delta have been validated by SetupShear(), one can + // prove that 0 <= offset <= 3 * 2^6. + assert(offset >= 0); + assert(offset < 3 * kWarpedPixelPrecisionShifts + 1); + int sum = 0; + for (int k = 0; k < 8; ++k) { + sum += kWarpedFilters[offset][k] * intermediate_result[y + k][x]; + } + sum -= offset_removal; + sum = RightShiftWithRounding(sum, kRoundBitsVertical); + if (is_compound) { + sum += (bitdepth == 8) ? 0 : kCompoundOffset; + dst_row[x] = static_cast<DestType>(sum); + } else { + dst_row[x] = static_cast<DestType>(Clip3(sum, 0, kMaxPixel)); + } + sy += gamma; + } + dst_row += dest_stride; + sy4 += delta; + } + } + dst += 8 * dest_stride; + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(8); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->warp = Warp_C</*is_compound=*/false, 8, uint8_t>; + dsp->warp_compound = Warp_C</*is_compound=*/true, 8, uint8_t>; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp8bpp_Warp + dsp->warp = Warp_C</*is_compound=*/false, 8, uint8_t>; +#endif +#ifndef LIBGAV1_Dsp8bpp_WarpCompound + dsp->warp_compound = Warp_C</*is_compound=*/true, 8, uint8_t>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(10); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + dsp->warp = Warp_C</*is_compound=*/false, 10, uint16_t>; + dsp->warp_compound = Warp_C</*is_compound=*/true, 10, uint16_t>; +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp10bpp_Warp + dsp->warp = Warp_C</*is_compound=*/false, 10, uint16_t>; +#endif +#ifndef LIBGAV1_Dsp10bpp_WarpCompound + dsp->warp_compound = Warp_C</*is_compound=*/true, 10, uint16_t>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} +#endif + +} // namespace + +void WarpInit_C() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 diff --git a/src/dsp/warp.h b/src/dsp/warp.h new file mode 100644 index 0000000..7367a9b --- /dev/null +++ b/src/dsp/warp.h @@ -0,0 +1,47 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_WARP_H_ +#define LIBGAV1_SRC_DSP_WARP_H_ + +// Pull in LIBGAV1_DspXXX defines representing the implementation status +// of each function. The resulting value of each can be used by each module to +// determine whether an implementation is needed at compile time. +// IWYU pragma: begin_exports + +// ARM: +#include "src/dsp/arm/warp_neon.h" + +// x86: +// Note includes should be sorted in logical order avx2/avx/sse4, etc. +// The order of includes is important as each tests for a superior version +// before setting the base. +// clang-format off +#include "src/dsp/x86/warp_sse4.h" +// clang-format on + +// IWYU pragma: end_exports + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::warp. This function is not thread-safe. +void WarpInit_C(); + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_WARP_H_ diff --git a/src/dsp/weight_mask.cc b/src/dsp/weight_mask.cc new file mode 100644 index 0000000..15d6bc6 --- /dev/null +++ b/src/dsp/weight_mask.cc @@ -0,0 +1,227 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/weight_mask.h" + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <string> +#include <type_traits> + +#include "src/dsp/dsp.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace { + +template <int width, int height, int bitdepth, bool mask_is_inverse> +void WeightMask_C(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + using PredType = + typename std::conditional<bitdepth == 8, int16_t, uint16_t>::type; + const auto* pred_0 = static_cast<const PredType*>(prediction_0); + const auto* pred_1 = static_cast<const PredType*>(prediction_1); + static_assert(width >= 8, ""); + static_assert(height >= 8, ""); + constexpr int rounding_bits = bitdepth - 8 + ((bitdepth == 12) ? 2 : 4); + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + const int difference = RightShiftWithRounding( + std::abs(pred_0[x] - pred_1[x]), rounding_bits); + const auto mask_value = + static_cast<uint8_t>(std::min(DivideBy16(difference) + 38, 64)); + mask[x] = mask_is_inverse ? 64 - mask_value : mask_value; + } + pred_0 += width; + pred_1 += width; + mask += mask_stride; + } +} + +#define INIT_WEIGHT_MASK(width, height, bitdepth, w_index, h_index) \ + dsp->weight_mask[w_index][h_index][0] = \ + WeightMask_C<width, height, bitdepth, 0>; \ + dsp->weight_mask[w_index][h_index][1] = \ + WeightMask_C<width, height, bitdepth, 1> + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(8); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + INIT_WEIGHT_MASK(8, 8, 8, 0, 0); + INIT_WEIGHT_MASK(8, 16, 8, 0, 1); + INIT_WEIGHT_MASK(8, 32, 8, 0, 2); + INIT_WEIGHT_MASK(16, 8, 8, 1, 0); + INIT_WEIGHT_MASK(16, 16, 8, 1, 1); + INIT_WEIGHT_MASK(16, 32, 8, 1, 2); + INIT_WEIGHT_MASK(16, 64, 8, 1, 3); + INIT_WEIGHT_MASK(32, 8, 8, 2, 0); + INIT_WEIGHT_MASK(32, 16, 8, 2, 1); + INIT_WEIGHT_MASK(32, 32, 8, 2, 2); + INIT_WEIGHT_MASK(32, 64, 8, 2, 3); + INIT_WEIGHT_MASK(64, 16, 8, 3, 1); + INIT_WEIGHT_MASK(64, 32, 8, 3, 2); + INIT_WEIGHT_MASK(64, 64, 8, 3, 3); + INIT_WEIGHT_MASK(64, 128, 8, 3, 4); + INIT_WEIGHT_MASK(128, 64, 8, 4, 3); + INIT_WEIGHT_MASK(128, 128, 8, 4, 4); +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp8bpp_WeightMask_8x8 + INIT_WEIGHT_MASK(8, 8, 8, 0, 0); +#endif +#ifndef LIBGAV1_Dsp8bpp_WeightMask_8x16 + INIT_WEIGHT_MASK(8, 16, 8, 0, 1); +#endif +#ifndef LIBGAV1_Dsp8bpp_WeightMask_8x32 + INIT_WEIGHT_MASK(8, 32, 8, 0, 2); +#endif +#ifndef LIBGAV1_Dsp8bpp_WeightMask_16x8 + INIT_WEIGHT_MASK(16, 8, 8, 1, 0); +#endif +#ifndef LIBGAV1_Dsp8bpp_WeightMask_16x16 + INIT_WEIGHT_MASK(16, 16, 8, 1, 1); +#endif +#ifndef LIBGAV1_Dsp8bpp_WeightMask_16x32 + INIT_WEIGHT_MASK(16, 32, 8, 1, 2); +#endif +#ifndef LIBGAV1_Dsp8bpp_WeightMask_16x64 + INIT_WEIGHT_MASK(16, 64, 8, 1, 3); +#endif +#ifndef LIBGAV1_Dsp8bpp_WeightMask_32x8 + INIT_WEIGHT_MASK(32, 8, 8, 2, 0); +#endif +#ifndef LIBGAV1_Dsp8bpp_WeightMask_32x16 + INIT_WEIGHT_MASK(32, 16, 8, 2, 1); +#endif +#ifndef LIBGAV1_Dsp8bpp_WeightMask_32x32 + INIT_WEIGHT_MASK(32, 32, 8, 2, 2); +#endif +#ifndef LIBGAV1_Dsp8bpp_WeightMask_32x64 + INIT_WEIGHT_MASK(32, 64, 8, 2, 3); +#endif +#ifndef LIBGAV1_Dsp8bpp_WeightMask_64x16 + INIT_WEIGHT_MASK(64, 16, 8, 3, 1); +#endif +#ifndef LIBGAV1_Dsp8bpp_WeightMask_64x32 + INIT_WEIGHT_MASK(64, 32, 8, 3, 2); +#endif +#ifndef LIBGAV1_Dsp8bpp_WeightMask_64x64 + INIT_WEIGHT_MASK(64, 64, 8, 3, 3); +#endif +#ifndef LIBGAV1_Dsp8bpp_WeightMask_64x128 + INIT_WEIGHT_MASK(64, 128, 8, 3, 4); +#endif +#ifndef LIBGAV1_Dsp8bpp_WeightMask_128x64 + INIT_WEIGHT_MASK(128, 64, 8, 4, 3); +#endif +#ifndef LIBGAV1_Dsp8bpp_WeightMask_128x128 + INIT_WEIGHT_MASK(128, 128, 8, 4, 4); +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(10); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + INIT_WEIGHT_MASK(8, 8, 10, 0, 0); + INIT_WEIGHT_MASK(8, 16, 10, 0, 1); + INIT_WEIGHT_MASK(8, 32, 10, 0, 2); + INIT_WEIGHT_MASK(16, 8, 10, 1, 0); + INIT_WEIGHT_MASK(16, 16, 10, 1, 1); + INIT_WEIGHT_MASK(16, 32, 10, 1, 2); + INIT_WEIGHT_MASK(16, 64, 10, 1, 3); + INIT_WEIGHT_MASK(32, 8, 10, 2, 0); + INIT_WEIGHT_MASK(32, 16, 10, 2, 1); + INIT_WEIGHT_MASK(32, 32, 10, 2, 2); + INIT_WEIGHT_MASK(32, 64, 10, 2, 3); + INIT_WEIGHT_MASK(64, 16, 10, 3, 1); + INIT_WEIGHT_MASK(64, 32, 10, 3, 2); + INIT_WEIGHT_MASK(64, 64, 10, 3, 3); + INIT_WEIGHT_MASK(64, 128, 10, 3, 4); + INIT_WEIGHT_MASK(128, 64, 10, 4, 3); + INIT_WEIGHT_MASK(128, 128, 10, 4, 4); +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + static_cast<void>(dsp); +#ifndef LIBGAV1_Dsp10bpp_WeightMask_8x8 + INIT_WEIGHT_MASK(8, 8, 10, 0, 0); +#endif +#ifndef LIBGAV1_Dsp10bpp_WeightMask_8x16 + INIT_WEIGHT_MASK(8, 16, 10, 0, 1); +#endif +#ifndef LIBGAV1_Dsp10bpp_WeightMask_8x32 + INIT_WEIGHT_MASK(8, 32, 10, 0, 2); +#endif +#ifndef LIBGAV1_Dsp10bpp_WeightMask_16x8 + INIT_WEIGHT_MASK(16, 8, 10, 1, 0); +#endif +#ifndef LIBGAV1_Dsp10bpp_WeightMask_16x16 + INIT_WEIGHT_MASK(16, 16, 10, 1, 1); +#endif +#ifndef LIBGAV1_Dsp10bpp_WeightMask_16x32 + INIT_WEIGHT_MASK(16, 32, 10, 1, 2); +#endif +#ifndef LIBGAV1_Dsp10bpp_WeightMask_16x64 + INIT_WEIGHT_MASK(16, 64, 10, 1, 3); +#endif +#ifndef LIBGAV1_Dsp10bpp_WeightMask_32x8 + INIT_WEIGHT_MASK(32, 8, 10, 2, 0); +#endif +#ifndef LIBGAV1_Dsp10bpp_WeightMask_32x16 + INIT_WEIGHT_MASK(32, 16, 10, 2, 1); +#endif +#ifndef LIBGAV1_Dsp10bpp_WeightMask_32x32 + INIT_WEIGHT_MASK(32, 32, 10, 2, 2); +#endif +#ifndef LIBGAV1_Dsp10bpp_WeightMask_32x64 + INIT_WEIGHT_MASK(32, 64, 10, 2, 3); +#endif +#ifndef LIBGAV1_Dsp10bpp_WeightMask_64x16 + INIT_WEIGHT_MASK(64, 16, 10, 3, 1); +#endif +#ifndef LIBGAV1_Dsp10bpp_WeightMask_64x32 + INIT_WEIGHT_MASK(64, 32, 10, 3, 2); +#endif +#ifndef LIBGAV1_Dsp10bpp_WeightMask_64x64 + INIT_WEIGHT_MASK(64, 64, 10, 3, 3); +#endif +#ifndef LIBGAV1_Dsp10bpp_WeightMask_64x128 + INIT_WEIGHT_MASK(64, 128, 10, 3, 4); +#endif +#ifndef LIBGAV1_Dsp10bpp_WeightMask_128x64 + INIT_WEIGHT_MASK(128, 64, 10, 4, 3); +#endif +#ifndef LIBGAV1_Dsp10bpp_WeightMask_128x128 + INIT_WEIGHT_MASK(128, 128, 10, 4, 4); +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} +#endif + +} // namespace + +void WeightMaskInit_C() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 diff --git a/src/dsp/weight_mask.h b/src/dsp/weight_mask.h new file mode 100644 index 0000000..43bef05 --- /dev/null +++ b/src/dsp/weight_mask.h @@ -0,0 +1,47 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_WEIGHT_MASK_H_ +#define LIBGAV1_SRC_DSP_WEIGHT_MASK_H_ + +// Pull in LIBGAV1_DspXXX defines representing the implementation status +// of each function. The resulting value of each can be used by each module to +// determine whether an implementation is needed at compile time. +// IWYU pragma: begin_exports + +// ARM: +#include "src/dsp/arm/weight_mask_neon.h" + +// x86: +// Note includes should be sorted in logical order avx2/avx/sse4, etc. +// The order of includes is important as each tests for a superior version +// before setting the base. +// clang-format off +#include "src/dsp/x86/weight_mask_sse4.h" +// clang-format on + +// IWYU pragma: end_exports + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::weight_mask. This function is not thread-safe. +void WeightMaskInit_C(); + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_SRC_DSP_WEIGHT_MASK_H_ diff --git a/src/dsp/x86/average_blend_sse4.cc b/src/dsp/x86/average_blend_sse4.cc new file mode 100644 index 0000000..8e008d1 --- /dev/null +++ b/src/dsp/x86/average_blend_sse4.cc @@ -0,0 +1,156 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/average_blend.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 + +#include <xmmintrin.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace { + +constexpr int kInterPostRoundBit = 4; + +inline void AverageBlend4Row(const int16_t* prediction_0, + const int16_t* prediction_1, uint8_t* dest) { + const __m128i pred_0 = LoadLo8(prediction_0); + const __m128i pred_1 = LoadLo8(prediction_1); + __m128i res = _mm_add_epi16(pred_0, pred_1); + res = RightShiftWithRounding_S16(res, kInterPostRoundBit + 1); + Store4(dest, _mm_packus_epi16(res, res)); +} + +inline void AverageBlend8Row(const int16_t* prediction_0, + const int16_t* prediction_1, uint8_t* dest) { + const __m128i pred_0 = LoadAligned16(prediction_0); + const __m128i pred_1 = LoadAligned16(prediction_1); + __m128i res = _mm_add_epi16(pred_0, pred_1); + res = RightShiftWithRounding_S16(res, kInterPostRoundBit + 1); + StoreLo8(dest, _mm_packus_epi16(res, res)); +} + +inline void AverageBlendLargeRow(const int16_t* prediction_0, + const int16_t* prediction_1, const int width, + uint8_t* dest) { + int x = 0; + do { + const __m128i pred_00 = LoadAligned16(&prediction_0[x]); + const __m128i pred_01 = LoadAligned16(&prediction_1[x]); + __m128i res0 = _mm_add_epi16(pred_00, pred_01); + res0 = RightShiftWithRounding_S16(res0, kInterPostRoundBit + 1); + const __m128i pred_10 = LoadAligned16(&prediction_0[x + 8]); + const __m128i pred_11 = LoadAligned16(&prediction_1[x + 8]); + __m128i res1 = _mm_add_epi16(pred_10, pred_11); + res1 = RightShiftWithRounding_S16(res1, kInterPostRoundBit + 1); + StoreUnaligned16(dest + x, _mm_packus_epi16(res0, res1)); + x += 16; + } while (x < width); +} + +void AverageBlend_SSE4_1(const void* prediction_0, const void* prediction_1, + const int width, const int height, void* const dest, + const ptrdiff_t dest_stride) { + auto* dst = static_cast<uint8_t*>(dest); + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y = height; + + if (width == 4) { + do { + // TODO(b/150326556): |prediction_[01]| values are packed. It is possible + // to load 8 values at a time. + AverageBlend4Row(pred_0, pred_1, dst); + dst += dest_stride; + pred_0 += width; + pred_1 += width; + + AverageBlend4Row(pred_0, pred_1, dst); + dst += dest_stride; + pred_0 += width; + pred_1 += width; + + y -= 2; + } while (y != 0); + return; + } + + if (width == 8) { + do { + AverageBlend8Row(pred_0, pred_1, dst); + dst += dest_stride; + pred_0 += width; + pred_1 += width; + + AverageBlend8Row(pred_0, pred_1, dst); + dst += dest_stride; + pred_0 += width; + pred_1 += width; + + y -= 2; + } while (y != 0); + return; + } + + do { + AverageBlendLargeRow(pred_0, pred_1, width, dst); + dst += dest_stride; + pred_0 += width; + pred_1 += width; + + AverageBlendLargeRow(pred_0, pred_1, width, dst); + dst += dest_stride; + pred_0 += width; + pred_1 += width; + + y -= 2; + } while (y != 0); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); +#if DSP_ENABLED_8BPP_SSE4_1(AverageBlend) + dsp->average_blend = AverageBlend_SSE4_1; +#endif +} + +} // namespace + +void AverageBlendInit_SSE4_1() { Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_TARGETING_SSE4_1 + +namespace libgav1 { +namespace dsp { + +void AverageBlendInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/average_blend_sse4.h b/src/dsp/x86/average_blend_sse4.h new file mode 100644 index 0000000..937e8e2 --- /dev/null +++ b/src/dsp/x86/average_blend_sse4.h @@ -0,0 +1,41 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_AVERAGE_BLEND_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_AVERAGE_BLEND_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::average_blend. This function is not thread-safe. +void AverageBlendInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +// If sse4 is enabled and the baseline isn't set due to a higher level of +// optimization being enabled, signal the sse4 implementation should be used. +#if LIBGAV1_TARGETING_SSE4_1 +#ifndef LIBGAV1_Dsp8bpp_AverageBlend +#define LIBGAV1_Dsp8bpp_AverageBlend LIBGAV1_CPU_SSE4_1 +#endif + +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_X86_AVERAGE_BLEND_SSE4_H_ diff --git a/src/dsp/x86/cdef_sse4.cc b/src/dsp/x86/cdef_sse4.cc new file mode 100644 index 0000000..3211a2d --- /dev/null +++ b/src/dsp/x86/cdef_sse4.cc @@ -0,0 +1,728 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/cdef.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 + +#include <emmintrin.h> +#include <tmmintrin.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstdlib> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/dsp/x86/transpose_sse4.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +#include "src/dsp/cdef.inc" + +// Used when calculating odd |cost[x]| values. +// Holds elements 1 3 5 7 7 7 7 7 +alignas(16) constexpr uint32_t kCdefDivisionTableOddPadded[] = { + 420, 210, 140, 105, 105, 105, 105, 105}; + +// ---------------------------------------------------------------------------- +// Refer to CdefDirection_C(). +// +// int32_t partial[8][15] = {}; +// for (int i = 0; i < 8; ++i) { +// for (int j = 0; j < 8; ++j) { +// const int x = 1; +// partial[0][i + j] += x; +// partial[1][i + j / 2] += x; +// partial[2][i] += x; +// partial[3][3 + i - j / 2] += x; +// partial[4][7 + i - j] += x; +// partial[5][3 - i / 2 + j] += x; +// partial[6][j] += x; +// partial[7][i / 2 + j] += x; +// } +// } +// +// Using the code above, generate the position count for partial[8][15]. +// +// partial[0]: 1 2 3 4 5 6 7 8 7 6 5 4 3 2 1 +// partial[1]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0 +// partial[2]: 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0 +// partial[3]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0 +// partial[4]: 1 2 3 4 5 6 7 8 7 6 5 4 3 2 1 +// partial[5]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0 +// partial[6]: 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0 +// partial[7]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0 +// +// The SIMD code shifts the input horizontally, then adds vertically to get the +// correct partial value for the given position. +// ---------------------------------------------------------------------------- + +// ---------------------------------------------------------------------------- +// partial[0][i + j] += x; +// +// 00 01 02 03 04 05 06 07 00 00 00 00 00 00 00 +// 00 10 11 12 13 14 15 16 17 00 00 00 00 00 00 +// 00 00 20 21 22 23 24 25 26 27 00 00 00 00 00 +// 00 00 00 30 31 32 33 34 35 36 37 00 00 00 00 +// 00 00 00 00 40 41 42 43 44 45 46 47 00 00 00 +// 00 00 00 00 00 50 51 52 53 54 55 56 57 00 00 +// 00 00 00 00 00 00 60 61 62 63 64 65 66 67 00 +// 00 00 00 00 00 00 00 70 71 72 73 74 75 76 77 +// +// partial[4] is the same except the source is reversed. +LIBGAV1_ALWAYS_INLINE void AddPartial_D0_D4(__m128i* v_src_16, + __m128i* partial_lo, + __m128i* partial_hi) { + // 00 01 02 03 04 05 06 07 + *partial_lo = v_src_16[0]; + // 00 00 00 00 00 00 00 00 + *partial_hi = _mm_setzero_si128(); + + // 00 10 11 12 13 14 15 16 + *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_src_16[1], 2)); + // 17 00 00 00 00 00 00 00 + *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_src_16[1], 14)); + + // 00 00 20 21 22 23 24 25 + *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_src_16[2], 4)); + // 26 27 00 00 00 00 00 00 + *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_src_16[2], 12)); + + // 00 00 00 30 31 32 33 34 + *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_src_16[3], 6)); + // 35 36 37 00 00 00 00 00 + *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_src_16[3], 10)); + + // 00 00 00 00 40 41 42 43 + *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_src_16[4], 8)); + // 44 45 46 47 00 00 00 00 + *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_src_16[4], 8)); + + // 00 00 00 00 00 50 51 52 + *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_src_16[5], 10)); + // 53 54 55 56 57 00 00 00 + *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_src_16[5], 6)); + + // 00 00 00 00 00 00 60 61 + *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_src_16[6], 12)); + // 62 63 64 65 66 67 00 00 + *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_src_16[6], 4)); + + // 00 00 00 00 00 00 00 70 + *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_src_16[7], 14)); + // 71 72 73 74 75 76 77 00 + *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_src_16[7], 2)); +} + +// ---------------------------------------------------------------------------- +// partial[1][i + j / 2] += x; +// +// A0 = src[0] + src[1], A1 = src[2] + src[3], ... +// +// A0 A1 A2 A3 00 00 00 00 00 00 00 00 00 00 00 +// 00 B0 B1 B2 B3 00 00 00 00 00 00 00 00 00 00 +// 00 00 C0 C1 C2 C3 00 00 00 00 00 00 00 00 00 +// 00 00 00 D0 D1 D2 D3 00 00 00 00 00 00 00 00 +// 00 00 00 00 E0 E1 E2 E3 00 00 00 00 00 00 00 +// 00 00 00 00 00 F0 F1 F2 F3 00 00 00 00 00 00 +// 00 00 00 00 00 00 G0 G1 G2 G3 00 00 00 00 00 +// 00 00 00 00 00 00 00 H0 H1 H2 H3 00 00 00 00 +// +// partial[3] is the same except the source is reversed. +LIBGAV1_ALWAYS_INLINE void AddPartial_D1_D3(__m128i* v_src_16, + __m128i* partial_lo, + __m128i* partial_hi) { + __m128i v_d1_temp[8]; + const __m128i v_zero = _mm_setzero_si128(); + + for (int i = 0; i < 8; ++i) { + v_d1_temp[i] = _mm_hadd_epi16(v_src_16[i], v_zero); + } + + *partial_lo = *partial_hi = v_zero; + // A0 A1 A2 A3 00 00 00 00 + *partial_lo = _mm_add_epi16(*partial_lo, v_d1_temp[0]); + + // 00 B0 B1 B2 B3 00 00 00 + *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_d1_temp[1], 2)); + + // 00 00 C0 C1 C2 C3 00 00 + *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_d1_temp[2], 4)); + // 00 00 00 D0 D1 D2 D3 00 + *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_d1_temp[3], 6)); + // 00 00 00 00 E0 E1 E2 E3 + *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_d1_temp[4], 8)); + + // 00 00 00 00 00 F0 F1 F2 + *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_d1_temp[5], 10)); + // F3 00 00 00 00 00 00 00 + *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_d1_temp[5], 6)); + + // 00 00 00 00 00 00 G0 G1 + *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_d1_temp[6], 12)); + // G2 G3 00 00 00 00 00 00 + *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_d1_temp[6], 4)); + + // 00 00 00 00 00 00 00 H0 + *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_d1_temp[7], 14)); + // H1 H2 H3 00 00 00 00 00 + *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_d1_temp[7], 2)); +} + +// ---------------------------------------------------------------------------- +// partial[7][i / 2 + j] += x; +// +// 00 01 02 03 04 05 06 07 00 00 00 00 00 00 00 +// 10 11 12 13 14 15 16 17 00 00 00 00 00 00 00 +// 00 20 21 22 23 24 25 26 27 00 00 00 00 00 00 +// 00 30 31 32 33 34 35 36 37 00 00 00 00 00 00 +// 00 00 40 41 42 43 44 45 46 47 00 00 00 00 00 +// 00 00 50 51 52 53 54 55 56 57 00 00 00 00 00 +// 00 00 00 60 61 62 63 64 65 66 67 00 00 00 00 +// 00 00 00 70 71 72 73 74 75 76 77 00 00 00 00 +// +// partial[5] is the same except the source is reversed. +LIBGAV1_ALWAYS_INLINE void AddPartial_D5_D7(__m128i* v_src, __m128i* partial_lo, + __m128i* partial_hi) { + __m128i v_pair_add[4]; + // Add vertical source pairs. + v_pair_add[0] = _mm_add_epi16(v_src[0], v_src[1]); + v_pair_add[1] = _mm_add_epi16(v_src[2], v_src[3]); + v_pair_add[2] = _mm_add_epi16(v_src[4], v_src[5]); + v_pair_add[3] = _mm_add_epi16(v_src[6], v_src[7]); + + // 00 01 02 03 04 05 06 07 + // 10 11 12 13 14 15 16 17 + *partial_lo = v_pair_add[0]; + // 00 00 00 00 00 00 00 00 + // 00 00 00 00 00 00 00 00 + *partial_hi = _mm_setzero_si128(); + + // 00 20 21 22 23 24 25 26 + // 00 30 31 32 33 34 35 36 + *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_pair_add[1], 2)); + // 27 00 00 00 00 00 00 00 + // 37 00 00 00 00 00 00 00 + *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_pair_add[1], 14)); + + // 00 00 40 41 42 43 44 45 + // 00 00 50 51 52 53 54 55 + *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_pair_add[2], 4)); + // 46 47 00 00 00 00 00 00 + // 56 57 00 00 00 00 00 00 + *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_pair_add[2], 12)); + + // 00 00 00 60 61 62 63 64 + // 00 00 00 70 71 72 73 74 + *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_pair_add[3], 6)); + // 65 66 67 00 00 00 00 00 + // 75 76 77 00 00 00 00 00 + *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_pair_add[3], 10)); +} + +LIBGAV1_ALWAYS_INLINE void AddPartial(const uint8_t* src, ptrdiff_t stride, + __m128i* partial_lo, + __m128i* partial_hi) { + // 8x8 input + // 00 01 02 03 04 05 06 07 + // 10 11 12 13 14 15 16 17 + // 20 21 22 23 24 25 26 27 + // 30 31 32 33 34 35 36 37 + // 40 41 42 43 44 45 46 47 + // 50 51 52 53 54 55 56 57 + // 60 61 62 63 64 65 66 67 + // 70 71 72 73 74 75 76 77 + __m128i v_src[8]; + for (auto& i : v_src) { + i = LoadLo8(src); + src += stride; + } + + const __m128i v_zero = _mm_setzero_si128(); + // partial for direction 2 + // -------------------------------------------------------------------------- + // partial[2][i] += x; + // 00 10 20 30 40 50 60 70 00 00 00 00 00 00 00 00 + // 01 11 21 33 41 51 61 71 00 00 00 00 00 00 00 00 + // 02 12 22 33 42 52 62 72 00 00 00 00 00 00 00 00 + // 03 13 23 33 43 53 63 73 00 00 00 00 00 00 00 00 + // 04 14 24 34 44 54 64 74 00 00 00 00 00 00 00 00 + // 05 15 25 35 45 55 65 75 00 00 00 00 00 00 00 00 + // 06 16 26 36 46 56 66 76 00 00 00 00 00 00 00 00 + // 07 17 27 37 47 57 67 77 00 00 00 00 00 00 00 00 + const __m128i v_src_4_0 = _mm_unpacklo_epi64(v_src[0], v_src[4]); + const __m128i v_src_5_1 = _mm_unpacklo_epi64(v_src[1], v_src[5]); + const __m128i v_src_6_2 = _mm_unpacklo_epi64(v_src[2], v_src[6]); + const __m128i v_src_7_3 = _mm_unpacklo_epi64(v_src[3], v_src[7]); + const __m128i v_hsum_4_0 = _mm_sad_epu8(v_src_4_0, v_zero); + const __m128i v_hsum_5_1 = _mm_sad_epu8(v_src_5_1, v_zero); + const __m128i v_hsum_6_2 = _mm_sad_epu8(v_src_6_2, v_zero); + const __m128i v_hsum_7_3 = _mm_sad_epu8(v_src_7_3, v_zero); + const __m128i v_hsum_1_0 = _mm_unpacklo_epi16(v_hsum_4_0, v_hsum_5_1); + const __m128i v_hsum_3_2 = _mm_unpacklo_epi16(v_hsum_6_2, v_hsum_7_3); + const __m128i v_hsum_5_4 = _mm_unpackhi_epi16(v_hsum_4_0, v_hsum_5_1); + const __m128i v_hsum_7_6 = _mm_unpackhi_epi16(v_hsum_6_2, v_hsum_7_3); + partial_lo[2] = + _mm_unpacklo_epi64(_mm_unpacklo_epi32(v_hsum_1_0, v_hsum_3_2), + _mm_unpacklo_epi32(v_hsum_5_4, v_hsum_7_6)); + + __m128i v_src_16[8]; + for (int i = 0; i < 8; ++i) { + v_src_16[i] = _mm_cvtepu8_epi16(v_src[i]); + } + + // partial for direction 6 + // -------------------------------------------------------------------------- + // partial[6][j] += x; + // 00 01 02 03 04 05 06 07 00 00 00 00 00 00 00 00 + // 10 11 12 13 14 15 16 17 00 00 00 00 00 00 00 00 + // 20 21 22 23 24 25 26 27 00 00 00 00 00 00 00 00 + // 30 31 32 33 34 35 36 37 00 00 00 00 00 00 00 00 + // 40 41 42 43 44 45 46 47 00 00 00 00 00 00 00 00 + // 50 51 52 53 54 55 56 57 00 00 00 00 00 00 00 00 + // 60 61 62 63 64 65 66 67 00 00 00 00 00 00 00 00 + // 70 71 72 73 74 75 76 77 00 00 00 00 00 00 00 00 + partial_lo[6] = v_src_16[0]; + for (int i = 1; i < 8; ++i) { + partial_lo[6] = _mm_add_epi16(partial_lo[6], v_src_16[i]); + } + + // partial for direction 0 + AddPartial_D0_D4(v_src_16, &partial_lo[0], &partial_hi[0]); + + // partial for direction 1 + AddPartial_D1_D3(v_src_16, &partial_lo[1], &partial_hi[1]); + + // partial for direction 7 + AddPartial_D5_D7(v_src_16, &partial_lo[7], &partial_hi[7]); + + __m128i v_src_reverse[8]; + const __m128i reverser = + _mm_set_epi32(0x01000302, 0x05040706, 0x09080b0a, 0x0d0c0f0e); + for (int i = 0; i < 8; ++i) { + v_src_reverse[i] = _mm_shuffle_epi8(v_src_16[i], reverser); + } + + // partial for direction 4 + AddPartial_D0_D4(v_src_reverse, &partial_lo[4], &partial_hi[4]); + + // partial for direction 3 + AddPartial_D1_D3(v_src_reverse, &partial_lo[3], &partial_hi[3]); + + // partial for direction 5 + AddPartial_D5_D7(v_src_reverse, &partial_lo[5], &partial_hi[5]); +} + +inline uint32_t SumVector_S32(__m128i a) { + a = _mm_hadd_epi32(a, a); + a = _mm_add_epi32(a, _mm_srli_si128(a, 4)); + return _mm_cvtsi128_si32(a); +} + +// |cost[0]| and |cost[4]| square the input and sum with the corresponding +// element from the other end of the vector: +// |kCdefDivisionTable[]| element: +// cost[0] += (Square(partial[0][i]) + Square(partial[0][14 - i])) * +// kCdefDivisionTable[i + 1]; +// cost[0] += Square(partial[0][7]) * kCdefDivisionTable[8]; +inline uint32_t Cost0Or4(const __m128i a, const __m128i b, + const __m128i division_table[2]) { + // Reverse and clear upper 2 bytes. + const __m128i reverser = + _mm_set_epi32(0x80800100, 0x03020504, 0x07060908, 0x0b0a0d0c); + // 14 13 12 11 10 09 08 ZZ + const __m128i b_reversed = _mm_shuffle_epi8(b, reverser); + // 00 14 01 13 02 12 03 11 + const __m128i ab_lo = _mm_unpacklo_epi16(a, b_reversed); + // 04 10 05 09 06 08 07 ZZ + const __m128i ab_hi = _mm_unpackhi_epi16(a, b_reversed); + + // Square(partial[0][i]) + Square(partial[0][14 - i]) + const __m128i square_lo = _mm_madd_epi16(ab_lo, ab_lo); + const __m128i square_hi = _mm_madd_epi16(ab_hi, ab_hi); + + const __m128i c = _mm_mullo_epi32(square_lo, division_table[0]); + const __m128i d = _mm_mullo_epi32(square_hi, division_table[1]); + return SumVector_S32(_mm_add_epi32(c, d)); +} + +inline uint32_t CostOdd(const __m128i a, const __m128i b, + const __m128i division_table[2]) { + // Reverse and clear upper 10 bytes. + const __m128i reverser = + _mm_set_epi32(0x80808080, 0x80808080, 0x80800100, 0x03020504); + // 10 09 08 ZZ ZZ ZZ ZZ ZZ + const __m128i b_reversed = _mm_shuffle_epi8(b, reverser); + // 00 10 01 09 02 08 03 ZZ + const __m128i ab_lo = _mm_unpacklo_epi16(a, b_reversed); + // 04 ZZ 05 ZZ 06 ZZ 07 ZZ + const __m128i ab_hi = _mm_unpackhi_epi16(a, b_reversed); + + // Square(partial[0][i]) + Square(partial[0][10 - i]) + const __m128i square_lo = _mm_madd_epi16(ab_lo, ab_lo); + const __m128i square_hi = _mm_madd_epi16(ab_hi, ab_hi); + + const __m128i c = _mm_mullo_epi32(square_lo, division_table[0]); + const __m128i d = _mm_mullo_epi32(square_hi, division_table[1]); + return SumVector_S32(_mm_add_epi32(c, d)); +} + +// Sum of squared elements. +inline uint32_t SquareSum_S16(const __m128i a) { + const __m128i square = _mm_madd_epi16(a, a); + return SumVector_S32(square); +} + +void CdefDirection_SSE4_1(const void* const source, ptrdiff_t stride, + uint8_t* const direction, int* const variance) { + assert(direction != nullptr); + assert(variance != nullptr); + const auto* src = static_cast<const uint8_t*>(source); + uint32_t cost[8]; + __m128i partial_lo[8], partial_hi[8]; + + AddPartial(src, stride, partial_lo, partial_hi); + + cost[2] = kCdefDivisionTable[7] * SquareSum_S16(partial_lo[2]); + cost[6] = kCdefDivisionTable[7] * SquareSum_S16(partial_lo[6]); + + const __m128i division_table[2] = {LoadUnaligned16(kCdefDivisionTable), + LoadUnaligned16(kCdefDivisionTable + 4)}; + + cost[0] = Cost0Or4(partial_lo[0], partial_hi[0], division_table); + cost[4] = Cost0Or4(partial_lo[4], partial_hi[4], division_table); + + const __m128i division_table_odd[2] = { + LoadAligned16(kCdefDivisionTableOddPadded), + LoadAligned16(kCdefDivisionTableOddPadded + 4)}; + + cost[1] = CostOdd(partial_lo[1], partial_hi[1], division_table_odd); + cost[3] = CostOdd(partial_lo[3], partial_hi[3], division_table_odd); + cost[5] = CostOdd(partial_lo[5], partial_hi[5], division_table_odd); + cost[7] = CostOdd(partial_lo[7], partial_hi[7], division_table_odd); + + uint32_t best_cost = 0; + *direction = 0; + for (int i = 0; i < 8; ++i) { + if (cost[i] > best_cost) { + best_cost = cost[i]; + *direction = i; + } + } + *variance = (best_cost - cost[(*direction + 4) & 7]) >> 10; +} + +// ------------------------------------------------------------------------- +// CdefFilter + +// Load 4 vectors based on the given |direction|. +inline void LoadDirection(const uint16_t* const src, const ptrdiff_t stride, + __m128i* output, const int direction) { + // Each |direction| describes a different set of source values. Expand this + // set by negating each set. For |direction| == 0 this gives a diagonal line + // from top right to bottom left. The first value is y, the second x. Negative + // y values move up. + // a b c d + // {-1, 1}, {1, -1}, {-2, 2}, {2, -2} + // c + // a + // 0 + // b + // d + const int y_0 = kCdefDirections[direction][0][0]; + const int x_0 = kCdefDirections[direction][0][1]; + const int y_1 = kCdefDirections[direction][1][0]; + const int x_1 = kCdefDirections[direction][1][1]; + output[0] = LoadUnaligned16(src - y_0 * stride - x_0); + output[1] = LoadUnaligned16(src + y_0 * stride + x_0); + output[2] = LoadUnaligned16(src - y_1 * stride - x_1); + output[3] = LoadUnaligned16(src + y_1 * stride + x_1); +} + +// Load 4 vectors based on the given |direction|. Use when |block_width| == 4 to +// do 2 rows at a time. +void LoadDirection4(const uint16_t* const src, const ptrdiff_t stride, + __m128i* output, const int direction) { + const int y_0 = kCdefDirections[direction][0][0]; + const int x_0 = kCdefDirections[direction][0][1]; + const int y_1 = kCdefDirections[direction][1][0]; + const int x_1 = kCdefDirections[direction][1][1]; + output[0] = LoadHi8(LoadLo8(src - y_0 * stride - x_0), + src - y_0 * stride + stride - x_0); + output[1] = LoadHi8(LoadLo8(src + y_0 * stride + x_0), + src + y_0 * stride + stride + x_0); + output[2] = LoadHi8(LoadLo8(src - y_1 * stride - x_1), + src - y_1 * stride + stride - x_1); + output[3] = LoadHi8(LoadLo8(src + y_1 * stride + x_1), + src + y_1 * stride + stride + x_1); +} + +inline __m128i Constrain(const __m128i& pixel, const __m128i& reference, + const __m128i& damping, const __m128i& threshold) { + const __m128i diff = _mm_sub_epi16(pixel, reference); + const __m128i abs_diff = _mm_abs_epi16(diff); + // sign(diff) * Clip3(threshold - (std::abs(diff) >> damping), + // 0, std::abs(diff)) + const __m128i shifted_diff = _mm_srl_epi16(abs_diff, damping); + // For bitdepth == 8, the threshold range is [0, 15] and the damping range is + // [3, 6]. If pixel == kCdefLargeValue(0x4000), shifted_diff will always be + // larger than threshold. Subtract using saturation will return 0 when pixel + // == kCdefLargeValue. + static_assert(kCdefLargeValue == 0x4000, "Invalid kCdefLargeValue"); + const __m128i thresh_minus_shifted_diff = + _mm_subs_epu16(threshold, shifted_diff); + const __m128i clamp_abs_diff = + _mm_min_epi16(thresh_minus_shifted_diff, abs_diff); + // Restore the sign. + return _mm_sign_epi16(clamp_abs_diff, diff); +} + +inline __m128i ApplyConstrainAndTap(const __m128i& pixel, const __m128i& val, + const __m128i& tap, const __m128i& damping, + const __m128i& threshold) { + const __m128i constrained = Constrain(val, pixel, damping, threshold); + return _mm_mullo_epi16(constrained, tap); +} + +template <int width, bool enable_primary = true, bool enable_secondary = true> +void CdefFilter_SSE4_1(const uint16_t* src, const ptrdiff_t src_stride, + const int height, const int primary_strength, + const int secondary_strength, const int damping, + const int direction, void* dest, + const ptrdiff_t dst_stride) { + static_assert(width == 8 || width == 4, "Invalid CDEF width."); + static_assert(enable_primary || enable_secondary, ""); + constexpr bool clipping_required = enable_primary && enable_secondary; + auto* dst = static_cast<uint8_t*>(dest); + __m128i primary_damping_shift, secondary_damping_shift; + + // FloorLog2() requires input to be > 0. + // 8-bit damping range: Y: [3, 6], UV: [2, 5]. + if (enable_primary) { + // primary_strength: [0, 15] -> FloorLog2: [0, 3] so a clamp is necessary + // for UV filtering. + primary_damping_shift = + _mm_cvtsi32_si128(std::max(0, damping - FloorLog2(primary_strength))); + } + if (enable_secondary) { + // secondary_strength: [0, 4] -> FloorLog2: [0, 2] so no clamp to 0 is + // necessary. + assert(damping - FloorLog2(secondary_strength) >= 0); + secondary_damping_shift = + _mm_cvtsi32_si128(damping - FloorLog2(secondary_strength)); + } + + const __m128i primary_tap_0 = + _mm_set1_epi16(kCdefPrimaryTaps[primary_strength & 1][0]); + const __m128i primary_tap_1 = + _mm_set1_epi16(kCdefPrimaryTaps[primary_strength & 1][1]); + const __m128i secondary_tap_0 = _mm_set1_epi16(kCdefSecondaryTap0); + const __m128i secondary_tap_1 = _mm_set1_epi16(kCdefSecondaryTap1); + const __m128i cdef_large_value_mask = + _mm_set1_epi16(static_cast<int16_t>(~kCdefLargeValue)); + const __m128i primary_threshold = _mm_set1_epi16(primary_strength); + const __m128i secondary_threshold = _mm_set1_epi16(secondary_strength); + + int y = height; + do { + __m128i pixel; + if (width == 8) { + pixel = LoadUnaligned16(src); + } else { + pixel = LoadHi8(LoadLo8(src), src + src_stride); + } + + __m128i min = pixel; + __m128i max = pixel; + __m128i sum; + + if (enable_primary) { + // Primary |direction|. + __m128i primary_val[4]; + if (width == 8) { + LoadDirection(src, src_stride, primary_val, direction); + } else { + LoadDirection4(src, src_stride, primary_val, direction); + } + + if (clipping_required) { + min = _mm_min_epu16(min, primary_val[0]); + min = _mm_min_epu16(min, primary_val[1]); + min = _mm_min_epu16(min, primary_val[2]); + min = _mm_min_epu16(min, primary_val[3]); + + // The source is 16 bits, however, we only really care about the lower + // 8 bits. The upper 8 bits contain the "large" flag. After the final + // primary max has been calculated, zero out the upper 8 bits. Use this + // to find the "16 bit" max. + const __m128i max_p01 = _mm_max_epu8(primary_val[0], primary_val[1]); + const __m128i max_p23 = _mm_max_epu8(primary_val[2], primary_val[3]); + const __m128i max_p = _mm_max_epu8(max_p01, max_p23); + max = _mm_max_epu16(max, _mm_and_si128(max_p, cdef_large_value_mask)); + } + + sum = ApplyConstrainAndTap(pixel, primary_val[0], primary_tap_0, + primary_damping_shift, primary_threshold); + sum = _mm_add_epi16( + sum, ApplyConstrainAndTap(pixel, primary_val[1], primary_tap_0, + primary_damping_shift, primary_threshold)); + sum = _mm_add_epi16( + sum, ApplyConstrainAndTap(pixel, primary_val[2], primary_tap_1, + primary_damping_shift, primary_threshold)); + sum = _mm_add_epi16( + sum, ApplyConstrainAndTap(pixel, primary_val[3], primary_tap_1, + primary_damping_shift, primary_threshold)); + } else { + sum = _mm_setzero_si128(); + } + + if (enable_secondary) { + // Secondary |direction| values (+/- 2). Clamp |direction|. + __m128i secondary_val[8]; + if (width == 8) { + LoadDirection(src, src_stride, secondary_val, direction + 2); + LoadDirection(src, src_stride, secondary_val + 4, direction - 2); + } else { + LoadDirection4(src, src_stride, secondary_val, direction + 2); + LoadDirection4(src, src_stride, secondary_val + 4, direction - 2); + } + + if (clipping_required) { + min = _mm_min_epu16(min, secondary_val[0]); + min = _mm_min_epu16(min, secondary_val[1]); + min = _mm_min_epu16(min, secondary_val[2]); + min = _mm_min_epu16(min, secondary_val[3]); + min = _mm_min_epu16(min, secondary_val[4]); + min = _mm_min_epu16(min, secondary_val[5]); + min = _mm_min_epu16(min, secondary_val[6]); + min = _mm_min_epu16(min, secondary_val[7]); + + const __m128i max_s01 = + _mm_max_epu8(secondary_val[0], secondary_val[1]); + const __m128i max_s23 = + _mm_max_epu8(secondary_val[2], secondary_val[3]); + const __m128i max_s45 = + _mm_max_epu8(secondary_val[4], secondary_val[5]); + const __m128i max_s67 = + _mm_max_epu8(secondary_val[6], secondary_val[7]); + const __m128i max_s = _mm_max_epu8(_mm_max_epu8(max_s01, max_s23), + _mm_max_epu8(max_s45, max_s67)); + max = _mm_max_epu16(max, _mm_and_si128(max_s, cdef_large_value_mask)); + } + + sum = _mm_add_epi16( + sum, + ApplyConstrainAndTap(pixel, secondary_val[0], secondary_tap_0, + secondary_damping_shift, secondary_threshold)); + sum = _mm_add_epi16( + sum, + ApplyConstrainAndTap(pixel, secondary_val[1], secondary_tap_0, + secondary_damping_shift, secondary_threshold)); + sum = _mm_add_epi16( + sum, + ApplyConstrainAndTap(pixel, secondary_val[2], secondary_tap_1, + secondary_damping_shift, secondary_threshold)); + sum = _mm_add_epi16( + sum, + ApplyConstrainAndTap(pixel, secondary_val[3], secondary_tap_1, + secondary_damping_shift, secondary_threshold)); + sum = _mm_add_epi16( + sum, + ApplyConstrainAndTap(pixel, secondary_val[4], secondary_tap_0, + secondary_damping_shift, secondary_threshold)); + sum = _mm_add_epi16( + sum, + ApplyConstrainAndTap(pixel, secondary_val[5], secondary_tap_0, + secondary_damping_shift, secondary_threshold)); + sum = _mm_add_epi16( + sum, + ApplyConstrainAndTap(pixel, secondary_val[6], secondary_tap_1, + secondary_damping_shift, secondary_threshold)); + sum = _mm_add_epi16( + sum, + ApplyConstrainAndTap(pixel, secondary_val[7], secondary_tap_1, + secondary_damping_shift, secondary_threshold)); + } + // Clip3(pixel + ((8 + sum - (sum < 0)) >> 4), min, max)) + const __m128i sum_lt_0 = _mm_srai_epi16(sum, 15); + // 8 + sum + sum = _mm_add_epi16(sum, _mm_set1_epi16(8)); + // (... - (sum < 0)) >> 4 + sum = _mm_add_epi16(sum, sum_lt_0); + sum = _mm_srai_epi16(sum, 4); + // pixel + ... + sum = _mm_add_epi16(sum, pixel); + if (clipping_required) { + // Clip3 + sum = _mm_min_epi16(sum, max); + sum = _mm_max_epi16(sum, min); + } + + const __m128i result = _mm_packus_epi16(sum, sum); + if (width == 8) { + src += src_stride; + StoreLo8(dst, result); + dst += dst_stride; + --y; + } else { + src += src_stride << 1; + Store4(dst, result); + dst += dst_stride; + Store4(dst, _mm_srli_si128(result, 4)); + dst += dst_stride; + y -= 2; + } + } while (y != 0); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(8); + assert(dsp != nullptr); + dsp->cdef_direction = CdefDirection_SSE4_1; + dsp->cdef_filters[0][0] = CdefFilter_SSE4_1<4>; + dsp->cdef_filters[0][1] = + CdefFilter_SSE4_1<4, /*enable_primary=*/true, /*enable_secondary=*/false>; + dsp->cdef_filters[0][2] = CdefFilter_SSE4_1<4, /*enable_primary=*/false>; + dsp->cdef_filters[1][0] = CdefFilter_SSE4_1<8>; + dsp->cdef_filters[1][1] = + CdefFilter_SSE4_1<8, /*enable_primary=*/true, /*enable_secondary=*/false>; + dsp->cdef_filters[1][2] = CdefFilter_SSE4_1<8, /*enable_primary=*/false>; +} + +} // namespace +} // namespace low_bitdepth + +void CdefInit_SSE4_1() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 +#else // !LIBGAV1_TARGETING_SSE4_1 +namespace libgav1 { +namespace dsp { + +void CdefInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/cdef_sse4.h b/src/dsp/x86/cdef_sse4.h new file mode 100644 index 0000000..6631eb7 --- /dev/null +++ b/src/dsp/x86/cdef_sse4.h @@ -0,0 +1,45 @@ +/* + * Copyright 2020 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_CDEF_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_CDEF_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::cdef_direction and Dsp::cdef_filters. This function is not +// thread-safe. +void CdefInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_TARGETING_SSE4_1 + +#ifndef LIBGAV1_Dsp8bpp_CdefDirection +#define LIBGAV1_Dsp8bpp_CdefDirection LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_CdefFilters +#define LIBGAV1_Dsp8bpp_CdefFilters LIBGAV1_CPU_SSE4_1 +#endif + +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_X86_CDEF_SSE4_H_ diff --git a/src/dsp/x86/common_avx2.h b/src/dsp/x86/common_avx2.h new file mode 100644 index 0000000..4ce7de2 --- /dev/null +++ b/src/dsp/x86/common_avx2.h @@ -0,0 +1,138 @@ +/* + * Copyright 2020 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_COMMON_AVX2_H_ +#define LIBGAV1_SRC_DSP_X86_COMMON_AVX2_H_ + +#include "src/utils/compiler_attributes.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_AVX2 + +#include <immintrin.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +namespace libgav1 { +namespace dsp { + +//------------------------------------------------------------------------------ +// Compatibility functions. + +inline __m256i SetrM128i(const __m128i lo, const __m128i hi) { + // For compatibility with older gcc toolchains (< 8) use + // _mm256_inserti128_si256 over _mm256_setr_m128i. Newer gcc implementations + // are implemented similarly to the following, clang uses a different method + // but no differences in assembly have been observed. + return _mm256_inserti128_si256(_mm256_castsi128_si256(lo), hi, 1); +} + +//------------------------------------------------------------------------------ +// Load functions. + +inline __m256i LoadAligned32(const void* a) { + assert((reinterpret_cast<uintptr_t>(a) & 0x1f) == 0); + return _mm256_load_si256(static_cast<const __m256i*>(a)); +} + +inline void LoadAligned64(const void* a, __m256i dst[2]) { + assert((reinterpret_cast<uintptr_t>(a) & 0x1f) == 0); + dst[0] = _mm256_load_si256(static_cast<const __m256i*>(a) + 0); + dst[1] = _mm256_load_si256(static_cast<const __m256i*>(a) + 1); +} + +inline __m256i LoadUnaligned32(const void* a) { + return _mm256_loadu_si256(static_cast<const __m256i*>(a)); +} + +//------------------------------------------------------------------------------ +// Load functions to avoid MemorySanitizer's use-of-uninitialized-value warning. + +inline __m256i MaskOverreads(const __m256i source, + const ptrdiff_t over_read_in_bytes) { + __m256i dst = source; +#if LIBGAV1_MSAN + if (over_read_in_bytes >= 32) return _mm256_setzero_si256(); + if (over_read_in_bytes > 0) { + __m128i m = _mm_set1_epi8(-1); + for (ptrdiff_t i = 0; i < over_read_in_bytes % 16; ++i) { + m = _mm_srli_si128(m, 1); + } + const __m256i mask = (over_read_in_bytes < 16) + ? SetrM128i(_mm_set1_epi8(-1), m) + : SetrM128i(m, _mm_setzero_si128()); + dst = _mm256_and_si256(dst, mask); + } +#else + static_cast<void>(over_read_in_bytes); +#endif + return dst; +} + +inline __m256i LoadAligned32Msan(const void* const source, + const ptrdiff_t over_read_in_bytes) { + return MaskOverreads(LoadAligned32(source), over_read_in_bytes); +} + +inline void LoadAligned64Msan(const void* const source, + const ptrdiff_t over_read_in_bytes, + __m256i dst[2]) { + dst[0] = MaskOverreads(LoadAligned32(source), over_read_in_bytes); + dst[1] = MaskOverreads(LoadAligned32(static_cast<const __m256i*>(source) + 1), + over_read_in_bytes); +} + +inline __m256i LoadUnaligned32Msan(const void* const source, + const ptrdiff_t over_read_in_bytes) { + return MaskOverreads(LoadUnaligned32(source), over_read_in_bytes); +} + +//------------------------------------------------------------------------------ +// Store functions. + +inline void StoreAligned32(void* a, const __m256i v) { + assert((reinterpret_cast<uintptr_t>(a) & 0x1f) == 0); + _mm256_store_si256(static_cast<__m256i*>(a), v); +} + +inline void StoreAligned64(void* a, const __m256i v[2]) { + assert((reinterpret_cast<uintptr_t>(a) & 0x1f) == 0); + _mm256_store_si256(static_cast<__m256i*>(a) + 0, v[0]); + _mm256_store_si256(static_cast<__m256i*>(a) + 1, v[1]); +} + +inline void StoreUnaligned32(void* a, const __m256i v) { + _mm256_storeu_si256(static_cast<__m256i*>(a), v); +} + +//------------------------------------------------------------------------------ +// Arithmetic utilities. + +inline __m256i RightShiftWithRounding_S16(const __m256i v_val_d, int bits) { + assert(bits <= 16); + const __m256i v_bias_d = + _mm256_set1_epi16(static_cast<int16_t>((1 << bits) >> 1)); + const __m256i v_tmp_d = _mm256_add_epi16(v_val_d, v_bias_d); + return _mm256_srai_epi16(v_tmp_d, bits); +} + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_TARGETING_AVX2 +#endif // LIBGAV1_SRC_DSP_X86_COMMON_AVX2_H_ diff --git a/src/dsp/x86/common_sse4.h b/src/dsp/x86/common_sse4.h new file mode 100644 index 0000000..c510f8c --- /dev/null +++ b/src/dsp/x86/common_sse4.h @@ -0,0 +1,265 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_COMMON_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_COMMON_SSE4_H_ + +#include "src/utils/compiler_attributes.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 + +#include <emmintrin.h> +#include <smmintrin.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <cstring> + +#if 0 +#include <cinttypes> +#include <cstdio> + +// Quite useful macro for debugging. Left here for convenience. +inline void PrintReg(const __m128i r, const char* const name, int size) { + int n; + union { + __m128i r; + uint8_t i8[16]; + uint16_t i16[8]; + uint32_t i32[4]; + uint64_t i64[2]; + } tmp; + tmp.r = r; + fprintf(stderr, "%s\t: ", name); + if (size == 8) { + for (n = 0; n < 16; ++n) fprintf(stderr, "%.2x ", tmp.i8[n]); + } else if (size == 16) { + for (n = 0; n < 8; ++n) fprintf(stderr, "%.4x ", tmp.i16[n]); + } else if (size == 32) { + for (n = 0; n < 4; ++n) fprintf(stderr, "%.8x ", tmp.i32[n]); + } else { + for (n = 0; n < 2; ++n) + fprintf(stderr, "%.16" PRIx64 " ", static_cast<uint64_t>(tmp.i64[n])); + } + fprintf(stderr, "\n"); +} + +inline void PrintReg(const int r, const char* const name) { + fprintf(stderr, "%s: %d\n", name, r); +} + +inline void PrintRegX(const int r, const char* const name) { + fprintf(stderr, "%s: %.8x\n", name, r); +} + +#define PR(var, N) PrintReg(var, #var, N) +#define PD(var) PrintReg(var, #var); +#define PX(var) PrintRegX(var, #var); +#endif // 0 + +namespace libgav1 { +namespace dsp { + +//------------------------------------------------------------------------------ +// Load functions. + +inline __m128i Load2(const void* src) { + int16_t val; + memcpy(&val, src, sizeof(val)); + return _mm_cvtsi32_si128(val); +} + +inline __m128i Load2x2(const void* src1, const void* src2) { + uint16_t val1; + uint16_t val2; + memcpy(&val1, src1, sizeof(val1)); + memcpy(&val2, src2, sizeof(val2)); + return _mm_cvtsi32_si128(val1 | (val2 << 16)); +} + +// Load 2 uint8_t values into |lane| * 2 and |lane| * 2 + 1. +template <int lane> +inline __m128i Load2(const void* const buf, __m128i val) { + uint16_t temp; + memcpy(&temp, buf, 2); + return _mm_insert_epi16(val, temp, lane); +} + +inline __m128i Load4(const void* src) { + // With new compilers such as clang 8.0.0 we can use the new _mm_loadu_si32 + // intrinsic. Both _mm_loadu_si32(src) and the code here are compiled into a + // movss instruction. + // + // Until compiler support of _mm_loadu_si32 is widespread, use of + // _mm_loadu_si32 is banned. + int val; + memcpy(&val, src, sizeof(val)); + return _mm_cvtsi32_si128(val); +} + +inline __m128i Load4x2(const void* src1, const void* src2) { + // With new compilers such as clang 8.0.0 we can use the new _mm_loadu_si32 + // intrinsic. Both _mm_loadu_si32(src) and the code here are compiled into a + // movss instruction. + // + // Until compiler support of _mm_loadu_si32 is widespread, use of + // _mm_loadu_si32 is banned. + int val1, val2; + memcpy(&val1, src1, sizeof(val1)); + memcpy(&val2, src2, sizeof(val2)); + return _mm_insert_epi32(_mm_cvtsi32_si128(val1), val2, 1); +} + +inline __m128i LoadLo8(const void* a) { + return _mm_loadl_epi64(static_cast<const __m128i*>(a)); +} + +inline __m128i LoadHi8(const __m128i v, const void* a) { + const __m128 x = + _mm_loadh_pi(_mm_castsi128_ps(v), static_cast<const __m64*>(a)); + return _mm_castps_si128(x); +} + +inline __m128i LoadUnaligned16(const void* a) { + return _mm_loadu_si128(static_cast<const __m128i*>(a)); +} + +inline __m128i LoadAligned16(const void* a) { + assert((reinterpret_cast<uintptr_t>(a) & 0xf) == 0); + return _mm_load_si128(static_cast<const __m128i*>(a)); +} + +//------------------------------------------------------------------------------ +// Load functions to avoid MemorySanitizer's use-of-uninitialized-value warning. + +inline __m128i MaskOverreads(const __m128i source, + const ptrdiff_t over_read_in_bytes) { + __m128i dst = source; +#if LIBGAV1_MSAN + if (over_read_in_bytes > 0) { + __m128i mask = _mm_set1_epi8(-1); + for (ptrdiff_t i = 0; i < over_read_in_bytes; ++i) { + mask = _mm_srli_si128(mask, 1); + } + dst = _mm_and_si128(dst, mask); + } +#else + static_cast<void>(over_read_in_bytes); +#endif + return dst; +} + +inline __m128i LoadLo8Msan(const void* const source, + const ptrdiff_t over_read_in_bytes) { + return MaskOverreads(LoadLo8(source), over_read_in_bytes + 8); +} + +inline __m128i LoadHi8Msan(const __m128i v, const void* source, + const ptrdiff_t over_read_in_bytes) { + return MaskOverreads(LoadHi8(v, source), over_read_in_bytes); +} + +inline __m128i LoadAligned16Msan(const void* const source, + const ptrdiff_t over_read_in_bytes) { + return MaskOverreads(LoadAligned16(source), over_read_in_bytes); +} + +inline __m128i LoadUnaligned16Msan(const void* const source, + const ptrdiff_t over_read_in_bytes) { + return MaskOverreads(LoadUnaligned16(source), over_read_in_bytes); +} + +//------------------------------------------------------------------------------ +// Store functions. + +inline void Store2(void* dst, const __m128i x) { + const int val = _mm_cvtsi128_si32(x); + memcpy(dst, &val, 2); +} + +inline void Store4(void* dst, const __m128i x) { + const int val = _mm_cvtsi128_si32(x); + memcpy(dst, &val, sizeof(val)); +} + +inline void StoreLo8(void* a, const __m128i v) { + _mm_storel_epi64(static_cast<__m128i*>(a), v); +} + +inline void StoreHi8(void* a, const __m128i v) { + _mm_storeh_pi(static_cast<__m64*>(a), _mm_castsi128_ps(v)); +} + +inline void StoreAligned16(void* a, const __m128i v) { + assert((reinterpret_cast<uintptr_t>(a) & 0xf) == 0); + _mm_store_si128(static_cast<__m128i*>(a), v); +} + +inline void StoreUnaligned16(void* a, const __m128i v) { + _mm_storeu_si128(static_cast<__m128i*>(a), v); +} + +//------------------------------------------------------------------------------ +// Arithmetic utilities. + +inline __m128i RightShiftWithRounding_U16(const __m128i v_val_d, int bits) { + assert(bits <= 16); + // Shift out all but the last bit. + const __m128i v_tmp_d = _mm_srli_epi16(v_val_d, bits - 1); + // Avg with zero will shift by 1 and round. + return _mm_avg_epu16(v_tmp_d, _mm_setzero_si128()); +} + +inline __m128i RightShiftWithRounding_S16(const __m128i v_val_d, int bits) { + assert(bits <= 16); + const __m128i v_bias_d = + _mm_set1_epi16(static_cast<int16_t>((1 << bits) >> 1)); + const __m128i v_tmp_d = _mm_add_epi16(v_val_d, v_bias_d); + return _mm_srai_epi16(v_tmp_d, bits); +} + +inline __m128i RightShiftWithRounding_U32(const __m128i v_val_d, int bits) { + const __m128i v_bias_d = _mm_set1_epi32((1 << bits) >> 1); + const __m128i v_tmp_d = _mm_add_epi32(v_val_d, v_bias_d); + return _mm_srli_epi32(v_tmp_d, bits); +} + +inline __m128i RightShiftWithRounding_S32(const __m128i v_val_d, int bits) { + const __m128i v_bias_d = _mm_set1_epi32((1 << bits) >> 1); + const __m128i v_tmp_d = _mm_add_epi32(v_val_d, v_bias_d); + return _mm_srai_epi32(v_tmp_d, bits); +} + +//------------------------------------------------------------------------------ +// Masking utilities +inline __m128i MaskHighNBytes(int n) { + static constexpr uint8_t kMask[32] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + }; + + return LoadUnaligned16(kMask + n); +} + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_TARGETING_SSE4_1 +#endif // LIBGAV1_SRC_DSP_X86_COMMON_SSE4_H_ diff --git a/src/dsp/x86/convolve_avx2.cc b/src/dsp/x86/convolve_avx2.cc new file mode 100644 index 0000000..3df2120 --- /dev/null +++ b/src/dsp/x86/convolve_avx2.cc @@ -0,0 +1,534 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/convolve.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_AVX2 +#include <immintrin.h> + +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <cstring> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_avx2.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +constexpr int kHorizontalOffset = 3; + +// Multiply every entry in |src[]| by the corresponding entry in |taps[]| and +// sum. The filters in |taps[]| are pre-shifted by 1. This prevents the final +// sum from outranging int16_t. +template <int filter_index> +__m256i SumOnePassTaps(const __m256i* const src, const __m256i* const taps) { + __m256i sum; + if (filter_index < 2) { + // 6 taps. + const __m256i v_madd_21 = _mm256_maddubs_epi16(src[0], taps[0]); // k2k1 + const __m256i v_madd_43 = _mm256_maddubs_epi16(src[1], taps[1]); // k4k3 + const __m256i v_madd_65 = _mm256_maddubs_epi16(src[2], taps[2]); // k6k5 + sum = _mm256_add_epi16(v_madd_21, v_madd_43); + sum = _mm256_add_epi16(sum, v_madd_65); + } else if (filter_index == 2) { + // 8 taps. + const __m256i v_madd_10 = _mm256_maddubs_epi16(src[0], taps[0]); // k1k0 + const __m256i v_madd_32 = _mm256_maddubs_epi16(src[1], taps[1]); // k3k2 + const __m256i v_madd_54 = _mm256_maddubs_epi16(src[2], taps[2]); // k5k4 + const __m256i v_madd_76 = _mm256_maddubs_epi16(src[3], taps[3]); // k7k6 + const __m256i v_sum_3210 = _mm256_add_epi16(v_madd_10, v_madd_32); + const __m256i v_sum_7654 = _mm256_add_epi16(v_madd_54, v_madd_76); + sum = _mm256_add_epi16(v_sum_7654, v_sum_3210); + } else if (filter_index == 3) { + // 2 taps. + sum = _mm256_maddubs_epi16(src[0], taps[0]); // k4k3 + } else { + // 4 taps. + const __m256i v_madd_32 = _mm256_maddubs_epi16(src[0], taps[0]); // k3k2 + const __m256i v_madd_54 = _mm256_maddubs_epi16(src[1], taps[1]); // k5k4 + sum = _mm256_add_epi16(v_madd_32, v_madd_54); + } + return sum; +} + +template <int filter_index> +__m256i SumHorizontalTaps(const __m256i* const src, + const __m256i* const v_tap) { + __m256i v_src[4]; + const __m256i src_long = *src; + const __m256i src_long_dup_lo = _mm256_unpacklo_epi8(src_long, src_long); + const __m256i src_long_dup_hi = _mm256_unpackhi_epi8(src_long, src_long); + + if (filter_index < 2) { + // 6 taps. + v_src[0] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 3); // _21 + v_src[1] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 7); // _43 + v_src[2] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 11); // _65 + } else if (filter_index == 2) { + // 8 taps. + v_src[0] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 1); // _10 + v_src[1] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 5); // _32 + v_src[2] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 9); // _54 + v_src[3] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 13); // _76 + } else if (filter_index == 3) { + // 2 taps. + v_src[0] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 7); // _43 + } else if (filter_index > 3) { + // 4 taps. + v_src[0] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 5); // _32 + v_src[1] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 9); // _54 + } + return SumOnePassTaps<filter_index>(v_src, v_tap); +} + +template <int filter_index> +__m256i SimpleHorizontalTaps(const __m256i* const src, + const __m256i* const v_tap) { + __m256i sum = SumHorizontalTaps<filter_index>(src, v_tap); + + // Normally the Horizontal pass does the downshift in two passes: + // kInterRoundBitsHorizontal - 1 and then (kFilterBits - + // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them + // requires adding the rounding offset from the skipped shift. + constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2); + + sum = _mm256_add_epi16(sum, _mm256_set1_epi16(first_shift_rounding_bit)); + sum = RightShiftWithRounding_S16(sum, kFilterBits - 1); + return _mm256_packus_epi16(sum, sum); +} + +template <int filter_index> +__m128i SumHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride, + const __m128i* const v_tap) { + // 00 01 02 03 04 05 06 07 10 11 12 13 14 15 16 17 + const __m128i v_src = LoadHi8(LoadLo8(&src[0]), &src[src_stride]); + + if (filter_index == 3) { + // 03 04 04 05 05 06 06 07 13 14 14 15 15 16 16 17 + const __m128i v_src_43 = _mm_shuffle_epi8( + v_src, _mm_set_epi32(0x0f0e0e0d, 0x0d0c0c0b, 0x07060605, 0x05040403)); + const __m128i v_sum_43 = _mm_maddubs_epi16(v_src_43, v_tap[0]); // k4k3 + return v_sum_43; + } + + // 02 03 03 04 04 05 05 06 12 13 13 14 14 15 15 16 + const __m128i v_src_32 = _mm_shuffle_epi8( + v_src, _mm_set_epi32(0x0e0d0d0c, 0x0c0b0b0a, 0x06050504, 0x04030302)); + // 04 05 05 06 06 07 07 xx 14 15 15 16 16 17 17 xx + const __m128i v_src_54 = _mm_shuffle_epi8( + v_src, _mm_set_epi32(0x800f0f0e, 0x0e0d0d0c, 0x80070706, 0x06050504)); + const __m128i v_madd_32 = _mm_maddubs_epi16(v_src_32, v_tap[0]); // k3k2 + const __m128i v_madd_54 = _mm_maddubs_epi16(v_src_54, v_tap[1]); // k5k4 + const __m128i v_sum_5432 = _mm_add_epi16(v_madd_54, v_madd_32); + return v_sum_5432; +} + +template <int filter_index> +__m128i SimpleHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride, + const __m128i* const v_tap) { + __m128i sum = SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); + + // Normally the Horizontal pass does the downshift in two passes: + // kInterRoundBitsHorizontal - 1 and then (kFilterBits - + // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them + // requires adding the rounding offset from the skipped shift. + constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2); + + sum = _mm_add_epi16(sum, _mm_set1_epi16(first_shift_rounding_bit)); + sum = RightShiftWithRounding_S16(sum, kFilterBits - 1); + return _mm_packus_epi16(sum, sum); +} + +template <int filter_index> +__m128i HorizontalTaps8To16_2x2(const uint8_t* src, const ptrdiff_t src_stride, + const __m128i* const v_tap) { + const __m128i sum = + SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); + + return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1); +} + +// Filter 2xh sizes. +template <int num_taps, int step, int filter_index, bool is_2d = false, + bool is_compound = false> +void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, + void* const dest, const ptrdiff_t pred_stride, + const int /*width*/, const int height, + const __m128i* const v_tap) { + auto* dest8 = static_cast<uint8_t*>(dest); + auto* dest16 = static_cast<uint16_t*>(dest); + + // Horizontal passes only need to account for |num_taps| 2 and 4 when + // |width| <= 4. + assert(num_taps <= 4); + if (num_taps <= 4) { + if (!is_compound) { + int y = 0; + do { + if (is_2d) { + const __m128i sum = + HorizontalTaps8To16_2x2<filter_index>(src, src_stride, v_tap); + Store4(&dest16[0], sum); + dest16 += pred_stride; + Store4(&dest16[0], _mm_srli_si128(sum, 8)); + dest16 += pred_stride; + } else { + const __m128i sum = + SimpleHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); + Store2(dest8, sum); + dest8 += pred_stride; + Store2(dest8, _mm_srli_si128(sum, 4)); + dest8 += pred_stride; + } + + src += src_stride << 1; + y += 2; + } while (y < height - 1); + + // The 2d filters have an odd |height| because the horizontal pass + // generates context for the vertical pass. + if (is_2d) { + assert(height % 2 == 1); + __m128i sum; + const __m128i input = LoadLo8(&src[2]); + if (filter_index == 3) { + // 03 04 04 05 05 06 06 07 .... + const __m128i v_src_43 = + _mm_srli_si128(_mm_unpacklo_epi8(input, input), 3); + sum = _mm_maddubs_epi16(v_src_43, v_tap[0]); // k4k3 + } else { + // 02 03 03 04 04 05 05 06 06 07 .... + const __m128i v_src_32 = + _mm_srli_si128(_mm_unpacklo_epi8(input, input), 1); + // 04 05 05 06 06 07 07 08 ... + const __m128i v_src_54 = _mm_srli_si128(v_src_32, 4); + const __m128i v_madd_32 = + _mm_maddubs_epi16(v_src_32, v_tap[0]); // k3k2 + const __m128i v_madd_54 = + _mm_maddubs_epi16(v_src_54, v_tap[1]); // k5k4 + sum = _mm_add_epi16(v_madd_54, v_madd_32); + } + sum = RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1); + Store4(dest16, sum); + } + } + } +} + +// Filter widths >= 4. +template <int num_taps, int step, int filter_index, bool is_2d = false, + bool is_compound = false> +void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, + void* const dest, const ptrdiff_t pred_stride, + const int width, const int height, + const __m256i* const v_tap) { + auto* dest8 = static_cast<uint8_t*>(dest); + auto* dest16 = static_cast<uint16_t*>(dest); + + if (width >= 32) { + int y = height; + do { + int x = 0; + do { + if (is_2d || is_compound) { + // placeholder + } else { + // Load src used to calculate dest8[7:0] and dest8[23:16]. + const __m256i src_long = LoadUnaligned32(&src[x]); + const __m256i result = + SimpleHorizontalTaps<filter_index>(&src_long, v_tap); + // Load src used to calculate dest8[15:8] and dest8[31:24]. + const __m256i src_long2 = LoadUnaligned32(&src[x + 8]); + const __m256i result2 = + SimpleHorizontalTaps<filter_index>(&src_long2, v_tap); + // Combine results and store. + StoreUnaligned32(&dest8[x], _mm256_unpacklo_epi64(result, result2)); + } + x += step * 4; + } while (x < width); + src += src_stride; + dest8 += pred_stride; + dest16 += pred_stride; + } while (--y != 0); + } else if (width == 16) { + int y = height; + do { + if (is_2d || is_compound) { + // placeholder + } else { + // Load into 2 128 bit lanes. + const __m256i src_long = SetrM128i(LoadUnaligned16(&src[0]), + LoadUnaligned16(&src[src_stride])); + const __m256i result = + SimpleHorizontalTaps<filter_index>(&src_long, v_tap); + const __m256i src_long2 = SetrM128i( + LoadUnaligned16(&src[8]), LoadUnaligned16(&src[8 + src_stride])); + const __m256i result2 = + SimpleHorizontalTaps<filter_index>(&src_long2, v_tap); + const __m256i packed_result = _mm256_unpacklo_epi64(result, result2); + StoreUnaligned16(&dest8[0], _mm256_castsi256_si128(packed_result)); + StoreUnaligned16(&dest8[pred_stride], + _mm256_extracti128_si256(packed_result, 1)); + } + src += src_stride * 2; + dest8 += pred_stride * 2; + dest16 += pred_stride * 2; + y -= 2; + } while (y != 0); + } else if (width == 8) { + int y = height; + do { + if (is_2d || is_compound) { + // placeholder + } else { + const __m128i this_row = LoadUnaligned16(&src[0]); + const __m128i next_row = LoadUnaligned16(&src[src_stride]); + // Load into 2 128 bit lanes. + const __m256i src_long = SetrM128i(this_row, next_row); + const __m256i result = + SimpleHorizontalTaps<filter_index>(&src_long, v_tap); + StoreLo8(&dest8[0], _mm256_castsi256_si128(result)); + StoreLo8(&dest8[pred_stride], _mm256_extracti128_si256(result, 1)); + } + src += src_stride * 2; + dest8 += pred_stride * 2; + dest16 += pred_stride * 2; + y -= 2; + } while (y != 0); + } else { // width == 4 + int y = height; + do { + if (is_2d || is_compound) { + // placeholder + } else { + const __m128i this_row = LoadUnaligned16(&src[0]); + const __m128i next_row = LoadUnaligned16(&src[src_stride]); + // Load into 2 128 bit lanes. + const __m256i src_long = SetrM128i(this_row, next_row); + const __m256i result = + SimpleHorizontalTaps<filter_index>(&src_long, v_tap); + Store4(&dest8[0], _mm256_castsi256_si128(result)); + Store4(&dest8[pred_stride], _mm256_extracti128_si256(result, 1)); + } + src += src_stride * 2; + dest8 += pred_stride * 2; + dest16 += pred_stride * 2; + y -= 2; + } while (y != 0); + } +} + +template <int num_taps, bool is_2d_vertical = false> +LIBGAV1_ALWAYS_INLINE void SetupTaps(const __m128i* const filter, + __m128i* v_tap) { + if (num_taps == 8) { + v_tap[0] = _mm_shufflelo_epi16(*filter, 0x0); // k1k0 + v_tap[1] = _mm_shufflelo_epi16(*filter, 0x55); // k3k2 + v_tap[2] = _mm_shufflelo_epi16(*filter, 0xaa); // k5k4 + v_tap[3] = _mm_shufflelo_epi16(*filter, 0xff); // k7k6 + if (is_2d_vertical) { + v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); + v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]); + v_tap[2] = _mm_cvtepi8_epi16(v_tap[2]); + v_tap[3] = _mm_cvtepi8_epi16(v_tap[3]); + } else { + v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); + v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]); + v_tap[2] = _mm_unpacklo_epi64(v_tap[2], v_tap[2]); + v_tap[3] = _mm_unpacklo_epi64(v_tap[3], v_tap[3]); + } + } else if (num_taps == 6) { + const __m128i adjusted_filter = _mm_srli_si128(*filter, 1); + v_tap[0] = _mm_shufflelo_epi16(adjusted_filter, 0x0); // k2k1 + v_tap[1] = _mm_shufflelo_epi16(adjusted_filter, 0x55); // k4k3 + v_tap[2] = _mm_shufflelo_epi16(adjusted_filter, 0xaa); // k6k5 + if (is_2d_vertical) { + v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); + v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]); + v_tap[2] = _mm_cvtepi8_epi16(v_tap[2]); + } else { + v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); + v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]); + v_tap[2] = _mm_unpacklo_epi64(v_tap[2], v_tap[2]); + } + } else if (num_taps == 4) { + v_tap[0] = _mm_shufflelo_epi16(*filter, 0x55); // k3k2 + v_tap[1] = _mm_shufflelo_epi16(*filter, 0xaa); // k5k4 + if (is_2d_vertical) { + v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); + v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]); + } else { + v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); + v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]); + } + } else { // num_taps == 2 + const __m128i adjusted_filter = _mm_srli_si128(*filter, 1); + v_tap[0] = _mm_shufflelo_epi16(adjusted_filter, 0x55); // k4k3 + if (is_2d_vertical) { + v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); + } else { + v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); + } + } +} + +template <int num_taps, bool is_2d_vertical = false> +LIBGAV1_ALWAYS_INLINE void SetupTaps(const __m128i* const filter, + __m256i* v_tap) { + if (num_taps == 8) { + v_tap[0] = _mm256_broadcastw_epi16(*filter); // k1k0 + v_tap[1] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 2)); // k3k2 + v_tap[2] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 4)); // k5k4 + v_tap[3] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 6)); // k7k6 + if (is_2d_vertical) { + // placeholder + } + } else if (num_taps == 6) { + v_tap[0] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 1)); // k2k1 + v_tap[1] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 3)); // k4k3 + v_tap[2] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 5)); // k6k5 + if (is_2d_vertical) { + // placeholder + } + } else if (num_taps == 4) { + v_tap[0] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 2)); // k3k2 + v_tap[1] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 4)); // k5k4 + if (is_2d_vertical) { + // placeholder + } + } else { // num_taps == 2 + v_tap[0] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 3)); // k4k3 + if (is_2d_vertical) { + // placeholder + } + } +} + +template <bool is_2d = false, bool is_compound = false> +LIBGAV1_ALWAYS_INLINE void DoHorizontalPass2xH( + const uint8_t* const src, const ptrdiff_t src_stride, void* const dst, + const ptrdiff_t dst_stride, const int width, const int height, + const int filter_id, const int filter_index) { + assert(filter_id != 0); + __m128i v_tap[4]; + const __m128i v_horizontal_filter = + LoadLo8(kHalfSubPixelFilters[filter_index][filter_id]); + + if (filter_index == 4) { // 4 tap. + SetupTaps<4>(&v_horizontal_filter, v_tap); + FilterHorizontal<4, 8, 4, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else if (filter_index == 5) { // 4 tap. + SetupTaps<4>(&v_horizontal_filter, v_tap); + FilterHorizontal<4, 8, 5, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else { // 2 tap. + SetupTaps<2>(&v_horizontal_filter, v_tap); + FilterHorizontal<2, 8, 3, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } +} + +template <bool is_2d = false, bool is_compound = false> +LIBGAV1_ALWAYS_INLINE void DoHorizontalPass( + const uint8_t* const src, const ptrdiff_t src_stride, void* const dst, + const ptrdiff_t dst_stride, const int width, const int height, + const int filter_id, const int filter_index) { + assert(filter_id != 0); + __m256i v_tap[4]; + const __m128i v_horizontal_filter = + LoadLo8(kHalfSubPixelFilters[filter_index][filter_id]); + + if (filter_index == 2) { // 8 tap. + SetupTaps<8>(&v_horizontal_filter, v_tap); + FilterHorizontal<8, 8, 2, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else if (filter_index == 1) { // 6 tap. + SetupTaps<6>(&v_horizontal_filter, v_tap); + FilterHorizontal<6, 8, 1, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else if (filter_index == 0) { // 6 tap. + SetupTaps<6>(&v_horizontal_filter, v_tap); + FilterHorizontal<6, 8, 0, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else if (filter_index == 4) { // 4 tap. + SetupTaps<4>(&v_horizontal_filter, v_tap); + FilterHorizontal<4, 8, 4, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else if (filter_index == 5) { // 4 tap. + SetupTaps<4>(&v_horizontal_filter, v_tap); + FilterHorizontal<4, 8, 5, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else { // 2 tap. + SetupTaps<2>(&v_horizontal_filter, v_tap); + FilterHorizontal<2, 8, 3, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } +} + +void ConvolveHorizontal_AVX2(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int /*vertical_filter_index*/, + const int horizontal_filter_id, + const int /*vertical_filter_id*/, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + const int filter_index = GetFilterIndex(horizontal_filter_index, width); + // Set |src| to the outermost tap. + const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset; + auto* dest = static_cast<uint8_t*>(prediction); + + if (width > 2) { + DoHorizontalPass(src, reference_stride, dest, pred_stride, width, height, + horizontal_filter_id, filter_index); + } else { + // Use non avx2 version for smaller widths. + DoHorizontalPass2xH(src, reference_stride, dest, pred_stride, width, height, + horizontal_filter_id, filter_index); + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->convolve[0][0][0][1] = ConvolveHorizontal_AVX2; +} + +} // namespace +} // namespace low_bitdepth + +void ConvolveInit_AVX2() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_TARGETING_AVX2 +namespace libgav1 { +namespace dsp { + +void ConvolveInit_AVX2() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_AVX2 diff --git a/src/dsp/x86/convolve_avx2.h b/src/dsp/x86/convolve_avx2.h new file mode 100644 index 0000000..6179d98 --- /dev/null +++ b/src/dsp/x86/convolve_avx2.h @@ -0,0 +1,43 @@ +/* + * Copyright 2020 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_CONVOLVE_AVX2_H_ +#define LIBGAV1_SRC_DSP_X86_CONVOLVE_AVX2_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::convolve, see the defines below for specifics. This +// function is not thread-safe. +void ConvolveInit_AVX2(); + +} // namespace dsp +} // namespace libgav1 + +// If avx2 is enabled and the baseline isn't set due to a higher level of +// optimization being enabled, signal the avx2 implementation should be used. +#if LIBGAV1_TARGETING_AVX2 + +#ifndef LIBGAV1_Dsp8bpp_ConvolveHorizontal +#define LIBGAV1_Dsp8bpp_ConvolveHorizontal LIBGAV1_CPU_AVX2 +#endif + +#endif // LIBGAV1_TARGETING_AVX2 + +#endif // LIBGAV1_SRC_DSP_X86_CONVOLVE_AVX2_H_ diff --git a/src/dsp/x86/convolve_sse4.cc b/src/dsp/x86/convolve_sse4.cc new file mode 100644 index 0000000..3a0fff5 --- /dev/null +++ b/src/dsp/x86/convolve_sse4.cc @@ -0,0 +1,2830 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/convolve.h" +#include "src/utils/constants.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 +#include <smmintrin.h> + +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <cstring> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +#include "src/dsp/convolve.inc" + +// Multiply every entry in |src[]| by the corresponding entry in |taps[]| and +// sum. The filters in |taps[]| are pre-shifted by 1. This prevents the final +// sum from outranging int16_t. +template <int filter_index> +__m128i SumOnePassTaps(const __m128i* const src, const __m128i* const taps) { + __m128i sum; + if (filter_index < 2) { + // 6 taps. + const __m128i v_madd_21 = _mm_maddubs_epi16(src[0], taps[0]); // k2k1 + const __m128i v_madd_43 = _mm_maddubs_epi16(src[1], taps[1]); // k4k3 + const __m128i v_madd_65 = _mm_maddubs_epi16(src[2], taps[2]); // k6k5 + sum = _mm_add_epi16(v_madd_21, v_madd_43); + sum = _mm_add_epi16(sum, v_madd_65); + } else if (filter_index == 2) { + // 8 taps. + const __m128i v_madd_10 = _mm_maddubs_epi16(src[0], taps[0]); // k1k0 + const __m128i v_madd_32 = _mm_maddubs_epi16(src[1], taps[1]); // k3k2 + const __m128i v_madd_54 = _mm_maddubs_epi16(src[2], taps[2]); // k5k4 + const __m128i v_madd_76 = _mm_maddubs_epi16(src[3], taps[3]); // k7k6 + const __m128i v_sum_3210 = _mm_add_epi16(v_madd_10, v_madd_32); + const __m128i v_sum_7654 = _mm_add_epi16(v_madd_54, v_madd_76); + sum = _mm_add_epi16(v_sum_7654, v_sum_3210); + } else if (filter_index == 3) { + // 2 taps. + sum = _mm_maddubs_epi16(src[0], taps[0]); // k4k3 + } else { + // 4 taps. + const __m128i v_madd_32 = _mm_maddubs_epi16(src[0], taps[0]); // k3k2 + const __m128i v_madd_54 = _mm_maddubs_epi16(src[1], taps[1]); // k5k4 + sum = _mm_add_epi16(v_madd_32, v_madd_54); + } + return sum; +} + +template <int filter_index> +__m128i SumHorizontalTaps(const uint8_t* const src, + const __m128i* const v_tap) { + __m128i v_src[4]; + const __m128i src_long = LoadUnaligned16(src); + const __m128i src_long_dup_lo = _mm_unpacklo_epi8(src_long, src_long); + const __m128i src_long_dup_hi = _mm_unpackhi_epi8(src_long, src_long); + + if (filter_index < 2) { + // 6 taps. + v_src[0] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 3); // _21 + v_src[1] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 7); // _43 + v_src[2] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 11); // _65 + } else if (filter_index == 2) { + // 8 taps. + v_src[0] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 1); // _10 + v_src[1] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 5); // _32 + v_src[2] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 9); // _54 + v_src[3] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 13); // _76 + } else if (filter_index == 3) { + // 2 taps. + v_src[0] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 7); // _43 + } else if (filter_index > 3) { + // 4 taps. + v_src[0] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 5); // _32 + v_src[1] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 9); // _54 + } + const __m128i sum = SumOnePassTaps<filter_index>(v_src, v_tap); + return sum; +} + +template <int filter_index> +__m128i SimpleHorizontalTaps(const uint8_t* const src, + const __m128i* const v_tap) { + __m128i sum = SumHorizontalTaps<filter_index>(src, v_tap); + + // Normally the Horizontal pass does the downshift in two passes: + // kInterRoundBitsHorizontal - 1 and then (kFilterBits - + // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them + // requires adding the rounding offset from the skipped shift. + constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2); + + sum = _mm_add_epi16(sum, _mm_set1_epi16(first_shift_rounding_bit)); + sum = RightShiftWithRounding_S16(sum, kFilterBits - 1); + return _mm_packus_epi16(sum, sum); +} + +template <int filter_index> +__m128i HorizontalTaps8To16(const uint8_t* const src, + const __m128i* const v_tap) { + const __m128i sum = SumHorizontalTaps<filter_index>(src, v_tap); + + return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1); +} + +template <int filter_index> +__m128i SumHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride, + const __m128i* const v_tap) { + const __m128i input0 = LoadLo8(&src[2]); + const __m128i input1 = LoadLo8(&src[2 + src_stride]); + + if (filter_index == 3) { + // 03 04 04 05 05 06 06 07 .... + const __m128i input0_dup = + _mm_srli_si128(_mm_unpacklo_epi8(input0, input0), 3); + // 13 14 14 15 15 16 16 17 .... + const __m128i input1_dup = + _mm_srli_si128(_mm_unpacklo_epi8(input1, input1), 3); + const __m128i v_src_43 = _mm_unpacklo_epi64(input0_dup, input1_dup); + const __m128i v_sum_43 = _mm_maddubs_epi16(v_src_43, v_tap[0]); // k4k3 + return v_sum_43; + } + + // 02 03 03 04 04 05 05 06 06 07 .... + const __m128i input0_dup = + _mm_srli_si128(_mm_unpacklo_epi8(input0, input0), 1); + // 12 13 13 14 14 15 15 16 16 17 .... + const __m128i input1_dup = + _mm_srli_si128(_mm_unpacklo_epi8(input1, input1), 1); + // 04 05 05 06 06 07 07 08 ... + const __m128i input0_dup_54 = _mm_srli_si128(input0_dup, 4); + // 14 15 15 16 16 17 17 18 ... + const __m128i input1_dup_54 = _mm_srli_si128(input1_dup, 4); + const __m128i v_src_32 = _mm_unpacklo_epi64(input0_dup, input1_dup); + const __m128i v_src_54 = _mm_unpacklo_epi64(input0_dup_54, input1_dup_54); + const __m128i v_madd_32 = _mm_maddubs_epi16(v_src_32, v_tap[0]); // k3k2 + const __m128i v_madd_54 = _mm_maddubs_epi16(v_src_54, v_tap[1]); // k5k4 + const __m128i v_sum_5432 = _mm_add_epi16(v_madd_54, v_madd_32); + return v_sum_5432; +} + +template <int filter_index> +__m128i SimpleHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride, + const __m128i* const v_tap) { + __m128i sum = SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); + + // Normally the Horizontal pass does the downshift in two passes: + // kInterRoundBitsHorizontal - 1 and then (kFilterBits - + // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them + // requires adding the rounding offset from the skipped shift. + constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2); + + sum = _mm_add_epi16(sum, _mm_set1_epi16(first_shift_rounding_bit)); + sum = RightShiftWithRounding_S16(sum, kFilterBits - 1); + return _mm_packus_epi16(sum, sum); +} + +template <int filter_index> +__m128i HorizontalTaps8To16_2x2(const uint8_t* src, const ptrdiff_t src_stride, + const __m128i* const v_tap) { + const __m128i sum = + SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); + + return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1); +} + +template <int num_taps, int step, int filter_index, bool is_2d = false, + bool is_compound = false> +void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride, + void* const dest, const ptrdiff_t pred_stride, + const int width, const int height, + const __m128i* const v_tap) { + auto* dest8 = static_cast<uint8_t*>(dest); + auto* dest16 = static_cast<uint16_t*>(dest); + + // 4 tap filters are never used when width > 4. + if (num_taps != 4 && width > 4) { + int y = 0; + do { + int x = 0; + do { + if (is_2d || is_compound) { + const __m128i v_sum = + HorizontalTaps8To16<filter_index>(&src[x], v_tap); + if (is_2d) { + StoreAligned16(&dest16[x], v_sum); + } else { + StoreUnaligned16(&dest16[x], v_sum); + } + } else { + const __m128i result = + SimpleHorizontalTaps<filter_index>(&src[x], v_tap); + StoreLo8(&dest8[x], result); + } + x += step; + } while (x < width); + src += src_stride; + dest8 += pred_stride; + dest16 += pred_stride; + } while (++y < height); + return; + } + + // Horizontal passes only needs to account for |num_taps| 2 and 4 when + // |width| <= 4. + assert(width <= 4); + assert(num_taps <= 4); + if (num_taps <= 4) { + if (width == 4) { + int y = 0; + do { + if (is_2d || is_compound) { + const __m128i v_sum = HorizontalTaps8To16<filter_index>(src, v_tap); + StoreLo8(dest16, v_sum); + } else { + const __m128i result = SimpleHorizontalTaps<filter_index>(src, v_tap); + Store4(&dest8[0], result); + } + src += src_stride; + dest8 += pred_stride; + dest16 += pred_stride; + } while (++y < height); + return; + } + + if (!is_compound) { + int y = 0; + do { + if (is_2d) { + const __m128i sum = + HorizontalTaps8To16_2x2<filter_index>(src, src_stride, v_tap); + Store4(&dest16[0], sum); + dest16 += pred_stride; + Store4(&dest16[0], _mm_srli_si128(sum, 8)); + dest16 += pred_stride; + } else { + const __m128i sum = + SimpleHorizontalTaps2x2<filter_index>(src, src_stride, v_tap); + Store2(dest8, sum); + dest8 += pred_stride; + Store2(dest8, _mm_srli_si128(sum, 4)); + dest8 += pred_stride; + } + + src += src_stride << 1; + y += 2; + } while (y < height - 1); + + // The 2d filters have an odd |height| because the horizontal pass + // generates context for the vertical pass. + if (is_2d) { + assert(height % 2 == 1); + __m128i sum; + const __m128i input = LoadLo8(&src[2]); + if (filter_index == 3) { + // 03 04 04 05 05 06 06 07 .... + const __m128i v_src_43 = + _mm_srli_si128(_mm_unpacklo_epi8(input, input), 3); + sum = _mm_maddubs_epi16(v_src_43, v_tap[0]); // k4k3 + } else { + // 02 03 03 04 04 05 05 06 06 07 .... + const __m128i v_src_32 = + _mm_srli_si128(_mm_unpacklo_epi8(input, input), 1); + // 04 05 05 06 06 07 07 08 ... + const __m128i v_src_54 = _mm_srli_si128(v_src_32, 4); + const __m128i v_madd_32 = + _mm_maddubs_epi16(v_src_32, v_tap[0]); // k3k2 + const __m128i v_madd_54 = + _mm_maddubs_epi16(v_src_54, v_tap[1]); // k5k4 + sum = _mm_add_epi16(v_madd_54, v_madd_32); + } + sum = RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1); + Store4(dest16, sum); + } + } + } +} + +template <int num_taps, bool is_2d_vertical = false> +LIBGAV1_ALWAYS_INLINE void SetupTaps(const __m128i* const filter, + __m128i* v_tap) { + if (num_taps == 8) { + v_tap[0] = _mm_shufflelo_epi16(*filter, 0x0); // k1k0 + v_tap[1] = _mm_shufflelo_epi16(*filter, 0x55); // k3k2 + v_tap[2] = _mm_shufflelo_epi16(*filter, 0xaa); // k5k4 + v_tap[3] = _mm_shufflelo_epi16(*filter, 0xff); // k7k6 + if (is_2d_vertical) { + v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); + v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]); + v_tap[2] = _mm_cvtepi8_epi16(v_tap[2]); + v_tap[3] = _mm_cvtepi8_epi16(v_tap[3]); + } else { + v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); + v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]); + v_tap[2] = _mm_unpacklo_epi64(v_tap[2], v_tap[2]); + v_tap[3] = _mm_unpacklo_epi64(v_tap[3], v_tap[3]); + } + } else if (num_taps == 6) { + const __m128i adjusted_filter = _mm_srli_si128(*filter, 1); + v_tap[0] = _mm_shufflelo_epi16(adjusted_filter, 0x0); // k2k1 + v_tap[1] = _mm_shufflelo_epi16(adjusted_filter, 0x55); // k4k3 + v_tap[2] = _mm_shufflelo_epi16(adjusted_filter, 0xaa); // k6k5 + if (is_2d_vertical) { + v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); + v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]); + v_tap[2] = _mm_cvtepi8_epi16(v_tap[2]); + } else { + v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); + v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]); + v_tap[2] = _mm_unpacklo_epi64(v_tap[2], v_tap[2]); + } + } else if (num_taps == 4) { + v_tap[0] = _mm_shufflelo_epi16(*filter, 0x55); // k3k2 + v_tap[1] = _mm_shufflelo_epi16(*filter, 0xaa); // k5k4 + if (is_2d_vertical) { + v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); + v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]); + } else { + v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); + v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]); + } + } else { // num_taps == 2 + const __m128i adjusted_filter = _mm_srli_si128(*filter, 1); + v_tap[0] = _mm_shufflelo_epi16(adjusted_filter, 0x55); // k4k3 + if (is_2d_vertical) { + v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]); + } else { + v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]); + } + } +} + +template <int num_taps, bool is_compound> +__m128i SimpleSum2DVerticalTaps(const __m128i* const src, + const __m128i* const taps) { + __m128i sum_lo = _mm_madd_epi16(_mm_unpacklo_epi16(src[0], src[1]), taps[0]); + __m128i sum_hi = _mm_madd_epi16(_mm_unpackhi_epi16(src[0], src[1]), taps[0]); + if (num_taps >= 4) { + __m128i madd_lo = + _mm_madd_epi16(_mm_unpacklo_epi16(src[2], src[3]), taps[1]); + __m128i madd_hi = + _mm_madd_epi16(_mm_unpackhi_epi16(src[2], src[3]), taps[1]); + sum_lo = _mm_add_epi32(sum_lo, madd_lo); + sum_hi = _mm_add_epi32(sum_hi, madd_hi); + if (num_taps >= 6) { + madd_lo = _mm_madd_epi16(_mm_unpacklo_epi16(src[4], src[5]), taps[2]); + madd_hi = _mm_madd_epi16(_mm_unpackhi_epi16(src[4], src[5]), taps[2]); + sum_lo = _mm_add_epi32(sum_lo, madd_lo); + sum_hi = _mm_add_epi32(sum_hi, madd_hi); + if (num_taps == 8) { + madd_lo = _mm_madd_epi16(_mm_unpacklo_epi16(src[6], src[7]), taps[3]); + madd_hi = _mm_madd_epi16(_mm_unpackhi_epi16(src[6], src[7]), taps[3]); + sum_lo = _mm_add_epi32(sum_lo, madd_lo); + sum_hi = _mm_add_epi32(sum_hi, madd_hi); + } + } + } + + if (is_compound) { + return _mm_packs_epi32( + RightShiftWithRounding_S32(sum_lo, kInterRoundBitsCompoundVertical - 1), + RightShiftWithRounding_S32(sum_hi, + kInterRoundBitsCompoundVertical - 1)); + } + + return _mm_packs_epi32( + RightShiftWithRounding_S32(sum_lo, kInterRoundBitsVertical - 1), + RightShiftWithRounding_S32(sum_hi, kInterRoundBitsVertical - 1)); +} + +template <int num_taps, bool is_compound = false> +void Filter2DVertical(const uint16_t* src, void* const dst, + const ptrdiff_t dst_stride, const int width, + const int height, const __m128i* const taps) { + assert(width >= 8); + constexpr int next_row = num_taps - 1; + // The Horizontal pass uses |width| as |stride| for the intermediate buffer. + const ptrdiff_t src_stride = width; + + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + int x = 0; + do { + __m128i srcs[8]; + const uint16_t* src_x = src + x; + srcs[0] = LoadAligned16(src_x); + src_x += src_stride; + if (num_taps >= 4) { + srcs[1] = LoadAligned16(src_x); + src_x += src_stride; + srcs[2] = LoadAligned16(src_x); + src_x += src_stride; + if (num_taps >= 6) { + srcs[3] = LoadAligned16(src_x); + src_x += src_stride; + srcs[4] = LoadAligned16(src_x); + src_x += src_stride; + if (num_taps == 8) { + srcs[5] = LoadAligned16(src_x); + src_x += src_stride; + srcs[6] = LoadAligned16(src_x); + src_x += src_stride; + } + } + } + + int y = 0; + do { + srcs[next_row] = LoadAligned16(src_x); + src_x += src_stride; + + const __m128i sum = + SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps); + if (is_compound) { + StoreUnaligned16(dst16 + x + y * dst_stride, sum); + } else { + StoreLo8(dst8 + x + y * dst_stride, _mm_packus_epi16(sum, sum)); + } + + srcs[0] = srcs[1]; + if (num_taps >= 4) { + srcs[1] = srcs[2]; + srcs[2] = srcs[3]; + if (num_taps >= 6) { + srcs[3] = srcs[4]; + srcs[4] = srcs[5]; + if (num_taps == 8) { + srcs[5] = srcs[6]; + srcs[6] = srcs[7]; + } + } + } + } while (++y < height); + x += 8; + } while (x < width); +} + +// Take advantage of |src_stride| == |width| to process two rows at a time. +template <int num_taps, bool is_compound = false> +void Filter2DVertical4xH(const uint16_t* src, void* const dst, + const ptrdiff_t dst_stride, const int height, + const __m128i* const taps) { + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + __m128i srcs[9]; + srcs[0] = LoadAligned16(src); + src += 8; + if (num_taps >= 4) { + srcs[2] = LoadAligned16(src); + src += 8; + srcs[1] = _mm_unpacklo_epi64(_mm_srli_si128(srcs[0], 8), srcs[2]); + if (num_taps >= 6) { + srcs[4] = LoadAligned16(src); + src += 8; + srcs[3] = _mm_unpacklo_epi64(_mm_srli_si128(srcs[2], 8), srcs[4]); + if (num_taps == 8) { + srcs[6] = LoadAligned16(src); + src += 8; + srcs[5] = _mm_unpacklo_epi64(_mm_srli_si128(srcs[4], 8), srcs[6]); + } + } + } + + int y = 0; + do { + srcs[num_taps] = LoadAligned16(src); + src += 8; + srcs[num_taps - 1] = _mm_unpacklo_epi64( + _mm_srli_si128(srcs[num_taps - 2], 8), srcs[num_taps]); + + const __m128i sum = + SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps); + if (is_compound) { + StoreUnaligned16(dst16, sum); + dst16 += 4 << 1; + } else { + const __m128i results = _mm_packus_epi16(sum, sum); + Store4(dst8, results); + dst8 += dst_stride; + Store4(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + if (num_taps >= 4) { + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + if (num_taps >= 6) { + srcs[3] = srcs[5]; + srcs[4] = srcs[6]; + if (num_taps == 8) { + srcs[5] = srcs[7]; + srcs[6] = srcs[8]; + } + } + } + y += 2; + } while (y < height); +} + +// Take advantage of |src_stride| == |width| to process four rows at a time. +template <int num_taps> +void Filter2DVertical2xH(const uint16_t* src, void* const dst, + const ptrdiff_t dst_stride, const int height, + const __m128i* const taps) { + constexpr int next_row = (num_taps < 6) ? 4 : 8; + + auto* dst8 = static_cast<uint8_t*>(dst); + + __m128i srcs[9]; + srcs[0] = LoadAligned16(src); + src += 8; + if (num_taps >= 6) { + srcs[4] = LoadAligned16(src); + src += 8; + srcs[1] = _mm_alignr_epi8(srcs[4], srcs[0], 4); + if (num_taps == 8) { + srcs[2] = _mm_alignr_epi8(srcs[4], srcs[0], 8); + srcs[3] = _mm_alignr_epi8(srcs[4], srcs[0], 12); + } + } + + int y = 0; + do { + srcs[next_row] = LoadAligned16(src); + src += 8; + if (num_taps == 2) { + srcs[1] = _mm_alignr_epi8(srcs[4], srcs[0], 4); + } else if (num_taps == 4) { + srcs[1] = _mm_alignr_epi8(srcs[4], srcs[0], 4); + srcs[2] = _mm_alignr_epi8(srcs[4], srcs[0], 8); + srcs[3] = _mm_alignr_epi8(srcs[4], srcs[0], 12); + } else if (num_taps == 6) { + srcs[2] = _mm_alignr_epi8(srcs[4], srcs[0], 8); + srcs[3] = _mm_alignr_epi8(srcs[4], srcs[0], 12); + srcs[5] = _mm_alignr_epi8(srcs[8], srcs[4], 4); + } else if (num_taps == 8) { + srcs[5] = _mm_alignr_epi8(srcs[8], srcs[4], 4); + srcs[6] = _mm_alignr_epi8(srcs[8], srcs[4], 8); + srcs[7] = _mm_alignr_epi8(srcs[8], srcs[4], 12); + } + + const __m128i sum = + SimpleSum2DVerticalTaps<num_taps, /*is_compound=*/false>(srcs, taps); + const __m128i results = _mm_packus_epi16(sum, sum); + + Store2(dst8, results); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 2)); + // When |height| <= 4 the taps are restricted to 2 and 4 tap variants. + // Therefore we don't need to check this condition when |height| > 4. + if (num_taps <= 4 && height == 2) return; + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 6)); + dst8 += dst_stride; + + srcs[0] = srcs[4]; + if (num_taps == 6) { + srcs[1] = srcs[5]; + srcs[4] = srcs[8]; + } else if (num_taps == 8) { + srcs[1] = srcs[5]; + srcs[2] = srcs[6]; + srcs[3] = srcs[7]; + srcs[4] = srcs[8]; + } + + y += 4; + } while (y < height); +} + +template <bool is_2d = false, bool is_compound = false> +LIBGAV1_ALWAYS_INLINE void DoHorizontalPass( + const uint8_t* const src, const ptrdiff_t src_stride, void* const dst, + const ptrdiff_t dst_stride, const int width, const int height, + const int filter_id, const int filter_index) { + assert(filter_id != 0); + __m128i v_tap[4]; + const __m128i v_horizontal_filter = + LoadLo8(kHalfSubPixelFilters[filter_index][filter_id]); + + if (filter_index == 2) { // 8 tap. + SetupTaps<8>(&v_horizontal_filter, v_tap); + FilterHorizontal<8, 8, 2, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else if (filter_index == 1) { // 6 tap. + SetupTaps<6>(&v_horizontal_filter, v_tap); + FilterHorizontal<6, 8, 1, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else if (filter_index == 0) { // 6 tap. + SetupTaps<6>(&v_horizontal_filter, v_tap); + FilterHorizontal<6, 8, 0, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else if (filter_index == 4) { // 4 tap. + SetupTaps<4>(&v_horizontal_filter, v_tap); + FilterHorizontal<4, 8, 4, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else if (filter_index == 5) { // 4 tap. + SetupTaps<4>(&v_horizontal_filter, v_tap); + FilterHorizontal<4, 8, 5, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } else { // 2 tap. + SetupTaps<2>(&v_horizontal_filter, v_tap); + FilterHorizontal<2, 8, 3, is_2d, is_compound>( + src, src_stride, dst, dst_stride, width, height, v_tap); + } +} + +void Convolve2D_SSE4_1(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int vertical_filter_index, + const int horizontal_filter_id, + const int vertical_filter_id, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width); + const int vert_filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(vert_filter_index); + + // The output of the horizontal filter is guaranteed to fit in 16 bits. + alignas(16) uint16_t + intermediate_result[kMaxSuperBlockSizeInPixels * + (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)]; + const int intermediate_height = height + vertical_taps - 1; + + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride - kHorizontalOffset; + + DoHorizontalPass</*is_2d=*/true>(src, src_stride, intermediate_result, width, + width, intermediate_height, + horizontal_filter_id, horiz_filter_index); + + // Vertical filter. + auto* dest = static_cast<uint8_t*>(prediction); + const ptrdiff_t dest_stride = pred_stride; + assert(vertical_filter_id != 0); + + __m128i taps[4]; + const __m128i v_filter = + LoadLo8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id]); + + if (vertical_taps == 8) { + SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 2) { + Filter2DVertical2xH<8>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<8>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<8>(intermediate_result, dest, dest_stride, width, height, + taps); + } + } else if (vertical_taps == 6) { + SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 2) { + Filter2DVertical2xH<6>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<6>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<6>(intermediate_result, dest, dest_stride, width, height, + taps); + } + } else if (vertical_taps == 4) { + SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 2) { + Filter2DVertical2xH<4>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<4>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<4>(intermediate_result, dest, dest_stride, width, height, + taps); + } + } else { // |vertical_taps| == 2 + SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 2) { + Filter2DVertical2xH<2>(intermediate_result, dest, dest_stride, height, + taps); + } else if (width == 4) { + Filter2DVertical4xH<2>(intermediate_result, dest, dest_stride, height, + taps); + } else { + Filter2DVertical<2>(intermediate_result, dest, dest_stride, width, height, + taps); + } + } +} + +// The 1D compound shift is always |kInterRoundBitsHorizontal|, even for 1D +// Vertical calculations. +__m128i Compound1DShift(const __m128i sum) { + return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1); +} + +template <int filter_index> +__m128i SumVerticalTaps(const __m128i* const srcs, const __m128i* const v_tap) { + __m128i v_src[4]; + + if (filter_index < 2) { + // 6 taps. + v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]); + v_src[1] = _mm_unpacklo_epi8(srcs[2], srcs[3]); + v_src[2] = _mm_unpacklo_epi8(srcs[4], srcs[5]); + } else if (filter_index == 2) { + // 8 taps. + v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]); + v_src[1] = _mm_unpacklo_epi8(srcs[2], srcs[3]); + v_src[2] = _mm_unpacklo_epi8(srcs[4], srcs[5]); + v_src[3] = _mm_unpacklo_epi8(srcs[6], srcs[7]); + } else if (filter_index == 3) { + // 2 taps. + v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]); + } else if (filter_index > 3) { + // 4 taps. + v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]); + v_src[1] = _mm_unpacklo_epi8(srcs[2], srcs[3]); + } + const __m128i sum = SumOnePassTaps<filter_index>(v_src, v_tap); + return sum; +} + +template <int filter_index, bool is_compound = false> +void FilterVertical(const uint8_t* src, const ptrdiff_t src_stride, + void* const dst, const ptrdiff_t dst_stride, + const int width, const int height, + const __m128i* const v_tap) { + const int num_taps = GetNumTapsInFilter(filter_index); + const int next_row = num_taps - 1; + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + assert(width >= 8); + + int x = 0; + do { + const uint8_t* src_x = src + x; + __m128i srcs[8]; + srcs[0] = LoadLo8(src_x); + src_x += src_stride; + if (num_taps >= 4) { + srcs[1] = LoadLo8(src_x); + src_x += src_stride; + srcs[2] = LoadLo8(src_x); + src_x += src_stride; + if (num_taps >= 6) { + srcs[3] = LoadLo8(src_x); + src_x += src_stride; + srcs[4] = LoadLo8(src_x); + src_x += src_stride; + if (num_taps == 8) { + srcs[5] = LoadLo8(src_x); + src_x += src_stride; + srcs[6] = LoadLo8(src_x); + src_x += src_stride; + } + } + } + + int y = 0; + do { + srcs[next_row] = LoadLo8(src_x); + src_x += src_stride; + + const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + if (is_compound) { + const __m128i results = Compound1DShift(sums); + StoreUnaligned16(dst16 + x + y * dst_stride, results); + } else { + const __m128i results = + RightShiftWithRounding_S16(sums, kFilterBits - 1); + StoreLo8(dst8 + x + y * dst_stride, _mm_packus_epi16(results, results)); + } + + srcs[0] = srcs[1]; + if (num_taps >= 4) { + srcs[1] = srcs[2]; + srcs[2] = srcs[3]; + if (num_taps >= 6) { + srcs[3] = srcs[4]; + srcs[4] = srcs[5]; + if (num_taps == 8) { + srcs[5] = srcs[6]; + srcs[6] = srcs[7]; + } + } + } + } while (++y < height); + x += 8; + } while (x < width); +} + +template <int filter_index, bool is_compound = false> +void FilterVertical4xH(const uint8_t* src, const ptrdiff_t src_stride, + void* const dst, const ptrdiff_t dst_stride, + const int height, const __m128i* const v_tap) { + const int num_taps = GetNumTapsInFilter(filter_index); + auto* dst8 = static_cast<uint8_t*>(dst); + auto* dst16 = static_cast<uint16_t*>(dst); + + __m128i srcs[9]; + + if (num_taps == 2) { + srcs[2] = _mm_setzero_si128(); + // 00 01 02 03 + srcs[0] = Load4(src); + src += src_stride; + + int y = 0; + do { + // 10 11 12 13 + const __m128i a = Load4(src); + // 00 01 02 03 10 11 12 13 + srcs[0] = _mm_unpacklo_epi32(srcs[0], a); + src += src_stride; + // 20 21 22 23 + srcs[2] = Load4(src); + src += src_stride; + // 10 11 12 13 20 21 22 23 + srcs[1] = _mm_unpacklo_epi32(a, srcs[2]); + + const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + if (is_compound) { + const __m128i results = Compound1DShift(sums); + StoreUnaligned16(dst16, results); + dst16 += 4 << 1; + } else { + const __m128i results_16 = + RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m128i results = _mm_packus_epi16(results_16, results_16); + Store4(dst8, results); + dst8 += dst_stride; + Store4(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + y += 2; + } while (y < height); + } else if (num_taps == 4) { + srcs[4] = _mm_setzero_si128(); + // 00 01 02 03 + srcs[0] = Load4(src); + src += src_stride; + // 10 11 12 13 + const __m128i a = Load4(src); + // 00 01 02 03 10 11 12 13 + srcs[0] = _mm_unpacklo_epi32(srcs[0], a); + src += src_stride; + // 20 21 22 23 + srcs[2] = Load4(src); + src += src_stride; + // 10 11 12 13 20 21 22 23 + srcs[1] = _mm_unpacklo_epi32(a, srcs[2]); + + int y = 0; + do { + // 30 31 32 33 + const __m128i b = Load4(src); + // 20 21 22 23 30 31 32 33 + srcs[2] = _mm_unpacklo_epi32(srcs[2], b); + src += src_stride; + // 40 41 42 43 + srcs[4] = Load4(src); + src += src_stride; + // 30 31 32 33 40 41 42 43 + srcs[3] = _mm_unpacklo_epi32(b, srcs[4]); + + const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + if (is_compound) { + const __m128i results = Compound1DShift(sums); + StoreUnaligned16(dst16, results); + dst16 += 4 << 1; + } else { + const __m128i results_16 = + RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m128i results = _mm_packus_epi16(results_16, results_16); + Store4(dst8, results); + dst8 += dst_stride; + Store4(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + y += 2; + } while (y < height); + } else if (num_taps == 6) { + srcs[6] = _mm_setzero_si128(); + // 00 01 02 03 + srcs[0] = Load4(src); + src += src_stride; + // 10 11 12 13 + const __m128i a = Load4(src); + // 00 01 02 03 10 11 12 13 + srcs[0] = _mm_unpacklo_epi32(srcs[0], a); + src += src_stride; + // 20 21 22 23 + srcs[2] = Load4(src); + src += src_stride; + // 10 11 12 13 20 21 22 23 + srcs[1] = _mm_unpacklo_epi32(a, srcs[2]); + // 30 31 32 33 + const __m128i b = Load4(src); + // 20 21 22 23 30 31 32 33 + srcs[2] = _mm_unpacklo_epi32(srcs[2], b); + src += src_stride; + // 40 41 42 43 + srcs[4] = Load4(src); + src += src_stride; + // 30 31 32 33 40 41 42 43 + srcs[3] = _mm_unpacklo_epi32(b, srcs[4]); + + int y = 0; + do { + // 50 51 52 53 + const __m128i c = Load4(src); + // 40 41 42 43 50 51 52 53 + srcs[4] = _mm_unpacklo_epi32(srcs[4], c); + src += src_stride; + // 60 61 62 63 + srcs[6] = Load4(src); + src += src_stride; + // 50 51 52 53 60 61 62 63 + srcs[5] = _mm_unpacklo_epi32(c, srcs[6]); + + const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + if (is_compound) { + const __m128i results = Compound1DShift(sums); + StoreUnaligned16(dst16, results); + dst16 += 4 << 1; + } else { + const __m128i results_16 = + RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m128i results = _mm_packus_epi16(results_16, results_16); + Store4(dst8, results); + dst8 += dst_stride; + Store4(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + srcs[3] = srcs[5]; + srcs[4] = srcs[6]; + y += 2; + } while (y < height); + } else if (num_taps == 8) { + srcs[8] = _mm_setzero_si128(); + // 00 01 02 03 + srcs[0] = Load4(src); + src += src_stride; + // 10 11 12 13 + const __m128i a = Load4(src); + // 00 01 02 03 10 11 12 13 + srcs[0] = _mm_unpacklo_epi32(srcs[0], a); + src += src_stride; + // 20 21 22 23 + srcs[2] = Load4(src); + src += src_stride; + // 10 11 12 13 20 21 22 23 + srcs[1] = _mm_unpacklo_epi32(a, srcs[2]); + // 30 31 32 33 + const __m128i b = Load4(src); + // 20 21 22 23 30 31 32 33 + srcs[2] = _mm_unpacklo_epi32(srcs[2], b); + src += src_stride; + // 40 41 42 43 + srcs[4] = Load4(src); + src += src_stride; + // 30 31 32 33 40 41 42 43 + srcs[3] = _mm_unpacklo_epi32(b, srcs[4]); + // 50 51 52 53 + const __m128i c = Load4(src); + // 40 41 42 43 50 51 52 53 + srcs[4] = _mm_unpacklo_epi32(srcs[4], c); + src += src_stride; + // 60 61 62 63 + srcs[6] = Load4(src); + src += src_stride; + // 50 51 52 53 60 61 62 63 + srcs[5] = _mm_unpacklo_epi32(c, srcs[6]); + + int y = 0; + do { + // 70 71 72 73 + const __m128i d = Load4(src); + // 60 61 62 63 70 71 72 73 + srcs[6] = _mm_unpacklo_epi32(srcs[6], d); + src += src_stride; + // 80 81 82 83 + srcs[8] = Load4(src); + src += src_stride; + // 70 71 72 73 80 81 82 83 + srcs[7] = _mm_unpacklo_epi32(d, srcs[8]); + + const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + if (is_compound) { + const __m128i results = Compound1DShift(sums); + StoreUnaligned16(dst16, results); + dst16 += 4 << 1; + } else { + const __m128i results_16 = + RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m128i results = _mm_packus_epi16(results_16, results_16); + Store4(dst8, results); + dst8 += dst_stride; + Store4(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + } + + srcs[0] = srcs[2]; + srcs[1] = srcs[3]; + srcs[2] = srcs[4]; + srcs[3] = srcs[5]; + srcs[4] = srcs[6]; + srcs[5] = srcs[7]; + srcs[6] = srcs[8]; + y += 2; + } while (y < height); + } +} + +template <int filter_index, bool negative_outside_taps = false> +void FilterVertical2xH(const uint8_t* src, const ptrdiff_t src_stride, + void* const dst, const ptrdiff_t dst_stride, + const int height, const __m128i* const v_tap) { + const int num_taps = GetNumTapsInFilter(filter_index); + auto* dst8 = static_cast<uint8_t*>(dst); + + __m128i srcs[9]; + + if (num_taps == 2) { + srcs[2] = _mm_setzero_si128(); + // 00 01 + srcs[0] = Load2(src); + src += src_stride; + + int y = 0; + do { + // 00 01 10 11 + srcs[0] = Load2<1>(src, srcs[0]); + src += src_stride; + // 00 01 10 11 20 21 + srcs[0] = Load2<2>(src, srcs[0]); + src += src_stride; + // 00 01 10 11 20 21 30 31 + srcs[0] = Load2<3>(src, srcs[0]); + src += src_stride; + // 40 41 + srcs[2] = Load2<0>(src, srcs[2]); + src += src_stride; + // 00 01 10 11 20 21 30 31 40 41 + const __m128i srcs_0_2 = _mm_unpacklo_epi64(srcs[0], srcs[2]); + // 10 11 20 21 30 31 40 41 + srcs[1] = _mm_srli_si128(srcs_0_2, 2); + // This uses srcs[0]..srcs[1]. + const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + const __m128i results_16 = + RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m128i results = _mm_packus_epi16(results_16, results_16); + + Store2(dst8, results); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 2)); + if (height == 2) return; + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 6)); + dst8 += dst_stride; + + srcs[0] = srcs[2]; + y += 4; + } while (y < height); + } else if (num_taps == 4) { + srcs[4] = _mm_setzero_si128(); + + // 00 01 + srcs[0] = Load2(src); + src += src_stride; + // 00 01 10 11 + srcs[0] = Load2<1>(src, srcs[0]); + src += src_stride; + // 00 01 10 11 20 21 + srcs[0] = Load2<2>(src, srcs[0]); + src += src_stride; + + int y = 0; + do { + // 00 01 10 11 20 21 30 31 + srcs[0] = Load2<3>(src, srcs[0]); + src += src_stride; + // 40 41 + srcs[4] = Load2<0>(src, srcs[4]); + src += src_stride; + // 40 41 50 51 + srcs[4] = Load2<1>(src, srcs[4]); + src += src_stride; + // 40 41 50 51 60 61 + srcs[4] = Load2<2>(src, srcs[4]); + src += src_stride; + // 00 01 10 11 20 21 30 31 40 41 50 51 60 61 + const __m128i srcs_0_4 = _mm_unpacklo_epi64(srcs[0], srcs[4]); + // 10 11 20 21 30 31 40 41 + srcs[1] = _mm_srli_si128(srcs_0_4, 2); + // 20 21 30 31 40 41 50 51 + srcs[2] = _mm_srli_si128(srcs_0_4, 4); + // 30 31 40 41 50 51 60 61 + srcs[3] = _mm_srli_si128(srcs_0_4, 6); + + // This uses srcs[0]..srcs[3]. + const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + const __m128i results_16 = + RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m128i results = _mm_packus_epi16(results_16, results_16); + + Store2(dst8, results); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 2)); + if (height == 2) return; + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 6)); + dst8 += dst_stride; + + srcs[0] = srcs[4]; + y += 4; + } while (y < height); + } else if (num_taps == 6) { + // During the vertical pass the number of taps is restricted when + // |height| <= 4. + assert(height > 4); + srcs[8] = _mm_setzero_si128(); + + // 00 01 + srcs[0] = Load2(src); + src += src_stride; + // 00 01 10 11 + srcs[0] = Load2<1>(src, srcs[0]); + src += src_stride; + // 00 01 10 11 20 21 + srcs[0] = Load2<2>(src, srcs[0]); + src += src_stride; + // 00 01 10 11 20 21 30 31 + srcs[0] = Load2<3>(src, srcs[0]); + src += src_stride; + // 40 41 + srcs[4] = Load2(src); + src += src_stride; + // 00 01 10 11 20 21 30 31 40 41 50 51 60 61 + const __m128i srcs_0_4x = _mm_unpacklo_epi64(srcs[0], srcs[4]); + // 10 11 20 21 30 31 40 41 + srcs[1] = _mm_srli_si128(srcs_0_4x, 2); + + int y = 0; + do { + // 40 41 50 51 + srcs[4] = Load2<1>(src, srcs[4]); + src += src_stride; + // 40 41 50 51 60 61 + srcs[4] = Load2<2>(src, srcs[4]); + src += src_stride; + // 40 41 50 51 60 61 70 71 + srcs[4] = Load2<3>(src, srcs[4]); + src += src_stride; + // 80 81 + srcs[8] = Load2<0>(src, srcs[8]); + src += src_stride; + // 00 01 10 11 20 21 30 31 40 41 50 51 60 61 + const __m128i srcs_0_4 = _mm_unpacklo_epi64(srcs[0], srcs[4]); + // 20 21 30 31 40 41 50 51 + srcs[2] = _mm_srli_si128(srcs_0_4, 4); + // 30 31 40 41 50 51 60 61 + srcs[3] = _mm_srli_si128(srcs_0_4, 6); + const __m128i srcs_4_8 = _mm_unpacklo_epi64(srcs[4], srcs[8]); + // 50 51 60 61 70 71 80 81 + srcs[5] = _mm_srli_si128(srcs_4_8, 2); + + // This uses srcs[0]..srcs[5]. + const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + const __m128i results_16 = + RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m128i results = _mm_packus_epi16(results_16, results_16); + + Store2(dst8, results); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 2)); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 6)); + dst8 += dst_stride; + + srcs[0] = srcs[4]; + srcs[1] = srcs[5]; + srcs[4] = srcs[8]; + y += 4; + } while (y < height); + } else if (num_taps == 8) { + // During the vertical pass the number of taps is restricted when + // |height| <= 4. + assert(height > 4); + srcs[8] = _mm_setzero_si128(); + // 00 01 + srcs[0] = Load2(src); + src += src_stride; + // 00 01 10 11 + srcs[0] = Load2<1>(src, srcs[0]); + src += src_stride; + // 00 01 10 11 20 21 + srcs[0] = Load2<2>(src, srcs[0]); + src += src_stride; + // 00 01 10 11 20 21 30 31 + srcs[0] = Load2<3>(src, srcs[0]); + src += src_stride; + // 40 41 + srcs[4] = Load2(src); + src += src_stride; + // 40 41 50 51 + srcs[4] = Load2<1>(src, srcs[4]); + src += src_stride; + // 40 41 50 51 60 61 + srcs[4] = Load2<2>(src, srcs[4]); + src += src_stride; + + // 00 01 10 11 20 21 30 31 40 41 50 51 60 61 + const __m128i srcs_0_4 = _mm_unpacklo_epi64(srcs[0], srcs[4]); + // 10 11 20 21 30 31 40 41 + srcs[1] = _mm_srli_si128(srcs_0_4, 2); + // 20 21 30 31 40 41 50 51 + srcs[2] = _mm_srli_si128(srcs_0_4, 4); + // 30 31 40 41 50 51 60 61 + srcs[3] = _mm_srli_si128(srcs_0_4, 6); + + int y = 0; + do { + // 40 41 50 51 60 61 70 71 + srcs[4] = Load2<3>(src, srcs[4]); + src += src_stride; + // 80 81 + srcs[8] = Load2<0>(src, srcs[8]); + src += src_stride; + // 80 81 90 91 + srcs[8] = Load2<1>(src, srcs[8]); + src += src_stride; + // 80 81 90 91 a0 a1 + srcs[8] = Load2<2>(src, srcs[8]); + src += src_stride; + + // 40 41 50 51 60 61 70 71 80 81 90 91 a0 a1 + const __m128i srcs_4_8 = _mm_unpacklo_epi64(srcs[4], srcs[8]); + // 50 51 60 61 70 71 80 81 + srcs[5] = _mm_srli_si128(srcs_4_8, 2); + // 60 61 70 71 80 81 90 91 + srcs[6] = _mm_srli_si128(srcs_4_8, 4); + // 70 71 80 81 90 91 a0 a1 + srcs[7] = _mm_srli_si128(srcs_4_8, 6); + + // This uses srcs[0]..srcs[7]. + const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap); + const __m128i results_16 = + RightShiftWithRounding_S16(sums, kFilterBits - 1); + const __m128i results = _mm_packus_epi16(results_16, results_16); + + Store2(dst8, results); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 2)); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 4)); + dst8 += dst_stride; + Store2(dst8, _mm_srli_si128(results, 6)); + dst8 += dst_stride; + + srcs[0] = srcs[4]; + srcs[1] = srcs[5]; + srcs[2] = srcs[6]; + srcs[3] = srcs[7]; + srcs[4] = srcs[8]; + y += 4; + } while (y < height); + } +} + +void ConvolveVertical_SSE4_1(const void* const reference, + const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, + const int vertical_filter_index, + const int /*horizontal_filter_id*/, + const int vertical_filter_id, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + const int filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(filter_index); + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride; + auto* dest = static_cast<uint8_t*>(prediction); + const ptrdiff_t dest_stride = pred_stride; + assert(vertical_filter_id != 0); + + __m128i taps[4]; + const __m128i v_filter = + LoadLo8(kHalfSubPixelFilters[filter_index][vertical_filter_id]); + + if (filter_index < 2) { // 6 tap. + SetupTaps<6>(&v_filter, taps); + if (width == 2) { + FilterVertical2xH<0>(src, src_stride, dest, dest_stride, height, taps); + } else if (width == 4) { + FilterVertical4xH<0>(src, src_stride, dest, dest_stride, height, taps); + } else { + FilterVertical<0>(src, src_stride, dest, dest_stride, width, height, + taps); + } + } else if (filter_index == 2) { // 8 tap. + SetupTaps<8>(&v_filter, taps); + if (width == 2) { + FilterVertical2xH<2>(src, src_stride, dest, dest_stride, height, taps); + } else if (width == 4) { + FilterVertical4xH<2>(src, src_stride, dest, dest_stride, height, taps); + } else { + FilterVertical<2>(src, src_stride, dest, dest_stride, width, height, + taps); + } + } else if (filter_index == 3) { // 2 tap. + SetupTaps<2>(&v_filter, taps); + if (width == 2) { + FilterVertical2xH<3>(src, src_stride, dest, dest_stride, height, taps); + } else if (width == 4) { + FilterVertical4xH<3>(src, src_stride, dest, dest_stride, height, taps); + } else { + FilterVertical<3>(src, src_stride, dest, dest_stride, width, height, + taps); + } + } else if (filter_index == 4) { // 4 tap. + SetupTaps<4>(&v_filter, taps); + if (width == 2) { + FilterVertical2xH<4>(src, src_stride, dest, dest_stride, height, taps); + } else if (width == 4) { + FilterVertical4xH<4>(src, src_stride, dest, dest_stride, height, taps); + } else { + FilterVertical<4>(src, src_stride, dest, dest_stride, width, height, + taps); + } + } else { + // TODO(slavarnway): Investigate adding |filter_index| == 1 special cases. + // See convolve_neon.cc + SetupTaps<4>(&v_filter, taps); + + if (width == 2) { + FilterVertical2xH<5>(src, src_stride, dest, dest_stride, height, taps); + } else if (width == 4) { + FilterVertical4xH<5>(src, src_stride, dest, dest_stride, height, taps); + } else { + FilterVertical<5>(src, src_stride, dest, dest_stride, width, height, + taps); + } + } +} + +void ConvolveCompoundCopy_SSE4(const void* const reference, + const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, + const int /*vertical_filter_index*/, + const int /*horizontal_filter_id*/, + const int /*vertical_filter_id*/, + const int width, const int height, + void* prediction, const ptrdiff_t pred_stride) { + const auto* src = static_cast<const uint8_t*>(reference); + const ptrdiff_t src_stride = reference_stride; + auto* dest = static_cast<uint16_t*>(prediction); + constexpr int kRoundBitsVertical = + kInterRoundBitsVertical - kInterRoundBitsCompoundVertical; + if (width >= 16) { + int y = height; + do { + int x = 0; + do { + const __m128i v_src = LoadUnaligned16(&src[x]); + const __m128i v_src_ext_lo = _mm_cvtepu8_epi16(v_src); + const __m128i v_src_ext_hi = + _mm_cvtepu8_epi16(_mm_srli_si128(v_src, 8)); + const __m128i v_dest_lo = + _mm_slli_epi16(v_src_ext_lo, kRoundBitsVertical); + const __m128i v_dest_hi = + _mm_slli_epi16(v_src_ext_hi, kRoundBitsVertical); + // TODO(slavarnway): Investigate using aligned stores. + StoreUnaligned16(&dest[x], v_dest_lo); + StoreUnaligned16(&dest[x + 8], v_dest_hi); + x += 16; + } while (x < width); + src += src_stride; + dest += pred_stride; + } while (--y != 0); + } else if (width == 8) { + int y = height; + do { + const __m128i v_src = LoadLo8(&src[0]); + const __m128i v_src_ext = _mm_cvtepu8_epi16(v_src); + const __m128i v_dest = _mm_slli_epi16(v_src_ext, kRoundBitsVertical); + StoreUnaligned16(&dest[0], v_dest); + src += src_stride; + dest += pred_stride; + } while (--y != 0); + } else { /* width == 4 */ + int y = height; + do { + const __m128i v_src0 = Load4(&src[0]); + const __m128i v_src1 = Load4(&src[src_stride]); + const __m128i v_src = _mm_unpacklo_epi32(v_src0, v_src1); + const __m128i v_src_ext = _mm_cvtepu8_epi16(v_src); + const __m128i v_dest = _mm_slli_epi16(v_src_ext, kRoundBitsVertical); + StoreLo8(&dest[0], v_dest); + StoreHi8(&dest[pred_stride], v_dest); + src += src_stride * 2; + dest += pred_stride * 2; + y -= 2; + } while (y != 0); + } +} + +void ConvolveCompoundVertical_SSE4_1( + const void* const reference, const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, const int vertical_filter_index, + const int /*horizontal_filter_id*/, const int vertical_filter_id, + const int width, const int height, void* prediction, + const ptrdiff_t /*pred_stride*/) { + const int filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(filter_index); + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride; + auto* dest = static_cast<uint16_t*>(prediction); + assert(vertical_filter_id != 0); + + __m128i taps[4]; + const __m128i v_filter = + LoadLo8(kHalfSubPixelFilters[filter_index][vertical_filter_id]); + + if (filter_index < 2) { // 6 tap. + SetupTaps<6>(&v_filter, taps); + if (width == 4) { + FilterVertical4xH<0, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps); + } else { + FilterVertical<0, /*is_compound=*/true>(src, src_stride, dest, width, + width, height, taps); + } + } else if (filter_index == 2) { // 8 tap. + SetupTaps<8>(&v_filter, taps); + + if (width == 4) { + FilterVertical4xH<2, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps); + } else { + FilterVertical<2, /*is_compound=*/true>(src, src_stride, dest, width, + width, height, taps); + } + } else if (filter_index == 3) { // 2 tap. + SetupTaps<2>(&v_filter, taps); + + if (width == 4) { + FilterVertical4xH<3, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps); + } else { + FilterVertical<3, /*is_compound=*/true>(src, src_stride, dest, width, + width, height, taps); + } + } else if (filter_index == 4) { // 4 tap. + SetupTaps<4>(&v_filter, taps); + + if (width == 4) { + FilterVertical4xH<4, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps); + } else { + FilterVertical<4, /*is_compound=*/true>(src, src_stride, dest, width, + width, height, taps); + } + } else { + SetupTaps<4>(&v_filter, taps); + + if (width == 4) { + FilterVertical4xH<5, /*is_compound=*/true>(src, src_stride, dest, 4, + height, taps); + } else { + FilterVertical<5, /*is_compound=*/true>(src, src_stride, dest, width, + width, height, taps); + } + } +} + +void ConvolveHorizontal_SSE4_1(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int /*vertical_filter_index*/, + const int horizontal_filter_id, + const int /*vertical_filter_id*/, + const int width, const int height, + void* prediction, const ptrdiff_t pred_stride) { + const int filter_index = GetFilterIndex(horizontal_filter_index, width); + // Set |src| to the outermost tap. + const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset; + auto* dest = static_cast<uint8_t*>(prediction); + + DoHorizontalPass(src, reference_stride, dest, pred_stride, width, height, + horizontal_filter_id, filter_index); +} + +void ConvolveCompoundHorizontal_SSE4_1( + const void* const reference, const ptrdiff_t reference_stride, + const int horizontal_filter_index, const int /*vertical_filter_index*/, + const int horizontal_filter_id, const int /*vertical_filter_id*/, + const int width, const int height, void* prediction, + const ptrdiff_t /*pred_stride*/) { + const int filter_index = GetFilterIndex(horizontal_filter_index, width); + const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset; + auto* dest = static_cast<uint16_t*>(prediction); + + DoHorizontalPass</*is_2d=*/false, /*is_compound=*/true>( + src, reference_stride, dest, width, width, height, horizontal_filter_id, + filter_index); +} + +void ConvolveCompound2D_SSE4_1(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int vertical_filter_index, + const int horizontal_filter_id, + const int vertical_filter_id, const int width, + const int height, void* prediction, + const ptrdiff_t /*pred_stride*/) { + // The output of the horizontal filter, i.e. the intermediate_result, is + // guaranteed to fit in int16_t. + alignas(16) uint16_t + intermediate_result[kMaxSuperBlockSizeInPixels * + (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)]; + + // Horizontal filter. + // Filter types used for width <= 4 are different from those for width > 4. + // When width > 4, the valid filter index range is always [0, 3]. + // When width <= 4, the valid filter index range is always [4, 5]. + // Similarly for height. + const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width); + const int vert_filter_index = GetFilterIndex(vertical_filter_index, height); + const int vertical_taps = GetNumTapsInFilter(vert_filter_index); + const int intermediate_height = height + vertical_taps - 1; + const ptrdiff_t src_stride = reference_stride; + const auto* const src = static_cast<const uint8_t*>(reference) - + (vertical_taps / 2 - 1) * src_stride - + kHorizontalOffset; + + DoHorizontalPass</*is_2d=*/true, /*is_compound=*/true>( + src, src_stride, intermediate_result, width, width, intermediate_height, + horizontal_filter_id, horiz_filter_index); + + // Vertical filter. + auto* dest = static_cast<uint16_t*>(prediction); + assert(vertical_filter_id != 0); + + const ptrdiff_t dest_stride = width; + __m128i taps[4]; + const __m128i v_filter = + LoadLo8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id]); + + if (vertical_taps == 8) { + SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 4) { + Filter2DVertical4xH<8, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<8, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } else if (vertical_taps == 6) { + SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 4) { + Filter2DVertical4xH<6, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<6, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } else if (vertical_taps == 4) { + SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 4) { + Filter2DVertical4xH<4, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<4, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } else { // |vertical_taps| == 2 + SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter, taps); + if (width == 4) { + Filter2DVertical4xH<2, /*is_compound=*/true>(intermediate_result, dest, + dest_stride, height, taps); + } else { + Filter2DVertical<2, /*is_compound=*/true>( + intermediate_result, dest, dest_stride, width, height, taps); + } + } +} + +// Pre-transposed filters. +template <int filter_index> +inline void GetHalfSubPixelFilter(__m128i* output) { + // Filter 0 + alignas( + 16) static constexpr int8_t kHalfSubPixel6TapSignedFilterColumns[6][16] = + {{0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0}, + {0, -3, -5, -6, -7, -7, -8, -7, -7, -6, -6, -6, -5, -4, -2, -1}, + {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4}, + {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63}, + {0, -1, -2, -4, -5, -6, -6, -6, -7, -7, -8, -7, -7, -6, -5, -3}, + {0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}}; + // Filter 1 + alignas(16) static constexpr int8_t + kHalfSubPixel6TapMixedSignedFilterColumns[6][16] = { + {0, 1, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 0, 0}, + {0, 14, 13, 11, 10, 9, 8, 8, 7, 6, 5, 4, 3, 2, 2, 1}, + {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17}, + {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31}, + {0, 1, 2, 2, 3, 4, 5, 6, 7, 8, 8, 9, 10, 11, 13, 14}, + {0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 1}}; + // Filter 2 + alignas( + 16) static constexpr int8_t kHalfSubPixel8TapSignedFilterColumns[8][16] = + {{0, -1, -1, -1, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, 0}, + {0, 1, 3, 4, 5, 5, 5, 5, 6, 5, 4, 4, 3, 3, 2, 1}, + {0, -3, -6, -9, -11, -11, -12, -12, -12, -11, -10, -9, -7, -5, -3, -1}, + {64, 63, 62, 60, 58, 54, 50, 45, 40, 35, 30, 24, 19, 13, 8, 4}, + {0, 4, 8, 13, 19, 24, 30, 35, 40, 45, 50, 54, 58, 60, 62, 63}, + {0, -1, -3, -5, -7, -9, -10, -11, -12, -12, -12, -11, -11, -9, -6, -3}, + {0, 1, 2, 3, 3, 4, 4, 5, 6, 5, 5, 5, 5, 4, 3, 1}, + {0, 0, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -1, -1, -1}}; + // Filter 3 + alignas(16) static constexpr uint8_t kHalfSubPixel2TapFilterColumns[2][16] = { + {64, 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4}, + {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}}; + // Filter 4 + alignas( + 16) static constexpr int8_t kHalfSubPixel4TapSignedFilterColumns[4][16] = + {{0, -2, -4, -5, -6, -6, -7, -6, -6, -5, -5, -5, -4, -3, -2, -1}, + {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4}, + {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63}, + {0, -1, -2, -3, -4, -5, -5, -5, -6, -6, -7, -6, -6, -5, -4, -2}}; + // Filter 5 + alignas( + 16) static constexpr uint8_t kSubPixel4TapPositiveFilterColumns[4][16] = { + {0, 15, 13, 11, 10, 9, 8, 7, 6, 6, 5, 4, 3, 2, 2, 1}, + {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17}, + {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31}, + {0, 1, 2, 2, 3, 4, 5, 6, 6, 7, 8, 9, 10, 11, 13, 15}}; + switch (filter_index) { + case 0: + output[0] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[0]); + output[1] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[1]); + output[2] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[2]); + output[3] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[3]); + output[4] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[4]); + output[5] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[5]); + break; + case 1: + // The term "mixed" refers to the fact that the outer taps have a mix of + // negative and positive values. + output[0] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[0]); + output[1] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[1]); + output[2] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[2]); + output[3] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[3]); + output[4] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[4]); + output[5] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[5]); + break; + case 2: + output[0] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[0]); + output[1] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[1]); + output[2] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[2]); + output[3] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[3]); + output[4] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[4]); + output[5] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[5]); + output[6] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[6]); + output[7] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[7]); + break; + case 3: + output[0] = LoadAligned16(kHalfSubPixel2TapFilterColumns[0]); + output[1] = LoadAligned16(kHalfSubPixel2TapFilterColumns[1]); + break; + case 4: + output[0] = LoadAligned16(kHalfSubPixel4TapSignedFilterColumns[0]); + output[1] = LoadAligned16(kHalfSubPixel4TapSignedFilterColumns[1]); + output[2] = LoadAligned16(kHalfSubPixel4TapSignedFilterColumns[2]); + output[3] = LoadAligned16(kHalfSubPixel4TapSignedFilterColumns[3]); + break; + default: + assert(filter_index == 5); + output[0] = LoadAligned16(kSubPixel4TapPositiveFilterColumns[0]); + output[1] = LoadAligned16(kSubPixel4TapPositiveFilterColumns[1]); + output[2] = LoadAligned16(kSubPixel4TapPositiveFilterColumns[2]); + output[3] = LoadAligned16(kSubPixel4TapPositiveFilterColumns[3]); + break; + } +} + +// There are many opportunities for overreading in scaled convolve, because +// the range of starting points for filter windows is anywhere from 0 to 16 +// for 8 destination pixels, and the window sizes range from 2 to 8. To +// accommodate this range concisely, we use |grade_x| to mean the most steps +// in src that can be traversed in a single |step_x| increment, i.e. 1 or 2. +// More importantly, |grade_x| answers the question "how many vector loads are +// needed to cover the source values?" +// When |grade_x| == 1, the maximum number of source values needed is 8 separate +// starting positions plus 7 more to cover taps, all fitting into 16 bytes. +// When |grade_x| > 1, we are guaranteed to exceed 8 whole steps in src for +// every 8 |step_x| increments, on top of 8 possible taps. The first load covers +// the starting sources for each kernel, while the final load covers the taps. +// Since the offset value of src_x cannot exceed 8 and |num_taps| does not +// exceed 4 when width <= 4, |grade_x| is set to 1 regardless of the value of +// |step_x|. +template <int num_taps, int grade_x> +inline void PrepareSourceVectors(const uint8_t* src, const __m128i src_indices, + __m128i* const source /*[num_taps >> 1]*/) { + const __m128i src_vals = LoadUnaligned16(src); + source[0] = _mm_shuffle_epi8(src_vals, src_indices); + if (grade_x == 1) { + if (num_taps > 2) { + source[1] = _mm_shuffle_epi8(_mm_srli_si128(src_vals, 2), src_indices); + } + if (num_taps > 4) { + source[2] = _mm_shuffle_epi8(_mm_srli_si128(src_vals, 4), src_indices); + } + if (num_taps > 6) { + source[3] = _mm_shuffle_epi8(_mm_srli_si128(src_vals, 6), src_indices); + } + } else { + assert(grade_x > 1); + assert(num_taps != 4); + // grade_x > 1 also means width >= 8 && num_taps != 4 + const __m128i src_vals_ext = LoadLo8(src + 16); + if (num_taps > 2) { + source[1] = _mm_shuffle_epi8(_mm_alignr_epi8(src_vals_ext, src_vals, 2), + src_indices); + source[2] = _mm_shuffle_epi8(_mm_alignr_epi8(src_vals_ext, src_vals, 4), + src_indices); + } + if (num_taps > 6) { + source[3] = _mm_shuffle_epi8(_mm_alignr_epi8(src_vals_ext, src_vals, 6), + src_indices); + } + } +} + +template <int num_taps> +inline void PrepareHorizontalTaps(const __m128i subpel_indices, + const __m128i* filter_taps, + __m128i* out_taps) { + const __m128i scale_index_offsets = + _mm_srli_epi16(subpel_indices, kFilterIndexShift); + const __m128i filter_index_mask = _mm_set1_epi8(kSubPixelMask); + const __m128i filter_indices = + _mm_and_si128(_mm_packus_epi16(scale_index_offsets, scale_index_offsets), + filter_index_mask); + // Line up taps for maddubs_epi16. + // The unpack is also assumed to be lighter than shift+alignr. + for (int k = 0; k < (num_taps >> 1); ++k) { + const __m128i taps0 = _mm_shuffle_epi8(filter_taps[2 * k], filter_indices); + const __m128i taps1 = + _mm_shuffle_epi8(filter_taps[2 * k + 1], filter_indices); + out_taps[k] = _mm_unpacklo_epi8(taps0, taps1); + } +} + +inline __m128i HorizontalScaleIndices(const __m128i subpel_indices) { + const __m128i src_indices16 = + _mm_srli_epi16(subpel_indices, kScaleSubPixelBits); + const __m128i src_indices = _mm_packus_epi16(src_indices16, src_indices16); + return _mm_unpacklo_epi8(src_indices, + _mm_add_epi8(src_indices, _mm_set1_epi8(1))); +} + +template <int grade_x, int filter_index, int num_taps> +inline void ConvolveHorizontalScale(const uint8_t* src, ptrdiff_t src_stride, + int width, int subpixel_x, int step_x, + int intermediate_height, + int16_t* intermediate) { + // Account for the 0-taps that precede the 2 nonzero taps. + const int kernel_offset = (8 - num_taps) >> 1; + const int ref_x = subpixel_x >> kScaleSubPixelBits; + const int step_x8 = step_x << 3; + __m128i filter_taps[num_taps]; + GetHalfSubPixelFilter<filter_index>(filter_taps); + const __m128i index_steps = + _mm_mullo_epi16(_mm_set_epi16(7, 6, 5, 4, 3, 2, 1, 0), + _mm_set1_epi16(static_cast<int16_t>(step_x))); + + __m128i taps[num_taps >> 1]; + __m128i source[num_taps >> 1]; + int p = subpixel_x; + // Case when width <= 4 is possible. + if (filter_index >= 3) { + if (filter_index > 3 || width <= 4) { + const uint8_t* src_x = + &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset]; + // Only add steps to the 10-bit truncated p to avoid overflow. + const __m128i p_fraction = _mm_set1_epi16(p & 1023); + const __m128i subpel_indices = _mm_add_epi16(index_steps, p_fraction); + PrepareHorizontalTaps<num_taps>(subpel_indices, filter_taps, taps); + const __m128i packed_indices = HorizontalScaleIndices(subpel_indices); + + int y = intermediate_height; + do { + // Load and line up source values with the taps. Width 4 means no need + // to load extended source. + PrepareSourceVectors<num_taps, /*grade_x=*/1>(src_x, packed_indices, + source); + + StoreLo8(intermediate, RightShiftWithRounding_S16( + SumOnePassTaps<filter_index>(source, taps), + kInterRoundBitsHorizontal - 1)); + src_x += src_stride; + intermediate += kIntermediateStride; + } while (--y != 0); + return; + } + } + + // |width| >= 8 + int x = 0; + do { + const uint8_t* src_x = + &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset]; + int16_t* intermediate_x = intermediate + x; + // Only add steps to the 10-bit truncated p to avoid overflow. + const __m128i p_fraction = _mm_set1_epi16(p & 1023); + const __m128i subpel_indices = _mm_add_epi16(index_steps, p_fraction); + PrepareHorizontalTaps<num_taps>(subpel_indices, filter_taps, taps); + const __m128i packed_indices = HorizontalScaleIndices(subpel_indices); + + int y = intermediate_height; + do { + // For each x, a lane of src_k[k] contains src_x[k]. + PrepareSourceVectors<num_taps, grade_x>(src_x, packed_indices, source); + + // Shift by one less because the taps are halved. + StoreAligned16( + intermediate_x, + RightShiftWithRounding_S16(SumOnePassTaps<filter_index>(source, taps), + kInterRoundBitsHorizontal - 1)); + src_x += src_stride; + intermediate_x += kIntermediateStride; + } while (--y != 0); + x += 8; + p += step_x8; + } while (x < width); +} + +template <int num_taps> +inline void PrepareVerticalTaps(const int8_t* taps, __m128i* output) { + // Avoid overreading the filter due to starting at kernel_offset. + // The only danger of overread is in the final filter, which has 4 taps. + const __m128i filter = + _mm_cvtepi8_epi16((num_taps > 4) ? LoadLo8(taps) : Load4(taps)); + output[0] = _mm_shuffle_epi32(filter, 0); + if (num_taps > 2) { + output[1] = _mm_shuffle_epi32(filter, 0x55); + } + if (num_taps > 4) { + output[2] = _mm_shuffle_epi32(filter, 0xAA); + } + if (num_taps > 6) { + output[3] = _mm_shuffle_epi32(filter, 0xFF); + } +} + +// Process eight 16 bit inputs and output eight 16 bit values. +template <int num_taps, bool is_compound> +inline __m128i Sum2DVerticalTaps(const __m128i* const src, + const __m128i* taps) { + const __m128i src_lo_01 = _mm_unpacklo_epi16(src[0], src[1]); + __m128i sum_lo = _mm_madd_epi16(src_lo_01, taps[0]); + const __m128i src_hi_01 = _mm_unpackhi_epi16(src[0], src[1]); + __m128i sum_hi = _mm_madd_epi16(src_hi_01, taps[0]); + if (num_taps > 2) { + const __m128i src_lo_23 = _mm_unpacklo_epi16(src[2], src[3]); + sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_23, taps[1])); + const __m128i src_hi_23 = _mm_unpackhi_epi16(src[2], src[3]); + sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_23, taps[1])); + } + if (num_taps > 4) { + const __m128i src_lo_45 = _mm_unpacklo_epi16(src[4], src[5]); + sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_45, taps[2])); + const __m128i src_hi_45 = _mm_unpackhi_epi16(src[4], src[5]); + sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_45, taps[2])); + } + if (num_taps > 6) { + const __m128i src_lo_67 = _mm_unpacklo_epi16(src[6], src[7]); + sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_67, taps[3])); + const __m128i src_hi_67 = _mm_unpackhi_epi16(src[6], src[7]); + sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_67, taps[3])); + } + if (is_compound) { + return _mm_packs_epi32( + RightShiftWithRounding_S32(sum_lo, kInterRoundBitsCompoundVertical - 1), + RightShiftWithRounding_S32(sum_hi, + kInterRoundBitsCompoundVertical - 1)); + } + return _mm_packs_epi32( + RightShiftWithRounding_S32(sum_lo, kInterRoundBitsVertical - 1), + RightShiftWithRounding_S32(sum_hi, kInterRoundBitsVertical - 1)); +} + +// Bottom half of each src[k] is the source for one filter, and the top half +// is the source for the other filter, for the next destination row. +template <int num_taps, bool is_compound> +__m128i Sum2DVerticalTaps4x2(const __m128i* const src, const __m128i* taps_lo, + const __m128i* taps_hi) { + const __m128i src_lo_01 = _mm_unpacklo_epi16(src[0], src[1]); + __m128i sum_lo = _mm_madd_epi16(src_lo_01, taps_lo[0]); + const __m128i src_hi_01 = _mm_unpackhi_epi16(src[0], src[1]); + __m128i sum_hi = _mm_madd_epi16(src_hi_01, taps_hi[0]); + if (num_taps > 2) { + const __m128i src_lo_23 = _mm_unpacklo_epi16(src[2], src[3]); + sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_23, taps_lo[1])); + const __m128i src_hi_23 = _mm_unpackhi_epi16(src[2], src[3]); + sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_23, taps_hi[1])); + } + if (num_taps > 4) { + const __m128i src_lo_45 = _mm_unpacklo_epi16(src[4], src[5]); + sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_45, taps_lo[2])); + const __m128i src_hi_45 = _mm_unpackhi_epi16(src[4], src[5]); + sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_45, taps_hi[2])); + } + if (num_taps > 6) { + const __m128i src_lo_67 = _mm_unpacklo_epi16(src[6], src[7]); + sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_67, taps_lo[3])); + const __m128i src_hi_67 = _mm_unpackhi_epi16(src[6], src[7]); + sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_67, taps_hi[3])); + } + + if (is_compound) { + return _mm_packs_epi32( + RightShiftWithRounding_S32(sum_lo, kInterRoundBitsCompoundVertical - 1), + RightShiftWithRounding_S32(sum_hi, + kInterRoundBitsCompoundVertical - 1)); + } + return _mm_packs_epi32( + RightShiftWithRounding_S32(sum_lo, kInterRoundBitsVertical - 1), + RightShiftWithRounding_S32(sum_hi, kInterRoundBitsVertical - 1)); +} + +// |width_class| is 2, 4, or 8, according to the Store function that should be +// used. +template <int num_taps, int width_class, bool is_compound> +#if LIBGAV1_MSAN +__attribute__((no_sanitize_memory)) void ConvolveVerticalScale( +#else +inline void ConvolveVerticalScale( +#endif + const int16_t* src, const int width, const int subpixel_y, + const int filter_index, const int step_y, const int height, void* dest, + const ptrdiff_t dest_stride) { + constexpr ptrdiff_t src_stride = kIntermediateStride; + constexpr int kernel_offset = (8 - num_taps) / 2; + const int16_t* src_y = src; + // |dest| is 16-bit in compound mode, Pixel otherwise. + auto* dest16_y = static_cast<uint16_t*>(dest); + auto* dest_y = static_cast<uint8_t*>(dest); + __m128i s[num_taps]; + + int p = subpixel_y & 1023; + int y = height; + if (width_class <= 4) { + __m128i filter_taps_lo[num_taps >> 1]; + __m128i filter_taps_hi[num_taps >> 1]; + do { // y > 0 + for (int i = 0; i < num_taps; ++i) { + s[i] = LoadLo8(src_y + i * src_stride); + } + int filter_id = (p >> 6) & kSubPixelMask; + const int8_t* filter0 = + kHalfSubPixelFilters[filter_index][filter_id] + kernel_offset; + PrepareVerticalTaps<num_taps>(filter0, filter_taps_lo); + p += step_y; + src_y = src + (p >> kScaleSubPixelBits) * src_stride; + + for (int i = 0; i < num_taps; ++i) { + s[i] = LoadHi8(s[i], src_y + i * src_stride); + } + filter_id = (p >> 6) & kSubPixelMask; + const int8_t* filter1 = + kHalfSubPixelFilters[filter_index][filter_id] + kernel_offset; + PrepareVerticalTaps<num_taps>(filter1, filter_taps_hi); + p += step_y; + src_y = src + (p >> kScaleSubPixelBits) * src_stride; + + const __m128i sums = Sum2DVerticalTaps4x2<num_taps, is_compound>( + s, filter_taps_lo, filter_taps_hi); + if (is_compound) { + assert(width_class > 2); + StoreLo8(dest16_y, sums); + dest16_y += dest_stride; + StoreHi8(dest16_y, sums); + dest16_y += dest_stride; + } else { + const __m128i result = _mm_packus_epi16(sums, sums); + if (width_class == 2) { + Store2(dest_y, result); + dest_y += dest_stride; + Store2(dest_y, _mm_srli_si128(result, 4)); + } else { + Store4(dest_y, result); + dest_y += dest_stride; + Store4(dest_y, _mm_srli_si128(result, 4)); + } + dest_y += dest_stride; + } + y -= 2; + } while (y != 0); + return; + } + + // |width_class| >= 8 + __m128i filter_taps[num_taps >> 1]; + do { // y > 0 + src_y = src + (p >> kScaleSubPixelBits) * src_stride; + const int filter_id = (p >> 6) & kSubPixelMask; + const int8_t* filter = + kHalfSubPixelFilters[filter_index][filter_id] + kernel_offset; + PrepareVerticalTaps<num_taps>(filter, filter_taps); + + int x = 0; + do { // x < width + for (int i = 0; i < num_taps; ++i) { + s[i] = LoadUnaligned16(src_y + i * src_stride); + } + + const __m128i sums = + Sum2DVerticalTaps<num_taps, is_compound>(s, filter_taps); + if (is_compound) { + StoreUnaligned16(dest16_y + x, sums); + } else { + StoreLo8(dest_y + x, _mm_packus_epi16(sums, sums)); + } + x += 8; + src_y += 8; + } while (x < width); + p += step_y; + dest_y += dest_stride; + dest16_y += dest_stride; + } while (--y != 0); +} + +template <bool is_compound> +void ConvolveScale2D_SSE4_1(const void* const reference, + const ptrdiff_t reference_stride, + const int horizontal_filter_index, + const int vertical_filter_index, + const int subpixel_x, const int subpixel_y, + const int step_x, const int step_y, const int width, + const int height, void* prediction, + const ptrdiff_t pred_stride) { + const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width); + const int vert_filter_index = GetFilterIndex(vertical_filter_index, height); + assert(step_x <= 2048); + // The output of the horizontal filter, i.e. the intermediate_result, is + // guaranteed to fit in int16_t. + // TODO(petersonab): Reduce intermediate block stride to width to make smaller + // blocks faster. + alignas(16) int16_t + intermediate_result[kMaxSuperBlockSizeInPixels * + (2 * kMaxSuperBlockSizeInPixels + kSubPixelTaps)]; + const int num_vert_taps = GetNumTapsInFilter(vert_filter_index); + const int intermediate_height = + (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >> + kScaleSubPixelBits) + + num_vert_taps; + + // Horizontal filter. + // Filter types used for width <= 4 are different from those for width > 4. + // When width > 4, the valid filter index range is always [0, 3]. + // When width <= 4, the valid filter index range is always [3, 5]. + // Similarly for height. + int16_t* intermediate = intermediate_result; + const ptrdiff_t src_stride = reference_stride; + const auto* src = static_cast<const uint8_t*>(reference); + const int vert_kernel_offset = (8 - num_vert_taps) / 2; + src += vert_kernel_offset * src_stride; + + // Derive the maximum value of |step_x| at which all source values fit in one + // 16-byte load. Final index is src_x + |num_taps| - 1 < 16 + // step_x*7 is the final base sub-pixel index for the shuffle mask for filter + // inputs in each iteration on large blocks. When step_x is large, we need a + // second register and alignr in order to gather all filter inputs. + // |num_taps| - 1 is the offset for the shuffle of inputs to the final tap. + const int num_horiz_taps = GetNumTapsInFilter(horiz_filter_index); + const int kernel_start_ceiling = 16 - num_horiz_taps; + // This truncated quotient |grade_x_threshold| selects |step_x| such that: + // (step_x * 7) >> kScaleSubPixelBits < single load limit + const int grade_x_threshold = + (kernel_start_ceiling << kScaleSubPixelBits) / 7; + switch (horiz_filter_index) { + case 0: + if (step_x > grade_x_threshold) { + ConvolveHorizontalScale<2, 0, 6>(src, src_stride, width, subpixel_x, + step_x, intermediate_height, + intermediate); + } else { + ConvolveHorizontalScale<1, 0, 6>(src, src_stride, width, subpixel_x, + step_x, intermediate_height, + intermediate); + } + break; + case 1: + if (step_x > grade_x_threshold) { + ConvolveHorizontalScale<2, 1, 6>(src, src_stride, width, subpixel_x, + step_x, intermediate_height, + intermediate); + + } else { + ConvolveHorizontalScale<1, 1, 6>(src, src_stride, width, subpixel_x, + step_x, intermediate_height, + intermediate); + } + break; + case 2: + if (step_x > grade_x_threshold) { + ConvolveHorizontalScale<2, 2, 8>(src, src_stride, width, subpixel_x, + step_x, intermediate_height, + intermediate); + } else { + ConvolveHorizontalScale<1, 2, 8>(src, src_stride, width, subpixel_x, + step_x, intermediate_height, + intermediate); + } + break; + case 3: + if (step_x > grade_x_threshold) { + ConvolveHorizontalScale<2, 3, 2>(src, src_stride, width, subpixel_x, + step_x, intermediate_height, + intermediate); + } else { + ConvolveHorizontalScale<1, 3, 2>(src, src_stride, width, subpixel_x, + step_x, intermediate_height, + intermediate); + } + break; + case 4: + assert(width <= 4); + ConvolveHorizontalScale<1, 4, 4>(src, src_stride, width, subpixel_x, + step_x, intermediate_height, + intermediate); + break; + default: + assert(horiz_filter_index == 5); + assert(width <= 4); + ConvolveHorizontalScale<1, 5, 4>(src, src_stride, width, subpixel_x, + step_x, intermediate_height, + intermediate); + } + + // Vertical filter. + intermediate = intermediate_result; + switch (vert_filter_index) { + case 0: + case 1: + if (!is_compound && width == 2) { + ConvolveVerticalScale<6, 2, is_compound>( + intermediate, width, subpixel_y, vert_filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale<6, 4, is_compound>( + intermediate, width, subpixel_y, vert_filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<6, 8, is_compound>( + intermediate, width, subpixel_y, vert_filter_index, step_y, height, + prediction, pred_stride); + } + break; + case 2: + if (!is_compound && width == 2) { + ConvolveVerticalScale<8, 2, is_compound>( + intermediate, width, subpixel_y, vert_filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale<8, 4, is_compound>( + intermediate, width, subpixel_y, vert_filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<8, 8, is_compound>( + intermediate, width, subpixel_y, vert_filter_index, step_y, height, + prediction, pred_stride); + } + break; + case 3: + if (!is_compound && width == 2) { + ConvolveVerticalScale<2, 2, is_compound>( + intermediate, width, subpixel_y, vert_filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale<2, 4, is_compound>( + intermediate, width, subpixel_y, vert_filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<2, 8, is_compound>( + intermediate, width, subpixel_y, vert_filter_index, step_y, height, + prediction, pred_stride); + } + break; + default: + assert(vert_filter_index == 4 || vert_filter_index == 5); + if (!is_compound && width == 2) { + ConvolveVerticalScale<4, 2, is_compound>( + intermediate, width, subpixel_y, vert_filter_index, step_y, height, + prediction, pred_stride); + } else if (width == 4) { + ConvolveVerticalScale<4, 4, is_compound>( + intermediate, width, subpixel_y, vert_filter_index, step_y, height, + prediction, pred_stride); + } else { + ConvolveVerticalScale<4, 8, is_compound>( + intermediate, width, subpixel_y, vert_filter_index, step_y, height, + prediction, pred_stride); + } + } +} + +inline void HalfAddHorizontal(const uint8_t* src, uint8_t* dst) { + const __m128i left = LoadUnaligned16(src); + const __m128i right = LoadUnaligned16(src + 1); + StoreUnaligned16(dst, _mm_avg_epu8(left, right)); +} + +template <int width> +inline void IntraBlockCopyHorizontal(const uint8_t* src, + const ptrdiff_t src_stride, + const int height, uint8_t* dst, + const ptrdiff_t dst_stride) { + const ptrdiff_t src_remainder_stride = src_stride - (width - 16); + const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16); + + int y = height; + do { + HalfAddHorizontal(src, dst); + if (width >= 32) { + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + if (width >= 64) { + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + if (width == 128) { + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + src += 16; + dst += 16; + HalfAddHorizontal(src, dst); + } + } + } + src += src_remainder_stride; + dst += dst_remainder_stride; + } while (--y != 0); +} + +void ConvolveIntraBlockCopyHorizontal_SSE4_1( + const void* const reference, const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/, + const int /*subpixel_x*/, const int /*subpixel_y*/, const int width, + const int height, void* const prediction, const ptrdiff_t pred_stride) { + const auto* src = static_cast<const uint8_t*>(reference); + auto* dest = static_cast<uint8_t*>(prediction); + + if (width == 128) { + IntraBlockCopyHorizontal<128>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 64) { + IntraBlockCopyHorizontal<64>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 32) { + IntraBlockCopyHorizontal<32>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 16) { + IntraBlockCopyHorizontal<16>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 8) { + int y = height; + do { + const __m128i left = LoadLo8(src); + const __m128i right = LoadLo8(src + 1); + StoreLo8(dest, _mm_avg_epu8(left, right)); + + src += reference_stride; + dest += pred_stride; + } while (--y != 0); + } else if (width == 4) { + int y = height; + do { + __m128i left = Load4(src); + __m128i right = Load4(src + 1); + src += reference_stride; + left = _mm_unpacklo_epi32(left, Load4(src)); + right = _mm_unpacklo_epi32(right, Load4(src + 1)); + src += reference_stride; + + const __m128i result = _mm_avg_epu8(left, right); + + Store4(dest, result); + dest += pred_stride; + Store4(dest, _mm_srli_si128(result, 4)); + dest += pred_stride; + y -= 2; + } while (y != 0); + } else { + assert(width == 2); + __m128i left = _mm_setzero_si128(); + __m128i right = _mm_setzero_si128(); + int y = height; + do { + left = Load2<0>(src, left); + right = Load2<0>(src + 1, right); + src += reference_stride; + left = Load2<1>(src, left); + right = Load2<1>(src + 1, right); + src += reference_stride; + + const __m128i result = _mm_avg_epu8(left, right); + + Store2(dest, result); + dest += pred_stride; + Store2(dest, _mm_srli_si128(result, 2)); + dest += pred_stride; + y -= 2; + } while (y != 0); + } +} + +template <int width> +inline void IntraBlockCopyVertical(const uint8_t* src, + const ptrdiff_t src_stride, const int height, + uint8_t* dst, const ptrdiff_t dst_stride) { + const ptrdiff_t src_remainder_stride = src_stride - (width - 16); + const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16); + __m128i row[8], below[8]; + + row[0] = LoadUnaligned16(src); + if (width >= 32) { + src += 16; + row[1] = LoadUnaligned16(src); + if (width >= 64) { + src += 16; + row[2] = LoadUnaligned16(src); + src += 16; + row[3] = LoadUnaligned16(src); + if (width == 128) { + src += 16; + row[4] = LoadUnaligned16(src); + src += 16; + row[5] = LoadUnaligned16(src); + src += 16; + row[6] = LoadUnaligned16(src); + src += 16; + row[7] = LoadUnaligned16(src); + } + } + } + src += src_remainder_stride; + + int y = height; + do { + below[0] = LoadUnaligned16(src); + if (width >= 32) { + src += 16; + below[1] = LoadUnaligned16(src); + if (width >= 64) { + src += 16; + below[2] = LoadUnaligned16(src); + src += 16; + below[3] = LoadUnaligned16(src); + if (width == 128) { + src += 16; + below[4] = LoadUnaligned16(src); + src += 16; + below[5] = LoadUnaligned16(src); + src += 16; + below[6] = LoadUnaligned16(src); + src += 16; + below[7] = LoadUnaligned16(src); + } + } + } + src += src_remainder_stride; + + StoreUnaligned16(dst, _mm_avg_epu8(row[0], below[0])); + row[0] = below[0]; + if (width >= 32) { + dst += 16; + StoreUnaligned16(dst, _mm_avg_epu8(row[1], below[1])); + row[1] = below[1]; + if (width >= 64) { + dst += 16; + StoreUnaligned16(dst, _mm_avg_epu8(row[2], below[2])); + row[2] = below[2]; + dst += 16; + StoreUnaligned16(dst, _mm_avg_epu8(row[3], below[3])); + row[3] = below[3]; + if (width >= 128) { + dst += 16; + StoreUnaligned16(dst, _mm_avg_epu8(row[4], below[4])); + row[4] = below[4]; + dst += 16; + StoreUnaligned16(dst, _mm_avg_epu8(row[5], below[5])); + row[5] = below[5]; + dst += 16; + StoreUnaligned16(dst, _mm_avg_epu8(row[6], below[6])); + row[6] = below[6]; + dst += 16; + StoreUnaligned16(dst, _mm_avg_epu8(row[7], below[7])); + row[7] = below[7]; + } + } + } + dst += dst_remainder_stride; + } while (--y != 0); +} + +void ConvolveIntraBlockCopyVertical_SSE4_1( + const void* const reference, const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/, + const int /*horizontal_filter_id*/, const int /*vertical_filter_id*/, + const int width, const int height, void* const prediction, + const ptrdiff_t pred_stride) { + const auto* src = static_cast<const uint8_t*>(reference); + auto* dest = static_cast<uint8_t*>(prediction); + + if (width == 128) { + IntraBlockCopyVertical<128>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 64) { + IntraBlockCopyVertical<64>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 32) { + IntraBlockCopyVertical<32>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 16) { + IntraBlockCopyVertical<16>(src, reference_stride, height, dest, + pred_stride); + } else if (width == 8) { + __m128i row, below; + row = LoadLo8(src); + src += reference_stride; + + int y = height; + do { + below = LoadLo8(src); + src += reference_stride; + + StoreLo8(dest, _mm_avg_epu8(row, below)); + dest += pred_stride; + + row = below; + } while (--y != 0); + } else if (width == 4) { + __m128i row = Load4(src); + src += reference_stride; + + int y = height; + do { + __m128i below = Load4(src); + src += reference_stride; + + Store4(dest, _mm_avg_epu8(row, below)); + dest += pred_stride; + + row = below; + } while (--y != 0); + } else { + assert(width == 2); + __m128i row = Load2(src); + __m128i below = _mm_setzero_si128(); + src += reference_stride; + + int y = height; + do { + below = Load2<0>(src, below); + src += reference_stride; + + Store2(dest, _mm_avg_epu8(row, below)); + dest += pred_stride; + + row = below; + } while (--y != 0); + } +} + +// Load then add two uint8_t vectors. Return the uint16_t vector result. +inline __m128i LoadU8AndAddLong(const uint8_t* src, const uint8_t* src1) { + const __m128i a = _mm_cvtepu8_epi16(LoadLo8(src)); + const __m128i b = _mm_cvtepu8_epi16(LoadLo8(src1)); + return _mm_add_epi16(a, b); +} + +inline __m128i AddU16RightShift2AndPack(__m128i v0, __m128i v1) { + const __m128i a = _mm_add_epi16(v0, v1); + const __m128i b = _mm_srli_epi16(a, 1); + // Use avg here to shift right by 1 with round. + const __m128i c = _mm_avg_epu16(b, _mm_setzero_si128()); + return _mm_packus_epi16(c, c); +} + +template <int width> +inline void IntraBlockCopy2D(const uint8_t* src, const ptrdiff_t src_stride, + const int height, uint8_t* dst, + const ptrdiff_t dst_stride) { + const ptrdiff_t src_remainder_stride = src_stride - (width - 8); + const ptrdiff_t dst_remainder_stride = dst_stride - (width - 8); + __m128i row[16]; + row[0] = LoadU8AndAddLong(src, src + 1); + if (width >= 16) { + src += 8; + row[1] = LoadU8AndAddLong(src, src + 1); + if (width >= 32) { + src += 8; + row[2] = LoadU8AndAddLong(src, src + 1); + src += 8; + row[3] = LoadU8AndAddLong(src, src + 1); + if (width >= 64) { + src += 8; + row[4] = LoadU8AndAddLong(src, src + 1); + src += 8; + row[5] = LoadU8AndAddLong(src, src + 1); + src += 8; + row[6] = LoadU8AndAddLong(src, src + 1); + src += 8; + row[7] = LoadU8AndAddLong(src, src + 1); + if (width == 128) { + src += 8; + row[8] = LoadU8AndAddLong(src, src + 1); + src += 8; + row[9] = LoadU8AndAddLong(src, src + 1); + src += 8; + row[10] = LoadU8AndAddLong(src, src + 1); + src += 8; + row[11] = LoadU8AndAddLong(src, src + 1); + src += 8; + row[12] = LoadU8AndAddLong(src, src + 1); + src += 8; + row[13] = LoadU8AndAddLong(src, src + 1); + src += 8; + row[14] = LoadU8AndAddLong(src, src + 1); + src += 8; + row[15] = LoadU8AndAddLong(src, src + 1); + } + } + } + } + src += src_remainder_stride; + + int y = height; + do { + const __m128i below_0 = LoadU8AndAddLong(src, src + 1); + StoreLo8(dst, AddU16RightShift2AndPack(row[0], below_0)); + row[0] = below_0; + if (width >= 16) { + src += 8; + dst += 8; + + const __m128i below_1 = LoadU8AndAddLong(src, src + 1); + StoreLo8(dst, AddU16RightShift2AndPack(row[1], below_1)); + row[1] = below_1; + if (width >= 32) { + src += 8; + dst += 8; + + const __m128i below_2 = LoadU8AndAddLong(src, src + 1); + StoreLo8(dst, AddU16RightShift2AndPack(row[2], below_2)); + row[2] = below_2; + src += 8; + dst += 8; + + const __m128i below_3 = LoadU8AndAddLong(src, src + 1); + StoreLo8(dst, AddU16RightShift2AndPack(row[3], below_3)); + row[3] = below_3; + if (width >= 64) { + src += 8; + dst += 8; + + const __m128i below_4 = LoadU8AndAddLong(src, src + 1); + StoreLo8(dst, AddU16RightShift2AndPack(row[4], below_4)); + row[4] = below_4; + src += 8; + dst += 8; + + const __m128i below_5 = LoadU8AndAddLong(src, src + 1); + StoreLo8(dst, AddU16RightShift2AndPack(row[5], below_5)); + row[5] = below_5; + src += 8; + dst += 8; + + const __m128i below_6 = LoadU8AndAddLong(src, src + 1); + StoreLo8(dst, AddU16RightShift2AndPack(row[6], below_6)); + row[6] = below_6; + src += 8; + dst += 8; + + const __m128i below_7 = LoadU8AndAddLong(src, src + 1); + StoreLo8(dst, AddU16RightShift2AndPack(row[7], below_7)); + row[7] = below_7; + if (width == 128) { + src += 8; + dst += 8; + + const __m128i below_8 = LoadU8AndAddLong(src, src + 1); + StoreLo8(dst, AddU16RightShift2AndPack(row[8], below_8)); + row[8] = below_8; + src += 8; + dst += 8; + + const __m128i below_9 = LoadU8AndAddLong(src, src + 1); + StoreLo8(dst, AddU16RightShift2AndPack(row[9], below_9)); + row[9] = below_9; + src += 8; + dst += 8; + + const __m128i below_10 = LoadU8AndAddLong(src, src + 1); + StoreLo8(dst, AddU16RightShift2AndPack(row[10], below_10)); + row[10] = below_10; + src += 8; + dst += 8; + + const __m128i below_11 = LoadU8AndAddLong(src, src + 1); + StoreLo8(dst, AddU16RightShift2AndPack(row[11], below_11)); + row[11] = below_11; + src += 8; + dst += 8; + + const __m128i below_12 = LoadU8AndAddLong(src, src + 1); + StoreLo8(dst, AddU16RightShift2AndPack(row[12], below_12)); + row[12] = below_12; + src += 8; + dst += 8; + + const __m128i below_13 = LoadU8AndAddLong(src, src + 1); + StoreLo8(dst, AddU16RightShift2AndPack(row[13], below_13)); + row[13] = below_13; + src += 8; + dst += 8; + + const __m128i below_14 = LoadU8AndAddLong(src, src + 1); + StoreLo8(dst, AddU16RightShift2AndPack(row[14], below_14)); + row[14] = below_14; + src += 8; + dst += 8; + + const __m128i below_15 = LoadU8AndAddLong(src, src + 1); + StoreLo8(dst, AddU16RightShift2AndPack(row[15], below_15)); + row[15] = below_15; + } + } + } + } + src += src_remainder_stride; + dst += dst_remainder_stride; + } while (--y != 0); +} + +void ConvolveIntraBlockCopy2D_SSE4_1( + const void* const reference, const ptrdiff_t reference_stride, + const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/, + const int /*horizontal_filter_id*/, const int /*vertical_filter_id*/, + const int width, const int height, void* const prediction, + const ptrdiff_t pred_stride) { + const auto* src = static_cast<const uint8_t*>(reference); + auto* dest = static_cast<uint8_t*>(prediction); + // Note: allow vertical access to height + 1. Because this function is only + // for u/v plane of intra block copy, such access is guaranteed to be within + // the prediction block. + + if (width == 128) { + IntraBlockCopy2D<128>(src, reference_stride, height, dest, pred_stride); + } else if (width == 64) { + IntraBlockCopy2D<64>(src, reference_stride, height, dest, pred_stride); + } else if (width == 32) { + IntraBlockCopy2D<32>(src, reference_stride, height, dest, pred_stride); + } else if (width == 16) { + IntraBlockCopy2D<16>(src, reference_stride, height, dest, pred_stride); + } else if (width == 8) { + IntraBlockCopy2D<8>(src, reference_stride, height, dest, pred_stride); + } else if (width == 4) { + __m128i left = _mm_cvtepu8_epi16(Load4(src)); + __m128i right = _mm_cvtepu8_epi16(Load4(src + 1)); + src += reference_stride; + + __m128i row = _mm_add_epi16(left, right); + + int y = height; + do { + left = Load4(src); + right = Load4(src + 1); + src += reference_stride; + left = _mm_unpacklo_epi32(left, Load4(src)); + right = _mm_unpacklo_epi32(right, Load4(src + 1)); + src += reference_stride; + + const __m128i below = + _mm_add_epi16(_mm_cvtepu8_epi16(left), _mm_cvtepu8_epi16(right)); + const __m128i result = + AddU16RightShift2AndPack(_mm_unpacklo_epi64(row, below), below); + + Store4(dest, result); + dest += pred_stride; + Store4(dest, _mm_srli_si128(result, 4)); + dest += pred_stride; + + row = _mm_srli_si128(below, 8); + y -= 2; + } while (y != 0); + } else { + __m128i left = Load2(src); + __m128i right = Load2(src + 1); + src += reference_stride; + + __m128i row = + _mm_add_epi16(_mm_cvtepu8_epi16(left), _mm_cvtepu8_epi16(right)); + + int y = height; + do { + left = Load2<0>(src, left); + right = Load2<0>(src + 1, right); + src += reference_stride; + left = Load2<2>(src, left); + right = Load2<2>(src + 1, right); + src += reference_stride; + + const __m128i below = + _mm_add_epi16(_mm_cvtepu8_epi16(left), _mm_cvtepu8_epi16(right)); + const __m128i result = + AddU16RightShift2AndPack(_mm_unpacklo_epi64(row, below), below); + + Store2(dest, result); + dest += pred_stride; + Store2(dest, _mm_srli_si128(result, 4)); + dest += pred_stride; + + row = _mm_srli_si128(below, 8); + y -= 2; + } while (y != 0); + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->convolve[0][0][0][1] = ConvolveHorizontal_SSE4_1; + dsp->convolve[0][0][1][0] = ConvolveVertical_SSE4_1; + dsp->convolve[0][0][1][1] = Convolve2D_SSE4_1; + + dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_SSE4; + dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_SSE4_1; + dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_SSE4_1; + dsp->convolve[0][1][1][1] = ConvolveCompound2D_SSE4_1; + + dsp->convolve[1][0][0][1] = ConvolveIntraBlockCopyHorizontal_SSE4_1; + dsp->convolve[1][0][1][0] = ConvolveIntraBlockCopyVertical_SSE4_1; + dsp->convolve[1][0][1][1] = ConvolveIntraBlockCopy2D_SSE4_1; + + dsp->convolve_scale[0] = ConvolveScale2D_SSE4_1<false>; + dsp->convolve_scale[1] = ConvolveScale2D_SSE4_1<true>; +} + +} // namespace +} // namespace low_bitdepth + +void ConvolveInit_SSE4_1() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_TARGETING_SSE4_1 +namespace libgav1 { +namespace dsp { + +void ConvolveInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/convolve_sse4.h b/src/dsp/x86/convolve_sse4.h new file mode 100644 index 0000000..d6c3155 --- /dev/null +++ b/src/dsp/x86/convolve_sse4.h @@ -0,0 +1,75 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_CONVOLVE_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_CONVOLVE_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::convolve, see the defines below for specifics. This +// function is not thread-safe. +void ConvolveInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +// If sse4 is enabled and the baseline isn't set due to a higher level of +// optimization being enabled, signal the sse4 implementation should be used. +#if LIBGAV1_TARGETING_SSE4_1 + +#ifndef LIBGAV1_Dsp8bpp_ConvolveHorizontal +#define LIBGAV1_Dsp8bpp_ConvolveHorizontal LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_ConvolveVertical +#define LIBGAV1_Dsp8bpp_ConvolveVertical LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_Convolve2D +#define LIBGAV1_Dsp8bpp_Convolve2D LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_ConvolveCompoundCopy +#define LIBGAV1_Dsp8bpp_ConvolveCompoundCopy LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_ConvolveCompoundHorizontal +#define LIBGAV1_Dsp8bpp_ConvolveCompoundHorizontal LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_ConvolveCompoundVertical +#define LIBGAV1_Dsp8bpp_ConvolveCompoundVertical LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_ConvolveCompound2D +#define LIBGAV1_Dsp8bpp_ConvolveCompound2D LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_ConvolveScale2D +#define LIBGAV1_Dsp8bpp_ConvolveScale2D LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_ConvolveCompoundScale2D +#define LIBGAV1_Dsp8bpp_ConvolveCompoundScale2D LIBGAV1_CPU_SSE4_1 +#endif + +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_X86_CONVOLVE_SSE4_H_ diff --git a/src/dsp/x86/distance_weighted_blend_sse4.cc b/src/dsp/x86/distance_weighted_blend_sse4.cc new file mode 100644 index 0000000..deb57ef --- /dev/null +++ b/src/dsp/x86/distance_weighted_blend_sse4.cc @@ -0,0 +1,230 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/distance_weighted_blend.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 + +#include <xmmintrin.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace { + +constexpr int kInterPostRoundBit = 4; + +inline __m128i ComputeWeightedAverage8(const __m128i& pred0, + const __m128i& pred1, + const __m128i& weights) { + // TODO(https://issuetracker.google.com/issues/150325685): Investigate range. + const __m128i preds_lo = _mm_unpacklo_epi16(pred0, pred1); + const __m128i mult_lo = _mm_madd_epi16(preds_lo, weights); + const __m128i result_lo = + RightShiftWithRounding_S32(mult_lo, kInterPostRoundBit + 4); + + const __m128i preds_hi = _mm_unpackhi_epi16(pred0, pred1); + const __m128i mult_hi = _mm_madd_epi16(preds_hi, weights); + const __m128i result_hi = + RightShiftWithRounding_S32(mult_hi, kInterPostRoundBit + 4); + + return _mm_packs_epi32(result_lo, result_hi); +} + +template <int height> +inline void DistanceWeightedBlend4xH_SSE4_1( + const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0, + const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) { + auto* dst = static_cast<uint8_t*>(dest); + const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16)); + + for (int y = 0; y < height; y += 4) { + // TODO(b/150326556): Use larger loads. + const __m128i src_00 = LoadLo8(pred_0); + const __m128i src_10 = LoadLo8(pred_1); + pred_0 += 4; + pred_1 += 4; + __m128i src_0 = LoadHi8(src_00, pred_0); + __m128i src_1 = LoadHi8(src_10, pred_1); + pred_0 += 4; + pred_1 += 4; + const __m128i res0 = ComputeWeightedAverage8(src_0, src_1, weights); + + const __m128i src_01 = LoadLo8(pred_0); + const __m128i src_11 = LoadLo8(pred_1); + pred_0 += 4; + pred_1 += 4; + src_0 = LoadHi8(src_01, pred_0); + src_1 = LoadHi8(src_11, pred_1); + pred_0 += 4; + pred_1 += 4; + const __m128i res1 = ComputeWeightedAverage8(src_0, src_1, weights); + + const __m128i result_pixels = _mm_packus_epi16(res0, res1); + Store4(dst, result_pixels); + dst += dest_stride; + const int result_1 = _mm_extract_epi32(result_pixels, 1); + memcpy(dst, &result_1, sizeof(result_1)); + dst += dest_stride; + const int result_2 = _mm_extract_epi32(result_pixels, 2); + memcpy(dst, &result_2, sizeof(result_2)); + dst += dest_stride; + const int result_3 = _mm_extract_epi32(result_pixels, 3); + memcpy(dst, &result_3, sizeof(result_3)); + dst += dest_stride; + } +} + +template <int height> +inline void DistanceWeightedBlend8xH_SSE4_1( + const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0, + const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) { + auto* dst = static_cast<uint8_t*>(dest); + const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16)); + + for (int y = 0; y < height; y += 2) { + const __m128i src_00 = LoadAligned16(pred_0); + const __m128i src_10 = LoadAligned16(pred_1); + pred_0 += 8; + pred_1 += 8; + const __m128i res0 = ComputeWeightedAverage8(src_00, src_10, weights); + + const __m128i src_01 = LoadAligned16(pred_0); + const __m128i src_11 = LoadAligned16(pred_1); + pred_0 += 8; + pred_1 += 8; + const __m128i res1 = ComputeWeightedAverage8(src_01, src_11, weights); + + const __m128i result_pixels = _mm_packus_epi16(res0, res1); + StoreLo8(dst, result_pixels); + dst += dest_stride; + StoreHi8(dst, result_pixels); + dst += dest_stride; + } +} + +inline void DistanceWeightedBlendLarge_SSE4_1( + const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0, + const uint8_t weight_1, const int width, const int height, void* const dest, + const ptrdiff_t dest_stride) { + auto* dst = static_cast<uint8_t*>(dest); + const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16)); + + int y = height; + do { + int x = 0; + do { + const __m128i src_0_lo = LoadAligned16(pred_0 + x); + const __m128i src_1_lo = LoadAligned16(pred_1 + x); + const __m128i res_lo = + ComputeWeightedAverage8(src_0_lo, src_1_lo, weights); + + const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8); + const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8); + const __m128i res_hi = + ComputeWeightedAverage8(src_0_hi, src_1_hi, weights); + + StoreUnaligned16(dst + x, _mm_packus_epi16(res_lo, res_hi)); + x += 16; + } while (x < width); + dst += dest_stride; + pred_0 += width; + pred_1 += width; + } while (--y != 0); +} + +void DistanceWeightedBlend_SSE4_1(const void* prediction_0, + const void* prediction_1, + const uint8_t weight_0, + const uint8_t weight_1, const int width, + const int height, void* const dest, + const ptrdiff_t dest_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + if (width == 4) { + if (height == 4) { + DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1, + dest, dest_stride); + } else if (height == 8) { + DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1, + dest, dest_stride); + } else { + assert(height == 16); + DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1, + dest, dest_stride); + } + return; + } + + if (width == 8) { + switch (height) { + case 4: + DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1, + dest, dest_stride); + return; + case 8: + DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1, + dest, dest_stride); + return; + case 16: + DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1, + dest, dest_stride); + return; + default: + assert(height == 32); + DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1, + dest, dest_stride); + + return; + } + } + + DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width, + height, dest, dest_stride); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); +#if DSP_ENABLED_8BPP_SSE4_1(DistanceWeightedBlend) + dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1; +#endif +} + +} // namespace + +void DistanceWeightedBlendInit_SSE4_1() { Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_TARGETING_SSE4_1 + +namespace libgav1 { +namespace dsp { + +void DistanceWeightedBlendInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/distance_weighted_blend_sse4.h b/src/dsp/x86/distance_weighted_blend_sse4.h new file mode 100644 index 0000000..8646eca --- /dev/null +++ b/src/dsp/x86/distance_weighted_blend_sse4.h @@ -0,0 +1,41 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_DISTANCE_WEIGHTED_BLEND_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_DISTANCE_WEIGHTED_BLEND_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::distance_weighted_blend. This function is not thread-safe. +void DistanceWeightedBlendInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +// If sse4 is enabled and the baseline isn't set due to a higher level of +// optimization being enabled, signal the sse4 implementation should be used. +#if LIBGAV1_TARGETING_SSE4_1 +#ifndef LIBGAV1_Dsp8bpp_DistanceWeightedBlend +#define LIBGAV1_Dsp8bpp_DistanceWeightedBlend LIBGAV1_CPU_SSE4_1 +#endif + +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_X86_DISTANCE_WEIGHTED_BLEND_SSE4_H_ diff --git a/src/dsp/x86/intra_edge_sse4.cc b/src/dsp/x86/intra_edge_sse4.cc new file mode 100644 index 0000000..4a8658d --- /dev/null +++ b/src/dsp/x86/intra_edge_sse4.cc @@ -0,0 +1,270 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intra_edge.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 + +#include <xmmintrin.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> // memcpy + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace { + +constexpr int kKernelTaps = 5; +constexpr int kKernels[3][kKernelTaps] = { + {0, 4, 8, 4, 0}, {0, 5, 6, 5, 0}, {2, 4, 4, 4, 2}}; +constexpr int kMaxEdgeBufferSize = 129; + +// This function applies the kernel [0, 4, 8, 4, 0] to 12 values. +// Assumes |edge| has 16 packed byte values. Produces 12 filter outputs to +// write as overlapping sets of 8-bytes. +inline void ComputeKernel1Store12(uint8_t* dest, const uint8_t* source) { + const __m128i edge_lo = LoadUnaligned16(source); + const __m128i edge_hi = _mm_srli_si128(edge_lo, 6); + // Samples matched with the '4' tap, expanded to 16-bit. + const __m128i outers_lo = _mm_cvtepu8_epi16(edge_lo); + const __m128i outers_hi = _mm_cvtepu8_epi16(edge_hi); + // Samples matched with the '8' tap, expanded to 16-bit. + const __m128i centers_lo = _mm_srli_si128(outers_lo, 2); + const __m128i centers_hi = _mm_srli_si128(outers_hi, 2); + + // Apply the taps by shifting. + const __m128i outers4_lo = _mm_slli_epi16(outers_lo, 2); + const __m128i outers4_hi = _mm_slli_epi16(outers_hi, 2); + const __m128i centers8_lo = _mm_slli_epi16(centers_lo, 3); + const __m128i centers8_hi = _mm_slli_epi16(centers_hi, 3); + // Move latter 4x values down to add with first 4x values for each output. + const __m128i partial_sums_lo = + _mm_add_epi16(outers4_lo, _mm_srli_si128(outers4_lo, 4)); + const __m128i partial_sums_hi = + _mm_add_epi16(outers4_hi, _mm_srli_si128(outers4_hi, 4)); + // Move 6x values down to add for the final kernel sum for each output. + const __m128i sums_lo = RightShiftWithRounding_U16( + _mm_add_epi16(partial_sums_lo, centers8_lo), 4); + const __m128i sums_hi = RightShiftWithRounding_U16( + _mm_add_epi16(partial_sums_hi, centers8_hi), 4); + + const __m128i result_lo = _mm_packus_epi16(sums_lo, sums_lo); + const __m128i result_hi = _mm_packus_epi16(sums_hi, sums_hi); + const __m128i result = + _mm_alignr_epi8(result_hi, _mm_slli_si128(result_lo, 10), 10); + StoreUnaligned16(dest, result); +} + +// This function applies the kernel [0, 5, 6, 5, 0] to 12 values. +// Assumes |edge| has 8 packed byte values, and that the 2 invalid values will +// be overwritten or safely discarded. +inline void ComputeKernel2Store12(uint8_t* dest, const uint8_t* source) { + const __m128i edge_lo = LoadUnaligned16(source); + const __m128i edge_hi = _mm_srli_si128(edge_lo, 6); + const __m128i outers_lo = _mm_cvtepu8_epi16(edge_lo); + const __m128i centers_lo = _mm_srli_si128(outers_lo, 2); + const __m128i outers_hi = _mm_cvtepu8_epi16(edge_hi); + const __m128i centers_hi = _mm_srli_si128(outers_hi, 2); + // Samples matched with the '5' tap, expanded to 16-bit. Add x + 4x. + const __m128i outers5_lo = + _mm_add_epi16(outers_lo, _mm_slli_epi16(outers_lo, 2)); + const __m128i outers5_hi = + _mm_add_epi16(outers_hi, _mm_slli_epi16(outers_hi, 2)); + // Samples matched with the '6' tap, expanded to 16-bit. Add 2x + 4x. + const __m128i centers6_lo = _mm_add_epi16(_mm_slli_epi16(centers_lo, 1), + _mm_slli_epi16(centers_lo, 2)); + const __m128i centers6_hi = _mm_add_epi16(_mm_slli_epi16(centers_hi, 1), + _mm_slli_epi16(centers_hi, 2)); + // Move latter 5x values down to add with first 5x values for each output. + const __m128i partial_sums_lo = + _mm_add_epi16(outers5_lo, _mm_srli_si128(outers5_lo, 4)); + // Move 6x values down to add for the final kernel sum for each output. + const __m128i sums_lo = RightShiftWithRounding_U16( + _mm_add_epi16(centers6_lo, partial_sums_lo), 4); + // Shift latter 5x values to add with first 5x values for each output. + const __m128i partial_sums_hi = + _mm_add_epi16(outers5_hi, _mm_srli_si128(outers5_hi, 4)); + // Move 6x values down to add for the final kernel sum for each output. + const __m128i sums_hi = RightShiftWithRounding_U16( + _mm_add_epi16(centers6_hi, partial_sums_hi), 4); + // First 6 values are valid outputs. + const __m128i result_lo = _mm_packus_epi16(sums_lo, sums_lo); + const __m128i result_hi = _mm_packus_epi16(sums_hi, sums_hi); + const __m128i result = + _mm_alignr_epi8(result_hi, _mm_slli_si128(result_lo, 10), 10); + StoreUnaligned16(dest, result); +} + +// This function applies the kernel [2, 4, 4, 4, 2] to 8 values. +inline void ComputeKernel3Store8(uint8_t* dest, const uint8_t* source) { + const __m128i edge_lo = LoadUnaligned16(source); + const __m128i edge_hi = _mm_srli_si128(edge_lo, 4); + // Finish |edge_lo| life cycle quickly. + // Multiply for 2x. + const __m128i source2_lo = _mm_slli_epi16(_mm_cvtepu8_epi16(edge_lo), 1); + // Multiply 2x by 2 and align. + const __m128i source4_lo = _mm_srli_si128(_mm_slli_epi16(source2_lo, 1), 2); + // Finish |source2| life cycle quickly. + // Move latter 2x values down to add with first 2x values for each output. + __m128i sum = _mm_add_epi16(source2_lo, _mm_srli_si128(source2_lo, 8)); + // First 4x values already aligned to add with running total. + sum = _mm_add_epi16(sum, source4_lo); + // Move second 4x values down to add with running total. + sum = _mm_add_epi16(sum, _mm_srli_si128(source4_lo, 2)); + // Move third 4x values down to add with running total. + sum = _mm_add_epi16(sum, _mm_srli_si128(source4_lo, 4)); + // Multiply for 2x. + const __m128i source2_hi = _mm_slli_epi16(_mm_cvtepu8_epi16(edge_hi), 1); + // Multiply 2x by 2 and align. + const __m128i source4_hi = _mm_srli_si128(_mm_slli_epi16(source2_hi, 1), 2); + // Move latter 2x values down to add with first 2x values for each output. + __m128i sum_hi = _mm_add_epi16(source2_hi, _mm_srli_si128(source2_hi, 8)); + // First 4x values already aligned to add with running total. + sum_hi = _mm_add_epi16(sum_hi, source4_hi); + // Move second 4x values down to add with running total. + sum_hi = _mm_add_epi16(sum_hi, _mm_srli_si128(source4_hi, 2)); + // Move third 4x values down to add with running total. + sum_hi = _mm_add_epi16(sum_hi, _mm_srli_si128(source4_hi, 4)); + + // Because we have only 8 values here, it is safe to align before packing down + // to 8-bit without losing data. + sum = _mm_alignr_epi8(sum_hi, _mm_slli_si128(sum, 8), 8); + sum = RightShiftWithRounding_U16(sum, 4); + StoreLo8(dest, _mm_packus_epi16(sum, sum)); +} + +void IntraEdgeFilter_SSE4_1(void* buffer, int size, int strength) { + uint8_t edge[kMaxEdgeBufferSize + 4]; + memcpy(edge, buffer, size); + auto* dst_buffer = static_cast<uint8_t*>(buffer); + + // Only process |size| - 1 elements. Nothing to do in this case. + if (size == 1) return; + + int i = 0; + switch (strength) { + case 1: + // To avoid overwriting, we stop short from the total write size plus the + // initial offset. In this case 12 valid values are written in two blocks + // of 8 bytes each. + for (; i < size - 17; i += 12) { + ComputeKernel1Store12(dst_buffer + i + 1, edge + i); + } + break; + case 2: + // See the comment for case 1. + for (; i < size - 17; i += 12) { + ComputeKernel2Store12(dst_buffer + i + 1, edge + i); + } + break; + default: + assert(strength == 3); + // The first filter input is repeated for taps of value 2 and 4. + dst_buffer[1] = RightShiftWithRounding( + (6 * edge[0] + 4 * edge[1] + 4 * edge[2] + 2 * edge[3]), 4); + // In this case, one block of 8 bytes is written in each iteration, with + // an offset of 2. + for (; i < size - 10; i += 8) { + ComputeKernel3Store8(dst_buffer + i + 2, edge + i); + } + } + const int kernel_index = strength - 1; + for (int final_index = Clip3(i, 1, size - 2); final_index < size; + ++final_index) { + int sum = 0; + for (int j = 0; j < kKernelTaps; ++j) { + const int k = Clip3(final_index + j - 2, 0, size - 1); + sum += kKernels[kernel_index][j] * edge[k]; + } + dst_buffer[final_index] = RightShiftWithRounding(sum, 4); + } +} + +constexpr int kMaxUpsampleSize = 16; + +// Applies the upsampling kernel [-1, 9, 9, -1] to alternating pixels, and +// interleaves the results with the original values. This implementation assumes +// that it is safe to write the maximum number of upsampled pixels (32) to the +// edge buffer, even when |size| is small. +void IntraEdgeUpsampler_SSE4_1(void* buffer, int size) { + assert(size % 4 == 0 && size <= kMaxUpsampleSize); + auto* const pixel_buffer = static_cast<uint8_t*>(buffer); + uint8_t temp[kMaxUpsampleSize + 8]; + temp[0] = temp[1] = pixel_buffer[-1]; + memcpy(temp + 2, pixel_buffer, sizeof(temp[0]) * size); + temp[size + 2] = pixel_buffer[size - 1]; + + pixel_buffer[-2] = temp[0]; + const __m128i data = LoadUnaligned16(temp); + const __m128i src_lo = _mm_cvtepu8_epi16(data); + const __m128i src_hi = _mm_unpackhi_epi8(data, _mm_setzero_si128()); + const __m128i src9_hi = _mm_add_epi16(src_hi, _mm_slli_epi16(src_hi, 3)); + const __m128i src9_lo = _mm_add_epi16(src_lo, _mm_slli_epi16(src_lo, 3)); + __m128i sum_lo = _mm_sub_epi16(_mm_alignr_epi8(src9_hi, src9_lo, 2), src_lo); + sum_lo = _mm_add_epi16(sum_lo, _mm_alignr_epi8(src9_hi, src9_lo, 4)); + sum_lo = _mm_sub_epi16(sum_lo, _mm_alignr_epi8(src_hi, src_lo, 6)); + sum_lo = RightShiftWithRounding_S16(sum_lo, 4); + const __m128i result_lo = _mm_unpacklo_epi8(_mm_packus_epi16(sum_lo, sum_lo), + _mm_srli_si128(data, 2)); + StoreUnaligned16(pixel_buffer - 1, result_lo); + if (size > 8) { + const __m128i src_hi_extra = _mm_cvtepu8_epi16(LoadLo8(temp + 16)); + const __m128i src9_hi_extra = + _mm_add_epi16(src_hi_extra, _mm_slli_epi16(src_hi_extra, 3)); + __m128i sum_hi = + _mm_sub_epi16(_mm_alignr_epi8(src9_hi_extra, src9_hi, 2), src_hi); + sum_hi = _mm_add_epi16(sum_hi, _mm_alignr_epi8(src9_hi_extra, src9_hi, 4)); + sum_hi = _mm_sub_epi16(sum_hi, _mm_alignr_epi8(src_hi_extra, src_hi, 6)); + sum_hi = RightShiftWithRounding_S16(sum_hi, 4); + const __m128i result_hi = + _mm_unpacklo_epi8(_mm_packus_epi16(sum_hi, sum_hi), LoadLo8(temp + 10)); + StoreUnaligned16(pixel_buffer + 15, result_hi); + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); +#if DSP_ENABLED_8BPP_SSE4_1(IntraEdgeFilter) + dsp->intra_edge_filter = IntraEdgeFilter_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(IntraEdgeUpsampler) + dsp->intra_edge_upsampler = IntraEdgeUpsampler_SSE4_1; +#endif +} + +} // namespace + +void IntraEdgeInit_SSE4_1() { Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_TARGETING_SSE4_1 +namespace libgav1 { +namespace dsp { + +void IntraEdgeInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/intra_edge_sse4.h b/src/dsp/x86/intra_edge_sse4.h new file mode 100644 index 0000000..6ed4d40 --- /dev/null +++ b/src/dsp/x86/intra_edge_sse4.h @@ -0,0 +1,46 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_INTRA_EDGE_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_INTRA_EDGE_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::intra_edge_filter and Dsp::intra_edge_upsampler. This +// function is not thread-safe. +void IntraEdgeInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +// If sse4 is enabled and the baseline isn't set due to a higher level of +// optimization being enabled, signal the sse4 implementation should be used. +#if LIBGAV1_TARGETING_SSE4_1 +#ifndef LIBGAV1_Dsp8bpp_IntraEdgeFilter +#define LIBGAV1_Dsp8bpp_IntraEdgeFilter LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_IntraEdgeUpsampler +#define LIBGAV1_Dsp8bpp_IntraEdgeUpsampler LIBGAV1_CPU_SSE4_1 +#endif + +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_X86_INTRA_EDGE_SSE4_H_ diff --git a/src/dsp/x86/intrapred_cfl_sse4.cc b/src/dsp/x86/intrapred_cfl_sse4.cc new file mode 100644 index 0000000..fac1556 --- /dev/null +++ b/src/dsp/x86/intrapred_cfl_sse4.cc @@ -0,0 +1,976 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intrapred.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 + +#include <smmintrin.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/utils/common.h" +#include "src/utils/compiler_attributes.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +//------------------------------------------------------------------------------ +// CflIntraPredictor_SSE4_1 + +inline __m128i CflPredictUnclipped(const __m128i* input, __m128i alpha_q12, + __m128i alpha_sign, __m128i dc_q0) { + __m128i ac_q3 = LoadUnaligned16(input); + __m128i ac_sign = _mm_sign_epi16(alpha_sign, ac_q3); + __m128i scaled_luma_q0 = _mm_mulhrs_epi16(_mm_abs_epi16(ac_q3), alpha_q12); + scaled_luma_q0 = _mm_sign_epi16(scaled_luma_q0, ac_sign); + return _mm_add_epi16(scaled_luma_q0, dc_q0); +} + +template <int width, int height> +void CflIntraPredictor_SSE4_1( + void* const dest, ptrdiff_t stride, + const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int alpha) { + auto* dst = static_cast<uint8_t*>(dest); + const __m128i alpha_sign = _mm_set1_epi16(alpha); + const __m128i alpha_q12 = _mm_slli_epi16(_mm_abs_epi16(alpha_sign), 9); + auto* row = reinterpret_cast<const __m128i*>(luma); + const int kCflLumaBufferStrideLog2_16i = 5; + const int kCflLumaBufferStrideLog2_128i = kCflLumaBufferStrideLog2_16i - 3; + const __m128i* row_end = row + (height << kCflLumaBufferStrideLog2_128i); + const __m128i dc_val = _mm_set1_epi16(dst[0]); + do { + __m128i res = CflPredictUnclipped(row, alpha_q12, alpha_sign, dc_val); + if (width < 16) { + res = _mm_packus_epi16(res, res); + if (width == 4) { + Store4(dst, res); + } else { + StoreLo8(dst, res); + } + } else { + __m128i next = + CflPredictUnclipped(row + 1, alpha_q12, alpha_sign, dc_val); + res = _mm_packus_epi16(res, next); + StoreUnaligned16(dst, res); + if (width == 32) { + res = CflPredictUnclipped(row + 2, alpha_q12, alpha_sign, dc_val); + next = CflPredictUnclipped(row + 3, alpha_q12, alpha_sign, dc_val); + res = _mm_packus_epi16(res, next); + StoreUnaligned16(dst + 16, res); + } + } + dst += stride; + } while ((row += (1 << kCflLumaBufferStrideLog2_128i)) < row_end); +} + +template <int block_height_log2, bool is_inside> +void CflSubsampler444_4xH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int /*max_luma_width*/, const int max_luma_height, + const void* const source, ptrdiff_t stride) { + static_assert(block_height_log2 <= 4, ""); + const int block_height = 1 << block_height_log2; + const int visible_height = max_luma_height; + const auto* src = static_cast<const uint8_t*>(source); + __m128i sum = _mm_setzero_si128(); + int16_t* luma_ptr = luma[0]; + const __m128i zero = _mm_setzero_si128(); + __m128i samples; + int y = 0; + do { + samples = Load4(src); + src += stride; + int src_bytes; + memcpy(&src_bytes, src, 4); + samples = _mm_insert_epi32(samples, src_bytes, 1); + src += stride; + samples = _mm_slli_epi16(_mm_cvtepu8_epi16(samples), 3); + StoreLo8(luma_ptr, samples); + luma_ptr += kCflLumaBufferStride; + StoreHi8(luma_ptr, samples); + luma_ptr += kCflLumaBufferStride; + + // The maximum value here is 2**bd * H * 2**shift. Since the maximum H for + // 4XH is 16 = 2**4, we have 2**(8 + 4 + 3) = 2**15, which fits in 16 bits. + sum = _mm_add_epi16(sum, samples); + y += 2; + } while (y < visible_height); + + if (!is_inside) { + int y = visible_height; + do { + StoreHi8(luma_ptr, samples); + luma_ptr += kCflLumaBufferStride; + sum = _mm_add_epi16(sum, samples); + ++y; + } while (y < block_height); + } + + __m128i sum_tmp = _mm_unpackhi_epi16(sum, zero); + sum = _mm_cvtepu16_epi32(sum); + sum = _mm_add_epi32(sum, sum_tmp); + sum = _mm_add_epi32(sum, _mm_srli_si128(sum, 8)); + sum = _mm_add_epi32(sum, _mm_srli_si128(sum, 4)); + + __m128i averages = RightShiftWithRounding_U32( + sum, block_height_log2 + 2 /* log2 of width 4 */); + averages = _mm_shufflelo_epi16(averages, 0); + luma_ptr = luma[0]; + for (int y = 0; y < block_height; ++y, luma_ptr += kCflLumaBufferStride) { + const __m128i samples = LoadLo8(luma_ptr); + StoreLo8(luma_ptr, _mm_sub_epi16(samples, averages)); + } +} + +template <int block_height_log2> +void CflSubsampler444_4xH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_width, const int max_luma_height, + const void* const source, ptrdiff_t stride) { + static_assert(block_height_log2 <= 4, ""); + assert(max_luma_width >= 4); + assert(max_luma_height >= 4); + const int block_height = 1 << block_height_log2; + const int block_width = 4; + + if (block_height <= max_luma_height && block_width <= max_luma_width) { + CflSubsampler444_4xH_SSE4_1<block_height_log2, true>( + luma, max_luma_width, max_luma_height, source, stride); + } else { + CflSubsampler444_4xH_SSE4_1<block_height_log2, false>( + luma, max_luma_width, max_luma_height, source, stride); + } +} + +template <int block_height_log2, bool inside> +void CflSubsampler444_8xH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_width, const int max_luma_height, + const void* const source, ptrdiff_t stride) { + static_assert(block_height_log2 <= 5, ""); + const int block_height = 1 << block_height_log2, block_width = 8; + const int visible_height = max_luma_height; + const int invisible_width = inside ? 0 : block_width - max_luma_width; + const int visible_width = max_luma_width; + const __m128i blend_mask = + inside ? _mm_setzero_si128() : MaskHighNBytes(8 + invisible_width); + const __m128i dup16 = _mm_set1_epi32(0x01000100); + const auto* src = static_cast<const uint8_t*>(source); + int16_t* luma_ptr = luma[0]; + const __m128i zero = _mm_setzero_si128(); + // Since the maximum height is 32, if we split them by parity, each one only + // needs to accumulate 16 rows. Just like the calculation done in 4XH, we can + // store them in 16 bits without casting to 32 bits. + __m128i sum_even = _mm_setzero_si128(), sum_odd = _mm_setzero_si128(); + __m128i sum; + __m128i samples1; + + int y = 0; + do { + __m128i samples0 = LoadLo8(src); + if (!inside) { + const __m128i border0 = + _mm_set1_epi8(static_cast<int8_t>(src[visible_width - 1])); + samples0 = _mm_blendv_epi8(samples0, border0, blend_mask); + } + src += stride; + samples0 = _mm_slli_epi16(_mm_cvtepu8_epi16(samples0), 3); + StoreUnaligned16(luma_ptr, samples0); + luma_ptr += kCflLumaBufferStride; + + sum_even = _mm_add_epi16(sum_even, samples0); + + samples1 = LoadLo8(src); + if (!inside) { + const __m128i border1 = + _mm_set1_epi8(static_cast<int8_t>(src[visible_width - 1])); + samples1 = _mm_blendv_epi8(samples1, border1, blend_mask); + } + src += stride; + samples1 = _mm_slli_epi16(_mm_cvtepu8_epi16(samples1), 3); + StoreUnaligned16(luma_ptr, samples1); + luma_ptr += kCflLumaBufferStride; + + sum_odd = _mm_add_epi16(sum_odd, samples1); + y += 2; + } while (y < visible_height); + + if (!inside) { + for (int y = visible_height; y < block_height; y += 2) { + sum_even = _mm_add_epi16(sum_even, samples1); + StoreUnaligned16(luma_ptr, samples1); + luma_ptr += kCflLumaBufferStride; + + sum_odd = _mm_add_epi16(sum_odd, samples1); + StoreUnaligned16(luma_ptr, samples1); + luma_ptr += kCflLumaBufferStride; + } + } + + sum = _mm_add_epi32(_mm_unpackhi_epi16(sum_even, zero), + _mm_cvtepu16_epi32(sum_even)); + sum = _mm_add_epi32(sum, _mm_unpackhi_epi16(sum_odd, zero)); + sum = _mm_add_epi32(sum, _mm_cvtepu16_epi32(sum_odd)); + + sum = _mm_add_epi32(sum, _mm_srli_si128(sum, 8)); + sum = _mm_add_epi32(sum, _mm_srli_si128(sum, 4)); + + __m128i averages = RightShiftWithRounding_U32( + sum, block_height_log2 + 3 /* log2 of width 8 */); + averages = _mm_shuffle_epi8(averages, dup16); + luma_ptr = luma[0]; + for (int y = 0; y < block_height; ++y, luma_ptr += kCflLumaBufferStride) { + const __m128i samples = LoadUnaligned16(luma_ptr); + StoreUnaligned16(luma_ptr, _mm_sub_epi16(samples, averages)); + } +} + +template <int block_height_log2> +void CflSubsampler444_8xH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_width, const int max_luma_height, + const void* const source, ptrdiff_t stride) { + static_assert(block_height_log2 <= 5, ""); + assert(max_luma_width >= 4); + assert(max_luma_height >= 4); + const int block_height = 1 << block_height_log2; + const int block_width = 8; + + const int horz_inside = block_width <= max_luma_width; + const int vert_inside = block_height <= max_luma_height; + if (horz_inside && vert_inside) { + CflSubsampler444_8xH_SSE4_1<block_height_log2, true>( + luma, max_luma_width, max_luma_height, source, stride); + } else { + CflSubsampler444_8xH_SSE4_1<block_height_log2, false>( + luma, max_luma_width, max_luma_height, source, stride); + } +} + +// This function will only work for block_width 16 and 32. +template <int block_width_log2, int block_height_log2, bool inside> +void CflSubsampler444_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_width, const int max_luma_height, + const void* const source, ptrdiff_t stride) { + static_assert(block_width_log2 == 4 || block_width_log2 == 5, ""); + static_assert(block_height_log2 <= 5, ""); + assert(max_luma_width >= 4); + assert(max_luma_height >= 4); + const int block_height = 1 << block_height_log2; + const int block_width = 1 << block_width_log2; + + const int visible_height = max_luma_height; + const int visible_width_16 = inside ? 16 : std::min(16, max_luma_width); + const int invisible_width_16 = 16 - visible_width_16; + const __m128i blend_mask_16 = MaskHighNBytes(invisible_width_16); + const int visible_width_32 = inside ? 32 : max_luma_width; + const int invisible_width_32 = 32 - visible_width_32; + const __m128i blend_mask_32 = + MaskHighNBytes(std::min(16, invisible_width_32)); + + const __m128i dup16 = _mm_set1_epi32(0x01000100); + const __m128i zero = _mm_setzero_si128(); + const auto* src = static_cast<const uint8_t*>(source); + int16_t* luma_ptr = luma[0]; + __m128i sum = _mm_setzero_si128(); + + __m128i samples0, samples1; + __m128i samples2, samples3; + __m128i inner_sum_lo, inner_sum_hi; + int y = 0; + do { +#if LIBGAV1_MSAN // We can load uninitialized values here. Even though they are + // then masked off by blendv, MSAN isn't smart enough to + // understand that. So we switch to a C implementation here. + uint16_t c_arr[16]; + for (int x = 0; x < 16; x++) { + const int x_index = std::min(x, visible_width_16 - 1); + c_arr[x] = src[x_index] << 3; + } + samples0 = LoadUnaligned16(c_arr); + samples1 = LoadUnaligned16(c_arr + 8); + static_cast<void>(blend_mask_16); +#else + __m128i samples01 = LoadUnaligned16(src); + + if (!inside) { + const __m128i border16 = + _mm_set1_epi8(static_cast<int8_t>(src[visible_width_16 - 1])); + samples01 = _mm_blendv_epi8(samples01, border16, blend_mask_16); + } + samples0 = _mm_slli_epi16(_mm_cvtepu8_epi16(samples01), 3); + samples1 = _mm_slli_epi16(_mm_unpackhi_epi8(samples01, zero), 3); +#endif // LIBGAV1_MSAN + + StoreUnaligned16(luma_ptr, samples0); + StoreUnaligned16(luma_ptr + 8, samples1); + __m128i inner_sum = _mm_add_epi16(samples0, samples1); + + if (block_width == 32) { +#if LIBGAV1_MSAN // We can load uninitialized values here. Even though they are + // then masked off by blendv, MSAN isn't smart enough to + // understand that. So we switch to a C implementation here. + uint16_t c_arr[16]; + for (int x = 16; x < 32; x++) { + const int x_index = std::min(x, visible_width_32 - 1); + c_arr[x - 16] = src[x_index] << 3; + } + samples2 = LoadUnaligned16(c_arr); + samples3 = LoadUnaligned16(c_arr + 8); + static_cast<void>(blend_mask_32); +#else + __m128i samples23 = LoadUnaligned16(src + 16); + if (!inside) { + const __m128i border32 = + _mm_set1_epi8(static_cast<int8_t>(src[visible_width_32 - 1])); + samples23 = _mm_blendv_epi8(samples23, border32, blend_mask_32); + } + samples2 = _mm_slli_epi16(_mm_cvtepu8_epi16(samples23), 3); + samples3 = _mm_slli_epi16(_mm_unpackhi_epi8(samples23, zero), 3); +#endif // LIBGAV1_MSAN + + StoreUnaligned16(luma_ptr + 16, samples2); + StoreUnaligned16(luma_ptr + 24, samples3); + inner_sum = _mm_add_epi16(samples2, inner_sum); + inner_sum = _mm_add_epi16(samples3, inner_sum); + } + + inner_sum_lo = _mm_cvtepu16_epi32(inner_sum); + inner_sum_hi = _mm_unpackhi_epi16(inner_sum, zero); + sum = _mm_add_epi32(sum, inner_sum_lo); + sum = _mm_add_epi32(sum, inner_sum_hi); + luma_ptr += kCflLumaBufferStride; + src += stride; + } while (++y < visible_height); + + if (!inside) { + for (int y = visible_height; y < block_height; + luma_ptr += kCflLumaBufferStride, ++y) { + sum = _mm_add_epi32(sum, inner_sum_lo); + StoreUnaligned16(luma_ptr, samples0); + sum = _mm_add_epi32(sum, inner_sum_hi); + StoreUnaligned16(luma_ptr + 8, samples1); + if (block_width == 32) { + StoreUnaligned16(luma_ptr + 16, samples2); + StoreUnaligned16(luma_ptr + 24, samples3); + } + } + } + + sum = _mm_add_epi32(sum, _mm_srli_si128(sum, 8)); + sum = _mm_add_epi32(sum, _mm_srli_si128(sum, 4)); + + __m128i averages = + RightShiftWithRounding_U32(sum, block_width_log2 + block_height_log2); + averages = _mm_shuffle_epi8(averages, dup16); + luma_ptr = luma[0]; + for (int y = 0; y < block_height; ++y, luma_ptr += kCflLumaBufferStride) { + for (int x = 0; x < block_width; x += 8) { + __m128i samples = LoadUnaligned16(&luma_ptr[x]); + StoreUnaligned16(&luma_ptr[x], _mm_sub_epi16(samples, averages)); + } + } +} + +template <int block_width_log2, int block_height_log2> +void CflSubsampler444_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_width, const int max_luma_height, + const void* const source, ptrdiff_t stride) { + static_assert(block_width_log2 == 4 || block_width_log2 == 5, ""); + static_assert(block_height_log2 <= 5, ""); + assert(max_luma_width >= 4); + assert(max_luma_height >= 4); + + const int block_height = 1 << block_height_log2; + const int block_width = 1 << block_width_log2; + const int horz_inside = block_width <= max_luma_width; + const int vert_inside = block_height <= max_luma_height; + if (horz_inside && vert_inside) { + CflSubsampler444_SSE4_1<block_width_log2, block_height_log2, true>( + luma, max_luma_width, max_luma_height, source, stride); + } else { + CflSubsampler444_SSE4_1<block_width_log2, block_height_log2, false>( + luma, max_luma_width, max_luma_height, source, stride); + } +} + +// Takes in two sums of input row pairs, and completes the computation for two +// output rows. +inline __m128i StoreLumaResults4_420(const __m128i vertical_sum0, + const __m128i vertical_sum1, + int16_t* luma_ptr) { + __m128i result = _mm_hadd_epi16(vertical_sum0, vertical_sum1); + result = _mm_slli_epi16(result, 1); + StoreLo8(luma_ptr, result); + StoreHi8(luma_ptr + kCflLumaBufferStride, result); + return result; +} + +// Takes two halves of a vertically added pair of rows and completes the +// computation for one output row. +inline __m128i StoreLumaResults8_420(const __m128i vertical_sum0, + const __m128i vertical_sum1, + int16_t* luma_ptr) { + __m128i result = _mm_hadd_epi16(vertical_sum0, vertical_sum1); + result = _mm_slli_epi16(result, 1); + StoreUnaligned16(luma_ptr, result); + return result; +} + +template <int block_height_log2> +void CflSubsampler420_4xH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int /*max_luma_width*/, const int max_luma_height, + const void* const source, ptrdiff_t stride) { + const int block_height = 1 << block_height_log2; + const auto* src = static_cast<const uint8_t*>(source); + int16_t* luma_ptr = luma[0]; + const __m128i zero = _mm_setzero_si128(); + __m128i final_sum = zero; + const int luma_height = std::min(block_height, max_luma_height >> 1); + int y = 0; + do { + // Note that double sampling and converting to 16bit makes a row fill the + // vector. + const __m128i samples_row0 = _mm_cvtepu8_epi16(LoadLo8(src)); + src += stride; + const __m128i samples_row1 = _mm_cvtepu8_epi16(LoadLo8(src)); + src += stride; + const __m128i luma_sum01 = _mm_add_epi16(samples_row0, samples_row1); + + const __m128i samples_row2 = _mm_cvtepu8_epi16(LoadLo8(src)); + src += stride; + const __m128i samples_row3 = _mm_cvtepu8_epi16(LoadLo8(src)); + src += stride; + const __m128i luma_sum23 = _mm_add_epi16(samples_row2, samples_row3); + __m128i sum = StoreLumaResults4_420(luma_sum01, luma_sum23, luma_ptr); + luma_ptr += kCflLumaBufferStride << 1; + + const __m128i samples_row4 = _mm_cvtepu8_epi16(LoadLo8(src)); + src += stride; + const __m128i samples_row5 = _mm_cvtepu8_epi16(LoadLo8(src)); + src += stride; + const __m128i luma_sum45 = _mm_add_epi16(samples_row4, samples_row5); + + const __m128i samples_row6 = _mm_cvtepu8_epi16(LoadLo8(src)); + src += stride; + const __m128i samples_row7 = _mm_cvtepu8_epi16(LoadLo8(src)); + src += stride; + const __m128i luma_sum67 = _mm_add_epi16(samples_row6, samples_row7); + sum = _mm_add_epi16( + sum, StoreLumaResults4_420(luma_sum45, luma_sum67, luma_ptr)); + luma_ptr += kCflLumaBufferStride << 1; + + final_sum = _mm_add_epi32(final_sum, _mm_cvtepu16_epi32(sum)); + final_sum = _mm_add_epi32(final_sum, _mm_unpackhi_epi16(sum, zero)); + y += 4; + } while (y < luma_height); + const __m128i final_fill = LoadLo8(luma_ptr - kCflLumaBufferStride); + const __m128i final_fill_to_sum = _mm_cvtepu16_epi32(final_fill); + for (; y < block_height; ++y) { + StoreLo8(luma_ptr, final_fill); + luma_ptr += kCflLumaBufferStride; + + final_sum = _mm_add_epi32(final_sum, final_fill_to_sum); + } + final_sum = _mm_add_epi32(final_sum, _mm_srli_si128(final_sum, 8)); + final_sum = _mm_add_epi32(final_sum, _mm_srli_si128(final_sum, 4)); + + __m128i averages = RightShiftWithRounding_U32( + final_sum, block_height_log2 + 2 /*log2 of width 4*/); + + averages = _mm_shufflelo_epi16(averages, 0); + luma_ptr = luma[0]; + for (int y = 0; y < block_height; ++y, luma_ptr += kCflLumaBufferStride) { + const __m128i samples = LoadLo8(luma_ptr); + StoreLo8(luma_ptr, _mm_sub_epi16(samples, averages)); + } +} + +// This duplicates the last two 16-bit values in |row|. +inline __m128i LastRowSamples(const __m128i row) { + return _mm_shuffle_epi32(row, 0xFF); +} + +// This duplicates the last 16-bit value in |row|. +inline __m128i LastRowResult(const __m128i row) { + const __m128i dup_row = _mm_shufflehi_epi16(row, 0xFF); + return _mm_shuffle_epi32(dup_row, 0xFF); +} + +template <int block_height_log2, int max_luma_width> +inline void CflSubsampler420Impl_8xH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int /*max_luma_width*/, const int max_luma_height, + const void* const source, ptrdiff_t stride) { + const int block_height = 1 << block_height_log2; + const auto* src = static_cast<const uint8_t*>(source); + const __m128i zero = _mm_setzero_si128(); + __m128i final_sum = zero; + int16_t* luma_ptr = luma[0]; + const int luma_height = std::min(block_height, max_luma_height >> 1); + int y = 0; + + do { + const __m128i samples_row00 = _mm_cvtepu8_epi16(LoadLo8(src)); + const __m128i samples_row01 = (max_luma_width == 16) + ? _mm_cvtepu8_epi16(LoadLo8(src + 8)) + : LastRowSamples(samples_row00); + src += stride; + const __m128i samples_row10 = _mm_cvtepu8_epi16(LoadLo8(src)); + const __m128i samples_row11 = (max_luma_width == 16) + ? _mm_cvtepu8_epi16(LoadLo8(src + 8)) + : LastRowSamples(samples_row10); + src += stride; + const __m128i luma_sum00 = _mm_add_epi16(samples_row00, samples_row10); + const __m128i luma_sum01 = _mm_add_epi16(samples_row01, samples_row11); + __m128i sum = StoreLumaResults8_420(luma_sum00, luma_sum01, luma_ptr); + luma_ptr += kCflLumaBufferStride; + + const __m128i samples_row20 = _mm_cvtepu8_epi16(LoadLo8(src)); + const __m128i samples_row21 = (max_luma_width == 16) + ? _mm_cvtepu8_epi16(LoadLo8(src + 8)) + : LastRowSamples(samples_row20); + src += stride; + const __m128i samples_row30 = _mm_cvtepu8_epi16(LoadLo8(src)); + const __m128i samples_row31 = (max_luma_width == 16) + ? _mm_cvtepu8_epi16(LoadLo8(src + 8)) + : LastRowSamples(samples_row30); + src += stride; + const __m128i luma_sum10 = _mm_add_epi16(samples_row20, samples_row30); + const __m128i luma_sum11 = _mm_add_epi16(samples_row21, samples_row31); + sum = _mm_add_epi16( + sum, StoreLumaResults8_420(luma_sum10, luma_sum11, luma_ptr)); + luma_ptr += kCflLumaBufferStride; + + const __m128i samples_row40 = _mm_cvtepu8_epi16(LoadLo8(src)); + const __m128i samples_row41 = (max_luma_width == 16) + ? _mm_cvtepu8_epi16(LoadLo8(src + 8)) + : LastRowSamples(samples_row40); + src += stride; + const __m128i samples_row50 = _mm_cvtepu8_epi16(LoadLo8(src)); + const __m128i samples_row51 = (max_luma_width == 16) + ? _mm_cvtepu8_epi16(LoadLo8(src + 8)) + : LastRowSamples(samples_row50); + src += stride; + const __m128i luma_sum20 = _mm_add_epi16(samples_row40, samples_row50); + const __m128i luma_sum21 = _mm_add_epi16(samples_row41, samples_row51); + sum = _mm_add_epi16( + sum, StoreLumaResults8_420(luma_sum20, luma_sum21, luma_ptr)); + luma_ptr += kCflLumaBufferStride; + + const __m128i samples_row60 = _mm_cvtepu8_epi16(LoadLo8(src)); + const __m128i samples_row61 = (max_luma_width == 16) + ? _mm_cvtepu8_epi16(LoadLo8(src + 8)) + : LastRowSamples(samples_row60); + src += stride; + const __m128i samples_row70 = _mm_cvtepu8_epi16(LoadLo8(src)); + const __m128i samples_row71 = (max_luma_width == 16) + ? _mm_cvtepu8_epi16(LoadLo8(src + 8)) + : LastRowSamples(samples_row70); + src += stride; + const __m128i luma_sum30 = _mm_add_epi16(samples_row60, samples_row70); + const __m128i luma_sum31 = _mm_add_epi16(samples_row61, samples_row71); + sum = _mm_add_epi16( + sum, StoreLumaResults8_420(luma_sum30, luma_sum31, luma_ptr)); + luma_ptr += kCflLumaBufferStride; + + final_sum = _mm_add_epi32(final_sum, _mm_cvtepu16_epi32(sum)); + final_sum = _mm_add_epi32(final_sum, _mm_unpackhi_epi16(sum, zero)); + y += 4; + } while (y < luma_height); + // Duplicate the final row downward to the end after max_luma_height. + const __m128i final_fill = LoadUnaligned16(luma_ptr - kCflLumaBufferStride); + const __m128i final_fill_to_sum0 = _mm_cvtepi16_epi32(final_fill); + const __m128i final_fill_to_sum1 = + _mm_cvtepi16_epi32(_mm_srli_si128(final_fill, 8)); + const __m128i final_fill_to_sum = + _mm_add_epi32(final_fill_to_sum0, final_fill_to_sum1); + for (; y < block_height; ++y) { + StoreUnaligned16(luma_ptr, final_fill); + luma_ptr += kCflLumaBufferStride; + + final_sum = _mm_add_epi32(final_sum, final_fill_to_sum); + } + final_sum = _mm_add_epi32(final_sum, _mm_srli_si128(final_sum, 8)); + final_sum = _mm_add_epi32(final_sum, _mm_srli_si128(final_sum, 4)); + + __m128i averages = RightShiftWithRounding_S32( + final_sum, block_height_log2 + 3 /*log2 of width 8*/); + + averages = _mm_shufflelo_epi16(averages, 0); + averages = _mm_shuffle_epi32(averages, 0); + luma_ptr = luma[0]; + for (int y = 0; y < block_height; ++y, luma_ptr += kCflLumaBufferStride) { + const __m128i samples = LoadUnaligned16(luma_ptr); + StoreUnaligned16(luma_ptr, _mm_sub_epi16(samples, averages)); + } +} + +template <int block_height_log2> +void CflSubsampler420_8xH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_width, const int max_luma_height, + const void* const source, ptrdiff_t stride) { + if (max_luma_width == 8) { + CflSubsampler420Impl_8xH_SSE4_1<block_height_log2, 8>( + luma, max_luma_width, max_luma_height, source, stride); + } else { + CflSubsampler420Impl_8xH_SSE4_1<block_height_log2, 16>( + luma, max_luma_width, max_luma_height, source, stride); + } +} + +template <int block_width_log2, int block_height_log2, int max_luma_width> +inline void CflSubsampler420Impl_WxH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int /*max_luma_width*/, const int max_luma_height, + const void* const source, ptrdiff_t stride) { + const auto* src = static_cast<const uint8_t*>(source); + const __m128i zero = _mm_setzero_si128(); + __m128i final_sum = zero; + const int block_height = 1 << block_height_log2; + const int luma_height = std::min(block_height, max_luma_height >> 1); + + int16_t* luma_ptr = luma[0]; + __m128i final_row_result; + // Begin first y section, covering width up to 16. + int y = 0; + do { + const uint8_t* src_next = src + stride; + const __m128i samples_row0_lo = LoadUnaligned16(src); + const __m128i samples_row00 = _mm_cvtepu8_epi16(samples_row0_lo); + const __m128i samples_row01 = (max_luma_width >= 16) + ? _mm_unpackhi_epi8(samples_row0_lo, zero) + : LastRowSamples(samples_row00); + const __m128i samples_row0_hi = LoadUnaligned16(src + 16); + const __m128i samples_row02 = (max_luma_width >= 24) + ? _mm_cvtepu8_epi16(samples_row0_hi) + : LastRowSamples(samples_row01); + const __m128i samples_row03 = (max_luma_width == 32) + ? _mm_unpackhi_epi8(samples_row0_hi, zero) + : LastRowSamples(samples_row02); + const __m128i samples_row1_lo = LoadUnaligned16(src_next); + const __m128i samples_row10 = _mm_cvtepu8_epi16(samples_row1_lo); + const __m128i samples_row11 = (max_luma_width >= 16) + ? _mm_unpackhi_epi8(samples_row1_lo, zero) + : LastRowSamples(samples_row10); + const __m128i samples_row1_hi = LoadUnaligned16(src_next + 16); + const __m128i samples_row12 = (max_luma_width >= 24) + ? _mm_cvtepu8_epi16(samples_row1_hi) + : LastRowSamples(samples_row11); + const __m128i samples_row13 = (max_luma_width == 32) + ? _mm_unpackhi_epi8(samples_row1_hi, zero) + : LastRowSamples(samples_row12); + const __m128i luma_sum0 = _mm_add_epi16(samples_row00, samples_row10); + const __m128i luma_sum1 = _mm_add_epi16(samples_row01, samples_row11); + const __m128i luma_sum2 = _mm_add_epi16(samples_row02, samples_row12); + const __m128i luma_sum3 = _mm_add_epi16(samples_row03, samples_row13); + __m128i sum = StoreLumaResults8_420(luma_sum0, luma_sum1, luma_ptr); + final_row_result = + StoreLumaResults8_420(luma_sum2, luma_sum3, luma_ptr + 8); + sum = _mm_add_epi16(sum, final_row_result); + final_sum = _mm_add_epi32(final_sum, _mm_cvtepu16_epi32(sum)); + final_sum = _mm_add_epi32(final_sum, _mm_unpackhi_epi16(sum, zero)); + src += stride << 1; + luma_ptr += kCflLumaBufferStride; + } while (++y < luma_height); + + // Because max_luma_width is at most 32, any values beyond x=16 will + // necessarily be duplicated. + if (block_width_log2 == 5) { + const __m128i wide_fill = LastRowResult(final_row_result); + // Multiply duplicated value by number of occurrences, height * 4, since + // there are 16 in each row and the value appears in the vector 4 times. + final_sum = _mm_add_epi32( + final_sum, + _mm_slli_epi32(_mm_cvtepi16_epi32(wide_fill), block_height_log2 + 2)); + } + + // Begin second y section. + if (y < block_height) { + const __m128i final_fill0 = + LoadUnaligned16(luma_ptr - kCflLumaBufferStride); + const __m128i final_fill1 = + LoadUnaligned16(luma_ptr - kCflLumaBufferStride + 8); + const __m128i final_inner_sum = _mm_add_epi16(final_fill0, final_fill1); + const __m128i final_inner_sum0 = _mm_cvtepu16_epi32(final_inner_sum); + const __m128i final_inner_sum1 = _mm_unpackhi_epi16(final_inner_sum, zero); + const __m128i final_fill_to_sum = + _mm_add_epi32(final_inner_sum0, final_inner_sum1); + + do { + StoreUnaligned16(luma_ptr, final_fill0); + StoreUnaligned16(luma_ptr + 8, final_fill1); + luma_ptr += kCflLumaBufferStride; + + final_sum = _mm_add_epi32(final_sum, final_fill_to_sum); + } while (++y < block_height); + } // End second y section. + + final_sum = _mm_add_epi32(final_sum, _mm_srli_si128(final_sum, 8)); + final_sum = _mm_add_epi32(final_sum, _mm_srli_si128(final_sum, 4)); + + __m128i averages = RightShiftWithRounding_S32( + final_sum, block_width_log2 + block_height_log2); + averages = _mm_shufflelo_epi16(averages, 0); + averages = _mm_shuffle_epi32(averages, 0); + + luma_ptr = luma[0]; + for (int y = 0; y < block_height; ++y, luma_ptr += kCflLumaBufferStride) { + const __m128i samples0 = LoadUnaligned16(luma_ptr); + StoreUnaligned16(luma_ptr, _mm_sub_epi16(samples0, averages)); + const __m128i samples1 = LoadUnaligned16(luma_ptr + 8); + final_row_result = _mm_sub_epi16(samples1, averages); + StoreUnaligned16(luma_ptr + 8, final_row_result); + } + if (block_width_log2 == 5) { + int16_t* wide_luma_ptr = luma[0] + 16; + const __m128i wide_fill = LastRowResult(final_row_result); + for (int i = 0; i < block_height; + ++i, wide_luma_ptr += kCflLumaBufferStride) { + StoreUnaligned16(wide_luma_ptr, wide_fill); + StoreUnaligned16(wide_luma_ptr + 8, wide_fill); + } + } +} + +template <int block_width_log2, int block_height_log2> +void CflSubsampler420_WxH_SSE4_1( + int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride], + const int max_luma_width, const int max_luma_height, + const void* const source, ptrdiff_t stride) { + switch (max_luma_width) { + case 8: + CflSubsampler420Impl_WxH_SSE4_1<block_width_log2, block_height_log2, 8>( + luma, max_luma_width, max_luma_height, source, stride); + return; + case 16: + CflSubsampler420Impl_WxH_SSE4_1<block_width_log2, block_height_log2, 16>( + luma, max_luma_width, max_luma_height, source, stride); + return; + case 24: + CflSubsampler420Impl_WxH_SSE4_1<block_width_log2, block_height_log2, 24>( + luma, max_luma_width, max_luma_height, source, stride); + return; + default: + assert(max_luma_width == 32); + CflSubsampler420Impl_WxH_SSE4_1<block_width_log2, block_height_log2, 32>( + luma, max_luma_width, max_luma_height, source, stride); + return; + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType420] = + CflSubsampler420_4xH_SSE4_1<2>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x8_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType420] = + CflSubsampler420_4xH_SSE4_1<3>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x16_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType420] = + CflSubsampler420_4xH_SSE4_1<4>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x4_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType420] = + CflSubsampler420_8xH_SSE4_1<2>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x8_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType420] = + CflSubsampler420_8xH_SSE4_1<3>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x16_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType420] = + CflSubsampler420_8xH_SSE4_1<4>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x32_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType420] = + CflSubsampler420_8xH_SSE4_1<5>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x4_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType420] = + CflSubsampler420_WxH_SSE4_1<4, 2>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x8_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType420] = + CflSubsampler420_WxH_SSE4_1<4, 3>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x16_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType420] = + CflSubsampler420_WxH_SSE4_1<4, 4>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x32_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType420] = + CflSubsampler420_WxH_SSE4_1<4, 5>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x8_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType420] = + CflSubsampler420_WxH_SSE4_1<5, 3>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x16_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType420] = + CflSubsampler420_WxH_SSE4_1<5, 4>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x32_CflSubsampler420) + dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType420] = + CflSubsampler420_WxH_SSE4_1<5, 5>; +#endif + +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType444] = + CflSubsampler444_4xH_SSE4_1<2>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x8_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType444] = + CflSubsampler444_4xH_SSE4_1<3>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x16_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType444] = + CflSubsampler444_4xH_SSE4_1<4>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x4_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType444] = + CflSubsampler444_8xH_SSE4_1<2>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x8_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType444] = + CflSubsampler444_8xH_SSE4_1<3>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x16_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType444] = + CflSubsampler444_8xH_SSE4_1<4>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x32_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType444] = + CflSubsampler444_8xH_SSE4_1<5>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x4_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType444] = + CflSubsampler444_SSE4_1<4, 2>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x8_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType444] = + CflSubsampler444_SSE4_1<4, 3>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x16_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType444] = + CflSubsampler444_SSE4_1<4, 4>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x32_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType444] = + CflSubsampler444_SSE4_1<4, 5>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x8_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType444] = + CflSubsampler444_SSE4_1<5, 3>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x16_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType444] = + CflSubsampler444_SSE4_1<5, 4>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x32_CflSubsampler444) + dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType444] = + CflSubsampler444_SSE4_1<5, 5>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize4x4] = CflIntraPredictor_SSE4_1<4, 4>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x8_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize4x8] = CflIntraPredictor_SSE4_1<4, 8>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x16_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize4x16] = + CflIntraPredictor_SSE4_1<4, 16>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x4_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize8x4] = CflIntraPredictor_SSE4_1<8, 4>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x8_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize8x8] = CflIntraPredictor_SSE4_1<8, 8>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x16_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize8x16] = + CflIntraPredictor_SSE4_1<8, 16>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x32_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize8x32] = + CflIntraPredictor_SSE4_1<8, 32>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x4_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize16x4] = + CflIntraPredictor_SSE4_1<16, 4>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x8_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize16x8] = + CflIntraPredictor_SSE4_1<16, 8>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x16_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize16x16] = + CflIntraPredictor_SSE4_1<16, 16>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x32_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize16x32] = + CflIntraPredictor_SSE4_1<16, 32>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x8_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize32x8] = + CflIntraPredictor_SSE4_1<32, 8>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x16_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize32x16] = + CflIntraPredictor_SSE4_1<32, 16>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x32_CflIntraPredictor) + dsp->cfl_intra_predictors[kTransformSize32x32] = + CflIntraPredictor_SSE4_1<32, 32>; +#endif +} + +} // namespace +} // namespace low_bitdepth + +void IntraPredCflInit_SSE4_1() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_TARGETING_SSE4_1 + +namespace libgav1 { +namespace dsp { + +void IntraPredCflInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/intrapred_smooth_sse4.cc b/src/dsp/x86/intrapred_smooth_sse4.cc new file mode 100644 index 0000000..e944ea3 --- /dev/null +++ b/src/dsp/x86/intrapred_smooth_sse4.cc @@ -0,0 +1,2662 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intrapred.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 + +#include <xmmintrin.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> // memcpy + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +// Note these constants are duplicated from intrapred.cc to allow the compiler +// to have visibility of the values. This helps reduce loads and in the +// creation of the inverse weights. +constexpr uint8_t kSmoothWeights[] = { + // block dimension = 4 + 255, 149, 85, 64, + // block dimension = 8 + 255, 197, 146, 105, 73, 50, 37, 32, + // block dimension = 16 + 255, 225, 196, 170, 145, 123, 102, 84, 68, 54, 43, 33, 26, 20, 17, 16, + // block dimension = 32 + 255, 240, 225, 210, 196, 182, 169, 157, 145, 133, 122, 111, 101, 92, 83, 74, + 66, 59, 52, 45, 39, 34, 29, 25, 21, 17, 14, 12, 10, 9, 8, 8, + // block dimension = 64 + 255, 248, 240, 233, 225, 218, 210, 203, 196, 189, 182, 176, 169, 163, 156, + 150, 144, 138, 133, 127, 121, 116, 111, 106, 101, 96, 91, 86, 82, 77, 73, + 69, 65, 61, 57, 54, 50, 47, 44, 41, 38, 35, 32, 29, 27, 25, 22, 20, 18, 16, + 15, 13, 12, 10, 9, 8, 7, 6, 6, 5, 5, 4, 4, 4}; + +template <int y_mask> +inline void WriteSmoothHorizontalSum4(void* const dest, const __m128i& left, + const __m128i& weights, + const __m128i& scaled_top_right, + const __m128i& round) { + const __m128i left_y = _mm_shuffle_epi32(left, y_mask); + const __m128i weighted_left_y = _mm_mullo_epi16(left_y, weights); + const __m128i pred_sum = _mm_add_epi32(scaled_top_right, weighted_left_y); + // Equivalent to RightShiftWithRounding(pred[x][y], 8). + const __m128i pred = _mm_srli_epi32(_mm_add_epi32(pred_sum, round), 8); + const __m128i cvtepi32_epi8 = _mm_set1_epi32(0x0C080400); + Store4(dest, _mm_shuffle_epi8(pred, cvtepi32_epi8)); +} + +template <int y_mask> +inline __m128i SmoothVerticalSum4(const __m128i& top, const __m128i& weights, + const __m128i& scaled_bottom_left) { + const __m128i weights_y = _mm_shuffle_epi32(weights, y_mask); + const __m128i weighted_top_y = _mm_mullo_epi16(top, weights_y); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi32(scaled_bottom_left, y_mask); + return _mm_add_epi32(scaled_bottom_left_y, weighted_top_y); +} + +template <int y_mask> +inline void WriteSmoothVerticalSum4(uint8_t* dest, const __m128i& top, + const __m128i& weights, + const __m128i& scaled_bottom_left, + const __m128i& round) { + __m128i pred_sum = + SmoothVerticalSum4<y_mask>(top, weights, scaled_bottom_left); + // Equivalent to RightShiftWithRounding(pred[x][y], 8). + pred_sum = _mm_srli_epi32(_mm_add_epi32(pred_sum, round), 8); + const __m128i cvtepi32_epi8 = _mm_set1_epi32(0x0C080400); + Store4(dest, _mm_shuffle_epi8(pred_sum, cvtepi32_epi8)); +} + +// For SMOOTH_H, |pixels| is the repeated left value for the row. For SMOOTH_V, +// |pixels| is a segment of the top row or the whole top row, and |weights| is +// repeated. +inline __m128i SmoothDirectionalSum8(const __m128i& pixels, + const __m128i& weights, + const __m128i& scaled_corner) { + const __m128i weighted_px = _mm_mullo_epi16(pixels, weights); + return _mm_add_epi16(scaled_corner, weighted_px); +} + +inline void WriteSmoothDirectionalSum8(uint8_t* dest, const __m128i& pixels, + const __m128i& weights, + const __m128i& scaled_corner, + const __m128i& round) { + const __m128i pred_sum = + SmoothDirectionalSum8(pixels, weights, scaled_corner); + // Equivalent to RightShiftWithRounding(pred[x][y], 8). + const __m128i pred = _mm_srli_epi16(_mm_add_epi16(pred_sum, round), 8); + StoreLo8(dest, _mm_packus_epi16(pred, pred)); +} + +// For Horizontal, pixels1 and pixels2 are the same repeated value. For +// Vertical, weights1 and weights2 are the same, and scaled_corner1 and +// scaled_corner2 are the same. +inline void WriteSmoothDirectionalSum16(uint8_t* dest, const __m128i& pixels1, + const __m128i& pixels2, + const __m128i& weights1, + const __m128i& weights2, + const __m128i& scaled_corner1, + const __m128i& scaled_corner2, + const __m128i& round) { + const __m128i weighted_px1 = _mm_mullo_epi16(pixels1, weights1); + const __m128i weighted_px2 = _mm_mullo_epi16(pixels2, weights2); + const __m128i pred_sum1 = _mm_add_epi16(scaled_corner1, weighted_px1); + const __m128i pred_sum2 = _mm_add_epi16(scaled_corner2, weighted_px2); + // Equivalent to RightShiftWithRounding(pred[x][y], 8). + const __m128i pred1 = _mm_srli_epi16(_mm_add_epi16(pred_sum1, round), 8); + const __m128i pred2 = _mm_srli_epi16(_mm_add_epi16(pred_sum2, round), 8); + StoreUnaligned16(dest, _mm_packus_epi16(pred1, pred2)); +} + +template <int y_mask> +inline void WriteSmoothPredSum4(uint8_t* const dest, const __m128i& top, + const __m128i& left, const __m128i& weights_x, + const __m128i& weights_y, + const __m128i& scaled_bottom_left, + const __m128i& scaled_top_right, + const __m128i& round) { + const __m128i left_y = _mm_shuffle_epi32(left, y_mask); + const __m128i weighted_left_y = _mm_mullo_epi32(left_y, weights_x); + const __m128i weight_y = _mm_shuffle_epi32(weights_y, y_mask); + const __m128i weighted_top = _mm_mullo_epi32(weight_y, top); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi32(scaled_bottom_left, y_mask); + const __m128i col_pred = _mm_add_epi32(scaled_bottom_left_y, weighted_left_y); + const __m128i row_pred = _mm_add_epi32(scaled_top_right, weighted_top); + const __m128i pred_sum = _mm_add_epi32(row_pred, col_pred); + + // Equivalent to RightShiftWithRounding(pred[x][y], 9). + const __m128i pred = _mm_srli_epi32(_mm_add_epi32(pred_sum, round), 9); + + const __m128i cvtepi32_epi8 = _mm_set1_epi32(0x0C080400); + Store4(dest, _mm_shuffle_epi8(pred, cvtepi32_epi8)); +} + +// pixels[0]: above and below_pred interleave vector +// pixels[1]: left vector +// pixels[2]: right_pred vector +inline void LoadSmoothPixels4(const uint8_t* above, const uint8_t* left, + const int height, __m128i* pixels) { + if (height == 4) { + pixels[1] = Load4(left); + } else if (height == 8) { + pixels[1] = LoadLo8(left); + } else { + pixels[1] = LoadUnaligned16(left); + } + + const __m128i bottom_left = _mm_set1_epi16(left[height - 1]); + const __m128i top = _mm_cvtepu8_epi16(Load4(above)); + pixels[0] = _mm_unpacklo_epi16(top, bottom_left); + pixels[2] = _mm_set1_epi16(above[3]); +} + +// weight_h[0]: weight_h vector +// weight_h[1]: scale - weight_h vector +// weight_h[2]: same as [0], second half for height = 16 only +// weight_h[3]: same as [1], second half for height = 16 only +// weight_w[0]: weights_w and scale - weights_w interleave vector +inline void LoadSmoothWeights4(const uint8_t* weight_array, const int height, + __m128i* weight_h, __m128i* weight_w) { + const __m128i scale = _mm_set1_epi16(256); + const __m128i x_weights = Load4(weight_array); + weight_h[0] = _mm_cvtepu8_epi16(x_weights); + weight_h[1] = _mm_sub_epi16(scale, weight_h[0]); + weight_w[0] = _mm_unpacklo_epi16(weight_h[0], weight_h[1]); + + if (height == 8) { + const __m128i y_weights = LoadLo8(weight_array + 4); + weight_h[0] = _mm_cvtepu8_epi16(y_weights); + weight_h[1] = _mm_sub_epi16(scale, weight_h[0]); + } else if (height == 16) { + const __m128i zero = _mm_setzero_si128(); + const __m128i y_weights = LoadUnaligned16(weight_array + 12); + weight_h[0] = _mm_cvtepu8_epi16(y_weights); + weight_h[1] = _mm_sub_epi16(scale, weight_h[0]); + weight_h[2] = _mm_unpackhi_epi8(y_weights, zero); + weight_h[3] = _mm_sub_epi16(scale, weight_h[2]); + } +} + +inline void WriteSmoothPred4x8(const __m128i* pixel, const __m128i* weights_y, + const __m128i* weight_x, uint8_t* dst, + const ptrdiff_t stride, + const bool use_second_half) { + const __m128i round = _mm_set1_epi32(256); + const __m128i mask_increment = _mm_set1_epi16(0x0202); + const __m128i cvtepi32_epi8 = _mm_set1_epi32(0x0C080400); + const __m128i zero = _mm_setzero_si128(); + const __m128i left = use_second_half ? _mm_unpackhi_epi8(pixel[1], zero) + : _mm_unpacklo_epi8(pixel[1], zero); + __m128i y_select = _mm_set1_epi16(0x0100); + + for (int i = 0; i < 8; ++i) { + const __m128i weight_y = _mm_shuffle_epi8(weights_y[0], y_select); + const __m128i inverted_weight_y = _mm_shuffle_epi8(weights_y[1], y_select); + const __m128i interleaved_weights = + _mm_unpacklo_epi16(weight_y, inverted_weight_y); + __m128i vertical_pred = _mm_madd_epi16(pixel[0], interleaved_weights); + + __m128i horizontal_vect = _mm_shuffle_epi8(left, y_select); + horizontal_vect = _mm_unpacklo_epi16(horizontal_vect, pixel[2]); + __m128i sum = _mm_madd_epi16(horizontal_vect, weight_x[0]); + + sum = _mm_add_epi32(vertical_pred, sum); + sum = _mm_add_epi32(sum, round); + sum = _mm_srai_epi32(sum, 9); + + sum = _mm_shuffle_epi8(sum, cvtepi32_epi8); + Store4(dst, sum); + dst += stride; + + y_select = _mm_add_epi16(y_select, mask_increment); + } +} + +// The interleaving approach has some overhead that causes it to underperform in +// the 4x4 case. +void Smooth4x4_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* top_row, const void* left_column) { + const __m128i top = _mm_cvtepu8_epi32(Load4(top_row)); + const __m128i left = _mm_cvtepu8_epi32(Load4(left_column)); + const __m128i weights = _mm_cvtepu8_epi32(Load4(kSmoothWeights)); + const __m128i scale = _mm_set1_epi32(256); + // Fourth short is top_row[3]. + const __m128i top_right = _mm_shuffle_epi32(top, 0xFF); + // Fourth short is left_column[3]. + const __m128i bottom_left = _mm_shuffle_epi32(left, 0xFF); + const __m128i inverted_weights = _mm_sub_epi32(scale, weights); + const __m128i scaled_top_right = _mm_mullo_epi16(inverted_weights, top_right); + const __m128i scaled_bottom_left = + _mm_mullo_epi16(inverted_weights, bottom_left); + auto* dst = static_cast<uint8_t*>(dest); + // AV1 spec 7.11.2.6 (3) describes the sum: + // smoothPred[y][x:x+3] = weighted_top + scaled_right + weighted_left[y] + + // scaled_bottom[y] This could be a loop, but for the immediate value in the + // shuffles. + WriteSmoothPredSum4<0>(dst, top, left, weights, weights, scaled_bottom_left, + scaled_top_right, scale); + dst += stride; + WriteSmoothPredSum4<0x55>(dst, top, left, weights, weights, + scaled_bottom_left, scaled_top_right, scale); + dst += stride; + WriteSmoothPredSum4<0xAA>(dst, top, left, weights, weights, + scaled_bottom_left, scaled_top_right, scale); + dst += stride; + WriteSmoothPredSum4<0xFF>(dst, top, left, weights, weights, + scaled_bottom_left, scaled_top_right, scale); +} + +void Smooth4x8_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* top_row, const void* left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + __m128i weights_x[1]; + __m128i weights_y[2]; + LoadSmoothWeights4(kSmoothWeights, 8, weights_y, weights_x); + __m128i pixels[3]; + LoadSmoothPixels4(top_ptr, left_ptr, 8, pixels); + auto* dst = static_cast<uint8_t*>(dest); + WriteSmoothPred4x8(pixels, weights_y, weights_x, dst, stride, false); +} + +void Smooth4x16_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* top_row, const void* left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + __m128i weights_x[1]; + __m128i weights_y[4]; + LoadSmoothWeights4(kSmoothWeights, 16, weights_y, weights_x); + __m128i pixels[3]; + LoadSmoothPixels4(top_ptr, left_ptr, 16, pixels); + auto* dst = static_cast<uint8_t*>(dest); + WriteSmoothPred4x8(pixels, weights_y, weights_x, dst, stride, false); + dst += stride << 3; + WriteSmoothPred4x8(pixels, &weights_y[2], weights_x, dst, stride, true); +} + +// pixels[0]: above and below_pred interleave vector, first half +// pixels[1]: above and below_pred interleave vector, second half +// pixels[2]: left vector +// pixels[3]: right_pred vector +// pixels[4]: above and below_pred interleave vector, first half +// pixels[5]: above and below_pred interleave vector, second half +// pixels[6]: left vector + 16 +// pixels[7]: right_pred vector +inline void LoadSmoothPixels8(const uint8_t* above, const uint8_t* left, + const int height, __m128i* pixels) { + const __m128i bottom_left = _mm_set1_epi16(left[height - 1]); + __m128i top_row = _mm_cvtepu8_epi16(LoadLo8(above)); + pixels[0] = _mm_unpacklo_epi16(top_row, bottom_left); + pixels[1] = _mm_unpackhi_epi16(top_row, bottom_left); + + pixels[3] = _mm_set1_epi16(above[7]); + + if (height == 4) { + pixels[2] = Load4(left); + } else if (height == 8) { + pixels[2] = LoadLo8(left); + } else if (height == 16) { + pixels[2] = LoadUnaligned16(left); + } else { + pixels[2] = LoadUnaligned16(left); + pixels[4] = pixels[0]; + pixels[5] = pixels[1]; + pixels[6] = LoadUnaligned16(left + 16); + pixels[7] = pixels[3]; + } +} + +// weight_h[0]: weight_h vector +// weight_h[1]: scale - weight_h vector +// weight_h[2]: same as [0], offset 8 +// weight_h[3]: same as [1], offset 8 +// weight_h[4]: same as [0], offset 16 +// weight_h[5]: same as [1], offset 16 +// weight_h[6]: same as [0], offset 24 +// weight_h[7]: same as [1], offset 24 +// weight_w[0]: weights_w and scale - weights_w interleave vector, first half +// weight_w[1]: weights_w and scale - weights_w interleave vector, second half +inline void LoadSmoothWeights8(const uint8_t* weight_array, const int height, + __m128i* weight_w, __m128i* weight_h) { + const int offset = (height < 8) ? 0 : 4; + __m128i loaded_weights = LoadUnaligned16(&weight_array[offset]); + weight_h[0] = _mm_cvtepu8_epi16(loaded_weights); + const __m128i inverter = _mm_set1_epi16(256); + weight_h[1] = _mm_sub_epi16(inverter, weight_h[0]); + + if (height == 4) { + loaded_weights = _mm_srli_si128(loaded_weights, 4); + __m128i weights_x = _mm_cvtepu8_epi16(loaded_weights); + __m128i inverted_weights_x = _mm_sub_epi16(inverter, weights_x); + weight_w[0] = _mm_unpacklo_epi16(weights_x, inverted_weights_x); + weight_w[1] = _mm_unpackhi_epi16(weights_x, inverted_weights_x); + } else { + weight_w[0] = _mm_unpacklo_epi16(weight_h[0], weight_h[1]); + weight_w[1] = _mm_unpackhi_epi16(weight_h[0], weight_h[1]); + } + + if (height == 16) { + const __m128i zero = _mm_setzero_si128(); + loaded_weights = LoadUnaligned16(weight_array + 12); + weight_h[0] = _mm_cvtepu8_epi16(loaded_weights); + weight_h[1] = _mm_sub_epi16(inverter, weight_h[0]); + weight_h[2] = _mm_unpackhi_epi8(loaded_weights, zero); + weight_h[3] = _mm_sub_epi16(inverter, weight_h[2]); + } else if (height == 32) { + const __m128i zero = _mm_setzero_si128(); + const __m128i weight_lo = LoadUnaligned16(weight_array + 28); + weight_h[0] = _mm_cvtepu8_epi16(weight_lo); + weight_h[1] = _mm_sub_epi16(inverter, weight_h[0]); + weight_h[2] = _mm_unpackhi_epi8(weight_lo, zero); + weight_h[3] = _mm_sub_epi16(inverter, weight_h[2]); + const __m128i weight_hi = LoadUnaligned16(weight_array + 44); + weight_h[4] = _mm_cvtepu8_epi16(weight_hi); + weight_h[5] = _mm_sub_epi16(inverter, weight_h[4]); + weight_h[6] = _mm_unpackhi_epi8(weight_hi, zero); + weight_h[7] = _mm_sub_epi16(inverter, weight_h[6]); + } +} + +inline void WriteSmoothPred8xH(const __m128i* pixels, const __m128i* weights_x, + const __m128i* weights_y, const int height, + uint8_t* dst, const ptrdiff_t stride, + const bool use_second_half) { + const __m128i round = _mm_set1_epi32(256); + const __m128i mask_increment = _mm_set1_epi16(0x0202); + const __m128i cvt_epu16_epi8 = _mm_set_epi32(0, 0, 0xe0c0a08, 0x6040200); + + const __m128i zero = _mm_setzero_si128(); + const __m128i left = use_second_half ? _mm_unpackhi_epi8(pixels[2], zero) + : _mm_unpacklo_epi8(pixels[2], zero); + __m128i y_select = _mm_set1_epi16(0x100); + + for (int i = 0; i < height; ++i) { + const __m128i weight_y = _mm_shuffle_epi8(weights_y[0], y_select); + const __m128i inverted_weight_y = _mm_shuffle_epi8(weights_y[1], y_select); + const __m128i interleaved_weights = + _mm_unpacklo_epi16(weight_y, inverted_weight_y); + const __m128i vertical_sum0 = + _mm_madd_epi16(pixels[0], interleaved_weights); + const __m128i vertical_sum1 = + _mm_madd_epi16(pixels[1], interleaved_weights); + + __m128i horizontal_pixels = _mm_shuffle_epi8(left, y_select); + horizontal_pixels = _mm_unpacklo_epi16(horizontal_pixels, pixels[3]); + const __m128i horizontal_sum0 = + _mm_madd_epi16(horizontal_pixels, weights_x[0]); + const __m128i horizontal_sum1 = + _mm_madd_epi16(horizontal_pixels, weights_x[1]); + + __m128i sum0 = _mm_add_epi32(vertical_sum0, horizontal_sum0); + sum0 = _mm_add_epi32(sum0, round); + sum0 = _mm_srai_epi32(sum0, 9); + + __m128i sum1 = _mm_add_epi32(vertical_sum1, horizontal_sum1); + sum1 = _mm_add_epi32(sum1, round); + sum1 = _mm_srai_epi32(sum1, 9); + + sum0 = _mm_packus_epi16(sum0, sum1); + sum0 = _mm_shuffle_epi8(sum0, cvt_epu16_epi8); + StoreLo8(dst, sum0); + dst += stride; + + y_select = _mm_add_epi16(y_select, mask_increment); + } +} + +void Smooth8x4_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* top_row, const void* left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + __m128i pixels[4]; + LoadSmoothPixels8(top_ptr, left_ptr, 4, pixels); + + __m128i weights_x[2], weights_y[2]; + LoadSmoothWeights8(kSmoothWeights, 4, weights_x, weights_y); + + auto* dst = static_cast<uint8_t*>(dest); + WriteSmoothPred8xH(pixels, weights_x, weights_y, 4, dst, stride, false); +} + +void Smooth8x8_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* top_row, const void* left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + + __m128i pixels[4]; + LoadSmoothPixels8(top_ptr, left_ptr, 8, pixels); + + __m128i weights_x[2], weights_y[2]; + LoadSmoothWeights8(kSmoothWeights, 8, weights_x, weights_y); + + auto* dst = static_cast<uint8_t*>(dest); + WriteSmoothPred8xH(pixels, weights_x, weights_y, 8, dst, stride, false); +} + +void Smooth8x16_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* top_row, const void* left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + __m128i pixels[4]; + LoadSmoothPixels8(top_ptr, left_ptr, 16, pixels); + + __m128i weights_x[2], weights_y[4]; + LoadSmoothWeights8(kSmoothWeights, 16, weights_x, weights_y); + + auto* dst = static_cast<uint8_t*>(dest); + WriteSmoothPred8xH(pixels, weights_x, weights_y, 8, dst, stride, false); + dst += stride << 3; + WriteSmoothPred8xH(pixels, weights_x, &weights_y[2], 8, dst, stride, true); +} + +void Smooth8x32_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* top_row, const void* left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + __m128i pixels[8]; + LoadSmoothPixels8(top_ptr, left_ptr, 32, pixels); + + __m128i weights_x[2], weights_y[8]; + LoadSmoothWeights8(kSmoothWeights, 32, weights_x, weights_y); + + auto* dst = static_cast<uint8_t*>(dest); + WriteSmoothPred8xH(pixels, weights_x, weights_y, 8, dst, stride, false); + dst += stride << 3; + WriteSmoothPred8xH(pixels, weights_x, &weights_y[2], 8, dst, stride, true); + dst += stride << 3; + WriteSmoothPred8xH(&pixels[4], weights_x, &weights_y[4], 8, dst, stride, + false); + dst += stride << 3; + WriteSmoothPred8xH(&pixels[4], weights_x, &weights_y[6], 8, dst, stride, + true); +} + +template <int width, int height> +void SmoothWxH(void* const dest, const ptrdiff_t stride, + const void* const top_row, const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const uint8_t* const sm_weights_h = kSmoothWeights + height - 4; + const uint8_t* const sm_weights_w = kSmoothWeights + width - 4; + const __m128i zero = _mm_setzero_si128(); + const __m128i scale_value = _mm_set1_epi16(256); + const __m128i bottom_left = _mm_cvtsi32_si128(left_ptr[height - 1]); + const __m128i top_right = _mm_set1_epi16(top_ptr[width - 1]); + const __m128i round = _mm_set1_epi32(256); + auto* dst = static_cast<uint8_t*>(dest); + for (int y = 0; y < height; ++y) { + const __m128i weights_y = _mm_cvtsi32_si128(sm_weights_h[y]); + const __m128i left_y = _mm_cvtsi32_si128(left_ptr[y]); + const __m128i scale_m_weights_y = _mm_sub_epi16(scale_value, weights_y); + __m128i scaled_bottom_left = + _mm_mullo_epi16(scale_m_weights_y, bottom_left); + const __m128i weight_left_y = + _mm_shuffle_epi32(_mm_unpacklo_epi16(weights_y, left_y), 0); + scaled_bottom_left = _mm_add_epi32(scaled_bottom_left, round); + scaled_bottom_left = _mm_shuffle_epi32(scaled_bottom_left, 0); + for (int x = 0; x < width; x += 8) { + const __m128i top_x = LoadLo8(top_ptr + x); + const __m128i weights_x = LoadLo8(sm_weights_w + x); + const __m128i top_weights_x = _mm_unpacklo_epi8(top_x, weights_x); + const __m128i top_weights_x_lo = _mm_cvtepu8_epi16(top_weights_x); + const __m128i top_weights_x_hi = _mm_unpackhi_epi8(top_weights_x, zero); + + // Here opposite weights and pixels are multiplied, where the order of + // interleaving is indicated in the names. + __m128i pred_lo = _mm_madd_epi16(top_weights_x_lo, weight_left_y); + __m128i pred_hi = _mm_madd_epi16(top_weights_x_hi, weight_left_y); + + // |scaled_bottom_left| is always scaled by the same weight each row, so + // we only derive |scaled_top_right| values here. + const __m128i inverted_weights_x = + _mm_sub_epi16(scale_value, _mm_cvtepu8_epi16(weights_x)); + const __m128i scaled_top_right = + _mm_mullo_epi16(inverted_weights_x, top_right); + const __m128i scaled_top_right_lo = _mm_cvtepu16_epi32(scaled_top_right); + const __m128i scaled_top_right_hi = + _mm_unpackhi_epi16(scaled_top_right, zero); + pred_lo = _mm_add_epi32(pred_lo, scaled_bottom_left); + pred_hi = _mm_add_epi32(pred_hi, scaled_bottom_left); + pred_lo = _mm_add_epi32(pred_lo, scaled_top_right_lo); + pred_hi = _mm_add_epi32(pred_hi, scaled_top_right_hi); + + // The round value for RightShiftWithRounding was added with + // |scaled_bottom_left|. + pred_lo = _mm_srli_epi32(pred_lo, 9); + pred_hi = _mm_srli_epi32(pred_hi, 9); + const __m128i pred = _mm_packus_epi16(pred_lo, pred_hi); + StoreLo8(dst + x, _mm_packus_epi16(pred, pred)); + } + dst += stride; + } +} + +void SmoothHorizontal4x4_SSE4_1(void* dest, const ptrdiff_t stride, + const void* top_row, const void* left_column) { + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i top_right = _mm_set1_epi32(top_ptr[3]); + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const __m128i left = _mm_cvtepu8_epi32(Load4(left_ptr)); + const __m128i weights = _mm_cvtepu8_epi32(Load4(kSmoothWeights)); + __m128i scale = _mm_set1_epi32(256); + const __m128i inverted_weights = _mm_sub_epi32(scale, weights); + const __m128i scaled_top_right = _mm_mullo_epi16(inverted_weights, top_right); + scale = _mm_set1_epi32(128); + auto* dst = static_cast<uint8_t*>(dest); + WriteSmoothHorizontalSum4<0>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0x55>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0xAA>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0xFF>(dst, left, weights, scaled_top_right, scale); +} + +void SmoothHorizontal4x8_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const uint8_t*>(top_row); + const __m128i top_right = _mm_set1_epi32(top[3]); + const __m128i weights = _mm_cvtepu8_epi32(Load4(kSmoothWeights)); + __m128i scale = _mm_set1_epi32(256); + const __m128i inverted_weights = _mm_sub_epi32(scale, weights); + const __m128i scaled_top_right = _mm_mullo_epi16(inverted_weights, top_right); + scale = _mm_set1_epi32(128); + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + __m128i left = _mm_cvtepu8_epi32(Load4(left_column)); + auto* dst = static_cast<uint8_t*>(dest); + WriteSmoothHorizontalSum4<0>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0x55>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0xAA>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0xFF>(dst, left, weights, scaled_top_right, scale); + dst += stride; + + left = _mm_cvtepu8_epi32(Load4(left_ptr + 4)); + WriteSmoothHorizontalSum4<0>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0x55>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0xAA>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0xFF>(dst, left, weights, scaled_top_right, scale); +} + +void SmoothHorizontal4x16_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const uint8_t*>(top_row); + const __m128i top_right = _mm_set1_epi32(top[3]); + const __m128i weights = _mm_cvtepu8_epi32(Load4(kSmoothWeights)); + __m128i scale = _mm_set1_epi32(256); + const __m128i inverted_weights = _mm_sub_epi32(scale, weights); + const __m128i scaled_top_right = _mm_mullo_epi16(inverted_weights, top_right); + scale = _mm_set1_epi32(128); + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + __m128i left = _mm_cvtepu8_epi32(Load4(left_column)); + auto* dst = static_cast<uint8_t*>(dest); + WriteSmoothHorizontalSum4<0>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0x55>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0xAA>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0xFF>(dst, left, weights, scaled_top_right, scale); + dst += stride; + + left = _mm_cvtepu8_epi32(Load4(left_ptr + 4)); + WriteSmoothHorizontalSum4<0>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0x55>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0xAA>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0xFF>(dst, left, weights, scaled_top_right, scale); + dst += stride; + + left = _mm_cvtepu8_epi32(Load4(left_ptr + 8)); + WriteSmoothHorizontalSum4<0>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0x55>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0xAA>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0xFF>(dst, left, weights, scaled_top_right, scale); + dst += stride; + + left = _mm_cvtepu8_epi32(Load4(left_ptr + 12)); + WriteSmoothHorizontalSum4<0>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0x55>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0xAA>(dst, left, weights, scaled_top_right, scale); + dst += stride; + WriteSmoothHorizontalSum4<0xFF>(dst, left, weights, scaled_top_right, scale); +} + +void SmoothHorizontal8x4_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const uint8_t*>(top_row); + const __m128i top_right = _mm_set1_epi16(top[7]); + const __m128i left = _mm_cvtepu8_epi16(Load4(left_column)); + const __m128i weights = _mm_cvtepu8_epi16(LoadLo8(kSmoothWeights + 4)); + __m128i scale = _mm_set1_epi16(256); + const __m128i inverted_weights = _mm_sub_epi16(scale, weights); + const __m128i scaled_top_right = _mm_mullo_epi16(inverted_weights, top_right); + scale = _mm_set1_epi16(128); + __m128i y_select = _mm_set1_epi32(0x01000100); + __m128i left_y = _mm_shuffle_epi8(left, y_select); + auto* dst = static_cast<uint8_t*>(dest); + WriteSmoothDirectionalSum8(dst, left_y, weights, scaled_top_right, scale); + dst += stride; + y_select = _mm_set1_epi32(0x03020302); + left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum8(dst, left_y, weights, scaled_top_right, scale); + dst += stride; + y_select = _mm_set1_epi32(0x05040504); + left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum8(dst, left_y, weights, scaled_top_right, scale); + dst += stride; + y_select = _mm_set1_epi32(0x07060706); + left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum8(dst, left_y, weights, scaled_top_right, scale); +} + +void SmoothHorizontal8x8_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const uint8_t*>(top_row); + const __m128i top_right = _mm_set1_epi16(top[7]); + const __m128i left = _mm_cvtepu8_epi16(LoadLo8(left_column)); + const __m128i weights = _mm_cvtepu8_epi16(LoadLo8(kSmoothWeights + 4)); + __m128i scale = _mm_set1_epi16(256); + const __m128i inverted_weights = _mm_sub_epi16(scale, weights); + const __m128i scaled_top_right = _mm_mullo_epi16(inverted_weights, top_right); + scale = _mm_set1_epi16(128); + auto* dst = static_cast<uint8_t*>(dest); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum8(dst, left_y, weights, scaled_top_right, scale); + dst += stride; + } +} + +void SmoothHorizontal8x16_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const uint8_t*>(top_row); + const __m128i top_right = _mm_set1_epi16(top[7]); + const __m128i weights = _mm_cvtepu8_epi16(LoadLo8(kSmoothWeights + 4)); + __m128i scale = _mm_set1_epi16(256); + const __m128i inverted_weights = _mm_sub_epi16(scale, weights); + const __m128i scaled_top_right = _mm_mullo_epi16(inverted_weights, top_right); + scale = _mm_set1_epi16(128); + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + __m128i left = _mm_cvtepu8_epi16(LoadLo8(left_column)); + auto* dst = static_cast<uint8_t*>(dest); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum8(dst, left_y, weights, scaled_top_right, scale); + dst += stride; + } + left = _mm_cvtepu8_epi16(LoadLo8(left_ptr + 8)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum8(dst, left_y, weights, scaled_top_right, scale); + dst += stride; + } +} + +void SmoothHorizontal8x32_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const uint8_t*>(top_row); + const __m128i top_right = _mm_set1_epi16(top[7]); + const __m128i weights = _mm_cvtepu8_epi16(LoadLo8(kSmoothWeights + 4)); + __m128i scale = _mm_set1_epi16(256); + const __m128i inverted_weights = _mm_sub_epi16(scale, weights); + const __m128i scaled_top_right = _mm_mullo_epi16(inverted_weights, top_right); + scale = _mm_set1_epi16(128); + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + __m128i left = _mm_cvtepu8_epi16(LoadLo8(left_column)); + auto* dst = static_cast<uint8_t*>(dest); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum8(dst, left_y, weights, scaled_top_right, scale); + dst += stride; + } + left = _mm_cvtepu8_epi16(LoadLo8(left_ptr + 8)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum8(dst, left_y, weights, scaled_top_right, scale); + dst += stride; + } + left = _mm_cvtepu8_epi16(LoadLo8(left_ptr + 16)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum8(dst, left_y, weights, scaled_top_right, scale); + dst += stride; + } + left = _mm_cvtepu8_epi16(LoadLo8(left_ptr + 24)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum8(dst, left_y, weights, scaled_top_right, scale); + dst += stride; + } +} + +void SmoothHorizontal16x4_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const uint8_t*>(top_row); + const __m128i top_right = _mm_set1_epi16(top[15]); + const __m128i left = _mm_cvtepu8_epi16(Load4(left_column)); + const __m128i weights = LoadUnaligned16(kSmoothWeights + 12); + __m128i scale = _mm_set1_epi16(256); + const __m128i weights1 = _mm_cvtepu8_epi16(weights); + const __m128i weights2 = _mm_cvtepu8_epi16(_mm_srli_si128(weights, 8)); + const __m128i inverted_weights1 = _mm_sub_epi16(scale, weights1); + const __m128i inverted_weights2 = _mm_sub_epi16(scale, weights2); + const __m128i scaled_top_right1 = + _mm_mullo_epi16(inverted_weights1, top_right); + const __m128i scaled_top_right2 = + _mm_mullo_epi16(inverted_weights2, top_right); + scale = _mm_set1_epi16(128); + __m128i y_mask = _mm_set1_epi32(0x01000100); + __m128i left_y = _mm_shuffle_epi8(left, y_mask); + auto* dst = static_cast<uint8_t*>(dest); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + dst += stride; + y_mask = _mm_set1_epi32(0x03020302); + left_y = _mm_shuffle_epi8(left, y_mask); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + dst += stride; + y_mask = _mm_set1_epi32(0x05040504); + left_y = _mm_shuffle_epi8(left, y_mask); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + dst += stride; + y_mask = _mm_set1_epi32(0x07060706); + left_y = _mm_shuffle_epi8(left, y_mask); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); +} + +void SmoothHorizontal16x8_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const uint8_t*>(top_row); + const __m128i top_right = _mm_set1_epi16(top[15]); + const __m128i left = _mm_cvtepu8_epi16(LoadLo8(left_column)); + const __m128i weights = LoadUnaligned16(kSmoothWeights + 12); + __m128i scale = _mm_set1_epi16(256); + const __m128i weights1 = _mm_cvtepu8_epi16(weights); + const __m128i weights2 = _mm_cvtepu8_epi16(_mm_srli_si128(weights, 8)); + const __m128i inverted_weights1 = _mm_sub_epi16(scale, weights1); + const __m128i inverted_weights2 = _mm_sub_epi16(scale, weights2); + const __m128i scaled_top_right1 = + _mm_mullo_epi16(inverted_weights1, top_right); + const __m128i scaled_top_right2 = + _mm_mullo_epi16(inverted_weights2, top_right); + scale = _mm_set1_epi16(128); + auto* dst = static_cast<uint8_t*>(dest); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + dst += stride; + } +} + +void SmoothHorizontal16x16_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const uint8_t*>(top_row); + const __m128i top_right = _mm_set1_epi16(top[15]); + const __m128i weights = LoadUnaligned16(kSmoothWeights + 12); + __m128i scale = _mm_set1_epi16(256); + const __m128i weights1 = _mm_cvtepu8_epi16(weights); + const __m128i weights2 = _mm_cvtepu8_epi16(_mm_srli_si128(weights, 8)); + const __m128i inverted_weights1 = _mm_sub_epi16(scale, weights1); + const __m128i inverted_weights2 = _mm_sub_epi16(scale, weights2); + const __m128i scaled_top_right1 = + _mm_mullo_epi16(inverted_weights1, top_right); + const __m128i scaled_top_right2 = + _mm_mullo_epi16(inverted_weights2, top_right); + scale = _mm_set1_epi16(128); + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + __m128i left = _mm_cvtepu8_epi16(LoadLo8(left_column)); + auto* dst = static_cast<uint8_t*>(dest); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + dst += stride; + } + left = _mm_cvtepu8_epi16(LoadLo8(left_ptr + 8)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + dst += stride; + } +} + +void SmoothHorizontal16x32_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const uint8_t*>(top_row); + const __m128i top_right = _mm_set1_epi16(top[15]); + const __m128i weights = LoadUnaligned16(kSmoothWeights + 12); + __m128i scale = _mm_set1_epi16(256); + const __m128i weights1 = _mm_cvtepu8_epi16(weights); + const __m128i weights2 = _mm_cvtepu8_epi16(_mm_srli_si128(weights, 8)); + const __m128i inverted_weights1 = _mm_sub_epi16(scale, weights1); + const __m128i inverted_weights2 = _mm_sub_epi16(scale, weights2); + const __m128i scaled_top_right1 = + _mm_mullo_epi16(inverted_weights1, top_right); + const __m128i scaled_top_right2 = + _mm_mullo_epi16(inverted_weights2, top_right); + scale = _mm_set1_epi16(128); + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + __m128i left = _mm_cvtepu8_epi16(LoadLo8(left_column)); + auto* dst = static_cast<uint8_t*>(dest); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + dst += stride; + } + left = _mm_cvtepu8_epi16(LoadLo8(left_ptr + 8)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + dst += stride; + } + left = _mm_cvtepu8_epi16(LoadLo8(left_ptr + 16)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + dst += stride; + } + left = _mm_cvtepu8_epi16(LoadLo8(left_ptr + 24)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + dst += stride; + } +} + +void SmoothHorizontal16x64_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const uint8_t*>(top_row); + const __m128i top_right = _mm_set1_epi16(top[15]); + const __m128i weights = LoadUnaligned16(kSmoothWeights + 12); + __m128i scale = _mm_set1_epi16(256); + const __m128i weights1 = _mm_cvtepu8_epi16(weights); + const __m128i weights2 = _mm_cvtepu8_epi16(_mm_srli_si128(weights, 8)); + const __m128i inverted_weights1 = _mm_sub_epi16(scale, weights1); + const __m128i inverted_weights2 = _mm_sub_epi16(scale, weights2); + const __m128i scaled_top_right1 = + _mm_mullo_epi16(inverted_weights1, top_right); + const __m128i scaled_top_right2 = + _mm_mullo_epi16(inverted_weights2, top_right); + scale = _mm_set1_epi16(128); + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + auto* dst = static_cast<uint8_t*>(dest); + for (int left_offset = 0; left_offset < 64; left_offset += 8) { + const __m128i left = _mm_cvtepu8_epi16(LoadLo8(left_ptr + left_offset)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + dst += stride; + } + } +} + +void SmoothHorizontal32x8_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const uint8_t*>(top_row); + const __m128i top_right = _mm_set1_epi16(top[31]); + const __m128i left = _mm_cvtepu8_epi16(LoadLo8(left_column)); + const __m128i weights_lo = LoadUnaligned16(kSmoothWeights + 28); + const __m128i weights_hi = LoadUnaligned16(kSmoothWeights + 44); + __m128i scale = _mm_set1_epi16(256); + const __m128i weights1 = _mm_cvtepu8_epi16(weights_lo); + const __m128i weights2 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_lo, 8)); + const __m128i weights3 = _mm_cvtepu8_epi16(weights_hi); + const __m128i weights4 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_hi, 8)); + const __m128i inverted_weights1 = _mm_sub_epi16(scale, weights1); + const __m128i inverted_weights2 = _mm_sub_epi16(scale, weights2); + const __m128i inverted_weights3 = _mm_sub_epi16(scale, weights3); + const __m128i inverted_weights4 = _mm_sub_epi16(scale, weights4); + const __m128i scaled_top_right1 = + _mm_mullo_epi16(inverted_weights1, top_right); + const __m128i scaled_top_right2 = + _mm_mullo_epi16(inverted_weights2, top_right); + const __m128i scaled_top_right3 = + _mm_mullo_epi16(inverted_weights3, top_right); + const __m128i scaled_top_right4 = + _mm_mullo_epi16(inverted_weights4, top_right); + scale = _mm_set1_epi16(128); + auto* dst = static_cast<uint8_t*>(dest); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + __m128i y_select = _mm_set1_epi32(y_mask); + __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + WriteSmoothDirectionalSum16(dst + 16, left_y, left_y, weights3, weights4, + scaled_top_right3, scaled_top_right4, scale); + dst += stride; + } +} + +void SmoothHorizontal32x16_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const uint8_t*>(top_row); + const __m128i top_right = _mm_set1_epi16(top[31]); + const __m128i left1 = _mm_cvtepu8_epi16(LoadLo8(left_column)); + const __m128i weights_lo = LoadUnaligned16(kSmoothWeights + 28); + const __m128i weights_hi = LoadUnaligned16(kSmoothWeights + 44); + __m128i scale = _mm_set1_epi16(256); + const __m128i weights1 = _mm_cvtepu8_epi16(weights_lo); + const __m128i weights2 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_lo, 8)); + const __m128i weights3 = _mm_cvtepu8_epi16(weights_hi); + const __m128i weights4 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_hi, 8)); + const __m128i inverted_weights1 = _mm_sub_epi16(scale, weights1); + const __m128i inverted_weights2 = _mm_sub_epi16(scale, weights2); + const __m128i inverted_weights3 = _mm_sub_epi16(scale, weights3); + const __m128i inverted_weights4 = _mm_sub_epi16(scale, weights4); + const __m128i scaled_top_right1 = + _mm_mullo_epi16(inverted_weights1, top_right); + const __m128i scaled_top_right2 = + _mm_mullo_epi16(inverted_weights2, top_right); + const __m128i scaled_top_right3 = + _mm_mullo_epi16(inverted_weights3, top_right); + const __m128i scaled_top_right4 = + _mm_mullo_epi16(inverted_weights4, top_right); + scale = _mm_set1_epi16(128); + auto* dst = static_cast<uint8_t*>(dest); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + __m128i y_select = _mm_set1_epi32(y_mask); + __m128i left_y = _mm_shuffle_epi8(left1, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + WriteSmoothDirectionalSum16(dst + 16, left_y, left_y, weights3, weights4, + scaled_top_right3, scaled_top_right4, scale); + dst += stride; + } + const __m128i left2 = + _mm_cvtepu8_epi16(LoadLo8(static_cast<const uint8_t*>(left_column) + 8)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + __m128i y_select = _mm_set1_epi32(y_mask); + __m128i left_y = _mm_shuffle_epi8(left2, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + WriteSmoothDirectionalSum16(dst + 16, left_y, left_y, weights3, weights4, + scaled_top_right3, scaled_top_right4, scale); + dst += stride; + } +} + +void SmoothHorizontal32x32_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const uint8_t*>(top_row); + const __m128i top_right = _mm_set1_epi16(top[31]); + const __m128i weights_lo = LoadUnaligned16(kSmoothWeights + 28); + const __m128i weights_hi = LoadUnaligned16(kSmoothWeights + 44); + __m128i scale = _mm_set1_epi16(256); + const __m128i weights1 = _mm_cvtepu8_epi16(weights_lo); + const __m128i weights2 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_lo, 8)); + const __m128i weights3 = _mm_cvtepu8_epi16(weights_hi); + const __m128i weights4 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_hi, 8)); + const __m128i inverted_weights1 = _mm_sub_epi16(scale, weights1); + const __m128i inverted_weights2 = _mm_sub_epi16(scale, weights2); + const __m128i inverted_weights3 = _mm_sub_epi16(scale, weights3); + const __m128i inverted_weights4 = _mm_sub_epi16(scale, weights4); + const __m128i scaled_top_right1 = + _mm_mullo_epi16(inverted_weights1, top_right); + const __m128i scaled_top_right2 = + _mm_mullo_epi16(inverted_weights2, top_right); + const __m128i scaled_top_right3 = + _mm_mullo_epi16(inverted_weights3, top_right); + const __m128i scaled_top_right4 = + _mm_mullo_epi16(inverted_weights4, top_right); + scale = _mm_set1_epi16(128); + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + __m128i left = _mm_cvtepu8_epi16(LoadLo8(left_column)); + auto* dst = static_cast<uint8_t*>(dest); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + __m128i y_select = _mm_set1_epi32(y_mask); + __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + WriteSmoothDirectionalSum16(dst + 16, left_y, left_y, weights3, weights4, + scaled_top_right3, scaled_top_right4, scale); + dst += stride; + } + left = _mm_cvtepu8_epi16(LoadLo8(left_ptr + 8)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + __m128i y_select = _mm_set1_epi32(y_mask); + __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + WriteSmoothDirectionalSum16(dst + 16, left_y, left_y, weights3, weights4, + scaled_top_right3, scaled_top_right4, scale); + dst += stride; + } + left = _mm_cvtepu8_epi16(LoadLo8(left_ptr + 16)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + __m128i y_select = _mm_set1_epi32(y_mask); + __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + WriteSmoothDirectionalSum16(dst + 16, left_y, left_y, weights3, weights4, + scaled_top_right3, scaled_top_right4, scale); + dst += stride; + } + left = _mm_cvtepu8_epi16(LoadLo8(left_ptr + 24)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + __m128i y_select = _mm_set1_epi32(y_mask); + __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + WriteSmoothDirectionalSum16(dst + 16, left_y, left_y, weights3, weights4, + scaled_top_right3, scaled_top_right4, scale); + dst += stride; + } +} + +void SmoothHorizontal32x64_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const uint8_t*>(top_row); + const __m128i top_right = _mm_set1_epi16(top[31]); + const __m128i weights_lo = LoadUnaligned16(kSmoothWeights + 28); + const __m128i weights_hi = LoadUnaligned16(kSmoothWeights + 44); + __m128i scale = _mm_set1_epi16(256); + const __m128i weights1 = _mm_cvtepu8_epi16(weights_lo); + const __m128i weights2 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_lo, 8)); + const __m128i weights3 = _mm_cvtepu8_epi16(weights_hi); + const __m128i weights4 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_hi, 8)); + const __m128i inverted_weights1 = _mm_sub_epi16(scale, weights1); + const __m128i inverted_weights2 = _mm_sub_epi16(scale, weights2); + const __m128i inverted_weights3 = _mm_sub_epi16(scale, weights3); + const __m128i inverted_weights4 = _mm_sub_epi16(scale, weights4); + const __m128i scaled_top_right1 = + _mm_mullo_epi16(inverted_weights1, top_right); + const __m128i scaled_top_right2 = + _mm_mullo_epi16(inverted_weights2, top_right); + const __m128i scaled_top_right3 = + _mm_mullo_epi16(inverted_weights3, top_right); + const __m128i scaled_top_right4 = + _mm_mullo_epi16(inverted_weights4, top_right); + scale = _mm_set1_epi16(128); + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + auto* dst = static_cast<uint8_t*>(dest); + for (int left_offset = 0; left_offset < 64; left_offset += 8) { + const __m128i left = _mm_cvtepu8_epi16(LoadLo8(left_ptr + left_offset)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + WriteSmoothDirectionalSum16(dst + 16, left_y, left_y, weights3, weights4, + scaled_top_right3, scaled_top_right4, scale); + dst += stride; + } + } +} + +void SmoothHorizontal64x16_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const uint8_t*>(top_row); + const __m128i top_right = _mm_set1_epi16(top[63]); + const __m128i left1 = _mm_cvtepu8_epi16(LoadLo8(left_column)); + const __m128i weights_lolo = LoadUnaligned16(kSmoothWeights + 60); + const __m128i weights_lohi = LoadUnaligned16(kSmoothWeights + 76); + __m128i scale = _mm_set1_epi16(256); + const __m128i weights1 = _mm_cvtepu8_epi16(weights_lolo); + const __m128i weights2 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_lolo, 8)); + const __m128i weights3 = _mm_cvtepu8_epi16(weights_lohi); + const __m128i weights4 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_lohi, 8)); + const __m128i inverted_weights1 = _mm_sub_epi16(scale, weights1); + const __m128i inverted_weights2 = _mm_sub_epi16(scale, weights2); + const __m128i inverted_weights3 = _mm_sub_epi16(scale, weights3); + const __m128i inverted_weights4 = _mm_sub_epi16(scale, weights4); + const __m128i scaled_top_right1 = + _mm_mullo_epi16(inverted_weights1, top_right); + const __m128i scaled_top_right2 = + _mm_mullo_epi16(inverted_weights2, top_right); + const __m128i scaled_top_right3 = + _mm_mullo_epi16(inverted_weights3, top_right); + const __m128i scaled_top_right4 = + _mm_mullo_epi16(inverted_weights4, top_right); + const __m128i weights_hilo = LoadUnaligned16(kSmoothWeights + 92); + const __m128i weights_hihi = LoadUnaligned16(kSmoothWeights + 108); + const __m128i weights5 = _mm_cvtepu8_epi16(weights_hilo); + const __m128i weights6 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_hilo, 8)); + const __m128i weights7 = _mm_cvtepu8_epi16(weights_hihi); + const __m128i weights8 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_hihi, 8)); + const __m128i inverted_weights5 = _mm_sub_epi16(scale, weights5); + const __m128i inverted_weights6 = _mm_sub_epi16(scale, weights6); + const __m128i inverted_weights7 = _mm_sub_epi16(scale, weights7); + const __m128i inverted_weights8 = _mm_sub_epi16(scale, weights8); + const __m128i scaled_top_right5 = + _mm_mullo_epi16(inverted_weights5, top_right); + const __m128i scaled_top_right6 = + _mm_mullo_epi16(inverted_weights6, top_right); + const __m128i scaled_top_right7 = + _mm_mullo_epi16(inverted_weights7, top_right); + const __m128i scaled_top_right8 = + _mm_mullo_epi16(inverted_weights8, top_right); + scale = _mm_set1_epi16(128); + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + auto* dst = static_cast<uint8_t*>(dest); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + __m128i y_select = _mm_set1_epi32(y_mask); + __m128i left_y = _mm_shuffle_epi8(left1, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + WriteSmoothDirectionalSum16(dst + 16, left_y, left_y, weights3, weights4, + scaled_top_right3, scaled_top_right4, scale); + WriteSmoothDirectionalSum16(dst + 32, left_y, left_y, weights5, weights6, + scaled_top_right5, scaled_top_right6, scale); + WriteSmoothDirectionalSum16(dst + 48, left_y, left_y, weights7, weights8, + scaled_top_right7, scaled_top_right8, scale); + dst += stride; + } + const __m128i left2 = _mm_cvtepu8_epi16(LoadLo8(left_ptr + 8)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + __m128i y_select = _mm_set1_epi32(y_mask); + __m128i left_y = _mm_shuffle_epi8(left2, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + WriteSmoothDirectionalSum16(dst + 16, left_y, left_y, weights3, weights4, + scaled_top_right3, scaled_top_right4, scale); + WriteSmoothDirectionalSum16(dst + 32, left_y, left_y, weights5, weights6, + scaled_top_right5, scaled_top_right6, scale); + WriteSmoothDirectionalSum16(dst + 48, left_y, left_y, weights7, weights8, + scaled_top_right7, scaled_top_right8, scale); + dst += stride; + } +} + +void SmoothHorizontal64x32_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const uint8_t*>(top_row); + const __m128i top_right = _mm_set1_epi16(top[63]); + const __m128i left1 = _mm_cvtepu8_epi16(LoadLo8(left_column)); + const __m128i weights_lolo = LoadUnaligned16(kSmoothWeights + 60); + const __m128i weights_lohi = LoadUnaligned16(kSmoothWeights + 76); + __m128i scale = _mm_set1_epi16(256); + const __m128i weights1 = _mm_cvtepu8_epi16(weights_lolo); + const __m128i weights2 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_lolo, 8)); + const __m128i weights3 = _mm_cvtepu8_epi16(weights_lohi); + const __m128i weights4 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_lohi, 8)); + const __m128i inverted_weights1 = _mm_sub_epi16(scale, weights1); + const __m128i inverted_weights2 = _mm_sub_epi16(scale, weights2); + const __m128i inverted_weights3 = _mm_sub_epi16(scale, weights3); + const __m128i inverted_weights4 = _mm_sub_epi16(scale, weights4); + const __m128i scaled_top_right1 = + _mm_mullo_epi16(inverted_weights1, top_right); + const __m128i scaled_top_right2 = + _mm_mullo_epi16(inverted_weights2, top_right); + const __m128i scaled_top_right3 = + _mm_mullo_epi16(inverted_weights3, top_right); + const __m128i scaled_top_right4 = + _mm_mullo_epi16(inverted_weights4, top_right); + const __m128i weights_hilo = LoadUnaligned16(kSmoothWeights + 92); + const __m128i weights_hihi = LoadUnaligned16(kSmoothWeights + 108); + const __m128i weights5 = _mm_cvtepu8_epi16(weights_hilo); + const __m128i weights6 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_hilo, 8)); + const __m128i weights7 = _mm_cvtepu8_epi16(weights_hihi); + const __m128i weights8 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_hihi, 8)); + const __m128i inverted_weights5 = _mm_sub_epi16(scale, weights5); + const __m128i inverted_weights6 = _mm_sub_epi16(scale, weights6); + const __m128i inverted_weights7 = _mm_sub_epi16(scale, weights7); + const __m128i inverted_weights8 = _mm_sub_epi16(scale, weights8); + const __m128i scaled_top_right5 = + _mm_mullo_epi16(inverted_weights5, top_right); + const __m128i scaled_top_right6 = + _mm_mullo_epi16(inverted_weights6, top_right); + const __m128i scaled_top_right7 = + _mm_mullo_epi16(inverted_weights7, top_right); + const __m128i scaled_top_right8 = + _mm_mullo_epi16(inverted_weights8, top_right); + scale = _mm_set1_epi16(128); + auto* dst = static_cast<uint8_t*>(dest); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left1, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + WriteSmoothDirectionalSum16(dst + 16, left_y, left_y, weights3, weights4, + scaled_top_right3, scaled_top_right4, scale); + WriteSmoothDirectionalSum16(dst + 32, left_y, left_y, weights5, weights6, + scaled_top_right5, scaled_top_right6, scale); + WriteSmoothDirectionalSum16(dst + 48, left_y, left_y, weights7, weights8, + scaled_top_right7, scaled_top_right8, scale); + dst += stride; + } + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const __m128i left2 = _mm_cvtepu8_epi16(LoadLo8(left_ptr + 8)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left2, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + WriteSmoothDirectionalSum16(dst + 16, left_y, left_y, weights3, weights4, + scaled_top_right3, scaled_top_right4, scale); + WriteSmoothDirectionalSum16(dst + 32, left_y, left_y, weights5, weights6, + scaled_top_right5, scaled_top_right6, scale); + WriteSmoothDirectionalSum16(dst + 48, left_y, left_y, weights7, weights8, + scaled_top_right7, scaled_top_right8, scale); + dst += stride; + } + const __m128i left3 = _mm_cvtepu8_epi16(LoadLo8(left_ptr + 16)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left3, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + WriteSmoothDirectionalSum16(dst + 16, left_y, left_y, weights3, weights4, + scaled_top_right3, scaled_top_right4, scale); + WriteSmoothDirectionalSum16(dst + 32, left_y, left_y, weights5, weights6, + scaled_top_right5, scaled_top_right6, scale); + WriteSmoothDirectionalSum16(dst + 48, left_y, left_y, weights7, weights8, + scaled_top_right7, scaled_top_right8, scale); + dst += stride; + } + const __m128i left4 = _mm_cvtepu8_epi16(LoadLo8(left_ptr + 24)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left4, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + WriteSmoothDirectionalSum16(dst + 16, left_y, left_y, weights3, weights4, + scaled_top_right3, scaled_top_right4, scale); + WriteSmoothDirectionalSum16(dst + 32, left_y, left_y, weights5, weights6, + scaled_top_right5, scaled_top_right6, scale); + WriteSmoothDirectionalSum16(dst + 48, left_y, left_y, weights7, weights8, + scaled_top_right7, scaled_top_right8, scale); + dst += stride; + } +} + +void SmoothHorizontal64x64_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const top = static_cast<const uint8_t*>(top_row); + const __m128i top_right = _mm_set1_epi16(top[63]); + const __m128i weights_lolo = LoadUnaligned16(kSmoothWeights + 60); + const __m128i weights_lohi = LoadUnaligned16(kSmoothWeights + 76); + __m128i scale = _mm_set1_epi16(256); + const __m128i weights1 = _mm_cvtepu8_epi16(weights_lolo); + const __m128i weights2 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_lolo, 8)); + const __m128i weights3 = _mm_cvtepu8_epi16(weights_lohi); + const __m128i weights4 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_lohi, 8)); + const __m128i inverted_weights1 = _mm_sub_epi16(scale, weights1); + const __m128i inverted_weights2 = _mm_sub_epi16(scale, weights2); + const __m128i inverted_weights3 = _mm_sub_epi16(scale, weights3); + const __m128i inverted_weights4 = _mm_sub_epi16(scale, weights4); + const __m128i scaled_top_right1 = + _mm_mullo_epi16(inverted_weights1, top_right); + const __m128i scaled_top_right2 = + _mm_mullo_epi16(inverted_weights2, top_right); + const __m128i scaled_top_right3 = + _mm_mullo_epi16(inverted_weights3, top_right); + const __m128i scaled_top_right4 = + _mm_mullo_epi16(inverted_weights4, top_right); + const __m128i weights_hilo = LoadUnaligned16(kSmoothWeights + 92); + const __m128i weights_hihi = LoadUnaligned16(kSmoothWeights + 108); + const __m128i weights5 = _mm_cvtepu8_epi16(weights_hilo); + const __m128i weights6 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_hilo, 8)); + const __m128i weights7 = _mm_cvtepu8_epi16(weights_hihi); + const __m128i weights8 = _mm_cvtepu8_epi16(_mm_srli_si128(weights_hihi, 8)); + const __m128i inverted_weights5 = _mm_sub_epi16(scale, weights5); + const __m128i inverted_weights6 = _mm_sub_epi16(scale, weights6); + const __m128i inverted_weights7 = _mm_sub_epi16(scale, weights7); + const __m128i inverted_weights8 = _mm_sub_epi16(scale, weights8); + const __m128i scaled_top_right5 = + _mm_mullo_epi16(inverted_weights5, top_right); + const __m128i scaled_top_right6 = + _mm_mullo_epi16(inverted_weights6, top_right); + const __m128i scaled_top_right7 = + _mm_mullo_epi16(inverted_weights7, top_right); + const __m128i scaled_top_right8 = + _mm_mullo_epi16(inverted_weights8, top_right); + scale = _mm_set1_epi16(128); + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + auto* dst = static_cast<uint8_t*>(dest); + for (int left_offset = 0; left_offset < 64; left_offset += 8) { + const __m128i left = _mm_cvtepu8_epi16(LoadLo8(left_ptr + left_offset)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i left_y = _mm_shuffle_epi8(left, y_select); + WriteSmoothDirectionalSum16(dst, left_y, left_y, weights1, weights2, + scaled_top_right1, scaled_top_right2, scale); + WriteSmoothDirectionalSum16(dst + 16, left_y, left_y, weights3, weights4, + scaled_top_right3, scaled_top_right4, scale); + WriteSmoothDirectionalSum16(dst + 32, left_y, left_y, weights5, weights6, + scaled_top_right5, scaled_top_right6, scale); + WriteSmoothDirectionalSum16(dst + 48, left_y, left_y, weights7, weights8, + scaled_top_right7, scaled_top_right8, scale); + dst += stride; + } + } +} + +inline void LoadSmoothVerticalPixels4(const uint8_t* above, const uint8_t* left, + const int height, __m128i* pixels) { + __m128i top = Load4(above); + const __m128i bottom_left = _mm_set1_epi16(left[height - 1]); + top = _mm_cvtepu8_epi16(top); + pixels[0] = _mm_unpacklo_epi16(top, bottom_left); +} + +// |weight_array| alternates weight vectors from the table with their inverted +// (256-w) counterparts. This is precomputed by the compiler when the weights +// table is visible to this module. Removing this visibility can cut speed by up +// to half in both 4xH and 8xH transforms. +inline void LoadSmoothVerticalWeights4(const uint8_t* weight_array, + const int height, __m128i* weights) { + const __m128i inverter = _mm_set1_epi16(256); + + if (height == 4) { + const __m128i weight = Load4(weight_array); + weights[0] = _mm_cvtepu8_epi16(weight); + weights[1] = _mm_sub_epi16(inverter, weights[0]); + } else if (height == 8) { + const __m128i weight = LoadLo8(weight_array + 4); + weights[0] = _mm_cvtepu8_epi16(weight); + weights[1] = _mm_sub_epi16(inverter, weights[0]); + } else { + const __m128i weight = LoadUnaligned16(weight_array + 12); + const __m128i zero = _mm_setzero_si128(); + weights[0] = _mm_cvtepu8_epi16(weight); + weights[1] = _mm_sub_epi16(inverter, weights[0]); + weights[2] = _mm_unpackhi_epi8(weight, zero); + weights[3] = _mm_sub_epi16(inverter, weights[2]); + } +} + +inline void WriteSmoothVertical4xH(const __m128i* pixel, const __m128i* weight, + const int height, uint8_t* dst, + const ptrdiff_t stride) { + const __m128i pred_round = _mm_set1_epi32(128); + const __m128i mask_increment = _mm_set1_epi16(0x0202); + const __m128i cvtepu8_epi32 = _mm_set1_epi32(0xC080400); + __m128i y_select = _mm_set1_epi16(0x0100); + + for (int y = 0; y < height; ++y) { + const __m128i weight_y = _mm_shuffle_epi8(weight[0], y_select); + const __m128i inverted_weight_y = _mm_shuffle_epi8(weight[1], y_select); + const __m128i alternate_weights = + _mm_unpacklo_epi16(weight_y, inverted_weight_y); + // Here the pixel vector is top_row[0], corner, top_row[1], corner, ... + // The madd instruction yields four results of the form: + // (top_row[x] * weight[y] + corner * inverted_weight[y]) + __m128i sum = _mm_madd_epi16(pixel[0], alternate_weights); + sum = _mm_add_epi32(sum, pred_round); + sum = _mm_srai_epi32(sum, 8); + sum = _mm_shuffle_epi8(sum, cvtepu8_epi32); + Store4(dst, sum); + dst += stride; + y_select = _mm_add_epi16(y_select, mask_increment); + } +} + +void SmoothVertical4x4_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left = static_cast<const uint8_t*>(left_column); + const auto* const above = static_cast<const uint8_t*>(top_row); + auto* dst = static_cast<uint8_t*>(dest); + __m128i pixels; + LoadSmoothVerticalPixels4(above, left, 4, &pixels); + + __m128i weights[2]; + LoadSmoothVerticalWeights4(kSmoothWeights, 4, weights); + + WriteSmoothVertical4xH(&pixels, weights, 4, dst, stride); +} + +void SmoothVertical4x8_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left = static_cast<const uint8_t*>(left_column); + const auto* const above = static_cast<const uint8_t*>(top_row); + auto* dst = static_cast<uint8_t*>(dest); + __m128i pixels; + LoadSmoothVerticalPixels4(above, left, 8, &pixels); + + __m128i weights[2]; + LoadSmoothVerticalWeights4(kSmoothWeights, 8, weights); + + WriteSmoothVertical4xH(&pixels, weights, 8, dst, stride); +} + +void SmoothVertical4x16_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left = static_cast<const uint8_t*>(left_column); + const auto* const above = static_cast<const uint8_t*>(top_row); + auto* dst = static_cast<uint8_t*>(dest); + __m128i pixels; + LoadSmoothVerticalPixels4(above, left, 16, &pixels); + + __m128i weights[4]; + LoadSmoothVerticalWeights4(kSmoothWeights, 16, weights); + + WriteSmoothVertical4xH(&pixels, weights, 8, dst, stride); + dst += stride << 3; + WriteSmoothVertical4xH(&pixels, &weights[2], 8, dst, stride); +} + +void SmoothVertical8x4_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const __m128i bottom_left = _mm_set1_epi16(left_ptr[3]); + const __m128i weights = _mm_cvtepu8_epi16(Load4(kSmoothWeights)); + __m128i scale = _mm_set1_epi16(256); + const __m128i inverted_weights = _mm_sub_epi16(scale, weights); + const __m128i scaled_bottom_left = + _mm_mullo_epi16(inverted_weights, bottom_left); + scale = _mm_set1_epi16(128); + + auto* dst = static_cast<uint8_t*>(dest); + __m128i y_select = _mm_set1_epi32(0x01000100); + const __m128i top = _mm_cvtepu8_epi16(LoadLo8(top_row)); + __m128i weights_y = _mm_shuffle_epi8(weights, y_select); + __m128i scaled_bottom_left_y = _mm_shuffle_epi8(scaled_bottom_left, y_select); + WriteSmoothDirectionalSum8(dst, top, weights_y, scaled_bottom_left_y, scale); + dst += stride; + y_select = _mm_set1_epi32(0x03020302); + weights_y = _mm_shuffle_epi8(weights, y_select); + scaled_bottom_left_y = _mm_shuffle_epi8(scaled_bottom_left, y_select); + WriteSmoothDirectionalSum8(dst, top, weights_y, scaled_bottom_left_y, scale); + dst += stride; + y_select = _mm_set1_epi32(0x05040504); + weights_y = _mm_shuffle_epi8(weights, y_select); + scaled_bottom_left_y = _mm_shuffle_epi8(scaled_bottom_left, y_select); + WriteSmoothDirectionalSum8(dst, top, weights_y, scaled_bottom_left_y, scale); + dst += stride; + y_select = _mm_set1_epi32(0x07060706); + weights_y = _mm_shuffle_epi8(weights, y_select); + scaled_bottom_left_y = _mm_shuffle_epi8(scaled_bottom_left, y_select); + WriteSmoothDirectionalSum8(dst, top, weights_y, scaled_bottom_left_y, scale); +} + +void SmoothVertical8x8_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const __m128i bottom_left = _mm_set1_epi16(left_ptr[7]); + const __m128i weights = _mm_cvtepu8_epi16(LoadLo8(kSmoothWeights + 4)); + __m128i scale = _mm_set1_epi16(256); + const __m128i inverted_weights = _mm_sub_epi16(scale, weights); + const __m128i scaled_bottom_left = + _mm_mullo_epi16(inverted_weights, bottom_left); + scale = _mm_set1_epi16(128); + const __m128i top = _mm_cvtepu8_epi16(LoadLo8(top_row)); + auto* dst = static_cast<uint8_t*>(dest); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left, y_select); + WriteSmoothDirectionalSum8(dst, top, weights_y, scaled_bottom_left_y, + scale); + dst += stride; + } +} + +void SmoothVertical8x16_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const __m128i bottom_left = _mm_set1_epi16(left_ptr[15]); + const __m128i weights = LoadUnaligned16(kSmoothWeights + 12); + + const __m128i weights1 = _mm_cvtepu8_epi16(weights); + const __m128i weights2 = _mm_cvtepu8_epi16(_mm_srli_si128(weights, 8)); + __m128i scale = _mm_set1_epi16(256); + const __m128i inverted_weights1 = _mm_sub_epi16(scale, weights1); + const __m128i inverted_weights2 = _mm_sub_epi16(scale, weights2); + const __m128i scaled_bottom_left1 = + _mm_mullo_epi16(inverted_weights1, bottom_left); + const __m128i scaled_bottom_left2 = + _mm_mullo_epi16(inverted_weights2, bottom_left); + scale = _mm_set1_epi16(128); + const __m128i top = _mm_cvtepu8_epi16(LoadLo8(top_row)); + auto* dst = static_cast<uint8_t*>(dest); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights1, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left1, y_select); + WriteSmoothDirectionalSum8(dst, top, weights_y, scaled_bottom_left_y, + scale); + dst += stride; + } + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights2, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left2, y_select); + WriteSmoothDirectionalSum8(dst, top, weights_y, scaled_bottom_left_y, + scale); + dst += stride; + } +} + +void SmoothVertical8x32_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const __m128i zero = _mm_setzero_si128(); + const __m128i bottom_left = _mm_set1_epi16(left_ptr[31]); + const __m128i weights_lo = LoadUnaligned16(kSmoothWeights + 28); + const __m128i weights_hi = LoadUnaligned16(kSmoothWeights + 44); + const __m128i weights1 = _mm_cvtepu8_epi16(weights_lo); + const __m128i weights2 = _mm_unpackhi_epi8(weights_lo, zero); + const __m128i weights3 = _mm_cvtepu8_epi16(weights_hi); + const __m128i weights4 = _mm_unpackhi_epi8(weights_hi, zero); + __m128i scale = _mm_set1_epi16(256); + const __m128i inverted_weights1 = _mm_sub_epi16(scale, weights1); + const __m128i inverted_weights2 = _mm_sub_epi16(scale, weights2); + const __m128i inverted_weights3 = _mm_sub_epi16(scale, weights3); + const __m128i inverted_weights4 = _mm_sub_epi16(scale, weights4); + const __m128i scaled_bottom_left1 = + _mm_mullo_epi16(inverted_weights1, bottom_left); + const __m128i scaled_bottom_left2 = + _mm_mullo_epi16(inverted_weights2, bottom_left); + const __m128i scaled_bottom_left3 = + _mm_mullo_epi16(inverted_weights3, bottom_left); + const __m128i scaled_bottom_left4 = + _mm_mullo_epi16(inverted_weights4, bottom_left); + scale = _mm_set1_epi16(128); + auto* dst = static_cast<uint8_t*>(dest); + const __m128i top = _mm_cvtepu8_epi16(LoadLo8(top_row)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights1, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left1, y_select); + WriteSmoothDirectionalSum8(dst, top, weights_y, scaled_bottom_left_y, + scale); + dst += stride; + } + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights2, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left2, y_select); + WriteSmoothDirectionalSum8(dst, top, weights_y, scaled_bottom_left_y, + scale); + dst += stride; + } + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights3, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left3, y_select); + WriteSmoothDirectionalSum8(dst, top, weights_y, scaled_bottom_left_y, + scale); + dst += stride; + } + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights4, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left4, y_select); + WriteSmoothDirectionalSum8(dst, top, weights_y, scaled_bottom_left_y, + scale); + dst += stride; + } +} + +void SmoothVertical16x4_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + auto* dst = static_cast<uint8_t*>(dest); + const __m128i bottom_left = _mm_set1_epi16(left_ptr[3]); + const __m128i weights = _mm_cvtepu8_epi16(Load4(kSmoothWeights)); + __m128i scale = _mm_set1_epi16(256); + const __m128i inverted_weights = _mm_sub_epi16(scale, weights); + const __m128i scaled_bottom_left = + _mm_mullo_epi16(inverted_weights, bottom_left); + scale = _mm_set1_epi16(128); + const __m128i top = LoadUnaligned16(top_row); + const __m128i top_lo = _mm_cvtepu8_epi16(top); + const __m128i top_hi = _mm_cvtepu8_epi16(_mm_srli_si128(top, 8)); + + __m128i y_select = _mm_set1_epi32(0x01000100); + __m128i weights_y = _mm_shuffle_epi8(weights, y_select); + __m128i scaled_bottom_left_y = _mm_shuffle_epi8(scaled_bottom_left, y_select); + WriteSmoothDirectionalSum16(dst, top_lo, top_hi, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + y_select = _mm_set1_epi32(0x03020302); + weights_y = _mm_shuffle_epi8(weights, y_select); + scaled_bottom_left_y = _mm_shuffle_epi8(scaled_bottom_left, y_select); + WriteSmoothDirectionalSum16(dst, top_lo, top_hi, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + y_select = _mm_set1_epi32(0x05040504); + weights_y = _mm_shuffle_epi8(weights, y_select); + scaled_bottom_left_y = _mm_shuffle_epi8(scaled_bottom_left, y_select); + WriteSmoothDirectionalSum16(dst, top_lo, top_hi, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + y_select = _mm_set1_epi32(0x07060706); + weights_y = _mm_shuffle_epi8(weights, y_select); + scaled_bottom_left_y = _mm_shuffle_epi8(scaled_bottom_left, y_select); + WriteSmoothDirectionalSum16(dst, top_lo, top_hi, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); +} + +void SmoothVertical16x8_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + auto* dst = static_cast<uint8_t*>(dest); + const __m128i bottom_left = _mm_set1_epi16(left_ptr[7]); + const __m128i weights = _mm_cvtepu8_epi16(LoadLo8(kSmoothWeights + 4)); + __m128i scale = _mm_set1_epi16(256); + const __m128i inverted_weights = _mm_sub_epi16(scale, weights); + const __m128i scaled_bottom_left = + _mm_mullo_epi16(inverted_weights, bottom_left); + scale = _mm_set1_epi16(128); + + const __m128i top = LoadUnaligned16(top_row); + const __m128i top_lo = _mm_cvtepu8_epi16(top); + const __m128i top_hi = _mm_cvtepu8_epi16(_mm_srli_si128(top, 8)); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left, y_select); + WriteSmoothDirectionalSum16(dst, top_lo, top_hi, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } +} + +void SmoothVertical16x16_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + auto* dst = static_cast<uint8_t*>(dest); + const __m128i bottom_left = _mm_set1_epi16(left_ptr[15]); + const __m128i zero = _mm_setzero_si128(); + __m128i scale = _mm_set1_epi16(256); + const __m128i weights = LoadUnaligned16(kSmoothWeights + 12); + const __m128i weights_lo = _mm_cvtepu8_epi16(weights); + const __m128i weights_hi = _mm_unpackhi_epi8(weights, zero); + const __m128i inverted_weights_lo = _mm_sub_epi16(scale, weights_lo); + const __m128i inverted_weights_hi = _mm_sub_epi16(scale, weights_hi); + const __m128i scaled_bottom_left_lo = + _mm_mullo_epi16(inverted_weights_lo, bottom_left); + const __m128i scaled_bottom_left_hi = + _mm_mullo_epi16(inverted_weights_hi, bottom_left); + scale = _mm_set1_epi16(128); + + const __m128i top = LoadUnaligned16(top_row); + const __m128i top_lo = _mm_cvtepu8_epi16(top); + const __m128i top_hi = _mm_unpackhi_epi8(top, zero); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights_lo, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left_lo, y_select); + WriteSmoothDirectionalSum16(dst, top_lo, top_hi, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights_hi, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left_hi, y_select); + WriteSmoothDirectionalSum16(dst, top_lo, top_hi, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } +} + +void SmoothVertical16x32_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + auto* dst = static_cast<uint8_t*>(dest); + const __m128i bottom_left = _mm_set1_epi16(left_ptr[31]); + const __m128i weights_lo = LoadUnaligned16(kSmoothWeights + 28); + const __m128i weights_hi = LoadUnaligned16(kSmoothWeights + 44); + __m128i scale = _mm_set1_epi16(256); + const __m128i zero = _mm_setzero_si128(); + const __m128i weights1 = _mm_cvtepu8_epi16(weights_lo); + const __m128i weights2 = _mm_unpackhi_epi8(weights_lo, zero); + const __m128i weights3 = _mm_cvtepu8_epi16(weights_hi); + const __m128i weights4 = _mm_unpackhi_epi8(weights_hi, zero); + const __m128i inverted_weights1 = _mm_sub_epi16(scale, weights1); + const __m128i inverted_weights2 = _mm_sub_epi16(scale, weights2); + const __m128i inverted_weights3 = _mm_sub_epi16(scale, weights3); + const __m128i inverted_weights4 = _mm_sub_epi16(scale, weights4); + const __m128i scaled_bottom_left1 = + _mm_mullo_epi16(inverted_weights1, bottom_left); + const __m128i scaled_bottom_left2 = + _mm_mullo_epi16(inverted_weights2, bottom_left); + const __m128i scaled_bottom_left3 = + _mm_mullo_epi16(inverted_weights3, bottom_left); + const __m128i scaled_bottom_left4 = + _mm_mullo_epi16(inverted_weights4, bottom_left); + scale = _mm_set1_epi16(128); + + const __m128i top = LoadUnaligned16(top_row); + const __m128i top_lo = _mm_cvtepu8_epi16(top); + const __m128i top_hi = _mm_unpackhi_epi8(top, zero); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights1, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left1, y_select); + WriteSmoothDirectionalSum16(dst, top_lo, top_hi, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights2, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left2, y_select); + WriteSmoothDirectionalSum16(dst, top_lo, top_hi, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights3, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left3, y_select); + WriteSmoothDirectionalSum16(dst, top_lo, top_hi, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights4, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left4, y_select); + WriteSmoothDirectionalSum16(dst, top_lo, top_hi, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } +} + +void SmoothVertical16x64_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + auto* dst = static_cast<uint8_t*>(dest); + const __m128i bottom_left = _mm_set1_epi16(left_ptr[63]); + const __m128i scale = _mm_set1_epi16(256); + const __m128i round = _mm_set1_epi16(128); + const __m128i zero = _mm_setzero_si128(); + + const __m128i top = LoadUnaligned16(top_row); + const __m128i top_lo = _mm_cvtepu8_epi16(top); + const __m128i top_hi = _mm_unpackhi_epi8(top, zero); + const uint8_t* weights_base_ptr = kSmoothWeights + 60; + for (int left_offset = 0; left_offset < 64; left_offset += 16) { + const __m128i weights = LoadUnaligned16(weights_base_ptr + left_offset); + const __m128i weights_lo = _mm_cvtepu8_epi16(weights); + const __m128i weights_hi = _mm_unpackhi_epi8(weights, zero); + const __m128i inverted_weights_lo = _mm_sub_epi16(scale, weights_lo); + const __m128i inverted_weights_hi = _mm_sub_epi16(scale, weights_hi); + const __m128i scaled_bottom_left_lo = + _mm_mullo_epi16(inverted_weights_lo, bottom_left); + const __m128i scaled_bottom_left_hi = + _mm_mullo_epi16(inverted_weights_hi, bottom_left); + + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights_lo, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left_lo, y_select); + WriteSmoothDirectionalSum16(dst, top_lo, top_hi, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + round); + dst += stride; + } + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights_hi, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left_hi, y_select); + WriteSmoothDirectionalSum16(dst, top_lo, top_hi, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + round); + dst += stride; + } + } +} + +void SmoothVertical32x8_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + auto* dst = static_cast<uint8_t*>(dest); + const __m128i zero = _mm_setzero_si128(); + const __m128i bottom_left = _mm_set1_epi16(left_ptr[7]); + const __m128i top_lo = LoadUnaligned16(top_ptr); + const __m128i top_hi = LoadUnaligned16(top_ptr + 16); + const __m128i top1 = _mm_cvtepu8_epi16(top_lo); + const __m128i top2 = _mm_unpackhi_epi8(top_lo, zero); + const __m128i top3 = _mm_cvtepu8_epi16(top_hi); + const __m128i top4 = _mm_unpackhi_epi8(top_hi, zero); + __m128i scale = _mm_set1_epi16(256); + const __m128i weights = _mm_cvtepu8_epi16(LoadLo8(kSmoothWeights + 4)); + const __m128i inverted_weights = _mm_sub_epi16(scale, weights); + const __m128i scaled_bottom_left = + _mm_mullo_epi16(inverted_weights, bottom_left); + scale = _mm_set1_epi16(128); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left, y_select); + WriteSmoothDirectionalSum16(dst, top1, top2, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 16, top3, top4, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } +} + +void SmoothVertical32x16_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + auto* dst = static_cast<uint8_t*>(dest); + const __m128i zero = _mm_setzero_si128(); + const __m128i bottom_left = _mm_set1_epi16(left_ptr[15]); + const __m128i top_lo = LoadUnaligned16(top_ptr); + const __m128i top_hi = LoadUnaligned16(top_ptr + 16); + const __m128i top1 = _mm_cvtepu8_epi16(top_lo); + const __m128i top2 = _mm_unpackhi_epi8(top_lo, zero); + const __m128i top3 = _mm_cvtepu8_epi16(top_hi); + const __m128i top4 = _mm_unpackhi_epi8(top_hi, zero); + const __m128i weights = LoadUnaligned16(kSmoothWeights + 12); + const __m128i weights1 = _mm_cvtepu8_epi16(weights); + const __m128i weights2 = _mm_unpackhi_epi8(weights, zero); + __m128i scale = _mm_set1_epi16(256); + const __m128i inverted_weights1 = _mm_sub_epi16(scale, weights1); + const __m128i inverted_weights2 = _mm_sub_epi16(scale, weights2); + const __m128i scaled_bottom_left1 = + _mm_mullo_epi16(inverted_weights1, bottom_left); + const __m128i scaled_bottom_left2 = + _mm_mullo_epi16(inverted_weights2, bottom_left); + scale = _mm_set1_epi16(128); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights1, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left1, y_select); + WriteSmoothDirectionalSum16(dst, top1, top2, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 16, top3, top4, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights2, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left2, y_select); + WriteSmoothDirectionalSum16(dst, top1, top2, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 16, top3, top4, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } +} + +void SmoothVertical32x32_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + auto* dst = static_cast<uint8_t*>(dest); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i bottom_left = _mm_set1_epi16(left_ptr[31]); + const __m128i weights_lo = LoadUnaligned16(kSmoothWeights + 28); + const __m128i weights_hi = LoadUnaligned16(kSmoothWeights + 44); + const __m128i zero = _mm_setzero_si128(); + __m128i scale = _mm_set1_epi16(256); + const __m128i top_lo = LoadUnaligned16(top_ptr); + const __m128i top_hi = LoadUnaligned16(top_ptr + 16); + const __m128i top1 = _mm_cvtepu8_epi16(top_lo); + const __m128i top2 = _mm_unpackhi_epi8(top_lo, zero); + const __m128i top3 = _mm_cvtepu8_epi16(top_hi); + const __m128i top4 = _mm_unpackhi_epi8(top_hi, zero); + const __m128i weights1 = _mm_cvtepu8_epi16(weights_lo); + const __m128i weights2 = _mm_unpackhi_epi8(weights_lo, zero); + const __m128i weights3 = _mm_cvtepu8_epi16(weights_hi); + const __m128i weights4 = _mm_unpackhi_epi8(weights_hi, zero); + const __m128i inverted_weights1 = _mm_sub_epi16(scale, weights1); + const __m128i inverted_weights2 = _mm_sub_epi16(scale, weights2); + const __m128i inverted_weights3 = _mm_sub_epi16(scale, weights3); + const __m128i inverted_weights4 = _mm_sub_epi16(scale, weights4); + const __m128i scaled_bottom_left1 = + _mm_mullo_epi16(inverted_weights1, bottom_left); + const __m128i scaled_bottom_left2 = + _mm_mullo_epi16(inverted_weights2, bottom_left); + const __m128i scaled_bottom_left3 = + _mm_mullo_epi16(inverted_weights3, bottom_left); + const __m128i scaled_bottom_left4 = + _mm_mullo_epi16(inverted_weights4, bottom_left); + scale = _mm_set1_epi16(128); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights1, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left1, y_select); + WriteSmoothDirectionalSum16(dst, top1, top2, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 16, top3, top4, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights2, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left2, y_select); + WriteSmoothDirectionalSum16(dst, top1, top2, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 16, top3, top4, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights3, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left3, y_select); + WriteSmoothDirectionalSum16(dst, top1, top2, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 16, top3, top4, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights4, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left4, y_select); + WriteSmoothDirectionalSum16(dst, top1, top2, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 16, top3, top4, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } +} + +void SmoothVertical32x64_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + auto* dst = static_cast<uint8_t*>(dest); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i zero = _mm_setzero_si128(); + const __m128i bottom_left = _mm_set1_epi16(left_ptr[63]); + const __m128i top_lo = LoadUnaligned16(top_ptr); + const __m128i top_hi = LoadUnaligned16(top_ptr + 16); + const __m128i top1 = _mm_cvtepu8_epi16(top_lo); + const __m128i top2 = _mm_unpackhi_epi8(top_lo, zero); + const __m128i top3 = _mm_cvtepu8_epi16(top_hi); + const __m128i top4 = _mm_unpackhi_epi8(top_hi, zero); + const __m128i scale = _mm_set1_epi16(256); + const __m128i round = _mm_set1_epi16(128); + const uint8_t* weights_base_ptr = kSmoothWeights + 60; + for (int left_offset = 0; left_offset < 64; left_offset += 16) { + const __m128i weights = LoadUnaligned16(weights_base_ptr + left_offset); + const __m128i weights_lo = _mm_cvtepu8_epi16(weights); + const __m128i weights_hi = _mm_unpackhi_epi8(weights, zero); + const __m128i inverted_weights_lo = _mm_sub_epi16(scale, weights_lo); + const __m128i inverted_weights_hi = _mm_sub_epi16(scale, weights_hi); + const __m128i scaled_bottom_left_lo = + _mm_mullo_epi16(inverted_weights_lo, bottom_left); + const __m128i scaled_bottom_left_hi = + _mm_mullo_epi16(inverted_weights_hi, bottom_left); + + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights_lo, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left_lo, y_select); + WriteSmoothDirectionalSum16(dst, top1, top2, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + round); + WriteSmoothDirectionalSum16(dst + 16, top3, top4, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + round); + dst += stride; + } + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights_hi, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left_hi, y_select); + WriteSmoothDirectionalSum16(dst, top1, top2, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + round); + WriteSmoothDirectionalSum16(dst + 16, top3, top4, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + round); + dst += stride; + } + } +} + +void SmoothVertical64x16_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + auto* dst = static_cast<uint8_t*>(dest); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i bottom_left = _mm_set1_epi16(left_ptr[15]); + __m128i scale = _mm_set1_epi16(256); + const __m128i zero = _mm_setzero_si128(); + const __m128i top_lolo = LoadUnaligned16(top_ptr); + const __m128i top_lohi = LoadUnaligned16(top_ptr + 16); + const __m128i top1 = _mm_cvtepu8_epi16(top_lolo); + const __m128i top2 = _mm_unpackhi_epi8(top_lolo, zero); + const __m128i top3 = _mm_cvtepu8_epi16(top_lohi); + const __m128i top4 = _mm_unpackhi_epi8(top_lohi, zero); + + const __m128i weights = LoadUnaligned16(kSmoothWeights + 12); + const __m128i weights1 = _mm_cvtepu8_epi16(weights); + const __m128i weights2 = _mm_unpackhi_epi8(weights, zero); + const __m128i inverted_weights1 = _mm_sub_epi16(scale, weights1); + const __m128i inverted_weights2 = _mm_sub_epi16(scale, weights2); + const __m128i top_hilo = LoadUnaligned16(top_ptr + 32); + const __m128i top_hihi = LoadUnaligned16(top_ptr + 48); + const __m128i top5 = _mm_cvtepu8_epi16(top_hilo); + const __m128i top6 = _mm_unpackhi_epi8(top_hilo, zero); + const __m128i top7 = _mm_cvtepu8_epi16(top_hihi); + const __m128i top8 = _mm_unpackhi_epi8(top_hihi, zero); + const __m128i scaled_bottom_left1 = + _mm_mullo_epi16(inverted_weights1, bottom_left); + const __m128i scaled_bottom_left2 = + _mm_mullo_epi16(inverted_weights2, bottom_left); + scale = _mm_set1_epi16(128); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights1, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left1, y_select); + WriteSmoothDirectionalSum16(dst, top1, top2, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 16, top3, top4, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 32, top5, top6, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 48, top7, top8, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights2, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left2, y_select); + WriteSmoothDirectionalSum16(dst, top1, top2, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 16, top3, top4, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 32, top5, top6, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 48, top7, top8, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } +} + +void SmoothVertical64x32_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + auto* dst = static_cast<uint8_t*>(dest); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i zero = _mm_setzero_si128(); + const __m128i bottom_left = _mm_set1_epi16(left_ptr[31]); + const __m128i top_lolo = LoadUnaligned16(top_ptr); + const __m128i top_lohi = LoadUnaligned16(top_ptr + 16); + const __m128i top1 = _mm_cvtepu8_epi16(top_lolo); + const __m128i top2 = _mm_unpackhi_epi8(top_lolo, zero); + const __m128i top3 = _mm_cvtepu8_epi16(top_lohi); + const __m128i top4 = _mm_unpackhi_epi8(top_lohi, zero); + const __m128i top_hilo = LoadUnaligned16(top_ptr + 32); + const __m128i top_hihi = LoadUnaligned16(top_ptr + 48); + const __m128i top5 = _mm_cvtepu8_epi16(top_hilo); + const __m128i top6 = _mm_unpackhi_epi8(top_hilo, zero); + const __m128i top7 = _mm_cvtepu8_epi16(top_hihi); + const __m128i top8 = _mm_unpackhi_epi8(top_hihi, zero); + const __m128i weights_lo = LoadUnaligned16(kSmoothWeights + 28); + const __m128i weights_hi = LoadUnaligned16(kSmoothWeights + 44); + const __m128i weights1 = _mm_cvtepu8_epi16(weights_lo); + const __m128i weights2 = _mm_unpackhi_epi8(weights_lo, zero); + const __m128i weights3 = _mm_cvtepu8_epi16(weights_hi); + const __m128i weights4 = _mm_unpackhi_epi8(weights_hi, zero); + __m128i scale = _mm_set1_epi16(256); + const __m128i inverted_weights1 = _mm_sub_epi16(scale, weights1); + const __m128i inverted_weights2 = _mm_sub_epi16(scale, weights2); + const __m128i inverted_weights3 = _mm_sub_epi16(scale, weights3); + const __m128i inverted_weights4 = _mm_sub_epi16(scale, weights4); + const __m128i scaled_bottom_left1 = + _mm_mullo_epi16(inverted_weights1, bottom_left); + const __m128i scaled_bottom_left2 = + _mm_mullo_epi16(inverted_weights2, bottom_left); + const __m128i scaled_bottom_left3 = + _mm_mullo_epi16(inverted_weights3, bottom_left); + const __m128i scaled_bottom_left4 = + _mm_mullo_epi16(inverted_weights4, bottom_left); + scale = _mm_set1_epi16(128); + + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights1, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left1, y_select); + WriteSmoothDirectionalSum16(dst, top1, top2, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 16, top3, top4, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 32, top5, top6, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 48, top7, top8, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights2, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left2, y_select); + WriteSmoothDirectionalSum16(dst, top1, top2, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 16, top3, top4, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 32, top5, top6, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 48, top7, top8, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights3, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left3, y_select); + WriteSmoothDirectionalSum16(dst, top1, top2, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 16, top3, top4, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 32, top5, top6, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 48, top7, top8, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights4, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left4, y_select); + WriteSmoothDirectionalSum16(dst, top1, top2, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 16, top3, top4, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 32, top5, top6, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + WriteSmoothDirectionalSum16(dst + 48, top7, top8, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + scale); + dst += stride; + } +} + +void SmoothVertical64x64_SSE4_1(void* const dest, const ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + auto* dst = static_cast<uint8_t*>(dest); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i zero = _mm_setzero_si128(); + const __m128i bottom_left = _mm_set1_epi16(left_ptr[63]); + const __m128i top_lolo = LoadUnaligned16(top_ptr); + const __m128i top_lohi = LoadUnaligned16(top_ptr + 16); + const __m128i top1 = _mm_cvtepu8_epi16(top_lolo); + const __m128i top2 = _mm_unpackhi_epi8(top_lolo, zero); + const __m128i top3 = _mm_cvtepu8_epi16(top_lohi); + const __m128i top4 = _mm_unpackhi_epi8(top_lohi, zero); + const __m128i top_hilo = LoadUnaligned16(top_ptr + 32); + const __m128i top_hihi = LoadUnaligned16(top_ptr + 48); + const __m128i top5 = _mm_cvtepu8_epi16(top_hilo); + const __m128i top6 = _mm_unpackhi_epi8(top_hilo, zero); + const __m128i top7 = _mm_cvtepu8_epi16(top_hihi); + const __m128i top8 = _mm_unpackhi_epi8(top_hihi, zero); + const __m128i scale = _mm_set1_epi16(256); + const __m128i round = _mm_set1_epi16(128); + const uint8_t* weights_base_ptr = kSmoothWeights + 60; + for (int left_offset = 0; left_offset < 64; left_offset += 16) { + const __m128i weights = LoadUnaligned16(weights_base_ptr + left_offset); + const __m128i weights_lo = _mm_cvtepu8_epi16(weights); + const __m128i weights_hi = _mm_unpackhi_epi8(weights, zero); + const __m128i inverted_weights_lo = _mm_sub_epi16(scale, weights_lo); + const __m128i inverted_weights_hi = _mm_sub_epi16(scale, weights_hi); + const __m128i scaled_bottom_left_lo = + _mm_mullo_epi16(inverted_weights_lo, bottom_left); + const __m128i scaled_bottom_left_hi = + _mm_mullo_epi16(inverted_weights_hi, bottom_left); + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights_lo, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left_lo, y_select); + WriteSmoothDirectionalSum16(dst, top1, top2, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + round); + WriteSmoothDirectionalSum16(dst + 16, top3, top4, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + round); + WriteSmoothDirectionalSum16(dst + 32, top5, top6, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + round); + WriteSmoothDirectionalSum16(dst + 48, top7, top8, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + round); + dst += stride; + } + for (int y_mask = 0x01000100; y_mask < 0x0F0E0F0F; y_mask += 0x02020202) { + const __m128i y_select = _mm_set1_epi32(y_mask); + const __m128i weights_y = _mm_shuffle_epi8(weights_hi, y_select); + const __m128i scaled_bottom_left_y = + _mm_shuffle_epi8(scaled_bottom_left_hi, y_select); + WriteSmoothDirectionalSum16(dst, top1, top2, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + round); + WriteSmoothDirectionalSum16(dst + 16, top3, top4, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + round); + WriteSmoothDirectionalSum16(dst + 32, top5, top6, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + round); + WriteSmoothDirectionalSum16(dst + 48, top7, top8, weights_y, weights_y, + scaled_bottom_left_y, scaled_bottom_left_y, + round); + dst += stride; + } + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_IntraPredictorSmooth) + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmooth] = + Smooth4x4_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x8_IntraPredictorSmooth) + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmooth] = + Smooth4x8_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x16_IntraPredictorSmooth) + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmooth] = + Smooth4x16_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x4_IntraPredictorSmooth) + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmooth] = + Smooth8x4_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x8_IntraPredictorSmooth) + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmooth] = + Smooth8x8_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x16_IntraPredictorSmooth) + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmooth] = + Smooth8x16_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x32_IntraPredictorSmooth) + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmooth] = + Smooth8x32_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x4_IntraPredictorSmooth) + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmooth] = + SmoothWxH<16, 4>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x8_IntraPredictorSmooth) + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmooth] = + SmoothWxH<16, 8>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x16_IntraPredictorSmooth) + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmooth] = + SmoothWxH<16, 16>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x32_IntraPredictorSmooth) + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmooth] = + SmoothWxH<16, 32>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x64_IntraPredictorSmooth) + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmooth] = + SmoothWxH<16, 64>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x8_IntraPredictorSmooth) + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmooth] = + SmoothWxH<32, 8>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x16_IntraPredictorSmooth) + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmooth] = + SmoothWxH<32, 16>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x32_IntraPredictorSmooth) + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmooth] = + SmoothWxH<32, 32>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x64_IntraPredictorSmooth) + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmooth] = + SmoothWxH<32, 64>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x16_IntraPredictorSmooth) + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmooth] = + SmoothWxH<64, 16>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x32_IntraPredictorSmooth) + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmooth] = + SmoothWxH<64, 32>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x64_IntraPredictorSmooth) + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmooth] = + SmoothWxH<64, 64>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_IntraPredictorSmoothVertical) + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmoothVertical] = + SmoothVertical4x4_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x8_IntraPredictorSmoothVertical) + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmoothVertical] = + SmoothVertical4x8_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x16_IntraPredictorSmoothVertical) + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmoothVertical] = + SmoothVertical4x16_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x4_IntraPredictorSmoothVertical) + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmoothVertical] = + SmoothVertical8x4_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x8_IntraPredictorSmoothVertical) + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmoothVertical] = + SmoothVertical8x8_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x16_IntraPredictorSmoothVertical) + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmoothVertical] = + SmoothVertical8x16_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x32_IntraPredictorSmoothVertical) + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmoothVertical] = + SmoothVertical8x32_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x4_IntraPredictorSmoothVertical) + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmoothVertical] = + SmoothVertical16x4_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x8_IntraPredictorSmoothVertical) + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmoothVertical] = + SmoothVertical16x8_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x16_IntraPredictorSmoothVertical) + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmoothVertical] = + SmoothVertical16x16_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x32_IntraPredictorSmoothVertical) + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmoothVertical] = + SmoothVertical16x32_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x64_IntraPredictorSmoothVertical) + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmoothVertical] = + SmoothVertical16x64_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x8_IntraPredictorSmoothVertical) + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmoothVertical] = + SmoothVertical32x8_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x16_IntraPredictorSmoothVertical) + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmoothVertical] = + SmoothVertical32x16_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x32_IntraPredictorSmoothVertical) + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmoothVertical] = + SmoothVertical32x32_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x64_IntraPredictorSmoothVertical) + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmoothVertical] = + SmoothVertical32x64_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x16_IntraPredictorSmoothVertical) + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmoothVertical] = + SmoothVertical64x16_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x32_IntraPredictorSmoothVertical) + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmoothVertical] = + SmoothVertical64x32_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x64_IntraPredictorSmoothVertical) + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmoothVertical] = + SmoothVertical64x64_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_IntraPredictorSmoothHorizontal) + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal4x4_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x8_IntraPredictorSmoothHorizontal) + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal4x8_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x16_IntraPredictorSmoothHorizontal) + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal4x16_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x4_IntraPredictorSmoothHorizontal) + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal8x4_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x8_IntraPredictorSmoothHorizontal) + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal8x8_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x16_IntraPredictorSmoothHorizontal) + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal8x16_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x32_IntraPredictorSmoothHorizontal) + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal8x32_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x4_IntraPredictorSmoothHorizontal) + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16x4_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x8_IntraPredictorSmoothHorizontal) + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16x8_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x16_IntraPredictorSmoothHorizontal) + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16x16_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x32_IntraPredictorSmoothHorizontal) + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16x32_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x64_IntraPredictorSmoothHorizontal) + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal16x64_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x8_IntraPredictorSmoothHorizontal) + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal32x8_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x16_IntraPredictorSmoothHorizontal) + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal32x16_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x32_IntraPredictorSmoothHorizontal) + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal32x32_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x64_IntraPredictorSmoothHorizontal) + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal32x64_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x16_IntraPredictorSmoothHorizontal) + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal64x16_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x32_IntraPredictorSmoothHorizontal) + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal64x32_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x64_IntraPredictorSmoothHorizontal) + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmoothHorizontal] = + SmoothHorizontal64x64_SSE4_1; +#endif +} + +} // namespace +} // namespace low_bitdepth + +void IntraPredSmoothInit_SSE4_1() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_TARGETING_SSE4_1 + +namespace libgav1 { +namespace dsp { + +void IntraPredSmoothInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/intrapred_sse4.cc b/src/dsp/x86/intrapred_sse4.cc new file mode 100644 index 0000000..9938dfe --- /dev/null +++ b/src/dsp/x86/intrapred_sse4.cc @@ -0,0 +1,3535 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/intrapred.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 + +#include <xmmintrin.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> // memcpy + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/dsp/x86/transpose_sse4.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace { + +//------------------------------------------------------------------------------ +// Utility Functions + +// This is a fast way to divide by a number of the form 2^n + 2^k, n > k. +// Divide by 2^k by right shifting by k, leaving the denominator 2^m + 1. In the +// block size cases, n - k is 1 or 2 (block is proportional to 1x2 or 1x4), so +// we use a multiplier that reflects division by 2+1=3 or 4+1=5 in the high +// bits. +constexpr int kThreeInverse = 0x5556; +constexpr int kFiveInverse = 0x3334; +template <int shiftk, int multiplier> +inline __m128i DivideByMultiplyShift_U32(const __m128i dividend) { + const __m128i interm = _mm_srli_epi32(dividend, shiftk); + return _mm_mulhi_epi16(interm, _mm_cvtsi32_si128(multiplier)); +} + +// This shuffle mask selects 32-bit blocks in the order 0, 1, 0, 1, which +// duplicates the first 8 bytes of a 128-bit vector into the second 8 bytes. +constexpr int kDuplicateFirstHalf = 0x44; + +//------------------------------------------------------------------------------ +// DcPredFuncs_SSE4_1 + +using DcSumFunc = __m128i (*)(const void* ref); +using DcStoreFunc = void (*)(void* dest, ptrdiff_t stride, const __m128i dc); +using WriteDuplicateFunc = void (*)(void* dest, ptrdiff_t stride, + const __m128i column); +// For copying an entire column across a block. +using ColumnStoreFunc = void (*)(void* dest, ptrdiff_t stride, + const void* column); + +// DC intra-predictors for non-square blocks. +template <int width_log2, int height_log2, DcSumFunc top_sumfn, + DcSumFunc left_sumfn, DcStoreFunc storefn, int shiftk, int dc_mult> +struct DcPredFuncs_SSE4_1 { + DcPredFuncs_SSE4_1() = delete; + + static void DcTop(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); + static void DcLeft(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); + static void Dc(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); +}; + +// Directional intra-predictors for square blocks. +template <ColumnStoreFunc col_storefn> +struct DirectionalPredFuncs_SSE4_1 { + DirectionalPredFuncs_SSE4_1() = delete; + + static void Vertical(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); + static void Horizontal(void* dest, ptrdiff_t stride, const void* top_row, + const void* left_column); +}; + +template <int width_log2, int height_log2, DcSumFunc top_sumfn, + DcSumFunc left_sumfn, DcStoreFunc storefn, int shiftk, int dc_mult> +void DcPredFuncs_SSE4_1<width_log2, height_log2, top_sumfn, left_sumfn, storefn, + shiftk, dc_mult>::DcTop(void* const dest, + ptrdiff_t stride, + const void* const top_row, + const void* /*left_column*/) { + const __m128i rounder = _mm_set1_epi32(1 << (width_log2 - 1)); + const __m128i sum = top_sumfn(top_row); + const __m128i dc = _mm_srli_epi32(_mm_add_epi32(sum, rounder), width_log2); + storefn(dest, stride, dc); +} + +template <int width_log2, int height_log2, DcSumFunc top_sumfn, + DcSumFunc left_sumfn, DcStoreFunc storefn, int shiftk, int dc_mult> +void DcPredFuncs_SSE4_1<width_log2, height_log2, top_sumfn, left_sumfn, storefn, + shiftk, + dc_mult>::DcLeft(void* const dest, ptrdiff_t stride, + const void* /*top_row*/, + const void* const left_column) { + const __m128i rounder = _mm_set1_epi32(1 << (height_log2 - 1)); + const __m128i sum = left_sumfn(left_column); + const __m128i dc = _mm_srli_epi32(_mm_add_epi32(sum, rounder), height_log2); + storefn(dest, stride, dc); +} + +template <int width_log2, int height_log2, DcSumFunc top_sumfn, + DcSumFunc left_sumfn, DcStoreFunc storefn, int shiftk, int dc_mult> +void DcPredFuncs_SSE4_1<width_log2, height_log2, top_sumfn, left_sumfn, storefn, + shiftk, dc_mult>::Dc(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const __m128i rounder = + _mm_set1_epi32((1 << (width_log2 - 1)) + (1 << (height_log2 - 1))); + const __m128i sum_top = top_sumfn(top_row); + const __m128i sum_left = left_sumfn(left_column); + const __m128i sum = _mm_add_epi32(sum_top, sum_left); + if (width_log2 == height_log2) { + const __m128i dc = + _mm_srli_epi32(_mm_add_epi32(sum, rounder), width_log2 + 1); + storefn(dest, stride, dc); + } else { + const __m128i dc = + DivideByMultiplyShift_U32<shiftk, dc_mult>(_mm_add_epi32(sum, rounder)); + storefn(dest, stride, dc); + } +} + +//------------------------------------------------------------------------------ +// DcPredFuncs_SSE4_1 directional predictors + +template <ColumnStoreFunc col_storefn> +void DirectionalPredFuncs_SSE4_1<col_storefn>::Horizontal( + void* const dest, ptrdiff_t stride, const void* /*top_row*/, + const void* const left_column) { + col_storefn(dest, stride, left_column); +} + +} // namespace + +//------------------------------------------------------------------------------ +namespace low_bitdepth { +namespace { + +// |ref| points to 4 bytes containing 4 packed ints. +inline __m128i DcSum4_SSE4_1(const void* const ref) { + const __m128i vals = Load4(ref); + const __m128i zero = _mm_setzero_si128(); + return _mm_sad_epu8(vals, zero); +} + +inline __m128i DcSum8_SSE4_1(const void* const ref) { + const __m128i vals = LoadLo8(ref); + const __m128i zero = _mm_setzero_si128(); + return _mm_sad_epu8(vals, zero); +} + +inline __m128i DcSum16_SSE4_1(const void* const ref) { + const __m128i zero = _mm_setzero_si128(); + const __m128i vals = LoadUnaligned16(ref); + const __m128i partial_sum = _mm_sad_epu8(vals, zero); + return _mm_add_epi16(partial_sum, _mm_srli_si128(partial_sum, 8)); +} + +inline __m128i DcSum32_SSE4_1(const void* const ref) { + const __m128i zero = _mm_setzero_si128(); + const __m128i vals1 = LoadUnaligned16(ref); + const __m128i vals2 = LoadUnaligned16(static_cast<const uint8_t*>(ref) + 16); + const __m128i partial_sum1 = _mm_sad_epu8(vals1, zero); + const __m128i partial_sum2 = _mm_sad_epu8(vals2, zero); + const __m128i partial_sum = _mm_add_epi16(partial_sum1, partial_sum2); + return _mm_add_epi16(partial_sum, _mm_srli_si128(partial_sum, 8)); +} + +inline __m128i DcSum64_SSE4_1(const void* const ref) { + const auto* const ref_ptr = static_cast<const uint8_t*>(ref); + const __m128i zero = _mm_setzero_si128(); + const __m128i vals1 = LoadUnaligned16(ref_ptr); + const __m128i vals2 = LoadUnaligned16(ref_ptr + 16); + const __m128i vals3 = LoadUnaligned16(ref_ptr + 32); + const __m128i vals4 = LoadUnaligned16(ref_ptr + 48); + const __m128i partial_sum1 = _mm_sad_epu8(vals1, zero); + const __m128i partial_sum2 = _mm_sad_epu8(vals2, zero); + __m128i partial_sum = _mm_add_epi16(partial_sum1, partial_sum2); + const __m128i partial_sum3 = _mm_sad_epu8(vals3, zero); + partial_sum = _mm_add_epi16(partial_sum, partial_sum3); + const __m128i partial_sum4 = _mm_sad_epu8(vals4, zero); + partial_sum = _mm_add_epi16(partial_sum, partial_sum4); + return _mm_add_epi16(partial_sum, _mm_srli_si128(partial_sum, 8)); +} + +template <int height> +inline void DcStore4xH_SSE4_1(void* const dest, ptrdiff_t stride, + const __m128i dc) { + const __m128i zero = _mm_setzero_si128(); + const __m128i dc_dup = _mm_shuffle_epi8(dc, zero); + int y = height - 1; + auto* dst = static_cast<uint8_t*>(dest); + do { + Store4(dst, dc_dup); + dst += stride; + } while (--y != 0); + Store4(dst, dc_dup); +} + +template <int height> +inline void DcStore8xH_SSE4_1(void* const dest, ptrdiff_t stride, + const __m128i dc) { + const __m128i zero = _mm_setzero_si128(); + const __m128i dc_dup = _mm_shuffle_epi8(dc, zero); + int y = height - 1; + auto* dst = static_cast<uint8_t*>(dest); + do { + StoreLo8(dst, dc_dup); + dst += stride; + } while (--y != 0); + StoreLo8(dst, dc_dup); +} + +template <int height> +inline void DcStore16xH_SSE4_1(void* const dest, ptrdiff_t stride, + const __m128i dc) { + const __m128i zero = _mm_setzero_si128(); + const __m128i dc_dup = _mm_shuffle_epi8(dc, zero); + int y = height - 1; + auto* dst = static_cast<uint8_t*>(dest); + do { + StoreUnaligned16(dst, dc_dup); + dst += stride; + } while (--y != 0); + StoreUnaligned16(dst, dc_dup); +} + +template <int height> +inline void DcStore32xH_SSE4_1(void* const dest, ptrdiff_t stride, + const __m128i dc) { + const __m128i zero = _mm_setzero_si128(); + const __m128i dc_dup = _mm_shuffle_epi8(dc, zero); + int y = height - 1; + auto* dst = static_cast<uint8_t*>(dest); + do { + StoreUnaligned16(dst, dc_dup); + StoreUnaligned16(dst + 16, dc_dup); + dst += stride; + } while (--y != 0); + StoreUnaligned16(dst, dc_dup); + StoreUnaligned16(dst + 16, dc_dup); +} + +template <int height> +inline void DcStore64xH_SSE4_1(void* const dest, ptrdiff_t stride, + const __m128i dc) { + const __m128i zero = _mm_setzero_si128(); + const __m128i dc_dup = _mm_shuffle_epi8(dc, zero); + int y = height - 1; + auto* dst = static_cast<uint8_t*>(dest); + do { + StoreUnaligned16(dst, dc_dup); + StoreUnaligned16(dst + 16, dc_dup); + StoreUnaligned16(dst + 32, dc_dup); + StoreUnaligned16(dst + 48, dc_dup); + dst += stride; + } while (--y != 0); + StoreUnaligned16(dst, dc_dup); + StoreUnaligned16(dst + 16, dc_dup); + StoreUnaligned16(dst + 32, dc_dup); + StoreUnaligned16(dst + 48, dc_dup); +} + +// WriteDuplicateN assumes dup has 4 sets of 4 identical bytes that are meant to +// be copied for width N into dest. +inline void WriteDuplicate4x4(void* const dest, ptrdiff_t stride, + const __m128i dup32) { + auto* dst = static_cast<uint8_t*>(dest); + Store4(dst, dup32); + dst += stride; + const int row1 = _mm_extract_epi32(dup32, 1); + memcpy(dst, &row1, 4); + dst += stride; + const int row2 = _mm_extract_epi32(dup32, 2); + memcpy(dst, &row2, 4); + dst += stride; + const int row3 = _mm_extract_epi32(dup32, 3); + memcpy(dst, &row3, 4); +} + +inline void WriteDuplicate8x4(void* const dest, ptrdiff_t stride, + const __m128i dup32) { + const __m128i dup64_lo = _mm_unpacklo_epi32(dup32, dup32); + const __m128i dup64_hi = _mm_unpackhi_epi32(dup32, dup32); + auto* dst = static_cast<uint8_t*>(dest); + _mm_storel_epi64(reinterpret_cast<__m128i*>(dst), dup64_lo); + dst += stride; + _mm_storeh_pi(reinterpret_cast<__m64*>(dst), _mm_castsi128_ps(dup64_lo)); + dst += stride; + _mm_storel_epi64(reinterpret_cast<__m128i*>(dst), dup64_hi); + dst += stride; + _mm_storeh_pi(reinterpret_cast<__m64*>(dst), _mm_castsi128_ps(dup64_hi)); +} + +inline void WriteDuplicate16x4(void* const dest, ptrdiff_t stride, + const __m128i dup32) { + const __m128i dup64_lo = _mm_unpacklo_epi32(dup32, dup32); + const __m128i dup64_hi = _mm_unpackhi_epi32(dup32, dup32); + + auto* dst = static_cast<uint8_t*>(dest); + const __m128i dup128_0 = _mm_unpacklo_epi64(dup64_lo, dup64_lo); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_0); + dst += stride; + const __m128i dup128_1 = _mm_unpackhi_epi64(dup64_lo, dup64_lo); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_1); + dst += stride; + const __m128i dup128_2 = _mm_unpacklo_epi64(dup64_hi, dup64_hi); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_2); + dst += stride; + const __m128i dup128_3 = _mm_unpackhi_epi64(dup64_hi, dup64_hi); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_3); +} + +inline void WriteDuplicate32x4(void* const dest, ptrdiff_t stride, + const __m128i dup32) { + const __m128i dup64_lo = _mm_unpacklo_epi32(dup32, dup32); + const __m128i dup64_hi = _mm_unpackhi_epi32(dup32, dup32); + + auto* dst = static_cast<uint8_t*>(dest); + const __m128i dup128_0 = _mm_unpacklo_epi64(dup64_lo, dup64_lo); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_0); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 16), dup128_0); + dst += stride; + const __m128i dup128_1 = _mm_unpackhi_epi64(dup64_lo, dup64_lo); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_1); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 16), dup128_1); + dst += stride; + const __m128i dup128_2 = _mm_unpacklo_epi64(dup64_hi, dup64_hi); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_2); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 16), dup128_2); + dst += stride; + const __m128i dup128_3 = _mm_unpackhi_epi64(dup64_hi, dup64_hi); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_3); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 16), dup128_3); +} + +inline void WriteDuplicate64x4(void* const dest, ptrdiff_t stride, + const __m128i dup32) { + const __m128i dup64_lo = _mm_unpacklo_epi32(dup32, dup32); + const __m128i dup64_hi = _mm_unpackhi_epi32(dup32, dup32); + + auto* dst = static_cast<uint8_t*>(dest); + const __m128i dup128_0 = _mm_unpacklo_epi64(dup64_lo, dup64_lo); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_0); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 16), dup128_0); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 32), dup128_0); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 48), dup128_0); + dst += stride; + const __m128i dup128_1 = _mm_unpackhi_epi64(dup64_lo, dup64_lo); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_1); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 16), dup128_1); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 32), dup128_1); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 48), dup128_1); + dst += stride; + const __m128i dup128_2 = _mm_unpacklo_epi64(dup64_hi, dup64_hi); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_2); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 16), dup128_2); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 32), dup128_2); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 48), dup128_2); + dst += stride; + const __m128i dup128_3 = _mm_unpackhi_epi64(dup64_hi, dup64_hi); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_3); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 16), dup128_3); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 32), dup128_3); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 48), dup128_3); +} + +// ColStoreN<height> copies each of the |height| values in |column| across its +// corresponding in dest. +template <WriteDuplicateFunc writefn> +inline void ColStore4_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const column) { + const __m128i col_data = Load4(column); + const __m128i col_dup16 = _mm_unpacklo_epi8(col_data, col_data); + const __m128i col_dup32 = _mm_unpacklo_epi16(col_dup16, col_dup16); + writefn(dest, stride, col_dup32); +} + +template <WriteDuplicateFunc writefn> +inline void ColStore8_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const column) { + const ptrdiff_t stride4 = stride << 2; + const __m128i col_data = LoadLo8(column); + const __m128i col_dup16 = _mm_unpacklo_epi8(col_data, col_data); + const __m128i col_dup32_lo = _mm_unpacklo_epi16(col_dup16, col_dup16); + auto* dst = static_cast<uint8_t*>(dest); + writefn(dst, stride, col_dup32_lo); + dst += stride4; + const __m128i col_dup32_hi = _mm_unpackhi_epi16(col_dup16, col_dup16); + writefn(dst, stride, col_dup32_hi); +} + +template <WriteDuplicateFunc writefn> +inline void ColStore16_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const column) { + const ptrdiff_t stride4 = stride << 2; + const __m128i col_data = _mm_loadu_si128(static_cast<const __m128i*>(column)); + const __m128i col_dup16_lo = _mm_unpacklo_epi8(col_data, col_data); + const __m128i col_dup16_hi = _mm_unpackhi_epi8(col_data, col_data); + const __m128i col_dup32_lolo = _mm_unpacklo_epi16(col_dup16_lo, col_dup16_lo); + auto* dst = static_cast<uint8_t*>(dest); + writefn(dst, stride, col_dup32_lolo); + dst += stride4; + const __m128i col_dup32_lohi = _mm_unpackhi_epi16(col_dup16_lo, col_dup16_lo); + writefn(dst, stride, col_dup32_lohi); + dst += stride4; + const __m128i col_dup32_hilo = _mm_unpacklo_epi16(col_dup16_hi, col_dup16_hi); + writefn(dst, stride, col_dup32_hilo); + dst += stride4; + const __m128i col_dup32_hihi = _mm_unpackhi_epi16(col_dup16_hi, col_dup16_hi); + writefn(dst, stride, col_dup32_hihi); +} + +template <WriteDuplicateFunc writefn> +inline void ColStore32_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const column) { + const ptrdiff_t stride4 = stride << 2; + auto* dst = static_cast<uint8_t*>(dest); + for (int y = 0; y < 32; y += 16) { + const __m128i col_data = + LoadUnaligned16(static_cast<const uint8_t*>(column) + y); + const __m128i col_dup16_lo = _mm_unpacklo_epi8(col_data, col_data); + const __m128i col_dup16_hi = _mm_unpackhi_epi8(col_data, col_data); + const __m128i col_dup32_lolo = + _mm_unpacklo_epi16(col_dup16_lo, col_dup16_lo); + writefn(dst, stride, col_dup32_lolo); + dst += stride4; + const __m128i col_dup32_lohi = + _mm_unpackhi_epi16(col_dup16_lo, col_dup16_lo); + writefn(dst, stride, col_dup32_lohi); + dst += stride4; + const __m128i col_dup32_hilo = + _mm_unpacklo_epi16(col_dup16_hi, col_dup16_hi); + writefn(dst, stride, col_dup32_hilo); + dst += stride4; + const __m128i col_dup32_hihi = + _mm_unpackhi_epi16(col_dup16_hi, col_dup16_hi); + writefn(dst, stride, col_dup32_hihi); + dst += stride4; + } +} + +template <WriteDuplicateFunc writefn> +inline void ColStore64_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const column) { + const ptrdiff_t stride4 = stride << 2; + auto* dst = static_cast<uint8_t*>(dest); + for (int y = 0; y < 64; y += 16) { + const __m128i col_data = + LoadUnaligned16(static_cast<const uint8_t*>(column) + y); + const __m128i col_dup16_lo = _mm_unpacklo_epi8(col_data, col_data); + const __m128i col_dup16_hi = _mm_unpackhi_epi8(col_data, col_data); + const __m128i col_dup32_lolo = + _mm_unpacklo_epi16(col_dup16_lo, col_dup16_lo); + writefn(dst, stride, col_dup32_lolo); + dst += stride4; + const __m128i col_dup32_lohi = + _mm_unpackhi_epi16(col_dup16_lo, col_dup16_lo); + writefn(dst, stride, col_dup32_lohi); + dst += stride4; + const __m128i col_dup32_hilo = + _mm_unpacklo_epi16(col_dup16_hi, col_dup16_hi); + writefn(dst, stride, col_dup32_hilo); + dst += stride4; + const __m128i col_dup32_hihi = + _mm_unpackhi_epi16(col_dup16_hi, col_dup16_hi); + writefn(dst, stride, col_dup32_hihi); + dst += stride4; + } +} + +struct DcDefs { + DcDefs() = delete; + + using _4x4 = DcPredFuncs_SSE4_1<2, 2, DcSum4_SSE4_1, DcSum4_SSE4_1, + DcStore4xH_SSE4_1<4>, 0, 0>; + // shiftk is the smaller of width_log2 and height_log2. + // dc_mult corresponds to the ratio of the smaller block size to the larger. + using _4x8 = DcPredFuncs_SSE4_1<2, 3, DcSum4_SSE4_1, DcSum8_SSE4_1, + DcStore4xH_SSE4_1<8>, 2, kThreeInverse>; + using _4x16 = DcPredFuncs_SSE4_1<2, 4, DcSum4_SSE4_1, DcSum16_SSE4_1, + DcStore4xH_SSE4_1<16>, 2, kFiveInverse>; + + using _8x4 = DcPredFuncs_SSE4_1<3, 2, DcSum8_SSE4_1, DcSum4_SSE4_1, + DcStore8xH_SSE4_1<4>, 2, kThreeInverse>; + using _8x8 = DcPredFuncs_SSE4_1<3, 3, DcSum8_SSE4_1, DcSum8_SSE4_1, + DcStore8xH_SSE4_1<8>, 0, 0>; + using _8x16 = DcPredFuncs_SSE4_1<3, 4, DcSum8_SSE4_1, DcSum16_SSE4_1, + DcStore8xH_SSE4_1<16>, 3, kThreeInverse>; + using _8x32 = DcPredFuncs_SSE4_1<3, 5, DcSum8_SSE4_1, DcSum32_SSE4_1, + DcStore8xH_SSE4_1<32>, 3, kFiveInverse>; + + using _16x4 = DcPredFuncs_SSE4_1<4, 2, DcSum16_SSE4_1, DcSum4_SSE4_1, + DcStore16xH_SSE4_1<4>, 2, kFiveInverse>; + using _16x8 = DcPredFuncs_SSE4_1<4, 3, DcSum16_SSE4_1, DcSum8_SSE4_1, + DcStore16xH_SSE4_1<8>, 3, kThreeInverse>; + using _16x16 = DcPredFuncs_SSE4_1<4, 4, DcSum16_SSE4_1, DcSum16_SSE4_1, + DcStore16xH_SSE4_1<16>, 0, 0>; + using _16x32 = DcPredFuncs_SSE4_1<4, 5, DcSum16_SSE4_1, DcSum32_SSE4_1, + DcStore16xH_SSE4_1<32>, 4, kThreeInverse>; + using _16x64 = DcPredFuncs_SSE4_1<4, 6, DcSum16_SSE4_1, DcSum64_SSE4_1, + DcStore16xH_SSE4_1<64>, 4, kFiveInverse>; + + using _32x8 = DcPredFuncs_SSE4_1<5, 3, DcSum32_SSE4_1, DcSum8_SSE4_1, + DcStore32xH_SSE4_1<8>, 3, kFiveInverse>; + using _32x16 = DcPredFuncs_SSE4_1<5, 4, DcSum32_SSE4_1, DcSum16_SSE4_1, + DcStore32xH_SSE4_1<16>, 4, kThreeInverse>; + using _32x32 = DcPredFuncs_SSE4_1<5, 5, DcSum32_SSE4_1, DcSum32_SSE4_1, + DcStore32xH_SSE4_1<32>, 0, 0>; + using _32x64 = DcPredFuncs_SSE4_1<5, 6, DcSum32_SSE4_1, DcSum64_SSE4_1, + DcStore32xH_SSE4_1<64>, 5, kThreeInverse>; + + using _64x16 = DcPredFuncs_SSE4_1<6, 4, DcSum64_SSE4_1, DcSum16_SSE4_1, + DcStore64xH_SSE4_1<16>, 4, kFiveInverse>; + using _64x32 = DcPredFuncs_SSE4_1<6, 5, DcSum64_SSE4_1, DcSum32_SSE4_1, + DcStore64xH_SSE4_1<32>, 5, kThreeInverse>; + using _64x64 = DcPredFuncs_SSE4_1<6, 6, DcSum64_SSE4_1, DcSum64_SSE4_1, + DcStore64xH_SSE4_1<64>, 0, 0>; +}; + +struct DirDefs { + DirDefs() = delete; + + using _4x4 = DirectionalPredFuncs_SSE4_1<ColStore4_SSE4_1<WriteDuplicate4x4>>; + using _4x8 = DirectionalPredFuncs_SSE4_1<ColStore8_SSE4_1<WriteDuplicate4x4>>; + using _4x16 = + DirectionalPredFuncs_SSE4_1<ColStore16_SSE4_1<WriteDuplicate4x4>>; + using _8x4 = DirectionalPredFuncs_SSE4_1<ColStore4_SSE4_1<WriteDuplicate8x4>>; + using _8x8 = DirectionalPredFuncs_SSE4_1<ColStore8_SSE4_1<WriteDuplicate8x4>>; + using _8x16 = + DirectionalPredFuncs_SSE4_1<ColStore16_SSE4_1<WriteDuplicate8x4>>; + using _8x32 = + DirectionalPredFuncs_SSE4_1<ColStore32_SSE4_1<WriteDuplicate8x4>>; + using _16x4 = + DirectionalPredFuncs_SSE4_1<ColStore4_SSE4_1<WriteDuplicate16x4>>; + using _16x8 = + DirectionalPredFuncs_SSE4_1<ColStore8_SSE4_1<WriteDuplicate16x4>>; + using _16x16 = + DirectionalPredFuncs_SSE4_1<ColStore16_SSE4_1<WriteDuplicate16x4>>; + using _16x32 = + DirectionalPredFuncs_SSE4_1<ColStore32_SSE4_1<WriteDuplicate16x4>>; + using _16x64 = + DirectionalPredFuncs_SSE4_1<ColStore64_SSE4_1<WriteDuplicate16x4>>; + using _32x8 = + DirectionalPredFuncs_SSE4_1<ColStore8_SSE4_1<WriteDuplicate32x4>>; + using _32x16 = + DirectionalPredFuncs_SSE4_1<ColStore16_SSE4_1<WriteDuplicate32x4>>; + using _32x32 = + DirectionalPredFuncs_SSE4_1<ColStore32_SSE4_1<WriteDuplicate32x4>>; + using _32x64 = + DirectionalPredFuncs_SSE4_1<ColStore64_SSE4_1<WriteDuplicate32x4>>; + using _64x16 = + DirectionalPredFuncs_SSE4_1<ColStore16_SSE4_1<WriteDuplicate64x4>>; + using _64x32 = + DirectionalPredFuncs_SSE4_1<ColStore32_SSE4_1<WriteDuplicate64x4>>; + using _64x64 = + DirectionalPredFuncs_SSE4_1<ColStore64_SSE4_1<WriteDuplicate64x4>>; +}; + +template <int y_mask> +inline void WritePaethLine4(uint8_t* dst, const __m128i& top, + const __m128i& left, const __m128i& top_lefts, + const __m128i& top_dists, const __m128i& left_dists, + const __m128i& top_left_diffs) { + const __m128i top_dists_y = _mm_shuffle_epi32(top_dists, y_mask); + + const __m128i lefts_y = _mm_shuffle_epi32(left, y_mask); + const __m128i top_left_dists = + _mm_abs_epi32(_mm_add_epi32(lefts_y, top_left_diffs)); + + // Section 7.11.2.2 specifies the logic and terms here. The less-or-equal + // operation is unavailable, so the logic for selecting top, left, or + // top_left is inverted. + __m128i not_select_left = _mm_cmpgt_epi32(left_dists, top_left_dists); + not_select_left = + _mm_or_si128(not_select_left, _mm_cmpgt_epi32(left_dists, top_dists_y)); + const __m128i not_select_top = _mm_cmpgt_epi32(top_dists_y, top_left_dists); + + const __m128i left_out = _mm_andnot_si128(not_select_left, lefts_y); + + const __m128i top_left_out = _mm_and_si128(not_select_top, top_lefts); + __m128i top_or_top_left_out = _mm_andnot_si128(not_select_top, top); + top_or_top_left_out = _mm_or_si128(top_or_top_left_out, top_left_out); + top_or_top_left_out = _mm_and_si128(not_select_left, top_or_top_left_out); + + // The sequence of 32-bit packed operations was found (see CL via blame) to + // outperform 16-bit operations, despite the availability of the packus + // function, when tested on a Xeon E7 v3. + const __m128i cvtepi32_epi8 = _mm_set1_epi32(0x0C080400); + const __m128i pred = _mm_shuffle_epi8( + _mm_or_si128(left_out, top_or_top_left_out), cvtepi32_epi8); + Store4(dst, pred); +} + +// top_left_diffs is the only variable whose ints may exceed 8 bits. Otherwise +// we would be able to do all of these operations as epi8 for a 16-pixel version +// of this function. Still, since lefts_y is just a vector of duplicates, it +// could pay off to accommodate top_left_dists for cmpgt, and repack into epi8 +// for the blends. +template <int y_mask> +inline void WritePaethLine8(uint8_t* dst, const __m128i& top, + const __m128i& left, const __m128i& top_lefts, + const __m128i& top_dists, const __m128i& left_dists, + const __m128i& top_left_diffs) { + const __m128i select_y = _mm_set1_epi32(y_mask); + const __m128i top_dists_y = _mm_shuffle_epi8(top_dists, select_y); + + const __m128i lefts_y = _mm_shuffle_epi8(left, select_y); + const __m128i top_left_dists = + _mm_abs_epi16(_mm_add_epi16(lefts_y, top_left_diffs)); + + // Section 7.11.2.2 specifies the logic and terms here. The less-or-equal + // operation is unavailable, so the logic for selecting top, left, or + // top_left is inverted. + __m128i not_select_left = _mm_cmpgt_epi16(left_dists, top_left_dists); + not_select_left = + _mm_or_si128(not_select_left, _mm_cmpgt_epi16(left_dists, top_dists_y)); + const __m128i not_select_top = _mm_cmpgt_epi16(top_dists_y, top_left_dists); + + const __m128i left_out = _mm_andnot_si128(not_select_left, lefts_y); + + const __m128i top_left_out = _mm_and_si128(not_select_top, top_lefts); + __m128i top_or_top_left_out = _mm_andnot_si128(not_select_top, top); + top_or_top_left_out = _mm_or_si128(top_or_top_left_out, top_left_out); + top_or_top_left_out = _mm_and_si128(not_select_left, top_or_top_left_out); + + const __m128i pred = _mm_packus_epi16( + _mm_or_si128(left_out, top_or_top_left_out), /* unused */ left_out); + _mm_storel_epi64(reinterpret_cast<__m128i*>(dst), pred); +} + +// |top| is an epi8 of length 16 +// |left| is epi8 of unknown length, as y_mask specifies access +// |top_lefts| is an epi8 of 16 duplicates +// |top_dists| is an epi8 of unknown length, as y_mask specifies access +// |left_dists| is an epi8 of length 16 +// |left_dists_lo| is an epi16 of length 8 +// |left_dists_hi| is an epi16 of length 8 +// |top_left_diffs_lo| is an epi16 of length 8 +// |top_left_diffs_hi| is an epi16 of length 8 +// The latter two vectors are epi16 because their values may reach -510. +// |left_dists| is provided alongside its spread out version because it doesn't +// change between calls and interacts with both kinds of packing. +template <int y_mask> +inline void WritePaethLine16(uint8_t* dst, const __m128i& top, + const __m128i& left, const __m128i& top_lefts, + const __m128i& top_dists, + const __m128i& left_dists, + const __m128i& left_dists_lo, + const __m128i& left_dists_hi, + const __m128i& top_left_diffs_lo, + const __m128i& top_left_diffs_hi) { + const __m128i select_y = _mm_set1_epi32(y_mask); + const __m128i top_dists_y8 = _mm_shuffle_epi8(top_dists, select_y); + const __m128i top_dists_y16 = _mm_cvtepu8_epi16(top_dists_y8); + const __m128i lefts_y8 = _mm_shuffle_epi8(left, select_y); + const __m128i lefts_y16 = _mm_cvtepu8_epi16(lefts_y8); + + const __m128i top_left_dists_lo = + _mm_abs_epi16(_mm_add_epi16(lefts_y16, top_left_diffs_lo)); + const __m128i top_left_dists_hi = + _mm_abs_epi16(_mm_add_epi16(lefts_y16, top_left_diffs_hi)); + + const __m128i left_gt_top_left_lo = _mm_packs_epi16( + _mm_cmpgt_epi16(left_dists_lo, top_left_dists_lo), left_dists_lo); + const __m128i left_gt_top_left_hi = + _mm_packs_epi16(_mm_cmpgt_epi16(left_dists_hi, top_left_dists_hi), + /* unused second arg for pack */ left_dists_hi); + const __m128i left_gt_top_left = _mm_alignr_epi8( + left_gt_top_left_hi, _mm_slli_si128(left_gt_top_left_lo, 8), 8); + + const __m128i not_select_top_lo = + _mm_packs_epi16(_mm_cmpgt_epi16(top_dists_y16, top_left_dists_lo), + /* unused second arg for pack */ top_dists_y16); + const __m128i not_select_top_hi = + _mm_packs_epi16(_mm_cmpgt_epi16(top_dists_y16, top_left_dists_hi), + /* unused second arg for pack */ top_dists_y16); + const __m128i not_select_top = _mm_alignr_epi8( + not_select_top_hi, _mm_slli_si128(not_select_top_lo, 8), 8); + + const __m128i left_leq_top = + _mm_cmpeq_epi8(left_dists, _mm_min_epu8(top_dists_y8, left_dists)); + const __m128i select_left = _mm_andnot_si128(left_gt_top_left, left_leq_top); + + // Section 7.11.2.2 specifies the logic and terms here. The less-or-equal + // operation is unavailable, so the logic for selecting top, left, or + // top_left is inverted. + const __m128i left_out = _mm_and_si128(select_left, lefts_y8); + + const __m128i top_left_out = _mm_and_si128(not_select_top, top_lefts); + __m128i top_or_top_left_out = _mm_andnot_si128(not_select_top, top); + top_or_top_left_out = _mm_or_si128(top_or_top_left_out, top_left_out); + top_or_top_left_out = _mm_andnot_si128(select_left, top_or_top_left_out); + const __m128i pred = _mm_or_si128(left_out, top_or_top_left_out); + + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), pred); +} + +void Paeth4x4_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, const void* const left_column) { + const __m128i left = _mm_cvtepu8_epi32(Load4(left_column)); + const __m128i top = _mm_cvtepu8_epi32(Load4(top_row)); + + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i top_lefts = _mm_set1_epi32(top_ptr[-1]); + + // Given that the spec defines "base" as top[x] + left[y] - top[-1], + // pLeft = abs(base - left[y]) = abs(top[x] - top[-1]) + // pTop = abs(base - top[x]) = abs(left[y] - top[-1]) + const __m128i left_dists = _mm_abs_epi32(_mm_sub_epi32(top, top_lefts)); + const __m128i top_dists = _mm_abs_epi32(_mm_sub_epi32(left, top_lefts)); + + const __m128i top_left_x2 = _mm_add_epi32(top_lefts, top_lefts); + const __m128i top_left_diff = _mm_sub_epi32(top, top_left_x2); + auto* dst = static_cast<uint8_t*>(dest); + WritePaethLine4<0>(dst, top, left, top_lefts, top_dists, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0x55>(dst, top, left, top_lefts, top_dists, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0xAA>(dst, top, left, top_lefts, top_dists, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0xFF>(dst, top, left, top_lefts, top_dists, left_dists, + top_left_diff); +} + +void Paeth4x8_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, const void* const left_column) { + const __m128i left = LoadLo8(left_column); + const __m128i left_lo = _mm_cvtepu8_epi32(left); + const __m128i left_hi = _mm_cvtepu8_epi32(_mm_srli_si128(left, 4)); + + const __m128i top = _mm_cvtepu8_epi32(Load4(top_row)); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i top_lefts = _mm_set1_epi32(top_ptr[-1]); + + // Given that the spec defines "base" as top[x] + left[y] - top[-1], + // pLeft = abs(base - left[y]) = abs(top[x] - top[-1]) + // pTop = abs(base - top[x]) = abs(left[y] - top[-1]) + const __m128i left_dists = _mm_abs_epi32(_mm_sub_epi32(top, top_lefts)); + const __m128i top_dists_lo = _mm_abs_epi32(_mm_sub_epi32(left_lo, top_lefts)); + const __m128i top_dists_hi = _mm_abs_epi32(_mm_sub_epi32(left_hi, top_lefts)); + + const __m128i top_left_x2 = _mm_add_epi32(top_lefts, top_lefts); + const __m128i top_left_diff = _mm_sub_epi32(top, top_left_x2); + auto* dst = static_cast<uint8_t*>(dest); + WritePaethLine4<0>(dst, top, left_lo, top_lefts, top_dists_lo, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0x55>(dst, top, left_lo, top_lefts, top_dists_lo, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0xAA>(dst, top, left_lo, top_lefts, top_dists_lo, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0xFF>(dst, top, left_lo, top_lefts, top_dists_lo, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0>(dst, top, left_hi, top_lefts, top_dists_hi, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0x55>(dst, top, left_hi, top_lefts, top_dists_hi, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0xAA>(dst, top, left_hi, top_lefts, top_dists_hi, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0xFF>(dst, top, left_hi, top_lefts, top_dists_hi, left_dists, + top_left_diff); +} + +void Paeth4x16_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const __m128i left = LoadUnaligned16(left_column); + const __m128i left_0 = _mm_cvtepu8_epi32(left); + const __m128i left_1 = _mm_cvtepu8_epi32(_mm_srli_si128(left, 4)); + const __m128i left_2 = _mm_cvtepu8_epi32(_mm_srli_si128(left, 8)); + const __m128i left_3 = _mm_cvtepu8_epi32(_mm_srli_si128(left, 12)); + + const __m128i top = _mm_cvtepu8_epi32(Load4(top_row)); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i top_lefts = _mm_set1_epi32(top_ptr[-1]); + + // Given that the spec defines "base" as top[x] + left[y] - top[-1], + // pLeft = abs(base - left[y]) = abs(top[x] - top[-1]) + // pTop = abs(base - top[x]) = abs(left[y] - top[-1]) + const __m128i left_dists = _mm_abs_epi32(_mm_sub_epi32(top, top_lefts)); + const __m128i top_dists_0 = _mm_abs_epi32(_mm_sub_epi32(left_0, top_lefts)); + const __m128i top_dists_1 = _mm_abs_epi32(_mm_sub_epi32(left_1, top_lefts)); + const __m128i top_dists_2 = _mm_abs_epi32(_mm_sub_epi32(left_2, top_lefts)); + const __m128i top_dists_3 = _mm_abs_epi32(_mm_sub_epi32(left_3, top_lefts)); + + const __m128i top_left_x2 = _mm_add_epi32(top_lefts, top_lefts); + const __m128i top_left_diff = _mm_sub_epi32(top, top_left_x2); + + auto* dst = static_cast<uint8_t*>(dest); + WritePaethLine4<0>(dst, top, left_0, top_lefts, top_dists_0, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0x55>(dst, top, left_0, top_lefts, top_dists_0, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0xAA>(dst, top, left_0, top_lefts, top_dists_0, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0xFF>(dst, top, left_0, top_lefts, top_dists_0, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0>(dst, top, left_1, top_lefts, top_dists_1, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0x55>(dst, top, left_1, top_lefts, top_dists_1, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0xAA>(dst, top, left_1, top_lefts, top_dists_1, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0xFF>(dst, top, left_1, top_lefts, top_dists_1, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0>(dst, top, left_2, top_lefts, top_dists_2, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0x55>(dst, top, left_2, top_lefts, top_dists_2, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0xAA>(dst, top, left_2, top_lefts, top_dists_2, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0xFF>(dst, top, left_2, top_lefts, top_dists_2, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0>(dst, top, left_3, top_lefts, top_dists_3, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0x55>(dst, top, left_3, top_lefts, top_dists_3, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0xAA>(dst, top, left_3, top_lefts, top_dists_3, left_dists, + top_left_diff); + dst += stride; + WritePaethLine4<0xFF>(dst, top, left_3, top_lefts, top_dists_3, left_dists, + top_left_diff); +} + +void Paeth8x4_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, const void* const left_column) { + const __m128i left = _mm_cvtepu8_epi16(Load4(left_column)); + const __m128i top = _mm_cvtepu8_epi16(LoadLo8(top_row)); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i top_lefts = _mm_set1_epi16(top_ptr[-1]); + + // Given that the spec defines "base" as top[x] + left[y] - top[-1], + // pLeft = abs(base - left[y]) = abs(top[x] - top[-1]) + // pTop = abs(base - top[x]) = abs(left[y] - top[-1]) + const __m128i left_dists = _mm_abs_epi16(_mm_sub_epi16(top, top_lefts)); + const __m128i top_dists = _mm_abs_epi16(_mm_sub_epi16(left, top_lefts)); + + const __m128i top_left_x2 = _mm_add_epi16(top_lefts, top_lefts); + const __m128i top_left_diff = _mm_sub_epi16(top, top_left_x2); + auto* dst = static_cast<uint8_t*>(dest); + WritePaethLine8<0x01000100>(dst, top, left, top_lefts, top_dists, left_dists, + top_left_diff); + dst += stride; + WritePaethLine8<0x03020302>(dst, top, left, top_lefts, top_dists, left_dists, + top_left_diff); + dst += stride; + WritePaethLine8<0x05040504>(dst, top, left, top_lefts, top_dists, left_dists, + top_left_diff); + dst += stride; + WritePaethLine8<0x07060706>(dst, top, left, top_lefts, top_dists, left_dists, + top_left_diff); +} + +void Paeth8x8_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, const void* const left_column) { + const __m128i left = _mm_cvtepu8_epi16(LoadLo8(left_column)); + const __m128i top = _mm_cvtepu8_epi16(LoadLo8(top_row)); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i top_lefts = _mm_set1_epi16(top_ptr[-1]); + + // Given that the spec defines "base" as top[x] + left[y] - top[-1], + // pLeft = abs(base - left[y]) = abs(top[x] - top[-1]) + // pTop = abs(base - top[x]) = abs(left[y] - top[-1]) + const __m128i left_dists = _mm_abs_epi16(_mm_sub_epi16(top, top_lefts)); + const __m128i top_dists = _mm_abs_epi16(_mm_sub_epi16(left, top_lefts)); + + const __m128i top_left_x2 = _mm_add_epi16(top_lefts, top_lefts); + const __m128i top_left_diff = _mm_sub_epi16(top, top_left_x2); + auto* dst = static_cast<uint8_t*>(dest); + WritePaethLine8<0x01000100>(dst, top, left, top_lefts, top_dists, left_dists, + top_left_diff); + dst += stride; + WritePaethLine8<0x03020302>(dst, top, left, top_lefts, top_dists, left_dists, + top_left_diff); + dst += stride; + WritePaethLine8<0x05040504>(dst, top, left, top_lefts, top_dists, left_dists, + top_left_diff); + dst += stride; + WritePaethLine8<0x07060706>(dst, top, left, top_lefts, top_dists, left_dists, + top_left_diff); + dst += stride; + WritePaethLine8<0x09080908>(dst, top, left, top_lefts, top_dists, left_dists, + top_left_diff); + dst += stride; + WritePaethLine8<0x0B0A0B0A>(dst, top, left, top_lefts, top_dists, left_dists, + top_left_diff); + dst += stride; + WritePaethLine8<0x0D0C0D0C>(dst, top, left, top_lefts, top_dists, left_dists, + top_left_diff); + dst += stride; + WritePaethLine8<0x0F0E0F0E>(dst, top, left, top_lefts, top_dists, left_dists, + top_left_diff); +} + +void Paeth8x16_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const __m128i left = LoadUnaligned16(left_column); + const __m128i left_lo = _mm_cvtepu8_epi16(left); + const __m128i left_hi = _mm_cvtepu8_epi16(_mm_srli_si128(left, 8)); + const __m128i top = _mm_cvtepu8_epi16(LoadLo8(top_row)); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i top_lefts = _mm_set1_epi16(top_ptr[-1]); + + // Given that the spec defines "base" as top[x] + left[y] - top[-1], + // pLeft = abs(base - left[y]) = abs(top[x] - top[-1]) + // pTop = abs(base - top[x]) = abs(left[y] - top[-1]) + const __m128i left_dists = _mm_abs_epi16(_mm_sub_epi16(top, top_lefts)); + const __m128i top_dists_lo = _mm_abs_epi16(_mm_sub_epi16(left_lo, top_lefts)); + const __m128i top_dists_hi = _mm_abs_epi16(_mm_sub_epi16(left_hi, top_lefts)); + + const __m128i top_left_x2 = _mm_add_epi16(top_lefts, top_lefts); + const __m128i top_left_diff = _mm_sub_epi16(top, top_left_x2); + auto* dst = static_cast<uint8_t*>(dest); + WritePaethLine8<0x01000100>(dst, top, left_lo, top_lefts, top_dists_lo, + left_dists, top_left_diff); + dst += stride; + WritePaethLine8<0x03020302>(dst, top, left_lo, top_lefts, top_dists_lo, + left_dists, top_left_diff); + dst += stride; + WritePaethLine8<0x05040504>(dst, top, left_lo, top_lefts, top_dists_lo, + left_dists, top_left_diff); + dst += stride; + WritePaethLine8<0x07060706>(dst, top, left_lo, top_lefts, top_dists_lo, + left_dists, top_left_diff); + dst += stride; + WritePaethLine8<0x09080908>(dst, top, left_lo, top_lefts, top_dists_lo, + left_dists, top_left_diff); + dst += stride; + WritePaethLine8<0x0B0A0B0A>(dst, top, left_lo, top_lefts, top_dists_lo, + left_dists, top_left_diff); + dst += stride; + WritePaethLine8<0x0D0C0D0C>(dst, top, left_lo, top_lefts, top_dists_lo, + left_dists, top_left_diff); + dst += stride; + WritePaethLine8<0x0F0E0F0E>(dst, top, left_lo, top_lefts, top_dists_lo, + left_dists, top_left_diff); + dst += stride; + WritePaethLine8<0x01000100>(dst, top, left_hi, top_lefts, top_dists_hi, + left_dists, top_left_diff); + dst += stride; + WritePaethLine8<0x03020302>(dst, top, left_hi, top_lefts, top_dists_hi, + left_dists, top_left_diff); + dst += stride; + WritePaethLine8<0x05040504>(dst, top, left_hi, top_lefts, top_dists_hi, + left_dists, top_left_diff); + dst += stride; + WritePaethLine8<0x07060706>(dst, top, left_hi, top_lefts, top_dists_hi, + left_dists, top_left_diff); + dst += stride; + WritePaethLine8<0x09080908>(dst, top, left_hi, top_lefts, top_dists_hi, + left_dists, top_left_diff); + dst += stride; + WritePaethLine8<0x0B0A0B0A>(dst, top, left_hi, top_lefts, top_dists_hi, + left_dists, top_left_diff); + dst += stride; + WritePaethLine8<0x0D0C0D0C>(dst, top, left_hi, top_lefts, top_dists_hi, + left_dists, top_left_diff); + dst += stride; + WritePaethLine8<0x0F0E0F0E>(dst, top, left_hi, top_lefts, top_dists_hi, + left_dists, top_left_diff); +} + +void Paeth8x32_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + auto* const dst = static_cast<uint8_t*>(dest); + Paeth8x16_SSE4_1(dst, stride, top_row, left_column); + Paeth8x16_SSE4_1(dst + (stride << 4), stride, top_row, left_ptr + 16); +} + +void Paeth16x4_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const __m128i left = Load4(left_column); + const __m128i top = LoadUnaligned16(top_row); + const __m128i top_lo = _mm_cvtepu8_epi16(top); + const __m128i top_hi = _mm_cvtepu8_epi16(_mm_srli_si128(top, 8)); + + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i top_lefts16 = _mm_set1_epi16(top_ptr[-1]); + const __m128i top_lefts8 = _mm_set1_epi8(static_cast<int8_t>(top_ptr[-1])); + + // Given that the spec defines "base" as top[x] + left[y] - top[-1], + // pLeft = abs(base - left[y]) = abs(top[x] - top[-1]) + // pTop = abs(base - top[x]) = abs(left[y] - top[-1]) + + const __m128i left_dists = _mm_or_si128(_mm_subs_epu8(top, top_lefts8), + _mm_subs_epu8(top_lefts8, top)); + const __m128i left_dists_lo = _mm_cvtepu8_epi16(left_dists); + const __m128i left_dists_hi = + _mm_cvtepu8_epi16(_mm_srli_si128(left_dists, 8)); + const __m128i top_dists = _mm_or_si128(_mm_subs_epu8(left, top_lefts8), + _mm_subs_epu8(top_lefts8, left)); + + const __m128i top_left_x2 = _mm_add_epi16(top_lefts16, top_lefts16); + const __m128i top_left_diff_lo = _mm_sub_epi16(top_lo, top_left_x2); + const __m128i top_left_diff_hi = _mm_sub_epi16(top_hi, top_left_x2); + auto* dst = static_cast<uint8_t*>(dest); + WritePaethLine16<0>(dst, top, left, top_lefts8, top_dists, left_dists, + left_dists_lo, left_dists_hi, top_left_diff_lo, + top_left_diff_hi); + dst += stride; + WritePaethLine16<0x01010101>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x02020202>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x03030303>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); +} + +// Inlined for calling with offsets in larger transform sizes, mainly to +// preserve top_left. +inline void WritePaeth16x8(void* const dest, ptrdiff_t stride, + const uint8_t top_left, const __m128i top, + const __m128i left) { + const __m128i top_lo = _mm_cvtepu8_epi16(top); + const __m128i top_hi = _mm_cvtepu8_epi16(_mm_srli_si128(top, 8)); + + const __m128i top_lefts16 = _mm_set1_epi16(top_left); + const __m128i top_lefts8 = _mm_set1_epi8(static_cast<int8_t>(top_left)); + + // Given that the spec defines "base" as top[x] + left[y] - top_left, + // pLeft = abs(base - left[y]) = abs(top[x] - top[-1]) + // pTop = abs(base - top[x]) = abs(left[y] - top[-1]) + + const __m128i left_dists = _mm_or_si128(_mm_subs_epu8(top, top_lefts8), + _mm_subs_epu8(top_lefts8, top)); + const __m128i left_dists_lo = _mm_cvtepu8_epi16(left_dists); + const __m128i left_dists_hi = + _mm_cvtepu8_epi16(_mm_srli_si128(left_dists, 8)); + const __m128i top_dists = _mm_or_si128(_mm_subs_epu8(left, top_lefts8), + _mm_subs_epu8(top_lefts8, left)); + + const __m128i top_left_x2 = _mm_add_epi16(top_lefts16, top_lefts16); + const __m128i top_left_diff_lo = _mm_sub_epi16(top_lo, top_left_x2); + const __m128i top_left_diff_hi = _mm_sub_epi16(top_hi, top_left_x2); + auto* dst = static_cast<uint8_t*>(dest); + WritePaethLine16<0>(dst, top, left, top_lefts8, top_dists, left_dists, + left_dists_lo, left_dists_hi, top_left_diff_lo, + top_left_diff_hi); + dst += stride; + WritePaethLine16<0x01010101>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x02020202>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x03030303>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x04040404>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x05050505>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x06060606>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x07070707>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); +} + +void Paeth16x8_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const __m128i top = LoadUnaligned16(top_row); + const __m128i left = LoadLo8(left_column); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + WritePaeth16x8(static_cast<uint8_t*>(dest), stride, top_ptr[-1], top, left); +} + +void WritePaeth16x16(void* const dest, ptrdiff_t stride, const uint8_t top_left, + const __m128i top, const __m128i left) { + const __m128i top_lo = _mm_cvtepu8_epi16(top); + const __m128i top_hi = _mm_cvtepu8_epi16(_mm_srli_si128(top, 8)); + + const __m128i top_lefts16 = _mm_set1_epi16(top_left); + const __m128i top_lefts8 = _mm_set1_epi8(static_cast<int8_t>(top_left)); + + // Given that the spec defines "base" as top[x] + left[y] - top[-1], + // pLeft = abs(base - left[y]) = abs(top[x] - top[-1]) + // pTop = abs(base - top[x]) = abs(left[y] - top[-1]) + + const __m128i left_dists = _mm_or_si128(_mm_subs_epu8(top, top_lefts8), + _mm_subs_epu8(top_lefts8, top)); + const __m128i left_dists_lo = _mm_cvtepu8_epi16(left_dists); + const __m128i left_dists_hi = + _mm_cvtepu8_epi16(_mm_srli_si128(left_dists, 8)); + const __m128i top_dists = _mm_or_si128(_mm_subs_epu8(left, top_lefts8), + _mm_subs_epu8(top_lefts8, left)); + + const __m128i top_left_x2 = _mm_add_epi16(top_lefts16, top_lefts16); + const __m128i top_left_diff_lo = _mm_sub_epi16(top_lo, top_left_x2); + const __m128i top_left_diff_hi = _mm_sub_epi16(top_hi, top_left_x2); + auto* dst = static_cast<uint8_t*>(dest); + WritePaethLine16<0>(dst, top, left, top_lefts8, top_dists, left_dists, + left_dists_lo, left_dists_hi, top_left_diff_lo, + top_left_diff_hi); + dst += stride; + WritePaethLine16<0x01010101>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x02020202>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x03030303>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x04040404>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x05050505>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x06060606>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x07070707>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x08080808>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x09090909>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x0A0A0A0A>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x0B0B0B0B>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x0C0C0C0C>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x0D0D0D0D>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x0E0E0E0E>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); + dst += stride; + WritePaethLine16<0x0F0F0F0F>(dst, top, left, top_lefts8, top_dists, + left_dists, left_dists_lo, left_dists_hi, + top_left_diff_lo, top_left_diff_hi); +} + +void Paeth16x16_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const __m128i left = LoadUnaligned16(left_column); + const __m128i top = LoadUnaligned16(top_row); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + WritePaeth16x16(static_cast<uint8_t*>(dest), stride, top_ptr[-1], top, left); +} + +void Paeth16x32_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const __m128i left_0 = LoadUnaligned16(left_column); + const __m128i top = LoadUnaligned16(top_row); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const uint8_t top_left = top_ptr[-1]; + auto* const dst = static_cast<uint8_t*>(dest); + WritePaeth16x16(dst, stride, top_left, top, left_0); + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const __m128i left_1 = LoadUnaligned16(left_ptr + 16); + WritePaeth16x16(dst + (stride << 4), stride, top_left, top, left_1); +} + +void Paeth16x64_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const ptrdiff_t stride16 = stride << 4; + const __m128i left_0 = LoadUnaligned16(left_column); + const __m128i top = LoadUnaligned16(top_row); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const uint8_t top_left = top_ptr[-1]; + auto* dst = static_cast<uint8_t*>(dest); + WritePaeth16x16(dst, stride, top_left, top, left_0); + dst += stride16; + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const __m128i left_1 = LoadUnaligned16(left_ptr + 16); + WritePaeth16x16(dst, stride, top_left, top, left_1); + dst += stride16; + const __m128i left_2 = LoadUnaligned16(left_ptr + 32); + WritePaeth16x16(dst, stride, top_left, top, left_2); + dst += stride16; + const __m128i left_3 = LoadUnaligned16(left_ptr + 48); + WritePaeth16x16(dst, stride, top_left, top, left_3); +} + +void Paeth32x8_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const __m128i left = LoadLo8(left_column); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i top_0 = LoadUnaligned16(top_row); + const uint8_t top_left = top_ptr[-1]; + auto* const dst = static_cast<uint8_t*>(dest); + WritePaeth16x8(dst, stride, top_left, top_0, left); + const __m128i top_1 = LoadUnaligned16(top_ptr + 16); + WritePaeth16x8(dst + 16, stride, top_left, top_1, left); +} + +void Paeth32x16_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const __m128i left = LoadUnaligned16(left_column); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i top_0 = LoadUnaligned16(top_row); + const uint8_t top_left = top_ptr[-1]; + auto* const dst = static_cast<uint8_t*>(dest); + WritePaeth16x16(dst, stride, top_left, top_0, left); + const __m128i top_1 = LoadUnaligned16(top_ptr + 16); + WritePaeth16x16(dst + 16, stride, top_left, top_1, left); +} + +void Paeth32x32_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const __m128i left_0 = LoadUnaligned16(left_ptr); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i top_0 = LoadUnaligned16(top_ptr); + const __m128i left_1 = LoadUnaligned16(left_ptr + 16); + const __m128i top_1 = LoadUnaligned16(top_ptr + 16); + const uint8_t top_left = top_ptr[-1]; + auto* dst = static_cast<uint8_t*>(dest); + WritePaeth16x16(dst, stride, top_left, top_0, left_0); + WritePaeth16x16(dst + 16, stride, top_left, top_1, left_0); + dst += (stride << 4); + WritePaeth16x16(dst, stride, top_left, top_0, left_1); + WritePaeth16x16(dst + 16, stride, top_left, top_1, left_1); +} + +void Paeth32x64_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const __m128i left_0 = LoadUnaligned16(left_ptr); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i top_0 = LoadUnaligned16(top_ptr); + const __m128i left_1 = LoadUnaligned16(left_ptr + 16); + const __m128i top_1 = LoadUnaligned16(top_ptr + 16); + const __m128i left_2 = LoadUnaligned16(left_ptr + 32); + const __m128i left_3 = LoadUnaligned16(left_ptr + 48); + const uint8_t top_left = top_ptr[-1]; + auto* dst = static_cast<uint8_t*>(dest); + WritePaeth16x16(dst, stride, top_left, top_0, left_0); + WritePaeth16x16(dst + 16, stride, top_left, top_1, left_0); + dst += (stride << 4); + WritePaeth16x16(dst, stride, top_left, top_0, left_1); + WritePaeth16x16(dst + 16, stride, top_left, top_1, left_1); + dst += (stride << 4); + WritePaeth16x16(dst, stride, top_left, top_0, left_2); + WritePaeth16x16(dst + 16, stride, top_left, top_1, left_2); + dst += (stride << 4); + WritePaeth16x16(dst, stride, top_left, top_0, left_3); + WritePaeth16x16(dst + 16, stride, top_left, top_1, left_3); +} + +void Paeth64x16_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const __m128i left = LoadUnaligned16(left_column); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i top_0 = LoadUnaligned16(top_ptr); + const __m128i top_1 = LoadUnaligned16(top_ptr + 16); + const __m128i top_2 = LoadUnaligned16(top_ptr + 32); + const __m128i top_3 = LoadUnaligned16(top_ptr + 48); + const uint8_t top_left = top_ptr[-1]; + auto* dst = static_cast<uint8_t*>(dest); + WritePaeth16x16(dst, stride, top_left, top_0, left); + WritePaeth16x16(dst + 16, stride, top_left, top_1, left); + WritePaeth16x16(dst + 32, stride, top_left, top_2, left); + WritePaeth16x16(dst + 48, stride, top_left, top_3, left); +} + +void Paeth64x32_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const __m128i left_0 = LoadUnaligned16(left_ptr); + const __m128i left_1 = LoadUnaligned16(left_ptr + 16); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i top_0 = LoadUnaligned16(top_ptr); + const __m128i top_1 = LoadUnaligned16(top_ptr + 16); + const __m128i top_2 = LoadUnaligned16(top_ptr + 32); + const __m128i top_3 = LoadUnaligned16(top_ptr + 48); + const uint8_t top_left = top_ptr[-1]; + auto* dst = static_cast<uint8_t*>(dest); + WritePaeth16x16(dst, stride, top_left, top_0, left_0); + WritePaeth16x16(dst + 16, stride, top_left, top_1, left_0); + WritePaeth16x16(dst + 32, stride, top_left, top_2, left_0); + WritePaeth16x16(dst + 48, stride, top_left, top_3, left_0); + dst += (stride << 4); + WritePaeth16x16(dst, stride, top_left, top_0, left_1); + WritePaeth16x16(dst + 16, stride, top_left, top_1, left_1); + WritePaeth16x16(dst + 32, stride, top_left, top_2, left_1); + WritePaeth16x16(dst + 48, stride, top_left, top_3, left_1); +} + +void Paeth64x64_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + const __m128i left_0 = LoadUnaligned16(left_ptr); + const __m128i left_1 = LoadUnaligned16(left_ptr + 16); + const __m128i left_2 = LoadUnaligned16(left_ptr + 32); + const __m128i left_3 = LoadUnaligned16(left_ptr + 48); + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const __m128i top_0 = LoadUnaligned16(top_ptr); + const __m128i top_1 = LoadUnaligned16(top_ptr + 16); + const __m128i top_2 = LoadUnaligned16(top_ptr + 32); + const __m128i top_3 = LoadUnaligned16(top_ptr + 48); + const uint8_t top_left = top_ptr[-1]; + auto* dst = static_cast<uint8_t*>(dest); + WritePaeth16x16(dst, stride, top_left, top_0, left_0); + WritePaeth16x16(dst + 16, stride, top_left, top_1, left_0); + WritePaeth16x16(dst + 32, stride, top_left, top_2, left_0); + WritePaeth16x16(dst + 48, stride, top_left, top_3, left_0); + dst += (stride << 4); + WritePaeth16x16(dst, stride, top_left, top_0, left_1); + WritePaeth16x16(dst + 16, stride, top_left, top_1, left_1); + WritePaeth16x16(dst + 32, stride, top_left, top_2, left_1); + WritePaeth16x16(dst + 48, stride, top_left, top_3, left_1); + dst += (stride << 4); + WritePaeth16x16(dst, stride, top_left, top_0, left_2); + WritePaeth16x16(dst + 16, stride, top_left, top_1, left_2); + WritePaeth16x16(dst + 32, stride, top_left, top_2, left_2); + WritePaeth16x16(dst + 48, stride, top_left, top_3, left_2); + dst += (stride << 4); + WritePaeth16x16(dst, stride, top_left, top_0, left_3); + WritePaeth16x16(dst + 16, stride, top_left, top_1, left_3); + WritePaeth16x16(dst + 32, stride, top_left, top_2, left_3); + WritePaeth16x16(dst + 48, stride, top_left, top_3, left_3); +} + +//------------------------------------------------------------------------------ +// 7.11.2.4. Directional intra prediction process + +// Special case: An |xstep| of 64 corresponds to an angle delta of 45, meaning +// upsampling is ruled out. In addition, the bits masked by 0x3F for +// |shift_val| are 0 for all multiples of 64, so the formula +// val = top[top_base_x]*shift + top[top_base_x+1]*(32-shift), reduces to +// val = top[top_base_x+1] << 5, meaning only the second set of pixels is +// involved in the output. Hence |top| is offset by 1. +inline void DirectionalZone1_Step64(uint8_t* dst, ptrdiff_t stride, + const uint8_t* const top, const int width, + const int height) { + ptrdiff_t offset = 1; + if (height == 4) { + memcpy(dst, top + offset, width); + dst += stride; + memcpy(dst, top + offset + 1, width); + dst += stride; + memcpy(dst, top + offset + 2, width); + dst += stride; + memcpy(dst, top + offset + 3, width); + return; + } + int y = 0; + do { + memcpy(dst, top + offset, width); + dst += stride; + memcpy(dst, top + offset + 1, width); + dst += stride; + memcpy(dst, top + offset + 2, width); + dst += stride; + memcpy(dst, top + offset + 3, width); + dst += stride; + memcpy(dst, top + offset + 4, width); + dst += stride; + memcpy(dst, top + offset + 5, width); + dst += stride; + memcpy(dst, top + offset + 6, width); + dst += stride; + memcpy(dst, top + offset + 7, width); + dst += stride; + + offset += 8; + y += 8; + } while (y < height); +} + +inline void DirectionalZone1_4xH(uint8_t* dst, ptrdiff_t stride, + const uint8_t* const top, const int height, + const int xstep, const bool upsampled) { + const int upsample_shift = static_cast<int>(upsampled); + const int scale_bits = 6 - upsample_shift; + const int rounding_bits = 5; + const int max_base_x = (height + 3 /* width - 1 */) << upsample_shift; + const __m128i final_top_val = _mm_set1_epi16(top[max_base_x]); + const __m128i sampler = upsampled ? _mm_set_epi64x(0, 0x0706050403020100) + : _mm_set_epi64x(0, 0x0403030202010100); + // Each 16-bit value here corresponds to a position that may exceed + // |max_base_x|. When added to the top_base_x, it is used to mask values + // that pass the end of |top|. Starting from 1 to simulate "cmpge" which is + // not supported for packed integers. + const __m128i offsets = + _mm_set_epi32(0x00080007, 0x00060005, 0x00040003, 0x00020001); + + // All rows from |min_corner_only_y| down will simply use memcpy. |max_base_x| + // is always greater than |height|, so clipping to 1 is enough to make the + // logic work. + const int xstep_units = std::max(xstep >> scale_bits, 1); + const int min_corner_only_y = std::min(max_base_x / xstep_units, height); + + // Rows up to this y-value can be computed without checking for bounds. + int y = 0; + int top_x = xstep; + + for (; y < min_corner_only_y; ++y, dst += stride, top_x += xstep) { + const int top_base_x = top_x >> scale_bits; + + // Permit negative values of |top_x|. + const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; + const __m128i shift = _mm_set1_epi8(shift_val); + const __m128i max_shift = _mm_set1_epi8(32); + const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); + const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); + __m128i top_index_vect = _mm_set1_epi16(top_base_x); + top_index_vect = _mm_add_epi16(top_index_vect, offsets); + const __m128i max_base_x_vect = _mm_set1_epi16(max_base_x); + + // Load 8 values because we will select the sampled values based on + // |upsampled|. + const __m128i values = LoadLo8(top + top_base_x); + const __m128i sampled_values = _mm_shuffle_epi8(values, sampler); + const __m128i past_max = _mm_cmpgt_epi16(top_index_vect, max_base_x_vect); + __m128i prod = _mm_maddubs_epi16(sampled_values, shifts); + prod = RightShiftWithRounding_U16(prod, rounding_bits); + // Replace pixels from invalid range with top-right corner. + prod = _mm_blendv_epi8(prod, final_top_val, past_max); + Store4(dst, _mm_packus_epi16(prod, prod)); + } + + // Fill in corner-only rows. + for (; y < height; ++y) { + memset(dst, top[max_base_x], /* width */ 4); + dst += stride; + } +} + +// 7.11.2.4 (7) angle < 90 +inline void DirectionalZone1_Large(uint8_t* dest, ptrdiff_t stride, + const uint8_t* const top_row, + const int width, const int height, + const int xstep, const bool upsampled) { + const int upsample_shift = static_cast<int>(upsampled); + const __m128i sampler = + upsampled ? _mm_set_epi32(0x0F0E0D0C, 0x0B0A0908, 0x07060504, 0x03020100) + : _mm_set_epi32(0x08070706, 0x06050504, 0x04030302, 0x02010100); + const int scale_bits = 6 - upsample_shift; + const int max_base_x = ((width + height) - 1) << upsample_shift; + + const __m128i max_shift = _mm_set1_epi8(32); + const int rounding_bits = 5; + const int base_step = 1 << upsample_shift; + const int base_step8 = base_step << 3; + + // All rows from |min_corner_only_y| down will simply use memcpy. |max_base_x| + // is always greater than |height|, so clipping to 1 is enough to make the + // logic work. + const int xstep_units = std::max(xstep >> scale_bits, 1); + const int min_corner_only_y = std::min(max_base_x / xstep_units, height); + + // Rows up to this y-value can be computed without checking for bounds. + const int max_no_corner_y = std::min( + LeftShift((max_base_x - (base_step * width)), scale_bits) / xstep, + height); + // No need to check for exceeding |max_base_x| in the first loop. + int y = 0; + int top_x = xstep; + for (; y < max_no_corner_y; ++y, dest += stride, top_x += xstep) { + int top_base_x = top_x >> scale_bits; + // Permit negative values of |top_x|. + const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; + const __m128i shift = _mm_set1_epi8(shift_val); + const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); + const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); + int x = 0; + do { + const __m128i top_vals = LoadUnaligned16(top_row + top_base_x); + __m128i vals = _mm_shuffle_epi8(top_vals, sampler); + vals = _mm_maddubs_epi16(vals, shifts); + vals = RightShiftWithRounding_U16(vals, rounding_bits); + StoreLo8(dest + x, _mm_packus_epi16(vals, vals)); + top_base_x += base_step8; + x += 8; + } while (x < width); + } + + // Each 16-bit value here corresponds to a position that may exceed + // |max_base_x|. When added to the top_base_x, it is used to mask values + // that pass the end of |top|. Starting from 1 to simulate "cmpge" which is + // not supported for packed integers. + const __m128i offsets = + _mm_set_epi32(0x00080007, 0x00060005, 0x00040003, 0x00020001); + + const __m128i max_base_x_vect = _mm_set1_epi16(max_base_x); + const __m128i final_top_val = _mm_set1_epi16(top_row[max_base_x]); + const __m128i base_step8_vect = _mm_set1_epi16(base_step8); + for (; y < min_corner_only_y; ++y, dest += stride, top_x += xstep) { + int top_base_x = top_x >> scale_bits; + + const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; + const __m128i shift = _mm_set1_epi8(shift_val); + const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); + const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); + __m128i top_index_vect = _mm_set1_epi16(top_base_x); + top_index_vect = _mm_add_epi16(top_index_vect, offsets); + + int x = 0; + const int min_corner_only_x = + std::min(width, ((max_base_x - top_base_x) >> upsample_shift) + 7) & ~7; + for (; x < min_corner_only_x; + x += 8, top_base_x += base_step8, + top_index_vect = _mm_add_epi16(top_index_vect, base_step8_vect)) { + const __m128i past_max = _mm_cmpgt_epi16(top_index_vect, max_base_x_vect); + // Assuming a buffer zone of 8 bytes at the end of top_row, this prevents + // reading out of bounds. If all indices are past max and we don't need to + // use the loaded bytes at all, |top_base_x| becomes 0. |top_base_x| will + // reset for the next |y|. + top_base_x &= ~_mm_cvtsi128_si32(past_max); + const __m128i top_vals = LoadUnaligned16(top_row + top_base_x); + __m128i vals = _mm_shuffle_epi8(top_vals, sampler); + vals = _mm_maddubs_epi16(vals, shifts); + vals = RightShiftWithRounding_U16(vals, rounding_bits); + vals = _mm_blendv_epi8(vals, final_top_val, past_max); + StoreLo8(dest + x, _mm_packus_epi16(vals, vals)); + } + // Corner-only section of the row. + memset(dest + x, top_row[max_base_x], width - x); + } + // Fill in corner-only rows. + for (; y < height; ++y) { + memset(dest, top_row[max_base_x], width); + dest += stride; + } +} + +// 7.11.2.4 (7) angle < 90 +inline void DirectionalZone1_SSE4_1(uint8_t* dest, ptrdiff_t stride, + const uint8_t* const top_row, + const int width, const int height, + const int xstep, const bool upsampled) { + const int upsample_shift = static_cast<int>(upsampled); + if (xstep == 64) { + DirectionalZone1_Step64(dest, stride, top_row, width, height); + return; + } + if (width == 4) { + DirectionalZone1_4xH(dest, stride, top_row, height, xstep, upsampled); + return; + } + if (width >= 32) { + DirectionalZone1_Large(dest, stride, top_row, width, height, xstep, + upsampled); + return; + } + const __m128i sampler = + upsampled ? _mm_set_epi32(0x0F0E0D0C, 0x0B0A0908, 0x07060504, 0x03020100) + : _mm_set_epi32(0x08070706, 0x06050504, 0x04030302, 0x02010100); + const int scale_bits = 6 - upsample_shift; + const int max_base_x = ((width + height) - 1) << upsample_shift; + + const __m128i max_shift = _mm_set1_epi8(32); + const int rounding_bits = 5; + const int base_step = 1 << upsample_shift; + const int base_step8 = base_step << 3; + + // No need to check for exceeding |max_base_x| in the loops. + if (((xstep * height) >> scale_bits) + base_step * width < max_base_x) { + int top_x = xstep; + int y = 0; + do { + int top_base_x = top_x >> scale_bits; + // Permit negative values of |top_x|. + const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; + const __m128i shift = _mm_set1_epi8(shift_val); + const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); + const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); + int x = 0; + do { + const __m128i top_vals = LoadUnaligned16(top_row + top_base_x); + __m128i vals = _mm_shuffle_epi8(top_vals, sampler); + vals = _mm_maddubs_epi16(vals, shifts); + vals = RightShiftWithRounding_U16(vals, rounding_bits); + StoreLo8(dest + x, _mm_packus_epi16(vals, vals)); + top_base_x += base_step8; + x += 8; + } while (x < width); + dest += stride; + top_x += xstep; + } while (++y < height); + return; + } + + // Each 16-bit value here corresponds to a position that may exceed + // |max_base_x|. When added to the top_base_x, it is used to mask values + // that pass the end of |top|. Starting from 1 to simulate "cmpge" which is + // not supported for packed integers. + const __m128i offsets = + _mm_set_epi32(0x00080007, 0x00060005, 0x00040003, 0x00020001); + + const __m128i max_base_x_vect = _mm_set1_epi16(max_base_x); + const __m128i final_top_val = _mm_set1_epi16(top_row[max_base_x]); + const __m128i base_step8_vect = _mm_set1_epi16(base_step8); + int top_x = xstep; + int y = 0; + do { + int top_base_x = top_x >> scale_bits; + + if (top_base_x >= max_base_x) { + for (int i = y; i < height; ++i) { + memset(dest, top_row[max_base_x], width); + dest += stride; + } + return; + } + + const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1; + const __m128i shift = _mm_set1_epi8(shift_val); + const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); + const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); + __m128i top_index_vect = _mm_set1_epi16(top_base_x); + top_index_vect = _mm_add_epi16(top_index_vect, offsets); + + int x = 0; + for (; x < width - 8; + x += 8, top_base_x += base_step8, + top_index_vect = _mm_add_epi16(top_index_vect, base_step8_vect)) { + const __m128i past_max = _mm_cmpgt_epi16(top_index_vect, max_base_x_vect); + // Assuming a buffer zone of 8 bytes at the end of top_row, this prevents + // reading out of bounds. If all indices are past max and we don't need to + // use the loaded bytes at all, |top_base_x| becomes 0. |top_base_x| will + // reset for the next |y|. + top_base_x &= ~_mm_cvtsi128_si32(past_max); + const __m128i top_vals = LoadUnaligned16(top_row + top_base_x); + __m128i vals = _mm_shuffle_epi8(top_vals, sampler); + vals = _mm_maddubs_epi16(vals, shifts); + vals = RightShiftWithRounding_U16(vals, rounding_bits); + vals = _mm_blendv_epi8(vals, final_top_val, past_max); + StoreLo8(dest + x, _mm_packus_epi16(vals, vals)); + } + const __m128i past_max = _mm_cmpgt_epi16(top_index_vect, max_base_x_vect); + __m128i vals; + if (upsampled) { + vals = LoadUnaligned16(top_row + top_base_x); + } else { + const __m128i top_vals = LoadLo8(top_row + top_base_x); + vals = _mm_shuffle_epi8(top_vals, sampler); + vals = _mm_insert_epi8(vals, top_row[top_base_x + 8], 15); + } + vals = _mm_maddubs_epi16(vals, shifts); + vals = RightShiftWithRounding_U16(vals, rounding_bits); + vals = _mm_blendv_epi8(vals, final_top_val, past_max); + StoreLo8(dest + x, _mm_packus_epi16(vals, vals)); + dest += stride; + top_x += xstep; + } while (++y < height); +} + +void DirectionalIntraPredictorZone1_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const int width, const int height, + const int xstep, + const bool upsampled_top) { + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + auto* dst = static_cast<uint8_t*>(dest); + DirectionalZone1_SSE4_1(dst, stride, top_ptr, width, height, xstep, + upsampled_top); +} + +template <bool upsampled> +inline void DirectionalZone3_4x4(uint8_t* dest, ptrdiff_t stride, + const uint8_t* const left_column, + const int base_left_y, const int ystep) { + // For use in the non-upsampled case. + const __m128i sampler = _mm_set_epi64x(0, 0x0403030202010100); + const int upsample_shift = static_cast<int>(upsampled); + const int scale_bits = 6 - upsample_shift; + const __m128i max_shift = _mm_set1_epi8(32); + const int rounding_bits = 5; + + __m128i result_block[4]; + for (int x = 0, left_y = base_left_y; x < 4; x++, left_y += ystep) { + const int left_base_y = left_y >> scale_bits; + const int shift_val = ((left_y << upsample_shift) & 0x3F) >> 1; + const __m128i shift = _mm_set1_epi8(shift_val); + const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); + const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); + __m128i vals; + if (upsampled) { + vals = LoadLo8(left_column + left_base_y); + } else { + const __m128i top_vals = LoadLo8(left_column + left_base_y); + vals = _mm_shuffle_epi8(top_vals, sampler); + } + vals = _mm_maddubs_epi16(vals, shifts); + vals = RightShiftWithRounding_U16(vals, rounding_bits); + result_block[x] = _mm_packus_epi16(vals, vals); + } + const __m128i result = Transpose4x4_U8(result_block); + // This is result_row0. + Store4(dest, result); + dest += stride; + const int result_row1 = _mm_extract_epi32(result, 1); + memcpy(dest, &result_row1, sizeof(result_row1)); + dest += stride; + const int result_row2 = _mm_extract_epi32(result, 2); + memcpy(dest, &result_row2, sizeof(result_row2)); + dest += stride; + const int result_row3 = _mm_extract_epi32(result, 3); + memcpy(dest, &result_row3, sizeof(result_row3)); +} + +template <bool upsampled, int height> +inline void DirectionalZone3_8xH(uint8_t* dest, ptrdiff_t stride, + const uint8_t* const left_column, + const int base_left_y, const int ystep) { + // For use in the non-upsampled case. + const __m128i sampler = + _mm_set_epi64x(0x0807070606050504, 0x0403030202010100); + const int upsample_shift = static_cast<int>(upsampled); + const int scale_bits = 6 - upsample_shift; + const __m128i max_shift = _mm_set1_epi8(32); + const int rounding_bits = 5; + + __m128i result_block[8]; + for (int x = 0, left_y = base_left_y; x < 8; x++, left_y += ystep) { + const int left_base_y = left_y >> scale_bits; + const int shift_val = (LeftShift(left_y, upsample_shift) & 0x3F) >> 1; + const __m128i shift = _mm_set1_epi8(shift_val); + const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift); + const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift); + __m128i vals; + if (upsampled) { + vals = LoadUnaligned16(left_column + left_base_y); + } else { + const __m128i top_vals = LoadUnaligned16(left_column + left_base_y); + vals = _mm_shuffle_epi8(top_vals, sampler); + } + vals = _mm_maddubs_epi16(vals, shifts); + result_block[x] = RightShiftWithRounding_U16(vals, rounding_bits); + } + Transpose8x8_U16(result_block, result_block); + for (int y = 0; y < height; ++y) { + StoreLo8(dest, _mm_packus_epi16(result_block[y], result_block[y])); + dest += stride; + } +} + +// 7.11.2.4 (9) angle > 180 +void DirectionalIntraPredictorZone3_SSE4_1(void* dest, ptrdiff_t stride, + const void* const left_column, + const int width, const int height, + const int ystep, + const bool upsampled) { + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + auto* dst = static_cast<uint8_t*>(dest); + const int upsample_shift = static_cast<int>(upsampled); + if (width == 4 || height == 4) { + const ptrdiff_t stride4 = stride << 2; + if (upsampled) { + int left_y = ystep; + int x = 0; + do { + uint8_t* dst_x = dst + x; + int y = 0; + do { + DirectionalZone3_4x4<true>( + dst_x, stride, left_ptr + (y << upsample_shift), left_y, ystep); + dst_x += stride4; + y += 4; + } while (y < height); + left_y += ystep << 2; + x += 4; + } while (x < width); + } else { + int left_y = ystep; + int x = 0; + do { + uint8_t* dst_x = dst + x; + int y = 0; + do { + DirectionalZone3_4x4<false>(dst_x, stride, left_ptr + y, left_y, + ystep); + dst_x += stride4; + y += 4; + } while (y < height); + left_y += ystep << 2; + x += 4; + } while (x < width); + } + return; + } + + const ptrdiff_t stride8 = stride << 3; + if (upsampled) { + int left_y = ystep; + int x = 0; + do { + uint8_t* dst_x = dst + x; + int y = 0; + do { + DirectionalZone3_8xH<true, 8>( + dst_x, stride, left_ptr + (y << upsample_shift), left_y, ystep); + dst_x += stride8; + y += 8; + } while (y < height); + left_y += ystep << 3; + x += 8; + } while (x < width); + } else { + int left_y = ystep; + int x = 0; + do { + uint8_t* dst_x = dst + x; + int y = 0; + do { + DirectionalZone3_8xH<false, 8>( + dst_x, stride, left_ptr + (y << upsample_shift), left_y, ystep); + dst_x += stride8; + y += 8; + } while (y < height); + left_y += ystep << 3; + x += 8; + } while (x < width); + } +} + +//------------------------------------------------------------------------------ +// Directional Zone 2 Functions +// 7.11.2.4 (8) + +// DirectionalBlend* selectively overwrites the values written by +// DirectionalZone2FromLeftCol*. |zone_bounds| has one 16-bit index for each +// row. +template <int y_selector> +inline void DirectionalBlend4_SSE4_1(uint8_t* dest, + const __m128i& dest_index_vect, + const __m128i& vals, + const __m128i& zone_bounds) { + const __m128i max_dest_x_vect = _mm_shufflelo_epi16(zone_bounds, y_selector); + const __m128i use_left = _mm_cmplt_epi16(dest_index_vect, max_dest_x_vect); + const __m128i original_vals = _mm_cvtepu8_epi16(Load4(dest)); + const __m128i blended_vals = _mm_blendv_epi8(vals, original_vals, use_left); + Store4(dest, _mm_packus_epi16(blended_vals, blended_vals)); +} + +inline void DirectionalBlend8_SSE4_1(uint8_t* dest, + const __m128i& dest_index_vect, + const __m128i& vals, + const __m128i& zone_bounds, + const __m128i& bounds_selector) { + const __m128i max_dest_x_vect = + _mm_shuffle_epi8(zone_bounds, bounds_selector); + const __m128i use_left = _mm_cmplt_epi16(dest_index_vect, max_dest_x_vect); + const __m128i original_vals = _mm_cvtepu8_epi16(LoadLo8(dest)); + const __m128i blended_vals = _mm_blendv_epi8(vals, original_vals, use_left); + StoreLo8(dest, _mm_packus_epi16(blended_vals, blended_vals)); +} + +constexpr int kDirectionalWeightBits = 5; +// |source| is packed with 4 or 8 pairs of 8-bit values from left or top. +// |shifts| is named to match the specification, with 4 or 8 pairs of (32 - +// shift) and shift. Shift is guaranteed to be between 0 and 32. +inline __m128i DirectionalZone2FromSource_SSE4_1(const uint8_t* const source, + const __m128i& shifts, + const __m128i& sampler) { + const __m128i src_vals = LoadUnaligned16(source); + __m128i vals = _mm_shuffle_epi8(src_vals, sampler); + vals = _mm_maddubs_epi16(vals, shifts); + return RightShiftWithRounding_U16(vals, kDirectionalWeightBits); +} + +// Because the source values "move backwards" as the row index increases, the +// indices derived from ystep are generally negative. This is accommodated by +// making sure the relative indices are within [-15, 0] when the function is +// called, and sliding them into the inclusive range [0, 15], relative to a +// lower base address. +constexpr int kPositiveIndexOffset = 15; + +template <bool upsampled> +inline void DirectionalZone2FromLeftCol_4x4_SSE4_1( + uint8_t* dst, ptrdiff_t stride, const uint8_t* const left_column_base, + __m128i left_y) { + const int upsample_shift = static_cast<int>(upsampled); + const int scale_bits = 6 - upsample_shift; + const __m128i max_shifts = _mm_set1_epi8(32); + const __m128i shift_mask = _mm_set1_epi32(0x003F003F); + const __m128i index_increment = _mm_cvtsi32_si128(0x01010101); + const __m128i positive_offset = _mm_set1_epi8(kPositiveIndexOffset); + // Left_column and sampler are both offset by 15 so the indices are always + // positive. + const uint8_t* left_column = left_column_base - kPositiveIndexOffset; + for (int y = 0; y < 4; dst += stride, ++y) { + __m128i offset_y = _mm_srai_epi16(left_y, scale_bits); + offset_y = _mm_packs_epi16(offset_y, offset_y); + + const __m128i adjacent = _mm_add_epi8(offset_y, index_increment); + __m128i sampler = _mm_unpacklo_epi8(offset_y, adjacent); + // Slide valid |offset_y| indices from range [-15, 0] to [0, 15] so they + // can work as shuffle indices. Some values may be out of bounds, but their + // pred results will be masked over by top prediction. + sampler = _mm_add_epi8(sampler, positive_offset); + + __m128i shifts = _mm_srli_epi16( + _mm_and_si128(_mm_slli_epi16(left_y, upsample_shift), shift_mask), 1); + shifts = _mm_packus_epi16(shifts, shifts); + const __m128i opposite_shifts = _mm_sub_epi8(max_shifts, shifts); + shifts = _mm_unpacklo_epi8(opposite_shifts, shifts); + const __m128i vals = DirectionalZone2FromSource_SSE4_1( + left_column + (y << upsample_shift), shifts, sampler); + Store4(dst, _mm_packus_epi16(vals, vals)); + } +} + +// The height at which a load of 16 bytes will not contain enough source pixels +// from |left_column| to supply an accurate row when computing 8 pixels at a +// time. The values are found by inspection. By coincidence, all angles that +// satisfy (ystep >> 6) == 2 map to the same value, so it is enough to look up +// by ystep >> 6. The largest index for this lookup is 1023 >> 6 == 15. +constexpr int kDirectionalZone2ShuffleInvalidHeight[16] = { + 1024, 1024, 16, 16, 16, 16, 0, 0, 18, 0, 0, 0, 0, 0, 0, 40}; + +template <bool upsampled> +inline void DirectionalZone2FromLeftCol_8x8_SSE4_1( + uint8_t* dst, ptrdiff_t stride, const uint8_t* const left_column, + __m128i left_y) { + const int upsample_shift = static_cast<int>(upsampled); + const int scale_bits = 6 - upsample_shift; + const __m128i max_shifts = _mm_set1_epi8(32); + const __m128i shift_mask = _mm_set1_epi32(0x003F003F); + const __m128i index_increment = _mm_set1_epi8(1); + const __m128i denegation = _mm_set1_epi8(kPositiveIndexOffset); + for (int y = 0; y < 8; dst += stride, ++y) { + __m128i offset_y = _mm_srai_epi16(left_y, scale_bits); + offset_y = _mm_packs_epi16(offset_y, offset_y); + const __m128i adjacent = _mm_add_epi8(offset_y, index_increment); + + // Offset the relative index because ystep is negative in Zone 2 and shuffle + // indices must be nonnegative. + __m128i sampler = _mm_unpacklo_epi8(offset_y, adjacent); + sampler = _mm_add_epi8(sampler, denegation); + + __m128i shifts = _mm_srli_epi16( + _mm_and_si128(_mm_slli_epi16(left_y, upsample_shift), shift_mask), 1); + shifts = _mm_packus_epi16(shifts, shifts); + const __m128i opposite_shifts = _mm_sub_epi8(max_shifts, shifts); + shifts = _mm_unpacklo_epi8(opposite_shifts, shifts); + + // The specification adds (y << 6) to left_y, which is subject to + // upsampling, but this puts sampler indices out of the 0-15 range. It is + // equivalent to offset the source address by (y << upsample_shift) instead. + const __m128i vals = DirectionalZone2FromSource_SSE4_1( + left_column - kPositiveIndexOffset + (y << upsample_shift), shifts, + sampler); + StoreLo8(dst, _mm_packus_epi16(vals, vals)); + } +} + +// |zone_bounds| is an epi16 of the relative x index at which base >= -(1 << +// upsampled_top), for each row. When there are 4 values, they can be duplicated +// with a non-register shuffle mask. +// |shifts| is one pair of weights that applies throughout a given row. +template <bool upsampled_top> +inline void DirectionalZone1Blend_4x4( + uint8_t* dest, const uint8_t* const top_row, ptrdiff_t stride, + __m128i sampler, const __m128i& zone_bounds, const __m128i& shifts, + const __m128i& dest_index_x, int top_x, const int xstep) { + const int upsample_shift = static_cast<int>(upsampled_top); + const int scale_bits_x = 6 - upsample_shift; + top_x -= xstep; + + int top_base_x = (top_x >> scale_bits_x); + const __m128i vals0 = DirectionalZone2FromSource_SSE4_1( + top_row + top_base_x, _mm_shufflelo_epi16(shifts, 0x00), sampler); + DirectionalBlend4_SSE4_1<0x00>(dest, dest_index_x, vals0, zone_bounds); + top_x -= xstep; + dest += stride; + + top_base_x = (top_x >> scale_bits_x); + const __m128i vals1 = DirectionalZone2FromSource_SSE4_1( + top_row + top_base_x, _mm_shufflelo_epi16(shifts, 0x55), sampler); + DirectionalBlend4_SSE4_1<0x55>(dest, dest_index_x, vals1, zone_bounds); + top_x -= xstep; + dest += stride; + + top_base_x = (top_x >> scale_bits_x); + const __m128i vals2 = DirectionalZone2FromSource_SSE4_1( + top_row + top_base_x, _mm_shufflelo_epi16(shifts, 0xAA), sampler); + DirectionalBlend4_SSE4_1<0xAA>(dest, dest_index_x, vals2, zone_bounds); + top_x -= xstep; + dest += stride; + + top_base_x = (top_x >> scale_bits_x); + const __m128i vals3 = DirectionalZone2FromSource_SSE4_1( + top_row + top_base_x, _mm_shufflelo_epi16(shifts, 0xFF), sampler); + DirectionalBlend4_SSE4_1<0xFF>(dest, dest_index_x, vals3, zone_bounds); +} + +template <bool upsampled_top, int height> +inline void DirectionalZone1Blend_8xH( + uint8_t* dest, const uint8_t* const top_row, ptrdiff_t stride, + __m128i sampler, const __m128i& zone_bounds, const __m128i& shifts, + const __m128i& dest_index_x, int top_x, const int xstep) { + const int upsample_shift = static_cast<int>(upsampled_top); + const int scale_bits_x = 6 - upsample_shift; + + __m128i y_selector = _mm_set1_epi32(0x01000100); + const __m128i index_increment = _mm_set1_epi32(0x02020202); + for (int y = 0; y < height; ++y, + y_selector = _mm_add_epi8(y_selector, index_increment), + dest += stride) { + top_x -= xstep; + const int top_base_x = top_x >> scale_bits_x; + const __m128i vals = DirectionalZone2FromSource_SSE4_1( + top_row + top_base_x, _mm_shuffle_epi8(shifts, y_selector), sampler); + DirectionalBlend8_SSE4_1(dest, dest_index_x, vals, zone_bounds, y_selector); + } +} + +// 7.11.2.4 (8) 90 < angle > 180 +// The strategy for this function is to know how many blocks can be processed +// with just pixels from |top_ptr|, then handle mixed blocks, then handle only +// blocks that take from |left_ptr|. Additionally, a fast index-shuffle +// approach is used for pred values from |left_column| in sections that permit +// it. +template <bool upsampled_left, bool upsampled_top> +inline void DirectionalZone2_SSE4_1(void* dest, ptrdiff_t stride, + const uint8_t* const top_row, + const uint8_t* const left_column, + const int width, const int height, + const int xstep, const int ystep) { + auto* dst = static_cast<uint8_t*>(dest); + const int upsample_left_shift = static_cast<int>(upsampled_left); + const int upsample_top_shift = static_cast<int>(upsampled_top); + const __m128i max_shift = _mm_set1_epi8(32); + const ptrdiff_t stride8 = stride << 3; + const __m128i dest_index_x = + _mm_set_epi32(0x00070006, 0x00050004, 0x00030002, 0x00010000); + const __m128i sampler_top = + upsampled_top + ? _mm_set_epi32(0x0F0E0D0C, 0x0B0A0908, 0x07060504, 0x03020100) + : _mm_set_epi32(0x08070706, 0x06050504, 0x04030302, 0x02010100); + const __m128i shift_mask = _mm_set1_epi32(0x003F003F); + // All columns from |min_top_only_x| to the right will only need |top_row| to + // compute. This assumes minimum |xstep| is 3. + const int min_top_only_x = std::min((height * xstep) >> 6, width); + + // For steep angles, the source pixels from left_column may not fit in a + // 16-byte load for shuffling. + // TODO(petersonab): Find a more precise formula for this subject to x. + const int max_shuffle_height = + std::min(height, kDirectionalZone2ShuffleInvalidHeight[ystep >> 6]); + + const int xstep8 = xstep << 3; + const __m128i xstep8_vect = _mm_set1_epi16(xstep8); + // Accumulate xstep across 8 rows. + const __m128i xstep_dup = _mm_set1_epi16(-xstep); + const __m128i increments = _mm_set_epi16(8, 7, 6, 5, 4, 3, 2, 1); + const __m128i xstep_for_shift = _mm_mullo_epi16(xstep_dup, increments); + // Offsets the original zone bound value to simplify x < (y+1)*xstep/64 -1 + const __m128i scaled_one = _mm_set1_epi16(-64); + __m128i xstep_bounds_base = + (xstep == 64) ? _mm_sub_epi16(scaled_one, xstep_for_shift) + : _mm_sub_epi16(_mm_set1_epi16(-1), xstep_for_shift); + + const int left_base_increment = ystep >> 6; + const int ystep_remainder = ystep & 0x3F; + const int ystep8 = ystep << 3; + const int left_base_increment8 = ystep8 >> 6; + const int ystep_remainder8 = ystep8 & 0x3F; + const __m128i increment_left8 = _mm_set1_epi16(-ystep_remainder8); + + // If the 64 scaling is regarded as a decimal point, the first value of the + // left_y vector omits the portion which is covered under the left_column + // offset. Following values need the full ystep as a relative offset. + const __m128i ystep_init = _mm_set1_epi16(-ystep_remainder); + const __m128i ystep_dup = _mm_set1_epi16(-ystep); + __m128i left_y = _mm_mullo_epi16(ystep_dup, dest_index_x); + left_y = _mm_add_epi16(ystep_init, left_y); + + const __m128i increment_top8 = _mm_set1_epi16(8 << 6); + int x = 0; + + // This loop treats each set of 4 columns in 3 stages with y-value boundaries. + // The first stage, before the first y-loop, covers blocks that are only + // computed from the top row. The second stage, comprising two y-loops, covers + // blocks that have a mixture of values computed from top or left. The final + // stage covers blocks that are only computed from the left. + for (int left_offset = -left_base_increment; x < min_top_only_x; + x += 8, + xstep_bounds_base = _mm_sub_epi16(xstep_bounds_base, increment_top8), + // Watch left_y because it can still get big. + left_y = _mm_add_epi16(left_y, increment_left8), + left_offset -= left_base_increment8) { + uint8_t* dst_x = dst + x; + + // Round down to the nearest multiple of 8. + const int max_top_only_y = std::min(((x + 1) << 6) / xstep, height) & ~7; + DirectionalZone1_4xH(dst_x, stride, top_row + (x << upsample_top_shift), + max_top_only_y, -xstep, upsampled_top); + DirectionalZone1_4xH(dst_x + 4, stride, + top_row + ((x + 4) << upsample_top_shift), + max_top_only_y, -xstep, upsampled_top); + + int y = max_top_only_y; + dst_x += stride * y; + const int xstep_y = xstep * y; + const __m128i xstep_y_vect = _mm_set1_epi16(xstep_y); + // All rows from |min_left_only_y| down for this set of columns, only need + // |left_column| to compute. + const int min_left_only_y = std::min(((x + 8) << 6) / xstep, height); + // At high angles such that min_left_only_y < 8, ystep is low and xstep is + // high. This means that max_shuffle_height is unbounded and xstep_bounds + // will overflow in 16 bits. This is prevented by stopping the first + // blending loop at min_left_only_y for such cases, which means we skip over + // the second blending loop as well. + const int left_shuffle_stop_y = + std::min(max_shuffle_height, min_left_only_y); + __m128i xstep_bounds = _mm_add_epi16(xstep_bounds_base, xstep_y_vect); + __m128i xstep_for_shift_y = _mm_sub_epi16(xstep_for_shift, xstep_y_vect); + int top_x = -xstep_y; + + for (; y < left_shuffle_stop_y; + y += 8, dst_x += stride8, + xstep_bounds = _mm_add_epi16(xstep_bounds, xstep8_vect), + xstep_for_shift_y = _mm_sub_epi16(xstep_for_shift_y, xstep8_vect), + top_x -= xstep8) { + DirectionalZone2FromLeftCol_8x8_SSE4_1<upsampled_left>( + dst_x, stride, + left_column + ((left_offset + y) << upsample_left_shift), left_y); + + __m128i shifts = _mm_srli_epi16( + _mm_and_si128(_mm_slli_epi16(xstep_for_shift_y, upsample_top_shift), + shift_mask), + 1); + shifts = _mm_packus_epi16(shifts, shifts); + __m128i opposite_shifts = _mm_sub_epi8(max_shift, shifts); + shifts = _mm_unpacklo_epi8(opposite_shifts, shifts); + __m128i xstep_bounds_off = _mm_srai_epi16(xstep_bounds, 6); + DirectionalZone1Blend_8xH<upsampled_top, 8>( + dst_x, top_row + (x << upsample_top_shift), stride, sampler_top, + xstep_bounds_off, shifts, dest_index_x, top_x, xstep); + } + // Pick up from the last y-value, using the 10% slower but secure method for + // left prediction. + const auto base_left_y = static_cast<int16_t>(_mm_extract_epi16(left_y, 0)); + for (; y < min_left_only_y; + y += 8, dst_x += stride8, + xstep_bounds = _mm_add_epi16(xstep_bounds, xstep8_vect), + xstep_for_shift_y = _mm_sub_epi16(xstep_for_shift_y, xstep8_vect), + top_x -= xstep8) { + const __m128i xstep_bounds_off = _mm_srai_epi16(xstep_bounds, 6); + + DirectionalZone3_8xH<upsampled_left, 8>( + dst_x, stride, + left_column + ((left_offset + y) << upsample_left_shift), base_left_y, + -ystep); + + __m128i shifts = _mm_srli_epi16( + _mm_and_si128(_mm_slli_epi16(xstep_for_shift_y, upsample_top_shift), + shift_mask), + 1); + shifts = _mm_packus_epi16(shifts, shifts); + __m128i opposite_shifts = _mm_sub_epi8(max_shift, shifts); + shifts = _mm_unpacklo_epi8(opposite_shifts, shifts); + DirectionalZone1Blend_8xH<upsampled_top, 8>( + dst_x, top_row + (x << upsample_top_shift), stride, sampler_top, + xstep_bounds_off, shifts, dest_index_x, top_x, xstep); + } + // Loop over y for left_only rows. + for (; y < height; y += 8, dst_x += stride8) { + DirectionalZone3_8xH<upsampled_left, 8>( + dst_x, stride, + left_column + ((left_offset + y) << upsample_left_shift), base_left_y, + -ystep); + } + } + for (; x < width; x += 4) { + DirectionalZone1_4xH(dst + x, stride, top_row + (x << upsample_top_shift), + height, -xstep, upsampled_top); + } +} + +template <bool upsampled_left, bool upsampled_top> +inline void DirectionalZone2_4_SSE4_1(void* dest, ptrdiff_t stride, + const uint8_t* const top_row, + const uint8_t* const left_column, + const int width, const int height, + const int xstep, const int ystep) { + auto* dst = static_cast<uint8_t*>(dest); + const int upsample_left_shift = static_cast<int>(upsampled_left); + const int upsample_top_shift = static_cast<int>(upsampled_top); + const __m128i max_shift = _mm_set1_epi8(32); + const ptrdiff_t stride4 = stride << 2; + const __m128i dest_index_x = _mm_set_epi32(0, 0, 0x00030002, 0x00010000); + const __m128i sampler_top = + upsampled_top + ? _mm_set_epi32(0x0F0E0D0C, 0x0B0A0908, 0x07060504, 0x03020100) + : _mm_set_epi32(0x08070706, 0x06050504, 0x04030302, 0x02010100); + // All columns from |min_top_only_x| to the right will only need |top_row| to + // compute. + assert(xstep >= 3); + const int min_top_only_x = std::min((height * xstep) >> 6, width); + + const int xstep4 = xstep << 2; + const __m128i xstep4_vect = _mm_set1_epi16(xstep4); + const __m128i xstep_dup = _mm_set1_epi16(-xstep); + const __m128i increments = _mm_set_epi32(0, 0, 0x00040003, 0x00020001); + __m128i xstep_for_shift = _mm_mullo_epi16(xstep_dup, increments); + const __m128i scaled_one = _mm_set1_epi16(-64); + // Offsets the original zone bound value to simplify x < (y+1)*xstep/64 -1 + __m128i xstep_bounds_base = + (xstep == 64) ? _mm_sub_epi16(scaled_one, xstep_for_shift) + : _mm_sub_epi16(_mm_set1_epi16(-1), xstep_for_shift); + + const int left_base_increment = ystep >> 6; + const int ystep_remainder = ystep & 0x3F; + const int ystep4 = ystep << 2; + const int left_base_increment4 = ystep4 >> 6; + // This is guaranteed to be less than 64, but accumulation may bring it past + // 64 for higher x values. + const int ystep_remainder4 = ystep4 & 0x3F; + const __m128i increment_left4 = _mm_set1_epi16(-ystep_remainder4); + const __m128i increment_top4 = _mm_set1_epi16(4 << 6); + + // If the 64 scaling is regarded as a decimal point, the first value of the + // left_y vector omits the portion which will go into the left_column offset. + // Following values need the full ystep as a relative offset. + const __m128i ystep_init = _mm_set1_epi16(-ystep_remainder); + const __m128i ystep_dup = _mm_set1_epi16(-ystep); + __m128i left_y = _mm_mullo_epi16(ystep_dup, dest_index_x); + left_y = _mm_add_epi16(ystep_init, left_y); + const __m128i shift_mask = _mm_set1_epi32(0x003F003F); + + int x = 0; + // Loop over x for columns with a mixture of sources. + for (int left_offset = -left_base_increment; x < min_top_only_x; x += 4, + xstep_bounds_base = _mm_sub_epi16(xstep_bounds_base, increment_top4), + left_y = _mm_add_epi16(left_y, increment_left4), + left_offset -= left_base_increment4) { + uint8_t* dst_x = dst + x; + + // Round down to the nearest multiple of 8. + const int max_top_only_y = std::min((x << 6) / xstep, height) & 0xFFFFFFF4; + DirectionalZone1_4xH(dst_x, stride, top_row + (x << upsample_top_shift), + max_top_only_y, -xstep, upsampled_top); + int y = max_top_only_y; + dst_x += stride * y; + const int xstep_y = xstep * y; + const __m128i xstep_y_vect = _mm_set1_epi16(xstep_y); + // All rows from |min_left_only_y| down for this set of columns, only need + // |left_column| to compute. Rounded up to the nearest multiple of 4. + const int min_left_only_y = std::min(((x + 4) << 6) / xstep, height); + + __m128i xstep_bounds = _mm_add_epi16(xstep_bounds_base, xstep_y_vect); + __m128i xstep_for_shift_y = _mm_sub_epi16(xstep_for_shift, xstep_y_vect); + int top_x = -xstep_y; + + // Loop over y for mixed rows. + for (; y < min_left_only_y; + y += 4, dst_x += stride4, + xstep_bounds = _mm_add_epi16(xstep_bounds, xstep4_vect), + xstep_for_shift_y = _mm_sub_epi16(xstep_for_shift_y, xstep4_vect), + top_x -= xstep4) { + DirectionalZone2FromLeftCol_4x4_SSE4_1<upsampled_left>( + dst_x, stride, + left_column + ((left_offset + y) * (1 << upsample_left_shift)), + left_y); + + __m128i shifts = _mm_srli_epi16( + _mm_and_si128(_mm_slli_epi16(xstep_for_shift_y, upsample_top_shift), + shift_mask), + 1); + shifts = _mm_packus_epi16(shifts, shifts); + const __m128i opposite_shifts = _mm_sub_epi8(max_shift, shifts); + shifts = _mm_unpacklo_epi8(opposite_shifts, shifts); + const __m128i xstep_bounds_off = _mm_srai_epi16(xstep_bounds, 6); + DirectionalZone1Blend_4x4<upsampled_top>( + dst_x, top_row + (x << upsample_top_shift), stride, sampler_top, + xstep_bounds_off, shifts, dest_index_x, top_x, xstep); + } + // Loop over y for left-only rows, if any. + for (; y < height; y += 4, dst_x += stride4) { + DirectionalZone2FromLeftCol_4x4_SSE4_1<upsampled_left>( + dst_x, stride, + left_column + ((left_offset + y) << upsample_left_shift), left_y); + } + } + // Loop over top-only columns, if any. + for (; x < width; x += 4) { + DirectionalZone1_4xH(dst + x, stride, top_row + (x << upsample_top_shift), + height, -xstep, upsampled_top); + } +} + +void DirectionalIntraPredictorZone2_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column, + const int width, const int height, + const int xstep, const int ystep, + const bool upsampled_top, + const bool upsampled_left) { + // Increasing the negative buffer for this function allows more rows to be + // processed at a time without branching in an inner loop to check the base. + uint8_t top_buffer[288]; + uint8_t left_buffer[288]; + memcpy(top_buffer + 128, static_cast<const uint8_t*>(top_row) - 16, 160); + memcpy(left_buffer + 128, static_cast<const uint8_t*>(left_column) - 16, 160); + const uint8_t* top_ptr = top_buffer + 144; + const uint8_t* left_ptr = left_buffer + 144; + if (width == 4 || height == 4) { + if (upsampled_left) { + if (upsampled_top) { + DirectionalZone2_4_SSE4_1<true, true>(dest, stride, top_ptr, left_ptr, + width, height, xstep, ystep); + } else { + DirectionalZone2_4_SSE4_1<true, false>(dest, stride, top_ptr, left_ptr, + width, height, xstep, ystep); + } + } else { + if (upsampled_top) { + DirectionalZone2_4_SSE4_1<false, true>(dest, stride, top_ptr, left_ptr, + width, height, xstep, ystep); + } else { + DirectionalZone2_4_SSE4_1<false, false>(dest, stride, top_ptr, left_ptr, + width, height, xstep, ystep); + } + } + return; + } + if (upsampled_left) { + if (upsampled_top) { + DirectionalZone2_SSE4_1<true, true>(dest, stride, top_ptr, left_ptr, + width, height, xstep, ystep); + } else { + DirectionalZone2_SSE4_1<true, false>(dest, stride, top_ptr, left_ptr, + width, height, xstep, ystep); + } + } else { + if (upsampled_top) { + DirectionalZone2_SSE4_1<false, true>(dest, stride, top_ptr, left_ptr, + width, height, xstep, ystep); + } else { + DirectionalZone2_SSE4_1<false, false>(dest, stride, top_ptr, left_ptr, + width, height, xstep, ystep); + } + } +} + +//------------------------------------------------------------------------------ +// FilterIntraPredictor_SSE4_1 + +// Apply all filter taps to the given 7 packed 16-bit values, keeping the 8th +// at zero to preserve the sum. +inline void Filter4x2_SSE4_1(uint8_t* dst, const ptrdiff_t stride, + const __m128i& pixels, const __m128i& taps_0_1, + const __m128i& taps_2_3, const __m128i& taps_4_5, + const __m128i& taps_6_7) { + const __m128i mul_0_01 = _mm_maddubs_epi16(pixels, taps_0_1); + const __m128i mul_0_23 = _mm_maddubs_epi16(pixels, taps_2_3); + // |output_half| contains 8 partial sums. + __m128i output_half = _mm_hadd_epi16(mul_0_01, mul_0_23); + __m128i output = _mm_hadd_epi16(output_half, output_half); + const __m128i output_row0 = + _mm_packus_epi16(RightShiftWithRounding_S16(output, 4), + /* arbitrary pack arg */ output); + Store4(dst, output_row0); + const __m128i mul_1_01 = _mm_maddubs_epi16(pixels, taps_4_5); + const __m128i mul_1_23 = _mm_maddubs_epi16(pixels, taps_6_7); + output_half = _mm_hadd_epi16(mul_1_01, mul_1_23); + output = _mm_hadd_epi16(output_half, output_half); + const __m128i output_row1 = + _mm_packus_epi16(RightShiftWithRounding_S16(output, 4), + /* arbitrary pack arg */ output); + Store4(dst + stride, output_row1); +} + +// 4xH transform sizes are given special treatment because LoadLo8 goes out +// of bounds and every block involves the left column. This implementation +// loads TL from the top row for the first block, so it is not +inline void Filter4xH(uint8_t* dest, ptrdiff_t stride, + const uint8_t* const top_ptr, + const uint8_t* const left_ptr, FilterIntraPredictor pred, + const int height) { + const __m128i taps_0_1 = LoadUnaligned16(kFilterIntraTaps[pred][0]); + const __m128i taps_2_3 = LoadUnaligned16(kFilterIntraTaps[pred][2]); + const __m128i taps_4_5 = LoadUnaligned16(kFilterIntraTaps[pred][4]); + const __m128i taps_6_7 = LoadUnaligned16(kFilterIntraTaps[pred][6]); + __m128i top = Load4(top_ptr - 1); + __m128i pixels = _mm_insert_epi8(top, top_ptr[3], 4); + __m128i left = (height == 4 ? Load4(left_ptr) : LoadLo8(left_ptr)); + left = _mm_slli_si128(left, 5); + + // Relative pixels: top[-1], top[0], top[1], top[2], top[3], left[0], left[1], + // left[2], left[3], left[4], left[5], left[6], left[7] + pixels = _mm_or_si128(left, pixels); + + // Duplicate first 8 bytes. + pixels = _mm_shuffle_epi32(pixels, kDuplicateFirstHalf); + Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + dest += stride; // Move to y = 1. + pixels = Load4(dest); + + // Relative pixels: top[0], top[1], top[2], top[3], empty, left[-2], left[-1], + // left[0], left[1], ... + pixels = _mm_or_si128(left, pixels); + + // This mask rearranges bytes in the order: 6, 0, 1, 2, 3, 7, 8, 15. The last + // byte is an unused value, which shall be multiplied by 0 when we apply the + // filter. + constexpr int64_t kInsertTopLeftFirstMask = 0x0F08070302010006; + + // Insert left[-1] in front as TL and put left[0] and left[1] at the end. + const __m128i pixel_order1 = _mm_set1_epi64x(kInsertTopLeftFirstMask); + pixels = _mm_shuffle_epi8(pixels, pixel_order1); + dest += stride; // Move to y = 2. + Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + dest += stride; // Move to y = 3. + + // Compute the middle 8 rows before using common code for the final 4 rows. + // Because the common code below this block assumes that + if (height == 16) { + // This shift allows us to use pixel_order2 twice after shifting by 2 later. + left = _mm_slli_si128(left, 1); + pixels = Load4(dest); + + // Relative pixels: top[0], top[1], top[2], top[3], empty, empty, left[-4], + // left[-3], left[-2], left[-1], left[0], left[1], left[2], left[3] + pixels = _mm_or_si128(left, pixels); + + // This mask rearranges bytes in the order: 9, 0, 1, 2, 3, 7, 8, 15. The + // last byte is an unused value, as above. The top-left was shifted to + // position nine to keep two empty spaces after the top pixels. + constexpr int64_t kInsertTopLeftSecondMask = 0x0F0B0A0302010009; + + // Insert (relative) left[-1] in front as TL and put left[0] and left[1] at + // the end. + const __m128i pixel_order2 = _mm_set1_epi64x(kInsertTopLeftSecondMask); + pixels = _mm_shuffle_epi8(pixels, pixel_order2); + dest += stride; // Move to y = 4. + + // First 4x2 in the if body. + Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + + // Clear all but final pixel in the first 8 of left column. + __m128i keep_top_left = _mm_srli_si128(left, 13); + dest += stride; // Move to y = 5. + pixels = Load4(dest); + left = _mm_srli_si128(left, 2); + + // Relative pixels: top[0], top[1], top[2], top[3], left[-6], + // left[-5], left[-4], left[-3], left[-2], left[-1], left[0], left[1] + pixels = _mm_or_si128(left, pixels); + left = LoadLo8(left_ptr + 8); + + pixels = _mm_shuffle_epi8(pixels, pixel_order2); + dest += stride; // Move to y = 6. + + // Second 4x2 in the if body. + Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + + // Position TL value so we can use pixel_order1. + keep_top_left = _mm_slli_si128(keep_top_left, 6); + dest += stride; // Move to y = 7. + pixels = Load4(dest); + left = _mm_slli_si128(left, 7); + left = _mm_or_si128(left, keep_top_left); + + // Relative pixels: top[0], top[1], top[2], top[3], empty, empty, + // left[-1], left[0], left[1], left[2], left[3], ... + pixels = _mm_or_si128(left, pixels); + pixels = _mm_shuffle_epi8(pixels, pixel_order1); + dest += stride; // Move to y = 8. + + // Third 4x2 in the if body. + Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + dest += stride; // Move to y = 9. + + // Prepare final inputs. + pixels = Load4(dest); + left = _mm_srli_si128(left, 2); + + // Relative pixels: top[0], top[1], top[2], top[3], left[-3], left[-2] + // left[-1], left[0], left[1], left[2], left[3], ... + pixels = _mm_or_si128(left, pixels); + pixels = _mm_shuffle_epi8(pixels, pixel_order1); + dest += stride; // Move to y = 10. + + // Fourth 4x2 in the if body. + Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + dest += stride; // Move to y = 11. + } + + // In both the 8 and 16 case, we assume that the left vector has the next TL + // at position 8. + if (height > 4) { + // Erase prior left pixels by shifting TL to position 0. + left = _mm_srli_si128(left, 8); + left = _mm_slli_si128(left, 6); + pixels = Load4(dest); + + // Relative pixels: top[0], top[1], top[2], top[3], empty, empty, + // left[-1], left[0], left[1], left[2], left[3], ... + pixels = _mm_or_si128(left, pixels); + pixels = _mm_shuffle_epi8(pixels, pixel_order1); + dest += stride; // Move to y = 12 or 4. + + // First of final two 4x2 blocks. + Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + dest += stride; // Move to y = 13 or 5. + pixels = Load4(dest); + left = _mm_srli_si128(left, 2); + + // Relative pixels: top[0], top[1], top[2], top[3], left[-3], left[-2] + // left[-1], left[0], left[1], left[2], left[3], ... + pixels = _mm_or_si128(left, pixels); + pixels = _mm_shuffle_epi8(pixels, pixel_order1); + dest += stride; // Move to y = 14 or 6. + + // Last of final two 4x2 blocks. + Filter4x2_SSE4_1(dest, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + } +} + +void FilterIntraPredictor_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const top_row, + const void* const left_column, + FilterIntraPredictor pred, const int width, + const int height) { + const auto* const top_ptr = static_cast<const uint8_t*>(top_row); + const auto* const left_ptr = static_cast<const uint8_t*>(left_column); + auto* dst = static_cast<uint8_t*>(dest); + if (width == 4) { + Filter4xH(dst, stride, top_ptr, left_ptr, pred, height); + return; + } + + // There is one set of 7 taps for each of the 4x2 output pixels. + const __m128i taps_0_1 = LoadUnaligned16(kFilterIntraTaps[pred][0]); + const __m128i taps_2_3 = LoadUnaligned16(kFilterIntraTaps[pred][2]); + const __m128i taps_4_5 = LoadUnaligned16(kFilterIntraTaps[pred][4]); + const __m128i taps_6_7 = LoadUnaligned16(kFilterIntraTaps[pred][6]); + + // This mask rearranges bytes in the order: 0, 1, 2, 3, 4, 8, 9, 15. The 15 at + // the end is an unused value, which shall be multiplied by 0 when we apply + // the filter. + constexpr int64_t kCondenseLeftMask = 0x0F09080403020100; + + // Takes the "left section" and puts it right after p0-p4. + const __m128i pixel_order1 = _mm_set1_epi64x(kCondenseLeftMask); + + // This mask rearranges bytes in the order: 8, 0, 1, 2, 3, 9, 10, 15. The last + // byte is unused as above. + constexpr int64_t kInsertTopLeftMask = 0x0F0A090302010008; + + // Shuffles the "top left" from the left section, to the front. Used when + // grabbing data from left_column and not top_row. + const __m128i pixel_order2 = _mm_set1_epi64x(kInsertTopLeftMask); + + // This first pass takes care of the cases where the top left pixel comes from + // top_row. + __m128i pixels = LoadLo8(top_ptr - 1); + __m128i left = _mm_slli_si128(Load4(left_column), 8); + pixels = _mm_or_si128(pixels, left); + + // Two sets of the same pixels to multiply with two sets of taps. + pixels = _mm_shuffle_epi8(pixels, pixel_order1); + Filter4x2_SSE4_1(dst, stride, pixels, taps_0_1, taps_2_3, taps_4_5, taps_6_7); + left = _mm_srli_si128(left, 1); + + // Load + pixels = Load4(dst + stride); + + // Because of the above shift, this OR 'invades' the final of the first 8 + // bytes of |pixels|. This is acceptable because the 8th filter tap is always + // a padded 0. + pixels = _mm_or_si128(pixels, left); + pixels = _mm_shuffle_epi8(pixels, pixel_order2); + const ptrdiff_t stride2 = stride << 1; + const ptrdiff_t stride4 = stride << 2; + Filter4x2_SSE4_1(dst + stride2, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + dst += 4; + for (int x = 3; x < width - 4; x += 4) { + pixels = Load4(top_ptr + x); + pixels = _mm_insert_epi8(pixels, top_ptr[x + 4], 4); + pixels = _mm_insert_epi8(pixels, dst[-1], 5); + pixels = _mm_insert_epi8(pixels, dst[stride - 1], 6); + + // Duplicate bottom half into upper half. + pixels = _mm_shuffle_epi32(pixels, kDuplicateFirstHalf); + Filter4x2_SSE4_1(dst, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + pixels = Load4(dst + stride - 1); + pixels = _mm_insert_epi8(pixels, dst[stride + 3], 4); + pixels = _mm_insert_epi8(pixels, dst[stride2 - 1], 5); + pixels = _mm_insert_epi8(pixels, dst[stride + stride2 - 1], 6); + + // Duplicate bottom half into upper half. + pixels = _mm_shuffle_epi32(pixels, kDuplicateFirstHalf); + Filter4x2_SSE4_1(dst + stride2, stride, pixels, taps_0_1, taps_2_3, + taps_4_5, taps_6_7); + dst += 4; + } + + // Now we handle heights that reference previous blocks rather than top_row. + for (int y = 4; y < height; y += 4) { + // Leftmost 4x4 block for this height. + dst -= width; + dst += stride4; + + // Top Left is not available by offset in these leftmost blocks. + pixels = Load4(dst - stride); + left = _mm_slli_si128(Load4(left_ptr + y - 1), 8); + left = _mm_insert_epi8(left, left_ptr[y + 3], 12); + pixels = _mm_or_si128(pixels, left); + pixels = _mm_shuffle_epi8(pixels, pixel_order2); + Filter4x2_SSE4_1(dst, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + + // The bytes shifted into positions 6 and 7 will be ignored by the shuffle. + left = _mm_srli_si128(left, 2); + pixels = Load4(dst + stride); + pixels = _mm_or_si128(pixels, left); + pixels = _mm_shuffle_epi8(pixels, pixel_order2); + Filter4x2_SSE4_1(dst + stride2, stride, pixels, taps_0_1, taps_2_3, + taps_4_5, taps_6_7); + + dst += 4; + + // Remaining 4x4 blocks for this height. + for (int x = 4; x < width; x += 4) { + pixels = Load4(dst - stride - 1); + pixels = _mm_insert_epi8(pixels, dst[-stride + 3], 4); + pixels = _mm_insert_epi8(pixels, dst[-1], 5); + pixels = _mm_insert_epi8(pixels, dst[stride - 1], 6); + + // Duplicate bottom half into upper half. + pixels = _mm_shuffle_epi32(pixels, kDuplicateFirstHalf); + Filter4x2_SSE4_1(dst, stride, pixels, taps_0_1, taps_2_3, taps_4_5, + taps_6_7); + pixels = Load4(dst + stride - 1); + pixels = _mm_insert_epi8(pixels, dst[stride + 3], 4); + pixels = _mm_insert_epi8(pixels, dst[stride2 - 1], 5); + pixels = _mm_insert_epi8(pixels, dst[stride2 + stride - 1], 6); + + // Duplicate bottom half into upper half. + pixels = _mm_shuffle_epi32(pixels, kDuplicateFirstHalf); + Filter4x2_SSE4_1(dst + stride2, stride, pixels, taps_0_1, taps_2_3, + taps_4_5, taps_6_7); + dst += 4; + } + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + static_cast<void>(dsp); +// These guards check if this version of the function was not superseded by +// a higher optimization level, such as AVX. The corresponding #define also +// prevents the C version from being added to the table. +#if DSP_ENABLED_8BPP_SSE4_1(FilterIntraPredictor) + dsp->filter_intra_predictor = FilterIntraPredictor_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(DirectionalIntraPredictorZone1) + dsp->directional_intra_predictor_zone1 = + DirectionalIntraPredictorZone1_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(DirectionalIntraPredictorZone2) + dsp->directional_intra_predictor_zone2 = + DirectionalIntraPredictorZone2_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(DirectionalIntraPredictorZone3) + dsp->directional_intra_predictor_zone3 = + DirectionalIntraPredictorZone3_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcTop] = + DcDefs::_4x4::DcTop; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x8_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcTop] = + DcDefs::_4x8::DcTop; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x16_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcTop] = + DcDefs::_4x16::DcTop; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x4_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcTop] = + DcDefs::_8x4::DcTop; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x8_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcTop] = + DcDefs::_8x8::DcTop; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x16_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcTop] = + DcDefs::_8x16::DcTop; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x32_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcTop] = + DcDefs::_8x32::DcTop; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x4_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcTop] = + DcDefs::_16x4::DcTop; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x8_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcTop] = + DcDefs::_16x8::DcTop; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x16_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcTop] = + DcDefs::_16x16::DcTop; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x32_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcTop] = + DcDefs::_16x32::DcTop; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x64_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcTop] = + DcDefs::_16x64::DcTop; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x8_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcTop] = + DcDefs::_32x8::DcTop; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x16_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcTop] = + DcDefs::_32x16::DcTop; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x32_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcTop] = + DcDefs::_32x32::DcTop; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x64_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcTop] = + DcDefs::_32x64::DcTop; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x16_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcTop] = + DcDefs::_64x16::DcTop; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x32_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcTop] = + DcDefs::_64x32::DcTop; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x64_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcTop] = + DcDefs::_64x64::DcTop; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcLeft] = + DcDefs::_4x4::DcLeft; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x8_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcLeft] = + DcDefs::_4x8::DcLeft; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x16_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcLeft] = + DcDefs::_4x16::DcLeft; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x4_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcLeft] = + DcDefs::_8x4::DcLeft; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x8_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcLeft] = + DcDefs::_8x8::DcLeft; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x16_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcLeft] = + DcDefs::_8x16::DcLeft; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x32_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcLeft] = + DcDefs::_8x32::DcLeft; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x4_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcLeft] = + DcDefs::_16x4::DcLeft; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x8_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcLeft] = + DcDefs::_16x8::DcLeft; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x16_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcLeft] = + DcDefs::_16x16::DcLeft; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x32_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcLeft] = + DcDefs::_16x32::DcLeft; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x64_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcLeft] = + DcDefs::_16x64::DcLeft; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x8_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcLeft] = + DcDefs::_32x8::DcLeft; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x16_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcLeft] = + DcDefs::_32x16::DcLeft; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x32_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcLeft] = + DcDefs::_32x32::DcLeft; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x64_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcLeft] = + DcDefs::_32x64::DcLeft; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x16_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcLeft] = + DcDefs::_64x16::DcLeft; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x32_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcLeft] = + DcDefs::_64x32::DcLeft; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x64_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcLeft] = + DcDefs::_64x64::DcLeft; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_IntraPredictorDc) + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDc] = + DcDefs::_4x4::Dc; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x8_IntraPredictorDc) + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDc] = + DcDefs::_4x8::Dc; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x16_IntraPredictorDc) + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDc] = + DcDefs::_4x16::Dc; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x4_IntraPredictorDc) + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDc] = + DcDefs::_8x4::Dc; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x8_IntraPredictorDc) + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDc] = + DcDefs::_8x8::Dc; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x16_IntraPredictorDc) + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDc] = + DcDefs::_8x16::Dc; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x32_IntraPredictorDc) + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDc] = + DcDefs::_8x32::Dc; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x4_IntraPredictorDc) + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDc] = + DcDefs::_16x4::Dc; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x8_IntraPredictorDc) + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDc] = + DcDefs::_16x8::Dc; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x16_IntraPredictorDc) + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDc] = + DcDefs::_16x16::Dc; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x32_IntraPredictorDc) + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDc] = + DcDefs::_16x32::Dc; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x64_IntraPredictorDc) + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDc] = + DcDefs::_16x64::Dc; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x8_IntraPredictorDc) + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDc] = + DcDefs::_32x8::Dc; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x16_IntraPredictorDc) + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDc] = + DcDefs::_32x16::Dc; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x32_IntraPredictorDc) + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDc] = + DcDefs::_32x32::Dc; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x64_IntraPredictorDc) + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDc] = + DcDefs::_32x64::Dc; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x16_IntraPredictorDc) + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDc] = + DcDefs::_64x16::Dc; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x32_IntraPredictorDc) + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDc] = + DcDefs::_64x32::Dc; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x64_IntraPredictorDc) + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDc] = + DcDefs::_64x64::Dc; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_IntraPredictorPaeth) + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorPaeth] = + Paeth4x4_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x8_IntraPredictorPaeth) + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorPaeth] = + Paeth4x8_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x16_IntraPredictorPaeth) + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorPaeth] = + Paeth4x16_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x4_IntraPredictorPaeth) + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorPaeth] = + Paeth8x4_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x8_IntraPredictorPaeth) + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorPaeth] = + Paeth8x8_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x16_IntraPredictorPaeth) + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorPaeth] = + Paeth8x16_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x32_IntraPredictorPaeth) + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorPaeth] = + Paeth8x32_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x4_IntraPredictorPaeth) + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorPaeth] = + Paeth16x4_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x8_IntraPredictorPaeth) + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorPaeth] = + Paeth16x8_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x16_IntraPredictorPaeth) + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorPaeth] = + Paeth16x16_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x32_IntraPredictorPaeth) + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorPaeth] = + Paeth16x32_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x64_IntraPredictorPaeth) + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorPaeth] = + Paeth16x64_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x8_IntraPredictorPaeth) + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorPaeth] = + Paeth32x8_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x16_IntraPredictorPaeth) + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorPaeth] = + Paeth32x16_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x32_IntraPredictorPaeth) + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorPaeth] = + Paeth32x32_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x64_IntraPredictorPaeth) + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorPaeth] = + Paeth32x64_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x16_IntraPredictorPaeth) + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorPaeth] = + Paeth64x16_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x32_IntraPredictorPaeth) + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorPaeth] = + Paeth64x32_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x64_IntraPredictorPaeth) + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorPaeth] = + Paeth64x64_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorHorizontal] = + DirDefs::_4x4::Horizontal; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x8_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorHorizontal] = + DirDefs::_4x8::Horizontal; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x16_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorHorizontal] = + DirDefs::_4x16::Horizontal; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x4_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorHorizontal] = + DirDefs::_8x4::Horizontal; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x8_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorHorizontal] = + DirDefs::_8x8::Horizontal; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x16_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorHorizontal] = + DirDefs::_8x16::Horizontal; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x32_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorHorizontal] = + DirDefs::_8x32::Horizontal; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x4_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorHorizontal] = + DirDefs::_16x4::Horizontal; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x8_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorHorizontal] = + DirDefs::_16x8::Horizontal; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x16_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorHorizontal] = + DirDefs::_16x16::Horizontal; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x32_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorHorizontal] = + DirDefs::_16x32::Horizontal; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x64_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorHorizontal] = + DirDefs::_16x64::Horizontal; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x8_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorHorizontal] = + DirDefs::_32x8::Horizontal; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x16_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorHorizontal] = + DirDefs::_32x16::Horizontal; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x32_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorHorizontal] = + DirDefs::_32x32::Horizontal; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x64_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorHorizontal] = + DirDefs::_32x64::Horizontal; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x16_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorHorizontal] = + DirDefs::_64x16::Horizontal; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x32_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorHorizontal] = + DirDefs::_64x32::Horizontal; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(TransformSize64x64_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorHorizontal] = + DirDefs::_64x64::Horizontal; +#endif +} // NOLINT(readability/fn_size) +// TODO(petersonab): Split Init8bpp function into family-specific files. + +} // namespace +} // namespace low_bitdepth + +//------------------------------------------------------------------------------ +#if LIBGAV1_MAX_BITDEPTH >= 10 +namespace high_bitdepth { +namespace { + +template <int height> +inline void DcStore4xH_SSE4_1(void* const dest, ptrdiff_t stride, + const __m128i dc) { + const __m128i dc_dup = _mm_shufflelo_epi16(dc, 0); + int y = height - 1; + auto* dst = static_cast<uint8_t*>(dest); + do { + StoreLo8(dst, dc_dup); + dst += stride; + } while (--y != 0); + StoreLo8(dst, dc_dup); +} + +// WriteDuplicateN assumes dup has 4 32-bit "units," each of which comprises 2 +// identical shorts that need N total copies written into dest. The unpacking +// works the same as in the 8bpp case, except that each 32-bit unit needs twice +// as many copies. +inline void WriteDuplicate4x4(void* const dest, ptrdiff_t stride, + const __m128i dup32) { + const __m128i dup64_lo = _mm_unpacklo_epi32(dup32, dup32); + auto* dst = static_cast<uint8_t*>(dest); + _mm_storel_epi64(reinterpret_cast<__m128i*>(dst), dup64_lo); + dst += stride; + _mm_storeh_pi(reinterpret_cast<__m64*>(dst), _mm_castsi128_ps(dup64_lo)); + dst += stride; + const __m128i dup64_hi = _mm_unpackhi_epi32(dup32, dup32); + _mm_storel_epi64(reinterpret_cast<__m128i*>(dst), dup64_hi); + dst += stride; + _mm_storeh_pi(reinterpret_cast<__m64*>(dst), _mm_castsi128_ps(dup64_hi)); +} + +inline void WriteDuplicate8x4(void* const dest, ptrdiff_t stride, + const __m128i dup32) { + const __m128i dup64_lo = _mm_unpacklo_epi32(dup32, dup32); + const __m128i dup64_hi = _mm_unpackhi_epi32(dup32, dup32); + + auto* dst = static_cast<uint8_t*>(dest); + const __m128i dup128_0 = _mm_unpacklo_epi64(dup64_lo, dup64_lo); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_0); + dst += stride; + const __m128i dup128_1 = _mm_unpackhi_epi64(dup64_lo, dup64_lo); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_1); + dst += stride; + const __m128i dup128_2 = _mm_unpacklo_epi64(dup64_hi, dup64_hi); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_2); + dst += stride; + const __m128i dup128_3 = _mm_unpackhi_epi64(dup64_hi, dup64_hi); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_3); +} + +inline void WriteDuplicate16x4(void* const dest, ptrdiff_t stride, + const __m128i dup32) { + const __m128i dup64_lo = _mm_unpacklo_epi32(dup32, dup32); + const __m128i dup64_hi = _mm_unpackhi_epi32(dup32, dup32); + + auto* dst = static_cast<uint8_t*>(dest); + const __m128i dup128_0 = _mm_unpacklo_epi64(dup64_lo, dup64_lo); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_0); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 16), dup128_0); + dst += stride; + const __m128i dup128_1 = _mm_unpackhi_epi64(dup64_lo, dup64_lo); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_1); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 16), dup128_1); + dst += stride; + const __m128i dup128_2 = _mm_unpacklo_epi64(dup64_hi, dup64_hi); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_2); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 16), dup128_2); + dst += stride; + const __m128i dup128_3 = _mm_unpackhi_epi64(dup64_hi, dup64_hi); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_3); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 16), dup128_3); +} + +inline void WriteDuplicate32x4(void* const dest, ptrdiff_t stride, + const __m128i dup32) { + const __m128i dup64_lo = _mm_unpacklo_epi32(dup32, dup32); + const __m128i dup64_hi = _mm_unpackhi_epi32(dup32, dup32); + + auto* dst = static_cast<uint8_t*>(dest); + const __m128i dup128_0 = _mm_unpacklo_epi64(dup64_lo, dup64_lo); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_0); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 16), dup128_0); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 32), dup128_0); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 48), dup128_0); + dst += stride; + const __m128i dup128_1 = _mm_unpackhi_epi64(dup64_lo, dup64_lo); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_1); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 16), dup128_1); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 32), dup128_1); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 48), dup128_1); + dst += stride; + const __m128i dup128_2 = _mm_unpacklo_epi64(dup64_hi, dup64_hi); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_2); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 16), dup128_2); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 32), dup128_2); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 48), dup128_2); + dst += stride; + const __m128i dup128_3 = _mm_unpackhi_epi64(dup64_hi, dup64_hi); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), dup128_3); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 16), dup128_3); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 32), dup128_3); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + 48), dup128_3); +} + +inline void WriteDuplicate64x4(void* const dest, ptrdiff_t stride, + const __m128i dup32) { + const __m128i dup64_lo = _mm_unpacklo_epi32(dup32, dup32); + const __m128i dup64_hi = _mm_unpackhi_epi32(dup32, dup32); + + auto* dst = static_cast<uint8_t*>(dest); + const __m128i dup128_0 = _mm_unpacklo_epi64(dup64_lo, dup64_lo); + for (int x = 0; x < 128; x += 16) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + x), dup128_0); + } + dst += stride; + const __m128i dup128_1 = _mm_unpackhi_epi64(dup64_lo, dup64_lo); + for (int x = 0; x < 128; x += 16) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + x), dup128_1); + } + dst += stride; + const __m128i dup128_2 = _mm_unpacklo_epi64(dup64_hi, dup64_hi); + for (int x = 0; x < 128; x += 16) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + x), dup128_2); + } + dst += stride; + const __m128i dup128_3 = _mm_unpackhi_epi64(dup64_hi, dup64_hi); + for (int x = 0; x < 128; x += 16) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + x), dup128_3); + } +} + +// ColStoreN<height> copies each of the |height| values in |column| across its +// corresponding row in dest. +template <WriteDuplicateFunc writefn> +inline void ColStore4_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const column) { + const __m128i col_data = LoadLo8(column); + const __m128i col_dup32 = _mm_unpacklo_epi16(col_data, col_data); + writefn(dest, stride, col_dup32); +} + +template <WriteDuplicateFunc writefn> +inline void ColStore8_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const column) { + const __m128i col_data = LoadUnaligned16(column); + const __m128i col_dup32_lo = _mm_unpacklo_epi16(col_data, col_data); + const __m128i col_dup32_hi = _mm_unpackhi_epi16(col_data, col_data); + auto* dst = static_cast<uint8_t*>(dest); + writefn(dst, stride, col_dup32_lo); + const ptrdiff_t stride4 = stride << 2; + dst += stride4; + writefn(dst, stride, col_dup32_hi); +} + +template <WriteDuplicateFunc writefn> +inline void ColStore16_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const column) { + const ptrdiff_t stride4 = stride << 2; + auto* dst = static_cast<uint8_t*>(dest); + for (int y = 0; y < 32; y += 16) { + const __m128i col_data = + LoadUnaligned16(static_cast<const uint8_t*>(column) + y); + const __m128i col_dup32_lo = _mm_unpacklo_epi16(col_data, col_data); + const __m128i col_dup32_hi = _mm_unpackhi_epi16(col_data, col_data); + writefn(dst, stride, col_dup32_lo); + dst += stride4; + writefn(dst, stride, col_dup32_hi); + dst += stride4; + } +} + +template <WriteDuplicateFunc writefn> +inline void ColStore32_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const column) { + const ptrdiff_t stride4 = stride << 2; + auto* dst = static_cast<uint8_t*>(dest); + for (int y = 0; y < 64; y += 16) { + const __m128i col_data = + LoadUnaligned16(static_cast<const uint8_t*>(column) + y); + const __m128i col_dup32_lo = _mm_unpacklo_epi16(col_data, col_data); + const __m128i col_dup32_hi = _mm_unpackhi_epi16(col_data, col_data); + writefn(dst, stride, col_dup32_lo); + dst += stride4; + writefn(dst, stride, col_dup32_hi); + dst += stride4; + } +} + +template <WriteDuplicateFunc writefn> +inline void ColStore64_SSE4_1(void* const dest, ptrdiff_t stride, + const void* const column) { + const ptrdiff_t stride4 = stride << 2; + auto* dst = static_cast<uint8_t*>(dest); + for (int y = 0; y < 128; y += 16) { + const __m128i col_data = + LoadUnaligned16(static_cast<const uint8_t*>(column) + y); + const __m128i col_dup32_lo = _mm_unpacklo_epi16(col_data, col_data); + const __m128i col_dup32_hi = _mm_unpackhi_epi16(col_data, col_data); + writefn(dst, stride, col_dup32_lo); + dst += stride4; + writefn(dst, stride, col_dup32_hi); + dst += stride4; + } +} + +// |ref| points to 8 bytes containing 4 packed int16 values. +inline __m128i DcSum4_SSE4_1(const void* ref) { + const __m128i vals = _mm_loadl_epi64(static_cast<const __m128i*>(ref)); + const __m128i ones = _mm_set1_epi16(1); + + // half_sum[31:0] = a1+a2 + // half_sum[63:32] = a3+a4 + const __m128i half_sum = _mm_madd_epi16(vals, ones); + // Place half_sum[63:32] in shift_sum[31:0]. + const __m128i shift_sum = _mm_srli_si128(half_sum, 4); + return _mm_add_epi32(half_sum, shift_sum); +} + +struct DcDefs { + DcDefs() = delete; + + using _4x4 = DcPredFuncs_SSE4_1<2, 2, DcSum4_SSE4_1, DcSum4_SSE4_1, + DcStore4xH_SSE4_1<4>, 0, 0>; +}; + +struct DirDefs { + DirDefs() = delete; + + using _4x4 = DirectionalPredFuncs_SSE4_1<ColStore4_SSE4_1<WriteDuplicate4x4>>; + using _4x8 = DirectionalPredFuncs_SSE4_1<ColStore8_SSE4_1<WriteDuplicate4x4>>; + using _4x16 = + DirectionalPredFuncs_SSE4_1<ColStore16_SSE4_1<WriteDuplicate4x4>>; + using _8x4 = DirectionalPredFuncs_SSE4_1<ColStore4_SSE4_1<WriteDuplicate8x4>>; + using _8x8 = DirectionalPredFuncs_SSE4_1<ColStore8_SSE4_1<WriteDuplicate8x4>>; + using _8x16 = + DirectionalPredFuncs_SSE4_1<ColStore16_SSE4_1<WriteDuplicate8x4>>; + using _8x32 = + DirectionalPredFuncs_SSE4_1<ColStore32_SSE4_1<WriteDuplicate8x4>>; + using _16x4 = + DirectionalPredFuncs_SSE4_1<ColStore4_SSE4_1<WriteDuplicate16x4>>; + using _16x8 = + DirectionalPredFuncs_SSE4_1<ColStore8_SSE4_1<WriteDuplicate16x4>>; + using _16x16 = + DirectionalPredFuncs_SSE4_1<ColStore16_SSE4_1<WriteDuplicate16x4>>; + using _16x32 = + DirectionalPredFuncs_SSE4_1<ColStore32_SSE4_1<WriteDuplicate16x4>>; + using _16x64 = + DirectionalPredFuncs_SSE4_1<ColStore64_SSE4_1<WriteDuplicate16x4>>; + using _32x8 = + DirectionalPredFuncs_SSE4_1<ColStore8_SSE4_1<WriteDuplicate32x4>>; + using _32x16 = + DirectionalPredFuncs_SSE4_1<ColStore16_SSE4_1<WriteDuplicate32x4>>; + using _32x32 = + DirectionalPredFuncs_SSE4_1<ColStore32_SSE4_1<WriteDuplicate32x4>>; + using _32x64 = + DirectionalPredFuncs_SSE4_1<ColStore64_SSE4_1<WriteDuplicate32x4>>; + using _64x16 = + DirectionalPredFuncs_SSE4_1<ColStore16_SSE4_1<WriteDuplicate64x4>>; + using _64x32 = + DirectionalPredFuncs_SSE4_1<ColStore32_SSE4_1<WriteDuplicate64x4>>; + using _64x64 = + DirectionalPredFuncs_SSE4_1<ColStore64_SSE4_1<WriteDuplicate64x4>>; +}; + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(10); + assert(dsp != nullptr); + static_cast<void>(dsp); +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize4x4_IntraPredictorDcTop) + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcTop] = + DcDefs::_4x4::DcTop; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize4x4_IntraPredictorDcLeft) + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcLeft] = + DcDefs::_4x4::DcLeft; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize4x4_IntraPredictorDc) + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDc] = + DcDefs::_4x4::Dc; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize4x4_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize4x4][kIntraPredictorHorizontal] = + DirDefs::_4x4::Horizontal; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize4x8_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize4x8][kIntraPredictorHorizontal] = + DirDefs::_4x8::Horizontal; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize4x16_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize4x16][kIntraPredictorHorizontal] = + DirDefs::_4x16::Horizontal; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize8x4_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize8x4][kIntraPredictorHorizontal] = + DirDefs::_8x4::Horizontal; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize8x8_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize8x8][kIntraPredictorHorizontal] = + DirDefs::_8x8::Horizontal; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize8x16_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize8x16][kIntraPredictorHorizontal] = + DirDefs::_8x16::Horizontal; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize8x32_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize8x32][kIntraPredictorHorizontal] = + DirDefs::_8x32::Horizontal; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize16x4_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize16x4][kIntraPredictorHorizontal] = + DirDefs::_16x4::Horizontal; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize16x8_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize16x8][kIntraPredictorHorizontal] = + DirDefs::_16x8::Horizontal; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize16x16_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize16x16][kIntraPredictorHorizontal] = + DirDefs::_16x16::Horizontal; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize16x32_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize16x32][kIntraPredictorHorizontal] = + DirDefs::_16x32::Horizontal; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize16x64_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize16x64][kIntraPredictorHorizontal] = + DirDefs::_16x64::Horizontal; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize32x8_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize32x8][kIntraPredictorHorizontal] = + DirDefs::_32x8::Horizontal; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize32x16_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize32x16][kIntraPredictorHorizontal] = + DirDefs::_32x16::Horizontal; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize32x32_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize32x32][kIntraPredictorHorizontal] = + DirDefs::_32x32::Horizontal; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize32x64_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize32x64][kIntraPredictorHorizontal] = + DirDefs::_32x64::Horizontal; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize64x16_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize64x16][kIntraPredictorHorizontal] = + DirDefs::_64x16::Horizontal; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize64x32_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize64x32][kIntraPredictorHorizontal] = + DirDefs::_64x32::Horizontal; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(TransformSize64x64_IntraPredictorHorizontal) + dsp->intra_predictors[kTransformSize64x64][kIntraPredictorHorizontal] = + DirDefs::_64x64::Horizontal; +#endif +} + +} // namespace +} // namespace high_bitdepth +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +void IntraPredInit_SSE4_1() { + low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + high_bitdepth::Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_TARGETING_SSE4_1 +namespace libgav1 { +namespace dsp { + +void IntraPredInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/intrapred_sse4.h b/src/dsp/x86/intrapred_sse4.h new file mode 100644 index 0000000..7f4fcd7 --- /dev/null +++ b/src/dsp/x86/intrapred_sse4.h @@ -0,0 +1,1060 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_INTRAPRED_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_INTRAPRED_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::intra_predictors, Dsp::directional_intra_predictor_zone*, +// Dsp::cfl_intra_predictors, Dsp::cfl_subsamplers and +// Dsp::filter_intra_predictor, see the defines below for specifics. These +// functions are not thread-safe. +void IntraPredInit_SSE4_1(); +void IntraPredCflInit_SSE4_1(); +void IntraPredSmoothInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +// If sse4 is enabled and the baseline isn't set due to a higher level of +// optimization being enabled, signal the sse4 implementation should be used. +#if LIBGAV1_TARGETING_SSE4_1 +#ifndef LIBGAV1_Dsp8bpp_FilterIntraPredictor +#define LIBGAV1_Dsp8bpp_FilterIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone1 +#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone1 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone2 +#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone2 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone3 +#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone3 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcTop +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcTop +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcTop +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcTop +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcTop +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcTop +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcTop +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcTop +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcTop +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcTop +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcTop \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcTop +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcTop \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcTop +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcTop \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcTop +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcTop +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcTop \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcTop +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcTop \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcTop +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcTop \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcTop +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcTop \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcTop +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcTop \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcTop +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcTop \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler420 +#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444 +#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_CflIntraPredictor +#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflIntraPredictor LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcLeft +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcLeft LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcLeft +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcLeft LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcLeft +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcLeft \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcLeft +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcLeft LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcLeft +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcLeft LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcLeft +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcLeft \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcLeft +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcLeft \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcLeft +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcLeft \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcLeft +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcLeft \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcLeft +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcLeft \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcLeft +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcLeft \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcLeft +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcLeft \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcLeft +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcLeft \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcLeft +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcLeft \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcLeft +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcLeft \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcLeft +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcLeft \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcLeft +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcLeft \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcLeft +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcLeft \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcLeft +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcLeft \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDc +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDc +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDc +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDc +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDc +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDc +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDc +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDc +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDc +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDc +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDc +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDc +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDc +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDc +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDc +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDc +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDc +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDc +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDc +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorPaeth +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorPaeth +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorPaeth +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorPaeth +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorPaeth +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorPaeth +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorPaeth +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorPaeth +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorPaeth +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorPaeth +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorPaeth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorPaeth +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorPaeth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorPaeth +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorPaeth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorPaeth +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorPaeth +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorPaeth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorPaeth +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorPaeth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorPaeth +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorPaeth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorPaeth +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorPaeth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorPaeth +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorPaeth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorPaeth +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorPaeth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmooth LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmooth LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmooth LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmooth LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmooth +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmooth \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothVertical +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothHorizontal +#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +//------------------------------------------------------------------------------ +// 10bpp + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcTop +#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcLeft +#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcLeft \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDc +#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDc LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorHorizontal +#define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_X86_INTRAPRED_SSE4_H_ diff --git a/src/dsp/x86/inverse_transform_sse4.cc b/src/dsp/x86/inverse_transform_sse4.cc new file mode 100644 index 0000000..787d706 --- /dev/null +++ b/src/dsp/x86/inverse_transform_sse4.cc @@ -0,0 +1,3086 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/inverse_transform.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 + +#include <smmintrin.h> + +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <cstring> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/dsp/x86/transpose_sse4.h" +#include "src/utils/array_2d.h" +#include "src/utils/common.h" +#include "src/utils/compiler_attributes.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +// Include the constants and utility functions inside the anonymous namespace. +#include "src/dsp/inverse_transform.inc" + +template <int store_width, int store_count> +LIBGAV1_ALWAYS_INLINE void StoreDst(int16_t* dst, int32_t stride, int32_t idx, + const __m128i* s) { + // NOTE: It is expected that the compiler will unroll these loops. + if (store_width == 16) { + for (int i = 0; i < store_count; i += 4) { + StoreUnaligned16(&dst[i * stride + idx], s[i]); + StoreUnaligned16(&dst[(i + 1) * stride + idx], s[i + 1]); + StoreUnaligned16(&dst[(i + 2) * stride + idx], s[i + 2]); + StoreUnaligned16(&dst[(i + 3) * stride + idx], s[i + 3]); + } + } + if (store_width == 8) { + for (int i = 0; i < store_count; i += 4) { + StoreLo8(&dst[i * stride + idx], s[i]); + StoreLo8(&dst[(i + 1) * stride + idx], s[i + 1]); + StoreLo8(&dst[(i + 2) * stride + idx], s[i + 2]); + StoreLo8(&dst[(i + 3) * stride + idx], s[i + 3]); + } + } +} + +template <int load_width, int load_count> +LIBGAV1_ALWAYS_INLINE void LoadSrc(const int16_t* src, int32_t stride, + int32_t idx, __m128i* x) { + // NOTE: It is expected that the compiler will unroll these loops. + if (load_width == 16) { + for (int i = 0; i < load_count; i += 4) { + x[i] = LoadUnaligned16(&src[i * stride + idx]); + x[i + 1] = LoadUnaligned16(&src[(i + 1) * stride + idx]); + x[i + 2] = LoadUnaligned16(&src[(i + 2) * stride + idx]); + x[i + 3] = LoadUnaligned16(&src[(i + 3) * stride + idx]); + } + } + if (load_width == 8) { + for (int i = 0; i < load_count; i += 4) { + x[i] = LoadLo8(&src[i * stride + idx]); + x[i + 1] = LoadLo8(&src[(i + 1) * stride + idx]); + x[i + 2] = LoadLo8(&src[(i + 2) * stride + idx]); + x[i + 3] = LoadLo8(&src[(i + 3) * stride + idx]); + } + } +} + +// Butterfly rotate 4 values. +LIBGAV1_ALWAYS_INLINE void ButterflyRotation_4(__m128i* a, __m128i* b, + const int angle, + const bool flip) { + const int16_t cos128 = Cos128(angle); + const int16_t sin128 = Sin128(angle); + const __m128i psin_pcos = _mm_set1_epi32( + static_cast<uint16_t>(cos128) | (static_cast<uint32_t>(sin128) << 16)); + const __m128i ba = _mm_unpacklo_epi16(*a, *b); + const __m128i ab = _mm_unpacklo_epi16(*b, *a); + const __m128i sign = + _mm_set_epi32(0x80000001, 0x80000001, 0x80000001, 0x80000001); + // -sin cos, -sin cos, -sin cos, -sin cos + const __m128i msin_pcos = _mm_sign_epi16(psin_pcos, sign); + const __m128i x0 = _mm_madd_epi16(ba, msin_pcos); + const __m128i y0 = _mm_madd_epi16(ab, psin_pcos); + const __m128i x1 = RightShiftWithRounding_S32(x0, 12); + const __m128i y1 = RightShiftWithRounding_S32(y0, 12); + const __m128i x = _mm_packs_epi32(x1, x1); + const __m128i y = _mm_packs_epi32(y1, y1); + if (flip) { + *a = y; + *b = x; + } else { + *a = x; + *b = y; + } +} + +// Butterfly rotate 8 values. +LIBGAV1_ALWAYS_INLINE void ButterflyRotation_8(__m128i* a, __m128i* b, + const int angle, + const bool flip) { + const int16_t cos128 = Cos128(angle); + const int16_t sin128 = Sin128(angle); + const __m128i psin_pcos = _mm_set1_epi32( + static_cast<uint16_t>(cos128) | (static_cast<uint32_t>(sin128) << 16)); + const __m128i sign = + _mm_set_epi32(0x80000001, 0x80000001, 0x80000001, 0x80000001); + // -sin cos, -sin cos, -sin cos, -sin cos + const __m128i msin_pcos = _mm_sign_epi16(psin_pcos, sign); + const __m128i ba = _mm_unpacklo_epi16(*a, *b); + const __m128i ab = _mm_unpacklo_epi16(*b, *a); + const __m128i ba_hi = _mm_unpackhi_epi16(*a, *b); + const __m128i ab_hi = _mm_unpackhi_epi16(*b, *a); + const __m128i x0 = _mm_madd_epi16(ba, msin_pcos); + const __m128i y0 = _mm_madd_epi16(ab, psin_pcos); + const __m128i x0_hi = _mm_madd_epi16(ba_hi, msin_pcos); + const __m128i y0_hi = _mm_madd_epi16(ab_hi, psin_pcos); + const __m128i x1 = RightShiftWithRounding_S32(x0, 12); + const __m128i y1 = RightShiftWithRounding_S32(y0, 12); + const __m128i x1_hi = RightShiftWithRounding_S32(x0_hi, 12); + const __m128i y1_hi = RightShiftWithRounding_S32(y0_hi, 12); + const __m128i x = _mm_packs_epi32(x1, x1_hi); + const __m128i y = _mm_packs_epi32(y1, y1_hi); + if (flip) { + *a = y; + *b = x; + } else { + *a = x; + *b = y; + } +} + +LIBGAV1_ALWAYS_INLINE void ButterflyRotation_FirstIsZero(__m128i* a, __m128i* b, + const int angle, + const bool flip) { + const int16_t cos128 = Cos128(angle); + const int16_t sin128 = Sin128(angle); + const __m128i pcos = _mm_set1_epi16(cos128 << 3); + const __m128i psin = _mm_set1_epi16(-(sin128 << 3)); + const __m128i x = _mm_mulhrs_epi16(*b, psin); + const __m128i y = _mm_mulhrs_epi16(*b, pcos); + if (flip) { + *a = y; + *b = x; + } else { + *a = x; + *b = y; + } +} + +LIBGAV1_ALWAYS_INLINE void ButterflyRotation_SecondIsZero(__m128i* a, + __m128i* b, + const int angle, + const bool flip) { + const int16_t cos128 = Cos128(angle); + const int16_t sin128 = Sin128(angle); + const __m128i pcos = _mm_set1_epi16(cos128 << 3); + const __m128i psin = _mm_set1_epi16(sin128 << 3); + const __m128i x = _mm_mulhrs_epi16(*a, pcos); + const __m128i y = _mm_mulhrs_epi16(*a, psin); + if (flip) { + *a = y; + *b = x; + } else { + *a = x; + *b = y; + } +} + +LIBGAV1_ALWAYS_INLINE void HadamardRotation(__m128i* a, __m128i* b, bool flip) { + __m128i x, y; + if (flip) { + y = _mm_adds_epi16(*b, *a); + x = _mm_subs_epi16(*b, *a); + } else { + x = _mm_adds_epi16(*a, *b); + y = _mm_subs_epi16(*a, *b); + } + *a = x; + *b = y; +} + +using ButterflyRotationFunc = void (*)(__m128i* a, __m128i* b, int angle, + bool flip); + +LIBGAV1_ALWAYS_INLINE __m128i ShiftResidual(const __m128i residual, + const __m128i v_row_shift_add, + const __m128i v_row_shift) { + const __m128i k7ffd = _mm_set1_epi16(0x7ffd); + // The max row_shift is 2, so int16_t values greater than 0x7ffd may + // overflow. Generate a mask for this case. + const __m128i mask = _mm_cmpgt_epi16(residual, k7ffd); + const __m128i x = _mm_add_epi16(residual, v_row_shift_add); + // Assume int16_t values. + const __m128i a = _mm_sra_epi16(x, v_row_shift); + // Assume uint16_t values. + const __m128i b = _mm_srl_epi16(x, v_row_shift); + // Select the correct shifted value. + return _mm_blendv_epi8(a, b, mask); +} + +//------------------------------------------------------------------------------ +// Discrete Cosine Transforms (DCT). + +template <int width> +LIBGAV1_ALWAYS_INLINE bool DctDcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + const __m128i v_src_lo = _mm_shufflelo_epi16(_mm_cvtsi32_si128(dst[0]), 0); + const __m128i v_src = + (width == 4) ? v_src_lo : _mm_shuffle_epi32(v_src_lo, 0); + const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0); + const __m128i v_kTransformRowMultiplier = + _mm_set1_epi16(kTransformRowMultiplier << 3); + const __m128i v_src_round = + _mm_mulhrs_epi16(v_src, v_kTransformRowMultiplier); + const __m128i s0 = _mm_blendv_epi8(v_src, v_src_round, v_mask); + const int16_t cos128 = Cos128(32); + const __m128i xy = _mm_mulhrs_epi16(s0, _mm_set1_epi16(cos128 << 3)); + + // Expand to 32 bits to prevent int16_t overflows during the shift add. + const __m128i v_row_shift_add = _mm_set1_epi32(row_shift); + const __m128i v_row_shift = _mm_cvtepu32_epi64(v_row_shift_add); + const __m128i a = _mm_cvtepi16_epi32(xy); + const __m128i a1 = _mm_cvtepi16_epi32(_mm_srli_si128(xy, 8)); + const __m128i b = _mm_add_epi32(a, v_row_shift_add); + const __m128i b1 = _mm_add_epi32(a1, v_row_shift_add); + const __m128i c = _mm_sra_epi32(b, v_row_shift); + const __m128i c1 = _mm_sra_epi32(b1, v_row_shift); + const __m128i xy_shifted = _mm_packs_epi32(c, c1); + + if (width == 4) { + StoreLo8(dst, xy_shifted); + } else { + for (int i = 0; i < width; i += 8) { + StoreUnaligned16(dst, xy_shifted); + dst += 8; + } + } + return true; +} + +template <int height> +LIBGAV1_ALWAYS_INLINE bool DctDcOnlyColumn(void* dest, int adjusted_tx_height, + int width) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + const int16_t cos128 = Cos128(32); + + // Calculate dc values for first row. + if (width == 4) { + const __m128i v_src = LoadLo8(dst); + const __m128i xy = _mm_mulhrs_epi16(v_src, _mm_set1_epi16(cos128 << 3)); + StoreLo8(dst, xy); + } else { + int i = 0; + do { + const __m128i v_src = LoadUnaligned16(&dst[i]); + const __m128i xy = _mm_mulhrs_epi16(v_src, _mm_set1_epi16(cos128 << 3)); + StoreUnaligned16(&dst[i], xy); + i += 8; + } while (i < width); + } + + // Copy first row to the rest of the block. + for (int y = 1; y < height; ++y) { + memcpy(&dst[y * width], dst, width * sizeof(dst[0])); + } + return true; +} + +template <ButterflyRotationFunc butterfly_rotation, + bool is_fast_butterfly = false> +LIBGAV1_ALWAYS_INLINE void Dct4Stages(__m128i* s) { + // stage 12. + if (is_fast_butterfly) { + ButterflyRotation_SecondIsZero(&s[0], &s[1], 32, true); + ButterflyRotation_SecondIsZero(&s[2], &s[3], 48, false); + } else { + butterfly_rotation(&s[0], &s[1], 32, true); + butterfly_rotation(&s[2], &s[3], 48, false); + } + + // stage 17. + HadamardRotation(&s[0], &s[3], false); + HadamardRotation(&s[1], &s[2], false); +} + +// Process 4 dct4 rows or columns, depending on the transpose flag. +template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular> +LIBGAV1_ALWAYS_INLINE void Dct4_SSE4_1(void* dest, int32_t step, + bool transpose) { + auto* const dst = static_cast<int16_t*>(dest); + __m128i s[4], x[4]; + + if (stage_is_rectangular) { + if (transpose) { + __m128i input[8]; + LoadSrc<8, 8>(dst, step, 0, input); + Transpose4x8To8x4_U16(input, x); + } else { + LoadSrc<16, 4>(dst, step, 0, x); + } + } else { + LoadSrc<8, 4>(dst, step, 0, x); + if (transpose) { + Transpose4x4_U16(x, x); + } + } + // stage 1. + // kBitReverseLookup 0, 2, 1, 3 + s[0] = x[0]; + s[1] = x[2]; + s[2] = x[1]; + s[3] = x[3]; + + Dct4Stages<butterfly_rotation>(s); + + if (stage_is_rectangular) { + if (transpose) { + __m128i output[8]; + Transpose8x4To4x8_U16(s, output); + StoreDst<8, 8>(dst, step, 0, output); + } else { + StoreDst<16, 4>(dst, step, 0, s); + } + } else { + if (transpose) { + Transpose4x4_U16(s, s); + } + StoreDst<8, 4>(dst, step, 0, s); + } +} + +template <ButterflyRotationFunc butterfly_rotation, + bool is_fast_butterfly = false> +LIBGAV1_ALWAYS_INLINE void Dct8Stages(__m128i* s) { + // stage 8. + if (is_fast_butterfly) { + ButterflyRotation_SecondIsZero(&s[4], &s[7], 56, false); + ButterflyRotation_FirstIsZero(&s[5], &s[6], 24, false); + } else { + butterfly_rotation(&s[4], &s[7], 56, false); + butterfly_rotation(&s[5], &s[6], 24, false); + } + + // stage 13. + HadamardRotation(&s[4], &s[5], false); + HadamardRotation(&s[6], &s[7], true); + + // stage 18. + butterfly_rotation(&s[6], &s[5], 32, true); + + // stage 22. + HadamardRotation(&s[0], &s[7], false); + HadamardRotation(&s[1], &s[6], false); + HadamardRotation(&s[2], &s[5], false); + HadamardRotation(&s[3], &s[4], false); +} + +// Process dct8 rows or columns, depending on the transpose flag. +template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular> +LIBGAV1_ALWAYS_INLINE void Dct8_SSE4_1(void* dest, int32_t step, + bool transpose) { + auto* const dst = static_cast<int16_t*>(dest); + __m128i s[8], x[8]; + + if (stage_is_rectangular) { + if (transpose) { + __m128i input[4]; + LoadSrc<16, 4>(dst, step, 0, input); + Transpose8x4To4x8_U16(input, x); + } else { + LoadSrc<8, 8>(dst, step, 0, x); + } + } else { + if (transpose) { + __m128i input[8]; + LoadSrc<16, 8>(dst, step, 0, input); + Transpose8x8_U16(input, x); + } else { + LoadSrc<16, 8>(dst, step, 0, x); + } + } + + // stage 1. + // kBitReverseLookup 0, 4, 2, 6, 1, 5, 3, 7, + s[0] = x[0]; + s[1] = x[4]; + s[2] = x[2]; + s[3] = x[6]; + s[4] = x[1]; + s[5] = x[5]; + s[6] = x[3]; + s[7] = x[7]; + + Dct4Stages<butterfly_rotation>(s); + Dct8Stages<butterfly_rotation>(s); + + if (stage_is_rectangular) { + if (transpose) { + __m128i output[4]; + Transpose4x8To8x4_U16(s, output); + StoreDst<16, 4>(dst, step, 0, output); + } else { + StoreDst<8, 8>(dst, step, 0, s); + } + } else { + if (transpose) { + __m128i output[8]; + Transpose8x8_U16(s, output); + StoreDst<16, 8>(dst, step, 0, output); + } else { + StoreDst<16, 8>(dst, step, 0, s); + } + } +} + +template <ButterflyRotationFunc butterfly_rotation, + bool is_fast_butterfly = false> +LIBGAV1_ALWAYS_INLINE void Dct16Stages(__m128i* s) { + // stage 5. + if (is_fast_butterfly) { + ButterflyRotation_SecondIsZero(&s[8], &s[15], 60, false); + ButterflyRotation_FirstIsZero(&s[9], &s[14], 28, false); + ButterflyRotation_SecondIsZero(&s[10], &s[13], 44, false); + ButterflyRotation_FirstIsZero(&s[11], &s[12], 12, false); + } else { + butterfly_rotation(&s[8], &s[15], 60, false); + butterfly_rotation(&s[9], &s[14], 28, false); + butterfly_rotation(&s[10], &s[13], 44, false); + butterfly_rotation(&s[11], &s[12], 12, false); + } + + // stage 9. + HadamardRotation(&s[8], &s[9], false); + HadamardRotation(&s[10], &s[11], true); + HadamardRotation(&s[12], &s[13], false); + HadamardRotation(&s[14], &s[15], true); + + // stage 14. + butterfly_rotation(&s[14], &s[9], 48, true); + butterfly_rotation(&s[13], &s[10], 112, true); + + // stage 19. + HadamardRotation(&s[8], &s[11], false); + HadamardRotation(&s[9], &s[10], false); + HadamardRotation(&s[12], &s[15], true); + HadamardRotation(&s[13], &s[14], true); + + // stage 23. + butterfly_rotation(&s[13], &s[10], 32, true); + butterfly_rotation(&s[12], &s[11], 32, true); + + // stage 26. + HadamardRotation(&s[0], &s[15], false); + HadamardRotation(&s[1], &s[14], false); + HadamardRotation(&s[2], &s[13], false); + HadamardRotation(&s[3], &s[12], false); + HadamardRotation(&s[4], &s[11], false); + HadamardRotation(&s[5], &s[10], false); + HadamardRotation(&s[6], &s[9], false); + HadamardRotation(&s[7], &s[8], false); +} + +// Process dct16 rows or columns, depending on the transpose flag. +template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular> +LIBGAV1_ALWAYS_INLINE void Dct16_SSE4_1(void* dest, int32_t step, + bool transpose) { + auto* const dst = static_cast<int16_t*>(dest); + __m128i s[16], x[16]; + + if (stage_is_rectangular) { + if (transpose) { + __m128i input[4]; + LoadSrc<16, 4>(dst, step, 0, input); + Transpose8x4To4x8_U16(input, x); + LoadSrc<16, 4>(dst, step, 8, input); + Transpose8x4To4x8_U16(input, &x[8]); + } else { + LoadSrc<8, 16>(dst, step, 0, x); + } + } else { + if (transpose) { + for (int idx = 0; idx < 16; idx += 8) { + __m128i input[8]; + LoadSrc<16, 8>(dst, step, idx, input); + Transpose8x8_U16(input, &x[idx]); + } + } else { + LoadSrc<16, 16>(dst, step, 0, x); + } + } + + // stage 1 + // kBitReverseLookup 0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15, + s[0] = x[0]; + s[1] = x[8]; + s[2] = x[4]; + s[3] = x[12]; + s[4] = x[2]; + s[5] = x[10]; + s[6] = x[6]; + s[7] = x[14]; + s[8] = x[1]; + s[9] = x[9]; + s[10] = x[5]; + s[11] = x[13]; + s[12] = x[3]; + s[13] = x[11]; + s[14] = x[7]; + s[15] = x[15]; + + Dct4Stages<butterfly_rotation>(s); + Dct8Stages<butterfly_rotation>(s); + Dct16Stages<butterfly_rotation>(s); + + if (stage_is_rectangular) { + if (transpose) { + __m128i output[4]; + Transpose4x8To8x4_U16(s, output); + StoreDst<16, 4>(dst, step, 0, output); + Transpose4x8To8x4_U16(&s[8], output); + StoreDst<16, 4>(dst, step, 8, output); + } else { + StoreDst<8, 16>(dst, step, 0, s); + } + } else { + if (transpose) { + for (int idx = 0; idx < 16; idx += 8) { + __m128i output[8]; + Transpose8x8_U16(&s[idx], output); + StoreDst<16, 8>(dst, step, idx, output); + } + } else { + StoreDst<16, 16>(dst, step, 0, s); + } + } +} + +template <ButterflyRotationFunc butterfly_rotation, + bool is_fast_butterfly = false> +LIBGAV1_ALWAYS_INLINE void Dct32Stages(__m128i* s) { + // stage 3 + if (is_fast_butterfly) { + ButterflyRotation_SecondIsZero(&s[16], &s[31], 62, false); + ButterflyRotation_FirstIsZero(&s[17], &s[30], 30, false); + ButterflyRotation_SecondIsZero(&s[18], &s[29], 46, false); + ButterflyRotation_FirstIsZero(&s[19], &s[28], 14, false); + ButterflyRotation_SecondIsZero(&s[20], &s[27], 54, false); + ButterflyRotation_FirstIsZero(&s[21], &s[26], 22, false); + ButterflyRotation_SecondIsZero(&s[22], &s[25], 38, false); + ButterflyRotation_FirstIsZero(&s[23], &s[24], 6, false); + } else { + butterfly_rotation(&s[16], &s[31], 62, false); + butterfly_rotation(&s[17], &s[30], 30, false); + butterfly_rotation(&s[18], &s[29], 46, false); + butterfly_rotation(&s[19], &s[28], 14, false); + butterfly_rotation(&s[20], &s[27], 54, false); + butterfly_rotation(&s[21], &s[26], 22, false); + butterfly_rotation(&s[22], &s[25], 38, false); + butterfly_rotation(&s[23], &s[24], 6, false); + } + // stage 6. + HadamardRotation(&s[16], &s[17], false); + HadamardRotation(&s[18], &s[19], true); + HadamardRotation(&s[20], &s[21], false); + HadamardRotation(&s[22], &s[23], true); + HadamardRotation(&s[24], &s[25], false); + HadamardRotation(&s[26], &s[27], true); + HadamardRotation(&s[28], &s[29], false); + HadamardRotation(&s[30], &s[31], true); + + // stage 10. + butterfly_rotation(&s[30], &s[17], 24 + 32, true); + butterfly_rotation(&s[29], &s[18], 24 + 64 + 32, true); + butterfly_rotation(&s[26], &s[21], 24, true); + butterfly_rotation(&s[25], &s[22], 24 + 64, true); + + // stage 15. + HadamardRotation(&s[16], &s[19], false); + HadamardRotation(&s[17], &s[18], false); + HadamardRotation(&s[20], &s[23], true); + HadamardRotation(&s[21], &s[22], true); + HadamardRotation(&s[24], &s[27], false); + HadamardRotation(&s[25], &s[26], false); + HadamardRotation(&s[28], &s[31], true); + HadamardRotation(&s[29], &s[30], true); + + // stage 20. + butterfly_rotation(&s[29], &s[18], 48, true); + butterfly_rotation(&s[28], &s[19], 48, true); + butterfly_rotation(&s[27], &s[20], 48 + 64, true); + butterfly_rotation(&s[26], &s[21], 48 + 64, true); + + // stage 24. + HadamardRotation(&s[16], &s[23], false); + HadamardRotation(&s[17], &s[22], false); + HadamardRotation(&s[18], &s[21], false); + HadamardRotation(&s[19], &s[20], false); + HadamardRotation(&s[24], &s[31], true); + HadamardRotation(&s[25], &s[30], true); + HadamardRotation(&s[26], &s[29], true); + HadamardRotation(&s[27], &s[28], true); + + // stage 27. + butterfly_rotation(&s[27], &s[20], 32, true); + butterfly_rotation(&s[26], &s[21], 32, true); + butterfly_rotation(&s[25], &s[22], 32, true); + butterfly_rotation(&s[24], &s[23], 32, true); + + // stage 29. + HadamardRotation(&s[0], &s[31], false); + HadamardRotation(&s[1], &s[30], false); + HadamardRotation(&s[2], &s[29], false); + HadamardRotation(&s[3], &s[28], false); + HadamardRotation(&s[4], &s[27], false); + HadamardRotation(&s[5], &s[26], false); + HadamardRotation(&s[6], &s[25], false); + HadamardRotation(&s[7], &s[24], false); + HadamardRotation(&s[8], &s[23], false); + HadamardRotation(&s[9], &s[22], false); + HadamardRotation(&s[10], &s[21], false); + HadamardRotation(&s[11], &s[20], false); + HadamardRotation(&s[12], &s[19], false); + HadamardRotation(&s[13], &s[18], false); + HadamardRotation(&s[14], &s[17], false); + HadamardRotation(&s[15], &s[16], false); +} + +// Process dct32 rows or columns, depending on the transpose flag. +LIBGAV1_ALWAYS_INLINE void Dct32_SSE4_1(void* dest, const int32_t step, + const bool transpose) { + auto* const dst = static_cast<int16_t*>(dest); + __m128i s[32], x[32]; + + if (transpose) { + for (int idx = 0; idx < 32; idx += 8) { + __m128i input[8]; + LoadSrc<16, 8>(dst, step, idx, input); + Transpose8x8_U16(input, &x[idx]); + } + } else { + LoadSrc<16, 32>(dst, step, 0, x); + } + + // stage 1 + // kBitReverseLookup + // 0, 16, 8, 24, 4, 20, 12, 28, 2, 18, 10, 26, 6, 22, 14, 30, + s[0] = x[0]; + s[1] = x[16]; + s[2] = x[8]; + s[3] = x[24]; + s[4] = x[4]; + s[5] = x[20]; + s[6] = x[12]; + s[7] = x[28]; + s[8] = x[2]; + s[9] = x[18]; + s[10] = x[10]; + s[11] = x[26]; + s[12] = x[6]; + s[13] = x[22]; + s[14] = x[14]; + s[15] = x[30]; + + // 1, 17, 9, 25, 5, 21, 13, 29, 3, 19, 11, 27, 7, 23, 15, 31, + s[16] = x[1]; + s[17] = x[17]; + s[18] = x[9]; + s[19] = x[25]; + s[20] = x[5]; + s[21] = x[21]; + s[22] = x[13]; + s[23] = x[29]; + s[24] = x[3]; + s[25] = x[19]; + s[26] = x[11]; + s[27] = x[27]; + s[28] = x[7]; + s[29] = x[23]; + s[30] = x[15]; + s[31] = x[31]; + + Dct4Stages<ButterflyRotation_8>(s); + Dct8Stages<ButterflyRotation_8>(s); + Dct16Stages<ButterflyRotation_8>(s); + Dct32Stages<ButterflyRotation_8>(s); + + if (transpose) { + for (int idx = 0; idx < 32; idx += 8) { + __m128i output[8]; + Transpose8x8_U16(&s[idx], output); + StoreDst<16, 8>(dst, step, idx, output); + } + } else { + StoreDst<16, 32>(dst, step, 0, s); + } +} + +// Allow the compiler to call this function instead of force inlining. Tests +// show the performance is slightly faster. +void Dct64_SSE4_1(void* dest, int32_t step, bool transpose) { + auto* const dst = static_cast<int16_t*>(dest); + __m128i s[64], x[32]; + + if (transpose) { + // The last 32 values of every row are always zero if the |tx_width| is + // 64. + for (int idx = 0; idx < 32; idx += 8) { + __m128i input[8]; + LoadSrc<16, 8>(dst, step, idx, input); + Transpose8x8_U16(input, &x[idx]); + } + } else { + // The last 32 values of every column are always zero if the |tx_height| is + // 64. + LoadSrc<16, 32>(dst, step, 0, x); + } + + // stage 1 + // kBitReverseLookup + // 0, 32, 16, 48, 8, 40, 24, 56, 4, 36, 20, 52, 12, 44, 28, 60, + s[0] = x[0]; + s[2] = x[16]; + s[4] = x[8]; + s[6] = x[24]; + s[8] = x[4]; + s[10] = x[20]; + s[12] = x[12]; + s[14] = x[28]; + + // 2, 34, 18, 50, 10, 42, 26, 58, 6, 38, 22, 54, 14, 46, 30, 62, + s[16] = x[2]; + s[18] = x[18]; + s[20] = x[10]; + s[22] = x[26]; + s[24] = x[6]; + s[26] = x[22]; + s[28] = x[14]; + s[30] = x[30]; + + // 1, 33, 17, 49, 9, 41, 25, 57, 5, 37, 21, 53, 13, 45, 29, 61, + s[32] = x[1]; + s[34] = x[17]; + s[36] = x[9]; + s[38] = x[25]; + s[40] = x[5]; + s[42] = x[21]; + s[44] = x[13]; + s[46] = x[29]; + + // 3, 35, 19, 51, 11, 43, 27, 59, 7, 39, 23, 55, 15, 47, 31, 63 + s[48] = x[3]; + s[50] = x[19]; + s[52] = x[11]; + s[54] = x[27]; + s[56] = x[7]; + s[58] = x[23]; + s[60] = x[15]; + s[62] = x[31]; + + Dct4Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s); + Dct8Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s); + Dct16Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s); + Dct32Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s); + + //-- start dct 64 stages + // stage 2. + ButterflyRotation_SecondIsZero(&s[32], &s[63], 63 - 0, false); + ButterflyRotation_FirstIsZero(&s[33], &s[62], 63 - 32, false); + ButterflyRotation_SecondIsZero(&s[34], &s[61], 63 - 16, false); + ButterflyRotation_FirstIsZero(&s[35], &s[60], 63 - 48, false); + ButterflyRotation_SecondIsZero(&s[36], &s[59], 63 - 8, false); + ButterflyRotation_FirstIsZero(&s[37], &s[58], 63 - 40, false); + ButterflyRotation_SecondIsZero(&s[38], &s[57], 63 - 24, false); + ButterflyRotation_FirstIsZero(&s[39], &s[56], 63 - 56, false); + ButterflyRotation_SecondIsZero(&s[40], &s[55], 63 - 4, false); + ButterflyRotation_FirstIsZero(&s[41], &s[54], 63 - 36, false); + ButterflyRotation_SecondIsZero(&s[42], &s[53], 63 - 20, false); + ButterflyRotation_FirstIsZero(&s[43], &s[52], 63 - 52, false); + ButterflyRotation_SecondIsZero(&s[44], &s[51], 63 - 12, false); + ButterflyRotation_FirstIsZero(&s[45], &s[50], 63 - 44, false); + ButterflyRotation_SecondIsZero(&s[46], &s[49], 63 - 28, false); + ButterflyRotation_FirstIsZero(&s[47], &s[48], 63 - 60, false); + + // stage 4. + HadamardRotation(&s[32], &s[33], false); + HadamardRotation(&s[34], &s[35], true); + HadamardRotation(&s[36], &s[37], false); + HadamardRotation(&s[38], &s[39], true); + HadamardRotation(&s[40], &s[41], false); + HadamardRotation(&s[42], &s[43], true); + HadamardRotation(&s[44], &s[45], false); + HadamardRotation(&s[46], &s[47], true); + HadamardRotation(&s[48], &s[49], false); + HadamardRotation(&s[50], &s[51], true); + HadamardRotation(&s[52], &s[53], false); + HadamardRotation(&s[54], &s[55], true); + HadamardRotation(&s[56], &s[57], false); + HadamardRotation(&s[58], &s[59], true); + HadamardRotation(&s[60], &s[61], false); + HadamardRotation(&s[62], &s[63], true); + + // stage 7. + ButterflyRotation_8(&s[62], &s[33], 60 - 0, true); + ButterflyRotation_8(&s[61], &s[34], 60 - 0 + 64, true); + ButterflyRotation_8(&s[58], &s[37], 60 - 32, true); + ButterflyRotation_8(&s[57], &s[38], 60 - 32 + 64, true); + ButterflyRotation_8(&s[54], &s[41], 60 - 16, true); + ButterflyRotation_8(&s[53], &s[42], 60 - 16 + 64, true); + ButterflyRotation_8(&s[50], &s[45], 60 - 48, true); + ButterflyRotation_8(&s[49], &s[46], 60 - 48 + 64, true); + + // stage 11. + HadamardRotation(&s[32], &s[35], false); + HadamardRotation(&s[33], &s[34], false); + HadamardRotation(&s[36], &s[39], true); + HadamardRotation(&s[37], &s[38], true); + HadamardRotation(&s[40], &s[43], false); + HadamardRotation(&s[41], &s[42], false); + HadamardRotation(&s[44], &s[47], true); + HadamardRotation(&s[45], &s[46], true); + HadamardRotation(&s[48], &s[51], false); + HadamardRotation(&s[49], &s[50], false); + HadamardRotation(&s[52], &s[55], true); + HadamardRotation(&s[53], &s[54], true); + HadamardRotation(&s[56], &s[59], false); + HadamardRotation(&s[57], &s[58], false); + HadamardRotation(&s[60], &s[63], true); + HadamardRotation(&s[61], &s[62], true); + + // stage 16. + ButterflyRotation_8(&s[61], &s[34], 56, true); + ButterflyRotation_8(&s[60], &s[35], 56, true); + ButterflyRotation_8(&s[59], &s[36], 56 + 64, true); + ButterflyRotation_8(&s[58], &s[37], 56 + 64, true); + ButterflyRotation_8(&s[53], &s[42], 56 - 32, true); + ButterflyRotation_8(&s[52], &s[43], 56 - 32, true); + ButterflyRotation_8(&s[51], &s[44], 56 - 32 + 64, true); + ButterflyRotation_8(&s[50], &s[45], 56 - 32 + 64, true); + + // stage 21. + HadamardRotation(&s[32], &s[39], false); + HadamardRotation(&s[33], &s[38], false); + HadamardRotation(&s[34], &s[37], false); + HadamardRotation(&s[35], &s[36], false); + HadamardRotation(&s[40], &s[47], true); + HadamardRotation(&s[41], &s[46], true); + HadamardRotation(&s[42], &s[45], true); + HadamardRotation(&s[43], &s[44], true); + HadamardRotation(&s[48], &s[55], false); + HadamardRotation(&s[49], &s[54], false); + HadamardRotation(&s[50], &s[53], false); + HadamardRotation(&s[51], &s[52], false); + HadamardRotation(&s[56], &s[63], true); + HadamardRotation(&s[57], &s[62], true); + HadamardRotation(&s[58], &s[61], true); + HadamardRotation(&s[59], &s[60], true); + + // stage 25. + ButterflyRotation_8(&s[59], &s[36], 48, true); + ButterflyRotation_8(&s[58], &s[37], 48, true); + ButterflyRotation_8(&s[57], &s[38], 48, true); + ButterflyRotation_8(&s[56], &s[39], 48, true); + ButterflyRotation_8(&s[55], &s[40], 112, true); + ButterflyRotation_8(&s[54], &s[41], 112, true); + ButterflyRotation_8(&s[53], &s[42], 112, true); + ButterflyRotation_8(&s[52], &s[43], 112, true); + + // stage 28. + HadamardRotation(&s[32], &s[47], false); + HadamardRotation(&s[33], &s[46], false); + HadamardRotation(&s[34], &s[45], false); + HadamardRotation(&s[35], &s[44], false); + HadamardRotation(&s[36], &s[43], false); + HadamardRotation(&s[37], &s[42], false); + HadamardRotation(&s[38], &s[41], false); + HadamardRotation(&s[39], &s[40], false); + HadamardRotation(&s[48], &s[63], true); + HadamardRotation(&s[49], &s[62], true); + HadamardRotation(&s[50], &s[61], true); + HadamardRotation(&s[51], &s[60], true); + HadamardRotation(&s[52], &s[59], true); + HadamardRotation(&s[53], &s[58], true); + HadamardRotation(&s[54], &s[57], true); + HadamardRotation(&s[55], &s[56], true); + + // stage 30. + ButterflyRotation_8(&s[55], &s[40], 32, true); + ButterflyRotation_8(&s[54], &s[41], 32, true); + ButterflyRotation_8(&s[53], &s[42], 32, true); + ButterflyRotation_8(&s[52], &s[43], 32, true); + ButterflyRotation_8(&s[51], &s[44], 32, true); + ButterflyRotation_8(&s[50], &s[45], 32, true); + ButterflyRotation_8(&s[49], &s[46], 32, true); + ButterflyRotation_8(&s[48], &s[47], 32, true); + + // stage 31. + for (int i = 0; i < 32; i += 4) { + HadamardRotation(&s[i], &s[63 - i], false); + HadamardRotation(&s[i + 1], &s[63 - i - 1], false); + HadamardRotation(&s[i + 2], &s[63 - i - 2], false); + HadamardRotation(&s[i + 3], &s[63 - i - 3], false); + } + //-- end dct 64 stages + + if (transpose) { + for (int idx = 0; idx < 64; idx += 8) { + __m128i output[8]; + Transpose8x8_U16(&s[idx], output); + StoreDst<16, 8>(dst, step, idx, output); + } + } else { + StoreDst<16, 64>(dst, step, 0, s); + } +} + +//------------------------------------------------------------------------------ +// Asymmetric Discrete Sine Transforms (ADST). + +template <bool stage_is_rectangular> +LIBGAV1_ALWAYS_INLINE void Adst4_SSE4_1(void* dest, int32_t step, + bool transpose) { + auto* const dst = static_cast<int16_t*>(dest); + __m128i s[8], x[4]; + + if (stage_is_rectangular) { + if (transpose) { + __m128i input[8]; + LoadSrc<8, 8>(dst, step, 0, input); + Transpose4x8To8x4_U16(input, x); + } else { + LoadSrc<16, 4>(dst, step, 0, x); + } + } else { + LoadSrc<8, 4>(dst, step, 0, x); + if (transpose) { + Transpose4x4_U16(x, x); + } + } + + const __m128i kAdst4Multiplier_1 = _mm_set1_epi16(kAdst4Multiplier[1]); + const __m128i kAdst4Multiplier_2 = _mm_set1_epi16(kAdst4Multiplier[2]); + const __m128i kAdst4Multiplier_3 = _mm_set1_epi16(kAdst4Multiplier[3]); + const __m128i kAdst4Multiplier_m0_1 = + _mm_set1_epi32(static_cast<uint16_t>(kAdst4Multiplier[1]) | + (static_cast<uint32_t>(-kAdst4Multiplier[0]) << 16)); + const __m128i kAdst4Multiplier_3_0 = + _mm_set1_epi32(static_cast<uint16_t>(kAdst4Multiplier[0]) | + (static_cast<uint32_t>(kAdst4Multiplier[3]) << 16)); + + // stage 1. + const __m128i x3_x0 = _mm_unpacklo_epi16(x[0], x[3]); + const __m128i x2_x0 = _mm_unpacklo_epi16(x[0], x[2]); + const __m128i zero_x1 = _mm_cvtepu16_epi32(x[1]); + const __m128i zero_x2 = _mm_cvtepu16_epi32(x[2]); + const __m128i zero_x3 = _mm_cvtepu16_epi32(x[3]); + + s[5] = _mm_madd_epi16(zero_x3, kAdst4Multiplier_1); + s[6] = _mm_madd_epi16(zero_x3, kAdst4Multiplier_3); + + // stage 2. + // ((src[0] - src[2]) + src[3]) * kAdst4Multiplier[2] + const __m128i k2_x3_x0 = _mm_madd_epi16(x3_x0, kAdst4Multiplier_2); + const __m128i k2_zero_x2 = _mm_madd_epi16(zero_x2, kAdst4Multiplier_2); + const __m128i b7 = _mm_sub_epi32(k2_x3_x0, k2_zero_x2); + + // stage 3. + s[0] = _mm_madd_epi16(x2_x0, kAdst4Multiplier_3_0); + s[1] = _mm_madd_epi16(x2_x0, kAdst4Multiplier_m0_1); + s[2] = b7; + s[3] = _mm_madd_epi16(zero_x1, kAdst4Multiplier_2); + + // stage 4. + s[0] = _mm_add_epi32(s[0], s[5]); + s[1] = _mm_sub_epi32(s[1], s[6]); + + // stages 5 and 6. + x[0] = _mm_add_epi32(s[0], s[3]); + x[1] = _mm_add_epi32(s[1], s[3]); + x[2] = _mm_add_epi32(s[0], s[1]); + x[3] = _mm_sub_epi32(x[2], s[3]); + + x[0] = RightShiftWithRounding_S32(x[0], 12); + x[1] = RightShiftWithRounding_S32(x[1], 12); + x[2] = RightShiftWithRounding_S32(s[2], 12); + x[3] = RightShiftWithRounding_S32(x[3], 12); + + x[0] = _mm_packs_epi32(x[0], x[1]); + x[2] = _mm_packs_epi32(x[2], x[3]); + x[1] = _mm_srli_si128(x[0], 8); + x[3] = _mm_srli_si128(x[2], 8); + + if (stage_is_rectangular) { + if (transpose) { + __m128i output[8]; + Transpose8x4To4x8_U16(x, output); + StoreDst<8, 8>(dst, step, 0, output); + } else { + StoreDst<16, 4>(dst, step, 0, x); + } + } else { + if (transpose) { + Transpose4x4_U16(x, x); + } + StoreDst<8, 4>(dst, step, 0, x); + } +} + +constexpr int16_t kAdst4DcOnlyMultiplier[8] = {1321, 0, 2482, 0, + 3344, 0, 2482, 1321}; + +LIBGAV1_ALWAYS_INLINE bool Adst4DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + const __m128i v_src = + _mm_shuffle_epi32(_mm_shufflelo_epi16(_mm_cvtsi32_si128(dst[0]), 0), 0); + const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0); + const __m128i v_kTransformRowMultiplier = + _mm_set1_epi16(kTransformRowMultiplier << 3); + const __m128i v_src_round = + _mm_mulhrs_epi16(v_src, v_kTransformRowMultiplier); + const __m128i s0 = _mm_blendv_epi8(v_src, v_src_round, v_mask); + const __m128i v_kAdst4DcOnlyMultipliers = + LoadUnaligned16(kAdst4DcOnlyMultiplier); + // s0*k0 s0*k1 s0*k2 s0*k1 + // + + // s0*0 s0*0 s0*0 s0*k0 + const __m128i x3 = _mm_madd_epi16(s0, v_kAdst4DcOnlyMultipliers); + const __m128i dst_0 = RightShiftWithRounding_S32(x3, 12); + const __m128i v_row_shift_add = _mm_set1_epi32(row_shift); + const __m128i v_row_shift = _mm_cvtepu32_epi64(v_row_shift_add); + const __m128i a = _mm_add_epi32(dst_0, v_row_shift_add); + const __m128i b = _mm_sra_epi32(a, v_row_shift); + const __m128i c = _mm_packs_epi32(b, b); + StoreLo8(dst, c); + + return true; +} + +LIBGAV1_ALWAYS_INLINE bool Adst4DcOnlyColumn(void* dest, int adjusted_tx_height, + int width) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + int i = 0; + do { + const __m128i v_src = _mm_cvtepi16_epi32(LoadLo8(&dst[i])); + const __m128i kAdst4Multiplier_0 = _mm_set1_epi32(kAdst4Multiplier[0]); + const __m128i kAdst4Multiplier_1 = _mm_set1_epi32(kAdst4Multiplier[1]); + const __m128i kAdst4Multiplier_2 = _mm_set1_epi32(kAdst4Multiplier[2]); + const __m128i s0 = _mm_mullo_epi32(kAdst4Multiplier_0, v_src); + const __m128i s1 = _mm_mullo_epi32(kAdst4Multiplier_1, v_src); + const __m128i s2 = _mm_mullo_epi32(kAdst4Multiplier_2, v_src); + const __m128i x0 = s0; + const __m128i x1 = s1; + const __m128i x2 = s2; + const __m128i x3 = _mm_add_epi32(s0, s1); + const __m128i dst_0 = RightShiftWithRounding_S32(x0, 12); + const __m128i dst_1 = RightShiftWithRounding_S32(x1, 12); + const __m128i dst_2 = RightShiftWithRounding_S32(x2, 12); + const __m128i dst_3 = RightShiftWithRounding_S32(x3, 12); + const __m128i dst_0_1 = _mm_packs_epi32(dst_0, dst_1); + const __m128i dst_2_3 = _mm_packs_epi32(dst_2, dst_3); + StoreLo8(&dst[i], dst_0_1); + StoreHi8(&dst[i + width * 1], dst_0_1); + StoreLo8(&dst[i + width * 2], dst_2_3); + StoreHi8(&dst[i + width * 3], dst_2_3); + i += 4; + } while (i < width); + + return true; +} + +template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular> +LIBGAV1_ALWAYS_INLINE void Adst8_SSE4_1(void* dest, int32_t step, + bool transpose) { + auto* const dst = static_cast<int16_t*>(dest); + __m128i s[8], x[8]; + + if (stage_is_rectangular) { + if (transpose) { + __m128i input[4]; + LoadSrc<16, 4>(dst, step, 0, input); + Transpose8x4To4x8_U16(input, x); + } else { + LoadSrc<8, 8>(dst, step, 0, x); + } + } else { + if (transpose) { + __m128i input[8]; + LoadSrc<16, 8>(dst, step, 0, input); + Transpose8x8_U16(input, x); + } else { + LoadSrc<16, 8>(dst, step, 0, x); + } + } + + // stage 1. + s[0] = x[7]; + s[1] = x[0]; + s[2] = x[5]; + s[3] = x[2]; + s[4] = x[3]; + s[5] = x[4]; + s[6] = x[1]; + s[7] = x[6]; + + // stage 2. + butterfly_rotation(&s[0], &s[1], 60 - 0, true); + butterfly_rotation(&s[2], &s[3], 60 - 16, true); + butterfly_rotation(&s[4], &s[5], 60 - 32, true); + butterfly_rotation(&s[6], &s[7], 60 - 48, true); + + // stage 3. + HadamardRotation(&s[0], &s[4], false); + HadamardRotation(&s[1], &s[5], false); + HadamardRotation(&s[2], &s[6], false); + HadamardRotation(&s[3], &s[7], false); + + // stage 4. + butterfly_rotation(&s[4], &s[5], 48 - 0, true); + butterfly_rotation(&s[7], &s[6], 48 - 32, true); + + // stage 5. + HadamardRotation(&s[0], &s[2], false); + HadamardRotation(&s[4], &s[6], false); + HadamardRotation(&s[1], &s[3], false); + HadamardRotation(&s[5], &s[7], false); + + // stage 6. + butterfly_rotation(&s[2], &s[3], 32, true); + butterfly_rotation(&s[6], &s[7], 32, true); + + // stage 7. + const __m128i v_zero = _mm_setzero_si128(); + x[0] = s[0]; + x[1] = _mm_subs_epi16(v_zero, s[4]); + x[2] = s[6]; + x[3] = _mm_subs_epi16(v_zero, s[2]); + x[4] = s[3]; + x[5] = _mm_subs_epi16(v_zero, s[7]); + x[6] = s[5]; + x[7] = _mm_subs_epi16(v_zero, s[1]); + + if (stage_is_rectangular) { + if (transpose) { + __m128i output[4]; + Transpose4x8To8x4_U16(x, output); + StoreDst<16, 4>(dst, step, 0, output); + } else { + StoreDst<8, 8>(dst, step, 0, x); + } + } else { + if (transpose) { + __m128i output[8]; + Transpose8x8_U16(x, output); + StoreDst<16, 8>(dst, step, 0, output); + } else { + StoreDst<16, 8>(dst, step, 0, x); + } + } +} + +LIBGAV1_ALWAYS_INLINE bool Adst8DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + __m128i s[8]; + + const __m128i v_src = _mm_shufflelo_epi16(_mm_cvtsi32_si128(dst[0]), 0); + const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0); + const __m128i v_kTransformRowMultiplier = + _mm_set1_epi16(kTransformRowMultiplier << 3); + const __m128i v_src_round = + _mm_mulhrs_epi16(v_src, v_kTransformRowMultiplier); + // stage 1. + s[1] = _mm_blendv_epi8(v_src, v_src_round, v_mask); + + // stage 2. + ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true); + + // stage 3. + s[4] = s[0]; + s[5] = s[1]; + + // stage 4. + ButterflyRotation_4(&s[4], &s[5], 48, true); + + // stage 5. + s[2] = s[0]; + s[3] = s[1]; + s[6] = s[4]; + s[7] = s[5]; + + // stage 6. + ButterflyRotation_4(&s[2], &s[3], 32, true); + ButterflyRotation_4(&s[6], &s[7], 32, true); + + // stage 7. + __m128i x[8]; + const __m128i v_zero = _mm_setzero_si128(); + x[0] = s[0]; + x[1] = _mm_subs_epi16(v_zero, s[4]); + x[2] = s[6]; + x[3] = _mm_subs_epi16(v_zero, s[2]); + x[4] = s[3]; + x[5] = _mm_subs_epi16(v_zero, s[7]); + x[6] = s[5]; + x[7] = _mm_subs_epi16(v_zero, s[1]); + + const __m128i x1_x0 = _mm_unpacklo_epi16(x[0], x[1]); + const __m128i x3_x2 = _mm_unpacklo_epi16(x[2], x[3]); + const __m128i x5_x4 = _mm_unpacklo_epi16(x[4], x[5]); + const __m128i x7_x6 = _mm_unpacklo_epi16(x[6], x[7]); + const __m128i x3_x0 = _mm_unpacklo_epi32(x1_x0, x3_x2); + const __m128i x7_x4 = _mm_unpacklo_epi32(x5_x4, x7_x6); + + const __m128i v_row_shift_add = _mm_set1_epi32(row_shift); + const __m128i v_row_shift = _mm_cvtepu32_epi64(v_row_shift_add); + const __m128i a = _mm_add_epi32(_mm_cvtepi16_epi32(x3_x0), v_row_shift_add); + const __m128i a1 = _mm_add_epi32(_mm_cvtepi16_epi32(x7_x4), v_row_shift_add); + const __m128i b = _mm_sra_epi32(a, v_row_shift); + const __m128i b1 = _mm_sra_epi32(a1, v_row_shift); + StoreUnaligned16(dst, _mm_packs_epi32(b, b1)); + + return true; +} + +LIBGAV1_ALWAYS_INLINE bool Adst8DcOnlyColumn(void* dest, int adjusted_tx_height, + int width) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + __m128i s[8]; + + int i = 0; + do { + const __m128i v_src = LoadLo8(dst); + // stage 1. + s[1] = v_src; + + // stage 2. + ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true); + + // stage 3. + s[4] = s[0]; + s[5] = s[1]; + + // stage 4. + ButterflyRotation_4(&s[4], &s[5], 48, true); + + // stage 5. + s[2] = s[0]; + s[3] = s[1]; + s[6] = s[4]; + s[7] = s[5]; + + // stage 6. + ButterflyRotation_4(&s[2], &s[3], 32, true); + ButterflyRotation_4(&s[6], &s[7], 32, true); + + // stage 7. + __m128i x[8]; + const __m128i v_zero = _mm_setzero_si128(); + x[0] = s[0]; + x[1] = _mm_subs_epi16(v_zero, s[4]); + x[2] = s[6]; + x[3] = _mm_subs_epi16(v_zero, s[2]); + x[4] = s[3]; + x[5] = _mm_subs_epi16(v_zero, s[7]); + x[6] = s[5]; + x[7] = _mm_subs_epi16(v_zero, s[1]); + + for (int j = 0; j < 8; ++j) { + StoreLo8(&dst[j * width], x[j]); + } + i += 4; + dst += 4; + } while (i < width); + + return true; +} + +template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular> +LIBGAV1_ALWAYS_INLINE void Adst16_SSE4_1(void* dest, int32_t step, + bool transpose) { + auto* const dst = static_cast<int16_t*>(dest); + __m128i s[16], x[16]; + + if (stage_is_rectangular) { + if (transpose) { + __m128i input[4]; + LoadSrc<16, 4>(dst, step, 0, input); + Transpose8x4To4x8_U16(input, x); + LoadSrc<16, 4>(dst, step, 8, input); + Transpose8x4To4x8_U16(input, &x[8]); + } else { + LoadSrc<8, 16>(dst, step, 0, x); + } + } else { + if (transpose) { + for (int idx = 0; idx < 16; idx += 8) { + __m128i input[8]; + LoadSrc<16, 8>(dst, step, idx, input); + Transpose8x8_U16(input, &x[idx]); + } + } else { + LoadSrc<16, 16>(dst, step, 0, x); + } + } + + // stage 1. + s[0] = x[15]; + s[1] = x[0]; + s[2] = x[13]; + s[3] = x[2]; + s[4] = x[11]; + s[5] = x[4]; + s[6] = x[9]; + s[7] = x[6]; + s[8] = x[7]; + s[9] = x[8]; + s[10] = x[5]; + s[11] = x[10]; + s[12] = x[3]; + s[13] = x[12]; + s[14] = x[1]; + s[15] = x[14]; + + // stage 2. + butterfly_rotation(&s[0], &s[1], 62 - 0, true); + butterfly_rotation(&s[2], &s[3], 62 - 8, true); + butterfly_rotation(&s[4], &s[5], 62 - 16, true); + butterfly_rotation(&s[6], &s[7], 62 - 24, true); + butterfly_rotation(&s[8], &s[9], 62 - 32, true); + butterfly_rotation(&s[10], &s[11], 62 - 40, true); + butterfly_rotation(&s[12], &s[13], 62 - 48, true); + butterfly_rotation(&s[14], &s[15], 62 - 56, true); + + // stage 3. + HadamardRotation(&s[0], &s[8], false); + HadamardRotation(&s[1], &s[9], false); + HadamardRotation(&s[2], &s[10], false); + HadamardRotation(&s[3], &s[11], false); + HadamardRotation(&s[4], &s[12], false); + HadamardRotation(&s[5], &s[13], false); + HadamardRotation(&s[6], &s[14], false); + HadamardRotation(&s[7], &s[15], false); + + // stage 4. + butterfly_rotation(&s[8], &s[9], 56 - 0, true); + butterfly_rotation(&s[13], &s[12], 8 + 0, true); + butterfly_rotation(&s[10], &s[11], 56 - 32, true); + butterfly_rotation(&s[15], &s[14], 8 + 32, true); + + // stage 5. + HadamardRotation(&s[0], &s[4], false); + HadamardRotation(&s[8], &s[12], false); + HadamardRotation(&s[1], &s[5], false); + HadamardRotation(&s[9], &s[13], false); + HadamardRotation(&s[2], &s[6], false); + HadamardRotation(&s[10], &s[14], false); + HadamardRotation(&s[3], &s[7], false); + HadamardRotation(&s[11], &s[15], false); + + // stage 6. + butterfly_rotation(&s[4], &s[5], 48 - 0, true); + butterfly_rotation(&s[12], &s[13], 48 - 0, true); + butterfly_rotation(&s[7], &s[6], 48 - 32, true); + butterfly_rotation(&s[15], &s[14], 48 - 32, true); + + // stage 7. + HadamardRotation(&s[0], &s[2], false); + HadamardRotation(&s[4], &s[6], false); + HadamardRotation(&s[8], &s[10], false); + HadamardRotation(&s[12], &s[14], false); + HadamardRotation(&s[1], &s[3], false); + HadamardRotation(&s[5], &s[7], false); + HadamardRotation(&s[9], &s[11], false); + HadamardRotation(&s[13], &s[15], false); + + // stage 8. + butterfly_rotation(&s[2], &s[3], 32, true); + butterfly_rotation(&s[6], &s[7], 32, true); + butterfly_rotation(&s[10], &s[11], 32, true); + butterfly_rotation(&s[14], &s[15], 32, true); + + // stage 9. + const __m128i v_zero = _mm_setzero_si128(); + x[0] = s[0]; + x[1] = _mm_subs_epi16(v_zero, s[8]); + x[2] = s[12]; + x[3] = _mm_subs_epi16(v_zero, s[4]); + x[4] = s[6]; + x[5] = _mm_subs_epi16(v_zero, s[14]); + x[6] = s[10]; + x[7] = _mm_subs_epi16(v_zero, s[2]); + x[8] = s[3]; + x[9] = _mm_subs_epi16(v_zero, s[11]); + x[10] = s[15]; + x[11] = _mm_subs_epi16(v_zero, s[7]); + x[12] = s[5]; + x[13] = _mm_subs_epi16(v_zero, s[13]); + x[14] = s[9]; + x[15] = _mm_subs_epi16(v_zero, s[1]); + + if (stage_is_rectangular) { + if (transpose) { + __m128i output[4]; + Transpose4x8To8x4_U16(x, output); + StoreDst<16, 4>(dst, step, 0, output); + Transpose4x8To8x4_U16(&x[8], output); + StoreDst<16, 4>(dst, step, 8, output); + } else { + StoreDst<8, 16>(dst, step, 0, x); + } + } else { + if (transpose) { + for (int idx = 0; idx < 16; idx += 8) { + __m128i output[8]; + Transpose8x8_U16(&x[idx], output); + StoreDst<16, 8>(dst, step, idx, output); + } + } else { + StoreDst<16, 16>(dst, step, 0, x); + } + } +} + +LIBGAV1_ALWAYS_INLINE void Adst16DcOnlyInternal(__m128i* s, __m128i* x) { + // stage 2. + ButterflyRotation_FirstIsZero(&s[0], &s[1], 62, true); + + // stage 3. + s[8] = s[0]; + s[9] = s[1]; + + // stage 4. + ButterflyRotation_4(&s[8], &s[9], 56, true); + + // stage 5. + s[4] = s[0]; + s[12] = s[8]; + s[5] = s[1]; + s[13] = s[9]; + + // stage 6. + ButterflyRotation_4(&s[4], &s[5], 48, true); + ButterflyRotation_4(&s[12], &s[13], 48, true); + + // stage 7. + s[2] = s[0]; + s[6] = s[4]; + s[10] = s[8]; + s[14] = s[12]; + s[3] = s[1]; + s[7] = s[5]; + s[11] = s[9]; + s[15] = s[13]; + + // stage 8. + ButterflyRotation_4(&s[2], &s[3], 32, true); + ButterflyRotation_4(&s[6], &s[7], 32, true); + ButterflyRotation_4(&s[10], &s[11], 32, true); + ButterflyRotation_4(&s[14], &s[15], 32, true); + + // stage 9. + const __m128i v_zero = _mm_setzero_si128(); + x[0] = s[0]; + x[1] = _mm_subs_epi16(v_zero, s[8]); + x[2] = s[12]; + x[3] = _mm_subs_epi16(v_zero, s[4]); + x[4] = s[6]; + x[5] = _mm_subs_epi16(v_zero, s[14]); + x[6] = s[10]; + x[7] = _mm_subs_epi16(v_zero, s[2]); + x[8] = s[3]; + x[9] = _mm_subs_epi16(v_zero, s[11]); + x[10] = s[15]; + x[11] = _mm_subs_epi16(v_zero, s[7]); + x[12] = s[5]; + x[13] = _mm_subs_epi16(v_zero, s[13]); + x[14] = s[9]; + x[15] = _mm_subs_epi16(v_zero, s[1]); +} + +LIBGAV1_ALWAYS_INLINE bool Adst16DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + __m128i s[16]; + __m128i x[16]; + + const __m128i v_src = _mm_shufflelo_epi16(_mm_cvtsi32_si128(dst[0]), 0); + const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0); + const __m128i v_kTransformRowMultiplier = + _mm_set1_epi16(kTransformRowMultiplier << 3); + const __m128i v_src_round = + _mm_mulhrs_epi16(v_src, v_kTransformRowMultiplier); + // stage 1. + s[1] = _mm_blendv_epi8(v_src, v_src_round, v_mask); + + Adst16DcOnlyInternal(s, x); + + for (int i = 0; i < 2; ++i) { + const __m128i x1_x0 = _mm_unpacklo_epi16(x[0 + i * 8], x[1 + i * 8]); + const __m128i x3_x2 = _mm_unpacklo_epi16(x[2 + i * 8], x[3 + i * 8]); + const __m128i x5_x4 = _mm_unpacklo_epi16(x[4 + i * 8], x[5 + i * 8]); + const __m128i x7_x6 = _mm_unpacklo_epi16(x[6 + i * 8], x[7 + i * 8]); + const __m128i x3_x0 = _mm_unpacklo_epi32(x1_x0, x3_x2); + const __m128i x7_x4 = _mm_unpacklo_epi32(x5_x4, x7_x6); + + const __m128i v_row_shift_add = _mm_set1_epi32(row_shift); + const __m128i v_row_shift = _mm_cvtepu32_epi64(v_row_shift_add); + const __m128i a = _mm_add_epi32(_mm_cvtepi16_epi32(x3_x0), v_row_shift_add); + const __m128i a1 = + _mm_add_epi32(_mm_cvtepi16_epi32(x7_x4), v_row_shift_add); + const __m128i b = _mm_sra_epi32(a, v_row_shift); + const __m128i b1 = _mm_sra_epi32(a1, v_row_shift); + StoreUnaligned16(&dst[i * 8], _mm_packs_epi32(b, b1)); + } + return true; +} + +LIBGAV1_ALWAYS_INLINE bool Adst16DcOnlyColumn(void* dest, + int adjusted_tx_height, + int width) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + int i = 0; + do { + __m128i s[16]; + __m128i x[16]; + const __m128i v_src = LoadUnaligned16(dst); + // stage 1. + s[1] = v_src; + + Adst16DcOnlyInternal(s, x); + + for (int j = 0; j < 16; ++j) { + StoreLo8(&dst[j * width], x[j]); + } + i += 4; + dst += 4; + } while (i < width); + + return true; +} + +//------------------------------------------------------------------------------ +// Identity Transforms. + +template <bool is_row_shift> +LIBGAV1_ALWAYS_INLINE void Identity4_SSE4_1(void* dest, int32_t step) { + auto* const dst = static_cast<int16_t*>(dest); + + if (is_row_shift) { + const int shift = 1; + const __m128i v_dual_round = _mm_set1_epi16((1 + (shift << 1)) << 11); + const __m128i v_multiplier_one = + _mm_set1_epi32((kIdentity4Multiplier << 16) | 0x0001); + for (int i = 0; i < 4; i += 2) { + const __m128i v_src = LoadUnaligned16(&dst[i * step]); + const __m128i v_src_round = _mm_unpacklo_epi16(v_dual_round, v_src); + const __m128i v_src_round_hi = _mm_unpackhi_epi16(v_dual_round, v_src); + const __m128i a = _mm_madd_epi16(v_src_round, v_multiplier_one); + const __m128i a_hi = _mm_madd_epi16(v_src_round_hi, v_multiplier_one); + const __m128i b = _mm_srai_epi32(a, 12 + shift); + const __m128i b_hi = _mm_srai_epi32(a_hi, 12 + shift); + StoreUnaligned16(&dst[i * step], _mm_packs_epi32(b, b_hi)); + } + } else { + const __m128i v_multiplier = + _mm_set1_epi16(kIdentity4MultiplierFraction << 3); + for (int i = 0; i < 4; i += 2) { + const __m128i v_src = LoadUnaligned16(&dst[i * step]); + const __m128i a = _mm_mulhrs_epi16(v_src, v_multiplier); + const __m128i b = _mm_adds_epi16(a, v_src); + StoreUnaligned16(&dst[i * step], b); + } + } +} + +LIBGAV1_ALWAYS_INLINE bool Identity4DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int tx_height) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + const __m128i v_src0 = _mm_cvtsi32_si128(dst[0]); + const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0); + const __m128i v_kTransformRowMultiplier = + _mm_set1_epi16(kTransformRowMultiplier << 3); + const __m128i v_src_round = + _mm_mulhrs_epi16(v_src0, v_kTransformRowMultiplier); + const __m128i v_src = _mm_blendv_epi8(v_src0, v_src_round, v_mask); + + const int shift = (tx_height < 16) ? 0 : 1; + const __m128i v_dual_round = _mm_set1_epi16((1 + (shift << 1)) << 11); + const __m128i v_multiplier_one = + _mm_set1_epi32((kIdentity4Multiplier << 16) | 0x0001); + const __m128i v_src_round_lo = _mm_unpacklo_epi16(v_dual_round, v_src); + const __m128i a = _mm_madd_epi16(v_src_round_lo, v_multiplier_one); + const __m128i b = _mm_srai_epi32(a, 12 + shift); + dst[0] = _mm_extract_epi16(_mm_packs_epi32(b, b), 0); + return true; +} + +LIBGAV1_ALWAYS_INLINE void Identity4ColumnStoreToFrame( + Array2DView<uint8_t> frame, const int start_x, const int start_y, + const int tx_width, const int tx_height, const int16_t* source) { + const int stride = frame.columns(); + uint8_t* dst = frame[start_y] + start_x; + + const __m128i v_multiplier_fraction = + _mm_set1_epi16(static_cast<int16_t>(kIdentity4MultiplierFraction << 3)); + const __m128i v_eight = _mm_set1_epi16(8); + + if (tx_width == 4) { + int i = 0; + do { + const __m128i v_src = LoadLo8(&source[i * tx_width]); + const __m128i v_src_mult = _mm_mulhrs_epi16(v_src, v_multiplier_fraction); + const __m128i frame_data = Load4(dst); + const __m128i v_dst_i = _mm_adds_epi16(v_src_mult, v_src); + const __m128i a = _mm_adds_epi16(v_dst_i, v_eight); + const __m128i b = _mm_srai_epi16(a, 4); + const __m128i c = _mm_cvtepu8_epi16(frame_data); + const __m128i d = _mm_adds_epi16(c, b); + Store4(dst, _mm_packus_epi16(d, d)); + dst += stride; + } while (++i < tx_height); + } else { + int i = 0; + do { + const int row = i * tx_width; + int j = 0; + do { + const __m128i v_src = LoadUnaligned16(&source[row + j]); + const __m128i v_src_mult = + _mm_mulhrs_epi16(v_src, v_multiplier_fraction); + const __m128i frame_data = LoadLo8(dst + j); + const __m128i v_dst_i = _mm_adds_epi16(v_src_mult, v_src); + const __m128i a = _mm_adds_epi16(v_dst_i, v_eight); + const __m128i b = _mm_srai_epi16(a, 4); + const __m128i c = _mm_cvtepu8_epi16(frame_data); + const __m128i d = _mm_adds_epi16(c, b); + StoreLo8(dst + j, _mm_packus_epi16(d, d)); + j += 8; + } while (j < tx_width); + dst += stride; + } while (++i < tx_height); + } +} + +LIBGAV1_ALWAYS_INLINE void Identity4RowColumnStoreToFrame( + Array2DView<uint8_t> frame, const int start_x, const int start_y, + const int tx_width, const int tx_height, const int16_t* source) { + const int stride = frame.columns(); + uint8_t* dst = frame[start_y] + start_x; + + const __m128i v_multiplier_fraction = + _mm_set1_epi16(static_cast<int16_t>(kIdentity4MultiplierFraction << 3)); + const __m128i v_eight = _mm_set1_epi16(8); + const __m128i v_kTransformRowMultiplier = + _mm_set1_epi16(kTransformRowMultiplier << 3); + + if (tx_width == 4) { + int i = 0; + do { + const __m128i v_src = LoadLo8(&source[i * tx_width]); + const __m128i v_src_mult = _mm_mulhrs_epi16(v_src, v_multiplier_fraction); + const __m128i frame_data = Load4(dst); + const __m128i v_dst_row = _mm_adds_epi16(v_src_mult, v_src); + const __m128i v_src_mult2 = + _mm_mulhrs_epi16(v_dst_row, v_multiplier_fraction); + const __m128i frame_data16 = _mm_cvtepu8_epi16(frame_data); + const __m128i v_dst_col = _mm_adds_epi16(v_src_mult2, v_dst_row); + const __m128i a = _mm_adds_epi16(v_dst_col, v_eight); + const __m128i b = _mm_srai_epi16(a, 4); + const __m128i c = _mm_adds_epi16(frame_data16, b); + Store4(dst, _mm_packus_epi16(c, c)); + dst += stride; + } while (++i < tx_height); + } else { + int i = 0; + do { + const int row = i * tx_width; + int j = 0; + do { + const __m128i v_src = LoadUnaligned16(&source[row + j]); + const __m128i v_src_round = + _mm_mulhrs_epi16(v_src, v_kTransformRowMultiplier); + const __m128i v_dst_row = _mm_adds_epi16(v_src_round, v_src_round); + const __m128i v_src_mult2 = + _mm_mulhrs_epi16(v_dst_row, v_multiplier_fraction); + const __m128i frame_data = LoadLo8(dst + j); + const __m128i frame_data16 = _mm_cvtepu8_epi16(frame_data); + const __m128i v_dst_col = _mm_adds_epi16(v_src_mult2, v_dst_row); + const __m128i a = _mm_adds_epi16(v_dst_col, v_eight); + const __m128i b = _mm_srai_epi16(a, 4); + const __m128i c = _mm_adds_epi16(frame_data16, b); + StoreLo8(dst + j, _mm_packus_epi16(c, c)); + j += 8; + } while (j < tx_width); + dst += stride; + } while (++i < tx_height); + } +} + +LIBGAV1_ALWAYS_INLINE void Identity8Row32_SSE4_1(void* dest, int32_t step) { + auto* const dst = static_cast<int16_t*>(dest); + + // When combining the identity8 multiplier with the row shift, the + // calculations for tx_height equal to 32 can be simplified from + // ((A * 2) + 2) >> 2) to ((A + 1) >> 1). + const __m128i v_row_multiplier = _mm_set1_epi16(1 << 14); + for (int h = 0; h < 4; ++h) { + const __m128i v_src = LoadUnaligned16(&dst[h * step]); + const __m128i v_src_mult = _mm_mulhrs_epi16(v_src, v_row_multiplier); + StoreUnaligned16(&dst[h * step], v_src_mult); + } +} + +LIBGAV1_ALWAYS_INLINE void Identity8Row4_SSE4_1(void* dest, int32_t step) { + auto* const dst = static_cast<int16_t*>(dest); + + for (int h = 0; h < 4; ++h) { + const __m128i v_src = LoadUnaligned16(&dst[h * step]); + // For bitdepth == 8, the identity row clamps to a signed 16bit value, so + // saturating add here is ok. + const __m128i a = _mm_adds_epi16(v_src, v_src); + StoreUnaligned16(&dst[h * step], a); + } +} + +LIBGAV1_ALWAYS_INLINE bool Identity8DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + const __m128i v_src0 = _mm_cvtsi32_si128(dst[0]); + const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0); + const __m128i v_kTransformRowMultiplier = + _mm_set1_epi16(kTransformRowMultiplier << 3); + const __m128i v_src_round = + _mm_mulhrs_epi16(v_src0, v_kTransformRowMultiplier); + const __m128i v_src = + _mm_cvtepi16_epi32(_mm_blendv_epi8(v_src0, v_src_round, v_mask)); + const __m128i v_srcx2 = _mm_add_epi32(v_src, v_src); + const __m128i v_row_shift_add = _mm_set1_epi32(row_shift); + const __m128i v_row_shift = _mm_cvtepu32_epi64(v_row_shift_add); + const __m128i a = _mm_add_epi32(v_srcx2, v_row_shift_add); + const __m128i b = _mm_sra_epi32(a, v_row_shift); + dst[0] = _mm_extract_epi16(_mm_packs_epi32(b, b), 0); + return true; +} + +LIBGAV1_ALWAYS_INLINE void Identity8ColumnStoreToFrame_SSE4_1( + Array2DView<uint8_t> frame, const int start_x, const int start_y, + const int tx_width, const int tx_height, const int16_t* source) { + const int stride = frame.columns(); + uint8_t* dst = frame[start_y] + start_x; + const __m128i v_eight = _mm_set1_epi16(8); + if (tx_width == 4) { + int i = 0; + do { + const int row = i * tx_width; + const __m128i v_src = LoadLo8(&source[row]); + const __m128i v_dst_i = _mm_adds_epi16(v_src, v_src); + const __m128i frame_data = Load4(dst); + const __m128i a = _mm_adds_epi16(v_dst_i, v_eight); + const __m128i b = _mm_srai_epi16(a, 4); + const __m128i c = _mm_cvtepu8_epi16(frame_data); + const __m128i d = _mm_adds_epi16(c, b); + Store4(dst, _mm_packus_epi16(d, d)); + dst += stride; + } while (++i < tx_height); + } else { + int i = 0; + do { + const int row = i * tx_width; + int j = 0; + do { + const __m128i v_src = LoadUnaligned16(&source[row + j]); + const __m128i v_dst_i = _mm_adds_epi16(v_src, v_src); + const __m128i frame_data = LoadLo8(dst + j); + const __m128i a = _mm_adds_epi16(v_dst_i, v_eight); + const __m128i b = _mm_srai_epi16(a, 4); + const __m128i c = _mm_cvtepu8_epi16(frame_data); + const __m128i d = _mm_adds_epi16(c, b); + StoreLo8(dst + j, _mm_packus_epi16(d, d)); + j += 8; + } while (j < tx_width); + dst += stride; + } while (++i < tx_height); + } +} + +LIBGAV1_ALWAYS_INLINE void Identity16Row_SSE4_1(void* dest, int32_t step, + int shift) { + auto* const dst = static_cast<int16_t*>(dest); + + const __m128i v_dual_round = _mm_set1_epi16((1 + (shift << 1)) << 11); + const __m128i v_multiplier_one = + _mm_set1_epi32((kIdentity16Multiplier << 16) | 0x0001); + const __m128i v_shift = _mm_set_epi64x(0, 12 + shift); + + for (int h = 0; h < 4; ++h) { + const __m128i v_src = LoadUnaligned16(&dst[h * step]); + const __m128i v_src2 = LoadUnaligned16(&dst[h * step + 8]); + const __m128i v_src_round0 = _mm_unpacklo_epi16(v_dual_round, v_src); + const __m128i v_src_round1 = _mm_unpackhi_epi16(v_dual_round, v_src); + const __m128i v_src2_round0 = _mm_unpacklo_epi16(v_dual_round, v_src2); + const __m128i v_src2_round1 = _mm_unpackhi_epi16(v_dual_round, v_src2); + const __m128i madd0 = _mm_madd_epi16(v_src_round0, v_multiplier_one); + const __m128i madd1 = _mm_madd_epi16(v_src_round1, v_multiplier_one); + const __m128i madd20 = _mm_madd_epi16(v_src2_round0, v_multiplier_one); + const __m128i madd21 = _mm_madd_epi16(v_src2_round1, v_multiplier_one); + const __m128i shift0 = _mm_sra_epi32(madd0, v_shift); + const __m128i shift1 = _mm_sra_epi32(madd1, v_shift); + const __m128i shift20 = _mm_sra_epi32(madd20, v_shift); + const __m128i shift21 = _mm_sra_epi32(madd21, v_shift); + StoreUnaligned16(&dst[h * step], _mm_packs_epi32(shift0, shift1)); + StoreUnaligned16(&dst[h * step + 8], _mm_packs_epi32(shift20, shift21)); + } +} + +LIBGAV1_ALWAYS_INLINE bool Identity16DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + const __m128i v_src0 = _mm_cvtsi32_si128(dst[0]); + const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0); + const __m128i v_kTransformRowMultiplier = + _mm_set1_epi16(kTransformRowMultiplier << 3); + const __m128i v_src_round0 = + _mm_mulhrs_epi16(v_src0, v_kTransformRowMultiplier); + const __m128i v_src = _mm_blendv_epi8(v_src0, v_src_round0, v_mask); + const __m128i v_dual_round = _mm_set1_epi16((1 + (shift << 1)) << 11); + const __m128i v_multiplier_one = + _mm_set1_epi32((kIdentity16Multiplier << 16) | 0x0001); + const __m128i v_shift = _mm_set_epi64x(0, 12 + shift); + const __m128i v_src_round = _mm_unpacklo_epi16(v_dual_round, v_src); + const __m128i a = _mm_madd_epi16(v_src_round, v_multiplier_one); + const __m128i b = _mm_sra_epi32(a, v_shift); + dst[0] = _mm_extract_epi16(_mm_packs_epi32(b, b), 0); + return true; +} + +LIBGAV1_ALWAYS_INLINE void Identity16ColumnStoreToFrame_SSE4_1( + Array2DView<uint8_t> frame, const int start_x, const int start_y, + const int tx_width, const int tx_height, const int16_t* source) { + const int stride = frame.columns(); + uint8_t* dst = frame[start_y] + start_x; + const __m128i v_eight = _mm_set1_epi16(8); + const __m128i v_multiplier = + _mm_set1_epi16(static_cast<int16_t>(kIdentity4MultiplierFraction << 4)); + + if (tx_width == 4) { + int i = 0; + do { + const __m128i v_src = LoadLo8(&source[i * tx_width]); + const __m128i v_src_mult = _mm_mulhrs_epi16(v_src, v_multiplier); + const __m128i frame_data = Load4(dst); + const __m128i v_srcx2 = _mm_adds_epi16(v_src, v_src); + const __m128i v_dst_i = _mm_adds_epi16(v_src_mult, v_srcx2); + const __m128i a = _mm_adds_epi16(v_dst_i, v_eight); + const __m128i b = _mm_srai_epi16(a, 4); + const __m128i c = _mm_cvtepu8_epi16(frame_data); + const __m128i d = _mm_adds_epi16(c, b); + Store4(dst, _mm_packus_epi16(d, d)); + dst += stride; + } while (++i < tx_height); + } else { + int i = 0; + do { + const int row = i * tx_width; + int j = 0; + do { + const __m128i v_src = LoadUnaligned16(&source[row + j]); + const __m128i v_src_mult = _mm_mulhrs_epi16(v_src, v_multiplier); + const __m128i frame_data = LoadLo8(dst + j); + const __m128i v_srcx2 = _mm_adds_epi16(v_src, v_src); + const __m128i v_dst_i = _mm_adds_epi16(v_src_mult, v_srcx2); + const __m128i a = _mm_adds_epi16(v_dst_i, v_eight); + const __m128i b = _mm_srai_epi16(a, 4); + const __m128i c = _mm_cvtepu8_epi16(frame_data); + const __m128i d = _mm_adds_epi16(c, b); + StoreLo8(dst + j, _mm_packus_epi16(d, d)); + j += 8; + } while (j < tx_width); + dst += stride; + } while (++i < tx_height); + } +} + +LIBGAV1_ALWAYS_INLINE void Identity32Row16_SSE4_1(void* dest, + const int32_t step) { + auto* const dst = static_cast<int16_t*>(dest); + + // When combining the identity32 multiplier with the row shift, the + // calculation for tx_height equal to 16 can be simplified from + // ((A * 4) + 1) >> 1) to (A * 2). + for (int h = 0; h < 4; ++h) { + for (int i = 0; i < 32; i += 8) { + const __m128i v_src = LoadUnaligned16(&dst[h * step + i]); + // For bitdepth == 8, the identity row clamps to a signed 16bit value, so + // saturating add here is ok. + const __m128i v_dst_i = _mm_adds_epi16(v_src, v_src); + StoreUnaligned16(&dst[h * step + i], v_dst_i); + } + } +} + +LIBGAV1_ALWAYS_INLINE bool Identity32DcOnly(void* dest, + int adjusted_tx_height) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int16_t*>(dest); + const __m128i v_src0 = _mm_cvtsi32_si128(dst[0]); + const __m128i v_kTransformRowMultiplier = + _mm_set1_epi16(kTransformRowMultiplier << 3); + const __m128i v_src = _mm_mulhrs_epi16(v_src0, v_kTransformRowMultiplier); + + // When combining the identity32 multiplier with the row shift, the + // calculation for tx_height equal to 16 can be simplified from + // ((A * 4) + 1) >> 1) to (A * 2). + const __m128i v_dst_0 = _mm_adds_epi16(v_src, v_src); + dst[0] = _mm_extract_epi16(v_dst_0, 0); + return true; +} + +LIBGAV1_ALWAYS_INLINE void Identity32ColumnStoreToFrame( + Array2DView<uint8_t> frame, const int start_x, const int start_y, + const int tx_width, const int tx_height, const int16_t* source) { + const int stride = frame.columns(); + uint8_t* dst = frame[start_y] + start_x; + const __m128i v_two = _mm_set1_epi16(2); + + int i = 0; + do { + const int row = i * tx_width; + int j = 0; + do { + const __m128i v_dst_i = LoadUnaligned16(&source[row + j]); + const __m128i frame_data = LoadLo8(dst + j); + const __m128i a = _mm_adds_epi16(v_dst_i, v_two); + const __m128i b = _mm_srai_epi16(a, 2); + const __m128i c = _mm_cvtepu8_epi16(frame_data); + const __m128i d = _mm_adds_epi16(c, b); + StoreLo8(dst + j, _mm_packus_epi16(d, d)); + j += 8; + } while (j < tx_width); + dst += stride; + } while (++i < tx_height); +} + +//------------------------------------------------------------------------------ +// Walsh Hadamard Transform. + +// Process 4 wht4 rows and columns. +LIBGAV1_ALWAYS_INLINE void Wht4_SSE4_1(Array2DView<uint8_t> frame, + const int start_x, const int start_y, + const void* source, + const int adjusted_tx_height) { + const auto* const src = static_cast<const int16_t*>(source); + __m128i s[4], x[4]; + + if (adjusted_tx_height == 1) { + // Special case: only src[0] is nonzero. + // src[0] 0 0 0 + // 0 0 0 0 + // 0 0 0 0 + // 0 0 0 0 + // + // After the row and column transforms are applied, we have: + // f h h h + // g i i i + // g i i i + // g i i i + // where f, g, h, i are computed as follows. + int16_t f = (src[0] >> 2) - (src[0] >> 3); + const int16_t g = f >> 1; + f = f - (f >> 1); + const int16_t h = (src[0] >> 3) - (src[0] >> 4); + const int16_t i = (src[0] >> 4); + s[0] = _mm_set1_epi16(h); + s[0] = _mm_insert_epi16(s[0], f, 0); + s[1] = _mm_set1_epi16(i); + s[1] = _mm_insert_epi16(s[1], g, 0); + s[2] = s[3] = s[1]; + } else { + x[0] = LoadLo8(&src[0 * 4]); + x[2] = LoadLo8(&src[1 * 4]); + x[3] = LoadLo8(&src[2 * 4]); + x[1] = LoadLo8(&src[3 * 4]); + + // Row transforms. + Transpose4x4_U16(x, x); + s[0] = _mm_srai_epi16(x[0], 2); + s[2] = _mm_srai_epi16(x[1], 2); + s[3] = _mm_srai_epi16(x[2], 2); + s[1] = _mm_srai_epi16(x[3], 2); + s[0] = _mm_add_epi16(s[0], s[2]); + s[3] = _mm_sub_epi16(s[3], s[1]); + __m128i e = _mm_sub_epi16(s[0], s[3]); + e = _mm_srai_epi16(e, 1); + s[1] = _mm_sub_epi16(e, s[1]); + s[2] = _mm_sub_epi16(e, s[2]); + s[0] = _mm_sub_epi16(s[0], s[1]); + s[3] = _mm_add_epi16(s[3], s[2]); + Transpose4x4_U16(s, s); + + // Column transforms. + s[0] = _mm_add_epi16(s[0], s[2]); + s[3] = _mm_sub_epi16(s[3], s[1]); + e = _mm_sub_epi16(s[0], s[3]); + e = _mm_srai_epi16(e, 1); + s[1] = _mm_sub_epi16(e, s[1]); + s[2] = _mm_sub_epi16(e, s[2]); + s[0] = _mm_sub_epi16(s[0], s[1]); + s[3] = _mm_add_epi16(s[3], s[2]); + } + + // Store to frame. + const int stride = frame.columns(); + uint8_t* dst = frame[start_y] + start_x; + for (int row = 0; row < 4; ++row) { + const __m128i frame_data = Load4(dst); + const __m128i a = _mm_cvtepu8_epi16(frame_data); + // Saturate to prevent overflowing int16_t + const __m128i b = _mm_adds_epi16(a, s[row]); + Store4(dst, _mm_packus_epi16(b, b)); + dst += stride; + } +} + +//------------------------------------------------------------------------------ +// row/column transform loops + +template <bool enable_flip_rows = false> +LIBGAV1_ALWAYS_INLINE void StoreToFrameWithRound( + Array2DView<uint8_t> frame, const int start_x, const int start_y, + const int tx_width, const int tx_height, const int16_t* source, + TransformType tx_type) { + const bool flip_rows = + enable_flip_rows ? kTransformFlipRowsMask.Contains(tx_type) : false; + const __m128i v_eight = _mm_set1_epi16(8); + const int stride = frame.columns(); + uint8_t* dst = frame[start_y] + start_x; + if (tx_width == 4) { + for (int i = 0; i < tx_height; ++i) { + const int row = flip_rows ? (tx_height - i - 1) * 4 : i * 4; + const __m128i residual = LoadLo8(&source[row]); + const __m128i frame_data = Load4(dst); + // Saturate to prevent overflowing int16_t + const __m128i a = _mm_adds_epi16(residual, v_eight); + const __m128i b = _mm_srai_epi16(a, 4); + const __m128i c = _mm_cvtepu8_epi16(frame_data); + const __m128i d = _mm_adds_epi16(c, b); + Store4(dst, _mm_packus_epi16(d, d)); + dst += stride; + } + } else if (tx_width == 8) { + for (int i = 0; i < tx_height; ++i) { + const int row = flip_rows ? (tx_height - i - 1) * 8 : i * 8; + const __m128i residual = LoadUnaligned16(&source[row]); + const __m128i frame_data = LoadLo8(dst); + // Saturate to prevent overflowing int16_t + const __m128i b = _mm_adds_epi16(residual, v_eight); + const __m128i c = _mm_srai_epi16(b, 4); + const __m128i d = _mm_cvtepu8_epi16(frame_data); + const __m128i e = _mm_adds_epi16(d, c); + StoreLo8(dst, _mm_packus_epi16(e, e)); + dst += stride; + } + } else { + for (int i = 0; i < tx_height; ++i) { + const int y = start_y + i; + const int row = flip_rows ? (tx_height - i - 1) * tx_width : i * tx_width; + int j = 0; + do { + const int x = start_x + j; + const __m128i residual = LoadUnaligned16(&source[row + j]); + const __m128i residual_hi = LoadUnaligned16(&source[row + j + 8]); + const __m128i frame_data = LoadUnaligned16(frame[y] + x); + const __m128i b = _mm_adds_epi16(residual, v_eight); + const __m128i b_hi = _mm_adds_epi16(residual_hi, v_eight); + const __m128i c = _mm_srai_epi16(b, 4); + const __m128i c_hi = _mm_srai_epi16(b_hi, 4); + const __m128i d = _mm_cvtepu8_epi16(frame_data); + const __m128i d_hi = _mm_cvtepu8_epi16(_mm_srli_si128(frame_data, 8)); + const __m128i e = _mm_adds_epi16(d, c); + const __m128i e_hi = _mm_adds_epi16(d_hi, c_hi); + StoreUnaligned16(frame[y] + x, _mm_packus_epi16(e, e_hi)); + j += 16; + } while (j < tx_width); + } + } +} + +template <int tx_height> +LIBGAV1_ALWAYS_INLINE void FlipColumns(int16_t* source, int tx_width) { + const __m128i word_reverse_8 = + _mm_set_epi32(0x01000302, 0x05040706, 0x09080b0a, 0x0d0c0f0e); + if (tx_width >= 16) { + int i = 0; + do { + // read 16 shorts + const __m128i v3210 = LoadUnaligned16(&source[i]); + const __m128i v7654 = LoadUnaligned16(&source[i + 8]); + const __m128i v0123 = _mm_shuffle_epi8(v3210, word_reverse_8); + const __m128i v4567 = _mm_shuffle_epi8(v7654, word_reverse_8); + StoreUnaligned16(&source[i], v4567); + StoreUnaligned16(&source[i + 8], v0123); + i += 16; + } while (i < tx_width * tx_height); + } else if (tx_width == 8) { + for (int i = 0; i < 8 * tx_height; i += 8) { + const __m128i a = LoadUnaligned16(&source[i]); + const __m128i b = _mm_shuffle_epi8(a, word_reverse_8); + StoreUnaligned16(&source[i], b); + } + } else { + const __m128i dual_word_reverse_4 = + _mm_set_epi32(0x09080b0a, 0x0d0c0f0e, 0x01000302, 0x05040706); + // Process two rows per iteration. + for (int i = 0; i < 4 * tx_height; i += 8) { + const __m128i a = LoadUnaligned16(&source[i]); + const __m128i b = _mm_shuffle_epi8(a, dual_word_reverse_4); + StoreUnaligned16(&source[i], b); + } + } +} + +template <int tx_width> +LIBGAV1_ALWAYS_INLINE void ApplyRounding(int16_t* source, int num_rows) { + const __m128i v_kTransformRowMultiplier = + _mm_set1_epi16(kTransformRowMultiplier << 3); + if (tx_width == 4) { + // Process two rows per iteration. + int i = 0; + do { + const __m128i a = LoadUnaligned16(&source[i]); + const __m128i b = _mm_mulhrs_epi16(a, v_kTransformRowMultiplier); + StoreUnaligned16(&source[i], b); + i += 8; + } while (i < tx_width * num_rows); + } else { + int i = 0; + do { + // The last 32 values of every row are always zero if the |tx_width| is + // 64. + const int non_zero_width = (tx_width < 64) ? tx_width : 32; + int j = 0; + do { + const __m128i a = LoadUnaligned16(&source[i * tx_width + j]); + const __m128i b = _mm_mulhrs_epi16(a, v_kTransformRowMultiplier); + StoreUnaligned16(&source[i * tx_width + j], b); + j += 8; + } while (j < non_zero_width); + } while (++i < num_rows); + } +} + +template <int tx_width> +LIBGAV1_ALWAYS_INLINE void RowShift(int16_t* source, int num_rows, + int row_shift) { + const __m128i v_row_shift_add = _mm_set1_epi16(row_shift); + const __m128i v_row_shift = _mm_cvtepu16_epi64(v_row_shift_add); + if (tx_width == 4) { + // Process two rows per iteration. + int i = 0; + do { + const __m128i residual = LoadUnaligned16(&source[i]); + const __m128i shifted_residual = + ShiftResidual(residual, v_row_shift_add, v_row_shift); + StoreUnaligned16(&source[i], shifted_residual); + i += 8; + } while (i < tx_width * num_rows); + } else { + int i = 0; + do { + for (int j = 0; j < tx_width; j += 8) { + const __m128i residual = LoadUnaligned16(&source[i * tx_width + j]); + const __m128i shifted_residual = + ShiftResidual(residual, v_row_shift_add, v_row_shift); + StoreUnaligned16(&source[i * tx_width + j], shifted_residual); + } + } while (++i < num_rows); + } +} + +void Dct4TransformLoopRow_SSE4_1(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_height = kTransformHeight[tx_size]; + const bool should_round = (tx_height == 8); + const int row_shift = static_cast<int>(tx_height == 16); + + if (DctDcOnly<4>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<4>(src, adjusted_tx_height); + } + + if (adjusted_tx_height <= 4) { + // Process 4 1d dct4 rows in parallel. + Dct4_SSE4_1<ButterflyRotation_4, false>(src, /*step=*/4, + /*transpose=*/true); + } else { + // Process 8 1d dct4 rows in parallel per iteration. + int i = 0; + do { + Dct4_SSE4_1<ButterflyRotation_8, true>(&src[i * 4], /*step=*/4, + /*transpose=*/true); + i += 8; + } while (i < adjusted_tx_height); + } + if (tx_height == 16) { + RowShift<4>(src, adjusted_tx_height, 1); + } +} + +void Dct4TransformLoopColumn_SSE4_1(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<4>(src, tx_width); + } + + if (!DctDcOnlyColumn<4>(src, adjusted_tx_height, tx_width)) { + if (tx_width == 4) { + // Process 4 1d dct4 columns in parallel. + Dct4_SSE4_1<ButterflyRotation_4, false>(src, tx_width, + /*transpose=*/false); + } else { + // Process 8 1d dct4 columns in parallel per iteration. + int i = 0; + do { + Dct4_SSE4_1<ButterflyRotation_8, true>(&src[i], tx_width, + /*transpose=*/false); + i += 8; + } while (i < tx_width); + } + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound(frame, start_x, start_y, tx_width, 4, src, tx_type); +} + +void Dct8TransformLoopRow_SSE4_1(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (DctDcOnly<8>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<8>(src, adjusted_tx_height); + } + + if (adjusted_tx_height <= 4) { + // Process 4 1d dct8 rows in parallel. + Dct8_SSE4_1<ButterflyRotation_4, true>(src, /*step=*/8, /*transpose=*/true); + } else { + // Process 8 1d dct8 rows in parallel per iteration. + int i = 0; + do { + Dct8_SSE4_1<ButterflyRotation_8, false>(&src[i * 8], /*step=*/8, + /*transpose=*/true); + i += 8; + } while (i < adjusted_tx_height); + } + if (row_shift > 0) { + RowShift<8>(src, adjusted_tx_height, row_shift); + } +} + +void Dct8TransformLoopColumn_SSE4_1(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<8>(src, tx_width); + } + + if (!DctDcOnlyColumn<8>(src, adjusted_tx_height, tx_width)) { + if (tx_width == 4) { + // Process 4 1d dct8 columns in parallel. + Dct8_SSE4_1<ButterflyRotation_4, true>(src, 4, /*transpose=*/false); + } else { + // Process 8 1d dct8 columns in parallel per iteration. + int i = 0; + do { + Dct8_SSE4_1<ButterflyRotation_8, false>(&src[i], tx_width, + /*transpose=*/false); + i += 8; + } while (i < tx_width); + } + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound(frame, start_x, start_y, tx_width, 8, src, tx_type); +} + +void Dct16TransformLoopRow_SSE4_1(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (DctDcOnly<16>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<16>(src, adjusted_tx_height); + } + + if (adjusted_tx_height <= 4) { + // Process 4 1d dct16 rows in parallel. + Dct16_SSE4_1<ButterflyRotation_4, true>(src, 16, /*transpose=*/true); + } else { + int i = 0; + do { + // Process 8 1d dct16 rows in parallel per iteration. + Dct16_SSE4_1<ButterflyRotation_8, false>(&src[i * 16], 16, + /*transpose=*/true); + i += 8; + } while (i < adjusted_tx_height); + } + // row_shift is always non zero here. + RowShift<16>(src, adjusted_tx_height, row_shift); +} + +void Dct16TransformLoopColumn_SSE4_1(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, + void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<16>(src, tx_width); + } + + if (!DctDcOnlyColumn<16>(src, adjusted_tx_height, tx_width)) { + if (tx_width == 4) { + // Process 4 1d dct16 columns in parallel. + Dct16_SSE4_1<ButterflyRotation_4, true>(src, 4, /*transpose=*/false); + } else { + int i = 0; + do { + // Process 8 1d dct16 columns in parallel per iteration. + Dct16_SSE4_1<ButterflyRotation_8, false>(&src[i], tx_width, + /*transpose=*/false); + i += 8; + } while (i < tx_width); + } + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound(frame, start_x, start_y, tx_width, 16, src, tx_type); +} + +void Dct32TransformLoopRow_SSE4_1(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (DctDcOnly<32>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<32>(src, adjusted_tx_height); + } + // Process 8 1d dct32 rows in parallel per iteration. + int i = 0; + do { + Dct32_SSE4_1(&src[i * 32], 32, /*transpose=*/true); + i += 8; + } while (i < adjusted_tx_height); + // row_shift is always non zero here. + RowShift<32>(src, adjusted_tx_height, row_shift); +} + +void Dct32TransformLoopColumn_SSE4_1(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, + void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (!DctDcOnlyColumn<32>(src, adjusted_tx_height, tx_width)) { + // Process 8 1d dct32 columns in parallel per iteration. + int i = 0; + do { + Dct32_SSE4_1(&src[i], tx_width, /*transpose=*/false); + i += 8; + } while (i < tx_width); + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound(frame, start_x, start_y, tx_width, 32, src, tx_type); +} + +void Dct64TransformLoopRow_SSE4_1(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (DctDcOnly<64>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<64>(src, adjusted_tx_height); + } + // Process 8 1d dct64 rows in parallel per iteration. + int i = 0; + do { + Dct64_SSE4_1(&src[i * 64], 64, /*transpose=*/true); + i += 8; + } while (i < adjusted_tx_height); + // row_shift is always non zero here. + RowShift<64>(src, adjusted_tx_height, row_shift); +} + +void Dct64TransformLoopColumn_SSE4_1(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, + void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (!DctDcOnlyColumn<64>(src, adjusted_tx_height, tx_width)) { + // Process 8 1d dct64 columns in parallel per iteration. + int i = 0; + do { + Dct64_SSE4_1(&src[i], tx_width, /*transpose=*/false); + i += 8; + } while (i < tx_width); + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound(frame, start_x, start_y, tx_width, 64, src, tx_type); +} + +void Adst4TransformLoopRow_SSE4_1(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_height = kTransformHeight[tx_size]; + const int row_shift = static_cast<int>(tx_height == 16); + const bool should_round = (tx_height == 8); + + if (Adst4DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<4>(src, adjusted_tx_height); + } + + // Process 4 1d adst4 rows in parallel per iteration. + int i = 0; + do { + Adst4_SSE4_1<false>(&src[i * 4], /*step=*/4, /*transpose=*/true); + i += 4; + } while (i < adjusted_tx_height); + + if (row_shift != 0) { + RowShift<4>(src, adjusted_tx_height, 1); + } +} + +void Adst4TransformLoopColumn_SSE4_1(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, + void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<4>(src, tx_width); + } + + if (!Adst4DcOnlyColumn(src, adjusted_tx_height, tx_width)) { + // Process 4 1d adst4 columns in parallel per iteration. + int i = 0; + do { + Adst4_SSE4_1<false>(&src[i], tx_width, /*transpose=*/false); + i += 4; + } while (i < tx_width); + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound</*enable_flip_rows=*/true>(frame, start_x, start_y, + tx_width, 4, src, tx_type); +} + +void Adst8TransformLoopRow_SSE4_1(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (Adst8DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<8>(src, adjusted_tx_height); + } + + if (adjusted_tx_height <= 4) { + // Process 4 1d adst8 rows in parallel. + Adst8_SSE4_1<ButterflyRotation_4, true>(src, /*step=*/8, + /*transpose=*/true); + } else { + // Process 8 1d adst8 rows in parallel per iteration. + int i = 0; + do { + Adst8_SSE4_1<ButterflyRotation_8, false>(&src[i * 8], /*step=*/8, + /*transpose=*/true); + i += 8; + } while (i < adjusted_tx_height); + } + if (row_shift > 0) { + RowShift<8>(src, adjusted_tx_height, row_shift); + } +} + +void Adst8TransformLoopColumn_SSE4_1(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, + void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<8>(src, tx_width); + } + + if (!Adst8DcOnlyColumn(src, adjusted_tx_height, tx_width)) { + if (tx_width == 4) { + // Process 4 1d adst8 columns in parallel. + Adst8_SSE4_1<ButterflyRotation_4, true>(src, 4, /*transpose=*/false); + } else { + // Process 8 1d adst8 columns in parallel per iteration. + int i = 0; + do { + Adst8_SSE4_1<ButterflyRotation_8, false>(&src[i], tx_width, + /*transpose=*/false); + i += 8; + } while (i < tx_width); + } + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + StoreToFrameWithRound</*enable_flip_rows=*/true>(frame, start_x, start_y, + tx_width, 8, src, tx_type); +} + +void Adst16TransformLoopRow_SSE4_1(TransformType /*tx_type*/, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (Adst16DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<16>(src, adjusted_tx_height); + } + + if (adjusted_tx_height <= 4) { + // Process 4 1d adst16 rows in parallel. + Adst16_SSE4_1<ButterflyRotation_4, true>(src, 16, /*transpose=*/true); + } else { + int i = 0; + do { + // Process 8 1d adst16 rows in parallel per iteration. + Adst16_SSE4_1<ButterflyRotation_8, false>(&src[i * 16], 16, + /*transpose=*/true); + i += 8; + } while (i < adjusted_tx_height); + } + // row_shift is always non zero here. + RowShift<16>(src, adjusted_tx_height, row_shift); +} + +void Adst16TransformLoopColumn_SSE4_1(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, + void* dst_frame) { + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<16>(src, tx_width); + } + + if (!Adst16DcOnlyColumn(src, adjusted_tx_height, tx_width)) { + if (tx_width == 4) { + // Process 4 1d adst16 columns in parallel. + Adst16_SSE4_1<ButterflyRotation_4, true>(src, 4, /*transpose=*/false); + } else { + int i = 0; + do { + // Process 8 1d adst16 columns in parallel per iteration. + Adst16_SSE4_1<ButterflyRotation_8, false>(&src[i], tx_width, + /*transpose=*/false); + i += 8; + } while (i < tx_width); + } + } + StoreToFrameWithRound</*enable_flip_rows=*/true>(frame, start_x, start_y, + tx_width, 16, src, tx_type); +} + +void Identity4TransformLoopRow_SSE4_1(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + // Special case: Process row calculations during column transform call. + // Improves performance. + if (tx_type == kTransformTypeIdentityIdentity && + tx_size == kTransformSize4x4) { + return; + } + + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_height = kTransformHeight[tx_size]; + const bool should_round = (tx_height == 8); + if (Identity4DcOnly(src, adjusted_tx_height, should_round, tx_height)) { + return; + } + + if (should_round) { + ApplyRounding<4>(src, adjusted_tx_height); + } + if (tx_height < 16) { + int i = 0; + do { + Identity4_SSE4_1<false>(&src[i * 4], /*step=*/4); + i += 4; + } while (i < adjusted_tx_height); + } else { + int i = 0; + do { + Identity4_SSE4_1<true>(&src[i * 4], /*step=*/4); + i += 4; + } while (i < adjusted_tx_height); + } +} + +void Identity4TransformLoopColumn_SSE4_1(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, + void* src_buffer, int start_x, + int start_y, void* dst_frame) { + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + // Special case: Process row calculations during column transform call. + if (tx_type == kTransformTypeIdentityIdentity && + (tx_size == kTransformSize4x4 || tx_size == kTransformSize8x4)) { + Identity4RowColumnStoreToFrame(frame, start_x, start_y, tx_width, + adjusted_tx_height, src); + return; + } + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<4>(src, tx_width); + } + + Identity4ColumnStoreToFrame(frame, start_x, start_y, tx_width, + adjusted_tx_height, src); +} + +void Identity8TransformLoopRow_SSE4_1(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + // Special case: Process row calculations during column transform call. + // Improves performance. + if (tx_type == kTransformTypeIdentityIdentity && + tx_size == kTransformSize8x4) { + return; + } + + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_height = kTransformHeight[tx_size]; + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + if (Identity8DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<8>(src, adjusted_tx_height); + } + + // When combining the identity8 multiplier with the row shift, the + // calculations for tx_height == 8 and tx_height == 16 can be simplified + // from ((A * 2) + 1) >> 1) to A. + if ((tx_height & 0x18) != 0) { + return; + } + if (tx_height == 32) { + int i = 0; + do { + Identity8Row32_SSE4_1(&src[i * 8], /*step=*/8); + i += 4; + } while (i < adjusted_tx_height); + return; + } + + assert(tx_size == kTransformSize8x4); + int i = 0; + do { + Identity8Row4_SSE4_1(&src[i * 8], /*step=*/8); + i += 4; + } while (i < adjusted_tx_height); +} + +void Identity8TransformLoopColumn_SSE4_1(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, + void* src_buffer, int start_x, + int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<8>(src, tx_width); + } + + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + Identity8ColumnStoreToFrame_SSE4_1(frame, start_x, start_y, tx_width, + adjusted_tx_height, src); +} + +void Identity16TransformLoopRow_SSE4_1(TransformType /*tx_type*/, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + auto* src = static_cast<int16_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + if (Identity16DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<16>(src, adjusted_tx_height); + } + int i = 0; + do { + Identity16Row_SSE4_1(&src[i * 16], /*step=*/16, + kTransformRowShift[tx_size]); + i += 4; + } while (i < adjusted_tx_height); +} + +void Identity16TransformLoopColumn_SSE4_1(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, + void* src_buffer, int start_x, + int start_y, void* dst_frame) { + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<16>(src, tx_width); + } + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + Identity16ColumnStoreToFrame_SSE4_1(frame, start_x, start_y, tx_width, + adjusted_tx_height, src); +} + +void Identity32TransformLoopRow_SSE4_1(TransformType /*tx_type*/, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + const int tx_height = kTransformHeight[tx_size]; + // When combining the identity32 multiplier with the row shift, the + // calculations for tx_height == 8 and tx_height == 32 can be simplified + // from ((A * 4) + 2) >> 2) to A. + if ((tx_height & 0x28) != 0) { + return; + } + + // Process kTransformSize32x16. The src is always rounded before the + // identity transform and shifted by 1 afterwards. + auto* src = static_cast<int16_t*>(src_buffer); + if (Identity32DcOnly(src, adjusted_tx_height)) { + return; + } + + assert(tx_size == kTransformSize32x16); + ApplyRounding<32>(src, adjusted_tx_height); + int i = 0; + do { + Identity32Row16_SSE4_1(&src[i * 32], /*step=*/32); + i += 4; + } while (i < adjusted_tx_height); +} + +void Identity32TransformLoopColumn_SSE4_1(TransformType /*tx_type*/, + TransformSize tx_size, + int adjusted_tx_height, + void* src_buffer, int start_x, + int start_y, void* dst_frame) { + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + auto* src = static_cast<int16_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + Identity32ColumnStoreToFrame(frame, start_x, start_y, tx_width, + adjusted_tx_height, src); +} + +void Wht4TransformLoopRow_SSE4_1(TransformType tx_type, TransformSize tx_size, + int /*adjusted_tx_height*/, + void* /*src_buffer*/, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + assert(tx_type == kTransformTypeDctDct); + assert(tx_size == kTransformSize4x4); + static_cast<void>(tx_type); + static_cast<void>(tx_size); + // Do both row and column transforms in the column-transform pass. +} + +void Wht4TransformLoopColumn_SSE4_1(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + assert(tx_type == kTransformTypeDctDct); + assert(tx_size == kTransformSize4x4); + static_cast<void>(tx_type); + static_cast<void>(tx_size); + + // Do both row and column transforms in the column-transform pass. + // Process 4 1d wht4 rows and columns in parallel. + const auto* src = static_cast<int16_t*>(src_buffer); + auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame); + Wht4_SSE4_1(frame, start_x, start_y, src, adjusted_tx_height); +} + +//------------------------------------------------------------------------------ + +template <typename Residual, typename Pixel> +void InitAll(Dsp* const dsp) { + // Maximum transform size for Dct is 64. + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kRow] = + Dct4TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kColumn] = + Dct4TransformLoopColumn_SSE4_1; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kRow] = + Dct8TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kColumn] = + Dct8TransformLoopColumn_SSE4_1; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kRow] = + Dct16TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kColumn] = + Dct16TransformLoopColumn_SSE4_1; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kRow] = + Dct32TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kColumn] = + Dct32TransformLoopColumn_SSE4_1; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kRow] = + Dct64TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kColumn] = + Dct64TransformLoopColumn_SSE4_1; + + // Maximum transform size for Adst is 16. + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kRow] = + Adst4TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kColumn] = + Adst4TransformLoopColumn_SSE4_1; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kRow] = + Adst8TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kColumn] = + Adst8TransformLoopColumn_SSE4_1; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kRow] = + Adst16TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kColumn] = + Adst16TransformLoopColumn_SSE4_1; + + // Maximum transform size for Identity transform is 32. + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kRow] = + Identity4TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kColumn] = + Identity4TransformLoopColumn_SSE4_1; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kRow] = + Identity8TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kColumn] = + Identity8TransformLoopColumn_SSE4_1; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kRow] = + Identity16TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kColumn] = + Identity16TransformLoopColumn_SSE4_1; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kRow] = + Identity32TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kColumn] = + Identity32TransformLoopColumn_SSE4_1; + + // Maximum transform size for Wht is 4. + dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kRow] = + Wht4TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kColumn] = + Wht4TransformLoopColumn_SSE4_1; +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + InitAll<int16_t, uint8_t>(dsp); +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize4_1DTransformDct) + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kRow] = + Dct4TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kColumn] = + Dct4TransformLoopColumn_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize8_1DTransformDct) + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kRow] = + Dct8TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kColumn] = + Dct8TransformLoopColumn_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize16_1DTransformDct) + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kRow] = + Dct16TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kColumn] = + Dct16TransformLoopColumn_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize32_1DTransformDct) + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kRow] = + Dct32TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kColumn] = + Dct32TransformLoopColumn_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize64_1DTransformDct) + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kRow] = + Dct64TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kColumn] = + Dct64TransformLoopColumn_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize4_1DTransformAdst) + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kRow] = + Adst4TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kColumn] = + Adst4TransformLoopColumn_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize8_1DTransformAdst) + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kRow] = + Adst8TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kColumn] = + Adst8TransformLoopColumn_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize16_1DTransformAdst) + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kRow] = + Adst16TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kColumn] = + Adst16TransformLoopColumn_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize4_1DTransformIdentity) + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kRow] = + Identity4TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kColumn] = + Identity4TransformLoopColumn_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize8_1DTransformIdentity) + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kRow] = + Identity8TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kColumn] = + Identity8TransformLoopColumn_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize16_1DTransformIdentity) + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kRow] = + Identity16TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kColumn] = + Identity16TransformLoopColumn_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize32_1DTransformIdentity) + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kRow] = + Identity32TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kColumn] = + Identity32TransformLoopColumn_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize4_1DTransformWht) + dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kRow] = + Wht4TransformLoopRow_SSE4_1; + dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kColumn] = + Wht4TransformLoopColumn_SSE4_1; +#endif +#endif +} + +} // namespace +} // namespace low_bitdepth + +void InverseTransformInit_SSE4_1() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 +#else // !LIBGAV1_TARGETING_SSE4_1 +namespace libgav1 { +namespace dsp { + +void InverseTransformInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/inverse_transform_sse4.h b/src/dsp/x86/inverse_transform_sse4.h new file mode 100644 index 0000000..106084b --- /dev/null +++ b/src/dsp/x86/inverse_transform_sse4.h @@ -0,0 +1,89 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_INVERSE_TRANSFORM_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_INVERSE_TRANSFORM_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::inverse_transforms, see the defines below for specifics. +// This function is not thread-safe. +void InverseTransformInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +// If sse4 is enabled and the baseline isn't set due to a higher level of +// optimization being enabled, signal the sse4 implementation should be used. +#if LIBGAV1_TARGETING_SSE4_1 + +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformDct +#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformDct LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformDct +#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformDct LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformDct +#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformDct LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformDct +#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformDct LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize64_1DTransformDct +#define LIBGAV1_Dsp8bpp_1DTransformSize64_1DTransformDct LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformAdst +#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformAdst LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformAdst +#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformAdst LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformAdst +#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformAdst LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformIdentity +#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformIdentity LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformIdentity +#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformIdentity LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformIdentity +#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformIdentity LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity +#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht +#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht LIBGAV1_CPU_SSE4_1 +#endif +#endif // LIBGAV1_TARGETING_SSE4_1 +#endif // LIBGAV1_SRC_DSP_X86_INVERSE_TRANSFORM_SSE4_H_ diff --git a/src/dsp/x86/loop_filter_sse4.cc b/src/dsp/x86/loop_filter_sse4.cc new file mode 100644 index 0000000..d67b450 --- /dev/null +++ b/src/dsp/x86/loop_filter_sse4.cc @@ -0,0 +1,2256 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/loop_filter.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 + +#include <smmintrin.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" + +namespace libgav1 { +namespace dsp { +namespace { + +inline __m128i FilterAdd2Sub2(const __m128i& total, const __m128i& a1, + const __m128i& a2, const __m128i& s1, + const __m128i& s2) { + __m128i x = _mm_add_epi16(a1, total); + x = _mm_add_epi16(_mm_sub_epi16(x, _mm_add_epi16(s1, s2)), a2); + return x; +} + +} // namespace + +namespace low_bitdepth { +namespace { + +inline __m128i AbsDiff(const __m128i& a, const __m128i& b) { + return _mm_or_si128(_mm_subs_epu8(a, b), _mm_subs_epu8(b, a)); +} + +inline __m128i CheckOuterThreshF4(const __m128i& q1q0, const __m128i& p1p0, + const __m128i& outer_thresh) { + const __m128i fe = _mm_set1_epi8(static_cast<int8_t>(0xfe)); + // abs(p0 - q0) * 2 + abs(p1 - q1) / 2 <= outer_thresh; + const __m128i abs_pmq = AbsDiff(p1p0, q1q0); + const __m128i a = _mm_adds_epu8(abs_pmq, abs_pmq); + const __m128i b = _mm_srli_epi16(_mm_and_si128(abs_pmq, fe), 1); + const __m128i c = _mm_adds_epu8(a, _mm_srli_si128(b, 4)); + return _mm_subs_epu8(c, outer_thresh); +} + +inline __m128i Hev(const __m128i& qp1, const __m128i& qp0, + const __m128i& hev_thresh) { + const __m128i abs_qp1mqp0 = AbsDiff(qp1, qp0); + const __m128i max_pq = + _mm_max_epu8(abs_qp1mqp0, _mm_srli_si128(abs_qp1mqp0, 4)); + const __m128i hev_mask0 = _mm_cvtepu8_epi16(max_pq); + const __m128i hev_mask1 = _mm_cmpgt_epi16(hev_mask0, hev_thresh); + const __m128i hev_mask = _mm_packs_epi16(hev_mask1, hev_mask1); + return hev_mask; +} + +inline __m128i AddShift3(const __m128i& a, const __m128i& b) { + const __m128i c = _mm_adds_epi8(a, b); + const __m128i d = _mm_unpacklo_epi8(c, c); + const __m128i e = _mm_srai_epi16(d, 11); /* >> 3 */ + return _mm_packs_epi16(e, e); +} + +inline __m128i AddShift1(const __m128i& a, const __m128i& b) { + const __m128i c = _mm_adds_epi8(a, b); + const __m128i d = _mm_unpacklo_epi8(c, c); + const __m128i e = _mm_srai_epi16(d, 9); /* >> 1 */ + return _mm_packs_epi16(e, e); +} + +//------------------------------------------------------------------------------ +// 4-tap filters + +inline __m128i NeedsFilter4(const __m128i& q1q0, const __m128i& p1p0, + const __m128i& qp1, const __m128i& qp0, + const __m128i& outer_thresh, + const __m128i& inner_thresh) { + const __m128i outer_mask = CheckOuterThreshF4(q1q0, p1p0, outer_thresh); + const __m128i abs_qp1mqp0 = AbsDiff(qp1, qp0); + const __m128i inner_mask = _mm_subs_epu8( + _mm_max_epu8(abs_qp1mqp0, _mm_srli_si128(abs_qp1mqp0, 4)), inner_thresh); + // ~mask + const __m128i zero = _mm_setzero_si128(); + const __m128i a = _mm_or_si128(outer_mask, inner_mask); + const __m128i b = _mm_cmpeq_epi8(a, zero); + return b; +} + +inline void Filter4(const __m128i& qp1, const __m128i& qp0, __m128i* oqp1, + __m128i* oqp0, const __m128i& mask, const __m128i& hev) { + const __m128i t80 = _mm_set1_epi8(static_cast<int8_t>(0x80)); + const __m128i t1 = _mm_set1_epi8(0x1); + const __m128i qp1qp0 = _mm_unpacklo_epi64(qp0, qp1); + const __m128i qps1qps0 = _mm_xor_si128(qp1qp0, t80); + const __m128i ps1qs0 = _mm_shuffle_epi32(qps1qps0, 0x09); + const __m128i qs1ps0 = _mm_shuffle_epi32(qps1qps0, 0x0c); + const __m128i _hev = _mm_unpacklo_epi32(hev, hev); + const __m128i x = _mm_subs_epi8(ps1qs0, qs1ps0); + __m128i a = _mm_and_si128(_mm_srli_si128(x, 4), _hev); + + a = _mm_adds_epi8(a, x); + a = _mm_adds_epi8(a, x); + a = _mm_adds_epi8(a, x); + a = _mm_and_si128(a, mask); + a = _mm_unpacklo_epi32(a, a); + + const __m128i t4t3 = _mm_set_epi32(0x0, 0x0, 0x04040404, 0x03030303); + const __m128i a1a2 = AddShift3(a, t4t3); + const __m128i a1a1 = _mm_shuffle_epi32(a1a2, 0x55); + const __m128i a3a3 = _mm_andnot_si128(_hev, AddShift1(a1a1, t1)); + // -1 -1 -1 -1 1 1 1 1 -1 -1 -1 -1 1 1 1 1 + const __m128i adjust_sign_for_add = + _mm_unpacklo_epi32(t1, _mm_cmpeq_epi8(t1, t1)); + + const __m128i a3a3a1a2 = _mm_unpacklo_epi64(a1a2, a3a3); + const __m128i ma3a3ma1a2 = _mm_sign_epi8(a3a3a1a2, adjust_sign_for_add); + + const __m128i b = _mm_adds_epi8(qps1qps0, ma3a3ma1a2); + const __m128i c = _mm_xor_si128(b, t80); + + *oqp0 = c; + *oqp1 = _mm_srli_si128(c, 8); +} + +void Horizontal4(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh) { + auto* const dst = static_cast<uint8_t*>(dest); + const __m128i zero = _mm_setzero_si128(); + const __m128i v_outer_thresh = + _mm_shuffle_epi8(_mm_cvtsi32_si128(outer_thresh), zero); + const __m128i v_inner_thresh = + _mm_shuffle_epi8(_mm_cvtsi32_si128(inner_thresh), zero); + const __m128i v_hev_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(hev_thresh), 0); + + const __m128i p1 = Load4(dst - 2 * stride); + const __m128i p0 = Load4(dst - 1 * stride); + const __m128i q0 = Load4(dst + 0 * stride); + const __m128i q1 = Load4(dst + 1 * stride); + const __m128i qp1 = _mm_unpacklo_epi32(p1, q1); + const __m128i qp0 = _mm_unpacklo_epi32(p0, q0); + const __m128i q1q0 = _mm_unpacklo_epi32(q0, q1); + const __m128i p1p0 = _mm_unpacklo_epi32(p0, p1); + + const __m128i v_hev_mask = Hev(qp1, qp0, v_hev_thresh); + const __m128i v_needs_mask = + NeedsFilter4(q1q0, p1p0, qp1, qp0, v_outer_thresh, v_inner_thresh); + + __m128i oqp1; + __m128i oqp0; + Filter4(qp1, qp0, &oqp1, &oqp0, v_needs_mask, v_hev_mask); + + Store4(dst - 2 * stride, oqp1); + Store4(dst - 1 * stride, oqp0); + Store4(dst + 0 * stride, _mm_srli_si128(oqp0, 4)); + Store4(dst + 1 * stride, _mm_srli_si128(oqp1, 4)); +} + +inline void Transpose4x4(const __m128i& x0, const __m128i& x1, + const __m128i& x2, const __m128i& x3, __m128i* d0, + __m128i* d1, __m128i* d2, __m128i* d3) { + // input + // x0 00 01 02 03 xx xx xx xx xx xx xx xx xx xx xx xx + // x1 10 11 12 13 xx xx xx xx xx xx xx xx xx xx xx xx + // x2 20 21 22 23 xx xx xx xx xx xx xx xx xx xx xx xx + // x3 30 31 32 33 xx xx xx xx xx xx xx xx xx xx xx xx + // output + // d0 00 10 20 30 xx xx xx xx xx xx xx xx xx xx xx xx + // d1 01 11 21 31 xx xx xx xx xx xx xx xx xx xx xx xx + // d2 02 12 22 32 xx xx xx xx xx xx xx xx xx xx xx xx + // d3 03 13 23 33 xx xx xx xx xx xx xx xx xx xx xx xx + + // 00 10 01 11 02 12 03 13 04 14 05 15 06 16 07 17 + const __m128i w0 = _mm_unpacklo_epi8(x0, x1); + // 20 30 21 31 22 32 23 33 24 34 25 35 26 36 27 37 + const __m128i w1 = _mm_unpacklo_epi8(x2, x3); + + // 00 10 20 30 01 11 21 31 02 12 22 32 03 13 23 33 + *d0 = _mm_unpacklo_epi16(w0, w1); + // 01 11 21 31 xx xx xx xx xx xx xx xx xx xx xx xx + *d1 = _mm_srli_si128(*d0, 4); + // 02 12 22 32 xx xx xx xx xx xx xx xx xx xx xx xx + *d2 = _mm_srli_si128(*d0, 8); + // 03 13 23 33 xx xx xx xx xx xx xx xx xx xx xx xx + *d3 = _mm_srli_si128(*d0, 12); +} + +void Vertical4(void* dest, ptrdiff_t stride, int outer_thresh, int inner_thresh, + int hev_thresh) { + auto* const dst = static_cast<uint8_t*>(dest); + const __m128i zero = _mm_setzero_si128(); + const __m128i v_outer_thresh = + _mm_shuffle_epi8(_mm_cvtsi32_si128(outer_thresh), zero); + const __m128i v_inner_thresh = + _mm_shuffle_epi8(_mm_cvtsi32_si128(inner_thresh), zero); + const __m128i v_hev_thresh0 = + _mm_shuffle_epi8(_mm_cvtsi32_si128(hev_thresh), zero); + const __m128i v_hev_thresh = _mm_unpacklo_epi8(v_hev_thresh0, zero); + + __m128i x0 = Load4(dst - 2 + 0 * stride); + __m128i x1 = Load4(dst - 2 + 1 * stride); + __m128i x2 = Load4(dst - 2 + 2 * stride); + __m128i x3 = Load4(dst - 2 + 3 * stride); + + // 00 10 01 11 02 12 03 13 04 14 05 15 06 16 07 17 + const __m128i w0 = _mm_unpacklo_epi8(x0, x1); + // 20 30 21 31 22 32 23 33 24 34 25 35 26 36 27 37 + const __m128i w1 = _mm_unpacklo_epi8(x2, x3); + // 00 10 20 30 01 11 21 31 02 12 22 32 03 13 23 33 + const __m128i d0 = _mm_unpacklo_epi16(w0, w1); + const __m128i qp1 = _mm_shuffle_epi32(d0, 0xc); + const __m128i qp0 = _mm_srli_si128(d0, 4); + const __m128i q1q0 = _mm_srli_si128(d0, 8); + const __m128i p1p0 = _mm_shuffle_epi32(d0, 0x1); + + const __m128i v_hev_mask = Hev(qp1, qp0, v_hev_thresh); + const __m128i v_needs_mask = + NeedsFilter4(q1q0, p1p0, qp1, qp0, v_outer_thresh, v_inner_thresh); + + __m128i oqp1; + __m128i oqp0; + Filter4(qp1, qp0, &oqp1, &oqp0, v_needs_mask, v_hev_mask); + + const __m128i p1 = oqp1; + const __m128i p0 = oqp0; + const __m128i q0 = _mm_srli_si128(oqp0, 4); + const __m128i q1 = _mm_srli_si128(oqp1, 4); + + Transpose4x4(p1, p0, q0, q1, &x0, &x1, &x2, &x3); + + Store4(dst - 2 + 0 * stride, x0); + Store4(dst - 2 + 1 * stride, x1); + Store4(dst - 2 + 2 * stride, x2); + Store4(dst - 2 + 3 * stride, x3); +} + +//------------------------------------------------------------------------------ +// 5-tap (chroma) filters + +inline __m128i NeedsFilter6(const __m128i& q1q0, const __m128i& p1p0, + const __m128i& qp2, const __m128i& qp1, + const __m128i& qp0, const __m128i& outer_thresh, + const __m128i& inner_thresh) { + const __m128i outer_mask = CheckOuterThreshF4(q1q0, p1p0, outer_thresh); + const __m128i abs_qp2mqp1 = AbsDiff(qp2, qp1); + const __m128i abs_qp1mqp0 = AbsDiff(qp1, qp0); + const __m128i max_pq = _mm_max_epu8(abs_qp2mqp1, abs_qp1mqp0); + const __m128i inner_mask = _mm_subs_epu8( + _mm_max_epu8(max_pq, _mm_srli_si128(max_pq, 4)), inner_thresh); + // ~mask + const __m128i zero = _mm_setzero_si128(); + const __m128i a = _mm_or_si128(outer_mask, inner_mask); + const __m128i b = _mm_cmpeq_epi8(a, zero); + return b; +} + +inline __m128i IsFlat3(const __m128i& qp2, const __m128i& qp1, + const __m128i& qp0, const __m128i& flat_thresh) { + const __m128i abs_pq2mpq0 = AbsDiff(qp2, qp0); + const __m128i abs_qp1mqp0 = AbsDiff(qp1, qp0); + const __m128i max_pq = _mm_max_epu8(abs_pq2mpq0, abs_qp1mqp0); + const __m128i flat_mask = _mm_subs_epu8( + _mm_max_epu8(max_pq, _mm_srli_si128(max_pq, 4)), flat_thresh); + // ~mask + const __m128i zero = _mm_setzero_si128(); + const __m128i a = _mm_cmpeq_epi8(flat_mask, zero); + return a; +} + +inline void Filter6(const __m128i& qp2, const __m128i& qp1, const __m128i& qp0, + __m128i* oqp1, __m128i* oqp0) { + const __m128i four = _mm_set1_epi16(4); + const __m128i qp2_lo = _mm_cvtepu8_epi16(qp2); + const __m128i qp1_lo = _mm_cvtepu8_epi16(qp1); + const __m128i qp0_lo = _mm_cvtepu8_epi16(qp0); + const __m128i pq1_lo = _mm_shuffle_epi32(qp1_lo, 0x4e); + const __m128i pq0_lo = _mm_shuffle_epi32(qp0_lo, 0x4e); + + __m128i f6_lo = + _mm_add_epi16(_mm_add_epi16(qp2_lo, four), _mm_add_epi16(qp2_lo, qp2_lo)); + + f6_lo = _mm_add_epi16(_mm_add_epi16(f6_lo, qp1_lo), qp1_lo); + + f6_lo = _mm_add_epi16(_mm_add_epi16(f6_lo, qp0_lo), + _mm_add_epi16(qp0_lo, pq0_lo)); + + // p2 * 3 + p1 * 2 + p0 * 2 + q0 + // q2 * 3 + q1 * 2 + q0 * 2 + p0 + *oqp1 = _mm_srli_epi16(f6_lo, 3); + *oqp1 = _mm_packus_epi16(*oqp1, *oqp1); + + // p2 + p1 * 2 + p0 * 2 + q0 * 2 + q1 + // q2 + q1 * 2 + q0 * 2 + p0 * 2 + p1 + f6_lo = FilterAdd2Sub2(f6_lo, pq0_lo, pq1_lo, qp2_lo, qp2_lo); + *oqp0 = _mm_srli_epi16(f6_lo, 3); + *oqp0 = _mm_packus_epi16(*oqp0, *oqp0); +} + +void Horizontal6(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh) { + auto* const dst = static_cast<uint8_t*>(dest); + const __m128i zero = _mm_setzero_si128(); + const __m128i v_flat_thresh = _mm_set1_epi8(1); + const __m128i v_outer_thresh = + _mm_shuffle_epi8(_mm_cvtsi32_si128(outer_thresh), zero); + const __m128i v_inner_thresh = + _mm_shuffle_epi8(_mm_cvtsi32_si128(inner_thresh), zero); + const __m128i v_hev_thresh0 = + _mm_shuffle_epi8(_mm_cvtsi32_si128(hev_thresh), zero); + const __m128i v_hev_thresh = _mm_unpacklo_epi8(v_hev_thresh0, zero); + + const __m128i p2 = Load4(dst - 3 * stride); + const __m128i p1 = Load4(dst - 2 * stride); + const __m128i p0 = Load4(dst - 1 * stride); + const __m128i q0 = Load4(dst + 0 * stride); + const __m128i q1 = Load4(dst + 1 * stride); + const __m128i q2 = Load4(dst + 2 * stride); + const __m128i qp2 = _mm_unpacklo_epi32(p2, q2); + const __m128i qp1 = _mm_unpacklo_epi32(p1, q1); + const __m128i qp0 = _mm_unpacklo_epi32(p0, q0); + const __m128i q1q0 = _mm_unpacklo_epi32(q0, q1); + const __m128i p1p0 = _mm_unpacklo_epi32(p0, p1); + + const __m128i v_hev_mask = Hev(qp1, qp0, v_hev_thresh); + const __m128i v_needs_mask = + NeedsFilter6(q1q0, p1p0, qp2, qp1, qp0, v_outer_thresh, v_inner_thresh); + __m128i oqp1; + __m128i oqp0; + + Filter4(qp1, qp0, &oqp1, &oqp0, v_needs_mask, v_hev_mask); + + const __m128i v_isflat3_mask = IsFlat3(qp2, qp1, qp0, v_flat_thresh); + const __m128i v_mask = + _mm_shuffle_epi32(_mm_and_si128(v_needs_mask, v_isflat3_mask), 0); + + if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi8(v_mask, v_mask)) == 0) { + __m128i oqp1_f6; + __m128i oqp0_f6; + + Filter6(qp2, qp1, qp0, &oqp1_f6, &oqp0_f6); + + oqp1 = _mm_blendv_epi8(oqp1, oqp1_f6, v_mask); + oqp0 = _mm_blendv_epi8(oqp0, oqp0_f6, v_mask); + } + + Store4(dst - 2 * stride, oqp1); + Store4(dst - 1 * stride, oqp0); + Store4(dst + 0 * stride, _mm_srli_si128(oqp0, 4)); + Store4(dst + 1 * stride, _mm_srli_si128(oqp1, 4)); +} + +inline void Transpose8x4To4x8(const __m128i& x0, const __m128i& x1, + const __m128i& x2, const __m128i& x3, __m128i* d0, + __m128i* d1, __m128i* d2, __m128i* d3, + __m128i* d4, __m128i* d5, __m128i* d6, + __m128i* d7) { + // input + // x0 00 01 02 03 04 05 06 07 xx xx xx xx xx xx xx xx + // x1 10 11 12 13 14 15 16 17 xx xx xx xx xx xx xx xx + // x2 20 21 22 23 24 25 26 27 xx xx xx xx xx xx xx xx + // x3 30 31 32 33 34 35 36 37 xx xx xx xx xx xx xx xx + // output + // 00 10 20 30 xx xx xx xx xx xx xx xx xx xx xx xx + // 01 11 21 31 xx xx xx xx xx xx xx xx xx xx xx xx + // 02 12 22 32 xx xx xx xx xx xx xx xx xx xx xx xx + // 03 13 23 33 xx xx xx xx xx xx xx xx xx xx xx xx + // 04 14 24 34 xx xx xx xx xx xx xx xx xx xx xx xx + // 05 15 25 35 xx xx xx xx xx xx xx xx xx xx xx xx + // 06 16 26 36 xx xx xx xx xx xx xx xx xx xx xx xx + // 07 17 27 37 xx xx xx xx xx xx xx xx xx xx xx xx + + // 00 10 01 11 02 12 03 13 04 14 05 15 06 16 07 17 + const __m128i w0 = _mm_unpacklo_epi8(x0, x1); + // 20 30 21 31 22 32 23 33 24 34 25 35 26 36 27 37 + const __m128i w1 = _mm_unpacklo_epi8(x2, x3); + // 00 10 20 30 01 11 21 31 02 12 22 32 03 13 23 33 + const __m128i ww0 = _mm_unpacklo_epi16(w0, w1); + // 04 14 24 34 05 15 25 35 06 16 26 36 07 17 27 37 + const __m128i ww1 = _mm_unpackhi_epi16(w0, w1); + + // 00 10 20 30 xx xx xx xx xx xx xx xx xx xx xx xx + *d0 = ww0; + // 01 11 21 31 xx xx xx xx xx xx xx xx xx xx xx xx + *d1 = _mm_srli_si128(ww0, 4); + // 02 12 22 32 xx xx xx xx xx xx xx xx xx xx xx xx + *d2 = _mm_srli_si128(ww0, 8); + // 03 13 23 33 xx xx xx xx xx xx xx xx xx xx xx xx + *d3 = _mm_srli_si128(ww0, 12); + // 04 14 24 34 xx xx xx xx xx xx xx xx xx xx xx xx + *d4 = ww1; + // 05 15 25 35 xx xx xx xx xx xx xx xx xx xx xx xx + *d5 = _mm_srli_si128(ww1, 4); + // 06 16 26 36 xx xx xx xx xx xx xx xx xx xx xx xx + *d6 = _mm_srli_si128(ww1, 8); + // 07 17 27 37 xx xx xx xx xx xx xx xx xx xx xx xx + *d7 = _mm_srli_si128(ww1, 12); +} + +void Vertical6(void* dest, ptrdiff_t stride, int outer_thresh, int inner_thresh, + int hev_thresh) { + auto* const dst = static_cast<uint8_t*>(dest); + const __m128i zero = _mm_setzero_si128(); + const __m128i v_flat_thresh = _mm_set1_epi8(1); + const __m128i v_outer_thresh = + _mm_shuffle_epi8(_mm_cvtsi32_si128(outer_thresh), zero); + const __m128i v_inner_thresh = + _mm_shuffle_epi8(_mm_cvtsi32_si128(inner_thresh), zero); + const __m128i v_hev_thresh0 = + _mm_shuffle_epi8(_mm_cvtsi32_si128(hev_thresh), zero); + const __m128i v_hev_thresh = _mm_unpacklo_epi8(v_hev_thresh0, zero); + + __m128i x0 = LoadLo8(dst - 3 + 0 * stride); + __m128i x1 = LoadLo8(dst - 3 + 1 * stride); + __m128i x2 = LoadLo8(dst - 3 + 2 * stride); + __m128i x3 = LoadLo8(dst - 3 + 3 * stride); + + __m128i p2, p1, p0, q0, q1, q2; + __m128i z0, z1; // not used + + Transpose8x4To4x8(x0, x1, x2, x3, &p2, &p1, &p0, &q0, &q1, &q2, &z0, &z1); + + const __m128i qp2 = _mm_unpacklo_epi32(p2, q2); + const __m128i qp1 = _mm_unpacklo_epi32(p1, q1); + const __m128i qp0 = _mm_unpacklo_epi32(p0, q0); + const __m128i q1q0 = _mm_unpacklo_epi32(q0, q1); + const __m128i p1p0 = _mm_unpacklo_epi32(p0, p1); + + const __m128i v_hev_mask = Hev(qp1, qp0, v_hev_thresh); + const __m128i v_needs_mask = + NeedsFilter6(q1q0, p1p0, qp2, qp1, qp0, v_outer_thresh, v_inner_thresh); + __m128i oqp1; + __m128i oqp0; + + Filter4(qp1, qp0, &oqp1, &oqp0, v_needs_mask, v_hev_mask); + + const __m128i v_isflat3_mask = IsFlat3(qp2, qp1, qp0, v_flat_thresh); + const __m128i v_mask = + _mm_shuffle_epi32(_mm_and_si128(v_needs_mask, v_isflat3_mask), 0); + + if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi8(v_mask, v_mask)) == 0) { + __m128i oqp1_f6; + __m128i oqp0_f6; + + Filter6(qp2, qp1, qp0, &oqp1_f6, &oqp0_f6); + + oqp1 = _mm_blendv_epi8(oqp1, oqp1_f6, v_mask); + oqp0 = _mm_blendv_epi8(oqp0, oqp0_f6, v_mask); + } + + p1 = oqp1; + p0 = oqp0; + q0 = _mm_srli_si128(oqp0, 4); + q1 = _mm_srli_si128(oqp1, 4); + + Transpose4x4(p1, p0, q0, q1, &x0, &x1, &x2, &x3); + + Store4(dst - 2 + 0 * stride, x0); + Store4(dst - 2 + 1 * stride, x1); + Store4(dst - 2 + 2 * stride, x2); + Store4(dst - 2 + 3 * stride, x3); +} + +//------------------------------------------------------------------------------ +// 7-tap filters + +inline __m128i NeedsFilter8(const __m128i& q1q0, const __m128i& p1p0, + const __m128i& qp3, const __m128i& qp2, + const __m128i& qp1, const __m128i& qp0, + const __m128i& outer_thresh, + const __m128i& inner_thresh) { + const __m128i outer_mask = CheckOuterThreshF4(q1q0, p1p0, outer_thresh); + const __m128i abs_qp2mqp1 = AbsDiff(qp2, qp1); + const __m128i abs_qp1mqp0 = AbsDiff(qp1, qp0); + const __m128i max_pq_a = _mm_max_epu8(abs_qp2mqp1, abs_qp1mqp0); + const __m128i abs_pq3mpq2 = AbsDiff(qp3, qp2); + const __m128i max_pq = _mm_max_epu8(max_pq_a, abs_pq3mpq2); + const __m128i inner_mask = _mm_subs_epu8( + _mm_max_epu8(max_pq, _mm_srli_si128(max_pq, 4)), inner_thresh); + // ~mask + const __m128i zero = _mm_setzero_si128(); + const __m128i a = _mm_or_si128(outer_mask, inner_mask); + const __m128i b = _mm_cmpeq_epi8(a, zero); + return b; +} + +inline __m128i IsFlat4(const __m128i& qp3, const __m128i& qp2, + const __m128i& qp1, const __m128i& qp0, + const __m128i& flat_thresh) { + const __m128i abs_pq2mpq0 = AbsDiff(qp2, qp0); + const __m128i abs_qp1mqp0 = AbsDiff(qp1, qp0); + const __m128i max_pq_a = _mm_max_epu8(abs_pq2mpq0, abs_qp1mqp0); + const __m128i abs_pq3mpq0 = AbsDiff(qp3, qp0); + const __m128i max_pq = _mm_max_epu8(max_pq_a, abs_pq3mpq0); + const __m128i flat_mask = _mm_subs_epu8( + _mm_max_epu8(max_pq, _mm_srli_si128(max_pq, 4)), flat_thresh); + // ~mask + const __m128i zero = _mm_setzero_si128(); + const __m128i a = _mm_cmpeq_epi8(flat_mask, zero); + return a; +} + +inline void Filter8(const __m128i& qp3, const __m128i& qp2, const __m128i& qp1, + const __m128i& qp0, __m128i* oqp2, __m128i* oqp1, + __m128i* oqp0) { + const __m128i four = _mm_set1_epi16(4); + const __m128i qp3_lo = _mm_cvtepu8_epi16(qp3); + const __m128i qp2_lo = _mm_cvtepu8_epi16(qp2); + const __m128i qp1_lo = _mm_cvtepu8_epi16(qp1); + const __m128i qp0_lo = _mm_cvtepu8_epi16(qp0); + const __m128i pq2_lo = _mm_shuffle_epi32(qp2_lo, 0x4e); + const __m128i pq1_lo = _mm_shuffle_epi32(qp1_lo, 0x4e); + const __m128i pq0_lo = _mm_shuffle_epi32(qp0_lo, 0x4e); + + __m128i f8_lo = + _mm_add_epi16(_mm_add_epi16(qp3_lo, four), _mm_add_epi16(qp3_lo, qp3_lo)); + + f8_lo = _mm_add_epi16(_mm_add_epi16(f8_lo, qp2_lo), qp2_lo); + + f8_lo = _mm_add_epi16(_mm_add_epi16(f8_lo, qp1_lo), + _mm_add_epi16(qp0_lo, pq0_lo)); + + // p3 + p3 + p3 + 2 * p2 + p1 + p0 + q0 + // q3 + q3 + q3 + 2 * q2 + q1 + q0 + p0 + *oqp2 = _mm_srli_epi16(f8_lo, 3); + *oqp2 = _mm_packus_epi16(*oqp2, *oqp2); + + // p3 + p3 + p2 + 2 * p1 + p0 + q0 + q1 + // q3 + q3 + q2 + 2 * q1 + q0 + p0 + p1 + f8_lo = FilterAdd2Sub2(f8_lo, qp1_lo, pq1_lo, qp3_lo, qp2_lo); + *oqp1 = _mm_srli_epi16(f8_lo, 3); + *oqp1 = _mm_packus_epi16(*oqp1, *oqp1); + + // p3 + p2 + p1 + 2 * p0 + q0 + q1 + q2 + // q3 + q2 + q1 + 2 * q0 + p0 + p1 + p2 + f8_lo = FilterAdd2Sub2(f8_lo, qp0_lo, pq2_lo, qp3_lo, qp1_lo); + *oqp0 = _mm_srli_epi16(f8_lo, 3); + *oqp0 = _mm_packus_epi16(*oqp0, *oqp0); +} + +void Horizontal8(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh) { + auto* const dst = static_cast<uint8_t*>(dest); + const __m128i zero = _mm_setzero_si128(); + const __m128i v_flat_thresh = _mm_set1_epi8(1); + const __m128i v_outer_thresh = + _mm_shuffle_epi8(_mm_cvtsi32_si128(outer_thresh), zero); + const __m128i v_inner_thresh = + _mm_shuffle_epi8(_mm_cvtsi32_si128(inner_thresh), zero); + const __m128i v_hev_thresh0 = + _mm_shuffle_epi8(_mm_cvtsi32_si128(hev_thresh), zero); + const __m128i v_hev_thresh = _mm_unpacklo_epi8(v_hev_thresh0, zero); + + const __m128i p3 = Load4(dst - 4 * stride); + const __m128i p2 = Load4(dst - 3 * stride); + const __m128i p1 = Load4(dst - 2 * stride); + const __m128i p0 = Load4(dst - 1 * stride); + const __m128i q0 = Load4(dst + 0 * stride); + const __m128i q1 = Load4(dst + 1 * stride); + const __m128i q2 = Load4(dst + 2 * stride); + const __m128i q3 = Load4(dst + 3 * stride); + + const __m128i qp3 = _mm_unpacklo_epi32(p3, q3); + const __m128i qp2 = _mm_unpacklo_epi32(p2, q2); + const __m128i qp1 = _mm_unpacklo_epi32(p1, q1); + const __m128i qp0 = _mm_unpacklo_epi32(p0, q0); + const __m128i q1q0 = _mm_unpacklo_epi32(q0, q1); + const __m128i p1p0 = _mm_unpacklo_epi32(p0, p1); + + const __m128i v_hev_mask = Hev(qp1, qp0, v_hev_thresh); + const __m128i v_needs_mask = NeedsFilter8(q1q0, p1p0, qp3, qp2, qp1, qp0, + v_outer_thresh, v_inner_thresh); + __m128i oqp1; + __m128i oqp0; + + Filter4(qp1, qp0, &oqp1, &oqp0, v_needs_mask, v_hev_mask); + + const __m128i v_isflat4_mask = IsFlat4(qp3, qp2, qp1, qp0, v_flat_thresh); + const __m128i v_mask = + _mm_shuffle_epi32(_mm_and_si128(v_needs_mask, v_isflat4_mask), 0); + + if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi8(v_mask, v_mask)) == 0) { + __m128i oqp2_f8; + __m128i oqp1_f8; + __m128i oqp0_f8; + + Filter8(qp3, qp2, qp1, qp0, &oqp2_f8, &oqp1_f8, &oqp0_f8); + + oqp2_f8 = _mm_blendv_epi8(qp2, oqp2_f8, v_mask); + oqp1 = _mm_blendv_epi8(oqp1, oqp1_f8, v_mask); + oqp0 = _mm_blendv_epi8(oqp0, oqp0_f8, v_mask); + Store4(dst - 3 * stride, oqp2_f8); + Store4(dst + 2 * stride, _mm_srli_si128(oqp2_f8, 4)); + } + + Store4(dst - 2 * stride, oqp1); + Store4(dst - 1 * stride, oqp0); + Store4(dst + 0 * stride, _mm_srli_si128(oqp0, 4)); + Store4(dst + 1 * stride, _mm_srli_si128(oqp1, 4)); +} + +inline void Transpose8x8To8x4(const __m128i& x0, const __m128i& x1, + const __m128i& x2, const __m128i& x3, + const __m128i& x4, const __m128i& x5, + const __m128i& x6, const __m128i& x7, __m128i* d0, + __m128i* d1, __m128i* d2, __m128i* d3) { + // input + // x0 00 01 02 03 04 05 06 07 + // x1 10 11 12 13 14 15 16 17 + // x2 20 21 22 23 24 25 26 27 + // x3 30 31 32 33 34 35 36 37 + // x4 40 41 42 43 44 45 46 47 + // x5 50 51 52 53 54 55 56 57 + // x6 60 61 62 63 64 65 66 67 + // x7 70 71 72 73 74 75 76 77 + // output + // d0 00 10 20 30 40 50 60 70 xx xx xx xx xx xx xx xx + // d1 01 11 21 31 41 51 61 71 xx xx xx xx xx xx xx xx + // d2 02 12 22 32 42 52 62 72 xx xx xx xx xx xx xx xx + // d3 03 13 23 33 43 53 63 73 xx xx xx xx xx xx xx xx + + // 00 10 01 11 02 12 03 13 04 14 05 15 06 16 07 17 + const __m128i w0 = _mm_unpacklo_epi8(x0, x1); + // 20 30 21 31 22 32 23 33 24 34 25 35 26 36 27 37 + const __m128i w1 = _mm_unpacklo_epi8(x2, x3); + // 40 50 41 51 42 52 43 53 44 54 45 55 46 56 47 57 + const __m128i w2 = _mm_unpacklo_epi8(x4, x5); + // 60 70 61 71 62 72 63 73 64 74 65 75 66 76 67 77 + const __m128i w3 = _mm_unpacklo_epi8(x6, x7); + + // 00 10 20 30 01 11 21 31 02 12 22 32 03 13 23 33 + const __m128i w4 = _mm_unpacklo_epi16(w0, w1); + // 40 50 60 70 41 51 61 71 42 52 62 72 43 53 63 73 + const __m128i w5 = _mm_unpacklo_epi16(w2, w3); + + // 00 10 20 30 40 50 60 70 01 11 21 31 41 51 61 71 + *d0 = _mm_unpacklo_epi32(w4, w5); + *d1 = _mm_srli_si128(*d0, 8); + // 02 12 22 32 42 52 62 72 03 13 23 33 43 53 63 73 + *d2 = _mm_unpackhi_epi32(w4, w5); + *d3 = _mm_srli_si128(*d2, 8); +} + +void Vertical8(void* dest, ptrdiff_t stride, int outer_thresh, int inner_thresh, + int hev_thresh) { + auto* const dst = static_cast<uint8_t*>(dest); + const __m128i zero = _mm_setzero_si128(); + const __m128i v_flat_thresh = _mm_set1_epi8(1); + const __m128i v_outer_thresh = + _mm_shuffle_epi8(_mm_cvtsi32_si128(outer_thresh), zero); + const __m128i v_inner_thresh = + _mm_shuffle_epi8(_mm_cvtsi32_si128(inner_thresh), zero); + const __m128i v_hev_thresh0 = + _mm_shuffle_epi8(_mm_cvtsi32_si128(hev_thresh), zero); + const __m128i v_hev_thresh = _mm_unpacklo_epi8(v_hev_thresh0, zero); + + __m128i x0 = LoadLo8(dst - 4 + 0 * stride); + __m128i x1 = LoadLo8(dst - 4 + 1 * stride); + __m128i x2 = LoadLo8(dst - 4 + 2 * stride); + __m128i x3 = LoadLo8(dst - 4 + 3 * stride); + + __m128i p3, p2, p1, p0, q0, q1, q2, q3; + Transpose8x4To4x8(x0, x1, x2, x3, &p3, &p2, &p1, &p0, &q0, &q1, &q2, &q3); + + const __m128i qp3 = _mm_unpacklo_epi32(p3, q3); + const __m128i qp2 = _mm_unpacklo_epi32(p2, q2); + const __m128i qp1 = _mm_unpacklo_epi32(p1, q1); + const __m128i qp0 = _mm_unpacklo_epi32(p0, q0); + const __m128i q1q0 = _mm_unpacklo_epi32(q0, q1); + const __m128i p1p0 = _mm_unpacklo_epi32(p0, p1); + + const __m128i v_hev_mask = Hev(qp1, qp0, v_hev_thresh); + const __m128i v_needs_mask = NeedsFilter8(q1q0, p1p0, qp3, qp2, qp1, qp0, + v_outer_thresh, v_inner_thresh); + __m128i oqp1; + __m128i oqp0; + + Filter4(qp1, qp0, &oqp1, &oqp0, v_needs_mask, v_hev_mask); + + const __m128i v_isflat4_mask = IsFlat4(qp3, qp2, qp1, qp0, v_flat_thresh); + const __m128i v_mask = + _mm_shuffle_epi32(_mm_and_si128(v_needs_mask, v_isflat4_mask), 0); + + if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi8(v_mask, v_mask)) == 0) { + __m128i oqp2_f8; + __m128i oqp1_f8; + __m128i oqp0_f8; + + Filter8(qp3, qp2, qp1, qp0, &oqp2_f8, &oqp1_f8, &oqp0_f8); + + oqp2_f8 = _mm_blendv_epi8(qp2, oqp2_f8, v_mask); + oqp1 = _mm_blendv_epi8(oqp1, oqp1_f8, v_mask); + oqp0 = _mm_blendv_epi8(oqp0, oqp0_f8, v_mask); + + p2 = oqp2_f8; + q2 = _mm_srli_si128(oqp2_f8, 4); + } + + p1 = oqp1; + p0 = oqp0; + q0 = _mm_srli_si128(oqp0, 4); + q1 = _mm_srli_si128(oqp1, 4); + + Transpose8x8To8x4(p3, p2, p1, p0, q0, q1, q2, q3, &x0, &x1, &x2, &x3); + + StoreLo8(dst - 4 + 0 * stride, x0); + StoreLo8(dst - 4 + 1 * stride, x1); + StoreLo8(dst - 4 + 2 * stride, x2); + StoreLo8(dst - 4 + 3 * stride, x3); +} + +//------------------------------------------------------------------------------ +// 13-tap filters + +inline void Filter14(const __m128i& qp6, const __m128i& qp5, const __m128i& qp4, + const __m128i& qp3, const __m128i& qp2, const __m128i& qp1, + const __m128i& qp0, __m128i* oqp5, __m128i* oqp4, + __m128i* oqp3, __m128i* oqp2, __m128i* oqp1, + __m128i* oqp0) { + const __m128i eight = _mm_set1_epi16(8); + const __m128i qp6_lo = _mm_cvtepu8_epi16(qp6); + const __m128i qp5_lo = _mm_cvtepu8_epi16(qp5); + const __m128i qp4_lo = _mm_cvtepu8_epi16(qp4); + const __m128i qp3_lo = _mm_cvtepu8_epi16(qp3); + const __m128i qp2_lo = _mm_cvtepu8_epi16(qp2); + const __m128i qp1_lo = _mm_cvtepu8_epi16(qp1); + const __m128i qp0_lo = _mm_cvtepu8_epi16(qp0); + const __m128i pq5_lo = _mm_shuffle_epi32(qp5_lo, 0x4e); + const __m128i pq4_lo = _mm_shuffle_epi32(qp4_lo, 0x4e); + const __m128i pq3_lo = _mm_shuffle_epi32(qp3_lo, 0x4e); + const __m128i pq2_lo = _mm_shuffle_epi32(qp2_lo, 0x4e); + const __m128i pq1_lo = _mm_shuffle_epi32(qp1_lo, 0x4e); + const __m128i pq0_lo = _mm_shuffle_epi32(qp0_lo, 0x4e); + + __m128i f14_lo = + _mm_add_epi16(eight, _mm_sub_epi16(_mm_slli_epi16(qp6_lo, 3), qp6_lo)); + + f14_lo = _mm_add_epi16(_mm_add_epi16(f14_lo, qp5_lo), + _mm_add_epi16(qp5_lo, qp4_lo)); + + f14_lo = _mm_add_epi16(_mm_add_epi16(f14_lo, qp4_lo), + _mm_add_epi16(qp3_lo, qp2_lo)); + + f14_lo = _mm_add_epi16(_mm_add_epi16(f14_lo, qp1_lo), + _mm_add_epi16(qp0_lo, pq0_lo)); + + // p6 * 7 + p5 * 2 + p4 * 2 + p3 + p2 + p1 + p0 + q0 + // q6 * 7 + q5 * 2 + q4 * 2 + q3 + q2 + q1 + q0 + p0 + *oqp5 = _mm_srli_epi16(f14_lo, 4); + *oqp5 = _mm_packus_epi16(*oqp5, *oqp5); + + // p6 * 5 + p5 * 2 + p4 * 2 + p3 * 2 + p2 + p1 + p0 + q0 + q1 + // q6 * 5 + q5 * 2 + q4 * 2 + q3 * 2 + q2 + q1 + q0 + p0 + p1 + f14_lo = FilterAdd2Sub2(f14_lo, qp3_lo, pq1_lo, qp6_lo, qp6_lo); + *oqp4 = _mm_srli_epi16(f14_lo, 4); + *oqp4 = _mm_packus_epi16(*oqp4, *oqp4); + + // p6 * 4 + p5 + p4 * 2 + p3 * 2 + p2 * 2 + p1 + p0 + q0 + q1 + q2 + // q6 * 4 + q5 + q4 * 2 + q3 * 2 + q2 * 2 + q1 + q0 + p0 + p1 + p2 + f14_lo = FilterAdd2Sub2(f14_lo, qp2_lo, pq2_lo, qp6_lo, qp5_lo); + *oqp3 = _mm_srli_epi16(f14_lo, 4); + *oqp3 = _mm_packus_epi16(*oqp3, *oqp3); + + // p6 * 3 + p5 + p4 + p3 * 2 + p2 * 2 + p1 * 2 + p0 + q0 + q1 + q2 + q3 + // q6 * 3 + q5 + q4 + q3 * 2 + q2 * 2 + q1 * 2 + q0 + p0 + p1 + p2 + p3 + f14_lo = FilterAdd2Sub2(f14_lo, qp1_lo, pq3_lo, qp6_lo, qp4_lo); + *oqp2 = _mm_srli_epi16(f14_lo, 4); + *oqp2 = _mm_packus_epi16(*oqp2, *oqp2); + + // p6 * 2 + p5 + p4 + p3 + p2 * 2 + p1 * 2 + p0 * 2 + q0 + q1 + q2 + q3 + q4 + // q6 * 2 + q5 + q4 + q3 + q2 * 2 + q1 * 2 + q0 * 2 + p0 + p1 + p2 + p3 + p4 + f14_lo = FilterAdd2Sub2(f14_lo, qp0_lo, pq4_lo, qp6_lo, qp3_lo); + *oqp1 = _mm_srli_epi16(f14_lo, 4); + *oqp1 = _mm_packus_epi16(*oqp1, *oqp1); + + // p6 + p5 + p4 + p3 + p2 + p1 * 2 + p0 * 2 + q0 * 2 + q1 + q2 + q3 + q4 + q5 + // q6 + q5 + q4 + q3 + q2 + q1 * 2 + q0 * 2 + p0 * 2 + p1 + p2 + p3 + p4 + p5 + f14_lo = FilterAdd2Sub2(f14_lo, pq0_lo, pq5_lo, qp6_lo, qp2_lo); + *oqp0 = _mm_srli_epi16(f14_lo, 4); + *oqp0 = _mm_packus_epi16(*oqp0, *oqp0); +} + +void Horizontal14(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh) { + auto* const dst = static_cast<uint8_t*>(dest); + const __m128i zero = _mm_setzero_si128(); + const __m128i v_flat_thresh = _mm_set1_epi8(1); + const __m128i v_outer_thresh = + _mm_shuffle_epi8(_mm_cvtsi32_si128(outer_thresh), zero); + const __m128i v_inner_thresh = + _mm_shuffle_epi8(_mm_cvtsi32_si128(inner_thresh), zero); + const __m128i v_hev_thresh0 = + _mm_shuffle_epi8(_mm_cvtsi32_si128(hev_thresh), zero); + const __m128i v_hev_thresh = _mm_unpacklo_epi8(v_hev_thresh0, zero); + + const __m128i p3 = Load4(dst - 4 * stride); + const __m128i p2 = Load4(dst - 3 * stride); + const __m128i p1 = Load4(dst - 2 * stride); + const __m128i p0 = Load4(dst - 1 * stride); + const __m128i q0 = Load4(dst + 0 * stride); + const __m128i q1 = Load4(dst + 1 * stride); + const __m128i q2 = Load4(dst + 2 * stride); + const __m128i q3 = Load4(dst + 3 * stride); + + const __m128i qp3 = _mm_unpacklo_epi32(p3, q3); + const __m128i qp2 = _mm_unpacklo_epi32(p2, q2); + const __m128i qp1 = _mm_unpacklo_epi32(p1, q1); + const __m128i qp0 = _mm_unpacklo_epi32(p0, q0); + const __m128i q1q0 = _mm_unpacklo_epi32(q0, q1); + const __m128i p1p0 = _mm_unpacklo_epi32(p0, p1); + + const __m128i v_hev_mask = Hev(qp1, qp0, v_hev_thresh); + const __m128i v_needs_mask = NeedsFilter8(q1q0, p1p0, qp3, qp2, qp1, qp0, + v_outer_thresh, v_inner_thresh); + + __m128i oqp1; + __m128i oqp0; + + Filter4(qp1, qp0, &oqp1, &oqp0, v_needs_mask, v_hev_mask); + + const __m128i v_isflat4_mask = IsFlat4(qp3, qp2, qp1, qp0, v_flat_thresh); + const __m128i v_mask = + _mm_shuffle_epi32(_mm_and_si128(v_needs_mask, v_isflat4_mask), 0); + + if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi8(v_mask, v_mask)) == 0) { + const __m128i p6 = Load4(dst - 7 * stride); + const __m128i p5 = Load4(dst - 6 * stride); + const __m128i p4 = Load4(dst - 5 * stride); + const __m128i q4 = Load4(dst + 4 * stride); + const __m128i q5 = Load4(dst + 5 * stride); + const __m128i q6 = Load4(dst + 6 * stride); + const __m128i qp6 = _mm_unpacklo_epi32(p6, q6); + const __m128i qp5 = _mm_unpacklo_epi32(p5, q5); + const __m128i qp4 = _mm_unpacklo_epi32(p4, q4); + + const __m128i v_isflatouter4_mask = + IsFlat4(qp6, qp5, qp4, qp0, v_flat_thresh); + const __m128i v_flat4_mask = + _mm_shuffle_epi32(_mm_and_si128(v_mask, v_isflatouter4_mask), 0); + + __m128i oqp2_f8; + __m128i oqp1_f8; + __m128i oqp0_f8; + + Filter8(qp3, qp2, qp1, qp0, &oqp2_f8, &oqp1_f8, &oqp0_f8); + + oqp2_f8 = _mm_blendv_epi8(qp2, oqp2_f8, v_mask); + oqp1 = _mm_blendv_epi8(oqp1, oqp1_f8, v_mask); + oqp0 = _mm_blendv_epi8(oqp0, oqp0_f8, v_mask); + + if (_mm_test_all_zeros(v_flat4_mask, + _mm_cmpeq_epi8(v_flat4_mask, v_flat4_mask)) == 0) { + __m128i oqp5_f14; + __m128i oqp4_f14; + __m128i oqp3_f14; + __m128i oqp2_f14; + __m128i oqp1_f14; + __m128i oqp0_f14; + + Filter14(qp6, qp5, qp4, qp3, qp2, qp1, qp0, &oqp5_f14, &oqp4_f14, + &oqp3_f14, &oqp2_f14, &oqp1_f14, &oqp0_f14); + + oqp5_f14 = _mm_blendv_epi8(qp5, oqp5_f14, v_flat4_mask); + oqp4_f14 = _mm_blendv_epi8(qp4, oqp4_f14, v_flat4_mask); + oqp3_f14 = _mm_blendv_epi8(qp3, oqp3_f14, v_flat4_mask); + oqp2_f8 = _mm_blendv_epi8(oqp2_f8, oqp2_f14, v_flat4_mask); + oqp1 = _mm_blendv_epi8(oqp1, oqp1_f14, v_flat4_mask); + oqp0 = _mm_blendv_epi8(oqp0, oqp0_f14, v_flat4_mask); + + Store4(dst - 6 * stride, oqp5_f14); + Store4(dst - 5 * stride, oqp4_f14); + Store4(dst - 4 * stride, oqp3_f14); + Store4(dst + 3 * stride, _mm_srli_si128(oqp3_f14, 4)); + Store4(dst + 4 * stride, _mm_srli_si128(oqp4_f14, 4)); + Store4(dst + 5 * stride, _mm_srli_si128(oqp5_f14, 4)); + } + + Store4(dst - 3 * stride, oqp2_f8); + Store4(dst + 2 * stride, _mm_srli_si128(oqp2_f8, 4)); + } + + Store4(dst - 2 * stride, oqp1); + Store4(dst - 1 * stride, oqp0); + Store4(dst + 0 * stride, _mm_srli_si128(oqp0, 4)); + Store4(dst + 1 * stride, _mm_srli_si128(oqp1, 4)); +} + +// Each of the 8x4 blocks of input data (p7-p0 and q0-q7) are transposed to 4x8, +// then unpacked to the correct qp register. (qp7 - qp0) +// +// p7 p6 p5 p4 p3 p2 p1 p0 q0 q1 q2 q3 q4 q5 q6 q7 +// +// 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f +// 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f +// 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f +// 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f + +inline void DualTranspose8x4To4x8(const __m128i& x0, const __m128i& x1, + const __m128i& x2, const __m128i& x3, + __m128i* q0p0, __m128i* q1p1, __m128i* q2p2, + __m128i* q3p3, __m128i* q4p4, __m128i* q5p5, + __m128i* q6p6, __m128i* q7p7) { + // 00 10 01 11 02 12 03 13 04 14 05 15 06 16 07 17 + const __m128i w0 = _mm_unpacklo_epi8(x0, x1); + // 20 30 21 31 22 32 23 33 24 34 25 35 26 36 27 37 + const __m128i w1 = _mm_unpacklo_epi8(x2, x3); + // 08 18 09 19 0a 1a 0b 1b 0c 1c 0d 1d 0e 1e 0f 1f + const __m128i w2 = _mm_unpackhi_epi8(x0, x1); + // 28 38 29 39 2a 3a 2b 3b 2c 3c 2d 3d 2e 3e 2f 3f + const __m128i w3 = _mm_unpackhi_epi8(x2, x3); + // 00 10 20 30 01 11 21 31 02 12 22 32 03 13 23 33 + const __m128i ww0 = _mm_unpacklo_epi16(w0, w1); + // 04 14 24 34 05 15 25 35 06 16 26 36 07 17 27 37 + const __m128i ww1 = _mm_unpackhi_epi16(w0, w1); + // 08 18 28 38 09 19 29 39 0a 1a 2a 3a 0b 1b 2b 3b + const __m128i ww2 = _mm_unpacklo_epi16(w2, w3); + // 0c 1c 2c 3c 0d 1d 2d 3d 0e 1e 2e 3e 0f 1f 2f 3f + const __m128i ww3 = _mm_unpackhi_epi16(w2, w3); + // 00 10 20 30 0f 1f 2f 3f xx xx xx xx xx xx xx xx + *q7p7 = _mm_unpacklo_epi32(ww0, _mm_srli_si128(ww3, 12)); + // 01 11 21 31 0e 1e 2e 3e xx xx xx xx xx xx xx xx + *q6p6 = _mm_unpackhi_epi32(_mm_slli_si128(ww0, 4), ww3); + // 02 12 22 32 0d 1d 2d 3d xx xx xx xx xx xx xx xx + *q5p5 = _mm_unpackhi_epi32(ww0, _mm_slli_si128(ww3, 4)); + // 03 13 23 33 0c 1c 2c 3c xx xx xx xx xx xx xx xx + *q4p4 = _mm_unpacklo_epi32(_mm_srli_si128(ww0, 12), ww3); + // 04 14 24 34 0b 1b 2b 3b xx xx xx xx xx xx xx xx + *q3p3 = _mm_unpacklo_epi32(ww1, _mm_srli_si128(ww2, 12)); + // 05 15 25 35 0a 1a 2a 3a xx xx xx xx xx xx xx xx + *q2p2 = _mm_unpackhi_epi32(_mm_slli_si128(ww1, 4), ww2); + // 06 16 26 36 09 19 29 39 xx xx xx xx xx xx xx xx + *q1p1 = _mm_unpackhi_epi32(ww1, _mm_slli_si128(ww2, 4)); + // 07 17 27 37 08 18 28 38 xx xx xx xx xx xx xx xx + *q0p0 = _mm_unpacklo_epi32(_mm_srli_si128(ww1, 12), ww2); +} + +inline void DualTranspose4x8To8x4(const __m128i& qp7, const __m128i& qp6, + const __m128i& qp5, const __m128i& qp4, + const __m128i& qp3, const __m128i& qp2, + const __m128i& qp1, const __m128i& qp0, + __m128i* x0, __m128i* x1, __m128i* x2, + __m128i* x3) { + // qp7: 00 10 20 30 0f 1f 2f 3f xx xx xx xx xx xx xx xx + // qp6: 01 11 21 31 0e 1e 2e 3e xx xx xx xx xx xx xx xx + // qp5: 02 12 22 32 0d 1d 2d 3d xx xx xx xx xx xx xx xx + // qp4: 03 13 23 33 0c 1c 2c 3c xx xx xx xx xx xx xx xx + // qp3: 04 14 24 34 0b 1b 2b 3b xx xx xx xx xx xx xx xx + // qp2: 05 15 25 35 0a 1a 2a 3a xx xx xx xx xx xx xx xx + // qp1: 06 16 26 36 09 19 29 39 xx xx xx xx xx xx xx xx + // qp0: 07 17 27 37 08 18 28 38 xx xx xx xx xx xx xx xx + + // 00 01 10 11 20 21 30 31 0f 0e 1f 1e 2f 2e 3f 3e + const __m128i w0 = _mm_unpacklo_epi8(qp7, qp6); + // 02 03 12 13 22 23 32 33 xx xx xx xx xx xx xx xx + const __m128i w1 = _mm_unpacklo_epi8(qp5, qp4); + // 04 05 14 15 24 25 34 35 xx xx xx xx xx xx xx xx + const __m128i w2 = _mm_unpacklo_epi8(qp3, qp2); + // 06 07 16 17 26 27 36 37 xx xx xx xx xx xx xx xx + const __m128i w3 = _mm_unpacklo_epi8(qp1, qp0); + // 00 01 02 03 10 11 12 13 20 21 22 23 30 31 32 33 + const __m128i w4 = _mm_unpacklo_epi16(w0, w1); + // 04 05 06 07 14 15 16 17 24 25 26 27 34 35 36 37 + const __m128i w5 = _mm_unpacklo_epi16(w2, w3); + // 00 01 02 03 04 05 06 07 10 11 12 13 14 15 16 17 + const __m128i d0 = _mm_unpacklo_epi32(w4, w5); + // 20 21 22 23 24 25 26 27 30 31 32 33 34 35 36 37 + const __m128i d2 = _mm_unpackhi_epi32(w4, w5); + // xx xx xx xx xx xx xx xx 08 09 18 19 28 29 38 39 + const __m128i w10 = _mm_unpacklo_epi8(qp0, qp1); + // xx xx xx xx xx xx xx xx 0a 0b 1a 1b 2a 2b 3a 3b + const __m128i w11 = _mm_unpacklo_epi8(qp2, qp3); + // xx xx xx xx xx xx xx xx 0c 0d 1c 1d 2c 2d 3c 3d + const __m128i w12 = _mm_unpacklo_epi8(qp4, qp5); + // xx xx xx xx xx xx xx xx 0e 0f 1e 1f 2e 2f 3e 3f + const __m128i w13 = _mm_unpacklo_epi8(qp6, qp7); + // 08 09 0a 0b 18 19 1a 1b 28 29 2a 2b 38 39 3a 3b + const __m128i w14 = _mm_unpackhi_epi16(w10, w11); + // 0c 0d 0e 0f 1c 1d 1e 1f 2c 2d 2e 2f 3c 3d 3e 3f + const __m128i w15 = _mm_unpackhi_epi16(w12, w13); + // 08 09 0a 0b 0c 0d 0e 0f 18 19 1a 1b 1c 1d 1e 1f + const __m128i d1 = _mm_unpacklo_epi32(w14, w15); + // 28 29 2a 2b 2c 2d 2e 2f 38 39 3a 3b 3c 3d 3e 3f + const __m128i d3 = _mm_unpackhi_epi32(w14, w15); + + // p7 p6 p5 p4 p3 p2 p1 p0 q0 q1 q2 q3 q4 q5 q6 q7 + // + // 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f + *x0 = _mm_unpacklo_epi64(d0, d1); + // 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f + *x1 = _mm_unpackhi_epi64(d0, d1); + // 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f + *x2 = _mm_unpacklo_epi64(d2, d3); + // 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f + *x3 = _mm_unpackhi_epi64(d2, d3); +} + +void Vertical14(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh) { + auto* const dst = static_cast<uint8_t*>(dest); + const __m128i zero = _mm_setzero_si128(); + const __m128i v_flat_thresh = _mm_set1_epi8(1); + const __m128i v_outer_thresh = + _mm_shuffle_epi8(_mm_cvtsi32_si128(outer_thresh), zero); + const __m128i v_inner_thresh = + _mm_shuffle_epi8(_mm_cvtsi32_si128(inner_thresh), zero); + const __m128i v_hev_thresh0 = + _mm_shuffle_epi8(_mm_cvtsi32_si128(hev_thresh), zero); + const __m128i v_hev_thresh = _mm_unpacklo_epi8(v_hev_thresh0, zero); + + __m128i x0 = LoadUnaligned16(dst - 8 + 0 * stride); + __m128i x1 = LoadUnaligned16(dst - 8 + 1 * stride); + __m128i x2 = LoadUnaligned16(dst - 8 + 2 * stride); + __m128i x3 = LoadUnaligned16(dst - 8 + 3 * stride); + + __m128i qp7, qp6, qp5, qp4, qp3, qp2, qp1, qp0; + + DualTranspose8x4To4x8(x0, x1, x2, x3, &qp0, &qp1, &qp2, &qp3, &qp4, &qp5, + &qp6, &qp7); + + const __m128i qp1qp0 = _mm_unpacklo_epi64(qp0, qp1); + const __m128i q1q0 = _mm_shuffle_epi32(qp1qp0, 0x0d); + const __m128i p1p0 = _mm_shuffle_epi32(qp1qp0, 0x08); + + const __m128i v_hev_mask = Hev(qp1, qp0, v_hev_thresh); + const __m128i v_needs_mask = NeedsFilter8(q1q0, p1p0, qp3, qp2, qp1, qp0, + v_outer_thresh, v_inner_thresh); + + __m128i oqp1; + __m128i oqp0; + + Filter4(qp1, qp0, &oqp1, &oqp0, v_needs_mask, v_hev_mask); + + const __m128i v_isflat4_mask = IsFlat4(qp3, qp2, qp1, qp0, v_flat_thresh); + const __m128i v_mask = + _mm_shuffle_epi32(_mm_and_si128(v_needs_mask, v_isflat4_mask), 0); + + if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi8(v_mask, v_mask)) == 0) { + const __m128i v_isflatouter4_mask = + IsFlat4(qp6, qp5, qp4, qp0, v_flat_thresh); + const __m128i v_flat4_mask = + _mm_shuffle_epi32(_mm_and_si128(v_mask, v_isflatouter4_mask), 0); + + __m128i oqp2_f8; + __m128i oqp1_f8; + __m128i oqp0_f8; + + Filter8(qp3, qp2, qp1, qp0, &oqp2_f8, &oqp1_f8, &oqp0_f8); + + oqp2_f8 = _mm_blendv_epi8(qp2, oqp2_f8, v_mask); + oqp1 = _mm_blendv_epi8(oqp1, oqp1_f8, v_mask); + oqp0 = _mm_blendv_epi8(oqp0, oqp0_f8, v_mask); + + if (_mm_test_all_zeros(v_flat4_mask, + _mm_cmpeq_epi8(v_flat4_mask, v_flat4_mask)) == 0) { + __m128i oqp5_f14; + __m128i oqp4_f14; + __m128i oqp3_f14; + __m128i oqp2_f14; + __m128i oqp1_f14; + __m128i oqp0_f14; + + Filter14(qp6, qp5, qp4, qp3, qp2, qp1, qp0, &oqp5_f14, &oqp4_f14, + &oqp3_f14, &oqp2_f14, &oqp1_f14, &oqp0_f14); + + oqp5_f14 = _mm_blendv_epi8(qp5, oqp5_f14, v_flat4_mask); + oqp4_f14 = _mm_blendv_epi8(qp4, oqp4_f14, v_flat4_mask); + oqp3_f14 = _mm_blendv_epi8(qp3, oqp3_f14, v_flat4_mask); + oqp2_f8 = _mm_blendv_epi8(oqp2_f8, oqp2_f14, v_flat4_mask); + oqp1 = _mm_blendv_epi8(oqp1, oqp1_f14, v_flat4_mask); + oqp0 = _mm_blendv_epi8(oqp0, oqp0_f14, v_flat4_mask); + qp3 = oqp3_f14; + qp4 = oqp4_f14; + qp5 = oqp5_f14; + } + qp2 = oqp2_f8; + } + + DualTranspose4x8To8x4(qp7, qp6, qp5, qp4, qp3, qp2, oqp1, oqp0, &x0, &x1, &x2, + &x3); + + StoreUnaligned16(dst - 8 + 0 * stride, x0); + StoreUnaligned16(dst - 8 + 1 * stride, x1); + StoreUnaligned16(dst - 8 + 2 * stride, x2); + StoreUnaligned16(dst - 8 + 3 * stride, x3); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + static_cast<void>(dsp); +#if DSP_ENABLED_8BPP_SSE4_1(LoopFilterSize4_LoopFilterTypeHorizontal) + dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeHorizontal] = Horizontal4; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(LoopFilterSize6_LoopFilterTypeHorizontal) + dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeHorizontal] = Horizontal6; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(LoopFilterSize8_LoopFilterTypeHorizontal) + dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeHorizontal] = Horizontal8; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(LoopFilterSize14_LoopFilterTypeHorizontal) + dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeHorizontal] = + Horizontal14; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(LoopFilterSize4_LoopFilterTypeVertical) + dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeVertical] = Vertical4; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(LoopFilterSize6_LoopFilterTypeVertical) + dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeVertical] = Vertical6; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(LoopFilterSize8_LoopFilterTypeVertical) + dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeVertical] = Vertical8; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(LoopFilterSize14_LoopFilterTypeVertical) + dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeVertical] = Vertical14; +#endif +} +} // namespace +} // namespace low_bitdepth + +//------------------------------------------------------------------------------ +namespace high_bitdepth { +namespace { + +#if LIBGAV1_MAX_BITDEPTH >= 10 + +template <int bitdepth> +struct LoopFilterFuncs_SSE4_1 { + LoopFilterFuncs_SSE4_1() = delete; + + static constexpr int kThreshShift = bitdepth - 8; + + static void Vertical4(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh); + static void Horizontal4(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh); + static void Vertical6(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh); + static void Horizontal6(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh); + static void Vertical8(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh); + static void Horizontal8(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh); + static void Vertical14(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh); + static void Horizontal14(void* dest, ptrdiff_t stride, int outer_thresh, + int inner_thresh, int hev_thresh); +}; + +inline __m128i Clamp(const __m128i& min, const __m128i& max, + const __m128i& val) { + const __m128i a = _mm_min_epi16(val, max); + const __m128i b = _mm_max_epi16(a, min); + return b; +} + +inline __m128i AddShift3(const __m128i& a, const __m128i& b, + const __m128i& vmin, const __m128i& vmax) { + const __m128i c = _mm_adds_epi16(a, b); + const __m128i d = Clamp(vmin, vmax, c); + const __m128i e = _mm_srai_epi16(d, 3); /* >> 3 */ + return e; +} + +inline __m128i AddShift1(const __m128i& a, const __m128i& b) { + const __m128i c = _mm_adds_epi16(a, b); + const __m128i e = _mm_srai_epi16(c, 1); /* >> 1 */ + return e; +} + +inline __m128i AbsDiff(const __m128i& a, const __m128i& b) { + return _mm_or_si128(_mm_subs_epu16(a, b), _mm_subs_epu16(b, a)); +} + +inline __m128i Hev(const __m128i& qp1, const __m128i& qp0, + const __m128i& hev_thresh) { + const __m128i abs_qp1mqp0 = AbsDiff(qp1, qp0); + const __m128i max_pq = + _mm_max_epu16(abs_qp1mqp0, _mm_srli_si128(abs_qp1mqp0, 8)); + const __m128i hev_mask = _mm_cmpgt_epi16(max_pq, hev_thresh); + return hev_mask; +} + +inline __m128i CheckOuterThreshF4(const __m128i& q1q0, const __m128i& p1p0, + const __m128i& outer_thresh) { + // abs(p0 - q0) * 2 + abs(p1 - q1) / 2 <= outer_thresh; + const __m128i abs_pmq = AbsDiff(p1p0, q1q0); + const __m128i a = _mm_adds_epu16(abs_pmq, abs_pmq); + const __m128i b = _mm_srli_epi16(abs_pmq, 1); + const __m128i c = _mm_adds_epu16(a, _mm_srli_si128(b, 8)); + return _mm_subs_epu16(c, outer_thresh); +} + +inline __m128i NeedsFilter4(const __m128i& q1q0, const __m128i& p1p0, + const __m128i& qp1, const __m128i& qp0, + const __m128i& outer_thresh, + const __m128i& inner_thresh) { + const __m128i outer_mask = CheckOuterThreshF4(q1q0, p1p0, outer_thresh); + const __m128i abs_qp1mqp0 = AbsDiff(qp1, qp0); + const __m128i max_abs_qp1mqp = + _mm_max_epu16(abs_qp1mqp0, _mm_srli_si128(abs_qp1mqp0, 8)); + const __m128i inner_mask = _mm_subs_epu16(max_abs_qp1mqp, inner_thresh); + // ~mask + const __m128i zero = _mm_setzero_si128(); + const __m128i a = _mm_or_si128(outer_mask, inner_mask); + const __m128i b = _mm_cmpeq_epi16(a, zero); + return b; +} + +inline void Filter4(const __m128i& qp1, const __m128i& qp0, __m128i* oqp1, + __m128i* oqp0, const __m128i& mask, const __m128i& hev, + int bitdepth) { + const __m128i t4 = _mm_set1_epi16(4); + const __m128i t3 = _mm_set1_epi16(3); + const __m128i t80 = _mm_set1_epi16(static_cast<int16_t>(1 << (bitdepth - 1))); + const __m128i t1 = _mm_set1_epi16(0x1); + const __m128i vmin = _mm_subs_epi16(_mm_setzero_si128(), t80); + const __m128i vmax = _mm_subs_epi16(t80, t1); + const __m128i ps1 = _mm_subs_epi16(qp1, t80); + const __m128i ps0 = _mm_subs_epi16(qp0, t80); + const __m128i qs0 = _mm_srli_si128(ps0, 8); + const __m128i qs1 = _mm_srli_si128(ps1, 8); + + __m128i a = _mm_subs_epi16(ps1, qs1); + a = _mm_and_si128(Clamp(vmin, vmax, a), hev); + + const __m128i x = _mm_subs_epi16(qs0, ps0); + a = _mm_adds_epi16(a, x); + a = _mm_adds_epi16(a, x); + a = _mm_adds_epi16(a, x); + a = _mm_and_si128(Clamp(vmin, vmax, a), mask); + + const __m128i a1 = AddShift3(a, t4, vmin, vmax); + const __m128i a2 = AddShift3(a, t3, vmin, vmax); + const __m128i a3 = _mm_andnot_si128(hev, AddShift1(a1, t1)); + + const __m128i ops1 = _mm_adds_epi16(ps1, a3); + const __m128i ops0 = _mm_adds_epi16(ps0, a2); + const __m128i oqs0 = _mm_subs_epi16(qs0, a1); + const __m128i oqs1 = _mm_subs_epi16(qs1, a3); + + __m128i oqps1 = _mm_unpacklo_epi64(ops1, oqs1); + __m128i oqps0 = _mm_unpacklo_epi64(ops0, oqs0); + + oqps1 = Clamp(vmin, vmax, oqps1); + oqps0 = Clamp(vmin, vmax, oqps0); + + *oqp1 = _mm_adds_epi16(oqps1, t80); + *oqp0 = _mm_adds_epi16(oqps0, t80); +} + +template <int bitdepth> +void LoopFilterFuncs_SSE4_1<bitdepth>::Horizontal4(void* dest, + ptrdiff_t stride8, + int outer_thresh, + int inner_thresh, + int hev_thresh) { + auto* const dst = static_cast<uint16_t*>(dest); + const ptrdiff_t stride = stride8 / 2; + const __m128i v_outer_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(outer_thresh << kThreshShift), 0); + const __m128i v_inner_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(inner_thresh << kThreshShift), 0); + const __m128i v_hev_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(hev_thresh << kThreshShift), 0); + const __m128i p1 = LoadLo8(dst - 2 * stride); + const __m128i p0 = LoadLo8(dst - 1 * stride); + const __m128i qp0 = LoadHi8(p0, dst + 0 * stride); + const __m128i qp1 = LoadHi8(p1, dst + 1 * stride); + const __m128i q1q0 = _mm_unpackhi_epi64(qp0, qp1); + const __m128i p1p0 = _mm_unpacklo_epi64(qp0, qp1); + const __m128i v_hev_mask = Hev(qp1, qp0, v_hev_thresh); + const __m128i v_needs_mask = + NeedsFilter4(q1q0, p1p0, qp1, qp0, v_outer_thresh, v_inner_thresh); + + __m128i oqp1; + __m128i oqp0; + Filter4(qp1, qp0, &oqp1, &oqp0, v_needs_mask, v_hev_mask, bitdepth); + + StoreLo8(dst - 2 * stride, oqp1); + StoreLo8(dst - 1 * stride, oqp0); + StoreHi8(dst + 0 * stride, oqp0); + StoreHi8(dst + 1 * stride, oqp1); +} + +template <int bitdepth> +void LoopFilterFuncs_SSE4_1<bitdepth>::Vertical4(void* dest, ptrdiff_t stride8, + int outer_thresh, + int inner_thresh, + int hev_thresh) { + auto* const dst = static_cast<uint16_t*>(dest); + const ptrdiff_t stride = stride8 / 2; + const __m128i v_outer_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(outer_thresh << kThreshShift), 0); + const __m128i v_inner_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(inner_thresh << kThreshShift), 0); + const __m128i v_hev_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(hev_thresh << kThreshShift), 0); + const __m128i x0 = LoadLo8(dst - 2 + 0 * stride); + const __m128i x1 = LoadLo8(dst - 2 + 1 * stride); + const __m128i x2 = LoadLo8(dst - 2 + 2 * stride); + const __m128i x3 = LoadLo8(dst - 2 + 3 * stride); + // 00 10 01 11 02 12 03 13 + const __m128i w0 = _mm_unpacklo_epi16(x0, x1); + // 20 30 21 31 22 32 23 33 + const __m128i w1 = _mm_unpacklo_epi16(x2, x3); + // 00 10 20 30 01 11 21 31 p0p1 + const __m128i a = _mm_unpacklo_epi32(w0, w1); + const __m128i p1p0 = _mm_shuffle_epi32(a, 0x4e); + // 02 12 22 32 03 13 23 33 q1q0 + const __m128i q1q0 = _mm_unpackhi_epi32(w0, w1); + const __m128i qp1 = _mm_unpackhi_epi64(p1p0, q1q0); + const __m128i qp0 = _mm_unpacklo_epi64(p1p0, q1q0); + const __m128i v_hev_mask = Hev(qp1, qp0, v_hev_thresh); + const __m128i v_needs_mask = + NeedsFilter4(q1q0, p1p0, qp1, qp0, v_outer_thresh, v_inner_thresh); + + __m128i oqp1; + __m128i oqp0; + Filter4(qp1, qp0, &oqp1, &oqp0, v_needs_mask, v_hev_mask, bitdepth); + + // 00 10 01 11 02 12 03 13 + const __m128i w2 = _mm_unpacklo_epi16(oqp1, oqp0); + // 20 30 21 31 22 32 23 33 + const __m128i w3 = _mm_unpackhi_epi16(oqp0, oqp1); + // 00 10 20 30 01 11 21 31 + const __m128i op0p1 = _mm_unpacklo_epi32(w2, w3); + // 02 12 22 32 03 13 23 33 + const __m128i oq1q0 = _mm_unpackhi_epi32(w2, w3); + + StoreLo8(dst - 2 + 0 * stride, op0p1); + StoreHi8(dst - 2 + 1 * stride, op0p1); + StoreLo8(dst - 2 + 2 * stride, oq1q0); + StoreHi8(dst - 2 + 3 * stride, oq1q0); +} + +//------------------------------------------------------------------------------ +// 5-tap (chroma) filters + +inline __m128i CheckOuterThreshF6(const __m128i& qp1, const __m128i& qp0, + const __m128i& outer_thresh) { + // abs(p0 - q0) * 2 + abs(p1 - q1) / 2 <= outer_thresh; + const __m128i q1q0 = _mm_unpackhi_epi64(qp0, qp1); + const __m128i p1p0 = _mm_unpacklo_epi64(qp0, qp1); + return CheckOuterThreshF4(q1q0, p1p0, outer_thresh); +} + +inline __m128i NeedsFilter6(const __m128i& qp2, const __m128i& qp1, + const __m128i& qp0, const __m128i& outer_thresh, + const __m128i& inner_thresh) { + const __m128i outer_mask = CheckOuterThreshF6(qp1, qp0, outer_thresh); + const __m128i abs_qp2mqp1 = AbsDiff(qp2, qp1); + const __m128i abs_qp1mqp0 = AbsDiff(qp1, qp0); + const __m128i max_pq = _mm_max_epu16(abs_qp2mqp1, abs_qp1mqp0); + const __m128i inner_mask = _mm_subs_epu16( + _mm_max_epu16(max_pq, _mm_srli_si128(max_pq, 8)), inner_thresh); + // ~mask + const __m128i zero = _mm_setzero_si128(); + const __m128i a = _mm_or_si128(outer_mask, inner_mask); + const __m128i b = _mm_cmpeq_epi16(a, zero); + return b; +} + +inline __m128i IsFlat3(const __m128i& qp2, const __m128i& qp1, + const __m128i& qp0, const __m128i& flat_thresh) { + const __m128i abs_pq2mpq0 = AbsDiff(qp2, qp0); + const __m128i abs_qp1mqp0 = AbsDiff(qp1, qp0); + const __m128i max_pq = _mm_max_epu16(abs_pq2mpq0, abs_qp1mqp0); + const __m128i flat_mask = _mm_subs_epu16( + _mm_max_epu16(max_pq, _mm_srli_si128(max_pq, 8)), flat_thresh); + // ~mask + const __m128i zero = _mm_setzero_si128(); + const __m128i a = _mm_cmpeq_epi16(flat_mask, zero); + return a; +} + +inline void Filter6(const __m128i& qp2, const __m128i& qp1, const __m128i& qp0, + __m128i* oqp1, __m128i* oqp0) { + const __m128i four = _mm_set1_epi16(4); + const __m128i qp2_lo = qp2; + const __m128i qp1_lo = qp1; + const __m128i qp0_lo = qp0; + const __m128i pq1_lo = _mm_shuffle_epi32(qp1_lo, 0x4e); + const __m128i pq0_lo = _mm_shuffle_epi32(qp0_lo, 0x4e); + + __m128i f6_lo; + f6_lo = + _mm_add_epi16(_mm_add_epi16(qp2_lo, four), _mm_add_epi16(qp2_lo, qp2_lo)); + + f6_lo = _mm_add_epi16(_mm_add_epi16(f6_lo, qp1_lo), qp1_lo); + + f6_lo = _mm_add_epi16(_mm_add_epi16(f6_lo, qp0_lo), + _mm_add_epi16(qp0_lo, pq0_lo)); + + // p2 * 3 + p1 * 2 + p0 * 2 + q0 + // q2 * 3 + q1 * 2 + q0 * 2 + p0 + *oqp1 = _mm_srli_epi16(f6_lo, 3); + + // p2 + p1 * 2 + p0 * 2 + q0 * 2 + q1 + // q2 + q1 * 2 + q0 * 2 + p0 * 2 + p1 + f6_lo = FilterAdd2Sub2(f6_lo, pq0_lo, pq1_lo, qp2_lo, qp2_lo); + *oqp0 = _mm_srli_epi16(f6_lo, 3); +} + +template <int bitdepth> +void LoopFilterFuncs_SSE4_1<bitdepth>::Horizontal6(void* dest, + ptrdiff_t stride8, + int outer_thresh, + int inner_thresh, + int hev_thresh) { + auto* const dst = static_cast<uint16_t*>(dest); + const ptrdiff_t stride = stride8 / 2; + const __m128i v_flat_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(1 << kThreshShift), 0); + const __m128i v_outer_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(outer_thresh << kThreshShift), 0); + const __m128i v_inner_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(inner_thresh << kThreshShift), 0); + const __m128i v_hev_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(hev_thresh << kThreshShift), 0); + + const __m128i p2 = LoadLo8(dst - 3 * stride); + const __m128i p1 = LoadLo8(dst - 2 * stride); + const __m128i p0 = LoadLo8(dst - 1 * stride); + const __m128i q0 = LoadLo8(dst + 0 * stride); + const __m128i q1 = LoadLo8(dst + 1 * stride); + const __m128i q2 = LoadLo8(dst + 2 * stride); + + const __m128i qp2 = _mm_unpacklo_epi64(p2, q2); + const __m128i qp1 = _mm_unpacklo_epi64(p1, q1); + const __m128i qp0 = _mm_unpacklo_epi64(p0, q0); + + const __m128i v_hev_mask = Hev(qp1, qp0, v_hev_thresh); + const __m128i v_needs_mask = + NeedsFilter6(qp2, qp1, qp0, v_outer_thresh, v_inner_thresh); + __m128i oqp1; + __m128i oqp0; + + Filter4(qp1, qp0, &oqp1, &oqp0, v_needs_mask, v_hev_mask, bitdepth); + + const __m128i v_isflat3_mask = IsFlat3(qp2, qp1, qp0, v_flat_thresh); + const __m128i v_mask_lo = _mm_and_si128(v_needs_mask, v_isflat3_mask); + const __m128i v_mask = _mm_unpacklo_epi64(v_mask_lo, v_mask_lo); + + if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi16(v_mask, v_mask)) == 0) { + __m128i oqp1_f6; + __m128i oqp0_f6; + + Filter6(qp2, qp1, qp0, &oqp1_f6, &oqp0_f6); + + oqp1 = _mm_blendv_epi8(oqp1, oqp1_f6, v_mask); + oqp0 = _mm_blendv_epi8(oqp0, oqp0_f6, v_mask); + } + + StoreLo8(dst - 2 * stride, oqp1); + StoreLo8(dst - 1 * stride, oqp0); + StoreHi8(dst + 0 * stride, oqp0); + StoreHi8(dst + 1 * stride, oqp1); +} + +inline void Transpose8x4To4x8(const __m128i& x0, const __m128i& x1, + const __m128i& x2, const __m128i& x3, __m128i* d0, + __m128i* d1, __m128i* d2, __m128i* d3, + __m128i* d4, __m128i* d5, __m128i* d6, + __m128i* d7) { + // input + // x0 00 01 02 03 04 05 06 07 + // x1 10 11 12 13 14 15 16 17 + // x2 20 21 22 23 24 25 26 27 + // x3 30 31 32 33 34 35 36 37 + // output + // 00 10 20 30 xx xx xx xx + // 01 11 21 31 xx xx xx xx + // 02 12 22 32 xx xx xx xx + // 03 13 23 33 xx xx xx xx + // 04 14 24 34 xx xx xx xx + // 05 15 25 35 xx xx xx xx + // 06 16 26 36 xx xx xx xx + // 07 17 27 37 xx xx xx xx + + // 00 10 01 11 02 12 03 13 + const __m128i w0 = _mm_unpacklo_epi16(x0, x1); + // 20 30 21 31 22 32 23 33 + const __m128i w1 = _mm_unpacklo_epi16(x2, x3); + // 04 14 05 15 06 16 07 17 + const __m128i w2 = _mm_unpackhi_epi16(x0, x1); + // 24 34 25 35 26 36 27 37 + const __m128i w3 = _mm_unpackhi_epi16(x2, x3); + + // 00 10 20 30 01 11 21 31 + const __m128i ww0 = _mm_unpacklo_epi32(w0, w1); + // 04 14 24 34 05 15 25 35 + const __m128i ww1 = _mm_unpacklo_epi32(w2, w3); + // 02 12 22 32 03 13 23 33 + const __m128i ww2 = _mm_unpackhi_epi32(w0, w1); + // 06 16 26 36 07 17 27 37 + const __m128i ww3 = _mm_unpackhi_epi32(w2, w3); + + // 00 10 20 30 xx xx xx xx + *d0 = ww0; + // 01 11 21 31 xx xx xx xx + *d1 = _mm_srli_si128(ww0, 8); + // 02 12 22 32 xx xx xx xx + *d2 = ww2; + // 03 13 23 33 xx xx xx xx + *d3 = _mm_srli_si128(ww2, 8); + // 04 14 24 34 xx xx xx xx + *d4 = ww1; + // 05 15 25 35 xx xx xx xx + *d5 = _mm_srli_si128(ww1, 8); + // 06 16 26 36 xx xx xx xx + *d6 = ww3; + // 07 17 27 37 xx xx xx xx + *d7 = _mm_srli_si128(ww3, 8); +} + +template <int bitdepth> +void LoopFilterFuncs_SSE4_1<bitdepth>::Vertical6(void* dest, ptrdiff_t stride8, + int outer_thresh, + int inner_thresh, + int hev_thresh) { + auto* const dst = static_cast<uint16_t*>(dest); + const ptrdiff_t stride = stride8 / 2; + const __m128i v_flat_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(1 << kThreshShift), 0); + const __m128i v_outer_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(outer_thresh << kThreshShift), 0); + const __m128i v_inner_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(inner_thresh << kThreshShift), 0); + const __m128i v_hev_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(hev_thresh << kThreshShift), 0); + + __m128i x0 = LoadUnaligned16(dst - 3 + 0 * stride); + __m128i x1 = LoadUnaligned16(dst - 3 + 1 * stride); + __m128i x2 = LoadUnaligned16(dst - 3 + 2 * stride); + __m128i x3 = LoadUnaligned16(dst - 3 + 3 * stride); + + __m128i p2, p1, p0, q0, q1, q2; + __m128i z0, z1; // not used + + Transpose8x4To4x8(x0, x1, x2, x3, &p2, &p1, &p0, &q0, &q1, &q2, &z0, &z1); + + const __m128i qp2 = _mm_unpacklo_epi64(p2, q2); + const __m128i qp1 = _mm_unpacklo_epi64(p1, q1); + const __m128i qp0 = _mm_unpacklo_epi64(p0, q0); + + const __m128i v_hev_mask = Hev(qp1, qp0, v_hev_thresh); + const __m128i v_needs_mask = + NeedsFilter6(qp2, qp1, qp0, v_outer_thresh, v_inner_thresh); + __m128i oqp1; + __m128i oqp0; + + Filter4(qp1, qp0, &oqp1, &oqp0, v_needs_mask, v_hev_mask, bitdepth); + + const __m128i v_isflat3_mask = IsFlat3(qp2, qp1, qp0, v_flat_thresh); + const __m128i v_mask_lo = _mm_and_si128(v_needs_mask, v_isflat3_mask); + const __m128i v_mask = _mm_unpacklo_epi64(v_mask_lo, v_mask_lo); + + if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi16(v_mask, v_mask)) == 0) { + __m128i oqp1_f6; + __m128i oqp0_f6; + + Filter6(qp2, qp1, qp0, &oqp1_f6, &oqp0_f6); + + oqp1 = _mm_blendv_epi8(oqp1, oqp1_f6, v_mask); + oqp0 = _mm_blendv_epi8(oqp0, oqp0_f6, v_mask); + } + + // 00 10 01 11 02 12 03 13 + const __m128i w2 = _mm_unpacklo_epi16(oqp1, oqp0); + // 20 30 21 31 22 32 23 33 + const __m128i w3 = _mm_unpackhi_epi16(oqp0, oqp1); + // 00 10 20 30 01 11 21 31 + const __m128i op0p1 = _mm_unpacklo_epi32(w2, w3); + // 02 12 22 32 03 13 23 33 + const __m128i oq1q0 = _mm_unpackhi_epi32(w2, w3); + + StoreLo8(dst - 2 + 0 * stride, op0p1); + StoreHi8(dst - 2 + 1 * stride, op0p1); + StoreLo8(dst - 2 + 2 * stride, oq1q0); + StoreHi8(dst - 2 + 3 * stride, oq1q0); +} + +//------------------------------------------------------------------------------ +// 7-tap filters +inline __m128i NeedsFilter8(const __m128i& qp3, const __m128i& qp2, + const __m128i& qp1, const __m128i& qp0, + const __m128i& outer_thresh, + const __m128i& inner_thresh) { + const __m128i outer_mask = CheckOuterThreshF6(qp1, qp0, outer_thresh); + const __m128i abs_qp2mqp1 = AbsDiff(qp2, qp1); + const __m128i abs_qp1mqp0 = AbsDiff(qp1, qp0); + const __m128i max_pq_a = _mm_max_epu16(abs_qp2mqp1, abs_qp1mqp0); + const __m128i abs_pq3mpq2 = AbsDiff(qp3, qp2); + const __m128i max_pq = _mm_max_epu16(max_pq_a, abs_pq3mpq2); + const __m128i inner_mask = _mm_subs_epu16( + _mm_max_epu16(max_pq, _mm_srli_si128(max_pq, 8)), inner_thresh); + // ~mask + const __m128i zero = _mm_setzero_si128(); + const __m128i a = _mm_or_si128(outer_mask, inner_mask); + const __m128i b = _mm_cmpeq_epi16(a, zero); + return b; +} + +inline __m128i IsFlat4(const __m128i& qp3, const __m128i& qp2, + const __m128i& qp1, const __m128i& qp0, + const __m128i& flat_thresh) { + const __m128i abs_pq2mpq0 = AbsDiff(qp2, qp0); + const __m128i abs_qp1mqp0 = AbsDiff(qp1, qp0); + const __m128i max_pq_a = _mm_max_epu16(abs_pq2mpq0, abs_qp1mqp0); + const __m128i abs_pq3mpq0 = AbsDiff(qp3, qp0); + const __m128i max_pq = _mm_max_epu16(max_pq_a, abs_pq3mpq0); + const __m128i flat_mask = _mm_subs_epu16( + _mm_max_epu16(max_pq, _mm_srli_si128(max_pq, 8)), flat_thresh); + // ~mask + const __m128i zero = _mm_setzero_si128(); + const __m128i a = _mm_cmpeq_epi16(flat_mask, zero); + return a; +} + +inline void Filter8(const __m128i& qp3, const __m128i& qp2, const __m128i& qp1, + const __m128i& qp0, __m128i* oqp2, __m128i* oqp1, + __m128i* oqp0) { + const __m128i four = _mm_set1_epi16(4); + const __m128i qp3_lo = qp3; + const __m128i qp2_lo = qp2; + const __m128i qp1_lo = qp1; + const __m128i qp0_lo = qp0; + const __m128i pq2_lo = _mm_shuffle_epi32(qp2_lo, 0x4e); + const __m128i pq1_lo = _mm_shuffle_epi32(qp1_lo, 0x4e); + const __m128i pq0_lo = _mm_shuffle_epi32(qp0_lo, 0x4e); + + __m128i f8_lo = + _mm_add_epi16(_mm_add_epi16(qp3_lo, four), _mm_add_epi16(qp3_lo, qp3_lo)); + + f8_lo = _mm_add_epi16(_mm_add_epi16(f8_lo, qp2_lo), qp2_lo); + + f8_lo = _mm_add_epi16(_mm_add_epi16(f8_lo, qp1_lo), + _mm_add_epi16(qp0_lo, pq0_lo)); + + // p3 + p3 + p3 + 2 * p2 + p1 + p0 + q0 + // q3 + q3 + q3 + 2 * q2 + q1 + q0 + p0 + *oqp2 = _mm_srli_epi16(f8_lo, 3); + + // p3 + p3 + p2 + 2 * p1 + p0 + q0 + q1 + // q3 + q3 + q2 + 2 * q1 + q0 + p0 + p1 + f8_lo = FilterAdd2Sub2(f8_lo, qp1_lo, pq1_lo, qp3_lo, qp2_lo); + *oqp1 = _mm_srli_epi16(f8_lo, 3); + + // p3 + p2 + p1 + 2 * p0 + q0 + q1 + q2 + // q3 + q2 + q1 + 2 * q0 + p0 + p1 + p2 + f8_lo = FilterAdd2Sub2(f8_lo, qp0_lo, pq2_lo, qp3_lo, qp1_lo); + *oqp0 = _mm_srli_epi16(f8_lo, 3); +} + +template <int bitdepth> +void LoopFilterFuncs_SSE4_1<bitdepth>::Horizontal8(void* dest, + ptrdiff_t stride8, + int outer_thresh, + int inner_thresh, + int hev_thresh) { + auto* const dst = static_cast<uint16_t*>(dest); + const ptrdiff_t stride = stride8 / 2; + const __m128i v_flat_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(1 << kThreshShift), 0); + const __m128i v_outer_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(outer_thresh << kThreshShift), 0); + const __m128i v_inner_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(inner_thresh << kThreshShift), 0); + const __m128i v_hev_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(hev_thresh << kThreshShift), 0); + + const __m128i p3 = LoadLo8(dst - 4 * stride); + const __m128i p2 = LoadLo8(dst - 3 * stride); + const __m128i p1 = LoadLo8(dst - 2 * stride); + const __m128i p0 = LoadLo8(dst - 1 * stride); + const __m128i q0 = LoadLo8(dst + 0 * stride); + const __m128i q1 = LoadLo8(dst + 1 * stride); + const __m128i q2 = LoadLo8(dst + 2 * stride); + const __m128i q3 = LoadLo8(dst + 3 * stride); + const __m128i qp3 = _mm_unpacklo_epi64(p3, q3); + const __m128i qp2 = _mm_unpacklo_epi64(p2, q2); + const __m128i qp1 = _mm_unpacklo_epi64(p1, q1); + const __m128i qp0 = _mm_unpacklo_epi64(p0, q0); + + const __m128i v_hev_mask = Hev(qp1, qp0, v_hev_thresh); + const __m128i v_needs_mask = + NeedsFilter8(qp3, qp2, qp1, qp0, v_outer_thresh, v_inner_thresh); + __m128i oqp1; + __m128i oqp0; + + Filter4(qp1, qp0, &oqp1, &oqp0, v_needs_mask, v_hev_mask, bitdepth); + + const __m128i v_isflat4_mask = IsFlat4(qp3, qp2, qp1, qp0, v_flat_thresh); + const __m128i v_mask_lo = _mm_and_si128(v_needs_mask, v_isflat4_mask); + const __m128i v_mask = _mm_unpacklo_epi64(v_mask_lo, v_mask_lo); + + if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi16(v_mask, v_mask)) == 0) { + __m128i oqp2_f8; + __m128i oqp1_f8; + __m128i oqp0_f8; + + Filter8(qp3, qp2, qp1, qp0, &oqp2_f8, &oqp1_f8, &oqp0_f8); + + oqp2_f8 = _mm_blendv_epi8(qp2, oqp2_f8, v_mask); + oqp1 = _mm_blendv_epi8(oqp1, oqp1_f8, v_mask); + oqp0 = _mm_blendv_epi8(oqp0, oqp0_f8, v_mask); + StoreLo8(dst - 3 * stride, oqp2_f8); + StoreHi8(dst + 2 * stride, oqp2_f8); + } + + StoreLo8(dst - 2 * stride, oqp1); + StoreLo8(dst - 1 * stride, oqp0); + StoreHi8(dst + 0 * stride, oqp0); + StoreHi8(dst + 1 * stride, oqp1); +} + +inline void TransposeLower4x8To8x4(const __m128i& x0, const __m128i& x1, + const __m128i& x2, const __m128i& x3, + const __m128i& x4, const __m128i& x5, + const __m128i& x6, const __m128i& x7, + __m128i* d0, __m128i* d1, __m128i* d2, + __m128i* d3) { + // input + // x0 00 01 02 03 04 05 06 07 + // x1 10 11 12 13 14 15 16 17 + // x2 20 21 22 23 24 25 26 27 + // x3 30 31 32 33 34 35 36 37 + // x4 40 41 42 43 44 45 46 47 + // x5 50 51 52 53 54 55 56 57 + // x6 60 61 62 63 64 65 66 67 + // x7 70 71 72 73 74 75 76 77 + // output + // d0 00 10 20 30 40 50 60 70 + // d1 01 11 21 31 41 51 61 71 + // d2 02 12 22 32 42 52 62 72 + // d3 03 13 23 33 43 53 63 73 + + // 00 10 01 11 02 12 03 13 + const __m128i w0 = _mm_unpacklo_epi16(x0, x1); + // 20 30 21 31 22 32 23 33 + const __m128i w1 = _mm_unpacklo_epi16(x2, x3); + // 40 50 41 51 42 52 43 53 + const __m128i w2 = _mm_unpacklo_epi16(x4, x5); + // 60 70 61 71 62 72 63 73 + const __m128i w3 = _mm_unpacklo_epi16(x6, x7); + + // 00 10 20 30 01 11 21 31 + const __m128i w4 = _mm_unpacklo_epi32(w0, w1); + // 40 50 60 70 41 51 61 71 + const __m128i w5 = _mm_unpacklo_epi32(w2, w3); + // 02 12 22 32 03 13 23 33 + const __m128i w6 = _mm_unpackhi_epi32(w0, w1); + // 42 52 62 72 43 53 63 73 + const __m128i w7 = _mm_unpackhi_epi32(w2, w3); + + // 00 10 20 30 40 50 60 70 + *d0 = _mm_unpacklo_epi64(w4, w5); + // 01 11 21 31 41 51 61 71 + *d1 = _mm_unpackhi_epi64(w4, w5); + // 02 12 22 32 42 52 62 72 + *d2 = _mm_unpacklo_epi64(w6, w7); + // 03 13 23 33 43 53 63 73 + *d3 = _mm_unpackhi_epi64(w6, w7); +} + +template <int bitdepth> +void LoopFilterFuncs_SSE4_1<bitdepth>::Vertical8(void* dest, ptrdiff_t stride8, + int outer_thresh, + int inner_thresh, + int hev_thresh) { + auto* const dst = static_cast<uint16_t*>(dest); + const ptrdiff_t stride = stride8 / 2; + const __m128i v_flat_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(1 << kThreshShift), 0); + const __m128i v_outer_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(outer_thresh << kThreshShift), 0); + const __m128i v_inner_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(inner_thresh << kThreshShift), 0); + const __m128i v_hev_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(hev_thresh << kThreshShift), 0); + + __m128i x0 = LoadUnaligned16(dst - 4 + 0 * stride); + __m128i x1 = LoadUnaligned16(dst - 4 + 1 * stride); + __m128i x2 = LoadUnaligned16(dst - 4 + 2 * stride); + __m128i x3 = LoadUnaligned16(dst - 4 + 3 * stride); + + __m128i p3, p2, p1, p0, q0, q1, q2, q3; + Transpose8x4To4x8(x0, x1, x2, x3, &p3, &p2, &p1, &p0, &q0, &q1, &q2, &q3); + + const __m128i qp3 = _mm_unpacklo_epi64(p3, q3); + const __m128i qp2 = _mm_unpacklo_epi64(p2, q2); + const __m128i qp1 = _mm_unpacklo_epi64(p1, q1); + const __m128i qp0 = _mm_unpacklo_epi64(p0, q0); + + const __m128i v_hev_mask = Hev(qp1, qp0, v_hev_thresh); + const __m128i v_needs_mask = + NeedsFilter8(qp3, qp2, qp1, qp0, v_outer_thresh, v_inner_thresh); + __m128i oqp1; + __m128i oqp0; + + Filter4(qp1, qp0, &oqp1, &oqp0, v_needs_mask, v_hev_mask, bitdepth); + + const __m128i v_isflat4_mask = IsFlat4(qp3, qp2, qp1, qp0, v_flat_thresh); + const __m128i v_mask_lo = _mm_and_si128(v_needs_mask, v_isflat4_mask); + const __m128i v_mask = _mm_unpacklo_epi64(v_mask_lo, v_mask_lo); + + if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi16(v_mask, v_mask)) == 0) { + __m128i oqp2_f8; + __m128i oqp1_f8; + __m128i oqp0_f8; + + Filter8(qp3, qp2, qp1, qp0, &oqp2_f8, &oqp1_f8, &oqp0_f8); + + oqp2_f8 = _mm_blendv_epi8(qp2, oqp2_f8, v_mask); + oqp1 = _mm_blendv_epi8(oqp1, oqp1_f8, v_mask); + oqp0 = _mm_blendv_epi8(oqp0, oqp0_f8, v_mask); + + p2 = oqp2_f8; + q2 = _mm_srli_si128(oqp2_f8, 8); + } + + p1 = oqp1; + p0 = oqp0; + q0 = _mm_srli_si128(oqp0, 8); + q1 = _mm_srli_si128(oqp1, 8); + + TransposeLower4x8To8x4(p3, p2, p1, p0, q0, q1, q2, q3, &x0, &x1, &x2, &x3); + + StoreUnaligned16(dst - 4 + 0 * stride, x0); + StoreUnaligned16(dst - 4 + 1 * stride, x1); + StoreUnaligned16(dst - 4 + 2 * stride, x2); + StoreUnaligned16(dst - 4 + 3 * stride, x3); +} + +//------------------------------------------------------------------------------ +// 13-tap filters + +inline void Filter14(const __m128i& qp6, const __m128i& qp5, const __m128i& qp4, + const __m128i& qp3, const __m128i& qp2, const __m128i& qp1, + const __m128i& qp0, __m128i* oqp5, __m128i* oqp4, + __m128i* oqp3, __m128i* oqp2, __m128i* oqp1, + __m128i* oqp0) { + const __m128i eight = _mm_set1_epi16(8); + const __m128i qp6_lo = qp6; + const __m128i qp5_lo = qp5; + const __m128i qp4_lo = qp4; + const __m128i qp3_lo = qp3; + const __m128i qp2_lo = qp2; + const __m128i qp1_lo = qp1; + const __m128i qp0_lo = qp0; + const __m128i pq5_lo = _mm_shuffle_epi32(qp5_lo, 0x4e); + const __m128i pq4_lo = _mm_shuffle_epi32(qp4_lo, 0x4e); + const __m128i pq3_lo = _mm_shuffle_epi32(qp3_lo, 0x4e); + const __m128i pq2_lo = _mm_shuffle_epi32(qp2_lo, 0x4e); + const __m128i pq1_lo = _mm_shuffle_epi32(qp1_lo, 0x4e); + const __m128i pq0_lo = _mm_shuffle_epi32(qp0_lo, 0x4e); + + __m128i f14_lo = + _mm_add_epi16(eight, _mm_sub_epi16(_mm_slli_epi16(qp6_lo, 3), qp6_lo)); + + f14_lo = _mm_add_epi16(_mm_add_epi16(f14_lo, qp5_lo), + _mm_add_epi16(qp5_lo, qp4_lo)); + + f14_lo = _mm_add_epi16(_mm_add_epi16(f14_lo, qp4_lo), + _mm_add_epi16(qp3_lo, qp2_lo)); + + f14_lo = _mm_add_epi16(_mm_add_epi16(f14_lo, qp1_lo), + _mm_add_epi16(qp0_lo, pq0_lo)); + + // p6 * 7 + p5 * 2 + p4 * 2 + p3 + p2 + p1 + p0 + q0 + // q6 * 7 + q5 * 2 + q4 * 2 + q3 + q2 + q1 + q0 + p0 + *oqp5 = _mm_srli_epi16(f14_lo, 4); + + // p6 * 5 + p5 * 2 + p4 * 2 + p3 * 2 + p2 + p1 + p0 + q0 + q1 + // q6 * 5 + q5 * 2 + q4 * 2 + q3 * 2 + q2 + q1 + q0 + p0 + p1 + f14_lo = FilterAdd2Sub2(f14_lo, qp3_lo, pq1_lo, qp6_lo, qp6_lo); + *oqp4 = _mm_srli_epi16(f14_lo, 4); + + // p6 * 4 + p5 + p4 * 2 + p3 * 2 + p2 * 2 + p1 + p0 + q0 + q1 + q2 + // q6 * 4 + q5 + q4 * 2 + q3 * 2 + q2 * 2 + q1 + q0 + p0 + p1 + p2 + f14_lo = FilterAdd2Sub2(f14_lo, qp2_lo, pq2_lo, qp6_lo, qp5_lo); + *oqp3 = _mm_srli_epi16(f14_lo, 4); + + // p6 * 3 + p5 + p4 + p3 * 2 + p2 * 2 + p1 * 2 + p0 + q0 + q1 + q2 + q3 + // q6 * 3 + q5 + q4 + q3 * 2 + q2 * 2 + q1 * 2 + q0 + p0 + p1 + p2 + p3 + f14_lo = FilterAdd2Sub2(f14_lo, qp1_lo, pq3_lo, qp6_lo, qp4_lo); + *oqp2 = _mm_srli_epi16(f14_lo, 4); + + // p6 * 2 + p5 + p4 + p3 + p2 * 2 + p1 * 2 + p0 * 2 + q0 + q1 + q2 + q3 + q4 + // q6 * 2 + q5 + q4 + q3 + q2 * 2 + q1 * 2 + q0 * 2 + p0 + p1 + p2 + p3 + p4 + f14_lo = FilterAdd2Sub2(f14_lo, qp0_lo, pq4_lo, qp6_lo, qp3_lo); + *oqp1 = _mm_srli_epi16(f14_lo, 4); + + // p6 + p5 + p4 + p3 + p2 + p1 * 2 + p0 * 2 + q0 * 2 + q1 + q2 + q3 + q4 + q5 + // q6 + q5 + q4 + q3 + q2 + q1 * 2 + q0 * 2 + p0 * 2 + p1 + p2 + p3 + p4 + p5 + f14_lo = FilterAdd2Sub2(f14_lo, pq0_lo, pq5_lo, qp6_lo, qp2_lo); + *oqp0 = _mm_srli_epi16(f14_lo, 4); +} + +template <int bitdepth> +void LoopFilterFuncs_SSE4_1<bitdepth>::Horizontal14(void* dest, + ptrdiff_t stride8, + int outer_thresh, + int inner_thresh, + int hev_thresh) { + auto* const dst = static_cast<uint16_t*>(dest); + const ptrdiff_t stride = stride8 / 2; + const __m128i v_flat_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(1 << kThreshShift), 0); + const __m128i v_outer_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(outer_thresh << kThreshShift), 0); + const __m128i v_inner_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(inner_thresh << kThreshShift), 0); + const __m128i v_hev_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(hev_thresh << kThreshShift), 0); + + const __m128i p3 = LoadLo8(dst - 4 * stride); + const __m128i p2 = LoadLo8(dst - 3 * stride); + const __m128i p1 = LoadLo8(dst - 2 * stride); + const __m128i p0 = LoadLo8(dst - 1 * stride); + const __m128i q0 = LoadLo8(dst + 0 * stride); + const __m128i q1 = LoadLo8(dst + 1 * stride); + const __m128i q2 = LoadLo8(dst + 2 * stride); + const __m128i q3 = LoadLo8(dst + 3 * stride); + const __m128i qp3 = _mm_unpacklo_epi64(p3, q3); + const __m128i qp2 = _mm_unpacklo_epi64(p2, q2); + const __m128i qp1 = _mm_unpacklo_epi64(p1, q1); + const __m128i qp0 = _mm_unpacklo_epi64(p0, q0); + + const __m128i v_hev_mask = Hev(qp1, qp0, v_hev_thresh); + const __m128i v_needs_mask = + NeedsFilter8(qp3, qp2, qp1, qp0, v_outer_thresh, v_inner_thresh); + + __m128i oqp1; + __m128i oqp0; + + Filter4(qp1, qp0, &oqp1, &oqp0, v_needs_mask, v_hev_mask, bitdepth); + + const __m128i v_isflat4_mask = IsFlat4(qp3, qp2, qp1, qp0, v_flat_thresh); + const __m128i v_mask_lo = _mm_and_si128(v_needs_mask, v_isflat4_mask); + const __m128i v_mask = _mm_unpacklo_epi64(v_mask_lo, v_mask_lo); + + if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi16(v_mask, v_mask)) == 0) { + const __m128i p6 = LoadLo8(dst - 7 * stride); + const __m128i p5 = LoadLo8(dst - 6 * stride); + const __m128i p4 = LoadLo8(dst - 5 * stride); + const __m128i q4 = LoadLo8(dst + 4 * stride); + const __m128i q5 = LoadLo8(dst + 5 * stride); + const __m128i q6 = LoadLo8(dst + 6 * stride); + const __m128i qp6 = _mm_unpacklo_epi64(p6, q6); + const __m128i qp5 = _mm_unpacklo_epi64(p5, q5); + const __m128i qp4 = _mm_unpacklo_epi64(p4, q4); + + const __m128i v_isflatouter4_mask = + IsFlat4(qp6, qp5, qp4, qp0, v_flat_thresh); + const __m128i v_flat4_mask_lo = _mm_and_si128(v_mask, v_isflatouter4_mask); + const __m128i v_flat4_mask = + _mm_unpacklo_epi64(v_flat4_mask_lo, v_flat4_mask_lo); + + __m128i oqp2_f8; + __m128i oqp1_f8; + __m128i oqp0_f8; + + Filter8(qp3, qp2, qp1, qp0, &oqp2_f8, &oqp1_f8, &oqp0_f8); + + oqp2_f8 = _mm_blendv_epi8(qp2, oqp2_f8, v_mask); + oqp1 = _mm_blendv_epi8(oqp1, oqp1_f8, v_mask); + oqp0 = _mm_blendv_epi8(oqp0, oqp0_f8, v_mask); + + if (_mm_test_all_zeros(v_flat4_mask, + _mm_cmpeq_epi16(v_flat4_mask, v_flat4_mask)) == 0) { + __m128i oqp5_f14; + __m128i oqp4_f14; + __m128i oqp3_f14; + __m128i oqp2_f14; + __m128i oqp1_f14; + __m128i oqp0_f14; + + Filter14(qp6, qp5, qp4, qp3, qp2, qp1, qp0, &oqp5_f14, &oqp4_f14, + &oqp3_f14, &oqp2_f14, &oqp1_f14, &oqp0_f14); + + oqp5_f14 = _mm_blendv_epi8(qp5, oqp5_f14, v_flat4_mask); + oqp4_f14 = _mm_blendv_epi8(qp4, oqp4_f14, v_flat4_mask); + oqp3_f14 = _mm_blendv_epi8(qp3, oqp3_f14, v_flat4_mask); + oqp2_f8 = _mm_blendv_epi8(oqp2_f8, oqp2_f14, v_flat4_mask); + oqp1 = _mm_blendv_epi8(oqp1, oqp1_f14, v_flat4_mask); + oqp0 = _mm_blendv_epi8(oqp0, oqp0_f14, v_flat4_mask); + + StoreLo8(dst - 6 * stride, oqp5_f14); + StoreLo8(dst - 5 * stride, oqp4_f14); + StoreLo8(dst - 4 * stride, oqp3_f14); + + StoreHi8(dst + 3 * stride, oqp3_f14); + StoreHi8(dst + 4 * stride, oqp4_f14); + StoreHi8(dst + 5 * stride, oqp5_f14); + } + + StoreLo8(dst - 3 * stride, oqp2_f8); + StoreHi8(dst + 2 * stride, oqp2_f8); + } + + StoreLo8(dst - 2 * stride, oqp1); + StoreLo8(dst - 1 * stride, oqp0); + StoreHi8(dst + 0 * stride, oqp0); + StoreHi8(dst + 1 * stride, oqp1); +} + +inline void TransposeUpper4x8To8x4(const __m128i& x0, const __m128i& x1, + const __m128i& x2, const __m128i& x3, + const __m128i& x4, const __m128i& x5, + const __m128i& x6, const __m128i& x7, + __m128i* d0, __m128i* d1, __m128i* d2, + __m128i* d3) { + // input + // x0 00 01 02 03 xx xx xx xx + // x1 10 11 12 13 xx xx xx xx + // x2 20 21 22 23 xx xx xx xx + // x3 30 31 32 33 xx xx xx xx + // x4 40 41 42 43 xx xx xx xx + // x5 50 51 52 53 xx xx xx xx + // x6 60 61 62 63 xx xx xx xx + // x7 70 71 72 73 xx xx xx xx + // output + // d0 00 10 20 30 40 50 60 70 + // d1 01 11 21 31 41 51 61 71 + // d2 02 12 22 32 42 52 62 72 + // d3 03 13 23 33 43 53 63 73 + + // 00 10 01 11 02 12 03 13 + const __m128i w0 = _mm_unpackhi_epi16(x0, x1); + // 20 30 21 31 22 32 23 33 + const __m128i w1 = _mm_unpackhi_epi16(x2, x3); + // 40 50 41 51 42 52 43 53 + const __m128i w2 = _mm_unpackhi_epi16(x4, x5); + // 60 70 61 71 62 72 63 73 + const __m128i w3 = _mm_unpackhi_epi16(x6, x7); + + // 00 10 20 30 01 11 21 31 + const __m128i w4 = _mm_unpacklo_epi32(w0, w1); + // 40 50 60 70 41 51 61 71 + const __m128i w5 = _mm_unpacklo_epi32(w2, w3); + // 02 12 22 32 03 13 23 33 + const __m128i w6 = _mm_unpackhi_epi32(w0, w1); + // 42 52 62 72 43 53 63 73 + const __m128i w7 = _mm_unpackhi_epi32(w2, w3); + + // 00 10 20 30 40 50 60 70 + *d0 = _mm_unpacklo_epi64(w4, w5); + // 01 11 21 31 41 51 61 71 + *d1 = _mm_unpackhi_epi64(w4, w5); + // 02 12 22 32 42 52 62 72 + *d2 = _mm_unpacklo_epi64(w6, w7); + // 03 13 23 33 43 53 63 73 + *d3 = _mm_unpackhi_epi64(w6, w7); +} + +template <int bitdepth> +void LoopFilterFuncs_SSE4_1<bitdepth>::Vertical14(void* dest, ptrdiff_t stride8, + int outer_thresh, + int inner_thresh, + int hev_thresh) { + auto* const dst = static_cast<uint16_t*>(dest); + const ptrdiff_t stride = stride8 / 2; + const __m128i v_flat_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(1 << kThreshShift), 0); + const __m128i v_outer_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(outer_thresh << kThreshShift), 0); + const __m128i v_inner_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(inner_thresh << kThreshShift), 0); + const __m128i v_hev_thresh = + _mm_shufflelo_epi16(_mm_cvtsi32_si128(hev_thresh << kThreshShift), 0); + + // p7 p6 p5 p4 p3 p2 p1 p0 q0 q1 q2 q3 q4 q5 q6 q7 + // + // 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f + // 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f + // 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f + // 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f + + __m128i x0 = LoadUnaligned16(dst - 8 + 0 * stride); + __m128i x1 = LoadUnaligned16(dst - 8 + 1 * stride); + __m128i x2 = LoadUnaligned16(dst - 8 + 2 * stride); + __m128i x3 = LoadUnaligned16(dst - 8 + 3 * stride); + + __m128i p7, p6, p5, p4, p3, p2, p1, p0; + __m128i q7, q6, q5, q4, q3, q2, q1, q0; + + Transpose8x4To4x8(x0, x1, x2, x3, &p7, &p6, &p5, &p4, &p3, &p2, &p1, &p0); + + x0 = LoadUnaligned16(dst - 8 + 8 + 0 * stride); + x1 = LoadUnaligned16(dst - 8 + 8 + 1 * stride); + x2 = LoadUnaligned16(dst - 8 + 8 + 2 * stride); + x3 = LoadUnaligned16(dst - 8 + 8 + 3 * stride); + + Transpose8x4To4x8(x0, x1, x2, x3, &q0, &q1, &q2, &q3, &q4, &q5, &q6, &q7); + + __m128i qp7 = _mm_unpacklo_epi64(p7, q7); + __m128i qp6 = _mm_unpacklo_epi64(p6, q6); + __m128i qp5 = _mm_unpacklo_epi64(p5, q5); + __m128i qp4 = _mm_unpacklo_epi64(p4, q4); + __m128i qp3 = _mm_unpacklo_epi64(p3, q3); + __m128i qp2 = _mm_unpacklo_epi64(p2, q2); + __m128i qp1 = _mm_unpacklo_epi64(p1, q1); + __m128i qp0 = _mm_unpacklo_epi64(p0, q0); + + const __m128i v_hev_mask = Hev(qp1, qp0, v_hev_thresh); + const __m128i v_needs_mask = + NeedsFilter8(qp3, qp2, qp1, qp0, v_outer_thresh, v_inner_thresh); + + __m128i oqp1; + __m128i oqp0; + + Filter4(qp1, qp0, &oqp1, &oqp0, v_needs_mask, v_hev_mask, bitdepth); + + const __m128i v_isflat4_mask = IsFlat4(qp3, qp2, qp1, qp0, v_flat_thresh); + const __m128i v_mask_lo = _mm_and_si128(v_needs_mask, v_isflat4_mask); + const __m128i v_mask = _mm_unpacklo_epi64(v_mask_lo, v_mask_lo); + + if (_mm_test_all_zeros(v_mask, _mm_cmpeq_epi16(v_mask, v_mask)) == 0) { + const __m128i v_isflatouter4_mask = + IsFlat4(qp6, qp5, qp4, qp0, v_flat_thresh); + const __m128i v_flat4_mask_lo = _mm_and_si128(v_mask, v_isflatouter4_mask); + const __m128i v_flat4_mask = + _mm_unpacklo_epi64(v_flat4_mask_lo, v_flat4_mask_lo); + + __m128i oqp2_f8; + __m128i oqp1_f8; + __m128i oqp0_f8; + + Filter8(qp3, qp2, qp1, qp0, &oqp2_f8, &oqp1_f8, &oqp0_f8); + + oqp2_f8 = _mm_blendv_epi8(qp2, oqp2_f8, v_mask); + oqp1 = _mm_blendv_epi8(oqp1, oqp1_f8, v_mask); + oqp0 = _mm_blendv_epi8(oqp0, oqp0_f8, v_mask); + + if (_mm_test_all_zeros(v_flat4_mask, + _mm_cmpeq_epi16(v_flat4_mask, v_flat4_mask)) == 0) { + __m128i oqp5_f14; + __m128i oqp4_f14; + __m128i oqp3_f14; + __m128i oqp2_f14; + __m128i oqp1_f14; + __m128i oqp0_f14; + + Filter14(qp6, qp5, qp4, qp3, qp2, qp1, qp0, &oqp5_f14, &oqp4_f14, + &oqp3_f14, &oqp2_f14, &oqp1_f14, &oqp0_f14); + + oqp5_f14 = _mm_blendv_epi8(qp5, oqp5_f14, v_flat4_mask); + oqp4_f14 = _mm_blendv_epi8(qp4, oqp4_f14, v_flat4_mask); + oqp3_f14 = _mm_blendv_epi8(qp3, oqp3_f14, v_flat4_mask); + oqp2_f8 = _mm_blendv_epi8(oqp2_f8, oqp2_f14, v_flat4_mask); + oqp1 = _mm_blendv_epi8(oqp1, oqp1_f14, v_flat4_mask); + oqp0 = _mm_blendv_epi8(oqp0, oqp0_f14, v_flat4_mask); + qp3 = oqp3_f14; + qp4 = oqp4_f14; + qp5 = oqp5_f14; + } + qp2 = oqp2_f8; + } + + TransposeLower4x8To8x4(qp7, qp6, qp5, qp4, qp3, qp2, oqp1, oqp0, &x0, &x1, + &x2, &x3); + + StoreUnaligned16(dst - 8 + 0 * stride, x0); + StoreUnaligned16(dst - 8 + 1 * stride, x1); + StoreUnaligned16(dst - 8 + 2 * stride, x2); + StoreUnaligned16(dst - 8 + 3 * stride, x3); + + TransposeUpper4x8To8x4(oqp0, oqp1, qp2, qp3, qp4, qp5, qp6, qp7, &x0, &x1, + &x2, &x3); + + StoreUnaligned16(dst - 8 + 8 + 0 * stride, x0); + StoreUnaligned16(dst - 8 + 8 + 1 * stride, x1); + StoreUnaligned16(dst - 8 + 8 + 2 * stride, x2); + StoreUnaligned16(dst - 8 + 8 + 3 * stride, x3); +} + +using Defs10bpp = LoopFilterFuncs_SSE4_1<kBitdepth10>; + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + static_cast<void>(dsp); +#if DSP_ENABLED_10BPP_SSE4_1(LoopFilterSize4_LoopFilterTypeHorizontal) + dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeHorizontal] = + Defs10bpp::Horizontal4; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(LoopFilterSize6_LoopFilterTypeHorizontal) + dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeHorizontal] = + Defs10bpp::Horizontal6; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(LoopFilterSize8_LoopFilterTypeHorizontal) + dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeHorizontal] = + Defs10bpp::Horizontal8; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(LoopFilterSize14_LoopFilterTypeHorizontal) + dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeHorizontal] = + Defs10bpp::Horizontal14; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(LoopFilterSize4_LoopFilterTypeVertical) + dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeVertical] = + Defs10bpp::Vertical4; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(LoopFilterSize6_LoopFilterTypeVertical) + dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeVertical] = + Defs10bpp::Vertical6; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(LoopFilterSize8_LoopFilterTypeVertical) + dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeVertical] = + Defs10bpp::Vertical8; +#endif +#if DSP_ENABLED_10BPP_SSE4_1(LoopFilterSize14_LoopFilterTypeVertical) + dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeVertical] = + Defs10bpp::Vertical14; +#endif +} +#endif +} // namespace +} // namespace high_bitdepth + +void LoopFilterInit_SSE4_1() { + low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + high_bitdepth::Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_TARGETING_SSE4_1 +namespace libgav1 { +namespace dsp { + +void LoopFilterInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/loop_filter_sse4.h b/src/dsp/x86/loop_filter_sse4.h new file mode 100644 index 0000000..4795d8b --- /dev/null +++ b/src/dsp/x86/loop_filter_sse4.h @@ -0,0 +1,119 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_LOOP_FILTER_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_LOOP_FILTER_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::loop_filters, see the defines below for specifics. This +// function is not thread-safe. +void LoopFilterInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +// If sse4 is enabled and the baseline isn't set due to a higher level of +// optimization being enabled, signal the sse4 implementation should be used. +#if LIBGAV1_TARGETING_SSE4_1 + +#ifndef LIBGAV1_Dsp8bpp_LoopFilterSize4_LoopFilterTypeHorizontal +#define LIBGAV1_Dsp8bpp_LoopFilterSize4_LoopFilterTypeHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_LoopFilterSize6_LoopFilterTypeHorizontal +#define LIBGAV1_Dsp8bpp_LoopFilterSize6_LoopFilterTypeHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_LoopFilterSize8_LoopFilterTypeHorizontal +#define LIBGAV1_Dsp8bpp_LoopFilterSize8_LoopFilterTypeHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeHorizontal +#define LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_LoopFilterSize4_LoopFilterTypeVertical +#define LIBGAV1_Dsp8bpp_LoopFilterSize4_LoopFilterTypeVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_LoopFilterSize6_LoopFilterTypeVertical +#define LIBGAV1_Dsp8bpp_LoopFilterSize6_LoopFilterTypeVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_LoopFilterSize8_LoopFilterTypeVertical +#define LIBGAV1_Dsp8bpp_LoopFilterSize8_LoopFilterTypeVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeVertical +#define LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_LoopFilterSize4_LoopFilterTypeHorizontal +#define LIBGAV1_Dsp10bpp_LoopFilterSize4_LoopFilterTypeHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_LoopFilterSize6_LoopFilterTypeHorizontal +#define LIBGAV1_Dsp10bpp_LoopFilterSize6_LoopFilterTypeHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_LoopFilterSize8_LoopFilterTypeHorizontal +#define LIBGAV1_Dsp10bpp_LoopFilterSize8_LoopFilterTypeHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_LoopFilterSize14_LoopFilterTypeHorizontal +#define LIBGAV1_Dsp10bpp_LoopFilterSize14_LoopFilterTypeHorizontal \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_LoopFilterSize4_LoopFilterTypeVertical +#define LIBGAV1_Dsp10bpp_LoopFilterSize4_LoopFilterTypeVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_LoopFilterSize6_LoopFilterTypeVertical +#define LIBGAV1_Dsp10bpp_LoopFilterSize6_LoopFilterTypeVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_LoopFilterSize8_LoopFilterTypeVertical +#define LIBGAV1_Dsp10bpp_LoopFilterSize8_LoopFilterTypeVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_LoopFilterSize14_LoopFilterTypeVertical +#define LIBGAV1_Dsp10bpp_LoopFilterSize14_LoopFilterTypeVertical \ + LIBGAV1_CPU_SSE4_1 +#endif + +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_X86_LOOP_FILTER_SSE4_H_ diff --git a/src/dsp/x86/loop_restoration_10bit_avx2.cc b/src/dsp/x86/loop_restoration_10bit_avx2.cc new file mode 100644 index 0000000..702bdea --- /dev/null +++ b/src/dsp/x86/loop_restoration_10bit_avx2.cc @@ -0,0 +1,592 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/loop_restoration.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_AVX2 && LIBGAV1_MAX_BITDEPTH >= 10 +#include <immintrin.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> + +#include "src/dsp/common.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_avx2.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace { + +inline void WienerHorizontalClip(const __m256i s[2], + int16_t* const wiener_buffer) { + constexpr int offset = + 1 << (10 + kWienerFilterBits - kInterRoundBitsHorizontal - 1); + constexpr int limit = (offset << 2) - 1; + const __m256i offsets = _mm256_set1_epi16(-offset); + const __m256i limits = _mm256_set1_epi16(limit - offset); + const __m256i round = _mm256_set1_epi32(1 << (kInterRoundBitsHorizontal - 1)); + const __m256i sum0 = _mm256_add_epi32(s[0], round); + const __m256i sum1 = _mm256_add_epi32(s[1], round); + const __m256i rounded_sum0 = + _mm256_srai_epi32(sum0, kInterRoundBitsHorizontal); + const __m256i rounded_sum1 = + _mm256_srai_epi32(sum1, kInterRoundBitsHorizontal); + const __m256i rounded_sum = _mm256_packs_epi32(rounded_sum0, rounded_sum1); + const __m256i d0 = _mm256_max_epi16(rounded_sum, offsets); + const __m256i d1 = _mm256_min_epi16(d0, limits); + StoreAligned32(wiener_buffer, d1); +} + +inline void WienerHorizontalTap7Kernel(const __m256i s[7], + const __m256i filter[2], + int16_t* const wiener_buffer) { + const __m256i s06 = _mm256_add_epi16(s[0], s[6]); + const __m256i s15 = _mm256_add_epi16(s[1], s[5]); + const __m256i s24 = _mm256_add_epi16(s[2], s[4]); + const __m256i ss0 = _mm256_unpacklo_epi16(s06, s15); + const __m256i ss1 = _mm256_unpackhi_epi16(s06, s15); + const __m256i ss2 = _mm256_unpacklo_epi16(s24, s[3]); + const __m256i ss3 = _mm256_unpackhi_epi16(s24, s[3]); + __m256i madds[4]; + madds[0] = _mm256_madd_epi16(ss0, filter[0]); + madds[1] = _mm256_madd_epi16(ss1, filter[0]); + madds[2] = _mm256_madd_epi16(ss2, filter[1]); + madds[3] = _mm256_madd_epi16(ss3, filter[1]); + madds[0] = _mm256_add_epi32(madds[0], madds[2]); + madds[1] = _mm256_add_epi32(madds[1], madds[3]); + WienerHorizontalClip(madds, wiener_buffer); +} + +inline void WienerHorizontalTap5Kernel(const __m256i s[5], const __m256i filter, + int16_t* const wiener_buffer) { + const __m256i s04 = _mm256_add_epi16(s[0], s[4]); + const __m256i s13 = _mm256_add_epi16(s[1], s[3]); + const __m256i s2d = _mm256_add_epi16(s[2], s[2]); + const __m256i s0m = _mm256_sub_epi16(s04, s2d); + const __m256i s1m = _mm256_sub_epi16(s13, s2d); + const __m256i ss0 = _mm256_unpacklo_epi16(s0m, s1m); + const __m256i ss1 = _mm256_unpackhi_epi16(s0m, s1m); + __m256i madds[2]; + madds[0] = _mm256_madd_epi16(ss0, filter); + madds[1] = _mm256_madd_epi16(ss1, filter); + const __m256i s2_lo = _mm256_unpacklo_epi16(s[2], _mm256_setzero_si256()); + const __m256i s2_hi = _mm256_unpackhi_epi16(s[2], _mm256_setzero_si256()); + const __m256i s2x128_lo = _mm256_slli_epi32(s2_lo, 7); + const __m256i s2x128_hi = _mm256_slli_epi32(s2_hi, 7); + madds[0] = _mm256_add_epi32(madds[0], s2x128_lo); + madds[1] = _mm256_add_epi32(madds[1], s2x128_hi); + WienerHorizontalClip(madds, wiener_buffer); +} + +inline void WienerHorizontalTap3Kernel(const __m256i s[3], const __m256i filter, + int16_t* const wiener_buffer) { + const __m256i s02 = _mm256_add_epi16(s[0], s[2]); + const __m256i ss0 = _mm256_unpacklo_epi16(s02, s[1]); + const __m256i ss1 = _mm256_unpackhi_epi16(s02, s[1]); + __m256i madds[2]; + madds[0] = _mm256_madd_epi16(ss0, filter); + madds[1] = _mm256_madd_epi16(ss1, filter); + WienerHorizontalClip(madds, wiener_buffer); +} + +inline void WienerHorizontalTap7(const uint16_t* src, + const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + const __m256i* const coefficients, + int16_t** const wiener_buffer) { + __m256i filter[2]; + filter[0] = _mm256_shuffle_epi32(*coefficients, 0x0); + filter[1] = _mm256_shuffle_epi32(*coefficients, 0x55); + for (int y = height; y != 0; --y) { + ptrdiff_t x = 0; + do { + __m256i s[7]; + s[0] = LoadUnaligned32(src + x + 0); + s[1] = LoadUnaligned32(src + x + 1); + s[2] = LoadUnaligned32(src + x + 2); + s[3] = LoadUnaligned32(src + x + 3); + s[4] = LoadUnaligned32(src + x + 4); + s[5] = LoadUnaligned32(src + x + 5); + s[6] = LoadUnaligned32(src + x + 6); + WienerHorizontalTap7Kernel(s, filter, *wiener_buffer + x); + x += 16; + } while (x < width); + src += src_stride; + *wiener_buffer += width; + } +} + +inline void WienerHorizontalTap5(const uint16_t* src, + const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + const __m256i* const coefficients, + int16_t** const wiener_buffer) { + const __m256i filter = + _mm256_shuffle_epi8(*coefficients, _mm256_set1_epi32(0x05040302)); + for (int y = height; y != 0; --y) { + ptrdiff_t x = 0; + do { + __m256i s[5]; + s[0] = LoadUnaligned32(src + x + 0); + s[1] = LoadUnaligned32(src + x + 1); + s[2] = LoadUnaligned32(src + x + 2); + s[3] = LoadUnaligned32(src + x + 3); + s[4] = LoadUnaligned32(src + x + 4); + WienerHorizontalTap5Kernel(s, filter, *wiener_buffer + x); + x += 16; + } while (x < width); + src += src_stride; + *wiener_buffer += width; + } +} + +inline void WienerHorizontalTap3(const uint16_t* src, + const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + const __m256i* const coefficients, + int16_t** const wiener_buffer) { + const auto filter = _mm256_shuffle_epi32(*coefficients, 0x55); + for (int y = height; y != 0; --y) { + ptrdiff_t x = 0; + do { + __m256i s[3]; + s[0] = LoadUnaligned32(src + x + 0); + s[1] = LoadUnaligned32(src + x + 1); + s[2] = LoadUnaligned32(src + x + 2); + WienerHorizontalTap3Kernel(s, filter, *wiener_buffer + x); + x += 16; + } while (x < width); + src += src_stride; + *wiener_buffer += width; + } +} + +inline void WienerHorizontalTap1(const uint16_t* src, + const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + int16_t** const wiener_buffer) { + for (int y = height; y != 0; --y) { + ptrdiff_t x = 0; + do { + const __m256i s0 = LoadUnaligned32(src + x); + const __m256i d0 = _mm256_slli_epi16(s0, 4); + StoreAligned32(*wiener_buffer + x, d0); + x += 16; + } while (x < width); + src += src_stride; + *wiener_buffer += width; + } +} + +inline __m256i WienerVertical7(const __m256i a[4], const __m256i filter[4]) { + const __m256i madd0 = _mm256_madd_epi16(a[0], filter[0]); + const __m256i madd1 = _mm256_madd_epi16(a[1], filter[1]); + const __m256i madd2 = _mm256_madd_epi16(a[2], filter[2]); + const __m256i madd3 = _mm256_madd_epi16(a[3], filter[3]); + const __m256i madd01 = _mm256_add_epi32(madd0, madd1); + const __m256i madd23 = _mm256_add_epi32(madd2, madd3); + const __m256i sum = _mm256_add_epi32(madd01, madd23); + return _mm256_srai_epi32(sum, kInterRoundBitsVertical); +} + +inline __m256i WienerVertical5(const __m256i a[3], const __m256i filter[3]) { + const __m256i madd0 = _mm256_madd_epi16(a[0], filter[0]); + const __m256i madd1 = _mm256_madd_epi16(a[1], filter[1]); + const __m256i madd2 = _mm256_madd_epi16(a[2], filter[2]); + const __m256i madd01 = _mm256_add_epi32(madd0, madd1); + const __m256i sum = _mm256_add_epi32(madd01, madd2); + return _mm256_srai_epi32(sum, kInterRoundBitsVertical); +} + +inline __m256i WienerVertical3(const __m256i a[2], const __m256i filter[2]) { + const __m256i madd0 = _mm256_madd_epi16(a[0], filter[0]); + const __m256i madd1 = _mm256_madd_epi16(a[1], filter[1]); + const __m256i sum = _mm256_add_epi32(madd0, madd1); + return _mm256_srai_epi32(sum, kInterRoundBitsVertical); +} + +inline __m256i WienerVerticalClip(const __m256i s[2]) { + const __m256i d = _mm256_packus_epi32(s[0], s[1]); + return _mm256_min_epu16(d, _mm256_set1_epi16(1023)); +} + +inline __m256i WienerVerticalFilter7(const __m256i a[7], + const __m256i filter[2]) { + const __m256i round = _mm256_set1_epi16(1 << (kInterRoundBitsVertical - 1)); + __m256i b[4], c[2]; + b[0] = _mm256_unpacklo_epi16(a[0], a[1]); + b[1] = _mm256_unpacklo_epi16(a[2], a[3]); + b[2] = _mm256_unpacklo_epi16(a[4], a[5]); + b[3] = _mm256_unpacklo_epi16(a[6], round); + c[0] = WienerVertical7(b, filter); + b[0] = _mm256_unpackhi_epi16(a[0], a[1]); + b[1] = _mm256_unpackhi_epi16(a[2], a[3]); + b[2] = _mm256_unpackhi_epi16(a[4], a[5]); + b[3] = _mm256_unpackhi_epi16(a[6], round); + c[1] = WienerVertical7(b, filter); + return WienerVerticalClip(c); +} + +inline __m256i WienerVerticalFilter5(const __m256i a[5], + const __m256i filter[3]) { + const __m256i round = _mm256_set1_epi16(1 << (kInterRoundBitsVertical - 1)); + __m256i b[3], c[2]; + b[0] = _mm256_unpacklo_epi16(a[0], a[1]); + b[1] = _mm256_unpacklo_epi16(a[2], a[3]); + b[2] = _mm256_unpacklo_epi16(a[4], round); + c[0] = WienerVertical5(b, filter); + b[0] = _mm256_unpackhi_epi16(a[0], a[1]); + b[1] = _mm256_unpackhi_epi16(a[2], a[3]); + b[2] = _mm256_unpackhi_epi16(a[4], round); + c[1] = WienerVertical5(b, filter); + return WienerVerticalClip(c); +} + +inline __m256i WienerVerticalFilter3(const __m256i a[3], + const __m256i filter[2]) { + const __m256i round = _mm256_set1_epi16(1 << (kInterRoundBitsVertical - 1)); + __m256i b[2], c[2]; + b[0] = _mm256_unpacklo_epi16(a[0], a[1]); + b[1] = _mm256_unpacklo_epi16(a[2], round); + c[0] = WienerVertical3(b, filter); + b[0] = _mm256_unpackhi_epi16(a[0], a[1]); + b[1] = _mm256_unpackhi_epi16(a[2], round); + c[1] = WienerVertical3(b, filter); + return WienerVerticalClip(c); +} + +inline __m256i WienerVerticalTap7Kernel(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m256i filter[2], __m256i a[7]) { + a[0] = LoadAligned32(wiener_buffer + 0 * wiener_stride); + a[1] = LoadAligned32(wiener_buffer + 1 * wiener_stride); + a[2] = LoadAligned32(wiener_buffer + 2 * wiener_stride); + a[3] = LoadAligned32(wiener_buffer + 3 * wiener_stride); + a[4] = LoadAligned32(wiener_buffer + 4 * wiener_stride); + a[5] = LoadAligned32(wiener_buffer + 5 * wiener_stride); + a[6] = LoadAligned32(wiener_buffer + 6 * wiener_stride); + return WienerVerticalFilter7(a, filter); +} + +inline __m256i WienerVerticalTap5Kernel(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m256i filter[3], __m256i a[5]) { + a[0] = LoadAligned32(wiener_buffer + 0 * wiener_stride); + a[1] = LoadAligned32(wiener_buffer + 1 * wiener_stride); + a[2] = LoadAligned32(wiener_buffer + 2 * wiener_stride); + a[3] = LoadAligned32(wiener_buffer + 3 * wiener_stride); + a[4] = LoadAligned32(wiener_buffer + 4 * wiener_stride); + return WienerVerticalFilter5(a, filter); +} + +inline __m256i WienerVerticalTap3Kernel(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m256i filter[2], __m256i a[3]) { + a[0] = LoadAligned32(wiener_buffer + 0 * wiener_stride); + a[1] = LoadAligned32(wiener_buffer + 1 * wiener_stride); + a[2] = LoadAligned32(wiener_buffer + 2 * wiener_stride); + return WienerVerticalFilter3(a, filter); +} + +inline void WienerVerticalTap7Kernel2(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m256i filter[2], __m256i d[2]) { + __m256i a[8]; + d[0] = WienerVerticalTap7Kernel(wiener_buffer, wiener_stride, filter, a); + a[7] = LoadAligned32(wiener_buffer + 7 * wiener_stride); + d[1] = WienerVerticalFilter7(a + 1, filter); +} + +inline void WienerVerticalTap5Kernel2(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m256i filter[3], __m256i d[2]) { + __m256i a[6]; + d[0] = WienerVerticalTap5Kernel(wiener_buffer, wiener_stride, filter, a); + a[5] = LoadAligned32(wiener_buffer + 5 * wiener_stride); + d[1] = WienerVerticalFilter5(a + 1, filter); +} + +inline void WienerVerticalTap3Kernel2(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m256i filter[2], __m256i d[2]) { + __m256i a[4]; + d[0] = WienerVerticalTap3Kernel(wiener_buffer, wiener_stride, filter, a); + a[3] = LoadAligned32(wiener_buffer + 3 * wiener_stride); + d[1] = WienerVerticalFilter3(a + 1, filter); +} + +inline void WienerVerticalTap7(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + const int16_t coefficients[4], uint16_t* dst, + const ptrdiff_t dst_stride) { + const __m256i c = _mm256_broadcastq_epi64(LoadLo8(coefficients)); + __m256i filter[4]; + filter[0] = _mm256_shuffle_epi32(c, 0x0); + filter[1] = _mm256_shuffle_epi32(c, 0x55); + filter[2] = _mm256_shuffle_epi8(c, _mm256_set1_epi32(0x03020504)); + filter[3] = + _mm256_set1_epi32((1 << 16) | static_cast<uint16_t>(coefficients[0])); + for (int y = height >> 1; y > 0; --y) { + ptrdiff_t x = 0; + do { + __m256i d[2]; + WienerVerticalTap7Kernel2(wiener_buffer + x, width, filter, d); + StoreUnaligned32(dst + x, d[0]); + StoreUnaligned32(dst + dst_stride + x, d[1]); + x += 16; + } while (x < width); + dst += 2 * dst_stride; + wiener_buffer += 2 * width; + } + + if ((height & 1) != 0) { + ptrdiff_t x = 0; + do { + __m256i a[7]; + const __m256i d = + WienerVerticalTap7Kernel(wiener_buffer + x, width, filter, a); + StoreUnaligned32(dst + x, d); + x += 16; + } while (x < width); + } +} + +inline void WienerVerticalTap5(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + const int16_t coefficients[3], uint16_t* dst, + const ptrdiff_t dst_stride) { + const __m256i c = _mm256_broadcastq_epi64(LoadLo8(coefficients)); + __m256i filter[3]; + filter[0] = _mm256_shuffle_epi32(c, 0x0); + filter[1] = _mm256_shuffle_epi8(c, _mm256_set1_epi32(0x03020504)); + filter[2] = + _mm256_set1_epi32((1 << 16) | static_cast<uint16_t>(coefficients[0])); + for (int y = height >> 1; y > 0; --y) { + ptrdiff_t x = 0; + do { + __m256i d[2]; + WienerVerticalTap5Kernel2(wiener_buffer + x, width, filter, d); + StoreUnaligned32(dst + x, d[0]); + StoreUnaligned32(dst + dst_stride + x, d[1]); + x += 16; + } while (x < width); + dst += 2 * dst_stride; + wiener_buffer += 2 * width; + } + + if ((height & 1) != 0) { + ptrdiff_t x = 0; + do { + __m256i a[5]; + const __m256i d = + WienerVerticalTap5Kernel(wiener_buffer + x, width, filter, a); + StoreUnaligned32(dst + x, d); + x += 16; + } while (x < width); + } +} + +inline void WienerVerticalTap3(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + const int16_t coefficients[2], uint16_t* dst, + const ptrdiff_t dst_stride) { + __m256i filter[2]; + filter[0] = + _mm256_set1_epi32(*reinterpret_cast<const int32_t*>(coefficients)); + filter[1] = + _mm256_set1_epi32((1 << 16) | static_cast<uint16_t>(coefficients[0])); + for (int y = height >> 1; y > 0; --y) { + ptrdiff_t x = 0; + do { + __m256i d[2][2]; + WienerVerticalTap3Kernel2(wiener_buffer + x, width, filter, d[0]); + StoreUnaligned32(dst + x, d[0][0]); + StoreUnaligned32(dst + dst_stride + x, d[0][1]); + x += 16; + } while (x < width); + dst += 2 * dst_stride; + wiener_buffer += 2 * width; + } + + if ((height & 1) != 0) { + ptrdiff_t x = 0; + do { + __m256i a[3]; + const __m256i d = + WienerVerticalTap3Kernel(wiener_buffer + x, width, filter, a); + StoreUnaligned32(dst + x, d); + x += 16; + } while (x < width); + } +} + +inline void WienerVerticalTap1Kernel(const int16_t* const wiener_buffer, + uint16_t* const dst) { + const __m256i a = LoadAligned32(wiener_buffer); + const __m256i b = _mm256_add_epi16(a, _mm256_set1_epi16(8)); + const __m256i c = _mm256_srai_epi16(b, 4); + const __m256i d = _mm256_max_epi16(c, _mm256_setzero_si256()); + const __m256i e = _mm256_min_epi16(d, _mm256_set1_epi16(1023)); + StoreUnaligned32(dst, e); +} + +inline void WienerVerticalTap1(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + uint16_t* dst, const ptrdiff_t dst_stride) { + for (int y = height >> 1; y > 0; --y) { + ptrdiff_t x = 0; + do { + WienerVerticalTap1Kernel(wiener_buffer + x, dst + x); + WienerVerticalTap1Kernel(wiener_buffer + width + x, dst + dst_stride + x); + x += 16; + } while (x < width); + dst += 2 * dst_stride; + wiener_buffer += 2 * width; + } + + if ((height & 1) != 0) { + ptrdiff_t x = 0; + do { + WienerVerticalTap1Kernel(wiener_buffer + x, dst + x); + x += 16; + } while (x < width); + } +} + +void WienerFilter_AVX2(const RestorationUnitInfo& restoration_info, + const void* const source, const void* const top_border, + const void* const bottom_border, const ptrdiff_t stride, + const int width, const int height, + RestorationBuffer* const restoration_buffer, + void* const dest) { + const int16_t* const number_leading_zero_coefficients = + restoration_info.wiener_info.number_leading_zero_coefficients; + const int number_rows_to_skip = std::max( + static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]), + 1); + const ptrdiff_t wiener_stride = Align(width, 16); + int16_t* const wiener_buffer_vertical = restoration_buffer->wiener_buffer; + // The values are saturated to 13 bits before storing. + int16_t* wiener_buffer_horizontal = + wiener_buffer_vertical + number_rows_to_skip * wiener_stride; + + // horizontal filtering. + // Over-reads up to 15 - |kRestorationHorizontalBorder| values. + const int height_horizontal = + height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip; + const int height_extra = (height_horizontal - height) >> 1; + assert(height_extra <= 2); + const auto* const src = static_cast<const uint16_t*>(source); + const auto* const top = static_cast<const uint16_t*>(top_border); + const auto* const bottom = static_cast<const uint16_t*>(bottom_border); + const __m128i c = + LoadLo8(restoration_info.wiener_info.filter[WienerInfo::kHorizontal]); + const __m256i coefficients_horizontal = _mm256_broadcastq_epi64(c); + if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) { + WienerHorizontalTap7(top + (2 - height_extra) * stride - 3, stride, + wiener_stride, height_extra, &coefficients_horizontal, + &wiener_buffer_horizontal); + WienerHorizontalTap7(src - 3, stride, wiener_stride, height, + &coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap7(bottom - 3, stride, wiener_stride, height_extra, + &coefficients_horizontal, &wiener_buffer_horizontal); + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) { + WienerHorizontalTap5(top + (2 - height_extra) * stride - 2, stride, + wiener_stride, height_extra, &coefficients_horizontal, + &wiener_buffer_horizontal); + WienerHorizontalTap5(src - 2, stride, wiener_stride, height, + &coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap5(bottom - 2, stride, wiener_stride, height_extra, + &coefficients_horizontal, &wiener_buffer_horizontal); + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) { + // The maximum over-reads happen here. + WienerHorizontalTap3(top + (2 - height_extra) * stride - 1, stride, + wiener_stride, height_extra, &coefficients_horizontal, + &wiener_buffer_horizontal); + WienerHorizontalTap3(src - 1, stride, wiener_stride, height, + &coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap3(bottom - 1, stride, wiener_stride, height_extra, + &coefficients_horizontal, &wiener_buffer_horizontal); + } else { + assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3); + WienerHorizontalTap1(top + (2 - height_extra) * stride, stride, + wiener_stride, height_extra, + &wiener_buffer_horizontal); + WienerHorizontalTap1(src, stride, wiener_stride, height, + &wiener_buffer_horizontal); + WienerHorizontalTap1(bottom, stride, wiener_stride, height_extra, + &wiener_buffer_horizontal); + } + + // vertical filtering. + // Over-writes up to 15 values. + const int16_t* const filter_vertical = + restoration_info.wiener_info.filter[WienerInfo::kVertical]; + auto* dst = static_cast<uint16_t*>(dest); + if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) { + // Because the top row of |source| is a duplicate of the second row, and the + // bottom row of |source| is a duplicate of its above row, we can duplicate + // the top and bottom row of |wiener_buffer| accordingly. + memcpy(wiener_buffer_horizontal, wiener_buffer_horizontal - wiener_stride, + sizeof(*wiener_buffer_horizontal) * wiener_stride); + memcpy(restoration_buffer->wiener_buffer, + restoration_buffer->wiener_buffer + wiener_stride, + sizeof(*restoration_buffer->wiener_buffer) * wiener_stride); + WienerVerticalTap7(wiener_buffer_vertical, wiener_stride, height, + filter_vertical, dst, stride); + } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) { + WienerVerticalTap5(wiener_buffer_vertical + wiener_stride, wiener_stride, + height, filter_vertical + 1, dst, stride); + } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) { + WienerVerticalTap3(wiener_buffer_vertical + 2 * wiener_stride, + wiener_stride, height, filter_vertical + 2, dst, stride); + } else { + assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3); + WienerVerticalTap1(wiener_buffer_vertical + 3 * wiener_stride, + wiener_stride, height, dst, stride); + } +} + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); +#if DSP_ENABLED_10BPP_AVX2(WienerFilter) + dsp->loop_restorations[0] = WienerFilter_AVX2; +#endif +} + +} // namespace + +void LoopRestorationInit10bpp_AVX2() { Init10bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !(LIBGAV1_TARGETING_AVX2 && LIBGAV1_MAX_BITDEPTH >= 10) +namespace libgav1 { +namespace dsp { + +void LoopRestorationInit10bpp_AVX2() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_AVX2 && LIBGAV1_MAX_BITDEPTH >= 10 diff --git a/src/dsp/x86/loop_restoration_10bit_sse4.cc b/src/dsp/x86/loop_restoration_10bit_sse4.cc new file mode 100644 index 0000000..0598435 --- /dev/null +++ b/src/dsp/x86/loop_restoration_10bit_sse4.cc @@ -0,0 +1,551 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/loop_restoration.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 && LIBGAV1_MAX_BITDEPTH >= 10 +#include <smmintrin.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> + +#include "src/dsp/common.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace { + +inline void WienerHorizontalClip(const __m128i s[2], + int16_t* const wiener_buffer) { + constexpr int offset = + 1 << (10 + kWienerFilterBits - kInterRoundBitsHorizontal - 1); + constexpr int limit = (offset << 2) - 1; + const __m128i offsets = _mm_set1_epi16(-offset); + const __m128i limits = _mm_set1_epi16(limit - offset); + const __m128i round = _mm_set1_epi32(1 << (kInterRoundBitsHorizontal - 1)); + const __m128i sum0 = _mm_add_epi32(s[0], round); + const __m128i sum1 = _mm_add_epi32(s[1], round); + const __m128i rounded_sum0 = _mm_srai_epi32(sum0, kInterRoundBitsHorizontal); + const __m128i rounded_sum1 = _mm_srai_epi32(sum1, kInterRoundBitsHorizontal); + const __m128i rounded_sum = _mm_packs_epi32(rounded_sum0, rounded_sum1); + const __m128i d0 = _mm_max_epi16(rounded_sum, offsets); + const __m128i d1 = _mm_min_epi16(d0, limits); + StoreAligned16(wiener_buffer, d1); +} + +inline void WienerHorizontalTap7(const uint16_t* src, + const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + const __m128i coefficients, + int16_t** const wiener_buffer) { + __m128i filter[2]; + filter[0] = _mm_shuffle_epi32(coefficients, 0x0); + filter[1] = _mm_shuffle_epi32(coefficients, 0x55); + for (int y = height; y != 0; --y) { + ptrdiff_t x = 0; + do { + __m128i s[7], madds[4]; + s[0] = LoadUnaligned16(src + x + 0); + s[1] = LoadUnaligned16(src + x + 1); + s[2] = LoadUnaligned16(src + x + 2); + s[3] = LoadUnaligned16(src + x + 3); + s[4] = LoadUnaligned16(src + x + 4); + s[5] = LoadUnaligned16(src + x + 5); + s[6] = LoadUnaligned16(src + x + 6); + const __m128i s06 = _mm_add_epi16(s[0], s[6]); + const __m128i s15 = _mm_add_epi16(s[1], s[5]); + const __m128i s24 = _mm_add_epi16(s[2], s[4]); + const __m128i ss0 = _mm_unpacklo_epi16(s06, s15); + const __m128i ss1 = _mm_unpackhi_epi16(s06, s15); + const __m128i ss2 = _mm_unpacklo_epi16(s24, s[3]); + const __m128i ss3 = _mm_unpackhi_epi16(s24, s[3]); + madds[0] = _mm_madd_epi16(ss0, filter[0]); + madds[1] = _mm_madd_epi16(ss1, filter[0]); + madds[2] = _mm_madd_epi16(ss2, filter[1]); + madds[3] = _mm_madd_epi16(ss3, filter[1]); + madds[0] = _mm_add_epi32(madds[0], madds[2]); + madds[1] = _mm_add_epi32(madds[1], madds[3]); + WienerHorizontalClip(madds, *wiener_buffer + x); + x += 8; + } while (x < width); + src += src_stride; + *wiener_buffer += width; + } +} + +inline void WienerHorizontalTap5(const uint16_t* src, + const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + const __m128i coefficients, + int16_t** const wiener_buffer) { + const __m128i filter = + _mm_shuffle_epi8(coefficients, _mm_set1_epi32(0x05040302)); + for (int y = height; y != 0; --y) { + ptrdiff_t x = 0; + do { + __m128i s[5], madds[2]; + s[0] = LoadUnaligned16(src + x + 0); + s[1] = LoadUnaligned16(src + x + 1); + s[2] = LoadUnaligned16(src + x + 2); + s[3] = LoadUnaligned16(src + x + 3); + s[4] = LoadUnaligned16(src + x + 4); + const __m128i s04 = _mm_add_epi16(s[0], s[4]); + const __m128i s13 = _mm_add_epi16(s[1], s[3]); + const __m128i s2d = _mm_add_epi16(s[2], s[2]); + const __m128i s0m = _mm_sub_epi16(s04, s2d); + const __m128i s1m = _mm_sub_epi16(s13, s2d); + const __m128i ss0 = _mm_unpacklo_epi16(s0m, s1m); + const __m128i ss1 = _mm_unpackhi_epi16(s0m, s1m); + madds[0] = _mm_madd_epi16(ss0, filter); + madds[1] = _mm_madd_epi16(ss1, filter); + const __m128i s2_lo = _mm_unpacklo_epi16(s[2], _mm_setzero_si128()); + const __m128i s2_hi = _mm_unpackhi_epi16(s[2], _mm_setzero_si128()); + const __m128i s2x128_lo = _mm_slli_epi32(s2_lo, 7); + const __m128i s2x128_hi = _mm_slli_epi32(s2_hi, 7); + madds[0] = _mm_add_epi32(madds[0], s2x128_lo); + madds[1] = _mm_add_epi32(madds[1], s2x128_hi); + WienerHorizontalClip(madds, *wiener_buffer + x); + x += 8; + } while (x < width); + src += src_stride; + *wiener_buffer += width; + } +} + +inline void WienerHorizontalTap3(const uint16_t* src, + const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + const __m128i coefficients, + int16_t** const wiener_buffer) { + const auto filter = _mm_shuffle_epi32(coefficients, 0x55); + for (int y = height; y != 0; --y) { + ptrdiff_t x = 0; + do { + __m128i s[3], madds[2]; + s[0] = LoadUnaligned16(src + x + 0); + s[1] = LoadUnaligned16(src + x + 1); + s[2] = LoadUnaligned16(src + x + 2); + const __m128i s02 = _mm_add_epi16(s[0], s[2]); + const __m128i ss0 = _mm_unpacklo_epi16(s02, s[1]); + const __m128i ss1 = _mm_unpackhi_epi16(s02, s[1]); + madds[0] = _mm_madd_epi16(ss0, filter); + madds[1] = _mm_madd_epi16(ss1, filter); + WienerHorizontalClip(madds, *wiener_buffer + x); + x += 8; + } while (x < width); + src += src_stride; + *wiener_buffer += width; + } +} + +inline void WienerHorizontalTap1(const uint16_t* src, + const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + int16_t** const wiener_buffer) { + for (int y = height; y != 0; --y) { + ptrdiff_t x = 0; + do { + const __m128i s = LoadUnaligned16(src + x); + const __m128i d = _mm_slli_epi16(s, 4); + StoreAligned16(*wiener_buffer + x, d); + x += 8; + } while (x < width); + src += src_stride; + *wiener_buffer += width; + } +} + +inline __m128i WienerVertical7(const __m128i a[4], const __m128i filter[4]) { + const __m128i madd0 = _mm_madd_epi16(a[0], filter[0]); + const __m128i madd1 = _mm_madd_epi16(a[1], filter[1]); + const __m128i madd2 = _mm_madd_epi16(a[2], filter[2]); + const __m128i madd3 = _mm_madd_epi16(a[3], filter[3]); + const __m128i madd01 = _mm_add_epi32(madd0, madd1); + const __m128i madd23 = _mm_add_epi32(madd2, madd3); + const __m128i sum = _mm_add_epi32(madd01, madd23); + return _mm_srai_epi32(sum, kInterRoundBitsVertical); +} + +inline __m128i WienerVertical5(const __m128i a[3], const __m128i filter[3]) { + const __m128i madd0 = _mm_madd_epi16(a[0], filter[0]); + const __m128i madd1 = _mm_madd_epi16(a[1], filter[1]); + const __m128i madd2 = _mm_madd_epi16(a[2], filter[2]); + const __m128i madd01 = _mm_add_epi32(madd0, madd1); + const __m128i sum = _mm_add_epi32(madd01, madd2); + return _mm_srai_epi32(sum, kInterRoundBitsVertical); +} + +inline __m128i WienerVertical3(const __m128i a[2], const __m128i filter[2]) { + const __m128i madd0 = _mm_madd_epi16(a[0], filter[0]); + const __m128i madd1 = _mm_madd_epi16(a[1], filter[1]); + const __m128i sum = _mm_add_epi32(madd0, madd1); + return _mm_srai_epi32(sum, kInterRoundBitsVertical); +} + +inline __m128i WienerVerticalClip(const __m128i s[2]) { + const __m128i d = _mm_packus_epi32(s[0], s[1]); + return _mm_min_epu16(d, _mm_set1_epi16(1023)); +} + +inline __m128i WienerVerticalFilter7(const __m128i a[7], + const __m128i filter[2]) { + const __m128i round = _mm_set1_epi16(1 << (kInterRoundBitsVertical - 1)); + __m128i b[4], c[2]; + b[0] = _mm_unpacklo_epi16(a[0], a[1]); + b[1] = _mm_unpacklo_epi16(a[2], a[3]); + b[2] = _mm_unpacklo_epi16(a[4], a[5]); + b[3] = _mm_unpacklo_epi16(a[6], round); + c[0] = WienerVertical7(b, filter); + b[0] = _mm_unpackhi_epi16(a[0], a[1]); + b[1] = _mm_unpackhi_epi16(a[2], a[3]); + b[2] = _mm_unpackhi_epi16(a[4], a[5]); + b[3] = _mm_unpackhi_epi16(a[6], round); + c[1] = WienerVertical7(b, filter); + return WienerVerticalClip(c); +} + +inline __m128i WienerVerticalFilter5(const __m128i a[5], + const __m128i filter[3]) { + const __m128i round = _mm_set1_epi16(1 << (kInterRoundBitsVertical - 1)); + __m128i b[3], c[2]; + b[0] = _mm_unpacklo_epi16(a[0], a[1]); + b[1] = _mm_unpacklo_epi16(a[2], a[3]); + b[2] = _mm_unpacklo_epi16(a[4], round); + c[0] = WienerVertical5(b, filter); + b[0] = _mm_unpackhi_epi16(a[0], a[1]); + b[1] = _mm_unpackhi_epi16(a[2], a[3]); + b[2] = _mm_unpackhi_epi16(a[4], round); + c[1] = WienerVertical5(b, filter); + return WienerVerticalClip(c); +} + +inline __m128i WienerVerticalFilter3(const __m128i a[3], + const __m128i filter[2]) { + const __m128i round = _mm_set1_epi16(1 << (kInterRoundBitsVertical - 1)); + __m128i b[2], c[2]; + b[0] = _mm_unpacklo_epi16(a[0], a[1]); + b[1] = _mm_unpacklo_epi16(a[2], round); + c[0] = WienerVertical3(b, filter); + b[0] = _mm_unpackhi_epi16(a[0], a[1]); + b[1] = _mm_unpackhi_epi16(a[2], round); + c[1] = WienerVertical3(b, filter); + return WienerVerticalClip(c); +} + +inline __m128i WienerVerticalTap7Kernel(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m128i filter[2], __m128i a[7]) { + a[0] = LoadAligned16(wiener_buffer + 0 * wiener_stride); + a[1] = LoadAligned16(wiener_buffer + 1 * wiener_stride); + a[2] = LoadAligned16(wiener_buffer + 2 * wiener_stride); + a[3] = LoadAligned16(wiener_buffer + 3 * wiener_stride); + a[4] = LoadAligned16(wiener_buffer + 4 * wiener_stride); + a[5] = LoadAligned16(wiener_buffer + 5 * wiener_stride); + a[6] = LoadAligned16(wiener_buffer + 6 * wiener_stride); + return WienerVerticalFilter7(a, filter); +} + +inline __m128i WienerVerticalTap5Kernel(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m128i filter[3], __m128i a[5]) { + a[0] = LoadAligned16(wiener_buffer + 0 * wiener_stride); + a[1] = LoadAligned16(wiener_buffer + 1 * wiener_stride); + a[2] = LoadAligned16(wiener_buffer + 2 * wiener_stride); + a[3] = LoadAligned16(wiener_buffer + 3 * wiener_stride); + a[4] = LoadAligned16(wiener_buffer + 4 * wiener_stride); + return WienerVerticalFilter5(a, filter); +} + +inline __m128i WienerVerticalTap3Kernel(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m128i filter[2], __m128i a[3]) { + a[0] = LoadAligned16(wiener_buffer + 0 * wiener_stride); + a[1] = LoadAligned16(wiener_buffer + 1 * wiener_stride); + a[2] = LoadAligned16(wiener_buffer + 2 * wiener_stride); + return WienerVerticalFilter3(a, filter); +} + +inline void WienerVerticalTap7(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + const int16_t coefficients[4], uint16_t* dst, + const ptrdiff_t dst_stride) { + const __m128i c = LoadLo8(coefficients); + __m128i filter[4]; + filter[0] = _mm_shuffle_epi32(c, 0x0); + filter[1] = _mm_shuffle_epi32(c, 0x55); + filter[2] = _mm_shuffle_epi8(c, _mm_set1_epi32(0x03020504)); + filter[3] = + _mm_set1_epi32((1 << 16) | static_cast<uint16_t>(coefficients[0])); + for (int y = height >> 1; y > 0; --y) { + ptrdiff_t x = 0; + do { + __m128i a[8], d[2]; + d[0] = WienerVerticalTap7Kernel(wiener_buffer + x, width, filter, a); + a[7] = LoadAligned16(wiener_buffer + x + 7 * width); + d[1] = WienerVerticalFilter7(a + 1, filter); + StoreAligned16(dst + x, d[0]); + StoreAligned16(dst + dst_stride + x, d[1]); + x += 8; + } while (x < width); + dst += 2 * dst_stride; + wiener_buffer += 2 * width; + } + + if ((height & 1) != 0) { + ptrdiff_t x = 0; + do { + __m128i a[7]; + const __m128i d = + WienerVerticalTap7Kernel(wiener_buffer + x, width, filter, a); + StoreAligned16(dst + x, d); + x += 8; + } while (x < width); + } +} + +inline void WienerVerticalTap5(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + const int16_t coefficients[3], uint16_t* dst, + const ptrdiff_t dst_stride) { + const __m128i c = LoadLo8(coefficients); + __m128i filter[3]; + filter[0] = _mm_shuffle_epi32(c, 0x0); + filter[1] = _mm_shuffle_epi8(c, _mm_set1_epi32(0x03020504)); + filter[2] = + _mm_set1_epi32((1 << 16) | static_cast<uint16_t>(coefficients[0])); + for (int y = height >> 1; y > 0; --y) { + ptrdiff_t x = 0; + do { + __m128i a[6], d[2]; + d[0] = WienerVerticalTap5Kernel(wiener_buffer + x, width, filter, a); + a[5] = LoadAligned16(wiener_buffer + x + 5 * width); + d[1] = WienerVerticalFilter5(a + 1, filter); + StoreAligned16(dst + x, d[0]); + StoreAligned16(dst + dst_stride + x, d[1]); + x += 8; + } while (x < width); + dst += 2 * dst_stride; + wiener_buffer += 2 * width; + } + + if ((height & 1) != 0) { + ptrdiff_t x = 0; + do { + __m128i a[5]; + const __m128i d = + WienerVerticalTap5Kernel(wiener_buffer + x, width, filter, a); + StoreAligned16(dst + x, d); + x += 8; + } while (x < width); + } +} + +inline void WienerVerticalTap3(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + const int16_t coefficients[2], uint16_t* dst, + const ptrdiff_t dst_stride) { + __m128i filter[2]; + filter[0] = _mm_set1_epi32(*reinterpret_cast<const int32_t*>(coefficients)); + filter[1] = + _mm_set1_epi32((1 << 16) | static_cast<uint16_t>(coefficients[0])); + for (int y = height >> 1; y > 0; --y) { + ptrdiff_t x = 0; + do { + __m128i a[4], d[2]; + d[0] = WienerVerticalTap3Kernel(wiener_buffer + x, width, filter, a); + a[3] = LoadAligned16(wiener_buffer + x + 3 * width); + d[1] = WienerVerticalFilter3(a + 1, filter); + StoreAligned16(dst + x, d[0]); + StoreAligned16(dst + dst_stride + x, d[1]); + x += 8; + } while (x < width); + dst += 2 * dst_stride; + wiener_buffer += 2 * width; + } + + if ((height & 1) != 0) { + ptrdiff_t x = 0; + do { + __m128i a[3]; + const __m128i d = + WienerVerticalTap3Kernel(wiener_buffer + x, width, filter, a); + StoreAligned16(dst + x, d); + x += 8; + } while (x < width); + } +} + +inline void WienerVerticalTap1Kernel(const int16_t* const wiener_buffer, + uint16_t* const dst) { + const __m128i a = LoadAligned16(wiener_buffer); + const __m128i b = _mm_add_epi16(a, _mm_set1_epi16(8)); + const __m128i c = _mm_srai_epi16(b, 4); + const __m128i d = _mm_max_epi16(c, _mm_setzero_si128()); + const __m128i e = _mm_min_epi16(d, _mm_set1_epi16(1023)); + StoreAligned16(dst, e); +} + +inline void WienerVerticalTap1(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + uint16_t* dst, const ptrdiff_t dst_stride) { + for (int y = height >> 1; y > 0; --y) { + ptrdiff_t x = 0; + do { + WienerVerticalTap1Kernel(wiener_buffer + x, dst + x); + WienerVerticalTap1Kernel(wiener_buffer + width + x, dst + dst_stride + x); + x += 8; + } while (x < width); + dst += 2 * dst_stride; + wiener_buffer += 2 * width; + } + + if ((height & 1) != 0) { + ptrdiff_t x = 0; + do { + WienerVerticalTap1Kernel(wiener_buffer + x, dst + x); + x += 8; + } while (x < width); + } +} + +void WienerFilter_SSE4_1(const RestorationUnitInfo& restoration_info, + const void* const source, const void* const top_border, + const void* const bottom_border, + const ptrdiff_t stride, const int width, + const int height, + RestorationBuffer* const restoration_buffer, + void* const dest) { + const int16_t* const number_leading_zero_coefficients = + restoration_info.wiener_info.number_leading_zero_coefficients; + const int number_rows_to_skip = std::max( + static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]), + 1); + const ptrdiff_t wiener_stride = Align(width, 16); + int16_t* const wiener_buffer_vertical = restoration_buffer->wiener_buffer; + // The values are saturated to 13 bits before storing. + int16_t* wiener_buffer_horizontal = + wiener_buffer_vertical + number_rows_to_skip * wiener_stride; + + // horizontal filtering. + // Over-reads up to 15 - |kRestorationHorizontalBorder| values. + const int height_horizontal = + height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip; + const int height_extra = (height_horizontal - height) >> 1; + assert(height_extra <= 2); + const auto* const src = static_cast<const uint16_t*>(source); + const auto* const top = static_cast<const uint16_t*>(top_border); + const auto* const bottom = static_cast<const uint16_t*>(bottom_border); + const __m128i coefficients_horizontal = + LoadLo8(restoration_info.wiener_info.filter[WienerInfo::kHorizontal]); + if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) { + WienerHorizontalTap7(top + (2 - height_extra) * stride - 3, stride, + wiener_stride, height_extra, coefficients_horizontal, + &wiener_buffer_horizontal); + WienerHorizontalTap7(src - 3, stride, wiener_stride, height, + coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap7(bottom - 3, stride, wiener_stride, height_extra, + coefficients_horizontal, &wiener_buffer_horizontal); + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) { + WienerHorizontalTap5(top + (2 - height_extra) * stride - 2, stride, + wiener_stride, height_extra, coefficients_horizontal, + &wiener_buffer_horizontal); + WienerHorizontalTap5(src - 2, stride, wiener_stride, height, + coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap5(bottom - 2, stride, wiener_stride, height_extra, + coefficients_horizontal, &wiener_buffer_horizontal); + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) { + // The maximum over-reads happen here. + WienerHorizontalTap3(top + (2 - height_extra) * stride - 1, stride, + wiener_stride, height_extra, coefficients_horizontal, + &wiener_buffer_horizontal); + WienerHorizontalTap3(src - 1, stride, wiener_stride, height, + coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap3(bottom - 1, stride, wiener_stride, height_extra, + coefficients_horizontal, &wiener_buffer_horizontal); + } else { + assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3); + WienerHorizontalTap1(top + (2 - height_extra) * stride, stride, + wiener_stride, height_extra, + &wiener_buffer_horizontal); + WienerHorizontalTap1(src, stride, wiener_stride, height, + &wiener_buffer_horizontal); + WienerHorizontalTap1(bottom, stride, wiener_stride, height_extra, + &wiener_buffer_horizontal); + } + + // vertical filtering. + // Over-writes up to 15 values. + const int16_t* const filter_vertical = + restoration_info.wiener_info.filter[WienerInfo::kVertical]; + auto* dst = static_cast<uint16_t*>(dest); + if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) { + // Because the top row of |source| is a duplicate of the second row, and the + // bottom row of |source| is a duplicate of its above row, we can duplicate + // the top and bottom row of |wiener_buffer| accordingly. + memcpy(wiener_buffer_horizontal, wiener_buffer_horizontal - wiener_stride, + sizeof(*wiener_buffer_horizontal) * wiener_stride); + memcpy(restoration_buffer->wiener_buffer, + restoration_buffer->wiener_buffer + wiener_stride, + sizeof(*restoration_buffer->wiener_buffer) * wiener_stride); + WienerVerticalTap7(wiener_buffer_vertical, wiener_stride, height, + filter_vertical, dst, stride); + } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) { + WienerVerticalTap5(wiener_buffer_vertical + wiener_stride, wiener_stride, + height, filter_vertical + 1, dst, stride); + } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) { + WienerVerticalTap3(wiener_buffer_vertical + 2 * wiener_stride, + wiener_stride, height, filter_vertical + 2, dst, stride); + } else { + assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3); + WienerVerticalTap1(wiener_buffer_vertical + 3 * wiener_stride, + wiener_stride, height, dst, stride); + } +} + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + static_cast<void>(dsp); +#if DSP_ENABLED_10BPP_SSE4_1(WienerFilter) + dsp->loop_restorations[0] = WienerFilter_SSE4_1; +#else + static_cast<void>(WienerFilter_SSE4_1); +#endif +} + +} // namespace + +void LoopRestorationInit10bpp_SSE4_1() { Init10bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !(LIBGAV1_TARGETING_SSE4_1 && LIBGAV1_MAX_BITDEPTH >= 10) +namespace libgav1 { +namespace dsp { + +void LoopRestorationInit10bpp_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 && LIBGAV1_MAX_BITDEPTH >= 10 diff --git a/src/dsp/x86/loop_restoration_avx2.cc b/src/dsp/x86/loop_restoration_avx2.cc new file mode 100644 index 0000000..7ae7c90 --- /dev/null +++ b/src/dsp/x86/loop_restoration_avx2.cc @@ -0,0 +1,2902 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/loop_restoration.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_AVX2 +#include <immintrin.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> + +#include "src/dsp/common.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_avx2.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +inline void WienerHorizontalClip(const __m256i s[2], const __m256i s_3x128, + int16_t* const wiener_buffer) { + constexpr int offset = + 1 << (8 + kWienerFilterBits - kInterRoundBitsHorizontal - 1); + constexpr int limit = + (1 << (8 + 1 + kWienerFilterBits - kInterRoundBitsHorizontal)) - 1; + const __m256i offsets = _mm256_set1_epi16(-offset); + const __m256i limits = _mm256_set1_epi16(limit - offset); + const __m256i round = _mm256_set1_epi16(1 << (kInterRoundBitsHorizontal - 1)); + // The sum range here is [-128 * 255, 90 * 255]. + const __m256i madd = _mm256_add_epi16(s[0], s[1]); + const __m256i sum = _mm256_add_epi16(madd, round); + const __m256i rounded_sum0 = + _mm256_srai_epi16(sum, kInterRoundBitsHorizontal); + // Add back scaled down offset correction. + const __m256i rounded_sum1 = _mm256_add_epi16(rounded_sum0, s_3x128); + const __m256i d0 = _mm256_max_epi16(rounded_sum1, offsets); + const __m256i d1 = _mm256_min_epi16(d0, limits); + StoreAligned32(wiener_buffer, d1); +} + +// Using _mm256_alignr_epi8() is about 8% faster than loading all and unpacking, +// because the compiler generates redundant code when loading all and unpacking. +inline void WienerHorizontalTap7Kernel(const __m256i s[2], + const __m256i filter[4], + int16_t* const wiener_buffer) { + const auto s01 = _mm256_alignr_epi8(s[1], s[0], 1); + const auto s23 = _mm256_alignr_epi8(s[1], s[0], 5); + const auto s45 = _mm256_alignr_epi8(s[1], s[0], 9); + const auto s67 = _mm256_alignr_epi8(s[1], s[0], 13); + __m256i madds[4]; + madds[0] = _mm256_maddubs_epi16(s01, filter[0]); + madds[1] = _mm256_maddubs_epi16(s23, filter[1]); + madds[2] = _mm256_maddubs_epi16(s45, filter[2]); + madds[3] = _mm256_maddubs_epi16(s67, filter[3]); + madds[0] = _mm256_add_epi16(madds[0], madds[2]); + madds[1] = _mm256_add_epi16(madds[1], madds[3]); + const __m256i s_3x128 = _mm256_slli_epi16(_mm256_srli_epi16(s23, 8), + 7 - kInterRoundBitsHorizontal); + WienerHorizontalClip(madds, s_3x128, wiener_buffer); +} + +inline void WienerHorizontalTap5Kernel(const __m256i s[2], + const __m256i filter[3], + int16_t* const wiener_buffer) { + const auto s01 = _mm256_alignr_epi8(s[1], s[0], 1); + const auto s23 = _mm256_alignr_epi8(s[1], s[0], 5); + const auto s45 = _mm256_alignr_epi8(s[1], s[0], 9); + __m256i madds[3]; + madds[0] = _mm256_maddubs_epi16(s01, filter[0]); + madds[1] = _mm256_maddubs_epi16(s23, filter[1]); + madds[2] = _mm256_maddubs_epi16(s45, filter[2]); + madds[0] = _mm256_add_epi16(madds[0], madds[2]); + const __m256i s_3x128 = _mm256_srli_epi16(_mm256_slli_epi16(s23, 8), + kInterRoundBitsHorizontal + 1); + WienerHorizontalClip(madds, s_3x128, wiener_buffer); +} + +inline void WienerHorizontalTap3Kernel(const __m256i s[2], + const __m256i filter[2], + int16_t* const wiener_buffer) { + const auto s01 = _mm256_alignr_epi8(s[1], s[0], 1); + const auto s23 = _mm256_alignr_epi8(s[1], s[0], 5); + __m256i madds[2]; + madds[0] = _mm256_maddubs_epi16(s01, filter[0]); + madds[1] = _mm256_maddubs_epi16(s23, filter[1]); + const __m256i s_3x128 = _mm256_slli_epi16(_mm256_srli_epi16(s01, 8), + 7 - kInterRoundBitsHorizontal); + WienerHorizontalClip(madds, s_3x128, wiener_buffer); +} + +inline void WienerHorizontalTap7(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + const __m256i coefficients, + int16_t** const wiener_buffer) { + __m256i filter[4]; + filter[0] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0100)); + filter[1] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0302)); + filter[2] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0102)); + filter[3] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x8000)); + for (int y = height; y != 0; --y) { + __m256i s = LoadUnaligned32(src); + __m256i ss[4]; + ss[0] = _mm256_unpacklo_epi8(s, s); + ptrdiff_t x = 0; + do { + ss[1] = _mm256_unpackhi_epi8(s, s); + s = LoadUnaligned32(src + x + 32); + ss[3] = _mm256_unpacklo_epi8(s, s); + ss[2] = _mm256_permute2x128_si256(ss[0], ss[3], 0x21); + WienerHorizontalTap7Kernel(ss + 0, filter, *wiener_buffer + x + 0); + WienerHorizontalTap7Kernel(ss + 1, filter, *wiener_buffer + x + 16); + ss[0] = ss[3]; + x += 32; + } while (x < width); + src += src_stride; + *wiener_buffer += width; + } +} + +inline void WienerHorizontalTap5(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + const __m256i coefficients, + int16_t** const wiener_buffer) { + __m256i filter[3]; + filter[0] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0201)); + filter[1] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0203)); + filter[2] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x8001)); + for (int y = height; y != 0; --y) { + __m256i s = LoadUnaligned32(src); + __m256i ss[4]; + ss[0] = _mm256_unpacklo_epi8(s, s); + ptrdiff_t x = 0; + do { + ss[1] = _mm256_unpackhi_epi8(s, s); + s = LoadUnaligned32(src + x + 32); + ss[3] = _mm256_unpacklo_epi8(s, s); + ss[2] = _mm256_permute2x128_si256(ss[0], ss[3], 0x21); + WienerHorizontalTap5Kernel(ss + 0, filter, *wiener_buffer + x + 0); + WienerHorizontalTap5Kernel(ss + 1, filter, *wiener_buffer + x + 16); + ss[0] = ss[3]; + x += 32; + } while (x < width); + src += src_stride; + *wiener_buffer += width; + } +} + +inline void WienerHorizontalTap3(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + const __m256i coefficients, + int16_t** const wiener_buffer) { + __m256i filter[2]; + filter[0] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0302)); + filter[1] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x8002)); + for (int y = height; y != 0; --y) { + __m256i s = LoadUnaligned32(src); + __m256i ss[4]; + ss[0] = _mm256_unpacklo_epi8(s, s); + ptrdiff_t x = 0; + do { + ss[1] = _mm256_unpackhi_epi8(s, s); + s = LoadUnaligned32(src + x + 32); + ss[3] = _mm256_unpacklo_epi8(s, s); + ss[2] = _mm256_permute2x128_si256(ss[0], ss[3], 0x21); + WienerHorizontalTap3Kernel(ss + 0, filter, *wiener_buffer + x + 0); + WienerHorizontalTap3Kernel(ss + 1, filter, *wiener_buffer + x + 16); + ss[0] = ss[3]; + x += 32; + } while (x < width); + src += src_stride; + *wiener_buffer += width; + } +} + +inline void WienerHorizontalTap1(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + int16_t** const wiener_buffer) { + for (int y = height; y != 0; --y) { + ptrdiff_t x = 0; + do { + const __m256i s = LoadUnaligned32(src + x); + const __m256i s0 = _mm256_unpacklo_epi8(s, _mm256_setzero_si256()); + const __m256i s1 = _mm256_unpackhi_epi8(s, _mm256_setzero_si256()); + __m256i d[2]; + d[0] = _mm256_slli_epi16(s0, 4); + d[1] = _mm256_slli_epi16(s1, 4); + StoreAligned64(*wiener_buffer + x, d); + x += 32; + } while (x < width); + src += src_stride; + *wiener_buffer += width; + } +} + +inline __m256i WienerVertical7(const __m256i a[2], const __m256i filter[2]) { + const __m256i round = _mm256_set1_epi32(1 << (kInterRoundBitsVertical - 1)); + const __m256i madd0 = _mm256_madd_epi16(a[0], filter[0]); + const __m256i madd1 = _mm256_madd_epi16(a[1], filter[1]); + const __m256i sum0 = _mm256_add_epi32(round, madd0); + const __m256i sum1 = _mm256_add_epi32(sum0, madd1); + return _mm256_srai_epi32(sum1, kInterRoundBitsVertical); +} + +inline __m256i WienerVertical5(const __m256i a[2], const __m256i filter[2]) { + const __m256i madd0 = _mm256_madd_epi16(a[0], filter[0]); + const __m256i madd1 = _mm256_madd_epi16(a[1], filter[1]); + const __m256i sum = _mm256_add_epi32(madd0, madd1); + return _mm256_srai_epi32(sum, kInterRoundBitsVertical); +} + +inline __m256i WienerVertical3(const __m256i a, const __m256i filter) { + const __m256i round = _mm256_set1_epi32(1 << (kInterRoundBitsVertical - 1)); + const __m256i madd = _mm256_madd_epi16(a, filter); + const __m256i sum = _mm256_add_epi32(round, madd); + return _mm256_srai_epi32(sum, kInterRoundBitsVertical); +} + +inline __m256i WienerVerticalFilter7(const __m256i a[7], + const __m256i filter[2]) { + __m256i b[2]; + const __m256i a06 = _mm256_add_epi16(a[0], a[6]); + const __m256i a15 = _mm256_add_epi16(a[1], a[5]); + const __m256i a24 = _mm256_add_epi16(a[2], a[4]); + b[0] = _mm256_unpacklo_epi16(a06, a15); + b[1] = _mm256_unpacklo_epi16(a24, a[3]); + const __m256i sum0 = WienerVertical7(b, filter); + b[0] = _mm256_unpackhi_epi16(a06, a15); + b[1] = _mm256_unpackhi_epi16(a24, a[3]); + const __m256i sum1 = WienerVertical7(b, filter); + return _mm256_packs_epi32(sum0, sum1); +} + +inline __m256i WienerVerticalFilter5(const __m256i a[5], + const __m256i filter[2]) { + const __m256i round = _mm256_set1_epi16(1 << (kInterRoundBitsVertical - 1)); + __m256i b[2]; + const __m256i a04 = _mm256_add_epi16(a[0], a[4]); + const __m256i a13 = _mm256_add_epi16(a[1], a[3]); + b[0] = _mm256_unpacklo_epi16(a04, a13); + b[1] = _mm256_unpacklo_epi16(a[2], round); + const __m256i sum0 = WienerVertical5(b, filter); + b[0] = _mm256_unpackhi_epi16(a04, a13); + b[1] = _mm256_unpackhi_epi16(a[2], round); + const __m256i sum1 = WienerVertical5(b, filter); + return _mm256_packs_epi32(sum0, sum1); +} + +inline __m256i WienerVerticalFilter3(const __m256i a[3], const __m256i filter) { + __m256i b; + const __m256i a02 = _mm256_add_epi16(a[0], a[2]); + b = _mm256_unpacklo_epi16(a02, a[1]); + const __m256i sum0 = WienerVertical3(b, filter); + b = _mm256_unpackhi_epi16(a02, a[1]); + const __m256i sum1 = WienerVertical3(b, filter); + return _mm256_packs_epi32(sum0, sum1); +} + +inline __m256i WienerVerticalTap7Kernel(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m256i filter[2], __m256i a[7]) { + a[0] = LoadAligned32(wiener_buffer + 0 * wiener_stride); + a[1] = LoadAligned32(wiener_buffer + 1 * wiener_stride); + a[2] = LoadAligned32(wiener_buffer + 2 * wiener_stride); + a[3] = LoadAligned32(wiener_buffer + 3 * wiener_stride); + a[4] = LoadAligned32(wiener_buffer + 4 * wiener_stride); + a[5] = LoadAligned32(wiener_buffer + 5 * wiener_stride); + a[6] = LoadAligned32(wiener_buffer + 6 * wiener_stride); + return WienerVerticalFilter7(a, filter); +} + +inline __m256i WienerVerticalTap5Kernel(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m256i filter[2], __m256i a[5]) { + a[0] = LoadAligned32(wiener_buffer + 0 * wiener_stride); + a[1] = LoadAligned32(wiener_buffer + 1 * wiener_stride); + a[2] = LoadAligned32(wiener_buffer + 2 * wiener_stride); + a[3] = LoadAligned32(wiener_buffer + 3 * wiener_stride); + a[4] = LoadAligned32(wiener_buffer + 4 * wiener_stride); + return WienerVerticalFilter5(a, filter); +} + +inline __m256i WienerVerticalTap3Kernel(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m256i filter, __m256i a[3]) { + a[0] = LoadAligned32(wiener_buffer + 0 * wiener_stride); + a[1] = LoadAligned32(wiener_buffer + 1 * wiener_stride); + a[2] = LoadAligned32(wiener_buffer + 2 * wiener_stride); + return WienerVerticalFilter3(a, filter); +} + +inline void WienerVerticalTap7Kernel2(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m256i filter[2], __m256i d[2]) { + __m256i a[8]; + d[0] = WienerVerticalTap7Kernel(wiener_buffer, wiener_stride, filter, a); + a[7] = LoadAligned32(wiener_buffer + 7 * wiener_stride); + d[1] = WienerVerticalFilter7(a + 1, filter); +} + +inline void WienerVerticalTap5Kernel2(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m256i filter[2], __m256i d[2]) { + __m256i a[6]; + d[0] = WienerVerticalTap5Kernel(wiener_buffer, wiener_stride, filter, a); + a[5] = LoadAligned32(wiener_buffer + 5 * wiener_stride); + d[1] = WienerVerticalFilter5(a + 1, filter); +} + +inline void WienerVerticalTap3Kernel2(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m256i filter, __m256i d[2]) { + __m256i a[4]; + d[0] = WienerVerticalTap3Kernel(wiener_buffer, wiener_stride, filter, a); + a[3] = LoadAligned32(wiener_buffer + 3 * wiener_stride); + d[1] = WienerVerticalFilter3(a + 1, filter); +} + +inline void WienerVerticalTap7(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + const int16_t coefficients[4], uint8_t* dst, + const ptrdiff_t dst_stride) { + const __m256i c = _mm256_broadcastq_epi64(LoadLo8(coefficients)); + __m256i filter[2]; + filter[0] = _mm256_shuffle_epi32(c, 0x0); + filter[1] = _mm256_shuffle_epi32(c, 0x55); + for (int y = height >> 1; y > 0; --y) { + ptrdiff_t x = 0; + do { + __m256i d[2][2]; + WienerVerticalTap7Kernel2(wiener_buffer + x + 0, width, filter, d[0]); + WienerVerticalTap7Kernel2(wiener_buffer + x + 16, width, filter, d[1]); + StoreUnaligned32(dst + x, _mm256_packus_epi16(d[0][0], d[1][0])); + StoreUnaligned32(dst + dst_stride + x, + _mm256_packus_epi16(d[0][1], d[1][1])); + x += 32; + } while (x < width); + dst += 2 * dst_stride; + wiener_buffer += 2 * width; + } + + if ((height & 1) != 0) { + ptrdiff_t x = 0; + do { + __m256i a[7]; + const __m256i d0 = + WienerVerticalTap7Kernel(wiener_buffer + x + 0, width, filter, a); + const __m256i d1 = + WienerVerticalTap7Kernel(wiener_buffer + x + 16, width, filter, a); + StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1)); + x += 32; + } while (x < width); + } +} + +inline void WienerVerticalTap5(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + const int16_t coefficients[3], uint8_t* dst, + const ptrdiff_t dst_stride) { + const __m256i c = _mm256_broadcastd_epi32(Load4(coefficients)); + __m256i filter[2]; + filter[0] = _mm256_shuffle_epi32(c, 0); + filter[1] = + _mm256_set1_epi32((1 << 16) | static_cast<uint16_t>(coefficients[2])); + for (int y = height >> 1; y > 0; --y) { + ptrdiff_t x = 0; + do { + __m256i d[2][2]; + WienerVerticalTap5Kernel2(wiener_buffer + x + 0, width, filter, d[0]); + WienerVerticalTap5Kernel2(wiener_buffer + x + 16, width, filter, d[1]); + StoreUnaligned32(dst + x, _mm256_packus_epi16(d[0][0], d[1][0])); + StoreUnaligned32(dst + dst_stride + x, + _mm256_packus_epi16(d[0][1], d[1][1])); + x += 32; + } while (x < width); + dst += 2 * dst_stride; + wiener_buffer += 2 * width; + } + + if ((height & 1) != 0) { + ptrdiff_t x = 0; + do { + __m256i a[5]; + const __m256i d0 = + WienerVerticalTap5Kernel(wiener_buffer + x + 0, width, filter, a); + const __m256i d1 = + WienerVerticalTap5Kernel(wiener_buffer + x + 16, width, filter, a); + StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1)); + x += 32; + } while (x < width); + } +} + +inline void WienerVerticalTap3(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + const int16_t coefficients[2], uint8_t* dst, + const ptrdiff_t dst_stride) { + const __m256i filter = + _mm256_set1_epi32(*reinterpret_cast<const int32_t*>(coefficients)); + for (int y = height >> 1; y > 0; --y) { + ptrdiff_t x = 0; + do { + __m256i d[2][2]; + WienerVerticalTap3Kernel2(wiener_buffer + x + 0, width, filter, d[0]); + WienerVerticalTap3Kernel2(wiener_buffer + x + 16, width, filter, d[1]); + StoreUnaligned32(dst + x, _mm256_packus_epi16(d[0][0], d[1][0])); + StoreUnaligned32(dst + dst_stride + x, + _mm256_packus_epi16(d[0][1], d[1][1])); + x += 32; + } while (x < width); + dst += 2 * dst_stride; + wiener_buffer += 2 * width; + } + + if ((height & 1) != 0) { + ptrdiff_t x = 0; + do { + __m256i a[3]; + const __m256i d0 = + WienerVerticalTap3Kernel(wiener_buffer + x + 0, width, filter, a); + const __m256i d1 = + WienerVerticalTap3Kernel(wiener_buffer + x + 16, width, filter, a); + StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1)); + x += 32; + } while (x < width); + } +} + +inline void WienerVerticalTap1Kernel(const int16_t* const wiener_buffer, + uint8_t* const dst) { + const __m256i a0 = LoadAligned32(wiener_buffer + 0); + const __m256i a1 = LoadAligned32(wiener_buffer + 16); + const __m256i b0 = _mm256_add_epi16(a0, _mm256_set1_epi16(8)); + const __m256i b1 = _mm256_add_epi16(a1, _mm256_set1_epi16(8)); + const __m256i c0 = _mm256_srai_epi16(b0, 4); + const __m256i c1 = _mm256_srai_epi16(b1, 4); + const __m256i d = _mm256_packus_epi16(c0, c1); + StoreUnaligned32(dst, d); +} + +inline void WienerVerticalTap1(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + uint8_t* dst, const ptrdiff_t dst_stride) { + for (int y = height >> 1; y > 0; --y) { + ptrdiff_t x = 0; + do { + WienerVerticalTap1Kernel(wiener_buffer + x, dst + x); + WienerVerticalTap1Kernel(wiener_buffer + width + x, dst + dst_stride + x); + x += 32; + } while (x < width); + dst += 2 * dst_stride; + wiener_buffer += 2 * width; + } + + if ((height & 1) != 0) { + ptrdiff_t x = 0; + do { + WienerVerticalTap1Kernel(wiener_buffer + x, dst + x); + x += 32; + } while (x < width); + } +} + +void WienerFilter_AVX2(const RestorationUnitInfo& restoration_info, + const void* const source, const void* const top_border, + const void* const bottom_border, const ptrdiff_t stride, + const int width, const int height, + RestorationBuffer* const restoration_buffer, + void* const dest) { + const int16_t* const number_leading_zero_coefficients = + restoration_info.wiener_info.number_leading_zero_coefficients; + const int number_rows_to_skip = std::max( + static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]), + 1); + const ptrdiff_t wiener_stride = Align(width, 32); + int16_t* const wiener_buffer_vertical = restoration_buffer->wiener_buffer; + // The values are saturated to 13 bits before storing. + int16_t* wiener_buffer_horizontal = + wiener_buffer_vertical + number_rows_to_skip * wiener_stride; + + // horizontal filtering. + // Over-reads up to 15 - |kRestorationHorizontalBorder| values. + const int height_horizontal = + height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip; + const int height_extra = (height_horizontal - height) >> 1; + assert(height_extra <= 2); + const auto* const src = static_cast<const uint8_t*>(source); + const auto* const top = static_cast<const uint8_t*>(top_border); + const auto* const bottom = static_cast<const uint8_t*>(bottom_border); + const __m128i c = + LoadLo8(restoration_info.wiener_info.filter[WienerInfo::kHorizontal]); + // In order to keep the horizontal pass intermediate values within 16 bits we + // offset |filter[3]| by 128. The 128 offset will be added back in the loop. + __m128i c_horizontal = + _mm_sub_epi16(c, _mm_setr_epi16(0, 0, 0, 128, 0, 0, 0, 0)); + c_horizontal = _mm_packs_epi16(c_horizontal, c_horizontal); + const __m256i coefficients_horizontal = _mm256_broadcastd_epi32(c_horizontal); + if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) { + WienerHorizontalTap7(top + (2 - height_extra) * stride - 3, stride, + wiener_stride, height_extra, coefficients_horizontal, + &wiener_buffer_horizontal); + WienerHorizontalTap7(src - 3, stride, wiener_stride, height, + coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap7(bottom - 3, stride, wiener_stride, height_extra, + coefficients_horizontal, &wiener_buffer_horizontal); + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) { + WienerHorizontalTap5(top + (2 - height_extra) * stride - 2, stride, + wiener_stride, height_extra, coefficients_horizontal, + &wiener_buffer_horizontal); + WienerHorizontalTap5(src - 2, stride, wiener_stride, height, + coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap5(bottom - 2, stride, wiener_stride, height_extra, + coefficients_horizontal, &wiener_buffer_horizontal); + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) { + // The maximum over-reads happen here. + WienerHorizontalTap3(top + (2 - height_extra) * stride - 1, stride, + wiener_stride, height_extra, coefficients_horizontal, + &wiener_buffer_horizontal); + WienerHorizontalTap3(src - 1, stride, wiener_stride, height, + coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap3(bottom - 1, stride, wiener_stride, height_extra, + coefficients_horizontal, &wiener_buffer_horizontal); + } else { + assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3); + WienerHorizontalTap1(top + (2 - height_extra) * stride, stride, + wiener_stride, height_extra, + &wiener_buffer_horizontal); + WienerHorizontalTap1(src, stride, wiener_stride, height, + &wiener_buffer_horizontal); + WienerHorizontalTap1(bottom, stride, wiener_stride, height_extra, + &wiener_buffer_horizontal); + } + + // vertical filtering. + // Over-writes up to 15 values. + const int16_t* const filter_vertical = + restoration_info.wiener_info.filter[WienerInfo::kVertical]; + auto* dst = static_cast<uint8_t*>(dest); + if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) { + // Because the top row of |source| is a duplicate of the second row, and the + // bottom row of |source| is a duplicate of its above row, we can duplicate + // the top and bottom row of |wiener_buffer| accordingly. + memcpy(wiener_buffer_horizontal, wiener_buffer_horizontal - wiener_stride, + sizeof(*wiener_buffer_horizontal) * wiener_stride); + memcpy(restoration_buffer->wiener_buffer, + restoration_buffer->wiener_buffer + wiener_stride, + sizeof(*restoration_buffer->wiener_buffer) * wiener_stride); + WienerVerticalTap7(wiener_buffer_vertical, wiener_stride, height, + filter_vertical, dst, stride); + } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) { + WienerVerticalTap5(wiener_buffer_vertical + wiener_stride, wiener_stride, + height, filter_vertical + 1, dst, stride); + } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) { + WienerVerticalTap3(wiener_buffer_vertical + 2 * wiener_stride, + wiener_stride, height, filter_vertical + 2, dst, stride); + } else { + assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3); + WienerVerticalTap1(wiener_buffer_vertical + 3 * wiener_stride, + wiener_stride, height, dst, stride); + } +} + +//------------------------------------------------------------------------------ +// SGR + +constexpr int kSumOffset = 24; + +// SIMD overreads the number of bytes in SIMD registers - (width % 16) - 2 * +// padding pixels, where padding is 3 for Pass 1 and 2 for Pass 2. The number of +// bytes in SIMD registers is 16 for SSE4.1 and 32 for AVX2. +constexpr int kOverreadInBytesPass1_128 = 10; +constexpr int kOverreadInBytesPass2_128 = 12; +constexpr int kOverreadInBytesPass1_256 = kOverreadInBytesPass1_128 + 16; +constexpr int kOverreadInBytesPass2_256 = kOverreadInBytesPass2_128 + 16; + +inline void LoadAligned16x2U16(const uint16_t* const src[2], const ptrdiff_t x, + __m128i dst[2]) { + dst[0] = LoadAligned16(src[0] + x); + dst[1] = LoadAligned16(src[1] + x); +} + +inline void LoadAligned32x2U16(const uint16_t* const src[2], const ptrdiff_t x, + __m256i dst[2]) { + dst[0] = LoadAligned32(src[0] + x); + dst[1] = LoadAligned32(src[1] + x); +} + +inline void LoadAligned32x2U16Msan(const uint16_t* const src[2], + const ptrdiff_t x, const ptrdiff_t border, + __m256i dst[2]) { + dst[0] = LoadAligned32Msan(src[0] + x, sizeof(**src) * (x + 16 - border)); + dst[1] = LoadAligned32Msan(src[1] + x, sizeof(**src) * (x + 16 - border)); +} + +inline void LoadAligned16x3U16(const uint16_t* const src[3], const ptrdiff_t x, + __m128i dst[3]) { + dst[0] = LoadAligned16(src[0] + x); + dst[1] = LoadAligned16(src[1] + x); + dst[2] = LoadAligned16(src[2] + x); +} + +inline void LoadAligned32x3U16(const uint16_t* const src[3], const ptrdiff_t x, + __m256i dst[3]) { + dst[0] = LoadAligned32(src[0] + x); + dst[1] = LoadAligned32(src[1] + x); + dst[2] = LoadAligned32(src[2] + x); +} + +inline void LoadAligned32x3U16Msan(const uint16_t* const src[3], + const ptrdiff_t x, const ptrdiff_t border, + __m256i dst[3]) { + dst[0] = LoadAligned32Msan(src[0] + x, sizeof(**src) * (x + 16 - border)); + dst[1] = LoadAligned32Msan(src[1] + x, sizeof(**src) * (x + 16 - border)); + dst[2] = LoadAligned32Msan(src[2] + x, sizeof(**src) * (x + 16 - border)); +} + +inline void LoadAligned32U32(const uint32_t* const src, __m128i dst[2]) { + dst[0] = LoadAligned16(src + 0); + dst[1] = LoadAligned16(src + 4); +} + +inline void LoadAligned32x2U32(const uint32_t* const src[2], const ptrdiff_t x, + __m128i dst[2][2]) { + LoadAligned32U32(src[0] + x, dst[0]); + LoadAligned32U32(src[1] + x, dst[1]); +} + +inline void LoadAligned64x2U32(const uint32_t* const src[2], const ptrdiff_t x, + __m256i dst[2][2]) { + LoadAligned64(src[0] + x, dst[0]); + LoadAligned64(src[1] + x, dst[1]); +} + +inline void LoadAligned64x2U32Msan(const uint32_t* const src[2], + const ptrdiff_t x, const ptrdiff_t border, + __m256i dst[2][2]) { + LoadAligned64Msan(src[0] + x, sizeof(**src) * (x + 16 - border), dst[0]); + LoadAligned64Msan(src[1] + x, sizeof(**src) * (x + 16 - border), dst[1]); +} + +inline void LoadAligned32x3U32(const uint32_t* const src[3], const ptrdiff_t x, + __m128i dst[3][2]) { + LoadAligned32U32(src[0] + x, dst[0]); + LoadAligned32U32(src[1] + x, dst[1]); + LoadAligned32U32(src[2] + x, dst[2]); +} + +inline void LoadAligned64x3U32(const uint32_t* const src[3], const ptrdiff_t x, + __m256i dst[3][2]) { + LoadAligned64(src[0] + x, dst[0]); + LoadAligned64(src[1] + x, dst[1]); + LoadAligned64(src[2] + x, dst[2]); +} + +inline void LoadAligned64x3U32Msan(const uint32_t* const src[3], + const ptrdiff_t x, const ptrdiff_t border, + __m256i dst[3][2]) { + LoadAligned64Msan(src[0] + x, sizeof(**src) * (x + 16 - border), dst[0]); + LoadAligned64Msan(src[1] + x, sizeof(**src) * (x + 16 - border), dst[1]); + LoadAligned64Msan(src[2] + x, sizeof(**src) * (x + 16 - border), dst[2]); +} + +inline void StoreAligned32U32(uint32_t* const dst, const __m128i src[2]) { + StoreAligned16(dst + 0, src[0]); + StoreAligned16(dst + 4, src[1]); +} + +// Don't use _mm_cvtepu8_epi16() or _mm_cvtepu16_epi32() in the following +// functions. Some compilers may generate super inefficient code and the whole +// decoder could be 15% slower. + +inline __m128i VaddlLo8(const __m128i src0, const __m128i src1) { + const __m128i s0 = _mm_unpacklo_epi8(src0, _mm_setzero_si128()); + const __m128i s1 = _mm_unpacklo_epi8(src1, _mm_setzero_si128()); + return _mm_add_epi16(s0, s1); +} + +inline __m256i VaddlLo8(const __m256i src0, const __m256i src1) { + const __m256i s0 = _mm256_unpacklo_epi8(src0, _mm256_setzero_si256()); + const __m256i s1 = _mm256_unpacklo_epi8(src1, _mm256_setzero_si256()); + return _mm256_add_epi16(s0, s1); +} + +inline __m256i VaddlHi8(const __m256i src0, const __m256i src1) { + const __m256i s0 = _mm256_unpackhi_epi8(src0, _mm256_setzero_si256()); + const __m256i s1 = _mm256_unpackhi_epi8(src1, _mm256_setzero_si256()); + return _mm256_add_epi16(s0, s1); +} + +inline __m128i VaddlLo16(const __m128i src0, const __m128i src1) { + const __m128i s0 = _mm_unpacklo_epi16(src0, _mm_setzero_si128()); + const __m128i s1 = _mm_unpacklo_epi16(src1, _mm_setzero_si128()); + return _mm_add_epi32(s0, s1); +} + +inline __m256i VaddlLo16(const __m256i src0, const __m256i src1) { + const __m256i s0 = _mm256_unpacklo_epi16(src0, _mm256_setzero_si256()); + const __m256i s1 = _mm256_unpacklo_epi16(src1, _mm256_setzero_si256()); + return _mm256_add_epi32(s0, s1); +} + +inline __m128i VaddlHi16(const __m128i src0, const __m128i src1) { + const __m128i s0 = _mm_unpackhi_epi16(src0, _mm_setzero_si128()); + const __m128i s1 = _mm_unpackhi_epi16(src1, _mm_setzero_si128()); + return _mm_add_epi32(s0, s1); +} + +inline __m256i VaddlHi16(const __m256i src0, const __m256i src1) { + const __m256i s0 = _mm256_unpackhi_epi16(src0, _mm256_setzero_si256()); + const __m256i s1 = _mm256_unpackhi_epi16(src1, _mm256_setzero_si256()); + return _mm256_add_epi32(s0, s1); +} + +inline __m128i VaddwLo8(const __m128i src0, const __m128i src1) { + const __m128i s1 = _mm_unpacklo_epi8(src1, _mm_setzero_si128()); + return _mm_add_epi16(src0, s1); +} + +inline __m256i VaddwLo8(const __m256i src0, const __m256i src1) { + const __m256i s1 = _mm256_unpacklo_epi8(src1, _mm256_setzero_si256()); + return _mm256_add_epi16(src0, s1); +} + +inline __m256i VaddwHi8(const __m256i src0, const __m256i src1) { + const __m256i s1 = _mm256_unpackhi_epi8(src1, _mm256_setzero_si256()); + return _mm256_add_epi16(src0, s1); +} + +inline __m128i VaddwLo16(const __m128i src0, const __m128i src1) { + const __m128i s1 = _mm_unpacklo_epi16(src1, _mm_setzero_si128()); + return _mm_add_epi32(src0, s1); +} + +inline __m256i VaddwLo16(const __m256i src0, const __m256i src1) { + const __m256i s1 = _mm256_unpacklo_epi16(src1, _mm256_setzero_si256()); + return _mm256_add_epi32(src0, s1); +} + +inline __m128i VaddwHi16(const __m128i src0, const __m128i src1) { + const __m128i s1 = _mm_unpackhi_epi16(src1, _mm_setzero_si128()); + return _mm_add_epi32(src0, s1); +} + +inline __m256i VaddwHi16(const __m256i src0, const __m256i src1) { + const __m256i s1 = _mm256_unpackhi_epi16(src1, _mm256_setzero_si256()); + return _mm256_add_epi32(src0, s1); +} + +// Using VgetLane16() can save a sign extension instruction. +template <int n> +inline int VgetLane16(__m256i src) { + return _mm256_extract_epi16(src, n); +} + +template <int n> +inline int VgetLane8(__m256i src) { + return _mm256_extract_epi8(src, n); +} + +inline __m256i VmullNLo8(const __m256i src0, const int src1) { + const __m256i s0 = _mm256_unpacklo_epi16(src0, _mm256_setzero_si256()); + return _mm256_madd_epi16(s0, _mm256_set1_epi32(src1)); +} + +inline __m256i VmullNHi8(const __m256i src0, const int src1) { + const __m256i s0 = _mm256_unpackhi_epi16(src0, _mm256_setzero_si256()); + return _mm256_madd_epi16(s0, _mm256_set1_epi32(src1)); +} + +inline __m128i VmullLo16(const __m128i src0, const __m128i src1) { + const __m128i s0 = _mm_unpacklo_epi16(src0, _mm_setzero_si128()); + const __m128i s1 = _mm_unpacklo_epi16(src1, _mm_setzero_si128()); + return _mm_madd_epi16(s0, s1); +} + +inline __m256i VmullLo16(const __m256i src0, const __m256i src1) { + const __m256i s0 = _mm256_unpacklo_epi16(src0, _mm256_setzero_si256()); + const __m256i s1 = _mm256_unpacklo_epi16(src1, _mm256_setzero_si256()); + return _mm256_madd_epi16(s0, s1); +} + +inline __m128i VmullHi16(const __m128i src0, const __m128i src1) { + const __m128i s0 = _mm_unpackhi_epi16(src0, _mm_setzero_si128()); + const __m128i s1 = _mm_unpackhi_epi16(src1, _mm_setzero_si128()); + return _mm_madd_epi16(s0, s1); +} + +inline __m256i VmullHi16(const __m256i src0, const __m256i src1) { + const __m256i s0 = _mm256_unpackhi_epi16(src0, _mm256_setzero_si256()); + const __m256i s1 = _mm256_unpackhi_epi16(src1, _mm256_setzero_si256()); + return _mm256_madd_epi16(s0, s1); +} + +inline __m256i VrshrS32(const __m256i src0, const int src1) { + const __m256i sum = + _mm256_add_epi32(src0, _mm256_set1_epi32(1 << (src1 - 1))); + return _mm256_srai_epi32(sum, src1); +} + +inline __m128i VrshrU32(const __m128i src0, const int src1) { + const __m128i sum = _mm_add_epi32(src0, _mm_set1_epi32(1 << (src1 - 1))); + return _mm_srli_epi32(sum, src1); +} + +inline __m256i VrshrU32(const __m256i src0, const int src1) { + const __m256i sum = + _mm256_add_epi32(src0, _mm256_set1_epi32(1 << (src1 - 1))); + return _mm256_srli_epi32(sum, src1); +} + +inline __m128i SquareLo8(const __m128i src) { + const __m128i s = _mm_unpacklo_epi8(src, _mm_setzero_si128()); + return _mm_mullo_epi16(s, s); +} + +inline __m256i SquareLo8(const __m256i src) { + const __m256i s = _mm256_unpacklo_epi8(src, _mm256_setzero_si256()); + return _mm256_mullo_epi16(s, s); +} + +inline __m128i SquareHi8(const __m128i src) { + const __m128i s = _mm_unpackhi_epi8(src, _mm_setzero_si128()); + return _mm_mullo_epi16(s, s); +} + +inline __m256i SquareHi8(const __m256i src) { + const __m256i s = _mm256_unpackhi_epi8(src, _mm256_setzero_si256()); + return _mm256_mullo_epi16(s, s); +} + +inline void Prepare3Lo8(const __m128i src, __m128i dst[3]) { + dst[0] = src; + dst[1] = _mm_srli_si128(src, 1); + dst[2] = _mm_srli_si128(src, 2); +} + +inline void Prepare3_8(const __m256i src[2], __m256i dst[3]) { + dst[0] = _mm256_alignr_epi8(src[1], src[0], 0); + dst[1] = _mm256_alignr_epi8(src[1], src[0], 1); + dst[2] = _mm256_alignr_epi8(src[1], src[0], 2); +} + +inline void Prepare3_16(const __m128i src[2], __m128i dst[3]) { + dst[0] = src[0]; + dst[1] = _mm_alignr_epi8(src[1], src[0], 2); + dst[2] = _mm_alignr_epi8(src[1], src[0], 4); +} + +inline void Prepare3_16(const __m256i src[2], __m256i dst[3]) { + dst[0] = src[0]; + dst[1] = _mm256_alignr_epi8(src[1], src[0], 2); + dst[2] = _mm256_alignr_epi8(src[1], src[0], 4); +} + +inline void Prepare5Lo8(const __m128i src, __m128i dst[5]) { + dst[0] = src; + dst[1] = _mm_srli_si128(src, 1); + dst[2] = _mm_srli_si128(src, 2); + dst[3] = _mm_srli_si128(src, 3); + dst[4] = _mm_srli_si128(src, 4); +} + +inline void Prepare5_16(const __m128i src[2], __m128i dst[5]) { + Prepare3_16(src, dst); + dst[3] = _mm_alignr_epi8(src[1], src[0], 6); + dst[4] = _mm_alignr_epi8(src[1], src[0], 8); +} + +inline void Prepare5_16(const __m256i src[2], __m256i dst[5]) { + Prepare3_16(src, dst); + dst[3] = _mm256_alignr_epi8(src[1], src[0], 6); + dst[4] = _mm256_alignr_epi8(src[1], src[0], 8); +} + +inline __m128i Sum3_16(const __m128i src0, const __m128i src1, + const __m128i src2) { + const __m128i sum = _mm_add_epi16(src0, src1); + return _mm_add_epi16(sum, src2); +} + +inline __m256i Sum3_16(const __m256i src0, const __m256i src1, + const __m256i src2) { + const __m256i sum = _mm256_add_epi16(src0, src1); + return _mm256_add_epi16(sum, src2); +} + +inline __m128i Sum3_16(const __m128i src[3]) { + return Sum3_16(src[0], src[1], src[2]); +} + +inline __m256i Sum3_16(const __m256i src[3]) { + return Sum3_16(src[0], src[1], src[2]); +} + +inline __m128i Sum3_32(const __m128i src0, const __m128i src1, + const __m128i src2) { + const __m128i sum = _mm_add_epi32(src0, src1); + return _mm_add_epi32(sum, src2); +} + +inline __m256i Sum3_32(const __m256i src0, const __m256i src1, + const __m256i src2) { + const __m256i sum = _mm256_add_epi32(src0, src1); + return _mm256_add_epi32(sum, src2); +} + +inline void Sum3_32(const __m128i src[3][2], __m128i dst[2]) { + dst[0] = Sum3_32(src[0][0], src[1][0], src[2][0]); + dst[1] = Sum3_32(src[0][1], src[1][1], src[2][1]); +} + +inline void Sum3_32(const __m256i src[3][2], __m256i dst[2]) { + dst[0] = Sum3_32(src[0][0], src[1][0], src[2][0]); + dst[1] = Sum3_32(src[0][1], src[1][1], src[2][1]); +} + +inline __m128i Sum3WLo16(const __m128i src[3]) { + const __m128i sum = VaddlLo8(src[0], src[1]); + return VaddwLo8(sum, src[2]); +} + +inline __m256i Sum3WLo16(const __m256i src[3]) { + const __m256i sum = VaddlLo8(src[0], src[1]); + return VaddwLo8(sum, src[2]); +} + +inline __m256i Sum3WHi16(const __m256i src[3]) { + const __m256i sum = VaddlHi8(src[0], src[1]); + return VaddwHi8(sum, src[2]); +} + +inline __m128i Sum3WLo32(const __m128i src[3]) { + const __m128i sum = VaddlLo16(src[0], src[1]); + return VaddwLo16(sum, src[2]); +} + +inline __m256i Sum3WLo32(const __m256i src[3]) { + const __m256i sum = VaddlLo16(src[0], src[1]); + return VaddwLo16(sum, src[2]); +} + +inline __m128i Sum3WHi32(const __m128i src[3]) { + const __m128i sum = VaddlHi16(src[0], src[1]); + return VaddwHi16(sum, src[2]); +} + +inline __m256i Sum3WHi32(const __m256i src[3]) { + const __m256i sum = VaddlHi16(src[0], src[1]); + return VaddwHi16(sum, src[2]); +} + +inline __m128i Sum5_16(const __m128i src[5]) { + const __m128i sum01 = _mm_add_epi16(src[0], src[1]); + const __m128i sum23 = _mm_add_epi16(src[2], src[3]); + const __m128i sum = _mm_add_epi16(sum01, sum23); + return _mm_add_epi16(sum, src[4]); +} + +inline __m256i Sum5_16(const __m256i src[5]) { + const __m256i sum01 = _mm256_add_epi16(src[0], src[1]); + const __m256i sum23 = _mm256_add_epi16(src[2], src[3]); + const __m256i sum = _mm256_add_epi16(sum01, sum23); + return _mm256_add_epi16(sum, src[4]); +} + +inline __m128i Sum5_32(const __m128i* const src0, const __m128i* const src1, + const __m128i* const src2, const __m128i* const src3, + const __m128i* const src4) { + const __m128i sum01 = _mm_add_epi32(*src0, *src1); + const __m128i sum23 = _mm_add_epi32(*src2, *src3); + const __m128i sum = _mm_add_epi32(sum01, sum23); + return _mm_add_epi32(sum, *src4); +} + +inline __m256i Sum5_32(const __m256i* const src0, const __m256i* const src1, + const __m256i* const src2, const __m256i* const src3, + const __m256i* const src4) { + const __m256i sum01 = _mm256_add_epi32(*src0, *src1); + const __m256i sum23 = _mm256_add_epi32(*src2, *src3); + const __m256i sum = _mm256_add_epi32(sum01, sum23); + return _mm256_add_epi32(sum, *src4); +} + +inline void Sum5_32(const __m128i src[5][2], __m128i dst[2]) { + dst[0] = Sum5_32(&src[0][0], &src[1][0], &src[2][0], &src[3][0], &src[4][0]); + dst[1] = Sum5_32(&src[0][1], &src[1][1], &src[2][1], &src[3][1], &src[4][1]); +} + +inline void Sum5_32(const __m256i src[5][2], __m256i dst[2]) { + dst[0] = Sum5_32(&src[0][0], &src[1][0], &src[2][0], &src[3][0], &src[4][0]); + dst[1] = Sum5_32(&src[0][1], &src[1][1], &src[2][1], &src[3][1], &src[4][1]); +} + +inline __m128i Sum5WLo16(const __m128i src[5]) { + const __m128i sum01 = VaddlLo8(src[0], src[1]); + const __m128i sum23 = VaddlLo8(src[2], src[3]); + const __m128i sum = _mm_add_epi16(sum01, sum23); + return VaddwLo8(sum, src[4]); +} + +inline __m256i Sum5WLo16(const __m256i src[5]) { + const __m256i sum01 = VaddlLo8(src[0], src[1]); + const __m256i sum23 = VaddlLo8(src[2], src[3]); + const __m256i sum = _mm256_add_epi16(sum01, sum23); + return VaddwLo8(sum, src[4]); +} + +inline __m256i Sum5WHi16(const __m256i src[5]) { + const __m256i sum01 = VaddlHi8(src[0], src[1]); + const __m256i sum23 = VaddlHi8(src[2], src[3]); + const __m256i sum = _mm256_add_epi16(sum01, sum23); + return VaddwHi8(sum, src[4]); +} + +inline __m128i Sum3Horizontal(const __m128i src) { + __m128i s[3]; + Prepare3Lo8(src, s); + return Sum3WLo16(s); +} + +inline void Sum3Horizontal(const uint8_t* const src, + const ptrdiff_t over_read_in_bytes, __m256i dst[2]) { + __m256i s[3]; + s[0] = LoadUnaligned32Msan(src + 0, over_read_in_bytes + 0); + s[1] = LoadUnaligned32Msan(src + 1, over_read_in_bytes + 1); + s[2] = LoadUnaligned32Msan(src + 2, over_read_in_bytes + 2); + dst[0] = Sum3WLo16(s); + dst[1] = Sum3WHi16(s); +} + +inline void Sum3WHorizontal(const __m128i src[2], __m128i dst[2]) { + __m128i s[3]; + Prepare3_16(src, s); + dst[0] = Sum3WLo32(s); + dst[1] = Sum3WHi32(s); +} + +inline void Sum3WHorizontal(const __m256i src[2], __m256i dst[2]) { + __m256i s[3]; + Prepare3_16(src, s); + dst[0] = Sum3WLo32(s); + dst[1] = Sum3WHi32(s); +} + +inline __m128i Sum5Horizontal(const __m128i src) { + __m128i s[5]; + Prepare5Lo8(src, s); + return Sum5WLo16(s); +} + +inline void Sum5Horizontal(const uint8_t* const src, + const ptrdiff_t over_read_in_bytes, + __m256i* const dst0, __m256i* const dst1) { + __m256i s[5]; + s[0] = LoadUnaligned32Msan(src + 0, over_read_in_bytes + 0); + s[1] = LoadUnaligned32Msan(src + 1, over_read_in_bytes + 1); + s[2] = LoadUnaligned32Msan(src + 2, over_read_in_bytes + 2); + s[3] = LoadUnaligned32Msan(src + 3, over_read_in_bytes + 3); + s[4] = LoadUnaligned32Msan(src + 4, over_read_in_bytes + 4); + *dst0 = Sum5WLo16(s); + *dst1 = Sum5WHi16(s); +} + +inline void Sum5WHorizontal(const __m128i src[2], __m128i dst[2]) { + __m128i s[5]; + Prepare5_16(src, s); + const __m128i sum01_lo = VaddlLo16(s[0], s[1]); + const __m128i sum23_lo = VaddlLo16(s[2], s[3]); + const __m128i sum0123_lo = _mm_add_epi32(sum01_lo, sum23_lo); + dst[0] = VaddwLo16(sum0123_lo, s[4]); + const __m128i sum01_hi = VaddlHi16(s[0], s[1]); + const __m128i sum23_hi = VaddlHi16(s[2], s[3]); + const __m128i sum0123_hi = _mm_add_epi32(sum01_hi, sum23_hi); + dst[1] = VaddwHi16(sum0123_hi, s[4]); +} + +inline void Sum5WHorizontal(const __m256i src[2], __m256i dst[2]) { + __m256i s[5]; + Prepare5_16(src, s); + const __m256i sum01_lo = VaddlLo16(s[0], s[1]); + const __m256i sum23_lo = VaddlLo16(s[2], s[3]); + const __m256i sum0123_lo = _mm256_add_epi32(sum01_lo, sum23_lo); + dst[0] = VaddwLo16(sum0123_lo, s[4]); + const __m256i sum01_hi = VaddlHi16(s[0], s[1]); + const __m256i sum23_hi = VaddlHi16(s[2], s[3]); + const __m256i sum0123_hi = _mm256_add_epi32(sum01_hi, sum23_hi); + dst[1] = VaddwHi16(sum0123_hi, s[4]); +} + +void SumHorizontalLo(const __m128i src[5], __m128i* const row_sq3, + __m128i* const row_sq5) { + const __m128i sum04 = VaddlLo16(src[0], src[4]); + *row_sq3 = Sum3WLo32(src + 1); + *row_sq5 = _mm_add_epi32(sum04, *row_sq3); +} + +void SumHorizontalLo(const __m256i src[5], __m256i* const row_sq3, + __m256i* const row_sq5) { + const __m256i sum04 = VaddlLo16(src[0], src[4]); + *row_sq3 = Sum3WLo32(src + 1); + *row_sq5 = _mm256_add_epi32(sum04, *row_sq3); +} + +void SumHorizontalHi(const __m128i src[5], __m128i* const row_sq3, + __m128i* const row_sq5) { + const __m128i sum04 = VaddlHi16(src[0], src[4]); + *row_sq3 = Sum3WHi32(src + 1); + *row_sq5 = _mm_add_epi32(sum04, *row_sq3); +} + +void SumHorizontalHi(const __m256i src[5], __m256i* const row_sq3, + __m256i* const row_sq5) { + const __m256i sum04 = VaddlHi16(src[0], src[4]); + *row_sq3 = Sum3WHi32(src + 1); + *row_sq5 = _mm256_add_epi32(sum04, *row_sq3); +} + +void SumHorizontalLo(const __m128i src, __m128i* const row3, + __m128i* const row5) { + __m128i s[5]; + Prepare5Lo8(src, s); + const __m128i sum04 = VaddlLo8(s[0], s[4]); + *row3 = Sum3WLo16(s + 1); + *row5 = _mm_add_epi16(sum04, *row3); +} + +inline void SumHorizontal(const uint8_t* const src, + const ptrdiff_t over_read_in_bytes, + __m256i* const row3_0, __m256i* const row3_1, + __m256i* const row5_0, __m256i* const row5_1) { + __m256i s[5]; + s[0] = LoadUnaligned32Msan(src + 0, over_read_in_bytes + 0); + s[1] = LoadUnaligned32Msan(src + 1, over_read_in_bytes + 1); + s[2] = LoadUnaligned32Msan(src + 2, over_read_in_bytes + 2); + s[3] = LoadUnaligned32Msan(src + 3, over_read_in_bytes + 3); + s[4] = LoadUnaligned32Msan(src + 4, over_read_in_bytes + 4); + const __m256i sum04_lo = VaddlLo8(s[0], s[4]); + const __m256i sum04_hi = VaddlHi8(s[0], s[4]); + *row3_0 = Sum3WLo16(s + 1); + *row3_1 = Sum3WHi16(s + 1); + *row5_0 = _mm256_add_epi16(sum04_lo, *row3_0); + *row5_1 = _mm256_add_epi16(sum04_hi, *row3_1); +} + +inline void SumHorizontal(const __m128i src[2], __m128i* const row_sq3_0, + __m128i* const row_sq3_1, __m128i* const row_sq5_0, + __m128i* const row_sq5_1) { + __m128i s[5]; + Prepare5_16(src, s); + SumHorizontalLo(s, row_sq3_0, row_sq5_0); + SumHorizontalHi(s, row_sq3_1, row_sq5_1); +} + +inline void SumHorizontal(const __m256i src[2], __m256i* const row_sq3_0, + __m256i* const row_sq3_1, __m256i* const row_sq5_0, + __m256i* const row_sq5_1) { + __m256i s[5]; + Prepare5_16(src, s); + SumHorizontalLo(s, row_sq3_0, row_sq5_0); + SumHorizontalHi(s, row_sq3_1, row_sq5_1); +} + +inline __m256i Sum343Lo(const __m256i ma3[3]) { + const __m256i sum = Sum3WLo16(ma3); + const __m256i sum3 = Sum3_16(sum, sum, sum); + return VaddwLo8(sum3, ma3[1]); +} + +inline __m256i Sum343Hi(const __m256i ma3[3]) { + const __m256i sum = Sum3WHi16(ma3); + const __m256i sum3 = Sum3_16(sum, sum, sum); + return VaddwHi8(sum3, ma3[1]); +} + +inline __m256i Sum343WLo(const __m256i src[3]) { + const __m256i sum = Sum3WLo32(src); + const __m256i sum3 = Sum3_32(sum, sum, sum); + return VaddwLo16(sum3, src[1]); +} + +inline __m256i Sum343WHi(const __m256i src[3]) { + const __m256i sum = Sum3WHi32(src); + const __m256i sum3 = Sum3_32(sum, sum, sum); + return VaddwHi16(sum3, src[1]); +} + +inline void Sum343W(const __m256i src[2], __m256i dst[2]) { + __m256i s[3]; + Prepare3_16(src, s); + dst[0] = Sum343WLo(s); + dst[1] = Sum343WHi(s); +} + +inline __m256i Sum565Lo(const __m256i src[3]) { + const __m256i sum = Sum3WLo16(src); + const __m256i sum4 = _mm256_slli_epi16(sum, 2); + const __m256i sum5 = _mm256_add_epi16(sum4, sum); + return VaddwLo8(sum5, src[1]); +} + +inline __m256i Sum565Hi(const __m256i src[3]) { + const __m256i sum = Sum3WHi16(src); + const __m256i sum4 = _mm256_slli_epi16(sum, 2); + const __m256i sum5 = _mm256_add_epi16(sum4, sum); + return VaddwHi8(sum5, src[1]); +} + +inline __m256i Sum565WLo(const __m256i src[3]) { + const __m256i sum = Sum3WLo32(src); + const __m256i sum4 = _mm256_slli_epi32(sum, 2); + const __m256i sum5 = _mm256_add_epi32(sum4, sum); + return VaddwLo16(sum5, src[1]); +} + +inline __m256i Sum565WHi(const __m256i src[3]) { + const __m256i sum = Sum3WHi32(src); + const __m256i sum4 = _mm256_slli_epi32(sum, 2); + const __m256i sum5 = _mm256_add_epi32(sum4, sum); + return VaddwHi16(sum5, src[1]); +} + +inline void Sum565W(const __m256i src[2], __m256i dst[2]) { + __m256i s[3]; + Prepare3_16(src, s); + dst[0] = Sum565WLo(s); + dst[1] = Sum565WHi(s); +} + +inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const ptrdiff_t sum_stride, + const ptrdiff_t sum_width, uint16_t* sum3, uint16_t* sum5, + uint32_t* square_sum3, uint32_t* square_sum5) { + int y = 2; + do { + const __m128i s0 = + LoadUnaligned16Msan(src, kOverreadInBytesPass1_128 - width); + __m128i sq_128[2]; + __m256i sq[3]; + __m128i s3, s5, sq3[2], sq5[2]; + sq_128[0] = SquareLo8(s0); + sq_128[1] = SquareHi8(s0); + SumHorizontalLo(s0, &s3, &s5); + StoreAligned16(sum3, s3); + StoreAligned16(sum5, s5); + SumHorizontal(sq_128, &sq3[0], &sq3[1], &sq5[0], &sq5[1]); + StoreAligned32U32(square_sum3, sq3); + StoreAligned32U32(square_sum5, sq5); + src += 8; + sum3 += 8; + sum5 += 8; + square_sum3 += 8; + square_sum5 += 8; + sq[0] = SetrM128i(sq_128[1], sq_128[1]); + ptrdiff_t x = sum_width; + do { + __m256i row3[2], row5[2], row_sq3[2], row_sq5[2]; + const __m256i s = LoadUnaligned32Msan( + src + 8, sum_width - x + 16 + kOverreadInBytesPass1_256 - width); + sq[1] = SquareLo8(s); + sq[2] = SquareHi8(s); + sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21); + SumHorizontal(src, sum_width - x + 8 + kOverreadInBytesPass1_256 - width, + &row3[0], &row3[1], &row5[0], &row5[1]); + StoreAligned64(sum3, row3); + StoreAligned64(sum5, row5); + SumHorizontal(sq + 0, &row_sq3[0], &row_sq3[1], &row_sq5[0], &row_sq5[1]); + StoreAligned64(square_sum3 + 0, row_sq3); + StoreAligned64(square_sum5 + 0, row_sq5); + SumHorizontal(sq + 1, &row_sq3[0], &row_sq3[1], &row_sq5[0], &row_sq5[1]); + StoreAligned64(square_sum3 + 16, row_sq3); + StoreAligned64(square_sum5 + 16, row_sq5); + sq[0] = sq[2]; + src += 32; + sum3 += 32; + sum5 += 32; + square_sum3 += 32; + square_sum5 += 32; + x -= 32; + } while (x != 0); + src += src_stride - sum_width - 8; + sum3 += sum_stride - sum_width - 8; + sum5 += sum_stride - sum_width - 8; + square_sum3 += sum_stride - sum_width - 8; + square_sum5 += sum_stride - sum_width - 8; + } while (--y != 0); +} + +template <int size> +inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const ptrdiff_t sum_stride, + const ptrdiff_t sum_width, uint16_t* sums, + uint32_t* square_sums) { + static_assert(size == 3 || size == 5, ""); + int kOverreadInBytes_128, kOverreadInBytes_256; + if (size == 3) { + kOverreadInBytes_128 = kOverreadInBytesPass2_128; + kOverreadInBytes_256 = kOverreadInBytesPass2_256; + } else { + kOverreadInBytes_128 = kOverreadInBytesPass1_128; + kOverreadInBytes_256 = kOverreadInBytesPass1_256; + } + int y = 2; + do { + const __m128i s = LoadUnaligned16Msan(src, kOverreadInBytes_128 - width); + __m128i ss, sq_128[2], sqs[2]; + __m256i sq[3]; + sq_128[0] = SquareLo8(s); + sq_128[1] = SquareHi8(s); + if (size == 3) { + ss = Sum3Horizontal(s); + Sum3WHorizontal(sq_128, sqs); + } else { + ss = Sum5Horizontal(s); + Sum5WHorizontal(sq_128, sqs); + } + StoreAligned16(sums, ss); + StoreAligned32U32(square_sums, sqs); + src += 8; + sums += 8; + square_sums += 8; + sq[0] = SetrM128i(sq_128[1], sq_128[1]); + ptrdiff_t x = sum_width; + do { + __m256i row[2], row_sq[4]; + const __m256i s = LoadUnaligned32Msan( + src + 8, sum_width - x + 16 + kOverreadInBytes_256 - width); + sq[1] = SquareLo8(s); + sq[2] = SquareHi8(s); + sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21); + if (size == 3) { + Sum3Horizontal(src, sum_width - x + 8 + kOverreadInBytes_256 - width, + row); + Sum3WHorizontal(sq + 0, row_sq + 0); + Sum3WHorizontal(sq + 1, row_sq + 2); + } else { + Sum5Horizontal(src, sum_width - x + 8 + kOverreadInBytes_256 - width, + &row[0], &row[1]); + Sum5WHorizontal(sq + 0, row_sq + 0); + Sum5WHorizontal(sq + 1, row_sq + 2); + } + StoreAligned64(sums, row); + StoreAligned64(square_sums + 0, row_sq + 0); + StoreAligned64(square_sums + 16, row_sq + 2); + sq[0] = sq[2]; + src += 32; + sums += 32; + square_sums += 32; + x -= 32; + } while (x != 0); + src += src_stride - sum_width - 8; + sums += sum_stride - sum_width - 8; + square_sums += sum_stride - sum_width - 8; + } while (--y != 0); +} + +template <int n> +inline __m128i CalculateMa(const __m128i sum, const __m128i sum_sq, + const uint32_t scale) { + static_assert(n == 9 || n == 25, ""); + // a = |sum_sq| + // d = |sum| + // p = (a * n < d * d) ? 0 : a * n - d * d; + const __m128i dxd = _mm_madd_epi16(sum, sum); + // _mm_mullo_epi32() has high latency. Using shifts and additions instead. + // Some compilers could do this for us but we make this explicit. + // return _mm_mullo_epi32(sum_sq, _mm_set1_epi32(n)); + __m128i axn = _mm_add_epi32(sum_sq, _mm_slli_epi32(sum_sq, 3)); + if (n == 25) axn = _mm_add_epi32(axn, _mm_slli_epi32(sum_sq, 4)); + const __m128i sub = _mm_sub_epi32(axn, dxd); + const __m128i p = _mm_max_epi32(sub, _mm_setzero_si128()); + const __m128i pxs = _mm_mullo_epi32(p, _mm_set1_epi32(scale)); + return VrshrU32(pxs, kSgrProjScaleBits); +} + +template <int n> +inline __m128i CalculateMa(const __m128i sum, const __m128i sum_sq[2], + const uint32_t scale) { + static_assert(n == 9 || n == 25, ""); + const __m128i sum_lo = _mm_unpacklo_epi16(sum, _mm_setzero_si128()); + const __m128i sum_hi = _mm_unpackhi_epi16(sum, _mm_setzero_si128()); + const __m128i z0 = CalculateMa<n>(sum_lo, sum_sq[0], scale); + const __m128i z1 = CalculateMa<n>(sum_hi, sum_sq[1], scale); + return _mm_packus_epi32(z0, z1); +} + +template <int n> +inline __m256i CalculateMa(const __m256i sum, const __m256i sum_sq, + const uint32_t scale) { + static_assert(n == 9 || n == 25, ""); + // a = |sum_sq| + // d = |sum| + // p = (a * n < d * d) ? 0 : a * n - d * d; + const __m256i dxd = _mm256_madd_epi16(sum, sum); + // _mm256_mullo_epi32() has high latency. Using shifts and additions instead. + // Some compilers could do this for us but we make this explicit. + // return _mm256_mullo_epi32(sum_sq, _mm256_set1_epi32(n)); + __m256i axn = _mm256_add_epi32(sum_sq, _mm256_slli_epi32(sum_sq, 3)); + if (n == 25) axn = _mm256_add_epi32(axn, _mm256_slli_epi32(sum_sq, 4)); + const __m256i sub = _mm256_sub_epi32(axn, dxd); + const __m256i p = _mm256_max_epi32(sub, _mm256_setzero_si256()); + const __m256i pxs = _mm256_mullo_epi32(p, _mm256_set1_epi32(scale)); + return VrshrU32(pxs, kSgrProjScaleBits); +} + +template <int n> +inline __m256i CalculateMa(const __m256i sum, const __m256i sum_sq[2], + const uint32_t scale) { + static_assert(n == 9 || n == 25, ""); + const __m256i sum_lo = _mm256_unpacklo_epi16(sum, _mm256_setzero_si256()); + const __m256i sum_hi = _mm256_unpackhi_epi16(sum, _mm256_setzero_si256()); + const __m256i z0 = CalculateMa<n>(sum_lo, sum_sq[0], scale); + const __m256i z1 = CalculateMa<n>(sum_hi, sum_sq[1], scale); + return _mm256_packus_epi32(z0, z1); +} + +template <int n> +inline __m128i CalculateB(const __m128i sum, const __m128i ma) { + static_assert(n == 9 || n == 25, ""); + constexpr uint32_t one_over_n = + ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n; + const __m128i m0 = VmullLo16(ma, sum); + const __m128i m1 = VmullHi16(ma, sum); + const __m128i m2 = _mm_mullo_epi32(m0, _mm_set1_epi32(one_over_n)); + const __m128i m3 = _mm_mullo_epi32(m1, _mm_set1_epi32(one_over_n)); + const __m128i b_lo = VrshrU32(m2, kSgrProjReciprocalBits); + const __m128i b_hi = VrshrU32(m3, kSgrProjReciprocalBits); + return _mm_packus_epi32(b_lo, b_hi); +} + +template <int n> +inline __m256i CalculateB(const __m256i sum, const __m256i ma) { + static_assert(n == 9 || n == 25, ""); + constexpr uint32_t one_over_n = + ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n; + const __m256i m0 = VmullLo16(ma, sum); + const __m256i m1 = VmullHi16(ma, sum); + const __m256i m2 = _mm256_mullo_epi32(m0, _mm256_set1_epi32(one_over_n)); + const __m256i m3 = _mm256_mullo_epi32(m1, _mm256_set1_epi32(one_over_n)); + const __m256i b_lo = VrshrU32(m2, kSgrProjReciprocalBits); + const __m256i b_hi = VrshrU32(m3, kSgrProjReciprocalBits); + return _mm256_packus_epi32(b_lo, b_hi); +} + +inline void CalculateSumAndIndex5(const __m128i s5[5], const __m128i sq5[5][2], + const uint32_t scale, __m128i* const sum, + __m128i* const index) { + __m128i sum_sq[2]; + *sum = Sum5_16(s5); + Sum5_32(sq5, sum_sq); + *index = CalculateMa<25>(*sum, sum_sq, scale); +} + +inline void CalculateSumAndIndex5(const __m256i s5[5], const __m256i sq5[5][2], + const uint32_t scale, __m256i* const sum, + __m256i* const index) { + __m256i sum_sq[2]; + *sum = Sum5_16(s5); + Sum5_32(sq5, sum_sq); + *index = CalculateMa<25>(*sum, sum_sq, scale); +} + +inline void CalculateSumAndIndex3(const __m128i s3[3], const __m128i sq3[3][2], + const uint32_t scale, __m128i* const sum, + __m128i* const index) { + __m128i sum_sq[2]; + *sum = Sum3_16(s3); + Sum3_32(sq3, sum_sq); + *index = CalculateMa<9>(*sum, sum_sq, scale); +} + +inline void CalculateSumAndIndex3(const __m256i s3[3], const __m256i sq3[3][2], + const uint32_t scale, __m256i* const sum, + __m256i* const index) { + __m256i sum_sq[2]; + *sum = Sum3_16(s3); + Sum3_32(sq3, sum_sq); + *index = CalculateMa<9>(*sum, sum_sq, scale); +} + +template <int n> +inline void LookupIntermediate(const __m128i sum, const __m128i index, + __m128i* const ma, __m128i* const b) { + static_assert(n == 9 || n == 25, ""); + const __m128i idx = _mm_packus_epi16(index, index); + // Actually it's not stored and loaded. The compiler will use a 64-bit + // general-purpose register to process. Faster than using _mm_extract_epi8(). + uint8_t temp[8]; + StoreLo8(temp, idx); + *ma = _mm_cvtsi32_si128(kSgrMaLookup[temp[0]]); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[1]], 1); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[2]], 2); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[3]], 3); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[4]], 4); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[5]], 5); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[6]], 6); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[7]], 7); + // b = ma * b * one_over_n + // |ma| = [0, 255] + // |sum| is a box sum with radius 1 or 2. + // For the first pass radius is 2. Maximum value is 5x5x255 = 6375. + // For the second pass radius is 1. Maximum value is 3x3x255 = 2295. + // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n + // When radius is 2 |n| is 25. |one_over_n| is 164. + // When radius is 1 |n| is 9. |one_over_n| is 455. + // |kSgrProjReciprocalBits| is 12. + // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits). + // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits). + const __m128i maq = _mm_unpacklo_epi8(*ma, _mm_setzero_si128()); + *b = CalculateB<n>(sum, maq); +} + +// Repeat the first 48 elements in kSgrMaLookup with a period of 16. +alignas(32) constexpr uint8_t kSgrMaLookupAvx2[96] = { + 255, 128, 85, 64, 51, 43, 37, 32, 28, 26, 23, 21, 20, 18, 17, 16, + 255, 128, 85, 64, 51, 43, 37, 32, 28, 26, 23, 21, 20, 18, 17, 16, + 15, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 9, 9, 8, 8, + 15, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 9, 9, 8, 8, + 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 5, 5, + 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 5, 5}; + +// Set the shuffle control mask of indices out of range [0, 15] to (1xxxxxxx)b +// to get value 0 as the shuffle result. The most significiant bit 1 comes +// either from the comparision instruction, or from the sign bit of the index. +inline __m256i ShuffleIndex(const __m256i table, const __m256i index) { + __m256i mask; + mask = _mm256_cmpgt_epi8(index, _mm256_set1_epi8(15)); + mask = _mm256_or_si256(mask, index); + return _mm256_shuffle_epi8(table, mask); +} + +inline __m256i AdjustValue(const __m256i value, const __m256i index, + const int threshold) { + const __m256i thresholds = _mm256_set1_epi8(threshold - 128); + const __m256i offset = _mm256_cmpgt_epi8(index, thresholds); + return _mm256_add_epi8(value, offset); +} + +template <int n> +inline void CalculateIntermediate(const __m256i sum[2], const __m256i index[2], + __m256i ma[3], __m256i b[2]) { + static_assert(n == 9 || n == 25, ""); + // Use table lookup to read elements which indices are less than 48. + const __m256i c0 = LoadAligned32(kSgrMaLookupAvx2 + 0 * 32); + const __m256i c1 = LoadAligned32(kSgrMaLookupAvx2 + 1 * 32); + const __m256i c2 = LoadAligned32(kSgrMaLookupAvx2 + 2 * 32); + const __m256i indices = _mm256_packus_epi16(index[0], index[1]); + __m256i idx, mas; + // Clip idx to 127 to apply signed comparision instructions. + idx = _mm256_min_epu8(indices, _mm256_set1_epi8(127)); + // All elements which indices are less than 48 are set to 0. + // Get shuffle results for indices in range [0, 15]. + mas = ShuffleIndex(c0, idx); + // Get shuffle results for indices in range [16, 31]. + // Subtract 16 to utilize the sign bit of the index. + idx = _mm256_sub_epi8(idx, _mm256_set1_epi8(16)); + const __m256i res1 = ShuffleIndex(c1, idx); + // Use OR instruction to combine shuffle results together. + mas = _mm256_or_si256(mas, res1); + // Get shuffle results for indices in range [32, 47]. + // Subtract 16 to utilize the sign bit of the index. + idx = _mm256_sub_epi8(idx, _mm256_set1_epi8(16)); + const __m256i res2 = ShuffleIndex(c2, idx); + mas = _mm256_or_si256(mas, res2); + + // For elements which indices are larger than 47, since they seldom change + // values with the increase of the index, we use comparison and arithmetic + // operations to calculate their values. + // Add -128 to apply signed comparision instructions. + idx = _mm256_add_epi8(indices, _mm256_set1_epi8(-128)); + // Elements which indices are larger than 47 (with value 0) are set to 5. + mas = _mm256_max_epu8(mas, _mm256_set1_epi8(5)); + mas = AdjustValue(mas, idx, 55); // 55 is the last index which value is 5. + mas = AdjustValue(mas, idx, 72); // 72 is the last index which value is 4. + mas = AdjustValue(mas, idx, 101); // 101 is the last index which value is 3. + mas = AdjustValue(mas, idx, 169); // 169 is the last index which value is 2. + mas = AdjustValue(mas, idx, 254); // 254 is the last index which value is 1. + + ma[2] = _mm256_permute4x64_epi64(mas, 0x93); // 32-39 8-15 16-23 24-31 + ma[0] = _mm256_blend_epi32(ma[0], ma[2], 0xfc); // 0-7 8-15 16-23 24-31 + ma[1] = _mm256_permute2x128_si256(ma[0], ma[2], 0x21); + + // b = ma * b * one_over_n + // |ma| = [0, 255] + // |sum| is a box sum with radius 1 or 2. + // For the first pass radius is 2. Maximum value is 5x5x255 = 6375. + // For the second pass radius is 1. Maximum value is 3x3x255 = 2295. + // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n + // When radius is 2 |n| is 25. |one_over_n| is 164. + // When radius is 1 |n| is 9. |one_over_n| is 455. + // |kSgrProjReciprocalBits| is 12. + // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits). + // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits). + const __m256i maq0 = _mm256_unpackhi_epi8(ma[0], _mm256_setzero_si256()); + const __m256i maq1 = _mm256_unpacklo_epi8(ma[1], _mm256_setzero_si256()); + b[0] = CalculateB<n>(sum[0], maq0); + b[1] = CalculateB<n>(sum[1], maq1); +} + +inline void CalculateIntermediate5(const __m128i s5[5], const __m128i sq5[5][2], + const uint32_t scale, __m128i* const ma, + __m128i* const b) { + __m128i sum, index; + CalculateSumAndIndex5(s5, sq5, scale, &sum, &index); + LookupIntermediate<25>(sum, index, ma, b); +} + +inline void CalculateIntermediate3(const __m128i s3[3], const __m128i sq3[3][2], + const uint32_t scale, __m128i* const ma, + __m128i* const b) { + __m128i sum, index; + CalculateSumAndIndex3(s3, sq3, scale, &sum, &index); + LookupIntermediate<9>(sum, index, ma, b); +} + +inline void Store343_444(const __m256i b3[2], const ptrdiff_t x, + __m256i sum_b343[2], __m256i sum_b444[2], + uint32_t* const b343, uint32_t* const b444) { + __m256i b[3], sum_b111[2]; + Prepare3_16(b3, b); + sum_b111[0] = Sum3WLo32(b); + sum_b111[1] = Sum3WHi32(b); + sum_b444[0] = _mm256_slli_epi32(sum_b111[0], 2); + sum_b444[1] = _mm256_slli_epi32(sum_b111[1], 2); + StoreAligned64(b444 + x, sum_b444); + sum_b343[0] = _mm256_sub_epi32(sum_b444[0], sum_b111[0]); + sum_b343[1] = _mm256_sub_epi32(sum_b444[1], sum_b111[1]); + sum_b343[0] = VaddwLo16(sum_b343[0], b[1]); + sum_b343[1] = VaddwHi16(sum_b343[1], b[1]); + StoreAligned64(b343 + x, sum_b343); +} + +inline void Store343_444Lo(const __m256i ma3[3], const __m256i b3[2], + const ptrdiff_t x, __m256i* const sum_ma343, + __m256i* const sum_ma444, __m256i sum_b343[2], + __m256i sum_b444[2], uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + const __m256i sum_ma111 = Sum3WLo16(ma3); + *sum_ma444 = _mm256_slli_epi16(sum_ma111, 2); + StoreAligned32(ma444 + x, *sum_ma444); + const __m256i sum333 = _mm256_sub_epi16(*sum_ma444, sum_ma111); + *sum_ma343 = VaddwLo8(sum333, ma3[1]); + StoreAligned32(ma343 + x, *sum_ma343); + Store343_444(b3, x, sum_b343, sum_b444, b343, b444); +} + +inline void Store343_444Hi(const __m256i ma3[3], const __m256i b3[2], + const ptrdiff_t x, __m256i* const sum_ma343, + __m256i* const sum_ma444, __m256i sum_b343[2], + __m256i sum_b444[2], uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + const __m256i sum_ma111 = Sum3WHi16(ma3); + *sum_ma444 = _mm256_slli_epi16(sum_ma111, 2); + StoreAligned32(ma444 + x, *sum_ma444); + const __m256i sum333 = _mm256_sub_epi16(*sum_ma444, sum_ma111); + *sum_ma343 = VaddwHi8(sum333, ma3[1]); + StoreAligned32(ma343 + x, *sum_ma343); + Store343_444(b3, x, sum_b343, sum_b444, b343, b444); +} + +inline void Store343_444Lo(const __m256i ma3[3], const __m256i b3[2], + const ptrdiff_t x, __m256i* const sum_ma343, + __m256i sum_b343[2], uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + __m256i sum_ma444, sum_b444[2]; + Store343_444Lo(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, sum_b444, ma343, + ma444, b343, b444); +} + +inline void Store343_444Hi(const __m256i ma3[3], const __m256i b3[2], + const ptrdiff_t x, __m256i* const sum_ma343, + __m256i sum_b343[2], uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + __m256i sum_ma444, sum_b444[2]; + Store343_444Hi(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, sum_b444, ma343, + ma444, b343, b444); +} + +inline void Store343_444Lo(const __m256i ma3[3], const __m256i b3[2], + const ptrdiff_t x, uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + __m256i sum_ma343, sum_b343[2]; + Store343_444Lo(ma3, b3, x, &sum_ma343, sum_b343, ma343, ma444, b343, b444); +} + +inline void Store343_444Hi(const __m256i ma3[3], const __m256i b3[2], + const ptrdiff_t x, uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + __m256i sum_ma343, sum_b343[2]; + Store343_444Hi(ma3, b3, x, &sum_ma343, sum_b343, ma343, ma444, b343, b444); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5Lo( + const __m128i s[2][3], const uint32_t scale, uint16_t* const sum5[5], + uint32_t* const square_sum5[5], __m128i sq[2][2], __m128i* const ma, + __m128i* const b) { + __m128i s5[2][5], sq5[5][2]; + sq[0][1] = SquareHi8(s[0][0]); + sq[1][1] = SquareHi8(s[1][0]); + s5[0][3] = Sum5Horizontal(s[0][0]); + StoreAligned16(sum5[3], s5[0][3]); + s5[0][4] = Sum5Horizontal(s[1][0]); + StoreAligned16(sum5[4], s5[0][4]); + Sum5WHorizontal(sq[0], sq5[3]); + StoreAligned32U32(square_sum5[3], sq5[3]); + Sum5WHorizontal(sq[1], sq5[4]); + StoreAligned32U32(square_sum5[4], sq5[4]); + LoadAligned16x3U16(sum5, 0, s5[0]); + LoadAligned32x3U32(square_sum5, 0, sq5); + CalculateIntermediate5(s5[0], sq5, scale, ma, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5( + const uint8_t* const src0, const uint8_t* const src1, + const ptrdiff_t over_read_in_bytes, const ptrdiff_t sum_width, + const ptrdiff_t x, const uint32_t scale, uint16_t* const sum5[5], + uint32_t* const square_sum5[5], __m256i sq[2][3], __m256i ma[3], + __m256i b[3]) { + const __m256i s0 = LoadUnaligned32Msan(src0 + 8, over_read_in_bytes + 8); + const __m256i s1 = LoadUnaligned32Msan(src1 + 8, over_read_in_bytes + 8); + __m256i s5[2][5], sq5[5][2], sum[2], index[2]; + sq[0][1] = SquareLo8(s0); + sq[0][2] = SquareHi8(s0); + sq[1][1] = SquareLo8(s1); + sq[1][2] = SquareHi8(s1); + sq[0][0] = _mm256_permute2x128_si256(sq[0][0], sq[0][2], 0x21); + sq[1][0] = _mm256_permute2x128_si256(sq[1][0], sq[1][2], 0x21); + Sum5Horizontal(src0, over_read_in_bytes, &s5[0][3], &s5[1][3]); + Sum5Horizontal(src1, over_read_in_bytes, &s5[0][4], &s5[1][4]); + StoreAligned32(sum5[3] + x + 0, s5[0][3]); + StoreAligned32(sum5[3] + x + 16, s5[1][3]); + StoreAligned32(sum5[4] + x + 0, s5[0][4]); + StoreAligned32(sum5[4] + x + 16, s5[1][4]); + Sum5WHorizontal(sq[0], sq5[3]); + StoreAligned64(square_sum5[3] + x, sq5[3]); + Sum5WHorizontal(sq[1], sq5[4]); + StoreAligned64(square_sum5[4] + x, sq5[4]); + LoadAligned32x3U16(sum5, x, s5[0]); + LoadAligned64x3U32(square_sum5, x, sq5); + CalculateSumAndIndex5(s5[0], sq5, scale, &sum[0], &index[0]); + + Sum5WHorizontal(sq[0] + 1, sq5[3]); + StoreAligned64(square_sum5[3] + x + 16, sq5[3]); + Sum5WHorizontal(sq[1] + 1, sq5[4]); + StoreAligned64(square_sum5[4] + x + 16, sq5[4]); + LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]); + LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5); + CalculateSumAndIndex5(s5[1], sq5, scale, &sum[1], &index[1]); + CalculateIntermediate<25>(sum, index, ma, b + 1); + b[0] = _mm256_permute2x128_si256(b[0], b[2], 0x21); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRowLo( + const __m128i s, const uint32_t scale, const uint16_t* const sum5[5], + const uint32_t* const square_sum5[5], __m128i sq[2], __m128i* const ma, + __m128i* const b) { + __m128i s5[5], sq5[5][2]; + sq[1] = SquareHi8(s); + s5[3] = s5[4] = Sum5Horizontal(s); + Sum5WHorizontal(sq, sq5[3]); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + LoadAligned16x3U16(sum5, 0, s5); + LoadAligned32x3U32(square_sum5, 0, sq5); + CalculateIntermediate5(s5, sq5, scale, ma, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRow( + const uint8_t* const src, const ptrdiff_t over_read_in_bytes, + const ptrdiff_t sum_width, const ptrdiff_t x, const uint32_t scale, + const uint16_t* const sum5[5], const uint32_t* const square_sum5[5], + __m256i sq[3], __m256i ma[3], __m256i b[3]) { + const __m256i s = LoadUnaligned32Msan(src + 8, over_read_in_bytes + 8); + __m256i s5[2][5], sq5[5][2], sum[2], index[2]; + sq[1] = SquareLo8(s); + sq[2] = SquareHi8(s); + sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21); + Sum5Horizontal(src, over_read_in_bytes, &s5[0][3], &s5[1][3]); + s5[0][4] = s5[0][3]; + s5[1][4] = s5[1][3]; + Sum5WHorizontal(sq, sq5[3]); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + LoadAligned32x3U16(sum5, x, s5[0]); + LoadAligned64x3U32(square_sum5, x, sq5); + CalculateSumAndIndex5(s5[0], sq5, scale, &sum[0], &index[0]); + + Sum5WHorizontal(sq + 1, sq5[3]); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]); + LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5); + CalculateSumAndIndex5(s5[1], sq5, scale, &sum[1], &index[1]); + CalculateIntermediate<25>(sum, index, ma, b + 1); + b[0] = _mm256_permute2x128_si256(b[0], b[2], 0x21); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3Lo( + const __m128i s, const uint32_t scale, uint16_t* const sum3[3], + uint32_t* const square_sum3[3], __m128i sq[2], __m128i* const ma, + __m128i* const b) { + __m128i s3[3], sq3[3][2]; + sq[1] = SquareHi8(s); + s3[2] = Sum3Horizontal(s); + StoreAligned16(sum3[2], s3[2]); + Sum3WHorizontal(sq, sq3[2]); + StoreAligned32U32(square_sum3[2], sq3[2]); + LoadAligned16x2U16(sum3, 0, s3); + LoadAligned32x2U32(square_sum3, 0, sq3); + CalculateIntermediate3(s3, sq3, scale, ma, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3( + const uint8_t* const src, const ptrdiff_t over_read_in_bytes, + const ptrdiff_t x, const ptrdiff_t sum_width, const uint32_t scale, + uint16_t* const sum3[3], uint32_t* const square_sum3[3], __m256i sq[3], + __m256i ma[3], __m256i b[3]) { + const __m256i s = LoadUnaligned32Msan(src + 8, over_read_in_bytes + 8); + __m256i s3[4], sq3[3][2], sum[2], index[2]; + sq[1] = SquareLo8(s); + sq[2] = SquareHi8(s); + sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21); + Sum3Horizontal(src, over_read_in_bytes, s3 + 2); + StoreAligned64(sum3[2] + x, s3 + 2); + Sum3WHorizontal(sq + 0, sq3[2]); + StoreAligned64(square_sum3[2] + x, sq3[2]); + LoadAligned32x2U16(sum3, x, s3); + LoadAligned64x2U32(square_sum3, x, sq3); + CalculateSumAndIndex3(s3, sq3, scale, &sum[0], &index[0]); + + Sum3WHorizontal(sq + 1, sq3[2]); + StoreAligned64(square_sum3[2] + x + 16, sq3[2]); + LoadAligned32x2U16Msan(sum3, x + 16, sum_width, s3 + 1); + LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3); + CalculateSumAndIndex3(s3 + 1, sq3, scale, &sum[1], &index[1]); + CalculateIntermediate<9>(sum, index, ma, b + 1); + b[0] = _mm256_permute2x128_si256(b[0], b[2], 0x21); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLo( + const __m128i s[2], const uint16_t scales[2], uint16_t* const sum3[4], + uint16_t* const sum5[5], uint32_t* const square_sum3[4], + uint32_t* const square_sum5[5], __m128i sq[2][2], __m128i ma3[2], + __m128i b3[2], __m128i* const ma5, __m128i* const b5) { + __m128i s3[4], s5[5], sq3[4][2], sq5[5][2]; + sq[0][1] = SquareHi8(s[0]); + sq[1][1] = SquareHi8(s[1]); + SumHorizontalLo(s[0], &s3[2], &s5[3]); + SumHorizontalLo(s[1], &s3[3], &s5[4]); + StoreAligned16(sum3[2], s3[2]); + StoreAligned16(sum3[3], s3[3]); + StoreAligned16(sum5[3], s5[3]); + StoreAligned16(sum5[4], s5[4]); + SumHorizontal(sq[0], &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + StoreAligned32U32(square_sum3[2], sq3[2]); + StoreAligned32U32(square_sum5[3], sq5[3]); + SumHorizontal(sq[1], &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]); + StoreAligned32U32(square_sum3[3], sq3[3]); + StoreAligned32U32(square_sum5[4], sq5[4]); + LoadAligned16x2U16(sum3, 0, s3); + LoadAligned32x2U32(square_sum3, 0, sq3); + LoadAligned16x3U16(sum5, 0, s5); + LoadAligned32x3U32(square_sum5, 0, sq5); + // Note: in the SSE4_1 version, CalculateIntermediate() is called + // to replace the slow LookupIntermediate() when calculating 16 intermediate + // data points. However, the AVX2 compiler generates even slower code. So we + // keep using CalculateIntermediate3(). + CalculateIntermediate3(s3 + 0, sq3 + 0, scales[1], &ma3[0], &b3[0]); + CalculateIntermediate3(s3 + 1, sq3 + 1, scales[1], &ma3[1], &b3[1]); + CalculateIntermediate5(s5, sq5, scales[0], ma5, b5); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess( + const uint8_t* const src0, const uint8_t* const src1, + const ptrdiff_t over_read_in_bytes, const ptrdiff_t x, + const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + const ptrdiff_t sum_width, __m256i sq[2][3], __m256i ma3[2][3], + __m256i b3[2][5], __m256i ma5[3], __m256i b5[5]) { + const __m256i s0 = LoadUnaligned32Msan(src0 + 8, over_read_in_bytes + 8); + const __m256i s1 = LoadUnaligned32Msan(src1 + 8, over_read_in_bytes + 8); + __m256i s3[2][4], s5[2][5], sq3[4][2], sq5[5][2], sq3t[4][2], sq5t[5][2], + sum_3[2][2], index_3[2][2], sum_5[2], index_5[2]; + sq[0][1] = SquareLo8(s0); + sq[0][2] = SquareHi8(s0); + sq[1][1] = SquareLo8(s1); + sq[1][2] = SquareHi8(s1); + sq[0][0] = _mm256_permute2x128_si256(sq[0][0], sq[0][2], 0x21); + sq[1][0] = _mm256_permute2x128_si256(sq[1][0], sq[1][2], 0x21); + SumHorizontal(src0, over_read_in_bytes, &s3[0][2], &s3[1][2], &s5[0][3], + &s5[1][3]); + SumHorizontal(src1, over_read_in_bytes, &s3[0][3], &s3[1][3], &s5[0][4], + &s5[1][4]); + StoreAligned32(sum3[2] + x + 0, s3[0][2]); + StoreAligned32(sum3[2] + x + 16, s3[1][2]); + StoreAligned32(sum3[3] + x + 0, s3[0][3]); + StoreAligned32(sum3[3] + x + 16, s3[1][3]); + StoreAligned32(sum5[3] + x + 0, s5[0][3]); + StoreAligned32(sum5[3] + x + 16, s5[1][3]); + StoreAligned32(sum5[4] + x + 0, s5[0][4]); + StoreAligned32(sum5[4] + x + 16, s5[1][4]); + SumHorizontal(sq[0], &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + SumHorizontal(sq[1], &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]); + StoreAligned64(square_sum3[2] + x, sq3[2]); + StoreAligned64(square_sum5[3] + x, sq5[3]); + StoreAligned64(square_sum3[3] + x, sq3[3]); + StoreAligned64(square_sum5[4] + x, sq5[4]); + LoadAligned32x2U16(sum3, x, s3[0]); + LoadAligned64x2U32(square_sum3, x, sq3); + CalculateSumAndIndex3(s3[0], sq3, scales[1], &sum_3[0][0], &index_3[0][0]); + CalculateSumAndIndex3(s3[0] + 1, sq3 + 1, scales[1], &sum_3[1][0], + &index_3[1][0]); + LoadAligned32x3U16(sum5, x, s5[0]); + LoadAligned64x3U32(square_sum5, x, sq5); + CalculateSumAndIndex5(s5[0], sq5, scales[0], &sum_5[0], &index_5[0]); + + SumHorizontal(sq[0] + 1, &sq3t[2][0], &sq3t[2][1], &sq5t[3][0], &sq5t[3][1]); + SumHorizontal(sq[1] + 1, &sq3t[3][0], &sq3t[3][1], &sq5t[4][0], &sq5t[4][1]); + StoreAligned64(square_sum3[2] + x + 16, sq3t[2]); + StoreAligned64(square_sum5[3] + x + 16, sq5t[3]); + StoreAligned64(square_sum3[3] + x + 16, sq3t[3]); + StoreAligned64(square_sum5[4] + x + 16, sq5t[4]); + LoadAligned32x2U16Msan(sum3, x + 16, sum_width, s3[1]); + LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3t); + CalculateSumAndIndex3(s3[1], sq3t, scales[1], &sum_3[0][1], &index_3[0][1]); + CalculateSumAndIndex3(s3[1] + 1, sq3t + 1, scales[1], &sum_3[1][1], + &index_3[1][1]); + CalculateIntermediate<9>(sum_3[0], index_3[0], ma3[0], b3[0] + 1); + CalculateIntermediate<9>(sum_3[1], index_3[1], ma3[1], b3[1] + 1); + LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]); + LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5t); + CalculateSumAndIndex5(s5[1], sq5t, scales[0], &sum_5[1], &index_5[1]); + CalculateIntermediate<25>(sum_5, index_5, ma5, b5 + 1); + b3[0][0] = _mm256_permute2x128_si256(b3[0][0], b3[0][2], 0x21); + b3[1][0] = _mm256_permute2x128_si256(b3[1][0], b3[1][2], 0x21); + b5[0] = _mm256_permute2x128_si256(b5[0], b5[2], 0x21); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRowLo( + const __m128i s, const uint16_t scales[2], const uint16_t* const sum3[4], + const uint16_t* const sum5[5], const uint32_t* const square_sum3[4], + const uint32_t* const square_sum5[5], __m128i sq[2], __m128i* const ma3, + __m128i* const ma5, __m128i* const b3, __m128i* const b5) { + __m128i s3[3], s5[5], sq3[3][2], sq5[5][2]; + sq[1] = SquareHi8(s); + SumHorizontalLo(s, &s3[2], &s5[3]); + SumHorizontal(sq, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + LoadAligned16x3U16(sum5, 0, s5); + s5[4] = s5[3]; + LoadAligned32x3U32(square_sum5, 0, sq5); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + CalculateIntermediate5(s5, sq5, scales[0], ma5, b5); + LoadAligned16x2U16(sum3, 0, s3); + LoadAligned32x2U32(square_sum3, 0, sq3); + CalculateIntermediate3(s3, sq3, scales[1], ma3, b3); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow( + const uint8_t* const src, const ptrdiff_t over_read_in_bytes, + const ptrdiff_t sum_width, const ptrdiff_t x, const uint16_t scales[2], + const uint16_t* const sum3[4], const uint16_t* const sum5[5], + const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5], + __m256i sq[6], __m256i ma3[2], __m256i ma5[2], __m256i b3[5], + __m256i b5[5]) { + const __m256i s0 = LoadUnaligned32Msan(src + 8, over_read_in_bytes + 8); + __m256i s3[2][3], s5[2][5], sq3[4][2], sq3t[4][2], sq5[5][2], sq5t[5][2], + sum_3[2], index_3[2], sum_5[2], index_5[2]; + sq[1] = SquareLo8(s0); + sq[2] = SquareHi8(s0); + sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21); + SumHorizontal(src, over_read_in_bytes, &s3[0][2], &s3[1][2], &s5[0][3], + &s5[1][3]); + SumHorizontal(sq, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + LoadAligned32x2U16(sum3, x, s3[0]); + LoadAligned64x2U32(square_sum3, x, sq3); + CalculateSumAndIndex3(s3[0], sq3, scales[1], &sum_3[0], &index_3[0]); + LoadAligned32x3U16(sum5, x, s5[0]); + s5[0][4] = s5[0][3]; + LoadAligned64x3U32(square_sum5, x, sq5); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + CalculateSumAndIndex5(s5[0], sq5, scales[0], &sum_5[0], &index_5[0]); + + SumHorizontal(sq + 1, &sq3t[2][0], &sq3t[2][1], &sq5t[3][0], &sq5t[3][1]); + LoadAligned32x2U16Msan(sum3, x + 16, sum_width, s3[1]); + LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3t); + CalculateSumAndIndex3(s3[1], sq3t, scales[1], &sum_3[1], &index_3[1]); + CalculateIntermediate<9>(sum_3, index_3, ma3, b3 + 1); + LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]); + s5[1][4] = s5[1][3]; + LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5t); + sq5t[4][0] = sq5t[3][0]; + sq5t[4][1] = sq5t[3][1]; + CalculateSumAndIndex5(s5[1], sq5t, scales[0], &sum_5[1], &index_5[1]); + CalculateIntermediate<25>(sum_5, index_5, ma5, b5 + 1); + b3[0] = _mm256_permute2x128_si256(b3[0], b3[2], 0x21); + b5[0] = _mm256_permute2x128_si256(b5[0], b5[2], 0x21); +} + +inline void BoxSumFilterPreProcess5(const uint8_t* const src0, + const uint8_t* const src1, const int width, + const uint32_t scale, + uint16_t* const sum5[5], + uint32_t* const square_sum5[5], + const ptrdiff_t sum_width, uint16_t* ma565, + uint32_t* b565) { + __m128i ma0, b0, s[2][3], sq_128[2][2]; + __m256i mas[3], sq[2][3], bs[3]; + s[0][0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width); + s[1][0] = LoadUnaligned16Msan(src1, kOverreadInBytesPass1_128 - width); + sq_128[0][0] = SquareLo8(s[0][0]); + sq_128[1][0] = SquareLo8(s[1][0]); + BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq_128, &ma0, &b0); + sq[0][0] = SetrM128i(sq_128[0][0], sq_128[0][1]); + sq[1][0] = SetrM128i(sq_128[1][0], sq_128[1][1]); + mas[0] = SetrM128i(ma0, ma0); + bs[0] = SetrM128i(b0, b0); + + int x = 0; + do { + __m256i ma5[3], ma[2], b[4]; + BoxFilterPreProcess5(src0 + x + 8, src1 + x + 8, + x + 8 + kOverreadInBytesPass1_256 - width, sum_width, + x + 8, scale, sum5, square_sum5, sq, mas, bs); + Prepare3_8(mas, ma5); + ma[0] = Sum565Lo(ma5); + ma[1] = Sum565Hi(ma5); + StoreAligned64(ma565, ma); + Sum565W(bs + 0, b + 0); + Sum565W(bs + 1, b + 2); + StoreAligned64(b565, b + 0); + StoreAligned64(b565 + 16, b + 2); + sq[0][0] = sq[0][2]; + sq[1][0] = sq[1][2]; + mas[0] = mas[2]; + bs[0] = bs[2]; + ma565 += 32; + b565 += 32; + x += 32; + } while (x < width); +} + +template <bool calculate444> +LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3( + const uint8_t* const src, const int width, const uint32_t scale, + uint16_t* const sum3[3], uint32_t* const square_sum3[3], + const ptrdiff_t sum_width, uint16_t* ma343, uint16_t* ma444, uint32_t* b343, + uint32_t* b444) { + __m128i ma0, sq_128[2], b0; + __m256i mas[3], sq[3], bs[3]; + const __m128i s = LoadUnaligned16Msan(src, kOverreadInBytesPass2_128 - width); + sq_128[0] = SquareLo8(s); + BoxFilterPreProcess3Lo(s, scale, sum3, square_sum3, sq_128, &ma0, &b0); + sq[0] = SetrM128i(sq_128[0], sq_128[1]); + mas[0] = SetrM128i(ma0, ma0); + bs[0] = SetrM128i(b0, b0); + + int x = 0; + do { + __m256i ma3[3]; + BoxFilterPreProcess3(src + x + 8, x + 8 + kOverreadInBytesPass2_256 - width, + x + 8, sum_width, scale, sum3, square_sum3, sq, mas, + bs); + Prepare3_8(mas, ma3); + if (calculate444) { // NOLINT(readability-simplify-boolean-expr) + Store343_444Lo(ma3, bs + 0, 0, ma343, ma444, b343, b444); + Store343_444Hi(ma3, bs + 1, 16, ma343, ma444, b343, b444); + ma444 += 32; + b444 += 32; + } else { + __m256i ma[2], b[4]; + ma[0] = Sum343Lo(ma3); + ma[1] = Sum343Hi(ma3); + StoreAligned64(ma343, ma); + Sum343W(bs + 0, b + 0); + Sum343W(bs + 1, b + 2); + StoreAligned64(b343 + 0, b + 0); + StoreAligned64(b343 + 16, b + 2); + } + sq[0] = sq[2]; + mas[0] = mas[2]; + bs[0] = bs[2]; + ma343 += 32; + b343 += 32; + x += 32; + } while (x < width); +} + +inline void BoxSumFilterPreProcess( + const uint8_t* const src0, const uint8_t* const src1, const int width, + const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + const ptrdiff_t sum_width, uint16_t* const ma343[4], + uint16_t* const ma444[2], uint16_t* ma565, uint32_t* const b343[4], + uint32_t* const b444[2], uint32_t* b565) { + __m128i s[2], ma3_128[2], ma5_0, sq_128[2][2], b3_128[2], b5_0; + __m256i ma3[2][3], ma5[3], sq[2][3], b3[2][5], b5[5]; + s[0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width); + s[1] = LoadUnaligned16Msan(src1, kOverreadInBytesPass1_128 - width); + sq_128[0][0] = SquareLo8(s[0]); + sq_128[1][0] = SquareLo8(s[1]); + BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq_128, + ma3_128, b3_128, &ma5_0, &b5_0); + sq[0][0] = SetrM128i(sq_128[0][0], sq_128[0][1]); + sq[1][0] = SetrM128i(sq_128[1][0], sq_128[1][1]); + ma3[0][0] = SetrM128i(ma3_128[0], ma3_128[0]); + ma3[1][0] = SetrM128i(ma3_128[1], ma3_128[1]); + ma5[0] = SetrM128i(ma5_0, ma5_0); + b3[0][0] = SetrM128i(b3_128[0], b3_128[0]); + b3[1][0] = SetrM128i(b3_128[1], b3_128[1]); + b5[0] = SetrM128i(b5_0, b5_0); + + int x = 0; + do { + __m256i ma[2], b[4], ma3x[3], ma5x[3]; + BoxFilterPreProcess(src0 + x + 8, src1 + x + 8, + x + 8 + kOverreadInBytesPass1_256 - width, x + 8, + scales, sum3, sum5, square_sum3, square_sum5, sum_width, + sq, ma3, b3, ma5, b5); + Prepare3_8(ma3[0], ma3x); + ma[0] = Sum343Lo(ma3x); + ma[1] = Sum343Hi(ma3x); + StoreAligned64(ma343[0] + x, ma); + Sum343W(b3[0], b); + StoreAligned64(b343[0] + x, b); + Sum565W(b5, b); + StoreAligned64(b565, b); + Prepare3_8(ma3[1], ma3x); + Store343_444Lo(ma3x, b3[1], x, ma343[1], ma444[0], b343[1], b444[0]); + Store343_444Hi(ma3x, b3[1] + 1, x + 16, ma343[1], ma444[0], b343[1], + b444[0]); + Prepare3_8(ma5, ma5x); + ma[0] = Sum565Lo(ma5x); + ma[1] = Sum565Hi(ma5x); + StoreAligned64(ma565, ma); + Sum343W(b3[0] + 1, b); + StoreAligned64(b343[0] + x + 16, b); + Sum565W(b5 + 1, b); + StoreAligned64(b565 + 16, b); + sq[0][0] = sq[0][2]; + sq[1][0] = sq[1][2]; + ma3[0][0] = ma3[0][2]; + ma3[1][0] = ma3[1][2]; + ma5[0] = ma5[2]; + b3[0][0] = b3[0][2]; + b3[1][0] = b3[1][2]; + b5[0] = b5[2]; + ma565 += 32; + b565 += 32; + x += 32; + } while (x < width); +} + +template <int shift> +inline __m256i FilterOutput(const __m256i ma_x_src, const __m256i b) { + // ma: 255 * 32 = 8160 (13 bits) + // b: 65088 * 32 = 2082816 (21 bits) + // v: b - ma * 255 (22 bits) + const __m256i v = _mm256_sub_epi32(b, ma_x_src); + // kSgrProjSgrBits = 8 + // kSgrProjRestoreBits = 4 + // shift = 4 or 5 + // v >> 8 or 9 (13 bits) + return VrshrS32(v, kSgrProjSgrBits + shift - kSgrProjRestoreBits); +} + +template <int shift> +inline __m256i CalculateFilteredOutput(const __m256i src, const __m256i ma, + const __m256i b[2]) { + const __m256i ma_x_src_lo = VmullLo16(ma, src); + const __m256i ma_x_src_hi = VmullHi16(ma, src); + const __m256i dst_lo = FilterOutput<shift>(ma_x_src_lo, b[0]); + const __m256i dst_hi = FilterOutput<shift>(ma_x_src_hi, b[1]); + return _mm256_packs_epi32(dst_lo, dst_hi); // 13 bits +} + +inline __m256i CalculateFilteredOutputPass1(const __m256i src, __m256i ma[2], + __m256i b[2][2]) { + const __m256i ma_sum = _mm256_add_epi16(ma[0], ma[1]); + __m256i b_sum[2]; + b_sum[0] = _mm256_add_epi32(b[0][0], b[1][0]); + b_sum[1] = _mm256_add_epi32(b[0][1], b[1][1]); + return CalculateFilteredOutput<5>(src, ma_sum, b_sum); +} + +inline __m256i CalculateFilteredOutputPass2(const __m256i src, __m256i ma[3], + __m256i b[3][2]) { + const __m256i ma_sum = Sum3_16(ma); + __m256i b_sum[2]; + Sum3_32(b, b_sum); + return CalculateFilteredOutput<5>(src, ma_sum, b_sum); +} + +inline __m256i SelfGuidedFinal(const __m256i src, const __m256i v[2]) { + const __m256i v_lo = + VrshrS32(v[0], kSgrProjRestoreBits + kSgrProjPrecisionBits); + const __m256i v_hi = + VrshrS32(v[1], kSgrProjRestoreBits + kSgrProjPrecisionBits); + const __m256i vv = _mm256_packs_epi32(v_lo, v_hi); + return _mm256_add_epi16(src, vv); +} + +inline __m256i SelfGuidedDoubleMultiplier(const __m256i src, + const __m256i filter[2], const int w0, + const int w2) { + __m256i v[2]; + const __m256i w0_w2 = + _mm256_set1_epi32((w2 << 16) | static_cast<uint16_t>(w0)); + const __m256i f_lo = _mm256_unpacklo_epi16(filter[0], filter[1]); + const __m256i f_hi = _mm256_unpackhi_epi16(filter[0], filter[1]); + v[0] = _mm256_madd_epi16(w0_w2, f_lo); + v[1] = _mm256_madd_epi16(w0_w2, f_hi); + return SelfGuidedFinal(src, v); +} + +inline __m256i SelfGuidedSingleMultiplier(const __m256i src, + const __m256i filter, const int w0) { + // weight: -96 to 96 (Sgrproj_Xqd_Min/Max) + __m256i v[2]; + v[0] = VmullNLo8(filter, w0); + v[1] = VmullNHi8(filter, w0); + return SelfGuidedFinal(src, v); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPass1( + const uint8_t* const src, const uint8_t* const src0, + const uint8_t* const src1, const ptrdiff_t stride, uint16_t* const sum5[5], + uint32_t* const square_sum5[5], const int width, const ptrdiff_t sum_width, + const uint32_t scale, const int16_t w0, uint16_t* const ma565[2], + uint32_t* const b565[2], uint8_t* const dst) { + __m128i ma0, b0, s[2][3], sq_128[2][2]; + __m256i mas[3], sq[2][3], bs[3]; + s[0][0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width); + s[1][0] = LoadUnaligned16Msan(src1, kOverreadInBytesPass1_128 - width); + sq_128[0][0] = SquareLo8(s[0][0]); + sq_128[1][0] = SquareLo8(s[1][0]); + BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq_128, &ma0, &b0); + sq[0][0] = SetrM128i(sq_128[0][0], sq_128[0][1]); + sq[1][0] = SetrM128i(sq_128[1][0], sq_128[1][1]); + mas[0] = SetrM128i(ma0, ma0); + bs[0] = SetrM128i(b0, b0); + + int x = 0; + do { + __m256i ma[3], ma3[3], b[2][2][2]; + BoxFilterPreProcess5(src0 + x + 8, src1 + x + 8, + x + 8 + kOverreadInBytesPass1_256 - width, sum_width, + x + 8, scale, sum5, square_sum5, sq, mas, bs); + Prepare3_8(mas, ma3); + ma[1] = Sum565Lo(ma3); + ma[2] = Sum565Hi(ma3); + StoreAligned64(ma565[1] + x, ma + 1); + Sum565W(bs + 0, b[0][1]); + Sum565W(bs + 1, b[1][1]); + StoreAligned64(b565[1] + x + 0, b[0][1]); + StoreAligned64(b565[1] + x + 16, b[1][1]); + const __m256i sr0 = LoadUnaligned32(src + x); + const __m256i sr1 = LoadUnaligned32(src + stride + x); + const __m256i sr0_lo = _mm256_unpacklo_epi8(sr0, _mm256_setzero_si256()); + const __m256i sr1_lo = _mm256_unpacklo_epi8(sr1, _mm256_setzero_si256()); + ma[0] = LoadAligned32(ma565[0] + x); + LoadAligned64(b565[0] + x, b[0][0]); + const __m256i p00 = CalculateFilteredOutputPass1(sr0_lo, ma, b[0]); + const __m256i p01 = CalculateFilteredOutput<4>(sr1_lo, ma[1], b[0][1]); + const __m256i d00 = SelfGuidedSingleMultiplier(sr0_lo, p00, w0); + const __m256i d10 = SelfGuidedSingleMultiplier(sr1_lo, p01, w0); + const __m256i sr0_hi = _mm256_unpackhi_epi8(sr0, _mm256_setzero_si256()); + const __m256i sr1_hi = _mm256_unpackhi_epi8(sr1, _mm256_setzero_si256()); + ma[1] = LoadAligned32(ma565[0] + x + 16); + LoadAligned64(b565[0] + x + 16, b[1][0]); + const __m256i p10 = CalculateFilteredOutputPass1(sr0_hi, ma + 1, b[1]); + const __m256i p11 = CalculateFilteredOutput<4>(sr1_hi, ma[2], b[1][1]); + const __m256i d01 = SelfGuidedSingleMultiplier(sr0_hi, p10, w0); + const __m256i d11 = SelfGuidedSingleMultiplier(sr1_hi, p11, w0); + StoreUnaligned32(dst + x, _mm256_packus_epi16(d00, d01)); + StoreUnaligned32(dst + stride + x, _mm256_packus_epi16(d10, d11)); + sq[0][0] = sq[0][2]; + sq[1][0] = sq[1][2]; + mas[0] = mas[2]; + bs[0] = bs[2]; + x += 32; + } while (x < width); +} + +inline void BoxFilterPass1LastRow( + const uint8_t* const src, const uint8_t* const src0, const int width, + const ptrdiff_t sum_width, const uint32_t scale, const int16_t w0, + uint16_t* const sum5[5], uint32_t* const square_sum5[5], uint16_t* ma565, + uint32_t* b565, uint8_t* const dst) { + const __m128i s0 = + LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width); + __m128i ma0, b0, sq_128[2]; + __m256i mas[3], sq[3], bs[3]; + sq_128[0] = SquareLo8(s0); + BoxFilterPreProcess5LastRowLo(s0, scale, sum5, square_sum5, sq_128, &ma0, + &b0); + sq[0] = SetrM128i(sq_128[0], sq_128[1]); + mas[0] = SetrM128i(ma0, ma0); + bs[0] = SetrM128i(b0, b0); + + int x = 0; + do { + __m256i ma[3], ma5[3], b[2][2]; + BoxFilterPreProcess5LastRow( + src0 + x + 8, x + 8 + kOverreadInBytesPass1_256 - width, sum_width, + x + 8, scale, sum5, square_sum5, sq, mas, bs); + Prepare3_8(mas, ma5); + ma[1] = Sum565Lo(ma5); + ma[2] = Sum565Hi(ma5); + Sum565W(bs + 0, b[1]); + const __m256i sr = LoadUnaligned32(src + x); + const __m256i sr_lo = _mm256_unpacklo_epi8(sr, _mm256_setzero_si256()); + const __m256i sr_hi = _mm256_unpackhi_epi8(sr, _mm256_setzero_si256()); + ma[0] = LoadAligned32(ma565); + LoadAligned64(b565 + 0, b[0]); + const __m256i p0 = CalculateFilteredOutputPass1(sr_lo, ma, b); + ma[1] = LoadAligned32(ma565 + 16); + LoadAligned64(b565 + 16, b[0]); + Sum565W(bs + 1, b[1]); + const __m256i p1 = CalculateFilteredOutputPass1(sr_hi, ma + 1, b); + const __m256i d0 = SelfGuidedSingleMultiplier(sr_lo, p0, w0); + const __m256i d1 = SelfGuidedSingleMultiplier(sr_hi, p1, w0); + StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1)); + sq[0] = sq[2]; + mas[0] = mas[2]; + bs[0] = bs[2]; + ma565 += 32; + b565 += 32; + x += 32; + } while (x < width); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPass2( + const uint8_t* const src, const uint8_t* const src0, const int width, + const ptrdiff_t sum_width, const uint32_t scale, const int16_t w0, + uint16_t* const sum3[3], uint32_t* const square_sum3[3], + uint16_t* const ma343[3], uint16_t* const ma444[2], uint32_t* const b343[3], + uint32_t* const b444[2], uint8_t* const dst) { + const __m128i s0 = + LoadUnaligned16Msan(src0, kOverreadInBytesPass2_128 - width); + __m128i ma0, b0, sq_128[2]; + __m256i mas[3], sq[3], bs[3]; + sq_128[0] = SquareLo8(s0); + BoxFilterPreProcess3Lo(s0, scale, sum3, square_sum3, sq_128, &ma0, &b0); + sq[0] = SetrM128i(sq_128[0], sq_128[1]); + mas[0] = SetrM128i(ma0, ma0); + bs[0] = SetrM128i(b0, b0); + + int x = 0; + do { + __m256i ma[4], b[4][2], ma3[3]; + BoxFilterPreProcess3(src0 + x + 8, + x + 8 + kOverreadInBytesPass2_256 - width, x + 8, + sum_width, scale, sum3, square_sum3, sq, mas, bs); + Prepare3_8(mas, ma3); + Store343_444Lo(ma3, bs + 0, x + 0, &ma[2], b[2], ma343[2], ma444[1], + b343[2], b444[1]); + Store343_444Hi(ma3, bs + 1, x + 16, &ma[3], b[3], ma343[2], ma444[1], + b343[2], b444[1]); + const __m256i sr = LoadUnaligned32(src + x); + const __m256i sr_lo = _mm256_unpacklo_epi8(sr, _mm256_setzero_si256()); + const __m256i sr_hi = _mm256_unpackhi_epi8(sr, _mm256_setzero_si256()); + ma[0] = LoadAligned32(ma343[0] + x); + ma[1] = LoadAligned32(ma444[0] + x); + LoadAligned64(b343[0] + x, b[0]); + LoadAligned64(b444[0] + x, b[1]); + const __m256i p0 = CalculateFilteredOutputPass2(sr_lo, ma, b); + ma[1] = LoadAligned32(ma343[0] + x + 16); + ma[2] = LoadAligned32(ma444[0] + x + 16); + LoadAligned64(b343[0] + x + 16, b[1]); + LoadAligned64(b444[0] + x + 16, b[2]); + const __m256i p1 = CalculateFilteredOutputPass2(sr_hi, ma + 1, b + 1); + const __m256i d0 = SelfGuidedSingleMultiplier(sr_lo, p0, w0); + const __m256i d1 = SelfGuidedSingleMultiplier(sr_hi, p1, w0); + StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1)); + sq[0] = sq[2]; + mas[0] = mas[2]; + bs[0] = bs[2]; + x += 32; + } while (x < width); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilter( + const uint8_t* const src, const uint8_t* const src0, + const uint8_t* const src1, const ptrdiff_t stride, const int width, + const uint16_t scales[2], const int16_t w0, const int16_t w2, + uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + const ptrdiff_t sum_width, uint16_t* const ma343[4], + uint16_t* const ma444[3], uint16_t* const ma565[2], uint32_t* const b343[4], + uint32_t* const b444[3], uint32_t* const b565[2], uint8_t* const dst) { + __m128i s[2], ma3_128[2], ma5_0, sq_128[2][2], b3_128[2], b5_0; + __m256i ma3[2][3], ma5[3], sq[2][3], b3[2][5], b5[5]; + s[0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width); + s[1] = LoadUnaligned16Msan(src1, kOverreadInBytesPass1_128 - width); + sq_128[0][0] = SquareLo8(s[0]); + sq_128[1][0] = SquareLo8(s[1]); + BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq_128, + ma3_128, b3_128, &ma5_0, &b5_0); + sq[0][0] = SetrM128i(sq_128[0][0], sq_128[0][1]); + sq[1][0] = SetrM128i(sq_128[1][0], sq_128[1][1]); + ma3[0][0] = SetrM128i(ma3_128[0], ma3_128[0]); + ma3[1][0] = SetrM128i(ma3_128[1], ma3_128[1]); + ma5[0] = SetrM128i(ma5_0, ma5_0); + b3[0][0] = SetrM128i(b3_128[0], b3_128[0]); + b3[1][0] = SetrM128i(b3_128[1], b3_128[1]); + b5[0] = SetrM128i(b5_0, b5_0); + + int x = 0; + do { + __m256i ma[3][3], mat[3][3], b[3][3][2], p[2][2], ma3x[2][3], ma5x[3]; + BoxFilterPreProcess(src0 + x + 8, src1 + x + 8, + x + 8 + kOverreadInBytesPass1_256 - width, x + 8, + scales, sum3, sum5, square_sum3, square_sum5, sum_width, + sq, ma3, b3, ma5, b5); + Prepare3_8(ma3[0], ma3x[0]); + Prepare3_8(ma3[1], ma3x[1]); + Prepare3_8(ma5, ma5x); + Store343_444Lo(ma3x[0], b3[0], x, &ma[1][2], &ma[2][1], b[1][2], b[2][1], + ma343[2], ma444[1], b343[2], b444[1]); + Store343_444Lo(ma3x[1], b3[1], x, &ma[2][2], b[2][2], ma343[3], ma444[2], + b343[3], b444[2]); + ma[0][1] = Sum565Lo(ma5x); + ma[0][2] = Sum565Hi(ma5x); + mat[0][1] = ma[0][2]; + StoreAligned64(ma565[1] + x, ma[0] + 1); + Sum565W(b5, b[0][1]); + StoreAligned64(b565[1] + x, b[0][1]); + const __m256i sr0 = LoadUnaligned32(src + x); + const __m256i sr1 = LoadUnaligned32(src + stride + x); + const __m256i sr0_lo = _mm256_unpacklo_epi8(sr0, _mm256_setzero_si256()); + const __m256i sr1_lo = _mm256_unpacklo_epi8(sr1, _mm256_setzero_si256()); + ma[0][0] = LoadAligned32(ma565[0] + x); + LoadAligned64(b565[0] + x, b[0][0]); + p[0][0] = CalculateFilteredOutputPass1(sr0_lo, ma[0], b[0]); + p[1][0] = CalculateFilteredOutput<4>(sr1_lo, ma[0][1], b[0][1]); + ma[1][0] = LoadAligned32(ma343[0] + x); + ma[1][1] = LoadAligned32(ma444[0] + x); + LoadAligned64(b343[0] + x, b[1][0]); + LoadAligned64(b444[0] + x, b[1][1]); + p[0][1] = CalculateFilteredOutputPass2(sr0_lo, ma[1], b[1]); + const __m256i d00 = SelfGuidedDoubleMultiplier(sr0_lo, p[0], w0, w2); + ma[2][0] = LoadAligned32(ma343[1] + x); + LoadAligned64(b343[1] + x, b[2][0]); + p[1][1] = CalculateFilteredOutputPass2(sr1_lo, ma[2], b[2]); + const __m256i d10 = SelfGuidedDoubleMultiplier(sr1_lo, p[1], w0, w2); + + Sum565W(b5 + 1, b[0][1]); + StoreAligned64(b565[1] + x + 16, b[0][1]); + Store343_444Hi(ma3x[0], b3[0] + 1, x + 16, &mat[1][2], &mat[2][1], b[1][2], + b[2][1], ma343[2], ma444[1], b343[2], b444[1]); + Store343_444Hi(ma3x[1], b3[1] + 1, x + 16, &mat[2][2], b[2][2], ma343[3], + ma444[2], b343[3], b444[2]); + const __m256i sr0_hi = _mm256_unpackhi_epi8(sr0, _mm256_setzero_si256()); + const __m256i sr1_hi = _mm256_unpackhi_epi8(sr1, _mm256_setzero_si256()); + mat[0][0] = LoadAligned32(ma565[0] + x + 16); + LoadAligned64(b565[0] + x + 16, b[0][0]); + p[0][0] = CalculateFilteredOutputPass1(sr0_hi, mat[0], b[0]); + p[1][0] = CalculateFilteredOutput<4>(sr1_hi, mat[0][1], b[0][1]); + mat[1][0] = LoadAligned32(ma343[0] + x + 16); + mat[1][1] = LoadAligned32(ma444[0] + x + 16); + LoadAligned64(b343[0] + x + 16, b[1][0]); + LoadAligned64(b444[0] + x + 16, b[1][1]); + p[0][1] = CalculateFilteredOutputPass2(sr0_hi, mat[1], b[1]); + const __m256i d01 = SelfGuidedDoubleMultiplier(sr0_hi, p[0], w0, w2); + mat[2][0] = LoadAligned32(ma343[1] + x + 16); + LoadAligned64(b343[1] + x + 16, b[2][0]); + p[1][1] = CalculateFilteredOutputPass2(sr1_hi, mat[2], b[2]); + const __m256i d11 = SelfGuidedDoubleMultiplier(sr1_hi, p[1], w0, w2); + StoreUnaligned32(dst + x, _mm256_packus_epi16(d00, d01)); + StoreUnaligned32(dst + stride + x, _mm256_packus_epi16(d10, d11)); + sq[0][0] = sq[0][2]; + sq[1][0] = sq[1][2]; + ma3[0][0] = ma3[0][2]; + ma3[1][0] = ma3[1][2]; + ma5[0] = ma5[2]; + b3[0][0] = b3[0][2]; + b3[1][0] = b3[1][2]; + b5[0] = b5[2]; + x += 32; + } while (x < width); +} + +inline void BoxFilterLastRow( + const uint8_t* const src, const uint8_t* const src0, const int width, + const ptrdiff_t sum_width, const uint16_t scales[2], const int16_t w0, + const int16_t w2, uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + uint16_t* const ma343[4], uint16_t* const ma444[3], + uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3], + uint32_t* const b565[2], uint8_t* const dst) { + const __m128i s0 = + LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width); + __m128i ma3_0, ma5_0, b3_0, b5_0, sq_128[2]; + __m256i ma3[3], ma5[3], sq[3], b3[3], b5[3]; + sq_128[0] = SquareLo8(s0); + BoxFilterPreProcessLastRowLo(s0, scales, sum3, sum5, square_sum3, square_sum5, + sq_128, &ma3_0, &ma5_0, &b3_0, &b5_0); + sq[0] = SetrM128i(sq_128[0], sq_128[1]); + ma3[0] = SetrM128i(ma3_0, ma3_0); + ma5[0] = SetrM128i(ma5_0, ma5_0); + b3[0] = SetrM128i(b3_0, b3_0); + b5[0] = SetrM128i(b5_0, b5_0); + + int x = 0; + do { + __m256i ma[3], mat[3], b[3][2], p[2], ma3x[3], ma5x[3]; + BoxFilterPreProcessLastRow(src0 + x + 8, + x + 8 + kOverreadInBytesPass1_256 - width, + sum_width, x + 8, scales, sum3, sum5, + square_sum3, square_sum5, sq, ma3, ma5, b3, b5); + Prepare3_8(ma3, ma3x); + Prepare3_8(ma5, ma5x); + ma[1] = Sum565Lo(ma5x); + Sum565W(b5, b[1]); + ma[2] = Sum343Lo(ma3x); + Sum343W(b3, b[2]); + const __m256i sr = LoadUnaligned32(src + x); + const __m256i sr_lo = _mm256_unpacklo_epi8(sr, _mm256_setzero_si256()); + ma[0] = LoadAligned32(ma565[0] + x); + LoadAligned64(b565[0] + x, b[0]); + p[0] = CalculateFilteredOutputPass1(sr_lo, ma, b); + ma[0] = LoadAligned32(ma343[0] + x); + ma[1] = LoadAligned32(ma444[0] + x); + LoadAligned64(b343[0] + x, b[0]); + LoadAligned64(b444[0] + x, b[1]); + p[1] = CalculateFilteredOutputPass2(sr_lo, ma, b); + const __m256i d0 = SelfGuidedDoubleMultiplier(sr_lo, p, w0, w2); + + mat[1] = Sum565Hi(ma5x); + Sum565W(b5 + 1, b[1]); + mat[2] = Sum343Hi(ma3x); + Sum343W(b3 + 1, b[2]); + const __m256i sr_hi = _mm256_unpackhi_epi8(sr, _mm256_setzero_si256()); + mat[0] = LoadAligned32(ma565[0] + x + 16); + LoadAligned64(b565[0] + x + 16, b[0]); + p[0] = CalculateFilteredOutputPass1(sr_hi, mat, b); + mat[0] = LoadAligned32(ma343[0] + x + 16); + mat[1] = LoadAligned32(ma444[0] + x + 16); + LoadAligned64(b343[0] + x + 16, b[0]); + LoadAligned64(b444[0] + x + 16, b[1]); + p[1] = CalculateFilteredOutputPass2(sr_hi, mat, b); + const __m256i d1 = SelfGuidedDoubleMultiplier(sr_hi, p, w0, w2); + StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1)); + sq[0] = sq[2]; + ma3[0] = ma3[2]; + ma5[0] = ma5[2]; + b3[0] = b3[2]; + b5[0] = b5[2]; + x += 32; + } while (x < width); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( + const RestorationUnitInfo& restoration_info, const uint8_t* src, + const uint8_t* const top_border, const uint8_t* bottom_border, + const ptrdiff_t stride, const int width, const int height, + SgrBuffer* const sgr_buffer, uint8_t* dst) { + const auto temp_stride = Align<ptrdiff_t>(width, 32); + const auto sum_width = temp_stride + 8; + const auto sum_stride = temp_stride + 32; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index]; // < 2^12. + const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0]; + const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1]; + const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1; + uint16_t *sum3[4], *sum5[5], *ma343[4], *ma444[3], *ma565[2]; + uint32_t *square_sum3[4], *square_sum5[5], *b343[4], *b444[3], *b565[2]; + sum3[0] = sgr_buffer->sum3 + kSumOffset; + square_sum3[0] = sgr_buffer->square_sum3 + kSumOffset; + ma343[0] = sgr_buffer->ma343; + b343[0] = sgr_buffer->b343; + for (int i = 1; i <= 3; ++i) { + sum3[i] = sum3[i - 1] + sum_stride; + square_sum3[i] = square_sum3[i - 1] + sum_stride; + ma343[i] = ma343[i - 1] + temp_stride; + b343[i] = b343[i - 1] + temp_stride; + } + sum5[0] = sgr_buffer->sum5 + kSumOffset; + square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset; + for (int i = 1; i <= 4; ++i) { + sum5[i] = sum5[i - 1] + sum_stride; + square_sum5[i] = square_sum5[i - 1] + sum_stride; + } + ma444[0] = sgr_buffer->ma444; + b444[0] = sgr_buffer->b444; + for (int i = 1; i <= 2; ++i) { + ma444[i] = ma444[i - 1] + temp_stride; + b444[i] = b444[i - 1] + temp_stride; + } + ma565[0] = sgr_buffer->ma565; + ma565[1] = ma565[0] + temp_stride; + b565[0] = sgr_buffer->b565; + b565[1] = b565[0] + temp_stride; + assert(scales[0] != 0); + assert(scales[1] != 0); + BoxSum(top_border, stride, width, sum_stride, temp_stride, sum3[0], sum5[1], + square_sum3[0], square_sum5[1]); + sum5[0] = sum5[1]; + square_sum5[0] = square_sum5[1]; + const uint8_t* const s = (height > 1) ? src + stride : bottom_border; + BoxSumFilterPreProcess(src, s, width, scales, sum3, sum5, square_sum3, + square_sum5, sum_width, ma343, ma444, ma565[0], b343, + b444, b565[0]); + sum5[0] = sgr_buffer->sum5 + kSumOffset; + square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset; + + for (int y = (height >> 1) - 1; y > 0; --y) { + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + BoxFilter(src + 3, src + 2 * stride, src + 3 * stride, stride, width, + scales, w0, w2, sum3, sum5, square_sum3, square_sum5, sum_width, + ma343, ma444, ma565, b343, b444, b565, dst); + src += 2 * stride; + dst += 2 * stride; + Circulate4PointersBy2<uint16_t>(ma343); + Circulate4PointersBy2<uint32_t>(b343); + std::swap(ma444[0], ma444[2]); + std::swap(b444[0], b444[2]); + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + if ((height & 1) == 0 || height > 1) { + const uint8_t* sr[2]; + if ((height & 1) == 0) { + sr[0] = bottom_border; + sr[1] = bottom_border + stride; + } else { + sr[0] = src + 2 * stride; + sr[1] = bottom_border; + } + BoxFilter(src + 3, sr[0], sr[1], stride, width, scales, w0, w2, sum3, sum5, + square_sum3, square_sum5, sum_width, ma343, ma444, ma565, b343, + b444, b565, dst); + } + if ((height & 1) != 0) { + if (height > 1) { + src += 2 * stride; + dst += 2 * stride; + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + Circulate4PointersBy2<uint16_t>(ma343); + Circulate4PointersBy2<uint32_t>(b343); + std::swap(ma444[0], ma444[2]); + std::swap(b444[0], b444[2]); + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + BoxFilterLastRow(src + 3, bottom_border + stride, width, sum_width, scales, + w0, w2, sum3, sum5, square_sum3, square_sum5, ma343, ma444, + ma565, b343, b444, b565, dst); + } +} + +inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, + const uint8_t* src, + const uint8_t* const top_border, + const uint8_t* bottom_border, + const ptrdiff_t stride, const int width, + const int height, SgrBuffer* const sgr_buffer, + uint8_t* dst) { + const auto temp_stride = Align<ptrdiff_t>(width, 32); + const auto sum_width = temp_stride + 8; + const auto sum_stride = temp_stride + 32; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0]; // < 2^12. + const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0]; + uint16_t *sum5[5], *ma565[2]; + uint32_t *square_sum5[5], *b565[2]; + sum5[0] = sgr_buffer->sum5 + kSumOffset; + square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset; + for (int i = 1; i <= 4; ++i) { + sum5[i] = sum5[i - 1] + sum_stride; + square_sum5[i] = square_sum5[i - 1] + sum_stride; + } + ma565[0] = sgr_buffer->ma565; + ma565[1] = ma565[0] + temp_stride; + b565[0] = sgr_buffer->b565; + b565[1] = b565[0] + temp_stride; + assert(scale != 0); + BoxSum<5>(top_border, stride, width, sum_stride, temp_stride, sum5[1], + square_sum5[1]); + sum5[0] = sum5[1]; + square_sum5[0] = square_sum5[1]; + const uint8_t* const s = (height > 1) ? src + stride : bottom_border; + BoxSumFilterPreProcess5(src, s, width, scale, sum5, square_sum5, sum_width, + ma565[0], b565[0]); + sum5[0] = sgr_buffer->sum5 + kSumOffset; + square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset; + + for (int y = (height >> 1) - 1; y > 0; --y) { + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + BoxFilterPass1(src + 3, src + 2 * stride, src + 3 * stride, stride, sum5, + square_sum5, width, sum_width, scale, w0, ma565, b565, dst); + src += 2 * stride; + dst += 2 * stride; + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + if ((height & 1) == 0 || height > 1) { + const uint8_t* sr[2]; + if ((height & 1) == 0) { + sr[0] = bottom_border; + sr[1] = bottom_border + stride; + } else { + sr[0] = src + 2 * stride; + sr[1] = bottom_border; + } + BoxFilterPass1(src + 3, sr[0], sr[1], stride, sum5, square_sum5, width, + sum_width, scale, w0, ma565, b565, dst); + } + if ((height & 1) != 0) { + src += 3; + if (height > 1) { + src += 2 * stride; + dst += 2 * stride; + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + } + BoxFilterPass1LastRow(src, bottom_border + stride, width, sum_width, scale, + w0, sum5, square_sum5, ma565[0], b565[0], dst); + } +} + +inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, + const uint8_t* src, + const uint8_t* const top_border, + const uint8_t* bottom_border, + const ptrdiff_t stride, const int width, + const int height, SgrBuffer* const sgr_buffer, + uint8_t* dst) { + assert(restoration_info.sgr_proj_info.multiplier[0] == 0); + const auto temp_stride = Align<ptrdiff_t>(width, 32); + const auto sum_width = temp_stride + 8; + const auto sum_stride = temp_stride + 32; + const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1]; + const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint32_t scale = kSgrScaleParameter[sgr_proj_index][1]; // < 2^12. + uint16_t *sum3[3], *ma343[3], *ma444[2]; + uint32_t *square_sum3[3], *b343[3], *b444[2]; + sum3[0] = sgr_buffer->sum3 + kSumOffset; + square_sum3[0] = sgr_buffer->square_sum3 + kSumOffset; + ma343[0] = sgr_buffer->ma343; + b343[0] = sgr_buffer->b343; + for (int i = 1; i <= 2; ++i) { + sum3[i] = sum3[i - 1] + sum_stride; + square_sum3[i] = square_sum3[i - 1] + sum_stride; + ma343[i] = ma343[i - 1] + temp_stride; + b343[i] = b343[i - 1] + temp_stride; + } + ma444[0] = sgr_buffer->ma444; + ma444[1] = ma444[0] + temp_stride; + b444[0] = sgr_buffer->b444; + b444[1] = b444[0] + temp_stride; + assert(scale != 0); + BoxSum<3>(top_border, stride, width, sum_stride, temp_stride, sum3[0], + square_sum3[0]); + BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3, + sum_width, ma343[0], nullptr, b343[0], + nullptr); + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + const uint8_t* s; + if (height > 1) { + s = src + stride; + } else { + s = bottom_border; + bottom_border += stride; + } + BoxSumFilterPreProcess3<true>(s, width, scale, sum3, square_sum3, sum_width, + ma343[1], ma444[0], b343[1], b444[0]); + + for (int y = height - 2; y > 0; --y) { + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + BoxFilterPass2(src + 2, src + 2 * stride, width, sum_width, scale, w0, sum3, + square_sum3, ma343, ma444, b343, b444, dst); + src += stride; + dst += stride; + Circulate3PointersBy1<uint16_t>(ma343); + Circulate3PointersBy1<uint32_t>(b343); + std::swap(ma444[0], ma444[1]); + std::swap(b444[0], b444[1]); + } + + int y = std::min(height, 2); + src += 2; + do { + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + BoxFilterPass2(src, bottom_border, width, sum_width, scale, w0, sum3, + square_sum3, ma343, ma444, b343, b444, dst); + src += stride; + dst += stride; + bottom_border += stride; + Circulate3PointersBy1<uint16_t>(ma343); + Circulate3PointersBy1<uint32_t>(b343); + std::swap(ma444[0], ma444[1]); + std::swap(b444[0], b444[1]); + } while (--y != 0); +} + +// If |width| is non-multiple of 8, up to 7 more pixels are written to |dest| in +// the end of each row. It is safe to overwrite the output as it will not be +// part of the visible frame. +void SelfGuidedFilter_AVX2( + const RestorationUnitInfo& restoration_info, const void* const source, + const void* const top_border, const void* const bottom_border, + const ptrdiff_t stride, const int width, const int height, + RestorationBuffer* const restoration_buffer, void* const dest) { + const int index = restoration_info.sgr_proj_info.index; + const int radius_pass_0 = kSgrProjParams[index][0]; // 2 or 0 + const int radius_pass_1 = kSgrProjParams[index][2]; // 1 or 0 + const auto* const src = static_cast<const uint8_t*>(source); + const auto* top = static_cast<const uint8_t*>(top_border); + const auto* bottom = static_cast<const uint8_t*>(bottom_border); + auto* const dst = static_cast<uint8_t*>(dest); + SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer; + if (radius_pass_1 == 0) { + // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the + // following assertion. + assert(radius_pass_0 != 0); + BoxFilterProcessPass1(restoration_info, src - 3, top - 3, bottom - 3, + stride, width, height, sgr_buffer, dst); + } else if (radius_pass_0 == 0) { + BoxFilterProcessPass2(restoration_info, src - 2, top - 2, bottom - 2, + stride, width, height, sgr_buffer, dst); + } else { + BoxFilterProcess(restoration_info, src - 3, top - 3, bottom - 3, stride, + width, height, sgr_buffer, dst); + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); +#if DSP_ENABLED_8BPP_AVX2(WienerFilter) + dsp->loop_restorations[0] = WienerFilter_AVX2; +#endif +#if DSP_ENABLED_8BPP_AVX2(SelfGuidedFilter) + dsp->loop_restorations[1] = SelfGuidedFilter_AVX2; +#endif +} + +} // namespace +} // namespace low_bitdepth + +void LoopRestorationInit_AVX2() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_TARGETING_AVX2 +namespace libgav1 { +namespace dsp { + +void LoopRestorationInit_AVX2() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_AVX2 diff --git a/src/dsp/x86/loop_restoration_avx2.h b/src/dsp/x86/loop_restoration_avx2.h new file mode 100644 index 0000000..d80227c --- /dev/null +++ b/src/dsp/x86/loop_restoration_avx2.h @@ -0,0 +1,52 @@ +/* + * Copyright 2020 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_LOOP_RESTORATION_AVX2_H_ +#define LIBGAV1_SRC_DSP_X86_LOOP_RESTORATION_AVX2_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::loop_restorations, see the defines below for specifics. +// These functions are not thread-safe. +void LoopRestorationInit_AVX2(); +void LoopRestorationInit10bpp_AVX2(); + +} // namespace dsp +} // namespace libgav1 + +// If avx2 is enabled and the baseline isn't set due to a higher level of +// optimization being enabled, signal the avx2 implementation should be used. +#if LIBGAV1_TARGETING_AVX2 + +#ifndef LIBGAV1_Dsp8bpp_WienerFilter +#define LIBGAV1_Dsp8bpp_WienerFilter LIBGAV1_CPU_AVX2 +#endif + +#ifndef LIBGAV1_Dsp10bpp_WienerFilter +#define LIBGAV1_Dsp10bpp_WienerFilter LIBGAV1_CPU_AVX2 +#endif + +#ifndef LIBGAV1_Dsp8bpp_SelfGuidedFilter +#define LIBGAV1_Dsp8bpp_SelfGuidedFilter LIBGAV1_CPU_AVX2 +#endif + +#endif // LIBGAV1_TARGETING_AVX2 + +#endif // LIBGAV1_SRC_DSP_X86_LOOP_RESTORATION_AVX2_H_ diff --git a/src/dsp/x86/loop_restoration_sse4.cc b/src/dsp/x86/loop_restoration_sse4.cc new file mode 100644 index 0000000..24f5ad2 --- /dev/null +++ b/src/dsp/x86/loop_restoration_sse4.cc @@ -0,0 +1,2549 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/loop_restoration.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 +#include <smmintrin.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> + +#include "src/dsp/common.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +inline void WienerHorizontalClip(const __m128i s[2], const __m128i s_3x128, + int16_t* const wiener_buffer) { + constexpr int offset = + 1 << (8 + kWienerFilterBits - kInterRoundBitsHorizontal - 1); + constexpr int limit = + (1 << (8 + 1 + kWienerFilterBits - kInterRoundBitsHorizontal)) - 1; + const __m128i offsets = _mm_set1_epi16(-offset); + const __m128i limits = _mm_set1_epi16(limit - offset); + // The sum range here is [-128 * 255 + 4, 90 * 255 + 4]. + const __m128i sum = _mm_add_epi16(s[0], s[1]); + const __m128i rounded_sum0 = _mm_srai_epi16(sum, kInterRoundBitsHorizontal); + // Add back scaled down offset correction. + const __m128i rounded_sum1 = _mm_add_epi16(rounded_sum0, s_3x128); + const __m128i d0 = _mm_max_epi16(rounded_sum1, offsets); + const __m128i d1 = _mm_min_epi16(d0, limits); + StoreAligned16(wiener_buffer, d1); +} + +inline void WienerHorizontalTap7Kernel(const __m128i s[4], + const __m128i filter[4], + int16_t* const wiener_buffer) { + __m128i madds[4]; + madds[0] = _mm_maddubs_epi16(s[0], filter[0]); + madds[1] = _mm_maddubs_epi16(s[1], filter[1]); + madds[2] = _mm_maddubs_epi16(s[2], filter[2]); + madds[3] = _mm_maddubs_epi16(s[3], filter[3]); + madds[0] = _mm_add_epi16(madds[0], madds[2]); + madds[1] = _mm_add_epi16(madds[1], madds[3]); + const __m128i s_3x128 = + _mm_slli_epi16(_mm_srli_epi16(s[1], 8), 7 - kInterRoundBitsHorizontal); + WienerHorizontalClip(madds, s_3x128, wiener_buffer); +} + +inline void WienerHorizontalTap5Kernel(const __m128i s[5], + const __m128i filter[3], + int16_t* const wiener_buffer) { + __m128i madds[3]; + madds[0] = _mm_maddubs_epi16(s[0], filter[0]); + madds[1] = _mm_maddubs_epi16(s[1], filter[1]); + madds[2] = _mm_maddubs_epi16(s[2], filter[2]); + madds[0] = _mm_add_epi16(madds[0], madds[2]); + const __m128i s_3x128 = + _mm_srli_epi16(_mm_slli_epi16(s[1], 8), kInterRoundBitsHorizontal + 1); + WienerHorizontalClip(madds, s_3x128, wiener_buffer); +} + +inline void WienerHorizontalTap3Kernel(const __m128i s[2], + const __m128i filter[2], + int16_t* const wiener_buffer) { + __m128i madds[2]; + madds[0] = _mm_maddubs_epi16(s[0], filter[0]); + madds[1] = _mm_maddubs_epi16(s[1], filter[1]); + const __m128i s_3x128 = + _mm_slli_epi16(_mm_srli_epi16(s[0], 8), 7 - kInterRoundBitsHorizontal); + WienerHorizontalClip(madds, s_3x128, wiener_buffer); +} + +// loading all and unpacking is about 7% faster than using _mm_alignr_epi8(). +inline void WienerHorizontalTap7(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + const int coefficient0, + const __m128i coefficients, + int16_t** const wiener_buffer) { + const __m128i round = _mm_set1_epi8(1 << (kInterRoundBitsHorizontal - 1)); + __m128i filter[4]; + filter[0] = _mm_shuffle_epi8(coefficients, _mm_set1_epi16(0x0200)); + filter[1] = _mm_shuffle_epi8(coefficients, _mm_set1_epi16(0x0604)); + filter[2] = _mm_shuffle_epi8(coefficients, _mm_set1_epi16(0x0204)); + filter[3] = _mm_set1_epi16((1 << 8) | static_cast<uint8_t>(coefficient0)); + for (int y = height; y != 0; --y) { + ptrdiff_t x = 0; + do { + __m128i s[7], ss[4]; + s[0] = LoadUnaligned16(src + x + 0); + s[1] = LoadUnaligned16(src + x + 1); + s[2] = LoadUnaligned16(src + x + 2); + s[3] = LoadUnaligned16(src + x + 3); + s[4] = LoadUnaligned16(src + x + 4); + s[5] = LoadUnaligned16(src + x + 5); + s[6] = LoadUnaligned16(src + x + 6); + ss[0] = _mm_unpacklo_epi8(s[0], s[1]); + ss[1] = _mm_unpacklo_epi8(s[2], s[3]); + ss[2] = _mm_unpacklo_epi8(s[4], s[5]); + ss[3] = _mm_unpacklo_epi8(s[6], round); + WienerHorizontalTap7Kernel(ss, filter, *wiener_buffer + x + 0); + ss[0] = _mm_unpackhi_epi8(s[0], s[1]); + ss[1] = _mm_unpackhi_epi8(s[2], s[3]); + ss[2] = _mm_unpackhi_epi8(s[4], s[5]); + ss[3] = _mm_unpackhi_epi8(s[6], round); + WienerHorizontalTap7Kernel(ss, filter, *wiener_buffer + x + 8); + x += 16; + } while (x < width); + src += src_stride; + *wiener_buffer += width; + } +} + +inline void WienerHorizontalTap5(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + const int coefficient1, + const __m128i coefficients, + int16_t** const wiener_buffer) { + const __m128i round = _mm_set1_epi8(1 << (kInterRoundBitsHorizontal - 1)); + __m128i filter[3]; + filter[0] = _mm_shuffle_epi8(coefficients, _mm_set1_epi16(0x0402)); + filter[1] = _mm_shuffle_epi8(coefficients, _mm_set1_epi16(0x0406)); + filter[2] = _mm_set1_epi16((1 << 8) | static_cast<uint8_t>(coefficient1)); + for (int y = height; y != 0; --y) { + ptrdiff_t x = 0; + do { + __m128i s[5], ss[3]; + s[0] = LoadUnaligned16(src + x + 0); + s[1] = LoadUnaligned16(src + x + 1); + s[2] = LoadUnaligned16(src + x + 2); + s[3] = LoadUnaligned16(src + x + 3); + s[4] = LoadUnaligned16(src + x + 4); + ss[0] = _mm_unpacklo_epi8(s[0], s[1]); + ss[1] = _mm_unpacklo_epi8(s[2], s[3]); + ss[2] = _mm_unpacklo_epi8(s[4], round); + WienerHorizontalTap5Kernel(ss, filter, *wiener_buffer + x + 0); + ss[0] = _mm_unpackhi_epi8(s[0], s[1]); + ss[1] = _mm_unpackhi_epi8(s[2], s[3]); + ss[2] = _mm_unpackhi_epi8(s[4], round); + WienerHorizontalTap5Kernel(ss, filter, *wiener_buffer + x + 8); + x += 16; + } while (x < width); + src += src_stride; + *wiener_buffer += width; + } +} + +inline void WienerHorizontalTap3(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + const int coefficient2, + const __m128i coefficients, + int16_t** const wiener_buffer) { + const __m128i round = _mm_set1_epi8(1 << (kInterRoundBitsHorizontal - 1)); + __m128i filter[2]; + filter[0] = _mm_shuffle_epi8(coefficients, _mm_set1_epi16(0x0604)); + filter[1] = _mm_set1_epi16((1 << 8) | static_cast<uint8_t>(coefficient2)); + for (int y = height; y != 0; --y) { + ptrdiff_t x = 0; + do { + __m128i s[3], ss[2]; + s[0] = LoadUnaligned16(src + x + 0); + s[1] = LoadUnaligned16(src + x + 1); + s[2] = LoadUnaligned16(src + x + 2); + ss[0] = _mm_unpacklo_epi8(s[0], s[1]); + ss[1] = _mm_unpacklo_epi8(s[2], round); + WienerHorizontalTap3Kernel(ss, filter, *wiener_buffer + x + 0); + ss[0] = _mm_unpackhi_epi8(s[0], s[1]); + ss[1] = _mm_unpackhi_epi8(s[2], round); + WienerHorizontalTap3Kernel(ss, filter, *wiener_buffer + x + 8); + x += 16; + } while (x < width); + src += src_stride; + *wiener_buffer += width; + } +} + +inline void WienerHorizontalTap1(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const int height, + int16_t** const wiener_buffer) { + for (int y = height; y != 0; --y) { + ptrdiff_t x = 0; + do { + const __m128i s = LoadUnaligned16(src + x); + const __m128i s0 = _mm_unpacklo_epi8(s, _mm_setzero_si128()); + const __m128i s1 = _mm_unpackhi_epi8(s, _mm_setzero_si128()); + const __m128i d0 = _mm_slli_epi16(s0, 4); + const __m128i d1 = _mm_slli_epi16(s1, 4); + StoreAligned16(*wiener_buffer + x + 0, d0); + StoreAligned16(*wiener_buffer + x + 8, d1); + x += 16; + } while (x < width); + src += src_stride; + *wiener_buffer += width; + } +} + +inline __m128i WienerVertical7(const __m128i a[2], const __m128i filter[2]) { + const __m128i round = _mm_set1_epi32(1 << (kInterRoundBitsVertical - 1)); + const __m128i madd0 = _mm_madd_epi16(a[0], filter[0]); + const __m128i madd1 = _mm_madd_epi16(a[1], filter[1]); + const __m128i sum0 = _mm_add_epi32(round, madd0); + const __m128i sum1 = _mm_add_epi32(sum0, madd1); + return _mm_srai_epi32(sum1, kInterRoundBitsVertical); +} + +inline __m128i WienerVertical5(const __m128i a[2], const __m128i filter[2]) { + const __m128i madd0 = _mm_madd_epi16(a[0], filter[0]); + const __m128i madd1 = _mm_madd_epi16(a[1], filter[1]); + const __m128i sum = _mm_add_epi32(madd0, madd1); + return _mm_srai_epi32(sum, kInterRoundBitsVertical); +} + +inline __m128i WienerVertical3(const __m128i a, const __m128i filter) { + const __m128i round = _mm_set1_epi32(1 << (kInterRoundBitsVertical - 1)); + const __m128i madd = _mm_madd_epi16(a, filter); + const __m128i sum = _mm_add_epi32(round, madd); + return _mm_srai_epi32(sum, kInterRoundBitsVertical); +} + +inline __m128i WienerVerticalFilter7(const __m128i a[7], + const __m128i filter[2]) { + __m128i b[2]; + const __m128i a06 = _mm_add_epi16(a[0], a[6]); + const __m128i a15 = _mm_add_epi16(a[1], a[5]); + const __m128i a24 = _mm_add_epi16(a[2], a[4]); + b[0] = _mm_unpacklo_epi16(a06, a15); + b[1] = _mm_unpacklo_epi16(a24, a[3]); + const __m128i sum0 = WienerVertical7(b, filter); + b[0] = _mm_unpackhi_epi16(a06, a15); + b[1] = _mm_unpackhi_epi16(a24, a[3]); + const __m128i sum1 = WienerVertical7(b, filter); + return _mm_packs_epi32(sum0, sum1); +} + +inline __m128i WienerVerticalFilter5(const __m128i a[5], + const __m128i filter[2]) { + const __m128i round = _mm_set1_epi16(1 << (kInterRoundBitsVertical - 1)); + __m128i b[2]; + const __m128i a04 = _mm_add_epi16(a[0], a[4]); + const __m128i a13 = _mm_add_epi16(a[1], a[3]); + b[0] = _mm_unpacklo_epi16(a04, a13); + b[1] = _mm_unpacklo_epi16(a[2], round); + const __m128i sum0 = WienerVertical5(b, filter); + b[0] = _mm_unpackhi_epi16(a04, a13); + b[1] = _mm_unpackhi_epi16(a[2], round); + const __m128i sum1 = WienerVertical5(b, filter); + return _mm_packs_epi32(sum0, sum1); +} + +inline __m128i WienerVerticalFilter3(const __m128i a[3], const __m128i filter) { + __m128i b; + const __m128i a02 = _mm_add_epi16(a[0], a[2]); + b = _mm_unpacklo_epi16(a02, a[1]); + const __m128i sum0 = WienerVertical3(b, filter); + b = _mm_unpackhi_epi16(a02, a[1]); + const __m128i sum1 = WienerVertical3(b, filter); + return _mm_packs_epi32(sum0, sum1); +} + +inline __m128i WienerVerticalTap7Kernel(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m128i filter[2], __m128i a[7]) { + a[0] = LoadAligned16(wiener_buffer + 0 * wiener_stride); + a[1] = LoadAligned16(wiener_buffer + 1 * wiener_stride); + a[2] = LoadAligned16(wiener_buffer + 2 * wiener_stride); + a[3] = LoadAligned16(wiener_buffer + 3 * wiener_stride); + a[4] = LoadAligned16(wiener_buffer + 4 * wiener_stride); + a[5] = LoadAligned16(wiener_buffer + 5 * wiener_stride); + a[6] = LoadAligned16(wiener_buffer + 6 * wiener_stride); + return WienerVerticalFilter7(a, filter); +} + +inline __m128i WienerVerticalTap5Kernel(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m128i filter[2], __m128i a[5]) { + a[0] = LoadAligned16(wiener_buffer + 0 * wiener_stride); + a[1] = LoadAligned16(wiener_buffer + 1 * wiener_stride); + a[2] = LoadAligned16(wiener_buffer + 2 * wiener_stride); + a[3] = LoadAligned16(wiener_buffer + 3 * wiener_stride); + a[4] = LoadAligned16(wiener_buffer + 4 * wiener_stride); + return WienerVerticalFilter5(a, filter); +} + +inline __m128i WienerVerticalTap3Kernel(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m128i filter, __m128i a[3]) { + a[0] = LoadAligned16(wiener_buffer + 0 * wiener_stride); + a[1] = LoadAligned16(wiener_buffer + 1 * wiener_stride); + a[2] = LoadAligned16(wiener_buffer + 2 * wiener_stride); + return WienerVerticalFilter3(a, filter); +} + +inline void WienerVerticalTap7Kernel2(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m128i filter[2], __m128i d[2]) { + __m128i a[8]; + d[0] = WienerVerticalTap7Kernel(wiener_buffer, wiener_stride, filter, a); + a[7] = LoadAligned16(wiener_buffer + 7 * wiener_stride); + d[1] = WienerVerticalFilter7(a + 1, filter); +} + +inline void WienerVerticalTap5Kernel2(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m128i filter[2], __m128i d[2]) { + __m128i a[6]; + d[0] = WienerVerticalTap5Kernel(wiener_buffer, wiener_stride, filter, a); + a[5] = LoadAligned16(wiener_buffer + 5 * wiener_stride); + d[1] = WienerVerticalFilter5(a + 1, filter); +} + +inline void WienerVerticalTap3Kernel2(const int16_t* wiener_buffer, + const ptrdiff_t wiener_stride, + const __m128i filter, __m128i d[2]) { + __m128i a[4]; + d[0] = WienerVerticalTap3Kernel(wiener_buffer, wiener_stride, filter, a); + a[3] = LoadAligned16(wiener_buffer + 3 * wiener_stride); + d[1] = WienerVerticalFilter3(a + 1, filter); +} + +inline void WienerVerticalTap7(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + const int16_t coefficients[4], uint8_t* dst, + const ptrdiff_t dst_stride) { + const __m128i c = LoadLo8(coefficients); + __m128i filter[2]; + filter[0] = _mm_shuffle_epi32(c, 0x0); + filter[1] = _mm_shuffle_epi32(c, 0x55); + for (int y = height >> 1; y > 0; --y) { + ptrdiff_t x = 0; + do { + __m128i d[2][2]; + WienerVerticalTap7Kernel2(wiener_buffer + x + 0, width, filter, d[0]); + WienerVerticalTap7Kernel2(wiener_buffer + x + 8, width, filter, d[1]); + StoreAligned16(dst + x, _mm_packus_epi16(d[0][0], d[1][0])); + StoreAligned16(dst + dst_stride + x, _mm_packus_epi16(d[0][1], d[1][1])); + x += 16; + } while (x < width); + dst += 2 * dst_stride; + wiener_buffer += 2 * width; + } + + if ((height & 1) != 0) { + ptrdiff_t x = 0; + do { + __m128i a[7]; + const __m128i d0 = + WienerVerticalTap7Kernel(wiener_buffer + x + 0, width, filter, a); + const __m128i d1 = + WienerVerticalTap7Kernel(wiener_buffer + x + 8, width, filter, a); + StoreAligned16(dst + x, _mm_packus_epi16(d0, d1)); + x += 16; + } while (x < width); + } +} + +inline void WienerVerticalTap5(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + const int16_t coefficients[3], uint8_t* dst, + const ptrdiff_t dst_stride) { + const __m128i c = Load4(coefficients); + __m128i filter[2]; + filter[0] = _mm_shuffle_epi32(c, 0); + filter[1] = + _mm_set1_epi32((1 << 16) | static_cast<uint16_t>(coefficients[2])); + for (int y = height >> 1; y > 0; --y) { + ptrdiff_t x = 0; + do { + __m128i d[2][2]; + WienerVerticalTap5Kernel2(wiener_buffer + x + 0, width, filter, d[0]); + WienerVerticalTap5Kernel2(wiener_buffer + x + 8, width, filter, d[1]); + StoreAligned16(dst + x, _mm_packus_epi16(d[0][0], d[1][0])); + StoreAligned16(dst + dst_stride + x, _mm_packus_epi16(d[0][1], d[1][1])); + x += 16; + } while (x < width); + dst += 2 * dst_stride; + wiener_buffer += 2 * width; + } + + if ((height & 1) != 0) { + ptrdiff_t x = 0; + do { + __m128i a[5]; + const __m128i d0 = + WienerVerticalTap5Kernel(wiener_buffer + x + 0, width, filter, a); + const __m128i d1 = + WienerVerticalTap5Kernel(wiener_buffer + x + 8, width, filter, a); + StoreAligned16(dst + x, _mm_packus_epi16(d0, d1)); + x += 16; + } while (x < width); + } +} + +inline void WienerVerticalTap3(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + const int16_t coefficients[2], uint8_t* dst, + const ptrdiff_t dst_stride) { + const __m128i filter = + _mm_set1_epi32(*reinterpret_cast<const int32_t*>(coefficients)); + for (int y = height >> 1; y > 0; --y) { + ptrdiff_t x = 0; + do { + __m128i d[2][2]; + WienerVerticalTap3Kernel2(wiener_buffer + x + 0, width, filter, d[0]); + WienerVerticalTap3Kernel2(wiener_buffer + x + 8, width, filter, d[1]); + StoreAligned16(dst + x, _mm_packus_epi16(d[0][0], d[1][0])); + StoreAligned16(dst + dst_stride + x, _mm_packus_epi16(d[0][1], d[1][1])); + x += 16; + } while (x < width); + dst += 2 * dst_stride; + wiener_buffer += 2 * width; + } + + if ((height & 1) != 0) { + ptrdiff_t x = 0; + do { + __m128i a[3]; + const __m128i d0 = + WienerVerticalTap3Kernel(wiener_buffer + x + 0, width, filter, a); + const __m128i d1 = + WienerVerticalTap3Kernel(wiener_buffer + x + 8, width, filter, a); + StoreAligned16(dst + x, _mm_packus_epi16(d0, d1)); + x += 16; + } while (x < width); + } +} + +inline void WienerVerticalTap1Kernel(const int16_t* const wiener_buffer, + uint8_t* const dst) { + const __m128i a0 = LoadAligned16(wiener_buffer + 0); + const __m128i a1 = LoadAligned16(wiener_buffer + 8); + const __m128i b0 = _mm_add_epi16(a0, _mm_set1_epi16(8)); + const __m128i b1 = _mm_add_epi16(a1, _mm_set1_epi16(8)); + const __m128i c0 = _mm_srai_epi16(b0, 4); + const __m128i c1 = _mm_srai_epi16(b1, 4); + const __m128i d = _mm_packus_epi16(c0, c1); + StoreAligned16(dst, d); +} + +inline void WienerVerticalTap1(const int16_t* wiener_buffer, + const ptrdiff_t width, const int height, + uint8_t* dst, const ptrdiff_t dst_stride) { + for (int y = height >> 1; y > 0; --y) { + ptrdiff_t x = 0; + do { + WienerVerticalTap1Kernel(wiener_buffer + x, dst + x); + WienerVerticalTap1Kernel(wiener_buffer + width + x, dst + dst_stride + x); + x += 16; + } while (x < width); + dst += 2 * dst_stride; + wiener_buffer += 2 * width; + } + + if ((height & 1) != 0) { + ptrdiff_t x = 0; + do { + WienerVerticalTap1Kernel(wiener_buffer + x, dst + x); + x += 16; + } while (x < width); + } +} + +void WienerFilter_SSE4_1(const RestorationUnitInfo& restoration_info, + const void* const source, const void* const top_border, + const void* const bottom_border, + const ptrdiff_t stride, const int width, + const int height, + RestorationBuffer* const restoration_buffer, + void* const dest) { + const int16_t* const number_leading_zero_coefficients = + restoration_info.wiener_info.number_leading_zero_coefficients; + const int number_rows_to_skip = std::max( + static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]), + 1); + const ptrdiff_t wiener_stride = Align(width, 16); + int16_t* const wiener_buffer_vertical = restoration_buffer->wiener_buffer; + // The values are saturated to 13 bits before storing. + int16_t* wiener_buffer_horizontal = + wiener_buffer_vertical + number_rows_to_skip * wiener_stride; + + // horizontal filtering. + // Over-reads up to 15 - |kRestorationHorizontalBorder| values. + const int height_horizontal = + height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip; + const int height_extra = (height_horizontal - height) >> 1; + assert(height_extra <= 2); + const auto* const src = static_cast<const uint8_t*>(source); + const auto* const top = static_cast<const uint8_t*>(top_border); + const auto* const bottom = static_cast<const uint8_t*>(bottom_border); + const int16_t* const filter_horizontal = + restoration_info.wiener_info.filter[WienerInfo::kHorizontal]; + const __m128i c = LoadLo8(filter_horizontal); + // In order to keep the horizontal pass intermediate values within 16 bits we + // offset |filter[3]| by 128. The 128 offset will be added back in the loop. + const __m128i coefficients_horizontal = + _mm_sub_epi16(c, _mm_setr_epi16(0, 0, 0, 128, 0, 0, 0, 0)); + if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) { + WienerHorizontalTap7(top + (2 - height_extra) * stride - 3, stride, + wiener_stride, height_extra, filter_horizontal[0], + coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap7(src - 3, stride, wiener_stride, height, + filter_horizontal[0], coefficients_horizontal, + &wiener_buffer_horizontal); + WienerHorizontalTap7(bottom - 3, stride, wiener_stride, height_extra, + filter_horizontal[0], coefficients_horizontal, + &wiener_buffer_horizontal); + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) { + WienerHorizontalTap5(top + (2 - height_extra) * stride - 2, stride, + wiener_stride, height_extra, filter_horizontal[1], + coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap5(src - 2, stride, wiener_stride, height, + filter_horizontal[1], coefficients_horizontal, + &wiener_buffer_horizontal); + WienerHorizontalTap5(bottom - 2, stride, wiener_stride, height_extra, + filter_horizontal[1], coefficients_horizontal, + &wiener_buffer_horizontal); + } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) { + // The maximum over-reads happen here. + WienerHorizontalTap3(top + (2 - height_extra) * stride - 1, stride, + wiener_stride, height_extra, filter_horizontal[2], + coefficients_horizontal, &wiener_buffer_horizontal); + WienerHorizontalTap3(src - 1, stride, wiener_stride, height, + filter_horizontal[2], coefficients_horizontal, + &wiener_buffer_horizontal); + WienerHorizontalTap3(bottom - 1, stride, wiener_stride, height_extra, + filter_horizontal[2], coefficients_horizontal, + &wiener_buffer_horizontal); + } else { + assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3); + WienerHorizontalTap1(top + (2 - height_extra) * stride, stride, + wiener_stride, height_extra, + &wiener_buffer_horizontal); + WienerHorizontalTap1(src, stride, wiener_stride, height, + &wiener_buffer_horizontal); + WienerHorizontalTap1(bottom, stride, wiener_stride, height_extra, + &wiener_buffer_horizontal); + } + + // vertical filtering. + // Over-writes up to 15 values. + const int16_t* const filter_vertical = + restoration_info.wiener_info.filter[WienerInfo::kVertical]; + auto* dst = static_cast<uint8_t*>(dest); + if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) { + // Because the top row of |source| is a duplicate of the second row, and the + // bottom row of |source| is a duplicate of its above row, we can duplicate + // the top and bottom row of |wiener_buffer| accordingly. + memcpy(wiener_buffer_horizontal, wiener_buffer_horizontal - wiener_stride, + sizeof(*wiener_buffer_horizontal) * wiener_stride); + memcpy(restoration_buffer->wiener_buffer, + restoration_buffer->wiener_buffer + wiener_stride, + sizeof(*restoration_buffer->wiener_buffer) * wiener_stride); + WienerVerticalTap7(wiener_buffer_vertical, wiener_stride, height, + filter_vertical, dst, stride); + } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) { + WienerVerticalTap5(wiener_buffer_vertical + wiener_stride, wiener_stride, + height, filter_vertical + 1, dst, stride); + } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) { + WienerVerticalTap3(wiener_buffer_vertical + 2 * wiener_stride, + wiener_stride, height, filter_vertical + 2, dst, stride); + } else { + assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3); + WienerVerticalTap1(wiener_buffer_vertical + 3 * wiener_stride, + wiener_stride, height, dst, stride); + } +} + +//------------------------------------------------------------------------------ +// SGR + +// SIMD overreads 16 - (width % 16) - 2 * padding pixels, where padding is 3 for +// Pass 1 and 2 for Pass 2. +constexpr int kOverreadInBytesPass1 = 10; +constexpr int kOverreadInBytesPass2 = 12; + +inline void LoadAligned16x2U16(const uint16_t* const src[2], const ptrdiff_t x, + __m128i dst[2]) { + dst[0] = LoadAligned16(src[0] + x); + dst[1] = LoadAligned16(src[1] + x); +} + +inline void LoadAligned16x2U16Msan(const uint16_t* const src[2], + const ptrdiff_t x, const ptrdiff_t border, + __m128i dst[2]) { + dst[0] = LoadAligned16Msan(src[0] + x, sizeof(**src) * (x + 8 - border)); + dst[1] = LoadAligned16Msan(src[1] + x, sizeof(**src) * (x + 8 - border)); +} + +inline void LoadAligned16x3U16(const uint16_t* const src[3], const ptrdiff_t x, + __m128i dst[3]) { + dst[0] = LoadAligned16(src[0] + x); + dst[1] = LoadAligned16(src[1] + x); + dst[2] = LoadAligned16(src[2] + x); +} + +inline void LoadAligned16x3U16Msan(const uint16_t* const src[3], + const ptrdiff_t x, const ptrdiff_t border, + __m128i dst[3]) { + dst[0] = LoadAligned16Msan(src[0] + x, sizeof(**src) * (x + 8 - border)); + dst[1] = LoadAligned16Msan(src[1] + x, sizeof(**src) * (x + 8 - border)); + dst[2] = LoadAligned16Msan(src[2] + x, sizeof(**src) * (x + 8 - border)); +} + +inline void LoadAligned32U32(const uint32_t* const src, __m128i dst[2]) { + dst[0] = LoadAligned16(src + 0); + dst[1] = LoadAligned16(src + 4); +} + +inline void LoadAligned32U32Msan(const uint32_t* const src, const ptrdiff_t x, + const ptrdiff_t border, __m128i dst[2]) { + dst[0] = LoadAligned16Msan(src + x + 0, sizeof(*src) * (x + 4 - border)); + dst[1] = LoadAligned16Msan(src + x + 4, sizeof(*src) * (x + 8 - border)); +} + +inline void LoadAligned32x2U32(const uint32_t* const src[2], const ptrdiff_t x, + __m128i dst[2][2]) { + LoadAligned32U32(src[0] + x, dst[0]); + LoadAligned32U32(src[1] + x, dst[1]); +} + +inline void LoadAligned32x2U32Msan(const uint32_t* const src[2], + const ptrdiff_t x, const ptrdiff_t border, + __m128i dst[2][2]) { + LoadAligned32U32Msan(src[0], x, border, dst[0]); + LoadAligned32U32Msan(src[1], x, border, dst[1]); +} + +inline void LoadAligned32x3U32(const uint32_t* const src[3], const ptrdiff_t x, + __m128i dst[3][2]) { + LoadAligned32U32(src[0] + x, dst[0]); + LoadAligned32U32(src[1] + x, dst[1]); + LoadAligned32U32(src[2] + x, dst[2]); +} + +inline void LoadAligned32x3U32Msan(const uint32_t* const src[3], + const ptrdiff_t x, const ptrdiff_t border, + __m128i dst[3][2]) { + LoadAligned32U32Msan(src[0], x, border, dst[0]); + LoadAligned32U32Msan(src[1], x, border, dst[1]); + LoadAligned32U32Msan(src[2], x, border, dst[2]); +} + +inline void StoreAligned32U16(uint16_t* const dst, const __m128i src[2]) { + StoreAligned16(dst + 0, src[0]); + StoreAligned16(dst + 8, src[1]); +} + +inline void StoreAligned32U32(uint32_t* const dst, const __m128i src[2]) { + StoreAligned16(dst + 0, src[0]); + StoreAligned16(dst + 4, src[1]); +} + +inline void StoreAligned64U32(uint32_t* const dst, const __m128i src[4]) { + StoreAligned32U32(dst + 0, src + 0); + StoreAligned32U32(dst + 8, src + 2); +} + +// Don't use _mm_cvtepu8_epi16() or _mm_cvtepu16_epi32() in the following +// functions. Some compilers may generate super inefficient code and the whole +// decoder could be 15% slower. + +inline __m128i VaddlLo8(const __m128i src0, const __m128i src1) { + const __m128i s0 = _mm_unpacklo_epi8(src0, _mm_setzero_si128()); + const __m128i s1 = _mm_unpacklo_epi8(src1, _mm_setzero_si128()); + return _mm_add_epi16(s0, s1); +} + +inline __m128i VaddlHi8(const __m128i src0, const __m128i src1) { + const __m128i s0 = _mm_unpackhi_epi8(src0, _mm_setzero_si128()); + const __m128i s1 = _mm_unpackhi_epi8(src1, _mm_setzero_si128()); + return _mm_add_epi16(s0, s1); +} + +inline __m128i VaddlLo16(const __m128i src0, const __m128i src1) { + const __m128i s0 = _mm_unpacklo_epi16(src0, _mm_setzero_si128()); + const __m128i s1 = _mm_unpacklo_epi16(src1, _mm_setzero_si128()); + return _mm_add_epi32(s0, s1); +} + +inline __m128i VaddlHi16(const __m128i src0, const __m128i src1) { + const __m128i s0 = _mm_unpackhi_epi16(src0, _mm_setzero_si128()); + const __m128i s1 = _mm_unpackhi_epi16(src1, _mm_setzero_si128()); + return _mm_add_epi32(s0, s1); +} + +inline __m128i VaddwLo8(const __m128i src0, const __m128i src1) { + const __m128i s1 = _mm_unpacklo_epi8(src1, _mm_setzero_si128()); + return _mm_add_epi16(src0, s1); +} + +inline __m128i VaddwHi8(const __m128i src0, const __m128i src1) { + const __m128i s1 = _mm_unpackhi_epi8(src1, _mm_setzero_si128()); + return _mm_add_epi16(src0, s1); +} + +inline __m128i VaddwLo16(const __m128i src0, const __m128i src1) { + const __m128i s1 = _mm_unpacklo_epi16(src1, _mm_setzero_si128()); + return _mm_add_epi32(src0, s1); +} + +inline __m128i VaddwHi16(const __m128i src0, const __m128i src1) { + const __m128i s1 = _mm_unpackhi_epi16(src1, _mm_setzero_si128()); + return _mm_add_epi32(src0, s1); +} + +inline __m128i VmullNLo8(const __m128i src0, const int src1) { + const __m128i s0 = _mm_unpacklo_epi16(src0, _mm_setzero_si128()); + return _mm_madd_epi16(s0, _mm_set1_epi32(src1)); +} + +inline __m128i VmullNHi8(const __m128i src0, const int src1) { + const __m128i s0 = _mm_unpackhi_epi16(src0, _mm_setzero_si128()); + return _mm_madd_epi16(s0, _mm_set1_epi32(src1)); +} + +inline __m128i VmullLo16(const __m128i src0, const __m128i src1) { + const __m128i s0 = _mm_unpacklo_epi16(src0, _mm_setzero_si128()); + const __m128i s1 = _mm_unpacklo_epi16(src1, _mm_setzero_si128()); + return _mm_madd_epi16(s0, s1); +} + +inline __m128i VmullHi16(const __m128i src0, const __m128i src1) { + const __m128i s0 = _mm_unpackhi_epi16(src0, _mm_setzero_si128()); + const __m128i s1 = _mm_unpackhi_epi16(src1, _mm_setzero_si128()); + return _mm_madd_epi16(s0, s1); +} + +inline __m128i VrshrS32(const __m128i src0, const int src1) { + const __m128i sum = _mm_add_epi32(src0, _mm_set1_epi32(1 << (src1 - 1))); + return _mm_srai_epi32(sum, src1); +} + +inline __m128i VrshrU32(const __m128i src0, const int src1) { + const __m128i sum = _mm_add_epi32(src0, _mm_set1_epi32(1 << (src1 - 1))); + return _mm_srli_epi32(sum, src1); +} + +inline __m128i SquareLo8(const __m128i src) { + const __m128i s = _mm_unpacklo_epi8(src, _mm_setzero_si128()); + return _mm_mullo_epi16(s, s); +} + +inline __m128i SquareHi8(const __m128i src) { + const __m128i s = _mm_unpackhi_epi8(src, _mm_setzero_si128()); + return _mm_mullo_epi16(s, s); +} + +inline void Prepare3Lo8(const __m128i src, __m128i dst[3]) { + dst[0] = src; + dst[1] = _mm_srli_si128(src, 1); + dst[2] = _mm_srli_si128(src, 2); +} + +template <int offset> +inline void Prepare3_8(const __m128i src[2], __m128i dst[3]) { + dst[0] = _mm_alignr_epi8(src[1], src[0], offset + 0); + dst[1] = _mm_alignr_epi8(src[1], src[0], offset + 1); + dst[2] = _mm_alignr_epi8(src[1], src[0], offset + 2); +} + +inline void Prepare3_16(const __m128i src[2], __m128i dst[3]) { + dst[0] = src[0]; + dst[1] = _mm_alignr_epi8(src[1], src[0], 2); + dst[2] = _mm_alignr_epi8(src[1], src[0], 4); +} + +inline void Prepare5Lo8(const __m128i src, __m128i dst[5]) { + dst[0] = src; + dst[1] = _mm_srli_si128(src, 1); + dst[2] = _mm_srli_si128(src, 2); + dst[3] = _mm_srli_si128(src, 3); + dst[4] = _mm_srli_si128(src, 4); +} + +template <int offset> +inline void Prepare5_8(const __m128i src[2], __m128i dst[5]) { + dst[0] = _mm_alignr_epi8(src[1], src[0], offset + 0); + dst[1] = _mm_alignr_epi8(src[1], src[0], offset + 1); + dst[2] = _mm_alignr_epi8(src[1], src[0], offset + 2); + dst[3] = _mm_alignr_epi8(src[1], src[0], offset + 3); + dst[4] = _mm_alignr_epi8(src[1], src[0], offset + 4); +} + +inline void Prepare5_16(const __m128i src[2], __m128i dst[5]) { + Prepare3_16(src, dst); + dst[3] = _mm_alignr_epi8(src[1], src[0], 6); + dst[4] = _mm_alignr_epi8(src[1], src[0], 8); +} + +inline __m128i Sum3_16(const __m128i src0, const __m128i src1, + const __m128i src2) { + const __m128i sum = _mm_add_epi16(src0, src1); + return _mm_add_epi16(sum, src2); +} + +inline __m128i Sum3_16(const __m128i src[3]) { + return Sum3_16(src[0], src[1], src[2]); +} + +inline __m128i Sum3_32(const __m128i src0, const __m128i src1, + const __m128i src2) { + const __m128i sum = _mm_add_epi32(src0, src1); + return _mm_add_epi32(sum, src2); +} + +inline void Sum3_32(const __m128i src[3][2], __m128i dst[2]) { + dst[0] = Sum3_32(src[0][0], src[1][0], src[2][0]); + dst[1] = Sum3_32(src[0][1], src[1][1], src[2][1]); +} + +inline __m128i Sum3WLo16(const __m128i src[3]) { + const __m128i sum = VaddlLo8(src[0], src[1]); + return VaddwLo8(sum, src[2]); +} + +inline __m128i Sum3WHi16(const __m128i src[3]) { + const __m128i sum = VaddlHi8(src[0], src[1]); + return VaddwHi8(sum, src[2]); +} + +inline __m128i Sum3WLo32(const __m128i src[3]) { + const __m128i sum = VaddlLo16(src[0], src[1]); + return VaddwLo16(sum, src[2]); +} + +inline __m128i Sum3WHi32(const __m128i src[3]) { + const __m128i sum = VaddlHi16(src[0], src[1]); + return VaddwHi16(sum, src[2]); +} + +inline __m128i Sum5_16(const __m128i src[5]) { + const __m128i sum01 = _mm_add_epi16(src[0], src[1]); + const __m128i sum23 = _mm_add_epi16(src[2], src[3]); + const __m128i sum = _mm_add_epi16(sum01, sum23); + return _mm_add_epi16(sum, src[4]); +} + +inline __m128i Sum5_32(const __m128i* const src0, const __m128i* const src1, + const __m128i* const src2, const __m128i* const src3, + const __m128i* const src4) { + const __m128i sum01 = _mm_add_epi32(*src0, *src1); + const __m128i sum23 = _mm_add_epi32(*src2, *src3); + const __m128i sum = _mm_add_epi32(sum01, sum23); + return _mm_add_epi32(sum, *src4); +} + +inline void Sum5_32(const __m128i src[5][2], __m128i dst[2]) { + dst[0] = Sum5_32(&src[0][0], &src[1][0], &src[2][0], &src[3][0], &src[4][0]); + dst[1] = Sum5_32(&src[0][1], &src[1][1], &src[2][1], &src[3][1], &src[4][1]); +} + +inline __m128i Sum5WLo16(const __m128i src[5]) { + const __m128i sum01 = VaddlLo8(src[0], src[1]); + const __m128i sum23 = VaddlLo8(src[2], src[3]); + const __m128i sum = _mm_add_epi16(sum01, sum23); + return VaddwLo8(sum, src[4]); +} + +inline __m128i Sum5WHi16(const __m128i src[5]) { + const __m128i sum01 = VaddlHi8(src[0], src[1]); + const __m128i sum23 = VaddlHi8(src[2], src[3]); + const __m128i sum = _mm_add_epi16(sum01, sum23); + return VaddwHi8(sum, src[4]); +} + +inline __m128i Sum3Horizontal(const __m128i src) { + __m128i s[3]; + Prepare3Lo8(src, s); + return Sum3WLo16(s); +} + +template <int offset> +inline void Sum3Horizontal(const __m128i src[2], __m128i dst[2]) { + __m128i s[3]; + Prepare3_8<offset>(src, s); + dst[0] = Sum3WLo16(s); + dst[1] = Sum3WHi16(s); +} + +inline void Sum3WHorizontal(const __m128i src[2], __m128i dst[2]) { + __m128i s[3]; + Prepare3_16(src, s); + dst[0] = Sum3WLo32(s); + dst[1] = Sum3WHi32(s); +} + +inline __m128i Sum5Horizontal(const __m128i src) { + __m128i s[5]; + Prepare5Lo8(src, s); + return Sum5WLo16(s); +} + +template <int offset> +inline void Sum5Horizontal(const __m128i src[2], __m128i* const dst0, + __m128i* const dst1) { + __m128i s[5]; + Prepare5_8<offset>(src, s); + *dst0 = Sum5WLo16(s); + *dst1 = Sum5WHi16(s); +} + +inline void Sum5WHorizontal(const __m128i src[2], __m128i dst[2]) { + __m128i s[5]; + Prepare5_16(src, s); + const __m128i sum01_lo = VaddlLo16(s[0], s[1]); + const __m128i sum23_lo = VaddlLo16(s[2], s[3]); + const __m128i sum0123_lo = _mm_add_epi32(sum01_lo, sum23_lo); + dst[0] = VaddwLo16(sum0123_lo, s[4]); + const __m128i sum01_hi = VaddlHi16(s[0], s[1]); + const __m128i sum23_hi = VaddlHi16(s[2], s[3]); + const __m128i sum0123_hi = _mm_add_epi32(sum01_hi, sum23_hi); + dst[1] = VaddwHi16(sum0123_hi, s[4]); +} + +void SumHorizontalLo(const __m128i src[5], __m128i* const row_sq3, + __m128i* const row_sq5) { + const __m128i sum04 = VaddlLo16(src[0], src[4]); + *row_sq3 = Sum3WLo32(src + 1); + *row_sq5 = _mm_add_epi32(sum04, *row_sq3); +} + +void SumHorizontalHi(const __m128i src[5], __m128i* const row_sq3, + __m128i* const row_sq5) { + const __m128i sum04 = VaddlHi16(src[0], src[4]); + *row_sq3 = Sum3WHi32(src + 1); + *row_sq5 = _mm_add_epi32(sum04, *row_sq3); +} + +void SumHorizontalLo(const __m128i src, __m128i* const row3, + __m128i* const row5) { + __m128i s[5]; + Prepare5Lo8(src, s); + const __m128i sum04 = VaddlLo8(s[0], s[4]); + *row3 = Sum3WLo16(s + 1); + *row5 = _mm_add_epi16(sum04, *row3); +} + +template <int offset> +void SumHorizontal(const __m128i src[2], __m128i* const row3_0, + __m128i* const row3_1, __m128i* const row5_0, + __m128i* const row5_1) { + __m128i s[5]; + Prepare5_8<offset>(src, s); + const __m128i sum04_lo = VaddlLo8(s[0], s[4]); + const __m128i sum04_hi = VaddlHi8(s[0], s[4]); + *row3_0 = Sum3WLo16(s + 1); + *row3_1 = Sum3WHi16(s + 1); + *row5_0 = _mm_add_epi16(sum04_lo, *row3_0); + *row5_1 = _mm_add_epi16(sum04_hi, *row3_1); +} + +inline void SumHorizontal(const __m128i src[2], __m128i* const row_sq3_0, + __m128i* const row_sq3_1, __m128i* const row_sq5_0, + __m128i* const row_sq5_1) { + __m128i s[5]; + Prepare5_16(src, s); + SumHorizontalLo(s, row_sq3_0, row_sq5_0); + SumHorizontalHi(s, row_sq3_1, row_sq5_1); +} + +inline __m128i Sum343Lo(const __m128i ma3[3]) { + const __m128i sum = Sum3WLo16(ma3); + const __m128i sum3 = Sum3_16(sum, sum, sum); + return VaddwLo8(sum3, ma3[1]); +} + +inline __m128i Sum343Hi(const __m128i ma3[3]) { + const __m128i sum = Sum3WHi16(ma3); + const __m128i sum3 = Sum3_16(sum, sum, sum); + return VaddwHi8(sum3, ma3[1]); +} + +inline __m128i Sum343WLo(const __m128i src[3]) { + const __m128i sum = Sum3WLo32(src); + const __m128i sum3 = Sum3_32(sum, sum, sum); + return VaddwLo16(sum3, src[1]); +} + +inline __m128i Sum343WHi(const __m128i src[3]) { + const __m128i sum = Sum3WHi32(src); + const __m128i sum3 = Sum3_32(sum, sum, sum); + return VaddwHi16(sum3, src[1]); +} + +inline void Sum343W(const __m128i src[2], __m128i dst[2]) { + __m128i s[3]; + Prepare3_16(src, s); + dst[0] = Sum343WLo(s); + dst[1] = Sum343WHi(s); +} + +inline __m128i Sum565Lo(const __m128i src[3]) { + const __m128i sum = Sum3WLo16(src); + const __m128i sum4 = _mm_slli_epi16(sum, 2); + const __m128i sum5 = _mm_add_epi16(sum4, sum); + return VaddwLo8(sum5, src[1]); +} + +inline __m128i Sum565Hi(const __m128i src[3]) { + const __m128i sum = Sum3WHi16(src); + const __m128i sum4 = _mm_slli_epi16(sum, 2); + const __m128i sum5 = _mm_add_epi16(sum4, sum); + return VaddwHi8(sum5, src[1]); +} + +inline __m128i Sum565WLo(const __m128i src[3]) { + const __m128i sum = Sum3WLo32(src); + const __m128i sum4 = _mm_slli_epi32(sum, 2); + const __m128i sum5 = _mm_add_epi32(sum4, sum); + return VaddwLo16(sum5, src[1]); +} + +inline __m128i Sum565WHi(const __m128i src[3]) { + const __m128i sum = Sum3WHi32(src); + const __m128i sum4 = _mm_slli_epi32(sum, 2); + const __m128i sum5 = _mm_add_epi32(sum4, sum); + return VaddwHi16(sum5, src[1]); +} + +inline void Sum565W(const __m128i src[2], __m128i dst[2]) { + __m128i s[3]; + Prepare3_16(src, s); + dst[0] = Sum565WLo(s); + dst[1] = Sum565WHi(s); +} + +inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const ptrdiff_t sum_stride, + const ptrdiff_t sum_width, uint16_t* sum3, uint16_t* sum5, + uint32_t* square_sum3, uint32_t* square_sum5) { + int y = 2; + do { + __m128i s[2], sq[3]; + s[0] = LoadUnaligned16Msan(src, kOverreadInBytesPass1 - width); + sq[0] = SquareLo8(s[0]); + ptrdiff_t x = sum_width; + do { + __m128i row3[2], row5[2], row_sq3[2], row_sq5[2]; + x -= 16; + src += 16; + s[1] = LoadUnaligned16Msan(src, + sum_width - x + kOverreadInBytesPass1 - width); + sq[1] = SquareHi8(s[0]); + sq[2] = SquareLo8(s[1]); + SumHorizontal<0>(s, &row3[0], &row3[1], &row5[0], &row5[1]); + StoreAligned32U16(sum3, row3); + StoreAligned32U16(sum5, row5); + SumHorizontal(sq + 0, &row_sq3[0], &row_sq3[1], &row_sq5[0], &row_sq5[1]); + StoreAligned32U32(square_sum3 + 0, row_sq3); + StoreAligned32U32(square_sum5 + 0, row_sq5); + SumHorizontal(sq + 1, &row_sq3[0], &row_sq3[1], &row_sq5[0], &row_sq5[1]); + StoreAligned32U32(square_sum3 + 8, row_sq3); + StoreAligned32U32(square_sum5 + 8, row_sq5); + s[0] = s[1]; + sq[0] = sq[2]; + sum3 += 16; + sum5 += 16; + square_sum3 += 16; + square_sum5 += 16; + } while (x != 0); + src += src_stride - sum_width; + sum3 += sum_stride - sum_width; + sum5 += sum_stride - sum_width; + square_sum3 += sum_stride - sum_width; + square_sum5 += sum_stride - sum_width; + } while (--y != 0); +} + +template <int size> +inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride, + const ptrdiff_t width, const ptrdiff_t sum_stride, + const ptrdiff_t sum_width, uint16_t* sums, + uint32_t* square_sums) { + static_assert(size == 3 || size == 5, ""); + constexpr int kOverreadInBytes = + (size == 5) ? kOverreadInBytesPass1 : kOverreadInBytesPass2; + int y = 2; + do { + __m128i s[2], sq[3]; + s[0] = LoadUnaligned16Msan(src, kOverreadInBytes - width); + sq[0] = SquareLo8(s[0]); + ptrdiff_t x = sum_width; + do { + __m128i row[2], row_sq[4]; + x -= 16; + src += 16; + s[1] = LoadUnaligned16Msan(src, sum_width - x + kOverreadInBytes - width); + sq[1] = SquareHi8(s[0]); + sq[2] = SquareLo8(s[1]); + if (size == 3) { + Sum3Horizontal<0>(s, row); + Sum3WHorizontal(sq + 0, row_sq + 0); + Sum3WHorizontal(sq + 1, row_sq + 2); + } else { + Sum5Horizontal<0>(s, &row[0], &row[1]); + Sum5WHorizontal(sq + 0, row_sq + 0); + Sum5WHorizontal(sq + 1, row_sq + 2); + } + StoreAligned32U16(sums, row); + StoreAligned64U32(square_sums, row_sq); + s[0] = s[1]; + sq[0] = sq[2]; + sums += 16; + square_sums += 16; + } while (x != 0); + src += src_stride - sum_width; + sums += sum_stride - sum_width; + square_sums += sum_stride - sum_width; + } while (--y != 0); +} + +template <int n> +inline __m128i CalculateMa(const __m128i sum, const __m128i sum_sq, + const uint32_t scale) { + static_assert(n == 9 || n == 25, ""); + // a = |sum_sq| + // d = |sum| + // p = (a * n < d * d) ? 0 : a * n - d * d; + const __m128i dxd = _mm_madd_epi16(sum, sum); + // _mm_mullo_epi32() has high latency. Using shifts and additions instead. + // Some compilers could do this for us but we make this explicit. + // return _mm_mullo_epi32(sum_sq, _mm_set1_epi32(n)); + __m128i axn = _mm_add_epi32(sum_sq, _mm_slli_epi32(sum_sq, 3)); + if (n == 25) axn = _mm_add_epi32(axn, _mm_slli_epi32(sum_sq, 4)); + const __m128i sub = _mm_sub_epi32(axn, dxd); + const __m128i p = _mm_max_epi32(sub, _mm_setzero_si128()); + const __m128i pxs = _mm_mullo_epi32(p, _mm_set1_epi32(scale)); + return VrshrU32(pxs, kSgrProjScaleBits); +} + +template <int n> +inline __m128i CalculateMa(const __m128i sum, const __m128i sum_sq[2], + const uint32_t scale) { + static_assert(n == 9 || n == 25, ""); + const __m128i sum_lo = _mm_unpacklo_epi16(sum, _mm_setzero_si128()); + const __m128i sum_hi = _mm_unpackhi_epi16(sum, _mm_setzero_si128()); + const __m128i z0 = CalculateMa<n>(sum_lo, sum_sq[0], scale); + const __m128i z1 = CalculateMa<n>(sum_hi, sum_sq[1], scale); + return _mm_packus_epi32(z0, z1); +} + +template <int n> +inline __m128i CalculateB(const __m128i sum, const __m128i ma) { + static_assert(n == 9 || n == 25, ""); + constexpr uint32_t one_over_n = + ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n; + const __m128i m0 = VmullLo16(ma, sum); + const __m128i m1 = VmullHi16(ma, sum); + const __m128i m2 = _mm_mullo_epi32(m0, _mm_set1_epi32(one_over_n)); + const __m128i m3 = _mm_mullo_epi32(m1, _mm_set1_epi32(one_over_n)); + const __m128i b_lo = VrshrU32(m2, kSgrProjReciprocalBits); + const __m128i b_hi = VrshrU32(m3, kSgrProjReciprocalBits); + return _mm_packus_epi32(b_lo, b_hi); +} + +inline void CalculateSumAndIndex5(const __m128i s5[5], const __m128i sq5[5][2], + const uint32_t scale, __m128i* const sum, + __m128i* const index) { + __m128i sum_sq[2]; + *sum = Sum5_16(s5); + Sum5_32(sq5, sum_sq); + *index = CalculateMa<25>(*sum, sum_sq, scale); +} + +inline void CalculateSumAndIndex3(const __m128i s3[3], const __m128i sq3[3][2], + const uint32_t scale, __m128i* const sum, + __m128i* const index) { + __m128i sum_sq[2]; + *sum = Sum3_16(s3); + Sum3_32(sq3, sum_sq); + *index = CalculateMa<9>(*sum, sum_sq, scale); +} + +template <int n, int offset> +inline void LookupIntermediate(const __m128i sum, const __m128i index, + __m128i* const ma, __m128i* const b) { + static_assert(n == 9 || n == 25, ""); + static_assert(offset == 0 || offset == 8, ""); + const __m128i idx = _mm_packus_epi16(index, index); + // Actually it's not stored and loaded. The compiler will use a 64-bit + // general-purpose register to process. Faster than using _mm_extract_epi8(). + uint8_t temp[8]; + StoreLo8(temp, idx); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[0]], offset + 0); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[1]], offset + 1); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[2]], offset + 2); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[3]], offset + 3); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[4]], offset + 4); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[5]], offset + 5); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[6]], offset + 6); + *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[7]], offset + 7); + // b = ma * b * one_over_n + // |ma| = [0, 255] + // |sum| is a box sum with radius 1 or 2. + // For the first pass radius is 2. Maximum value is 5x5x255 = 6375. + // For the second pass radius is 1. Maximum value is 3x3x255 = 2295. + // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n + // When radius is 2 |n| is 25. |one_over_n| is 164. + // When radius is 1 |n| is 9. |one_over_n| is 455. + // |kSgrProjReciprocalBits| is 12. + // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits). + // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits). + __m128i maq; + if (offset == 0) { + maq = _mm_unpacklo_epi8(*ma, _mm_setzero_si128()); + } else { + maq = _mm_unpackhi_epi8(*ma, _mm_setzero_si128()); + } + *b = CalculateB<n>(sum, maq); +} + +// Set the shuffle control mask of indices out of range [0, 15] to (1xxxxxxx)b +// to get value 0 as the shuffle result. The most significiant bit 1 comes +// either from the comparision instruction, or from the sign bit of the index. +inline __m128i ShuffleIndex(const __m128i table, const __m128i index) { + __m128i mask; + mask = _mm_cmpgt_epi8(index, _mm_set1_epi8(15)); + mask = _mm_or_si128(mask, index); + return _mm_shuffle_epi8(table, mask); +} + +inline __m128i AdjustValue(const __m128i value, const __m128i index, + const int threshold) { + const __m128i thresholds = _mm_set1_epi8(threshold - 128); + const __m128i offset = _mm_cmpgt_epi8(index, thresholds); + return _mm_add_epi8(value, offset); +} + +inline void CalculateIntermediate(const __m128i sum[2], const __m128i index[2], + __m128i* const ma, __m128i* const b0, + __m128i* const b1) { + // Use table lookup to read elements which indices are less than 48. + const __m128i c0 = LoadAligned16(kSgrMaLookup + 0 * 16); + const __m128i c1 = LoadAligned16(kSgrMaLookup + 1 * 16); + const __m128i c2 = LoadAligned16(kSgrMaLookup + 2 * 16); + const __m128i indices = _mm_packus_epi16(index[0], index[1]); + __m128i idx; + // Clip idx to 127 to apply signed comparision instructions. + idx = _mm_min_epu8(indices, _mm_set1_epi8(127)); + // All elements which indices are less than 48 are set to 0. + // Get shuffle results for indices in range [0, 15]. + *ma = ShuffleIndex(c0, idx); + // Get shuffle results for indices in range [16, 31]. + // Subtract 16 to utilize the sign bit of the index. + idx = _mm_sub_epi8(idx, _mm_set1_epi8(16)); + const __m128i res1 = ShuffleIndex(c1, idx); + // Use OR instruction to combine shuffle results together. + *ma = _mm_or_si128(*ma, res1); + // Get shuffle results for indices in range [32, 47]. + // Subtract 16 to utilize the sign bit of the index. + idx = _mm_sub_epi8(idx, _mm_set1_epi8(16)); + const __m128i res2 = ShuffleIndex(c2, idx); + *ma = _mm_or_si128(*ma, res2); + + // For elements which indices are larger than 47, since they seldom change + // values with the increase of the index, we use comparison and arithmetic + // operations to calculate their values. + // Add -128 to apply signed comparision instructions. + idx = _mm_add_epi8(indices, _mm_set1_epi8(-128)); + // Elements which indices are larger than 47 (with value 0) are set to 5. + *ma = _mm_max_epu8(*ma, _mm_set1_epi8(5)); + *ma = AdjustValue(*ma, idx, 55); // 55 is the last index which value is 5. + *ma = AdjustValue(*ma, idx, 72); // 72 is the last index which value is 4. + *ma = AdjustValue(*ma, idx, 101); // 101 is the last index which value is 3. + *ma = AdjustValue(*ma, idx, 169); // 169 is the last index which value is 2. + *ma = AdjustValue(*ma, idx, 254); // 254 is the last index which value is 1. + + // b = ma * b * one_over_n + // |ma| = [0, 255] + // |sum| is a box sum with radius 1 or 2. + // For the first pass radius is 2. Maximum value is 5x5x255 = 6375. + // For the second pass radius is 1. Maximum value is 3x3x255 = 2295. + // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n + // When radius is 2 |n| is 25. |one_over_n| is 164. + // When radius is 1 |n| is 9. |one_over_n| is 455. + // |kSgrProjReciprocalBits| is 12. + // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits). + // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits). + const __m128i maq0 = _mm_unpacklo_epi8(*ma, _mm_setzero_si128()); + *b0 = CalculateB<9>(sum[0], maq0); + const __m128i maq1 = _mm_unpackhi_epi8(*ma, _mm_setzero_si128()); + *b1 = CalculateB<9>(sum[1], maq1); +} + +inline void CalculateIntermediate(const __m128i sum[2], const __m128i index[2], + __m128i ma[2], __m128i b[2]) { + __m128i mas; + CalculateIntermediate(sum, index, &mas, &b[0], &b[1]); + ma[0] = _mm_unpacklo_epi64(ma[0], mas); + ma[1] = _mm_srli_si128(mas, 8); +} + +// Note: It has been tried to call CalculateIntermediate() to replace the slow +// LookupIntermediate() when calculating 16 intermediate data points. However, +// the compiler generates even slower code. +template <int offset> +inline void CalculateIntermediate5(const __m128i s5[5], const __m128i sq5[5][2], + const uint32_t scale, __m128i* const ma, + __m128i* const b) { + static_assert(offset == 0 || offset == 8, ""); + __m128i sum, index; + CalculateSumAndIndex5(s5, sq5, scale, &sum, &index); + LookupIntermediate<25, offset>(sum, index, ma, b); +} + +inline void CalculateIntermediate3(const __m128i s3[3], const __m128i sq3[3][2], + const uint32_t scale, __m128i* const ma, + __m128i* const b) { + __m128i sum, index; + CalculateSumAndIndex3(s3, sq3, scale, &sum, &index); + LookupIntermediate<9, 0>(sum, index, ma, b); +} + +inline void Store343_444(const __m128i b3[2], const ptrdiff_t x, + __m128i sum_b343[2], __m128i sum_b444[2], + uint32_t* const b343, uint32_t* const b444) { + __m128i b[3], sum_b111[2]; + Prepare3_16(b3, b); + sum_b111[0] = Sum3WLo32(b); + sum_b111[1] = Sum3WHi32(b); + sum_b444[0] = _mm_slli_epi32(sum_b111[0], 2); + sum_b444[1] = _mm_slli_epi32(sum_b111[1], 2); + StoreAligned32U32(b444 + x, sum_b444); + sum_b343[0] = _mm_sub_epi32(sum_b444[0], sum_b111[0]); + sum_b343[1] = _mm_sub_epi32(sum_b444[1], sum_b111[1]); + sum_b343[0] = VaddwLo16(sum_b343[0], b[1]); + sum_b343[1] = VaddwHi16(sum_b343[1], b[1]); + StoreAligned32U32(b343 + x, sum_b343); +} + +inline void Store343_444Lo(const __m128i ma3[3], const __m128i b3[2], + const ptrdiff_t x, __m128i* const sum_ma343, + __m128i* const sum_ma444, __m128i sum_b343[2], + __m128i sum_b444[2], uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + const __m128i sum_ma111 = Sum3WLo16(ma3); + *sum_ma444 = _mm_slli_epi16(sum_ma111, 2); + StoreAligned16(ma444 + x, *sum_ma444); + const __m128i sum333 = _mm_sub_epi16(*sum_ma444, sum_ma111); + *sum_ma343 = VaddwLo8(sum333, ma3[1]); + StoreAligned16(ma343 + x, *sum_ma343); + Store343_444(b3, x, sum_b343, sum_b444, b343, b444); +} + +inline void Store343_444Hi(const __m128i ma3[3], const __m128i b3[2], + const ptrdiff_t x, __m128i* const sum_ma343, + __m128i* const sum_ma444, __m128i sum_b343[2], + __m128i sum_b444[2], uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + const __m128i sum_ma111 = Sum3WHi16(ma3); + *sum_ma444 = _mm_slli_epi16(sum_ma111, 2); + StoreAligned16(ma444 + x, *sum_ma444); + const __m128i sum333 = _mm_sub_epi16(*sum_ma444, sum_ma111); + *sum_ma343 = VaddwHi8(sum333, ma3[1]); + StoreAligned16(ma343 + x, *sum_ma343); + Store343_444(b3, x, sum_b343, sum_b444, b343, b444); +} + +inline void Store343_444Lo(const __m128i ma3[3], const __m128i b3[2], + const ptrdiff_t x, __m128i* const sum_ma343, + __m128i sum_b343[2], uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + __m128i sum_ma444, sum_b444[2]; + Store343_444Lo(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, sum_b444, ma343, + ma444, b343, b444); +} + +inline void Store343_444Hi(const __m128i ma3[3], const __m128i b3[2], + const ptrdiff_t x, __m128i* const sum_ma343, + __m128i sum_b343[2], uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + __m128i sum_ma444, sum_b444[2]; + Store343_444Hi(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, sum_b444, ma343, + ma444, b343, b444); +} + +inline void Store343_444Lo(const __m128i ma3[3], const __m128i b3[2], + const ptrdiff_t x, uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + __m128i sum_ma343, sum_b343[2]; + Store343_444Lo(ma3, b3, x, &sum_ma343, sum_b343, ma343, ma444, b343, b444); +} + +inline void Store343_444Hi(const __m128i ma3[3], const __m128i b3[2], + const ptrdiff_t x, uint16_t* const ma343, + uint16_t* const ma444, uint32_t* const b343, + uint32_t* const b444) { + __m128i sum_ma343, sum_b343[2]; + Store343_444Hi(ma3, b3, x, &sum_ma343, sum_b343, ma343, ma444, b343, b444); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5Lo( + const __m128i s[2][2], const uint32_t scale, uint16_t* const sum5[5], + uint32_t* const square_sum5[5], __m128i sq[2][4], __m128i* const ma, + __m128i* const b) { + __m128i s5[2][5], sq5[5][2]; + sq[0][1] = SquareHi8(s[0][0]); + sq[1][1] = SquareHi8(s[1][0]); + s5[0][3] = Sum5Horizontal(s[0][0]); + StoreAligned16(sum5[3], s5[0][3]); + s5[0][4] = Sum5Horizontal(s[1][0]); + StoreAligned16(sum5[4], s5[0][4]); + Sum5WHorizontal(sq[0], sq5[3]); + StoreAligned32U32(square_sum5[3], sq5[3]); + Sum5WHorizontal(sq[1], sq5[4]); + StoreAligned32U32(square_sum5[4], sq5[4]); + LoadAligned16x3U16(sum5, 0, s5[0]); + LoadAligned32x3U32(square_sum5, 0, sq5); + CalculateIntermediate5<0>(s5[0], sq5, scale, ma, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5( + const __m128i s[2][2], const ptrdiff_t sum_width, const ptrdiff_t x, + const uint32_t scale, uint16_t* const sum5[5], + uint32_t* const square_sum5[5], __m128i sq[2][4], __m128i ma[2], + __m128i b[3]) { + __m128i s5[2][5], sq5[5][2]; + sq[0][2] = SquareLo8(s[0][1]); + sq[1][2] = SquareLo8(s[1][1]); + Sum5Horizontal<8>(s[0], &s5[0][3], &s5[1][3]); + StoreAligned16(sum5[3] + x + 0, s5[0][3]); + StoreAligned16(sum5[3] + x + 8, s5[1][3]); + Sum5Horizontal<8>(s[1], &s5[0][4], &s5[1][4]); + StoreAligned16(sum5[4] + x + 0, s5[0][4]); + StoreAligned16(sum5[4] + x + 8, s5[1][4]); + Sum5WHorizontal(sq[0] + 1, sq5[3]); + StoreAligned32U32(square_sum5[3] + x, sq5[3]); + Sum5WHorizontal(sq[1] + 1, sq5[4]); + StoreAligned32U32(square_sum5[4] + x, sq5[4]); + LoadAligned16x3U16(sum5, x, s5[0]); + LoadAligned32x3U32(square_sum5, x, sq5); + CalculateIntermediate5<8>(s5[0], sq5, scale, &ma[0], &b[1]); + + sq[0][3] = SquareHi8(s[0][1]); + sq[1][3] = SquareHi8(s[1][1]); + Sum5WHorizontal(sq[0] + 2, sq5[3]); + StoreAligned32U32(square_sum5[3] + x + 8, sq5[3]); + Sum5WHorizontal(sq[1] + 2, sq5[4]); + StoreAligned32U32(square_sum5[4] + x + 8, sq5[4]); + LoadAligned16x3U16Msan(sum5, x + 8, sum_width, s5[1]); + LoadAligned32x3U32Msan(square_sum5, x + 8, sum_width, sq5); + CalculateIntermediate5<0>(s5[1], sq5, scale, &ma[1], &b[2]); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRowLo( + const __m128i s, const uint32_t scale, const uint16_t* const sum5[5], + const uint32_t* const square_sum5[5], __m128i sq[2], __m128i* const ma, + __m128i* const b) { + __m128i s5[5], sq5[5][2]; + sq[1] = SquareHi8(s); + s5[3] = s5[4] = Sum5Horizontal(s); + Sum5WHorizontal(sq, sq5[3]); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + LoadAligned16x3U16(sum5, 0, s5); + LoadAligned32x3U32(square_sum5, 0, sq5); + CalculateIntermediate5<0>(s5, sq5, scale, ma, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRow( + const __m128i s[2], const ptrdiff_t sum_width, const ptrdiff_t x, + const uint32_t scale, const uint16_t* const sum5[5], + const uint32_t* const square_sum5[5], __m128i sq[4], __m128i ma[2], + __m128i b[3]) { + __m128i s5[2][5], sq5[5][2]; + sq[2] = SquareLo8(s[1]); + Sum5Horizontal<8>(s, &s5[0][3], &s5[1][3]); + s5[0][4] = s5[0][3]; + s5[1][4] = s5[1][3]; + Sum5WHorizontal(sq + 1, sq5[3]); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + LoadAligned16x3U16(sum5, x, s5[0]); + LoadAligned32x3U32(square_sum5, x, sq5); + CalculateIntermediate5<8>(s5[0], sq5, scale, &ma[0], &b[1]); + + sq[3] = SquareHi8(s[1]); + Sum5WHorizontal(sq + 2, sq5[3]); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + LoadAligned16x3U16Msan(sum5, x + 8, sum_width, s5[1]); + LoadAligned32x3U32Msan(square_sum5, x + 8, sum_width, sq5); + CalculateIntermediate5<0>(s5[1], sq5, scale, &ma[1], &b[2]); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3Lo( + const __m128i s, const uint32_t scale, uint16_t* const sum3[3], + uint32_t* const square_sum3[3], __m128i sq[2], __m128i* const ma, + __m128i* const b) { + __m128i s3[3], sq3[3][2]; + sq[1] = SquareHi8(s); + s3[2] = Sum3Horizontal(s); + StoreAligned16(sum3[2], s3[2]); + Sum3WHorizontal(sq, sq3[2]); + StoreAligned32U32(square_sum3[2], sq3[2]); + LoadAligned16x2U16(sum3, 0, s3); + LoadAligned32x2U32(square_sum3, 0, sq3); + CalculateIntermediate3(s3, sq3, scale, ma, b); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3( + const __m128i s[2], const ptrdiff_t x, const ptrdiff_t sum_width, + const uint32_t scale, uint16_t* const sum3[3], + uint32_t* const square_sum3[3], __m128i sq[4], __m128i ma[2], + __m128i b[3]) { + __m128i s3[4], sq3[3][2], sum[2], index[2]; + sq[2] = SquareLo8(s[1]); + Sum3Horizontal<8>(s, s3 + 2); + StoreAligned32U16(sum3[2] + x, s3 + 2); + Sum3WHorizontal(sq + 1, sq3[2]); + StoreAligned32U32(square_sum3[2] + x + 0, sq3[2]); + LoadAligned16x2U16(sum3, x, s3); + LoadAligned32x2U32(square_sum3, x, sq3); + CalculateSumAndIndex3(s3, sq3, scale, &sum[0], &index[0]); + + sq[3] = SquareHi8(s[1]); + Sum3WHorizontal(sq + 2, sq3[2]); + StoreAligned32U32(square_sum3[2] + x + 8, sq3[2]); + LoadAligned16x2U16Msan(sum3, x + 8, sum_width, s3 + 1); + LoadAligned32x2U32Msan(square_sum3, x + 8, sum_width, sq3); + CalculateSumAndIndex3(s3 + 1, sq3, scale, &sum[1], &index[1]); + CalculateIntermediate(sum, index, ma, b + 1); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLo( + const __m128i s[2][2], const uint16_t scales[2], uint16_t* const sum3[4], + uint16_t* const sum5[5], uint32_t* const square_sum3[4], + uint32_t* const square_sum5[5], __m128i sq[2][4], __m128i ma3[2][2], + __m128i b3[2][3], __m128i* const ma5, __m128i* const b5) { + __m128i s3[4], s5[5], sq3[4][2], sq5[5][2], sum[2], index[2]; + sq[0][1] = SquareHi8(s[0][0]); + sq[1][1] = SquareHi8(s[1][0]); + SumHorizontalLo(s[0][0], &s3[2], &s5[3]); + SumHorizontalLo(s[1][0], &s3[3], &s5[4]); + StoreAligned16(sum3[2], s3[2]); + StoreAligned16(sum3[3], s3[3]); + StoreAligned16(sum5[3], s5[3]); + StoreAligned16(sum5[4], s5[4]); + SumHorizontal(sq[0], &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + StoreAligned32U32(square_sum3[2], sq3[2]); + StoreAligned32U32(square_sum5[3], sq5[3]); + SumHorizontal(sq[1], &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]); + StoreAligned32U32(square_sum3[3], sq3[3]); + StoreAligned32U32(square_sum5[4], sq5[4]); + LoadAligned16x2U16(sum3, 0, s3); + LoadAligned32x2U32(square_sum3, 0, sq3); + LoadAligned16x3U16(sum5, 0, s5); + LoadAligned32x3U32(square_sum5, 0, sq5); + CalculateSumAndIndex3(s3 + 0, sq3 + 0, scales[1], &sum[0], &index[0]); + CalculateSumAndIndex3(s3 + 1, sq3 + 1, scales[1], &sum[1], &index[1]); + CalculateIntermediate(sum, index, &ma3[0][0], &b3[0][0], &b3[1][0]); + ma3[1][0] = _mm_srli_si128(ma3[0][0], 8); + CalculateIntermediate5<0>(s5, sq5, scales[0], ma5, b5); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess( + const __m128i s[2][2], const ptrdiff_t x, const uint16_t scales[2], + uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + const ptrdiff_t sum_width, __m128i sq[2][4], __m128i ma3[2][2], + __m128i b3[2][3], __m128i ma5[2], __m128i b5[3]) { + __m128i s3[2][4], s5[2][5], sq3[4][2], sq5[5][2], sum[2][2], index[2][2]; + SumHorizontal<8>(s[0], &s3[0][2], &s3[1][2], &s5[0][3], &s5[1][3]); + StoreAligned16(sum3[2] + x + 0, s3[0][2]); + StoreAligned16(sum3[2] + x + 8, s3[1][2]); + StoreAligned16(sum5[3] + x + 0, s5[0][3]); + StoreAligned16(sum5[3] + x + 8, s5[1][3]); + SumHorizontal<8>(s[1], &s3[0][3], &s3[1][3], &s5[0][4], &s5[1][4]); + StoreAligned16(sum3[3] + x + 0, s3[0][3]); + StoreAligned16(sum3[3] + x + 8, s3[1][3]); + StoreAligned16(sum5[4] + x + 0, s5[0][4]); + StoreAligned16(sum5[4] + x + 8, s5[1][4]); + sq[0][2] = SquareLo8(s[0][1]); + sq[1][2] = SquareLo8(s[1][1]); + SumHorizontal(sq[0] + 1, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + StoreAligned32U32(square_sum3[2] + x, sq3[2]); + StoreAligned32U32(square_sum5[3] + x, sq5[3]); + SumHorizontal(sq[1] + 1, &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]); + StoreAligned32U32(square_sum3[3] + x, sq3[3]); + StoreAligned32U32(square_sum5[4] + x, sq5[4]); + LoadAligned16x2U16(sum3, x, s3[0]); + LoadAligned32x2U32(square_sum3, x, sq3); + CalculateSumAndIndex3(s3[0], sq3, scales[1], &sum[0][0], &index[0][0]); + CalculateSumAndIndex3(s3[0] + 1, sq3 + 1, scales[1], &sum[1][0], + &index[1][0]); + LoadAligned16x3U16(sum5, x, s5[0]); + LoadAligned32x3U32(square_sum5, x, sq5); + CalculateIntermediate5<8>(s5[0], sq5, scales[0], &ma5[0], &b5[1]); + + sq[0][3] = SquareHi8(s[0][1]); + sq[1][3] = SquareHi8(s[1][1]); + SumHorizontal(sq[0] + 2, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + StoreAligned32U32(square_sum3[2] + x + 8, sq3[2]); + StoreAligned32U32(square_sum5[3] + x + 8, sq5[3]); + SumHorizontal(sq[1] + 2, &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]); + StoreAligned32U32(square_sum3[3] + x + 8, sq3[3]); + StoreAligned32U32(square_sum5[4] + x + 8, sq5[4]); + LoadAligned16x2U16Msan(sum3, x + 8, sum_width, s3[1]); + LoadAligned32x2U32Msan(square_sum3, x + 8, sum_width, sq3); + CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum[0][1], &index[0][1]); + CalculateSumAndIndex3(s3[1] + 1, sq3 + 1, scales[1], &sum[1][1], + &index[1][1]); + CalculateIntermediate(sum[0], index[0], ma3[0], b3[0] + 1); + CalculateIntermediate(sum[1], index[1], ma3[1], b3[1] + 1); + LoadAligned16x3U16Msan(sum5, x + 8, sum_width, s5[1]); + LoadAligned32x3U32Msan(square_sum5, x + 8, sum_width, sq5); + CalculateIntermediate5<0>(s5[1], sq5, scales[0], &ma5[1], &b5[2]); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRowLo( + const __m128i s, const uint16_t scales[2], const uint16_t* const sum3[4], + const uint16_t* const sum5[5], const uint32_t* const square_sum3[4], + const uint32_t* const square_sum5[5], __m128i sq[2], __m128i* const ma3, + __m128i* const ma5, __m128i* const b3, __m128i* const b5) { + __m128i s3[3], s5[5], sq3[3][2], sq5[5][2]; + sq[1] = SquareHi8(s); + SumHorizontalLo(s, &s3[2], &s5[3]); + SumHorizontal(sq, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + LoadAligned16x3U16(sum5, 0, s5); + s5[4] = s5[3]; + LoadAligned32x3U32(square_sum5, 0, sq5); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + CalculateIntermediate5<0>(s5, sq5, scales[0], ma5, b5); + LoadAligned16x2U16(sum3, 0, s3); + LoadAligned32x2U32(square_sum3, 0, sq3); + CalculateIntermediate3(s3, sq3, scales[1], ma3, b3); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow( + const __m128i s[2], const ptrdiff_t sum_width, const ptrdiff_t x, + const uint16_t scales[2], const uint16_t* const sum3[4], + const uint16_t* const sum5[5], const uint32_t* const square_sum3[4], + const uint32_t* const square_sum5[5], __m128i sq[4], __m128i ma3[2], + __m128i ma5[2], __m128i b3[3], __m128i b5[3]) { + __m128i s3[2][3], s5[2][5], sq3[3][2], sq5[5][2], sum[2], index[2]; + sq[2] = SquareLo8(s[1]); + SumHorizontal<8>(s, &s3[0][2], &s3[1][2], &s5[0][3], &s5[1][3]); + SumHorizontal(sq + 1, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + LoadAligned16x3U16(sum5, x, s5[0]); + s5[0][4] = s5[0][3]; + LoadAligned32x3U32(square_sum5, x, sq5); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + CalculateIntermediate5<8>(s5[0], sq5, scales[0], ma5, b5 + 1); + LoadAligned16x2U16(sum3, x, s3[0]); + LoadAligned32x2U32(square_sum3, x, sq3); + CalculateSumAndIndex3(s3[0], sq3, scales[1], &sum[0], &index[0]); + + sq[3] = SquareHi8(s[1]); + SumHorizontal(sq + 2, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]); + LoadAligned16x3U16Msan(sum5, x + 8, sum_width, s5[1]); + s5[1][4] = s5[1][3]; + LoadAligned32x3U32Msan(square_sum5, x + 8, sum_width, sq5); + sq5[4][0] = sq5[3][0]; + sq5[4][1] = sq5[3][1]; + CalculateIntermediate5<0>(s5[1], sq5, scales[0], ma5 + 1, b5 + 2); + LoadAligned16x2U16Msan(sum3, x + 8, sum_width, s3[1]); + LoadAligned32x2U32Msan(square_sum3, x + 8, sum_width, sq3); + CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum[1], &index[1]); + CalculateIntermediate(sum, index, ma3, b3 + 1); +} + +inline void BoxSumFilterPreProcess5(const uint8_t* const src0, + const uint8_t* const src1, const int width, + const uint32_t scale, + uint16_t* const sum5[5], + uint32_t* const square_sum5[5], + const ptrdiff_t sum_width, uint16_t* ma565, + uint32_t* b565) { + __m128i s[2][2], mas[2], sq[2][4], bs[3]; + s[0][0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1 - width); + s[1][0] = LoadUnaligned16Msan(src1, kOverreadInBytesPass1 - width); + sq[0][0] = SquareLo8(s[0][0]); + sq[1][0] = SquareLo8(s[1][0]); + BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq, &mas[0], &bs[0]); + + int x = 0; + do { + __m128i ma5[3], ma[2], b[4]; + s[0][1] = LoadUnaligned16Msan(src0 + x + 16, + x + 16 + kOverreadInBytesPass1 - width); + s[1][1] = LoadUnaligned16Msan(src1 + x + 16, + x + 16 + kOverreadInBytesPass1 - width); + BoxFilterPreProcess5(s, sum_width, x + 8, scale, sum5, square_sum5, sq, mas, + bs); + Prepare3_8<0>(mas, ma5); + ma[0] = Sum565Lo(ma5); + ma[1] = Sum565Hi(ma5); + StoreAligned32U16(ma565, ma); + Sum565W(bs + 0, b + 0); + Sum565W(bs + 1, b + 2); + StoreAligned64U32(b565, b); + s[0][0] = s[0][1]; + s[1][0] = s[1][1]; + sq[0][1] = sq[0][3]; + sq[1][1] = sq[1][3]; + mas[0] = mas[1]; + bs[0] = bs[2]; + ma565 += 16; + b565 += 16; + x += 16; + } while (x < width); +} + +template <bool calculate444> +LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3( + const uint8_t* const src, const int width, const uint32_t scale, + uint16_t* const sum3[3], uint32_t* const square_sum3[3], + const ptrdiff_t sum_width, uint16_t* ma343, uint16_t* ma444, uint32_t* b343, + uint32_t* b444) { + __m128i s[2], mas[2], sq[4], bs[3]; + s[0] = LoadUnaligned16Msan(src, kOverreadInBytesPass2 - width); + sq[0] = SquareLo8(s[0]); + BoxFilterPreProcess3Lo(s[0], scale, sum3, square_sum3, sq, &mas[0], &bs[0]); + + int x = 0; + do { + s[1] = LoadUnaligned16Msan(src + x + 16, + x + 16 + kOverreadInBytesPass2 - width); + BoxFilterPreProcess3(s, x + 8, sum_width, scale, sum3, square_sum3, sq, mas, + bs); + __m128i ma3[3]; + Prepare3_8<0>(mas, ma3); + if (calculate444) { // NOLINT(readability-simplify-boolean-expr) + Store343_444Lo(ma3, bs + 0, 0, ma343, ma444, b343, b444); + Store343_444Hi(ma3, bs + 1, 8, ma343, ma444, b343, b444); + ma444 += 16; + b444 += 16; + } else { + __m128i ma[2], b[4]; + ma[0] = Sum343Lo(ma3); + ma[1] = Sum343Hi(ma3); + StoreAligned32U16(ma343, ma); + Sum343W(bs + 0, b + 0); + Sum343W(bs + 1, b + 2); + StoreAligned64U32(b343, b); + } + s[0] = s[1]; + sq[1] = sq[3]; + mas[0] = mas[1]; + bs[0] = bs[2]; + ma343 += 16; + b343 += 16; + x += 16; + } while (x < width); +} + +inline void BoxSumFilterPreProcess( + const uint8_t* const src0, const uint8_t* const src1, const int width, + const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + const ptrdiff_t sum_width, uint16_t* const ma343[4], + uint16_t* const ma444[2], uint16_t* ma565, uint32_t* const b343[4], + uint32_t* const b444[2], uint32_t* b565) { + __m128i s[2][2], ma3[2][2], ma5[2], sq[2][4], b3[2][3], b5[3]; + s[0][0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1 - width); + s[1][0] = LoadUnaligned16Msan(src1, kOverreadInBytesPass1 - width); + sq[0][0] = SquareLo8(s[0][0]); + sq[1][0] = SquareLo8(s[1][0]); + BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq, + ma3, b3, &ma5[0], &b5[0]); + + int x = 0; + do { + __m128i ma[2], b[4], ma3x[3], ma5x[3]; + s[0][1] = LoadUnaligned16Msan(src0 + x + 16, + x + 16 + kOverreadInBytesPass1 - width); + s[1][1] = LoadUnaligned16Msan(src1 + x + 16, + x + 16 + kOverreadInBytesPass1 - width); + BoxFilterPreProcess(s, x + 8, scales, sum3, sum5, square_sum3, square_sum5, + sum_width, sq, ma3, b3, ma5, b5); + + Prepare3_8<0>(ma3[0], ma3x); + ma[0] = Sum343Lo(ma3x); + ma[1] = Sum343Hi(ma3x); + StoreAligned32U16(ma343[0] + x, ma); + Sum343W(b3[0] + 0, b + 0); + Sum343W(b3[0] + 1, b + 2); + StoreAligned64U32(b343[0] + x, b); + Sum565W(b5 + 0, b + 0); + Sum565W(b5 + 1, b + 2); + StoreAligned64U32(b565, b); + Prepare3_8<0>(ma3[1], ma3x); + Store343_444Lo(ma3x, b3[1], x, ma343[1], ma444[0], b343[1], b444[0]); + Store343_444Hi(ma3x, b3[1] + 1, x + 8, ma343[1], ma444[0], b343[1], + b444[0]); + Prepare3_8<0>(ma5, ma5x); + ma[0] = Sum565Lo(ma5x); + ma[1] = Sum565Hi(ma5x); + StoreAligned32U16(ma565, ma); + s[0][0] = s[0][1]; + s[1][0] = s[1][1]; + sq[0][1] = sq[0][3]; + sq[1][1] = sq[1][3]; + ma3[0][0] = ma3[0][1]; + ma3[1][0] = ma3[1][1]; + ma5[0] = ma5[1]; + b3[0][0] = b3[0][2]; + b3[1][0] = b3[1][2]; + b5[0] = b5[2]; + ma565 += 16; + b565 += 16; + x += 16; + } while (x < width); +} + +template <int shift> +inline __m128i FilterOutput(const __m128i ma_x_src, const __m128i b) { + // ma: 255 * 32 = 8160 (13 bits) + // b: 65088 * 32 = 2082816 (21 bits) + // v: b - ma * 255 (22 bits) + const __m128i v = _mm_sub_epi32(b, ma_x_src); + // kSgrProjSgrBits = 8 + // kSgrProjRestoreBits = 4 + // shift = 4 or 5 + // v >> 8 or 9 (13 bits) + return VrshrS32(v, kSgrProjSgrBits + shift - kSgrProjRestoreBits); +} + +template <int shift> +inline __m128i CalculateFilteredOutput(const __m128i src, const __m128i ma, + const __m128i b[2]) { + const __m128i ma_x_src_lo = VmullLo16(ma, src); + const __m128i ma_x_src_hi = VmullHi16(ma, src); + const __m128i dst_lo = FilterOutput<shift>(ma_x_src_lo, b[0]); + const __m128i dst_hi = FilterOutput<shift>(ma_x_src_hi, b[1]); + return _mm_packs_epi32(dst_lo, dst_hi); // 13 bits +} + +inline __m128i CalculateFilteredOutputPass1(const __m128i src, __m128i ma[2], + __m128i b[2][2]) { + const __m128i ma_sum = _mm_add_epi16(ma[0], ma[1]); + __m128i b_sum[2]; + b_sum[0] = _mm_add_epi32(b[0][0], b[1][0]); + b_sum[1] = _mm_add_epi32(b[0][1], b[1][1]); + return CalculateFilteredOutput<5>(src, ma_sum, b_sum); +} + +inline __m128i CalculateFilteredOutputPass2(const __m128i src, __m128i ma[3], + __m128i b[3][2]) { + const __m128i ma_sum = Sum3_16(ma); + __m128i b_sum[2]; + Sum3_32(b, b_sum); + return CalculateFilteredOutput<5>(src, ma_sum, b_sum); +} + +inline __m128i SelfGuidedFinal(const __m128i src, const __m128i v[2]) { + const __m128i v_lo = + VrshrS32(v[0], kSgrProjRestoreBits + kSgrProjPrecisionBits); + const __m128i v_hi = + VrshrS32(v[1], kSgrProjRestoreBits + kSgrProjPrecisionBits); + const __m128i vv = _mm_packs_epi32(v_lo, v_hi); + return _mm_add_epi16(src, vv); +} + +inline __m128i SelfGuidedDoubleMultiplier(const __m128i src, + const __m128i filter[2], const int w0, + const int w2) { + __m128i v[2]; + const __m128i w0_w2 = _mm_set1_epi32((w2 << 16) | static_cast<uint16_t>(w0)); + const __m128i f_lo = _mm_unpacklo_epi16(filter[0], filter[1]); + const __m128i f_hi = _mm_unpackhi_epi16(filter[0], filter[1]); + v[0] = _mm_madd_epi16(w0_w2, f_lo); + v[1] = _mm_madd_epi16(w0_w2, f_hi); + return SelfGuidedFinal(src, v); +} + +inline __m128i SelfGuidedSingleMultiplier(const __m128i src, + const __m128i filter, const int w0) { + // weight: -96 to 96 (Sgrproj_Xqd_Min/Max) + __m128i v[2]; + v[0] = VmullNLo8(filter, w0); + v[1] = VmullNHi8(filter, w0); + return SelfGuidedFinal(src, v); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPass1( + const uint8_t* const src, const uint8_t* const src0, + const uint8_t* const src1, const ptrdiff_t stride, uint16_t* const sum5[5], + uint32_t* const square_sum5[5], const int width, const ptrdiff_t sum_width, + const uint32_t scale, const int16_t w0, uint16_t* const ma565[2], + uint32_t* const b565[2], uint8_t* const dst) { + __m128i s[2][2], mas[2], sq[2][4], bs[3]; + s[0][0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1 - width); + s[1][0] = LoadUnaligned16Msan(src1, kOverreadInBytesPass1 - width); + sq[0][0] = SquareLo8(s[0][0]); + sq[1][0] = SquareLo8(s[1][0]); + BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq, &mas[0], &bs[0]); + + int x = 0; + do { + __m128i ma[2], ma3[3], b[2][2], sr[2], p[2]; + s[0][1] = LoadUnaligned16Msan(src0 + x + 16, + x + 16 + kOverreadInBytesPass1 - width); + s[1][1] = LoadUnaligned16Msan(src1 + x + 16, + x + 16 + kOverreadInBytesPass1 - width); + BoxFilterPreProcess5(s, sum_width, x + 8, scale, sum5, square_sum5, sq, mas, + bs); + Prepare3_8<0>(mas, ma3); + ma[1] = Sum565Lo(ma3); + StoreAligned16(ma565[1] + x, ma[1]); + Sum565W(bs, b[1]); + StoreAligned32U32(b565[1] + x, b[1]); + sr[0] = LoadAligned16(src + x); + sr[1] = LoadAligned16(src + stride + x); + const __m128i sr0_lo = _mm_unpacklo_epi8(sr[0], _mm_setzero_si128()); + const __m128i sr1_lo = _mm_unpacklo_epi8(sr[1], _mm_setzero_si128()); + ma[0] = LoadAligned16(ma565[0] + x); + LoadAligned32U32(b565[0] + x, b[0]); + p[0] = CalculateFilteredOutputPass1(sr0_lo, ma, b); + p[1] = CalculateFilteredOutput<4>(sr1_lo, ma[1], b[1]); + const __m128i d00 = SelfGuidedSingleMultiplier(sr0_lo, p[0], w0); + const __m128i d10 = SelfGuidedSingleMultiplier(sr1_lo, p[1], w0); + + ma[1] = Sum565Hi(ma3); + StoreAligned16(ma565[1] + x + 8, ma[1]); + Sum565W(bs + 1, b[1]); + StoreAligned32U32(b565[1] + x + 8, b[1]); + const __m128i sr0_hi = _mm_unpackhi_epi8(sr[0], _mm_setzero_si128()); + const __m128i sr1_hi = _mm_unpackhi_epi8(sr[1], _mm_setzero_si128()); + ma[0] = LoadAligned16(ma565[0] + x + 8); + LoadAligned32U32(b565[0] + x + 8, b[0]); + p[0] = CalculateFilteredOutputPass1(sr0_hi, ma, b); + p[1] = CalculateFilteredOutput<4>(sr1_hi, ma[1], b[1]); + const __m128i d01 = SelfGuidedSingleMultiplier(sr0_hi, p[0], w0); + StoreAligned16(dst + x, _mm_packus_epi16(d00, d01)); + const __m128i d11 = SelfGuidedSingleMultiplier(sr1_hi, p[1], w0); + StoreAligned16(dst + stride + x, _mm_packus_epi16(d10, d11)); + s[0][0] = s[0][1]; + s[1][0] = s[1][1]; + sq[0][1] = sq[0][3]; + sq[1][1] = sq[1][3]; + mas[0] = mas[1]; + bs[0] = bs[2]; + x += 16; + } while (x < width); +} + +inline void BoxFilterPass1LastRow( + const uint8_t* const src, const uint8_t* const src0, const int width, + const ptrdiff_t sum_width, const uint32_t scale, const int16_t w0, + uint16_t* const sum5[5], uint32_t* const square_sum5[5], uint16_t* ma565, + uint32_t* b565, uint8_t* const dst) { + __m128i s[2], mas[2], sq[4], bs[3]; + s[0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1 - width); + sq[0] = SquareLo8(s[0]); + BoxFilterPreProcess5LastRowLo(s[0], scale, sum5, square_sum5, sq, &mas[0], + &bs[0]); + + int x = 0; + do { + __m128i ma[2], ma5[3], b[2][2]; + s[1] = LoadUnaligned16Msan(src0 + x + 16, + x + 16 + kOverreadInBytesPass1 - width); + BoxFilterPreProcess5LastRow(s, sum_width, x + 8, scale, sum5, square_sum5, + sq, mas, bs); + Prepare3_8<0>(mas, ma5); + ma[1] = Sum565Lo(ma5); + Sum565W(bs, b[1]); + ma[0] = LoadAligned16(ma565); + LoadAligned32U32(b565, b[0]); + const __m128i sr = LoadAligned16(src + x); + const __m128i sr_lo = _mm_unpacklo_epi8(sr, _mm_setzero_si128()); + __m128i p = CalculateFilteredOutputPass1(sr_lo, ma, b); + const __m128i d0 = SelfGuidedSingleMultiplier(sr_lo, p, w0); + + ma[1] = Sum565Hi(ma5); + Sum565W(bs + 1, b[1]); + ma[0] = LoadAligned16(ma565 + 8); + LoadAligned32U32(b565 + 8, b[0]); + const __m128i sr_hi = _mm_unpackhi_epi8(sr, _mm_setzero_si128()); + p = CalculateFilteredOutputPass1(sr_hi, ma, b); + const __m128i d1 = SelfGuidedSingleMultiplier(sr_hi, p, w0); + StoreAligned16(dst + x, _mm_packus_epi16(d0, d1)); + s[0] = s[1]; + sq[1] = sq[3]; + mas[0] = mas[1]; + bs[0] = bs[2]; + ma565 += 16; + b565 += 16; + x += 16; + } while (x < width); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterPass2( + const uint8_t* const src, const uint8_t* const src0, const int width, + const ptrdiff_t sum_width, const uint32_t scale, const int16_t w0, + uint16_t* const sum3[3], uint32_t* const square_sum3[3], + uint16_t* const ma343[3], uint16_t* const ma444[2], uint32_t* const b343[3], + uint32_t* const b444[2], uint8_t* const dst) { + __m128i s[2], mas[2], sq[4], bs[3]; + s[0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass2 - width); + sq[0] = SquareLo8(s[0]); + BoxFilterPreProcess3Lo(s[0], scale, sum3, square_sum3, sq, &mas[0], &bs[0]); + + int x = 0; + do { + s[1] = LoadUnaligned16Msan(src0 + x + 16, + x + 16 + kOverreadInBytesPass2 - width); + BoxFilterPreProcess3(s, x + 8, sum_width, scale, sum3, square_sum3, sq, mas, + bs); + __m128i ma[3], b[3][2], ma3[3]; + Prepare3_8<0>(mas, ma3); + Store343_444Lo(ma3, bs + 0, x, &ma[2], b[2], ma343[2], ma444[1], b343[2], + b444[1]); + const __m128i sr = LoadAligned16(src + x); + const __m128i sr_lo = _mm_unpacklo_epi8(sr, _mm_setzero_si128()); + ma[0] = LoadAligned16(ma343[0] + x); + ma[1] = LoadAligned16(ma444[0] + x); + LoadAligned32U32(b343[0] + x, b[0]); + LoadAligned32U32(b444[0] + x, b[1]); + const __m128i p0 = CalculateFilteredOutputPass2(sr_lo, ma, b); + + Store343_444Hi(ma3, bs + 1, x + 8, &ma[2], b[2], ma343[2], ma444[1], + b343[2], b444[1]); + const __m128i sr_hi = _mm_unpackhi_epi8(sr, _mm_setzero_si128()); + ma[0] = LoadAligned16(ma343[0] + x + 8); + ma[1] = LoadAligned16(ma444[0] + x + 8); + LoadAligned32U32(b343[0] + x + 8, b[0]); + LoadAligned32U32(b444[0] + x + 8, b[1]); + const __m128i p1 = CalculateFilteredOutputPass2(sr_hi, ma, b); + const __m128i d0 = SelfGuidedSingleMultiplier(sr_lo, p0, w0); + const __m128i d1 = SelfGuidedSingleMultiplier(sr_hi, p1, w0); + StoreAligned16(dst + x, _mm_packus_epi16(d0, d1)); + s[0] = s[1]; + sq[1] = sq[3]; + mas[0] = mas[1]; + bs[0] = bs[2]; + x += 16; + } while (x < width); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilter( + const uint8_t* const src, const uint8_t* const src0, + const uint8_t* const src1, const ptrdiff_t stride, const int width, + const uint16_t scales[2], const int16_t w0, const int16_t w2, + uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + const ptrdiff_t sum_width, uint16_t* const ma343[4], + uint16_t* const ma444[3], uint16_t* const ma565[2], uint32_t* const b343[4], + uint32_t* const b444[3], uint32_t* const b565[2], uint8_t* const dst) { + __m128i s[2][2], ma3[2][2], ma5[2], sq[2][4], b3[2][3], b5[3]; + s[0][0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1 - width); + s[1][0] = LoadUnaligned16Msan(src1, kOverreadInBytesPass1 - width); + sq[0][0] = SquareLo8(s[0][0]); + sq[1][0] = SquareLo8(s[1][0]); + BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq, + ma3, b3, &ma5[0], &b5[0]); + + int x = 0; + do { + __m128i ma[3][3], b[3][3][2], p[2][2], ma3x[2][3], ma5x[3]; + s[0][1] = LoadUnaligned16Msan(src0 + x + 16, + x + 16 + kOverreadInBytesPass1 - width); + s[1][1] = LoadUnaligned16Msan(src1 + x + 16, + x + 16 + kOverreadInBytesPass1 - width); + BoxFilterPreProcess(s, x + 8, scales, sum3, sum5, square_sum3, square_sum5, + sum_width, sq, ma3, b3, ma5, b5); + Prepare3_8<0>(ma3[0], ma3x[0]); + Prepare3_8<0>(ma3[1], ma3x[1]); + Prepare3_8<0>(ma5, ma5x); + Store343_444Lo(ma3x[0], b3[0], x, &ma[1][2], &ma[2][1], b[1][2], b[2][1], + ma343[2], ma444[1], b343[2], b444[1]); + Store343_444Lo(ma3x[1], b3[1], x, &ma[2][2], b[2][2], ma343[3], ma444[2], + b343[3], b444[2]); + ma[0][1] = Sum565Lo(ma5x); + StoreAligned16(ma565[1] + x, ma[0][1]); + Sum565W(b5, b[0][1]); + StoreAligned32U32(b565[1] + x, b[0][1]); + const __m128i sr0 = LoadAligned16(src + x); + const __m128i sr1 = LoadAligned16(src + stride + x); + const __m128i sr0_lo = _mm_unpacklo_epi8(sr0, _mm_setzero_si128()); + const __m128i sr1_lo = _mm_unpacklo_epi8(sr1, _mm_setzero_si128()); + ma[0][0] = LoadAligned16(ma565[0] + x); + LoadAligned32U32(b565[0] + x, b[0][0]); + p[0][0] = CalculateFilteredOutputPass1(sr0_lo, ma[0], b[0]); + p[1][0] = CalculateFilteredOutput<4>(sr1_lo, ma[0][1], b[0][1]); + ma[1][0] = LoadAligned16(ma343[0] + x); + ma[1][1] = LoadAligned16(ma444[0] + x); + LoadAligned32U32(b343[0] + x, b[1][0]); + LoadAligned32U32(b444[0] + x, b[1][1]); + p[0][1] = CalculateFilteredOutputPass2(sr0_lo, ma[1], b[1]); + const __m128i d00 = SelfGuidedDoubleMultiplier(sr0_lo, p[0], w0, w2); + ma[2][0] = LoadAligned16(ma343[1] + x); + LoadAligned32U32(b343[1] + x, b[2][0]); + p[1][1] = CalculateFilteredOutputPass2(sr1_lo, ma[2], b[2]); + const __m128i d10 = SelfGuidedDoubleMultiplier(sr1_lo, p[1], w0, w2); + + Store343_444Hi(ma3x[0], b3[0] + 1, x + 8, &ma[1][2], &ma[2][1], b[1][2], + b[2][1], ma343[2], ma444[1], b343[2], b444[1]); + Store343_444Hi(ma3x[1], b3[1] + 1, x + 8, &ma[2][2], b[2][2], ma343[3], + ma444[2], b343[3], b444[2]); + ma[0][1] = Sum565Hi(ma5x); + StoreAligned16(ma565[1] + x + 8, ma[0][1]); + Sum565W(b5 + 1, b[0][1]); + StoreAligned32U32(b565[1] + x + 8, b[0][1]); + const __m128i sr0_hi = _mm_unpackhi_epi8(sr0, _mm_setzero_si128()); + const __m128i sr1_hi = _mm_unpackhi_epi8(sr1, _mm_setzero_si128()); + ma[0][0] = LoadAligned16(ma565[0] + x + 8); + LoadAligned32U32(b565[0] + x + 8, b[0][0]); + p[0][0] = CalculateFilteredOutputPass1(sr0_hi, ma[0], b[0]); + p[1][0] = CalculateFilteredOutput<4>(sr1_hi, ma[0][1], b[0][1]); + ma[1][0] = LoadAligned16(ma343[0] + x + 8); + ma[1][1] = LoadAligned16(ma444[0] + x + 8); + LoadAligned32U32(b343[0] + x + 8, b[1][0]); + LoadAligned32U32(b444[0] + x + 8, b[1][1]); + p[0][1] = CalculateFilteredOutputPass2(sr0_hi, ma[1], b[1]); + const __m128i d01 = SelfGuidedDoubleMultiplier(sr0_hi, p[0], w0, w2); + StoreAligned16(dst + x, _mm_packus_epi16(d00, d01)); + ma[2][0] = LoadAligned16(ma343[1] + x + 8); + LoadAligned32U32(b343[1] + x + 8, b[2][0]); + p[1][1] = CalculateFilteredOutputPass2(sr1_hi, ma[2], b[2]); + const __m128i d11 = SelfGuidedDoubleMultiplier(sr1_hi, p[1], w0, w2); + StoreAligned16(dst + stride + x, _mm_packus_epi16(d10, d11)); + s[0][0] = s[0][1]; + s[1][0] = s[1][1]; + sq[0][1] = sq[0][3]; + sq[1][1] = sq[1][3]; + ma3[0][0] = ma3[0][1]; + ma3[1][0] = ma3[1][1]; + ma5[0] = ma5[1]; + b3[0][0] = b3[0][2]; + b3[1][0] = b3[1][2]; + b5[0] = b5[2]; + x += 16; + } while (x < width); +} + +inline void BoxFilterLastRow( + const uint8_t* const src, const uint8_t* const src0, const int width, + const ptrdiff_t sum_width, const uint16_t scales[2], const int16_t w0, + const int16_t w2, uint16_t* const sum3[4], uint16_t* const sum5[5], + uint32_t* const square_sum3[4], uint32_t* const square_sum5[5], + uint16_t* const ma343[4], uint16_t* const ma444[3], + uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3], + uint32_t* const b565[2], uint8_t* const dst) { + __m128i s[2], ma3[2], ma5[2], sq[4], b3[3], b5[3], ma[3], b[3][2]; + s[0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1 - width); + sq[0] = SquareLo8(s[0]); + BoxFilterPreProcessLastRowLo(s[0], scales, sum3, sum5, square_sum3, + square_sum5, sq, &ma3[0], &ma5[0], &b3[0], + &b5[0]); + + int x = 0; + do { + __m128i ma3x[3], ma5x[3], p[2]; + s[1] = LoadUnaligned16Msan(src0 + x + 16, + x + 16 + kOverreadInBytesPass1 - width); + BoxFilterPreProcessLastRow(s, sum_width, x + 8, scales, sum3, sum5, + square_sum3, square_sum5, sq, ma3, ma5, b3, b5); + Prepare3_8<0>(ma3, ma3x); + Prepare3_8<0>(ma5, ma5x); + ma[1] = Sum565Lo(ma5x); + Sum565W(b5, b[1]); + ma[2] = Sum343Lo(ma3x); + Sum343W(b3, b[2]); + const __m128i sr = LoadAligned16(src + x); + const __m128i sr_lo = _mm_unpacklo_epi8(sr, _mm_setzero_si128()); + ma[0] = LoadAligned16(ma565[0] + x); + LoadAligned32U32(b565[0] + x, b[0]); + p[0] = CalculateFilteredOutputPass1(sr_lo, ma, b); + ma[0] = LoadAligned16(ma343[0] + x); + ma[1] = LoadAligned16(ma444[0] + x); + LoadAligned32U32(b343[0] + x, b[0]); + LoadAligned32U32(b444[0] + x, b[1]); + p[1] = CalculateFilteredOutputPass2(sr_lo, ma, b); + const __m128i d0 = SelfGuidedDoubleMultiplier(sr_lo, p, w0, w2); + + ma[1] = Sum565Hi(ma5x); + Sum565W(b5 + 1, b[1]); + ma[2] = Sum343Hi(ma3x); + Sum343W(b3 + 1, b[2]); + const __m128i sr_hi = _mm_unpackhi_epi8(sr, _mm_setzero_si128()); + ma[0] = LoadAligned16(ma565[0] + x + 8); + LoadAligned32U32(b565[0] + x + 8, b[0]); + p[0] = CalculateFilteredOutputPass1(sr_hi, ma, b); + ma[0] = LoadAligned16(ma343[0] + x + 8); + ma[1] = LoadAligned16(ma444[0] + x + 8); + LoadAligned32U32(b343[0] + x + 8, b[0]); + LoadAligned32U32(b444[0] + x + 8, b[1]); + p[1] = CalculateFilteredOutputPass2(sr_hi, ma, b); + const __m128i d1 = SelfGuidedDoubleMultiplier(sr_hi, p, w0, w2); + StoreAligned16(dst + x, _mm_packus_epi16(d0, d1)); + s[0] = s[1]; + sq[1] = sq[3]; + ma3[0] = ma3[1]; + ma5[0] = ma5[1]; + b3[0] = b3[2]; + b5[0] = b5[2]; + x += 16; + } while (x < width); +} + +LIBGAV1_ALWAYS_INLINE void BoxFilterProcess( + const RestorationUnitInfo& restoration_info, const uint8_t* src, + const uint8_t* const top_border, const uint8_t* bottom_border, + const ptrdiff_t stride, const int width, const int height, + SgrBuffer* const sgr_buffer, uint8_t* dst) { + const auto temp_stride = Align<ptrdiff_t>(width, 16); + const auto sum_width = Align<ptrdiff_t>(width + 8, 16); + const auto sum_stride = temp_stride + 16; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index]; // < 2^12. + const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0]; + const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1]; + const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1; + uint16_t *sum3[4], *sum5[5], *ma343[4], *ma444[3], *ma565[2]; + uint32_t *square_sum3[4], *square_sum5[5], *b343[4], *b444[3], *b565[2]; + sum3[0] = sgr_buffer->sum3; + square_sum3[0] = sgr_buffer->square_sum3; + ma343[0] = sgr_buffer->ma343; + b343[0] = sgr_buffer->b343; + for (int i = 1; i <= 3; ++i) { + sum3[i] = sum3[i - 1] + sum_stride; + square_sum3[i] = square_sum3[i - 1] + sum_stride; + ma343[i] = ma343[i - 1] + temp_stride; + b343[i] = b343[i - 1] + temp_stride; + } + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + for (int i = 1; i <= 4; ++i) { + sum5[i] = sum5[i - 1] + sum_stride; + square_sum5[i] = square_sum5[i - 1] + sum_stride; + } + ma444[0] = sgr_buffer->ma444; + b444[0] = sgr_buffer->b444; + for (int i = 1; i <= 2; ++i) { + ma444[i] = ma444[i - 1] + temp_stride; + b444[i] = b444[i - 1] + temp_stride; + } + ma565[0] = sgr_buffer->ma565; + ma565[1] = ma565[0] + temp_stride; + b565[0] = sgr_buffer->b565; + b565[1] = b565[0] + temp_stride; + assert(scales[0] != 0); + assert(scales[1] != 0); + BoxSum(top_border, stride, width, sum_stride, sum_width, sum3[0], sum5[1], + square_sum3[0], square_sum5[1]); + sum5[0] = sum5[1]; + square_sum5[0] = square_sum5[1]; + const uint8_t* const s = (height > 1) ? src + stride : bottom_border; + BoxSumFilterPreProcess(src, s, width, scales, sum3, sum5, square_sum3, + square_sum5, sum_width, ma343, ma444, ma565[0], b343, + b444, b565[0]); + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + + for (int y = (height >> 1) - 1; y > 0; --y) { + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + BoxFilter(src + 3, src + 2 * stride, src + 3 * stride, stride, width, + scales, w0, w2, sum3, sum5, square_sum3, square_sum5, sum_width, + ma343, ma444, ma565, b343, b444, b565, dst); + src += 2 * stride; + dst += 2 * stride; + Circulate4PointersBy2<uint16_t>(ma343); + Circulate4PointersBy2<uint32_t>(b343); + std::swap(ma444[0], ma444[2]); + std::swap(b444[0], b444[2]); + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + if ((height & 1) == 0 || height > 1) { + const uint8_t* sr[2]; + if ((height & 1) == 0) { + sr[0] = bottom_border; + sr[1] = bottom_border + stride; + } else { + sr[0] = src + 2 * stride; + sr[1] = bottom_border; + } + BoxFilter(src + 3, sr[0], sr[1], stride, width, scales, w0, w2, sum3, sum5, + square_sum3, square_sum5, sum_width, ma343, ma444, ma565, b343, + b444, b565, dst); + } + if ((height & 1) != 0) { + if (height > 1) { + src += 2 * stride; + dst += 2 * stride; + Circulate4PointersBy2<uint16_t>(sum3); + Circulate4PointersBy2<uint32_t>(square_sum3); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + Circulate4PointersBy2<uint16_t>(ma343); + Circulate4PointersBy2<uint32_t>(b343); + std::swap(ma444[0], ma444[2]); + std::swap(b444[0], b444[2]); + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + BoxFilterLastRow(src + 3, bottom_border + stride, width, sum_width, scales, + w0, w2, sum3, sum5, square_sum3, square_sum5, ma343, ma444, + ma565, b343, b444, b565, dst); + } +} + +inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info, + const uint8_t* src, + const uint8_t* const top_border, + const uint8_t* bottom_border, + const ptrdiff_t stride, const int width, + const int height, SgrBuffer* const sgr_buffer, + uint8_t* dst) { + const auto temp_stride = Align<ptrdiff_t>(width, 16); + const auto sum_width = Align<ptrdiff_t>(width + 8, 16); + const auto sum_stride = temp_stride + 16; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0]; // < 2^12. + const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0]; + uint16_t *sum5[5], *ma565[2]; + uint32_t *square_sum5[5], *b565[2]; + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + for (int i = 1; i <= 4; ++i) { + sum5[i] = sum5[i - 1] + sum_stride; + square_sum5[i] = square_sum5[i - 1] + sum_stride; + } + ma565[0] = sgr_buffer->ma565; + ma565[1] = ma565[0] + temp_stride; + b565[0] = sgr_buffer->b565; + b565[1] = b565[0] + temp_stride; + assert(scale != 0); + BoxSum<5>(top_border, stride, width, sum_stride, sum_width, sum5[1], + square_sum5[1]); + sum5[0] = sum5[1]; + square_sum5[0] = square_sum5[1]; + const uint8_t* const s = (height > 1) ? src + stride : bottom_border; + BoxSumFilterPreProcess5(src, s, width, scale, sum5, square_sum5, sum_width, + ma565[0], b565[0]); + sum5[0] = sgr_buffer->sum5; + square_sum5[0] = sgr_buffer->square_sum5; + + for (int y = (height >> 1) - 1; y > 0; --y) { + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + BoxFilterPass1(src + 3, src + 2 * stride, src + 3 * stride, stride, sum5, + square_sum5, width, sum_width, scale, w0, ma565, b565, dst); + src += 2 * stride; + dst += 2 * stride; + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + } + + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + if ((height & 1) == 0 || height > 1) { + const uint8_t* sr[2]; + if ((height & 1) == 0) { + sr[0] = bottom_border; + sr[1] = bottom_border + stride; + } else { + sr[0] = src + 2 * stride; + sr[1] = bottom_border; + } + BoxFilterPass1(src + 3, sr[0], sr[1], stride, sum5, square_sum5, width, + sum_width, scale, w0, ma565, b565, dst); + } + if ((height & 1) != 0) { + src += 3; + if (height > 1) { + src += 2 * stride; + dst += 2 * stride; + std::swap(ma565[0], ma565[1]); + std::swap(b565[0], b565[1]); + Circulate5PointersBy2<uint16_t>(sum5); + Circulate5PointersBy2<uint32_t>(square_sum5); + } + BoxFilterPass1LastRow(src, bottom_border + stride, width, sum_width, scale, + w0, sum5, square_sum5, ma565[0], b565[0], dst); + } +} + +inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info, + const uint8_t* src, + const uint8_t* const top_border, + const uint8_t* bottom_border, + const ptrdiff_t stride, const int width, + const int height, SgrBuffer* const sgr_buffer, + uint8_t* dst) { + assert(restoration_info.sgr_proj_info.multiplier[0] == 0); + const auto temp_stride = Align<ptrdiff_t>(width, 16); + const auto sum_width = Align<ptrdiff_t>(width + 8, 16); + const auto sum_stride = temp_stride + 16; + const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1]; + const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1; + const int sgr_proj_index = restoration_info.sgr_proj_info.index; + const uint32_t scale = kSgrScaleParameter[sgr_proj_index][1]; // < 2^12. + uint16_t *sum3[3], *ma343[3], *ma444[2]; + uint32_t *square_sum3[3], *b343[3], *b444[2]; + sum3[0] = sgr_buffer->sum3; + square_sum3[0] = sgr_buffer->square_sum3; + ma343[0] = sgr_buffer->ma343; + b343[0] = sgr_buffer->b343; + for (int i = 1; i <= 2; ++i) { + sum3[i] = sum3[i - 1] + sum_stride; + square_sum3[i] = square_sum3[i - 1] + sum_stride; + ma343[i] = ma343[i - 1] + temp_stride; + b343[i] = b343[i - 1] + temp_stride; + } + ma444[0] = sgr_buffer->ma444; + ma444[1] = ma444[0] + temp_stride; + b444[0] = sgr_buffer->b444; + b444[1] = b444[0] + temp_stride; + assert(scale != 0); + BoxSum<3>(top_border, stride, width, sum_stride, sum_width, sum3[0], + square_sum3[0]); + BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3, + sum_width, ma343[0], nullptr, b343[0], + nullptr); + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + const uint8_t* s; + if (height > 1) { + s = src + stride; + } else { + s = bottom_border; + bottom_border += stride; + } + BoxSumFilterPreProcess3<true>(s, width, scale, sum3, square_sum3, sum_width, + ma343[1], ma444[0], b343[1], b444[0]); + + for (int y = height - 2; y > 0; --y) { + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + BoxFilterPass2(src + 2, src + 2 * stride, width, sum_width, scale, w0, sum3, + square_sum3, ma343, ma444, b343, b444, dst); + src += stride; + dst += stride; + Circulate3PointersBy1<uint16_t>(ma343); + Circulate3PointersBy1<uint32_t>(b343); + std::swap(ma444[0], ma444[1]); + std::swap(b444[0], b444[1]); + } + + int y = std::min(height, 2); + src += 2; + do { + Circulate3PointersBy1<uint16_t>(sum3); + Circulate3PointersBy1<uint32_t>(square_sum3); + BoxFilterPass2(src, bottom_border, width, sum_width, scale, w0, sum3, + square_sum3, ma343, ma444, b343, b444, dst); + src += stride; + dst += stride; + bottom_border += stride; + Circulate3PointersBy1<uint16_t>(ma343); + Circulate3PointersBy1<uint32_t>(b343); + std::swap(ma444[0], ma444[1]); + std::swap(b444[0], b444[1]); + } while (--y != 0); +} + +// If |width| is non-multiple of 8, up to 7 more pixels are written to |dest| in +// the end of each row. It is safe to overwrite the output as it will not be +// part of the visible frame. +void SelfGuidedFilter_SSE4_1( + const RestorationUnitInfo& restoration_info, const void* const source, + const void* const top_border, const void* const bottom_border, + const ptrdiff_t stride, const int width, const int height, + RestorationBuffer* const restoration_buffer, void* const dest) { + const int index = restoration_info.sgr_proj_info.index; + const int radius_pass_0 = kSgrProjParams[index][0]; // 2 or 0 + const int radius_pass_1 = kSgrProjParams[index][2]; // 1 or 0 + const auto* const src = static_cast<const uint8_t*>(source); + const auto* top = static_cast<const uint8_t*>(top_border); + const auto* bottom = static_cast<const uint8_t*>(bottom_border); + auto* const dst = static_cast<uint8_t*>(dest); + SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer; + if (radius_pass_1 == 0) { + // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the + // following assertion. + assert(radius_pass_0 != 0); + BoxFilterProcessPass1(restoration_info, src - 3, top - 3, bottom - 3, + stride, width, height, sgr_buffer, dst); + } else if (radius_pass_0 == 0) { + BoxFilterProcessPass2(restoration_info, src - 2, top - 2, bottom - 2, + stride, width, height, sgr_buffer, dst); + } else { + BoxFilterProcess(restoration_info, src - 3, top - 3, bottom - 3, stride, + width, height, sgr_buffer, dst); + } +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + static_cast<void>(dsp); +#if DSP_ENABLED_8BPP_SSE4_1(WienerFilter) + dsp->loop_restorations[0] = WienerFilter_SSE4_1; +#else + static_cast<void>(WienerFilter_SSE4_1); +#endif +#if DSP_ENABLED_8BPP_SSE4_1(SelfGuidedFilter) + dsp->loop_restorations[1] = SelfGuidedFilter_SSE4_1; +#else + static_cast<void>(SelfGuidedFilter_SSE4_1); +#endif +} + +} // namespace +} // namespace low_bitdepth + +void LoopRestorationInit_SSE4_1() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_TARGETING_SSE4_1 +namespace libgav1 { +namespace dsp { + +void LoopRestorationInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/loop_restoration_sse4.h b/src/dsp/x86/loop_restoration_sse4.h new file mode 100644 index 0000000..65b2b11 --- /dev/null +++ b/src/dsp/x86/loop_restoration_sse4.h @@ -0,0 +1,52 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_LOOP_RESTORATION_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_LOOP_RESTORATION_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::loop_restorations, see the defines below for specifics. +// These functions are not thread-safe. +void LoopRestorationInit_SSE4_1(); +void LoopRestorationInit10bpp_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +// If sse4 is enabled and the baseline isn't set due to a higher level of +// optimization being enabled, signal the sse4 implementation should be used. +#if LIBGAV1_TARGETING_SSE4_1 + +#ifndef LIBGAV1_Dsp8bpp_WienerFilter +#define LIBGAV1_Dsp8bpp_WienerFilter LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_SelfGuidedFilter +#define LIBGAV1_Dsp8bpp_SelfGuidedFilter LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp10bpp_WienerFilter +#define LIBGAV1_Dsp10bpp_WienerFilter LIBGAV1_CPU_SSE4_1 +#endif + +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_X86_LOOP_RESTORATION_SSE4_H_ diff --git a/src/dsp/x86/mask_blend_sse4.cc b/src/dsp/x86/mask_blend_sse4.cc new file mode 100644 index 0000000..d8036be --- /dev/null +++ b/src/dsp/x86/mask_blend_sse4.cc @@ -0,0 +1,447 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/mask_blend.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 + +#include <smmintrin.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +// Width can only be 4 when it is subsampled from a block of width 8, hence +// subsampling_x is always 1 when this function is called. +template <int subsampling_x, int subsampling_y> +inline __m128i GetMask4x2(const uint8_t* mask, ptrdiff_t mask_stride) { + if (subsampling_x == 1) { + const __m128i mask_val_0 = _mm_cvtepu8_epi16(LoadLo8(mask)); + const __m128i mask_val_1 = + _mm_cvtepu8_epi16(LoadLo8(mask + (mask_stride << subsampling_y))); + __m128i subsampled_mask = _mm_hadd_epi16(mask_val_0, mask_val_1); + if (subsampling_y == 1) { + const __m128i next_mask_val_0 = + _mm_cvtepu8_epi16(LoadLo8(mask + mask_stride)); + const __m128i next_mask_val_1 = + _mm_cvtepu8_epi16(LoadLo8(mask + mask_stride * 3)); + subsampled_mask = _mm_add_epi16( + subsampled_mask, _mm_hadd_epi16(next_mask_val_0, next_mask_val_1)); + } + return RightShiftWithRounding_U16(subsampled_mask, 1 + subsampling_y); + } + const __m128i mask_val_0 = Load4(mask); + const __m128i mask_val_1 = Load4(mask + mask_stride); + return _mm_cvtepu8_epi16( + _mm_or_si128(mask_val_0, _mm_slli_si128(mask_val_1, 4))); +} + +// This function returns a 16-bit packed mask to fit in _mm_madd_epi16. +// 16-bit is also the lowest packing for hadd, but without subsampling there is +// an unfortunate conversion required. +template <int subsampling_x, int subsampling_y> +inline __m128i GetMask8(const uint8_t* mask, ptrdiff_t stride) { + if (subsampling_x == 1) { + const __m128i row_vals = LoadUnaligned16(mask); + + const __m128i mask_val_0 = _mm_cvtepu8_epi16(row_vals); + const __m128i mask_val_1 = _mm_cvtepu8_epi16(_mm_srli_si128(row_vals, 8)); + __m128i subsampled_mask = _mm_hadd_epi16(mask_val_0, mask_val_1); + + if (subsampling_y == 1) { + const __m128i next_row_vals = LoadUnaligned16(mask + stride); + const __m128i next_mask_val_0 = _mm_cvtepu8_epi16(next_row_vals); + const __m128i next_mask_val_1 = + _mm_cvtepu8_epi16(_mm_srli_si128(next_row_vals, 8)); + subsampled_mask = _mm_add_epi16( + subsampled_mask, _mm_hadd_epi16(next_mask_val_0, next_mask_val_1)); + } + return RightShiftWithRounding_U16(subsampled_mask, 1 + subsampling_y); + } + assert(subsampling_y == 0 && subsampling_x == 0); + const __m128i mask_val = LoadLo8(mask); + return _mm_cvtepu8_epi16(mask_val); +} + +// This version returns 8-bit packed values to fit in _mm_maddubs_epi16 because, +// when is_inter_intra is true, the prediction values are brought to 8-bit +// packing as well. +template <int subsampling_x, int subsampling_y> +inline __m128i GetInterIntraMask8(const uint8_t* mask, ptrdiff_t stride) { + if (subsampling_x == 1) { + const __m128i row_vals = LoadUnaligned16(mask); + + const __m128i mask_val_0 = _mm_cvtepu8_epi16(row_vals); + const __m128i mask_val_1 = _mm_cvtepu8_epi16(_mm_srli_si128(row_vals, 8)); + __m128i subsampled_mask = _mm_hadd_epi16(mask_val_0, mask_val_1); + + if (subsampling_y == 1) { + const __m128i next_row_vals = LoadUnaligned16(mask + stride); + const __m128i next_mask_val_0 = _mm_cvtepu8_epi16(next_row_vals); + const __m128i next_mask_val_1 = + _mm_cvtepu8_epi16(_mm_srli_si128(next_row_vals, 8)); + subsampled_mask = _mm_add_epi16( + subsampled_mask, _mm_hadd_epi16(next_mask_val_0, next_mask_val_1)); + } + const __m128i ret = + RightShiftWithRounding_U16(subsampled_mask, 1 + subsampling_y); + return _mm_packus_epi16(ret, ret); + } + assert(subsampling_y == 0 && subsampling_x == 0); + // Unfortunately there is no shift operation for 8-bit packing, or else we + // could return everything with 8-bit packing. + const __m128i mask_val = LoadLo8(mask); + return mask_val; +} + +inline void WriteMaskBlendLine4x2(const int16_t* const pred_0, + const int16_t* const pred_1, + const __m128i pred_mask_0, + const __m128i pred_mask_1, uint8_t* dst, + const ptrdiff_t dst_stride) { + const __m128i pred_val_0 = LoadAligned16(pred_0); + const __m128i pred_val_1 = LoadAligned16(pred_1); + const __m128i mask_lo = _mm_unpacklo_epi16(pred_mask_0, pred_mask_1); + const __m128i mask_hi = _mm_unpackhi_epi16(pred_mask_0, pred_mask_1); + const __m128i pred_lo = _mm_unpacklo_epi16(pred_val_0, pred_val_1); + const __m128i pred_hi = _mm_unpackhi_epi16(pred_val_0, pred_val_1); + + // int res = (mask_value * prediction_0[x] + + // (64 - mask_value) * prediction_1[x]) >> 6; + const __m128i compound_pred_lo = _mm_madd_epi16(pred_lo, mask_lo); + const __m128i compound_pred_hi = _mm_madd_epi16(pred_hi, mask_hi); + const __m128i compound_pred = _mm_packus_epi32( + _mm_srli_epi32(compound_pred_lo, 6), _mm_srli_epi32(compound_pred_hi, 6)); + + // dst[x] = static_cast<Pixel>( + // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0, + // (1 << kBitdepth8) - 1)); + const __m128i result = RightShiftWithRounding_S16(compound_pred, 4); + const __m128i res = _mm_packus_epi16(result, result); + Store4(dst, res); + Store4(dst + dst_stride, _mm_srli_si128(res, 4)); +} + +template <int subsampling_x, int subsampling_y> +inline void MaskBlending4x4_SSE4(const int16_t* pred_0, const int16_t* pred_1, + const uint8_t* mask, + const ptrdiff_t mask_stride, uint8_t* dst, + const ptrdiff_t dst_stride) { + const __m128i mask_inverter = _mm_set1_epi16(64); + __m128i pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + + pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); +} + +template <int subsampling_x, int subsampling_y> +inline void MaskBlending4xH_SSE4(const int16_t* pred_0, const int16_t* pred_1, + const uint8_t* const mask_ptr, + const ptrdiff_t mask_stride, const int height, + uint8_t* dst, const ptrdiff_t dst_stride) { + const uint8_t* mask = mask_ptr; + if (height == 4) { + MaskBlending4x4_SSE4<subsampling_x, subsampling_y>( + pred_0, pred_1, mask, mask_stride, dst, dst_stride); + return; + } + const __m128i mask_inverter = _mm_set1_epi16(64); + int y = 0; + do { + __m128i pred_mask_0 = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + + pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + + pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + + pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst, + dst_stride); + pred_0 += 4 << 1; + pred_1 += 4 << 1; + mask += mask_stride << (1 + subsampling_y); + dst += dst_stride << 1; + y += 8; + } while (y < height); +} + +template <int subsampling_x, int subsampling_y> +inline void MaskBlend_SSE4(const void* prediction_0, const void* prediction_1, + const ptrdiff_t /*prediction_stride_1*/, + const uint8_t* const mask_ptr, + const ptrdiff_t mask_stride, const int width, + const int height, void* dest, + const ptrdiff_t dst_stride) { + auto* dst = static_cast<uint8_t*>(dest); + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + const ptrdiff_t pred_stride_0 = width; + const ptrdiff_t pred_stride_1 = width; + if (width == 4) { + MaskBlending4xH_SSE4<subsampling_x, subsampling_y>( + pred_0, pred_1, mask_ptr, mask_stride, height, dst, dst_stride); + return; + } + const uint8_t* mask = mask_ptr; + const __m128i mask_inverter = _mm_set1_epi16(64); + int y = 0; + do { + int x = 0; + do { + const __m128i pred_mask_0 = GetMask8<subsampling_x, subsampling_y>( + mask + (x << subsampling_x), mask_stride); + // 64 - mask + const __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0); + const __m128i mask_lo = _mm_unpacklo_epi16(pred_mask_0, pred_mask_1); + const __m128i mask_hi = _mm_unpackhi_epi16(pred_mask_0, pred_mask_1); + + const __m128i pred_val_0 = LoadAligned16(pred_0 + x); + const __m128i pred_val_1 = LoadAligned16(pred_1 + x); + const __m128i pred_lo = _mm_unpacklo_epi16(pred_val_0, pred_val_1); + const __m128i pred_hi = _mm_unpackhi_epi16(pred_val_0, pred_val_1); + // int res = (mask_value * prediction_0[x] + + // (64 - mask_value) * prediction_1[x]) >> 6; + const __m128i compound_pred_lo = _mm_madd_epi16(pred_lo, mask_lo); + const __m128i compound_pred_hi = _mm_madd_epi16(pred_hi, mask_hi); + + const __m128i res = _mm_packus_epi32(_mm_srli_epi32(compound_pred_lo, 6), + _mm_srli_epi32(compound_pred_hi, 6)); + // dst[x] = static_cast<Pixel>( + // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0, + // (1 << kBitdepth8) - 1)); + const __m128i result = RightShiftWithRounding_S16(res, 4); + StoreLo8(dst + x, _mm_packus_epi16(result, result)); + + x += 8; + } while (x < width); + dst += dst_stride; + pred_0 += pred_stride_0; + pred_1 += pred_stride_1; + mask += mask_stride << subsampling_y; + } while (++y < height); +} + +inline void InterIntraWriteMaskBlendLine8bpp4x2(const uint8_t* const pred_0, + uint8_t* const pred_1, + const ptrdiff_t pred_stride_1, + const __m128i pred_mask_0, + const __m128i pred_mask_1) { + const __m128i pred_mask = _mm_unpacklo_epi8(pred_mask_0, pred_mask_1); + + const __m128i pred_val_0 = LoadLo8(pred_0); + // TODO(b/150326556): One load. + __m128i pred_val_1 = Load4(pred_1); + pred_val_1 = _mm_or_si128(_mm_slli_si128(Load4(pred_1 + pred_stride_1), 4), + pred_val_1); + const __m128i pred = _mm_unpacklo_epi8(pred_val_0, pred_val_1); + // int res = (mask_value * prediction_1[x] + + // (64 - mask_value) * prediction_0[x]) >> 6; + const __m128i compound_pred = _mm_maddubs_epi16(pred, pred_mask); + const __m128i result = RightShiftWithRounding_U16(compound_pred, 6); + const __m128i res = _mm_packus_epi16(result, result); + + Store4(pred_1, res); + Store4(pred_1 + pred_stride_1, _mm_srli_si128(res, 4)); +} + +template <int subsampling_x, int subsampling_y> +inline void InterIntraMaskBlending8bpp4x4_SSE4(const uint8_t* pred_0, + uint8_t* pred_1, + const ptrdiff_t pred_stride_1, + const uint8_t* mask, + const ptrdiff_t mask_stride) { + const __m128i mask_inverter = _mm_set1_epi8(64); + const __m128i pred_mask_u16_first = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + mask += mask_stride << (1 + subsampling_y); + const __m128i pred_mask_u16_second = + GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride); + mask += mask_stride << (1 + subsampling_y); + __m128i pred_mask_1 = + _mm_packus_epi16(pred_mask_u16_first, pred_mask_u16_second); + __m128i pred_mask_0 = _mm_sub_epi8(mask_inverter, pred_mask_1); + InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1, + pred_mask_0, pred_mask_1); + pred_0 += 4 << 1; + pred_1 += pred_stride_1 << 1; + + pred_mask_1 = _mm_srli_si128(pred_mask_1, 8); + pred_mask_0 = _mm_sub_epi8(mask_inverter, pred_mask_1); + InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1, + pred_mask_0, pred_mask_1); +} + +template <int subsampling_x, int subsampling_y> +inline void InterIntraMaskBlending8bpp4xH_SSE4(const uint8_t* pred_0, + uint8_t* pred_1, + const ptrdiff_t pred_stride_1, + const uint8_t* const mask_ptr, + const ptrdiff_t mask_stride, + const int height) { + const uint8_t* mask = mask_ptr; + if (height == 4) { + InterIntraMaskBlending8bpp4x4_SSE4<subsampling_x, subsampling_y>( + pred_0, pred_1, pred_stride_1, mask, mask_stride); + return; + } + int y = 0; + do { + InterIntraMaskBlending8bpp4x4_SSE4<subsampling_x, subsampling_y>( + pred_0, pred_1, pred_stride_1, mask, mask_stride); + pred_0 += 4 << 2; + pred_1 += pred_stride_1 << 2; + mask += mask_stride << (2 + subsampling_y); + + InterIntraMaskBlending8bpp4x4_SSE4<subsampling_x, subsampling_y>( + pred_0, pred_1, pred_stride_1, mask, mask_stride); + pred_0 += 4 << 2; + pred_1 += pred_stride_1 << 2; + mask += mask_stride << (2 + subsampling_y); + y += 8; + } while (y < height); +} + +template <int subsampling_x, int subsampling_y> +void InterIntraMaskBlend8bpp_SSE4(const uint8_t* prediction_0, + uint8_t* prediction_1, + const ptrdiff_t prediction_stride_1, + const uint8_t* const mask_ptr, + const ptrdiff_t mask_stride, const int width, + const int height) { + if (width == 4) { + InterIntraMaskBlending8bpp4xH_SSE4<subsampling_x, subsampling_y>( + prediction_0, prediction_1, prediction_stride_1, mask_ptr, mask_stride, + height); + return; + } + const uint8_t* mask = mask_ptr; + const __m128i mask_inverter = _mm_set1_epi8(64); + int y = 0; + do { + int x = 0; + do { + const __m128i pred_mask_1 = + GetInterIntraMask8<subsampling_x, subsampling_y>( + mask + (x << subsampling_x), mask_stride); + // 64 - mask + const __m128i pred_mask_0 = _mm_sub_epi8(mask_inverter, pred_mask_1); + const __m128i pred_mask = _mm_unpacklo_epi8(pred_mask_0, pred_mask_1); + + const __m128i pred_val_0 = LoadLo8(prediction_0 + x); + const __m128i pred_val_1 = LoadLo8(prediction_1 + x); + const __m128i pred = _mm_unpacklo_epi8(pred_val_0, pred_val_1); + // int res = (mask_value * prediction_1[x] + + // (64 - mask_value) * prediction_0[x]) >> 6; + const __m128i compound_pred = _mm_maddubs_epi16(pred, pred_mask); + const __m128i result = RightShiftWithRounding_U16(compound_pred, 6); + const __m128i res = _mm_packus_epi16(result, result); + + StoreLo8(prediction_1 + x, res); + + x += 8; + } while (x < width); + prediction_0 += width; + prediction_1 += prediction_stride_1; + mask += mask_stride << subsampling_y; + } while (++y < height); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); +#if DSP_ENABLED_8BPP_SSE4_1(MaskBlend444) + dsp->mask_blend[0][0] = MaskBlend_SSE4<0, 0>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(MaskBlend422) + dsp->mask_blend[1][0] = MaskBlend_SSE4<1, 0>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(MaskBlend420) + dsp->mask_blend[2][0] = MaskBlend_SSE4<1, 1>; +#endif + // The is_inter_intra index of mask_blend[][] is replaced by + // inter_intra_mask_blend_8bpp[] in 8-bit. +#if DSP_ENABLED_8BPP_SSE4_1(InterIntraMaskBlend8bpp444) + dsp->inter_intra_mask_blend_8bpp[0] = InterIntraMaskBlend8bpp_SSE4<0, 0>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(InterIntraMaskBlend8bpp422) + dsp->inter_intra_mask_blend_8bpp[1] = InterIntraMaskBlend8bpp_SSE4<1, 0>; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(InterIntraMaskBlend8bpp420) + dsp->inter_intra_mask_blend_8bpp[2] = InterIntraMaskBlend8bpp_SSE4<1, 1>; +#endif +} + +} // namespace +} // namespace low_bitdepth + +void MaskBlendInit_SSE4_1() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_TARGETING_SSE4_1 + +namespace libgav1 { +namespace dsp { + +void MaskBlendInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/mask_blend_sse4.h b/src/dsp/x86/mask_blend_sse4.h new file mode 100644 index 0000000..52b0b5c --- /dev/null +++ b/src/dsp/x86/mask_blend_sse4.h @@ -0,0 +1,60 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_MASK_BLEND_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_MASK_BLEND_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::mask_blend. This function is not thread-safe. +void MaskBlendInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_TARGETING_SSE4_1 + +#ifndef LIBGAV1_Dsp8bpp_MaskBlend444 +#define LIBGAV1_Dsp8bpp_MaskBlend444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_MaskBlend422 +#define LIBGAV1_Dsp8bpp_MaskBlend422 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_MaskBlend420 +#define LIBGAV1_Dsp8bpp_MaskBlend420 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp444 +#define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp444 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp422 +#define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp422 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp420 +#define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp420 LIBGAV1_CPU_SSE4_1 +#endif + +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_X86_MASK_BLEND_SSE4_H_ diff --git a/src/dsp/x86/motion_field_projection_sse4.cc b/src/dsp/x86/motion_field_projection_sse4.cc new file mode 100644 index 0000000..c506941 --- /dev/null +++ b/src/dsp/x86/motion_field_projection_sse4.cc @@ -0,0 +1,397 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/motion_field_projection.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 + +#include <smmintrin.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" +#include "src/utils/types.h" + +namespace libgav1 { +namespace dsp { +namespace { + +inline __m128i LoadDivision(const __m128i division_table, + const __m128i reference_offset) { + const __m128i kOne = _mm_set1_epi16(0x0100); + const __m128i t = _mm_add_epi8(reference_offset, reference_offset); + const __m128i tt = _mm_unpacklo_epi8(t, t); + const __m128i idx = _mm_add_epi8(tt, kOne); + return _mm_shuffle_epi8(division_table, idx); +} + +inline __m128i MvProjection(const __m128i mv, const __m128i denominator, + const int numerator) { + const __m128i m0 = _mm_madd_epi16(mv, denominator); + const __m128i m = _mm_mullo_epi32(m0, _mm_set1_epi32(numerator)); + // Add the sign (0 or -1) to round towards zero. + const __m128i sign = _mm_srai_epi32(m, 31); + const __m128i add_sign = _mm_add_epi32(m, sign); + const __m128i sum = _mm_add_epi32(add_sign, _mm_set1_epi32(1 << 13)); + return _mm_srai_epi32(sum, 14); +} + +inline __m128i MvProjectionClip(const __m128i mv, const __m128i denominator, + const int numerator) { + const __m128i mv0 = _mm_unpacklo_epi16(mv, _mm_setzero_si128()); + const __m128i mv1 = _mm_unpackhi_epi16(mv, _mm_setzero_si128()); + const __m128i denorm0 = _mm_unpacklo_epi16(denominator, _mm_setzero_si128()); + const __m128i denorm1 = _mm_unpackhi_epi16(denominator, _mm_setzero_si128()); + const __m128i s0 = MvProjection(mv0, denorm0, numerator); + const __m128i s1 = MvProjection(mv1, denorm1, numerator); + const __m128i projection = _mm_packs_epi32(s0, s1); + const __m128i projection_mv_clamp = _mm_set1_epi16(kProjectionMvClamp); + const __m128i projection_mv_clamp_negative = + _mm_set1_epi16(-kProjectionMvClamp); + const __m128i clamp = _mm_min_epi16(projection, projection_mv_clamp); + return _mm_max_epi16(clamp, projection_mv_clamp_negative); +} + +inline __m128i Project_SSE4_1(const __m128i delta, const __m128i dst_sign) { + // Add 63 to negative delta so that it shifts towards zero. + const __m128i delta_sign = _mm_srai_epi16(delta, 15); + const __m128i delta_sign_63 = _mm_srli_epi16(delta_sign, 10); + const __m128i delta_adjust = _mm_add_epi16(delta, delta_sign_63); + const __m128i offset0 = _mm_srai_epi16(delta_adjust, 6); + const __m128i offset1 = _mm_xor_si128(offset0, dst_sign); + return _mm_sub_epi16(offset1, dst_sign); +} + +inline void GetPosition( + const __m128i division_table, const MotionVector* const mv, + const int numerator, const int x8_start, const int x8_end, const int x8, + const __m128i& r_offsets, const __m128i& source_reference_type8, + const __m128i& skip_r, const __m128i& y8_floor8, const __m128i& y8_ceiling8, + const __m128i& d_sign, const int delta, __m128i* const r, + __m128i* const position_xy, int64_t* const skip_64, __m128i mvs[2]) { + const auto* const mv_int = reinterpret_cast<const int32_t*>(mv + x8); + *r = _mm_shuffle_epi8(r_offsets, source_reference_type8); + const __m128i denorm = LoadDivision(division_table, source_reference_type8); + __m128i projection_mv[2]; + mvs[0] = LoadUnaligned16(mv_int + 0); + mvs[1] = LoadUnaligned16(mv_int + 4); + // Deinterlace x and y components + const __m128i kShuffle = + _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15); + const __m128i mv0 = _mm_shuffle_epi8(mvs[0], kShuffle); + const __m128i mv1 = _mm_shuffle_epi8(mvs[1], kShuffle); + const __m128i mv_y = _mm_unpacklo_epi64(mv0, mv1); + const __m128i mv_x = _mm_unpackhi_epi64(mv0, mv1); + // numerator could be 0. + projection_mv[0] = MvProjectionClip(mv_y, denorm, numerator); + projection_mv[1] = MvProjectionClip(mv_x, denorm, numerator); + // Do not update the motion vector if the block position is not valid or + // if position_x8 is outside the current range of x8_start and x8_end. + // Note that position_y8 will always be within the range of y8_start and + // y8_end. + // After subtracting the base, valid projections are within 8-bit. + const __m128i position_y = Project_SSE4_1(projection_mv[0], d_sign); + const __m128i position_x = Project_SSE4_1(projection_mv[1], d_sign); + const __m128i positions = _mm_packs_epi16(position_x, position_y); + const __m128i k01234567 = + _mm_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0); + *position_xy = _mm_add_epi8(positions, k01234567); + const int x8_floor = std::max( + x8_start - x8, delta - kProjectionMvMaxHorizontalOffset); // [-8, 8] + const int x8_ceiling = + std::min(x8_end - x8, delta + 8 + kProjectionMvMaxHorizontalOffset) - + 1; // [-1, 15] + const __m128i x8_floor8 = _mm_set1_epi8(x8_floor); + const __m128i x8_ceiling8 = _mm_set1_epi8(x8_ceiling); + const __m128i floor_xy = _mm_unpacklo_epi64(x8_floor8, y8_floor8); + const __m128i ceiling_xy = _mm_unpacklo_epi64(x8_ceiling8, y8_ceiling8); + const __m128i underflow = _mm_cmplt_epi8(*position_xy, floor_xy); + const __m128i overflow = _mm_cmpgt_epi8(*position_xy, ceiling_xy); + const __m128i out = _mm_or_si128(underflow, overflow); + const __m128i skip_low = _mm_or_si128(skip_r, out); + const __m128i skip = _mm_or_si128(skip_low, _mm_srli_si128(out, 8)); + StoreLo8(skip_64, skip); +} + +template <int idx> +inline void Store(const __m128i position, const __m128i reference_offset, + const __m128i mv, int8_t* dst_reference_offset, + MotionVector* dst_mv) { + const ptrdiff_t offset = + static_cast<int16_t>(_mm_extract_epi16(position, idx)); + if ((idx & 3) == 0) { + dst_mv[offset].mv32 = _mm_cvtsi128_si32(mv); + } else { + dst_mv[offset].mv32 = _mm_extract_epi32(mv, idx & 3); + } + dst_reference_offset[offset] = _mm_extract_epi8(reference_offset, idx); +} + +template <int idx> +inline void CheckStore(const int8_t* skips, const __m128i position, + const __m128i reference_offset, const __m128i mv, + int8_t* dst_reference_offset, MotionVector* dst_mv) { + if (skips[idx] == 0) { + Store<idx>(position, reference_offset, mv, dst_reference_offset, dst_mv); + } +} + +// 7.9.2. +void MotionFieldProjectionKernel_SSE4_1( + const ReferenceInfo& reference_info, + const int reference_to_current_with_sign, const int dst_sign, + const int y8_start, const int y8_end, const int x8_start, const int x8_end, + TemporalMotionField* const motion_field) { + const ptrdiff_t stride = motion_field->mv.columns(); + // The column range has to be offset by kProjectionMvMaxHorizontalOffset since + // coordinates in that range could end up being position_x8 because of + // projection. + const int adjusted_x8_start = + std::max(x8_start - kProjectionMvMaxHorizontalOffset, 0); + const int adjusted_x8_end = std::min( + x8_end + kProjectionMvMaxHorizontalOffset, static_cast<int>(stride)); + const int adjusted_x8_end8 = adjusted_x8_end & ~7; + const int leftover = adjusted_x8_end - adjusted_x8_end8; + const int8_t* const reference_offsets = + reference_info.relative_distance_to.data(); + const bool* const skip_references = reference_info.skip_references.data(); + const int16_t* const projection_divisions = + reference_info.projection_divisions.data(); + const ReferenceFrameType* source_reference_types = + &reference_info.motion_field_reference_frame[y8_start][0]; + const MotionVector* mv = &reference_info.motion_field_mv[y8_start][0]; + int8_t* dst_reference_offset = motion_field->reference_offset[y8_start]; + MotionVector* dst_mv = motion_field->mv[y8_start]; + const __m128i d_sign = _mm_set1_epi16(dst_sign); + + static_assert(sizeof(int8_t) == sizeof(bool), ""); + static_assert(sizeof(int8_t) == sizeof(ReferenceFrameType), ""); + static_assert(sizeof(int32_t) == sizeof(MotionVector), ""); + assert(dst_sign == 0 || dst_sign == -1); + assert(stride == motion_field->reference_offset.columns()); + assert((y8_start & 7) == 0); + assert((adjusted_x8_start & 7) == 0); + // The final position calculation is represented with int16_t. Valid + // position_y8 from its base is at most 7. After considering the horizontal + // offset which is at most |stride - 1|, we have the following assertion, + // which means this optimization works for frame width up to 32K (each + // position is a 8x8 block). + assert(8 * stride <= 32768); + const __m128i skip_reference = LoadLo8(skip_references); + const __m128i r_offsets = LoadLo8(reference_offsets); + const __m128i division_table = LoadUnaligned16(projection_divisions); + + int y8 = y8_start; + do { + const int y8_floor = (y8 & ~7) - y8; // [-7, 0] + const int y8_ceiling = std::min(y8_end - y8, y8_floor + 8) - 1; // [0, 7] + const __m128i y8_floor8 = _mm_set1_epi8(y8_floor); + const __m128i y8_ceiling8 = _mm_set1_epi8(y8_ceiling); + int x8; + + for (x8 = adjusted_x8_start; x8 < adjusted_x8_end8; x8 += 8) { + const __m128i source_reference_type8 = + LoadLo8(source_reference_types + x8); + const __m128i skip_r = + _mm_shuffle_epi8(skip_reference, source_reference_type8); + int64_t early_skip; + StoreLo8(&early_skip, skip_r); + // Early termination #1 if all are skips. Chance is typically ~30-40%. + if (early_skip == -1) continue; + int64_t skip_64; + __m128i r, position_xy, mvs[2]; + GetPosition(division_table, mv, reference_to_current_with_sign, x8_start, + x8_end, x8, r_offsets, source_reference_type8, skip_r, + y8_floor8, y8_ceiling8, d_sign, 0, &r, &position_xy, &skip_64, + mvs); + // Early termination #2 if all are skips. + // Chance is typically ~15-25% after Early termination #1. + if (skip_64 == -1) continue; + const __m128i p_y = _mm_cvtepi8_epi16(_mm_srli_si128(position_xy, 8)); + const __m128i p_x = _mm_cvtepi8_epi16(position_xy); + const __m128i p_y_offset = _mm_mullo_epi16(p_y, _mm_set1_epi16(stride)); + const __m128i pos = _mm_add_epi16(p_y_offset, p_x); + const __m128i position = _mm_add_epi16(pos, _mm_set1_epi16(x8)); + if (skip_64 == 0) { + // Store all. Chance is typically ~70-85% after Early termination #2. + Store<0>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<1>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<2>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<3>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<4>(position, r, mvs[1], dst_reference_offset, dst_mv); + Store<5>(position, r, mvs[1], dst_reference_offset, dst_mv); + Store<6>(position, r, mvs[1], dst_reference_offset, dst_mv); + Store<7>(position, r, mvs[1], dst_reference_offset, dst_mv); + } else { + // Check and store each. + // Chance is typically ~15-30% after Early termination #2. + // The compiler is smart enough to not create the local buffer skips[]. + int8_t skips[8]; + memcpy(skips, &skip_64, sizeof(skips)); + CheckStore<0>(skips, position, r, mvs[0], dst_reference_offset, dst_mv); + CheckStore<1>(skips, position, r, mvs[0], dst_reference_offset, dst_mv); + CheckStore<2>(skips, position, r, mvs[0], dst_reference_offset, dst_mv); + CheckStore<3>(skips, position, r, mvs[0], dst_reference_offset, dst_mv); + CheckStore<4>(skips, position, r, mvs[1], dst_reference_offset, dst_mv); + CheckStore<5>(skips, position, r, mvs[1], dst_reference_offset, dst_mv); + CheckStore<6>(skips, position, r, mvs[1], dst_reference_offset, dst_mv); + CheckStore<7>(skips, position, r, mvs[1], dst_reference_offset, dst_mv); + } + } + + // The following leftover processing cannot be moved out of the do...while + // loop. Doing so may change the result storing orders of the same position. + if (leftover > 0) { + // Use SIMD only when leftover is at least 4, and there are at least 8 + // elements in a row. + if (leftover >= 4 && adjusted_x8_start < adjusted_x8_end8) { + // Process the last 8 elements to avoid loading invalid memory. Some + // elements may have been processed in the above loop, which is OK. + const int delta = 8 - leftover; + x8 = adjusted_x8_end - 8; + const __m128i source_reference_type8 = + LoadLo8(source_reference_types + x8); + const __m128i skip_r = + _mm_shuffle_epi8(skip_reference, source_reference_type8); + int64_t early_skip; + StoreLo8(&early_skip, skip_r); + // Early termination #1 if all are skips. + if (early_skip != -1) { + int64_t skip_64; + __m128i r, position_xy, mvs[2]; + GetPosition(division_table, mv, reference_to_current_with_sign, + x8_start, x8_end, x8, r_offsets, source_reference_type8, + skip_r, y8_floor8, y8_ceiling8, d_sign, delta, &r, + &position_xy, &skip_64, mvs); + // Early termination #2 if all are skips. + if (skip_64 != -1) { + const __m128i p_y = + _mm_cvtepi8_epi16(_mm_srli_si128(position_xy, 8)); + const __m128i p_x = _mm_cvtepi8_epi16(position_xy); + const __m128i p_y_offset = + _mm_mullo_epi16(p_y, _mm_set1_epi16(stride)); + const __m128i pos = _mm_add_epi16(p_y_offset, p_x); + const __m128i position = _mm_add_epi16(pos, _mm_set1_epi16(x8)); + // Store up to 7 elements since leftover is at most 7. + if (skip_64 == 0) { + // Store all. + Store<1>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<2>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<3>(position, r, mvs[0], dst_reference_offset, dst_mv); + Store<4>(position, r, mvs[1], dst_reference_offset, dst_mv); + Store<5>(position, r, mvs[1], dst_reference_offset, dst_mv); + Store<6>(position, r, mvs[1], dst_reference_offset, dst_mv); + Store<7>(position, r, mvs[1], dst_reference_offset, dst_mv); + } else { + // Check and store each. + // The compiler is smart enough to not create the local buffer + // skips[]. + int8_t skips[8]; + memcpy(skips, &skip_64, sizeof(skips)); + CheckStore<1>(skips, position, r, mvs[0], dst_reference_offset, + dst_mv); + CheckStore<2>(skips, position, r, mvs[0], dst_reference_offset, + dst_mv); + CheckStore<3>(skips, position, r, mvs[0], dst_reference_offset, + dst_mv); + CheckStore<4>(skips, position, r, mvs[1], dst_reference_offset, + dst_mv); + CheckStore<5>(skips, position, r, mvs[1], dst_reference_offset, + dst_mv); + CheckStore<6>(skips, position, r, mvs[1], dst_reference_offset, + dst_mv); + CheckStore<7>(skips, position, r, mvs[1], dst_reference_offset, + dst_mv); + } + } + } + } else { + for (; x8 < adjusted_x8_end; ++x8) { + const int source_reference_type = source_reference_types[x8]; + if (skip_references[source_reference_type]) continue; + MotionVector projection_mv; + // reference_to_current_with_sign could be 0. + GetMvProjection(mv[x8], reference_to_current_with_sign, + projection_divisions[source_reference_type], + &projection_mv); + // Do not update the motion vector if the block position is not valid + // or if position_x8 is outside the current range of x8_start and + // x8_end. Note that position_y8 will always be within the range of + // y8_start and y8_end. + const int position_y8 = Project(0, projection_mv.mv[0], dst_sign); + if (position_y8 < y8_floor || position_y8 > y8_ceiling) continue; + const int x8_base = x8 & ~7; + const int x8_floor = + std::max(x8_start, x8_base - kProjectionMvMaxHorizontalOffset); + const int x8_ceiling = + std::min(x8_end, x8_base + 8 + kProjectionMvMaxHorizontalOffset); + const int position_x8 = Project(x8, projection_mv.mv[1], dst_sign); + if (position_x8 < x8_floor || position_x8 >= x8_ceiling) continue; + dst_mv[position_y8 * stride + position_x8] = mv[x8]; + dst_reference_offset[position_y8 * stride + position_x8] = + reference_offsets[source_reference_type]; + } + } + } + + source_reference_types += stride; + mv += stride; + dst_reference_offset += stride; + dst_mv += stride; + } while (++y8 < y8_end); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_SSE4_1; +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_SSE4_1; +} +#endif + +} // namespace + +void MotionFieldProjectionInit_SSE4_1() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_TARGETING_SSE4_1 +namespace libgav1 { +namespace dsp { + +void MotionFieldProjectionInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/motion_field_projection_sse4.h b/src/dsp/x86/motion_field_projection_sse4.h new file mode 100644 index 0000000..c05422c --- /dev/null +++ b/src/dsp/x86/motion_field_projection_sse4.h @@ -0,0 +1,41 @@ +/* + * Copyright 2020 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_MOTION_FIELD_PROJECTION_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_MOTION_FIELD_PROJECTION_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::motion_field_projection_kernel. This function is not +// thread-safe. +void MotionFieldProjectionInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_TARGETING_SSE4_1 + +#ifndef LIBGAV1_Dsp8bpp_MotionFieldProjectionKernel +#define LIBGAV1_Dsp8bpp_MotionFieldProjectionKernel LIBGAV1_CPU_SSE4_1 +#endif + +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_X86_MOTION_FIELD_PROJECTION_SSE4_H_ diff --git a/src/dsp/x86/motion_vector_search_sse4.cc b/src/dsp/x86/motion_vector_search_sse4.cc new file mode 100644 index 0000000..e9cdd4c --- /dev/null +++ b/src/dsp/x86/motion_vector_search_sse4.cc @@ -0,0 +1,262 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/motion_vector_search.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 + +#include <smmintrin.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" +#include "src/utils/types.h" + +namespace libgav1 { +namespace dsp { +namespace { + +constexpr int kProjectionMvDivisionLookup_32bit[kMaxFrameDistance + 1] = { + 0, 16384, 8192, 5461, 4096, 3276, 2730, 2340, 2048, 1820, 1638, + 1489, 1365, 1260, 1170, 1092, 1024, 963, 910, 862, 819, 780, + 744, 712, 682, 655, 630, 606, 585, 564, 546, 528}; + +inline __m128i MvProjection(const __m128i mv, const __m128i denominator, + const __m128i numerator) { + const __m128i m0 = _mm_madd_epi16(mv, denominator); + const __m128i m = _mm_mullo_epi32(m0, numerator); + // Add the sign (0 or -1) to round towards zero. + const __m128i sign = _mm_srai_epi32(m, 31); + const __m128i add_sign = _mm_add_epi32(m, sign); + const __m128i sum = _mm_add_epi32(add_sign, _mm_set1_epi32(1 << 13)); + return _mm_srai_epi32(sum, 14); +} + +inline __m128i MvProjectionClip(const __m128i mvs[2], + const __m128i denominators[2], + const __m128i numerator) { + const __m128i s0 = MvProjection(mvs[0], denominators[0], numerator); + const __m128i s1 = MvProjection(mvs[1], denominators[1], numerator); + const __m128i mv = _mm_packs_epi32(s0, s1); + const __m128i projection_mv_clamp = _mm_set1_epi16(kProjectionMvClamp); + const __m128i projection_mv_clamp_negative = + _mm_set1_epi16(-kProjectionMvClamp); + const __m128i clamp = _mm_min_epi16(mv, projection_mv_clamp); + return _mm_max_epi16(clamp, projection_mv_clamp_negative); +} + +inline __m128i MvProjectionCompoundClip( + const MotionVector* const temporal_mvs, + const int8_t temporal_reference_offsets[2], + const int reference_offsets[2]) { + const auto* const tmvs = reinterpret_cast<const int32_t*>(temporal_mvs); + const __m128i temporal_mv = LoadLo8(tmvs); + const __m128i temporal_mv_0 = _mm_cvtepu16_epi32(temporal_mv); + __m128i mvs[2], denominators[2]; + mvs[0] = _mm_unpacklo_epi64(temporal_mv_0, temporal_mv_0); + mvs[1] = _mm_unpackhi_epi64(temporal_mv_0, temporal_mv_0); + denominators[0] = _mm_set1_epi32( + kProjectionMvDivisionLookup[temporal_reference_offsets[0]]); + denominators[1] = _mm_set1_epi32( + kProjectionMvDivisionLookup[temporal_reference_offsets[1]]); + const __m128i offsets = LoadLo8(reference_offsets); + const __m128i numerator = _mm_unpacklo_epi32(offsets, offsets); + return MvProjectionClip(mvs, denominators, numerator); +} + +inline __m128i MvProjectionSingleClip( + const MotionVector* const temporal_mvs, + const int8_t* const temporal_reference_offsets, + const int reference_offset) { + const auto* const tmvs = reinterpret_cast<const int16_t*>(temporal_mvs); + const __m128i temporal_mv = LoadAligned16(tmvs); + __m128i lookup = _mm_cvtsi32_si128( + kProjectionMvDivisionLookup_32bit[temporal_reference_offsets[0]]); + lookup = _mm_insert_epi32( + lookup, kProjectionMvDivisionLookup_32bit[temporal_reference_offsets[1]], + 1); + lookup = _mm_insert_epi32( + lookup, kProjectionMvDivisionLookup_32bit[temporal_reference_offsets[2]], + 2); + lookup = _mm_insert_epi32( + lookup, kProjectionMvDivisionLookup_32bit[temporal_reference_offsets[3]], + 3); + __m128i mvs[2], denominators[2]; + mvs[0] = _mm_unpacklo_epi16(temporal_mv, _mm_setzero_si128()); + mvs[1] = _mm_unpackhi_epi16(temporal_mv, _mm_setzero_si128()); + denominators[0] = _mm_unpacklo_epi32(lookup, lookup); + denominators[1] = _mm_unpackhi_epi32(lookup, lookup); + const __m128i numerator = _mm_set1_epi32(reference_offset); + return MvProjectionClip(mvs, denominators, numerator); +} + +inline void LowPrecision(const __m128i mv, void* const candidate_mvs) { + const __m128i kRoundDownMask = _mm_set1_epi16(~1); + const __m128i sign = _mm_srai_epi16(mv, 15); + const __m128i sub_sign = _mm_sub_epi16(mv, sign); + const __m128i d = _mm_and_si128(sub_sign, kRoundDownMask); + StoreAligned16(candidate_mvs, d); +} + +inline void ForceInteger(const __m128i mv, void* const candidate_mvs) { + const __m128i kRoundDownMask = _mm_set1_epi16(~7); + const __m128i sign = _mm_srai_epi16(mv, 15); + const __m128i mv1 = _mm_add_epi16(mv, _mm_set1_epi16(3)); + const __m128i mv2 = _mm_sub_epi16(mv1, sign); + const __m128i mv3 = _mm_and_si128(mv2, kRoundDownMask); + StoreAligned16(candidate_mvs, mv3); +} + +void MvProjectionCompoundLowPrecision_SSE4_1( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + const int reference_offsets[2], const int count, + CompoundMotionVector* candidate_mvs) { + // |reference_offsets| non-zero check usually equals true and is ignored. + // To facilitate the compilers, make a local copy of |reference_offsets|. + const int offsets[2] = {reference_offsets[0], reference_offsets[1]}; + // One more element could be calculated. + int i = 0; + do { + const __m128i mv = MvProjectionCompoundClip( + temporal_mvs + i, temporal_reference_offsets + i, offsets); + LowPrecision(mv, candidate_mvs + i); + i += 2; + } while (i < count); +} + +void MvProjectionCompoundForceInteger_SSE4_1( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + const int reference_offsets[2], const int count, + CompoundMotionVector* candidate_mvs) { + // |reference_offsets| non-zero check usually equals true and is ignored. + // To facilitate the compilers, make a local copy of |reference_offsets|. + const int offsets[2] = {reference_offsets[0], reference_offsets[1]}; + // One more element could be calculated. + int i = 0; + do { + const __m128i mv = MvProjectionCompoundClip( + temporal_mvs + i, temporal_reference_offsets + i, offsets); + ForceInteger(mv, candidate_mvs + i); + i += 2; + } while (i < count); +} + +void MvProjectionCompoundHighPrecision_SSE4_1( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + const int reference_offsets[2], const int count, + CompoundMotionVector* candidate_mvs) { + // |reference_offsets| non-zero check usually equals true and is ignored. + // To facilitate the compilers, make a local copy of |reference_offsets|. + const int offsets[2] = {reference_offsets[0], reference_offsets[1]}; + // One more element could be calculated. + int i = 0; + do { + const __m128i mv = MvProjectionCompoundClip( + temporal_mvs + i, temporal_reference_offsets + i, offsets); + StoreAligned16(candidate_mvs + i, mv); + i += 2; + } while (i < count); +} + +void MvProjectionSingleLowPrecision_SSE4_1( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + const int reference_offset, const int count, MotionVector* candidate_mvs) { + // Up to three more elements could be calculated. + int i = 0; + do { + const __m128i mv = MvProjectionSingleClip( + temporal_mvs + i, temporal_reference_offsets + i, reference_offset); + LowPrecision(mv, candidate_mvs + i); + i += 4; + } while (i < count); +} + +void MvProjectionSingleForceInteger_SSE4_1( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + const int reference_offset, const int count, MotionVector* candidate_mvs) { + // Up to three more elements could be calculated. + int i = 0; + do { + const __m128i mv = MvProjectionSingleClip( + temporal_mvs + i, temporal_reference_offsets + i, reference_offset); + ForceInteger(mv, candidate_mvs + i); + i += 4; + } while (i < count); +} + +void MvProjectionSingleHighPrecision_SSE4_1( + const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets, + const int reference_offset, const int count, MotionVector* candidate_mvs) { + // Up to three more elements could be calculated. + int i = 0; + do { + const __m128i mv = MvProjectionSingleClip( + temporal_mvs + i, temporal_reference_offsets + i, reference_offset); + StoreAligned16(candidate_mvs + i, mv); + i += 4; + } while (i < count); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_SSE4_1; + dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_SSE4_1; + dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_SSE4_1; + dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_SSE4_1; + dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_SSE4_1; + dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_SSE4_1; +} + +#if LIBGAV1_MAX_BITDEPTH >= 10 +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_SSE4_1; + dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_SSE4_1; + dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_SSE4_1; + dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_SSE4_1; + dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_SSE4_1; + dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_SSE4_1; +} +#endif + +} // namespace + +void MotionVectorSearchInit_SSE4_1() { + Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + Init10bpp(); +#endif +} + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_TARGETING_SSE4_1 +namespace libgav1 { +namespace dsp { + +void MotionVectorSearchInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/motion_vector_search_sse4.h b/src/dsp/x86/motion_vector_search_sse4.h new file mode 100644 index 0000000..d65b392 --- /dev/null +++ b/src/dsp/x86/motion_vector_search_sse4.h @@ -0,0 +1,41 @@ +/* + * Copyright 2020 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_MOTION_VECTOR_SEARCH_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_MOTION_VECTOR_SEARCH_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::mv_projection_compound and Dsp::mv_projection_single. This +// function is not thread-safe. +void MotionVectorSearchInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_TARGETING_SSE4_1 + +#ifndef LIBGAV1_Dsp8bpp_MotionVectorSearch +#define LIBGAV1_Dsp8bpp_MotionVectorSearch LIBGAV1_CPU_SSE4_1 +#endif + +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_X86_MOTION_VECTOR_SEARCH_SSE4_H_ diff --git a/src/dsp/x86/obmc_sse4.cc b/src/dsp/x86/obmc_sse4.cc new file mode 100644 index 0000000..3a1d1fd --- /dev/null +++ b/src/dsp/x86/obmc_sse4.cc @@ -0,0 +1,329 @@ +// Copyright 2019 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/obmc.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 + +#include <xmmintrin.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace { + +#include "src/dsp/obmc.inc" + +inline void OverlapBlendFromLeft2xH_SSE4_1( + uint8_t* const prediction, const ptrdiff_t prediction_stride, + const int height, const uint8_t* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + uint8_t* pred = prediction; + const uint8_t* obmc_pred = obmc_prediction; + const __m128i mask_inverter = _mm_cvtsi32_si128(0x40404040); + const __m128i mask_val = _mm_shufflelo_epi16(Load4(kObmcMask), 0); + // 64 - mask + const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val); + const __m128i masks = _mm_unpacklo_epi8(mask_val, obmc_mask_val); + int y = height; + do { + const __m128i pred_val = Load2x2(pred, pred + prediction_stride); + const __m128i obmc_pred_val = + Load2x2(obmc_pred, obmc_pred + obmc_prediction_stride); + + const __m128i terms = _mm_unpacklo_epi8(pred_val, obmc_pred_val); + const __m128i result = + RightShiftWithRounding_U16(_mm_maddubs_epi16(terms, masks), 6); + const __m128i packed_result = _mm_packus_epi16(result, result); + Store2(pred, packed_result); + pred += prediction_stride; + const int16_t second_row_result = _mm_extract_epi16(packed_result, 1); + memcpy(pred, &second_row_result, sizeof(second_row_result)); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride << 1; + y -= 2; + } while (y != 0); +} + +inline void OverlapBlendFromLeft4xH_SSE4_1( + uint8_t* const prediction, const ptrdiff_t prediction_stride, + const int height, const uint8_t* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + uint8_t* pred = prediction; + const uint8_t* obmc_pred = obmc_prediction; + const __m128i mask_inverter = _mm_cvtsi32_si128(0x40404040); + const __m128i mask_val = Load4(kObmcMask + 2); + // 64 - mask + const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val); + // Duplicate first half of vector. + const __m128i masks = + _mm_shuffle_epi32(_mm_unpacklo_epi8(mask_val, obmc_mask_val), 0x44); + int y = height; + do { + const __m128i pred_val0 = Load4(pred); + const __m128i obmc_pred_val0 = Load4(obmc_pred); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + + // Place the second row of each source in the second four bytes. + const __m128i pred_val = + _mm_alignr_epi8(Load4(pred), _mm_slli_si128(pred_val0, 12), 12); + const __m128i obmc_pred_val = _mm_alignr_epi8( + Load4(obmc_pred), _mm_slli_si128(obmc_pred_val0, 12), 12); + const __m128i terms = _mm_unpacklo_epi8(pred_val, obmc_pred_val); + const __m128i result = + RightShiftWithRounding_U16(_mm_maddubs_epi16(terms, masks), 6); + const __m128i packed_result = _mm_packus_epi16(result, result); + Store4(pred - prediction_stride, packed_result); + const int second_row_result = _mm_extract_epi32(packed_result, 1); + memcpy(pred, &second_row_result, sizeof(second_row_result)); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + y -= 2; + } while (y != 0); +} + +inline void OverlapBlendFromLeft8xH_SSE4_1( + uint8_t* const prediction, const ptrdiff_t prediction_stride, + const int height, const uint8_t* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + uint8_t* pred = prediction; + const uint8_t* obmc_pred = obmc_prediction; + const __m128i mask_inverter = _mm_set1_epi8(64); + const __m128i mask_val = LoadLo8(kObmcMask + 6); + // 64 - mask + const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val); + const __m128i masks = _mm_unpacklo_epi8(mask_val, obmc_mask_val); + int y = height; + do { + const __m128i pred_val = LoadLo8(pred); + const __m128i obmc_pred_val = LoadLo8(obmc_pred); + const __m128i terms = _mm_unpacklo_epi8(pred_val, obmc_pred_val); + const __m128i result = + RightShiftWithRounding_U16(_mm_maddubs_epi16(terms, masks), 6); + + StoreLo8(pred, _mm_packus_epi16(result, result)); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + } while (--y != 0); +} + +void OverlapBlendFromLeft_SSE4_1(void* const prediction, + const ptrdiff_t prediction_stride, + const int width, const int height, + const void* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + auto* pred = static_cast<uint8_t*>(prediction); + const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction); + + if (width == 2) { + OverlapBlendFromLeft2xH_SSE4_1(pred, prediction_stride, height, obmc_pred, + obmc_prediction_stride); + return; + } + if (width == 4) { + OverlapBlendFromLeft4xH_SSE4_1(pred, prediction_stride, height, obmc_pred, + obmc_prediction_stride); + return; + } + if (width == 8) { + OverlapBlendFromLeft8xH_SSE4_1(pred, prediction_stride, height, obmc_pred, + obmc_prediction_stride); + return; + } + const __m128i mask_inverter = _mm_set1_epi8(64); + const uint8_t* mask = kObmcMask + width - 2; + int x = 0; + do { + pred = static_cast<uint8_t*>(prediction) + x; + obmc_pred = static_cast<const uint8_t*>(obmc_prediction) + x; + const __m128i mask_val = LoadUnaligned16(mask + x); + // 64 - mask + const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val); + const __m128i masks_lo = _mm_unpacklo_epi8(mask_val, obmc_mask_val); + const __m128i masks_hi = _mm_unpackhi_epi8(mask_val, obmc_mask_val); + + int y = 0; + do { + const __m128i pred_val = LoadUnaligned16(pred); + const __m128i obmc_pred_val = LoadUnaligned16(obmc_pred); + const __m128i terms_lo = _mm_unpacklo_epi8(pred_val, obmc_pred_val); + const __m128i result_lo = + RightShiftWithRounding_U16(_mm_maddubs_epi16(terms_lo, masks_lo), 6); + const __m128i terms_hi = _mm_unpackhi_epi8(pred_val, obmc_pred_val); + const __m128i result_hi = + RightShiftWithRounding_U16(_mm_maddubs_epi16(terms_hi, masks_hi), 6); + StoreUnaligned16(pred, _mm_packus_epi16(result_lo, result_hi)); + + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + } while (++y < height); + x += 16; + } while (x < width); +} + +inline void OverlapBlendFromTop4xH_SSE4_1( + uint8_t* const prediction, const ptrdiff_t prediction_stride, + const int height, const uint8_t* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + uint8_t* pred = prediction; + const uint8_t* obmc_pred = obmc_prediction; + const __m128i mask_inverter = _mm_set1_epi16(64); + const __m128i mask_shuffler = _mm_set_epi32(0x01010101, 0x01010101, 0, 0); + const __m128i mask_preinverter = _mm_set1_epi16(-256 | 1); + + const uint8_t* mask = kObmcMask + height - 2; + const int compute_height = height - (height >> 2); + int y = 0; + do { + // First mask in the first half, second mask in the second half. + const __m128i mask_val = _mm_shuffle_epi8( + _mm_cvtsi32_si128(*reinterpret_cast<const uint16_t*>(mask + y)), + mask_shuffler); + const __m128i masks = + _mm_sub_epi8(mask_inverter, _mm_sign_epi8(mask_val, mask_preinverter)); + const __m128i pred_val0 = Load4(pred); + + const __m128i obmc_pred_val0 = Load4(obmc_pred); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + const __m128i pred_val = + _mm_alignr_epi8(Load4(pred), _mm_slli_si128(pred_val0, 12), 12); + const __m128i obmc_pred_val = _mm_alignr_epi8( + Load4(obmc_pred), _mm_slli_si128(obmc_pred_val0, 12), 12); + const __m128i terms = _mm_unpacklo_epi8(obmc_pred_val, pred_val); + const __m128i result = + RightShiftWithRounding_U16(_mm_maddubs_epi16(terms, masks), 6); + + const __m128i packed_result = _mm_packus_epi16(result, result); + Store4(pred - prediction_stride, packed_result); + Store4(pred, _mm_srli_si128(packed_result, 4)); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + y += 2; + } while (y < compute_height); +} + +inline void OverlapBlendFromTop8xH_SSE4_1( + uint8_t* const prediction, const ptrdiff_t prediction_stride, + const int height, const uint8_t* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + uint8_t* pred = prediction; + const uint8_t* obmc_pred = obmc_prediction; + const uint8_t* mask = kObmcMask + height - 2; + const __m128i mask_inverter = _mm_set1_epi8(64); + const int compute_height = height - (height >> 2); + int y = compute_height; + do { + const __m128i mask_val = _mm_set1_epi8(mask[compute_height - y]); + // 64 - mask + const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val); + const __m128i masks = _mm_unpacklo_epi8(mask_val, obmc_mask_val); + const __m128i pred_val = LoadLo8(pred); + const __m128i obmc_pred_val = LoadLo8(obmc_pred); + const __m128i terms = _mm_unpacklo_epi8(pred_val, obmc_pred_val); + const __m128i result = + RightShiftWithRounding_U16(_mm_maddubs_epi16(terms, masks), 6); + + StoreLo8(pred, _mm_packus_epi16(result, result)); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + } while (--y != 0); +} + +void OverlapBlendFromTop_SSE4_1(void* const prediction, + const ptrdiff_t prediction_stride, + const int width, const int height, + const void* const obmc_prediction, + const ptrdiff_t obmc_prediction_stride) { + auto* pred = static_cast<uint8_t*>(prediction); + const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction); + + if (width <= 4) { + OverlapBlendFromTop4xH_SSE4_1(pred, prediction_stride, height, obmc_pred, + obmc_prediction_stride); + return; + } + if (width == 8) { + OverlapBlendFromTop8xH_SSE4_1(pred, prediction_stride, height, obmc_pred, + obmc_prediction_stride); + return; + } + + // Stop when mask value becomes 64. + const int compute_height = height - (height >> 2); + const __m128i mask_inverter = _mm_set1_epi8(64); + int y = 0; + const uint8_t* mask = kObmcMask + height - 2; + do { + const __m128i mask_val = _mm_set1_epi8(mask[y]); + // 64 - mask + const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val); + const __m128i masks = _mm_unpacklo_epi8(mask_val, obmc_mask_val); + int x = 0; + do { + const __m128i pred_val = LoadUnaligned16(pred + x); + const __m128i obmc_pred_val = LoadUnaligned16(obmc_pred + x); + const __m128i terms_lo = _mm_unpacklo_epi8(pred_val, obmc_pred_val); + const __m128i result_lo = + RightShiftWithRounding_U16(_mm_maddubs_epi16(terms_lo, masks), 6); + const __m128i terms_hi = _mm_unpackhi_epi8(pred_val, obmc_pred_val); + const __m128i result_hi = + RightShiftWithRounding_U16(_mm_maddubs_epi16(terms_hi, masks), 6); + StoreUnaligned16(pred + x, _mm_packus_epi16(result_lo, result_hi)); + x += 16; + } while (x < width); + pred += prediction_stride; + obmc_pred += obmc_prediction_stride; + } while (++y < compute_height); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); +#if DSP_ENABLED_8BPP_SSE4_1(ObmcVertical) + dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendFromTop_SSE4_1; +#endif +#if DSP_ENABLED_8BPP_SSE4_1(ObmcHorizontal) + dsp->obmc_blend[kObmcDirectionHorizontal] = OverlapBlendFromLeft_SSE4_1; +#endif +} + +} // namespace + +void ObmcInit_SSE4_1() { Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_TARGETING_SSE4_1 + +namespace libgav1 { +namespace dsp { + +void ObmcInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/obmc_sse4.h b/src/dsp/x86/obmc_sse4.h new file mode 100644 index 0000000..bd8b416 --- /dev/null +++ b/src/dsp/x86/obmc_sse4.h @@ -0,0 +1,43 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_OBMC_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_OBMC_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::obmc_blend[]. This function is not thread-safe. +void ObmcInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +// If sse4 is enabled and the baseline isn't set due to a higher level of +// optimization being enabled, signal the sse4 implementation should be used. +#if LIBGAV1_TARGETING_SSE4_1 +#ifndef LIBGAV1_Dsp8bpp_ObmcVertical +#define LIBGAV1_Dsp8bpp_ObmcVertical LIBGAV1_CPU_SSE4_1 +#endif +#ifndef LIBGAV1_Dsp8bpp_ObmcHorizontal +#define LIBGAV1_Dsp8bpp_ObmcHorizontal LIBGAV1_CPU_SSE4_1 +#endif +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_X86_OBMC_SSE4_H_ diff --git a/src/dsp/x86/super_res_sse4.cc b/src/dsp/x86/super_res_sse4.cc new file mode 100644 index 0000000..b2bdfd2 --- /dev/null +++ b/src/dsp/x86/super_res_sse4.cc @@ -0,0 +1,166 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/super_res.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 + +#include <smmintrin.h> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/dsp/x86/transpose_sse4.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +// Upscale_Filter as defined in AV1 Section 7.16 +// Negative to make them fit in 8-bit. +alignas(16) const int8_t + kNegativeUpscaleFilter[kSuperResFilterShifts][kSuperResFilterTaps] = { + {0, 0, 0, -128, 0, 0, 0, 0}, {0, 0, 1, -128, -2, 1, 0, 0}, + {0, -1, 3, -127, -4, 2, -1, 0}, {0, -1, 4, -127, -6, 3, -1, 0}, + {0, -2, 6, -126, -8, 3, -1, 0}, {0, -2, 7, -125, -11, 4, -1, 0}, + {1, -2, 8, -125, -13, 5, -2, 0}, {1, -3, 9, -124, -15, 6, -2, 0}, + {1, -3, 10, -123, -18, 6, -2, 1}, {1, -3, 11, -122, -20, 7, -3, 1}, + {1, -4, 12, -121, -22, 8, -3, 1}, {1, -4, 13, -120, -25, 9, -3, 1}, + {1, -4, 14, -118, -28, 9, -3, 1}, {1, -4, 15, -117, -30, 10, -4, 1}, + {1, -5, 16, -116, -32, 11, -4, 1}, {1, -5, 16, -114, -35, 12, -4, 1}, + {1, -5, 17, -112, -38, 12, -4, 1}, {1, -5, 18, -111, -40, 13, -5, 1}, + {1, -5, 18, -109, -43, 14, -5, 1}, {1, -6, 19, -107, -45, 14, -5, 1}, + {1, -6, 19, -105, -48, 15, -5, 1}, {1, -6, 19, -103, -51, 16, -5, 1}, + {1, -6, 20, -101, -53, 16, -6, 1}, {1, -6, 20, -99, -56, 17, -6, 1}, + {1, -6, 20, -97, -58, 17, -6, 1}, {1, -6, 20, -95, -61, 18, -6, 1}, + {2, -7, 20, -93, -64, 18, -6, 2}, {2, -7, 20, -91, -66, 19, -6, 1}, + {2, -7, 20, -88, -69, 19, -6, 1}, {2, -7, 20, -86, -71, 19, -6, 1}, + {2, -7, 20, -84, -74, 20, -7, 2}, {2, -7, 20, -81, -76, 20, -7, 1}, + {2, -7, 20, -79, -79, 20, -7, 2}, {1, -7, 20, -76, -81, 20, -7, 2}, + {2, -7, 20, -74, -84, 20, -7, 2}, {1, -6, 19, -71, -86, 20, -7, 2}, + {1, -6, 19, -69, -88, 20, -7, 2}, {1, -6, 19, -66, -91, 20, -7, 2}, + {2, -6, 18, -64, -93, 20, -7, 2}, {1, -6, 18, -61, -95, 20, -6, 1}, + {1, -6, 17, -58, -97, 20, -6, 1}, {1, -6, 17, -56, -99, 20, -6, 1}, + {1, -6, 16, -53, -101, 20, -6, 1}, {1, -5, 16, -51, -103, 19, -6, 1}, + {1, -5, 15, -48, -105, 19, -6, 1}, {1, -5, 14, -45, -107, 19, -6, 1}, + {1, -5, 14, -43, -109, 18, -5, 1}, {1, -5, 13, -40, -111, 18, -5, 1}, + {1, -4, 12, -38, -112, 17, -5, 1}, {1, -4, 12, -35, -114, 16, -5, 1}, + {1, -4, 11, -32, -116, 16, -5, 1}, {1, -4, 10, -30, -117, 15, -4, 1}, + {1, -3, 9, -28, -118, 14, -4, 1}, {1, -3, 9, -25, -120, 13, -4, 1}, + {1, -3, 8, -22, -121, 12, -4, 1}, {1, -3, 7, -20, -122, 11, -3, 1}, + {1, -2, 6, -18, -123, 10, -3, 1}, {0, -2, 6, -15, -124, 9, -3, 1}, + {0, -2, 5, -13, -125, 8, -2, 1}, {0, -1, 4, -11, -125, 7, -2, 0}, + {0, -1, 3, -8, -126, 6, -2, 0}, {0, -1, 3, -6, -127, 4, -1, 0}, + {0, -1, 2, -4, -127, 3, -1, 0}, {0, 0, 1, -2, -128, 1, 0, 0}, +}; + +void SuperResCoefficients_SSE4_1(const int upscaled_width, + const int initial_subpixel_x, const int step, + void* const coefficients) { + auto* dst = static_cast<uint8_t*>(coefficients); + int subpixel_x = initial_subpixel_x; + int x = RightShiftWithCeiling(upscaled_width, 4); + do { + for (int i = 0; i < 8; ++i, dst += 16) { + int remainder = subpixel_x & kSuperResScaleMask; + __m128i filter = + LoadLo8(kNegativeUpscaleFilter[remainder >> kSuperResExtraBits]); + subpixel_x += step; + remainder = subpixel_x & kSuperResScaleMask; + filter = LoadHi8(filter, + kNegativeUpscaleFilter[remainder >> kSuperResExtraBits]); + subpixel_x += step; + StoreAligned16(dst, filter); + } + } while (--x != 0); +} + +void SuperRes_SSE4_1(const void* const coefficients, void* const source, + const ptrdiff_t stride, const int height, + const int downscaled_width, const int upscaled_width, + const int initial_subpixel_x, const int step, + void* const dest) { + auto* src = static_cast<uint8_t*>(source) - DivideBy2(kSuperResFilterTaps); + auto* dst = static_cast<uint8_t*>(dest); + int y = height; + do { + const auto* filter = static_cast<const uint8_t*>(coefficients); + uint8_t* dst_ptr = dst; + ExtendLine<uint8_t>(src + DivideBy2(kSuperResFilterTaps), downscaled_width, + kSuperResHorizontalBorder, kSuperResHorizontalBorder); + int subpixel_x = initial_subpixel_x; + // The below code calculates up to 15 extra upscaled + // pixels which will over-read up to 15 downscaled pixels in the end of each + // row. kSuperResHorizontalBorder accounts for this. + int x = RightShiftWithCeiling(upscaled_width, 4); + do { + __m128i weighted_src[8]; + for (int i = 0; i < 8; ++i, filter += 16) { + __m128i s = LoadLo8(&src[subpixel_x >> kSuperResScaleBits]); + subpixel_x += step; + s = LoadHi8(s, &src[subpixel_x >> kSuperResScaleBits]); + subpixel_x += step; + const __m128i f = LoadAligned16(filter); + weighted_src[i] = _mm_maddubs_epi16(s, f); + } + + __m128i a[4]; + a[0] = _mm_hadd_epi16(weighted_src[0], weighted_src[1]); + a[1] = _mm_hadd_epi16(weighted_src[2], weighted_src[3]); + a[2] = _mm_hadd_epi16(weighted_src[4], weighted_src[5]); + a[3] = _mm_hadd_epi16(weighted_src[6], weighted_src[7]); + Transpose2x16_U16(a, a); + a[0] = _mm_adds_epi16(a[0], a[1]); + a[1] = _mm_adds_epi16(a[2], a[3]); + const __m128i rounding = _mm_set1_epi16(1 << (kFilterBits - 1)); + a[0] = _mm_subs_epi16(rounding, a[0]); + a[1] = _mm_subs_epi16(rounding, a[1]); + a[0] = _mm_srai_epi16(a[0], kFilterBits); + a[1] = _mm_srai_epi16(a[1], kFilterBits); + StoreAligned16(dst_ptr, _mm_packus_epi16(a[0], a[1])); + dst_ptr += 16; + } while (--x != 0); + src += stride; + dst += stride; + } while (--y != 0); +} + +void Init8bpp() { + Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + dsp->super_res_coefficients = SuperResCoefficients_SSE4_1; + dsp->super_res = SuperRes_SSE4_1; +} + +} // namespace +} // namespace low_bitdepth + +void SuperResInit_SSE4_1() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_TARGETING_SSE4_1 + +namespace libgav1 { +namespace dsp { + +void SuperResInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/super_res_sse4.h b/src/dsp/x86/super_res_sse4.h new file mode 100644 index 0000000..aef5147 --- /dev/null +++ b/src/dsp/x86/super_res_sse4.h @@ -0,0 +1,38 @@ +/* + * Copyright 2020 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_SUPER_RES_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_SUPER_RES_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::super_res_row. This function is not thread-safe. +void SuperResInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_TARGETING_SSE4_1 +#ifndef LIBGAV1_Dsp8bpp_SuperRes +#define LIBGAV1_Dsp8bpp_SuperRes LIBGAV1_CPU_SSE4_1 +#endif +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_X86_SUPER_RES_SSE4_H_ diff --git a/src/dsp/x86/transpose_sse4.h b/src/dsp/x86/transpose_sse4.h new file mode 100644 index 0000000..208b301 --- /dev/null +++ b/src/dsp/x86/transpose_sse4.h @@ -0,0 +1,307 @@ +/* + * Copyright 2019 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_TRANSPOSE_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_TRANSPOSE_SSE4_H_ + +#include "src/utils/compiler_attributes.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 +#include <emmintrin.h> + +namespace libgav1 { +namespace dsp { + +LIBGAV1_ALWAYS_INLINE void Transpose2x16_U16(const __m128i* const in, + __m128i* const out) { + // Unpack 16 bit elements. Goes from: + // in[0]: 00 01 10 11 20 21 30 31 + // in[0]: 40 41 50 51 60 61 70 71 + // in[0]: 80 81 90 91 a0 a1 b0 b1 + // in[0]: c0 c1 d0 d1 e0 e1 f0 f1 + // to: + // a0: 00 40 01 41 10 50 11 51 + // a1: 20 60 21 61 30 70 31 71 + // a2: 80 c0 81 c1 90 d0 91 d1 + // a3: a0 e0 a1 e1 b0 f0 b1 f1 + const __m128i a0 = _mm_unpacklo_epi16(in[0], in[1]); + const __m128i a1 = _mm_unpackhi_epi16(in[0], in[1]); + const __m128i a2 = _mm_unpacklo_epi16(in[2], in[3]); + const __m128i a3 = _mm_unpackhi_epi16(in[2], in[3]); + // b0: 00 20 40 60 01 21 41 61 + // b1: 10 30 50 70 11 31 51 71 + // b2: 80 a0 c0 e0 81 a1 c1 e1 + // b3: 90 b0 d0 f0 91 b1 d1 f1 + const __m128i b0 = _mm_unpacklo_epi16(a0, a1); + const __m128i b1 = _mm_unpackhi_epi16(a0, a1); + const __m128i b2 = _mm_unpacklo_epi16(a2, a3); + const __m128i b3 = _mm_unpackhi_epi16(a2, a3); + // out[0]: 00 10 20 30 40 50 60 70 + // out[1]: 01 11 21 31 41 51 61 71 + // out[2]: 80 90 a0 b0 c0 d0 e0 f0 + // out[3]: 81 91 a1 b1 c1 d1 e1 f1 + out[0] = _mm_unpacklo_epi16(b0, b1); + out[1] = _mm_unpackhi_epi16(b0, b1); + out[2] = _mm_unpacklo_epi16(b2, b3); + out[3] = _mm_unpackhi_epi16(b2, b3); +} + +LIBGAV1_ALWAYS_INLINE __m128i Transpose4x4_U8(const __m128i* const in) { + // Unpack 8 bit elements. Goes from: + // in[0]: 00 01 02 03 + // in[1]: 10 11 12 13 + // in[2]: 20 21 22 23 + // in[3]: 30 31 32 33 + // to: + // a0: 00 10 01 11 02 12 03 13 + // a1: 20 30 21 31 22 32 23 33 + const __m128i a0 = _mm_unpacklo_epi8(in[0], in[1]); + const __m128i a1 = _mm_unpacklo_epi8(in[2], in[3]); + + // Unpack 32 bit elements resulting in: + // 00 10 20 30 01 11 21 31 02 12 22 32 03 13 23 33 + return _mm_unpacklo_epi16(a0, a1); +} + +LIBGAV1_ALWAYS_INLINE void Transpose8x8To4x16_U8(const __m128i* const in, + __m128i* out) { + // Unpack 8 bit elements. Goes from: + // in[0]: 00 01 02 03 04 05 06 07 + // in[1]: 10 11 12 13 14 15 16 17 + // in[2]: 20 21 22 23 24 25 26 27 + // in[3]: 30 31 32 33 34 35 36 37 + // in[4]: 40 41 42 43 44 45 46 47 + // in[5]: 50 51 52 53 54 55 56 57 + // in[6]: 60 61 62 63 64 65 66 67 + // in[7]: 70 71 72 73 74 75 76 77 + // to: + // a0: 00 10 01 11 02 12 03 13 04 14 05 15 06 16 07 17 + // a1: 20 30 21 31 22 32 23 33 24 34 25 35 26 36 27 37 + // a2: 40 50 41 51 42 52 43 53 44 54 45 55 46 56 47 57 + // a3: 60 70 61 71 62 72 63 73 64 74 65 75 66 76 67 77 + const __m128i a0 = _mm_unpacklo_epi8(in[0], in[1]); + const __m128i a1 = _mm_unpacklo_epi8(in[2], in[3]); + const __m128i a2 = _mm_unpacklo_epi8(in[4], in[5]); + const __m128i a3 = _mm_unpacklo_epi8(in[6], in[7]); + + // b0: 00 10 20 30 01 11 21 31 02 12 22 32 03 13 23 33 + // b1: 40 50 60 70 41 51 61 71 42 52 62 72 43 53 63 73 + // b2: 04 14 24 34 05 15 25 35 06 16 26 36 07 17 27 37 + // b3: 44 54 64 74 45 55 65 75 46 56 66 76 47 57 67 77 + const __m128i b0 = _mm_unpacklo_epi16(a0, a1); + const __m128i b1 = _mm_unpacklo_epi16(a2, a3); + const __m128i b2 = _mm_unpackhi_epi16(a0, a1); + const __m128i b3 = _mm_unpackhi_epi16(a2, a3); + + // out[0]: 00 10 20 30 40 50 60 70 01 11 21 31 41 51 61 71 + // out[1]: 02 12 22 32 42 52 62 72 03 13 23 33 43 53 63 73 + // out[2]: 04 14 24 34 44 54 64 74 05 15 25 35 45 55 65 75 + // out[3]: 06 16 26 36 46 56 66 76 07 17 27 37 47 57 67 77 + out[0] = _mm_unpacklo_epi32(b0, b1); + out[1] = _mm_unpackhi_epi32(b0, b1); + out[2] = _mm_unpacklo_epi32(b2, b3); + out[3] = _mm_unpackhi_epi32(b2, b3); +} + +LIBGAV1_ALWAYS_INLINE void Transpose4x4_U16(const __m128i* in, __m128i* out) { + // Unpack 16 bit elements. Goes from: + // in[0]: 00 01 02 03 XX XX XX XX + // in[1]: 10 11 12 13 XX XX XX XX + // in[2]: 20 21 22 23 XX XX XX XX + // in[3]: 30 31 32 33 XX XX XX XX + // to: + // ba: 00 10 01 11 02 12 03 13 + // dc: 20 30 21 31 22 32 23 33 + const __m128i ba = _mm_unpacklo_epi16(in[0], in[1]); + const __m128i dc = _mm_unpacklo_epi16(in[2], in[3]); + // Unpack 32 bit elements resulting in: + // dcba_lo: 00 10 20 30 01 11 21 31 + // dcba_hi: 02 12 22 32 03 13 23 33 + const __m128i dcba_lo = _mm_unpacklo_epi32(ba, dc); + const __m128i dcba_hi = _mm_unpackhi_epi32(ba, dc); + // Assign or shift right by 8 bytes resulting in: + // out[0]: 00 10 20 30 01 11 21 31 + // out[1]: 01 11 21 31 XX XX XX XX + // out[2]: 02 12 22 32 03 13 23 33 + // out[3]: 03 13 23 33 XX XX XX XX + out[0] = dcba_lo; + out[1] = _mm_srli_si128(dcba_lo, 8); + out[2] = dcba_hi; + out[3] = _mm_srli_si128(dcba_hi, 8); +} + +LIBGAV1_ALWAYS_INLINE void Transpose4x8To8x4_U16(const __m128i* in, + __m128i* out) { + // Unpack 16 bit elements. Goes from: + // in[0]: 00 01 02 03 XX XX XX XX + // in[1]: 10 11 12 13 XX XX XX XX + // in[2]: 20 21 22 23 XX XX XX XX + // in[3]: 30 31 32 33 XX XX XX XX + // in[4]: 40 41 42 43 XX XX XX XX + // in[5]: 50 51 52 53 XX XX XX XX + // in[6]: 60 61 62 63 XX XX XX XX + // in[7]: 70 71 72 73 XX XX XX XX + // to: + // a0: 00 10 01 11 02 12 03 13 + // a1: 20 30 21 31 22 32 23 33 + // a2: 40 50 41 51 42 52 43 53 + // a3: 60 70 61 71 62 72 63 73 + const __m128i a0 = _mm_unpacklo_epi16(in[0], in[1]); + const __m128i a1 = _mm_unpacklo_epi16(in[2], in[3]); + const __m128i a2 = _mm_unpacklo_epi16(in[4], in[5]); + const __m128i a3 = _mm_unpacklo_epi16(in[6], in[7]); + + // Unpack 32 bit elements resulting in: + // b0: 00 10 20 30 01 11 21 31 + // b1: 40 50 60 70 41 51 61 71 + // b2: 02 12 22 32 03 13 23 33 + // b3: 42 52 62 72 43 53 63 73 + const __m128i b0 = _mm_unpacklo_epi32(a0, a1); + const __m128i b1 = _mm_unpacklo_epi32(a2, a3); + const __m128i b2 = _mm_unpackhi_epi32(a0, a1); + const __m128i b3 = _mm_unpackhi_epi32(a2, a3); + + // Unpack 64 bit elements resulting in: + // out[0]: 00 10 20 30 40 50 60 70 + // out[1]: 01 11 21 31 41 51 61 71 + // out[2]: 02 12 22 32 42 52 62 72 + // out[3]: 03 13 23 33 43 53 63 73 + out[0] = _mm_unpacklo_epi64(b0, b1); + out[1] = _mm_unpackhi_epi64(b0, b1); + out[2] = _mm_unpacklo_epi64(b2, b3); + out[3] = _mm_unpackhi_epi64(b2, b3); +} + +LIBGAV1_ALWAYS_INLINE void Transpose8x4To4x8_U16(const __m128i* in, + __m128i* out) { + // Unpack 16 bit elements. Goes from: + // in[0]: 00 01 02 03 04 05 06 07 + // in[1]: 10 11 12 13 14 15 16 17 + // in[2]: 20 21 22 23 24 25 26 27 + // in[3]: 30 31 32 33 34 35 36 37 + + // to: + // a0: 00 10 01 11 02 12 03 13 + // a1: 20 30 21 31 22 32 23 33 + // a4: 04 14 05 15 06 16 07 17 + // a5: 24 34 25 35 26 36 27 37 + const __m128i a0 = _mm_unpacklo_epi16(in[0], in[1]); + const __m128i a1 = _mm_unpacklo_epi16(in[2], in[3]); + const __m128i a4 = _mm_unpackhi_epi16(in[0], in[1]); + const __m128i a5 = _mm_unpackhi_epi16(in[2], in[3]); + + // Unpack 32 bit elements resulting in: + // b0: 00 10 20 30 01 11 21 31 + // b2: 04 14 24 34 05 15 25 35 + // b4: 02 12 22 32 03 13 23 33 + // b6: 06 16 26 36 07 17 27 37 + const __m128i b0 = _mm_unpacklo_epi32(a0, a1); + const __m128i b2 = _mm_unpacklo_epi32(a4, a5); + const __m128i b4 = _mm_unpackhi_epi32(a0, a1); + const __m128i b6 = _mm_unpackhi_epi32(a4, a5); + + // Unpack 64 bit elements resulting in: + // out[0]: 00 10 20 30 XX XX XX XX + // out[1]: 01 11 21 31 XX XX XX XX + // out[2]: 02 12 22 32 XX XX XX XX + // out[3]: 03 13 23 33 XX XX XX XX + // out[4]: 04 14 24 34 XX XX XX XX + // out[5]: 05 15 25 35 XX XX XX XX + // out[6]: 06 16 26 36 XX XX XX XX + // out[7]: 07 17 27 37 XX XX XX XX + const __m128i zeros = _mm_setzero_si128(); + out[0] = _mm_unpacklo_epi64(b0, zeros); + out[1] = _mm_unpackhi_epi64(b0, zeros); + out[2] = _mm_unpacklo_epi64(b4, zeros); + out[3] = _mm_unpackhi_epi64(b4, zeros); + out[4] = _mm_unpacklo_epi64(b2, zeros); + out[5] = _mm_unpackhi_epi64(b2, zeros); + out[6] = _mm_unpacklo_epi64(b6, zeros); + out[7] = _mm_unpackhi_epi64(b6, zeros); +} + +LIBGAV1_ALWAYS_INLINE void Transpose8x8_U16(const __m128i* const in, + __m128i* const out) { + // Unpack 16 bit elements. Goes from: + // in[0]: 00 01 02 03 04 05 06 07 + // in[1]: 10 11 12 13 14 15 16 17 + // in[2]: 20 21 22 23 24 25 26 27 + // in[3]: 30 31 32 33 34 35 36 37 + // in[4]: 40 41 42 43 44 45 46 47 + // in[5]: 50 51 52 53 54 55 56 57 + // in[6]: 60 61 62 63 64 65 66 67 + // in[7]: 70 71 72 73 74 75 76 77 + // to: + // a0: 00 10 01 11 02 12 03 13 + // a1: 20 30 21 31 22 32 23 33 + // a2: 40 50 41 51 42 52 43 53 + // a3: 60 70 61 71 62 72 63 73 + // a4: 04 14 05 15 06 16 07 17 + // a5: 24 34 25 35 26 36 27 37 + // a6: 44 54 45 55 46 56 47 57 + // a7: 64 74 65 75 66 76 67 77 + const __m128i a0 = _mm_unpacklo_epi16(in[0], in[1]); + const __m128i a1 = _mm_unpacklo_epi16(in[2], in[3]); + const __m128i a2 = _mm_unpacklo_epi16(in[4], in[5]); + const __m128i a3 = _mm_unpacklo_epi16(in[6], in[7]); + const __m128i a4 = _mm_unpackhi_epi16(in[0], in[1]); + const __m128i a5 = _mm_unpackhi_epi16(in[2], in[3]); + const __m128i a6 = _mm_unpackhi_epi16(in[4], in[5]); + const __m128i a7 = _mm_unpackhi_epi16(in[6], in[7]); + + // Unpack 32 bit elements resulting in: + // b0: 00 10 20 30 01 11 21 31 + // b1: 40 50 60 70 41 51 61 71 + // b2: 04 14 24 34 05 15 25 35 + // b3: 44 54 64 74 45 55 65 75 + // b4: 02 12 22 32 03 13 23 33 + // b5: 42 52 62 72 43 53 63 73 + // b6: 06 16 26 36 07 17 27 37 + // b7: 46 56 66 76 47 57 67 77 + const __m128i b0 = _mm_unpacklo_epi32(a0, a1); + const __m128i b1 = _mm_unpacklo_epi32(a2, a3); + const __m128i b2 = _mm_unpacklo_epi32(a4, a5); + const __m128i b3 = _mm_unpacklo_epi32(a6, a7); + const __m128i b4 = _mm_unpackhi_epi32(a0, a1); + const __m128i b5 = _mm_unpackhi_epi32(a2, a3); + const __m128i b6 = _mm_unpackhi_epi32(a4, a5); + const __m128i b7 = _mm_unpackhi_epi32(a6, a7); + + // Unpack 64 bit elements resulting in: + // out[0]: 00 10 20 30 40 50 60 70 + // out[1]: 01 11 21 31 41 51 61 71 + // out[2]: 02 12 22 32 42 52 62 72 + // out[3]: 03 13 23 33 43 53 63 73 + // out[4]: 04 14 24 34 44 54 64 74 + // out[5]: 05 15 25 35 45 55 65 75 + // out[6]: 06 16 26 36 46 56 66 76 + // out[7]: 07 17 27 37 47 57 67 77 + out[0] = _mm_unpacklo_epi64(b0, b1); + out[1] = _mm_unpackhi_epi64(b0, b1); + out[2] = _mm_unpacklo_epi64(b4, b5); + out[3] = _mm_unpackhi_epi64(b4, b5); + out[4] = _mm_unpacklo_epi64(b2, b3); + out[5] = _mm_unpackhi_epi64(b2, b3); + out[6] = _mm_unpacklo_epi64(b6, b7); + out[7] = _mm_unpackhi_epi64(b6, b7); +} + +} // namespace dsp +} // namespace libgav1 + +#endif // LIBGAV1_TARGETING_SSE4_1 +#endif // LIBGAV1_SRC_DSP_X86_TRANSPOSE_SSE4_H_ diff --git a/src/dsp/x86/warp_sse4.cc b/src/dsp/x86/warp_sse4.cc new file mode 100644 index 0000000..43279ab --- /dev/null +++ b/src/dsp/x86/warp_sse4.cc @@ -0,0 +1,525 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/warp.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 + +#include <smmintrin.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstring> +#include <type_traits> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/dsp/x86/transpose_sse4.h" +#include "src/utils/common.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +// Number of extra bits of precision in warped filtering. +constexpr int kWarpedDiffPrecisionBits = 10; + +// This assumes the two filters contain filter[x] and filter[x+2]. +inline __m128i AccumulateFilter(const __m128i sum, const __m128i filter_0, + const __m128i filter_1, + const __m128i& src_window) { + const __m128i filter_taps = _mm_unpacklo_epi8(filter_0, filter_1); + const __m128i src = + _mm_unpacklo_epi8(src_window, _mm_srli_si128(src_window, 2)); + return _mm_add_epi16(sum, _mm_maddubs_epi16(src, filter_taps)); +} + +constexpr int kFirstPassOffset = 1 << 14; +constexpr int kOffsetRemoval = + (kFirstPassOffset >> kInterRoundBitsHorizontal) * 128; + +// Applies the horizontal filter to one source row and stores the result in +// |intermediate_result_row|. |intermediate_result_row| is a row in the 15x8 +// |intermediate_result| two-dimensional array. +inline void HorizontalFilter(const int sx4, const int16_t alpha, + const __m128i src_row, + int16_t intermediate_result_row[8]) { + int sx = sx4 - MultiplyBy4(alpha); + __m128i filter[8]; + for (__m128i& f : filter) { + const int offset = RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) + + kWarpedPixelPrecisionShifts; + f = LoadLo8(kWarpedFilters8[offset]); + sx += alpha; + } + Transpose8x8To4x16_U8(filter, filter); + // |filter| now contains two filters per register. + // Staggered combinations allow us to take advantage of _mm_maddubs_epi16 + // without overflowing the sign bit. The sign bit is hit only where two taps + // paired in a single madd add up to more than 128. This is only possible with + // two adjacent "inner" taps. Therefore, pairing odd with odd and even with + // even guarantees safety. |sum| is given a negative offset to allow for large + // intermediate values. + // k = 0, 2. + __m128i src_row_window = src_row; + __m128i sum = _mm_set1_epi16(-kFirstPassOffset); + sum = AccumulateFilter(sum, filter[0], filter[1], src_row_window); + + // k = 1, 3. + src_row_window = _mm_srli_si128(src_row_window, 1); + sum = AccumulateFilter(sum, _mm_srli_si128(filter[0], 8), + _mm_srli_si128(filter[1], 8), src_row_window); + // k = 4, 6. + src_row_window = _mm_srli_si128(src_row_window, 3); + sum = AccumulateFilter(sum, filter[2], filter[3], src_row_window); + + // k = 5, 7. + src_row_window = _mm_srli_si128(src_row_window, 1); + sum = AccumulateFilter(sum, _mm_srli_si128(filter[2], 8), + _mm_srli_si128(filter[3], 8), src_row_window); + + sum = RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal); + StoreUnaligned16(intermediate_result_row, sum); +} + +template <bool is_compound> +inline void WriteVerticalFilter(const __m128i filter[8], + const int16_t intermediate_result[15][8], int y, + void* dst_row) { + constexpr int kRoundBitsVertical = + is_compound ? kInterRoundBitsCompoundVertical : kInterRoundBitsVertical; + __m128i sum_low = _mm_set1_epi32(kOffsetRemoval); + __m128i sum_high = sum_low; + for (int k = 0; k < 8; k += 2) { + const __m128i filters_low = _mm_unpacklo_epi16(filter[k], filter[k + 1]); + const __m128i filters_high = _mm_unpackhi_epi16(filter[k], filter[k + 1]); + const __m128i intermediate_0 = LoadUnaligned16(intermediate_result[y + k]); + const __m128i intermediate_1 = + LoadUnaligned16(intermediate_result[y + k + 1]); + const __m128i intermediate_low = + _mm_unpacklo_epi16(intermediate_0, intermediate_1); + const __m128i intermediate_high = + _mm_unpackhi_epi16(intermediate_0, intermediate_1); + + const __m128i product_low = _mm_madd_epi16(filters_low, intermediate_low); + const __m128i product_high = + _mm_madd_epi16(filters_high, intermediate_high); + sum_low = _mm_add_epi32(sum_low, product_low); + sum_high = _mm_add_epi32(sum_high, product_high); + } + sum_low = RightShiftWithRounding_S32(sum_low, kRoundBitsVertical); + sum_high = RightShiftWithRounding_S32(sum_high, kRoundBitsVertical); + if (is_compound) { + const __m128i sum = _mm_packs_epi32(sum_low, sum_high); + StoreUnaligned16(static_cast<int16_t*>(dst_row), sum); + } else { + const __m128i sum = _mm_packus_epi32(sum_low, sum_high); + StoreLo8(static_cast<uint8_t*>(dst_row), _mm_packus_epi16(sum, sum)); + } +} + +template <bool is_compound> +inline void WriteVerticalFilter(const __m128i filter[8], + const int16_t* intermediate_result_column, + void* dst_row) { + constexpr int kRoundBitsVertical = + is_compound ? kInterRoundBitsCompoundVertical : kInterRoundBitsVertical; + __m128i sum_low = _mm_setzero_si128(); + __m128i sum_high = _mm_setzero_si128(); + for (int k = 0; k < 8; k += 2) { + const __m128i filters_low = _mm_unpacklo_epi16(filter[k], filter[k + 1]); + const __m128i filters_high = _mm_unpackhi_epi16(filter[k], filter[k + 1]); + // Equivalent to unpacking two vectors made by duplicating int16_t values. + const __m128i intermediate = + _mm_set1_epi32((intermediate_result_column[k + 1] << 16) | + intermediate_result_column[k]); + const __m128i product_low = _mm_madd_epi16(filters_low, intermediate); + const __m128i product_high = _mm_madd_epi16(filters_high, intermediate); + sum_low = _mm_add_epi32(sum_low, product_low); + sum_high = _mm_add_epi32(sum_high, product_high); + } + sum_low = RightShiftWithRounding_S32(sum_low, kRoundBitsVertical); + sum_high = RightShiftWithRounding_S32(sum_high, kRoundBitsVertical); + if (is_compound) { + const __m128i sum = _mm_packs_epi32(sum_low, sum_high); + StoreUnaligned16(static_cast<int16_t*>(dst_row), sum); + } else { + const __m128i sum = _mm_packus_epi32(sum_low, sum_high); + StoreLo8(static_cast<uint8_t*>(dst_row), _mm_packus_epi16(sum, sum)); + } +} + +template <bool is_compound, typename DestType> +inline void VerticalFilter(const int16_t source[15][8], int y4, int gamma, + int delta, DestType* dest_row, + ptrdiff_t dest_stride) { + int sy4 = (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta); + for (int y = 0; y < 8; ++y) { + int sy = sy4 - MultiplyBy4(gamma); + __m128i filter[8]; + for (__m128i& f : filter) { + const int offset = RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) + + kWarpedPixelPrecisionShifts; + f = LoadUnaligned16(kWarpedFilters[offset]); + sy += gamma; + } + Transpose8x8_U16(filter, filter); + WriteVerticalFilter<is_compound>(filter, source, y, dest_row); + dest_row += dest_stride; + sy4 += delta; + } +} + +template <bool is_compound, typename DestType> +inline void VerticalFilter(const int16_t* source_cols, int y4, int gamma, + int delta, DestType* dest_row, + ptrdiff_t dest_stride) { + int sy4 = (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta); + for (int y = 0; y < 8; ++y) { + int sy = sy4 - MultiplyBy4(gamma); + __m128i filter[8]; + for (__m128i& f : filter) { + const int offset = RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) + + kWarpedPixelPrecisionShifts; + f = LoadUnaligned16(kWarpedFilters[offset]); + sy += gamma; + } + Transpose8x8_U16(filter, filter); + WriteVerticalFilter<is_compound>(filter, &source_cols[y], dest_row); + dest_row += dest_stride; + sy4 += delta; + } +} + +template <bool is_compound, typename DestType> +inline void WarpRegion1(const uint8_t* src, ptrdiff_t source_stride, + int source_width, int source_height, int ix4, int iy4, + DestType* dst_row, ptrdiff_t dest_stride) { + // Region 1 + // Points to the left or right border of the first row of |src|. + const uint8_t* first_row_border = + (ix4 + 7 <= 0) ? src : src + source_width - 1; + // In general, for y in [-7, 8), the row number iy4 + y is clipped: + // const int row = Clip3(iy4 + y, 0, source_height - 1); + // In two special cases, iy4 + y is clipped to either 0 or + // source_height - 1 for all y. In the rest of the cases, iy4 + y is + // bounded and we can avoid clipping iy4 + y by relying on a reference + // frame's boundary extension on the top and bottom. + // Region 1. + // Every sample used to calculate the prediction block has the same + // value. So the whole prediction block has the same value. + const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1; + const uint8_t row_border_pixel = first_row_border[row * source_stride]; + + if (is_compound) { + const __m128i sum = + _mm_set1_epi16(row_border_pixel << (kInterRoundBitsVertical - + kInterRoundBitsCompoundVertical)); + StoreUnaligned16(dst_row, sum); + } else { + memset(dst_row, row_border_pixel, 8); + } + const DestType* const first_dst_row = dst_row; + dst_row += dest_stride; + for (int y = 1; y < 8; ++y) { + memcpy(dst_row, first_dst_row, 8 * sizeof(*dst_row)); + dst_row += dest_stride; + } +} + +template <bool is_compound, typename DestType> +inline void WarpRegion2(const uint8_t* src, ptrdiff_t source_stride, + int source_width, int y4, int ix4, int iy4, int gamma, + int delta, int16_t intermediate_result_column[15], + DestType* dst_row, ptrdiff_t dest_stride) { + // Region 2. + // Points to the left or right border of the first row of |src|. + const uint8_t* first_row_border = + (ix4 + 7 <= 0) ? src : src + source_width - 1; + // In general, for y in [-7, 8), the row number iy4 + y is clipped: + // const int row = Clip3(iy4 + y, 0, source_height - 1); + // In two special cases, iy4 + y is clipped to either 0 or + // source_height - 1 for all y. In the rest of the cases, iy4 + y is + // bounded and we can avoid clipping iy4 + y by relying on a reference + // frame's boundary extension on the top and bottom. + + // Region 2. + // Horizontal filter. + // The input values in this region are generated by extending the border + // which makes them identical in the horizontal direction. This + // computation could be inlined in the vertical pass but most + // implementations will need a transpose of some sort. + // It is not necessary to use the offset values here because the + // horizontal pass is a simple shift and the vertical pass will always + // require using 32 bits. + for (int y = -7; y < 8; ++y) { + // We may over-read up to 13 pixels above the top source row, or up + // to 13 pixels below the bottom source row. This is proved in + // warp.cc. + const int row = iy4 + y; + int sum = first_row_border[row * source_stride]; + sum <<= (kFilterBits - kInterRoundBitsHorizontal); + intermediate_result_column[y + 7] = sum; + } + // Region 2 vertical filter. + VerticalFilter<is_compound, DestType>(intermediate_result_column, y4, gamma, + delta, dst_row, dest_stride); +} + +template <bool is_compound, typename DestType> +inline void WarpRegion3(const uint8_t* src, ptrdiff_t source_stride, + int source_height, int alpha, int beta, int x4, int ix4, + int iy4, int16_t intermediate_result[15][8]) { + // Region 3 + // At this point, we know ix4 - 7 < source_width - 1 and ix4 + 7 > 0. + + // In general, for y in [-7, 8), the row number iy4 + y is clipped: + // const int row = Clip3(iy4 + y, 0, source_height - 1); + // In two special cases, iy4 + y is clipped to either 0 or + // source_height - 1 for all y. In the rest of the cases, iy4 + y is + // bounded and we can avoid clipping iy4 + y by relying on a reference + // frame's boundary extension on the top and bottom. + // Horizontal filter. + const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1; + const uint8_t* const src_row = src + row * source_stride; + // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also + // read but is ignored. + // + // NOTE: This may read up to 13 bytes before src_row[0] or up to 14 + // bytes after src_row[source_width - 1]. We assume the source frame + // has left and right borders of at least 13 bytes that extend the + // frame boundary pixels. We also assume there is at least one extra + // padding byte after the right border of the last source row. + const __m128i src_row_v = LoadUnaligned16(&src_row[ix4 - 7]); + int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7; + for (int y = -7; y < 8; ++y) { + HorizontalFilter(sx4, alpha, src_row_v, intermediate_result[y + 7]); + sx4 += beta; + } +} + +template <bool is_compound, typename DestType> +inline void WarpRegion4(const uint8_t* src, ptrdiff_t source_stride, int alpha, + int beta, int x4, int ix4, int iy4, + int16_t intermediate_result[15][8]) { + // Region 4. + // At this point, we know ix4 - 7 < source_width - 1 and ix4 + 7 > 0. + + // In general, for y in [-7, 8), the row number iy4 + y is clipped: + // const int row = Clip3(iy4 + y, 0, source_height - 1); + // In two special cases, iy4 + y is clipped to either 0 or + // source_height - 1 for all y. In the rest of the cases, iy4 + y is + // bounded and we can avoid clipping iy4 + y by relying on a reference + // frame's boundary extension on the top and bottom. + // Horizontal filter. + int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7; + for (int y = -7; y < 8; ++y) { + // We may over-read up to 13 pixels above the top source row, or up + // to 13 pixels below the bottom source row. This is proved in + // warp.cc. + const int row = iy4 + y; + const uint8_t* const src_row = src + row * source_stride; + // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also + // read but is ignored. + // + // NOTE: This may read up to 13 bytes before src_row[0] or up to 14 + // bytes after src_row[source_width - 1]. We assume the source frame + // has left and right borders of at least 13 bytes that extend the + // frame boundary pixels. We also assume there is at least one extra + // padding byte after the right border of the last source row. + const __m128i src_row_v = LoadUnaligned16(&src_row[ix4 - 7]); + // Convert src_row_v to int8 (subtract 128). + HorizontalFilter(sx4, alpha, src_row_v, intermediate_result[y + 7]); + sx4 += beta; + } +} + +template <bool is_compound, typename DestType> +inline void HandleWarpBlock(const uint8_t* src, ptrdiff_t source_stride, + int source_width, int source_height, + const int* warp_params, int subsampling_x, + int subsampling_y, int src_x, int src_y, + int16_t alpha, int16_t beta, int16_t gamma, + int16_t delta, DestType* dst_row, + ptrdiff_t dest_stride) { + union { + // Intermediate_result is the output of the horizontal filtering and + // rounding. The range is within 13 (= bitdepth + kFilterBits + 1 - + // kInterRoundBitsHorizontal) bits (unsigned). We use the signed int16_t + // type so that we can start with a negative offset and restore it on the + // final filter sum. + int16_t intermediate_result[15][8]; // 15 rows, 8 columns. + // In the simple special cases where the samples in each row are all the + // same, store one sample per row in a column vector. + int16_t intermediate_result_column[15]; + }; + + const int dst_x = + src_x * warp_params[2] + src_y * warp_params[3] + warp_params[0]; + const int dst_y = + src_x * warp_params[4] + src_y * warp_params[5] + warp_params[1]; + const int x4 = dst_x >> subsampling_x; + const int y4 = dst_y >> subsampling_y; + const int ix4 = x4 >> kWarpedModelPrecisionBits; + const int iy4 = y4 >> kWarpedModelPrecisionBits; + // A prediction block may fall outside the frame's boundaries. If a + // prediction block is calculated using only samples outside the frame's + // boundary, the filtering can be simplified. We can divide the plane + // into several regions and handle them differently. + // + // | | + // 1 | 3 | 1 + // | | + // -------+-----------+------- + // |***********| + // 2 |*****4*****| 2 + // |***********| + // -------+-----------+------- + // | | + // 1 | 3 | 1 + // | | + // + // At the center, region 4 represents the frame and is the general case. + // + // In regions 1 and 2, the prediction block is outside the frame's + // boundary horizontally. Therefore the horizontal filtering can be + // simplified. Furthermore, in the region 1 (at the four corners), the + // prediction is outside the frame's boundary both horizontally and + // vertically, so we get a constant prediction block. + // + // In region 3, the prediction block is outside the frame's boundary + // vertically. Unfortunately because we apply the horizontal filters + // first, by the time we apply the vertical filters, they no longer see + // simple inputs. So the only simplification is that all the rows are + // the same, but we still need to apply all the horizontal and vertical + // filters. + + // Check for two simple special cases, where the horizontal filter can + // be significantly simplified. + // + // In general, for each row, the horizontal filter is calculated as + // follows: + // for (int x = -4; x < 4; ++x) { + // const int offset = ...; + // int sum = first_pass_offset; + // for (int k = 0; k < 8; ++k) { + // const int column = Clip3(ix4 + x + k - 3, 0, source_width - 1); + // sum += kWarpedFilters[offset][k] * src_row[column]; + // } + // ... + // } + // The column index before clipping, ix4 + x + k - 3, varies in the range + // ix4 - 7 <= ix4 + x + k - 3 <= ix4 + 7. If ix4 - 7 >= source_width - 1 + // or ix4 + 7 <= 0, then all the column indexes are clipped to the same + // border index (source_width - 1 or 0, respectively). Then for each x, + // the inner for loop of the horizontal filter is reduced to multiplying + // the border pixel by the sum of the filter coefficients. + if (ix4 - 7 >= source_width - 1 || ix4 + 7 <= 0) { + if ((iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0)) { + // Outside the frame in both directions. One repeated value. + WarpRegion1<is_compound, DestType>(src, source_stride, source_width, + source_height, ix4, iy4, dst_row, + dest_stride); + return; + } + // Outside the frame horizontally. Rows repeated. + WarpRegion2<is_compound, DestType>( + src, source_stride, source_width, y4, ix4, iy4, gamma, delta, + intermediate_result_column, dst_row, dest_stride); + return; + } + + if ((iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0)) { + // Outside the frame vertically. + WarpRegion3<is_compound, DestType>(src, source_stride, source_height, alpha, + beta, x4, ix4, iy4, intermediate_result); + } else { + // Inside the frame. + WarpRegion4<is_compound, DestType>(src, source_stride, alpha, beta, x4, ix4, + iy4, intermediate_result); + } + // Region 3 and 4 vertical filter. + VerticalFilter<is_compound, DestType>(intermediate_result, y4, gamma, delta, + dst_row, dest_stride); +} + +template <bool is_compound> +void Warp_SSE4_1(const void* source, ptrdiff_t source_stride, int source_width, + int source_height, const int* warp_params, int subsampling_x, + int subsampling_y, int block_start_x, int block_start_y, + int block_width, int block_height, int16_t alpha, int16_t beta, + int16_t gamma, int16_t delta, void* dest, + ptrdiff_t dest_stride) { + const auto* const src = static_cast<const uint8_t*>(source); + using DestType = + typename std::conditional<is_compound, int16_t, uint8_t>::type; + auto* dst = static_cast<DestType*>(dest); + + // Warp process applies for each 8x8 block. + assert(block_width >= 8); + assert(block_height >= 8); + const int block_end_x = block_start_x + block_width; + const int block_end_y = block_start_y + block_height; + + const int start_x = block_start_x; + const int start_y = block_start_y; + int src_x = (start_x + 4) << subsampling_x; + int src_y = (start_y + 4) << subsampling_y; + const int end_x = (block_end_x + 4) << subsampling_x; + const int end_y = (block_end_y + 4) << subsampling_y; + do { + DestType* dst_row = dst; + src_x = (start_x + 4) << subsampling_x; + do { + HandleWarpBlock<is_compound, DestType>( + src, source_stride, source_width, source_height, warp_params, + subsampling_x, subsampling_y, src_x, src_y, alpha, beta, gamma, delta, + dst_row, dest_stride); + src_x += (8 << subsampling_x); + dst_row += 8; + } while (src_x < end_x); + dst += 8 * dest_stride; + src_y += (8 << subsampling_y); + } while (src_y < end_y); +} + +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + dsp->warp = Warp_SSE4_1</*is_compound=*/false>; + dsp->warp_compound = Warp_SSE4_1</*is_compound=*/true>; +} + +} // namespace +} // namespace low_bitdepth + +void WarpInit_SSE4_1() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 +#else // !LIBGAV1_TARGETING_SSE4_1 + +namespace libgav1 { +namespace dsp { + +void WarpInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/warp_sse4.h b/src/dsp/x86/warp_sse4.h new file mode 100644 index 0000000..a2dc5ca --- /dev/null +++ b/src/dsp/x86/warp_sse4.h @@ -0,0 +1,44 @@ +/* + * Copyright 2020 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_X86_WARP_SSE4_H_ +#define LIBGAV1_SRC_DSP_X86_WARP_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::warp. This function is not thread-safe. +void WarpInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_TARGETING_SSE4_1 + +#ifndef LIBGAV1_Dsp8bpp_Warp +#define LIBGAV1_Dsp8bpp_Warp LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_WarpCompound +#define LIBGAV1_Dsp8bpp_WarpCompound LIBGAV1_CPU_SSE4_1 +#endif + +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_X86_WARP_SSE4_H_ diff --git a/src/dsp/x86/weight_mask_sse4.cc b/src/dsp/x86/weight_mask_sse4.cc new file mode 100644 index 0000000..dfd5662 --- /dev/null +++ b/src/dsp/x86/weight_mask_sse4.cc @@ -0,0 +1,464 @@ +// Copyright 2020 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/x86/weight_mask_sse4.h" + +#include "src/utils/cpu.h" + +#if LIBGAV1_TARGETING_SSE4_1 + +#include <smmintrin.h> + +#include <cassert> +#include <cstddef> +#include <cstdint> + +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/dsp/x86/common_sse4.h" +#include "src/utils/common.h" + +namespace libgav1 { +namespace dsp { +namespace low_bitdepth { +namespace { + +constexpr int kRoundingBits8bpp = 4; + +template <bool mask_is_inverse> +inline void WeightMask8_SSE4(const int16_t* prediction_0, + const int16_t* prediction_1, uint8_t* mask) { + const __m128i pred_0 = LoadAligned16(prediction_0); + const __m128i pred_1 = LoadAligned16(prediction_1); + const __m128i difference = RightShiftWithRounding_U16( + _mm_abs_epi16(_mm_sub_epi16(pred_0, pred_1)), kRoundingBits8bpp); + const __m128i scaled_difference = _mm_srli_epi16(difference, 4); + const __m128i difference_offset = _mm_set1_epi8(38); + const __m128i adjusted_difference = + _mm_adds_epu8(_mm_packus_epi16(scaled_difference, scaled_difference), + difference_offset); + const __m128i mask_ceiling = _mm_set1_epi8(64); + const __m128i mask_value = _mm_min_epi8(adjusted_difference, mask_ceiling); + if (mask_is_inverse) { + const __m128i inverted_mask_value = _mm_sub_epi8(mask_ceiling, mask_value); + StoreLo8(mask, inverted_mask_value); + } else { + StoreLo8(mask, mask_value); + } +} + +#define WEIGHT8_WITHOUT_STRIDE \ + WeightMask8_SSE4<mask_is_inverse>(pred_0, pred_1, mask) + +#define WEIGHT8_AND_STRIDE \ + WEIGHT8_WITHOUT_STRIDE; \ + pred_0 += 8; \ + pred_1 += 8; \ + mask += mask_stride + +template <bool mask_is_inverse> +void WeightMask8x8_SSE4(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y = 0; + do { + WEIGHT8_AND_STRIDE; + } while (++y < 7); + WEIGHT8_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask8x16_SSE4(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT8_AND_STRIDE; + WEIGHT8_AND_STRIDE; + WEIGHT8_AND_STRIDE; + } while (++y3 < 5); + WEIGHT8_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask8x32_SSE4(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y5 = 0; + do { + WEIGHT8_AND_STRIDE; + WEIGHT8_AND_STRIDE; + WEIGHT8_AND_STRIDE; + WEIGHT8_AND_STRIDE; + WEIGHT8_AND_STRIDE; + } while (++y5 < 6); + WEIGHT8_AND_STRIDE; + WEIGHT8_WITHOUT_STRIDE; +} + +#define WEIGHT16_WITHOUT_STRIDE \ + WeightMask8_SSE4<mask_is_inverse>(pred_0, pred_1, mask); \ + WeightMask8_SSE4<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8) + +#define WEIGHT16_AND_STRIDE \ + WEIGHT16_WITHOUT_STRIDE; \ + pred_0 += 16; \ + pred_1 += 16; \ + mask += mask_stride + +template <bool mask_is_inverse> +void WeightMask16x8_SSE4(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y = 0; + do { + WEIGHT16_AND_STRIDE; + } while (++y < 7); + WEIGHT16_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask16x16_SSE4(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + } while (++y3 < 5); + WEIGHT16_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask16x32_SSE4(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y5 = 0; + do { + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + } while (++y5 < 6); + WEIGHT16_AND_STRIDE; + WEIGHT16_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask16x64_SSE4(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + WEIGHT16_AND_STRIDE; + } while (++y3 < 21); + WEIGHT16_WITHOUT_STRIDE; +} + +#define WEIGHT32_WITHOUT_STRIDE \ + WeightMask8_SSE4<mask_is_inverse>(pred_0, pred_1, mask); \ + WeightMask8_SSE4<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8); \ + WeightMask8_SSE4<mask_is_inverse>(pred_0 + 16, pred_1 + 16, mask + 16); \ + WeightMask8_SSE4<mask_is_inverse>(pred_0 + 24, pred_1 + 24, mask + 24) + +#define WEIGHT32_AND_STRIDE \ + WEIGHT32_WITHOUT_STRIDE; \ + pred_0 += 32; \ + pred_1 += 32; \ + mask += mask_stride + +template <bool mask_is_inverse> +void WeightMask32x8_SSE4(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask32x16_SSE4(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + } while (++y3 < 5); + WEIGHT32_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask32x32_SSE4(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y5 = 0; + do { + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + } while (++y5 < 6); + WEIGHT32_AND_STRIDE; + WEIGHT32_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask32x64_SSE4(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + WEIGHT32_AND_STRIDE; + } while (++y3 < 21); + WEIGHT32_WITHOUT_STRIDE; +} + +#define WEIGHT64_WITHOUT_STRIDE \ + WeightMask8_SSE4<mask_is_inverse>(pred_0, pred_1, mask); \ + WeightMask8_SSE4<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8); \ + WeightMask8_SSE4<mask_is_inverse>(pred_0 + 16, pred_1 + 16, mask + 16); \ + WeightMask8_SSE4<mask_is_inverse>(pred_0 + 24, pred_1 + 24, mask + 24); \ + WeightMask8_SSE4<mask_is_inverse>(pred_0 + 32, pred_1 + 32, mask + 32); \ + WeightMask8_SSE4<mask_is_inverse>(pred_0 + 40, pred_1 + 40, mask + 40); \ + WeightMask8_SSE4<mask_is_inverse>(pred_0 + 48, pred_1 + 48, mask + 48); \ + WeightMask8_SSE4<mask_is_inverse>(pred_0 + 56, pred_1 + 56, mask + 56) + +#define WEIGHT64_AND_STRIDE \ + WEIGHT64_WITHOUT_STRIDE; \ + pred_0 += 64; \ + pred_1 += 64; \ + mask += mask_stride + +template <bool mask_is_inverse> +void WeightMask64x16_SSE4(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + } while (++y3 < 5); + WEIGHT64_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask64x32_SSE4(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y5 = 0; + do { + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + } while (++y5 < 6); + WEIGHT64_AND_STRIDE; + WEIGHT64_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask64x64_SSE4(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + } while (++y3 < 21); + WEIGHT64_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask64x128_SSE4(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + do { + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + WEIGHT64_AND_STRIDE; + } while (++y3 < 42); + WEIGHT64_AND_STRIDE; + WEIGHT64_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask128x64_SSE4(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + const ptrdiff_t adjusted_mask_stride = mask_stride - 64; + do { + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + } while (++y3 < 21); + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; +} + +template <bool mask_is_inverse> +void WeightMask128x128_SSE4(const void* prediction_0, const void* prediction_1, + uint8_t* mask, ptrdiff_t mask_stride) { + const auto* pred_0 = static_cast<const int16_t*>(prediction_0); + const auto* pred_1 = static_cast<const int16_t*>(prediction_1); + int y3 = 0; + const ptrdiff_t adjusted_mask_stride = mask_stride - 64; + do { + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + } while (++y3 < 42); + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += adjusted_mask_stride; + + WEIGHT64_WITHOUT_STRIDE; + pred_0 += 64; + pred_1 += 64; + mask += 64; + WEIGHT64_WITHOUT_STRIDE; +} + +#define INIT_WEIGHT_MASK_8BPP(width, height, w_index, h_index) \ + dsp->weight_mask[w_index][h_index][0] = \ + WeightMask##width##x##height##_SSE4<0>; \ + dsp->weight_mask[w_index][h_index][1] = WeightMask##width##x##height##_SSE4<1> +void Init8bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8); + assert(dsp != nullptr); + INIT_WEIGHT_MASK_8BPP(8, 8, 0, 0); + INIT_WEIGHT_MASK_8BPP(8, 16, 0, 1); + INIT_WEIGHT_MASK_8BPP(8, 32, 0, 2); + INIT_WEIGHT_MASK_8BPP(16, 8, 1, 0); + INIT_WEIGHT_MASK_8BPP(16, 16, 1, 1); + INIT_WEIGHT_MASK_8BPP(16, 32, 1, 2); + INIT_WEIGHT_MASK_8BPP(16, 64, 1, 3); + INIT_WEIGHT_MASK_8BPP(32, 8, 2, 0); + INIT_WEIGHT_MASK_8BPP(32, 16, 2, 1); + INIT_WEIGHT_MASK_8BPP(32, 32, 2, 2); + INIT_WEIGHT_MASK_8BPP(32, 64, 2, 3); + INIT_WEIGHT_MASK_8BPP(64, 16, 3, 1); + INIT_WEIGHT_MASK_8BPP(64, 32, 3, 2); + INIT_WEIGHT_MASK_8BPP(64, 64, 3, 3); + INIT_WEIGHT_MASK_8BPP(64, 128, 3, 4); + INIT_WEIGHT_MASK_8BPP(128, 64, 4, 3); + INIT_WEIGHT_MASK_8BPP(128, 128, 4, 4); +} + +} // namespace +} // namespace low_bitdepth + +void WeightMaskInit_SSE4_1() { low_bitdepth::Init8bpp(); } + +} // namespace dsp +} // namespace libgav1 + +#else // !LIBGAV1_TARGETING_SSE4_1 + +namespace libgav1 { +namespace dsp { + +void WeightMaskInit_SSE4_1() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_TARGETING_SSE4_1 diff --git a/src/dsp/x86/weight_mask_sse4.h b/src/dsp/x86/weight_mask_sse4.h new file mode 100644 index 0000000..07636b7 --- /dev/null +++ b/src/dsp/x86/weight_mask_sse4.h @@ -0,0 +1,104 @@ +/* + * Copyright 2020 The libgav1 Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_SSE4_H_ +#define LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_SSE4_H_ + +#include "src/dsp/dsp.h" +#include "src/utils/cpu.h" + +namespace libgav1 { +namespace dsp { + +// Initializes Dsp::weight_mask. This function is not thread-safe. +void WeightMaskInit_SSE4_1(); + +} // namespace dsp +} // namespace libgav1 + +#if LIBGAV1_TARGETING_SSE4_1 + +#ifndef LIBGAV1_Dsp8bpp_WeightMask_8x8 +#define LIBGAV1_Dsp8bpp_WeightMask_8x8 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_WeightMask_8x16 +#define LIBGAV1_Dsp8bpp_WeightMask_8x16 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_WeightMask_8x32 +#define LIBGAV1_Dsp8bpp_WeightMask_8x32 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_WeightMask_16x8 +#define LIBGAV1_Dsp8bpp_WeightMask_16x8 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_WeightMask_16x16 +#define LIBGAV1_Dsp8bpp_WeightMask_16x16 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_WeightMask_16x32 +#define LIBGAV1_Dsp8bpp_WeightMask_16x32 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_WeightMask_16x64 +#define LIBGAV1_Dsp8bpp_WeightMask_16x64 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_WeightMask_32x8 +#define LIBGAV1_Dsp8bpp_WeightMask_32x8 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_WeightMask_32x16 +#define LIBGAV1_Dsp8bpp_WeightMask_32x16 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_WeightMask_32x32 +#define LIBGAV1_Dsp8bpp_WeightMask_32x32 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_WeightMask_32x64 +#define LIBGAV1_Dsp8bpp_WeightMask_32x64 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_WeightMask_64x16 +#define LIBGAV1_Dsp8bpp_WeightMask_64x16 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_WeightMask_64x32 +#define LIBGAV1_Dsp8bpp_WeightMask_64x32 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_WeightMask_64x64 +#define LIBGAV1_Dsp8bpp_WeightMask_64x64 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_WeightMask_64x128 +#define LIBGAV1_Dsp8bpp_WeightMask_64x128 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_WeightMask_128x64 +#define LIBGAV1_Dsp8bpp_WeightMask_128x64 LIBGAV1_CPU_SSE4_1 +#endif + +#ifndef LIBGAV1_Dsp8bpp_WeightMask_128x128 +#define LIBGAV1_Dsp8bpp_WeightMask_128x128 LIBGAV1_CPU_SSE4_1 +#endif + +#endif // LIBGAV1_TARGETING_SSE4_1 + +#endif // LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_SSE4_H_ |