diff options
Diffstat (limited to 'src/dsp/arm/inverse_transform_10bit_neon.cc')
-rw-r--r-- | src/dsp/arm/inverse_transform_10bit_neon.cc | 2543 |
1 files changed, 2543 insertions, 0 deletions
diff --git a/src/dsp/arm/inverse_transform_10bit_neon.cc b/src/dsp/arm/inverse_transform_10bit_neon.cc new file mode 100644 index 0000000..ff184a1 --- /dev/null +++ b/src/dsp/arm/inverse_transform_10bit_neon.cc @@ -0,0 +1,2543 @@ +// Copyright 2021 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/dsp/inverse_transform.h" +#include "src/utils/cpu.h" + +#if LIBGAV1_ENABLE_NEON && LIBGAV1_MAX_BITDEPTH >= 10 + +#include <arm_neon.h> + +#include <algorithm> +#include <cassert> +#include <cstdint> + +#include "src/dsp/arm/common_neon.h" +#include "src/dsp/constants.h" +#include "src/dsp/dsp.h" +#include "src/utils/array_2d.h" +#include "src/utils/common.h" +#include "src/utils/compiler_attributes.h" +#include "src/utils/constants.h" + +namespace libgav1 { +namespace dsp { +namespace { + +// Include the constants and utility functions inside the anonymous namespace. +#include "src/dsp/inverse_transform.inc" + +//------------------------------------------------------------------------------ + +LIBGAV1_ALWAYS_INLINE void Transpose4x4(const int32x4_t in[4], + int32x4_t out[4]) { + // in: + // 00 01 02 03 + // 10 11 12 13 + // 20 21 22 23 + // 30 31 32 33 + + // 00 10 02 12 a.val[0] + // 01 11 03 13 a.val[1] + // 20 30 22 32 b.val[0] + // 21 31 23 33 b.val[1] + const int32x4x2_t a = vtrnq_s32(in[0], in[1]); + const int32x4x2_t b = vtrnq_s32(in[2], in[3]); + out[0] = vextq_s32(vextq_s32(a.val[0], a.val[0], 2), b.val[0], 2); + out[1] = vextq_s32(vextq_s32(a.val[1], a.val[1], 2), b.val[1], 2); + out[2] = vextq_s32(a.val[0], vextq_s32(b.val[0], b.val[0], 2), 2); + out[3] = vextq_s32(a.val[1], vextq_s32(b.val[1], b.val[1], 2), 2); + // out: + // 00 10 20 30 + // 01 11 21 31 + // 02 12 22 32 + // 03 13 23 33 +} + +//------------------------------------------------------------------------------ +template <int store_count> +LIBGAV1_ALWAYS_INLINE void StoreDst(int32_t* dst, int32_t stride, int32_t idx, + const int32x4_t* const s) { + assert(store_count % 4 == 0); + for (int i = 0; i < store_count; i += 4) { + vst1q_s32(&dst[i * stride + idx], s[i]); + vst1q_s32(&dst[(i + 1) * stride + idx], s[i + 1]); + vst1q_s32(&dst[(i + 2) * stride + idx], s[i + 2]); + vst1q_s32(&dst[(i + 3) * stride + idx], s[i + 3]); + } +} + +template <int load_count> +LIBGAV1_ALWAYS_INLINE void LoadSrc(const int32_t* src, int32_t stride, + int32_t idx, int32x4_t* x) { + assert(load_count % 4 == 0); + for (int i = 0; i < load_count; i += 4) { + x[i] = vld1q_s32(&src[i * stride + idx]); + x[i + 1] = vld1q_s32(&src[(i + 1) * stride + idx]); + x[i + 2] = vld1q_s32(&src[(i + 2) * stride + idx]); + x[i + 3] = vld1q_s32(&src[(i + 3) * stride + idx]); + } +} + +// Butterfly rotate 4 values. +LIBGAV1_ALWAYS_INLINE void ButterflyRotation_4(int32x4_t* a, int32x4_t* b, + const int angle, + const bool flip) { + const int32_t cos128 = Cos128(angle); + const int32_t sin128 = Sin128(angle); + const int32x4_t acc_x = vmulq_n_s32(*a, cos128); + const int32x4_t acc_y = vmulq_n_s32(*a, sin128); + // The max range for the input is 18 bits. The cos128/sin128 is 13 bits, + // which leaves 1 bit for the add/subtract. For 10bpp, x/y will fit in a 32 + // bit lane. + const int32x4_t x0 = vmlsq_n_s32(acc_x, *b, sin128); + const int32x4_t y0 = vmlaq_n_s32(acc_y, *b, cos128); + const int32x4_t x = vrshrq_n_s32(x0, 12); + const int32x4_t y = vrshrq_n_s32(y0, 12); + if (flip) { + *a = y; + *b = x; + } else { + *a = x; + *b = y; + } +} + +LIBGAV1_ALWAYS_INLINE void ButterflyRotation_FirstIsZero(int32x4_t* a, + int32x4_t* b, + const int angle, + const bool flip) { + const int32_t cos128 = Cos128(angle); + const int32_t sin128 = Sin128(angle); + assert(sin128 <= 0xfff); + const int32x4_t x0 = vmulq_n_s32(*b, -sin128); + const int32x4_t y0 = vmulq_n_s32(*b, cos128); + const int32x4_t x = vrshrq_n_s32(x0, 12); + const int32x4_t y = vrshrq_n_s32(y0, 12); + if (flip) { + *a = y; + *b = x; + } else { + *a = x; + *b = y; + } +} + +LIBGAV1_ALWAYS_INLINE void ButterflyRotation_SecondIsZero(int32x4_t* a, + int32x4_t* b, + const int angle, + const bool flip) { + const int32_t cos128 = Cos128(angle); + const int32_t sin128 = Sin128(angle); + const int32x4_t x0 = vmulq_n_s32(*a, cos128); + const int32x4_t y0 = vmulq_n_s32(*a, sin128); + const int32x4_t x = vrshrq_n_s32(x0, 12); + const int32x4_t y = vrshrq_n_s32(y0, 12); + if (flip) { + *a = y; + *b = x; + } else { + *a = x; + *b = y; + } +} + +LIBGAV1_ALWAYS_INLINE void HadamardRotation(int32x4_t* a, int32x4_t* b, + bool flip) { + int32x4_t x, y; + if (flip) { + y = vqaddq_s32(*b, *a); + x = vqsubq_s32(*b, *a); + } else { + x = vqaddq_s32(*a, *b); + y = vqsubq_s32(*a, *b); + } + *a = x; + *b = y; +} + +LIBGAV1_ALWAYS_INLINE void HadamardRotation(int32x4_t* a, int32x4_t* b, + bool flip, const int32x4_t* min, + const int32x4_t* max) { + int32x4_t x, y; + if (flip) { + y = vqaddq_s32(*b, *a); + x = vqsubq_s32(*b, *a); + } else { + x = vqaddq_s32(*a, *b); + y = vqsubq_s32(*a, *b); + } + *a = vmaxq_s32(vminq_s32(x, *max), *min); + *b = vmaxq_s32(vminq_s32(y, *max), *min); +} + +using ButterflyRotationFunc = void (*)(int32x4_t* a, int32x4_t* b, int angle, + bool flip); + +//------------------------------------------------------------------------------ +// Discrete Cosine Transforms (DCT). + +template <int width> +LIBGAV1_ALWAYS_INLINE bool DctDcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int32_t*>(dest); + const int32x4_t v_src = vdupq_n_s32(dst[0]); + const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0); + const int32x4_t v_src_round = + vqrdmulhq_n_s32(v_src, kTransformRowMultiplier << (31 - 12)); + const int32x4_t s0 = vbslq_s32(v_mask, v_src_round, v_src); + const int32_t cos128 = Cos128(32); + const int32x4_t xy = vqrdmulhq_n_s32(s0, cos128 << (31 - 12)); + // vqrshlq_s32 will shift right if shift value is negative. + const int32x4_t xy_shifted = vqrshlq_s32(xy, vdupq_n_s32(-row_shift)); + // Clamp result to signed 16 bits. + const int32x4_t result = vmovl_s16(vqmovn_s32(xy_shifted)); + if (width == 4) { + vst1q_s32(dst, result); + } else { + for (int i = 0; i < width; i += 4) { + vst1q_s32(dst, result); + dst += 4; + } + } + return true; +} + +template <int height> +LIBGAV1_ALWAYS_INLINE bool DctDcOnlyColumn(void* dest, int adjusted_tx_height, + int width) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int32_t*>(dest); + const int32_t cos128 = Cos128(32); + + // Calculate dc values for first row. + if (width == 4) { + const int32x4_t v_src = vld1q_s32(dst); + const int32x4_t xy = vqrdmulhq_n_s32(v_src, cos128 << (31 - 12)); + vst1q_s32(dst, xy); + } else { + int i = 0; + do { + const int32x4_t v_src = vld1q_s32(&dst[i]); + const int32x4_t xy = vqrdmulhq_n_s32(v_src, cos128 << (31 - 12)); + vst1q_s32(&dst[i], xy); + i += 4; + } while (i < width); + } + + // Copy first row to the rest of the block. + for (int y = 1; y < height; ++y) { + memcpy(&dst[y * width], dst, width * sizeof(dst[0])); + } + return true; +} + +template <ButterflyRotationFunc butterfly_rotation, + bool is_fast_butterfly = false> +LIBGAV1_ALWAYS_INLINE void Dct4Stages(int32x4_t* s, const int32x4_t* min, + const int32x4_t* max, + const bool is_last_stage) { + // stage 12. + if (is_fast_butterfly) { + ButterflyRotation_SecondIsZero(&s[0], &s[1], 32, true); + ButterflyRotation_SecondIsZero(&s[2], &s[3], 48, false); + } else { + butterfly_rotation(&s[0], &s[1], 32, true); + butterfly_rotation(&s[2], &s[3], 48, false); + } + + // stage 17. + if (is_last_stage) { + HadamardRotation(&s[0], &s[3], false); + HadamardRotation(&s[1], &s[2], false); + } else { + HadamardRotation(&s[0], &s[3], false, min, max); + HadamardRotation(&s[1], &s[2], false, min, max); + } +} + +template <ButterflyRotationFunc butterfly_rotation> +LIBGAV1_ALWAYS_INLINE void Dct4_NEON(void* dest, int32_t step, bool is_row, + int row_shift) { + auto* const dst = static_cast<int32_t*>(dest); + // When |is_row| is true, set range to the row range, otherwise, set to the + // column range. + const int32_t range = is_row ? kBitdepth10 + 7 : 15; + const int32x4_t min = vdupq_n_s32(-(1 << range)); + const int32x4_t max = vdupq_n_s32((1 << range) - 1); + int32x4_t s[4], x[4]; + + LoadSrc<4>(dst, step, 0, x); + if (is_row) { + Transpose4x4(x, x); + } + + // stage 1. + // kBitReverseLookup 0, 2, 1, 3 + s[0] = x[0]; + s[1] = x[2]; + s[2] = x[1]; + s[3] = x[3]; + + Dct4Stages<butterfly_rotation>(s, &min, &max, /*is_last_stage=*/true); + + if (is_row) { + const int32x4_t v_row_shift = vdupq_n_s32(-row_shift); + for (int i = 0; i < 4; ++i) { + s[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(s[i], v_row_shift))); + } + Transpose4x4(s, s); + } + StoreDst<4>(dst, step, 0, s); +} + +template <ButterflyRotationFunc butterfly_rotation, + bool is_fast_butterfly = false> +LIBGAV1_ALWAYS_INLINE void Dct8Stages(int32x4_t* s, const int32x4_t* min, + const int32x4_t* max, + const bool is_last_stage) { + // stage 8. + if (is_fast_butterfly) { + ButterflyRotation_SecondIsZero(&s[4], &s[7], 56, false); + ButterflyRotation_FirstIsZero(&s[5], &s[6], 24, false); + } else { + butterfly_rotation(&s[4], &s[7], 56, false); + butterfly_rotation(&s[5], &s[6], 24, false); + } + + // stage 13. + HadamardRotation(&s[4], &s[5], false, min, max); + HadamardRotation(&s[6], &s[7], true, min, max); + + // stage 18. + butterfly_rotation(&s[6], &s[5], 32, true); + + // stage 22. + if (is_last_stage) { + HadamardRotation(&s[0], &s[7], false); + HadamardRotation(&s[1], &s[6], false); + HadamardRotation(&s[2], &s[5], false); + HadamardRotation(&s[3], &s[4], false); + } else { + HadamardRotation(&s[0], &s[7], false, min, max); + HadamardRotation(&s[1], &s[6], false, min, max); + HadamardRotation(&s[2], &s[5], false, min, max); + HadamardRotation(&s[3], &s[4], false, min, max); + } +} + +// Process dct8 rows or columns, depending on the |is_row| flag. +template <ButterflyRotationFunc butterfly_rotation> +LIBGAV1_ALWAYS_INLINE void Dct8_NEON(void* dest, int32_t step, bool is_row, + int row_shift) { + auto* const dst = static_cast<int32_t*>(dest); + const int32_t range = is_row ? kBitdepth10 + 7 : 15; + const int32x4_t min = vdupq_n_s32(-(1 << range)); + const int32x4_t max = vdupq_n_s32((1 << range) - 1); + int32x4_t s[8], x[8]; + + if (is_row) { + LoadSrc<4>(dst, step, 0, &x[0]); + LoadSrc<4>(dst, step, 4, &x[4]); + Transpose4x4(&x[0], &x[0]); + Transpose4x4(&x[4], &x[4]); + } else { + LoadSrc<8>(dst, step, 0, &x[0]); + } + + // stage 1. + // kBitReverseLookup 0, 4, 2, 6, 1, 5, 3, 7, + s[0] = x[0]; + s[1] = x[4]; + s[2] = x[2]; + s[3] = x[6]; + s[4] = x[1]; + s[5] = x[5]; + s[6] = x[3]; + s[7] = x[7]; + + Dct4Stages<butterfly_rotation>(s, &min, &max, /*is_last_stage=*/false); + Dct8Stages<butterfly_rotation>(s, &min, &max, /*is_last_stage=*/true); + + if (is_row) { + const int32x4_t v_row_shift = vdupq_n_s32(-row_shift); + for (int i = 0; i < 8; ++i) { + s[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(s[i], v_row_shift))); + } + Transpose4x4(&s[0], &s[0]); + Transpose4x4(&s[4], &s[4]); + StoreDst<4>(dst, step, 0, &s[0]); + StoreDst<4>(dst, step, 4, &s[4]); + } else { + StoreDst<8>(dst, step, 0, &s[0]); + } +} + +template <ButterflyRotationFunc butterfly_rotation, + bool is_fast_butterfly = false> +LIBGAV1_ALWAYS_INLINE void Dct16Stages(int32x4_t* s, const int32x4_t* min, + const int32x4_t* max, + const bool is_last_stage) { + // stage 5. + if (is_fast_butterfly) { + ButterflyRotation_SecondIsZero(&s[8], &s[15], 60, false); + ButterflyRotation_FirstIsZero(&s[9], &s[14], 28, false); + ButterflyRotation_SecondIsZero(&s[10], &s[13], 44, false); + ButterflyRotation_FirstIsZero(&s[11], &s[12], 12, false); + } else { + butterfly_rotation(&s[8], &s[15], 60, false); + butterfly_rotation(&s[9], &s[14], 28, false); + butterfly_rotation(&s[10], &s[13], 44, false); + butterfly_rotation(&s[11], &s[12], 12, false); + } + + // stage 9. + HadamardRotation(&s[8], &s[9], false, min, max); + HadamardRotation(&s[10], &s[11], true, min, max); + HadamardRotation(&s[12], &s[13], false, min, max); + HadamardRotation(&s[14], &s[15], true, min, max); + + // stage 14. + butterfly_rotation(&s[14], &s[9], 48, true); + butterfly_rotation(&s[13], &s[10], 112, true); + + // stage 19. + HadamardRotation(&s[8], &s[11], false, min, max); + HadamardRotation(&s[9], &s[10], false, min, max); + HadamardRotation(&s[12], &s[15], true, min, max); + HadamardRotation(&s[13], &s[14], true, min, max); + + // stage 23. + butterfly_rotation(&s[13], &s[10], 32, true); + butterfly_rotation(&s[12], &s[11], 32, true); + + // stage 26. + if (is_last_stage) { + HadamardRotation(&s[0], &s[15], false); + HadamardRotation(&s[1], &s[14], false); + HadamardRotation(&s[2], &s[13], false); + HadamardRotation(&s[3], &s[12], false); + HadamardRotation(&s[4], &s[11], false); + HadamardRotation(&s[5], &s[10], false); + HadamardRotation(&s[6], &s[9], false); + HadamardRotation(&s[7], &s[8], false); + } else { + HadamardRotation(&s[0], &s[15], false, min, max); + HadamardRotation(&s[1], &s[14], false, min, max); + HadamardRotation(&s[2], &s[13], false, min, max); + HadamardRotation(&s[3], &s[12], false, min, max); + HadamardRotation(&s[4], &s[11], false, min, max); + HadamardRotation(&s[5], &s[10], false, min, max); + HadamardRotation(&s[6], &s[9], false, min, max); + HadamardRotation(&s[7], &s[8], false, min, max); + } +} + +// Process dct16 rows or columns, depending on the |is_row| flag. +template <ButterflyRotationFunc butterfly_rotation> +LIBGAV1_ALWAYS_INLINE void Dct16_NEON(void* dest, int32_t step, bool is_row, + int row_shift) { + auto* const dst = static_cast<int32_t*>(dest); + const int32_t range = is_row ? kBitdepth10 + 7 : 15; + const int32x4_t min = vdupq_n_s32(-(1 << range)); + const int32x4_t max = vdupq_n_s32((1 << range) - 1); + int32x4_t s[16], x[16]; + + if (is_row) { + for (int idx = 0; idx < 16; idx += 8) { + LoadSrc<4>(dst, step, idx, &x[idx]); + LoadSrc<4>(dst, step, idx + 4, &x[idx + 4]); + Transpose4x4(&x[idx], &x[idx]); + Transpose4x4(&x[idx + 4], &x[idx + 4]); + } + } else { + LoadSrc<16>(dst, step, 0, &x[0]); + } + + // stage 1 + // kBitReverseLookup 0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15, + s[0] = x[0]; + s[1] = x[8]; + s[2] = x[4]; + s[3] = x[12]; + s[4] = x[2]; + s[5] = x[10]; + s[6] = x[6]; + s[7] = x[14]; + s[8] = x[1]; + s[9] = x[9]; + s[10] = x[5]; + s[11] = x[13]; + s[12] = x[3]; + s[13] = x[11]; + s[14] = x[7]; + s[15] = x[15]; + + Dct4Stages<butterfly_rotation>(s, &min, &max, /*is_last_stage=*/false); + Dct8Stages<butterfly_rotation>(s, &min, &max, /*is_last_stage=*/false); + Dct16Stages<butterfly_rotation>(s, &min, &max, /*is_last_stage=*/true); + + if (is_row) { + const int32x4_t v_row_shift = vdupq_n_s32(-row_shift); + for (int i = 0; i < 16; ++i) { + s[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(s[i], v_row_shift))); + } + for (int idx = 0; idx < 16; idx += 8) { + Transpose4x4(&s[idx], &s[idx]); + Transpose4x4(&s[idx + 4], &s[idx + 4]); + StoreDst<4>(dst, step, idx, &s[idx]); + StoreDst<4>(dst, step, idx + 4, &s[idx + 4]); + } + } else { + StoreDst<16>(dst, step, 0, &s[0]); + } +} + +template <ButterflyRotationFunc butterfly_rotation, + bool is_fast_butterfly = false> +LIBGAV1_ALWAYS_INLINE void Dct32Stages(int32x4_t* s, const int32x4_t* min, + const int32x4_t* max, + const bool is_last_stage) { + // stage 3 + if (is_fast_butterfly) { + ButterflyRotation_SecondIsZero(&s[16], &s[31], 62, false); + ButterflyRotation_FirstIsZero(&s[17], &s[30], 30, false); + ButterflyRotation_SecondIsZero(&s[18], &s[29], 46, false); + ButterflyRotation_FirstIsZero(&s[19], &s[28], 14, false); + ButterflyRotation_SecondIsZero(&s[20], &s[27], 54, false); + ButterflyRotation_FirstIsZero(&s[21], &s[26], 22, false); + ButterflyRotation_SecondIsZero(&s[22], &s[25], 38, false); + ButterflyRotation_FirstIsZero(&s[23], &s[24], 6, false); + } else { + butterfly_rotation(&s[16], &s[31], 62, false); + butterfly_rotation(&s[17], &s[30], 30, false); + butterfly_rotation(&s[18], &s[29], 46, false); + butterfly_rotation(&s[19], &s[28], 14, false); + butterfly_rotation(&s[20], &s[27], 54, false); + butterfly_rotation(&s[21], &s[26], 22, false); + butterfly_rotation(&s[22], &s[25], 38, false); + butterfly_rotation(&s[23], &s[24], 6, false); + } + + // stage 6. + HadamardRotation(&s[16], &s[17], false, min, max); + HadamardRotation(&s[18], &s[19], true, min, max); + HadamardRotation(&s[20], &s[21], false, min, max); + HadamardRotation(&s[22], &s[23], true, min, max); + HadamardRotation(&s[24], &s[25], false, min, max); + HadamardRotation(&s[26], &s[27], true, min, max); + HadamardRotation(&s[28], &s[29], false, min, max); + HadamardRotation(&s[30], &s[31], true, min, max); + + // stage 10. + butterfly_rotation(&s[30], &s[17], 24 + 32, true); + butterfly_rotation(&s[29], &s[18], 24 + 64 + 32, true); + butterfly_rotation(&s[26], &s[21], 24, true); + butterfly_rotation(&s[25], &s[22], 24 + 64, true); + + // stage 15. + HadamardRotation(&s[16], &s[19], false, min, max); + HadamardRotation(&s[17], &s[18], false, min, max); + HadamardRotation(&s[20], &s[23], true, min, max); + HadamardRotation(&s[21], &s[22], true, min, max); + HadamardRotation(&s[24], &s[27], false, min, max); + HadamardRotation(&s[25], &s[26], false, min, max); + HadamardRotation(&s[28], &s[31], true, min, max); + HadamardRotation(&s[29], &s[30], true, min, max); + + // stage 20. + butterfly_rotation(&s[29], &s[18], 48, true); + butterfly_rotation(&s[28], &s[19], 48, true); + butterfly_rotation(&s[27], &s[20], 48 + 64, true); + butterfly_rotation(&s[26], &s[21], 48 + 64, true); + + // stage 24. + HadamardRotation(&s[16], &s[23], false, min, max); + HadamardRotation(&s[17], &s[22], false, min, max); + HadamardRotation(&s[18], &s[21], false, min, max); + HadamardRotation(&s[19], &s[20], false, min, max); + HadamardRotation(&s[24], &s[31], true, min, max); + HadamardRotation(&s[25], &s[30], true, min, max); + HadamardRotation(&s[26], &s[29], true, min, max); + HadamardRotation(&s[27], &s[28], true, min, max); + + // stage 27. + butterfly_rotation(&s[27], &s[20], 32, true); + butterfly_rotation(&s[26], &s[21], 32, true); + butterfly_rotation(&s[25], &s[22], 32, true); + butterfly_rotation(&s[24], &s[23], 32, true); + + // stage 29. + if (is_last_stage) { + HadamardRotation(&s[0], &s[31], false); + HadamardRotation(&s[1], &s[30], false); + HadamardRotation(&s[2], &s[29], false); + HadamardRotation(&s[3], &s[28], false); + HadamardRotation(&s[4], &s[27], false); + HadamardRotation(&s[5], &s[26], false); + HadamardRotation(&s[6], &s[25], false); + HadamardRotation(&s[7], &s[24], false); + HadamardRotation(&s[8], &s[23], false); + HadamardRotation(&s[9], &s[22], false); + HadamardRotation(&s[10], &s[21], false); + HadamardRotation(&s[11], &s[20], false); + HadamardRotation(&s[12], &s[19], false); + HadamardRotation(&s[13], &s[18], false); + HadamardRotation(&s[14], &s[17], false); + HadamardRotation(&s[15], &s[16], false); + } else { + HadamardRotation(&s[0], &s[31], false, min, max); + HadamardRotation(&s[1], &s[30], false, min, max); + HadamardRotation(&s[2], &s[29], false, min, max); + HadamardRotation(&s[3], &s[28], false, min, max); + HadamardRotation(&s[4], &s[27], false, min, max); + HadamardRotation(&s[5], &s[26], false, min, max); + HadamardRotation(&s[6], &s[25], false, min, max); + HadamardRotation(&s[7], &s[24], false, min, max); + HadamardRotation(&s[8], &s[23], false, min, max); + HadamardRotation(&s[9], &s[22], false, min, max); + HadamardRotation(&s[10], &s[21], false, min, max); + HadamardRotation(&s[11], &s[20], false, min, max); + HadamardRotation(&s[12], &s[19], false, min, max); + HadamardRotation(&s[13], &s[18], false, min, max); + HadamardRotation(&s[14], &s[17], false, min, max); + HadamardRotation(&s[15], &s[16], false, min, max); + } +} + +// Process dct32 rows or columns, depending on the |is_row| flag. +LIBGAV1_ALWAYS_INLINE void Dct32_NEON(void* dest, const int32_t step, + const bool is_row, int row_shift) { + auto* const dst = static_cast<int32_t*>(dest); + const int32_t range = is_row ? kBitdepth10 + 7 : 15; + const int32x4_t min = vdupq_n_s32(-(1 << range)); + const int32x4_t max = vdupq_n_s32((1 << range) - 1); + int32x4_t s[32], x[32]; + + if (is_row) { + for (int idx = 0; idx < 32; idx += 8) { + LoadSrc<4>(dst, step, idx, &x[idx]); + LoadSrc<4>(dst, step, idx + 4, &x[idx + 4]); + Transpose4x4(&x[idx], &x[idx]); + Transpose4x4(&x[idx + 4], &x[idx + 4]); + } + } else { + LoadSrc<32>(dst, step, 0, &x[0]); + } + + // stage 1 + // kBitReverseLookup + // 0, 16, 8, 24, 4, 20, 12, 28, 2, 18, 10, 26, 6, 22, 14, 30, + s[0] = x[0]; + s[1] = x[16]; + s[2] = x[8]; + s[3] = x[24]; + s[4] = x[4]; + s[5] = x[20]; + s[6] = x[12]; + s[7] = x[28]; + s[8] = x[2]; + s[9] = x[18]; + s[10] = x[10]; + s[11] = x[26]; + s[12] = x[6]; + s[13] = x[22]; + s[14] = x[14]; + s[15] = x[30]; + + // 1, 17, 9, 25, 5, 21, 13, 29, 3, 19, 11, 27, 7, 23, 15, 31, + s[16] = x[1]; + s[17] = x[17]; + s[18] = x[9]; + s[19] = x[25]; + s[20] = x[5]; + s[21] = x[21]; + s[22] = x[13]; + s[23] = x[29]; + s[24] = x[3]; + s[25] = x[19]; + s[26] = x[11]; + s[27] = x[27]; + s[28] = x[7]; + s[29] = x[23]; + s[30] = x[15]; + s[31] = x[31]; + + Dct4Stages<ButterflyRotation_4>(s, &min, &max, /*is_last_stage=*/false); + Dct8Stages<ButterflyRotation_4>(s, &min, &max, /*is_last_stage=*/false); + Dct16Stages<ButterflyRotation_4>(s, &min, &max, /*is_last_stage=*/false); + Dct32Stages<ButterflyRotation_4>(s, &min, &max, /*is_last_stage=*/true); + + if (is_row) { + const int32x4_t v_row_shift = vdupq_n_s32(-row_shift); + for (int idx = 0; idx < 32; idx += 8) { + int32x4_t output[8]; + Transpose4x4(&s[idx], &output[0]); + Transpose4x4(&s[idx + 4], &output[4]); + for (int i = 0; i < 8; ++i) { + output[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(output[i], v_row_shift))); + } + StoreDst<4>(dst, step, idx, &output[0]); + StoreDst<4>(dst, step, idx + 4, &output[4]); + } + } else { + StoreDst<32>(dst, step, 0, &s[0]); + } +} + +void Dct64_NEON(void* dest, int32_t step, bool is_row, int row_shift) { + auto* const dst = static_cast<int32_t*>(dest); + const int32_t range = is_row ? kBitdepth10 + 7 : 15; + const int32x4_t min = vdupq_n_s32(-(1 << range)); + const int32x4_t max = vdupq_n_s32((1 << range) - 1); + int32x4_t s[64], x[32]; + + if (is_row) { + // The last 32 values of every row are always zero if the |tx_width| is + // 64. + for (int idx = 0; idx < 32; idx += 8) { + LoadSrc<4>(dst, step, idx, &x[idx]); + LoadSrc<4>(dst, step, idx + 4, &x[idx + 4]); + Transpose4x4(&x[idx], &x[idx]); + Transpose4x4(&x[idx + 4], &x[idx + 4]); + } + } else { + // The last 32 values of every column are always zero if the |tx_height| is + // 64. + LoadSrc<32>(dst, step, 0, &x[0]); + } + + // stage 1 + // kBitReverseLookup + // 0, 32, 16, 48, 8, 40, 24, 56, 4, 36, 20, 52, 12, 44, 28, 60, + s[0] = x[0]; + s[2] = x[16]; + s[4] = x[8]; + s[6] = x[24]; + s[8] = x[4]; + s[10] = x[20]; + s[12] = x[12]; + s[14] = x[28]; + + // 2, 34, 18, 50, 10, 42, 26, 58, 6, 38, 22, 54, 14, 46, 30, 62, + s[16] = x[2]; + s[18] = x[18]; + s[20] = x[10]; + s[22] = x[26]; + s[24] = x[6]; + s[26] = x[22]; + s[28] = x[14]; + s[30] = x[30]; + + // 1, 33, 17, 49, 9, 41, 25, 57, 5, 37, 21, 53, 13, 45, 29, 61, + s[32] = x[1]; + s[34] = x[17]; + s[36] = x[9]; + s[38] = x[25]; + s[40] = x[5]; + s[42] = x[21]; + s[44] = x[13]; + s[46] = x[29]; + + // 3, 35, 19, 51, 11, 43, 27, 59, 7, 39, 23, 55, 15, 47, 31, 63 + s[48] = x[3]; + s[50] = x[19]; + s[52] = x[11]; + s[54] = x[27]; + s[56] = x[7]; + s[58] = x[23]; + s[60] = x[15]; + s[62] = x[31]; + + Dct4Stages<ButterflyRotation_4, /*is_fast_butterfly=*/true>( + s, &min, &max, /*is_last_stage=*/false); + Dct8Stages<ButterflyRotation_4, /*is_fast_butterfly=*/true>( + s, &min, &max, /*is_last_stage=*/false); + Dct16Stages<ButterflyRotation_4, /*is_fast_butterfly=*/true>( + s, &min, &max, /*is_last_stage=*/false); + Dct32Stages<ButterflyRotation_4, /*is_fast_butterfly=*/true>( + s, &min, &max, /*is_last_stage=*/false); + + //-- start dct 64 stages + // stage 2. + ButterflyRotation_SecondIsZero(&s[32], &s[63], 63 - 0, false); + ButterflyRotation_FirstIsZero(&s[33], &s[62], 63 - 32, false); + ButterflyRotation_SecondIsZero(&s[34], &s[61], 63 - 16, false); + ButterflyRotation_FirstIsZero(&s[35], &s[60], 63 - 48, false); + ButterflyRotation_SecondIsZero(&s[36], &s[59], 63 - 8, false); + ButterflyRotation_FirstIsZero(&s[37], &s[58], 63 - 40, false); + ButterflyRotation_SecondIsZero(&s[38], &s[57], 63 - 24, false); + ButterflyRotation_FirstIsZero(&s[39], &s[56], 63 - 56, false); + ButterflyRotation_SecondIsZero(&s[40], &s[55], 63 - 4, false); + ButterflyRotation_FirstIsZero(&s[41], &s[54], 63 - 36, false); + ButterflyRotation_SecondIsZero(&s[42], &s[53], 63 - 20, false); + ButterflyRotation_FirstIsZero(&s[43], &s[52], 63 - 52, false); + ButterflyRotation_SecondIsZero(&s[44], &s[51], 63 - 12, false); + ButterflyRotation_FirstIsZero(&s[45], &s[50], 63 - 44, false); + ButterflyRotation_SecondIsZero(&s[46], &s[49], 63 - 28, false); + ButterflyRotation_FirstIsZero(&s[47], &s[48], 63 - 60, false); + + // stage 4. + HadamardRotation(&s[32], &s[33], false, &min, &max); + HadamardRotation(&s[34], &s[35], true, &min, &max); + HadamardRotation(&s[36], &s[37], false, &min, &max); + HadamardRotation(&s[38], &s[39], true, &min, &max); + HadamardRotation(&s[40], &s[41], false, &min, &max); + HadamardRotation(&s[42], &s[43], true, &min, &max); + HadamardRotation(&s[44], &s[45], false, &min, &max); + HadamardRotation(&s[46], &s[47], true, &min, &max); + HadamardRotation(&s[48], &s[49], false, &min, &max); + HadamardRotation(&s[50], &s[51], true, &min, &max); + HadamardRotation(&s[52], &s[53], false, &min, &max); + HadamardRotation(&s[54], &s[55], true, &min, &max); + HadamardRotation(&s[56], &s[57], false, &min, &max); + HadamardRotation(&s[58], &s[59], true, &min, &max); + HadamardRotation(&s[60], &s[61], false, &min, &max); + HadamardRotation(&s[62], &s[63], true, &min, &max); + + // stage 7. + ButterflyRotation_4(&s[62], &s[33], 60 - 0, true); + ButterflyRotation_4(&s[61], &s[34], 60 - 0 + 64, true); + ButterflyRotation_4(&s[58], &s[37], 60 - 32, true); + ButterflyRotation_4(&s[57], &s[38], 60 - 32 + 64, true); + ButterflyRotation_4(&s[54], &s[41], 60 - 16, true); + ButterflyRotation_4(&s[53], &s[42], 60 - 16 + 64, true); + ButterflyRotation_4(&s[50], &s[45], 60 - 48, true); + ButterflyRotation_4(&s[49], &s[46], 60 - 48 + 64, true); + + // stage 11. + HadamardRotation(&s[32], &s[35], false, &min, &max); + HadamardRotation(&s[33], &s[34], false, &min, &max); + HadamardRotation(&s[36], &s[39], true, &min, &max); + HadamardRotation(&s[37], &s[38], true, &min, &max); + HadamardRotation(&s[40], &s[43], false, &min, &max); + HadamardRotation(&s[41], &s[42], false, &min, &max); + HadamardRotation(&s[44], &s[47], true, &min, &max); + HadamardRotation(&s[45], &s[46], true, &min, &max); + HadamardRotation(&s[48], &s[51], false, &min, &max); + HadamardRotation(&s[49], &s[50], false, &min, &max); + HadamardRotation(&s[52], &s[55], true, &min, &max); + HadamardRotation(&s[53], &s[54], true, &min, &max); + HadamardRotation(&s[56], &s[59], false, &min, &max); + HadamardRotation(&s[57], &s[58], false, &min, &max); + HadamardRotation(&s[60], &s[63], true, &min, &max); + HadamardRotation(&s[61], &s[62], true, &min, &max); + + // stage 16. + ButterflyRotation_4(&s[61], &s[34], 56, true); + ButterflyRotation_4(&s[60], &s[35], 56, true); + ButterflyRotation_4(&s[59], &s[36], 56 + 64, true); + ButterflyRotation_4(&s[58], &s[37], 56 + 64, true); + ButterflyRotation_4(&s[53], &s[42], 56 - 32, true); + ButterflyRotation_4(&s[52], &s[43], 56 - 32, true); + ButterflyRotation_4(&s[51], &s[44], 56 - 32 + 64, true); + ButterflyRotation_4(&s[50], &s[45], 56 - 32 + 64, true); + + // stage 21. + HadamardRotation(&s[32], &s[39], false, &min, &max); + HadamardRotation(&s[33], &s[38], false, &min, &max); + HadamardRotation(&s[34], &s[37], false, &min, &max); + HadamardRotation(&s[35], &s[36], false, &min, &max); + HadamardRotation(&s[40], &s[47], true, &min, &max); + HadamardRotation(&s[41], &s[46], true, &min, &max); + HadamardRotation(&s[42], &s[45], true, &min, &max); + HadamardRotation(&s[43], &s[44], true, &min, &max); + HadamardRotation(&s[48], &s[55], false, &min, &max); + HadamardRotation(&s[49], &s[54], false, &min, &max); + HadamardRotation(&s[50], &s[53], false, &min, &max); + HadamardRotation(&s[51], &s[52], false, &min, &max); + HadamardRotation(&s[56], &s[63], true, &min, &max); + HadamardRotation(&s[57], &s[62], true, &min, &max); + HadamardRotation(&s[58], &s[61], true, &min, &max); + HadamardRotation(&s[59], &s[60], true, &min, &max); + + // stage 25. + ButterflyRotation_4(&s[59], &s[36], 48, true); + ButterflyRotation_4(&s[58], &s[37], 48, true); + ButterflyRotation_4(&s[57], &s[38], 48, true); + ButterflyRotation_4(&s[56], &s[39], 48, true); + ButterflyRotation_4(&s[55], &s[40], 112, true); + ButterflyRotation_4(&s[54], &s[41], 112, true); + ButterflyRotation_4(&s[53], &s[42], 112, true); + ButterflyRotation_4(&s[52], &s[43], 112, true); + + // stage 28. + HadamardRotation(&s[32], &s[47], false, &min, &max); + HadamardRotation(&s[33], &s[46], false, &min, &max); + HadamardRotation(&s[34], &s[45], false, &min, &max); + HadamardRotation(&s[35], &s[44], false, &min, &max); + HadamardRotation(&s[36], &s[43], false, &min, &max); + HadamardRotation(&s[37], &s[42], false, &min, &max); + HadamardRotation(&s[38], &s[41], false, &min, &max); + HadamardRotation(&s[39], &s[40], false, &min, &max); + HadamardRotation(&s[48], &s[63], true, &min, &max); + HadamardRotation(&s[49], &s[62], true, &min, &max); + HadamardRotation(&s[50], &s[61], true, &min, &max); + HadamardRotation(&s[51], &s[60], true, &min, &max); + HadamardRotation(&s[52], &s[59], true, &min, &max); + HadamardRotation(&s[53], &s[58], true, &min, &max); + HadamardRotation(&s[54], &s[57], true, &min, &max); + HadamardRotation(&s[55], &s[56], true, &min, &max); + + // stage 30. + ButterflyRotation_4(&s[55], &s[40], 32, true); + ButterflyRotation_4(&s[54], &s[41], 32, true); + ButterflyRotation_4(&s[53], &s[42], 32, true); + ButterflyRotation_4(&s[52], &s[43], 32, true); + ButterflyRotation_4(&s[51], &s[44], 32, true); + ButterflyRotation_4(&s[50], &s[45], 32, true); + ButterflyRotation_4(&s[49], &s[46], 32, true); + ButterflyRotation_4(&s[48], &s[47], 32, true); + + // stage 31. + for (int i = 0; i < 32; i += 4) { + HadamardRotation(&s[i], &s[63 - i], false, &min, &max); + HadamardRotation(&s[i + 1], &s[63 - i - 1], false, &min, &max); + HadamardRotation(&s[i + 2], &s[63 - i - 2], false, &min, &max); + HadamardRotation(&s[i + 3], &s[63 - i - 3], false, &min, &max); + } + //-- end dct 64 stages + if (is_row) { + const int32x4_t v_row_shift = vdupq_n_s32(-row_shift); + for (int idx = 0; idx < 64; idx += 8) { + int32x4_t output[8]; + Transpose4x4(&s[idx], &output[0]); + Transpose4x4(&s[idx + 4], &output[4]); + for (int i = 0; i < 8; ++i) { + output[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(output[i], v_row_shift))); + } + StoreDst<4>(dst, step, idx, &output[0]); + StoreDst<4>(dst, step, idx + 4, &output[4]); + } + } else { + StoreDst<64>(dst, step, 0, &s[0]); + } +} + +//------------------------------------------------------------------------------ +// Asymmetric Discrete Sine Transforms (ADST). +LIBGAV1_ALWAYS_INLINE void Adst4_NEON(void* dest, int32_t step, bool is_row, + int row_shift) { + auto* const dst = static_cast<int32_t*>(dest); + int32x4_t s[8]; + int32x4_t x[4]; + + LoadSrc<4>(dst, step, 0, x); + if (is_row) { + Transpose4x4(x, x); + } + + // stage 1. + s[5] = vmulq_n_s32(x[3], kAdst4Multiplier[1]); + s[6] = vmulq_n_s32(x[3], kAdst4Multiplier[3]); + + // stage 2. + const int32x4_t a7 = vsubq_s32(x[0], x[2]); + const int32x4_t b7 = vaddq_s32(a7, x[3]); + + // stage 3. + s[0] = vmulq_n_s32(x[0], kAdst4Multiplier[0]); + s[1] = vmulq_n_s32(x[0], kAdst4Multiplier[1]); + // s[0] = s[0] + s[3] + s[0] = vmlaq_n_s32(s[0], x[2], kAdst4Multiplier[3]); + // s[1] = s[1] - s[4] + s[1] = vmlsq_n_s32(s[1], x[2], kAdst4Multiplier[0]); + + s[3] = vmulq_n_s32(x[1], kAdst4Multiplier[2]); + s[2] = vmulq_n_s32(b7, kAdst4Multiplier[2]); + + // stage 4. + s[0] = vaddq_s32(s[0], s[5]); + s[1] = vsubq_s32(s[1], s[6]); + + // stages 5 and 6. + const int32x4_t x0 = vaddq_s32(s[0], s[3]); + const int32x4_t x1 = vaddq_s32(s[1], s[3]); + const int32x4_t x3_a = vaddq_s32(s[0], s[1]); + const int32x4_t x3 = vsubq_s32(x3_a, s[3]); + x[0] = vrshrq_n_s32(x0, 12); + x[1] = vrshrq_n_s32(x1, 12); + x[2] = vrshrq_n_s32(s[2], 12); + x[3] = vrshrq_n_s32(x3, 12); + + if (is_row) { + const int32x4_t v_row_shift = vdupq_n_s32(-row_shift); + x[0] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[0], v_row_shift))); + x[1] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[1], v_row_shift))); + x[2] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[2], v_row_shift))); + x[3] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[3], v_row_shift))); + Transpose4x4(x, x); + } + StoreDst<4>(dst, step, 0, x); +} + +alignas(16) constexpr int32_t kAdst4DcOnlyMultiplier[4] = {1321, 2482, 3344, + 2482}; + +LIBGAV1_ALWAYS_INLINE bool Adst4DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int32_t*>(dest); + int32x4_t s[2]; + + const int32x4_t v_src0 = vdupq_n_s32(dst[0]); + const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0); + const int32x4_t v_src0_round = + vqrdmulhq_n_s32(v_src0, kTransformRowMultiplier << (31 - 12)); + + const int32x4_t v_src = vbslq_s32(v_mask, v_src0_round, v_src0); + const int32x4_t kAdst4DcOnlyMultipliers = vld1q_s32(kAdst4DcOnlyMultiplier); + s[1] = vdupq_n_s32(0); + + // s0*k0 s0*k1 s0*k2 s0*k1 + s[0] = vmulq_s32(kAdst4DcOnlyMultipliers, v_src); + // 0 0 0 s0*k0 + s[1] = vextq_s32(s[1], s[0], 1); + + const int32x4_t x3 = vaddq_s32(s[0], s[1]); + const int32x4_t dst_0 = vrshrq_n_s32(x3, 12); + + // vqrshlq_s32 will shift right if shift value is negative. + vst1q_s32(dst, + vmovl_s16(vqmovn_s32(vqrshlq_s32(dst_0, vdupq_n_s32(-row_shift))))); + + return true; +} + +LIBGAV1_ALWAYS_INLINE bool Adst4DcOnlyColumn(void* dest, int adjusted_tx_height, + int width) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int32_t*>(dest); + int32x4_t s[4]; + + int i = 0; + do { + const int32x4_t v_src = vld1q_s32(&dst[i]); + + s[0] = vmulq_n_s32(v_src, kAdst4Multiplier[0]); + s[1] = vmulq_n_s32(v_src, kAdst4Multiplier[1]); + s[2] = vmulq_n_s32(v_src, kAdst4Multiplier[2]); + + const int32x4_t x0 = s[0]; + const int32x4_t x1 = s[1]; + const int32x4_t x2 = s[2]; + const int32x4_t x3 = vaddq_s32(s[0], s[1]); + const int32x4_t dst_0 = vrshrq_n_s32(x0, 12); + const int32x4_t dst_1 = vrshrq_n_s32(x1, 12); + const int32x4_t dst_2 = vrshrq_n_s32(x2, 12); + const int32x4_t dst_3 = vrshrq_n_s32(x3, 12); + + vst1q_s32(&dst[i], dst_0); + vst1q_s32(&dst[i + width * 1], dst_1); + vst1q_s32(&dst[i + width * 2], dst_2); + vst1q_s32(&dst[i + width * 3], dst_3); + + i += 4; + } while (i < width); + + return true; +} + +template <ButterflyRotationFunc butterfly_rotation> +LIBGAV1_ALWAYS_INLINE void Adst8_NEON(void* dest, int32_t step, bool is_row, + int row_shift) { + auto* const dst = static_cast<int32_t*>(dest); + const int32_t range = is_row ? kBitdepth10 + 7 : 15; + const int32x4_t min = vdupq_n_s32(-(1 << range)); + const int32x4_t max = vdupq_n_s32((1 << range) - 1); + int32x4_t s[8], x[8]; + + if (is_row) { + LoadSrc<4>(dst, step, 0, &x[0]); + LoadSrc<4>(dst, step, 4, &x[4]); + Transpose4x4(&x[0], &x[0]); + Transpose4x4(&x[4], &x[4]); + } else { + LoadSrc<8>(dst, step, 0, &x[0]); + } + + // stage 1. + s[0] = x[7]; + s[1] = x[0]; + s[2] = x[5]; + s[3] = x[2]; + s[4] = x[3]; + s[5] = x[4]; + s[6] = x[1]; + s[7] = x[6]; + + // stage 2. + butterfly_rotation(&s[0], &s[1], 60 - 0, true); + butterfly_rotation(&s[2], &s[3], 60 - 16, true); + butterfly_rotation(&s[4], &s[5], 60 - 32, true); + butterfly_rotation(&s[6], &s[7], 60 - 48, true); + + // stage 3. + HadamardRotation(&s[0], &s[4], false, &min, &max); + HadamardRotation(&s[1], &s[5], false, &min, &max); + HadamardRotation(&s[2], &s[6], false, &min, &max); + HadamardRotation(&s[3], &s[7], false, &min, &max); + + // stage 4. + butterfly_rotation(&s[4], &s[5], 48 - 0, true); + butterfly_rotation(&s[7], &s[6], 48 - 32, true); + + // stage 5. + HadamardRotation(&s[0], &s[2], false, &min, &max); + HadamardRotation(&s[4], &s[6], false, &min, &max); + HadamardRotation(&s[1], &s[3], false, &min, &max); + HadamardRotation(&s[5], &s[7], false, &min, &max); + + // stage 6. + butterfly_rotation(&s[2], &s[3], 32, true); + butterfly_rotation(&s[6], &s[7], 32, true); + + // stage 7. + x[0] = s[0]; + x[1] = vqnegq_s32(s[4]); + x[2] = s[6]; + x[3] = vqnegq_s32(s[2]); + x[4] = s[3]; + x[5] = vqnegq_s32(s[7]); + x[6] = s[5]; + x[7] = vqnegq_s32(s[1]); + + if (is_row) { + const int32x4_t v_row_shift = vdupq_n_s32(-row_shift); + for (int i = 0; i < 8; ++i) { + x[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[i], v_row_shift))); + } + Transpose4x4(&x[0], &x[0]); + Transpose4x4(&x[4], &x[4]); + StoreDst<4>(dst, step, 0, &x[0]); + StoreDst<4>(dst, step, 4, &x[4]); + } else { + StoreDst<8>(dst, step, 0, &x[0]); + } +} + +LIBGAV1_ALWAYS_INLINE bool Adst8DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int32_t*>(dest); + int32x4_t s[8]; + + const int32x4_t v_src = vdupq_n_s32(dst[0]); + const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0); + const int32x4_t v_src_round = + vqrdmulhq_n_s32(v_src, kTransformRowMultiplier << (31 - 12)); + // stage 1. + s[1] = vbslq_s32(v_mask, v_src_round, v_src); + + // stage 2. + ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true); + + // stage 3. + s[4] = s[0]; + s[5] = s[1]; + + // stage 4. + ButterflyRotation_4(&s[4], &s[5], 48, true); + + // stage 5. + s[2] = s[0]; + s[3] = s[1]; + s[6] = s[4]; + s[7] = s[5]; + + // stage 6. + ButterflyRotation_4(&s[2], &s[3], 32, true); + ButterflyRotation_4(&s[6], &s[7], 32, true); + + // stage 7. + int32x4_t x[8]; + x[0] = s[0]; + x[1] = vqnegq_s32(s[4]); + x[2] = s[6]; + x[3] = vqnegq_s32(s[2]); + x[4] = s[3]; + x[5] = vqnegq_s32(s[7]); + x[6] = s[5]; + x[7] = vqnegq_s32(s[1]); + + for (int i = 0; i < 8; ++i) { + // vqrshlq_s32 will shift right if shift value is negative. + x[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[i], vdupq_n_s32(-row_shift)))); + vst1q_lane_s32(&dst[i], x[i], 0); + } + + return true; +} + +LIBGAV1_ALWAYS_INLINE bool Adst8DcOnlyColumn(void* dest, int adjusted_tx_height, + int width) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int32_t*>(dest); + int32x4_t s[8]; + + int i = 0; + do { + const int32x4_t v_src = vld1q_s32(dst); + // stage 1. + s[1] = v_src; + + // stage 2. + ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true); + + // stage 3. + s[4] = s[0]; + s[5] = s[1]; + + // stage 4. + ButterflyRotation_4(&s[4], &s[5], 48, true); + + // stage 5. + s[2] = s[0]; + s[3] = s[1]; + s[6] = s[4]; + s[7] = s[5]; + + // stage 6. + ButterflyRotation_4(&s[2], &s[3], 32, true); + ButterflyRotation_4(&s[6], &s[7], 32, true); + + // stage 7. + int32x4_t x[8]; + x[0] = s[0]; + x[1] = vqnegq_s32(s[4]); + x[2] = s[6]; + x[3] = vqnegq_s32(s[2]); + x[4] = s[3]; + x[5] = vqnegq_s32(s[7]); + x[6] = s[5]; + x[7] = vqnegq_s32(s[1]); + + for (int j = 0; j < 8; ++j) { + vst1q_s32(&dst[j * width], x[j]); + } + i += 4; + dst += 4; + } while (i < width); + + return true; +} + +template <ButterflyRotationFunc butterfly_rotation> +LIBGAV1_ALWAYS_INLINE void Adst16_NEON(void* dest, int32_t step, bool is_row, + int row_shift) { + auto* const dst = static_cast<int32_t*>(dest); + const int32_t range = is_row ? kBitdepth10 + 7 : 15; + const int32x4_t min = vdupq_n_s32(-(1 << range)); + const int32x4_t max = vdupq_n_s32((1 << range) - 1); + int32x4_t s[16], x[16]; + + if (is_row) { + for (int idx = 0; idx < 16; idx += 8) { + LoadSrc<4>(dst, step, idx, &x[idx]); + LoadSrc<4>(dst, step, idx + 4, &x[idx + 4]); + Transpose4x4(&x[idx], &x[idx]); + Transpose4x4(&x[idx + 4], &x[idx + 4]); + } + } else { + LoadSrc<16>(dst, step, 0, &x[0]); + } + + // stage 1. + s[0] = x[15]; + s[1] = x[0]; + s[2] = x[13]; + s[3] = x[2]; + s[4] = x[11]; + s[5] = x[4]; + s[6] = x[9]; + s[7] = x[6]; + s[8] = x[7]; + s[9] = x[8]; + s[10] = x[5]; + s[11] = x[10]; + s[12] = x[3]; + s[13] = x[12]; + s[14] = x[1]; + s[15] = x[14]; + + // stage 2. + butterfly_rotation(&s[0], &s[1], 62 - 0, true); + butterfly_rotation(&s[2], &s[3], 62 - 8, true); + butterfly_rotation(&s[4], &s[5], 62 - 16, true); + butterfly_rotation(&s[6], &s[7], 62 - 24, true); + butterfly_rotation(&s[8], &s[9], 62 - 32, true); + butterfly_rotation(&s[10], &s[11], 62 - 40, true); + butterfly_rotation(&s[12], &s[13], 62 - 48, true); + butterfly_rotation(&s[14], &s[15], 62 - 56, true); + + // stage 3. + HadamardRotation(&s[0], &s[8], false, &min, &max); + HadamardRotation(&s[1], &s[9], false, &min, &max); + HadamardRotation(&s[2], &s[10], false, &min, &max); + HadamardRotation(&s[3], &s[11], false, &min, &max); + HadamardRotation(&s[4], &s[12], false, &min, &max); + HadamardRotation(&s[5], &s[13], false, &min, &max); + HadamardRotation(&s[6], &s[14], false, &min, &max); + HadamardRotation(&s[7], &s[15], false, &min, &max); + + // stage 4. + butterfly_rotation(&s[8], &s[9], 56 - 0, true); + butterfly_rotation(&s[13], &s[12], 8 + 0, true); + butterfly_rotation(&s[10], &s[11], 56 - 32, true); + butterfly_rotation(&s[15], &s[14], 8 + 32, true); + + // stage 5. + HadamardRotation(&s[0], &s[4], false, &min, &max); + HadamardRotation(&s[8], &s[12], false, &min, &max); + HadamardRotation(&s[1], &s[5], false, &min, &max); + HadamardRotation(&s[9], &s[13], false, &min, &max); + HadamardRotation(&s[2], &s[6], false, &min, &max); + HadamardRotation(&s[10], &s[14], false, &min, &max); + HadamardRotation(&s[3], &s[7], false, &min, &max); + HadamardRotation(&s[11], &s[15], false, &min, &max); + + // stage 6. + butterfly_rotation(&s[4], &s[5], 48 - 0, true); + butterfly_rotation(&s[12], &s[13], 48 - 0, true); + butterfly_rotation(&s[7], &s[6], 48 - 32, true); + butterfly_rotation(&s[15], &s[14], 48 - 32, true); + + // stage 7. + HadamardRotation(&s[0], &s[2], false, &min, &max); + HadamardRotation(&s[4], &s[6], false, &min, &max); + HadamardRotation(&s[8], &s[10], false, &min, &max); + HadamardRotation(&s[12], &s[14], false, &min, &max); + HadamardRotation(&s[1], &s[3], false, &min, &max); + HadamardRotation(&s[5], &s[7], false, &min, &max); + HadamardRotation(&s[9], &s[11], false, &min, &max); + HadamardRotation(&s[13], &s[15], false, &min, &max); + + // stage 8. + butterfly_rotation(&s[2], &s[3], 32, true); + butterfly_rotation(&s[6], &s[7], 32, true); + butterfly_rotation(&s[10], &s[11], 32, true); + butterfly_rotation(&s[14], &s[15], 32, true); + + // stage 9. + x[0] = s[0]; + x[1] = vqnegq_s32(s[8]); + x[2] = s[12]; + x[3] = vqnegq_s32(s[4]); + x[4] = s[6]; + x[5] = vqnegq_s32(s[14]); + x[6] = s[10]; + x[7] = vqnegq_s32(s[2]); + x[8] = s[3]; + x[9] = vqnegq_s32(s[11]); + x[10] = s[15]; + x[11] = vqnegq_s32(s[7]); + x[12] = s[5]; + x[13] = vqnegq_s32(s[13]); + x[14] = s[9]; + x[15] = vqnegq_s32(s[1]); + + if (is_row) { + const int32x4_t v_row_shift = vdupq_n_s32(-row_shift); + for (int i = 0; i < 16; ++i) { + x[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[i], v_row_shift))); + } + for (int idx = 0; idx < 16; idx += 8) { + Transpose4x4(&x[idx], &x[idx]); + Transpose4x4(&x[idx + 4], &x[idx + 4]); + StoreDst<4>(dst, step, idx, &x[idx]); + StoreDst<4>(dst, step, idx + 4, &x[idx + 4]); + } + } else { + StoreDst<16>(dst, step, 0, &x[0]); + } +} + +LIBGAV1_ALWAYS_INLINE void Adst16DcOnlyInternal(int32x4_t* s, int32x4_t* x) { + // stage 2. + ButterflyRotation_FirstIsZero(&s[0], &s[1], 62, true); + + // stage 3. + s[8] = s[0]; + s[9] = s[1]; + + // stage 4. + ButterflyRotation_4(&s[8], &s[9], 56, true); + + // stage 5. + s[4] = s[0]; + s[12] = s[8]; + s[5] = s[1]; + s[13] = s[9]; + + // stage 6. + ButterflyRotation_4(&s[4], &s[5], 48, true); + ButterflyRotation_4(&s[12], &s[13], 48, true); + + // stage 7. + s[2] = s[0]; + s[6] = s[4]; + s[10] = s[8]; + s[14] = s[12]; + s[3] = s[1]; + s[7] = s[5]; + s[11] = s[9]; + s[15] = s[13]; + + // stage 8. + ButterflyRotation_4(&s[2], &s[3], 32, true); + ButterflyRotation_4(&s[6], &s[7], 32, true); + ButterflyRotation_4(&s[10], &s[11], 32, true); + ButterflyRotation_4(&s[14], &s[15], 32, true); + + // stage 9. + x[0] = s[0]; + x[1] = vqnegq_s32(s[8]); + x[2] = s[12]; + x[3] = vqnegq_s32(s[4]); + x[4] = s[6]; + x[5] = vqnegq_s32(s[14]); + x[6] = s[10]; + x[7] = vqnegq_s32(s[2]); + x[8] = s[3]; + x[9] = vqnegq_s32(s[11]); + x[10] = s[15]; + x[11] = vqnegq_s32(s[7]); + x[12] = s[5]; + x[13] = vqnegq_s32(s[13]); + x[14] = s[9]; + x[15] = vqnegq_s32(s[1]); +} + +LIBGAV1_ALWAYS_INLINE bool Adst16DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int32_t*>(dest); + int32x4_t s[16]; + int32x4_t x[16]; + const int32x4_t v_src = vdupq_n_s32(dst[0]); + const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0); + const int32x4_t v_src_round = + vqrdmulhq_n_s32(v_src, kTransformRowMultiplier << (31 - 12)); + // stage 1. + s[1] = vbslq_s32(v_mask, v_src_round, v_src); + + Adst16DcOnlyInternal(s, x); + + for (int i = 0; i < 16; ++i) { + // vqrshlq_s32 will shift right if shift value is negative. + x[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[i], vdupq_n_s32(-row_shift)))); + vst1q_lane_s32(&dst[i], x[i], 0); + } + + return true; +} + +LIBGAV1_ALWAYS_INLINE bool Adst16DcOnlyColumn(void* dest, + int adjusted_tx_height, + int width) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int32_t*>(dest); + int i = 0; + do { + int32x4_t s[16]; + int32x4_t x[16]; + const int32x4_t v_src = vld1q_s32(dst); + // stage 1. + s[1] = v_src; + + Adst16DcOnlyInternal(s, x); + + for (int j = 0; j < 16; ++j) { + vst1q_s32(&dst[j * width], x[j]); + } + i += 4; + dst += 4; + } while (i < width); + + return true; +} + +//------------------------------------------------------------------------------ +// Identity Transforms. + +LIBGAV1_ALWAYS_INLINE void Identity4_NEON(void* dest, int32_t step, int shift) { + auto* const dst = static_cast<int32_t*>(dest); + const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11); + const int32x4_t v_multiplier = vdupq_n_s32(kIdentity4Multiplier); + const int32x4_t v_shift = vdupq_n_s32(-(12 + shift)); + for (int i = 0; i < 4; ++i) { + const int32x4_t v_src = vld1q_s32(&dst[i * step]); + const int32x4_t v_src_mult_lo = + vmlaq_s32(v_dual_round, v_src, v_multiplier); + const int32x4_t shift_lo = vqshlq_s32(v_src_mult_lo, v_shift); + vst1q_s32(&dst[i * step], vmovl_s16(vqmovn_s32(shift_lo))); + } +} + +LIBGAV1_ALWAYS_INLINE bool Identity4DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int tx_height) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int32_t*>(dest); + const int32x4_t v_src0 = vdupq_n_s32(dst[0]); + const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0); + const int32x4_t v_src_round = + vqrdmulhq_n_s32(v_src0, kTransformRowMultiplier << (31 - 12)); + const int32x4_t v_src = vbslq_s32(v_mask, v_src_round, v_src0); + const int shift = tx_height < 16 ? 0 : 1; + const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11); + const int32x4_t v_multiplier = vdupq_n_s32(kIdentity4Multiplier); + const int32x4_t v_shift = vdupq_n_s32(-(12 + shift)); + const int32x4_t v_src_mult_lo = vmlaq_s32(v_dual_round, v_src, v_multiplier); + const int32x4_t dst_0 = vqshlq_s32(v_src_mult_lo, v_shift); + vst1q_lane_s32(dst, vmovl_s16(vqmovn_s32(dst_0)), 0); + return true; +} + +template <int identity_size> +LIBGAV1_ALWAYS_INLINE void IdentityColumnStoreToFrame( + Array2DView<uint16_t> frame, const int start_x, const int start_y, + const int tx_width, const int tx_height, const int32_t* source) { + static_assert(identity_size == 4 || identity_size == 8 || identity_size == 16, + "Invalid identity_size."); + const int stride = frame.columns(); + uint16_t* dst = frame[start_y] + start_x; + const int32x4_t v_dual_round = vdupq_n_s32((1 + (1 << 4)) << 11); + const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1); + + if (tx_width == 4) { + int i = 0; + do { + int32x4x2_t v_src, v_dst_i, a, b; + v_src.val[0] = vld1q_s32(&source[i * 4]); + v_src.val[1] = vld1q_s32(&source[(i * 4) + 4]); + if (identity_size == 4) { + v_dst_i.val[0] = + vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity4Multiplier); + v_dst_i.val[1] = + vmlaq_n_s32(v_dual_round, v_src.val[1], kIdentity4Multiplier); + a.val[0] = vshrq_n_s32(v_dst_i.val[0], 4 + 12); + a.val[1] = vshrq_n_s32(v_dst_i.val[1], 4 + 12); + } else if (identity_size == 8) { + v_dst_i.val[0] = vaddq_s32(v_src.val[0], v_src.val[0]); + v_dst_i.val[1] = vaddq_s32(v_src.val[1], v_src.val[1]); + a.val[0] = vrshrq_n_s32(v_dst_i.val[0], 4); + a.val[1] = vrshrq_n_s32(v_dst_i.val[1], 4); + } else { // identity_size == 16 + v_dst_i.val[0] = + vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity16Multiplier); + v_dst_i.val[1] = + vmlaq_n_s32(v_dual_round, v_src.val[1], kIdentity16Multiplier); + a.val[0] = vshrq_n_s32(v_dst_i.val[0], 4 + 12); + a.val[1] = vshrq_n_s32(v_dst_i.val[1], 4 + 12); + } + uint16x4x2_t frame_data; + frame_data.val[0] = vld1_u16(dst); + frame_data.val[1] = vld1_u16(dst + stride); + b.val[0] = vaddw_s16(a.val[0], vreinterpret_s16_u16(frame_data.val[0])); + b.val[1] = vaddw_s16(a.val[1], vreinterpret_s16_u16(frame_data.val[1])); + vst1_u16(dst, vmin_u16(vqmovun_s32(b.val[0]), v_max_bitdepth)); + vst1_u16(dst + stride, vmin_u16(vqmovun_s32(b.val[1]), v_max_bitdepth)); + dst += stride << 1; + i += 2; + } while (i < tx_height); + } else { + int i = 0; + do { + const int row = i * tx_width; + int j = 0; + do { + int32x4x2_t v_src, v_dst_i, a, b; + v_src.val[0] = vld1q_s32(&source[row + j]); + v_src.val[1] = vld1q_s32(&source[row + j + 4]); + if (identity_size == 4) { + v_dst_i.val[0] = + vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity4Multiplier); + v_dst_i.val[1] = + vmlaq_n_s32(v_dual_round, v_src.val[1], kIdentity4Multiplier); + a.val[0] = vshrq_n_s32(v_dst_i.val[0], 4 + 12); + a.val[1] = vshrq_n_s32(v_dst_i.val[1], 4 + 12); + } else if (identity_size == 8) { + v_dst_i.val[0] = vaddq_s32(v_src.val[0], v_src.val[0]); + v_dst_i.val[1] = vaddq_s32(v_src.val[1], v_src.val[1]); + a.val[0] = vrshrq_n_s32(v_dst_i.val[0], 4); + a.val[1] = vrshrq_n_s32(v_dst_i.val[1], 4); + } else { // identity_size == 16 + v_dst_i.val[0] = + vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity16Multiplier); + v_dst_i.val[1] = + vmlaq_n_s32(v_dual_round, v_src.val[1], kIdentity16Multiplier); + a.val[0] = vshrq_n_s32(v_dst_i.val[0], 4 + 12); + a.val[1] = vshrq_n_s32(v_dst_i.val[1], 4 + 12); + } + uint16x4x2_t frame_data; + frame_data.val[0] = vld1_u16(dst + j); + frame_data.val[1] = vld1_u16(dst + j + 4); + b.val[0] = vaddw_s16(a.val[0], vreinterpret_s16_u16(frame_data.val[0])); + b.val[1] = vaddw_s16(a.val[1], vreinterpret_s16_u16(frame_data.val[1])); + vst1_u16(dst + j, vmin_u16(vqmovun_s32(b.val[0]), v_max_bitdepth)); + vst1_u16(dst + j + 4, vmin_u16(vqmovun_s32(b.val[1]), v_max_bitdepth)); + j += 8; + } while (j < tx_width); + dst += stride; + } while (++i < tx_height); + } +} + +LIBGAV1_ALWAYS_INLINE void Identity4RowColumnStoreToFrame( + Array2DView<uint16_t> frame, const int start_x, const int start_y, + const int tx_width, const int tx_height, const int32_t* source) { + const int stride = frame.columns(); + uint16_t* dst = frame[start_y] + start_x; + const int32x4_t v_round = vdupq_n_s32((1 + (0)) << 11); + const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1); + + if (tx_width == 4) { + int i = 0; + do { + const int32x4_t v_src = vld1q_s32(&source[i * 4]); + const int32x4_t v_dst_row = + vshrq_n_s32(vmlaq_n_s32(v_round, v_src, kIdentity4Multiplier), 12); + const int32x4_t v_dst_col = + vmlaq_n_s32(v_round, v_dst_row, kIdentity4Multiplier); + const uint16x4_t frame_data = vld1_u16(dst); + const int32x4_t a = vrshrq_n_s32(v_dst_col, 4 + 12); + const int32x4_t b = vaddw_s16(a, vreinterpret_s16_u16(frame_data)); + vst1_u16(dst, vmin_u16(vqmovun_s32(b), v_max_bitdepth)); + dst += stride; + } while (++i < tx_height); + } else { + int i = 0; + do { + const int row = i * tx_width; + int j = 0; + do { + int32x4x2_t v_src, v_src_round, v_dst_row, v_dst_col, a, b; + v_src.val[0] = vld1q_s32(&source[row + j]); + v_src.val[1] = vld1q_s32(&source[row + j + 4]); + v_src_round.val[0] = vshrq_n_s32( + vmlaq_n_s32(v_round, v_src.val[0], kTransformRowMultiplier), 12); + v_src_round.val[1] = vshrq_n_s32( + vmlaq_n_s32(v_round, v_src.val[1], kTransformRowMultiplier), 12); + v_dst_row.val[0] = vqaddq_s32(v_src_round.val[0], v_src_round.val[0]); + v_dst_row.val[1] = vqaddq_s32(v_src_round.val[1], v_src_round.val[1]); + v_dst_col.val[0] = + vmlaq_n_s32(v_round, v_dst_row.val[0], kIdentity4Multiplier); + v_dst_col.val[1] = + vmlaq_n_s32(v_round, v_dst_row.val[1], kIdentity4Multiplier); + uint16x4x2_t frame_data; + frame_data.val[0] = vld1_u16(dst + j); + frame_data.val[1] = vld1_u16(dst + j + 4); + a.val[0] = vrshrq_n_s32(v_dst_col.val[0], 4 + 12); + a.val[1] = vrshrq_n_s32(v_dst_col.val[1], 4 + 12); + b.val[0] = vaddw_s16(a.val[0], vreinterpret_s16_u16(frame_data.val[0])); + b.val[1] = vaddw_s16(a.val[1], vreinterpret_s16_u16(frame_data.val[1])); + vst1_u16(dst + j, vmin_u16(vqmovun_s32(b.val[0]), v_max_bitdepth)); + vst1_u16(dst + j + 4, vmin_u16(vqmovun_s32(b.val[1]), v_max_bitdepth)); + j += 8; + } while (j < tx_width); + dst += stride; + } while (++i < tx_height); + } +} + +LIBGAV1_ALWAYS_INLINE void Identity8Row32_NEON(void* dest, int32_t step) { + auto* const dst = static_cast<int32_t*>(dest); + + // When combining the identity8 multiplier with the row shift, the + // calculations for tx_height equal to 32 can be simplified from + // ((A * 2) + 2) >> 2) to ((A + 1) >> 1). + for (int i = 0; i < 4; ++i) { + const int32x4_t v_src_lo = vld1q_s32(&dst[i * step]); + const int32x4_t v_src_hi = vld1q_s32(&dst[(i * step) + 4]); + const int32x4_t a_lo = vrshrq_n_s32(v_src_lo, 1); + const int32x4_t a_hi = vrshrq_n_s32(v_src_hi, 1); + vst1q_s32(&dst[i * step], vmovl_s16(vqmovn_s32(a_lo))); + vst1q_s32(&dst[(i * step) + 4], vmovl_s16(vqmovn_s32(a_hi))); + } +} + +LIBGAV1_ALWAYS_INLINE void Identity8Row4_NEON(void* dest, int32_t step) { + auto* const dst = static_cast<int32_t*>(dest); + + for (int i = 0; i < 4; ++i) { + const int32x4_t v_src_lo = vld1q_s32(&dst[i * step]); + const int32x4_t v_src_hi = vld1q_s32(&dst[(i * step) + 4]); + const int32x4_t v_srcx2_lo = vqaddq_s32(v_src_lo, v_src_lo); + const int32x4_t v_srcx2_hi = vqaddq_s32(v_src_hi, v_src_hi); + vst1q_s32(&dst[i * step], vmovl_s16(vqmovn_s32(v_srcx2_lo))); + vst1q_s32(&dst[(i * step) + 4], vmovl_s16(vqmovn_s32(v_srcx2_hi))); + } +} + +LIBGAV1_ALWAYS_INLINE bool Identity8DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int row_shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int32_t*>(dest); + const int32x4_t v_src0 = vdupq_n_s32(dst[0]); + const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0); + const int32x4_t v_src_round = + vqrdmulhq_n_s32(v_src0, kTransformRowMultiplier << (31 - 12)); + const int32x4_t v_src = vbslq_s32(v_mask, v_src_round, v_src0); + const int32x4_t v_srcx2 = vaddq_s32(v_src, v_src); + const int32x4_t dst_0 = vqrshlq_s32(v_srcx2, vdupq_n_s32(-row_shift)); + vst1q_lane_s32(dst, vmovl_s16(vqmovn_s32(dst_0)), 0); + return true; +} + +LIBGAV1_ALWAYS_INLINE void Identity16Row_NEON(void* dest, int32_t step, + int shift) { + auto* const dst = static_cast<int32_t*>(dest); + const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11); + const int32x4_t v_shift = vdupq_n_s32(-(12 + shift)); + + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 2; ++j) { + int32x4x2_t v_src; + v_src.val[0] = vld1q_s32(&dst[i * step + j * 8]); + v_src.val[1] = vld1q_s32(&dst[i * step + j * 8 + 4]); + const int32x4_t v_src_mult_lo = + vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity16Multiplier); + const int32x4_t v_src_mult_hi = + vmlaq_n_s32(v_dual_round, v_src.val[1], kIdentity16Multiplier); + const int32x4_t shift_lo = vqshlq_s32(v_src_mult_lo, v_shift); + const int32x4_t shift_hi = vqshlq_s32(v_src_mult_hi, v_shift); + vst1q_s32(&dst[i * step + j * 8], vmovl_s16(vqmovn_s32(shift_lo))); + vst1q_s32(&dst[i * step + j * 8 + 4], vmovl_s16(vqmovn_s32(shift_hi))); + } + } +} + +LIBGAV1_ALWAYS_INLINE bool Identity16DcOnly(void* dest, int adjusted_tx_height, + bool should_round, int shift) { + if (adjusted_tx_height > 1) return false; + + auto* dst = static_cast<int32_t*>(dest); + const int32x4_t v_src0 = vdupq_n_s32(dst[0]); + const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0); + const int32x4_t v_src_round = + vqrdmulhq_n_s32(v_src0, kTransformRowMultiplier << (31 - 12)); + const int32x4_t v_src = vbslq_s32(v_mask, v_src_round, v_src0); + const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11); + const int32x4_t v_src_mult_lo = + vmlaq_n_s32(v_dual_round, v_src, kIdentity16Multiplier); + const int32x4_t dst_0 = vqshlq_s32(v_src_mult_lo, vdupq_n_s32(-(12 + shift))); + vst1q_lane_s32(dst, vmovl_s16(vqmovn_s32(dst_0)), 0); + return true; +} + +//------------------------------------------------------------------------------ +// row/column transform loops + +template <int tx_height> +LIBGAV1_ALWAYS_INLINE void FlipColumns(int32_t* source, int tx_width) { + if (tx_width >= 16) { + int i = 0; + do { + // 00 01 02 03 + const int32x4_t a = vld1q_s32(&source[i]); + const int32x4_t b = vld1q_s32(&source[i + 4]); + const int32x4_t c = vld1q_s32(&source[i + 8]); + const int32x4_t d = vld1q_s32(&source[i + 12]); + // 01 00 03 02 + const int32x4_t a_rev = vrev64q_s32(a); + const int32x4_t b_rev = vrev64q_s32(b); + const int32x4_t c_rev = vrev64q_s32(c); + const int32x4_t d_rev = vrev64q_s32(d); + // 03 02 01 00 + vst1q_s32(&source[i], vextq_s32(d_rev, d_rev, 2)); + vst1q_s32(&source[i + 4], vextq_s32(c_rev, c_rev, 2)); + vst1q_s32(&source[i + 8], vextq_s32(b_rev, b_rev, 2)); + vst1q_s32(&source[i + 12], vextq_s32(a_rev, a_rev, 2)); + i += 16; + } while (i < tx_width * tx_height); + } else if (tx_width == 8) { + for (int i = 0; i < 8 * tx_height; i += 8) { + // 00 01 02 03 + const int32x4_t a = vld1q_s32(&source[i]); + const int32x4_t b = vld1q_s32(&source[i + 4]); + // 01 00 03 02 + const int32x4_t a_rev = vrev64q_s32(a); + const int32x4_t b_rev = vrev64q_s32(b); + // 03 02 01 00 + vst1q_s32(&source[i], vextq_s32(b_rev, b_rev, 2)); + vst1q_s32(&source[i + 4], vextq_s32(a_rev, a_rev, 2)); + } + } else { + // Process two rows per iteration. + for (int i = 0; i < 4 * tx_height; i += 8) { + // 00 01 02 03 + const int32x4_t a = vld1q_s32(&source[i]); + const int32x4_t b = vld1q_s32(&source[i + 4]); + // 01 00 03 02 + const int32x4_t a_rev = vrev64q_s32(a); + const int32x4_t b_rev = vrev64q_s32(b); + // 03 02 01 00 + vst1q_s32(&source[i], vextq_s32(a_rev, a_rev, 2)); + vst1q_s32(&source[i + 4], vextq_s32(b_rev, b_rev, 2)); + } + } +} + +template <int tx_width> +LIBGAV1_ALWAYS_INLINE void ApplyRounding(int32_t* source, int num_rows) { + // Process two rows per iteration. + int i = 0; + do { + const int32x4_t a_lo = vld1q_s32(&source[i]); + const int32x4_t a_hi = vld1q_s32(&source[i + 4]); + const int32x4_t b_lo = + vqrdmulhq_n_s32(a_lo, kTransformRowMultiplier << (31 - 12)); + const int32x4_t b_hi = + vqrdmulhq_n_s32(a_hi, kTransformRowMultiplier << (31 - 12)); + vst1q_s32(&source[i], b_lo); + vst1q_s32(&source[i + 4], b_hi); + i += 8; + } while (i < tx_width * num_rows); +} + +template <int tx_width> +LIBGAV1_ALWAYS_INLINE void RowShift(int32_t* source, int num_rows, + int row_shift) { + // vqrshlq_s32 will shift right if shift value is negative. + row_shift = -row_shift; + + // Process two rows per iteration. + int i = 0; + do { + const int32x4_t residual0 = vld1q_s32(&source[i]); + const int32x4_t residual1 = vld1q_s32(&source[i + 4]); + vst1q_s32(&source[i], vqrshlq_s32(residual0, vdupq_n_s32(row_shift))); + vst1q_s32(&source[i + 4], vqrshlq_s32(residual1, vdupq_n_s32(row_shift))); + i += 8; + } while (i < tx_width * num_rows); +} + +template <int tx_height, bool enable_flip_rows = false> +LIBGAV1_ALWAYS_INLINE void StoreToFrameWithRound( + Array2DView<uint16_t> frame, const int start_x, const int start_y, + const int tx_width, const int32_t* source, TransformType tx_type) { + const bool flip_rows = + enable_flip_rows ? kTransformFlipRowsMask.Contains(tx_type) : false; + const int stride = frame.columns(); + uint16_t* dst = frame[start_y] + start_x; + + if (tx_width == 4) { + for (int i = 0; i < tx_height; ++i) { + const int row = flip_rows ? (tx_height - i - 1) * 4 : i * 4; + const int32x4_t residual = vld1q_s32(&source[row]); + const uint16x4_t frame_data = vld1_u16(dst); + const int32x4_t a = vrshrq_n_s32(residual, 4); + const uint32x4_t b = vaddw_u16(vreinterpretq_u32_s32(a), frame_data); + const uint16x4_t d = vqmovun_s32(vreinterpretq_s32_u32(b)); + vst1_u16(dst, vmin_u16(d, vdup_n_u16((1 << kBitdepth10) - 1))); + dst += stride; + } + } else { + for (int i = 0; i < tx_height; ++i) { + const int y = start_y + i; + const int row = flip_rows ? (tx_height - i - 1) * tx_width : i * tx_width; + int j = 0; + do { + const int x = start_x + j; + const int32x4_t residual = vld1q_s32(&source[row + j]); + const int32x4_t residual_hi = vld1q_s32(&source[row + j + 4]); + const uint16x8_t frame_data = vld1q_u16(frame[y] + x); + const int32x4_t a = vrshrq_n_s32(residual, 4); + const int32x4_t a_hi = vrshrq_n_s32(residual_hi, 4); + const uint32x4_t b = + vaddw_u16(vreinterpretq_u32_s32(a), vget_low_u16(frame_data)); + const uint32x4_t b_hi = + vaddw_u16(vreinterpretq_u32_s32(a_hi), vget_high_u16(frame_data)); + const uint16x4_t d = vqmovun_s32(vreinterpretq_s32_u32(b)); + const uint16x4_t d_hi = vqmovun_s32(vreinterpretq_s32_u32(b_hi)); + vst1q_u16(frame[y] + x, vminq_u16(vcombine_u16(d, d_hi), + vdupq_n_u16((1 << kBitdepth10) - 1))); + j += 8; + } while (j < tx_width); + } + } +} + +void Dct4TransformLoopRow_NEON(TransformType /*tx_type*/, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + auto* src = static_cast<int32_t*>(src_buffer); + const int tx_height = kTransformHeight[tx_size]; + const bool should_round = (tx_height == 8); + const int row_shift = (tx_height == 16); + + if (DctDcOnly<4>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<4>(src, adjusted_tx_height); + } + + // Process 4 1d dct4 rows in parallel per iteration. + int i = adjusted_tx_height; + auto* data = src; + do { + Dct4_NEON<ButterflyRotation_4>(data, /*step=*/4, /*is_row=*/true, + row_shift); + data += 16; + i -= 4; + } while (i != 0); +} + +void Dct4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int32_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<4>(src, tx_width); + } + + if (!DctDcOnlyColumn<4>(src, adjusted_tx_height, tx_width)) { + // Process 4 1d dct4 columns in parallel per iteration. + int i = tx_width; + auto* data = src; + do { + Dct4_NEON<ButterflyRotation_4>(data, tx_width, /*transpose=*/false, + /*row_shift=*/0); + data += 4; + i -= 4; + } while (i != 0); + } + + auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame); + StoreToFrameWithRound<4>(frame, start_x, start_y, tx_width, src, tx_type); +} + +void Dct8TransformLoopRow_NEON(TransformType /*tx_type*/, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + auto* src = static_cast<int32_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (DctDcOnly<8>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<8>(src, adjusted_tx_height); + } + + // Process 4 1d dct8 rows in parallel per iteration. + int i = adjusted_tx_height; + auto* data = src; + do { + Dct8_NEON<ButterflyRotation_4>(data, /*step=*/8, /*is_row=*/true, + row_shift); + data += 32; + i -= 4; + } while (i != 0); +} + +void Dct8TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int32_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<8>(src, tx_width); + } + + if (!DctDcOnlyColumn<8>(src, adjusted_tx_height, tx_width)) { + // Process 4 1d dct8 columns in parallel per iteration. + int i = tx_width; + auto* data = src; + do { + Dct8_NEON<ButterflyRotation_4>(data, tx_width, /*is_row=*/false, + /*row_shift=*/0); + data += 4; + i -= 4; + } while (i != 0); + } + auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame); + StoreToFrameWithRound<8>(frame, start_x, start_y, tx_width, src, tx_type); +} + +void Dct16TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int32_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (DctDcOnly<16>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<16>(src, adjusted_tx_height); + } + + assert(adjusted_tx_height % 4 == 0); + int i = adjusted_tx_height; + auto* data = src; + do { + // Process 4 1d dct16 rows in parallel per iteration. + Dct16_NEON<ButterflyRotation_4>(data, 16, /*is_row=*/true, row_shift); + data += 64; + i -= 4; + } while (i != 0); +} + +void Dct16TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int32_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<16>(src, tx_width); + } + + if (!DctDcOnlyColumn<16>(src, adjusted_tx_height, tx_width)) { + // Process 4 1d dct16 columns in parallel per iteration. + int i = tx_width; + auto* data = src; + do { + Dct16_NEON<ButterflyRotation_4>(data, tx_width, /*is_row=*/false, + /*row_shift=*/0); + data += 4; + i -= 4; + } while (i != 0); + } + auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame); + StoreToFrameWithRound<16>(frame, start_x, start_y, tx_width, src, tx_type); +} + +void Dct32TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int32_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (DctDcOnly<32>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<32>(src, adjusted_tx_height); + } + + assert(adjusted_tx_height % 4 == 0); + int i = adjusted_tx_height; + auto* data = src; + do { + // Process 4 1d dct32 rows in parallel per iteration. + Dct32_NEON(data, 32, /*is_row=*/true, row_shift); + data += 128; + i -= 4; + } while (i != 0); +} + +void Dct32TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int32_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<32>(src, tx_width); + } + + if (!DctDcOnlyColumn<32>(src, adjusted_tx_height, tx_width)) { + // Process 4 1d dct32 columns in parallel per iteration. + int i = tx_width; + auto* data = src; + do { + Dct32_NEON(data, tx_width, /*is_row=*/false, /*row_shift=*/0); + data += 4; + i -= 4; + } while (i != 0); + } + auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame); + StoreToFrameWithRound<32>(frame, start_x, start_y, tx_width, src, tx_type); +} + +void Dct64TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int32_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (DctDcOnly<64>(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<64>(src, adjusted_tx_height); + } + + assert(adjusted_tx_height % 4 == 0); + int i = adjusted_tx_height; + auto* data = src; + do { + // Process 4 1d dct64 rows in parallel per iteration. + Dct64_NEON(data, 64, /*is_row=*/true, row_shift); + data += 128 * 2; + i -= 4; + } while (i != 0); +} + +void Dct64TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int32_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<64>(src, tx_width); + } + + if (!DctDcOnlyColumn<64>(src, adjusted_tx_height, tx_width)) { + // Process 4 1d dct64 columns in parallel per iteration. + int i = tx_width; + auto* data = src; + do { + Dct64_NEON(data, tx_width, /*is_row=*/false, /*row_shift=*/0); + data += 4; + i -= 4; + } while (i != 0); + } + auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame); + StoreToFrameWithRound<64>(frame, start_x, start_y, tx_width, src, tx_type); +} + +void Adst4TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int32_t*>(src_buffer); + const int tx_height = kTransformHeight[tx_size]; + const int row_shift = static_cast<int>(tx_height == 16); + const bool should_round = (tx_height == 8); + + if (Adst4DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<4>(src, adjusted_tx_height); + } + + // Process 4 1d adst4 rows in parallel per iteration. + int i = adjusted_tx_height; + auto* data = src; + do { + Adst4_NEON(data, /*step=*/4, /*is_row=*/true, row_shift); + data += 16; + i -= 4; + } while (i != 0); +} + +void Adst4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int32_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<4>(src, tx_width); + } + + if (!Adst4DcOnlyColumn(src, adjusted_tx_height, tx_width)) { + // Process 4 1d adst4 columns in parallel per iteration. + int i = tx_width; + auto* data = src; + do { + Adst4_NEON(data, tx_width, /*is_row=*/false, /*row_shift=*/0); + data += 4; + i -= 4; + } while (i != 0); + } + + auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame); + StoreToFrameWithRound<4, /*enable_flip_rows=*/true>(frame, start_x, start_y, + tx_width, src, tx_type); +} + +void Adst8TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int32_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (Adst8DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<8>(src, adjusted_tx_height); + } + + // Process 4 1d adst8 rows in parallel per iteration. + assert(adjusted_tx_height % 4 == 0); + int i = adjusted_tx_height; + auto* data = src; + do { + Adst8_NEON<ButterflyRotation_4>(data, /*step=*/8, + /*transpose=*/true, row_shift); + data += 32; + i -= 4; + } while (i != 0); +} + +void Adst8TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int32_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<8>(src, tx_width); + } + + if (!Adst8DcOnlyColumn(src, adjusted_tx_height, tx_width)) { + // Process 4 1d adst8 columns in parallel per iteration. + int i = tx_width; + auto* data = src; + do { + Adst8_NEON<ButterflyRotation_4>(data, tx_width, /*transpose=*/false, + /*row_shift=*/0); + data += 4; + i -= 4; + } while (i != 0); + } + auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame); + StoreToFrameWithRound<8, /*enable_flip_rows=*/true>(frame, start_x, start_y, + tx_width, src, tx_type); +} + +void Adst16TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, int adjusted_tx_height, + void* src_buffer, int /*start_x*/, + int /*start_y*/, void* /*dst_frame*/) { + auto* src = static_cast<int32_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (Adst16DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<16>(src, adjusted_tx_height); + } + + assert(adjusted_tx_height % 4 == 0); + int i = adjusted_tx_height; + do { + // Process 4 1d adst16 rows in parallel per iteration. + Adst16_NEON<ButterflyRotation_4>(src, 16, /*is_row=*/true, row_shift); + src += 64; + i -= 4; + } while (i != 0); +} + +void Adst16TransformLoopColumn_NEON(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, void* dst_frame) { + auto* src = static_cast<int32_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<16>(src, tx_width); + } + + if (!Adst16DcOnlyColumn(src, adjusted_tx_height, tx_width)) { + int i = tx_width; + auto* data = src; + do { + // Process 4 1d adst16 columns in parallel per iteration. + Adst16_NEON<ButterflyRotation_4>(data, tx_width, /*is_row=*/false, + /*row_shift=*/0); + data += 4; + i -= 4; + } while (i != 0); + } + auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame); + StoreToFrameWithRound<16, /*enable_flip_rows=*/true>(frame, start_x, start_y, + tx_width, src, tx_type); +} + +void Identity4TransformLoopRow_NEON(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + // Special case: Process row calculations during column transform call. + // Improves performance. + if (tx_type == kTransformTypeIdentityIdentity && + tx_size == kTransformSize4x4) { + return; + } + + auto* src = static_cast<int32_t*>(src_buffer); + const int tx_height = kTransformHeight[tx_size]; + const bool should_round = (tx_height == 8); + + if (Identity4DcOnly(src, adjusted_tx_height, should_round, tx_height)) { + return; + } + + if (should_round) { + ApplyRounding<4>(src, adjusted_tx_height); + } + + const int shift = tx_height > 8 ? 1 : 0; + int i = adjusted_tx_height; + do { + Identity4_NEON(src, /*step=*/4, shift); + src += 16; + i -= 4; + } while (i != 0); +} + +void Identity4TransformLoopColumn_NEON(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, + void* dst_frame) { + auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame); + auto* src = static_cast<int32_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + // Special case: Process row calculations during column transform call. + if (tx_type == kTransformTypeIdentityIdentity && + (tx_size == kTransformSize4x4 || tx_size == kTransformSize8x4)) { + Identity4RowColumnStoreToFrame(frame, start_x, start_y, tx_width, + adjusted_tx_height, src); + return; + } + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<4>(src, tx_width); + } + + IdentityColumnStoreToFrame<4>(frame, start_x, start_y, tx_width, + adjusted_tx_height, src); +} + +void Identity8TransformLoopRow_NEON(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + // Special case: Process row calculations during column transform call. + // Improves performance. + if (tx_type == kTransformTypeIdentityIdentity && + tx_size == kTransformSize8x4) { + return; + } + + auto* src = static_cast<int32_t*>(src_buffer); + const int tx_height = kTransformHeight[tx_size]; + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (Identity8DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + if (should_round) { + ApplyRounding<8>(src, adjusted_tx_height); + } + + // When combining the identity8 multiplier with the row shift, the + // calculations for tx_height == 8 and tx_height == 16 can be simplified + // from ((A * 2) + 1) >> 1) to A. For 10bpp, A must be clamped to a signed 16 + // bit value. + if ((tx_height & 0x18) != 0) { + for (int i = 0; i < tx_height; ++i) { + const int32x4_t v_src_lo = vld1q_s32(&src[i * 8]); + const int32x4_t v_src_hi = vld1q_s32(&src[(i * 8) + 4]); + vst1q_s32(&src[i * 8], vmovl_s16(vqmovn_s32(v_src_lo))); + vst1q_s32(&src[(i * 8) + 4], vmovl_s16(vqmovn_s32(v_src_hi))); + } + return; + } + if (tx_height == 32) { + int i = adjusted_tx_height; + do { + Identity8Row32_NEON(src, /*step=*/8); + src += 32; + i -= 4; + } while (i != 0); + return; + } + + assert(tx_size == kTransformSize8x4); + int i = adjusted_tx_height; + do { + Identity8Row4_NEON(src, /*step=*/8); + src += 32; + i -= 4; + } while (i != 0); +} + +void Identity8TransformLoopColumn_NEON(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int start_x, int start_y, + void* dst_frame) { + auto* src = static_cast<int32_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<8>(src, tx_width); + } + auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame); + IdentityColumnStoreToFrame<8>(frame, start_x, start_y, tx_width, + adjusted_tx_height, src); +} + +void Identity16TransformLoopRow_NEON(TransformType /*tx_type*/, + TransformSize tx_size, + int adjusted_tx_height, void* src_buffer, + int /*start_x*/, int /*start_y*/, + void* /*dst_frame*/) { + auto* src = static_cast<int32_t*>(src_buffer); + const bool should_round = kShouldRound[tx_size]; + const uint8_t row_shift = kTransformRowShift[tx_size]; + + if (Identity16DcOnly(src, adjusted_tx_height, should_round, row_shift)) { + return; + } + + if (should_round) { + ApplyRounding<16>(src, adjusted_tx_height); + } + int i = adjusted_tx_height; + do { + Identity16Row_NEON(src, /*step=*/16, row_shift); + src += 64; + i -= 4; + } while (i != 0); +} + +void Identity16TransformLoopColumn_NEON(TransformType tx_type, + TransformSize tx_size, + int adjusted_tx_height, + void* src_buffer, int start_x, + int start_y, void* dst_frame) { + auto* src = static_cast<int32_t*>(src_buffer); + const int tx_width = kTransformWidth[tx_size]; + + if (kTransformFlipColumnsMask.Contains(tx_type)) { + FlipColumns<16>(src, tx_width); + } + auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame); + IdentityColumnStoreToFrame<16>(frame, start_x, start_y, tx_width, + adjusted_tx_height, src); +} + +//------------------------------------------------------------------------------ + +void Init10bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + // Maximum transform size for Dct is 64. + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kRow] = + Dct4TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kColumn] = + Dct4TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kRow] = + Dct8TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kColumn] = + Dct8TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kRow] = + Dct16TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kColumn] = + Dct16TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kRow] = + Dct32TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kColumn] = + Dct32TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kRow] = + Dct64TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kColumn] = + Dct64TransformLoopColumn_NEON; + + // Maximum transform size for Adst is 16. + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kRow] = + Adst4TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kColumn] = + Adst4TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kRow] = + Adst8TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kColumn] = + Adst8TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kRow] = + Adst16TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kColumn] = + Adst16TransformLoopColumn_NEON; + + // Maximum transform size for Identity transform is 32. + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kRow] = + Identity4TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kColumn] = + Identity4TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kRow] = + Identity8TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kColumn] = + Identity8TransformLoopColumn_NEON; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kRow] = + Identity16TransformLoopRow_NEON; + dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kColumn] = + Identity16TransformLoopColumn_NEON; +} + +} // namespace + +void InverseTransformInit10bpp_NEON() { Init10bpp(); } + +} // namespace dsp +} // namespace libgav1 +#else // !LIBGAV1_ENABLE_NEON || LIBGAV1_MAX_BITDEPTH < 10 +namespace libgav1 { +namespace dsp { + +void InverseTransformInit10bpp_NEON() {} + +} // namespace dsp +} // namespace libgav1 +#endif // LIBGAV1_ENABLE_NEON && LIBGAV1_MAX_BITDEPTH >= 10 |