2010-11-29 11 views
6

私はnxn行列Aを持ちます.nは2の累乗です。行列Aは4つの等しいサイズの部分行列に分割されます。 Javaでサブ行列A11、A12、A21、A22をどのように参照することができますか?私は分割を試みると行列乗算アルゴリズム(Strassenの)行列内の部分行列の参照方法

  A11 | A12 
    A --> --------- 
      A21 | A22 

EDITを征服しています:行列は整数配列として格納されます。int [] []。

+1

の実装ですか?多次元配列、または特殊なクラス? – Lagerbaer

+0

それがどのように保存されているかを知ることで、私たちはあなたを助けることができません。 –

+0

編集した質問をご覧ください。 – devnull

答えて

3

もし、ijがあなたのインデックスであれば、i = 0 ..(n/2)-1、j = 0 ..(n/2)-1でA11が得られます。 次に、A12は、i = 0 ..(n/2)-1、j = n/2..n-1などとなります。

「参照」するには、「i_min、i_max、j_min、j_max」が必要で、0からn-1のインデックスを実行するのではなく、最小から最大まで実行します。

+0

私はどのようにサブ行列の参照を方法。 C++では、正しいアドレスを割り当てて配列へのポインタに格納することができます。このポインタをメソッドに渡すことができます。 – devnull

+0

配列は参照渡しであるため、要素の境界を定義する4つの整数とともに行列自体を渡すことができます。または、「スライシング」を可能にする高度なマトリックスライブラリを使用します。 – Lagerbaer

0

毎回各サブマトリックスの内容をコピーするか、アドレッシングを計算するかどうかを決めなければならないと思います。あなたの質問は、あなたのサブマトリックスが分割されているのではなく連続していることを暗示しています(未成年者と補因子の判別の計算 - http://mathworld.wolfram.com/Determinant.html)。あなたがこれをしたい理由を示していないので、あなたがすでに遭遇したパフォーマンスのヒットと、より小さなマトリックスへの再帰があるかどうかは、あなただけがコピーの単純さと複雑さのバランスを決めることができると思う再帰的アドレッシングのしかし、私はすでに図書館があることを期待しており、私はhttp://commons.apache.org/math/をチェックするだろう。

1

これはあなたの行列が保存されているどのようにStrassen algorithm for matrix multiplication

import java.io.*; 

public class MatrixMultiplication { 

    public static BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); 

    public MatrixMultiplication() throws IOException { 
     int n; 
     int[][] a, b; 

     System.out.print("Enter the number for rows/colums: "); 
     n = Integer.parseInt(br.readLine()); 

     a = new int[n][n]; 
     b = new int[n][n]; 

     System.out.print("\n\n\nEnter the values for the first matrix:\n\n"); 
     for (int i = 0; i < n; i++) { 
      for (int j = 0; j < n; j++) { 
       System.out.print("Enter the value for cell("+(i+1)+","+(j+1)+"): "); 
       a[i][j] = Integer.parseInt(br.readLine()); 
      } 
     } 
     System.out.print("\n\n\nEnter the values for the second matrix:\n"); 
     for (int i = 0; i < n; i++) { 
      for (int j = 0; j < n; j++) { 
       System.out.print("Enter the value for cell ("+(i+1)+","+(j+1)+"): "); 
       b[i][j] = Integer.parseInt(br.readLine()); 
      } 
     } 

     System.out.print("\n\nMatrix multiplication using standard method:\n"); 
     print(multiplyWithStandard(a, b)); 

     System.out.print("\n\nMatrix multiplication using Strassen method:\n"); 
     print(multiplyWithStandard(a, b)); 
    } 

    public int[][] multiplyWithStandard(int[][] a, int[][] b) { 
     int n = a.length; 
     int[][] c = new int[n][n]; 

     for (int i = 0; i < n; i++) { 
      for (int j = 0; j < n; j++) { 
       for (int k = 0; k < n; k++) { 
        c[i][j] += a[i][k] * b[k][j]; 
       } 
      } 
     } 
     return c; 
    } 

    public int[][] multiplyWithStrassen(int [][] A, int [][] B) { 
     int n = A.length; 
     int [][] result = new int[n][n]; 

     if (n == 1) { 
      result[0][0] = A[0][0] * B[0][0]; 
     } else if ((n%2 != 0) && (n != 1)) { 
      int[][] a1, b1, c1; 
      int n1 = n+1; 
      a1 = new int[n1][n1]; 
      b1 = new int[n1][n1]; 
      c1 = new int[n1][n1]; 

      for (int i = 0; i < n; i++) { 
       for (int j = 0; j < n; j++) { 
        a1[i][j] = A[i][j]; 
        b1[i][j] = B[i][j]; 
       } 
      } 
      c1 = multiplyWithStrassen(a1, b1); 
      for (int i = 0; i < n; i++) { 
       for (int j = 0; j < n; j++) { 
        result[i][j] = c1[i][j]; 
       } 
      } 
     } else { 
      int [][] A11 = new int[n/2][n/2]; 
      int [][] A12 = new int[n/2][n/2]; 
      int [][] A21 = new int[n/2][n/2]; 
      int [][] A22 = new int[n/2][n/2]; 

      int [][] B11 = new int[n/2][n/2]; 
      int [][] B12 = new int[n/2][n/2]; 
      int [][] B21 = new int[n/2][n/2]; 
      int [][] B22 = new int[n/2][n/2]; 

      divideArray(A, A11, 0 , 0); 
      divideArray(A, A12, 0 , n/2); 
      divideArray(A, A21, n/2, 0); 
      divideArray(A, A22, n/2, n/2); 

      divideArray(B, B11, 0 , 0); 
      divideArray(B, B12, 0 , n/2); 
      divideArray(B, B21, n/2, 0); 
      divideArray(B, B22, n/2, n/2); 

      int [][] M1 = multiplyWithStrassen(add(A11, A22), add(B11, B22)); 
      int [][] M2 = multiplyWithStrassen(add(A21, A22), B11); 
      int [][] M3 = multiplyWithStrassen(A11, subtract(B12, B22)); 
      int [][] M4 = multiplyWithStrassen(A22, subtract(B21, B11)); 
      int [][] M5 = multiplyWithStrassen(add(A11, A12), B22); 
      int [][] M6 = multiplyWithStrassen(subtract(A21, A11), add(B11, B12)); 
      int [][] M7 = multiplyWithStrassen(subtract(A12, A22), add(B21, B22)); 

      int [][] C11 = add(subtract(add(M1, M4), M5), M7); 
      int [][] C12 = add(M3, M5); 
      int [][] C21 = add(M2, M4); 
      int [][] C22 = add(subtract(add(M1, M3), M2), M6); 

      copyArray(C11, result, 0 , 0); 
      copyArray(C12, result, 0 , n/2); 
      copyArray(C21, result, n/2, 0); 
      copyArray(C22, result, n/2, n/2); 
     } 
     return result; 
    } 

    public int[][] add(int [][] A, int [][] B) { 
     int n = A.length; 
     int [][] result = new int[n][n]; 

     for (int i = 0; i < n; i++) { 
      for (int j = 0; j < n; j++) 
       result[i][j] = A[i][j] + B[i][j]; 
      } 
     return result; 
    } 

    public int[][] subtract(int [][] A, int [][] B) { 
     int n = A.length; 
     int [][] result = new int[n][n]; 

     for (int i = 0; i < n; i++) { 
      for (int j = 0; j < n; j++) { 
       result[i][j] = A[i][j] - B[i][j]; 
      } 
     }  
     return result; 
    } 

    private void divideArray(int[][] parent, int[][] child, int iB, int jB) { 
     for (int i1 = 0, i2 = iB; i1 < child.length; i1++, i2++) { 
      for (int j1 = 0, j2 = jB; j1 < child.length; j1++, j2++) { 
       child[i1][j1] = parent[i2][j2]; 
      } 
     } 
    } 

    private void copyArray(int[][] child, int[][] parent, int iB, int jB) { 
     for(int i1 = 0, i2 = iB; i1 < child.length; i1++, i2++) { 
      for(int j1 = 0, j2 = jB; j1 < child.length; j1++, j2++) { 
       parent[i2][j2] = child[i1][j1]; 
      } 
     } 
    } 

    public void print(int [][] array) { 
     int n = array.length; 

     System.out.println(); 
     for (int i = 0; i < n; i++) { 
      for (int j = 0; j < n; j++) { 
       System.out.print(array[i][j] + "\t"); 
      } 
      System.out.println(); 
     } 
     System.out.println(); 
    } 

    public static void main(String[] args) throws IOException { 
     new MatrixMultiplication(); 
    } 
} 
関連する問題