*! Source of lmoremata11.mlib


*! {smcl}
*! {marker mm_density}{bf:mm_density.mata}{asis}
*! version 1.1.1  19nov2022  Ben Jann
version 11.2

// class & struct
local MAIN   mm_density
local SETUP  _`MAIN'_setup
local Setup  struct `SETUP' scalar
// real
local RS     real scalar
local RR     real rowvector
local RC     real colvector
local RV     real vector
// counters
local Int    real scalar
local IntC   real colvector
// string
local SS     string scalar
// boolean
local Bool   real scalar
local TRUE   1
local FALSE  0
// transmorphic
local T      transmorphic
local TS     transmorphic scalar
// pointers
local Pf     pointer(function) scalar
local PC     pointer(real colvector) scalar

mata:

// ---------------------------------------------------------------------------
// class definition
// ---------------------------------------------------------------------------

struct `SETUP' {
    // data
    `PC'    X             // pointer to X
    `PC'    w             // pointer to w
    `RS'    nobs          // number of obs/sum of weights
    `Bool'  pw            // weights are sampling weights
    `Bool'  sorted        // whether data is sorted

    // kernel
    `SS'    kernel        // name of kernel
    `Int'   adapt         // stages of adaptive estimator
    `Pf'    k             // kernel function
    `Pf'    kd            // kernel derivative
    `Pf'    kint          // kernel integral function
    `RS'    kh            // canonical bandwidth of kernel

    // bandwidth selection
    `RS'    h0            // user provided bandwidth
    `SS'    bwmethod      // bandwidth estimation method
    `RS'    adjust        // bandwidth adjustment factor
    `Int'   dpi           // number of DPI stages; default is 2
    `Bool'  qui           // quietly (omit SJPI/ISJ failure message)

    // support/boundary correction
    `RS'    lb            // lower boundary (missing if unbounded)
    `RS'    ub            // upper boundary (missing if unbounded)
    `SS'    bcmethod      // boundary correction method
    `Int'   bc            // 0=none, 1=renorm, 2=reflect, 3=linear correction
    `Bool'  rd            // relative data

    // other settings
    `Int'   n             // size of approximation grid
    `RS'    pad           // padding proportion of approximation grid
}

class `MAIN' {
    // settings
    public:
        void    data()        // set data
        `T'     kernel()      // kernel settings/retrieve kernel name
        `T'     bw()          // bandwidth settings
        `T'     support()     // support/boundary correction
        `T'     n()           // set/retrieve size of approximation grid
    private:
        void    new()         // initialize class with default settings
        void    clear()       // clear all results
        `Setup' setup         // settings
        `RC'    k(), kd(), kint() // raw kernel functions
        void    klc()         // compute terms for linear correction kernel
        `RC'    _K()          // kernel function including boundary correction
        void    checksuprt()  // check whether data is within support
    public:
        `RC'    X(), w()      // retrieve X and w
        `RS'    nobs()        // retrieve N (sum of weights)
        `Bool'  pw()          // retrieve pweighte flag
        `Bool'  sorted()      // retrieve sorted flag
        `Int'   adapt()       // retrieve stages of adaptive estimator
        `RS'    kh()          // retrieve canonical bandwidth of kernel
        `RC'    K()           // kernel function for external use
        `RC'    Kd()          // kernel derivative for external use
        `RS'    adjust()      // retrieve bw adjustment factor
        `Int'   dpi()         // retrieve dpi level
        `RS'    lb(), ub()    // retrieve lower and upper bounds of support
        `SS'    bc()          // retrieve boundary-correction method
        `Bool'  rd()          // retrieve relative data flag
        `RS'    pad()         // retrieve padding proportion
    
    // results
    public:
        `RC'    d()          // estimate/retrieve d
        `RS'    h()          // estimate/retrieve h
        `RC'    at()         // retrieve at
        `RC'    l()          // retrieve l
        `RC'    D()          // estimate/retrieve D
        `RC'    AT()         // set/retrieve AT
        `RC'    W()          // set/retrieve W
        `RC'    L()          // retrieve L
    private:
        `RS'    _h()         // estimate/retrieve h; return error if missing
        `RC'    d            // density estimate
        `RS'    h            // bandwidth
        `RC'    at           // evaluation grid
        `RC'    l            // local bandwidth factors (at observation level)
        `RC'    D            // full grid approximation estimate
        `RC'    D0           // adapt-1 approximation estimate
        `RC'    D0()         // adapt-1 approximation estimate
        `RC'    AT           // approximation grid
        `RC'    W            // grid counts
        `RC'    L            // local bandwidth factors
        
    // internal functions
    private:
        `RS'    h_sjpi(), h_isj(), h_dpi(), h_si(), h_ov(), h_no(), h_rd(), 
                _h_sjpi(), _h_isj(), h_root(), _h_root_fn()
        `Int'   _h_root()
        `RS'    df()
        `RC'    dd()
        void    dexact()
        `RC'    _dexact()
        `RC'    dapprox(), dapprox_fft(), dapprox_fft_rf(), dapprox_std()
        `RC'    lbwf()
        `RC'    ipolate()
        `RC'    grid()
        `RS'    scale()
}

// ---------------------------------------------------------------------------
// settings
// ---------------------------------------------------------------------------

void `MAIN'::new()
{
    kernel("")
    bw("")
    support(.)
    n(.)
}

void `MAIN'::clear()
{
    h  = .
    d = at = l = D = D0 = AT = W = L = J(0,1,.)
}

// D.data() -------------------------------------------------------------------

void `MAIN'::data(`RC' X, | `RC' w, `Bool' pw, `Bool' sorted)
{   // -sorted- indicates that
    //     X is sorted and non-missing
    //     w is non-missing and non-negative
    if (args()<2) w = 1
    if (args()<3) pw = `FALSE'
    if (args()<4) sorted = `FALSE'
    if (sorted==`FALSE') {
        if (missing(X) | missing(w)) _error(3351)
        if (any(w:<0)) {
            errprintf("{it:w} must not be negative\n")
            _error(3300)
        }
    }
    if (sorted==`FALSE') checksuprt(X, lb(), ub())
    setup.nobs   = mm_nobs(X, w)
    setup.X      = &X
    setup.w      = &w
    setup.pw     = (pw!=`FALSE')
    setup.sorted = (sorted!=`FALSE')
    clear()
}

`RC' `MAIN'::X()
{
    if (setup.X==NULL) return(J(0,1,.))
    return(*setup.X)
}

`RC' `MAIN'::w()
{
    if (setup.w==NULL) return(1)
    return(*setup.w)
}

`RS' `MAIN'::nobs() return(setup.nobs)

`Bool' `MAIN'::pw() return(setup.pw)

`Bool' `MAIN'::sorted() return(setup.sorted)

// D.kernel() -----------------------------------------------------------------

`T' `MAIN'::kernel(| `SS' kernel0, `Int' adapt)
{
    `SS' kernel
    
    // get
    if (args()==0) {
        return(setup.kernel)
    }
    // set
    if (adapt<0) _error(3300)
    kernel = strlower(strtrim(kernel0))
    if (kernel=="") kernel = "gaussian"  // default is "gaussian"
    setup.kernel = _mm_unabkern(strlower(strtrim(kernel)))
    setup.k      = _mm_findkern(setup.kernel)
    setup.kint   = _mm_findkint(setup.kernel)
    setup.kd     = _mm_findkderiv(setup.kernel)
    setup.kh     = (*_mm_findkdel0(setup.kernel))()
    setup.adapt = (adapt<. ? trunc(adapt) : 0) // default is 0
    clear()
}

`RS' `MAIN'::kh() return(setup.kh)

`Int' `MAIN'::adapt() return(setup.adapt)

`RC' `MAIN'::k(`RC' X) {
    return((*setup.k)(X))
}

`RC' `MAIN'::kd(`RC' X) {
    return((*setup.kd)(X))
}

`RC' `MAIN'::kint(`RS' l, | `RC' X)
{
    if (args()==1) return((*setup.kint)(l))
    return((*setup.kint)(l, X))
}

`RC' `MAIN'::K(`RC' X, `RC' x, `RC' h, | `Bool' fast)
{   // fast!=0: do not reset results to 0 outside support
    `RC' k
    
    k = _K(X, x, h):/h
    if (setup.bc==0)      return(k)
    if (fast & args()==4) return(k)
    if (lb()<.) k = k :* (X:>=lb() :& x:>=lb())
    if (ub()<.) k = k :* (X:<=ub() :& x:<=ub())
    return(k)
}

`RC' `MAIN'::_K(`RC' X, `RC' x, `RC' h)
{
    `RC' k, a0, a1, a2
    
    // unbounded support
    if (setup.bc==0) return(k((x:-X):/h))
    // renormalization
    if (setup.bc==1) {
        k = k((x:-X):/h)
        if      (ub()>=.) k = k :/  kint(1, (x:-lb()):/h) // lower bound only
        else if (lb()>=.) k = k :/  kint(1, (ub():-x):/h) // upper bound only
        else k = k :/ (kint(1, (ub():-x):/h) - kint(1, (lb():-x):/h))
        return(k)
    }
    // reflection
    if (setup.bc==2) {
        k = k((x:-X):/h)
        if (lb()<.) k = k + k((x:-2*lb():+X):/h)
        if (ub()<.) k = k + k((x:-2*ub():+X):/h)
        return(k)
    }
    // linear correction
    if (setup.bc==3) {
        klc(x, h, a0=., a1=., a2=.)
        k = (x:-X):/h
        return((a2 :- a1:*k):/(a0:*a2-a1:^2) :* k(k))
    }
    _error(3498) // not reached
}

void `MAIN'::klc(`RC' x, `RC' h, `RC' a0, `RC' a1, `RC' a2)
{
    `RC' l
    
    if (ub()>=.) { // lower bound only
        l  = (x:-lb()):/h
        a0 = kint(1, l)
        a1 = kint(3, l)
        a2 = kint(4, l)
        return
    }
    l = (ub():-x):/h
    a0 =  kint(1, l)
    a1 = -kint(3, l)
    a2 =  kint(4, l)
    if (lb()>=.) return // upper bound only
    l  = (lb():-x):/h
    a0 = a0 :- kint(1, l)
    a1 = a1 :+ kint(3, l)
    a2 = a2 :- kint(4, l)
}

`RC' `MAIN'::Kd(`RC' X, `RC' x, `RC' h, | `Bool' fast)
{   // fast!=0: do not reset results to 0 outside support
    `RC' k, a0, a1, a2
    
    // unbounded support
    if (setup.bc==0) return(-kd((x:-X):/h) :/ h:^2)
    // renormalization
    if (setup.bc==1) {
        k = kd((x:-X):/h)
        if      (ub()>=.) k = k :/  kint(1, (x:-lb()):/h) // lower bound only
        else if (lb()>=.) k = k :/  kint(1, (ub():-x):/h) // upper bound only
        else k = k :/ (kint(1, (ub():-x):/h) - kint(1, (lb():-x):/h))
        k = -k :/ h:^2
    }
    // reflection
    else if (setup.bc==2) {
        k = kd((x:-X):/h)
        if (lb()<.) k = k - kd((x:-2*lb():+X):/h)
        if (ub()<.) k = k - kd((x:-2*ub():+X):/h)
        k = -k :/ h:^2
    }
    // linear correction
    else if (setup.bc==3) {
        klc(x, h, a0=., a1=., a2=.)
        k = (x:-X):/h
        k = (a1:*k(k) - (a2 :- a1:*k):*kd(k)) :/ ((a0:*a2-a1:^2):*h:^2)
    }
    else _error(3498) // not reached
    // reset to zero outside of support
    if (fast & args()==4) return(k)
    if (lb()<.) k = k :* (X:>=lb() :& x:>=lb())
    if (ub()<.) k = k :* (X:<=ub() :& x:<=ub())
    return(k)
}

// D.bw() -------------------------------------------------------------------

`T' `MAIN'::bw(| `TS' bw, `RS' adj, `Int' dpi, `Bool' qui)
{
    `RS' h
    `SS' method
    
    // get
    if (args()==0) {
        if (setup.h0<.) return(setup.h0)
        return(setup.bwmethod)
    }
    // set
    if (adj<=0) _error(3300)
    if (dpi<0)  _error(3300)
    if (args()<4) qui = `FALSE'
    if (!isstring(bw)) {
        if (bw<=0) _error(3300)
        if (bw<.) h = bw
    }
    else method = bw
    method = mm_strexpand(strlower(strtrim(method)),
        ("silverman", "normalscale", "oversmoothed", "sjpi", "dpi", "isj"),
        "sjpi") // default is "sjpi"
    setup.h0       = h
    setup.bwmethod = method
    setup.adjust   = (adj<. ? adj : 1)          // default is 1
    setup.dpi      = (dpi<. ? trunc(dpi) : 2)   // default is 2
    setup.qui      = (qui!=`FALSE')
    clear()
}

`RS' `MAIN'::adjust() return(setup.adjust)

`Int' `MAIN'::dpi() return(setup.dpi)

// D.support() ----------------------------------------------------------------

`T' `MAIN'::support(| `RV' minmax, `SS' method, `Bool' rd)
{
    `RS' lb, ub
    
    // get
    if (args()==0) {
        return((setup.lb, setup.ub))
    }
    // set
    if (args()<3) rd = `FALSE'
    if (length(minmax)>2) _error(3200)
    if (length(minmax))   lb = minmax[1]
    else                  lb = .
    if (length(minmax)>1) ub = minmax[2]
    else                  ub = .
    if (lb<. & ub<.) {
        if (lb>=ub) _error(3300)
    }
    if (rd) {
        if (lb >= .)  lb = 0
        if (lb <  0)  _error(3300)
        if (ub >= .)  ub = 1
        if (ub >  1)  _error(3300)
        if (lb > ub)  _error(3300)
    }
    if (sorted()==`FALSE') checksuprt(X(), lb, ub)
    setup.bcmethod = mm_strexpand(strlower(strtrim(stritrim(method))),
        ("renormalization", "reflection", "linear correction"),
        "renormalization") // default is "renormalization"
    setup.lb = lb
    setup.ub = ub
    setup.rd = (rd!=`FALSE')
    if (setup.lb>=. & setup.ub>=.)                  setup.bc = 0
    else if (setup.bcmethod=="renormalization")     setup.bc = 1
    else if (setup.bcmethod=="reflection")          setup.bc = 2
    else if (setup.bcmethod=="linear correction")   setup.bc = 3
    else _error(3498) // cannot be reached
    clear()
}

`RS' `MAIN'::lb() return(setup.lb)

`RS' `MAIN'::ub() return(setup.ub)

`SS' `MAIN'::bc() return(setup.bcmethod)

`Bool' `MAIN'::rd() return(setup.rd)

void `MAIN'::checksuprt(`RC' X, `RS' lb, `RS' ub)
{
    if (rows(X)==0) return // no data
    if (lb<.) {
        if (min(X)<lb) {
            errprintf("{it:X} contains values out of support\n")
            _error(3300)
        }
    }
    if (ub<.) {
        if (max(X)>ub) {
            errprintf("{it:X} contains values out of support\n")
            _error(3300)
        }
    }
}

// D.n() ----------------------------------------------------------------------

`T' `MAIN'::n(| `Int' n, `RS' pad)
{
    // get
    if (args()==0) return(setup.n)
    // set
    if (n<1)   _error(3300)
    if (pad<0) _error(3300)
    setup.n   = (n<. ? trunc(n) : 1024)  // default is 1024
    setup.pad = (pad<. ? pad : 0.1)      // default is 0.1
    clear()
}

`RS' `MAIN'::pad() return(setup.pad)

// ---------------------------------------------------------------------------
// bandwidth selection
// ---------------------------------------------------------------------------

`RS' `MAIN'::h() {
    if (h<.) return(h)
    // user bandwidth
    if (setup.h0<.) { 
        h = setup.h0 * adjust()
        return(h)
    }
    // data-driven bandwidth selection
    if (bw()=="sjpi") {
        h = h_sjpi()
        if (h>=.) {
            if (setup.qui==`FALSE')
                display("{txt}(SJPI bandwidth estimation failed; using DPI method)")
            h = h_dpi()
        }
    }
    else if (bw()=="isj") {
        h = h_isj()
        if (h>=.) {
            if (setup.qui==`FALSE')
                display("{txt}(ISJ bandwidth estimation failed; using DPI method)")
            h = h_dpi()
        }
    }
    else if (bw()=="dpi")          h = h_dpi()
    else if (bw()=="silverman")    h = h_si()
    else if (bw()=="oversmoothed") h = h_ov()
    else if (bw()=="normalscale")  h = h_no()
    else _error(3498) // cannot be reached
    // correction for pweights
    if (pw() & rows(w())!=1) {
        h = h * (sum(w():^2)/(nobs()))^.2
    }
    // final adjustments
    h = h * kh() * adjust()
    if (h<=0) h = . // h must be strictly positive
    return(h)
}

`RS' `MAIN'::h_rd(`RS' s) // relative data correction factor
{
    if (rd()) return((1 + 1/(2 * sqrt(pi()) * s))^.2)
    return(1)
}

`RS' `MAIN'::_h()
{
    if (h()>=.) {
        errprintf("bandwidth could not be determined\n")
        _error(3498)
    }
    return(h)
}

// Sheather-Jones solve-the-equation plug-in selection rule -------------------

`RS' `MAIN'::h_sjpi()
{
    `RS'  n, s, hmin, h_os, sda, tdb, tdc, alpha, beta
    `RC'  AT, W
    
    AT = AT()
    W = W()
    s = scale(0, AT, W, 1)           // min of sd and iqr
    if (pw()) {                      // pweights: normalize grid counts
        n = rows(X())
        W = W * (n / nobs())
    }
    else n = nobs()
    sda = df(AT, W, n, 1.241 * s * n^(-1/7), 2)
    tdb = df(AT, W, n, 1.230 * s * n^(-1/9), 3)
    alpha = 1.357 * (sda/tdb)^(1/7)
    if (rd()) {
        tdc = df(AT, W, n, 1.304 * s * n^(-1/5), 1)
        beta = 1.414 * (sda/tdc)^(1/3)
    }
    hmin = .5 * (AT[n()]-AT[1])/(n()-1) * mm_kdel0_gaussian()/mm_kdel0_rectangle()
    h_os = s * (243/(35*n))^.2 * mm_kdel0_gaussian() * h_rd(s) // oversmoothed (modified)
        // // plot objective function:
        // `Int' i
        // `RC'  at, hh
        // at = rangen(hmin, 3*h_os, 100)
        // hh = J(rows(at),1,.)
        // for (i=rows(at);i;i--) hh[i] = _h_sjpi(at[i], AT, W, n, alpha, beta)
        // mm_plot((hh,at),"line",sprintf("xline(%g)",h_os))
    return(h_root(0, hmin, h_os, AT, W, n, alpha, beta) * 
        (n/nobs())^.2 / mm_kdel0_gaussian())
}

`RS' `MAIN'::h_root(`Int' m, `RS' hmin, `RS' h_os,
    | `T' o1, `T' o2, `T' o3, `T' o4, `T' o5)
{
    `RS'  ax, bx, h
    `Int' rc
    
    bx = h_os
    ax = max((hmin, bx/2))
    rc = _h_root(m, h=., ax, bx, o1, o2, o3, o4, o5)
    if (rc==2) { // move down
        bx = ax
        while (1) {
            if (bx<=hmin) break // cannot go below hmin
            ax = max((hmin, bx/2))
            rc = _h_root(m, h, ax, bx, o1, o2, o3, o4, o5)
            if (rc==2) bx = ax // continue moving down
            else break // this also stops if there is a change in 
                      //  direction (rc==3) (in this case: h = current bx)
        }
    }
    else if (rc==3) { // move up
        ax = bx; bx = ax*1.5
        while (1) {
            rc = _h_root(m, h, ax, bx, o1, o2, o3, o4, o5)
            if (rc==3) {; ax = bx; bx = ax*1.5; } // continue moving up
            else break // this also stops if there is a change in 
                       //  direction (rc==2) (in this case: h = current ax)
        }
    }
    if (h<=hmin) { // solution is smaller than hmin
        ax = h_os; bx = ax*1.5
        rc = _h_root(m, h, ax, bx, o1, o2, o3, o4, o5)
        if (rc==2) return(.) // moving up does not help
        if (rc==3) { // continue moving up
            ax = bx; bx = ax*1.5
            while (1) {
                rc = _h_root(m, h, ax, bx, o1, o2, o3, o4, o5)
                if (rc==3) {; ax = bx; bx = ax*1.5; } // continue moving up
                else break // this also stops if there is a change in 
                           //  direction (rc==2) (in this case: h = current ax)
            }
        }
    }
    return(h)
}

`Int' `MAIN'::_h_root(`Int' m, `RS' x, `RS' ax, `RS' bx, 
    | `T' o1, `T' o2, `T' o3, `T' o4, `T' o5)
{   // root finder for SJPI (m!=1) and ISJ (m=1); adapted from mm_root()
    `Int' maxit, itr
    `RS'  tol, tol_act
    `RS'  a, b, c, fa, fb, fc, prev_step, p, q, new_step, t1, cb, t2

    tol   = 0
    maxit = 100
    x = .; a = ax; b = bx
    fa = _h_root_fn(m, a, o1, o2, o3, o4, o5)
    fb = _h_root_fn(m, b, o1, o2, o3, o4, o5)
    c = a; fc = fa
    if ( fa==. ) return(0) // abort if fa missing => x=.
    if ( (fa > 0 & fb > 0) | (fa < 0 & fb < 0) ) {
        if ( abs(fa) < abs(fb) ) {
            x = a; return(2)
        }
        x = b; return(3)
    }
    for (itr=1; itr<=maxit; itr++) {
        if ( fb==. ) return(0)
        prev_step = b-a
        if ( abs(fc) < abs(fb) ) {
            a = b;  b = c;  c = a; fa = fb; fb = fc; fc = fa
        }
        tol_act = 2*epsilon(b) + tol/2
        new_step = (c-b)/2
        if ( abs(new_step) <= tol_act | fb == 0 ) {
             x = b
             return(0)
        }
        if ( abs(prev_step) >= tol_act & abs(fa) > abs(fb) ) {
            cb = c-b
            if ( a==c ) {
                t1 = fb/fa
                p = cb*t1
                q = 1.0 - t1
            }
            else {
                q = fa/fc;  t1 = fb/fc;  t2 = fb/fa
                p = t2 * ( cb*q*(q-t1) - (b-a)*(t1-1.0) )
                q = (q-1.0) * (t1-1.0) * (t2-1.0)
            }
            if ( p>0 ) q = -q
            else      p = -p
            if ( p < (0.75*cb*q-abs(tol_act*q)/2) & p < abs(prev_step*q/2) )
                new_step = p/q
        }
        if ( abs(new_step) < tol_act ) {
            if ( new_step > 0 ) new_step = tol_act
            else                new_step = -tol_act
        }
        a = b;  fa = fb
        b = b + new_step; fb = _h_root_fn(m, b, o1, o2, o3, o4, o5)
        if ( (fb > 0 & fc > 0) | (fb < 0 & fc < 0) ) {
            c = a;  fc = fa
        }
    }
    x = b
    return(1)
}

`RS' `MAIN'::_h_root_fn(`Int' m, `RS' x, | `T' o1, `T' o2, `T' o3, `T' o4, `T' o5)
{
    if (m==1) return(_h_isj(x, o1, o2, o3))   // ISJ method
    return(_h_sjpi(x, o1, o2, o3, o4, o5))    // SJPI method
}

`RS' `MAIN'::_h_sjpi(`RS' h, `RC' AT, `RC' W, `RS' n, `RS' alpha, `RS' beta)
{
    `RS' d
    
    d = mm_kint_gaussian(2)
    if (rd()) d = d * (1 + df(AT, W, n, beta * h^(5/3), 0))
    return((d / (n * df(AT, W, n, alpha * h^(5/7), 2)))^0.2 - h)
}

`RS' `MAIN'::df(`RC' AT, `RC' W, `RS' n, `RS' h, `Int' d)
{
    return( (-1)^d * sum(W :* dd(AT, W, n, h, d)) / n )
}

`RC' `MAIN'::dd(`RC' AT, `RC' W, `RS' n, `RS' h, `Int' d)
{    // d'th density derivative using gaussian kernel
    `Int' M, L, i, first, last
    `RS'  a, b
    `RC'  kappam, arg, hmold0, hmold1, hmnew, w
    
    // compute kappam
    M = rows(AT)
    a = AT[1]; b = AT[M]
    L = (M-1) * (1 + lb()<. + ub()<.)
    L = max( (min( (floor((4 + 2*d) * h * (M-1)/(b-a)), L) ), 1) )
    arg = (0::L) * (b-a) / (h*(M-1))
    kappam = normalden(arg)
    hmold0 = 1; hmold1 = arg; hmnew  = 1
    for (i=2; i<=2*d; i++) { // compute mth degree Hermite polynomial
        hmnew  = arg :* hmold1 :- (i-1) * hmold0
        hmold0 = hmold1; hmold1 = hmnew
    }
    kappam = hmnew :* kappam
    
    // unbounded estimator
    if (lb()>=. & lb()>=.) {
        return( convolve((kappam[L+1::1] \ kappam[|2 \ L+1|]), 
            W)[|L+1 \ L+M|] / (n*h^(2*d+1)) )
    }
    
    // reflection estimator
    w = W
    if (lb()<.) {
        first = M
        w = W[M::2] \ w
        w[first] = 2 * w[first]
    }
    else first = 1
    last = first + (M-1)
    if (ub()<.) {
        w = w \ W[M-1::1]
        w[last] = 2 * w[last]
    }
    return( convolve((kappam[L+1::1] \ kappam[|2 \ L+1|]), 
        w)[|L+first \ L+last|] / (n*h^(2*d+1)) )
}

// "improved" SJPI (diffusion method) -----------------------------------------

`RS' `MAIN'::h_isj()
{
    `Int' n
    `RS'  N, s, hmin, h_os
    `RC'  AT, W, a
    
    // step 1: bin data on regular grid
    n = 2^ceil(ln(n())/ln(2))        // round up to next power of 2
    n = max((n, 1024))               // enforce min grid size of at least 1024
    AT = grid(n)                     // generate grid
    if (sorted()) W = _mm_exactbin(X(), w(), grid(n+1))
    else          W = mm_fastexactbin(X(), w(), grid(n+1))
        // need to use exact binning because linear binning would introduce
        // some (non-vanishing) bias at the boundaries (doubling the first and
        // last grid count does not seem to help); a consequence of exact 
        // binning is that the density estimate will be slightly shifted/stretched
        // to the left; this error can be substantial if the grid size is small,
        // but it vanishes with increasing grid size
    s = scale(0, AT, W, 1)           // min of sd and iqr
    W = W / nobs()                   // relative frequencies
    if (pw()) N = rows(X())          // obtain sample size
    else      N = nobs()
    // step 2: obtain discrete cosine transform of binned data
    a = Re( (1 \ 2 * exp(1i * (1::n-1) * pi() / (2*n)))
         :* fft(W[mm_seq(1,n-1,2)] \ W[mm_seq(n,2,2)]) )
    // step 3: compute bandwidth
    hmin = (.5/(n-1) * mm_kdel0_gaussian()/mm_kdel0_rectangle())^2
    h_os = (s * (243/(35*N))^.2 * mm_kdel0_gaussian() * h_rd(s) / (AT[n]-AT[1]))^2
        // // plot objective function:
        // `Int' i
        // `RC'  at, hh
        // at = rangen(hmin, 3*h_os, 100)
        // hh = J(rows(at),1,.)
        // for (i=rows(at);i;i--) hh[i] = _h_isj(at[i], N, (1::n-1):^2, (a[2::n]/2):^2)
        // mm_plot((hh,at),"line",sprintf("xline(%g)",h_os))
    return(sqrt(h_root(1, hmin, h_os, N, (1::n-1):^2, (a[2::n]/2):^2)) * 
        (AT[n]-AT[1]) * (N/nobs())^.2 / mm_kdel0_gaussian() * h_rd(s))
}

`RS' `MAIN'::_h_isj(`RS' h, `RS' N, `RC' I, `RC' a2)
{
    `Int' l, s
    `RS'  K0, c
    `RC'  f, t
    
    l = 7
    f = 2 * pi()^(2*l) * sum(I:^l :* a2 :* exp(-I * pi()^2 * h))
    for (s=l-1; s>=2; s--) {
        K0 = mm_prod(mm_seq(1, 2*s-1, 2)) / sqrt(2*pi())
        c  = (1 + (1/2)^(s + 1/2)) / 3
        t  = (2 * c * K0/N/f):^(2/(3 + 2*s))
        f  = 2 * pi()^(2*s) * sum(I:^s :* a2 :* exp(-I * pi()^2 * t))
    }
    return((2 * N * sqrt(pi()) * f)^(-2/5) - h)
}

// Sheather-Jones direct plug-in selection rule -------------------------------

`RS' `MAIN'::h_dpi()
{
    `RS'  n, s, alpha, psi, psi0, alpha0
    `Int' i
    `RC'  AT, W
    
    i = dpi()
    if (i==0) return(h_no()) // h normalscale
    else {
        AT = AT()
        W  = W()
        s = scale(0, AT, W, 1)           // min of sd and iqr
        if (pw()) {                      // pweights: normalize grid counts
            n = rows(X())
            W = W * (n / nobs())
        }
        else n = nobs()
        alpha = (2 * (sqrt(2) * s)^(3 + 2 * (i+1)) /
                ((1 + 2 * (i+1)) * n))^(1/(3 + 2 * (i+1)))
        if (rd()) alpha0 = (2 * (sqrt(2) * s)^(3 + 2 * (i-1)) /
                           ((1 + 2 * (i-1)) * n))^(1/(3 + 2 * (i-1)))
        for (; i; i--) {
            psi = df(AT, W, n, alpha, i+1)
            if (rd()) psi0 = df(AT, W, n, alpha0, i-1)
            else      psi0 = 0
            if (i>1) {
                alpha = ( factorial(i*2) / (2^i * factorial(i)) *
                          sqrt(2/pi()) / (psi*n) )^(1/(3 + 2*i))
                if (rd()) alpha0 = ( factorial((i-2)*2) / (2^(i-2) * 
                                   factorial(i-2)) * sqrt(2/pi()) / 
                                   (psi0*n) )^(1/(3 + 2*(i-2)))
            }
        }
    }
    return( ((1 + psi0) / (psi * nobs()))^.2 )
}

// optimal of Silverman selection rule ----------------------------------------

`RS' `MAIN'::h_si()
{
    `RS' s
    
    s = scale(0, X(), w(), sorted()) // min of sd and iqr
    return(0.9/mm_kdel0_gaussian() * s / nobs()^.2 * h_rd(s))
}

// oversmoothed selection rule ------------------------------------------------

`RS' `MAIN'::h_ov()
{
    `RS' s
    
    s = scale(1, X(), w(), sorted()) // sd
    return((243/35)^.2 * s / nobs()^.2 * h_rd(s))
}

// normal scale selection rule ------------------------------------------------

`RS' `MAIN'::h_no()
{
    `RS' s
    
    s = scale(0, X(), w(), sorted()) // min of sd and iqr
    return((8*sqrt(pi())/3)^.2 * s / nobs()^.2 * h_rd(s))
}

// ---------------------------------------------------------------------------
// density estimation
// ---------------------------------------------------------------------------

`RC' `MAIN'::d(| `RV' o1, `RS' o2, `RS' o3, `RS' o4)
{
    // case 0: return existing result
    if (args()==0) return(d)
    // case 1: o1 contains grid
    if (args()<=2) {
        if (cols(o1)!=1)   at = o1'
        else               at = o1
        // case 1a: exact estimator
        if (args()==2 & o2) dexact()
        // case 1b: approximation estimator
        else d = ipolate(AT(), D(), at, mm_issorted(at))
        return(d)
    }
    // case 2: o1=n, o2=from, o3=to
    at = grid(o1, _h(), o2, o3)
    // case 2a: exact estimator
    if (args()==4 & o4) dexact()
    // case 2b: approximation estimator
    else d = ipolate(AT(), D(), at, mm_issorted(at))
    return(d)
}

`RC' `MAIN'::at() return(at)

// exact estimator ------------------------------------------------------------

void `MAIN'::dexact()
{
    `Int'  n
    `IntC' p
    
    n = rows(at)
    if (setup.bc & n>0) {
        // check for evaluation points out of support and set density
        // to zero for these points
        if (lb()<. & ub()<.) p = select(1::n, at:>=lb() :& at:<=ub())
        else if (lb()<.)     p = select(1::n, at:>=lb())
        else if (ub()<.)     p = select(1::n, at:<=ub())
        if (length(p)!=n) {
            d = J(n,1,0)
            if (length(p)) d[p] = _dexact(X(), w(), _h() :* l(), at[p])
            return
        }
    }
    d = _dexact(X(), w(), _h() :* l(), at)
}

`RC' `MAIN'::_dexact(`RC' x, `RC' w, `RC' h, `RC' at)
{
    `Int' i
    `RC'  d
    
    i = rows(at)
    d = J(i,1,.)
    // using slightly different method depending on whether h and w are
    // scalar or not to save a bit of computer time, if possible
    if (rows(h)==1 & rows(w)==1) {
        for (;i;i--) d[i] = w/h * sum(_K(x, at[i], h))
    }
    else if (rows(h)==1) {
        for (;i;i--) d[i] = sum(w :* _K(x, at[i], h)) / h
    }
    else if (rows(w)==1) {
        for (;i;i--) d[i] = w * sum(_K(x, at[i], h) :/ h)
    }
    else {
        for (;i;i--) d[i] = sum(w:/h :* _K(x, at[i], h))
    }
    return(d / nobs())
}

`RC' `MAIN'::l()
{
    if (rows(l)) return(l)
    if (!adapt()) return(1)
    return(lbwf(ipolate(AT(), D0(), X(), sorted()), w()))
}

// binned approximation estimator ---------------------------------------------

`RC' `MAIN'::D()
{
    if (rows(D)) return(D)
    if (adapt()) (void) D0()
    D = dapprox()
    return(D)
}

`RC' `MAIN'::D0()
{
    `Int' i
    
    if (rows(D0)) return(D0)
    for (i=adapt();i;i--) D0 = dapprox()
    return(D0)
}

`RC' `MAIN'::AT()
{
    if (rows(AT)) return(AT)
    AT = grid(n())
    return(AT)
}

`RC' `MAIN'::W()
{
    if (rows(W)) return(W)
    if (sorted()) W = _mm_linbin(X(), w(), AT())
    else          W = mm_fastlinbin(X(), w(), AT())
    return(W)
}

`RC' `MAIN'::L()
{
    if (rows(L)) return(L)
    if (!adapt()) return(1)
    L = lbwf(D0(), W())
    return(L)
}

`RC' `MAIN'::dapprox()
{
    `RC' h
    
    // create grid and compute grid counts if necessary
    if (rows(W)==0) (void) W()
    // obtain h
    h = _h()
    if (rows(D0)) h = h :* lbwf(D0, W) // adaptive estimator
    // FFT estimation if h is constant
    if (rows(h)==1) {
        if (setup.bc==3) return(dapprox_std(h))    // linear correction (no FFT)
        if (setup.bc==2) return(dapprox_fft_rf(h)) // reflection
                         return(dapprox_fft(h))    // no bc or renormalization
    }
    // else use standard estimator
    return(dapprox_std(h))
}

`RC' `MAIN'::dapprox_fft(`RS' h)
{
    `Int' M, L
    `RS'  a, b, tau
    `RC'  kappa, D
    
    M = rows(AT)
    a = AT[1]; b = AT[M]
    L = M - 1
    // reduce number of evaluation points if possible
    if (kernel()!="gaussian") {
        if (kernel()=="cosine")            tau = .5
        else if (kernel()=="epanechnikov") tau = sqrt(5)
        else                               tau = 1
        L = max( (min( (floor(tau*h*(M-1)/(b-a)), L) ), 1) )
    }
    // compute kappa and obtain FFT
    kappa = k( (0::L) * (b-a) / (h*(M-1)) )
    D = convolve((kappa[L+1::1]\kappa[|2 \ L+1|]), W)[|L+1 \ L+M|] / (nobs()*h)
    if (setup.bc==0) return(D)
    // renormalization boundary correction
    if (ub()>=.) return(D :/  kint(1, (AT:-lb()):/h)) // lower bound only
    if (lb()>=.) return(D :/  kint(1, (ub():-AT):/h)) // upper bound only
    return(D :/ (kint(1, (ub():-AT):/h) - kint(1, (lb():-AT):/h)))
}

`RC' `MAIN'::dapprox_fft_rf(`RS' h)
{
    `Int' M, L, first, last
    `RS'  a, b, tau
    `RC'  kappa, w
    
    M = rows(AT)
    a = AT[1]; b = AT[M]
    L = (M-1) * (1 + (lb()<.) + (ub()<.))
    // reduce number of evaluation points if possible
    if (kernel()!="gaussian") {
        if (kernel()=="cosine")            tau = .5
        else if (kernel()=="epanechnikov") tau = sqrt(5)
        else                               tau = 1
        L = max( (min( (floor(tau*h*(M-1)/(b-a)), L) ), 1) )
    }
    // expand vector of grid counts
    w = W
    if (lb()<.) {
        first = M
        w = W[M::2] \ w
        w[first] = 2 * w[first]
    }
    else first = 1
    last = first + (M-1)
    if (ub()<.) {
        w = w \ W[M-1::1]
        w[last] = 2 * w[last]
    }
    // compute kappa and obtain FFT
    kappa = k( (0::L) * (b-a) / (h*(M-1)) )
    return(convolve((kappa[L+1::1]\kappa[|2 \ L+1|]), w)[|L+first \ L+last|] / 
        (nobs()*h))
}

`RC' `MAIN'::dapprox_std(`RC' h)
{
    `Int' n, i, a, b
    `RC'  r, D
    
    // no computational shortcut in case of gaussian kernel
    if (kernel()=="gaussian") return(_dexact(AT, W, h, AT))
    // other kernels: restrict computation to relevant range of evaluation points
    n = n()
    r = h
    if (kernel()=="cosine")            r = r * .5
    else if (kernel()=="epanechnikov") r = r * sqrt(5)
    r = trunc(r * (n-1) / (AT[n]-AT[1])) :+ 1 // add 1 to prevent roundoff error
    D = J(n,1,0)
    if (rows(h)==1) {
        for (i=n;i;i--) {
            a = max((1, i-r))
            b = min((n, i+r))
            D[|a \ b|] = D[|a \ b|] + W[i] / h * _K(AT[i], AT[|a \ b|], h)
        }
        return(D :/ nobs())
    }
    for (i=n;i;i--) {
        a = max((1, i-r[i]))
        b = min((n, i+r[i]))
        D[|a \ b|] = D[|a \ b|] + W[i] / h[i] * _K(AT[i], AT[|a \ b|], h[i])
    }
    return(D :/ nobs())
}

// ---------------------------------------------------------------------------
// helper functions
// ---------------------------------------------------------------------------

`RC' `MAIN'::grid(`Int' n, | `RS' h, `RS' from, `RS' to)
{
    `RS' tau
    `RR' range, minmax
    
    range = J(1,2,.)
    if (from<.)                 range[1] = from
    else if (lb()<. & from<.y)  range[1] = lb()
    if (to<.)                   range[2] = to
    else if (ub()<. & to<.y)    range[2] = ub()
    if (missing(range)) {
        minmax = minmax(X())
        if (range[1]>=. & from==.z) range[1] = minmax[1]
        if (range[2]>=. & to==.z)   range[2] = minmax[2]
        if (missing(range)) {
            // if h is not provided:
            // - extend grid below min(x) and above max(x) by pad()% of data range
            // if h is provided:
            // - extend grid by the kernel halfwidth such that density can go to 
            //   zero outside data range (only approximately for gaussian kernel or
            //   in case of adaptive estimator), but limit by pad()% of data range
            tau = (minmax[2]-minmax[1]) * pad()
            if (h<.) tau = min((tau, h * (kernel()=="epanechnikov" ? sqrt(5) : 
                (kernel()=="cosine" ? .5 : (kernel()=="gaussian" ? 3 : 1)))))
            minmax[1] = minmax[1] - tau
            minmax[2] = minmax[2] + tau
            if (range[1]>=.) {
                if (lb()<.) minmax[1] = max((lb(),minmax[1]))
                range[1] = minmax[1]
            }
            if (range[2]>=.) {
                if (ub()<.) minmax[2] = min((ub(),minmax[2]))
                range[2] = minmax[2]
            }
        }
    }
    if (range[1]>range[2]) _error(3300)
    return(rangen(range[1], range[2], n))
}

`RS' `MAIN'::scale(`Int' type, `RC' X, `RC' w, `Bool' sorted)
{   // type: 0 = min(sd,iqr), 1 = sd, 2 = iqr
    // iqr will be replaced by sd if 0
    `RS' iqr, sd
    
    if (type!=1) {
        if (sorted) iqr = _mm_iqrange(X, w) / 1.349
        else        iqr =  mm_iqrange(X, w) / 1.349
    }
    if (type!=2 | iqr<=0) {
        sd = sqrt(variance(X, w))
        if (pw()) sd = sd * sqrt( (nobs()-1) / (nobs() - nobs()/rows(X())) )
    }
    if (type==1) return(sd)
    if (iqr<=0)  iqr = sd
    if (type==2) return(iqr)
    return(min((sd, iqr)))
}

`RC' `MAIN'::lbwf(`RC' d, `RC' w) // local bandwidth factors
{
    `RC' l
    
    l = sqrt( exp(mean(log(d), w)) :/ d)  // exp(...) -> geometric mean
    return(editmissing(l, 1))
}

`RC' `MAIN'::ipolate(`RC' AT, `RC' D, `RC' at, `Bool' sorted)
{
    `RC' d, p
    
    if (sorted) d = mm_fastipolate(AT, D, at)
    else {
        p = order(at, 1)
        d = mm_fastipolate(AT, D, at[p])
        d[p] = d
    }
    _editmissing(d, 0)  // set density outside of grid to 0
    return(d)
}

end



*! {smcl}
*! {marker mm_qr}{bf:mm_qr.mata}{asis}
*! version 1.0.5  30mar2021  Ben Jann
version 11
mata:

class mm_qr
{
    // constructors
    private:
        void           new()
        void           init() // default settings
        void           clear(), clear1(), clear2(), clear2b(), clear3()

    // setup
    public:
        void           data()
        transmorphic   qd()
        transmorphic   demean()
        transmorphic   collin()
        transmorphic   p()
        transmorphic   b_init()
        transmorphic   tol()
        transmorphic   maxiter()
        transmorphic   beta()
        transmorphic   method()
        transmorphic   log()
    
    private:
        pointer scalar y
        pointer scalar X
        pointer scalar w
        real scalar    n, N
        real scalar    cons
        real scalar    k, kadj
        real scalar    K, Kadj
        real scalar    p
        real colvector b_init, b_ls
        real scalar    b_init_user
        real scalar    maxiter
        real scalar    tol
        real scalar    beta
        string scalar  method
        real scalar    qd
        real scalar    demean
        real scalar    collin
        real colvector omit
        real rowvector xindx
        real scalar    log
    
    // results
    public:
        real colvector b()
        real colvector xb()
        real scalar    gap()
        real scalar    sdev()
        real scalar    iter()
        real scalar    converged()
        real scalar    n() // number of observations
        real scalar    N() // sum of weights
        real scalar    cons()
        real scalar    k() // number of predictors
        real scalar    K() // number of coefficients
        real colvector omit()
        real scalar    k_omit()
        real scalar    ymean()
        real rowvector means()
    
    private:
        real colvector b
        real scalar    conv
        real scalar    iter
        real scalar    gap
        real scalar    sdev
        real scalar    ymean
        real rowvector means
        
    // functions
    private:
        void           err_nodata()
        void           printflush()
        void           printlog()
        real colvector _xb()
        real scalar    _sdev()
        void           set_collin()
        transmorphic   lsfit()
        void           _b_init()
        real colvector rmomit(), addomit()
        real colvector meanadj()
        void           fit(), fnb()
        pointer scalar get_X()
        void           gen_z_w()
        real matrix    cross()
        real scalar    minselect()
}

void mm_qr::new()
{
    init()
    clear()
}

void mm_qr::init()
{
    p         = .5
    qd        = 1
    demean    = 1
    collin    = 1
    tol       = 1e-8
    maxiter   = st_numscalar("c(maxiter)")
    beta      = 0.99995
    method    = "fnb"
    log       = 0
}

void mm_qr::clear() // if new() or data()
{
    y = X = w = NULL
    n = N = cons = k = K = ymean = means = .z
    b_init_user = 0 // b_init can only be set after data has been set
    clear1()
}

void mm_qr::clear1() // if qd() or demean()
{
    b_ls = .z
    if (!b_init_user) b_init = .z
    clear2()
}

void mm_qr::clear2() // if collin()
{
    kadj = Kadj = omit = xindx = .z
    clear3()
}

void mm_qr::clear2b() // if p()
{
    if (!b_init_user & cons) b_init = .z
    clear3()
}

void mm_qr::clear3() // if b_init(), tol(), maxiter(), beta()
{
    b = conv = iter = gap = sdev = .z
}

void mm_qr::err_nodata()
{
    if (y==NULL) {
        display("{err}data not set")
        _error(3498)
    }
}

void mm_qr::printflush(string scalar s)
{
    printf(s)
    displayflush()
}

void mm_qr::printlog(real scalar d, real scalar iter)
{
    printf("{txt}")
    printf("{txt}Iteration %g:", iter)
    printf("{col 16}mreldif() in b = {res}%11.0g;", d)
    printf("{txt}  duality gap = {res}%11.0g\n", gap)
    displayflush()
}

void mm_qr::data(real colvector y0, | real matrix X0, real colvector w0,
    real colvector cons0)
{
    clear()
    // do error checks before storing anything
    if (!(X0==. | X0==J(0,0,.))) {
        if (rows(X0)!=rows(y0)) _error(3200)
    }
    if (rows(w0)) {
        if (rows(w0)!=rows(y0) & rows(w0)!=1) _error(3200)
    }
    // now start storing
    y = &y0
    n = rows(y0)
    if (X0==. | X0==J(0,0,.)) X = &J(n, 0, 1)
    else                      X = &X0
    if (rows(w0)) {
        if (rows(w0)==1) N = n * w0
        else             N = quadsum(w0)
        w = &w0
    }
    else {
        N = n
        w = &1
    }
    cons = (cons0!=0)
    k = cols(*X)
    K = k + cons
}

transmorphic mm_qr::qd(| real scalar qd0)
{
    if (args()==0) return(qd)
    if (qd==(qd0!=0)) return // no change
    qd = (qd0!=0)
    clear1()
}

transmorphic mm_qr::demean(| real scalar demean0)
{
    if (args()==0) return(demean)
    if (demean==(demean0!=0)) return // no change
    demean = (demean0!=0)
    clear1()
}

transmorphic mm_qr::collin(| real scalar collin0)
{
    if (args()==0) return(collin)
    if (collin==(collin0!=0)) return // no change
    collin = (collin0!=0)
    clear2()
}

transmorphic mm_qr::p(| real scalar p0)
{
    if (args()==0) return(p)
    if (p0<=0 | p0>=1) _error(3300)
    if (p==p0) return // no change
    p = p0
    clear2b()
}

transmorphic mm_qr::b_init(| real colvector b_init0)
{
    if (args()==0) {
        if (b_init==.z) {
            err_nodata()
            _b_init()
        }
        return(b_init)
    }
    if (b_init0==.z) {  // use b_init(.z) to clear starting values
        b_init_user = 0
    }
    else {
        err_nodata()
        if (rows(b_init0)!=K) _error(3200) // wrong number of parameters
        b_init = b_init0
        b_init_user = 1
    }
    clear3()
}

transmorphic mm_qr::tol(| real scalar tol0)
{
    if (args()==0) return(tol)
    if (tol0<=0) _error(3300)
    if (tol==tol0) return // no change
    tol = tol0
    clear3()
}

transmorphic mm_qr::maxiter(| real scalar maxiter0)
{
    if (args()==0) return(maxiter)
    if (maxiter0<0) _error(3300)
    if (maxiter==maxiter0) return // no change
    maxiter = maxiter0
    clear3()
}

transmorphic mm_qr::beta(| real scalar beta0)
{
    if (args()==0) return(beta)
    if (beta0<0 | beta0>1) _error(3300)
    if (beta==beta0) return // no change
    beta = beta0
    clear3()
}

transmorphic mm_qr::method(| string scalar method0)
{
    if (args()==0) return(method)
    if (method==method0) return // no change
    if (!anyof(("fnb"), method0)) _error(3300)
    method = method0
    clear3()
}

transmorphic mm_qr::log(| real scalar log0)
{
    if (args()==0) return(log)
    if (!anyof((0,1,2,3), log0)) _error(3300)
    log = log0
}

real scalar mm_qr::n() return(n)

real scalar mm_qr::N() return(N)

real scalar mm_qr::cons() return(cons)

real scalar mm_qr::k() return(k)

real scalar mm_qr::K() return(K)

real colvector mm_qr::omit()
{
    if (omit==.z) {
        if (y!=NULL) set_collin()
    }
    return(omit)
}

real scalar mm_qr::k_omit()
{
    if (kadj==.z) {
        if (y!=NULL) set_collin()
    }
    return(K-Kadj)
}

real scalar mm_qr::ymean()
{
    if (ymean==.z) {
        if (y!=NULL) ymean = quadcross(*w, *y) / N
    }
    return(ymean)
}

real rowvector mm_qr::means()
{
    if (means==.z) {
        if (y!=NULL) means = quadcross(*w, *X) / N
    }
    return(means)
}

real colvector mm_qr::b()
{
    if (b==.z) {
        err_nodata()
        fit()
    }
    return(b)
}

real colvector mm_qr::xb(| real matrix X0)
{
    if (b==.z) (void) b()
    if (args()==1) {
        if (cols(X0)!=k) _error(3200)
        return(_xb(X0, b))
    }
    return(_xb(*X, b))
}

real colvector mm_qr::_xb(real matrix X, real colvector b)
{
    real scalar k
    
    k = rows(b) - cons
    if (k==0) {
        if (cons) return(J(rows(*y), 1, b))
        return(J(rows(*y), 1, .)) // model has no parameters
    }
    if (cons) return(X * b[|1\k|] :+ b[k+1])
    return(X * b)
}

real scalar mm_qr::gap()
{
    if (b==.z) (void) b()
    return(gap)
}

real scalar mm_qr::sdev()
{
    if (b==.z) (void) b()
    if (sdev==.z) sdev = _sdev(b)
    return(sdev)
}

real scalar mm_qr::_sdev(real colvector b)
{
    real colvector e
    
    e = *y - _xb(*X, b)
    return(cross(*w, (p :- (e:<0)) :* e))
}

real scalar mm_qr::iter()
{
    if (b==.z) (void) b()
    return(iter)
}

real scalar mm_qr::converged()
{
    if (b==.z) (void) b()
    return(conv)
}

void mm_qr::fit()
{
    real colvector b
    
    // identify collinear variables and exit if nothing to do
    set_collin()
    iter = conv = 0
    if (Kadj==0) { // model has no (non-omitted) parameters
        this.b = J(K,1,.)
        return
    }
    if (n==0) { // no observations
        this.b = J(K,1,.)
        return
    }
    // starting values
    if (b_init==.z) _b_init()
    b = -rmomit(meanadj(b_init, -1))
    
    // interior point algorithm (b will be replaced by solution)
    if      (method=="fnb")  fnb(b)
    else                     _error(3300)
    
    // rescale coefficients
    this.b = meanadj(addomit(-b), 1)
}

void mm_qr::set_collin()
{
    real scalar  nomit
    transmorphic t
    
    if (kadj!=.z) return    // already applied
    if (collin==0 | k==0) {
        kadj  = k
        Kadj  = K
        omit  = J(K, 1, 0)
        xindx = J(1, 0, .)
        return
    }
    t = lsfit()
    omit = mm_ls_omit(t)
    if (cons) omit[K] = 0 // not necessary, I believe
    nomit = sum(omit)
    kadj = k - nomit
    Kadj = K - nomit
    if (!nomit) xindx = J(1, 0, .)
    else        xindx = select(1..k, !omit[|1\k|]')
}

transmorphic mm_qr::lsfit()
{
    transmorphic t
    
    t = mm_ls(*y, *X, *w, cons, qd, demean)
    // hold on to coefficients for b_init
    b_ls = mm_ls_b(t)
    // hold on to means for demeaning
    if (cons & demean) { 
        ymean = mm_ls_ymean(t)
        means = mm_ls_means(t)
    }
    return(t)
}

void mm_qr::_b_init()
{
    if (K==0) { // model has no parameters
        b_init = J(K,1,.)
        return
    }
    if (n==0) { // no observations
        b_init = J(K,1,.)
        return
    }
    if (b_ls==.z) (void) lsfit()
    b_init = b_ls
}

real colvector mm_qr::rmomit(real colvector b)
{
    if (length(xindx)==0) return(b)
    return(b[xindx] \ (cons ? b[K] : J(0,1,.)))
}

real colvector mm_qr::addomit(real colvector b0)
{
    real colvector b
    
    if (length(xindx)==0) return(b0)
    b = J(K,1,0)
    b[xindx] = b0[|1\kadj|]
    if (cons) b[K] = b0[Kadj]
    return(b)
}

real colvector mm_qr::meanadj(real colvector b0, real scalar sign)
{
    real colvector b
    
    if (!demean) return(b0)
    if (!cons)   return(b0)
    b = b0
    b[K] = b[K] + ymean() * sign
    if (k) b[K] = b[K] - (means()*b[|1\k|]) * sign
    return(b)
}

// translation of rqfnb.f from "quantreg" package version 5.85 for R
void mm_qr::fnb(real colvector y)
{
    real scalar    fp, fd, mu, g
    real colvector b, dy, y0
    real colvector W, x, s, z, w, dx, ds, dz, dw, dxdz, dsdw, rhs, d, r
    real matrix    ada
    pointer scalar a
    
    // data
    s = -(*this.y)
    if (cons & demean)    s = s :+ ymean()
    a = get_X()
    if (rows(*this.w)!=1) W = (*this.w) * (n/N)
    else                  W = 1
    
    // algorithm
    gen_z_w(z=., w=., s - _xb(*a, y), tol)
    x = J(n, 1, 1-p)
    s = J(n, 1, p)
    if (maxiter<=0) {
        // make sure gap is filled in even if maxiter=0
        gap = cross(z, W, x) + cross(w, W, s)
    }
    b = cross(*a,cons, W, x,0)
    while (iter < maxiter) {
        iter = iter + 1
        d    = 1 :/ (z :/ x + w :/ s)
        ada  = cross(*a,cons, d:*W, *a,cons)
        r    = z - w
        rhs  = b - cross(*a,cons, W, x :- d:*r,0)
        dy   = cholsolve(ada, rhs)
        if (hasmissing(dy)) dy = invsym(ada) * rhs // singularity encountered
        dx   = d :* (_xb(*a, dy) - r)
        ds   = -dx
        dz   = -z :* (1 :+ dx :/ x)
        dw   = -w :* (1 :+ ds :/ s)
        fp   = minselect(x, dx, s, ds)
        fd   = minselect(w, dw, z, dz)
        if (min((fp, fd)) < 1) {
            mu   = cross(z, W, x) + cross(w, W, s)
            g    = cross(z + fd * dz, W, x + fp * dx) + 
                   cross(w + fd * dw, W, s + fp * ds)
            mu   = mu * (g / mu)^3 / (2 * n)
            dxdz = dx :* dz
            dsdw = ds :* dw
            r    = mu:/s - mu:/x + dxdz:/x - dsdw:/s
            rhs  = rhs + cross(*a,cons, d:*W, r,0)
            dy = cholsolve(ada, rhs)
            if (hasmissing(dy)) dy = invsym(ada) * rhs
            dx   = d :* (_xb(*a, dy) - z + w - r) 
            ds   = -dx
            dz   = (mu :- z :* dx :- dxdz) :/ x - z
            dw   = (mu :- w :* ds :- dsdw) :/ s - w
            fp   = minselect(x, dx, s, ds)
            fd   = minselect(w, dw, z, dz)
        }
        x    = x + fp * dx
        s    = s + fp * ds
        w    = w + fd * dw
        z    = z + fd * dz
        y0   = y
        y    = y + fd * dy
        gap = cross(z, W, x) + cross(w, W, s)
        if (log) {
            if (log>=2) printflush(".")
            else        printlog(mreldif(y, y0), iter)
        }
        if (gap<tol) {
            conv = 1
            break
        }
    }
    if (log==2 & iter) printflush("\n")
}

pointer scalar mm_qr::get_X()
{
    real matrix x
    
    if (!(cons & demean) & !length(xindx)) return(X)
    x = *X
    if (cons & demean) x = x :- means()
    if (length(xindx)) x = x[,xindx]
    return(&x)
}

void mm_qr::gen_z_w(real colvector z, real colvector w, real colvector r,
    real scalar eps)
{
    real colvector o
    
    z = r :* (r :> 0)
    w = z - r
    o = eps :* (abs(r):<eps)
    z = z + o
    w = w + o
}

real matrix mm_qr::cross(real matrix a, real matrix b, | real matrix c,
    real matrix d, real matrix e)
{
    if (args()==2) {
        if (qd) return(quadcross(a, b))
        return(::cross(a, b))
    }
    if (args()==3) {
        if (qd) return(quadcross(a, b, c))
        return(::cross(a, b, c))
    }
    if (args()==4) {
        if (qd) return(quadcross(a, b, c, d))
        return(::cross(a, b, c, d))
    }
    if (qd) return(quadcross(a, b, c, d, e))
    return(::cross(a, b, c, d, e))
}

real scalar mm_qr::minselect(real colvector x, real colvector dx, 
    real colvector s, real colvector ds)
{
    real colvector min, p
    
    min = J(3,1,.)
    p = select(1::n, dx:<0)
    if (length(p)) min[1] = min(-x[p] :/ dx[p]) * beta
    p = select(1::n, ds:<0)
    if (length(p)) min[2] = min(-s[p] :/ ds[p]) * beta
    min[3] = 1
    return(min(min))
}

// - wrapper for quick QR fit
real colvector mm_qrfit(
    real colvector      y,
    | real matrix       X, 
      real colvector    w, 
      real scalar       p,
      real scalar       cons,
      real colvector    b_init,
      real scalar       relax)
{
    class mm_qr scalar S
    
    S.data(y, X, w, cons)
    if (p!=.) S.p(p)
    if (args()>=6 & b_init!=.) S.b_init(b_init)
    if (args()<7) relax = 0
    if (!S.converged() & !relax) _error(3360)
    return(S.b())
}

end


*! {smcl}
*! {marker mm_mloc}{bf:mm_mloc.mata}{asis}
*! version 1.0.3  12mar2021  Ben Jann
version 11
mata:

// M-estimate of location
struct _mm_mls_struct {
    real scalar eff   // not used by mscale
    real scalar bp    // only used by mscale
    real scalar delta // only used by mscale
    real scalar k
    real scalar b
    real scalar b0
    real scalar s     // not used by mscale
    real scalar l     // only used by mscale
    real scalar d
    real scalar tol
    real scalar iter
    real scalar maxiter
    real scalar trace
    real scalar conv
}

struct _mm_mls_struct scalar mm_mloc(real colvector x,
  | real colvector w, real scalar eff, string scalar obj,
    real scalar b, real scalar s, 
    real scalar trace, real scalar tol, real scalar iter)
{
    real scalar b0, i
    real colvector z
    struct _mm_mls_struct scalar S
    pointer(real scalar function) scalar f
    
    if (args()<2) w = 1
    if (hasmissing(w)) _error(3351)
    if (hasmissing(x)) _error(3351)
    S.eff     = (eff<.   ? eff   : 95)
    S.trace   = (trace<. ? trace : 0)
    S.tol     = (tol<.   ? tol   : 1e-10)
    S.maxiter = (iter<.  ? iter  : st_numscalar("c(maxiter)"))
    S.iter    = 0
    S.conv    = 0
    S.b       = .
    if (obj=="" | obj=="huber") {
        f = &mm_huber_w()
        S.k = mm_huber_k(S.eff)
    }
    else if (obj=="biweight") {
        f = &mm_biweight_w()
        S.k = mm_biweight_k(S.eff)
    }
    else {
        printf("{err}'%s' not supported; invalid objective function\n", obj)
        _error(3498)
    }
    if (rows(x)==0) return(S) // nothing to do (b = ., conv = 0)
    S.b = S.b0 = (b<. ? b : mm_median(x, w))
    S.s  = (s<. ? s : mm_median(abs(x :- (b<. ? mm_median(x, w) : S.b)), w)
                      / invnormal(0.75))
    if (S.s<=0) return(S) // scale is 0 (b = starting value, conv = 0)
    // iterate
    z = (x :- S.b) / S.s
    for (i=1; i<=S.maxiter; i++) {
        b0 = S.b
        z = (*f)(z, S.k) :* w
        S.b = mean(x, z)
        S.d = abs(S.b - b0) / S.s
        if (S.trace) printf("{txt}{lalign 16:Iteration %g:}" +
            "absolute difference = %9.0g\n", i, S.d)
        if (S.d <= S.tol) break
        z = (x :- S.b) / S.s
    }
    S.iter = i
    S.conv = (S.iter<=S.maxiter)
    return(S)
}

real scalar    mm_mloc_b(struct _mm_mls_struct scalar S) return(S.b)
real scalar   mm_mloc_b0(struct _mm_mls_struct scalar S) return(S.b0)
real scalar    mm_mloc_s(struct _mm_mls_struct scalar S) return(S.s)
real scalar mm_mloc_conv(struct _mm_mls_struct scalar S) return(S.conv)
real scalar    mm_mloc_d(struct _mm_mls_struct scalar S) return(S.d)
real scalar mm_mloc_iter(struct _mm_mls_struct scalar S) return(S.iter)
real scalar    mm_mloc_k(struct _mm_mls_struct scalar S) return(S.k)
real scalar  mm_mloc_eff(struct _mm_mls_struct scalar S) return(S.eff)

// M-estimate of scale
struct _mm_mls_struct scalar mm_mscale(real colvector x,
  | real colvector w, real scalar bp,
    real scalar b, real scalar l, 
    real scalar trace, real scalar tol, real scalar iter)
{
    real scalar b0, i
    real colvector z
    struct _mm_mls_struct scalar S
    
    if (args()<2) w = 1
    if (hasmissing(w)) _error(3351)
    if (hasmissing(x)) _error(3351)
    S.bp      = (bp<.    ? bp    : 50)
    S.trace   = (trace<. ? trace : 0)
    S.tol     = (tol<.   ? tol   : 1e-10)
    S.maxiter = (iter<.  ? iter  : st_numscalar("c(maxiter)"))
    S.iter    = 0
    S.conv    = 0
    S.b       = .
    S.k       = mm_biweight_k_bp(S.bp)
    if (rows(x)==0) return(S) // nothing to do (b = ., conv = 0)
    S.l = (l<. ? l : mm_median(x, w))
    S.b = S.b0 = (b<. ? b : 
        mm_median(abs(x :- (l<. ? mm_median(x, w) : S.l)), w) / invnormal(0.75))
    if (S.b<=0) return(S) // scale is 0 (b = starting value, conv = 0)
    // iterate
    S.delta = S.bp/100 * S.k^2/6
    z = (x :- S.l)
    for (i=1; i<=S.maxiter; i++) {
        b0 = S.b
        S.b = sqrt(mean(mm_biweight_rho(z/b0, S.k), w) / S.delta) * b0
        S.d = abs(S.b/b0 - 1)
        if (S.trace) printf("{txt}{lalign 16:Iteration %g:}" +
            "absolute relative difference = %9.0g\n", i, S.d)
        if (S.d <= S.tol) break
    }
    S.iter = i
    S.conv = (S.iter<=S.maxiter)
    return(S)
}

real scalar     mm_mscale_b(struct _mm_mls_struct scalar S) return(S.b)
real scalar    mm_mscale_b0(struct _mm_mls_struct scalar S) return(S.b0)
real scalar     mm_mscale_l(struct _mm_mls_struct scalar S) return(S.l)
real scalar  mm_mscale_conv(struct _mm_mls_struct scalar S) return(S.conv)
real scalar     mm_mscale_d(struct _mm_mls_struct scalar S) return(S.d)
real scalar  mm_mscale_iter(struct _mm_mls_struct scalar S) return(S.iter)
real scalar     mm_mscale_k(struct _mm_mls_struct scalar S) return(S.k)
real scalar    mm_mscale_bp(struct _mm_mls_struct scalar S) return(S.bp)
real scalar mm_mscale_delta(struct _mm_mls_struct scalar S) return(S.delta)

// Huber tuning constant for given efficiency
real scalar mm_huber_k(real scalar eff)
{
    if (eff==95) return(1.34499751)
    if (eff==90) return( .98180232)
    if (eff==85) return( .73173882)
    if (eff==80) return( .52942958)
    if (eff<63.7)  _error(3498, "efficiency may not be smaller than 63.7")
    if (eff>99.9) _error(3498, "efficiency may not be larger than 99.9")
    return(round(mm_finvert(eff/100, &mm_huber_eff(), 0.001, 3), 1e-8))
}

// Huber efficiency for given tuning constant
real scalar mm_huber_eff(real scalar k)
{
    if (k<=0) return(2/pi())
    if (k<1e-4) {
        // use linear interpolation at bottom
        return(max((2/pi(), mm_huber_eff(1e-4) - (1e-4 - k)/1e-4 *
            (mm_huber_eff(2e-4)-mm_huber_eff(1e-4)))))
    }
    return((normal(k)-normal(-k))^2 / 
        (2 * (k^2 * (1 - normal(k)) + normal(k) - 0.5 - k * normalden(k))))
}

// Huber objective functions and weights
real colvector mm_huber_rho(real colvector x, real scalar k)
{
    real colvector y, d
    
    y = abs(x)
    d = y:<=k
    return((0.5 * x:^2):*d :+ (k*y :- 0.5*k^2):*(1 :- d))
}

real colvector mm_huber_psi(real colvector x, real scalar k)
{
    real colvector d
    
    d = abs(x):<=k
    return(x:*d :+ (sign(x)*k):*(1:-d))
}

real colvector mm_huber_phi(real colvector x, real scalar k)
{
    return(abs(x):<=k)
}

real colvector mm_huber_w(real colvector x, real scalar k)
{
    real colvector y
    
    y = abs(x)
    return(editmissing(k :/ y, 0):^(y:>k))
}

// biweight tuning constant for given efficiency
real scalar mm_biweight_k(real scalar eff0)
{
    real scalar k, eff
    
    if (eff0==95) return(4.6850649)
    if (eff0==90) return(3.8826616)
    if (eff0==85) return(3.4436898)
    if (eff0==80) return(3.1369087)
    if (eff0<0.1)  _error(3498, "efficiency may not be smaller than 0.1")
    if (eff0>99.9) _error(3498, "efficiency may not be larger than 99.9")
    eff = eff0/100
    k = 0.8376 + 1.499*eff + 0.7509*sin(1.301*eff^6) + 
        0.04945*eff/sin(3.136*eff) + 0.9212*eff/cos(1.301*eff^6)
    return(round(mm_finvert(eff, &mm_biweight_eff(), k/5, k*1.1), 1e-7))
}

// biweight efficiency for given tuning constant
real scalar mm_biweight_eff(real scalar k)
{   // using Simpson's rule integration
    real scalar    l, u, n, d, phi, psi2
    real colvector x, w
    
    if (k<=0)  return(0)
    if (k>100) return(1)
    l = 0; u = k; n = 1000; d = (u-l)/n 
    x = rangen(l, u, n+1)
    w = 1 \ colshape((J(n/2,1,4), J(n/2,1,2)), 1)
    w[n+1] = 1
    phi = 2 * (d / 3) * quadcolsum(_mm_biweight_eff_phi(x, k):*w)
    psi2 = 2 * (d / 3) * quadcolsum(_mm_biweight_eff_psi2(x, k):*w)
    return(phi^2 / psi2)
}

real matrix _mm_biweight_eff_phi(real matrix x, real scalar k)
{
    real matrix x2

    x2 = (x / k):^2
    return(normalden(x) :* ((1 :- x2) :* (1 :- 5*x2)) :* (x2:<=1))
}

real matrix _mm_biweight_eff_psi2(real matrix x, real scalar k)
{
    real matrix x2

    x2 = (x / k):^2
    return(normalden(x) :* ((x :* (1 :- x2):^2) :* (x2:<=1)):^2)
}

// biweight tuning constant for given breakdown point
real scalar mm_biweight_k_bp(real scalar bp)
{
    if (bp==50) return(1.547645)
    if (bp<1)  _error(3498, "bp may not be smaller than 1")
    if (bp>50) _error(3498, "bp may not be larger than 50")
    return(round(mm_finvert(bp/100, &mm_biweight_bp(), 1.5, 18), 1e-7))
}

// biweight breakdown point for given tuning constant
real scalar mm_biweight_bp(real scalar k)
{   // using Simpson's rule integration
    real scalar    l, u, n, d
    real colvector x, w

    if (k<=0) return(1)
    l = 0; u = k; n = 1000; d = (u-l)/n
    x = rangen(l, u, n+1)
    w = 1 \ colshape((J(n/2,1,4), J(n/2,1,2)), 1)
    w[n+1] = 1
    return(2 * (normal(-k) +  d/3 *
        quadcolsum(normalden(x):*mm_biweight_rho(x, k):*w) / (k^2/6)))
}

// biweight objective functions and weights
real colvector mm_biweight_rho(real colvector x, real scalar k)
{
    real colvector x2

    x2 = (x / k):^2
    return(k^2/6 * (1 :- (1 :- x2):^3):^(x2:<=1))
}

real colvector mm_biweight_psi(real colvector x, real scalar k)
{
    real colvector x2

    x2 = (x / k):^2
    return((x :* (1 :- x2):^2) :* (x2:<=1))
}

real colvector mm_biweight_phi(real colvector x, real scalar k)
{
    real colvector x2

    x2 = (x / k):^2
    return(((1 :- x2) :* (1 :- 5*x2)) :* (x2:<=1))
}
real colvector mm_biweight_w(real colvector x, real scalar k)
{
    real colvector x2

    x2 = (x / k):^2
    return(((1 :- x2):^2) :* (x2:<=1))
}

end


*! {smcl}
*! {marker mm_hl}{bf:mm_hl.mata}{asis}
*! version 1.0.0  21oct2020  Ben Jann
version 11
mata:

// robust statistics based on pairwise comparisons (code adapted from
// robstat.ado, version 1.0.3)
// - HL: location
// - Qn: scale
// - mc (medcouple): skewness

// HL estimator

real scalar mm_hl(real colvector X, | real colvector w, real scalar fw, 
    real scalar naive)
{
    real colvector p
    
    if (args()<2) w = 1
    if (args()<3) fw = 0
    if (args()<4) naive = 0
    if (hasmissing(w)) _error(3351)
    if (any(w:<0)) _error(3498, "negative weights not allowed")
    if (rows(X)==0) return(.)
    if (hasmissing(X)) _error(3351)
    if (naive) {
        if (rows(w)==1) return(_mm_hl_naive(X))
        if (fw) return(_mm_hl_naive_fw(X, w))
        return(_mm_hl_naive_w(X, w))
    }
    if (mm_issorted(X)) {
        if (rows(w)==1) return(_mm_hl(X))
        if (fw) return(_mm_hl_fw(X, w))
        return(_mm_hl_w(X, w))
    }
    if (rows(w)==1) return(_mm_hl(sort(X,1)))
    p = order((X,w), (1,2))
    if (fw) return(_mm_hl_fw(X[p], w[p]))
    return(_mm_hl_w(X[p], w[p]))
}

real scalar _mm_hl(real colvector x) // no weights
{
    // the trick of this algorithm is to consider only elements that are on the 
    // right of the main diagonal
    real scalar     i, j, m, k, n, nl, nr, nL, nR, trial
    real colvector  xx, /*ww,*/ l, r, L, R
    
    n = rows(x)                    // dimension of search matrix
    if (n==1) return(x)            // returning observed value if n=1
    xx      = /*ww =*/ J(n, 1, .)  // temp vector for matrix elements
    l = L   = (1::n):+1            // indices of left boundary (old and new)
    r = R   = J(n, 1, n)           // indices of right boundary (old and new)
    nl = nl = n + comb(n, 2)       // number of cells below left boundary
    nr = nR = n * n                // number of cells within right boundary
    k       = nl + comb(n, 2)/2    // target quantile
    while ((nr-nl)>n) {
        // get trial value
        m = 0
        for (i=1; i<n; i++) { // last row cannot contain candidates
            if (l[i]<=r[i]) {
                // high median within row
                xx[++m] = __mm_hl_el(x, i, l[i]+trunc((r[i]-l[i]+1)/2))
                /*m++
                ww[m] = r[i] - l[i] + 1
                xx[m] = __mm_hl_el(x, i, l[i]+trunc(ww[m]/2))*/
            }
        }
        trial = _mm_hl_qhi(xx[|1 \ m|], .5)
        /*trial = _mm_hl_qhi_w(xx[|1 \ m|], ww[|1 \ m|], .5)*/
        /*the unweighted quantile is faster; results are the same*/
        // move right border
        j = n-1
        for (i=(n-1); i>=1; i--) {
            if (i==j) {
                if (__mm_hl_el(x, i, j)>=trial) {
                    R[i] = j
                    j = i-1
                    continue
                }
            }
            if (j<n) {
                while (__mm_hl_el(x, i, j+1)<trial) {
                    j++
                    if (j==n) break
                }
            }
            R[i] = j
        }
        nR = sum(R)
        if (nR>k) {
            swap(r, R)
            nr = nR
            continue
        }
        // move left border
        j = n + 1
        for (i=1; i<=n; i++) {
            if (j>(i+1)) {
                while (__mm_hl_el(x, i, j-1)>trial) {
                    j--
                    if (j==(i+1)) break
                }
            }
            if (j<(i+1)) j = (i+1)
            L[i] = j
        }
        nL = sum(L) - n
        if (nL<k) {
            swap(l, L)
            nl = nL
            continue
        }
        // trial = low quantile = high quantile
        if (ceil(k)!=k | (nR<k & nL>k)) return(trial)
        // trial = low quantile
        if (nL==k) {
            m = 0
            for (i=1; i<=n; i++) {
                if (L[i]>n) continue
                xx[++m] = __mm_hl_el(x, i, L[i])
            }
            return((trial+min(xx[|1 \ m|]))/2)
        }
        // trial = high quantile
        m = 0
        for (i=1; i<=n; i++) {
            if (R[i]<=i) continue
            xx[++m] = __mm_hl_el(x, i, R[i])
        }
        return((trial+max(xx[|1 \ m|]))/2)
    }
    // get target value from remaining candidates
    m = 0
    for (i=1; i<n; i++) { // last row cannot contain candidates
        if (l[i]<=r[i]) {
            for (j=l[i]; j<=r[i]; j++) {
                m++
                xx[m] = __mm_hl_el(x, i, j)
            }
        }
    }
    return(_mm_hl_q(xx[|1 \ m|], k, nl))
}

real scalar __mm_hl_el(real colvector y, real scalar i, real scalar j)
{
    return((y[i] + y[j])/2)
}

real scalar _mm_hl_w(real colvector x, real colvector w)
{
    real scalar     i, j, m, k, n, nl, nr, trial, Wl, WR, WL, W0, W1
    real colvector  xx, ww, l, r, L, R, ccw
    
    n       = rows(x)              // dimension of search matrix
    if (n==1) return(x)            // returning observed value if n=1
    xx = ww = J(n, 1, .)           // temp vector for matrix elements
    l = L   = (1::n):+1            // indices of left boundary (old and new)
    r = R   = J(n, 1, n)           // indices of right boundary (old and new)
    nl      = comb(n, 2) + n       // number of cells below left boundary
    nr      = n * n                // number of cells within right boundary
    ccw     = quadrunningsum(w)    // cumulative column weights
    W1      = quadsum(w[|2 \ .|] :* ccw[|1 \ n-1|]) // sum of weights in target triangle
    W0      = W1 + quadsum(w:*w)   // sum of weights in rest of search matrix
    Wl = WL = W0                   // sum of weights below left boundary
    WR      = W0 + W1              // sum of weights within right boundary
    k       = W0 + W1/2            // target quantile
    while ((nr-nl)>n) {
        // get trial value
        m = 0
        for (i=1; i<n; i++) { // last row cannot contain candidates
            if (l[i]<=r[i]) {
                // high median within row
                xx[++m] = __mm_hl_el(x, i, l[i]+trunc((r[i]-l[i]+1)/2))
            }
        }
        trial = _mm_hl_qhi(xx[|1 \ m|], .5)
        // move right border
        j = n-1
        for (i=(n-1); i>=1; i--) {
            if (i==j) {
                if (__mm_hl_el(x, i, j)>=trial) {
                    R[i] = j
                    j = i-1
                    continue
                }
            }
            if (j<n) {
                while (__mm_hl_el(x, i, j+1)<trial) {
                    j++
                    if (j==n) break
                }
            }
            R[i] = j
        }
        WR = quadsum(w:*ccw[R])
        if (WR>k) {
            swap(r, R)
            nr = sum(R)
            continue
        }
        // move left border
        j = n + 1
        for (i=1; i<=n; i++) {
            if (j>(i+1)) {
                while (__mm_hl_el(x, i, j-1)>trial) {
                    j--
                    if (j==(i+1)) break
                }
            }
            if (j<(i+1)) j = (i+1)
            L[i] = j
        }
        WL = quadsum(w:*ccw[L:-1])
        if (WL<k) {
            swap(l, L)
            Wl = WL
            nl = sum(L) - n
            continue
        }
        // trial = low quantile = high quantile
        if (WR==WL | (WR<k & WL>k)) return(trial)
        // trial = low quantile
        if (WL==k) {
            m = 0
            for (i=1; i<=n; i++) {
                if (L[i]>n) continue
                xx[++m] = __mm_hl_el(x, i, L[i])
            }
            return((trial+min(xx[|1 \ m|]))/2)
        }
        // trial = high quantile
        m = 0
        for (i=1; i<=n; i++) {
            if (R[i]<=i) continue
            xx[++m] = __mm_hl_el(x, i, R[i])
        }
        return((trial+max(xx[|1 \ m|]))/2)
    }
    // get target value from remaining candidates
    m = 0
    for (i=1; i<n; i++) { // last row cannot contain candidates
        if (l[i]<=r[i]) {
            for (j=l[i]; j<=r[i]; j++) {
                m++
                xx[m] = __mm_hl_el(x, i, j)
                ww[m] = w[i] * w[j]
            }
        }
    }
    return(_mm_hl_q_w(xx[|1 \ m|], ww[|1 \ m|], k, Wl))
}

real scalar _mm_hl_fw(real colvector x, real colvector w)
{   // the algorithm "duplicates" the diagonal so that the relevant pairs in 
    // case of w>1 can be taken into account
    real scalar     i, j, m, k, n, nl, nr, trial, Wl, WR, WL, W0, W1
    real colvector  xx, ww, l, r, L, R, ccw, wcorr, idx
    
    if (any(trunc(w):!=w)) _error(3498, "non-integer frequency not allowed")
    n       = rows(x)          // dimension of search matrix
    if (n==1) return(x)        // returning observed value if n=1
    xx = ww = J(n, 1, .)       // temp vector for matrix elements
    ccw     = runningsum(w)    // cumulative column weights
    wcorr   = mm_cond(w:<=1, 0, comb(w, 2)) // correction of weights
    idx     = 1::n             // diagonal indices
    l = L   = idx :+ 1 :+ (wcorr:==0) // indices of left boundary (old and new)
    r = R   = J(n, 1, n+1)     // indices of right boundary (old and new)
    nl      = comb(n, 2) + n + sum(wcorr:==0) // n. of cells below left boundary
    nr      = n * (n+1)        // number of cells within right boundary
    W0 = W1 = sum(w[|2 \ .|] :* ccw[|1 \ n-1|])
    W1      = W1 + sum(wcorr)  // sum of weights in target triangle
    W0      = W0 + sum(w) + sum(wcorr) // sum of weights in rest of search matrix
    Wl = WL = W0               // sum of weights below left boundary
    WR      = W0 + W1          // sum of weights within right boundary
    k       = W0 + W1/2        // target quantile
    while ((nr-nl)>n) {
        // get trial value
        m = 0
        for (i=1; i<=n; i++) {
            if (l[i]<=r[i]) {
                // high median within row
                xx[++m] = __mm_hl_el(x, i, 
                    (l[i]-1) + trunc(((r[i]-1)-(l[i]-1)+1)/2))
            }
        }
        trial = _mm_hl_qhi(xx[|1 \ m|], .5)
        // move right border
        j = n
        for (i=n; i>=1; i--) {
            if (j==i) {
                if (__mm_hl_el(x, i, j)>=trial) {
                    R[i] = j + (wcorr[i]==0)
                    j = i-1
                    continue
                }
            }
            if (j<=n) {
                while (__mm_hl_el(x, i, (j+1)-((j+1)>i))<trial) {
                    j++
                    if (j>n) break
                }
            }
            R[i] = j
        }
        WR = sum(w:*ccw[R:-(R:>idx)]) - sum(wcorr:*(R:==idx))
        if (WR>k) {
            swap(r, R)
            nr = sum(R)
            continue
        }
        // move left border
        j = n + 2
        for (i=1; i<=n; i++) {
            if (j>(i+1)) {
                while (__mm_hl_el(x, i, (j-1)-((j-1)>i))>trial) {
                    j--
                    if (j==(i+1)) break
                }
            }
            if (j<(i+1)) j = (i+1)
            if (j==(i+1)) {
                if (wcorr[i]==0) j++
            }
            L[i] = j
        }
        WL = sum(w:*ccw[L:-1:-((L:-1):>idx)]) - sum(wcorr:*((L:-1):==idx))
        if (WL<k) {
            swap(l, L)
            Wl = WL
            nl = sum(L) - n
            continue
        }
        // trial = low quantile = high quantile
        if (ceil(k)!=k | (WR<k & WL>k)) return(trial)
        // trial = low quantile
        if (WL==k) {
            m = 0
            for (i=1; i<=n; i++) {
                if ((L[i]-(L[i]>i))>n) continue
                xx[++m] = __mm_hl_el(x, i, L[i]-(L[i]>i))
            }
            return((trial+min(xx[|1 \ m|]))/2)
        }
        // trial = high quantile
        m = 0
        for (i=1; i<=n; i++) {
            if (R[i]<=i) continue
            if (R[i]==i+1) {
                if (wcorr[i]==0) continue
            }
            xx[++m] = __mm_hl_el(x, i, R[i]-1)
        }
        return((trial+max(xx[|1 \ m|]))/2)
    }
    // get target value from remaining candidates
    m = 0
    for (i=1; i<=n; i++) {
        if (l[i]<=r[i]) {
            for (j=l[i]; j<=r[i]; j++) {
                m++
                xx[m] = __mm_hl_el(x, i, j-1)
                ww[m] = (i==(j-1) ? comb(w[i], 2) : w[i]*w[j-1])
            }
        }
    }
    return(_mm_hl_q_w(xx[|1 \ m|], ww[|1 \ m|], k, Wl))
}

real scalar _mm_hl_naive(real colvector x) // no weights
{
    real scalar    i, j, m, n
    real colvector xx
    
    n = rows(x)
    if (n==1) return(x) // HL undefined if n=1; returning observed value
    m = 0
    xx = J(comb(n,2), 1, .)
    for (i=1; i<n; i++) {
        for (j=(i+1); j<=n; j++) {
            xx[++m] = __mm_hl_el(x, i, j)
        }
    }
    return(_mm_hl_q(xx, .5))
}

real scalar _mm_hl_naive_w(real colvector x, real colvector w)
{
    real scalar    i, j, m, n
    real colvector xx, ww

    n = rows(x)
    if (n==1) return(x) // HL undefined if n=1; returning observed value
    m = 0
    xx = J(comb(n,2), 1, .)
    ww = J(rows(xx), 1, .)
    for (i=1; i<n; i++) {
        for (j=(i+1); j<=n; j++) {
            m++
            xx[m] = __mm_hl_el(x, i, j)
            ww[m] = w[i]*w[j]
        }
    }
    return(_mm_hl_q_w(xx, ww, .5))
}

real scalar _mm_hl_naive_fw(real colvector x, real colvector w)
{
    real scalar    i, j, m, n
    real colvector xx, ww

    if (any(trunc(w):!=w)) _error(3498, "non-integer frequency not allowed")
    n = rows(x)
    if (n==1) return(x) // HL undefined if n=1; returning observed value
    m = 0
    xx = J(comb(n,2)+sum(w:>1), 1, .)
    ww = J(rows(xx), 1, .)
    for (i=1; i<=n; i++) {
        for (j=i; j<=n; j++) {
            if (i==j) {
                if (w[i]==1) continue
            }
            m++
            xx[m] = __mm_hl_el(x, i, j)
            ww[m] = (i==j ? comb(w[i], 2) : w[i]*w[j])
        }
    }
    return(_mm_hl_q_w(xx, ww, .5))
}

// Qn estimator

real scalar mm_qn(real colvector X, | real colvector w, real scalar fw, 
    real scalar naive)
{
    real colvector p
    real scalar    c
    
    if (args()<2) w = 1
    if (args()<3) fw = 0
    if (args()<4) naive = 0
    if (hasmissing(w)) _error(3351)
    if (any(w:<0)) _error(3498, "negative weights not allowed")
    if (rows(X)==0) return(.)
    if (hasmissing(X)) _error(3351)
    c = 1 / (sqrt(2) * invnormal(5/8))
    if (naive) {
        if (rows(w)==1) return(_mm_qn_naive(X) * c)
        if (fw) return(_mm_qn_naive_fw(X, w) * c)
        return(_mm_qn_naive_w(X, w) * c)
    }
    if (mm_issorted(X)) {
        if (rows(w)==1) return(_mm_qn(X) * c)
        if (fw) return(_mm_qn_fw(X, w) * c)
        return(_mm_qn_w(X, w) * c)
    }
    if (rows(w)==1) return(_mm_qn(sort(X,1)) * c)
    p = order((X,w), (1,2))
    if (fw) return(_mm_qn_fw(X[p], w[p]) * c)
    return(_mm_qn_w(X[p], w[p]) * c)
}

real scalar _mm_qn(real colvector x) // no weights
{
    real scalar     i, j, m, k, n, nl, nr, nL, nR, trial
    real colvector  xx, /*ww,*/ l, r, L, R

    n = rows(x)                    // dimension of search matrix
    if (n==1) return(0)            // returning zero if n=1
    xx      = /*ww =*/ J(n, 1, .)  // temp vector for matrix elements
    l = L   = (n::1):+1            // indices of left boundary (old and new)
    r = R   = J(n, 1, n)           // indices of right boundary (old and new)
    nl = nl = comb(n, 2) + n       // number of cells below left boundary
    nr = nR = n * n                // number of cells within right boundary
    k       = nl + comb(n, 2)/4    // target quantile
    /*k = nl + comb(trunc(n/2) + 1, 2)*/
    while ((nr-nl)>n) {
        // get trial value
        m = 0
        for (i=2; i<=n; i++) { // first row cannot contain candidates
            if (l[i]<=r[i]) {
                // high median within row
                xx[++m] = __mm_qn_el(x, i, l[i]+trunc((r[i]-l[i]+1)/2), n)
                /*m++
                ww[m] = r[i] - l[i] + 1
                xx[m] = __mm_qn_el(x, i, l[i]+trunc(www[m]/2), n)*/
            }
        }
        trial = _mm_hl_qhi(xx[|1 \ m|], .5)
        /*trial = _mm_hl_qhi_w(xx[|1 \ m|], ww[|1 \ m|], .5)*/
        /*the unweighted quantile is faster; results are the same*/
        //move right border
        j = 0
        for (i=n; i>=1; i--) {
            if (j<n) {
                while (__mm_qn_el(x, i, j+1, n)<trial) {
                    j++
                    if (j==n) break
                }
            }
            R[i] = j
        }
        nR = sum(R)
        if (nR>k) {
            swap(r, R)
            nr = nR
            continue
        }
        // move left border
        j = n + 1
        for (i=1; i<=n; i++) {
            while (__mm_qn_el(x, i, j-1, n)>trial) {
                j--
            }
            L[i] = j
        }
        nL = sum(L) - n
        if (nL<k) {
            swap(l, L)
            nl = nL
            continue
        }
        // trial = low quantile = high quantile
        if (ceil(k)!=k | (nR<k & nL>k)) return(trial)
        // trial = low quantile
        if (nL==k) {
            m = 0
            for (i=1; i<=n; i++) {
                if (L[i]>n) continue
                xx[++m] = __mm_qn_el(x, i, L[i], n)
            }
            return((trial+min(xx[|1 \ m|]))/2)
        }
        // trial = high quantile
        for (i=1; i<=n; i++) {
            xx[i] = __mm_qn_el(x, i, R[i], n)
        }
        return((trial+max(xx))/2)
    }
    // get target value from remaining candidates
    m = 0
    for (i=2; i<=n; i++) { // first row cannot contain candidates
        if (l[i]<=r[i]) {
            for (j=l[i]; j<=r[i]; j++) {
                m++
                xx[m] = __mm_qn_el(x, i, j, n)
            }
        }
    }
    return(_mm_hl_q(xx[|1 \ m|], k, nl))
}

real scalar __mm_qn_el(real colvector y, real scalar i, real scalar j, 
    real scalar n)
{
    return(y[i] - y[n-j+1])
}

real scalar _mm_qn_w(real colvector x, real colvector w)
{
    real scalar     i, j, m, k, n, nl, nr, trial, Wl, WR, WL, W0, W1
    real colvector  xx, ww, l, r, L, R, p, ccw

    n       = rows(x)              // dimension of search matrix
    if (n==1) return(0)            // returning zero if n=1
    xx = ww = J(n, 1, .)           // temp vector for matrix elements
    l = L   = (n::1):+1            // indices of left boundary (old and new)
    r = R   = J(n, 1, n)           // indices of right boundary (old and new)
    nl      = comb(n, 2) + n       // number of cells below left boundary
    nr      = n * n                // number of cells within right boundary
    ccw     = quadrunningsum(w[n::1]) // cumulative column weights
    W0      = quadsum(w:*ccw[n::1]) // sum weights in rest of search matrix
    W1      = quadsum(w:*ccw[rows(ccw)]) - W0 // sum of weights target triangle
    Wl = WL = W0                   // sum of weights below left boundary
    WR      = W0 + W1              // sum of weights within right boundary
    k       = W0 + W1/4            // target sum (high 25% quantile)
    while ((nr-nl)>n) {
        // get trial value
        m = 0
        for (i=2; i<=n; i++) { // first row cannot contain candidates
            if (l[i]<=r[i]) {
                // high median within row
                xx[++m] = __mm_qn_el(x, i, l[i]+trunc((r[i]-l[i]+1)/2), n)
            }
        }
        trial = _mm_hl_qhi(xx[|1 \ m|], .5)
        //move right border
        j = 0
        for (i=n; i>=1; i--) {
            if (j<n) {
                while (__mm_qn_el(x, i, j+1, n)<trial) {
                    j++
                    if (j==n) break
                }
            }
            R[i] = j
        }
        p = (R:>0)
        WR = quadsum(select(w, p) :* ccw[select(R, p)])
        if (WR>k) {
            swap(r, R)
            nr = sum(R)
            continue
        }
        // move left border
        j = n + 1
        for (i=1; i<=n; i++) {
            while (__mm_qn_el(x, i, j-1, n)>trial) {
                j--
            }
            L[i] = j
        }
        WL = quadsum(w :* ccw[L:-1])
        if (WL<k) {
            swap(l, L)
            Wl = WL
            nl = sum(L) - n
            continue
        }
        // trial = low quantile = high quantile
        if (WR==WL | (WR<k & WL>k)) return(trial)
        // trial = low quantile
        if (WL==k) {
            m = 0
            for (i=1; i<=n; i++) {
                if (L[i]>n) continue
                xx[++m] =  __mm_qn_el(x, i, L[i], n)
            }
            return((trial+min(xx[|1 \ m|]))/2)
        }
        // trial = high quantile
        for (i=1; i<=n; i++) {
            xx[i] = __mm_qn_el(x, i, R[i], n)
        }
        return((trial+max(xx))/2)
    }
    // get target value from remaining candidates
    m = 0
    for (i=2; i<=n; i++) { // first row cannot contain candidates
        if (l[i]<=r[i]) {
            for (j=l[i]; j<=r[i]; j++) {
                m++
                xx[m] = __mm_qn_el(x, i, j, n)
                ww[m] = w[i] * w[n-j+1]
            }
        }
    }
    return(_mm_hl_q_w(xx[|1 \ m|], ww[|1 \ m|], k, Wl))
}

real scalar _mm_qn_fw(real colvector x, real colvector w)
{
    real scalar     i, j, m, k, n, nl, nr, trial, Wl, WR, WL, W0, W1
    real colvector  xx, ww, l, r, L, R, p, ccw, wcorr, idx

    if (any(trunc(w):!=w)) _error(3498, "non-integer frequency not allowed")
    n       = rows(x)              // dimension of search matrix
    if (n==1) return(0)            // returning zero if n=1
    xx = ww = J(n, 1, .)           // temp vector for matrix elements
    idx     = n::1                 // (minor) diagonal indices
    l = L   = idx:+1               // indices of left boundary (old and new)
    r = R   = J(n, 1, n+1)         // indices of right boundary (old and new)
    nl      = comb(n, 2) + n       // number of cells below left boundary
    nr      = n * (n+1)            // number of cells within right boundary
    ccw     = runningsum(w[idx])   // cumulative column weights
    wcorr   = mm_cond(w:<=1, 0, comb(w, 2))[idx] // correction of weights
    W0      = sum(w:*ccw[idx]) - sum(wcorr) // sum of weights in rest of search matrix
    W1      = sum(w:*ccw[rows(ccw)]) - W0 // sum of weights in target triangle
    Wl = WL = W0                   // sum of weights below left boundary
    WR      = W0 + W1              // sum of weights within right boundary
    k       = W0 + W1/4            // target quantile
    while ((nr-nl)>n) {
        // get trial value
        m = 0
        for (i=1; i<=n; i++) {
            if (l[i]<=r[i]) {
                // high median within row
                xx[++m] = __mm_qn_el(x, i, 
                    (l[i]-1)+trunc(((r[i]-1)-(l[i]-1)+1)/2), n)
            }
        }
        trial = _mm_hl_qhi(xx[|1 \ m|], .5)
        //move right border
        j = 0
        for (i=n; i>=1; i--) {
            if (j<=n) {
                while (__mm_qn_el(x, i, (j+1)-((j+1)>(n-i+1)), n)<trial) {
                    j++
                    if (j>n) break
                }
            }
            R[i] = j
        }
        p = (R:>0)
        WR = sum(select(w, p) :* ccw[select(R:-(R:>idx), p)]) - sum(wcorr:*(R:==idx))
        if (WR>k) {
            swap(r, R)
            nr = sum(R)
            continue
        }
        // move left border
        j = n + 2
        for (i=1; i<=n; i++) {
            while (__mm_qn_el(x, i, (j-1)-((j-1)>(n-i+1)), n)>trial) {
                j--
            }
            L[i] = j
        }
        WL = sum(w:*ccw[L:-1:-((L:-1):>idx)]) - sum(wcorr:*((L:-1):==idx))
        if (WL<k) {
            swap(l, L)
            Wl = WL
            nl = sum(L) - n
            continue
        }
        // trial = low quantile = high quantile
        if (ceil(k)!=k | (WR<k & WL>k)) return(trial)
        // trial = low quantile
        if (WL==k) {
            m = 0
            for (i=1; i<=n; i++) {
                if ((L[i]-(L[i]>(n-i+1)))>n) continue
                xx[++m] = __mm_qn_el(x, i, L[i]-(L[i]>(n-i+1)), n)
            }
            return((trial+min(xx[|1 \ m|]))/2)
        }
        // trial = high quantile
        for (i=1; i<=n; i++) {
            xx[i] = __mm_qn_el(x, i, R[i]-(R[i]>(n-i+1)), n)
        }
        return((trial+max(xx))/2)
    }
    // get target value from remaining candidates
    m = 0
    for (i=1; i<=n; i++) {
        if (l[i]<=r[i]) {
            for (j=l[i]; j<=r[i]; j++) {
                m++
                xx[m] = __mm_qn_el(x, i, j-1, n)
                ww[m] = ((n-i+1)==(j-1) ? comb(w[i], 2) : w[i]*w[n-(j-1)+1])
            }
        }
    }
    return(_mm_hl_q_w(xx[|1 \ m|], ww[|1 \ m|], k, Wl))
}

real scalar _mm_qn_naive(real colvector x) // no weights
{
    real scalar    i, j, m, n
    real colvector xx
    
    n = rows(x)
    if (n==1) return(0) // returning zero if n=1
    m = 0
    xx = J(comb(n,2), 1, .)
    for (i=1; i<n; i++) {
        for (j=(i+1); j<=n; j++) {
            xx[++m] = abs(x[i] - x[j])
        }
    }
    return(_mm_hl_q(xx, 0.25))
}

real scalar _mm_qn_naive_w(real colvector x, real colvector w)
{
    real scalar    i, j, m, n
    real colvector xx, ww
    
    n = rows(x)
    if (n==1) return(0) // Qn undefined if n=1; returning zero
    m = 0
    xx = J(comb(n,2), 1, .)
    ww = J(rows(xx), 1, .)
    for (i=1; i<n; i++) {
        for (j=(i+1); j<=n; j++) {
            m++
            xx[m] = abs(x[i] - x[j])
            ww[m] = w[i]*w[j]
        }
    }
    return(_mm_hl_q_w(xx, ww, 0.25))
}

real scalar _mm_qn_naive_fw(real colvector x, real colvector w)
{
    real scalar    i, j, m, n
    real colvector xx, ww

    if (any(trunc(w):!=w)) _error(3498, "non-integer frequency not allowed")
    n = rows(x)
    if (n==1) return(0) // returning zero if n=1
    m = 0
    xx = J(comb(n,2)+sum(w:>1), 1, .)
    ww = J(rows(xx), 1, .)
    for (i=1; i<=n; i++) {
        for (j=i; j<=n; j++) {
            if (i==j) {
                if (w[i]==1) continue
            }
            m++
            xx[m] = abs(x[i] - x[j])
            ww[m] = (i==j ? comb(w[i], 2) : w[i]*w[j])
        }
    }
    return(_mm_hl_q_w(xx, ww, 0.25))
}

// medcouple

real scalar mm_mc(real colvector X, | real colvector w, real scalar fw, 
    real scalar naive)
{
    real colvector p, r
    
    if (args()<2) w = 1
    if (args()<3) fw = 0
    if (args()<4) naive = 0
    if (hasmissing(w)) _error(3351)
    if (any(w:<0)) _error(3498, "negative weights not allowed")
    if (rows(X)==0) return(.)
    if (hasmissing(X)) _error(3351)
    if (mm_issorted(X)) {
        if (naive) {
            if (rows(w)==1) return(_mm_mc_naive(X:-_mm_median(X)))
            if (fw) return(_mm_mc_naive_fw(X:-_mm_median(X,w), w))
            return(_mm_mc_naive_w(X:-_mm_median(X,w), w))
        }
        r = rows(X)::1
        if (rows(w)==1) return(_mm_mc(X[r]:-_mm_median(X)))
        if (fw) return(_mm_mc_fw(X[r]:-_mm_median(X,w), w[r]))
        return(_mm_mc_w(X[r]:-_mm_median(X,w), w[r]))
    }
    if (naive) {
        if (rows(w)==1) return(_mm_mc_naive(X:-mm_median(X)))
        if (fw) return(_mm_mc_naive_fw(X:-mm_median(X,w), w))
        return(_mm_mc_naive_w(X:-mm_median(X,w), w))
    }
    if (rows(w)==1) {
        p = order(X, 1)
        r = p[rows(X)::1]
        return(_mm_mc(X[r]:-_mm_median(X[p])))
    }
    p = order((X,w), (1,2))
    r = p[rows(X)::1]
    if (fw) return(_mm_mc_fw(X[r]:-_mm_median(X[p], w[p]), w[r]))
    return(_mm_mc_w(X[r]:-_mm_median(X[p], w[p]), w[r]))
}

real scalar _mm_mc(real colvector x) // no weights; assumes med(x)=0
{
    real scalar     i, j, m, k, n, q, nl, nr, nL, nR, trial, npos, nzero
    real colvector  xx, /*ww,*/ l, r, L, R, xpos, xneg

    if (rows(x)<=1) return(0)      // returning zero if n=1
    xpos    = select(x, x:>0)      // obervations > median
    npos    = rows(xpos)           // number of obs > median
    xneg    = select(x, x:<0)      // observations < median
    nzero   = sum(x:==0)           // number of obs = median
    n       = npos + nzero         // number of rows in search matrix
    q       = rows(xneg) + nzero   // number of columns in search matrix
    xx      = /*ww =*/ J(n, 1, .)  // temp vector for matrix elements
    l = L   = J(n, 1, 1)           // indices of left boundary (old and new)
    r = R   = J(n, 1, q)           // indices of right boundary (old and new)
    nl = nL = 0                    // number of cells below left boundary
    nr = nR = n * q                // number of cells within right boundary
    k       = n*q/2                // target quantile
    while ((nr-nl)>n) {
        // get trial value
        m = 0
        for (i=1; i<=n; i++) {
            if (l[i]<=r[i]) {
                // high median within row
                xx[++m] = -__mm_mc_el(xpos, xneg, npos, nzero, i,
                           l[i]+trunc((r[i]-l[i]+1)/2))
                /*m++
                ww[m] = r[i] - l[i] + 1
                xx[m] = -__mm_mc_el(xpos, xneg, npos, nzero, i,
                          l[i]+trunc((ww[m])/2))*/
            }
        }
        trial = _mm_hl_qhi(xx[|1 \ m|], .5)
        /*the unweighted quantile is faster; results are the same*/
        /*trial = _mm_hl_qhi_w(xx[|1 \ m|], ww[|1 \ m|], .5)*/
        // move right border
        j = 0
        for (i=n; i>=1; i--) {
            if (j<q) {
                while (-__mm_mc_el(xpos, xneg, npos, nzero, i, j+1)<trial) {
                    j++
                    if (j==q) break
                }
            }
            R[i] = j
        }
        nR = sum(R)
        if (nR>k) {
            swap(r, R)
            nr = nR
            continue
        }
        // move left border
        j = q + 1
        for (i=1; i<=n; i++) {
            while (-__mm_mc_el(xpos, xneg, npos, nzero, i, j-1)>trial) {
                j--
            }
            L[i] = j
        }
        nL = sum(L) - n
        if (nL<k) {
            swap(l, L)
            nl = nL
            continue
        }
        // trial = low quantile = high quantile
        if (ceil(k)!=k | (nR<k & nL>k)) return(-trial)
        // trial = low quantile
        if (nL==k) {
            m = 0
            for (i=1; i<=n; i++) {
                if (L[i]>q) continue
                xx[++m] = -__mm_mc_el(xpos, xneg, npos, nzero, i, L[i])
            }
            return(-(trial+min(xx[|1 \ m|]))/2)
        }
        // trial = high quantile
        for (i=1; i<=n; i++) {
            xx[i] = -__mm_mc_el(xpos, xneg, npos, nzero, i, R[i])
        }
        return(-(trial+max(xx))/2)
    }
    // get target value from remaining candidates
    m = 0
    for (i=1; i<=n; i++) {
        if (l[i]<=r[i]) {
            for (j=l[i]; j<=r[i]; j++) {
                m++
                xx[m] = -__mm_mc_el(xpos, xneg, npos, nzero, i, j)
            }
        }
    }
    return(-_mm_hl_q(xx[|1 \ m|], k, nl))
}

real scalar __mm_mc_el(real colvector xpos, real colvector xneg, 
    real scalar npos, real scalar nzero, real scalar i, real scalar j)
{
    if (i<=npos) {
        if (j<=nzero) return(1)
        // => (j>nzero)
        return((xpos[i] + xneg[j-nzero])/(xpos[i] - xneg[j-nzero]))
    }
    // => (i>npos)
    if (j>nzero)  return(-1)
    // => (j<=nzero)
    return(sign((npos+nzero-i+1)-j))
}

real scalar _mm_mc_w(real colvector x, real colvector w) // assumes med(x)=0
{
    real scalar     i, j, m, k, n, q, nl, nr, trial, npos, nzero, Wl, WR, WL, W
    real colvector  xx, ww, l, r, L, R, p, ccw, xpos, xneg, wpos, wneg, wzero

    if (rows(x)<=1) return(0)      // returning zero if n=1
    p = (x:>0)
    xpos    = select(x, p)         // obervations > median
    wpos    = select(w, p)         // weights of obervations >= median
    npos    = rows(xpos)           // number of obs > median
    p = (x:<0)
    xneg    = select(x, p)         // observations < median
    wneg    = select(w, p)         // weights of obervations <= median
    wzero   = select(w, x:==0)     // weights of obervations = median
    nzero   = rows(wzero)          // number of obs = median
    if (nzero>0) {
        wpos = wpos \ wzero
        wneg = wzero[nzero::1] \ wneg // need to use reverse ordered wzero
    }
    n       = npos + nzero         // number of rows in search matrix
    q       = rows(xneg) + nzero   // number of columns in search matrix
    xx = ww = J(n, 1, .)           // temp vector for matrix elements
    l = L   = J(n, 1, 1)           // indices of left boundary (old and new)
    r = R   = J(n, 1, q)           // indices of right boundary (old and new)
    nl      = 0                    // number of cells below left boundary
    nr      = n * q                // number of cells within right boundary
    ccw     = quadrunningsum(wneg) // cumulative column weights
    W       = quadsum(wpos:*ccw[rows(ccw)]) // sum of weights in search matrix
    Wl = WL = 0                    // sum of weights below left boundary
    WR      = W                    // sum of weights within right boundary
    k       = W/2                  // target quantile
    while ((nr-nl)>n) {
        // get trial value
        m = 0
        for (i=1; i<=n; i++) {
            if (l[i]<=r[i]) {
                // high median within row
                xx[++m] = -__mm_mc_el(xpos, xneg, npos, nzero, i,
                           l[i]+trunc((r[i]-l[i]+1)/2))
            }
        }
        trial = _mm_hl_qhi(xx[|1 \ m|], .5)
        // move right border
        j = 0
        for (i=n; i>=1; i--) {
            if (j<q) {
                while (-__mm_mc_el(xpos, xneg, npos, nzero, i, j+1)<trial) {
                    j++
                    if (j==q) break
                }
            }
            R[i] = j
        }
        p = (R:>0)
        if (any(p)) WR = quadsum(select(wpos, p) :* ccw[select(R, p)])
        else        WR = 0
        if (WR>k) {
            swap(r, R)
            nr = sum(R)
            continue
        }
        // move left border
        j = q + 1
        for (i=1; i<=n; i++) {
            while (-__mm_mc_el(xpos, xneg, npos, nzero, i, j-1)>trial) {
                j--
            }
            L[i] = j
        }
        p = (L:>1)
        WL = quadsum(select(wpos, p) :* ccw[select(L, p):-1])
        if (WL<k) {
            swap(l, L)
            Wl = WL
            nl = sum(L) - n
            continue
        }
        // trial = low quantile = high quantile
        if (WR==WL | (WR<k & WL>k)) return(-trial)
        // trial = low quantile
        if (WL==k) {
            m = 0
            for (i=1; i<=n; i++) {
                if (L[i]>q) continue
                xx[++m] = -__mm_mc_el(xpos, xneg, npos, nzero, i, L[i])
            }
            return(-(trial+min(xx[|1 \ m|]))/2)
        }
        // trial = high quantile
        for (i=1; i<=n; i++) {
            xx[i] = -__mm_mc_el(xpos, xneg, npos, nzero, i, R[i])
        }
        return(-(trial+max(xx))/2)
    }
    // get target value from remaining candidates
    m = 0
    for (i=1; i<=n; i++) {
        if (l[i]<=r[i]) {
            for (j=l[i]; j<=r[i]; j++) {
                m++
                xx[m] = -__mm_mc_el(xpos, xneg, npos, nzero, i, j)
                ww[m] = wpos[i] * wneg[j]
            }
        }
    }
    return(-_mm_hl_q_w(xx[|1 \ m|], ww[|1 \ m|], k, Wl))
}

real scalar _mm_mc_fw(real colvector x, real colvector w) // assumes med(x)=0
{
    real scalar     i, j, m, k, n, q, nl, nr, trial, npos, nzero, Wl, WR, WL, W
    real colvector  xx, ww, l, r, L, R, p, ccw, ccwz, xpos, xneg, wpos, wneg, wnegz, wzero

    if (any(trunc(w):!=w)) _error(3498, "non-integer frequency not allowed")
    if (rows(x)<=1) return(0)       // returning zero if n=1
    p = (x:>0)
    xpos    = select(x, p)          // obervations > median
    wpos    = select(w, p)          // weights of obervations >= median
    npos    = rows(xpos)            // number of obs > median
    p = (x:<0)
    xneg    = select(x, p)          // observations < median
    wneg    = select(w, p)          // weights of obervations <= median
    wzero   = sum(select(w, x:==0)) // aggregate weights for x==median
    nzero   = (wzero>0)             // has x==median
    if (nzero>0) {
        wpos  = wpos \ 1
        wnegz = (wzero>1 ? (comb(wzero, 2), wzero, comb(wzero, 2))' : wzero) \ wneg*wzero
        wneg  = wzero \ wneg
    }
    n       = npos + nzero          // number of rows in search matrix
    q       = nzero + rows(xneg)    // number of columns in search matrix
    xx = ww = J(n, 1, .)            // temp vector for matrix elements
    l = L   = J(n, 1, 1)            // indices of left boundary (old and new)
    r = R   = J(n, 1, q)            // indices of right boundary (old and new)
    if (wzero>1) {
        r[n] = q+2
        R[n] = q+2
    }
    nl      = 0                     // number of cells below left boundary
    nr      = sum(R)                // number of cells within right boundary
    ccw     = runningsum(wneg)      // cumulative column weights
    if (nzero==0) {
        W   = quadsum(wpos:*ccw[rows(ccw)]) // sum of weights in search matrix
    }
    else {
        ccwz = runningsum(wnegz)    // cumulative column weights in last row
        if (npos>=1) W = quadsum(wpos[|1 \ n-1|]:*ccw[rows(ccw)]) + ccwz[rows(ccwz)]
        else         W = ccwz[rows(ccwz)]
    }
    Wl = WL = 0                       // sum of weights below left boundary
    WR      = W                       // sum of weights within right boundary
    k       = W/2                     // target quantile
    while ((nr-nl)>n) {
        // get trial value
        m = 0
        for (i=1; i<=(n-(wzero>1)); i++) {
            if (l[i]<=r[i]) {
                // high median within row
                xx[++m] = -__mm_mc_el(xpos, xneg, npos, nzero, i,
                           l[i]+trunc((r[i]-l[i]+1)/2))
            }
        }
        if (wzero>1) { // handle last row
            if (l[n]<=r[n]) {
                m++
                xx[m] = -__mm_mc_el(xpos, xneg, npos-1, 3, n,
                           l[n]+trunc((r[n]-l[n]+1)/2))
            }
        }
        trial = _mm_hl_qhi(xx[|1 \ m|], .5)
        // move right border
        if (wzero>1) { // handle last row
            j = 0
            while (-__mm_mc_el(xpos, xneg, npos-1, 3, n, j+1)<trial) {
                j++
                if (j==(q+2)) break
            }
            R[n] = j
        }
        j = 0
        for (i=(n-(wzero>1)); i>=1; i--) {
            if (j<q) {
                while (-__mm_mc_el(xpos, xneg, npos, nzero, i, j+1)<trial) {
                    j++
                    if (j==q) break
                }
            }
            R[i] = j
        }
        p = (R:>0)
        if (nzero>0) p[n] = 0
        if (any(p)) WR = quadsum(select(wpos, p) :* ccw[select(R, p)])
        else        WR = 0
        if (nzero>0) {
            if (R[n]>0) WR = WR + ccwz[R[n]]
        }
        if (WR>k) {
            swap(r, R)
            nr = sum(R)
            continue
        }
        // move left border
        j = q + 1
        for (i=1; i<=(n-(wzero>1)); i++) {
            while (-__mm_mc_el(xpos, xneg, npos, nzero, i, j-1)>trial) {
                j--
            }
            L[i] = j
        }
        if (wzero>1) { // handle last row
            j = q + 3
            while (-__mm_mc_el(xpos, xneg, npos-1, 3, n, j-1)>trial) {
                j--
            }
            L[n] = j
        }
        p = (L:>1)
        if (nzero>0) p[n] = 0
        if (any(p)) WL = quadsum(select(wpos, p) :* 
            (rows(ccw)==1 ? ccw[select(L, p):-1]' : ccw[select(L, p):-1]))
        else        WL = 0
        if (nzero>0) {
            if (L[n]>1) WL = WL + ccwz[L[n]-1]
        }
        if (WL<k) {
            swap(l, L)
            Wl = WL
            nl = sum(L) - n
            continue
        }
        // trial = low quantile = high quantile
        if (ceil(k)!=k | (WR<k & WL>k)) return(-trial)
        // trial = low quantile
        if (WL==k) {
            m = 0
            for (i=1; i<=(n-(wzero>1)); i++) {
                if (L[i]>q) continue
                xx[++m] = -__mm_mc_el(xpos, xneg, npos, nzero, i, L[i])
            }
            if (wzero>1) { // handle last row
                if (L[n]<=(q+2)) {
                    xx[++m] = -__mm_mc_el(xpos, xneg, npos-1, 3, n, L[n])
                }
            }
            return(-(trial+min(xx[|1 \ m|]))/2)
        }
        // trial = high quantile
        for (i=1; i<=(n-(wzero>1)); i++) {
            xx[i] = -__mm_mc_el(xpos, xneg, npos, nzero, i, R[i])
        }
        if (wzero>1) { // handle last row
            xx[n] = -__mm_mc_el(xpos, xneg, npos-1, 3, n, R[n])
        }
        return(-(trial+max(xx))/2)
    }
    // get target value from remaining candidates
    m = 0
    for (i=1; i<=(n-nzero); i++) {
        if (l[i]<=r[i]) {
            for (j=l[i]; j<=r[i]; j++) {
                m++
                xx[m] = -__mm_mc_el(xpos, xneg, npos, nzero, i, j)
                ww[m] = wpos[i] * wneg[j]
            }
        }
    }
    if (nzero>0) { // handle last row
        if (l[n]<=r[n]) {
            for (j=l[n]; j<=r[n]; j++) {
                m++
                xx[m] = -__mm_mc_el(xpos, xneg, npos-(wzero>1),
                    (wzero>1 ? 3 : 1), n, j)
                ww[m] = wnegz[j]
            }
        }
    }
    return(-_mm_hl_q_w(xx[|1 \ m|], ww[|1 \ m|], k, Wl))
}

real scalar _mm_mc_naive(real colvector x) // noweights; assumes med(x)=0
{
    real scalar    i, j, m, n, q, npos, nzero
    real colvector xx, xpos, xneg 

    if (rows(x)<=1) return(0) // returning zero if n=1
    xpos   = select(x, x:>0)
    npos   = rows(xpos)
    xneg   = select(x, x:<0)
    nzero  = sum(x:==0)
    n      = npos + nzero
    q      = rows(xneg) + nzero
    m = 0
    xx = J(n*q, 1, .)
    for (i=1; i<=n; i++) {
        for (j=1; j<=q; j++) {
            xx[++m] = __mm_mc_el(xpos, xneg, npos, nzero, i, j)
        }
    }
    return(_mm_hl_q(xx, .5))
}

real scalar _mm_mc_naive_w(real colvector x, real colvector w) // assumes med(x)=0
{
    real scalar    i, j, m, n, q, npos, nzero
    real colvector xx, ww, xpos, xneg, wpos, wneg, wzero

    if (rows(x)<=1) return(0) // returning zero if n=1
    xpos   = select(x, x:>0)
    wpos   = select(w, x:>0)
    npos   = rows(xpos)
    xneg   = select(x, x:<0)
    wneg   = select(w, x:<0)
    wzero  = select(w, x:==0)
    nzero  = rows(wzero)
    if (nzero>0) {
        wpos = wpos \ wzero
        wneg = wzero[nzero::1] \ wneg // need to use reverse ordered wzero
    }
    n      = npos + nzero
    q      = rows(xneg) + nzero
    m = 0
    xx = ww = J(n*q, 1, .)
    for (i=1; i<=n; i++) {
        for (j=1; j<=q; j++) {
            m++
            xx[m] = __mm_mc_el(xpos, xneg, npos, nzero, i, j)
            ww[m] = wpos[i] * wneg[j]
        }
    }
    return(_mm_hl_q_w(xx, ww, .5))
}

real scalar _mm_mc_naive_fw(real colvector x, real colvector w) // assumes med(x)=0
{
    real scalar    i, j, m, n, q, npos, nzero, wzero
    real colvector xx, ww, xpos, xneg, wpos, wneg

    if (any(trunc(w):!=w)) _error(3498, "non-integer frequency not allowed")
    if (rows(x)<=1) return(0) // returning zero if n=1
    xpos   = select(x, x:>0)
    wpos   = select(w, x:>0)
    npos   = rows(xpos)
    xneg   = select(x, x:<0)
    wneg   = select(w, x:<0)
    wzero  = sum(select(w, x:==0)) // aggregate weights for x==median
    nzero  = (wzero>0)
    if (nzero>0) {
        wpos = wpos \ wzero
        wneg = wzero \ wneg
    }
    n      = npos + nzero
    q      = rows(xneg) + nzero
    m = 0
    xx = ww = J(n*q + 2*(wzero>1), 1, .)
    for (i=1; i<=n; i++) {
        for (j=1; j<=q; j++) {
            m++
            if (i>npos & j==nzero) { // x==median
                xx[m] = 0
                ww[m] = wzero
                if (wzero>1) {
                    m++
                    xx[m] = 1
                    ww[m] = comb(wzero, 2)
                    m++
                    xx[m] = -1
                    ww[m] = comb(wzero, 2)
                }
            }
            else {
                xx[m] = __mm_mc_el(xpos, xneg, npos, nzero, i, j)
                ww[m] = wpos[i] * wneg[j]
            }
        }
    }
    return(_mm_hl_q_w(xx, ww, .5))
}

// helper functions for quantiles

real scalar _mm_hl_q(real colvector x, real scalar P, 
    | real scalar offset) // must be integer; changes meaning of P if specified
{   // quantile (definition 2)
    real scalar    j0, j1, n, k
    real colvector p

    n = rows(x)
    if (n<1) return(.)
    if (n==1) return(x)
    if (args()==3) k = P     // P is a count (possibly noninteger)
    else           k = P * n // P is a proportion
    j0 = ceil(k)      - (args()==3 ? offset : 0) // index of low quantile
    j1 = floor(k) + 1 - (args()==3 ? offset : 0) // index of high quantile
    if (j0<1)      j0 = 1
    else if (j0>n) j0 = n
    if (j1<1)      j1 = 1
    else if (j1>n) j1 = n
    p = order(x, 1)
    if (j0==j1) return(x[p[j1]])
    return((x[p[j0]] + x[p[j1]])/2)
}

real scalar _mm_hl_qhi(real colvector x, real scalar P)
{   // high quantile
    real scalar    j, n
    real colvector p

    n = rows(x)
    if (n<1) return(.)
    p = order(x, 1)
    j = floor(P * n) + 1
    if (j<1)      j = 1
    else if (j>n) j = n
    return(x[p[j]])
}

real scalar _mm_hl_q_w(real colvector x, real colvector w, real scalar P, 
    | real scalar offset) // changes meaning of P if specified
{
    real scalar    n, i, k
    real colvector p, cw

    n = rows(x)
    if (n<1) return(.)
    p = order(x, 1)
    if (anyof(w, 0)) {
         p = select(p, w[p]:!=0)
         n = rows(p)
    }
    if (n<1) return(.)
    if (n==1) return(x[p])
    if (args()==4) {
        cw = quadrunningsum(offset \ w[p])[|2 \ n+1|]
        k = P // P is a count
    }
    else {
        cw = quadrunningsum(w[p])
        k = P * cw[n] // P is a proportion
    }
    if (k>=cw[n]) return(x[p[n]])
    for (i=1; i<=n; i++) {
        if (k>cw[i]) continue
        if (k==cw[i]) return((x[p[i]]+x[p[i+1]])/2)
        return(x[p[i]])
    }
    // cannot be reached
}

/*
real scalar _mm_hl_qhi_w(real colvector x, real colvector w, real scalar P)
{   // high quantile (weighted)
    real scalar    i, n, k
    real colvector p, cw

    n = rows(x)
    if (n<1) return(.)
    p = order(x, 1)
    cw = quadrunningsum(w[p])
    k  = cw[n] * P
    if (k>=cw[n]) return(x[p[n]])
    for (i=1; i<=n; i++) {
        if (k>=cw[i]) continue
        return(x[p[i]])
    }
    // cannot be reached
}
*/

end


*! {smcl}
*! {marker mm_ebalance}{bf:mm_ebalance.mata}{asis}
*! version 1.0.8  27apr2022  Ben Jann

version 11.2

// class & struct
local MAIN   mm_ebalance
local SETUP  _`MAIN'_setup
local Setup  struct `SETUP' scalar
local IF     _`MAIN'_IF
local If     struct `IF' scalar
// real
local RS     real scalar
local RR     real rowvector
local RC     real colvector
local RM     real matrix
// counters
local Int    real scalar
local IntC   real colvector
local IntR   real rowvector
local IntM   real matrix
// string
local SS     string scalar
// boolean
local Bool   real scalar
local BoolC  real colvector
// transmorphic
local T      transmorphic
local TS     transmorphic scalar
// pointers
local PC     pointer(real colvector) scalar
local PM     pointer(real matrix) scalar

mata:

// class ----------------------------------------------------------------------

struct `SETUP' {
    // data
    `PM'    X, X0      // pointers to main data and reference data
    `PC'    w, w0      // pointers to base weights
    `Int'   N, N0      // number of obs
    `RS'    W, W0      // sum of weights
    `RR'    m, m0      // means
    `RR'    s, s0      // scales
    `Int'   k          // number of terms
    `BoolC' omit       // flag collinear terms
    `Int'   k_omit     // number of omitted terms
    
    // settings
    `IntR'  adj, noadj // indices of columns to be adjusted/not adjusted
    `T'     tau        // target sum of weights
    `SS'    scale      // type of scales
    `RC'    btol       // balancing tolerance
    `SS'    ltype      // type of loss function
    `SS'    etype      // evaluator type
    `SS'    trace      // trace level
    `Int'   maxiter    // max number of iterations
    `RS'    ptol       // convergence tolerance for the parameter vector
    `RS'    vtol       // convergence tolerance for the balancing loss
    `Bool'  difficult  // use hybrid optimization
    `Bool'  nostd      // do not standardize
    `Bool'  nowarn     // do not display no convergence/balance warning
}

struct `IF' {
    `RM'   b, b0      // influence functions of coefficients
    `RC'   a, a0      // influence function of intercept
}

class `MAIN' {
    // settings
    private:
        void    new()             // initialize class with default settings
        void    clear()           // clear all results
        `Bool'  nodata()          // whether data is set
        `Setup' setup             // container for data and settings
    public:
        void    data()
        `RM'    X(), Xref()       // retrieve data
        `RC'    w(), wref()       // retrieve base weights
        `RS'    N(), Nref()       // retrieve number of obs
        `RS'    W(), Wref()       // retrieve sum of weights
        `RR'    m(), mref()       // retrieve moments
        `RR'    s(), sref()       // retrieve scales
        `RR'    mu()              // retrieve target moments
        `Int'   k()               // retrieve number of terms
        `RC'    omit()            // retrieve omitted flags
        `Int'   k_omit()          // retrieve number of omitted terms
        `T'     adj(), noadj()    // set/retrieve adj/noadj
        `T'     tau()             // set/retrieve target sum of weights
        `T'     scale()           // set/retrieve scales
        `T'     btol()            // set/retrieve balancing tolerance
        `T'     ltype()           // set/retrieve loss function
        `T'     etype()           // set/retrieve evaluator type
        `T'     alteval()         // set/retrieve alteval flag (old)
        `T'     trace()           // set/retrieve trace level
        `T'     maxiter()         // set/retrieve max iterations
        `T'     ptol()            // set/retrieve p-tolerance
        `T'     vtol()            // set/retrieve v-tolerance
        `T'     difficult()       // set/retrieve difficult flag
        `T'     nostd()           // set/retrieve nostd flag
        `T'     nowarn()          // set/retrieve nowarn flag
    
    // results
    public:
        `RC'    b()               // retrieve coefficients
        `RS'    a()               // retrieve normalizing intercept
        `RC'    xb()              // retrieve linear prediction
        `RC'    wbal()            // retrieve balancing weights
        `RC'    pr()              // retrieve propensity score
        `RR'    madj()            // adjusted (reweighted) means
        `RS'    wsum()            // retrieve sum of balancing weights
        `RS'    loss()            // retrieve final balancing loss
        `Bool'  balanced()        // retrieve balancing flag
        `RS'    value()           // retrieve value of optimization criterion
        `Int'   iter()            // retrieve number of iterations
        `Bool'  converged()       // retrieve convergence flag
        `RM'    IF_b(), IFref_b() // retrieve IF of coefficients
        `RC'    IF_a(), IFref_a() // retrieve IF of intercept
    private:
        `RS'    tau               // target sum of weight
        `RR'    mu                // target means
        `IntM'  adj, noadj        // permutation vectors for source of mu
        `RR'    scale             // scales for standardization
        `RC'    b                 // coefficients
        `RS'    a                 // normalizing intercept
        `RC'    xb                // linear prediction (without a)
        `RC'    wbal              // balancing weights
        `RR'    madj              // adjusted (reweighted) means
        `RS'    wsum              // sum of balancing weights
        `RC'    loss              // balancing loss
        `Bool'  balanced          // balance achieved
        `RS'    value             // value of optimization criterion
        `Int'   iter              // number of iterations
        `Bool'  conv              // optimize() convergence
        `If'    IF                // influence functions
        void    _IF_b(), _IF_a()  // generate influence functions
        void    Fit()             // fit coefficients
        void    _Fit_b(), _Fit_a()
        `RM'    _Fit_b_X(), _Fit_b_Xc()
        `RR'    _Fit_b_mu()
        void    _setadj()         // fill in adj and noadj
}

// init -----------------------------------------------------------------------

void `MAIN'::new()
{
    setup.tau       = "Wref"
    setup.scale     = "main"
    setup.ltype     = "reldif"
    setup.etype     = "bl"
    setup.nostd     = 0
    setup.trace     = (st_global("c(iterlog)")=="off" ? "none" : "value")
    setup.difficult = 0
    setup.maxiter   = st_numscalar("c(maxiter)")
    setup.ptol      = 1e-6
    setup.vtol      = 1e-7
    setup.btol      = 1e-6
    setup.nowarn    = 0
    setup.adj       = .
    b               = .z
}

void `MAIN'::clear()
{
    mu   = J(1,0,.)
    adj  = noadj = J(0,0,.)
    tau  = wsum = .
    b    = .z
    a    = .
    xb   = wbal = J(0,1,.)
    madj = J(1,0,.)
    loss = balanced = value = iter = conv = .
    IF   = `IF'()
}

// data -----------------------------------------------------------------------

void `MAIN'::data(`RM' X, `RC' w, `RM' X0, `RC' w0, | `Bool' fast)
{
    `RM' CP
    
    // check for missing values and negative weights
    if (args()<5) fast = 0
    if (!fast) {
        if (missing(X) | missing(w) | missing(X0) | missing(w0)) _error(3351)
        if (any(w:<0) | any(w0:<0)) _error(3498, "w and wref must be positive")
    }
    // obtain main data
    setup.k = cols(X)
    setup.X = &X
    setup.w = &w
    setup.N = rows(X)
    if (setup.N==0) _error(2000, "no observations in main data")
    if (rows(w)!=1) {
        if (rows(X)!=rows(w)) _error(3200, "X and w not conformable")
        setup.W = quadsum(w)
    }
    else setup.W = setup.N * w
    if (setup.N!=0 & setup.W==0) _error(3498, "sum(w) must be > 0")
    // obtain reference data
    setup.X0 = &X0
    setup.w0 = &w0
    setup.N0 = rows(X0)
    if (setup.N0==0) _error(2000, "no d observations in reference data")
    if (rows(w0)!=1) {
        if (rows(X0)!=rows(w0)) _error(3200, "Xref and wref not conformable")
        setup.W0 = quadsum(w0)
    }
    else setup.W0 = setup.N0 * w0
    if (setup.N0!=0 & setup.W0==0) _error(3498, "sum(wref) must be > 0")
    // if scale is set by user
    if (setup.scale=="user") {
        if (setup.k!=length(scale)) _error(3200, "X not conformable with scale")
    }
    else scale = J(1,0,.) // (clear scale)
    // target moments
    if (setup.k!=cols(X0)) _error(3200, "X and Xref not conformable")
    setup.m0 = mean(X0, w0)
    // identify collinear terms
    setup.m  = mean(X, w)
    if (setup.k==0) {
        setup.omit = J(0,1,.)
        setup.s = J(1,0,.)
        setup.k_omit = 0
    }
    else {
        CP = quadcrossdev(X, setup.m, w, X, setup.m)
        setup.s = sqrt(diagonal(CP)' / setup.W)
        setup.omit = (diagonal(invsym(CP)):==0) // or: diagonal(invsym(CP, 1..setup.k)):==0
        setup.k_omit = sum(setup.omit)
    }
    // clear results
    setup.s0 = J(1,0,.) // (scale of refdata will be set later only if needed)
    clear()
}

`Bool' `MAIN'::nodata() return(setup.X==NULL)

`RM' `MAIN'::X()
{
    if (nodata()) return(J(0,1,.))
    return(*setup.X)
}

`RM' `MAIN'::Xref()
{
    if (nodata()) return(J(0,1,.))
    return(*setup.X0)
}

`RC' `MAIN'::w()
{
    if (nodata()) return(J(0,1,.))
    return(*setup.w)
}

`RC' `MAIN'::wref()
{
    if (nodata()) return(J(0,1,.))
    return(*setup.w0)
}

`Int' `MAIN'::N() return(setup.N)

`Int' `MAIN'::Nref() return(setup.N0)

`RS' `MAIN'::W() return(setup.W)

`RS' `MAIN'::Wref() return(setup.W0)

`RR' `MAIN'::m() return(setup.m)

`RR' `MAIN'::mref() return(setup.m0)

`RR' `MAIN'::s() return(setup.s)

`RR' `MAIN'::sref() {
    if (length(setup.s0)) return(setup.s0)
    if (nodata()) return(setup.s0)
    setup.s0 = sqrt(diagonal(quadcrossdev(*setup.X0, setup.m0, *setup.w0, 
        *setup.X0, setup.m0))'/setup.W0)
    return(setup.s0)
}

`Int' `MAIN'::k() return(setup.k)

`RC'  `MAIN'::omit() return(setup.omit)

`Int' `MAIN'::k_omit() return(setup.k_omit)

// settings -------------------------------------------------------------------

`T' `MAIN'::tau(| `TS' tau)
{
    if (args()==0) {
        if (tau<.) return(tau)
        if (nodata()) return(setup.tau)
        if      (setup.tau=="Wref") tau = setup.W0
        else if (setup.tau=="W")    tau = setup.W
        else if (setup.tau=="Nref") tau = setup.N0
        else if (setup.tau=="N")    tau = setup.N
        else                        tau = setup.tau
        return(tau)
    }
    if (setup.tau==tau) return // no change
    if (isstring(tau)) {
        if (!anyof(("Wref", "W", "Nref", "N"), tau)) {
            printf("{err}'%s' not allowed\n", tau)
            _error(3498)
        }
    }
    else if (tau<=0 | tau>=.) _error(3498, "setting out of range")
    setup.tau = tau
    clear()
}

`T' `MAIN'::scale(| `T' scale0)
{
    `RR' s
    
    if (args()==0) {
        if (length(scale)) return(scale)
        if (setup.scale=="user") return(scale)
        if (nodata()) return(setup.scale)
        if      (setup.scale=="main") scale = s()
        else if (setup.scale=="ref")  scale = sref()
        else if (setup.scale=="avg")  scale = (s() + sref()) / 2
        else if (setup.scale=="wavg") scale = (s()*W() + sref()*Wref()) / 
                                              (W() + Wref())
        else if (setup.scale=="pooled") {
            scale = sqrt(diagonal(mm_variance0(X() \ Xref(), 
                (rows(w())==1  ? J(N(), 1, w())   : w()) \ 
                (rows(wref())==1 ? J(Nref(), 1, wref()) : wref())))')
        }
        _editvalue(scale, 0, 1)
        return(scale)
    }
    if (isstring(scale0)) {
        if (length(scale0)!=1) _error(3200)
        if (setup.scale==scale0) return // no change
        if (!anyof(("main", "ref", "avg", "wavg","pooled"), scale0)) {
            printf("{err}'%s' not allowed\n", scale0)
            _error(3498)
        }
        setup.scale = scale0
        scale = J(1,0,.)
    }
    else {
        s = editvalue(vec(scale0)', 0, 1)
        if (scale==s) return // no change
        if (missing(s)) _error(3351)
        if (any(s:<=0)) _error(3498, "scale must be positive")
        if (nodata()==0) {
            if (setup.k!=length(s)) _error(3200, "scale not conformable with X")
        }
        setup.scale = "user"
        scale = s
    }
    clear()
}

`T' `MAIN'::btol(| `RS' btol)
{
    if (args()==0) return(setup.btol)
    if (setup.btol==btol) return // no change
    if (btol<=0) _error(3498, "setting out of range")
    setup.btol = btol
    clear()
}

`T' `MAIN'::ltype(| `SS' ltype)
{
    if (args()==0) return(setup.ltype)
    if (setup.ltype==ltype) return // no change
    if (!anyof(("reldif", "absdif", "norm"), ltype)) {
        printf("{err}'%s' not allowed\n", ltype)
        _error(3498)
    }
    setup.ltype = ltype
    clear()
}

`T' `MAIN'::etype(| `SS' etype)
{
    if (args()==0) return(setup.etype)
    if (setup.etype==etype) return // no change
    if (!anyof(("bl","wl","mm","mma"), etype)) {
        printf("{err}'%g' not allowed\n", etype)
        _error(3498)
    }
    setup.etype = etype
    clear()
}

`T' `MAIN'::alteval(| `Bool' alteval) // for backward compatibility
{
    if (args()==0) return(setup.etype=="wl")
    if (setup.etype==(alteval!=0 ? "wl" : "bl")) return // no change
    setup.etype = (alteval!=0 ? "wl" : "bl")
    clear()
}

`T' `MAIN'::trace(| `SS' trace)
{
    `T' S
    
    if (args()==0) return(setup.trace)
    if (setup.trace==trace) return // no change
    S = optimize_init()
    optimize_init_tracelevel(S, trace) // throw error if trace is invalid
    setup.trace = trace
    clear()
}

`T' `MAIN'::maxiter(| `Int' maxiter)
{
    if (args()==0) return(setup.maxiter)
    if (setup.maxiter==maxiter) return // no change
    if (maxiter<0) _error(3498, "setting out of range")
    setup.maxiter = maxiter
    clear()
}

`T' `MAIN'::ptol(| `RS' ptol)
{
    if (args()==0) return(setup.ptol)
    if (setup.ptol==ptol) return // no change
    if (ptol<=0) _error(3498, "setting out of range")
    setup.ptol = ptol
    clear()
}

`T' `MAIN'::vtol(| `RS' vtol)
{
    if (args()==0) return(setup.vtol)
    if (setup.vtol==vtol) return // no change
    if (vtol<=0) _error(3498, "setting out of range")
    setup.vtol = vtol
    clear()
}

`T' `MAIN'::difficult(| `Bool' difficult)
{
    if (args()==0) return(setup.difficult)
    if (setup.difficult==(difficult!=0)) return // no change
    setup.difficult = (difficult!=0)
    clear()
}

`T' `MAIN'::nostd(| `Bool' nostd)
{
    if (args()==0) return(setup.nostd)
    if (setup.nostd==(nostd!=0)) return // no change
    setup.nostd = (nostd!=0)
    clear()
}

`T' `MAIN'::nowarn(| `Bool' nowarn)
{
    if (args()==0) return(setup.nowarn)
    if (setup.nowarn==(nowarn!=0)) return // no change
    setup.nowarn = (nowarn!=0)
    clear()
}

// target means ---------------------------------------------------------------

`T' `MAIN'::adj(| `RM' adj0)
{
    `RR' adj
    
    if (args()==0) {
        if (nodata()) return(setup.adj)
        _setadj()
        return(this.adj)
    }
    adj = mm_unique(trunc(vec(adj0)))'
    if (length(setup.noadj)==0) {
        if (setup.adj==adj) return // no change
    }
    if (adj!=.) {
        if (any(adj:<1)) _error(3498, "setting out of range")
    }
    setup.adj   = adj
    setup.noadj = J(1,0,.)
    clear()
}

`T' `MAIN'::noadj(| `RM' noadj0)
{
    `RR' noadj
    
    if (args()==0) {
        if (nodata()) return(setup.noadj)
        _setadj()
        return(this.noadj)
    }
    noadj = mm_unique(trunc(vec(noadj0)))'
    if (setup.adj==.) {
        if (setup.noadj==noadj) return // no change
    }
    if (any(noadj:<1)) _error(3498, "setting out of range")
    setup.adj   = .
    setup.noadj = noadj
    clear()
}

void `MAIN'::_setadj() // assumes that data has been set
{
    `Int'  i, j, k
    `IntR' p
    
    if (rows(adj)) return // already set
    // zero variables
    k = k()
    if (k==0) {
        adj = noadj = J(1,0,.)
        return
    }
    // case 1: noadj() has been set
    if (length(setup.noadj)) {
        p = J(1,k,1)
        for (i=length(setup.noadj); i; i--) {
            j = setup.noadj[i]
            if (j>k) continue // be tolerant and ignore invalid subscripts
            p[j] = 0
        }
    }
    // case 2: adj() has been set
    else if (setup.adj!=.) {
        p = J(1,k,0)
        for (i=length(setup.adj); i; i--) {
            j = setup.adj[i]
            if (j>k) continue // be tolerant and ignore invalid subscripts
            p[j] = 1
        }
    }
    // case 3: default (adjust all)
    else p = 1
    // fill in adj/noadj
    if (allof(p, 1)) {
        adj   = .
        noadj = J(1,0,.)
    }
    else {
        adj   = select(1..k, p)
        noadj = select(1..k,!p)
    }
}

`RR' `MAIN'::mu()
{
    if (length(mu)) return(mu)
    if (nodata())   return(mu)
    if (adj()==.) {
        mu = setup.m0
    }
    else {
        mu = setup.m
        if (length(adj())) mu[adj()] = setup.m0[adj()]
    }
    return(mu)
}

// results --------------------------------------------------------------------

`RC' `MAIN'::b()
{
    if (b==.z) Fit()
    return(b)
}

`RS' `MAIN'::a()
{
    if (b==.z) Fit()
    return(a)
}

`RC' `MAIN'::wbal()
{
    if (b==.z) Fit()
    return(wbal)
}

`RC' `MAIN'::xb()
{
    if (b==.z) Fit()
    return(xb :+ a)
}

`RC' `MAIN'::pr()
{
    if (b==.z) Fit()
    return(invlogit(xb :+ (a + ln(Wref()/tau()))))
}

`RR' `MAIN'::madj()
{
    if (length(madj)) return(madj)
    if (b==.z) Fit()
    madj = mean(X(), wbal)
    return(madj)
}

`RS' `MAIN'::wsum()
{
    if (wsum<.) return(wsum)
    if (b==.z) Fit()
    wsum = quadsum(wbal)
    return(wsum)
}

`RS' `MAIN'::loss()
{
    if (b==.z) Fit()
    return(loss)
}

`RS' `MAIN'::balanced()
{
    if (b==.z) Fit()
    return(balanced)
}

`RS' `MAIN'::value()
{
    if (b==.z) Fit()
    return(value)
}

`Int' `MAIN'::iter()
{
    if (b==.z) Fit()
    return(iter)
}

`Bool' `MAIN'::converged()
{
    if (b==.z) Fit()
    return(conv)
}

`RM' `MAIN'::IF_b()
{
    if (rows(IF.b)==0) _IF_b()
    return(IF.b)
}

`RM' `MAIN'::IFref_b()
{
    if (rows(IF.b0)==0) _IF_b()
    return(IF.b0)
}

`RC' `MAIN'::IF_a()
{
    if (rows(IF.a)==0) _IF_a()
    return(IF.a)
}

`RC' `MAIN'::IFref_a()
{
    if (rows(IF.a0)==0) _IF_a()
    return(IF.a0)
}

// optimization ---------------------------------------------------------------

void `MAIN'::Fit()
{
    `Bool' nofit
    
    // optimize
    if (nodata()) _error(3498, "data not set")
    nofit = 0
    if ((k()-k_omit())<=0)   nofit = 1     // no covariates
    else if (!length(adj())) nofit = 1     // no adjustments
    else if (length(noadj()) & k_omit()) { // no adjustments among non-omitted
        nofit = all(omit()[adj()])
    }
    if (nofit) {
        b = J(k(), 1, 0)
        iter = value = 0
        conv = 1
    }
    else _Fit_b() // fit coefficients
    _Fit_a()      // compute intercept (and balancing weights)
    
    // check balancing
    if (etype()!="bl" | nostd()==0 | k_omit()) {
        // compute balancing loss using raw data
        loss = _mm_ebalance_loss(ltype(), mean(X():-mu(), wbal), mu())
    }
    else loss = value
    if (trace()!="none") {
        printf("{txt}Final fit:     balancing loss = {res}%10.0g\n", loss)
    }
    balanced = (loss<btol())
    if (balanced) return
    if (nowarn()) return
    display("{err}balance not achieved")
}

void `MAIN'::_Fit_a()
{
    `RS' ul
    
    xb = X() * b
    ul = max(xb) // set exp(max)=1 to avoid numerical overflow
    a  = ln(tau()) - ln(quadsum(w() :* exp(xb :- ul))) - ul
    wbal = w() :* exp(xb :+ a)
}

void `MAIN'::_Fit_b()
{
    `RR'   beta
    `IntC' p
    `T'    S
    
    // setup
    if (k_omit()) p = select(1::k(), omit():==0)
    S = optimize_init()
    optimize_init_which(S, "min")
    optimize_init_technique(S, "nr")
    optimize_init_evaluatortype(S, "d2")
    optimize_init_conv_ignorenrtol(S, "on")
    optimize_init_tracelevel(S, trace())
    optimize_init_conv_maxiter(S, maxiter())
    optimize_init_conv_ptol(S, ptol())
    optimize_init_conv_vtol(S, vtol())
    optimize_init_singularHmethod(S, difficult() ? "hybrid" : "")
    optimize_init_conv_warning(S, nowarn() ? "off" : "on")
    if (etype()=="mma") {
        optimize_init_evaluator(S, &_mm_ebalance_mma())   // gmm w/ alpha
        optimize_init_valueid(S, "criterion Q(p)")
        optimize_init_params(S, J(1, k()-k_omit()+1, 0)) // starting values
        optimize_init_argument(S, 1, _Fit_b_X(p))        // data
        optimize_init_argument(S, 2, _Fit_b_mu(p))       // target moments
        optimize_init_argument(S, 3, w()/W())            // norm. base weights
        optimize_init_argument(S, 4, tau()/W())          // normalized tau
    }
    else if (etype()=="mm") {
        optimize_init_evaluator(S, &_mm_ebalance_mm())   // gmm
        optimize_init_valueid(S, "criterion Q(p)")
        optimize_init_params(S, J(1, k()-k_omit(), 0))   // starting values
        optimize_init_argument(S, 1, _Fit_b_X(p))        // data
        optimize_init_argument(S, 2, _Fit_b_mu(p))       // target moments
        optimize_init_argument(S, 3, w())                // base weights
    }
    else if (etype()=="wl") {
        optimize_init_evaluator(S, &_mm_ebalance_lw())   // sum(ln(w))
        optimize_init_valueid(S, "criterion L(w)")
        optimize_init_params(S, J(1, k()-k_omit(), 0))   // starting values
        optimize_init_argument(S, 1, _Fit_b_Xc(p))       // centered data
        optimize_init_argument(S, 2, w())                // base weights
    }
    else {
        optimize_init_evaluator(S, &_mm_ebalance_bl())  // balance loss
        optimize_init_valueid(S, "balancing loss")
        optimize_init_params(S, J(1, k()-k_omit(), 0))  // starting values
        optimize_init_argument(S, 1, _Fit_b_Xc(p))      // centered data
        optimize_init_argument(S, 2, _Fit_b_mu(p))      // target moments
        optimize_init_argument(S, 3, w())               // base weights
        optimize_init_argument(S, 4, ltype())           // loss type
    }
    
    // run optimizer
    (void) _optimize(S)
    if (optimize_result_errorcode(S)) {
        errprintf("{p}\n")
        errprintf("%s\n", optimize_result_errortext(S))
        errprintf("{p_end}\n")
        exit(optimize_result_returncode(S))
    }
    
    // obtain results
    beta = optimize_result_params(S)
    if (etype()=="mma") beta = beta[|1\length(beta)-1|] // discard alpha
    if (k_omit()) {
        b = J(k(), 1, 0)
        if (nostd()) b[p] = beta'
        else         b[p] = (beta :/ scale()[p])'
    }
    else {
        if (nostd()) b = beta'
        else         b = (beta :/ scale())'
    }
    iter  = optimize_result_iterations(S)
    value = optimize_result_value(S)
    conv  = optimize_result_converged(S)
}

`RM' `MAIN'::_Fit_b_X(`IntC' p)
{
    if (k_omit()) {
        if (nostd()) return(X()[,p])
        return(X()[,p] :/ scale()[p])
    }
    if (nostd()) return(X())
    return(X() :/ scale())
}

`RM' `MAIN'::_Fit_b_Xc(`IntC' p)
{
    if (k_omit()) {
        if (nostd()) return(X()[,p] :- mu()[p])
        return((X()[,p] :- mu()[p]) :/ scale()[p])
    }
    if (nostd()) return(X() :- mu())
    return((X() :- mu()) :/ scale())
}

`RR' `MAIN'::_Fit_b_mu(`IntC' p)
{
    if (k_omit()) {
        if (nostd()) return(mu()[p])
        return((mu() :/ scale())[p])
    }
    if (nostd()) return(mu())
    return(mu() :/ scale())
}

`RS' _mm_ebalance_loss(`SS' ltype, `RR' d, `RR' mu)
{
    if (ltype=="absdif") return(max(abs(d)))
    if (ltype=="norm") return(sqrt(d*d'))
    return(mreldif(d+mu, mu)) // ltype=="reldif"
}

void _mm_ebalance_bl(`Int' todo, `RR' b, `RM' X, `RR' mu, `RC' w0, `SS' ltype,
    `RS' v, `RR' g, `RM' H)
{   // evaluator based on balance loss
    `RS' W
    `RC' w
    
    w = X * b'
    w = w0 :* exp(w :- max(w)) // avoid numerical overflow
    W = quadsum(w)
    g = quadcross(w, X) / W
    v = _mm_ebalance_loss(ltype, g, mu)
    if (todo==2) H = quadcross(X, w, X) / W
}

void _mm_ebalance_lw(`Int' todo, `RR' b, `RM' X, `RC' w0,
    `RS' v, `RR' g, `RM' H)
{   // evaluator using sum(ln(weights)) as criterion
    `RS' W
    `RC' w
    
    w = X * b'
    W = max(w)
    w = w0 :* exp(w :- W) // avoid numerical overflow
    v = ln(quadsum(w)) + W
    if (todo>=1) {
        W = quadsum(w)
        g = quadcross(w, X) / W
        if (todo==2) H = quadcross(X, w, X) / W
    }
}

void _mm_ebalance_mm(`Int' todo, `RR' b, `RM' X, `RR' mu, `RC' w0,
    `RS' v, `RR' g, `RM' H)
{   // gmm type evaluator
    `RC' w
    `RR' d
    `RM' h, G
    
    w = X * b'
    w = w0 :* exp(w :- max(w))   // avoid numerical overflow
    w = w :/ quadcolsum(w)
    h = w :* (X :- mu)
    d = quadcolsum(h)
    v = d * d'
    if (todo>=1) {
        G = quadcross(h, X :- quadcolsum(w:*X))
        g = d * G
        if (todo==2) H = G'G 
    }
}

void _mm_ebalance_mma(`Int' todo, `RR' b, `RM' X, `RR' mu, `RC' w0, `RS' tau, 
    `RS' v, `RR' g, `RM' H)
{   // gmm type evaluator including alpha in optimization problem
    `RS' d_a
    `RC' w, h_a
    `RR' d_b
    `RM' h_b, G
    
    w   = w0 :* exp(X*b[|1\cols(b)-1|]' :+ b[cols(b)])
    h_b = w :* (X :- mu)
    h_a = w :- w0:*tau  // = w0 * (e(xb+a) - tau/sum(w0))
    d_b = quadcolsum(h_b)
    d_a = quadsum(h_a)
    v   = (d_b, d_a) * (d_b, d_a)'
    if (todo>=1) {
        G = (quadcross(h_b, X), d_b') \ (quadcross(w, X), colsum(w))
        g = (d_b, d_a) * G
        if (todo==2) H = G'G 
    }
}

// influence functions --------------------------------------------------------

void `MAIN'::_IF_b()
{
    if (b==.z) Fit()
    _mm_ebalance_IF_b(IF, X(), Xref(), w(), wbal, madj(), mu(), tau(), W(),
        Wref(), adj(), noadj(), omit())
}

void `MAIN'::_IF_a()
{
    if (rows(IF.b)==0) _IF_b()
    _mm_ebalance_IF_a(IF, X(), w(), wbal, tau(), W())
}

void _mm_ebalance_IF_b(`If' IF, `RM' X, `RM' Xref, `RC' w, `RC' wbal, `RR' madj,
    `RR' mu, `RS' tau, `RS' W, `RS' Wref, `IntR' adj, `IntR' noadj,
    | `BoolC' omit)
{   // using "alternative approach" formulas
    `Int'  k
    `IntC' p
    `RM'   G, Q
    
    G = quadcrossdev(X, mu, wbal/tau, X, madj)
    if (length(omit)==0) {
        // no information on omitted terms; use qrinv()
        Q = -qrinv(G)
    }
    else if (any(omit)) {
        // discard omitted terms during inversion
        k = cols(X)
        p = select(1::k, omit:==0)
        Q = J(k, k, 0)
        if (length(p)) Q[p,p] = -luinv(G[p,p])
    }
    else {
        // no omitted terms; save to use luinv()
        Q = -luinv(G)
    }
    IF.b  = (wbal/tau):/w :* (X :- madj) * Q' // [sic!] using madj instead of mu
    if (adj==.) IF.b0 = (-1/Wref) * (Xref :- mu) * Q'
    else {
        IF.b0 = J(rows(Xref), cols(Xref), 0)
        if (length(adj)) IF.b0[, adj] = 
            (-1/Wref) * (Xref[,adj] :- mu[adj]) * Q[adj,adj]'
        if (length(noadj)) IF.b[,noadj] = 
            IF.b[,noadj] - (X[,noadj] :- mu[noadj])/W * Q[noadj,noadj]'
    }
}

void _mm_ebalance_IF_a(`If' IF, `RM' X, `RC' w, `RC' wbal, `RS' tau, `RS' W)
{    // using simplified IF assuming tau as fixed
    `RM' Q
    
    Q = quadcross(wbal, X)
    IF.a  = ((wbal:/w :- tau/W) + IF.b * Q') / -tau
    IF.a0 = (IF.b0 * Q') / -tau
}

end



*! {smcl}
*! {marker mm_wbal}{bf:mm_wbal.mata}{asis}
*! version 1.0.0  25jul2021  Ben Jann
version 11.2

mata:

real colvector mm_wbal(real matrix X, real colvector w, 
    real matrix X0, real colvector w0, | real scalar nowarn)
{
    class mm_ebalance scalar S
    
    if (args()<5) nowarn = 0
    S.nowarn(nowarn)
    S.trace("none")
    S.data(X, w, X0, w0)
    return(S.wbal())
}

end



*! {smcl}
*! {marker mm_ebal}{bf:mm_ebal.mata}{asis}
*! version 1.0.4  05aug2019  Ben Jann
version 11.2
local Bool  real scalar
local BoolR real rowvector
local Int   real scalar
local IntV  real vector
local IntC  real colvector
local RS    real scalar
local RV    real vector
local RC    real colvector
local RR    real rowvector
local RM    real matrix
local SS    string scalar
local T     transmorphic
local pRC   pointer (`RC') scalar
local pRM   pointer (`RM') scalar
local S     struct _mm_ebal_struct scalar
mata:

struct _mm_ebal_struct {
    `T'    O         // optimization object
    `RS'   N         // target group size (sum of weights)
    `RM'   C         // constraint matrix (excluding collinear columns)
    `RM'   CC        // collinear columns from constraint matrix
    `RC'   Q         // base weights
    `RC'   W         // balancing weights
    `Bool' nc        // normalizing constraint included in optimization
    `RS'   btol      // balancing tolerance
    `Bool' balanced  // 1 if balance achieved, 0 else
    `RR'   Z         // coefficients
    `RR'   g         // gradient
    `RS'   v         // max difference
    `Bool' conv      // 1 if optimize() converged
    `Int'  rc        // return code of optimize()
    `Int'  i         // number of iterations
}

`RR'   mm_ebal_N(`S' S)        return(S.N)
`RM'   mm_ebal_C(`S' S)        return(S.C)
`RM'   mm_ebal_CC(`S' S)       return(S.CC)
`RC'   mm_ebal_Q(`S' S)        return(S.Q)
`RC'   mm_ebal_W(`S' S)        return(S.W)
`Bool' mm_ebal_balanced(`S' S) return(S.balanced)
`RR'   mm_ebal_g(`S' S)        return(S.g)
`RS'   mm_ebal_v(`S' S)        return(S.v)
`Bool' mm_ebal_conv(`S' S)     return(S.conv)
`Int'  mm_ebal_rc(`S' S)       return(S.rc)
`Int'  mm_ebal_i(`S' S)        return(S.i)

`S' mm_ebal_init(
    `RM'   X1,       // treatment group data
    `RC'   w1,       // treatment group base weights
    `RM'   X0,       // control group data
    `RC'   w0,       // control group base weights
  | `IntV' tar,      // balancing targets
    `Bool' cov,      // whether to balance covariances
    `Bool' nc,       // include normalizing constraint in optimization
    `Bool' dfc,      // apply degrees of freedom correction
    `Bool' nostd)    // do not standardize data
{
    `S'     S        // main struct
    `RM'    C1       // expanded X1
    `RM'    M        // target moments
    `RR'    sd       // standard deviations of C1
    `BoolR' collin   // collinear terms
    `RS'    N0       // size of control group (sum of weights)
    
    // defaults
    if (args()<5)  tar   = 1
    if (args()<6)  cov   = 0
    if (args()<7)  nc    = 0
    if (args()<8)  dfc   = 0
    if (args()<9)  nostd = 0
    S.nc   = nc
    S.btol = 1e-5
    // compute group sizes and check input
    // - treatment group
    if (hasmissing(X1)) _error(3351, "missing values in X1 not allowed")
    if (hasmissing(w1)) _error(3351, "missing values in w1 not allowed")
    if (any(w1:<0))     _error(3498, "negative values in w1 not allowed")
    if (rows(w1)==1) S.N = w1*rows(X1)
    else {
        if (rows(w1)!=rows(X1)) _error(3200, "X1 and w1 not conformable")
        S.N = quadsum(w1)
    }
    if (S.N<=0) _error(3498, "treatment group size (sum of weights) must be > 0")
    // - control group
    if (hasmissing(X0)) _error(3351, "missing values in X0 not allowed")
    if (hasmissing(w0)) _error(3351, "missing values in w0 not allowed")
    if (any(w0:<0))     _error(3498, "negative values in w0 not allowed")
    if (rows(w0)==1) N0 = w0*rows(X0)
    else {
        if (rows(w0)!=rows(X0)) _error(3200, "X0 and w0 not conformable")
        N0 = quadsum(w0)
    }
    if (N0<=0) _error(3498, "control group size (sum of weights) must be > 0")
    // - number of variables
    if (cols(X1)!=cols(X0)) _error(3200, "X1 and X0 must contain the same number of columns")
    // expand X1 and X0
    C1  = _mm_ebal_C(tar, cov, X1)
    S.C = _mm_ebal_C(tar, cov, X0)
    // determine collinear columns (across both groups)
    collin = _mm_ebal_collin((C1 \ S.C),
             ((rows(w1)==1 ? J(rows(C1), 1, w1) : w1)
            \ (rows(w0)==1 ? J(rows(S.C), 1, w0) : w0)))
    // compute target moments
    M = mean(C1, w1)
    // apply degrees-of-freedom correction to target moments
    if (dfc) _mm_ebal_dfc(tar, cov, M, cols(X1), (1-1/S.N) / (1-1/N0))
    // remove collinear columns
    if (any(collin)) {
        C1  = select(C1, collin:==0)
        M   = select(M, collin:==0)
        S.C = select(S.C, collin:==0)
    }
    // determine scale for standardization
    if (rows(C1)==1) sd = J(1, cols(C1), 1)
    else {
        if (nostd | cols(C1)==0) sd = J(1, cols(C1), 1)
        else {
            sd = sqrt(diagonal(variance(C1, w1)))'
            _editvalue(sd, 0, 1) // (sd must not be zero)
        }
    }
    // prepare constraints matrix
    // - determine collinear columns
    collin = _mm_ebal_collin(S.C, w0)
    // - standardize constraints
    S.C = (S.C :- M) :/ sd
    // - put collinear columns aside
    if (any(collin)) {
        S.CC = select(S.C, collin)
        S.C  = select(S.C, collin:==0)
    }
    // - add normalization constraint
    if (S.nc) S.C = J(rows(X0),1,1), S.C
    // prepare base weights
    if (S.nc) S.Q = w0
    else      S.Q = w0 / N0
    // prepare optimization object
    S.O = optimize_init()
    optimize_init_which(S.O, "min")
    optimize_init_evaluator(S.O, &_mm_ebal_eval())
    optimize_init_evaluatortype(S.O, "d2")
    optimize_init_conv_ignorenrtol(S.O, "on")
    optimize_init_valueid(S.O, "max difference")
    // set starting values
    if (S.nc) S.Z = (ln(S.N/N0), J(1, cols(S.C)-1, 0))
    else      S.Z = J(1, cols(S.C), 0)
    // done
    return(S)
}

`RM' _mm_ebal_C(`RV' t, `Bool' cov, `RM' X)
{
    `RS' i, j, k, l, c
    `RM' C
    
    // confirm that t only contains 1, 2, or 3
    l = length(t)
    for (i=1; i<=l; i++) {
        if (!anyof((1,2,3), t[i])) _error("invalid target")
    }
    // means
    k = cols(X)
    C = X
    // variances
    if (any(t:>=2)) {
        c = 0
        for (i=1; i<=k; i++) { // count number of terms
            if (t[mod(i-1, l) + 1] >= 2) c++
        }
        C = C, J(rows(C), c, .)
        c = cols(C) - c
        for (i=1; i<=k; i++) {
            if (t[mod(i-1, l) + 1] >= 2) {
                C[,++c] = X[,i]:^2
            }
        }
    }
    // covariances
    if (cov) {
        c = (k^2 - k)/2
        C = C, J(rows(C), c, .)
        c = cols(C) - c
        for (i=1; i<k; i++) {
            for (j=i+1; j<=k; j++) {
                C[,++c] = X[,i] :* X[,j]
            }
        }
    }
    // skewnesses
    if (anyof(t,3)) {
        c = 0
        for (i=1; i<=k; i++) { // count number of terms
            if (t[mod(i-1, l) + 1] == 3) c++
        }
        C = C, J(rows(C), c, .)
        c = cols(C) - c
        for (i=1; i<=k; i++) {
            if (t[mod(i-1, l) + 1] == 3) {
                C[,++c] = X[,i]:^3
            }
        }
    }
    return(C)
}

`BoolR' _mm_ebal_collin(`RM' C, `RC' w)
{
    `RR' m
    `RM' CP

    if (cols(C)==0) return(J(1,0,.))
    m  = mean(C, w)
    CP = quadcrossdev(C, m, w, C, m)
    return((diagonal(invsym(CP, 1..cols(CP))):==0)')
}

void _mm_ebal_dfc(`RV' t, `Bool' cov, `RM' M, `RS' k, `RS' dfc)
{
    `RS' i, j, l, c
    `RM' V
    
    if (cols(M)==k) return
    if (dfc==1)     return
    c = k
    l = length(t)
    // variances
    if (any(t:>=2)) {
        V = J(2, k, 1) // needed for skewness correction
        for (i=1; i<=k; i++) {
            if (t[mod(i-1, l) + 1] >= 2) {
                c++
                V[1,i] = M[c]
                M[c]   = (M[c] - (1-dfc) * M[i]^2) / dfc
                V[2,i] = M[c]
            }
        }
    }
    // covariances
    if (cov) {
        for (i=1; i<k; i++) {
            for (j=i+1; j<=k; j++) {
                c++
                M[c] = (M[c] - (1-dfc) * M[i] * M[j]) / dfc
            }
        }
    }
    // skewnesses
    if (anyof(t,3)) {
        for (i=1; i<=k; i++) {
            if (t[mod(i-1, l) + 1] == 3) {
                c++
                M[c] = (M[c] - (V[1,i] - dfc^1.5 * V[2,i]) * 3 *M[i]
                             + (1 - dfc^1.5) * 2 * M[i]^3) / dfc^1.5
            }
        }
    }
}

`T' mm_ebal_btol(`S' S, | `RS' btol)
{
    if (args()==1) return(S.btol)
    S.btol = btol
}

`T' mm_ebal_trace(`S' S, | `SS' tracelevel)
{
    if (args()==1) return(optimize_init_tracelevel(S.O))
    optimize_init_tracelevel(S.O, tracelevel)
}

`T' mm_ebal_difficult(`S' S, | `Bool' flag)
{
    if (args()==1) return(optimize_init_singularHmethod(S.O)=="hybrid")
    if (flag)      optimize_init_singularHmethod(S.O, "hybrid")
    else           optimize_init_singularHmethod(S.O, "")
}

`T' mm_ebal_maxiter(`S' S, | `Int' maxiter)
{
    if (args()==1) return(optimize_init_conv_maxiter(S.O))
    optimize_init_conv_maxiter(S.O, maxiter)
}

`T' mm_ebal_ptol(`S' S, | `RS' tol)
{
    if (args()==1) return(optimize_init_conv_ptol(S.O))
    optimize_init_conv_ptol(S.O, tol)
}

`T' mm_ebal_vtol(`S' S, | `RS' tol)
{
    if (args()==1) return(optimize_init_conv_vtol(S.O))
    optimize_init_conv_vtol(S.O, tol)
}

`T' mm_ebal_nowarn(`S' S, | `Bool' flag)
{
    if (args()==1) return(optimize_init_conv_warning(S.O)=="off")
    if (flag)      optimize_init_conv_warning(S.O, "off")
    else           optimize_init_conv_warning(S.O, "on")
}

`T' mm_ebal_Z(`S' S, | `RR' Z)
{
    if (args()==1) return(S.Z)
    S.Z = Z
}

`Bool' mm_ebal(`S' S)
{
    // find solution
    if (S.nc==0 & cols(S.C)==0) {
        // no covariates; nothing to do
        S.Z = S.g = J(1,0,.)
        S.conv = 1
        S.rc = S.i = S.v = 0
    }
    else {
        optimize_init_argument(S.O, 1, S.nc)
        optimize_init_argument(S.O, 2, &S.Q)
        optimize_init_argument(S.O, 3, &S.C)
        optimize_init_argument(S.O, 4, S.N)
        optimize_init_params(S.O, S.Z)
        (void) _optimize(S.O)
        if (optimize_result_errorcode(S.O)) {
            errprintf("{p}\n")
            errprintf("%s\n", optimize_result_errortext(S.O))
            errprintf("{p_end}\n")
        }
        S.Z    = optimize_result_params(S.O)
        S.conv = optimize_result_converged(S.O)
        S.rc   = optimize_result_returncode(S.O)
        S.i    = optimize_result_iterations(S.O)
        S.v    = optimize_result_value(S.O)
        S.g    = optimize_result_gradient(S.O)
    }
    // obtain balancing weights from solution and update balancing loss
    if (S.nc) {
        S.W = S.Q :* exp(quadcross(S.C', S.Z'))
        if (cols(S.CC)>0) S.v = max((S.v, abs(quadcross(S.W, S.CC)) / S.N))
    }
    else {
        S.W = quadcross(S.C', S.Z')
        S.W = S.Q :* exp(S.W:-max(S.W))
        S.W = S.W / quadsum(S.W)
        if (cols(S.CC)>0) S.v = max((S.v, abs(quadcross(S.W, S.CC))))
        S.W = S.W * S.N
    }
    if (cols(S.CC)>0) {
        if (optimize_init_tracelevel(S.O)!="none") {
            printf("{txt}Final fit including collinear terms:\n")
            printf("               max difference = {res}%10.0g\n", S.v)
        }
    }
    // check whether balancing criterion fulfilled
    S.balanced = (S.v < S.btol)
    return(S.balanced)
}

void _mm_ebal_eval(`Int' todo, `RR' Z, 
    `Bool' nc, `pRC' Q, `pRM' C, `RS' N, 
    `RS' v, `RR' g, `RM' H)
{
    `RC' W
    
    if (nc) {
        W = *Q :* exp(quadcross(*C', Z'))
        g = quadcross(W, *C)
        g[1] = g[1] - N
        v = max(abs(g)) / N
    }
    else {
        W = quadcross(*C', Z')
        W = *Q :* exp(W:-max(W)) // set exp(max)=1 to avoid numerical overflow
        W = W / quadsum(W)
        g = quadcross(W, *C)
        v = max(abs(g))
    }
    if (todo==2) {
        H = quadcross(*C, W, *C)
    }
}

end



*! {smcl}
*! {marker mm_crosswalk}{bf:mm_crosswalk.mata}{asis}
*! version 1.0.0  Ben Jann  23aug2021

version 11.2
mata:

transmorphic vector mm_crosswalk(
    transmorphic vector x,
    transmorphic vector from,
    transmorphic vector to,
    | transmorphic vector d, 
      transmorphic scalar n)
{
    real scalar         usehash, l, offset
    real rowvector      minmax
    transmorphic vector res

    // defaults
    if (args()<4) d = missingof(to)
    if (args()<5) n = 1e6
    
    // consistency checks
    if (eltype(x)!=eltype(from))  _error(3250, "{it:x} and {it:from} must be of same type")
    if (length(from)!=length(to)) _error(3200, "{it:from} and {it:to} must have same length")
    if (eltype(to)!=eltype(d))    _error(3250, "{it:to} and {it:d} must be of same type")
    if (length(d)!=1) {
        if (length(d)!=length(x)) _error(3200, "{it:x} and {it:d} not conformable")
    }
    
    // nothing to translate
    // - empty x
    if (!length(x)) return(J(rows(x), cols(x), missingof(to)))
    // - empty dictionary
    if (!length(to)) {
        if (length(d)!=1) return(J(1,1, rows(x)==rows(d) ? d : d'))
        return(J(rows(x), cols(x), d))
    }
    
    // pick algorithm
    usehash = 0
    if (!isreal(x)) usehash = 1
    else if (n<1)   usehash = 1
    else {
        if (n<.z) {
            if      (hasmissing(from))  usehash = 1
            else if (hasmissing(x))     usehash = 1
            else if (trunc(from)!=from) usehash = 1
            else if (trunc(x)!=x)       usehash = 1
        }
        if (!usehash) {
            minmax = minmax((minmax(from),minmax(x)))
            l = minmax[2] - minmax[1] + 1
            if (l>n) usehash = 1
        }
    }
    
    // use hash algorithm
    if (usehash) return(_mm_crosswalk_hash(x, from, to, d))
    
    // use index algorithm
    offset = minmax[1] - 1
    res = rows(to)!=1 ? J(l, 1, d[1]) : J(1, l, d[1])
    if (length(d)!=1) res[x :- offset] = (rows(res)!=1)==(rows(d)!=1) ? d : d'
    res[from :- offset] = to
    res = res[x :- offset]
    return(rows(res)==rows(x) ? res : res')
}

transmorphic vector mm_crosswalk_hash(
    transmorphic vector x,
    transmorphic vector from,
    transmorphic vector to,
    | transmorphic vector d)
{
    // defaults
    if (args()<4) d = missingof(to)
    
    // consistency checks
    if (eltype(x)!=eltype(from))  _error(3250, "{it:x} and {it:from} must be of same type")
    if (length(from)!=length(to)) _error(3200, "{it:from} and {it:to} must have same length")
    if (eltype(to)!=eltype(d))    _error(3250, "{it:to} and {it:d} must be of same type")
    if (length(d)!=1) {
        if (length(d)!=length(x)) _error(3200, "{it:x} and {it:d} not conformable")
    }
    
    // nothing to translate
    // - empty x
    if (!length(x)) return(J(rows(x), cols(x), missingof(to)))
    // - empty dictionary
    if (!length(to)) {
        if (length(d)!=1) return(J(1,1, rows(x)==rows(d) ? d : d'))
        return(J(rows(x), cols(x), d))
    }

    // run hash algorithm
    return(_mm_crosswalk_hash(x, from, to, d))
}

transmorphic vector _mm_crosswalk_hash(
    transmorphic vector x,
    transmorphic vector from,
    transmorphic vector to,
    transmorphic vector d)
{
    real scalar  i
    transmorphic A
    transmorphic scalar a
    real matrix notfound
    transmorphic vector res
    
    // build dictionary
    A = asarray_create(eltype(from))
    for (i=length(from); i; i--) asarray(A, from[i], to[i])
    
    // case 1: d is scalar
    if (length(d)==1) {
        res = J(rows(x), cols(x), missingof(to))
        asarray_notfound(A, d)
        for (i=length(x); i; i--) res[i] = asarray(A, x[i])
        return(res)
    }
    
    // case 2: d is non-scalar
    res = rows(x)==rows(d) ? d : d'
    notfound = J(0,0,.)
    asarray_notfound(A, notfound)
    for (i=length(x); i; i--) {
        a = asarray(A, x[i])
        if (a==notfound) continue // using asarray_contains() would be slower
        res[i] = a
    }
    return(res)
}

end