// Copyright 2021 The libgav1 Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "src/threading_strategy.h" #include #include #include #include "absl/strings/str_cat.h" #include "gtest/gtest.h" #include "src/frame_scratch_buffer.h" #include "src/obu_parser.h" #include "src/utils/constants.h" #include "src/utils/threadpool.h" #include "src/utils/types.h" namespace libgav1 { namespace { class ThreadingStrategyTest : public testing::Test { protected: ThreadingStrategy strategy_; ObuFrameHeader frame_header_ = {}; }; TEST_F(ThreadingStrategyTest, MaxThreadEnforced) { frame_header_.tile_info.tile_count = 32; ASSERT_TRUE(strategy_.Reset(frame_header_, 32)); EXPECT_NE(strategy_.tile_thread_pool(), nullptr); for (int i = 0; i < 32; ++i) { EXPECT_EQ(strategy_.row_thread_pool(i), nullptr); } EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); } TEST_F(ThreadingStrategyTest, UseAllThreadsForTiles) { frame_header_.tile_info.tile_count = 8; ASSERT_TRUE(strategy_.Reset(frame_header_, 8)); EXPECT_NE(strategy_.tile_thread_pool(), nullptr); for (int i = 0; i < 8; ++i) { EXPECT_EQ(strategy_.row_thread_pool(i), nullptr); } EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); } TEST_F(ThreadingStrategyTest, RowThreads) { frame_header_.tile_info.tile_count = 2; ASSERT_TRUE(strategy_.Reset(frame_header_, 8)); EXPECT_NE(strategy_.tile_thread_pool(), nullptr); // Each tile should get 3 threads each. for (int i = 0; i < 2; ++i) { EXPECT_NE(strategy_.row_thread_pool(i), nullptr); } EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); } TEST_F(ThreadingStrategyTest, RowThreadsUnequal) { frame_header_.tile_info.tile_count = 2; ASSERT_TRUE(strategy_.Reset(frame_header_, 9)); EXPECT_NE(strategy_.tile_thread_pool(), nullptr); EXPECT_NE(strategy_.row_thread_pool(0), nullptr); EXPECT_NE(strategy_.row_thread_pool(1), nullptr); EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); } // Test a random combination of tile_count and thread_count. TEST_F(ThreadingStrategyTest, MultipleCalls) { frame_header_.tile_info.tile_count = 2; ASSERT_TRUE(strategy_.Reset(frame_header_, 8)); EXPECT_NE(strategy_.tile_thread_pool(), nullptr); for (int i = 0; i < 2; ++i) { EXPECT_NE(strategy_.row_thread_pool(i), nullptr); } EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); frame_header_.tile_info.tile_count = 8; ASSERT_TRUE(strategy_.Reset(frame_header_, 8)); EXPECT_NE(strategy_.tile_thread_pool(), nullptr); // Row threads must have been reset. for (int i = 0; i < 8; ++i) { EXPECT_EQ(strategy_.row_thread_pool(i), nullptr); } EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); frame_header_.tile_info.tile_count = 8; ASSERT_TRUE(strategy_.Reset(frame_header_, 16)); EXPECT_NE(strategy_.tile_thread_pool(), nullptr); for (int i = 0; i < 8; ++i) { // See ThreadingStrategy::Reset(). #if defined(__ANDROID__) if (i >= 4) { EXPECT_EQ(strategy_.row_thread_pool(i), nullptr) << "i = " << i; continue; } #endif EXPECT_NE(strategy_.row_thread_pool(i), nullptr) << "i = " << i; } EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); frame_header_.tile_info.tile_count = 4; ASSERT_TRUE(strategy_.Reset(frame_header_, 16)); EXPECT_NE(strategy_.tile_thread_pool(), nullptr); for (int i = 0; i < 4; ++i) { EXPECT_NE(strategy_.row_thread_pool(i), nullptr); } // All the other row threads must be reset. for (int i = 4; i < 8; ++i) { EXPECT_EQ(strategy_.row_thread_pool(i), nullptr); } EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); frame_header_.tile_info.tile_count = 4; ASSERT_TRUE(strategy_.Reset(frame_header_, 6)); EXPECT_NE(strategy_.tile_thread_pool(), nullptr); // First two tiles will get 1 thread each. for (int i = 0; i < 2; ++i) { // See ThreadingStrategy::Reset(). #if defined(__ANDROID__) if (i == 1) { EXPECT_EQ(strategy_.row_thread_pool(i), nullptr) << "i = " << i; continue; } #endif EXPECT_NE(strategy_.row_thread_pool(i), nullptr) << "i = " << i; } // All the other row threads must be reset. for (int i = 2; i < 8; ++i) { EXPECT_EQ(strategy_.row_thread_pool(i), nullptr) << "i = " << i; } EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); ASSERT_TRUE(strategy_.Reset(frame_header_, 1)); EXPECT_EQ(strategy_.tile_thread_pool(), nullptr); for (int i = 0; i < 8; ++i) { EXPECT_EQ(strategy_.row_thread_pool(i), nullptr); } EXPECT_EQ(strategy_.post_filter_thread_pool(), nullptr); } // Tests the following order of calls (with thread count fixed at 4): // * 1 Tile - 2 Tiles - 1 Tile. TEST_F(ThreadingStrategyTest, MultipleCalls2) { frame_header_.tile_info.tile_count = 1; ASSERT_TRUE(strategy_.Reset(frame_header_, 4)); // When there is only one tile, tile thread pool must be nullptr. EXPECT_EQ(strategy_.tile_thread_pool(), nullptr); EXPECT_NE(strategy_.row_thread_pool(0), nullptr); for (int i = 1; i < 8; ++i) { EXPECT_EQ(strategy_.row_thread_pool(i), nullptr); } EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); frame_header_.tile_info.tile_count = 2; ASSERT_TRUE(strategy_.Reset(frame_header_, 4)); EXPECT_NE(strategy_.tile_thread_pool(), nullptr); for (int i = 0; i < 2; ++i) { // See ThreadingStrategy::Reset(). #if defined(__ANDROID__) if (i == 1) { EXPECT_EQ(strategy_.row_thread_pool(i), nullptr) << "i = " << i; continue; } #endif EXPECT_NE(strategy_.row_thread_pool(i), nullptr); } for (int i = 2; i < 8; ++i) { EXPECT_EQ(strategy_.row_thread_pool(i), nullptr); } EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); frame_header_.tile_info.tile_count = 1; ASSERT_TRUE(strategy_.Reset(frame_header_, 4)); EXPECT_EQ(strategy_.tile_thread_pool(), nullptr); EXPECT_NE(strategy_.row_thread_pool(0), nullptr); for (int i = 1; i < 8; ++i) { EXPECT_EQ(strategy_.row_thread_pool(i), nullptr); } EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr); } void VerifyFrameParallel(int thread_count, int tile_count, int tile_columns, int expected_frame_threads, const std::vector& expected_tile_threads) { ASSERT_EQ(expected_frame_threads, expected_tile_threads.size()); ASSERT_GT(thread_count, 1); std::unique_ptr frame_thread_pool; FrameScratchBufferPool frame_scratch_buffer_pool; ASSERT_TRUE(InitializeThreadPoolsForFrameParallel( thread_count, tile_count, tile_columns, &frame_thread_pool, &frame_scratch_buffer_pool)); if (expected_frame_threads == 0) { EXPECT_EQ(frame_thread_pool, nullptr); return; } EXPECT_NE(frame_thread_pool.get(), nullptr); EXPECT_EQ(frame_thread_pool->num_threads(), expected_frame_threads); std::vector> frame_scratch_buffers; int actual_thread_count = frame_thread_pool->num_threads(); for (int i = 0; i < expected_frame_threads; ++i) { SCOPED_TRACE(absl::StrCat("i: ", i)); frame_scratch_buffers.push_back(frame_scratch_buffer_pool.Get()); ThreadPool* const thread_pool = frame_scratch_buffers.back()->threading_strategy.thread_pool(); if (expected_tile_threads[i] > 0) { EXPECT_NE(thread_pool, nullptr); EXPECT_EQ(thread_pool->num_threads(), expected_tile_threads[i]); actual_thread_count += thread_pool->num_threads(); } else { EXPECT_EQ(thread_pool, nullptr); } } EXPECT_EQ(thread_count, actual_thread_count); for (auto& frame_scratch_buffer : frame_scratch_buffers) { frame_scratch_buffer_pool.Release(std::move(frame_scratch_buffer)); } } TEST(FrameParallelStrategyTest, FrameParallel) { // This loop has thread_count <= 3 * tile count. So there should be no frame // threads irrespective of the number of tile columns. for (int thread_count = 2; thread_count <= 6; ++thread_count) { VerifyFrameParallel(thread_count, /*tile_count=*/2, /*tile_columns=*/1, /*expected_frame_threads=*/0, /*expected_tile_threads=*/{}); VerifyFrameParallel(thread_count, /*tile_count=*/2, /*tile_columns=*/2, /*expected_frame_threads=*/0, /*expected_tile_threads=*/{}); } // Equal number of tile threads for each frame thread. VerifyFrameParallel( /*thread_count=*/8, /*tile_count=*/1, /*tile_columns=*/1, /*expected_frame_threads=*/4, /*expected_tile_threads=*/{1, 1, 1, 1}); VerifyFrameParallel( /*thread_count=*/12, /*tile_count=*/2, /*tile_columns=*/2, /*expected_frame_threads=*/4, /*expected_tile_threads=*/{2, 2, 2, 2}); VerifyFrameParallel( /*thread_count=*/18, /*tile_count=*/2, /*tile_columns=*/2, /*expected_frame_threads=*/6, /*expected_tile_threads=*/{2, 2, 2, 2, 2, 2}); VerifyFrameParallel( /*thread_count=*/16, /*tile_count=*/3, /*tile_columns=*/3, /*expected_frame_threads=*/4, /*expected_tile_threads=*/{3, 3, 3, 3}); // Unequal number of tile threads for each frame thread. VerifyFrameParallel( /*thread_count=*/7, /*tile_count=*/1, /*tile_columns=*/1, /*expected_frame_threads=*/3, /*expected_tile_threads=*/{2, 1, 1}); VerifyFrameParallel( /*thread_count=*/14, /*tile_count=*/2, /*tile_columns=*/2, /*expected_frame_threads=*/4, /*expected_tile_threads=*/{3, 3, 2, 2}); VerifyFrameParallel( /*thread_count=*/20, /*tile_count=*/2, /*tile_columns=*/2, /*expected_frame_threads=*/6, /*expected_tile_threads=*/{3, 3, 2, 2, 2, 2}); VerifyFrameParallel( /*thread_count=*/17, /*tile_count=*/3, /*tile_columns=*/3, /*expected_frame_threads=*/4, /*expected_tile_threads=*/{4, 3, 3, 3}); } TEST(FrameParallelStrategyTest, ThreadCountDoesNotExceedkMaxThreads) { std::unique_ptr frame_thread_pool; FrameScratchBufferPool frame_scratch_buffer_pool; ASSERT_TRUE(InitializeThreadPoolsForFrameParallel( /*thread_count=*/kMaxThreads + 10, /*tile_count=*/2, /*tile_columns=*/2, &frame_thread_pool, &frame_scratch_buffer_pool)); EXPECT_NE(frame_thread_pool.get(), nullptr); std::vector> frame_scratch_buffers; int actual_thread_count = frame_thread_pool->num_threads(); for (int i = 0; i < frame_thread_pool->num_threads(); ++i) { SCOPED_TRACE(absl::StrCat("i: ", i)); frame_scratch_buffers.push_back(frame_scratch_buffer_pool.Get()); ThreadPool* const thread_pool = frame_scratch_buffers.back()->threading_strategy.thread_pool(); if (thread_pool != nullptr) { actual_thread_count += thread_pool->num_threads(); } } // In this case, the exact number of frame threads and tile threads depend on // the value of kMaxThreads. So simply ensure that the total number of threads // does not exceed kMaxThreads. EXPECT_LE(actual_thread_count, kMaxThreads); for (auto& frame_scratch_buffer : frame_scratch_buffers) { frame_scratch_buffer_pool.Release(std::move(frame_scratch_buffer)); } } } // namespace } // namespace libgav1