Rate limit the processing of rumoured addresses. Ref: https://github.com/bitcoin/bitcoin/pull/22387

(cherry picked from commit 8b256769fd)
pull/938/head
David Burkett 1 year ago
parent 32a108600d
commit a8f12c2615

@ -1161,6 +1161,42 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) {
RandAddEvent((uint32_t)id);
}
bool CConnman::AddConnection(const std::string& address, ConnectionType conn_type)
{
Optional<int> max_connections;
switch (conn_type) {
case ConnectionType::INBOUND:
case ConnectionType::MANUAL:
return false;
case ConnectionType::OUTBOUND_FULL_RELAY:
max_connections = m_max_outbound_full_relay;
break;
case ConnectionType::BLOCK_RELAY:
max_connections = m_max_outbound_block_relay;
break;
// no limit for ADDR_FETCH because -seednode has no limit either
case ConnectionType::ADDR_FETCH:
break;
// no limit for FEELER connections since they're short-lived
case ConnectionType::FEELER:
break;
} // no default case, so the compiler can warn about missing cases
// Count existing connections
int existing_connections = WITH_LOCK(cs_vNodes,
return std::count_if(vNodes.begin(), vNodes.end(), [conn_type](CNode* node) { return node->m_conn_type == conn_type; }););
// Max connections of specified type already exist
if (max_connections != nullopt && existing_connections >= max_connections) return false;
// Max total outbound connections already exist
CSemaphoreGrant grant(*semOutbound, true);
if (!grant) return false;
OpenNetworkConnection(CAddress(), false, &grant, address.c_str(), conn_type);
return true;
}
void CConnman::DisconnectNodes()
{
{

@ -346,6 +346,20 @@ public:
bool RemoveAddedNode(const std::string& node);
std::vector<AddedNodeInfo> GetAddedNodeInfo();
/**
* Attempts to open a connection. Currently only used from tests.
*
* @param[in] address Address of node to try connecting to
* @param[in] conn_type ConnectionType::OUTBOUND, ConnectionType::BLOCK_RELAY,
* ConnectionType::ADDR_FETCH or ConnectionType::FEELER
* @return bool Returns false if there are no available
* slots for this connection:
* - conn_type not a supported ConnectionType
* - Max total outbound connection capacity filled
* - Max connection capacity for type is filled
*/
bool AddConnection(const std::string& address, ConnectionType conn_type);
size_t GetNodeCount(NumConnections num);
void GetNodeStats(std::vector<CNodeStats>& vstats);
bool DisconnectNode(const std::string& node);

@ -30,7 +30,8 @@ enum NetPermissionFlags {
PF_NOBAN = (1U << 4) | PF_DOWNLOAD,
// Can query the mempool
PF_MEMPOOL = (1U << 5),
// Can request addrs without hitting a privacy-preserving cache
// Can request addrs without hitting a privacy-preserving cache, and send us
// unlimited amounts of addrs.
PF_ADDR = (1U << 7),
// True if the user did not specifically set fine grained permissions

@ -146,6 +146,13 @@ static constexpr uint32_t MAX_GETCFILTERS_SIZE = 1000;
static constexpr uint32_t MAX_GETCFHEADERS_SIZE = 2000;
/** the maximum percentage of addresses from our addrman to return in response to a getaddr message. */
static constexpr size_t MAX_PCT_ADDR_TO_SEND = 23;
/** The maximum rate of address records we're willing to process on average. Can be bypassed using
* the NetPermissionFlags::Addr permission. */
static constexpr double MAX_ADDR_RATE_PER_SECOND{0.1};
/** The soft limit of the address processing token bucket (the regular MAX_ADDR_RATE_PER_SECOND
* based increments won't go above this, but the MAX_ADDR_TO_SEND increment following GETADDR
* is exempt from this limit). */
static constexpr size_t MAX_ADDR_PROCESSING_TOKEN_BUCKET{MAX_ADDR_TO_SEND};
struct COrphanTx {
// When modifying, adapt the copy of this definition in tests/DoS_tests.
@ -471,6 +478,16 @@ struct Peer {
/** Work queue of items requested by this peer **/
std::deque<CInv> m_getdata_requests GUARDED_BY(m_getdata_requests_mutex);
/** Number of addresses that can be processed from this peer. Start at 1 to
* permit self-announcement. */
double m_addr_token_bucket GUARDED_BY(NetEventsInterface::g_msgproc_mutex){1.0};
/** When m_addr_token_bucket was last updated */
std::chrono::microseconds m_addr_token_timestamp GUARDED_BY(NetEventsInterface::g_msgproc_mutex){GetTime<std::chrono::microseconds>()};
/** Total number of addresses that were dropped due to rate limiting. */
std::atomic<uint64_t> m_addr_rate_limited{0};
/** Total number of addresses that were processed (excludes rate-limited ones). */
std::atomic<uint64_t> m_addr_processed{0};
Peer(NodeId id) : m_id(id) {}
};
@ -930,6 +947,8 @@ bool GetNodeStateStats(NodeId nodeid, CNodeStateStats &stats) {
PeerRef peer = GetPeerRef(nodeid);
if (peer == nullptr) return false;
stats.m_misbehavior_score = WITH_LOCK(peer->m_misbehavior_mutex, return peer->m_misbehavior_score);
stats.m_addr_processed = peer->m_addr_processed.load();
stats.m_addr_rate_limited = peer->m_addr_rate_limited.load();
return true;
}
@ -2489,6 +2508,9 @@ void PeerManager::ProcessMessage(CNode& pfrom, const std::string& msg_type, CDat
// Get recent addresses
m_connman.PushMessage(&pfrom, CNetMsgMaker(greatest_common_version).Make(NetMsgType::GETADDR));
pfrom.fGetAddr = true;
// When requesting a getaddr, accept an additional MAX_ADDR_TO_SEND addresses in response
// (bypassing the MAX_ADDR_PROCESSING_TOKEN_BUCKET limit).
peer->m_addr_token_bucket += MAX_ADDR_TO_SEND;
}
if (!pfrom.IsInboundConn()) {
@ -2645,11 +2667,35 @@ void PeerManager::ProcessMessage(CNode& pfrom, const std::string& msg_type, CDat
std::vector<CAddress> vAddrOk;
int64_t nNow = GetAdjustedTime();
int64_t nSince = nNow - 10 * 60;
// Update/increment addr rate limiting bucket.
const auto current_time{GetTime<std::chrono::microseconds>()};
if (peer->m_addr_token_bucket < MAX_ADDR_PROCESSING_TOKEN_BUCKET) {
// Don't increment bucket if it's already full
const auto time_diff = std::max(current_time - peer->m_addr_token_timestamp, 0us);
const double increment = Ticks<SecondsDouble>(time_diff) * MAX_ADDR_RATE_PER_SECOND;
peer->m_addr_token_bucket = std::min<double>(peer->m_addr_token_bucket + increment, MAX_ADDR_PROCESSING_TOKEN_BUCKET);
}
peer->m_addr_token_timestamp = current_time;
const bool rate_limited = !pfrom.HasPermission(NetPermissionFlags::PF_ADDR);
uint64_t num_proc = 0;
uint64_t num_rate_limit = 0;
Shuffle(vAddr.begin(), vAddr.end(), FastRandomContext());
for (CAddress& addr : vAddr)
{
if (interruptMsgProc)
return;
// Apply rate limiting.
if (peer->m_addr_token_bucket < 1.0) {
if (rate_limited) {
++num_rate_limit;
continue;
}
} else {
peer->m_addr_token_bucket -= 1.0;
}
// We only bother storing full nodes, though this may include
// things which we would not make an outbound connection to, in
// part because we may make feeler connections to them.
@ -2663,6 +2709,7 @@ void PeerManager::ProcessMessage(CNode& pfrom, const std::string& msg_type, CDat
// Do not process banned/discouraged addresses beyond remembering we received them
continue;
}
++num_proc;
bool fReachable = IsReachable(addr);
if (addr.nTime > nSince && !pfrom.fGetAddr && vAddr.size() <= 10 && addr.IsRoutable())
{
@ -2673,6 +2720,11 @@ void PeerManager::ProcessMessage(CNode& pfrom, const std::string& msg_type, CDat
if (fReachable)
vAddrOk.push_back(addr);
}
peer->m_addr_processed += num_proc;
peer->m_addr_rate_limited += num_rate_limit;
LogPrint(BCLog::NET, "Received addr: %u addresses (%u processed, %u rate-limited) from peer=%d\n",
vAddr.size(), num_proc, num_rate_limit, pfrom.GetId());
m_connman.AddNewAddresses(vAddrOk, pfrom.addr, 2 * 60 * 60);
if (vAddr.size() < 1000)
pfrom.fGetAddr = false;

@ -150,6 +150,8 @@ struct CNodeStateStats {
int nSyncHeight = -1;
int nCommonHeight = -1;
std::vector<int> vHeightInFlight;
uint64_t m_addr_processed = 0;
uint64_t m_addr_rate_limited = 0;
};
/** Get statistics from node state */

@ -81,6 +81,15 @@ ChainstateManager& EnsureChainman(const util::Ref& context)
return *node.chainman;
}
CConnman& EnsureConnman(const util::Ref& context)
{
NodeContext& node = EnsureNodeContext(context);
if (!node.connman) {
throw JSONRPCError(RPC_CLIENT_P2P_DISABLED, "Error: Peer-to-peer functionality missing or disabled");
}
return *node.connman;
}
/* Calculate the difficulty for a given block index.
*/
double GetDifficulty(const CBlockIndex* blockindex)

@ -15,6 +15,7 @@ extern RecursiveMutex cs_main;
class CBlock;
class CBlockIndex;
class CConnman;
class CTxMemPool;
class ChainstateManager;
class UniValue;
@ -54,5 +55,6 @@ void CalculatePercentilesByWeight(CAmount result[NUM_GETBLOCKSTATS_PERCENTILES],
NodeContext& EnsureNodeContext(const util::Ref& context);
CTxMemPool& EnsureMemPool(const util::Ref& context);
ChainstateManager& EnsureChainman(const util::Ref& context);
CConnman& EnsureConnman(const util::Ref& context);
#endif

@ -139,6 +139,8 @@ static RPCHelpMan getpeerinfo()
{
{RPCResult::Type::NUM, "n", "The heights of blocks we're currently asking from this peer"},
}},
{RPCResult::Type::NUM, "addr_processed", "The total number of addresses processed, excluding those dropped due to rate limiting"},
{RPCResult::Type::NUM, "addr_rate_limited", "The total number of addresses dropped due to rate limiting"},
{RPCResult::Type::BOOL, "whitelisted", /* optional */ true, "Whether the peer is whitelisted with default permissions\n"
"(DEPRECATED, returned only if config option -deprecatedrpc=whitelisted is passed)"},
{RPCResult::Type::ARR, "permissions", "Any special permissions that have been granted to this peer",
@ -236,6 +238,8 @@ static RPCHelpMan getpeerinfo()
heights.push_back(height);
}
obj.pushKV("inflight", heights);
obj.pushKV("addr_processed", statestats.m_addr_processed);
obj.pushKV("addr_rate_limited", statestats.m_addr_rate_limited);
}
if (IsDeprecatedRPCEnabled("whitelisted")) {
// whitelisted is deprecated in v0.21 for removal in v0.22
@ -326,6 +330,61 @@ static RPCHelpMan addnode()
};
}
static RPCHelpMan addconnection()
{
return RPCHelpMan{"addconnection",
"\nOpen an outbound connection to a specified node. This RPC is for testing only.\n",
{
{"address", RPCArg::Type::STR, RPCArg::Optional::NO, "The IP address and port to attempt connecting to."},
{"connection_type", RPCArg::Type::STR, RPCArg::Optional::NO, "Type of connection to open (\"outbound-full-relay\", \"block-relay-only\", \"addr-fetch\" or \"feeler\")."},
},
RPCResult{
RPCResult::Type::OBJ, "", "",
{
{ RPCResult::Type::STR, "address", "Address of newly added connection." },
{ RPCResult::Type::STR, "connection_type", "Type of connection opened." },
}},
RPCExamples{
HelpExampleCli("addconnection", "\"192.168.0.6:8333\" \"outbound-full-relay\"")
+ HelpExampleRpc("addconnection", "\"192.168.0.6:8333\" \"outbound-full-relay\"")
},
[&](const RPCHelpMan& self, const JSONRPCRequest& request) -> UniValue
{
if (Params().NetworkIDString() != CBaseChainParams::REGTEST) {
throw std::runtime_error("addconnection is for regression testing (-regtest mode) only.");
}
const std::string address = request.params[0].get_str();
const std::string conn_type_in{TrimString(request.params[1].get_str())};
ConnectionType conn_type{};
if (conn_type_in == "outbound-full-relay") {
conn_type = ConnectionType::OUTBOUND_FULL_RELAY;
} else if (conn_type_in == "block-relay-only") {
conn_type = ConnectionType::BLOCK_RELAY;
} else if (conn_type_in == "addr-fetch") {
conn_type = ConnectionType::ADDR_FETCH;
} else if (conn_type_in == "feeler") {
conn_type = ConnectionType::FEELER;
} else {
throw JSONRPCError(RPC_INVALID_PARAMETER, self.ToString());
}
CConnman& connman = EnsureConnman(request.context);
const bool success = connman.AddConnection(address, conn_type);
if (!success) {
throw JSONRPCError(RPC_CLIENT_NODE_CAPACITY_REACHED, "Error: Already at capacity for specified connection type.");
}
UniValue info(UniValue::VOBJ);
info.pushKV("address", address);
info.pushKV("connection_type", conn_type_in);
return info;
},
};
}
static RPCHelpMan disconnectnode()
{
return RPCHelpMan{"disconnectnode",
@ -910,6 +969,7 @@ static const CRPCCommand commands[] =
{ "network", "clearbanned", &clearbanned, {} },
{ "network", "setnetworkactive", &setnetworkactive, {"state"} },
{ "network", "getnodeaddresses", &getnodeaddresses, {"count"} },
{ "hidden", "addconnection", &addconnection, {"address", "connection_type"} },
{ "hidden", "addpeeraddress", &addpeeraddress, {"address", "port"} },
};
// clang-format on

@ -62,6 +62,7 @@ enum RPCErrorCode
RPC_CLIENT_NODE_NOT_CONNECTED = -29, //!< Node to disconnect not found in connected nodes
RPC_CLIENT_INVALID_IP_OR_SUBNET = -30, //!< Invalid IP/Subnet
RPC_CLIENT_P2P_DISABLED = -31, //!< No valid connection manager instance found
RPC_CLIENT_NODE_CAPACITY_REACHED= -34, //!< Max number of outbound or block-relay connections already open
//! Chain errors
RPC_CLIENT_MEMPOOL_DISABLED = -33, //!< No mempool instance found

@ -10,21 +10,38 @@
#include <string>
#include <chrono>
using namespace std::chrono_literals;
void UninterruptibleSleep(const std::chrono::microseconds& n);
/**
* Helper to count the seconds of a duration.
* Helper to count the seconds of a duration/time_point.
*
* All durations should be using std::chrono and calling this should generally
* All durations/time_points should be using std::chrono and calling this should generally
* be avoided in code. Though, it is still preferred to an inline t.count() to
* protect against a reliance on the exact type of t.
*
* This helper is used to convert durations before passing them over an
* This helper is used to convert durations/time_points before passing them over an
* interface that doesn't support std::chrono (e.g. RPC, debug log, or the GUI)
*/
template <typename Dur1, typename Dur2>
inline auto Ticks(Dur2 d)
{
return std::chrono::duration_cast<Dur1>(d).count();
}
template <typename Duration, typename Timepoint>
inline auto TicksSinceEpoch(Timepoint t)
{
return Ticks<Duration>(t.time_since_epoch());
}
inline int64_t count_seconds(std::chrono::seconds t) { return t.count(); }
inline int64_t count_milliseconds(std::chrono::milliseconds t) { return t.count(); }
inline int64_t count_microseconds(std::chrono::microseconds t) { return t.count(); }
using HoursDouble = std::chrono::duration<double, std::chrono::hours::period>;
using SecondsDouble = std::chrono::duration<double, std::chrono::seconds::period>;
using MillisecondsDouble = std::chrono::duration<double, std::chrono::milliseconds::period>;
/**
* DEPRECATED
* Use either GetSystemTimeInSeconds (not mockable) or GetTime<T> (mockable)

@ -12,11 +12,15 @@ from test_framework.messages import (
NODE_WITNESS,
msg_addr,
)
from test_framework.p2p import P2PInterface
from test_framework.p2p import (
P2PInterface,
p2p_lock,
)
from test_framework.test_framework import BitcoinTestFramework
from test_framework.util import (
assert_equal,
)
import random
import time
ADDRS = []
@ -30,17 +34,33 @@ for i in range(10):
class AddrReceiver(P2PInterface):
_tokens = 1
def on_addr(self, message):
for addr in message.addrs:
assert_equal(addr.nServices, 9)
assert addr.ip.startswith('123.123.123.')
assert (8333 <= addr.port < 8343)
def on_getaddr(self, message):
# When the node sends us a getaddr, it increments the addr relay tokens for the connection by 1000
self._tokens += 1000
@property
def tokens(self):
with p2p_lock:
return self._tokens
def increment_tokens(self, n):
# When we move mocktime forward, the node increments the addr relay tokens for its peers
with p2p_lock:
self._tokens += n
class AddrTest(BitcoinTestFramework):
def set_test_params(self):
self.setup_clean_chain = False
self.num_nodes = 1
self.extra_args = [["-whitelist=addr@127.0.0.1"]]
def run_test(self):
self.log.info('Create connection that sends addr messages')
@ -64,6 +84,77 @@ class AddrTest(BitcoinTestFramework):
self.nodes[0].setmocktime(int(time.time()) + 30 * 60)
addr_receiver.sync_with_ping()
self.rate_limit_tests()
def setup_rand_addr_msg(self, num):
addrs = []
for i in range(num):
addr = CAddress()
addr.time = self.mocktime + i
addr.nServices = NODE_NETWORK | NODE_WITNESS
addr.ip = f"{random.randrange(128,169)}.{random.randrange(1,255)}.{random.randrange(1,255)}.{random.randrange(1,255)}"
addr.port = 8333
addrs.append(addr)
msg = msg_addr()
msg.addrs = addrs
return msg
def send_addrs_and_test_rate_limiting(self, peer, no_relay, *, new_addrs, total_addrs):
"""Send an addr message and check that the number of addresses processed and rate-limited is as expected"""
peer.send_and_ping(self.setup_rand_addr_msg(new_addrs))
peerinfo = self.nodes[0].getpeerinfo()[0]
addrs_processed = peerinfo['addr_processed']
addrs_rate_limited = peerinfo['addr_rate_limited']
self.log.debug(f"addrs_processed = {addrs_processed}, addrs_rate_limited = {addrs_rate_limited}")
if no_relay:
assert_equal(addrs_processed, 0)
assert_equal(addrs_rate_limited, 0)
else:
assert_equal(addrs_processed, min(total_addrs, peer.tokens))
assert_equal(addrs_rate_limited, max(0, total_addrs - peer.tokens))
def rate_limit_tests(self):
self.mocktime = int(time.time())
self.restart_node(0, [])
self.nodes[0].setmocktime(self.mocktime)
for conn_type, no_relay in [("outbound-full-relay", False), ("block-relay-only", True), ("inbound", False)]:
self.log.info(f'Test rate limiting of addr processing for {conn_type} peers')
if conn_type == "inbound":
peer = self.nodes[0].add_p2p_connection(AddrReceiver())
else:
peer = self.nodes[0].add_outbound_p2p_connection(AddrReceiver(), p2p_idx=0, connection_type=conn_type)
# Send 600 addresses. For all but the block-relay-only peer this should result in addresses being processed.
self.send_addrs_and_test_rate_limiting(peer, no_relay, new_addrs=600, total_addrs=600)
# Send 600 more addresses. For the outbound-full-relay peer (which we send a GETADDR, and thus will
# process up to 1001 incoming addresses), this means more addresses will be processed.
self.send_addrs_and_test_rate_limiting(peer, no_relay, new_addrs=600, total_addrs=1200)
# Send 10 more. As we reached the processing limit for all nodes, no more addresses should be procesesd.
self.send_addrs_and_test_rate_limiting(peer, no_relay, new_addrs=10, total_addrs=1210)
# Advance the time by 100 seconds, permitting the processing of 10 more addresses.
# Send 200 and verify that 10 are processed.
self.mocktime += 100
self.nodes[0].setmocktime(self.mocktime)
peer.increment_tokens(10)
self.send_addrs_and_test_rate_limiting(peer, no_relay, new_addrs=200, total_addrs=1410)
# Advance the time by 1000 seconds, permitting the processing of 100 more addresses.
# Send 200 and verify that 100 are processed.
self.mocktime += 1000
self.nodes[0].setmocktime(self.mocktime)
peer.increment_tokens(100)
self.send_addrs_and_test_rate_limiting(peer, no_relay, new_addrs=200, total_addrs=1610)
self.nodes[0].disconnect_p2ps()
if __name__ == '__main__':
AddrTest().main()

@ -39,7 +39,10 @@ class AddrReceiver(P2PInterface):
assert_equal(addr.nServices, 9)
assert addr.ip.startswith('123.123.123.')
assert (8333 <= addr.port < 8343)
self.addrv2_received_and_checked = True
expected_set = set((addr.ip, addr.port) for addr in ADDRS)
received_set = set((addr.ip, addr.port) for addr in message.addrs)
if expected_set == received_set:
self.addrv2_received_and_checked = True
def wait_for_addrv2(self):
self.wait_until(lambda: "addrv2" in self.last_message)
@ -49,6 +52,7 @@ class AddrTest(BitcoinTestFramework):
def set_test_params(self):
self.setup_clean_chain = True
self.num_nodes = 1
self.extra_args = [["-whitelist=addr@127.0.0.1"]]
def run_test(self):
self.log.info('Create connection that sends addrv2 messages')

@ -57,6 +57,7 @@ class InvalidMessagesTest(BitcoinTestFramework):
def set_test_params(self):
self.num_nodes = 1
self.setup_clean_chain = True
self.extra_args = [["-whitelist=addr@127.0.0.1"]]
def run_test(self):
self.test_buffer()

@ -72,7 +72,11 @@ from test_framework.messages import (
NODE_WITNESS,
sha256,
)
from test_framework.util import wait_until_helper
from test_framework.util import (
MAX_NODES,
p2p_port,
wait_until_helper,
)
logger = logging.getLogger("TestFramework.p2p")
@ -140,7 +144,7 @@ class P2PConnection(asyncio.Protocol):
def is_connected(self):
return self._transport is not None
def peer_connect(self, dstaddr, dstport, *, net, timeout_factor):
def peer_connect_helper(self, dstaddr, dstport, net, timeout_factor):
assert not self.is_connected
self.timeout_factor = timeout_factor
self.dstaddr = dstaddr
@ -149,12 +153,20 @@ class P2PConnection(asyncio.Protocol):
self.on_connection_send_msg = None
self.recvbuf = b""
self.magic_bytes = MAGIC_BYTES[net]
logger.debug('Connecting to Litecoin Node: %s:%d' % (self.dstaddr, self.dstport))
def peer_connect(self, dstaddr, dstport, *, net, timeout_factor):
self.peer_connect_helper(dstaddr, dstport, net, timeout_factor)
loop = NetworkThread.network_event_loop
conn_gen_unsafe = loop.create_connection(lambda: self, host=self.dstaddr, port=self.dstport)
conn_gen = lambda: loop.call_soon_threadsafe(loop.create_task, conn_gen_unsafe)
return conn_gen
logger.debug('Connecting to Bitcoin Node: %s:%d' % (self.dstaddr, self.dstport))
coroutine = loop.create_connection(lambda: self, host=self.dstaddr, port=self.dstport)
return lambda: loop.call_soon_threadsafe(loop.create_task, coroutine)
def peer_accept_connection(self, connect_id, connect_cb=lambda: None, *, net, timeout_factor):
self.peer_connect_helper('0', 0, net, timeout_factor)
logger.debug('Listening for Bitcoin Node with id: {}'.format(connect_id))
return lambda: NetworkThread.listen(self, connect_cb, idx=connect_id)
def peer_disconnect(self):
# Connection could have already been closed by other end.
@ -310,18 +322,27 @@ class P2PInterface(P2PConnection):
self.support_addrv2 = support_addrv2
def peer_connect_send_version(self, services):
# Send a version msg
vt = msg_version()
vt.nServices = services
vt.addrTo.ip = self.dstaddr
vt.addrTo.port = self.dstport
vt.addrFrom.ip = "0.0.0.0"
vt.addrFrom.port = 0
self.on_connection_send_msg = vt # Will be sent in connection_made callback
def peer_connect(self, *args, services=NODE_NETWORK|NODE_WITNESS|NODE_MWEB, send_version=True, **kwargs):
create_conn = super().peer_connect(*args, **kwargs)
if send_version:
# Send a version msg
vt = msg_version()
vt.nServices = services
vt.addrTo.ip = self.dstaddr
vt.addrTo.port = self.dstport
vt.addrFrom.ip = "0.0.0.0"
vt.addrFrom.port = 0
self.on_connection_send_msg = vt # Will be sent soon after connection_made
self.peer_connect_send_version(services)
return create_conn
def peer_accept_connection(self, *args, services=NODE_NETWORK|NODE_WITNESS|NODE_MWEB, **kwargs):
create_conn = super().peer_accept_connection(*args, **kwargs)
self.peer_connect_send_version(services)
return create_conn
@ -412,6 +433,10 @@ class P2PInterface(P2PConnection):
wait_until_helper(test_function, timeout=timeout, lock=p2p_lock, timeout_factor=self.timeout_factor)
def wait_for_connect(self, timeout=60):
test_function = lambda: self.is_connected
wait_until_helper(test_function, timeout=timeout, lock=p2p_lock)
def wait_for_disconnect(self, timeout=60):
test_function = lambda: not self.is_connected
self.wait_until(test_function, timeout=timeout, check_connected=False)
@ -525,6 +550,10 @@ class NetworkThread(threading.Thread):
# There is only one event loop and no more than one thread must be created
assert not self.network_event_loop
NetworkThread.listeners = {}
NetworkThread.protos = {}
if sys.platform == 'win32':
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
NetworkThread.network_event_loop = asyncio.new_event_loop()
def run(self):
@ -540,6 +569,47 @@ class NetworkThread(threading.Thread):
# Safe to remove event loop.
NetworkThread.network_event_loop = None
@classmethod
def listen(cls, p2p, callback, port=None, addr=None, idx=1):
""" Ensure a listening server is running on the given port, and run the
protocol specified by `p2p` on the next connection to it. Once ready
for connections, call `callback`."""
if port is None:
assert 0 < idx <= MAX_NODES
port = p2p_port(MAX_NODES - idx)
if addr is None:
addr = '127.0.0.1'
coroutine = cls.create_listen_server(addr, port, callback, p2p)
cls.network_event_loop.call_soon_threadsafe(cls.network_event_loop.create_task, coroutine)
@classmethod
async def create_listen_server(cls, addr, port, callback, proto):
def peer_protocol():
"""Returns a function that does the protocol handling for a new
connection. To allow different connections to have different
behaviors, the protocol function is first put in the cls.protos
dict. When the connection is made, the function removes the
protocol function from that dict, and returns it so the event loop
can start executing it."""
response = cls.protos.get((addr, port))
cls.protos[(addr, port)] = None
return response
if (addr, port) not in cls.listeners:
# When creating a listener on a given (addr, port) we only need to
# do it once. If we want different behaviors for different
# connections, we can accomplish this by providing different
# `proto` functions
listener = await cls.network_event_loop.create_server(peer_protocol, addr, port)
logger.debug("Listening server on %s:%d should be started" % (addr, port))
cls.listeners[(addr, port)] = listener
cls.protos[(addr, port)] = proto
callback(addr, port)
class P2PDataStore(P2PInterface):
"""A P2P data store class.

@ -542,6 +542,38 @@ class TestNode():
return p2p_conn
def add_outbound_p2p_connection(self, p2p_conn, *, wait_for_verack=True, p2p_idx, connection_type="outbound-full-relay", **kwargs):
"""Add an outbound p2p connection from node. Must be an
"outbound-full-relay", "block-relay-only", "addr-fetch" or "feeler" connection.
This method adds the p2p connection to the self.p2ps list and returns
the connection to the caller.
p2p_idx must be different for simultaneously connected peers. When reusing it for the next peer
after disconnecting the previous one, it is necessary to wait for the disconnect to finish to avoid
a race condition.
"""
def addconnection_callback(address, port):
self.log.debug("Connecting to %s:%d %s" % (address, port, connection_type))
self.addconnection('%s:%d' % (address, port), connection_type)
p2p_conn.peer_accept_connection(connect_cb=addconnection_callback, connect_id=p2p_idx + 1, net=self.chain, timeout_factor=self.timeout_factor, **kwargs)()
if connection_type == "feeler":
# feeler connections are closed as soon as the node receives a `version` message
p2p_conn.wait_until(lambda: p2p_conn.message_count["version"] == 1, check_connected=False)
p2p_conn.wait_until(lambda: not p2p_conn.is_connected, check_connected=False)
else:
p2p_conn.wait_for_connect()
self.p2ps.append(p2p_conn)
if wait_for_verack:
p2p_conn.wait_for_verack()
p2p_conn.sync_with_ping()
return p2p_conn
def num_test_p2p_connections(self):
"""Return number of test framework p2p connections to the node."""
return len([peer for peer in self.getpeerinfo() if peer['subver'] == MY_SUBVERSION])

Loading…
Cancel
Save