7a61b4b5de9ccfac4a6dc717145055f181a88494
[rust-101.git] / solutions / src / bigint.rs
1 use std::ops;
2 use std::cmp;
3 use std::fmt;
4
5 pub trait Minimum {
6     /// Return the smaller of the two
7     fn min<'a>(&'a self, other: &'a Self) -> &'a Self;
8 }
9
10 /// Return a pointer to the minimal value of `v`.
11 pub fn vec_min<T: Minimum>(v: &Vec<T>) -> Option<&T> {
12     let mut min = None;
13     for e in v {
14         min = Some(match min {
15             None => e,
16             Some(n) => e.min(n)
17         });
18     }
19     min
20 }
21
22 pub struct BigInt {
23     data: Vec<u64>, // least significant digits first. The last block will *not* be 0.
24 }
25
26 // Add with carry, returning the sum and the carry
27 fn overflowing_add(a: u64, b: u64, carry: bool) -> (u64, bool) {
28     let sum = u64::wrapping_add(a, b);
29     let carry_n = if carry { 1 } else { 0 };
30     if sum >= a { // the first sum did not overflow
31         let sum_total = u64::wrapping_add(sum, carry_n);
32         let had_overflow = sum_total < sum;
33         (sum_total, had_overflow)
34     } else { // the first sum did overflow
35         // it is impossible for this to overflow again, as we are just adding 0 or 1
36         (sum + carry_n, true)
37     }
38 }
39
40 // Subtract with carry, returning the difference and the carry
41 fn overflowing_sub(a: u64, b: u64, carry: bool) -> (u64, bool) {
42     let diff = u64::wrapping_sub(a, b);
43     let carry_n = if carry { 1 } else { 0 };
44     if diff <= a { // the first diff did not wrap
45         let diff_total = u64::wrapping_sub(diff, carry_n);
46         let had_wrap = diff_total > diff;
47         (diff_total, had_wrap)
48     } else { // the first diff did wrap
49         // it is impossible for this to wrap again, as we are just substracting 0 or 1
50         (diff - carry_n, true)
51     }
52 }
53
54 impl BigInt {
55     /// Construct a BigInt from a "small" one.
56     pub fn new(x: u64) -> Self {
57         if x == 0 { // take care of our invariant!
58             BigInt { data: vec![] }
59         } else {
60             BigInt { data: vec![x] }
61         }
62     }
63
64     fn test_invariant(&self) -> bool {
65         if self.data.len() == 0 {
66             true
67         } else {
68             self.data[self.data.len() - 1] != 0
69         }
70     }
71
72     /// Construct a BigInt from a vector of 64-bit "digits", with the last significant digit being first. Solution to 05.1.
73     pub fn from_vec(mut v: Vec<u64>) -> Self {
74         // remove trailing zeros
75         while v.len() > 0 && v[v.len()-1] == 0 {
76             v.pop();
77         }
78         BigInt { data: v }
79     }
80
81     /// Increments the number by 1.
82     pub fn inc1(&mut self) {
83         let mut idx = 0;
84         // This loop adds "(1 << idx)". If there is no more carry, we leave.
85         while idx < self.data.len() {
86             let cur = self.data[idx];
87             let sum = u64::wrapping_add(cur, 1);
88             self.data[idx] = sum;
89             if sum >= cur {
90                 // No overflow, we are done.
91                 return;
92             } else {
93                 // We need to go on.
94                 idx += 1;
95             }
96         }
97         // If we came here, there is a last carry to add
98         self.data.push(1);
99     }
100
101     /// Increments the number by "by".
102     pub fn inc(&mut self, mut by: u64) {
103         let mut idx = 0;
104         // This loop adds "by * (1 << idx)". Think of "by" as the carry from incrementing the last digit.
105         while idx < self.data.len() {
106             let cur = self.data[idx];
107             let sum = u64::wrapping_add(cur, by);
108             self.data[idx] = sum;
109             if sum >= cur {
110                 // No overflow, we are done.
111                 return;
112             } else {
113                 // We need to add a carry.
114                 by = 1;
115                 idx += 1;
116             }
117         }
118         // If we came here, there is a last carry to add
119         self.data.push(by);
120     }
121
122     /// Return the nth power-of-2 as BigInt
123     pub fn power_of_2(mut power: u64) -> BigInt {
124         let mut v = Vec::new();
125         while power >= 64 {
126             v.push(0);
127             power -= 64;
128         }
129         v.push(1 << power);
130         BigInt::from_vec(v)
131     }
132 }
133
134 impl Clone for BigInt {
135     fn clone(&self) -> Self {
136         BigInt { data: self.data.clone() }
137     }
138 }
139
140 impl PartialEq for BigInt {
141     fn eq(&self, other: &BigInt) -> bool {
142         debug_assert!(self.test_invariant() && other.test_invariant());
143         self.data == other.data
144     }
145 }
146
147 impl Minimum for BigInt {
148     // This is essentially the solution to 06.1.
149     fn min<'a>(&'a self, other: &'a Self) -> &'a Self {
150         debug_assert!(self.test_invariant() && other.test_invariant());
151         if self.data.len() < other.data.len() {
152             self
153         } else if self.data.len() > other.data.len() {
154             other
155         } else {
156             // compare back-to-front, i.e., most significant digit first
157             let mut idx = self.data.len()-1;
158             while idx > 0 {
159                 if self.data[idx] < other.data[idx] {
160                     return self;
161                 } else if self.data[idx] > other.data[idx] {
162                     return other;
163                 }
164                 else {
165                     idx = idx-1;
166                 }
167             }
168             // the two are equal
169             return self;
170         }
171     }
172 }
173
174 impl fmt::Debug for BigInt {
175     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
176         self.data.fmt(f)
177     }
178 }
179
180 impl<'a, 'b> ops::Add<&'a BigInt> for &'b BigInt {
181     type Output = BigInt;
182     fn add(self, rhs: &'a BigInt) -> Self::Output {
183         let max_len = cmp::max(self.data.len(), rhs.data.len());
184         let mut result_vec:Vec<u64> = Vec::with_capacity(max_len);
185         let mut carry:bool = false; // the carry bit
186         for i in 0..max_len {
187             // compute next digit and carry
188             let lhs_val = if i < self.data.len() { self.data[i] } else { 0 };
189             let rhs_val = if i < rhs.data.len() { rhs.data[i] } else { 0 };
190             let (sum, new_carry) = overflowing_add(lhs_val, rhs_val, carry);
191             // store them
192             result_vec.push(sum);
193             carry = new_carry;
194         }
195         if carry {
196             result_vec.push(1);
197         }
198         // We know that the invariant holds: overflowing_add would only return (0, false) if
199         // the arguments are (0, 0, false), but we know that in the last iteration, one od the two digits
200         // is the last of its number and hence not 0.
201         BigInt { data: result_vec }
202     }
203 }
204
205 impl<'a> ops::Add<BigInt> for &'a BigInt {
206     type Output = BigInt;
207     #[inline]
208     fn add(self, rhs: BigInt) -> Self::Output {
209         self + &rhs
210     }
211 }
212
213 impl<'a> ops::Add<&'a BigInt> for BigInt {
214     type Output = BigInt;
215     #[inline]
216     fn add(self, rhs: &'a BigInt) -> Self::Output {
217         &self + rhs
218     }
219 }
220
221 impl ops::Add<BigInt> for BigInt {
222     type Output = BigInt;
223     #[inline]
224     fn add(self, rhs: BigInt) -> Self::Output {
225         &self + &rhs
226     }
227 }
228
229 impl<'a, 'b> ops::Sub<&'a BigInt> for &'b BigInt {
230     type Output = BigInt;
231     fn sub(self, rhs: &'a BigInt) -> Self::Output {
232         let max_len = cmp::max(self.data.len(), rhs.data.len());
233         let mut result_vec:Vec<u64> = Vec::with_capacity(max_len);
234         let mut carry:bool = false; // the carry bit
235         for i in 0..max_len {
236             // compute next digit and carry
237             let lhs_val = if i < self.data.len() { self.data[i] } else { 0 };
238             let rhs_val = if i < rhs.data.len() { rhs.data[i] } else { 0 };
239             let (sum, new_carry) = overflowing_sub(lhs_val, rhs_val, carry);
240             // store them
241             result_vec.push(sum);
242             carry = new_carry;
243         }
244         if carry {
245             panic!("Wrapping subtraction of BigInt");
246         }
247         // We may have trailing zeroes, so get rid of them
248         BigInt::from_vec(result_vec)
249     }
250 }
251
252 impl<'a> ops::Sub<BigInt> for &'a BigInt {
253     type Output = BigInt;
254     #[inline]
255     fn sub(self, rhs: BigInt) -> Self::Output {
256         self - &rhs
257     }
258 }
259
260 impl<'a> ops::Sub<&'a BigInt> for BigInt {
261     type Output = BigInt;
262     #[inline]
263     fn sub(self, rhs: &'a BigInt) -> Self::Output {
264         &self - rhs
265     }
266 }
267
268 impl ops::Sub<BigInt> for BigInt {
269     type Output = BigInt;
270     #[inline]
271     fn sub(self, rhs: BigInt) -> Self::Output {
272         &self - &rhs
273     }
274 }
275
276 #[cfg(test)]
277 mod tests {
278     use std::u64;
279     use super::{overflowing_add,overflowing_sub,BigInt};
280
281     #[test]
282     fn test_overflowing_add() {
283         assert_eq!(overflowing_add(10, 100, false), (110, false));
284         assert_eq!(overflowing_add(10, 100, true), (111, false));
285         assert_eq!(overflowing_add(1 << 63, 1 << 63, false), (0, true));
286         assert_eq!(overflowing_add(1 << 63, 1 << 63, true), (1, true));
287         assert_eq!(overflowing_add(1 << 63, (1 << 63) -1 , true), (0, true));
288     }
289
290     #[test]
291     fn test_overflowing_sub() {
292         assert_eq!(overflowing_sub(100, 10, false), (90, false));
293         assert_eq!(overflowing_sub(100, 10, true), (89, false));
294         assert_eq!(overflowing_sub(10, 1 << 63, false), ((1 << 63) + 10, true));
295         assert_eq!(overflowing_sub(10, 1 << 63, true), ((1 << 63) + 9, true));
296         assert_eq!(overflowing_sub(42, 42 , true), (u64::max_value(), true));
297     }
298
299     #[test]
300     fn test_add() {
301         let b1 = BigInt::new(1 << 32);
302         let b2 = BigInt::from_vec(vec![0, 1]);
303         let b3 = BigInt::from_vec(vec![0, 0, 1]);
304         let b4 = BigInt::new(1 << 63);
305
306         assert_eq!(&b1 + &b2, BigInt::from_vec(vec![1 << 32, 1]));
307         assert_eq!(&b2 + &b1, BigInt::from_vec(vec![1 << 32, 1]));
308         assert_eq!(&b2 + &b3, BigInt::from_vec(vec![0, 1, 1]));
309         assert_eq!(&b2 + &b3 + &b4 + &b4, BigInt::from_vec(vec![0, 2, 1]));
310         assert_eq!(&b2 + &b4 + &b3 + &b4, BigInt::from_vec(vec![0, 2, 1]));
311         assert_eq!(&b4 + &b2 + &b3 + &b4, BigInt::from_vec(vec![0, 2, 1]));
312     }
313
314     #[test]
315     fn test_sub() {
316         let b1 = BigInt::new(1 << 32);
317         let b2 = BigInt::from_vec(vec![0, 1]);
318         let b3 = BigInt::from_vec(vec![0, 0, 1]);
319         let b4 = BigInt::new(1 << 63);
320
321         assert_eq!(&b2 - &b1, BigInt::from_vec(vec![u64::max_value() - (1 << 32) + 1]));
322         assert_eq!(&b3 - &b2, BigInt::from_vec(vec![0, u64::max_value(), 0]));
323         assert_eq!(&b2 - &b4 - &b4, BigInt::from_vec(vec![0]));
324         assert_eq!(&b3 - &b2 - &b4 - &b4, BigInt::from_vec(vec![0, u64::max_value() - 1]));
325         assert_eq!(&b3 - &b4 - &b2 - &b4, BigInt::from_vec(vec![0, u64::max_value() - 1]));
326         assert_eq!(&b3 - &b4 - &b4 - &b2, BigInt::from_vec(vec![0, u64::max_value() - 1]));
327     }
328
329     #[test]
330     fn test_inc1() {
331         let mut b = BigInt::new(0);
332         b.inc1();
333         assert_eq!(b, BigInt::new(1));
334         b.inc1();
335         assert_eq!(b, BigInt::new(2));
336
337         b = BigInt::new(u64::MAX);
338         b.inc1();
339         assert_eq!(b, BigInt::from_vec(vec![0, 1]));
340         b.inc1();
341         assert_eq!(b, BigInt::from_vec(vec![1, 1]));
342     }
343
344     #[test]
345     fn test_power_of_2() {
346         assert_eq!(BigInt::power_of_2(0), BigInt::new(1));
347         assert_eq!(BigInt::power_of_2(13), BigInt::new(1 << 13));
348         assert_eq!(BigInt::power_of_2(64), BigInt::from_vec(vec![0, 1]));
349         assert_eq!(BigInt::power_of_2(96), BigInt::from_vec(vec![0, 1 << 32]));
350         assert_eq!(BigInt::power_of_2(128), BigInt::from_vec(vec![0, 0, 1]));
351     }
352 }
353
354