// cv: cross-validated predictions from any Stata estimation command. // Breaks a dataset into a number of subsets ("folds"), and for each // runs an estimator on everything but that subset, and predicts results. program define crossvalidate *! version 1.2.2 Oct 28, 2020 name change inside crossv.ado version 14.1 /* parse arguments */ gettoken target 0 : 0 gettoken estimator 0 : 0 // estimators starting with discrim have two words if "`estimator'"=="discrim" gettoken estimator2 0 : 0 if "`estimator'"=="discrim" { di as error "Requires special parsing : varlist is no longer y xvars, but instead ,group(y)" di as error " not yet implemented; see gridsearch.ado" // if ("`estimator'"=="discrim") local addstuff = `"group(`group')"' " exit 198 } syntax varlist (fv ts) [if] [in], [folds(string)] [gen(string)] [shuffle] [*] confirm name `estimator' confirm new variable `target' //Stata if funky: because I use [*] above, if I declare folds(int 5) and you pass a real (e.g. folds(18.5)), // rather than giving the usual "option folds() incorrectly specified" error, Stata *ignores* that folds, // gives the default value, and pushes the wrong folds into the `options' macro // instead, I take a string (i.e. anything) to ensure the folds option always, and then parse manually if("`folds'"=="") { local folds = 5 } confirm integer number `folds' //di as txt "folds= `folds' options=`options'" //DEBUG qui count `if' `in' if(`folds'<1 | `folds'>=`r(N)') { di as error "Invalid number of folds: `folds'. Must be between 2 and the number of active observations `r(N)'." exit 1 } if(`folds'==1) { // special case: 1-fold is the same as just training `estimator' `estimator2' `varlist' `if' `in', `options' predict `target' exit } if("`strata'" != "") { confirm variable `strata' di as error "crossvalidate: stratification not implemented." exit 2 } /* generate folds */ // the easiest way to do this in Stata is simply to mark a new column // and stamp out id numbers into it // the tricky part is dealing with if/in // and the trickier (and currently not implemented) part is dealing with // stratification (making sure each fold has equal proportions of a categorical variable) tempvar fold // compute the size of each group *as a float* // derivation: // we have r(N) items in total -- you can also think of this as the last item, which should get mapped to group `folds' // we want `folds' groups // if we divide each _n by `folds' then the largest ID generated is r(N)/`folds' == # of items per group // so we can't do that // if instead we divide each _n by r(N)/`folds', then the largest is r(N)/(r(N)/`folds') = `folds' // Also, maybe clearer, this python script empirically proves the formula: /* for G in range(1,302): for N in range(G,1302): folds = {k: len(list(g)) for k,g in groupby(int((i-1)//(N/G)+1) for i in range(1,N+1)) } print("N =", N, "G =", G, "keys:", set(folds.keys())); assert set(folds.keys()) == set(range(1,G+1)) */ qui count `if' `in' local g = `r(N)'/`folds' // generate a pseudo-_n which is the observation *within the if/in subset* // if you do not give if/in this is should be equal to _n qui gen int `fold' = 1 `if' `in' /* shuffling */ // this is tricky: shuffling has to happen *after* partially generating fold IDs, // because the shuffle invalidates the `in', but it must happen *before* the IDs // are actually assigned because otherwise there's no point if("`shuffle'"!="") { tempvar original_order tempvar random_order qui gen `original_order' = _n qui gen `random_order' = uniform() sort `random_order' } qui replace `fold' = sum(`fold') if !missing(`fold') //egen has 'fill()' which is more complicated than this, and so does not allow if/in. None of its other options seem to be what I want. // map the pseudo-_n into a fold id number // nopromote causes integer instead of floating point division, which is needed for id numbers //Stata counts from 1, which is why the -1 and +1s are there // (because the proper computation should happen counting from 0, but not true) qui replace `fold' = (`fold'-1)/`g'+1 if !missing(`fold'), nopromote // because shuffling can only affect which folds data ends up in, // immediately after generating fold labels we can put the data back as they were. // (i prefer rather do this early lest something later break and the dataset be mangled) // (this can't use snapshot or preserve because restoring those will erase `fold') if("`shuffle'"!="") { sort `original_order' } // make sure the trickery above worked qui sum `fold' assert `r(min)'==1 assert `r(max)'==`folds' qui levelsof `fold' assert `: word count `r(levels)''==`folds' /* cross-predict */ // We don't actually predict into target directly, because most estimation commands // get annoyed at you trying to overwrite an old variable (even if an unused region). // Instead we repeatedly predict into B, copy the fold into target, and destroy B. // // We don't actually create `target' until we have done one fold, at which point we *clone* it // because we do not know what types/labels the predictor wants to attach to its predictions, // (which can lead to strangeness if the predictor is inconsistent with itself) tempvar B forvalues f = 1/`folds' { // train on everything satisfaying `if' `in' that isn't the fold // use IFAND to combine `if' with the other if condition if ("`if'"=="") local IFAND = "if " //`if' is empty else local IFAND = " & " //`if' is not empty and contains the word if already qui count `if' `IFAND' `fold'!=`f' `in' di as text "[fold `f'/`folds': training on `r(N)' observations]" capture noi `estimator' `estimator2' `varlist' `if' `IFAND' `fold' != `f' `in', `options' if(_rc!=0) { di as error "`estimator' `estimator2' failed" exit _rc } // predict on the fold qui count if `fold' == `f' di "[fold `f'/`folds': predicting on `r(N)' observations]" predict `B' if `fold' == `f' // on the first fold, *clone* B to our real output if(`f' == 1) { qui clonevar `target' = `B' if 0 } // save the predictions from the current fold qui replace `target' = `B' if `fold' == `f' drop `B' } if "`gen'"!="" { // optionally, keep fold variable gen `gen' = `fold' } /* clean up */ // drop e(), because its contents at this point are only valid for the last fold // and that's just confusing ereturn clear end /////////////////////////////////////////////////////////////////////////////////// //Version History: // //version 1.2.1 Oct 28, 2020 remove eclass; not needed //version 1.2.0 June, 2020 gen(newvar) option //version 1.1.0 May, 2020 //version 1.0.0 May, 2017