*! version 1.2  09dec2017 Michael D Barker Felix Pöge

/*******************************************************************************
Michael Barker
mdb96@georgetown.edu

Felix Pöge
felix.poege@ip.mpg.de

strdist.ado file
Implements distance calculation between two strings.  
Uses Levenstein distance as metric.

Thanks to Sergio Correia for his suggestions on how to improve strdist's
performance.

See also: ustrdist.ado (for version 14 and above, adds unicode support)

*******************************************************************************/

version 10

*** Check arguments and call appropriate sub-routine
program define strdist, rclass
    syntax anything [if] [in], [Generate(name) MAXdist(integer -1)]

    if c(version) >= 14 {
        noi di as text "Hint: From Stata 14 onwards, use of the unicode-compatible ustrdist method is recommended."
    }
	
	if `maxdist' < 1 & `maxdist' != -1 {
		noi di as text "Invalid argument maxdist (`maxdist'). The Levenstein distance without limitation will be computed."
		local maxdist 0
	}

    gettoken first remain : anything , qed(isstring1)
    gettoken second remain : remain , qed(isstring2)
    if `"`remain'"' != "" error 103

    * Two string scalar version
    if `isstring1' & `isstring2' strdist0var `if' `in' , first(`"`first'"') second(`"`second'"') gen(`generate') maxdist(`maxdist')

    * One or two string variable version
    else {

        local strscalar ""

        if `isstring1' {
            local strscalar = `"`first'"'
            local first ""
        }
        else if `isstring2' {
            local strscalar = `"`second'"'
            local second ""
        }

        strdist12var `first' `second' `if' `in' , m(`"`strscalar'"') gen(`generate') maxdist(`maxdist')
    } 

    * Return values generated by subroutines
    return scalar d = r(d)
    return local strdist "`r(strdist)'"
end

*** One or two variable command
program define strdist12var , rclass
    syntax varlist(min=1 max=2 string) [if] [in] , [Match(string)] [GENerate(name) MAXdist(integer 0)] 
    tempvar touse
    mark `touse' `if' `in'

	if `maxdist' < 1 {
		local maxdist 0
	}
	
    *** Declare default and confirm newvarname 
    if `"`generate'"' == "" local generate "strdist"
    confirm new variable `generate' 

    tokenize "`varlist'"
    if "`2'"=="" mata: matalev1var("`1'", `"`match'"' , "`generate'", "`touse'", `maxdist')
    else         mata: matalev2var("`1'",  "`2'"      , "`generate'", "`touse'", `maxdist')
    
    return local strdist "`generate'"
end

*** Two string scalar command
program define strdist0var, rclass
    syntax [if] [in], [first(string) second(string) Generate(name) MAXdist(integer 0)]
    marksample touse
	
	if `maxdist' < 1 {
		local maxdist 0
	}

    tempname dist
    mata: st_numscalar("`dist'" , fastlev(`"`first'"' , `"`second'"', `maxdist'))

    if `"`generate'"' != "" {
        confirm new variable `generate' 
        generate int `generate' = `dist' if `touse'
        return local strdist "`generate'"
    }
    
    display as result `dist'
    return scalar d = `dist'

end

mata:
/******************************************************************************
   Terminology Note
   key: string to measure each observation against
   trymatch: one of many potential matches to be measured against the key
   TRIES: Nx1 vector of all "trymatch"s
******************************************************************************/

void matalev1var(string scalar varname, string scalar key, 
                 string scalar newvar, string scalar touse, | real scalar maxdist) {

    string colvector TRIES
    real colvector dist
    numeric t
				 
	if (args()<5 | maxdist < 1) maxdist = 0

    TRIES = st_sdata(. , varname , touse) // Nx1 string vector with potential matches 
    n = rows(TRIES)
	dist = J(n, 1, .)             // Nx1 real vector to hold lev distances to each match

    for (t = 1 ; t <= n; t++) { 
        dist[t] = fastlev(key, TRIES[t, 1], maxdist) // save distance
    }
	vartype = max(dist) > 100 ? "int" : "byte"
    st_store(. , st_addvar(vartype, newvar), touse, dist)
}

void matalev2var(string scalar var1, string scalar var2, 
                 string scalar newvar, string scalar touse, | real scalar maxdist) {

    string colvector TRIES
    string colvector KEYS
    real colvector dist
    numeric t

	if (args()<5 | maxdist < 1) maxdist = 0
				 
    KEYS  = st_sdata(. , var1 , touse) 
    TRIES = st_sdata(. , var2 , touse) // Nx1 string vector with potential matches 
    n = rows(TRIES)
	dist = J(n, 1, .)          // Nx1 real vector to hold lev distances to each match
	
    for (t = 1 ; t <= n; t++) { 
        dist[t] = fastlev(KEYS[t,1], TRIES[t,1], maxdist)
    }
	vartype = max(dist) > 100 ? "int" : "byte"
    st_store(. , st_addvar(vartype, newvar), touse, dist)
}

real scalar fastlev(string scalar a, string scalar b, | real scalar maxdist)
{
    real scalar len_a, len_b, i, j, cost
    real vector v0, v1, chars_a, chars_b
	   
	if (args()<3 | maxdist < 1) maxdist = . // Ensure that maxdist is positive or missing
    
	if (a==b) return(0) // Shortcut if strings are the same   
	len_a = strlen(a)
    len_b = strlen(b)
    if (maxdist < . & abs(len_a - len_b) > maxdist) return(.) // Shortcut if we know we will exceed maxdist
    if (len_a==0 | len_b==0) return(len_a + len_b) // Shortcut if one is empty

    // Code based on:
    // https://en.wikipedia.org/wiki/Levenshtein_distance#Iterative_with_two_matrix_rows
    
    v0 = 0::len_b
    v1 = J(len_b+1, 1, .)
    chars_a = ascii(a)
    chars_b = ascii(b)
	
	// Algorithm with stopping condition
	if (maxdist < .) {
		for (i=1; i<=len_a; i++) {
			v1[1] = i
			for (j=1; j<=len_b; j++) {
				cost = chars_a[i] != chars_b[j]
				v1[j + 1] = minmax(( v1[j] + 1, v0[j+1] + 1, v0[j] + cost ))[1]
			}
			swap(v0, v1)
			// If the minimum necessary edit distance is exceeded, stop
			if (minmax(v0)[1] > maxdist) return(.)
		}
		// Final check: Is the edit distance really not exceeded?
		if (v0[len_b+1] > maxdist) return(.)
		return(v0[len_b+1])
	}
	// Algorithm without stopping condition
	else {
		for (i=1; i<=len_a; i++) {
			v1[1] = i
			for (j=1; j<=len_b; j++) {
				cost = chars_a[i] != chars_b[j]
				v1[j + 1] = minmax(( v1[j] + 1, v0[j+1] + 1, v0[j] + cost ))[1]
			}
			swap(v0, v1)
		}
		return(v0[len_b+1])
	}
}

end