예제 #1
0
def newton(fn, jac_fn, U):
    maxit=20
    tol = 1e-8
    count = 0
    res = 100
    fail = 0

    Uold = U
    maxit=5
#    
#    @jax.jit
#    def body_fun(U,Uold):
#        J =  jac_fn(U, Uold)
#        y = fn(U,Uold)
#        res = norm(y/norm(U,np.inf),np.inf)
#        delta = solve(J,y)
#        U = U - delta
#        return U, res
#   
    print("here")
    start =timeit.default_timer()     
    J =  jac_fn(U, Uold)
    print("computed jacobian")
    y = fn(U,Uold)
    res0 = norm(y/norm(U,np.inf),np.inf)
    delta = solve(J,y)
    U = U - delta
    count = count + 1
    end = timeit.default_timer()
    print("time elapsed in first loop", end-start)
    print(count, res0)
    while(count < maxit and  res > tol):
#        U, res, delta = body_fun(U,Uold)\
        start1 =timeit.default_timer() 
        J =  jac_fn(U, Uold)
        y = fn(U,Uold)
        res = norm(y/norm(U,np.inf),np.inf)
        delta = solve(J,y)
        U = U - delta
        count = count + 1
        end1 =timeit.default_timer() 
        print("time per loop", end1-start1)
        print(count, res)
    
        
    if fail ==0 and np.any(np.isnan(delta)):
        fail = 1
        print("nan solution")
        
    if fail == 0 and max(abs(np.imag(delta))) > 0:
            fail = 1
            print("solution complex")
    
    if fail == 0 and res > tol:
        fail = 1;
        print('Newton fail: no convergence')
    else:
        fail == 0 
        
    return U, fail
예제 #2
0
def _make_associative_smoothing_params_generic(transition_function, Qk,
                                               filtered_state,
                                               linearization_state):
    # Prediction part
    sigma_points = get_sigma_points(linearization_state)

    propagated_points = transition_function(sigma_points.points)
    propagated_sigma_points = SigmaPoints(propagated_points, sigma_points.wm,
                                          sigma_points.wc)
    propagated_state = get_mv_normal_parameters(propagated_sigma_points)

    pred_cross_covariance = covariance_sigma_points(sigma_points,
                                                    linearization_state.mean,
                                                    propagated_sigma_points,
                                                    propagated_state.mean)

    F = jlinalg.solve(linearization_state.cov,
                      pred_cross_covariance,
                      sym_pos=True).T  # Linearized transition function

    Pp = Qk + propagated_state.cov + F @ (filtered_state.cov -
                                          linearization_state.cov) @ F.T

    E = jlinalg.solve(Pp, F @ filtered_state.cov, sym_pos=True).T

    g = filtered_state.mean - E @ (propagated_state.mean + F @ (
        filtered_state.mean - linearization_state.mean))
    L = filtered_state.cov - E @ F @ filtered_state.cov

    return g, E, 0.5 * (L + L.T)
예제 #3
0
def filtering_operator(elem1, elem2):
    """
    Associative operator described in TODO: put the reference

    Parameters
    ----------
    elem1: tuple of array
        a_i, b_i, C_i, eta_i, J_i
    elem2: tuple of array
        a_j, b_j, C_j, eta_j, J_j

    Returns
    -------

    """
    A1, b1, C1, eta1, J1 = elem1
    A2, b2, C2, eta2, J2 = elem2
    dim = b1.shape[0]

    I_dim = jnp.eye(dim)

    IpCJ = I_dim + jnp.dot(C1, J2)
    IpJC = I_dim + jnp.dot(J2, C1)

    AIpCJ_inv = jlinalg.solve(IpCJ.T, A2.T, sym_pos=False).T
    AIpJC_inv = jlinalg.solve(IpJC.T, A1, sym_pos=False).T

    A = jnp.dot(AIpCJ_inv, A1)
    b = jnp.dot(AIpCJ_inv, b1 + jnp.dot(C1, eta2)) + b2
    C = jnp.dot(AIpCJ_inv, jnp.dot(C1, A2.T)) + C2
    eta = jnp.dot(AIpJC_inv, eta2 - jnp.dot(J2, b1)) + eta1
    J = jnp.dot(AIpJC_inv, jnp.dot(J2, A1)) + J1
    return A, b, 0.5 * (C + C.T), eta, 0.5 * (J + J.T)
예제 #4
0
def update(observation_function: Callable, observation_covariance: jnp.ndarray,
           predicted_state: MVNormalParameters, observation: jnp.ndarray,
           linearization_state: MVNormalParameters) -> MVNormalParameters:
    """ Computes the extended kalman filter linearization of :math:`x_t \mid y_t`

    Parameters
    ----------
    observation_function: callable :math:`h(x_t,\epsilon_t)\mapsto y_t`
        observation function of the state space model
    observation_covariance: (K,K) array
        observation_error :math:`\Sigma` fed to observation_function
    predicted_state: MVNormalParameters
        predicted approximate mv normal parameters of the filter :math:`x`
    observation: (K) array
        Observation :math:`y`
    linearization_state: MVNormalParameters
        state for the linearization of the update

    Returns
    -------
    updated_mvn_parameters: MVNormalParameters
        filtered state
    """
    if linearization_state is None:
        linearization_state = predicted_state
    sigma_points = get_sigma_points(linearization_state)
    obs_points = observation_function(sigma_points.points)
    obs_sigma_points = SigmaPoints(obs_points, sigma_points.wm,
                                   sigma_points.wc)

    obs_state = get_mv_normal_parameters(obs_sigma_points)
    cross_covariance = covariance_sigma_points(sigma_points,
                                               linearization_state.mean,
                                               obs_sigma_points,
                                               obs_state.mean)

    H = jlinalg.solve(linearization_state.cov, cross_covariance,
                      sym_pos=True).T  # linearized observation function

    d = obs_state.mean - jnp.dot(
        H, linearization_state.mean)  # linearized observation offset

    residual_cov = H @ (predicted_state.cov - linearization_state.cov) @ H.T + \
                   observation_covariance + obs_state.cov

    gain = jlinalg.solve(residual_cov, H @ predicted_state.cov).T

    predicted_observation = H @ predicted_state.mean + d

    residual = observation - predicted_observation
    mean = predicted_state.mean + gain @ residual
    cov = predicted_state.cov - gain @ residual_cov @ gain.T
    loglikelihood = multivariate_normal.logpdf(residual,
                                               jnp.zeros_like(residual),
                                               residual_cov)

    return loglikelihood, MVNormalParameters(mean, 0.5 * (cov + cov.T))
예제 #5
0
def _make_associative_filtering_params_generic(observation_function, Rk,
                                               transition_function, Qk_1,
                                               prev_linearization_state,
                                               linearization_state, yk):
    # Prediction part
    sigma_points = get_sigma_points(prev_linearization_state)

    propagated_points = transition_function(sigma_points.points)
    propagated_sigma_points = SigmaPoints(propagated_points, sigma_points.wm,
                                          sigma_points.wc)
    propagated_state = get_mv_normal_parameters(propagated_sigma_points)

    pred_cross_covariance = covariance_sigma_points(
        sigma_points, prev_linearization_state.mean, propagated_sigma_points,
        propagated_state.mean)

    F = jlinalg.solve(prev_linearization_state.cov,
                      pred_cross_covariance,
                      sym_pos=True).T  # Linearized transition function
    pred_mean_residual = propagated_state.mean - F @ prev_linearization_state.mean
    pred_cov_residual = propagated_state.cov - F @ prev_linearization_state.cov @ F.T + Qk_1

    # Update part
    linearization_points = get_sigma_points(linearization_state)
    obs_points = observation_function(linearization_points.points)
    obs_sigma_points = SigmaPoints(obs_points, linearization_points.wm,
                                   linearization_points.wc)
    obs_mvn = get_mv_normal_parameters(obs_sigma_points)
    update_cross_covariance = covariance_sigma_points(linearization_points,
                                                      linearization_state.mean,
                                                      obs_sigma_points,
                                                      obs_mvn.mean)

    H = jlinalg.solve(linearization_state.cov,
                      update_cross_covariance,
                      sym_pos=True).T
    obs_mean_residual = obs_mvn.mean - jnp.dot(H, linearization_state.mean)
    obs_cov_residual = obs_mvn.cov - H @ linearization_state.cov @ H.T

    S = H @ pred_cov_residual @ H.T + Rk + obs_cov_residual  # total residual covariance
    total_obs_residual = (yk - H @ pred_mean_residual - obs_mean_residual)
    S_invH = jlinalg.solve(S, H, sym_pos=True)

    K = (S_invH @ pred_cov_residual).T
    A = F - K @ H @ F
    b = pred_mean_residual + K @ total_obs_residual
    C = pred_cov_residual - K @ S @ K.T

    temp = (S_invH @ F).T
    HF = H @ F

    eta = temp @ total_obs_residual
    J = temp @ HF
    return A, b, 0.5 * (C + C.T), eta, 0.5 * (J + J.T)
예제 #6
0
def damped_newton(fn, jac_fn, U):
    maxit=10
    tol = 1e-8
    count = 0
    res = 100
    fail = 0
#    U = jax.ops.index_update(U, jax.ops.index[jp02:etap02], U[jp02:etap02]*2**(-16))
    Uold = U
    J =  jac_fn(U, Uold)
    y = fn(U,Uold)  
    delta = solve(J,y)
    U = U - delta;
    res0 = norm(y/norm(U,np.inf),np.inf)
    print(count, res0)
    while(count < maxit and res > tol):
        J =  jac_fn(U, Uold)
        y = fn(U,Uold)        
        res = norm(y/norm(U,np.inf),np.inf)
#        res=norm(y, np.inf)
        print(count, res)
        delta = solve(J,y)
        
#        alpha = 1.0
#        while (norm( fn(U - alpha*delta,Uold )) > (1-alpha*0.5)*norm(y)):
##            print("norm1",norm( fn(U - alpha*delta,Uold )))
##            print("norm2", (1-alpha*0.5)*norm(y) )
#            alpha = alpha/2;
##            print("alpha",alpha)
#            if (alpha < 1e-8):
#                break;
#                
#        U = U - alpha*delta
        U = U - delta;
        count = count + 1
    
        
    if fail ==0 and np.any(np.isnan(delta)):
        fail = 1
        print("nan solution")
        
    if fail == 0 and max(abs(np.imag(delta))) > 0:
            fail = 1
            print("solution complex")
    
    if fail == 0 and res > tol:
        fail = 1;
        print('Newton fail: no convergence')
    else:
        fail == 0 
        
    return U, fail
def _make_associative_filtering_params_first(observation_function, R,
                                             transition_function, Q,
                                             initial_state,
                                             linearization_state, y):
    # Prediction part
    initial_sigma_points = get_sigma_points(initial_state)
    propagated_points = transition_function(initial_sigma_points.points)
    propagated_sigma_points = SigmaPoints(propagated_points,
                                          initial_sigma_points.wm,
                                          initial_sigma_points.wc)
    propagated_state = get_mv_normal_parameters(propagated_sigma_points)

    pred_cross_covariance = covariance_sigma_points(initial_sigma_points,
                                                    initial_state.mean,
                                                    propagated_sigma_points,
                                                    propagated_state.mean)

    F = jlinalg.solve(initial_state.cov, pred_cross_covariance,
                      sym_pos=True).T  # Linearized transition function

    m1 = propagated_state.mean
    P1 = propagated_state.cov + Q

    # Update part
    linearization_points = get_sigma_points(linearization_state)
    obs_points = observation_function(linearization_points.points)
    obs_sigma_points = SigmaPoints(obs_points, linearization_points.wm,
                                   linearization_points.wc)
    obs_mvn = get_mv_normal_parameters(obs_sigma_points)
    update_cross_covariance = covariance_sigma_points(linearization_points,
                                                      linearization_state.mean,
                                                      obs_sigma_points,
                                                      obs_mvn.mean)

    H = jlinalg.solve(linearization_state.cov,
                      update_cross_covariance,
                      sym_pos=True).T
    d = obs_mvn.mean - jnp.dot(H, linearization_state.mean)
    predicted_observation = H @ m1 + d

    S = H @ (P1 - linearization_state.cov) @ H.T + R + obs_mvn.cov
    K = jlinalg.solve(S, H @ P1, sym_pos=True).T
    A = jnp.zeros(F.shape)
    b = m1 + K @ (y - predicted_observation)
    C = P1 - K @ S @ K.T

    eta = jnp.zeros(F.shape[0])
    J = jnp.zeros(F.shape)

    return A, b, 0.5 * (C + C.T), eta, J
예제 #8
0
 def update(state):
     data, p_, e_, C_, mu, iters, _ = state
     x, y = data
     mu = np.float32(mu)
     #
     J = jacobian(p_, x, y)
     H = damped_hessian(J, mu)
     Je = jac_err_prod(J, e_, p_)
     #
     dp = solve(H, Je, sym_pos=True)
     p = p_ - dp
     e = error(p, x, y)
     C = cost(e, p)
     rho = (C_ - C) / (dp.T @ (mu * dp + Je))
     #
     mu = np.where(rho > rho_c, np.maximum(mu / c, mu_min), mu)
     #
     bad_step = (rho < rho_min) | np.any(np.isnan(p))
     mu = np.where(bad_step, np.minimum(c * mu, mu_max), mu)
     p = cond(bad_step, lambda t: t[0], lambda t: t[1], (p_, p))
     e = cond(bad_step, lambda t: t[0], lambda t: t[1], (e_, e))
     C = np.where(bad_step, C_, C)
     improved = (C_ > C) | bad_step
     #
     return LevenbergMarquardtState(data, p, e, C, mu, iters + ~bad_step,
                                    improved)
예제 #9
0
 def update(state):
     data, p_, e_, C_, mu, alpha, iters, _ = state
     x, y = data
     mu = np.float32(mu)
     alpha_ = np.float32(alpha)
     #
     J = jacobian(p_, x, y)
     H = J.T @ J
     Je = J.T @ e_ + alpha_ * p_
     I = np.diag_indices_from(H)
     #
     dp = solve(H.at[I].add(alpha_ + mu), Je, sym_pos=True)
     p = p_ - dp
     e = error(p, x, y)
     C = (sum_squares(e) + alpha * sum_squares(p)) / 2
     rho = (C_ - C) / (dp.T @ (mu * dp + Je))
     #
     mu = np.where(rho > rho_c, np.maximum(mu / c, mu_min), mu)
     #
     bad_step = (rho < rho_min) | np.any(np.isnan(p))
     mu = np.where(bad_step, np.minimum(c * mu, mu_max), mu)
     p = cond(bad_step, lambda t: t[0], lambda t: t[1], (p_, p))
     e = cond(bad_step, lambda t: t[0], lambda t: t[1], (e_, e))
     #
     sse = sum_squares(e)
     ssp = sum_squares(p)
     C = np.where(bad_step, C_, C)
     improved = (C_ > C) | bad_step
     #
     bundle = (alpha, H, I, sse, ssp, x.size)
     alpha, *_ = cond(bad_step, lambda t: t, update_hyperparams, bundle)
     C = (sse + alpha * ssp) / 2
     #
     return LevenbergMarquardtBRState(data, p, e, C, mu, alpha,
                                      iters + ~bad_step, improved)
예제 #10
0
def _make_associative_filtering_params(args):
    Hk, Rk, Fk_1, Qk_1, uk_1, yk, dk, I_dim = args

    # FIRST TERM
    ############

    # temp variable
    HQ = jnp.dot(Hk, Qk_1)  # Hk @ Qk_1

    Sk = jnp.dot(HQ, Hk.T) + Rk
    Kk = jlnialg.solve(
        Sk, HQ, sym_pos=True).T  # using the fact that S and Q are symmetric

    # temp variable:
    I_KH = I_dim - jnp.dot(Kk, Hk)  # I - Kk @ Hk

    Ck = jnp.dot(I_KH, Qk_1)

    residual = (yk - jnp.dot(Hk, uk_1) - dk)

    bk = uk_1 + jnp.dot(Kk, residual)
    Ak = jnp.dot(I_KH, Fk_1)

    # SECOND TERM
    #############
    HF = jnp.dot(Hk, Fk_1)
    FHS_inv = jsolve(Sk, HF).T

    etak = jnp.dot(FHS_inv, residual)
    Jk = jnp.dot(FHS_inv, HF)

    return Ak, bk, Ck, etak, Jk
예제 #11
0
def _make_associative_filtering_params_first(observation_function,
                                             jac_observation_function, R,
                                             transition_function,
                                             jac_transition_function, Q, m0,
                                             P0, x_k, y):
    F = jac_transition_function(m0)

    m1 = transition_function(m0)
    P1 = F @ P0 @ F.T + Q

    H = jac_observation_function(x_k)

    S = H @ P1 @ H.T + R
    K = jlinalg.solve(S, H @ P1, sym_pos=True).T
    A = jnp.zeros(F.shape)

    alpha = observation_function(x_k) + H @ (m1 - x_k)

    b = m1 + K @ (y - alpha)
    C = P1 - (K @ S @ K.T)

    eta = jnp.zeros(F.shape[0])
    J = jnp.zeros(F.shape)

    return A, b, C, eta, J
def _make_associative_filtering_params_first(
        observation_function, jac_observation_function, R, transition_function,
        jac_transition_function, Q, m0, P0, x_k_1, x_k, y, propagate_first):
    if propagate_first:
        F = jac_transition_function(x_k_1)
        m = F @ (m0 - x_k_1) + transition_function(x_k_1)
        P = F @ P0 @ F.T + Q
        H = jac_observation_function(x_k)
        alpha = observation_function(x_k) + H @ (m - x_k)
    else:
        P = P0
        m = m0
        H = jac_observation_function(x_k_1)
        alpha = observation_function(x_k_1) + H @ (m0 - x_k_1)

    S = H @ P @ H.T + R
    K = jlinalg.solve(S, H @ P, sym_pos=True).T
    A = jnp.zeros_like(P0)

    b = m + K @ (y - alpha)
    C = P - (K @ S @ K.T)

    eta = jnp.zeros_like(m0)
    J = jnp.zeros_like(P0)

    return A, b, C, eta, J
def _make_associative_filtering_params_generic(observation_function,
                                               jac_observation_function, Rk,
                                               transition_function,
                                               jac_transition_function, x_k_1,
                                               x_k, Qk_1, yk):
    F = jac_transition_function(x_k_1)
    H = jac_observation_function(x_k)

    F_x_k_1 = F @ x_k_1
    x_k_hat = transition_function(x_k_1)

    alpha = observation_function(x_k) + H @ (x_k_hat - F_x_k_1 - x_k)
    residual = yk - alpha
    HQ = H @ Qk_1

    S = HQ @ H.T + Rk
    S_invH = jlinalg.solve(S, H, sym_pos=True)
    K = (S_invH @ Qk_1).T
    A = F - K @ H @ F
    b = K @ residual + x_k_hat - F_x_k_1
    C = Qk_1 - K @ H @ Qk_1

    HF = H @ F

    temp = (S_invH @ F).T
    eta = temp @ residual
    J = temp @ HF

    return A, b, C, eta, J
예제 #14
0
 def _bl_update(H, C, R, state):
     G, (α, _), μ, τ = state
     tr_inv_H = np.trace(solve(H, I, sym_pos="sym"))
     γ = n - α * tr_inv_H
     α = np.float32(n / (2 * R + tr_inv_H))
     β = np.float32((x.shape[0] - γ) / (2 * C))
     return G, (α, β), μ, τ
예제 #15
0
def lax_newton(fn, jac_fn, U, maxit, tol):

    Uold = U
    state = NewtonInfo(count=0, converged=0, fail=0, U=U)

    #    jac_fn = jacfwd(fn)
    def body(state):
        J = jac_fn(state.U, Uold)
        y = fn(state.U, Uold)
        delta = solve(J, y)
        #        delta = spsolve(csr_matrix(np.asarray(J)),y)
        U = state.U - delta
        res = norm(y / norm(U, np.inf), np.inf)
        converged1 = res < tol
        state._replace(count=state.count + 1,
                       converged=converged1,
                       fail=np.any(np.isnan(delta)),
                       U=U)

        #        print(state.count, state.res)
        return state

    J = jac_fn(state.U, Uold)
    y = fn(state.U, Uold)
    delta = solve(J, y)
    #    delta = spsolve(csr_matrix(np.asarray(J)),y)
    U = state.U - delta
    state._replace(U=U)
    state = lax.while_loop(
        lambda state: np.logical_and(
            np.logical_and(~state.converged, ~state.fail), state.count < maxit
        ), body, state)

    return state
예제 #16
0
 def _lm_update(θ, H, Je, y, Λ, state):
     α, β = Λ
     p = θ - solve(H + state.μ * I, Je, sym_pos="sym").T
     e = errors(p, x, y)
     C = obj.cost(e)
     R = obj.regularizer(θ)
     G = np.float32(β * C + α * R)
     return LMState(p, e, G, C, R, state.μ * μs)
예제 #17
0
def newton(fn, jac_fn, U):
    maxit = 20
    tol = 1e-8
    count = 0
    res = 100
    fail = 0
    Uold = U

    start = timeit.default_timer()
    J = jac_fn(U, Uold)
    y = fn(U, Uold)
    res0 = norm(y / norm(U, np.inf), np.inf)
    delta = solve(J, y)
    U = U - delta
    count = count + 1
    end = timeit.default_timer()
    print("time elapsed in first loop", end - start)
    print(count, res0)
    while (count < maxit and res > tol):
        start1 = timeit.default_timer()
        J = jac_fn(U, Uold)
        y = fn(U, Uold)
        res = norm(y / norm(U, np.inf), np.inf)
        delta = solve(J, y)
        U = U - delta
        count = count + 1
        end1 = timeit.default_timer()
        print(count, res)
        print("time per loop", end1 - start1)

    if fail == 0 and np.any(np.isnan(delta)):
        fail = 1
        print("nan solution")

    if fail == 0 and max(abs(np.imag(delta))) > 0:
        fail = 1
        print("solution complex")

    if fail == 0 and res > tol:
        fail = 1
        print('Newton fail: no convergence')
    else:
        fail == 0

    return U, fail
예제 #18
0
def newton_while_lax(fn, jac_fn, U, maxit, tol):
    
    
    count = 0
    res = 100
    fail = 0
    
    val = (U, count, res, fail)

    Uold = U
    J =  jac_fn(U, Uold)
    y = fn(U,Uold)  
    delta = solve(J,y)
    U = U - delta;
#    res0 = norm(y/norm(U,np.inf),np.inf)
    
   
    def cond_fun(val):
        U, count, res, _ = val
        res = norm(y/norm(U,np.inf),np.inf)
        print("res:",res)
        return np.logical_and(res > tol, count < maxit)
#    
   
    def body_fun(val):
        U, count, res, fail = val
        J = jac_fn(U,Uold);
        y = fn(U,Uold)
        delta = solve(J,y)
        U = U - delta
        res = norm(y/norm(U,np.inf),np.inf)
        count = count + 1
        print(count, res)
        val = U, count, res, fail
      
        return val
    
    val =lax.while_loop(cond_fun, body_fun, val )
    U, count, res, _ = val
   
        
    if fail ==0 and np.any(np.isnan(delta)):
        fail = 1
        print("nan solution")
        
    if fail == 0 and max(abs(np.imag(delta))) > 0:
            fail = 1
            print("solution complex")
    
    if fail == 0 and res > tol:
        fail = 1;
        print('Newton fail: no convergence')
    else:
        fail == 0 
        
    return U, fail
예제 #19
0
    def predict(self):
        """Computing the prediction mean and standard deviation

        Returns
        -------
        means : ndarray tuple
            tuple containing the mean components
        stds : ndarray tuple
            tuple containing the standard deviations
        """
        lambdam = self.getlambda()
        mean = self.Phi_pred_T @ jscl.solve(
            self.PhiTPhi + np.diag(self.sigma_n / lambdam),
            self.Phi.T @ self.y)
        std = np.sqrt(self.sigma_n * np.sum(
            self.Phi_pred_T *
            jscl.solve(self.PhiTPhi + np.diag(self.sigma_n / lambdam),
                       self.Phi_pred_T.T).T, 1))
        return (mean[::3], mean[1::3], mean[2::3]), (std[::3], std[1::3],
                                                     std[2::3])
예제 #20
0
def _make_associative_smoothing_params_generic(transition_function,
                                               jac_transition_function, Qk, mk,
                                               Pk, xk):
    F = jac_transition_function(xk)
    Pp = F @ Pk @ F.T + Qk

    E = jlinalg.solve(Pp, F @ Pk, sym_pos=True).T

    g = mk - E @ (transition_function(xk) + F @ (mk - xk))
    L = Pk - E @ Pp @ E.T

    return g, E, L
예제 #21
0
def predict(transition_function: Callable,
            transition_covariance: jnp.ndarray,
            previous_state: MVNormalParameters,
            linearization_state: MVNormalParameters,
            return_linearized_transition: bool = False) -> MVNormalParameters:
    """ Computes the cubature Kalman filter linearization of :math:`x_{t+1} = f(x_t, \mathcal{N}(0, \Sigma))`

    Parameters
    ----------
    transition_function: callable :math:`f(x_t,\epsilon_t)\mapsto x_{t-1}`
        transition function of the state space model
    transition_covariance: (D,D) array
        covariance :math:`\Sigma` of the noise fed to transition_function
    previous_state: MVNormalParameters
        previous state for the filter x
    linearization_state: MVNormalParameters
        state for the linearization of the prediction
    return_linearized_transition: bool, optional
        Returns the linearized transition matrix A

    Returns
    -------
    mvn_parameters: MVNormalParameters
        Propagated approximate Normal distribution

    F: array_like
        returned if return_linearized_transition is True
    """
    if linearization_state is None:
        linearization_state = previous_state

    sigma_points = get_sigma_points(linearization_state)
    propagated_points = transition_function(sigma_points.points)
    propagated_sigma_points = SigmaPoints(propagated_points, sigma_points.wm,
                                          sigma_points.wc)

    propagated_state = get_mv_normal_parameters(propagated_sigma_points)
    cross_covariance = covariance_sigma_points(sigma_points,
                                               linearization_state.mean,
                                               propagated_sigma_points,
                                               propagated_state.mean)

    F = jlinalg.solve(linearization_state.cov, cross_covariance,
                      sym_pos=True).T  # Linearized transition function
    b = propagated_state.mean - jnp.dot(
        F, linearization_state.mean)  # Linearized offset

    mean = F @ previous_state.mean + b
    cov = transition_covariance + propagated_state.cov + F @ (
        previous_state.cov - linearization_state.cov) @ F.T
    if return_linearized_transition:
        return MVNormalParameters(mean, cov), F
    return MVNormalParameters(mean, 0.5 * (cov + cov.T))
예제 #22
0
 def body_fun(val):
     U, count, res, fail = val
     J = jac_fn(U,Uold);
     y = fn(U,Uold)
     delta = solve(J,y)
     U = U - delta
     res = norm(y/norm(U,np.inf),np.inf)
     count = count + 1
     print(count, res)
     val = U, count, res, fail
   
     return val
예제 #23
0
def update(
        observation_function: Callable[[jnp.ndarray], jnp.ndarray],
        observation_covariance: jnp.ndarray, predicted: MVNormalParameters,
        observation: jnp.ndarray,
        linearization_point: jnp.ndarray) -> Tuple[float, MVNormalParameters]:
    """ Computes the extended kalman filter linearization of :math:`x_t \mid y_t`

    Parameters
    ----------
    observation_function: callable :math:`h(x_t,\epsilon_t)\mapsto y_t`
        observation function of the state space model
    observation_covariance: (K,K) array
        observation_error :math:`\Sigma` fed to observation_function
    predicted: MVNormalParameters
        predicted state of the filter :math:`x`
    observation: (K) array
        Observation :math:`y`
    linearization_point: jnp.ndarray
        Where to compute the Jacobian

    Returns
    -------
    loglikelihood: float
        Log-likelihood increment for observation
    updated_state: MVNormalParameters
        filtered state
    """
    if linearization_point is None:
        linearization_point = predicted.mean
    jac_x = jacfwd(observation_function, 0)(linearization_point)

    obs_mean = observation_function(linearization_point) + jnp.dot(
        jac_x, predicted.mean - linearization_point)

    residual = observation - obs_mean
    residual_covariance = jnp.dot(jac_x, jnp.dot(predicted.cov, jac_x.T))
    residual_covariance = residual_covariance + observation_covariance

    gain = jnp.dot(predicted.cov,
                   jlag.solve(residual_covariance, jac_x, sym_pos=True).T)

    mean = predicted.mean + jnp.dot(gain, residual)
    cov = predicted.cov - jnp.dot(gain, jnp.dot(residual_covariance, gain.T))
    updated_state = MVNormalParameters(mean, 0.5 * (cov + cov.T))

    loglikelihood = multivariate_normal.logpdf(residual,
                                               jnp.zeros_like(residual),
                                               residual_covariance)
    return loglikelihood, updated_state
예제 #24
0
    def body(state):
        J = jac_fn(state.U, Uold)
        y = fn(state.U, Uold)
        delta = solve(J, y)
        #        delta = spsolve(csr_matrix(np.asarray(J)),y)
        U = state.U - delta
        res = norm(y / norm(U, np.inf), np.inf)
        converged1 = res < tol
        state._replace(count=state.count + 1,
                       converged=converged1,
                       fail=np.any(np.isnan(delta)),
                       U=U)

        #        print(state.count, state.res)
        return state
def smooth(transition_function: Callable[[jnp.ndarray], jnp.ndarray],
           transition_covariance: jnp.array,
           filtered_state: MVNormalParameters,
           previous_smoothed: MVNormalParameters,
           linearization_point: jnp.ndarray) -> MVNormalParameters:
    """
    One step extended kalman smoother

    Parameters
    ----------
    transition_function: callable :math:`f(x_t,\epsilon_t)\mapsto x_{t-1}`
        transition function of the state space model
    transition_covariance: (D,D) array
        covariance :math:`\Sigma` of the noise fed to transition_function
    filtered_state: MVNormalParameters
        mean and cov computed by Kalman Filtering
    previous_smoothed: MVNormalParameters,
        smoothed state of the previous step
    linearization_point: jnp.ndarray
        Where to compute the Jacobian

    Returns
    -------
    smoothed_state: MVNormalParameters
        smoothed state
    """

    jac_x = jacfwd(transition_function, 0)(linearization_point)

    mean = transition_function(linearization_point) + jnp.dot(
        jac_x, filtered_state.mean - linearization_point)
    mean_diff = previous_smoothed.mean - mean

    cov = jnp.dot(jac_x, jnp.dot(filtered_state.cov,
                                 jac_x.T)) + transition_covariance
    cov_diff = previous_smoothed.cov - cov

    gain = jnp.dot(filtered_state.cov, jlag.solve(cov, jac_x, sym_pos=True).T)

    mean = filtered_state.mean + jnp.dot(gain, mean_diff)
    cov = filtered_state.cov + jnp.dot(gain, jnp.dot(cov_diff, gain.T))
    return MVNormalParameters(mean, cov)
def stream_vel(bb):
    n = grid_size
    h, beta_fric, dx = stream_vel_init(n, rhoi, g)
    beta_fric = bb + beta_fric
    f, fend = stream_vel_taud(h, n, dx, rhoi, g)
    u = jnp.zeros(n + 1)
    #driving stress
    f_plus1 = jnp.roll(f, -1)
    b = jnp.append(-dx * f[0:n - 1] - f_plus1[0:n - 1] * dx,
                   -dx * f[n - 1] - f_plus1[n - 1] * dx + fend)

    for i in range(n_nl):
        #update viscosities
        nu = stream_vel_visc(h, u, n, dx)
        # assemble tridiag matrix. This represents the discretization of
        #  (nu^(i-1) u^(i)_x)_x - \beta^2 u^(i) = f
        A = stream_assemble(nu, beta_fric, n, dx)

        # solve linear system for new u
        # effectively apply boundary condition u(0)==0
        u = jnp.append(jnp.zeros(1), la.solve(A, b))
    return u
예제 #27
0
def _T_bar(F: np.ndarray, N_T_inv: np.ndarray,
           N_inv_d: np.ndarray) -> np.ndarray:
    """Function to calculate the expected component amplitudes, `T_bar`.
    This is an implementation of Equation (A4) in 1608.00551. See also
    Equation (A10) for interpretation.

    Parameters
    ----------
    F: ndarray
        SED matrix
    N_T_inv: ndarray
        Inverse component covariance.
    N_inv_d: ndarray
        Inverse covariance-weighted data.

    Returns
    -------
    ndarray
        T_bar, the expected component amplitude.
    """
    y = np.sum(F[None, :, :] * N_inv_d[:, None, :], axis=2)
    return linalg.solve(N_T_inv, y)
예제 #28
0
def newton_tol(fn, jac_fn, U,tol):
    maxit=20
    count = 0
    res = 100
    fail = 0
    Uold = U
    
    while(count < maxit and res > tol):
        J =  jac_fn(U, Uold)
#        J = jacrev(fn)(U,Uold)
#        Jsparse = csr_matrix(J)
        y = fn(U,Uold)
        res = max(abs(y/norm(y,2)))
        print(count, res)
        delta = solve(J,y)
#        delta = jitsolve(J,fn(U, Uold))
#        delta = spsolve(csr_matrix(J),fn(U,Uold))
        U = U - delta
        count = count + 1
    
        
    if fail ==0 and np.any(np.isnan(delta)):
        fail = 1
        print("nan solution")
        
    if fail == 0 and max(abs(np.imag(delta))) > 0:
            fail = 1
            print("solution complex")
    
    if fail == 0 and res > tol:
        fail = 1;
        print('Newton fail: no convergence')
    else:
        fail == 0 
        
    return U, fail
예제 #29
0
def _make_associative_filtering_params_first(
        observation_function, R, transition_function, Q, initial_state,
        prev_linearization_state, linearization_state, y, propagate_first):
    # Prediction part

    if propagate_first:
        initial_sigma_points = get_sigma_points(prev_linearization_state)
        propagated_points = transition_function(initial_sigma_points.points)
        propagated_sigma_points = SigmaPoints(propagated_points,
                                              initial_sigma_points.wm,
                                              initial_sigma_points.wc)
        propagated_state = get_mv_normal_parameters(propagated_sigma_points)

        pred_cross_covariance = covariance_sigma_points(
            initial_sigma_points, prev_linearization_state.mean,
            propagated_sigma_points, propagated_state.mean)

        F = jlinalg.solve(prev_linearization_state.cov,
                          pred_cross_covariance,
                          sym_pos=True).T  # Linearized transition function

        m = propagated_state.mean + F @ (initial_state.mean -
                                         prev_linearization_state.mean)
        P = propagated_state.cov + Q + F @ (initial_state.cov -
                                            prev_linearization_state.cov) @ F.T
        linearization_points = get_sigma_points(linearization_state)
        obs_points = observation_function(linearization_points.points)
        obs_sigma_points = SigmaPoints(obs_points, linearization_points.wm,
                                       linearization_points.wc)
        obs_mvn = get_mv_normal_parameters(obs_sigma_points)
        update_cross_covariance = covariance_sigma_points(
            linearization_points, linearization_state.mean, obs_sigma_points,
            obs_mvn.mean)

        H = jlinalg.solve(linearization_state.cov,
                          update_cross_covariance,
                          sym_pos=True).T
        d = obs_mvn.mean - jnp.dot(H, linearization_state.mean)
        predicted_observation = H @ m + d

        S = H @ (P - linearization_state.cov) @ H.T + R + obs_mvn.cov
    else:
        m = initial_state.mean
        P = initial_state.cov
        linearization_points = get_sigma_points(prev_linearization_state)
        obs_points = observation_function(linearization_points.points)
        obs_sigma_points = SigmaPoints(obs_points, linearization_points.wm,
                                       linearization_points.wc)
        obs_mvn = get_mv_normal_parameters(obs_sigma_points)
        update_cross_covariance = covariance_sigma_points(
            linearization_points, linearization_state.mean, obs_sigma_points,
            obs_mvn.mean)

        H = jlinalg.solve(prev_linearization_state.cov,
                          update_cross_covariance,
                          sym_pos=True).T
        d = obs_mvn.mean - jnp.dot(H, prev_linearization_state.mean)
        predicted_observation = H @ m + d

        S = H @ (P - prev_linearization_state.cov) @ H.T + R + obs_mvn.cov

    K = jlinalg.solve(S, H @ P, sym_pos=True).T
    A = jnp.zeros_like(initial_state.cov)
    b = m + K @ (y - predicted_observation)
    C = P - K @ S @ K.T

    eta = jnp.zeros_like(initial_state.mean)
    J = jnp.zeros_like(initial_state.cov)

    return A, b, 0.5 * (C + C.T), eta, 0.5 * (J + J.T)
예제 #30
0
def psd_inv_cholesky(matrix: jnp.ndarray, damping: jnp.ndarray) -> jnp.ndarray:
    assert matrix.ndim == 2
    identity = jnp.eye(matrix.shape[0])
    matrix = matrix + damping * identity
    return linalg.solve(matrix, identity, sym_pos=True)