Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion divide_and_conquer/strassen_matrix_multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,21 @@ def print_matrix(matrix: list) -> None:
def actual_strassen(matrix_a: list, matrix_b: list) -> list:
"""
Recursive function to calculate the product of two matrices, using the Strassen
Algorithm. It only supports square matrices of any size that is a power of 2.
Algorithm.

Time complexity:
The recurrence is T(n) = 7 T(n/2) + \u0398(n^2), which solves to
T(n) = \u0398(n^{log_2 7}) \u2248 \u0398(n^{2.8074}). This is asymptotically
faster than the naive \u0398(n^3) algorithm for sufficiently large n.

Space complexity:
Uses additional memory for temporary submatrices and padding; overall
space complexity is O(n^2).

Notes:
This function expects square matrices whose size is a power of two.
Matrices of other sizes are handled by `strassen` which pads to the
next power of two.
"""
if matrix_dimensions(matrix_a) == (2, 2):
return default_matrix_multiplication(matrix_a, matrix_b)
Expand Down Expand Up @@ -106,6 +120,16 @@ def actual_strassen(matrix_a: list, matrix_b: list) -> list:

def strassen(matrix1: list, matrix2: list) -> list:
"""
Multiply two matrices using Strassen's divide-and-conquer algorithm.

Time complexity:
\u0398(n^{log_2 7}) \u2248 \u0398(n^{2.8074})
(recurrence T(n) = 7 T(n/2) + \u0398(n^2)).

Space complexity:
O(n^2) due to padding and temporary matrices used during recursion.

Examples:
>>> strassen([[2,1,3],[3,4,6],[1,4,2],[7,6,7]], [[4,2,3,4],[2,1,1,1],[8,6,4,2]])
[[34, 23, 19, 15], [68, 46, 37, 28], [28, 18, 15, 12], [96, 62, 55, 48]]
>>> strassen([[3,7,5,6,9],[1,5,3,7,8],[1,4,4,5,7]], [[2,4],[5,2],[1,7],[5,5],[7,8]])
Expand Down