diff options
author | Eric Astor <epastor@google.com> | 2023-12-21 08:11:01 -0800 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2023-12-21 08:12:11 -0800 |
commit | 258e5a15759cc3d122d4a4826bc499af91d40aa9 (patch) | |
tree | 0696c01c1d40217b8c339a3e81418dace1e10640 /absl/algorithm/container.h | |
parent | 794352a92f09425714b9116974b29e58ce8f9ba9 (diff) | |
download | abseil-258e5a15759cc3d122d4a4826bc499af91d40aa9.tar.gz abseil-258e5a15759cc3d122d4a4826bc499af91d40aa9.tar.bz2 abseil-258e5a15759cc3d122d4a4826bc499af91d40aa9.zip |
Add a container-based version of `std::sample()`
PiperOrigin-RevId: 592864147
Change-Id: I83179b0225aa446ae0b57b46b604af14f1fa14df
Diffstat (limited to 'absl/algorithm/container.h')
-rw-r--r-- | absl/algorithm/container.h | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/absl/algorithm/container.h b/absl/algorithm/container.h index 934dd179..c7bafae1 100644 --- a/absl/algorithm/container.h +++ b/absl/algorithm/container.h @@ -774,6 +774,36 @@ void c_shuffle(RandomAccessContainer& c, UniformRandomBitGenerator&& gen) { std::forward<UniformRandomBitGenerator>(gen)); } +// c_sample() +// +// Container-based version of the <algorithm> `std::sample()` function to +// randomly sample elements from the container without replacement using a +// `gen()` uniform random number generator and write them to an iterator range. +template <typename C, typename OutputIterator, typename Distance, + typename UniformRandomBitGenerator> +OutputIterator c_sample(const C& c, OutputIterator result, Distance n, + UniformRandomBitGenerator&& gen) { +#if defined(__cpp_lib_sample) && __cpp_lib_sample >= 201603L + return std::sample(container_algorithm_internal::c_begin(c), + container_algorithm_internal::c_end(c), result, n, + std::forward<UniformRandomBitGenerator>(gen)); +#else + // Fall back to a stable selection-sampling implementation. + auto first = container_algorithm_internal::c_begin(c); + Distance unsampled_elements = c_distance(c); + n = (std::min)(n, unsampled_elements); + for (; n != 0; ++first) { + Distance r = + std::uniform_int_distribution<Distance>(0, --unsampled_elements)(gen); + if (r < n) { + *result++ = *first; + --n; + } + } + return result; +#endif +} + //------------------------------------------------------------------------------ // <algorithm> Partition functions //------------------------------------------------------------------------------ |