/* model2stata: a subroutine to convert the global struct svm_model that lives in the DLL to a mixture of e() entries, variables, and matrices.
 *
 * Besides being usefully modular, this *must* be its own subroutine because it needs to be marked eclass.
 * This is because, due to limitations in the Stata C API, there has to be an awkward dance to get the information out:
 *   _svmachines.plugin writes to the (global!) scalar dictionary and then this subroutine code copies those entries to e().
 *
 * as with svm_load, the extension function is called multiple times with sub-sub-commands, because it doesn't have permission to perform all the operations needed

 * if passed, SV specifies a column to create and then record svm_model->sv_indecies into
 */

/* load the C extension */
svm_ensurelib           // check for libsvm
program _svmachines, plugin    // load the wrapper for libsvm

program define _svm_model2stata, eclass
  version 13
  
  syntax [if] [in], [SV(string)] [Verbose]
  
  * as with loading, this has to call in and out of the plugin because chicken/egg:
  *   the plugin doesn't have permission to allocate Stata memory (in this case matrices),
  *   but we don't know how much to allocate before interrogating the svm_model
  
  
  * Phase 1
  
  * the total number of observations
  * this gets set by _svmachines.c::train(); it doesn't exist for a model loaded via import().
  * nevertheless it is in this file instead of svm_train.ado, because it is most similar here
  * but we cap { } around it so the other case is tolerable
  capture {
    ereturn scalar N = _model2stata_N
    scalar drop _model2stata_N
  }
  
  
  /*an undefined macro will inconsistently cause an eval error because `have_rho'==1 will eval to ==1 will eval to "unknown variable"*/
  /*so just define them ahead of time to be safe*/
  local have_sv_indices = 0
  local have_sv_coef = 0
  local have_rho = 0
  local labels = ""
  
  
  plugin call _svmachines `if' `in', `verbose' "_model2stata" 1
  
  * the total number of (detected?) classes
  ereturn scalar N_class = _model2stata_nr_class
  scalar drop _model2stata_nr_class
  
  * the number of support vectors
  ereturn scalar N_SV = _model2stata_l
  scalar drop _model2stata_l
  
  
  
  * Phase 2
  * Allocate Stata matrices and copy the libsvm matrices and vectors
  if(`have_sv_coef'==1 & `e(N_class)'>1 & `e(N_SV)'>0) {
    // with more than 11000 rows, don't create SV matrix unless you are running Stata MP
	// This allows running large data sizes in Stata SV, at the small cost of not getting the SV vectors
    if (e(N_SV)<=11000 | c(MP)==1) {
	
    capture noisily {
      matrix sv_coef = J(e(N_class)-1,e(N_SV),.)
      
      // there doesn't seem to be an easy way to generate a list of strings with a prefix in Stata
      // so: the inefficient way
      local cols = ""
      forval j = 1/`e(N_SV)' {
        local cols = "`cols' SV`j'"
      }
      matrix colnames sv_coef = `cols'
      
      // TODO: rows
      //  there is one row per class *less one*. the rows probably represent decision boundaries, then. I'm not sure what this should be labelled.
      // matrix rownames sv_coef = class1..class`e(N_SV)'
    }
	}
  }
  
  if(`have_rho'==1 & `e(N_class)'>0) {
    capture noisily matrix rho = J(e(N_class),e(N_class),.)
  }
  

  
  * TODO: also label the rows according to model->label (libsvm's "labels" are just more integers, but it helps to be consistent anyway);
  *  I can easily extract ->label with the same code, but attaching it to the rownames of the other is tricky
  capture noisily {
    plugin call _svmachines `if' `in', `verbose' "_model2stata" 2
	
    // Label the resulting matrices and vectors with the 'labels' array, if we have it
    if("`labels'"!="") {
      ereturn local levels = strtrim("`labels'")
      
      capture matrix rownames rho = `labels'
      capture matrix colnames rho = `labels'
    }
  }
  
  * Phase 3
  * Export the SVs 
  if("`sv'"!="") {
    if(`have_sv_indices'==0) {
      di as err "Warning: SV statuses missing. Perhaps your underlying version of libsvm is too old to support sv()."
    }
    else {
      capture noisily {
        // he internal libsvm format is a list of indices
        // we want indicators, which are convenient for Stata
        // so we  *start* with all 0s (rather than missings) and overwrite with 1s as we discover SVs
        quietly generate `sv' `if' `in' = 0
        plugin call _svmachines `sv' `if' `in', `verbose' "_model2stata" 3
      }
    }
  }
  
  * Phase 4
  * Export the rest of the values to e()
  * We *cannot* export matrices to e() from the C interface, hence we have to do this very explicit thing
  * NOTE: 'ereturn matrix' erases the old name (unless you specify ,copy), which is why we don't have to explicitly drop things
  *       'ereturn scalar' doesn't do this, because Stata loves being consistent. Just go read the docs for 'syntax' and see how easy it is. 
  * All of these are silenced because various things might kill any of them, and we want failures to be independent of each other
    
  quietly capture ereturn matrix sv_coef = sv_coef
  quietly capture ereturn matrix rho = rho
end