saito_core/core/consensus/
merkle.rs

1use std::collections::LinkedList;
2
3use rayon::prelude::*;
4
5use crate::core::consensus::transaction::Transaction;
6use crate::core::defs::SaitoHash;
7use crate::core::util::crypto::hash;
8use crate::iterate_mut;
9
10#[derive(PartialEq)]
11pub enum TraverseMode {
12    DepthFist,
13    BreadthFirst,
14}
15
16enum NodeType {
17    Node {
18        left: Option<Box<MerkleTreeNode>>,
19        right: Option<Box<MerkleTreeNode>>,
20    },
21    Transaction {
22        index: usize,
23    },
24}
25
26pub struct MerkleTreeNode {
27    node_type: NodeType,
28    hash: Option<SaitoHash>,
29    count: usize,
30    is_spv: bool,
31}
32
33impl MerkleTreeNode {
34    fn new(
35        node_type: NodeType,
36        hash: Option<SaitoHash>,
37        count: usize,
38        is_spv: bool,
39    ) -> MerkleTreeNode {
40        MerkleTreeNode {
41            node_type,
42            hash,
43            count,
44            is_spv,
45        }
46    }
47
48    pub fn get_hash(&self) -> Option<SaitoHash> {
49        return self.hash;
50    }
51}
52
53pub struct MerkleTree {
54    root: Box<MerkleTreeNode>,
55}
56
57impl MerkleTree {
58    pub fn len(&self) -> usize {
59        self.root.count
60    }
61
62    pub fn get_root_hash(&self) -> SaitoHash {
63        return self.root.hash.unwrap();
64    }
65
66    pub fn generate(transactions: &Vec<Transaction>) -> Option<Box<MerkleTree>> {
67        if transactions.is_empty() {
68            return None;
69        }
70
71        let mut leaves: LinkedList<Box<MerkleTreeNode>> = LinkedList::new();
72
73        // Create leaves for the Merkle tree
74        for (index, tx) in transactions.iter().enumerate() {
75            if tx.txs_replacements > 1 {
76                for _ in 0..tx.txs_replacements {
77                    leaves.push_back(Box::new(MerkleTreeNode::new(
78                        NodeType::Transaction { index },
79                        Some(tx.hash_for_signature.unwrap_or([0; 32])),
80                        1,
81                        true, // is_spv
82                    )));
83                }
84            } else {
85                leaves.push_back(Box::new(MerkleTreeNode::new(
86                    NodeType::Transaction { index },
87                    tx.hash_for_signature,
88                    1,
89                    false, // is_spv
90                )));
91            }
92        }
93
94        // Combine leaves into nodes to form the tree
95
96        while leaves.len() > 1 {
97            let mut nodes: LinkedList<MerkleTreeNode> = Default::default();
98
99            // Create a node per two leaves
100            while !leaves.is_empty() {
101                let left = leaves.pop_front();
102                let right = leaves.pop_front(); //Can be None, this is expected
103                let count = MerkleTree::calculate_child_count(&left, &right);
104
105                if right.is_some() {
106                    nodes.push_back(MerkleTreeNode::new(
107                        NodeType::Node { left, right },
108                        None,
109                        count,
110                        false,
111                    ));
112                } else {
113                    let hash = left.as_ref().unwrap().get_hash();
114                    nodes.push_back(MerkleTreeNode::new(
115                        NodeType::Node { left, right },
116                        hash,
117                        count,
118                        false,
119                    ));
120                }
121            }
122
123            // Compute the node hashes in parallel
124            iterate_mut!(nodes).all(MerkleTree::generate_hash);
125            // Collect the next set of leaves for the computation
126            leaves.clear();
127
128            while !nodes.is_empty() {
129                let node = nodes.pop_front().unwrap();
130                leaves.push_back(Box::new(node));
131            }
132
133            // trace!("---------------------");
134        }
135
136        Some(Box::new(MerkleTree {
137            root: leaves.pop_front().unwrap(),
138        }))
139    }
140
141    pub fn compute_combined_hash(
142        left_hash: Option<[u8; 32]>,
143        right_hash: Option<[u8; 32]>,
144    ) -> [u8; 32] {
145        let mut vbytes: Vec<u8> = vec![];
146        vbytes.extend(left_hash.unwrap());
147        vbytes.extend(right_hash.unwrap());
148        hash(&vbytes)
149    }
150    pub fn traverse(&self, mode: TraverseMode, read_func: impl Fn(&MerkleTreeNode)) {
151        MerkleTree::traverse_node(&mode, &self.root, &read_func);
152    }
153
154    pub fn create_clone(&self) -> Box<MerkleTree> {
155        Box::new(MerkleTree {
156            root: MerkleTree::clone_node(Some(&self.root)).unwrap(),
157        })
158    }
159
160    pub fn prune(&mut self, prune_func: impl Fn(usize) -> bool) {
161        MerkleTree::prune_node(Some(&mut self.root), &prune_func);
162    }
163
164    pub fn calculate_child_count(
165        left: &Option<Box<MerkleTreeNode>>,
166        right: &Option<Box<MerkleTreeNode>>,
167    ) -> usize {
168        let mut count = 1 as usize;
169
170        if left.is_some() {
171            count += left.as_ref().unwrap().count;
172        }
173
174        if right.is_some() {
175            count += right.as_ref().unwrap().count;
176        }
177
178        count
179    }
180
181    fn generate_hash(node: &mut MerkleTreeNode) -> bool {
182        if node.hash.is_some() {
183            return true;
184        }
185
186        match &node.node_type {
187            NodeType::Node { left, right } => {
188                let mut vbytes: Vec<u8> = vec![];
189                vbytes.extend(left.as_ref().unwrap().hash.unwrap());
190                vbytes.extend(right.as_ref().unwrap().hash.unwrap());
191
192                // dbg!(hash(&vbytes));
193                node.hash = Some(hash(&vbytes));
194                // trace!(
195                //     "Node : buffer = {:?}, hash = {:?}",
196                //     hex::encode(vbytes),
197                //     hex::encode(node.hash.unwrap())
198                // );
199            }
200            NodeType::Transaction { .. } => {}
201        }
202
203        return true;
204    }
205
206    fn traverse_node(
207        mode: &TraverseMode,
208        node: &MerkleTreeNode,
209        read_func: &impl Fn(&MerkleTreeNode),
210    ) {
211        if *mode == TraverseMode::BreadthFirst {
212            read_func(node);
213        }
214
215        match &node.node_type {
216            NodeType::Node { left, right } => {
217                if left.is_some() {
218                    MerkleTree::traverse_node(mode, left.as_ref().unwrap(), read_func);
219                }
220
221                if right.is_some() {
222                    MerkleTree::traverse_node(mode, right.as_ref().unwrap(), read_func);
223                }
224            }
225            NodeType::Transaction { .. } => {}
226        }
227
228        if *mode == TraverseMode::DepthFist {
229            read_func(node);
230        }
231    }
232
233    fn clone_node(node: Option<&Box<MerkleTreeNode>>) -> Option<Box<MerkleTreeNode>> {
234        if node.is_some() {
235            Some(Box::new(MerkleTreeNode::new(
236                match &node.unwrap().node_type {
237                    NodeType::Node { left, right } => NodeType::Node {
238                        left: MerkleTree::clone_node(left.as_ref()),
239                        right: MerkleTree::clone_node(right.as_ref()),
240                    },
241                    NodeType::Transaction { index } => NodeType::Transaction { index: *index },
242                },
243                node.as_ref().unwrap().hash,
244                node.as_ref().unwrap().count,
245                node.as_ref().unwrap().is_spv,
246            )))
247        } else {
248            None
249        }
250    }
251
252    fn prune_node(
253        node: Option<&mut Box<MerkleTreeNode>>,
254        prune_func: &impl Fn(usize) -> bool,
255    ) -> bool {
256        return if node.is_some() {
257            let node = node.unwrap();
258            match &mut node.node_type {
259                NodeType::Node { left, right } => {
260                    let mut prune = MerkleTree::prune_node(left.as_mut(), prune_func);
261                    prune &= MerkleTree::prune_node(right.as_mut(), prune_func);
262
263                    if prune {
264                        node.node_type = NodeType::Node {
265                            left: None,
266                            right: None,
267                        };
268                        node.count = 1;
269                    } else {
270                        node.count = MerkleTree::calculate_child_count(&left, &right);
271                    }
272
273                    prune
274                }
275                NodeType::Transaction { index } => prune_func(*index),
276            }
277        } else {
278            true
279        };
280    }
281    // Generates a Merkle proof for the given transaction hash.
282}
283
284#[cfg(test)]
285mod tests {
286    use crate::core::consensus::merkle::MerkleTree;
287    use crate::core::consensus::transaction::{Transaction, TransactionType};
288    use crate::core::consensus::wallet::Wallet;
289    use crate::core::util::crypto::generate_keys;
290
291    #[test]
292    fn merkle_tree_generation_test() {
293        let keys = generate_keys();
294        let wallet = Wallet::new(keys.1, keys.0);
295
296        let mut transactions = vec![];
297        for i in 0..5 {
298            let mut transaction = Transaction::default();
299            transaction.timestamp = i;
300            transaction.sign(&wallet.private_key);
301            transactions.push(transaction);
302        }
303
304        let tree1 = MerkleTree::generate(&transactions).unwrap();
305        transactions[0].timestamp = 10;
306        transactions[0].sign(&wallet.private_key);
307        let tree2 = MerkleTree::generate(&transactions).unwrap();
308        transactions[4].timestamp = 11;
309        transactions[4].sign(&wallet.private_key);
310        let tree3 = MerkleTree::generate(&transactions).unwrap();
311
312        transactions[2].timestamp = 12;
313        transactions[2].sign(&wallet.private_key);
314        let tree4 = MerkleTree::generate(&transactions).unwrap();
315        let tree5 = MerkleTree::generate(&transactions).unwrap();
316
317        dbg!(tree1.get_root_hash(), tree2.get_root_hash());
318        assert_ne!(tree1.get_root_hash(), tree2.get_root_hash());
319        assert_ne!(tree2.get_root_hash(), tree3.get_root_hash());
320        assert_ne!(tree3.get_root_hash(), tree4.get_root_hash());
321        assert_eq!(tree4.get_root_hash(), tree5.get_root_hash());
322    }
323
324    #[test]
325    fn test_generate_odd_number_of_transactions() {
326        let keys = generate_keys();
327        let wallet = Wallet::new(keys.1, keys.0);
328
329        let mut transactions = Vec::new();
330        for i in 0..3 {
331            // Use 3 for an odd number of transactions
332            let mut transaction = Transaction::default();
333            transaction.timestamp = i;
334            transaction.sign(&wallet.private_key);
335            transactions.push(transaction);
336        }
337
338        // Generate the Merkle tree from the transactions.
339        let tree = MerkleTree::generate(&transactions).unwrap();
340
341        let root_hash = tree.get_root_hash();
342
343        assert_ne!(root_hash, [0u8; 32], "Root hash should not be all zeros.");
344
345        let mut altered_transactions = transactions.clone();
346        altered_transactions[0].timestamp += 1;
347        altered_transactions[0].sign(&wallet.private_key);
348        let altered_tree = MerkleTree::generate(&altered_transactions).unwrap();
349        let altered_root_hash = altered_tree.get_root_hash();
350        assert_ne!(
351            root_hash, altered_root_hash,
352            "Root hash should change when a transaction is altered."
353        );
354    }
355
356    #[test]
357    fn merkle_tree_pruning_test() {
358        let keys = generate_keys();
359        let wallet = Wallet::new(keys.1, keys.0);
360
361        let mut transactions = vec![];
362
363        for i in 0..5 {
364            let mut transaction = Transaction::default();
365            transaction.timestamp = i;
366            transaction.sign(&wallet.private_key);
367            transactions.push(transaction);
368        }
369
370        let target_hash = transactions[0].hash_for_signature.unwrap();
371
372        let tree = MerkleTree::generate(&transactions).unwrap();
373        let cloned_tree = tree.create_clone();
374        let mut pruned_tree = tree.create_clone();
375        pruned_tree.prune(|index| target_hash != transactions[index].hash_for_signature.unwrap());
376
377        assert_eq!(tree.get_root_hash(), cloned_tree.get_root_hash());
378        assert_eq!(cloned_tree.get_root_hash(), pruned_tree.get_root_hash());
379        assert_eq!(tree.len(), 11);
380        assert_eq!(cloned_tree.len(), tree.len());
381        assert_eq!(pruned_tree.len(), 7);
382    }
383
384    #[test]
385    fn test_generate_with_spv_transactions() {
386        let keys = generate_keys();
387        let wallet = Wallet::new(keys.1, keys.0);
388
389        // Create 5 normal transactions and sign them
390        let mut transactions = Vec::new();
391        for _ in 0..5 {
392            let mut tx = Transaction::default();
393            tx.sign(&wallet.private_key);
394            // dbg!(&tx);
395            transactions.push(tx);
396        }
397
398        // Generate the Merkle tree from the original transactions
399        let merkle_tree_original = MerkleTree::generate(&transactions).unwrap();
400        let root_original = merkle_tree_original.get_root_hash();
401
402        // Replace second transaction with an SPV transaction (b becomes b(SPV))
403        let mut transactions_with_spv_b = transactions.clone();
404        transactions_with_spv_b[1] = create_spv_transaction(&wallet, transactions[1].clone(), 1);
405        let merkle_tree_spv_b = MerkleTree::generate(&transactions_with_spv_b).unwrap();
406        assert_eq!(
407            root_original,
408            merkle_tree_spv_b.get_root_hash(),
409            "Merkle roots should be equal after replacing b with b(SPV)."
410        );
411
412        //  Replace all transactions with SPV transactions
413        let transactions_all_spv = transactions
414            .iter()
415            .map(|tx| create_spv_transaction(&wallet, tx.clone(), 1))
416            .collect::<Vec<Transaction>>();
417        let merkle_tree_all_spv = MerkleTree::generate(&transactions_all_spv).unwrap();
418        assert_eq!(
419            root_original,
420            merkle_tree_all_spv.get_root_hash(),
421            "Merkle roots should be equal after replacing all with SPV transactions."
422        );
423
424        //  Combine c and d into a single SPV transaction (cd becomes cd(SPV))
425        // and combine their hashes
426        let mut transactions_cd_spv = transactions.clone();
427        let combined_tx3_tx4: Transaction =
428            combine_transactions_into_spv(transactions[2].clone(), transactions[3].clone());
429
430        dbg!(&combined_tx3_tx4);
431        transactions_cd_spv.splice(2..4, std::iter::once(combined_tx3_tx4));
432        let merkle_tree_cd_spv = MerkleTree::generate(&transactions_cd_spv).unwrap();
433        assert_eq!(
434            root_original,
435            merkle_tree_cd_spv.get_root_hash(),
436            "Merkle roots should be equal after replacing cd with combined cd(SPV)."
437        );
438
439        // Various SPV transactions (b and cd are SPV, e is SPV)
440        let mut transactions_mixed_spv = transactions.clone();
441        transactions_mixed_spv[1] = create_spv_transaction(&wallet, transactions[1].clone(), 1);
442        transactions_mixed_spv.splice(
443            2..4,
444            std::iter::once(combine_transactions_into_spv(
445                transactions[2].clone(),
446                transactions[3].clone(),
447            )),
448        );
449        transactions_mixed_spv[3] = create_spv_transaction(&wallet, transactions[4].clone(), 1);
450        let merkle_tree_mixed_spv = MerkleTree::generate(&transactions_mixed_spv).unwrap();
451        assert_eq!(
452            root_original,
453            merkle_tree_mixed_spv.get_root_hash(),
454            "Merkle roots should be equal after various SPV replacements."
455        );
456
457        // Break it by changing txs_replacements value somewhere
458        let mut transactions_broken_spv = transactions_mixed_spv.clone();
459        transactions_broken_spv[1].txs_replacements = 3;
460        let merkle_tree_broken_spv = MerkleTree::generate(&transactions_broken_spv).unwrap();
461        assert_ne!(
462            root_original,
463            merkle_tree_broken_spv.get_root_hash(),
464            "Merkle root should differ due to changed SPV tx replacements."
465        );
466    }
467
468    fn create_spv_transaction(
469        wallet: &Wallet,
470        mut tx: Transaction,
471        txs_replacements: u32,
472    ) -> Transaction {
473        tx.sign(&wallet.private_key);
474        Transaction {
475            timestamp: tx.timestamp,
476            from: tx.from,
477            to: tx.to,
478            data: tx.data,
479            transaction_type: TransactionType::SPV,
480            txs_replacements, // Set the number of transactions this SPV transaction represents
481            signature: tx.signature,
482            path: tx.path,
483            hash_for_signature: tx.hash_for_signature,
484            total_in: tx.total_in,
485            total_out: tx.total_out,
486            total_fees: tx.total_fees,
487            total_work_for_me: tx.total_work_for_me,
488            cumulative_fees: tx.cumulative_fees,
489        }
490    }
491    fn combine_transactions_into_spv(mut tx1: Transaction, tx2: Transaction) -> Transaction {
492        let combined_hash =
493            MerkleTree::compute_combined_hash(tx1.hash_for_signature, tx2.hash_for_signature);
494
495        dbg!(combined_hash);
496        tx1.hash_for_signature = tx2.hash_for_signature;
497        tx1.transaction_type = TransactionType::SPV;
498        tx1.txs_replacements = 2;
499        tx1
500        // spv_tx.transaction_type = TransactionType::SPV;
501        // spv_tx.txs_replacements = 1; // Represents the combination of two transactions
502        // spv_tx.hash_for_signature = Some(combined_hash);
503        // // dbg!(&spv_tx);
504        // spv_tx
505    }
506}