/*=================================================================
 *
 * stackgddfast.c
 *
 *=================================================================*/


#include <math.h>
#include <complex.h>
#include <string.h>
#include "mex.h"


// Input Arguments
#define	KS_IN		prhs[0]
#define	DS_IN		prhs[1]
#define N0_IN		prhs[2]
#define NS_IN		prhs[3]
#define DNS_IN		prhs[4]
#define THETA_IN	prhs[5]
#define POL_IN		prhs[6]

// Output Arguments
#define	R2_OUT		plhs[0]
#define GD_OUT	        plhs[1]
#define GDD_OUT        plhs[2]


// Definitions
#define NIN 7
#define NOUT 3
#define C  0.2997924580  // speed of light (um/fs)
#define C2 0.08987551787


// Core computation routine
static void stackgddfast(
			 const int nk,
			 const int n,
			 const double ks[],
			 const double ds[],
			 const double n0,
			 const double ns[],
			 const double dns[],
			 const double theta,
			 const int isTM,

			 double r2[],
			 double gd[],
			 double gdd[])
{
  // Variables
  int dx, kx;
  int j;  // iterator
  double neffs[3], dneffs[3], dterms[3], dterms2[3],  kneffs[3];  // try comp
  double pTEs[4], p0s[4], ps, pps[4], pms[4];
  int nx, px;  // material data indices
  complex ephi, D, D2;
  complex T1, T2, old;  // T matrix elements
  complex dT1, dT2, ddT1, ddT2, TL1, TL2;
  complex TLdT1, TLdT2;
  complex T1new, T2new, dT1new, dT2new, ddT1new, ddT2new;

  // Common terms
  const double n0sintheta2 = n0*n0*sin(theta)*sin(theta);
  const double *nsofk = ns, *dnsofk = dns;  // offset index pointers


  /*
   * Loop over all wavenumbers.
   */
  for (kx = 0; kx < nk; kx++)
    {
      // Precalculate all material parameters.
      // index variables: [n1, n2, nsub]
      // p variables: [p01, p12, p21, p2sub] */
      // Calculate effective indices and common expressions.
      for (j = 0; j < 3; j++)
	{
	  neffs[j] = sqrt(nsofk[j]*nsofk[j] - n0sintheta2);
	  dneffs[j] = nsofk[j]*dnsofk[j]/neffs[j];
	  dterms[j] = neffs[j] + ks[kx]*dneffs[j];
	  kneffs[j] = ks[kx]*neffs[j];
	  dterms2[j] = dterms[j]*dterms[j];
	}
      // Calculate pTEs (needed regardless of polarization).
      pTEs[0] = n0*cos(theta)/neffs[0];  // neff0/neff1
      pTEs[1] = neffs[0]/neffs[1];
      pTEs[2] = 1/pTEs[1];
      pTEs[3] = neffs[1]/neffs[2];
      // Calculate TM or TE as needed.
      if (isTM)  // TM
	{
	  // Calculate pTM from p0 (normal incidence) and pTE.
	  p0s[0] = n0/nsofk[0];
	  p0s[1] = nsofk[0]/nsofk[1];
	  p0s[2] = 1/p0s[1];
	  p0s[3] = nsofk[1]/nsofk[2];
	  for (j = 0; j < 4; j++)
	    {
	      ps = pTEs[j]/(p0s[j]*p0s[j]);
	      pps[j] = 1 + ps;
	      pms[j] = 1 - ps;
	    }
	}
      else  // TE
	{
	  for (j = 0; j < 4; j++)
	    {
	      pps[j] = 1 + pTEs[j];
	      pms[j] = 1 - pTEs[j];
	    }
	}

      /*
       * Loop through structure.
       */
      T1 = 1; T2 = 0;
      dT1 = 0; dT2 = 0;
      ddT1 = 0; ddT2 = 0;
      for (dx = 0; dx < n; dx++)
	{
	  // Select appropriate material parameters
	  if (dx & 1) {  // even layer, n1->n2 (mod(dx,2) == 1)
	    nx = 1;
	    px = 1; }
	  else if (dx == 0) {  // first layer, n0->n1
	    nx = 0;
	    px = 0; }
	  else {  // random odd layer, n2->n1
	    nx = 0;
	    px = 2;
	  }

	  // Calculate new matrices for structure up to layer dx + 1.
	  ephi = (cos(ds[dx]*kneffs[nx]) - I*sin(ds[dx]*kneffs[nx]))/2;
	  D = -I*(ds[dx]*dterms[nx]);  // differential operator
	  D2 = -ds[dx]*(2*I*dneffs[nx] + ds[dx]*dterms2[nx]);
	  TL1 = ephi*pps[px];  // current layer matrix element 1
	  TL2 = ephi*pms[px];  // current layer matrix element 2

	  // Tnew = TL * T
	  old = T1;
	  T1 = TL1*T1 + TL2*conj(T2);
	  T2 = TL1*T2 + TL2*conj(old);

	  TLdT1 = TL1*dT1 + TL2*conj(dT2);
	  TLdT2 = TL1*dT2 + TL2*conj(dT1);

	  // dTnew = D*Tnew + TL*dT
	  dT1 = D*T1 + TLdT1;
	  dT2 = D*T2 + TLdT2;

	  // ddTnew
	  old = ddT1;
	  ddT1 = D2*T1 + 2*D*TLdT1 + TL1*ddT1 + TL2*conj(ddT2);
	  ddT2 = D2*T2 + 2*D*TLdT2 + TL1*ddT2 + TL2*conj(old);
	} // end: loop over structure
      // Handle propagation of last layer into (dispersionless) subtrate.
      TL1 = pps[3]/2;
      TL2 = pms[3]/2;
      old = T1;
      T1 = TL1*T1 + TL2*conj(T2);
      T2 = TL1*T2 + TL2*conj(old);
      old = dT1;
      dT1 = TL1*dT1 + TL2*conj(dT2);
      dT2 = TL1*dT2 + TL2*conj(old);
      old = ddT1;
      ddT1 = TL1*ddT1 + TL2*conj(ddT2);
      ddT2 = TL1*ddT2 + TL2*conj(old);

      // Calculate reflectance and group delay from total transfer matrices.
      complex r = -T2/T1;
      double absr = cabs(r);
      r2[kx] = absr*absr;
      complex dr = (T2*dT1 - dT2*T1)/(T1*T1);
      complex ddr = (T2*(-2*dT1*dT1 + T1*ddT1) +
		     T1*(2*dT1*dT2 - T1*ddT2))/(T1*T1*T1);

      // Output
      gd[kx] = (cimag(dr)*creal(r) - creal(dr)*cimag(r))/r2[kx]/C;
      gdd[kx] = ((cimag(ddr)*creal(r) - creal(ddr)*cimag(r))/C -
		 2*(creal(dr)*creal(r) + cimag(dr)*cimag(r))*gd[kx])/r2[kx]/C;

      // Update offset pointers by the three index spaces.
      nsofk += 3;
      dnsofk += 3;
    }

  return;
}


// MEX function gateway routine
void mexFunction(
		 int nlhs, mxArray* plhs[],
		 int nrhs, const mxArray* prhs[]
		 )
{
  // Check dimensions.
  int nk = mxGetN(KS_IN);
  int n = mxGetM(DS_IN);

  // Create matrices for return arguments
  R2_OUT = mxCreateDoubleMatrix(1, nk, mxREAL);
  GD_OUT = mxCreateDoubleMatrix(1, nk, mxREAL);
  GDD_OUT = mxCreateDoubleMatrix(1, nk, mxREAL);

  // Assign pointers and values to the parameters
  double* r2 = mxGetPr(R2_OUT);
  double* gd = mxGetPr(GD_OUT);
  double* gdd = mxGetPr(GDD_OUT);

  double* ks = mxGetPr(KS_IN);
  double* ds = mxGetPr(DS_IN);
  double n0 = mxGetScalar(N0_IN);
  double* ns = mxGetPr(NS_IN);
  double* dns = mxGetPr(DNS_IN);
  double theta = mxGetScalar(THETA_IN);
  int polstrlen = mxGetN(POL_IN);
  char polstr[8];
  mxGetString(POL_IN, polstr, polstrlen + 1);
  int isTM = (strncmp("TM", polstr, 2) == 0);

  // Do the actual computation
  stackgddfast(nk, n, ks, ds, n0, ns, dns, theta, isTM, r2, gd, gdd);

  return;
}
