start doing the course-part of BigInt
[rust-101.git] / solutions / src / bigint.rs
1 use std::ops;
2 use std::cmp;
3 use std::fmt;
4
5 pub trait Minimum {
6     /// Return the smaller of the two
7     fn min<'a>(&'a self, other: &'a Self) -> &'a Self;
8 }
9
10 /// Return a pointer to the minimal value of `v`.
11 pub fn vec_min<T: Minimum>(v: &Vec<T>) -> Option<&T> {
12     let mut min = None;
13     for e in v {
14         min = Some(match min {
15             None => e,
16             Some(n) => e.min(n)
17         });
18     }
19     min
20 }
21
22 pub struct BigInt {
23     data: Vec<u64>, // least significant digits first. The last block will *not* be 0.
24 }
25
26 // Add with carry, returning the sum and the carry
27 fn overflowing_add(a: u64, b: u64, carry: bool) -> (u64, bool) {
28     let sum = u64::wrapping_add(a, b);
29     let carry_n = if carry { 1 } else { 0 };
30     if sum >= a { // the first sum did not overflow
31         let sum_total = u64::wrapping_add(sum, carry_n);
32         let had_overflow = sum_total < sum;
33         (sum_total, had_overflow)
34     } else { // the first sum did overflow
35         // it is impossible for this to overflow again, as we are just adding 0 or 1
36         (sum + carry_n, true)
37     }
38 }
39
40 impl BigInt {
41     /// Construct a BigInt from a "small" one.
42     pub fn new(x: u64) -> Self {
43         if x == 0 { // take care of our invariant!
44             BigInt { data: vec![] }
45         } else {
46             BigInt { data: vec![x] }
47         }
48     }
49
50     fn test_invariant(&self) -> bool {
51         if self.data.len() == 0 {
52             true
53         } else {
54             self.data[self.data.len() - 1] != 0
55         }
56     }
57
58     /// Construct a BigInt from a vector of 64-bit "digits", with the last significant digit being first
59     pub fn from_vec(mut v: Vec<u64>) -> Self {
60         // remove trailing zeroes
61         while v.len() > 0 && v[v.len()-1] == 0 {
62             v.pop();
63         }
64         BigInt { data: v }
65     }
66
67     /// Increments the number by "by".
68     pub fn inc(&mut self, mut by: u64) {
69         let mut idx = 0;
70         // This loop adds "by * (1 << idx)". Think of "by" as the carry from incrementing the last digit.
71         while idx < self.data.len() {
72             let cur = self.data[idx];
73             let sum = u64::wrapping_add(cur, by);
74             self.data[idx] = sum;
75             if sum >= cur {
76                 // No overflow, we are done.
77                 return;
78             } else {
79                 // We need to add a carry.
80                 by = 1;
81                 idx += 1;
82             }
83         }
84         // If we came here, there is a last carry to add
85         self.data.push(by);
86     }
87
88     /// Return the nth power-of-2 as BigInt
89     pub fn power_of_2(mut power: u64) -> BigInt {
90         let mut v = Vec::new();
91         while power >= 64 {
92             v.push(0);
93             power -= 64;
94         }
95         v.push(1 << power);
96         BigInt::from_vec(v)
97     }
98 }
99
100 impl Clone for BigInt {
101     fn clone(&self) -> Self {
102         BigInt { data: self.data.clone() }
103     }
104 }
105
106 impl PartialEq for BigInt {
107     fn eq(&self, other: &BigInt) -> bool {
108         debug_assert!(self.test_invariant() && other.test_invariant());
109         self.data == other.data
110     }
111 }
112
113 impl Minimum for BigInt {
114     fn min<'a>(&'a self, other: &'a Self) -> &'a Self {
115         debug_assert!(self.test_invariant() && other.test_invariant());
116         if self.data.len() < other.data.len() {
117             self
118         } else if self.data.len() > other.data.len() {
119             other
120         } else {
121             // compare back-to-front, i.e., most significant digit first
122             let mut idx = self.data.len()-1;
123             while idx > 0 {
124                 if self.data[idx] < other.data[idx] {
125                     return self;
126                 } else if self.data[idx] > other.data[idx] {
127                     return other;
128                 }
129                 else {
130                     idx = idx-1;
131                 }
132             }
133             // the two are equal
134             return self;
135         }
136     }
137 }
138
139 impl fmt::Debug for BigInt {
140     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
141         self.data.fmt(f)
142     }
143 }
144
145 impl<'a, 'b> ops::Add<&'a BigInt> for &'b BigInt {
146     type Output = BigInt;
147     fn add(self, rhs: &'a BigInt) -> Self::Output {
148         let mut result_vec:Vec<u64> = Vec::with_capacity(cmp::max(self.data.len(), rhs.data.len()));
149         let mut carry:bool = false; // the carry bit
150         for (i, val) in (&self.data).into_iter().enumerate() {
151             // compute next digit and carry
152             let rhs_val = if i < rhs.data.len() { rhs.data[i] } else { 0 };
153             let (sum, new_carry) = overflowing_add(*val, rhs_val, carry);
154             // store them
155             result_vec.push(sum);
156             carry = new_carry;
157         }
158         if carry {
159             result_vec.push(1);
160         }
161         // We know that the invariant holds: overflowing_add would only return (0, false) if
162         // the arguments are (0, 0, false), but we know that in the last iteration, `val` is the
163         // last digit of `self` and hence not 0.
164         BigInt { data: result_vec }
165     }
166 }
167
168 impl<'a> ops::Add<BigInt> for &'a BigInt {
169     type Output = BigInt;
170     #[inline]
171     fn add(self, rhs: BigInt) -> Self::Output {
172         self + &rhs
173     }
174 }
175
176 impl<'a> ops::Add<&'a BigInt> for BigInt {
177     type Output = BigInt;
178     #[inline]
179     fn add(self, rhs: &'a BigInt) -> Self::Output {
180         &self + rhs
181     }
182 }
183
184 impl ops::Add<BigInt> for BigInt {
185     type Output = BigInt;
186     #[inline]
187     fn add(self, rhs: BigInt) -> Self::Output {
188         &self + &rhs
189     }
190 }
191
192 #[cfg(test)]
193 mod tests {
194     use super::overflowing_add;
195     use super::BigInt;
196
197     #[test]
198     fn test_overflowing_add() {
199         assert_eq!(overflowing_add(10, 100, false), (110, false));
200         assert_eq!(overflowing_add(10, 100, true), (111, false));
201         assert_eq!(overflowing_add(1 << 63, 1 << 63, false), (0, true));
202         assert_eq!(overflowing_add(1 << 63, 1 << 63, true), (1, true));
203         assert_eq!(overflowing_add(1 << 63, (1 << 63) -1 , true), (0, true));
204     }
205
206     #[test]
207     fn test_power_of_2() {
208         assert_eq!(BigInt::power_of_2(0), BigInt::new(1));
209         assert_eq!(BigInt::power_of_2(13), BigInt::new(1 << 13));
210         assert_eq!(BigInt::power_of_2(64), BigInt::from_vec(vec![0, 1]));
211         assert_eq!(BigInt::power_of_2(96), BigInt::from_vec(vec![0, 1 << 32]));
212         assert_eq!(BigInt::power_of_2(128), BigInt::from_vec(vec![0, 0, 1]));
213     }
214 }
215
216