#include<iostream.h>
#include<stdlib.h>

#define BREAK 2
#define a11 a->p[0]
#define a12 a->p[1]
#define a21 a->p[2]
#define a22 a->p[3]
#define b11 b->p[0]
#define b12 b->p[1]
#define b21 b->p[2]
#define b22 b->p[3]
#define c11 c->p[0]
#define c12 c->p[1]
#define c21 c->p[2]
#define c22 c->p[3]
#define d11 d->p[0]
#define d12 d->p[1]
#define d21 d->p[2]
#define d22 d->p[3]

typedef double **matrix;
typedef union _strassen_matrix
{
    matrix d;
    union _strassen_matrix **p;
} *strassen_matrix;

matrix  new_matrix(int);
strassen_matrix new_strassen(int);
void normal_to_strassen(matrix, strassen_matrix, int);
void strassen_to_normal(strassen_matrix, matrix, int);
matrix  strassen_submatrix(strassen_matrix, int, int, int);
void copy_matrix(matrix, matrix, int);
void add_matrix(matrix, matrix, matrix, int);
void sub_matrix(matrix, matrix, matrix, int);
void copy_strassen(strassen_matrix, strassen_matrix, int);
void add_strassen(strassen_matrix, strassen_matrix, strassen_matrix, int);
void sub_strassen(strassen_matrix, strassen_matrix, strassen_matrix, int);
void mul_matrix(matrix, matrix, matrix, int);
void mul_strassen(strassen_matrix, strassen_matrix, strassen_matrix, strassen_matrix, int);
void print_matrix(matrix, int);

int least_power_of_two(int);


matrix new_matrix(int n)
{
  matrix a = (matrix) malloc(sizeof(double *) * n);
   for (int j = 0; j < n; j++)
		a[j] = (double *) malloc(sizeof(double) * n);
	return(a);
}

strassen_matrix new_strassen(int n)
{
    strassen_matrix	a;
    a = (strassen_matrix)malloc(sizeof(*a));
    if (n <= BREAK)
      	a->d = (matrix ) new_matrix(n);
   else
	{
		register int 	m = n/2;
      a->p = (strassen_matrix *)malloc(4*sizeof(strassen_matrix));
		a11 = new_strassen(m);
		a12 = new_strassen(m);
		a21 = new_strassen(m);
		a22 = new_strassen(m);
    }
    return a;
}

matrix strassen_submatrix(strassen_matrix a,
	int i,
	int j,
	int n
)
{
	if (n <= BREAK)
			return(a->d);
	else
	{
		int cur_bit, bit_num;
		strassen_matrix cur_ptr = a;
      bit_num = least_power_of_two(n)-1;
		cur_bit = n/2;
		while (cur_bit >= BREAK)
		{
			cur_ptr = cur_ptr->p[(((j & cur_bit) | ((i & cur_bit)*2)) >> bit_num)];
			cur_bit >>= 1;
			bit_num--;
		}
		return (cur_ptr->d);
	}
}

void normal_to_strassen(matrix a,strassen_matrix b,int n)
{
	if (n <= BREAK)
		copy_matrix(a,b->d,n);
	else
	{
		int		i,j,ii,jj;
		matrix 	sub;

		for (i=0; i<n; i += BREAK)
		{
			for (j=0; j<n; j += BREAK)
			{
				sub = strassen_submatrix(b,i,j,n);
				for (ii=0; ii<BREAK; ii++)
					for (jj=0; jj<BREAK; jj++)
								sub[ii][jj] = a[i+ii][j+jj];
         }
		}
	}
}

void strassen_to_normal(strassen_matrix a,matrix b,int n)
{
	if (n <= BREAK)
			copy_matrix(a->d,b, n);
	else
	{
		matrix 	sub;
     	for (int i=0; i<n; i += BREAK)
		{
			for (int j=0; j<n; j += BREAK)
			{
				sub = strassen_submatrix(a,i,j,n);
				for (int ii=0; ii<BREAK; ii++)
			  			for (int jj=0; jj<BREAK; jj++)
								b[i+ii][j+jj] = sub[ii][jj];
        	}
		}
	}
}

void copy_matrix(
	matrix 	a,
	matrix 	b,
	int 	n
)
{
  	for(int i=0; i<n; i++)
 		for(int j=0; j<n; j++)
				b[i][j] = a[i][j];
}


void add_matrix(
	matrix 	a,
	matrix 	b,
	matrix 	c,
	int 	n
)
{
	for (int i=0; i<n; i++)
		for (int j=0; j<n; j++)
			c[i][j] = b[i][j] + a[i][j];
}


void sub_matrix(
	matrix 	a,
	matrix 	b,
	matrix 	c,
	int 	n
)
{
 for (int i=0; i<n; i++)
  	for (int j=0; j<n; j++)
		c[i][j] = a[i][j] - b[i][j];
}


void add_strassen(
	strassen_matrix a,
	strassen_matrix b,
	strassen_matrix c,
	int 			n
)
{
	if (n <= BREAK)
			add_matrix(a->d, b->d, c->d, n);
	else
	{
		int m=n/2;
      add_strassen(a11, b11, c11, m);
		add_strassen(a12, b12, c12, m);
		add_strassen(a21, b21, c21, m);
		add_strassen(a22, b22, c22, m);
	}
}

void sub_strassen(
	strassen_matrix a,
	strassen_matrix b,
	strassen_matrix c,
	int 			n
)
{
	if (n <= BREAK)
	{
		sub_matrix(a->d, b->d, c->d, n);
	}
	else
	{
		int m = n/2;
      sub_strassen(a11, b11, c11, m);
		sub_strassen(a12, b12, c12, m);
		sub_strassen(a21, b21, c21, m);
		sub_strassen(a22, b22, c22, m);
	}
}

void mul_matrix(
	matrix 	a,
	matrix 	b,
	matrix 	c,
	int 	n
)
{
 for(int i=0; i<n; i++)
	for(int j=0; j<n; j++)
		{
			c[i][j] = 0.0;
	  	for(int k=0; k<n; k++)
			c[i][j] += a[i][k] * b[k][j];
     	}
}

void mul_strassen(
   strassen_matrix a,
	strassen_matrix b,
	strassen_matrix c,
	strassen_matrix d,
	int 			n
)
{
	if (n <= BREAK)
			mul_matrix(a->d,b->d,c->d,n);
	else
	{
		int m = n/2;
      sub_strassen(a12, a22, d11, m);
		add_strassen(b21, b22, d12, m);
		mul_strassen(d11, d12, c11, d21, m);
		sub_strassen(a21, a11, d11, m);
		add_strassen(b11, b12, d12, m);
		mul_strassen(d11, d12, c22, d21, m);
		add_strassen(a11, a12, d11, m);
		mul_strassen(d11, b22, c12, d12, m);
		sub_strassen(c11, c12, c11, m);
		sub_strassen(b21, b11, d11, m);
		mul_strassen(a22, d11, c21, d12, m);
		add_strassen(c21, c11, c11, m);
		sub_strassen(b12, b22, d11, m);
		mul_strassen(a11, d11, d12, d21, m);
		add_strassen(d12, c12, c12, m);
		add_strassen(d12, c22, c22, m);
		add_strassen(a21, a22, d11, m);
		mul_strassen(d11, b11, d12, d21, m);
		add_strassen(d12, c21, c21, m);
		sub_strassen(c22, d12, c22, m);
		add_strassen(a11, a22, d11, m);
		add_strassen(b11, b22, d12, m);
		mul_strassen(d11, d12, d21, d22, m);
		add_strassen(d21, c11, c11, m);
		add_strassen(d21, c22, c22, m);

	}
}

void print_matrix(matrix a,int n)
{
for(int i=0;i<n;i++){
   for(int j=0;j<n;j++)
        cout<<a[i][j]<<"\t";
    cout<<endl;
   }
}

int least_power_of_two(int n )
{
	int i = 1, k = 1;
   if (n==1)
		return (0);
	while ((k <<= 1) < n)
		i++;
	return(i);
}

void readMatrix(matrix a,int n)
{
   for(int i=0; i<n; i++)
		for(int j=0; j<n; j++)
        cin>>a[i][j];
}

void main()
{
int n;
matrix  a1, a2;
strassen_matrix b1, b2, b3, b4;
cout<<"Enter Size Of Matrix(Power Of 2):\n";
cin>>n;
   a1 = (matrix) new_matrix(n);
	a2 = (matrix) new_matrix(n);
   b1 = (strassen_matrix) new_strassen(n);
	b2 = (strassen_matrix) new_strassen(n);
	b3 = (strassen_matrix) new_strassen(n);
	b4 = (strassen_matrix) new_strassen(n);
  cout<<"Enter Matrix One\n";
    readMatrix(a1,n);
  cout<<"Enter Matrix Two\n";
    readMatrix(a2,n);
  normal_to_strassen(a1,b1,n);
  normal_to_strassen(a2,b2,n);
  mul_strassen(b1,b2,b3,b4,n);
  strassen_to_normal(b3,a2,n);
  cout<<"Result Is: \n";
  print_matrix(a2,n);
}

