def newton(fn, jac_fn, U): maxit=20 tol = 1e-8 count = 0 res = 100 fail = 0 Uold = U maxit=5 # # @jax.jit # def body_fun(U,Uold): # J = jac_fn(U, Uold) # y = fn(U,Uold) # res = norm(y/norm(U,np.inf),np.inf) # delta = solve(J,y) # U = U - delta # return U, res # print("here") start =timeit.default_timer() J = jac_fn(U, Uold) print("computed jacobian") y = fn(U,Uold) res0 = norm(y/norm(U,np.inf),np.inf) delta = solve(J,y) U = U - delta count = count + 1 end = timeit.default_timer() print("time elapsed in first loop", end-start) print(count, res0) while(count < maxit and res > tol): # U, res, delta = body_fun(U,Uold)\ start1 =timeit.default_timer() J = jac_fn(U, Uold) y = fn(U,Uold) res = norm(y/norm(U,np.inf),np.inf) delta = solve(J,y) U = U - delta count = count + 1 end1 =timeit.default_timer() print("time per loop", end1-start1) print(count, res) if fail ==0 and np.any(np.isnan(delta)): fail = 1 print("nan solution") if fail == 0 and max(abs(np.imag(delta))) > 0: fail = 1 print("solution complex") if fail == 0 and res > tol: fail = 1; print('Newton fail: no convergence') else: fail == 0 return U, fail
def _make_associative_smoothing_params_generic(transition_function, Qk, filtered_state, linearization_state): # Prediction part sigma_points = get_sigma_points(linearization_state) propagated_points = transition_function(sigma_points.points) propagated_sigma_points = SigmaPoints(propagated_points, sigma_points.wm, sigma_points.wc) propagated_state = get_mv_normal_parameters(propagated_sigma_points) pred_cross_covariance = covariance_sigma_points(sigma_points, linearization_state.mean, propagated_sigma_points, propagated_state.mean) F = jlinalg.solve(linearization_state.cov, pred_cross_covariance, sym_pos=True).T # Linearized transition function Pp = Qk + propagated_state.cov + F @ (filtered_state.cov - linearization_state.cov) @ F.T E = jlinalg.solve(Pp, F @ filtered_state.cov, sym_pos=True).T g = filtered_state.mean - E @ (propagated_state.mean + F @ ( filtered_state.mean - linearization_state.mean)) L = filtered_state.cov - E @ F @ filtered_state.cov return g, E, 0.5 * (L + L.T)
def filtering_operator(elem1, elem2): """ Associative operator described in TODO: put the reference Parameters ---------- elem1: tuple of array a_i, b_i, C_i, eta_i, J_i elem2: tuple of array a_j, b_j, C_j, eta_j, J_j Returns ------- """ A1, b1, C1, eta1, J1 = elem1 A2, b2, C2, eta2, J2 = elem2 dim = b1.shape[0] I_dim = jnp.eye(dim) IpCJ = I_dim + jnp.dot(C1, J2) IpJC = I_dim + jnp.dot(J2, C1) AIpCJ_inv = jlinalg.solve(IpCJ.T, A2.T, sym_pos=False).T AIpJC_inv = jlinalg.solve(IpJC.T, A1, sym_pos=False).T A = jnp.dot(AIpCJ_inv, A1) b = jnp.dot(AIpCJ_inv, b1 + jnp.dot(C1, eta2)) + b2 C = jnp.dot(AIpCJ_inv, jnp.dot(C1, A2.T)) + C2 eta = jnp.dot(AIpJC_inv, eta2 - jnp.dot(J2, b1)) + eta1 J = jnp.dot(AIpJC_inv, jnp.dot(J2, A1)) + J1 return A, b, 0.5 * (C + C.T), eta, 0.5 * (J + J.T)
def update(observation_function: Callable, observation_covariance: jnp.ndarray, predicted_state: MVNormalParameters, observation: jnp.ndarray, linearization_state: MVNormalParameters) -> MVNormalParameters: """ Computes the extended kalman filter linearization of :math:`x_t \mid y_t` Parameters ---------- observation_function: callable :math:`h(x_t,\epsilon_t)\mapsto y_t` observation function of the state space model observation_covariance: (K,K) array observation_error :math:`\Sigma` fed to observation_function predicted_state: MVNormalParameters predicted approximate mv normal parameters of the filter :math:`x` observation: (K) array Observation :math:`y` linearization_state: MVNormalParameters state for the linearization of the update Returns ------- updated_mvn_parameters: MVNormalParameters filtered state """ if linearization_state is None: linearization_state = predicted_state sigma_points = get_sigma_points(linearization_state) obs_points = observation_function(sigma_points.points) obs_sigma_points = SigmaPoints(obs_points, sigma_points.wm, sigma_points.wc) obs_state = get_mv_normal_parameters(obs_sigma_points) cross_covariance = covariance_sigma_points(sigma_points, linearization_state.mean, obs_sigma_points, obs_state.mean) H = jlinalg.solve(linearization_state.cov, cross_covariance, sym_pos=True).T # linearized observation function d = obs_state.mean - jnp.dot( H, linearization_state.mean) # linearized observation offset residual_cov = H @ (predicted_state.cov - linearization_state.cov) @ H.T + \ observation_covariance + obs_state.cov gain = jlinalg.solve(residual_cov, H @ predicted_state.cov).T predicted_observation = H @ predicted_state.mean + d residual = observation - predicted_observation mean = predicted_state.mean + gain @ residual cov = predicted_state.cov - gain @ residual_cov @ gain.T loglikelihood = multivariate_normal.logpdf(residual, jnp.zeros_like(residual), residual_cov) return loglikelihood, MVNormalParameters(mean, 0.5 * (cov + cov.T))
def _make_associative_filtering_params_generic(observation_function, Rk, transition_function, Qk_1, prev_linearization_state, linearization_state, yk): # Prediction part sigma_points = get_sigma_points(prev_linearization_state) propagated_points = transition_function(sigma_points.points) propagated_sigma_points = SigmaPoints(propagated_points, sigma_points.wm, sigma_points.wc) propagated_state = get_mv_normal_parameters(propagated_sigma_points) pred_cross_covariance = covariance_sigma_points( sigma_points, prev_linearization_state.mean, propagated_sigma_points, propagated_state.mean) F = jlinalg.solve(prev_linearization_state.cov, pred_cross_covariance, sym_pos=True).T # Linearized transition function pred_mean_residual = propagated_state.mean - F @ prev_linearization_state.mean pred_cov_residual = propagated_state.cov - F @ prev_linearization_state.cov @ F.T + Qk_1 # Update part linearization_points = get_sigma_points(linearization_state) obs_points = observation_function(linearization_points.points) obs_sigma_points = SigmaPoints(obs_points, linearization_points.wm, linearization_points.wc) obs_mvn = get_mv_normal_parameters(obs_sigma_points) update_cross_covariance = covariance_sigma_points(linearization_points, linearization_state.mean, obs_sigma_points, obs_mvn.mean) H = jlinalg.solve(linearization_state.cov, update_cross_covariance, sym_pos=True).T obs_mean_residual = obs_mvn.mean - jnp.dot(H, linearization_state.mean) obs_cov_residual = obs_mvn.cov - H @ linearization_state.cov @ H.T S = H @ pred_cov_residual @ H.T + Rk + obs_cov_residual # total residual covariance total_obs_residual = (yk - H @ pred_mean_residual - obs_mean_residual) S_invH = jlinalg.solve(S, H, sym_pos=True) K = (S_invH @ pred_cov_residual).T A = F - K @ H @ F b = pred_mean_residual + K @ total_obs_residual C = pred_cov_residual - K @ S @ K.T temp = (S_invH @ F).T HF = H @ F eta = temp @ total_obs_residual J = temp @ HF return A, b, 0.5 * (C + C.T), eta, 0.5 * (J + J.T)
def damped_newton(fn, jac_fn, U): maxit=10 tol = 1e-8 count = 0 res = 100 fail = 0 # U = jax.ops.index_update(U, jax.ops.index[jp02:etap02], U[jp02:etap02]*2**(-16)) Uold = U J = jac_fn(U, Uold) y = fn(U,Uold) delta = solve(J,y) U = U - delta; res0 = norm(y/norm(U,np.inf),np.inf) print(count, res0) while(count < maxit and res > tol): J = jac_fn(U, Uold) y = fn(U,Uold) res = norm(y/norm(U,np.inf),np.inf) # res=norm(y, np.inf) print(count, res) delta = solve(J,y) # alpha = 1.0 # while (norm( fn(U - alpha*delta,Uold )) > (1-alpha*0.5)*norm(y)): ## print("norm1",norm( fn(U - alpha*delta,Uold ))) ## print("norm2", (1-alpha*0.5)*norm(y) ) # alpha = alpha/2; ## print("alpha",alpha) # if (alpha < 1e-8): # break; # # U = U - alpha*delta U = U - delta; count = count + 1 if fail ==0 and np.any(np.isnan(delta)): fail = 1 print("nan solution") if fail == 0 and max(abs(np.imag(delta))) > 0: fail = 1 print("solution complex") if fail == 0 and res > tol: fail = 1; print('Newton fail: no convergence') else: fail == 0 return U, fail
def _make_associative_filtering_params_first(observation_function, R, transition_function, Q, initial_state, linearization_state, y): # Prediction part initial_sigma_points = get_sigma_points(initial_state) propagated_points = transition_function(initial_sigma_points.points) propagated_sigma_points = SigmaPoints(propagated_points, initial_sigma_points.wm, initial_sigma_points.wc) propagated_state = get_mv_normal_parameters(propagated_sigma_points) pred_cross_covariance = covariance_sigma_points(initial_sigma_points, initial_state.mean, propagated_sigma_points, propagated_state.mean) F = jlinalg.solve(initial_state.cov, pred_cross_covariance, sym_pos=True).T # Linearized transition function m1 = propagated_state.mean P1 = propagated_state.cov + Q # Update part linearization_points = get_sigma_points(linearization_state) obs_points = observation_function(linearization_points.points) obs_sigma_points = SigmaPoints(obs_points, linearization_points.wm, linearization_points.wc) obs_mvn = get_mv_normal_parameters(obs_sigma_points) update_cross_covariance = covariance_sigma_points(linearization_points, linearization_state.mean, obs_sigma_points, obs_mvn.mean) H = jlinalg.solve(linearization_state.cov, update_cross_covariance, sym_pos=True).T d = obs_mvn.mean - jnp.dot(H, linearization_state.mean) predicted_observation = H @ m1 + d S = H @ (P1 - linearization_state.cov) @ H.T + R + obs_mvn.cov K = jlinalg.solve(S, H @ P1, sym_pos=True).T A = jnp.zeros(F.shape) b = m1 + K @ (y - predicted_observation) C = P1 - K @ S @ K.T eta = jnp.zeros(F.shape[0]) J = jnp.zeros(F.shape) return A, b, 0.5 * (C + C.T), eta, J
def update(state): data, p_, e_, C_, mu, iters, _ = state x, y = data mu = np.float32(mu) # J = jacobian(p_, x, y) H = damped_hessian(J, mu) Je = jac_err_prod(J, e_, p_) # dp = solve(H, Je, sym_pos=True) p = p_ - dp e = error(p, x, y) C = cost(e, p) rho = (C_ - C) / (dp.T @ (mu * dp + Je)) # mu = np.where(rho > rho_c, np.maximum(mu / c, mu_min), mu) # bad_step = (rho < rho_min) | np.any(np.isnan(p)) mu = np.where(bad_step, np.minimum(c * mu, mu_max), mu) p = cond(bad_step, lambda t: t[0], lambda t: t[1], (p_, p)) e = cond(bad_step, lambda t: t[0], lambda t: t[1], (e_, e)) C = np.where(bad_step, C_, C) improved = (C_ > C) | bad_step # return LevenbergMarquardtState(data, p, e, C, mu, iters + ~bad_step, improved)
def update(state): data, p_, e_, C_, mu, alpha, iters, _ = state x, y = data mu = np.float32(mu) alpha_ = np.float32(alpha) # J = jacobian(p_, x, y) H = J.T @ J Je = J.T @ e_ + alpha_ * p_ I = np.diag_indices_from(H) # dp = solve(H.at[I].add(alpha_ + mu), Je, sym_pos=True) p = p_ - dp e = error(p, x, y) C = (sum_squares(e) + alpha * sum_squares(p)) / 2 rho = (C_ - C) / (dp.T @ (mu * dp + Je)) # mu = np.where(rho > rho_c, np.maximum(mu / c, mu_min), mu) # bad_step = (rho < rho_min) | np.any(np.isnan(p)) mu = np.where(bad_step, np.minimum(c * mu, mu_max), mu) p = cond(bad_step, lambda t: t[0], lambda t: t[1], (p_, p)) e = cond(bad_step, lambda t: t[0], lambda t: t[1], (e_, e)) # sse = sum_squares(e) ssp = sum_squares(p) C = np.where(bad_step, C_, C) improved = (C_ > C) | bad_step # bundle = (alpha, H, I, sse, ssp, x.size) alpha, *_ = cond(bad_step, lambda t: t, update_hyperparams, bundle) C = (sse + alpha * ssp) / 2 # return LevenbergMarquardtBRState(data, p, e, C, mu, alpha, iters + ~bad_step, improved)
def _make_associative_filtering_params(args): Hk, Rk, Fk_1, Qk_1, uk_1, yk, dk, I_dim = args # FIRST TERM ############ # temp variable HQ = jnp.dot(Hk, Qk_1) # Hk @ Qk_1 Sk = jnp.dot(HQ, Hk.T) + Rk Kk = jlnialg.solve( Sk, HQ, sym_pos=True).T # using the fact that S and Q are symmetric # temp variable: I_KH = I_dim - jnp.dot(Kk, Hk) # I - Kk @ Hk Ck = jnp.dot(I_KH, Qk_1) residual = (yk - jnp.dot(Hk, uk_1) - dk) bk = uk_1 + jnp.dot(Kk, residual) Ak = jnp.dot(I_KH, Fk_1) # SECOND TERM ############# HF = jnp.dot(Hk, Fk_1) FHS_inv = jsolve(Sk, HF).T etak = jnp.dot(FHS_inv, residual) Jk = jnp.dot(FHS_inv, HF) return Ak, bk, Ck, etak, Jk
def _make_associative_filtering_params_first(observation_function, jac_observation_function, R, transition_function, jac_transition_function, Q, m0, P0, x_k, y): F = jac_transition_function(m0) m1 = transition_function(m0) P1 = F @ P0 @ F.T + Q H = jac_observation_function(x_k) S = H @ P1 @ H.T + R K = jlinalg.solve(S, H @ P1, sym_pos=True).T A = jnp.zeros(F.shape) alpha = observation_function(x_k) + H @ (m1 - x_k) b = m1 + K @ (y - alpha) C = P1 - (K @ S @ K.T) eta = jnp.zeros(F.shape[0]) J = jnp.zeros(F.shape) return A, b, C, eta, J
def _make_associative_filtering_params_first( observation_function, jac_observation_function, R, transition_function, jac_transition_function, Q, m0, P0, x_k_1, x_k, y, propagate_first): if propagate_first: F = jac_transition_function(x_k_1) m = F @ (m0 - x_k_1) + transition_function(x_k_1) P = F @ P0 @ F.T + Q H = jac_observation_function(x_k) alpha = observation_function(x_k) + H @ (m - x_k) else: P = P0 m = m0 H = jac_observation_function(x_k_1) alpha = observation_function(x_k_1) + H @ (m0 - x_k_1) S = H @ P @ H.T + R K = jlinalg.solve(S, H @ P, sym_pos=True).T A = jnp.zeros_like(P0) b = m + K @ (y - alpha) C = P - (K @ S @ K.T) eta = jnp.zeros_like(m0) J = jnp.zeros_like(P0) return A, b, C, eta, J
def _make_associative_filtering_params_generic(observation_function, jac_observation_function, Rk, transition_function, jac_transition_function, x_k_1, x_k, Qk_1, yk): F = jac_transition_function(x_k_1) H = jac_observation_function(x_k) F_x_k_1 = F @ x_k_1 x_k_hat = transition_function(x_k_1) alpha = observation_function(x_k) + H @ (x_k_hat - F_x_k_1 - x_k) residual = yk - alpha HQ = H @ Qk_1 S = HQ @ H.T + Rk S_invH = jlinalg.solve(S, H, sym_pos=True) K = (S_invH @ Qk_1).T A = F - K @ H @ F b = K @ residual + x_k_hat - F_x_k_1 C = Qk_1 - K @ H @ Qk_1 HF = H @ F temp = (S_invH @ F).T eta = temp @ residual J = temp @ HF return A, b, C, eta, J
def _bl_update(H, C, R, state): G, (α, _), μ, τ = state tr_inv_H = np.trace(solve(H, I, sym_pos="sym")) γ = n - α * tr_inv_H α = np.float32(n / (2 * R + tr_inv_H)) β = np.float32((x.shape[0] - γ) / (2 * C)) return G, (α, β), μ, τ
def lax_newton(fn, jac_fn, U, maxit, tol): Uold = U state = NewtonInfo(count=0, converged=0, fail=0, U=U) # jac_fn = jacfwd(fn) def body(state): J = jac_fn(state.U, Uold) y = fn(state.U, Uold) delta = solve(J, y) # delta = spsolve(csr_matrix(np.asarray(J)),y) U = state.U - delta res = norm(y / norm(U, np.inf), np.inf) converged1 = res < tol state._replace(count=state.count + 1, converged=converged1, fail=np.any(np.isnan(delta)), U=U) # print(state.count, state.res) return state J = jac_fn(state.U, Uold) y = fn(state.U, Uold) delta = solve(J, y) # delta = spsolve(csr_matrix(np.asarray(J)),y) U = state.U - delta state._replace(U=U) state = lax.while_loop( lambda state: np.logical_and( np.logical_and(~state.converged, ~state.fail), state.count < maxit ), body, state) return state
def _lm_update(θ, H, Je, y, Λ, state): α, β = Λ p = θ - solve(H + state.μ * I, Je, sym_pos="sym").T e = errors(p, x, y) C = obj.cost(e) R = obj.regularizer(θ) G = np.float32(β * C + α * R) return LMState(p, e, G, C, R, state.μ * μs)
def newton(fn, jac_fn, U): maxit = 20 tol = 1e-8 count = 0 res = 100 fail = 0 Uold = U start = timeit.default_timer() J = jac_fn(U, Uold) y = fn(U, Uold) res0 = norm(y / norm(U, np.inf), np.inf) delta = solve(J, y) U = U - delta count = count + 1 end = timeit.default_timer() print("time elapsed in first loop", end - start) print(count, res0) while (count < maxit and res > tol): start1 = timeit.default_timer() J = jac_fn(U, Uold) y = fn(U, Uold) res = norm(y / norm(U, np.inf), np.inf) delta = solve(J, y) U = U - delta count = count + 1 end1 = timeit.default_timer() print(count, res) print("time per loop", end1 - start1) if fail == 0 and np.any(np.isnan(delta)): fail = 1 print("nan solution") if fail == 0 and max(abs(np.imag(delta))) > 0: fail = 1 print("solution complex") if fail == 0 and res > tol: fail = 1 print('Newton fail: no convergence') else: fail == 0 return U, fail
def newton_while_lax(fn, jac_fn, U, maxit, tol): count = 0 res = 100 fail = 0 val = (U, count, res, fail) Uold = U J = jac_fn(U, Uold) y = fn(U,Uold) delta = solve(J,y) U = U - delta; # res0 = norm(y/norm(U,np.inf),np.inf) def cond_fun(val): U, count, res, _ = val res = norm(y/norm(U,np.inf),np.inf) print("res:",res) return np.logical_and(res > tol, count < maxit) # def body_fun(val): U, count, res, fail = val J = jac_fn(U,Uold); y = fn(U,Uold) delta = solve(J,y) U = U - delta res = norm(y/norm(U,np.inf),np.inf) count = count + 1 print(count, res) val = U, count, res, fail return val val =lax.while_loop(cond_fun, body_fun, val ) U, count, res, _ = val if fail ==0 and np.any(np.isnan(delta)): fail = 1 print("nan solution") if fail == 0 and max(abs(np.imag(delta))) > 0: fail = 1 print("solution complex") if fail == 0 and res > tol: fail = 1; print('Newton fail: no convergence') else: fail == 0 return U, fail
def predict(self): """Computing the prediction mean and standard deviation Returns ------- means : ndarray tuple tuple containing the mean components stds : ndarray tuple tuple containing the standard deviations """ lambdam = self.getlambda() mean = self.Phi_pred_T @ jscl.solve( self.PhiTPhi + np.diag(self.sigma_n / lambdam), self.Phi.T @ self.y) std = np.sqrt(self.sigma_n * np.sum( self.Phi_pred_T * jscl.solve(self.PhiTPhi + np.diag(self.sigma_n / lambdam), self.Phi_pred_T.T).T, 1)) return (mean[::3], mean[1::3], mean[2::3]), (std[::3], std[1::3], std[2::3])
def _make_associative_smoothing_params_generic(transition_function, jac_transition_function, Qk, mk, Pk, xk): F = jac_transition_function(xk) Pp = F @ Pk @ F.T + Qk E = jlinalg.solve(Pp, F @ Pk, sym_pos=True).T g = mk - E @ (transition_function(xk) + F @ (mk - xk)) L = Pk - E @ Pp @ E.T return g, E, L
def predict(transition_function: Callable, transition_covariance: jnp.ndarray, previous_state: MVNormalParameters, linearization_state: MVNormalParameters, return_linearized_transition: bool = False) -> MVNormalParameters: """ Computes the cubature Kalman filter linearization of :math:`x_{t+1} = f(x_t, \mathcal{N}(0, \Sigma))` Parameters ---------- transition_function: callable :math:`f(x_t,\epsilon_t)\mapsto x_{t-1}` transition function of the state space model transition_covariance: (D,D) array covariance :math:`\Sigma` of the noise fed to transition_function previous_state: MVNormalParameters previous state for the filter x linearization_state: MVNormalParameters state for the linearization of the prediction return_linearized_transition: bool, optional Returns the linearized transition matrix A Returns ------- mvn_parameters: MVNormalParameters Propagated approximate Normal distribution F: array_like returned if return_linearized_transition is True """ if linearization_state is None: linearization_state = previous_state sigma_points = get_sigma_points(linearization_state) propagated_points = transition_function(sigma_points.points) propagated_sigma_points = SigmaPoints(propagated_points, sigma_points.wm, sigma_points.wc) propagated_state = get_mv_normal_parameters(propagated_sigma_points) cross_covariance = covariance_sigma_points(sigma_points, linearization_state.mean, propagated_sigma_points, propagated_state.mean) F = jlinalg.solve(linearization_state.cov, cross_covariance, sym_pos=True).T # Linearized transition function b = propagated_state.mean - jnp.dot( F, linearization_state.mean) # Linearized offset mean = F @ previous_state.mean + b cov = transition_covariance + propagated_state.cov + F @ ( previous_state.cov - linearization_state.cov) @ F.T if return_linearized_transition: return MVNormalParameters(mean, cov), F return MVNormalParameters(mean, 0.5 * (cov + cov.T))
def body_fun(val): U, count, res, fail = val J = jac_fn(U,Uold); y = fn(U,Uold) delta = solve(J,y) U = U - delta res = norm(y/norm(U,np.inf),np.inf) count = count + 1 print(count, res) val = U, count, res, fail return val
def update( observation_function: Callable[[jnp.ndarray], jnp.ndarray], observation_covariance: jnp.ndarray, predicted: MVNormalParameters, observation: jnp.ndarray, linearization_point: jnp.ndarray) -> Tuple[float, MVNormalParameters]: """ Computes the extended kalman filter linearization of :math:`x_t \mid y_t` Parameters ---------- observation_function: callable :math:`h(x_t,\epsilon_t)\mapsto y_t` observation function of the state space model observation_covariance: (K,K) array observation_error :math:`\Sigma` fed to observation_function predicted: MVNormalParameters predicted state of the filter :math:`x` observation: (K) array Observation :math:`y` linearization_point: jnp.ndarray Where to compute the Jacobian Returns ------- loglikelihood: float Log-likelihood increment for observation updated_state: MVNormalParameters filtered state """ if linearization_point is None: linearization_point = predicted.mean jac_x = jacfwd(observation_function, 0)(linearization_point) obs_mean = observation_function(linearization_point) + jnp.dot( jac_x, predicted.mean - linearization_point) residual = observation - obs_mean residual_covariance = jnp.dot(jac_x, jnp.dot(predicted.cov, jac_x.T)) residual_covariance = residual_covariance + observation_covariance gain = jnp.dot(predicted.cov, jlag.solve(residual_covariance, jac_x, sym_pos=True).T) mean = predicted.mean + jnp.dot(gain, residual) cov = predicted.cov - jnp.dot(gain, jnp.dot(residual_covariance, gain.T)) updated_state = MVNormalParameters(mean, 0.5 * (cov + cov.T)) loglikelihood = multivariate_normal.logpdf(residual, jnp.zeros_like(residual), residual_covariance) return loglikelihood, updated_state
def body(state): J = jac_fn(state.U, Uold) y = fn(state.U, Uold) delta = solve(J, y) # delta = spsolve(csr_matrix(np.asarray(J)),y) U = state.U - delta res = norm(y / norm(U, np.inf), np.inf) converged1 = res < tol state._replace(count=state.count + 1, converged=converged1, fail=np.any(np.isnan(delta)), U=U) # print(state.count, state.res) return state
def smooth(transition_function: Callable[[jnp.ndarray], jnp.ndarray], transition_covariance: jnp.array, filtered_state: MVNormalParameters, previous_smoothed: MVNormalParameters, linearization_point: jnp.ndarray) -> MVNormalParameters: """ One step extended kalman smoother Parameters ---------- transition_function: callable :math:`f(x_t,\epsilon_t)\mapsto x_{t-1}` transition function of the state space model transition_covariance: (D,D) array covariance :math:`\Sigma` of the noise fed to transition_function filtered_state: MVNormalParameters mean and cov computed by Kalman Filtering previous_smoothed: MVNormalParameters, smoothed state of the previous step linearization_point: jnp.ndarray Where to compute the Jacobian Returns ------- smoothed_state: MVNormalParameters smoothed state """ jac_x = jacfwd(transition_function, 0)(linearization_point) mean = transition_function(linearization_point) + jnp.dot( jac_x, filtered_state.mean - linearization_point) mean_diff = previous_smoothed.mean - mean cov = jnp.dot(jac_x, jnp.dot(filtered_state.cov, jac_x.T)) + transition_covariance cov_diff = previous_smoothed.cov - cov gain = jnp.dot(filtered_state.cov, jlag.solve(cov, jac_x, sym_pos=True).T) mean = filtered_state.mean + jnp.dot(gain, mean_diff) cov = filtered_state.cov + jnp.dot(gain, jnp.dot(cov_diff, gain.T)) return MVNormalParameters(mean, cov)
def stream_vel(bb): n = grid_size h, beta_fric, dx = stream_vel_init(n, rhoi, g) beta_fric = bb + beta_fric f, fend = stream_vel_taud(h, n, dx, rhoi, g) u = jnp.zeros(n + 1) #driving stress f_plus1 = jnp.roll(f, -1) b = jnp.append(-dx * f[0:n - 1] - f_plus1[0:n - 1] * dx, -dx * f[n - 1] - f_plus1[n - 1] * dx + fend) for i in range(n_nl): #update viscosities nu = stream_vel_visc(h, u, n, dx) # assemble tridiag matrix. This represents the discretization of # (nu^(i-1) u^(i)_x)_x - \beta^2 u^(i) = f A = stream_assemble(nu, beta_fric, n, dx) # solve linear system for new u # effectively apply boundary condition u(0)==0 u = jnp.append(jnp.zeros(1), la.solve(A, b)) return u
def _T_bar(F: np.ndarray, N_T_inv: np.ndarray, N_inv_d: np.ndarray) -> np.ndarray: """Function to calculate the expected component amplitudes, `T_bar`. This is an implementation of Equation (A4) in 1608.00551. See also Equation (A10) for interpretation. Parameters ---------- F: ndarray SED matrix N_T_inv: ndarray Inverse component covariance. N_inv_d: ndarray Inverse covariance-weighted data. Returns ------- ndarray T_bar, the expected component amplitude. """ y = np.sum(F[None, :, :] * N_inv_d[:, None, :], axis=2) return linalg.solve(N_T_inv, y)
def newton_tol(fn, jac_fn, U,tol): maxit=20 count = 0 res = 100 fail = 0 Uold = U while(count < maxit and res > tol): J = jac_fn(U, Uold) # J = jacrev(fn)(U,Uold) # Jsparse = csr_matrix(J) y = fn(U,Uold) res = max(abs(y/norm(y,2))) print(count, res) delta = solve(J,y) # delta = jitsolve(J,fn(U, Uold)) # delta = spsolve(csr_matrix(J),fn(U,Uold)) U = U - delta count = count + 1 if fail ==0 and np.any(np.isnan(delta)): fail = 1 print("nan solution") if fail == 0 and max(abs(np.imag(delta))) > 0: fail = 1 print("solution complex") if fail == 0 and res > tol: fail = 1; print('Newton fail: no convergence') else: fail == 0 return U, fail
def _make_associative_filtering_params_first( observation_function, R, transition_function, Q, initial_state, prev_linearization_state, linearization_state, y, propagate_first): # Prediction part if propagate_first: initial_sigma_points = get_sigma_points(prev_linearization_state) propagated_points = transition_function(initial_sigma_points.points) propagated_sigma_points = SigmaPoints(propagated_points, initial_sigma_points.wm, initial_sigma_points.wc) propagated_state = get_mv_normal_parameters(propagated_sigma_points) pred_cross_covariance = covariance_sigma_points( initial_sigma_points, prev_linearization_state.mean, propagated_sigma_points, propagated_state.mean) F = jlinalg.solve(prev_linearization_state.cov, pred_cross_covariance, sym_pos=True).T # Linearized transition function m = propagated_state.mean + F @ (initial_state.mean - prev_linearization_state.mean) P = propagated_state.cov + Q + F @ (initial_state.cov - prev_linearization_state.cov) @ F.T linearization_points = get_sigma_points(linearization_state) obs_points = observation_function(linearization_points.points) obs_sigma_points = SigmaPoints(obs_points, linearization_points.wm, linearization_points.wc) obs_mvn = get_mv_normal_parameters(obs_sigma_points) update_cross_covariance = covariance_sigma_points( linearization_points, linearization_state.mean, obs_sigma_points, obs_mvn.mean) H = jlinalg.solve(linearization_state.cov, update_cross_covariance, sym_pos=True).T d = obs_mvn.mean - jnp.dot(H, linearization_state.mean) predicted_observation = H @ m + d S = H @ (P - linearization_state.cov) @ H.T + R + obs_mvn.cov else: m = initial_state.mean P = initial_state.cov linearization_points = get_sigma_points(prev_linearization_state) obs_points = observation_function(linearization_points.points) obs_sigma_points = SigmaPoints(obs_points, linearization_points.wm, linearization_points.wc) obs_mvn = get_mv_normal_parameters(obs_sigma_points) update_cross_covariance = covariance_sigma_points( linearization_points, linearization_state.mean, obs_sigma_points, obs_mvn.mean) H = jlinalg.solve(prev_linearization_state.cov, update_cross_covariance, sym_pos=True).T d = obs_mvn.mean - jnp.dot(H, prev_linearization_state.mean) predicted_observation = H @ m + d S = H @ (P - prev_linearization_state.cov) @ H.T + R + obs_mvn.cov K = jlinalg.solve(S, H @ P, sym_pos=True).T A = jnp.zeros_like(initial_state.cov) b = m + K @ (y - predicted_observation) C = P - K @ S @ K.T eta = jnp.zeros_like(initial_state.mean) J = jnp.zeros_like(initial_state.cov) return A, b, 0.5 * (C + C.T), eta, 0.5 * (J + J.T)
def psd_inv_cholesky(matrix: jnp.ndarray, damping: jnp.ndarray) -> jnp.ndarray: assert matrix.ndim == 2 identity = jnp.eye(matrix.shape[0]) matrix = matrix + damping * identity return linalg.solve(matrix, identity, sym_pos=True)