예제 #1
0
    def forward(self, u, s, x=None, m=None, V=None):
        """
        m and V are from the Gaussian state x_{t-1}.
        Forward marginalises out the Gaussian if m and V given,
        or transforms a sample x if that is given instead.
        """
        # assert len({(x is None), (m is None and V is None)}) == 2
        assert (x is None) or (m is None and V is None)  #

        h = torch.cat((u, s), dim=-1) if u is not None else s
        switch_to_switch_dist = self.conditional_dist_transform(h)
        if self.F is not None:
            if m is not None:  # marginalise
                mp, Vp = filter_forward_prediction_step(
                    m=m,
                    V=V,
                    A=self.F,
                    R=switch_to_switch_dist.covariance_matrix,
                    b=switch_to_switch_dist.loc,
                )
            else:  # single sample fwd
                mp = matvec(self.F, x) + switch_to_switch_dist.loc
                Vp = switch_to_switch_dist.covariance_matrix
            return MultivariateNormal(loc=mp, scale_tril=cholesky(Vp))
        else:
            return switch_to_switch_dist
예제 #2
0
    def _make_switch_transition_dist(
        self,
        lat_vars_tm1: GLSVariablesSGLS,
        ctrl_t: ControlInputsSGLS,
    ) -> torch.distributions.MultivariateNormal:
        """
        Compute p(s_t | s_{t-1}) = \int p(x_{t-1}, s_t | s_{t-1}) dx_{t-1}
        = \int p(s_t | s_{t-1}, x_{t-1}) N(x_{t-1} | s_{t-1}) dx_{t-1}.
        We use an additive structure, resulting in a convolution of PDFs, i.e.
        i) the conditional from the switch-to-switch transition and
        ii) the marginal from the state-switch-transition (state marginalised).
        The Gaussian is a 'stable distribution' -> The sum of Gaussian variables
        (not PDFs!) is Gaussian, for which locs and *covariances* are summed.
        (In case of weighted sum, means and *scales* are weighted.)
        """
        if len({lat_vars_tm1.x is None, lat_vars_tm1.m is None}) != 2:
            raise Exception(
                "Provide samples XOR dist params (-> marginalize).")
        marginalize_states = lat_vars_tm1.m is not None

        # i) switch-to-switch conditional
        switch_to_switch_dist = super()._make_switch_transition_dist(
            lat_vars_tm1=lat_vars_tm1,
            ctrl_t=ctrl_t,
        )

        # ii) state-to-switch
        rec_base_params = self.recurrent_base_parameters(
            switch=lat_vars_tm1.switch)
        if marginalize_states:
            m, V = filter_forward_prediction_step(
                m=lat_vars_tm1.m,
                V=lat_vars_tm1.V,
                A=rec_base_params.F,
                R=rec_base_params.S,
                b=None,
            )
        else:
            m = matvec(rec_base_params.F, lat_vars_tm1.x)
            V = rec_base_params.S
        state_to_switch_dist = MultivariateNormal(
            loc=m,
            scale_tril=cholesky(V),
        )

        # combine i) & ii): sum variables (=convolve PDFs).
        switch_model_dist = gaussian_linear_combination({
            state_to_switch_dist:
            0.5,
            switch_to_switch_dist:
            0.5
        })
        return switch_model_dist
예제 #3
0
def filter_forward(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    device, dtype = A.device, A.dtype

    R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag)
    Q = cov_from_invcholesky_param(LQinv_tril, LQinv_logdiag)
    V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)
    m0, V0 = add_sample_dims_to_initial_state(m0=m0, V0=V0, dims=dims)

    # pre-compute biases
    b = matvec(B, u_state) if u_state is not None else 0
    d = matvec(D, u_obs) if u_obs is not None else 0

    m_fw = torch.zeros((dims.timesteps, dims.batch, dims.state),
                       device=device,
                       dtype=dtype)
    V_fw = torch.zeros(
        (dims.timesteps, dims.batch, dims.state, dims.state),
        device=device,
        dtype=dtype,
    )
    for t in range(0, dims.timesteps):
        (mp, Vp) = (filter_forward_prediction_step(
            m=m_fw[t - 1],
            V=V_fw[t - 1],
            R=R,
            A=A,
            b=b[t - 1],
        ) if t > 0 else (m0, V0))
        m_fw[t], V_fw[t] = filter_forward_measurement_step(y=y[t],
                                                           m=mp,
                                                           V=Vp,
                                                           Q=Q,
                                                           C=C,
                                                           d=d[t])
    return m_fw, V_fw
예제 #4
0
def loss_forward(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    device, dtype = A.device, A.dtype

    R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag)
    Q = cov_from_invcholesky_param(LQinv_tril, LQinv_logdiag)
    V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)
    m0, V0 = add_sample_dims_to_initial_state(m0=m0, V0=V0, dims=dims)

    # pre-compute biases
    b = matvec(B, u_state) if u_state is not None else 0
    d = matvec(D, u_obs) if u_obs is not None else 0

    # Note: We can not use (more readable) in-place operations due to backprop problems.
    m_fw = [None] * dims.timesteps
    V_fw = [None] * dims.timesteps
    loss = torch.zeros((dims.batch, ), device=device, dtype=dtype)
    for t in range(0, dims.timesteps):
        (mp, Vp) = (filter_forward_prediction_step(
            m=m_fw[t - 1],
            V=V_fw[t - 1],
            R=R,
            A=A,
            b=b[t - 1],
        ) if t > 0 else (m0, V0))
        m_fw[t], V_fw[t], dobs_norm, LVpyinv = filter_forward_measurement_step(
            y=y[t], m=mp, V=Vp, Q=Q, C=C, d=d[t], return_loss_components=True)
        loss += (
            0.5 * torch.sum(dobs_norm**2, dim=-1) -
            0.5 * 2 * torch.sum(torch.log(batch_diag(LVpyinv)), dim=(-1, )) +
            0.5 * dims.target * LOG_2PI)

    return loss
def filter_forward(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    # TODO: assumes B != None, D != None. u_state and u_obs can be None.
    #  Better to work with vectors b, d. matmul Bu should be done outside!
    m_fw = create_zeros_state_vec(dims=dims, device=A.device, dtype=A.dtype)
    V_fw = create_zeros_state_mat(dims=dims, device=A.device, dtype=A.dtype)
    for t in range(0, dims.timesteps):
        if t == 0:
            V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)
            mp, Vp = add_sample_dims_to_initial_state(m0=m0, V0=V0, dims=dims)
        else:
            R_tm1 = cov_from_invcholesky_param(LRinv_tril[t - 1],
                                               LRinv_logdiag[t - 1])
            b_tm1 = (matvec(B[t -
                              1], u_state[t -
                                          1]) if u_state is not None else 0)
            mp, Vp = filter_forward_prediction_step(m=m_fw[t - 1],
                                                    V=V_fw[t - 1],
                                                    R=R_tm1,
                                                    A=A[t - 1],
                                                    b=b_tm1)
        Q_t = cov_from_invcholesky_param(LQinv_tril[t], LQinv_logdiag[t])
        d_t = matvec(D[t], u_obs[t]) if u_obs is not None else 0
        m_fw[t], V_fw[t] = filter_forward_measurement_step(y=y[t],
                                                           m=mp,
                                                           V=Vp,
                                                           Q=Q_t,
                                                           C=C[t],
                                                           d=d_t)
    return torch.stack(m_fw, dim=0), torch.stack(V_fw, dim=0)
예제 #6
0
    def marginal_step(
        self,
        lats_tm1: LatentsSGLS,
        ctrl_t: ControlInputsSGLS,
        deterministic: bool = False,
    ) -> Prediction:
        # TODO: duplication with sample_step. Requires refactoring.
        n_batch = lats_tm1.variables.m.shape[1]

        if lats_tm1.variables.switch is None:
            switch_model_dist_t = self._make_switch_prior_dist(
                n_particle=self.n_particle,
                n_batch=n_batch,
                lat_vars_tm1=lats_tm1.variables,
                ctrl_t=ctrl_t,
            )
        else:
            switch_model_dist_t = self._make_switch_transition_dist(
                lat_vars_tm1=lats_tm1.variables,
                ctrl_t=ctrl_t,
            )

        s_t = (switch_model_dist_t.mean
               if deterministic else switch_model_dist_t.sample())
        gls_params_t = self.gls_base_parameters(
            switch=s_t,
            controls=ctrl_t,
        )

        x_t = None
        m_t, V_t = filter_forward_prediction_step(
            m=lats_tm1.variables.m,
            V=lats_tm1.variables.V,
            R=gls_params_t.R,
            A=gls_params_t.A,
            b=gls_params_t.b,
        )
        mpy_t, Vpy_t = filter_forward_predictive_distribution(
            m=m_t,
            V=V_t,
            Q=gls_params_t.Q,
            C=gls_params_t.C,
            d=gls_params_t.d,
        )
        emission_dist_t = torch.distributions.MultivariateNormal(
            loc=mpy_t,
            scale_tril=cholesky(Vpy_t),
        )

        # NOTE: Should compute Cov if need forecast joint distribution.
        lats_t = LatentsSGLS(
            log_weights=lats_tm1.log_weights,  # does not change w/o evidence.
            gls_params=None,  # not used outside this function
            variables=GLSVariablesSGLS(
                x=x_t,
                m=m_t,
                V=V_t,
                Cov=None,
                switch=s_t,
            ),
        )
        return Prediction(latents=lats_t, emissions=emission_dist_t)
예제 #7
0
    def filter_step(
        self,
        lats_tm1: (LatentsSGLS, None),
        tar_t: torch.Tensor,
        ctrl_t: ControlInputsSGLS,
        tar_is_obs_t: Optional[torch.Tensor] = None,
    ):
        if tar_is_obs_t is not None:
            raise NotImplementedError("cannot handle missing data atm.")

        is_initial_step = lats_tm1 is None
        if is_initial_step:
            n_particle, n_batch = self.n_particle, len(tar_t)
            state_prior = self.state_prior_model(
                None,
                batch_shape_to_prepend=(n_particle, n_batch),
            )
            log_norm_weights = normalize_log_weights(
                log_weights=torch.zeros_like(state_prior.loc[..., 0]), )
            lats_tm1 = LatentsSGLS(
                log_weights=None,  # Not used. We use log_norm_weights instead.
                gls_params=None,  # First (previous) step no gls_params
                variables=GLSVariablesSGLS(
                    m=state_prior.loc,
                    V=state_prior.covariance_matrix,
                    Cov=None,
                    x=None,
                    switch=None,
                ),
            )
            switch_model_dist = self._make_switch_prior_dist(
                lat_vars_tm1=lats_tm1.variables,
                ctrl_t=ctrl_t,
                n_particle=n_particle,
                n_batch=n_batch,
            )
        else:
            log_norm_weights = normalize_log_weights(
                log_weights=lats_tm1.log_weights, )
            log_norm_weights, resampled_tensors = resample(
                n_particle=self.n_particle,
                log_norm_weights=log_norm_weights,
                tensors_to_resample={
                    key: val
                    for key, val in lats_tm1.variables.__dict__.items()
                    if key not in ("x", "Cov")  # below set to None explicitly
                },
                resampling_indices_fn=self.resampling_indices_fn,
                criterion_fn=self.resampling_criterion_fn,
            )
            # We dont use gls_params anymore.
            # If needed for e.g. evaluation, remember to re-sample all params!
            lats_tm1 = LatentsSGLS(
                log_weights=None,  # Not used. We use log_norm_weights instead.
                gls_params=None,  # not used outside this function. Read above.
                variables=GLSVariablesSGLS(
                    **resampled_tensors,
                    x=None,
                    Cov=None,
                ),
            )
            switch_model_dist = self._make_switch_transition_dist(
                lat_vars_tm1=lats_tm1.variables,
                ctrl_t=ctrl_t,
            )

        switch_proposal_dist = self._make_switch_proposal_dist(
            switch_model_dist=switch_model_dist,
            switch_encoder_dist=self._make_encoder_dists(
                tar_t=tar_t,
                ctrl_t=ctrl_t,
            ).switch,
        )
        s_t = switch_proposal_dist.rsample()
        gls_params_t = self.gls_base_parameters(
            switch=s_t,
            controls=ctrl_t,
        )

        mp, Vp = filter_forward_prediction_step(
            m=lats_tm1.variables.m,
            V=lats_tm1.variables.V,
            R=gls_params_t.R,
            A=gls_params_t.A,
            b=gls_params_t.b,
        )

        m_t, V_t = filter_forward_measurement_step(
            y=tar_t,
            m=mp,
            V=Vp,
            Q=gls_params_t.Q,
            C=gls_params_t.C,
            d=gls_params_t.d,
        )
        mpy_t, Vpy_t = filter_forward_predictive_distribution(
            m=mp,
            V=Vp,
            Q=gls_params_t.Q,
            C=gls_params_t.C,
            d=gls_params_t.d,
        )
        measurement_dist = MultivariateNormal(
            loc=mpy_t,
            scale_tril=cholesky(Vpy_t),
        )

        log_update = (measurement_dist.log_prob(tar_t) +
                      switch_model_dist.log_prob(s_t) -
                      switch_proposal_dist.log_prob(s_t))
        log_weights_t = log_norm_weights + log_update

        return LatentsSGLS(
            log_weights=log_weights_t,
            gls_params=None,  # not used outside this function
            variables=GLSVariablesSGLS(
                m=m_t,
                V=V_t,
                x=None,
                Cov=None,
                switch=s_t,
            ),
        )
예제 #8
0
    def filter_step(
        self,
        lats_tm1: (LatentsKVAE, None),
        tar_t: torch.Tensor,
        ctrl_t: ControlInputs,
        tar_is_obs_t: Optional[torch.Tensor] = None,
    ) -> LatentsKVAE:
        is_initial_step = lats_tm1 is None
        if tar_is_obs_t is None:
            tar_is_obs_t = torch.ones(
                tar_t.shape[:-1],
                dtype=tar_t.dtype,
                device=tar_t.device,
            )

        # 1) Initial step must prepare previous latents with prior and learnt z.
        if is_initial_step:
            n_particle, n_batch = self.n_particle, len(tar_t)
            state_prior = self.state_prior_model(
                None,
                batch_shape_to_prepend=(n_particle, n_batch),
            )
            z_init = self.z_initial[None, None].repeat(n_particle, n_batch, 1)
            lats_tm1 = LatentsKVAE(
                variables=GLSVariablesKVAE(
                    m=state_prior.loc,
                    V=state_prior.covariance_matrix,
                    Cov=None,
                    x=None,
                    auxiliary=z_init,
                    rnn_state=None,
                    m_auxiliary_variational=None,
                    V_auxiliary_variational=None,
                ),
                gls_params=None,
            )
        # 2) Compute GLS params
        rnn_state_t, rnn_output_t = self.compute_deterministic_switch_step(
            rnn_input=lats_tm1.variables.auxiliary,
            rnn_prev_state=lats_tm1.variables.rnn_state,
        )
        gls_params_t = self.gls_base_parameters(
            switch=rnn_output_t,
            controls=ctrl_t,
        )

        # Perform filter step:
        # 3) Prediction Step: Only for t > 0 and using previous GLS params.
        # (In KVAE, they do first update then prediction step.)
        if is_initial_step:
            mp, Vp, = lats_tm1.variables.m, lats_tm1.variables.V
        else:
            mp, Vp = filter_forward_prediction_step(
                m=lats_tm1.variables.m,
                V=lats_tm1.variables.V,
                R=lats_tm1.gls_params.R,
                A=lats_tm1.gls_params.A,
                b=lats_tm1.gls_params.b,
            )
        # 4) Update step
        # 4a) Observed data: Infer pseudo-obs by encoding obs && Bayes update
        auxiliary_variational_dist_t = self.encoder(tar_t)
        z_infer_t = auxiliary_variational_dist_t.rsample([self.n_particle])
        m_infer_t, V_infer_t = filter_forward_measurement_step(
            y=z_infer_t,
            m=mp,
            V=Vp,
            Q=gls_params_t.Q,
            C=gls_params_t.C,
            d=gls_params_t.d,
        )

        # 4b) Choice: inferred / predicted m, V for observed / missing data.
        is_filtered = tar_is_obs_t[None, :].repeat(self.n_particle, 1).byte()
        replace_m_fw = is_filtered[:, :, None].repeat(1, 1, mp.shape[2])
        replace_V_fw = is_filtered[:, :, None, None].repeat(
            1,
            1,
            Vp.shape[2],
            Vp.shape[3],
        )
        assert replace_m_fw.shape == m_infer_t.shape == mp.shape
        assert replace_V_fw.shape == V_infer_t.shape == Vp.shape

        m_t = torch.where(replace_m_fw, m_infer_t, mp)
        V_t = torch.where(replace_V_fw, V_infer_t, Vp)

        # 4c) Missing Data: Predict pseudo-observations && No Bayes update
        mpz_t, Vpz_t = filter_forward_predictive_distribution(
            m=m_t,  # posterior predictive or one-step-predictive (if missing)
            V=V_t,
            Q=gls_params_t.Q,
            C=gls_params_t.C,
            d=gls_params_t.d,
        )
        auxiliary_predictive_dist_t = MultivariateNormal(
            loc=mpz_t,
            covariance_matrix=Vpz_t,
        )
        z_gen_t = auxiliary_predictive_dist_t.rsample()

        # 4d) Choice: inferred / predicted z for observed / missing data.
        # One-step predictive if missing and inferred from encoder otherwise.
        replace_z = is_filtered[:, :, None].repeat(1, 1, z_gen_t.shape[2])
        z_t = torch.where(replace_z, z_infer_t, z_gen_t)

        # 5) Put result in Latents object, used in next iteration
        lats_t = LatentsKVAE(
            variables=GLSVariablesKVAE(
                m=m_t,
                V=V_t,
                Cov=None,
                x=None,
                auxiliary=z_t,
                rnn_state=rnn_state_t,
                m_auxiliary_variational=auxiliary_variational_dist_t.loc,
                V_auxiliary_variational=auxiliary_variational_dist_t.
                covariance_matrix,
            ),
            gls_params=gls_params_t,
        )
        return lats_t
def loss_forward(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    """
    Computes the sample-wise (e.g. (particle, batch)) forward / filter loss.
    If particle and batch dims are used,
    the ordering is TPBF (time, particle, batch, feature).

    Note: it would be better numerically (amount of computation and precision)
    to sum/avg over these dimensions in all loss terms.
    However, we must compute it sample-wise for many algorithms,
    i.e. Rao-Blackwellised Particle Filters.
    """
    device, dtype = A.device, A.dtype

    V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)
    m0, V0 = add_sample_dims_to_initial_state(m0=m0, V0=V0, dims=dims)

    # Note: We can not use (more readable) in-place operations due to backprop problems.
    m_fw = [None] * dims.timesteps
    V_fw = [None] * dims.timesteps
    dim_particle = ((dims.particle, ) if dims.particle is not None
                    and dims.particle != 0 else tuple())
    loss = torch.zeros(dim_particle + (dims.batch, ),
                       device=device,
                       dtype=dtype)
    for t in range(0, dims.timesteps):
        (mp, Vp) = (filter_forward_prediction_step(
            m=m_fw[t - 1],
            V=V_fw[t - 1],
            R=cov_from_invcholesky_param(LRinv_tril[t - 1], LRinv_logdiag[t -
                                                                          1]),
            A=A[t - 1],
            b=matvec(B[t - 1], u_state[t - 1]) if u_state is not None else 0,
        ) if t > 0 else (m0, V0))
        m_fw[t], V_fw[t], dobs_norm, LVpyinv = filter_forward_measurement_step(
            y=y[t],
            m=mp,
            V=Vp,
            Q=cov_from_invcholesky_param(LQinv_tril[t], LQinv_logdiag[t]),
            C=C[t],
            d=matvec(D[t], u_obs[t]) if u_obs is not None else 0,
            return_loss_components=True,
        )
        loss += 0.5 * (torch.sum(dobs_norm**2, dim=-1) - 2 *
                       torch.sum(torch.log(batch_diag(LVpyinv)), dim=(-1, )) +
                       dims.target * LOG_2PI)
    return loss