e9dfdf0b9edfe9a3a627e7971e98eb2697ebeb5d
[rust-101.git] / solutions / src / bigint.rs
1 use std::ops;
2 use std::cmp;
3 use std::fmt;
4
5 pub struct BigInt {
6     data: Vec<u64>, // least significant digits first. The last block will *not* be 0.
7 }
8
9 // Add with carry, returning the sum and the carry
10 fn overflowing_add(a: u64, b: u64, carry: bool) -> (u64, bool) {
11     match u64::checked_add(a, b) {
12         Some(sum) if !carry => (sum, false),
13         Some(sum) => { // we have to increment the sum by 1, where it may overflow again
14             match u64::checked_add(sum, 1) {
15                 Some(total_sum) => (total_sum, false),
16                 None => (0, true) // we overflowed incrementing by 1, so we are just "at the edge"
17             }
18         },
19         None => {
20             // Get the remainder, i.e., the wrapping sum. This cannot overflow again by adding just 1, so it is safe
21             // to add the carry here.
22             let rem = u64::wrapping_add(a, b) + if carry { 1 } else { 0 };
23             (rem, true)
24         }
25     }
26 }
27
28
29 impl BigInt {
30     /// Construct a BigInt from a "small" one.
31     pub fn new(x: u64) -> Self {
32         if x == 0 { // take care of our invariant!
33             BigInt { data: vec![] }
34         } else {
35             BigInt { data: vec![x] }
36         }
37     }
38
39     fn test_invariant(&self) -> bool {
40         if self.data.len() == 0 {
41             true
42         } else {
43             self.data[self.data.len() - 1] != 0
44         }
45     }
46
47     /// Construct a BigInt from a vector of 64-bit "digits", with the last significant digit being first
48     pub fn from_vec(mut v: Vec<u64>) -> Self {
49         // remove trailing zeroes
50         while v.len() > 0 && v[v.len()-1] == 0 {
51             v.pop();
52         }
53         BigInt { data: v }
54     }
55
56     /// Return the smaller of the two numbers
57     pub fn min(self, other: Self) -> Self {
58         debug_assert!(self.test_invariant() && other.test_invariant());
59         if self.data.len() < other.data.len() {
60             self
61         } else if self.data.len() > other.data.len() {
62             other
63         } else {
64             // compare back-to-front, i.e., most significant digit first
65             let mut idx = self.data.len()-1;
66             while idx > 0 {
67                 if self.data[idx] < other.data[idx] {
68                     return self;
69                 } else if self.data[idx] > other.data[idx] {
70                     return other;
71                 }
72                 else {
73                     idx = idx-1;
74                 }
75             }
76             // the two are equal
77             return self;
78         }
79     }
80
81     /// Returns a view on the raw digits representing the number.
82     /// 
83     /// ```
84     /// use solutions::bigint::BigInt;
85     /// let b = BigInt::new(13);
86     /// let d = b.data();
87     /// assert_eq!(d, [13]);
88     /// ```
89     pub fn data(&self) -> &[u64] {
90         &self.data[..]
91     }
92
93     /// Increments the number by "by".
94     pub fn inc(&mut self, mut by: u64) {
95         let mut idx = 0;
96         // This loop adds "by * (1 << idx)". Think of "by" as the carry from incrementing the last digit.
97         while idx < self.data.len() {
98             let cur = self.data[idx];
99             let sum = u64::wrapping_add(cur, by);
100             self.data[idx] = sum;
101             if sum >= cur {
102                 // No overflow, we are done.
103                 return;
104             } else {
105                 // We need to add a carry.
106                 by = 1;
107                 idx += 1;
108             }
109         }
110         // If we came here, there is a last carry to add
111         self.data.push(by);
112     }
113
114     /// Return the nth power-of-2 as BigInt
115     pub fn power_of_2(mut power: u64) -> BigInt {
116         let mut v = Vec::new();
117         while power >= 64 {
118             v.push(0);
119             power -= 64;
120         }
121         v.push(1 << power);
122         BigInt::from_vec(v)
123     }
124 }
125
126 impl Clone for BigInt {
127     fn clone(&self) -> Self {
128         BigInt { data: self.data.clone() }
129     }
130 }
131
132
133 impl PartialEq for BigInt {
134     fn eq(&self, other: &BigInt) -> bool {
135         debug_assert!(self.test_invariant() && other.test_invariant());
136         self.data() == other.data()
137     }
138 }
139
140 impl fmt::Debug for BigInt {
141     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
142         self.data().fmt(f)
143     }
144 }
145
146 impl<'a, 'b> ops::Add<&'a BigInt> for &'b BigInt {
147     type Output = BigInt;
148     fn add(self, rhs: &'a BigInt) -> Self::Output {
149         let mut result_vec:Vec<u64> = Vec::with_capacity(cmp::max(self.data().len(), rhs.data().len()));
150         let mut carry:bool = false; // the carry bit
151         for (i, val) in self.data().into_iter().enumerate() {
152             // compute next digit and carry
153             let rhs_val = if i < rhs.data().len() { rhs.data()[i] } else { 0 };
154             let (sum, new_carry) = overflowing_add(*val, rhs_val, carry);
155             // store them
156             result_vec.push(sum);
157             carry = new_carry;
158         }
159         BigInt::from_vec(result_vec)
160     }
161 }
162
163 impl<'a> ops::Add<BigInt> for &'a BigInt {
164     type Output = BigInt;
165     #[inline]
166     fn add(self, rhs: BigInt) -> Self::Output {
167         self + &rhs
168     }
169 }
170
171 impl<'a> ops::Add<&'a BigInt> for BigInt {
172     type Output = BigInt;
173     #[inline]
174     fn add(self, rhs: &'a BigInt) -> Self::Output {
175         &self + rhs
176     }
177 }
178
179 impl ops::Add<BigInt> for BigInt {
180     type Output = BigInt;
181     #[inline]
182     fn add(self, rhs: BigInt) -> Self::Output {
183         &self + &rhs
184     }
185 }
186
187 #[cfg(test)]
188 mod tests {
189     use super::overflowing_add;
190     use super::BigInt;
191
192     #[test]
193     fn test_overflowing_add() {
194         assert_eq!(overflowing_add(10, 100, false), (110, false));
195         assert_eq!(overflowing_add(10, 100, true), (111, false));
196         assert_eq!(overflowing_add(1 << 63, 1 << 63, false), (0, true));
197         assert_eq!(overflowing_add(1 << 63, 1 << 63, true), (1, true));
198         assert_eq!(overflowing_add(1 << 63, (1 << 63) -1 , true), (0, true));
199     }
200
201     #[test]
202     fn test_power_of_2() {
203         assert_eq!(BigInt::power_of_2(0), BigInt::new(1));
204         assert_eq!(BigInt::power_of_2(13), BigInt::new(1 << 13));
205         assert_eq!(BigInt::power_of_2(64), BigInt::from_vec(vec![0, 1]));
206         assert_eq!(BigInt::power_of_2(96), BigInt::from_vec(vec![0, 1 << 32]));
207         assert_eq!(BigInt::power_of_2(128), BigInt::from_vec(vec![0, 0, 1]));
208     }
209 }
210
211