/*=================================================================
 *
 * stackgdfastmex.c
 *
 *=================================================================*/


#include <math.h>
#include <complex.h>
#include <string.h>
#include "mex.h"


// Input Arguments
#define	KS_IN		prhs[0]
#define	DS_IN		prhs[1]
#define N0_IN		prhs[2]
#define NS_IN		prhs[3]
#define DNS_IN		prhs[4]
#define THETA_IN	prhs[5]
#define POL_IN		prhs[6]

// Output Arguments
#define	R2_OUT		plhs[0]
#define GD_OUT	        plhs[1]


// Definitions
#define NIN 7
#define NOUT 2
#define C 0.2997924580  // speed of light (um/fs)


// Core computation routine
static void stackgdfast(
			const int nk,
			const int n,
			const double ks[],
			const double ds[],
			const double n0,
			const double ns[],
			const double dns[],
			const double theta,
			const int isTM,

			double r2[],
			double gd[])
{
  // Variables
  int dx, kx;
  int j;  // iterator
  double neffs[3], dneffs[3], dterms[3], kneffs[3];
  double pTEs[4], p0s[4], ps, pps[4], pms[4];
  int nx, px;  // material data indices
  complex ephi, D;
  complex T1, T2, old;  // layer T matrix elements
  complex dT1, dT2, TL1, TL2;  // T matrix elements
  complex T1new, T2new, dT1new, dT2new;

  // Common terms
  const double n0sintheta2 = n0*n0*sin(theta)*sin(theta);
  const double *nsofk = ns, *dnsofk = dns;  // offset index pointers

  /*
   * Loop over all wavenumbers.
   */
  for (kx = 0; kx < nk; kx++)
    {
      // Precalculate all material parameters.
      // index variables: [n1, n2, nsub]
      // p variables: [p01, p12, p21, p2sub] */
      // Calculate effective indices and common expressions.
      for (j = 0; j < 3; j++)
	{
	  neffs[j] = sqrt(nsofk[j]*nsofk[j] - n0sintheta2);
	  dneffs[j] = nsofk[j]*dnsofk[j]/neffs[j];
	  dterms[j] = neffs[j] + ks[kx]*dneffs[j];
	  kneffs[j] = ks[kx]*neffs[j];
	}
      // Calculate pTEs (needed regardless of polarization).
      pTEs[0] = n0*cos(theta)/neffs[0];  // neff0/neff1
      pTEs[1] = neffs[0]/neffs[1];
      pTEs[2] = 1/pTEs[1];
      pTEs[3] = neffs[1]/neffs[2];
      // Calculate TM or TE as needed.
      if (isTM)  // TM
	{
	  // Calculate pTM from p0 (normal incidence) and pTE.
	  p0s[0] = n0/nsofk[0];
	  p0s[1] = nsofk[0]/nsofk[1];
	  p0s[2] = 1/p0s[1];
	  p0s[3] = nsofk[1]/nsofk[2];
	  for (j = 0; j < 4; j++)
	    {
	      ps = pTEs[j]/(p0s[j]*p0s[j]);
	      pps[j] = 1 + ps;
	      pms[j] = 1 - ps;
	    }
	}
      else  // TE
	{
	  for (j = 0; j < 4; j++)
	    {
	      pps[j] = 1 + pTEs[j];
	      pms[j] = 1 - pTEs[j];
	    }
	}

      /*
       * Loop through structure.
       */
      T1 = 1; T2 = 0;
      dT1 = 0; dT2 = 0;
      for (dx = 0; dx < n; dx++)
	{
	  //d = ds[dx];  // current layer thickness

	  // Select appropriate material parameters
	  if (dx & 1) {  // even layer, n1->n2 (mod(dx,2) == 1)
	    nx = 1;
	    px = 1; }
	  else if (dx == 0) {  // first layer, n0->n1
	    nx = 0;
	    px = 0; }
	  else {  // random odd layer, n2->n1
	    nx = 0;
	    px = 2;
	  }

	  // Calculate new matrices for structure up to layer dx + 1.
	  ephi = (cos(ds[dx]*kneffs[nx]) - I*sin(ds[dx]*kneffs[nx]))/2;
	  D = -I*ds[dx]*dterms[nx];  // differential operator
	  TL1 = ephi*pps[px];  // current layer element 1
	  TL2 = ephi*pms[px];  // current layer element 2

	  // Tnew = TL * T
	  old = T1;
	  T1 = TL1*T1 + TL2*conj(T2);
	  T2new = TL1*T2 + TL2*conj(old);

	  // dTnew = D*Tnew + TL*dT
	  old = dT1;
	  dT1 = D*T1new + TL1*dT1 + TL2*conj(dT2);
	  dT2 = D*T2new + TL1*dT2 + TL2*conj(old);
	}
      // Handle propagation of last layer into subtrate.
      TL1 = pps[3];
      TL2 = pms[3];
      old = T1;
      T1 = TL1*T1 + TL2*conj(T2);
      T2 = TL1*T2 + TL2*conj(old);
      old = dT1;
      dT1 = TL1*dT1 + TL2*conj(dT2);
      dT2 = TL1*dT2 + TL2*conj(old);

      // Calculate reflectance and group delay from total transfer matrices.
      complex r = -T2/T1;
      complex dr = (T2*dT1 - dT2*T1)/(T1*T1);

      // Output
      r2[kx] = cabs(r)*cabs(r);
      gd[kx] = (cimag(dr)*creal(r) - creal(dr)*cimag(r))/(r2[kx]*C);

      // Update offset pointers by the three index spaces.
      nsofk += 3;
      dnsofk += 3;
    }

  return;
}


// MEX function gateway routine
void mexFunction(
		 int nlhs, mxArray* plhs[],
		 int nrhs, const mxArray* prhs[]
		 )
{
  // Get dimensions.
  int nk = mxGetM(KS_IN)*mxGetN(KS_IN);
  int n = mxGetN(DS_IN)*mxGetM(DS_IN);

  // Create matrices for return arguments
  R2_OUT = mxCreateDoubleMatrix(1, nk, mxREAL);
  GD_OUT = mxCreateDoubleMatrix(1, nk, mxREAL);

  // Assign pointers and values to the parameters
  double* r2 = mxGetPr(R2_OUT);
  double* gd = mxGetPr(GD_OUT);

  double* ks = mxGetPr(KS_IN);
  double* ds = mxGetPr(DS_IN);
  double n0 = mxGetScalar(N0_IN);
  double* ns = mxGetPr(NS_IN);
  double* dns = mxGetPr(DNS_IN);
  double theta = mxGetScalar(THETA_IN);
  int polstrlen = mxGetN(POL_IN);
  char polstr[8];
  mxGetString(POL_IN, polstr, polstrlen + 1);
  int isTM = (strncmp("TM", polstr, 2) == 0);

  // Do the actual computation
  stackgdfast(nk, n, ks, ds, n0, ns, dns, theta, isTM, r2, gd);

  return;
}
