diff options
Diffstat (limited to 'absl/random/internal')
-rw-r--r-- | absl/random/internal/BUILD.bazel | 24 | ||||
-rw-r--r-- | absl/random/internal/mock_helpers.h | 40 | ||||
-rw-r--r-- | absl/random/internal/mock_overload_set.h | 82 | ||||
-rw-r--r-- | absl/random/internal/mock_validators.h | 98 |
4 files changed, 204 insertions, 40 deletions
diff --git a/absl/random/internal/BUILD.bazel b/absl/random/internal/BUILD.bazel index 71a742ee..5e05130d 100644 --- a/absl/random/internal/BUILD.bazel +++ b/absl/random/internal/BUILD.bazel @@ -137,7 +137,7 @@ cc_library( cc_library( name = "explicit_seed_seq", - testonly = 1, + testonly = True, hdrs = [ "explicit_seed_seq.h", ], @@ -151,7 +151,7 @@ cc_library( cc_library( name = "sequence_urbg", - testonly = 1, + testonly = True, hdrs = [ "sequence_urbg.h", ], @@ -375,7 +375,7 @@ cc_binary( cc_library( name = "distribution_test_util", - testonly = 1, + testonly = True, srcs = [ "chi_square.cc", "distribution_test_util.cc", @@ -527,6 +527,7 @@ cc_library( hdrs = ["mock_helpers.h"], linkopts = ABSL_DEFAULT_LINKOPTS, deps = [ + "//absl/base:config", "//absl/base:fast_type_id", "//absl/types:optional", ], @@ -534,11 +535,12 @@ cc_library( cc_library( name = "mock_overload_set", - testonly = 1, + testonly = True, hdrs = ["mock_overload_set.h"], linkopts = ABSL_DEFAULT_LINKOPTS, deps = [ ":mock_helpers", + "//absl/base:config", "//absl/random:mocking_bit_gen", "@com_google_googletest//:gtest", ], @@ -712,7 +714,19 @@ cc_library( ":traits", "//absl/base:config", "//absl/meta:type_traits", - "//absl/numeric:int128", + ], +) + +cc_library( + name = "mock_validators", + hdrs = ["mock_validators.h"], + deps = [ + ":iostream_state_saver", + ":uniform_helper", + "//absl/base:config", + "//absl/base:raw_logging_internal", + "//absl/strings", + "//absl/strings:string_view", ], ) diff --git a/absl/random/internal/mock_helpers.h b/absl/random/internal/mock_helpers.h index a7a97bfc..19d05612 100644 --- a/absl/random/internal/mock_helpers.h +++ b/absl/random/internal/mock_helpers.h @@ -16,10 +16,9 @@ #ifndef ABSL_RANDOM_INTERNAL_MOCK_HELPERS_H_ #define ABSL_RANDOM_INTERNAL_MOCK_HELPERS_H_ -#include <tuple> -#include <type_traits> #include <utility> +#include "absl/base/config.h" #include "absl/base/internal/fast_type_id.h" #include "absl/types/optional.h" @@ -27,6 +26,16 @@ namespace absl { ABSL_NAMESPACE_BEGIN namespace random_internal { +// A no-op validator meeting the ValidatorT requirements for MockHelpers. +// +// Custom validators should follow a similar structure, passing the type to +// MockHelpers::MockFor<KeyT>(m, CustomValidatorT()). +struct NoOpValidator { + // Default validation: do nothing. + template <typename ResultT, typename... Args> + static void Validate(ResultT, Args&&...) {} +}; + // MockHelpers works in conjunction with MockOverloadSet, MockingBitGen, and // BitGenRef to enable the mocking capability for absl distribution functions. // @@ -109,22 +118,39 @@ class MockHelpers { 0, urbg, std::forward<Args>(args)...); } - // Acquire a mock for the KeyT (may or may not be a signature). + // Acquire a mock for the KeyT (may or may not be a signature), set up to use + // the ValidatorT to verify that the result is in the range of the RNG + // function. // // KeyT is used to generate a typeid-based lookup for the mock. // KeyT is a signature of the form: // result_type(discriminator_type, std::tuple<args...>) // The mocked function signature will be composed from KeyT as: // result_type(args...) - template <typename KeyT, typename MockURBG> - static auto MockFor(MockURBG& m) + // ValidatorT::Validate will be called after the result of the RNG. The + // signature is expected to be of the form: + // ValidatorT::Validate(result, args...) + template <typename KeyT, typename ValidatorT, typename MockURBG> + static auto MockFor(MockURBG& m, ValidatorT) -> decltype(m.template RegisterMock< typename KeySignature<KeyT>::result_type, typename KeySignature<KeyT>::arg_tuple_type>( - m, std::declval<IdType>())) { + m, std::declval<IdType>(), ValidatorT())) { return m.template RegisterMock<typename KeySignature<KeyT>::result_type, typename KeySignature<KeyT>::arg_tuple_type>( - m, ::absl::base_internal::FastTypeId<KeyT>()); + m, ::absl::base_internal::FastTypeId<KeyT>(), ValidatorT()); + } + + // Acquire a mock for the KeyT (may or may not be a signature). + // + // KeyT is used to generate a typeid-based lookup for the mock. + // KeyT is a signature of the form: + // result_type(discriminator_type, std::tuple<args...>) + // The mocked function signature will be composed from KeyT as: + // result_type(args...) + template <typename KeyT, typename MockURBG> + static decltype(auto) MockFor(MockURBG& m) { + return MockFor<KeyT>(m, NoOpValidator()); } }; diff --git a/absl/random/internal/mock_overload_set.h b/absl/random/internal/mock_overload_set.h index 0d9c6c12..cfaeeeef 100644 --- a/absl/random/internal/mock_overload_set.h +++ b/absl/random/internal/mock_overload_set.h @@ -16,9 +16,11 @@ #ifndef ABSL_RANDOM_INTERNAL_MOCK_OVERLOAD_SET_H_ #define ABSL_RANDOM_INTERNAL_MOCK_OVERLOAD_SET_H_ +#include <tuple> #include <type_traits> #include "gmock/gmock.h" +#include "absl/base/config.h" #include "absl/random/internal/mock_helpers.h" #include "absl/random/mocking_bit_gen.h" @@ -26,7 +28,7 @@ namespace absl { ABSL_NAMESPACE_BEGIN namespace random_internal { -template <typename DistrT, typename Fn> +template <typename DistrT, typename ValidatorT, typename Fn> struct MockSingleOverload; // MockSingleOverload @@ -38,8 +40,8 @@ struct MockSingleOverload; // arguments to MockingBitGen::Register. // // The underlying KeyT must match the KeyT constructed by DistributionCaller. -template <typename DistrT, typename Ret, typename... Args> -struct MockSingleOverload<DistrT, Ret(MockingBitGen&, Args...)> { +template <typename DistrT, typename ValidatorT, typename Ret, typename... Args> +struct MockSingleOverload<DistrT, ValidatorT, Ret(MockingBitGen&, Args...)> { static_assert(std::is_same<typename DistrT::result_type, Ret>::value, "Overload signature must have return type matching the " "distribution result_type."); @@ -47,15 +49,21 @@ struct MockSingleOverload<DistrT, Ret(MockingBitGen&, Args...)> { template <typename MockURBG> auto gmock_Call(MockURBG& gen, const ::testing::Matcher<Args>&... matchers) - -> decltype(MockHelpers::MockFor<KeyT>(gen).gmock_Call(matchers...)) { - static_assert(std::is_base_of<MockingBitGen, MockURBG>::value, - "Mocking requires an absl::MockingBitGen"); - return MockHelpers::MockFor<KeyT>(gen).gmock_Call(matchers...); + -> decltype(MockHelpers::MockFor<KeyT>(gen, ValidatorT()) + .gmock_Call(matchers...)) { + static_assert( + std::is_base_of<MockingBitGenImpl<true>, MockURBG>::value || + std::is_base_of<MockingBitGenImpl<false>, MockURBG>::value, + "Mocking requires an absl::MockingBitGen"); + return MockHelpers::MockFor<KeyT>(gen, ValidatorT()) + .gmock_Call(matchers...); } }; -template <typename DistrT, typename Ret, typename Arg, typename... Args> -struct MockSingleOverload<DistrT, Ret(Arg, MockingBitGen&, Args...)> { +template <typename DistrT, typename ValidatorT, typename Ret, typename Arg, + typename... Args> +struct MockSingleOverload<DistrT, ValidatorT, + Ret(Arg, MockingBitGen&, Args...)> { static_assert(std::is_same<typename DistrT::result_type, Ret>::value, "Overload signature must have return type matching the " "distribution result_type."); @@ -64,14 +72,44 @@ struct MockSingleOverload<DistrT, Ret(Arg, MockingBitGen&, Args...)> { template <typename MockURBG> auto gmock_Call(const ::testing::Matcher<Arg>& matcher, MockURBG& gen, const ::testing::Matcher<Args>&... matchers) - -> decltype(MockHelpers::MockFor<KeyT>(gen).gmock_Call(matcher, - matchers...)) { - static_assert(std::is_base_of<MockingBitGen, MockURBG>::value, - "Mocking requires an absl::MockingBitGen"); - return MockHelpers::MockFor<KeyT>(gen).gmock_Call(matcher, matchers...); + -> decltype(MockHelpers::MockFor<KeyT>(gen, ValidatorT()) + .gmock_Call(matcher, matchers...)) { + static_assert( + std::is_base_of<MockingBitGenImpl<true>, MockURBG>::value || + std::is_base_of<MockingBitGenImpl<false>, MockURBG>::value, + "Mocking requires an absl::MockingBitGen"); + return MockHelpers::MockFor<KeyT>(gen, ValidatorT()) + .gmock_Call(matcher, matchers...); } }; +// MockOverloadSetWithValidator +// +// MockOverloadSetWithValidator is a wrapper around MockOverloadSet which takes +// an additional Validator parameter, allowing for customization of the mock +// behavior. +// +// `ValidatorT::Validate(result, args...)` will be called after the mock +// distribution returns a value in `result`, allowing for validation against the +// args. +template <typename DistrT, typename ValidatorT, typename... Fns> +struct MockOverloadSetWithValidator; + +template <typename DistrT, typename ValidatorT, typename Sig> +struct MockOverloadSetWithValidator<DistrT, ValidatorT, Sig> + : public MockSingleOverload<DistrT, ValidatorT, Sig> { + using MockSingleOverload<DistrT, ValidatorT, Sig>::gmock_Call; +}; + +template <typename DistrT, typename ValidatorT, typename FirstSig, + typename... Rest> +struct MockOverloadSetWithValidator<DistrT, ValidatorT, FirstSig, Rest...> + : public MockSingleOverload<DistrT, ValidatorT, FirstSig>, + public MockOverloadSetWithValidator<DistrT, ValidatorT, Rest...> { + using MockSingleOverload<DistrT, ValidatorT, FirstSig>::gmock_Call; + using MockOverloadSetWithValidator<DistrT, ValidatorT, Rest...>::gmock_Call; +}; + // MockOverloadSet // // MockOverloadSet takes a distribution and a collection of signatures and @@ -79,20 +117,8 @@ struct MockSingleOverload<DistrT, Ret(Arg, MockingBitGen&, Args...)> { // `EXPECT_CALL(mock_overload_set, Call(...))` expand and do overload resolution // correctly. template <typename DistrT, typename... Signatures> -struct MockOverloadSet; - -template <typename DistrT, typename Sig> -struct MockOverloadSet<DistrT, Sig> : public MockSingleOverload<DistrT, Sig> { - using MockSingleOverload<DistrT, Sig>::gmock_Call; -}; - -template <typename DistrT, typename FirstSig, typename... Rest> -struct MockOverloadSet<DistrT, FirstSig, Rest...> - : public MockSingleOverload<DistrT, FirstSig>, - public MockOverloadSet<DistrT, Rest...> { - using MockSingleOverload<DistrT, FirstSig>::gmock_Call; - using MockOverloadSet<DistrT, Rest...>::gmock_Call; -}; +using MockOverloadSet = + MockOverloadSetWithValidator<DistrT, NoOpValidator, Signatures...>; } // namespace random_internal ABSL_NAMESPACE_END diff --git a/absl/random/internal/mock_validators.h b/absl/random/internal/mock_validators.h new file mode 100644 index 00000000..d76d169c --- /dev/null +++ b/absl/random/internal/mock_validators.h @@ -0,0 +1,98 @@ +// Copyright 2024 The Abseil Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_ +#define ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_ + +#include <type_traits> + +#include "absl/base/config.h" +#include "absl/base/internal/raw_logging.h" +#include "absl/random/internal/iostream_state_saver.h" +#include "absl/random/internal/uniform_helper.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +namespace absl { +ABSL_NAMESPACE_BEGIN +namespace random_internal { + +template <typename NumType> +class UniformDistributionValidator { + public: + // Handle absl::Uniform<NumType>(gen, absl::IntervalTag, lo, hi). + template <typename TagType> + static void Validate(NumType x, TagType tag, NumType lo, NumType hi) { + // For invalid ranges, absl::Uniform() simply returns one of the bounds. + if (x == lo && lo == hi) return; + + ValidateImpl(std::is_floating_point<NumType>{}, x, tag, lo, hi); + } + + // Handle absl::Uniform<NumType>(gen, lo, hi). + static void Validate(NumType x, NumType lo, NumType hi) { + Validate(x, IntervalClosedOpenTag(), lo, hi); + } + + // Handle absl::Uniform<NumType>(gen). + static void Validate(NumType) { + // absl::Uniform<NumType>(gen) spans the entire range of `NumType`, so any + // value is okay. This overload exists because the validation logic attempts + // to call it anyway rather than adding extra SFINAE. + } + + private: + static absl::string_view TagLbBound(IntervalClosedOpenTag) { return "["; } + static absl::string_view TagLbBound(IntervalOpenOpenTag) { return "("; } + static absl::string_view TagLbBound(IntervalClosedClosedTag) { return "["; } + static absl::string_view TagLbBound(IntervalOpenClosedTag) { return "("; } + static absl::string_view TagUbBound(IntervalClosedOpenTag) { return ")"; } + static absl::string_view TagUbBound(IntervalOpenOpenTag) { return ")"; } + static absl::string_view TagUbBound(IntervalClosedClosedTag) { return "]"; } + static absl::string_view TagUbBound(IntervalOpenClosedTag) { return "]"; } + + template <typename TagType> + static void ValidateImpl(std::true_type /* is_floating_point */, NumType x, + TagType tag, NumType lo, NumType hi) { + UniformDistributionWrapper<NumType> dist(tag, lo, hi); + NumType lb = dist.a(); + NumType ub = dist.b(); + // uniform_real_distribution is always closed-open, so the upper bound is + // always non-inclusive. + ABSL_INTERNAL_CHECK(lb <= x && x < ub, + absl::StrCat(x, " is not in ", TagLbBound(tag), lo, + ", ", hi, TagUbBound(tag))); + } + + template <typename TagType> + static void ValidateImpl(std::false_type /* is_floating_point */, NumType x, + TagType tag, NumType lo, NumType hi) { + using stream_type = + typename random_internal::stream_format_type<NumType>::type; + + UniformDistributionWrapper<NumType> dist(tag, lo, hi); + NumType lb = dist.a(); + NumType ub = dist.b(); + ABSL_INTERNAL_CHECK( + lb <= x && x <= ub, + absl::StrCat(stream_type{x}, " is not in ", TagLbBound(tag), + stream_type{lo}, ", ", stream_type{hi}, TagUbBound(tag))); + } +}; + +} // namespace random_internal +ABSL_NAMESPACE_END +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_ |