diff options
Diffstat (limited to 'src/dsp/inverse_transform.cc')
-rw-r--r-- | src/dsp/inverse_transform.cc | 267 |
1 files changed, 235 insertions, 32 deletions
diff --git a/src/dsp/inverse_transform.cc b/src/dsp/inverse_transform.cc index 1b0064f..0bbdffa 100644 --- a/src/dsp/inverse_transform.cc +++ b/src/dsp/inverse_transform.cc @@ -18,6 +18,7 @@ #include <cassert> #include <cstdint> #include <cstring> +#include <type_traits> #include "src/dsp/dsp.h" #include "src/utils/array_2d.h" @@ -25,6 +26,15 @@ #include "src/utils/compiler_attributes.h" #include "src/utils/logging.h" +#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) +#undef LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK +#endif + +#if defined(LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK) && \ + LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK +#include <cinttypes> +#endif + namespace libgav1 { namespace dsp { namespace { @@ -34,24 +44,25 @@ namespace { constexpr uint8_t kTransformColumnShift = 4; -#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) -#undef LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK -#endif - -int32_t RangeCheckValue(int32_t value, int8_t range) { +template <typename T> +int32_t RangeCheckValue(T value, int8_t range) { #if defined(LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK) && \ LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK + static_assert( + std::is_same<T, int32_t>::value || std::is_same<T, std::int64_t>::value, + ""); assert(range <= 32); const auto min = static_cast<int32_t>(-(uint32_t{1} << (range - 1))); const auto max = static_cast<int32_t>((uint32_t{1} << (range - 1)) - 1); if (min > value || value > max) { - LIBGAV1_DLOG(ERROR, "coeff out of bit range, value: %d bit range %d\n", - value, range); + LIBGAV1_DLOG(ERROR, + "coeff out of bit range, value: %" PRId64 " bit range %d", + static_cast<int64_t>(value), range); assert(min <= value && value <= max); } #endif // LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK static_cast<void>(range); - return value; + return static_cast<int32_t>(value); } template <typename Residual> @@ -433,7 +444,13 @@ void Adst4_C(void* dest, int8_t range) { // Section 7.13.2.6: It is a requirement of bitstream conformance that all // values stored in the s and x arrays by this process are representable by // a signed integer using range + 12 bits of precision. - int32_t s[7]; + // Note the intermediate value can only exceed INT32_MAX with invalid 12-bit + // content. For simplicity in unoptimized code, int64_t is used for both 10 & + // 12-bit. SIMD implementations can allow these to rollover on platforms + // where this has defined behavior. + using Intermediate = + typename std::conditional<sizeof(Residual) == 2, int32_t, int64_t>::type; + Intermediate s[7]; s[0] = RangeCheckValue(kAdst4Multiplier[0] * dst[0], range + 12); s[1] = RangeCheckValue(kAdst4Multiplier[1] * dst[0], range + 12); s[2] = RangeCheckValue(kAdst4Multiplier[2] * dst[1], range + 12); @@ -454,19 +471,23 @@ void Adst4_C(void* dest, int8_t range) { s[0] = RangeCheckValue(s[0] + s[3], range + 12); s[1] = RangeCheckValue(s[1] - s[4], range + 12); s[3] = s[2]; - s[2] = RangeCheckValue(kAdst4Multiplier[2] * b7, range + 12); + // With range checking enabled b7 would be trapped above. This prevents an + // integer sanitizer warning. In SIMD implementations the multiply can be + // allowed to rollover on platforms where this has defined behavior. + const auto adst2_b7 = static_cast<Intermediate>(kAdst4Multiplier[2]) * b7; + s[2] = RangeCheckValue(adst2_b7, range + 12); // stage 4. s[0] = RangeCheckValue(s[0] + s[5], range + 12); s[1] = RangeCheckValue(s[1] - s[6], range + 12); // stages 5 and 6. - const int32_t x0 = RangeCheckValue(s[0] + s[3], range + 12); - const int32_t x1 = RangeCheckValue(s[1] + s[3], range + 12); - int32_t x3 = RangeCheckValue(s[0] + s[1], range + 12); + const Intermediate x0 = RangeCheckValue(s[0] + s[3], range + 12); + const Intermediate x1 = RangeCheckValue(s[1] + s[3], range + 12); + Intermediate x3 = RangeCheckValue(s[0] + s[1], range + 12); x3 = RangeCheckValue(x3 - s[3], range + 12); - int32_t dst_0 = RightShiftWithRounding(x0, 12); - int32_t dst_1 = RightShiftWithRounding(x1, 12); - int32_t dst_2 = RightShiftWithRounding(s[2], 12); - int32_t dst_3 = RightShiftWithRounding(x3, 12); + auto dst_0 = static_cast<int32_t>(RightShiftWithRounding(x0, 12)); + auto dst_1 = static_cast<int32_t>(RightShiftWithRounding(x1, 12)); + auto dst_2 = static_cast<int32_t>(RightShiftWithRounding(s[2], 12)); + auto dst_3 = static_cast<int32_t>(RightShiftWithRounding(x3, 12)); if (sizeof(Residual) == 2) { // If the first argument to RightShiftWithRounding(..., 12) is only // slightly smaller than 2^27 - 1 (e.g., 0x7fffe4e), adding 2^11 to it @@ -840,6 +861,10 @@ void Adst16DcOnly_C(void* dest, int8_t range, bool should_round, int row_shift, template <typename Residual> void Identity4Row_C(void* dest, int8_t shift) { + // Note the intermediate value can only exceed 32 bits with 12-bit content. + // For simplicity in unoptimized code, int64_t is used for both 10 & 12-bit. + using Intermediate = + typename std::conditional<sizeof(Residual) == 2, int32_t, int64_t>::type; assert(shift == 0 || shift == 1); auto* const dst = static_cast<Residual*>(dest); // If |shift| is 0, |rounding| should be 1 << 11. If |shift| is 1, |rounding| @@ -847,10 +872,10 @@ void Identity4Row_C(void* dest, int8_t shift) { // values of |shift|. const int32_t rounding = (1 + (shift << 1)) << 11; for (int i = 0; i < 4; ++i) { - // The intermediate value here will have to fit into an int32_t for it to be - // bitstream conformant. The multiplication is promoted to int32_t by - // defining kIdentity4Multiplier as int32_t. - int32_t dst_i = (dst[i] * kIdentity4Multiplier + rounding) >> (12 + shift); + const auto intermediate = + static_cast<Intermediate>(dst[i]) * kIdentity4Multiplier; + int32_t dst_i = + static_cast<int32_t>((intermediate + rounding) >> (12 + shift)); if (sizeof(Residual) == 2) { dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX); } @@ -874,16 +899,24 @@ void Identity4Column_C(void* dest, int8_t /*shift*/) { template <int bitdepth, typename Residual> void Identity4DcOnly_C(void* dest, int8_t /*range*/, bool should_round, int row_shift, bool is_row) { + // Note the intermediate value can only exceed 32 bits with 12-bit content. + // For simplicity in unoptimized code, int64_t is used for both 10 & 12-bit. + using Intermediate = + typename std::conditional<sizeof(Residual) == 2, int32_t, int64_t>::type; auto* const dst = static_cast<Residual*>(dest); if (is_row) { if (should_round) { - dst[0] = RightShiftWithRounding(dst[0] * kTransformRowMultiplier, 12); + const auto intermediate = + static_cast<Intermediate>(dst[0]) * kTransformRowMultiplier; + dst[0] = RightShiftWithRounding(intermediate, 12); } const int32_t rounding = (1 + (row_shift << 1)) << 11; + const auto intermediate = + static_cast<Intermediate>(dst[0]) * kIdentity4Multiplier; int32_t dst_i = - (dst[0] * kIdentity4Multiplier + rounding) >> (12 + row_shift); + static_cast<int32_t>((intermediate + rounding) >> (12 + row_shift)); if (sizeof(Residual) == 2) { dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX); } @@ -923,11 +956,17 @@ void Identity8Column_C(void* dest, int8_t /*shift*/) { template <int bitdepth, typename Residual> void Identity8DcOnly_C(void* dest, int8_t /*range*/, bool should_round, int row_shift, bool is_row) { + // Note the intermediate value can only exceed 32 bits with 12-bit content. + // For simplicity in unoptimized code, int64_t is used for both 10 & 12-bit. + using Intermediate = + typename std::conditional<sizeof(Residual) == 2, int32_t, int64_t>::type; auto* const dst = static_cast<Residual*>(dest); if (is_row) { if (should_round) { - dst[0] = RightShiftWithRounding(dst[0] * kTransformRowMultiplier, 12); + const auto intermediate = + static_cast<Intermediate>(dst[0]) * kTransformRowMultiplier; + dst[0] = RightShiftWithRounding(intermediate, 12); } int32_t dst_i = RightShiftWithRounding(MultiplyBy2(dst[0]), row_shift); @@ -954,13 +993,19 @@ void Identity8DcOnly_C(void* dest, int8_t /*range*/, bool should_round, template <typename Residual> void Identity16Row_C(void* dest, int8_t shift) { assert(shift == 1 || shift == 2); + // Note the intermediate value can only exceed 32 bits with 12-bit content. + // For simplicity in unoptimized code, int64_t is used for both 10 & 12-bit. + using Intermediate = + typename std::conditional<sizeof(Residual) == 2, int32_t, int64_t>::type; auto* const dst = static_cast<Residual*>(dest); const int32_t rounding = (1 + (1 << shift)) << 11; for (int i = 0; i < 16; ++i) { - // The intermediate value here will have to fit into an int32_t for it to be - // bitstream conformant. The multiplication is promoted to int32_t by - // defining kIdentity16Multiplier as int32_t. - int32_t dst_i = (dst[i] * kIdentity16Multiplier + rounding) >> (12 + shift); + // Note the intermediate value can only exceed 32 bits with 12-bit content. + // For simplicity in unoptimized code, int64_t is used for all cases. + const auto intermediate = + static_cast<Intermediate>(dst[i]) * kIdentity16Multiplier; + int32_t dst_i = + static_cast<int32_t>((intermediate + rounding) >> (12 + shift)); if (sizeof(Residual) == 2) { dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX); } @@ -985,16 +1030,24 @@ void Identity16Column_C(void* dest, int8_t /*shift*/) { template <int bitdepth, typename Residual> void Identity16DcOnly_C(void* dest, int8_t /*range*/, bool should_round, int row_shift, bool is_row) { + // Note the intermediate value can only exceed 32 bits with 12-bit content. + // For simplicity in unoptimized code, int64_t is used for both 10 & 12-bit. + using Intermediate = + typename std::conditional<sizeof(Residual) == 2, int32_t, int64_t>::type; auto* const dst = static_cast<Residual*>(dest); if (is_row) { if (should_round) { - dst[0] = RightShiftWithRounding(dst[0] * kTransformRowMultiplier, 12); + const auto intermediate = + static_cast<Intermediate>(dst[0]) * kTransformRowMultiplier; + dst[0] = RightShiftWithRounding(intermediate, 12); } const int32_t rounding = (1 + (1 << row_shift)) << 11; + const auto intermediate = + static_cast<Intermediate>(dst[0]) * kIdentity16Multiplier; int32_t dst_i = - (dst[0] * kIdentity16Multiplier + rounding) >> (12 + row_shift); + static_cast<int32_t>((intermediate + rounding) >> (12 + row_shift)); if (sizeof(Residual) == 2) { dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX); } @@ -1034,11 +1087,17 @@ void Identity32Column_C(void* dest, int8_t /*shift*/) { template <int bitdepth, typename Residual> void Identity32DcOnly_C(void* dest, int8_t /*range*/, bool should_round, int row_shift, bool is_row) { + // Note the intermediate value can only exceed 32 bits with 12-bit content. + // For simplicity in unoptimized code, int64_t is used for both 10 & 12-bit. + using Intermediate = + typename std::conditional<sizeof(Residual) == 2, int32_t, int64_t>::type; auto* const dst = static_cast<Residual*>(dest); if (is_row) { if (should_round) { - dst[0] = RightShiftWithRounding(dst[0] * kTransformRowMultiplier, 12); + const auto intermediate = + static_cast<Intermediate>(dst[0]) * kTransformRowMultiplier; + dst[0] = RightShiftWithRounding(intermediate, 12); } int32_t dst_i = RightShiftWithRounding(MultiplyBy4(dst[0]), row_shift); @@ -1612,6 +1671,148 @@ void Init10bpp() { } #endif // LIBGAV1_MAX_BITDEPTH >= 10 +#if LIBGAV1_MAX_BITDEPTH == 12 +void Init12bpp() { + Dsp* const dsp = dsp_internal::GetWritableDspTable(12); + assert(dsp != nullptr); + static_cast<void>(dsp); +#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS + InitAll<12, int32_t, uint16_t>(dsp); +#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +#ifndef LIBGAV1_Dsp12bpp_Transform1dSize4_Transform1dDct + dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kRow] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct, + DctDcOnly_C<12, int32_t, 2>, Dct_C<int32_t, 2>, + /*is_row=*/true>; + dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kColumn] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct, + DctDcOnly_C<12, int32_t, 2>, Dct_C<int32_t, 2>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp12bpp_Transform1dSize8_Transform1dDct + dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kRow] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct, + DctDcOnly_C<12, int32_t, 3>, Dct_C<int32_t, 3>, + /*is_row=*/true>; + dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kColumn] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct, + DctDcOnly_C<12, int32_t, 3>, Dct_C<int32_t, 3>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp12bpp_Transform1dSize16_Transform1dDct + dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kRow] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct, + DctDcOnly_C<12, int32_t, 4>, Dct_C<int32_t, 4>, + /*is_row=*/true>; + dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kColumn] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct, + DctDcOnly_C<12, int32_t, 4>, Dct_C<int32_t, 4>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp12bpp_Transform1dSize32_Transform1dDct + dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kRow] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct, + DctDcOnly_C<12, int32_t, 5>, Dct_C<int32_t, 5>, + /*is_row=*/true>; + dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kColumn] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct, + DctDcOnly_C<12, int32_t, 5>, Dct_C<int32_t, 5>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp12bpp_Transform1dSize64_Transform1dDct + dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kRow] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct, + DctDcOnly_C<12, int32_t, 6>, Dct_C<int32_t, 6>, + /*is_row=*/true>; + dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kColumn] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct, + DctDcOnly_C<12, int32_t, 6>, Dct_C<int32_t, 6>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp12bpp_Transform1dSize4_Transform1dAdst + dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kRow] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dAdst, + Adst4DcOnly_C<12, int32_t>, Adst4_C<int32_t>, + /*is_row=*/true>; + dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kColumn] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dAdst, + Adst4DcOnly_C<12, int32_t>, Adst4_C<int32_t>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp12bpp_Transform1dSize8_Transform1dAdst + dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kRow] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dAdst, + Adst8DcOnly_C<12, int32_t>, Adst8_C<int32_t>, + /*is_row=*/true>; + dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kColumn] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dAdst, + Adst8DcOnly_C<12, int32_t>, Adst8_C<int32_t>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp12bpp_Transform1dSize16_Transform1dAdst + dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kRow] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dAdst, + Adst16DcOnly_C<12, int32_t>, Adst16_C<int32_t>, + /*is_row=*/true>; + dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kColumn] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dAdst, + Adst16DcOnly_C<12, int32_t>, Adst16_C<int32_t>, + /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp12bpp_Transform1dSize4_Transform1dIdentity + dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kRow] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dIdentity, + Identity4DcOnly_C<12, int32_t>, Identity4Row_C<int32_t>, + /*is_row=*/true>; + dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kColumn] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dIdentity, + Identity4DcOnly_C<12, int32_t>, + Identity4Column_C<int32_t>, /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp12bpp_Transform1dSize8_Transform1dIdentity + dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kRow] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dIdentity, + Identity8DcOnly_C<12, int32_t>, Identity8Row_C<int32_t>, + /*is_row=*/true>; + dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kColumn] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dIdentity, + Identity8DcOnly_C<12, int32_t>, + Identity8Column_C<int32_t>, /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp12bpp_Transform1dSize16_Transform1dIdentity + dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kRow] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dIdentity, + Identity16DcOnly_C<12, int32_t>, Identity16Row_C<int32_t>, + /*is_row=*/true>; + dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kColumn] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dIdentity, + Identity16DcOnly_C<12, int32_t>, + Identity16Column_C<int32_t>, /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp12bpp_Transform1dSize32_Transform1dIdentity + dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kRow] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dIdentity, + Identity32DcOnly_C<12, int32_t>, Identity32Row_C<int32_t>, + /*is_row=*/true>; + dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kColumn] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dIdentity, + Identity32DcOnly_C<12, int32_t>, + Identity32Column_C<int32_t>, /*is_row=*/false>; +#endif +#ifndef LIBGAV1_Dsp12bpp_Transform1dSize4_Transform1dWht + dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kRow] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dWht, + Wht4DcOnly_C<12, int32_t>, Wht4_C<int32_t>, + /*is_row=*/true>; + dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kColumn] = + TransformLoop_C<12, int32_t, uint16_t, kTransform1dWht, + Wht4DcOnly_C<12, int32_t>, Wht4_C<int32_t>, + /*is_row=*/false>; +#endif +#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS +} +#endif // LIBGAV1_MAX_BITDEPTH == 12 + } // namespace void InverseTransformInit_C() { @@ -1619,10 +1820,12 @@ void InverseTransformInit_C() { #if LIBGAV1_MAX_BITDEPTH >= 10 Init10bpp(); #endif +#if LIBGAV1_MAX_BITDEPTH == 12 + Init12bpp(); +#endif // Local functions that may be unused depending on the optimizations // available. - static_cast<void>(RangeCheckValue); static_cast<void>(kBitReverseLookup); } |