예제 #1
0
    def __init__(
        self,
        emission_coeff: Tensor,
        transition_coeff: Tensor,
        innovation_coeff: Tensor,
        noise_std: Tensor,
        residuals: Tensor,
        prior_mean: Tensor,
        prior_cov: Tensor,
        latent_dim: int,
        output_dim: int,
        seq_length: int,
        F=None,
    ) -> None:
        self.latent_dim = latent_dim
        self.output_dim = output_dim
        self.seq_length = seq_length

        # Split coefficients along time axis for easy access
        # emission_coef[t]: (batch_size, obs_dim, latent_dim)
        self.emission_coeff = emission_coeff.split(
            axis=1, num_outputs=self.seq_length, squeeze_axis=True
        )

        # innovation_coef[t]: (batch_size, latent_dim)
        self.innovation_coeff = innovation_coeff.split(
            axis=1, num_outputs=self.seq_length, squeeze_axis=False
        )

        # transition_coeff: (batch_size, latent_dim, latent_dim)
        self.transition_coeff = transition_coeff.split(
            axis=1, num_outputs=self.seq_length, squeeze_axis=True
        )

        # noise_std[t]: (batch_size, obs_dim)
        self.noise_std = noise_std.split(
            axis=1, num_outputs=self.seq_length, squeeze_axis=True
        )

        # residuals[t]: (batch_size, obs_dim)
        self.residuals = residuals.split(
            axis=1, num_outputs=self.seq_length, squeeze_axis=True
        )

        self.prior_mean = prior_mean
        self.prior_cov = prior_cov
        self.F = F if F else getF(noise_std)
예제 #2
0
    def kalman_filter(self, targets: Tensor,
                      observed: Tensor) -> Tuple[Tensor, ...]:
        """
        Performs Kalman filtering given observations.


        Parameters
        ----------
        targets
            Observations, shape (batch_size, seq_length, output_dim)
        observed
            Flag tensor indicating which observations are genuine (1.0) and
            which are missing (0.0)

        Returns
        -------
        Tensor
            Log probabilities, shape (batch_size, seq_length)
        Tensor
            Mean of p(l_T | l_{T-1}), where T is seq_length, with shape
            (batch_size, latent_dim)
        Tensor
            Covariance of p(l_T | l_{T-1}), where T is seq_length, with shape
            (batch_size, latent_dim, latent_dim)
        """
        F = self.F
        # targets[t]: (batch_size, obs_dim)
        targets = targets.split(axis=1,
                                num_outputs=self.seq_length,
                                squeeze_axis=True)

        log_p_seq = []

        mean = self.prior_mean
        cov = self.prior_cov

        observed = (observed.split(
            axis=1, num_outputs=self.seq_length, squeeze_axis=True)
                    if observed is not None else None)

        for t in range(self.seq_length):
            # Compute the filtered distribution
            #   p(l_t | z_1, ..., z_{t + 1})
            # and log - probability
            #   log p(z_t | z_0, z_{t - 1})
            filtered_mean, filtered_cov, log_p = kalman_filter_step(
                F,
                target=targets[t],
                prior_mean=mean,
                prior_cov=cov,
                emission_coeff=self.emission_coeff[t],
                residual=self.residuals[t],
                noise_std=self.noise_std[t],
                latent_dim=self.latent_dim,
                output_dim=self.output_dim,
            )

            log_p_seq.append(log_p.expand_dims(axis=1))

            # Mean of p(l_{t+1} | l_t)
            mean = F.linalg_gemm2(
                self.transition_coeff[t],
                (filtered_mean.expand_dims(axis=-1) if observed is None else
                 F.where(observed[t], x=filtered_mean, y=mean).expand_dims(
                     axis=-1)),
            ).squeeze(axis=-1)

            # Covariance of p(l_{t+1} | l_t)
            cov = F.linalg_gemm2(
                self.transition_coeff[t],
                F.linalg_gemm2(
                    (filtered_cov if observed is None else F.where(
                        observed[t], x=filtered_cov, y=cov)),
                    self.transition_coeff[t],
                    transpose_b=True,
                ),
            ) + F.linalg_gemm2(
                self.innovation_coeff[t],
                self.innovation_coeff[t],
                transpose_a=True,
            )

        # Return sequence of log likelihoods, as well as
        # final mean and covariance of p(l_T | l_{T-1} where T is seq_length
        return F.concat(*log_p_seq, dim=1), mean, cov