def quadratic_residue(p):
k = (p - 1) // 4
return next(x for a in range(2, p - 1) if (x := pow(a, k, p)) * x % p == p - 1)
def gcd(a: GaussianInt, b: GaussianInt) -> GaussianInt:
while b.real != 0 or b.imag != 0:
a, b = b, a % b
return a
class GaussianInt:
__slots__ = ('real', 'imag')
def __init__(self, real, imag=0):
self.real = real
self.imag = imag
def __sub__(self, other):
return GaussianInt(self.real - other.real, self.imag - other.imag)
def __mul__(self, other):
return GaussianInt(
self.real * other.real - self.imag * other.imag,
self.real * other.imag + self.imag * other.real,
)
def __floordiv__(self, other):
norm = other.real * other.real + other.imag * other.imag
r = self.real * other.real + self.imag * other.imag
i = self.imag * other.real - self.real * other.imag
return GaussianInt(round(r / norm), round(i / norm))
def __mod__(self, other):
return self - self // other * other
def sum_of_two_squares_if_possible(n) -> tuple[int,int] | None:
factoriaztion = numtheory.prime_factorization(n)
primes = [p for p, e in factoriaztion.items() if e % 2 == 1]
if any(p % 4 == 3 for p in primes):
return None
ans = GaussianInt(math.isqrt(n // math.prod(primes)))
for p in primes:
if p == 2:
ans *= GaussianInt(1, 1)
else:
x = quadratic_residue(p)
ans *= gcd(GaussianInt(p), GaussianInt(x, 1))
return (abs(ans.real), abs(ans.imag))
def hermite_serret(p):
if p == 2:
return (1, 1)
q = quadratic_residue(p)
sqrt_p = math.sqrt(p)
c = None
while q > 0:
if q < sqrt_p:
if c is None:
c = q
else:
return (c, q)
p, q = q, p % q
def sum_of_two_squares_if_possible(n):
factoriaztion = prime_factorization(n)
primes = [p for p, e in factoriaztion.items() if e % 2 == 1]
if any(p % 4 == 3 for p in primes):
return None
a, b = math.isqrt(n // math.prod(primes)), 0
for p in primes:
c, d = hermite_serret(p)
a, b = a * c + b * d, a * d - b * c
return (abs(a), abs(b))