Merge pull request #34 from louy2/patch-4
[rust-101.git] / solutions / src / counter.rs
1 use std::sync::{Arc, RwLock};
2 use std::thread;
3 use std::time::Duration;
4
5 #[derive(Clone)]
6 pub struct ConcurrentCounter(Arc<RwLock<usize>>);
7
8 impl ConcurrentCounter {
9     // The constructor should not be surprising.
10     pub fn new(val: usize) -> Self {
11         ConcurrentCounter(Arc::new(RwLock::new(val)))
12     }
13
14     pub fn increment(&self, by: usize) {
15         let mut counter = self.0.write().unwrap_or_else(|e| e.into_inner());
16         *counter = *counter + by;
17     }
18
19     pub fn compare_and_inc(&self, test: usize, by: usize) {
20         let mut counter = self.0.write().unwrap_or_else(|e| e.into_inner());
21         if *counter == test {
22             *counter += by;
23         }
24     }
25
26     pub fn get(&self) -> usize {
27         let counter = self.0.read().unwrap_or_else(|e| e.into_inner());
28         *counter
29     }
30 }
31
32 // Now our counter is ready for action.
33 pub fn main() {
34     let counter = ConcurrentCounter::new(0);
35
36     // We clone the counter for the first thread, which increments it by 2 every 15ms.
37     let counter1 = counter.clone();
38     let handle1 = thread::spawn(move || {
39         for _ in 0..10 {
40             thread::sleep(Duration::from_millis(15));
41             counter1.increment(2);
42         }
43     });
44
45     // The second thread increments the counter by 3 every 20ms.
46     let counter2 = counter.clone();
47     let handle2 = thread::spawn(move || {
48         for _ in 0..10 {
49             thread::sleep(Duration::from_millis(20));
50             counter2.increment(3);
51         }
52     });
53
54     // Now we want to watch the threads working on the counter.
55     for _ in 0..50 {
56         thread::sleep(Duration::from_millis(5));
57         println!("Current value: {}", counter.get());
58     }
59
60     // Finally, wait for all the threads to finish to be sure we can catch the counter's final value.
61     handle1.join().unwrap();
62     handle2.join().unwrap();
63     println!("Final value: {}", counter.get());
64 }