diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index 519f022cc88..feb0563409e 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -2372,6 +2372,13 @@ bool CWallet::SelectCoins(const std::vector& vAvailableCoins, const CAm ++it; } + unsigned int limit_ancestor_count = 0; + unsigned int limit_descendant_count = 0; + chain().getPackageLimits(limit_ancestor_count, limit_descendant_count); + size_t max_ancestors = (size_t)std::max(1, limit_ancestor_count); + size_t max_descendants = (size_t)std::max(1, limit_descendant_count); + bool fRejectLongChains = gArgs.GetBoolArg("-walletrejectlongchains", DEFAULT_WALLET_REJECT_LONG_CHAINS); + // form groups from remaining coins; note that preset coins will not // automatically have their associated (same address) coins included if (coin_control.m_avoid_partial_spends && vCoins.size() > OUTPUT_GROUP_MAX_ENTRIES) { @@ -2380,14 +2387,7 @@ bool CWallet::SelectCoins(const std::vector& vAvailableCoins, const CAm // explicitly shuffling the outputs before processing Shuffle(vCoins.begin(), vCoins.end(), FastRandomContext()); } - std::vector groups = GroupOutputs(vCoins, !coin_control.m_avoid_partial_spends); - - unsigned int limit_ancestor_count; - unsigned int limit_descendant_count; - chain().getPackageLimits(limit_ancestor_count, limit_descendant_count); - size_t max_ancestors = (size_t)std::max(1, limit_ancestor_count); - size_t max_descendants = (size_t)std::max(1, limit_descendant_count); - bool fRejectLongChains = gArgs.GetBoolArg("-walletrejectlongchains", DEFAULT_WALLET_REJECT_LONG_CHAINS); + std::vector groups = GroupOutputs(vCoins, !coin_control.m_avoid_partial_spends, max_ancestors); bool res = value_to_select <= 0 || SelectCoinsMinConf(value_to_select, CoinEligibilityFilter(1, 6, 0), groups, setCoinsRet, nValueRet, coin_selection_params, bnb_used) || @@ -4184,32 +4184,49 @@ bool CWalletTx::IsImmatureCoinBase() const return GetBlocksToMaturity() > 0; } -std::vector CWallet::GroupOutputs(const std::vector& outputs, bool single_coin) const { +std::vector CWallet::GroupOutputs(const std::vector& outputs, bool single_coin, const size_t max_ancestors) const { std::vector groups; std::map gmap; - CTxDestination dst; + std::set full_groups; + for (const auto& output : outputs) { if (output.fSpendable) { + CTxDestination dst; CInputCoin input_coin = output.GetInputCoin(); size_t ancestors, descendants; chain().getTransactionAncestry(output.tx->GetHash(), ancestors, descendants); if (!single_coin && ExtractDestination(output.tx->tx->vout[output.i].scriptPubKey, dst)) { - // Limit output groups to no more than 10 entries, to protect - // against inadvertently creating a too-large transaction - // when using -avoidpartialspends - if (gmap[dst].m_outputs.size() >= OUTPUT_GROUP_MAX_ENTRIES) { - groups.push_back(gmap[dst]); - gmap.erase(dst); + auto it = gmap.find(dst); + if (it != gmap.end()) { + // Limit output groups to no more than OUTPUT_GROUP_MAX_ENTRIES + // number of entries, to protect against inadvertently creating + // a too-large transaction when using -avoidpartialspends to + // prevent breaking consensus or surprising users with a very + // high amount of fees. + if (it->second.m_outputs.size() >= OUTPUT_GROUP_MAX_ENTRIES) { + groups.push_back(it->second); + it->second = OutputGroup{}; + full_groups.insert(dst); + } + it->second.Insert(input_coin, output.nDepth, output.tx->IsFromMe(ISMINE_ALL), ancestors, descendants); + } else { + gmap[dst].Insert(input_coin, output.nDepth, output.tx->IsFromMe(ISMINE_ALL), ancestors, descendants); } - gmap[dst].Insert(input_coin, output.nDepth, output.tx->IsFromMe(ISMINE_ALL), ancestors, descendants); } else { groups.emplace_back(input_coin, output.nDepth, output.tx->IsFromMe(ISMINE_ALL), ancestors, descendants); } } } if (!single_coin) { - for (const auto& it : gmap) groups.push_back(it.second); + for (auto& it : gmap) { + auto& group = it.second; + if (full_groups.count(it.first) > 0) { + // Make this unattractive as we want coin selection to avoid it if possible + group.m_ancestors = max_ancestors - 1; + } + groups.push_back(group); + } } return groups; } diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index 577a739d89d..638c8562c9a 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -830,7 +830,7 @@ public: bool IsSpentKey(const uint256& hash, unsigned int n) const EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); void SetSpentKeyState(WalletBatch& batch, const uint256& hash, unsigned int n, bool used, std::set& tx_destinations) EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); - std::vector GroupOutputs(const std::vector& outputs, bool single_coin) const; + std::vector GroupOutputs(const std::vector& outputs, bool single_coin, const size_t max_ancestors) const; bool IsLockedCoin(uint256 hash, unsigned int n) const EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); void LockCoin(const COutPoint& output) EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); diff --git a/test/functional/wallet_avoidreuse.py b/test/functional/wallet_avoidreuse.py index 2ce8d459c6d..78a51a1d5ff 100755 --- a/test/functional/wallet_avoidreuse.py +++ b/test/functional/wallet_avoidreuse.py @@ -85,15 +85,19 @@ class AvoidReuseTest(BitcoinTestFramework): self.sync_all() self.test_change_remains_change(self.nodes[1]) reset_balance(self.nodes[1], self.nodes[0].getnewaddress()) - self.test_fund_send_fund_senddirty() + self.test_sending_from_reused_address_without_avoid_reuse() reset_balance(self.nodes[1], self.nodes[0].getnewaddress()) - self.test_fund_send_fund_send("legacy") + self.test_sending_from_reused_address_fails("legacy") reset_balance(self.nodes[1], self.nodes[0].getnewaddress()) - self.test_fund_send_fund_send("p2sh-segwit") + self.test_sending_from_reused_address_fails("p2sh-segwit") reset_balance(self.nodes[1], self.nodes[0].getnewaddress()) - self.test_fund_send_fund_send("bech32") + self.test_sending_from_reused_address_fails("bech32") reset_balance(self.nodes[1], self.nodes[0].getnewaddress()) self.test_getbalances_used() + reset_balance(self.nodes[1], self.nodes[0].getnewaddress()) + self.test_full_destination_group_is_preferred() + reset_balance(self.nodes[1], self.nodes[0].getnewaddress()) + self.test_all_destination_groups_are_used() def test_persistence(self): '''Test that wallet files persist the avoid_reuse flag.''' @@ -162,13 +166,13 @@ class AvoidReuseTest(BitcoinTestFramework): for logical_tx in node.listtransactions(): assert logical_tx.get('address') != changeaddr - def test_fund_send_fund_senddirty(self): + def test_sending_from_reused_address_without_avoid_reuse(self): ''' - Test the same as test_fund_send_fund_send, except send the 10 BTC with + Test the same as test_sending_from_reused_address_fails, except send the 10 BTC with the avoid_reuse flag set to false. This means the 10 BTC send should succeed, - where it fails in test_fund_send_fund_send. + where it fails in test_sending_from_reused_address_fails. ''' - self.log.info("Test fund send fund send dirty") + self.log.info("Test sending from reused address with avoid_reuse=false") fundaddr = self.nodes[1].getnewaddress() retaddr = self.nodes[0].getnewaddress() @@ -213,7 +217,7 @@ class AvoidReuseTest(BitcoinTestFramework): assert_approx(self.nodes[1].getbalance(), 5, 0.001) assert_approx(self.nodes[1].getbalance(avoid_reuse=False), 5, 0.001) - def test_fund_send_fund_send(self, second_addr_type): + def test_sending_from_reused_address_fails(self, second_addr_type): ''' Test the simple case where [1] generates a new address A, then [0] sends 10 BTC to A. @@ -222,7 +226,7 @@ class AvoidReuseTest(BitcoinTestFramework): [1] tries to spend 10 BTC (fails; dirty). [1] tries to spend 4 BTC (succeeds; change address sufficient) ''' - self.log.info("Test fund send fund send") + self.log.info("Test sending from reused {} address fails".format(second_addr_type)) fundaddr = self.nodes[1].getnewaddress(label="", address_type="legacy") retaddr = self.nodes[0].getnewaddress() @@ -313,5 +317,66 @@ class AvoidReuseTest(BitcoinTestFramework): assert_unspent(self.nodes[1], total_count=2, total_sum=6, reused_count=1, reused_sum=1) assert_balances(self.nodes[1], mine={"used": 1, "trusted": 5}) + def test_full_destination_group_is_preferred(self): + ''' + Test the case where [1] only has 11 outputs of 1 BTC in the same reused + address and tries to send a small payment of 0.5 BTC. The wallet + should use 10 outputs from the reused address as inputs and not a + single 1 BTC input, in order to join several outputs from the reused + address. + ''' + self.log.info("Test that full destination groups are preferred in coin selection") + + # Node under test should be empty + assert_equal(self.nodes[1].getbalance(avoid_reuse=False), 0) + + new_addr = self.nodes[1].getnewaddress() + ret_addr = self.nodes[0].getnewaddress() + + # Send 11 outputs of 1 BTC to the same, reused address in the wallet + for _ in range(11): + self.nodes[0].sendtoaddress(new_addr, 1) + + self.nodes[0].generate(1) + self.sync_all() + + # Sending a transaction that is smaller than each one of the + # available outputs + txid = self.nodes[1].sendtoaddress(address=ret_addr, amount=0.5) + inputs = self.nodes[1].getrawtransaction(txid, 1)["vin"] + + # The transaction should use 10 inputs exactly + assert_equal(len(inputs), 10) + + def test_all_destination_groups_are_used(self): + ''' + Test the case where [1] only has 22 outputs of 1 BTC in the same reused + address and tries to send a payment of 20.5 BTC. The wallet + should use all 22 outputs from the reused address as inputs. + ''' + self.log.info("Test that all destination groups are used") + + # Node under test should be empty + assert_equal(self.nodes[1].getbalance(avoid_reuse=False), 0) + + new_addr = self.nodes[1].getnewaddress() + ret_addr = self.nodes[0].getnewaddress() + + # Send 22 outputs of 1 BTC to the same, reused address in the wallet + for _ in range(22): + self.nodes[0].sendtoaddress(new_addr, 1) + + self.nodes[0].generate(1) + self.sync_all() + + # Sending a transaction that needs to use the full groups + # of 10 inputs but also the incomplete group of 2 inputs. + txid = self.nodes[1].sendtoaddress(address=ret_addr, amount=20.5) + inputs = self.nodes[1].getrawtransaction(txid, 1)["vin"] + + # The transaction should use 22 inputs exactly + assert_equal(len(inputs), 22) + + if __name__ == '__main__': AvoidReuseTest().main()