diff --git a/src/net.cpp b/src/net.cpp index 1ae4b8fe08..941ea3c4ac 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -10,7 +10,6 @@ #include #include -#include #include #include #include @@ -615,7 +614,7 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete if (m_deserializer->Complete()) { // decompose a transport agnostic CNetMessage from the deserializer uint32_t out_err_raw_size{0}; - Optional result{m_deserializer->GetMessage(Params().MessageStart(), time, out_err_raw_size)}; + Optional result{m_deserializer->GetMessage(time, out_err_raw_size)}; if (!result) { // store the size of the corrupt message mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER)->second += out_err_raw_size; @@ -697,15 +696,14 @@ const uint256& V1TransportDeserializer::GetMessageHash() const return data_hash; } -Optional V1TransportDeserializer::GetMessage(const CMessageHeader::MessageStartChars& message_start, const std::chrono::microseconds time, uint32_t& out_err_raw_size) +Optional V1TransportDeserializer::GetMessage(const std::chrono::microseconds time, uint32_t& out_err_raw_size) { // decompose a single CNetMessage from the TransportDeserializer Optional msg(std::move(vRecv)); // store state about valid header, netmagic and checksum - msg->m_valid_header = hdr.IsValid(message_start); - msg->m_valid_netmagic = (memcmp(hdr.pchMessageStart, message_start, CMessageHeader::MESSAGE_START_SIZE) == 0); - uint256 hash = GetMessageHash(); + msg->m_valid_header = hdr.IsValid(m_chain_params.MessageStart()); + msg->m_valid_netmagic = (memcmp(hdr.pchMessageStart, m_chain_params.MessageStart(), CMessageHeader::MESSAGE_START_SIZE) == 0); // store command string, time, and sizes msg->m_command = hdr.GetCommand(); @@ -713,6 +711,8 @@ Optional V1TransportDeserializer::GetMessage(const CMessageHeader:: msg->m_message_size = hdr.nMessageSize; msg->m_raw_message_size = hdr.nMessageSize + CMessageHeader::HEADER_SIZE; + uint256 hash = GetMessageHash(); + // We just received a message off the wire, harvest entropy from the time (and the message checksum) RandAddEvent(ReadLE32(hash.begin())); @@ -2846,7 +2846,7 @@ CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, int nMyStartingHeightIn LogPrint(BCLog::NET, "Added connection peer=%d\n", id); } - m_deserializer = MakeUnique(V1TransportDeserializer(GetId(), SER_NETWORK, INIT_PROTO_VERSION)); + m_deserializer = MakeUnique(V1TransportDeserializer(Params(), GetId(), SER_NETWORK, INIT_PROTO_VERSION)); m_serializer = MakeUnique(V1TransportSerializer()); } diff --git a/src/net.h b/src/net.h index cec201c5d2..29941b9622 100644 --- a/src/net.h +++ b/src/net.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -732,13 +733,14 @@ public: // read and deserialize data virtual int Read(const char *data, unsigned int bytes) = 0; // decomposes a message from the context - virtual Optional GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time, uint32_t& out_err) = 0; + virtual Optional GetMessage(std::chrono::microseconds time, uint32_t& out_err) = 0; virtual ~TransportDeserializer() {} }; class V1TransportDeserializer final : public TransportDeserializer { private: + const CChainParams& m_chain_params; const NodeId m_node_id; // Only for logging mutable CHash256 hasher; mutable uint256 data_hash; @@ -765,8 +767,9 @@ private: } public: - V1TransportDeserializer(const NodeId node_id, int nTypeIn, int nVersionIn) - : m_node_id(node_id), + V1TransportDeserializer(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn) + : m_chain_params(chain_params), + m_node_id(node_id), hdrbuf(nTypeIn, nVersionIn), vRecv(nTypeIn, nVersionIn) { @@ -789,7 +792,7 @@ public: if (ret < 0) Reset(); return ret; } - Optional GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time, uint32_t& out_err_raw_size) override; + Optional GetMessage(std::chrono::microseconds time, uint32_t& out_err_raw_size) override; }; /** The TransportSerializer prepares messages for the network transport diff --git a/src/test/fuzz/p2p_transport_deserializer.cpp b/src/test/fuzz/p2p_transport_deserializer.cpp index 5349fd3f68..6252b8e91b 100644 --- a/src/test/fuzz/p2p_transport_deserializer.cpp +++ b/src/test/fuzz/p2p_transport_deserializer.cpp @@ -20,7 +20,7 @@ void initialize() void test_one_input(const std::vector& buffer) { // Construct deserializer, with a dummy NodeId - V1TransportDeserializer deserializer{(NodeId)0, SER_NETWORK, INIT_PROTO_VERSION}; + V1TransportDeserializer deserializer{Params(), (NodeId)0, SER_NETWORK, INIT_PROTO_VERSION}; const char* pch = (const char*)buffer.data(); size_t n_bytes = buffer.size(); while (n_bytes > 0) { @@ -33,7 +33,7 @@ void test_one_input(const std::vector& buffer) if (deserializer.Complete()) { const std::chrono::microseconds m_time{std::numeric_limits::max()}; uint32_t out_err_raw_size{0}; - Optional result{deserializer.GetMessage(Params().MessageStart(), m_time, out_err_raw_size)}; + Optional result{deserializer.GetMessage(m_time, out_err_raw_size)}; if (result) { assert(result->m_command.size() <= CMessageHeader::COMMAND_SIZE); assert(result->m_raw_message_size <= buffer.size());