diff options
Diffstat (limited to 'src/threading_strategy_test.cc')
-rw-r--r-- | src/threading_strategy_test.cc | 281 |
1 files changed, 281 insertions, 0 deletions
diff --git a/src/threading_strategy_test.cc b/src/threading_strategy_test.cc new file mode 100644 index 0000000..2a7a781 --- /dev/null +++ b/src/threading_strategy_test.cc @@ -0,0 +1,281 @@ +// Copyright 2021 The libgav1 Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/threading_strategy.h" + +#include <memory> +#include <utility> +#include <vector> + +#include "absl/strings/str_cat.h" +#include "gtest/gtest.h" +#include "src/frame_scratch_buffer.h" +#include "src/obu_parser.h" +#include "src/utils/constants.h" +#include "src/utils/threadpool.h" +#include "src/utils/types.h" + +namespace libgav1 { +namespace { + +class ThreadingStrategyTest : public testing::Test { + protected: + ThreadingStrategy strategy_; + ObuFrameHeader frame_header_ = {}; +}; + +TEST_F(ThreadingStrategyTest, MaxThreadEnforced) { + frame_header_.tile_info.tile_count = 32; + ASSERT_TRUE(strategy_.Reset(frame_header_, 32)); + EXPECT_NE(strategy_.tile_thread_pool(), nullptr); + for (int i = 0; i < 32; ++i) { + EXPECT_EQ(strategy_.row_thread_pool(i), nullptr); + } + EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); +} + +TEST_F(ThreadingStrategyTest, UseAllThreadsForTiles) { + frame_header_.tile_info.tile_count = 8; + ASSERT_TRUE(strategy_.Reset(frame_header_, 8)); + EXPECT_NE(strategy_.tile_thread_pool(), nullptr); + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(strategy_.row_thread_pool(i), nullptr); + } + EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); +} + +TEST_F(ThreadingStrategyTest, RowThreads) { + frame_header_.tile_info.tile_count = 2; + ASSERT_TRUE(strategy_.Reset(frame_header_, 8)); + EXPECT_NE(strategy_.tile_thread_pool(), nullptr); + // Each tile should get 3 threads each. + for (int i = 0; i < 2; ++i) { + EXPECT_NE(strategy_.row_thread_pool(i), nullptr); + } + EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); +} + +TEST_F(ThreadingStrategyTest, RowThreadsUnequal) { + frame_header_.tile_info.tile_count = 2; + + ASSERT_TRUE(strategy_.Reset(frame_header_, 9)); + EXPECT_NE(strategy_.tile_thread_pool(), nullptr); + EXPECT_NE(strategy_.row_thread_pool(0), nullptr); + EXPECT_NE(strategy_.row_thread_pool(1), nullptr); + EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); +} + +// Test a random combination of tile_count and thread_count. +TEST_F(ThreadingStrategyTest, MultipleCalls) { + frame_header_.tile_info.tile_count = 2; + ASSERT_TRUE(strategy_.Reset(frame_header_, 8)); + EXPECT_NE(strategy_.tile_thread_pool(), nullptr); + for (int i = 0; i < 2; ++i) { + EXPECT_NE(strategy_.row_thread_pool(i), nullptr); + } + EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); + + frame_header_.tile_info.tile_count = 8; + ASSERT_TRUE(strategy_.Reset(frame_header_, 8)); + EXPECT_NE(strategy_.tile_thread_pool(), nullptr); + // Row threads must have been reset. + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(strategy_.row_thread_pool(i), nullptr); + } + EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); + + frame_header_.tile_info.tile_count = 8; + ASSERT_TRUE(strategy_.Reset(frame_header_, 16)); + EXPECT_NE(strategy_.tile_thread_pool(), nullptr); + for (int i = 0; i < 8; ++i) { + EXPECT_NE(strategy_.row_thread_pool(i), nullptr); + } + EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); + + frame_header_.tile_info.tile_count = 4; + ASSERT_TRUE(strategy_.Reset(frame_header_, 16)); + EXPECT_NE(strategy_.tile_thread_pool(), nullptr); + for (int i = 0; i < 4; ++i) { + EXPECT_NE(strategy_.row_thread_pool(i), nullptr); + } + // All the other row threads must be reset. + for (int i = 4; i < 8; ++i) { + EXPECT_EQ(strategy_.row_thread_pool(i), nullptr); + } + EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); + + frame_header_.tile_info.tile_count = 4; + ASSERT_TRUE(strategy_.Reset(frame_header_, 6)); + EXPECT_NE(strategy_.tile_thread_pool(), nullptr); + // First two tiles will get 1 thread each. + for (int i = 0; i < 2; ++i) { + EXPECT_NE(strategy_.row_thread_pool(i), nullptr); + } + // All the other row threads must be reset. + for (int i = 2; i < 8; ++i) { + EXPECT_EQ(strategy_.row_thread_pool(i), nullptr); + } + EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); + + ASSERT_TRUE(strategy_.Reset(frame_header_, 1)); + EXPECT_EQ(strategy_.tile_thread_pool(), nullptr); + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(strategy_.row_thread_pool(i), nullptr); + } + EXPECT_EQ(strategy_.post_filter_thread_pool(), nullptr); +} + +// Tests the following order of calls (with thread count fixed at 4): +// * 1 Tile - 2 Tiles - 1 Tile. +TEST_F(ThreadingStrategyTest, MultipleCalls2) { + frame_header_.tile_info.tile_count = 1; + ASSERT_TRUE(strategy_.Reset(frame_header_, 4)); + // When there is only one tile, tile thread pool must be nullptr. + EXPECT_EQ(strategy_.tile_thread_pool(), nullptr); + EXPECT_NE(strategy_.row_thread_pool(0), nullptr); + for (int i = 1; i < 8; ++i) { + EXPECT_EQ(strategy_.row_thread_pool(i), nullptr); + } + EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); + + frame_header_.tile_info.tile_count = 2; + ASSERT_TRUE(strategy_.Reset(frame_header_, 4)); + EXPECT_NE(strategy_.tile_thread_pool(), nullptr); + for (int i = 0; i < 2; ++i) { + EXPECT_NE(strategy_.row_thread_pool(i), nullptr); + } + for (int i = 2; i < 8; ++i) { + EXPECT_EQ(strategy_.row_thread_pool(i), nullptr); + } + EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); + + frame_header_.tile_info.tile_count = 1; + ASSERT_TRUE(strategy_.Reset(frame_header_, 4)); + EXPECT_EQ(strategy_.tile_thread_pool(), nullptr); + EXPECT_NE(strategy_.row_thread_pool(0), nullptr); + for (int i = 1; i < 8; ++i) { + EXPECT_EQ(strategy_.row_thread_pool(i), nullptr); + } + EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); +} + +void VerifyFrameParallel(int thread_count, int tile_count, int tile_columns, + int expected_frame_threads, + const std::vector<int>& expected_tile_threads) { + ASSERT_EQ(expected_frame_threads, expected_tile_threads.size()); + ASSERT_GT(thread_count, 1); + std::unique_ptr<ThreadPool> frame_thread_pool; + FrameScratchBufferPool frame_scratch_buffer_pool; + ASSERT_TRUE(InitializeThreadPoolsForFrameParallel( + thread_count, tile_count, tile_columns, &frame_thread_pool, + &frame_scratch_buffer_pool)); + if (expected_frame_threads == 0) { + EXPECT_EQ(frame_thread_pool, nullptr); + return; + } + EXPECT_NE(frame_thread_pool.get(), nullptr); + EXPECT_EQ(frame_thread_pool->num_threads(), expected_frame_threads); + std::vector<std::unique_ptr<FrameScratchBuffer>> frame_scratch_buffers; + int actual_thread_count = frame_thread_pool->num_threads(); + for (int i = 0; i < expected_frame_threads; ++i) { + SCOPED_TRACE(absl::StrCat("i: ", i)); + frame_scratch_buffers.push_back(frame_scratch_buffer_pool.Get()); + ThreadPool* const thread_pool = + frame_scratch_buffers.back()->threading_strategy.thread_pool(); + if (expected_tile_threads[i] > 0) { + EXPECT_NE(thread_pool, nullptr); + EXPECT_EQ(thread_pool->num_threads(), expected_tile_threads[i]); + actual_thread_count += thread_pool->num_threads(); + } else { + EXPECT_EQ(thread_pool, nullptr); + } + } + EXPECT_EQ(thread_count, actual_thread_count); + for (auto& frame_scratch_buffer : frame_scratch_buffers) { + frame_scratch_buffer_pool.Release(std::move(frame_scratch_buffer)); + } +} + +TEST(FrameParallelStrategyTest, FrameParallel) { + // This loop has thread_count <= 3 * tile count. So there should be no frame + // threads irrespective of the number of tile columns. + for (int thread_count = 2; thread_count <= 6; ++thread_count) { + VerifyFrameParallel(thread_count, /*tile_count=*/2, /*tile_columns=*/1, + /*expected_frame_threads=*/0, + /*expected_tile_threads=*/{}); + VerifyFrameParallel(thread_count, /*tile_count=*/2, /*tile_columns=*/2, + /*expected_frame_threads=*/0, + /*expected_tile_threads=*/{}); + } + + // Equal number of tile threads for each frame thread. + VerifyFrameParallel( + /*thread_count=*/8, /*tile_count=*/1, /*tile_columns=*/1, + /*expected_frame_threads=*/4, /*expected_tile_threads=*/{1, 1, 1, 1}); + VerifyFrameParallel( + /*thread_count=*/12, /*tile_count=*/2, /*tile_columns=*/2, + /*expected_frame_threads=*/4, /*expected_tile_threads=*/{2, 2, 2, 2}); + VerifyFrameParallel( + /*thread_count=*/18, /*tile_count=*/2, /*tile_columns=*/2, + /*expected_frame_threads=*/6, + /*expected_tile_threads=*/{2, 2, 2, 2, 2, 2}); + VerifyFrameParallel( + /*thread_count=*/16, /*tile_count=*/3, /*tile_columns=*/3, + /*expected_frame_threads=*/4, /*expected_tile_threads=*/{3, 3, 3, 3}); + + // Unequal number of tile threads for each frame thread. + VerifyFrameParallel( + /*thread_count=*/7, /*tile_count=*/1, /*tile_columns=*/1, + /*expected_frame_threads=*/3, /*expected_tile_threads=*/{2, 1, 1}); + VerifyFrameParallel( + /*thread_count=*/14, /*tile_count=*/2, /*tile_columns=*/2, + /*expected_frame_threads=*/4, /*expected_tile_threads=*/{3, 3, 2, 2}); + VerifyFrameParallel( + /*thread_count=*/20, /*tile_count=*/2, /*tile_columns=*/2, + /*expected_frame_threads=*/6, + /*expected_tile_threads=*/{3, 3, 2, 2, 2, 2}); + VerifyFrameParallel( + /*thread_count=*/17, /*tile_count=*/3, /*tile_columns=*/3, + /*expected_frame_threads=*/4, /*expected_tile_threads=*/{4, 3, 3, 3}); +} + +TEST(FrameParallelStrategyTest, ThreadCountDoesNotExceedkMaxThreads) { + std::unique_ptr<ThreadPool> frame_thread_pool; + FrameScratchBufferPool frame_scratch_buffer_pool; + ASSERT_TRUE(InitializeThreadPoolsForFrameParallel( + /*thread_count=*/kMaxThreads + 10, /*tile_count=*/2, /*tile_columns=*/2, + &frame_thread_pool, &frame_scratch_buffer_pool)); + EXPECT_NE(frame_thread_pool.get(), nullptr); + std::vector<std::unique_ptr<FrameScratchBuffer>> frame_scratch_buffers; + int actual_thread_count = frame_thread_pool->num_threads(); + for (int i = 0; i < frame_thread_pool->num_threads(); ++i) { + SCOPED_TRACE(absl::StrCat("i: ", i)); + frame_scratch_buffers.push_back(frame_scratch_buffer_pool.Get()); + ThreadPool* const thread_pool = + frame_scratch_buffers.back()->threading_strategy.thread_pool(); + if (thread_pool != nullptr) { + actual_thread_count += thread_pool->num_threads(); + } + } + // In this case, the exact number of frame threads and tile threads depend on + // the value of kMaxThreads. So simply ensure that the total number of threads + // does not exceed kMaxThreads. + EXPECT_LE(actual_thread_count, kMaxThreads); + for (auto& frame_scratch_buffer : frame_scratch_buffers) { + frame_scratch_buffer_pool.Release(std::move(frame_scratch_buffer)); + } +} + +} // namespace +} // namespace libgav1 |