aboutsummaryrefslogtreecommitdiff
path: root/absl/algorithm/container.h
diff options
context:
space:
mode:
authorEric Astor <epastor@google.com>2023-12-21 08:11:01 -0800
committerCopybara-Service <copybara-worker@google.com>2023-12-21 08:12:11 -0800
commit258e5a15759cc3d122d4a4826bc499af91d40aa9 (patch)
tree0696c01c1d40217b8c339a3e81418dace1e10640 /absl/algorithm/container.h
parent794352a92f09425714b9116974b29e58ce8f9ba9 (diff)
downloadabseil-258e5a15759cc3d122d4a4826bc499af91d40aa9.tar.gz
abseil-258e5a15759cc3d122d4a4826bc499af91d40aa9.tar.bz2
abseil-258e5a15759cc3d122d4a4826bc499af91d40aa9.zip
Add a container-based version of `std::sample()`
PiperOrigin-RevId: 592864147 Change-Id: I83179b0225aa446ae0b57b46b604af14f1fa14df
Diffstat (limited to 'absl/algorithm/container.h')
-rw-r--r--absl/algorithm/container.h30
1 files changed, 30 insertions, 0 deletions
diff --git a/absl/algorithm/container.h b/absl/algorithm/container.h
index 934dd179..c7bafae1 100644
--- a/absl/algorithm/container.h
+++ b/absl/algorithm/container.h
@@ -774,6 +774,36 @@ void c_shuffle(RandomAccessContainer& c, UniformRandomBitGenerator&& gen) {
std::forward<UniformRandomBitGenerator>(gen));
}
+// c_sample()
+//
+// Container-based version of the <algorithm> `std::sample()` function to
+// randomly sample elements from the container without replacement using a
+// `gen()` uniform random number generator and write them to an iterator range.
+template <typename C, typename OutputIterator, typename Distance,
+ typename UniformRandomBitGenerator>
+OutputIterator c_sample(const C& c, OutputIterator result, Distance n,
+ UniformRandomBitGenerator&& gen) {
+#if defined(__cpp_lib_sample) && __cpp_lib_sample >= 201603L
+ return std::sample(container_algorithm_internal::c_begin(c),
+ container_algorithm_internal::c_end(c), result, n,
+ std::forward<UniformRandomBitGenerator>(gen));
+#else
+ // Fall back to a stable selection-sampling implementation.
+ auto first = container_algorithm_internal::c_begin(c);
+ Distance unsampled_elements = c_distance(c);
+ n = (std::min)(n, unsampled_elements);
+ for (; n != 0; ++first) {
+ Distance r =
+ std::uniform_int_distribution<Distance>(0, --unsampled_elements)(gen);
+ if (r < n) {
+ *result++ = *first;
+ --n;
+ }
+ }
+ return result;
+#endif
+}
+
//------------------------------------------------------------------------------
// <algorithm> Partition functions
//------------------------------------------------------------------------------