*! 2.3.0 Ariel Linden 28Oct2020 // Changed -randomforest- to -rforest- since the former no longer exists as a package *! 2.2.0 Ariel Linden 20Aug2018 // Added -svmachines- and -boost- as model options // Added p-value to table to compare full with test data *! 2.0.0 Ariel Linden 20Aug2018 // Added -randomforest- as model option // Changed k-group generating process to ensure proportions of the outcome variable are equal in each k-group to that of the full sample *! 1.1.0 Ariel Linden 059Nov2017 // Added the "save" option *! 1.0.0 Ariel Linden 05Oct2017 program define kfoldclass, rclass version 11.0 syntax varlist(min=2 numeric fv) [if] [in] /// [fweight iweight pweight] , /// MODel(string) /// [k(numlist int max=1 >1) /// CUToff(numlist max=1 >0 <1) /// SAve /// FIGure *] quietly { // Get Y and X variables gettoken dvar xvar : varlist marksample touse count if `touse' if r(N) == 0 error 2000 local N = r(N) replace `touse' = -`touse' // Validate options if "`k'" == "" local k = 5 if `k' > `N' { di as err "Number of folds cannot exceed number of observations" exit 198 } if "`cutoff'" == "" local cutoff = 0.5 tempvar u group tokeep u1 u2 I T full1 full2 yhat1 yhat2 test ******************************************************************* // generate k groups with equal proportions of dvar to full sample local l = `k'-1 gen `group' = . gen `tokeep' = 0 if `touse' gen double `u1' = runiform() gen double `u2' = runiform() sort `touse' `tokeep' `dvar' `u1' `u2' bys `touse' `dvar': gen long `T' = sum(`tokeep'==0) if `touse' bys `touse' `dvar': replace `T' = int(`T'[_N]/`k' +.5) if `touse' forval i = 1/`l' { bys `touse' `dvar': gen long `I' = sum(`tokeep'==0) if `group' ==. & `touse' bys `touse' `tokeep' `dvar' `I': replace `tokeep'= 1 if `tokeep'==0 & `I'<=`T' replace `group' = `i' if `tokeep'==1 & `group' ==. & `touse' drop `I' } replace `group' = `k' if `group' ==. & `touse' ******************************************************************* ************************ // Full data ************************ // run model on full sample if "`model'" == "probit" { probit `dvar' `xvar' if `touse' [`weight' `exp'], `options' predict `full2' if `touse' } else if "`model'" == "logit" { logit `dvar' `xvar' if `touse' [`weight' `exp'], `options' predict `full2' if `touse' } else if "`model'" == "boost" { boost `dvar' `xvar' if `touse', distribution(logistic) predict(`full2') `options' replace `full2' = . if !`touse' } else if "`model'" == "svmachines" { tempvar dvar2 gen byte `dvar2' = `dvar' // replicate dvar to avoid value labels svmachines `dvar2' `xvar' if `touse', prob `options' predict fullSVM if `touse', prob // hardcoded variables are needed for predictions in svmachines local full2 fullSVM_1 } else if "`model'" == "rforest" { rforest `dvar' `xvar' if `touse', type(class) `options' predict `full1' `full2' if `touse', pr } local full `full2' // collect cell values for classification count if `dvar' !=0 & `full' >= `cutoff' & `touse' local a1 = r(N) count if `dvar' ==0 & `full' >= `cutoff' & `touse' local b1 = r(N) count if `dvar' !=0 & `full' <`cutoff' & `touse' local c1 = r(N) count if `dvar' ==0 & `full' <`cutoff' & `touse' local d1 = r(N) ************************ // test (k-fold) data ************************ // gen variable to hold test values gen `test'=. } //end quietly // fancy setup for dots di _n di as txt "Iterating across (" as res `k' as txt ") hold-out samples" di as txt "{hline 4}{c +}{hline 3} 1 " "{hline 3}{c +}{hline 3} 2 " "{hline 3}{c +}{hline 3} 3 " "{hline 3}{c +}{hline 3} 4 " "{hline 3}{c +}{hline 3} 5 " // run model on test (k-fold) sample forval i = 1/`k' { _dots `i' 0 quietly { if "`model'" == "probit" { probit `dvar' `xvar' if `group'!=`i' & `touse' [`weight' `exp'], `options' predict `yhat2' if `group'==`i' } else if "`model'" == "logit" { logit `dvar' `xvar' if `group'!=`i' & `touse' [`weight' `exp'], `options' predict `yhat2' if `group'==`i' } else if "`model'" == "boost" { boost `dvar' `xvar' if `group'!=`i' &`touse', distribution(logistic) predict(`yhat2') `options' replace `yhat2' = . if !`touse' & `group'!=`i' } else if "`model'" == "svmachines" { svmachines `dvar2' `xvar' if `group'!=`i' &`touse', prob `options' predict yhatSVM if `group'==`i', prob local yhat2 yhatSVM_1 drop yhatSVM yhatSVM_0 } else if "`model'" == "rforest" { rforest `dvar' `xvar' if `group'!=`i' & `touse', type(class) `options' predict `yhat1' `yhat2' if `group'==`i', pr drop `yhat1' } replace `test' = `yhat2' if `group'==`i' drop `yhat2' } //end forval } // end quietly quietly { // collect cell values for classification count if `dvar' !=0 & `test' >=`cutoff' & `touse' local a2 = r(N) count if `dvar' ==0 & `test' >= `cutoff' & `touse' local b2 = r(N) count if `dvar' !=0 & `test' <`cutoff' & `touse' local c2 = r(N) count if `dvar' ==0 & `test' <`cutoff' & `touse' local d2 = r(N) // collect values for ROC area roctab `dvar' `full' local roc1 : di %05.4f r(area) roctab `dvar' `test' local roc2 : di %05.4f r(area) roccomp `dvar' `full' `test' local rocp : di %05.4f r(p) // Graph the ROC curves and save if "`figure'" != "" { roccomp `dvar' `full' `test' , graph legend(rows(2) order(1 2 3) label(1 "Full ROC area: `roc1'") label(2 "Test ROC area: `roc2'") label(3 "Reference") ) } if "`save'" != "" { gen group = `group' label var group "k group" gen full = `full' label var full "Full-sample predictions" gen test = `test' label var test "Test-sample predictions" } } // end quietly // Classification calculations and save r() * for full data ret scalar P_corr_1 = ((`a1'+`d1')/(`a1'+`b1'+`c1'+`d1'))*100 /* correctly classified */ ret scalar P_p1_1 = (`a1'/(`a1'+`c1'))*100 /* sensitivity */ ret scalar P_n0_1 = (`d1'/(`b1'+`d1'))*100 /* specificity */ ret scalar P_p0_1 = (`b1'/(`b1'+`d1'))*100 /* false + given ~D */ ret scalar P_n1_1 = (`c1'/(`a1'+`c1'))*100 /* false - given D */ ret scalar P_1p_1 = (`a1'/(`a1'+`b1'))*100 /* + pred value */ ret scalar P_0n_1 = (`d1'/(`c1'+`d1'))*100 /* - pred value */ ret scalar P_0p_1 = (`b1'/(`a1'+`b1'))*100 /* false + given + */ ret scalar P_1n_1 = (`c1'/(`c1'+`d1'))*100 /* false - given - */ ret scalar roc1 = `roc1' /* roc curve */ * for test data ret scalar P_corr_2 = ((`a2'+`d2')/(`a2'+`b2'+`c2'+`d2'))*100 /* correctly classified */ ret scalar P_p1_2 = (`a2'/(`a2'+`c2'))*100 /* sensitivity */ ret scalar P_n0_2 = (`d2'/(`b2'+`d2'))*100 /* specificity */ ret scalar P_p0_2 = (`b2'/(`b2'+`d2'))*100 /* false + given ~D */ ret scalar P_n1_2 = (`c2'/(`a2'+`c2'))*100 /* false - given D */ ret scalar P_1p_2 = (`a2'/(`a2'+`b2'))*100 /* + pred value */ ret scalar P_0n_2 = (`d2'/(`c2'+`d2'))*100 /* - pred value */ ret scalar P_0p_2 = (`b2'/(`a2'+`b2'))*100 /* false + given + */ ret scalar P_1n_2 = (`c2'/(`c2'+`d2'))*100 /* false - given - */ ret scalar roc2 = `roc2' /* roc curve */ ret scalar rocp = `rocp' if "`model'" == "svmachines" { drop fullSVM fullSVM_1 fullSVM_0 } // Produce output tables #delimit ; di _n ; di _n in gr `"Classification Table for Full Data:"' ; di _n in smcl in gr _col(15) "{hline 8} True {hline 8}" _n `"Classified {c |}"' _col(22) `"D"' _col(35) `"~D {c |}"' _col(46) `"Total"' ; di in smcl in gr "{hline 11}{c +}{hline 26}{c +}{hline 11}" ; di in smcl in gr _col(6) "+" _col(12) `"{c |} "' in ye %9.0g `a1' _col(28) %9.0g `b1' in gr `" {c |} "' in ye %9.0g `a1'+`b1' ; di in smcl in gr _col(6) "-" _col(12) "{c |} " in ye %9.0g `c1' _col(28) %9.0g `d1' in gr `" {c |} "' in ye %9.0g `c1'+`d1' ; di in smcl in gr "{hline 11}{c +}{hline 26}{c +}{hline 11}" ; di in smcl in gr `" Total {c |} "' in ye %9.0g `a1'+`c1' _col(28) %9.0g `b1'+`d1' in gr `" {c |} "' in ye %9.0g `a1'+`b1'+`c1'+`d1' ; di _n ; di _n in gr `"Classification Table for Test Data:"' ; di _n in smcl in gr _col(15) "{hline 8} True {hline 8}" _n `"Classified {c |}"' _col(22) `"D"' _col(35) `"~D {c |}"' _col(46) `"Total"' ; di in smcl in gr "{hline 11}{c +}{hline 26}{c +}{hline 11}" ; di in smcl in gr _col(6) "+" _col(12) `"{c |} "' in ye %9.0g `a2' _col(28) %9.0g `b2' in gr `" {c |} "' in ye %9.0g `a2'+`b2' ; di in smcl in gr _col(6) "-" _col(12) "{c |} " in ye %9.0g `c2' _col(28) %9.0g `d2' in gr `" {c |} "' in ye %9.0g `c2'+`d2' ; di in smcl in gr "{hline 11}{c +}{hline 26}{c +}{hline 11}" ; di in smcl in gr `" Total {c |} "' in ye %9.0g `a2'+`c2' _col(28) %9.0g `b2'+`d2' in gr `" {c |} "' in ye %9.0g `a2'+`b2'+`c2'+`d2' ; di _n ; di _n in gr `"Classified + if predicted Pr(D) >= `cutoff'"' _n `"True D defined as `y' != 0"' ; di in gr _col(45) `"Full"' _col(58) `"Test"'; di in smcl in gr "{hline 64}" ; di in gr `"Sensitivity"' _col(33) `"Pr( +| D)"' in ye %8.2f return(P_p1_1) `"%"' _col(55) in ye %8.2f return(P_p1_2) `"%"' _n in gr `"Specificity"' _col(33) `"Pr( -|~D)"' in ye %8.2f return(P_n0_1) `"%"' _col(55) in ye %8.2f return(P_n0_2) `"%"' _n in gr `"Positive predictive value"' _col(33) `"Pr( D| +)"' in ye %8.2f return(P_1p_1) `"%"' _col(55) in ye %8.2f return(P_1p_2) `"%"' _n in gr `"Negative predictive value"' _col(33) `"Pr(~D| -)"' in ye %8.2f return(P_0n_1) `"%"' _col(55) in ye %8.2f return(P_0n_2) `"%"' ; di in smcl in gr "{hline 64}" ; di in gr `"False + rate for true ~D"' _col(33) `"Pr( +|~D)"' in ye %8.2f return(P_p0_1) `"%"' _col(55) in ye %8.2f return(P_p0_2) `"%"' _n in gr `"False - rate for true D"' _col(33) `"Pr( -| D)"' in ye %8.2f return(P_n1_1) `"%"' _col(55) in ye %8.2f return(P_n1_2) `"%"' _n in gr `"False + rate for classified +"' _col(33) `"Pr(~D| +)"' in ye %8.2f return(P_0p_1) `"%"' _col(55) in ye %8.2f return(P_0p_2) `"%"' _n in gr `"False - rate for classified -"' _col(33) `"Pr( D| -)"' in ye %8.2f return(P_1n_1) `"%"' _col(55) in ye %8.2f return(P_1n_2) `"%"'; di in smcl in gr "{hline 64}" ; di in gr `"Correctly classified"' _col(42) in ye %8.2f return(P_corr_1) `"%"' _col(55) in ye %8.2f return(P_corr_2) `"%"' ; di in smcl in gr "{hline 64}" ; di in gr `"ROC area"' _col(42) in ye %9.4f return(roc1) _col(55) in ye %9.4f return(roc2) ; di in smcl in gr "{hline 64}" ; di in gr `"p-value for Full vs Test ROC areas"' _col(42) _col(55) in ye %9.4f return(rocp) ; di in smcl in gr "{hline 64}" ; end ;