saito_core/core/consensus/
merkle.rs1use std::collections::LinkedList;
2
3use rayon::prelude::*;
4
5use crate::core::consensus::transaction::Transaction;
6use crate::core::defs::SaitoHash;
7use crate::core::util::crypto::hash;
8use crate::iterate_mut;
9
10#[derive(PartialEq)]
11pub enum TraverseMode {
12 DepthFist,
13 BreadthFirst,
14}
15
16enum NodeType {
17 Node {
18 left: Option<Box<MerkleTreeNode>>,
19 right: Option<Box<MerkleTreeNode>>,
20 },
21 Transaction {
22 index: usize,
23 },
24}
25
26pub struct MerkleTreeNode {
27 node_type: NodeType,
28 hash: Option<SaitoHash>,
29 count: usize,
30 is_spv: bool,
31}
32
33impl MerkleTreeNode {
34 fn new(
35 node_type: NodeType,
36 hash: Option<SaitoHash>,
37 count: usize,
38 is_spv: bool,
39 ) -> MerkleTreeNode {
40 MerkleTreeNode {
41 node_type,
42 hash,
43 count,
44 is_spv,
45 }
46 }
47
48 pub fn get_hash(&self) -> Option<SaitoHash> {
49 return self.hash;
50 }
51}
52
53pub struct MerkleTree {
54 root: Box<MerkleTreeNode>,
55}
56
57impl MerkleTree {
58 pub fn len(&self) -> usize {
59 self.root.count
60 }
61
62 pub fn get_root_hash(&self) -> SaitoHash {
63 return self.root.hash.unwrap();
64 }
65
66 pub fn generate(transactions: &Vec<Transaction>) -> Option<Box<MerkleTree>> {
67 if transactions.is_empty() {
68 return None;
69 }
70
71 let mut leaves: LinkedList<Box<MerkleTreeNode>> = LinkedList::new();
72
73 for (index, tx) in transactions.iter().enumerate() {
75 if tx.txs_replacements > 1 {
76 for _ in 0..tx.txs_replacements {
77 leaves.push_back(Box::new(MerkleTreeNode::new(
78 NodeType::Transaction { index },
79 Some(tx.hash_for_signature.unwrap_or([0; 32])),
80 1,
81 true, )));
83 }
84 } else {
85 leaves.push_back(Box::new(MerkleTreeNode::new(
86 NodeType::Transaction { index },
87 tx.hash_for_signature,
88 1,
89 false, )));
91 }
92 }
93
94 while leaves.len() > 1 {
97 let mut nodes: LinkedList<MerkleTreeNode> = Default::default();
98
99 while !leaves.is_empty() {
101 let left = leaves.pop_front();
102 let right = leaves.pop_front(); let count = MerkleTree::calculate_child_count(&left, &right);
104
105 if right.is_some() {
106 nodes.push_back(MerkleTreeNode::new(
107 NodeType::Node { left, right },
108 None,
109 count,
110 false,
111 ));
112 } else {
113 let hash = left.as_ref().unwrap().get_hash();
114 nodes.push_back(MerkleTreeNode::new(
115 NodeType::Node { left, right },
116 hash,
117 count,
118 false,
119 ));
120 }
121 }
122
123 iterate_mut!(nodes).all(MerkleTree::generate_hash);
125 leaves.clear();
127
128 while !nodes.is_empty() {
129 let node = nodes.pop_front().unwrap();
130 leaves.push_back(Box::new(node));
131 }
132
133 }
135
136 Some(Box::new(MerkleTree {
137 root: leaves.pop_front().unwrap(),
138 }))
139 }
140
141 pub fn compute_combined_hash(
142 left_hash: Option<[u8; 32]>,
143 right_hash: Option<[u8; 32]>,
144 ) -> [u8; 32] {
145 let mut vbytes: Vec<u8> = vec![];
146 vbytes.extend(left_hash.unwrap());
147 vbytes.extend(right_hash.unwrap());
148 hash(&vbytes)
149 }
150 pub fn traverse(&self, mode: TraverseMode, read_func: impl Fn(&MerkleTreeNode)) {
151 MerkleTree::traverse_node(&mode, &self.root, &read_func);
152 }
153
154 pub fn create_clone(&self) -> Box<MerkleTree> {
155 Box::new(MerkleTree {
156 root: MerkleTree::clone_node(Some(&self.root)).unwrap(),
157 })
158 }
159
160 pub fn prune(&mut self, prune_func: impl Fn(usize) -> bool) {
161 MerkleTree::prune_node(Some(&mut self.root), &prune_func);
162 }
163
164 pub fn calculate_child_count(
165 left: &Option<Box<MerkleTreeNode>>,
166 right: &Option<Box<MerkleTreeNode>>,
167 ) -> usize {
168 let mut count = 1 as usize;
169
170 if left.is_some() {
171 count += left.as_ref().unwrap().count;
172 }
173
174 if right.is_some() {
175 count += right.as_ref().unwrap().count;
176 }
177
178 count
179 }
180
181 fn generate_hash(node: &mut MerkleTreeNode) -> bool {
182 if node.hash.is_some() {
183 return true;
184 }
185
186 match &node.node_type {
187 NodeType::Node { left, right } => {
188 let mut vbytes: Vec<u8> = vec![];
189 vbytes.extend(left.as_ref().unwrap().hash.unwrap());
190 vbytes.extend(right.as_ref().unwrap().hash.unwrap());
191
192 node.hash = Some(hash(&vbytes));
194 }
200 NodeType::Transaction { .. } => {}
201 }
202
203 return true;
204 }
205
206 fn traverse_node(
207 mode: &TraverseMode,
208 node: &MerkleTreeNode,
209 read_func: &impl Fn(&MerkleTreeNode),
210 ) {
211 if *mode == TraverseMode::BreadthFirst {
212 read_func(node);
213 }
214
215 match &node.node_type {
216 NodeType::Node { left, right } => {
217 if left.is_some() {
218 MerkleTree::traverse_node(mode, left.as_ref().unwrap(), read_func);
219 }
220
221 if right.is_some() {
222 MerkleTree::traverse_node(mode, right.as_ref().unwrap(), read_func);
223 }
224 }
225 NodeType::Transaction { .. } => {}
226 }
227
228 if *mode == TraverseMode::DepthFist {
229 read_func(node);
230 }
231 }
232
233 fn clone_node(node: Option<&Box<MerkleTreeNode>>) -> Option<Box<MerkleTreeNode>> {
234 if node.is_some() {
235 Some(Box::new(MerkleTreeNode::new(
236 match &node.unwrap().node_type {
237 NodeType::Node { left, right } => NodeType::Node {
238 left: MerkleTree::clone_node(left.as_ref()),
239 right: MerkleTree::clone_node(right.as_ref()),
240 },
241 NodeType::Transaction { index } => NodeType::Transaction { index: *index },
242 },
243 node.as_ref().unwrap().hash,
244 node.as_ref().unwrap().count,
245 node.as_ref().unwrap().is_spv,
246 )))
247 } else {
248 None
249 }
250 }
251
252 fn prune_node(
253 node: Option<&mut Box<MerkleTreeNode>>,
254 prune_func: &impl Fn(usize) -> bool,
255 ) -> bool {
256 return if node.is_some() {
257 let node = node.unwrap();
258 match &mut node.node_type {
259 NodeType::Node { left, right } => {
260 let mut prune = MerkleTree::prune_node(left.as_mut(), prune_func);
261 prune &= MerkleTree::prune_node(right.as_mut(), prune_func);
262
263 if prune {
264 node.node_type = NodeType::Node {
265 left: None,
266 right: None,
267 };
268 node.count = 1;
269 } else {
270 node.count = MerkleTree::calculate_child_count(&left, &right);
271 }
272
273 prune
274 }
275 NodeType::Transaction { index } => prune_func(*index),
276 }
277 } else {
278 true
279 };
280 }
281 }
283
284#[cfg(test)]
285mod tests {
286 use crate::core::consensus::merkle::MerkleTree;
287 use crate::core::consensus::transaction::{Transaction, TransactionType};
288 use crate::core::consensus::wallet::Wallet;
289 use crate::core::util::crypto::generate_keys;
290
291 #[test]
292 fn merkle_tree_generation_test() {
293 let keys = generate_keys();
294 let wallet = Wallet::new(keys.1, keys.0);
295
296 let mut transactions = vec![];
297 for i in 0..5 {
298 let mut transaction = Transaction::default();
299 transaction.timestamp = i;
300 transaction.sign(&wallet.private_key);
301 transactions.push(transaction);
302 }
303
304 let tree1 = MerkleTree::generate(&transactions).unwrap();
305 transactions[0].timestamp = 10;
306 transactions[0].sign(&wallet.private_key);
307 let tree2 = MerkleTree::generate(&transactions).unwrap();
308 transactions[4].timestamp = 11;
309 transactions[4].sign(&wallet.private_key);
310 let tree3 = MerkleTree::generate(&transactions).unwrap();
311
312 transactions[2].timestamp = 12;
313 transactions[2].sign(&wallet.private_key);
314 let tree4 = MerkleTree::generate(&transactions).unwrap();
315 let tree5 = MerkleTree::generate(&transactions).unwrap();
316
317 dbg!(tree1.get_root_hash(), tree2.get_root_hash());
318 assert_ne!(tree1.get_root_hash(), tree2.get_root_hash());
319 assert_ne!(tree2.get_root_hash(), tree3.get_root_hash());
320 assert_ne!(tree3.get_root_hash(), tree4.get_root_hash());
321 assert_eq!(tree4.get_root_hash(), tree5.get_root_hash());
322 }
323
324 #[test]
325 fn test_generate_odd_number_of_transactions() {
326 let keys = generate_keys();
327 let wallet = Wallet::new(keys.1, keys.0);
328
329 let mut transactions = Vec::new();
330 for i in 0..3 {
331 let mut transaction = Transaction::default();
333 transaction.timestamp = i;
334 transaction.sign(&wallet.private_key);
335 transactions.push(transaction);
336 }
337
338 let tree = MerkleTree::generate(&transactions).unwrap();
340
341 let root_hash = tree.get_root_hash();
342
343 assert_ne!(root_hash, [0u8; 32], "Root hash should not be all zeros.");
344
345 let mut altered_transactions = transactions.clone();
346 altered_transactions[0].timestamp += 1;
347 altered_transactions[0].sign(&wallet.private_key);
348 let altered_tree = MerkleTree::generate(&altered_transactions).unwrap();
349 let altered_root_hash = altered_tree.get_root_hash();
350 assert_ne!(
351 root_hash, altered_root_hash,
352 "Root hash should change when a transaction is altered."
353 );
354 }
355
356 #[test]
357 fn merkle_tree_pruning_test() {
358 let keys = generate_keys();
359 let wallet = Wallet::new(keys.1, keys.0);
360
361 let mut transactions = vec![];
362
363 for i in 0..5 {
364 let mut transaction = Transaction::default();
365 transaction.timestamp = i;
366 transaction.sign(&wallet.private_key);
367 transactions.push(transaction);
368 }
369
370 let target_hash = transactions[0].hash_for_signature.unwrap();
371
372 let tree = MerkleTree::generate(&transactions).unwrap();
373 let cloned_tree = tree.create_clone();
374 let mut pruned_tree = tree.create_clone();
375 pruned_tree.prune(|index| target_hash != transactions[index].hash_for_signature.unwrap());
376
377 assert_eq!(tree.get_root_hash(), cloned_tree.get_root_hash());
378 assert_eq!(cloned_tree.get_root_hash(), pruned_tree.get_root_hash());
379 assert_eq!(tree.len(), 11);
380 assert_eq!(cloned_tree.len(), tree.len());
381 assert_eq!(pruned_tree.len(), 7);
382 }
383
384 #[test]
385 fn test_generate_with_spv_transactions() {
386 let keys = generate_keys();
387 let wallet = Wallet::new(keys.1, keys.0);
388
389 let mut transactions = Vec::new();
391 for _ in 0..5 {
392 let mut tx = Transaction::default();
393 tx.sign(&wallet.private_key);
394 transactions.push(tx);
396 }
397
398 let merkle_tree_original = MerkleTree::generate(&transactions).unwrap();
400 let root_original = merkle_tree_original.get_root_hash();
401
402 let mut transactions_with_spv_b = transactions.clone();
404 transactions_with_spv_b[1] = create_spv_transaction(&wallet, transactions[1].clone(), 1);
405 let merkle_tree_spv_b = MerkleTree::generate(&transactions_with_spv_b).unwrap();
406 assert_eq!(
407 root_original,
408 merkle_tree_spv_b.get_root_hash(),
409 "Merkle roots should be equal after replacing b with b(SPV)."
410 );
411
412 let transactions_all_spv = transactions
414 .iter()
415 .map(|tx| create_spv_transaction(&wallet, tx.clone(), 1))
416 .collect::<Vec<Transaction>>();
417 let merkle_tree_all_spv = MerkleTree::generate(&transactions_all_spv).unwrap();
418 assert_eq!(
419 root_original,
420 merkle_tree_all_spv.get_root_hash(),
421 "Merkle roots should be equal after replacing all with SPV transactions."
422 );
423
424 let mut transactions_cd_spv = transactions.clone();
427 let combined_tx3_tx4: Transaction =
428 combine_transactions_into_spv(transactions[2].clone(), transactions[3].clone());
429
430 dbg!(&combined_tx3_tx4);
431 transactions_cd_spv.splice(2..4, std::iter::once(combined_tx3_tx4));
432 let merkle_tree_cd_spv = MerkleTree::generate(&transactions_cd_spv).unwrap();
433 assert_eq!(
434 root_original,
435 merkle_tree_cd_spv.get_root_hash(),
436 "Merkle roots should be equal after replacing cd with combined cd(SPV)."
437 );
438
439 let mut transactions_mixed_spv = transactions.clone();
441 transactions_mixed_spv[1] = create_spv_transaction(&wallet, transactions[1].clone(), 1);
442 transactions_mixed_spv.splice(
443 2..4,
444 std::iter::once(combine_transactions_into_spv(
445 transactions[2].clone(),
446 transactions[3].clone(),
447 )),
448 );
449 transactions_mixed_spv[3] = create_spv_transaction(&wallet, transactions[4].clone(), 1);
450 let merkle_tree_mixed_spv = MerkleTree::generate(&transactions_mixed_spv).unwrap();
451 assert_eq!(
452 root_original,
453 merkle_tree_mixed_spv.get_root_hash(),
454 "Merkle roots should be equal after various SPV replacements."
455 );
456
457 let mut transactions_broken_spv = transactions_mixed_spv.clone();
459 transactions_broken_spv[1].txs_replacements = 3;
460 let merkle_tree_broken_spv = MerkleTree::generate(&transactions_broken_spv).unwrap();
461 assert_ne!(
462 root_original,
463 merkle_tree_broken_spv.get_root_hash(),
464 "Merkle root should differ due to changed SPV tx replacements."
465 );
466 }
467
468 fn create_spv_transaction(
469 wallet: &Wallet,
470 mut tx: Transaction,
471 txs_replacements: u32,
472 ) -> Transaction {
473 tx.sign(&wallet.private_key);
474 Transaction {
475 timestamp: tx.timestamp,
476 from: tx.from,
477 to: tx.to,
478 data: tx.data,
479 transaction_type: TransactionType::SPV,
480 txs_replacements, signature: tx.signature,
482 path: tx.path,
483 hash_for_signature: tx.hash_for_signature,
484 total_in: tx.total_in,
485 total_out: tx.total_out,
486 total_fees: tx.total_fees,
487 total_work_for_me: tx.total_work_for_me,
488 cumulative_fees: tx.cumulative_fees,
489 }
490 }
491 fn combine_transactions_into_spv(mut tx1: Transaction, tx2: Transaction) -> Transaction {
492 let combined_hash =
493 MerkleTree::compute_combined_hash(tx1.hash_for_signature, tx2.hash_for_signature);
494
495 dbg!(combined_hash);
496 tx1.hash_for_signature = tx2.hash_for_signature;
497 tx1.transaction_type = TransactionType::SPV;
498 tx1.txs_replacements = 2;
499 tx1
500 }
506}