예제 #1
0
def get_ll(M1,C1,M2,C2,beta,s,y):
    '''
    Compute negative log-likelihood up to a constant.

    Parameters
    ----------
    M1: prior mean
    C1: prior covariance
    M2: posterior mean
    C2: posterior covariance
    beta: projection onto rate
    s: offset 
    y: point-process observation count
    '''
    M1   = assertfinitereal(ascolumn(M1))
    M2   = assertfinitereal(ascolumn(M2))
    C1   = assertfinitereal(assquare(C1))
    C2   = assertfinitereal(assquare(C2))
    beta = assertfinitereal(ascolumn(beta))
    logr   = beta.T.dot(M1)+s
    logPyx = y*logr - sexp(logr)
    Ch     = trychol(C1,1e-6)
    RR     = linv(Ch,M2-M1)
    ll     = logPyx + np.sum(slog(np.diag(chol(C2)))) - np.sum(slog(np.diag(Ch))) - 0.5*RR.T.dot(RR)
    return scalar(ll)
def univariate_lgp_update_laplace(m,v,y,s,dt,
                    tol     = 1e-6,
                    maxiter = 20,
                    eps     = 1e-12):
    '''
    Optimize using Laplace approximation
    '''
    v = max(v,eps)
    scale = dt*np.exp(s)
    def objective(mu):
        rate = scale*sexp(mu)
        if not np.isfinite(mu) or not np.isfinite(rate):
            return np.inf
        return -y*mu+rate+0.5*(mu-m)**2/v
    def gradient(mu):
        rate = scale*sexp(mu)
        if not np.isfinite(mu) or not np.isfinite(rate):
            return np.inf
        return -y+rate+(mu-m)/v
    def hessian(mu):
        rate = scale*sexp(mu)
        if not np.isfinite(mu) or not np.isfinite(rate):
            return np.inf
        return rate+1/v
    mu = minimize_retry(objective,m,gradient,hessian,tol=tol,show_progress=False,printerrors=False)
    vv = 1/hessian(mu)
    # Get likelihood at posterior mode
    logr   = mu+s+slog(dt)
    logPyx = y*logr-sexp(logr)
    ll     = logPyx + 0.5*slog(vv/v) - 0.5*(mu-m)**2/v 
    return mu, vv, ll
def univariate_lgp_update_moment(m,v,y,s,dt,
                    tol      = 1e-3,
                    maxiter  = 100,
                    eps      = 1e-7,
                    minlrate = -200,
                    maxlrate = 20,
                    ngrid    = 50,
                    minprec  = 1e-6,
                    maxrange = 150):
    '''
    Update a log-Gaussian distribution with a Poisson measurement by
    integrating to extract the posterior mean and variance.
    '''
    # Get moments by integrating
    v = max(v,eps)
    t = 1/v
    m = np.clip(m,minlrate,maxlrate)
    # Set integration limits
    m0,s0 = (m,np.sqrt(v)) if t>minprec else (slog(y+0.25),np.sqrt(1/(y+1)))
    m0 = np.clip(m0,minlrate,maxlrate)
    delta = min(4*s0,maxrange)
    x = np.linspace(m0-delta,m0+delta,ngrid)
    # Calculate likelihood contribution
    r  = x + s + slog(dt)
    ll = y*r-sexp(r)
    # Calculate prior contribution
    lq = -.5*((x-m)**2/v+slog(2*np.pi*v))
    # "clean up" prior (numerical stability)
    q  = np.maximum(eps,sexp(lq))
    q  = q/np.sum(q)
    lq = slog(q)
    # Normalize to prevent overflow, calculate log-posterior
    nn = np.max(ll)
    lp = (ll - nn) + lq
    # Estimate posterior
    p  = np.maximum(eps,sexp(lp))
    s  = np.sum(p)
    p /= s
    # Integrate to get posterior moments and likelihood
    pm = np.sum(x*p)
    pv = np.sum((x-pm)**2*p)
    ll = scipy.misc.logsumexp(ll+lq)
    assertfinitereal(pm)
    assertfinitereal(pv)
    assertfinitereal(ll)
    return pm,pv,ll
def univariate_lgp_update_variational_so(m,v,y,s,dt,
                    tol     = 1e-6,
                    maxiter = 20,
                    eps     = 1e-12):
    '''
    Optimize variational approximation
    Mean and variance must be optimized jointly
    2nd order Gaussian/Poisson model approximation
    '''
    v = max(v,eps)
    scale = dt*sexp(s)
    # Use minimization to solve for variational solution
    def objective(parameters):
        mq,vq = parameters
        rate  = scale*np.exp(mq)*(1 + 0.5*vq)
        if vq<eps or not np.isfinite(mq) or not np.isfinite(vq) or not np.isfinite(rate):
            return np.inf
        return -y*mq + rate + 0.5*( -slog(vq) + vq/v + (mq-m)**2/v )
    def gradient(parameters):
        mq,vq = parameters
        rate  = scale*np.exp(mq)*(1 + 0.5*vq)
        if vq<eps or not np.isfinite(mq) or not np.isfinite(vq) or not np.isfinite(rate):
            return (np.NaN,np.NaN)
        dm    = -y + rate + (mq-m)/v
        dv    = rate*0.5 + 0.5*(1/v-1/vq)
        return np.array([dm, dv]).squeeze()
    def hessian(parameters):
        mq,vq = parameters
        rate  = scale*np.exp(mq)*(1 + 0.5*vq)
        if vq<eps or not np.isfinite(mq) or not np.isfinite(vq) or not np.isfinite(rate): 
            return [[np.NaN,np.NaN],[np.NaN,np.NaN]]
        dmdm  = rate + 1/v
        dvdm  = rate*0.5
        dvdv  = rate*0.25 + 0.5/vq**2
        return np.array([[dmdm,dvdm],[dvdm,dvdv]]).squeeze()
    mp,vp = minimize_retry(objective,[m,v],gradient,hessian,tol=tol,show_progress=False,printerrors=False)
    # Get likelihood using second order assumption
    logr   = mp+s+slog(dt)
    logPyx = y*logr-sexp(logr)*(1+0.5*vp)
    ll     = logPyx + 0.5*slog(vp/v) - 0.5*(mp-m)**2/v 
    return mp, vp, ll
예제 #5
0
def get_ll_univariate(m1,v1,m2,v2,beta,s,y):
    '''
    Compute negative log-likelihood up to a constant

    Parameters
    ----------
    m1: prior mean
    v1: prior covariance
    m2: posterior mean
    v2: posterior covariance
    beta: projection onto rate
    s: offset 
    y: point-process observation count
    '''
    m1 = scalar(m1)
    v1 = scalar(v1)
    m2 = scalar(m2)
    v2 = scalar(v2)
    logr = m2+s
    logPyx = y*logr - sexp(logr)
    ll = logPyx + 0.5*slog(v2/v1) - 0.5*(m2-m1)**2/v1 
    return scalar(ll)
예제 #6
0
def filter_moments(stim,Y,A,beta,C,m,
    dt          = 1.0,
    oversample  = 10,
    maxrate     = 500,
    maxvcorr    = 2000,
    method      = "moment_closure",
    int_method  = "euler",
    measurement = "moment",
    reg_cov     = 0,
    reg_rate    = 0,
    return_surrogates  = False,
    use_surrogates     = None,
    initial_conditions = None,
    progress           = False,
    safe               = True):
    
    '''
    Parameters
    ----------
    stim : zero-lage effective input (filtered stimulus plus mean offset)
    Y : point-process count observations, same length as stim
    A : forward operator for delay-line evolution
    C : projection of current state onto delay-like
    beta : basis history weights
    m : log-rate bias parameter, log-rates are regularized toward this value
    
    Other Parameters
    ----------------
    dt : time step
    oversample : int
        Integration steps per time step. Should be larger if using 
        Gaussian moment closure, which is stiff. Can be small if using
        second-order approximations, which are less stiff.
    maxrate : 
        maximum rate tolerated
    maxvcorr: 
        Maximum variance correction ('convexity correction' in some literature)
        tolerated during the moment closure.
    method : 
        Moment-closure method. Can be "LNA" for mean-field with linear
        noise approximation, "moment_closure" for Gaussian moment-closure
        on the history process, or "second_order", which discards higher
        moments of the rate which emerge when exponentiating.
    int_method:
        Integration method. Can be either "euler" for forward-Euler, or 
        "exponential", which integrates the locally-linearized system 
        forward using matrix exponentiation (slower).
    measurement:
        "moment", "laplace", or "variational"
    reg_cov:
        Diagonal covariance regularization
    reg_rate:
        Small regularization toward log mean-rate; This parameter reflects
        the precision of a Gaussian prior about the log mean-rate, applied 
        at every measurement update.
    return_surrogates: bool
        If true, Gaussian approximations of measurement likelihoods are
        returned. 
    use_surrogates: None or tuple
        Can be set as tuple of (means, variances) for Gaussian
        approximations of meausrement updates.
    initial_conditions: None or tuple
        Can be set to a tuple (M1,M2) of initial conditions for moment
        filtering.
    progress: boolean
        Whether to report progress
    
    Returns
    -------
    allLR : single-time marginal mean of log-rate
    allLV : single-time marginal variance of log-rate
    allM1 : low-dimensional approximation of history process, mean
    allM2 : low-dimensional approximation of history process, covariance
    nll   : negative log-likelihood
    '''
    # check arguments
    stim = asvector(stim)
    Y    = asvector(Y)
    A    = assquare(A)
    if oversample<1:
        raise ValueError('oversample must be non-negative integer')
    if method=="moment_closure" and measurement=="variational":
        warnings.warn("There are unresolved numerical stability issues "\
        "when using the log-Gaussian variational update with Gaussian "\
        "moment closure. Suggest using the second-order moment closure "\
        "instead")
    # Precompute constants
    maxlogr   = np.log(maxrate)
    maxratemc = maxvcorr*maxrate
    dtfine    = dt/oversample
    T         = len(stim)
    K         = beta.size
    I         = np.eye(K)
    Cb        = C.dot(beta.T)
    CC        = C.dot(C.T)
    Adt       = A*dtfine
    if not use_surrogates is None:
        MR,VR = use_surrogates
    # Get measurement update function
    measurement = get_measurement(measurement)
    # Buid moment integrator functions
    mean_update, cov_update = get_moment_integrator(int_method,Adt)
    # Get update function (computes expected rate from moments)
    update = get_update_function(method,Cb,Adt,maxvcorr)
    # Accumulate negative log-likelihood up to a constant
    nll = 0
    llrescale = 1.0/len(stim)
    if initial_conditions is None:
        # Initial condition for moments
        M1 = np.zeros((K,1))
        M2 = np.eye(K)*1e-6
    else:
        M1,M2 = initial_conditions
    # Store moments
    allM1 = np.zeros((T,K))
    allM2 = np.zeros((T,K,K))
    allLR = np.zeros((T))
    allLV = np.zeros((T))
    allmr = np.zeros((T))
    allvr = np.zeros((T))
    if progress:
        last_shown = current_milli_time()
    for i,s in enumerate(stim):
        # Regularize
        if reg_cov>0:
            strength = reg_cov+max(0,-np.min(np.diag(M2)))
            M2 = 0.5*(M2+M2.T) + strength*np.eye(K) 
        # Integrate moments forward
        for j in range(oversample):
            logv  = beta.T.dot(M2).dot(beta)
            logx  = min(beta.T.dot(M1)+s,maxlogr)
            R0    = sexp(logx)
            R0    = min(maxrate,R0)
            R0   *= dtfine
            Rm,J  = update(logx,logv,R0,M1,M2)
            M2    = cov_update(M2,J) + CC*Rm
            M1    = mean_update(M1)  + C *Rm
            if safe:
                M1    = np.clip(M1,-100,100)
                M2    = np.clip(M2,-100,100)
        # Measurement update
        pM1,pM2 = M1,M2
        if use_surrogates is None:
            # Use specified approximation method to handle non-conjugate
            # log-Gaussian Poisson measurement update
            M1,M2,ll,mr,vr = measurement_update_projected_gaussian(\
                      M1,M2,Y[i],beta,s,dt,m,reg_rate,measurement)
            allmr[i] = mr
            allvr[i] = vr
        else:
            # Use specified Gaussian approximations (MR,VR) to the 
            # measurement likelihoods.
            M1,M2,ll = measurement_update_projected_gaussian_surrogate(\
                      M1,M2,Y[i],beta,s,dt,m,reg_rate,measurement,
                                      return_surrogate=False,
                                      surrogate=(MR[i],VR[i]))
        if safe:
            M1    = np.clip(M1,-100,100)
            M2    = np.clip(M2,-100,100)
        # Store moments
        allM1[i] = M1[:,0].copy()
        allM2[i] = M2.copy()
        allLR[i] = min(beta.T.dot(M1)+s,maxlogr)
        allLV[i] = beta.T.dot(M2).dot(beta)
        nll -= ll*llrescale
        if safe:
            # Heuristic: detect numerical failure and exit early
            failed = np.any(M1)<-1e5
            failed|= logx>100*maxlogr
            failed|= nll<-1e10
            if failed:
                nll = np.inf
                break
        if progress and current_milli_time()-last_shown>500:
            sys.stdout.write('\r%02.02f%%'%(i*100/T))
            sys.stdout.flush()
            last_shown = current_milli_time()
    if progress:
        sys.stdout.write('\r100.00%')
        sys.stdout.flush()
        sys.stdout.write('\n')
    if return_surrogates:
        return allLR,allLV,allM1,allM2,nll,allmr,allvr
    else:
        return allLR,allLV,allM1,allM2,nll
예제 #7
0
def integrate_moments(stim,A,beta,C,
    dt         = 1.0,
    oversample = 10,
    maxrate    = 500,
    maxvcorr   = 2000,
    method     = "moment_closure",
    int_method = "euler",
    reg_cov    = 1e-6,
    safemode   = "assert"):
    '''
    Integrate moment equations for autoregressive PPGLM
    
    Parameters
    ----------
    stim : zero-lage effective input (filtered stimulus plus mean offset)
    A : forward operator for delay-line evolution
    beta : basis history weights
    C : projection of current state onto delay-like
    
    Other Parameters
    ----------------
    dt : time step
    oversample : int
        Integration steps per time step. Should be larger if using 
        Gaussian moment closure, which is stiff. Can be small if using
        second-order approximations, which are less stiff.
    maxrate : 
        maximum rate tolerated
    maxvcorr: 
        Maximum variance correction ('convexity correction' in some literature)
        tolerated during the moment closure.
    method : 
        Moment-closure method. Can be "LNA" for mean-field with linear
        noise approximation, "moment_closure" for Gaussian moment-closure
        on the history process, or "second_order", which discards higher
        moments of the rate which emerge when exponentiating.
    int_method:
        Integration method. Can be either "euler" for forward-Euler, or 
        "exponential", which integrates the locally-linearized system 
        forward using matrix exponentiation (slower).
    reg_cov: 
        Small diagonal regularization for covariance matrix
    safemode: string
        Whether to check, repair NaNs and singular covariance. 
        if "assert", NaNs will trigger assertion error,
        if "repair", NaNs will be silently corrected
    
    Returns
    -------
    allLR : single-time marginal mean of log-rate
    allLV : single-time marginal variance of log-rate
    allM1 : low-dimensional approximation of history process, mean
    allM2 : low-dimensional approximation of history process, covariance
    '''
    maxlogr   = np.log(maxrate)
    maxratemc = maxvcorr*maxrate
    dtfine    = dt/oversample
    T         = len(stim)
    K         = beta.size
    # Precompute constants
    Cb  = C.dot(beta.T)
    CC  = C.dot(C.T)
    Adt = A*dtfine
    F1  = scipy.linalg.expm(Adt)
    # Buid moment integrator functions
    mean_update, cov_update = get_moment_integrator(int_method,Adt)
    # Get update function (computes expected rate from moments)
    update = get_update_function(method,Cb,Adt,maxvcorr)
    # Initial condition for moments
    M1 = np.zeros((K,1))
    M2 = np.eye(K)*1e-6
    # Store moments
    allM1 = np.zeros((T,K))
    allM2 = np.zeros((T,K,K))
    allLR = np.zeros((T))
    allLV = np.zeros((T))
    # Integrate
    for i,s in enumerate(stim):
        for j in range(oversample):
            if   safemode=="assert":
                assert(np.all(np.isfinite(M1)) and np.all(np.isfinite(M2)))
            elif safemode=="repair":
                M1[~np.isfinite(M1)] = 0
                M2[~np.isfinite(M2)] = 0
            # Marginal variance and mean of the predicted log-rate
            # (This is a linear projection of the process history)
            logv  = beta.T.dot(M2).dot(beta)
            logx  = min(beta.T.dot(M1)+s,maxlogr)
            # Uncorrected rate estimate
            R0    = sexp(logx)*dtfine
            # Covariance corrected rate estimate Rm
            # And system Jacobian J
            Rm,J  = update(logx,logv,R0,M1,M2)
            # Moment updates with new mean and variance
            # proportional to corrected rate Rm
            M2    = cov_update(M2,J) + CC*Rm
            M1    = mean_update(M1)  + C *Rm
        if safemode=="repair":
            M2 = repair_covariance(M2,reg_cov)
        allM1[i] = M1[:,0]
        allM2[i] = M2
        allLR[i] = logx
        allLV[i] = beta.T.dot(M2).dot(beta)
    return allLR,allLV,allM1,allM2
예제 #8
0
 def update(logx,logv,R0,M1,M2):
     Rm = R0 * min(sexp(0.5*logv),maxvcorr)
     J  = Cb*Rm+Adt
     return Rm,J
def measurement_update_projected_gaussian_surrogate(m1,m2,y,b,s,dt,pm,pt,
                                      univariate_method = univariate_lgp_update_moment,
                                      return_surrogate=False,
                                      surrogate=None,
                                      eps=1e-12,
                                      safe=False):
    '''
    Please see `measurement_update_projected_gaussian`. 
    
    This function is the same, except that it can return "surrogate"
    gaussian approximations of the measurement updates, which can then
    be used in subsequent filtering of the same data for much more rapid
    measurement updates.
    
    If the parameters do not change too much, these surrogate updates
    will remain approximately correct. This provides a path toward an 
    approximate EM-style algorithm for optimizing the likelihood using
    moment-closures as an additional likelihood penalty (regularizer) for
    slow dynamics.
    
    Parameters
    ----------
    m1 : first moment (mean of multivariate gaussian)
    m2 : covariance of multivariate gaussian
    y  : spiking observatoin (scalar)
    b  : projection from multivariate gaussian onto univariate log-rate distribution
    s  : stimulus or bias term for log-rate
    dt : time-step scaling of rate (can also be used as generic gain parameter)
    pm : regularizing prior mean log-rate
    pt : regularizing prior precision
    univariate_method : function, one of
        univariate_lgp_update_moment
        univariate_lgp_update_variational
        univariate_lgp_update_laplace
    '''
    
    if surrogate is None:
        return measurement_update_projected_gaussian(m1,m2,y,b,s,dt,pm,pt,
                                          univariate_method = univariate_method,
                                          eps=eps,
                                          safe=safe)
    (mr, vr) = surrogate
    # Validate arguments
    if safe:
        m2 = assertfinitereal(assertsquare(m2))
        m1 = assertfinitereal(assertcolumn(m1))
        b  = assertfinitereal(assertcolumn(b))
    # Precompute constants
    m2 = 0.5*(m2+m2.T)
    # Gaussian state prior on log-rate
    m2b = m2.dot(b)
    if safe:
        v   = scalar(b.T.dot(m2b))
        m   = scalar(b.T.dot(m1))
    v   = max(eps,(b.T.dot(m2b))[0,0])
    m   = (b.T.dot(m1))[0,0]
    t   = 1/v
    # Regularizing Gaussian prior on log-rate
    tq  = pt + t
    mq = (m*t+pm*pt)/tq
    vq  = 1/tq
    tr = 1/vr
    tp = tq + tr
    mp = (mr*tr + mq*tq)/tp
    vp = 1/tp
    if safe:
        mp = scalar(assertfinitereal(mp))
        vp = scalar(assertfinitereal(vp))
    # Futher optimized
    K   = m2b/(vr+v)
    m2p = m2 - K.dot(m2b.T)
    m1p = m1 + K*(mr-m)
    # Also compute log-likelihood from univariate
    logr   = mp+s
    logPyx = y*logr-sexp(logr)
    ll     = logPyx + 0.5*slog(vp/v) - 0.5*(mp-m)**2/v 
    return m1p, m2p, scalar(ll)
 def hessian(mu):
     rate = scale*sexp(mu)
     if not np.isfinite(mu) or not np.isfinite(rate):
         return np.inf
     return rate+1/v
 def gradient(mu):
     rate = scale*sexp(mu)
     if not np.isfinite(mu) or not np.isfinite(rate):
         return np.inf
     return -y+rate+(mu-m)/v
 def objective(mu):
     rate = scale*sexp(mu)
     if not np.isfinite(mu) or not np.isfinite(rate):
         return np.inf
     return -y*mu+rate+0.5*(mu-m)**2/v