write part 16
[rust-101.git] / solutions / src / list.rs
1
2 use std::ptr;
3 use std::mem;
4 use std::marker::PhantomData;
5
6 fn box_into_raw<T>(b: Box<T>) -> *mut T {
7     unsafe { mem::transmute(b) }
8 }
9 unsafe fn raw_into_box<T>(r: *mut T) -> Box<T> {
10     mem::transmute(r)
11 }
12
13 struct Node<T> {
14     data: T,
15     next: NodePtr<T>,
16     prev: NodePtr<T>,
17 }
18 type NodePtr<T> = *mut Node<T>;
19
20 pub struct LinkedList<T> {
21     first: NodePtr<T>,
22     last:  NodePtr<T>,
23     _marker: PhantomData<T>,
24 }
25
26 impl<T> LinkedList<T> {
27     pub fn new() -> Self {
28         LinkedList { first: ptr::null_mut(), last: ptr::null_mut(), _marker: PhantomData }
29     }
30
31     pub fn push_back(&mut self, t: T) {
32         // Create the new node.
33         let new = Box::new( Node { data: t, next: ptr::null_mut(), prev: self.last } );
34         let new = box_into_raw(new);
35         // Update other points to this node.
36         if self.last.is_null() {
37             debug_assert!(self.first.is_null());
38             self.first = new;
39         } else {
40             debug_assert!(!self.first.is_null());
41             unsafe { (*self.last).next  = new; }
42         }
43         // Make this the last node.
44         self.last = new;
45     }
46
47     pub fn pop_back(&mut self) -> Option<T> {
48         if self.last.is_null() {
49             None
50         } else {
51             let last = self.last;
52             let new_last = unsafe { (*self.last).prev };
53             self.last = new_last;
54             if new_last.is_null() {
55                 // The list is now empty.
56                 self.first = new_last;
57             }
58             let last = unsafe { raw_into_box(last) } ;
59             Some(last.data)
60         }
61     }
62
63     pub fn push_front(&mut self, t: T) {
64         // Create the new node.
65         let new = Box::new( Node { data: t, next: self.first, prev: ptr::null_mut() } );
66         let new = box_into_raw(new);
67         // Update other points to this node.
68         if self.first.is_null() {
69             debug_assert!(self.last.is_null());
70             self.last = new;
71         }
72         else {
73             debug_assert!(!self.last.is_null());
74             unsafe { (*self.first).prev = new; }
75         }
76         // Make this the first node.
77         self.first = new;
78     }
79
80     pub fn pop_front(&mut self) -> Option<T> {
81         if self.first.is_null() {
82             None
83         } else {
84             let first = self.first;
85             let new_first = unsafe { (*self.first).next };
86             self.first = new_first;
87             if new_first.is_null() {
88                 // The list is now empty.
89                 self.last = new_first;
90             }
91             let first = unsafe { raw_into_box(first) } ;
92             Some(first.data)
93         }
94     }
95
96     pub fn for_each<F: FnMut(&mut T)>(&mut self, mut f: F) {
97         let mut cur_ptr = self.first;
98         while !cur_ptr.is_null() {
99             // Iterate over every node, and call `f`.
100             f(unsafe{ &mut (*cur_ptr).data });
101             cur_ptr = unsafe{ (*cur_ptr).next };
102         }
103     }
104
105     pub fn iter_mut(&self) -> IterMut<T> {
106         IterMut { next: self.first, _marker: PhantomData  }
107     }
108 }
109
110 pub struct IterMut<'a, T> where T: 'a {
111     next: NodePtr<T>,
112     _marker: PhantomData<&'a T>,
113 }
114
115 impl<'a, T> Iterator for IterMut<'a, T> {
116     type Item = &'a mut T;
117
118     fn next(&mut self) -> Option<Self::Item> {
119         if self.next.is_null() {
120            None
121         } else {
122             let ret = unsafe{ &mut (*self.next).data };
123             self.next = unsafe { (*self.next).next };
124             Some(ret)
125         }
126     }
127 }
128
129 impl<T> Drop for LinkedList<T> {
130     fn drop(&mut self) {
131         let mut cur_ptr = self.first;
132         while !cur_ptr.is_null() {
133             let cur = unsafe { raw_into_box(cur_ptr) };
134             cur_ptr = cur.next;
135             drop(cur);
136         }
137     }
138 }
139
140 #[cfg(test)]
141 mod tests {
142     use std::rc::Rc;
143     use std::cell::Cell;
144     use super::LinkedList;
145
146     #[test]
147     fn test_pop_back() {
148         let mut l: LinkedList<i32> = LinkedList::new();
149         for i in 0..3 {
150             l.push_front(-i);
151             l.push_back(i);
152         }
153
154         assert_eq!(l.pop_back(), Some(2));
155         assert_eq!(l.pop_back(), Some(1));
156         assert_eq!(l.pop_back(), Some(0));
157         assert_eq!(l.pop_back(), Some(-0));
158         assert_eq!(l.pop_back(), Some(-1));
159         assert_eq!(l.pop_back(), Some(-2));
160         assert_eq!(l.pop_back(), None);
161         assert_eq!(l.pop_back(), None);
162     }
163
164     #[test]
165     fn test_pop_front() {
166         let mut l: LinkedList<i32> = LinkedList::new();
167         for i in 0..3 {
168             l.push_front(-i);
169             l.push_back(i);
170         }
171
172         assert_eq!(l.pop_front(), Some(-2));
173         assert_eq!(l.pop_front(), Some(-1));
174         assert_eq!(l.pop_front(), Some(-0));
175         assert_eq!(l.pop_front(), Some(0));
176         assert_eq!(l.pop_front(), Some(1));
177         assert_eq!(l.pop_front(), Some(2));
178         assert_eq!(l.pop_front(), None);
179         assert_eq!(l.pop_front(), None);
180     }
181
182     #[derive(Clone)]
183     struct DropChecker {
184         count: Rc<Cell<usize>>,
185     }
186     impl Drop for DropChecker {
187         fn drop(&mut self) {
188             self.count.set(self.count.get() + 1);
189         }
190     }
191
192     #[test]
193     fn test_drop() {
194         let count = DropChecker { count: Rc::new(Cell::new(0)) };
195         {
196             let mut l = LinkedList::new();
197             for _ in 0..10 {
198                 l.push_back(count.clone());
199                 l.push_front(count.clone());
200             }
201         }
202         assert_eq!(count.count.get(), 20);
203     }
204 }