1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
| def matrix_strassen(a,b): n=len(a) c = [[0 for col in range(n)] for row in range(n)] if n==1: c[0][0]=a[0][0]*b[0][0] else: (a11,a12,a21,a22)=division(a) (b11,b12,b21,b22)=division(b) (c11,c12,c21,c22)=division(c) s1=matrix_add_sub(b12,b22,0) s2=matrix_add_sub(a11,a12,1) s3=matrix_add_sub(a21,a22,1) s4=matrix_add_sub(b21,b11,0) s5=matrix_add_sub(a11,a22,1) s6=matrix_add_sub(b11,b22,1) s7=matrix_add_sub(a12,a22,0) s8=matrix_add_sub(b21,b22,1) s9=matrix_add_sub(a11,a21,0) s10=matrix_add_sub(b11,b12,1) p1=matrix_strassen(a11,s1) p2=matrix_strassen(s2,b22) p3=matrix_strassen(s3,b11) p4=matrix_strassen(a22,s4) p5=matrix_strassen(s5,s6) p6=matrix_strassen(s7,s8) p7=matrix_strassen(s9,s10) c11=matrix_add_sub(matrix_add_sub(matrix_add_sub(p5,p4,1),p2,0),p6,1) c12=matrix_add_sub(p1,p2,1) c21=matrix_add_sub(p3,p4,1) c22=matrix_add_sub(matrix_add_sub(matrix_add_sub(p5,p1,1),p3,0),p7,0) c=matrix_combination(c11,c12,c21,c22) return c
def division(a): n=len(a)//2 a11=[[0 for i in range(n)]for j in range(n)] a12=[[0 for i in range(n)]for j in range(n)] a21=[[0 for i in range(n)]for j in range(n)] a22=[[0 for i in range(n)]for j in range(n)] for i in range(n): for j in range(n): a11[i][j]=a[i][j] a12[i][j]=a[i][j+n] a21[i][j]=a[i+n][j] a22[i][j]=a[i+n][j+n] return (a11,a12,a21,a22) def matrix_add_sub(a,b,keys): n = len(a) c = [[0 for col in range(n)] for row in range(n)] if keys==1: for i in range(n): for j in range(n): c[i][j] = a[i][j]+b[i][j] else: for i in range(n): for j in range(n): c[i][j]=a[i][j]-b[i][j] return c def matrix_combination(a11,a12,a21,a22): n2 = len(a11) n=n2*2 a = [[0 for col in range(n)] for row in range(n)] for i in range (0,n): for j in range (0,n): if i <= (n2-1) and j <= (n2-1): a[i][j] = a11[i][j] elif i <= (n2-1) and j > (n2-1): a[i][j] = a12[i][j-n2] elif i > (n2-1) and j <= (n2-1): a[i][j] = a21[i-n2][j] else: a[i][j] = a22[i-n2][j-n2] return a a=[[1,1,1,1],[1,1,1,1],[2,2,2,2],[2,2,2,2]] b=a print(matrix_strassen(a,b))
|