aboutsummaryrefslogtreecommitdiff
path: root/absl/random/internal
diff options
context:
space:
mode:
Diffstat (limited to 'absl/random/internal')
-rw-r--r--absl/random/internal/BUILD.bazel24
-rw-r--r--absl/random/internal/mock_helpers.h40
-rw-r--r--absl/random/internal/mock_overload_set.h82
-rw-r--r--absl/random/internal/mock_validators.h98
4 files changed, 204 insertions, 40 deletions
diff --git a/absl/random/internal/BUILD.bazel b/absl/random/internal/BUILD.bazel
index 71a742ee..5e05130d 100644
--- a/absl/random/internal/BUILD.bazel
+++ b/absl/random/internal/BUILD.bazel
@@ -137,7 +137,7 @@ cc_library(
cc_library(
name = "explicit_seed_seq",
- testonly = 1,
+ testonly = True,
hdrs = [
"explicit_seed_seq.h",
],
@@ -151,7 +151,7 @@ cc_library(
cc_library(
name = "sequence_urbg",
- testonly = 1,
+ testonly = True,
hdrs = [
"sequence_urbg.h",
],
@@ -375,7 +375,7 @@ cc_binary(
cc_library(
name = "distribution_test_util",
- testonly = 1,
+ testonly = True,
srcs = [
"chi_square.cc",
"distribution_test_util.cc",
@@ -527,6 +527,7 @@ cc_library(
hdrs = ["mock_helpers.h"],
linkopts = ABSL_DEFAULT_LINKOPTS,
deps = [
+ "//absl/base:config",
"//absl/base:fast_type_id",
"//absl/types:optional",
],
@@ -534,11 +535,12 @@ cc_library(
cc_library(
name = "mock_overload_set",
- testonly = 1,
+ testonly = True,
hdrs = ["mock_overload_set.h"],
linkopts = ABSL_DEFAULT_LINKOPTS,
deps = [
":mock_helpers",
+ "//absl/base:config",
"//absl/random:mocking_bit_gen",
"@com_google_googletest//:gtest",
],
@@ -712,7 +714,19 @@ cc_library(
":traits",
"//absl/base:config",
"//absl/meta:type_traits",
- "//absl/numeric:int128",
+ ],
+)
+
+cc_library(
+ name = "mock_validators",
+ hdrs = ["mock_validators.h"],
+ deps = [
+ ":iostream_state_saver",
+ ":uniform_helper",
+ "//absl/base:config",
+ "//absl/base:raw_logging_internal",
+ "//absl/strings",
+ "//absl/strings:string_view",
],
)
diff --git a/absl/random/internal/mock_helpers.h b/absl/random/internal/mock_helpers.h
index a7a97bfc..19d05612 100644
--- a/absl/random/internal/mock_helpers.h
+++ b/absl/random/internal/mock_helpers.h
@@ -16,10 +16,9 @@
#ifndef ABSL_RANDOM_INTERNAL_MOCK_HELPERS_H_
#define ABSL_RANDOM_INTERNAL_MOCK_HELPERS_H_
-#include <tuple>
-#include <type_traits>
#include <utility>
+#include "absl/base/config.h"
#include "absl/base/internal/fast_type_id.h"
#include "absl/types/optional.h"
@@ -27,6 +26,16 @@ namespace absl {
ABSL_NAMESPACE_BEGIN
namespace random_internal {
+// A no-op validator meeting the ValidatorT requirements for MockHelpers.
+//
+// Custom validators should follow a similar structure, passing the type to
+// MockHelpers::MockFor<KeyT>(m, CustomValidatorT()).
+struct NoOpValidator {
+ // Default validation: do nothing.
+ template <typename ResultT, typename... Args>
+ static void Validate(ResultT, Args&&...) {}
+};
+
// MockHelpers works in conjunction with MockOverloadSet, MockingBitGen, and
// BitGenRef to enable the mocking capability for absl distribution functions.
//
@@ -109,22 +118,39 @@ class MockHelpers {
0, urbg, std::forward<Args>(args)...);
}
- // Acquire a mock for the KeyT (may or may not be a signature).
+ // Acquire a mock for the KeyT (may or may not be a signature), set up to use
+ // the ValidatorT to verify that the result is in the range of the RNG
+ // function.
//
// KeyT is used to generate a typeid-based lookup for the mock.
// KeyT is a signature of the form:
// result_type(discriminator_type, std::tuple<args...>)
// The mocked function signature will be composed from KeyT as:
// result_type(args...)
- template <typename KeyT, typename MockURBG>
- static auto MockFor(MockURBG& m)
+ // ValidatorT::Validate will be called after the result of the RNG. The
+ // signature is expected to be of the form:
+ // ValidatorT::Validate(result, args...)
+ template <typename KeyT, typename ValidatorT, typename MockURBG>
+ static auto MockFor(MockURBG& m, ValidatorT)
-> decltype(m.template RegisterMock<
typename KeySignature<KeyT>::result_type,
typename KeySignature<KeyT>::arg_tuple_type>(
- m, std::declval<IdType>())) {
+ m, std::declval<IdType>(), ValidatorT())) {
return m.template RegisterMock<typename KeySignature<KeyT>::result_type,
typename KeySignature<KeyT>::arg_tuple_type>(
- m, ::absl::base_internal::FastTypeId<KeyT>());
+ m, ::absl::base_internal::FastTypeId<KeyT>(), ValidatorT());
+ }
+
+ // Acquire a mock for the KeyT (may or may not be a signature).
+ //
+ // KeyT is used to generate a typeid-based lookup for the mock.
+ // KeyT is a signature of the form:
+ // result_type(discriminator_type, std::tuple<args...>)
+ // The mocked function signature will be composed from KeyT as:
+ // result_type(args...)
+ template <typename KeyT, typename MockURBG>
+ static decltype(auto) MockFor(MockURBG& m) {
+ return MockFor<KeyT>(m, NoOpValidator());
}
};
diff --git a/absl/random/internal/mock_overload_set.h b/absl/random/internal/mock_overload_set.h
index 0d9c6c12..cfaeeeef 100644
--- a/absl/random/internal/mock_overload_set.h
+++ b/absl/random/internal/mock_overload_set.h
@@ -16,9 +16,11 @@
#ifndef ABSL_RANDOM_INTERNAL_MOCK_OVERLOAD_SET_H_
#define ABSL_RANDOM_INTERNAL_MOCK_OVERLOAD_SET_H_
+#include <tuple>
#include <type_traits>
#include "gmock/gmock.h"
+#include "absl/base/config.h"
#include "absl/random/internal/mock_helpers.h"
#include "absl/random/mocking_bit_gen.h"
@@ -26,7 +28,7 @@ namespace absl {
ABSL_NAMESPACE_BEGIN
namespace random_internal {
-template <typename DistrT, typename Fn>
+template <typename DistrT, typename ValidatorT, typename Fn>
struct MockSingleOverload;
// MockSingleOverload
@@ -38,8 +40,8 @@ struct MockSingleOverload;
// arguments to MockingBitGen::Register.
//
// The underlying KeyT must match the KeyT constructed by DistributionCaller.
-template <typename DistrT, typename Ret, typename... Args>
-struct MockSingleOverload<DistrT, Ret(MockingBitGen&, Args...)> {
+template <typename DistrT, typename ValidatorT, typename Ret, typename... Args>
+struct MockSingleOverload<DistrT, ValidatorT, Ret(MockingBitGen&, Args...)> {
static_assert(std::is_same<typename DistrT::result_type, Ret>::value,
"Overload signature must have return type matching the "
"distribution result_type.");
@@ -47,15 +49,21 @@ struct MockSingleOverload<DistrT, Ret(MockingBitGen&, Args...)> {
template <typename MockURBG>
auto gmock_Call(MockURBG& gen, const ::testing::Matcher<Args>&... matchers)
- -> decltype(MockHelpers::MockFor<KeyT>(gen).gmock_Call(matchers...)) {
- static_assert(std::is_base_of<MockingBitGen, MockURBG>::value,
- "Mocking requires an absl::MockingBitGen");
- return MockHelpers::MockFor<KeyT>(gen).gmock_Call(matchers...);
+ -> decltype(MockHelpers::MockFor<KeyT>(gen, ValidatorT())
+ .gmock_Call(matchers...)) {
+ static_assert(
+ std::is_base_of<MockingBitGenImpl<true>, MockURBG>::value ||
+ std::is_base_of<MockingBitGenImpl<false>, MockURBG>::value,
+ "Mocking requires an absl::MockingBitGen");
+ return MockHelpers::MockFor<KeyT>(gen, ValidatorT())
+ .gmock_Call(matchers...);
}
};
-template <typename DistrT, typename Ret, typename Arg, typename... Args>
-struct MockSingleOverload<DistrT, Ret(Arg, MockingBitGen&, Args...)> {
+template <typename DistrT, typename ValidatorT, typename Ret, typename Arg,
+ typename... Args>
+struct MockSingleOverload<DistrT, ValidatorT,
+ Ret(Arg, MockingBitGen&, Args...)> {
static_assert(std::is_same<typename DistrT::result_type, Ret>::value,
"Overload signature must have return type matching the "
"distribution result_type.");
@@ -64,14 +72,44 @@ struct MockSingleOverload<DistrT, Ret(Arg, MockingBitGen&, Args...)> {
template <typename MockURBG>
auto gmock_Call(const ::testing::Matcher<Arg>& matcher, MockURBG& gen,
const ::testing::Matcher<Args>&... matchers)
- -> decltype(MockHelpers::MockFor<KeyT>(gen).gmock_Call(matcher,
- matchers...)) {
- static_assert(std::is_base_of<MockingBitGen, MockURBG>::value,
- "Mocking requires an absl::MockingBitGen");
- return MockHelpers::MockFor<KeyT>(gen).gmock_Call(matcher, matchers...);
+ -> decltype(MockHelpers::MockFor<KeyT>(gen, ValidatorT())
+ .gmock_Call(matcher, matchers...)) {
+ static_assert(
+ std::is_base_of<MockingBitGenImpl<true>, MockURBG>::value ||
+ std::is_base_of<MockingBitGenImpl<false>, MockURBG>::value,
+ "Mocking requires an absl::MockingBitGen");
+ return MockHelpers::MockFor<KeyT>(gen, ValidatorT())
+ .gmock_Call(matcher, matchers...);
}
};
+// MockOverloadSetWithValidator
+//
+// MockOverloadSetWithValidator is a wrapper around MockOverloadSet which takes
+// an additional Validator parameter, allowing for customization of the mock
+// behavior.
+//
+// `ValidatorT::Validate(result, args...)` will be called after the mock
+// distribution returns a value in `result`, allowing for validation against the
+// args.
+template <typename DistrT, typename ValidatorT, typename... Fns>
+struct MockOverloadSetWithValidator;
+
+template <typename DistrT, typename ValidatorT, typename Sig>
+struct MockOverloadSetWithValidator<DistrT, ValidatorT, Sig>
+ : public MockSingleOverload<DistrT, ValidatorT, Sig> {
+ using MockSingleOverload<DistrT, ValidatorT, Sig>::gmock_Call;
+};
+
+template <typename DistrT, typename ValidatorT, typename FirstSig,
+ typename... Rest>
+struct MockOverloadSetWithValidator<DistrT, ValidatorT, FirstSig, Rest...>
+ : public MockSingleOverload<DistrT, ValidatorT, FirstSig>,
+ public MockOverloadSetWithValidator<DistrT, ValidatorT, Rest...> {
+ using MockSingleOverload<DistrT, ValidatorT, FirstSig>::gmock_Call;
+ using MockOverloadSetWithValidator<DistrT, ValidatorT, Rest...>::gmock_Call;
+};
+
// MockOverloadSet
//
// MockOverloadSet takes a distribution and a collection of signatures and
@@ -79,20 +117,8 @@ struct MockSingleOverload<DistrT, Ret(Arg, MockingBitGen&, Args...)> {
// `EXPECT_CALL(mock_overload_set, Call(...))` expand and do overload resolution
// correctly.
template <typename DistrT, typename... Signatures>
-struct MockOverloadSet;
-
-template <typename DistrT, typename Sig>
-struct MockOverloadSet<DistrT, Sig> : public MockSingleOverload<DistrT, Sig> {
- using MockSingleOverload<DistrT, Sig>::gmock_Call;
-};
-
-template <typename DistrT, typename FirstSig, typename... Rest>
-struct MockOverloadSet<DistrT, FirstSig, Rest...>
- : public MockSingleOverload<DistrT, FirstSig>,
- public MockOverloadSet<DistrT, Rest...> {
- using MockSingleOverload<DistrT, FirstSig>::gmock_Call;
- using MockOverloadSet<DistrT, Rest...>::gmock_Call;
-};
+using MockOverloadSet =
+ MockOverloadSetWithValidator<DistrT, NoOpValidator, Signatures...>;
} // namespace random_internal
ABSL_NAMESPACE_END
diff --git a/absl/random/internal/mock_validators.h b/absl/random/internal/mock_validators.h
new file mode 100644
index 00000000..d76d169c
--- /dev/null
+++ b/absl/random/internal/mock_validators.h
@@ -0,0 +1,98 @@
+// Copyright 2024 The Abseil Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_
+#define ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_
+
+#include <type_traits>
+
+#include "absl/base/config.h"
+#include "absl/base/internal/raw_logging.h"
+#include "absl/random/internal/iostream_state_saver.h"
+#include "absl/random/internal/uniform_helper.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+
+namespace absl {
+ABSL_NAMESPACE_BEGIN
+namespace random_internal {
+
+template <typename NumType>
+class UniformDistributionValidator {
+ public:
+ // Handle absl::Uniform<NumType>(gen, absl::IntervalTag, lo, hi).
+ template <typename TagType>
+ static void Validate(NumType x, TagType tag, NumType lo, NumType hi) {
+ // For invalid ranges, absl::Uniform() simply returns one of the bounds.
+ if (x == lo && lo == hi) return;
+
+ ValidateImpl(std::is_floating_point<NumType>{}, x, tag, lo, hi);
+ }
+
+ // Handle absl::Uniform<NumType>(gen, lo, hi).
+ static void Validate(NumType x, NumType lo, NumType hi) {
+ Validate(x, IntervalClosedOpenTag(), lo, hi);
+ }
+
+ // Handle absl::Uniform<NumType>(gen).
+ static void Validate(NumType) {
+ // absl::Uniform<NumType>(gen) spans the entire range of `NumType`, so any
+ // value is okay. This overload exists because the validation logic attempts
+ // to call it anyway rather than adding extra SFINAE.
+ }
+
+ private:
+ static absl::string_view TagLbBound(IntervalClosedOpenTag) { return "["; }
+ static absl::string_view TagLbBound(IntervalOpenOpenTag) { return "("; }
+ static absl::string_view TagLbBound(IntervalClosedClosedTag) { return "["; }
+ static absl::string_view TagLbBound(IntervalOpenClosedTag) { return "("; }
+ static absl::string_view TagUbBound(IntervalClosedOpenTag) { return ")"; }
+ static absl::string_view TagUbBound(IntervalOpenOpenTag) { return ")"; }
+ static absl::string_view TagUbBound(IntervalClosedClosedTag) { return "]"; }
+ static absl::string_view TagUbBound(IntervalOpenClosedTag) { return "]"; }
+
+ template <typename TagType>
+ static void ValidateImpl(std::true_type /* is_floating_point */, NumType x,
+ TagType tag, NumType lo, NumType hi) {
+ UniformDistributionWrapper<NumType> dist(tag, lo, hi);
+ NumType lb = dist.a();
+ NumType ub = dist.b();
+ // uniform_real_distribution is always closed-open, so the upper bound is
+ // always non-inclusive.
+ ABSL_INTERNAL_CHECK(lb <= x && x < ub,
+ absl::StrCat(x, " is not in ", TagLbBound(tag), lo,
+ ", ", hi, TagUbBound(tag)));
+ }
+
+ template <typename TagType>
+ static void ValidateImpl(std::false_type /* is_floating_point */, NumType x,
+ TagType tag, NumType lo, NumType hi) {
+ using stream_type =
+ typename random_internal::stream_format_type<NumType>::type;
+
+ UniformDistributionWrapper<NumType> dist(tag, lo, hi);
+ NumType lb = dist.a();
+ NumType ub = dist.b();
+ ABSL_INTERNAL_CHECK(
+ lb <= x && x <= ub,
+ absl::StrCat(stream_type{x}, " is not in ", TagLbBound(tag),
+ stream_type{lo}, ", ", stream_type{hi}, TagUbBound(tag)));
+ }
+};
+
+} // namespace random_internal
+ABSL_NAMESPACE_END
+} // namespace absl
+
+#endif // ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_