typos and wording (thanks, Thomas!)
[rust-101.git] / solutions / src / list.rs
1 use std::ptr;
2 use std::mem;
3 use std::marker::PhantomData;
4
5 fn box_into_raw<T>(b: Box<T>) -> *mut T {
6     unsafe { mem::transmute(b) }
7 }
8 unsafe fn raw_into_box<T>(r: *mut T) -> Box<T> {
9     mem::transmute(r)
10 }
11
12 struct Node<T> {
13     data: T,
14     next: NodePtr<T>,
15     prev: NodePtr<T>,
16 }
17 type NodePtr<T> = *mut Node<T>;
18
19 pub struct LinkedList<T> {
20     first: NodePtr<T>,
21     last:  NodePtr<T>,
22     _marker: PhantomData<T>,
23 }
24
25 impl<T> LinkedList<T> {
26     pub fn new() -> Self {
27         LinkedList { first: ptr::null_mut(), last: ptr::null_mut(), _marker: PhantomData }
28     }
29
30     pub fn push_back(&mut self, t: T) {
31         // Create the new node.
32         let new = Box::new( Node { data: t, next: ptr::null_mut(), prev: self.last } );
33         let new = box_into_raw(new);
34         // Update other points to this node.
35         if self.last.is_null() {
36             debug_assert!(self.first.is_null());
37             self.first = new;
38         } else {
39             debug_assert!(!self.first.is_null());
40             unsafe { (*self.last).next  = new; }
41         }
42         // Make this the last node.
43         self.last = new;
44     }
45
46     pub fn pop_back(&mut self) -> Option<T> {
47         if self.last.is_null() {
48             None
49         } else {
50             let last = self.last;
51             let new_last = unsafe { (*self.last).prev };
52             self.last = new_last;
53             if new_last.is_null() {
54                 // The list is now empty.
55                 self.first = new_last;
56             }
57             let last = unsafe { raw_into_box(last) } ;
58             Some(last.data)
59         }
60     }
61
62     pub fn push_front(&mut self, t: T) {
63         // Create the new node.
64         let new = Box::new( Node { data: t, next: self.first, prev: ptr::null_mut() } );
65         let new = box_into_raw(new);
66         // Update other points to this node.
67         if self.first.is_null() {
68             debug_assert!(self.last.is_null());
69             self.last = new;
70         }
71         else {
72             debug_assert!(!self.last.is_null());
73             unsafe { (*self.first).prev = new; }
74         }
75         // Make this the first node.
76         self.first = new;
77     }
78
79     pub fn pop_front(&mut self) -> Option<T> {
80         if self.first.is_null() {
81             None
82         } else {
83             let first = self.first;
84             let new_first = unsafe { (*self.first).next };
85             self.first = new_first;
86             if new_first.is_null() {
87                 // The list is now empty.
88                 self.last = new_first;
89             }
90             let first = unsafe { raw_into_box(first) } ;
91             Some(first.data)
92         }
93     }
94
95     pub fn for_each<F: FnMut(&mut T)>(&mut self, mut f: F) {
96         let mut cur_ptr = self.first;
97         while !cur_ptr.is_null() {
98             // Iterate over every node, and call `f`.
99             f(unsafe{ &mut (*cur_ptr).data });
100             cur_ptr = unsafe{ (*cur_ptr).next };
101         }
102     }
103
104     pub fn iter_mut(&mut self) -> IterMut<T> {
105         IterMut { next: self.first, _marker: PhantomData  }
106     }
107 }
108
109 pub struct IterMut<'a, T> where T: 'a {
110     next: NodePtr<T>,
111     _marker: PhantomData<&'a T>,
112 }
113
114 impl<'a, T> Iterator for IterMut<'a, T> {
115     type Item = &'a mut T;
116
117     fn next(&mut self) -> Option<Self::Item> {
118         if self.next.is_null() {
119            None
120         } else {
121             let ret = unsafe{ &mut (*self.next).data };
122             self.next = unsafe { (*self.next).next };
123             Some(ret)
124         }
125     }
126 }
127
128 impl<T> Drop for LinkedList<T> {
129     fn drop(&mut self) {
130         let mut cur_ptr = self.first;
131         while !cur_ptr.is_null() {
132             let cur = unsafe { raw_into_box(cur_ptr) };
133             cur_ptr = cur.next;
134             drop(cur);
135         }
136     }
137 }
138
139 #[cfg(test)]
140 mod tests {
141     use std::rc::Rc;
142     use std::cell::Cell;
143     use super::LinkedList;
144
145     #[test]
146     fn test_pop_back() {
147         let mut l: LinkedList<i32> = LinkedList::new();
148         for i in 0..3 {
149             l.push_front(-i);
150             l.push_back(i);
151         }
152
153         assert_eq!(l.pop_back(), Some(2));
154         assert_eq!(l.pop_back(), Some(1));
155         assert_eq!(l.pop_back(), Some(0));
156         assert_eq!(l.pop_back(), Some(-0));
157         assert_eq!(l.pop_back(), Some(-1));
158         assert_eq!(l.pop_back(), Some(-2));
159         assert_eq!(l.pop_back(), None);
160         assert_eq!(l.pop_back(), None);
161     }
162
163     #[test]
164     fn test_pop_front() {
165         let mut l: LinkedList<i32> = LinkedList::new();
166         for i in 0..3 {
167             l.push_front(-i);
168             l.push_back(i);
169         }
170
171         assert_eq!(l.pop_front(), Some(-2));
172         assert_eq!(l.pop_front(), Some(-1));
173         assert_eq!(l.pop_front(), Some(-0));
174         assert_eq!(l.pop_front(), Some(0));
175         assert_eq!(l.pop_front(), Some(1));
176         assert_eq!(l.pop_front(), Some(2));
177         assert_eq!(l.pop_front(), None);
178         assert_eq!(l.pop_front(), None);
179     }
180
181     #[derive(Clone)]
182     struct DropChecker {
183         count: Rc<Cell<usize>>,
184     }
185     impl Drop for DropChecker {
186         fn drop(&mut self) {
187             self.count.set(self.count.get() + 1);
188         }
189     }
190
191     #[test]
192     fn test_drop() {
193         let count = DropChecker { count: Rc::new(Cell::new(0)) };
194         {
195             let mut l = LinkedList::new();
196             for _ in 0..10 {
197                 l.push_back(count.clone());
198                 l.push_front(count.clone());
199             }
200         }
201         assert_eq!(count.count.get(), 20);
202     }
203 }