Replace CAffectedKeysVisitor with descriptor based logic

pull/643/head
Pieter Wuille 6 years ago
parent fdf146f329
commit 0e75f44a09

@ -22,6 +22,7 @@
#include <policy/rbf.h> #include <policy/rbf.h>
#include <primitives/block.h> #include <primitives/block.h>
#include <primitives/transaction.h> #include <primitives/transaction.h>
#include <script/descriptor.h>
#include <script/script.h> #include <script/script.h>
#include <shutdown.h> #include <shutdown.h>
#include <timedata.h> #include <timedata.h>
@ -104,68 +105,18 @@ std::string COutput::ToString() const
return strprintf("COutput(%s, %d, %d) [%s]", tx->GetHash().ToString(), i, nDepth, FormatMoney(tx->tx->vout[i].nValue)); return strprintf("COutput(%s, %d, %d) [%s]", tx->GetHash().ToString(), i, nDepth, FormatMoney(tx->tx->vout[i].nValue));
} }
/** A class to identify which pubkeys a script and a keystore have in common. */ std::vector<CKeyID> GetAffectedKeys(const CScript& spk, const SigningProvider& provider)
class CAffectedKeysVisitor : public boost::static_visitor<void> {
private:
const CKeyStore &keystore;
std::vector<CKeyID> &vKeys;
public:
/**
* @param[in] keystoreIn The CKeyStore that is queried for the presence of a pubkey.
* @param[out] vKeysIn A vector to which a script's pubkey identifiers are appended if they are in the keystore.
*/
CAffectedKeysVisitor(const CKeyStore &keystoreIn, std::vector<CKeyID> &vKeysIn) : keystore(keystoreIn), vKeys(vKeysIn) {}
/**
* Apply the visitor to each destination in a script, recursively to the redeemscript
* in the case of p2sh destinations.
* @param[in] script The CScript from which destinations are extracted.
* @post Any CKeyIDs that script and keystore have in common are appended to the visitor's vKeys.
*/
void Process(const CScript &script) {
txnouttype type;
std::vector<CTxDestination> vDest;
int nRequired;
if (ExtractDestinations(script, type, vDest, nRequired)) {
for (const CTxDestination &dest : vDest)
boost::apply_visitor(*this, dest);
}
}
void operator()(const CKeyID &keyId) {
if (keystore.HaveKey(keyId))
vKeys.push_back(keyId);
}
void operator()(const CScriptID &scriptId) {
CScript script;
if (keystore.GetCScript(scriptId, script))
Process(script);
}
void operator()(const WitnessV0ScriptHash& scriptID)
{
CScriptID id;
CRIPEMD160().Write(scriptID.begin(), 32).Finalize(id.begin());
CScript script;
if (keystore.GetCScript(id, script)) {
Process(script);
}
}
void operator()(const WitnessV0KeyHash& keyid)
{ {
CKeyID id(keyid); std::vector<CScript> dummy;
if (keystore.HaveKey(id)) { FlatSigningProvider out;
vKeys.push_back(id); InferDescriptor(spk, provider)->Expand(0, DUMMY_SIGNING_PROVIDER, dummy, out);
std::vector<CKeyID> ret;
for (const auto& entry : out.pubkeys) {
ret.push_back(entry.first);
} }
return ret;
} }
template<typename X>
void operator()(const X &none) {}
};
const CWalletTx* CWallet::GetWalletTx(const uint256& hash) const const CWalletTx* CWallet::GetWalletTx(const uint256& hash) const
{ {
LOCK(cs_wallet); LOCK(cs_wallet);
@ -977,9 +928,7 @@ bool CWallet::AddToWalletIfInvolvingMe(const CTransactionRef& ptx, const CBlockI
// loop though all outputs // loop though all outputs
for (const CTxOut& txout: tx.vout) { for (const CTxOut& txout: tx.vout) {
// extract addresses and check if they match with an unused keypool key // extract addresses and check if they match with an unused keypool key
std::vector<CKeyID> vAffected; for (const auto& keyid : GetAffectedKeys(txout.scriptPubKey, *this)) {
CAffectedKeysVisitor(*this, vAffected).Process(txout.scriptPubKey);
for (const CKeyID &keyid : vAffected) {
std::map<CKeyID, int64_t>::const_iterator mi = m_pool_key_to_index.find(keyid); std::map<CKeyID, int64_t>::const_iterator mi = m_pool_key_to_index.find(keyid);
if (mi != m_pool_key_to_index.end()) { if (mi != m_pool_key_to_index.end()) {
WalletLogPrintf("%s: Detected a used keypool key, mark all keypool key up to this key as used\n", __func__); WalletLogPrintf("%s: Detected a used keypool key, mark all keypool key up to this key as used\n", __func__);
@ -3693,7 +3642,6 @@ void CWallet::GetKeyBirthTimes(interfaces::Chain::Lock& locked_chain, std::map<C
return; return;
// find first block that affects those keys, if there are any left // find first block that affects those keys, if there are any left
std::vector<CKeyID> vAffected;
for (const auto& entry : mapWallet) { for (const auto& entry : mapWallet) {
// iterate over all wallet transactions... // iterate over all wallet transactions...
const CWalletTx &wtx = entry.second; const CWalletTx &wtx = entry.second;
@ -3703,14 +3651,12 @@ void CWallet::GetKeyBirthTimes(interfaces::Chain::Lock& locked_chain, std::map<C
int nHeight = pindex->nHeight; int nHeight = pindex->nHeight;
for (const CTxOut &txout : wtx.tx->vout) { for (const CTxOut &txout : wtx.tx->vout) {
// iterate over all their outputs // iterate over all their outputs
CAffectedKeysVisitor(*this, vAffected).Process(txout.scriptPubKey); for (const auto &keyid : GetAffectedKeys(txout.scriptPubKey, *this)) {
for (const CKeyID &keyid : vAffected) {
// ... and all their affected keys // ... and all their affected keys
std::map<CKeyID, CBlockIndex*>::iterator rit = mapKeyFirstBlock.find(keyid); std::map<CKeyID, CBlockIndex*>::iterator rit = mapKeyFirstBlock.find(keyid);
if (rit != mapKeyFirstBlock.end() && nHeight < rit->second->nHeight) if (rit != mapKeyFirstBlock.end() && nHeight < rit->second->nHeight)
rit->second = pindex; rit->second = pindex;
} }
vAffected.clear();
} }
} }
} }

Loading…
Cancel
Save