aboutsummaryrefslogtreecommitdiff
path: root/src/dsp/inverse_transform.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/dsp/inverse_transform.cc')
-rw-r--r--src/dsp/inverse_transform.cc267
1 files changed, 235 insertions, 32 deletions
diff --git a/src/dsp/inverse_transform.cc b/src/dsp/inverse_transform.cc
index 1b0064f..0bbdffa 100644
--- a/src/dsp/inverse_transform.cc
+++ b/src/dsp/inverse_transform.cc
@@ -18,6 +18,7 @@
#include <cassert>
#include <cstdint>
#include <cstring>
+#include <type_traits>
#include "src/dsp/dsp.h"
#include "src/utils/array_2d.h"
@@ -25,6 +26,15 @@
#include "src/utils/compiler_attributes.h"
#include "src/utils/logging.h"
+#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION)
+#undef LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK
+#endif
+
+#if defined(LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK) && \
+ LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK
+#include <cinttypes>
+#endif
+
namespace libgav1 {
namespace dsp {
namespace {
@@ -34,24 +44,25 @@ namespace {
constexpr uint8_t kTransformColumnShift = 4;
-#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION)
-#undef LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK
-#endif
-
-int32_t RangeCheckValue(int32_t value, int8_t range) {
+template <typename T>
+int32_t RangeCheckValue(T value, int8_t range) {
#if defined(LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK) && \
LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK
+ static_assert(
+ std::is_same<T, int32_t>::value || std::is_same<T, std::int64_t>::value,
+ "");
assert(range <= 32);
const auto min = static_cast<int32_t>(-(uint32_t{1} << (range - 1)));
const auto max = static_cast<int32_t>((uint32_t{1} << (range - 1)) - 1);
if (min > value || value > max) {
- LIBGAV1_DLOG(ERROR, "coeff out of bit range, value: %d bit range %d\n",
- value, range);
+ LIBGAV1_DLOG(ERROR,
+ "coeff out of bit range, value: %" PRId64 " bit range %d",
+ static_cast<int64_t>(value), range);
assert(min <= value && value <= max);
}
#endif // LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK
static_cast<void>(range);
- return value;
+ return static_cast<int32_t>(value);
}
template <typename Residual>
@@ -433,7 +444,13 @@ void Adst4_C(void* dest, int8_t range) {
// Section 7.13.2.6: It is a requirement of bitstream conformance that all
// values stored in the s and x arrays by this process are representable by
// a signed integer using range + 12 bits of precision.
- int32_t s[7];
+ // Note the intermediate value can only exceed INT32_MAX with invalid 12-bit
+ // content. For simplicity in unoptimized code, int64_t is used for both 10 &
+ // 12-bit. SIMD implementations can allow these to rollover on platforms
+ // where this has defined behavior.
+ using Intermediate =
+ typename std::conditional<sizeof(Residual) == 2, int32_t, int64_t>::type;
+ Intermediate s[7];
s[0] = RangeCheckValue(kAdst4Multiplier[0] * dst[0], range + 12);
s[1] = RangeCheckValue(kAdst4Multiplier[1] * dst[0], range + 12);
s[2] = RangeCheckValue(kAdst4Multiplier[2] * dst[1], range + 12);
@@ -454,19 +471,23 @@ void Adst4_C(void* dest, int8_t range) {
s[0] = RangeCheckValue(s[0] + s[3], range + 12);
s[1] = RangeCheckValue(s[1] - s[4], range + 12);
s[3] = s[2];
- s[2] = RangeCheckValue(kAdst4Multiplier[2] * b7, range + 12);
+ // With range checking enabled b7 would be trapped above. This prevents an
+ // integer sanitizer warning. In SIMD implementations the multiply can be
+ // allowed to rollover on platforms where this has defined behavior.
+ const auto adst2_b7 = static_cast<Intermediate>(kAdst4Multiplier[2]) * b7;
+ s[2] = RangeCheckValue(adst2_b7, range + 12);
// stage 4.
s[0] = RangeCheckValue(s[0] + s[5], range + 12);
s[1] = RangeCheckValue(s[1] - s[6], range + 12);
// stages 5 and 6.
- const int32_t x0 = RangeCheckValue(s[0] + s[3], range + 12);
- const int32_t x1 = RangeCheckValue(s[1] + s[3], range + 12);
- int32_t x3 = RangeCheckValue(s[0] + s[1], range + 12);
+ const Intermediate x0 = RangeCheckValue(s[0] + s[3], range + 12);
+ const Intermediate x1 = RangeCheckValue(s[1] + s[3], range + 12);
+ Intermediate x3 = RangeCheckValue(s[0] + s[1], range + 12);
x3 = RangeCheckValue(x3 - s[3], range + 12);
- int32_t dst_0 = RightShiftWithRounding(x0, 12);
- int32_t dst_1 = RightShiftWithRounding(x1, 12);
- int32_t dst_2 = RightShiftWithRounding(s[2], 12);
- int32_t dst_3 = RightShiftWithRounding(x3, 12);
+ auto dst_0 = static_cast<int32_t>(RightShiftWithRounding(x0, 12));
+ auto dst_1 = static_cast<int32_t>(RightShiftWithRounding(x1, 12));
+ auto dst_2 = static_cast<int32_t>(RightShiftWithRounding(s[2], 12));
+ auto dst_3 = static_cast<int32_t>(RightShiftWithRounding(x3, 12));
if (sizeof(Residual) == 2) {
// If the first argument to RightShiftWithRounding(..., 12) is only
// slightly smaller than 2^27 - 1 (e.g., 0x7fffe4e), adding 2^11 to it
@@ -840,6 +861,10 @@ void Adst16DcOnly_C(void* dest, int8_t range, bool should_round, int row_shift,
template <typename Residual>
void Identity4Row_C(void* dest, int8_t shift) {
+ // Note the intermediate value can only exceed 32 bits with 12-bit content.
+ // For simplicity in unoptimized code, int64_t is used for both 10 & 12-bit.
+ using Intermediate =
+ typename std::conditional<sizeof(Residual) == 2, int32_t, int64_t>::type;
assert(shift == 0 || shift == 1);
auto* const dst = static_cast<Residual*>(dest);
// If |shift| is 0, |rounding| should be 1 << 11. If |shift| is 1, |rounding|
@@ -847,10 +872,10 @@ void Identity4Row_C(void* dest, int8_t shift) {
// values of |shift|.
const int32_t rounding = (1 + (shift << 1)) << 11;
for (int i = 0; i < 4; ++i) {
- // The intermediate value here will have to fit into an int32_t for it to be
- // bitstream conformant. The multiplication is promoted to int32_t by
- // defining kIdentity4Multiplier as int32_t.
- int32_t dst_i = (dst[i] * kIdentity4Multiplier + rounding) >> (12 + shift);
+ const auto intermediate =
+ static_cast<Intermediate>(dst[i]) * kIdentity4Multiplier;
+ int32_t dst_i =
+ static_cast<int32_t>((intermediate + rounding) >> (12 + shift));
if (sizeof(Residual) == 2) {
dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX);
}
@@ -874,16 +899,24 @@ void Identity4Column_C(void* dest, int8_t /*shift*/) {
template <int bitdepth, typename Residual>
void Identity4DcOnly_C(void* dest, int8_t /*range*/, bool should_round,
int row_shift, bool is_row) {
+ // Note the intermediate value can only exceed 32 bits with 12-bit content.
+ // For simplicity in unoptimized code, int64_t is used for both 10 & 12-bit.
+ using Intermediate =
+ typename std::conditional<sizeof(Residual) == 2, int32_t, int64_t>::type;
auto* const dst = static_cast<Residual*>(dest);
if (is_row) {
if (should_round) {
- dst[0] = RightShiftWithRounding(dst[0] * kTransformRowMultiplier, 12);
+ const auto intermediate =
+ static_cast<Intermediate>(dst[0]) * kTransformRowMultiplier;
+ dst[0] = RightShiftWithRounding(intermediate, 12);
}
const int32_t rounding = (1 + (row_shift << 1)) << 11;
+ const auto intermediate =
+ static_cast<Intermediate>(dst[0]) * kIdentity4Multiplier;
int32_t dst_i =
- (dst[0] * kIdentity4Multiplier + rounding) >> (12 + row_shift);
+ static_cast<int32_t>((intermediate + rounding) >> (12 + row_shift));
if (sizeof(Residual) == 2) {
dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX);
}
@@ -923,11 +956,17 @@ void Identity8Column_C(void* dest, int8_t /*shift*/) {
template <int bitdepth, typename Residual>
void Identity8DcOnly_C(void* dest, int8_t /*range*/, bool should_round,
int row_shift, bool is_row) {
+ // Note the intermediate value can only exceed 32 bits with 12-bit content.
+ // For simplicity in unoptimized code, int64_t is used for both 10 & 12-bit.
+ using Intermediate =
+ typename std::conditional<sizeof(Residual) == 2, int32_t, int64_t>::type;
auto* const dst = static_cast<Residual*>(dest);
if (is_row) {
if (should_round) {
- dst[0] = RightShiftWithRounding(dst[0] * kTransformRowMultiplier, 12);
+ const auto intermediate =
+ static_cast<Intermediate>(dst[0]) * kTransformRowMultiplier;
+ dst[0] = RightShiftWithRounding(intermediate, 12);
}
int32_t dst_i = RightShiftWithRounding(MultiplyBy2(dst[0]), row_shift);
@@ -954,13 +993,19 @@ void Identity8DcOnly_C(void* dest, int8_t /*range*/, bool should_round,
template <typename Residual>
void Identity16Row_C(void* dest, int8_t shift) {
assert(shift == 1 || shift == 2);
+ // Note the intermediate value can only exceed 32 bits with 12-bit content.
+ // For simplicity in unoptimized code, int64_t is used for both 10 & 12-bit.
+ using Intermediate =
+ typename std::conditional<sizeof(Residual) == 2, int32_t, int64_t>::type;
auto* const dst = static_cast<Residual*>(dest);
const int32_t rounding = (1 + (1 << shift)) << 11;
for (int i = 0; i < 16; ++i) {
- // The intermediate value here will have to fit into an int32_t for it to be
- // bitstream conformant. The multiplication is promoted to int32_t by
- // defining kIdentity16Multiplier as int32_t.
- int32_t dst_i = (dst[i] * kIdentity16Multiplier + rounding) >> (12 + shift);
+ // Note the intermediate value can only exceed 32 bits with 12-bit content.
+ // For simplicity in unoptimized code, int64_t is used for all cases.
+ const auto intermediate =
+ static_cast<Intermediate>(dst[i]) * kIdentity16Multiplier;
+ int32_t dst_i =
+ static_cast<int32_t>((intermediate + rounding) >> (12 + shift));
if (sizeof(Residual) == 2) {
dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX);
}
@@ -985,16 +1030,24 @@ void Identity16Column_C(void* dest, int8_t /*shift*/) {
template <int bitdepth, typename Residual>
void Identity16DcOnly_C(void* dest, int8_t /*range*/, bool should_round,
int row_shift, bool is_row) {
+ // Note the intermediate value can only exceed 32 bits with 12-bit content.
+ // For simplicity in unoptimized code, int64_t is used for both 10 & 12-bit.
+ using Intermediate =
+ typename std::conditional<sizeof(Residual) == 2, int32_t, int64_t>::type;
auto* const dst = static_cast<Residual*>(dest);
if (is_row) {
if (should_round) {
- dst[0] = RightShiftWithRounding(dst[0] * kTransformRowMultiplier, 12);
+ const auto intermediate =
+ static_cast<Intermediate>(dst[0]) * kTransformRowMultiplier;
+ dst[0] = RightShiftWithRounding(intermediate, 12);
}
const int32_t rounding = (1 + (1 << row_shift)) << 11;
+ const auto intermediate =
+ static_cast<Intermediate>(dst[0]) * kIdentity16Multiplier;
int32_t dst_i =
- (dst[0] * kIdentity16Multiplier + rounding) >> (12 + row_shift);
+ static_cast<int32_t>((intermediate + rounding) >> (12 + row_shift));
if (sizeof(Residual) == 2) {
dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX);
}
@@ -1034,11 +1087,17 @@ void Identity32Column_C(void* dest, int8_t /*shift*/) {
template <int bitdepth, typename Residual>
void Identity32DcOnly_C(void* dest, int8_t /*range*/, bool should_round,
int row_shift, bool is_row) {
+ // Note the intermediate value can only exceed 32 bits with 12-bit content.
+ // For simplicity in unoptimized code, int64_t is used for both 10 & 12-bit.
+ using Intermediate =
+ typename std::conditional<sizeof(Residual) == 2, int32_t, int64_t>::type;
auto* const dst = static_cast<Residual*>(dest);
if (is_row) {
if (should_round) {
- dst[0] = RightShiftWithRounding(dst[0] * kTransformRowMultiplier, 12);
+ const auto intermediate =
+ static_cast<Intermediate>(dst[0]) * kTransformRowMultiplier;
+ dst[0] = RightShiftWithRounding(intermediate, 12);
}
int32_t dst_i = RightShiftWithRounding(MultiplyBy4(dst[0]), row_shift);
@@ -1612,6 +1671,148 @@ void Init10bpp() {
}
#endif // LIBGAV1_MAX_BITDEPTH >= 10
+#if LIBGAV1_MAX_BITDEPTH == 12
+void Init12bpp() {
+ Dsp* const dsp = dsp_internal::GetWritableDspTable(12);
+ assert(dsp != nullptr);
+ static_cast<void>(dsp);
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+ InitAll<12, int32_t, uint16_t>(dsp);
+#else // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+#ifndef LIBGAV1_Dsp12bpp_Transform1dSize4_Transform1dDct
+ dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kRow] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct,
+ DctDcOnly_C<12, int32_t, 2>, Dct_C<int32_t, 2>,
+ /*is_row=*/true>;
+ dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kColumn] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct,
+ DctDcOnly_C<12, int32_t, 2>, Dct_C<int32_t, 2>,
+ /*is_row=*/false>;
+#endif
+#ifndef LIBGAV1_Dsp12bpp_Transform1dSize8_Transform1dDct
+ dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kRow] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct,
+ DctDcOnly_C<12, int32_t, 3>, Dct_C<int32_t, 3>,
+ /*is_row=*/true>;
+ dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kColumn] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct,
+ DctDcOnly_C<12, int32_t, 3>, Dct_C<int32_t, 3>,
+ /*is_row=*/false>;
+#endif
+#ifndef LIBGAV1_Dsp12bpp_Transform1dSize16_Transform1dDct
+ dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kRow] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct,
+ DctDcOnly_C<12, int32_t, 4>, Dct_C<int32_t, 4>,
+ /*is_row=*/true>;
+ dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kColumn] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct,
+ DctDcOnly_C<12, int32_t, 4>, Dct_C<int32_t, 4>,
+ /*is_row=*/false>;
+#endif
+#ifndef LIBGAV1_Dsp12bpp_Transform1dSize32_Transform1dDct
+ dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kRow] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct,
+ DctDcOnly_C<12, int32_t, 5>, Dct_C<int32_t, 5>,
+ /*is_row=*/true>;
+ dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kColumn] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct,
+ DctDcOnly_C<12, int32_t, 5>, Dct_C<int32_t, 5>,
+ /*is_row=*/false>;
+#endif
+#ifndef LIBGAV1_Dsp12bpp_Transform1dSize64_Transform1dDct
+ dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kRow] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct,
+ DctDcOnly_C<12, int32_t, 6>, Dct_C<int32_t, 6>,
+ /*is_row=*/true>;
+ dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kColumn] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dDct,
+ DctDcOnly_C<12, int32_t, 6>, Dct_C<int32_t, 6>,
+ /*is_row=*/false>;
+#endif
+#ifndef LIBGAV1_Dsp12bpp_Transform1dSize4_Transform1dAdst
+ dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kRow] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dAdst,
+ Adst4DcOnly_C<12, int32_t>, Adst4_C<int32_t>,
+ /*is_row=*/true>;
+ dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kColumn] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dAdst,
+ Adst4DcOnly_C<12, int32_t>, Adst4_C<int32_t>,
+ /*is_row=*/false>;
+#endif
+#ifndef LIBGAV1_Dsp12bpp_Transform1dSize8_Transform1dAdst
+ dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kRow] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dAdst,
+ Adst8DcOnly_C<12, int32_t>, Adst8_C<int32_t>,
+ /*is_row=*/true>;
+ dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kColumn] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dAdst,
+ Adst8DcOnly_C<12, int32_t>, Adst8_C<int32_t>,
+ /*is_row=*/false>;
+#endif
+#ifndef LIBGAV1_Dsp12bpp_Transform1dSize16_Transform1dAdst
+ dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kRow] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dAdst,
+ Adst16DcOnly_C<12, int32_t>, Adst16_C<int32_t>,
+ /*is_row=*/true>;
+ dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kColumn] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dAdst,
+ Adst16DcOnly_C<12, int32_t>, Adst16_C<int32_t>,
+ /*is_row=*/false>;
+#endif
+#ifndef LIBGAV1_Dsp12bpp_Transform1dSize4_Transform1dIdentity
+ dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kRow] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dIdentity,
+ Identity4DcOnly_C<12, int32_t>, Identity4Row_C<int32_t>,
+ /*is_row=*/true>;
+ dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kColumn] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dIdentity,
+ Identity4DcOnly_C<12, int32_t>,
+ Identity4Column_C<int32_t>, /*is_row=*/false>;
+#endif
+#ifndef LIBGAV1_Dsp12bpp_Transform1dSize8_Transform1dIdentity
+ dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kRow] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dIdentity,
+ Identity8DcOnly_C<12, int32_t>, Identity8Row_C<int32_t>,
+ /*is_row=*/true>;
+ dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kColumn] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dIdentity,
+ Identity8DcOnly_C<12, int32_t>,
+ Identity8Column_C<int32_t>, /*is_row=*/false>;
+#endif
+#ifndef LIBGAV1_Dsp12bpp_Transform1dSize16_Transform1dIdentity
+ dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kRow] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dIdentity,
+ Identity16DcOnly_C<12, int32_t>, Identity16Row_C<int32_t>,
+ /*is_row=*/true>;
+ dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kColumn] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dIdentity,
+ Identity16DcOnly_C<12, int32_t>,
+ Identity16Column_C<int32_t>, /*is_row=*/false>;
+#endif
+#ifndef LIBGAV1_Dsp12bpp_Transform1dSize32_Transform1dIdentity
+ dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kRow] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dIdentity,
+ Identity32DcOnly_C<12, int32_t>, Identity32Row_C<int32_t>,
+ /*is_row=*/true>;
+ dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kColumn] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dIdentity,
+ Identity32DcOnly_C<12, int32_t>,
+ Identity32Column_C<int32_t>, /*is_row=*/false>;
+#endif
+#ifndef LIBGAV1_Dsp12bpp_Transform1dSize4_Transform1dWht
+ dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kRow] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dWht,
+ Wht4DcOnly_C<12, int32_t>, Wht4_C<int32_t>,
+ /*is_row=*/true>;
+ dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kColumn] =
+ TransformLoop_C<12, int32_t, uint16_t, kTransform1dWht,
+ Wht4DcOnly_C<12, int32_t>, Wht4_C<int32_t>,
+ /*is_row=*/false>;
+#endif
+#endif // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+}
+#endif // LIBGAV1_MAX_BITDEPTH == 12
+
} // namespace
void InverseTransformInit_C() {
@@ -1619,10 +1820,12 @@ void InverseTransformInit_C() {
#if LIBGAV1_MAX_BITDEPTH >= 10
Init10bpp();
#endif
+#if LIBGAV1_MAX_BITDEPTH == 12
+ Init12bpp();
+#endif
// Local functions that may be unused depending on the optimizations
// available.
- static_cast<void>(RangeCheckValue);
static_cast<void>(kBitReverseLookup);
}