diff options
Diffstat (limited to 'src/dsp/arm/warp_neon.cc')
-rw-r--r-- | src/dsp/arm/warp_neon.cc | 479 |
1 files changed, 466 insertions, 13 deletions
diff --git a/src/dsp/arm/warp_neon.cc b/src/dsp/arm/warp_neon.cc index c7fb739..71e0a43 100644 --- a/src/dsp/arm/warp_neon.cc +++ b/src/dsp/arm/warp_neon.cc @@ -34,11 +34,16 @@ namespace libgav1 { namespace dsp { -namespace low_bitdepth { namespace { // Number of extra bits of precision in warped filtering. constexpr int kWarpedDiffPrecisionBits = 10; + +} // namespace + +namespace low_bitdepth { +namespace { + constexpr int kFirstPassOffset = 1 << 14; constexpr int kOffsetRemoval = (kFirstPassOffset >> kInterRoundBitsHorizontal) * 128; @@ -54,10 +59,10 @@ void HorizontalFilter(const int sx4, const int16_t alpha, int16_t intermediate_result_row[8]) { int sx = sx4 - MultiplyBy4(alpha); int8x8_t filter[8]; - for (int x = 0; x < 8; ++x) { + for (auto& f : filter) { const int offset = RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) + kWarpedPixelPrecisionShifts; - filter[x] = vld1_s8(kWarpedFilters8[offset]); + f = vld1_s8(kWarpedFilters8[offset]); sx += alpha; } Transpose8x8(filter); @@ -103,13 +108,15 @@ void HorizontalFilter(const int sx4, const int16_t alpha, } template <bool is_compound> -void Warp_NEON(const void* const source, const ptrdiff_t source_stride, - const int source_width, const int source_height, - const int* const warp_params, const int subsampling_x, - const int subsampling_y, const int block_start_x, - const int block_start_y, const int block_width, - const int block_height, const int16_t alpha, const int16_t beta, - const int16_t gamma, const int16_t delta, void* dest, +void Warp_NEON(const void* LIBGAV1_RESTRICT const source, + const ptrdiff_t source_stride, const int source_width, + const int source_height, + const int* LIBGAV1_RESTRICT const warp_params, + const int subsampling_x, const int subsampling_y, + const int block_start_x, const int block_start_y, + const int block_width, const int block_height, + const int16_t alpha, const int16_t beta, const int16_t gamma, + const int16_t delta, void* LIBGAV1_RESTRICT dest, const ptrdiff_t dest_stride) { constexpr int kRoundBitsVertical = is_compound ? kInterRoundBitsCompoundVertical : kInterRoundBitsVertical; @@ -393,11 +400,11 @@ void Warp_NEON(const void* const source, const ptrdiff_t source_stride, for (int y = 0; y < 8; ++y) { int sy = sy4 - MultiplyBy4(gamma); int16x8_t filter[8]; - for (int x = 0; x < 8; ++x) { + for (auto& f : filter) { const int offset = RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) + kWarpedPixelPrecisionShifts; - filter[x] = vld1q_s16(kWarpedFilters[offset]); + f = vld1q_s16(kWarpedFilters[offset]); sy += gamma; } Transpose8x8(filter); @@ -438,7 +445,453 @@ void Init8bpp() { } // namespace } // namespace low_bitdepth -void WarpInit_NEON() { low_bitdepth::Init8bpp(); } +//------------------------------------------------------------------------------ +#if LIBGAV1_MAX_BITDEPTH >= 10 +namespace high_bitdepth { +namespace { + +LIBGAV1_ALWAYS_INLINE uint16x8x2_t LoadSrcRow(uint16_t const* ptr) { + uint16x8x2_t x; + // Clang/gcc uses ldp here. + x.val[0] = vld1q_u16(ptr); + x.val[1] = vld1q_u16(ptr + 8); + return x; +} + +LIBGAV1_ALWAYS_INLINE void HorizontalFilter( + const int sx4, const int16_t alpha, const uint16x8x2_t src_row, + int16_t intermediate_result_row[8]) { + int sx = sx4 - MultiplyBy4(alpha); + int8x8_t filter8[8]; + for (auto& f : filter8) { + const int offset = RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) + + kWarpedPixelPrecisionShifts; + f = vld1_s8(kWarpedFilters8[offset]); + sx += alpha; + } + + Transpose8x8(filter8); + + int16x8_t filter[8]; + for (int i = 0; i < 8; ++i) { + filter[i] = vmovl_s8(filter8[i]); + } + + int32x4x2_t sum; + int16x8_t src_row_window; + // k = 0. + src_row_window = vreinterpretq_s16_u16(src_row.val[0]); + sum.val[0] = vmull_s16(vget_low_s16(filter[0]), vget_low_s16(src_row_window)); + sum.val[1] = VMullHighS16(filter[0], src_row_window); + // k = 1. + src_row_window = + vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 1)); + sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[1]), + vget_low_s16(src_row_window)); + sum.val[1] = VMlalHighS16(sum.val[1], filter[1], src_row_window); + // k = 2. + src_row_window = + vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 2)); + sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[2]), + vget_low_s16(src_row_window)); + sum.val[1] = VMlalHighS16(sum.val[1], filter[2], src_row_window); + // k = 3. + src_row_window = + vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 3)); + sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[3]), + vget_low_s16(src_row_window)); + sum.val[1] = VMlalHighS16(sum.val[1], filter[3], src_row_window); + // k = 4. + src_row_window = + vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 4)); + sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[4]), + vget_low_s16(src_row_window)); + sum.val[1] = VMlalHighS16(sum.val[1], filter[4], src_row_window); + // k = 5. + src_row_window = + vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 5)); + sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[5]), + vget_low_s16(src_row_window)); + sum.val[1] = VMlalHighS16(sum.val[1], filter[5], src_row_window); + // k = 6. + src_row_window = + vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 6)); + sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[6]), + vget_low_s16(src_row_window)); + sum.val[1] = VMlalHighS16(sum.val[1], filter[6], src_row_window); + // k = 7. + src_row_window = + vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 7)); + sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[7]), + vget_low_s16(src_row_window)); + sum.val[1] = VMlalHighS16(sum.val[1], filter[7], src_row_window); + // End of unrolled k = 0..7 loop. + + vst1_s16(intermediate_result_row, + vrshrn_n_s32(sum.val[0], kInterRoundBitsHorizontal)); + vst1_s16(intermediate_result_row + 4, + vrshrn_n_s32(sum.val[1], kInterRoundBitsHorizontal)); +} + +template <bool is_compound> +void Warp_NEON(const void* LIBGAV1_RESTRICT const source, + const ptrdiff_t source_stride, const int source_width, + const int source_height, + const int* LIBGAV1_RESTRICT const warp_params, + const int subsampling_x, const int subsampling_y, + const int block_start_x, const int block_start_y, + const int block_width, const int block_height, + const int16_t alpha, const int16_t beta, const int16_t gamma, + const int16_t delta, void* LIBGAV1_RESTRICT dest, + const ptrdiff_t dest_stride) { + constexpr int kRoundBitsVertical = + is_compound ? kInterRoundBitsCompoundVertical : kInterRoundBitsVertical; + union { + // Intermediate_result is the output of the horizontal filtering and + // rounding. The range is within 13 (= bitdepth + kFilterBits + 1 - + // kInterRoundBitsHorizontal) bits (unsigned). We use the signed int16_t + // type so that we can multiply it by kWarpedFilters (which has signed + // values) using vmlal_s16(). + int16_t intermediate_result[15][8]; // 15 rows, 8 columns. + // In the simple special cases where the samples in each row are all the + // same, store one sample per row in a column vector. + int16_t intermediate_result_column[15]; + }; + + const auto* const src = static_cast<const uint16_t*>(source); + const ptrdiff_t src_stride = source_stride >> 1; + using DestType = + typename std::conditional<is_compound, int16_t, uint16_t>::type; + auto* dst = static_cast<DestType*>(dest); + const ptrdiff_t dst_stride = is_compound ? dest_stride : dest_stride >> 1; + assert(block_width >= 8); + assert(block_height >= 8); + + // Warp process applies for each 8x8 block. + int start_y = block_start_y; + do { + int start_x = block_start_x; + do { + const int src_x = (start_x + 4) << subsampling_x; + const int src_y = (start_y + 4) << subsampling_y; + const int dst_x = + src_x * warp_params[2] + src_y * warp_params[3] + warp_params[0]; + const int dst_y = + src_x * warp_params[4] + src_y * warp_params[5] + warp_params[1]; + const int x4 = dst_x >> subsampling_x; + const int y4 = dst_y >> subsampling_y; + const int ix4 = x4 >> kWarpedModelPrecisionBits; + const int iy4 = y4 >> kWarpedModelPrecisionBits; + // A prediction block may fall outside the frame's boundaries. If a + // prediction block is calculated using only samples outside the frame's + // boundary, the filtering can be simplified. We can divide the plane + // into several regions and handle them differently. + // + // | | + // 1 | 3 | 1 + // | | + // -------+-----------+------- + // |***********| + // 2 |*****4*****| 2 + // |***********| + // -------+-----------+------- + // | | + // 1 | 3 | 1 + // | | + // + // At the center, region 4 represents the frame and is the general case. + // + // In regions 1 and 2, the prediction block is outside the frame's + // boundary horizontally. Therefore the horizontal filtering can be + // simplified. Furthermore, in the region 1 (at the four corners), the + // prediction is outside the frame's boundary both horizontally and + // vertically, so we get a constant prediction block. + // + // In region 3, the prediction block is outside the frame's boundary + // vertically. Unfortunately because we apply the horizontal filters + // first, by the time we apply the vertical filters, they no longer see + // simple inputs. So the only simplification is that all the rows are + // the same, but we still need to apply all the horizontal and vertical + // filters. + + // Check for two simple special cases, where the horizontal filter can + // be significantly simplified. + // + // In general, for each row, the horizontal filter is calculated as + // follows: + // for (int x = -4; x < 4; ++x) { + // const int offset = ...; + // int sum = first_pass_offset; + // for (int k = 0; k < 8; ++k) { + // const int column = Clip3(ix4 + x + k - 3, 0, source_width - 1); + // sum += kWarpedFilters[offset][k] * src_row[column]; + // } + // ... + // } + // The column index before clipping, ix4 + x + k - 3, varies in the range + // ix4 - 7 <= ix4 + x + k - 3 <= ix4 + 7. If ix4 - 7 >= source_width - 1 + // or ix4 + 7 <= 0, then all the column indexes are clipped to the same + // border index (source_width - 1 or 0, respectively). Then for each x, + // the inner for loop of the horizontal filter is reduced to multiplying + // the border pixel by the sum of the filter coefficients. + if (ix4 - 7 >= source_width - 1 || ix4 + 7 <= 0) { + // Regions 1 and 2. + // Points to the left or right border of the first row of |src|. + const uint16_t* first_row_border = + (ix4 + 7 <= 0) ? src : src + source_width - 1; + // In general, for y in [-7, 8), the row number iy4 + y is clipped: + // const int row = Clip3(iy4 + y, 0, source_height - 1); + // In two special cases, iy4 + y is clipped to either 0 or + // source_height - 1 for all y. In the rest of the cases, iy4 + y is + // bounded and we can avoid clipping iy4 + y by relying on a reference + // frame's boundary extension on the top and bottom. + if (iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0) { + // Region 1. + // Every sample used to calculate the prediction block has the same + // value. So the whole prediction block has the same value. + const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1; + const uint16_t row_border_pixel = first_row_border[row * src_stride]; + + DestType* dst_row = dst + start_x - block_start_x; + for (int y = 0; y < 8; ++y) { + if (is_compound) { + const int16x8_t sum = + vdupq_n_s16(row_border_pixel << (kInterRoundBitsVertical - + kRoundBitsVertical)); + vst1q_s16(reinterpret_cast<int16_t*>(dst_row), + vaddq_s16(sum, vdupq_n_s16(kCompoundOffset))); + } else { + vst1q_u16(reinterpret_cast<uint16_t*>(dst_row), + vdupq_n_u16(row_border_pixel)); + } + dst_row += dst_stride; + } + // End of region 1. Continue the |start_x| do-while loop. + start_x += 8; + continue; + } + + // Region 2. + // Horizontal filter. + // The input values in this region are generated by extending the border + // which makes them identical in the horizontal direction. This + // computation could be inlined in the vertical pass but most + // implementations will need a transpose of some sort. + // It is not necessary to use the offset values here because the + // horizontal pass is a simple shift and the vertical pass will always + // require using 32 bits. + for (int y = -7; y < 8; ++y) { + // We may over-read up to 13 pixels above the top source row, or up + // to 13 pixels below the bottom source row. This is proved in + // warp.cc. + const int row = iy4 + y; + int sum = first_row_border[row * src_stride]; + sum <<= (kFilterBits - kInterRoundBitsHorizontal); + intermediate_result_column[y + 7] = sum; + } + // Vertical filter. + DestType* dst_row = dst + start_x - block_start_x; + int sy4 = + (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta); + for (int y = 0; y < 8; ++y) { + int sy = sy4 - MultiplyBy4(gamma); +#if defined(__aarch64__) + const int16x8_t intermediate = + vld1q_s16(&intermediate_result_column[y]); + int16_t tmp[8]; + for (int x = 0; x < 8; ++x) { + const int offset = + RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) + + kWarpedPixelPrecisionShifts; + const int16x8_t filter = vld1q_s16(kWarpedFilters[offset]); + const int32x4_t product_low = + vmull_s16(vget_low_s16(filter), vget_low_s16(intermediate)); + const int32x4_t product_high = + vmull_s16(vget_high_s16(filter), vget_high_s16(intermediate)); + // vaddvq_s32 is only available on __aarch64__. + const int32_t sum = + vaddvq_s32(product_low) + vaddvq_s32(product_high); + const int16_t sum_descale = + RightShiftWithRounding(sum, kRoundBitsVertical); + if (is_compound) { + dst_row[x] = sum_descale + kCompoundOffset; + } else { + tmp[x] = sum_descale; + } + sy += gamma; + } + if (!is_compound) { + const uint16x8_t v_max_bitdepth = + vdupq_n_u16((1 << kBitdepth10) - 1); + const int16x8_t sum = vld1q_s16(tmp); + const uint16x8_t d0 = + vminq_u16(vreinterpretq_u16_s16(vmaxq_s16(sum, vdupq_n_s16(0))), + v_max_bitdepth); + vst1q_u16(reinterpret_cast<uint16_t*>(dst_row), d0); + } +#else // !defined(__aarch64__) + int16x8_t filter[8]; + for (int x = 0; x < 8; ++x) { + const int offset = + RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) + + kWarpedPixelPrecisionShifts; + filter[x] = vld1q_s16(kWarpedFilters[offset]); + sy += gamma; + } + Transpose8x8(filter); + int32x4_t sum_low = vdupq_n_s32(0); + int32x4_t sum_high = sum_low; + for (int k = 0; k < 8; ++k) { + const int16_t intermediate = intermediate_result_column[y + k]; + sum_low = + vmlal_n_s16(sum_low, vget_low_s16(filter[k]), intermediate); + sum_high = + vmlal_n_s16(sum_high, vget_high_s16(filter[k]), intermediate); + } + if (is_compound) { + const int16x8_t sum = + vcombine_s16(vrshrn_n_s32(sum_low, kRoundBitsVertical), + vrshrn_n_s32(sum_high, kRoundBitsVertical)); + vst1q_s16(reinterpret_cast<int16_t*>(dst_row), + vaddq_s16(sum, vdupq_n_s16(kCompoundOffset))); + } else { + const uint16x4_t v_max_bitdepth = + vdup_n_u16((1 << kBitdepth10) - 1); + const uint16x4_t d0 = vmin_u16( + vqrshrun_n_s32(sum_low, kRoundBitsVertical), v_max_bitdepth); + const uint16x4_t d1 = vmin_u16( + vqrshrun_n_s32(sum_high, kRoundBitsVertical), v_max_bitdepth); + vst1_u16(reinterpret_cast<uint16_t*>(dst_row), d0); + vst1_u16(reinterpret_cast<uint16_t*>(dst_row + 4), d1); + } +#endif // defined(__aarch64__) + dst_row += dst_stride; + sy4 += delta; + } + // End of region 2. Continue the |start_x| do-while loop. + start_x += 8; + continue; + } + + // Regions 3 and 4. + // At this point, we know ix4 - 7 < source_width - 1 and ix4 + 7 > 0. + + // In general, for y in [-7, 8), the row number iy4 + y is clipped: + // const int row = Clip3(iy4 + y, 0, source_height - 1); + // In two special cases, iy4 + y is clipped to either 0 or + // source_height - 1 for all y. In the rest of the cases, iy4 + y is + // bounded and we can avoid clipping iy4 + y by relying on a reference + // frame's boundary extension on the top and bottom. + if (iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0) { + // Region 3. + // Horizontal filter. + const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1; + const uint16_t* const src_row = src + row * src_stride; + // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also + // read but is ignored. + // + // NOTE: This may read up to 13 pixels before src_row[0] or up to 14 + // pixels after src_row[source_width - 1]. We assume the source frame + // has left and right borders of at least 13 pixels that extend the + // frame boundary pixels. We also assume there is at least one extra + // padding pixel after the right border of the last source row. + const uint16x8x2_t src_row_v = LoadSrcRow(&src_row[ix4 - 7]); + int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7; + for (int y = -7; y < 8; ++y) { + HorizontalFilter(sx4, alpha, src_row_v, intermediate_result[y + 7]); + sx4 += beta; + } + } else { + // Region 4. + // Horizontal filter. + int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7; + for (int y = -7; y < 8; ++y) { + // We may over-read up to 13 pixels above the top source row, or up + // to 13 pixels below the bottom source row. This is proved in + // warp.cc. + const int row = iy4 + y; + const uint16_t* const src_row = src + row * src_stride; + // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also + // read but is ignored. + // + // NOTE: This may read up to pixels bytes before src_row[0] or up to + // 14 pixels after src_row[source_width - 1]. We assume the source + // frame has left and right borders of at least 13 pixels that extend + // the frame boundary pixels. We also assume there is at least one + // extra padding pixel after the right border of the last source row. + const uint16x8x2_t src_row_v = LoadSrcRow(&src_row[ix4 - 7]); + HorizontalFilter(sx4, alpha, src_row_v, intermediate_result[y + 7]); + sx4 += beta; + } + } + + // Regions 3 and 4. + // Vertical filter. + DestType* dst_row = dst + start_x - block_start_x; + int sy4 = + (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta); + for (int y = 0; y < 8; ++y) { + int sy = sy4 - MultiplyBy4(gamma); + int16x8_t filter[8]; + for (auto& f : filter) { + const int offset = + RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) + + kWarpedPixelPrecisionShifts; + f = vld1q_s16(kWarpedFilters[offset]); + sy += gamma; + } + Transpose8x8(filter); + int32x4_t sum_low = vdupq_n_s32(0); + int32x4_t sum_high = sum_low; + for (int k = 0; k < 8; ++k) { + const int16x8_t intermediate = vld1q_s16(intermediate_result[y + k]); + sum_low = vmlal_s16(sum_low, vget_low_s16(filter[k]), + vget_low_s16(intermediate)); + sum_high = vmlal_s16(sum_high, vget_high_s16(filter[k]), + vget_high_s16(intermediate)); + } + if (is_compound) { + const int16x8_t sum = + vcombine_s16(vrshrn_n_s32(sum_low, kRoundBitsVertical), + vrshrn_n_s32(sum_high, kRoundBitsVertical)); + vst1q_s16(reinterpret_cast<int16_t*>(dst_row), + vaddq_s16(sum, vdupq_n_s16(kCompoundOffset))); + } else { + const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1); + const uint16x4_t d0 = vmin_u16( + vqrshrun_n_s32(sum_low, kRoundBitsVertical), v_max_bitdepth); + const uint16x4_t d1 = vmin_u16( + vqrshrun_n_s32(sum_high, kRoundBitsVertical), v_max_bitdepth); + vst1_u16(reinterpret_cast<uint16_t*>(dst_row), d0); + vst1_u16(reinterpret_cast<uint16_t*>(dst_row + 4), d1); + } + dst_row += dst_stride; + sy4 += delta; + } + start_x += 8; + } while (start_x < block_start_x + block_width); + dst += 8 * dst_stride; + start_y += 8; + } while (start_y < block_start_y + block_height); +} + +void Init10bpp() { + Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth10); + assert(dsp != nullptr); + dsp->warp = Warp_NEON</*is_compound=*/false>; + dsp->warp_compound = Warp_NEON</*is_compound=*/true>; +} + +} // namespace +} // namespace high_bitdepth +#endif // LIBGAV1_MAX_BITDEPTH >= 10 + +void WarpInit_NEON() { + low_bitdepth::Init8bpp(); +#if LIBGAV1_MAX_BITDEPTH >= 10 + high_bitdepth::Init10bpp(); +#endif +} } // namespace dsp } // namespace libgav1 |