コード例 #1
0
    def expected_values(self, hist=None):
        """
        Compute the expected values for the parameters
        π_k, μ_k and Λ_k of the re-estimation equations
        for the lower-bound.

        Parameters
        ----------
        hist: dictionary
            Entry of the list returned by the .fit method
            if store_hist=True

        Returns
        -------
        * array(K): weight of each cluster
        * array(K, M): center of each cluster
        * array(K, M, M): covariance matrix of each cluster
        """
        if not self.fitted and hist is None:
            raise NotFittedError("This VBMixture instance is not fitted yet.")
        elif hist is None:
            pi_k = self.alpha_k / self.alpha_k.sum()
            mu_k = self.m_k.T
            Sigma_k = inv(self.eta_k[:, None, None] * self.W_k)
        else:
            pi_k = hist["alpha"] / hist["alpha"].sum()
            mu_k = hist["m"].T
            Sigma_k = inv(hist["eta"][:, None, None] * hist["W"])

        return pi_k, mu_k, Sigma_k
コード例 #2
0
    def __kalman_filter(self, x_hist):
        """
        Compute the online version of the Kalman-Filter, i.e,
        the one-step-ahead prediction for the hidden state or the
        time update step
        
        Parameters
        ----------
        x_hist: array(timesteps, observation_size)
            
        Returns
        -------
        * array(timesteps, state_size):
            Filtered means mut
        * array(timesteps, state_size, state_size)
            Filtered covariances Sigmat
        * array(timesteps, state_size)
            Filtered conditional means mut|t-1
        * array(timesteps, state_size, state_size)
            Filtered conditional covariances Sigmat|t-1
        """
        I = jnp.eye(self.state_size)
        mu_hist = jnp.zeros((self.timesteps, self.state_size))
        Sigma_hist = jnp.zeros(
            (self.timesteps, self.state_size, self.state_size))
        Sigma_cond_hist = jnp.zeros(
            (self.timesteps, self.state_size, self.state_size))
        mu_cond_hist = jnp.zeros((self.timesteps, self.state_size))

        # Initial configuration
        K1 = self.Sigma0 @ self.C.T @ inv(self.C @ self.Sigma0 @ self.C.T +
                                          self.R)
        mu1 = self.mu0 + K1 @ (x_hist[0] - self.C @ self.mu0)
        Sigma1 = (I - K1 @ self.C) @ self.Sigma0

        mu_hist = index_update(mu_hist, 0, mu1)
        Sigma_hist = index_update(Sigma_hist, 0, Sigma1)
        mu_cond_hist = index_update(mu_cond_hist, 0, self.mu0)
        Sigma_cond_hist = index_update(Sigma_hist, 0, self.Sigma0)

        Sigman = Sigma1
        for n in range(1, self.timesteps):
            # Sigman|{n-1}
            Sigman_cond = self.A @ Sigman @ self.A.T + self.Q
            St = self.C @ Sigman_cond @ self.C.T + self.R
            Kn = Sigman_cond @ self.C.T @ inv(St)

            # mun|{n-1} and xn|{n-1}
            mu_update = self.A @ mu_hist[n - 1]
            x_update = self.C @ mu_update

            mun = mu_update + Kn @ (x_hist[n] - x_update)
            Sigman = (I - Kn @ self.C) @ Sigman_cond

            mu_hist = index_update(mu_hist, n, mun)
            Sigma_hist = index_update(Sigma_hist, n, Sigman)
            mu_cond_hist = index_update(mu_cond_hist, n, mu_update)
            Sigma_cond_hist = index_update(Sigma_cond_hist, n, Sigman_cond)

        return mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist
コード例 #3
0
def inv_penalized_cov(x, lam=None, backend="cpu"):
    # assert on dimensions
    sys_platform = platform.system()
    if backend in ("gpu", "tpu") and (sys_platform in ("Linux", "Darwin")):
        if lam is None:
            return jla.inv(mo.crossprod(x=x, backend=backend))
        return jla.inv(
            mo.crossprod(x=x, backend=backend) + lam * jnp.eye(x.shape[1])
        )

    if lam is None:
        return la.inv(mo.crossprod(x))
    return la.inv(mo.crossprod(x) + lam * np.eye(x.shape[1]))
コード例 #4
0
ファイル: lasso.py プロジェクト: HaidYi/stor893
    def __init__(self, A, b, gamma, optimizer='ISTA'):
        super(Lasso, self).__init__()

        # initialize values
        self.A, self.b = A, b
        self.n, self.p = A.shape
        self.gamma = gamma
        self.optimizer = optimizer

        # cache results
        self.AtA = jnp.dot(A.T, A)
        if optimizer == 'ISTA' or optimizer == 'ADMM':
            self.Atb = jnp.dot(A.T, b)
            self.AtA_inverse = inv(self.AtA + jnp.eye(self.p))

        # calculate the Lipschitz constant
        eig_vals, _ = eigh(self.AtA)
        self.L_f = eig_vals[-1]
        self.lr = 1. / self.L_f

        # initialize parameters
        self.x = jnp.zeros(self.p)
        self.z = None
        if optimizer == 'FISTA':
            self.y = self.x
            self.t = 1
        elif optimizer == 'ADMM':
            self.z = jnp.zeros(self.p)
            self.u = jnp.zeros(self.p)

        self.grad_f = jit(grad(self.feval))
コード例 #5
0
ファイル: lds_lib.py プロジェクト: stjordanis/pyprobml
 def __smoother_step(self, state, elements):
     mut_giv_T, Sigmat_giv_T = state
     mutt, Sigmatt, mut_cond_next, Sigmat_cond_next = elements
     Jt  = Sigmatt @ self.A.T @ inv(Sigmat_cond_next)
     mut_giv_T = mutt + Jt @ (mut_giv_T - mut_cond_next)
     Sigmat_giv_T = Sigmatt + Jt @ (Sigmat_giv_T - Sigmat_cond_next) @ Jt.T
     return (mut_giv_T, Sigmat_giv_T), (mut_giv_T, Sigmat_giv_T)
コード例 #6
0
def solve_hinf(
    A: jnp.ndarray,
    B: jnp.ndarray,
    Q: jnp.ndarray,
    R: jnp.ndarray,
    T: jnp.ndarray,
    gamma_range: Tuple[Real, Real] = (0.1, 10.8),
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Description: solves H-infinity control problem

    Args:
        A (jnp.ndarray):
        B (jnp.ndarray):
        T (jnp.ndarray):
        Q (jnp.ndarray):
        R (jnp.ndarray):

    Returns
        Tuple[jnp.ndarray, jnp.ndarray]:
    """
    gamma_low, gamma_high = gamma_range
    n, m = B.shape
    M = [jnp.zeros(shape=(n, n))] * (T + 1)
    K = [jnp.zeros(shape=(m, n))] * T
    W = [jnp.zeros(shape=(n, 1))] * T

    while gamma_high - gamma_low > 0.01:
        K = [jnp.zeros(shape=(m, n))] * T
        W = [jnp.zeros(shape=(n, 1))] * T

        gamma = 0.5 * (gamma_high + gamma_low)
        M[T] = Q
        for t in range(T - 1, -1, -1):
            Lambda = jnp.eye(n) + (B @ inv(R) @ B.T - gamma**
                                   (-2) * jnp.eye(n)) @ M[t + 1]
            M[t] = Q + A.T @ M[t + 1] @ inv(Lambda) @ A

            K[t] = -inv(R) @ B.T @ M[t + 1] @ inv(Lambda) @ A
            W[t] = (gamma**(-2)) * M[t + 1] @ inv(Lambda) @ A
            if not is_psd(M[t], gamma):
                gamma_low = gamma
                break
        if gamma_low != gamma:
            gamma_high = gamma
    return K, W
コード例 #7
0
    def _posterior_params(self, X, r, alpha, beta, eta, m, W):
        """
        Compute the posterior parameters for each
        components of the mixture of gaussians
        (Variational M-Step)
        """
        Nk, xbar_k, Sk = self._compute_m_statistics(X, r)
        alpha_k = alpha + Nk
        beta_k = beta + Nk
        eta_k = eta + Nk

        m_k = (beta * m + Nk * xbar_k) / beta_k
        C0 = (beta * Nk) / (beta + Nk)
        f0 = (xbar_k - m)[:, None, :]
        W_k_inv = inv(W) + (Nk * Sk).T + C0[:, None, None] * jnp.einsum(
            "ijk,jik->kij", f0, f0)
        W_k = inv(W_k_inv)

        return alpha_k, beta_k, eta_k, m_k, W_k
コード例 #8
0
ファイル: linalg.py プロジェクト: dev-fennek/jax
def tensorinv(a, ind=2):
  a = np.asarray(a)
  oldshape = a.shape
  prod = 1
  if ind > 0:
    invshape = oldshape[ind:] + oldshape[:ind]
    for k in oldshape[ind:]:
      prod *= k
  else:
    raise ValueError("Invalid ind argument.")
  a = a.reshape(prod, -1)
  ia = la.inv(a)
  return ia.reshape(*invshape)
コード例 #9
0
    def __kalman_smoother(self, mu_hist, Sigma_hist, mu_cond_hist,
                          Sigma_cond_hist):
        """
        Compute the offline version of the Kalman-Filter, i.e,
        the kalman smoother for the hidden state.
        Note that we require to independently run the kalman_filter function first
        
        Parameters
        ----------
        mu_hist: array(timesteps, state_size):
            Filtered means mut
        Sigma_hist: array(timesteps, state_size, state_size)
            Filtered covariances Sigmat
        mu_cond_hist: array(timesteps, state_size)
            Filtered conditional means mut|t-1
        Sigma_cond_hist: array(timesteps, state_size, state_size)
            Filtered conditional covariances Sigmat|t-1
            
        Returns
        -------
        * array(timesteps, state_size):
            Smoothed means mut
        * array(timesteps, state_size, state_size)
            Smoothed covariances Sigmat
        """
        timesteps, _ = mu_hist.shape
        state_size, _ = self.A.shape
        mu_hist_smooth = jnp.zeros((timesteps, state_size))
        Sigma_hist_smooth = jnp.zeros((timesteps, state_size, state_size))

        mut_giv_T = mu_hist[-1, :]
        Sigmat_giv_T = Sigma_hist[-1, :]

        # Update last step
        mu_hist_smooth = index_update(mu_hist_smooth, -1, mut_giv_T)
        Sigma_hist_smooth = index_update(Sigma_hist_smooth, -1, Sigmat_giv_T)

        elements = zip(mu_hist[-2::-1], Sigma_hist[-2::-1, ...],
                       mu_cond_hist[::-1, ...], Sigma_cond_hist[::-1, ...])
        for t, (mutt, Sigmatt, mut_cond_next,
                Sigmat_cond_next) in enumerate(elements, 1):
            Jt = Sigmatt @ self.A.T @ inv(Sigmat_cond_next)
            mut_giv_T = mutt + Jt @ (mut_giv_T - mut_cond_next)
            Sigmat_giv_T = Sigmatt + Jt @ (Sigmat_giv_T -
                                           Sigmat_cond_next) @ Jt.T

            mu_hist_smooth = index_update(mu_hist_smooth, -(t + 1), mut_giv_T)
            Sigma_hist_smooth = index_update(Sigma_hist_smooth, -(t + 1),
                                             Sigmat_giv_T)

        return mu_hist_smooth, Sigma_hist_smooth
コード例 #10
0
    def expected_values(self):
        """
        Compute the expected values for the parameters
        π_k, μ_k and Λ_k of the re-estimation equations
        for the lower-bound.
        To be used during training
        """
        if not self.fitted:
            raise NotFittedError("This VBMixture instance is not fitted yet.")
        pi_k = self.alpha_k / self.alpha_k.sum()
        mu_k = self.m_k.T
        Sigma_k = inv(self.eta_k[:, None, None] * self.W_k)

        return pi_k, mu_k, Sigma_k
コード例 #11
0
ファイル: lds_lib.py プロジェクト: stjordanis/pyprobml
    def __kalman_step(self, state, xt):
        mun, Sigman = state
        I = jnp.eye(self.state_size)
        # Sigman|{n-1}
        Sigman_cond = self.A @ Sigman @ self.A.T + self.Q
        St = self.C @ Sigman_cond @ self.C.T + self.R
        Kn = Sigman_cond @ self.C.T @ inv(St)

        # mun|{n-1} and xn|{n-1}
        mu_update = self.A @ mun
        x_update = self.C @ mu_update

        mun = mu_update + Kn @ (xt - x_update)
        Sigman = (I - Kn @ self.C) @ Sigman_cond
        return (mun, Sigman), (mun, Sigman, mu_update, Sigman_cond)
コード例 #12
0
ファイル: linalg.py プロジェクト: dev-fennek/jax
def cond(x, p=None):
  _assertNoEmpty2d(x)
  if p in (None, 2):
    s = la.svd(x, compute_uv=False)
    return s[..., 0] / s[..., -1]
  elif p == -2:
    s = la.svd(x, compute_uv=False)
    r = s[..., -1] / s[..., 0]
  else:
    _assertRankAtLeast2(x)
    _assertNdSquareness(x)
    invx = la.inv(x)
    r = la.norm(x, ord=p, axis=(-2, -1)) * la.norm(invx, ord=p, axis=(-2, -1))

  # Convert nans to infs unless the original array had nan entries
  orig_nan_check = np.full_like(r, ~np.isnan(r).any())
  nan_mask = np.logical_and(np.isnan(r), ~np.isnan(x).any(axis=(-2, -1)))
  r = np.where(orig_nan_check, np.where(nan_mask, np.inf, r), r)
  return r
コード例 #13
0
    def E4(self, X, beta, eta, m, W, beta_0, eta_0, m_0, W_0):
        """
        E[log p(μ, Λ)]
        """
        N, M = X.shape
        K, *_ = eta.shape
        decomp = jnp.linalg.cholesky(W_0)
        wishart = tfd.WishartTriL(df=eta_0, scale_tril=decomp)

        diffk = m - m_0
        log_hat_Λ = self._compute_e_log_lambda(X, eta, W)
        mahal_mk = jnp.einsum("im,mij,jm->m", diffk, W, diffk)
        Tr_W0inv_W = jnp.einsum("mij,mij->m", inv(W_0), W)

        term1 = (M * jnp.log(beta_0[0] / (jnp.pi * 2)).sum() + log_hat_Λ -
                 M * beta_0 / beta - beta_0 * eta * mahal_mk).sum() / 2
        term2 = K * wishart.log_normalization().sum()
        term3 = ((eta_0 - M - 1) / 2 * log_hat_Λ).sum()
        term4 = (eta * Tr_W0inv_W).sum() / 2

        E_val = term1 + term2 + term3 - term4

        return E_val
コード例 #14
0
ファイル: weibull.py プロジェクト: RustamUzb/WEMIOT
    def __fitComplete2pMLE(self):
        # initial guess:
        shape = 1.2
        scale = self.failures.mean()
        parameters = jnp.array([shape, scale])

        J = jacfwd(self.__logLikelihood2pComp)
        H = jacfwd(jacrev(self.__logLikelihood2pComp))

        epoch = 0
        total = 1
        while not (total < 0.01 or epoch > 200):
            epoch += 1
            grads = J(parameters)
            hess = linalg.inv(H(parameters))
            # Q is a coefficient to reduce gradient ascent step for high delta
            q = 1 / (1 + jnp.sqrt(abs(grads / 800)))
            # Newton-Raphson maximisation
            parameters -= q * hess @ grads
            total = abs(grads[0]) + abs(grads[1])

        if epoch < 200:
            self.converged = True
            self.shape = parameters[0]
            self.scale = parameters[1]
            self.method = '2pComplete'

            # Fisher Matrix confidence bound
            self.variance = [abs(hess[0][0]), abs(hess[1][1])]
            self.beta_eta_covar = [abs(hess[1][0])]

        else:
            # if more than 200 epoch it would be considered that fit is not converged
            self.converged = False
            self.shape = 0.0
            self.scale = 0.0
            self.method = Method.MLEComplete2p
コード例 #15
0
ファイル: weibull.py プロジェクト: RustamUzb/WEMIOT
    def __fitTypeICensored2pMLE(self):
        # initial guess:
        shape = 1.2
        scale = (self.failures.mean() + self.censored.mean()) / 2
        parameters = jnp.array([shape, scale])

        J = jacfwd(self.__logLikelihood2pTypeICensored)
        H = jacfwd(jacrev(self.__logLikelihood2pTypeICensored))

        epoch = 0
        total = 1
        while not (total < 0.09 or epoch > 200):
            epoch += 1
            grads = J(parameters)
            hess = linalg.inv(H(parameters))
            q = 1 / (
                1 + jnp.sqrt(abs(grads / 8))
            )  # Q is a coefficient to reduce gradient ascent step for high delta
            parameters -= q * hess @ grads  # Newton-Raphson maximisation
            total = abs(grads[0]) + abs(grads[1])

        if epoch < 200:
            self.converged = True
            self.shape = parameters[0]
            self.scale = parameters[1]
            self.method = Method.MLECensored2p
            self.variance = [abs(hess[0][0]), abs(hess[1][1])]
            self.beta_eta_covar = [abs(hess[1][0])]

        else:
            # if more than 200 epoch it would be considered that fit is not converged
            self.converged = False
            self.shape = None
            self.scale = None
            self.method = Method.MLECensored2p
            print('no')
コード例 #16
0
 def _calc_canon_mat(X: Arrays, Y: Arrays, λs) -> np.DeviceArray:
     # https://stackoverflow.com/questions/15670094/speed-up-solving-a-triangular-linear-system-with-numpy
     K = (inv(cholesky(CanonicalRidge._ridge_cov(X, λs[0]))) @ (X.T @ Y) /
          (X.T.shape[0] * Y.shape[1]) @ inv(
              cholesky(CanonicalRidge._ridge_cov(Y, λs[1]))))
     return K
コード例 #17
0
def beta_Sigma_hat_rvfl2(
    X=None,
    y=None,
    Sigma=None,
    sigma=0.05,
    fit_intercept=False,
    X_star=None,  # check when dim = 1 # check when dim = 1
    return_cov=True,
    beta_hat_=None,  # for prediction only (X_star is not None)
    Sigma_hat_=None,
    backend="cpu",
):  # for prediction only (X_star is not None)

    if (X is not None) & (y is not None):

        if len(X.shape) == 1:
            X = X.reshape(-1, 1)

        n, p = X.shape

        if Sigma is None:
            if fit_intercept == True:
                Sigma = np.eye(p + 1)
            else:
                Sigma = np.eye(p)

        if X_star is not None:
            if len(X_star.shape) == 1:
                X_star = X_star.reshape(-1, 1)

        if fit_intercept == True:

            X = mo.cbind(np.ones(n), X)
            Cn = (
                la.inv(
                    mo.safe_sparse_dot(
                        a=Sigma,
                        b=mo.crossprod(x=X, backend=backend),
                        backend="cpu",
                    )
                    + (sigma ** 2) * np.eye(p + 1)
                )
                if backend == "cpu"
                else jla.inv(
                    mo.safe_sparse_dot(
                        a=Sigma,
                        b=mo.crossprod(x=X, backend=backend),
                        backend=backend,
                    )
                    + (sigma ** 2) * np.eye(p + 1)
                )
            )

            if X_star is not None:
                X_star = mo.cbind(
                    x=np.ones(X_star.shape[0]), y=X_star, backend=backend
                )
        else:
            # rename to invCn
            Cn = (
                la.inv(
                    mo.safe_sparse_dot(
                        a=Sigma,
                        b=mo.crossprod(x=X, backend=backend),
                        backend="cpu",
                    )
                    + (sigma ** 2) * np.eye(p)
                )
                if backend == "cpu"
                else jla.inv(
                    mo.safe_sparse_dot(
                        a=Sigma,
                        b=mo.crossprod(x=X, backend=backend),
                        backend=backend,
                    )
                    + (sigma ** 2) * np.eye(p)
                )
            )

        temp = mo.safe_sparse_dot(
            a=Cn,
            b=mo.tcrossprod(x=Sigma, y=X, backend=backend),
            backend=backend,
        )
        smoothing_matrix = mo.safe_sparse_dot(a=X, b=temp, backend=backend)
        y_hat = mo.safe_sparse_dot(a=smoothing_matrix, b=y, backend=backend)

        if return_cov == True:

            if X_star is None:

                return {
                    "beta_hat": mo.safe_sparse_dot(
                        a=temp, b=y, backend=backend
                    ),
                    "Sigma_hat": Sigma
                    - mo.safe_sparse_dot(
                        a=temp,
                        b=mo.safe_sparse_dot(a=X, b=Sigma, backend=backend),
                        backend=backend,
                    ),
                    "GCV": np.mean(
                        ((y - y_hat) / (1 - np.trace(smoothing_matrix) / n))
                        ** 2
                    ),
                }

            else:

                if beta_hat_ is None:
                    beta_hat_ = mo.safe_sparse_dot(a=temp, b=y, backend=backend)

                if Sigma_hat_ is None:
                    Sigma_hat_ = Sigma - mo.safe_sparse_dot(
                        a=temp,
                        b=mo.safe_sparse_dot(a=X, b=Sigma, backend=backend),
                        backend=backend,
                    )

                return {
                    "beta_hat": beta_hat_,
                    "Sigma_hat": Sigma_hat_,
                    "GCV": np.mean(
                        ((y - y_hat) / (1 - np.trace(smoothing_matrix) / n))
                        ** 2
                    ),
                    "preds": mo.safe_sparse_dot(
                        a=X_star, b=beta_hat_, backend=backend
                    ),
                    "preds_std": np.sqrt(
                        np.diag(
                            mo.safe_sparse_dot(
                                a=X_star,
                                b=mo.tcrossprod(
                                    x=Sigma_hat_, y=X_star, backend=backend
                                ),
                                backend=backend,
                            )
                            + (sigma ** 2) * np.eye(X_star.shape[0])
                        )
                    ),
                }

        else:  # return_cov == False

            if X_star is None:

                return {
                    "beta_hat": mo.safe_sparse_dot(
                        a=temp, b=y, backend=backend
                    ),
                    "GCV": np.mean(
                        ((y - y_hat) / (1 - np.trace(smoothing_matrix) / n))
                        ** 2
                    ),
                }

            else:

                if beta_hat_ is None:
                    beta_hat_ = mo.safe_sparse_dot(a=temp, b=y, backend=backend)

                return {
                    "beta_hat": beta_hat_,
                    "GCV": np.mean(
                        ((y - y_hat) / (1 - np.trace(smoothing_matrix) / n))
                        ** 2
                    ),
                    "preds": mo.safe_sparse_dot(
                        a=X_star, b=beta_hat_, backend=backend
                    ),
                }

    else:  # (X is None) | (y is None) # predict

        assert (beta_hat_ is not None) & (X_star is not None)

        if return_cov == True:

            assert Sigma_hat_ is not None

            return {
                "preds": mo.safe_sparse_dot(
                    a=X_star, b=beta_hat_, backend=backend
                ),
                "preds_std": np.sqrt(
                    np.diag(
                        mo.safe_sparse_dot(
                            a=X_star,
                            b=mo.tcrossprod(
                                Sigma_hat_, X_star, backend=backend
                            ),
                            backend=backend,
                        )
                        + (sigma ** 2) * np.eye(X_star.shape[0])
                    )
                ),
            }

        else:

            return {
                "preds": mo.safe_sparse_dot(
                    a=X_star, b=beta_hat_, backend=backend
                )
            }
コード例 #18
0
    def filter(self, x_hist, jump_size, dt):
        """
        Compute the online version of the Kalman-Filter, i.e,
        the one-step-ahead prediction for the hidden state or the
        time update step
        
        Parameters
        ----------
        x_hist: array(timesteps, observation_size)
            
        Returns
        -------
        * array(timesteps, state_size):
            Filtered means mut
        * array(timesteps, state_size, state_size)
            Filtered covariances Sigmat
        * array(timesteps, state_size)
            Filtered conditional means mut|t-1
        * array(timesteps, state_size, state_size)
            Filtered conditional covariances Sigmat|t-1
        """
        I = jnp.eye(self.state_size)
        timesteps, *_ = x_hist.shape
        mu_hist = jnp.zeros((timesteps, self.state_size))
        Sigma_hist = jnp.zeros((timesteps, self.state_size, self.state_size))
        Sigma_cond_hist = jnp.zeros(
            (timesteps, self.state_size, self.state_size))
        mu_cond_hist = jnp.zeros((timesteps, self.state_size))

        # Initial configuration
        K1 = self.Sigma0 @ self.C.T @ inv(self.C @ self.Sigma0 @ self.C.T +
                                          self.R)
        mu1 = self.mu0 + K1 @ (x_hist[0] - self.C @ self.mu0)
        Sigma1 = (I - K1 @ self.C) @ self.Sigma0

        mu_hist = index_update(mu_hist, 0, mu1)
        Sigma_hist = index_update(Sigma_hist, 0, Sigma1)
        mu_cond_hist = index_update(mu_cond_hist, 0, self.mu0)
        Sigma_cond_hist = index_update(Sigma_hist, 0, self.Sigma0)

        Sigman = Sigma1.copy()
        mun = mu1.copy()
        for n in range(1, timesteps):
            # Runge-kutta integration step
            for _ in range(jump_size):
                k1 = self.A @ mun
                k2 = self.A @ (mun + dt * k1)
                mun = mun + dt * (k1 + k2) / 2

                k1 = self.A @ Sigman @ self.A.T + self.Q
                k2 = self.A @ (Sigman + dt * k1) @ self.A.T + self.Q
                Sigman = Sigman + dt * (k1 + k2) / 2

            Sigman_cond = Sigman.copy()
            St = self.C @ Sigman_cond @ self.C.T + self.R
            Kn = Sigman_cond @ self.C.T @ inv(St)

            mu_update = mun.copy()
            x_update = self.C @ mun
            mun = mu_update + Kn @ (x_hist[n] - x_update)
            Sigman = (I - Kn @ self.C) @ Sigman_cond

            mu_hist = index_update(mu_hist, n, mun)
            Sigma_hist = index_update(Sigma_hist, n, Sigman)
            mu_cond_hist = index_update(mu_cond_hist, n, mu_update)
            Sigma_cond_hist = index_update(Sigma_cond_hist, n, Sigman_cond)

        return mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist
コード例 #19
0
def inv(a):
    if isinstance(a, JaxArray): a = a.value
    return JaxArray(linalg.inv(a))