/* Declare structures and library functions that are used */
#include <math.h>	// sqrt()
#include <stdio.h>	// printf()
#define	min(x,y)	(((x) < (y)) ? (x) : (y))
class GenMatrix {
public:
	GenMatrix();
	GenMatrix(double *, int em, int en);
	GenMatrix(int em, int en);
	GenMatrix& operator+=(const GenMatrix& M);
	GenMatrix& operator-=(const GenMatrix& M);
	void Print(char const *message);
};
// Print a message and exit
void QCrash(char const *msg);
// Just print the warning
void QWarn(char const *msg);
// Real Vector ("VR") routines.  In the description comments,
// v[len] means "operate on the whole vector", not the actual C/C++ syntax
// for accessing a single element:
// copy dest[len] = src[len]
void VRcopy(double *dest, double *src, int len);
// *val = dot(l[len], r[len])
void VRdot(double *val, double *l, double *r, int len);
// vec[len] *= val
void VRscale(double *vec, int len, double val);
// dest[len] += src[len]
void VRadd2(double *dest, double *src, int len);
// dest[len] -= src[len]
void VRsub2(double *dest, double *src, int len);
// vec[len] = val
void VRload(double *vec, int len, double val);
// y[len] += a * x[len]
void VRdxy(double a, double *x, double *y, int len);
// A = CofA * A + CofBC * B*C, possibly using B^T or C^T
void MatMult(GenMatrix A, GenMatrix B, GenMatrix C,
	bool TransB = false, bool TransC = false,
	double CofA = 0.0, double CofBC = 1.0);
// Memory allocator customized for doubles.
double *QAllocDouble(int len);
// Initialize the allocation to zeros, too.
double *QAllocDoubleWithInit(int len);
// Free an allocation.
void QFree(void *);
// Pretend we're Fortran for calling into LAPACK.
typedef char CHAR1;
typedef int INTEGER;
// The LAPACK routine of the same name.
// Note that the LAPACK routine expects char*s; CHAR1 is actually supposed
// to be a class so that it can have a constructor that takes just 'char's
// as we are giving it, here.  So this file will compile with these
// prototypes, but will not run properly even if it could be linked.
/*
 *  DTRTRS solves a triangular system of the form
 *
 *     A * X = B  or  A**T * X = B,
 */
void dtrtrs(CHAR1, CHAR1, CHAR1, INTEGER*, INTEGER*, double*, INTEGER*,
	double*, INTEGER*, INTEGER*);
/* End of declarations of things for which headers are not public */

/*
 * This file implements the GMRES iterative linear solver algorithm.
 * The premise is to use a matrix-vector product routine that computes
 * A.x to solve the linear system A.x == b for a given 'b', without ever
 * computing or diagonalizing A explicitly.
 * Of course, this does not usually give an algorithmic speedup if one
 * is interested in the exact solution vector 'x', but if an approximate
 * one is acceptable (as is always the case for floating-point arithmetic),
 * then this method can give a great speedup, since with an appropriate
 * preconditioner, a solution of a given accuracy can usually be found in
 * a constant number of iterations.
 * A preconditioner is needed to help make 'A' as close to diagonal as
 * possible, as the algorithm tends to converge more quickly when that is
 * the case.  A full discussion is beyond the scope of this comment, but
 * any numerical linear algebra text should give some insight into the
 * nature of this sort of problem.
 *
 * Given a preconditioner, then, we are actually solving
 * A_0^{-1}.A.x == A_0^{-1}.b
 * where A_0^{-1}.x may similarly be evaluated quickly, as we assumed
 * for the product A.x .
 * For the remainder of this discussion, I assume that the preconditioner
 * has already been wrapped into 'A' and 'b', as it makes things shorter
 * to write out.
 * The basic algorithm involves assembling a Krylov space corresponding
 * to the vectors Q_n = {b, A.b, A^2.b, A^3.b, ... A^n.b}, and finding the best
 * candidate solution vector x in Q.  Of course, just computing those
 * vectors as written will quickly become numerically unstable, since
 * the largest eigenvalue of 'A' will become dominant.  Instead, we construct
 * a set of vectors {q_0, q_1, q_2, ..., q_n} which span Q_n but are
 * orthonormal.  This is effected via a Gram-Schmidt orthogonalization at
 * each step of the algorithm.  The initial q_0 is seeded as b/|b|, since
 * this ends up making some of the later algebra much simpler.
 * At each iteration, we compute V = A.q_n, and subtract off the
 * overlap with q_0 through q_n to yield a candidate q_{n+1} which is then
 * normalized. (If this candidate vector has norm almost zero, we have
 * found a closed subspace of the Kryolv space and must restart the
 * algorithm with a different seed vector 'b'.  This is quite uncommon
 * in practice, and is not implemented here.)
 * We also keep track of a matrix H which has Hessenberg structure (its
 * lower-1 triangle is all zero).  This matrix represents essentially
 * the projection of 'A' onto the "basis" {q_n} (which obviously cannot
 * span the full vector space in the general case).  Conveneniently,
 * in this "basis", our initial vector 'b' is just |b|*e_0, proportional
 * to the elementary vector with the first component 1 and all other
 * components zero.  This means that we can solve "A.x == b" in our
 * incomplete "basis" of q_n rather efficiently (but only approximately
 * in the least-squares sense), since this problem is then
 * H.y == |b|*e_0
 * and H has a nice structure to it.  In particular, if we claim that
 * we have a Q-R decomposition of H (Q is orthogonal and R is
 * upper triangular), we end up just needing to solve the upper
 * triangular system R.y == Q^{-1}.(|b|*e_0) .
 * Since we are starting from a Hessenberg matrix H, we can perform the
 * decomposition by doing a sequence of rotations to eliminate the
 * lower diagonal elements -- since we only need Q^{-1}.(|b|*e_0), we
 * need not explicitly store the matrix Q and can just update our
 * initial vector with these pairwise rotations in-place to obtain
 * the final vector "jQb" == Q^{-1}.(|b|*e_0) .
 * This transformation is factored into the "hessenberg_qr_with_b"
 * routine; it is not particularly transferrable due to the very
 * precise structure of the problem it solves; as such, it is given
 * only file-local linkage.
 * The "hessenberg_qr_with_b" returns the vector jQb into the least-squares
 * routine, which calls into LAPACK to actually solve the upper-triangular
 * system for the solution vector 'y'.
 * The best solution vector 'x' for our original problem uses the coefficients
 * of 'y' in our "basis" of q_i.
 * We then check whether this approximate solution vector is good enough
 * for the given tolerance (using a callback function to see whether
 * we should check the unpreconditioned residual, which is more accurate
 * but also computationally expensive), and return if an acceptable solution
 * is found.
 */

// Forward-declare our helper functions.
static void gmres_hessenberg_qr_with_b(double *jA, double *jQb,
	    double b, int nn);
static void gmres_least_squares(double *jY, double *jH,
	    double b_mag, int en, int nn);

// Solve A.x == b for x, approximately.
// A is an em x en matrix, which is not explicitly passed in -- only the
// matrix-vector product A.x is ever used.
// GMRES is only defined for em == en, but we use separate parameters for
// the two dimensions for compatibility with other iterative solvers,
// in that we can allow reuse of matrix-vector product routines between
// different iterative solvers.
// jX	    - output, the solution vector X
// jB0	    - input, the RHS vector 'b'
// em, en   - input, the dimension of the matrix 'A' that defines the system
// epsilon  - input, the tolerance for an acceptable residual ||A.x-b||
// ax	    - input, a routine that performs the matrix-vector product
//	      Called as ax(Y, X, em, en, args) to perform Y = A.X.
//	      The void* cookie 'args contains extra arguments which may
//	      be needed by any given matrix-vector routine and may be NULL.
// axa	    - input, the args cookie passed in to ax. ("ax args")
// lp	    - input, a routine to perform left preconditioning on the system
//	      May be NULL.  lp(Y, em, args) updates Y to Y' = A_0^{-1}.Y
//	      in-place.  As for ax, there is a void* args cookie passed along.
// lpa	    - input, "lp args", the argument cookie for lp.
// cc	    - input, "convergence callback" to allow the caller to give us
//	      feedback for when to check the residual of the unpreconditioned
//	      system in addition to the residual of the preconditioned system.
//	      May be NULL.  cc(nn, tmp, lastun, lastres, X, en, eps, args)
//	      is called at iteration nn of the GMRES algorithm with current
//	      (preconditioned) residual tmp.  lastun and lastres are the
//	      iteration on which we last checked the residual of the
//	      unpreconditioned system (or -1 if we have never done so),
//	      and the residual that we found at that time.  X is the
//	      current trial solution vector, en the dimension of the system,
//	      eps the convergence criterion, and args a void* cookie.
//	      cc should return true if the extra check should be performed
//	      (and false otherwise).
//	      If cc is NULL, the default behavior is to only check the
//	      unpreconditioned residual when the preconditioned residual
//	      is smaller than the GMRES tolerance epsilon.
//	    - input, "cc args", the argument cookie for cc.
#define MAX_GMRES_SPACE_SIZE	100
extern "C" void
gmres(double *jX, double *jB0, int em, int en, double epsilon,
		void (*ax)(double *, double *, int, int, void *), void *axa,
		void (*lp)(double *, int, void *), void *lpa,
		bool (*cc)(int, double, int, double, double const * const,
			   int, double, void *), void *cca)
{
    int length = MAX_GMRES_SPACE_SIZE;		    // the size of our arrays
    bool need_restart = false;
    double *jQ = QAllocDoubleWithInit(length * en); // the vectors q_i
    double *jH = QAllocDoubleWithInit(length * en); // the Hessenberg matrix H
    double *jV = QAllocDouble(en);		    // scratch vector
    double *jY = QAllocDouble(en);		    // least-squares solution
    double *jB = QAllocDoubleWithInit(en);	    // our local copy of 'b'
    double *jXf = QAllocDoubleWithInit(en);	    // interim solution x_0
    int lastun = -1;		// the last time we checked without A_0^{-1}
    int counter, nn;		// progress overall and through current block
    double lastres = 0.0;	// residual from last time without A_0^{-1}
    double tmp;			// scratch space
    double b_mag, v_mag;	// |b|, |v|
    bool norm_is_final = false;	// did we converge the unpreconditioned system?
    GenMatrix Xt(jX, 1, en);	// more convenient to have X^T for printing
    GenMatrix Bt(jB, 1, en);	// likewise for B^T
    GenMatrix V(jV, en, 1);

    if (ax == NULL)
	QCrash("no matrix-vector product routine supplied to GMRES");

    // Leave original B untouched, for determination of unpreconditioned error.
    VRcopy(jB, jB0, em);
    // precondition B
    if (lp != NULL)
	(*lp)(jB, em, lpa);

    // Initialize first basis vector and normalize.
    VRcopy(jQ, jB, en);
    VRdot(&b_mag, jQ, jQ, en);
    b_mag = sqrt(b_mag);
    VRscale(jQ, en, 1.0 / b_mag);

    // GMRES loop
    // We can only do so many iterations as the rank of A, but our
    // temporary storage for (e.g.) Q may not be that large.  In that case,
    // reset the Krylov space by solving A.(x_0+x) == b - A.x_0 where
    // x_0 is our current best solution vector.
    for(nn = 1, counter = 1; counter <= min(em, en); ++nn, ++counter) {
	if (nn == length) {
	    // need to reset the GMRES kryolv space
	    printf("Resetting GMRES Kryolv space\n");
	    VRadd2(jXf, jX, en);
	    (*ax)(jV, jX, em, en, axa);
	    if (lp != NULL)
		(*lp)(jV, em, lpa);
	    VRsub2(jB, jV, en);
	    VRcopy(jQ, jB, en);
	    VRdot(&b_mag, jQ, jQ, en);
	    b_mag = sqrt(b_mag);
	    VRscale(jQ, en, 1.0 / b_mag);
	    nn = 1;
	}

	// The size of these matrices changes at each iteration.
	// We must declare them after the previous block because nn may
	// have changed in that block.
	GenMatrix Y(jY, nn, 1);
	GenMatrix Yt(jY, 1, nn);
	GenMatrix Qn(jQ, en, nn);
	GenMatrix Qnn(jQ, en, nn + 1);
	GenMatrix H(jH, en, nn);

	// Start with an Arnoldi iteration.
	(*ax)(jV, jQ + (nn - 1) * en, em, en, axa);
	if (lp != NULL)
	    (*lp)(jV, em, lpa);
	for(int jj = 1; jj <= nn; ++jj) {
	    VRdot(jH + en * (nn - 1) + (jj - 1), jQ + (jj - 1) * en, jV, en);
	    // VRdxy(a,X,Y,N): Y = a*X + Y, for a scalar 'a'
	    // so this is v = v - h_{jn} q_j
	    VRdxy(-*(jH + en * (nn - 1) + (jj - 1)), jQ + (jj - 1) * en, jV, en);
	}
	VRdot(&v_mag, jV, jV, en);
	v_mag = sqrt(v_mag);
	printf("GMRES residual step %i: %.12f\n", counter, v_mag);
	if (fabs(v_mag) < 1e-8)
	    need_restart = true;
	// h_{n+1,n} = |v|
	*(jH + en * (nn - 1) + nn) = v_mag;
	// q_{n+1} += v/h_{n+1,n}
	VRcopy(jQ + en * nn, jV, en);
	if (!need_restart)
	    VRscale(jQ + en * nn, en, 1.0 / *(jH + en * (nn - 1) + nn));
	// Done with Arnoldi iteration -- the above is q_{n+1}
	// Now, find y to minimize |H_n.y - |b|*e_1| == |r_n|
	// Could be algorithmically faster than this QR factorization
	// by using more storage to hold the Givens rotations from each step.
	// XXX dgels for least squares?
	gmres_least_squares(jY, jH, b_mag, en, nn);
	Yt.Print("current Y");
	MatMult(Xt, Y, Qn, 1, 1);
	printf("GMRES iteration %i\n", counter);
	(*ax)(jV, jX, em, en, axa);
	if (lp != NULL)
	    (*lp)(jV, em, lpa);
	GenMatrix B(jB, en, 1);
	V -= B;
	VRdot(&tmp, jV, jV, en);
	tmp = sqrt(tmp);
	printf("residual norm %.12f\n", tmp);
	// If there's no preconditioner and we're done; we're done.
	if ((lp == NULL) && (tmp < epsilon)) {
	    norm_is_final = true;
	    break;
	}
	// Now, check with the caller whether we want to check the residual
	// of the unpreconditioned system.
	// We need to have both a (left) preconditioner and a convergence check
	// routine available for using the latter to make any sense at all.
	// If we don't have both of them, just use a standard check of
	// whether the preconditioned norm would be good enough.
	if ( (lp != NULL) && (((cc == NULL) && (tmp < epsilon))
	     || ((cc != NULL) &&
	     (*cc)(nn, tmp, lastun, lastres, jX, en, epsilon, cca)))) {
	    // Check the unpreconditioned residual.
	    (*ax)(jV, jX, em, en, axa);
	    GenMatrix Vt(jV, 1, em);
	    GenMatrix B0t(jB0, 1, em);
	    Vt -= B0t;
	    VRdot(&tmp, jV, jV, en);
	    tmp = sqrt(tmp);
	    printf("residual norm of unpreconditioned system %.16f\n", tmp);
	    lastun = nn;
	    lastres = tmp;
	    if (tmp < epsilon) {
		norm_is_final = true;
		break;
	    }
	}
	Xt.Print("X");
	if (need_restart)
	    break;
    }
    if (!norm_is_final) {
	// Only recompute the final norm if it's going to be different.
	(*ax)(jV, jX, em, en, axa);
	VRload(jB, en, 0.0);
	VRcopy(jB, jB0, em);
	GenMatrix Vt(jV, 1, en);
	Bt -= Vt;
	VRdot(&tmp, jB, jB, en);
	tmp = sqrt(tmp);
    }
    printf("GMRES final residual norm of unpreconditioned system %.16f\n",tmp);

    // Restore any interim solution stored when we overran fixed storage
    VRadd2(jX, jXf, en);

    // Check for a closed subspace.
    if ((tmp >= epsilon) && need_restart)
	QWarn("implement GMRES restart");

    QFree(jQ);
    QFree(jB);
    QFree(jH);
    QFree(jV);
    QFree(jY);
    QFree(jXf);
}

// A is en x nn (en may be large).
// Solve for Q,R s.t. A == Q.R with Q unitary and R upper-triangular.
// Q is (nn + 1) x nn, R is nn x nn
static void
gmres_hessenberg_qr_with_b(double *jA, double *jQb, double b, int nn)
{
    double x[2];
    double v[2];
    double *tmp[2];
    double norm;
    int mm = nn + 1;

    // Allocation and initialization.
    tmp[0] = QAllocDouble(2 * nn);
    tmp[1] = tmp[0] + nn;
    VRload(jQb, nn + 1, 0.0);
    *jQb = b;

    // Go through and eliminate a subdiagonal element of H.
    // pp 73-74 of Trefethen & Bau's _Numerical Linear Algebra_.
    // We can simplify x to a pair since A is Hessenberg
    // but must treat the last block specially.
    for(int kk = 0; kk < nn ; ++kk) {
	x[0] = *(jA + kk * mm + kk);
	x[1] = *(jA + kk * mm + kk + 1);
	norm = sqrt(x[0] * x[0] + x[1] * x[1]);
	v[0] = (x[0] > 0 ? norm : -norm) + x[0];
	v[1] = x[1];
	norm = 1.0 / sqrt(v[0] * v[0] + v[1] * v[1]);
	v[0] *= norm;
	v[1] *= norm;
	// tmp holds the nonzero part of v.v^*.A_{k:m,k:n}
	for(int ii = kk; ii < nn; ++ii) {
	    tmp[0][ii] = v[0] * v[0] * *(jA + ii * mm + kk)
			+ v[0] * v[1] * *(jA + ii * mm + kk + 1);
	    tmp[1][ii] = v[0] * v[1] * *(jA + ii * mm + kk)
			+ v[1] * v[1] * *(jA + ii * mm + kk + 1);
	    // This is A_{k:m,k:n} -= 2 * v.v^*.A_{k:m,k:n}
	    *(jA + ii * mm + kk) -= 2 * tmp[0][ii];
	    *(jA + ii * mm + kk + 1) -= 2 * tmp[1][ii];
	}

	// Now do b_{k:m} -= 2 * v.v^*.b_{k:m}.
	// XXX May gain simplification from B's starting structure
	tmp[0][0] = v[0] * v[0] * *(jQb + kk)
		    + v[0] * v[1] * *(jQb + kk + 1);
	tmp[1][0] = v[0] * v[1] * *(jQb + kk)
		    + v[1] * v[1] * *(jQb + kk + 1);
	*(jQb + kk) -= 2 * tmp[0][0];
	*(jQb + kk + 1) -= 2 * tmp[1][0];
    }

    // Clean up after ourselves.
    QFree(tmp[0]);
}

// Find y to minimize |H.y - |b|e_q|
// H is en x n with Hessenberg structure, so really (n+1) x n
static void
gmres_least_squares(double *jY, double *jH, double b_mag, int en, int nn)
{
    double *jC = QAllocDoubleWithInit(nn + 1);
    double *jR = QAllocDoubleWithInit(nn * (nn + 1));
    GenMatrix C(jC, nn + 1, 1);
    CHAR1 uu = 'U';
    CHAR1 c_en = 'N';
    INTEGER I_nn = nn;
    INTEGER I_em = nn + 1;
    INTEGER one = 1;
    INTEGER info = 0;

    // H = Q.R for orthogonal Q and upper-triangular R.
    // Make a copy of H as qr_withb will destroy its input to form R
    for(int ii = 0; ii < nn; ii++) {
	VRcopy(jR + ii * (nn + 1), jH + ii * en, nn + 1);
    }
    GenMatrix R(jR, nn + 1, nn);
    // Gets C = Q^T.B and generates R (Q itself is not needed).
    gmres_hessenberg_qr_with_b(jR, jC, b_mag, nn);
    // Solve R.x == C for x
    dtrtrs(uu, c_en, c_en, &I_nn, &one, jR, &I_em, jC, &I_nn, &info);
    if (info != 0)
	printf("in dummy_least_squares dtrtrs info %i\n", info);

    // Store the answer.
    VRcopy(jY, jC, nn);

    // Cleanup.
    QFree(jC);
    QFree(jR);
}
