Fix potential invalid memory reference.
[rust-101.git] / solutions / src / list.rs
1 use std::ptr;
2 use std::mem;
3 use std::marker::PhantomData;
4
5 fn box_into_raw<T>(b: Box<T>) -> *mut T {
6     unsafe { mem::transmute(b) }
7 }
8 unsafe fn raw_into_box<T>(r: *mut T) -> Box<T> {
9     mem::transmute(r)
10 }
11
12 struct Node<T> {
13     data: T,
14     next: NodePtr<T>,
15     prev: NodePtr<T>,
16 }
17 type NodePtr<T> = *mut Node<T>;
18
19 pub struct LinkedList<T> {
20     first: NodePtr<T>,
21     last:  NodePtr<T>,
22     _marker: PhantomData<T>,
23 }
24
25 impl<T> LinkedList<T> {
26     pub fn new() -> Self {
27         LinkedList { first: ptr::null_mut(), last: ptr::null_mut(), _marker: PhantomData }
28     }
29
30     pub fn push_back(&mut self, t: T) {
31         // Create the new node.
32         let new = Box::new( Node { data: t, next: ptr::null_mut(), prev: self.last } );
33         let new = box_into_raw(new);
34         // Update other points to this node.
35         if self.last.is_null() {
36             debug_assert!(self.first.is_null());
37             self.first = new;
38         } else {
39             debug_assert!(!self.first.is_null());
40             unsafe { (*self.last).next  = new; }
41         }
42         // Make this the last node.
43         self.last = new;
44     }
45
46     pub fn pop_back(&mut self) -> Option<T> {
47         if self.last.is_null() {
48             None
49         } else {
50             let last = self.last;
51             let new_last = unsafe { (*self.last).prev };
52             self.last = new_last;
53             if new_last.is_null() {
54                 // The list is now empty.
55                 self.first = new_last;
56             } else {
57                 unsafe { (*new_last).next = ptr::null_mut() };
58             }
59             let last = unsafe { raw_into_box(last) } ;
60             Some(last.data)
61         }
62     }
63
64     pub fn push_front(&mut self, t: T) {
65         // Create the new node.
66         let new = Box::new( Node { data: t, next: self.first, prev: ptr::null_mut() } );
67         let new = box_into_raw(new);
68         // Update other points to this node.
69         if self.first.is_null() {
70             debug_assert!(self.last.is_null());
71             self.last = new;
72         }
73         else {
74             debug_assert!(!self.last.is_null());
75             unsafe { (*self.first).prev = new; }
76         }
77         // Make this the first node.
78         self.first = new;
79     }
80
81     pub fn pop_front(&mut self) -> Option<T> {
82         if self.first.is_null() {
83             None
84         } else {
85             let first = self.first;
86             let new_first = unsafe { (*self.first).next };
87             self.first = new_first;
88             if new_first.is_null() {
89                 // The list is now empty.
90                 self.last = new_first;
91             } else {
92                 unsafe { (*new_first).prev = ptr::null_mut() };
93             }
94             let first = unsafe { raw_into_box(first) } ;
95             Some(first.data)
96         }
97     }
98
99     pub fn for_each<F: FnMut(&mut T)>(&mut self, mut f: F) {
100         let mut cur_ptr = self.first;
101         while !cur_ptr.is_null() {
102             // Iterate over every node, and call `f`.
103             f(unsafe{ &mut (*cur_ptr).data });
104             cur_ptr = unsafe{ (*cur_ptr).next };
105         }
106     }
107
108     pub fn iter_mut(&mut self) -> IterMut<T> {
109         IterMut { next: self.first, _marker: PhantomData  }
110     }
111 }
112
113 pub struct IterMut<'a, T> where T: 'a {
114     next: NodePtr<T>,
115     _marker: PhantomData<&'a T>,
116 }
117
118 impl<'a, T> Iterator for IterMut<'a, T> {
119     type Item = &'a mut T;
120
121     fn next(&mut self) -> Option<Self::Item> {
122         if self.next.is_null() {
123            None
124         } else {
125             let ret = unsafe{ &mut (*self.next).data };
126             self.next = unsafe { (*self.next).next };
127             Some(ret)
128         }
129     }
130 }
131
132 impl<T> Drop for LinkedList<T> {
133     fn drop(&mut self) {
134         let mut cur_ptr = self.first;
135         while !cur_ptr.is_null() {
136             let cur = unsafe { raw_into_box(cur_ptr) };
137             cur_ptr = cur.next;
138             drop(cur);
139         }
140     }
141 }
142
143 #[cfg(test)]
144 mod tests {
145     use std::rc::Rc;
146     use std::cell::Cell;
147     use super::LinkedList;
148
149     #[test]
150     fn test_pop_back() {
151         let mut l: LinkedList<i32> = LinkedList::new();
152         for i in 0..3 {
153             l.push_front(-i);
154             l.push_back(i);
155         }
156
157         assert_eq!(l.pop_back(), Some(2));
158         assert_eq!(l.pop_back(), Some(1));
159         assert_eq!(l.pop_back(), Some(0));
160         assert_eq!(l.pop_back(), Some(-0));
161         assert_eq!(l.pop_back(), Some(-1));
162         assert_eq!(l.pop_back(), Some(-2));
163         assert_eq!(l.pop_back(), None);
164         assert_eq!(l.pop_back(), None);
165     }
166
167     #[test]
168     fn test_pop_front() {
169         let mut l: LinkedList<i32> = LinkedList::new();
170         for i in 0..3 {
171             l.push_front(-i);
172             l.push_back(i);
173         }
174
175         assert_eq!(l.pop_front(), Some(-2));
176         assert_eq!(l.pop_front(), Some(-1));
177         assert_eq!(l.pop_front(), Some(-0));
178         assert_eq!(l.pop_front(), Some(0));
179         assert_eq!(l.pop_front(), Some(1));
180         assert_eq!(l.pop_front(), Some(2));
181         assert_eq!(l.pop_front(), None);
182         assert_eq!(l.pop_front(), None);
183     }
184
185     #[derive(Clone)]
186     struct DropChecker {
187         count: Rc<Cell<usize>>,
188     }
189     impl Drop for DropChecker {
190         fn drop(&mut self) {
191             self.count.set(self.count.get() + 1);
192         }
193     }
194
195     #[test]
196     fn test_drop() {
197         let count = DropChecker { count: Rc::new(Cell::new(0)) };
198         {
199             let mut l = LinkedList::new();
200             for _ in 0..10 {
201                 l.push_back(count.clone());
202                 l.push_front(count.clone());
203             }
204         }
205         assert_eq!(count.count.get(), 20);
206     }
207 }