diff --git a/src/net.cpp b/src/net.cpp index 46e53924db..1a0bee8c51 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -1161,6 +1161,42 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { RandAddEvent((uint32_t)id); } +bool CConnman::AddConnection(const std::string& address, ConnectionType conn_type) +{ + Optional max_connections; + switch (conn_type) { + case ConnectionType::INBOUND: + case ConnectionType::MANUAL: + return false; + case ConnectionType::OUTBOUND_FULL_RELAY: + max_connections = m_max_outbound_full_relay; + break; + case ConnectionType::BLOCK_RELAY: + max_connections = m_max_outbound_block_relay; + break; + // no limit for ADDR_FETCH because -seednode has no limit either + case ConnectionType::ADDR_FETCH: + break; + // no limit for FEELER connections since they're short-lived + case ConnectionType::FEELER: + break; + } // no default case, so the compiler can warn about missing cases + + // Count existing connections + int existing_connections = WITH_LOCK(cs_vNodes, + return std::count_if(vNodes.begin(), vNodes.end(), [conn_type](CNode* node) { return node->m_conn_type == conn_type; });); + + // Max connections of specified type already exist + if (max_connections != nullopt && existing_connections >= max_connections) return false; + + // Max total outbound connections already exist + CSemaphoreGrant grant(*semOutbound, true); + if (!grant) return false; + + OpenNetworkConnection(CAddress(), false, &grant, address.c_str(), conn_type); + return true; +} + void CConnman::DisconnectNodes() { { diff --git a/src/net.h b/src/net.h index 77649247d9..2597b5fdc8 100644 --- a/src/net.h +++ b/src/net.h @@ -346,6 +346,20 @@ public: bool RemoveAddedNode(const std::string& node); std::vector GetAddedNodeInfo(); + /** + * Attempts to open a connection. Currently only used from tests. + * + * @param[in] address Address of node to try connecting to + * @param[in] conn_type ConnectionType::OUTBOUND, ConnectionType::BLOCK_RELAY, + * ConnectionType::ADDR_FETCH or ConnectionType::FEELER + * @return bool Returns false if there are no available + * slots for this connection: + * - conn_type not a supported ConnectionType + * - Max total outbound connection capacity filled + * - Max connection capacity for type is filled + */ + bool AddConnection(const std::string& address, ConnectionType conn_type); + size_t GetNodeCount(NumConnections num); void GetNodeStats(std::vector& vstats); bool DisconnectNode(const std::string& node); diff --git a/src/net_permissions.h b/src/net_permissions.h index bba0ea1695..3b841ab138 100644 --- a/src/net_permissions.h +++ b/src/net_permissions.h @@ -30,7 +30,8 @@ enum NetPermissionFlags { PF_NOBAN = (1U << 4) | PF_DOWNLOAD, // Can query the mempool PF_MEMPOOL = (1U << 5), - // Can request addrs without hitting a privacy-preserving cache + // Can request addrs without hitting a privacy-preserving cache, and send us + // unlimited amounts of addrs. PF_ADDR = (1U << 7), // True if the user did not specifically set fine grained permissions diff --git a/src/net_processing.cpp b/src/net_processing.cpp index 1f307f39c7..d6504de62b 100644 --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -146,6 +146,13 @@ static constexpr uint32_t MAX_GETCFILTERS_SIZE = 1000; static constexpr uint32_t MAX_GETCFHEADERS_SIZE = 2000; /** the maximum percentage of addresses from our addrman to return in response to a getaddr message. */ static constexpr size_t MAX_PCT_ADDR_TO_SEND = 23; +/** The maximum rate of address records we're willing to process on average. Can be bypassed using + * the NetPermissionFlags::Addr permission. */ +static constexpr double MAX_ADDR_RATE_PER_SECOND{0.1}; +/** The soft limit of the address processing token bucket (the regular MAX_ADDR_RATE_PER_SECOND + * based increments won't go above this, but the MAX_ADDR_TO_SEND increment following GETADDR + * is exempt from this limit). */ +static constexpr size_t MAX_ADDR_PROCESSING_TOKEN_BUCKET{MAX_ADDR_TO_SEND}; struct COrphanTx { // When modifying, adapt the copy of this definition in tests/DoS_tests. @@ -471,6 +478,16 @@ struct Peer { /** Work queue of items requested by this peer **/ std::deque m_getdata_requests GUARDED_BY(m_getdata_requests_mutex); + /** Number of addresses that can be processed from this peer. Start at 1 to + * permit self-announcement. */ + double m_addr_token_bucket GUARDED_BY(NetEventsInterface::g_msgproc_mutex){1.0}; + /** When m_addr_token_bucket was last updated */ + std::chrono::microseconds m_addr_token_timestamp GUARDED_BY(NetEventsInterface::g_msgproc_mutex){GetTime()}; + /** Total number of addresses that were dropped due to rate limiting. */ + std::atomic m_addr_rate_limited{0}; + /** Total number of addresses that were processed (excludes rate-limited ones). */ + std::atomic m_addr_processed{0}; + Peer(NodeId id) : m_id(id) {} }; @@ -930,6 +947,8 @@ bool GetNodeStateStats(NodeId nodeid, CNodeStateStats &stats) { PeerRef peer = GetPeerRef(nodeid); if (peer == nullptr) return false; stats.m_misbehavior_score = WITH_LOCK(peer->m_misbehavior_mutex, return peer->m_misbehavior_score); + stats.m_addr_processed = peer->m_addr_processed.load(); + stats.m_addr_rate_limited = peer->m_addr_rate_limited.load(); return true; } @@ -2489,6 +2508,9 @@ void PeerManager::ProcessMessage(CNode& pfrom, const std::string& msg_type, CDat // Get recent addresses m_connman.PushMessage(&pfrom, CNetMsgMaker(greatest_common_version).Make(NetMsgType::GETADDR)); pfrom.fGetAddr = true; + // When requesting a getaddr, accept an additional MAX_ADDR_TO_SEND addresses in response + // (bypassing the MAX_ADDR_PROCESSING_TOKEN_BUCKET limit). + peer->m_addr_token_bucket += MAX_ADDR_TO_SEND; } if (!pfrom.IsInboundConn()) { @@ -2645,11 +2667,35 @@ void PeerManager::ProcessMessage(CNode& pfrom, const std::string& msg_type, CDat std::vector vAddrOk; int64_t nNow = GetAdjustedTime(); int64_t nSince = nNow - 10 * 60; + + // Update/increment addr rate limiting bucket. + const auto current_time{GetTime()}; + if (peer->m_addr_token_bucket < MAX_ADDR_PROCESSING_TOKEN_BUCKET) { + // Don't increment bucket if it's already full + const auto time_diff = std::max(current_time - peer->m_addr_token_timestamp, 0us); + const double increment = Ticks(time_diff) * MAX_ADDR_RATE_PER_SECOND; + peer->m_addr_token_bucket = std::min(peer->m_addr_token_bucket + increment, MAX_ADDR_PROCESSING_TOKEN_BUCKET); + } + peer->m_addr_token_timestamp = current_time; + + const bool rate_limited = !pfrom.HasPermission(NetPermissionFlags::PF_ADDR); + uint64_t num_proc = 0; + uint64_t num_rate_limit = 0; + Shuffle(vAddr.begin(), vAddr.end(), FastRandomContext()); for (CAddress& addr : vAddr) { if (interruptMsgProc) return; + // Apply rate limiting. + if (peer->m_addr_token_bucket < 1.0) { + if (rate_limited) { + ++num_rate_limit; + continue; + } + } else { + peer->m_addr_token_bucket -= 1.0; + } // We only bother storing full nodes, though this may include // things which we would not make an outbound connection to, in // part because we may make feeler connections to them. @@ -2663,6 +2709,7 @@ void PeerManager::ProcessMessage(CNode& pfrom, const std::string& msg_type, CDat // Do not process banned/discouraged addresses beyond remembering we received them continue; } + ++num_proc; bool fReachable = IsReachable(addr); if (addr.nTime > nSince && !pfrom.fGetAddr && vAddr.size() <= 10 && addr.IsRoutable()) { @@ -2673,6 +2720,11 @@ void PeerManager::ProcessMessage(CNode& pfrom, const std::string& msg_type, CDat if (fReachable) vAddrOk.push_back(addr); } + peer->m_addr_processed += num_proc; + peer->m_addr_rate_limited += num_rate_limit; + LogPrint(BCLog::NET, "Received addr: %u addresses (%u processed, %u rate-limited) from peer=%d\n", + vAddr.size(), num_proc, num_rate_limit, pfrom.GetId()); + m_connman.AddNewAddresses(vAddrOk, pfrom.addr, 2 * 60 * 60); if (vAddr.size() < 1000) pfrom.fGetAddr = false; diff --git a/src/net_processing.h b/src/net_processing.h index 87eee566de..4a9c76a3fa 100644 --- a/src/net_processing.h +++ b/src/net_processing.h @@ -150,6 +150,8 @@ struct CNodeStateStats { int nSyncHeight = -1; int nCommonHeight = -1; std::vector vHeightInFlight; + uint64_t m_addr_processed = 0; + uint64_t m_addr_rate_limited = 0; }; /** Get statistics from node state */ diff --git a/src/rpc/blockchain.cpp b/src/rpc/blockchain.cpp index c52cbc248b..1ba0e3f429 100644 --- a/src/rpc/blockchain.cpp +++ b/src/rpc/blockchain.cpp @@ -81,6 +81,15 @@ ChainstateManager& EnsureChainman(const util::Ref& context) return *node.chainman; } +CConnman& EnsureConnman(const util::Ref& context) +{ + NodeContext& node = EnsureNodeContext(context); + if (!node.connman) { + throw JSONRPCError(RPC_CLIENT_P2P_DISABLED, "Error: Peer-to-peer functionality missing or disabled"); + } + return *node.connman; +} + /* Calculate the difficulty for a given block index. */ double GetDifficulty(const CBlockIndex* blockindex) diff --git a/src/rpc/blockchain.h b/src/rpc/blockchain.h index 5b362bf211..b8ea216975 100644 --- a/src/rpc/blockchain.h +++ b/src/rpc/blockchain.h @@ -15,6 +15,7 @@ extern RecursiveMutex cs_main; class CBlock; class CBlockIndex; +class CConnman; class CTxMemPool; class ChainstateManager; class UniValue; @@ -54,5 +55,6 @@ void CalculatePercentilesByWeight(CAmount result[NUM_GETBLOCKSTATS_PERCENTILES], NodeContext& EnsureNodeContext(const util::Ref& context); CTxMemPool& EnsureMemPool(const util::Ref& context); ChainstateManager& EnsureChainman(const util::Ref& context); +CConnman& EnsureConnman(const util::Ref& context); #endif diff --git a/src/rpc/net.cpp b/src/rpc/net.cpp index 4fb96af72e..207509d71f 100644 --- a/src/rpc/net.cpp +++ b/src/rpc/net.cpp @@ -139,6 +139,8 @@ static RPCHelpMan getpeerinfo() { {RPCResult::Type::NUM, "n", "The heights of blocks we're currently asking from this peer"}, }}, + {RPCResult::Type::NUM, "addr_processed", "The total number of addresses processed, excluding those dropped due to rate limiting"}, + {RPCResult::Type::NUM, "addr_rate_limited", "The total number of addresses dropped due to rate limiting"}, {RPCResult::Type::BOOL, "whitelisted", /* optional */ true, "Whether the peer is whitelisted with default permissions\n" "(DEPRECATED, returned only if config option -deprecatedrpc=whitelisted is passed)"}, {RPCResult::Type::ARR, "permissions", "Any special permissions that have been granted to this peer", @@ -236,6 +238,8 @@ static RPCHelpMan getpeerinfo() heights.push_back(height); } obj.pushKV("inflight", heights); + obj.pushKV("addr_processed", statestats.m_addr_processed); + obj.pushKV("addr_rate_limited", statestats.m_addr_rate_limited); } if (IsDeprecatedRPCEnabled("whitelisted")) { // whitelisted is deprecated in v0.21 for removal in v0.22 @@ -326,6 +330,61 @@ static RPCHelpMan addnode() }; } +static RPCHelpMan addconnection() +{ + return RPCHelpMan{"addconnection", + "\nOpen an outbound connection to a specified node. This RPC is for testing only.\n", + { + {"address", RPCArg::Type::STR, RPCArg::Optional::NO, "The IP address and port to attempt connecting to."}, + {"connection_type", RPCArg::Type::STR, RPCArg::Optional::NO, "Type of connection to open (\"outbound-full-relay\", \"block-relay-only\", \"addr-fetch\" or \"feeler\")."}, + }, + RPCResult{ + RPCResult::Type::OBJ, "", "", + { + { RPCResult::Type::STR, "address", "Address of newly added connection." }, + { RPCResult::Type::STR, "connection_type", "Type of connection opened." }, + }}, + RPCExamples{ + HelpExampleCli("addconnection", "\"192.168.0.6:8333\" \"outbound-full-relay\"") + + HelpExampleRpc("addconnection", "\"192.168.0.6:8333\" \"outbound-full-relay\"") + }, + [&](const RPCHelpMan& self, const JSONRPCRequest& request) -> UniValue +{ + if (Params().NetworkIDString() != CBaseChainParams::REGTEST) { + throw std::runtime_error("addconnection is for regression testing (-regtest mode) only."); + } + + const std::string address = request.params[0].get_str(); + const std::string conn_type_in{TrimString(request.params[1].get_str())}; + ConnectionType conn_type{}; + if (conn_type_in == "outbound-full-relay") { + conn_type = ConnectionType::OUTBOUND_FULL_RELAY; + } else if (conn_type_in == "block-relay-only") { + conn_type = ConnectionType::BLOCK_RELAY; + } else if (conn_type_in == "addr-fetch") { + conn_type = ConnectionType::ADDR_FETCH; + } else if (conn_type_in == "feeler") { + conn_type = ConnectionType::FEELER; + } else { + throw JSONRPCError(RPC_INVALID_PARAMETER, self.ToString()); + } + + CConnman& connman = EnsureConnman(request.context); + + const bool success = connman.AddConnection(address, conn_type); + if (!success) { + throw JSONRPCError(RPC_CLIENT_NODE_CAPACITY_REACHED, "Error: Already at capacity for specified connection type."); + } + + UniValue info(UniValue::VOBJ); + info.pushKV("address", address); + info.pushKV("connection_type", conn_type_in); + + return info; +}, + }; +} + static RPCHelpMan disconnectnode() { return RPCHelpMan{"disconnectnode", @@ -910,6 +969,7 @@ static const CRPCCommand commands[] = { "network", "clearbanned", &clearbanned, {} }, { "network", "setnetworkactive", &setnetworkactive, {"state"} }, { "network", "getnodeaddresses", &getnodeaddresses, {"count"} }, + { "hidden", "addconnection", &addconnection, {"address", "connection_type"} }, { "hidden", "addpeeraddress", &addpeeraddress, {"address", "port"} }, }; // clang-format on diff --git a/src/rpc/protocol.h b/src/rpc/protocol.h index d1475f452d..c8ceb2c186 100644 --- a/src/rpc/protocol.h +++ b/src/rpc/protocol.h @@ -62,6 +62,7 @@ enum RPCErrorCode RPC_CLIENT_NODE_NOT_CONNECTED = -29, //!< Node to disconnect not found in connected nodes RPC_CLIENT_INVALID_IP_OR_SUBNET = -30, //!< Invalid IP/Subnet RPC_CLIENT_P2P_DISABLED = -31, //!< No valid connection manager instance found + RPC_CLIENT_NODE_CAPACITY_REACHED= -34, //!< Max number of outbound or block-relay connections already open //! Chain errors RPC_CLIENT_MEMPOOL_DISABLED = -33, //!< No mempool instance found diff --git a/src/util/time.h b/src/util/time.h index af934e423b..e167ad5be2 100644 --- a/src/util/time.h +++ b/src/util/time.h @@ -10,21 +10,38 @@ #include #include +using namespace std::chrono_literals; + void UninterruptibleSleep(const std::chrono::microseconds& n); /** - * Helper to count the seconds of a duration. + * Helper to count the seconds of a duration/time_point. * - * All durations should be using std::chrono and calling this should generally + * All durations/time_points should be using std::chrono and calling this should generally * be avoided in code. Though, it is still preferred to an inline t.count() to * protect against a reliance on the exact type of t. * - * This helper is used to convert durations before passing them over an + * This helper is used to convert durations/time_points before passing them over an * interface that doesn't support std::chrono (e.g. RPC, debug log, or the GUI) */ +template +inline auto Ticks(Dur2 d) +{ + return std::chrono::duration_cast(d).count(); +} +template +inline auto TicksSinceEpoch(Timepoint t) +{ + return Ticks(t.time_since_epoch()); +} inline int64_t count_seconds(std::chrono::seconds t) { return t.count(); } +inline int64_t count_milliseconds(std::chrono::milliseconds t) { return t.count(); } inline int64_t count_microseconds(std::chrono::microseconds t) { return t.count(); } +using HoursDouble = std::chrono::duration; +using SecondsDouble = std::chrono::duration; +using MillisecondsDouble = std::chrono::duration; + /** * DEPRECATED * Use either GetSystemTimeInSeconds (not mockable) or GetTime (mockable) diff --git a/test/functional/p2p_addr_relay.py b/test/functional/p2p_addr_relay.py index 80f262d0d3..dc4996b718 100755 --- a/test/functional/p2p_addr_relay.py +++ b/test/functional/p2p_addr_relay.py @@ -12,11 +12,15 @@ from test_framework.messages import ( NODE_WITNESS, msg_addr, ) -from test_framework.p2p import P2PInterface +from test_framework.p2p import ( + P2PInterface, + p2p_lock, +) from test_framework.test_framework import BitcoinTestFramework from test_framework.util import ( assert_equal, ) +import random import time ADDRS = [] @@ -30,17 +34,33 @@ for i in range(10): class AddrReceiver(P2PInterface): + _tokens = 1 def on_addr(self, message): for addr in message.addrs: assert_equal(addr.nServices, 9) assert addr.ip.startswith('123.123.123.') assert (8333 <= addr.port < 8343) + def on_getaddr(self, message): + # When the node sends us a getaddr, it increments the addr relay tokens for the connection by 1000 + self._tokens += 1000 + + @property + def tokens(self): + with p2p_lock: + return self._tokens + + def increment_tokens(self, n): + # When we move mocktime forward, the node increments the addr relay tokens for its peers + with p2p_lock: + self._tokens += n + class AddrTest(BitcoinTestFramework): def set_test_params(self): self.setup_clean_chain = False self.num_nodes = 1 + self.extra_args = [["-whitelist=addr@127.0.0.1"]] def run_test(self): self.log.info('Create connection that sends addr messages') @@ -64,6 +84,77 @@ class AddrTest(BitcoinTestFramework): self.nodes[0].setmocktime(int(time.time()) + 30 * 60) addr_receiver.sync_with_ping() + self.rate_limit_tests() + + def setup_rand_addr_msg(self, num): + addrs = [] + for i in range(num): + addr = CAddress() + addr.time = self.mocktime + i + addr.nServices = NODE_NETWORK | NODE_WITNESS + addr.ip = f"{random.randrange(128,169)}.{random.randrange(1,255)}.{random.randrange(1,255)}.{random.randrange(1,255)}" + addr.port = 8333 + addrs.append(addr) + msg = msg_addr() + msg.addrs = addrs + return msg + + def send_addrs_and_test_rate_limiting(self, peer, no_relay, *, new_addrs, total_addrs): + """Send an addr message and check that the number of addresses processed and rate-limited is as expected""" + + peer.send_and_ping(self.setup_rand_addr_msg(new_addrs)) + + peerinfo = self.nodes[0].getpeerinfo()[0] + addrs_processed = peerinfo['addr_processed'] + addrs_rate_limited = peerinfo['addr_rate_limited'] + self.log.debug(f"addrs_processed = {addrs_processed}, addrs_rate_limited = {addrs_rate_limited}") + + if no_relay: + assert_equal(addrs_processed, 0) + assert_equal(addrs_rate_limited, 0) + else: + assert_equal(addrs_processed, min(total_addrs, peer.tokens)) + assert_equal(addrs_rate_limited, max(0, total_addrs - peer.tokens)) + + def rate_limit_tests(self): + self.mocktime = int(time.time()) + self.restart_node(0, []) + self.nodes[0].setmocktime(self.mocktime) + + for conn_type, no_relay in [("outbound-full-relay", False), ("block-relay-only", True), ("inbound", False)]: + self.log.info(f'Test rate limiting of addr processing for {conn_type} peers') + if conn_type == "inbound": + peer = self.nodes[0].add_p2p_connection(AddrReceiver()) + else: + peer = self.nodes[0].add_outbound_p2p_connection(AddrReceiver(), p2p_idx=0, connection_type=conn_type) + + # Send 600 addresses. For all but the block-relay-only peer this should result in addresses being processed. + self.send_addrs_and_test_rate_limiting(peer, no_relay, new_addrs=600, total_addrs=600) + + # Send 600 more addresses. For the outbound-full-relay peer (which we send a GETADDR, and thus will + # process up to 1001 incoming addresses), this means more addresses will be processed. + self.send_addrs_and_test_rate_limiting(peer, no_relay, new_addrs=600, total_addrs=1200) + + # Send 10 more. As we reached the processing limit for all nodes, no more addresses should be procesesd. + self.send_addrs_and_test_rate_limiting(peer, no_relay, new_addrs=10, total_addrs=1210) + + # Advance the time by 100 seconds, permitting the processing of 10 more addresses. + # Send 200 and verify that 10 are processed. + self.mocktime += 100 + self.nodes[0].setmocktime(self.mocktime) + peer.increment_tokens(10) + + self.send_addrs_and_test_rate_limiting(peer, no_relay, new_addrs=200, total_addrs=1410) + + # Advance the time by 1000 seconds, permitting the processing of 100 more addresses. + # Send 200 and verify that 100 are processed. + self.mocktime += 1000 + self.nodes[0].setmocktime(self.mocktime) + peer.increment_tokens(100) + + self.send_addrs_and_test_rate_limiting(peer, no_relay, new_addrs=200, total_addrs=1610) + + self.nodes[0].disconnect_p2ps() if __name__ == '__main__': AddrTest().main() diff --git a/test/functional/p2p_addrv2_relay.py b/test/functional/p2p_addrv2_relay.py index 23ce3e5d04..0ee755a28d 100755 --- a/test/functional/p2p_addrv2_relay.py +++ b/test/functional/p2p_addrv2_relay.py @@ -39,7 +39,10 @@ class AddrReceiver(P2PInterface): assert_equal(addr.nServices, 9) assert addr.ip.startswith('123.123.123.') assert (8333 <= addr.port < 8343) - self.addrv2_received_and_checked = True + expected_set = set((addr.ip, addr.port) for addr in ADDRS) + received_set = set((addr.ip, addr.port) for addr in message.addrs) + if expected_set == received_set: + self.addrv2_received_and_checked = True def wait_for_addrv2(self): self.wait_until(lambda: "addrv2" in self.last_message) @@ -49,6 +52,7 @@ class AddrTest(BitcoinTestFramework): def set_test_params(self): self.setup_clean_chain = True self.num_nodes = 1 + self.extra_args = [["-whitelist=addr@127.0.0.1"]] def run_test(self): self.log.info('Create connection that sends addrv2 messages') diff --git a/test/functional/p2p_invalid_messages.py b/test/functional/p2p_invalid_messages.py index db72a361d9..3934e7611e 100755 --- a/test/functional/p2p_invalid_messages.py +++ b/test/functional/p2p_invalid_messages.py @@ -57,6 +57,7 @@ class InvalidMessagesTest(BitcoinTestFramework): def set_test_params(self): self.num_nodes = 1 self.setup_clean_chain = True + self.extra_args = [["-whitelist=addr@127.0.0.1"]] def run_test(self): self.test_buffer() diff --git a/test/functional/test_framework/p2p.py b/test/functional/test_framework/p2p.py index 9da3f59ad2..eef8a98e86 100755 --- a/test/functional/test_framework/p2p.py +++ b/test/functional/test_framework/p2p.py @@ -72,7 +72,11 @@ from test_framework.messages import ( NODE_WITNESS, sha256, ) -from test_framework.util import wait_until_helper +from test_framework.util import ( + MAX_NODES, + p2p_port, + wait_until_helper, +) logger = logging.getLogger("TestFramework.p2p") @@ -140,7 +144,7 @@ class P2PConnection(asyncio.Protocol): def is_connected(self): return self._transport is not None - def peer_connect(self, dstaddr, dstport, *, net, timeout_factor): + def peer_connect_helper(self, dstaddr, dstport, net, timeout_factor): assert not self.is_connected self.timeout_factor = timeout_factor self.dstaddr = dstaddr @@ -149,12 +153,20 @@ class P2PConnection(asyncio.Protocol): self.on_connection_send_msg = None self.recvbuf = b"" self.magic_bytes = MAGIC_BYTES[net] - logger.debug('Connecting to Litecoin Node: %s:%d' % (self.dstaddr, self.dstport)) + + def peer_connect(self, dstaddr, dstport, *, net, timeout_factor): + self.peer_connect_helper(dstaddr, dstport, net, timeout_factor) loop = NetworkThread.network_event_loop - conn_gen_unsafe = loop.create_connection(lambda: self, host=self.dstaddr, port=self.dstport) - conn_gen = lambda: loop.call_soon_threadsafe(loop.create_task, conn_gen_unsafe) - return conn_gen + logger.debug('Connecting to Bitcoin Node: %s:%d' % (self.dstaddr, self.dstport)) + coroutine = loop.create_connection(lambda: self, host=self.dstaddr, port=self.dstport) + return lambda: loop.call_soon_threadsafe(loop.create_task, coroutine) + + def peer_accept_connection(self, connect_id, connect_cb=lambda: None, *, net, timeout_factor): + self.peer_connect_helper('0', 0, net, timeout_factor) + + logger.debug('Listening for Bitcoin Node with id: {}'.format(connect_id)) + return lambda: NetworkThread.listen(self, connect_cb, idx=connect_id) def peer_disconnect(self): # Connection could have already been closed by other end. @@ -310,18 +322,27 @@ class P2PInterface(P2PConnection): self.support_addrv2 = support_addrv2 + def peer_connect_send_version(self, services): + # Send a version msg + vt = msg_version() + vt.nServices = services + vt.addrTo.ip = self.dstaddr + vt.addrTo.port = self.dstport + vt.addrFrom.ip = "0.0.0.0" + vt.addrFrom.port = 0 + self.on_connection_send_msg = vt # Will be sent in connection_made callback + def peer_connect(self, *args, services=NODE_NETWORK|NODE_WITNESS|NODE_MWEB, send_version=True, **kwargs): create_conn = super().peer_connect(*args, **kwargs) if send_version: - # Send a version msg - vt = msg_version() - vt.nServices = services - vt.addrTo.ip = self.dstaddr - vt.addrTo.port = self.dstport - vt.addrFrom.ip = "0.0.0.0" - vt.addrFrom.port = 0 - self.on_connection_send_msg = vt # Will be sent soon after connection_made + self.peer_connect_send_version(services) + + return create_conn + + def peer_accept_connection(self, *args, services=NODE_NETWORK|NODE_WITNESS|NODE_MWEB, **kwargs): + create_conn = super().peer_accept_connection(*args, **kwargs) + self.peer_connect_send_version(services) return create_conn @@ -412,6 +433,10 @@ class P2PInterface(P2PConnection): wait_until_helper(test_function, timeout=timeout, lock=p2p_lock, timeout_factor=self.timeout_factor) + def wait_for_connect(self, timeout=60): + test_function = lambda: self.is_connected + wait_until_helper(test_function, timeout=timeout, lock=p2p_lock) + def wait_for_disconnect(self, timeout=60): test_function = lambda: not self.is_connected self.wait_until(test_function, timeout=timeout, check_connected=False) @@ -525,6 +550,10 @@ class NetworkThread(threading.Thread): # There is only one event loop and no more than one thread must be created assert not self.network_event_loop + NetworkThread.listeners = {} + NetworkThread.protos = {} + if sys.platform == 'win32': + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) NetworkThread.network_event_loop = asyncio.new_event_loop() def run(self): @@ -540,6 +569,47 @@ class NetworkThread(threading.Thread): # Safe to remove event loop. NetworkThread.network_event_loop = None + @classmethod + def listen(cls, p2p, callback, port=None, addr=None, idx=1): + """ Ensure a listening server is running on the given port, and run the + protocol specified by `p2p` on the next connection to it. Once ready + for connections, call `callback`.""" + + if port is None: + assert 0 < idx <= MAX_NODES + port = p2p_port(MAX_NODES - idx) + if addr is None: + addr = '127.0.0.1' + + coroutine = cls.create_listen_server(addr, port, callback, p2p) + cls.network_event_loop.call_soon_threadsafe(cls.network_event_loop.create_task, coroutine) + + @classmethod + async def create_listen_server(cls, addr, port, callback, proto): + def peer_protocol(): + """Returns a function that does the protocol handling for a new + connection. To allow different connections to have different + behaviors, the protocol function is first put in the cls.protos + dict. When the connection is made, the function removes the + protocol function from that dict, and returns it so the event loop + can start executing it.""" + response = cls.protos.get((addr, port)) + cls.protos[(addr, port)] = None + return response + + if (addr, port) not in cls.listeners: + # When creating a listener on a given (addr, port) we only need to + # do it once. If we want different behaviors for different + # connections, we can accomplish this by providing different + # `proto` functions + + listener = await cls.network_event_loop.create_server(peer_protocol, addr, port) + logger.debug("Listening server on %s:%d should be started" % (addr, port)) + cls.listeners[(addr, port)] = listener + + cls.protos[(addr, port)] = proto + callback(addr, port) + class P2PDataStore(P2PInterface): """A P2P data store class. diff --git a/test/functional/test_framework/test_node.py b/test/functional/test_framework/test_node.py index 1d0db146ae..81d8327ddb 100755 --- a/test/functional/test_framework/test_node.py +++ b/test/functional/test_framework/test_node.py @@ -542,6 +542,38 @@ class TestNode(): return p2p_conn + def add_outbound_p2p_connection(self, p2p_conn, *, wait_for_verack=True, p2p_idx, connection_type="outbound-full-relay", **kwargs): + """Add an outbound p2p connection from node. Must be an + "outbound-full-relay", "block-relay-only", "addr-fetch" or "feeler" connection. + + This method adds the p2p connection to the self.p2ps list and returns + the connection to the caller. + + p2p_idx must be different for simultaneously connected peers. When reusing it for the next peer + after disconnecting the previous one, it is necessary to wait for the disconnect to finish to avoid + a race condition. + """ + + def addconnection_callback(address, port): + self.log.debug("Connecting to %s:%d %s" % (address, port, connection_type)) + self.addconnection('%s:%d' % (address, port), connection_type) + + p2p_conn.peer_accept_connection(connect_cb=addconnection_callback, connect_id=p2p_idx + 1, net=self.chain, timeout_factor=self.timeout_factor, **kwargs)() + + if connection_type == "feeler": + # feeler connections are closed as soon as the node receives a `version` message + p2p_conn.wait_until(lambda: p2p_conn.message_count["version"] == 1, check_connected=False) + p2p_conn.wait_until(lambda: not p2p_conn.is_connected, check_connected=False) + else: + p2p_conn.wait_for_connect() + self.p2ps.append(p2p_conn) + + if wait_for_verack: + p2p_conn.wait_for_verack() + p2p_conn.sync_with_ping() + + return p2p_conn + def num_test_p2p_connections(self): """Return number of test framework p2p connections to the node.""" return len([peer for peer in self.getpeerinfo() if peer['subver'] == MY_SUBVERSION])