aboutsummaryrefslogtreecommitdiff
path: root/src/dsp/arm
diff options
context:
space:
mode:
authorqinxialei <xialeiqin@gmail.com>2020-10-29 11:26:59 +0800
committerqinxialei <xialeiqin@gmail.com>2020-10-29 11:26:59 +0800
commite8d277081293b6fb2a5d469616baaa7a06f52496 (patch)
tree1179bb07d3927d1837d4a90bd81b2034c4c696a9 /src/dsp/arm
downloadlibgav1-e8d277081293b6fb2a5d469616baaa7a06f52496.tar.gz
libgav1-e8d277081293b6fb2a5d469616baaa7a06f52496.tar.bz2
libgav1-e8d277081293b6fb2a5d469616baaa7a06f52496.zip
Import Upstream version 0.16.0
Diffstat (limited to 'src/dsp/arm')
-rw-r--r--src/dsp/arm/average_blend_neon.cc146
-rw-r--r--src/dsp/arm/average_blend_neon.h36
-rw-r--r--src/dsp/arm/cdef_neon.cc697
-rw-r--r--src/dsp/arm/cdef_neon.h38
-rw-r--r--src/dsp/arm/common_neon.h777
-rw-r--r--src/dsp/arm/convolve_neon.cc3105
-rw-r--r--src/dsp/arm/convolve_neon.h50
-rw-r--r--src/dsp/arm/distance_weighted_blend_neon.cc203
-rw-r--r--src/dsp/arm/distance_weighted_blend_neon.h39
-rw-r--r--src/dsp/arm/film_grain_neon.cc1188
-rw-r--r--src/dsp/arm/film_grain_neon.h47
-rw-r--r--src/dsp/arm/intra_edge_neon.cc301
-rw-r--r--src/dsp/arm/intra_edge_neon.h39
-rw-r--r--src/dsp/arm/intrapred_cfl_neon.cc479
-rw-r--r--src/dsp/arm/intrapred_directional_neon.cc926
-rw-r--r--src/dsp/arm/intrapred_filter_intra_neon.cc176
-rw-r--r--src/dsp/arm/intrapred_neon.cc1144
-rw-r--r--src/dsp/arm/intrapred_neon.h418
-rw-r--r--src/dsp/arm/intrapred_smooth_neon.cc616
-rw-r--r--src/dsp/arm/inverse_transform_neon.cc3128
-rw-r--r--src/dsp/arm/inverse_transform_neon.h52
-rw-r--r--src/dsp/arm/loop_filter_neon.cc1190
-rw-r--r--src/dsp/arm/loop_filter_neon.h53
-rw-r--r--src/dsp/arm/loop_restoration_neon.cc1901
-rw-r--r--src/dsp/arm/loop_restoration_neon.h40
-rw-r--r--src/dsp/arm/mask_blend_neon.cc444
-rw-r--r--src/dsp/arm/mask_blend_neon.h41
-rw-r--r--src/dsp/arm/motion_field_projection_neon.cc393
-rw-r--r--src/dsp/arm/motion_field_projection_neon.h39
-rw-r--r--src/dsp/arm/motion_vector_search_neon.cc267
-rw-r--r--src/dsp/arm/motion_vector_search_neon.h39
-rw-r--r--src/dsp/arm/obmc_neon.cc392
-rw-r--r--src/dsp/arm/obmc_neon.h38
-rw-r--r--src/dsp/arm/super_res_neon.cc166
-rw-r--r--src/dsp/arm/super_res_neon.h37
-rw-r--r--src/dsp/arm/warp_neon.cc453
-rw-r--r--src/dsp/arm/warp_neon.h37
-rw-r--r--src/dsp/arm/weight_mask_neon.cc463
-rw-r--r--src/dsp/arm/weight_mask_neon.h52
39 files changed, 19650 insertions, 0 deletions
diff --git a/src/dsp/arm/average_blend_neon.cc b/src/dsp/arm/average_blend_neon.cc
new file mode 100644
index 0000000..834e8b4
--- /dev/null
+++ b/src/dsp/arm/average_blend_neon.cc
@@ -0,0 +1,146 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/average_blend.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+constexpr int kInterPostRoundBit =
+ kInterRoundBitsVertical - kInterRoundBitsCompoundVertical;
+
+inline uint8x8_t AverageBlend8Row(const int16_t* prediction_0,
+ const int16_t* prediction_1) {
+ const int16x8_t pred0 = vld1q_s16(prediction_0);
+ const int16x8_t pred1 = vld1q_s16(prediction_1);
+ const int16x8_t res = vaddq_s16(pred0, pred1);
+ return vqrshrun_n_s16(res, kInterPostRoundBit + 1);
+}
+
+inline void AverageBlendLargeRow(const int16_t* prediction_0,
+ const int16_t* prediction_1, const int width,
+ uint8_t* dest) {
+ int x = width;
+ do {
+ const int16x8_t pred_00 = vld1q_s16(prediction_0);
+ const int16x8_t pred_01 = vld1q_s16(prediction_1);
+ prediction_0 += 8;
+ prediction_1 += 8;
+ const int16x8_t res0 = vaddq_s16(pred_00, pred_01);
+ const uint8x8_t res_out0 = vqrshrun_n_s16(res0, kInterPostRoundBit + 1);
+ const int16x8_t pred_10 = vld1q_s16(prediction_0);
+ const int16x8_t pred_11 = vld1q_s16(prediction_1);
+ prediction_0 += 8;
+ prediction_1 += 8;
+ const int16x8_t res1 = vaddq_s16(pred_10, pred_11);
+ const uint8x8_t res_out1 = vqrshrun_n_s16(res1, kInterPostRoundBit + 1);
+ vst1q_u8(dest, vcombine_u8(res_out0, res_out1));
+ dest += 16;
+ x -= 16;
+ } while (x != 0);
+}
+
+void AverageBlend_NEON(const void* prediction_0, const void* prediction_1,
+ const int width, const int height, void* const dest,
+ const ptrdiff_t dest_stride) {
+ auto* dst = static_cast<uint8_t*>(dest);
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ int y = height;
+
+ if (width == 4) {
+ do {
+ const uint8x8_t result = AverageBlend8Row(pred_0, pred_1);
+ pred_0 += 8;
+ pred_1 += 8;
+
+ StoreLo4(dst, result);
+ dst += dest_stride;
+ StoreHi4(dst, result);
+ dst += dest_stride;
+ y -= 2;
+ } while (y != 0);
+ return;
+ }
+
+ if (width == 8) {
+ do {
+ vst1_u8(dst, AverageBlend8Row(pred_0, pred_1));
+ dst += dest_stride;
+ pred_0 += 8;
+ pred_1 += 8;
+
+ vst1_u8(dst, AverageBlend8Row(pred_0, pred_1));
+ dst += dest_stride;
+ pred_0 += 8;
+ pred_1 += 8;
+
+ y -= 2;
+ } while (y != 0);
+ return;
+ }
+
+ do {
+ AverageBlendLargeRow(pred_0, pred_1, width, dst);
+ dst += dest_stride;
+ pred_0 += width;
+ pred_1 += width;
+
+ AverageBlendLargeRow(pred_0, pred_1, width, dst);
+ dst += dest_stride;
+ pred_0 += width;
+ pred_1 += width;
+
+ y -= 2;
+ } while (y != 0);
+}
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+ dsp->average_blend = AverageBlend_NEON;
+}
+
+} // namespace
+
+void AverageBlendInit_NEON() { Init8bpp(); }
+
+} // namespace dsp
+} // namespace libgav1
+
+#else // !LIBGAV1_ENABLE_NEON
+
+namespace libgav1 {
+namespace dsp {
+
+void AverageBlendInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/average_blend_neon.h b/src/dsp/arm/average_blend_neon.h
new file mode 100644
index 0000000..d13bcd6
--- /dev/null
+++ b/src/dsp/arm/average_blend_neon.h
@@ -0,0 +1,36 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_AVERAGE_BLEND_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_AVERAGE_BLEND_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::average_blend. This function is not thread-safe.
+void AverageBlendInit_NEON();
+
+} // namespace dsp
+} // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_AverageBlend LIBGAV1_CPU_NEON
+#endif // LIBGAV1_ENABLE_NEON
+
+#endif // LIBGAV1_SRC_DSP_ARM_AVERAGE_BLEND_NEON_H_
diff --git a/src/dsp/arm/cdef_neon.cc b/src/dsp/arm/cdef_neon.cc
new file mode 100644
index 0000000..4d0e76f
--- /dev/null
+++ b/src/dsp/arm/cdef_neon.cc
@@ -0,0 +1,697 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/cdef.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <cstdlib>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+#include "src/utils/constants.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+#include "src/dsp/cdef.inc"
+
+// ----------------------------------------------------------------------------
+// Refer to CdefDirection_C().
+//
+// int32_t partial[8][15] = {};
+// for (int i = 0; i < 8; ++i) {
+// for (int j = 0; j < 8; ++j) {
+// const int x = 1;
+// partial[0][i + j] += x;
+// partial[1][i + j / 2] += x;
+// partial[2][i] += x;
+// partial[3][3 + i - j / 2] += x;
+// partial[4][7 + i - j] += x;
+// partial[5][3 - i / 2 + j] += x;
+// partial[6][j] += x;
+// partial[7][i / 2 + j] += x;
+// }
+// }
+//
+// Using the code above, generate the position count for partial[8][15].
+//
+// partial[0]: 1 2 3 4 5 6 7 8 7 6 5 4 3 2 1
+// partial[1]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
+// partial[2]: 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0
+// partial[3]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
+// partial[4]: 1 2 3 4 5 6 7 8 7 6 5 4 3 2 1
+// partial[5]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
+// partial[6]: 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0
+// partial[7]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
+//
+// The SIMD code shifts the input horizontally, then adds vertically to get the
+// correct partial value for the given position.
+// ----------------------------------------------------------------------------
+
+// ----------------------------------------------------------------------------
+// partial[0][i + j] += x;
+//
+// 00 01 02 03 04 05 06 07 00 00 00 00 00 00 00
+// 00 10 11 12 13 14 15 16 17 00 00 00 00 00 00
+// 00 00 20 21 22 23 24 25 26 27 00 00 00 00 00
+// 00 00 00 30 31 32 33 34 35 36 37 00 00 00 00
+// 00 00 00 00 40 41 42 43 44 45 46 47 00 00 00
+// 00 00 00 00 00 50 51 52 53 54 55 56 57 00 00
+// 00 00 00 00 00 00 60 61 62 63 64 65 66 67 00
+// 00 00 00 00 00 00 00 70 71 72 73 74 75 76 77
+//
+// partial[4] is the same except the source is reversed.
+LIBGAV1_ALWAYS_INLINE void AddPartial_D0_D4(uint8x8_t* v_src,
+ uint16x8_t* partial_lo,
+ uint16x8_t* partial_hi) {
+ const uint8x8_t v_zero = vdup_n_u8(0);
+ // 00 01 02 03 04 05 06 07
+ // 00 10 11 12 13 14 15 16
+ *partial_lo = vaddl_u8(v_src[0], vext_u8(v_zero, v_src[1], 7));
+
+ // 00 00 20 21 22 23 24 25
+ *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[2], 6));
+ // 17 00 00 00 00 00 00 00
+ // 26 27 00 00 00 00 00 00
+ *partial_hi =
+ vaddl_u8(vext_u8(v_src[1], v_zero, 7), vext_u8(v_src[2], v_zero, 6));
+
+ // 00 00 00 30 31 32 33 34
+ *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[3], 5));
+ // 35 36 37 00 00 00 00 00
+ *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[3], v_zero, 5));
+
+ // 00 00 00 00 40 41 42 43
+ *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[4], 4));
+ // 44 45 46 47 00 00 00 00
+ *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[4], v_zero, 4));
+
+ // 00 00 00 00 00 50 51 52
+ *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[5], 3));
+ // 53 54 55 56 57 00 00 00
+ *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[5], v_zero, 3));
+
+ // 00 00 00 00 00 00 60 61
+ *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[6], 2));
+ // 62 63 64 65 66 67 00 00
+ *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[6], v_zero, 2));
+
+ // 00 00 00 00 00 00 00 70
+ *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[7], 1));
+ // 71 72 73 74 75 76 77 00
+ *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[7], v_zero, 1));
+}
+
+// ----------------------------------------------------------------------------
+// partial[1][i + j / 2] += x;
+//
+// A0 = src[0] + src[1], A1 = src[2] + src[3], ...
+//
+// A0 A1 A2 A3 00 00 00 00 00 00 00 00 00 00 00
+// 00 B0 B1 B2 B3 00 00 00 00 00 00 00 00 00 00
+// 00 00 C0 C1 C2 C3 00 00 00 00 00 00 00 00 00
+// 00 00 00 D0 D1 D2 D3 00 00 00 00 00 00 00 00
+// 00 00 00 00 E0 E1 E2 E3 00 00 00 00 00 00 00
+// 00 00 00 00 00 F0 F1 F2 F3 00 00 00 00 00 00
+// 00 00 00 00 00 00 G0 G1 G2 G3 00 00 00 00 00
+// 00 00 00 00 00 00 00 H0 H1 H2 H3 00 00 00 00
+//
+// partial[3] is the same except the source is reversed.
+LIBGAV1_ALWAYS_INLINE void AddPartial_D1_D3(uint8x8_t* v_src,
+ uint16x8_t* partial_lo,
+ uint16x8_t* partial_hi) {
+ uint8x16_t v_d1_temp[8];
+ const uint8x8_t v_zero = vdup_n_u8(0);
+ const uint8x16_t v_zero_16 = vdupq_n_u8(0);
+
+ for (int i = 0; i < 8; ++i) {
+ v_d1_temp[i] = vcombine_u8(v_src[i], v_zero);
+ }
+
+ *partial_lo = *partial_hi = vdupq_n_u16(0);
+ // A0 A1 A2 A3 00 00 00 00
+ *partial_lo = vpadalq_u8(*partial_lo, v_d1_temp[0]);
+
+ // 00 B0 B1 B2 B3 00 00 00
+ *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[1], 14));
+
+ // 00 00 C0 C1 C2 C3 00 00
+ *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[2], 12));
+ // 00 00 00 D0 D1 D2 D3 00
+ *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[3], 10));
+ // 00 00 00 00 E0 E1 E2 E3
+ *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[4], 8));
+
+ // 00 00 00 00 00 F0 F1 F2
+ *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[5], 6));
+ // F3 00 00 00 00 00 00 00
+ *partial_hi = vpadalq_u8(*partial_hi, vextq_u8(v_d1_temp[5], v_zero_16, 6));
+
+ // 00 00 00 00 00 00 G0 G1
+ *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[6], 4));
+ // G2 G3 00 00 00 00 00 00
+ *partial_hi = vpadalq_u8(*partial_hi, vextq_u8(v_d1_temp[6], v_zero_16, 4));
+
+ // 00 00 00 00 00 00 00 H0
+ *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[7], 2));
+ // H1 H2 H3 00 00 00 00 00
+ *partial_hi = vpadalq_u8(*partial_hi, vextq_u8(v_d1_temp[7], v_zero_16, 2));
+}
+
+// ----------------------------------------------------------------------------
+// partial[7][i / 2 + j] += x;
+//
+// 00 01 02 03 04 05 06 07 00 00 00 00 00 00 00
+// 10 11 12 13 14 15 16 17 00 00 00 00 00 00 00
+// 00 20 21 22 23 24 25 26 27 00 00 00 00 00 00
+// 00 30 31 32 33 34 35 36 37 00 00 00 00 00 00
+// 00 00 40 41 42 43 44 45 46 47 00 00 00 00 00
+// 00 00 50 51 52 53 54 55 56 57 00 00 00 00 00
+// 00 00 00 60 61 62 63 64 65 66 67 00 00 00 00
+// 00 00 00 70 71 72 73 74 75 76 77 00 00 00 00
+//
+// partial[5] is the same except the source is reversed.
+LIBGAV1_ALWAYS_INLINE void AddPartial_D5_D7(uint8x8_t* v_src,
+ uint16x8_t* partial_lo,
+ uint16x8_t* partial_hi) {
+ const uint16x8_t v_zero = vdupq_n_u16(0);
+ uint16x8_t v_pair_add[4];
+ // Add vertical source pairs.
+ v_pair_add[0] = vaddl_u8(v_src[0], v_src[1]);
+ v_pair_add[1] = vaddl_u8(v_src[2], v_src[3]);
+ v_pair_add[2] = vaddl_u8(v_src[4], v_src[5]);
+ v_pair_add[3] = vaddl_u8(v_src[6], v_src[7]);
+
+ // 00 01 02 03 04 05 06 07
+ // 10 11 12 13 14 15 16 17
+ *partial_lo = v_pair_add[0];
+ // 00 00 00 00 00 00 00 00
+ // 00 00 00 00 00 00 00 00
+ *partial_hi = vdupq_n_u16(0);
+
+ // 00 20 21 22 23 24 25 26
+ // 00 30 31 32 33 34 35 36
+ *partial_lo = vaddq_u16(*partial_lo, vextq_u16(v_zero, v_pair_add[1], 7));
+ // 27 00 00 00 00 00 00 00
+ // 37 00 00 00 00 00 00 00
+ *partial_hi = vaddq_u16(*partial_hi, vextq_u16(v_pair_add[1], v_zero, 7));
+
+ // 00 00 40 41 42 43 44 45
+ // 00 00 50 51 52 53 54 55
+ *partial_lo = vaddq_u16(*partial_lo, vextq_u16(v_zero, v_pair_add[2], 6));
+ // 46 47 00 00 00 00 00 00
+ // 56 57 00 00 00 00 00 00
+ *partial_hi = vaddq_u16(*partial_hi, vextq_u16(v_pair_add[2], v_zero, 6));
+
+ // 00 00 00 60 61 62 63 64
+ // 00 00 00 70 71 72 73 74
+ *partial_lo = vaddq_u16(*partial_lo, vextq_u16(v_zero, v_pair_add[3], 5));
+ // 65 66 67 00 00 00 00 00
+ // 75 76 77 00 00 00 00 00
+ *partial_hi = vaddq_u16(*partial_hi, vextq_u16(v_pair_add[3], v_zero, 5));
+}
+
+LIBGAV1_ALWAYS_INLINE void AddPartial(const void* const source,
+ ptrdiff_t stride, uint16x8_t* partial_lo,
+ uint16x8_t* partial_hi) {
+ const auto* src = static_cast<const uint8_t*>(source);
+
+ // 8x8 input
+ // 00 01 02 03 04 05 06 07
+ // 10 11 12 13 14 15 16 17
+ // 20 21 22 23 24 25 26 27
+ // 30 31 32 33 34 35 36 37
+ // 40 41 42 43 44 45 46 47
+ // 50 51 52 53 54 55 56 57
+ // 60 61 62 63 64 65 66 67
+ // 70 71 72 73 74 75 76 77
+ uint8x8_t v_src[8];
+ for (int i = 0; i < 8; ++i) {
+ v_src[i] = vld1_u8(src);
+ src += stride;
+ }
+
+ // partial for direction 2
+ // --------------------------------------------------------------------------
+ // partial[2][i] += x;
+ // 00 10 20 30 40 50 60 70 00 00 00 00 00 00 00 00
+ // 01 11 21 33 41 51 61 71 00 00 00 00 00 00 00 00
+ // 02 12 22 33 42 52 62 72 00 00 00 00 00 00 00 00
+ // 03 13 23 33 43 53 63 73 00 00 00 00 00 00 00 00
+ // 04 14 24 34 44 54 64 74 00 00 00 00 00 00 00 00
+ // 05 15 25 35 45 55 65 75 00 00 00 00 00 00 00 00
+ // 06 16 26 36 46 56 66 76 00 00 00 00 00 00 00 00
+ // 07 17 27 37 47 57 67 77 00 00 00 00 00 00 00 00
+ partial_lo[2] = vsetq_lane_u16(SumVector(v_src[0]), partial_lo[2], 0);
+ partial_lo[2] = vsetq_lane_u16(SumVector(v_src[1]), partial_lo[2], 1);
+ partial_lo[2] = vsetq_lane_u16(SumVector(v_src[2]), partial_lo[2], 2);
+ partial_lo[2] = vsetq_lane_u16(SumVector(v_src[3]), partial_lo[2], 3);
+ partial_lo[2] = vsetq_lane_u16(SumVector(v_src[4]), partial_lo[2], 4);
+ partial_lo[2] = vsetq_lane_u16(SumVector(v_src[5]), partial_lo[2], 5);
+ partial_lo[2] = vsetq_lane_u16(SumVector(v_src[6]), partial_lo[2], 6);
+ partial_lo[2] = vsetq_lane_u16(SumVector(v_src[7]), partial_lo[2], 7);
+
+ // partial for direction 6
+ // --------------------------------------------------------------------------
+ // partial[6][j] += x;
+ // 00 01 02 03 04 05 06 07 00 00 00 00 00 00 00 00
+ // 10 11 12 13 14 15 16 17 00 00 00 00 00 00 00 00
+ // 20 21 22 23 24 25 26 27 00 00 00 00 00 00 00 00
+ // 30 31 32 33 34 35 36 37 00 00 00 00 00 00 00 00
+ // 40 41 42 43 44 45 46 47 00 00 00 00 00 00 00 00
+ // 50 51 52 53 54 55 56 57 00 00 00 00 00 00 00 00
+ // 60 61 62 63 64 65 66 67 00 00 00 00 00 00 00 00
+ // 70 71 72 73 74 75 76 77 00 00 00 00 00 00 00 00
+ const uint8x8_t v_zero = vdup_n_u8(0);
+ partial_lo[6] = vaddl_u8(v_zero, v_src[0]);
+ for (int i = 1; i < 8; ++i) {
+ partial_lo[6] = vaddw_u8(partial_lo[6], v_src[i]);
+ }
+
+ // partial for direction 0
+ AddPartial_D0_D4(v_src, &partial_lo[0], &partial_hi[0]);
+
+ // partial for direction 1
+ AddPartial_D1_D3(v_src, &partial_lo[1], &partial_hi[1]);
+
+ // partial for direction 7
+ AddPartial_D5_D7(v_src, &partial_lo[7], &partial_hi[7]);
+
+ uint8x8_t v_src_reverse[8];
+ for (int i = 0; i < 8; ++i) {
+ v_src_reverse[i] = vrev64_u8(v_src[i]);
+ }
+
+ // partial for direction 4
+ AddPartial_D0_D4(v_src_reverse, &partial_lo[4], &partial_hi[4]);
+
+ // partial for direction 3
+ AddPartial_D1_D3(v_src_reverse, &partial_lo[3], &partial_hi[3]);
+
+ // partial for direction 5
+ AddPartial_D5_D7(v_src_reverse, &partial_lo[5], &partial_hi[5]);
+}
+
+uint32x4_t Square(uint16x4_t a) { return vmull_u16(a, a); }
+
+uint32x4_t SquareAccumulate(uint32x4_t a, uint16x4_t b) {
+ return vmlal_u16(a, b, b);
+}
+
+// |cost[0]| and |cost[4]| square the input and sum with the corresponding
+// element from the other end of the vector:
+// |kCdefDivisionTable[]| element:
+// cost[0] += (Square(partial[0][i]) + Square(partial[0][14 - i])) *
+// kCdefDivisionTable[i + 1];
+// cost[0] += Square(partial[0][7]) * kCdefDivisionTable[8];
+// Because everything is being summed into a single value the distributive
+// property allows us to mirror the division table and accumulate once.
+uint32_t Cost0Or4(const uint16x8_t a, const uint16x8_t b,
+ const uint32x4_t division_table[4]) {
+ uint32x4_t c = vmulq_u32(Square(vget_low_u16(a)), division_table[0]);
+ c = vmlaq_u32(c, Square(vget_high_u16(a)), division_table[1]);
+ c = vmlaq_u32(c, Square(vget_low_u16(b)), division_table[2]);
+ c = vmlaq_u32(c, Square(vget_high_u16(b)), division_table[3]);
+ return SumVector(c);
+}
+
+// |cost[2]| and |cost[6]| square the input and accumulate:
+// cost[2] += Square(partial[2][i])
+uint32_t SquareAccumulate(const uint16x8_t a) {
+ uint32x4_t c = Square(vget_low_u16(a));
+ c = SquareAccumulate(c, vget_high_u16(a));
+ c = vmulq_n_u32(c, kCdefDivisionTable[7]);
+ return SumVector(c);
+}
+
+uint32_t CostOdd(const uint16x8_t a, const uint16x8_t b, const uint32x4_t mask,
+ const uint32x4_t division_table[2]) {
+ // Remove elements 0-2.
+ uint32x4_t c = vandq_u32(mask, Square(vget_low_u16(a)));
+ c = vaddq_u32(c, Square(vget_high_u16(a)));
+ c = vmulq_n_u32(c, kCdefDivisionTable[7]);
+
+ c = vmlaq_u32(c, Square(vget_low_u16(a)), division_table[0]);
+ c = vmlaq_u32(c, Square(vget_low_u16(b)), division_table[1]);
+ return SumVector(c);
+}
+
+void CdefDirection_NEON(const void* const source, ptrdiff_t stride,
+ uint8_t* const direction, int* const variance) {
+ assert(direction != nullptr);
+ assert(variance != nullptr);
+ const auto* src = static_cast<const uint8_t*>(source);
+ uint32_t cost[8];
+ uint16x8_t partial_lo[8], partial_hi[8];
+
+ AddPartial(src, stride, partial_lo, partial_hi);
+
+ cost[2] = SquareAccumulate(partial_lo[2]);
+ cost[6] = SquareAccumulate(partial_lo[6]);
+
+ const uint32x4_t division_table[4] = {
+ vld1q_u32(kCdefDivisionTable), vld1q_u32(kCdefDivisionTable + 4),
+ vld1q_u32(kCdefDivisionTable + 8), vld1q_u32(kCdefDivisionTable + 12)};
+
+ cost[0] = Cost0Or4(partial_lo[0], partial_hi[0], division_table);
+ cost[4] = Cost0Or4(partial_lo[4], partial_hi[4], division_table);
+
+ const uint32x4_t division_table_odd[2] = {
+ vld1q_u32(kCdefDivisionTableOdd), vld1q_u32(kCdefDivisionTableOdd + 4)};
+
+ const uint32x4_t element_3_mask = {0, 0, 0, static_cast<uint32_t>(-1)};
+
+ cost[1] =
+ CostOdd(partial_lo[1], partial_hi[1], element_3_mask, division_table_odd);
+ cost[3] =
+ CostOdd(partial_lo[3], partial_hi[3], element_3_mask, division_table_odd);
+ cost[5] =
+ CostOdd(partial_lo[5], partial_hi[5], element_3_mask, division_table_odd);
+ cost[7] =
+ CostOdd(partial_lo[7], partial_hi[7], element_3_mask, division_table_odd);
+
+ uint32_t best_cost = 0;
+ *direction = 0;
+ for (int i = 0; i < 8; ++i) {
+ if (cost[i] > best_cost) {
+ best_cost = cost[i];
+ *direction = i;
+ }
+ }
+ *variance = (best_cost - cost[(*direction + 4) & 7]) >> 10;
+}
+
+// -------------------------------------------------------------------------
+// CdefFilter
+
+// Load 4 vectors based on the given |direction|.
+void LoadDirection(const uint16_t* const src, const ptrdiff_t stride,
+ uint16x8_t* output, const int direction) {
+ // Each |direction| describes a different set of source values. Expand this
+ // set by negating each set. For |direction| == 0 this gives a diagonal line
+ // from top right to bottom left. The first value is y, the second x. Negative
+ // y values move up.
+ // a b c d
+ // {-1, 1}, {1, -1}, {-2, 2}, {2, -2}
+ // c
+ // a
+ // 0
+ // b
+ // d
+ const int y_0 = kCdefDirections[direction][0][0];
+ const int x_0 = kCdefDirections[direction][0][1];
+ const int y_1 = kCdefDirections[direction][1][0];
+ const int x_1 = kCdefDirections[direction][1][1];
+ output[0] = vld1q_u16(src + y_0 * stride + x_0);
+ output[1] = vld1q_u16(src - y_0 * stride - x_0);
+ output[2] = vld1q_u16(src + y_1 * stride + x_1);
+ output[3] = vld1q_u16(src - y_1 * stride - x_1);
+}
+
+// Load 4 vectors based on the given |direction|. Use when |block_width| == 4 to
+// do 2 rows at a time.
+void LoadDirection4(const uint16_t* const src, const ptrdiff_t stride,
+ uint16x8_t* output, const int direction) {
+ const int y_0 = kCdefDirections[direction][0][0];
+ const int x_0 = kCdefDirections[direction][0][1];
+ const int y_1 = kCdefDirections[direction][1][0];
+ const int x_1 = kCdefDirections[direction][1][1];
+ output[0] = vcombine_u16(vld1_u16(src + y_0 * stride + x_0),
+ vld1_u16(src + y_0 * stride + stride + x_0));
+ output[1] = vcombine_u16(vld1_u16(src - y_0 * stride - x_0),
+ vld1_u16(src - y_0 * stride + stride - x_0));
+ output[2] = vcombine_u16(vld1_u16(src + y_1 * stride + x_1),
+ vld1_u16(src + y_1 * stride + stride + x_1));
+ output[3] = vcombine_u16(vld1_u16(src - y_1 * stride - x_1),
+ vld1_u16(src - y_1 * stride + stride - x_1));
+}
+
+int16x8_t Constrain(const uint16x8_t pixel, const uint16x8_t reference,
+ const uint16x8_t threshold, const int16x8_t damping) {
+ // If reference > pixel, the difference will be negative, so covert to 0 or
+ // -1.
+ const uint16x8_t sign = vcgtq_u16(reference, pixel);
+ const uint16x8_t abs_diff = vabdq_u16(pixel, reference);
+ const uint16x8_t shifted_diff = vshlq_u16(abs_diff, damping);
+ // For bitdepth == 8, the threshold range is [0, 15] and the damping range is
+ // [3, 6]. If pixel == kCdefLargeValue(0x4000), shifted_diff will always be
+ // larger than threshold. Subtract using saturation will return 0 when pixel
+ // == kCdefLargeValue.
+ static_assert(kCdefLargeValue == 0x4000, "Invalid kCdefLargeValue");
+ const uint16x8_t thresh_minus_shifted_diff =
+ vqsubq_u16(threshold, shifted_diff);
+ const uint16x8_t clamp_abs_diff =
+ vminq_u16(thresh_minus_shifted_diff, abs_diff);
+ // Restore the sign.
+ return vreinterpretq_s16_u16(
+ vsubq_u16(veorq_u16(clamp_abs_diff, sign), sign));
+}
+
+template <int width, bool enable_primary = true, bool enable_secondary = true>
+void CdefFilter_NEON(const uint16_t* src, const ptrdiff_t src_stride,
+ const int height, const int primary_strength,
+ const int secondary_strength, const int damping,
+ const int direction, void* dest,
+ const ptrdiff_t dst_stride) {
+ static_assert(width == 8 || width == 4, "");
+ static_assert(enable_primary || enable_secondary, "");
+ constexpr bool clipping_required = enable_primary && enable_secondary;
+ auto* dst = static_cast<uint8_t*>(dest);
+ const uint16x8_t cdef_large_value_mask =
+ vdupq_n_u16(static_cast<uint16_t>(~kCdefLargeValue));
+ const uint16x8_t primary_threshold = vdupq_n_u16(primary_strength);
+ const uint16x8_t secondary_threshold = vdupq_n_u16(secondary_strength);
+
+ int16x8_t primary_damping_shift, secondary_damping_shift;
+
+ // FloorLog2() requires input to be > 0.
+ // 8-bit damping range: Y: [3, 6], UV: [2, 5].
+ if (enable_primary) {
+ // primary_strength: [0, 15] -> FloorLog2: [0, 3] so a clamp is necessary
+ // for UV filtering.
+ primary_damping_shift =
+ vdupq_n_s16(-std::max(0, damping - FloorLog2(primary_strength)));
+ }
+ if (enable_secondary) {
+ // secondary_strength: [0, 4] -> FloorLog2: [0, 2] so no clamp to 0 is
+ // necessary.
+ assert(damping - FloorLog2(secondary_strength) >= 0);
+ secondary_damping_shift =
+ vdupq_n_s16(-(damping - FloorLog2(secondary_strength)));
+ }
+
+ const int primary_tap_0 = kCdefPrimaryTaps[primary_strength & 1][0];
+ const int primary_tap_1 = kCdefPrimaryTaps[primary_strength & 1][1];
+
+ int y = height;
+ do {
+ uint16x8_t pixel;
+ if (width == 8) {
+ pixel = vld1q_u16(src);
+ } else {
+ pixel = vcombine_u16(vld1_u16(src), vld1_u16(src + src_stride));
+ }
+
+ uint16x8_t min = pixel;
+ uint16x8_t max = pixel;
+ int16x8_t sum;
+
+ if (enable_primary) {
+ // Primary |direction|.
+ uint16x8_t primary_val[4];
+ if (width == 8) {
+ LoadDirection(src, src_stride, primary_val, direction);
+ } else {
+ LoadDirection4(src, src_stride, primary_val, direction);
+ }
+
+ if (clipping_required) {
+ min = vminq_u16(min, primary_val[0]);
+ min = vminq_u16(min, primary_val[1]);
+ min = vminq_u16(min, primary_val[2]);
+ min = vminq_u16(min, primary_val[3]);
+
+ // The source is 16 bits, however, we only really care about the lower
+ // 8 bits. The upper 8 bits contain the "large" flag. After the final
+ // primary max has been calculated, zero out the upper 8 bits. Use this
+ // to find the "16 bit" max.
+ const uint8x16_t max_p01 =
+ vmaxq_u8(vreinterpretq_u8_u16(primary_val[0]),
+ vreinterpretq_u8_u16(primary_val[1]));
+ const uint8x16_t max_p23 =
+ vmaxq_u8(vreinterpretq_u8_u16(primary_val[2]),
+ vreinterpretq_u8_u16(primary_val[3]));
+ const uint16x8_t max_p =
+ vreinterpretq_u16_u8(vmaxq_u8(max_p01, max_p23));
+ max = vmaxq_u16(max, vandq_u16(max_p, cdef_large_value_mask));
+ }
+
+ sum = Constrain(primary_val[0], pixel, primary_threshold,
+ primary_damping_shift);
+ sum = vmulq_n_s16(sum, primary_tap_0);
+ sum = vmlaq_n_s16(sum,
+ Constrain(primary_val[1], pixel, primary_threshold,
+ primary_damping_shift),
+ primary_tap_0);
+ sum = vmlaq_n_s16(sum,
+ Constrain(primary_val[2], pixel, primary_threshold,
+ primary_damping_shift),
+ primary_tap_1);
+ sum = vmlaq_n_s16(sum,
+ Constrain(primary_val[3], pixel, primary_threshold,
+ primary_damping_shift),
+ primary_tap_1);
+ } else {
+ sum = vdupq_n_s16(0);
+ }
+
+ if (enable_secondary) {
+ // Secondary |direction| values (+/- 2). Clamp |direction|.
+ uint16x8_t secondary_val[8];
+ if (width == 8) {
+ LoadDirection(src, src_stride, secondary_val, direction + 2);
+ LoadDirection(src, src_stride, secondary_val + 4, direction - 2);
+ } else {
+ LoadDirection4(src, src_stride, secondary_val, direction + 2);
+ LoadDirection4(src, src_stride, secondary_val + 4, direction - 2);
+ }
+
+ if (clipping_required) {
+ min = vminq_u16(min, secondary_val[0]);
+ min = vminq_u16(min, secondary_val[1]);
+ min = vminq_u16(min, secondary_val[2]);
+ min = vminq_u16(min, secondary_val[3]);
+ min = vminq_u16(min, secondary_val[4]);
+ min = vminq_u16(min, secondary_val[5]);
+ min = vminq_u16(min, secondary_val[6]);
+ min = vminq_u16(min, secondary_val[7]);
+
+ const uint8x16_t max_s01 =
+ vmaxq_u8(vreinterpretq_u8_u16(secondary_val[0]),
+ vreinterpretq_u8_u16(secondary_val[1]));
+ const uint8x16_t max_s23 =
+ vmaxq_u8(vreinterpretq_u8_u16(secondary_val[2]),
+ vreinterpretq_u8_u16(secondary_val[3]));
+ const uint8x16_t max_s45 =
+ vmaxq_u8(vreinterpretq_u8_u16(secondary_val[4]),
+ vreinterpretq_u8_u16(secondary_val[5]));
+ const uint8x16_t max_s67 =
+ vmaxq_u8(vreinterpretq_u8_u16(secondary_val[6]),
+ vreinterpretq_u8_u16(secondary_val[7]));
+ const uint16x8_t max_s = vreinterpretq_u16_u8(
+ vmaxq_u8(vmaxq_u8(max_s01, max_s23), vmaxq_u8(max_s45, max_s67)));
+ max = vmaxq_u16(max, vandq_u16(max_s, cdef_large_value_mask));
+ }
+
+ sum = vmlaq_n_s16(sum,
+ Constrain(secondary_val[0], pixel, secondary_threshold,
+ secondary_damping_shift),
+ kCdefSecondaryTap0);
+ sum = vmlaq_n_s16(sum,
+ Constrain(secondary_val[1], pixel, secondary_threshold,
+ secondary_damping_shift),
+ kCdefSecondaryTap0);
+ sum = vmlaq_n_s16(sum,
+ Constrain(secondary_val[2], pixel, secondary_threshold,
+ secondary_damping_shift),
+ kCdefSecondaryTap1);
+ sum = vmlaq_n_s16(sum,
+ Constrain(secondary_val[3], pixel, secondary_threshold,
+ secondary_damping_shift),
+ kCdefSecondaryTap1);
+ sum = vmlaq_n_s16(sum,
+ Constrain(secondary_val[4], pixel, secondary_threshold,
+ secondary_damping_shift),
+ kCdefSecondaryTap0);
+ sum = vmlaq_n_s16(sum,
+ Constrain(secondary_val[5], pixel, secondary_threshold,
+ secondary_damping_shift),
+ kCdefSecondaryTap0);
+ sum = vmlaq_n_s16(sum,
+ Constrain(secondary_val[6], pixel, secondary_threshold,
+ secondary_damping_shift),
+ kCdefSecondaryTap1);
+ sum = vmlaq_n_s16(sum,
+ Constrain(secondary_val[7], pixel, secondary_threshold,
+ secondary_damping_shift),
+ kCdefSecondaryTap1);
+ }
+ // Clip3(pixel + ((8 + sum - (sum < 0)) >> 4), min, max))
+ const int16x8_t sum_lt_0 = vshrq_n_s16(sum, 15);
+ sum = vaddq_s16(sum, sum_lt_0);
+ int16x8_t result = vrsraq_n_s16(vreinterpretq_s16_u16(pixel), sum, 4);
+ if (clipping_required) {
+ result = vminq_s16(result, vreinterpretq_s16_u16(max));
+ result = vmaxq_s16(result, vreinterpretq_s16_u16(min));
+ }
+
+ const uint8x8_t dst_pixel = vqmovun_s16(result);
+ if (width == 8) {
+ src += src_stride;
+ vst1_u8(dst, dst_pixel);
+ dst += dst_stride;
+ --y;
+ } else {
+ src += src_stride << 1;
+ StoreLo4(dst, dst_pixel);
+ dst += dst_stride;
+ StoreHi4(dst, dst_pixel);
+ dst += dst_stride;
+ y -= 2;
+ }
+ } while (y != 0);
+}
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+ dsp->cdef_direction = CdefDirection_NEON;
+ dsp->cdef_filters[0][0] = CdefFilter_NEON<4>;
+ dsp->cdef_filters[0][1] =
+ CdefFilter_NEON<4, /*enable_primary=*/true, /*enable_secondary=*/false>;
+ dsp->cdef_filters[0][2] = CdefFilter_NEON<4, /*enable_primary=*/false>;
+ dsp->cdef_filters[1][0] = CdefFilter_NEON<8>;
+ dsp->cdef_filters[1][1] =
+ CdefFilter_NEON<8, /*enable_primary=*/true, /*enable_secondary=*/false>;
+ dsp->cdef_filters[1][2] = CdefFilter_NEON<8, /*enable_primary=*/false>;
+}
+
+} // namespace
+} // namespace low_bitdepth
+
+void CdefInit_NEON() { low_bitdepth::Init8bpp(); }
+
+} // namespace dsp
+} // namespace libgav1
+#else // !LIBGAV1_ENABLE_NEON
+namespace libgav1 {
+namespace dsp {
+
+void CdefInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/cdef_neon.h b/src/dsp/arm/cdef_neon.h
new file mode 100644
index 0000000..53d5f86
--- /dev/null
+++ b/src/dsp/arm/cdef_neon.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_CDEF_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_CDEF_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::cdef_direction and Dsp::cdef_filters. This function is not
+// thread-safe.
+void CdefInit_NEON();
+
+} // namespace dsp
+} // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_CdefDirection LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_CdefFilters LIBGAV1_CPU_NEON
+#endif // LIBGAV1_ENABLE_NEON
+
+#endif // LIBGAV1_SRC_DSP_ARM_CDEF_NEON_H_
diff --git a/src/dsp/arm/common_neon.h b/src/dsp/arm/common_neon.h
new file mode 100644
index 0000000..dcb7567
--- /dev/null
+++ b/src/dsp/arm/common_neon.h
@@ -0,0 +1,777 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_COMMON_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_COMMON_NEON_H_
+
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <cstdint>
+#include <cstring>
+
+#if 0
+#include <cstdio>
+
+#include "absl/strings/str_cat.h"
+
+constexpr bool kEnablePrintRegs = true;
+
+union DebugRegister {
+ int8_t i8[8];
+ int16_t i16[4];
+ int32_t i32[2];
+ uint8_t u8[8];
+ uint16_t u16[4];
+ uint32_t u32[2];
+};
+
+union DebugRegisterQ {
+ int8_t i8[16];
+ int16_t i16[8];
+ int32_t i32[4];
+ uint8_t u8[16];
+ uint16_t u16[8];
+ uint32_t u32[4];
+};
+
+// Quite useful macro for debugging. Left here for convenience.
+inline void PrintVect(const DebugRegister r, const char* const name, int size) {
+ int n;
+ if (kEnablePrintRegs) {
+ fprintf(stderr, "%s\t: ", name);
+ if (size == 8) {
+ for (n = 0; n < 8; ++n) fprintf(stderr, "%.2x ", r.u8[n]);
+ } else if (size == 16) {
+ for (n = 0; n < 4; ++n) fprintf(stderr, "%.4x ", r.u16[n]);
+ } else if (size == 32) {
+ for (n = 0; n < 2; ++n) fprintf(stderr, "%.8x ", r.u32[n]);
+ }
+ fprintf(stderr, "\n");
+ }
+}
+
+// Debugging macro for 128-bit types.
+inline void PrintVectQ(const DebugRegisterQ r, const char* const name,
+ int size) {
+ int n;
+ if (kEnablePrintRegs) {
+ fprintf(stderr, "%s\t: ", name);
+ if (size == 8) {
+ for (n = 0; n < 16; ++n) fprintf(stderr, "%.2x ", r.u8[n]);
+ } else if (size == 16) {
+ for (n = 0; n < 8; ++n) fprintf(stderr, "%.4x ", r.u16[n]);
+ } else if (size == 32) {
+ for (n = 0; n < 4; ++n) fprintf(stderr, "%.8x ", r.u32[n]);
+ }
+ fprintf(stderr, "\n");
+ }
+}
+
+inline void PrintReg(const int32x4x2_t val, const std::string& name) {
+ DebugRegisterQ r;
+ vst1q_u32(r.u32, val.val[0]);
+ const std::string name0 = absl::StrCat(name, ".val[0]").c_str();
+ PrintVectQ(r, name0.c_str(), 32);
+ vst1q_u32(r.u32, val.val[1]);
+ const std::string name1 = absl::StrCat(name, ".val[1]").c_str();
+ PrintVectQ(r, name1.c_str(), 32);
+}
+
+inline void PrintReg(const uint32x4_t val, const char* name) {
+ DebugRegisterQ r;
+ vst1q_u32(r.u32, val);
+ PrintVectQ(r, name, 32);
+}
+
+inline void PrintReg(const uint32x2_t val, const char* name) {
+ DebugRegister r;
+ vst1_u32(r.u32, val);
+ PrintVect(r, name, 32);
+}
+
+inline void PrintReg(const uint16x8_t val, const char* name) {
+ DebugRegisterQ r;
+ vst1q_u16(r.u16, val);
+ PrintVectQ(r, name, 16);
+}
+
+inline void PrintReg(const uint16x4_t val, const char* name) {
+ DebugRegister r;
+ vst1_u16(r.u16, val);
+ PrintVect(r, name, 16);
+}
+
+inline void PrintReg(const uint8x16_t val, const char* name) {
+ DebugRegisterQ r;
+ vst1q_u8(r.u8, val);
+ PrintVectQ(r, name, 8);
+}
+
+inline void PrintReg(const uint8x8_t val, const char* name) {
+ DebugRegister r;
+ vst1_u8(r.u8, val);
+ PrintVect(r, name, 8);
+}
+
+inline void PrintReg(const int32x4_t val, const char* name) {
+ DebugRegisterQ r;
+ vst1q_s32(r.i32, val);
+ PrintVectQ(r, name, 32);
+}
+
+inline void PrintReg(const int32x2_t val, const char* name) {
+ DebugRegister r;
+ vst1_s32(r.i32, val);
+ PrintVect(r, name, 32);
+}
+
+inline void PrintReg(const int16x8_t val, const char* name) {
+ DebugRegisterQ r;
+ vst1q_s16(r.i16, val);
+ PrintVectQ(r, name, 16);
+}
+
+inline void PrintReg(const int16x4_t val, const char* name) {
+ DebugRegister r;
+ vst1_s16(r.i16, val);
+ PrintVect(r, name, 16);
+}
+
+inline void PrintReg(const int8x16_t val, const char* name) {
+ DebugRegisterQ r;
+ vst1q_s8(r.i8, val);
+ PrintVectQ(r, name, 8);
+}
+
+inline void PrintReg(const int8x8_t val, const char* name) {
+ DebugRegister r;
+ vst1_s8(r.i8, val);
+ PrintVect(r, name, 8);
+}
+
+// Print an individual (non-vector) value in decimal format.
+inline void PrintReg(const int x, const char* name) {
+ if (kEnablePrintRegs) {
+ printf("%s: %d\n", name, x);
+ }
+}
+
+// Print an individual (non-vector) value in hexadecimal format.
+inline void PrintHex(const int x, const char* name) {
+ if (kEnablePrintRegs) {
+ printf("%s: %x\n", name, x);
+ }
+}
+
+#define PR(x) PrintReg(x, #x)
+#define PD(x) PrintReg(x, #x)
+#define PX(x) PrintHex(x, #x)
+
+#endif // 0
+
+namespace libgav1 {
+namespace dsp {
+
+//------------------------------------------------------------------------------
+// Load functions.
+
+// Load 2 uint8_t values into lanes 0 and 1. Zeros the register before loading
+// the values. Use caution when using this in loops because it will re-zero the
+// register before loading on every iteration.
+inline uint8x8_t Load2(const void* const buf) {
+ const uint16x4_t zero = vdup_n_u16(0);
+ uint16_t temp;
+ memcpy(&temp, buf, 2);
+ return vreinterpret_u8_u16(vld1_lane_u16(&temp, zero, 0));
+}
+
+// Load 2 uint8_t values into |lane| * 2 and |lane| * 2 + 1.
+template <int lane>
+inline uint8x8_t Load2(const void* const buf, uint8x8_t val) {
+ uint16_t temp;
+ memcpy(&temp, buf, 2);
+ return vreinterpret_u8_u16(
+ vld1_lane_u16(&temp, vreinterpret_u16_u8(val), lane));
+}
+
+// Load 4 uint8_t values into the low half of a uint8x8_t register. Zeros the
+// register before loading the values. Use caution when using this in loops
+// because it will re-zero the register before loading on every iteration.
+inline uint8x8_t Load4(const void* const buf) {
+ const uint32x2_t zero = vdup_n_u32(0);
+ uint32_t temp;
+ memcpy(&temp, buf, 4);
+ return vreinterpret_u8_u32(vld1_lane_u32(&temp, zero, 0));
+}
+
+// Load 4 uint8_t values into 4 lanes staring with |lane| * 4.
+template <int lane>
+inline uint8x8_t Load4(const void* const buf, uint8x8_t val) {
+ uint32_t temp;
+ memcpy(&temp, buf, 4);
+ return vreinterpret_u8_u32(
+ vld1_lane_u32(&temp, vreinterpret_u32_u8(val), lane));
+}
+
+//------------------------------------------------------------------------------
+// Store functions.
+
+// Propagate type information to the compiler. Without this the compiler may
+// assume the required alignment of the type (4 bytes in the case of uint32_t)
+// and add alignment hints to the memory access.
+template <typename T>
+inline void ValueToMem(void* const buf, T val) {
+ memcpy(buf, &val, sizeof(val));
+}
+
+// Store 4 int8_t values from the low half of an int8x8_t register.
+inline void StoreLo4(void* const buf, const int8x8_t val) {
+ ValueToMem<int32_t>(buf, vget_lane_s32(vreinterpret_s32_s8(val), 0));
+}
+
+// Store 4 uint8_t values from the low half of a uint8x8_t register.
+inline void StoreLo4(void* const buf, const uint8x8_t val) {
+ ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u8(val), 0));
+}
+
+// Store 4 uint8_t values from the high half of a uint8x8_t register.
+inline void StoreHi4(void* const buf, const uint8x8_t val) {
+ ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u8(val), 1));
+}
+
+// Store 2 uint8_t values from |lane| * 2 and |lane| * 2 + 1 of a uint8x8_t
+// register.
+template <int lane>
+inline void Store2(void* const buf, const uint8x8_t val) {
+ ValueToMem<uint16_t>(buf, vget_lane_u16(vreinterpret_u16_u8(val), lane));
+}
+
+// Store 2 uint16_t values from |lane| * 2 and |lane| * 2 + 1 of a uint16x8_t
+// register.
+template <int lane>
+inline void Store2(void* const buf, const uint16x8_t val) {
+ ValueToMem<uint32_t>(buf, vgetq_lane_u32(vreinterpretq_u32_u16(val), lane));
+}
+
+// Store 2 uint16_t values from |lane| * 2 and |lane| * 2 + 1 of a uint16x4_t
+// register.
+template <int lane>
+inline void Store2(uint16_t* const buf, const uint16x4_t val) {
+ ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u16(val), lane));
+}
+
+//------------------------------------------------------------------------------
+// Bit manipulation.
+
+// vshXX_n_XX() requires an immediate.
+template <int shift>
+inline uint8x8_t LeftShift(const uint8x8_t vector) {
+ return vreinterpret_u8_u64(vshl_n_u64(vreinterpret_u64_u8(vector), shift));
+}
+
+template <int shift>
+inline uint8x8_t RightShift(const uint8x8_t vector) {
+ return vreinterpret_u8_u64(vshr_n_u64(vreinterpret_u64_u8(vector), shift));
+}
+
+template <int shift>
+inline int8x8_t RightShift(const int8x8_t vector) {
+ return vreinterpret_s8_u64(vshr_n_u64(vreinterpret_u64_s8(vector), shift));
+}
+
+// Shim vqtbl1_u8 for armv7.
+inline uint8x8_t VQTbl1U8(const uint8x16_t a, const uint8x8_t index) {
+#if defined(__aarch64__)
+ return vqtbl1_u8(a, index);
+#else
+ const uint8x8x2_t b = {vget_low_u8(a), vget_high_u8(a)};
+ return vtbl2_u8(b, index);
+#endif
+}
+
+// Shim vqtbl1_s8 for armv7.
+inline int8x8_t VQTbl1S8(const int8x16_t a, const uint8x8_t index) {
+#if defined(__aarch64__)
+ return vqtbl1_s8(a, index);
+#else
+ const int8x8x2_t b = {vget_low_s8(a), vget_high_s8(a)};
+ return vtbl2_s8(b, vreinterpret_s8_u8(index));
+#endif
+}
+
+//------------------------------------------------------------------------------
+// Interleave.
+
+// vzipN is exclusive to A64.
+inline uint8x8_t InterleaveLow8(const uint8x8_t a, const uint8x8_t b) {
+#if defined(__aarch64__)
+ return vzip1_u8(a, b);
+#else
+ // Discard |.val[1]|
+ return vzip_u8(a, b).val[0];
+#endif
+}
+
+inline uint8x8_t InterleaveLow32(const uint8x8_t a, const uint8x8_t b) {
+#if defined(__aarch64__)
+ return vreinterpret_u8_u32(
+ vzip1_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)));
+#else
+ // Discard |.val[1]|
+ return vreinterpret_u8_u32(
+ vzip_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)).val[0]);
+#endif
+}
+
+inline int8x8_t InterleaveLow32(const int8x8_t a, const int8x8_t b) {
+#if defined(__aarch64__)
+ return vreinterpret_s8_u32(
+ vzip1_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)));
+#else
+ // Discard |.val[1]|
+ return vreinterpret_s8_u32(
+ vzip_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)).val[0]);
+#endif
+}
+
+inline uint8x8_t InterleaveHigh32(const uint8x8_t a, const uint8x8_t b) {
+#if defined(__aarch64__)
+ return vreinterpret_u8_u32(
+ vzip2_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)));
+#else
+ // Discard |.val[0]|
+ return vreinterpret_u8_u32(
+ vzip_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)).val[1]);
+#endif
+}
+
+inline int8x8_t InterleaveHigh32(const int8x8_t a, const int8x8_t b) {
+#if defined(__aarch64__)
+ return vreinterpret_s8_u32(
+ vzip2_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)));
+#else
+ // Discard |.val[0]|
+ return vreinterpret_s8_u32(
+ vzip_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)).val[1]);
+#endif
+}
+
+//------------------------------------------------------------------------------
+// Sum.
+
+inline uint16_t SumVector(const uint8x8_t a) {
+#if defined(__aarch64__)
+ return vaddlv_u8(a);
+#else
+ const uint16x4_t c = vpaddl_u8(a);
+ const uint32x2_t d = vpaddl_u16(c);
+ const uint64x1_t e = vpaddl_u32(d);
+ return static_cast<uint16_t>(vget_lane_u64(e, 0));
+#endif // defined(__aarch64__)
+}
+
+inline uint32_t SumVector(const uint32x4_t a) {
+#if defined(__aarch64__)
+ return vaddvq_u32(a);
+#else
+ const uint64x2_t b = vpaddlq_u32(a);
+ const uint64x1_t c = vadd_u64(vget_low_u64(b), vget_high_u64(b));
+ return static_cast<uint32_t>(vget_lane_u64(c, 0));
+#endif
+}
+
+//------------------------------------------------------------------------------
+// Transpose.
+
+// Transpose 32 bit elements such that:
+// a: 00 01
+// b: 02 03
+// returns
+// val[0]: 00 02
+// val[1]: 01 03
+inline uint8x8x2_t Interleave32(const uint8x8_t a, const uint8x8_t b) {
+ const uint32x2_t a_32 = vreinterpret_u32_u8(a);
+ const uint32x2_t b_32 = vreinterpret_u32_u8(b);
+ const uint32x2x2_t c = vtrn_u32(a_32, b_32);
+ const uint8x8x2_t d = {vreinterpret_u8_u32(c.val[0]),
+ vreinterpret_u8_u32(c.val[1])};
+ return d;
+}
+
+// Swap high and low 32 bit elements.
+inline uint8x8_t Transpose32(const uint8x8_t a) {
+ const uint32x2_t b = vrev64_u32(vreinterpret_u32_u8(a));
+ return vreinterpret_u8_u32(b);
+}
+
+// Implement vtrnq_s64().
+// Input:
+// a0: 00 01 02 03 04 05 06 07
+// a1: 16 17 18 19 20 21 22 23
+// Output:
+// b0.val[0]: 00 01 02 03 16 17 18 19
+// b0.val[1]: 04 05 06 07 20 21 22 23
+inline int16x8x2_t VtrnqS64(int32x4_t a0, int32x4_t a1) {
+ int16x8x2_t b0;
+ b0.val[0] = vcombine_s16(vreinterpret_s16_s32(vget_low_s32(a0)),
+ vreinterpret_s16_s32(vget_low_s32(a1)));
+ b0.val[1] = vcombine_s16(vreinterpret_s16_s32(vget_high_s32(a0)),
+ vreinterpret_s16_s32(vget_high_s32(a1)));
+ return b0;
+}
+
+inline uint16x8x2_t VtrnqU64(uint32x4_t a0, uint32x4_t a1) {
+ uint16x8x2_t b0;
+ b0.val[0] = vcombine_u16(vreinterpret_u16_u32(vget_low_u32(a0)),
+ vreinterpret_u16_u32(vget_low_u32(a1)));
+ b0.val[1] = vcombine_u16(vreinterpret_u16_u32(vget_high_u32(a0)),
+ vreinterpret_u16_u32(vget_high_u32(a1)));
+ return b0;
+}
+
+// Input:
+// a: 00 01 02 03 10 11 12 13
+// b: 20 21 22 23 30 31 32 33
+// Output:
+// Note that columns [1] and [2] are transposed.
+// a: 00 10 20 30 02 12 22 32
+// b: 01 11 21 31 03 13 23 33
+inline void Transpose4x4(uint8x8_t* a, uint8x8_t* b) {
+ const uint16x4x2_t c =
+ vtrn_u16(vreinterpret_u16_u8(*a), vreinterpret_u16_u8(*b));
+ const uint32x2x2_t d =
+ vtrn_u32(vreinterpret_u32_u16(c.val[0]), vreinterpret_u32_u16(c.val[1]));
+ const uint8x8x2_t e =
+ vtrn_u8(vreinterpret_u8_u32(d.val[0]), vreinterpret_u8_u32(d.val[1]));
+ *a = e.val[0];
+ *b = e.val[1];
+}
+
+// Reversible if the x4 values are packed next to each other.
+// x4 input / x8 output:
+// a0: 00 01 02 03 40 41 42 43 44
+// a1: 10 11 12 13 50 51 52 53 54
+// a2: 20 21 22 23 60 61 62 63 64
+// a3: 30 31 32 33 70 71 72 73 74
+// x8 input / x4 output:
+// a0: 00 10 20 30 40 50 60 70
+// a1: 01 11 21 31 41 51 61 71
+// a2: 02 12 22 32 42 52 62 72
+// a3: 03 13 23 33 43 53 63 73
+inline void Transpose8x4(uint8x8_t* a0, uint8x8_t* a1, uint8x8_t* a2,
+ uint8x8_t* a3) {
+ const uint8x8x2_t b0 = vtrn_u8(*a0, *a1);
+ const uint8x8x2_t b1 = vtrn_u8(*a2, *a3);
+
+ const uint16x4x2_t c0 =
+ vtrn_u16(vreinterpret_u16_u8(b0.val[0]), vreinterpret_u16_u8(b1.val[0]));
+ const uint16x4x2_t c1 =
+ vtrn_u16(vreinterpret_u16_u8(b0.val[1]), vreinterpret_u16_u8(b1.val[1]));
+
+ *a0 = vreinterpret_u8_u16(c0.val[0]);
+ *a1 = vreinterpret_u8_u16(c1.val[0]);
+ *a2 = vreinterpret_u8_u16(c0.val[1]);
+ *a3 = vreinterpret_u8_u16(c1.val[1]);
+}
+
+// Input:
+// a[0]: 00 01 02 03 04 05 06 07
+// a[1]: 10 11 12 13 14 15 16 17
+// a[2]: 20 21 22 23 24 25 26 27
+// a[3]: 30 31 32 33 34 35 36 37
+// a[4]: 40 41 42 43 44 45 46 47
+// a[5]: 50 51 52 53 54 55 56 57
+// a[6]: 60 61 62 63 64 65 66 67
+// a[7]: 70 71 72 73 74 75 76 77
+
+// Output:
+// a[0]: 00 10 20 30 40 50 60 70
+// a[1]: 01 11 21 31 41 51 61 71
+// a[2]: 02 12 22 32 42 52 62 72
+// a[3]: 03 13 23 33 43 53 63 73
+// a[4]: 04 14 24 34 44 54 64 74
+// a[5]: 05 15 25 35 45 55 65 75
+// a[6]: 06 16 26 36 46 56 66 76
+// a[7]: 07 17 27 37 47 57 67 77
+inline void Transpose8x8(int8x8_t a[8]) {
+ // Swap 8 bit elements. Goes from:
+ // a[0]: 00 01 02 03 04 05 06 07
+ // a[1]: 10 11 12 13 14 15 16 17
+ // a[2]: 20 21 22 23 24 25 26 27
+ // a[3]: 30 31 32 33 34 35 36 37
+ // a[4]: 40 41 42 43 44 45 46 47
+ // a[5]: 50 51 52 53 54 55 56 57
+ // a[6]: 60 61 62 63 64 65 66 67
+ // a[7]: 70 71 72 73 74 75 76 77
+ // to:
+ // b0.val[0]: 00 10 02 12 04 14 06 16 40 50 42 52 44 54 46 56
+ // b0.val[1]: 01 11 03 13 05 15 07 17 41 51 43 53 45 55 47 57
+ // b1.val[0]: 20 30 22 32 24 34 26 36 60 70 62 72 64 74 66 76
+ // b1.val[1]: 21 31 23 33 25 35 27 37 61 71 63 73 65 75 67 77
+ const int8x16x2_t b0 =
+ vtrnq_s8(vcombine_s8(a[0], a[4]), vcombine_s8(a[1], a[5]));
+ const int8x16x2_t b1 =
+ vtrnq_s8(vcombine_s8(a[2], a[6]), vcombine_s8(a[3], a[7]));
+
+ // Swap 16 bit elements resulting in:
+ // c0.val[0]: 00 10 20 30 04 14 24 34 40 50 60 70 44 54 64 74
+ // c0.val[1]: 02 12 22 32 06 16 26 36 42 52 62 72 46 56 66 76
+ // c1.val[0]: 01 11 21 31 05 15 25 35 41 51 61 71 45 55 65 75
+ // c1.val[1]: 03 13 23 33 07 17 27 37 43 53 63 73 47 57 67 77
+ const int16x8x2_t c0 = vtrnq_s16(vreinterpretq_s16_s8(b0.val[0]),
+ vreinterpretq_s16_s8(b1.val[0]));
+ const int16x8x2_t c1 = vtrnq_s16(vreinterpretq_s16_s8(b0.val[1]),
+ vreinterpretq_s16_s8(b1.val[1]));
+
+ // Unzip 32 bit elements resulting in:
+ // d0.val[0]: 00 10 20 30 40 50 60 70 01 11 21 31 41 51 61 71
+ // d0.val[1]: 04 14 24 34 44 54 64 74 05 15 25 35 45 55 65 75
+ // d1.val[0]: 02 12 22 32 42 52 62 72 03 13 23 33 43 53 63 73
+ // d1.val[1]: 06 16 26 36 46 56 66 76 07 17 27 37 47 57 67 77
+ const int32x4x2_t d0 = vuzpq_s32(vreinterpretq_s32_s16(c0.val[0]),
+ vreinterpretq_s32_s16(c1.val[0]));
+ const int32x4x2_t d1 = vuzpq_s32(vreinterpretq_s32_s16(c0.val[1]),
+ vreinterpretq_s32_s16(c1.val[1]));
+
+ a[0] = vreinterpret_s8_s32(vget_low_s32(d0.val[0]));
+ a[1] = vreinterpret_s8_s32(vget_high_s32(d0.val[0]));
+ a[2] = vreinterpret_s8_s32(vget_low_s32(d1.val[0]));
+ a[3] = vreinterpret_s8_s32(vget_high_s32(d1.val[0]));
+ a[4] = vreinterpret_s8_s32(vget_low_s32(d0.val[1]));
+ a[5] = vreinterpret_s8_s32(vget_high_s32(d0.val[1]));
+ a[6] = vreinterpret_s8_s32(vget_low_s32(d1.val[1]));
+ a[7] = vreinterpret_s8_s32(vget_high_s32(d1.val[1]));
+}
+
+// Unsigned.
+inline void Transpose8x8(uint8x8_t a[8]) {
+ const uint8x16x2_t b0 =
+ vtrnq_u8(vcombine_u8(a[0], a[4]), vcombine_u8(a[1], a[5]));
+ const uint8x16x2_t b1 =
+ vtrnq_u8(vcombine_u8(a[2], a[6]), vcombine_u8(a[3], a[7]));
+
+ const uint16x8x2_t c0 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[0]),
+ vreinterpretq_u16_u8(b1.val[0]));
+ const uint16x8x2_t c1 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[1]),
+ vreinterpretq_u16_u8(b1.val[1]));
+
+ const uint32x4x2_t d0 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[0]),
+ vreinterpretq_u32_u16(c1.val[0]));
+ const uint32x4x2_t d1 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[1]),
+ vreinterpretq_u32_u16(c1.val[1]));
+
+ a[0] = vreinterpret_u8_u32(vget_low_u32(d0.val[0]));
+ a[1] = vreinterpret_u8_u32(vget_high_u32(d0.val[0]));
+ a[2] = vreinterpret_u8_u32(vget_low_u32(d1.val[0]));
+ a[3] = vreinterpret_u8_u32(vget_high_u32(d1.val[0]));
+ a[4] = vreinterpret_u8_u32(vget_low_u32(d0.val[1]));
+ a[5] = vreinterpret_u8_u32(vget_high_u32(d0.val[1]));
+ a[6] = vreinterpret_u8_u32(vget_low_u32(d1.val[1]));
+ a[7] = vreinterpret_u8_u32(vget_high_u32(d1.val[1]));
+}
+
+inline void Transpose8x8(uint8x8_t in[8], uint8x16_t out[4]) {
+ const uint8x16x2_t a0 =
+ vtrnq_u8(vcombine_u8(in[0], in[4]), vcombine_u8(in[1], in[5]));
+ const uint8x16x2_t a1 =
+ vtrnq_u8(vcombine_u8(in[2], in[6]), vcombine_u8(in[3], in[7]));
+
+ const uint16x8x2_t b0 = vtrnq_u16(vreinterpretq_u16_u8(a0.val[0]),
+ vreinterpretq_u16_u8(a1.val[0]));
+ const uint16x8x2_t b1 = vtrnq_u16(vreinterpretq_u16_u8(a0.val[1]),
+ vreinterpretq_u16_u8(a1.val[1]));
+
+ const uint32x4x2_t c0 = vuzpq_u32(vreinterpretq_u32_u16(b0.val[0]),
+ vreinterpretq_u32_u16(b1.val[0]));
+ const uint32x4x2_t c1 = vuzpq_u32(vreinterpretq_u32_u16(b0.val[1]),
+ vreinterpretq_u32_u16(b1.val[1]));
+
+ out[0] = vreinterpretq_u8_u32(c0.val[0]);
+ out[1] = vreinterpretq_u8_u32(c1.val[0]);
+ out[2] = vreinterpretq_u8_u32(c0.val[1]);
+ out[3] = vreinterpretq_u8_u32(c1.val[1]);
+}
+
+// Input:
+// a[0]: 00 01 02 03 04 05 06 07
+// a[1]: 10 11 12 13 14 15 16 17
+// a[2]: 20 21 22 23 24 25 26 27
+// a[3]: 30 31 32 33 34 35 36 37
+// a[4]: 40 41 42 43 44 45 46 47
+// a[5]: 50 51 52 53 54 55 56 57
+// a[6]: 60 61 62 63 64 65 66 67
+// a[7]: 70 71 72 73 74 75 76 77
+
+// Output:
+// a[0]: 00 10 20 30 40 50 60 70
+// a[1]: 01 11 21 31 41 51 61 71
+// a[2]: 02 12 22 32 42 52 62 72
+// a[3]: 03 13 23 33 43 53 63 73
+// a[4]: 04 14 24 34 44 54 64 74
+// a[5]: 05 15 25 35 45 55 65 75
+// a[6]: 06 16 26 36 46 56 66 76
+// a[7]: 07 17 27 37 47 57 67 77
+inline void Transpose8x8(int16x8_t a[8]) {
+ const int16x8x2_t b0 = vtrnq_s16(a[0], a[1]);
+ const int16x8x2_t b1 = vtrnq_s16(a[2], a[3]);
+ const int16x8x2_t b2 = vtrnq_s16(a[4], a[5]);
+ const int16x8x2_t b3 = vtrnq_s16(a[6], a[7]);
+
+ const int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[0]),
+ vreinterpretq_s32_s16(b1.val[0]));
+ const int32x4x2_t c1 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[1]),
+ vreinterpretq_s32_s16(b1.val[1]));
+ const int32x4x2_t c2 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[0]),
+ vreinterpretq_s32_s16(b3.val[0]));
+ const int32x4x2_t c3 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[1]),
+ vreinterpretq_s32_s16(b3.val[1]));
+
+ const int16x8x2_t d0 = VtrnqS64(c0.val[0], c2.val[0]);
+ const int16x8x2_t d1 = VtrnqS64(c1.val[0], c3.val[0]);
+ const int16x8x2_t d2 = VtrnqS64(c0.val[1], c2.val[1]);
+ const int16x8x2_t d3 = VtrnqS64(c1.val[1], c3.val[1]);
+
+ a[0] = d0.val[0];
+ a[1] = d1.val[0];
+ a[2] = d2.val[0];
+ a[3] = d3.val[0];
+ a[4] = d0.val[1];
+ a[5] = d1.val[1];
+ a[6] = d2.val[1];
+ a[7] = d3.val[1];
+}
+
+// Unsigned.
+inline void Transpose8x8(uint16x8_t a[8]) {
+ const uint16x8x2_t b0 = vtrnq_u16(a[0], a[1]);
+ const uint16x8x2_t b1 = vtrnq_u16(a[2], a[3]);
+ const uint16x8x2_t b2 = vtrnq_u16(a[4], a[5]);
+ const uint16x8x2_t b3 = vtrnq_u16(a[6], a[7]);
+
+ const uint32x4x2_t c0 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[0]),
+ vreinterpretq_u32_u16(b1.val[0]));
+ const uint32x4x2_t c1 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[1]),
+ vreinterpretq_u32_u16(b1.val[1]));
+ const uint32x4x2_t c2 = vtrnq_u32(vreinterpretq_u32_u16(b2.val[0]),
+ vreinterpretq_u32_u16(b3.val[0]));
+ const uint32x4x2_t c3 = vtrnq_u32(vreinterpretq_u32_u16(b2.val[1]),
+ vreinterpretq_u32_u16(b3.val[1]));
+
+ const uint16x8x2_t d0 = VtrnqU64(c0.val[0], c2.val[0]);
+ const uint16x8x2_t d1 = VtrnqU64(c1.val[0], c3.val[0]);
+ const uint16x8x2_t d2 = VtrnqU64(c0.val[1], c2.val[1]);
+ const uint16x8x2_t d3 = VtrnqU64(c1.val[1], c3.val[1]);
+
+ a[0] = d0.val[0];
+ a[1] = d1.val[0];
+ a[2] = d2.val[0];
+ a[3] = d3.val[0];
+ a[4] = d0.val[1];
+ a[5] = d1.val[1];
+ a[6] = d2.val[1];
+ a[7] = d3.val[1];
+}
+
+// Input:
+// a[0]: 00 01 02 03 04 05 06 07 80 81 82 83 84 85 86 87
+// a[1]: 10 11 12 13 14 15 16 17 90 91 92 93 94 95 96 97
+// a[2]: 20 21 22 23 24 25 26 27 a0 a1 a2 a3 a4 a5 a6 a7
+// a[3]: 30 31 32 33 34 35 36 37 b0 b1 b2 b3 b4 b5 b6 b7
+// a[4]: 40 41 42 43 44 45 46 47 c0 c1 c2 c3 c4 c5 c6 c7
+// a[5]: 50 51 52 53 54 55 56 57 d0 d1 d2 d3 d4 d5 d6 d7
+// a[6]: 60 61 62 63 64 65 66 67 e0 e1 e2 e3 e4 e5 e6 e7
+// a[7]: 70 71 72 73 74 75 76 77 f0 f1 f2 f3 f4 f5 f6 f7
+
+// Output:
+// a[0]: 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0
+// a[1]: 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1
+// a[2]: 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2
+// a[3]: 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3
+// a[4]: 04 14 24 34 44 54 64 74 84 94 a4 b4 c4 d4 e4 f4
+// a[5]: 05 15 25 35 45 55 65 75 85 95 a5 b5 c5 d5 e5 f5
+// a[6]: 06 16 26 36 46 56 66 76 86 96 a6 b6 c6 d6 e6 f6
+// a[7]: 07 17 27 37 47 57 67 77 87 97 a7 b7 c7 d7 e7 f7
+inline void Transpose8x16(uint8x16_t a[8]) {
+ // b0.val[0]: 00 10 02 12 04 14 06 16 80 90 82 92 84 94 86 96
+ // b0.val[1]: 01 11 03 13 05 15 07 17 81 91 83 93 85 95 87 97
+ // b1.val[0]: 20 30 22 32 24 34 26 36 a0 b0 a2 b2 a4 b4 a6 b6
+ // b1.val[1]: 21 31 23 33 25 35 27 37 a1 b1 a3 b3 a5 b5 a7 b7
+ // b2.val[0]: 40 50 42 52 44 54 46 56 c0 d0 c2 d2 c4 d4 c6 d6
+ // b2.val[1]: 41 51 43 53 45 55 47 57 c1 d1 c3 d3 c5 d5 c7 d7
+ // b3.val[0]: 60 70 62 72 64 74 66 76 e0 f0 e2 f2 e4 f4 e6 f6
+ // b3.val[1]: 61 71 63 73 65 75 67 77 e1 f1 e3 f3 e5 f5 e7 f7
+ const uint8x16x2_t b0 = vtrnq_u8(a[0], a[1]);
+ const uint8x16x2_t b1 = vtrnq_u8(a[2], a[3]);
+ const uint8x16x2_t b2 = vtrnq_u8(a[4], a[5]);
+ const uint8x16x2_t b3 = vtrnq_u8(a[6], a[7]);
+
+ // c0.val[0]: 00 10 20 30 04 14 24 34 80 90 a0 b0 84 94 a4 b4
+ // c0.val[1]: 02 12 22 32 06 16 26 36 82 92 a2 b2 86 96 a6 b6
+ // c1.val[0]: 01 11 21 31 05 15 25 35 81 91 a1 b1 85 95 a5 b5
+ // c1.val[1]: 03 13 23 33 07 17 27 37 83 93 a3 b3 87 97 a7 b7
+ // c2.val[0]: 40 50 60 70 44 54 64 74 c0 d0 e0 f0 c4 d4 e4 f4
+ // c2.val[1]: 42 52 62 72 46 56 66 76 c2 d2 e2 f2 c6 d6 e6 f6
+ // c3.val[0]: 41 51 61 71 45 55 65 75 c1 d1 e1 f1 c5 d5 e5 f5
+ // c3.val[1]: 43 53 63 73 47 57 67 77 c3 d3 e3 f3 c7 d7 e7 f7
+ const uint16x8x2_t c0 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[0]),
+ vreinterpretq_u16_u8(b1.val[0]));
+ const uint16x8x2_t c1 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[1]),
+ vreinterpretq_u16_u8(b1.val[1]));
+ const uint16x8x2_t c2 = vtrnq_u16(vreinterpretq_u16_u8(b2.val[0]),
+ vreinterpretq_u16_u8(b3.val[0]));
+ const uint16x8x2_t c3 = vtrnq_u16(vreinterpretq_u16_u8(b2.val[1]),
+ vreinterpretq_u16_u8(b3.val[1]));
+
+ // d0.val[0]: 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0
+ // d0.val[1]: 04 14 24 34 44 54 64 74 84 94 a4 b4 c4 d4 e4 f4
+ // d1.val[0]: 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1
+ // d1.val[1]: 05 15 25 35 45 55 65 75 85 95 a5 b5 c5 d5 e5 f5
+ // d2.val[0]: 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2
+ // d2.val[1]: 06 16 26 36 46 56 66 76 86 96 a6 b6 c6 d6 e6 f6
+ // d3.val[0]: 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3
+ // d3.val[1]: 07 17 27 37 47 57 67 77 87 97 a7 b7 c7 d7 e7 f7
+ const uint32x4x2_t d0 = vtrnq_u32(vreinterpretq_u32_u16(c0.val[0]),
+ vreinterpretq_u32_u16(c2.val[0]));
+ const uint32x4x2_t d1 = vtrnq_u32(vreinterpretq_u32_u16(c1.val[0]),
+ vreinterpretq_u32_u16(c3.val[0]));
+ const uint32x4x2_t d2 = vtrnq_u32(vreinterpretq_u32_u16(c0.val[1]),
+ vreinterpretq_u32_u16(c2.val[1]));
+ const uint32x4x2_t d3 = vtrnq_u32(vreinterpretq_u32_u16(c1.val[1]),
+ vreinterpretq_u32_u16(c3.val[1]));
+
+ a[0] = vreinterpretq_u8_u32(d0.val[0]);
+ a[1] = vreinterpretq_u8_u32(d1.val[0]);
+ a[2] = vreinterpretq_u8_u32(d2.val[0]);
+ a[3] = vreinterpretq_u8_u32(d3.val[0]);
+ a[4] = vreinterpretq_u8_u32(d0.val[1]);
+ a[5] = vreinterpretq_u8_u32(d1.val[1]);
+ a[6] = vreinterpretq_u8_u32(d2.val[1]);
+ a[7] = vreinterpretq_u8_u32(d3.val[1]);
+}
+
+inline int16x8_t ZeroExtend(const uint8x8_t in) {
+ return vreinterpretq_s16_u16(vmovl_u8(in));
+}
+
+} // namespace dsp
+} // namespace libgav1
+
+#endif // LIBGAV1_ENABLE_NEON
+#endif // LIBGAV1_SRC_DSP_ARM_COMMON_NEON_H_
diff --git a/src/dsp/arm/convolve_neon.cc b/src/dsp/arm/convolve_neon.cc
new file mode 100644
index 0000000..fd9b912
--- /dev/null
+++ b/src/dsp/arm/convolve_neon.cc
@@ -0,0 +1,3105 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/convolve.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+// Include the constants and utility functions inside the anonymous namespace.
+#include "src/dsp/convolve.inc"
+
+// Multiply every entry in |src[]| by the corresponding entry in |taps[]| and
+// sum. The filters in |taps[]| are pre-shifted by 1. This prevents the final
+// sum from outranging int16_t.
+template <int filter_index, bool negative_outside_taps = false>
+int16x8_t SumOnePassTaps(const uint8x8_t* const src,
+ const uint8x8_t* const taps) {
+ uint16x8_t sum;
+ if (filter_index == 0) {
+ // 6 taps. + - + + - +
+ sum = vmull_u8(src[0], taps[0]);
+ // Unsigned overflow will result in a valid int16_t value.
+ sum = vmlsl_u8(sum, src[1], taps[1]);
+ sum = vmlal_u8(sum, src[2], taps[2]);
+ sum = vmlal_u8(sum, src[3], taps[3]);
+ sum = vmlsl_u8(sum, src[4], taps[4]);
+ sum = vmlal_u8(sum, src[5], taps[5]);
+ } else if (filter_index == 1 && negative_outside_taps) {
+ // 6 taps. - + + + + -
+ // Set a base we can subtract from.
+ sum = vmull_u8(src[1], taps[1]);
+ sum = vmlsl_u8(sum, src[0], taps[0]);
+ sum = vmlal_u8(sum, src[2], taps[2]);
+ sum = vmlal_u8(sum, src[3], taps[3]);
+ sum = vmlal_u8(sum, src[4], taps[4]);
+ sum = vmlsl_u8(sum, src[5], taps[5]);
+ } else if (filter_index == 1) {
+ // 6 taps. All are positive.
+ sum = vmull_u8(src[0], taps[0]);
+ sum = vmlal_u8(sum, src[1], taps[1]);
+ sum = vmlal_u8(sum, src[2], taps[2]);
+ sum = vmlal_u8(sum, src[3], taps[3]);
+ sum = vmlal_u8(sum, src[4], taps[4]);
+ sum = vmlal_u8(sum, src[5], taps[5]);
+ } else if (filter_index == 2) {
+ // 8 taps. - + - + + - + -
+ sum = vmull_u8(src[1], taps[1]);
+ sum = vmlsl_u8(sum, src[0], taps[0]);
+ sum = vmlsl_u8(sum, src[2], taps[2]);
+ sum = vmlal_u8(sum, src[3], taps[3]);
+ sum = vmlal_u8(sum, src[4], taps[4]);
+ sum = vmlsl_u8(sum, src[5], taps[5]);
+ sum = vmlal_u8(sum, src[6], taps[6]);
+ sum = vmlsl_u8(sum, src[7], taps[7]);
+ } else if (filter_index == 3) {
+ // 2 taps. All are positive.
+ sum = vmull_u8(src[0], taps[0]);
+ sum = vmlal_u8(sum, src[1], taps[1]);
+ } else if (filter_index == 4) {
+ // 4 taps. - + + -
+ sum = vmull_u8(src[1], taps[1]);
+ sum = vmlsl_u8(sum, src[0], taps[0]);
+ sum = vmlal_u8(sum, src[2], taps[2]);
+ sum = vmlsl_u8(sum, src[3], taps[3]);
+ } else if (filter_index == 5) {
+ // 4 taps. All are positive.
+ sum = vmull_u8(src[0], taps[0]);
+ sum = vmlal_u8(sum, src[1], taps[1]);
+ sum = vmlal_u8(sum, src[2], taps[2]);
+ sum = vmlal_u8(sum, src[3], taps[3]);
+ }
+ return vreinterpretq_s16_u16(sum);
+}
+
+template <int filter_index, bool negative_outside_taps>
+int16x8_t SumHorizontalTaps(const uint8_t* const src,
+ const uint8x8_t* const v_tap) {
+ uint8x8_t v_src[8];
+ const uint8x16_t src_long = vld1q_u8(src);
+ int16x8_t sum;
+
+ if (filter_index < 2) {
+ v_src[0] = vget_low_u8(vextq_u8(src_long, src_long, 1));
+ v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 2));
+ v_src[2] = vget_low_u8(vextq_u8(src_long, src_long, 3));
+ v_src[3] = vget_low_u8(vextq_u8(src_long, src_long, 4));
+ v_src[4] = vget_low_u8(vextq_u8(src_long, src_long, 5));
+ v_src[5] = vget_low_u8(vextq_u8(src_long, src_long, 6));
+ sum = SumOnePassTaps<filter_index, negative_outside_taps>(v_src, v_tap + 1);
+ } else if (filter_index == 2) {
+ v_src[0] = vget_low_u8(src_long);
+ v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 1));
+ v_src[2] = vget_low_u8(vextq_u8(src_long, src_long, 2));
+ v_src[3] = vget_low_u8(vextq_u8(src_long, src_long, 3));
+ v_src[4] = vget_low_u8(vextq_u8(src_long, src_long, 4));
+ v_src[5] = vget_low_u8(vextq_u8(src_long, src_long, 5));
+ v_src[6] = vget_low_u8(vextq_u8(src_long, src_long, 6));
+ v_src[7] = vget_low_u8(vextq_u8(src_long, src_long, 7));
+ sum = SumOnePassTaps<filter_index, negative_outside_taps>(v_src, v_tap);
+ } else if (filter_index == 3) {
+ v_src[0] = vget_low_u8(vextq_u8(src_long, src_long, 3));
+ v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 4));
+ sum = SumOnePassTaps<filter_index, negative_outside_taps>(v_src, v_tap + 3);
+ } else if (filter_index > 3) {
+ v_src[0] = vget_low_u8(vextq_u8(src_long, src_long, 2));
+ v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 3));
+ v_src[2] = vget_low_u8(vextq_u8(src_long, src_long, 4));
+ v_src[3] = vget_low_u8(vextq_u8(src_long, src_long, 5));
+ sum = SumOnePassTaps<filter_index, negative_outside_taps>(v_src, v_tap + 2);
+ }
+ return sum;
+}
+
+template <int filter_index, bool negative_outside_taps>
+uint8x8_t SimpleHorizontalTaps(const uint8_t* const src,
+ const uint8x8_t* const v_tap) {
+ int16x8_t sum =
+ SumHorizontalTaps<filter_index, negative_outside_taps>(src, v_tap);
+
+ // Normally the Horizontal pass does the downshift in two passes:
+ // kInterRoundBitsHorizontal - 1 and then (kFilterBits -
+ // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them
+ // requires adding the rounding offset from the skipped shift.
+ constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2);
+
+ sum = vaddq_s16(sum, vdupq_n_s16(first_shift_rounding_bit));
+ return vqrshrun_n_s16(sum, kFilterBits - 1);
+}
+
+template <int filter_index, bool negative_outside_taps>
+uint16x8_t HorizontalTaps8To16(const uint8_t* const src,
+ const uint8x8_t* const v_tap) {
+ const int16x8_t sum =
+ SumHorizontalTaps<filter_index, negative_outside_taps>(src, v_tap);
+
+ return vreinterpretq_u16_s16(
+ vrshrq_n_s16(sum, kInterRoundBitsHorizontal - 1));
+}
+
+template <int filter_index>
+int16x8_t SumHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride,
+ const uint8x8_t* const v_tap) {
+ uint16x8_t sum;
+ const uint8x8_t input0 = vld1_u8(src);
+ src += src_stride;
+ const uint8x8_t input1 = vld1_u8(src);
+ uint8x8x2_t input = vzip_u8(input0, input1);
+
+ if (filter_index == 3) {
+ // tap signs : + +
+ sum = vmull_u8(vext_u8(input.val[0], input.val[1], 6), v_tap[3]);
+ sum = vmlal_u8(sum, input.val[1], v_tap[4]);
+ } else if (filter_index == 4) {
+ // tap signs : - + + -
+ sum = vmull_u8(vext_u8(input.val[0], input.val[1], 6), v_tap[3]);
+ sum = vmlsl_u8(sum, RightShift<4 * 8>(input.val[0]), v_tap[2]);
+ sum = vmlal_u8(sum, input.val[1], v_tap[4]);
+ sum = vmlsl_u8(sum, RightShift<2 * 8>(input.val[1]), v_tap[5]);
+ } else {
+ // tap signs : + + + +
+ sum = vmull_u8(RightShift<4 * 8>(input.val[0]), v_tap[2]);
+ sum = vmlal_u8(sum, vext_u8(input.val[0], input.val[1], 6), v_tap[3]);
+ sum = vmlal_u8(sum, input.val[1], v_tap[4]);
+ sum = vmlal_u8(sum, RightShift<2 * 8>(input.val[1]), v_tap[5]);
+ }
+
+ return vreinterpretq_s16_u16(sum);
+}
+
+template <int filter_index>
+uint8x8_t SimpleHorizontalTaps2x2(const uint8_t* src,
+ const ptrdiff_t src_stride,
+ const uint8x8_t* const v_tap) {
+ int16x8_t sum = SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap);
+
+ // Normally the Horizontal pass does the downshift in two passes:
+ // kInterRoundBitsHorizontal - 1 and then (kFilterBits -
+ // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them
+ // requires adding the rounding offset from the skipped shift.
+ constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2);
+
+ sum = vaddq_s16(sum, vdupq_n_s16(first_shift_rounding_bit));
+ return vqrshrun_n_s16(sum, kFilterBits - 1);
+}
+
+template <int filter_index>
+uint16x8_t HorizontalTaps8To16_2x2(const uint8_t* src,
+ const ptrdiff_t src_stride,
+ const uint8x8_t* const v_tap) {
+ const int16x8_t sum =
+ SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap);
+
+ return vreinterpretq_u16_s16(
+ vrshrq_n_s16(sum, kInterRoundBitsHorizontal - 1));
+}
+
+template <int num_taps, int step, int filter_index,
+ bool negative_outside_taps = true, bool is_2d = false,
+ bool is_compound = false>
+void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride,
+ void* const dest, const ptrdiff_t pred_stride,
+ const int width, const int height,
+ const uint8x8_t* const v_tap) {
+ auto* dest8 = static_cast<uint8_t*>(dest);
+ auto* dest16 = static_cast<uint16_t*>(dest);
+
+ // 4 tap filters are never used when width > 4.
+ if (num_taps != 4 && width > 4) {
+ int y = 0;
+ do {
+ int x = 0;
+ do {
+ if (is_2d || is_compound) {
+ const uint16x8_t v_sum =
+ HorizontalTaps8To16<filter_index, negative_outside_taps>(&src[x],
+ v_tap);
+ vst1q_u16(&dest16[x], v_sum);
+ } else {
+ const uint8x8_t result =
+ SimpleHorizontalTaps<filter_index, negative_outside_taps>(&src[x],
+ v_tap);
+ vst1_u8(&dest8[x], result);
+ }
+ x += step;
+ } while (x < width);
+ src += src_stride;
+ dest8 += pred_stride;
+ dest16 += pred_stride;
+ } while (++y < height);
+ return;
+ }
+
+ // Horizontal passes only needs to account for |num_taps| 2 and 4 when
+ // |width| <= 4.
+ assert(width <= 4);
+ assert(num_taps <= 4);
+ if (num_taps <= 4) {
+ if (width == 4) {
+ int y = 0;
+ do {
+ if (is_2d || is_compound) {
+ const uint16x8_t v_sum =
+ HorizontalTaps8To16<filter_index, negative_outside_taps>(src,
+ v_tap);
+ vst1_u16(dest16, vget_low_u16(v_sum));
+ } else {
+ const uint8x8_t result =
+ SimpleHorizontalTaps<filter_index, negative_outside_taps>(src,
+ v_tap);
+ StoreLo4(&dest8[0], result);
+ }
+ src += src_stride;
+ dest8 += pred_stride;
+ dest16 += pred_stride;
+ } while (++y < height);
+ return;
+ }
+
+ if (!is_compound) {
+ int y = 0;
+ do {
+ if (is_2d) {
+ const uint16x8_t sum =
+ HorizontalTaps8To16_2x2<filter_index>(src, src_stride, v_tap);
+ dest16[0] = vgetq_lane_u16(sum, 0);
+ dest16[1] = vgetq_lane_u16(sum, 2);
+ dest16 += pred_stride;
+ dest16[0] = vgetq_lane_u16(sum, 1);
+ dest16[1] = vgetq_lane_u16(sum, 3);
+ dest16 += pred_stride;
+ } else {
+ const uint8x8_t sum =
+ SimpleHorizontalTaps2x2<filter_index>(src, src_stride, v_tap);
+
+ dest8[0] = vget_lane_u8(sum, 0);
+ dest8[1] = vget_lane_u8(sum, 2);
+ dest8 += pred_stride;
+
+ dest8[0] = vget_lane_u8(sum, 1);
+ dest8[1] = vget_lane_u8(sum, 3);
+ dest8 += pred_stride;
+ }
+
+ src += src_stride << 1;
+ y += 2;
+ } while (y < height - 1);
+
+ // The 2d filters have an odd |height| because the horizontal pass
+ // generates context for the vertical pass.
+ if (is_2d) {
+ assert(height % 2 == 1);
+ uint16x8_t sum;
+ const uint8x8_t input = vld1_u8(src);
+ if (filter_index == 3) { // |num_taps| == 2
+ sum = vmull_u8(RightShift<3 * 8>(input), v_tap[3]);
+ sum = vmlal_u8(sum, RightShift<4 * 8>(input), v_tap[4]);
+ } else if (filter_index == 4) {
+ sum = vmull_u8(RightShift<3 * 8>(input), v_tap[3]);
+ sum = vmlsl_u8(sum, RightShift<2 * 8>(input), v_tap[2]);
+ sum = vmlal_u8(sum, RightShift<4 * 8>(input), v_tap[4]);
+ sum = vmlsl_u8(sum, RightShift<5 * 8>(input), v_tap[5]);
+ } else {
+ assert(filter_index == 5);
+ sum = vmull_u8(RightShift<2 * 8>(input), v_tap[2]);
+ sum = vmlal_u8(sum, RightShift<3 * 8>(input), v_tap[3]);
+ sum = vmlal_u8(sum, RightShift<4 * 8>(input), v_tap[4]);
+ sum = vmlal_u8(sum, RightShift<5 * 8>(input), v_tap[5]);
+ }
+ // |sum| contains an int16_t value.
+ sum = vreinterpretq_u16_s16(vrshrq_n_s16(
+ vreinterpretq_s16_u16(sum), kInterRoundBitsHorizontal - 1));
+ Store2<0>(dest16, sum);
+ }
+ }
+ }
+}
+
+// Process 16 bit inputs and output 32 bits.
+template <int num_taps, bool is_compound>
+inline int16x4_t Sum2DVerticalTaps4(const int16x4_t* const src,
+ const int16x8_t taps) {
+ const int16x4_t taps_lo = vget_low_s16(taps);
+ const int16x4_t taps_hi = vget_high_s16(taps);
+ int32x4_t sum;
+ if (num_taps == 8) {
+ sum = vmull_lane_s16(src[0], taps_lo, 0);
+ sum = vmlal_lane_s16(sum, src[1], taps_lo, 1);
+ sum = vmlal_lane_s16(sum, src[2], taps_lo, 2);
+ sum = vmlal_lane_s16(sum, src[3], taps_lo, 3);
+ sum = vmlal_lane_s16(sum, src[4], taps_hi, 0);
+ sum = vmlal_lane_s16(sum, src[5], taps_hi, 1);
+ sum = vmlal_lane_s16(sum, src[6], taps_hi, 2);
+ sum = vmlal_lane_s16(sum, src[7], taps_hi, 3);
+ } else if (num_taps == 6) {
+ sum = vmull_lane_s16(src[0], taps_lo, 1);
+ sum = vmlal_lane_s16(sum, src[1], taps_lo, 2);
+ sum = vmlal_lane_s16(sum, src[2], taps_lo, 3);
+ sum = vmlal_lane_s16(sum, src[3], taps_hi, 0);
+ sum = vmlal_lane_s16(sum, src[4], taps_hi, 1);
+ sum = vmlal_lane_s16(sum, src[5], taps_hi, 2);
+ } else if (num_taps == 4) {
+ sum = vmull_lane_s16(src[0], taps_lo, 2);
+ sum = vmlal_lane_s16(sum, src[1], taps_lo, 3);
+ sum = vmlal_lane_s16(sum, src[2], taps_hi, 0);
+ sum = vmlal_lane_s16(sum, src[3], taps_hi, 1);
+ } else if (num_taps == 2) {
+ sum = vmull_lane_s16(src[0], taps_lo, 3);
+ sum = vmlal_lane_s16(sum, src[1], taps_hi, 0);
+ }
+
+ if (is_compound) {
+ return vqrshrn_n_s32(sum, kInterRoundBitsCompoundVertical - 1);
+ }
+
+ return vqrshrn_n_s32(sum, kInterRoundBitsVertical - 1);
+}
+
+template <int num_taps, bool is_compound>
+int16x8_t SimpleSum2DVerticalTaps(const int16x8_t* const src,
+ const int16x8_t taps) {
+ const int16x4_t taps_lo = vget_low_s16(taps);
+ const int16x4_t taps_hi = vget_high_s16(taps);
+ int32x4_t sum_lo, sum_hi;
+ if (num_taps == 8) {
+ sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 0);
+ sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 0);
+ sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 1);
+ sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 1);
+ sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_lo, 2);
+ sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_lo, 2);
+ sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_lo, 3);
+ sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_lo, 3);
+
+ sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[4]), taps_hi, 0);
+ sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[4]), taps_hi, 0);
+ sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[5]), taps_hi, 1);
+ sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[5]), taps_hi, 1);
+ sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[6]), taps_hi, 2);
+ sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[6]), taps_hi, 2);
+ sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[7]), taps_hi, 3);
+ sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[7]), taps_hi, 3);
+ } else if (num_taps == 6) {
+ sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 1);
+ sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 1);
+ sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 2);
+ sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 2);
+ sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_lo, 3);
+ sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_lo, 3);
+
+ sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_hi, 0);
+ sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_hi, 0);
+ sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[4]), taps_hi, 1);
+ sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[4]), taps_hi, 1);
+ sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[5]), taps_hi, 2);
+ sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[5]), taps_hi, 2);
+ } else if (num_taps == 4) {
+ sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 2);
+ sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 2);
+ sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 3);
+ sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 3);
+
+ sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_hi, 0);
+ sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_hi, 0);
+ sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_hi, 1);
+ sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_hi, 1);
+ } else if (num_taps == 2) {
+ sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 3);
+ sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 3);
+
+ sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_hi, 0);
+ sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_hi, 0);
+ }
+
+ if (is_compound) {
+ return vcombine_s16(
+ vqrshrn_n_s32(sum_lo, kInterRoundBitsCompoundVertical - 1),
+ vqrshrn_n_s32(sum_hi, kInterRoundBitsCompoundVertical - 1));
+ }
+
+ return vcombine_s16(vqrshrn_n_s32(sum_lo, kInterRoundBitsVertical - 1),
+ vqrshrn_n_s32(sum_hi, kInterRoundBitsVertical - 1));
+}
+
+template <int num_taps, bool is_compound = false>
+void Filter2DVertical(const uint16_t* src, void* const dst,
+ const ptrdiff_t dst_stride, const int width,
+ const int height, const int16x8_t taps) {
+ assert(width >= 8);
+ constexpr int next_row = num_taps - 1;
+ // The Horizontal pass uses |width| as |stride| for the intermediate buffer.
+ const ptrdiff_t src_stride = width;
+
+ auto* dst8 = static_cast<uint8_t*>(dst);
+ auto* dst16 = static_cast<uint16_t*>(dst);
+
+ int x = 0;
+ do {
+ int16x8_t srcs[8];
+ const uint16_t* src_x = src + x;
+ srcs[0] = vreinterpretq_s16_u16(vld1q_u16(src_x));
+ src_x += src_stride;
+ if (num_taps >= 4) {
+ srcs[1] = vreinterpretq_s16_u16(vld1q_u16(src_x));
+ src_x += src_stride;
+ srcs[2] = vreinterpretq_s16_u16(vld1q_u16(src_x));
+ src_x += src_stride;
+ if (num_taps >= 6) {
+ srcs[3] = vreinterpretq_s16_u16(vld1q_u16(src_x));
+ src_x += src_stride;
+ srcs[4] = vreinterpretq_s16_u16(vld1q_u16(src_x));
+ src_x += src_stride;
+ if (num_taps == 8) {
+ srcs[5] = vreinterpretq_s16_u16(vld1q_u16(src_x));
+ src_x += src_stride;
+ srcs[6] = vreinterpretq_s16_u16(vld1q_u16(src_x));
+ src_x += src_stride;
+ }
+ }
+ }
+
+ int y = 0;
+ do {
+ srcs[next_row] = vreinterpretq_s16_u16(vld1q_u16(src_x));
+ src_x += src_stride;
+
+ const int16x8_t sum =
+ SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps);
+ if (is_compound) {
+ vst1q_u16(dst16 + x + y * dst_stride, vreinterpretq_u16_s16(sum));
+ } else {
+ vst1_u8(dst8 + x + y * dst_stride, vqmovun_s16(sum));
+ }
+
+ srcs[0] = srcs[1];
+ if (num_taps >= 4) {
+ srcs[1] = srcs[2];
+ srcs[2] = srcs[3];
+ if (num_taps >= 6) {
+ srcs[3] = srcs[4];
+ srcs[4] = srcs[5];
+ if (num_taps == 8) {
+ srcs[5] = srcs[6];
+ srcs[6] = srcs[7];
+ }
+ }
+ }
+ } while (++y < height);
+ x += 8;
+ } while (x < width);
+}
+
+// Take advantage of |src_stride| == |width| to process two rows at a time.
+template <int num_taps, bool is_compound = false>
+void Filter2DVertical4xH(const uint16_t* src, void* const dst,
+ const ptrdiff_t dst_stride, const int height,
+ const int16x8_t taps) {
+ auto* dst8 = static_cast<uint8_t*>(dst);
+ auto* dst16 = static_cast<uint16_t*>(dst);
+
+ int16x8_t srcs[9];
+ srcs[0] = vreinterpretq_s16_u16(vld1q_u16(src));
+ src += 8;
+ if (num_taps >= 4) {
+ srcs[2] = vreinterpretq_s16_u16(vld1q_u16(src));
+ src += 8;
+ srcs[1] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[2]));
+ if (num_taps >= 6) {
+ srcs[4] = vreinterpretq_s16_u16(vld1q_u16(src));
+ src += 8;
+ srcs[3] = vcombine_s16(vget_high_s16(srcs[2]), vget_low_s16(srcs[4]));
+ if (num_taps == 8) {
+ srcs[6] = vreinterpretq_s16_u16(vld1q_u16(src));
+ src += 8;
+ srcs[5] = vcombine_s16(vget_high_s16(srcs[4]), vget_low_s16(srcs[6]));
+ }
+ }
+ }
+
+ int y = 0;
+ do {
+ srcs[num_taps] = vreinterpretq_s16_u16(vld1q_u16(src));
+ src += 8;
+ srcs[num_taps - 1] = vcombine_s16(vget_high_s16(srcs[num_taps - 2]),
+ vget_low_s16(srcs[num_taps]));
+
+ const int16x8_t sum =
+ SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps);
+ if (is_compound) {
+ const uint16x8_t results = vreinterpretq_u16_s16(sum);
+ vst1q_u16(dst16, results);
+ dst16 += 4 << 1;
+ } else {
+ const uint8x8_t results = vqmovun_s16(sum);
+
+ StoreLo4(dst8, results);
+ dst8 += dst_stride;
+ StoreHi4(dst8, results);
+ dst8 += dst_stride;
+ }
+
+ srcs[0] = srcs[2];
+ if (num_taps >= 4) {
+ srcs[1] = srcs[3];
+ srcs[2] = srcs[4];
+ if (num_taps >= 6) {
+ srcs[3] = srcs[5];
+ srcs[4] = srcs[6];
+ if (num_taps == 8) {
+ srcs[5] = srcs[7];
+ srcs[6] = srcs[8];
+ }
+ }
+ }
+ y += 2;
+ } while (y < height);
+}
+
+// Take advantage of |src_stride| == |width| to process four rows at a time.
+template <int num_taps>
+void Filter2DVertical2xH(const uint16_t* src, void* const dst,
+ const ptrdiff_t dst_stride, const int height,
+ const int16x8_t taps) {
+ constexpr int next_row = (num_taps < 6) ? 4 : 8;
+
+ auto* dst8 = static_cast<uint8_t*>(dst);
+
+ int16x8_t srcs[9];
+ srcs[0] = vreinterpretq_s16_u16(vld1q_u16(src));
+ src += 8;
+ if (num_taps >= 6) {
+ srcs[4] = vreinterpretq_s16_u16(vld1q_u16(src));
+ src += 8;
+ srcs[1] = vextq_s16(srcs[0], srcs[4], 2);
+ if (num_taps == 8) {
+ srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4]));
+ srcs[3] = vextq_s16(srcs[0], srcs[4], 6);
+ }
+ }
+
+ int y = 0;
+ do {
+ srcs[next_row] = vreinterpretq_s16_u16(vld1q_u16(src));
+ src += 8;
+ if (num_taps == 2) {
+ srcs[1] = vextq_s16(srcs[0], srcs[4], 2);
+ } else if (num_taps == 4) {
+ srcs[1] = vextq_s16(srcs[0], srcs[4], 2);
+ srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4]));
+ srcs[3] = vextq_s16(srcs[0], srcs[4], 6);
+ } else if (num_taps == 6) {
+ srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4]));
+ srcs[3] = vextq_s16(srcs[0], srcs[4], 6);
+ srcs[5] = vextq_s16(srcs[4], srcs[8], 2);
+ } else if (num_taps == 8) {
+ srcs[5] = vextq_s16(srcs[4], srcs[8], 2);
+ srcs[6] = vcombine_s16(vget_high_s16(srcs[4]), vget_low_s16(srcs[8]));
+ srcs[7] = vextq_s16(srcs[4], srcs[8], 6);
+ }
+
+ const int16x8_t sum =
+ SimpleSum2DVerticalTaps<num_taps, /*is_compound=*/false>(srcs, taps);
+ const uint8x8_t results = vqmovun_s16(sum);
+
+ Store2<0>(dst8, results);
+ dst8 += dst_stride;
+ Store2<1>(dst8, results);
+ // When |height| <= 4 the taps are restricted to 2 and 4 tap variants.
+ // Therefore we don't need to check this condition when |height| > 4.
+ if (num_taps <= 4 && height == 2) return;
+ dst8 += dst_stride;
+ Store2<2>(dst8, results);
+ dst8 += dst_stride;
+ Store2<3>(dst8, results);
+ dst8 += dst_stride;
+
+ srcs[0] = srcs[4];
+ if (num_taps == 6) {
+ srcs[1] = srcs[5];
+ srcs[4] = srcs[8];
+ } else if (num_taps == 8) {
+ srcs[1] = srcs[5];
+ srcs[2] = srcs[6];
+ srcs[3] = srcs[7];
+ srcs[4] = srcs[8];
+ }
+
+ y += 4;
+ } while (y < height);
+}
+
+template <bool is_2d = false, bool is_compound = false>
+LIBGAV1_ALWAYS_INLINE void DoHorizontalPass(
+ const uint8_t* const src, const ptrdiff_t src_stride, void* const dst,
+ const ptrdiff_t dst_stride, const int width, const int height,
+ const int filter_id, const int filter_index) {
+ // Duplicate the absolute value for each tap. Negative taps are corrected
+ // by using the vmlsl_u8 instruction. Positive taps use vmlal_u8.
+ uint8x8_t v_tap[kSubPixelTaps];
+ assert(filter_id != 0);
+
+ for (int k = 0; k < kSubPixelTaps; ++k) {
+ v_tap[k] = vdup_n_u8(kAbsHalfSubPixelFilters[filter_index][filter_id][k]);
+ }
+
+ if (filter_index == 2) { // 8 tap.
+ FilterHorizontal<8, 8, 2, true, is_2d, is_compound>(
+ src, src_stride, dst, dst_stride, width, height, v_tap);
+ } else if (filter_index == 1) { // 6 tap.
+ // Check if outside taps are positive.
+ if ((filter_id == 1) | (filter_id == 15)) {
+ FilterHorizontal<6, 8, 1, false, is_2d, is_compound>(
+ src, src_stride, dst, dst_stride, width, height, v_tap);
+ } else {
+ FilterHorizontal<6, 8, 1, true, is_2d, is_compound>(
+ src, src_stride, dst, dst_stride, width, height, v_tap);
+ }
+ } else if (filter_index == 0) { // 6 tap.
+ FilterHorizontal<6, 8, 0, true, is_2d, is_compound>(
+ src, src_stride, dst, dst_stride, width, height, v_tap);
+ } else if (filter_index == 4) { // 4 tap.
+ FilterHorizontal<4, 8, 4, true, is_2d, is_compound>(
+ src, src_stride, dst, dst_stride, width, height, v_tap);
+ } else if (filter_index == 5) { // 4 tap.
+ FilterHorizontal<4, 8, 5, true, is_2d, is_compound>(
+ src, src_stride, dst, dst_stride, width, height, v_tap);
+ } else { // 2 tap.
+ FilterHorizontal<2, 8, 3, true, is_2d, is_compound>(
+ src, src_stride, dst, dst_stride, width, height, v_tap);
+ }
+}
+
+void Convolve2D_NEON(const void* const reference,
+ const ptrdiff_t reference_stride,
+ const int horizontal_filter_index,
+ const int vertical_filter_index,
+ const int horizontal_filter_id,
+ const int vertical_filter_id, const int width,
+ const int height, void* prediction,
+ const ptrdiff_t pred_stride) {
+ const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
+ const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
+ const int vertical_taps = GetNumTapsInFilter(vert_filter_index);
+
+ // The output of the horizontal filter is guaranteed to fit in 16 bits.
+ uint16_t
+ intermediate_result[kMaxSuperBlockSizeInPixels *
+ (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
+ const int intermediate_height = height + vertical_taps - 1;
+
+ const ptrdiff_t src_stride = reference_stride;
+ const auto* src = static_cast<const uint8_t*>(reference) -
+ (vertical_taps / 2 - 1) * src_stride - kHorizontalOffset;
+
+ DoHorizontalPass</*is_2d=*/true>(src, src_stride, intermediate_result, width,
+ width, intermediate_height,
+ horizontal_filter_id, horiz_filter_index);
+
+ // Vertical filter.
+ auto* dest = static_cast<uint8_t*>(prediction);
+ const ptrdiff_t dest_stride = pred_stride;
+ assert(vertical_filter_id != 0);
+
+ const int16x8_t taps = vmovl_s8(
+ vld1_s8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id]));
+
+ if (vertical_taps == 8) {
+ if (width == 2) {
+ Filter2DVertical2xH<8>(intermediate_result, dest, dest_stride, height,
+ taps);
+ } else if (width == 4) {
+ Filter2DVertical4xH<8>(intermediate_result, dest, dest_stride, height,
+ taps);
+ } else {
+ Filter2DVertical<8>(intermediate_result, dest, dest_stride, width, height,
+ taps);
+ }
+ } else if (vertical_taps == 6) {
+ if (width == 2) {
+ Filter2DVertical2xH<6>(intermediate_result, dest, dest_stride, height,
+ taps);
+ } else if (width == 4) {
+ Filter2DVertical4xH<6>(intermediate_result, dest, dest_stride, height,
+ taps);
+ } else {
+ Filter2DVertical<6>(intermediate_result, dest, dest_stride, width, height,
+ taps);
+ }
+ } else if (vertical_taps == 4) {
+ if (width == 2) {
+ Filter2DVertical2xH<4>(intermediate_result, dest, dest_stride, height,
+ taps);
+ } else if (width == 4) {
+ Filter2DVertical4xH<4>(intermediate_result, dest, dest_stride, height,
+ taps);
+ } else {
+ Filter2DVertical<4>(intermediate_result, dest, dest_stride, width, height,
+ taps);
+ }
+ } else { // |vertical_taps| == 2
+ if (width == 2) {
+ Filter2DVertical2xH<2>(intermediate_result, dest, dest_stride, height,
+ taps);
+ } else if (width == 4) {
+ Filter2DVertical4xH<2>(intermediate_result, dest, dest_stride, height,
+ taps);
+ } else {
+ Filter2DVertical<2>(intermediate_result, dest, dest_stride, width, height,
+ taps);
+ }
+ }
+}
+
+// There are many opportunities for overreading in scaled convolve, because the
+// range of starting points for filter windows is anywhere from 0 to 16 for 8
+// destination pixels, and the window sizes range from 2 to 8. To accommodate
+// this range concisely, we use |grade_x| to mean the most steps in src that can
+// be traversed in a single |step_x| increment, i.e. 1 or 2. When grade_x is 2,
+// we are guaranteed to exceed 8 whole steps in src for every 8 |step_x|
+// increments. The first load covers the initial elements of src_x, while the
+// final load covers the taps.
+template <int grade_x>
+inline uint8x8x3_t LoadSrcVals(const uint8_t* src_x) {
+ uint8x8x3_t ret;
+ const uint8x16_t src_val = vld1q_u8(src_x);
+ ret.val[0] = vget_low_u8(src_val);
+ ret.val[1] = vget_high_u8(src_val);
+ if (grade_x > 1) {
+ ret.val[2] = vld1_u8(src_x + 16);
+ }
+ return ret;
+}
+
+// Pre-transpose the 2 tap filters in |kAbsHalfSubPixelFilters|[3]
+inline uint8x16_t GetPositive2TapFilter(const int tap_index) {
+ assert(tap_index < 2);
+ alignas(
+ 16) static constexpr uint8_t kAbsHalfSubPixel2TapFilterColumns[2][16] = {
+ {64, 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4},
+ {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}};
+
+ return vld1q_u8(kAbsHalfSubPixel2TapFilterColumns[tap_index]);
+}
+
+template <int grade_x>
+inline void ConvolveKernelHorizontal2Tap(const uint8_t* src,
+ const ptrdiff_t src_stride,
+ const int width, const int subpixel_x,
+ const int step_x,
+ const int intermediate_height,
+ int16_t* intermediate) {
+ // Account for the 0-taps that precede the 2 nonzero taps.
+ const int kernel_offset = 3;
+ const int ref_x = subpixel_x >> kScaleSubPixelBits;
+ const int step_x8 = step_x << 3;
+ const uint8x16_t filter_taps0 = GetPositive2TapFilter(0);
+ const uint8x16_t filter_taps1 = GetPositive2TapFilter(1);
+ const uint16x8_t index_steps = vmulq_n_u16(
+ vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
+ const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
+
+ int p = subpixel_x;
+ if (width <= 4) {
+ const uint8_t* src_x =
+ &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
+ // Only add steps to the 10-bit truncated p to avoid overflow.
+ const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+ const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+ const uint8x8_t filter_indices =
+ vand_u8(vshrn_n_u16(subpel_index_offsets, 6), filter_index_mask);
+ // This is a special case. The 2-tap filter has no negative taps, so we
+ // can use unsigned values.
+ // For each x, a lane of tapsK has
+ // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends
+ // on x.
+ const uint8x8_t taps[2] = {VQTbl1U8(filter_taps0, filter_indices),
+ VQTbl1U8(filter_taps1, filter_indices)};
+ int y = 0;
+ do {
+ // Load a pool of samples to select from using stepped indices.
+ const uint8x16_t src_vals = vld1q_u8(src_x);
+ const uint8x8_t src_indices =
+ vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
+
+ // For each x, a lane of srcK contains src_x[k].
+ const uint8x8_t src[2] = {
+ VQTbl1U8(src_vals, src_indices),
+ VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(1)))};
+
+ vst1q_s16(intermediate,
+ vrshrq_n_s16(SumOnePassTaps</*filter_index=*/3>(src, taps),
+ kInterRoundBitsHorizontal - 1));
+ src_x += src_stride;
+ intermediate += kIntermediateStride;
+ } while (++y < intermediate_height);
+ return;
+ }
+
+ // |width| >= 8
+ int x = 0;
+ do {
+ const uint8_t* src_x =
+ &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
+ int16_t* intermediate_x = intermediate + x;
+ // Only add steps to the 10-bit truncated p to avoid overflow.
+ const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+ const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+ const uint8x8_t filter_indices =
+ vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
+ filter_index_mask);
+ // This is a special case. The 2-tap filter has no negative taps, so we
+ // can use unsigned values.
+ // For each x, a lane of tapsK has
+ // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends
+ // on x.
+ const uint8x8_t taps[2] = {VQTbl1U8(filter_taps0, filter_indices),
+ VQTbl1U8(filter_taps1, filter_indices)};
+ int y = 0;
+ do {
+ // Load a pool of samples to select from using stepped indices.
+ const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x);
+ const uint8x8_t src_indices =
+ vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
+
+ // For each x, a lane of srcK contains src_x[k].
+ const uint8x8_t src[2] = {
+ vtbl3_u8(src_vals, src_indices),
+ vtbl3_u8(src_vals, vadd_u8(src_indices, vdup_n_u8(1)))};
+
+ vst1q_s16(intermediate_x,
+ vrshrq_n_s16(SumOnePassTaps</*filter_index=*/3>(src, taps),
+ kInterRoundBitsHorizontal - 1));
+ src_x += src_stride;
+ intermediate_x += kIntermediateStride;
+ } while (++y < intermediate_height);
+ x += 8;
+ p += step_x8;
+ } while (x < width);
+}
+
+// Pre-transpose the 4 tap filters in |kAbsHalfSubPixelFilters|[5].
+inline uint8x16_t GetPositive4TapFilter(const int tap_index) {
+ assert(tap_index < 4);
+ alignas(
+ 16) static constexpr uint8_t kSubPixel4TapPositiveFilterColumns[4][16] = {
+ {0, 15, 13, 11, 10, 9, 8, 7, 6, 6, 5, 4, 3, 2, 2, 1},
+ {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17},
+ {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31},
+ {0, 1, 2, 2, 3, 4, 5, 6, 6, 7, 8, 9, 10, 11, 13, 15}};
+
+ return vld1q_u8(kSubPixel4TapPositiveFilterColumns[tap_index]);
+}
+
+// This filter is only possible when width <= 4.
+void ConvolveKernelHorizontalPositive4Tap(
+ const uint8_t* src, const ptrdiff_t src_stride, const int subpixel_x,
+ const int step_x, const int intermediate_height, int16_t* intermediate) {
+ const int kernel_offset = 2;
+ const int ref_x = subpixel_x >> kScaleSubPixelBits;
+ const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
+ const uint8x16_t filter_taps0 = GetPositive4TapFilter(0);
+ const uint8x16_t filter_taps1 = GetPositive4TapFilter(1);
+ const uint8x16_t filter_taps2 = GetPositive4TapFilter(2);
+ const uint8x16_t filter_taps3 = GetPositive4TapFilter(3);
+ const uint16x8_t index_steps = vmulq_n_u16(
+ vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
+ const int p = subpixel_x;
+ // First filter is special, just a 128 tap on the center.
+ const uint8_t* src_x =
+ &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
+ // Only add steps to the 10-bit truncated p to avoid overflow.
+ const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+ const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+ const uint8x8_t filter_indices = vand_u8(
+ vshrn_n_u16(subpel_index_offsets, kFilterIndexShift), filter_index_mask);
+ // Note that filter_id depends on x.
+ // For each x, tapsK has kSubPixelFilters[filter_index][filter_id][k].
+ const uint8x8_t taps[4] = {VQTbl1U8(filter_taps0, filter_indices),
+ VQTbl1U8(filter_taps1, filter_indices),
+ VQTbl1U8(filter_taps2, filter_indices),
+ VQTbl1U8(filter_taps3, filter_indices)};
+
+ const uint8x8_t src_indices =
+ vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
+ int y = 0;
+ do {
+ // Load a pool of samples to select from using stepped index vectors.
+ const uint8x16_t src_vals = vld1q_u8(src_x);
+
+ // For each x, srcK contains src_x[k] where k=1.
+ // Whereas taps come from different arrays, src pixels are drawn from the
+ // same contiguous line.
+ const uint8x8_t src[4] = {
+ VQTbl1U8(src_vals, src_indices),
+ VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(1))),
+ VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(2))),
+ VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(3)))};
+
+ vst1q_s16(intermediate,
+ vrshrq_n_s16(SumOnePassTaps</*filter_index=*/5>(src, taps),
+ kInterRoundBitsHorizontal - 1));
+
+ src_x += src_stride;
+ intermediate += kIntermediateStride;
+ } while (++y < intermediate_height);
+}
+
+// Pre-transpose the 4 tap filters in |kAbsHalfSubPixelFilters|[4].
+inline uint8x16_t GetSigned4TapFilter(const int tap_index) {
+ assert(tap_index < 4);
+ alignas(16) static constexpr uint8_t
+ kAbsHalfSubPixel4TapSignedFilterColumns[4][16] = {
+ {0, 2, 4, 5, 6, 6, 7, 6, 6, 5, 5, 5, 4, 3, 2, 1},
+ {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4},
+ {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63},
+ {0, 1, 2, 3, 4, 5, 5, 5, 6, 6, 7, 6, 6, 5, 4, 2}};
+
+ return vld1q_u8(kAbsHalfSubPixel4TapSignedFilterColumns[tap_index]);
+}
+
+// This filter is only possible when width <= 4.
+inline void ConvolveKernelHorizontalSigned4Tap(
+ const uint8_t* src, const ptrdiff_t src_stride, const int subpixel_x,
+ const int step_x, const int intermediate_height, int16_t* intermediate) {
+ const int kernel_offset = 2;
+ const int ref_x = subpixel_x >> kScaleSubPixelBits;
+ const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
+ const uint8x16_t filter_taps0 = GetSigned4TapFilter(0);
+ const uint8x16_t filter_taps1 = GetSigned4TapFilter(1);
+ const uint8x16_t filter_taps2 = GetSigned4TapFilter(2);
+ const uint8x16_t filter_taps3 = GetSigned4TapFilter(3);
+ const uint16x4_t index_steps = vmul_n_u16(vcreate_u16(0x0003000200010000),
+ static_cast<uint16_t>(step_x));
+
+ const int p = subpixel_x;
+ const uint8_t* src_x =
+ &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
+ // Only add steps to the 10-bit truncated p to avoid overflow.
+ const uint16x4_t p_fraction = vdup_n_u16(p & 1023);
+ const uint16x4_t subpel_index_offsets = vadd_u16(index_steps, p_fraction);
+ const uint8x8_t filter_index_offsets = vshrn_n_u16(
+ vcombine_u16(subpel_index_offsets, vdup_n_u16(0)), kFilterIndexShift);
+ const uint8x8_t filter_indices =
+ vand_u8(filter_index_offsets, filter_index_mask);
+ // Note that filter_id depends on x.
+ // For each x, tapsK has kSubPixelFilters[filter_index][filter_id][k].
+ const uint8x8_t taps[4] = {VQTbl1U8(filter_taps0, filter_indices),
+ VQTbl1U8(filter_taps1, filter_indices),
+ VQTbl1U8(filter_taps2, filter_indices),
+ VQTbl1U8(filter_taps3, filter_indices)};
+
+ const uint8x8_t src_indices_base =
+ vshr_n_u8(filter_index_offsets, kScaleSubPixelBits - kFilterIndexShift);
+
+ const uint8x8_t src_indices[4] = {src_indices_base,
+ vadd_u8(src_indices_base, vdup_n_u8(1)),
+ vadd_u8(src_indices_base, vdup_n_u8(2)),
+ vadd_u8(src_indices_base, vdup_n_u8(3))};
+
+ int y = 0;
+ do {
+ // Load a pool of samples to select from using stepped indices.
+ const uint8x16_t src_vals = vld1q_u8(src_x);
+
+ // For each x, srcK contains src_x[k] where k=1.
+ // Whereas taps come from different arrays, src pixels are drawn from the
+ // same contiguous line.
+ const uint8x8_t src[4] = {
+ VQTbl1U8(src_vals, src_indices[0]), VQTbl1U8(src_vals, src_indices[1]),
+ VQTbl1U8(src_vals, src_indices[2]), VQTbl1U8(src_vals, src_indices[3])};
+
+ vst1q_s16(intermediate,
+ vrshrq_n_s16(SumOnePassTaps</*filter_index=*/4>(src, taps),
+ kInterRoundBitsHorizontal - 1));
+ src_x += src_stride;
+ intermediate += kIntermediateStride;
+ } while (++y < intermediate_height);
+}
+
+// Pre-transpose the 6 tap filters in |kAbsHalfSubPixelFilters|[0].
+inline uint8x16_t GetSigned6TapFilter(const int tap_index) {
+ assert(tap_index < 6);
+ alignas(16) static constexpr uint8_t
+ kAbsHalfSubPixel6TapSignedFilterColumns[6][16] = {
+ {0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0},
+ {0, 3, 5, 6, 7, 7, 8, 7, 7, 6, 6, 6, 5, 4, 2, 1},
+ {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4},
+ {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63},
+ {0, 1, 2, 4, 5, 6, 6, 6, 7, 7, 8, 7, 7, 6, 5, 3},
+ {0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}};
+
+ return vld1q_u8(kAbsHalfSubPixel6TapSignedFilterColumns[tap_index]);
+}
+
+// This filter is only possible when width >= 8.
+template <int grade_x>
+inline void ConvolveKernelHorizontalSigned6Tap(
+ const uint8_t* src, const ptrdiff_t src_stride, const int width,
+ const int subpixel_x, const int step_x, const int intermediate_height,
+ int16_t* intermediate) {
+ const int kernel_offset = 1;
+ const uint8x8_t one = vdup_n_u8(1);
+ const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
+ const int ref_x = subpixel_x >> kScaleSubPixelBits;
+ const int step_x8 = step_x << 3;
+ uint8x16_t filter_taps[6];
+ for (int i = 0; i < 6; ++i) {
+ filter_taps[i] = GetSigned6TapFilter(i);
+ }
+ const uint16x8_t index_steps = vmulq_n_u16(
+ vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
+
+ int x = 0;
+ int p = subpixel_x;
+ do {
+ // Avoid overloading outside the reference boundaries. This means
+ // |trailing_width| can be up to 24.
+ const uint8_t* src_x =
+ &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
+ int16_t* intermediate_x = intermediate + x;
+ // Only add steps to the 10-bit truncated p to avoid overflow.
+ const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+ const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+ const uint8x8_t src_indices =
+ vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
+ uint8x8_t src_lookup[6];
+ src_lookup[0] = src_indices;
+ for (int i = 1; i < 6; ++i) {
+ src_lookup[i] = vadd_u8(src_lookup[i - 1], one);
+ }
+
+ const uint8x8_t filter_indices =
+ vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
+ filter_index_mask);
+ // For each x, a lane of taps[k] has
+ // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends
+ // on x.
+ uint8x8_t taps[6];
+ for (int i = 0; i < 6; ++i) {
+ taps[i] = VQTbl1U8(filter_taps[i], filter_indices);
+ }
+ int y = 0;
+ do {
+ // Load a pool of samples to select from using stepped indices.
+ const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x);
+
+ const uint8x8_t src[6] = {
+ vtbl3_u8(src_vals, src_lookup[0]), vtbl3_u8(src_vals, src_lookup[1]),
+ vtbl3_u8(src_vals, src_lookup[2]), vtbl3_u8(src_vals, src_lookup[3]),
+ vtbl3_u8(src_vals, src_lookup[4]), vtbl3_u8(src_vals, src_lookup[5])};
+
+ vst1q_s16(intermediate_x,
+ vrshrq_n_s16(SumOnePassTaps</*filter_index=*/0>(src, taps),
+ kInterRoundBitsHorizontal - 1));
+ src_x += src_stride;
+ intermediate_x += kIntermediateStride;
+ } while (++y < intermediate_height);
+ x += 8;
+ p += step_x8;
+ } while (x < width);
+}
+
+// Pre-transpose the 6 tap filters in |kAbsHalfSubPixelFilters|[1]. This filter
+// has mixed positive and negative outer taps which are handled in
+// GetMixed6TapFilter().
+inline uint8x16_t GetPositive6TapFilter(const int tap_index) {
+ assert(tap_index < 6);
+ alignas(16) static constexpr uint8_t
+ kAbsHalfSubPixel6TapPositiveFilterColumns[4][16] = {
+ {0, 14, 13, 11, 10, 9, 8, 8, 7, 6, 5, 4, 3, 2, 2, 1},
+ {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17},
+ {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31},
+ {0, 1, 2, 2, 3, 4, 5, 6, 7, 8, 8, 9, 10, 11, 13, 14}};
+
+ return vld1q_u8(kAbsHalfSubPixel6TapPositiveFilterColumns[tap_index]);
+}
+
+inline int8x16_t GetMixed6TapFilter(const int tap_index) {
+ assert(tap_index < 2);
+ alignas(
+ 16) static constexpr int8_t kHalfSubPixel6TapMixedFilterColumns[2][16] = {
+ {0, 1, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 0, 0},
+ {0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 1}};
+
+ return vld1q_s8(kHalfSubPixel6TapMixedFilterColumns[tap_index]);
+}
+
+// This filter is only possible when width >= 8.
+template <int grade_x>
+inline void ConvolveKernelHorizontalMixed6Tap(
+ const uint8_t* src, const ptrdiff_t src_stride, const int width,
+ const int subpixel_x, const int step_x, const int intermediate_height,
+ int16_t* intermediate) {
+ const int kernel_offset = 1;
+ const uint8x8_t one = vdup_n_u8(1);
+ const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
+ const int ref_x = subpixel_x >> kScaleSubPixelBits;
+ const int step_x8 = step_x << 3;
+ uint8x8_t taps[4];
+ int16x8_t mixed_taps[2];
+ uint8x16_t positive_filter_taps[4];
+ for (int i = 0; i < 4; ++i) {
+ positive_filter_taps[i] = GetPositive6TapFilter(i);
+ }
+ int8x16_t mixed_filter_taps[2];
+ mixed_filter_taps[0] = GetMixed6TapFilter(0);
+ mixed_filter_taps[1] = GetMixed6TapFilter(1);
+ const uint16x8_t index_steps = vmulq_n_u16(
+ vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
+
+ int x = 0;
+ int p = subpixel_x;
+ do {
+ const uint8_t* src_x =
+ &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
+ int16_t* intermediate_x = intermediate + x;
+ // Only add steps to the 10-bit truncated p to avoid overflow.
+ const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+ const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+ const uint8x8_t src_indices =
+ vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
+ uint8x8_t src_lookup[6];
+ src_lookup[0] = src_indices;
+ for (int i = 1; i < 6; ++i) {
+ src_lookup[i] = vadd_u8(src_lookup[i - 1], one);
+ }
+
+ const uint8x8_t filter_indices =
+ vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
+ filter_index_mask);
+ // For each x, a lane of taps[k] has
+ // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends
+ // on x.
+ for (int i = 0; i < 4; ++i) {
+ taps[i] = VQTbl1U8(positive_filter_taps[i], filter_indices);
+ }
+ mixed_taps[0] = vmovl_s8(VQTbl1S8(mixed_filter_taps[0], filter_indices));
+ mixed_taps[1] = vmovl_s8(VQTbl1S8(mixed_filter_taps[1], filter_indices));
+
+ int y = 0;
+ do {
+ // Load a pool of samples to select from using stepped indices.
+ const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x);
+
+ int16x8_t sum_mixed = vmulq_s16(
+ mixed_taps[0], ZeroExtend(vtbl3_u8(src_vals, src_lookup[0])));
+ sum_mixed = vmlaq_s16(sum_mixed, mixed_taps[1],
+ ZeroExtend(vtbl3_u8(src_vals, src_lookup[5])));
+ uint16x8_t sum = vreinterpretq_u16_s16(sum_mixed);
+ sum = vmlal_u8(sum, taps[0], vtbl3_u8(src_vals, src_lookup[1]));
+ sum = vmlal_u8(sum, taps[1], vtbl3_u8(src_vals, src_lookup[2]));
+ sum = vmlal_u8(sum, taps[2], vtbl3_u8(src_vals, src_lookup[3]));
+ sum = vmlal_u8(sum, taps[3], vtbl3_u8(src_vals, src_lookup[4]));
+
+ vst1q_s16(intermediate_x, vrshrq_n_s16(vreinterpretq_s16_u16(sum),
+ kInterRoundBitsHorizontal - 1));
+ src_x += src_stride;
+ intermediate_x += kIntermediateStride;
+ } while (++y < intermediate_height);
+ x += 8;
+ p += step_x8;
+ } while (x < width);
+}
+
+// Pre-transpose the 8 tap filters in |kAbsHalfSubPixelFilters|[2].
+inline uint8x16_t GetSigned8TapFilter(const int tap_index) {
+ assert(tap_index < 8);
+ alignas(16) static constexpr uint8_t
+ kAbsHalfSubPixel8TapSignedFilterColumns[8][16] = {
+ {0, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 0},
+ {0, 1, 3, 4, 5, 5, 5, 5, 6, 5, 4, 4, 3, 3, 2, 1},
+ {0, 3, 6, 9, 11, 11, 12, 12, 12, 11, 10, 9, 7, 5, 3, 1},
+ {64, 63, 62, 60, 58, 54, 50, 45, 40, 35, 30, 24, 19, 13, 8, 4},
+ {0, 4, 8, 13, 19, 24, 30, 35, 40, 45, 50, 54, 58, 60, 62, 63},
+ {0, 1, 3, 5, 7, 9, 10, 11, 12, 12, 12, 11, 11, 9, 6, 3},
+ {0, 1, 2, 3, 3, 4, 4, 5, 6, 5, 5, 5, 5, 4, 3, 1},
+ {0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1}};
+
+ return vld1q_u8(kAbsHalfSubPixel8TapSignedFilterColumns[tap_index]);
+}
+
+// This filter is only possible when width >= 8.
+template <int grade_x>
+inline void ConvolveKernelHorizontalSigned8Tap(
+ const uint8_t* src, const ptrdiff_t src_stride, const int width,
+ const int subpixel_x, const int step_x, const int intermediate_height,
+ int16_t* intermediate) {
+ const uint8x8_t one = vdup_n_u8(1);
+ const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
+ const int ref_x = subpixel_x >> kScaleSubPixelBits;
+ const int step_x8 = step_x << 3;
+ uint8x8_t taps[8];
+ uint8x16_t filter_taps[8];
+ for (int i = 0; i < 8; ++i) {
+ filter_taps[i] = GetSigned8TapFilter(i);
+ }
+ const uint16x8_t index_steps = vmulq_n_u16(
+ vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
+ int x = 0;
+ int p = subpixel_x;
+ do {
+ const uint8_t* src_x = &src[(p >> kScaleSubPixelBits) - ref_x];
+ int16_t* intermediate_x = intermediate + x;
+ // Only add steps to the 10-bit truncated p to avoid overflow.
+ const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+ const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+ const uint8x8_t src_indices =
+ vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
+ uint8x8_t src_lookup[8];
+ src_lookup[0] = src_indices;
+ for (int i = 1; i < 8; ++i) {
+ src_lookup[i] = vadd_u8(src_lookup[i - 1], one);
+ }
+
+ const uint8x8_t filter_indices =
+ vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
+ filter_index_mask);
+ // For each x, a lane of taps[k] has
+ // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends
+ // on x.
+ for (int i = 0; i < 8; ++i) {
+ taps[i] = VQTbl1U8(filter_taps[i], filter_indices);
+ }
+
+ int y = 0;
+ do {
+ // Load a pool of samples to select from using stepped indices.
+ const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x);
+
+ const uint8x8_t src[8] = {
+ vtbl3_u8(src_vals, src_lookup[0]), vtbl3_u8(src_vals, src_lookup[1]),
+ vtbl3_u8(src_vals, src_lookup[2]), vtbl3_u8(src_vals, src_lookup[3]),
+ vtbl3_u8(src_vals, src_lookup[4]), vtbl3_u8(src_vals, src_lookup[5]),
+ vtbl3_u8(src_vals, src_lookup[6]), vtbl3_u8(src_vals, src_lookup[7])};
+
+ vst1q_s16(intermediate_x,
+ vrshrq_n_s16(SumOnePassTaps</*filter_index=*/2>(src, taps),
+ kInterRoundBitsHorizontal - 1));
+ src_x += src_stride;
+ intermediate_x += kIntermediateStride;
+ } while (++y < intermediate_height);
+ x += 8;
+ p += step_x8;
+ } while (x < width);
+}
+
+// This function handles blocks of width 2 or 4.
+template <int num_taps, int grade_y, int width, bool is_compound>
+void ConvolveVerticalScale4xH(const int16_t* src, const int subpixel_y,
+ const int filter_index, const int step_y,
+ const int height, void* dest,
+ const ptrdiff_t dest_stride) {
+ constexpr ptrdiff_t src_stride = kIntermediateStride;
+ const int16_t* src_y = src;
+ // |dest| is 16-bit in compound mode, Pixel otherwise.
+ uint16_t* dest16_y = static_cast<uint16_t*>(dest);
+ uint8_t* dest_y = static_cast<uint8_t*>(dest);
+ int16x4_t s[num_taps + grade_y];
+
+ int p = subpixel_y & 1023;
+ int prev_p = p;
+ int y = 0;
+ do { // y < height
+ for (int i = 0; i < num_taps; ++i) {
+ s[i] = vld1_s16(src_y + i * src_stride);
+ }
+ int filter_id = (p >> 6) & kSubPixelMask;
+ int16x8_t filter =
+ vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
+ int16x4_t sums = Sum2DVerticalTaps4<num_taps, is_compound>(s, filter);
+ if (is_compound) {
+ assert(width != 2);
+ const uint16x4_t result = vreinterpret_u16_s16(sums);
+ vst1_u16(dest16_y, result);
+ } else {
+ const uint8x8_t result = vqmovun_s16(vcombine_s16(sums, sums));
+ if (width == 2) {
+ Store2<0>(dest_y, result);
+ } else {
+ StoreLo4(dest_y, result);
+ }
+ }
+ p += step_y;
+ const int p_diff =
+ (p >> kScaleSubPixelBits) - (prev_p >> kScaleSubPixelBits);
+ prev_p = p;
+ // Here we load extra source in case it is needed. If |p_diff| == 0, these
+ // values will be unused, but it's faster to load than to branch.
+ s[num_taps] = vld1_s16(src_y + num_taps * src_stride);
+ if (grade_y > 1) {
+ s[num_taps + 1] = vld1_s16(src_y + (num_taps + 1) * src_stride);
+ }
+ dest16_y += dest_stride;
+ dest_y += dest_stride;
+
+ filter_id = (p >> 6) & kSubPixelMask;
+ filter = vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
+ sums = Sum2DVerticalTaps4<num_taps, is_compound>(&s[p_diff], filter);
+ if (is_compound) {
+ assert(width != 2);
+ const uint16x4_t result = vreinterpret_u16_s16(sums);
+ vst1_u16(dest16_y, result);
+ } else {
+ const uint8x8_t result = vqmovun_s16(vcombine_s16(sums, sums));
+ if (width == 2) {
+ Store2<0>(dest_y, result);
+ } else {
+ StoreLo4(dest_y, result);
+ }
+ }
+ p += step_y;
+ src_y = src + (p >> kScaleSubPixelBits) * src_stride;
+ prev_p = p;
+ dest16_y += dest_stride;
+ dest_y += dest_stride;
+
+ y += 2;
+ } while (y < height);
+}
+
+template <int num_taps, int grade_y, bool is_compound>
+inline void ConvolveVerticalScale(const int16_t* src, const int width,
+ const int subpixel_y, const int filter_index,
+ const int step_y, const int height,
+ void* dest, const ptrdiff_t dest_stride) {
+ constexpr ptrdiff_t src_stride = kIntermediateStride;
+ // A possible improvement is to use arithmetic to decide how many times to
+ // apply filters to same source before checking whether to load new srcs.
+ // However, this will only improve performance with very small step sizes.
+ int16x8_t s[num_taps + grade_y];
+ // |dest| is 16-bit in compound mode, Pixel otherwise.
+ uint16_t* dest16_y;
+ uint8_t* dest_y;
+
+ int x = 0;
+ do { // x < width
+ const int16_t* src_x = src + x;
+ const int16_t* src_y = src_x;
+ dest16_y = static_cast<uint16_t*>(dest) + x;
+ dest_y = static_cast<uint8_t*>(dest) + x;
+ int p = subpixel_y & 1023;
+ int prev_p = p;
+ int y = 0;
+ do { // y < height
+ for (int i = 0; i < num_taps; ++i) {
+ s[i] = vld1q_s16(src_y + i * src_stride);
+ }
+ int filter_id = (p >> 6) & kSubPixelMask;
+ int16x8_t filter =
+ vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
+ int16x8_t sum = SimpleSum2DVerticalTaps<num_taps, is_compound>(s, filter);
+ if (is_compound) {
+ vst1q_u16(dest16_y, vreinterpretq_u16_s16(sum));
+ } else {
+ vst1_u8(dest_y, vqmovun_s16(sum));
+ }
+ p += step_y;
+ const int p_diff =
+ (p >> kScaleSubPixelBits) - (prev_p >> kScaleSubPixelBits);
+ // |grade_y| > 1 always means p_diff > 0, so load vectors that may be
+ // needed. Otherwise, we only need to load one vector because |p_diff|
+ // can't exceed 1.
+ s[num_taps] = vld1q_s16(src_y + num_taps * src_stride);
+ if (grade_y > 1) {
+ s[num_taps + 1] = vld1q_s16(src_y + (num_taps + 1) * src_stride);
+ }
+ dest16_y += dest_stride;
+ dest_y += dest_stride;
+
+ filter_id = (p >> 6) & kSubPixelMask;
+ filter = vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
+ sum = SimpleSum2DVerticalTaps<num_taps, is_compound>(&s[p_diff], filter);
+ if (is_compound) {
+ vst1q_u16(dest16_y, vreinterpretq_u16_s16(sum));
+ } else {
+ vst1_u8(dest_y, vqmovun_s16(sum));
+ }
+ p += step_y;
+ src_y = src_x + (p >> kScaleSubPixelBits) * src_stride;
+ prev_p = p;
+ dest16_y += dest_stride;
+ dest_y += dest_stride;
+
+ y += 2;
+ } while (y < height);
+ x += 8;
+ } while (x < width);
+}
+
+template <bool is_compound>
+void ConvolveScale2D_NEON(const void* const reference,
+ const ptrdiff_t reference_stride,
+ const int horizontal_filter_index,
+ const int vertical_filter_index, const int subpixel_x,
+ const int subpixel_y, const int step_x,
+ const int step_y, const int width, const int height,
+ void* prediction, const ptrdiff_t pred_stride) {
+ const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
+ const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
+ assert(step_x <= 2048);
+ const int num_vert_taps = GetNumTapsInFilter(vert_filter_index);
+ const int intermediate_height =
+ (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
+ kScaleSubPixelBits) +
+ num_vert_taps;
+ assert(step_x <= 2048);
+ // The output of the horizontal filter, i.e. the intermediate_result, is
+ // guaranteed to fit in int16_t.
+ int16_t intermediate_result[kMaxSuperBlockSizeInPixels *
+ (2 * kMaxSuperBlockSizeInPixels + 8)];
+
+ // Horizontal filter.
+ // Filter types used for width <= 4 are different from those for width > 4.
+ // When width > 4, the valid filter index range is always [0, 3].
+ // When width <= 4, the valid filter index range is always [3, 5].
+ // Similarly for height.
+ int filter_index = GetFilterIndex(horizontal_filter_index, width);
+ int16_t* intermediate = intermediate_result;
+ const ptrdiff_t src_stride = reference_stride;
+ const auto* src = static_cast<const uint8_t*>(reference);
+ const int vert_kernel_offset = (8 - num_vert_taps) / 2;
+ src += vert_kernel_offset * src_stride;
+
+ // Derive the maximum value of |step_x| at which all source values fit in one
+ // 16-byte load. Final index is src_x + |num_taps| - 1 < 16
+ // step_x*7 is the final base subpel index for the shuffle mask for filter
+ // inputs in each iteration on large blocks. When step_x is large, we need a
+ // larger structure and use a larger table lookup in order to gather all
+ // filter inputs.
+ // |num_taps| - 1 is the shuffle index of the final filter input.
+ const int num_horiz_taps = GetNumTapsInFilter(horiz_filter_index);
+ const int kernel_start_ceiling = 16 - num_horiz_taps;
+ // This truncated quotient |grade_x_threshold| selects |step_x| such that:
+ // (step_x * 7) >> kScaleSubPixelBits < single load limit
+ const int grade_x_threshold =
+ (kernel_start_ceiling << kScaleSubPixelBits) / 7;
+ switch (filter_index) {
+ case 0:
+ if (step_x > grade_x_threshold) {
+ ConvolveKernelHorizontalSigned6Tap<2>(
+ src, src_stride, width, subpixel_x, step_x, intermediate_height,
+ intermediate);
+ } else {
+ ConvolveKernelHorizontalSigned6Tap<1>(
+ src, src_stride, width, subpixel_x, step_x, intermediate_height,
+ intermediate);
+ }
+ break;
+ case 1:
+ if (step_x > grade_x_threshold) {
+ ConvolveKernelHorizontalMixed6Tap<2>(src, src_stride, width, subpixel_x,
+ step_x, intermediate_height,
+ intermediate);
+
+ } else {
+ ConvolveKernelHorizontalMixed6Tap<1>(src, src_stride, width, subpixel_x,
+ step_x, intermediate_height,
+ intermediate);
+ }
+ break;
+ case 2:
+ if (step_x > grade_x_threshold) {
+ ConvolveKernelHorizontalSigned8Tap<2>(
+ src, src_stride, width, subpixel_x, step_x, intermediate_height,
+ intermediate);
+ } else {
+ ConvolveKernelHorizontalSigned8Tap<1>(
+ src, src_stride, width, subpixel_x, step_x, intermediate_height,
+ intermediate);
+ }
+ break;
+ case 3:
+ if (step_x > grade_x_threshold) {
+ ConvolveKernelHorizontal2Tap<2>(src, src_stride, width, subpixel_x,
+ step_x, intermediate_height,
+ intermediate);
+ } else {
+ ConvolveKernelHorizontal2Tap<1>(src, src_stride, width, subpixel_x,
+ step_x, intermediate_height,
+ intermediate);
+ }
+ break;
+ case 4:
+ assert(width <= 4);
+ ConvolveKernelHorizontalSigned4Tap(src, src_stride, subpixel_x, step_x,
+ intermediate_height, intermediate);
+ break;
+ default:
+ assert(filter_index == 5);
+ ConvolveKernelHorizontalPositive4Tap(src, src_stride, subpixel_x, step_x,
+ intermediate_height, intermediate);
+ }
+ // Vertical filter.
+ filter_index = GetFilterIndex(vertical_filter_index, height);
+ intermediate = intermediate_result;
+
+ switch (filter_index) {
+ case 0:
+ case 1:
+ if (step_y <= 1024) {
+ if (!is_compound && width == 2) {
+ ConvolveVerticalScale4xH<6, 1, 2, is_compound>(
+ intermediate, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ } else if (width == 4) {
+ ConvolveVerticalScale4xH<6, 1, 4, is_compound>(
+ intermediate, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ } else {
+ ConvolveVerticalScale<6, 1, is_compound>(
+ intermediate, width, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ }
+ } else {
+ if (!is_compound && width == 2) {
+ ConvolveVerticalScale4xH<6, 2, 2, is_compound>(
+ intermediate, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ } else if (width == 4) {
+ ConvolveVerticalScale4xH<6, 2, 4, is_compound>(
+ intermediate, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ } else {
+ ConvolveVerticalScale<6, 2, is_compound>(
+ intermediate, width, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ }
+ }
+ break;
+ case 2:
+ if (step_y <= 1024) {
+ if (!is_compound && width == 2) {
+ ConvolveVerticalScale4xH<8, 1, 2, is_compound>(
+ intermediate, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ } else if (width == 4) {
+ ConvolveVerticalScale4xH<8, 1, 4, is_compound>(
+ intermediate, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ } else {
+ ConvolveVerticalScale<8, 1, is_compound>(
+ intermediate, width, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ }
+ } else {
+ if (!is_compound && width == 2) {
+ ConvolveVerticalScale4xH<8, 2, 2, is_compound>(
+ intermediate, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ } else if (width == 4) {
+ ConvolveVerticalScale4xH<8, 2, 4, is_compound>(
+ intermediate, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ } else {
+ ConvolveVerticalScale<8, 2, is_compound>(
+ intermediate, width, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ }
+ }
+ break;
+ case 3:
+ if (step_y <= 1024) {
+ if (!is_compound && width == 2) {
+ ConvolveVerticalScale4xH<2, 1, 2, is_compound>(
+ intermediate, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ } else if (width == 4) {
+ ConvolveVerticalScale4xH<2, 1, 4, is_compound>(
+ intermediate, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ } else {
+ ConvolveVerticalScale<2, 1, is_compound>(
+ intermediate, width, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ }
+ } else {
+ if (!is_compound && width == 2) {
+ ConvolveVerticalScale4xH<2, 2, 2, is_compound>(
+ intermediate, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ } else if (width == 4) {
+ ConvolveVerticalScale4xH<2, 2, 4, is_compound>(
+ intermediate, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ } else {
+ ConvolveVerticalScale<2, 2, is_compound>(
+ intermediate, width, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ }
+ }
+ break;
+ case 4:
+ default:
+ assert(filter_index == 4 || filter_index == 5);
+ assert(height <= 4);
+ if (step_y <= 1024) {
+ if (!is_compound && width == 2) {
+ ConvolveVerticalScale4xH<4, 1, 2, is_compound>(
+ intermediate, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ } else if (width == 4) {
+ ConvolveVerticalScale4xH<4, 1, 4, is_compound>(
+ intermediate, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ } else {
+ ConvolveVerticalScale<4, 1, is_compound>(
+ intermediate, width, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ }
+ } else {
+ if (!is_compound && width == 2) {
+ ConvolveVerticalScale4xH<4, 2, 2, is_compound>(
+ intermediate, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ } else if (width == 4) {
+ ConvolveVerticalScale4xH<4, 2, 4, is_compound>(
+ intermediate, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ } else {
+ ConvolveVerticalScale<4, 2, is_compound>(
+ intermediate, width, subpixel_y, filter_index, step_y, height,
+ prediction, pred_stride);
+ }
+ }
+ }
+}
+
+void ConvolveHorizontal_NEON(const void* const reference,
+ const ptrdiff_t reference_stride,
+ const int horizontal_filter_index,
+ const int /*vertical_filter_index*/,
+ const int horizontal_filter_id,
+ const int /*vertical_filter_id*/, const int width,
+ const int height, void* prediction,
+ const ptrdiff_t pred_stride) {
+ const int filter_index = GetFilterIndex(horizontal_filter_index, width);
+ // Set |src| to the outermost tap.
+ const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
+ auto* dest = static_cast<uint8_t*>(prediction);
+
+ DoHorizontalPass(src, reference_stride, dest, pred_stride, width, height,
+ horizontal_filter_id, filter_index);
+}
+
+// The 1D compound shift is always |kInterRoundBitsHorizontal|, even for 1D
+// Vertical calculations.
+uint16x8_t Compound1DShift(const int16x8_t sum) {
+ return vreinterpretq_u16_s16(
+ vrshrq_n_s16(sum, kInterRoundBitsHorizontal - 1));
+}
+
+template <int filter_index, bool is_compound = false,
+ bool negative_outside_taps = false>
+void FilterVertical(const uint8_t* src, const ptrdiff_t src_stride,
+ void* const dst, const ptrdiff_t dst_stride,
+ const int width, const int height,
+ const uint8x8_t* const taps) {
+ const int num_taps = GetNumTapsInFilter(filter_index);
+ const int next_row = num_taps - 1;
+ auto* dst8 = static_cast<uint8_t*>(dst);
+ auto* dst16 = static_cast<uint16_t*>(dst);
+ assert(width >= 8);
+
+ int x = 0;
+ do {
+ const uint8_t* src_x = src + x;
+ uint8x8_t srcs[8];
+ srcs[0] = vld1_u8(src_x);
+ src_x += src_stride;
+ if (num_taps >= 4) {
+ srcs[1] = vld1_u8(src_x);
+ src_x += src_stride;
+ srcs[2] = vld1_u8(src_x);
+ src_x += src_stride;
+ if (num_taps >= 6) {
+ srcs[3] = vld1_u8(src_x);
+ src_x += src_stride;
+ srcs[4] = vld1_u8(src_x);
+ src_x += src_stride;
+ if (num_taps == 8) {
+ srcs[5] = vld1_u8(src_x);
+ src_x += src_stride;
+ srcs[6] = vld1_u8(src_x);
+ src_x += src_stride;
+ }
+ }
+ }
+
+ int y = 0;
+ do {
+ srcs[next_row] = vld1_u8(src_x);
+ src_x += src_stride;
+
+ const int16x8_t sums =
+ SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
+ if (is_compound) {
+ const uint16x8_t results = Compound1DShift(sums);
+ vst1q_u16(dst16 + x + y * dst_stride, results);
+ } else {
+ const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
+ vst1_u8(dst8 + x + y * dst_stride, results);
+ }
+
+ srcs[0] = srcs[1];
+ if (num_taps >= 4) {
+ srcs[1] = srcs[2];
+ srcs[2] = srcs[3];
+ if (num_taps >= 6) {
+ srcs[3] = srcs[4];
+ srcs[4] = srcs[5];
+ if (num_taps == 8) {
+ srcs[5] = srcs[6];
+ srcs[6] = srcs[7];
+ }
+ }
+ }
+ } while (++y < height);
+ x += 8;
+ } while (x < width);
+}
+
+template <int filter_index, bool is_compound = false,
+ bool negative_outside_taps = false>
+void FilterVertical4xH(const uint8_t* src, const ptrdiff_t src_stride,
+ void* const dst, const ptrdiff_t dst_stride,
+ const int height, const uint8x8_t* const taps) {
+ const int num_taps = GetNumTapsInFilter(filter_index);
+ auto* dst8 = static_cast<uint8_t*>(dst);
+ auto* dst16 = static_cast<uint16_t*>(dst);
+
+ uint8x8_t srcs[9];
+
+ if (num_taps == 2) {
+ srcs[2] = vdup_n_u8(0);
+
+ srcs[0] = Load4(src);
+ src += src_stride;
+
+ int y = 0;
+ do {
+ srcs[0] = Load4<1>(src, srcs[0]);
+ src += src_stride;
+ srcs[2] = Load4<0>(src, srcs[2]);
+ src += src_stride;
+ srcs[1] = vext_u8(srcs[0], srcs[2], 4);
+
+ const int16x8_t sums =
+ SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
+ if (is_compound) {
+ const uint16x8_t results = Compound1DShift(sums);
+
+ vst1q_u16(dst16, results);
+ dst16 += 4 << 1;
+ } else {
+ const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
+
+ StoreLo4(dst8, results);
+ dst8 += dst_stride;
+ StoreHi4(dst8, results);
+ dst8 += dst_stride;
+ }
+
+ srcs[0] = srcs[2];
+ y += 2;
+ } while (y < height);
+ } else if (num_taps == 4) {
+ srcs[4] = vdup_n_u8(0);
+
+ srcs[0] = Load4(src);
+ src += src_stride;
+ srcs[0] = Load4<1>(src, srcs[0]);
+ src += src_stride;
+ srcs[2] = Load4(src);
+ src += src_stride;
+ srcs[1] = vext_u8(srcs[0], srcs[2], 4);
+
+ int y = 0;
+ do {
+ srcs[2] = Load4<1>(src, srcs[2]);
+ src += src_stride;
+ srcs[4] = Load4<0>(src, srcs[4]);
+ src += src_stride;
+ srcs[3] = vext_u8(srcs[2], srcs[4], 4);
+
+ const int16x8_t sums =
+ SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
+ if (is_compound) {
+ const uint16x8_t results = Compound1DShift(sums);
+
+ vst1q_u16(dst16, results);
+ dst16 += 4 << 1;
+ } else {
+ const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
+
+ StoreLo4(dst8, results);
+ dst8 += dst_stride;
+ StoreHi4(dst8, results);
+ dst8 += dst_stride;
+ }
+
+ srcs[0] = srcs[2];
+ srcs[1] = srcs[3];
+ srcs[2] = srcs[4];
+ y += 2;
+ } while (y < height);
+ } else if (num_taps == 6) {
+ srcs[6] = vdup_n_u8(0);
+
+ srcs[0] = Load4(src);
+ src += src_stride;
+ srcs[0] = Load4<1>(src, srcs[0]);
+ src += src_stride;
+ srcs[2] = Load4(src);
+ src += src_stride;
+ srcs[1] = vext_u8(srcs[0], srcs[2], 4);
+ srcs[2] = Load4<1>(src, srcs[2]);
+ src += src_stride;
+ srcs[4] = Load4(src);
+ src += src_stride;
+ srcs[3] = vext_u8(srcs[2], srcs[4], 4);
+
+ int y = 0;
+ do {
+ srcs[4] = Load4<1>(src, srcs[4]);
+ src += src_stride;
+ srcs[6] = Load4<0>(src, srcs[6]);
+ src += src_stride;
+ srcs[5] = vext_u8(srcs[4], srcs[6], 4);
+
+ const int16x8_t sums =
+ SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
+ if (is_compound) {
+ const uint16x8_t results = Compound1DShift(sums);
+
+ vst1q_u16(dst16, results);
+ dst16 += 4 << 1;
+ } else {
+ const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
+
+ StoreLo4(dst8, results);
+ dst8 += dst_stride;
+ StoreHi4(dst8, results);
+ dst8 += dst_stride;
+ }
+
+ srcs[0] = srcs[2];
+ srcs[1] = srcs[3];
+ srcs[2] = srcs[4];
+ srcs[3] = srcs[5];
+ srcs[4] = srcs[6];
+ y += 2;
+ } while (y < height);
+ } else if (num_taps == 8) {
+ srcs[8] = vdup_n_u8(0);
+
+ srcs[0] = Load4(src);
+ src += src_stride;
+ srcs[0] = Load4<1>(src, srcs[0]);
+ src += src_stride;
+ srcs[2] = Load4(src);
+ src += src_stride;
+ srcs[1] = vext_u8(srcs[0], srcs[2], 4);
+ srcs[2] = Load4<1>(src, srcs[2]);
+ src += src_stride;
+ srcs[4] = Load4(src);
+ src += src_stride;
+ srcs[3] = vext_u8(srcs[2], srcs[4], 4);
+ srcs[4] = Load4<1>(src, srcs[4]);
+ src += src_stride;
+ srcs[6] = Load4(src);
+ src += src_stride;
+ srcs[5] = vext_u8(srcs[4], srcs[6], 4);
+
+ int y = 0;
+ do {
+ srcs[6] = Load4<1>(src, srcs[6]);
+ src += src_stride;
+ srcs[8] = Load4<0>(src, srcs[8]);
+ src += src_stride;
+ srcs[7] = vext_u8(srcs[6], srcs[8], 4);
+
+ const int16x8_t sums =
+ SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
+ if (is_compound) {
+ const uint16x8_t results = Compound1DShift(sums);
+
+ vst1q_u16(dst16, results);
+ dst16 += 4 << 1;
+ } else {
+ const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
+
+ StoreLo4(dst8, results);
+ dst8 += dst_stride;
+ StoreHi4(dst8, results);
+ dst8 += dst_stride;
+ }
+
+ srcs[0] = srcs[2];
+ srcs[1] = srcs[3];
+ srcs[2] = srcs[4];
+ srcs[3] = srcs[5];
+ srcs[4] = srcs[6];
+ srcs[5] = srcs[7];
+ srcs[6] = srcs[8];
+ y += 2;
+ } while (y < height);
+ }
+}
+
+template <int filter_index, bool negative_outside_taps = false>
+void FilterVertical2xH(const uint8_t* src, const ptrdiff_t src_stride,
+ void* const dst, const ptrdiff_t dst_stride,
+ const int height, const uint8x8_t* const taps) {
+ const int num_taps = GetNumTapsInFilter(filter_index);
+ auto* dst8 = static_cast<uint8_t*>(dst);
+
+ uint8x8_t srcs[9];
+
+ if (num_taps == 2) {
+ srcs[2] = vdup_n_u8(0);
+
+ srcs[0] = Load2(src);
+ src += src_stride;
+
+ int y = 0;
+ do {
+ srcs[0] = Load2<1>(src, srcs[0]);
+ src += src_stride;
+ srcs[0] = Load2<2>(src, srcs[0]);
+ src += src_stride;
+ srcs[0] = Load2<3>(src, srcs[0]);
+ src += src_stride;
+ srcs[2] = Load2<0>(src, srcs[2]);
+ src += src_stride;
+ srcs[1] = vext_u8(srcs[0], srcs[2], 2);
+
+ // This uses srcs[0]..srcs[1].
+ const int16x8_t sums =
+ SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
+ const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
+
+ Store2<0>(dst8, results);
+ dst8 += dst_stride;
+ Store2<1>(dst8, results);
+ if (height == 2) return;
+ dst8 += dst_stride;
+ Store2<2>(dst8, results);
+ dst8 += dst_stride;
+ Store2<3>(dst8, results);
+ dst8 += dst_stride;
+
+ srcs[0] = srcs[2];
+ y += 4;
+ } while (y < height);
+ } else if (num_taps == 4) {
+ srcs[4] = vdup_n_u8(0);
+
+ srcs[0] = Load2(src);
+ src += src_stride;
+ srcs[0] = Load2<1>(src, srcs[0]);
+ src += src_stride;
+ srcs[0] = Load2<2>(src, srcs[0]);
+ src += src_stride;
+
+ int y = 0;
+ do {
+ srcs[0] = Load2<3>(src, srcs[0]);
+ src += src_stride;
+ srcs[4] = Load2<0>(src, srcs[4]);
+ src += src_stride;
+ srcs[1] = vext_u8(srcs[0], srcs[4], 2);
+ srcs[4] = Load2<1>(src, srcs[4]);
+ src += src_stride;
+ srcs[2] = vext_u8(srcs[0], srcs[4], 4);
+ srcs[4] = Load2<2>(src, srcs[4]);
+ src += src_stride;
+ srcs[3] = vext_u8(srcs[0], srcs[4], 6);
+
+ // This uses srcs[0]..srcs[3].
+ const int16x8_t sums =
+ SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
+ const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
+
+ Store2<0>(dst8, results);
+ dst8 += dst_stride;
+ Store2<1>(dst8, results);
+ if (height == 2) return;
+ dst8 += dst_stride;
+ Store2<2>(dst8, results);
+ dst8 += dst_stride;
+ Store2<3>(dst8, results);
+ dst8 += dst_stride;
+
+ srcs[0] = srcs[4];
+ y += 4;
+ } while (y < height);
+ } else if (num_taps == 6) {
+ // During the vertical pass the number of taps is restricted when
+ // |height| <= 4.
+ assert(height > 4);
+ srcs[8] = vdup_n_u8(0);
+
+ srcs[0] = Load2(src);
+ src += src_stride;
+ srcs[0] = Load2<1>(src, srcs[0]);
+ src += src_stride;
+ srcs[0] = Load2<2>(src, srcs[0]);
+ src += src_stride;
+ srcs[0] = Load2<3>(src, srcs[0]);
+ src += src_stride;
+ srcs[4] = Load2(src);
+ src += src_stride;
+ srcs[1] = vext_u8(srcs[0], srcs[4], 2);
+
+ int y = 0;
+ do {
+ srcs[4] = Load2<1>(src, srcs[4]);
+ src += src_stride;
+ srcs[2] = vext_u8(srcs[0], srcs[4], 4);
+ srcs[4] = Load2<2>(src, srcs[4]);
+ src += src_stride;
+ srcs[3] = vext_u8(srcs[0], srcs[4], 6);
+ srcs[4] = Load2<3>(src, srcs[4]);
+ src += src_stride;
+ srcs[8] = Load2<0>(src, srcs[8]);
+ src += src_stride;
+ srcs[5] = vext_u8(srcs[4], srcs[8], 2);
+
+ // This uses srcs[0]..srcs[5].
+ const int16x8_t sums =
+ SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
+ const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
+
+ Store2<0>(dst8, results);
+ dst8 += dst_stride;
+ Store2<1>(dst8, results);
+ dst8 += dst_stride;
+ Store2<2>(dst8, results);
+ dst8 += dst_stride;
+ Store2<3>(dst8, results);
+ dst8 += dst_stride;
+
+ srcs[0] = srcs[4];
+ srcs[1] = srcs[5];
+ srcs[4] = srcs[8];
+ y += 4;
+ } while (y < height);
+ } else if (num_taps == 8) {
+ // During the vertical pass the number of taps is restricted when
+ // |height| <= 4.
+ assert(height > 4);
+ srcs[8] = vdup_n_u8(0);
+
+ srcs[0] = Load2(src);
+ src += src_stride;
+ srcs[0] = Load2<1>(src, srcs[0]);
+ src += src_stride;
+ srcs[0] = Load2<2>(src, srcs[0]);
+ src += src_stride;
+ srcs[0] = Load2<3>(src, srcs[0]);
+ src += src_stride;
+ srcs[4] = Load2(src);
+ src += src_stride;
+ srcs[1] = vext_u8(srcs[0], srcs[4], 2);
+ srcs[4] = Load2<1>(src, srcs[4]);
+ src += src_stride;
+ srcs[2] = vext_u8(srcs[0], srcs[4], 4);
+ srcs[4] = Load2<2>(src, srcs[4]);
+ src += src_stride;
+ srcs[3] = vext_u8(srcs[0], srcs[4], 6);
+
+ int y = 0;
+ do {
+ srcs[4] = Load2<3>(src, srcs[4]);
+ src += src_stride;
+ srcs[8] = Load2<0>(src, srcs[8]);
+ src += src_stride;
+ srcs[5] = vext_u8(srcs[4], srcs[8], 2);
+ srcs[8] = Load2<1>(src, srcs[8]);
+ src += src_stride;
+ srcs[6] = vext_u8(srcs[4], srcs[8], 4);
+ srcs[8] = Load2<2>(src, srcs[8]);
+ src += src_stride;
+ srcs[7] = vext_u8(srcs[4], srcs[8], 6);
+
+ // This uses srcs[0]..srcs[7].
+ const int16x8_t sums =
+ SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
+ const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
+
+ Store2<0>(dst8, results);
+ dst8 += dst_stride;
+ Store2<1>(dst8, results);
+ dst8 += dst_stride;
+ Store2<2>(dst8, results);
+ dst8 += dst_stride;
+ Store2<3>(dst8, results);
+ dst8 += dst_stride;
+
+ srcs[0] = srcs[4];
+ srcs[1] = srcs[5];
+ srcs[2] = srcs[6];
+ srcs[3] = srcs[7];
+ srcs[4] = srcs[8];
+ y += 4;
+ } while (y < height);
+ }
+}
+
+// This function is a simplified version of Convolve2D_C.
+// It is called when it is single prediction mode, where only vertical
+// filtering is required.
+// The output is the single prediction of the block, clipped to valid pixel
+// range.
+void ConvolveVertical_NEON(const void* const reference,
+ const ptrdiff_t reference_stride,
+ const int /*horizontal_filter_index*/,
+ const int vertical_filter_index,
+ const int /*horizontal_filter_id*/,
+ const int vertical_filter_id, const int width,
+ const int height, void* prediction,
+ const ptrdiff_t pred_stride) {
+ const int filter_index = GetFilterIndex(vertical_filter_index, height);
+ const int vertical_taps = GetNumTapsInFilter(filter_index);
+ const ptrdiff_t src_stride = reference_stride;
+ const auto* src = static_cast<const uint8_t*>(reference) -
+ (vertical_taps / 2 - 1) * src_stride;
+ auto* dest = static_cast<uint8_t*>(prediction);
+ const ptrdiff_t dest_stride = pred_stride;
+ assert(vertical_filter_id != 0);
+
+ uint8x8_t taps[8];
+ for (int k = 0; k < kSubPixelTaps; ++k) {
+ taps[k] =
+ vdup_n_u8(kAbsHalfSubPixelFilters[filter_index][vertical_filter_id][k]);
+ }
+
+ if (filter_index == 0) { // 6 tap.
+ if (width == 2) {
+ FilterVertical2xH<0>(src, src_stride, dest, dest_stride, height,
+ taps + 1);
+ } else if (width == 4) {
+ FilterVertical4xH<0>(src, src_stride, dest, dest_stride, height,
+ taps + 1);
+ } else {
+ FilterVertical<0>(src, src_stride, dest, dest_stride, width, height,
+ taps + 1);
+ }
+ } else if ((filter_index == 1) & ((vertical_filter_id == 1) |
+ (vertical_filter_id == 15))) { // 5 tap.
+ if (width == 2) {
+ FilterVertical2xH<1>(src, src_stride, dest, dest_stride, height,
+ taps + 1);
+ } else if (width == 4) {
+ FilterVertical4xH<1>(src, src_stride, dest, dest_stride, height,
+ taps + 1);
+ } else {
+ FilterVertical<1>(src, src_stride, dest, dest_stride, width, height,
+ taps + 1);
+ }
+ } else if ((filter_index == 1) &
+ ((vertical_filter_id == 7) | (vertical_filter_id == 8) |
+ (vertical_filter_id == 9))) { // 6 tap with weird negative taps.
+ if (width == 2) {
+ FilterVertical2xH<1,
+ /*negative_outside_taps=*/true>(
+ src, src_stride, dest, dest_stride, height, taps + 1);
+ } else if (width == 4) {
+ FilterVertical4xH<1, /*is_compound=*/false,
+ /*negative_outside_taps=*/true>(
+ src, src_stride, dest, dest_stride, height, taps + 1);
+ } else {
+ FilterVertical<1, /*is_compound=*/false, /*negative_outside_taps=*/true>(
+ src, src_stride, dest, dest_stride, width, height, taps + 1);
+ }
+ } else if (filter_index == 2) { // 8 tap.
+ if (width == 2) {
+ FilterVertical2xH<2>(src, src_stride, dest, dest_stride, height, taps);
+ } else if (width == 4) {
+ FilterVertical4xH<2>(src, src_stride, dest, dest_stride, height, taps);
+ } else {
+ FilterVertical<2>(src, src_stride, dest, dest_stride, width, height,
+ taps);
+ }
+ } else if (filter_index == 3) { // 2 tap.
+ if (width == 2) {
+ FilterVertical2xH<3>(src, src_stride, dest, dest_stride, height,
+ taps + 3);
+ } else if (width == 4) {
+ FilterVertical4xH<3>(src, src_stride, dest, dest_stride, height,
+ taps + 3);
+ } else {
+ FilterVertical<3>(src, src_stride, dest, dest_stride, width, height,
+ taps + 3);
+ }
+ } else if (filter_index == 4) { // 4 tap.
+ // Outside taps are negative.
+ if (width == 2) {
+ FilterVertical2xH<4>(src, src_stride, dest, dest_stride, height,
+ taps + 2);
+ } else if (width == 4) {
+ FilterVertical4xH<4>(src, src_stride, dest, dest_stride, height,
+ taps + 2);
+ } else {
+ FilterVertical<4>(src, src_stride, dest, dest_stride, width, height,
+ taps + 2);
+ }
+ } else {
+ // 4 tap. When |filter_index| == 1 the |vertical_filter_id| values listed
+ // below map to 4 tap filters.
+ assert(filter_index == 5 ||
+ (filter_index == 1 &&
+ (vertical_filter_id == 2 || vertical_filter_id == 3 ||
+ vertical_filter_id == 4 || vertical_filter_id == 5 ||
+ vertical_filter_id == 6 || vertical_filter_id == 10 ||
+ vertical_filter_id == 11 || vertical_filter_id == 12 ||
+ vertical_filter_id == 13 || vertical_filter_id == 14)));
+ // According to GetNumTapsInFilter() this has 6 taps but here we are
+ // treating it as though it has 4.
+ if (filter_index == 1) src += src_stride;
+ if (width == 2) {
+ FilterVertical2xH<5>(src, src_stride, dest, dest_stride, height,
+ taps + 2);
+ } else if (width == 4) {
+ FilterVertical4xH<5>(src, src_stride, dest, dest_stride, height,
+ taps + 2);
+ } else {
+ FilterVertical<5>(src, src_stride, dest, dest_stride, width, height,
+ taps + 2);
+ }
+ }
+}
+
+void ConvolveCompoundCopy_NEON(
+ const void* const reference, const ptrdiff_t reference_stride,
+ const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
+ const int /*horizontal_filter_id*/, const int /*vertical_filter_id*/,
+ const int width, const int height, void* prediction,
+ const ptrdiff_t /*pred_stride*/) {
+ const auto* src = static_cast<const uint8_t*>(reference);
+ const ptrdiff_t src_stride = reference_stride;
+ auto* dest = static_cast<uint16_t*>(prediction);
+ constexpr int final_shift =
+ kInterRoundBitsVertical - kInterRoundBitsCompoundVertical;
+
+ if (width >= 16) {
+ int y = 0;
+ do {
+ int x = 0;
+ do {
+ const uint8x16_t v_src = vld1q_u8(&src[x]);
+ const uint16x8_t v_dest_lo =
+ vshll_n_u8(vget_low_u8(v_src), final_shift);
+ const uint16x8_t v_dest_hi =
+ vshll_n_u8(vget_high_u8(v_src), final_shift);
+ vst1q_u16(&dest[x], v_dest_lo);
+ x += 8;
+ vst1q_u16(&dest[x], v_dest_hi);
+ x += 8;
+ } while (x < width);
+ src += src_stride;
+ dest += width;
+ } while (++y < height);
+ } else if (width == 8) {
+ int y = 0;
+ do {
+ const uint8x8_t v_src = vld1_u8(&src[0]);
+ const uint16x8_t v_dest = vshll_n_u8(v_src, final_shift);
+ vst1q_u16(&dest[0], v_dest);
+ src += src_stride;
+ dest += width;
+ } while (++y < height);
+ } else { /* width == 4 */
+ uint8x8_t v_src = vdup_n_u8(0);
+
+ int y = 0;
+ do {
+ v_src = Load4<0>(&src[0], v_src);
+ src += src_stride;
+ v_src = Load4<1>(&src[0], v_src);
+ src += src_stride;
+ const uint16x8_t v_dest = vshll_n_u8(v_src, final_shift);
+ vst1q_u16(&dest[0], v_dest);
+ dest += 4 << 1;
+ y += 2;
+ } while (y < height);
+ }
+}
+
+void ConvolveCompoundVertical_NEON(
+ const void* const reference, const ptrdiff_t reference_stride,
+ const int /*horizontal_filter_index*/, const int vertical_filter_index,
+ const int /*horizontal_filter_id*/, const int vertical_filter_id,
+ const int width, const int height, void* prediction,
+ const ptrdiff_t /*pred_stride*/) {
+ const int filter_index = GetFilterIndex(vertical_filter_index, height);
+ const int vertical_taps = GetNumTapsInFilter(filter_index);
+ const ptrdiff_t src_stride = reference_stride;
+ const auto* src = static_cast<const uint8_t*>(reference) -
+ (vertical_taps / 2 - 1) * src_stride;
+ auto* dest = static_cast<uint16_t*>(prediction);
+ assert(vertical_filter_id != 0);
+
+ uint8x8_t taps[8];
+ for (int k = 0; k < kSubPixelTaps; ++k) {
+ taps[k] =
+ vdup_n_u8(kAbsHalfSubPixelFilters[filter_index][vertical_filter_id][k]);
+ }
+
+ if (filter_index == 0) { // 6 tap.
+ if (width == 4) {
+ FilterVertical4xH<0, /*is_compound=*/true>(src, src_stride, dest, 4,
+ height, taps + 1);
+ } else {
+ FilterVertical<0, /*is_compound=*/true>(src, src_stride, dest, width,
+ width, height, taps + 1);
+ }
+ } else if ((filter_index == 1) & ((vertical_filter_id == 1) |
+ (vertical_filter_id == 15))) { // 5 tap.
+ if (width == 4) {
+ FilterVertical4xH<1, /*is_compound=*/true>(src, src_stride, dest, 4,
+ height, taps + 1);
+ } else {
+ FilterVertical<1, /*is_compound=*/true>(src, src_stride, dest, width,
+ width, height, taps + 1);
+ }
+ } else if ((filter_index == 1) &
+ ((vertical_filter_id == 7) | (vertical_filter_id == 8) |
+ (vertical_filter_id == 9))) { // 6 tap with weird negative taps.
+ if (width == 4) {
+ FilterVertical4xH<1, /*is_compound=*/true,
+ /*negative_outside_taps=*/true>(src, src_stride, dest,
+ 4, height, taps + 1);
+ } else {
+ FilterVertical<1, /*is_compound=*/true, /*negative_outside_taps=*/true>(
+ src, src_stride, dest, width, width, height, taps + 1);
+ }
+ } else if (filter_index == 2) { // 8 tap.
+ if (width == 4) {
+ FilterVertical4xH<2, /*is_compound=*/true>(src, src_stride, dest, 4,
+ height, taps);
+ } else {
+ FilterVertical<2, /*is_compound=*/true>(src, src_stride, dest, width,
+ width, height, taps);
+ }
+ } else if (filter_index == 3) { // 2 tap.
+ if (width == 4) {
+ FilterVertical4xH<3, /*is_compound=*/true>(src, src_stride, dest, 4,
+ height, taps + 3);
+ } else {
+ FilterVertical<3, /*is_compound=*/true>(src, src_stride, dest, width,
+ width, height, taps + 3);
+ }
+ } else if (filter_index == 4) { // 4 tap.
+ if (width == 4) {
+ FilterVertical4xH<4, /*is_compound=*/true>(src, src_stride, dest, 4,
+ height, taps + 2);
+ } else {
+ FilterVertical<4, /*is_compound=*/true>(src, src_stride, dest, width,
+ width, height, taps + 2);
+ }
+ } else {
+ // 4 tap. When |filter_index| == 1 the |filter_id| values listed below map
+ // to 4 tap filters.
+ assert(filter_index == 5 ||
+ (filter_index == 1 &&
+ (vertical_filter_id == 2 || vertical_filter_id == 3 ||
+ vertical_filter_id == 4 || vertical_filter_id == 5 ||
+ vertical_filter_id == 6 || vertical_filter_id == 10 ||
+ vertical_filter_id == 11 || vertical_filter_id == 12 ||
+ vertical_filter_id == 13 || vertical_filter_id == 14)));
+ // According to GetNumTapsInFilter() this has 6 taps but here we are
+ // treating it as though it has 4.
+ if (filter_index == 1) src += src_stride;
+ if (width == 4) {
+ FilterVertical4xH<5, /*is_compound=*/true>(src, src_stride, dest, 4,
+ height, taps + 2);
+ } else {
+ FilterVertical<5, /*is_compound=*/true>(src, src_stride, dest, width,
+ width, height, taps + 2);
+ }
+ }
+}
+
+void ConvolveCompoundHorizontal_NEON(
+ const void* const reference, const ptrdiff_t reference_stride,
+ const int horizontal_filter_index, const int /*vertical_filter_index*/,
+ const int horizontal_filter_id, const int /*vertical_filter_id*/,
+ const int width, const int height, void* prediction,
+ const ptrdiff_t /*pred_stride*/) {
+ const int filter_index = GetFilterIndex(horizontal_filter_index, width);
+ const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
+ auto* dest = static_cast<uint16_t*>(prediction);
+
+ DoHorizontalPass</*is_2d=*/false, /*is_compound=*/true>(
+ src, reference_stride, dest, width, width, height, horizontal_filter_id,
+ filter_index);
+}
+
+void ConvolveCompound2D_NEON(const void* const reference,
+ const ptrdiff_t reference_stride,
+ const int horizontal_filter_index,
+ const int vertical_filter_index,
+ const int horizontal_filter_id,
+ const int vertical_filter_id, const int width,
+ const int height, void* prediction,
+ const ptrdiff_t /*pred_stride*/) {
+ // The output of the horizontal filter, i.e. the intermediate_result, is
+ // guaranteed to fit in int16_t.
+ uint16_t
+ intermediate_result[kMaxSuperBlockSizeInPixels *
+ (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
+
+ // Horizontal filter.
+ // Filter types used for width <= 4 are different from those for width > 4.
+ // When width > 4, the valid filter index range is always [0, 3].
+ // When width <= 4, the valid filter index range is always [4, 5].
+ // Similarly for height.
+ const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
+ const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
+ const int vertical_taps = GetNumTapsInFilter(vert_filter_index);
+ const int intermediate_height = height + vertical_taps - 1;
+ const ptrdiff_t src_stride = reference_stride;
+ const auto* const src = static_cast<const uint8_t*>(reference) -
+ (vertical_taps / 2 - 1) * src_stride -
+ kHorizontalOffset;
+
+ DoHorizontalPass</*is_2d=*/true, /*is_compound=*/true>(
+ src, src_stride, intermediate_result, width, width, intermediate_height,
+ horizontal_filter_id, horiz_filter_index);
+
+ // Vertical filter.
+ auto* dest = static_cast<uint16_t*>(prediction);
+ assert(vertical_filter_id != 0);
+
+ const ptrdiff_t dest_stride = width;
+ const int16x8_t taps = vmovl_s8(
+ vld1_s8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id]));
+
+ if (vertical_taps == 8) {
+ if (width == 4) {
+ Filter2DVertical4xH<8, /*is_compound=*/true>(intermediate_result, dest,
+ dest_stride, height, taps);
+ } else {
+ Filter2DVertical<8, /*is_compound=*/true>(
+ intermediate_result, dest, dest_stride, width, height, taps);
+ }
+ } else if (vertical_taps == 6) {
+ if (width == 4) {
+ Filter2DVertical4xH<6, /*is_compound=*/true>(intermediate_result, dest,
+ dest_stride, height, taps);
+ } else {
+ Filter2DVertical<6, /*is_compound=*/true>(
+ intermediate_result, dest, dest_stride, width, height, taps);
+ }
+ } else if (vertical_taps == 4) {
+ if (width == 4) {
+ Filter2DVertical4xH<4, /*is_compound=*/true>(intermediate_result, dest,
+ dest_stride, height, taps);
+ } else {
+ Filter2DVertical<4, /*is_compound=*/true>(
+ intermediate_result, dest, dest_stride, width, height, taps);
+ }
+ } else { // |vertical_taps| == 2
+ if (width == 4) {
+ Filter2DVertical4xH<2, /*is_compound=*/true>(intermediate_result, dest,
+ dest_stride, height, taps);
+ } else {
+ Filter2DVertical<2, /*is_compound=*/true>(
+ intermediate_result, dest, dest_stride, width, height, taps);
+ }
+ }
+}
+
+inline void HalfAddHorizontal(const uint8_t* src, uint8_t* dst) {
+ const uint8x16_t left = vld1q_u8(src);
+ const uint8x16_t right = vld1q_u8(src + 1);
+ vst1q_u8(dst, vrhaddq_u8(left, right));
+}
+
+template <int width>
+inline void IntraBlockCopyHorizontal(const uint8_t* src,
+ const ptrdiff_t src_stride,
+ const int height, uint8_t* dst,
+ const ptrdiff_t dst_stride) {
+ const ptrdiff_t src_remainder_stride = src_stride - (width - 16);
+ const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16);
+
+ int y = 0;
+ do {
+ HalfAddHorizontal(src, dst);
+ if (width >= 32) {
+ src += 16;
+ dst += 16;
+ HalfAddHorizontal(src, dst);
+ if (width >= 64) {
+ src += 16;
+ dst += 16;
+ HalfAddHorizontal(src, dst);
+ src += 16;
+ dst += 16;
+ HalfAddHorizontal(src, dst);
+ if (width == 128) {
+ src += 16;
+ dst += 16;
+ HalfAddHorizontal(src, dst);
+ src += 16;
+ dst += 16;
+ HalfAddHorizontal(src, dst);
+ src += 16;
+ dst += 16;
+ HalfAddHorizontal(src, dst);
+ src += 16;
+ dst += 16;
+ HalfAddHorizontal(src, dst);
+ }
+ }
+ }
+ src += src_remainder_stride;
+ dst += dst_remainder_stride;
+ } while (++y < height);
+}
+
+void ConvolveIntraBlockCopyHorizontal_NEON(
+ const void* const reference, const ptrdiff_t reference_stride,
+ const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
+ const int /*subpixel_x*/, const int /*subpixel_y*/, const int width,
+ const int height, void* const prediction, const ptrdiff_t pred_stride) {
+ const auto* src = static_cast<const uint8_t*>(reference);
+ auto* dest = static_cast<uint8_t*>(prediction);
+
+ if (width == 128) {
+ IntraBlockCopyHorizontal<128>(src, reference_stride, height, dest,
+ pred_stride);
+ } else if (width == 64) {
+ IntraBlockCopyHorizontal<64>(src, reference_stride, height, dest,
+ pred_stride);
+ } else if (width == 32) {
+ IntraBlockCopyHorizontal<32>(src, reference_stride, height, dest,
+ pred_stride);
+ } else if (width == 16) {
+ IntraBlockCopyHorizontal<16>(src, reference_stride, height, dest,
+ pred_stride);
+ } else if (width == 8) {
+ int y = 0;
+ do {
+ const uint8x8_t left = vld1_u8(src);
+ const uint8x8_t right = vld1_u8(src + 1);
+ vst1_u8(dest, vrhadd_u8(left, right));
+
+ src += reference_stride;
+ dest += pred_stride;
+ } while (++y < height);
+ } else if (width == 4) {
+ uint8x8_t left = vdup_n_u8(0);
+ uint8x8_t right = vdup_n_u8(0);
+ int y = 0;
+ do {
+ left = Load4<0>(src, left);
+ right = Load4<0>(src + 1, right);
+ src += reference_stride;
+ left = Load4<1>(src, left);
+ right = Load4<1>(src + 1, right);
+ src += reference_stride;
+
+ const uint8x8_t result = vrhadd_u8(left, right);
+
+ StoreLo4(dest, result);
+ dest += pred_stride;
+ StoreHi4(dest, result);
+ dest += pred_stride;
+ y += 2;
+ } while (y < height);
+ } else {
+ assert(width == 2);
+ uint8x8_t left = vdup_n_u8(0);
+ uint8x8_t right = vdup_n_u8(0);
+ int y = 0;
+ do {
+ left = Load2<0>(src, left);
+ right = Load2<0>(src + 1, right);
+ src += reference_stride;
+ left = Load2<1>(src, left);
+ right = Load2<1>(src + 1, right);
+ src += reference_stride;
+
+ const uint8x8_t result = vrhadd_u8(left, right);
+
+ Store2<0>(dest, result);
+ dest += pred_stride;
+ Store2<1>(dest, result);
+ dest += pred_stride;
+ y += 2;
+ } while (y < height);
+ }
+}
+
+template <int width>
+inline void IntraBlockCopyVertical(const uint8_t* src,
+ const ptrdiff_t src_stride, const int height,
+ uint8_t* dst, const ptrdiff_t dst_stride) {
+ const ptrdiff_t src_remainder_stride = src_stride - (width - 16);
+ const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16);
+ uint8x16_t row[8], below[8];
+
+ row[0] = vld1q_u8(src);
+ if (width >= 32) {
+ src += 16;
+ row[1] = vld1q_u8(src);
+ if (width >= 64) {
+ src += 16;
+ row[2] = vld1q_u8(src);
+ src += 16;
+ row[3] = vld1q_u8(src);
+ if (width == 128) {
+ src += 16;
+ row[4] = vld1q_u8(src);
+ src += 16;
+ row[5] = vld1q_u8(src);
+ src += 16;
+ row[6] = vld1q_u8(src);
+ src += 16;
+ row[7] = vld1q_u8(src);
+ }
+ }
+ }
+ src += src_remainder_stride;
+
+ int y = 0;
+ do {
+ below[0] = vld1q_u8(src);
+ if (width >= 32) {
+ src += 16;
+ below[1] = vld1q_u8(src);
+ if (width >= 64) {
+ src += 16;
+ below[2] = vld1q_u8(src);
+ src += 16;
+ below[3] = vld1q_u8(src);
+ if (width == 128) {
+ src += 16;
+ below[4] = vld1q_u8(src);
+ src += 16;
+ below[5] = vld1q_u8(src);
+ src += 16;
+ below[6] = vld1q_u8(src);
+ src += 16;
+ below[7] = vld1q_u8(src);
+ }
+ }
+ }
+ src += src_remainder_stride;
+
+ vst1q_u8(dst, vrhaddq_u8(row[0], below[0]));
+ row[0] = below[0];
+ if (width >= 32) {
+ dst += 16;
+ vst1q_u8(dst, vrhaddq_u8(row[1], below[1]));
+ row[1] = below[1];
+ if (width >= 64) {
+ dst += 16;
+ vst1q_u8(dst, vrhaddq_u8(row[2], below[2]));
+ row[2] = below[2];
+ dst += 16;
+ vst1q_u8(dst, vrhaddq_u8(row[3], below[3]));
+ row[3] = below[3];
+ if (width >= 128) {
+ dst += 16;
+ vst1q_u8(dst, vrhaddq_u8(row[4], below[4]));
+ row[4] = below[4];
+ dst += 16;
+ vst1q_u8(dst, vrhaddq_u8(row[5], below[5]));
+ row[5] = below[5];
+ dst += 16;
+ vst1q_u8(dst, vrhaddq_u8(row[6], below[6]));
+ row[6] = below[6];
+ dst += 16;
+ vst1q_u8(dst, vrhaddq_u8(row[7], below[7]));
+ row[7] = below[7];
+ }
+ }
+ }
+ dst += dst_remainder_stride;
+ } while (++y < height);
+}
+
+void ConvolveIntraBlockCopyVertical_NEON(
+ const void* const reference, const ptrdiff_t reference_stride,
+ const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
+ const int /*horizontal_filter_id*/, const int /*vertical_filter_id*/,
+ const int width, const int height, void* const prediction,
+ const ptrdiff_t pred_stride) {
+ const auto* src = static_cast<const uint8_t*>(reference);
+ auto* dest = static_cast<uint8_t*>(prediction);
+
+ if (width == 128) {
+ IntraBlockCopyVertical<128>(src, reference_stride, height, dest,
+ pred_stride);
+ } else if (width == 64) {
+ IntraBlockCopyVertical<64>(src, reference_stride, height, dest,
+ pred_stride);
+ } else if (width == 32) {
+ IntraBlockCopyVertical<32>(src, reference_stride, height, dest,
+ pred_stride);
+ } else if (width == 16) {
+ IntraBlockCopyVertical<16>(src, reference_stride, height, dest,
+ pred_stride);
+ } else if (width == 8) {
+ uint8x8_t row, below;
+ row = vld1_u8(src);
+ src += reference_stride;
+
+ int y = 0;
+ do {
+ below = vld1_u8(src);
+ src += reference_stride;
+
+ vst1_u8(dest, vrhadd_u8(row, below));
+ dest += pred_stride;
+
+ row = below;
+ } while (++y < height);
+ } else if (width == 4) {
+ uint8x8_t row = Load4(src);
+ uint8x8_t below = vdup_n_u8(0);
+ src += reference_stride;
+
+ int y = 0;
+ do {
+ below = Load4<0>(src, below);
+ src += reference_stride;
+
+ StoreLo4(dest, vrhadd_u8(row, below));
+ dest += pred_stride;
+
+ row = below;
+ } while (++y < height);
+ } else {
+ assert(width == 2);
+ uint8x8_t row = Load2(src);
+ uint8x8_t below = vdup_n_u8(0);
+ src += reference_stride;
+
+ int y = 0;
+ do {
+ below = Load2<0>(src, below);
+ src += reference_stride;
+
+ Store2<0>(dest, vrhadd_u8(row, below));
+ dest += pred_stride;
+
+ row = below;
+ } while (++y < height);
+ }
+}
+
+template <int width>
+inline void IntraBlockCopy2D(const uint8_t* src, const ptrdiff_t src_stride,
+ const int height, uint8_t* dst,
+ const ptrdiff_t dst_stride) {
+ const ptrdiff_t src_remainder_stride = src_stride - (width - 8);
+ const ptrdiff_t dst_remainder_stride = dst_stride - (width - 8);
+ uint16x8_t row[16];
+ row[0] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ if (width >= 16) {
+ src += 8;
+ row[1] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ if (width >= 32) {
+ src += 8;
+ row[2] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ src += 8;
+ row[3] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ if (width >= 64) {
+ src += 8;
+ row[4] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ src += 8;
+ row[5] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ src += 8;
+ row[6] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ src += 8;
+ row[7] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ if (width == 128) {
+ src += 8;
+ row[8] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ src += 8;
+ row[9] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ src += 8;
+ row[10] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ src += 8;
+ row[11] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ src += 8;
+ row[12] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ src += 8;
+ row[13] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ src += 8;
+ row[14] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ src += 8;
+ row[15] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ }
+ }
+ }
+ }
+ src += src_remainder_stride;
+
+ int y = 0;
+ do {
+ const uint16x8_t below_0 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[0], below_0), 2));
+ row[0] = below_0;
+ if (width >= 16) {
+ src += 8;
+ dst += 8;
+
+ const uint16x8_t below_1 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[1], below_1), 2));
+ row[1] = below_1;
+ if (width >= 32) {
+ src += 8;
+ dst += 8;
+
+ const uint16x8_t below_2 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[2], below_2), 2));
+ row[2] = below_2;
+ src += 8;
+ dst += 8;
+
+ const uint16x8_t below_3 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[3], below_3), 2));
+ row[3] = below_3;
+ if (width >= 64) {
+ src += 8;
+ dst += 8;
+
+ const uint16x8_t below_4 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[4], below_4), 2));
+ row[4] = below_4;
+ src += 8;
+ dst += 8;
+
+ const uint16x8_t below_5 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[5], below_5), 2));
+ row[5] = below_5;
+ src += 8;
+ dst += 8;
+
+ const uint16x8_t below_6 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[6], below_6), 2));
+ row[6] = below_6;
+ src += 8;
+ dst += 8;
+
+ const uint16x8_t below_7 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[7], below_7), 2));
+ row[7] = below_7;
+ if (width == 128) {
+ src += 8;
+ dst += 8;
+
+ const uint16x8_t below_8 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[8], below_8), 2));
+ row[8] = below_8;
+ src += 8;
+ dst += 8;
+
+ const uint16x8_t below_9 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[9], below_9), 2));
+ row[9] = below_9;
+ src += 8;
+ dst += 8;
+
+ const uint16x8_t below_10 =
+ vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[10], below_10), 2));
+ row[10] = below_10;
+ src += 8;
+ dst += 8;
+
+ const uint16x8_t below_11 =
+ vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[11], below_11), 2));
+ row[11] = below_11;
+ src += 8;
+ dst += 8;
+
+ const uint16x8_t below_12 =
+ vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[12], below_12), 2));
+ row[12] = below_12;
+ src += 8;
+ dst += 8;
+
+ const uint16x8_t below_13 =
+ vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[13], below_13), 2));
+ row[13] = below_13;
+ src += 8;
+ dst += 8;
+
+ const uint16x8_t below_14 =
+ vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[14], below_14), 2));
+ row[14] = below_14;
+ src += 8;
+ dst += 8;
+
+ const uint16x8_t below_15 =
+ vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+ vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[15], below_15), 2));
+ row[15] = below_15;
+ }
+ }
+ }
+ }
+ src += src_remainder_stride;
+ dst += dst_remainder_stride;
+ } while (++y < height);
+}
+
+void ConvolveIntraBlockCopy2D_NEON(
+ const void* const reference, const ptrdiff_t reference_stride,
+ const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
+ const int /*horizontal_filter_id*/, const int /*vertical_filter_id*/,
+ const int width, const int height, void* const prediction,
+ const ptrdiff_t pred_stride) {
+ const auto* src = static_cast<const uint8_t*>(reference);
+ auto* dest = static_cast<uint8_t*>(prediction);
+ // Note: allow vertical access to height + 1. Because this function is only
+ // for u/v plane of intra block copy, such access is guaranteed to be within
+ // the prediction block.
+
+ if (width == 128) {
+ IntraBlockCopy2D<128>(src, reference_stride, height, dest, pred_stride);
+ } else if (width == 64) {
+ IntraBlockCopy2D<64>(src, reference_stride, height, dest, pred_stride);
+ } else if (width == 32) {
+ IntraBlockCopy2D<32>(src, reference_stride, height, dest, pred_stride);
+ } else if (width == 16) {
+ IntraBlockCopy2D<16>(src, reference_stride, height, dest, pred_stride);
+ } else if (width == 8) {
+ IntraBlockCopy2D<8>(src, reference_stride, height, dest, pred_stride);
+ } else if (width == 4) {
+ uint8x8_t left = Load4(src);
+ uint8x8_t right = Load4(src + 1);
+ src += reference_stride;
+
+ uint16x4_t row = vget_low_u16(vaddl_u8(left, right));
+
+ int y = 0;
+ do {
+ left = Load4<0>(src, left);
+ right = Load4<0>(src + 1, right);
+ src += reference_stride;
+ left = Load4<1>(src, left);
+ right = Load4<1>(src + 1, right);
+ src += reference_stride;
+
+ const uint16x8_t below = vaddl_u8(left, right);
+
+ const uint8x8_t result = vrshrn_n_u16(
+ vaddq_u16(vcombine_u16(row, vget_low_u16(below)), below), 2);
+ StoreLo4(dest, result);
+ dest += pred_stride;
+ StoreHi4(dest, result);
+ dest += pred_stride;
+
+ row = vget_high_u16(below);
+ y += 2;
+ } while (y < height);
+ } else {
+ uint8x8_t left = Load2(src);
+ uint8x8_t right = Load2(src + 1);
+ src += reference_stride;
+
+ uint16x4_t row = vget_low_u16(vaddl_u8(left, right));
+
+ int y = 0;
+ do {
+ left = Load2<0>(src, left);
+ right = Load2<0>(src + 1, right);
+ src += reference_stride;
+ left = Load2<2>(src, left);
+ right = Load2<2>(src + 1, right);
+ src += reference_stride;
+
+ const uint16x8_t below = vaddl_u8(left, right);
+
+ const uint8x8_t result = vrshrn_n_u16(
+ vaddq_u16(vcombine_u16(row, vget_low_u16(below)), below), 2);
+ Store2<0>(dest, result);
+ dest += pred_stride;
+ Store2<2>(dest, result);
+ dest += pred_stride;
+
+ row = vget_high_u16(below);
+ y += 2;
+ } while (y < height);
+ }
+}
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+ dsp->convolve[0][0][0][1] = ConvolveHorizontal_NEON;
+ dsp->convolve[0][0][1][0] = ConvolveVertical_NEON;
+ dsp->convolve[0][0][1][1] = Convolve2D_NEON;
+
+ dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_NEON;
+ dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_NEON;
+ dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_NEON;
+ dsp->convolve[0][1][1][1] = ConvolveCompound2D_NEON;
+
+ dsp->convolve[1][0][0][1] = ConvolveIntraBlockCopyHorizontal_NEON;
+ dsp->convolve[1][0][1][0] = ConvolveIntraBlockCopyVertical_NEON;
+ dsp->convolve[1][0][1][1] = ConvolveIntraBlockCopy2D_NEON;
+
+ dsp->convolve_scale[0] = ConvolveScale2D_NEON<false>;
+ dsp->convolve_scale[1] = ConvolveScale2D_NEON<true>;
+}
+
+} // namespace
+} // namespace low_bitdepth
+
+void ConvolveInit_NEON() { low_bitdepth::Init8bpp(); }
+
+} // namespace dsp
+} // namespace libgav1
+
+#else // !LIBGAV1_ENABLE_NEON
+
+namespace libgav1 {
+namespace dsp {
+
+void ConvolveInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/convolve_neon.h b/src/dsp/arm/convolve_neon.h
new file mode 100644
index 0000000..948ef4d
--- /dev/null
+++ b/src/dsp/arm/convolve_neon.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_CONVOLVE_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_CONVOLVE_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::convolve. This function is not thread-safe.
+void ConvolveInit_NEON();
+
+} // namespace dsp
+} // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveHorizontal LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveVertical LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_Convolve2D LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_ConvolveCompoundCopy LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveCompoundHorizontal LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveCompoundVertical LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveCompound2D LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopyHorizontal LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopyVertical LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopy2D LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_ConvolveScale2D LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveCompoundScale2D LIBGAV1_CPU_NEON
+#endif // LIBGAV1_ENABLE_NEON
+
+#endif // LIBGAV1_SRC_DSP_ARM_CONVOLVE_NEON_H_
diff --git a/src/dsp/arm/distance_weighted_blend_neon.cc b/src/dsp/arm/distance_weighted_blend_neon.cc
new file mode 100644
index 0000000..04952ab
--- /dev/null
+++ b/src/dsp/arm/distance_weighted_blend_neon.cc
@@ -0,0 +1,203 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/distance_weighted_blend.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+constexpr int kInterPostRoundBit = 4;
+
+inline int16x8_t ComputeWeightedAverage8(const int16x8_t pred0,
+ const int16x8_t pred1,
+ const int16x4_t weights[2]) {
+ // TODO(https://issuetracker.google.com/issues/150325685): Investigate range.
+ const int32x4_t wpred0_lo = vmull_s16(weights[0], vget_low_s16(pred0));
+ const int32x4_t wpred0_hi = vmull_s16(weights[0], vget_high_s16(pred0));
+ const int32x4_t blended_lo =
+ vmlal_s16(wpred0_lo, weights[1], vget_low_s16(pred1));
+ const int32x4_t blended_hi =
+ vmlal_s16(wpred0_hi, weights[1], vget_high_s16(pred1));
+
+ return vcombine_s16(vqrshrn_n_s32(blended_lo, kInterPostRoundBit + 4),
+ vqrshrn_n_s32(blended_hi, kInterPostRoundBit + 4));
+}
+
+template <int width, int height>
+inline void DistanceWeightedBlendSmall_NEON(const int16_t* prediction_0,
+ const int16_t* prediction_1,
+ const int16x4_t weights[2],
+ void* const dest,
+ const ptrdiff_t dest_stride) {
+ auto* dst = static_cast<uint8_t*>(dest);
+ constexpr int step = 16 / width;
+
+ for (int y = 0; y < height; y += step) {
+ const int16x8_t src_00 = vld1q_s16(prediction_0);
+ const int16x8_t src_10 = vld1q_s16(prediction_1);
+ prediction_0 += 8;
+ prediction_1 += 8;
+ const int16x8_t res0 = ComputeWeightedAverage8(src_00, src_10, weights);
+
+ const int16x8_t src_01 = vld1q_s16(prediction_0);
+ const int16x8_t src_11 = vld1q_s16(prediction_1);
+ prediction_0 += 8;
+ prediction_1 += 8;
+ const int16x8_t res1 = ComputeWeightedAverage8(src_01, src_11, weights);
+
+ const uint8x8_t result0 = vqmovun_s16(res0);
+ const uint8x8_t result1 = vqmovun_s16(res1);
+ if (width == 4) {
+ StoreLo4(dst, result0);
+ dst += dest_stride;
+ StoreHi4(dst, result0);
+ dst += dest_stride;
+ StoreLo4(dst, result1);
+ dst += dest_stride;
+ StoreHi4(dst, result1);
+ dst += dest_stride;
+ } else {
+ assert(width == 8);
+ vst1_u8(dst, result0);
+ dst += dest_stride;
+ vst1_u8(dst, result1);
+ dst += dest_stride;
+ }
+ }
+}
+
+inline void DistanceWeightedBlendLarge_NEON(const int16_t* prediction_0,
+ const int16_t* prediction_1,
+ const int16x4_t weights[2],
+ const int width, const int height,
+ void* const dest,
+ const ptrdiff_t dest_stride) {
+ auto* dst = static_cast<uint8_t*>(dest);
+
+ int y = height;
+ do {
+ int x = 0;
+ do {
+ const int16x8_t src0_lo = vld1q_s16(prediction_0 + x);
+ const int16x8_t src1_lo = vld1q_s16(prediction_1 + x);
+ const int16x8_t res_lo =
+ ComputeWeightedAverage8(src0_lo, src1_lo, weights);
+
+ const int16x8_t src0_hi = vld1q_s16(prediction_0 + x + 8);
+ const int16x8_t src1_hi = vld1q_s16(prediction_1 + x + 8);
+ const int16x8_t res_hi =
+ ComputeWeightedAverage8(src0_hi, src1_hi, weights);
+
+ const uint8x16_t result =
+ vcombine_u8(vqmovun_s16(res_lo), vqmovun_s16(res_hi));
+ vst1q_u8(dst + x, result);
+ x += 16;
+ } while (x < width);
+ dst += dest_stride;
+ prediction_0 += width;
+ prediction_1 += width;
+ } while (--y != 0);
+}
+
+inline void DistanceWeightedBlend_NEON(const void* prediction_0,
+ const void* prediction_1,
+ const uint8_t weight_0,
+ const uint8_t weight_1, const int width,
+ const int height, void* const dest,
+ const ptrdiff_t dest_stride) {
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ int16x4_t weights[2] = {vdup_n_s16(weight_0), vdup_n_s16(weight_1)};
+ // TODO(johannkoenig): Investigate the branching. May be fine to call with a
+ // variable height.
+ if (width == 4) {
+ if (height == 4) {
+ DistanceWeightedBlendSmall_NEON<4, 4>(pred_0, pred_1, weights, dest,
+ dest_stride);
+ } else if (height == 8) {
+ DistanceWeightedBlendSmall_NEON<4, 8>(pred_0, pred_1, weights, dest,
+ dest_stride);
+ } else {
+ assert(height == 16);
+ DistanceWeightedBlendSmall_NEON<4, 16>(pred_0, pred_1, weights, dest,
+ dest_stride);
+ }
+ return;
+ }
+
+ if (width == 8) {
+ switch (height) {
+ case 4:
+ DistanceWeightedBlendSmall_NEON<8, 4>(pred_0, pred_1, weights, dest,
+ dest_stride);
+ return;
+ case 8:
+ DistanceWeightedBlendSmall_NEON<8, 8>(pred_0, pred_1, weights, dest,
+ dest_stride);
+ return;
+ case 16:
+ DistanceWeightedBlendSmall_NEON<8, 16>(pred_0, pred_1, weights, dest,
+ dest_stride);
+ return;
+ default:
+ assert(height == 32);
+ DistanceWeightedBlendSmall_NEON<8, 32>(pred_0, pred_1, weights, dest,
+ dest_stride);
+
+ return;
+ }
+ }
+
+ DistanceWeightedBlendLarge_NEON(pred_0, pred_1, weights, width, height, dest,
+ dest_stride);
+}
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+ dsp->distance_weighted_blend = DistanceWeightedBlend_NEON;
+}
+
+} // namespace
+
+void DistanceWeightedBlendInit_NEON() { Init8bpp(); }
+
+} // namespace dsp
+} // namespace libgav1
+
+#else // !LIBGAV1_ENABLE_NEON
+
+namespace libgav1 {
+namespace dsp {
+
+void DistanceWeightedBlendInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/distance_weighted_blend_neon.h b/src/dsp/arm/distance_weighted_blend_neon.h
new file mode 100644
index 0000000..4d8824c
--- /dev/null
+++ b/src/dsp/arm/distance_weighted_blend_neon.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_DISTANCE_WEIGHTED_BLEND_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_DISTANCE_WEIGHTED_BLEND_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::distance_weighted_blend. This function is not thread-safe.
+void DistanceWeightedBlendInit_NEON();
+
+} // namespace dsp
+} // namespace libgav1
+
+// If NEON is enabled signal the NEON implementation should be used instead of
+// normal C.
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_DistanceWeightedBlend LIBGAV1_CPU_NEON
+
+#endif // LIBGAV1_ENABLE_NEON
+
+#endif // LIBGAV1_SRC_DSP_ARM_DISTANCE_WEIGHTED_BLEND_NEON_H_
diff --git a/src/dsp/arm/film_grain_neon.cc b/src/dsp/arm/film_grain_neon.cc
new file mode 100644
index 0000000..2612466
--- /dev/null
+++ b/src/dsp/arm/film_grain_neon.cc
@@ -0,0 +1,1188 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/film_grain.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+#include <arm_neon.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <cstring>
+#include <new>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/arm/film_grain_neon.h"
+#include "src/dsp/common.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/dsp/film_grain_common.h"
+#include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
+#include "src/utils/logging.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace film_grain {
+namespace {
+
+// These functions are overloaded for both possible sizes in order to simplify
+// loading and storing to and from intermediate value types from within a
+// template function.
+inline int16x8_t GetSignedSource8(const int8_t* src) {
+ return vmovl_s8(vld1_s8(src));
+}
+
+inline int16x8_t GetSignedSource8(const uint8_t* src) {
+ return ZeroExtend(vld1_u8(src));
+}
+
+inline void StoreUnsigned8(uint8_t* dest, const uint16x8_t data) {
+ vst1_u8(dest, vmovn_u16(data));
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+inline int16x8_t GetSignedSource8(const int16_t* src) { return vld1q_s16(src); }
+
+inline int16x8_t GetSignedSource8(const uint16_t* src) {
+ return vreinterpretq_s16_u16(vld1q_u16(src));
+}
+
+inline void StoreUnsigned8(uint16_t* dest, const uint16x8_t data) {
+ vst1q_u16(dest, data);
+}
+#endif // LIBGAV1_MAX_BITDEPTH >= 10
+
+// Each element in |sum| represents one destination value's running
+// autoregression formula. The fixed source values in |grain_lo| and |grain_hi|
+// allow for a sliding window in successive calls to this function.
+template <int position_offset>
+inline int32x4x2_t AccumulateWeightedGrain(const int16x8_t grain_lo,
+ const int16x8_t grain_hi,
+ int16_t coeff, int32x4x2_t sum) {
+ const int16x8_t grain = vextq_s16(grain_lo, grain_hi, position_offset);
+ sum.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(grain), coeff);
+ sum.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(grain), coeff);
+ return sum;
+}
+
+// Because the autoregressive filter requires the output of each pixel to
+// compute pixels that come after in the row, we have to finish the calculations
+// one at a time.
+template <int bitdepth, int auto_regression_coeff_lag, int lane>
+inline void WriteFinalAutoRegression(int8_t* grain_cursor, int32x4x2_t sum,
+ const int8_t* coeffs, int pos, int shift) {
+ int32_t result = vgetq_lane_s32(sum.val[lane >> 2], lane & 3);
+
+ for (int delta_col = -auto_regression_coeff_lag; delta_col < 0; ++delta_col) {
+ result += grain_cursor[lane + delta_col] * coeffs[pos];
+ ++pos;
+ }
+ grain_cursor[lane] =
+ Clip3(grain_cursor[lane] + RightShiftWithRounding(result, shift),
+ GetGrainMin<bitdepth>(), GetGrainMax<bitdepth>());
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+template <int bitdepth, int auto_regression_coeff_lag, int lane>
+inline void WriteFinalAutoRegression(int16_t* grain_cursor, int32x4x2_t sum,
+ const int8_t* coeffs, int pos, int shift) {
+ int32_t result = vgetq_lane_s32(sum.val[lane >> 2], lane & 3);
+
+ for (int delta_col = -auto_regression_coeff_lag; delta_col < 0; ++delta_col) {
+ result += grain_cursor[lane + delta_col] * coeffs[pos];
+ ++pos;
+ }
+ grain_cursor[lane] =
+ Clip3(grain_cursor[lane] + RightShiftWithRounding(result, shift),
+ GetGrainMin<bitdepth>(), GetGrainMax<bitdepth>());
+}
+#endif // LIBGAV1_MAX_BITDEPTH >= 10
+
+// Because the autoregressive filter requires the output of each pixel to
+// compute pixels that come after in the row, we have to finish the calculations
+// one at a time.
+template <int bitdepth, int auto_regression_coeff_lag, int lane>
+inline void WriteFinalAutoRegressionChroma(int8_t* u_grain_cursor,
+ int8_t* v_grain_cursor,
+ int32x4x2_t sum_u, int32x4x2_t sum_v,
+ const int8_t* coeffs_u,
+ const int8_t* coeffs_v, int pos,
+ int shift) {
+ WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>(
+ u_grain_cursor, sum_u, coeffs_u, pos, shift);
+ WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>(
+ v_grain_cursor, sum_v, coeffs_v, pos, shift);
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+template <int bitdepth, int auto_regression_coeff_lag, int lane>
+inline void WriteFinalAutoRegressionChroma(int16_t* u_grain_cursor,
+ int16_t* v_grain_cursor,
+ int32x4x2_t sum_u, int32x4x2_t sum_v,
+ const int8_t* coeffs_u,
+ const int8_t* coeffs_v, int pos,
+ int shift) {
+ WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>(
+ u_grain_cursor, sum_u, coeffs_u, pos, shift);
+ WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>(
+ v_grain_cursor, sum_v, coeffs_v, pos, shift);
+}
+#endif // LIBGAV1_MAX_BITDEPTH >= 10
+
+inline void SetZero(int32x4x2_t* v) {
+ v->val[0] = vdupq_n_s32(0);
+ v->val[1] = vdupq_n_s32(0);
+}
+
+// Computes subsampled luma for use with chroma, by averaging in the x direction
+// or y direction when applicable.
+int16x8_t GetSubsampledLuma(const int8_t* const luma, int subsampling_x,
+ int subsampling_y, ptrdiff_t stride) {
+ if (subsampling_y != 0) {
+ assert(subsampling_x != 0);
+ const int8x16_t src0 = vld1q_s8(luma);
+ const int8x16_t src1 = vld1q_s8(luma + stride);
+ const int16x8_t ret0 = vcombine_s16(vpaddl_s8(vget_low_s8(src0)),
+ vpaddl_s8(vget_high_s8(src0)));
+ const int16x8_t ret1 = vcombine_s16(vpaddl_s8(vget_low_s8(src1)),
+ vpaddl_s8(vget_high_s8(src1)));
+ return vrshrq_n_s16(vaddq_s16(ret0, ret1), 2);
+ }
+ if (subsampling_x != 0) {
+ const int8x16_t src = vld1q_s8(luma);
+ return vrshrq_n_s16(
+ vcombine_s16(vpaddl_s8(vget_low_s8(src)), vpaddl_s8(vget_high_s8(src))),
+ 1);
+ }
+ return vmovl_s8(vld1_s8(luma));
+}
+
+// For BlendNoiseWithImageChromaWithCfl, only |subsampling_x| is needed.
+inline uint16x8_t GetAverageLuma(const uint8_t* const luma, int subsampling_x) {
+ if (subsampling_x != 0) {
+ const uint8x16_t src = vld1q_u8(luma);
+ return vrshrq_n_u16(vpaddlq_u8(src), 1);
+ }
+ return vmovl_u8(vld1_u8(luma));
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+// Computes subsampled luma for use with chroma, by averaging in the x direction
+// or y direction when applicable.
+int16x8_t GetSubsampledLuma(const int16_t* const luma, int subsampling_x,
+ int subsampling_y, ptrdiff_t stride) {
+ if (subsampling_y != 0) {
+ assert(subsampling_x != 0);
+ int16x8_t src0_lo = vld1q_s16(luma);
+ int16x8_t src0_hi = vld1q_s16(luma + 8);
+ const int16x8_t src1_lo = vld1q_s16(luma + stride);
+ const int16x8_t src1_hi = vld1q_s16(luma + stride + 8);
+ const int16x8_t src0 =
+ vcombine_s16(vpadd_s16(vget_low_s16(src0_lo), vget_high_s16(src0_lo)),
+ vpadd_s16(vget_low_s16(src0_hi), vget_high_s16(src0_hi)));
+ const int16x8_t src1 =
+ vcombine_s16(vpadd_s16(vget_low_s16(src1_lo), vget_high_s16(src1_lo)),
+ vpadd_s16(vget_low_s16(src1_hi), vget_high_s16(src1_hi)));
+ return vrshrq_n_s16(vaddq_s16(src0, src1), 2);
+ }
+ if (subsampling_x != 0) {
+ const int16x8_t src_lo = vld1q_s16(luma);
+ const int16x8_t src_hi = vld1q_s16(luma + 8);
+ const int16x8_t ret =
+ vcombine_s16(vpadd_s16(vget_low_s16(src_lo), vget_high_s16(src_lo)),
+ vpadd_s16(vget_low_s16(src_hi), vget_high_s16(src_hi)));
+ return vrshrq_n_s16(ret, 1);
+ }
+ return vld1q_s16(luma);
+}
+
+// For BlendNoiseWithImageChromaWithCfl, only |subsampling_x| is needed.
+inline uint16x8_t GetAverageLuma(const uint16_t* const luma,
+ int subsampling_x) {
+ if (subsampling_x != 0) {
+ const uint16x8x2_t src = vld2q_u16(luma);
+ return vrhaddq_u16(src.val[0], src.val[1]);
+ }
+ return vld1q_u16(luma);
+}
+#endif // LIBGAV1_MAX_BITDEPTH >= 10
+
+template <int bitdepth, typename GrainType, int auto_regression_coeff_lag,
+ bool use_luma>
+void ApplyAutoRegressiveFilterToChromaGrains_NEON(const FilmGrainParams& params,
+ const void* luma_grain_buffer,
+ int subsampling_x,
+ int subsampling_y,
+ void* u_grain_buffer,
+ void* v_grain_buffer) {
+ static_assert(auto_regression_coeff_lag <= 3, "Invalid autoregression lag.");
+ const auto* luma_grain = static_cast<const GrainType*>(luma_grain_buffer);
+ auto* u_grain = static_cast<GrainType*>(u_grain_buffer);
+ auto* v_grain = static_cast<GrainType*>(v_grain_buffer);
+ const int auto_regression_shift = params.auto_regression_shift;
+ const int chroma_width =
+ (subsampling_x == 0) ? kMaxChromaWidth : kMinChromaWidth;
+ const int chroma_height =
+ (subsampling_y == 0) ? kMaxChromaHeight : kMinChromaHeight;
+ // When |chroma_width| == 44, we write 8 at a time from x in [3, 34],
+ // leaving [35, 40] to write at the end.
+ const int chroma_width_remainder =
+ (chroma_width - 2 * kAutoRegressionBorder) & 7;
+
+ int y = kAutoRegressionBorder;
+ luma_grain += kLumaWidth * y;
+ u_grain += chroma_width * y;
+ v_grain += chroma_width * y;
+ do {
+ // Each row is computed 8 values at a time in the following loop. At the
+ // end of the loop, 4 values remain to write. They are given a special
+ // reduced iteration at the end.
+ int x = kAutoRegressionBorder;
+ int luma_x = kAutoRegressionBorder;
+ do {
+ int pos = 0;
+ int32x4x2_t sum_u;
+ int32x4x2_t sum_v;
+ SetZero(&sum_u);
+ SetZero(&sum_v);
+
+ if (auto_regression_coeff_lag > 0) {
+ for (int delta_row = -auto_regression_coeff_lag; delta_row < 0;
+ ++delta_row) {
+ // These loads may overflow to the next row, but they are never called
+ // on the final row of a grain block. Therefore, they will never
+ // exceed the block boundaries.
+ // Note: this could be slightly optimized to a single load in 8bpp,
+ // but requires making a special first iteration and accumulate
+ // function that takes an int8x16_t.
+ const int16x8_t u_grain_lo =
+ GetSignedSource8(u_grain + x + delta_row * chroma_width -
+ auto_regression_coeff_lag);
+ const int16x8_t u_grain_hi =
+ GetSignedSource8(u_grain + x + delta_row * chroma_width -
+ auto_regression_coeff_lag + 8);
+ const int16x8_t v_grain_lo =
+ GetSignedSource8(v_grain + x + delta_row * chroma_width -
+ auto_regression_coeff_lag);
+ const int16x8_t v_grain_hi =
+ GetSignedSource8(v_grain + x + delta_row * chroma_width -
+ auto_regression_coeff_lag + 8);
+#define ACCUMULATE_WEIGHTED_GRAIN(offset) \
+ sum_u = AccumulateWeightedGrain<offset>( \
+ u_grain_lo, u_grain_hi, params.auto_regression_coeff_u[pos], sum_u); \
+ sum_v = AccumulateWeightedGrain<offset>( \
+ v_grain_lo, v_grain_hi, params.auto_regression_coeff_v[pos++], sum_v)
+
+ ACCUMULATE_WEIGHTED_GRAIN(0);
+ ACCUMULATE_WEIGHTED_GRAIN(1);
+ ACCUMULATE_WEIGHTED_GRAIN(2);
+ // The horizontal |auto_regression_coeff_lag| loop is replaced with
+ // if-statements to give vextq_s16 an immediate param.
+ if (auto_regression_coeff_lag > 1) {
+ ACCUMULATE_WEIGHTED_GRAIN(3);
+ ACCUMULATE_WEIGHTED_GRAIN(4);
+ }
+ if (auto_regression_coeff_lag > 2) {
+ assert(auto_regression_coeff_lag == 3);
+ ACCUMULATE_WEIGHTED_GRAIN(5);
+ ACCUMULATE_WEIGHTED_GRAIN(6);
+ }
+ }
+ }
+
+ if (use_luma) {
+ const int16x8_t luma = GetSubsampledLuma(
+ luma_grain + luma_x, subsampling_x, subsampling_y, kLumaWidth);
+
+ // Luma samples get the final coefficient in the formula, but are best
+ // computed all at once before the final row.
+ const int coeff_u =
+ params.auto_regression_coeff_u[pos + auto_regression_coeff_lag];
+ const int coeff_v =
+ params.auto_regression_coeff_v[pos + auto_regression_coeff_lag];
+
+ sum_u.val[0] = vmlal_n_s16(sum_u.val[0], vget_low_s16(luma), coeff_u);
+ sum_u.val[1] = vmlal_n_s16(sum_u.val[1], vget_high_s16(luma), coeff_u);
+ sum_v.val[0] = vmlal_n_s16(sum_v.val[0], vget_low_s16(luma), coeff_v);
+ sum_v.val[1] = vmlal_n_s16(sum_v.val[1], vget_high_s16(luma), coeff_v);
+ }
+ // At this point in the filter, the source addresses and destination
+ // addresses overlap. Because this is an auto-regressive filter, the
+ // higher lanes cannot be computed without the results of the lower lanes.
+ // Each call to WriteFinalAutoRegression incorporates preceding values
+ // on the final row, and writes a single sample. This allows the next
+ // pixel's value to be computed in the next call.
+#define WRITE_AUTO_REGRESSION_RESULT(lane) \
+ WriteFinalAutoRegressionChroma<bitdepth, auto_regression_coeff_lag, lane>( \
+ u_grain + x, v_grain + x, sum_u, sum_v, params.auto_regression_coeff_u, \
+ params.auto_regression_coeff_v, pos, auto_regression_shift)
+
+ WRITE_AUTO_REGRESSION_RESULT(0);
+ WRITE_AUTO_REGRESSION_RESULT(1);
+ WRITE_AUTO_REGRESSION_RESULT(2);
+ WRITE_AUTO_REGRESSION_RESULT(3);
+ WRITE_AUTO_REGRESSION_RESULT(4);
+ WRITE_AUTO_REGRESSION_RESULT(5);
+ WRITE_AUTO_REGRESSION_RESULT(6);
+ WRITE_AUTO_REGRESSION_RESULT(7);
+
+ x += 8;
+ luma_x += 8 << subsampling_x;
+ } while (x < chroma_width - kAutoRegressionBorder - chroma_width_remainder);
+
+ // This is the "final iteration" of the above loop over width. We fill in
+ // the remainder of the width, which is less than 8.
+ int pos = 0;
+ int32x4x2_t sum_u;
+ int32x4x2_t sum_v;
+ SetZero(&sum_u);
+ SetZero(&sum_v);
+
+ for (int delta_row = -auto_regression_coeff_lag; delta_row < 0;
+ ++delta_row) {
+ // These loads may overflow to the next row, but they are never called on
+ // the final row of a grain block. Therefore, they will never exceed the
+ // block boundaries.
+ const int16x8_t u_grain_lo = GetSignedSource8(
+ u_grain + x + delta_row * chroma_width - auto_regression_coeff_lag);
+ const int16x8_t u_grain_hi =
+ GetSignedSource8(u_grain + x + delta_row * chroma_width -
+ auto_regression_coeff_lag + 8);
+ const int16x8_t v_grain_lo = GetSignedSource8(
+ v_grain + x + delta_row * chroma_width - auto_regression_coeff_lag);
+ const int16x8_t v_grain_hi =
+ GetSignedSource8(v_grain + x + delta_row * chroma_width -
+ auto_regression_coeff_lag + 8);
+
+ ACCUMULATE_WEIGHTED_GRAIN(0);
+ ACCUMULATE_WEIGHTED_GRAIN(1);
+ ACCUMULATE_WEIGHTED_GRAIN(2);
+ // The horizontal |auto_regression_coeff_lag| loop is replaced with
+ // if-statements to give vextq_s16 an immediate param.
+ if (auto_regression_coeff_lag > 1) {
+ ACCUMULATE_WEIGHTED_GRAIN(3);
+ ACCUMULATE_WEIGHTED_GRAIN(4);
+ }
+ if (auto_regression_coeff_lag > 2) {
+ assert(auto_regression_coeff_lag == 3);
+ ACCUMULATE_WEIGHTED_GRAIN(5);
+ ACCUMULATE_WEIGHTED_GRAIN(6);
+ }
+ }
+
+ if (use_luma) {
+ const int16x8_t luma = GetSubsampledLuma(
+ luma_grain + luma_x, subsampling_x, subsampling_y, kLumaWidth);
+
+ // Luma samples get the final coefficient in the formula, but are best
+ // computed all at once before the final row.
+ const int coeff_u =
+ params.auto_regression_coeff_u[pos + auto_regression_coeff_lag];
+ const int coeff_v =
+ params.auto_regression_coeff_v[pos + auto_regression_coeff_lag];
+
+ sum_u.val[0] = vmlal_n_s16(sum_u.val[0], vget_low_s16(luma), coeff_u);
+ sum_u.val[1] = vmlal_n_s16(sum_u.val[1], vget_high_s16(luma), coeff_u);
+ sum_v.val[0] = vmlal_n_s16(sum_v.val[0], vget_low_s16(luma), coeff_v);
+ sum_v.val[1] = vmlal_n_s16(sum_v.val[1], vget_high_s16(luma), coeff_v);
+ }
+
+ WRITE_AUTO_REGRESSION_RESULT(0);
+ WRITE_AUTO_REGRESSION_RESULT(1);
+ WRITE_AUTO_REGRESSION_RESULT(2);
+ WRITE_AUTO_REGRESSION_RESULT(3);
+ if (chroma_width_remainder == 6) {
+ WRITE_AUTO_REGRESSION_RESULT(4);
+ WRITE_AUTO_REGRESSION_RESULT(5);
+ }
+
+ luma_grain += kLumaWidth << subsampling_y;
+ u_grain += chroma_width;
+ v_grain += chroma_width;
+ } while (++y < chroma_height);
+#undef ACCUMULATE_WEIGHTED_GRAIN
+#undef WRITE_AUTO_REGRESSION_RESULT
+}
+
+// Applies an auto-regressive filter to the white noise in luma_grain.
+template <int bitdepth, typename GrainType, int auto_regression_coeff_lag>
+void ApplyAutoRegressiveFilterToLumaGrain_NEON(const FilmGrainParams& params,
+ void* luma_grain_buffer) {
+ static_assert(auto_regression_coeff_lag > 0, "");
+ const int8_t* const auto_regression_coeff_y = params.auto_regression_coeff_y;
+ const uint8_t auto_regression_shift = params.auto_regression_shift;
+
+ int y = kAutoRegressionBorder;
+ auto* luma_grain =
+ static_cast<GrainType*>(luma_grain_buffer) + kLumaWidth * y;
+ do {
+ // Each row is computed 8 values at a time in the following loop. At the
+ // end of the loop, 4 values remain to write. They are given a special
+ // reduced iteration at the end.
+ int x = kAutoRegressionBorder;
+ do {
+ int pos = 0;
+ int32x4x2_t sum;
+ SetZero(&sum);
+ for (int delta_row = -auto_regression_coeff_lag; delta_row < 0;
+ ++delta_row) {
+ // These loads may overflow to the next row, but they are never called
+ // on the final row of a grain block. Therefore, they will never exceed
+ // the block boundaries.
+ const int16x8_t src_grain_lo =
+ GetSignedSource8(luma_grain + x + delta_row * kLumaWidth -
+ auto_regression_coeff_lag);
+ const int16x8_t src_grain_hi =
+ GetSignedSource8(luma_grain + x + delta_row * kLumaWidth -
+ auto_regression_coeff_lag + 8);
+
+ // A pictorial representation of the auto-regressive filter for
+ // various values of params.auto_regression_coeff_lag. The letter 'O'
+ // represents the current sample. (The filter always operates on the
+ // current sample with filter coefficient 1.) The letters 'X'
+ // represent the neighboring samples that the filter operates on, below
+ // their corresponding "offset" number.
+ //
+ // params.auto_regression_coeff_lag == 3:
+ // 0 1 2 3 4 5 6
+ // X X X X X X X
+ // X X X X X X X
+ // X X X X X X X
+ // X X X O
+ // params.auto_regression_coeff_lag == 2:
+ // 0 1 2 3 4
+ // X X X X X
+ // X X X X X
+ // X X O
+ // params.auto_regression_coeff_lag == 1:
+ // 0 1 2
+ // X X X
+ // X O
+ // params.auto_regression_coeff_lag == 0:
+ // O
+ // The function relies on the caller to skip the call in the 0 lag
+ // case.
+
+#define ACCUMULATE_WEIGHTED_GRAIN(offset) \
+ sum = AccumulateWeightedGrain<offset>(src_grain_lo, src_grain_hi, \
+ auto_regression_coeff_y[pos++], sum)
+ ACCUMULATE_WEIGHTED_GRAIN(0);
+ ACCUMULATE_WEIGHTED_GRAIN(1);
+ ACCUMULATE_WEIGHTED_GRAIN(2);
+ // The horizontal |auto_regression_coeff_lag| loop is replaced with
+ // if-statements to give vextq_s16 an immediate param.
+ if (auto_regression_coeff_lag > 1) {
+ ACCUMULATE_WEIGHTED_GRAIN(3);
+ ACCUMULATE_WEIGHTED_GRAIN(4);
+ }
+ if (auto_regression_coeff_lag > 2) {
+ assert(auto_regression_coeff_lag == 3);
+ ACCUMULATE_WEIGHTED_GRAIN(5);
+ ACCUMULATE_WEIGHTED_GRAIN(6);
+ }
+ }
+ // At this point in the filter, the source addresses and destination
+ // addresses overlap. Because this is an auto-regressive filter, the
+ // higher lanes cannot be computed without the results of the lower lanes.
+ // Each call to WriteFinalAutoRegression incorporates preceding values
+ // on the final row, and writes a single sample. This allows the next
+ // pixel's value to be computed in the next call.
+#define WRITE_AUTO_REGRESSION_RESULT(lane) \
+ WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>( \
+ luma_grain + x, sum, auto_regression_coeff_y, pos, \
+ auto_regression_shift)
+
+ WRITE_AUTO_REGRESSION_RESULT(0);
+ WRITE_AUTO_REGRESSION_RESULT(1);
+ WRITE_AUTO_REGRESSION_RESULT(2);
+ WRITE_AUTO_REGRESSION_RESULT(3);
+ WRITE_AUTO_REGRESSION_RESULT(4);
+ WRITE_AUTO_REGRESSION_RESULT(5);
+ WRITE_AUTO_REGRESSION_RESULT(6);
+ WRITE_AUTO_REGRESSION_RESULT(7);
+ x += 8;
+ // Leave the final four pixels for the special iteration below.
+ } while (x < kLumaWidth - kAutoRegressionBorder - 4);
+
+ // Final 4 pixels in the row.
+ int pos = 0;
+ int32x4x2_t sum;
+ SetZero(&sum);
+ for (int delta_row = -auto_regression_coeff_lag; delta_row < 0;
+ ++delta_row) {
+ const int16x8_t src_grain_lo = GetSignedSource8(
+ luma_grain + x + delta_row * kLumaWidth - auto_regression_coeff_lag);
+ const int16x8_t src_grain_hi =
+ GetSignedSource8(luma_grain + x + delta_row * kLumaWidth -
+ auto_regression_coeff_lag + 8);
+
+ ACCUMULATE_WEIGHTED_GRAIN(0);
+ ACCUMULATE_WEIGHTED_GRAIN(1);
+ ACCUMULATE_WEIGHTED_GRAIN(2);
+ // The horizontal |auto_regression_coeff_lag| loop is replaced with
+ // if-statements to give vextq_s16 an immediate param.
+ if (auto_regression_coeff_lag > 1) {
+ ACCUMULATE_WEIGHTED_GRAIN(3);
+ ACCUMULATE_WEIGHTED_GRAIN(4);
+ }
+ if (auto_regression_coeff_lag > 2) {
+ assert(auto_regression_coeff_lag == 3);
+ ACCUMULATE_WEIGHTED_GRAIN(5);
+ ACCUMULATE_WEIGHTED_GRAIN(6);
+ }
+ }
+ // delta_row == 0
+ WRITE_AUTO_REGRESSION_RESULT(0);
+ WRITE_AUTO_REGRESSION_RESULT(1);
+ WRITE_AUTO_REGRESSION_RESULT(2);
+ WRITE_AUTO_REGRESSION_RESULT(3);
+ luma_grain += kLumaWidth;
+ } while (++y < kLumaHeight);
+
+#undef WRITE_AUTO_REGRESSION_RESULT
+#undef ACCUMULATE_WEIGHTED_GRAIN
+}
+
+void InitializeScalingLookupTable_NEON(
+ int num_points, const uint8_t point_value[], const uint8_t point_scaling[],
+ uint8_t scaling_lut[kScalingLookupTableSize]) {
+ if (num_points == 0) {
+ memset(scaling_lut, 0, sizeof(scaling_lut[0]) * kScalingLookupTableSize);
+ return;
+ }
+ static_assert(sizeof(scaling_lut[0]) == 1, "");
+ memset(scaling_lut, point_scaling[0], point_value[0]);
+ const uint32x4_t steps = vmovl_u16(vcreate_u16(0x0003000200010000));
+ const uint32x4_t offset = vdupq_n_u32(32768);
+ for (int i = 0; i < num_points - 1; ++i) {
+ const int delta_y = point_scaling[i + 1] - point_scaling[i];
+ const int delta_x = point_value[i + 1] - point_value[i];
+ const int delta = delta_y * ((65536 + (delta_x >> 1)) / delta_x);
+ const int delta4 = delta << 2;
+ const uint8x8_t base_point = vdup_n_u8(point_scaling[i]);
+ uint32x4_t upscaled_points0 = vmlaq_n_u32(offset, steps, delta);
+ const uint32x4_t line_increment4 = vdupq_n_u32(delta4);
+ // Get the second set of 4 points by adding 4 steps to the first set.
+ uint32x4_t upscaled_points1 = vaddq_u32(upscaled_points0, line_increment4);
+ // We obtain the next set of 8 points by adding 8 steps to each of the
+ // current 8 points.
+ const uint32x4_t line_increment8 = vshlq_n_u32(line_increment4, 1);
+ int x = 0;
+ do {
+ const uint16x4_t interp_points0 = vshrn_n_u32(upscaled_points0, 16);
+ const uint16x4_t interp_points1 = vshrn_n_u32(upscaled_points1, 16);
+ const uint8x8_t interp_points =
+ vmovn_u16(vcombine_u16(interp_points0, interp_points1));
+ // The spec guarantees that the max value of |point_value[i]| + x is 255.
+ // Writing 8 bytes starting at the final table byte, leaves 7 bytes of
+ // required padding.
+ vst1_u8(&scaling_lut[point_value[i] + x],
+ vadd_u8(interp_points, base_point));
+ upscaled_points0 = vaddq_u32(upscaled_points0, line_increment8);
+ upscaled_points1 = vaddq_u32(upscaled_points1, line_increment8);
+ x += 8;
+ } while (x < delta_x);
+ }
+ const uint8_t last_point_value = point_value[num_points - 1];
+ memset(&scaling_lut[last_point_value], point_scaling[num_points - 1],
+ kScalingLookupTableSize - last_point_value);
+}
+
+inline int16x8_t Clip3(const int16x8_t value, const int16x8_t low,
+ const int16x8_t high) {
+ const int16x8_t clipped_to_ceiling = vminq_s16(high, value);
+ return vmaxq_s16(low, clipped_to_ceiling);
+}
+
+template <int bitdepth, typename Pixel>
+inline int16x8_t GetScalingFactors(
+ const uint8_t scaling_lut[kScalingLookupTableSize], const Pixel* source) {
+ int16_t start_vals[8];
+ if (bitdepth == 8) {
+ start_vals[0] = scaling_lut[source[0]];
+ start_vals[1] = scaling_lut[source[1]];
+ start_vals[2] = scaling_lut[source[2]];
+ start_vals[3] = scaling_lut[source[3]];
+ start_vals[4] = scaling_lut[source[4]];
+ start_vals[5] = scaling_lut[source[5]];
+ start_vals[6] = scaling_lut[source[6]];
+ start_vals[7] = scaling_lut[source[7]];
+ return vld1q_s16(start_vals);
+ }
+ int16_t end_vals[8];
+ // TODO(petersonab): Precompute this into a larger table for direct lookups.
+ int index = source[0] >> 2;
+ start_vals[0] = scaling_lut[index];
+ end_vals[0] = scaling_lut[index + 1];
+ index = source[1] >> 2;
+ start_vals[1] = scaling_lut[index];
+ end_vals[1] = scaling_lut[index + 1];
+ index = source[2] >> 2;
+ start_vals[2] = scaling_lut[index];
+ end_vals[2] = scaling_lut[index + 1];
+ index = source[3] >> 2;
+ start_vals[3] = scaling_lut[index];
+ end_vals[3] = scaling_lut[index + 1];
+ index = source[4] >> 2;
+ start_vals[4] = scaling_lut[index];
+ end_vals[4] = scaling_lut[index + 1];
+ index = source[5] >> 2;
+ start_vals[5] = scaling_lut[index];
+ end_vals[5] = scaling_lut[index + 1];
+ index = source[6] >> 2;
+ start_vals[6] = scaling_lut[index];
+ end_vals[6] = scaling_lut[index + 1];
+ index = source[7] >> 2;
+ start_vals[7] = scaling_lut[index];
+ end_vals[7] = scaling_lut[index + 1];
+ const int16x8_t start = vld1q_s16(start_vals);
+ const int16x8_t end = vld1q_s16(end_vals);
+ int16x8_t remainder = GetSignedSource8(source);
+ remainder = vandq_s16(remainder, vdupq_n_s16(3));
+ const int16x8_t delta = vmulq_s16(vsubq_s16(end, start), remainder);
+ return vaddq_s16(start, vrshrq_n_s16(delta, 2));
+}
+
+inline int16x8_t ScaleNoise(const int16x8_t noise, const int16x8_t scaling,
+ const int16x8_t scaling_shift_vect) {
+ const int16x8_t upscaled_noise = vmulq_s16(noise, scaling);
+ return vrshlq_s16(upscaled_noise, scaling_shift_vect);
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+inline int16x8_t ScaleNoise(const int16x8_t noise, const int16x8_t scaling,
+ const int32x4_t scaling_shift_vect) {
+ // TODO(petersonab): Try refactoring scaling lookup table to int16_t and
+ // upscaling by 7 bits to permit high half multiply. This would eliminate
+ // the intermediate 32x4 registers. Also write the averaged values directly
+ // into the table so it doesn't have to be done for every pixel in
+ // the frame.
+ const int32x4_t upscaled_noise_lo =
+ vmull_s16(vget_low_s16(noise), vget_low_s16(scaling));
+ const int32x4_t upscaled_noise_hi =
+ vmull_s16(vget_high_s16(noise), vget_high_s16(scaling));
+ const int16x4_t noise_lo =
+ vmovn_s32(vrshlq_s32(upscaled_noise_lo, scaling_shift_vect));
+ const int16x4_t noise_hi =
+ vmovn_s32(vrshlq_s32(upscaled_noise_hi, scaling_shift_vect));
+ return vcombine_s16(noise_lo, noise_hi);
+}
+#endif // LIBGAV1_MAX_BITDEPTH >= 10
+
+template <int bitdepth, typename GrainType, typename Pixel>
+void BlendNoiseWithImageLuma_NEON(
+ const void* noise_image_ptr, int min_value, int max_luma, int scaling_shift,
+ int width, int height, int start_height,
+ const uint8_t scaling_lut_y[kScalingLookupTableSize],
+ const void* source_plane_y, ptrdiff_t source_stride_y, void* dest_plane_y,
+ ptrdiff_t dest_stride_y) {
+ const auto* noise_image =
+ static_cast<const Array2D<GrainType>*>(noise_image_ptr);
+ const auto* in_y_row = static_cast<const Pixel*>(source_plane_y);
+ source_stride_y /= sizeof(Pixel);
+ auto* out_y_row = static_cast<Pixel*>(dest_plane_y);
+ dest_stride_y /= sizeof(Pixel);
+ const int16x8_t floor = vdupq_n_s16(min_value);
+ const int16x8_t ceiling = vdupq_n_s16(max_luma);
+ // In 8bpp, the maximum upscaled noise is 127*255 = 0x7E81, which is safe
+ // for 16 bit signed integers. In higher bitdepths, however, we have to
+ // expand to 32 to protect the sign bit.
+ const int16x8_t scaling_shift_vect16 = vdupq_n_s16(-scaling_shift);
+#if LIBGAV1_MAX_BITDEPTH >= 10
+ const int32x4_t scaling_shift_vect32 = vdupq_n_s32(-scaling_shift);
+#endif // LIBGAV1_MAX_BITDEPTH >= 10
+
+ int y = 0;
+ do {
+ int x = 0;
+ do {
+ // This operation on the unsigned input is safe in 8bpp because the vector
+ // is widened before it is reinterpreted.
+ const int16x8_t orig = GetSignedSource8(&in_y_row[x]);
+ const int16x8_t scaling =
+ GetScalingFactors<bitdepth, Pixel>(scaling_lut_y, &in_y_row[x]);
+ int16x8_t noise =
+ GetSignedSource8(&(noise_image[kPlaneY][y + start_height][x]));
+
+ if (bitdepth == 8) {
+ noise = ScaleNoise(noise, scaling, scaling_shift_vect16);
+ } else {
+#if LIBGAV1_MAX_BITDEPTH >= 10
+ noise = ScaleNoise(noise, scaling, scaling_shift_vect32);
+#endif // LIBGAV1_MAX_BITDEPTH >= 10
+ }
+ const int16x8_t combined = vaddq_s16(orig, noise);
+ // In 8bpp, when params_.clip_to_restricted_range == false, we can replace
+ // clipping with vqmovun_s16, but it's not likely to be worth copying the
+ // function for just that case, though the gain would be very small.
+ StoreUnsigned8(&out_y_row[x],
+ vreinterpretq_u16_s16(Clip3(combined, floor, ceiling)));
+ x += 8;
+ } while (x < width);
+ in_y_row += source_stride_y;
+ out_y_row += dest_stride_y;
+ } while (++y < height);
+}
+
+template <int bitdepth, typename GrainType, typename Pixel>
+inline int16x8_t BlendChromaValsWithCfl(
+ const Pixel* average_luma_buffer,
+ const uint8_t scaling_lut[kScalingLookupTableSize],
+ const Pixel* chroma_cursor, const GrainType* noise_image_cursor,
+ const int16x8_t scaling_shift_vect16,
+ const int32x4_t scaling_shift_vect32) {
+ const int16x8_t scaling =
+ GetScalingFactors<bitdepth, Pixel>(scaling_lut, average_luma_buffer);
+ const int16x8_t orig = GetSignedSource8(chroma_cursor);
+ int16x8_t noise = GetSignedSource8(noise_image_cursor);
+ if (bitdepth == 8) {
+ noise = ScaleNoise(noise, scaling, scaling_shift_vect16);
+ } else {
+ noise = ScaleNoise(noise, scaling, scaling_shift_vect32);
+ }
+ return vaddq_s16(orig, noise);
+}
+
+template <int bitdepth, typename GrainType, typename Pixel>
+LIBGAV1_ALWAYS_INLINE void BlendChromaPlaneWithCfl_NEON(
+ const Array2D<GrainType>& noise_image, int min_value, int max_chroma,
+ int width, int height, int start_height, int subsampling_x,
+ int subsampling_y, int scaling_shift,
+ const uint8_t scaling_lut[kScalingLookupTableSize], const Pixel* in_y_row,
+ ptrdiff_t source_stride_y, const Pixel* in_chroma_row,
+ ptrdiff_t source_stride_chroma, Pixel* out_chroma_row,
+ ptrdiff_t dest_stride) {
+ const int16x8_t floor = vdupq_n_s16(min_value);
+ const int16x8_t ceiling = vdupq_n_s16(max_chroma);
+ Pixel luma_buffer[16];
+ memset(luma_buffer, 0, sizeof(luma_buffer));
+ // In 8bpp, the maximum upscaled noise is 127*255 = 0x7E81, which is safe
+ // for 16 bit signed integers. In higher bitdepths, however, we have to
+ // expand to 32 to protect the sign bit.
+ const int16x8_t scaling_shift_vect16 = vdupq_n_s16(-scaling_shift);
+ const int32x4_t scaling_shift_vect32 = vdupq_n_s32(-scaling_shift);
+
+ const int chroma_height = (height + subsampling_y) >> subsampling_y;
+ const int chroma_width = (width + subsampling_x) >> subsampling_x;
+ const int safe_chroma_width = chroma_width & ~7;
+
+ // Writing to this buffer avoids the cost of doing 8 lane lookups in a row
+ // in GetScalingFactors.
+ Pixel average_luma_buffer[8];
+ assert(start_height % 2 == 0);
+ start_height >>= subsampling_y;
+ int y = 0;
+ do {
+ int x = 0;
+ do {
+ const int luma_x = x << subsampling_x;
+ // TODO(petersonab): Consider specializing by subsampling_x. In the 444
+ // case &in_y_row[x] can be passed to GetScalingFactors directly.
+ const uint16x8_t average_luma =
+ GetAverageLuma(&in_y_row[luma_x], subsampling_x);
+ StoreUnsigned8(average_luma_buffer, average_luma);
+
+ const int16x8_t blended =
+ BlendChromaValsWithCfl<bitdepth, GrainType, Pixel>(
+ average_luma_buffer, scaling_lut, &in_chroma_row[x],
+ &(noise_image[y + start_height][x]), scaling_shift_vect16,
+ scaling_shift_vect32);
+
+ // In 8bpp, when params_.clip_to_restricted_range == false, we can replace
+ // clipping with vqmovun_s16, but it's not likely to be worth copying the
+ // function for just that case.
+ StoreUnsigned8(&out_chroma_row[x],
+ vreinterpretq_u16_s16(Clip3(blended, floor, ceiling)));
+ x += 8;
+ } while (x < safe_chroma_width);
+
+ if (x < chroma_width) {
+ const int luma_x = x << subsampling_x;
+ const int valid_range = width - luma_x;
+ memcpy(luma_buffer, &in_y_row[luma_x], valid_range * sizeof(in_y_row[0]));
+ luma_buffer[valid_range] = in_y_row[width - 1];
+ const uint16x8_t average_luma =
+ GetAverageLuma(luma_buffer, subsampling_x);
+ StoreUnsigned8(average_luma_buffer, average_luma);
+
+ const int16x8_t blended =
+ BlendChromaValsWithCfl<bitdepth, GrainType, Pixel>(
+ average_luma_buffer, scaling_lut, &in_chroma_row[x],
+ &(noise_image[y + start_height][x]), scaling_shift_vect16,
+ scaling_shift_vect32);
+ // In 8bpp, when params_.clip_to_restricted_range == false, we can replace
+ // clipping with vqmovun_s16, but it's not likely to be worth copying the
+ // function for just that case.
+ StoreUnsigned8(&out_chroma_row[x],
+ vreinterpretq_u16_s16(Clip3(blended, floor, ceiling)));
+ }
+
+ in_y_row += source_stride_y << subsampling_y;
+ in_chroma_row += source_stride_chroma;
+ out_chroma_row += dest_stride;
+ } while (++y < chroma_height);
+}
+
+// This function is for the case params_.chroma_scaling_from_luma == true.
+// This further implies that scaling_lut_u == scaling_lut_v == scaling_lut_y.
+template <int bitdepth, typename GrainType, typename Pixel>
+void BlendNoiseWithImageChromaWithCfl_NEON(
+ Plane plane, const FilmGrainParams& params, const void* noise_image_ptr,
+ int min_value, int max_chroma, int width, int height, int start_height,
+ int subsampling_x, int subsampling_y,
+ const uint8_t scaling_lut[kScalingLookupTableSize],
+ const void* source_plane_y, ptrdiff_t source_stride_y,
+ const void* source_plane_uv, ptrdiff_t source_stride_uv,
+ void* dest_plane_uv, ptrdiff_t dest_stride_uv) {
+ const auto* noise_image =
+ static_cast<const Array2D<GrainType>*>(noise_image_ptr);
+ const auto* in_y = static_cast<const Pixel*>(source_plane_y);
+ source_stride_y /= sizeof(Pixel);
+
+ const auto* in_uv = static_cast<const Pixel*>(source_plane_uv);
+ source_stride_uv /= sizeof(Pixel);
+ auto* out_uv = static_cast<Pixel*>(dest_plane_uv);
+ dest_stride_uv /= sizeof(Pixel);
+ // Looping over one plane at a time is faster in higher resolutions, despite
+ // re-computing luma.
+ BlendChromaPlaneWithCfl_NEON<bitdepth, GrainType, Pixel>(
+ noise_image[plane], min_value, max_chroma, width, height, start_height,
+ subsampling_x, subsampling_y, params.chroma_scaling, scaling_lut, in_y,
+ source_stride_y, in_uv, source_stride_uv, out_uv, dest_stride_uv);
+}
+
+} // namespace
+
+namespace low_bitdepth {
+namespace {
+
+inline int16x8_t BlendChromaValsNoCfl(
+ const uint8_t scaling_lut[kScalingLookupTableSize],
+ const uint8_t* chroma_cursor, const int8_t* noise_image_cursor,
+ const int16x8_t& average_luma, const int16x8_t& scaling_shift_vect,
+ const int16x8_t& offset, int luma_multiplier, int chroma_multiplier) {
+ uint8_t merged_buffer[8];
+ const int16x8_t orig = GetSignedSource8(chroma_cursor);
+ const int16x8_t weighted_luma = vmulq_n_s16(average_luma, luma_multiplier);
+ const int16x8_t weighted_chroma = vmulq_n_s16(orig, chroma_multiplier);
+ // Maximum value of |combined_u| is 127*255 = 0x7E81.
+ const int16x8_t combined = vhaddq_s16(weighted_luma, weighted_chroma);
+ // Maximum value of u_offset is (255 << 5) = 0x1FE0.
+ // 0x7E81 + 0x1FE0 = 0x9E61, therefore another halving add is required.
+ const uint8x8_t merged = vqshrun_n_s16(vhaddq_s16(offset, combined), 4);
+ vst1_u8(merged_buffer, merged);
+ const int16x8_t scaling =
+ GetScalingFactors<8, uint8_t>(scaling_lut, merged_buffer);
+ int16x8_t noise = GetSignedSource8(noise_image_cursor);
+ noise = ScaleNoise(noise, scaling, scaling_shift_vect);
+ return vaddq_s16(orig, noise);
+}
+
+LIBGAV1_ALWAYS_INLINE void BlendChromaPlane8bpp_NEON(
+ const Array2D<int8_t>& noise_image, int min_value, int max_chroma,
+ int width, int height, int start_height, int subsampling_x,
+ int subsampling_y, int scaling_shift, int chroma_offset,
+ int chroma_multiplier, int luma_multiplier,
+ const uint8_t scaling_lut[kScalingLookupTableSize], const uint8_t* in_y_row,
+ ptrdiff_t source_stride_y, const uint8_t* in_chroma_row,
+ ptrdiff_t source_stride_chroma, uint8_t* out_chroma_row,
+ ptrdiff_t dest_stride) {
+ const int16x8_t floor = vdupq_n_s16(min_value);
+ const int16x8_t ceiling = vdupq_n_s16(max_chroma);
+ // In 8bpp, the maximum upscaled noise is 127*255 = 0x7E81, which is safe
+ // for 16 bit signed integers. In higher bitdepths, however, we have to
+ // expand to 32 to protect the sign bit.
+ const int16x8_t scaling_shift_vect = vdupq_n_s16(-scaling_shift);
+
+ const int chroma_height = (height + subsampling_y) >> subsampling_y;
+ const int chroma_width = (width + subsampling_x) >> subsampling_x;
+ const int safe_chroma_width = chroma_width & ~7;
+ uint8_t luma_buffer[16];
+ const int16x8_t offset = vdupq_n_s16(chroma_offset << 5);
+
+ start_height >>= subsampling_y;
+ int y = 0;
+ do {
+ int x = 0;
+ do {
+ const int luma_x = x << subsampling_x;
+ const int16x8_t average_luma = vreinterpretq_s16_u16(
+ GetAverageLuma(&in_y_row[luma_x], subsampling_x));
+ const int16x8_t blended = BlendChromaValsNoCfl(
+ scaling_lut, &in_chroma_row[x], &(noise_image[y + start_height][x]),
+ average_luma, scaling_shift_vect, offset, luma_multiplier,
+ chroma_multiplier);
+ // In 8bpp, when params_.clip_to_restricted_range == false, we can
+ // replace clipping with vqmovun_s16, but the gain would be small.
+ StoreUnsigned8(&out_chroma_row[x],
+ vreinterpretq_u16_s16(Clip3(blended, floor, ceiling)));
+
+ x += 8;
+ } while (x < safe_chroma_width);
+
+ if (x < chroma_width) {
+ // Begin right edge iteration. Same as the normal iterations, but the
+ // |average_luma| computation requires a duplicated luma value at the
+ // end.
+ const int luma_x = x << subsampling_x;
+ const int valid_range = width - luma_x;
+ memcpy(luma_buffer, &in_y_row[luma_x], valid_range * sizeof(in_y_row[0]));
+ luma_buffer[valid_range] = in_y_row[width - 1];
+
+ const int16x8_t average_luma =
+ vreinterpretq_s16_u16(GetAverageLuma(luma_buffer, subsampling_x));
+ const int16x8_t blended = BlendChromaValsNoCfl(
+ scaling_lut, &in_chroma_row[x], &(noise_image[y + start_height][x]),
+ average_luma, scaling_shift_vect, offset, luma_multiplier,
+ chroma_multiplier);
+ StoreUnsigned8(&out_chroma_row[x],
+ vreinterpretq_u16_s16(Clip3(blended, floor, ceiling)));
+ // End of right edge iteration.
+ }
+
+ in_y_row += source_stride_y << subsampling_y;
+ in_chroma_row += source_stride_chroma;
+ out_chroma_row += dest_stride;
+ } while (++y < chroma_height);
+}
+
+// This function is for the case params_.chroma_scaling_from_luma == false.
+void BlendNoiseWithImageChroma8bpp_NEON(
+ Plane plane, const FilmGrainParams& params, const void* noise_image_ptr,
+ int min_value, int max_chroma, int width, int height, int start_height,
+ int subsampling_x, int subsampling_y,
+ const uint8_t scaling_lut[kScalingLookupTableSize],
+ const void* source_plane_y, ptrdiff_t source_stride_y,
+ const void* source_plane_uv, ptrdiff_t source_stride_uv,
+ void* dest_plane_uv, ptrdiff_t dest_stride_uv) {
+ assert(plane == kPlaneU || plane == kPlaneV);
+ const auto* noise_image =
+ static_cast<const Array2D<int8_t>*>(noise_image_ptr);
+ const auto* in_y = static_cast<const uint8_t*>(source_plane_y);
+ const auto* in_uv = static_cast<const uint8_t*>(source_plane_uv);
+ auto* out_uv = static_cast<uint8_t*>(dest_plane_uv);
+
+ const int offset = (plane == kPlaneU) ? params.u_offset : params.v_offset;
+ const int luma_multiplier =
+ (plane == kPlaneU) ? params.u_luma_multiplier : params.v_luma_multiplier;
+ const int multiplier =
+ (plane == kPlaneU) ? params.u_multiplier : params.v_multiplier;
+ BlendChromaPlane8bpp_NEON(noise_image[plane], min_value, max_chroma, width,
+ height, start_height, subsampling_x, subsampling_y,
+ params.chroma_scaling, offset, multiplier,
+ luma_multiplier, scaling_lut, in_y, source_stride_y,
+ in_uv, source_stride_uv, out_uv, dest_stride_uv);
+}
+
+inline void WriteOverlapLine8bpp_NEON(const int8_t* noise_stripe_row,
+ const int8_t* noise_stripe_row_prev,
+ int plane_width,
+ const int8x8_t grain_coeff,
+ const int8x8_t old_coeff,
+ int8_t* noise_image_row) {
+ int x = 0;
+ do {
+ // Note that these reads may exceed noise_stripe_row's width by up to 7
+ // bytes.
+ const int8x8_t source_grain = vld1_s8(noise_stripe_row + x);
+ const int8x8_t source_old = vld1_s8(noise_stripe_row_prev + x);
+ const int16x8_t weighted_grain = vmull_s8(grain_coeff, source_grain);
+ const int16x8_t grain = vmlal_s8(weighted_grain, old_coeff, source_old);
+ // Note that this write may exceed noise_image_row's width by up to 7 bytes.
+ vst1_s8(noise_image_row + x, vqrshrn_n_s16(grain, 5));
+ x += 8;
+ } while (x < plane_width);
+}
+
+void ConstructNoiseImageOverlap8bpp_NEON(const void* noise_stripes_buffer,
+ int width, int height,
+ int subsampling_x, int subsampling_y,
+ void* noise_image_buffer) {
+ const auto* noise_stripes =
+ static_cast<const Array2DView<int8_t>*>(noise_stripes_buffer);
+ auto* noise_image = static_cast<Array2D<int8_t>*>(noise_image_buffer);
+ const int plane_width = (width + subsampling_x) >> subsampling_x;
+ const int plane_height = (height + subsampling_y) >> subsampling_y;
+ const int stripe_height = 32 >> subsampling_y;
+ const int stripe_mask = stripe_height - 1;
+ int y = stripe_height;
+ int luma_num = 1;
+ if (subsampling_y == 0) {
+ const int8x8_t first_row_grain_coeff = vdup_n_s8(17);
+ const int8x8_t first_row_old_coeff = vdup_n_s8(27);
+ const int8x8_t second_row_grain_coeff = first_row_old_coeff;
+ const int8x8_t second_row_old_coeff = first_row_grain_coeff;
+ for (; y < (plane_height & ~stripe_mask); ++luma_num, y += stripe_height) {
+ const int8_t* noise_stripe = (*noise_stripes)[luma_num];
+ const int8_t* noise_stripe_prev = (*noise_stripes)[luma_num - 1];
+ WriteOverlapLine8bpp_NEON(
+ noise_stripe, &noise_stripe_prev[32 * plane_width], plane_width,
+ first_row_grain_coeff, first_row_old_coeff, (*noise_image)[y]);
+
+ WriteOverlapLine8bpp_NEON(&noise_stripe[plane_width],
+ &noise_stripe_prev[(32 + 1) * plane_width],
+ plane_width, second_row_grain_coeff,
+ second_row_old_coeff, (*noise_image)[y + 1]);
+ }
+ // Either one partial stripe remains (remaining_height > 0),
+ // OR image is less than one stripe high (remaining_height < 0),
+ // OR all stripes are completed (remaining_height == 0).
+ const int remaining_height = plane_height - y;
+ if (remaining_height <= 0) {
+ return;
+ }
+ const int8_t* noise_stripe = (*noise_stripes)[luma_num];
+ const int8_t* noise_stripe_prev = (*noise_stripes)[luma_num - 1];
+ WriteOverlapLine8bpp_NEON(
+ noise_stripe, &noise_stripe_prev[32 * plane_width], plane_width,
+ first_row_grain_coeff, first_row_old_coeff, (*noise_image)[y]);
+
+ if (remaining_height > 1) {
+ WriteOverlapLine8bpp_NEON(&noise_stripe[plane_width],
+ &noise_stripe_prev[(32 + 1) * plane_width],
+ plane_width, second_row_grain_coeff,
+ second_row_old_coeff, (*noise_image)[y + 1]);
+ }
+ } else { // subsampling_y == 1
+ const int8x8_t first_row_grain_coeff = vdup_n_s8(22);
+ const int8x8_t first_row_old_coeff = vdup_n_s8(23);
+ for (; y < plane_height; ++luma_num, y += stripe_height) {
+ const int8_t* noise_stripe = (*noise_stripes)[luma_num];
+ const int8_t* noise_stripe_prev = (*noise_stripes)[luma_num - 1];
+ WriteOverlapLine8bpp_NEON(
+ noise_stripe, &noise_stripe_prev[16 * plane_width], plane_width,
+ first_row_grain_coeff, first_row_old_coeff, (*noise_image)[y]);
+ }
+ }
+}
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+
+ // LumaAutoRegressionFunc
+ dsp->film_grain.luma_auto_regression[0] =
+ ApplyAutoRegressiveFilterToLumaGrain_NEON<8, int8_t, 1>;
+ dsp->film_grain.luma_auto_regression[1] =
+ ApplyAutoRegressiveFilterToLumaGrain_NEON<8, int8_t, 2>;
+ dsp->film_grain.luma_auto_regression[2] =
+ ApplyAutoRegressiveFilterToLumaGrain_NEON<8, int8_t, 3>;
+
+ // ChromaAutoRegressionFunc[use_luma][auto_regression_coeff_lag]
+ // Chroma autoregression should never be called when lag is 0 and use_luma
+ // is false.
+ dsp->film_grain.chroma_auto_regression[0][0] = nullptr;
+ dsp->film_grain.chroma_auto_regression[0][1] =
+ ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 1, false>;
+ dsp->film_grain.chroma_auto_regression[0][2] =
+ ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 2, false>;
+ dsp->film_grain.chroma_auto_regression[0][3] =
+ ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 3, false>;
+ dsp->film_grain.chroma_auto_regression[1][0] =
+ ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 0, true>;
+ dsp->film_grain.chroma_auto_regression[1][1] =
+ ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 1, true>;
+ dsp->film_grain.chroma_auto_regression[1][2] =
+ ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 2, true>;
+ dsp->film_grain.chroma_auto_regression[1][3] =
+ ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 3, true>;
+
+ dsp->film_grain.construct_noise_image_overlap =
+ ConstructNoiseImageOverlap8bpp_NEON;
+
+ dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_NEON;
+
+ dsp->film_grain.blend_noise_luma =
+ BlendNoiseWithImageLuma_NEON<8, int8_t, uint8_t>;
+ dsp->film_grain.blend_noise_chroma[0] = BlendNoiseWithImageChroma8bpp_NEON;
+ dsp->film_grain.blend_noise_chroma[1] =
+ BlendNoiseWithImageChromaWithCfl_NEON<8, int8_t, uint8_t>;
+}
+
+} // namespace
+} // namespace low_bitdepth
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+namespace high_bitdepth {
+namespace {
+
+void Init10bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+ assert(dsp != nullptr);
+
+ // LumaAutoRegressionFunc
+ dsp->film_grain.luma_auto_regression[0] =
+ ApplyAutoRegressiveFilterToLumaGrain_NEON<10, int16_t, 1>;
+ dsp->film_grain.luma_auto_regression[1] =
+ ApplyAutoRegressiveFilterToLumaGrain_NEON<10, int16_t, 2>;
+ dsp->film_grain.luma_auto_regression[2] =
+ ApplyAutoRegressiveFilterToLumaGrain_NEON<10, int16_t, 3>;
+
+ // ChromaAutoRegressionFunc[use_luma][auto_regression_coeff_lag][subsampling]
+ // Chroma autoregression should never be called when lag is 0 and use_luma
+ // is false.
+ dsp->film_grain.chroma_auto_regression[0][0] = nullptr;
+ dsp->film_grain.chroma_auto_regression[0][1] =
+ ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 1, false>;
+ dsp->film_grain.chroma_auto_regression[0][2] =
+ ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 2, false>;
+ dsp->film_grain.chroma_auto_regression[0][3] =
+ ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 3, false>;
+ dsp->film_grain.chroma_auto_regression[1][0] =
+ ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 0, true>;
+ dsp->film_grain.chroma_auto_regression[1][1] =
+ ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 1, true>;
+ dsp->film_grain.chroma_auto_regression[1][2] =
+ ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 2, true>;
+ dsp->film_grain.chroma_auto_regression[1][3] =
+ ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 3, true>;
+
+ dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_NEON;
+
+ dsp->film_grain.blend_noise_luma =
+ BlendNoiseWithImageLuma_NEON<10, int16_t, uint16_t>;
+ dsp->film_grain.blend_noise_chroma[1] =
+ BlendNoiseWithImageChromaWithCfl_NEON<10, int16_t, uint16_t>;
+}
+
+} // namespace
+} // namespace high_bitdepth
+#endif // LIBGAV1_MAX_BITDEPTH >= 10
+
+} // namespace film_grain
+
+void FilmGrainInit_NEON() {
+ film_grain::low_bitdepth::Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+ film_grain::high_bitdepth::Init10bpp();
+#endif // LIBGAV1_MAX_BITDEPTH >= 10
+}
+
+} // namespace dsp
+} // namespace libgav1
+
+#else // !LIBGAV1_ENABLE_NEON
+
+namespace libgav1 {
+namespace dsp {
+
+void FilmGrainInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/film_grain_neon.h b/src/dsp/arm/film_grain_neon.h
new file mode 100644
index 0000000..44b3d1d
--- /dev/null
+++ b/src/dsp/arm/film_grain_neon.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_FILM_GRAIN_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_FILM_GRAIN_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initialize members of Dsp::film_grain. This function is not thread-safe.
+void FilmGrainInit_NEON();
+
+} // namespace dsp
+} // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_FilmGrainAutoregressionLuma LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_FilmGrainAutoregressionLuma LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_FilmGrainAutoregressionChroma LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_FilmGrainAutoregressionChroma LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_FilmGrainConstructNoiseImageOverlap LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_FilmGrainInitializeScalingLutFunc LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_FilmGrainInitializeScalingLutFunc LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseLuma LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseLuma LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseChroma LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseChromaWithCfl LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseChromaWithCfl LIBGAV1_DSP_NEON
+#endif // LIBGAV1_ENABLE_NEON
+
+#endif // LIBGAV1_SRC_DSP_ARM_FILM_GRAIN_NEON_H_
diff --git a/src/dsp/arm/intra_edge_neon.cc b/src/dsp/arm/intra_edge_neon.cc
new file mode 100644
index 0000000..00b186a
--- /dev/null
+++ b/src/dsp/arm/intra_edge_neon.cc
@@ -0,0 +1,301 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/intra_edge.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <algorithm>
+#include <cassert>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h" // RightShiftWithRounding()
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+// Simplified version of intra_edge.cc:kKernels[][]. Only |strength| 1 and 2 are
+// required.
+constexpr int kKernelsNEON[3][2] = {{4, 8}, {5, 6}};
+
+void IntraEdgeFilter_NEON(void* buffer, const int size, const int strength) {
+ assert(strength == 1 || strength == 2 || strength == 3);
+ const int kernel_index = strength - 1;
+ auto* const dst_buffer = static_cast<uint8_t*>(buffer);
+
+ // The first element is not written out (but it is input) so the number of
+ // elements written is |size| - 1.
+ if (size == 1) return;
+
+ // |strength| 1 and 2 use a 3 tap filter.
+ if (strength < 3) {
+ // The last value requires extending the buffer (duplicating
+ // |dst_buffer[size - 1]). Calculate it here to avoid extra processing in
+ // neon.
+ const uint8_t last_val = RightShiftWithRounding(
+ kKernelsNEON[kernel_index][0] * dst_buffer[size - 2] +
+ kKernelsNEON[kernel_index][1] * dst_buffer[size - 1] +
+ kKernelsNEON[kernel_index][0] * dst_buffer[size - 1],
+ 4);
+
+ const uint8x8_t krn1 = vdup_n_u8(kKernelsNEON[kernel_index][1]);
+
+ // The first value we need gets overwritten by the output from the
+ // previous iteration.
+ uint8x16_t src_0 = vld1q_u8(dst_buffer);
+ int i = 1;
+
+ // Process blocks until there are less than 16 values remaining.
+ for (; i < size - 15; i += 16) {
+ // Loading these at the end of the block with |src_0| will read past the
+ // end of |top_row_data[160]|, the source of |buffer|.
+ const uint8x16_t src_1 = vld1q_u8(dst_buffer + i);
+ const uint8x16_t src_2 = vld1q_u8(dst_buffer + i + 1);
+ uint16x8_t sum_lo = vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_2));
+ sum_lo = vmulq_n_u16(sum_lo, kKernelsNEON[kernel_index][0]);
+ sum_lo = vmlal_u8(sum_lo, vget_low_u8(src_1), krn1);
+ uint16x8_t sum_hi = vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_2));
+ sum_hi = vmulq_n_u16(sum_hi, kKernelsNEON[kernel_index][0]);
+ sum_hi = vmlal_u8(sum_hi, vget_high_u8(src_1), krn1);
+
+ const uint8x16_t result =
+ vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4));
+
+ // Load the next row before overwriting. This loads an extra 15 values
+ // past |size| on the trailing iteration.
+ src_0 = vld1q_u8(dst_buffer + i + 15);
+
+ vst1q_u8(dst_buffer + i, result);
+ }
+
+ // The last output value |last_val| was already calculated so if
+ // |remainder| == 1 then we don't have to do anything.
+ const int remainder = (size - 1) & 0xf;
+ if (remainder > 1) {
+ uint8_t temp[16];
+ const uint8x16_t src_1 = vld1q_u8(dst_buffer + i);
+ const uint8x16_t src_2 = vld1q_u8(dst_buffer + i + 1);
+
+ uint16x8_t sum_lo = vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_2));
+ sum_lo = vmulq_n_u16(sum_lo, kKernelsNEON[kernel_index][0]);
+ sum_lo = vmlal_u8(sum_lo, vget_low_u8(src_1), krn1);
+ uint16x8_t sum_hi = vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_2));
+ sum_hi = vmulq_n_u16(sum_hi, kKernelsNEON[kernel_index][0]);
+ sum_hi = vmlal_u8(sum_hi, vget_high_u8(src_1), krn1);
+
+ const uint8x16_t result =
+ vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4));
+
+ vst1q_u8(temp, result);
+ memcpy(dst_buffer + i, temp, remainder);
+ }
+
+ dst_buffer[size - 1] = last_val;
+ return;
+ }
+
+ assert(strength == 3);
+ // 5 tap filter. The first element requires duplicating |buffer[0]| and the
+ // last two elements require duplicating |buffer[size - 1]|.
+ uint8_t special_vals[3];
+ special_vals[0] = RightShiftWithRounding(
+ (dst_buffer[0] << 1) + (dst_buffer[0] << 2) + (dst_buffer[1] << 2) +
+ (dst_buffer[2] << 2) + (dst_buffer[3] << 1),
+ 4);
+ // Clamp index for very small |size| values.
+ const int first_index_min = std::max(size - 4, 0);
+ const int second_index_min = std::max(size - 3, 0);
+ const int third_index_min = std::max(size - 2, 0);
+ special_vals[1] = RightShiftWithRounding(
+ (dst_buffer[first_index_min] << 1) + (dst_buffer[second_index_min] << 2) +
+ (dst_buffer[third_index_min] << 2) + (dst_buffer[size - 1] << 2) +
+ (dst_buffer[size - 1] << 1),
+ 4);
+ special_vals[2] = RightShiftWithRounding(
+ (dst_buffer[second_index_min] << 1) + (dst_buffer[third_index_min] << 2) +
+ // x << 2 + x << 2 == x << 3
+ (dst_buffer[size - 1] << 3) + (dst_buffer[size - 1] << 1),
+ 4);
+
+ // The first two values we need get overwritten by the output from the
+ // previous iteration.
+ uint8x16_t src_0 = vld1q_u8(dst_buffer - 1);
+ uint8x16_t src_1 = vld1q_u8(dst_buffer);
+ int i = 1;
+
+ for (; i < size - 15; i += 16) {
+ // Loading these at the end of the block with |src_[01]| will read past
+ // the end of |top_row_data[160]|, the source of |buffer|.
+ const uint8x16_t src_2 = vld1q_u8(dst_buffer + i);
+ const uint8x16_t src_3 = vld1q_u8(dst_buffer + i + 1);
+ const uint8x16_t src_4 = vld1q_u8(dst_buffer + i + 2);
+
+ uint16x8_t sum_lo =
+ vshlq_n_u16(vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_4)), 1);
+ const uint16x8_t sum_123_lo = vaddw_u8(
+ vaddl_u8(vget_low_u8(src_1), vget_low_u8(src_2)), vget_low_u8(src_3));
+ sum_lo = vaddq_u16(sum_lo, vshlq_n_u16(sum_123_lo, 2));
+
+ uint16x8_t sum_hi =
+ vshlq_n_u16(vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_4)), 1);
+ const uint16x8_t sum_123_hi =
+ vaddw_u8(vaddl_u8(vget_high_u8(src_1), vget_high_u8(src_2)),
+ vget_high_u8(src_3));
+ sum_hi = vaddq_u16(sum_hi, vshlq_n_u16(sum_123_hi, 2));
+
+ const uint8x16_t result =
+ vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4));
+
+ src_0 = vld1q_u8(dst_buffer + i + 14);
+ src_1 = vld1q_u8(dst_buffer + i + 15);
+
+ vst1q_u8(dst_buffer + i, result);
+ }
+
+ const int remainder = (size - 1) & 0xf;
+ // Like the 3 tap but if there are two remaining values we have already
+ // calculated them.
+ if (remainder > 2) {
+ uint8_t temp[16];
+ const uint8x16_t src_2 = vld1q_u8(dst_buffer + i);
+ const uint8x16_t src_3 = vld1q_u8(dst_buffer + i + 1);
+ const uint8x16_t src_4 = vld1q_u8(dst_buffer + i + 2);
+
+ uint16x8_t sum_lo =
+ vshlq_n_u16(vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_4)), 1);
+ const uint16x8_t sum_123_lo = vaddw_u8(
+ vaddl_u8(vget_low_u8(src_1), vget_low_u8(src_2)), vget_low_u8(src_3));
+ sum_lo = vaddq_u16(sum_lo, vshlq_n_u16(sum_123_lo, 2));
+
+ uint16x8_t sum_hi =
+ vshlq_n_u16(vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_4)), 1);
+ const uint16x8_t sum_123_hi =
+ vaddw_u8(vaddl_u8(vget_high_u8(src_1), vget_high_u8(src_2)),
+ vget_high_u8(src_3));
+ sum_hi = vaddq_u16(sum_hi, vshlq_n_u16(sum_123_hi, 2));
+
+ const uint8x16_t result =
+ vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4));
+
+ vst1q_u8(temp, result);
+ memcpy(dst_buffer + i, temp, remainder);
+ }
+
+ dst_buffer[1] = special_vals[0];
+ // Avoid overwriting |dst_buffer[0]|.
+ if (size > 2) dst_buffer[size - 2] = special_vals[1];
+ dst_buffer[size - 1] = special_vals[2];
+}
+
+// (-|src0| + |src1| * 9 + |src2| * 9 - |src3|) >> 4
+uint8x8_t Upsample(const uint8x8_t src0, const uint8x8_t src1,
+ const uint8x8_t src2, const uint8x8_t src3) {
+ const uint16x8_t middle = vmulq_n_u16(vaddl_u8(src1, src2), 9);
+ const uint16x8_t ends = vaddl_u8(src0, src3);
+ const int16x8_t sum =
+ vsubq_s16(vreinterpretq_s16_u16(middle), vreinterpretq_s16_u16(ends));
+ return vqrshrun_n_s16(sum, 4);
+}
+
+void IntraEdgeUpsampler_NEON(void* buffer, const int size) {
+ assert(size % 4 == 0 && size <= 16);
+ auto* const pixel_buffer = static_cast<uint8_t*>(buffer);
+ // This is OK because we don't read this value for |size| 4 or 8 but if we
+ // write |pixel_buffer[size]| and then vld() it, that seems to introduce
+ // some latency.
+ pixel_buffer[-2] = pixel_buffer[-1];
+ if (size == 4) {
+ // This uses one load and two vtbl() which is better than 4x Load{Lo,Hi}4().
+ const uint8x8_t src = vld1_u8(pixel_buffer - 1);
+ // The outside values are negated so put those in the same vector.
+ const uint8x8_t src03 = vtbl1_u8(src, vcreate_u8(0x0404030202010000));
+ // Reverse |src1| and |src2| so we can use |src2| for the interleave at the
+ // end.
+ const uint8x8_t src21 = vtbl1_u8(src, vcreate_u8(0x0302010004030201));
+
+ const uint16x8_t middle = vmull_u8(src21, vdup_n_u8(9));
+ const int16x8_t half_sum = vsubq_s16(
+ vreinterpretq_s16_u16(middle), vreinterpretq_s16_u16(vmovl_u8(src03)));
+ const int16x4_t sum =
+ vadd_s16(vget_low_s16(half_sum), vget_high_s16(half_sum));
+ const uint8x8_t result = vqrshrun_n_s16(vcombine_s16(sum, sum), 4);
+
+ vst1_u8(pixel_buffer - 1, InterleaveLow8(result, src21));
+ return;
+ } else if (size == 8) {
+ // Likewise, one load + multiple vtbls seems preferred to multiple loads.
+ const uint8x16_t src = vld1q_u8(pixel_buffer - 1);
+ const uint8x8_t src0 = VQTbl1U8(src, vcreate_u8(0x0605040302010000));
+ const uint8x8_t src1 = vget_low_u8(src);
+ const uint8x8_t src2 = VQTbl1U8(src, vcreate_u8(0x0807060504030201));
+ const uint8x8_t src3 = VQTbl1U8(src, vcreate_u8(0x0808070605040302));
+
+ const uint8x8x2_t output = {Upsample(src0, src1, src2, src3), src2};
+ vst2_u8(pixel_buffer - 1, output);
+ return;
+ }
+ assert(size == 12 || size == 16);
+ // Extend the input borders to avoid branching later.
+ pixel_buffer[size] = pixel_buffer[size - 1];
+ const uint8x16_t src0 = vld1q_u8(pixel_buffer - 2);
+ const uint8x16_t src1 = vld1q_u8(pixel_buffer - 1);
+ const uint8x16_t src2 = vld1q_u8(pixel_buffer);
+ const uint8x16_t src3 = vld1q_u8(pixel_buffer + 1);
+
+ const uint8x8_t result_lo = Upsample(vget_low_u8(src0), vget_low_u8(src1),
+ vget_low_u8(src2), vget_low_u8(src3));
+
+ const uint8x8x2_t output_lo = {result_lo, vget_low_u8(src2)};
+ vst2_u8(pixel_buffer - 1, output_lo);
+
+ const uint8x8_t result_hi = Upsample(vget_high_u8(src0), vget_high_u8(src1),
+ vget_high_u8(src2), vget_high_u8(src3));
+
+ if (size == 12) {
+ vst1_u8(pixel_buffer + 15, InterleaveLow8(result_hi, vget_high_u8(src2)));
+ } else /* size == 16 */ {
+ const uint8x8x2_t output_hi = {result_hi, vget_high_u8(src2)};
+ vst2_u8(pixel_buffer + 15, output_hi);
+ }
+}
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+ dsp->intra_edge_filter = IntraEdgeFilter_NEON;
+ dsp->intra_edge_upsampler = IntraEdgeUpsampler_NEON;
+}
+
+} // namespace
+
+void IntraEdgeInit_NEON() { Init8bpp(); }
+
+} // namespace dsp
+} // namespace libgav1
+
+#else // !LIBGAV1_ENABLE_NEON
+namespace libgav1 {
+namespace dsp {
+
+void IntraEdgeInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/intra_edge_neon.h b/src/dsp/arm/intra_edge_neon.h
new file mode 100644
index 0000000..d3bb243
--- /dev/null
+++ b/src/dsp/arm/intra_edge_neon.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_INTRA_EDGE_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_INTRA_EDGE_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::intra_edge_filter and Dsp::intra_edge_upsampler. This
+// function is not thread-safe.
+void IntraEdgeInit_NEON();
+
+} // namespace dsp
+} // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_IntraEdgeFilter LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_IntraEdgeUpsampler LIBGAV1_CPU_NEON
+
+#endif // LIBGAV1_ENABLE_NEON
+
+#endif // LIBGAV1_SRC_DSP_ARM_INTRA_EDGE_NEON_H_
diff --git a/src/dsp/arm/intrapred_cfl_neon.cc b/src/dsp/arm/intrapred_cfl_neon.cc
new file mode 100644
index 0000000..45fe33b
--- /dev/null
+++ b/src/dsp/arm/intrapred_cfl_neon.cc
@@ -0,0 +1,479 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/intrapred.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+uint8x16_t Set2ValuesQ(const uint8_t* a) {
+ uint16_t combined_values = a[0] | a[1] << 8;
+ return vreinterpretq_u8_u16(vdupq_n_u16(combined_values));
+}
+
+uint32_t SumVector(uint32x2_t a) {
+#if defined(__aarch64__)
+ return vaddv_u32(a);
+#else
+ const uint64x1_t b = vpaddl_u32(a);
+ return vget_lane_u32(vreinterpret_u32_u64(b), 0);
+#endif // defined(__aarch64__)
+}
+
+uint32_t SumVector(uint32x4_t a) {
+#if defined(__aarch64__)
+ return vaddvq_u32(a);
+#else
+ const uint64x2_t b = vpaddlq_u32(a);
+ const uint64x1_t c = vadd_u64(vget_low_u64(b), vget_high_u64(b));
+ return vget_lane_u32(vreinterpret_u32_u64(c), 0);
+#endif // defined(__aarch64__)
+}
+
+// Divide by the number of elements.
+uint32_t Average(const uint32_t sum, const int width, const int height) {
+ return RightShiftWithRounding(sum, FloorLog2(width) + FloorLog2(height));
+}
+
+// Subtract |val| from every element in |a|.
+void BlockSubtract(const uint32_t val,
+ int16_t a[kCflLumaBufferStride][kCflLumaBufferStride],
+ const int width, const int height) {
+ assert(val <= INT16_MAX);
+ const int16x8_t val_v = vdupq_n_s16(static_cast<int16_t>(val));
+
+ for (int y = 0; y < height; ++y) {
+ if (width == 4) {
+ const int16x4_t b = vld1_s16(a[y]);
+ vst1_s16(a[y], vsub_s16(b, vget_low_s16(val_v)));
+ } else if (width == 8) {
+ const int16x8_t b = vld1q_s16(a[y]);
+ vst1q_s16(a[y], vsubq_s16(b, val_v));
+ } else if (width == 16) {
+ const int16x8_t b = vld1q_s16(a[y]);
+ const int16x8_t c = vld1q_s16(a[y] + 8);
+ vst1q_s16(a[y], vsubq_s16(b, val_v));
+ vst1q_s16(a[y] + 8, vsubq_s16(c, val_v));
+ } else /* block_width == 32 */ {
+ const int16x8_t b = vld1q_s16(a[y]);
+ const int16x8_t c = vld1q_s16(a[y] + 8);
+ const int16x8_t d = vld1q_s16(a[y] + 16);
+ const int16x8_t e = vld1q_s16(a[y] + 24);
+ vst1q_s16(a[y], vsubq_s16(b, val_v));
+ vst1q_s16(a[y] + 8, vsubq_s16(c, val_v));
+ vst1q_s16(a[y] + 16, vsubq_s16(d, val_v));
+ vst1q_s16(a[y] + 24, vsubq_s16(e, val_v));
+ }
+ }
+}
+
+template <int block_width, int block_height>
+void CflSubsampler420_NEON(
+ int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
+ const int max_luma_width, const int max_luma_height,
+ const void* const source, const ptrdiff_t stride) {
+ const auto* src = static_cast<const uint8_t*>(source);
+ uint32_t sum;
+ if (block_width == 4) {
+ assert(max_luma_width >= 8);
+ uint32x2_t running_sum = vdup_n_u32(0);
+
+ for (int y = 0; y < block_height; ++y) {
+ const uint8x8_t row0 = vld1_u8(src);
+ const uint8x8_t row1 = vld1_u8(src + stride);
+
+ uint16x4_t sum_row = vpadal_u8(vpaddl_u8(row0), row1);
+ sum_row = vshl_n_u16(sum_row, 1);
+ running_sum = vpadal_u16(running_sum, sum_row);
+ vst1_s16(luma[y], vreinterpret_s16_u16(sum_row));
+
+ if (y << 1 < max_luma_height - 2) {
+ // Once this threshold is reached the loop could be simplified.
+ src += stride << 1;
+ }
+ }
+
+ sum = SumVector(running_sum);
+ } else if (block_width == 8) {
+ const uint8x16_t x_index = {0, 0, 2, 2, 4, 4, 6, 6,
+ 8, 8, 10, 10, 12, 12, 14, 14};
+ const uint8x16_t x_max_index = vdupq_n_u8(max_luma_width - 2);
+ const uint8x16_t x_mask = vcltq_u8(x_index, x_max_index);
+
+ uint32x4_t running_sum = vdupq_n_u32(0);
+
+ for (int y = 0; y < block_height; ++y) {
+ const uint8x16_t x_max0 = Set2ValuesQ(src + max_luma_width - 2);
+ const uint8x16_t x_max1 = Set2ValuesQ(src + max_luma_width - 2 + stride);
+
+ uint8x16_t row0 = vld1q_u8(src);
+ row0 = vbslq_u8(x_mask, row0, x_max0);
+ uint8x16_t row1 = vld1q_u8(src + stride);
+ row1 = vbslq_u8(x_mask, row1, x_max1);
+
+ uint16x8_t sum_row = vpadalq_u8(vpaddlq_u8(row0), row1);
+ sum_row = vshlq_n_u16(sum_row, 1);
+ running_sum = vpadalq_u16(running_sum, sum_row);
+ vst1q_s16(luma[y], vreinterpretq_s16_u16(sum_row));
+
+ if (y << 1 < max_luma_height - 2) {
+ src += stride << 1;
+ }
+ }
+
+ sum = SumVector(running_sum);
+ } else /* block_width >= 16 */ {
+ const uint8x16_t x_max_index = vdupq_n_u8(max_luma_width - 2);
+ uint32x4_t running_sum = vdupq_n_u32(0);
+
+ for (int y = 0; y < block_height; ++y) {
+ uint8x16_t x_index = {0, 2, 4, 6, 8, 10, 12, 14,
+ 16, 18, 20, 22, 24, 26, 28, 30};
+ const uint8x16_t x_max00 = vdupq_n_u8(src[max_luma_width - 2]);
+ const uint8x16_t x_max01 = vdupq_n_u8(src[max_luma_width - 2 + 1]);
+ const uint8x16_t x_max10 = vdupq_n_u8(src[stride + max_luma_width - 2]);
+ const uint8x16_t x_max11 =
+ vdupq_n_u8(src[stride + max_luma_width - 2 + 1]);
+ for (int x = 0; x < block_width; x += 16) {
+ const ptrdiff_t src_x_offset = x << 1;
+ const uint8x16_t x_mask = vcltq_u8(x_index, x_max_index);
+ const uint8x16x2_t row0 = vld2q_u8(src + src_x_offset);
+ const uint8x16x2_t row1 = vld2q_u8(src + src_x_offset + stride);
+ const uint8x16_t row_masked_00 = vbslq_u8(x_mask, row0.val[0], x_max00);
+ const uint8x16_t row_masked_01 = vbslq_u8(x_mask, row0.val[1], x_max01);
+ const uint8x16_t row_masked_10 = vbslq_u8(x_mask, row1.val[0], x_max10);
+ const uint8x16_t row_masked_11 = vbslq_u8(x_mask, row1.val[1], x_max11);
+
+ uint16x8_t sum_row_lo =
+ vaddl_u8(vget_low_u8(row_masked_00), vget_low_u8(row_masked_01));
+ sum_row_lo = vaddw_u8(sum_row_lo, vget_low_u8(row_masked_10));
+ sum_row_lo = vaddw_u8(sum_row_lo, vget_low_u8(row_masked_11));
+ sum_row_lo = vshlq_n_u16(sum_row_lo, 1);
+ running_sum = vpadalq_u16(running_sum, sum_row_lo);
+ vst1q_s16(luma[y] + x, vreinterpretq_s16_u16(sum_row_lo));
+
+ uint16x8_t sum_row_hi =
+ vaddl_u8(vget_high_u8(row_masked_00), vget_high_u8(row_masked_01));
+ sum_row_hi = vaddw_u8(sum_row_hi, vget_high_u8(row_masked_10));
+ sum_row_hi = vaddw_u8(sum_row_hi, vget_high_u8(row_masked_11));
+ sum_row_hi = vshlq_n_u16(sum_row_hi, 1);
+ running_sum = vpadalq_u16(running_sum, sum_row_hi);
+ vst1q_s16(luma[y] + x + 8, vreinterpretq_s16_u16(sum_row_hi));
+
+ x_index = vaddq_u8(x_index, vdupq_n_u8(32));
+ }
+ if (y << 1 < max_luma_height - 2) {
+ src += stride << 1;
+ }
+ }
+ sum = SumVector(running_sum);
+ }
+
+ const uint32_t average = Average(sum, block_width, block_height);
+ BlockSubtract(average, luma, block_width, block_height);
+}
+
+template <int block_width, int block_height>
+void CflSubsampler444_NEON(
+ int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
+ const int max_luma_width, const int max_luma_height,
+ const void* const source, const ptrdiff_t stride) {
+ const auto* src = static_cast<const uint8_t*>(source);
+ uint32_t sum;
+ if (block_width == 4) {
+ assert(max_luma_width >= 4);
+ uint32x4_t running_sum = vdupq_n_u32(0);
+ uint8x8_t row = vdup_n_u8(0);
+
+ for (int y = 0; y < block_height; y += 2) {
+ row = Load4<0>(src, row);
+ row = Load4<1>(src + stride, row);
+ if (y < (max_luma_height - 1)) {
+ src += stride << 1;
+ }
+
+ const uint16x8_t row_shifted = vshll_n_u8(row, 3);
+ running_sum = vpadalq_u16(running_sum, row_shifted);
+ vst1_s16(luma[y], vreinterpret_s16_u16(vget_low_u16(row_shifted)));
+ vst1_s16(luma[y + 1], vreinterpret_s16_u16(vget_high_u16(row_shifted)));
+ }
+
+ sum = SumVector(running_sum);
+ } else if (block_width == 8) {
+ const uint8x8_t x_index = {0, 1, 2, 3, 4, 5, 6, 7};
+ const uint8x8_t x_max_index = vdup_n_u8(max_luma_width - 1);
+ const uint8x8_t x_mask = vclt_u8(x_index, x_max_index);
+
+ uint32x4_t running_sum = vdupq_n_u32(0);
+
+ for (int y = 0; y < block_height; ++y) {
+ const uint8x8_t x_max = vdup_n_u8(src[max_luma_width - 1]);
+ const uint8x8_t row = vbsl_u8(x_mask, vld1_u8(src), x_max);
+
+ const uint16x8_t row_shifted = vshll_n_u8(row, 3);
+ running_sum = vpadalq_u16(running_sum, row_shifted);
+ vst1q_s16(luma[y], vreinterpretq_s16_u16(row_shifted));
+
+ if (y < max_luma_height - 1) {
+ src += stride;
+ }
+ }
+
+ sum = SumVector(running_sum);
+ } else /* block_width >= 16 */ {
+ const uint8x16_t x_max_index = vdupq_n_u8(max_luma_width - 1);
+ uint32x4_t running_sum = vdupq_n_u32(0);
+
+ for (int y = 0; y < block_height; ++y) {
+ uint8x16_t x_index = {0, 1, 2, 3, 4, 5, 6, 7,
+ 8, 9, 10, 11, 12, 13, 14, 15};
+ const uint8x16_t x_max = vdupq_n_u8(src[max_luma_width - 1]);
+ for (int x = 0; x < block_width; x += 16) {
+ const uint8x16_t x_mask = vcltq_u8(x_index, x_max_index);
+ const uint8x16_t row = vbslq_u8(x_mask, vld1q_u8(src + x), x_max);
+
+ const uint16x8_t row_shifted_low = vshll_n_u8(vget_low_u8(row), 3);
+ const uint16x8_t row_shifted_high = vshll_n_u8(vget_high_u8(row), 3);
+ running_sum = vpadalq_u16(running_sum, row_shifted_low);
+ running_sum = vpadalq_u16(running_sum, row_shifted_high);
+ vst1q_s16(luma[y] + x, vreinterpretq_s16_u16(row_shifted_low));
+ vst1q_s16(luma[y] + x + 8, vreinterpretq_s16_u16(row_shifted_high));
+
+ x_index = vaddq_u8(x_index, vdupq_n_u8(16));
+ }
+ if (y < max_luma_height - 1) {
+ src += stride;
+ }
+ }
+ sum = SumVector(running_sum);
+ }
+
+ const uint32_t average = Average(sum, block_width, block_height);
+ BlockSubtract(average, luma, block_width, block_height);
+}
+
+// Saturate |dc + ((alpha * luma) >> 6))| to uint8_t.
+inline uint8x8_t Combine8(const int16x8_t luma, const int alpha,
+ const int16x8_t dc) {
+ const int16x8_t la = vmulq_n_s16(luma, alpha);
+ // Subtract the sign bit to round towards zero.
+ const int16x8_t sub_sign = vsraq_n_s16(la, la, 15);
+ // Shift and accumulate.
+ const int16x8_t result = vrsraq_n_s16(dc, sub_sign, 6);
+ return vqmovun_s16(result);
+}
+
+// The range of luma/alpha is not really important because it gets saturated to
+// uint8_t. Saturated int16_t >> 6 outranges uint8_t.
+template <int block_height>
+inline void CflIntraPredictor4xN_NEON(
+ void* const dest, const ptrdiff_t stride,
+ const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
+ const int alpha) {
+ auto* dst = static_cast<uint8_t*>(dest);
+ const int16x8_t dc = vdupq_n_s16(dst[0]);
+ for (int y = 0; y < block_height; y += 2) {
+ const int16x4_t luma_row0 = vld1_s16(luma[y]);
+ const int16x4_t luma_row1 = vld1_s16(luma[y + 1]);
+ const uint8x8_t sum =
+ Combine8(vcombine_s16(luma_row0, luma_row1), alpha, dc);
+ StoreLo4(dst, sum);
+ dst += stride;
+ StoreHi4(dst, sum);
+ dst += stride;
+ }
+}
+
+template <int block_height>
+inline void CflIntraPredictor8xN_NEON(
+ void* const dest, const ptrdiff_t stride,
+ const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
+ const int alpha) {
+ auto* dst = static_cast<uint8_t*>(dest);
+ const int16x8_t dc = vdupq_n_s16(dst[0]);
+ for (int y = 0; y < block_height; ++y) {
+ const int16x8_t luma_row = vld1q_s16(luma[y]);
+ const uint8x8_t sum = Combine8(luma_row, alpha, dc);
+ vst1_u8(dst, sum);
+ dst += stride;
+ }
+}
+
+template <int block_height>
+inline void CflIntraPredictor16xN_NEON(
+ void* const dest, const ptrdiff_t stride,
+ const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
+ const int alpha) {
+ auto* dst = static_cast<uint8_t*>(dest);
+ const int16x8_t dc = vdupq_n_s16(dst[0]);
+ for (int y = 0; y < block_height; ++y) {
+ const int16x8_t luma_row_0 = vld1q_s16(luma[y]);
+ const int16x8_t luma_row_1 = vld1q_s16(luma[y] + 8);
+ const uint8x8_t sum_0 = Combine8(luma_row_0, alpha, dc);
+ const uint8x8_t sum_1 = Combine8(luma_row_1, alpha, dc);
+ vst1_u8(dst, sum_0);
+ vst1_u8(dst + 8, sum_1);
+ dst += stride;
+ }
+}
+
+template <int block_height>
+inline void CflIntraPredictor32xN_NEON(
+ void* const dest, const ptrdiff_t stride,
+ const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
+ const int alpha) {
+ auto* dst = static_cast<uint8_t*>(dest);
+ const int16x8_t dc = vdupq_n_s16(dst[0]);
+ for (int y = 0; y < block_height; ++y) {
+ const int16x8_t luma_row_0 = vld1q_s16(luma[y]);
+ const int16x8_t luma_row_1 = vld1q_s16(luma[y] + 8);
+ const int16x8_t luma_row_2 = vld1q_s16(luma[y] + 16);
+ const int16x8_t luma_row_3 = vld1q_s16(luma[y] + 24);
+ const uint8x8_t sum_0 = Combine8(luma_row_0, alpha, dc);
+ const uint8x8_t sum_1 = Combine8(luma_row_1, alpha, dc);
+ const uint8x8_t sum_2 = Combine8(luma_row_2, alpha, dc);
+ const uint8x8_t sum_3 = Combine8(luma_row_3, alpha, dc);
+ vst1_u8(dst, sum_0);
+ vst1_u8(dst + 8, sum_1);
+ vst1_u8(dst + 16, sum_2);
+ vst1_u8(dst + 24, sum_3);
+ dst += stride;
+ }
+}
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+
+ dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType420] =
+ CflSubsampler420_NEON<4, 4>;
+ dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType420] =
+ CflSubsampler420_NEON<4, 8>;
+ dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType420] =
+ CflSubsampler420_NEON<4, 16>;
+
+ dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType420] =
+ CflSubsampler420_NEON<8, 4>;
+ dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType420] =
+ CflSubsampler420_NEON<8, 8>;
+ dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType420] =
+ CflSubsampler420_NEON<8, 16>;
+ dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType420] =
+ CflSubsampler420_NEON<8, 32>;
+
+ dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType420] =
+ CflSubsampler420_NEON<16, 4>;
+ dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType420] =
+ CflSubsampler420_NEON<16, 8>;
+ dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType420] =
+ CflSubsampler420_NEON<16, 16>;
+ dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType420] =
+ CflSubsampler420_NEON<16, 32>;
+
+ dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType420] =
+ CflSubsampler420_NEON<32, 8>;
+ dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType420] =
+ CflSubsampler420_NEON<32, 16>;
+ dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType420] =
+ CflSubsampler420_NEON<32, 32>;
+
+ dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType444] =
+ CflSubsampler444_NEON<4, 4>;
+ dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType444] =
+ CflSubsampler444_NEON<4, 8>;
+ dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType444] =
+ CflSubsampler444_NEON<4, 16>;
+
+ dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType444] =
+ CflSubsampler444_NEON<8, 4>;
+ dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType444] =
+ CflSubsampler444_NEON<8, 8>;
+ dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType444] =
+ CflSubsampler444_NEON<8, 16>;
+ dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType444] =
+ CflSubsampler444_NEON<8, 32>;
+
+ dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType444] =
+ CflSubsampler444_NEON<16, 4>;
+ dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType444] =
+ CflSubsampler444_NEON<16, 8>;
+ dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType444] =
+ CflSubsampler444_NEON<16, 16>;
+ dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType444] =
+ CflSubsampler444_NEON<16, 32>;
+
+ dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType444] =
+ CflSubsampler444_NEON<32, 8>;
+ dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType444] =
+ CflSubsampler444_NEON<32, 16>;
+ dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType444] =
+ CflSubsampler444_NEON<32, 32>;
+
+ dsp->cfl_intra_predictors[kTransformSize4x4] = CflIntraPredictor4xN_NEON<4>;
+ dsp->cfl_intra_predictors[kTransformSize4x8] = CflIntraPredictor4xN_NEON<8>;
+ dsp->cfl_intra_predictors[kTransformSize4x16] = CflIntraPredictor4xN_NEON<16>;
+
+ dsp->cfl_intra_predictors[kTransformSize8x4] = CflIntraPredictor8xN_NEON<4>;
+ dsp->cfl_intra_predictors[kTransformSize8x8] = CflIntraPredictor8xN_NEON<8>;
+ dsp->cfl_intra_predictors[kTransformSize8x16] = CflIntraPredictor8xN_NEON<16>;
+ dsp->cfl_intra_predictors[kTransformSize8x32] = CflIntraPredictor8xN_NEON<32>;
+
+ dsp->cfl_intra_predictors[kTransformSize16x4] = CflIntraPredictor16xN_NEON<4>;
+ dsp->cfl_intra_predictors[kTransformSize16x8] = CflIntraPredictor16xN_NEON<8>;
+ dsp->cfl_intra_predictors[kTransformSize16x16] =
+ CflIntraPredictor16xN_NEON<16>;
+ dsp->cfl_intra_predictors[kTransformSize16x32] =
+ CflIntraPredictor16xN_NEON<32>;
+
+ dsp->cfl_intra_predictors[kTransformSize32x8] = CflIntraPredictor32xN_NEON<8>;
+ dsp->cfl_intra_predictors[kTransformSize32x16] =
+ CflIntraPredictor32xN_NEON<16>;
+ dsp->cfl_intra_predictors[kTransformSize32x32] =
+ CflIntraPredictor32xN_NEON<32>;
+ // Max Cfl predictor size is 32x32.
+}
+
+} // namespace
+} // namespace low_bitdepth
+
+void IntraPredCflInit_NEON() { low_bitdepth::Init8bpp(); }
+
+} // namespace dsp
+} // namespace libgav1
+
+#else // !LIBGAV1_ENABLE_NEON
+namespace libgav1 {
+namespace dsp {
+
+void IntraPredCflInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/intrapred_directional_neon.cc b/src/dsp/arm/intrapred_directional_neon.cc
new file mode 100644
index 0000000..805ba81
--- /dev/null
+++ b/src/dsp/arm/intrapred_directional_neon.cc
@@ -0,0 +1,926 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/intrapred.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <algorithm> // std::min
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <cstring> // memset
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+// Blend two values based on a 32 bit weight.
+inline uint8x8_t WeightedBlend(const uint8x8_t a, const uint8x8_t b,
+ const uint8x8_t a_weight,
+ const uint8x8_t b_weight) {
+ const uint16x8_t a_product = vmull_u8(a, a_weight);
+ const uint16x8_t b_product = vmull_u8(b, b_weight);
+
+ return vrshrn_n_u16(vaddq_u16(a_product, b_product), 5);
+}
+
+// For vertical operations the weights are one constant value.
+inline uint8x8_t WeightedBlend(const uint8x8_t a, const uint8x8_t b,
+ const uint8_t weight) {
+ return WeightedBlend(a, b, vdup_n_u8(32 - weight), vdup_n_u8(weight));
+}
+
+// Fill |left| and |right| with the appropriate values for a given |base_step|.
+inline void LoadStepwise(const uint8_t* const source, const uint8x8_t left_step,
+ const uint8x8_t right_step, uint8x8_t* left,
+ uint8x8_t* right) {
+ const uint8x16_t mixed = vld1q_u8(source);
+ *left = VQTbl1U8(mixed, left_step);
+ *right = VQTbl1U8(mixed, right_step);
+}
+
+// Handle signed step arguments by ignoring the sign. Negative values are
+// considered out of range and overwritten later.
+inline void LoadStepwise(const uint8_t* const source, const int8x8_t left_step,
+ const int8x8_t right_step, uint8x8_t* left,
+ uint8x8_t* right) {
+ LoadStepwise(source, vreinterpret_u8_s8(left_step),
+ vreinterpret_u8_s8(right_step), left, right);
+}
+
+// Process 4 or 8 |width| by any |height|.
+template <int width>
+inline void DirectionalZone1_WxH(uint8_t* dst, const ptrdiff_t stride,
+ const int height, const uint8_t* const top,
+ const int xstep, const bool upsampled) {
+ assert(width == 4 || width == 8);
+
+ const int upsample_shift = static_cast<int>(upsampled);
+ const int scale_bits = 6 - upsample_shift;
+
+ const int max_base_x = (width + height - 1) << upsample_shift;
+ const int8x8_t max_base = vdup_n_s8(max_base_x);
+ const uint8x8_t top_max_base = vdup_n_u8(top[max_base_x]);
+
+ const int8x8_t all = vcreate_s8(0x0706050403020100);
+ const int8x8_t even = vcreate_s8(0x0e0c0a0806040200);
+ const int8x8_t base_step = upsampled ? even : all;
+ const int8x8_t right_step = vadd_s8(base_step, vdup_n_s8(1));
+
+ int top_x = xstep;
+ int y = 0;
+ do {
+ const int top_base_x = top_x >> scale_bits;
+
+ if (top_base_x >= max_base_x) {
+ for (int i = y; i < height; ++i) {
+ memset(dst, top[max_base_x], 4 /* width */);
+ dst += stride;
+ }
+ return;
+ }
+
+ const uint8_t shift = ((top_x << upsample_shift) & 0x3F) >> 1;
+
+ // Zone2 uses negative values for xstep. Use signed values to compare
+ // |top_base_x| to |max_base_x|.
+ const int8x8_t base_v = vadd_s8(vdup_n_s8(top_base_x), base_step);
+
+ const uint8x8_t max_base_mask = vclt_s8(base_v, max_base);
+
+ // 4 wide subsamples the output. 8 wide subsamples the input.
+ if (width == 4) {
+ const uint8x8_t left_values = vld1_u8(top + top_base_x);
+ const uint8x8_t right_values = RightShift<8>(left_values);
+ const uint8x8_t value = WeightedBlend(left_values, right_values, shift);
+
+ // If |upsampled| is true then extract every other value for output.
+ const uint8x8_t value_stepped =
+ vtbl1_u8(value, vreinterpret_u8_s8(base_step));
+ const uint8x8_t masked_value =
+ vbsl_u8(max_base_mask, value_stepped, top_max_base);
+
+ StoreLo4(dst, masked_value);
+ } else /* width == 8 */ {
+ uint8x8_t left_values, right_values;
+ // WeightedBlend() steps up to Q registers. Downsample the input to avoid
+ // doing extra calculations.
+ LoadStepwise(top + top_base_x, base_step, right_step, &left_values,
+ &right_values);
+
+ const uint8x8_t value = WeightedBlend(left_values, right_values, shift);
+ const uint8x8_t masked_value =
+ vbsl_u8(max_base_mask, value, top_max_base);
+
+ vst1_u8(dst, masked_value);
+ }
+ dst += stride;
+ top_x += xstep;
+ } while (++y < height);
+}
+
+// Process a multiple of 8 |width| by any |height|. Processes horizontally
+// before vertically in the hopes of being a little more cache friendly.
+inline void DirectionalZone1_WxH(uint8_t* dst, const ptrdiff_t stride,
+ const int width, const int height,
+ const uint8_t* const top, const int xstep,
+ const bool upsampled) {
+ assert(width % 8 == 0);
+ const int upsample_shift = static_cast<int>(upsampled);
+ const int scale_bits = 6 - upsample_shift;
+
+ const int max_base_x = (width + height - 1) << upsample_shift;
+ const int8x8_t max_base = vdup_n_s8(max_base_x);
+ const uint8x8_t top_max_base = vdup_n_u8(top[max_base_x]);
+
+ const int8x8_t all = vcreate_s8(0x0706050403020100);
+ const int8x8_t even = vcreate_s8(0x0e0c0a0806040200);
+ const int8x8_t base_step = upsampled ? even : all;
+ const int8x8_t right_step = vadd_s8(base_step, vdup_n_s8(1));
+ const int8x8_t block_step = vdup_n_s8(8 << upsample_shift);
+
+ int top_x = xstep;
+ int y = 0;
+ do {
+ const int top_base_x = top_x >> scale_bits;
+
+ if (top_base_x >= max_base_x) {
+ for (int i = y; i < height; ++i) {
+ memset(dst, top[max_base_x], 4 /* width */);
+ dst += stride;
+ }
+ return;
+ }
+
+ const uint8_t shift = ((top_x << upsample_shift) & 0x3F) >> 1;
+
+ // Zone2 uses negative values for xstep. Use signed values to compare
+ // |top_base_x| to |max_base_x|.
+ int8x8_t base_v = vadd_s8(vdup_n_s8(top_base_x), base_step);
+
+ int x = 0;
+ do {
+ const uint8x8_t max_base_mask = vclt_s8(base_v, max_base);
+
+ // Extract the input values based on |upsampled| here to avoid doing twice
+ // as many calculations.
+ uint8x8_t left_values, right_values;
+ LoadStepwise(top + top_base_x + x, base_step, right_step, &left_values,
+ &right_values);
+
+ const uint8x8_t value = WeightedBlend(left_values, right_values, shift);
+ const uint8x8_t masked_value =
+ vbsl_u8(max_base_mask, value, top_max_base);
+
+ vst1_u8(dst + x, masked_value);
+
+ base_v = vadd_s8(base_v, block_step);
+ x += 8;
+ } while (x < width);
+ top_x += xstep;
+ dst += stride;
+ } while (++y < height);
+}
+
+void DirectionalIntraPredictorZone1_NEON(void* const dest,
+ const ptrdiff_t stride,
+ const void* const top_row,
+ const int width, const int height,
+ const int xstep,
+ const bool upsampled_top) {
+ const uint8_t* const top = static_cast<const uint8_t*>(top_row);
+ uint8_t* dst = static_cast<uint8_t*>(dest);
+
+ assert(xstep > 0);
+
+ const int upsample_shift = static_cast<int>(upsampled_top);
+
+ const uint8x8_t all = vcreate_u8(0x0706050403020100);
+
+ if (xstep == 64) {
+ assert(!upsampled_top);
+ const uint8_t* top_ptr = top + 1;
+ int y = 0;
+ do {
+ memcpy(dst, top_ptr, width);
+ memcpy(dst + stride, top_ptr + 1, width);
+ memcpy(dst + 2 * stride, top_ptr + 2, width);
+ memcpy(dst + 3 * stride, top_ptr + 3, width);
+ dst += 4 * stride;
+ top_ptr += 4;
+ y += 4;
+ } while (y < height);
+ } else if (width == 4) {
+ DirectionalZone1_WxH<4>(dst, stride, height, top, xstep, upsampled_top);
+ } else if (xstep > 51) {
+ // 7.11.2.10. Intra edge upsample selection process
+ // if ( d <= 0 || d >= 40 ) useUpsample = 0
+ // For |upsample_top| the delta is from vertical so |prediction_angle - 90|.
+ // In |kDirectionalIntraPredictorDerivative[]| angles less than 51 will meet
+ // this criteria. The |xstep| value for angle 51 happens to be 51 as well.
+ // Shallower angles have greater xstep values.
+ assert(!upsampled_top);
+ const int max_base_x = ((width + height) - 1);
+ const uint8x8_t max_base = vdup_n_u8(max_base_x);
+ const uint8x8_t top_max_base = vdup_n_u8(top[max_base_x]);
+ const uint8x8_t block_step = vdup_n_u8(8);
+
+ int top_x = xstep;
+ int y = 0;
+ do {
+ const int top_base_x = top_x >> 6;
+ const uint8_t shift = ((top_x << upsample_shift) & 0x3F) >> 1;
+ uint8x8_t base_v = vadd_u8(vdup_n_u8(top_base_x), all);
+ int x = 0;
+ // Only calculate a block of 8 when at least one of the output values is
+ // within range. Otherwise it can read off the end of |top|.
+ const int must_calculate_width =
+ std::min(width, max_base_x - top_base_x + 7) & ~7;
+ for (; x < must_calculate_width; x += 8) {
+ const uint8x8_t max_base_mask = vclt_u8(base_v, max_base);
+
+ // Since these |xstep| values can not be upsampled the load is
+ // simplified.
+ const uint8x8_t left_values = vld1_u8(top + top_base_x + x);
+ const uint8x8_t right_values = vld1_u8(top + top_base_x + x + 1);
+ const uint8x8_t value = WeightedBlend(left_values, right_values, shift);
+ const uint8x8_t masked_value =
+ vbsl_u8(max_base_mask, value, top_max_base);
+
+ vst1_u8(dst + x, masked_value);
+ base_v = vadd_u8(base_v, block_step);
+ }
+ memset(dst + x, top[max_base_x], width - x);
+ dst += stride;
+ top_x += xstep;
+ } while (++y < height);
+ } else {
+ DirectionalZone1_WxH(dst, stride, width, height, top, xstep, upsampled_top);
+ }
+}
+
+// Process 4 or 8 |width| by 4 or 8 |height|.
+template <int width>
+inline void DirectionalZone3_WxH(uint8_t* dest, const ptrdiff_t stride,
+ const int height,
+ const uint8_t* const left_column,
+ const int base_left_y, const int ystep,
+ const int upsample_shift) {
+ assert(width == 4 || width == 8);
+ assert(height == 4 || height == 8);
+ const int scale_bits = 6 - upsample_shift;
+
+ // Zone3 never runs out of left_column values.
+ assert((width + height - 1) << upsample_shift > // max_base_y
+ ((ystep * width) >> scale_bits) +
+ (/* base_step */ 1 << upsample_shift) *
+ (height - 1)); // left_base_y
+
+ // Limited improvement for 8x8. ~20% faster for 64x64.
+ const uint8x8_t all = vcreate_u8(0x0706050403020100);
+ const uint8x8_t even = vcreate_u8(0x0e0c0a0806040200);
+ const uint8x8_t base_step = upsample_shift ? even : all;
+ const uint8x8_t right_step = vadd_u8(base_step, vdup_n_u8(1));
+
+ uint8_t* dst = dest;
+ uint8x8_t left_v[8], right_v[8], value_v[8];
+ const uint8_t* const left = left_column;
+
+ const int index_0 = base_left_y;
+ LoadStepwise(left + (index_0 >> scale_bits), base_step, right_step,
+ &left_v[0], &right_v[0]);
+ value_v[0] = WeightedBlend(left_v[0], right_v[0],
+ ((index_0 << upsample_shift) & 0x3F) >> 1);
+
+ const int index_1 = base_left_y + ystep;
+ LoadStepwise(left + (index_1 >> scale_bits), base_step, right_step,
+ &left_v[1], &right_v[1]);
+ value_v[1] = WeightedBlend(left_v[1], right_v[1],
+ ((index_1 << upsample_shift) & 0x3F) >> 1);
+
+ const int index_2 = base_left_y + ystep * 2;
+ LoadStepwise(left + (index_2 >> scale_bits), base_step, right_step,
+ &left_v[2], &right_v[2]);
+ value_v[2] = WeightedBlend(left_v[2], right_v[2],
+ ((index_2 << upsample_shift) & 0x3F) >> 1);
+
+ const int index_3 = base_left_y + ystep * 3;
+ LoadStepwise(left + (index_3 >> scale_bits), base_step, right_step,
+ &left_v[3], &right_v[3]);
+ value_v[3] = WeightedBlend(left_v[3], right_v[3],
+ ((index_3 << upsample_shift) & 0x3F) >> 1);
+
+ const int index_4 = base_left_y + ystep * 4;
+ LoadStepwise(left + (index_4 >> scale_bits), base_step, right_step,
+ &left_v[4], &right_v[4]);
+ value_v[4] = WeightedBlend(left_v[4], right_v[4],
+ ((index_4 << upsample_shift) & 0x3F) >> 1);
+
+ const int index_5 = base_left_y + ystep * 5;
+ LoadStepwise(left + (index_5 >> scale_bits), base_step, right_step,
+ &left_v[5], &right_v[5]);
+ value_v[5] = WeightedBlend(left_v[5], right_v[5],
+ ((index_5 << upsample_shift) & 0x3F) >> 1);
+
+ const int index_6 = base_left_y + ystep * 6;
+ LoadStepwise(left + (index_6 >> scale_bits), base_step, right_step,
+ &left_v[6], &right_v[6]);
+ value_v[6] = WeightedBlend(left_v[6], right_v[6],
+ ((index_6 << upsample_shift) & 0x3F) >> 1);
+
+ const int index_7 = base_left_y + ystep * 7;
+ LoadStepwise(left + (index_7 >> scale_bits), base_step, right_step,
+ &left_v[7], &right_v[7]);
+ value_v[7] = WeightedBlend(left_v[7], right_v[7],
+ ((index_7 << upsample_shift) & 0x3F) >> 1);
+
+ // 8x8 transpose.
+ const uint8x16x2_t b0 = vtrnq_u8(vcombine_u8(value_v[0], value_v[4]),
+ vcombine_u8(value_v[1], value_v[5]));
+ const uint8x16x2_t b1 = vtrnq_u8(vcombine_u8(value_v[2], value_v[6]),
+ vcombine_u8(value_v[3], value_v[7]));
+
+ const uint16x8x2_t c0 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[0]),
+ vreinterpretq_u16_u8(b1.val[0]));
+ const uint16x8x2_t c1 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[1]),
+ vreinterpretq_u16_u8(b1.val[1]));
+
+ const uint32x4x2_t d0 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[0]),
+ vreinterpretq_u32_u16(c1.val[0]));
+ const uint32x4x2_t d1 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[1]),
+ vreinterpretq_u32_u16(c1.val[1]));
+
+ if (width == 4) {
+ StoreLo4(dst, vreinterpret_u8_u32(vget_low_u32(d0.val[0])));
+ dst += stride;
+ StoreLo4(dst, vreinterpret_u8_u32(vget_high_u32(d0.val[0])));
+ dst += stride;
+ StoreLo4(dst, vreinterpret_u8_u32(vget_low_u32(d1.val[0])));
+ dst += stride;
+ StoreLo4(dst, vreinterpret_u8_u32(vget_high_u32(d1.val[0])));
+ if (height == 4) return;
+ dst += stride;
+ StoreLo4(dst, vreinterpret_u8_u32(vget_low_u32(d0.val[1])));
+ dst += stride;
+ StoreLo4(dst, vreinterpret_u8_u32(vget_high_u32(d0.val[1])));
+ dst += stride;
+ StoreLo4(dst, vreinterpret_u8_u32(vget_low_u32(d1.val[1])));
+ dst += stride;
+ StoreLo4(dst, vreinterpret_u8_u32(vget_high_u32(d1.val[1])));
+ } else {
+ vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d0.val[0])));
+ dst += stride;
+ vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d0.val[0])));
+ dst += stride;
+ vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d1.val[0])));
+ dst += stride;
+ vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d1.val[0])));
+ if (height == 4) return;
+ dst += stride;
+ vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d0.val[1])));
+ dst += stride;
+ vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d0.val[1])));
+ dst += stride;
+ vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d1.val[1])));
+ dst += stride;
+ vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d1.val[1])));
+ }
+}
+
+// Because the source values "move backwards" as the row index increases, the
+// indices derived from ystep are generally negative. This is accommodated by
+// making sure the relative indices are within [-15, 0] when the function is
+// called, and sliding them into the inclusive range [0, 15], relative to a
+// lower base address.
+constexpr int kPositiveIndexOffset = 15;
+
+// Process 4 or 8 |width| by any |height|.
+template <int width>
+inline void DirectionalZone2FromLeftCol_WxH(uint8_t* dst,
+ const ptrdiff_t stride,
+ const int height,
+ const uint8_t* const left_column,
+ const int16x8_t left_y,
+ const int upsample_shift) {
+ assert(width == 4 || width == 8);
+
+ // The shift argument must be a constant.
+ int16x8_t offset_y, shift_upsampled = left_y;
+ if (upsample_shift) {
+ offset_y = vshrq_n_s16(left_y, 5);
+ shift_upsampled = vshlq_n_s16(shift_upsampled, 1);
+ } else {
+ offset_y = vshrq_n_s16(left_y, 6);
+ }
+
+ // Select values to the left of the starting point.
+ // The 15th element (and 16th) will be all the way at the end, to the right.
+ // With a negative ystep everything else will be "left" of them.
+ // This supports cumulative steps up to 15. We could support up to 16 by doing
+ // separate loads for |left_values| and |right_values|. vtbl supports 2 Q
+ // registers as input which would allow for cumulative offsets of 32.
+ const int16x8_t sampler =
+ vaddq_s16(offset_y, vdupq_n_s16(kPositiveIndexOffset));
+ const uint8x8_t left_values = vqmovun_s16(sampler);
+ const uint8x8_t right_values = vadd_u8(left_values, vdup_n_u8(1));
+
+ const int16x8_t shift_masked = vandq_s16(shift_upsampled, vdupq_n_s16(0x3f));
+ const uint8x8_t shift_mul = vreinterpret_u8_s8(vshrn_n_s16(shift_masked, 1));
+ const uint8x8_t inv_shift_mul = vsub_u8(vdup_n_u8(32), shift_mul);
+
+ int y = 0;
+ do {
+ uint8x8_t src_left, src_right;
+ LoadStepwise(left_column - kPositiveIndexOffset + (y << upsample_shift),
+ left_values, right_values, &src_left, &src_right);
+ const uint8x8_t val =
+ WeightedBlend(src_left, src_right, inv_shift_mul, shift_mul);
+
+ if (width == 4) {
+ StoreLo4(dst, val);
+ } else {
+ vst1_u8(dst, val);
+ }
+ dst += stride;
+ } while (++y < height);
+}
+
+// Process 4 or 8 |width| by any |height|.
+template <int width>
+inline void DirectionalZone1Blend_WxH(uint8_t* dest, const ptrdiff_t stride,
+ const int height,
+ const uint8_t* const top_row,
+ int zone_bounds, int top_x,
+ const int xstep,
+ const int upsample_shift) {
+ assert(width == 4 || width == 8);
+
+ const int scale_bits_x = 6 - upsample_shift;
+
+ const uint8x8_t all = vcreate_u8(0x0706050403020100);
+ const uint8x8_t even = vcreate_u8(0x0e0c0a0806040200);
+ const uint8x8_t base_step = upsample_shift ? even : all;
+ const uint8x8_t right_step = vadd_u8(base_step, vdup_n_u8(1));
+
+ int y = 0;
+ do {
+ const uint8_t* const src = top_row + (top_x >> scale_bits_x);
+ uint8x8_t left, right;
+ LoadStepwise(src, base_step, right_step, &left, &right);
+
+ const uint8_t shift = ((top_x << upsample_shift) & 0x3f) >> 1;
+ const uint8x8_t val = WeightedBlend(left, right, shift);
+
+ uint8x8_t dst_blend = vld1_u8(dest);
+ // |zone_bounds| values can be negative.
+ uint8x8_t blend =
+ vcge_s8(vreinterpret_s8_u8(all), vdup_n_s8((zone_bounds >> 6)));
+ uint8x8_t output = vbsl_u8(blend, val, dst_blend);
+
+ if (width == 4) {
+ StoreLo4(dest, output);
+ } else {
+ vst1_u8(dest, output);
+ }
+ dest += stride;
+ zone_bounds += xstep;
+ top_x -= xstep;
+ } while (++y < height);
+}
+
+// The height at which a load of 16 bytes will not contain enough source pixels
+// from |left_column| to supply an accurate row when computing 8 pixels at a
+// time. The values are found by inspection. By coincidence, all angles that
+// satisfy (ystep >> 6) == 2 map to the same value, so it is enough to look up
+// by ystep >> 6. The largest index for this lookup is 1023 >> 6 == 15.
+constexpr int kDirectionalZone2ShuffleInvalidHeight[16] = {
+ 1024, 1024, 16, 16, 16, 16, 0, 0, 18, 0, 0, 0, 0, 0, 0, 40};
+
+// 7.11.2.4 (8) 90 < angle > 180
+// The strategy for these functions (4xH and 8+xH) is to know how many blocks
+// can be processed with just pixels from |top_ptr|, then handle mixed blocks,
+// then handle only blocks that take from |left_ptr|. Additionally, a fast
+// index-shuffle approach is used for pred values from |left_column| in sections
+// that permit it.
+inline void DirectionalZone2_4xH(uint8_t* dst, const ptrdiff_t stride,
+ const uint8_t* const top_row,
+ const uint8_t* const left_column,
+ const int height, const int xstep,
+ const int ystep, const bool upsampled_top,
+ const bool upsampled_left) {
+ const int upsample_left_shift = static_cast<int>(upsampled_left);
+ const int upsample_top_shift = static_cast<int>(upsampled_top);
+
+ // Helper vector.
+ const int16x8_t zero_to_seven = {0, 1, 2, 3, 4, 5, 6, 7};
+
+ // Loop incrementers for moving by block (4xN). Vertical still steps by 8. If
+ // it's only 4, it will be finished in the first iteration.
+ const ptrdiff_t stride8 = stride << 3;
+ const int xstep8 = xstep << 3;
+
+ const int min_height = (height == 4) ? 4 : 8;
+
+ // All columns from |min_top_only_x| to the right will only need |top_row| to
+ // compute and can therefore call the Zone1 functions. This assumes |xstep| is
+ // at least 3.
+ assert(xstep >= 3);
+ const int min_top_only_x = std::min((height * xstep) >> 6, /* width */ 4);
+
+ // For steep angles, the source pixels from |left_column| may not fit in a
+ // 16-byte load for shuffling.
+ // TODO(petersonab): Find a more precise formula for this subject to x.
+ // TODO(johannkoenig): Revisit this for |width| == 4.
+ const int max_shuffle_height =
+ std::min(kDirectionalZone2ShuffleInvalidHeight[ystep >> 6], height);
+
+ // Offsets the original zone bound value to simplify x < (y+1)*xstep/64 -1
+ int xstep_bounds_base = (xstep == 64) ? 0 : xstep - 1;
+
+ const int left_base_increment = ystep >> 6;
+ const int ystep_remainder = ystep & 0x3F;
+
+ // If the 64 scaling is regarded as a decimal point, the first value of the
+ // left_y vector omits the portion which is covered under the left_column
+ // offset. The following values need the full ystep as a relative offset.
+ int16x8_t left_y = vmulq_n_s16(zero_to_seven, -ystep);
+ left_y = vaddq_s16(left_y, vdupq_n_s16(-ystep_remainder));
+
+ // This loop treats each set of 4 columns in 3 stages with y-value boundaries.
+ // The first stage, before the first y-loop, covers blocks that are only
+ // computed from the top row. The second stage, comprising two y-loops, covers
+ // blocks that have a mixture of values computed from top or left. The final
+ // stage covers blocks that are only computed from the left.
+ if (min_top_only_x > 0) {
+ // Round down to the nearest multiple of 8.
+ // TODO(johannkoenig): This never hits for Wx4 blocks but maybe it should.
+ const int max_top_only_y = std::min((1 << 6) / xstep, height) & ~7;
+ DirectionalZone1_WxH<4>(dst, stride, max_top_only_y, top_row, -xstep,
+ upsampled_top);
+
+ if (max_top_only_y == height) return;
+
+ int y = max_top_only_y;
+ dst += stride * y;
+ const int xstep_y = xstep * y;
+
+ // All rows from |min_left_only_y| down for this set of columns only need
+ // |left_column| to compute.
+ const int min_left_only_y = std::min((4 << 6) / xstep, height);
+ // At high angles such that min_left_only_y < 8, ystep is low and xstep is
+ // high. This means that max_shuffle_height is unbounded and xstep_bounds
+ // will overflow in 16 bits. This is prevented by stopping the first
+ // blending loop at min_left_only_y for such cases, which means we skip over
+ // the second blending loop as well.
+ const int left_shuffle_stop_y =
+ std::min(max_shuffle_height, min_left_only_y);
+ int xstep_bounds = xstep_bounds_base + xstep_y;
+ int top_x = -xstep - xstep_y;
+
+ // +8 increment is OK because if height is 4 this only goes once.
+ for (; y < left_shuffle_stop_y;
+ y += 8, dst += stride8, xstep_bounds += xstep8, top_x -= xstep8) {
+ DirectionalZone2FromLeftCol_WxH<4>(
+ dst, stride, min_height,
+ left_column + ((y - left_base_increment) << upsample_left_shift),
+ left_y, upsample_left_shift);
+
+ DirectionalZone1Blend_WxH<4>(dst, stride, min_height, top_row,
+ xstep_bounds, top_x, xstep,
+ upsample_top_shift);
+ }
+
+ // Pick up from the last y-value, using the slower but secure method for
+ // left prediction.
+ const int16_t base_left_y = vgetq_lane_s16(left_y, 0);
+ for (; y < min_left_only_y;
+ y += 8, dst += stride8, xstep_bounds += xstep8, top_x -= xstep8) {
+ DirectionalZone3_WxH<4>(
+ dst, stride, min_height,
+ left_column + ((y - left_base_increment) << upsample_left_shift),
+ base_left_y, -ystep, upsample_left_shift);
+
+ DirectionalZone1Blend_WxH<4>(dst, stride, min_height, top_row,
+ xstep_bounds, top_x, xstep,
+ upsample_top_shift);
+ }
+ // Loop over y for left_only rows.
+ for (; y < height; y += 8, dst += stride8) {
+ DirectionalZone3_WxH<4>(
+ dst, stride, min_height,
+ left_column + ((y - left_base_increment) << upsample_left_shift),
+ base_left_y, -ystep, upsample_left_shift);
+ }
+ } else {
+ DirectionalZone1_WxH<4>(dst, stride, height, top_row, -xstep,
+ upsampled_top);
+ }
+}
+
+// Process a multiple of 8 |width|.
+inline void DirectionalZone2_8(uint8_t* const dst, const ptrdiff_t stride,
+ const uint8_t* const top_row,
+ const uint8_t* const left_column,
+ const int width, const int height,
+ const int xstep, const int ystep,
+ const bool upsampled_top,
+ const bool upsampled_left) {
+ const int upsample_left_shift = static_cast<int>(upsampled_left);
+ const int upsample_top_shift = static_cast<int>(upsampled_top);
+
+ // Helper vector.
+ const int16x8_t zero_to_seven = {0, 1, 2, 3, 4, 5, 6, 7};
+
+ // Loop incrementers for moving by block (8x8). This function handles blocks
+ // with height 4 as well. They are calculated in one pass so these variables
+ // do not get used.
+ const ptrdiff_t stride8 = stride << 3;
+ const int xstep8 = xstep << 3;
+ const int ystep8 = ystep << 3;
+
+ // Process Wx4 blocks.
+ const int min_height = (height == 4) ? 4 : 8;
+
+ // All columns from |min_top_only_x| to the right will only need |top_row| to
+ // compute and can therefore call the Zone1 functions. This assumes |xstep| is
+ // at least 3.
+ assert(xstep >= 3);
+ const int min_top_only_x = std::min((height * xstep) >> 6, width);
+
+ // For steep angles, the source pixels from |left_column| may not fit in a
+ // 16-byte load for shuffling.
+ // TODO(petersonab): Find a more precise formula for this subject to x.
+ const int max_shuffle_height =
+ std::min(kDirectionalZone2ShuffleInvalidHeight[ystep >> 6], height);
+
+ // Offsets the original zone bound value to simplify x < (y+1)*xstep/64 -1
+ int xstep_bounds_base = (xstep == 64) ? 0 : xstep - 1;
+
+ const int left_base_increment = ystep >> 6;
+ const int ystep_remainder = ystep & 0x3F;
+
+ const int left_base_increment8 = ystep8 >> 6;
+ const int ystep_remainder8 = ystep8 & 0x3F;
+ const int16x8_t increment_left8 = vdupq_n_s16(ystep_remainder8);
+
+ // If the 64 scaling is regarded as a decimal point, the first value of the
+ // left_y vector omits the portion which is covered under the left_column
+ // offset. Following values need the full ystep as a relative offset.
+ int16x8_t left_y = vmulq_n_s16(zero_to_seven, -ystep);
+ left_y = vaddq_s16(left_y, vdupq_n_s16(-ystep_remainder));
+
+ // This loop treats each set of 4 columns in 3 stages with y-value boundaries.
+ // The first stage, before the first y-loop, covers blocks that are only
+ // computed from the top row. The second stage, comprising two y-loops, covers
+ // blocks that have a mixture of values computed from top or left. The final
+ // stage covers blocks that are only computed from the left.
+ int x = 0;
+ for (int left_offset = -left_base_increment; x < min_top_only_x; x += 8,
+ xstep_bounds_base -= (8 << 6),
+ left_y = vsubq_s16(left_y, increment_left8),
+ left_offset -= left_base_increment8) {
+ uint8_t* dst_x = dst + x;
+
+ // Round down to the nearest multiple of 8.
+ const int max_top_only_y = std::min(((x + 1) << 6) / xstep, height) & ~7;
+ DirectionalZone1_WxH<8>(dst_x, stride, max_top_only_y,
+ top_row + (x << upsample_top_shift), -xstep,
+ upsampled_top);
+
+ if (max_top_only_y == height) continue;
+
+ int y = max_top_only_y;
+ dst_x += stride * y;
+ const int xstep_y = xstep * y;
+
+ // All rows from |min_left_only_y| down for this set of columns only need
+ // |left_column| to compute.
+ const int min_left_only_y = std::min(((x + 8) << 6) / xstep, height);
+ // At high angles such that min_left_only_y < 8, ystep is low and xstep is
+ // high. This means that max_shuffle_height is unbounded and xstep_bounds
+ // will overflow in 16 bits. This is prevented by stopping the first
+ // blending loop at min_left_only_y for such cases, which means we skip over
+ // the second blending loop as well.
+ const int left_shuffle_stop_y =
+ std::min(max_shuffle_height, min_left_only_y);
+ int xstep_bounds = xstep_bounds_base + xstep_y;
+ int top_x = -xstep - xstep_y;
+
+ for (; y < left_shuffle_stop_y;
+ y += 8, dst_x += stride8, xstep_bounds += xstep8, top_x -= xstep8) {
+ DirectionalZone2FromLeftCol_WxH<8>(
+ dst_x, stride, min_height,
+ left_column + ((left_offset + y) << upsample_left_shift), left_y,
+ upsample_left_shift);
+
+ DirectionalZone1Blend_WxH<8>(
+ dst_x, stride, min_height, top_row + (x << upsample_top_shift),
+ xstep_bounds, top_x, xstep, upsample_top_shift);
+ }
+
+ // Pick up from the last y-value, using the slower but secure method for
+ // left prediction.
+ const int16_t base_left_y = vgetq_lane_s16(left_y, 0);
+ for (; y < min_left_only_y;
+ y += 8, dst_x += stride8, xstep_bounds += xstep8, top_x -= xstep8) {
+ DirectionalZone3_WxH<8>(
+ dst_x, stride, min_height,
+ left_column + ((left_offset + y) << upsample_left_shift), base_left_y,
+ -ystep, upsample_left_shift);
+
+ DirectionalZone1Blend_WxH<8>(
+ dst_x, stride, min_height, top_row + (x << upsample_top_shift),
+ xstep_bounds, top_x, xstep, upsample_top_shift);
+ }
+ // Loop over y for left_only rows.
+ for (; y < height; y += 8, dst_x += stride8) {
+ DirectionalZone3_WxH<8>(
+ dst_x, stride, min_height,
+ left_column + ((left_offset + y) << upsample_left_shift), base_left_y,
+ -ystep, upsample_left_shift);
+ }
+ }
+ // TODO(johannkoenig): May be able to remove this branch.
+ if (x < width) {
+ DirectionalZone1_WxH(dst + x, stride, width - x, height,
+ top_row + (x << upsample_top_shift), -xstep,
+ upsampled_top);
+ }
+}
+
+void DirectionalIntraPredictorZone2_NEON(
+ void* const dest, const ptrdiff_t stride, const void* const top_row,
+ const void* const left_column, const int width, const int height,
+ const int xstep, const int ystep, const bool upsampled_top,
+ const bool upsampled_left) {
+ // Increasing the negative buffer for this function allows more rows to be
+ // processed at a time without branching in an inner loop to check the base.
+ uint8_t top_buffer[288];
+ uint8_t left_buffer[288];
+ memcpy(top_buffer + 128, static_cast<const uint8_t*>(top_row) - 16, 160);
+ memcpy(left_buffer + 128, static_cast<const uint8_t*>(left_column) - 16, 160);
+ const uint8_t* top_ptr = top_buffer + 144;
+ const uint8_t* left_ptr = left_buffer + 144;
+ auto* dst = static_cast<uint8_t*>(dest);
+
+ if (width == 4) {
+ DirectionalZone2_4xH(dst, stride, top_ptr, left_ptr, height, xstep, ystep,
+ upsampled_top, upsampled_left);
+ } else {
+ DirectionalZone2_8(dst, stride, top_ptr, left_ptr, width, height, xstep,
+ ystep, upsampled_top, upsampled_left);
+ }
+}
+
+void DirectionalIntraPredictorZone3_NEON(void* const dest,
+ const ptrdiff_t stride,
+ const void* const left_column,
+ const int width, const int height,
+ const int ystep,
+ const bool upsampled_left) {
+ const auto* const left = static_cast<const uint8_t*>(left_column);
+
+ assert(ystep > 0);
+
+ const int upsample_shift = static_cast<int>(upsampled_left);
+ const int scale_bits = 6 - upsample_shift;
+ const int base_step = 1 << upsample_shift;
+
+ if (width == 4 || height == 4) {
+ // This block can handle all sizes but the specializations for other sizes
+ // are faster.
+ const uint8x8_t all = vcreate_u8(0x0706050403020100);
+ const uint8x8_t even = vcreate_u8(0x0e0c0a0806040200);
+ const uint8x8_t base_step_v = upsampled_left ? even : all;
+ const uint8x8_t right_step = vadd_u8(base_step_v, vdup_n_u8(1));
+
+ int y = 0;
+ do {
+ int x = 0;
+ do {
+ uint8_t* dst = static_cast<uint8_t*>(dest);
+ dst += y * stride + x;
+ uint8x8_t left_v[4], right_v[4], value_v[4];
+ const int ystep_base = ystep * x;
+ const int offset = y * base_step;
+
+ const int index_0 = ystep_base + ystep * 1;
+ LoadStepwise(left + offset + (index_0 >> scale_bits), base_step_v,
+ right_step, &left_v[0], &right_v[0]);
+ value_v[0] = WeightedBlend(left_v[0], right_v[0],
+ ((index_0 << upsample_shift) & 0x3F) >> 1);
+
+ const int index_1 = ystep_base + ystep * 2;
+ LoadStepwise(left + offset + (index_1 >> scale_bits), base_step_v,
+ right_step, &left_v[1], &right_v[1]);
+ value_v[1] = WeightedBlend(left_v[1], right_v[1],
+ ((index_1 << upsample_shift) & 0x3F) >> 1);
+
+ const int index_2 = ystep_base + ystep * 3;
+ LoadStepwise(left + offset + (index_2 >> scale_bits), base_step_v,
+ right_step, &left_v[2], &right_v[2]);
+ value_v[2] = WeightedBlend(left_v[2], right_v[2],
+ ((index_2 << upsample_shift) & 0x3F) >> 1);
+
+ const int index_3 = ystep_base + ystep * 4;
+ LoadStepwise(left + offset + (index_3 >> scale_bits), base_step_v,
+ right_step, &left_v[3], &right_v[3]);
+ value_v[3] = WeightedBlend(left_v[3], right_v[3],
+ ((index_3 << upsample_shift) & 0x3F) >> 1);
+
+ // 8x4 transpose.
+ const uint8x8x2_t b0 = vtrn_u8(value_v[0], value_v[1]);
+ const uint8x8x2_t b1 = vtrn_u8(value_v[2], value_v[3]);
+
+ const uint16x4x2_t c0 = vtrn_u16(vreinterpret_u16_u8(b0.val[0]),
+ vreinterpret_u16_u8(b1.val[0]));
+ const uint16x4x2_t c1 = vtrn_u16(vreinterpret_u16_u8(b0.val[1]),
+ vreinterpret_u16_u8(b1.val[1]));
+
+ StoreLo4(dst, vreinterpret_u8_u16(c0.val[0]));
+ dst += stride;
+ StoreLo4(dst, vreinterpret_u8_u16(c1.val[0]));
+ dst += stride;
+ StoreLo4(dst, vreinterpret_u8_u16(c0.val[1]));
+ dst += stride;
+ StoreLo4(dst, vreinterpret_u8_u16(c1.val[1]));
+
+ if (height > 4) {
+ dst += stride;
+ StoreHi4(dst, vreinterpret_u8_u16(c0.val[0]));
+ dst += stride;
+ StoreHi4(dst, vreinterpret_u8_u16(c1.val[0]));
+ dst += stride;
+ StoreHi4(dst, vreinterpret_u8_u16(c0.val[1]));
+ dst += stride;
+ StoreHi4(dst, vreinterpret_u8_u16(c1.val[1]));
+ }
+ x += 4;
+ } while (x < width);
+ y += 8;
+ } while (y < height);
+ } else { // 8x8 at a time.
+ // Limited improvement for 8x8. ~20% faster for 64x64.
+ int y = 0;
+ do {
+ int x = 0;
+ do {
+ uint8_t* dst = static_cast<uint8_t*>(dest);
+ dst += y * stride + x;
+ const int ystep_base = ystep * (x + 1);
+
+ DirectionalZone3_WxH<8>(dst, stride, 8, left + (y << upsample_shift),
+ ystep_base, ystep, upsample_shift);
+ x += 8;
+ } while (x < width);
+ y += 8;
+ } while (y < height);
+ }
+}
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+ dsp->directional_intra_predictor_zone1 = DirectionalIntraPredictorZone1_NEON;
+ dsp->directional_intra_predictor_zone2 = DirectionalIntraPredictorZone2_NEON;
+ dsp->directional_intra_predictor_zone3 = DirectionalIntraPredictorZone3_NEON;
+}
+
+} // namespace
+} // namespace low_bitdepth
+
+void IntraPredDirectionalInit_NEON() { low_bitdepth::Init8bpp(); }
+
+} // namespace dsp
+} // namespace libgav1
+
+#else // !LIBGAV1_ENABLE_NEON
+namespace libgav1 {
+namespace dsp {
+
+void IntraPredDirectionalInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/intrapred_filter_intra_neon.cc b/src/dsp/arm/intrapred_filter_intra_neon.cc
new file mode 100644
index 0000000..411708e
--- /dev/null
+++ b/src/dsp/arm/intrapred_filter_intra_neon.cc
@@ -0,0 +1,176 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/intrapred.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+
+namespace low_bitdepth {
+namespace {
+
+// Transpose kFilterIntraTaps and convert the first row to unsigned values.
+//
+// With the previous orientation we were able to multiply all the input values
+// by a single tap. This required that all the input values be in one vector
+// which requires expensive set up operations (shifts, vext, vtbl). All the
+// elements of the result needed to be summed (easy on A64 - vaddvq_s16) but
+// then the shifting, rounding, and clamping was done in GP registers.
+//
+// Switching to unsigned values allows multiplying the 8 bit inputs directly.
+// When one value was negative we needed to vmovl_u8 first so that the results
+// maintained the proper sign.
+//
+// We take this into account when summing the values by subtracting the product
+// of the first row.
+alignas(8) constexpr uint8_t kTransposedTaps[kNumFilterIntraPredictors][7][8] =
+ {{{6, 5, 3, 3, 4, 3, 3, 3}, // Original values are negative.
+ {10, 2, 1, 1, 6, 2, 2, 1},
+ {0, 10, 1, 1, 0, 6, 2, 2},
+ {0, 0, 10, 2, 0, 0, 6, 2},
+ {0, 0, 0, 10, 0, 0, 0, 6},
+ {12, 9, 7, 5, 2, 2, 2, 3},
+ {0, 0, 0, 0, 12, 9, 7, 5}},
+ {{10, 6, 4, 2, 10, 6, 4, 2}, // Original values are negative.
+ {16, 0, 0, 0, 16, 0, 0, 0},
+ {0, 16, 0, 0, 0, 16, 0, 0},
+ {0, 0, 16, 0, 0, 0, 16, 0},
+ {0, 0, 0, 16, 0, 0, 0, 16},
+ {10, 6, 4, 2, 0, 0, 0, 0},
+ {0, 0, 0, 0, 10, 6, 4, 2}},
+ {{8, 8, 8, 8, 4, 4, 4, 4}, // Original values are negative.
+ {8, 0, 0, 0, 4, 0, 0, 0},
+ {0, 8, 0, 0, 0, 4, 0, 0},
+ {0, 0, 8, 0, 0, 0, 4, 0},
+ {0, 0, 0, 8, 0, 0, 0, 4},
+ {16, 16, 16, 16, 0, 0, 0, 0},
+ {0, 0, 0, 0, 16, 16, 16, 16}},
+ {{2, 1, 1, 0, 1, 1, 1, 1}, // Original values are negative.
+ {8, 3, 2, 1, 4, 3, 2, 2},
+ {0, 8, 3, 2, 0, 4, 3, 2},
+ {0, 0, 8, 3, 0, 0, 4, 3},
+ {0, 0, 0, 8, 0, 0, 0, 4},
+ {10, 6, 4, 2, 3, 4, 4, 3},
+ {0, 0, 0, 0, 10, 6, 4, 3}},
+ {{12, 10, 9, 8, 10, 9, 8, 7}, // Original values are negative.
+ {14, 0, 0, 0, 12, 1, 0, 0},
+ {0, 14, 0, 0, 0, 12, 0, 0},
+ {0, 0, 14, 0, 0, 0, 12, 1},
+ {0, 0, 0, 14, 0, 0, 0, 12},
+ {14, 12, 11, 10, 0, 0, 1, 1},
+ {0, 0, 0, 0, 14, 12, 11, 9}}};
+
+void FilterIntraPredictor_NEON(void* const dest, ptrdiff_t stride,
+ const void* const top_row,
+ const void* const left_column,
+ FilterIntraPredictor pred, int width,
+ int height) {
+ const uint8_t* const top = static_cast<const uint8_t*>(top_row);
+ const uint8_t* const left = static_cast<const uint8_t*>(left_column);
+
+ assert(width <= 32 && height <= 32);
+
+ uint8_t* dst = static_cast<uint8_t*>(dest);
+
+ uint8x8_t transposed_taps[7];
+ for (int i = 0; i < 7; ++i) {
+ transposed_taps[i] = vld1_u8(kTransposedTaps[pred][i]);
+ }
+
+ uint8_t relative_top_left = top[-1];
+ const uint8_t* relative_top = top;
+ uint8_t relative_left[2] = {left[0], left[1]};
+
+ int y = 0;
+ do {
+ uint8_t* row_dst = dst;
+ int x = 0;
+ do {
+ uint16x8_t sum = vdupq_n_u16(0);
+ const uint16x8_t subtrahend =
+ vmull_u8(transposed_taps[0], vdup_n_u8(relative_top_left));
+ for (int i = 1; i < 5; ++i) {
+ sum = vmlal_u8(sum, transposed_taps[i], vdup_n_u8(relative_top[i - 1]));
+ }
+ for (int i = 5; i < 7; ++i) {
+ sum =
+ vmlal_u8(sum, transposed_taps[i], vdup_n_u8(relative_left[i - 5]));
+ }
+
+ const int16x8_t sum_signed =
+ vreinterpretq_s16_u16(vsubq_u16(sum, subtrahend));
+ const int16x8_t sum_shifted = vrshrq_n_s16(sum_signed, 4);
+
+ uint8x8_t sum_saturated = vqmovun_s16(sum_shifted);
+
+ StoreLo4(row_dst, sum_saturated);
+ StoreHi4(row_dst + stride, sum_saturated);
+
+ // Progress across
+ relative_top_left = relative_top[3];
+ relative_top += 4;
+ relative_left[0] = row_dst[3];
+ relative_left[1] = row_dst[3 + stride];
+ row_dst += 4;
+ x += 4;
+ } while (x < width);
+
+ // Progress down.
+ relative_top_left = left[y + 1];
+ relative_top = dst + stride;
+ relative_left[0] = left[y + 2];
+ relative_left[1] = left[y + 3];
+
+ dst += 2 * stride;
+ y += 2;
+ } while (y < height);
+}
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+ dsp->filter_intra_predictor = FilterIntraPredictor_NEON;
+}
+
+} // namespace
+} // namespace low_bitdepth
+
+void IntraPredFilterIntraInit_NEON() { low_bitdepth::Init8bpp(); }
+
+} // namespace dsp
+} // namespace libgav1
+
+#else // !LIBGAV1_ENABLE_NEON
+namespace libgav1 {
+namespace dsp {
+
+void IntraPredFilterIntraInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/intrapred_neon.cc b/src/dsp/arm/intrapred_neon.cc
new file mode 100644
index 0000000..c967d82
--- /dev/null
+++ b/src/dsp/arm/intrapred_neon.cc
@@ -0,0 +1,1144 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/intrapred.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+//------------------------------------------------------------------------------
+// DcPredFuncs_NEON
+
+using DcSumFunc = uint32x2_t (*)(const void* ref_0, const int ref_0_size_log2,
+ const bool use_ref_1, const void* ref_1,
+ const int ref_1_size_log2);
+using DcStoreFunc = void (*)(void* dest, ptrdiff_t stride, const uint32x2_t dc);
+
+// DC intra-predictors for square blocks.
+template <int block_width_log2, int block_height_log2, DcSumFunc sumfn,
+ DcStoreFunc storefn>
+struct DcPredFuncs_NEON {
+ DcPredFuncs_NEON() = delete;
+
+ static void DcTop(void* dest, ptrdiff_t stride, const void* top_row,
+ const void* left_column);
+ static void DcLeft(void* dest, ptrdiff_t stride, const void* top_row,
+ const void* left_column);
+ static void Dc(void* dest, ptrdiff_t stride, const void* top_row,
+ const void* left_column);
+};
+
+template <int block_width_log2, int block_height_log2, DcSumFunc sumfn,
+ DcStoreFunc storefn>
+void DcPredFuncs_NEON<block_width_log2, block_height_log2, sumfn,
+ storefn>::DcTop(void* const dest, ptrdiff_t stride,
+ const void* const top_row,
+ const void* /*left_column*/) {
+ const uint32x2_t sum = sumfn(top_row, block_width_log2, false, nullptr, 0);
+ const uint32x2_t dc = vrshr_n_u32(sum, block_width_log2);
+ storefn(dest, stride, dc);
+}
+
+template <int block_width_log2, int block_height_log2, DcSumFunc sumfn,
+ DcStoreFunc storefn>
+void DcPredFuncs_NEON<block_width_log2, block_height_log2, sumfn,
+ storefn>::DcLeft(void* const dest, ptrdiff_t stride,
+ const void* /*top_row*/,
+ const void* const left_column) {
+ const uint32x2_t sum =
+ sumfn(left_column, block_height_log2, false, nullptr, 0);
+ const uint32x2_t dc = vrshr_n_u32(sum, block_height_log2);
+ storefn(dest, stride, dc);
+}
+
+template <int block_width_log2, int block_height_log2, DcSumFunc sumfn,
+ DcStoreFunc storefn>
+void DcPredFuncs_NEON<block_width_log2, block_height_log2, sumfn, storefn>::Dc(
+ void* const dest, ptrdiff_t stride, const void* const top_row,
+ const void* const left_column) {
+ const uint32x2_t sum =
+ sumfn(top_row, block_width_log2, true, left_column, block_height_log2);
+ if (block_width_log2 == block_height_log2) {
+ const uint32x2_t dc = vrshr_n_u32(sum, block_width_log2 + 1);
+ storefn(dest, stride, dc);
+ } else {
+ // TODO(johannkoenig): Compare this to mul/shift in vectors.
+ const int divisor = (1 << block_width_log2) + (1 << block_height_log2);
+ uint32_t dc = vget_lane_u32(sum, 0);
+ dc += divisor >> 1;
+ dc /= divisor;
+ storefn(dest, stride, vdup_n_u32(dc));
+ }
+}
+
+// Sum all the elements in the vector into the low 32 bits.
+inline uint32x2_t Sum(const uint16x4_t val) {
+ const uint32x2_t sum = vpaddl_u16(val);
+ return vpadd_u32(sum, sum);
+}
+
+// Sum all the elements in the vector into the low 32 bits.
+inline uint32x2_t Sum(const uint16x8_t val) {
+ const uint32x4_t sum_0 = vpaddlq_u16(val);
+ const uint64x2_t sum_1 = vpaddlq_u32(sum_0);
+ return vadd_u32(vget_low_u32(vreinterpretq_u32_u64(sum_1)),
+ vget_high_u32(vreinterpretq_u32_u64(sum_1)));
+}
+
+} // namespace
+
+//------------------------------------------------------------------------------
+namespace low_bitdepth {
+namespace {
+
+// Add and expand the elements in the |val_[01]| to uint16_t but do not sum the
+// entire vector.
+inline uint16x8_t Add(const uint8x16_t val_0, const uint8x16_t val_1) {
+ const uint16x8_t sum_0 = vpaddlq_u8(val_0);
+ const uint16x8_t sum_1 = vpaddlq_u8(val_1);
+ return vaddq_u16(sum_0, sum_1);
+}
+
+// Add and expand the elements in the |val_[0123]| to uint16_t but do not sum
+// the entire vector.
+inline uint16x8_t Add(const uint8x16_t val_0, const uint8x16_t val_1,
+ const uint8x16_t val_2, const uint8x16_t val_3) {
+ const uint16x8_t sum_0 = Add(val_0, val_1);
+ const uint16x8_t sum_1 = Add(val_2, val_3);
+ return vaddq_u16(sum_0, sum_1);
+}
+
+// Load and combine 32 uint8_t values.
+inline uint16x8_t LoadAndAdd32(const uint8_t* buf) {
+ const uint8x16_t val_0 = vld1q_u8(buf);
+ const uint8x16_t val_1 = vld1q_u8(buf + 16);
+ return Add(val_0, val_1);
+}
+
+// Load and combine 64 uint8_t values.
+inline uint16x8_t LoadAndAdd64(const uint8_t* buf) {
+ const uint8x16_t val_0 = vld1q_u8(buf);
+ const uint8x16_t val_1 = vld1q_u8(buf + 16);
+ const uint8x16_t val_2 = vld1q_u8(buf + 32);
+ const uint8x16_t val_3 = vld1q_u8(buf + 48);
+ return Add(val_0, val_1, val_2, val_3);
+}
+
+// |ref_[01]| each point to 1 << |ref[01]_size_log2| packed uint8_t values.
+// If |use_ref_1| is false then only sum |ref_0|.
+// For |ref[01]_size_log2| == 4 this relies on |ref_[01]| being aligned to
+// uint32_t.
+inline uint32x2_t DcSum_NEON(const void* ref_0, const int ref_0_size_log2,
+ const bool use_ref_1, const void* ref_1,
+ const int ref_1_size_log2) {
+ const auto* const ref_0_u8 = static_cast<const uint8_t*>(ref_0);
+ const auto* const ref_1_u8 = static_cast<const uint8_t*>(ref_1);
+ if (ref_0_size_log2 == 2) {
+ uint8x8_t val = Load4(ref_0_u8);
+ if (use_ref_1) {
+ if (ref_1_size_log2 == 2) { // 4x4
+ val = Load4<1>(ref_1_u8, val);
+ return Sum(vpaddl_u8(val));
+ } else if (ref_1_size_log2 == 3) { // 4x8
+ const uint8x8_t val_1 = vld1_u8(ref_1_u8);
+ const uint16x4_t sum_0 = vpaddl_u8(val);
+ const uint16x4_t sum_1 = vpaddl_u8(val_1);
+ return Sum(vadd_u16(sum_0, sum_1));
+ } else if (ref_1_size_log2 == 4) { // 4x16
+ const uint8x16_t val_1 = vld1q_u8(ref_1_u8);
+ return Sum(vaddw_u8(vpaddlq_u8(val_1), val));
+ }
+ }
+ // 4x1
+ const uint16x4_t sum = vpaddl_u8(val);
+ return vpaddl_u16(sum);
+ } else if (ref_0_size_log2 == 3) {
+ const uint8x8_t val_0 = vld1_u8(ref_0_u8);
+ if (use_ref_1) {
+ if (ref_1_size_log2 == 2) { // 8x4
+ const uint8x8_t val_1 = Load4(ref_1_u8);
+ const uint16x4_t sum_0 = vpaddl_u8(val_0);
+ const uint16x4_t sum_1 = vpaddl_u8(val_1);
+ return Sum(vadd_u16(sum_0, sum_1));
+ } else if (ref_1_size_log2 == 3) { // 8x8
+ const uint8x8_t val_1 = vld1_u8(ref_1_u8);
+ const uint16x4_t sum_0 = vpaddl_u8(val_0);
+ const uint16x4_t sum_1 = vpaddl_u8(val_1);
+ return Sum(vadd_u16(sum_0, sum_1));
+ } else if (ref_1_size_log2 == 4) { // 8x16
+ const uint8x16_t val_1 = vld1q_u8(ref_1_u8);
+ return Sum(vaddw_u8(vpaddlq_u8(val_1), val_0));
+ } else if (ref_1_size_log2 == 5) { // 8x32
+ return Sum(vaddw_u8(LoadAndAdd32(ref_1_u8), val_0));
+ }
+ }
+ // 8x1
+ return Sum(vpaddl_u8(val_0));
+ } else if (ref_0_size_log2 == 4) {
+ const uint8x16_t val_0 = vld1q_u8(ref_0_u8);
+ if (use_ref_1) {
+ if (ref_1_size_log2 == 2) { // 16x4
+ const uint8x8_t val_1 = Load4(ref_1_u8);
+ return Sum(vaddw_u8(vpaddlq_u8(val_0), val_1));
+ } else if (ref_1_size_log2 == 3) { // 16x8
+ const uint8x8_t val_1 = vld1_u8(ref_1_u8);
+ return Sum(vaddw_u8(vpaddlq_u8(val_0), val_1));
+ } else if (ref_1_size_log2 == 4) { // 16x16
+ const uint8x16_t val_1 = vld1q_u8(ref_1_u8);
+ return Sum(Add(val_0, val_1));
+ } else if (ref_1_size_log2 == 5) { // 16x32
+ const uint16x8_t sum_0 = vpaddlq_u8(val_0);
+ const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u8);
+ return Sum(vaddq_u16(sum_0, sum_1));
+ } else if (ref_1_size_log2 == 6) { // 16x64
+ const uint16x8_t sum_0 = vpaddlq_u8(val_0);
+ const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u8);
+ return Sum(vaddq_u16(sum_0, sum_1));
+ }
+ }
+ // 16x1
+ return Sum(vpaddlq_u8(val_0));
+ } else if (ref_0_size_log2 == 5) {
+ const uint16x8_t sum_0 = LoadAndAdd32(ref_0_u8);
+ if (use_ref_1) {
+ if (ref_1_size_log2 == 3) { // 32x8
+ const uint8x8_t val_1 = vld1_u8(ref_1_u8);
+ return Sum(vaddw_u8(sum_0, val_1));
+ } else if (ref_1_size_log2 == 4) { // 32x16
+ const uint8x16_t val_1 = vld1q_u8(ref_1_u8);
+ const uint16x8_t sum_1 = vpaddlq_u8(val_1);
+ return Sum(vaddq_u16(sum_0, sum_1));
+ } else if (ref_1_size_log2 == 5) { // 32x32
+ const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u8);
+ return Sum(vaddq_u16(sum_0, sum_1));
+ } else if (ref_1_size_log2 == 6) { // 32x64
+ const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u8);
+ return Sum(vaddq_u16(sum_0, sum_1));
+ }
+ }
+ // 32x1
+ return Sum(sum_0);
+ }
+
+ assert(ref_0_size_log2 == 6);
+ const uint16x8_t sum_0 = LoadAndAdd64(ref_0_u8);
+ if (use_ref_1) {
+ if (ref_1_size_log2 == 4) { // 64x16
+ const uint8x16_t val_1 = vld1q_u8(ref_1_u8);
+ const uint16x8_t sum_1 = vpaddlq_u8(val_1);
+ return Sum(vaddq_u16(sum_0, sum_1));
+ } else if (ref_1_size_log2 == 5) { // 64x32
+ const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u8);
+ return Sum(vaddq_u16(sum_0, sum_1));
+ } else if (ref_1_size_log2 == 6) { // 64x64
+ const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u8);
+ return Sum(vaddq_u16(sum_0, sum_1));
+ }
+ }
+ // 64x1
+ return Sum(sum_0);
+}
+
+template <int width, int height>
+inline void DcStore_NEON(void* const dest, ptrdiff_t stride,
+ const uint32x2_t dc) {
+ const uint8x16_t dc_dup = vdupq_lane_u8(vreinterpret_u8_u32(dc), 0);
+ auto* dst = static_cast<uint8_t*>(dest);
+ if (width == 4) {
+ int i = height - 1;
+ do {
+ StoreLo4(dst, vget_low_u8(dc_dup));
+ dst += stride;
+ } while (--i != 0);
+ StoreLo4(dst, vget_low_u8(dc_dup));
+ } else if (width == 8) {
+ int i = height - 1;
+ do {
+ vst1_u8(dst, vget_low_u8(dc_dup));
+ dst += stride;
+ } while (--i != 0);
+ vst1_u8(dst, vget_low_u8(dc_dup));
+ } else if (width == 16) {
+ int i = height - 1;
+ do {
+ vst1q_u8(dst, dc_dup);
+ dst += stride;
+ } while (--i != 0);
+ vst1q_u8(dst, dc_dup);
+ } else if (width == 32) {
+ int i = height - 1;
+ do {
+ vst1q_u8(dst, dc_dup);
+ vst1q_u8(dst + 16, dc_dup);
+ dst += stride;
+ } while (--i != 0);
+ vst1q_u8(dst, dc_dup);
+ vst1q_u8(dst + 16, dc_dup);
+ } else {
+ assert(width == 64);
+ int i = height - 1;
+ do {
+ vst1q_u8(dst, dc_dup);
+ vst1q_u8(dst + 16, dc_dup);
+ vst1q_u8(dst + 32, dc_dup);
+ vst1q_u8(dst + 48, dc_dup);
+ dst += stride;
+ } while (--i != 0);
+ vst1q_u8(dst, dc_dup);
+ vst1q_u8(dst + 16, dc_dup);
+ vst1q_u8(dst + 32, dc_dup);
+ vst1q_u8(dst + 48, dc_dup);
+ }
+}
+
+template <int width, int height>
+inline void Paeth4Or8xN_NEON(void* const dest, ptrdiff_t stride,
+ const void* const top_row,
+ const void* const left_column) {
+ auto* dest_u8 = static_cast<uint8_t*>(dest);
+ const auto* const top_row_u8 = static_cast<const uint8_t*>(top_row);
+ const auto* const left_col_u8 = static_cast<const uint8_t*>(left_column);
+
+ const uint8x8_t top_left = vdup_n_u8(top_row_u8[-1]);
+ const uint16x8_t top_left_x2 = vdupq_n_u16(top_row_u8[-1] + top_row_u8[-1]);
+ uint8x8_t top;
+ if (width == 4) {
+ top = Load4(top_row_u8);
+ } else { // width == 8
+ top = vld1_u8(top_row_u8);
+ }
+
+ for (int y = 0; y < height; ++y) {
+ const uint8x8_t left = vdup_n_u8(left_col_u8[y]);
+
+ const uint8x8_t left_dist = vabd_u8(top, top_left);
+ const uint8x8_t top_dist = vabd_u8(left, top_left);
+ const uint16x8_t top_left_dist =
+ vabdq_u16(vaddl_u8(top, left), top_left_x2);
+
+ const uint8x8_t left_le_top = vcle_u8(left_dist, top_dist);
+ const uint8x8_t left_le_top_left =
+ vmovn_u16(vcleq_u16(vmovl_u8(left_dist), top_left_dist));
+ const uint8x8_t top_le_top_left =
+ vmovn_u16(vcleq_u16(vmovl_u8(top_dist), top_left_dist));
+
+ // if (left_dist <= top_dist && left_dist <= top_left_dist)
+ const uint8x8_t left_mask = vand_u8(left_le_top, left_le_top_left);
+ // dest[x] = left_column[y];
+ // Fill all the unused spaces with 'top'. They will be overwritten when
+ // the positions for top_left are known.
+ uint8x8_t result = vbsl_u8(left_mask, left, top);
+ // else if (top_dist <= top_left_dist)
+ // dest[x] = top_row[x];
+ // Add these values to the mask. They were already set.
+ const uint8x8_t left_or_top_mask = vorr_u8(left_mask, top_le_top_left);
+ // else
+ // dest[x] = top_left;
+ result = vbsl_u8(left_or_top_mask, result, top_left);
+
+ if (width == 4) {
+ StoreLo4(dest_u8, result);
+ } else { // width == 8
+ vst1_u8(dest_u8, result);
+ }
+ dest_u8 += stride;
+ }
+}
+
+// Calculate X distance <= TopLeft distance and pack the resulting mask into
+// uint8x8_t.
+inline uint8x16_t XLeTopLeft(const uint8x16_t x_dist,
+ const uint16x8_t top_left_dist_low,
+ const uint16x8_t top_left_dist_high) {
+ // TODO(johannkoenig): cle() should work with vmovn(top_left_dist) instead of
+ // using movl(x_dist).
+ const uint8x8_t x_le_top_left_low =
+ vmovn_u16(vcleq_u16(vmovl_u8(vget_low_u8(x_dist)), top_left_dist_low));
+ const uint8x8_t x_le_top_left_high =
+ vmovn_u16(vcleq_u16(vmovl_u8(vget_high_u8(x_dist)), top_left_dist_high));
+ return vcombine_u8(x_le_top_left_low, x_le_top_left_high);
+}
+
+// Select the closest values and collect them.
+inline uint8x16_t SelectPaeth(const uint8x16_t top, const uint8x16_t left,
+ const uint8x16_t top_left,
+ const uint8x16_t left_le_top,
+ const uint8x16_t left_le_top_left,
+ const uint8x16_t top_le_top_left) {
+ // if (left_dist <= top_dist && left_dist <= top_left_dist)
+ const uint8x16_t left_mask = vandq_u8(left_le_top, left_le_top_left);
+ // dest[x] = left_column[y];
+ // Fill all the unused spaces with 'top'. They will be overwritten when
+ // the positions for top_left are known.
+ uint8x16_t result = vbslq_u8(left_mask, left, top);
+ // else if (top_dist <= top_left_dist)
+ // dest[x] = top_row[x];
+ // Add these values to the mask. They were already set.
+ const uint8x16_t left_or_top_mask = vorrq_u8(left_mask, top_le_top_left);
+ // else
+ // dest[x] = top_left;
+ return vbslq_u8(left_or_top_mask, result, top_left);
+}
+
+// Generate numbered and high/low versions of top_left_dist.
+#define TOP_LEFT_DIST(num) \
+ const uint16x8_t top_left_##num##_dist_low = vabdq_u16( \
+ vaddl_u8(vget_low_u8(top[num]), vget_low_u8(left)), top_left_x2); \
+ const uint16x8_t top_left_##num##_dist_high = vabdq_u16( \
+ vaddl_u8(vget_high_u8(top[num]), vget_low_u8(left)), top_left_x2)
+
+// Generate numbered versions of XLeTopLeft with x = left.
+#define LEFT_LE_TOP_LEFT(num) \
+ const uint8x16_t left_le_top_left_##num = \
+ XLeTopLeft(left_##num##_dist, top_left_##num##_dist_low, \
+ top_left_##num##_dist_high)
+
+// Generate numbered versions of XLeTopLeft with x = top.
+#define TOP_LE_TOP_LEFT(num) \
+ const uint8x16_t top_le_top_left_##num = XLeTopLeft( \
+ top_dist, top_left_##num##_dist_low, top_left_##num##_dist_high)
+
+template <int width, int height>
+inline void Paeth16PlusxN_NEON(void* const dest, ptrdiff_t stride,
+ const void* const top_row,
+ const void* const left_column) {
+ auto* dest_u8 = static_cast<uint8_t*>(dest);
+ const auto* const top_row_u8 = static_cast<const uint8_t*>(top_row);
+ const auto* const left_col_u8 = static_cast<const uint8_t*>(left_column);
+
+ const uint8x16_t top_left = vdupq_n_u8(top_row_u8[-1]);
+ const uint16x8_t top_left_x2 = vdupq_n_u16(top_row_u8[-1] + top_row_u8[-1]);
+ uint8x16_t top[4];
+ top[0] = vld1q_u8(top_row_u8);
+ if (width > 16) {
+ top[1] = vld1q_u8(top_row_u8 + 16);
+ if (width == 64) {
+ top[2] = vld1q_u8(top_row_u8 + 32);
+ top[3] = vld1q_u8(top_row_u8 + 48);
+ }
+ }
+
+ for (int y = 0; y < height; ++y) {
+ const uint8x16_t left = vdupq_n_u8(left_col_u8[y]);
+
+ const uint8x16_t top_dist = vabdq_u8(left, top_left);
+
+ const uint8x16_t left_0_dist = vabdq_u8(top[0], top_left);
+ TOP_LEFT_DIST(0);
+ const uint8x16_t left_0_le_top = vcleq_u8(left_0_dist, top_dist);
+ LEFT_LE_TOP_LEFT(0);
+ TOP_LE_TOP_LEFT(0);
+
+ const uint8x16_t result_0 =
+ SelectPaeth(top[0], left, top_left, left_0_le_top, left_le_top_left_0,
+ top_le_top_left_0);
+ vst1q_u8(dest_u8, result_0);
+
+ if (width > 16) {
+ const uint8x16_t left_1_dist = vabdq_u8(top[1], top_left);
+ TOP_LEFT_DIST(1);
+ const uint8x16_t left_1_le_top = vcleq_u8(left_1_dist, top_dist);
+ LEFT_LE_TOP_LEFT(1);
+ TOP_LE_TOP_LEFT(1);
+
+ const uint8x16_t result_1 =
+ SelectPaeth(top[1], left, top_left, left_1_le_top, left_le_top_left_1,
+ top_le_top_left_1);
+ vst1q_u8(dest_u8 + 16, result_1);
+
+ if (width == 64) {
+ const uint8x16_t left_2_dist = vabdq_u8(top[2], top_left);
+ TOP_LEFT_DIST(2);
+ const uint8x16_t left_2_le_top = vcleq_u8(left_2_dist, top_dist);
+ LEFT_LE_TOP_LEFT(2);
+ TOP_LE_TOP_LEFT(2);
+
+ const uint8x16_t result_2 =
+ SelectPaeth(top[2], left, top_left, left_2_le_top,
+ left_le_top_left_2, top_le_top_left_2);
+ vst1q_u8(dest_u8 + 32, result_2);
+
+ const uint8x16_t left_3_dist = vabdq_u8(top[3], top_left);
+ TOP_LEFT_DIST(3);
+ const uint8x16_t left_3_le_top = vcleq_u8(left_3_dist, top_dist);
+ LEFT_LE_TOP_LEFT(3);
+ TOP_LE_TOP_LEFT(3);
+
+ const uint8x16_t result_3 =
+ SelectPaeth(top[3], left, top_left, left_3_le_top,
+ left_le_top_left_3, top_le_top_left_3);
+ vst1q_u8(dest_u8 + 48, result_3);
+ }
+ }
+
+ dest_u8 += stride;
+ }
+}
+
+struct DcDefs {
+ DcDefs() = delete;
+
+ using _4x4 = DcPredFuncs_NEON<2, 2, DcSum_NEON, DcStore_NEON<4, 4>>;
+ using _4x8 = DcPredFuncs_NEON<2, 3, DcSum_NEON, DcStore_NEON<4, 8>>;
+ using _4x16 = DcPredFuncs_NEON<2, 4, DcSum_NEON, DcStore_NEON<4, 16>>;
+ using _8x4 = DcPredFuncs_NEON<3, 2, DcSum_NEON, DcStore_NEON<8, 4>>;
+ using _8x8 = DcPredFuncs_NEON<3, 3, DcSum_NEON, DcStore_NEON<8, 8>>;
+ using _8x16 = DcPredFuncs_NEON<3, 4, DcSum_NEON, DcStore_NEON<8, 16>>;
+ using _8x32 = DcPredFuncs_NEON<3, 5, DcSum_NEON, DcStore_NEON<8, 32>>;
+ using _16x4 = DcPredFuncs_NEON<4, 2, DcSum_NEON, DcStore_NEON<16, 4>>;
+ using _16x8 = DcPredFuncs_NEON<4, 3, DcSum_NEON, DcStore_NEON<16, 8>>;
+ using _16x16 = DcPredFuncs_NEON<4, 4, DcSum_NEON, DcStore_NEON<16, 16>>;
+ using _16x32 = DcPredFuncs_NEON<4, 5, DcSum_NEON, DcStore_NEON<16, 32>>;
+ using _16x64 = DcPredFuncs_NEON<4, 6, DcSum_NEON, DcStore_NEON<16, 64>>;
+ using _32x8 = DcPredFuncs_NEON<5, 3, DcSum_NEON, DcStore_NEON<32, 8>>;
+ using _32x16 = DcPredFuncs_NEON<5, 4, DcSum_NEON, DcStore_NEON<32, 16>>;
+ using _32x32 = DcPredFuncs_NEON<5, 5, DcSum_NEON, DcStore_NEON<32, 32>>;
+ using _32x64 = DcPredFuncs_NEON<5, 6, DcSum_NEON, DcStore_NEON<32, 64>>;
+ using _64x16 = DcPredFuncs_NEON<6, 4, DcSum_NEON, DcStore_NEON<64, 16>>;
+ using _64x32 = DcPredFuncs_NEON<6, 5, DcSum_NEON, DcStore_NEON<64, 32>>;
+ using _64x64 = DcPredFuncs_NEON<6, 6, DcSum_NEON, DcStore_NEON<64, 64>>;
+};
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+ // 4x4
+ dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcTop] =
+ DcDefs::_4x4::DcTop;
+ dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcLeft] =
+ DcDefs::_4x4::DcLeft;
+ dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDc] =
+ DcDefs::_4x4::Dc;
+ dsp->intra_predictors[kTransformSize4x4][kIntraPredictorPaeth] =
+ Paeth4Or8xN_NEON<4, 4>;
+
+ // 4x8
+ dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcTop] =
+ DcDefs::_4x8::DcTop;
+ dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcLeft] =
+ DcDefs::_4x8::DcLeft;
+ dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDc] =
+ DcDefs::_4x8::Dc;
+ dsp->intra_predictors[kTransformSize4x8][kIntraPredictorPaeth] =
+ Paeth4Or8xN_NEON<4, 8>;
+
+ // 4x16
+ dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcTop] =
+ DcDefs::_4x16::DcTop;
+ dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcLeft] =
+ DcDefs::_4x16::DcLeft;
+ dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDc] =
+ DcDefs::_4x16::Dc;
+ dsp->intra_predictors[kTransformSize4x16][kIntraPredictorPaeth] =
+ Paeth4Or8xN_NEON<4, 16>;
+
+ // 8x4
+ dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcTop] =
+ DcDefs::_8x4::DcTop;
+ dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcLeft] =
+ DcDefs::_8x4::DcLeft;
+ dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDc] =
+ DcDefs::_8x4::Dc;
+ dsp->intra_predictors[kTransformSize8x4][kIntraPredictorPaeth] =
+ Paeth4Or8xN_NEON<8, 4>;
+
+ // 8x8
+ dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcTop] =
+ DcDefs::_8x8::DcTop;
+ dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcLeft] =
+ DcDefs::_8x8::DcLeft;
+ dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDc] =
+ DcDefs::_8x8::Dc;
+ dsp->intra_predictors[kTransformSize8x8][kIntraPredictorPaeth] =
+ Paeth4Or8xN_NEON<8, 8>;
+
+ // 8x16
+ dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcTop] =
+ DcDefs::_8x16::DcTop;
+ dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcLeft] =
+ DcDefs::_8x16::DcLeft;
+ dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDc] =
+ DcDefs::_8x16::Dc;
+ dsp->intra_predictors[kTransformSize8x16][kIntraPredictorPaeth] =
+ Paeth4Or8xN_NEON<8, 16>;
+
+ // 8x32
+ dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcTop] =
+ DcDefs::_8x32::DcTop;
+ dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcLeft] =
+ DcDefs::_8x32::DcLeft;
+ dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDc] =
+ DcDefs::_8x32::Dc;
+ dsp->intra_predictors[kTransformSize8x32][kIntraPredictorPaeth] =
+ Paeth4Or8xN_NEON<8, 32>;
+
+ // 16x4
+ dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcTop] =
+ DcDefs::_16x4::DcTop;
+ dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcLeft] =
+ DcDefs::_16x4::DcLeft;
+ dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDc] =
+ DcDefs::_16x4::Dc;
+ dsp->intra_predictors[kTransformSize16x4][kIntraPredictorPaeth] =
+ Paeth16PlusxN_NEON<16, 4>;
+
+ // 16x8
+ dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcTop] =
+ DcDefs::_16x8::DcTop;
+ dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcLeft] =
+ DcDefs::_16x8::DcLeft;
+ dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDc] =
+ DcDefs::_16x8::Dc;
+ dsp->intra_predictors[kTransformSize16x8][kIntraPredictorPaeth] =
+ Paeth16PlusxN_NEON<16, 8>;
+
+ // 16x16
+ dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcTop] =
+ DcDefs::_16x16::DcTop;
+ dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcLeft] =
+ DcDefs::_16x16::DcLeft;
+ dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDc] =
+ DcDefs::_16x16::Dc;
+ dsp->intra_predictors[kTransformSize16x16][kIntraPredictorPaeth] =
+ Paeth16PlusxN_NEON<16, 16>;
+
+ // 16x32
+ dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcTop] =
+ DcDefs::_16x32::DcTop;
+ dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcLeft] =
+ DcDefs::_16x32::DcLeft;
+ dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDc] =
+ DcDefs::_16x32::Dc;
+ dsp->intra_predictors[kTransformSize16x32][kIntraPredictorPaeth] =
+ Paeth16PlusxN_NEON<16, 32>;
+
+ // 16x64
+ dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcTop] =
+ DcDefs::_16x64::DcTop;
+ dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcLeft] =
+ DcDefs::_16x64::DcLeft;
+ dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDc] =
+ DcDefs::_16x64::Dc;
+ dsp->intra_predictors[kTransformSize16x64][kIntraPredictorPaeth] =
+ Paeth16PlusxN_NEON<16, 64>;
+
+ // 32x8
+ dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcTop] =
+ DcDefs::_32x8::DcTop;
+ dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcLeft] =
+ DcDefs::_32x8::DcLeft;
+ dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDc] =
+ DcDefs::_32x8::Dc;
+ dsp->intra_predictors[kTransformSize32x8][kIntraPredictorPaeth] =
+ Paeth16PlusxN_NEON<32, 8>;
+
+ // 32x16
+ dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcTop] =
+ DcDefs::_32x16::DcTop;
+ dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcLeft] =
+ DcDefs::_32x16::DcLeft;
+ dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDc] =
+ DcDefs::_32x16::Dc;
+ dsp->intra_predictors[kTransformSize32x16][kIntraPredictorPaeth] =
+ Paeth16PlusxN_NEON<32, 16>;
+
+ // 32x32
+ dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcTop] =
+ DcDefs::_32x32::DcTop;
+ dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcLeft] =
+ DcDefs::_32x32::DcLeft;
+ dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDc] =
+ DcDefs::_32x32::Dc;
+ dsp->intra_predictors[kTransformSize32x32][kIntraPredictorPaeth] =
+ Paeth16PlusxN_NEON<32, 32>;
+
+ // 32x64
+ dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcTop] =
+ DcDefs::_32x64::DcTop;
+ dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcLeft] =
+ DcDefs::_32x64::DcLeft;
+ dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDc] =
+ DcDefs::_32x64::Dc;
+ dsp->intra_predictors[kTransformSize32x64][kIntraPredictorPaeth] =
+ Paeth16PlusxN_NEON<32, 64>;
+
+ // 64x16
+ dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcTop] =
+ DcDefs::_64x16::DcTop;
+ dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcLeft] =
+ DcDefs::_64x16::DcLeft;
+ dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDc] =
+ DcDefs::_64x16::Dc;
+ dsp->intra_predictors[kTransformSize64x16][kIntraPredictorPaeth] =
+ Paeth16PlusxN_NEON<64, 16>;
+
+ // 64x32
+ dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcTop] =
+ DcDefs::_64x32::DcTop;
+ dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcLeft] =
+ DcDefs::_64x32::DcLeft;
+ dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDc] =
+ DcDefs::_64x32::Dc;
+ dsp->intra_predictors[kTransformSize64x32][kIntraPredictorPaeth] =
+ Paeth16PlusxN_NEON<64, 32>;
+
+ // 64x64
+ dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcTop] =
+ DcDefs::_64x64::DcTop;
+ dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcLeft] =
+ DcDefs::_64x64::DcLeft;
+ dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDc] =
+ DcDefs::_64x64::Dc;
+ dsp->intra_predictors[kTransformSize64x64][kIntraPredictorPaeth] =
+ Paeth16PlusxN_NEON<64, 64>;
+}
+
+} // namespace
+} // namespace low_bitdepth
+
+//------------------------------------------------------------------------------
+#if LIBGAV1_MAX_BITDEPTH >= 10
+namespace high_bitdepth {
+namespace {
+
+// Add the elements in the given vectors together but do not sum the entire
+// vector.
+inline uint16x8_t Add(const uint16x8_t val_0, const uint16x8_t val_1,
+ const uint16x8_t val_2, const uint16x8_t val_3) {
+ const uint16x8_t sum_0 = vaddq_u16(val_0, val_1);
+ const uint16x8_t sum_1 = vaddq_u16(val_2, val_3);
+ return vaddq_u16(sum_0, sum_1);
+}
+
+// Load and combine 16 uint16_t values.
+inline uint16x8_t LoadAndAdd16(const uint16_t* buf) {
+ const uint16x8_t val_0 = vld1q_u16(buf);
+ const uint16x8_t val_1 = vld1q_u16(buf + 8);
+ return vaddq_u16(val_0, val_1);
+}
+
+// Load and combine 32 uint16_t values.
+inline uint16x8_t LoadAndAdd32(const uint16_t* buf) {
+ const uint16x8_t val_0 = vld1q_u16(buf);
+ const uint16x8_t val_1 = vld1q_u16(buf + 8);
+ const uint16x8_t val_2 = vld1q_u16(buf + 16);
+ const uint16x8_t val_3 = vld1q_u16(buf + 24);
+ return Add(val_0, val_1, val_2, val_3);
+}
+
+// Load and combine 64 uint16_t values.
+inline uint16x8_t LoadAndAdd64(const uint16_t* buf) {
+ const uint16x8_t val_0 = vld1q_u16(buf);
+ const uint16x8_t val_1 = vld1q_u16(buf + 8);
+ const uint16x8_t val_2 = vld1q_u16(buf + 16);
+ const uint16x8_t val_3 = vld1q_u16(buf + 24);
+ const uint16x8_t val_4 = vld1q_u16(buf + 32);
+ const uint16x8_t val_5 = vld1q_u16(buf + 40);
+ const uint16x8_t val_6 = vld1q_u16(buf + 48);
+ const uint16x8_t val_7 = vld1q_u16(buf + 56);
+ const uint16x8_t sum_0 = Add(val_0, val_1, val_2, val_3);
+ const uint16x8_t sum_1 = Add(val_4, val_5, val_6, val_7);
+ return vaddq_u16(sum_0, sum_1);
+}
+
+// |ref_[01]| each point to 1 << |ref[01]_size_log2| packed uint16_t values.
+// If |use_ref_1| is false then only sum |ref_0|.
+inline uint32x2_t DcSum_NEON(const void* ref_0, const int ref_0_size_log2,
+ const bool use_ref_1, const void* ref_1,
+ const int ref_1_size_log2) {
+ const auto* ref_0_u16 = static_cast<const uint16_t*>(ref_0);
+ const auto* ref_1_u16 = static_cast<const uint16_t*>(ref_1);
+ if (ref_0_size_log2 == 2) {
+ const uint16x4_t val_0 = vld1_u16(ref_0_u16);
+ if (use_ref_1) {
+ if (ref_1_size_log2 == 2) { // 4x4
+ const uint16x4_t val_1 = vld1_u16(ref_1_u16);
+ return Sum(vadd_u16(val_0, val_1));
+ } else if (ref_1_size_log2 == 3) { // 4x8
+ const uint16x8_t val_1 = vld1q_u16(ref_1_u16);
+ const uint16x8_t sum_0 = vcombine_u16(vdup_n_u16(0), val_0);
+ return Sum(vaddq_u16(sum_0, val_1));
+ } else if (ref_1_size_log2 == 4) { // 4x16
+ const uint16x8_t sum_0 = vcombine_u16(vdup_n_u16(0), val_0);
+ const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16);
+ return Sum(vaddq_u16(sum_0, sum_1));
+ }
+ }
+ // 4x1
+ return Sum(val_0);
+ } else if (ref_0_size_log2 == 3) {
+ const uint16x8_t val_0 = vld1q_u16(ref_0_u16);
+ if (use_ref_1) {
+ if (ref_1_size_log2 == 2) { // 8x4
+ const uint16x4_t val_1 = vld1_u16(ref_1_u16);
+ const uint16x8_t sum_1 = vcombine_u16(vdup_n_u16(0), val_1);
+ return Sum(vaddq_u16(val_0, sum_1));
+ } else if (ref_1_size_log2 == 3) { // 8x8
+ const uint16x8_t val_1 = vld1q_u16(ref_1_u16);
+ return Sum(vaddq_u16(val_0, val_1));
+ } else if (ref_1_size_log2 == 4) { // 8x16
+ const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16);
+ return Sum(vaddq_u16(val_0, sum_1));
+ } else if (ref_1_size_log2 == 5) { // 8x32
+ const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16);
+ return Sum(vaddq_u16(val_0, sum_1));
+ }
+ }
+ // 8x1
+ return Sum(val_0);
+ } else if (ref_0_size_log2 == 4) {
+ const uint16x8_t sum_0 = LoadAndAdd16(ref_0_u16);
+ if (use_ref_1) {
+ if (ref_1_size_log2 == 2) { // 16x4
+ const uint16x4_t val_1 = vld1_u16(ref_1_u16);
+ const uint16x8_t sum_1 = vcombine_u16(vdup_n_u16(0), val_1);
+ return Sum(vaddq_u16(sum_0, sum_1));
+ } else if (ref_1_size_log2 == 3) { // 16x8
+ const uint16x8_t val_1 = vld1q_u16(ref_1_u16);
+ return Sum(vaddq_u16(sum_0, val_1));
+ } else if (ref_1_size_log2 == 4) { // 16x16
+ const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16);
+ return Sum(vaddq_u16(sum_0, sum_1));
+ } else if (ref_1_size_log2 == 5) { // 16x32
+ const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16);
+ return Sum(vaddq_u16(sum_0, sum_1));
+ } else if (ref_1_size_log2 == 6) { // 16x64
+ const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u16);
+ return Sum(vaddq_u16(sum_0, sum_1));
+ }
+ }
+ // 16x1
+ return Sum(sum_0);
+ } else if (ref_0_size_log2 == 5) {
+ const uint16x8_t sum_0 = LoadAndAdd32(ref_0_u16);
+ if (use_ref_1) {
+ if (ref_1_size_log2 == 3) { // 32x8
+ const uint16x8_t val_1 = vld1q_u16(ref_1_u16);
+ return Sum(vaddq_u16(sum_0, val_1));
+ } else if (ref_1_size_log2 == 4) { // 32x16
+ const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16);
+ return Sum(vaddq_u16(sum_0, sum_1));
+ } else if (ref_1_size_log2 == 5) { // 32x32
+ const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16);
+ return Sum(vaddq_u16(sum_0, sum_1));
+ } else if (ref_1_size_log2 == 6) { // 32x64
+ const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u16);
+ return Sum(vaddq_u16(sum_0, sum_1));
+ }
+ }
+ // 32x1
+ return Sum(sum_0);
+ }
+
+ assert(ref_0_size_log2 == 6);
+ const uint16x8_t sum_0 = LoadAndAdd64(ref_0_u16);
+ if (use_ref_1) {
+ if (ref_1_size_log2 == 4) { // 64x16
+ const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16);
+ return Sum(vaddq_u16(sum_0, sum_1));
+ } else if (ref_1_size_log2 == 5) { // 64x32
+ const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16);
+ return Sum(vaddq_u16(sum_0, sum_1));
+ } else if (ref_1_size_log2 == 6) { // 64x64
+ const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u16);
+ return Sum(vaddq_u16(sum_0, sum_1));
+ }
+ }
+ // 64x1
+ return Sum(sum_0);
+}
+
+template <int width, int height>
+inline void DcStore_NEON(void* const dest, ptrdiff_t stride,
+ const uint32x2_t dc) {
+ auto* dest_u16 = static_cast<uint16_t*>(dest);
+ ptrdiff_t stride_u16 = stride >> 1;
+ const uint16x8_t dc_dup = vdupq_lane_u16(vreinterpret_u16_u32(dc), 0);
+ if (width == 4) {
+ int i = height - 1;
+ do {
+ vst1_u16(dest_u16, vget_low_u16(dc_dup));
+ dest_u16 += stride_u16;
+ } while (--i != 0);
+ vst1_u16(dest_u16, vget_low_u16(dc_dup));
+ } else if (width == 8) {
+ int i = height - 1;
+ do {
+ vst1q_u16(dest_u16, dc_dup);
+ dest_u16 += stride_u16;
+ } while (--i != 0);
+ vst1q_u16(dest_u16, dc_dup);
+ } else if (width == 16) {
+ int i = height - 1;
+ do {
+ vst1q_u16(dest_u16, dc_dup);
+ vst1q_u16(dest_u16 + 8, dc_dup);
+ dest_u16 += stride_u16;
+ } while (--i != 0);
+ vst1q_u16(dest_u16, dc_dup);
+ vst1q_u16(dest_u16 + 8, dc_dup);
+ } else if (width == 32) {
+ int i = height - 1;
+ do {
+ vst1q_u16(dest_u16, dc_dup);
+ vst1q_u16(dest_u16 + 8, dc_dup);
+ vst1q_u16(dest_u16 + 16, dc_dup);
+ vst1q_u16(dest_u16 + 24, dc_dup);
+ dest_u16 += stride_u16;
+ } while (--i != 0);
+ vst1q_u16(dest_u16, dc_dup);
+ vst1q_u16(dest_u16 + 8, dc_dup);
+ vst1q_u16(dest_u16 + 16, dc_dup);
+ vst1q_u16(dest_u16 + 24, dc_dup);
+ } else {
+ assert(width == 64);
+ int i = height - 1;
+ do {
+ vst1q_u16(dest_u16, dc_dup);
+ vst1q_u16(dest_u16 + 8, dc_dup);
+ vst1q_u16(dest_u16 + 16, dc_dup);
+ vst1q_u16(dest_u16 + 24, dc_dup);
+ vst1q_u16(dest_u16 + 32, dc_dup);
+ vst1q_u16(dest_u16 + 40, dc_dup);
+ vst1q_u16(dest_u16 + 48, dc_dup);
+ vst1q_u16(dest_u16 + 56, dc_dup);
+ dest_u16 += stride_u16;
+ } while (--i != 0);
+ vst1q_u16(dest_u16, dc_dup);
+ vst1q_u16(dest_u16 + 8, dc_dup);
+ vst1q_u16(dest_u16 + 16, dc_dup);
+ vst1q_u16(dest_u16 + 24, dc_dup);
+ vst1q_u16(dest_u16 + 32, dc_dup);
+ vst1q_u16(dest_u16 + 40, dc_dup);
+ vst1q_u16(dest_u16 + 48, dc_dup);
+ vst1q_u16(dest_u16 + 56, dc_dup);
+ }
+}
+
+struct DcDefs {
+ DcDefs() = delete;
+
+ using _4x4 = DcPredFuncs_NEON<2, 2, DcSum_NEON, DcStore_NEON<4, 4>>;
+ using _4x8 = DcPredFuncs_NEON<2, 3, DcSum_NEON, DcStore_NEON<4, 8>>;
+ using _4x16 = DcPredFuncs_NEON<2, 4, DcSum_NEON, DcStore_NEON<4, 16>>;
+ using _8x4 = DcPredFuncs_NEON<3, 2, DcSum_NEON, DcStore_NEON<8, 4>>;
+ using _8x8 = DcPredFuncs_NEON<3, 3, DcSum_NEON, DcStore_NEON<8, 8>>;
+ using _8x16 = DcPredFuncs_NEON<3, 4, DcSum_NEON, DcStore_NEON<8, 16>>;
+ using _8x32 = DcPredFuncs_NEON<3, 5, DcSum_NEON, DcStore_NEON<8, 32>>;
+ using _16x4 = DcPredFuncs_NEON<4, 2, DcSum_NEON, DcStore_NEON<16, 4>>;
+ using _16x8 = DcPredFuncs_NEON<4, 3, DcSum_NEON, DcStore_NEON<16, 8>>;
+ using _16x16 = DcPredFuncs_NEON<4, 4, DcSum_NEON, DcStore_NEON<16, 16>>;
+ using _16x32 = DcPredFuncs_NEON<4, 5, DcSum_NEON, DcStore_NEON<16, 32>>;
+ using _16x64 = DcPredFuncs_NEON<4, 6, DcSum_NEON, DcStore_NEON<16, 64>>;
+ using _32x8 = DcPredFuncs_NEON<5, 3, DcSum_NEON, DcStore_NEON<32, 8>>;
+ using _32x16 = DcPredFuncs_NEON<5, 4, DcSum_NEON, DcStore_NEON<32, 16>>;
+ using _32x32 = DcPredFuncs_NEON<5, 5, DcSum_NEON, DcStore_NEON<32, 32>>;
+ using _32x64 = DcPredFuncs_NEON<5, 6, DcSum_NEON, DcStore_NEON<32, 64>>;
+ using _64x16 = DcPredFuncs_NEON<6, 4, DcSum_NEON, DcStore_NEON<64, 16>>;
+ using _64x32 = DcPredFuncs_NEON<6, 5, DcSum_NEON, DcStore_NEON<64, 32>>;
+ using _64x64 = DcPredFuncs_NEON<6, 6, DcSum_NEON, DcStore_NEON<64, 64>>;
+};
+
+void Init10bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+ assert(dsp != nullptr);
+ dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcTop] =
+ DcDefs::_4x4::DcTop;
+ dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcLeft] =
+ DcDefs::_4x4::DcLeft;
+ dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDc] =
+ DcDefs::_4x4::Dc;
+
+ // 4x8
+ dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcTop] =
+ DcDefs::_4x8::DcTop;
+ dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcLeft] =
+ DcDefs::_4x8::DcLeft;
+ dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDc] =
+ DcDefs::_4x8::Dc;
+
+ // 4x16
+ dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcTop] =
+ DcDefs::_4x16::DcTop;
+ dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcLeft] =
+ DcDefs::_4x16::DcLeft;
+ dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDc] =
+ DcDefs::_4x16::Dc;
+
+ // 8x4
+ dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcTop] =
+ DcDefs::_8x4::DcTop;
+ dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcLeft] =
+ DcDefs::_8x4::DcLeft;
+ dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDc] =
+ DcDefs::_8x4::Dc;
+
+ // 8x8
+ dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcTop] =
+ DcDefs::_8x8::DcTop;
+ dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcLeft] =
+ DcDefs::_8x8::DcLeft;
+ dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDc] =
+ DcDefs::_8x8::Dc;
+
+ // 8x16
+ dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcTop] =
+ DcDefs::_8x16::DcTop;
+ dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcLeft] =
+ DcDefs::_8x16::DcLeft;
+ dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDc] =
+ DcDefs::_8x16::Dc;
+
+ // 8x32
+ dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcTop] =
+ DcDefs::_8x32::DcTop;
+ dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcLeft] =
+ DcDefs::_8x32::DcLeft;
+ dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDc] =
+ DcDefs::_8x32::Dc;
+
+ // 16x4
+ dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcTop] =
+ DcDefs::_16x4::DcTop;
+ dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcLeft] =
+ DcDefs::_16x4::DcLeft;
+ dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDc] =
+ DcDefs::_16x4::Dc;
+
+ // 16x8
+ dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcTop] =
+ DcDefs::_16x8::DcTop;
+ dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcLeft] =
+ DcDefs::_16x8::DcLeft;
+ dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDc] =
+ DcDefs::_16x8::Dc;
+
+ // 16x16
+ dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcTop] =
+ DcDefs::_16x16::DcTop;
+ dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcLeft] =
+ DcDefs::_16x16::DcLeft;
+ dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDc] =
+ DcDefs::_16x16::Dc;
+
+ // 16x32
+ dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcTop] =
+ DcDefs::_16x32::DcTop;
+ dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcLeft] =
+ DcDefs::_16x32::DcLeft;
+ dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDc] =
+ DcDefs::_16x32::Dc;
+
+ // 16x64
+ dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcTop] =
+ DcDefs::_16x64::DcTop;
+ dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcLeft] =
+ DcDefs::_16x64::DcLeft;
+ dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDc] =
+ DcDefs::_16x64::Dc;
+
+ // 32x8
+ dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcTop] =
+ DcDefs::_32x8::DcTop;
+ dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcLeft] =
+ DcDefs::_32x8::DcLeft;
+ dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDc] =
+ DcDefs::_32x8::Dc;
+
+ // 32x16
+ dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcTop] =
+ DcDefs::_32x16::DcTop;
+ dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcLeft] =
+ DcDefs::_32x16::DcLeft;
+ dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDc] =
+ DcDefs::_32x16::Dc;
+
+ // 32x32
+ dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcTop] =
+ DcDefs::_32x32::DcTop;
+ dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcLeft] =
+ DcDefs::_32x32::DcLeft;
+ dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDc] =
+ DcDefs::_32x32::Dc;
+
+ // 32x64
+ dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcTop] =
+ DcDefs::_32x64::DcTop;
+ dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcLeft] =
+ DcDefs::_32x64::DcLeft;
+ dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDc] =
+ DcDefs::_32x64::Dc;
+
+ // 64x16
+ dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcTop] =
+ DcDefs::_64x16::DcTop;
+ dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcLeft] =
+ DcDefs::_64x16::DcLeft;
+ dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDc] =
+ DcDefs::_64x16::Dc;
+
+ // 64x32
+ dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcTop] =
+ DcDefs::_64x32::DcTop;
+ dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcLeft] =
+ DcDefs::_64x32::DcLeft;
+ dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDc] =
+ DcDefs::_64x32::Dc;
+
+ // 64x64
+ dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcTop] =
+ DcDefs::_64x64::DcTop;
+ dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcLeft] =
+ DcDefs::_64x64::DcLeft;
+ dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDc] =
+ DcDefs::_64x64::Dc;
+}
+
+} // namespace
+} // namespace high_bitdepth
+#endif // LIBGAV1_MAX_BITDEPTH >= 10
+
+void IntraPredInit_NEON() {
+ low_bitdepth::Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+ high_bitdepth::Init10bpp();
+#endif
+}
+
+} // namespace dsp
+} // namespace libgav1
+
+#else // !LIBGAV1_ENABLE_NEON
+namespace libgav1 {
+namespace dsp {
+
+void IntraPredInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/intrapred_neon.h b/src/dsp/arm/intrapred_neon.h
new file mode 100644
index 0000000..16f858c
--- /dev/null
+++ b/src/dsp/arm/intrapred_neon.h
@@ -0,0 +1,418 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_INTRAPRED_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_INTRAPRED_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::intra_predictors, Dsp::directional_intra_predictor_zone*,
+// Dsp::cfl_intra_predictors, Dsp::cfl_subsamplers and
+// Dsp::filter_intra_predictor, see the defines below for specifics. These
+// functions are not thread-safe.
+void IntraPredCflInit_NEON();
+void IntraPredDirectionalInit_NEON();
+void IntraPredFilterIntraInit_NEON();
+void IntraPredInit_NEON();
+void IntraPredSmoothInit_NEON();
+
+} // namespace dsp
+} // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+// 8 bit
+#define LIBGAV1_Dsp8bpp_FilterIntraPredictor LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone1 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone2 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone3 LIBGAV1_CPU_NEON
+
+// 4x4
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothVertical \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothHorizontal \
+ LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444 LIBGAV1_CPU_NEON
+
+// 4x8
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothVertical \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothHorizontal \
+ LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler444 LIBGAV1_CPU_NEON
+
+// 4x16
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothVertical \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothHorizontal \
+ LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler444 LIBGAV1_CPU_NEON
+
+// 8x4
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothVertical \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothHorizontal \
+ LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler444 LIBGAV1_CPU_NEON
+
+// 8x8
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothVertical \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothHorizontal \
+ LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler444 LIBGAV1_CPU_NEON
+
+// 8x16
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothVertical \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothHorizontal \
+ LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler444 LIBGAV1_CPU_NEON
+
+// 8x32
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothVertical \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothHorizontal \
+ LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler444 LIBGAV1_CPU_NEON
+
+// 16x4
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothVertical \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothHorizontal \
+ LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler444 LIBGAV1_CPU_NEON
+
+// 16x8
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothVertical \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothHorizontal \
+ LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler444 LIBGAV1_CPU_NEON
+
+// 16x16
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothVertical \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothHorizontal \
+ LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler444 LIBGAV1_CPU_NEON
+
+// 16x32
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothVertical \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothHorizontal \
+ LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler444 LIBGAV1_CPU_NEON
+
+// 16x64
+#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothVertical \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothHorizontal \
+ LIBGAV1_CPU_NEON
+
+// 32x8
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothVertical \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothHorizontal \
+ LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler444 LIBGAV1_CPU_NEON
+
+// 32x16
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothVertical \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothHorizontal \
+ LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler444 LIBGAV1_CPU_NEON
+
+// 32x32
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothVertical \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothHorizontal \
+ LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444 LIBGAV1_CPU_NEON
+
+// 32x64
+#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothVertical \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothHorizontal \
+ LIBGAV1_CPU_NEON
+
+// 64x16
+#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothVertical \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothHorizontal \
+ LIBGAV1_CPU_NEON
+
+// 64x32
+#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothVertical \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothHorizontal \
+ LIBGAV1_CPU_NEON
+
+// 64x64
+#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothVertical \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothHorizontal \
+ LIBGAV1_CPU_NEON
+
+// 10 bit
+// 4x4
+#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDc LIBGAV1_CPU_NEON
+
+// 4x8
+#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDc LIBGAV1_CPU_NEON
+
+// 4x16
+#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDc LIBGAV1_CPU_NEON
+
+// 8x4
+#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDc LIBGAV1_CPU_NEON
+
+// 8x8
+#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDc LIBGAV1_CPU_NEON
+
+// 8x16
+#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDc LIBGAV1_CPU_NEON
+
+// 8x32
+#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDc LIBGAV1_CPU_NEON
+
+// 16x4
+#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDc LIBGAV1_CPU_NEON
+
+// 16x8
+#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDc LIBGAV1_CPU_NEON
+
+// 16x16
+#define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDcLeft \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDc LIBGAV1_CPU_NEON
+
+// 16x32
+#define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDcLeft \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDc LIBGAV1_CPU_NEON
+
+// 16x64
+#define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDcLeft \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDc LIBGAV1_CPU_NEON
+
+// 32x8
+#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDc LIBGAV1_CPU_NEON
+
+// 32x16
+#define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDcLeft \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDc LIBGAV1_CPU_NEON
+
+// 32x32
+#define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDcLeft \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDc LIBGAV1_CPU_NEON
+
+// 32x64
+#define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorDcLeft \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorDc LIBGAV1_CPU_NEON
+
+// 64x16
+#define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDcLeft \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDc LIBGAV1_CPU_NEON
+
+// 64x32
+#define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDcLeft \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDc LIBGAV1_CPU_NEON
+
+// 64x64
+#define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDcLeft \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDc LIBGAV1_CPU_NEON
+#endif // LIBGAV1_ENABLE_NEON
+
+#endif // LIBGAV1_SRC_DSP_ARM_INTRAPRED_NEON_H_
diff --git a/src/dsp/arm/intrapred_smooth_neon.cc b/src/dsp/arm/intrapred_smooth_neon.cc
new file mode 100644
index 0000000..abc93e8
--- /dev/null
+++ b/src/dsp/arm/intrapred_smooth_neon.cc
@@ -0,0 +1,616 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/intrapred.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+
+namespace libgav1 {
+namespace dsp {
+
+namespace low_bitdepth {
+namespace {
+
+// Note these constants are duplicated from intrapred.cc to allow the compiler
+// to have visibility of the values. This helps reduce loads and in the
+// creation of the inverse weights.
+constexpr uint8_t kSmoothWeights[] = {
+ // block dimension = 4
+ 255, 149, 85, 64,
+ // block dimension = 8
+ 255, 197, 146, 105, 73, 50, 37, 32,
+ // block dimension = 16
+ 255, 225, 196, 170, 145, 123, 102, 84, 68, 54, 43, 33, 26, 20, 17, 16,
+ // block dimension = 32
+ 255, 240, 225, 210, 196, 182, 169, 157, 145, 133, 122, 111, 101, 92, 83, 74,
+ 66, 59, 52, 45, 39, 34, 29, 25, 21, 17, 14, 12, 10, 9, 8, 8,
+ // block dimension = 64
+ 255, 248, 240, 233, 225, 218, 210, 203, 196, 189, 182, 176, 169, 163, 156,
+ 150, 144, 138, 133, 127, 121, 116, 111, 106, 101, 96, 91, 86, 82, 77, 73,
+ 69, 65, 61, 57, 54, 50, 47, 44, 41, 38, 35, 32, 29, 27, 25, 22, 20, 18, 16,
+ 15, 13, 12, 10, 9, 8, 7, 6, 6, 5, 5, 4, 4, 4};
+
+// TODO(b/150459137): Keeping the intermediate values in uint16_t would allow
+// processing more values at once. At the high end, it could do 4x4 or 8x2 at a
+// time.
+inline uint16x4_t CalculatePred(const uint16x4_t weighted_top,
+ const uint16x4_t weighted_left,
+ const uint16x4_t weighted_bl,
+ const uint16x4_t weighted_tr) {
+ const uint32x4_t pred_0 = vaddl_u16(weighted_top, weighted_left);
+ const uint32x4_t pred_1 = vaddl_u16(weighted_bl, weighted_tr);
+ const uint32x4_t pred_2 = vaddq_u32(pred_0, pred_1);
+ return vrshrn_n_u32(pred_2, kSmoothWeightScale + 1);
+}
+
+template <int width, int height>
+inline void Smooth4Or8xN_NEON(void* const dest, ptrdiff_t stride,
+ const void* const top_row,
+ const void* const left_column) {
+ const uint8_t* const top = static_cast<const uint8_t*>(top_row);
+ const uint8_t* const left = static_cast<const uint8_t*>(left_column);
+ const uint8_t top_right = top[width - 1];
+ const uint8_t bottom_left = left[height - 1];
+ const uint8_t* const weights_y = kSmoothWeights + height - 4;
+ uint8_t* dst = static_cast<uint8_t*>(dest);
+
+ uint8x8_t top_v;
+ if (width == 4) {
+ top_v = Load4(top);
+ } else { // width == 8
+ top_v = vld1_u8(top);
+ }
+ const uint8x8_t top_right_v = vdup_n_u8(top_right);
+ const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left);
+ // Over-reads for 4xN but still within the array.
+ const uint8x8_t weights_x_v = vld1_u8(kSmoothWeights + width - 4);
+ // 256 - weights = vneg_s8(weights)
+ const uint8x8_t scaled_weights_x =
+ vreinterpret_u8_s8(vneg_s8(vreinterpret_s8_u8(weights_x_v)));
+
+ for (int y = 0; y < height; ++y) {
+ const uint8x8_t left_v = vdup_n_u8(left[y]);
+ const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]);
+ const uint8x8_t scaled_weights_y = vdup_n_u8(256 - weights_y[y]);
+ const uint16x8_t weighted_bl = vmull_u8(scaled_weights_y, bottom_left_v);
+
+ const uint16x8_t weighted_top = vmull_u8(weights_y_v, top_v);
+ const uint16x8_t weighted_left = vmull_u8(weights_x_v, left_v);
+ const uint16x8_t weighted_tr = vmull_u8(scaled_weights_x, top_right_v);
+ const uint16x4_t dest_0 =
+ CalculatePred(vget_low_u16(weighted_top), vget_low_u16(weighted_left),
+ vget_low_u16(weighted_tr), vget_low_u16(weighted_bl));
+
+ if (width == 4) {
+ StoreLo4(dst, vmovn_u16(vcombine_u16(dest_0, dest_0)));
+ } else { // width == 8
+ const uint16x4_t dest_1 = CalculatePred(
+ vget_high_u16(weighted_top), vget_high_u16(weighted_left),
+ vget_high_u16(weighted_tr), vget_high_u16(weighted_bl));
+ vst1_u8(dst, vmovn_u16(vcombine_u16(dest_0, dest_1)));
+ }
+ dst += stride;
+ }
+}
+
+inline uint8x16_t CalculateWeightsAndPred(
+ const uint8x16_t top, const uint8x8_t left, const uint8x8_t top_right,
+ const uint8x8_t weights_y, const uint8x16_t weights_x,
+ const uint8x16_t scaled_weights_x, const uint16x8_t weighted_bl) {
+ const uint16x8_t weighted_top_low = vmull_u8(weights_y, vget_low_u8(top));
+ const uint16x8_t weighted_left_low = vmull_u8(vget_low_u8(weights_x), left);
+ const uint16x8_t weighted_tr_low =
+ vmull_u8(vget_low_u8(scaled_weights_x), top_right);
+ const uint16x4_t dest_0 = CalculatePred(
+ vget_low_u16(weighted_top_low), vget_low_u16(weighted_left_low),
+ vget_low_u16(weighted_tr_low), vget_low_u16(weighted_bl));
+ const uint16x4_t dest_1 = CalculatePred(
+ vget_high_u16(weighted_top_low), vget_high_u16(weighted_left_low),
+ vget_high_u16(weighted_tr_low), vget_high_u16(weighted_bl));
+ const uint8x8_t dest_0_u8 = vmovn_u16(vcombine_u16(dest_0, dest_1));
+
+ const uint16x8_t weighted_top_high = vmull_u8(weights_y, vget_high_u8(top));
+ const uint16x8_t weighted_left_high = vmull_u8(vget_high_u8(weights_x), left);
+ const uint16x8_t weighted_tr_high =
+ vmull_u8(vget_high_u8(scaled_weights_x), top_right);
+ const uint16x4_t dest_2 = CalculatePred(
+ vget_low_u16(weighted_top_high), vget_low_u16(weighted_left_high),
+ vget_low_u16(weighted_tr_high), vget_low_u16(weighted_bl));
+ const uint16x4_t dest_3 = CalculatePred(
+ vget_high_u16(weighted_top_high), vget_high_u16(weighted_left_high),
+ vget_high_u16(weighted_tr_high), vget_high_u16(weighted_bl));
+ const uint8x8_t dest_1_u8 = vmovn_u16(vcombine_u16(dest_2, dest_3));
+
+ return vcombine_u8(dest_0_u8, dest_1_u8);
+}
+
+template <int width, int height>
+inline void Smooth16PlusxN_NEON(void* const dest, ptrdiff_t stride,
+ const void* const top_row,
+ const void* const left_column) {
+ const uint8_t* const top = static_cast<const uint8_t*>(top_row);
+ const uint8_t* const left = static_cast<const uint8_t*>(left_column);
+ const uint8_t top_right = top[width - 1];
+ const uint8_t bottom_left = left[height - 1];
+ const uint8_t* const weights_y = kSmoothWeights + height - 4;
+ uint8_t* dst = static_cast<uint8_t*>(dest);
+
+ uint8x16_t top_v[4];
+ top_v[0] = vld1q_u8(top);
+ if (width > 16) {
+ top_v[1] = vld1q_u8(top + 16);
+ if (width == 64) {
+ top_v[2] = vld1q_u8(top + 32);
+ top_v[3] = vld1q_u8(top + 48);
+ }
+ }
+
+ const uint8x8_t top_right_v = vdup_n_u8(top_right);
+ const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left);
+
+ // TODO(johannkoenig): Consider re-reading top_v and weights_x_v in the loop.
+ // This currently has a performance slope similar to Paeth so it does not
+ // appear to be register bound for arm64.
+ uint8x16_t weights_x_v[4];
+ weights_x_v[0] = vld1q_u8(kSmoothWeights + width - 4);
+ if (width > 16) {
+ weights_x_v[1] = vld1q_u8(kSmoothWeights + width + 16 - 4);
+ if (width == 64) {
+ weights_x_v[2] = vld1q_u8(kSmoothWeights + width + 32 - 4);
+ weights_x_v[3] = vld1q_u8(kSmoothWeights + width + 48 - 4);
+ }
+ }
+
+ uint8x16_t scaled_weights_x[4];
+ scaled_weights_x[0] =
+ vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x_v[0])));
+ if (width > 16) {
+ scaled_weights_x[1] =
+ vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x_v[1])));
+ if (width == 64) {
+ scaled_weights_x[2] =
+ vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x_v[2])));
+ scaled_weights_x[3] =
+ vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x_v[3])));
+ }
+ }
+
+ for (int y = 0; y < height; ++y) {
+ const uint8x8_t left_v = vdup_n_u8(left[y]);
+ const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]);
+ const uint8x8_t scaled_weights_y = vdup_n_u8(256 - weights_y[y]);
+ const uint16x8_t weighted_bl = vmull_u8(scaled_weights_y, bottom_left_v);
+
+ vst1q_u8(dst, CalculateWeightsAndPred(top_v[0], left_v, top_right_v,
+ weights_y_v, weights_x_v[0],
+ scaled_weights_x[0], weighted_bl));
+
+ if (width > 16) {
+ vst1q_u8(dst + 16, CalculateWeightsAndPred(
+ top_v[1], left_v, top_right_v, weights_y_v,
+ weights_x_v[1], scaled_weights_x[1], weighted_bl));
+ if (width == 64) {
+ vst1q_u8(dst + 32,
+ CalculateWeightsAndPred(top_v[2], left_v, top_right_v,
+ weights_y_v, weights_x_v[2],
+ scaled_weights_x[2], weighted_bl));
+ vst1q_u8(dst + 48,
+ CalculateWeightsAndPred(top_v[3], left_v, top_right_v,
+ weights_y_v, weights_x_v[3],
+ scaled_weights_x[3], weighted_bl));
+ }
+ }
+
+ dst += stride;
+ }
+}
+
+template <int width, int height>
+inline void SmoothVertical4Or8xN_NEON(void* const dest, ptrdiff_t stride,
+ const void* const top_row,
+ const void* const left_column) {
+ const uint8_t* const top = static_cast<const uint8_t*>(top_row);
+ const uint8_t* const left = static_cast<const uint8_t*>(left_column);
+ const uint8_t bottom_left = left[height - 1];
+ const uint8_t* const weights_y = kSmoothWeights + height - 4;
+ uint8_t* dst = static_cast<uint8_t*>(dest);
+
+ uint8x8_t top_v;
+ if (width == 4) {
+ top_v = Load4(top);
+ } else { // width == 8
+ top_v = vld1_u8(top);
+ }
+
+ const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left);
+
+ for (int y = 0; y < height; ++y) {
+ const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]);
+ const uint8x8_t scaled_weights_y = vdup_n_u8(256 - weights_y[y]);
+
+ const uint16x8_t weighted_top = vmull_u8(weights_y_v, top_v);
+ const uint16x8_t weighted_bl = vmull_u8(scaled_weights_y, bottom_left_v);
+ const uint16x8_t pred = vaddq_u16(weighted_top, weighted_bl);
+ const uint8x8_t pred_scaled = vrshrn_n_u16(pred, kSmoothWeightScale);
+
+ if (width == 4) {
+ StoreLo4(dst, pred_scaled);
+ } else { // width == 8
+ vst1_u8(dst, pred_scaled);
+ }
+ dst += stride;
+ }
+}
+
+inline uint8x16_t CalculateVerticalWeightsAndPred(
+ const uint8x16_t top, const uint8x8_t weights_y,
+ const uint16x8_t weighted_bl) {
+ const uint16x8_t weighted_top_low = vmull_u8(weights_y, vget_low_u8(top));
+ const uint16x8_t weighted_top_high = vmull_u8(weights_y, vget_high_u8(top));
+ const uint16x8_t pred_low = vaddq_u16(weighted_top_low, weighted_bl);
+ const uint16x8_t pred_high = vaddq_u16(weighted_top_high, weighted_bl);
+ const uint8x8_t pred_scaled_low = vrshrn_n_u16(pred_low, kSmoothWeightScale);
+ const uint8x8_t pred_scaled_high =
+ vrshrn_n_u16(pred_high, kSmoothWeightScale);
+ return vcombine_u8(pred_scaled_low, pred_scaled_high);
+}
+
+template <int width, int height>
+inline void SmoothVertical16PlusxN_NEON(void* const dest, ptrdiff_t stride,
+ const void* const top_row,
+ const void* const left_column) {
+ const uint8_t* const top = static_cast<const uint8_t*>(top_row);
+ const uint8_t* const left = static_cast<const uint8_t*>(left_column);
+ const uint8_t bottom_left = left[height - 1];
+ const uint8_t* const weights_y = kSmoothWeights + height - 4;
+ uint8_t* dst = static_cast<uint8_t*>(dest);
+
+ uint8x16_t top_v[4];
+ top_v[0] = vld1q_u8(top);
+ if (width > 16) {
+ top_v[1] = vld1q_u8(top + 16);
+ if (width == 64) {
+ top_v[2] = vld1q_u8(top + 32);
+ top_v[3] = vld1q_u8(top + 48);
+ }
+ }
+
+ const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left);
+
+ for (int y = 0; y < height; ++y) {
+ const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]);
+ const uint8x8_t scaled_weights_y = vdup_n_u8(256 - weights_y[y]);
+ const uint16x8_t weighted_bl = vmull_u8(scaled_weights_y, bottom_left_v);
+
+ const uint8x16_t pred_0 =
+ CalculateVerticalWeightsAndPred(top_v[0], weights_y_v, weighted_bl);
+ vst1q_u8(dst, pred_0);
+
+ if (width > 16) {
+ const uint8x16_t pred_1 =
+ CalculateVerticalWeightsAndPred(top_v[1], weights_y_v, weighted_bl);
+ vst1q_u8(dst + 16, pred_1);
+
+ if (width == 64) {
+ const uint8x16_t pred_2 =
+ CalculateVerticalWeightsAndPred(top_v[2], weights_y_v, weighted_bl);
+ vst1q_u8(dst + 32, pred_2);
+
+ const uint8x16_t pred_3 =
+ CalculateVerticalWeightsAndPred(top_v[3], weights_y_v, weighted_bl);
+ vst1q_u8(dst + 48, pred_3);
+ }
+ }
+
+ dst += stride;
+ }
+}
+
+template <int width, int height>
+inline void SmoothHorizontal4Or8xN_NEON(void* const dest, ptrdiff_t stride,
+ const void* const top_row,
+ const void* const left_column) {
+ const uint8_t* const top = static_cast<const uint8_t*>(top_row);
+ const uint8_t* const left = static_cast<const uint8_t*>(left_column);
+ const uint8_t top_right = top[width - 1];
+ uint8_t* dst = static_cast<uint8_t*>(dest);
+
+ const uint8x8_t top_right_v = vdup_n_u8(top_right);
+ // Over-reads for 4xN but still within the array.
+ const uint8x8_t weights_x = vld1_u8(kSmoothWeights + width - 4);
+ // 256 - weights = vneg_s8(weights)
+ const uint8x8_t scaled_weights_x =
+ vreinterpret_u8_s8(vneg_s8(vreinterpret_s8_u8(weights_x)));
+
+ for (int y = 0; y < height; ++y) {
+ const uint8x8_t left_v = vdup_n_u8(left[y]);
+
+ const uint16x8_t weighted_left = vmull_u8(weights_x, left_v);
+ const uint16x8_t weighted_tr = vmull_u8(scaled_weights_x, top_right_v);
+ const uint16x8_t pred = vaddq_u16(weighted_left, weighted_tr);
+ const uint8x8_t pred_scaled = vrshrn_n_u16(pred, kSmoothWeightScale);
+
+ if (width == 4) {
+ StoreLo4(dst, pred_scaled);
+ } else { // width == 8
+ vst1_u8(dst, pred_scaled);
+ }
+ dst += stride;
+ }
+}
+
+inline uint8x16_t CalculateHorizontalWeightsAndPred(
+ const uint8x8_t left, const uint8x8_t top_right, const uint8x16_t weights_x,
+ const uint8x16_t scaled_weights_x) {
+ const uint16x8_t weighted_left_low = vmull_u8(vget_low_u8(weights_x), left);
+ const uint16x8_t weighted_tr_low =
+ vmull_u8(vget_low_u8(scaled_weights_x), top_right);
+ const uint16x8_t pred_low = vaddq_u16(weighted_left_low, weighted_tr_low);
+ const uint8x8_t pred_scaled_low = vrshrn_n_u16(pred_low, kSmoothWeightScale);
+
+ const uint16x8_t weighted_left_high = vmull_u8(vget_high_u8(weights_x), left);
+ const uint16x8_t weighted_tr_high =
+ vmull_u8(vget_high_u8(scaled_weights_x), top_right);
+ const uint16x8_t pred_high = vaddq_u16(weighted_left_high, weighted_tr_high);
+ const uint8x8_t pred_scaled_high =
+ vrshrn_n_u16(pred_high, kSmoothWeightScale);
+
+ return vcombine_u8(pred_scaled_low, pred_scaled_high);
+}
+
+template <int width, int height>
+inline void SmoothHorizontal16PlusxN_NEON(void* const dest, ptrdiff_t stride,
+ const void* const top_row,
+ const void* const left_column) {
+ const uint8_t* const top = static_cast<const uint8_t*>(top_row);
+ const uint8_t* const left = static_cast<const uint8_t*>(left_column);
+ const uint8_t top_right = top[width - 1];
+ uint8_t* dst = static_cast<uint8_t*>(dest);
+
+ const uint8x8_t top_right_v = vdup_n_u8(top_right);
+
+ uint8x16_t weights_x[4];
+ weights_x[0] = vld1q_u8(kSmoothWeights + width - 4);
+ if (width > 16) {
+ weights_x[1] = vld1q_u8(kSmoothWeights + width + 16 - 4);
+ if (width == 64) {
+ weights_x[2] = vld1q_u8(kSmoothWeights + width + 32 - 4);
+ weights_x[3] = vld1q_u8(kSmoothWeights + width + 48 - 4);
+ }
+ }
+
+ uint8x16_t scaled_weights_x[4];
+ scaled_weights_x[0] =
+ vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x[0])));
+ if (width > 16) {
+ scaled_weights_x[1] =
+ vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x[1])));
+ if (width == 64) {
+ scaled_weights_x[2] =
+ vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x[2])));
+ scaled_weights_x[3] =
+ vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(weights_x[3])));
+ }
+ }
+
+ for (int y = 0; y < height; ++y) {
+ const uint8x8_t left_v = vdup_n_u8(left[y]);
+
+ const uint8x16_t pred_0 = CalculateHorizontalWeightsAndPred(
+ left_v, top_right_v, weights_x[0], scaled_weights_x[0]);
+ vst1q_u8(dst, pred_0);
+
+ if (width > 16) {
+ const uint8x16_t pred_1 = CalculateHorizontalWeightsAndPred(
+ left_v, top_right_v, weights_x[1], scaled_weights_x[1]);
+ vst1q_u8(dst + 16, pred_1);
+
+ if (width == 64) {
+ const uint8x16_t pred_2 = CalculateHorizontalWeightsAndPred(
+ left_v, top_right_v, weights_x[2], scaled_weights_x[2]);
+ vst1q_u8(dst + 32, pred_2);
+
+ const uint8x16_t pred_3 = CalculateHorizontalWeightsAndPred(
+ left_v, top_right_v, weights_x[3], scaled_weights_x[3]);
+ vst1q_u8(dst + 48, pred_3);
+ }
+ }
+ dst += stride;
+ }
+}
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+ // 4x4
+ dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmooth] =
+ Smooth4Or8xN_NEON<4, 4>;
+ dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmoothVertical] =
+ SmoothVertical4Or8xN_NEON<4, 4>;
+ dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmoothHorizontal] =
+ SmoothHorizontal4Or8xN_NEON<4, 4>;
+
+ // 4x8
+ dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmooth] =
+ Smooth4Or8xN_NEON<4, 8>;
+ dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmoothVertical] =
+ SmoothVertical4Or8xN_NEON<4, 8>;
+ dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmoothHorizontal] =
+ SmoothHorizontal4Or8xN_NEON<4, 8>;
+
+ // 4x16
+ dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmooth] =
+ Smooth4Or8xN_NEON<4, 16>;
+ dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmoothVertical] =
+ SmoothVertical4Or8xN_NEON<4, 16>;
+ dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmoothHorizontal] =
+ SmoothHorizontal4Or8xN_NEON<4, 16>;
+
+ // 8x4
+ dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmooth] =
+ Smooth4Or8xN_NEON<8, 4>;
+ dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmoothVertical] =
+ SmoothVertical4Or8xN_NEON<8, 4>;
+ dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmoothHorizontal] =
+ SmoothHorizontal4Or8xN_NEON<8, 4>;
+
+ // 8x8
+ dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmooth] =
+ Smooth4Or8xN_NEON<8, 8>;
+ dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmoothVertical] =
+ SmoothVertical4Or8xN_NEON<8, 8>;
+ dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmoothHorizontal] =
+ SmoothHorizontal4Or8xN_NEON<8, 8>;
+
+ // 8x16
+ dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmooth] =
+ Smooth4Or8xN_NEON<8, 16>;
+ dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmoothVertical] =
+ SmoothVertical4Or8xN_NEON<8, 16>;
+ dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmoothHorizontal] =
+ SmoothHorizontal4Or8xN_NEON<8, 16>;
+
+ // 8x32
+ dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmooth] =
+ Smooth4Or8xN_NEON<8, 32>;
+ dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmoothVertical] =
+ SmoothVertical4Or8xN_NEON<8, 32>;
+ dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmoothHorizontal] =
+ SmoothHorizontal4Or8xN_NEON<8, 32>;
+
+ // 16x4
+ dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmooth] =
+ Smooth16PlusxN_NEON<16, 4>;
+ dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmoothVertical] =
+ SmoothVertical16PlusxN_NEON<16, 4>;
+ dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmoothHorizontal] =
+ SmoothHorizontal16PlusxN_NEON<16, 4>;
+
+ // 16x8
+ dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmooth] =
+ Smooth16PlusxN_NEON<16, 8>;
+ dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmoothVertical] =
+ SmoothVertical16PlusxN_NEON<16, 8>;
+ dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmoothHorizontal] =
+ SmoothHorizontal16PlusxN_NEON<16, 8>;
+
+ // 16x16
+ dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmooth] =
+ Smooth16PlusxN_NEON<16, 16>;
+ dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmoothVertical] =
+ SmoothVertical16PlusxN_NEON<16, 16>;
+ dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmoothHorizontal] =
+ SmoothHorizontal16PlusxN_NEON<16, 16>;
+
+ // 16x32
+ dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmooth] =
+ Smooth16PlusxN_NEON<16, 32>;
+ dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmoothVertical] =
+ SmoothVertical16PlusxN_NEON<16, 32>;
+ dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmoothHorizontal] =
+ SmoothHorizontal16PlusxN_NEON<16, 32>;
+
+ // 16x64
+ dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmooth] =
+ Smooth16PlusxN_NEON<16, 64>;
+ dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmoothVertical] =
+ SmoothVertical16PlusxN_NEON<16, 64>;
+ dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmoothHorizontal] =
+ SmoothHorizontal16PlusxN_NEON<16, 64>;
+
+ // 32x8
+ dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmooth] =
+ Smooth16PlusxN_NEON<32, 8>;
+ dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmoothVertical] =
+ SmoothVertical16PlusxN_NEON<32, 8>;
+ dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmoothHorizontal] =
+ SmoothHorizontal16PlusxN_NEON<32, 8>;
+
+ // 32x16
+ dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmooth] =
+ Smooth16PlusxN_NEON<32, 16>;
+ dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmoothVertical] =
+ SmoothVertical16PlusxN_NEON<32, 16>;
+ dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmoothHorizontal] =
+ SmoothHorizontal16PlusxN_NEON<32, 16>;
+
+ // 32x32
+ dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmooth] =
+ Smooth16PlusxN_NEON<32, 32>;
+ dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmoothVertical] =
+ SmoothVertical16PlusxN_NEON<32, 32>;
+ dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmoothHorizontal] =
+ SmoothHorizontal16PlusxN_NEON<32, 32>;
+
+ // 32x64
+ dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmooth] =
+ Smooth16PlusxN_NEON<32, 64>;
+ dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmoothVertical] =
+ SmoothVertical16PlusxN_NEON<32, 64>;
+ dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmoothHorizontal] =
+ SmoothHorizontal16PlusxN_NEON<32, 64>;
+
+ // 64x16
+ dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmooth] =
+ Smooth16PlusxN_NEON<64, 16>;
+ dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmoothVertical] =
+ SmoothVertical16PlusxN_NEON<64, 16>;
+ dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmoothHorizontal] =
+ SmoothHorizontal16PlusxN_NEON<64, 16>;
+
+ // 64x32
+ dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmooth] =
+ Smooth16PlusxN_NEON<64, 32>;
+ dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmoothVertical] =
+ SmoothVertical16PlusxN_NEON<64, 32>;
+ dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmoothHorizontal] =
+ SmoothHorizontal16PlusxN_NEON<64, 32>;
+
+ // 64x64
+ dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmooth] =
+ Smooth16PlusxN_NEON<64, 64>;
+ dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmoothVertical] =
+ SmoothVertical16PlusxN_NEON<64, 64>;
+ dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmoothHorizontal] =
+ SmoothHorizontal16PlusxN_NEON<64, 64>;
+}
+
+} // namespace
+} // namespace low_bitdepth
+
+void IntraPredSmoothInit_NEON() { low_bitdepth::Init8bpp(); }
+
+} // namespace dsp
+} // namespace libgav1
+
+#else // !LIBGAV1_ENABLE_NEON
+namespace libgav1 {
+namespace dsp {
+
+void IntraPredSmoothInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/inverse_transform_neon.cc b/src/dsp/arm/inverse_transform_neon.cc
new file mode 100644
index 0000000..072991a
--- /dev/null
+++ b/src/dsp/arm/inverse_transform_neon.cc
@@ -0,0 +1,3128 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/inverse_transform.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstdint>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/array_2d.h"
+#include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
+#include "src/utils/constants.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+// Include the constants and utility functions inside the anonymous namespace.
+#include "src/dsp/inverse_transform.inc"
+
+//------------------------------------------------------------------------------
+
+// TODO(slavarnway): Move transpose functions to transpose_neon.h or
+// common_neon.h.
+
+LIBGAV1_ALWAYS_INLINE void Transpose4x4(const int16x8_t in[4],
+ int16x8_t out[4]) {
+ // Swap 16 bit elements. Goes from:
+ // a0: 00 01 02 03
+ // a1: 10 11 12 13
+ // a2: 20 21 22 23
+ // a3: 30 31 32 33
+ // to:
+ // b0.val[0]: 00 10 02 12
+ // b0.val[1]: 01 11 03 13
+ // b1.val[0]: 20 30 22 32
+ // b1.val[1]: 21 31 23 33
+ const int16x4_t a0 = vget_low_s16(in[0]);
+ const int16x4_t a1 = vget_low_s16(in[1]);
+ const int16x4_t a2 = vget_low_s16(in[2]);
+ const int16x4_t a3 = vget_low_s16(in[3]);
+
+ const int16x4x2_t b0 = vtrn_s16(a0, a1);
+ const int16x4x2_t b1 = vtrn_s16(a2, a3);
+
+ // Swap 32 bit elements resulting in:
+ // c0.val[0]: 00 10 20 30 04 14 24 34
+ // c0.val[1]: 02 12 22 32 06 16 26 36
+ // c1.val[0]: 01 11 21 31 05 15 25 35
+ // c1.val[1]: 03 13 23 33 07 17 27 37
+ const int32x2x2_t c0 = vtrn_s32(vreinterpret_s32_s16(b0.val[0]),
+ vreinterpret_s32_s16(b1.val[0]));
+ const int32x2x2_t c1 = vtrn_s32(vreinterpret_s32_s16(b0.val[1]),
+ vreinterpret_s32_s16(b1.val[1]));
+
+ const int16x4_t d0 = vreinterpret_s16_s32(c0.val[0]);
+ const int16x4_t d1 = vreinterpret_s16_s32(c1.val[0]);
+ const int16x4_t d2 = vreinterpret_s16_s32(c0.val[1]);
+ const int16x4_t d3 = vreinterpret_s16_s32(c1.val[1]);
+
+ out[0] = vcombine_s16(d0, d0);
+ out[1] = vcombine_s16(d1, d1);
+ out[2] = vcombine_s16(d2, d2);
+ out[3] = vcombine_s16(d3, d3);
+}
+
+// Note this is only used in the final stage of Dct32/64 and Adst16 as the in
+// place version causes additional stack usage with clang.
+LIBGAV1_ALWAYS_INLINE void Transpose8x8(const int16x8_t in[8],
+ int16x8_t out[8]) {
+ // Swap 16 bit elements. Goes from:
+ // a0: 00 01 02 03 04 05 06 07
+ // a1: 10 11 12 13 14 15 16 17
+ // a2: 20 21 22 23 24 25 26 27
+ // a3: 30 31 32 33 34 35 36 37
+ // a4: 40 41 42 43 44 45 46 47
+ // a5: 50 51 52 53 54 55 56 57
+ // a6: 60 61 62 63 64 65 66 67
+ // a7: 70 71 72 73 74 75 76 77
+ // to:
+ // b0.val[0]: 00 10 02 12 04 14 06 16
+ // b0.val[1]: 01 11 03 13 05 15 07 17
+ // b1.val[0]: 20 30 22 32 24 34 26 36
+ // b1.val[1]: 21 31 23 33 25 35 27 37
+ // b2.val[0]: 40 50 42 52 44 54 46 56
+ // b2.val[1]: 41 51 43 53 45 55 47 57
+ // b3.val[0]: 60 70 62 72 64 74 66 76
+ // b3.val[1]: 61 71 63 73 65 75 67 77
+
+ const int16x8x2_t b0 = vtrnq_s16(in[0], in[1]);
+ const int16x8x2_t b1 = vtrnq_s16(in[2], in[3]);
+ const int16x8x2_t b2 = vtrnq_s16(in[4], in[5]);
+ const int16x8x2_t b3 = vtrnq_s16(in[6], in[7]);
+
+ // Swap 32 bit elements resulting in:
+ // c0.val[0]: 00 10 20 30 04 14 24 34
+ // c0.val[1]: 02 12 22 32 06 16 26 36
+ // c1.val[0]: 01 11 21 31 05 15 25 35
+ // c1.val[1]: 03 13 23 33 07 17 27 37
+ // c2.val[0]: 40 50 60 70 44 54 64 74
+ // c2.val[1]: 42 52 62 72 46 56 66 76
+ // c3.val[0]: 41 51 61 71 45 55 65 75
+ // c3.val[1]: 43 53 63 73 47 57 67 77
+
+ const int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[0]),
+ vreinterpretq_s32_s16(b1.val[0]));
+ const int32x4x2_t c1 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[1]),
+ vreinterpretq_s32_s16(b1.val[1]));
+ const int32x4x2_t c2 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[0]),
+ vreinterpretq_s32_s16(b3.val[0]));
+ const int32x4x2_t c3 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[1]),
+ vreinterpretq_s32_s16(b3.val[1]));
+
+ // Swap 64 bit elements resulting in:
+ // d0.val[0]: 00 10 20 30 40 50 60 70
+ // d0.val[1]: 04 14 24 34 44 54 64 74
+ // d1.val[0]: 01 11 21 31 41 51 61 71
+ // d1.val[1]: 05 15 25 35 45 55 65 75
+ // d2.val[0]: 02 12 22 32 42 52 62 72
+ // d2.val[1]: 06 16 26 36 46 56 66 76
+ // d3.val[0]: 03 13 23 33 43 53 63 73
+ // d3.val[1]: 07 17 27 37 47 57 67 77
+ const int16x8x2_t d0 = VtrnqS64(c0.val[0], c2.val[0]);
+ const int16x8x2_t d1 = VtrnqS64(c1.val[0], c3.val[0]);
+ const int16x8x2_t d2 = VtrnqS64(c0.val[1], c2.val[1]);
+ const int16x8x2_t d3 = VtrnqS64(c1.val[1], c3.val[1]);
+
+ out[0] = d0.val[0];
+ out[1] = d1.val[0];
+ out[2] = d2.val[0];
+ out[3] = d3.val[0];
+ out[4] = d0.val[1];
+ out[5] = d1.val[1];
+ out[6] = d2.val[1];
+ out[7] = d3.val[1];
+}
+
+LIBGAV1_ALWAYS_INLINE void Transpose4x8To8x4(const uint16x8_t in[8],
+ uint16x8_t out[4]) {
+ // Swap 16 bit elements. Goes from:
+ // a0: 00 01 02 03
+ // a1: 10 11 12 13
+ // a2: 20 21 22 23
+ // a3: 30 31 32 33
+ // a4: 40 41 42 43
+ // a5: 50 51 52 53
+ // a6: 60 61 62 63
+ // a7: 70 71 72 73
+ // to:
+ // b0.val[0]: 00 10 02 12
+ // b0.val[1]: 01 11 03 13
+ // b1.val[0]: 20 30 22 32
+ // b1.val[1]: 21 31 23 33
+ // b2.val[0]: 40 50 42 52
+ // b2.val[1]: 41 51 43 53
+ // b3.val[0]: 60 70 62 72
+ // b3.val[1]: 61 71 63 73
+
+ uint16x4x2_t b0 = vtrn_u16(vget_low_u16(in[0]), vget_low_u16(in[1]));
+ uint16x4x2_t b1 = vtrn_u16(vget_low_u16(in[2]), vget_low_u16(in[3]));
+ uint16x4x2_t b2 = vtrn_u16(vget_low_u16(in[4]), vget_low_u16(in[5]));
+ uint16x4x2_t b3 = vtrn_u16(vget_low_u16(in[6]), vget_low_u16(in[7]));
+
+ // Swap 32 bit elements resulting in:
+ // c0.val[0]: 00 10 20 30
+ // c0.val[1]: 02 12 22 32
+ // c1.val[0]: 01 11 21 31
+ // c1.val[1]: 03 13 23 33
+ // c2.val[0]: 40 50 60 70
+ // c2.val[1]: 42 52 62 72
+ // c3.val[0]: 41 51 61 71
+ // c3.val[1]: 43 53 63 73
+
+ uint32x2x2_t c0 = vtrn_u32(vreinterpret_u32_u16(b0.val[0]),
+ vreinterpret_u32_u16(b1.val[0]));
+ uint32x2x2_t c1 = vtrn_u32(vreinterpret_u32_u16(b0.val[1]),
+ vreinterpret_u32_u16(b1.val[1]));
+ uint32x2x2_t c2 = vtrn_u32(vreinterpret_u32_u16(b2.val[0]),
+ vreinterpret_u32_u16(b3.val[0]));
+ uint32x2x2_t c3 = vtrn_u32(vreinterpret_u32_u16(b2.val[1]),
+ vreinterpret_u32_u16(b3.val[1]));
+
+ // Swap 64 bit elements resulting in:
+ // o0: 00 10 20 30 40 50 60 70
+ // o1: 01 11 21 31 41 51 61 71
+ // o2: 02 12 22 32 42 52 62 72
+ // o3: 03 13 23 33 43 53 63 73
+
+ out[0] = vcombine_u16(vreinterpret_u16_u32(c0.val[0]),
+ vreinterpret_u16_u32(c2.val[0]));
+ out[1] = vcombine_u16(vreinterpret_u16_u32(c1.val[0]),
+ vreinterpret_u16_u32(c3.val[0]));
+ out[2] = vcombine_u16(vreinterpret_u16_u32(c0.val[1]),
+ vreinterpret_u16_u32(c2.val[1]));
+ out[3] = vcombine_u16(vreinterpret_u16_u32(c1.val[1]),
+ vreinterpret_u16_u32(c3.val[1]));
+}
+
+LIBGAV1_ALWAYS_INLINE void Transpose4x8To8x4(const int16x8_t in[8],
+ int16x8_t out[4]) {
+ Transpose4x8To8x4(reinterpret_cast<const uint16x8_t*>(in),
+ reinterpret_cast<uint16x8_t*>(out));
+}
+
+LIBGAV1_ALWAYS_INLINE void Transpose8x4To4x8(const int16x8_t in[4],
+ int16x8_t out[8]) {
+ // Swap 16 bit elements. Goes from:
+ // a0: 00 01 02 03 04 05 06 07
+ // a1: 10 11 12 13 14 15 16 17
+ // a2: 20 21 22 23 24 25 26 27
+ // a3: 30 31 32 33 34 35 36 37
+ // to:
+ // b0.val[0]: 00 10 02 12 04 14 06 16
+ // b0.val[1]: 01 11 03 13 05 15 07 17
+ // b1.val[0]: 20 30 22 32 24 34 26 36
+ // b1.val[1]: 21 31 23 33 25 35 27 37
+ const int16x8x2_t b0 = vtrnq_s16(in[0], in[1]);
+ const int16x8x2_t b1 = vtrnq_s16(in[2], in[3]);
+
+ // Swap 32 bit elements resulting in:
+ // c0.val[0]: 00 10 20 30 04 14 24 34
+ // c0.val[1]: 02 12 22 32 06 16 26 36
+ // c1.val[0]: 01 11 21 31 05 15 25 35
+ // c1.val[1]: 03 13 23 33 07 17 27 37
+ const int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[0]),
+ vreinterpretq_s32_s16(b1.val[0]));
+ const int32x4x2_t c1 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[1]),
+ vreinterpretq_s32_s16(b1.val[1]));
+
+ // The upper 8 bytes are don't cares.
+ // out[0]: 00 10 20 30 04 14 24 34
+ // out[1]: 01 11 21 31 05 15 25 35
+ // out[2]: 02 12 22 32 06 16 26 36
+ // out[3]: 03 13 23 33 07 17 27 37
+ // out[4]: 04 14 24 34 04 14 24 34
+ // out[5]: 05 15 25 35 05 15 25 35
+ // out[6]: 06 16 26 36 06 16 26 36
+ // out[7]: 07 17 27 37 07 17 27 37
+ out[0] = vreinterpretq_s16_s32(c0.val[0]);
+ out[1] = vreinterpretq_s16_s32(c1.val[0]);
+ out[2] = vreinterpretq_s16_s32(c0.val[1]);
+ out[3] = vreinterpretq_s16_s32(c1.val[1]);
+ out[4] = vreinterpretq_s16_s32(
+ vcombine_s32(vget_high_s32(c0.val[0]), vget_high_s32(c0.val[0])));
+ out[5] = vreinterpretq_s16_s32(
+ vcombine_s32(vget_high_s32(c1.val[0]), vget_high_s32(c1.val[0])));
+ out[6] = vreinterpretq_s16_s32(
+ vcombine_s32(vget_high_s32(c0.val[1]), vget_high_s32(c0.val[1])));
+ out[7] = vreinterpretq_s16_s32(
+ vcombine_s32(vget_high_s32(c1.val[1]), vget_high_s32(c1.val[1])));
+}
+
+//------------------------------------------------------------------------------
+template <int store_width, int store_count>
+LIBGAV1_ALWAYS_INLINE void StoreDst(int16_t* dst, int32_t stride, int32_t idx,
+ const int16x8_t* const s) {
+ assert(store_count % 4 == 0);
+ assert(store_width == 8 || store_width == 16);
+ // NOTE: It is expected that the compiler will unroll these loops.
+ if (store_width == 16) {
+ for (int i = 0; i < store_count; i += 4) {
+ vst1q_s16(&dst[i * stride + idx], (s[i]));
+ vst1q_s16(&dst[(i + 1) * stride + idx], (s[i + 1]));
+ vst1q_s16(&dst[(i + 2) * stride + idx], (s[i + 2]));
+ vst1q_s16(&dst[(i + 3) * stride + idx], (s[i + 3]));
+ }
+ } else {
+ // store_width == 8
+ for (int i = 0; i < store_count; i += 4) {
+ vst1_s16(&dst[i * stride + idx], vget_low_s16(s[i]));
+ vst1_s16(&dst[(i + 1) * stride + idx], vget_low_s16(s[i + 1]));
+ vst1_s16(&dst[(i + 2) * stride + idx], vget_low_s16(s[i + 2]));
+ vst1_s16(&dst[(i + 3) * stride + idx], vget_low_s16(s[i + 3]));
+ }
+ }
+}
+
+template <int load_width, int load_count>
+LIBGAV1_ALWAYS_INLINE void LoadSrc(const int16_t* src, int32_t stride,
+ int32_t idx, int16x8_t* x) {
+ assert(load_count % 4 == 0);
+ assert(load_width == 8 || load_width == 16);
+ // NOTE: It is expected that the compiler will unroll these loops.
+ if (load_width == 16) {
+ for (int i = 0; i < load_count; i += 4) {
+ x[i] = vld1q_s16(&src[i * stride + idx]);
+ x[i + 1] = vld1q_s16(&src[(i + 1) * stride + idx]);
+ x[i + 2] = vld1q_s16(&src[(i + 2) * stride + idx]);
+ x[i + 3] = vld1q_s16(&src[(i + 3) * stride + idx]);
+ }
+ } else {
+ // load_width == 8
+ const int64x2_t zero = vdupq_n_s64(0);
+ for (int i = 0; i < load_count; i += 4) {
+ // The src buffer is aligned to 32 bytes. Each load will always be 8
+ // byte aligned.
+ x[i] = vreinterpretq_s16_s64(vld1q_lane_s64(
+ reinterpret_cast<const int64_t*>(&src[i * stride + idx]), zero, 0));
+ x[i + 1] = vreinterpretq_s16_s64(vld1q_lane_s64(
+ reinterpret_cast<const int64_t*>(&src[(i + 1) * stride + idx]), zero,
+ 0));
+ x[i + 2] = vreinterpretq_s16_s64(vld1q_lane_s64(
+ reinterpret_cast<const int64_t*>(&src[(i + 2) * stride + idx]), zero,
+ 0));
+ x[i + 3] = vreinterpretq_s16_s64(vld1q_lane_s64(
+ reinterpret_cast<const int64_t*>(&src[(i + 3) * stride + idx]), zero,
+ 0));
+ }
+ }
+}
+
+// Butterfly rotate 4 values.
+LIBGAV1_ALWAYS_INLINE void ButterflyRotation_4(int16x8_t* a, int16x8_t* b,
+ const int angle,
+ const bool flip) {
+ const int16_t cos128 = Cos128(angle);
+ const int16_t sin128 = Sin128(angle);
+ const int32x4_t acc_x = vmull_n_s16(vget_low_s16(*a), cos128);
+ const int32x4_t acc_y = vmull_n_s16(vget_low_s16(*a), sin128);
+ const int32x4_t x0 = vmlsl_n_s16(acc_x, vget_low_s16(*b), sin128);
+ const int32x4_t y0 = vmlal_n_s16(acc_y, vget_low_s16(*b), cos128);
+ const int16x4_t x1 = vqrshrn_n_s32(x0, 12);
+ const int16x4_t y1 = vqrshrn_n_s32(y0, 12);
+ const int16x8_t x = vcombine_s16(x1, x1);
+ const int16x8_t y = vcombine_s16(y1, y1);
+ if (flip) {
+ *a = y;
+ *b = x;
+ } else {
+ *a = x;
+ *b = y;
+ }
+}
+
+// Butterfly rotate 8 values.
+LIBGAV1_ALWAYS_INLINE void ButterflyRotation_8(int16x8_t* a, int16x8_t* b,
+ const int angle,
+ const bool flip) {
+ const int16_t cos128 = Cos128(angle);
+ const int16_t sin128 = Sin128(angle);
+ const int32x4_t acc_x = vmull_n_s16(vget_low_s16(*a), cos128);
+ const int32x4_t acc_y = vmull_n_s16(vget_low_s16(*a), sin128);
+ const int32x4_t x0 = vmlsl_n_s16(acc_x, vget_low_s16(*b), sin128);
+ const int32x4_t y0 = vmlal_n_s16(acc_y, vget_low_s16(*b), cos128);
+ const int16x4_t x1 = vqrshrn_n_s32(x0, 12);
+ const int16x4_t y1 = vqrshrn_n_s32(y0, 12);
+
+ const int32x4_t acc_x_hi = vmull_n_s16(vget_high_s16(*a), cos128);
+ const int32x4_t acc_y_hi = vmull_n_s16(vget_high_s16(*a), sin128);
+ const int32x4_t x0_hi = vmlsl_n_s16(acc_x_hi, vget_high_s16(*b), sin128);
+ const int32x4_t y0_hi = vmlal_n_s16(acc_y_hi, vget_high_s16(*b), cos128);
+ const int16x4_t x1_hi = vqrshrn_n_s32(x0_hi, 12);
+ const int16x4_t y1_hi = vqrshrn_n_s32(y0_hi, 12);
+
+ const int16x8_t x = vcombine_s16(x1, x1_hi);
+ const int16x8_t y = vcombine_s16(y1, y1_hi);
+ if (flip) {
+ *a = y;
+ *b = x;
+ } else {
+ *a = x;
+ *b = y;
+ }
+}
+
+LIBGAV1_ALWAYS_INLINE void ButterflyRotation_FirstIsZero(int16x8_t* a,
+ int16x8_t* b,
+ const int angle,
+ const bool flip) {
+ const int16_t cos128 = Cos128(angle);
+ const int16_t sin128 = Sin128(angle);
+ // For this function, the max value returned by Sin128() is 4091, which fits
+ // inside 12 bits. This leaves room for the sign bit and the 3 left shifted
+ // bits.
+ assert(sin128 <= 0xfff);
+ const int16x8_t x = vqrdmulhq_n_s16(*b, -sin128 << 3);
+ const int16x8_t y = vqrdmulhq_n_s16(*b, cos128 << 3);
+ if (flip) {
+ *a = y;
+ *b = x;
+ } else {
+ *a = x;
+ *b = y;
+ }
+}
+
+LIBGAV1_ALWAYS_INLINE void ButterflyRotation_SecondIsZero(int16x8_t* a,
+ int16x8_t* b,
+ const int angle,
+ const bool flip) {
+ const int16_t cos128 = Cos128(angle);
+ const int16_t sin128 = Sin128(angle);
+ const int16x8_t x = vqrdmulhq_n_s16(*a, cos128 << 3);
+ const int16x8_t y = vqrdmulhq_n_s16(*a, sin128 << 3);
+ if (flip) {
+ *a = y;
+ *b = x;
+ } else {
+ *a = x;
+ *b = y;
+ }
+}
+
+LIBGAV1_ALWAYS_INLINE void HadamardRotation(int16x8_t* a, int16x8_t* b,
+ bool flip) {
+ int16x8_t x, y;
+ if (flip) {
+ y = vqaddq_s16(*b, *a);
+ x = vqsubq_s16(*b, *a);
+ } else {
+ x = vqaddq_s16(*a, *b);
+ y = vqsubq_s16(*a, *b);
+ }
+ *a = x;
+ *b = y;
+}
+
+using ButterflyRotationFunc = void (*)(int16x8_t* a, int16x8_t* b, int angle,
+ bool flip);
+
+//------------------------------------------------------------------------------
+// Discrete Cosine Transforms (DCT).
+
+template <int width>
+LIBGAV1_ALWAYS_INLINE bool DctDcOnly(void* dest, int adjusted_tx_height,
+ bool should_round, int row_shift) {
+ if (adjusted_tx_height > 1) return false;
+
+ auto* dst = static_cast<int16_t*>(dest);
+ const int16x8_t v_src = vdupq_n_s16(dst[0]);
+ const uint16x8_t v_mask = vdupq_n_u16(should_round ? 0xffff : 0);
+ const int16x8_t v_src_round =
+ vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3);
+ const int16x8_t s0 = vbslq_s16(v_mask, v_src_round, v_src);
+ const int16_t cos128 = Cos128(32);
+ const int16x8_t xy = vqrdmulhq_n_s16(s0, cos128 << 3);
+ // vqrshlq_s16 will shift right if shift value is negative.
+ const int16x8_t xy_shifted = vqrshlq_s16(xy, vdupq_n_s16(-row_shift));
+
+ if (width == 4) {
+ vst1_s16(dst, vget_low_s16(xy_shifted));
+ } else {
+ for (int i = 0; i < width; i += 8) {
+ vst1q_s16(dst, xy_shifted);
+ dst += 8;
+ }
+ }
+ return true;
+}
+
+template <int height>
+LIBGAV1_ALWAYS_INLINE bool DctDcOnlyColumn(void* dest, int adjusted_tx_height,
+ int width) {
+ if (adjusted_tx_height > 1) return false;
+
+ auto* dst = static_cast<int16_t*>(dest);
+ const int16_t cos128 = Cos128(32);
+
+ // Calculate dc values for first row.
+ if (width == 4) {
+ const int16x4_t v_src = vld1_s16(dst);
+ const int16x4_t xy = vqrdmulh_n_s16(v_src, cos128 << 3);
+ vst1_s16(dst, xy);
+ } else {
+ int i = 0;
+ do {
+ const int16x8_t v_src = vld1q_s16(&dst[i]);
+ const int16x8_t xy = vqrdmulhq_n_s16(v_src, cos128 << 3);
+ vst1q_s16(&dst[i], xy);
+ i += 8;
+ } while (i < width);
+ }
+
+ // Copy first row to the rest of the block.
+ for (int y = 1; y < height; ++y) {
+ memcpy(&dst[y * width], dst, width * sizeof(dst[0]));
+ }
+ return true;
+}
+
+template <ButterflyRotationFunc butterfly_rotation,
+ bool is_fast_butterfly = false>
+LIBGAV1_ALWAYS_INLINE void Dct4Stages(int16x8_t* s) {
+ // stage 12.
+ if (is_fast_butterfly) {
+ ButterflyRotation_SecondIsZero(&s[0], &s[1], 32, true);
+ ButterflyRotation_SecondIsZero(&s[2], &s[3], 48, false);
+ } else {
+ butterfly_rotation(&s[0], &s[1], 32, true);
+ butterfly_rotation(&s[2], &s[3], 48, false);
+ }
+
+ // stage 17.
+ HadamardRotation(&s[0], &s[3], false);
+ HadamardRotation(&s[1], &s[2], false);
+}
+
+template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular>
+LIBGAV1_ALWAYS_INLINE void Dct4_NEON(void* dest, int32_t step, bool transpose) {
+ auto* const dst = static_cast<int16_t*>(dest);
+ int16x8_t s[4], x[4];
+
+ if (stage_is_rectangular) {
+ if (transpose) {
+ int16x8_t input[8];
+ LoadSrc<8, 8>(dst, step, 0, input);
+ Transpose4x8To8x4(input, x);
+ } else {
+ LoadSrc<16, 4>(dst, step, 0, x);
+ }
+ } else {
+ LoadSrc<8, 4>(dst, step, 0, x);
+ if (transpose) {
+ Transpose4x4(x, x);
+ }
+ }
+
+ // stage 1.
+ // kBitReverseLookup 0, 2, 1, 3
+ s[0] = x[0];
+ s[1] = x[2];
+ s[2] = x[1];
+ s[3] = x[3];
+
+ Dct4Stages<butterfly_rotation>(s);
+
+ if (stage_is_rectangular) {
+ if (transpose) {
+ int16x8_t output[8];
+ Transpose8x4To4x8(s, output);
+ StoreDst<8, 8>(dst, step, 0, output);
+ } else {
+ StoreDst<16, 4>(dst, step, 0, s);
+ }
+ } else {
+ if (transpose) {
+ Transpose4x4(s, s);
+ }
+ StoreDst<8, 4>(dst, step, 0, s);
+ }
+}
+
+template <ButterflyRotationFunc butterfly_rotation,
+ bool is_fast_butterfly = false>
+LIBGAV1_ALWAYS_INLINE void Dct8Stages(int16x8_t* s) {
+ // stage 8.
+ if (is_fast_butterfly) {
+ ButterflyRotation_SecondIsZero(&s[4], &s[7], 56, false);
+ ButterflyRotation_FirstIsZero(&s[5], &s[6], 24, false);
+ } else {
+ butterfly_rotation(&s[4], &s[7], 56, false);
+ butterfly_rotation(&s[5], &s[6], 24, false);
+ }
+
+ // stage 13.
+ HadamardRotation(&s[4], &s[5], false);
+ HadamardRotation(&s[6], &s[7], true);
+
+ // stage 18.
+ butterfly_rotation(&s[6], &s[5], 32, true);
+
+ // stage 22.
+ HadamardRotation(&s[0], &s[7], false);
+ HadamardRotation(&s[1], &s[6], false);
+ HadamardRotation(&s[2], &s[5], false);
+ HadamardRotation(&s[3], &s[4], false);
+}
+
+// Process dct8 rows or columns, depending on the transpose flag.
+template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular>
+LIBGAV1_ALWAYS_INLINE void Dct8_NEON(void* dest, int32_t step, bool transpose) {
+ auto* const dst = static_cast<int16_t*>(dest);
+ int16x8_t s[8], x[8];
+
+ if (stage_is_rectangular) {
+ if (transpose) {
+ int16x8_t input[4];
+ LoadSrc<16, 4>(dst, step, 0, input);
+ Transpose8x4To4x8(input, x);
+ } else {
+ LoadSrc<8, 8>(dst, step, 0, x);
+ }
+ } else if (transpose) {
+ LoadSrc<16, 8>(dst, step, 0, x);
+ dsp::Transpose8x8(x);
+ } else {
+ LoadSrc<16, 8>(dst, step, 0, x);
+ }
+
+ // stage 1.
+ // kBitReverseLookup 0, 4, 2, 6, 1, 5, 3, 7,
+ s[0] = x[0];
+ s[1] = x[4];
+ s[2] = x[2];
+ s[3] = x[6];
+ s[4] = x[1];
+ s[5] = x[5];
+ s[6] = x[3];
+ s[7] = x[7];
+
+ Dct4Stages<butterfly_rotation>(s);
+ Dct8Stages<butterfly_rotation>(s);
+
+ if (stage_is_rectangular) {
+ if (transpose) {
+ int16x8_t output[4];
+ Transpose4x8To8x4(s, output);
+ StoreDst<16, 4>(dst, step, 0, output);
+ } else {
+ StoreDst<8, 8>(dst, step, 0, s);
+ }
+ } else if (transpose) {
+ dsp::Transpose8x8(s);
+ StoreDst<16, 8>(dst, step, 0, s);
+ } else {
+ StoreDst<16, 8>(dst, step, 0, s);
+ }
+}
+
+template <ButterflyRotationFunc butterfly_rotation,
+ bool is_fast_butterfly = false>
+LIBGAV1_ALWAYS_INLINE void Dct16Stages(int16x8_t* s) {
+ // stage 5.
+ if (is_fast_butterfly) {
+ ButterflyRotation_SecondIsZero(&s[8], &s[15], 60, false);
+ ButterflyRotation_FirstIsZero(&s[9], &s[14], 28, false);
+ ButterflyRotation_SecondIsZero(&s[10], &s[13], 44, false);
+ ButterflyRotation_FirstIsZero(&s[11], &s[12], 12, false);
+ } else {
+ butterfly_rotation(&s[8], &s[15], 60, false);
+ butterfly_rotation(&s[9], &s[14], 28, false);
+ butterfly_rotation(&s[10], &s[13], 44, false);
+ butterfly_rotation(&s[11], &s[12], 12, false);
+ }
+
+ // stage 9.
+ HadamardRotation(&s[8], &s[9], false);
+ HadamardRotation(&s[10], &s[11], true);
+ HadamardRotation(&s[12], &s[13], false);
+ HadamardRotation(&s[14], &s[15], true);
+
+ // stage 14.
+ butterfly_rotation(&s[14], &s[9], 48, true);
+ butterfly_rotation(&s[13], &s[10], 112, true);
+
+ // stage 19.
+ HadamardRotation(&s[8], &s[11], false);
+ HadamardRotation(&s[9], &s[10], false);
+ HadamardRotation(&s[12], &s[15], true);
+ HadamardRotation(&s[13], &s[14], true);
+
+ // stage 23.
+ butterfly_rotation(&s[13], &s[10], 32, true);
+ butterfly_rotation(&s[12], &s[11], 32, true);
+
+ // stage 26.
+ HadamardRotation(&s[0], &s[15], false);
+ HadamardRotation(&s[1], &s[14], false);
+ HadamardRotation(&s[2], &s[13], false);
+ HadamardRotation(&s[3], &s[12], false);
+ HadamardRotation(&s[4], &s[11], false);
+ HadamardRotation(&s[5], &s[10], false);
+ HadamardRotation(&s[6], &s[9], false);
+ HadamardRotation(&s[7], &s[8], false);
+}
+
+// Process dct16 rows or columns, depending on the transpose flag.
+template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular>
+LIBGAV1_ALWAYS_INLINE void Dct16_NEON(void* dest, int32_t step, bool is_row,
+ int row_shift) {
+ auto* const dst = static_cast<int16_t*>(dest);
+ int16x8_t s[16], x[16];
+
+ if (stage_is_rectangular) {
+ if (is_row) {
+ int16x8_t input[4];
+ LoadSrc<16, 4>(dst, step, 0, input);
+ Transpose8x4To4x8(input, x);
+ LoadSrc<16, 4>(dst, step, 8, input);
+ Transpose8x4To4x8(input, &x[8]);
+ } else {
+ LoadSrc<8, 16>(dst, step, 0, x);
+ }
+ } else if (is_row) {
+ for (int idx = 0; idx < 16; idx += 8) {
+ LoadSrc<16, 8>(dst, step, idx, &x[idx]);
+ dsp::Transpose8x8(&x[idx]);
+ }
+ } else {
+ LoadSrc<16, 16>(dst, step, 0, x);
+ }
+
+ // stage 1
+ // kBitReverseLookup 0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15,
+ s[0] = x[0];
+ s[1] = x[8];
+ s[2] = x[4];
+ s[3] = x[12];
+ s[4] = x[2];
+ s[5] = x[10];
+ s[6] = x[6];
+ s[7] = x[14];
+ s[8] = x[1];
+ s[9] = x[9];
+ s[10] = x[5];
+ s[11] = x[13];
+ s[12] = x[3];
+ s[13] = x[11];
+ s[14] = x[7];
+ s[15] = x[15];
+
+ Dct4Stages<butterfly_rotation>(s);
+ Dct8Stages<butterfly_rotation>(s);
+ Dct16Stages<butterfly_rotation>(s);
+
+ if (is_row) {
+ const int16x8_t v_row_shift = vdupq_n_s16(-row_shift);
+ for (int i = 0; i < 16; ++i) {
+ s[i] = vqrshlq_s16(s[i], v_row_shift);
+ }
+ }
+
+ if (stage_is_rectangular) {
+ if (is_row) {
+ int16x8_t output[4];
+ Transpose4x8To8x4(s, output);
+ StoreDst<16, 4>(dst, step, 0, output);
+ Transpose4x8To8x4(&s[8], output);
+ StoreDst<16, 4>(dst, step, 8, output);
+ } else {
+ StoreDst<8, 16>(dst, step, 0, s);
+ }
+ } else if (is_row) {
+ for (int idx = 0; idx < 16; idx += 8) {
+ dsp::Transpose8x8(&s[idx]);
+ StoreDst<16, 8>(dst, step, idx, &s[idx]);
+ }
+ } else {
+ StoreDst<16, 16>(dst, step, 0, s);
+ }
+}
+
+template <ButterflyRotationFunc butterfly_rotation,
+ bool is_fast_butterfly = false>
+LIBGAV1_ALWAYS_INLINE void Dct32Stages(int16x8_t* s) {
+ // stage 3
+ if (is_fast_butterfly) {
+ ButterflyRotation_SecondIsZero(&s[16], &s[31], 62, false);
+ ButterflyRotation_FirstIsZero(&s[17], &s[30], 30, false);
+ ButterflyRotation_SecondIsZero(&s[18], &s[29], 46, false);
+ ButterflyRotation_FirstIsZero(&s[19], &s[28], 14, false);
+ ButterflyRotation_SecondIsZero(&s[20], &s[27], 54, false);
+ ButterflyRotation_FirstIsZero(&s[21], &s[26], 22, false);
+ ButterflyRotation_SecondIsZero(&s[22], &s[25], 38, false);
+ ButterflyRotation_FirstIsZero(&s[23], &s[24], 6, false);
+ } else {
+ butterfly_rotation(&s[16], &s[31], 62, false);
+ butterfly_rotation(&s[17], &s[30], 30, false);
+ butterfly_rotation(&s[18], &s[29], 46, false);
+ butterfly_rotation(&s[19], &s[28], 14, false);
+ butterfly_rotation(&s[20], &s[27], 54, false);
+ butterfly_rotation(&s[21], &s[26], 22, false);
+ butterfly_rotation(&s[22], &s[25], 38, false);
+ butterfly_rotation(&s[23], &s[24], 6, false);
+ }
+ // stage 6.
+ HadamardRotation(&s[16], &s[17], false);
+ HadamardRotation(&s[18], &s[19], true);
+ HadamardRotation(&s[20], &s[21], false);
+ HadamardRotation(&s[22], &s[23], true);
+ HadamardRotation(&s[24], &s[25], false);
+ HadamardRotation(&s[26], &s[27], true);
+ HadamardRotation(&s[28], &s[29], false);
+ HadamardRotation(&s[30], &s[31], true);
+
+ // stage 10.
+ butterfly_rotation(&s[30], &s[17], 24 + 32, true);
+ butterfly_rotation(&s[29], &s[18], 24 + 64 + 32, true);
+ butterfly_rotation(&s[26], &s[21], 24, true);
+ butterfly_rotation(&s[25], &s[22], 24 + 64, true);
+
+ // stage 15.
+ HadamardRotation(&s[16], &s[19], false);
+ HadamardRotation(&s[17], &s[18], false);
+ HadamardRotation(&s[20], &s[23], true);
+ HadamardRotation(&s[21], &s[22], true);
+ HadamardRotation(&s[24], &s[27], false);
+ HadamardRotation(&s[25], &s[26], false);
+ HadamardRotation(&s[28], &s[31], true);
+ HadamardRotation(&s[29], &s[30], true);
+
+ // stage 20.
+ butterfly_rotation(&s[29], &s[18], 48, true);
+ butterfly_rotation(&s[28], &s[19], 48, true);
+ butterfly_rotation(&s[27], &s[20], 48 + 64, true);
+ butterfly_rotation(&s[26], &s[21], 48 + 64, true);
+
+ // stage 24.
+ HadamardRotation(&s[16], &s[23], false);
+ HadamardRotation(&s[17], &s[22], false);
+ HadamardRotation(&s[18], &s[21], false);
+ HadamardRotation(&s[19], &s[20], false);
+ HadamardRotation(&s[24], &s[31], true);
+ HadamardRotation(&s[25], &s[30], true);
+ HadamardRotation(&s[26], &s[29], true);
+ HadamardRotation(&s[27], &s[28], true);
+
+ // stage 27.
+ butterfly_rotation(&s[27], &s[20], 32, true);
+ butterfly_rotation(&s[26], &s[21], 32, true);
+ butterfly_rotation(&s[25], &s[22], 32, true);
+ butterfly_rotation(&s[24], &s[23], 32, true);
+
+ // stage 29.
+ HadamardRotation(&s[0], &s[31], false);
+ HadamardRotation(&s[1], &s[30], false);
+ HadamardRotation(&s[2], &s[29], false);
+ HadamardRotation(&s[3], &s[28], false);
+ HadamardRotation(&s[4], &s[27], false);
+ HadamardRotation(&s[5], &s[26], false);
+ HadamardRotation(&s[6], &s[25], false);
+ HadamardRotation(&s[7], &s[24], false);
+ HadamardRotation(&s[8], &s[23], false);
+ HadamardRotation(&s[9], &s[22], false);
+ HadamardRotation(&s[10], &s[21], false);
+ HadamardRotation(&s[11], &s[20], false);
+ HadamardRotation(&s[12], &s[19], false);
+ HadamardRotation(&s[13], &s[18], false);
+ HadamardRotation(&s[14], &s[17], false);
+ HadamardRotation(&s[15], &s[16], false);
+}
+
+// Process dct32 rows or columns, depending on the transpose flag.
+LIBGAV1_ALWAYS_INLINE void Dct32_NEON(void* dest, const int32_t step,
+ const bool is_row, int row_shift) {
+ auto* const dst = static_cast<int16_t*>(dest);
+ int16x8_t s[32], x[32];
+
+ if (is_row) {
+ for (int idx = 0; idx < 32; idx += 8) {
+ LoadSrc<16, 8>(dst, step, idx, &x[idx]);
+ dsp::Transpose8x8(&x[idx]);
+ }
+ } else {
+ LoadSrc<16, 32>(dst, step, 0, x);
+ }
+
+ // stage 1
+ // kBitReverseLookup
+ // 0, 16, 8, 24, 4, 20, 12, 28, 2, 18, 10, 26, 6, 22, 14, 30,
+ s[0] = x[0];
+ s[1] = x[16];
+ s[2] = x[8];
+ s[3] = x[24];
+ s[4] = x[4];
+ s[5] = x[20];
+ s[6] = x[12];
+ s[7] = x[28];
+ s[8] = x[2];
+ s[9] = x[18];
+ s[10] = x[10];
+ s[11] = x[26];
+ s[12] = x[6];
+ s[13] = x[22];
+ s[14] = x[14];
+ s[15] = x[30];
+
+ // 1, 17, 9, 25, 5, 21, 13, 29, 3, 19, 11, 27, 7, 23, 15, 31,
+ s[16] = x[1];
+ s[17] = x[17];
+ s[18] = x[9];
+ s[19] = x[25];
+ s[20] = x[5];
+ s[21] = x[21];
+ s[22] = x[13];
+ s[23] = x[29];
+ s[24] = x[3];
+ s[25] = x[19];
+ s[26] = x[11];
+ s[27] = x[27];
+ s[28] = x[7];
+ s[29] = x[23];
+ s[30] = x[15];
+ s[31] = x[31];
+
+ Dct4Stages<ButterflyRotation_8>(s);
+ Dct8Stages<ButterflyRotation_8>(s);
+ Dct16Stages<ButterflyRotation_8>(s);
+ Dct32Stages<ButterflyRotation_8>(s);
+
+ if (is_row) {
+ const int16x8_t v_row_shift = vdupq_n_s16(-row_shift);
+ for (int idx = 0; idx < 32; idx += 8) {
+ int16x8_t output[8];
+ Transpose8x8(&s[idx], output);
+ for (int i = 0; i < 8; ++i) {
+ output[i] = vqrshlq_s16(output[i], v_row_shift);
+ }
+ StoreDst<16, 8>(dst, step, idx, output);
+ }
+ } else {
+ StoreDst<16, 32>(dst, step, 0, s);
+ }
+}
+
+// Allow the compiler to call this function instead of force inlining. Tests
+// show the performance is slightly faster.
+void Dct64_NEON(void* dest, int32_t step, bool is_row, int row_shift) {
+ auto* const dst = static_cast<int16_t*>(dest);
+ int16x8_t s[64], x[32];
+
+ if (is_row) {
+ // The last 32 values of every row are always zero if the |tx_width| is
+ // 64.
+ for (int idx = 0; idx < 32; idx += 8) {
+ LoadSrc<16, 8>(dst, step, idx, &x[idx]);
+ dsp::Transpose8x8(&x[idx]);
+ }
+ } else {
+ // The last 32 values of every column are always zero if the |tx_height| is
+ // 64.
+ LoadSrc<16, 32>(dst, step, 0, x);
+ }
+
+ // stage 1
+ // kBitReverseLookup
+ // 0, 32, 16, 48, 8, 40, 24, 56, 4, 36, 20, 52, 12, 44, 28, 60,
+ s[0] = x[0];
+ s[2] = x[16];
+ s[4] = x[8];
+ s[6] = x[24];
+ s[8] = x[4];
+ s[10] = x[20];
+ s[12] = x[12];
+ s[14] = x[28];
+
+ // 2, 34, 18, 50, 10, 42, 26, 58, 6, 38, 22, 54, 14, 46, 30, 62,
+ s[16] = x[2];
+ s[18] = x[18];
+ s[20] = x[10];
+ s[22] = x[26];
+ s[24] = x[6];
+ s[26] = x[22];
+ s[28] = x[14];
+ s[30] = x[30];
+
+ // 1, 33, 17, 49, 9, 41, 25, 57, 5, 37, 21, 53, 13, 45, 29, 61,
+ s[32] = x[1];
+ s[34] = x[17];
+ s[36] = x[9];
+ s[38] = x[25];
+ s[40] = x[5];
+ s[42] = x[21];
+ s[44] = x[13];
+ s[46] = x[29];
+
+ // 3, 35, 19, 51, 11, 43, 27, 59, 7, 39, 23, 55, 15, 47, 31, 63
+ s[48] = x[3];
+ s[50] = x[19];
+ s[52] = x[11];
+ s[54] = x[27];
+ s[56] = x[7];
+ s[58] = x[23];
+ s[60] = x[15];
+ s[62] = x[31];
+
+ Dct4Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s);
+ Dct8Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s);
+ Dct16Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s);
+ Dct32Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s);
+
+ //-- start dct 64 stages
+ // stage 2.
+ ButterflyRotation_SecondIsZero(&s[32], &s[63], 63 - 0, false);
+ ButterflyRotation_FirstIsZero(&s[33], &s[62], 63 - 32, false);
+ ButterflyRotation_SecondIsZero(&s[34], &s[61], 63 - 16, false);
+ ButterflyRotation_FirstIsZero(&s[35], &s[60], 63 - 48, false);
+ ButterflyRotation_SecondIsZero(&s[36], &s[59], 63 - 8, false);
+ ButterflyRotation_FirstIsZero(&s[37], &s[58], 63 - 40, false);
+ ButterflyRotation_SecondIsZero(&s[38], &s[57], 63 - 24, false);
+ ButterflyRotation_FirstIsZero(&s[39], &s[56], 63 - 56, false);
+ ButterflyRotation_SecondIsZero(&s[40], &s[55], 63 - 4, false);
+ ButterflyRotation_FirstIsZero(&s[41], &s[54], 63 - 36, false);
+ ButterflyRotation_SecondIsZero(&s[42], &s[53], 63 - 20, false);
+ ButterflyRotation_FirstIsZero(&s[43], &s[52], 63 - 52, false);
+ ButterflyRotation_SecondIsZero(&s[44], &s[51], 63 - 12, false);
+ ButterflyRotation_FirstIsZero(&s[45], &s[50], 63 - 44, false);
+ ButterflyRotation_SecondIsZero(&s[46], &s[49], 63 - 28, false);
+ ButterflyRotation_FirstIsZero(&s[47], &s[48], 63 - 60, false);
+
+ // stage 4.
+ HadamardRotation(&s[32], &s[33], false);
+ HadamardRotation(&s[34], &s[35], true);
+ HadamardRotation(&s[36], &s[37], false);
+ HadamardRotation(&s[38], &s[39], true);
+ HadamardRotation(&s[40], &s[41], false);
+ HadamardRotation(&s[42], &s[43], true);
+ HadamardRotation(&s[44], &s[45], false);
+ HadamardRotation(&s[46], &s[47], true);
+ HadamardRotation(&s[48], &s[49], false);
+ HadamardRotation(&s[50], &s[51], true);
+ HadamardRotation(&s[52], &s[53], false);
+ HadamardRotation(&s[54], &s[55], true);
+ HadamardRotation(&s[56], &s[57], false);
+ HadamardRotation(&s[58], &s[59], true);
+ HadamardRotation(&s[60], &s[61], false);
+ HadamardRotation(&s[62], &s[63], true);
+
+ // stage 7.
+ ButterflyRotation_8(&s[62], &s[33], 60 - 0, true);
+ ButterflyRotation_8(&s[61], &s[34], 60 - 0 + 64, true);
+ ButterflyRotation_8(&s[58], &s[37], 60 - 32, true);
+ ButterflyRotation_8(&s[57], &s[38], 60 - 32 + 64, true);
+ ButterflyRotation_8(&s[54], &s[41], 60 - 16, true);
+ ButterflyRotation_8(&s[53], &s[42], 60 - 16 + 64, true);
+ ButterflyRotation_8(&s[50], &s[45], 60 - 48, true);
+ ButterflyRotation_8(&s[49], &s[46], 60 - 48 + 64, true);
+
+ // stage 11.
+ HadamardRotation(&s[32], &s[35], false);
+ HadamardRotation(&s[33], &s[34], false);
+ HadamardRotation(&s[36], &s[39], true);
+ HadamardRotation(&s[37], &s[38], true);
+ HadamardRotation(&s[40], &s[43], false);
+ HadamardRotation(&s[41], &s[42], false);
+ HadamardRotation(&s[44], &s[47], true);
+ HadamardRotation(&s[45], &s[46], true);
+ HadamardRotation(&s[48], &s[51], false);
+ HadamardRotation(&s[49], &s[50], false);
+ HadamardRotation(&s[52], &s[55], true);
+ HadamardRotation(&s[53], &s[54], true);
+ HadamardRotation(&s[56], &s[59], false);
+ HadamardRotation(&s[57], &s[58], false);
+ HadamardRotation(&s[60], &s[63], true);
+ HadamardRotation(&s[61], &s[62], true);
+
+ // stage 16.
+ ButterflyRotation_8(&s[61], &s[34], 56, true);
+ ButterflyRotation_8(&s[60], &s[35], 56, true);
+ ButterflyRotation_8(&s[59], &s[36], 56 + 64, true);
+ ButterflyRotation_8(&s[58], &s[37], 56 + 64, true);
+ ButterflyRotation_8(&s[53], &s[42], 56 - 32, true);
+ ButterflyRotation_8(&s[52], &s[43], 56 - 32, true);
+ ButterflyRotation_8(&s[51], &s[44], 56 - 32 + 64, true);
+ ButterflyRotation_8(&s[50], &s[45], 56 - 32 + 64, true);
+
+ // stage 21.
+ HadamardRotation(&s[32], &s[39], false);
+ HadamardRotation(&s[33], &s[38], false);
+ HadamardRotation(&s[34], &s[37], false);
+ HadamardRotation(&s[35], &s[36], false);
+ HadamardRotation(&s[40], &s[47], true);
+ HadamardRotation(&s[41], &s[46], true);
+ HadamardRotation(&s[42], &s[45], true);
+ HadamardRotation(&s[43], &s[44], true);
+ HadamardRotation(&s[48], &s[55], false);
+ HadamardRotation(&s[49], &s[54], false);
+ HadamardRotation(&s[50], &s[53], false);
+ HadamardRotation(&s[51], &s[52], false);
+ HadamardRotation(&s[56], &s[63], true);
+ HadamardRotation(&s[57], &s[62], true);
+ HadamardRotation(&s[58], &s[61], true);
+ HadamardRotation(&s[59], &s[60], true);
+
+ // stage 25.
+ ButterflyRotation_8(&s[59], &s[36], 48, true);
+ ButterflyRotation_8(&s[58], &s[37], 48, true);
+ ButterflyRotation_8(&s[57], &s[38], 48, true);
+ ButterflyRotation_8(&s[56], &s[39], 48, true);
+ ButterflyRotation_8(&s[55], &s[40], 112, true);
+ ButterflyRotation_8(&s[54], &s[41], 112, true);
+ ButterflyRotation_8(&s[53], &s[42], 112, true);
+ ButterflyRotation_8(&s[52], &s[43], 112, true);
+
+ // stage 28.
+ HadamardRotation(&s[32], &s[47], false);
+ HadamardRotation(&s[33], &s[46], false);
+ HadamardRotation(&s[34], &s[45], false);
+ HadamardRotation(&s[35], &s[44], false);
+ HadamardRotation(&s[36], &s[43], false);
+ HadamardRotation(&s[37], &s[42], false);
+ HadamardRotation(&s[38], &s[41], false);
+ HadamardRotation(&s[39], &s[40], false);
+ HadamardRotation(&s[48], &s[63], true);
+ HadamardRotation(&s[49], &s[62], true);
+ HadamardRotation(&s[50], &s[61], true);
+ HadamardRotation(&s[51], &s[60], true);
+ HadamardRotation(&s[52], &s[59], true);
+ HadamardRotation(&s[53], &s[58], true);
+ HadamardRotation(&s[54], &s[57], true);
+ HadamardRotation(&s[55], &s[56], true);
+
+ // stage 30.
+ ButterflyRotation_8(&s[55], &s[40], 32, true);
+ ButterflyRotation_8(&s[54], &s[41], 32, true);
+ ButterflyRotation_8(&s[53], &s[42], 32, true);
+ ButterflyRotation_8(&s[52], &s[43], 32, true);
+ ButterflyRotation_8(&s[51], &s[44], 32, true);
+ ButterflyRotation_8(&s[50], &s[45], 32, true);
+ ButterflyRotation_8(&s[49], &s[46], 32, true);
+ ButterflyRotation_8(&s[48], &s[47], 32, true);
+
+ // stage 31.
+ for (int i = 0; i < 32; i += 4) {
+ HadamardRotation(&s[i], &s[63 - i], false);
+ HadamardRotation(&s[i + 1], &s[63 - i - 1], false);
+ HadamardRotation(&s[i + 2], &s[63 - i - 2], false);
+ HadamardRotation(&s[i + 3], &s[63 - i - 3], false);
+ }
+ //-- end dct 64 stages
+
+ if (is_row) {
+ const int16x8_t v_row_shift = vdupq_n_s16(-row_shift);
+ for (int idx = 0; idx < 64; idx += 8) {
+ int16x8_t output[8];
+ Transpose8x8(&s[idx], output);
+ for (int i = 0; i < 8; ++i) {
+ output[i] = vqrshlq_s16(output[i], v_row_shift);
+ }
+ StoreDst<16, 8>(dst, step, idx, output);
+ }
+ } else {
+ StoreDst<16, 64>(dst, step, 0, s);
+ }
+}
+
+//------------------------------------------------------------------------------
+// Asymmetric Discrete Sine Transforms (ADST).
+template <bool stage_is_rectangular>
+LIBGAV1_ALWAYS_INLINE void Adst4_NEON(void* dest, int32_t step,
+ bool transpose) {
+ auto* const dst = static_cast<int16_t*>(dest);
+ int32x4_t s[8];
+ int16x8_t x[4];
+
+ if (stage_is_rectangular) {
+ if (transpose) {
+ int16x8_t input[8];
+ LoadSrc<8, 8>(dst, step, 0, input);
+ Transpose4x8To8x4(input, x);
+ } else {
+ LoadSrc<16, 4>(dst, step, 0, x);
+ }
+ } else {
+ LoadSrc<8, 4>(dst, step, 0, x);
+ if (transpose) {
+ Transpose4x4(x, x);
+ }
+ }
+
+ // stage 1.
+ s[5] = vmull_n_s16(vget_low_s16(x[3]), kAdst4Multiplier[1]);
+ s[6] = vmull_n_s16(vget_low_s16(x[3]), kAdst4Multiplier[3]);
+
+ // stage 2.
+ const int32x4_t a7 = vsubl_s16(vget_low_s16(x[0]), vget_low_s16(x[2]));
+ const int32x4_t b7 = vaddw_s16(a7, vget_low_s16(x[3]));
+
+ // stage 3.
+ s[0] = vmull_n_s16(vget_low_s16(x[0]), kAdst4Multiplier[0]);
+ s[1] = vmull_n_s16(vget_low_s16(x[0]), kAdst4Multiplier[1]);
+ // s[0] = s[0] + s[3]
+ s[0] = vmlal_n_s16(s[0], vget_low_s16(x[2]), kAdst4Multiplier[3]);
+ // s[1] = s[1] - s[4]
+ s[1] = vmlsl_n_s16(s[1], vget_low_s16(x[2]), kAdst4Multiplier[0]);
+
+ s[3] = vmull_n_s16(vget_low_s16(x[1]), kAdst4Multiplier[2]);
+ s[2] = vmulq_n_s32(b7, kAdst4Multiplier[2]);
+
+ // stage 4.
+ s[0] = vaddq_s32(s[0], s[5]);
+ s[1] = vsubq_s32(s[1], s[6]);
+
+ // stages 5 and 6.
+ const int32x4_t x0 = vaddq_s32(s[0], s[3]);
+ const int32x4_t x1 = vaddq_s32(s[1], s[3]);
+ const int32x4_t x3_a = vaddq_s32(s[0], s[1]);
+ const int32x4_t x3 = vsubq_s32(x3_a, s[3]);
+ const int16x4_t dst_0 = vqrshrn_n_s32(x0, 12);
+ const int16x4_t dst_1 = vqrshrn_n_s32(x1, 12);
+ const int16x4_t dst_2 = vqrshrn_n_s32(s[2], 12);
+ const int16x4_t dst_3 = vqrshrn_n_s32(x3, 12);
+
+ x[0] = vcombine_s16(dst_0, dst_0);
+ x[1] = vcombine_s16(dst_1, dst_1);
+ x[2] = vcombine_s16(dst_2, dst_2);
+ x[3] = vcombine_s16(dst_3, dst_3);
+
+ if (stage_is_rectangular) {
+ if (transpose) {
+ int16x8_t output[8];
+ Transpose8x4To4x8(x, output);
+ StoreDst<8, 8>(dst, step, 0, output);
+ } else {
+ StoreDst<16, 4>(dst, step, 0, x);
+ }
+ } else {
+ if (transpose) {
+ Transpose4x4(x, x);
+ }
+ StoreDst<8, 4>(dst, step, 0, x);
+ }
+}
+
+alignas(8) constexpr int16_t kAdst4DcOnlyMultiplier[4] = {1321, 2482, 3344,
+ 2482};
+
+LIBGAV1_ALWAYS_INLINE bool Adst4DcOnly(void* dest, int adjusted_tx_height,
+ bool should_round, int row_shift) {
+ if (adjusted_tx_height > 1) return false;
+
+ auto* dst = static_cast<int16_t*>(dest);
+ int32x4_t s[2];
+
+ const int16x4_t v_src0 = vdup_n_s16(dst[0]);
+ const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0);
+ const int16x4_t v_src_round =
+ vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3);
+ const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0);
+ const int16x4_t kAdst4DcOnlyMultipliers = vld1_s16(kAdst4DcOnlyMultiplier);
+ s[1] = vdupq_n_s32(0);
+
+ // s0*k0 s0*k1 s0*k2 s0*k1
+ s[0] = vmull_s16(kAdst4DcOnlyMultipliers, v_src);
+ // 0 0 0 s0*k0
+ s[1] = vextq_s32(s[1], s[0], 1);
+
+ const int32x4_t x3 = vaddq_s32(s[0], s[1]);
+ const int16x4_t dst_0 = vqrshrn_n_s32(x3, 12);
+
+ // vqrshlq_s16 will shift right if shift value is negative.
+ vst1_s16(dst, vqrshl_s16(dst_0, vdup_n_s16(-row_shift)));
+
+ return true;
+}
+
+LIBGAV1_ALWAYS_INLINE bool Adst4DcOnlyColumn(void* dest, int adjusted_tx_height,
+ int width) {
+ if (adjusted_tx_height > 1) return false;
+
+ auto* dst = static_cast<int16_t*>(dest);
+ int32x4_t s[4];
+
+ int i = 0;
+ do {
+ const int16x4_t v_src = vld1_s16(&dst[i]);
+
+ s[0] = vmull_n_s16(v_src, kAdst4Multiplier[0]);
+ s[1] = vmull_n_s16(v_src, kAdst4Multiplier[1]);
+ s[2] = vmull_n_s16(v_src, kAdst4Multiplier[2]);
+
+ const int32x4_t x0 = s[0];
+ const int32x4_t x1 = s[1];
+ const int32x4_t x2 = s[2];
+ const int32x4_t x3 = vaddq_s32(s[0], s[1]);
+ const int16x4_t dst_0 = vqrshrn_n_s32(x0, 12);
+ const int16x4_t dst_1 = vqrshrn_n_s32(x1, 12);
+ const int16x4_t dst_2 = vqrshrn_n_s32(x2, 12);
+ const int16x4_t dst_3 = vqrshrn_n_s32(x3, 12);
+
+ vst1_s16(&dst[i], dst_0);
+ vst1_s16(&dst[i + width * 1], dst_1);
+ vst1_s16(&dst[i + width * 2], dst_2);
+ vst1_s16(&dst[i + width * 3], dst_3);
+
+ i += 4;
+ } while (i < width);
+
+ return true;
+}
+
+template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular>
+LIBGAV1_ALWAYS_INLINE void Adst8_NEON(void* dest, int32_t step,
+ bool transpose) {
+ auto* const dst = static_cast<int16_t*>(dest);
+ int16x8_t s[8], x[8];
+
+ if (stage_is_rectangular) {
+ if (transpose) {
+ int16x8_t input[4];
+ LoadSrc<16, 4>(dst, step, 0, input);
+ Transpose8x4To4x8(input, x);
+ } else {
+ LoadSrc<8, 8>(dst, step, 0, x);
+ }
+ } else {
+ if (transpose) {
+ LoadSrc<16, 8>(dst, step, 0, x);
+ dsp::Transpose8x8(x);
+ } else {
+ LoadSrc<16, 8>(dst, step, 0, x);
+ }
+ }
+
+ // stage 1.
+ s[0] = x[7];
+ s[1] = x[0];
+ s[2] = x[5];
+ s[3] = x[2];
+ s[4] = x[3];
+ s[5] = x[4];
+ s[6] = x[1];
+ s[7] = x[6];
+
+ // stage 2.
+ butterfly_rotation(&s[0], &s[1], 60 - 0, true);
+ butterfly_rotation(&s[2], &s[3], 60 - 16, true);
+ butterfly_rotation(&s[4], &s[5], 60 - 32, true);
+ butterfly_rotation(&s[6], &s[7], 60 - 48, true);
+
+ // stage 3.
+ HadamardRotation(&s[0], &s[4], false);
+ HadamardRotation(&s[1], &s[5], false);
+ HadamardRotation(&s[2], &s[6], false);
+ HadamardRotation(&s[3], &s[7], false);
+
+ // stage 4.
+ butterfly_rotation(&s[4], &s[5], 48 - 0, true);
+ butterfly_rotation(&s[7], &s[6], 48 - 32, true);
+
+ // stage 5.
+ HadamardRotation(&s[0], &s[2], false);
+ HadamardRotation(&s[4], &s[6], false);
+ HadamardRotation(&s[1], &s[3], false);
+ HadamardRotation(&s[5], &s[7], false);
+
+ // stage 6.
+ butterfly_rotation(&s[2], &s[3], 32, true);
+ butterfly_rotation(&s[6], &s[7], 32, true);
+
+ // stage 7.
+ x[0] = s[0];
+ x[1] = vqnegq_s16(s[4]);
+ x[2] = s[6];
+ x[3] = vqnegq_s16(s[2]);
+ x[4] = s[3];
+ x[5] = vqnegq_s16(s[7]);
+ x[6] = s[5];
+ x[7] = vqnegq_s16(s[1]);
+
+ if (stage_is_rectangular) {
+ if (transpose) {
+ int16x8_t output[4];
+ Transpose4x8To8x4(x, output);
+ StoreDst<16, 4>(dst, step, 0, output);
+ } else {
+ StoreDst<8, 8>(dst, step, 0, x);
+ }
+ } else {
+ if (transpose) {
+ dsp::Transpose8x8(x);
+ StoreDst<16, 8>(dst, step, 0, x);
+ } else {
+ StoreDst<16, 8>(dst, step, 0, x);
+ }
+ }
+}
+
+LIBGAV1_ALWAYS_INLINE bool Adst8DcOnly(void* dest, int adjusted_tx_height,
+ bool should_round, int row_shift) {
+ if (adjusted_tx_height > 1) return false;
+
+ auto* dst = static_cast<int16_t*>(dest);
+ int16x8_t s[8];
+
+ const int16x8_t v_src = vdupq_n_s16(dst[0]);
+ const uint16x8_t v_mask = vdupq_n_u16(should_round ? 0xffff : 0);
+ const int16x8_t v_src_round =
+ vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3);
+ // stage 1.
+ s[1] = vbslq_s16(v_mask, v_src_round, v_src);
+
+ // stage 2.
+ ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true);
+
+ // stage 3.
+ s[4] = s[0];
+ s[5] = s[1];
+
+ // stage 4.
+ ButterflyRotation_4(&s[4], &s[5], 48, true);
+
+ // stage 5.
+ s[2] = s[0];
+ s[3] = s[1];
+ s[6] = s[4];
+ s[7] = s[5];
+
+ // stage 6.
+ ButterflyRotation_4(&s[2], &s[3], 32, true);
+ ButterflyRotation_4(&s[6], &s[7], 32, true);
+
+ // stage 7.
+ int16x8_t x[8];
+ x[0] = s[0];
+ x[1] = vqnegq_s16(s[4]);
+ x[2] = s[6];
+ x[3] = vqnegq_s16(s[2]);
+ x[4] = s[3];
+ x[5] = vqnegq_s16(s[7]);
+ x[6] = s[5];
+ x[7] = vqnegq_s16(s[1]);
+
+ for (int i = 0; i < 8; ++i) {
+ // vqrshlq_s16 will shift right if shift value is negative.
+ x[i] = vqrshlq_s16(x[i], vdupq_n_s16(-row_shift));
+ vst1q_lane_s16(&dst[i], x[i], 0);
+ }
+
+ return true;
+}
+
+LIBGAV1_ALWAYS_INLINE bool Adst8DcOnlyColumn(void* dest, int adjusted_tx_height,
+ int width) {
+ if (adjusted_tx_height > 1) return false;
+
+ auto* dst = static_cast<int16_t*>(dest);
+ int16x8_t s[8];
+
+ int i = 0;
+ do {
+ const int16x8_t v_src = vld1q_s16(dst);
+ // stage 1.
+ s[1] = v_src;
+
+ // stage 2.
+ ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true);
+
+ // stage 3.
+ s[4] = s[0];
+ s[5] = s[1];
+
+ // stage 4.
+ ButterflyRotation_4(&s[4], &s[5], 48, true);
+
+ // stage 5.
+ s[2] = s[0];
+ s[3] = s[1];
+ s[6] = s[4];
+ s[7] = s[5];
+
+ // stage 6.
+ ButterflyRotation_4(&s[2], &s[3], 32, true);
+ ButterflyRotation_4(&s[6], &s[7], 32, true);
+
+ // stage 7.
+ int16x8_t x[8];
+ x[0] = s[0];
+ x[1] = vqnegq_s16(s[4]);
+ x[2] = s[6];
+ x[3] = vqnegq_s16(s[2]);
+ x[4] = s[3];
+ x[5] = vqnegq_s16(s[7]);
+ x[6] = s[5];
+ x[7] = vqnegq_s16(s[1]);
+
+ for (int j = 0; j < 8; ++j) {
+ vst1_s16(&dst[j * width], vget_low_s16(x[j]));
+ }
+ i += 4;
+ dst += 4;
+ } while (i < width);
+
+ return true;
+}
+
+template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular>
+LIBGAV1_ALWAYS_INLINE void Adst16_NEON(void* dest, int32_t step, bool is_row,
+ int row_shift) {
+ auto* const dst = static_cast<int16_t*>(dest);
+ int16x8_t s[16], x[16];
+
+ if (stage_is_rectangular) {
+ if (is_row) {
+ int16x8_t input[4];
+ LoadSrc<16, 4>(dst, step, 0, input);
+ Transpose8x4To4x8(input, x);
+ LoadSrc<16, 4>(dst, step, 8, input);
+ Transpose8x4To4x8(input, &x[8]);
+ } else {
+ LoadSrc<8, 16>(dst, step, 0, x);
+ }
+ } else {
+ if (is_row) {
+ for (int idx = 0; idx < 16; idx += 8) {
+ LoadSrc<16, 8>(dst, step, idx, &x[idx]);
+ dsp::Transpose8x8(&x[idx]);
+ }
+ } else {
+ LoadSrc<16, 16>(dst, step, 0, x);
+ }
+ }
+
+ // stage 1.
+ s[0] = x[15];
+ s[1] = x[0];
+ s[2] = x[13];
+ s[3] = x[2];
+ s[4] = x[11];
+ s[5] = x[4];
+ s[6] = x[9];
+ s[7] = x[6];
+ s[8] = x[7];
+ s[9] = x[8];
+ s[10] = x[5];
+ s[11] = x[10];
+ s[12] = x[3];
+ s[13] = x[12];
+ s[14] = x[1];
+ s[15] = x[14];
+
+ // stage 2.
+ butterfly_rotation(&s[0], &s[1], 62 - 0, true);
+ butterfly_rotation(&s[2], &s[3], 62 - 8, true);
+ butterfly_rotation(&s[4], &s[5], 62 - 16, true);
+ butterfly_rotation(&s[6], &s[7], 62 - 24, true);
+ butterfly_rotation(&s[8], &s[9], 62 - 32, true);
+ butterfly_rotation(&s[10], &s[11], 62 - 40, true);
+ butterfly_rotation(&s[12], &s[13], 62 - 48, true);
+ butterfly_rotation(&s[14], &s[15], 62 - 56, true);
+
+ // stage 3.
+ HadamardRotation(&s[0], &s[8], false);
+ HadamardRotation(&s[1], &s[9], false);
+ HadamardRotation(&s[2], &s[10], false);
+ HadamardRotation(&s[3], &s[11], false);
+ HadamardRotation(&s[4], &s[12], false);
+ HadamardRotation(&s[5], &s[13], false);
+ HadamardRotation(&s[6], &s[14], false);
+ HadamardRotation(&s[7], &s[15], false);
+
+ // stage 4.
+ butterfly_rotation(&s[8], &s[9], 56 - 0, true);
+ butterfly_rotation(&s[13], &s[12], 8 + 0, true);
+ butterfly_rotation(&s[10], &s[11], 56 - 32, true);
+ butterfly_rotation(&s[15], &s[14], 8 + 32, true);
+
+ // stage 5.
+ HadamardRotation(&s[0], &s[4], false);
+ HadamardRotation(&s[8], &s[12], false);
+ HadamardRotation(&s[1], &s[5], false);
+ HadamardRotation(&s[9], &s[13], false);
+ HadamardRotation(&s[2], &s[6], false);
+ HadamardRotation(&s[10], &s[14], false);
+ HadamardRotation(&s[3], &s[7], false);
+ HadamardRotation(&s[11], &s[15], false);
+
+ // stage 6.
+ butterfly_rotation(&s[4], &s[5], 48 - 0, true);
+ butterfly_rotation(&s[12], &s[13], 48 - 0, true);
+ butterfly_rotation(&s[7], &s[6], 48 - 32, true);
+ butterfly_rotation(&s[15], &s[14], 48 - 32, true);
+
+ // stage 7.
+ HadamardRotation(&s[0], &s[2], false);
+ HadamardRotation(&s[4], &s[6], false);
+ HadamardRotation(&s[8], &s[10], false);
+ HadamardRotation(&s[12], &s[14], false);
+ HadamardRotation(&s[1], &s[3], false);
+ HadamardRotation(&s[5], &s[7], false);
+ HadamardRotation(&s[9], &s[11], false);
+ HadamardRotation(&s[13], &s[15], false);
+
+ // stage 8.
+ butterfly_rotation(&s[2], &s[3], 32, true);
+ butterfly_rotation(&s[6], &s[7], 32, true);
+ butterfly_rotation(&s[10], &s[11], 32, true);
+ butterfly_rotation(&s[14], &s[15], 32, true);
+
+ // stage 9.
+ x[0] = s[0];
+ x[1] = vqnegq_s16(s[8]);
+ x[2] = s[12];
+ x[3] = vqnegq_s16(s[4]);
+ x[4] = s[6];
+ x[5] = vqnegq_s16(s[14]);
+ x[6] = s[10];
+ x[7] = vqnegq_s16(s[2]);
+ x[8] = s[3];
+ x[9] = vqnegq_s16(s[11]);
+ x[10] = s[15];
+ x[11] = vqnegq_s16(s[7]);
+ x[12] = s[5];
+ x[13] = vqnegq_s16(s[13]);
+ x[14] = s[9];
+ x[15] = vqnegq_s16(s[1]);
+
+ if (stage_is_rectangular) {
+ if (is_row) {
+ const int16x8_t v_row_shift = vdupq_n_s16(-row_shift);
+ int16x8_t output[4];
+ Transpose4x8To8x4(x, output);
+ for (int i = 0; i < 4; ++i) {
+ output[i] = vqrshlq_s16(output[i], v_row_shift);
+ }
+ StoreDst<16, 4>(dst, step, 0, output);
+ Transpose4x8To8x4(&x[8], output);
+ for (int i = 0; i < 4; ++i) {
+ output[i] = vqrshlq_s16(output[i], v_row_shift);
+ }
+ StoreDst<16, 4>(dst, step, 8, output);
+ } else {
+ StoreDst<8, 16>(dst, step, 0, x);
+ }
+ } else {
+ if (is_row) {
+ const int16x8_t v_row_shift = vdupq_n_s16(-row_shift);
+ for (int idx = 0; idx < 16; idx += 8) {
+ int16x8_t output[8];
+ Transpose8x8(&x[idx], output);
+ for (int i = 0; i < 8; ++i) {
+ output[i] = vqrshlq_s16(output[i], v_row_shift);
+ }
+ StoreDst<16, 8>(dst, step, idx, output);
+ }
+ } else {
+ StoreDst<16, 16>(dst, step, 0, x);
+ }
+ }
+}
+
+LIBGAV1_ALWAYS_INLINE void Adst16DcOnlyInternal(int16x8_t* s, int16x8_t* x) {
+ // stage 2.
+ ButterflyRotation_FirstIsZero(&s[0], &s[1], 62, true);
+
+ // stage 3.
+ s[8] = s[0];
+ s[9] = s[1];
+
+ // stage 4.
+ ButterflyRotation_4(&s[8], &s[9], 56, true);
+
+ // stage 5.
+ s[4] = s[0];
+ s[12] = s[8];
+ s[5] = s[1];
+ s[13] = s[9];
+
+ // stage 6.
+ ButterflyRotation_4(&s[4], &s[5], 48, true);
+ ButterflyRotation_4(&s[12], &s[13], 48, true);
+
+ // stage 7.
+ s[2] = s[0];
+ s[6] = s[4];
+ s[10] = s[8];
+ s[14] = s[12];
+ s[3] = s[1];
+ s[7] = s[5];
+ s[11] = s[9];
+ s[15] = s[13];
+
+ // stage 8.
+ ButterflyRotation_4(&s[2], &s[3], 32, true);
+ ButterflyRotation_4(&s[6], &s[7], 32, true);
+ ButterflyRotation_4(&s[10], &s[11], 32, true);
+ ButterflyRotation_4(&s[14], &s[15], 32, true);
+
+ // stage 9.
+ x[0] = s[0];
+ x[1] = vqnegq_s16(s[8]);
+ x[2] = s[12];
+ x[3] = vqnegq_s16(s[4]);
+ x[4] = s[6];
+ x[5] = vqnegq_s16(s[14]);
+ x[6] = s[10];
+ x[7] = vqnegq_s16(s[2]);
+ x[8] = s[3];
+ x[9] = vqnegq_s16(s[11]);
+ x[10] = s[15];
+ x[11] = vqnegq_s16(s[7]);
+ x[12] = s[5];
+ x[13] = vqnegq_s16(s[13]);
+ x[14] = s[9];
+ x[15] = vqnegq_s16(s[1]);
+}
+
+LIBGAV1_ALWAYS_INLINE bool Adst16DcOnly(void* dest, int adjusted_tx_height,
+ bool should_round, int row_shift) {
+ if (adjusted_tx_height > 1) return false;
+
+ auto* dst = static_cast<int16_t*>(dest);
+ int16x8_t s[16];
+ int16x8_t x[16];
+
+ const int16x8_t v_src = vdupq_n_s16(dst[0]);
+ const uint16x8_t v_mask = vdupq_n_u16(should_round ? 0xffff : 0);
+ const int16x8_t v_src_round =
+ vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3);
+ // stage 1.
+ s[1] = vbslq_s16(v_mask, v_src_round, v_src);
+
+ Adst16DcOnlyInternal(s, x);
+
+ for (int i = 0; i < 16; ++i) {
+ // vqrshlq_s16 will shift right if shift value is negative.
+ x[i] = vqrshlq_s16(x[i], vdupq_n_s16(-row_shift));
+ vst1q_lane_s16(&dst[i], x[i], 0);
+ }
+
+ return true;
+}
+
+LIBGAV1_ALWAYS_INLINE bool Adst16DcOnlyColumn(void* dest,
+ int adjusted_tx_height,
+ int width) {
+ if (adjusted_tx_height > 1) return false;
+
+ auto* dst = static_cast<int16_t*>(dest);
+ int i = 0;
+ do {
+ int16x8_t s[16];
+ int16x8_t x[16];
+ const int16x8_t v_src = vld1q_s16(dst);
+ // stage 1.
+ s[1] = v_src;
+
+ Adst16DcOnlyInternal(s, x);
+
+ for (int j = 0; j < 16; ++j) {
+ vst1_s16(&dst[j * width], vget_low_s16(x[j]));
+ }
+ i += 4;
+ dst += 4;
+ } while (i < width);
+
+ return true;
+}
+
+//------------------------------------------------------------------------------
+// Identity Transforms.
+
+template <bool is_row_shift>
+LIBGAV1_ALWAYS_INLINE void Identity4_NEON(void* dest, int32_t step) {
+ auto* const dst = static_cast<int16_t*>(dest);
+
+ if (is_row_shift) {
+ const int shift = 1;
+ const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
+ const int16x4_t v_multiplier = vdup_n_s16(kIdentity4Multiplier);
+ const int32x4_t v_shift = vdupq_n_s32(-(12 + shift));
+ for (int i = 0; i < 4; i += 2) {
+ const int16x8_t v_src = vld1q_s16(&dst[i * step]);
+ const int32x4_t v_src_mult_lo =
+ vmlal_s16(v_dual_round, vget_low_s16(v_src), v_multiplier);
+ const int32x4_t v_src_mult_hi =
+ vmlal_s16(v_dual_round, vget_high_s16(v_src), v_multiplier);
+ const int32x4_t shift_lo = vqshlq_s32(v_src_mult_lo, v_shift);
+ const int32x4_t shift_hi = vqshlq_s32(v_src_mult_hi, v_shift);
+ vst1q_s16(&dst[i * step],
+ vcombine_s16(vqmovn_s32(shift_lo), vqmovn_s32(shift_hi)));
+ }
+ } else {
+ for (int i = 0; i < 4; i += 2) {
+ const int16x8_t v_src = vld1q_s16(&dst[i * step]);
+ const int16x8_t a =
+ vqrdmulhq_n_s16(v_src, kIdentity4MultiplierFraction << 3);
+ const int16x8_t b = vqaddq_s16(v_src, a);
+ vst1q_s16(&dst[i * step], b);
+ }
+ }
+}
+
+LIBGAV1_ALWAYS_INLINE bool Identity4DcOnly(void* dest, int adjusted_tx_height,
+ bool should_round, int tx_height) {
+ if (adjusted_tx_height > 1) return false;
+
+ auto* dst = static_cast<int16_t*>(dest);
+ const int16x4_t v_src0 = vdup_n_s16(dst[0]);
+ const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0);
+ const int16x4_t v_src_round =
+ vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3);
+ const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0);
+ const int shift = tx_height < 16 ? 0 : 1;
+ const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
+ const int16x4_t v_multiplier = vdup_n_s16(kIdentity4Multiplier);
+ const int32x4_t v_shift = vdupq_n_s32(-(12 + shift));
+ const int32x4_t v_src_mult_lo = vmlal_s16(v_dual_round, v_src, v_multiplier);
+ const int32x4_t dst_0 = vqshlq_s32(v_src_mult_lo, v_shift);
+ vst1_lane_s16(dst, vqmovn_s32(dst_0), 0);
+ return true;
+}
+
+template <int identity_size>
+LIBGAV1_ALWAYS_INLINE void IdentityColumnStoreToFrame(
+ Array2DView<uint8_t> frame, const int start_x, const int start_y,
+ const int tx_width, const int tx_height, const int16_t* source) {
+ const int stride = frame.columns();
+ uint8_t* dst = frame[start_y] + start_x;
+
+ if (identity_size < 32) {
+ if (tx_width == 4) {
+ uint8x8_t frame_data = vdup_n_u8(0);
+ int i = 0;
+ do {
+ const int16x4_t v_src = vld1_s16(&source[i * tx_width]);
+
+ int16x4_t v_dst_i;
+ if (identity_size == 4) {
+ const int16x4_t v_src_fraction =
+ vqrdmulh_n_s16(v_src, kIdentity4MultiplierFraction << 3);
+ v_dst_i = vqadd_s16(v_src, v_src_fraction);
+ } else if (identity_size == 8) {
+ v_dst_i = vqadd_s16(v_src, v_src);
+ } else { // identity_size == 16
+ const int16x4_t v_src_mult =
+ vqrdmulh_n_s16(v_src, kIdentity4MultiplierFraction << 4);
+ const int16x4_t v_srcx2 = vqadd_s16(v_src, v_src);
+ v_dst_i = vqadd_s16(v_srcx2, v_src_mult);
+ }
+
+ frame_data = Load4<0>(dst, frame_data);
+ const int16x4_t a = vrshr_n_s16(v_dst_i, 4);
+ const uint16x8_t b =
+ vaddw_u8(vreinterpretq_u16_s16(vcombine_s16(a, a)), frame_data);
+ const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
+ StoreLo4(dst, d);
+ dst += stride;
+ } while (++i < tx_height);
+ } else {
+ int i = 0;
+ do {
+ const int row = i * tx_width;
+ int j = 0;
+ do {
+ const int16x8_t v_src = vld1q_s16(&source[row + j]);
+
+ int16x8_t v_dst_i;
+ if (identity_size == 4) {
+ const int16x8_t v_src_fraction =
+ vqrdmulhq_n_s16(v_src, kIdentity4MultiplierFraction << 3);
+ v_dst_i = vqaddq_s16(v_src, v_src_fraction);
+ } else if (identity_size == 8) {
+ v_dst_i = vqaddq_s16(v_src, v_src);
+ } else { // identity_size == 16
+ const int16x8_t v_src_mult =
+ vqrdmulhq_n_s16(v_src, kIdentity4MultiplierFraction << 4);
+ const int16x8_t v_srcx2 = vqaddq_s16(v_src, v_src);
+ v_dst_i = vqaddq_s16(v_src_mult, v_srcx2);
+ }
+
+ const uint8x8_t frame_data = vld1_u8(dst + j);
+ const int16x8_t a = vrshrq_n_s16(v_dst_i, 4);
+ const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data);
+ const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
+ vst1_u8(dst + j, d);
+ j += 8;
+ } while (j < tx_width);
+ dst += stride;
+ } while (++i < tx_height);
+ }
+ } else {
+ int i = 0;
+ do {
+ const int row = i * tx_width;
+ int j = 0;
+ do {
+ const int16x8_t v_dst_i = vld1q_s16(&source[row + j]);
+ const uint8x8_t frame_data = vld1_u8(dst + j);
+ const int16x8_t a = vrshrq_n_s16(v_dst_i, 2);
+ const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data);
+ const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
+ vst1_u8(dst + j, d);
+ j += 8;
+ } while (j < tx_width);
+ dst += stride;
+ } while (++i < tx_height);
+ }
+}
+
+LIBGAV1_ALWAYS_INLINE void Identity4RowColumnStoreToFrame(
+ Array2DView<uint8_t> frame, const int start_x, const int start_y,
+ const int tx_width, const int tx_height, const int16_t* source) {
+ const int stride = frame.columns();
+ uint8_t* dst = frame[start_y] + start_x;
+
+ if (tx_width == 4) {
+ uint8x8_t frame_data = vdup_n_u8(0);
+ int i = 0;
+ do {
+ const int16x4_t v_src = vld1_s16(&source[i * tx_width]);
+ const int16x4_t v_src_mult =
+ vqrdmulh_n_s16(v_src, kIdentity4MultiplierFraction << 3);
+ const int16x4_t v_dst_row = vqadd_s16(v_src, v_src_mult);
+ const int16x4_t v_src_mult2 =
+ vqrdmulh_n_s16(v_dst_row, kIdentity4MultiplierFraction << 3);
+ const int16x4_t v_dst_col = vqadd_s16(v_dst_row, v_src_mult2);
+ frame_data = Load4<0>(dst, frame_data);
+ const int16x4_t a = vrshr_n_s16(v_dst_col, 4);
+ const uint16x8_t b =
+ vaddw_u8(vreinterpretq_u16_s16(vcombine_s16(a, a)), frame_data);
+ const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
+ StoreLo4(dst, d);
+ dst += stride;
+ } while (++i < tx_height);
+ } else {
+ int i = 0;
+ do {
+ const int row = i * tx_width;
+ int j = 0;
+ do {
+ const int16x8_t v_src = vld1q_s16(&source[row + j]);
+ const int16x8_t v_src_round =
+ vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3);
+ const int16x8_t v_dst_row = vqaddq_s16(v_src_round, v_src_round);
+ const int16x8_t v_src_mult2 =
+ vqrdmulhq_n_s16(v_dst_row, kIdentity4MultiplierFraction << 3);
+ const int16x8_t v_dst_col = vqaddq_s16(v_dst_row, v_src_mult2);
+ const uint8x8_t frame_data = vld1_u8(dst + j);
+ const int16x8_t a = vrshrq_n_s16(v_dst_col, 4);
+ const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data);
+ const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
+ vst1_u8(dst + j, d);
+ j += 8;
+ } while (j < tx_width);
+ dst += stride;
+ } while (++i < tx_height);
+ }
+}
+
+LIBGAV1_ALWAYS_INLINE void Identity8Row32_NEON(void* dest, int32_t step) {
+ auto* const dst = static_cast<int16_t*>(dest);
+
+ // When combining the identity8 multiplier with the row shift, the
+ // calculations for tx_height equal to 32 can be simplified from
+ // ((A * 2) + 2) >> 2) to ((A + 1) >> 1).
+ for (int i = 0; i < 4; ++i) {
+ const int16x8_t v_src = vld1q_s16(&dst[i * step]);
+ const int16x8_t a = vrshrq_n_s16(v_src, 1);
+ vst1q_s16(&dst[i * step], a);
+ }
+}
+
+LIBGAV1_ALWAYS_INLINE void Identity8Row4_NEON(void* dest, int32_t step) {
+ auto* const dst = static_cast<int16_t*>(dest);
+
+ for (int i = 0; i < 4; ++i) {
+ const int16x8_t v_src = vld1q_s16(&dst[i * step]);
+ // For bitdepth == 8, the identity row clamps to a signed 16bit value, so
+ // saturating add here is ok.
+ const int16x8_t v_srcx2 = vqaddq_s16(v_src, v_src);
+ vst1q_s16(&dst[i * step], v_srcx2);
+ }
+}
+
+LIBGAV1_ALWAYS_INLINE bool Identity8DcOnly(void* dest, int adjusted_tx_height,
+ bool should_round, int row_shift) {
+ if (adjusted_tx_height > 1) return false;
+
+ auto* dst = static_cast<int16_t*>(dest);
+ const int16x4_t v_src0 = vdup_n_s16(dst[0]);
+ const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0);
+ const int16x4_t v_src_round =
+ vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3);
+ const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0);
+ const int32x4_t v_srcx2 = vaddl_s16(v_src, v_src);
+ const int32x4_t dst_0 = vqrshlq_s32(v_srcx2, vdupq_n_s32(-row_shift));
+ vst1_lane_s16(dst, vqmovn_s32(dst_0), 0);
+ return true;
+}
+
+LIBGAV1_ALWAYS_INLINE void Identity16Row_NEON(void* dest, int32_t step,
+ int shift) {
+ auto* const dst = static_cast<int16_t*>(dest);
+ const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
+ const int32x4_t v_shift = vdupq_n_s32(-(12 + shift));
+
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 2; ++j) {
+ const int16x8_t v_src = vld1q_s16(&dst[i * step + j * 8]);
+ const int32x4_t v_src_mult_lo =
+ vmlal_n_s16(v_dual_round, vget_low_s16(v_src), kIdentity16Multiplier);
+ const int32x4_t v_src_mult_hi = vmlal_n_s16(
+ v_dual_round, vget_high_s16(v_src), kIdentity16Multiplier);
+ const int32x4_t shift_lo = vqshlq_s32(v_src_mult_lo, v_shift);
+ const int32x4_t shift_hi = vqshlq_s32(v_src_mult_hi, v_shift);
+ vst1q_s16(&dst[i * step + j * 8],
+ vcombine_s16(vqmovn_s32(shift_lo), vqmovn_s32(shift_hi)));
+ }
+ }
+}
+
+LIBGAV1_ALWAYS_INLINE bool Identity16DcOnly(void* dest, int adjusted_tx_height,
+ bool should_round, int shift) {
+ if (adjusted_tx_height > 1) return false;
+
+ auto* dst = static_cast<int16_t*>(dest);
+ const int16x4_t v_src0 = vdup_n_s16(dst[0]);
+ const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0);
+ const int16x4_t v_src_round =
+ vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3);
+ const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0);
+ const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
+ const int16x4_t v_multiplier = vdup_n_s16(kIdentity16Multiplier);
+ const int32x4_t v_shift = vdupq_n_s32(-(12 + shift));
+ const int32x4_t v_src_mult_lo =
+ vmlal_s16(v_dual_round, (v_src), v_multiplier);
+ const int32x4_t dst_0 = vqshlq_s32(v_src_mult_lo, v_shift);
+ vst1_lane_s16(dst, vqmovn_s32(dst_0), 0);
+ return true;
+}
+
+LIBGAV1_ALWAYS_INLINE void Identity32Row16_NEON(void* dest,
+ const int32_t step) {
+ auto* const dst = static_cast<int16_t*>(dest);
+
+ // When combining the identity32 multiplier with the row shift, the
+ // calculation for tx_height equal to 16 can be simplified from
+ // ((A * 4) + 1) >> 1) to (A * 2).
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 32; j += 8) {
+ const int16x8_t v_src = vld1q_s16(&dst[i * step + j]);
+ // For bitdepth == 8, the identity row clamps to a signed 16bit value, so
+ // saturating add here is ok.
+ const int16x8_t v_dst_i = vqaddq_s16(v_src, v_src);
+ vst1q_s16(&dst[i * step + j], v_dst_i);
+ }
+ }
+}
+
+LIBGAV1_ALWAYS_INLINE bool Identity32DcOnly(void* dest,
+ int adjusted_tx_height) {
+ if (adjusted_tx_height > 1) return false;
+
+ auto* dst = static_cast<int16_t*>(dest);
+ const int16x4_t v_src0 = vdup_n_s16(dst[0]);
+ const int16x4_t v_src = vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3);
+ // When combining the identity32 multiplier with the row shift, the
+ // calculation for tx_height equal to 16 can be simplified from
+ // ((A * 4) + 1) >> 1) to (A * 2).
+ const int16x4_t v_dst_0 = vqadd_s16(v_src, v_src);
+ vst1_lane_s16(dst, v_dst_0, 0);
+ return true;
+}
+
+//------------------------------------------------------------------------------
+// Walsh Hadamard Transform.
+
+// Transposes a 4x4 matrix and then permutes the rows of the transposed matrix
+// for the WHT. The input matrix is in two "wide" int16x8_t variables. The
+// output matrix is in four int16x4_t variables.
+//
+// Input:
+// in[0]: 00 01 02 03 10 11 12 13
+// in[1]: 20 21 22 23 30 31 32 33
+// Output:
+// out[0]: 00 10 20 30
+// out[1]: 03 13 23 33
+// out[2]: 01 11 21 31
+// out[3]: 02 12 22 32
+LIBGAV1_ALWAYS_INLINE void TransposeAndPermute4x4WideInput(
+ const int16x8_t in[2], int16x4_t out[4]) {
+ // Swap 32 bit elements. Goes from:
+ // in[0]: 00 01 02 03 10 11 12 13
+ // in[1]: 20 21 22 23 30 31 32 33
+ // to:
+ // b0.val[0]: 00 01 20 21 10 11 30 31
+ // b0.val[1]: 02 03 22 23 12 13 32 33
+
+ const int32x4x2_t b0 =
+ vtrnq_s32(vreinterpretq_s32_s16(in[0]), vreinterpretq_s32_s16(in[1]));
+
+ // Swap 16 bit elements. Goes from:
+ // vget_low_s32(b0.val[0]): 00 01 20 21
+ // vget_high_s32(b0.val[0]): 10 11 30 31
+ // vget_low_s32(b0.val[1]): 02 03 22 23
+ // vget_high_s32(b0.val[1]): 12 13 32 33
+ // to:
+ // c0.val[0]: 00 10 20 30
+ // c0.val[1]: 01 11 21 32
+ // c1.val[0]: 02 12 22 32
+ // c1.val[1]: 03 13 23 33
+
+ const int16x4x2_t c0 =
+ vtrn_s16(vreinterpret_s16_s32(vget_low_s32(b0.val[0])),
+ vreinterpret_s16_s32(vget_high_s32(b0.val[0])));
+ const int16x4x2_t c1 =
+ vtrn_s16(vreinterpret_s16_s32(vget_low_s32(b0.val[1])),
+ vreinterpret_s16_s32(vget_high_s32(b0.val[1])));
+
+ out[0] = c0.val[0];
+ out[1] = c1.val[1];
+ out[2] = c0.val[1];
+ out[3] = c1.val[0];
+}
+
+// Process 4 wht4 rows and columns.
+LIBGAV1_ALWAYS_INLINE void Wht4_NEON(uint8_t* dst, const int dst_stride,
+ const void* source,
+ const int adjusted_tx_height) {
+ const auto* const src = static_cast<const int16_t*>(source);
+ int16x4_t s[4];
+
+ if (adjusted_tx_height == 1) {
+ // Special case: only src[0] is nonzero.
+ // src[0] 0 0 0
+ // 0 0 0 0
+ // 0 0 0 0
+ // 0 0 0 0
+ //
+ // After the row and column transforms are applied, we have:
+ // f h h h
+ // g i i i
+ // g i i i
+ // g i i i
+ // where f, g, h, i are computed as follows.
+ int16_t f = (src[0] >> 2) - (src[0] >> 3);
+ const int16_t g = f >> 1;
+ f = f - (f >> 1);
+ const int16_t h = (src[0] >> 3) - (src[0] >> 4);
+ const int16_t i = (src[0] >> 4);
+ s[0] = vdup_n_s16(h);
+ s[0] = vset_lane_s16(f, s[0], 0);
+ s[1] = vdup_n_s16(i);
+ s[1] = vset_lane_s16(g, s[1], 0);
+ s[2] = s[3] = s[1];
+ } else {
+ // Load the 4x4 source in transposed form.
+ int16x4x4_t columns = vld4_s16(src);
+ // Shift right and permute the columns for the WHT.
+ s[0] = vshr_n_s16(columns.val[0], 2);
+ s[2] = vshr_n_s16(columns.val[1], 2);
+ s[3] = vshr_n_s16(columns.val[2], 2);
+ s[1] = vshr_n_s16(columns.val[3], 2);
+
+ // Row transforms.
+ s[0] = vadd_s16(s[0], s[2]);
+ s[3] = vsub_s16(s[3], s[1]);
+ int16x4_t e = vhsub_s16(s[0], s[3]); // e = (s[0] - s[3]) >> 1
+ s[1] = vsub_s16(e, s[1]);
+ s[2] = vsub_s16(e, s[2]);
+ s[0] = vsub_s16(s[0], s[1]);
+ s[3] = vadd_s16(s[3], s[2]);
+
+ int16x8_t x[2];
+ x[0] = vcombine_s16(s[0], s[1]);
+ x[1] = vcombine_s16(s[2], s[3]);
+ TransposeAndPermute4x4WideInput(x, s);
+
+ // Column transforms.
+ s[0] = vadd_s16(s[0], s[2]);
+ s[3] = vsub_s16(s[3], s[1]);
+ e = vhsub_s16(s[0], s[3]); // e = (s[0] - s[3]) >> 1
+ s[1] = vsub_s16(e, s[1]);
+ s[2] = vsub_s16(e, s[2]);
+ s[0] = vsub_s16(s[0], s[1]);
+ s[3] = vadd_s16(s[3], s[2]);
+ }
+
+ // Store to frame.
+ uint8x8_t frame_data = vdup_n_u8(0);
+ for (int row = 0; row < 4; row += 2) {
+ frame_data = Load4<0>(dst, frame_data);
+ frame_data = Load4<1>(dst + dst_stride, frame_data);
+ const int16x8_t residual = vcombine_s16(s[row], s[row + 1]);
+ const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(residual), frame_data);
+ frame_data = vqmovun_s16(vreinterpretq_s16_u16(b));
+ StoreLo4(dst, frame_data);
+ dst += dst_stride;
+ StoreHi4(dst, frame_data);
+ dst += dst_stride;
+ }
+}
+
+//------------------------------------------------------------------------------
+// row/column transform loops
+
+template <int tx_height>
+LIBGAV1_ALWAYS_INLINE void FlipColumns(int16_t* source, int tx_width) {
+ if (tx_width >= 16) {
+ int i = 0;
+ do {
+ const int16x8_t a = vld1q_s16(&source[i]);
+ const int16x8_t b = vld1q_s16(&source[i + 8]);
+ const int16x8_t c = vrev64q_s16(a);
+ const int16x8_t d = vrev64q_s16(b);
+ vst1q_s16(&source[i], vcombine_s16(vget_high_s16(d), vget_low_s16(d)));
+ vst1q_s16(&source[i + 8],
+ vcombine_s16(vget_high_s16(c), vget_low_s16(c)));
+ i += 16;
+ } while (i < tx_width * tx_height);
+ } else if (tx_width == 8) {
+ for (int i = 0; i < 8 * tx_height; i += 8) {
+ const int16x8_t a = vld1q_s16(&source[i]);
+ const int16x8_t b = vrev64q_s16(a);
+ vst1q_s16(&source[i], vcombine_s16(vget_high_s16(b), vget_low_s16(b)));
+ }
+ } else {
+ // Process two rows per iteration.
+ for (int i = 0; i < 4 * tx_height; i += 8) {
+ const int16x8_t a = vld1q_s16(&source[i]);
+ vst1q_s16(&source[i], vrev64q_s16(a));
+ }
+ }
+}
+
+template <int tx_width>
+LIBGAV1_ALWAYS_INLINE void ApplyRounding(int16_t* source, int num_rows) {
+ if (tx_width == 4) {
+ // Process two rows per iteration.
+ int i = 0;
+ do {
+ const int16x8_t a = vld1q_s16(&source[i]);
+ const int16x8_t b = vqrdmulhq_n_s16(a, kTransformRowMultiplier << 3);
+ vst1q_s16(&source[i], b);
+ i += 8;
+ } while (i < tx_width * num_rows);
+ } else {
+ int i = 0;
+ do {
+ // The last 32 values of every row are always zero if the |tx_width| is
+ // 64.
+ const int non_zero_width = (tx_width < 64) ? tx_width : 32;
+ int j = 0;
+ do {
+ const int16x8_t a = vld1q_s16(&source[i * tx_width + j]);
+ const int16x8_t b = vqrdmulhq_n_s16(a, kTransformRowMultiplier << 3);
+ vst1q_s16(&source[i * tx_width + j], b);
+ j += 8;
+ } while (j < non_zero_width);
+ } while (++i < num_rows);
+ }
+}
+
+template <int tx_width>
+LIBGAV1_ALWAYS_INLINE void RowShift(int16_t* source, int num_rows,
+ int row_shift) {
+ // vqrshlq_s16 will shift right if shift value is negative.
+ row_shift = -row_shift;
+
+ if (tx_width == 4) {
+ // Process two rows per iteration.
+ int i = 0;
+ do {
+ const int16x8_t residual = vld1q_s16(&source[i]);
+ vst1q_s16(&source[i], vqrshlq_s16(residual, vdupq_n_s16(row_shift)));
+ i += 8;
+ } while (i < tx_width * num_rows);
+ } else {
+ int i = 0;
+ do {
+ for (int j = 0; j < tx_width; j += 8) {
+ const int16x8_t residual = vld1q_s16(&source[i * tx_width + j]);
+ const int16x8_t residual_shifted =
+ vqrshlq_s16(residual, vdupq_n_s16(row_shift));
+ vst1q_s16(&source[i * tx_width + j], residual_shifted);
+ }
+ } while (++i < num_rows);
+ }
+}
+
+template <int tx_height, bool enable_flip_rows = false>
+LIBGAV1_ALWAYS_INLINE void StoreToFrameWithRound(
+ Array2DView<uint8_t> frame, const int start_x, const int start_y,
+ const int tx_width, const int16_t* source, TransformType tx_type) {
+ const bool flip_rows =
+ enable_flip_rows ? kTransformFlipRowsMask.Contains(tx_type) : false;
+ const int stride = frame.columns();
+ uint8_t* dst = frame[start_y] + start_x;
+
+ // Enable for 4x4, 4x8, 4x16
+ if (tx_height < 32 && tx_width == 4) {
+ uint8x8_t frame_data = vdup_n_u8(0);
+ for (int i = 0; i < tx_height; ++i) {
+ const int row = flip_rows ? (tx_height - i - 1) * 4 : i * 4;
+ const int16x4_t residual = vld1_s16(&source[row]);
+ frame_data = Load4<0>(dst, frame_data);
+ const int16x4_t a = vrshr_n_s16(residual, 4);
+ const uint16x8_t b =
+ vaddw_u8(vreinterpretq_u16_s16(vcombine_s16(a, a)), frame_data);
+ const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
+ StoreLo4(dst, d);
+ dst += stride;
+ }
+ // Enable for 8x4, 8x8, 8x16, 8x32
+ } else if (tx_height < 64 && tx_width == 8) {
+ for (int i = 0; i < tx_height; ++i) {
+ const int row = flip_rows ? (tx_height - i - 1) * 8 : i * 8;
+ const int16x8_t residual = vld1q_s16(&source[row]);
+ const uint8x8_t frame_data = vld1_u8(dst);
+ const int16x8_t a = vrshrq_n_s16(residual, 4);
+ const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data);
+ const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
+ vst1_u8(dst, d);
+ dst += stride;
+ }
+ // Remaining widths >= 16.
+ } else {
+ for (int i = 0; i < tx_height; ++i) {
+ const int y = start_y + i;
+ const int row = flip_rows ? (tx_height - i - 1) * tx_width : i * tx_width;
+ int j = 0;
+ do {
+ const int x = start_x + j;
+ const int16x8_t residual = vld1q_s16(&source[row + j]);
+ const int16x8_t residual_hi = vld1q_s16(&source[row + j + 8]);
+ const uint8x16_t frame_data = vld1q_u8(frame[y] + x);
+ const int16x8_t a = vrshrq_n_s16(residual, 4);
+ const int16x8_t a_hi = vrshrq_n_s16(residual_hi, 4);
+ const uint16x8_t b =
+ vaddw_u8(vreinterpretq_u16_s16(a), vget_low_u8(frame_data));
+ const uint16x8_t b_hi =
+ vaddw_u8(vreinterpretq_u16_s16(a_hi), vget_high_u8(frame_data));
+ vst1q_u8(frame[y] + x,
+ vcombine_u8(vqmovun_s16(vreinterpretq_s16_u16(b)),
+ vqmovun_s16(vreinterpretq_s16_u16(b_hi))));
+ j += 16;
+ } while (j < tx_width);
+ }
+ }
+}
+
+void Dct4TransformLoopRow_NEON(TransformType /*tx_type*/, TransformSize tx_size,
+ int adjusted_tx_height, void* src_buffer,
+ int /*start_x*/, int /*start_y*/,
+ void* /*dst_frame*/) {
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const int tx_height = kTransformHeight[tx_size];
+ const bool should_round = (tx_height == 8);
+ const int row_shift = (tx_height == 16);
+
+ if (DctDcOnly<4>(src, adjusted_tx_height, should_round, row_shift)) {
+ return;
+ }
+
+ if (should_round) {
+ ApplyRounding<4>(src, adjusted_tx_height);
+ }
+
+ if (adjusted_tx_height == 4) {
+ // Process 4 1d dct4 rows in parallel.
+ Dct4_NEON<ButterflyRotation_4, false>(src, /*step=*/4, /*transpose=*/true);
+ } else {
+ // Process 8 1d dct4 rows in parallel per iteration.
+ int i = adjusted_tx_height;
+ auto* data = src;
+ do {
+ Dct4_NEON<ButterflyRotation_8, true>(data, /*step=*/4,
+ /*transpose=*/true);
+ data += 32;
+ i -= 8;
+ } while (i != 0);
+ }
+ if (tx_height == 16) {
+ RowShift<4>(src, adjusted_tx_height, 1);
+ }
+}
+
+void Dct4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
+ int adjusted_tx_height, void* src_buffer,
+ int start_x, int start_y, void* dst_frame) {
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const int tx_width = kTransformWidth[tx_size];
+
+ if (kTransformFlipColumnsMask.Contains(tx_type)) {
+ FlipColumns<4>(src, tx_width);
+ }
+
+ if (!DctDcOnlyColumn<4>(src, adjusted_tx_height, tx_width)) {
+ if (tx_width == 4) {
+ // Process 4 1d dct4 columns in parallel.
+ Dct4_NEON<ButterflyRotation_4, false>(src, tx_width, /*transpose=*/false);
+ } else {
+ // Process 8 1d dct4 columns in parallel per iteration.
+ int i = tx_width;
+ auto* data = src;
+ do {
+ Dct4_NEON<ButterflyRotation_8, true>(data, tx_width,
+ /*transpose=*/false);
+ data += 8;
+ i -= 8;
+ } while (i != 0);
+ }
+ }
+
+ auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
+ StoreToFrameWithRound<4>(frame, start_x, start_y, tx_width, src, tx_type);
+}
+
+void Dct8TransformLoopRow_NEON(TransformType /*tx_type*/, TransformSize tx_size,
+ int adjusted_tx_height, void* src_buffer,
+ int /*start_x*/, int /*start_y*/,
+ void* /*dst_frame*/) {
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const bool should_round = kShouldRound[tx_size];
+ const uint8_t row_shift = kTransformRowShift[tx_size];
+
+ if (DctDcOnly<8>(src, adjusted_tx_height, should_round, row_shift)) {
+ return;
+ }
+
+ if (should_round) {
+ ApplyRounding<8>(src, adjusted_tx_height);
+ }
+
+ if (adjusted_tx_height == 4) {
+ // Process 4 1d dct8 rows in parallel.
+ Dct8_NEON<ButterflyRotation_4, true>(src, /*step=*/8, /*transpose=*/true);
+ } else {
+ // Process 8 1d dct8 rows in parallel per iteration.
+ assert(adjusted_tx_height % 8 == 0);
+ int i = adjusted_tx_height;
+ auto* data = src;
+ do {
+ Dct8_NEON<ButterflyRotation_8, false>(data, /*step=*/8,
+ /*transpose=*/true);
+ data += 64;
+ i -= 8;
+ } while (i != 0);
+ }
+ if (row_shift > 0) {
+ RowShift<8>(src, adjusted_tx_height, row_shift);
+ }
+}
+
+void Dct8TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
+ int adjusted_tx_height, void* src_buffer,
+ int start_x, int start_y, void* dst_frame) {
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const int tx_width = kTransformWidth[tx_size];
+
+ if (kTransformFlipColumnsMask.Contains(tx_type)) {
+ FlipColumns<8>(src, tx_width);
+ }
+
+ if (!DctDcOnlyColumn<8>(src, adjusted_tx_height, tx_width)) {
+ if (tx_width == 4) {
+ // Process 4 1d dct8 columns in parallel.
+ Dct8_NEON<ButterflyRotation_4, true>(src, 4, /*transpose=*/false);
+ } else {
+ // Process 8 1d dct8 columns in parallel per iteration.
+ int i = tx_width;
+ auto* data = src;
+ do {
+ Dct8_NEON<ButterflyRotation_8, false>(data, tx_width,
+ /*transpose=*/false);
+ data += 8;
+ i -= 8;
+ } while (i != 0);
+ }
+ }
+ auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
+ StoreToFrameWithRound<8>(frame, start_x, start_y, tx_width, src, tx_type);
+}
+
+void Dct16TransformLoopRow_NEON(TransformType /*tx_type*/,
+ TransformSize tx_size, int adjusted_tx_height,
+ void* src_buffer, int /*start_x*/,
+ int /*start_y*/, void* /*dst_frame*/) {
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const bool should_round = kShouldRound[tx_size];
+ const uint8_t row_shift = kTransformRowShift[tx_size];
+
+ if (DctDcOnly<16>(src, adjusted_tx_height, should_round, row_shift)) {
+ return;
+ }
+
+ if (should_round) {
+ ApplyRounding<16>(src, adjusted_tx_height);
+ }
+
+ if (adjusted_tx_height == 4) {
+ // Process 4 1d dct16 rows in parallel.
+ Dct16_NEON<ButterflyRotation_4, true>(src, 16, /*is_row=*/true, row_shift);
+ } else {
+ assert(adjusted_tx_height % 8 == 0);
+ int i = adjusted_tx_height;
+ do {
+ // Process 8 1d dct16 rows in parallel per iteration.
+ Dct16_NEON<ButterflyRotation_8, false>(src, 16, /*is_row=*/true,
+ row_shift);
+ src += 128;
+ i -= 8;
+ } while (i != 0);
+ }
+}
+
+void Dct16TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
+ int adjusted_tx_height, void* src_buffer,
+ int start_x, int start_y, void* dst_frame) {
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const int tx_width = kTransformWidth[tx_size];
+
+ if (kTransformFlipColumnsMask.Contains(tx_type)) {
+ FlipColumns<16>(src, tx_width);
+ }
+
+ if (!DctDcOnlyColumn<16>(src, adjusted_tx_height, tx_width)) {
+ if (tx_width == 4) {
+ // Process 4 1d dct16 columns in parallel.
+ Dct16_NEON<ButterflyRotation_4, true>(src, 4, /*is_row=*/false,
+ /*row_shift=*/0);
+ } else {
+ int i = tx_width;
+ auto* data = src;
+ do {
+ // Process 8 1d dct16 columns in parallel per iteration.
+ Dct16_NEON<ButterflyRotation_8, false>(data, tx_width, /*is_row=*/false,
+ /*row_shift=*/0);
+ data += 8;
+ i -= 8;
+ } while (i != 0);
+ }
+ }
+ auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
+ StoreToFrameWithRound<16>(frame, start_x, start_y, tx_width, src, tx_type);
+}
+
+void Dct32TransformLoopRow_NEON(TransformType /*tx_type*/,
+ TransformSize tx_size, int adjusted_tx_height,
+ void* src_buffer, int /*start_x*/,
+ int /*start_y*/, void* /*dst_frame*/) {
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const bool should_round = kShouldRound[tx_size];
+ const uint8_t row_shift = kTransformRowShift[tx_size];
+
+ if (DctDcOnly<32>(src, adjusted_tx_height, should_round, row_shift)) {
+ return;
+ }
+
+ if (should_round) {
+ ApplyRounding<32>(src, adjusted_tx_height);
+ }
+ // Process 8 1d dct32 rows in parallel per iteration.
+ int i = 0;
+ do {
+ Dct32_NEON(&src[i * 32], 32, /*is_row=*/true, row_shift);
+ i += 8;
+ } while (i < adjusted_tx_height);
+}
+
+void Dct32TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
+ int adjusted_tx_height, void* src_buffer,
+ int start_x, int start_y, void* dst_frame) {
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const int tx_width = kTransformWidth[tx_size];
+
+ if (!DctDcOnlyColumn<32>(src, adjusted_tx_height, tx_width)) {
+ // Process 8 1d dct32 columns in parallel per iteration.
+ int i = tx_width;
+ auto* data = src;
+ do {
+ Dct32_NEON(data, tx_width, /*is_row=*/false, /*row_shift=*/0);
+ data += 8;
+ i -= 8;
+ } while (i != 0);
+ }
+ auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
+ StoreToFrameWithRound<32>(frame, start_x, start_y, tx_width, src, tx_type);
+}
+
+void Dct64TransformLoopRow_NEON(TransformType /*tx_type*/,
+ TransformSize tx_size, int adjusted_tx_height,
+ void* src_buffer, int /*start_x*/,
+ int /*start_y*/, void* /*dst_frame*/) {
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const bool should_round = kShouldRound[tx_size];
+ const uint8_t row_shift = kTransformRowShift[tx_size];
+
+ if (DctDcOnly<64>(src, adjusted_tx_height, should_round, row_shift)) {
+ return;
+ }
+
+ if (should_round) {
+ ApplyRounding<64>(src, adjusted_tx_height);
+ }
+ // Process 8 1d dct64 rows in parallel per iteration.
+ int i = 0;
+ do {
+ Dct64_NEON(&src[i * 64], 64, /*is_row=*/true, row_shift);
+ i += 8;
+ } while (i < adjusted_tx_height);
+}
+
+void Dct64TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
+ int adjusted_tx_height, void* src_buffer,
+ int start_x, int start_y, void* dst_frame) {
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const int tx_width = kTransformWidth[tx_size];
+
+ if (!DctDcOnlyColumn<64>(src, adjusted_tx_height, tx_width)) {
+ // Process 8 1d dct64 columns in parallel per iteration.
+ int i = tx_width;
+ auto* data = src;
+ do {
+ Dct64_NEON(data, tx_width, /*is_row=*/false, /*row_shift=*/0);
+ data += 8;
+ i -= 8;
+ } while (i != 0);
+ }
+ auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
+ StoreToFrameWithRound<64>(frame, start_x, start_y, tx_width, src, tx_type);
+}
+
+void Adst4TransformLoopRow_NEON(TransformType /*tx_type*/,
+ TransformSize tx_size, int adjusted_tx_height,
+ void* src_buffer, int /*start_x*/,
+ int /*start_y*/, void* /*dst_frame*/) {
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const int tx_height = kTransformHeight[tx_size];
+ const int row_shift = static_cast<int>(tx_height == 16);
+ const bool should_round = (tx_height == 8);
+
+ if (Adst4DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
+ return;
+ }
+
+ if (should_round) {
+ ApplyRounding<4>(src, adjusted_tx_height);
+ }
+
+ // Process 4 1d adst4 rows in parallel per iteration.
+ int i = adjusted_tx_height;
+ auto* data = src;
+ do {
+ Adst4_NEON<false>(data, /*step=*/4, /*transpose=*/true);
+ data += 16;
+ i -= 4;
+ } while (i != 0);
+
+ if (tx_height == 16) {
+ RowShift<4>(src, adjusted_tx_height, 1);
+ }
+}
+
+void Adst4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
+ int adjusted_tx_height, void* src_buffer,
+ int start_x, int start_y, void* dst_frame) {
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const int tx_width = kTransformWidth[tx_size];
+
+ if (kTransformFlipColumnsMask.Contains(tx_type)) {
+ FlipColumns<4>(src, tx_width);
+ }
+
+ if (!Adst4DcOnlyColumn(src, adjusted_tx_height, tx_width)) {
+ // Process 4 1d adst4 columns in parallel per iteration.
+ int i = tx_width;
+ auto* data = src;
+ do {
+ Adst4_NEON<false>(data, tx_width, /*transpose=*/false);
+ data += 4;
+ i -= 4;
+ } while (i != 0);
+ }
+
+ auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
+ StoreToFrameWithRound<4, /*enable_flip_rows=*/true>(frame, start_x, start_y,
+ tx_width, src, tx_type);
+}
+
+void Adst8TransformLoopRow_NEON(TransformType /*tx_type*/,
+ TransformSize tx_size, int adjusted_tx_height,
+ void* src_buffer, int /*start_x*/,
+ int /*start_y*/, void* /*dst_frame*/) {
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const bool should_round = kShouldRound[tx_size];
+ const uint8_t row_shift = kTransformRowShift[tx_size];
+
+ if (Adst8DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
+ return;
+ }
+
+ if (should_round) {
+ ApplyRounding<8>(src, adjusted_tx_height);
+ }
+
+ if (adjusted_tx_height == 4) {
+ // Process 4 1d adst8 rows in parallel.
+ Adst8_NEON<ButterflyRotation_4, true>(src, /*step=*/8, /*transpose=*/true);
+ } else {
+ // Process 8 1d adst8 rows in parallel per iteration.
+ assert(adjusted_tx_height % 8 == 0);
+ int i = adjusted_tx_height;
+ auto* data = src;
+ do {
+ Adst8_NEON<ButterflyRotation_8, false>(data, /*step=*/8,
+ /*transpose=*/true);
+ data += 64;
+ i -= 8;
+ } while (i != 0);
+ }
+ if (row_shift > 0) {
+ RowShift<8>(src, adjusted_tx_height, row_shift);
+ }
+}
+
+void Adst8TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
+ int adjusted_tx_height, void* src_buffer,
+ int start_x, int start_y, void* dst_frame) {
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const int tx_width = kTransformWidth[tx_size];
+
+ if (kTransformFlipColumnsMask.Contains(tx_type)) {
+ FlipColumns<8>(src, tx_width);
+ }
+
+ if (!Adst8DcOnlyColumn(src, adjusted_tx_height, tx_width)) {
+ if (tx_width == 4) {
+ // Process 4 1d adst8 columns in parallel.
+ Adst8_NEON<ButterflyRotation_4, true>(src, 4, /*transpose=*/false);
+ } else {
+ // Process 8 1d adst8 columns in parallel per iteration.
+ int i = tx_width;
+ auto* data = src;
+ do {
+ Adst8_NEON<ButterflyRotation_8, false>(data, tx_width,
+ /*transpose=*/false);
+ data += 8;
+ i -= 8;
+ } while (i != 0);
+ }
+ }
+ auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
+ StoreToFrameWithRound<8, /*enable_flip_rows=*/true>(frame, start_x, start_y,
+ tx_width, src, tx_type);
+}
+
+void Adst16TransformLoopRow_NEON(TransformType /*tx_type*/,
+ TransformSize tx_size, int adjusted_tx_height,
+ void* src_buffer, int /*start_x*/,
+ int /*start_y*/, void* /*dst_frame*/) {
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const bool should_round = kShouldRound[tx_size];
+ const uint8_t row_shift = kTransformRowShift[tx_size];
+
+ if (Adst16DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
+ return;
+ }
+
+ if (should_round) {
+ ApplyRounding<16>(src, adjusted_tx_height);
+ }
+
+ if (adjusted_tx_height == 4) {
+ // Process 4 1d adst16 rows in parallel.
+ Adst16_NEON<ButterflyRotation_4, true>(src, 16, /*is_row=*/true, row_shift);
+ } else {
+ assert(adjusted_tx_height % 8 == 0);
+ int i = adjusted_tx_height;
+ do {
+ // Process 8 1d adst16 rows in parallel per iteration.
+ Adst16_NEON<ButterflyRotation_8, false>(src, 16, /*is_row=*/true,
+ row_shift);
+ src += 128;
+ i -= 8;
+ } while (i != 0);
+ }
+}
+
+void Adst16TransformLoopColumn_NEON(TransformType tx_type,
+ TransformSize tx_size,
+ int adjusted_tx_height, void* src_buffer,
+ int start_x, int start_y, void* dst_frame) {
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const int tx_width = kTransformWidth[tx_size];
+
+ if (kTransformFlipColumnsMask.Contains(tx_type)) {
+ FlipColumns<16>(src, tx_width);
+ }
+
+ if (!Adst16DcOnlyColumn(src, adjusted_tx_height, tx_width)) {
+ if (tx_width == 4) {
+ // Process 4 1d adst16 columns in parallel.
+ Adst16_NEON<ButterflyRotation_4, true>(src, 4, /*is_row=*/false,
+ /*row_shift=*/0);
+ } else {
+ int i = tx_width;
+ auto* data = src;
+ do {
+ // Process 8 1d adst16 columns in parallel per iteration.
+ Adst16_NEON<ButterflyRotation_8, false>(
+ data, tx_width, /*is_row=*/false, /*row_shift=*/0);
+ data += 8;
+ i -= 8;
+ } while (i != 0);
+ }
+ }
+ auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
+ StoreToFrameWithRound<16, /*enable_flip_rows=*/true>(frame, start_x, start_y,
+ tx_width, src, tx_type);
+}
+
+void Identity4TransformLoopRow_NEON(TransformType tx_type,
+ TransformSize tx_size,
+ int adjusted_tx_height, void* src_buffer,
+ int /*start_x*/, int /*start_y*/,
+ void* /*dst_frame*/) {
+ // Special case: Process row calculations during column transform call.
+ // Improves performance.
+ if (tx_type == kTransformTypeIdentityIdentity &&
+ tx_size == kTransformSize4x4) {
+ return;
+ }
+
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const int tx_height = kTransformHeight[tx_size];
+ const bool should_round = (tx_height == 8);
+
+ if (Identity4DcOnly(src, adjusted_tx_height, should_round, tx_height)) {
+ return;
+ }
+
+ if (should_round) {
+ ApplyRounding<4>(src, adjusted_tx_height);
+ }
+ if (tx_height < 16) {
+ int i = adjusted_tx_height;
+ do {
+ Identity4_NEON<false>(src, /*step=*/4);
+ src += 16;
+ i -= 4;
+ } while (i != 0);
+ } else {
+ int i = adjusted_tx_height;
+ do {
+ Identity4_NEON<true>(src, /*step=*/4);
+ src += 16;
+ i -= 4;
+ } while (i != 0);
+ }
+}
+
+void Identity4TransformLoopColumn_NEON(TransformType tx_type,
+ TransformSize tx_size,
+ int adjusted_tx_height, void* src_buffer,
+ int start_x, int start_y,
+ void* dst_frame) {
+ auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const int tx_width = kTransformWidth[tx_size];
+
+ // Special case: Process row calculations during column transform call.
+ if (tx_type == kTransformTypeIdentityIdentity &&
+ (tx_size == kTransformSize4x4 || tx_size == kTransformSize8x4)) {
+ Identity4RowColumnStoreToFrame(frame, start_x, start_y, tx_width,
+ adjusted_tx_height, src);
+ return;
+ }
+
+ if (kTransformFlipColumnsMask.Contains(tx_type)) {
+ FlipColumns<4>(src, tx_width);
+ }
+
+ IdentityColumnStoreToFrame<4>(frame, start_x, start_y, tx_width,
+ adjusted_tx_height, src);
+}
+
+void Identity8TransformLoopRow_NEON(TransformType tx_type,
+ TransformSize tx_size,
+ int adjusted_tx_height, void* src_buffer,
+ int /*start_x*/, int /*start_y*/,
+ void* /*dst_frame*/) {
+ // Special case: Process row calculations during column transform call.
+ // Improves performance.
+ if (tx_type == kTransformTypeIdentityIdentity &&
+ tx_size == kTransformSize8x4) {
+ return;
+ }
+
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const int tx_height = kTransformHeight[tx_size];
+ const bool should_round = kShouldRound[tx_size];
+ const uint8_t row_shift = kTransformRowShift[tx_size];
+
+ if (Identity8DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
+ return;
+ }
+
+ if (should_round) {
+ ApplyRounding<8>(src, adjusted_tx_height);
+ }
+
+ // When combining the identity8 multiplier with the row shift, the
+ // calculations for tx_height == 8 and tx_height == 16 can be simplified
+ // from ((A * 2) + 1) >> 1) to A.
+ if ((tx_height & 0x18) != 0) {
+ return;
+ }
+ if (tx_height == 32) {
+ int i = adjusted_tx_height;
+ do {
+ Identity8Row32_NEON(src, /*step=*/8);
+ src += 32;
+ i -= 4;
+ } while (i != 0);
+ return;
+ }
+
+ assert(tx_size == kTransformSize8x4);
+ int i = adjusted_tx_height;
+ do {
+ Identity8Row4_NEON(src, /*step=*/8);
+ src += 32;
+ i -= 4;
+ } while (i != 0);
+}
+
+void Identity8TransformLoopColumn_NEON(TransformType tx_type,
+ TransformSize tx_size,
+ int adjusted_tx_height, void* src_buffer,
+ int start_x, int start_y,
+ void* dst_frame) {
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const int tx_width = kTransformWidth[tx_size];
+
+ if (kTransformFlipColumnsMask.Contains(tx_type)) {
+ FlipColumns<8>(src, tx_width);
+ }
+ auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
+ IdentityColumnStoreToFrame<8>(frame, start_x, start_y, tx_width,
+ adjusted_tx_height, src);
+}
+
+void Identity16TransformLoopRow_NEON(TransformType /*tx_type*/,
+ TransformSize tx_size,
+ int adjusted_tx_height, void* src_buffer,
+ int /*start_x*/, int /*start_y*/,
+ void* /*dst_frame*/) {
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const bool should_round = kShouldRound[tx_size];
+ const uint8_t row_shift = kTransformRowShift[tx_size];
+
+ if (Identity16DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
+ return;
+ }
+
+ if (should_round) {
+ ApplyRounding<16>(src, adjusted_tx_height);
+ }
+ int i = adjusted_tx_height;
+ do {
+ Identity16Row_NEON(src, /*step=*/16, kTransformRowShift[tx_size]);
+ src += 64;
+ i -= 4;
+ } while (i != 0);
+}
+
+void Identity16TransformLoopColumn_NEON(TransformType tx_type,
+ TransformSize tx_size,
+ int adjusted_tx_height,
+ void* src_buffer, int start_x,
+ int start_y, void* dst_frame) {
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const int tx_width = kTransformWidth[tx_size];
+
+ if (kTransformFlipColumnsMask.Contains(tx_type)) {
+ FlipColumns<16>(src, tx_width);
+ }
+ auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
+ IdentityColumnStoreToFrame<16>(frame, start_x, start_y, tx_width,
+ adjusted_tx_height, src);
+}
+
+void Identity32TransformLoopRow_NEON(TransformType /*tx_type*/,
+ TransformSize tx_size,
+ int adjusted_tx_height, void* src_buffer,
+ int /*start_x*/, int /*start_y*/,
+ void* /*dst_frame*/) {
+ const int tx_height = kTransformHeight[tx_size];
+
+ // When combining the identity32 multiplier with the row shift, the
+ // calculations for tx_height == 8 and tx_height == 32 can be simplified
+ // from ((A * 4) + 2) >> 2) to A.
+ if ((tx_height & 0x28) != 0) {
+ return;
+ }
+
+ // Process kTransformSize32x16. The src is always rounded before the
+ // identity transform and shifted by 1 afterwards.
+ auto* src = static_cast<int16_t*>(src_buffer);
+ if (Identity32DcOnly(src, adjusted_tx_height)) {
+ return;
+ }
+
+ assert(tx_size == kTransformSize32x16);
+ ApplyRounding<32>(src, adjusted_tx_height);
+ int i = adjusted_tx_height;
+ do {
+ Identity32Row16_NEON(src, /*step=*/32);
+ src += 128;
+ i -= 4;
+ } while (i != 0);
+}
+
+void Identity32TransformLoopColumn_NEON(TransformType /*tx_type*/,
+ TransformSize tx_size,
+ int adjusted_tx_height,
+ void* src_buffer, int start_x,
+ int start_y, void* dst_frame) {
+ auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
+ auto* src = static_cast<int16_t*>(src_buffer);
+ const int tx_width = kTransformWidth[tx_size];
+
+ IdentityColumnStoreToFrame<32>(frame, start_x, start_y, tx_width,
+ adjusted_tx_height, src);
+}
+
+void Wht4TransformLoopRow_NEON(TransformType tx_type, TransformSize tx_size,
+ int /*adjusted_tx_height*/, void* /*src_buffer*/,
+ int /*start_x*/, int /*start_y*/,
+ void* /*dst_frame*/) {
+ assert(tx_type == kTransformTypeDctDct);
+ assert(tx_size == kTransformSize4x4);
+ static_cast<void>(tx_type);
+ static_cast<void>(tx_size);
+ // Do both row and column transforms in the column-transform pass.
+}
+
+void Wht4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
+ int adjusted_tx_height, void* src_buffer,
+ int start_x, int start_y, void* dst_frame) {
+ assert(tx_type == kTransformTypeDctDct);
+ assert(tx_size == kTransformSize4x4);
+ static_cast<void>(tx_type);
+ static_cast<void>(tx_size);
+
+ // Process 4 1d wht4 rows and columns in parallel.
+ const auto* src = static_cast<int16_t*>(src_buffer);
+ auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
+ uint8_t* dst = frame[start_y] + start_x;
+ const int dst_stride = frame.columns();
+ Wht4_NEON(dst, dst_stride, src, adjusted_tx_height);
+}
+
+//------------------------------------------------------------------------------
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+ // Maximum transform size for Dct is 64.
+ dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kRow] =
+ Dct4TransformLoopRow_NEON;
+ dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kColumn] =
+ Dct4TransformLoopColumn_NEON;
+ dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kRow] =
+ Dct8TransformLoopRow_NEON;
+ dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kColumn] =
+ Dct8TransformLoopColumn_NEON;
+ dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kRow] =
+ Dct16TransformLoopRow_NEON;
+ dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kColumn] =
+ Dct16TransformLoopColumn_NEON;
+ dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kRow] =
+ Dct32TransformLoopRow_NEON;
+ dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kColumn] =
+ Dct32TransformLoopColumn_NEON;
+ dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kRow] =
+ Dct64TransformLoopRow_NEON;
+ dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kColumn] =
+ Dct64TransformLoopColumn_NEON;
+
+ // Maximum transform size for Adst is 16.
+ dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kRow] =
+ Adst4TransformLoopRow_NEON;
+ dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kColumn] =
+ Adst4TransformLoopColumn_NEON;
+ dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kRow] =
+ Adst8TransformLoopRow_NEON;
+ dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kColumn] =
+ Adst8TransformLoopColumn_NEON;
+ dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kRow] =
+ Adst16TransformLoopRow_NEON;
+ dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kColumn] =
+ Adst16TransformLoopColumn_NEON;
+
+ // Maximum transform size for Identity transform is 32.
+ dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kRow] =
+ Identity4TransformLoopRow_NEON;
+ dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kColumn] =
+ Identity4TransformLoopColumn_NEON;
+ dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kRow] =
+ Identity8TransformLoopRow_NEON;
+ dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kColumn] =
+ Identity8TransformLoopColumn_NEON;
+ dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kRow] =
+ Identity16TransformLoopRow_NEON;
+ dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kColumn] =
+ Identity16TransformLoopColumn_NEON;
+ dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kRow] =
+ Identity32TransformLoopRow_NEON;
+ dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kColumn] =
+ Identity32TransformLoopColumn_NEON;
+
+ // Maximum transform size for Wht is 4.
+ dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kRow] =
+ Wht4TransformLoopRow_NEON;
+ dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kColumn] =
+ Wht4TransformLoopColumn_NEON;
+}
+
+} // namespace
+} // namespace low_bitdepth
+
+void InverseTransformInit_NEON() { low_bitdepth::Init8bpp(); }
+
+} // namespace dsp
+} // namespace libgav1
+#else // !LIBGAV1_ENABLE_NEON
+namespace libgav1 {
+namespace dsp {
+
+void InverseTransformInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/inverse_transform_neon.h b/src/dsp/arm/inverse_transform_neon.h
new file mode 100644
index 0000000..af647e8
--- /dev/null
+++ b/src/dsp/arm/inverse_transform_neon.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_INVERSE_TRANSFORM_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_INVERSE_TRANSFORM_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::inverse_transforms, see the defines below for specifics.
+// This function is not thread-safe.
+void InverseTransformInit_NEON();
+
+} // namespace dsp
+} // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformDct LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformDct LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformDct LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformDct LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize64_1DTransformDct LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformAdst LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformAdst LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformAdst LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformIdentity LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformIdentity LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformIdentity LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht LIBGAV1_CPU_NEON
+#endif // LIBGAV1_ENABLE_NEON
+
+#endif // LIBGAV1_SRC_DSP_ARM_INVERSE_TRANSFORM_NEON_H_
diff --git a/src/dsp/arm/loop_filter_neon.cc b/src/dsp/arm/loop_filter_neon.cc
new file mode 100644
index 0000000..146c983
--- /dev/null
+++ b/src/dsp/arm/loop_filter_neon.cc
@@ -0,0 +1,1190 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/loop_filter.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+// (abs(p1 - p0) > thresh) || (abs(q1 - q0) > thresh)
+inline uint8x8_t Hev(const uint8x8_t abd_p0p1_q0q1, const uint8_t thresh) {
+ const uint8x8_t a = vcgt_u8(abd_p0p1_q0q1, vdup_n_u8(thresh));
+ return vorr_u8(a, RightShift<32>(a));
+}
+
+// abs(p0 - q0) * 2 + abs(p1 - q1) / 2 <= outer_thresh
+inline uint8x8_t OuterThreshold(const uint8x8_t p0q0, const uint8x8_t p1q1,
+ const uint8_t outer_thresh) {
+ const uint8x8x2_t a = Interleave32(p0q0, p1q1);
+ const uint8x8_t b = vabd_u8(a.val[0], a.val[1]);
+ const uint8x8_t p0q0_double = vqadd_u8(b, b);
+ const uint8x8_t p1q1_half = RightShift<32>(vshr_n_u8(b, 1));
+ const uint8x8_t c = vqadd_u8(p0q0_double, p1q1_half);
+ return vcle_u8(c, vdup_n_u8(outer_thresh));
+}
+
+// abs(p1 - p0) <= inner_thresh && abs(q1 - q0) <= inner_thresh &&
+// OuterThreshhold()
+inline uint8x8_t NeedsFilter4(const uint8x8_t abd_p0p1_q0q1,
+ const uint8x8_t p0q0, const uint8x8_t p1q1,
+ const uint8_t inner_thresh,
+ const uint8_t outer_thresh) {
+ const uint8x8_t a = vcle_u8(abd_p0p1_q0q1, vdup_n_u8(inner_thresh));
+ const uint8x8_t inner_mask = vand_u8(a, RightShift<32>(a));
+ const uint8x8_t outer_mask = OuterThreshold(p0q0, p1q1, outer_thresh);
+ return vand_u8(inner_mask, outer_mask);
+}
+
+inline void Filter4Masks(const uint8x8_t p0q0, const uint8x8_t p1q1,
+ const uint8_t hev_thresh, const uint8_t outer_thresh,
+ const uint8_t inner_thresh, uint8x8_t* const hev_mask,
+ uint8x8_t* const needs_filter4_mask) {
+ const uint8x8_t p0p1_q0q1 = vabd_u8(p0q0, p1q1);
+ // This includes cases where NeedsFilter4() is not true and so Filter2() will
+ // not be applied.
+ const uint8x8_t hev_tmp_mask = Hev(p0p1_q0q1, hev_thresh);
+
+ *needs_filter4_mask =
+ NeedsFilter4(p0p1_q0q1, p0q0, p1q1, inner_thresh, outer_thresh);
+
+ // Filter2() will only be applied if both NeedsFilter4() and Hev() are true.
+ *hev_mask = vand_u8(hev_tmp_mask, *needs_filter4_mask);
+}
+
+// Calculate Filter4() or Filter2() based on |hev_mask|.
+inline void Filter4(const uint8x8_t q0p1, const uint8x8_t p0q1,
+ const uint8x8_t hev_mask, uint8x8_t* const p1q1_result,
+ uint8x8_t* const p0q0_result) {
+ const int16x4_t zero = vdup_n_s16(0);
+
+ // a = 3 * (q0 - p0) + Clip3(p1 - q1, min_signed_val, max_signed_val);
+ const int16x8_t q0mp0_p1mq1 = vreinterpretq_s16_u16(vsubl_u8(q0p1, p0q1));
+ const int16x4_t q0mp0_3 = vmul_n_s16(vget_low_s16(q0mp0_p1mq1), 3);
+
+ // If this is for Filter2() then include |p1mq1|. Otherwise zero it.
+ const int16x4_t p1mq1 = vget_high_s16(q0mp0_p1mq1);
+ const int8x8_t p1mq1_saturated = vqmovn_s16(vcombine_s16(p1mq1, zero));
+ const int8x8_t hev_option =
+ vand_s8(vreinterpret_s8_u8(hev_mask), p1mq1_saturated);
+
+ const int16x4_t a =
+ vget_low_s16(vaddw_s8(vcombine_s16(q0mp0_3, zero), hev_option));
+
+ // We can not shift with rounding because the clamp comes *before* the
+ // shifting. a1 = Clip3(a + 4, min_signed_val, max_signed_val) >> 3; a2 =
+ // Clip3(a + 3, min_signed_val, max_signed_val) >> 3;
+ const int16x4_t plus_four = vadd_s16(a, vdup_n_s16(4));
+ const int16x4_t plus_three = vadd_s16(a, vdup_n_s16(3));
+ const int8x8_t a2_a1 =
+ vshr_n_s8(vqmovn_s16(vcombine_s16(plus_three, plus_four)), 3);
+
+ // a3 is in the high 4 values.
+ // a3 = (a1 + 1) >> 1;
+ const int8x8_t a3 = vrshr_n_s8(a2_a1, 1);
+
+ const int16x8_t p0q1_l = vreinterpretq_s16_u16(vmovl_u8(p0q1));
+ const int16x8_t q0p1_l = vreinterpretq_s16_u16(vmovl_u8(q0p1));
+
+ const int16x8_t p1q1_l =
+ vcombine_s16(vget_high_s16(q0p1_l), vget_high_s16(p0q1_l));
+
+ const int8x8_t a3_ma3 = InterleaveHigh32(a3, vneg_s8(a3));
+ const int16x8_t p1q1_a3 = vaddw_s8(p1q1_l, a3_ma3);
+
+ const int16x8_t p0q0_l =
+ vcombine_s16(vget_low_s16(p0q1_l), vget_low_s16(q0p1_l));
+ // Need to shift the second term or we end up with a2_ma2.
+ const int8x8_t a2_ma1 =
+ InterleaveLow32(a2_a1, RightShift<32>(vneg_s8(a2_a1)));
+ const int16x8_t p0q0_a = vaddw_s8(p0q0_l, a2_ma1);
+
+ *p1q1_result = vqmovun_s16(p1q1_a3);
+ *p0q0_result = vqmovun_s16(p0q0_a);
+}
+
+void Horizontal4_NEON(void* const dest, const ptrdiff_t stride,
+ const int outer_thresh, const int inner_thresh,
+ const int hev_thresh) {
+ uint8_t* dst = static_cast<uint8_t*>(dest);
+
+ const uint8x8_t p1_v = Load4(dst - 2 * stride);
+ const uint8x8_t p0_v = Load4(dst - stride);
+ const uint8x8_t p0q0 = Load4<1>(dst, p0_v);
+ const uint8x8_t p1q1 = Load4<1>(dst + stride, p1_v);
+
+ uint8x8_t hev_mask;
+ uint8x8_t needs_filter4_mask;
+ Filter4Masks(p0q0, p1q1, hev_thresh, outer_thresh, inner_thresh, &hev_mask,
+ &needs_filter4_mask);
+
+ // Copy the masks to the high bits for packed comparisons later.
+ hev_mask = InterleaveLow32(hev_mask, hev_mask);
+ needs_filter4_mask = InterleaveLow32(needs_filter4_mask, needs_filter4_mask);
+
+#if defined(__aarch64__)
+ // This provides a good speedup for the unit test. Not sure how applicable it
+ // is to valid streams though.
+ // Consider doing this on armv7 if there is a quick way to check if a vector
+ // is zero.
+ if (vaddv_u8(needs_filter4_mask) == 0) {
+ // None of the values will be filtered.
+ return;
+ }
+#endif // defined(__aarch64__)
+
+ uint8x8_t f_p1q1;
+ uint8x8_t f_p0q0;
+ const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1);
+ Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0);
+
+ // Already integrated the Hev mask when calculating the filtered values.
+ const uint8x8_t p0q0_output = vbsl_u8(needs_filter4_mask, f_p0q0, p0q0);
+
+ // p1/q1 are unmodified if only Hev() is true. This works because it was and'd
+ // with |needs_filter4_mask| previously.
+ const uint8x8_t p1q1_mask = veor_u8(hev_mask, needs_filter4_mask);
+ const uint8x8_t p1q1_output = vbsl_u8(p1q1_mask, f_p1q1, p1q1);
+
+ StoreLo4(dst - 2 * stride, p1q1_output);
+ StoreLo4(dst - stride, p0q0_output);
+ StoreHi4(dst, p0q0_output);
+ StoreHi4(dst + stride, p1q1_output);
+}
+
+void Vertical4_NEON(void* const dest, const ptrdiff_t stride,
+ const int outer_thresh, const int inner_thresh,
+ const int hev_thresh) {
+ uint8_t* dst = static_cast<uint8_t*>(dest);
+
+ // Move |dst| to the left side of the filter window.
+ dst -= 2;
+
+ // |p1q0| and |p0q1| are named for the values they will contain after the
+ // transpose.
+ const uint8x8_t row0 = Load4(dst);
+ uint8x8_t p1q0 = Load4<1>(dst + stride, row0);
+ const uint8x8_t row2 = Load4(dst + 2 * stride);
+ uint8x8_t p0q1 = Load4<1>(dst + 3 * stride, row2);
+
+ Transpose4x4(&p1q0, &p0q1);
+ // Rearrange.
+ const uint8x8x2_t p1q1xq0p0 = Interleave32(p1q0, Transpose32(p0q1));
+ const uint8x8x2_t p1q1xp0q0 = {p1q1xq0p0.val[0],
+ Transpose32(p1q1xq0p0.val[1])};
+
+ uint8x8_t hev_mask;
+ uint8x8_t needs_filter4_mask;
+ Filter4Masks(p1q1xp0q0.val[1], p1q1xp0q0.val[0], hev_thresh, outer_thresh,
+ inner_thresh, &hev_mask, &needs_filter4_mask);
+
+ // Copy the masks to the high bits for packed comparisons later.
+ hev_mask = InterleaveLow32(hev_mask, hev_mask);
+ needs_filter4_mask = InterleaveLow32(needs_filter4_mask, needs_filter4_mask);
+
+#if defined(__aarch64__)
+ // This provides a good speedup for the unit test. Not sure how applicable it
+ // is to valid streams though.
+ // Consider doing this on armv7 if there is a quick way to check if a vector
+ // is zero.
+ if (vaddv_u8(needs_filter4_mask) == 0) {
+ // None of the values will be filtered.
+ return;
+ }
+#endif // defined(__aarch64__)
+
+ uint8x8_t f_p1q1;
+ uint8x8_t f_p0q0;
+ Filter4(Transpose32(p1q0), p0q1, hev_mask, &f_p1q1, &f_p0q0);
+
+ // Already integrated the Hev mask when calculating the filtered values.
+ const uint8x8_t p0q0_output =
+ vbsl_u8(needs_filter4_mask, f_p0q0, p1q1xp0q0.val[1]);
+
+ // p1/q1 are unmodified if only Hev() is true. This works because it was and'd
+ // with |needs_filter4_mask| previously.
+ const uint8x8_t p1q1_mask = veor_u8(hev_mask, needs_filter4_mask);
+ const uint8x8_t p1q1_output = vbsl_u8(p1q1_mask, f_p1q1, p1q1xp0q0.val[0]);
+
+ // Put things back in order to reverse the transpose.
+ const uint8x8x2_t p1p0xq1q0 = Interleave32(p1q1_output, p0q0_output);
+ uint8x8_t output_0 = p1p0xq1q0.val[0],
+ output_1 = Transpose32(p1p0xq1q0.val[1]);
+
+ Transpose4x4(&output_0, &output_1);
+
+ StoreLo4(dst, output_0);
+ StoreLo4(dst + stride, output_1);
+ StoreHi4(dst + 2 * stride, output_0);
+ StoreHi4(dst + 3 * stride, output_1);
+}
+
+// abs(p1 - p0) <= flat_thresh && abs(q1 - q0) <= flat_thresh &&
+// abs(p2 - p0) <= flat_thresh && abs(q2 - q0) <= flat_thresh
+// |flat_thresh| == 1 for 8 bit decode.
+inline uint8x8_t IsFlat3(const uint8x8_t abd_p0p1_q0q1,
+ const uint8x8_t abd_p0p2_q0q2) {
+ const uint8x8_t a = vmax_u8(abd_p0p1_q0q1, abd_p0p2_q0q2);
+ const uint8x8_t b = vcle_u8(a, vdup_n_u8(1));
+ return vand_u8(b, RightShift<32>(b));
+}
+
+// abs(p2 - p1) <= inner_thresh && abs(p1 - p0) <= inner_thresh &&
+// abs(q1 - q0) <= inner_thresh && abs(q2 - q1) <= inner_thresh &&
+// OuterThreshhold()
+inline uint8x8_t NeedsFilter6(const uint8x8_t abd_p0p1_q0q1,
+ const uint8x8_t abd_p1p2_q1q2,
+ const uint8x8_t p0q0, const uint8x8_t p1q1,
+ const uint8_t inner_thresh,
+ const uint8_t outer_thresh) {
+ const uint8x8_t a = vmax_u8(abd_p0p1_q0q1, abd_p1p2_q1q2);
+ const uint8x8_t b = vcle_u8(a, vdup_n_u8(inner_thresh));
+ const uint8x8_t inner_mask = vand_u8(b, RightShift<32>(b));
+ const uint8x8_t outer_mask = OuterThreshold(p0q0, p1q1, outer_thresh);
+ return vand_u8(inner_mask, outer_mask);
+}
+
+inline void Filter6Masks(const uint8x8_t p2q2, const uint8x8_t p1q1,
+ const uint8x8_t p0q0, const uint8_t hev_thresh,
+ const uint8_t outer_thresh, const uint8_t inner_thresh,
+ uint8x8_t* const needs_filter6_mask,
+ uint8x8_t* const is_flat3_mask,
+ uint8x8_t* const hev_mask) {
+ const uint8x8_t p0p1_q0q1 = vabd_u8(p0q0, p1q1);
+ *hev_mask = Hev(p0p1_q0q1, hev_thresh);
+ *is_flat3_mask = IsFlat3(p0p1_q0q1, vabd_u8(p0q0, p2q2));
+ *needs_filter6_mask = NeedsFilter6(p0p1_q0q1, vabd_u8(p1q1, p2q2), p0q0, p1q1,
+ inner_thresh, outer_thresh);
+}
+
+inline void Filter6(const uint8x8_t p2q2, const uint8x8_t p1q1,
+ const uint8x8_t p0q0, uint8x8_t* const p1q1_output,
+ uint8x8_t* const p0q0_output) {
+ // Sum p1 and q1 output from opposite directions
+ // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0
+ // ^^^^^^^^
+ // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3)
+ // ^^^^^^^^
+ const uint16x8_t p2q2_double = vaddl_u8(p2q2, p2q2);
+ uint16x8_t sum = vaddw_u8(p2q2_double, p2q2);
+
+ // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0
+ // ^^^^^^^^
+ // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3)
+ // ^^^^^^^^
+ sum = vaddq_u16(vaddl_u8(p1q1, p1q1), sum);
+
+ // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0
+ // ^^^^^^^^
+ // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3)
+ // ^^^^^^^^
+ sum = vaddq_u16(vaddl_u8(p0q0, p0q0), sum);
+
+ // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0
+ // ^^
+ // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3)
+ // ^^
+ const uint8x8_t q0p0 = Transpose32(p0q0);
+ sum = vaddw_u8(sum, q0p0);
+
+ *p1q1_output = vrshrn_n_u16(sum, 3);
+
+ // Convert to p0 and q0 output:
+ // p0 = p1 - (2 * p2) + q0 + q1
+ // q0 = q1 - (2 * q2) + p0 + p1
+ sum = vsubq_u16(sum, p2q2_double);
+ const uint8x8_t q1p1 = Transpose32(p1q1);
+ sum = vaddq_u16(vaddl_u8(q0p0, q1p1), sum);
+
+ *p0q0_output = vrshrn_n_u16(sum, 3);
+}
+
+void Horizontal6_NEON(void* const dest, const ptrdiff_t stride,
+ const int outer_thresh, const int inner_thresh,
+ const int hev_thresh) {
+ auto* dst = static_cast<uint8_t*>(dest);
+
+ const uint8x8_t p2_v = Load4(dst - 3 * stride);
+ const uint8x8_t p1_v = Load4(dst - 2 * stride);
+ const uint8x8_t p0_v = Load4(dst - stride);
+ const uint8x8_t p0q0 = Load4<1>(dst, p0_v);
+ const uint8x8_t p1q1 = Load4<1>(dst + stride, p1_v);
+ const uint8x8_t p2q2 = Load4<1>(dst + 2 * stride, p2_v);
+
+ uint8x8_t needs_filter6_mask, is_flat3_mask, hev_mask;
+ Filter6Masks(p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh,
+ &needs_filter6_mask, &is_flat3_mask, &hev_mask);
+
+ needs_filter6_mask = InterleaveLow32(needs_filter6_mask, needs_filter6_mask);
+ is_flat3_mask = InterleaveLow32(is_flat3_mask, is_flat3_mask);
+ hev_mask = InterleaveLow32(hev_mask, hev_mask);
+
+#if defined(__aarch64__)
+ // This provides a good speedup for the unit test. Not sure how applicable it
+ // is to valid streams though.
+ // Consider doing this on armv7 if there is a quick way to check if a vector
+ // is zero.
+ if (vaddv_u8(needs_filter6_mask) == 0) {
+ // None of the values will be filtered.
+ return;
+ }
+#endif // defined(__aarch64__)
+
+ uint8x8_t f_p1q1;
+ uint8x8_t f_p0q0;
+ const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1);
+ Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0);
+ // Reset the outer values if only a Hev() mask was required.
+ f_p1q1 = vbsl_u8(hev_mask, p1q1, f_p1q1);
+
+ uint8x8_t f6_p1q1, f6_p0q0;
+#if defined(__aarch64__)
+ if (vaddv_u8(vand_u8(is_flat3_mask, needs_filter6_mask)) == 0) {
+ // Filter6() does not apply.
+ const uint8x8_t zero = vdup_n_u8(0);
+ f6_p1q1 = zero;
+ f6_p0q0 = zero;
+ } else {
+#endif // defined(__aarch64__)
+ Filter6(p2q2, p1q1, p0q0, &f6_p1q1, &f6_p0q0);
+#if defined(__aarch64__)
+ }
+#endif // defined(__aarch64__)
+
+ uint8x8_t p1q1_output = vbsl_u8(is_flat3_mask, f6_p1q1, f_p1q1);
+ p1q1_output = vbsl_u8(needs_filter6_mask, p1q1_output, p1q1);
+ StoreLo4(dst - 2 * stride, p1q1_output);
+ StoreHi4(dst + stride, p1q1_output);
+
+ uint8x8_t p0q0_output = vbsl_u8(is_flat3_mask, f6_p0q0, f_p0q0);
+ p0q0_output = vbsl_u8(needs_filter6_mask, p0q0_output, p0q0);
+ StoreLo4(dst - stride, p0q0_output);
+ StoreHi4(dst, p0q0_output);
+}
+
+void Vertical6_NEON(void* const dest, const ptrdiff_t stride,
+ const int outer_thresh, const int inner_thresh,
+ const int hev_thresh) {
+ auto* dst = static_cast<uint8_t*>(dest);
+
+ // Move |dst| to the left side of the filter window.
+ dst -= 3;
+
+ // |p2q1|, |p1q2|, |p0xx| and |q0xx| are named for the values they will
+ // contain after the transpose.
+ // These over-read by 2 bytes. We only need 6.
+ uint8x8_t p2q1 = vld1_u8(dst);
+ uint8x8_t p1q2 = vld1_u8(dst + stride);
+ uint8x8_t p0xx = vld1_u8(dst + 2 * stride);
+ uint8x8_t q0xx = vld1_u8(dst + 3 * stride);
+
+ Transpose8x4(&p2q1, &p1q2, &p0xx, &q0xx);
+
+ const uint8x8x2_t p2q2xq1p1 = Interleave32(p2q1, Transpose32(p1q2));
+ const uint8x8_t p2q2 = p2q2xq1p1.val[0];
+ const uint8x8_t p1q1 = Transpose32(p2q2xq1p1.val[1]);
+ const uint8x8_t p0q0 = InterleaveLow32(p0xx, q0xx);
+
+ uint8x8_t needs_filter6_mask, is_flat3_mask, hev_mask;
+ Filter6Masks(p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh,
+ &needs_filter6_mask, &is_flat3_mask, &hev_mask);
+
+ needs_filter6_mask = InterleaveLow32(needs_filter6_mask, needs_filter6_mask);
+ is_flat3_mask = InterleaveLow32(is_flat3_mask, is_flat3_mask);
+ hev_mask = InterleaveLow32(hev_mask, hev_mask);
+
+#if defined(__aarch64__)
+ // This provides a good speedup for the unit test. Not sure how applicable it
+ // is to valid streams though.
+ // Consider doing this on armv7 if there is a quick way to check if a vector
+ // is zero.
+ if (vaddv_u8(needs_filter6_mask) == 0) {
+ // None of the values will be filtered.
+ return;
+ }
+#endif // defined(__aarch64__)
+
+ uint8x8_t f_p1q1;
+ uint8x8_t f_p0q0;
+ const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1);
+ Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0);
+ // Reset the outer values if only a Hev() mask was required.
+ f_p1q1 = vbsl_u8(hev_mask, p1q1, f_p1q1);
+
+ uint8x8_t f6_p1q1, f6_p0q0;
+#if defined(__aarch64__)
+ if (vaddv_u8(vand_u8(is_flat3_mask, needs_filter6_mask)) == 0) {
+ // Filter6() does not apply.
+ const uint8x8_t zero = vdup_n_u8(0);
+ f6_p1q1 = zero;
+ f6_p0q0 = zero;
+ } else {
+#endif // defined(__aarch64__)
+ Filter6(p2q2, p1q1, p0q0, &f6_p1q1, &f6_p0q0);
+#if defined(__aarch64__)
+ }
+#endif // defined(__aarch64__)
+
+ uint8x8_t p1q1_output = vbsl_u8(is_flat3_mask, f6_p1q1, f_p1q1);
+ p1q1_output = vbsl_u8(needs_filter6_mask, p1q1_output, p1q1);
+
+ uint8x8_t p0q0_output = vbsl_u8(is_flat3_mask, f6_p0q0, f_p0q0);
+ p0q0_output = vbsl_u8(needs_filter6_mask, p0q0_output, p0q0);
+
+ // The six tap filter is only six taps on input. Output is limited to p1-q1.
+ dst += 1;
+ // Put things back in order to reverse the transpose.
+ const uint8x8x2_t p1p0xq1q0 = Interleave32(p1q1_output, p0q0_output);
+ uint8x8_t output_0 = p1p0xq1q0.val[0];
+ uint8x8_t output_1 = Transpose32(p1p0xq1q0.val[1]);
+
+ Transpose4x4(&output_0, &output_1);
+
+ StoreLo4(dst, output_0);
+ StoreLo4(dst + stride, output_1);
+ StoreHi4(dst + 2 * stride, output_0);
+ StoreHi4(dst + 3 * stride, output_1);
+}
+
+// IsFlat4 uses N=1, IsFlatOuter4 uses N=4.
+// abs(p[N] - p0) <= flat_thresh && abs(q[N] - q0) <= flat_thresh &&
+// abs(p[N+1] - p0) <= flat_thresh && abs(q[N+1] - q0) <= flat_thresh &&
+// abs(p[N+2] - p0) <= flat_thresh && abs(q[N+1] - q0) <= flat_thresh
+// |flat_thresh| == 1 for 8 bit decode.
+inline uint8x8_t IsFlat4(const uint8x8_t abd_p0n0_q0n0,
+ const uint8x8_t abd_p0n1_q0n1,
+ const uint8x8_t abd_p0n2_q0n2) {
+ const uint8x8_t a = vmax_u8(abd_p0n0_q0n0, abd_p0n1_q0n1);
+ const uint8x8_t b = vmax_u8(a, abd_p0n2_q0n2);
+ const uint8x8_t c = vcle_u8(b, vdup_n_u8(1));
+ return vand_u8(c, RightShift<32>(c));
+}
+
+// abs(p3 - p2) <= inner_thresh && abs(p2 - p1) <= inner_thresh &&
+// abs(p1 - p0) <= inner_thresh && abs(q1 - q0) <= inner_thresh &&
+// abs(q2 - q1) <= inner_thresh && abs(q3 - q2) <= inner_thresh
+// OuterThreshhold()
+inline uint8x8_t NeedsFilter8(const uint8x8_t abd_p0p1_q0q1,
+ const uint8x8_t abd_p1p2_q1q2,
+ const uint8x8_t abd_p2p3_q2q3,
+ const uint8x8_t p0q0, const uint8x8_t p1q1,
+ const uint8_t inner_thresh,
+ const uint8_t outer_thresh) {
+ const uint8x8_t a = vmax_u8(abd_p0p1_q0q1, abd_p1p2_q1q2);
+ const uint8x8_t b = vmax_u8(a, abd_p2p3_q2q3);
+ const uint8x8_t c = vcle_u8(b, vdup_n_u8(inner_thresh));
+ const uint8x8_t inner_mask = vand_u8(c, RightShift<32>(c));
+ const uint8x8_t outer_mask = OuterThreshold(p0q0, p1q1, outer_thresh);
+ return vand_u8(inner_mask, outer_mask);
+}
+
+inline void Filter8Masks(const uint8x8_t p3q3, const uint8x8_t p2q2,
+ const uint8x8_t p1q1, const uint8x8_t p0q0,
+ const uint8_t hev_thresh, const uint8_t outer_thresh,
+ const uint8_t inner_thresh,
+ uint8x8_t* const needs_filter8_mask,
+ uint8x8_t* const is_flat4_mask,
+ uint8x8_t* const hev_mask) {
+ const uint8x8_t p0p1_q0q1 = vabd_u8(p0q0, p1q1);
+ *hev_mask = Hev(p0p1_q0q1, hev_thresh);
+ *is_flat4_mask = IsFlat4(p0p1_q0q1, vabd_u8(p0q0, p2q2), vabd_u8(p0q0, p3q3));
+ *needs_filter8_mask =
+ NeedsFilter8(p0p1_q0q1, vabd_u8(p1q1, p2q2), vabd_u8(p2q2, p3q3), p0q0,
+ p1q1, inner_thresh, outer_thresh);
+}
+
+inline void Filter8(const uint8x8_t p3q3, const uint8x8_t p2q2,
+ const uint8x8_t p1q1, const uint8x8_t p0q0,
+ uint8x8_t* const p2q2_output, uint8x8_t* const p1q1_output,
+ uint8x8_t* const p0q0_output) {
+ // Sum p2 and q2 output from opposite directions
+ // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0
+ // ^^^^^^^^
+ // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3)
+ // ^^^^^^^^
+ uint16x8_t sum = vaddw_u8(vaddl_u8(p3q3, p3q3), p3q3);
+
+ // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0
+ // ^^^^^^^^
+ // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3)
+ // ^^^^^^^^
+ sum = vaddq_u16(vaddl_u8(p2q2, p2q2), sum);
+
+ // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0
+ // ^^^^^^^
+ // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3)
+ // ^^^^^^^
+ sum = vaddq_u16(vaddl_u8(p1q1, p0q0), sum);
+
+ // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0
+ // ^^
+ // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3)
+ // ^^
+ const uint8x8_t q0p0 = Transpose32(p0q0);
+ sum = vaddw_u8(sum, q0p0);
+
+ *p2q2_output = vrshrn_n_u16(sum, 3);
+
+ // Convert to p1 and q1 output:
+ // p1 = p2 - p3 - p2 + p1 + q1
+ // q1 = q2 - q3 - q2 + q0 + p1
+ sum = vsubq_u16(sum, vaddl_u8(p3q3, p2q2));
+ const uint8x8_t q1p1 = Transpose32(p1q1);
+ sum = vaddq_u16(vaddl_u8(p1q1, q1p1), sum);
+
+ *p1q1_output = vrshrn_n_u16(sum, 3);
+
+ // Convert to p0 and q0 output:
+ // p0 = p1 - p3 - p1 + p0 + q2
+ // q0 = q1 - q3 - q1 + q0 + p2
+ sum = vsubq_u16(sum, vaddl_u8(p3q3, p1q1));
+ const uint8x8_t q2p2 = Transpose32(p2q2);
+ sum = vaddq_u16(vaddl_u8(p0q0, q2p2), sum);
+
+ *p0q0_output = vrshrn_n_u16(sum, 3);
+}
+
+void Horizontal8_NEON(void* const dest, const ptrdiff_t stride,
+ const int outer_thresh, const int inner_thresh,
+ const int hev_thresh) {
+ auto* dst = static_cast<uint8_t*>(dest);
+
+ const uint8x8_t p3_v = Load4(dst - 4 * stride);
+ const uint8x8_t p2_v = Load4(dst - 3 * stride);
+ const uint8x8_t p1_v = Load4(dst - 2 * stride);
+ const uint8x8_t p0_v = Load4(dst - stride);
+ const uint8x8_t p0q0 = Load4<1>(dst, p0_v);
+ const uint8x8_t p1q1 = Load4<1>(dst + stride, p1_v);
+ const uint8x8_t p2q2 = Load4<1>(dst + 2 * stride, p2_v);
+ const uint8x8_t p3q3 = Load4<1>(dst + 3 * stride, p3_v);
+
+ uint8x8_t needs_filter8_mask, is_flat4_mask, hev_mask;
+ Filter8Masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh,
+ &needs_filter8_mask, &is_flat4_mask, &hev_mask);
+
+ needs_filter8_mask = InterleaveLow32(needs_filter8_mask, needs_filter8_mask);
+ is_flat4_mask = vand_u8(is_flat4_mask, needs_filter8_mask);
+ is_flat4_mask = InterleaveLow32(is_flat4_mask, is_flat4_mask);
+ hev_mask = InterleaveLow32(hev_mask, hev_mask);
+
+#if defined(__aarch64__)
+ // This provides a good speedup for the unit test. Not sure how applicable it
+ // is to valid streams though.
+ // Consider doing this on armv7 if there is a quick way to check if a vector
+ // is zero.
+ if (vaddv_u8(needs_filter8_mask) == 0) {
+ // None of the values will be filtered.
+ return;
+ }
+#endif // defined(__aarch64__)
+
+ uint8x8_t f_p1q1;
+ uint8x8_t f_p0q0;
+ const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1);
+ Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0);
+ // Reset the outer values if only a Hev() mask was required.
+ f_p1q1 = vbsl_u8(hev_mask, p1q1, f_p1q1);
+
+ uint8x8_t f8_p2q2, f8_p1q1, f8_p0q0;
+#if defined(__aarch64__)
+ if (vaddv_u8(is_flat4_mask) == 0) {
+ // Filter8() does not apply.
+ const uint8x8_t zero = vdup_n_u8(0);
+ f8_p2q2 = zero;
+ f8_p1q1 = zero;
+ f8_p0q0 = zero;
+ } else {
+#endif // defined(__aarch64__)
+ Filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0);
+
+ const uint8x8_t p2p2_output = vbsl_u8(is_flat4_mask, f8_p2q2, p2q2);
+ StoreLo4(dst - 3 * stride, p2p2_output);
+ StoreHi4(dst + 2 * stride, p2p2_output);
+#if defined(__aarch64__)
+ }
+#endif // defined(__aarch64__)
+
+ uint8x8_t p1q1_output = vbsl_u8(is_flat4_mask, f8_p1q1, f_p1q1);
+ p1q1_output = vbsl_u8(needs_filter8_mask, p1q1_output, p1q1);
+ StoreLo4(dst - 2 * stride, p1q1_output);
+ StoreHi4(dst + stride, p1q1_output);
+
+ uint8x8_t p0q0_output = vbsl_u8(is_flat4_mask, f8_p0q0, f_p0q0);
+ p0q0_output = vbsl_u8(needs_filter8_mask, p0q0_output, p0q0);
+ StoreLo4(dst - stride, p0q0_output);
+ StoreHi4(dst, p0q0_output);
+}
+
+void Vertical8_NEON(void* const dest, const ptrdiff_t stride,
+ const int outer_thresh, const int inner_thresh,
+ const int hev_thresh) {
+ auto* dst = static_cast<uint8_t*>(dest);
+ // Move |dst| to the left side of the filter window.
+ dst -= 4;
+
+ // |p3q0|, |p2q1|, |p1q2| and |p0q3| are named for the values they will
+ // contain after the transpose.
+ uint8x8_t p3q0 = vld1_u8(dst);
+ uint8x8_t p2q1 = vld1_u8(dst + stride);
+ uint8x8_t p1q2 = vld1_u8(dst + 2 * stride);
+ uint8x8_t p0q3 = vld1_u8(dst + 3 * stride);
+
+ Transpose8x4(&p3q0, &p2q1, &p1q2, &p0q3);
+ const uint8x8x2_t p3q3xq0p0 = Interleave32(p3q0, Transpose32(p0q3));
+ const uint8x8_t p3q3 = p3q3xq0p0.val[0];
+ const uint8x8_t p0q0 = Transpose32(p3q3xq0p0.val[1]);
+ const uint8x8x2_t p2q2xq1p1 = Interleave32(p2q1, Transpose32(p1q2));
+ const uint8x8_t p2q2 = p2q2xq1p1.val[0];
+ const uint8x8_t p1q1 = Transpose32(p2q2xq1p1.val[1]);
+
+ uint8x8_t needs_filter8_mask, is_flat4_mask, hev_mask;
+ Filter8Masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh,
+ &needs_filter8_mask, &is_flat4_mask, &hev_mask);
+
+ needs_filter8_mask = InterleaveLow32(needs_filter8_mask, needs_filter8_mask);
+ is_flat4_mask = vand_u8(is_flat4_mask, needs_filter8_mask);
+ is_flat4_mask = InterleaveLow32(is_flat4_mask, is_flat4_mask);
+ hev_mask = InterleaveLow32(hev_mask, hev_mask);
+
+#if defined(__aarch64__)
+ // This provides a good speedup for the unit test. Not sure how applicable it
+ // is to valid streams though.
+ // Consider doing this on armv7 if there is a quick way to check if a vector
+ // is zero.
+ if (vaddv_u8(needs_filter8_mask) == 0) {
+ // None of the values will be filtered.
+ return;
+ }
+#endif // defined(__aarch64__)
+
+ uint8x8_t f_p1q1;
+ uint8x8_t f_p0q0;
+ const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1);
+ Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0);
+ // Reset the outer values if only a Hev() mask was required.
+ f_p1q1 = vbsl_u8(hev_mask, p1q1, f_p1q1);
+
+ uint8x8_t f8_p2q2, f8_p1q1, f8_p0q0;
+#if defined(__aarch64__)
+ if (vaddv_u8(is_flat4_mask) == 0) {
+ // Filter8() does not apply.
+ const uint8x8_t zero = vdup_n_u8(0);
+ f8_p2q2 = zero;
+ f8_p1q1 = zero;
+ f8_p0q0 = zero;
+ } else {
+#endif // defined(__aarch64__)
+ Filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0);
+
+#if defined(__aarch64__)
+ }
+#endif // defined(__aarch64__)
+
+ // Always prepare and store p2/q2 because we need to transpose it anyway.
+ const uint8x8_t p2q2_output = vbsl_u8(is_flat4_mask, f8_p2q2, p2q2);
+
+ uint8x8_t p1q1_output = vbsl_u8(is_flat4_mask, f8_p1q1, f_p1q1);
+ p1q1_output = vbsl_u8(needs_filter8_mask, p1q1_output, p1q1);
+
+ uint8x8_t p0q0_output = vbsl_u8(is_flat4_mask, f8_p0q0, f_p0q0);
+ p0q0_output = vbsl_u8(needs_filter8_mask, p0q0_output, p0q0);
+
+ // Write out p3/q3 as well. There isn't a good way to write out 6 bytes.
+ // Variable names reflect the values before transposition.
+ const uint8x8x2_t p3q0xq3p0_output =
+ Interleave32(p3q3, Transpose32(p0q0_output));
+ uint8x8_t p3q0_output = p3q0xq3p0_output.val[0];
+ uint8x8_t p0q3_output = Transpose32(p3q0xq3p0_output.val[1]);
+ const uint8x8x2_t p2q1xq2p1_output =
+ Interleave32(p2q2_output, Transpose32(p1q1_output));
+ uint8x8_t p2q1_output = p2q1xq2p1_output.val[0];
+ uint8x8_t p1q2_output = Transpose32(p2q1xq2p1_output.val[1]);
+
+ Transpose8x4(&p3q0_output, &p2q1_output, &p1q2_output, &p0q3_output);
+
+ vst1_u8(dst, p3q0_output);
+ vst1_u8(dst + stride, p2q1_output);
+ vst1_u8(dst + 2 * stride, p1q2_output);
+ vst1_u8(dst + 3 * stride, p0q3_output);
+}
+
+inline void Filter14(const uint8x8_t p6q6, const uint8x8_t p5q5,
+ const uint8x8_t p4q4, const uint8x8_t p3q3,
+ const uint8x8_t p2q2, const uint8x8_t p1q1,
+ const uint8x8_t p0q0, uint8x8_t* const p5q5_output,
+ uint8x8_t* const p4q4_output, uint8x8_t* const p3q3_output,
+ uint8x8_t* const p2q2_output, uint8x8_t* const p1q1_output,
+ uint8x8_t* const p0q0_output) {
+ // Sum p5 and q5 output from opposite directions
+ // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0
+ // ^^^^^^^^
+ // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6)
+ // ^^^^^^^^
+ uint16x8_t sum = vsubw_u8(vshll_n_u8(p6q6, 3), p6q6);
+
+ // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0
+ // ^^^^^^^^
+ // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6)
+ // ^^^^^^^^
+ sum = vaddq_u16(vaddl_u8(p5q5, p5q5), sum);
+
+ // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0
+ // ^^^^^^^^
+ // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6)
+ // ^^^^^^^^
+ sum = vaddq_u16(vaddl_u8(p4q4, p4q4), sum);
+
+ // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0
+ // ^^^^^^^
+ // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6)
+ // ^^^^^^^
+ sum = vaddq_u16(vaddl_u8(p3q3, p2q2), sum);
+
+ // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0
+ // ^^^^^^^
+ // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6)
+ // ^^^^^^^
+ sum = vaddq_u16(vaddl_u8(p1q1, p0q0), sum);
+
+ // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0
+ // ^^
+ // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6)
+ // ^^
+ const uint8x8_t q0p0 = Transpose32(p0q0);
+ sum = vaddw_u8(sum, q0p0);
+
+ *p5q5_output = vrshrn_n_u16(sum, 4);
+
+ // Convert to p4 and q4 output:
+ // p4 = p5 - (2 * p6) + p3 + q1
+ // q4 = q5 - (2 * q6) + q3 + p1
+ sum = vsubq_u16(sum, vaddl_u8(p6q6, p6q6));
+ const uint8x8_t q1p1 = Transpose32(p1q1);
+ sum = vaddq_u16(vaddl_u8(p3q3, q1p1), sum);
+
+ *p4q4_output = vrshrn_n_u16(sum, 4);
+
+ // Convert to p3 and q3 output:
+ // p3 = p4 - p6 - p5 + p2 + q2
+ // q3 = q4 - q6 - q5 + q2 + p2
+ sum = vsubq_u16(sum, vaddl_u8(p6q6, p5q5));
+ const uint8x8_t q2p2 = Transpose32(p2q2);
+ sum = vaddq_u16(vaddl_u8(p2q2, q2p2), sum);
+
+ *p3q3_output = vrshrn_n_u16(sum, 4);
+
+ // Convert to p2 and q2 output:
+ // p2 = p3 - p6 - p4 + p1 + q3
+ // q2 = q3 - q6 - q4 + q1 + p3
+ sum = vsubq_u16(sum, vaddl_u8(p6q6, p4q4));
+ const uint8x8_t q3p3 = Transpose32(p3q3);
+ sum = vaddq_u16(vaddl_u8(p1q1, q3p3), sum);
+
+ *p2q2_output = vrshrn_n_u16(sum, 4);
+
+ // Convert to p1 and q1 output:
+ // p1 = p2 - p6 - p3 + p0 + q4
+ // q1 = q2 - q6 - q3 + q0 + p4
+ sum = vsubq_u16(sum, vaddl_u8(p6q6, p3q3));
+ const uint8x8_t q4p4 = Transpose32(p4q4);
+ sum = vaddq_u16(vaddl_u8(p0q0, q4p4), sum);
+
+ *p1q1_output = vrshrn_n_u16(sum, 4);
+
+ // Convert to p0 and q0 output:
+ // p0 = p1 - p6 - p2 + q0 + q5
+ // q0 = q1 - q6 - q2 + p0 + p5
+ sum = vsubq_u16(sum, vaddl_u8(p6q6, p2q2));
+ const uint8x8_t q5p5 = Transpose32(p5q5);
+ sum = vaddq_u16(vaddl_u8(q0p0, q5p5), sum);
+
+ *p0q0_output = vrshrn_n_u16(sum, 4);
+}
+
+void Horizontal14_NEON(void* const dest, const ptrdiff_t stride,
+ const int outer_thresh, const int inner_thresh,
+ const int hev_thresh) {
+ auto* dst = static_cast<uint8_t*>(dest);
+
+ const uint8x8_t p6_v = Load4(dst - 7 * stride);
+ const uint8x8_t p5_v = Load4(dst - 6 * stride);
+ const uint8x8_t p4_v = Load4(dst - 5 * stride);
+ const uint8x8_t p3_v = Load4(dst - 4 * stride);
+ const uint8x8_t p2_v = Load4(dst - 3 * stride);
+ const uint8x8_t p1_v = Load4(dst - 2 * stride);
+ const uint8x8_t p0_v = Load4(dst - stride);
+ const uint8x8_t p0q0 = Load4<1>(dst, p0_v);
+ const uint8x8_t p1q1 = Load4<1>(dst + stride, p1_v);
+ const uint8x8_t p2q2 = Load4<1>(dst + 2 * stride, p2_v);
+ const uint8x8_t p3q3 = Load4<1>(dst + 3 * stride, p3_v);
+ const uint8x8_t p4q4 = Load4<1>(dst + 4 * stride, p4_v);
+ const uint8x8_t p5q5 = Load4<1>(dst + 5 * stride, p5_v);
+ const uint8x8_t p6q6 = Load4<1>(dst + 6 * stride, p6_v);
+
+ uint8x8_t needs_filter8_mask, is_flat4_mask, hev_mask;
+ Filter8Masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh,
+ &needs_filter8_mask, &is_flat4_mask, &hev_mask);
+
+ needs_filter8_mask = InterleaveLow32(needs_filter8_mask, needs_filter8_mask);
+ is_flat4_mask = vand_u8(is_flat4_mask, needs_filter8_mask);
+ is_flat4_mask = InterleaveLow32(is_flat4_mask, is_flat4_mask);
+ hev_mask = InterleaveLow32(hev_mask, hev_mask);
+
+#if defined(__aarch64__)
+ // This provides a good speedup for the unit test. Not sure how applicable it
+ // is to valid streams though.
+ // Consider doing this on armv7 if there is a quick way to check if a vector
+ // is zero.
+ if (vaddv_u8(needs_filter8_mask) == 0) {
+ // None of the values will be filtered.
+ return;
+ }
+#endif // defined(__aarch64__)
+
+ // Decide between Filter8() and Filter14().
+ uint8x8_t is_flat_outer4_mask =
+ IsFlat4(vabd_u8(p0q0, p4q4), vabd_u8(p0q0, p5q5), vabd_u8(p0q0, p6q6));
+ is_flat_outer4_mask = vand_u8(is_flat4_mask, is_flat_outer4_mask);
+ is_flat_outer4_mask =
+ InterleaveLow32(is_flat_outer4_mask, is_flat_outer4_mask);
+
+ uint8x8_t f_p1q1;
+ uint8x8_t f_p0q0;
+ const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1);
+ Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0);
+ // Reset the outer values if only a Hev() mask was required.
+ f_p1q1 = vbsl_u8(hev_mask, p1q1, f_p1q1);
+
+ uint8x8_t f8_p1q1, f8_p0q0;
+ uint8x8_t f14_p2q2, f14_p1q1, f14_p0q0;
+#if defined(__aarch64__)
+ if (vaddv_u8(is_flat4_mask) == 0) {
+ // Filter8() and Filter14() do not apply.
+ const uint8x8_t zero = vdup_n_u8(0);
+ f8_p1q1 = zero;
+ f8_p0q0 = zero;
+ f14_p1q1 = zero;
+ f14_p0q0 = zero;
+ } else {
+#endif // defined(__aarch64__)
+ uint8x8_t f8_p2q2;
+ Filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0);
+
+#if defined(__aarch64__)
+ if (vaddv_u8(is_flat_outer4_mask) == 0) {
+ // Filter14() does not apply.
+ const uint8x8_t zero = vdup_n_u8(0);
+ f14_p2q2 = zero;
+ f14_p1q1 = zero;
+ f14_p0q0 = zero;
+ } else {
+#endif // defined(__aarch64__)
+ uint8x8_t f14_p5q5, f14_p4q4, f14_p3q3;
+ Filter14(p6q6, p5q5, p4q4, p3q3, p2q2, p1q1, p0q0, &f14_p5q5, &f14_p4q4,
+ &f14_p3q3, &f14_p2q2, &f14_p1q1, &f14_p0q0);
+
+ const uint8x8_t p5q5_output =
+ vbsl_u8(is_flat_outer4_mask, f14_p5q5, p5q5);
+ StoreLo4(dst - 6 * stride, p5q5_output);
+ StoreHi4(dst + 5 * stride, p5q5_output);
+
+ const uint8x8_t p4q4_output =
+ vbsl_u8(is_flat_outer4_mask, f14_p4q4, p4q4);
+ StoreLo4(dst - 5 * stride, p4q4_output);
+ StoreHi4(dst + 4 * stride, p4q4_output);
+
+ const uint8x8_t p3q3_output =
+ vbsl_u8(is_flat_outer4_mask, f14_p3q3, p3q3);
+ StoreLo4(dst - 4 * stride, p3q3_output);
+ StoreHi4(dst + 3 * stride, p3q3_output);
+#if defined(__aarch64__)
+ }
+#endif // defined(__aarch64__)
+
+ uint8x8_t p2q2_output = vbsl_u8(is_flat_outer4_mask, f14_p2q2, f8_p2q2);
+ p2q2_output = vbsl_u8(is_flat4_mask, p2q2_output, p2q2);
+ StoreLo4(dst - 3 * stride, p2q2_output);
+ StoreHi4(dst + 2 * stride, p2q2_output);
+#if defined(__aarch64__)
+ }
+#endif // defined(__aarch64__)
+
+ uint8x8_t p1q1_output = vbsl_u8(is_flat_outer4_mask, f14_p1q1, f8_p1q1);
+ p1q1_output = vbsl_u8(is_flat4_mask, p1q1_output, f_p1q1);
+ p1q1_output = vbsl_u8(needs_filter8_mask, p1q1_output, p1q1);
+ StoreLo4(dst - 2 * stride, p1q1_output);
+ StoreHi4(dst + stride, p1q1_output);
+
+ uint8x8_t p0q0_output = vbsl_u8(is_flat_outer4_mask, f14_p0q0, f8_p0q0);
+ p0q0_output = vbsl_u8(is_flat4_mask, p0q0_output, f_p0q0);
+ p0q0_output = vbsl_u8(needs_filter8_mask, p0q0_output, p0q0);
+ StoreLo4(dst - stride, p0q0_output);
+ StoreHi4(dst, p0q0_output);
+}
+
+void Vertical14_NEON(void* const dest, const ptrdiff_t stride,
+ const int outer_thresh, const int inner_thresh,
+ const int hev_thresh) {
+ auto* dst = static_cast<uint8_t*>(dest);
+ dst -= 8;
+ // input
+ // p7 p6 p5 p4 p3 p2 p1 p0 q0 q1 q2 q3 q4 q5 q6 q7
+ const uint8x16_t x0 = vld1q_u8(dst);
+ dst += stride;
+ const uint8x16_t x1 = vld1q_u8(dst);
+ dst += stride;
+ const uint8x16_t x2 = vld1q_u8(dst);
+ dst += stride;
+ const uint8x16_t x3 = vld1q_u8(dst);
+ dst -= (stride * 3);
+
+ // re-order input
+#if defined(__aarch64__)
+ const uint8x8_t index_qp3toqp0 = vcreate_u8(0x0b0a090804050607);
+ const uint8x8_t index_qp7toqp4 = vcreate_u8(0x0f0e0d0c00010203);
+ const uint8x16_t index_qp7toqp0 = vcombine_u8(index_qp3toqp0, index_qp7toqp4);
+
+ uint8x16_t input_0 = vqtbl1q_u8(x0, index_qp7toqp0);
+ uint8x16_t input_1 = vqtbl1q_u8(x1, index_qp7toqp0);
+ uint8x16_t input_2 = vqtbl1q_u8(x2, index_qp7toqp0);
+ uint8x16_t input_3 = vqtbl1q_u8(x3, index_qp7toqp0);
+#else
+ const uint8x8_t index_qp3toqp0 = vcreate_u8(0x0b0a090804050607);
+ const uint8x8_t index_qp7toqp4 = vcreate_u8(0x0f0e0d0c00010203);
+
+ const uint8x8_t x0_qp3qp0 = VQTbl1U8(x0, index_qp3toqp0);
+ const uint8x8_t x1_qp3qp0 = VQTbl1U8(x1, index_qp3toqp0);
+ const uint8x8_t x2_qp3qp0 = VQTbl1U8(x2, index_qp3toqp0);
+ const uint8x8_t x3_qp3qp0 = VQTbl1U8(x3, index_qp3toqp0);
+
+ const uint8x8_t x0_qp7qp4 = VQTbl1U8(x0, index_qp7toqp4);
+ const uint8x8_t x1_qp7qp4 = VQTbl1U8(x1, index_qp7toqp4);
+ const uint8x8_t x2_qp7qp4 = VQTbl1U8(x2, index_qp7toqp4);
+ const uint8x8_t x3_qp7qp4 = VQTbl1U8(x3, index_qp7toqp4);
+
+ const uint8x16_t input_0 = vcombine_u8(x0_qp3qp0, x0_qp7qp4);
+ const uint8x16_t input_1 = vcombine_u8(x1_qp3qp0, x1_qp7qp4);
+ const uint8x16_t input_2 = vcombine_u8(x2_qp3qp0, x2_qp7qp4);
+ const uint8x16_t input_3 = vcombine_u8(x3_qp3qp0, x3_qp7qp4);
+#endif
+ // input after re-order
+ // p0 p1 p2 p3 q0 q1 q2 q3 p4 p5 p6 p7 q4 q5 q6 q7
+
+ const uint8x16x2_t in01 = vtrnq_u8(input_0, input_1);
+ const uint8x16x2_t in23 = vtrnq_u8(input_2, input_3);
+ const uint16x8x2_t in02 = vtrnq_u16(vreinterpretq_u16_u8(in01.val[0]),
+ vreinterpretq_u16_u8(in23.val[0]));
+ const uint16x8x2_t in13 = vtrnq_u16(vreinterpretq_u16_u8(in01.val[1]),
+ vreinterpretq_u16_u8(in23.val[1]));
+
+ const uint8x8_t p0q0 = vget_low_u8(vreinterpretq_u8_u16(in02.val[0]));
+ const uint8x8_t p1q1 = vget_low_u8(vreinterpretq_u8_u16(in13.val[0]));
+
+ const uint8x8_t p2q2 = vget_low_u8(vreinterpretq_u8_u16(in02.val[1]));
+ const uint8x8_t p3q3 = vget_low_u8(vreinterpretq_u8_u16(in13.val[1]));
+
+ const uint8x8_t p4q4 = vget_high_u8(vreinterpretq_u8_u16(in02.val[0]));
+ const uint8x8_t p5q5 = vget_high_u8(vreinterpretq_u8_u16(in13.val[0]));
+
+ const uint8x8_t p6q6 = vget_high_u8(vreinterpretq_u8_u16(in02.val[1]));
+ const uint8x8_t p7q7 = vget_high_u8(vreinterpretq_u8_u16(in13.val[1]));
+
+ uint8x8_t needs_filter8_mask, is_flat4_mask, hev_mask;
+ Filter8Masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh,
+ &needs_filter8_mask, &is_flat4_mask, &hev_mask);
+
+ needs_filter8_mask = InterleaveLow32(needs_filter8_mask, needs_filter8_mask);
+ is_flat4_mask = vand_u8(is_flat4_mask, needs_filter8_mask);
+ is_flat4_mask = InterleaveLow32(is_flat4_mask, is_flat4_mask);
+ hev_mask = InterleaveLow32(hev_mask, hev_mask);
+
+#if defined(__aarch64__)
+ // This provides a good speedup for the unit test. Not sure how applicable it
+ // is to valid streams though.
+ // Consider doing this on armv7 if there is a quick way to check if a vector
+ // is zero.
+ if (vaddv_u8(needs_filter8_mask) == 0) {
+ // None of the values will be filtered.
+ return;
+ }
+#endif // defined(__aarch64__)
+
+ // Decide between Filter8() and Filter14().
+ uint8x8_t is_flat_outer4_mask =
+ IsFlat4(vabd_u8(p0q0, p4q4), vabd_u8(p0q0, p5q5), vabd_u8(p0q0, p6q6));
+ is_flat_outer4_mask = vand_u8(is_flat4_mask, is_flat_outer4_mask);
+ is_flat_outer4_mask =
+ InterleaveLow32(is_flat_outer4_mask, is_flat_outer4_mask);
+
+ uint8x8_t f_p0q0, f_p1q1;
+ const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1);
+ Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0);
+ // Reset the outer values if only a Hev() mask was required.
+ f_p1q1 = vbsl_u8(hev_mask, p1q1, f_p1q1);
+
+ uint8x8_t p1q1_output, p0q0_output;
+ uint8x8_t p5q5_output, p4q4_output, p3q3_output, p2q2_output;
+
+#if defined(__aarch64__)
+ if (vaddv_u8(is_flat4_mask) == 0) {
+ // Filter8() and Filter14() do not apply.
+ p1q1_output = p1q1;
+ p0q0_output = p0q0;
+
+ p5q5_output = p5q5;
+ p4q4_output = p4q4;
+ p3q3_output = p3q3;
+ p2q2_output = p2q2;
+ } else {
+#endif // defined(__aarch64__)
+ uint8x8_t f8_p2q2, f8_p1q1, f8_p0q0;
+ Filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0);
+
+#if defined(__aarch64__)
+ if (vaddv_u8(is_flat_outer4_mask) == 0) {
+ // Filter14() does not apply.
+ p5q5_output = p5q5;
+ p4q4_output = p4q4;
+ p3q3_output = p3q3;
+ p2q2_output = f8_p2q2;
+ p1q1_output = f8_p1q1;
+ p0q0_output = f8_p0q0;
+ } else {
+#endif // defined(__aarch64__)
+ uint8x8_t f14_p5q5, f14_p4q4, f14_p3q3, f14_p2q2, f14_p1q1, f14_p0q0;
+ Filter14(p6q6, p5q5, p4q4, p3q3, p2q2, p1q1, p0q0, &f14_p5q5, &f14_p4q4,
+ &f14_p3q3, &f14_p2q2, &f14_p1q1, &f14_p0q0);
+
+ p5q5_output = vbsl_u8(is_flat_outer4_mask, f14_p5q5, p5q5);
+ p4q4_output = vbsl_u8(is_flat_outer4_mask, f14_p4q4, p4q4);
+ p3q3_output = vbsl_u8(is_flat_outer4_mask, f14_p3q3, p3q3);
+ p2q2_output = vbsl_u8(is_flat_outer4_mask, f14_p2q2, f8_p2q2);
+ p1q1_output = vbsl_u8(is_flat_outer4_mask, f14_p1q1, f8_p1q1);
+ p0q0_output = vbsl_u8(is_flat_outer4_mask, f14_p0q0, f8_p0q0);
+#if defined(__aarch64__)
+ }
+#endif // defined(__aarch64__)
+ p2q2_output = vbsl_u8(is_flat4_mask, p2q2_output, p2q2);
+#if defined(__aarch64__)
+ }
+#endif // defined(__aarch64__)
+
+ p1q1_output = vbsl_u8(is_flat4_mask, p1q1_output, f_p1q1);
+ p1q1_output = vbsl_u8(needs_filter8_mask, p1q1_output, p1q1);
+ p0q0_output = vbsl_u8(is_flat4_mask, p0q0_output, f_p0q0);
+ p0q0_output = vbsl_u8(needs_filter8_mask, p0q0_output, p0q0);
+
+ const uint8x16_t p0q0_p4q4 = vcombine_u8(p0q0_output, p4q4_output);
+ const uint8x16_t p2q2_p6q6 = vcombine_u8(p2q2_output, p6q6);
+ const uint8x16_t p1q1_p5q5 = vcombine_u8(p1q1_output, p5q5_output);
+ const uint8x16_t p3q3_p7q7 = vcombine_u8(p3q3_output, p7q7);
+
+ const uint16x8x2_t out02 = vtrnq_u16(vreinterpretq_u16_u8(p0q0_p4q4),
+ vreinterpretq_u16_u8(p2q2_p6q6));
+ const uint16x8x2_t out13 = vtrnq_u16(vreinterpretq_u16_u8(p1q1_p5q5),
+ vreinterpretq_u16_u8(p3q3_p7q7));
+ const uint8x16x2_t out01 = vtrnq_u8(vreinterpretq_u8_u16(out02.val[0]),
+ vreinterpretq_u8_u16(out13.val[0]));
+ const uint8x16x2_t out23 = vtrnq_u8(vreinterpretq_u8_u16(out02.val[1]),
+ vreinterpretq_u8_u16(out13.val[1]));
+
+#if defined(__aarch64__)
+ const uint8x8_t index_p7top0 = vcreate_u8(0x0001020308090a0b);
+ const uint8x8_t index_q7toq0 = vcreate_u8(0x0f0e0d0c07060504);
+ const uint8x16_t index_p7toq7 = vcombine_u8(index_p7top0, index_q7toq0);
+
+ const uint8x16_t output_0 = vqtbl1q_u8(out01.val[0], index_p7toq7);
+ const uint8x16_t output_1 = vqtbl1q_u8(out01.val[1], index_p7toq7);
+ const uint8x16_t output_2 = vqtbl1q_u8(out23.val[0], index_p7toq7);
+ const uint8x16_t output_3 = vqtbl1q_u8(out23.val[1], index_p7toq7);
+#else
+ const uint8x8_t index_p7top0 = vcreate_u8(0x0001020308090a0b);
+ const uint8x8_t index_q7toq0 = vcreate_u8(0x0f0e0d0c07060504);
+
+ const uint8x8_t x0_p7p0 = VQTbl1U8(out01.val[0], index_p7top0);
+ const uint8x8_t x1_p7p0 = VQTbl1U8(out01.val[1], index_p7top0);
+ const uint8x8_t x2_p7p0 = VQTbl1U8(out23.val[0], index_p7top0);
+ const uint8x8_t x3_p7p0 = VQTbl1U8(out23.val[1], index_p7top0);
+
+ const uint8x8_t x0_q7q0 = VQTbl1U8(out01.val[0], index_q7toq0);
+ const uint8x8_t x1_q7q0 = VQTbl1U8(out01.val[1], index_q7toq0);
+ const uint8x8_t x2_q7q0 = VQTbl1U8(out23.val[0], index_q7toq0);
+ const uint8x8_t x3_q7q0 = VQTbl1U8(out23.val[1], index_q7toq0);
+
+ const uint8x16_t output_0 = vcombine_u8(x0_p7p0, x0_q7q0);
+ const uint8x16_t output_1 = vcombine_u8(x1_p7p0, x1_q7q0);
+ const uint8x16_t output_2 = vcombine_u8(x2_p7p0, x2_q7q0);
+ const uint8x16_t output_3 = vcombine_u8(x3_p7p0, x3_q7q0);
+#endif
+
+ vst1q_u8(dst, output_0);
+ dst += stride;
+ vst1q_u8(dst, output_1);
+ dst += stride;
+ vst1q_u8(dst, output_2);
+ dst += stride;
+ vst1q_u8(dst, output_3);
+}
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+ dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeHorizontal] =
+ Horizontal4_NEON;
+ dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeVertical] = Vertical4_NEON;
+
+ dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeHorizontal] =
+ Horizontal6_NEON;
+ dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeVertical] = Vertical6_NEON;
+
+ dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeHorizontal] =
+ Horizontal8_NEON;
+ dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeVertical] = Vertical8_NEON;
+
+ dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeHorizontal] =
+ Horizontal14_NEON;
+ dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeVertical] =
+ Vertical14_NEON;
+}
+} // namespace
+} // namespace low_bitdepth
+
+void LoopFilterInit_NEON() { low_bitdepth::Init8bpp(); }
+
+} // namespace dsp
+} // namespace libgav1
+
+#else // !LIBGAV1_ENABLE_NEON
+namespace libgav1 {
+namespace dsp {
+
+void LoopFilterInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/loop_filter_neon.h b/src/dsp/arm/loop_filter_neon.h
new file mode 100644
index 0000000..5f79200
--- /dev/null
+++ b/src/dsp/arm/loop_filter_neon.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_LOOP_FILTER_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_LOOP_FILTER_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::loop_filters, see the defines below for specifics. This
+// function is not thread-safe.
+void LoopFilterInit_NEON();
+
+} // namespace dsp
+} // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+
+#define LIBGAV1_Dsp8bpp_LoopFilterSize4_LoopFilterTypeHorizontal \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_LoopFilterSize4_LoopFilterTypeVertical LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_LoopFilterSize6_LoopFilterTypeHorizontal \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_LoopFilterSize6_LoopFilterTypeVertical LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_LoopFilterSize8_LoopFilterTypeHorizontal \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_LoopFilterSize8_LoopFilterTypeVertical LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeHorizontal \
+ LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeVertical LIBGAV1_CPU_NEON
+
+#endif // LIBGAV1_ENABLE_NEON
+
+#endif // LIBGAV1_SRC_DSP_ARM_LOOP_FILTER_NEON_H_
diff --git a/src/dsp/arm/loop_restoration_neon.cc b/src/dsp/arm/loop_restoration_neon.cc
new file mode 100644
index 0000000..337c9b4
--- /dev/null
+++ b/src/dsp/arm/loop_restoration_neon.cc
@@ -0,0 +1,1901 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/loop_restoration.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+#include <arm_neon.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <cstring>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+#include "src/utils/constants.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+template <int bytes>
+inline uint8x8_t VshrU128(const uint8x8x2_t src) {
+ return vext_u8(src.val[0], src.val[1], bytes);
+}
+
+template <int bytes>
+inline uint16x8_t VshrU128(const uint16x8x2_t src) {
+ return vextq_u16(src.val[0], src.val[1], bytes / 2);
+}
+
+// Wiener
+
+// Must make a local copy of coefficients to help compiler know that they have
+// no overlap with other buffers. Using 'const' keyword is not enough. Actually
+// compiler doesn't make a copy, since there is enough registers in this case.
+inline void PopulateWienerCoefficients(
+ const RestorationUnitInfo& restoration_info, const int direction,
+ int16_t filter[4]) {
+ // In order to keep the horizontal pass intermediate values within 16 bits we
+ // offset |filter[3]| by 128. The 128 offset will be added back in the loop.
+ for (int i = 0; i < 4; ++i) {
+ filter[i] = restoration_info.wiener_info.filter[direction][i];
+ }
+ if (direction == WienerInfo::kHorizontal) {
+ filter[3] -= 128;
+ }
+}
+
+inline int16x8_t WienerHorizontal2(const uint8x8_t s0, const uint8x8_t s1,
+ const int16_t filter, const int16x8_t sum) {
+ const int16x8_t ss = vreinterpretq_s16_u16(vaddl_u8(s0, s1));
+ return vmlaq_n_s16(sum, ss, filter);
+}
+
+inline int16x8x2_t WienerHorizontal2(const uint8x16_t s0, const uint8x16_t s1,
+ const int16_t filter,
+ const int16x8x2_t sum) {
+ int16x8x2_t d;
+ d.val[0] =
+ WienerHorizontal2(vget_low_u8(s0), vget_low_u8(s1), filter, sum.val[0]);
+ d.val[1] =
+ WienerHorizontal2(vget_high_u8(s0), vget_high_u8(s1), filter, sum.val[1]);
+ return d;
+}
+
+inline void WienerHorizontalSum(const uint8x8_t s[3], const int16_t filter[4],
+ int16x8_t sum, int16_t* const wiener_buffer) {
+ constexpr int offset =
+ 1 << (8 + kWienerFilterBits - kInterRoundBitsHorizontal - 1);
+ constexpr int limit = (offset << 2) - 1;
+ const int16x8_t s_0_2 = vreinterpretq_s16_u16(vaddl_u8(s[0], s[2]));
+ const int16x8_t s_1 = ZeroExtend(s[1]);
+ sum = vmlaq_n_s16(sum, s_0_2, filter[2]);
+ sum = vmlaq_n_s16(sum, s_1, filter[3]);
+ // Calculate scaled down offset correction, and add to sum here to prevent
+ // signed 16 bit outranging.
+ sum = vrsraq_n_s16(vshlq_n_s16(s_1, 7 - kInterRoundBitsHorizontal), sum,
+ kInterRoundBitsHorizontal);
+ sum = vmaxq_s16(sum, vdupq_n_s16(-offset));
+ sum = vminq_s16(sum, vdupq_n_s16(limit - offset));
+ vst1q_s16(wiener_buffer, sum);
+}
+
+inline void WienerHorizontalSum(const uint8x16_t src[3],
+ const int16_t filter[4], int16x8x2_t sum,
+ int16_t* const wiener_buffer) {
+ uint8x8_t s[3];
+ s[0] = vget_low_u8(src[0]);
+ s[1] = vget_low_u8(src[1]);
+ s[2] = vget_low_u8(src[2]);
+ WienerHorizontalSum(s, filter, sum.val[0], wiener_buffer);
+ s[0] = vget_high_u8(src[0]);
+ s[1] = vget_high_u8(src[1]);
+ s[2] = vget_high_u8(src[2]);
+ WienerHorizontalSum(s, filter, sum.val[1], wiener_buffer + 8);
+}
+
+inline void WienerHorizontalTap7(const uint8_t* src, const ptrdiff_t src_stride,
+ const ptrdiff_t width, const int height,
+ const int16_t filter[4],
+ int16_t** const wiener_buffer) {
+ for (int y = height; y != 0; --y) {
+ const uint8_t* src_ptr = src;
+ uint8x16_t s[8];
+ s[0] = vld1q_u8(src_ptr);
+ ptrdiff_t x = width;
+ do {
+ src_ptr += 16;
+ s[7] = vld1q_u8(src_ptr);
+ s[1] = vextq_u8(s[0], s[7], 1);
+ s[2] = vextq_u8(s[0], s[7], 2);
+ s[3] = vextq_u8(s[0], s[7], 3);
+ s[4] = vextq_u8(s[0], s[7], 4);
+ s[5] = vextq_u8(s[0], s[7], 5);
+ s[6] = vextq_u8(s[0], s[7], 6);
+ int16x8x2_t sum;
+ sum.val[0] = sum.val[1] = vdupq_n_s16(0);
+ sum = WienerHorizontal2(s[0], s[6], filter[0], sum);
+ sum = WienerHorizontal2(s[1], s[5], filter[1], sum);
+ WienerHorizontalSum(s + 2, filter, sum, *wiener_buffer);
+ s[0] = s[7];
+ *wiener_buffer += 16;
+ x -= 16;
+ } while (x != 0);
+ src += src_stride;
+ }
+}
+
+inline void WienerHorizontalTap5(const uint8_t* src, const ptrdiff_t src_stride,
+ const ptrdiff_t width, const int height,
+ const int16_t filter[4],
+ int16_t** const wiener_buffer) {
+ for (int y = height; y != 0; --y) {
+ const uint8_t* src_ptr = src;
+ uint8x16_t s[6];
+ s[0] = vld1q_u8(src_ptr);
+ ptrdiff_t x = width;
+ do {
+ src_ptr += 16;
+ s[5] = vld1q_u8(src_ptr);
+ s[1] = vextq_u8(s[0], s[5], 1);
+ s[2] = vextq_u8(s[0], s[5], 2);
+ s[3] = vextq_u8(s[0], s[5], 3);
+ s[4] = vextq_u8(s[0], s[5], 4);
+ int16x8x2_t sum;
+ sum.val[0] = sum.val[1] = vdupq_n_s16(0);
+ sum = WienerHorizontal2(s[0], s[4], filter[1], sum);
+ WienerHorizontalSum(s + 1, filter, sum, *wiener_buffer);
+ s[0] = s[5];
+ *wiener_buffer += 16;
+ x -= 16;
+ } while (x != 0);
+ src += src_stride;
+ }
+}
+
+inline void WienerHorizontalTap3(const uint8_t* src, const ptrdiff_t src_stride,
+ const ptrdiff_t width, const int height,
+ const int16_t filter[4],
+ int16_t** const wiener_buffer) {
+ for (int y = height; y != 0; --y) {
+ const uint8_t* src_ptr = src;
+ uint8x16_t s[4];
+ s[0] = vld1q_u8(src_ptr);
+ ptrdiff_t x = width;
+ do {
+ src_ptr += 16;
+ s[3] = vld1q_u8(src_ptr);
+ s[1] = vextq_u8(s[0], s[3], 1);
+ s[2] = vextq_u8(s[0], s[3], 2);
+ int16x8x2_t sum;
+ sum.val[0] = sum.val[1] = vdupq_n_s16(0);
+ WienerHorizontalSum(s, filter, sum, *wiener_buffer);
+ s[0] = s[3];
+ *wiener_buffer += 16;
+ x -= 16;
+ } while (x != 0);
+ src += src_stride;
+ }
+}
+
+inline void WienerHorizontalTap1(const uint8_t* src, const ptrdiff_t src_stride,
+ const ptrdiff_t width, const int height,
+ int16_t** const wiener_buffer) {
+ for (int y = height; y != 0; --y) {
+ const uint8_t* src_ptr = src;
+ ptrdiff_t x = width;
+ do {
+ const uint8x16_t s = vld1q_u8(src_ptr);
+ const uint8x8_t s0 = vget_low_u8(s);
+ const uint8x8_t s1 = vget_high_u8(s);
+ const int16x8_t d0 = vreinterpretq_s16_u16(vshll_n_u8(s0, 4));
+ const int16x8_t d1 = vreinterpretq_s16_u16(vshll_n_u8(s1, 4));
+ vst1q_s16(*wiener_buffer + 0, d0);
+ vst1q_s16(*wiener_buffer + 8, d1);
+ src_ptr += 16;
+ *wiener_buffer += 16;
+ x -= 16;
+ } while (x != 0);
+ src += src_stride;
+ }
+}
+
+inline int32x4x2_t WienerVertical2(const int16x8_t a0, const int16x8_t a1,
+ const int16_t filter,
+ const int32x4x2_t sum) {
+ const int16x8_t a = vaddq_s16(a0, a1);
+ int32x4x2_t d;
+ d.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(a), filter);
+ d.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(a), filter);
+ return d;
+}
+
+inline uint8x8_t WienerVertical(const int16x8_t a[3], const int16_t filter[4],
+ const int32x4x2_t sum) {
+ int32x4x2_t d = WienerVertical2(a[0], a[2], filter[2], sum);
+ d.val[0] = vmlal_n_s16(d.val[0], vget_low_s16(a[1]), filter[3]);
+ d.val[1] = vmlal_n_s16(d.val[1], vget_high_s16(a[1]), filter[3]);
+ const uint16x4_t sum_lo_16 = vqrshrun_n_s32(d.val[0], 11);
+ const uint16x4_t sum_hi_16 = vqrshrun_n_s32(d.val[1], 11);
+ return vqmovn_u16(vcombine_u16(sum_lo_16, sum_hi_16));
+}
+
+inline uint8x8_t WienerVerticalTap7Kernel(const int16_t* const wiener_buffer,
+ const ptrdiff_t wiener_stride,
+ const int16_t filter[4],
+ int16x8_t a[7]) {
+ int32x4x2_t sum;
+ a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
+ a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
+ a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride);
+ a[6] = vld1q_s16(wiener_buffer + 6 * wiener_stride);
+ sum.val[0] = sum.val[1] = vdupq_n_s32(0);
+ sum = WienerVertical2(a[0], a[6], filter[0], sum);
+ sum = WienerVertical2(a[1], a[5], filter[1], sum);
+ a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
+ a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
+ a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride);
+ return WienerVertical(a + 2, filter, sum);
+}
+
+inline uint8x8x2_t WienerVerticalTap7Kernel2(const int16_t* const wiener_buffer,
+ const ptrdiff_t wiener_stride,
+ const int16_t filter[4]) {
+ int16x8_t a[8];
+ int32x4x2_t sum;
+ uint8x8x2_t d;
+ d.val[0] = WienerVerticalTap7Kernel(wiener_buffer, wiener_stride, filter, a);
+ a[7] = vld1q_s16(wiener_buffer + 7 * wiener_stride);
+ sum.val[0] = sum.val[1] = vdupq_n_s32(0);
+ sum = WienerVertical2(a[1], a[7], filter[0], sum);
+ sum = WienerVertical2(a[2], a[6], filter[1], sum);
+ d.val[1] = WienerVertical(a + 3, filter, sum);
+ return d;
+}
+
+inline void WienerVerticalTap7(const int16_t* wiener_buffer,
+ const ptrdiff_t width, const int height,
+ const int16_t filter[4], uint8_t* dst,
+ const ptrdiff_t dst_stride) {
+ for (int y = height >> 1; y != 0; --y) {
+ uint8_t* dst_ptr = dst;
+ ptrdiff_t x = width;
+ do {
+ uint8x8x2_t d[2];
+ d[0] = WienerVerticalTap7Kernel2(wiener_buffer + 0, width, filter);
+ d[1] = WienerVerticalTap7Kernel2(wiener_buffer + 8, width, filter);
+ vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0]));
+ vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1]));
+ wiener_buffer += 16;
+ dst_ptr += 16;
+ x -= 16;
+ } while (x != 0);
+ wiener_buffer += width;
+ dst += 2 * dst_stride;
+ }
+
+ if ((height & 1) != 0) {
+ ptrdiff_t x = width;
+ do {
+ int16x8_t a[7];
+ const uint8x8_t d0 =
+ WienerVerticalTap7Kernel(wiener_buffer + 0, width, filter, a);
+ const uint8x8_t d1 =
+ WienerVerticalTap7Kernel(wiener_buffer + 8, width, filter, a);
+ vst1q_u8(dst, vcombine_u8(d0, d1));
+ wiener_buffer += 16;
+ dst += 16;
+ x -= 16;
+ } while (x != 0);
+ }
+}
+
+inline uint8x8_t WienerVerticalTap5Kernel(const int16_t* const wiener_buffer,
+ const ptrdiff_t wiener_stride,
+ const int16_t filter[4],
+ int16x8_t a[5]) {
+ a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
+ a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
+ a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
+ a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
+ a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride);
+ int32x4x2_t sum;
+ sum.val[0] = sum.val[1] = vdupq_n_s32(0);
+ sum = WienerVertical2(a[0], a[4], filter[1], sum);
+ return WienerVertical(a + 1, filter, sum);
+}
+
+inline uint8x8x2_t WienerVerticalTap5Kernel2(const int16_t* const wiener_buffer,
+ const ptrdiff_t wiener_stride,
+ const int16_t filter[4]) {
+ int16x8_t a[6];
+ int32x4x2_t sum;
+ uint8x8x2_t d;
+ d.val[0] = WienerVerticalTap5Kernel(wiener_buffer, wiener_stride, filter, a);
+ a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride);
+ sum.val[0] = sum.val[1] = vdupq_n_s32(0);
+ sum = WienerVertical2(a[1], a[5], filter[1], sum);
+ d.val[1] = WienerVertical(a + 2, filter, sum);
+ return d;
+}
+
+inline void WienerVerticalTap5(const int16_t* wiener_buffer,
+ const ptrdiff_t width, const int height,
+ const int16_t filter[4], uint8_t* dst,
+ const ptrdiff_t dst_stride) {
+ for (int y = height >> 1; y != 0; --y) {
+ uint8_t* dst_ptr = dst;
+ ptrdiff_t x = width;
+ do {
+ uint8x8x2_t d[2];
+ d[0] = WienerVerticalTap5Kernel2(wiener_buffer + 0, width, filter);
+ d[1] = WienerVerticalTap5Kernel2(wiener_buffer + 8, width, filter);
+ vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0]));
+ vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1]));
+ wiener_buffer += 16;
+ dst_ptr += 16;
+ x -= 16;
+ } while (x != 0);
+ wiener_buffer += width;
+ dst += 2 * dst_stride;
+ }
+
+ if ((height & 1) != 0) {
+ ptrdiff_t x = width;
+ do {
+ int16x8_t a[5];
+ const uint8x8_t d0 =
+ WienerVerticalTap5Kernel(wiener_buffer + 0, width, filter, a);
+ const uint8x8_t d1 =
+ WienerVerticalTap5Kernel(wiener_buffer + 8, width, filter, a);
+ vst1q_u8(dst, vcombine_u8(d0, d1));
+ wiener_buffer += 16;
+ dst += 16;
+ x -= 16;
+ } while (x != 0);
+ }
+}
+
+inline uint8x8_t WienerVerticalTap3Kernel(const int16_t* const wiener_buffer,
+ const ptrdiff_t wiener_stride,
+ const int16_t filter[4],
+ int16x8_t a[3]) {
+ a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
+ a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
+ a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
+ int32x4x2_t sum;
+ sum.val[0] = sum.val[1] = vdupq_n_s32(0);
+ return WienerVertical(a, filter, sum);
+}
+
+inline uint8x8x2_t WienerVerticalTap3Kernel2(const int16_t* const wiener_buffer,
+ const ptrdiff_t wiener_stride,
+ const int16_t filter[4]) {
+ int16x8_t a[4];
+ int32x4x2_t sum;
+ uint8x8x2_t d;
+ d.val[0] = WienerVerticalTap3Kernel(wiener_buffer, wiener_stride, filter, a);
+ a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
+ sum.val[0] = sum.val[1] = vdupq_n_s32(0);
+ d.val[1] = WienerVertical(a + 1, filter, sum);
+ return d;
+}
+
+inline void WienerVerticalTap3(const int16_t* wiener_buffer,
+ const ptrdiff_t width, const int height,
+ const int16_t filter[4], uint8_t* dst,
+ const ptrdiff_t dst_stride) {
+ for (int y = height >> 1; y != 0; --y) {
+ uint8_t* dst_ptr = dst;
+ ptrdiff_t x = width;
+ do {
+ uint8x8x2_t d[2];
+ d[0] = WienerVerticalTap3Kernel2(wiener_buffer + 0, width, filter);
+ d[1] = WienerVerticalTap3Kernel2(wiener_buffer + 8, width, filter);
+ vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0]));
+ vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1]));
+ wiener_buffer += 16;
+ dst_ptr += 16;
+ x -= 16;
+ } while (x != 0);
+ wiener_buffer += width;
+ dst += 2 * dst_stride;
+ }
+
+ if ((height & 1) != 0) {
+ ptrdiff_t x = width;
+ do {
+ int16x8_t a[3];
+ const uint8x8_t d0 =
+ WienerVerticalTap3Kernel(wiener_buffer + 0, width, filter, a);
+ const uint8x8_t d1 =
+ WienerVerticalTap3Kernel(wiener_buffer + 8, width, filter, a);
+ vst1q_u8(dst, vcombine_u8(d0, d1));
+ wiener_buffer += 16;
+ dst += 16;
+ x -= 16;
+ } while (x != 0);
+ }
+}
+
+inline void WienerVerticalTap1Kernel(const int16_t* const wiener_buffer,
+ uint8_t* const dst) {
+ const int16x8_t a0 = vld1q_s16(wiener_buffer + 0);
+ const int16x8_t a1 = vld1q_s16(wiener_buffer + 8);
+ const uint8x8_t d0 = vqrshrun_n_s16(a0, 4);
+ const uint8x8_t d1 = vqrshrun_n_s16(a1, 4);
+ vst1q_u8(dst, vcombine_u8(d0, d1));
+}
+
+inline void WienerVerticalTap1(const int16_t* wiener_buffer,
+ const ptrdiff_t width, const int height,
+ uint8_t* dst, const ptrdiff_t dst_stride) {
+ for (int y = height >> 1; y != 0; --y) {
+ uint8_t* dst_ptr = dst;
+ ptrdiff_t x = width;
+ do {
+ WienerVerticalTap1Kernel(wiener_buffer, dst_ptr);
+ WienerVerticalTap1Kernel(wiener_buffer + width, dst_ptr + dst_stride);
+ wiener_buffer += 16;
+ dst_ptr += 16;
+ x -= 16;
+ } while (x != 0);
+ wiener_buffer += width;
+ dst += 2 * dst_stride;
+ }
+
+ if ((height & 1) != 0) {
+ ptrdiff_t x = width;
+ do {
+ WienerVerticalTap1Kernel(wiener_buffer, dst);
+ wiener_buffer += 16;
+ dst += 16;
+ x -= 16;
+ } while (x != 0);
+ }
+}
+
+// For width 16 and up, store the horizontal results, and then do the vertical
+// filter row by row. This is faster than doing it column by column when
+// considering cache issues.
+void WienerFilter_NEON(const RestorationUnitInfo& restoration_info,
+ const void* const source, const void* const top_border,
+ const void* const bottom_border, const ptrdiff_t stride,
+ const int width, const int height,
+ RestorationBuffer* const restoration_buffer,
+ void* const dest) {
+ const int16_t* const number_leading_zero_coefficients =
+ restoration_info.wiener_info.number_leading_zero_coefficients;
+ const int number_rows_to_skip = std::max(
+ static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]),
+ 1);
+ const ptrdiff_t wiener_stride = Align(width, 16);
+ int16_t* const wiener_buffer_vertical = restoration_buffer->wiener_buffer;
+ // The values are saturated to 13 bits before storing.
+ int16_t* wiener_buffer_horizontal =
+ wiener_buffer_vertical + number_rows_to_skip * wiener_stride;
+ int16_t filter_horizontal[(kWienerFilterTaps + 1) / 2];
+ int16_t filter_vertical[(kWienerFilterTaps + 1) / 2];
+ PopulateWienerCoefficients(restoration_info, WienerInfo::kHorizontal,
+ filter_horizontal);
+ PopulateWienerCoefficients(restoration_info, WienerInfo::kVertical,
+ filter_vertical);
+
+ // horizontal filtering.
+ // Over-reads up to 15 - |kRestorationHorizontalBorder| values.
+ const int height_horizontal =
+ height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip;
+ const int height_extra = (height_horizontal - height) >> 1;
+ assert(height_extra <= 2);
+ const auto* const src = static_cast<const uint8_t*>(source);
+ const auto* const top = static_cast<const uint8_t*>(top_border);
+ const auto* const bottom = static_cast<const uint8_t*>(bottom_border);
+ if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) {
+ WienerHorizontalTap7(top + (2 - height_extra) * stride - 3, stride,
+ wiener_stride, height_extra, filter_horizontal,
+ &wiener_buffer_horizontal);
+ WienerHorizontalTap7(src - 3, stride, wiener_stride, height,
+ filter_horizontal, &wiener_buffer_horizontal);
+ WienerHorizontalTap7(bottom - 3, stride, wiener_stride, height_extra,
+ filter_horizontal, &wiener_buffer_horizontal);
+ } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) {
+ WienerHorizontalTap5(top + (2 - height_extra) * stride - 2, stride,
+ wiener_stride, height_extra, filter_horizontal,
+ &wiener_buffer_horizontal);
+ WienerHorizontalTap5(src - 2, stride, wiener_stride, height,
+ filter_horizontal, &wiener_buffer_horizontal);
+ WienerHorizontalTap5(bottom - 2, stride, wiener_stride, height_extra,
+ filter_horizontal, &wiener_buffer_horizontal);
+ } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) {
+ // The maximum over-reads happen here.
+ WienerHorizontalTap3(top + (2 - height_extra) * stride - 1, stride,
+ wiener_stride, height_extra, filter_horizontal,
+ &wiener_buffer_horizontal);
+ WienerHorizontalTap3(src - 1, stride, wiener_stride, height,
+ filter_horizontal, &wiener_buffer_horizontal);
+ WienerHorizontalTap3(bottom - 1, stride, wiener_stride, height_extra,
+ filter_horizontal, &wiener_buffer_horizontal);
+ } else {
+ assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3);
+ WienerHorizontalTap1(top + (2 - height_extra) * stride, stride,
+ wiener_stride, height_extra,
+ &wiener_buffer_horizontal);
+ WienerHorizontalTap1(src, stride, wiener_stride, height,
+ &wiener_buffer_horizontal);
+ WienerHorizontalTap1(bottom, stride, wiener_stride, height_extra,
+ &wiener_buffer_horizontal);
+ }
+
+ // vertical filtering.
+ // Over-writes up to 15 values.
+ auto* dst = static_cast<uint8_t*>(dest);
+ if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) {
+ // Because the top row of |source| is a duplicate of the second row, and the
+ // bottom row of |source| is a duplicate of its above row, we can duplicate
+ // the top and bottom row of |wiener_buffer| accordingly.
+ memcpy(wiener_buffer_horizontal, wiener_buffer_horizontal - wiener_stride,
+ sizeof(*wiener_buffer_horizontal) * wiener_stride);
+ memcpy(restoration_buffer->wiener_buffer,
+ restoration_buffer->wiener_buffer + wiener_stride,
+ sizeof(*restoration_buffer->wiener_buffer) * wiener_stride);
+ WienerVerticalTap7(wiener_buffer_vertical, wiener_stride, height,
+ filter_vertical, dst, stride);
+ } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) {
+ WienerVerticalTap5(wiener_buffer_vertical + wiener_stride, wiener_stride,
+ height, filter_vertical, dst, stride);
+ } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) {
+ WienerVerticalTap3(wiener_buffer_vertical + 2 * wiener_stride,
+ wiener_stride, height, filter_vertical, dst, stride);
+ } else {
+ assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3);
+ WienerVerticalTap1(wiener_buffer_vertical + 3 * wiener_stride,
+ wiener_stride, height, dst, stride);
+ }
+}
+
+//------------------------------------------------------------------------------
+// SGR
+
+inline void Prepare3_8(const uint8x8x2_t src, uint8x8_t dst[3]) {
+ dst[0] = VshrU128<0>(src);
+ dst[1] = VshrU128<1>(src);
+ dst[2] = VshrU128<2>(src);
+}
+
+inline void Prepare3_16(const uint16x8x2_t src, uint16x4_t low[3],
+ uint16x4_t high[3]) {
+ uint16x8_t s[3];
+ s[0] = VshrU128<0>(src);
+ s[1] = VshrU128<2>(src);
+ s[2] = VshrU128<4>(src);
+ low[0] = vget_low_u16(s[0]);
+ low[1] = vget_low_u16(s[1]);
+ low[2] = vget_low_u16(s[2]);
+ high[0] = vget_high_u16(s[0]);
+ high[1] = vget_high_u16(s[1]);
+ high[2] = vget_high_u16(s[2]);
+}
+
+inline void Prepare5_8(const uint8x8x2_t src, uint8x8_t dst[5]) {
+ dst[0] = VshrU128<0>(src);
+ dst[1] = VshrU128<1>(src);
+ dst[2] = VshrU128<2>(src);
+ dst[3] = VshrU128<3>(src);
+ dst[4] = VshrU128<4>(src);
+}
+
+inline void Prepare5_16(const uint16x8x2_t src, uint16x4_t low[5],
+ uint16x4_t high[5]) {
+ Prepare3_16(src, low, high);
+ const uint16x8_t s3 = VshrU128<6>(src);
+ const uint16x8_t s4 = VshrU128<8>(src);
+ low[3] = vget_low_u16(s3);
+ low[4] = vget_low_u16(s4);
+ high[3] = vget_high_u16(s3);
+ high[4] = vget_high_u16(s4);
+}
+
+inline uint16x8_t Sum3_16(const uint16x8_t src0, const uint16x8_t src1,
+ const uint16x8_t src2) {
+ const uint16x8_t sum = vaddq_u16(src0, src1);
+ return vaddq_u16(sum, src2);
+}
+
+inline uint16x8_t Sum3_16(const uint16x8_t src[3]) {
+ return Sum3_16(src[0], src[1], src[2]);
+}
+
+inline uint32x4_t Sum3_32(const uint32x4_t src0, const uint32x4_t src1,
+ const uint32x4_t src2) {
+ const uint32x4_t sum = vaddq_u32(src0, src1);
+ return vaddq_u32(sum, src2);
+}
+
+inline uint32x4x2_t Sum3_32(const uint32x4x2_t src[3]) {
+ uint32x4x2_t d;
+ d.val[0] = Sum3_32(src[0].val[0], src[1].val[0], src[2].val[0]);
+ d.val[1] = Sum3_32(src[0].val[1], src[1].val[1], src[2].val[1]);
+ return d;
+}
+
+inline uint16x8_t Sum3W_16(const uint8x8_t src[3]) {
+ const uint16x8_t sum = vaddl_u8(src[0], src[1]);
+ return vaddw_u8(sum, src[2]);
+}
+
+inline uint32x4_t Sum3W_32(const uint16x4_t src[3]) {
+ const uint32x4_t sum = vaddl_u16(src[0], src[1]);
+ return vaddw_u16(sum, src[2]);
+}
+
+inline uint16x8_t Sum5_16(const uint16x8_t src[5]) {
+ const uint16x8_t sum01 = vaddq_u16(src[0], src[1]);
+ const uint16x8_t sum23 = vaddq_u16(src[2], src[3]);
+ const uint16x8_t sum = vaddq_u16(sum01, sum23);
+ return vaddq_u16(sum, src[4]);
+}
+
+inline uint32x4_t Sum5_32(const uint32x4_t src0, const uint32x4_t src1,
+ const uint32x4_t src2, const uint32x4_t src3,
+ const uint32x4_t src4) {
+ const uint32x4_t sum01 = vaddq_u32(src0, src1);
+ const uint32x4_t sum23 = vaddq_u32(src2, src3);
+ const uint32x4_t sum = vaddq_u32(sum01, sum23);
+ return vaddq_u32(sum, src4);
+}
+
+inline uint32x4x2_t Sum5_32(const uint32x4x2_t src[5]) {
+ uint32x4x2_t d;
+ d.val[0] = Sum5_32(src[0].val[0], src[1].val[0], src[2].val[0], src[3].val[0],
+ src[4].val[0]);
+ d.val[1] = Sum5_32(src[0].val[1], src[1].val[1], src[2].val[1], src[3].val[1],
+ src[4].val[1]);
+ return d;
+}
+
+inline uint32x4_t Sum5W_32(const uint16x4_t src[5]) {
+ const uint32x4_t sum01 = vaddl_u16(src[0], src[1]);
+ const uint32x4_t sum23 = vaddl_u16(src[2], src[3]);
+ const uint32x4_t sum0123 = vaddq_u32(sum01, sum23);
+ return vaddw_u16(sum0123, src[4]);
+}
+
+inline uint16x8_t Sum3Horizontal(const uint8x8x2_t src) {
+ uint8x8_t s[3];
+ Prepare3_8(src, s);
+ return Sum3W_16(s);
+}
+
+inline uint32x4x2_t Sum3WHorizontal(const uint16x8x2_t src) {
+ uint16x4_t low[3], high[3];
+ uint32x4x2_t sum;
+ Prepare3_16(src, low, high);
+ sum.val[0] = Sum3W_32(low);
+ sum.val[1] = Sum3W_32(high);
+ return sum;
+}
+
+inline uint16x8_t Sum5Horizontal(const uint8x8x2_t src) {
+ uint8x8_t s[5];
+ Prepare5_8(src, s);
+ const uint16x8_t sum01 = vaddl_u8(s[0], s[1]);
+ const uint16x8_t sum23 = vaddl_u8(s[2], s[3]);
+ const uint16x8_t sum0123 = vaddq_u16(sum01, sum23);
+ return vaddw_u8(sum0123, s[4]);
+}
+
+inline uint32x4x2_t Sum5WHorizontal(const uint16x8x2_t src) {
+ uint16x4_t low[5], high[5];
+ Prepare5_16(src, low, high);
+ uint32x4x2_t sum;
+ sum.val[0] = Sum5W_32(low);
+ sum.val[1] = Sum5W_32(high);
+ return sum;
+}
+
+void SumHorizontal(const uint16x4_t src[5], uint32x4_t* const row_sq3,
+ uint32x4_t* const row_sq5) {
+ const uint32x4_t sum04 = vaddl_u16(src[0], src[4]);
+ const uint32x4_t sum12 = vaddl_u16(src[1], src[2]);
+ *row_sq3 = vaddw_u16(sum12, src[3]);
+ *row_sq5 = vaddq_u32(sum04, *row_sq3);
+}
+
+void SumHorizontal(const uint8x8x2_t src, const uint16x8x2_t sq,
+ uint16x8_t* const row3, uint16x8_t* const row5,
+ uint32x4x2_t* const row_sq3, uint32x4x2_t* const row_sq5) {
+ uint8x8_t s[5];
+ Prepare5_8(src, s);
+ const uint16x8_t sum04 = vaddl_u8(s[0], s[4]);
+ const uint16x8_t sum12 = vaddl_u8(s[1], s[2]);
+ *row3 = vaddw_u8(sum12, s[3]);
+ *row5 = vaddq_u16(sum04, *row3);
+ uint16x4_t low[5], high[5];
+ Prepare5_16(sq, low, high);
+ SumHorizontal(low, &row_sq3->val[0], &row_sq5->val[0]);
+ SumHorizontal(high, &row_sq3->val[1], &row_sq5->val[1]);
+}
+
+inline uint16x8_t Sum343(const uint8x8x2_t src) {
+ uint8x8_t s[3];
+ Prepare3_8(src, s);
+ const uint16x8_t sum = Sum3W_16(s);
+ const uint16x8_t sum3 = Sum3_16(sum, sum, sum);
+ return vaddw_u8(sum3, s[1]);
+}
+
+inline uint32x4_t Sum343W(const uint16x4_t src[3]) {
+ const uint32x4_t sum = Sum3W_32(src);
+ const uint32x4_t sum3 = Sum3_32(sum, sum, sum);
+ return vaddw_u16(sum3, src[1]);
+}
+
+inline uint32x4x2_t Sum343W(const uint16x8x2_t src) {
+ uint16x4_t low[3], high[3];
+ uint32x4x2_t d;
+ Prepare3_16(src, low, high);
+ d.val[0] = Sum343W(low);
+ d.val[1] = Sum343W(high);
+ return d;
+}
+
+inline uint16x8_t Sum565(const uint8x8x2_t src) {
+ uint8x8_t s[3];
+ Prepare3_8(src, s);
+ const uint16x8_t sum = Sum3W_16(s);
+ const uint16x8_t sum4 = vshlq_n_u16(sum, 2);
+ const uint16x8_t sum5 = vaddq_u16(sum4, sum);
+ return vaddw_u8(sum5, s[1]);
+}
+
+inline uint32x4_t Sum565W(const uint16x4_t src[3]) {
+ const uint32x4_t sum = Sum3W_32(src);
+ const uint32x4_t sum4 = vshlq_n_u32(sum, 2);
+ const uint32x4_t sum5 = vaddq_u32(sum4, sum);
+ return vaddw_u16(sum5, src[1]);
+}
+
+inline uint32x4x2_t Sum565W(const uint16x8x2_t src) {
+ uint16x4_t low[3], high[3];
+ uint32x4x2_t d;
+ Prepare3_16(src, low, high);
+ d.val[0] = Sum565W(low);
+ d.val[1] = Sum565W(high);
+ return d;
+}
+
+inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride,
+ const int height, const ptrdiff_t sum_stride, uint16_t* sum3,
+ uint16_t* sum5, uint32_t* square_sum3,
+ uint32_t* square_sum5) {
+ int y = height;
+ do {
+ uint8x8x2_t s;
+ uint16x8x2_t sq;
+ s.val[0] = vld1_u8(src);
+ sq.val[0] = vmull_u8(s.val[0], s.val[0]);
+ ptrdiff_t x = 0;
+ do {
+ uint16x8_t row3, row5;
+ uint32x4x2_t row_sq3, row_sq5;
+ s.val[1] = vld1_u8(src + x + 8);
+ sq.val[1] = vmull_u8(s.val[1], s.val[1]);
+ SumHorizontal(s, sq, &row3, &row5, &row_sq3, &row_sq5);
+ vst1q_u16(sum3, row3);
+ vst1q_u16(sum5, row5);
+ vst1q_u32(square_sum3 + 0, row_sq3.val[0]);
+ vst1q_u32(square_sum3 + 4, row_sq3.val[1]);
+ vst1q_u32(square_sum5 + 0, row_sq5.val[0]);
+ vst1q_u32(square_sum5 + 4, row_sq5.val[1]);
+ s.val[0] = s.val[1];
+ sq.val[0] = sq.val[1];
+ sum3 += 8;
+ sum5 += 8;
+ square_sum3 += 8;
+ square_sum5 += 8;
+ x += 8;
+ } while (x < sum_stride);
+ src += src_stride;
+ } while (--y != 0);
+}
+
+template <int size>
+inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride,
+ const int height, const ptrdiff_t sum_stride, uint16_t* sums,
+ uint32_t* square_sums) {
+ static_assert(size == 3 || size == 5, "");
+ int y = height;
+ do {
+ uint8x8x2_t s;
+ uint16x8x2_t sq;
+ s.val[0] = vld1_u8(src);
+ sq.val[0] = vmull_u8(s.val[0], s.val[0]);
+ ptrdiff_t x = 0;
+ do {
+ uint16x8_t row;
+ uint32x4x2_t row_sq;
+ s.val[1] = vld1_u8(src + x + 8);
+ sq.val[1] = vmull_u8(s.val[1], s.val[1]);
+ if (size == 3) {
+ row = Sum3Horizontal(s);
+ row_sq = Sum3WHorizontal(sq);
+ } else {
+ row = Sum5Horizontal(s);
+ row_sq = Sum5WHorizontal(sq);
+ }
+ vst1q_u16(sums, row);
+ vst1q_u32(square_sums + 0, row_sq.val[0]);
+ vst1q_u32(square_sums + 4, row_sq.val[1]);
+ s.val[0] = s.val[1];
+ sq.val[0] = sq.val[1];
+ sums += 8;
+ square_sums += 8;
+ x += 8;
+ } while (x < sum_stride);
+ src += src_stride;
+ } while (--y != 0);
+}
+
+template <int n>
+inline uint16x4_t CalculateMa(const uint16x4_t sum, const uint32x4_t sum_sq,
+ const uint32_t scale) {
+ // a = |sum_sq|
+ // d = |sum|
+ // p = (a * n < d * d) ? 0 : a * n - d * d;
+ const uint32x4_t dxd = vmull_u16(sum, sum);
+ const uint32x4_t axn = vmulq_n_u32(sum_sq, n);
+ // Ensure |p| does not underflow by using saturating subtraction.
+ const uint32x4_t p = vqsubq_u32(axn, dxd);
+ const uint32x4_t pxs = vmulq_n_u32(p, scale);
+ // vrshrn_n_u32() (narrowing shift) can only shift by 16 and kSgrProjScaleBits
+ // is 20.
+ const uint32x4_t shifted = vrshrq_n_u32(pxs, kSgrProjScaleBits);
+ return vmovn_u32(shifted);
+}
+
+template <int n>
+inline void CalculateIntermediate(const uint16x8_t sum,
+ const uint32x4x2_t sum_sq,
+ const uint32_t scale, uint8x8_t* const ma,
+ uint16x8_t* const b) {
+ constexpr uint32_t one_over_n =
+ ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n;
+ const uint16x4_t z0 = CalculateMa<n>(vget_low_u16(sum), sum_sq.val[0], scale);
+ const uint16x4_t z1 =
+ CalculateMa<n>(vget_high_u16(sum), sum_sq.val[1], scale);
+ const uint16x8_t z01 = vcombine_u16(z0, z1);
+ // Using vqmovn_u16() needs an extra sign extension instruction.
+ const uint16x8_t z = vminq_u16(z01, vdupq_n_u16(255));
+ // Using vgetq_lane_s16() can save the sign extension instruction.
+ const uint8_t lookup[8] = {
+ kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 0)],
+ kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 1)],
+ kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 2)],
+ kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 3)],
+ kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 4)],
+ kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 5)],
+ kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 6)],
+ kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 7)]};
+ *ma = vld1_u8(lookup);
+ // b = ma * b * one_over_n
+ // |ma| = [0, 255]
+ // |sum| is a box sum with radius 1 or 2.
+ // For the first pass radius is 2. Maximum value is 5x5x255 = 6375.
+ // For the second pass radius is 1. Maximum value is 3x3x255 = 2295.
+ // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n
+ // When radius is 2 |n| is 25. |one_over_n| is 164.
+ // When radius is 1 |n| is 9. |one_over_n| is 455.
+ // |kSgrProjReciprocalBits| is 12.
+ // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits).
+ // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits).
+ const uint16x8_t maq = vmovl_u8(*ma);
+ const uint32x4_t m0 = vmull_u16(vget_low_u16(maq), vget_low_u16(sum));
+ const uint32x4_t m1 = vmull_u16(vget_high_u16(maq), vget_high_u16(sum));
+ const uint32x4_t m2 = vmulq_n_u32(m0, one_over_n);
+ const uint32x4_t m3 = vmulq_n_u32(m1, one_over_n);
+ const uint16x4_t b_lo = vrshrn_n_u32(m2, kSgrProjReciprocalBits);
+ const uint16x4_t b_hi = vrshrn_n_u32(m3, kSgrProjReciprocalBits);
+ *b = vcombine_u16(b_lo, b_hi);
+}
+
+inline void CalculateIntermediate5(const uint16x8_t s5[5],
+ const uint32x4x2_t sq5[5],
+ const uint32_t scale, uint8x8_t* const ma,
+ uint16x8_t* const b) {
+ const uint16x8_t sum = Sum5_16(s5);
+ const uint32x4x2_t sum_sq = Sum5_32(sq5);
+ CalculateIntermediate<25>(sum, sum_sq, scale, ma, b);
+}
+
+inline void CalculateIntermediate3(const uint16x8_t s3[3],
+ const uint32x4x2_t sq3[3],
+ const uint32_t scale, uint8x8_t* const ma,
+ uint16x8_t* const b) {
+ const uint16x8_t sum = Sum3_16(s3);
+ const uint32x4x2_t sum_sq = Sum3_32(sq3);
+ CalculateIntermediate<9>(sum, sum_sq, scale, ma, b);
+}
+
+inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3,
+ const ptrdiff_t x, uint16x8_t* const sum_ma343,
+ uint16x8_t* const sum_ma444,
+ uint32x4x2_t* const sum_b343,
+ uint32x4x2_t* const sum_b444, uint16_t* const ma343,
+ uint16_t* const ma444, uint32_t* const b343,
+ uint32_t* const b444) {
+ uint8x8_t s[3];
+ Prepare3_8(ma3, s);
+ const uint16x8_t sum_ma111 = Sum3W_16(s);
+ *sum_ma444 = vshlq_n_u16(sum_ma111, 2);
+ const uint16x8_t sum333 = vsubq_u16(*sum_ma444, sum_ma111);
+ *sum_ma343 = vaddw_u8(sum333, s[1]);
+ uint16x4_t low[3], high[3];
+ uint32x4x2_t sum_b111;
+ Prepare3_16(b3, low, high);
+ sum_b111.val[0] = Sum3W_32(low);
+ sum_b111.val[1] = Sum3W_32(high);
+ sum_b444->val[0] = vshlq_n_u32(sum_b111.val[0], 2);
+ sum_b444->val[1] = vshlq_n_u32(sum_b111.val[1], 2);
+ sum_b343->val[0] = vsubq_u32(sum_b444->val[0], sum_b111.val[0]);
+ sum_b343->val[1] = vsubq_u32(sum_b444->val[1], sum_b111.val[1]);
+ sum_b343->val[0] = vaddw_u16(sum_b343->val[0], low[1]);
+ sum_b343->val[1] = vaddw_u16(sum_b343->val[1], high[1]);
+ vst1q_u16(ma343 + x, *sum_ma343);
+ vst1q_u16(ma444 + x, *sum_ma444);
+ vst1q_u32(b343 + x + 0, sum_b343->val[0]);
+ vst1q_u32(b343 + x + 4, sum_b343->val[1]);
+ vst1q_u32(b444 + x + 0, sum_b444->val[0]);
+ vst1q_u32(b444 + x + 4, sum_b444->val[1]);
+}
+
+inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3,
+ const ptrdiff_t x, uint16x8_t* const sum_ma343,
+ uint32x4x2_t* const sum_b343, uint16_t* const ma343,
+ uint16_t* const ma444, uint32_t* const b343,
+ uint32_t* const b444) {
+ uint16x8_t sum_ma444;
+ uint32x4x2_t sum_b444;
+ Store343_444(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, &sum_b444, ma343,
+ ma444, b343, b444);
+}
+
+inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3,
+ const ptrdiff_t x, uint16_t* const ma343,
+ uint16_t* const ma444, uint32_t* const b343,
+ uint32_t* const b444) {
+ uint16x8_t sum_ma343;
+ uint32x4x2_t sum_b343;
+ Store343_444(ma3, b3, x, &sum_ma343, &sum_b343, ma343, ma444, b343, b444);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5(
+ const uint8_t* const src0, const uint8_t* const src1, const ptrdiff_t x,
+ const uint32_t scale, uint16_t* const sum5[5],
+ uint32_t* const square_sum5[5], uint8x8x2_t s[2], uint16x8x2_t sq[2],
+ uint8x8_t* const ma, uint16x8_t* const b) {
+ uint16x8_t s5[5];
+ uint32x4x2_t sq5[5];
+ s[0].val[1] = vld1_u8(src0 + x + 8);
+ s[1].val[1] = vld1_u8(src1 + x + 8);
+ sq[0].val[1] = vmull_u8(s[0].val[1], s[0].val[1]);
+ sq[1].val[1] = vmull_u8(s[1].val[1], s[1].val[1]);
+ s5[3] = Sum5Horizontal(s[0]);
+ s5[4] = Sum5Horizontal(s[1]);
+ sq5[3] = Sum5WHorizontal(sq[0]);
+ sq5[4] = Sum5WHorizontal(sq[1]);
+ vst1q_u16(sum5[3] + x, s5[3]);
+ vst1q_u16(sum5[4] + x, s5[4]);
+ vst1q_u32(square_sum5[3] + x + 0, sq5[3].val[0]);
+ vst1q_u32(square_sum5[3] + x + 4, sq5[3].val[1]);
+ vst1q_u32(square_sum5[4] + x + 0, sq5[4].val[0]);
+ vst1q_u32(square_sum5[4] + x + 4, sq5[4].val[1]);
+ s5[0] = vld1q_u16(sum5[0] + x);
+ s5[1] = vld1q_u16(sum5[1] + x);
+ s5[2] = vld1q_u16(sum5[2] + x);
+ sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
+ sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
+ sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
+ sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
+ sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
+ sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
+ CalculateIntermediate5(s5, sq5, scale, ma, b);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRow(
+ const uint8_t* const src, const ptrdiff_t x, const uint32_t scale,
+ const uint16_t* const sum5[5], const uint32_t* const square_sum5[5],
+ uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma,
+ uint16x8_t* const b) {
+ uint16x8_t s5[5];
+ uint32x4x2_t sq5[5];
+ s->val[1] = vld1_u8(src + x + 8);
+ sq->val[1] = vmull_u8(s->val[1], s->val[1]);
+ s5[3] = s5[4] = Sum5Horizontal(*s);
+ sq5[3] = sq5[4] = Sum5WHorizontal(*sq);
+ s5[0] = vld1q_u16(sum5[0] + x);
+ s5[1] = vld1q_u16(sum5[1] + x);
+ s5[2] = vld1q_u16(sum5[2] + x);
+ sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
+ sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
+ sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
+ sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
+ sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
+ sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
+ CalculateIntermediate5(s5, sq5, scale, ma, b);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3(
+ const uint8_t* const src, const ptrdiff_t x, const uint32_t scale,
+ uint16_t* const sum3[3], uint32_t* const square_sum3[3],
+ uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma,
+ uint16x8_t* const b) {
+ uint16x8_t s3[3];
+ uint32x4x2_t sq3[3];
+ s->val[1] = vld1_u8(src + x + 8);
+ sq->val[1] = vmull_u8(s->val[1], s->val[1]);
+ s3[2] = Sum3Horizontal(*s);
+ sq3[2] = Sum3WHorizontal(*sq);
+ vst1q_u16(sum3[2] + x, s3[2]);
+ vst1q_u32(square_sum3[2] + x + 0, sq3[2].val[0]);
+ vst1q_u32(square_sum3[2] + x + 4, sq3[2].val[1]);
+ s3[0] = vld1q_u16(sum3[0] + x);
+ s3[1] = vld1q_u16(sum3[1] + x);
+ sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0);
+ sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4);
+ sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0);
+ sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4);
+ CalculateIntermediate3(s3, sq3, scale, ma, b);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess(
+ const uint8_t* const src0, const uint8_t* const src1, const ptrdiff_t x,
+ const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5],
+ uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
+ uint8x8x2_t s[2], uint16x8x2_t sq[2], uint8x8_t* const ma3_0,
+ uint8x8_t* const ma3_1, uint16x8_t* const b3_0, uint16x8_t* const b3_1,
+ uint8x8_t* const ma5, uint16x8_t* const b5) {
+ uint16x8_t s3[4], s5[5];
+ uint32x4x2_t sq3[4], sq5[5];
+ s[0].val[1] = vld1_u8(src0 + x + 8);
+ s[1].val[1] = vld1_u8(src1 + x + 8);
+ sq[0].val[1] = vmull_u8(s[0].val[1], s[0].val[1]);
+ sq[1].val[1] = vmull_u8(s[1].val[1], s[1].val[1]);
+ SumHorizontal(s[0], sq[0], &s3[2], &s5[3], &sq3[2], &sq5[3]);
+ SumHorizontal(s[1], sq[1], &s3[3], &s5[4], &sq3[3], &sq5[4]);
+ vst1q_u16(sum3[2] + x, s3[2]);
+ vst1q_u16(sum3[3] + x, s3[3]);
+ vst1q_u32(square_sum3[2] + x + 0, sq3[2].val[0]);
+ vst1q_u32(square_sum3[2] + x + 4, sq3[2].val[1]);
+ vst1q_u32(square_sum3[3] + x + 0, sq3[3].val[0]);
+ vst1q_u32(square_sum3[3] + x + 4, sq3[3].val[1]);
+ vst1q_u16(sum5[3] + x, s5[3]);
+ vst1q_u16(sum5[4] + x, s5[4]);
+ vst1q_u32(square_sum5[3] + x + 0, sq5[3].val[0]);
+ vst1q_u32(square_sum5[3] + x + 4, sq5[3].val[1]);
+ vst1q_u32(square_sum5[4] + x + 0, sq5[4].val[0]);
+ vst1q_u32(square_sum5[4] + x + 4, sq5[4].val[1]);
+ s3[0] = vld1q_u16(sum3[0] + x);
+ s3[1] = vld1q_u16(sum3[1] + x);
+ sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0);
+ sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4);
+ sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0);
+ sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4);
+ s5[0] = vld1q_u16(sum5[0] + x);
+ s5[1] = vld1q_u16(sum5[1] + x);
+ s5[2] = vld1q_u16(sum5[2] + x);
+ sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
+ sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
+ sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
+ sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
+ sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
+ sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
+ CalculateIntermediate3(s3, sq3, scales[1], ma3_0, b3_0);
+ CalculateIntermediate3(s3 + 1, sq3 + 1, scales[1], ma3_1, b3_1);
+ CalculateIntermediate5(s5, sq5, scales[0], ma5, b5);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow(
+ const uint8_t* const src, const ptrdiff_t x, const uint16_t scales[2],
+ const uint16_t* const sum3[4], const uint16_t* const sum5[5],
+ const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5],
+ uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma3,
+ uint8x8_t* const ma5, uint16x8_t* const b3, uint16x8_t* const b5) {
+ uint16x8_t s3[3], s5[5];
+ uint32x4x2_t sq3[3], sq5[5];
+ s->val[1] = vld1_u8(src + x + 8);
+ sq->val[1] = vmull_u8(s->val[1], s->val[1]);
+ SumHorizontal(*s, *sq, &s3[2], &s5[3], &sq3[2], &sq5[3]);
+ s5[0] = vld1q_u16(sum5[0] + x);
+ s5[1] = vld1q_u16(sum5[1] + x);
+ s5[2] = vld1q_u16(sum5[2] + x);
+ s5[4] = s5[3];
+ sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
+ sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
+ sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
+ sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
+ sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
+ sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
+ sq5[4] = sq5[3];
+ CalculateIntermediate5(s5, sq5, scales[0], ma5, b5);
+ s3[0] = vld1q_u16(sum3[0] + x);
+ s3[1] = vld1q_u16(sum3[1] + x);
+ sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0);
+ sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4);
+ sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0);
+ sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4);
+ CalculateIntermediate3(s3, sq3, scales[1], ma3, b3);
+}
+
+inline void BoxSumFilterPreProcess5(const uint8_t* const src0,
+ const uint8_t* const src1, const int width,
+ const uint32_t scale,
+ uint16_t* const sum5[5],
+ uint32_t* const square_sum5[5],
+ uint16_t* ma565, uint32_t* b565) {
+ uint8x8x2_t s[2], mas;
+ uint16x8x2_t sq[2], bs;
+ s[0].val[0] = vld1_u8(src0);
+ s[1].val[0] = vld1_u8(src1);
+ sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
+ sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
+ BoxFilterPreProcess5(src0, src1, 0, scale, sum5, square_sum5, s, sq,
+ &mas.val[0], &bs.val[0]);
+
+ int x = 0;
+ do {
+ s[0].val[0] = s[0].val[1];
+ s[1].val[0] = s[1].val[1];
+ sq[0].val[0] = sq[0].val[1];
+ sq[1].val[0] = sq[1].val[1];
+ BoxFilterPreProcess5(src0, src1, x + 8, scale, sum5, square_sum5, s, sq,
+ &mas.val[1], &bs.val[1]);
+ const uint16x8_t ma = Sum565(mas);
+ const uint32x4x2_t b = Sum565W(bs);
+ vst1q_u16(ma565, ma);
+ vst1q_u32(b565 + 0, b.val[0]);
+ vst1q_u32(b565 + 4, b.val[1]);
+ mas.val[0] = mas.val[1];
+ bs.val[0] = bs.val[1];
+ ma565 += 8;
+ b565 += 8;
+ x += 8;
+ } while (x < width);
+}
+
+template <bool calculate444>
+LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3(
+ const uint8_t* const src, const int width, const uint32_t scale,
+ uint16_t* const sum3[3], uint32_t* const square_sum3[3], uint16_t* ma343,
+ uint16_t* ma444, uint32_t* b343, uint32_t* b444) {
+ uint8x8x2_t s, mas;
+ uint16x8x2_t sq, bs;
+ s.val[0] = vld1_u8(src);
+ sq.val[0] = vmull_u8(s.val[0], s.val[0]);
+ BoxFilterPreProcess3(src, 0, scale, sum3, square_sum3, &s, &sq, &mas.val[0],
+ &bs.val[0]);
+
+ int x = 0;
+ do {
+ s.val[0] = s.val[1];
+ sq.val[0] = sq.val[1];
+ BoxFilterPreProcess3(src, x + 8, scale, sum3, square_sum3, &s, &sq,
+ &mas.val[1], &bs.val[1]);
+ if (calculate444) {
+ Store343_444(mas, bs, 0, ma343, ma444, b343, b444);
+ ma444 += 8;
+ b444 += 8;
+ } else {
+ const uint16x8_t ma = Sum343(mas);
+ const uint32x4x2_t b = Sum343W(bs);
+ vst1q_u16(ma343, ma);
+ vst1q_u32(b343 + 0, b.val[0]);
+ vst1q_u32(b343 + 4, b.val[1]);
+ }
+ mas.val[0] = mas.val[1];
+ bs.val[0] = bs.val[1];
+ ma343 += 8;
+ b343 += 8;
+ x += 8;
+ } while (x < width);
+}
+
+inline void BoxSumFilterPreProcess(
+ const uint8_t* const src0, const uint8_t* const src1, const int width,
+ const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5],
+ uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
+ uint16_t* const ma343[4], uint16_t* const ma444[2], uint16_t* ma565,
+ uint32_t* const b343[4], uint32_t* const b444[2], uint32_t* b565) {
+ uint8x8x2_t s[2];
+ uint8x8x2_t ma3[2], ma5;
+ uint16x8x2_t sq[2], b3[2], b5;
+ s[0].val[0] = vld1_u8(src0);
+ s[1].val[0] = vld1_u8(src1);
+ sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
+ sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
+ BoxFilterPreProcess(src0, src1, 0, scales, sum3, sum5, square_sum3,
+ square_sum5, s, sq, &ma3[0].val[0], &ma3[1].val[0],
+ &b3[0].val[0], &b3[1].val[0], &ma5.val[0], &b5.val[0]);
+
+ int x = 0;
+ do {
+ s[0].val[0] = s[0].val[1];
+ s[1].val[0] = s[1].val[1];
+ sq[0].val[0] = sq[0].val[1];
+ sq[1].val[0] = sq[1].val[1];
+ BoxFilterPreProcess(src0, src1, x + 8, scales, sum3, sum5, square_sum3,
+ square_sum5, s, sq, &ma3[0].val[1], &ma3[1].val[1],
+ &b3[0].val[1], &b3[1].val[1], &ma5.val[1], &b5.val[1]);
+ uint16x8_t ma = Sum343(ma3[0]);
+ uint32x4x2_t b = Sum343W(b3[0]);
+ vst1q_u16(ma343[0] + x, ma);
+ vst1q_u32(b343[0] + x, b.val[0]);
+ vst1q_u32(b343[0] + x + 4, b.val[1]);
+ Store343_444(ma3[1], b3[1], x, ma343[1], ma444[0], b343[1], b444[0]);
+ ma = Sum565(ma5);
+ b = Sum565W(b5);
+ vst1q_u16(ma565, ma);
+ vst1q_u32(b565 + 0, b.val[0]);
+ vst1q_u32(b565 + 4, b.val[1]);
+ ma3[0].val[0] = ma3[0].val[1];
+ ma3[1].val[0] = ma3[1].val[1];
+ b3[0].val[0] = b3[0].val[1];
+ b3[1].val[0] = b3[1].val[1];
+ ma5.val[0] = ma5.val[1];
+ b5.val[0] = b5.val[1];
+ ma565 += 8;
+ b565 += 8;
+ x += 8;
+ } while (x < width);
+}
+
+template <int shift>
+inline int16x4_t FilterOutput(const uint16x4_t src, const uint16x4_t ma,
+ const uint32x4_t b) {
+ // ma: 255 * 32 = 8160 (13 bits)
+ // b: 65088 * 32 = 2082816 (21 bits)
+ // v: b - ma * 255 (22 bits)
+ const int32x4_t v = vreinterpretq_s32_u32(vmlsl_u16(b, ma, src));
+ // kSgrProjSgrBits = 8
+ // kSgrProjRestoreBits = 4
+ // shift = 4 or 5
+ // v >> 8 or 9 (13 bits)
+ return vrshrn_n_s32(v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
+}
+
+template <int shift>
+inline int16x8_t CalculateFilteredOutput(const uint8x8_t src,
+ const uint16x8_t ma,
+ const uint32x4x2_t b) {
+ const uint16x8_t src_u16 = vmovl_u8(src);
+ const int16x4_t dst_lo =
+ FilterOutput<shift>(vget_low_u16(src_u16), vget_low_u16(ma), b.val[0]);
+ const int16x4_t dst_hi =
+ FilterOutput<shift>(vget_high_u16(src_u16), vget_high_u16(ma), b.val[1]);
+ return vcombine_s16(dst_lo, dst_hi); // 13 bits
+}
+
+inline int16x8_t CalculateFilteredOutputPass1(const uint8x8_t s,
+ uint16x8_t ma[2],
+ uint32x4x2_t b[2]) {
+ const uint16x8_t ma_sum = vaddq_u16(ma[0], ma[1]);
+ uint32x4x2_t b_sum;
+ b_sum.val[0] = vaddq_u32(b[0].val[0], b[1].val[0]);
+ b_sum.val[1] = vaddq_u32(b[0].val[1], b[1].val[1]);
+ return CalculateFilteredOutput<5>(s, ma_sum, b_sum);
+}
+
+inline int16x8_t CalculateFilteredOutputPass2(const uint8x8_t s,
+ uint16x8_t ma[3],
+ uint32x4x2_t b[3]) {
+ const uint16x8_t ma_sum = Sum3_16(ma);
+ const uint32x4x2_t b_sum = Sum3_32(b);
+ return CalculateFilteredOutput<5>(s, ma_sum, b_sum);
+}
+
+inline void SelfGuidedFinal(const uint8x8_t src, const int32x4_t v[2],
+ uint8_t* const dst) {
+ const int16x4_t v_lo =
+ vrshrn_n_s32(v[0], kSgrProjRestoreBits + kSgrProjPrecisionBits);
+ const int16x4_t v_hi =
+ vrshrn_n_s32(v[1], kSgrProjRestoreBits + kSgrProjPrecisionBits);
+ const int16x8_t vv = vcombine_s16(v_lo, v_hi);
+ const int16x8_t s = ZeroExtend(src);
+ const int16x8_t d = vaddq_s16(s, vv);
+ vst1_u8(dst, vqmovun_s16(d));
+}
+
+inline void SelfGuidedDoubleMultiplier(const uint8x8_t src,
+ const int16x8_t filter[2], const int w0,
+ const int w2, uint8_t* const dst) {
+ int32x4_t v[2];
+ v[0] = vmull_n_s16(vget_low_s16(filter[0]), w0);
+ v[1] = vmull_n_s16(vget_high_s16(filter[0]), w0);
+ v[0] = vmlal_n_s16(v[0], vget_low_s16(filter[1]), w2);
+ v[1] = vmlal_n_s16(v[1], vget_high_s16(filter[1]), w2);
+ SelfGuidedFinal(src, v, dst);
+}
+
+inline void SelfGuidedSingleMultiplier(const uint8x8_t src,
+ const int16x8_t filter, const int w0,
+ uint8_t* const dst) {
+ // weight: -96 to 96 (Sgrproj_Xqd_Min/Max)
+ int32x4_t v[2];
+ v[0] = vmull_n_s16(vget_low_s16(filter), w0);
+ v[1] = vmull_n_s16(vget_high_s16(filter), w0);
+ SelfGuidedFinal(src, v, dst);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPass1(
+ const uint8_t* const src, const uint8_t* const src0,
+ const uint8_t* const src1, const ptrdiff_t stride, uint16_t* const sum5[5],
+ uint32_t* const square_sum5[5], const int width, const uint32_t scale,
+ const int16_t w0, uint16_t* const ma565[2], uint32_t* const b565[2],
+ uint8_t* const dst) {
+ uint8x8x2_t s[2], mas;
+ uint16x8x2_t sq[2], bs;
+ s[0].val[0] = vld1_u8(src0);
+ s[1].val[0] = vld1_u8(src1);
+ sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
+ sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
+ BoxFilterPreProcess5(src0, src1, 0, scale, sum5, square_sum5, s, sq,
+ &mas.val[0], &bs.val[0]);
+
+ int x = 0;
+ do {
+ s[0].val[0] = s[0].val[1];
+ s[1].val[0] = s[1].val[1];
+ sq[0].val[0] = sq[0].val[1];
+ sq[1].val[0] = sq[1].val[1];
+ BoxFilterPreProcess5(src0, src1, x + 8, scale, sum5, square_sum5, s, sq,
+ &mas.val[1], &bs.val[1]);
+ uint16x8_t ma[2];
+ uint32x4x2_t b[2];
+ ma[1] = Sum565(mas);
+ b[1] = Sum565W(bs);
+ vst1q_u16(ma565[1] + x, ma[1]);
+ vst1q_u32(b565[1] + x + 0, b[1].val[0]);
+ vst1q_u32(b565[1] + x + 4, b[1].val[1]);
+ const uint8x8_t sr0 = vld1_u8(src + x);
+ const uint8x8_t sr1 = vld1_u8(src + stride + x);
+ int16x8_t p0, p1;
+ ma[0] = vld1q_u16(ma565[0] + x);
+ b[0].val[0] = vld1q_u32(b565[0] + x + 0);
+ b[0].val[1] = vld1q_u32(b565[0] + x + 4);
+ p0 = CalculateFilteredOutputPass1(sr0, ma, b);
+ p1 = CalculateFilteredOutput<4>(sr1, ma[1], b[1]);
+ SelfGuidedSingleMultiplier(sr0, p0, w0, dst + x);
+ SelfGuidedSingleMultiplier(sr1, p1, w0, dst + stride + x);
+ mas.val[0] = mas.val[1];
+ bs.val[0] = bs.val[1];
+ x += 8;
+ } while (x < width);
+}
+
+inline void BoxFilterPass1LastRow(const uint8_t* const src,
+ const uint8_t* const src0, const int width,
+ const uint32_t scale, const int16_t w0,
+ uint16_t* const sum5[5],
+ uint32_t* const square_sum5[5],
+ uint16_t* ma565, uint32_t* b565,
+ uint8_t* const dst) {
+ uint8x8x2_t s, mas;
+ uint16x8x2_t sq, bs;
+ s.val[0] = vld1_u8(src0);
+ sq.val[0] = vmull_u8(s.val[0], s.val[0]);
+ BoxFilterPreProcess5LastRow(src0, 0, scale, sum5, square_sum5, &s, &sq,
+ &mas.val[0], &bs.val[0]);
+
+ int x = 0;
+ do {
+ s.val[0] = s.val[1];
+ sq.val[0] = sq.val[1];
+ BoxFilterPreProcess5LastRow(src0, x + 8, scale, sum5, square_sum5, &s, &sq,
+ &mas.val[1], &bs.val[1]);
+ uint16x8_t ma[2];
+ uint32x4x2_t b[2];
+ ma[1] = Sum565(mas);
+ b[1] = Sum565W(bs);
+ mas.val[0] = mas.val[1];
+ bs.val[0] = bs.val[1];
+ ma[0] = vld1q_u16(ma565);
+ b[0].val[0] = vld1q_u32(b565 + 0);
+ b[0].val[1] = vld1q_u32(b565 + 4);
+ const uint8x8_t sr = vld1_u8(src + x);
+ const int16x8_t p = CalculateFilteredOutputPass1(sr, ma, b);
+ SelfGuidedSingleMultiplier(sr, p, w0, dst + x);
+ ma565 += 8;
+ b565 += 8;
+ x += 8;
+ } while (x < width);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPass2(
+ const uint8_t* const src, const uint8_t* const src0, const int width,
+ const uint32_t scale, const int16_t w0, uint16_t* const sum3[3],
+ uint32_t* const square_sum3[3], uint16_t* const ma343[3],
+ uint16_t* const ma444[2], uint32_t* const b343[3], uint32_t* const b444[2],
+ uint8_t* const dst) {
+ uint8x8x2_t s, mas;
+ uint16x8x2_t sq, bs;
+ s.val[0] = vld1_u8(src0);
+ sq.val[0] = vmull_u8(s.val[0], s.val[0]);
+ BoxFilterPreProcess3(src0, 0, scale, sum3, square_sum3, &s, &sq, &mas.val[0],
+ &bs.val[0]);
+
+ int x = 0;
+ do {
+ s.val[0] = s.val[1];
+ sq.val[0] = sq.val[1];
+ BoxFilterPreProcess3(src0, x + 8, scale, sum3, square_sum3, &s, &sq,
+ &mas.val[1], &bs.val[1]);
+ uint16x8_t ma[3];
+ uint32x4x2_t b[3];
+ Store343_444(mas, bs, x, &ma[2], &b[2], ma343[2], ma444[1], b343[2],
+ b444[1]);
+ const uint8x8_t sr = vld1_u8(src + x);
+ ma[0] = vld1q_u16(ma343[0] + x);
+ ma[1] = vld1q_u16(ma444[0] + x);
+ b[0].val[0] = vld1q_u32(b343[0] + x + 0);
+ b[0].val[1] = vld1q_u32(b343[0] + x + 4);
+ b[1].val[0] = vld1q_u32(b444[0] + x + 0);
+ b[1].val[1] = vld1q_u32(b444[0] + x + 4);
+ const int16x8_t p = CalculateFilteredOutputPass2(sr, ma, b);
+ SelfGuidedSingleMultiplier(sr, p, w0, dst + x);
+ mas.val[0] = mas.val[1];
+ bs.val[0] = bs.val[1];
+ x += 8;
+ } while (x < width);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilter(
+ const uint8_t* const src, const uint8_t* const src0,
+ const uint8_t* const src1, const ptrdiff_t stride, const int width,
+ const uint16_t scales[2], const int16_t w0, const int16_t w2,
+ uint16_t* const sum3[4], uint16_t* const sum5[5],
+ uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
+ uint16_t* const ma343[4], uint16_t* const ma444[3],
+ uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3],
+ uint32_t* const b565[2], uint8_t* const dst) {
+ uint8x8x2_t s[2], ma3[2], ma5;
+ uint16x8x2_t sq[2], b3[2], b5;
+ s[0].val[0] = vld1_u8(src0);
+ s[1].val[0] = vld1_u8(src1);
+ sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
+ sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
+ BoxFilterPreProcess(src0, src1, 0, scales, sum3, sum5, square_sum3,
+ square_sum5, s, sq, &ma3[0].val[0], &ma3[1].val[0],
+ &b3[0].val[0], &b3[1].val[0], &ma5.val[0], &b5.val[0]);
+
+ int x = 0;
+ do {
+ s[0].val[0] = s[0].val[1];
+ s[1].val[0] = s[1].val[1];
+ sq[0].val[0] = sq[0].val[1];
+ sq[1].val[0] = sq[1].val[1];
+ BoxFilterPreProcess(src0, src1, x + 8, scales, sum3, sum5, square_sum3,
+ square_sum5, s, sq, &ma3[0].val[1], &ma3[1].val[1],
+ &b3[0].val[1], &b3[1].val[1], &ma5.val[1], &b5.val[1]);
+ uint16x8_t ma[3][3];
+ uint32x4x2_t b[3][3];
+ Store343_444(ma3[0], b3[0], x, &ma[1][2], &ma[2][1], &b[1][2], &b[2][1],
+ ma343[2], ma444[1], b343[2], b444[1]);
+ Store343_444(ma3[1], b3[1], x, &ma[2][2], &b[2][2], ma343[3], ma444[2],
+ b343[3], b444[2]);
+ ma[0][1] = Sum565(ma5);
+ b[0][1] = Sum565W(b5);
+ vst1q_u16(ma565[1] + x, ma[0][1]);
+ vst1q_u32(b565[1] + x, b[0][1].val[0]);
+ vst1q_u32(b565[1] + x + 4, b[0][1].val[1]);
+ ma3[0].val[0] = ma3[0].val[1];
+ ma3[1].val[0] = ma3[1].val[1];
+ b3[0].val[0] = b3[0].val[1];
+ b3[1].val[0] = b3[1].val[1];
+ ma5.val[0] = ma5.val[1];
+ b5.val[0] = b5.val[1];
+ int16x8_t p[2][2];
+ const uint8x8_t sr0 = vld1_u8(src + x);
+ const uint8x8_t sr1 = vld1_u8(src + stride + x);
+ ma[0][0] = vld1q_u16(ma565[0] + x);
+ b[0][0].val[0] = vld1q_u32(b565[0] + x);
+ b[0][0].val[1] = vld1q_u32(b565[0] + x + 4);
+ p[0][0] = CalculateFilteredOutputPass1(sr0, ma[0], b[0]);
+ p[1][0] = CalculateFilteredOutput<4>(sr1, ma[0][1], b[0][1]);
+ ma[1][0] = vld1q_u16(ma343[0] + x);
+ ma[1][1] = vld1q_u16(ma444[0] + x);
+ b[1][0].val[0] = vld1q_u32(b343[0] + x);
+ b[1][0].val[1] = vld1q_u32(b343[0] + x + 4);
+ b[1][1].val[0] = vld1q_u32(b444[0] + x);
+ b[1][1].val[1] = vld1q_u32(b444[0] + x + 4);
+ p[0][1] = CalculateFilteredOutputPass2(sr0, ma[1], b[1]);
+ ma[2][0] = vld1q_u16(ma343[1] + x);
+ b[2][0].val[0] = vld1q_u32(b343[1] + x);
+ b[2][0].val[1] = vld1q_u32(b343[1] + x + 4);
+ p[1][1] = CalculateFilteredOutputPass2(sr1, ma[2], b[2]);
+ SelfGuidedDoubleMultiplier(sr0, p[0], w0, w2, dst + x);
+ SelfGuidedDoubleMultiplier(sr1, p[1], w0, w2, dst + stride + x);
+ x += 8;
+ } while (x < width);
+}
+
+inline void BoxFilterLastRow(
+ const uint8_t* const src, const uint8_t* const src0, const int width,
+ const uint16_t scales[2], const int16_t w0, const int16_t w2,
+ uint16_t* const sum3[4], uint16_t* const sum5[5],
+ uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
+ uint16_t* const ma343[4], uint16_t* const ma444[3],
+ uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3],
+ uint32_t* const b565[2], uint8_t* const dst) {
+ uint8x8x2_t s, ma3, ma5;
+ uint16x8x2_t sq, b3, b5;
+ uint16x8_t ma[3];
+ uint32x4x2_t b[3];
+ s.val[0] = vld1_u8(src0);
+ sq.val[0] = vmull_u8(s.val[0], s.val[0]);
+ BoxFilterPreProcessLastRow(src0, 0, scales, sum3, sum5, square_sum3,
+ square_sum5, &s, &sq, &ma3.val[0], &ma5.val[0],
+ &b3.val[0], &b5.val[0]);
+
+ int x = 0;
+ do {
+ s.val[0] = s.val[1];
+ sq.val[0] = sq.val[1];
+ BoxFilterPreProcessLastRow(src0, x + 8, scales, sum3, sum5, square_sum3,
+ square_sum5, &s, &sq, &ma3.val[1], &ma5.val[1],
+ &b3.val[1], &b5.val[1]);
+ ma[1] = Sum565(ma5);
+ b[1] = Sum565W(b5);
+ ma5.val[0] = ma5.val[1];
+ b5.val[0] = b5.val[1];
+ ma[2] = Sum343(ma3);
+ b[2] = Sum343W(b3);
+ ma3.val[0] = ma3.val[1];
+ b3.val[0] = b3.val[1];
+ const uint8x8_t sr = vld1_u8(src + x);
+ int16x8_t p[2];
+ ma[0] = vld1q_u16(ma565[0] + x);
+ b[0].val[0] = vld1q_u32(b565[0] + x + 0);
+ b[0].val[1] = vld1q_u32(b565[0] + x + 4);
+ p[0] = CalculateFilteredOutputPass1(sr, ma, b);
+ ma[0] = vld1q_u16(ma343[0] + x);
+ ma[1] = vld1q_u16(ma444[0] + x);
+ b[0].val[0] = vld1q_u32(b343[0] + x + 0);
+ b[0].val[1] = vld1q_u32(b343[0] + x + 4);
+ b[1].val[0] = vld1q_u32(b444[0] + x + 0);
+ b[1].val[1] = vld1q_u32(b444[0] + x + 4);
+ p[1] = CalculateFilteredOutputPass2(sr, ma, b);
+ SelfGuidedDoubleMultiplier(sr, p, w0, w2, dst + x);
+ x += 8;
+ } while (x < width);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterProcess(
+ const RestorationUnitInfo& restoration_info, const uint8_t* src,
+ const uint8_t* const top_border, const uint8_t* bottom_border,
+ const ptrdiff_t stride, const int width, const int height,
+ SgrBuffer* const sgr_buffer, uint8_t* dst) {
+ const auto temp_stride = Align<ptrdiff_t>(width, 8);
+ const ptrdiff_t sum_stride = temp_stride + 8;
+ const int sgr_proj_index = restoration_info.sgr_proj_info.index;
+ const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index]; // < 2^12.
+ const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
+ const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
+ const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1;
+ uint16_t *sum3[4], *sum5[5], *ma343[4], *ma444[3], *ma565[2];
+ uint32_t *square_sum3[4], *square_sum5[5], *b343[4], *b444[3], *b565[2];
+ sum3[0] = sgr_buffer->sum3;
+ square_sum3[0] = sgr_buffer->square_sum3;
+ ma343[0] = sgr_buffer->ma343;
+ b343[0] = sgr_buffer->b343;
+ for (int i = 1; i <= 3; ++i) {
+ sum3[i] = sum3[i - 1] + sum_stride;
+ square_sum3[i] = square_sum3[i - 1] + sum_stride;
+ ma343[i] = ma343[i - 1] + temp_stride;
+ b343[i] = b343[i - 1] + temp_stride;
+ }
+ sum5[0] = sgr_buffer->sum5;
+ square_sum5[0] = sgr_buffer->square_sum5;
+ for (int i = 1; i <= 4; ++i) {
+ sum5[i] = sum5[i - 1] + sum_stride;
+ square_sum5[i] = square_sum5[i - 1] + sum_stride;
+ }
+ ma444[0] = sgr_buffer->ma444;
+ b444[0] = sgr_buffer->b444;
+ for (int i = 1; i <= 2; ++i) {
+ ma444[i] = ma444[i - 1] + temp_stride;
+ b444[i] = b444[i - 1] + temp_stride;
+ }
+ ma565[0] = sgr_buffer->ma565;
+ ma565[1] = ma565[0] + temp_stride;
+ b565[0] = sgr_buffer->b565;
+ b565[1] = b565[0] + temp_stride;
+ assert(scales[0] != 0);
+ assert(scales[1] != 0);
+ BoxSum(top_border, stride, 2, sum_stride, sum3[0], sum5[1], square_sum3[0],
+ square_sum5[1]);
+ sum5[0] = sum5[1];
+ square_sum5[0] = square_sum5[1];
+ const uint8_t* const s = (height > 1) ? src + stride : bottom_border;
+ BoxSumFilterPreProcess(src, s, width, scales, sum3, sum5, square_sum3,
+ square_sum5, ma343, ma444, ma565[0], b343, b444,
+ b565[0]);
+ sum5[0] = sgr_buffer->sum5;
+ square_sum5[0] = sgr_buffer->square_sum5;
+
+ for (int y = (height >> 1) - 1; y > 0; --y) {
+ Circulate4PointersBy2<uint16_t>(sum3);
+ Circulate4PointersBy2<uint32_t>(square_sum3);
+ Circulate5PointersBy2<uint16_t>(sum5);
+ Circulate5PointersBy2<uint32_t>(square_sum5);
+ BoxFilter(src + 3, src + 2 * stride, src + 3 * stride, stride, width,
+ scales, w0, w2, sum3, sum5, square_sum3, square_sum5, ma343,
+ ma444, ma565, b343, b444, b565, dst);
+ src += 2 * stride;
+ dst += 2 * stride;
+ Circulate4PointersBy2<uint16_t>(ma343);
+ Circulate4PointersBy2<uint32_t>(b343);
+ std::swap(ma444[0], ma444[2]);
+ std::swap(b444[0], b444[2]);
+ std::swap(ma565[0], ma565[1]);
+ std::swap(b565[0], b565[1]);
+ }
+
+ Circulate4PointersBy2<uint16_t>(sum3);
+ Circulate4PointersBy2<uint32_t>(square_sum3);
+ Circulate5PointersBy2<uint16_t>(sum5);
+ Circulate5PointersBy2<uint32_t>(square_sum5);
+ if ((height & 1) == 0 || height > 1) {
+ const uint8_t* sr[2];
+ if ((height & 1) == 0) {
+ sr[0] = bottom_border;
+ sr[1] = bottom_border + stride;
+ } else {
+ sr[0] = src + 2 * stride;
+ sr[1] = bottom_border;
+ }
+ BoxFilter(src + 3, sr[0], sr[1], stride, width, scales, w0, w2, sum3, sum5,
+ square_sum3, square_sum5, ma343, ma444, ma565, b343, b444, b565,
+ dst);
+ }
+ if ((height & 1) != 0) {
+ if (height > 1) {
+ src += 2 * stride;
+ dst += 2 * stride;
+ Circulate4PointersBy2<uint16_t>(sum3);
+ Circulate4PointersBy2<uint32_t>(square_sum3);
+ Circulate5PointersBy2<uint16_t>(sum5);
+ Circulate5PointersBy2<uint32_t>(square_sum5);
+ Circulate4PointersBy2<uint16_t>(ma343);
+ Circulate4PointersBy2<uint32_t>(b343);
+ std::swap(ma444[0], ma444[2]);
+ std::swap(b444[0], b444[2]);
+ std::swap(ma565[0], ma565[1]);
+ std::swap(b565[0], b565[1]);
+ }
+ BoxFilterLastRow(src + 3, bottom_border + stride, width, scales, w0, w2,
+ sum3, sum5, square_sum3, square_sum5, ma343, ma444, ma565,
+ b343, b444, b565, dst);
+ }
+}
+
+inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info,
+ const uint8_t* src,
+ const uint8_t* const top_border,
+ const uint8_t* bottom_border,
+ const ptrdiff_t stride, const int width,
+ const int height, SgrBuffer* const sgr_buffer,
+ uint8_t* dst) {
+ const auto temp_stride = Align<ptrdiff_t>(width, 8);
+ const ptrdiff_t sum_stride = temp_stride + 8;
+ const int sgr_proj_index = restoration_info.sgr_proj_info.index;
+ const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0]; // < 2^12.
+ const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
+ uint16_t *sum5[5], *ma565[2];
+ uint32_t *square_sum5[5], *b565[2];
+ sum5[0] = sgr_buffer->sum5;
+ square_sum5[0] = sgr_buffer->square_sum5;
+ for (int i = 1; i <= 4; ++i) {
+ sum5[i] = sum5[i - 1] + sum_stride;
+ square_sum5[i] = square_sum5[i - 1] + sum_stride;
+ }
+ ma565[0] = sgr_buffer->ma565;
+ ma565[1] = ma565[0] + temp_stride;
+ b565[0] = sgr_buffer->b565;
+ b565[1] = b565[0] + temp_stride;
+ assert(scale != 0);
+ BoxSum<5>(top_border, stride, 2, sum_stride, sum5[1], square_sum5[1]);
+ sum5[0] = sum5[1];
+ square_sum5[0] = square_sum5[1];
+ const uint8_t* const s = (height > 1) ? src + stride : bottom_border;
+ BoxSumFilterPreProcess5(src, s, width, scale, sum5, square_sum5, ma565[0],
+ b565[0]);
+ sum5[0] = sgr_buffer->sum5;
+ square_sum5[0] = sgr_buffer->square_sum5;
+
+ for (int y = (height >> 1) - 1; y > 0; --y) {
+ Circulate5PointersBy2<uint16_t>(sum5);
+ Circulate5PointersBy2<uint32_t>(square_sum5);
+ BoxFilterPass1(src + 3, src + 2 * stride, src + 3 * stride, stride, sum5,
+ square_sum5, width, scale, w0, ma565, b565, dst);
+ src += 2 * stride;
+ dst += 2 * stride;
+ std::swap(ma565[0], ma565[1]);
+ std::swap(b565[0], b565[1]);
+ }
+
+ Circulate5PointersBy2<uint16_t>(sum5);
+ Circulate5PointersBy2<uint32_t>(square_sum5);
+ if ((height & 1) == 0 || height > 1) {
+ const uint8_t* sr[2];
+ if ((height & 1) == 0) {
+ sr[0] = bottom_border;
+ sr[1] = bottom_border + stride;
+ } else {
+ sr[0] = src + 2 * stride;
+ sr[1] = bottom_border;
+ }
+ BoxFilterPass1(src + 3, sr[0], sr[1], stride, sum5, square_sum5, width,
+ scale, w0, ma565, b565, dst);
+ }
+ if ((height & 1) != 0) {
+ if (height > 1) {
+ src += 2 * stride;
+ dst += 2 * stride;
+ std::swap(ma565[0], ma565[1]);
+ std::swap(b565[0], b565[1]);
+ Circulate5PointersBy2<uint16_t>(sum5);
+ Circulate5PointersBy2<uint32_t>(square_sum5);
+ }
+ BoxFilterPass1LastRow(src + 3, bottom_border + stride, width, scale, w0,
+ sum5, square_sum5, ma565[0], b565[0], dst);
+ }
+}
+
+inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info,
+ const uint8_t* src,
+ const uint8_t* const top_border,
+ const uint8_t* bottom_border,
+ const ptrdiff_t stride, const int width,
+ const int height, SgrBuffer* const sgr_buffer,
+ uint8_t* dst) {
+ assert(restoration_info.sgr_proj_info.multiplier[0] == 0);
+ const auto temp_stride = Align<ptrdiff_t>(width, 8);
+ const ptrdiff_t sum_stride = temp_stride + 8;
+ const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
+ const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1;
+ const int sgr_proj_index = restoration_info.sgr_proj_info.index;
+ const uint32_t scale = kSgrScaleParameter[sgr_proj_index][1]; // < 2^12.
+ uint16_t *sum3[3], *ma343[3], *ma444[2];
+ uint32_t *square_sum3[3], *b343[3], *b444[2];
+ sum3[0] = sgr_buffer->sum3;
+ square_sum3[0] = sgr_buffer->square_sum3;
+ ma343[0] = sgr_buffer->ma343;
+ b343[0] = sgr_buffer->b343;
+ for (int i = 1; i <= 2; ++i) {
+ sum3[i] = sum3[i - 1] + sum_stride;
+ square_sum3[i] = square_sum3[i - 1] + sum_stride;
+ ma343[i] = ma343[i - 1] + temp_stride;
+ b343[i] = b343[i - 1] + temp_stride;
+ }
+ ma444[0] = sgr_buffer->ma444;
+ ma444[1] = ma444[0] + temp_stride;
+ b444[0] = sgr_buffer->b444;
+ b444[1] = b444[0] + temp_stride;
+ assert(scale != 0);
+ BoxSum<3>(top_border, stride, 2, sum_stride, sum3[0], square_sum3[0]);
+ BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3, ma343[0],
+ nullptr, b343[0], nullptr);
+ Circulate3PointersBy1<uint16_t>(sum3);
+ Circulate3PointersBy1<uint32_t>(square_sum3);
+ const uint8_t* s;
+ if (height > 1) {
+ s = src + stride;
+ } else {
+ s = bottom_border;
+ bottom_border += stride;
+ }
+ BoxSumFilterPreProcess3<true>(s, width, scale, sum3, square_sum3, ma343[1],
+ ma444[0], b343[1], b444[0]);
+
+ for (int y = height - 2; y > 0; --y) {
+ Circulate3PointersBy1<uint16_t>(sum3);
+ Circulate3PointersBy1<uint32_t>(square_sum3);
+ BoxFilterPass2(src + 2, src + 2 * stride, width, scale, w0, sum3,
+ square_sum3, ma343, ma444, b343, b444, dst);
+ src += stride;
+ dst += stride;
+ Circulate3PointersBy1<uint16_t>(ma343);
+ Circulate3PointersBy1<uint32_t>(b343);
+ std::swap(ma444[0], ma444[1]);
+ std::swap(b444[0], b444[1]);
+ }
+
+ src += 2;
+ int y = std::min(height, 2);
+ do {
+ Circulate3PointersBy1<uint16_t>(sum3);
+ Circulate3PointersBy1<uint32_t>(square_sum3);
+ BoxFilterPass2(src, bottom_border, width, scale, w0, sum3, square_sum3,
+ ma343, ma444, b343, b444, dst);
+ src += stride;
+ dst += stride;
+ bottom_border += stride;
+ Circulate3PointersBy1<uint16_t>(ma343);
+ Circulate3PointersBy1<uint32_t>(b343);
+ std::swap(ma444[0], ma444[1]);
+ std::swap(b444[0], b444[1]);
+ } while (--y != 0);
+}
+
+// If |width| is non-multiple of 8, up to 7 more pixels are written to |dest| in
+// the end of each row. It is safe to overwrite the output as it will not be
+// part of the visible frame.
+void SelfGuidedFilter_NEON(
+ const RestorationUnitInfo& restoration_info, const void* const source,
+ const void* const top_border, const void* const bottom_border,
+ const ptrdiff_t stride, const int width, const int height,
+ RestorationBuffer* const restoration_buffer, void* const dest) {
+ const int index = restoration_info.sgr_proj_info.index;
+ const int radius_pass_0 = kSgrProjParams[index][0]; // 2 or 0
+ const int radius_pass_1 = kSgrProjParams[index][2]; // 1 or 0
+ const auto* const src = static_cast<const uint8_t*>(source);
+ const auto* top = static_cast<const uint8_t*>(top_border);
+ const auto* bottom = static_cast<const uint8_t*>(bottom_border);
+ auto* const dst = static_cast<uint8_t*>(dest);
+ SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer;
+ if (radius_pass_1 == 0) {
+ // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the
+ // following assertion.
+ assert(radius_pass_0 != 0);
+ BoxFilterProcessPass1(restoration_info, src - 3, top - 3, bottom - 3,
+ stride, width, height, sgr_buffer, dst);
+ } else if (radius_pass_0 == 0) {
+ BoxFilterProcessPass2(restoration_info, src - 2, top - 2, bottom - 2,
+ stride, width, height, sgr_buffer, dst);
+ } else {
+ BoxFilterProcess(restoration_info, src - 3, top - 3, bottom - 3, stride,
+ width, height, sgr_buffer, dst);
+ }
+}
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+ dsp->loop_restorations[0] = WienerFilter_NEON;
+ dsp->loop_restorations[1] = SelfGuidedFilter_NEON;
+}
+
+} // namespace
+} // namespace low_bitdepth
+
+void LoopRestorationInit_NEON() { low_bitdepth::Init8bpp(); }
+
+} // namespace dsp
+} // namespace libgav1
+
+#else // !LIBGAV1_ENABLE_NEON
+namespace libgav1 {
+namespace dsp {
+
+void LoopRestorationInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/loop_restoration_neon.h b/src/dsp/arm/loop_restoration_neon.h
new file mode 100644
index 0000000..b551610
--- /dev/null
+++ b/src/dsp/arm/loop_restoration_neon.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_LOOP_RESTORATION_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_LOOP_RESTORATION_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::loop_restorations, see the defines below for specifics.
+// This function is not thread-safe.
+void LoopRestorationInit_NEON();
+
+} // namespace dsp
+} // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+
+#define LIBGAV1_Dsp8bpp_WienerFilter LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_SelfGuidedFilter LIBGAV1_CPU_NEON
+
+#endif // LIBGAV1_ENABLE_NEON
+
+#endif // LIBGAV1_SRC_DSP_ARM_LOOP_RESTORATION_NEON_H_
diff --git a/src/dsp/arm/mask_blend_neon.cc b/src/dsp/arm/mask_blend_neon.cc
new file mode 100644
index 0000000..084f42f
--- /dev/null
+++ b/src/dsp/arm/mask_blend_neon.cc
@@ -0,0 +1,444 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/mask_blend.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+// TODO(b/150461164): Consider combining with GetInterIntraMask4x2().
+// Compound predictors use int16_t values and need to multiply long because the
+// Convolve range * 64 is 20 bits. Unfortunately there is no multiply int16_t by
+// int8_t and accumulate into int32_t instruction.
+template <int subsampling_x, int subsampling_y>
+inline int16x8_t GetMask4x2(const uint8_t* mask, ptrdiff_t mask_stride) {
+ if (subsampling_x == 1) {
+ const int16x4_t mask_val0 = vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask)));
+ const int16x4_t mask_val1 = vreinterpret_s16_u16(
+ vpaddl_u8(vld1_u8(mask + (mask_stride << subsampling_y))));
+ int16x8_t final_val;
+ if (subsampling_y == 1) {
+ const int16x4_t next_mask_val0 =
+ vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask + mask_stride)));
+ const int16x4_t next_mask_val1 =
+ vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask + mask_stride * 3)));
+ final_val = vaddq_s16(vcombine_s16(mask_val0, mask_val1),
+ vcombine_s16(next_mask_val0, next_mask_val1));
+ } else {
+ final_val = vreinterpretq_s16_u16(
+ vpaddlq_u8(vreinterpretq_u8_s16(vcombine_s16(mask_val0, mask_val1))));
+ }
+ return vrshrq_n_s16(final_val, subsampling_y + 1);
+ }
+ assert(subsampling_y == 0 && subsampling_x == 0);
+ const uint8x8_t mask_val0 = Load4(mask);
+ const uint8x8_t mask_val = Load4<1>(mask + mask_stride, mask_val0);
+ return vreinterpretq_s16_u16(vmovl_u8(mask_val));
+}
+
+template <int subsampling_x, int subsampling_y>
+inline int16x8_t GetMask8(const uint8_t* mask, ptrdiff_t mask_stride) {
+ if (subsampling_x == 1) {
+ int16x8_t mask_val = vreinterpretq_s16_u16(vpaddlq_u8(vld1q_u8(mask)));
+ if (subsampling_y == 1) {
+ const int16x8_t next_mask_val =
+ vreinterpretq_s16_u16(vpaddlq_u8(vld1q_u8(mask + mask_stride)));
+ mask_val = vaddq_s16(mask_val, next_mask_val);
+ }
+ return vrshrq_n_s16(mask_val, 1 + subsampling_y);
+ }
+ assert(subsampling_y == 0 && subsampling_x == 0);
+ const uint8x8_t mask_val = vld1_u8(mask);
+ return vreinterpretq_s16_u16(vmovl_u8(mask_val));
+}
+
+inline void WriteMaskBlendLine4x2(const int16_t* const pred_0,
+ const int16_t* const pred_1,
+ const int16x8_t pred_mask_0,
+ const int16x8_t pred_mask_1, uint8_t* dst,
+ const ptrdiff_t dst_stride) {
+ const int16x8_t pred_val_0 = vld1q_s16(pred_0);
+ const int16x8_t pred_val_1 = vld1q_s16(pred_1);
+ // int res = (mask_value * prediction_0[x] +
+ // (64 - mask_value) * prediction_1[x]) >> 6;
+ const int32x4_t weighted_pred_0_lo =
+ vmull_s16(vget_low_s16(pred_mask_0), vget_low_s16(pred_val_0));
+ const int32x4_t weighted_pred_0_hi =
+ vmull_s16(vget_high_s16(pred_mask_0), vget_high_s16(pred_val_0));
+ const int32x4_t weighted_combo_lo = vmlal_s16(
+ weighted_pred_0_lo, vget_low_s16(pred_mask_1), vget_low_s16(pred_val_1));
+ const int32x4_t weighted_combo_hi =
+ vmlal_s16(weighted_pred_0_hi, vget_high_s16(pred_mask_1),
+ vget_high_s16(pred_val_1));
+ // dst[x] = static_cast<Pixel>(
+ // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
+ // (1 << kBitdepth8) - 1));
+ const uint8x8_t result =
+ vqrshrun_n_s16(vcombine_s16(vshrn_n_s32(weighted_combo_lo, 6),
+ vshrn_n_s32(weighted_combo_hi, 6)),
+ 4);
+ StoreLo4(dst, result);
+ StoreHi4(dst + dst_stride, result);
+}
+
+template <int subsampling_x, int subsampling_y>
+inline void MaskBlending4x4_NEON(const int16_t* pred_0, const int16_t* pred_1,
+ const uint8_t* mask,
+ const ptrdiff_t mask_stride, uint8_t* dst,
+ const ptrdiff_t dst_stride) {
+ const int16x8_t mask_inverter = vdupq_n_s16(64);
+ int16x8_t pred_mask_0 =
+ GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+ int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
+ WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
+ dst_stride);
+ // TODO(b/150461164): Arm tends to do better with load(val); val += stride
+ // It may be possible to turn this into a loop with a templated height.
+ pred_0 += 4 << 1;
+ pred_1 += 4 << 1;
+ mask += mask_stride << (1 + subsampling_y);
+ dst += dst_stride << 1;
+
+ pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+ pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
+ WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
+ dst_stride);
+}
+
+template <int subsampling_x, int subsampling_y>
+inline void MaskBlending4xH_NEON(const int16_t* pred_0, const int16_t* pred_1,
+ const uint8_t* const mask_ptr,
+ const ptrdiff_t mask_stride, const int height,
+ uint8_t* dst, const ptrdiff_t dst_stride) {
+ const uint8_t* mask = mask_ptr;
+ if (height == 4) {
+ MaskBlending4x4_NEON<subsampling_x, subsampling_y>(
+ pred_0, pred_1, mask, mask_stride, dst, dst_stride);
+ return;
+ }
+ const int16x8_t mask_inverter = vdupq_n_s16(64);
+ int y = 0;
+ do {
+ int16x8_t pred_mask_0 =
+ GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+ int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
+
+ WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
+ dst_stride);
+ pred_0 += 4 << 1;
+ pred_1 += 4 << 1;
+ mask += mask_stride << (1 + subsampling_y);
+ dst += dst_stride << 1;
+
+ pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+ pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
+ WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
+ dst_stride);
+ pred_0 += 4 << 1;
+ pred_1 += 4 << 1;
+ mask += mask_stride << (1 + subsampling_y);
+ dst += dst_stride << 1;
+
+ pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+ pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
+ WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
+ dst_stride);
+ pred_0 += 4 << 1;
+ pred_1 += 4 << 1;
+ mask += mask_stride << (1 + subsampling_y);
+ dst += dst_stride << 1;
+
+ pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+ pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
+ WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
+ dst_stride);
+ pred_0 += 4 << 1;
+ pred_1 += 4 << 1;
+ mask += mask_stride << (1 + subsampling_y);
+ dst += dst_stride << 1;
+ y += 8;
+ } while (y < height);
+}
+
+template <int subsampling_x, int subsampling_y>
+inline void MaskBlend_NEON(const void* prediction_0, const void* prediction_1,
+ const ptrdiff_t /*prediction_stride_1*/,
+ const uint8_t* const mask_ptr,
+ const ptrdiff_t mask_stride, const int width,
+ const int height, void* dest,
+ const ptrdiff_t dst_stride) {
+ auto* dst = static_cast<uint8_t*>(dest);
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ if (width == 4) {
+ MaskBlending4xH_NEON<subsampling_x, subsampling_y>(
+ pred_0, pred_1, mask_ptr, mask_stride, height, dst, dst_stride);
+ return;
+ }
+ const uint8_t* mask = mask_ptr;
+ const int16x8_t mask_inverter = vdupq_n_s16(64);
+ int y = 0;
+ do {
+ int x = 0;
+ do {
+ const int16x8_t pred_mask_0 = GetMask8<subsampling_x, subsampling_y>(
+ mask + (x << subsampling_x), mask_stride);
+ // 64 - mask
+ const int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
+ const int16x8_t pred_val_0 = vld1q_s16(pred_0 + x);
+ const int16x8_t pred_val_1 = vld1q_s16(pred_1 + x);
+ uint8x8_t result;
+ // int res = (mask_value * prediction_0[x] +
+ // (64 - mask_value) * prediction_1[x]) >> 6;
+ const int32x4_t weighted_pred_0_lo =
+ vmull_s16(vget_low_s16(pred_mask_0), vget_low_s16(pred_val_0));
+ const int32x4_t weighted_pred_0_hi =
+ vmull_s16(vget_high_s16(pred_mask_0), vget_high_s16(pred_val_0));
+ const int32x4_t weighted_combo_lo =
+ vmlal_s16(weighted_pred_0_lo, vget_low_s16(pred_mask_1),
+ vget_low_s16(pred_val_1));
+ const int32x4_t weighted_combo_hi =
+ vmlal_s16(weighted_pred_0_hi, vget_high_s16(pred_mask_1),
+ vget_high_s16(pred_val_1));
+
+ // dst[x] = static_cast<Pixel>(
+ // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
+ // (1 << kBitdepth8) - 1));
+ result = vqrshrun_n_s16(vcombine_s16(vshrn_n_s32(weighted_combo_lo, 6),
+ vshrn_n_s32(weighted_combo_hi, 6)),
+ 4);
+ vst1_u8(dst + x, result);
+
+ x += 8;
+ } while (x < width);
+ dst += dst_stride;
+ pred_0 += width;
+ pred_1 += width;
+ mask += mask_stride << subsampling_y;
+ } while (++y < height);
+}
+
+// TODO(b/150461164): This is much faster for inter_intra (input is Pixel
+// values) but regresses compound versions (input is int16_t). Try to
+// consolidate these.
+template <int subsampling_x, int subsampling_y>
+inline uint8x8_t GetInterIntraMask4x2(const uint8_t* mask,
+ ptrdiff_t mask_stride) {
+ if (subsampling_x == 1) {
+ const uint8x8_t mask_val =
+ vpadd_u8(vld1_u8(mask), vld1_u8(mask + (mask_stride << subsampling_y)));
+ if (subsampling_y == 1) {
+ const uint8x8_t next_mask_val = vpadd_u8(vld1_u8(mask + mask_stride),
+ vld1_u8(mask + mask_stride * 3));
+
+ // Use a saturating add to work around the case where all |mask| values
+ // are 64. Together with the rounding shift this ensures the correct
+ // result.
+ const uint8x8_t sum = vqadd_u8(mask_val, next_mask_val);
+ return vrshr_n_u8(sum, /*subsampling_x=*/1 + subsampling_y);
+ }
+
+ return vrshr_n_u8(mask_val, /*subsampling_x=*/1);
+ }
+
+ assert(subsampling_y == 0 && subsampling_x == 0);
+ const uint8x8_t mask_val0 = Load4(mask);
+ // TODO(b/150461164): Investigate the source of |mask| and see if the stride
+ // can be removed.
+ // TODO(b/150461164): The unit tests start at 8x8. Does this get run?
+ return Load4<1>(mask + mask_stride, mask_val0);
+}
+
+template <int subsampling_x, int subsampling_y>
+inline uint8x8_t GetInterIntraMask8(const uint8_t* mask,
+ ptrdiff_t mask_stride) {
+ if (subsampling_x == 1) {
+ const uint8x16_t mask_val = vld1q_u8(mask);
+ const uint8x8_t mask_paired =
+ vpadd_u8(vget_low_u8(mask_val), vget_high_u8(mask_val));
+ if (subsampling_y == 1) {
+ const uint8x16_t next_mask_val = vld1q_u8(mask + mask_stride);
+ const uint8x8_t next_mask_paired =
+ vpadd_u8(vget_low_u8(next_mask_val), vget_high_u8(next_mask_val));
+
+ // Use a saturating add to work around the case where all |mask| values
+ // are 64. Together with the rounding shift this ensures the correct
+ // result.
+ const uint8x8_t sum = vqadd_u8(mask_paired, next_mask_paired);
+ return vrshr_n_u8(sum, /*subsampling_x=*/1 + subsampling_y);
+ }
+
+ return vrshr_n_u8(mask_paired, /*subsampling_x=*/1);
+ }
+
+ assert(subsampling_y == 0 && subsampling_x == 0);
+ return vld1_u8(mask);
+}
+
+inline void InterIntraWriteMaskBlendLine8bpp4x2(const uint8_t* const pred_0,
+ uint8_t* const pred_1,
+ const ptrdiff_t pred_stride_1,
+ const uint8x8_t pred_mask_0,
+ const uint8x8_t pred_mask_1) {
+ const uint8x8_t pred_val_0 = vld1_u8(pred_0);
+ uint8x8_t pred_val_1 = Load4(pred_1);
+ pred_val_1 = Load4<1>(pred_1 + pred_stride_1, pred_val_1);
+
+ const uint16x8_t weighted_pred_0 = vmull_u8(pred_mask_0, pred_val_0);
+ const uint16x8_t weighted_combo =
+ vmlal_u8(weighted_pred_0, pred_mask_1, pred_val_1);
+ const uint8x8_t result = vrshrn_n_u16(weighted_combo, 6);
+ StoreLo4(pred_1, result);
+ StoreHi4(pred_1 + pred_stride_1, result);
+}
+
+template <int subsampling_x, int subsampling_y>
+inline void InterIntraMaskBlending8bpp4x4_NEON(const uint8_t* pred_0,
+ uint8_t* pred_1,
+ const ptrdiff_t pred_stride_1,
+ const uint8_t* mask,
+ const ptrdiff_t mask_stride) {
+ const uint8x8_t mask_inverter = vdup_n_u8(64);
+ uint8x8_t pred_mask_1 =
+ GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+ uint8x8_t pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1);
+ InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1,
+ pred_mask_0, pred_mask_1);
+ pred_0 += 4 << 1;
+ pred_1 += pred_stride_1 << 1;
+ mask += mask_stride << (1 + subsampling_y);
+
+ pred_mask_1 =
+ GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+ pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1);
+ InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1,
+ pred_mask_0, pred_mask_1);
+}
+
+template <int subsampling_x, int subsampling_y>
+inline void InterIntraMaskBlending8bpp4xH_NEON(
+ const uint8_t* pred_0, uint8_t* pred_1, const ptrdiff_t pred_stride_1,
+ const uint8_t* mask, const ptrdiff_t mask_stride, const int height) {
+ if (height == 4) {
+ InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>(
+ pred_0, pred_1, pred_stride_1, mask, mask_stride);
+ return;
+ }
+ int y = 0;
+ do {
+ InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>(
+ pred_0, pred_1, pred_stride_1, mask, mask_stride);
+ pred_0 += 4 << 2;
+ pred_1 += pred_stride_1 << 2;
+ mask += mask_stride << (2 + subsampling_y);
+
+ InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>(
+ pred_0, pred_1, pred_stride_1, mask, mask_stride);
+ pred_0 += 4 << 2;
+ pred_1 += pred_stride_1 << 2;
+ mask += mask_stride << (2 + subsampling_y);
+ y += 8;
+ } while (y < height);
+}
+
+template <int subsampling_x, int subsampling_y>
+inline void InterIntraMaskBlend8bpp_NEON(const uint8_t* prediction_0,
+ uint8_t* prediction_1,
+ const ptrdiff_t prediction_stride_1,
+ const uint8_t* const mask_ptr,
+ const ptrdiff_t mask_stride,
+ const int width, const int height) {
+ if (width == 4) {
+ InterIntraMaskBlending8bpp4xH_NEON<subsampling_x, subsampling_y>(
+ prediction_0, prediction_1, prediction_stride_1, mask_ptr, mask_stride,
+ height);
+ return;
+ }
+ const uint8_t* mask = mask_ptr;
+ const uint8x8_t mask_inverter = vdup_n_u8(64);
+ int y = 0;
+ do {
+ int x = 0;
+ do {
+ // TODO(b/150461164): Consider a 16 wide specialization (at least for the
+ // unsampled version) to take advantage of vld1q_u8().
+ const uint8x8_t pred_mask_1 =
+ GetInterIntraMask8<subsampling_x, subsampling_y>(
+ mask + (x << subsampling_x), mask_stride);
+ // 64 - mask
+ const uint8x8_t pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1);
+ const uint8x8_t pred_val_0 = vld1_u8(prediction_0);
+ prediction_0 += 8;
+ const uint8x8_t pred_val_1 = vld1_u8(prediction_1 + x);
+ const uint16x8_t weighted_pred_0 = vmull_u8(pred_mask_0, pred_val_0);
+ // weighted_pred0 + weighted_pred1
+ const uint16x8_t weighted_combo =
+ vmlal_u8(weighted_pred_0, pred_mask_1, pred_val_1);
+ const uint8x8_t result = vrshrn_n_u16(weighted_combo, 6);
+ vst1_u8(prediction_1 + x, result);
+
+ x += 8;
+ } while (x < width);
+ prediction_1 += prediction_stride_1;
+ mask += mask_stride << subsampling_y;
+ } while (++y < height);
+}
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+ dsp->mask_blend[0][0] = MaskBlend_NEON<0, 0>;
+ dsp->mask_blend[1][0] = MaskBlend_NEON<1, 0>;
+ dsp->mask_blend[2][0] = MaskBlend_NEON<1, 1>;
+ // The is_inter_intra index of mask_blend[][] is replaced by
+ // inter_intra_mask_blend_8bpp[] in 8-bit.
+ dsp->inter_intra_mask_blend_8bpp[0] = InterIntraMaskBlend8bpp_NEON<0, 0>;
+ dsp->inter_intra_mask_blend_8bpp[1] = InterIntraMaskBlend8bpp_NEON<1, 0>;
+ dsp->inter_intra_mask_blend_8bpp[2] = InterIntraMaskBlend8bpp_NEON<1, 1>;
+}
+
+} // namespace
+} // namespace low_bitdepth
+
+void MaskBlendInit_NEON() { low_bitdepth::Init8bpp(); }
+
+} // namespace dsp
+} // namespace libgav1
+
+#else // !LIBGAV1_ENABLE_NEON
+
+namespace libgav1 {
+namespace dsp {
+
+void MaskBlendInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/mask_blend_neon.h b/src/dsp/arm/mask_blend_neon.h
new file mode 100644
index 0000000..3829274
--- /dev/null
+++ b/src/dsp/arm/mask_blend_neon.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_MASK_BLEND_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_MASK_BLEND_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::mask_blend. This function is not thread-safe.
+void MaskBlendInit_NEON();
+
+} // namespace dsp
+} // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_MaskBlend444 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_MaskBlend422 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_MaskBlend420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp444 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp422 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp420 LIBGAV1_CPU_NEON
+#endif // LIBGAV1_ENABLE_NEON
+
+#endif // LIBGAV1_SRC_DSP_ARM_MASK_BLEND_NEON_H_
diff --git a/src/dsp/arm/motion_field_projection_neon.cc b/src/dsp/arm/motion_field_projection_neon.cc
new file mode 100644
index 0000000..8caba7d
--- /dev/null
+++ b/src/dsp/arm/motion_field_projection_neon.cc
@@ -0,0 +1,393 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/motion_field_projection.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+#include "src/utils/constants.h"
+#include "src/utils/types.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+inline int16x8_t LoadDivision(const int8x8x2_t division_table,
+ const int8x8_t reference_offset) {
+ const int8x8_t kOne = vcreate_s8(0x0100010001000100);
+ const int8x16_t kOneQ = vcombine_s8(kOne, kOne);
+ const int8x8_t t = vadd_s8(reference_offset, reference_offset);
+ const int8x8x2_t tt = vzip_s8(t, t);
+ const int8x16_t t1 = vcombine_s8(tt.val[0], tt.val[1]);
+ const int8x16_t idx = vaddq_s8(t1, kOneQ);
+ const int8x8_t idx_low = vget_low_s8(idx);
+ const int8x8_t idx_high = vget_high_s8(idx);
+ const int16x4_t d0 = vreinterpret_s16_s8(vtbl2_s8(division_table, idx_low));
+ const int16x4_t d1 = vreinterpret_s16_s8(vtbl2_s8(division_table, idx_high));
+ return vcombine_s16(d0, d1);
+}
+
+inline int16x4_t MvProjection(const int16x4_t mv, const int16x4_t denominator,
+ const int numerator) {
+ const int32x4_t m0 = vmull_s16(mv, denominator);
+ const int32x4_t m = vmulq_n_s32(m0, numerator);
+ // Add the sign (0 or -1) to round towards zero.
+ const int32x4_t add_sign = vsraq_n_s32(m, m, 31);
+ return vqrshrn_n_s32(add_sign, 14);
+}
+
+inline int16x8_t MvProjectionClip(const int16x8_t mv,
+ const int16x8_t denominator,
+ const int numerator) {
+ const int16x4_t mv0 = vget_low_s16(mv);
+ const int16x4_t mv1 = vget_high_s16(mv);
+ const int16x4_t s0 = MvProjection(mv0, vget_low_s16(denominator), numerator);
+ const int16x4_t s1 = MvProjection(mv1, vget_high_s16(denominator), numerator);
+ const int16x8_t projection = vcombine_s16(s0, s1);
+ const int16x8_t projection_mv_clamp = vdupq_n_s16(kProjectionMvClamp);
+ const int16x8_t clamp = vminq_s16(projection, projection_mv_clamp);
+ return vmaxq_s16(clamp, vnegq_s16(projection_mv_clamp));
+}
+
+inline int8x8_t Project_NEON(const int16x8_t delta, const int16x8_t dst_sign) {
+ // Add 63 to negative delta so that it shifts towards zero.
+ const int16x8_t delta_sign = vshrq_n_s16(delta, 15);
+ const uint16x8_t delta_u = vreinterpretq_u16_s16(delta);
+ const uint16x8_t delta_sign_u = vreinterpretq_u16_s16(delta_sign);
+ const uint16x8_t delta_adjust_u = vsraq_n_u16(delta_u, delta_sign_u, 10);
+ const int16x8_t delta_adjust = vreinterpretq_s16_u16(delta_adjust_u);
+ const int16x8_t offset0 = vshrq_n_s16(delta_adjust, 6);
+ const int16x8_t offset1 = veorq_s16(offset0, dst_sign);
+ const int16x8_t offset2 = vsubq_s16(offset1, dst_sign);
+ return vqmovn_s16(offset2);
+}
+
+inline void GetPosition(
+ const int8x8x2_t division_table, const MotionVector* const mv,
+ const int numerator, const int x8_start, const int x8_end, const int x8,
+ const int8x8_t r_offsets, const int8x8_t source_reference_type8,
+ const int8x8_t skip_r, const int8x8_t y8_floor8, const int8x8_t y8_ceiling8,
+ const int16x8_t d_sign, const int delta, int8x8_t* const r,
+ int8x8_t* const position_y8, int8x8_t* const position_x8,
+ int64_t* const skip_64, int32x4_t mvs[2]) {
+ const auto* const mv_int = reinterpret_cast<const int32_t*>(mv + x8);
+ *r = vtbl1_s8(r_offsets, source_reference_type8);
+ const int16x8_t denorm = LoadDivision(division_table, source_reference_type8);
+ int16x8_t projection_mv[2];
+ mvs[0] = vld1q_s32(mv_int + 0);
+ mvs[1] = vld1q_s32(mv_int + 4);
+ // Deinterlace x and y components
+ const int16x8_t mv0 = vreinterpretq_s16_s32(mvs[0]);
+ const int16x8_t mv1 = vreinterpretq_s16_s32(mvs[1]);
+ const int16x8x2_t mv_yx = vuzpq_s16(mv0, mv1);
+ // numerator could be 0.
+ projection_mv[0] = MvProjectionClip(mv_yx.val[0], denorm, numerator);
+ projection_mv[1] = MvProjectionClip(mv_yx.val[1], denorm, numerator);
+ // Do not update the motion vector if the block position is not valid or
+ // if position_x8 is outside the current range of x8_start and x8_end.
+ // Note that position_y8 will always be within the range of y8_start and
+ // y8_end.
+ // After subtracting the base, valid projections are within 8-bit.
+ *position_y8 = Project_NEON(projection_mv[0], d_sign);
+ const int8x8_t position_x = Project_NEON(projection_mv[1], d_sign);
+ const int8x8_t k01234567 = vcreate_s8(uint64_t{0x0706050403020100});
+ *position_x8 = vqadd_s8(position_x, k01234567);
+ const int8x16_t position_xy = vcombine_s8(*position_x8, *position_y8);
+ const int x8_floor = std::max(
+ x8_start - x8, delta - kProjectionMvMaxHorizontalOffset); // [-8, 8]
+ const int x8_ceiling = std::min(
+ x8_end - x8, delta + 8 + kProjectionMvMaxHorizontalOffset); // [0, 16]
+ const int8x8_t x8_floor8 = vdup_n_s8(x8_floor);
+ const int8x8_t x8_ceiling8 = vdup_n_s8(x8_ceiling);
+ const int8x16_t floor_xy = vcombine_s8(x8_floor8, y8_floor8);
+ const int8x16_t ceiling_xy = vcombine_s8(x8_ceiling8, y8_ceiling8);
+ const uint8x16_t underflow = vcltq_s8(position_xy, floor_xy);
+ const uint8x16_t overflow = vcgeq_s8(position_xy, ceiling_xy);
+ const int8x16_t out = vreinterpretq_s8_u8(vorrq_u8(underflow, overflow));
+ const int8x8_t skip_low = vorr_s8(skip_r, vget_low_s8(out));
+ const int8x8_t skip = vorr_s8(skip_low, vget_high_s8(out));
+ *skip_64 = vget_lane_s64(vreinterpret_s64_s8(skip), 0);
+}
+
+template <int idx>
+inline void Store(const int16x8_t position, const int8x8_t reference_offset,
+ const int32x4_t mv, int8_t* dst_reference_offset,
+ MotionVector* dst_mv) {
+ const ptrdiff_t offset = vgetq_lane_s16(position, idx);
+ auto* const d_mv = reinterpret_cast<int32_t*>(&dst_mv[offset]);
+ vst1q_lane_s32(d_mv, mv, idx & 3);
+ vst1_lane_s8(&dst_reference_offset[offset], reference_offset, idx);
+}
+
+template <int idx>
+inline void CheckStore(const int8_t* skips, const int16x8_t position,
+ const int8x8_t reference_offset, const int32x4_t mv,
+ int8_t* dst_reference_offset, MotionVector* dst_mv) {
+ if (skips[idx] == 0) {
+ Store<idx>(position, reference_offset, mv, dst_reference_offset, dst_mv);
+ }
+}
+
+// 7.9.2.
+void MotionFieldProjectionKernel_NEON(const ReferenceInfo& reference_info,
+ const int reference_to_current_with_sign,
+ const int dst_sign, const int y8_start,
+ const int y8_end, const int x8_start,
+ const int x8_end,
+ TemporalMotionField* const motion_field) {
+ const ptrdiff_t stride = motion_field->mv.columns();
+ // The column range has to be offset by kProjectionMvMaxHorizontalOffset since
+ // coordinates in that range could end up being position_x8 because of
+ // projection.
+ const int adjusted_x8_start =
+ std::max(x8_start - kProjectionMvMaxHorizontalOffset, 0);
+ const int adjusted_x8_end = std::min(
+ x8_end + kProjectionMvMaxHorizontalOffset, static_cast<int>(stride));
+ const int adjusted_x8_end8 = adjusted_x8_end & ~7;
+ const int leftover = adjusted_x8_end - adjusted_x8_end8;
+ const int8_t* const reference_offsets =
+ reference_info.relative_distance_to.data();
+ const bool* const skip_references = reference_info.skip_references.data();
+ const int16_t* const projection_divisions =
+ reference_info.projection_divisions.data();
+ const ReferenceFrameType* source_reference_types =
+ &reference_info.motion_field_reference_frame[y8_start][0];
+ const MotionVector* mv = &reference_info.motion_field_mv[y8_start][0];
+ int8_t* dst_reference_offset = motion_field->reference_offset[y8_start];
+ MotionVector* dst_mv = motion_field->mv[y8_start];
+ const int16x8_t d_sign = vdupq_n_s16(dst_sign);
+
+ static_assert(sizeof(int8_t) == sizeof(bool), "");
+ static_assert(sizeof(int8_t) == sizeof(ReferenceFrameType), "");
+ static_assert(sizeof(int32_t) == sizeof(MotionVector), "");
+ assert(dst_sign == 0 || dst_sign == -1);
+ assert(stride == motion_field->reference_offset.columns());
+ assert((y8_start & 7) == 0);
+ assert((adjusted_x8_start & 7) == 0);
+ // The final position calculation is represented with int16_t. Valid
+ // position_y8 from its base is at most 7. After considering the horizontal
+ // offset which is at most |stride - 1|, we have the following assertion,
+ // which means this optimization works for frame width up to 32K (each
+ // position is a 8x8 block).
+ assert(8 * stride <= 32768);
+ const int8x8_t skip_reference =
+ vld1_s8(reinterpret_cast<const int8_t*>(skip_references));
+ const int8x8_t r_offsets = vld1_s8(reference_offsets);
+ const int8x16_t table = vreinterpretq_s8_s16(vld1q_s16(projection_divisions));
+ int8x8x2_t division_table;
+ division_table.val[0] = vget_low_s8(table);
+ division_table.val[1] = vget_high_s8(table);
+
+ int y8 = y8_start;
+ do {
+ const int y8_floor = (y8 & ~7) - y8; // [-7, 0]
+ const int y8_ceiling = std::min(y8_end - y8, y8_floor + 8); // [1, 8]
+ const int8x8_t y8_floor8 = vdup_n_s8(y8_floor);
+ const int8x8_t y8_ceiling8 = vdup_n_s8(y8_ceiling);
+ int x8;
+
+ for (x8 = adjusted_x8_start; x8 < adjusted_x8_end8; x8 += 8) {
+ const int8x8_t source_reference_type8 =
+ vld1_s8(reinterpret_cast<const int8_t*>(source_reference_types + x8));
+ const int8x8_t skip_r = vtbl1_s8(skip_reference, source_reference_type8);
+ const int64_t early_skip = vget_lane_s64(vreinterpret_s64_s8(skip_r), 0);
+ // Early termination #1 if all are skips. Chance is typically ~30-40%.
+ if (early_skip == -1) continue;
+ int64_t skip_64;
+ int8x8_t r, position_x8, position_y8;
+ int32x4_t mvs[2];
+ GetPosition(division_table, mv, reference_to_current_with_sign, x8_start,
+ x8_end, x8, r_offsets, source_reference_type8, skip_r,
+ y8_floor8, y8_ceiling8, d_sign, 0, &r, &position_y8,
+ &position_x8, &skip_64, mvs);
+ // Early termination #2 if all are skips.
+ // Chance is typically ~15-25% after Early termination #1.
+ if (skip_64 == -1) continue;
+ const int16x8_t p_y = vmovl_s8(position_y8);
+ const int16x8_t p_x = vmovl_s8(position_x8);
+ const int16x8_t pos = vmlaq_n_s16(p_x, p_y, stride);
+ const int16x8_t position = vaddq_s16(pos, vdupq_n_s16(x8));
+ if (skip_64 == 0) {
+ // Store all. Chance is typically ~70-85% after Early termination #2.
+ Store<0>(position, r, mvs[0], dst_reference_offset, dst_mv);
+ Store<1>(position, r, mvs[0], dst_reference_offset, dst_mv);
+ Store<2>(position, r, mvs[0], dst_reference_offset, dst_mv);
+ Store<3>(position, r, mvs[0], dst_reference_offset, dst_mv);
+ Store<4>(position, r, mvs[1], dst_reference_offset, dst_mv);
+ Store<5>(position, r, mvs[1], dst_reference_offset, dst_mv);
+ Store<6>(position, r, mvs[1], dst_reference_offset, dst_mv);
+ Store<7>(position, r, mvs[1], dst_reference_offset, dst_mv);
+ } else {
+ // Check and store each.
+ // Chance is typically ~15-30% after Early termination #2.
+ // The compiler is smart enough to not create the local buffer skips[].
+ int8_t skips[8];
+ memcpy(skips, &skip_64, sizeof(skips));
+ CheckStore<0>(skips, position, r, mvs[0], dst_reference_offset, dst_mv);
+ CheckStore<1>(skips, position, r, mvs[0], dst_reference_offset, dst_mv);
+ CheckStore<2>(skips, position, r, mvs[0], dst_reference_offset, dst_mv);
+ CheckStore<3>(skips, position, r, mvs[0], dst_reference_offset, dst_mv);
+ CheckStore<4>(skips, position, r, mvs[1], dst_reference_offset, dst_mv);
+ CheckStore<5>(skips, position, r, mvs[1], dst_reference_offset, dst_mv);
+ CheckStore<6>(skips, position, r, mvs[1], dst_reference_offset, dst_mv);
+ CheckStore<7>(skips, position, r, mvs[1], dst_reference_offset, dst_mv);
+ }
+ }
+
+ // The following leftover processing cannot be moved out of the do...while
+ // loop. Doing so may change the result storing orders of the same position.
+ if (leftover > 0) {
+ // Use SIMD only when leftover is at least 4, and there are at least 8
+ // elements in a row.
+ if (leftover >= 4 && adjusted_x8_start < adjusted_x8_end8) {
+ // Process the last 8 elements to avoid loading invalid memory. Some
+ // elements may have been processed in the above loop, which is OK.
+ const int delta = 8 - leftover;
+ x8 = adjusted_x8_end - 8;
+ const int8x8_t source_reference_type8 = vld1_s8(
+ reinterpret_cast<const int8_t*>(source_reference_types + x8));
+ const int8x8_t skip_r =
+ vtbl1_s8(skip_reference, source_reference_type8);
+ const int64_t early_skip =
+ vget_lane_s64(vreinterpret_s64_s8(skip_r), 0);
+ // Early termination #1 if all are skips.
+ if (early_skip != -1) {
+ int64_t skip_64;
+ int8x8_t r, position_x8, position_y8;
+ int32x4_t mvs[2];
+ GetPosition(division_table, mv, reference_to_current_with_sign,
+ x8_start, x8_end, x8, r_offsets, source_reference_type8,
+ skip_r, y8_floor8, y8_ceiling8, d_sign, delta, &r,
+ &position_y8, &position_x8, &skip_64, mvs);
+ // Early termination #2 if all are skips.
+ if (skip_64 != -1) {
+ const int16x8_t p_y = vmovl_s8(position_y8);
+ const int16x8_t p_x = vmovl_s8(position_x8);
+ const int16x8_t pos = vmlaq_n_s16(p_x, p_y, stride);
+ const int16x8_t position = vaddq_s16(pos, vdupq_n_s16(x8));
+ // Store up to 7 elements since leftover is at most 7.
+ if (skip_64 == 0) {
+ // Store all.
+ Store<1>(position, r, mvs[0], dst_reference_offset, dst_mv);
+ Store<2>(position, r, mvs[0], dst_reference_offset, dst_mv);
+ Store<3>(position, r, mvs[0], dst_reference_offset, dst_mv);
+ Store<4>(position, r, mvs[1], dst_reference_offset, dst_mv);
+ Store<5>(position, r, mvs[1], dst_reference_offset, dst_mv);
+ Store<6>(position, r, mvs[1], dst_reference_offset, dst_mv);
+ Store<7>(position, r, mvs[1], dst_reference_offset, dst_mv);
+ } else {
+ // Check and store each.
+ // The compiler is smart enough to not create the local buffer
+ // skips[].
+ int8_t skips[8];
+ memcpy(skips, &skip_64, sizeof(skips));
+ CheckStore<1>(skips, position, r, mvs[0], dst_reference_offset,
+ dst_mv);
+ CheckStore<2>(skips, position, r, mvs[0], dst_reference_offset,
+ dst_mv);
+ CheckStore<3>(skips, position, r, mvs[0], dst_reference_offset,
+ dst_mv);
+ CheckStore<4>(skips, position, r, mvs[1], dst_reference_offset,
+ dst_mv);
+ CheckStore<5>(skips, position, r, mvs[1], dst_reference_offset,
+ dst_mv);
+ CheckStore<6>(skips, position, r, mvs[1], dst_reference_offset,
+ dst_mv);
+ CheckStore<7>(skips, position, r, mvs[1], dst_reference_offset,
+ dst_mv);
+ }
+ }
+ }
+ } else {
+ for (; x8 < adjusted_x8_end; ++x8) {
+ const int source_reference_type = source_reference_types[x8];
+ if (skip_references[source_reference_type]) continue;
+ MotionVector projection_mv;
+ // reference_to_current_with_sign could be 0.
+ GetMvProjection(mv[x8], reference_to_current_with_sign,
+ projection_divisions[source_reference_type],
+ &projection_mv);
+ // Do not update the motion vector if the block position is not valid
+ // or if position_x8 is outside the current range of x8_start and
+ // x8_end. Note that position_y8 will always be within the range of
+ // y8_start and y8_end.
+ const int position_y8 = Project(0, projection_mv.mv[0], dst_sign);
+ if (position_y8 < y8_floor || position_y8 >= y8_ceiling) continue;
+ const int x8_base = x8 & ~7;
+ const int x8_floor =
+ std::max(x8_start, x8_base - kProjectionMvMaxHorizontalOffset);
+ const int x8_ceiling =
+ std::min(x8_end, x8_base + 8 + kProjectionMvMaxHorizontalOffset);
+ const int position_x8 = Project(x8, projection_mv.mv[1], dst_sign);
+ if (position_x8 < x8_floor || position_x8 >= x8_ceiling) continue;
+ dst_mv[position_y8 * stride + position_x8] = mv[x8];
+ dst_reference_offset[position_y8 * stride + position_x8] =
+ reference_offsets[source_reference_type];
+ }
+ }
+ }
+
+ source_reference_types += stride;
+ mv += stride;
+ dst_reference_offset += stride;
+ dst_mv += stride;
+ } while (++y8 < y8_end);
+}
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+ dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_NEON;
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+void Init10bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+ assert(dsp != nullptr);
+ dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_NEON;
+}
+#endif
+
+} // namespace
+
+void MotionFieldProjectionInit_NEON() {
+ Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+ Init10bpp();
+#endif
+}
+
+} // namespace dsp
+} // namespace libgav1
+
+#else // !LIBGAV1_ENABLE_NEON
+namespace libgav1 {
+namespace dsp {
+
+void MotionFieldProjectionInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/motion_field_projection_neon.h b/src/dsp/arm/motion_field_projection_neon.h
new file mode 100644
index 0000000..41ab6a6
--- /dev/null
+++ b/src/dsp/arm/motion_field_projection_neon.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_MOTION_FIELD_PROJECTION_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_MOTION_FIELD_PROJECTION_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::motion_field_projection_kernel. This function is not
+// thread-safe.
+void MotionFieldProjectionInit_NEON();
+
+} // namespace dsp
+} // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+
+#define LIBGAV1_Dsp8bpp_MotionFieldProjectionKernel LIBGAV1_CPU_NEON
+
+#endif // LIBGAV1_ENABLE_NEON
+
+#endif // LIBGAV1_SRC_DSP_ARM_MOTION_FIELD_PROJECTION_NEON_H_
diff --git a/src/dsp/arm/motion_vector_search_neon.cc b/src/dsp/arm/motion_vector_search_neon.cc
new file mode 100644
index 0000000..8a403a6
--- /dev/null
+++ b/src/dsp/arm/motion_vector_search_neon.cc
@@ -0,0 +1,267 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/motion_vector_search.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+#include "src/utils/constants.h"
+#include "src/utils/types.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+inline int16x4_t MvProjection(const int16x4_t mv, const int16x4_t denominator,
+ const int32x4_t numerator) {
+ const int32x4_t m0 = vmull_s16(mv, denominator);
+ const int32x4_t m = vmulq_s32(m0, numerator);
+ // Add the sign (0 or -1) to round towards zero.
+ const int32x4_t add_sign = vsraq_n_s32(m, m, 31);
+ return vqrshrn_n_s32(add_sign, 14);
+}
+
+inline int16x4_t MvProjectionCompound(const int16x4_t mv,
+ const int temporal_reference_offsets,
+ const int reference_offsets[2]) {
+ const int16x4_t denominator =
+ vdup_n_s16(kProjectionMvDivisionLookup[temporal_reference_offsets]);
+ const int32x2_t offset = vld1_s32(reference_offsets);
+ const int32x2x2_t offsets = vzip_s32(offset, offset);
+ const int32x4_t numerator = vcombine_s32(offsets.val[0], offsets.val[1]);
+ return MvProjection(mv, denominator, numerator);
+}
+
+inline int16x8_t ProjectionClip(const int16x4_t mv0, const int16x4_t mv1) {
+ const int16x8_t projection_mv_clamp = vdupq_n_s16(kProjectionMvClamp);
+ const int16x8_t mv = vcombine_s16(mv0, mv1);
+ const int16x8_t clamp = vminq_s16(mv, projection_mv_clamp);
+ return vmaxq_s16(clamp, vnegq_s16(projection_mv_clamp));
+}
+
+inline int16x8_t MvProjectionCompoundClip(
+ const MotionVector* const temporal_mvs,
+ const int8_t* const temporal_reference_offsets,
+ const int reference_offsets[2]) {
+ const auto* const tmvs = reinterpret_cast<const int32_t*>(temporal_mvs);
+ const int32x2_t temporal_mv = vld1_s32(tmvs);
+ const int16x4_t tmv0 = vreinterpret_s16_s32(vdup_lane_s32(temporal_mv, 0));
+ const int16x4_t tmv1 = vreinterpret_s16_s32(vdup_lane_s32(temporal_mv, 1));
+ const int16x4_t mv0 = MvProjectionCompound(
+ tmv0, temporal_reference_offsets[0], reference_offsets);
+ const int16x4_t mv1 = MvProjectionCompound(
+ tmv1, temporal_reference_offsets[1], reference_offsets);
+ return ProjectionClip(mv0, mv1);
+}
+
+inline int16x8_t MvProjectionSingleClip(
+ const MotionVector* const temporal_mvs,
+ const int8_t* const temporal_reference_offsets, const int reference_offset,
+ int16x4_t* const lookup) {
+ const auto* const tmvs = reinterpret_cast<const int16_t*>(temporal_mvs);
+ const int16x8_t temporal_mv = vld1q_s16(tmvs);
+ *lookup = vld1_lane_s16(
+ &kProjectionMvDivisionLookup[temporal_reference_offsets[0]], *lookup, 0);
+ *lookup = vld1_lane_s16(
+ &kProjectionMvDivisionLookup[temporal_reference_offsets[1]], *lookup, 1);
+ *lookup = vld1_lane_s16(
+ &kProjectionMvDivisionLookup[temporal_reference_offsets[2]], *lookup, 2);
+ *lookup = vld1_lane_s16(
+ &kProjectionMvDivisionLookup[temporal_reference_offsets[3]], *lookup, 3);
+ const int16x4x2_t denominator = vzip_s16(*lookup, *lookup);
+ const int16x4_t tmv0 = vget_low_s16(temporal_mv);
+ const int16x4_t tmv1 = vget_high_s16(temporal_mv);
+ const int32x4_t numerator = vdupq_n_s32(reference_offset);
+ const int16x4_t mv0 = MvProjection(tmv0, denominator.val[0], numerator);
+ const int16x4_t mv1 = MvProjection(tmv1, denominator.val[1], numerator);
+ return ProjectionClip(mv0, mv1);
+}
+
+inline void LowPrecision(const int16x8_t mv, void* const candidate_mvs) {
+ const int16x8_t kRoundDownMask = vdupq_n_s16(1);
+ const uint16x8_t mvu = vreinterpretq_u16_s16(mv);
+ const int16x8_t mv0 = vreinterpretq_s16_u16(vsraq_n_u16(mvu, mvu, 15));
+ const int16x8_t mv1 = vbicq_s16(mv0, kRoundDownMask);
+ vst1q_s16(static_cast<int16_t*>(candidate_mvs), mv1);
+}
+
+inline void ForceInteger(const int16x8_t mv, void* const candidate_mvs) {
+ const int16x8_t kRoundDownMask = vdupq_n_s16(7);
+ const uint16x8_t mvu = vreinterpretq_u16_s16(mv);
+ const int16x8_t mv0 = vreinterpretq_s16_u16(vsraq_n_u16(mvu, mvu, 15));
+ const int16x8_t mv1 = vaddq_s16(mv0, vdupq_n_s16(3));
+ const int16x8_t mv2 = vbicq_s16(mv1, kRoundDownMask);
+ vst1q_s16(static_cast<int16_t*>(candidate_mvs), mv2);
+}
+
+void MvProjectionCompoundLowPrecision_NEON(
+ const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+ const int reference_offsets[2], const int count,
+ CompoundMotionVector* candidate_mvs) {
+ // |reference_offsets| non-zero check usually equals true and is ignored.
+ // To facilitate the compilers, make a local copy of |reference_offsets|.
+ const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
+ // One more element could be calculated.
+ int loop_count = (count + 1) >> 1;
+ do {
+ const int16x8_t mv = MvProjectionCompoundClip(
+ temporal_mvs, temporal_reference_offsets, offsets);
+ LowPrecision(mv, candidate_mvs);
+ temporal_mvs += 2;
+ temporal_reference_offsets += 2;
+ candidate_mvs += 2;
+ } while (--loop_count);
+}
+
+void MvProjectionCompoundForceInteger_NEON(
+ const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+ const int reference_offsets[2], const int count,
+ CompoundMotionVector* candidate_mvs) {
+ // |reference_offsets| non-zero check usually equals true and is ignored.
+ // To facilitate the compilers, make a local copy of |reference_offsets|.
+ const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
+ // One more element could be calculated.
+ int loop_count = (count + 1) >> 1;
+ do {
+ const int16x8_t mv = MvProjectionCompoundClip(
+ temporal_mvs, temporal_reference_offsets, offsets);
+ ForceInteger(mv, candidate_mvs);
+ temporal_mvs += 2;
+ temporal_reference_offsets += 2;
+ candidate_mvs += 2;
+ } while (--loop_count);
+}
+
+void MvProjectionCompoundHighPrecision_NEON(
+ const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+ const int reference_offsets[2], const int count,
+ CompoundMotionVector* candidate_mvs) {
+ // |reference_offsets| non-zero check usually equals true and is ignored.
+ // To facilitate the compilers, make a local copy of |reference_offsets|.
+ const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
+ // One more element could be calculated.
+ int loop_count = (count + 1) >> 1;
+ do {
+ const int16x8_t mv = MvProjectionCompoundClip(
+ temporal_mvs, temporal_reference_offsets, offsets);
+ vst1q_s16(reinterpret_cast<int16_t*>(candidate_mvs), mv);
+ temporal_mvs += 2;
+ temporal_reference_offsets += 2;
+ candidate_mvs += 2;
+ } while (--loop_count);
+}
+
+void MvProjectionSingleLowPrecision_NEON(
+ const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+ const int reference_offset, const int count, MotionVector* candidate_mvs) {
+ // Up to three more elements could be calculated.
+ int loop_count = (count + 3) >> 2;
+ int16x4_t lookup = vdup_n_s16(0);
+ do {
+ const int16x8_t mv = MvProjectionSingleClip(
+ temporal_mvs, temporal_reference_offsets, reference_offset, &lookup);
+ LowPrecision(mv, candidate_mvs);
+ temporal_mvs += 4;
+ temporal_reference_offsets += 4;
+ candidate_mvs += 4;
+ } while (--loop_count);
+}
+
+void MvProjectionSingleForceInteger_NEON(
+ const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+ const int reference_offset, const int count, MotionVector* candidate_mvs) {
+ // Up to three more elements could be calculated.
+ int loop_count = (count + 3) >> 2;
+ int16x4_t lookup = vdup_n_s16(0);
+ do {
+ const int16x8_t mv = MvProjectionSingleClip(
+ temporal_mvs, temporal_reference_offsets, reference_offset, &lookup);
+ ForceInteger(mv, candidate_mvs);
+ temporal_mvs += 4;
+ temporal_reference_offsets += 4;
+ candidate_mvs += 4;
+ } while (--loop_count);
+}
+
+void MvProjectionSingleHighPrecision_NEON(
+ const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+ const int reference_offset, const int count, MotionVector* candidate_mvs) {
+ // Up to three more elements could be calculated.
+ int loop_count = (count + 3) >> 2;
+ int16x4_t lookup = vdup_n_s16(0);
+ do {
+ const int16x8_t mv = MvProjectionSingleClip(
+ temporal_mvs, temporal_reference_offsets, reference_offset, &lookup);
+ vst1q_s16(reinterpret_cast<int16_t*>(candidate_mvs), mv);
+ temporal_mvs += 4;
+ temporal_reference_offsets += 4;
+ candidate_mvs += 4;
+ } while (--loop_count);
+}
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+ dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_NEON;
+ dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_NEON;
+ dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_NEON;
+ dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_NEON;
+ dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_NEON;
+ dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_NEON;
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+void Init10bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+ assert(dsp != nullptr);
+ dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_NEON;
+ dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_NEON;
+ dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_NEON;
+ dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_NEON;
+ dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_NEON;
+ dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_NEON;
+}
+#endif
+
+} // namespace
+
+void MotionVectorSearchInit_NEON() {
+ Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+ Init10bpp();
+#endif
+}
+
+} // namespace dsp
+} // namespace libgav1
+
+#else // !LIBGAV1_ENABLE_NEON
+namespace libgav1 {
+namespace dsp {
+
+void MotionVectorSearchInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/motion_vector_search_neon.h b/src/dsp/arm/motion_vector_search_neon.h
new file mode 100644
index 0000000..19b4519
--- /dev/null
+++ b/src/dsp/arm/motion_vector_search_neon.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_MOTION_VECTOR_SEARCH_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_MOTION_VECTOR_SEARCH_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::mv_projection_compound and Dsp::mv_projection_single. This
+// function is not thread-safe.
+void MotionVectorSearchInit_NEON();
+
+} // namespace dsp
+} // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+
+#define LIBGAV1_Dsp8bpp_MotionVectorSearch LIBGAV1_CPU_NEON
+
+#endif // LIBGAV1_ENABLE_NEON
+
+#endif // LIBGAV1_SRC_DSP_ARM_MOTION_VECTOR_SEARCH_NEON_H_
diff --git a/src/dsp/arm/obmc_neon.cc b/src/dsp/arm/obmc_neon.cc
new file mode 100644
index 0000000..66ad663
--- /dev/null
+++ b/src/dsp/arm/obmc_neon.cc
@@ -0,0 +1,392 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/obmc.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <cstring>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+#include "src/dsp/obmc.inc"
+
+inline void WriteObmcLine4(uint8_t* const pred, const uint8_t* const obmc_pred,
+ const uint8x8_t pred_mask,
+ const uint8x8_t obmc_pred_mask) {
+ const uint8x8_t pred_val = Load4(pred);
+ const uint8x8_t obmc_pred_val = Load4(obmc_pred);
+ const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val);
+ const uint8x8_t result =
+ vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6);
+ StoreLo4(pred, result);
+}
+
+template <bool from_left>
+inline void OverlapBlend2xH_NEON(uint8_t* const prediction,
+ const ptrdiff_t prediction_stride,
+ const int height,
+ const uint8_t* const obmc_prediction,
+ const ptrdiff_t obmc_prediction_stride) {
+ uint8_t* pred = prediction;
+ const uint8x8_t mask_inverter = vdup_n_u8(64);
+ const uint8_t* obmc_pred = obmc_prediction;
+ uint8x8_t pred_mask;
+ uint8x8_t obmc_pred_mask;
+ int compute_height;
+ const int mask_offset = height - 2;
+ if (from_left) {
+ pred_mask = Load2(kObmcMask);
+ obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+ compute_height = height;
+ } else {
+ // Weights for the last line are all 64, which is a no-op.
+ compute_height = height - 1;
+ }
+ uint8x8_t pred_val = vdup_n_u8(0);
+ uint8x8_t obmc_pred_val = vdup_n_u8(0);
+ int y = 0;
+ do {
+ if (!from_left) {
+ pred_mask = vdup_n_u8(kObmcMask[mask_offset + y]);
+ obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+ }
+ pred_val = Load2<0>(pred, pred_val);
+ const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val);
+ obmc_pred_val = Load2<0>(obmc_pred, obmc_pred_val);
+ const uint8x8_t result =
+ vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6);
+ Store2<0>(pred, result);
+
+ pred += prediction_stride;
+ obmc_pred += obmc_prediction_stride;
+ } while (++y != compute_height);
+}
+
+inline void OverlapBlendFromLeft4xH_NEON(
+ uint8_t* const prediction, const ptrdiff_t prediction_stride,
+ const int height, const uint8_t* const obmc_prediction,
+ const ptrdiff_t obmc_prediction_stride) {
+ uint8_t* pred = prediction;
+ const uint8_t* obmc_pred = obmc_prediction;
+
+ const uint8x8_t mask_inverter = vdup_n_u8(64);
+ const uint8x8_t pred_mask = Load4(kObmcMask + 2);
+ // 64 - mask
+ const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+ int y = 0;
+ do {
+ WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+ pred += prediction_stride;
+ obmc_pred += obmc_prediction_stride;
+
+ WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+ pred += prediction_stride;
+ obmc_pred += obmc_prediction_stride;
+
+ y += 2;
+ } while (y != height);
+}
+
+inline void OverlapBlendFromLeft8xH_NEON(
+ uint8_t* const prediction, const ptrdiff_t prediction_stride,
+ const int height, const uint8_t* const obmc_prediction,
+ const ptrdiff_t obmc_prediction_stride) {
+ uint8_t* pred = prediction;
+ const uint8_t* obmc_pred = obmc_prediction;
+ const uint8x8_t mask_inverter = vdup_n_u8(64);
+ const uint8x8_t pred_mask = vld1_u8(kObmcMask + 6);
+ // 64 - mask
+ const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+ int y = 0;
+ do {
+ const uint8x8_t pred_val = vld1_u8(pred);
+ const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val);
+ const uint8x8_t obmc_pred_val = vld1_u8(obmc_pred);
+ const uint8x8_t result =
+ vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6);
+
+ vst1_u8(pred, result);
+ pred += prediction_stride;
+ obmc_pred += obmc_prediction_stride;
+ } while (++y != height);
+}
+
+void OverlapBlendFromLeft_NEON(void* const prediction,
+ const ptrdiff_t prediction_stride,
+ const int width, const int height,
+ const void* const obmc_prediction,
+ const ptrdiff_t obmc_prediction_stride) {
+ auto* pred = static_cast<uint8_t*>(prediction);
+ const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction);
+
+ if (width == 2) {
+ OverlapBlend2xH_NEON<true>(pred, prediction_stride, height, obmc_pred,
+ obmc_prediction_stride);
+ return;
+ }
+ if (width == 4) {
+ OverlapBlendFromLeft4xH_NEON(pred, prediction_stride, height, obmc_pred,
+ obmc_prediction_stride);
+ return;
+ }
+ if (width == 8) {
+ OverlapBlendFromLeft8xH_NEON(pred, prediction_stride, height, obmc_pred,
+ obmc_prediction_stride);
+ return;
+ }
+ const uint8x16_t mask_inverter = vdupq_n_u8(64);
+ const uint8_t* mask = kObmcMask + width - 2;
+ int x = 0;
+ do {
+ pred = static_cast<uint8_t*>(prediction) + x;
+ obmc_pred = static_cast<const uint8_t*>(obmc_prediction) + x;
+ const uint8x16_t pred_mask = vld1q_u8(mask + x);
+ // 64 - mask
+ const uint8x16_t obmc_pred_mask = vsubq_u8(mask_inverter, pred_mask);
+ int y = 0;
+ do {
+ const uint8x16_t pred_val = vld1q_u8(pred);
+ const uint8x16_t obmc_pred_val = vld1q_u8(obmc_pred);
+ const uint16x8_t weighted_pred_lo =
+ vmull_u8(vget_low_u8(pred_mask), vget_low_u8(pred_val));
+ const uint8x8_t result_lo =
+ vrshrn_n_u16(vmlal_u8(weighted_pred_lo, vget_low_u8(obmc_pred_mask),
+ vget_low_u8(obmc_pred_val)),
+ 6);
+ const uint16x8_t weighted_pred_hi =
+ vmull_u8(vget_high_u8(pred_mask), vget_high_u8(pred_val));
+ const uint8x8_t result_hi =
+ vrshrn_n_u16(vmlal_u8(weighted_pred_hi, vget_high_u8(obmc_pred_mask),
+ vget_high_u8(obmc_pred_val)),
+ 6);
+ vst1q_u8(pred, vcombine_u8(result_lo, result_hi));
+
+ pred += prediction_stride;
+ obmc_pred += obmc_prediction_stride;
+ } while (++y < height);
+ x += 16;
+ } while (x < width);
+}
+
+inline void OverlapBlendFromTop4x4_NEON(uint8_t* const prediction,
+ const ptrdiff_t prediction_stride,
+ const uint8_t* const obmc_prediction,
+ const ptrdiff_t obmc_prediction_stride,
+ const int height) {
+ uint8_t* pred = prediction;
+ const uint8_t* obmc_pred = obmc_prediction;
+ uint8x8_t pred_mask = vdup_n_u8(kObmcMask[height - 2]);
+ const uint8x8_t mask_inverter = vdup_n_u8(64);
+ uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+ WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+ pred += prediction_stride;
+ obmc_pred += obmc_prediction_stride;
+
+ if (height == 2) {
+ return;
+ }
+
+ pred_mask = vdup_n_u8(kObmcMask[3]);
+ obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+ WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+ pred += prediction_stride;
+ obmc_pred += obmc_prediction_stride;
+
+ pred_mask = vdup_n_u8(kObmcMask[4]);
+ obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+ WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+}
+
+inline void OverlapBlendFromTop4xH_NEON(
+ uint8_t* const prediction, const ptrdiff_t prediction_stride,
+ const int height, const uint8_t* const obmc_prediction,
+ const ptrdiff_t obmc_prediction_stride) {
+ if (height < 8) {
+ OverlapBlendFromTop4x4_NEON(prediction, prediction_stride, obmc_prediction,
+ obmc_prediction_stride, height);
+ return;
+ }
+ uint8_t* pred = prediction;
+ const uint8_t* obmc_pred = obmc_prediction;
+ const uint8_t* mask = kObmcMask + height - 2;
+ const uint8x8_t mask_inverter = vdup_n_u8(64);
+ int y = 0;
+ // Compute 6 lines for height 8, or 12 lines for height 16. The remaining
+ // lines are unchanged as the corresponding mask value is 64.
+ do {
+ uint8x8_t pred_mask = vdup_n_u8(mask[y]);
+ uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+ WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+ pred += prediction_stride;
+ obmc_pred += obmc_prediction_stride;
+
+ pred_mask = vdup_n_u8(mask[y + 1]);
+ obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+ WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+ pred += prediction_stride;
+ obmc_pred += obmc_prediction_stride;
+
+ pred_mask = vdup_n_u8(mask[y + 2]);
+ obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+ WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+ pred += prediction_stride;
+ obmc_pred += obmc_prediction_stride;
+
+ pred_mask = vdup_n_u8(mask[y + 3]);
+ obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+ WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+ pred += prediction_stride;
+ obmc_pred += obmc_prediction_stride;
+
+ pred_mask = vdup_n_u8(mask[y + 4]);
+ obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+ WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+ pred += prediction_stride;
+ obmc_pred += obmc_prediction_stride;
+
+ pred_mask = vdup_n_u8(mask[y + 5]);
+ obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+ WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+ pred += prediction_stride;
+ obmc_pred += obmc_prediction_stride;
+
+ // Increment for the right mask index.
+ y += 6;
+ } while (y < height - 4);
+}
+
+inline void OverlapBlendFromTop8xH_NEON(
+ uint8_t* const prediction, const ptrdiff_t prediction_stride,
+ const int height, const uint8_t* const obmc_prediction,
+ const ptrdiff_t obmc_prediction_stride) {
+ uint8_t* pred = prediction;
+ const uint8_t* obmc_pred = obmc_prediction;
+ const uint8x8_t mask_inverter = vdup_n_u8(64);
+ const uint8_t* mask = kObmcMask + height - 2;
+ const int compute_height = height - (height >> 2);
+ int y = 0;
+ do {
+ const uint8x8_t pred_mask = vdup_n_u8(mask[y]);
+ // 64 - mask
+ const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+ const uint8x8_t pred_val = vld1_u8(pred);
+ const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val);
+ const uint8x8_t obmc_pred_val = vld1_u8(obmc_pred);
+ const uint8x8_t result =
+ vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6);
+
+ vst1_u8(pred, result);
+ pred += prediction_stride;
+ obmc_pred += obmc_prediction_stride;
+ } while (++y != compute_height);
+}
+
+void OverlapBlendFromTop_NEON(void* const prediction,
+ const ptrdiff_t prediction_stride,
+ const int width, const int height,
+ const void* const obmc_prediction,
+ const ptrdiff_t obmc_prediction_stride) {
+ auto* pred = static_cast<uint8_t*>(prediction);
+ const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction);
+
+ if (width == 2) {
+ OverlapBlend2xH_NEON<false>(pred, prediction_stride, height, obmc_pred,
+ obmc_prediction_stride);
+ return;
+ }
+ if (width == 4) {
+ OverlapBlendFromTop4xH_NEON(pred, prediction_stride, height, obmc_pred,
+ obmc_prediction_stride);
+ return;
+ }
+
+ if (width == 8) {
+ OverlapBlendFromTop8xH_NEON(pred, prediction_stride, height, obmc_pred,
+ obmc_prediction_stride);
+ return;
+ }
+
+ const uint8_t* mask = kObmcMask + height - 2;
+ const uint8x8_t mask_inverter = vdup_n_u8(64);
+ // Stop when mask value becomes 64. This is inferred for 4xH.
+ const int compute_height = height - (height >> 2);
+ int y = 0;
+ do {
+ const uint8x8_t pred_mask = vdup_n_u8(mask[y]);
+ // 64 - mask
+ const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+ int x = 0;
+ do {
+ const uint8x16_t pred_val = vld1q_u8(pred + x);
+ const uint8x16_t obmc_pred_val = vld1q_u8(obmc_pred + x);
+ const uint16x8_t weighted_pred_lo =
+ vmull_u8(pred_mask, vget_low_u8(pred_val));
+ const uint8x8_t result_lo =
+ vrshrn_n_u16(vmlal_u8(weighted_pred_lo, obmc_pred_mask,
+ vget_low_u8(obmc_pred_val)),
+ 6);
+ const uint16x8_t weighted_pred_hi =
+ vmull_u8(pred_mask, vget_high_u8(pred_val));
+ const uint8x8_t result_hi =
+ vrshrn_n_u16(vmlal_u8(weighted_pred_hi, obmc_pred_mask,
+ vget_high_u8(obmc_pred_val)),
+ 6);
+ vst1q_u8(pred + x, vcombine_u8(result_lo, result_hi));
+
+ x += 16;
+ } while (x < width);
+ pred += prediction_stride;
+ obmc_pred += obmc_prediction_stride;
+ } while (++y < compute_height);
+}
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+ dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendFromTop_NEON;
+ dsp->obmc_blend[kObmcDirectionHorizontal] = OverlapBlendFromLeft_NEON;
+}
+
+} // namespace
+
+void ObmcInit_NEON() { Init8bpp(); }
+
+} // namespace dsp
+} // namespace libgav1
+
+#else // !LIBGAV1_ENABLE_NEON
+
+namespace libgav1 {
+namespace dsp {
+
+void ObmcInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/obmc_neon.h b/src/dsp/arm/obmc_neon.h
new file mode 100644
index 0000000..d5c9d9c
--- /dev/null
+++ b/src/dsp/arm/obmc_neon.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_OBMC_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_OBMC_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::obmc_blend. This function is not thread-safe.
+void ObmcInit_NEON();
+
+} // namespace dsp
+} // namespace libgav1
+
+// If NEON is enabled, signal the NEON implementation should be used.
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_ObmcVertical LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_ObmcHorizontal LIBGAV1_CPU_NEON
+#endif // LIBGAV1_ENABLE_NEON
+
+#endif // LIBGAV1_SRC_DSP_ARM_OBMC_NEON_H_
diff --git a/src/dsp/arm/super_res_neon.cc b/src/dsp/arm/super_res_neon.cc
new file mode 100644
index 0000000..1680450
--- /dev/null
+++ b/src/dsp/arm/super_res_neon.cc
@@ -0,0 +1,166 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/super_res.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+#include "src/utils/constants.h"
+
+namespace libgav1 {
+namespace dsp {
+
+namespace low_bitdepth {
+namespace {
+
+void SuperResCoefficients_NEON(const int upscaled_width,
+ const int initial_subpixel_x, const int step,
+ void* const coefficients) {
+ auto* dst = static_cast<uint8_t*>(coefficients);
+ int subpixel_x = initial_subpixel_x;
+ int x = RightShiftWithCeiling(upscaled_width, 3);
+ do {
+ uint8x8_t filter[8];
+ uint8x16_t d[kSuperResFilterTaps / 2];
+ for (int i = 0; i < 8; ++i, subpixel_x += step) {
+ filter[i] =
+ vld1_u8(kUpscaleFilterUnsigned[(subpixel_x & kSuperResScaleMask) >>
+ kSuperResExtraBits]);
+ }
+ Transpose8x8(filter, d);
+ vst1q_u8(dst, d[0]);
+ dst += 16;
+ vst1q_u8(dst, d[1]);
+ dst += 16;
+ vst1q_u8(dst, d[2]);
+ dst += 16;
+ vst1q_u8(dst, d[3]);
+ dst += 16;
+ } while (--x != 0);
+}
+
+// Maximum sum of positive taps: 171 = 7 + 86 + 71 + 7
+// Maximum sum: 255*171 == 0xAA55
+// The sum is clipped to [0, 255], so adding all positive and then
+// subtracting all negative with saturation is sufficient.
+// 0 1 2 3 4 5 6 7
+// tap sign: - + - + + - + -
+inline uint8x8_t SuperRes(const uint8x8_t src[kSuperResFilterTaps],
+ const uint8_t** coefficients) {
+ uint8x16_t f[kSuperResFilterTaps / 2];
+ for (int i = 0; i < kSuperResFilterTaps / 2; ++i, *coefficients += 16) {
+ f[i] = vld1q_u8(*coefficients);
+ }
+ uint16x8_t res = vmull_u8(src[1], vget_high_u8(f[0]));
+ res = vmlal_u8(res, src[3], vget_high_u8(f[1]));
+ res = vmlal_u8(res, src[4], vget_low_u8(f[2]));
+ res = vmlal_u8(res, src[6], vget_low_u8(f[3]));
+ uint16x8_t temp = vmull_u8(src[0], vget_low_u8(f[0]));
+ temp = vmlal_u8(temp, src[2], vget_low_u8(f[1]));
+ temp = vmlal_u8(temp, src[5], vget_high_u8(f[2]));
+ temp = vmlal_u8(temp, src[7], vget_high_u8(f[3]));
+ res = vqsubq_u16(res, temp);
+ return vqrshrn_n_u16(res, kFilterBits);
+}
+
+void SuperRes_NEON(const void* const coefficients, void* const source,
+ const ptrdiff_t stride, const int height,
+ const int downscaled_width, const int upscaled_width,
+ const int initial_subpixel_x, const int step,
+ void* const dest) {
+ auto* src = static_cast<uint8_t*>(source) - DivideBy2(kSuperResFilterTaps);
+ auto* dst = static_cast<uint8_t*>(dest);
+ int y = height;
+ do {
+ const auto* filter = static_cast<const uint8_t*>(coefficients);
+ uint8_t* dst_ptr = dst;
+ ExtendLine<uint8_t>(src + DivideBy2(kSuperResFilterTaps), downscaled_width,
+ kSuperResHorizontalBorder, kSuperResHorizontalBorder);
+ int subpixel_x = initial_subpixel_x;
+ uint8x8_t sr[8];
+ uint8x16_t s[8];
+ int x = RightShiftWithCeiling(upscaled_width, 4);
+ // The below code calculates up to 15 extra upscaled
+ // pixels which will over-read up to 15 downscaled pixels in the end of each
+ // row. kSuperResHorizontalBorder accounts for this.
+ do {
+ for (int i = 0; i < 8; ++i, subpixel_x += step) {
+ sr[i] = vld1_u8(&src[subpixel_x >> kSuperResScaleBits]);
+ }
+ for (int i = 0; i < 8; ++i, subpixel_x += step) {
+ const uint8x8_t s_hi = vld1_u8(&src[subpixel_x >> kSuperResScaleBits]);
+ s[i] = vcombine_u8(sr[i], s_hi);
+ }
+ Transpose8x16(s);
+ // Do not use loop for the following 8 instructions, since the compiler
+ // will generate redundant code.
+ sr[0] = vget_low_u8(s[0]);
+ sr[1] = vget_low_u8(s[1]);
+ sr[2] = vget_low_u8(s[2]);
+ sr[3] = vget_low_u8(s[3]);
+ sr[4] = vget_low_u8(s[4]);
+ sr[5] = vget_low_u8(s[5]);
+ sr[6] = vget_low_u8(s[6]);
+ sr[7] = vget_low_u8(s[7]);
+ const uint8x8_t d0 = SuperRes(sr, &filter);
+ // Do not use loop for the following 8 instructions, since the compiler
+ // will generate redundant code.
+ sr[0] = vget_high_u8(s[0]);
+ sr[1] = vget_high_u8(s[1]);
+ sr[2] = vget_high_u8(s[2]);
+ sr[3] = vget_high_u8(s[3]);
+ sr[4] = vget_high_u8(s[4]);
+ sr[5] = vget_high_u8(s[5]);
+ sr[6] = vget_high_u8(s[6]);
+ sr[7] = vget_high_u8(s[7]);
+ const uint8x8_t d1 = SuperRes(sr, &filter);
+ vst1q_u8(dst_ptr, vcombine_u8(d0, d1));
+ dst_ptr += 16;
+ } while (--x != 0);
+ src += stride;
+ dst += stride;
+ } while (--y != 0);
+}
+
+void Init8bpp() {
+ Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ dsp->super_res_coefficients = SuperResCoefficients_NEON;
+ dsp->super_res = SuperRes_NEON;
+}
+
+} // namespace
+} // namespace low_bitdepth
+
+void SuperResInit_NEON() { low_bitdepth::Init8bpp(); }
+
+} // namespace dsp
+} // namespace libgav1
+
+#else // !LIBGAV1_ENABLE_NEON
+
+namespace libgav1 {
+namespace dsp {
+
+void SuperResInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/super_res_neon.h b/src/dsp/arm/super_res_neon.h
new file mode 100644
index 0000000..f51785d
--- /dev/null
+++ b/src/dsp/arm/super_res_neon.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_SUPER_RES_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_SUPER_RES_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::super_res. This function is not thread-safe.
+void SuperResInit_NEON();
+
+} // namespace dsp
+} // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_SuperRes LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_SuperResClip LIBGAV1_CPU_NEON
+#endif // LIBGAV1_ENABLE_NEON
+
+#endif // LIBGAV1_SRC_DSP_ARM_SUPER_RES_NEON_H_
diff --git a/src/dsp/arm/warp_neon.cc b/src/dsp/arm/warp_neon.cc
new file mode 100644
index 0000000..7a41998
--- /dev/null
+++ b/src/dsp/arm/warp_neon.cc
@@ -0,0 +1,453 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/warp.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <cstdlib>
+#include <type_traits>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+#include "src/utils/constants.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+// Number of extra bits of precision in warped filtering.
+constexpr int kWarpedDiffPrecisionBits = 10;
+constexpr int kFirstPassOffset = 1 << 14;
+constexpr int kOffsetRemoval =
+ (kFirstPassOffset >> kInterRoundBitsHorizontal) * 128;
+
+// Applies the horizontal filter to one source row and stores the result in
+// |intermediate_result_row|. |intermediate_result_row| is a row in the 15x8
+// |intermediate_result| two-dimensional array.
+//
+// src_row_centered contains 16 "centered" samples of a source row. (We center
+// the samples by subtracting 128 from the samples.)
+void HorizontalFilter(const int sx4, const int16_t alpha,
+ const int8x16_t src_row_centered,
+ int16_t intermediate_result_row[8]) {
+ int sx = sx4 - MultiplyBy4(alpha);
+ int8x8_t filter[8];
+ for (int x = 0; x < 8; ++x) {
+ const int offset = RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) +
+ kWarpedPixelPrecisionShifts;
+ filter[x] = vld1_s8(kWarpedFilters8[offset]);
+ sx += alpha;
+ }
+ Transpose8x8(filter);
+ // Add kFirstPassOffset to ensure |sum| stays within uint16_t.
+ // Add 128 (offset) * 128 (filter sum) (also 1 << 14) to account for the
+ // centering of the source samples. These combined are 1 << 15 or -32768.
+ int16x8_t sum =
+ vdupq_n_s16(static_cast<int16_t>(kFirstPassOffset + 128 * 128));
+ // Unrolled k = 0..7 loop. We need to manually unroll the loop because the
+ // third argument (an index value) to vextq_s8() must be a constant
+ // (immediate). src_row_window is a sliding window of length 8 into
+ // src_row_centered.
+ // k = 0.
+ int8x8_t src_row_window = vget_low_s8(src_row_centered);
+ sum = vmlal_s8(sum, filter[0], src_row_window);
+ // k = 1.
+ src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 1));
+ sum = vmlal_s8(sum, filter[1], src_row_window);
+ // k = 2.
+ src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 2));
+ sum = vmlal_s8(sum, filter[2], src_row_window);
+ // k = 3.
+ src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 3));
+ sum = vmlal_s8(sum, filter[3], src_row_window);
+ // k = 4.
+ src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 4));
+ sum = vmlal_s8(sum, filter[4], src_row_window);
+ // k = 5.
+ src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 5));
+ sum = vmlal_s8(sum, filter[5], src_row_window);
+ // k = 6.
+ src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 6));
+ sum = vmlal_s8(sum, filter[6], src_row_window);
+ // k = 7.
+ src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 7));
+ sum = vmlal_s8(sum, filter[7], src_row_window);
+ // End of unrolled k = 0..7 loop.
+ // Due to the offset |sum| is guaranteed to be unsigned.
+ uint16x8_t sum_unsigned = vreinterpretq_u16_s16(sum);
+ sum_unsigned = vrshrq_n_u16(sum_unsigned, kInterRoundBitsHorizontal);
+ // After the shift |sum_unsigned| will fit into int16_t.
+ vst1q_s16(intermediate_result_row, vreinterpretq_s16_u16(sum_unsigned));
+}
+
+template <bool is_compound>
+void Warp_NEON(const void* const source, const ptrdiff_t source_stride,
+ const int source_width, const int source_height,
+ const int* const warp_params, const int subsampling_x,
+ const int subsampling_y, const int block_start_x,
+ const int block_start_y, const int block_width,
+ const int block_height, const int16_t alpha, const int16_t beta,
+ const int16_t gamma, const int16_t delta, void* dest,
+ const ptrdiff_t dest_stride) {
+ constexpr int kRoundBitsVertical =
+ is_compound ? kInterRoundBitsCompoundVertical : kInterRoundBitsVertical;
+ union {
+ // Intermediate_result is the output of the horizontal filtering and
+ // rounding. The range is within 13 (= bitdepth + kFilterBits + 1 -
+ // kInterRoundBitsHorizontal) bits (unsigned). We use the signed int16_t
+ // type so that we can multiply it by kWarpedFilters (which has signed
+ // values) using vmlal_s16().
+ int16_t intermediate_result[15][8]; // 15 rows, 8 columns.
+ // In the simple special cases where the samples in each row are all the
+ // same, store one sample per row in a column vector.
+ int16_t intermediate_result_column[15];
+ };
+
+ const auto* const src = static_cast<const uint8_t*>(source);
+ using DestType =
+ typename std::conditional<is_compound, int16_t, uint8_t>::type;
+ auto* dst = static_cast<DestType*>(dest);
+
+ assert(block_width >= 8);
+ assert(block_height >= 8);
+
+ // Warp process applies for each 8x8 block.
+ int start_y = block_start_y;
+ do {
+ int start_x = block_start_x;
+ do {
+ const int src_x = (start_x + 4) << subsampling_x;
+ const int src_y = (start_y + 4) << subsampling_y;
+ const int dst_x =
+ src_x * warp_params[2] + src_y * warp_params[3] + warp_params[0];
+ const int dst_y =
+ src_x * warp_params[4] + src_y * warp_params[5] + warp_params[1];
+ const int x4 = dst_x >> subsampling_x;
+ const int y4 = dst_y >> subsampling_y;
+ const int ix4 = x4 >> kWarpedModelPrecisionBits;
+ const int iy4 = y4 >> kWarpedModelPrecisionBits;
+ // A prediction block may fall outside the frame's boundaries. If a
+ // prediction block is calculated using only samples outside the frame's
+ // boundary, the filtering can be simplified. We can divide the plane
+ // into several regions and handle them differently.
+ //
+ // | |
+ // 1 | 3 | 1
+ // | |
+ // -------+-----------+-------
+ // |***********|
+ // 2 |*****4*****| 2
+ // |***********|
+ // -------+-----------+-------
+ // | |
+ // 1 | 3 | 1
+ // | |
+ //
+ // At the center, region 4 represents the frame and is the general case.
+ //
+ // In regions 1 and 2, the prediction block is outside the frame's
+ // boundary horizontally. Therefore the horizontal filtering can be
+ // simplified. Furthermore, in the region 1 (at the four corners), the
+ // prediction is outside the frame's boundary both horizontally and
+ // vertically, so we get a constant prediction block.
+ //
+ // In region 3, the prediction block is outside the frame's boundary
+ // vertically. Unfortunately because we apply the horizontal filters
+ // first, by the time we apply the vertical filters, they no longer see
+ // simple inputs. So the only simplification is that all the rows are
+ // the same, but we still need to apply all the horizontal and vertical
+ // filters.
+
+ // Check for two simple special cases, where the horizontal filter can
+ // be significantly simplified.
+ //
+ // In general, for each row, the horizontal filter is calculated as
+ // follows:
+ // for (int x = -4; x < 4; ++x) {
+ // const int offset = ...;
+ // int sum = first_pass_offset;
+ // for (int k = 0; k < 8; ++k) {
+ // const int column = Clip3(ix4 + x + k - 3, 0, source_width - 1);
+ // sum += kWarpedFilters[offset][k] * src_row[column];
+ // }
+ // ...
+ // }
+ // The column index before clipping, ix4 + x + k - 3, varies in the range
+ // ix4 - 7 <= ix4 + x + k - 3 <= ix4 + 7. If ix4 - 7 >= source_width - 1
+ // or ix4 + 7 <= 0, then all the column indexes are clipped to the same
+ // border index (source_width - 1 or 0, respectively). Then for each x,
+ // the inner for loop of the horizontal filter is reduced to multiplying
+ // the border pixel by the sum of the filter coefficients.
+ if (ix4 - 7 >= source_width - 1 || ix4 + 7 <= 0) {
+ // Regions 1 and 2.
+ // Points to the left or right border of the first row of |src|.
+ const uint8_t* first_row_border =
+ (ix4 + 7 <= 0) ? src : src + source_width - 1;
+ // In general, for y in [-7, 8), the row number iy4 + y is clipped:
+ // const int row = Clip3(iy4 + y, 0, source_height - 1);
+ // In two special cases, iy4 + y is clipped to either 0 or
+ // source_height - 1 for all y. In the rest of the cases, iy4 + y is
+ // bounded and we can avoid clipping iy4 + y by relying on a reference
+ // frame's boundary extension on the top and bottom.
+ if (iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0) {
+ // Region 1.
+ // Every sample used to calculate the prediction block has the same
+ // value. So the whole prediction block has the same value.
+ const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1;
+ const uint8_t row_border_pixel =
+ first_row_border[row * source_stride];
+
+ DestType* dst_row = dst + start_x - block_start_x;
+ for (int y = 0; y < 8; ++y) {
+ if (is_compound) {
+ const int16x8_t sum =
+ vdupq_n_s16(row_border_pixel << (kInterRoundBitsVertical -
+ kRoundBitsVertical));
+ vst1q_s16(reinterpret_cast<int16_t*>(dst_row), sum);
+ } else {
+ memset(dst_row, row_border_pixel, 8);
+ }
+ dst_row += dest_stride;
+ }
+ // End of region 1. Continue the |start_x| do-while loop.
+ start_x += 8;
+ continue;
+ }
+
+ // Region 2.
+ // Horizontal filter.
+ // The input values in this region are generated by extending the border
+ // which makes them identical in the horizontal direction. This
+ // computation could be inlined in the vertical pass but most
+ // implementations will need a transpose of some sort.
+ // It is not necessary to use the offset values here because the
+ // horizontal pass is a simple shift and the vertical pass will always
+ // require using 32 bits.
+ for (int y = -7; y < 8; ++y) {
+ // We may over-read up to 13 pixels above the top source row, or up
+ // to 13 pixels below the bottom source row. This is proved in
+ // warp.cc.
+ const int row = iy4 + y;
+ int sum = first_row_border[row * source_stride];
+ sum <<= (kFilterBits - kInterRoundBitsHorizontal);
+ intermediate_result_column[y + 7] = sum;
+ }
+ // Vertical filter.
+ DestType* dst_row = dst + start_x - block_start_x;
+ int sy4 =
+ (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta);
+ for (int y = 0; y < 8; ++y) {
+ int sy = sy4 - MultiplyBy4(gamma);
+#if defined(__aarch64__)
+ const int16x8_t intermediate =
+ vld1q_s16(&intermediate_result_column[y]);
+ int16_t tmp[8];
+ for (int x = 0; x < 8; ++x) {
+ const int offset =
+ RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
+ kWarpedPixelPrecisionShifts;
+ const int16x8_t filter = vld1q_s16(kWarpedFilters[offset]);
+ const int32x4_t product_low =
+ vmull_s16(vget_low_s16(filter), vget_low_s16(intermediate));
+ const int32x4_t product_high =
+ vmull_s16(vget_high_s16(filter), vget_high_s16(intermediate));
+ // vaddvq_s32 is only available on __aarch64__.
+ const int32_t sum =
+ vaddvq_s32(product_low) + vaddvq_s32(product_high);
+ const int16_t sum_descale =
+ RightShiftWithRounding(sum, kRoundBitsVertical);
+ if (is_compound) {
+ dst_row[x] = sum_descale;
+ } else {
+ tmp[x] = sum_descale;
+ }
+ sy += gamma;
+ }
+ if (!is_compound) {
+ const int16x8_t sum = vld1q_s16(tmp);
+ vst1_u8(reinterpret_cast<uint8_t*>(dst_row), vqmovun_s16(sum));
+ }
+#else // !defined(__aarch64__)
+ int16x8_t filter[8];
+ for (int x = 0; x < 8; ++x) {
+ const int offset =
+ RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
+ kWarpedPixelPrecisionShifts;
+ filter[x] = vld1q_s16(kWarpedFilters[offset]);
+ sy += gamma;
+ }
+ Transpose8x8(filter);
+ int32x4_t sum_low = vdupq_n_s32(0);
+ int32x4_t sum_high = sum_low;
+ for (int k = 0; k < 8; ++k) {
+ const int16_t intermediate = intermediate_result_column[y + k];
+ sum_low =
+ vmlal_n_s16(sum_low, vget_low_s16(filter[k]), intermediate);
+ sum_high =
+ vmlal_n_s16(sum_high, vget_high_s16(filter[k]), intermediate);
+ }
+ const int16x8_t sum =
+ vcombine_s16(vrshrn_n_s32(sum_low, kRoundBitsVertical),
+ vrshrn_n_s32(sum_high, kRoundBitsVertical));
+ if (is_compound) {
+ vst1q_s16(reinterpret_cast<int16_t*>(dst_row), sum);
+ } else {
+ vst1_u8(reinterpret_cast<uint8_t*>(dst_row), vqmovun_s16(sum));
+ }
+#endif // defined(__aarch64__)
+ dst_row += dest_stride;
+ sy4 += delta;
+ }
+ // End of region 2. Continue the |start_x| do-while loop.
+ start_x += 8;
+ continue;
+ }
+
+ // Regions 3 and 4.
+ // At this point, we know ix4 - 7 < source_width - 1 and ix4 + 7 > 0.
+
+ // In general, for y in [-7, 8), the row number iy4 + y is clipped:
+ // const int row = Clip3(iy4 + y, 0, source_height - 1);
+ // In two special cases, iy4 + y is clipped to either 0 or
+ // source_height - 1 for all y. In the rest of the cases, iy4 + y is
+ // bounded and we can avoid clipping iy4 + y by relying on a reference
+ // frame's boundary extension on the top and bottom.
+ if (iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0) {
+ // Region 3.
+ // Horizontal filter.
+ const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1;
+ const uint8_t* const src_row = src + row * source_stride;
+ // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also
+ // read but is ignored.
+ //
+ // NOTE: This may read up to 13 bytes before src_row[0] or up to 14
+ // bytes after src_row[source_width - 1]. We assume the source frame
+ // has left and right borders of at least 13 bytes that extend the
+ // frame boundary pixels. We also assume there is at least one extra
+ // padding byte after the right border of the last source row.
+ const uint8x16_t src_row_v = vld1q_u8(&src_row[ix4 - 7]);
+ // Convert src_row_v to int8 (subtract 128).
+ const int8x16_t src_row_centered =
+ vreinterpretq_s8_u8(vsubq_u8(src_row_v, vdupq_n_u8(128)));
+ int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7;
+ for (int y = -7; y < 8; ++y) {
+ HorizontalFilter(sx4, alpha, src_row_centered,
+ intermediate_result[y + 7]);
+ sx4 += beta;
+ }
+ } else {
+ // Region 4.
+ // Horizontal filter.
+ int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7;
+ for (int y = -7; y < 8; ++y) {
+ // We may over-read up to 13 pixels above the top source row, or up
+ // to 13 pixels below the bottom source row. This is proved in
+ // warp.cc.
+ const int row = iy4 + y;
+ const uint8_t* const src_row = src + row * source_stride;
+ // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also
+ // read but is ignored.
+ //
+ // NOTE: This may read up to 13 bytes before src_row[0] or up to 14
+ // bytes after src_row[source_width - 1]. We assume the source frame
+ // has left and right borders of at least 13 bytes that extend the
+ // frame boundary pixels. We also assume there is at least one extra
+ // padding byte after the right border of the last source row.
+ const uint8x16_t src_row_v = vld1q_u8(&src_row[ix4 - 7]);
+ // Convert src_row_v to int8 (subtract 128).
+ const int8x16_t src_row_centered =
+ vreinterpretq_s8_u8(vsubq_u8(src_row_v, vdupq_n_u8(128)));
+ HorizontalFilter(sx4, alpha, src_row_centered,
+ intermediate_result[y + 7]);
+ sx4 += beta;
+ }
+ }
+
+ // Regions 3 and 4.
+ // Vertical filter.
+ DestType* dst_row = dst + start_x - block_start_x;
+ int sy4 =
+ (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta);
+ for (int y = 0; y < 8; ++y) {
+ int sy = sy4 - MultiplyBy4(gamma);
+ int16x8_t filter[8];
+ for (int x = 0; x < 8; ++x) {
+ const int offset =
+ RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
+ kWarpedPixelPrecisionShifts;
+ filter[x] = vld1q_s16(kWarpedFilters[offset]);
+ sy += gamma;
+ }
+ Transpose8x8(filter);
+ int32x4_t sum_low = vdupq_n_s32(-kOffsetRemoval);
+ int32x4_t sum_high = sum_low;
+ for (int k = 0; k < 8; ++k) {
+ const int16x8_t intermediate = vld1q_s16(intermediate_result[y + k]);
+ sum_low = vmlal_s16(sum_low, vget_low_s16(filter[k]),
+ vget_low_s16(intermediate));
+ sum_high = vmlal_s16(sum_high, vget_high_s16(filter[k]),
+ vget_high_s16(intermediate));
+ }
+ const int16x8_t sum =
+ vcombine_s16(vrshrn_n_s32(sum_low, kRoundBitsVertical),
+ vrshrn_n_s32(sum_high, kRoundBitsVertical));
+ if (is_compound) {
+ vst1q_s16(reinterpret_cast<int16_t*>(dst_row), sum);
+ } else {
+ vst1_u8(reinterpret_cast<uint8_t*>(dst_row), vqmovun_s16(sum));
+ }
+ dst_row += dest_stride;
+ sy4 += delta;
+ }
+ start_x += 8;
+ } while (start_x < block_start_x + block_width);
+ dst += 8 * dest_stride;
+ start_y += 8;
+ } while (start_y < block_start_y + block_height);
+}
+
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+ dsp->warp = Warp_NEON</*is_compound=*/false>;
+ dsp->warp_compound = Warp_NEON</*is_compound=*/true>;
+}
+
+} // namespace
+} // namespace low_bitdepth
+
+void WarpInit_NEON() { low_bitdepth::Init8bpp(); }
+
+} // namespace dsp
+} // namespace libgav1
+#else // !LIBGAV1_ENABLE_NEON
+namespace libgav1 {
+namespace dsp {
+
+void WarpInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/warp_neon.h b/src/dsp/arm/warp_neon.h
new file mode 100644
index 0000000..dbcaa23
--- /dev/null
+++ b/src/dsp/arm/warp_neon.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_WARP_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_WARP_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::warp. This function is not thread-safe.
+void WarpInit_NEON();
+
+} // namespace dsp
+} // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_Warp LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WarpCompound LIBGAV1_CPU_NEON
+#endif // LIBGAV1_ENABLE_NEON
+
+#endif // LIBGAV1_SRC_DSP_ARM_WARP_NEON_H_
diff --git a/src/dsp/arm/weight_mask_neon.cc b/src/dsp/arm/weight_mask_neon.cc
new file mode 100644
index 0000000..49d3be0
--- /dev/null
+++ b/src/dsp/arm/weight_mask_neon.cc
@@ -0,0 +1,463 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/arm/weight_mask_neon.h"
+
+#include "src/dsp/weight_mask.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+constexpr int kRoundingBits8bpp = 4;
+
+template <bool mask_is_inverse>
+inline void WeightMask8_NEON(const int16_t* prediction_0,
+ const int16_t* prediction_1, uint8_t* mask) {
+ const int16x8_t pred_0 = vld1q_s16(prediction_0);
+ const int16x8_t pred_1 = vld1q_s16(prediction_1);
+ const uint8x8_t difference_offset = vdup_n_u8(38);
+ const uint8x8_t mask_ceiling = vdup_n_u8(64);
+ const uint16x8_t difference = vrshrq_n_u16(
+ vreinterpretq_u16_s16(vabdq_s16(pred_0, pred_1)), kRoundingBits8bpp);
+ const uint8x8_t adjusted_difference =
+ vqadd_u8(vqshrn_n_u16(difference, 4), difference_offset);
+ const uint8x8_t mask_value = vmin_u8(adjusted_difference, mask_ceiling);
+ if (mask_is_inverse) {
+ const uint8x8_t inverted_mask_value = vsub_u8(mask_ceiling, mask_value);
+ vst1_u8(mask, inverted_mask_value);
+ } else {
+ vst1_u8(mask, mask_value);
+ }
+}
+
+#define WEIGHT8_WITHOUT_STRIDE \
+ WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask)
+
+#define WEIGHT8_AND_STRIDE \
+ WEIGHT8_WITHOUT_STRIDE; \
+ pred_0 += 8; \
+ pred_1 += 8; \
+ mask += mask_stride
+
+template <bool mask_is_inverse>
+void WeightMask8x8_NEON(const void* prediction_0, const void* prediction_1,
+ uint8_t* mask, ptrdiff_t mask_stride) {
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ int y = 0;
+ do {
+ WEIGHT8_AND_STRIDE;
+ } while (++y < 7);
+ WEIGHT8_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask8x16_NEON(const void* prediction_0, const void* prediction_1,
+ uint8_t* mask, ptrdiff_t mask_stride) {
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ int y3 = 0;
+ do {
+ WEIGHT8_AND_STRIDE;
+ WEIGHT8_AND_STRIDE;
+ WEIGHT8_AND_STRIDE;
+ } while (++y3 < 5);
+ WEIGHT8_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask8x32_NEON(const void* prediction_0, const void* prediction_1,
+ uint8_t* mask, ptrdiff_t mask_stride) {
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ int y5 = 0;
+ do {
+ WEIGHT8_AND_STRIDE;
+ WEIGHT8_AND_STRIDE;
+ WEIGHT8_AND_STRIDE;
+ WEIGHT8_AND_STRIDE;
+ WEIGHT8_AND_STRIDE;
+ } while (++y5 < 6);
+ WEIGHT8_AND_STRIDE;
+ WEIGHT8_WITHOUT_STRIDE;
+}
+
+#define WEIGHT16_WITHOUT_STRIDE \
+ WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask); \
+ WeightMask8_NEON<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8)
+
+#define WEIGHT16_AND_STRIDE \
+ WEIGHT16_WITHOUT_STRIDE; \
+ pred_0 += 16; \
+ pred_1 += 16; \
+ mask += mask_stride
+
+template <bool mask_is_inverse>
+void WeightMask16x8_NEON(const void* prediction_0, const void* prediction_1,
+ uint8_t* mask, ptrdiff_t mask_stride) {
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ int y = 0;
+ do {
+ WEIGHT16_AND_STRIDE;
+ } while (++y < 7);
+ WEIGHT16_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask16x16_NEON(const void* prediction_0, const void* prediction_1,
+ uint8_t* mask, ptrdiff_t mask_stride) {
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ int y3 = 0;
+ do {
+ WEIGHT16_AND_STRIDE;
+ WEIGHT16_AND_STRIDE;
+ WEIGHT16_AND_STRIDE;
+ } while (++y3 < 5);
+ WEIGHT16_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask16x32_NEON(const void* prediction_0, const void* prediction_1,
+ uint8_t* mask, ptrdiff_t mask_stride) {
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ int y5 = 0;
+ do {
+ WEIGHT16_AND_STRIDE;
+ WEIGHT16_AND_STRIDE;
+ WEIGHT16_AND_STRIDE;
+ WEIGHT16_AND_STRIDE;
+ WEIGHT16_AND_STRIDE;
+ } while (++y5 < 6);
+ WEIGHT16_AND_STRIDE;
+ WEIGHT16_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask16x64_NEON(const void* prediction_0, const void* prediction_1,
+ uint8_t* mask, ptrdiff_t mask_stride) {
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ int y3 = 0;
+ do {
+ WEIGHT16_AND_STRIDE;
+ WEIGHT16_AND_STRIDE;
+ WEIGHT16_AND_STRIDE;
+ } while (++y3 < 21);
+ WEIGHT16_WITHOUT_STRIDE;
+}
+
+#define WEIGHT32_WITHOUT_STRIDE \
+ WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask); \
+ WeightMask8_NEON<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8); \
+ WeightMask8_NEON<mask_is_inverse>(pred_0 + 16, pred_1 + 16, mask + 16); \
+ WeightMask8_NEON<mask_is_inverse>(pred_0 + 24, pred_1 + 24, mask + 24)
+
+#define WEIGHT32_AND_STRIDE \
+ WEIGHT32_WITHOUT_STRIDE; \
+ pred_0 += 32; \
+ pred_1 += 32; \
+ mask += mask_stride
+
+template <bool mask_is_inverse>
+void WeightMask32x8_NEON(const void* prediction_0, const void* prediction_1,
+ uint8_t* mask, ptrdiff_t mask_stride) {
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ WEIGHT32_AND_STRIDE;
+ WEIGHT32_AND_STRIDE;
+ WEIGHT32_AND_STRIDE;
+ WEIGHT32_AND_STRIDE;
+ WEIGHT32_AND_STRIDE;
+ WEIGHT32_AND_STRIDE;
+ WEIGHT32_AND_STRIDE;
+ WEIGHT32_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask32x16_NEON(const void* prediction_0, const void* prediction_1,
+ uint8_t* mask, ptrdiff_t mask_stride) {
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ int y3 = 0;
+ do {
+ WEIGHT32_AND_STRIDE;
+ WEIGHT32_AND_STRIDE;
+ WEIGHT32_AND_STRIDE;
+ } while (++y3 < 5);
+ WEIGHT32_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask32x32_NEON(const void* prediction_0, const void* prediction_1,
+ uint8_t* mask, ptrdiff_t mask_stride) {
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ int y5 = 0;
+ do {
+ WEIGHT32_AND_STRIDE;
+ WEIGHT32_AND_STRIDE;
+ WEIGHT32_AND_STRIDE;
+ WEIGHT32_AND_STRIDE;
+ WEIGHT32_AND_STRIDE;
+ } while (++y5 < 6);
+ WEIGHT32_AND_STRIDE;
+ WEIGHT32_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask32x64_NEON(const void* prediction_0, const void* prediction_1,
+ uint8_t* mask, ptrdiff_t mask_stride) {
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ int y3 = 0;
+ do {
+ WEIGHT32_AND_STRIDE;
+ WEIGHT32_AND_STRIDE;
+ WEIGHT32_AND_STRIDE;
+ } while (++y3 < 21);
+ WEIGHT32_WITHOUT_STRIDE;
+}
+
+#define WEIGHT64_WITHOUT_STRIDE \
+ WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask); \
+ WeightMask8_NEON<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8); \
+ WeightMask8_NEON<mask_is_inverse>(pred_0 + 16, pred_1 + 16, mask + 16); \
+ WeightMask8_NEON<mask_is_inverse>(pred_0 + 24, pred_1 + 24, mask + 24); \
+ WeightMask8_NEON<mask_is_inverse>(pred_0 + 32, pred_1 + 32, mask + 32); \
+ WeightMask8_NEON<mask_is_inverse>(pred_0 + 40, pred_1 + 40, mask + 40); \
+ WeightMask8_NEON<mask_is_inverse>(pred_0 + 48, pred_1 + 48, mask + 48); \
+ WeightMask8_NEON<mask_is_inverse>(pred_0 + 56, pred_1 + 56, mask + 56)
+
+#define WEIGHT64_AND_STRIDE \
+ WEIGHT64_WITHOUT_STRIDE; \
+ pred_0 += 64; \
+ pred_1 += 64; \
+ mask += mask_stride
+
+template <bool mask_is_inverse>
+void WeightMask64x16_NEON(const void* prediction_0, const void* prediction_1,
+ uint8_t* mask, ptrdiff_t mask_stride) {
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ int y3 = 0;
+ do {
+ WEIGHT64_AND_STRIDE;
+ WEIGHT64_AND_STRIDE;
+ WEIGHT64_AND_STRIDE;
+ } while (++y3 < 5);
+ WEIGHT64_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask64x32_NEON(const void* prediction_0, const void* prediction_1,
+ uint8_t* mask, ptrdiff_t mask_stride) {
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ int y5 = 0;
+ do {
+ WEIGHT64_AND_STRIDE;
+ WEIGHT64_AND_STRIDE;
+ WEIGHT64_AND_STRIDE;
+ WEIGHT64_AND_STRIDE;
+ WEIGHT64_AND_STRIDE;
+ } while (++y5 < 6);
+ WEIGHT64_AND_STRIDE;
+ WEIGHT64_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask64x64_NEON(const void* prediction_0, const void* prediction_1,
+ uint8_t* mask, ptrdiff_t mask_stride) {
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ int y3 = 0;
+ do {
+ WEIGHT64_AND_STRIDE;
+ WEIGHT64_AND_STRIDE;
+ WEIGHT64_AND_STRIDE;
+ } while (++y3 < 21);
+ WEIGHT64_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask64x128_NEON(const void* prediction_0, const void* prediction_1,
+ uint8_t* mask, ptrdiff_t mask_stride) {
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ int y3 = 0;
+ do {
+ WEIGHT64_AND_STRIDE;
+ WEIGHT64_AND_STRIDE;
+ WEIGHT64_AND_STRIDE;
+ } while (++y3 < 42);
+ WEIGHT64_AND_STRIDE;
+ WEIGHT64_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask128x64_NEON(const void* prediction_0, const void* prediction_1,
+ uint8_t* mask, ptrdiff_t mask_stride) {
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ int y3 = 0;
+ const ptrdiff_t adjusted_mask_stride = mask_stride - 64;
+ do {
+ WEIGHT64_WITHOUT_STRIDE;
+ pred_0 += 64;
+ pred_1 += 64;
+ mask += 64;
+ WEIGHT64_WITHOUT_STRIDE;
+ pred_0 += 64;
+ pred_1 += 64;
+ mask += adjusted_mask_stride;
+
+ WEIGHT64_WITHOUT_STRIDE;
+ pred_0 += 64;
+ pred_1 += 64;
+ mask += 64;
+ WEIGHT64_WITHOUT_STRIDE;
+ pred_0 += 64;
+ pred_1 += 64;
+ mask += adjusted_mask_stride;
+
+ WEIGHT64_WITHOUT_STRIDE;
+ pred_0 += 64;
+ pred_1 += 64;
+ mask += 64;
+ WEIGHT64_WITHOUT_STRIDE;
+ pred_0 += 64;
+ pred_1 += 64;
+ mask += adjusted_mask_stride;
+ } while (++y3 < 21);
+ WEIGHT64_WITHOUT_STRIDE;
+ pred_0 += 64;
+ pred_1 += 64;
+ mask += 64;
+ WEIGHT64_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask128x128_NEON(const void* prediction_0, const void* prediction_1,
+ uint8_t* mask, ptrdiff_t mask_stride) {
+ const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+ const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+ int y3 = 0;
+ const ptrdiff_t adjusted_mask_stride = mask_stride - 64;
+ do {
+ WEIGHT64_WITHOUT_STRIDE;
+ pred_0 += 64;
+ pred_1 += 64;
+ mask += 64;
+ WEIGHT64_WITHOUT_STRIDE;
+ pred_0 += 64;
+ pred_1 += 64;
+ mask += adjusted_mask_stride;
+
+ WEIGHT64_WITHOUT_STRIDE;
+ pred_0 += 64;
+ pred_1 += 64;
+ mask += 64;
+ WEIGHT64_WITHOUT_STRIDE;
+ pred_0 += 64;
+ pred_1 += 64;
+ mask += adjusted_mask_stride;
+
+ WEIGHT64_WITHOUT_STRIDE;
+ pred_0 += 64;
+ pred_1 += 64;
+ mask += 64;
+ WEIGHT64_WITHOUT_STRIDE;
+ pred_0 += 64;
+ pred_1 += 64;
+ mask += adjusted_mask_stride;
+ } while (++y3 < 42);
+ WEIGHT64_WITHOUT_STRIDE;
+ pred_0 += 64;
+ pred_1 += 64;
+ mask += 64;
+ WEIGHT64_WITHOUT_STRIDE;
+ pred_0 += 64;
+ pred_1 += 64;
+ mask += adjusted_mask_stride;
+
+ WEIGHT64_WITHOUT_STRIDE;
+ pred_0 += 64;
+ pred_1 += 64;
+ mask += 64;
+ WEIGHT64_WITHOUT_STRIDE;
+}
+
+#define INIT_WEIGHT_MASK_8BPP(width, height, w_index, h_index) \
+ dsp->weight_mask[w_index][h_index][0] = \
+ WeightMask##width##x##height##_NEON<0>; \
+ dsp->weight_mask[w_index][h_index][1] = WeightMask##width##x##height##_NEON<1>
+void Init8bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+ assert(dsp != nullptr);
+ INIT_WEIGHT_MASK_8BPP(8, 8, 0, 0);
+ INIT_WEIGHT_MASK_8BPP(8, 16, 0, 1);
+ INIT_WEIGHT_MASK_8BPP(8, 32, 0, 2);
+ INIT_WEIGHT_MASK_8BPP(16, 8, 1, 0);
+ INIT_WEIGHT_MASK_8BPP(16, 16, 1, 1);
+ INIT_WEIGHT_MASK_8BPP(16, 32, 1, 2);
+ INIT_WEIGHT_MASK_8BPP(16, 64, 1, 3);
+ INIT_WEIGHT_MASK_8BPP(32, 8, 2, 0);
+ INIT_WEIGHT_MASK_8BPP(32, 16, 2, 1);
+ INIT_WEIGHT_MASK_8BPP(32, 32, 2, 2);
+ INIT_WEIGHT_MASK_8BPP(32, 64, 2, 3);
+ INIT_WEIGHT_MASK_8BPP(64, 16, 3, 1);
+ INIT_WEIGHT_MASK_8BPP(64, 32, 3, 2);
+ INIT_WEIGHT_MASK_8BPP(64, 64, 3, 3);
+ INIT_WEIGHT_MASK_8BPP(64, 128, 3, 4);
+ INIT_WEIGHT_MASK_8BPP(128, 64, 4, 3);
+ INIT_WEIGHT_MASK_8BPP(128, 128, 4, 4);
+}
+
+} // namespace
+} // namespace low_bitdepth
+
+void WeightMaskInit_NEON() { low_bitdepth::Init8bpp(); }
+
+} // namespace dsp
+} // namespace libgav1
+
+#else // !LIBGAV1_ENABLE_NEON
+
+namespace libgav1 {
+namespace dsp {
+
+void WeightMaskInit_NEON() {}
+
+} // namespace dsp
+} // namespace libgav1
+#endif // LIBGAV1_ENABLE_NEON
diff --git a/src/dsp/arm/weight_mask_neon.h b/src/dsp/arm/weight_mask_neon.h
new file mode 100644
index 0000000..b4749ec
--- /dev/null
+++ b/src/dsp/arm/weight_mask_neon.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::weight_mask. This function is not thread-safe.
+void WeightMaskInit_NEON();
+
+} // namespace dsp
+} // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_8x8 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_8x16 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_8x32 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_16x8 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_16x16 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_16x32 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_16x64 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_32x8 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_32x16 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_32x32 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_32x64 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_64x16 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_64x32 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_64x64 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_64x128 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_128x64 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_128x128 LIBGAV1_CPU_NEON
+#endif // LIBGAV1_ENABLE_NEON
+
+#endif // LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_NEON_H_