java实现任意矩阵Strassen算法

网友投稿 215 2023-07-22


java实现任意矩阵Strassen算法

本例输入为两个任意尺寸的矩阵m * n, n * m,输出为两个矩阵的乘积。计算任意尺寸矩阵相乘时,使用了Strassen算法。程序为自编,经过测试,请放心使用。基本算法是:

1.对于方阵(正方形矩阵),找到最大的l, 使得l = 2 ^ k, k为整数并且l < m。边长为l的方形矩阵则采用Strassen算法,其余部分以及方形矩阵中遗漏的部分用蛮力法。

2.对于非方阵,依照行列相应添加0使其成为方阵。

StrassenMethodTest.java

package matrixalgorithm;

import java.util.Scanner;

public class StrassenMethodTest {

private StrassenMethod strassenMultiply;

StrassenMethodTest(){

strassenMultiply = new StrassenMethod();

}//end cons

public static void main(String[] args){

Scanner input = new Scanner(System.in);

System.out.println("Input row size of the first matrix: ");

int arow = input.nextInt();

System.out.println("Input column size of the first matrix: ");

int acol = input.nextInt();

System.out.println("Input row size of the second matrix: ");

int brow = input.nextInt();

System.out.println("Input column size of the second matrix: ");

int bcol = input.nextInt();

double[][] A = new double[arow][acol];

double[][] B = new double[brow][bcol];

double[][] C = new double[arow][bcol];

System.out.println("Input data for matrix A: ");

/*In all of the codes later in this project,

r means row while c means column.

*/

for (int r = 0; r < arow; r++) {

for (int c = 0; c < acol; c++) {

System.out.printf("Data of A[%d][%d]: ", r, c);

A[r][c] = input.nextDouble();

}//end inner loop

}//end loop

System.out.println("Input data for matrix B: ");

for (int r = 0; r < brow; r++) {

for (int c = 0; c < bcol; c++) {

System.out.printf("Data of A[%d][%d]: ", r, c);

B[r][c] = input.nextDouble();

}//end inner loop

}//end loop

StrassenMethodTest algorithm = new StrassenMethodTest();

C = algorithm.multiplyRectMatrix(A, B, arow, acol, brow, bcol);

//Display the calculation result:

System.out.println("Result from matrix C: ");

for (int r = 0; r < arow; r++) {

for (int c = 0; c < bcol; c++) {

System.out.printf("Data of C[%d][%d]: %f\n", r, c, C[r][c]);

}//end inner loop

}//end outter loop

}//end main

//Deal with matrices that are not square:

public double[][] multiplyRectMatrix(double[][] A, double[][] B,

int arow, int acol, int brow, int bcol) {

if (arow != bcol) //Invalid multiplicatio

return new double[][]{{0}};

double[][] C = new double[arow][bcol];

if (arow < acol) {

double[][] newA = new double[acol][acol];

double[][] newB = new double[brow][brow];

int n = acol;

for (int r = 0; r < acol; r++)

for (int c = 0; c < acol; c++)

newA[r][c] = 0.0;

for (int r = 0; r < brow; r++)

for (int c = 0; c < brow; c++)

newB[r][c] = 0.0;

for (int r = 0; r < arow; r++)

for (int c = 0; c < acol; c++)

newA[r][c] = A[r][c];

for (int r = 0; r < brow; r++)

for (int c = 0; c < bcol; c++)

newB[r][c] = B[r][c];

double[][] C2 = multiplySquareMatrix(newA, newB, n);

for(int r = 0; r < arow; r++)

for(int c = 0; c < bcol; c++)

C[r][c] = C2[r][c];

}//end if

else if(arow == acol)

C = multiplySquareMatrix(A, B, arow);

else {

int n = arow;

double[][] newA = new double[arow][arow];

double[][] newB = new double[bcol][bcol];

for (int r = 0; r < arow; r++)

for (int c = 0; c < arow; c++)

newA[r][c] = 0.0;

for (int r = 0; r < bcol; r++)

for (int c = 0; c < bcol; c++)

newB[r][c] = 0.0;

for (int r = 0; r < arow; r++)

for (int c = 0; c < acol; c++)

newA[r][c] = A[r][c];

for (int r = 0; r < brow; r++)

for (int c = 0; c < bcol; c++)

newB[r][c] = B[r][c];

double[][] C2 = multiplySquareMatrix(newA, newB, n);

for(int r = 0; r < arow; r++)

for(int c = 0; c < bcol; c++)

C[r][c] = C2[r][c];

}//end else

return C;

}//end method

//Deal with matrices that are square matrices.

public double[][] multiplySquareMatrix(double[][] A2, double[][] B2, int n){

double[][] C2 = new double[n][n];

for(int r = 0; r < n; r++)

for(int c = 0; c < n; c++)

C2[r][c] = 0;

if(n == 1){

C2[0][0] = A2[0][0] * B2[0][0];

return C2;

}//end if

int exp2k = 2;

while(exp2k <= (n / 2) ){

exp2k *= 2;

}//end loop

if(exp2k == n){

C2 = strassenMultiply.strassenMultiplyMatrix(A2, B2, n);

return C2;

}//end else

//The "biggest" strassen matrix:

double[][][] A = new double[6][exp2k][exp2k];

double[][][] B = new dhttp://ouble[6][exp2k][exp2k];

double[][][] C = new double[6][exp2k][exp2k];

for(int r = 0; r < exp2k; r++){

for(int c = 0; c < exp2k; c++){

A[0][r][c] = A2[r][c];

B[0][r][c] = B2[r][c];

}//end inner loop

}//end outter loop

C[0] = strassenMultiply.strassenMultiplyMatrix(A[0], B[0], exp2k);

for(int r = 0; r < exp2k; r++)

for(int c = 0; c < exp2k; c++)

C2[r][c] = C[0][r][c];

int middle = exp2k / 2;

for(int r = 0; r < middle; r++){

for(int c = exp2k; c < n; c++){

A[1][r][c - exp2k] = http://A2[r][c];

B[3][r][c - exp2k] = B2[r][c];

}//end inner loop

}//end outter loop

for(int r = exp2k; r < n; r++){

for(int c = 0; c < middle; c++){

A[3][r - exp2k][c] = A2[r][c];

B[1][r - exp2k][c] = B2[r][c];

}//end inner loop

}//end outter loop

for(int r = middle; r < exp2k; r++){

for(int c = exp2k; c < n; c++){

A[2][r - middle][c - exp2k] = A2[r][c];

B[4][r - middle][c - exp2k] = B2[r][c];

}//end inner loop

}//end outter loop

for(int r = exp2k; r < n; r++){

for(int c = middle; c < n - exp2k + 1; c++){

A[4][r - exp2k][c - middle] = A2[r][c];

B[2][r - exp2k][c - middle] = B2[r][c];

}//end inner loop

}//end outter loop

for(int i = 1; i <= 4; i++)

C[i] = multiplyRectMatrix(A[i], B[i], middle, A[i].length, A[i].length, middle);

/*

Calculate the final results of grids in the "biggest 2^k square,

according to the rules of matrice multiplication.

*/

for (int row = 0; row < exp2k; row++) {

for (int col = 0; col < exp2k; col++) {

for (int k = exp2k; k < n; k++) {

C2[row][col] += A2[row][k] * B2[k][col];

}//end loop

}//end inner loop

}//end outter loop

//Use brute force to solve the rest, will be improved later:

for(int col = exp2k; col < n; col++){

for(int row = 0; row < n; row++){

for(int k = 0; k < n; k++)

C2[row][col] += A2[row][k] * B2[k][row];

}//end inner loop

}//end outter loop

for(int row = exp2k; row < n; row++){

for(int col = 0; col < exp2k; col++){

for(int k = 0; k < n; k++)

C2[row][col] += A2[row][k] * B2[k][row];

}//end inner loop

}//end outter loop

return C2;

}//end method

}//end class

StrassenMethod.java

package matrixalgorithm;

import java.util.Scanner;

public class StrassenMethod {

private double[][][][] A = new double[2][2][][];

private double[][][][] B = new double[2][2][][];

private double[][][][] C = new double[2][2][][];

/*//Codes for testing this class:

public static void main(String[] args) {

Scanner input = new Scanner(System.in);

System.out.println("Input size of the matrix: ");

int n = input.nextInt();

double[][] A = new double[n][n];

double[][] B = new double[n][n];

double[][] C = new double[n][n];

System.out.println("Input data for matrix A: ");

for (int r = 0; r < n; r++) {

for (int c = 0; c < n; c++) {

System.out.printf("Data of A[%d][%d]: ", r, c);

A[r][c] = input.nextDouble();

}//end inner loop

}//end loop

System.out.println("Input data for matrix B: ");

for (int r = 0; r < n; r++) {

for (int c = 0; c < n; c++) {

System.out.printf("Data of A[%d][%d]: ", r, c);

B[r][c] = input.nextDouble();

}//end inner loop

}//end loop

StrassenMethod algorithm = new StrassenMethod();

C = algorithm.strassenMultiplyMatrix(A, B, n);

System.out.println("Result from matrix C: ");

for (int r = 0; r < n; r++) {

for (int c = 0; c < n; c++) {

System.out.printf("Data of C[%d][%d]: %f\n", r, c, C[r][c]);

}//end inner loop

}//end outter loop

}//end main*/

public double[][] strassenMultiplyMatrix(double[][] A2, double B2[][], int n){

double[][] C2 = new double[n][n];

//Initialize the matrix:

for(int rowIndex = 0; rowIndex < n; rowIndex++)

for(int colIndex = 0; colIndex < n; colIndex++)

C2[rowIndex][colIndex] = 0.0;

if(n == 1)

C2[0][0] = A2[0][0] * B2[0][0];

//"Slice matrices into 2 * 2 parts:

else{

double[][][][] A = new double[2][2][n / 2][n / 2];

double[][][][] B = new double[2][2][n / 2][n / 2];

double[][][][] C = new double[2][2][n / 2][n / 2];

for(int r = 0; r < n / 2; r++){

for(int c = 0; c < n / 2; c++){

A[0][0][r][c] = A2[r][c];

A[0][1][r][c] = A2[r][n / 2 + c];

A[1][0][r][c] = A2[n / 2 + r][c];

A[1][1][r][c] = A2[n / 2 + r][n / 2 + c];

B[0][0][r][c] = B2[r][c];

B[0][1][r][c] = B2[r][n / 2 + c];

B[1][0][r][c] = B2[n / 2 + r][c];

B[1][1][r][c] = B2[n / 2 + r][n / 2 + c];

}//end loop

}//end loop

n = n / 2;

double[][][] S = new double[10][n][n];

S[0] = minusMatrix(B[0][1], B[1][1], n);

S[1] = addMatrix(A[0][0], A[0][1], n);

S[2] = addMatrix(A[1][0], A[1][1], n);

S[3] = minusMatrix(B[1][0], B[0][0], n);

S[4] = addMatrix(A[0][0], A[1][1], n);

S[5] = addMatrix(B[0][0], B[1][1], n);

S[6] = minusMatrix(A[0][1], A[1][1], n);

S[7] = addMatrix(B[1][0], B[1][1], n);

S[8] = minusMatrix(A[0][0], A[1][0], n);

S[9] = addMatrix(B[0][0], B[0][1], n);

double[][][] P = new double[7][n][n];

P[0] = strassenMultiplyMatrix(A[0][0], S[0], n);

P[1] = strassenMultiplyMatrix(S[1], B[1][1], n);

P[2] = strassenMultiplyMatrix(S[2], B[0][0], n);

P[3] = strassenMultiplyMatrix(A[1][1], S[3], n);

P[4] = strassenMultiplyMatrix(S[4], S[5], n);

P[5] = strassenMultiplyMatrix(S[6], S[7], n);

P[6] = strassenMultiplyMatrix(S[8], S[9], n);

C[0][0] = addMatrix(minusMatrix(addMatrix(P[4], P[3], n), P[1], n), P[5], n);

C[0][1] = addMatrix(P[0], P[1], n);

C[1][0] = addMatrix(P[2], P[3], n);

C[1][1] = minusMatrix(minusMatrix(addMatrix(P[4], P[0], n), P[2], n), P[6], n);

n *= 2;

for(int r = 0; r < n / 2; r++){

for(int c = 0; c < n / 2; c++){

C2[r][c] = C[0][0][r][c];

C2[r][n / 2 + c] = C[0][1][r][c];

C2[n / 2 + r][c] = C[1][0][r][c];

C2[n / 2 + r][n / 2 + c] = C[1][1][r][c];

}//end inner loop

}//end outter loop

}//end else

return C2;

}//end method

//Add two matrices according to matrix addition.

private double[][] addMatrix(double[][] A, double[][] B, int n){

double C[][] = new double[n][n];

for(int r = 0; r < n; r++)

for(int c = 0; c < n; c++)

C[r][c] = A[r][c] + B[r][c];

return C;

}//end method

//Substract two matrices according to matrix addition.

private double[][] minusMatrix(double[][] A, double[][] B, int n){

double C[][] = new double[n][n];

for(int r = 0; r < n; r++)

for(int c = 0; c < n; c++)

C[r][c] = A[r][c] - B[r][c];

return C;

}//end method

}//end class

希望本文所述对大家学习java程序设计有所帮助。


版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:使用Java通过OAuth协议验证发送微博的教程
下一篇:详解Java设计模式编程中的中介者模式
相关文章

 发表评论

暂时没有评论,来抢沙发吧~