aboutsummaryrefslogtreecommitdiff
path: root/src/utils/blocking_counter.h
blob: 6d664f8bb2b35eb2652a1c63428d2aec31111b0a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
/*
 * Copyright 2019 The libgav1 Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef LIBGAV1_SRC_UTILS_BLOCKING_COUNTER_H_
#define LIBGAV1_SRC_UTILS_BLOCKING_COUNTER_H_

#include <cassert>
#include <condition_variable>  // NOLINT (unapproved c++11 header)
#include <mutex>               // NOLINT (unapproved c++11 header)

#include "src/utils/compiler_attributes.h"

namespace libgav1 {

// Implementation of a Blocking Counter that is used for the "fork-join"
// use case. Typical usage would be as follows:
//   BlockingCounter counter(num_jobs);
//     - spawn the jobs.
//     - call counter.Wait() on the master thread.
//     - worker threads will call counter.Decrement().
//     - master thread will return from counter.Wait() when all workers are
//     complete.
template <bool has_failure_status>
class BlockingCounterImpl {
 public:
  explicit BlockingCounterImpl(int initial_count)
      : count_(initial_count), job_failed_(false) {}

  // Increment the counter by |count|. This must be called before Wait() is
  // called. This must be called from the same thread that will call Wait().
  void IncrementBy(int count) {
    assert(count >= 0);
    std::unique_lock<std::mutex> lock(mutex_);
    count_ += count;
  }

  // Decrement the counter by 1. This function can be called only when
  // |has_failure_status| is false (i.e.) when this class is being used with the
  // |BlockingCounter| alias.
  void Decrement() {
    static_assert(!has_failure_status, "");
    std::unique_lock<std::mutex> lock(mutex_);
    if (--count_ == 0) {
      condition_.notify_one();
    }
  }

  // Decrement the counter by 1. This function can be called only when
  // |has_failure_status| is true (i.e.) when this class is being used with the
  // |BlockingCounterWithStatus| alias. |job_succeeded| is used to update the
  // state of |job_failed_|.
  void Decrement(bool job_succeeded) {
    static_assert(has_failure_status, "");
    std::unique_lock<std::mutex> lock(mutex_);
    job_failed_ |= !job_succeeded;
    if (--count_ == 0) {
      condition_.notify_one();
    }
  }

  // Block until the counter becomes 0. This function can be called only once
  // per object. If |has_failure_status| is true, true is returned if all the
  // jobs succeeded and false is returned if any of the jobs failed. If
  // |has_failure_status| is false, this function always returns true.
  bool Wait() {
    std::unique_lock<std::mutex> lock(mutex_);
    condition_.wait(lock, [this]() { return count_ == 0; });
    // If |has_failure_status| is false, we simply return true.
    return has_failure_status ? !job_failed_ : true;
  }

 private:
  std::mutex mutex_;
  std::condition_variable condition_;
  int count_ LIBGAV1_GUARDED_BY(mutex_);
  bool job_failed_ LIBGAV1_GUARDED_BY(mutex_);
};

using BlockingCounterWithStatus = BlockingCounterImpl<true>;
using BlockingCounter = BlockingCounterImpl<false>;

}  // namespace libgav1

#endif  // LIBGAV1_SRC_UTILS_BLOCKING_COUNTER_H_