/*$Id: ex24.c,v 1.25 2001/08/07 03:04:16 balay Exp $*/

static char help[] = "Solves PDE optimization problem of ex22.c with AD for adjoint.\n\n";

#include "petscda.h"
#include "petscpf.h"
#include "petscmg.h"
#include "petscsnes.h"

/*

              Minimize F(w,u) such that G(w,u) = 0

         L(w,u,lambda) = F(w,u) + lambda^T G(w,u)

       w - design variables (what we change to get an optimal solution)
       u - state variables (i.e. the PDE solution)
       lambda - the Lagrange multipliers

            U = (w u lambda)

       fu, fw, flambda contain the gradient of L(w,u,lambda)

            FU = (fw fu flambda)

       In this example the PDE is 
                             Uxx - u^2 = 2, 
                            u(0) = w(0), thus this is the free parameter
                            u(1) = 0
       the function we wish to minimize is 
                            \integral u^{2}

       The exact solution for u is given by u(x) = x*x - 1.25*x + .25

       Use the usual centered finite differences.

       Note we treat the problem as non-linear though it happens to be linear

       The lambda and u are NOT interlaced.

          We optionally provide a preconditioner on each level from the operator

              (1   0   0)
              (0   J   0)
              (0   0   J')

  
*/


extern int FormFunction(SNES,Vec,Vec,void*);
extern int PDEFormFunctionLocal(DALocalInfo*,PetscScalar*,PetscScalar*,PassiveScalar*);

typedef struct {
  Mat        J;           /* Jacobian of PDE system */
  SLES       sles;        /* Solver for that Jacobian */
} AppCtx;

#undef __FUNCT__
#define __FUNCT__ "myPCApply"
int myPCApply(DMMG dmmg,Vec x,Vec y)
{
  Vec          xu,xlambda,yu,ylambda;
  PetscScalar  *xw,*yw;
  int          ierr;
  VecPack      packer = (VecPack)dmmg->dm;
  AppCtx       *appctx = (AppCtx*)dmmg->user;

  PetscFunctionBegin;
  ierr = VecPackGetAccess(packer,x,&xw,&xu,&xlambda);CHKERRQ(ierr);
  ierr = VecPackGetAccess(packer,y,&yw,&yu,&ylambda);CHKERRQ(ierr);
  if (yw && xw) {
    yw[0] = xw[0];
  }
  ierr = SLESSolve(appctx->sles,xu,yu,PETSC_IGNORE);CHKERRQ(ierr);
  ierr = SLESSolveTranspose(appctx->sles,xlambda,ylambda,PETSC_IGNORE);CHKERRQ(ierr);
  /*  ierr = VecCopy(xu,yu);CHKERRQ(ierr);
      ierr = VecCopy(xlambda,ylambda);CHKERRQ(ierr); */
  ierr = VecPackRestoreAccess(packer,x,&xw,&xu,&xlambda);CHKERRQ(ierr);
  ierr = VecPackRestoreAccess(packer,y,&yw,&yu,&ylambda);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

#undef __FUNCT__
#define __FUNCT__ "myPCView"
int myPCView(DMMG dmmg,PetscViewer v)
{
  int     ierr;
  AppCtx  *appctx = (AppCtx*)dmmg->user;

  PetscFunctionBegin;
  ierr = SLESView(appctx->sles,v);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

#undef __FUNCT__
#define __FUNCT__ "main"
int main(int argc,char **argv)
{
  int        ierr,nlevels,i,j;
  DA         da;
  DMMG       *dmmg;
  VecPack    packer;
  AppCtx     *appctx;
  ISColoring iscoloring;
  PetscTruth bdp;

  PetscInitialize(&argc,&argv,PETSC_NULL,help);

  /* Hardwire several options; can be changed at command line */
  ierr = PetscOptionsSetValue("-dmmg_grid_sequence",PETSC_NULL);CHKERRQ(ierr);
  ierr = PetscOptionsSetValue("-ksp_type","fgmres");CHKERRQ(ierr);
  ierr = PetscOptionsSetValue("-ksp_max_it","5");CHKERRQ(ierr);
  ierr = PetscOptionsSetValue("-pc_mg_type","full");CHKERRQ(ierr);
  ierr = PetscOptionsSetValue("-mg_coarse_ksp_type","gmres");CHKERRQ(ierr);
  ierr = PetscOptionsSetValue("-mg_levels_ksp_type","gmres");CHKERRQ(ierr);
  ierr = PetscOptionsSetValue("-mg_coarse_ksp_max_it","6");CHKERRQ(ierr);
  ierr = PetscOptionsSetValue("-mg_levels_ksp_max_it","3");CHKERRQ(ierr);
  ierr = PetscOptionsSetValue("-snes_mf_type","wp");CHKERRQ(ierr);
  ierr = PetscOptionsSetValue("-snes_mf_compute_norma","no");CHKERRQ(ierr);
  ierr = PetscOptionsSetValue("-snes_mf_compute_normu","no");CHKERRQ(ierr);
  ierr = PetscOptionsSetValue("-snes_ls","basic");CHKERRQ(ierr); 
  ierr = PetscOptionsSetValue("-dmmg_jacobian_mf_fd",0);CHKERRQ(ierr); 
  /* ierr = PetscOptionsSetValue("-snes_ls","basicnonorms");CHKERRQ(ierr); */
  ierr = PetscOptionsInsert(&argc,&argv,PETSC_NULL);CHKERRQ(ierr);   

  /* create VecPack object to manage composite vector */
  ierr = VecPackCreate(PETSC_COMM_WORLD,&packer);CHKERRQ(ierr);
  ierr = VecPackAddArray(packer,1);CHKERRQ(ierr);
  ierr = DACreate1d(PETSC_COMM_WORLD,DA_NONPERIODIC,-5,1,1,PETSC_NULL,&da);CHKERRQ(ierr);
  ierr = VecPackAddDA(packer,da);CHKERRQ(ierr);
  ierr = VecPackAddDA(packer,da);CHKERRQ(ierr);
  ierr = DADestroy(da);CHKERRQ(ierr);

  /* create nonlinear multi-level solver */
  ierr = DMMGCreate(PETSC_COMM_WORLD,2,PETSC_NULL,&dmmg);CHKERRQ(ierr);
  ierr = DMMGSetDM(dmmg,(DM)packer);CHKERRQ(ierr);
  ierr = VecPackDestroy(packer);CHKERRQ(ierr);

  /* Create Jacobian of PDE function for each level */
  nlevels = DMMGGetLevels(dmmg);
  for (i=0; i<nlevels; i++) {
    packer = (VecPack)dmmg[i]->dm;
    ierr   = VecPackGetEntries(packer,PETSC_NULL,&da,PETSC_NULL);CHKERRQ(ierr);
    ierr   = PetscNew(AppCtx,&appctx);CHKERRQ(ierr);
    ierr   = DAGetColoring(da,IS_COLORING_GHOSTED,&iscoloring);CHKERRQ(ierr);
    ierr   = DAGetMatrix(da,MATMPIAIJ,&appctx->J);CHKERRQ(ierr);
    ierr   = MatSetColoring(appctx->J,iscoloring);CHKERRQ(ierr);
    ierr   = ISColoringDestroy(iscoloring);CHKERRQ(ierr);
    ierr   = DASetLocalFunction(da,(DALocalFunction1)PDEFormFunctionLocal);CHKERRQ(ierr);
    ierr   = DASetLocalAdicFunction(da,ad_PDEFormFunctionLocal);CHKERRQ(ierr);
    dmmg[i]->user = (void*)appctx;
  }

  ierr = DMMGSetSNES(dmmg,FormFunction,PETSC_NULL);CHKERRQ(ierr);

  ierr = PetscOptionsHasName(PETSC_NULL,"-bdp",&bdp);CHKERRQ(ierr);
  if (bdp) {
    for (i=0; i<nlevels; i++) {
      SLES sles;
      PC   pc,mpc;

      appctx = (AppCtx*) dmmg[i]->user;
      ierr   = SLESCreate(PETSC_COMM_WORLD,&appctx->sles);CHKERRQ(ierr);
      ierr   = SLESSetOptionsPrefix(appctx->sles,"bdp_");CHKERRQ(ierr);
      ierr   = SLESSetFromOptions(appctx->sles);CHKERRQ(ierr);

      ierr = SNESGetSLES(dmmg[i]->snes,&sles);CHKERRQ(ierr);
      ierr = SLESGetPC(sles,&pc);CHKERRQ(ierr);
      for (j=0; j<=i; j++) {
	ierr = MGGetSmoother(pc,j,&sles);CHKERRQ(ierr);
	ierr = SLESGetPC(sles,&mpc);CHKERRQ(ierr);
	ierr = PCSetType(mpc,PCSHELL);CHKERRQ(ierr);
	ierr = PCShellSetApply(mpc,(int (*)(void*,Vec,Vec))myPCApply,dmmg[j]);CHKERRQ(ierr);
	ierr = PCShellSetView(mpc,(int (*)(void*,PetscViewer))myPCView);CHKERRQ(ierr);
      }
    }
  }

  ierr = DMMGSolve(dmmg);CHKERRQ(ierr);

  /* ierr = VecView(DMMGGetx(dmmg),PETSC_VIEWER_SOCKET_WORLD);CHKERRQ(ierr); */
  for (i=0; i<nlevels; i++) {
    appctx = (AppCtx*)dmmg[i]->user;
    ierr   = MatDestroy(appctx->J);CHKERRQ(ierr);
    if (appctx->sles) {ierr = SLESDestroy(appctx->sles);CHKERRQ(ierr);}
    ierr   = PetscFree(appctx);CHKERRQ(ierr);  
  }
  ierr = DMMGDestroy(dmmg);CHKERRQ(ierr);

  ierr = PetscFinalize();CHKERRQ(ierr);
  return 0;
}
 
/*
     Enforces the PDE on the grid
     This local function acts on the ghosted version of U (accessed via DAGetLocalVector())
     BUT the global, nonghosted version of FU

     Process adiC: PDEFormFunctionLocal
*/
#undef __FUNCT__
#define __FUNCT__ "PDEFormFunctionLocal"
int PDEFormFunctionLocal(DALocalInfo *info,PetscScalar *u,PetscScalar *fu,PassiveScalar *w)
{
  int          xs = info->xs,xm = info->xm,i,mx = info->mx;
  PetscScalar  d,h;

  d    = mx-1.0;
  h    = 1.0/d;

  for (i=xs; i<xs+xm; i++) {
    if      (i == 0)    fu[i]   = 2.0*d*(u[i] - w[0]) + h*u[i]*u[i];
    else if (i == mx-1) fu[i]   = 2.0*d*u[i] + h*u[i]*u[i];
    else                fu[i]   = -(d*(u[i+1] - 2.0*u[i] + u[i-1]) - 2.0*h) + h*u[i]*u[i];
  } 

  PetscLogFlops(9*mx);
  return 0;
}

/*
      Evaluates FU = Gradiant(L(w,u,lambda))

      This is the function that is usually passed to the SNESSetJacobian() or DMMGSetSNES() and
    defines the nonlinear set of equations that are to be solved.

     This local function acts on the ghosted version of U (accessed via VecPackGetLocalVectors() and
   VecPackScatter()) BUT the global, nonghosted version of FU (via VecPackAccess()).

     This function uses PDEFormFunction() to enforce the PDE constraint equations and its adjoint
   for the Lagrange multiplier equations

*/
#undef __FUNCT__
#define __FUNCT__ "FormFunction"
int FormFunction(SNES snes,Vec U,Vec FU,void* dummy)
{
  DMMG         dmmg = (DMMG)dummy;
  int          ierr,xs,xm,i,N,nredundant;
  PetscScalar  *u,*w,*fw,*fu,*lambda,*flambda,d,h,h2;
  Vec          vu,vlambda,vfu,vflambda,vglambda;
  DA           da;
  VecPack      packer = (VecPack)dmmg->dm;
  AppCtx       *appctx = (AppCtx*)dmmg->user;
  PetscTruth   skipadic;

  PetscFunctionBegin;
  ierr = PetscOptionsHasName(0,"-skipadic",&skipadic);CHKERRQ(ierr);

  ierr = VecPackGetEntries(packer,&nredundant,&da,PETSC_IGNORE);CHKERRQ(ierr);
  ierr = DAGetCorners(da,&xs,PETSC_NULL,PETSC_NULL,&xm,PETSC_NULL,PETSC_NULL);CHKERRQ(ierr);
  ierr = DAGetInfo(da,0,&N,0,0,0,0,0,0,0,0,0);CHKERRQ(ierr);
  d    = (N-1.0);
  h    = 1.0/d;
  h2   = 2.0*h;

  ierr = VecPackGetLocalVectors(packer,&w,&vu,&vlambda);CHKERRQ(ierr);
  ierr = VecPackScatter(packer,U,w,vu,vlambda);CHKERRQ(ierr);
  ierr = VecPackGetAccess(packer,FU,&fw,&vfu,&vflambda);CHKERRQ(ierr);
  ierr = VecPackGetAccess(packer,U,0,0,&vglambda);CHKERRQ(ierr);

  /* G() */
  ierr = DAFormFunction1(da,vu,vfu,w);CHKERRQ(ierr);
  if (!skipadic) { 
    /* lambda^T G_u() */
    ierr = DAComputeJacobian1WithAdic(da,vu,appctx->J,w);CHKERRQ(ierr);  
    if (appctx->sles) {
      ierr = SLESSetOperators(appctx->sles,appctx->J,appctx->J,SAME_NONZERO_PATTERN);CHKERRQ(ierr);
    }
    ierr = MatMultTranspose(appctx->J,vglambda,vflambda);CHKERRQ(ierr); 
  }

  ierr = DAVecGetArray(da,vu,(void**)&u);CHKERRQ(ierr);
  ierr = DAVecGetArray(da,vfu,(void**)&fu);CHKERRQ(ierr);
  ierr = DAVecGetArray(da,vlambda,(void**)&lambda);CHKERRQ(ierr);
  ierr = DAVecGetArray(da,vflambda,(void**)&flambda);CHKERRQ(ierr);

  /* L_w */
  if (xs == 0) { /* only first processor computes this */
    fw[0] = -2.*d*lambda[0];
  }

  /* lambda^T G_u() */
  if (skipadic) {
    for (i=xs; i<xs+xm; i++) {
      if      (i == 0)   flambda[0]   = 2.*d*lambda[0]   - d*lambda[1] + h2*lambda[0]*u[0];
      else if (i == 1)   flambda[1]   = 2.*d*lambda[1]   - d*lambda[2] + h2*lambda[1]*u[1];
      else if (i == N-1) flambda[N-1] = 2.*d*lambda[N-1] - d*lambda[N-2] + h2*lambda[N-1]*u[N-1];
      else if (i == N-2) flambda[N-2] = 2.*d*lambda[N-2] - d*lambda[N-3] + h2*lambda[N-2]*u[N-2];
      else               flambda[i]   = - d*(lambda[i+1] - 2.0*lambda[i] + lambda[i-1]) + h2*lambda[i]*u[i];
    }  
  }

  /* F_u */
  for (i=xs; i<xs+xm; i++) {
    if      (i == 0)   flambda[0]   +=    h*u[0];
    else if (i == 1)   flambda[1]   +=    h2*u[1];
    else if (i == N-1) flambda[N-1] +=    h*u[N-1];
    else if (i == N-2) flambda[N-2] +=    h2*u[N-2];
    else               flambda[i]   +=    h2*u[i];
  } 

  ierr = DAVecRestoreArray(da,vu,(void**)&u);CHKERRQ(ierr);
  ierr = DAVecRestoreArray(da,vfu,(void**)&fu);CHKERRQ(ierr);
  ierr = DAVecRestoreArray(da,vlambda,(void**)&lambda);CHKERRQ(ierr);
  ierr = DAVecRestoreArray(da,vflambda,(void**)&flambda);CHKERRQ(ierr);

  ierr = VecPackRestoreLocalVectors(packer,&w,&vu,&vlambda);CHKERRQ(ierr);
  ierr = VecPackRestoreAccess(packer,FU,&fw,&vfu,&vflambda);CHKERRQ(ierr);
  ierr = VecPackRestoreAccess(packer,U,0,0,&vglambda);CHKERRQ(ierr);

  PetscLogFlops(9*N);
  PetscFunctionReturn(0);
}






