/*
 * C++ class for MxN matrices.
 *
 * $Id: Matrix.C,v 1.6 91/04/14 18:27:48 bjaspan Exp $
 */


/*
 * (1) An m x n matrix has m rows and n columns -- ie height x width.
 * 
 * (2) nam == 1 --> nothing else is guaranteed (some code here might
 * violate that.. sigh)
 *
 * (3) Two NAM's are equal.
 */

#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include "Matrix.H"

Matrix *Matrix::M_NAM = Matrix::MakeNAM();

/* Constructors */

void Matrix::Init(const int m, const int n)
{
     M_m = m;
     M_n = n;
     nam = 0;
     d = new double[m*n];
}

Matrix::Matrix()
{
     nam = 1;
     M_m = M_n = 0;
     d = 0;
}

Matrix::Matrix(int m, int n)
{
     Init(m, n);
}

Matrix::Matrix(int m, int n, double first ...)
{
     int i;
     va_list ap;
     
     Init(m, n);

     d[0] = first;
     va_start(ap, first);
     for (i=1; i < m*n; i += 1)
	  d[i] = va_arg(ap, double);
}

Matrix::Matrix(int m, int n, int first ...)
{
     int i;
     va_list ap;
     
     Init(m, n);

     d[0] = double(first);
     va_start(ap, first);
     for (i=1; i < m*n; i += 1)
	  d[i] = double(va_arg(ap, int));
}

Matrix::Matrix(int m, int n, int *array)
{
     int i;

     Init(m, n);
     for (i=0; i < (m*n); i += 1)
	  d[i] = double(array[i]);
}

Matrix::Matrix(int m, int n, double *array)
{
     int i;

     Init(m, n);
     for (i=0; i < (m*n); i += 1)
	  d[i] = array[i];
}

Matrix::Matrix(const Matrix& m1)
{
     int i;

     Init(m1.M_m, m1.M_n);
     for (i=0; i < (M_m*M_n); i += 1)
	  d[i] = m1.d[i];
}

Matrix::Matrix(const Matrix& m1, int r, int c, int m, int n)
{
     int i,j;

     if (r<1 || c<1 || m1.Rows() < r-1+m || m1.Cols() < c-1+n)
	  nam = 1;
     else {
	  Init(m,n);

	  for (i=1;i<=m;i++)
	       for (j=1;j<=n;j++)
		    (*this)(i,j) = m1(r+i-1,c+j-1);
     }
}

Matrix::~Matrix()
{
     delete [M_m*M_n] d;
}

/* Overloaded operators */

int operator==(const Matrix& m1, const Matrix& m2)
{
     int i;

     if (m1.nam != m2.nam)
	  return 0;
     if (m1.nam == 1)
	  return 1;
     if (! m1.SameSize(m2))
	  return 0;

     for (i = 0; i < m1.M_m*m1.M_n; i++)
	  if (m1.d[i] != m2.d[i])
	       return 0;

     return 1;
}

Matrix operator+(const Matrix& m1, const Matrix& m2)
{
     if (m1.M_m != m2.M_m || m1.M_n != m2.M_n ||
	 m1 == Matrix::NAM() || m2 == Matrix::NAM())
	  return Matrix::NAM();

     Matrix m(m1);

     for (int i=0; i < (m1.M_m*m1.M_n); i++)
	  m.d[i] += m2.d[i];

     return m;
}

Matrix operator-(const Matrix& m1, const Matrix& m2)
{
     if (m1.M_m != m2.M_m || m1.M_n != m2.M_n ||
	 m1 == Matrix::NAM() || m2 == Matrix::NAM())
	  return Matrix::NAM();

     Matrix m(m1);

     for (int i=0; i < (m1.M_m*m1.M_n); i++)
	  m.d[i] -= m2.d[i];

     return m;
}

Matrix operator*(const Matrix& m1, const Matrix& m2)
{
     if (m1.M_n != m2.M_m || m1 == Matrix::NAM() || m2 == Matrix::NAM())
	  return Matrix::NAM();

     Matrix m(m1.M_m, m2.M_n);
     double val = 0.0;

     for (int i=1; i <= m1.M_m; i += 1)
     for (int j=1; j <= m2.M_n; j += 1) {
	  m(i, j) = 0.0;
	  for (int k=1; k <= m1.M_n; k += 1) 
	       m(i, j) += m1(i, k) * m2(k, j);
     }
     
     return m;
}

Matrix& Matrix::operator+=(const Matrix& m)
{
     if (M_m != m.M_m || M_n != m.M_n || *this == Matrix::NAM() ||
	 m == Matrix::NAM())
	  nam = 1;
     else
	  for (int i=0; i < (M_m*M_n); i++)
	       d[i] += m.d[i];

     return *this;
}

Matrix& Matrix::operator-=(const Matrix& m)
{
     if (M_m != m.M_m || M_n != m.M_n || *this == Matrix::NAM() ||
	 m == Matrix::NAM())
	  nam = 1;
     else
	  for (int i=0; i < (M_m*M_n); i++)
	       d[i] -= m.d[i];

     return *this;
}

Matrix operator*(double k, const Matrix& m)
{
   Matrix n(m);
   int i;

   for (i=0;i<n.M_m*n.M_n;i++) n.d[i]*=k;
   return(n);
}

Matrix& Matrix::operator=(const Matrix& m)
{
     nam = m.nam;
     
     if (nam)
	  ;
     else if (M_m == m.M_m && M_n == m.M_n)
	  bcopy(m.d, d, M_m*M_n*sizeof(double));
     else if (M_m*M_n == m.M_m*m.M_n) {
	  bcopy(m.d, d, M_m*M_n*sizeof(double));
	  M_m = m.M_m;
	  M_n = m.M_n;
     } else {
	  delete [M_m*M_n] d;
	  
	  M_m = m.M_m;
	  M_n = m.M_n;
	  
	  d = new double[M_m*M_n];
	  bcopy(m.d, d, M_m*M_n*sizeof(double));
     }

     return(*this);
}

Matrix operator|(const Matrix& m1, const Matrix& m2)
   // Concatenates two matrices horizontally.  Requires that m1 and m2
   // have the same number of rows.
   // [ A ] | [ B ] --> [ A | B ]
{
     if (m1.Rows() != m2.Rows() || m1 == Matrix::NAM() || m2 == Matrix::NAM())
	  return Matrix::NAM();

     Matrix m(m1.Rows(), m1.Cols() + m2.Cols());
     int r, c;

     for (r = 1; r <= m1.Rows(); r += 1)
	  for (c = 1; c <= m1.Cols(); c += 1)
	       m(r, c) = m1(r, c);

     for (r = 1; r <= m1.Rows(); r += 1)
	  for (c = 1; c <= m2.Cols(); c += 1)
	       m(r, c + m1.Cols()) = m2(r, c);

     return m;
}

Matrix operator&(const Matrix& m1, const Matrix& m2)
   // Concatenates two matrices vertically.  Requires that m1 and m2
   // have the same number of columns.
   //                   [ A ]
   // [ A ] & [ B ] --> [ - ]
   //                   [ B ]
{
     if (m1.Cols() != m2.Cols() || m1 == Matrix::NAM() || m2 == Matrix::NAM())
	  return Matrix::NAM();

     Matrix m(m1.Rows()+m2.Rows(), m1.Cols());
     int r, c;

     for (r = 1; r <= m1.Rows(); r += 1)
	  for (c = 1; c <= m1.Cols(); c += 1)
	       m(r, c) = m1(r, c);
     for (r = 1; r <= m2.Rows(); r += 1)
	  for (c = 1; c <= m2.Cols(); c += 1)
	       m(r + m1.Rows(), c) = m2(r, c);
     
#if 0
     // This doesn't work, and I don't know why.
     bcopy(m1.d, m.d, m1.Rows()*m1.Cols()*sizeof(double));
     bcopy(m2.d, m.d + m1.Rows()*m1.Cols()*sizeof(double),
	   m2.Rows()*m2.Cols()*sizeof(double));
#endif

     return m;
}

ostream& operator<<(ostream& s, const Matrix& m)
{
     int i,j;
     
     if (m.nam) {
	  s << "NAM\n";
	  return(s);
     }
     
     for (i = 1; i <= m.M_m; i++) {
	  for (j = 1; j <= m.M_n; j++)
#ifdef __GNUG__
	       s.form("%6.2f ", m(i, j));
#else
	       s << form("%6.2f ", m(i, j));
#endif
	  s << "\n";
     }
     
     return(s);
}

/* Member functions */

Matrix Matrix::T() const
{
     Matrix t(M_n, M_m);

     for (int i = 1; i <= M_m; i++)
	  for (int j = 1; j <= M_n; j++)
	       t(j, i) = (*this)(i, j);

     return t;
}

void Matrix::Print() const
{
   cout << *this;
}

/* Special matrix constructors */
Matrix *Matrix::MakeNAM()
{
     Matrix *m = new Matrix(1,1);
     m->nam = 1;
     return m;
}

/* Matrix mutators */
void Matrix::Identify()	
{
     if (M_m != M_n) {
	  nam = 1;
	  return;
     }

     for (int i = 1; i <= M_m; i++)
	  for (int j = 1; j <= M_n; j++)
	       (*this)(i, j) = 0.0;

     for (i = 1; i <= M_m; i++)
	  (*this)(i, i) = 1.0;
}
