// cv: cross-validated predictions from any Stata estimation command.
// Breaks a dataset into a number of subsets ("folds"), and for each
// runs an estimator on everything but that subset, and predicts results.
program define crossvalidate 
*! version 1.2.2  Oct 28, 2020   name change inside crossv.ado

  version 14.1
  /* parse arguments */
  gettoken target 0 : 0
  gettoken estimator 0 : 0
  //  estimators starting with discrim have two words
  if "`estimator'"=="discrim"  gettoken estimator2 0 : 0

  if "`estimator'"=="discrim"  { 
     di as error "Requires special parsing : varlist is no longer y xvars, but instead ,group(y)" 
	 di as error " not yet implemented; see gridsearch.ado"
	 // if ("`estimator'"=="discrim") local addstuff = `"group(`group')"' "
	 exit 198
  }
  
  syntax varlist (fv ts) [if] [in], [folds(string)] [gen(string)] [shuffle] [*]

  confirm name `estimator'
  confirm new variable `target'
  
  //Stata if funky: because I use [*] above, if I declare folds(int 5) and you pass a real (e.g. folds(18.5)), 
  // 	rather than giving the usual "option folds() incorrectly specified" error, Stata *ignores* that folds, 
  //  	gives the default value, and pushes the wrong folds into the `options' macro
  // instead, I take a string (i.e. anything) to ensure the folds option always, and then parse manually
  if("`folds'"=="") {
    local folds = 5
  }
  confirm integer number `folds'
  
  //di as txt "folds= `folds' options=`options'" //DEBUG
  
  qui count `if' `in'
  if(`folds'<1 | `folds'>=`r(N)') {
    di as error "Invalid number of folds: `folds'. Must be between 2 and the number of active observations `r(N)'."
    exit 1
  }
  
  if(`folds'==1) {
    // special case: 1-fold is the same as just training
    `estimator' `estimator2' `varlist' `if' `in', `options'
    predict `target'
    exit 
  }
  
  if("`strata'" != "") {
    confirm variable `strata'
    di as error "crossvalidate: stratification not implemented."
    exit 2
  }
  
  
  /* generate folds */
  // the easiest way to do this in Stata is simply to mark a new column
  // and stamp out id numbers into it
  // the tricky part is dealing with if/in
  // and the trickier (and currently not implemented) part is dealing with
  // stratification (making sure each fold has equal proportions of a categorical variable)
  tempvar fold
    
  // compute the size of each group *as a float*
  // derivation:
  // we have r(N) items in total -- you can also think of this as the last item, which should get mapped to group `folds' 
  // we want `folds' groups
  // if we divide each _n by `folds' then the largest ID generated is r(N)/`folds' == # of items per group
  // so we can't do that
  // if instead we divide each _n by r(N)/`folds', then the largest is r(N)/(r(N)/`folds') = `folds'
  // Also, maybe clearer, this python script empirically proves the formula:
  /*
  for G in range(1,302):
      for N in range(G,1302):
          folds = {k: len(list(g)) for k,g in groupby(int((i-1)//(N/G)+1) for i in range(1,N+1)) }
          print("N =", N, "G =", G, "keys:", set(folds.keys()));
          assert set(folds.keys()) == set(range(1,G+1))
  */
  qui count `if' `in'
  local g =  `r(N)'/`folds'
    // generate a pseudo-_n which is the observation *within the if/in subset*
    // if you do not give if/in this is should be equal to _n
  qui gen int `fold' = 1 `if' `in'
  
  /* shuffling */
  // this is tricky: shuffling has to happen *after* partially generating fold IDs,
  // because the shuffle invalidates the `in', but it must happen *before* the IDs
  // are actually assigned because otherwise there's no point
  if("`shuffle'"!="") {
    tempvar original_order
    tempvar random_order
    qui gen `original_order' = _n
    qui gen `random_order' = uniform()
    sort `random_order'
  }
  
  qui replace `fold' = sum(`fold') if !missing(`fold') //egen has 'fill()' which is more complicated than this, and so does not allow if/in. None of its other options seem to be what I want.
  
  // map the pseudo-_n into a fold id number
  // nopromote causes integer instead of floating point division, which is needed for id numbers
  //Stata counts from 1, which is why the -1 and +1s are there
  // (because the proper computation should happen counting from 0, but not true)
  qui replace `fold' = (`fold'-1)/`g'+1 if !missing(`fold'), nopromote

  // because shuffling can only affect which folds data ends up in,
  // immediately after generating fold labels we can put the data back as they were.
  // (i prefer rather do this early lest something later break and the dataset be mangled)
  // (this can't use snapshot or preserve because restoring those will erase `fold')
  if("`shuffle'"!="") {
    sort `original_order'
  }
  
  // make sure the trickery above worked
  qui sum `fold'
  assert `r(min)'==1
  assert `r(max)'==`folds'
  qui levelsof `fold'
  assert `: word count `r(levels)''==`folds'
  
  
  /* cross-predict */
  // We don't actually predict into target directly, because most estimation commands
  // get annoyed at you trying to overwrite an old variable (even if an unused region).
  // Instead we repeatedly predict into B, copy the fold into target, and destroy B.
  // 
  // We don't actually create `target' until we have done one fold, at which point we *clone* it
  // because we do not know what types/labels the predictor wants to attach to its predictions,
  // (which can lead to strangeness if the predictor is inconsistent with itself)
  tempvar B
  forvalues f = 1/`folds' {
    // train on everything satisfaying `if' `in' that isn't the fold
	// use IFAND to combine `if' with the other if condition
	if ("`if'"=="") local IFAND = "if "   //`if' is empty 
	else			local IFAND = " & "   //`if' is not empty and contains the word if already
    qui count `if' `IFAND'  `fold'!=`f'   `in'
    di as text "[fold `f'/`folds': training on `r(N)' observations]"
    capture noi `estimator' `estimator2' `varlist' `if' `IFAND' `fold' != `f'   `in', `options'
    if(_rc!=0) {
      di as error "`estimator' `estimator2' failed"
      exit _rc
    }
    
    // predict on the fold
    qui count if `fold' == `f'
    di "[fold `f'/`folds': predicting on `r(N)' observations]"
    predict `B' if `fold' == `f'
    
    // on the first fold, *clone* B to our real output
    if(`f' == 1) {
      qui clonevar  `target' = `B' if 0
    }
    
    // save the predictions from the current fold
    qui replace `target' = `B' if `fold' == `f'
    drop `B'
  }
  
  if "`gen'"!="" {
	// optionally, keep fold variable
	gen `gen' = `fold'
  }
  
  /* clean up */
  // drop e(), because its contents at this point are only valid for the last fold
  // and that's just confusing
  ereturn clear
  
end
///////////////////////////////////////////////////////////////////////////////////
//Version History: 
//
//version 1.2.1  Oct 28, 2020   remove eclass; not needed
//version 1.2.0  June, 2020   gen(newvar) option
//version 1.1.0  May, 2020
//version 1.0.0  May, 2017