ps:problems:boj:25974
거듭제곱의 합 1
ps | |
---|---|
링크 | acmicpc.net/… |
출처 | BOJ |
문제 번호 | 25974 |
문제명 | 거듭제곱의 합 1 |
레벨 | 플래티넘 2 |
분류 |
수학 |
시간복잡도 | O(plogp) |
인풋사이즈 | p<=1000 |
사용한 언어 | Python 3.11 |
제출기록 | 31256KB / 44ms |
최고기록 | 44ms |
해결날짜 | 2023/02/09 |
풀이
* 거듭제곱의 합을 계산하는 문제.
코드
"""Solution code for "BOJ 25974. 거듭제곱의 합 1".
- Problem link: https://www.acmicpc.net/problem/25974
- Solution link: http://www.teferi.net/ps/problems/boj/25974
Tags: [Lagrangian interpolation]
"""
MOD = 10**9 + 7
def multiple_mod_inv(nums, mod):
a = list(nums)
b = [v := 1] + [v := v * x % mod for x in reversed(a)]
b_inv = pow(b.pop(), -1, mod)
return [b_inv * b.pop() % mod] + [
(b_inv := b_inv * a_ % mod) * b_ % mod for a_, b_ in zip(a, reversed(b))
]
def lagrangian_interpolation(y, n, prime_mod):
"""Finds k-th order func f(x) from y=[f(0), ..., f(k)], and returns f(n)."""
l = len(y)
if n < l:
return y[n]
invs = multiple_mod_inv(range(n, n - l, -1), prime_mod)
factorials = [v := 1] + [v := v * i % prime_mod for i in range(1, l)]
finv = multiple_mod_inv(factorials, prime_mod)
answer = 0
sign = 1 if l % 2 else -1
for inv_i, finv_i, finv_j, y_i in zip(invs, finv, reversed(finv), y):
answer += sign * inv_i * finv_i * finv_j * y_i
sign = -sign
for i in range(n - l + 1, n + 1):
answer = answer * i % prime_mod
return answer
def sum_of_powers(n, k, prime_mod):
"""Returns (1^k + 2^k + ... + n^k) % prime_mod."""
y = [v := 0] + [v := v + pow(i, k, prime_mod) for i in range(1, k + 2)]
return lagrangian_interpolation(y, n, prime_mod)
def main():
n, p = [int(x) for x in input().split()]
print(sum_of_powers(n, p, MOD))
if __name__ == '__main__':
main()
ps/problems/boj/25974.txt · 마지막으로 수정됨: 2023/02/09 16:31 저자 teferi
토론