def compute_entropy(
    dims: TensorDims,
    V: torch.Tensor,
    Cov: torch.Tensor,
):
    """
    Compute sample-wise entropy of the smoothing posterior.
    We factorise the smoothing posterior such that the entropy sums over all time-steps
    For t==0, we use the entropy of p(z1 | y_{1:T})

    If particle and/or batch is present, the order is TPBF, i.e. time, particle, batch, feature.
    """
    entropy = 0.0
    for t in range(0, dims.timesteps):
        if t == 0:  # marginal entropy (t==0)
            LVt = cholesky(V[t])
            entropy += 0.5 * (
                2.0 * torch.sum(torch.log(batch_diag(LVt)), dim=(-1, ))  # F
                + dims.state * (1.0 + LOG_2PI))
        else:  # Joint entropy (t, t-1) - marginal entropy (t-1)
            Vtm1inv = batch_cholesky_inverse(cholesky(V[t - 1]))
            Cov_cond = V[t] - matmul(Cov[t - 1].transpose(-1, -2),
                                     matmul(Vtm1inv, Cov[t - 1]))
            LCov_cond = cholesky(Cov_cond)
            entropy += 0.5 * (
                2.0 *
                torch.sum(torch.log(batch_diag(LCov_cond)), dim=(-1, ))  # F
                + dims.state * (1.0 + LOG_2PI))
    return entropy
Exemplo n.º 2
0
 def natural_to_dist_params(natural_params: dict, dist_cls):
     """
     Map the natural parameters to one of pytorch's accepted distribution parameters.
     This is not necessarily the so called canonical response function.
     E.g. the Bernoulli and Categoricals are mapped to logits, not to probs,
     so that we do not unnecessarily switch around between the two.
     """
     if dist_cls is Normal:
         eta, neg_precision = (
             natural_params["eta"],
             natural_params["neg_precision"],
         )
         return {
             "loc": -0.5 * eta / neg_precision,
             "scale": torch.sqrt(-0.5 / neg_precision),
         }
     elif dist_cls is MultivariateNormal:
         eta, neg_precision = (
             natural_params["eta"],
             natural_params["neg_precision"],
         )
         precision_matrix = -2 * neg_precision
         covariance_matrix = batch_cholesky_inverse(
             cholesky(precision_matrix))
         return {
             "loc": matvec(covariance_matrix, eta),
             "covariance_matrix": covariance_matrix,
         }
     if dist_cls is Bernoulli:
         return natural_params
     elif dist_cls in (Categorical, OneHotCategorical):
         return natural_params
     else:
         raise NotImplementedError(f"Not implemented for {type(dist_cls)}")
Exemplo n.º 3
0
    def emit(
        self,
        lats_t: LatentsSGLS,
        ctrl_t: ControlInputsSGLS,
    ) -> torch.distributions.MultivariateNormal:
        # Unfortunately need to recompute gls_params.
        # Trade-off: faster, lower memory training vs. slower sampling/forecast
        gls_params_t = self.gls_base_parameters(
            switch=lats_t.variables.switch,
            controls=ctrl_t,
        )

        if lats_t.variables.m is not None:  # marginalise states
            mpy_t, Vpy_t = filter_forward_predictive_distribution(
                m=lats_t.variables.m,
                V=lats_t.variables.V,
                Q=gls_params_t.Q,
                C=gls_params_t.C,
                d=gls_params_t.d,
            )
            return MultivariateNormal(
                loc=mpy_t,
                scale_tril=cholesky(Vpy_t),
            )
        else:  # emit from state sample
            return torch.distributions.MultivariateNormal(
                loc=matvec(gls_params_t.C, lats_t.variables.x) +
                (gls_params_t.d if gls_params_t.d is not None else 0.0),
                scale_tril=gls_params_t.LQ,
            )
Exemplo n.º 4
0
    def forward(self, u, s, x=None, m=None, V=None):
        """
        m and V are from the Gaussian state x_{t-1}.
        Forward marginalises out the Gaussian if m and V given,
        or transforms a sample x if that is given instead.
        """
        # assert len({(x is None), (m is None and V is None)}) == 2
        assert (x is None) or (m is None and V is None)  #

        h = torch.cat((u, s), dim=-1) if u is not None else s
        switch_to_switch_dist = self.conditional_dist_transform(h)
        if self.F is not None:
            if m is not None:  # marginalise
                mp, Vp = filter_forward_prediction_step(
                    m=m,
                    V=V,
                    A=self.F,
                    R=switch_to_switch_dist.covariance_matrix,
                    b=switch_to_switch_dist.loc,
                )
            else:  # single sample fwd
                mp = matvec(self.F, x) + switch_to_switch_dist.loc
                Vp = switch_to_switch_dist.covariance_matrix
            return MultivariateNormal(loc=mp, scale_tril=cholesky(Vp))
        else:
            return switch_to_switch_dist
Exemplo n.º 5
0
    def forecast(
        self,
        n_steps_forecast: int,
        initial_latent: Latents,
        future_controls: Optional[Union[Sequence[ControlInputs],
                                        ControlInputs]] = None,
        deterministic: bool = False,
    ) -> Sequence[Prediction]:

        # TODO: we only support sample forecasts atm
        #  The metrics such as CRPS in GluonTS are evaluated with samples only.
        #  Some models could retain states closed-form though.
        if initial_latent.variables.x is None:
            initial_latent.variables.x = MultivariateNormal(
                loc=initial_latent.variables.m,
                scale_tril=cholesky(initial_latent.variables.V),
            ).rsample()
            initial_latent.variables.m = None
            initial_latent.variables.V = None
            initial_latent.variables.Cov = None

        initial_latent, future_controls = self._prepare_forecast(
            initial_latent=initial_latent,
            controls=future_controls,
            deterministic=deterministic,
        )

        return self._sample_trajectory_from_initial(
            n_steps_forecast=n_steps_forecast,
            initial_latent=initial_latent,
            future_controls=future_controls,
            deterministic=deterministic,
        )
Exemplo n.º 6
0
    def _make_switch_transition_dist(
        self,
        lat_vars_tm1: GLSVariablesSGLS,
        ctrl_t: ControlInputsSGLS,
    ) -> torch.distributions.MultivariateNormal:
        """
        Compute p(s_t | s_{t-1}) = \int p(x_{t-1}, s_t | s_{t-1}) dx_{t-1}
        = \int p(s_t | s_{t-1}, x_{t-1}) N(x_{t-1} | s_{t-1}) dx_{t-1}.
        We use an additive structure, resulting in a convolution of PDFs, i.e.
        i) the conditional from the switch-to-switch transition and
        ii) the marginal from the state-switch-transition (state marginalised).
        The Gaussian is a 'stable distribution' -> The sum of Gaussian variables
        (not PDFs!) is Gaussian, for which locs and *covariances* are summed.
        (In case of weighted sum, means and *scales* are weighted.)
        """
        if len({lat_vars_tm1.x is None, lat_vars_tm1.m is None}) != 2:
            raise Exception(
                "Provide samples XOR dist params (-> marginalize).")
        marginalize_states = lat_vars_tm1.m is not None

        # i) switch-to-switch conditional
        switch_to_switch_dist = super()._make_switch_transition_dist(
            lat_vars_tm1=lat_vars_tm1,
            ctrl_t=ctrl_t,
        )

        # ii) state-to-switch
        rec_base_params = self.recurrent_base_parameters(
            switch=lat_vars_tm1.switch)
        if marginalize_states:
            m, V = filter_forward_prediction_step(
                m=lat_vars_tm1.m,
                V=lat_vars_tm1.V,
                A=rec_base_params.F,
                R=rec_base_params.S,
                b=None,
            )
        else:
            m = matvec(rec_base_params.F, lat_vars_tm1.x)
            V = rec_base_params.S
        state_to_switch_dist = MultivariateNormal(
            loc=m,
            scale_tril=cholesky(V),
        )

        # combine i) & ii): sum variables (=convolve PDFs).
        switch_model_dist = gaussian_linear_combination({
            state_to_switch_dist:
            0.5,
            switch_to_switch_dist:
            0.5
        })
        return switch_model_dist
Exemplo n.º 7
0
def smooth_backward_step(m_sm, V_sm, m_fw, V_fw, A, b, R):
    # filter one-step predictive variance
    P = (matmul(A, matmul(V_fw, A.transpose(-1, -2)))
         if A is not None else V_fw) + R
    Pinv = batch_cholesky_inverse(cholesky(P))
    # m and V share J when in the conditioning operation (joint to posterior bw)
    J = matmul(V_fw, matmul(A.transpose(-1, -2), Pinv))

    m_sm_t = m_fw + matvec(
        J, m_sm - (matvec(A, m_fw) if A is not None else m_fw) - b)
    V_sm_t = symmetrize(V_fw +
                        matmul(J, matmul(V_sm - P, J.transpose(-1, -2))))
    Cov_sm_t = matmul(J, V_sm)
    return m_sm_t, V_sm_t, Cov_sm_t
Exemplo n.º 8
0
def compute_entropy(
    dims: TensorDims,
    V: torch.Tensor,
    Cov: torch.Tensor,
):
    """ Compute entropy of Gaussian posterior from E-step (in Markovian SSM) """
    entropy = 0.0
    for t in range(0, dims.timesteps):
        if t == 0:  # marginal entropy (t==0)
            LVt = cholesky(V[t])
            entropy += 0.5 * 2.0 * torch.sum(
                torch.log(batch_diag(LVt)),
                dim=(-1, )) + 0.5 * dims.state * (1 + LOG_2PI)
        else:  # joint entropy (t, t-1) - marginal entropy (t-1)
            Vtm1inv = batch_cholesky_inverse(cholesky(V[t - 1]))
            Cov_cond = V[t] - matmul(Cov[t - 1].transpose(-1, -2),
                                     matmul(Vtm1inv, Cov[t - 1]))
            LCov_cond = cholesky(Cov_cond)

            entropy += 0.5 * 2.0 * torch.sum(
                torch.log(batch_diag(LCov_cond)),
                dim=(-1, )) + 0.5 * dims.state * (1.0 + LOG_2PI)
    return entropy
Exemplo n.º 9
0
 def _make_auxiliary_model_dist(
     self,
     mp: torch.Tensor,
     Vp: torch.Tensor,
     gls_params: Box,
 ):
     mpz, Vpz = filter_forward_predictive_distribution(
         m=mp,
         V=Vp,
         Q=gls_params.Q,
         C=gls_params.C,
         d=gls_params.d,
     )
     auxiliary_model_dist = MultivariateNormal(loc=mpz,
                                               scale_tril=cholesky(Vpz))
     return auxiliary_model_dist
Exemplo n.º 10
0
def filter_forward_measurement_step(y,
                                    m,
                                    V,
                                    Q,
                                    C,
                                    d=None,
                                    return_loss_components=False):
    """ Single measurement/update step in forward filtering cycle
    (prediction --> measurement) """
    mpy, Vpy = filter_forward_predictive_distribution(m=m, V=V, Q=Q, C=C, d=d)
    LVpyinv = torch.inverse(cholesky(Vpy))
    S = matmul(LVpyinv, matmul(
        C, V))  # CV is cov_T. cov VC.T could be output of predictive dist.
    dobs = y - mpy
    dobs_norm = matvec(LVpyinv, dobs)
    mt = m + matvec(S.transpose(-1, -2), dobs_norm)
    Vt = symmetrize(V - matmul(S.transpose(-1, -2), S))
    if return_loss_components:  # I hate to do that! but it is convenient...
        return mt, Vt, dobs_norm, LVpyinv
    else:
        return mt, Vt
Exemplo n.º 11
0
    def marginal_step(
        self,
        lats_tm1: LatentsSGLS,
        ctrl_t: ControlInputsSGLS,
        deterministic: bool = False,
    ) -> Prediction:
        # TODO: duplication with sample_step. Requires refactoring.
        n_batch = lats_tm1.variables.m.shape[1]

        if lats_tm1.variables.switch is None:
            switch_model_dist_t = self._make_switch_prior_dist(
                n_particle=self.n_particle,
                n_batch=n_batch,
                lat_vars_tm1=lats_tm1.variables,
                ctrl_t=ctrl_t,
            )
        else:
            switch_model_dist_t = self._make_switch_transition_dist(
                lat_vars_tm1=lats_tm1.variables,
                ctrl_t=ctrl_t,
            )

        s_t = (switch_model_dist_t.mean
               if deterministic else switch_model_dist_t.sample())
        gls_params_t = self.gls_base_parameters(
            switch=s_t,
            controls=ctrl_t,
        )

        x_t = None
        m_t, V_t = filter_forward_prediction_step(
            m=lats_tm1.variables.m,
            V=lats_tm1.variables.V,
            R=gls_params_t.R,
            A=gls_params_t.A,
            b=gls_params_t.b,
        )
        mpy_t, Vpy_t = filter_forward_predictive_distribution(
            m=m_t,
            V=V_t,
            Q=gls_params_t.Q,
            C=gls_params_t.C,
            d=gls_params_t.d,
        )
        emission_dist_t = torch.distributions.MultivariateNormal(
            loc=mpy_t,
            scale_tril=cholesky(Vpy_t),
        )

        # NOTE: Should compute Cov if need forecast joint distribution.
        lats_t = LatentsSGLS(
            log_weights=lats_tm1.log_weights,  # does not change w/o evidence.
            gls_params=None,  # not used outside this function
            variables=GLSVariablesSGLS(
                x=x_t,
                m=m_t,
                V=V_t,
                Cov=None,
                switch=s_t,
            ),
        )
        return Prediction(latents=lats_t, emissions=emission_dist_t)
Exemplo n.º 12
0
    def filter_step(
        self,
        lats_tm1: (LatentsSGLS, None),
        tar_t: torch.Tensor,
        ctrl_t: ControlInputsSGLS,
        tar_is_obs_t: Optional[torch.Tensor] = None,
    ):
        if tar_is_obs_t is not None:
            raise NotImplementedError("cannot handle missing data atm.")

        is_initial_step = lats_tm1 is None
        if is_initial_step:
            n_particle, n_batch = self.n_particle, len(tar_t)
            state_prior = self.state_prior_model(
                None,
                batch_shape_to_prepend=(n_particle, n_batch),
            )
            log_norm_weights = normalize_log_weights(
                log_weights=torch.zeros_like(state_prior.loc[..., 0]), )
            lats_tm1 = LatentsSGLS(
                log_weights=None,  # Not used. We use log_norm_weights instead.
                gls_params=None,  # First (previous) step no gls_params
                variables=GLSVariablesSGLS(
                    m=state_prior.loc,
                    V=state_prior.covariance_matrix,
                    Cov=None,
                    x=None,
                    switch=None,
                ),
            )
            switch_model_dist = self._make_switch_prior_dist(
                lat_vars_tm1=lats_tm1.variables,
                ctrl_t=ctrl_t,
                n_particle=n_particle,
                n_batch=n_batch,
            )
        else:
            log_norm_weights = normalize_log_weights(
                log_weights=lats_tm1.log_weights, )
            log_norm_weights, resampled_tensors = resample(
                n_particle=self.n_particle,
                log_norm_weights=log_norm_weights,
                tensors_to_resample={
                    key: val
                    for key, val in lats_tm1.variables.__dict__.items()
                    if key not in ("x", "Cov")  # below set to None explicitly
                },
                resampling_indices_fn=self.resampling_indices_fn,
                criterion_fn=self.resampling_criterion_fn,
            )
            # We dont use gls_params anymore.
            # If needed for e.g. evaluation, remember to re-sample all params!
            lats_tm1 = LatentsSGLS(
                log_weights=None,  # Not used. We use log_norm_weights instead.
                gls_params=None,  # not used outside this function. Read above.
                variables=GLSVariablesSGLS(
                    **resampled_tensors,
                    x=None,
                    Cov=None,
                ),
            )
            switch_model_dist = self._make_switch_transition_dist(
                lat_vars_tm1=lats_tm1.variables,
                ctrl_t=ctrl_t,
            )

        switch_proposal_dist = self._make_switch_proposal_dist(
            switch_model_dist=switch_model_dist,
            switch_encoder_dist=self._make_encoder_dists(
                tar_t=tar_t,
                ctrl_t=ctrl_t,
            ).switch,
        )
        s_t = switch_proposal_dist.rsample()
        gls_params_t = self.gls_base_parameters(
            switch=s_t,
            controls=ctrl_t,
        )

        mp, Vp = filter_forward_prediction_step(
            m=lats_tm1.variables.m,
            V=lats_tm1.variables.V,
            R=gls_params_t.R,
            A=gls_params_t.A,
            b=gls_params_t.b,
        )

        m_t, V_t = filter_forward_measurement_step(
            y=tar_t,
            m=mp,
            V=Vp,
            Q=gls_params_t.Q,
            C=gls_params_t.C,
            d=gls_params_t.d,
        )
        mpy_t, Vpy_t = filter_forward_predictive_distribution(
            m=mp,
            V=Vp,
            Q=gls_params_t.Q,
            C=gls_params_t.C,
            d=gls_params_t.d,
        )
        measurement_dist = MultivariateNormal(
            loc=mpy_t,
            scale_tril=cholesky(Vpy_t),
        )

        log_update = (measurement_dist.log_prob(tar_t) +
                      switch_model_dist.log_prob(s_t) -
                      switch_proposal_dist.log_prob(s_t))
        log_weights_t = log_norm_weights + log_update

        return LatentsSGLS(
            log_weights=log_weights_t,
            gls_params=None,  # not used outside this function
            variables=GLSVariablesSGLS(
                m=m_t,
                V=V_t,
                x=None,
                Cov=None,
                switch=s_t,
            ),
        )
Exemplo n.º 13
0
def smooth_global(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    """ compute posterior by direct inversion of unrolled model """
    device, dtype = A.device, A.dtype

    R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag)
    Q = cov_from_invcholesky_param(LQinv_tril, LQinv_logdiag)
    V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)

    Q_field = torch.zeros(
        (dims.batch, dims.timesteps * dims.state, dims.timesteps * dims.state),
        device=device,
        dtype=dtype,
    )
    h_field = torch.zeros((dims.batch, dims.timesteps * dims.state),
                          device=device,
                          dtype=dtype)

    # pre-compute biases
    b = matvec(B, u_state) if u_state is not None else 0
    d = matvec(D, u_obs) if u_obs is not None else 0

    Rinv = symmetrize(torch.cholesky_inverse(cholesky(R)))
    Qinv = symmetrize(torch.cholesky_inverse(cholesky(Q)))
    V0inv = symmetrize(torch.cholesky_inverse(cholesky(V0)))

    CtQinvymd = matvec(matmul(C.transpose(-1, -2), Qinv), y - d)
    h_obs = CtQinvymd.transpose(1, 0).reshape((
        dims.batch,
        dims.timesteps * dims.state,
    ))
    Q_obs = kron(
        torch.eye(dims.timesteps, dtype=dtype, device=device),
        matmul(C.transpose(-1, -2), matmul(Qinv, C)),
    )

    AtRinvA = matmul(A.transpose(-1, -2), matmul(Rinv, A))
    RinvA = matmul(Rinv, A)

    h_field[:, :dims.state] = matmul(V0inv, m0).repeat((dims.batch, ) + (1, ) *
                                                       (h_field.ndim - 1))
    Q_field[:, :dims.state, :dims.state] += V0inv.repeat((dims.batch, ) +
                                                         (1, ) *
                                                         (Q_field.ndim - 1))
    for t in range(dims.timesteps - 1):
        idx = t * dims.state
        h_field[:,
                idx:idx + dims.state] += -matvec(RinvA.transpose(-1, -2), b[t])
        h_field[:, idx + dims.state:idx + 2 * dims.state] += matvec(Rinv, b[t])
        Q_field[:, idx:idx + dims.state, idx:idx + dims.state] += AtRinvA
        Q_field[:, idx:idx + dims.state, idx + dims.state:idx +
                2 * dims.state] += -RinvA.transpose(-1, -2)
        Q_field[:, idx + dims.state:idx + 2 * dims.state,
                idx:idx + dims.state] += -RinvA
        Q_field[:, idx + dims.state:idx + 2 * dims.state,
                idx + dims.state:idx + 2 * dims.state, ] += Rinv

    L_all_inv = torch.inverse(cholesky(Q_field + Q_obs))
    V_all = matmul(L_all_inv.transpose(-1, -2), L_all_inv)
    m_all = matvec(V_all, h_obs + h_field)

    # Pytorch has no Fortran style reading of indices.
    m = m_all.reshape((dims.batch, dims.timesteps, dims.state)).transpose(0, 1)
    V, Cov = [], []
    for t in range(0, dims.timesteps):
        idx = t * dims.state
        V.append(V_all[:, idx:idx + dims.state, idx:idx + dims.state])
        if t < (dims.timesteps - 1):
            Cov.append(V_all[:, idx:idx + dims.state,
                             idx + dims.state:idx + 2 * dims.state, ])
        else:
            Cov.append(
                torch.zeros(
                    (dims.batch, dims.state, dims.state),
                    device=device,
                    dtype=dtype,
                ))
    V = torch.stack(V, dim=0)
    Cov = torch.stack(Cov, dim=0)

    return m, V, Cov
Exemplo n.º 14
0
    def sample_step(
        self,
        lats_tm1: LatentsASGLS,
        ctrl_t: ControlInputsSGLS,
        deterministic: bool = False,
    ) -> Prediction:
        n_batch = lats_tm1.variables.x.shape[1]

        if lats_tm1.variables.switch is None:
            switch_model_dist_t = self._make_switch_prior_dist(
                n_particle=self.n_particle,
                n_batch=n_batch,
                lat_vars_tm1=lats_tm1.variables,
                ctrl_t=ctrl_t,
            )
        else:
            switch_model_dist_t = self._make_switch_transition_dist(
                lat_vars_tm1=lats_tm1.variables,
                ctrl_t=ctrl_t,
            )

        s_t = (switch_model_dist_t.mean
               if deterministic else switch_model_dist_t.sample())
        gls_params_t = self.gls_base_parameters(
            switch=s_t,
            controls=ctrl_t,
        )

        x_dist_t = torch.distributions.MultivariateNormal(
            loc=(matvec(gls_params_t.A, lats_tm1.variables.x)
                 if gls_params_t.A is not None else lats_tm1.variables.x) +
            (gls_params_t.b if gls_params_t.b is not None else 0.0),
            scale_tril=gls_params_t.LR,  # stable with scale and 0 variance.
        )

        x_t = x_dist_t.mean if deterministic else x_dist_t.sample()

        z_dist_t = torch.distributions.MultivariateNormal(
            loc=matvec(gls_params_t.C, x_t) +
            (gls_params_t.d if gls_params_t.d is not None else 0.0),
            scale_tril=cholesky(gls_params_t.Q),
        )
        z_t = z_dist_t.mean if deterministic else z_dist_t.sample()

        lats_t = LatentsASGLS(
            log_weights=lats_tm1.log_weights,  # does not change w/o evidence.
            gls_params=None,  # not used outside this function
            variables=GLSVariablesASGLS(
                x=x_t,
                m=None,
                V=None,
                Cov=None,
                switch=s_t,
                auxiliary=z_t,
            ),
        )
        emission_dist = self.emit(lats_t=lats_t, ctrl_t=ctrl_t)
        emissions_t = (emission_dist.mean
                       if deterministic else emission_dist.sample())

        return Prediction(latents=lats_t, emissions=emissions_t)
def smooth_global(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    """ compute posterior by direct inversion of unrolled model """
    # This implementation works only if all mats have time and batch dimension.
    #  Otherwise does not broadcast correctly.
    assert A.ndim == 4
    assert B.ndim == 4
    assert C.ndim == 4
    assert D.ndim == 4
    assert LQinv_tril.ndim == 4
    assert LQinv_logdiag.ndim == 3
    assert LRinv_tril.ndim == 4
    assert LRinv_logdiag.ndim == 3

    # raise NotImplementedError("Not yet implemented")
    device, dtype = A.device, A.dtype

    R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag)
    Q = cov_from_invcholesky_param(LQinv_tril, LQinv_logdiag)
    V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)

    Q_field = torch.zeros(
        (dims.batch, dims.timesteps * dims.state, dims.timesteps * dims.state),
        device=device,
        dtype=dtype,
    )
    h_field = torch.zeros((dims.batch, dims.timesteps * dims.state),
                          device=device,
                          dtype=dtype)

    b = matvec(B, u_state[:-1]) if u_state is not None else 0
    d = matvec(D, u_obs) if u_obs is not None else 0

    Rinv = symmetrize(batch_cholesky_inverse(cholesky(R)))
    Qinv = symmetrize(batch_cholesky_inverse(cholesky(Q)))
    V0inv = symmetrize(batch_cholesky_inverse(cholesky(V0)))

    CtQinvymd = matvec(matmul(C.transpose(-1, -2), Qinv), y - d)
    h_obs = CtQinvymd.transpose(1, 0).reshape((
        dims.batch,
        dims.timesteps * dims.state,
    ))

    CtQinvC = matmul(C.transpose(-1, -2), matmul(Qinv, C))
    assert len(CtQinvC) == dims.timesteps
    Q_obs = make_block_diagonal(mats=tuple(mat_t for mat_t in CtQinvC))

    AtRinvA = matmul(A.transpose(-1, -2), matmul(Rinv, A))
    RinvA = matmul(Rinv, A)

    h_field[:, :dims.state] = matmul(V0inv, m0).repeat((dims.batch, ) + (1, ) *
                                                       (h_field.ndim - 1))
    Q_field[:, :dims.state, :dims.state] += V0inv.repeat((dims.batch, ) +
                                                         (1, ) *
                                                         (Q_field.ndim - 1))
    for t in range(dims.timesteps - 1):
        id = t * dims.state
        h_field[:, id:id +
                dims.state] += -matvec(RinvA[t].transpose(-1, -2), b[t])
        h_field[:,
                id + dims.state:id + 2 * dims.state] += matvec(Rinv[t], b[t])
        Q_field[:, id:id + dims.state, id:id + dims.state] += AtRinvA[t]
        Q_field[:, id:id + dims.state, id + dims.state:id +
                2 * dims.state] += -RinvA[t].transpose(-1, -2)
        Q_field[:, id + dims.state:id + 2 * dims.state,
                id:id + dims.state] += -RinvA[t]
        Q_field[:, id + dims.state:id + 2 * dims.state,
                id + dims.state:id + 2 * dims.state, ] += Rinv[t]

    L_all_inv = torch.inverse(cholesky(Q_field + Q_obs))
    V_all = matmul(L_all_inv.transpose(-1, -2), L_all_inv)
    m_all = matvec(V_all, h_obs + h_field)

    # Pytorch has no Fortran style reading of indices.
    m = m_all.reshape((dims.batch, dims.timesteps, dims.state)).transpose(0, 1)
    V, Cov = [], []
    for t in range(0, dims.timesteps):
        id = t * dims.state
        V.append(V_all[:, id:id + dims.state, id:id + dims.state])
        if t < (dims.timesteps - 1):
            Cov.append(V_all[:, id:id + dims.state,
                             id + dims.state:id + 2 * dims.state, ])
        else:
            Cov.append(
                torch.zeros(
                    (dims.batch, dims.state, dims.state),
                    device=device,
                    dtype=dtype,
                ))
    V = torch.stack(V, dim=0)
    Cov = torch.stack(Cov, dim=0)

    return m, V, Cov