net: pass socket closing responsibility up to caller for outgoing connections

This allows const references to be passed around, making it clear where the
socket may and may not be invalidated.
pull/476/merge
Cory Fields 7 years ago
parent 9e3b2f576b
commit df3bcf89e4

@ -317,12 +317,11 @@ std::string Socks5ErrorString(uint8_t err)
} }
/** Connect using SOCKS5 (as described in RFC1928) */ /** Connect using SOCKS5 (as described in RFC1928) */
static bool Socks5(const std::string& strDest, int port, const ProxyCredentials *auth, SOCKET& hSocket) static bool Socks5(const std::string& strDest, int port, const ProxyCredentials *auth, const SOCKET& hSocket)
{ {
IntrRecvError recvr; IntrRecvError recvr;
LogPrint(BCLog::NET, "SOCKS5 connecting %s\n", strDest); LogPrint(BCLog::NET, "SOCKS5 connecting %s\n", strDest);
if (strDest.size() > 255) { if (strDest.size() > 255) {
CloseSocket(hSocket);
return error("Hostname too long"); return error("Hostname too long");
} }
// Accepted authentication methods // Accepted authentication methods
@ -338,17 +337,14 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials
} }
ssize_t ret = send(hSocket, (const char*)vSocks5Init.data(), vSocks5Init.size(), MSG_NOSIGNAL); ssize_t ret = send(hSocket, (const char*)vSocks5Init.data(), vSocks5Init.size(), MSG_NOSIGNAL);
if (ret != (ssize_t)vSocks5Init.size()) { if (ret != (ssize_t)vSocks5Init.size()) {
CloseSocket(hSocket);
return error("Error sending to proxy"); return error("Error sending to proxy");
} }
uint8_t pchRet1[2]; uint8_t pchRet1[2];
if ((recvr = InterruptibleRecv(pchRet1, 2, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) { if ((recvr = InterruptibleRecv(pchRet1, 2, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) {
CloseSocket(hSocket);
LogPrintf("Socks5() connect to %s:%d failed: InterruptibleRecv() timeout or other failure\n", strDest, port); LogPrintf("Socks5() connect to %s:%d failed: InterruptibleRecv() timeout or other failure\n", strDest, port);
return false; return false;
} }
if (pchRet1[0] != SOCKSVersion::SOCKS5) { if (pchRet1[0] != SOCKSVersion::SOCKS5) {
CloseSocket(hSocket);
return error("Proxy failed to initialize"); return error("Proxy failed to initialize");
} }
if (pchRet1[1] == SOCKS5Method::USER_PASS && auth) { if (pchRet1[1] == SOCKS5Method::USER_PASS && auth) {
@ -363,23 +359,19 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials
vAuth.insert(vAuth.end(), auth->password.begin(), auth->password.end()); vAuth.insert(vAuth.end(), auth->password.begin(), auth->password.end());
ret = send(hSocket, (const char*)vAuth.data(), vAuth.size(), MSG_NOSIGNAL); ret = send(hSocket, (const char*)vAuth.data(), vAuth.size(), MSG_NOSIGNAL);
if (ret != (ssize_t)vAuth.size()) { if (ret != (ssize_t)vAuth.size()) {
CloseSocket(hSocket);
return error("Error sending authentication to proxy"); return error("Error sending authentication to proxy");
} }
LogPrint(BCLog::PROXY, "SOCKS5 sending proxy authentication %s:%s\n", auth->username, auth->password); LogPrint(BCLog::PROXY, "SOCKS5 sending proxy authentication %s:%s\n", auth->username, auth->password);
uint8_t pchRetA[2]; uint8_t pchRetA[2];
if ((recvr = InterruptibleRecv(pchRetA, 2, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) { if ((recvr = InterruptibleRecv(pchRetA, 2, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) {
CloseSocket(hSocket);
return error("Error reading proxy authentication response"); return error("Error reading proxy authentication response");
} }
if (pchRetA[0] != 0x01 || pchRetA[1] != 0x00) { if (pchRetA[0] != 0x01 || pchRetA[1] != 0x00) {
CloseSocket(hSocket);
return error("Proxy authentication unsuccessful"); return error("Proxy authentication unsuccessful");
} }
} else if (pchRet1[1] == SOCKS5Method::NOAUTH) { } else if (pchRet1[1] == SOCKS5Method::NOAUTH) {
// Perform no authentication // Perform no authentication
} else { } else {
CloseSocket(hSocket);
return error("Proxy requested wrong authentication method %02x", pchRet1[1]); return error("Proxy requested wrong authentication method %02x", pchRet1[1]);
} }
std::vector<uint8_t> vSocks5; std::vector<uint8_t> vSocks5;
@ -393,12 +385,10 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials
vSocks5.push_back((port >> 0) & 0xFF); vSocks5.push_back((port >> 0) & 0xFF);
ret = send(hSocket, (const char*)vSocks5.data(), vSocks5.size(), MSG_NOSIGNAL); ret = send(hSocket, (const char*)vSocks5.data(), vSocks5.size(), MSG_NOSIGNAL);
if (ret != (ssize_t)vSocks5.size()) { if (ret != (ssize_t)vSocks5.size()) {
CloseSocket(hSocket);
return error("Error sending to proxy"); return error("Error sending to proxy");
} }
uint8_t pchRet2[4]; uint8_t pchRet2[4];
if ((recvr = InterruptibleRecv(pchRet2, 4, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) { if ((recvr = InterruptibleRecv(pchRet2, 4, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) {
CloseSocket(hSocket);
if (recvr == IntrRecvError::Timeout) { if (recvr == IntrRecvError::Timeout) {
/* If a timeout happens here, this effectively means we timed out while connecting /* If a timeout happens here, this effectively means we timed out while connecting
* to the remote node. This is very common for Tor, so do not print an * to the remote node. This is very common for Tor, so do not print an
@ -409,17 +399,14 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials
} }
} }
if (pchRet2[0] != SOCKSVersion::SOCKS5) { if (pchRet2[0] != SOCKSVersion::SOCKS5) {
CloseSocket(hSocket);
return error("Proxy failed to accept request"); return error("Proxy failed to accept request");
} }
if (pchRet2[1] != SOCKS5Reply::SUCCEEDED) { if (pchRet2[1] != SOCKS5Reply::SUCCEEDED) {
// Failures to connect to a peer that are not proxy errors // Failures to connect to a peer that are not proxy errors
CloseSocket(hSocket);
LogPrintf("Socks5() connect to %s:%d failed: %s\n", strDest, port, Socks5ErrorString(pchRet2[1])); LogPrintf("Socks5() connect to %s:%d failed: %s\n", strDest, port, Socks5ErrorString(pchRet2[1]));
return false; return false;
} }
if (pchRet2[2] != 0x00) { // Reserved field must be 0 if (pchRet2[2] != 0x00) { // Reserved field must be 0
CloseSocket(hSocket);
return error("Error: malformed proxy response"); return error("Error: malformed proxy response");
} }
uint8_t pchRet3[256]; uint8_t pchRet3[256];
@ -431,21 +418,18 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials
{ {
recvr = InterruptibleRecv(pchRet3, 1, SOCKS5_RECV_TIMEOUT, hSocket); recvr = InterruptibleRecv(pchRet3, 1, SOCKS5_RECV_TIMEOUT, hSocket);
if (recvr != IntrRecvError::OK) { if (recvr != IntrRecvError::OK) {
CloseSocket(hSocket);
return error("Error reading from proxy"); return error("Error reading from proxy");
} }
int nRecv = pchRet3[0]; int nRecv = pchRet3[0];
recvr = InterruptibleRecv(pchRet3, nRecv, SOCKS5_RECV_TIMEOUT, hSocket); recvr = InterruptibleRecv(pchRet3, nRecv, SOCKS5_RECV_TIMEOUT, hSocket);
break; break;
} }
default: CloseSocket(hSocket); return error("Error: malformed proxy response"); default: return error("Error: malformed proxy response");
} }
if (recvr != IntrRecvError::OK) { if (recvr != IntrRecvError::OK) {
CloseSocket(hSocket);
return error("Error reading from proxy"); return error("Error reading from proxy");
} }
if ((recvr = InterruptibleRecv(pchRet3, 2, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) { if ((recvr = InterruptibleRecv(pchRet3, 2, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) {
CloseSocket(hSocket);
return error("Error reading from proxy"); return error("Error reading from proxy");
} }
LogPrint(BCLog::NET, "SOCKS5 connected %s\n", strDest); LogPrint(BCLog::NET, "SOCKS5 connected %s\n", strDest);
@ -488,7 +472,7 @@ SOCKET CreateSocket(const CService &addrConnect)
return hSocket; return hSocket;
} }
bool ConnectSocketDirectly(const CService &addrConnect, SOCKET& hSocket, int nTimeout) bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocket, int nTimeout)
{ {
struct sockaddr_storage sockaddr; struct sockaddr_storage sockaddr;
socklen_t len = sizeof(sockaddr); socklen_t len = sizeof(sockaddr);
@ -498,7 +482,6 @@ bool ConnectSocketDirectly(const CService &addrConnect, SOCKET& hSocket, int nTi
} }
if (!addrConnect.GetSockAddr((struct sockaddr*)&sockaddr, &len)) { if (!addrConnect.GetSockAddr((struct sockaddr*)&sockaddr, &len)) {
LogPrintf("Cannot connect to %s: unsupported network\n", addrConnect.ToString()); LogPrintf("Cannot connect to %s: unsupported network\n", addrConnect.ToString());
CloseSocket(hSocket);
return false; return false;
} }
if (connect(hSocket, (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR) if (connect(hSocket, (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR)
@ -515,13 +498,11 @@ bool ConnectSocketDirectly(const CService &addrConnect, SOCKET& hSocket, int nTi
if (nRet == 0) if (nRet == 0)
{ {
LogPrint(BCLog::NET, "connection to %s timeout\n", addrConnect.ToString()); LogPrint(BCLog::NET, "connection to %s timeout\n", addrConnect.ToString());
CloseSocket(hSocket);
return false; return false;
} }
if (nRet == SOCKET_ERROR) if (nRet == SOCKET_ERROR)
{ {
LogPrintf("select() for %s failed: %s\n", addrConnect.ToString(), NetworkErrorString(WSAGetLastError())); LogPrintf("select() for %s failed: %s\n", addrConnect.ToString(), NetworkErrorString(WSAGetLastError()));
CloseSocket(hSocket);
return false; return false;
} }
socklen_t nRetSize = sizeof(nRet); socklen_t nRetSize = sizeof(nRet);
@ -532,13 +513,11 @@ bool ConnectSocketDirectly(const CService &addrConnect, SOCKET& hSocket, int nTi
#endif #endif
{ {
LogPrintf("getsockopt() for %s failed: %s\n", addrConnect.ToString(), NetworkErrorString(WSAGetLastError())); LogPrintf("getsockopt() for %s failed: %s\n", addrConnect.ToString(), NetworkErrorString(WSAGetLastError()));
CloseSocket(hSocket);
return false; return false;
} }
if (nRet != 0) if (nRet != 0)
{ {
LogPrintf("connect() to %s failed after select(): %s\n", addrConnect.ToString(), NetworkErrorString(nRet)); LogPrintf("connect() to %s failed after select(): %s\n", addrConnect.ToString(), NetworkErrorString(nRet));
CloseSocket(hSocket);
return false; return false;
} }
} }
@ -549,7 +528,6 @@ bool ConnectSocketDirectly(const CService &addrConnect, SOCKET& hSocket, int nTi
#endif #endif
{ {
LogPrintf("connect() to %s failed: %s\n", addrConnect.ToString(), NetworkErrorString(WSAGetLastError())); LogPrintf("connect() to %s failed: %s\n", addrConnect.ToString(), NetworkErrorString(WSAGetLastError()));
CloseSocket(hSocket);
return false; return false;
} }
} }
@ -604,7 +582,7 @@ bool IsProxy(const CNetAddr &addr) {
return false; return false;
} }
bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, SOCKET& hSocket, int nTimeout, bool *outProxyConnectionFailed) bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, const SOCKET& hSocket, int nTimeout, bool *outProxyConnectionFailed)
{ {
// first connect to proxy server // first connect to proxy server
if (!ConnectSocketDirectly(proxy.proxy, hSocket, nTimeout)) { if (!ConnectSocketDirectly(proxy.proxy, hSocket, nTimeout)) {
@ -618,12 +596,10 @@ bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int
static std::atomic_int counter(0); static std::atomic_int counter(0);
random_auth.username = random_auth.password = strprintf("%i", counter++); random_auth.username = random_auth.password = strprintf("%i", counter++);
if (!Socks5(strDest, (unsigned short)port, &random_auth, hSocket)) { if (!Socks5(strDest, (unsigned short)port, &random_auth, hSocket)) {
CloseSocket(hSocket);
return false; return false;
} }
} else { } else {
if (!Socks5(strDest, (unsigned short)port, 0, hSocket)) { if (!Socks5(strDest, (unsigned short)port, 0, hSocket)) {
CloseSocket(hSocket);
return false; return false;
} }
} }

@ -52,8 +52,8 @@ bool Lookup(const char *pszName, std::vector<CService>& vAddr, int portDefault,
CService LookupNumeric(const char *pszName, int portDefault = 0); CService LookupNumeric(const char *pszName, int portDefault = 0);
bool LookupSubNet(const char *pszName, CSubNet& subnet); bool LookupSubNet(const char *pszName, CSubNet& subnet);
SOCKET CreateSocket(const CService &addrConnect); SOCKET CreateSocket(const CService &addrConnect);
bool ConnectSocketDirectly(const CService &addrConnect, SOCKET& hSocketRet, int nTimeout); bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocketRet, int nTimeout);
bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, SOCKET& hSocketRet, int nTimeout, bool *outProxyConnectionFailed); bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, const SOCKET& hSocketRet, int nTimeout, bool *outProxyConnectionFailed);
/** Return readable error string for a network error code */ /** Return readable error string for a network error code */
std::string NetworkErrorString(int err); std::string NetworkErrorString(int err);
/** Close socket and set hSocket to INVALID_SOCKET */ /** Close socket and set hSocket to INVALID_SOCKET */

Loading…
Cancel
Save