Fixes as per review.
[rust-101.git] / solutions / src / bigint.rs
1 use std::ops;
2 use std::cmp;
3 use std::fmt;
4
5 pub trait Minimum {
6     /// Return the smaller of the two
7     fn min<'a>(&'a self, other: &'a Self) -> &'a Self;
8 }
9
10 /// Return a pointer to the minimal value of `v`.
11 pub fn vec_min<T: Minimum>(v: &Vec<T>) -> Option<&T> {
12     let mut min = None;
13     for e in v {
14         min = Some(match min {
15             None => e,
16             Some(n) => e.min(n)
17         });
18     }
19     min
20 }
21
22 pub struct BigInt {
23     data: Vec<u64>, // least significant digits first. The last block will *not* be 0.
24 }
25
26 // Add with carry, returning the sum and the carry
27 fn overflowing_add(a: u64, b: u64, carry: bool) -> (u64, bool) {
28     let sum = u64::wrapping_add(a, b);
29     let carry_n = if carry { 1 } else { 0 };
30     if sum >= a { // the first sum did not overflow
31         let sum_total = u64::wrapping_add(sum, carry_n);
32         let had_overflow = sum_total < sum;
33         (sum_total, had_overflow)
34     } else { // the first sum did overflow
35         // it is impossible for this to overflow again, as we are just adding 0 or 1
36         (sum + carry_n, true)
37     }
38 }
39
40 // Subtract with carry, returning the difference and the carry
41 fn overflowing_sub(a: u64, b: u64, carry: bool) -> (u64, bool) {
42     let diff = u64::wrapping_sub(a, b);
43     let carry_n = if carry { 1 } else { 0 };
44     if diff <= a { // the first diff did not wrap
45         let diff_total = u64::wrapping_sub(diff, carry_n);
46         let had_wrap = diff_total > diff;
47         (diff_total, had_wrap)
48     } else { // the first diff did wrap
49         // it is impossible for this to wrap again, as we are just substracting 0 or 1
50         (diff - carry_n, true)
51     }
52 }
53
54 impl BigInt {
55     /// Construct a BigInt from a "small" one.
56     pub fn new(x: u64) -> Self {
57         if x == 0 { // take care of our invariant!
58             BigInt { data: vec![] }
59         } else {
60             BigInt { data: vec![x] }
61         }
62     }
63
64     fn test_invariant(&self) -> bool {
65         if self.data.len() == 0 {
66             true
67         } else {
68             self.data[self.data.len() - 1] != 0
69         }
70     }
71
72     /// Construct a BigInt from a vector of 64-bit "digits", with the last significant digit being first. Solution to 05.1.
73     pub fn from_vec(mut v: Vec<u64>) -> Self {
74         // remove trailing zeros
75         while v.len() > 0 && v[v.len()-1] == 0 {
76             v.pop();
77         }
78         BigInt { data: v }
79     }
80
81     /// Increments the number by 1.
82     pub fn inc1(&mut self) {
83         let mut idx = 0;
84         // This loop adds "(1 << idx)". If there is no more carry, we leave.
85         while idx < self.data.len() {
86             let cur = self.data[idx];
87             let sum = u64::wrapping_add(cur, 1);
88             self.data[idx] = sum;
89             if sum >= cur {
90                 // No overflow, we are done.
91                 return;
92             } else {
93                 // We need to go on.
94                 idx += 1;
95             }
96         }
97         // If we came here, there is a last carry to add
98         self.data.push(1);
99     }
100
101     /// Increments the number by "by".
102     pub fn inc(&mut self, mut by: u64) {
103         let mut idx = 0;
104         // This loop adds "by * (1 << idx)". Think of "by" as the carry from incrementing the last digit.
105         while idx < self.data.len() {
106             let cur = self.data[idx];
107             let sum = u64::wrapping_add(cur, by);
108             self.data[idx] = sum;
109             if sum >= cur {
110                 // No overflow, we are done.
111                 return;
112             } else {
113                 // We need to add a carry.
114                 by = 1;
115                 idx += 1;
116             }
117         }
118         // If we came here, there is a last carry to add
119         self.data.push(by);
120     }
121
122     /// Return the nth power-of-2 as BigInt
123     pub fn power_of_2(mut power: u64) -> BigInt {
124         let mut v = Vec::new();
125         while power >= 64 {
126             v.push(0);
127             power -= 64;
128         }
129         v.push(1 << power);
130         BigInt::from_vec(v)
131     }
132 }
133
134 impl Clone for BigInt {
135     fn clone(&self) -> Self {
136         BigInt { data: self.data.clone() }
137     }
138 }
139
140 impl PartialEq for BigInt {
141     fn eq(&self, other: &BigInt) -> bool {
142         debug_assert!(self.test_invariant() && other.test_invariant());
143         self.data == other.data
144     }
145 }
146
147 impl Minimum for BigInt {
148     // This is essentially the solution to 06.1.
149     fn min<'a>(&'a self, other: &'a Self) -> &'a Self {
150         debug_assert!(self.test_invariant() && other.test_invariant());
151         if self.data.len() < other.data.len() {
152             self
153         } else if self.data.len() > other.data.len() {
154             other
155         } else {
156             // compare back-to-front, i.e., most significant digit first
157             let mut idx = self.data.len();
158             while idx > 0 {
159                 idx = idx-1;
160                 if self.data[idx] < other.data[idx] {
161                     return self;
162                 } else if self.data[idx] > other.data[idx] {
163                     return other;
164                 }
165             }
166             // the two are equal
167             return self;
168         }
169     }
170 }
171
172 impl fmt::Debug for BigInt {
173     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
174         self.data.fmt(f)
175     }
176 }
177
178 impl<'a, 'b> ops::Add<&'a BigInt> for &'b BigInt {
179     type Output = BigInt;
180     fn add(self, rhs: &'a BigInt) -> Self::Output {
181         let max_len = cmp::max(self.data.len(), rhs.data.len());
182         let mut result_vec:Vec<u64> = Vec::with_capacity(max_len);
183         let mut carry:bool = false; // the carry bit
184         for i in 0..max_len {
185             // compute next digit and carry
186             let lhs_val = if i < self.data.len() { self.data[i] } else { 0 };
187             let rhs_val = if i < rhs.data.len() { rhs.data[i] } else { 0 };
188             let (sum, new_carry) = overflowing_add(lhs_val, rhs_val, carry);
189             // store them
190             result_vec.push(sum);
191             carry = new_carry;
192         }
193         if carry {
194             result_vec.push(1);
195         }
196         // We know that the invariant holds: overflowing_add would only return (0, false) if
197         // the arguments are (0, 0, false), but we know that in the last iteration, one od the two digits
198         // is the last of its number and hence not 0.
199         BigInt { data: result_vec }
200     }
201 }
202
203 impl<'a> ops::Add<BigInt> for &'a BigInt {
204     type Output = BigInt;
205     #[inline]
206     fn add(self, rhs: BigInt) -> Self::Output {
207         self + &rhs
208     }
209 }
210
211 impl<'a> ops::Add<&'a BigInt> for BigInt {
212     type Output = BigInt;
213     #[inline]
214     fn add(self, rhs: &'a BigInt) -> Self::Output {
215         &self + rhs
216     }
217 }
218
219 impl ops::Add<BigInt> for BigInt {
220     type Output = BigInt;
221     #[inline]
222     fn add(self, rhs: BigInt) -> Self::Output {
223         &self + &rhs
224     }
225 }
226
227 impl<'a, 'b> ops::Sub<&'a BigInt> for &'b BigInt {
228     type Output = BigInt;
229     fn sub(self, rhs: &'a BigInt) -> Self::Output {
230         let max_len = cmp::max(self.data.len(), rhs.data.len());
231         let mut result_vec:Vec<u64> = Vec::with_capacity(max_len);
232         let mut carry:bool = false; // the carry bit
233         for i in 0..max_len {
234             // compute next digit and carry
235             let lhs_val = if i < self.data.len() { self.data[i] } else { 0 };
236             let rhs_val = if i < rhs.data.len() { rhs.data[i] } else { 0 };
237             let (sum, new_carry) = overflowing_sub(lhs_val, rhs_val, carry);
238             // store them
239             result_vec.push(sum);
240             carry = new_carry;
241         }
242         if carry {
243             panic!("Wrapping subtraction of BigInt");
244         }
245         // We may have trailing zeroes, so get rid of them
246         BigInt::from_vec(result_vec)
247     }
248 }
249
250 impl<'a> ops::Sub<BigInt> for &'a BigInt {
251     type Output = BigInt;
252     #[inline]
253     fn sub(self, rhs: BigInt) -> Self::Output {
254         self - &rhs
255     }
256 }
257
258 impl<'a> ops::Sub<&'a BigInt> for BigInt {
259     type Output = BigInt;
260     #[inline]
261     fn sub(self, rhs: &'a BigInt) -> Self::Output {
262         &self - rhs
263     }
264 }
265
266 impl ops::Sub<BigInt> for BigInt {
267     type Output = BigInt;
268     #[inline]
269     fn sub(self, rhs: BigInt) -> Self::Output {
270         &self - &rhs
271     }
272 }
273
274 #[cfg(test)]
275 mod tests {
276     use std::u64;
277     use super::{overflowing_add,overflowing_sub,BigInt,Minimum,vec_min};
278
279     #[test]
280     fn test_min() {
281         let b1 = BigInt::new(1);
282         let b2 = BigInt::new(42);
283         let b3 = BigInt::from_vec(vec![0, 1]);
284
285         assert_eq!(b1.min(&b2), &b1);
286         assert_eq!(b2.min(&b1), &b1);
287         assert_eq!(b3.min(&b2), &b2);
288         assert_eq!(b2.min(&b3), &b2);
289     }
290
291     #[test]
292     fn test_vec_min() {
293     let b1 = BigInt::new(1);
294         let b2 = BigInt::new(42);
295         let b3 = BigInt::from_vec(vec![0, 1]);
296
297         assert_eq!(vec_min(&vec![b2.clone(), b1.clone(), b3.clone()]), Some(&b1));
298         assert_eq!(vec_min(&vec![b2.clone(), b3.clone()]), Some(&b2));
299         assert_eq!(vec_min(&vec![b3.clone()]), Some(&b3));
300         assert_eq!(vec_min::<BigInt>(&vec![]), None);
301     }
302
303     #[test]
304     fn test_overflowing_add() {
305         assert_eq!(overflowing_add(10, 100, false), (110, false));
306         assert_eq!(overflowing_add(10, 100, true), (111, false));
307         assert_eq!(overflowing_add(1 << 63, 1 << 63, false), (0, true));
308         assert_eq!(overflowing_add(1 << 63, 1 << 63, true), (1, true));
309         assert_eq!(overflowing_add(1 << 63, (1 << 63) -1 , true), (0, true));
310     }
311
312     #[test]
313     fn test_overflowing_sub() {
314         assert_eq!(overflowing_sub(100, 10, false), (90, false));
315         assert_eq!(overflowing_sub(100, 10, true), (89, false));
316         assert_eq!(overflowing_sub(10, 1 << 63, false), ((1 << 63) + 10, true));
317         assert_eq!(overflowing_sub(10, 1 << 63, true), ((1 << 63) + 9, true));
318         assert_eq!(overflowing_sub(42, 42 , true), (u64::max_value(), true));
319     }
320
321     #[test]
322     fn test_add() {
323         let b1 = BigInt::new(1 << 32);
324         let b2 = BigInt::from_vec(vec![0, 1]);
325         let b3 = BigInt::from_vec(vec![0, 0, 1]);
326         let b4 = BigInt::new(1 << 63);
327
328         assert_eq!(&b1 + &b2, BigInt::from_vec(vec![1 << 32, 1]));
329         assert_eq!(&b2 + &b1, BigInt::from_vec(vec![1 << 32, 1]));
330         assert_eq!(&b2 + &b3, BigInt::from_vec(vec![0, 1, 1]));
331         assert_eq!(&b2 + &b3 + &b4 + &b4, BigInt::from_vec(vec![0, 2, 1]));
332         assert_eq!(&b2 + &b4 + &b3 + &b4, BigInt::from_vec(vec![0, 2, 1]));
333         assert_eq!(&b4 + &b2 + &b3 + &b4, BigInt::from_vec(vec![0, 2, 1]));
334     }
335
336     #[test]
337     fn test_sub() {
338         let b1 = BigInt::new(1 << 32);
339         let b2 = BigInt::from_vec(vec![0, 1]);
340         let b3 = BigInt::from_vec(vec![0, 0, 1]);
341         let b4 = BigInt::new(1 << 63);
342
343         assert_eq!(&b2 - &b1, BigInt::from_vec(vec![u64::max_value() - (1 << 32) + 1]));
344         assert_eq!(&b3 - &b2, BigInt::from_vec(vec![0, u64::max_value(), 0]));
345         assert_eq!(&b2 - &b4 - &b4, BigInt::new(0));
346         assert_eq!(&b3 - &b2 - &b4 - &b4, BigInt::from_vec(vec![0, u64::max_value() - 1]));
347         assert_eq!(&b3 - &b4 - &b2 - &b4, BigInt::from_vec(vec![0, u64::max_value() - 1]));
348         assert_eq!(&b3 - &b4 - &b4 - &b2, BigInt::from_vec(vec![0, u64::max_value() - 1]));
349     }
350
351     #[test]
352     #[should_panic(expected = "Wrapping subtraction of BigInt")]
353     fn test_sub_panic1() {
354         let _ = BigInt::new(1) - BigInt::new(5);
355     }
356
357     #[test]
358     #[should_panic(expected = "Wrapping subtraction of BigInt")]
359     fn test_sub_panic2() {
360         let _ = BigInt::from_vec(vec![5,8,3,33,1<<13,46,1<<49, 1, 583,1<<60,2533]) - BigInt::from_vec(vec![5,8,3,33,1<<13,46,1<<49, 5, 583,1<<60,2533]);
361     }
362
363     #[test]
364     fn test_inc1() {
365         let mut b = BigInt::new(0);
366         b.inc1();
367         assert_eq!(b, BigInt::new(1));
368         b.inc1();
369         assert_eq!(b, BigInt::new(2));
370
371         b = BigInt::new(u64::MAX);
372         b.inc1();
373         assert_eq!(b, BigInt::from_vec(vec![0, 1]));
374         b.inc1();
375         assert_eq!(b, BigInt::from_vec(vec![1, 1]));
376     }
377
378     #[test]
379     fn test_power_of_2() {
380         assert_eq!(BigInt::power_of_2(0), BigInt::new(1));
381         assert_eq!(BigInt::power_of_2(13), BigInt::new(1 << 13));
382         assert_eq!(BigInt::power_of_2(64), BigInt::from_vec(vec![0, 1]));
383         assert_eq!(BigInt::power_of_2(96), BigInt::from_vec(vec![0, 1 << 32]));
384         assert_eq!(BigInt::power_of_2(128), BigInt::from_vec(vec![0, 0, 1]));
385     }
386 }
387
388