120ae6c8d008e9d6b9d682728f6b7bc1f09a3657
[rust-101.git] / solutions / src / bigint.rs
1 use std::ops;
2 use std::cmp;
3 use std::fmt;
4
5 pub trait Minimum {
6     /// Return the smaller of the two
7     fn min<'a>(&'a self, other: &'a Self) -> &'a Self;
8 }
9
10 /// Return a pointer to the minimal value of `v`.
11 pub fn vec_min<T: Minimum>(v: &Vec<T>) -> Option<&T> {
12     let mut min = None;
13     for e in v {
14         min = Some(match min {
15             None => e,
16             Some(n) => e.min(n)
17         });
18     }
19     min
20 }
21
22 pub struct BigInt {
23     data: Vec<u64>, // least significant digits first. The last block will *not* be 0.
24 }
25
26 // Add with carry, returning the sum and the carry
27 fn overflowing_add(a: u64, b: u64, carry: bool) -> (u64, bool) {
28     let sum = u64::wrapping_add(a, b);
29     let carry_n = if carry { 1 } else { 0 };
30     if sum >= a { // the first sum did not overflow
31         let sum_total = u64::wrapping_add(sum, carry_n);
32         let had_overflow = sum_total < sum;
33         (sum_total, had_overflow)
34     } else { // the first sum did overflow
35         // it is impossible for this to overflow again, as we are just adding 0 or 1
36         (sum + carry_n, true)
37     }
38 }
39
40 impl BigInt {
41     /// Construct a BigInt from a "small" one.
42     pub fn new(x: u64) -> Self {
43         if x == 0 { // take care of our invariant!
44             BigInt { data: vec![] }
45         } else {
46             BigInt { data: vec![x] }
47         }
48     }
49
50     fn test_invariant(&self) -> bool {
51         if self.data.len() == 0 {
52             true
53         } else {
54             self.data[self.data.len() - 1] != 0
55         }
56     }
57
58     /// Construct a BigInt from a vector of 64-bit "digits", with the last significant digit being first. Solution to 05.1.
59     pub fn from_vec(mut v: Vec<u64>) -> Self {
60         // remove trailing zeros
61         while v.len() > 0 && v[v.len()-1] == 0 {
62             v.pop();
63         }
64         BigInt { data: v }
65     }
66
67     /// Increments the number by 1.
68     pub fn inc1(&mut self) {
69         let mut idx = 0;
70         // This loop adds "(1 << idx)". If there is no more carry, we leave.
71         while idx < self.data.len() {
72             let cur = self.data[idx];
73             let sum = u64::wrapping_add(cur, 1);
74             self.data[idx] = sum;
75             if sum >= cur {
76                 // No overflow, we are done.
77                 return;
78             } else {
79                 // We need to go on.
80                 idx += 1;
81             }
82         }
83         // If we came here, there is a last carry to add
84         self.data.push(1);
85     }
86
87     /// Increments the number by "by".
88     pub fn inc(&mut self, mut by: u64) {
89         let mut idx = 0;
90         // This loop adds "by * (1 << idx)". Think of "by" as the carry from incrementing the last digit.
91         while idx < self.data.len() {
92             let cur = self.data[idx];
93             let sum = u64::wrapping_add(cur, by);
94             self.data[idx] = sum;
95             if sum >= cur {
96                 // No overflow, we are done.
97                 return;
98             } else {
99                 // We need to add a carry.
100                 by = 1;
101                 idx += 1;
102             }
103         }
104         // If we came here, there is a last carry to add
105         self.data.push(by);
106     }
107
108     /// Return the nth power-of-2 as BigInt
109     pub fn power_of_2(mut power: u64) -> BigInt {
110         let mut v = Vec::new();
111         while power >= 64 {
112             v.push(0);
113             power -= 64;
114         }
115         v.push(1 << power);
116         BigInt::from_vec(v)
117     }
118 }
119
120 impl Clone for BigInt {
121     fn clone(&self) -> Self {
122         BigInt { data: self.data.clone() }
123     }
124 }
125
126 impl PartialEq for BigInt {
127     fn eq(&self, other: &BigInt) -> bool {
128         debug_assert!(self.test_invariant() && other.test_invariant());
129         self.data == other.data
130     }
131 }
132
133 impl Minimum for BigInt {
134     // This is essentially the solution to 06.1.
135     fn min<'a>(&'a self, other: &'a Self) -> &'a Self {
136         debug_assert!(self.test_invariant() && other.test_invariant());
137         if self.data.len() < other.data.len() {
138             self
139         } else if self.data.len() > other.data.len() {
140             other
141         } else {
142             // compare back-to-front, i.e., most significant digit first
143             let mut idx = self.data.len()-1;
144             while idx > 0 {
145                 if self.data[idx] < other.data[idx] {
146                     return self;
147                 } else if self.data[idx] > other.data[idx] {
148                     return other;
149                 }
150                 else {
151                     idx = idx-1;
152                 }
153             }
154             // the two are equal
155             return self;
156         }
157     }
158 }
159
160 impl fmt::Debug for BigInt {
161     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
162         self.data.fmt(f)
163     }
164 }
165
166 impl<'a, 'b> ops::Add<&'a BigInt> for &'b BigInt {
167     type Output = BigInt;
168     fn add(self, rhs: &'a BigInt) -> Self::Output {
169         let max_len = cmp::max(self.data.len(), rhs.data.len());
170         let mut result_vec:Vec<u64> = Vec::with_capacity(max_len);
171         let mut carry:bool = false; // the carry bit
172         for i in 0..max_len {
173             // compute next digit and carry
174             let lhs_val = if i < self.data.len() { self.data[i] } else { 0 };
175             let rhs_val = if i < rhs.data.len() { rhs.data[i] } else { 0 };
176             let (sum, new_carry) = overflowing_add(lhs_val, rhs_val, carry);
177             // store them
178             result_vec.push(sum);
179             carry = new_carry;
180         }
181         if carry {
182             result_vec.push(1);
183         }
184         // We know that the invariant holds: overflowing_add would only return (0, false) if
185         // the arguments are (0, 0, false), but we know that in the last iteration, one od the two digits
186         // is the last of its number and hence not 0.
187         BigInt { data: result_vec }
188     }
189 }
190
191 impl<'a> ops::Add<BigInt> for &'a BigInt {
192     type Output = BigInt;
193     #[inline]
194     fn add(self, rhs: BigInt) -> Self::Output {
195         self + &rhs
196     }
197 }
198
199 impl<'a> ops::Add<&'a BigInt> for BigInt {
200     type Output = BigInt;
201     #[inline]
202     fn add(self, rhs: &'a BigInt) -> Self::Output {
203         &self + rhs
204     }
205 }
206
207 impl ops::Add<BigInt> for BigInt {
208     type Output = BigInt;
209     #[inline]
210     fn add(self, rhs: BigInt) -> Self::Output {
211         &self + &rhs
212     }
213 }
214
215 #[cfg(test)]
216 mod tests {
217     use std::u64;
218     use super::overflowing_add;
219     use super::BigInt;
220
221     #[test]
222     fn test_overflowing_add() {
223         assert_eq!(overflowing_add(10, 100, false), (110, false));
224         assert_eq!(overflowing_add(10, 100, true), (111, false));
225         assert_eq!(overflowing_add(1 << 63, 1 << 63, false), (0, true));
226         assert_eq!(overflowing_add(1 << 63, 1 << 63, true), (1, true));
227         assert_eq!(overflowing_add(1 << 63, (1 << 63) -1 , true), (0, true));
228     }
229
230     #[test]
231     fn test_add() {
232         let b1 = BigInt::new(1 << 32);
233         let b2 = BigInt::from_vec(vec![0, 1]);
234
235         assert_eq!(&b1 + &b2, BigInt::from_vec(vec![1 << 32, 1]));
236     }
237
238     #[test]
239     fn test_inc1() {
240         let mut b = BigInt::new(0);
241         b.inc1();
242         assert_eq!(b, BigInt::new(1));
243         b.inc1();
244         assert_eq!(b, BigInt::new(2));
245
246         b = BigInt::new(u64::MAX);
247         b.inc1();
248         assert_eq!(b, BigInt::from_vec(vec![0, 1]));
249         b.inc1();
250         assert_eq!(b, BigInt::from_vec(vec![1, 1]));
251     }
252
253     #[test]
254     fn test_power_of_2() {
255         assert_eq!(BigInt::power_of_2(0), BigInt::new(1));
256         assert_eq!(BigInt::power_of_2(13), BigInt::new(1 << 13));
257         assert_eq!(BigInt::power_of_2(64), BigInt::from_vec(vec![0, 1]));
258         assert_eq!(BigInt::power_of_2(96), BigInt::from_vec(vec![0, 1 << 32]));
259         assert_eq!(BigInt::power_of_2(128), BigInt::from_vec(vec![0, 0, 1]));
260     }
261 }
262
263