Example #1
0
    def sample_step(
        self,
        lats_tm1: LatentsSGLS,
        ctrl_t: ControlInputsSGLS,
        deterministic: bool = False,
    ) -> Prediction:
        n_batch = lats_tm1.variables.x.shape[1]

        if lats_tm1.variables.switch is None:
            switch_model_dist_t = self._make_switch_prior_dist(
                n_particle=self.n_particle,
                n_batch=n_batch,
                lat_vars_tm1=lats_tm1.variables,
                ctrl_t=ctrl_t,
            )
        else:
            switch_model_dist_t = self._make_switch_transition_dist(
                lat_vars_tm1=lats_tm1.variables,
                ctrl_t=ctrl_t,
            )

        s_t = (switch_model_dist_t.mean
               if deterministic else switch_model_dist_t.sample())
        gls_params_t = self.gls_base_parameters(
            switch=s_t,
            controls=ctrl_t,
        )

        x_dist_t = torch.distributions.MultivariateNormal(
            loc=(matvec(gls_params_t.A, lats_tm1.variables.x)
                 if gls_params_t.A is not None else lats_tm1.variables.x) +
            (gls_params_t.b if gls_params_t.b is not None else 0.0),
            scale_tril=gls_params_t.LR,  # stable with scale and 0 variance.
        )

        x_t = x_dist_t.mean if deterministic else x_dist_t.sample()
        (m_t, V_t) = (None, None)
        # emission_dist = self.emit(lats_t=lats_t, ctrl_t=ctrl_t)
        emission_dist_t = torch.distributions.MultivariateNormal(
            loc=matvec(gls_params_t.C, x_t) +
            (gls_params_t.d if gls_params_t.d is not None else 0.0),
            scale_tril=gls_params_t.LQ,
        )
        emissions_t = (emission_dist_t.mean
                       if deterministic else emission_dist_t.sample())

        # NOTE: Should compute Cov if need forecast joint distribution.
        lats_t = LatentsSGLS(
            log_weights=lats_tm1.log_weights,  # does not change w/o evidence.
            gls_params=None,  # not used outside this function
            variables=GLSVariablesSGLS(
                x=x_t,
                m=m_t,
                V=V_t,
                Cov=None,
                switch=s_t,
            ),
        )
        return Prediction(latents=lats_t, emissions=emissions_t)
Example #2
0
def filter_forward(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    device, dtype = A.device, A.dtype

    R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag)
    Q = cov_from_invcholesky_param(LQinv_tril, LQinv_logdiag)
    V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)
    m0, V0 = add_sample_dims_to_initial_state(m0=m0, V0=V0, dims=dims)

    # pre-compute biases
    b = matvec(B, u_state) if u_state is not None else 0
    d = matvec(D, u_obs) if u_obs is not None else 0

    m_fw = torch.zeros((dims.timesteps, dims.batch, dims.state),
                       device=device,
                       dtype=dtype)
    V_fw = torch.zeros(
        (dims.timesteps, dims.batch, dims.state, dims.state),
        device=device,
        dtype=dtype,
    )
    for t in range(0, dims.timesteps):
        (mp, Vp) = (filter_forward_prediction_step(
            m=m_fw[t - 1],
            V=V_fw[t - 1],
            R=R,
            A=A,
            b=b[t - 1],
        ) if t > 0 else (m0, V0))
        m_fw[t], V_fw[t] = filter_forward_measurement_step(y=y[t],
                                                           m=mp,
                                                           V=Vp,
                                                           Q=Q,
                                                           C=C,
                                                           d=d[t])
    return m_fw, V_fw
Example #3
0
    def forward(self, switch, controls: ControlInputsSGLSISSM) -> GLSParams:
        weights = self.link_transformers(switch=switch)

        # biases
        B = (torch.einsum("...k,koi->...oi", weights.B, self.B)
             if self.B is not None else None)
        D = (torch.einsum("...k,koi->...oi", weights.D, self.D)
             if self.D is not None else None)
        b = self.compute_bias(
            s=switch,
            u=controls.state,
            bias_fn=self.b_fn,
            bias_matrix=B,
        )
        d = self.compute_bias(
            s=switch,
            u=controls.target,
            bias_fn=self.d_fn,
            bias_matrix=D,
        )

        # transition and emission from ISSM
        _, C, R_diag_projector = self.issm(
            seasonal_indicators=controls.seasonal_indicators, )
        A = None  # instead of identity, we use None to reduce computation

        # covariance matrices
        if self.make_cov_from_cholesky_avg:
            Q_diag, LQ_diag = self.var_from_average_scales(
                weights=weights.Q,
                Linv_logdiag=self.LQinv_logdiag_limiter(self.LQinv_logdiag),
            )
            R_diag, LR_diag = self.var_from_average_scales(
                weights=weights.R,
                Linv_logdiag=self.LRinv_logdiag_limiter(self.LRinv_logdiag),
            )
        else:
            Q_diag, LQ_diag = self.var_from_average_variances(
                weights=weights.Q,
                Linv_logdiag=self.LQinv_logdiag_limiter(self.LQinv_logdiag),
            )
            R_diag, LR_diag = self.var_from_average_variances(
                weights=weights.R,
                Linv_logdiag=self.LRinv_logdiag_limiter(self.LRinv_logdiag),
            )
        Q = batch_diag_matrix(Q_diag)
        R = batch_diag_matrix(matvec(R_diag_projector, R_diag))
        LQ = batch_diag_matrix(LQ_diag)
        LR = batch_diag_matrix(matvec(R_diag_projector, LR_diag))
        return GLSParams(A=A, b=b, C=C, d=d, Q=Q, R=R, LR=LR, LQ=LQ)
Example #4
0
def smooth_backward_step(m_sm, V_sm, m_fw, V_fw, A, b, R):
    # filter one-step predictive variance
    P = (matmul(A, matmul(V_fw, A.transpose(-1, -2)))
         if A is not None else V_fw) + R
    Pinv = batch_cholesky_inverse(cholesky(P))
    # m and V share J when in the conditioning operation (joint to posterior bw)
    J = matmul(V_fw, matmul(A.transpose(-1, -2), Pinv))

    m_sm_t = m_fw + matvec(
        J, m_sm - (matvec(A, m_fw) if A is not None else m_fw) - b)
    V_sm_t = symmetrize(V_fw +
                        matmul(J, matmul(V_sm - P, J.transpose(-1, -2))))
    Cov_sm_t = matmul(J, V_sm)
    return m_sm_t, V_sm_t, Cov_sm_t
Example #5
0
def loss_forward(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    device, dtype = A.device, A.dtype

    R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag)
    Q = cov_from_invcholesky_param(LQinv_tril, LQinv_logdiag)
    V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)
    m0, V0 = add_sample_dims_to_initial_state(m0=m0, V0=V0, dims=dims)

    # pre-compute biases
    b = matvec(B, u_state) if u_state is not None else 0
    d = matvec(D, u_obs) if u_obs is not None else 0

    # Note: We can not use (more readable) in-place operations due to backprop problems.
    m_fw = [None] * dims.timesteps
    V_fw = [None] * dims.timesteps
    loss = torch.zeros((dims.batch, ), device=device, dtype=dtype)
    for t in range(0, dims.timesteps):
        (mp, Vp) = (filter_forward_prediction_step(
            m=m_fw[t - 1],
            V=V_fw[t - 1],
            R=R,
            A=A,
            b=b[t - 1],
        ) if t > 0 else (m0, V0))
        m_fw[t], V_fw[t], dobs_norm, LVpyinv = filter_forward_measurement_step(
            y=y[t], m=mp, V=Vp, Q=Q, C=C, d=d[t], return_loss_components=True)
        loss += (
            0.5 * torch.sum(dobs_norm**2, dim=-1) -
            0.5 * 2 * torch.sum(torch.log(batch_diag(LVpyinv)), dim=(-1, )) +
            0.5 * dims.target * LOG_2PI)

    return loss
    def forward(self, u, s, x=None, m=None, V=None):
        """
        m and V are from the Gaussian state x_{t-1}.
        Forward marginalises out the Gaussian if m and V given,
        or transforms a sample x if that is given instead.
        """
        # assert len({(x is None), (m is None and V is None)}) == 2
        assert (x is None) or (m is None and V is None)  #

        h = torch.cat((u, s), dim=-1) if u is not None else s
        switch_to_switch_dist = self.conditional_dist_transform(h)
        if self.F is not None:
            if m is not None:  # marginalise
                mp, Vp = filter_forward_prediction_step(
                    m=m,
                    V=V,
                    A=self.F,
                    R=switch_to_switch_dist.covariance_matrix,
                    b=switch_to_switch_dist.loc,
                )
            else:  # single sample fwd
                mp = matvec(self.F, x) + switch_to_switch_dist.loc
                Vp = switch_to_switch_dist.covariance_matrix
            return MultivariateNormal(loc=mp, scale_tril=cholesky(Vp))
        else:
            return switch_to_switch_dist
Example #7
0
 def natural_to_dist_params(natural_params: dict, dist_cls):
     """
     Map the natural parameters to one of pytorch's accepted distribution parameters.
     This is not necessarily the so called canonical response function.
     E.g. the Bernoulli and Categoricals are mapped to logits, not to probs,
     so that we do not unnecessarily switch around between the two.
     """
     if dist_cls is Normal:
         eta, neg_precision = (
             natural_params["eta"],
             natural_params["neg_precision"],
         )
         return {
             "loc": -0.5 * eta / neg_precision,
             "scale": torch.sqrt(-0.5 / neg_precision),
         }
     elif dist_cls is MultivariateNormal:
         eta, neg_precision = (
             natural_params["eta"],
             natural_params["neg_precision"],
         )
         precision_matrix = -2 * neg_precision
         covariance_matrix = batch_cholesky_inverse(
             cholesky(precision_matrix))
         return {
             "loc": matvec(covariance_matrix, eta),
             "covariance_matrix": covariance_matrix,
         }
     if dist_cls is Bernoulli:
         return natural_params
     elif dist_cls in (Categorical, OneHotCategorical):
         return natural_params
     else:
         raise NotImplementedError(f"Not implemented for {type(dist_cls)}")
Example #8
0
    def emit(
        self,
        lats_t: LatentsSGLS,
        ctrl_t: ControlInputsSGLS,
    ) -> torch.distributions.MultivariateNormal:
        # Unfortunately need to recompute gls_params.
        # Trade-off: faster, lower memory training vs. slower sampling/forecast
        gls_params_t = self.gls_base_parameters(
            switch=lats_t.variables.switch,
            controls=ctrl_t,
        )

        if lats_t.variables.m is not None:  # marginalise states
            mpy_t, Vpy_t = filter_forward_predictive_distribution(
                m=lats_t.variables.m,
                V=lats_t.variables.V,
                Q=gls_params_t.Q,
                C=gls_params_t.C,
                d=gls_params_t.d,
            )
            return MultivariateNormal(
                loc=mpy_t,
                scale_tril=cholesky(Vpy_t),
            )
        else:  # emit from state sample
            return torch.distributions.MultivariateNormal(
                loc=matvec(gls_params_t.C, lats_t.variables.x) +
                (gls_params_t.d if gls_params_t.d is not None else 0.0),
                scale_tril=gls_params_t.LQ,
            )
Example #9
0
def filter_forward_predictive_distribution(m, V, Q, C, d=None):
    """ The predictive distirbution p(yt | y_{1:t-1}),
    obtained by marginalising the state prediction distribution in the measurement model. """
    Vpy = symmetrize(matmul(C, matmul(V, C.transpose(-1, -2))) + Q)
    mpy = matvec(C, m)
    if d is not None:
        mpy += d
    return mpy, Vpy
Example #10
0
 def compute_bias(self, s, u=None, bias_fn=None, bias_matrix=None):
     if bias_fn is None and bias_matrix is None:
         b = None
     else:
         b_nonlin = bias_fn(s) if bias_fn is not None else 0.0
         b_lin = matvec(bias_matrix, u) if bias_matrix is not None else 0.0
         b = b_lin + b_nonlin
     return b
def filter_forward(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    # TODO: assumes B != None, D != None. u_state and u_obs can be None.
    #  Better to work with vectors b, d. matmul Bu should be done outside!
    m_fw = create_zeros_state_vec(dims=dims, device=A.device, dtype=A.dtype)
    V_fw = create_zeros_state_mat(dims=dims, device=A.device, dtype=A.dtype)
    for t in range(0, dims.timesteps):
        if t == 0:
            V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)
            mp, Vp = add_sample_dims_to_initial_state(m0=m0, V0=V0, dims=dims)
        else:
            R_tm1 = cov_from_invcholesky_param(LRinv_tril[t - 1],
                                               LRinv_logdiag[t - 1])
            b_tm1 = (matvec(B[t -
                              1], u_state[t -
                                          1]) if u_state is not None else 0)
            mp, Vp = filter_forward_prediction_step(m=m_fw[t - 1],
                                                    V=V_fw[t - 1],
                                                    R=R_tm1,
                                                    A=A[t - 1],
                                                    b=b_tm1)
        Q_t = cov_from_invcholesky_param(LQinv_tril[t], LQinv_logdiag[t])
        d_t = matvec(D[t], u_obs[t]) if u_obs is not None else 0
        m_fw[t], V_fw[t] = filter_forward_measurement_step(y=y[t],
                                                           m=mp,
                                                           V=Vp,
                                                           Q=Q_t,
                                                           C=C[t],
                                                           d=d_t)
    return torch.stack(m_fw, dim=0), torch.stack(V_fw, dim=0)
def sample(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    # generate noise
    wz = torch.randn(dims.timesteps, dims.batch, dims.state)
    wy = torch.randn(dims.timesteps, dims.batch, dims.target)

    # Initial step.
    # Note: We cannot use in-place operations here because we must backprop through y.
    LV0 = torch.inverse(
        torch.tril(LV0inv_tril, -1) + torch.diag(torch.exp(LV0inv_logdiag)))
    LQ_0 = torch.inverse(
        torch.tril(LQinv_tril[0], -1) +
        torch.diag(torch.exp(LQinv_logdiag[0])))
    d_0 = matvec(D[0], u_obs[0]) if u_obs is not None else 0
    x = [m0 + matvec(LV0, wz[0])] + [None] * (dims.timesteps - 1)
    y = [matvec(C[0], x[0]) + d_0 + matvec(LQ_0, wy[0])
         ] + [None] * (dims.timesteps - 1)
    for t in range(1, dims.timesteps):
        LR_tm1 = torch.inverse(
            torch.tril(LRinv_tril[t - 1], -1) +
            torch.diag(torch.exp(LRinv_logdiag[t - 1])))
        LQ_t = torch.inverse(
            torch.tril(LQinv_tril[t], -1) +
            torch.diag(torch.exp(LQinv_logdiag[t])))
        b_tm1 = matvec(B[t - 1], u_state[t - 1]) if u_state is not None else 0
        d_t = matvec(D[t], u_obs[t]) if u_obs is not None else 0

        x[t] = matvec(A[t - 1], x[t - 1]) + b_tm1 + matvec(LR_tm1, wz[t])
        y[t] = matvec(C[t], x[t]) + d_t + matvec(LQ_t, wy[t])

    return torch.stack(x, dim=0), torch.stack(y, dim=0)
Example #13
0
def filter_forward_prediction_step(m, V, R, A, b=None):
    """ Single prediction step in forward filtering cycle
    (prediction --> measurement) """
    mp = matvec(A, m) if A is not None else m
    if b is not None:
        mp = mp + b  # do not 'mp += b', is inplace and wont work with autograd
    Vp = symmetrize((
        matmul(A, matmul(V, A.transpose(-1, -2))) if A is not None else V) + R)
    return mp, Vp
def smooth_forward_backward(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    m_sm = create_zeros_state_vec(dims=dims, device=A.device, dtype=A.dtype)
    V_sm = create_zeros_state_mat(dims=dims, device=A.device, dtype=A.dtype)
    Cov_sm = create_zeros_state_mat(dims=dims, device=A.device, dtype=A.dtype)
    m_fw, V_fw = filter_forward(
        dims=dims,
        A=A,
        B=B,
        C=C,
        D=D,
        LQinv_tril=LQinv_tril,
        LQinv_logdiag=LQinv_logdiag,
        LRinv_tril=LRinv_tril,
        LRinv_logdiag=LRinv_logdiag,
        LV0inv_tril=LV0inv_tril,
        LV0inv_logdiag=LV0inv_logdiag,
        m0=m0,
        y=y,
        u_state=u_state,
        u_obs=u_obs,
    )
    m_sm[-1], V_sm[-1] = m_fw[-1], V_fw[-1]
    for t in reversed(range(0, dims.timesteps - 1)):
        Rt = cov_from_invcholesky_param(LRinv_tril[t], LRinv_logdiag[t])
        bt = matvec(B[t], u_state[t]) if u_state is not None else 0
        m_sm[t], V_sm[t], Cov_sm[t] = smooth_backward_step(
            m_sm=m_sm[t + 1],
            V_sm=V_sm[t + 1],
            m_fw=m_fw[t],
            V_fw=V_fw[t],
            A=A[t],
            R=Rt,
            b=bt,
        )
    return (
        torch.stack(m_sm, dim=0),
        torch.stack(V_sm, dim=0),
        torch.stack(Cov_sm, dim=0),
    )
Example #15
0
def filter_forward_measurement_step(y,
                                    m,
                                    V,
                                    Q,
                                    C,
                                    d=None,
                                    return_loss_components=False):
    """ Single measurement/update step in forward filtering cycle
    (prediction --> measurement) """
    mpy, Vpy = filter_forward_predictive_distribution(m=m, V=V, Q=Q, C=C, d=d)
    LVpyinv = torch.inverse(cholesky(Vpy))
    S = matmul(LVpyinv, matmul(
        C, V))  # CV is cov_T. cov VC.T could be output of predictive dist.
    dobs = y - mpy
    dobs_norm = matvec(LVpyinv, dobs)
    mt = m + matvec(S.transpose(-1, -2), dobs_norm)
    Vt = symmetrize(V - matmul(S.transpose(-1, -2), S))
    if return_loss_components:  # I hate to do that! but it is convenient...
        return mt, Vt, dobs_norm, LVpyinv
    else:
        return mt, Vt
Example #16
0
    def _make_switch_transition_dist(
        self,
        lat_vars_tm1: GLSVariablesSGLS,
        ctrl_t: ControlInputsSGLS,
    ) -> torch.distributions.MultivariateNormal:
        """
        Compute p(s_t | s_{t-1}) = \int p(x_{t-1}, s_t | s_{t-1}) dx_{t-1}
        = \int p(s_t | s_{t-1}, x_{t-1}) N(x_{t-1} | s_{t-1}) dx_{t-1}.
        We use an additive structure, resulting in a convolution of PDFs, i.e.
        i) the conditional from the switch-to-switch transition and
        ii) the marginal from the state-switch-transition (state marginalised).
        The Gaussian is a 'stable distribution' -> The sum of Gaussian variables
        (not PDFs!) is Gaussian, for which locs and *covariances* are summed.
        (In case of weighted sum, means and *scales* are weighted.)
        """
        if len({lat_vars_tm1.x is None, lat_vars_tm1.m is None}) != 2:
            raise Exception(
                "Provide samples XOR dist params (-> marginalize).")
        marginalize_states = lat_vars_tm1.m is not None

        # i) switch-to-switch conditional
        switch_to_switch_dist = super()._make_switch_transition_dist(
            lat_vars_tm1=lat_vars_tm1,
            ctrl_t=ctrl_t,
        )

        # ii) state-to-switch
        rec_base_params = self.recurrent_base_parameters(
            switch=lat_vars_tm1.switch)
        if marginalize_states:
            m, V = filter_forward_prediction_step(
                m=lat_vars_tm1.m,
                V=lat_vars_tm1.V,
                A=rec_base_params.F,
                R=rec_base_params.S,
                b=None,
            )
        else:
            m = matvec(rec_base_params.F, lat_vars_tm1.x)
            V = rec_base_params.S
        state_to_switch_dist = MultivariateNormal(
            loc=m,
            scale_tril=cholesky(V),
        )

        # combine i) & ii): sum variables (=convolve PDFs).
        switch_model_dist = gaussian_linear_combination({
            state_to_switch_dist:
            0.5,
            switch_to_switch_dist:
            0.5
        })
        return switch_model_dist
Example #17
0
def sample(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    device, dtype = A.device, A.dtype

    # generate noise
    wz = torch.randn(dims.timesteps, dims.batch, dims.state)
    wy = torch.randn(dims.timesteps, dims.batch, dims.target)

    # pre-compute cholesky matrices
    LR = torch.inverse(
        torch.tril(LRinv_tril, -1) + torch.diag(torch.exp(LRinv_logdiag)))
    LQ = torch.inverse(
        torch.tril(LQinv_tril, -1) + torch.diag(torch.exp(LQinv_logdiag)))
    LV0 = torch.inverse(
        torch.tril(LV0inv_tril, -1) + torch.diag(torch.exp(LV0inv_logdiag)))

    # pre-compute biases
    b = matvec(B, u_state) if u_state is not None else 0
    d = matvec(D, u_obs) if u_obs is not None else 0

    # Initial step.
    # Note: We cannot use in-place operations here because we must backprop through y.
    x = [m0 + matvec(LV0, wz[0])] + [None] * (dims.timesteps - 1)
    y = [matvec(C, x[0]) + d[0] + matvec(LQ, wy[0])
         ] + [None] * (dims.timesteps - 1)
    for t in range(1, dims.timesteps):
        x[t] = matvec(A, x[t - 1]) + b[t - 1] + matvec(LR, wz[t])
        y[t] = matvec(C, x[t]) + d[t] + matvec(LQ, wy[t])
    x = torch.stack(x, dim=0)
    y = torch.stack(y, dim=0)
    return x, y
Example #18
0
 def get_natural_params(dists: Tuple[Distribution]):
     dist_cls = type(dists[0])
     if dist_cls is Normal:
         return tuple(
             {
                 "eta": dist.loc / dist.scale.pow(2),
                 "neg_precision": -0.5 * dist.scale.pow(2).reciprocal(),
             } for dist in dists)
     elif dist_cls is MultivariateNormal:
         return tuple({
             "eta": matvec(dist.precision_matrix, dist.loc),
             "neg_precision": -0.5 * dist.precision_matrix,
         } for dist in dists)
     if dist_cls is Bernoulli:
         return tuple({
             "logits": dist.logits,
         } for dist in dists)
     elif dist_cls in (Categorical, OneHotCategorical):
         return tuple({
             "logits": dist.logits,
         } for dist in dists)
     else:
         raise NotImplementedError(f"Not implemented for {type(dists)}")
def loss_forward(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    """
    Computes the sample-wise (e.g. (particle, batch)) forward / filter loss.
    If particle and batch dims are used,
    the ordering is TPBF (time, particle, batch, feature).

    Note: it would be better numerically (amount of computation and precision)
    to sum/avg over these dimensions in all loss terms.
    However, we must compute it sample-wise for many algorithms,
    i.e. Rao-Blackwellised Particle Filters.
    """
    device, dtype = A.device, A.dtype

    V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)
    m0, V0 = add_sample_dims_to_initial_state(m0=m0, V0=V0, dims=dims)

    # Note: We can not use (more readable) in-place operations due to backprop problems.
    m_fw = [None] * dims.timesteps
    V_fw = [None] * dims.timesteps
    dim_particle = ((dims.particle, ) if dims.particle is not None
                    and dims.particle != 0 else tuple())
    loss = torch.zeros(dim_particle + (dims.batch, ),
                       device=device,
                       dtype=dtype)
    for t in range(0, dims.timesteps):
        (mp, Vp) = (filter_forward_prediction_step(
            m=m_fw[t - 1],
            V=V_fw[t - 1],
            R=cov_from_invcholesky_param(LRinv_tril[t - 1], LRinv_logdiag[t -
                                                                          1]),
            A=A[t - 1],
            b=matvec(B[t - 1], u_state[t - 1]) if u_state is not None else 0,
        ) if t > 0 else (m0, V0))
        m_fw[t], V_fw[t], dobs_norm, LVpyinv = filter_forward_measurement_step(
            y=y[t],
            m=mp,
            V=Vp,
            Q=cov_from_invcholesky_param(LQinv_tril[t], LQinv_logdiag[t]),
            C=C[t],
            d=matvec(D[t], u_obs[t]) if u_obs is not None else 0,
            return_loss_components=True,
        )
        loss += 0.5 * (torch.sum(dobs_norm**2, dim=-1) - 2 *
                       torch.sum(torch.log(batch_diag(LVpyinv)), dim=(-1, )) +
                       dims.target * LOG_2PI)
    return loss
Example #20
0
    def _loss_em_mc_efficient(
        self,
        past_targets: [Sequence[torch.Tensor], torch.Tensor],
        past_controls: Optional[Union[Sequence[ControlInputs],
                                      ControlInputs]] = None,
    ) -> torch.Tensor:
        """
        Monte Carlo loss as computed in KVAE paper.
        Can be computed more efficiently if no missing data (no imputation),
        by batching some things along time-axis.
        """
        past_controls = self._expand_particle_dim(past_controls)
        n_batch = len(past_targets[0])

        # A) SSM related distributions:
        # A1) smoothing.
        latents_smoothed = self._smooth_efficient(
            past_targets=past_targets,
            past_controls=past_controls,
            return_time_tensor=True,
        )

        state_smoothed_dist = MultivariateNormal(
            loc=latents_smoothed.variables.m,
            covariance_matrix=latents_smoothed.variables.V,
        )
        x = state_smoothed_dist.rsample()
        gls_params = latents_smoothed.gls_params

        # A2) prior && posterior transition distribution.
        prior_dist = self.state_prior_model(
            None, batch_shape_to_prepend=(self.n_particle, n_batch))

        #  # A, B, R are already 0:T-1.
        transition_dist = MultivariateNormal(
            loc=matvec(gls_params.A[:-1], x[:-1]) +
            (matvec(gls_params.B[:-1], past_controls.state[:-1])
             if gls_params.B is not None else 0.0),
            covariance_matrix=gls_params.R[:-1],
        )
        # A3) posterior predictive (auxiliary) distribution.
        auxiliary_predictive_dist = MultivariateNormal(
            loc=matvec(gls_params.C, x) +
            (matvec(gls_params.D, past_controls.target)
             if gls_params.D is not None else 0.0),
            covariance_matrix=gls_params.Q,
        )

        # A4) SSM related losses
        l_prior = (-prior_dist.log_prob(x[0:1]).sum(dim=(0, 1)) /
                   self.n_particle)  # time and particle dim
        l_transition = (-transition_dist.log_prob(x[1:]).sum(dim=(0, 1)) /
                        self.n_particle)  # time and particle dim
        l_auxiliary = (-auxiliary_predictive_dist.log_prob(
            latents_smoothed.variables.auxiliary).sum(dim=(0, 1)) /
                       self.n_particle)  # time and particle dim
        l_entropy = (
            state_smoothed_dist.log_prob(x).sum(dim=(0, 1))  # negative entropy
            / self.n_particle)  # time and particle dim

        # B) VAE related distributions
        # B1) inv_measurement_dist already obtained from smoothing (as we dont want to re-compute)
        # B2) measurement (decoder) distribution
        # transpose TPBF -> PTBF to broadcast log_prob of y (TBF) correctly
        z_particle_first = latents_smoothed.variables.auxiliary.transpose(0, 1)
        measurement_dist = self.measurement_model(z_particle_first)
        # B3) VAE related losses
        l_measurement = (
            -measurement_dist.log_prob(past_targets).sum(dim=(0, 1)) /
            self.n_particle)  # time and particle dim

        auxiliary_variational_dist = MultivariateNormal(
            loc=latents_smoothed.variables.m_auxiliary_variational,
            covariance_matrix=latents_smoothed.variables.
            V_auxiliary_variational,
        )
        l_inv_measurement = (
            auxiliary_variational_dist.log_prob(z_particle_first).sum(
                dim=(0, 1)) / self.n_particle)  # time and particle dim

        assert all(t.shape == l_prior.shape for t in (
            l_prior,
            l_transition,
            l_auxiliary,
            l_measurement,
            l_inv_measurement,
        ))

        l_total = (self.reconstruction_weight * l_measurement +
                   l_inv_measurement + l_auxiliary + l_prior + l_transition +
                   l_entropy)
        return l_total
Example #21
0
    def _loss_em_mc(
        self,
        past_targets: [Sequence[torch.Tensor], torch.Tensor],
        past_controls: Optional[Union[Sequence[ControlInputs],
                                      ControlInputs]] = None,
        past_targets_is_observed: Optional[Union[Sequence[torch.Tensor],
                                                 torch.Tensor]] = None,
    ) -> torch.Tensor:
        """" Monte Carlo loss as computed in KVAE paper """
        n_batch = len(past_targets[0])

        past_controls = self._expand_particle_dim(past_controls)

        # A) SSM related distributions:
        # A1) smoothing.
        latents_smoothed = self.smooth(
            past_targets=past_targets,
            past_controls=past_controls,
            past_targets_is_observed=past_targets_is_observed,
        )
        m = torch.stack([l.variables.m for l in latents_smoothed])
        V = torch.stack([l.variables.V for l in latents_smoothed])
        z = torch.stack([l.variables.auxiliary for l in latents_smoothed])
        state_smoothed_dist = MultivariateNormal(loc=m, covariance_matrix=V)
        x = state_smoothed_dist.rsample()

        A = torch.stack([l.gls_params.A for l in latents_smoothed])
        C = torch.stack([l.gls_params.C for l in latents_smoothed])
        LR = torch.stack([l.gls_params.LR for l in latents_smoothed])
        LQ = torch.stack([l.gls_params.LQ for l in latents_smoothed])
        if latents_smoothed[0].gls_params.B is not None:
            B = torch.stack([l.gls_params.B for l in latents_smoothed])
        else:
            B = None
        if latents_smoothed[0].gls_params.D is not None:
            D = torch.stack([l.gls_params.D for l in latents_smoothed])
        else:
            D = None

        # A2) prior && posterior transition distribution.
        prior_dist = self.state_prior_model(
            None, batch_shape_to_prepend=(self.n_particle, n_batch))

        #  # A, B, R are already 0:T-1.
        transition_dist = MultivariateNormal(
            loc=matvec(A[:-1], x[:-1]) + (matvec(
                B[:-1], past_controls.state[:-1]) if B is not None else 0.0),
            scale_tril=LR[:-1],
        )
        # A3) posterior predictive (auxiliary) distribution.
        auxiliary_predictive_dist = MultivariateNormal(
            loc=matvec(C, x) +
            (matvec(D, past_controls.target) if D is not None else 0.0),
            scale_tril=LQ,
        )

        # A4) SSM related losses
        # mean over particle dim, sum over time (after masking), leave batch dim
        l_prior = -prior_dist.log_prob(x[0:1]).mean(dim=1).sum(dim=0)
        l_transition = -transition_dist.log_prob(x[1:]).mean(dim=1).sum(dim=0)
        l_entropy = state_smoothed_dist.log_prob(x).mean(dim=1).sum(dim=0)

        _l_aux_timewise = -auxiliary_predictive_dist.log_prob(z).mean(dim=1)
        if past_targets_is_observed is not None:
            _l_aux_timewise = _l_aux_timewise * past_targets_is_observed
        l_auxiliary = _l_aux_timewise.sum(dim=0)

        # B) VAE related distributions
        # B1) inv_measurement_dist already obtained from smoothing (as we dont want to re-compute)
        # B2) measurement (decoder) distribution
        # transpose TPBF -> PTBF to broadcast log_prob of y (TBF) correctly
        z_particle_first = z.transpose(0, 1)
        measurement_dist = self.measurement_model(z_particle_first)
        # B3) VAE related losses
        # We use z_particle_first for correct broadcasting -> dim=0 is particle.
        _l_meas_timewise = -measurement_dist.log_prob(past_targets).mean(dim=0)
        if past_targets_is_observed is not None:
            _l_meas_timewise = _l_meas_timewise * past_targets_is_observed
        l_measurement = _l_meas_timewise.sum(dim=0)

        auxiliary_variational_dist = MultivariateNormal(
            loc=torch.stack([
                l.variables.m_auxiliary_variational for l in latents_smoothed
            ]),
            covariance_matrix=torch.stack([
                l.variables.V_auxiliary_variational for l in latents_smoothed
            ]),
        )
        _l_variational_timewise = auxiliary_variational_dist.log_prob(
            z_particle_first).mean(dim=0)  # again dim=0 is particle dim here.
        if past_targets_is_observed is not None:
            _l_variational_timewise = (_l_variational_timewise *
                                       past_targets_is_observed)
        l_inv_measurement = _l_variational_timewise.sum(dim=0)

        assert all(t.shape == l_prior.shape for t in (
            l_prior,
            l_transition,
            l_auxiliary,
            l_measurement,
            l_inv_measurement,
        ))

        l_total = (self.reconstruction_weight * l_measurement +
                   l_inv_measurement + l_auxiliary + l_prior + l_transition +
                   l_entropy)
        return l_total
Example #22
0
def smooth_forward_backward(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    device, dtype = A.device, A.dtype
    R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag)

    # pre-compute biases
    b = matvec(B, u_state) if u_state is not None else 0
    m_sm = torch.zeros((dims.timesteps, dims.batch, dims.state),
                       device=device,
                       dtype=dtype)
    V_sm = torch.zeros(
        (dims.timesteps, dims.batch, dims.state, dims.state),
        device=device,
        dtype=dtype,
    )
    Cov_sm = torch.zeros(
        (dims.timesteps, dims.batch, dims.state, dims.state),
        device=device,
        dtype=dtype,
    )

    m_fw, V_fw = filter_forward(
        dims=dims,
        A=A,
        B=B,
        C=C,
        D=D,
        LQinv_tril=LQinv_tril,
        LQinv_logdiag=LQinv_logdiag,
        LRinv_tril=LRinv_tril,
        LRinv_logdiag=LRinv_logdiag,
        LV0inv_tril=LV0inv_tril,
        LV0inv_logdiag=LV0inv_logdiag,
        m0=m0,
        y=y,
        u_state=u_state,
        u_obs=u_obs,
    )
    m_sm[-1], V_sm[-1] = m_fw[-1], V_fw[-1]
    for t in reversed(range(0, dims.timesteps - 1)):
        m_sm[t], V_sm[t], Cov_sm[t] = smooth_backward_step(
            m_sm=m_sm[t + 1],
            V_sm=V_sm[t + 1],
            m_fw=m_fw[t],
            V_fw=V_fw[t],
            A=A,
            R=R,
            b=b[t],
        )
    return m_sm, V_sm, Cov_sm
Example #23
0
def loss_em(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    Rinv = inv_from_invcholesky_param(LRinv_tril, LRinv_logdiag)
    Qinv = inv_from_invcholesky_param(LQinv_tril, LQinv_logdiag)

    with torch.no_grad():  # E-Step is optimal --> analytically zero gradients.
        m, V, Cov = smooth_forward_backward(
            dims=dims,
            A=A,
            B=B,
            C=C,
            D=D,
            LQinv_tril=LQinv_tril,
            LQinv_logdiag=LQinv_logdiag,
            LRinv_tril=LRinv_tril,
            LRinv_logdiag=LRinv_logdiag,
            LV0inv_tril=LV0inv_tril,
            LV0inv_logdiag=LV0inv_logdiag,
            m0=m0,
            y=y,
            u_state=u_state,
            u_obs=u_obs,
        )
        loss_entropy = -compute_entropy(dims=dims, V=V, Cov=Cov)

    Cov_sum = torch.sum(Cov[:-1], dim=0)  # idx -1 is Cov_{T, T+1}.
    V_sum = torch.sum(V, dim=0)
    V_sum_head = V_sum - V[-1]
    V_sum_tail = V_sum - V[0]

    # initial prior loss
    V0inv = inv_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)
    delta_init = m[0] - m0
    quad_init = matmul(delta_init[..., None], delta_init[..., None, :]) + V[0]
    loss_init = 0.5 * (torch.sum(V0inv * quad_init, dim=(-1, -2)) -
                       2.0 * torch.sum(LV0inv_logdiag) + dims.state * LOG_2PI)

    # transition losses - summed over all time-steps
    b = matvec(B, u_state[:-1]) if u_state is not None else 0
    delta_trans = m[1:] - matvec(A, m[:-1]) - b
    quad_trans = (matmul(
        delta_trans.transpose(0, 1).transpose(-1, -2),
        delta_trans.transpose(0, 1),
    ) + V_sum_tail - matmul(A, Cov_sum) -
                  matmul(Cov_sum.transpose(-1, -2), A.transpose(-1, -2)) +
                  matmul(matmul(A, V_sum_head), A.transpose(-1, -2)))
    loss_trans = 0.5 * (
        torch.sum(Rinv * quad_trans, dim=(-1, -2)) - 2.0 *
        (dims.timesteps - 1) * torch.sum(LRinv_logdiag, dim=-1) +
        (dims.timesteps - 1) * dims.state * LOG_2PI)

    # observation losses - summed over all time-steps
    d = matvec(D, u_obs) if u_obs is not None else 0
    delta_obs = y - matvec(C, m) - d
    quad_obs = matmul(
        delta_obs.transpose(0, 1).transpose(-1, -2), delta_obs.transpose(
            0, 1)) + matmul(C, matmul(V_sum, C.transpose(-1, -2)))
    loss_obs = 0.5 * (torch.sum(Qinv * quad_obs, dim=(-1, -2)) -
                      2.0 * dims.timesteps * torch.sum(LQinv_logdiag, dim=-1) +
                      dims.timesteps * dims.target * LOG_2PI)

    loss = loss_trans + loss_obs + loss_init + loss_entropy
    return loss
def loss_em(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    """
    Computes the sample-wise (e.g. (particle, batch)) EM loss.
    If particle and batch dims are used,
    the ordering is TPBF (time, particle, batch, feature).
    """
    with torch.no_grad(
    ):  # Inference (E-Step) is optimal --> analytically zero gradients.
        m, V, Cov = smooth_forward_backward(
            dims=dims,
            A=A,
            B=B,
            C=C,
            D=D,
            LQinv_tril=LQinv_tril,
            LQinv_logdiag=LQinv_logdiag,
            LRinv_tril=LRinv_tril,
            LRinv_logdiag=LRinv_logdiag,
            LV0inv_tril=LV0inv_tril,
            LV0inv_logdiag=LV0inv_logdiag,
            m0=m0,
            y=y,
            u_state=u_state,
            u_obs=u_obs,
        )
    Rinv = inv_from_invcholesky_param(LRinv_tril, LRinv_logdiag)
    Qinv = inv_from_invcholesky_param(LQinv_tril, LQinv_logdiag)
    if (
            A.ndim == 3
    ):  # No Batch and Particle dimension --> Must add at least Batch dimension.
        Rinv, Qinv = Rinv[:, None, :, :], Qinv[:, None, :, :]
        A, B = A[:, None, :, :], B[:, None, :, :]
        C, D = C[:, None, :, :], D[:, None, :, :]
    # initial prior loss
    V0inv = inv_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)
    dinit = m[0] - m0
    quad_init = matmul(dinit[..., None], dinit[..., None, :]) + V[0]

    loss_init = 0.5 * (
        torch.sum(V0inv * quad_init, dim=(-1, -2))  # FF
        - 2.0 * torch.sum(LV0inv_logdiag, dim=(-1, ))  # F
        + dims.state * LOG_2PI)

    # transition: Note that we here do no sum the quads over all time-steps.
    b = matvec(B, u_state[:-1]) if u_state is not None else 0
    dtrans = (m[1:] - matvec(A, m[:-1]) - b)[..., None]
    quad_trans = (matmul(dtrans, dtrans.transpose(-1, -2)) + V[1:] -
                  matmul(A, Cov[:-1]) -
                  matmul(Cov[:-1].transpose(-1, -2), A.transpose(-1, -2)) +
                  matmul(matmul(A, V[:-1]), A.transpose(-1, -2)))

    loss_trans = 0.5 * (
        torch.sum(Rinv * quad_trans, dim=(0, -1, -2))  # T...FF
        - 2.0 * torch.sum(LRinv_logdiag, dim=(0, -1))  # T...F
        + (dims.timesteps - 1) * dims.state * LOG_2PI)

    # likelihood
    d = matvec(D, u_obs) if u_obs is not None else 0
    dobs = (y - matvec(C, m) - d)[..., None]
    quad_obs = matmul(dobs, dobs.transpose(-1, -2)) + matmul(
        C, matmul(V, C.transpose(-1, -2)))
    loss_obs = 0.5 * (
        torch.sum(Qinv * quad_obs, dim=(0, -1, -2))  # T...FF
        - 2.0 * torch.sum(LQinv_logdiag, dim=(0, -1))  # T...F
        + dims.timesteps * dims.target * LOG_2PI)

    with torch.no_grad(
    ):  # posterior optimal --> entropy has analytically zero gradients as well.
        loss_entropy = -compute_entropy(dims=dims, V=V, Cov=Cov)
    assert (loss_trans.shape == loss_obs.shape == loss_init.shape ==
            loss_entropy.shape)
    loss_all = loss_trans + loss_obs + loss_init + loss_entropy
    return loss_all
Example #25
0
def smooth_global(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    """ compute posterior by direct inversion of unrolled model """
    device, dtype = A.device, A.dtype

    R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag)
    Q = cov_from_invcholesky_param(LQinv_tril, LQinv_logdiag)
    V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)

    Q_field = torch.zeros(
        (dims.batch, dims.timesteps * dims.state, dims.timesteps * dims.state),
        device=device,
        dtype=dtype,
    )
    h_field = torch.zeros((dims.batch, dims.timesteps * dims.state),
                          device=device,
                          dtype=dtype)

    # pre-compute biases
    b = matvec(B, u_state) if u_state is not None else 0
    d = matvec(D, u_obs) if u_obs is not None else 0

    Rinv = symmetrize(torch.cholesky_inverse(cholesky(R)))
    Qinv = symmetrize(torch.cholesky_inverse(cholesky(Q)))
    V0inv = symmetrize(torch.cholesky_inverse(cholesky(V0)))

    CtQinvymd = matvec(matmul(C.transpose(-1, -2), Qinv), y - d)
    h_obs = CtQinvymd.transpose(1, 0).reshape((
        dims.batch,
        dims.timesteps * dims.state,
    ))
    Q_obs = kron(
        torch.eye(dims.timesteps, dtype=dtype, device=device),
        matmul(C.transpose(-1, -2), matmul(Qinv, C)),
    )

    AtRinvA = matmul(A.transpose(-1, -2), matmul(Rinv, A))
    RinvA = matmul(Rinv, A)

    h_field[:, :dims.state] = matmul(V0inv, m0).repeat((dims.batch, ) + (1, ) *
                                                       (h_field.ndim - 1))
    Q_field[:, :dims.state, :dims.state] += V0inv.repeat((dims.batch, ) +
                                                         (1, ) *
                                                         (Q_field.ndim - 1))
    for t in range(dims.timesteps - 1):
        idx = t * dims.state
        h_field[:,
                idx:idx + dims.state] += -matvec(RinvA.transpose(-1, -2), b[t])
        h_field[:, idx + dims.state:idx + 2 * dims.state] += matvec(Rinv, b[t])
        Q_field[:, idx:idx + dims.state, idx:idx + dims.state] += AtRinvA
        Q_field[:, idx:idx + dims.state, idx + dims.state:idx +
                2 * dims.state] += -RinvA.transpose(-1, -2)
        Q_field[:, idx + dims.state:idx + 2 * dims.state,
                idx:idx + dims.state] += -RinvA
        Q_field[:, idx + dims.state:idx + 2 * dims.state,
                idx + dims.state:idx + 2 * dims.state, ] += Rinv

    L_all_inv = torch.inverse(cholesky(Q_field + Q_obs))
    V_all = matmul(L_all_inv.transpose(-1, -2), L_all_inv)
    m_all = matvec(V_all, h_obs + h_field)

    # Pytorch has no Fortran style reading of indices.
    m = m_all.reshape((dims.batch, dims.timesteps, dims.state)).transpose(0, 1)
    V, Cov = [], []
    for t in range(0, dims.timesteps):
        idx = t * dims.state
        V.append(V_all[:, idx:idx + dims.state, idx:idx + dims.state])
        if t < (dims.timesteps - 1):
            Cov.append(V_all[:, idx:idx + dims.state,
                             idx + dims.state:idx + 2 * dims.state, ])
        else:
            Cov.append(
                torch.zeros(
                    (dims.batch, dims.state, dims.state),
                    device=device,
                    dtype=dtype,
                ))
    V = torch.stack(V, dim=0)
    Cov = torch.stack(Cov, dim=0)

    return m, V, Cov
Example #26
0
    def sample_step(
        self,
        lats_tm1: LatentsKVAE,
        ctrl_t: ControlInputsKVAE,
        deterministic: bool = False,
    ) -> Prediction:
        first_step = lats_tm1.gls_params is None

        if first_step:  # from t == 0, i.e. lats_tm1 is t == -1.
            n_batch = len(lats_tm1.variables.auxiliary.shape[1])
            assert lats_tm1.variables.x is None
            assert lats_tm1.variables.m is None
            assert lats_tm1.variables.V is None
            x_t_dist = self.state_prior_model(
                None, batch_shape_to_prepend=(self.n_particle, n_batch))
        else:
            x_t_dist = torch.distributions.MultivariateNormal(
                loc=(matvec(lats_tm1.gls_params.A, lats_tm1.variables.x)
                     if lats_tm1.gls_params.A is not None else
                     lats_tm1.variables.x) +
                (lats_tm1.gls_params.b
                 if lats_tm1.gls_params.b is not None else 0.0),
                scale_tril=lats_tm1.gls_params.LR,
            )

        rnn_state_t, rnn_output_t = self.compute_deterministic_switch_step(
            rnn_input=lats_tm1.variables.auxiliary,
            rnn_prev_state=lats_tm1.variables.rnn_state,
        )
        gls_params_t = self.gls_base_parameters(
            switch=rnn_output_t,
            controls=ctrl_t,
        )

        x_t = x_t_dist.mean if deterministic else x_t_dist.rsample()
        z_t_dist = torch.distributions.MultivariateNormal(
            loc=matvec(gls_params_t.C, x_t) +
            (gls_params_t.d if gls_params_t.d is not None else 0.0),
            covariance_matrix=gls_params_t.Q,
        )
        z_t = z_t_dist.mean if deterministic else z_t_dist.rsample()

        lats_t = LatentsKVAE(
            variables=GLSVariablesKVAE(
                m=None,
                V=None,
                Cov=None,
                x=x_t,
                auxiliary=z_t,
                rnn_state=rnn_state_t,
                m_auxiliary_variational=None,
                V_auxiliary_variational=None,
            ),
            gls_params=gls_params_t,
        )

        emission_dist_t = self.emit(lats_t=lats_t, ctrl_t=ctrl_t)
        emissions_t = (emission_dist_t.mean
                       if deterministic else emission_dist_t.sample())

        return Prediction(
            latents=lats_t,
            emissions=emissions_t,
        )
def smooth_global(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    """ compute posterior by direct inversion of unrolled model """
    # This implementation works only if all mats have time and batch dimension.
    #  Otherwise does not broadcast correctly.
    assert A.ndim == 4
    assert B.ndim == 4
    assert C.ndim == 4
    assert D.ndim == 4
    assert LQinv_tril.ndim == 4
    assert LQinv_logdiag.ndim == 3
    assert LRinv_tril.ndim == 4
    assert LRinv_logdiag.ndim == 3

    # raise NotImplementedError("Not yet implemented")
    device, dtype = A.device, A.dtype

    R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag)
    Q = cov_from_invcholesky_param(LQinv_tril, LQinv_logdiag)
    V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)

    Q_field = torch.zeros(
        (dims.batch, dims.timesteps * dims.state, dims.timesteps * dims.state),
        device=device,
        dtype=dtype,
    )
    h_field = torch.zeros((dims.batch, dims.timesteps * dims.state),
                          device=device,
                          dtype=dtype)

    b = matvec(B, u_state[:-1]) if u_state is not None else 0
    d = matvec(D, u_obs) if u_obs is not None else 0

    Rinv = symmetrize(batch_cholesky_inverse(cholesky(R)))
    Qinv = symmetrize(batch_cholesky_inverse(cholesky(Q)))
    V0inv = symmetrize(batch_cholesky_inverse(cholesky(V0)))

    CtQinvymd = matvec(matmul(C.transpose(-1, -2), Qinv), y - d)
    h_obs = CtQinvymd.transpose(1, 0).reshape((
        dims.batch,
        dims.timesteps * dims.state,
    ))

    CtQinvC = matmul(C.transpose(-1, -2), matmul(Qinv, C))
    assert len(CtQinvC) == dims.timesteps
    Q_obs = make_block_diagonal(mats=tuple(mat_t for mat_t in CtQinvC))

    AtRinvA = matmul(A.transpose(-1, -2), matmul(Rinv, A))
    RinvA = matmul(Rinv, A)

    h_field[:, :dims.state] = matmul(V0inv, m0).repeat((dims.batch, ) + (1, ) *
                                                       (h_field.ndim - 1))
    Q_field[:, :dims.state, :dims.state] += V0inv.repeat((dims.batch, ) +
                                                         (1, ) *
                                                         (Q_field.ndim - 1))
    for t in range(dims.timesteps - 1):
        id = t * dims.state
        h_field[:, id:id +
                dims.state] += -matvec(RinvA[t].transpose(-1, -2), b[t])
        h_field[:,
                id + dims.state:id + 2 * dims.state] += matvec(Rinv[t], b[t])
        Q_field[:, id:id + dims.state, id:id + dims.state] += AtRinvA[t]
        Q_field[:, id:id + dims.state, id + dims.state:id +
                2 * dims.state] += -RinvA[t].transpose(-1, -2)
        Q_field[:, id + dims.state:id + 2 * dims.state,
                id:id + dims.state] += -RinvA[t]
        Q_field[:, id + dims.state:id + 2 * dims.state,
                id + dims.state:id + 2 * dims.state, ] += Rinv[t]

    L_all_inv = torch.inverse(cholesky(Q_field + Q_obs))
    V_all = matmul(L_all_inv.transpose(-1, -2), L_all_inv)
    m_all = matvec(V_all, h_obs + h_field)

    # Pytorch has no Fortran style reading of indices.
    m = m_all.reshape((dims.batch, dims.timesteps, dims.state)).transpose(0, 1)
    V, Cov = [], []
    for t in range(0, dims.timesteps):
        id = t * dims.state
        V.append(V_all[:, id:id + dims.state, id:id + dims.state])
        if t < (dims.timesteps - 1):
            Cov.append(V_all[:, id:id + dims.state,
                             id + dims.state:id + 2 * dims.state, ])
        else:
            Cov.append(
                torch.zeros(
                    (dims.batch, dims.state, dims.state),
                    device=device,
                    dtype=dtype,
                ))
    V = torch.stack(V, dim=0)
    Cov = torch.stack(Cov, dim=0)

    return m, V, Cov