From ddb7d26cfd96c1f626def4755e0e1b5aaac94d3e Mon Sep 17 00:00:00 2001 From: Pieter Wuille Date: Sun, 10 Mar 2024 12:38:14 -0400 Subject: [PATCH] random: add RandomMixin::randbits with compile-known bits In many cases, it is known at compile time how many bits are requested from randbits. Provide a variant of randbits that accepts this number as a template, to make sure the compiler can make use of this knowledge. This is used immediately in rand32() and randbool(), and a few further call sites. --- src/addrman.cpp | 2 +- src/random.cpp | 2 +- src/random.h | 28 ++++++++++++++++++++++++++-- src/test/crypto_tests.cpp | 6 +++--- src/test/random_tests.cpp | 22 ++++++++++++++++++++-- test/sanitizer_suppressions/ubsan | 1 + 6 files changed, 52 insertions(+), 9 deletions(-) diff --git a/src/addrman.cpp b/src/addrman.cpp index d0b820ee651..054a9bee32d 100644 --- a/src/addrman.cpp +++ b/src/addrman.cpp @@ -776,7 +776,7 @@ std::pair AddrManImpl::Select_(bool new_only, std::option const AddrInfo& info{it_found->second}; // With probability GetChance() * chance_factor, return the entry. - if (insecure_rand.randbits(30) < chance_factor * info.GetChance() * (1 << 30)) { + if (insecure_rand.randbits<30>() < chance_factor * info.GetChance() * (1 << 30)) { LogPrint(BCLog::ADDRMAN, "Selected %s from %s\n", info.ToStringAddrPort(), search_tried ? "tried" : "new"); return {info, info.m_last_try}; } diff --git a/src/random.cpp b/src/random.cpp index bb19d70d922..10ad4e2558a 100644 --- a/src/random.cpp +++ b/src/random.cpp @@ -741,6 +741,6 @@ void RandomInit() std::chrono::microseconds GetExponentialRand(std::chrono::microseconds now, std::chrono::seconds average_interval) { - double unscaled = -std::log1p(GetRand(uint64_t{1} << 48) * -0.0000000000000035527136788 /* -1/2^48 */); + double unscaled = -std::log1p(FastRandomContext().randbits<48>() * -0.0000000000000035527136788 /* -1/2^48 */); return now + std::chrono::duration_cast(unscaled * average_interval + 0.5us); } diff --git a/src/random.h b/src/random.h index efcdae5142b..007853bc364 100644 --- a/src/random.h +++ b/src/random.h @@ -223,6 +223,30 @@ public: return ret & ((uint64_t{1} << bits) - 1); } + /** Same as above, but with compile-time fixed bits count. */ + template + uint64_t randbits() noexcept + { + static_assert(Bits >= 0 && Bits <= 64); + if constexpr (Bits == 64) { + return Impl().rand64(); + } else { + uint64_t ret; + if (Bits <= bitbuf_size) { + ret = bitbuf; + bitbuf >>= Bits; + bitbuf_size -= Bits; + } else { + uint64_t gen = Impl().rand64(); + ret = (gen << bitbuf_size) | bitbuf; + bitbuf = gen >> (Bits - bitbuf_size); + bitbuf_size = 64 + bitbuf_size - Bits; + } + constexpr uint64_t MASK = (uint64_t{1} << Bits) - 1; + return ret & MASK; + } + } + /** Generate a random integer in the range [0..range). * Precondition: range > 0. */ @@ -247,7 +271,7 @@ public: } /** Generate a random 32-bit integer. */ - uint32_t rand32() noexcept { return Impl().randbits(32); } + uint32_t rand32() noexcept { return Impl().template randbits<32>(); } /** generate a random uint256. */ uint256 rand256() noexcept @@ -258,7 +282,7 @@ public: } /** Generate a random boolean. */ - bool randbool() noexcept { return Impl().randbits(1); } + bool randbool() noexcept { return Impl().template randbits<1>(); } /** Return the time point advanced by a uniform random duration. */ template diff --git a/src/test/crypto_tests.cpp b/src/test/crypto_tests.cpp index 46acc6fc9f5..d78957e35a4 100644 --- a/src/test/crypto_tests.cpp +++ b/src/test/crypto_tests.cpp @@ -1195,7 +1195,7 @@ BOOST_AUTO_TEST_CASE(muhash_tests) uint256 res; int table[4]; for (int i = 0; i < 4; ++i) { - table[i] = g_insecure_rand_ctx.randbits(3); + table[i] = g_insecure_rand_ctx.randbits<3>(); } for (int order = 0; order < 4; ++order) { MuHash3072 acc; @@ -1215,8 +1215,8 @@ BOOST_AUTO_TEST_CASE(muhash_tests) } } - MuHash3072 x = FromInt(g_insecure_rand_ctx.randbits(4)); // x=X - MuHash3072 y = FromInt(g_insecure_rand_ctx.randbits(4)); // x=X, y=Y + MuHash3072 x = FromInt(g_insecure_rand_ctx.randbits<4>()); // x=X + MuHash3072 y = FromInt(g_insecure_rand_ctx.randbits<4>()); // x=X, y=Y MuHash3072 z; // x=X, y=Y, z=1 z *= x; // x=X, y=Y, z=X z *= y; // x=X, y=Y, z=X*Y diff --git a/src/test/random_tests.cpp b/src/test/random_tests.cpp index 53c29f31cab..2617cc4a2a5 100644 --- a/src/test/random_tests.cpp +++ b/src/test/random_tests.cpp @@ -107,7 +107,7 @@ BOOST_AUTO_TEST_CASE(fastrandom_randbits) BOOST_AUTO_TEST_CASE(randbits_test) { FastRandomContext ctx_lens; //!< RNG for producing the lengths requested from ctx_test. - FastRandomContext ctx_test; //!< The RNG being tested. + FastRandomContext ctx_test1(true), ctx_test2(true); //!< The RNGs being tested. int ctx_test_bitsleft{0}; //!< (Assumed value of) ctx_test::bitbuf_len // Run the entire test 5 times. @@ -122,7 +122,25 @@ BOOST_AUTO_TEST_CASE(randbits_test) // Decide on a number of bits to request (0 through 64, inclusive; don't use randbits/randrange). int bits = ctx_lens.rand64() % 65; // Generate that many bits. - uint64_t gen = ctx_test.randbits(bits); + uint64_t gen = ctx_test1.randbits(bits); + // For certain bits counts, also test randbits and compare. + uint64_t gen2; + if (bits == 0) { + gen2 = ctx_test2.randbits<0>(); + } else if (bits == 1) { + gen2 = ctx_test2.randbits<1>(); + } else if (bits == 7) { + gen2 = ctx_test2.randbits<7>(); + } else if (bits == 32) { + gen2 = ctx_test2.randbits<32>(); + } else if (bits == 51) { + gen2 = ctx_test2.randbits<51>(); + } else if (bits == 64) { + gen2 = ctx_test2.randbits<64>(); + } else { + gen2 = ctx_test2.randbits(bits); + } + BOOST_CHECK_EQUAL(gen, gen2); // Make sure the result is in range. if (bits < 64) BOOST_CHECK_EQUAL(gen >> bits, 0); // Mark all the seen bits in the output. diff --git a/test/sanitizer_suppressions/ubsan b/test/sanitizer_suppressions/ubsan index be9c7fb300a..d949aabf846 100644 --- a/test/sanitizer_suppressions/ubsan +++ b/test/sanitizer_suppressions/ubsan @@ -77,3 +77,4 @@ shift-base:streams.h shift-base:FormatHDKeypath shift-base:xoroshiro128plusplus.h shift-base:RandomMixin<*>::randbits +shift-base:RandomMixin<*>::randbits<*>