Merge pull request #2 from wheals/master
[rust-101.git] / solutions / src / counter.rs
1 use std::sync::{Arc, RwLock};
2 use std::thread;
3
4 #[derive(Clone)]
5 pub struct ConcurrentCounter(Arc<RwLock<usize>>);
6
7 impl ConcurrentCounter {
8     // The constructor should not be surprising.
9     pub fn new(val: usize) -> Self {
10         ConcurrentCounter(Arc::new(RwLock::new(val)))
11     }
12
13     pub fn increment(&self, by: usize) {
14         let mut counter = self.0.write().unwrap_or_else(|e| e.into_inner());
15         *counter = *counter + by;
16     }
17
18     pub fn compare_and_inc(&self, test: usize, by: usize) {
19         let mut counter = self.0.write().unwrap_or_else(|e| e.into_inner());
20         if *counter == test {
21             *counter += by;
22         }
23     }
24
25     pub fn get(&self) -> usize {
26         let counter = self.0.read().unwrap_or_else(|e| e.into_inner());
27         *counter
28     }
29 }
30
31 // Now our counter is ready for action.
32 pub fn main() {
33     let counter = ConcurrentCounter::new(0);
34
35     // We clone the counter for the first thread, which increments it by 2 every 15ms.
36     let counter1 = counter.clone();
37     let handle1 = thread::spawn(move || {
38         for _ in 0..10 {
39             thread::sleep_ms(15);
40             counter1.increment(2);
41         }
42     });
43
44     // The second thread increments the counter by 3 every 20ms.
45     let counter2 = counter.clone();
46     let handle2 = thread::spawn(move || {
47         for _ in 0..10 {
48             thread::sleep_ms(20);
49             counter2.increment(3);
50         }
51     });
52
53     // Now we want to watch the threads working on the counter.
54     for _ in 0..50 {
55         thread::sleep_ms(5);
56         println!("Current value: {}", counter.get());
57     }
58
59     // Finally, wait for all the threads to finish to be sure we can catch the counter's final value.
60     handle1.join().unwrap();
61     handle2.join().unwrap();
62     println!("Final value: {}", counter.get());
63 }