[python刷题模板] 矩阵快速幂 (手写/numpy
- 一、 算法&数据结构
- 1. 描述
- 2. 复杂度分析
- 3. 常见应用
- 4. 常用优化
- 利用numpy库省去手写矩阵乘法的过程.
- 二、 模板代码
- 1. 斐波那契数列(错位写矩阵,手写矩阵乘法)
- 2. 1137. 第 N 个泰波那契数(错位写矩阵,手写矩阵乘法)
- 3. 1220. 统计元音字母序列的数目(状态机DP,用numpy)
- 4. 552. 学生出勤记录 II(2维状态机DP展开成1维,用numpy)
- 5. 2851. 字符串转换(KMP+矩阵快速幂)
- 三、其他
- 四、更多例题
- 五、参考链接
一、 算法&数据结构
1. 描述
矩阵快速幂是一种采用数学办法降低复杂度的操作。(其实就是把递推公式变成通项公式)
- 如果一个递推公式可以写成诸如正方形矩阵的形式,且每个相邻递推公式系数不变,那么可以提出系数变成指数写法.
- 记为: f[i] = m*f[i-1],且m不随i变化。
- 那么可以推出f[n] = mn * f[0]。
- 我们可以用快速幂计算
m^n
,这样本需要递推计算n次的操作,就降低到了log(n)次。
- 若f[i]每一项只有一项,那么就是上述计算过程,非常好理解,一下就推出了通项公式。
- 但若f[i][0/1/2…]这种二维矩阵怎么办呢?其实是一样的.
- 可以将f[i]展开写成一列的矩阵,形如:
- f[i][0] = c00 * f[i-1][0] + c01 * f[i-1][1] …
- f[i][1] = c10 * f[i-1][0] + c11 * f[i-1][1] …
- …
- 即:
- [ f [ i ] [ 0 ] f [ i ] [ 1 ] . . ] = [ c 00 c 01 . . c 10 c 11 . . . . ] ∗ [ f [ i − 1 ] [ 0 ] f [ i − 1 ] [ 1 ] . . ] \left[\begin {array}{c} f[i][0] \\ f[i][1] \\ .. \\ \end{array}\right] = \left[\begin {array}{c} c00 &c01 &.. \\ c10 &c11 &.. \\ .. \\ \end{array}\right] *\left[\begin {array}{c} f[i-1][0] \\ f[i-1][1] \\ .. \\ \end{array}\right] f[i][0]f[i][1].. = c00c10..c01c11.... ∗ f[i−1][0]f[i−1][1]..
- 可推出:
- [ f [ n ] [ 0 ] f [ n ] [ 1 ] . . ] = [ c 00 c 01 . . c 10 c 11 . . . . ] n ∗ [ f [ 0 ] [ 0 ] f [ 0 ] [ 1 ] . . ] \left[\begin {array}{c} f[n][0] \\ f[n][1] \\ .. \\ \end{array}\right] = \left[\begin {array}{c} c00 &c01 &.. \\ c10 &c11 &.. \\ .. \\ \end{array}\right]^n * \left[\begin {array}{c} f[0][0] \\ f[0][1] \\ .. \\ \end{array}\right] f[n][0]f[n][1].. = c00c10..c01c11.... n∗ f[0][0]f[0][1]..
- 于是,我们只需要知道f[0]和矩阵m,就相当于知道了通项公式。
- 要注意的是,这需要
保证
一件事,就是上边提到的相邻递推公式系数不变。 - 换句话说,在递推过程中,每层的转移方法是固定的,不随着当前层的
值
/下标
等因素发生if
等特殊改动(这通常是状态机)。 - 这才能保证可以提出系数,而且需要矩阵是正方形,才能做平方操作。
- 因此这种题可能首先要写普通dp,列出转移方程,再优化。
- 注意,手写快速幂矩阵时,np.eye()(对角线是1)的矩阵,相当于数字里的1。
- 另外,推完f[n],通常要考虑最终答案是什么,可能是fn[i],也可能是sum(fn),等。
2. 复杂度分析
- 把本来 O(n) 的递推操作,优化成了 O(logn) 的数学计算。
3. 常见应用
- 状态转移只依赖上层的线性DP(通常是状态机)。
- 状态转移只依赖上2/3…层的线性DP,采用错位写法。
4. 常用优化
import numpy as np
if n == 1:return 5
m = np.mat([[0, 1, 1, 0, 1],[1, 0, 1, 0, 0],[0, 1, 0, 1, 0],[0, 0, 1, 0, 0],[0, 0, 1, 1, 0]
])
f0 = np.mat([[1],[1],[1],[1],[1],
])
n -= 1
while n:if n & 1:f0 = m * f0 % MODm = m * m % MODn >>= 1
return int(f0.sum()) % MOD
二、 模板代码
1. 斐波那契数列(错位写矩阵,手写矩阵乘法)
例题: 509. 斐波那契数
def matrix_multiply(a, b, MOD=10 ** 9 + 7):m, n, p = len(a), len(a[0]), len(b[0])ans = [[0] * p for _ in range(m)]for i in range(m):for j in range(n):for k in range(p):ans[i][k] = (ans[i][k] + a[i][j] * b[j][k])return ansdef matrix_pow_mod(a, b, MOD=10 ** 9 + 7):n = len(a)ans = [[0] * n for _ in range(n)]for i in range(n):ans[i][i] = 1while b:if b & 1:ans = matrix_multiply(ans, a, MOD)a = matrix_multiply(a, a, MOD)b >>= 1return ansclass Solution:def fib(self, n: int) -> int:if n == 0:return 0m = [[1, 1], [1, 0]]return matrix_pow_mod(m, n - 1)[0][0]
2. 1137. 第 N 个泰波那契数(错位写矩阵,手写矩阵乘法)
链接: 1137. 第 N 个泰波那契数
def matrix_multiply(a, b, MOD=10 ** 9 + 7):m, n, p = len(a), len(a[0]), len(b[0])ans = [[0] * p for _ in range(m)]for i in range(m):for j in range(n):for k in range(p):ans[i][k] = (ans[i][k] + a[i][j] * b[j][k])return ansdef matrix_pow_mod(a, b, MOD=10 ** 9 + 7):n = len(a)ans = [[0] * n for _ in range(n)]for i in range(n):ans[i][i] = 1while b:if b & 1:ans = matrix_multiply(ans, a, MOD)a = matrix_multiply(a, a, MOD)b >>= 1return ansclass Solution:def fib(self, n: int) -> int:if n == 0:return 0m = [[1, 1], [1, 0]]return matrix_pow_mod(m, n - 1)[0][0]
3. 1220. 统计元音字母序列的数目(状态机DP,用numpy)
链接: 1220. 统计元音字母序列的数目
MOD = 10**9+7import numpy as np
class Solution:def countVowelPermutation(self, n: int) -> int:"""定义f[i][0~4]表示长为i+1的字符串,最后结尾是aeiou的种类数显然f[0][0:5] = [1,1,1,1,1]下边g = f[i-1]f[i][0] = 0g[0] + 1g[1] + 1g[2] + 0g[3] + 1g[4]f[i][1] = 1g[0] + 0g[1] + 1g[2] + 0g[3] + 0g[4]f[i][2] = 0g[0] + 1g[1] + 0g[2] + 1g[3] + 0g[4]f[i][3] = 0g[0] + 0g[1] + 1g[2] + 0g[3] + 0g[4]f[i][4] = 0g[0] + 0g[1] + 1g[2] + 1g[3] + 0g[4]"""if n == 1:return 5m = np.mat([[0,1,1,0,1],[1,0,1,0,0],[0,1,0,1,0],[0,0,1,0,0],[0,0,1,1,0]])f0 = np.mat([[1],[1],[1],[1],[1],])n -= 1while n:if n &1:f0 = m*f0 %MOD m = m*m%MOD n >>= 1return int(f0.sum()) %MOD
4. 552. 学生出勤记录 II(2维状态机DP展开成1维,用numpy)
链接: 552. 学生出勤记录 II
- 这题状态维度多一层,还好是2*3,可以直接展开。
MOD = 10 ** 9 + 7
import numpy as np
class Solution:def checkRecord(self, n: int) -> int:'''f[i][0/1][0/1/2]表示i天,A=0,1,最近连续L为0/1/2时的情况'''# f = [[0]*3 for _ in range(2)]# f[0][0] = 1# for _ in range(n):# g = [[0]*3 for _ in range(2)]# g[0][0] = f[0][0] + f[0][1] + f[0][2]# g[0][1] = f[0][0] # g[0][2] = f[0][1]# g[1][0] = f[0][0] + f[0][1] + f[0][2] + f[1][0] + f[1][1] + f[1][2]# g[1][1] = f[1][0]# g[1][2] = f[1][1]# for i in range(2):# for j in range(3):# g[i][j] %= MOD # f = g # return sum(sum(row) for row in f) %MOD'''0 0:00 1:10 2:21 0:31 1:41 2:5f[i][0] = [1, 1, 1, 0, 0, 0]f[i][1] = [1, 0, 0, 0, 0, 0]f[i][2] = [0, 1, 0, 0, 0, 0]f[i][3] = [1, 1, 1, 1, 1, 1]f[i][4] = [0, 0, 0, 1, 0, 0]f[i][5] = [0, 0, 0, 0, 1, 0]f[n] = m^n * [[1],[0],[0],[0],[0],[0]]'''f0 = np.mat([[1],[0],[0],[0],[0],[0]])m = np.mat([[1, 1, 1, 0, 0, 0],[1, 0, 0, 0, 0, 0],[0, 1, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1],[0, 0, 0, 1, 0, 0],[0, 0, 0, 0, 1, 0],])while n:if n & 1:f0 = m * f0 %MOD m = m * m %MOD n >>= 1return int(f0.sum()%MOD)
5. 2851. 字符串转换(KMP+矩阵快速幂)
链接: 2851. 字符串转换
- 难点在于想到kmp,以及状态转移。
MOD = 10 ** 9 + 7
class Kmp:"""kmp算法,计算前缀函数pi,根据pi转移,复杂度O(m+n)"""def __init__(self, t):"""传入模式串,计算前缀函数"""self.t = tn = len(t)self.pi = pi = [0] * nj = 0for i in range(1, n):while j and t[i] != t[j]:j = pi[j - 1] # 失配后缩短期望匹配长度if t[i] == t[j]:j += 1 # 多配一个pi[i] = jdef find_all_yield(self, s):"""查找t在s中的所有位置,注意可能为空"""n, t, pi, j = len(self.t), self.t, self.pi, 0 for i, v in enumerate(s):while j and v != t[j]:j = pi[j - 1]if v == t[j]:j += 1if j == n:yield i - j + 1j = pi[j - 1]def find_one(self, s):"""查找t在s中的第一个位置,如果不存在就返回-1"""for ans in self.find_all_yield(s):return ansreturn -1
def matrix_multiply(a, b, MOD=10**9+7):m, n, p = len(a), len(a[0]), len(b[0])ans = [[0]*p for _ in range(m)]for i in range(m):for j in range(n):for k in range(p):ans[i][k] = (ans[i][k]+a[i][j] * b[j][k]) %MODreturn ans
def matrix_pow_mod(a, b, MOD=10**9+7):n = len(a)ans = [[0]*n for _ in range(n)]for i in range(n):ans[i][i] = 1 while b:if b & 1:ans = matrix_multiply(ans, a, MOD)a = matrix_multiply(a, a, MOD)b >>= 1return ansclass Solution:def numberOfWays(self, s: str, t: str, k: int) -> int:if t not in s+s:return 0n = len(s)c = len(list(Kmp(t).find_all_yield(s+s[:-1]))) m = [[c-1, c],[n-c,n-1-c]]m = matrix_pow_mod(m, k)return m[0][s != t]
三、其他
- 一定别忘了取模。
- 通常小数据特判。
四、更多例题
- 790. 多米诺和托米诺平铺
五、参考链接
- 链接: KMP + 矩阵快速幂优化 DP(附题单)Python/Java/C++/Go/JS