From 258e5a15759cc3d122d4a4826bc499af91d40aa9 Mon Sep 17 00:00:00 2001 From: Eric Astor Date: Thu, 21 Dec 2023 08:11:01 -0800 Subject: Add a container-based version of `std::sample()` PiperOrigin-RevId: 592864147 Change-Id: I83179b0225aa446ae0b57b46b604af14f1fa14df --- absl/algorithm/container.h | 30 ++++++++++++++++++++++++++++++ absl/algorithm/container_test.cc | 22 +++++++++++++++++++++- 2 files changed, 51 insertions(+), 1 deletion(-) (limited to 'absl/algorithm') diff --git a/absl/algorithm/container.h b/absl/algorithm/container.h index 934dd179..c7bafae1 100644 --- a/absl/algorithm/container.h +++ b/absl/algorithm/container.h @@ -774,6 +774,36 @@ void c_shuffle(RandomAccessContainer& c, UniformRandomBitGenerator&& gen) { std::forward(gen)); } +// c_sample() +// +// Container-based version of the `std::sample()` function to +// randomly sample elements from the container without replacement using a +// `gen()` uniform random number generator and write them to an iterator range. +template +OutputIterator c_sample(const C& c, OutputIterator result, Distance n, + UniformRandomBitGenerator&& gen) { +#if defined(__cpp_lib_sample) && __cpp_lib_sample >= 201603L + return std::sample(container_algorithm_internal::c_begin(c), + container_algorithm_internal::c_end(c), result, n, + std::forward(gen)); +#else + // Fall back to a stable selection-sampling implementation. + auto first = container_algorithm_internal::c_begin(c); + Distance unsampled_elements = c_distance(c); + n = (std::min)(n, unsampled_elements); + for (; n != 0; ++first) { + Distance r = + std::uniform_int_distribution(0, --unsampled_elements)(gen); + if (r < n) { + *result++ = *first; + --n; + } + } + return result; +#endif +} + //------------------------------------------------------------------------------ // Partition functions //------------------------------------------------------------------------------ diff --git a/absl/algorithm/container_test.cc b/absl/algorithm/container_test.cc index 0fbc7773..c01f5fc0 100644 --- a/absl/algorithm/container_test.cc +++ b/absl/algorithm/container_test.cc @@ -14,6 +14,7 @@ #include "absl/algorithm/container.h" +#include #include #include #include @@ -40,8 +41,10 @@ using ::testing::Each; using ::testing::ElementsAre; using ::testing::Gt; using ::testing::IsNull; +using ::testing::IsSubsetOf; using ::testing::Lt; using ::testing::Pointee; +using ::testing::SizeIs; using ::testing::Truly; using ::testing::UnorderedElementsAre; @@ -963,12 +966,29 @@ TEST(MutatingTest, RotateCopy) { EXPECT_THAT(actual, ElementsAre(3, 4, 1, 2, 5)); } +template +T RandomlySeededPrng() { + std::random_device rdev; + std::seed_seq::result_type data[T::state_size]; + std::generate_n(data, T::state_size, std::ref(rdev)); + std::seed_seq prng_seed(data, data + T::state_size); + return T(prng_seed); +} + TEST(MutatingTest, Shuffle) { std::vector actual = {1, 2, 3, 4, 5}; - absl::c_shuffle(actual, std::random_device()); + absl::c_shuffle(actual, RandomlySeededPrng()); EXPECT_THAT(actual, UnorderedElementsAre(1, 2, 3, 4, 5)); } +TEST(MutatingTest, Sample) { + std::vector actual; + absl::c_sample(std::vector{1, 2, 3, 4, 5}, std::back_inserter(actual), 3, + RandomlySeededPrng()); + EXPECT_THAT(actual, IsSubsetOf({1, 2, 3, 4, 5})); + EXPECT_THAT(actual, SizeIs(3)); +} + TEST(MutatingTest, PartialSort) { std::vector sequence{5, 3, 42, 0}; absl::c_partial_sort(sequence, sequence.begin() + 2); -- cgit v1.2.3