/*
//
// Ultra-fast version of the lazy learning algorithm.
//
//
// Local models                   quadratic
// Kernel functions               rectangular
// Identification                 recursive mean and variance
// Metric                         L1
// Model selection                Minimum of leave-one-out error
//
// -----------------------------------------------------------------------
//
//            Mauro Birattari       &        Gianluca Bontempi
//
//                               IRIDIA 
//                    Universite' Libre de Bruxelles
//                    email: {mbiro,gbonte}@ulb.ac.be
//
// -----------------------------------------------------------------------
//   
//   
//   How to compile:
//   If a C compiler is correctly installed, and if matlab knows its path,
//   you can compile me from the matlab prompt, as follows:
//   
//   	     >> mex quaM.c
//   
//   How to use:
//   From matlab you can call the resulting mex file using the following syntax.
//   
//   	     >> [h,s,t,k,H,S,T,I] = quaM(X,Y,Q,id_par);
//   
//   where 
//      INPUT:
//   	     X[m,d]              Examples: Input
//   	     Y[m,1]              Examples: Ouput
//   	     Q[q,d]              Query points
//   	     id_par[2,1]         Identification parameters = [idm,idM];
//
//      OUTPUT:
//           h[q,1]              Prediction with the selected number of 
//                               neighbors
//           s[q,1]              Estimated variance of the prediction.
//	     t[d+1,q]            Selected model for each query  
//	     k[q,1]              Selected number of neighbors
//	     H[idM,q]            All the predictions obtained using a 
//                               number of neighbors in the range 1:idM
//	     S[idM,q]            Estimated variance of ALL the predictions
//	     T[p,idM,q]          All the models considered for each query;
//                               here p=(d+1)*(d+2)/2
//           I[idM,q]            Index of idM-nearest-neighbors of each 
//                               query point
//
//   
//   The identification parameters are [idm,idM]: the minimum and maximum
//   number of examples used in the identification.
//   
//   UNDOCUMENTED FEATURE: the function accept also, 
//   as a 5th input, a scalar LAMBDA which is a regularization parameter
//   (default 1E6).
//   As 6th input, a vector W[d,1] of weights that can be used to 
//   rescale the dimensions.
//
//______________________________________________Mao: Feb 5, 1999___
*/


#include <mex.h>
#include<math.h>
#include <string.h>


void mexFunction(int nlhs, mxArray *plhs[],
		 int nrhs, const mxArray *prhs[]){

  double LAMBDA;
  int mx, nx, my, ny;
  int mq, nq, mw, nw;
  int nz;
  double* id;
  int mid, nid;
  int idm, idM;
  double initDist;
  double* Xvec;
  double* Qvec;
  double* Cvec;
  double* Wvec;
  double*  Y;
  double*  W;
  double* y_hat;
  double* s_hat;
  double* t_hat;
  double* k_hat;
  double* Y_hat;
  double* S_hat;
  double* T_hat;
  double* I_hat;
  int dimSize[3];
  double** X;
  double** Q;
  double ** C;
  double** Z;
  double* Zvec;
  int* BestIndx;
  double* BestDist;
  double dist;
  double  hB, eB;
  int kB;
  double e, b;
  double tmp;
  double sse;
  double eC;
  double H, MSE;
  double* Vc;
  double** v;
  double* t;
  double* tB;
  double* a;
  int indx;
  int i, j, k, q, p, m;

  initDist = mxGetInf();
  LAMBDA = 1E6;
  Wvec = NULL;

  switch (nrhs){
    case 6:
      mw = mxGetM(prhs[6]);  
      nw = mxGetN(prhs[6]);
      Wvec = mxGetPr(prhs[6]);
    case 5:
      LAMBDA = mxGetScalar(prhs[4]);
    case 4:
      break;
    default:
      mexErrMsgTxt("Number of argument no good.");
  }


  // Examples: input
  mx = mxGetM(prhs[0]);                //number of examples
  nx = mxGetN(prhs[0]);             

  // Examples: output
  my = mxGetM(prhs[1]);                //number of examples
  ny = mxGetN(prhs[1]);

  // Queries
  mq = mxGetM(prhs[2]);                //number of queries
  nq = mxGetN(prhs[2]);

  // Range identification examples
  mid = mxGetM(prhs[3]);
  nid = mxGetN(prhs[3]);

  
  if (  (ny != 1)   ||
	(mid*nid !=2)  ||
	(mx != my)  ||
	(nq != nx)    )
    mexErrMsgTxt("Matrix dimensions must agree.");

  if ( (nrhs==7) && (mw*nw!=nx) )
    mexErrMsgTxt("Vector of weights no good.");

  Xvec = mxGetPr(prhs[0]);
  Y = mxGetPr(prhs[1]);
  Qvec = mxGetPr(prhs[2]);


  id = mxGetPr(prhs[3]);
  idm = (int)id[0];
  idM = (int)id[1];
  idm = (idm<2)? 2 : idm;
  idM = (idM>mx)? mx : idM;
  idM = (idM<idm)? idm : idM;


  nz = (nx+1)*(nx+2)/2;

  plhs[0] = mxCreateDoubleMatrix(mq,1,mxREAL); 
  y_hat = mxGetPr(plhs[0]); 

  plhs[1] = mxCreateDoubleMatrix(mq,1,mxREAL); 
  s_hat = mxGetPr(plhs[1]); 

  plhs[2] = mxCreateDoubleMatrix(nz,mq,mxREAL); 
  t_hat = mxGetPr(plhs[2]); 

  plhs[3] = mxCreateDoubleMatrix(mq,1,mxREAL); 
  k_hat = mxGetPr(plhs[3]); 
  
  if (nlhs > 4){
    plhs[4] = mxCreateDoubleMatrix(idM,mq,mxREAL); 
    Y_hat = mxGetPr(plhs[4]); 
  }else
    Y_hat = 0;
  

  if (nlhs > 5){
    plhs[5] = mxCreateDoubleMatrix(idM,mq,mxREAL); 
    S_hat = mxGetPr(plhs[5]); 
  }else
    S_hat = 0;
  

  if (nlhs > 6){
    dimSize[0] = nz;
    dimSize[1] = idM;
    dimSize[2] = mq;
    plhs[6] = mxCreateNumericArray(3,dimSize,mxDOUBLE_CLASS,mxREAL);
    T_hat = mxGetPr(plhs[6]); 
  }else
    T_hat = 0;
  

  if (nlhs > 7){
    plhs[7] = mxCreateDoubleMatrix(idM,mq,mxREAL); 
    I_hat = mxGetPr(plhs[7]); 
  }else
    I_hat = 0;
  


  // Create Fortran-style matrices from Matlab vectors
  X = mxCalloc(nx,sizeof(double*));
  Q = mxCalloc(nq,sizeof(double*));
  C = mxCalloc(nx,sizeof(double*));
  Cvec = mxCalloc(nx*mx,sizeof(double));

  for (i=0; i<nx; i++,Xvec+=mx,Qvec+=mq,Cvec+=mx){
    X[i] = Xvec;
    Q[i] = Qvec;
    C[i] = Cvec;
  }


  Zvec = mxCalloc(idM*nz,sizeof(double));
  Z = mxCalloc(idM,sizeof(double*));
  W = mxCalloc(idM,sizeof(double));

  for (i=0; i<idM; i++,Zvec+=nz)
    Z[i] = Zvec;

  Vc = mxCalloc(nz*nz,sizeof(double));
  v = mxCalloc(nz,sizeof(double*));
  t  = mxCalloc(nz,sizeof(double));
  tB = mxCalloc(nz,sizeof(double));
  a  = mxCalloc(nz,sizeof(double));

  for (i=0; i<nz; i++)
    v[i] = Vc + i * nz;



  BestIndx = mxCalloc(idM+1,sizeof(int));
  BestDist = mxCalloc(idM+2,sizeof(double));
  *BestDist = 0;

  for( q=0; q<mq; q++){

    for (p=1; p<=idM; p++)
      BestDist[p] = initDist;

    if (Wvec){
      for (i=0; i<mx; i++){
	dist = 0.0;
	for (j=0; j<nx && dist < BestDist[idM] ; j++){
	  C[j][i] = X[j][i]-Q[j][q];
	  dist += Wvec[j]*fabs(C[j][i]);
	}
	for(p=idM; dist < BestDist[p] ; p--){
	  BestDist[p+1] = BestDist[p];
	  BestIndx[p] = BestIndx[p-1];
	}
	BestDist[p+1] = dist;
	BestIndx[p] = i;
      }
    } else {
      for (i=0; i<mx; i++){
	dist = 0.0;
	for (j=0; j<nx && dist < BestDist[idM] ; j++){
	  C[j][i] = X[j][i]-Q[j][q];
	  dist += fabs(C[j][i]);
	}
	for(p=idM; dist < BestDist[p] ; p--){
	  BestDist[p+1] = BestDist[p];
	  BestIndx[p] = BestIndx[p-1];
	}
	BestDist[p+1] = dist;
	BestIndx[p] = i;
      }
    }


    if (I_hat)
      for(i=0;i<idM;i++)
	*(I_hat++) = (double)BestIndx[i];


    /*    Reinitialize v    */
    for (i=0; i<nz*nz; i++)
      Vc[i] = 0.0;
    for (j=0; j<nz; j++)
      v[j][j] = LAMBDA;


    Zvec = *Z;
    for(i=0;i<idM;i++){
      indx = BestIndx[i];
      W[i] = Y[indx];
      *(Zvec++) = 1.0;
      for(j=0;j<nx;j++)
	*(Zvec++) = C[j][indx];
      for(p=0;p<nx;p++)
	for(m=p;m<nx;m++)
	  *(Zvec++) = C[p][indx]*C[m][indx];
    }

    for(i=0;i<nz;i++)
      t[i] = 0.0;


    
    for (k=0; k<idM; k++){
      e = W[k];
      b = 1;
      for (i=0; i<nz; i++){
	tmp=0;
	for(j=0; j<nz; j++)
	  tmp += v[j][i] * Z[k][j];
	a[i] = tmp;
	b += Z[k][i] * tmp;
	e -= Z[k][i] * t[i];
      }
      for (i=0; i<nz; i++)
	for(j=0; j<nz; j++)
	  v[j][i] -= a[i] * a[j] / b;
      for (i=0; i<nz; i++){
	tmp=0;
	for(j=0; j<nz; j++)
	  tmp += v[j][i] * Z[k][j];
	t[i] += e * tmp;
      }


      if (k>0){
	sse=0;
	for(m=0; m<=k; m++){
	  e = W[m];
	  b = 1;
	  for (i=0; i<nz; i++){
	    tmp=0;
	    for(j=0; j<nz; j++)
	      tmp += v[j][i] * Z[m][j];
	    b -= Z[m][i] * tmp;
	    e -= Z[m][i] * t[i];
	  }
	  sse += pow(e/b,2);
	}
	eC = sse / (k+1);
      }else{
	eC = mxGetInf();
	eB = eC;
      }


      if (Y_hat)
	*(Y_hat++) = t[0];

      if (S_hat)
	*(S_hat++) = eC;

      if (T_hat){
	memcpy(T_hat,t,nz*sizeof(double));
	T_hat += nz;
      }


      if ( (k>=idm-1) && (eC < eB) ) {
	memcpy(tB,t,nz*sizeof(double));
	eB = eC;
	kB = k+1;
      }

    }

    y_hat[q] = tB[0];
    k_hat[q] = (double)kB;
    s_hat[q] = eB;
    memcpy(t_hat,tB,nz*sizeof(double));
    t_hat += nz;
  }

  /*  
  mxFree(Vc);
  mxFree(v);
  mxFree(t);
  mxFree(tB);
  mxFree(a);


  mxFree(BestIndx);
  mxFree(BestDist);
  mxFree(X);
  mxFree(Q);
  mxFree(C);

  mxFree(W);
  mxFree(Z);
  */

}
