/* _svm_train: this is the meat of the Stata interface to the fitting algorithm.
This is called by svm_train; though Stata programs can call subprograms defined in the same file as them,
similar to Matlab, this has to be a separate file as the special command 'xi' used there apparently cannot
*/
/* load the C extension */
svm_ensurelib // check for libsvm
program _svmachines, plugin // load the wrapper for libsvm
program define _svm_train, eclass
version 13
/* argument parsing */
// these defaults were taken from svm-train.c
// (except that we have shrinking off by default)
#delimit ;
syntax varlist (numeric)
[if] [in]
[,
// strings cannot have default values
// ints and reals *must*
// (and yes the only other data types known to syntax are int and real, despite the stata datatypes being str, int, byte, float, double, ...)
//
// also be careful of the mixed-case shenanigans
Type(string)
Kernel(string)
Gamma(real 0) COEF0(real 0) DEGree(int 3)
C(real 1) EPSilon(real 0.1) NU(real 0.5)
// weights() --> char* weight_label[], double weight[nr_weight] // how should this work?
// apparently syntax has a special 'weights' argument which is maybe meant just for this purpose
// but how to pass it on?
TOLerance(real 0.001)
SHRINKing PROBability
CACHE_size(int 100)
// if specified, a column to generate to mark which rows were detected as SVs
SV(string)
// turn on internal libsvm printing
Verbose
//set the C random seed
seed(int 1)
];
#delimit cr
// stash because we run syntax again below, which will smash these
local cmd = "`0'"
local _varlist = "`varlist'"
local _if = "`if'"
local _in = "`in'"
// make the string variables case insensitive (by forcing them to CAPS and letting the .c deal with them that way)
local type = upper("`type'")
local kernel = upper("`kernel'")
// translate the boolean flags into integers
// the protocol here is silly, because syntax special-cases "no" prefixes:
// *if* the user gives the no form of the option, a macro is defined with "noprobability" in lower case in it
// in all *other* cases, the macro is undefined (so if you eval it you get "")
// conversely, with regular option flags, if the user gives it you get a macro with "shrinking" in it, and otherwise the macro is undefined
if("`shrinking'"=="shrinking") {
local shrinking = 1
}
else {
local shrinking = 0
}
if("`probability'"=="probability") {
local probability = 1
}
else {
local probability = 0
}
/* fill in default values (only for the string vars, because syntax doesn't support defaults for them) */
if("`type'"=="") {
local type = "SVC"
}
if("`kernel'"=="") {
local kernel = "RBF"
}
/* preprocessing */
if("`type'" == "ONE_CLASS") {
// handle the special-case that one-class is unsupervised and so takes no
// libsvm still reads a Y vector though; it just, apparently, ignores it
// rather than tweaking numbers to be off-by-one, the easiest is to silently
// duplicate the pointer to one of the variables.
gettoken Y : _varlist
local _varlist = "`Y' `_varlist'"
}
else {
gettoken depvar indepvars : _varlist
}
/* sanity checks */
if("`type'" == "SVC" | "`type'" == "NU_SVC") {
// "ensure" type is categorical
local T : type `depvar'
/*
if("`T'"=="float" | "`T'"=="double") {
di as error "Warning: `depvar' is a `T', which is usually used for continuous variables."
di as error " SV classification will cast real numbers to integers before fitting." //<-- this is done by libsvm with no control from us
di as error
di as error " If your outcome is actually categorical, consider storing it so:"
di as error " . tempvar B"
di as error " . generate byte \`B' = `depvar'" //CAREFUL: B is meant to be quoted and depvar is meant to be unquoted.
di as error " . drop `depvar'"
di as error " . rename \`B' `depvar'"
di as error " (If your category coding uses floating point levels you must choose a different coding)"
di as error
di as error " Alternately, consider SV regression: type(SVR) or type(NU_SVR)."
di as error
}
*/
}
if("`type'" == "SVR" | "`type'" == "NU_SVR") {
// "ensure" type is float or double
local T : type `depvar'
if "`T'"=="byte" | "`T'"=="int" | "`T'"=="long" {
di `"Your dependent variable is of type "`T'". "' ///
"The prediction variable will take the same type. " ///
`"To allow for continuous predictions, your dependent variable "`depvar'" has been recast to type "double" "'
recast double `depvar'
}
}
if(`probability'==1) {
// ensure model is a classification
if("`type'" != "SVC" & "`type'" != "NU_SVC") {
// the command line tools *allow* this combination, but at prediction time silently change the parameters
// "Errors should never pass silently. Unless explicitly silenced." -- Tim Peters, The Zen of Python
di as error "Error: requested model is a `type'. You can only use the probability option with classification models (SVC, NU_SVC)."
exit 2
}
}
if("`sv'"!="") {
// fail-fast on name errors in sv()
local 0 = "`sv'"
syntax newvarname
}
/* call down into C */
/* CAREFUL: epsilon() => svm_param->p and tol() => svm_param->epsilon */
#delimit ;
plugin call _svmachines `_varlist' `_if' `_in',
`verbose' // notice: this is *not* in quotes, which means that if it's not there it's not there at all
"train"
"`type'" "`kernel'"
"`gamma'" "`coef0'" "`degree'"
"`c'" "`epsilon'" "`nu'"
"`tolerance'"
"`shrinking'" "`probability'"
"`cache_size'" "`seed'"
;
#delimit cr
// *reparse* the command line in order to fix varlist at it's current value.
// If "varlist" includes tokens that get expanded to multiple variables
// then when svm_predict reparses it again, it will get a different set.
local 0 = "`cmd'"
syntax varlist [if] [in], [*]
local cmd = "`varlist' `if' `in', `options'"
/* fixup the e() dictionary */
ereturn clear
// set standard Stata estimation (e()) properties
ereturn local cmd = "svmachines"
ereturn local cmdline = "`e(cmd)' `cmd'"
ereturn local predict = "svm_predict" //this is a function pointer, or as close as Stata has to that: causes "predict" to run "svm_predict"
ereturn local estat = "svm_estat" //ditto. NOT IMPLEMENTED
ereturn local title = "Support Vector Machine"
ereturn local model = "svmachines"
ereturn local svm_type = "`type'"
ereturn local svm_kernel = "`kernel'"
ereturn local depvar = "`depvar'" //NB: if depvar is "", namely if we're in ONE_CLASS, then Stata effectively ignores this line (which we want).
//ereturn local indepvars = "`indepvars'" //XXX Instead svm_predict reparses cmdline. This needs vetting.
// append the svm_model structure to e()
_svm_model2stata `_if' `_in', sv(`sv') `verbose'
end