diff options
Diffstat (limited to 'src/utils/common_test.cc')
-rw-r--r-- | src/utils/common_test.cc | 604 |
1 files changed, 604 insertions, 0 deletions
diff --git a/src/utils/common_test.cc b/src/utils/common_test.cc new file mode 100644 index 0000000..fdb218d --- /dev/null +++ b/src/utils/common_test.cc @@ -0,0 +1,604 @@ +// Copyright 2021 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/utils/common.h" + +#include <cstddef> +#include <cstdint> +#include <memory> +#include <string> + +#include "absl/base/macros.h" +#include "gtest/gtest.h" +#include "src/utils/constants.h" +#include "src/utils/memory.h" +#include "src/utils/types.h" + +namespace libgav1 { +namespace { + +int BitLength(int64_t n) { + int count = 0; + while (n != 0) { + ++count; + n >>= 1; + } + return count; +} + +TEST(CommonUtilsTest, Align) { + for (int i = 0; i <= 8; ++i) { + const int alignment = 1 << i; + SCOPED_TRACE("alignment: " + std::to_string(alignment)); + EXPECT_EQ(Align(0, alignment), 0); + EXPECT_EQ(Align(1, alignment), alignment); + EXPECT_EQ(Align(alignment + 1, alignment), 2 * alignment); + if (i > 1) { + EXPECT_EQ(Align(alignment - 1, alignment), alignment); + EXPECT_EQ(Align(2 * alignment - 1, alignment), 2 * alignment); + } + } +} + +TEST(CommonUtilsTest, AlignAddr) { + auto buf = MakeAlignedUniquePtr<uint8_t>(/*alignment=*/1024, 512); + ASSERT_NE(buf, nullptr); + auto* const bufptr = buf.get(); + ASSERT_EQ(reinterpret_cast<uintptr_t>(bufptr) % 1024, 0); + + for (int i = 0; i <= 8; ++i) { + const int alignment = 1 << i; + ASSERT_LE(alignment, 1024); + SCOPED_TRACE("alignment: " + std::to_string(alignment)); + EXPECT_EQ(AlignAddr(nullptr, alignment), nullptr); + EXPECT_EQ(AlignAddr(bufptr, alignment), bufptr); + EXPECT_EQ(AlignAddr(bufptr + 1, alignment), bufptr + alignment); + EXPECT_EQ(AlignAddr(bufptr + alignment + 1, alignment), + bufptr + 2 * alignment); + if (i > 1) { + EXPECT_EQ(AlignAddr(bufptr + alignment - 1, alignment), + bufptr + alignment); + EXPECT_EQ(AlignAddr(bufptr + 2 * alignment - 1, alignment), + bufptr + 2 * alignment); + } + } +} + +TEST(CommonUtilsTest, Clip3) { + // Value <= lower boundary. + EXPECT_EQ(Clip3(10, 20, 30), 20); + EXPECT_EQ(Clip3(20, 20, 30), 20); + // Value >= higher boundary. + EXPECT_EQ(Clip3(40, 20, 30), 30); + EXPECT_EQ(Clip3(30, 20, 30), 30); + // Value within boundary. + EXPECT_EQ(Clip3(25, 20, 30), 25); + // Clipping based on bitdepth (clamp between 0 and 2^bitdepth - 1). Make sure + // that the resulting values are always in the pixel range for the + // corresponding bitdepth. + static constexpr int bitdepths[] = {8, 10, 12}; + static constexpr int pixels[] = {100, 500, 5000, -100, -500, -5000}; + for (const auto& bitdepth : bitdepths) { + for (const auto& pixel : pixels) { + const int clipped_pixel = Clip3(pixel, 0, (1 << bitdepth) - 1); + EXPECT_GE(clipped_pixel, 0) + << "Clip3 mismatch for bitdepth: " << bitdepth << " pixel: " << pixel; + EXPECT_LE(clipped_pixel, (1 << bitdepth) - 1) + << "Clip3 mismatch for bitdepth: " << bitdepth << " pixel: " << pixel; + } + } +} + +template <typename Pixel> +void TestExtendLine(int width, const int left, int right, Pixel left_value, + Pixel right_value) { + constexpr int size = 1000; + ASSERT_LE(width + left + right, size); + Pixel line[size]; + Pixel* line_start = line + left; + line_start[0] = left_value; + line_start[width - 1] = right_value; + ExtendLine<Pixel>(line_start, width, left, right); + for (int x = 0; x < left; x++) { + EXPECT_EQ(left_value, line[x]) << "Left side mismatch at x: " << x; + } + for (int x = 0; x < right; x++) { + EXPECT_EQ(right_value, line[left + width + x]) + << "Right side mismatch at x: " << x; + } +} + +TEST(CommonUtilsTest, ExtendLine) { + TestExtendLine<uint8_t>(300, 0, 0, 31, 13); + TestExtendLine<uint8_t>(100, 10, 20, 31, 13); + TestExtendLine<uint8_t>(257, 31, 77, 59, 255); + TestExtendLine<uint16_t>(600, 0, 0, 1234, 4321); + TestExtendLine<uint16_t>(200, 55, 88, 12345, 54321); + TestExtendLine<uint16_t>(2, 99, 333, 257, 513); +} + +template <typename T> +void TestMemSetBlock(int rows, int columns, ptrdiff_t stride, T value) { + constexpr int size = 1000; + T block[size]; + static_assert(sizeof(T) == 1, ""); + ASSERT_LE(rows * stride, size); + ASSERT_LE(columns, stride); + MemSetBlock<T>(rows, columns, value, block, stride); + for (int y = 0; y < rows; y++) { + for (int x = 0; x < columns; x++) { + EXPECT_EQ(value, block[y * stride + x]) + << "Mismatch at y: " << y << " x: " << x; + } + } +} + +TEST(CommonUtilsTest, MemSetBlock) { + TestMemSetBlock<bool>(15, 28, 29, true); + TestMemSetBlock<bool>(17, 1, 24, false); + TestMemSetBlock<bool>(7, 2, 13, true); + TestMemSetBlock<int8_t>(35, 17, 19, 123); + TestMemSetBlock<uint8_t>(19, 16, 16, 234); +} + +template <typename T> +void TestSetBlock(int rows, int columns, ptrdiff_t stride, T value) { + constexpr int size = 1000; + T block[size]; + ASSERT_LE(rows * stride, size); + ASSERT_LE(columns, stride); + SetBlock<T>(rows, columns, value, block, stride); + for (int y = 0; y < rows; y++) { + for (int x = 0; x < columns; x++) { + EXPECT_EQ(value, block[y * stride + x]) + << "Mismatch at y: " << y << " x: " << x; + } + } +} + +TEST(CommonUtilsTest, SetBlock) { + // Test 1-byte block set. + TestSetBlock<bool>(15, 28, 29, true); + TestSetBlock<bool>(17, 1, 24, false); + TestSetBlock<bool>(7, 2, 13, true); + TestSetBlock<int8_t>(35, 17, 19, 123); + TestSetBlock<uint8_t>(19, 16, 16, 234); + // Test 2-byte block set. + TestSetBlock<int16_t>(23, 27, 28, 1234); + TestSetBlock<uint16_t>(13, 39, 44, 4321); + // Test 4-byte block set. + TestSetBlock<int>(14, 7, 7, 12345); + TestSetBlock<int>(33, 4, 15, 54321); + // Test pointer block set. + int data; + TestSetBlock<int*>(23, 8, 25, &data); +} + +TEST(CommonUtilsTest, CountTrailingZeros) { + EXPECT_EQ(CountTrailingZeros(0x1), 0); + EXPECT_EQ(CountTrailingZeros(0x3), 0); + EXPECT_EQ(CountTrailingZeros(0x7), 0); + EXPECT_EQ(CountTrailingZeros(0xF), 0); + EXPECT_EQ(CountTrailingZeros(0x2), 1); + EXPECT_EQ(CountTrailingZeros(0x6), 1); + EXPECT_EQ(CountTrailingZeros(0xE), 1); + EXPECT_EQ(CountTrailingZeros(0x4), 2); + EXPECT_EQ(CountTrailingZeros(0xC), 2); + EXPECT_EQ(CountTrailingZeros(0x8), 3); + EXPECT_EQ(CountTrailingZeros(0x10), 4); + EXPECT_EQ(CountTrailingZeros(0x30), 4); + EXPECT_EQ(CountTrailingZeros(0x70), 4); + EXPECT_EQ(CountTrailingZeros(0xF0), 4); + EXPECT_EQ(CountTrailingZeros(0x20), 5); + EXPECT_EQ(CountTrailingZeros(0x60), 5); + EXPECT_EQ(CountTrailingZeros(0xE0), 5); + EXPECT_EQ(CountTrailingZeros(0x40), 6); + EXPECT_EQ(CountTrailingZeros(0xC0), 6); + EXPECT_EQ(CountTrailingZeros(0x80), 7); + EXPECT_EQ(CountTrailingZeros(0x31), 0); + EXPECT_EQ(CountTrailingZeros(0x32), 1); + EXPECT_EQ(CountTrailingZeros(0x34), 2); + EXPECT_EQ(CountTrailingZeros(0x38), 3); + EXPECT_EQ(CountTrailingZeros(0x310), 4); + EXPECT_EQ(CountTrailingZeros(0x320), 5); + EXPECT_EQ(CountTrailingZeros(0x340), 6); + EXPECT_EQ(CountTrailingZeros(0x380), 7); +} + +TEST(CommonUtilsTest, FloorLog2) { + // Powers of 2. + EXPECT_EQ(FloorLog2(1), 0); + EXPECT_EQ(FloorLog2(2), 1); + EXPECT_EQ(FloorLog2(8), 3); + EXPECT_EQ(FloorLog2(64), 6); + // Powers of 2 +/- 1. + EXPECT_EQ(FloorLog2(9), 3); + EXPECT_EQ(FloorLog2(15), 3); + EXPECT_EQ(FloorLog2(63), 5); + // Large value, smaller than 32 bit. + EXPECT_EQ(FloorLog2(0x7fffffff), 30); + EXPECT_EQ(FloorLog2(0x80000000), 31); + // Larger than 32 bit. + EXPECT_EQ(FloorLog2(uint64_t{0x7fffffffffffffff}), 62); + EXPECT_EQ(FloorLog2(uint64_t{0x8000000000000000}), 63); + EXPECT_EQ(FloorLog2(uint64_t{0xffffffffffffffff}), 63); +} + +TEST(CommonUtilsTest, CeilLog2) { + // Even though log2(0) is -inf, here we explicitly define it to be 0. + EXPECT_EQ(CeilLog2(0), 0); + // Powers of 2. + EXPECT_EQ(CeilLog2(1), 0); + EXPECT_EQ(CeilLog2(2), 1); + EXPECT_EQ(CeilLog2(8), 3); + EXPECT_EQ(CeilLog2(64), 6); + // Powers of 2 +/- 1. + EXPECT_EQ(CeilLog2(9), 4); + EXPECT_EQ(CeilLog2(15), 4); + EXPECT_EQ(CeilLog2(63), 6); + // Large value. + EXPECT_EQ(CeilLog2(0x7fffffff), 31); +} + +TEST(CommonUtilsTest, RightShiftWithCeiling) { + // Shift 1 bit. + EXPECT_EQ(RightShiftWithCeiling(1, 1), 1); + EXPECT_EQ(RightShiftWithCeiling(2, 1), 1); + EXPECT_EQ(RightShiftWithCeiling(3, 1), 2); + EXPECT_EQ(RightShiftWithCeiling(4, 1), 2); + EXPECT_EQ(RightShiftWithCeiling(5, 1), 3); + // Shift 2 bits. + EXPECT_EQ(RightShiftWithCeiling(1, 2), 1); + EXPECT_EQ(RightShiftWithCeiling(2, 2), 1); + EXPECT_EQ(RightShiftWithCeiling(3, 2), 1); + EXPECT_EQ(RightShiftWithCeiling(4, 2), 1); + EXPECT_EQ(RightShiftWithCeiling(5, 2), 2); + // Shift 20 bits. + EXPECT_EQ(RightShiftWithCeiling(1, 20), 1); + EXPECT_EQ(RightShiftWithCeiling((1 << 20) - 1, 20), 1); + EXPECT_EQ(RightShiftWithCeiling(1 << 20, 20), 1); + EXPECT_EQ(RightShiftWithCeiling((1 << 20) + 1, 20), 2); + EXPECT_EQ(RightShiftWithCeiling((1 << 21) - 1, 20), 2); +} + +template <typename Input, typename Output> +void VerifyRightShiftWithRounding(const Input* const values, + const int* const bits, + const Output* const rounded_values, + size_t count) { + for (size_t i = 0; i < count; ++i) { + const Output rounded_value = RightShiftWithRounding(values[i], bits[i]); + EXPECT_EQ(rounded_value, rounded_values[i]) << "Mismatch at index " << i; + // Rounding reduces the bit length by |bits[i]| - 1. + EXPECT_LE(BitLength(rounded_value), BitLength(values[i]) - (bits[i] - 1)) + << "Mismatch at index " << i; + } +} + +TEST(CommonUtilTest, RightShiftWithRoundingInt32) { + static constexpr int32_t values[] = {5, 203, 204, 255, 40000, 50000}; + static constexpr int bits[] = {0, 3, 3, 3, 12, 12}; + static constexpr int32_t rounded_values[] = {5, 25, 26, 32, 10, 12}; + static_assert(ABSL_ARRAYSIZE(values) == ABSL_ARRAYSIZE(bits), ""); + static_assert(ABSL_ARRAYSIZE(values) == ABSL_ARRAYSIZE(rounded_values), ""); + VerifyRightShiftWithRounding<int32_t, int32_t>(values, bits, rounded_values, + ABSL_ARRAYSIZE(values)); +} + +TEST(CommonUtilTest, RightShiftWithRoundingUint32) { + static constexpr uint32_t values[] = {5, 203, 204, 255, + 40000, 50000, 0x7fffffff}; + static constexpr int bits[] = {0, 3, 3, 3, 12, 12, 20}; + static constexpr uint32_t rounded_values[] = {5, 25, 26, 32, 10, 12, 2048}; + static_assert(ABSL_ARRAYSIZE(values) == ABSL_ARRAYSIZE(bits), ""); + static_assert(ABSL_ARRAYSIZE(values) == ABSL_ARRAYSIZE(rounded_values), ""); + VerifyRightShiftWithRounding<uint32_t, uint32_t>(values, bits, rounded_values, + ABSL_ARRAYSIZE(values)); +} + +TEST(CommonUtilTest, RightShiftWithRoundingInt64) { + static constexpr int64_t values[] = {5, 203, 204, 255, + 40000, 50000, 0x7fffffff, 0x8fffffff}; + static constexpr int bits[] = {0, 3, 3, 3, 12, 12, 20, 20}; + static constexpr int32_t rounded_values[] = {5, 25, 26, 32, + 10, 12, 2048, 2304}; + static_assert(ABSL_ARRAYSIZE(values) == ABSL_ARRAYSIZE(bits), ""); + static_assert(ABSL_ARRAYSIZE(values) == ABSL_ARRAYSIZE(rounded_values), ""); + VerifyRightShiftWithRounding<int64_t, int32_t>(values, bits, rounded_values, + ABSL_ARRAYSIZE(values)); +} + +template <typename Input> +void VerifyRightShiftWithRoundingSigned(const Input* const values, + const int* const bits, + const int32_t* const rounded_values, + int count) { + for (int i = 0; i < count; ++i) { + int32_t rounded_value = RightShiftWithRoundingSigned(values[i], bits[i]); + EXPECT_EQ(rounded_value, rounded_values[i]) << "Mismatch at index " << i; + rounded_value = RightShiftWithRoundingSigned(-values[i], bits[i]); + EXPECT_EQ(rounded_value, -rounded_values[i]) << "Mismatch at index " << i; + } +} + +TEST(CommonUtilTest, RightShiftWithRoundingSignedInt32) { + static constexpr int32_t values[] = {203, 204, 255, 40000, 50000}; + static constexpr int bits[] = {3, 3, 3, 12, 12}; + static constexpr int32_t rounded_values[] = {25, 26, 32, 10, 12}; + static_assert(ABSL_ARRAYSIZE(values) == ABSL_ARRAYSIZE(bits), ""); + static_assert(ABSL_ARRAYSIZE(values) == ABSL_ARRAYSIZE(rounded_values), ""); + VerifyRightShiftWithRoundingSigned<int32_t>(values, bits, rounded_values, + ABSL_ARRAYSIZE(values)); +} + +TEST(CommonUtilTest, RightShiftWithRoundingSignedInt64) { + static constexpr int64_t values[] = {203, 204, 255, 40000, + 50000, 0x7fffffff, 0x8fffffff}; + static constexpr int bits[] = {3, 3, 3, 12, 12, 20, 20}; + static constexpr int32_t rounded_values[] = {25, 26, 32, 10, 12, 2048, 2304}; + static_assert(ABSL_ARRAYSIZE(values) == ABSL_ARRAYSIZE(bits), ""); + static_assert(ABSL_ARRAYSIZE(values) == ABSL_ARRAYSIZE(rounded_values), ""); + VerifyRightShiftWithRoundingSigned<int64_t>(values, bits, rounded_values, + ABSL_ARRAYSIZE(values)); +} + +TEST(CommonUtilTest, GetResidualBufferSize) { + // No subsampling. + EXPECT_EQ(GetResidualBufferSize(64, 64, 0, 0, 2), + /* 2*(64*64*3/1 + 32*4) = */ 24832); + // Only X is subsampled. + EXPECT_EQ(GetResidualBufferSize(64, 64, 1, 0, 2), + /* 2*(64*64*2/1 + 32*4) = */ 16640); + // Only Y is subsampled. + EXPECT_EQ(GetResidualBufferSize(64, 64, 0, 1, 2), + /* 2*(64*64*2/1 + 32*4) = */ 16640); + // Both X and Y are subsampled. + EXPECT_EQ(GetResidualBufferSize(64, 64, 1, 1, 2), + /* 2*(64*64*3/2 + 32*4) = */ 12544); +} + +//------------------------------------------------------------------------------ +// Tests for bitstream util functions + +TEST(BitstreamUtilTest, IsIntraFrame) { + EXPECT_TRUE(IsIntraFrame(kFrameKey)); + EXPECT_TRUE(IsIntraFrame(kFrameIntraOnly)); + EXPECT_FALSE(IsIntraFrame(kFrameInter)); + EXPECT_FALSE(IsIntraFrame(kFrameSwitch)); +} + +TEST(BitstreamUtilTest, GetTransformClass) { + static constexpr TransformClass expected_classes[kNumTransformTypes] = { + kTransformClass2D, kTransformClass2D, + kTransformClass2D, kTransformClass2D, + kTransformClass2D, kTransformClass2D, + kTransformClass2D, kTransformClass2D, + kTransformClass2D, kTransformClass2D, + kTransformClassVertical, kTransformClassHorizontal, + kTransformClassVertical, kTransformClassHorizontal, + kTransformClassVertical, kTransformClassHorizontal, + }; + for (int i = 0; i < kNumTransformTypes; ++i) { + EXPECT_EQ(GetTransformClass(static_cast<TransformType>(i)), + expected_classes[i]) + << "Mismatch at index " << i; + } +} + +TEST(BitstreamUtilTest, RowOrColumn4x4ToPixel) { + EXPECT_EQ(RowOrColumn4x4ToPixel(10, kPlaneY, 0), 40); + EXPECT_EQ(RowOrColumn4x4ToPixel(10, kPlaneY, 1), + 40); // Subsampling should have no effect on Y plane. + EXPECT_EQ(RowOrColumn4x4ToPixel(10, kPlaneU, 0), 40); + EXPECT_EQ(RowOrColumn4x4ToPixel(10, kPlaneU, 1), 20); + EXPECT_EQ(RowOrColumn4x4ToPixel(10, kPlaneV, 0), 40); + EXPECT_EQ(RowOrColumn4x4ToPixel(10, kPlaneV, 1), 20); +} + +TEST(BitstreamUtilTest, GetPlaneType) { + EXPECT_EQ(GetPlaneType(kPlaneY), kPlaneTypeY); + EXPECT_EQ(GetPlaneType(kPlaneU), kPlaneTypeUV); + EXPECT_EQ(GetPlaneType(kPlaneV), kPlaneTypeUV); +} + +TEST(BitstreamUtils, IsDirectionalMode) { + static constexpr bool is_directional_modes[kNumPredictionModes] = { + false, true, true, true, true, true, true, true, true, + false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, + }; + for (int i = 0; i < kNumPredictionModes; ++i) { + EXPECT_EQ(IsDirectionalMode(static_cast<PredictionMode>(i)), + is_directional_modes[i]) + << "Mismatch at index " << i; + } +} + +TEST(BitstreamUtils, GetRelativeDistance) { + // Both order_hint_bits and order_hint_shift_bits are zero. (a and b must be + // zero.) + EXPECT_EQ(GetRelativeDistance(0, 0, 0), 0); + EXPECT_EQ(GetRelativeDistance(10, 20, 27), -10); + + EXPECT_EQ(GetRelativeDistance(2, 1, 30), 1); + EXPECT_EQ(GetRelativeDistance(2, 1, 29), 1); + + EXPECT_EQ(GetRelativeDistance(1, 2, 30), -1); + EXPECT_EQ(GetRelativeDistance(1, 2, 29), -1); + + // With an order_hint_bits of 4 and an order_hint_shift_bits of 28, 16 is the + // same as 0, 17 is the same as 1, etc. The most positive distance is 7, and + // the most negative distance is -8. + + EXPECT_EQ(GetRelativeDistance(2, 6, 28), -4); + EXPECT_EQ(GetRelativeDistance(6, 2, 28), 4); + // 18 - 14 = 4. + EXPECT_EQ(GetRelativeDistance(2, 14, 28), 4); + // 14 - 18 = -4. + EXPECT_EQ(GetRelativeDistance(14, 2, 28), -4); + // If a and b are exactly 8 apart, GetRelativeDistance() cannot tell whether + // a is before or after b. GetRelativeDistance(a, b) and + // GetRelativeDistance(b, a) are both -8. + // 1 - 9 = -8. + EXPECT_EQ(GetRelativeDistance(1, 9, 28), -8); + // 9 - 17 = -8. + EXPECT_EQ(GetRelativeDistance(9, 1, 28), -8); + + // With an order_hint_bits of 5 and an order_hint_shift_bits of 27, 32 is the + // same as 0, 33 is the same as 1, etc. The most positive distance is 15, and + // the most negative distance is -16. + + // 31 - 32 = -1. + EXPECT_EQ(GetRelativeDistance(31, 0, 27), -1); + // 32 - 31 = 1. + EXPECT_EQ(GetRelativeDistance(0, 31, 27), 1); + // 30 - 33 = -3. + EXPECT_EQ(GetRelativeDistance(30, 1, 27), -3); + // 33 - 30 = 3. + EXPECT_EQ(GetRelativeDistance(1, 30, 27), 3); + // 25 - 36 = -11. + EXPECT_EQ(GetRelativeDistance(25, 4, 27), -11); + // 36 - 25 = 11. + EXPECT_EQ(GetRelativeDistance(4, 25, 27), 11); + // 15 - 0 = 15. + EXPECT_EQ(GetRelativeDistance(15, 0, 27), 15); + // If a and b are exactly 16 apart, GetRelativeDistance() cannot tell whether + // a is before or after b. GetRelativeDistance(a, b) and + // GetRelativeDistance(b, a) are both -16. + // 16 - 32 = -16. + EXPECT_EQ(GetRelativeDistance(16, 0, 27), -16); + // 0 - 16 = -16. + EXPECT_EQ(GetRelativeDistance(0, 16, 27), -16); +} + +TEST(BitstreamUtils, ApplySign) { + // ApplyPositive(0) = 0 + EXPECT_EQ(ApplySign(0, 0), 0); + // ApplyNegative(0) = 0 + EXPECT_EQ(ApplySign(0, -1), 0); + + // ApplyPositive(1) = 1 + EXPECT_EQ(ApplySign(1, 0), 1); + // ApplyNegative(1) = -1 + EXPECT_EQ(ApplySign(1, -1), -1); + + // ApplyPositive(-1) = -1 + EXPECT_EQ(ApplySign(-1, 0), -1); + // ApplyNegative(-1) = 1 + EXPECT_EQ(ApplySign(-1, -1), 1); + + // ApplyPositive(1234) = 1234 + EXPECT_EQ(ApplySign(1234, 0), 1234); + // ApplyNegative(1234) = -1234 + EXPECT_EQ(ApplySign(1234, -1), -1234); + + // ApplyPositive(-1234) = -1234 + EXPECT_EQ(ApplySign(-1234, 0), -1234); + // ApplyNegative(-1234) = 1234 + EXPECT_EQ(ApplySign(-1234, -1), 1234); +} + +// 7.9.3. (without the clamp for numerator and denominator). +int SpecGetMvProjectionKernel(int mv, int numerator, int denominator) { + int value = mv * numerator * kProjectionMvDivisionLookup[denominator]; + if (value >= 0) { + value += 1 << 13; + value >>= 14; + } else { + value = -value; + value += 1 << 13; + value >>= 14; + value = -value; + } + if (value < (-(1 << 14) + 1)) value = -(1 << 14) + 1; + if (value > (1 << 14) - 1) value = (1 << 14) - 1; + return value; +} + +void SpecGetMvProjectionNoClamp(const MotionVector& mv, int numerator, + int denominator, MotionVector* projection_mv) { + for (int i = 0; i < 2; ++i) { + projection_mv->mv[i] = + SpecGetMvProjectionKernel(mv.mv[i], numerator, denominator); + } +} + +TEST(BitstreamUtils, GetMvProjection) { + const int16_t mvs[5][2] = { + {0, 0}, {11, 73}, {-84, 272}, {733, -827}, {-472, -697}}; + for (auto& mv_value : mvs) { + for (int numerator = -kMaxFrameDistance; numerator <= kMaxFrameDistance; + ++numerator) { + for (int denominator = 0; denominator <= kMaxFrameDistance; + ++denominator) { + MotionVector mv, projection_mv, spec_projection_mv; + mv.mv[0] = mv_value[0]; + mv.mv[1] = mv_value[1]; + GetMvProjection(mv, numerator, kProjectionMvDivisionLookup[denominator], + &projection_mv); + SpecGetMvProjectionNoClamp(mv, numerator, denominator, + &spec_projection_mv); + EXPECT_EQ(projection_mv.mv32, spec_projection_mv.mv32); + } + } + } +} + +// 7.9.4. +int SpecProject(int value, int delta, int dst_sign) { + constexpr int kMiSizeLog2 = 2; + const int sign = (dst_sign == 0) ? 1 : dst_sign; + int offset; + if (delta >= 0) { + offset = delta >> (3 + 1 + kMiSizeLog2); + } else { + offset = -((-delta) >> (3 + 1 + kMiSizeLog2)); + } + return value + sign * offset; +} + +TEST(BitstreamUtils, Project) { + for (int value = -10; value <= 10; ++value) { + for (int delta = -256; delta <= 256; ++delta) { + for (int dst_sign = -1; dst_sign <= 0; ++dst_sign) { + EXPECT_EQ(Project(value, delta, dst_sign), + SpecProject(value, delta, dst_sign)); + } + } + } +} + +TEST(BitstreamUtils, IsBlockSmallerThan8x8) { + static constexpr bool is_block_smaller_than8x8[kMaxBlockSizes] = { + true, true, false, true, false, false, false, false, + false, false, false, false, false, false, false, false, + false, false, false, false, false, false, + }; + for (int i = 0; i < kMaxBlockSizes; ++i) { + EXPECT_EQ(IsBlockSmallerThan8x8(static_cast<BlockSize>(i)), + is_block_smaller_than8x8[i]) + << "Mismatch at index " << i; + } +} + +TEST(BitstreamUtils, TransformSizeToSquareTransformIndex) { + EXPECT_EQ(TransformSizeToSquareTransformIndex(kTransformSize4x4), 0); + EXPECT_EQ(TransformSizeToSquareTransformIndex(kTransformSize8x8), 1); + EXPECT_EQ(TransformSizeToSquareTransformIndex(kTransformSize16x16), 2); + EXPECT_EQ(TransformSizeToSquareTransformIndex(kTransformSize32x32), 3); + EXPECT_EQ(TransformSizeToSquareTransformIndex(kTransformSize64x64), 4); +} + +} // namespace +} // namespace libgav1 |