aboutsummaryrefslogtreecommitdiff
path: root/src/threading_strategy_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/threading_strategy_test.cc')
-rw-r--r--src/threading_strategy_test.cc281
1 files changed, 281 insertions, 0 deletions
diff --git a/src/threading_strategy_test.cc b/src/threading_strategy_test.cc
new file mode 100644
index 0000000..2a7a781
--- /dev/null
+++ b/src/threading_strategy_test.cc
@@ -0,0 +1,281 @@
+// Copyright 2021 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/threading_strategy.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "gtest/gtest.h"
+#include "src/frame_scratch_buffer.h"
+#include "src/obu_parser.h"
+#include "src/utils/constants.h"
+#include "src/utils/threadpool.h"
+#include "src/utils/types.h"
+
+namespace libgav1 {
+namespace {
+
+class ThreadingStrategyTest : public testing::Test {
+ protected:
+ ThreadingStrategy strategy_;
+ ObuFrameHeader frame_header_ = {};
+};
+
+TEST_F(ThreadingStrategyTest, MaxThreadEnforced) {
+ frame_header_.tile_info.tile_count = 32;
+ ASSERT_TRUE(strategy_.Reset(frame_header_, 32));
+ EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
+ for (int i = 0; i < 32; ++i) {
+ EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
+ }
+ EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
+}
+
+TEST_F(ThreadingStrategyTest, UseAllThreadsForTiles) {
+ frame_header_.tile_info.tile_count = 8;
+ ASSERT_TRUE(strategy_.Reset(frame_header_, 8));
+ EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
+ for (int i = 0; i < 8; ++i) {
+ EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
+ }
+ EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
+}
+
+TEST_F(ThreadingStrategyTest, RowThreads) {
+ frame_header_.tile_info.tile_count = 2;
+ ASSERT_TRUE(strategy_.Reset(frame_header_, 8));
+ EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
+ // Each tile should get 3 threads each.
+ for (int i = 0; i < 2; ++i) {
+ EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
+ }
+ EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
+}
+
+TEST_F(ThreadingStrategyTest, RowThreadsUnequal) {
+ frame_header_.tile_info.tile_count = 2;
+
+ ASSERT_TRUE(strategy_.Reset(frame_header_, 9));
+ EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
+ EXPECT_NE(strategy_.row_thread_pool(0), nullptr);
+ EXPECT_NE(strategy_.row_thread_pool(1), nullptr);
+ EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
+}
+
+// Test a random combination of tile_count and thread_count.
+TEST_F(ThreadingStrategyTest, MultipleCalls) {
+ frame_header_.tile_info.tile_count = 2;
+ ASSERT_TRUE(strategy_.Reset(frame_header_, 8));
+ EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
+ for (int i = 0; i < 2; ++i) {
+ EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
+ }
+ EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
+
+ frame_header_.tile_info.tile_count = 8;
+ ASSERT_TRUE(strategy_.Reset(frame_header_, 8));
+ EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
+ // Row threads must have been reset.
+ for (int i = 0; i < 8; ++i) {
+ EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
+ }
+ EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
+
+ frame_header_.tile_info.tile_count = 8;
+ ASSERT_TRUE(strategy_.Reset(frame_header_, 16));
+ EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
+ for (int i = 0; i < 8; ++i) {
+ EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
+ }
+ EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
+
+ frame_header_.tile_info.tile_count = 4;
+ ASSERT_TRUE(strategy_.Reset(frame_header_, 16));
+ EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
+ }
+ // All the other row threads must be reset.
+ for (int i = 4; i < 8; ++i) {
+ EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
+ }
+ EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
+
+ frame_header_.tile_info.tile_count = 4;
+ ASSERT_TRUE(strategy_.Reset(frame_header_, 6));
+ EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
+ // First two tiles will get 1 thread each.
+ for (int i = 0; i < 2; ++i) {
+ EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
+ }
+ // All the other row threads must be reset.
+ for (int i = 2; i < 8; ++i) {
+ EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
+ }
+ EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
+
+ ASSERT_TRUE(strategy_.Reset(frame_header_, 1));
+ EXPECT_EQ(strategy_.tile_thread_pool(), nullptr);
+ for (int i = 0; i < 8; ++i) {
+ EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
+ }
+ EXPECT_EQ(strategy_.post_filter_thread_pool(), nullptr);
+}
+
+// Tests the following order of calls (with thread count fixed at 4):
+// * 1 Tile - 2 Tiles - 1 Tile.
+TEST_F(ThreadingStrategyTest, MultipleCalls2) {
+ frame_header_.tile_info.tile_count = 1;
+ ASSERT_TRUE(strategy_.Reset(frame_header_, 4));
+ // When there is only one tile, tile thread pool must be nullptr.
+ EXPECT_EQ(strategy_.tile_thread_pool(), nullptr);
+ EXPECT_NE(strategy_.row_thread_pool(0), nullptr);
+ for (int i = 1; i < 8; ++i) {
+ EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
+ }
+ EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
+
+ frame_header_.tile_info.tile_count = 2;
+ ASSERT_TRUE(strategy_.Reset(frame_header_, 4));
+ EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
+ for (int i = 0; i < 2; ++i) {
+ EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
+ }
+ for (int i = 2; i < 8; ++i) {
+ EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
+ }
+ EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
+
+ frame_header_.tile_info.tile_count = 1;
+ ASSERT_TRUE(strategy_.Reset(frame_header_, 4));
+ EXPECT_EQ(strategy_.tile_thread_pool(), nullptr);
+ EXPECT_NE(strategy_.row_thread_pool(0), nullptr);
+ for (int i = 1; i < 8; ++i) {
+ EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
+ }
+ EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
+}
+
+void VerifyFrameParallel(int thread_count, int tile_count, int tile_columns,
+ int expected_frame_threads,
+ const std::vector<int>& expected_tile_threads) {
+ ASSERT_EQ(expected_frame_threads, expected_tile_threads.size());
+ ASSERT_GT(thread_count, 1);
+ std::unique_ptr<ThreadPool> frame_thread_pool;
+ FrameScratchBufferPool frame_scratch_buffer_pool;
+ ASSERT_TRUE(InitializeThreadPoolsForFrameParallel(
+ thread_count, tile_count, tile_columns, &frame_thread_pool,
+ &frame_scratch_buffer_pool));
+ if (expected_frame_threads == 0) {
+ EXPECT_EQ(frame_thread_pool, nullptr);
+ return;
+ }
+ EXPECT_NE(frame_thread_pool.get(), nullptr);
+ EXPECT_EQ(frame_thread_pool->num_threads(), expected_frame_threads);
+ std::vector<std::unique_ptr<FrameScratchBuffer>> frame_scratch_buffers;
+ int actual_thread_count = frame_thread_pool->num_threads();
+ for (int i = 0; i < expected_frame_threads; ++i) {
+ SCOPED_TRACE(absl::StrCat("i: ", i));
+ frame_scratch_buffers.push_back(frame_scratch_buffer_pool.Get());
+ ThreadPool* const thread_pool =
+ frame_scratch_buffers.back()->threading_strategy.thread_pool();
+ if (expected_tile_threads[i] > 0) {
+ EXPECT_NE(thread_pool, nullptr);
+ EXPECT_EQ(thread_pool->num_threads(), expected_tile_threads[i]);
+ actual_thread_count += thread_pool->num_threads();
+ } else {
+ EXPECT_EQ(thread_pool, nullptr);
+ }
+ }
+ EXPECT_EQ(thread_count, actual_thread_count);
+ for (auto& frame_scratch_buffer : frame_scratch_buffers) {
+ frame_scratch_buffer_pool.Release(std::move(frame_scratch_buffer));
+ }
+}
+
+TEST(FrameParallelStrategyTest, FrameParallel) {
+ // This loop has thread_count <= 3 * tile count. So there should be no frame
+ // threads irrespective of the number of tile columns.
+ for (int thread_count = 2; thread_count <= 6; ++thread_count) {
+ VerifyFrameParallel(thread_count, /*tile_count=*/2, /*tile_columns=*/1,
+ /*expected_frame_threads=*/0,
+ /*expected_tile_threads=*/{});
+ VerifyFrameParallel(thread_count, /*tile_count=*/2, /*tile_columns=*/2,
+ /*expected_frame_threads=*/0,
+ /*expected_tile_threads=*/{});
+ }
+
+ // Equal number of tile threads for each frame thread.
+ VerifyFrameParallel(
+ /*thread_count=*/8, /*tile_count=*/1, /*tile_columns=*/1,
+ /*expected_frame_threads=*/4, /*expected_tile_threads=*/{1, 1, 1, 1});
+ VerifyFrameParallel(
+ /*thread_count=*/12, /*tile_count=*/2, /*tile_columns=*/2,
+ /*expected_frame_threads=*/4, /*expected_tile_threads=*/{2, 2, 2, 2});
+ VerifyFrameParallel(
+ /*thread_count=*/18, /*tile_count=*/2, /*tile_columns=*/2,
+ /*expected_frame_threads=*/6,
+ /*expected_tile_threads=*/{2, 2, 2, 2, 2, 2});
+ VerifyFrameParallel(
+ /*thread_count=*/16, /*tile_count=*/3, /*tile_columns=*/3,
+ /*expected_frame_threads=*/4, /*expected_tile_threads=*/{3, 3, 3, 3});
+
+ // Unequal number of tile threads for each frame thread.
+ VerifyFrameParallel(
+ /*thread_count=*/7, /*tile_count=*/1, /*tile_columns=*/1,
+ /*expected_frame_threads=*/3, /*expected_tile_threads=*/{2, 1, 1});
+ VerifyFrameParallel(
+ /*thread_count=*/14, /*tile_count=*/2, /*tile_columns=*/2,
+ /*expected_frame_threads=*/4, /*expected_tile_threads=*/{3, 3, 2, 2});
+ VerifyFrameParallel(
+ /*thread_count=*/20, /*tile_count=*/2, /*tile_columns=*/2,
+ /*expected_frame_threads=*/6,
+ /*expected_tile_threads=*/{3, 3, 2, 2, 2, 2});
+ VerifyFrameParallel(
+ /*thread_count=*/17, /*tile_count=*/3, /*tile_columns=*/3,
+ /*expected_frame_threads=*/4, /*expected_tile_threads=*/{4, 3, 3, 3});
+}
+
+TEST(FrameParallelStrategyTest, ThreadCountDoesNotExceedkMaxThreads) {
+ std::unique_ptr<ThreadPool> frame_thread_pool;
+ FrameScratchBufferPool frame_scratch_buffer_pool;
+ ASSERT_TRUE(InitializeThreadPoolsForFrameParallel(
+ /*thread_count=*/kMaxThreads + 10, /*tile_count=*/2, /*tile_columns=*/2,
+ &frame_thread_pool, &frame_scratch_buffer_pool));
+ EXPECT_NE(frame_thread_pool.get(), nullptr);
+ std::vector<std::unique_ptr<FrameScratchBuffer>> frame_scratch_buffers;
+ int actual_thread_count = frame_thread_pool->num_threads();
+ for (int i = 0; i < frame_thread_pool->num_threads(); ++i) {
+ SCOPED_TRACE(absl::StrCat("i: ", i));
+ frame_scratch_buffers.push_back(frame_scratch_buffer_pool.Get());
+ ThreadPool* const thread_pool =
+ frame_scratch_buffers.back()->threading_strategy.thread_pool();
+ if (thread_pool != nullptr) {
+ actual_thread_count += thread_pool->num_threads();
+ }
+ }
+ // In this case, the exact number of frame threads and tile threads depend on
+ // the value of kMaxThreads. So simply ensure that the total number of threads
+ // does not exceed kMaxThreads.
+ EXPECT_LE(actual_thread_count, kMaxThreads);
+ for (auto& frame_scratch_buffer : frame_scratch_buffers) {
+ frame_scratch_buffer_pool.Release(std::move(frame_scratch_buffer));
+ }
+}
+
+} // namespace
+} // namespace libgav1