cpp/biginteger/biginteger.h

639 lines
16 KiB
C
Raw Permalink Normal View History

2023-07-16 07:03:47 +00:00
#include <algorithm>
#include <complex>
#include <cstring>
#include <iomanip>
#include <sstream>
#include <vector>
namespace FastFourierTransform {
typedef double ld;
typedef std::complex<ld> cld;
typedef std::vector<cld> ComplexPolynom;
const ld PI = acosl(-1);
std::vector<int16_t> Muliplication(const std::vector<int16_t>& a,
const std::vector<int16_t>& b,
std::vector<int16_t>& result);
ComplexPolynom ToComplex(const std::vector<int16_t>& a);
void FastFourierTransform_(ComplexPolynom& a, bool invert);
std::vector<int16_t> Muliplication(const std::vector<int16_t>& a,
const std::vector<int16_t>& b,
std::vector<int16_t>& result) {
ComplexPolynom ca = ToComplex(a);
ComplexPolynom cb = ToComplex(b);
size_t n = std::max(a.size(), b.size());
size_t m = 1;
while (m < n) m <<= 1;
m <<= 1;
ca.resize(m);
cb.resize(m);
FastFourierTransform_(ca, false);
FastFourierTransform_(cb, false);
for (size_t i = 0; i < m; ++i) {
ca[i] *= cb[i];
}
FastFourierTransform_(ca, true);
std::vector<int64_t> v(m);
for (size_t i = 0; i < m; ++i) {
v[i] = ca[i].real() + 0.5;
}
for (size_t i = 0; i < m; ++i) {
v[i + 1] += v[i] / 10000;
v[i] %= 10000;
while (v[i] < 0) v[i] += 10000, v[i + 1] -= 1;
}
result.resize(v.size());
std::copy(v.begin(), v.end(), result.begin());
return result;
}
ComplexPolynom ToComplex(const std::vector<int16_t>& a) {
ComplexPolynom res(a.begin(), a.end());
return res;
}
size_t rev(size_t num, size_t lg_n) {
int res = 0;
for (size_t i = 0; i < lg_n; ++i)
if (num & (1 << i)) res |= 1 << (lg_n - 1 - i);
return res;
}
void FastFourierTransform_(ComplexPolynom& a, bool invert) {
size_t n = a.size();
size_t lg_n = 0;
while ((1u << lg_n) < n) ++lg_n;
for (size_t i = 0; i < n; ++i)
if (i < rev(i, lg_n)) swap(a[i], a[rev(i, lg_n)]);
for (size_t len = 2; len <= n; len <<= 1) {
double ang = 2 * PI / len * (invert ? -1 : 1);
cld wlen(cos(ang), sin(ang));
for (size_t i = 0; i < n; i += len) {
cld w(1);
for (size_t j = 0; j < len / 2; ++j) {
cld u = a[i + j], v = a[i + j + len / 2] * w;
a[i + j] = u + v;
a[i + j + len / 2] = u - v;
w *= wlen;
}
}
}
if (invert)
for (size_t i = 0; i < n; ++i) a[i] /= n;
}
}; // namespace FastFourierTransform
class BigInteger;
BigInteger operator-(const BigInteger& a, const BigInteger& b);
BigInteger operator*(const BigInteger& a, const BigInteger& b);
BigInteger operator/(const BigInteger& a, const BigInteger& b);
BigInteger operator+(const BigInteger& a, const BigInteger& b);
bool operator>(const BigInteger& a, const BigInteger& b);
bool operator<=(const BigInteger& a, const BigInteger& b);
bool operator>=(const BigInteger& a, const BigInteger& b);
bool operator==(const BigInteger& a, const BigInteger& b);
bool operator!=(const BigInteger& a, const BigInteger& b);
BigInteger operator%(const BigInteger& a, const BigInteger& b);
class BigInteger {
public:
BigInteger() {}
BigInteger(int64_t x) {
if (x < 0) {
is_negative_ = true, x *= -1;
}
t_[0] = x % BASE;
x /= BASE;
while (x > 0) {
t_.push_back(x % BASE);
x /= BASE;
}
Normalize();
}
BigInteger(const BigInteger& x) : t_(x.t_), is_negative_(x.is_negative_) {}
void swap(BigInteger& x) {
std::swap(t_, x.t_);
std::swap(is_negative_, x.is_negative_);
}
BigInteger& operator=(const BigInteger& x) {
BigInteger tmp = x;
swap(tmp);
return *this;
}
friend bool operator<(const BigInteger& a, const BigInteger& b) {
if (a.is_negative_ ^ b.is_negative_) {
return a.is_negative_;
}
if (a.t_.size() != b.t_.size()) {
return (a.t_.size() < b.t_.size()) ^ a.is_negative_;
}
for (size_t i = a.t_.size(); i > 0; --i) {
if (a.t_[i - 1] != b.t_[i - 1]) {
return (a.t_[i - 1] < b.t_[i - 1]) ^ a.is_negative_;
}
}
return a.is_negative_;
}
BigInteger& operator++() {
*this += 1;
Normalize();
return *this;
}
BigInteger operator++(int) {
BigInteger tmp = *this;
++(*this);
return tmp;
}
void addingDifferentSigns(const BigInteger& x) {
int16_t carry[2] = {0, 0};
// Вычитаем, если *this < x, то последнее число будет отрицательным и
// потом это пофиксим
for (size_t i = 0; i < x.t_.size() || (carry[i & 1] && i < t_.size());
++i) {
if (i < x.t_.size()) {
t_[i] -= x.t_[i];
}
t_[i] += carry[i & 1];
if (t_[i] < 0 && i + 1 < t_.size()) {
t_[i] += BASE;
carry[(i & 1) ^ 1] = -1;
} else {
carry[(i & 1) ^ 1] = 0;
}
}
// Убераем нули в начале и начинаем фиксить отрицательное число
Normalize();
if (t_.back() < 0) {
t_.pop_back();
for (size_t i = 0; i < t_.size(); ++i) {
t_[i] = BASE - 1 - t_[i];
}
++t_[0];
while (t_.size() > 1 && t_.back() == 0) {
t_.pop_back();
}
is_negative_ ^= true;
}
// Когда мы пофиксили отрицательные числа, у нас в ячейках могли
// образоваться числа >= BASE, фиксим это
carry[0] = 0;
carry[1] = 0;
if (t_[0] >= BASE) {
carry[1] = 1;
t_[0] -= BASE;
}
for (size_t i = 1; i < t_.size() && carry[i & 1]; ++i) {
t_[i] += carry[i & 1];
carry[i & 1] = 0;
if (t_[i] >= BASE) {
carry[(i & 1) ^ 1] = 1;
t_[i] -= BASE;
}
}
// У нас может произойти перенос в ячейку, которой не существует, добавим
// ее)
if (carry[t_.size() & 1]) {
t_.push_back(1);
}
}
BigInteger& operator+=(const BigInteger& x) {
t_.resize(std::max(t_.size(), x.t_.size()) + 1);
// Если числа одинаковых знаков - просто сложение
if (!is_negative_ ^ x.is_negative_) {
int16_t carry[2] = {0, 0};
for (size_t i = 0; i < x.t_.size() || carry[i & 1]; ++i) {
if (i < x.t_.size()) {
t_[i] += x.t_[i];
}
t_[i] += carry[i & 1];
if (t_[i] >= BASE) {
t_[i] -= BASE;
carry[(i & 1) ^ 1] = 1;
} else {
carry[(i & 1) ^ 1] = 0;
}
}
Normalize();
} else {
addingDifferentSigns(x);
}
return *this;
}
BigInteger& operator-=(const BigInteger& x) {
// a - b = -(-a + b)
is_negative_ ^= 1;
*this += x;
is_negative_ ^= 1;
Normalize();
return *this;
}
BigInteger& operator*=(int16_t x) {
// Это умножение на маленькое число, работает за O(n) и используется в
// делении
if (x < 0) {
is_negative_ = !is_negative_;
x *= -1;
}
if (x >= BASE) {
return *this *= BigInteger(x);
}
t_.resize(t_.size() + 3);
int16_t carry[2] = {0, 0};
for (size_t i = 0; i < t_.size(); ++i) {
carry[!(i & 1)] = (int(t_[i]) * x + carry[i & 1]) / BASE;
t_[i] = (int(t_[i]) * x + carry[i & 1]) % BASE;
}
while (t_.size() > 1 && t_.back() == 0) t_.pop_back();
Normalize();
return *this;
}
BigInteger& operator*=(const BigInteger& x) {
// Если число на которое нам надо умножить небольшое, выгоднее умножать за
// квадрат
if (x.t_.size() < 100) {
return SquareMultiplication(x);
}
FastFourierTransform::Muliplication(t_, x.t_, t_);
is_negative_ ^= x.is_negative_;
Normalize();
return *this;
}
// Деление работает за O(N^2), мы вычисляем все значащие биты в делении,
// которых O(n), с помощью бинпоиска(константа, так как log 10000), и
// проверкой в бинпоиске за O(n), потому что используется умножение за O(n), и
// в итоге O(N^2)
BigInteger& operator/=(const BigInteger& x) {
bool is_negative = is_negative_ ^ x.is_negative_;
is_negative_ = false;
std::vector<int16_t> res(t_.size() - std::min(x.t_.size(), t_.size()) + 4);
int16_t d = 1;
if (x < 0) {
d = -1;
}
for (int i = res.size(); i > 0; --i) {
int l = 0, r = BASE - 1, m;
while (l != r) {
m = (l + r + 1) / 2;
if ((x * d * m).MuliplicationDegree10(i - 1) <= *this) {
l = m;
} else {
r = m - 1;
}
}
res[i - 1] = l;
*this -= (x * d * res[i - 1]).MuliplicationDegree10(i - 1);
}
t_ = res;
is_negative_ = is_negative;
Normalize();
return *this;
}
BigInteger& operator%=(const BigInteger& x) {
*this -= (*this / x) * x;
Normalize();
return *this;
}
std::string toString() const {
std::ostringstream string;
if (is_negative_) {
string << "-";
}
string << t_.back();
for (auto it = ++t_.rbegin(); it != t_.rend(); ++it) {
string << std::setw(4) << std::setfill('0') << *it;
}
return string.str();
}
explicit operator bool() const { return t_.size() > 1 || t_[0]; }
friend std::istream& operator>>(std::istream& in, BigInteger& x) {
std::string s;
in >> s;
std::reverse(s.begin(), s.end());
if (s.back() == '-') {
x.is_negative_ = true;
s.pop_back();
} else {
x.is_negative_ = false;
}
while (s.size() % 4) {
s.push_back('0');
}
x.t_.resize(s.size() / 4);
for (size_t i = 0; i < s.size(); i += 4) {
x.t_[i / 4] = s[i] - '0' + 10 * (s[i + 1] - '0') +
100 * (s[i + 2] - '0') + 1000 * (s[i + 3] - '0');
}
x.Normalize();
return in;
}
friend std::ostream& operator<<(std::ostream& out, const BigInteger& x) {
out << x.toString();
return out;
}
BigInteger operator-() {
BigInteger tmp(*this);
tmp.is_negative_ ^= true;
tmp.Normalize();
return tmp;
}
private:
void Normalize() {
while (t_.size() > 1 && t_.back() == 0) {
t_.pop_back();
}
if (t_.size() == 1 && t_[0] == 0) {
is_negative_ = false;
}
}
// Это умножение на 10^degree
BigInteger& MuliplicationDegree10(int degree) {
size_t n = t_.size();
t_.resize(t_.size() + degree);
std::rotate(t_.begin(), t_.begin() + n, t_.end());
Normalize();
return *this;
}
// Это используется, если число на которое мы умножаем не очень большое,
// потому что FFT работает достаточно медленно
BigInteger& SquareMultiplication(const BigInteger& x) {
std::vector<uint64_t> t(t_.size() + x.t_.size() + 1);
for (size_t i = 0; i < t_.size(); ++i) {
for (size_t j = 0; j < x.t_.size(); ++j) {
t[i + j] += static_cast<uint64_t>(t_[i]) * x.t_[j];
}
}
t_.resize(t_.size() + x.t_.size() + 1);
for (size_t i = 0; i < t.size(); ++i) {
t[i + 1] += t[i] / BASE;
t_[i] = t[i] % BASE;
}
while (t_.size() > 1 && t_.back() == 0) t_.pop_back();
is_negative_ ^= x.is_negative_;
Normalize();
return *this;
}
static constexpr uint16_t BASE = 10000;
std::vector<int16_t> t_ = {0};
bool is_negative_ = false;
};
BigInteger operator-(const BigInteger& a, const BigInteger& b) {
return BigInteger(a) -= b;
}
BigInteger operator*(const BigInteger& a, const BigInteger& b) {
return BigInteger(a) *= b;
}
BigInteger operator/(const BigInteger& a, const BigInteger& b) {
BigInteger tmp(a);
tmp /= b;
return tmp;
}
BigInteger operator%(const BigInteger& a, const BigInteger& b) {
BigInteger tmp(a);
tmp %= b;
return tmp;
}
BigInteger operator+(const BigInteger& a, const BigInteger& b) {
return BigInteger(a) += b;
}
bool operator>(const BigInteger& a, const BigInteger& b) { return b < a; }
bool operator<=(const BigInteger& a, const BigInteger& b) { return !(a > b); }
bool operator>=(const BigInteger& a, const BigInteger& b) { return b <= a; }
bool operator==(const BigInteger& a, const BigInteger& b) {
return !(a < b) && !(b < a);
}
bool operator!=(const BigInteger& a, const BigInteger& b) { return !(a == b); }
class Rational;
Rational operator+(const Rational& a, const Rational& b);
Rational operator-(const Rational& a, const Rational& b);
Rational operator*(const Rational& a, const Rational& b);
Rational operator/(const Rational& a, const Rational& b);
bool operator>(const Rational& a, const Rational& b);
bool operator<=(const Rational& a, const Rational& b);
bool operator>=(const Rational& a, const Rational& b);
bool operator==(const Rational& a, const Rational& b);
bool operator!=(const Rational& a, const Rational& b);
class Rational {
public:
Rational() = default;
Rational(const BigInteger& a) { numerator_ = a; }
Rational(int a) { numerator_ = a; }
Rational(const Rational& x)
: numerator_(x.numerator_), denominator_(x.denominator_) {}
void swap(Rational& x) {
std::swap(numerator_, x.numerator_);
std::swap(denominator_, x.denominator_);
}
Rational& operator=(const Rational& x) {
Rational tmp(x);
swap(tmp);
return *this;
}
Rational& operator+=(const Rational& a) {
if (this == &a) {
return *this += Rational(a);
}
numerator_ *= a.denominator_;
numerator_ += a.numerator_ * denominator_;
denominator_ *= a.denominator_;
Normalize();
return *this;
}
Rational& operator-=(const Rational& a) {
if (this == &a) {
return *this -= Rational(a);
}
numerator_ *= a.denominator_;
numerator_ -= a.numerator_ * denominator_;
denominator_ *= a.denominator_;
Normalize();
return *this;
}
Rational& operator*=(const Rational& a) {
numerator_ *= a.numerator_;
denominator_ *= a.denominator_;
Normalize();
return *this;
}
Rational& operator/=(const Rational& a) {
if (this == &a) {
return *this /= Rational(a);
}
numerator_ *= a.denominator_;
denominator_ *= a.numerator_;
Normalize();
return *this;
}
Rational operator-() const {
Rational tmp(*this);
tmp.numerator_ *= -1;
return tmp;
}
bool operator<(const Rational& a) const {
return numerator_ * a.denominator_ < denominator_ * a.numerator_;
}
std::string toString() const {
std::ostringstream str;
str << numerator_;
if (denominator_ != 1) {
str << "/" << denominator_;
}
return str.str();
}
std::string asDecimal(size_t precision) {
BigInteger t = 1;
for (size_t i = 0; i < precision; ++i) {
t *= 10;
}
std::ostringstream str;
BigInteger result = (numerator_ * t) / denominator_;
if (result < 0) {
str << "-", result *= -1;
}
str << result / t << "." << std::setw(precision) << std::setfill('0')
<< result % t;
return str.str();
}
explicit operator double() { return 1.0; }
BigInteger gcd(const BigInteger& a, const BigInteger& b) {
return b ? gcd(b, a % b) : a;
}
void Normalize() {
if (numerator_ == 0) {
denominator_ = 1;
return;
}
BigInteger d = gcd(numerator_, denominator_);
numerator_ /= d;
denominator_ /= d;
if (denominator_ < 0) {
numerator_ *= -1;
denominator_ *= -1;
}
}
BigInteger numerator_ = 0;
BigInteger denominator_ = 1;
};
Rational operator+(const Rational& a, const Rational& b) {
Rational tmp(a);
tmp += b;
return tmp;
}
Rational operator-(const Rational& a, const Rational& b) {
Rational tmp(a);
tmp -= b;
return tmp;
}
Rational operator*(const Rational& a, const Rational& b) {
Rational tmp(a);
tmp *= b;
return tmp;
}
Rational operator/(const Rational& a, const Rational& b) {
Rational tmp(a);
tmp /= b;
return tmp;
}
bool operator>(const Rational& a, const Rational& b) { return b < a; }
bool operator<=(const Rational& a, const Rational& b) {
return a < b || a == b;
}
bool operator>=(const Rational& a, const Rational& b) {
return a > b || a == b;
}
bool operator==(const Rational& a, const Rational& b) {
return !(b < a) && !(a < b);
}
bool operator!=(const Rational& a, const Rational& b) { return !(a == b); }