consider the function below which converts the result of a * b in a couple of numbers i and j, where:
- a, b, x, y are int (Suppose they are always => 32bit-long)
- a and b are <= n*m, where n = 10^3 and m=10^5. n*m = BASE.
- a * b can be written as i*BASE + j
How would you calculate j without using any types larger than int (in case be careful about overflows with int's which are UB):
#include <iostream>
#include <cstdlib>
using namespace std;
int n = 1000, m = 100000;
struct N {
int i, j;
};
N f(int a, int b) {
N x;
int a0, a1, b0, b1, o;
a1 = a / n;
a0 = a - (a1 * n); // a0 = a % n
b1 = b / m;
b0 = b - (b1 * m); // b0 = b % m
o = a1 * b1 + (a0 * b1) / n + (b0 * a1) / m;
x.i = o;
x.j = 0; // CALCULATE J WITH INTs MATH
return x;
}
int main(int, char* argv[]) {
int a = atoi(argv[1]),
b = atoi(argv[2]);
N x = f(a, b);
cout << a << " * " << b &l开发者_如何学Got;< " = " << x.i << "*" << n*m
<< " + " << x.j << endl;
cout << "which is: " << (long long)a * b << endl;
return 0;
}
You started correctly, but lost the plot around calculation of o
. First, my assumptions: you don't want to deal with any integer greater than n*m
, so taking mod n*m
is cheating. I am saying this, because given m > 2^16
, I have to assume int is 32-bit long, which is capable of dealing with your numbers without overflowing.
In any case. You have correctly (I guess, since purpose of n
and m
are not specified) written:
a=a0 + a1*n (a0<n)
b=b0 + b1*m (b0<m)
So, if we do the math:
a*b = a0*b0 + a0*b1*m + a1*b0*n + a1*b1*n*m
Here, a0*b0 < n*m
, so it is part of j
, and a1*b1*n*m > n*m
, so it is part of i
. It is the other two terms that you need to split into two again. But you cannot calculate each and take the mod n*m
, since that would be cheating (as per my rule above). If you write:
a0*b1 = a0b1_0 + a0b1_1*n
You get:
a0*b1*m = a0b1_0*m + a0b1_1*n*m
Since a0b1_0 < n
, a0b1_0*m < n*m
, which means this part goes to j
. Obviously, a0b1_1
goes to i.
Repeat a similar logic for a1*b0, and you've got three terms to add up for j
, and three more to add up for i
.
EDIT: Forgot to mention a few things:
You need the constraints
a < n^2
andb < m^2
for this to work. Otherwise, you need more ai "words". e.g.:a = a0 + a1*n + a2*n^2, ai < n
.The final sum of
j
may be greater thann*m
. You need to watch for overflow (n*m - o < addend
, or a similar logic, and add1
toi
when this happens - while calculatingj + addend - n*m
without overflow).
I think answer will be j = a0 * b0
(a*b)/(n*m) = (a/n) * (b/m)
= (a1 + a0/n) * (b1 + b0/m)
= a1*b1 + a1*b0/m + a0*b1/n + (a0*b0)/(n*m)
now
o = a1*b1 + a1*b0/m + a0*b1/n
multiply both side with n*m
a * b = o * n*m + a0*b0
n*m is base
a * b = o * BASE + a0*b0
j = a0*b0
QED
精彩评论