def test_gaussian_log_prob_more_dims(cov_matrix: torch.Tensor,
                                     mean_vector: torch.Tensor,
                                     x: torch.Tensor):
    cov = cov_matrix @ cov_matrix.transpose(-1, -2) + torch.eye(5)
    distribution = torch.distributions.multivariate_normal.MultivariateNormal(
        mean_vector, cov)
    torch_result = distribution.log_prob(x)
    assert_allclose_tensors(torch_result,
                            gaussian_log_prob(mean_vector, cov, x),
                            rtol=1e-2)
Exemplo n.º 2
0
    def eval_loss(self, viz: VisData, filt_points: int,
                  pred_points: int) -> None:
        """Prints evaluation losses for the model.

        This is the default implementation suitable for Gaussian estimators.

        Parameters
        ----------
        viz : VisData
            The visualization data with which to compute the loss.
        filt_points : int
            The number of points with which to filter.
        pred_points : int
            The desired number of prediction points.
        """
        assert filt_points + pred_points <= len(viz.t)

        t = viz.t
        y_data = viz.y
        u = viz.u
        B = y_data.shape[1]

        # filtering and prediction time/data
        t_filt = t[:filt_points]
        y_filt = y_data[:filt_points]
        u_filt = u[:filt_points]
        t_pred = t[(filt_points - 1):(filt_points + pred_points - 1)]
        y_pred = y_data[(filt_points - 1):(filt_points + pred_points - 1)]
        u_pred = u[(filt_points - 1):(filt_points + pred_points - 1)]

        # filtering
        z0_f = self.get_initial_hidden_state(B)
        z_mu_f, z_cov_f = self(t_filt,
                               y_filt,
                               u_filt,
                               z0_f,
                               return_hidden=True)

        # prediction
        z0_mu_p = z_mu_f[-1]
        z0_cov_p = z_cov_f[-1]
        y_mu_p, y_cov_p = self.predict(z0_mu_p, z0_cov_p, t_pred, u_pred)

        # computing losses (NLL and L2)
        loss_nll = -torch.mean(gaussian_log_prob(y_mu_p, y_cov_p, y_pred))
        y_samps = torch.stack(
            [reparameterize_gauss(y_mu_p, y_cov_p) for i in range(100)])
        loss_ade = torch.mean(torch.sqrt(
            (y_samps - y_pred.unsqueeze(0))**2))  # ADE

        # reporting the evaluation
        print(
            f"Prediction Loss (filt_pts={filt_points}, pred_pts={pred_points}) \t"
            f"NLL Loss: {loss_nll.item():.3f} \t ADE Loss: {loss_ade.item():.5f}"
        )
Exemplo n.º 3
0
    def prediction_loss(
        self,
        z_mean: torch.Tensor,
        z_cov: torch.Tensor,
        batch_t: torch.Tensor,
        batch_y: torch.Tensor,
        batch_u: torch.Tensor,
        cond: Optional[torch.Tensor] = None,
        l2: bool = False,
        avg: bool = True,
    ) -> torch.Tensor:
        """Prediction loss computation.

        Parameters
        ----------
        z_mean : torch.Tensor, shape=(T, B, n)
            Latent means.
        z_cov : torch.Tensor, shape=(T, B, n)
            Latent covariances.
        batch_t : torch.Tensor, shape=(T)
            Times.
        batch_y : torch.Tensor, shape=(T, B, p)
            Observation trajectories.
        batch_u : torch.Tensor, shape=(T, B, m)
            Control inputs.
        cond : Optional[torch.Tensor], shape=(B, C)
            Conditional context.
        l2 : bool
            Whether to use the l2 loss.
        avg : bool, default=True
            Flag indicating whether to average the loss.

        Returns
        -------
        torch.Tensor, shape=(1)
            Prediction loss.
        """
        T, B = batch_y.shape[:2]

        # take prediction loss over obs y or latent state z
        if not self._z_pred:
            y_mu_p, y_cov_p = self.predict(z_mean[0],
                                           z_cov[0],
                                           batch_t,
                                           batch_u,
                                           cond=cond,
                                           return_hidden=False)

            if l2:
                loss_p = -((y_mu_p - batch_y)**2)
            else:
                loss_p = -gaussian_log_prob(y_mu_p, y_cov_p, batch_y)
        else:
            z_mu_p, z_cov_p = self.predict(
                z_mean[0],
                z_cov[0],
                batch_t,
                batch_u,
                cond=cond,
                return_hidden=True,
            )
            z_mu_s, z_cov_s = self.get_smooth()  # use smoothed vals as targets

            if l2:
                loss_p = -((z_mu_p - z_mu_s)**2)
            else:
                loss_p = -gaussian_log_prob(z_mu_p, z_cov_p, z_mu_s)

        if avg:
            loss_p = torch.sum(loss_p) / (T * B)

        assert not torch.isnan(loss_p).any()
        return loss_p
Exemplo n.º 4
0
    def loss(
        self,
        batch_t: torch.Tensor,
        batch_y: torch.Tensor,
        batch_u: torch.Tensor,
        iteration: int,
        cond: Optional[torch.Tensor] = None,
        avg: bool = True,
        return_components: bool = False,
    ):
        """See parent class.

        New Parameters
        --------------
        return_components : bool, default=False
            Flag indicating whether to return loss components.
        """
        T, B = batch_y.shape[:2]

        # loss coefficients
        burn_in_coeff = min(1.0, iteration /
                            self._burn_in)  # ramp up prediction weight
        anneal_coeff = min(1.0,
                           iteration / self._dkl_anneal_iter)  # kl annealing

        z0_p = self.get_initial_hidden_state(B)
        z_mean, z_cov = self(batch_t,
                             batch_y,
                             batch_u,
                             z0_p,
                             cond=cond,
                             return_hidden=True)
        y_mean, y_cov = self.latent_to_observation(z_mean, z_cov, cond=cond)

        if not self._is_smooth:
            raise NotImplementedError
        z_mean_s, z_cov_s = self.get_smooth()

        # filter and kl loss
        # the order of the loss computations is important to preserve for LE-EKF!
        loss_dkl = self.kl_loss(z_mean_s,
                                z_cov_s,
                                batch_t,
                                batch_u,
                                cond=cond,
                                avg=avg)
        loss_f = -gaussian_log_prob(y_mean, y_cov, batch_y)

        # smoothing/prediction loss
        y_mean_s, y_cov_s = self.latent_to_observation(z_mean_s,
                                                       z_cov_s,
                                                       cond=cond)
        loss_s = -gaussian_log_prob(y_mean_s, y_cov_s, batch_y)
        loss_p = self.prediction_loss(z_mean_s,
                                      z_cov_s,
                                      batch_t,
                                      batch_y,
                                      batch_u,
                                      cond=cond,
                                      avg=avg)

        if avg:
            loss_f = torch.sum(loss_f) / (T * B)
            loss_s = torch.sum(loss_s) / (T * B)

        if return_components:
            return loss_f, loss_s, loss_p, loss_dkl
        else:
            return (self._alpha * loss_s +
                    (1 - self._alpha) * burn_in_coeff * loss_p +
                    self._beta * anneal_coeff * loss_dkl)
Exemplo n.º 5
0
    def eval_loss(self, viz: VisData, filt_points: int,
                  pred_points: int) -> None:
        """Prints evaluation losses for the model.

        Parameters
        ----------
        viz : VisData
            The visualization data with which to compute the loss.
        filt_points : int
            The number of points with which to filter.
        pred_points : int
            The desired number of prediction points.
        """
        assert filt_points + pred_points <= len(viz.t)

        # pend_img
        if isinstance(viz, VisDataIMG):
            # filtering and prediction time/data
            t_filt = viz.t[:filt_points]
            o_filt = viz.y[:filt_points]
            u_filt = viz.u[:filt_points]
            t_pred = viz.t[(filt_points - 1):(filt_points + pred_points - 1)]
            o_pred = viz.y[(filt_points - 1):(filt_points + pred_points - 1)]
            u_pred = viz.u[(filt_points - 1):(filt_points + pred_points - 1)]

            T, B = o_pred.shape[:2]

            # filtering
            T, B = viz.y.shape[:2]
            z0_f = self.get_initial_hidden_state(B)
            z_mu_f, z_cov_f = self(t_filt,
                                   o_filt,
                                   u_filt,
                                   z0_f,
                                   return_hidden=True)

            # prediction
            z0_mu_p = z_mu_f[-1]
            z0_cov_p = z_cov_f[-1]

            # prediction
            if self._cond_channels > 0:
                _o_pred = o_pred[:, :, :-self._cond_channels, 0, 0]
                cond = o_pred[0, :, -self._cond_channels:, 0,
                              0]  # conditional context
            else:
                _o_pred = o_pred
                cond = None

            # ORIGINAL
            # computing l2 loss over 100 samples
            log_dsd_p = self.predict(z0_mu_p,
                                     z0_cov_p,
                                     t_pred,
                                     u_pred,
                                     cond=cond)
            _img_samples = []
            for i in range(100):
                # if i % 10 == 0:
                #     print(f"Sample {i} / 100")
                _img_samples.append(self._image_model.sample_img(log_dsd_p))
            img_samples = torch.stack(_img_samples)
            loss_ade = torch.mean(
                torch.sqrt((img_samples - _o_pred.unsqueeze(0))**2))

            # NLL Loss on discrete log-softmax distribution
            o_quant = (_o_pred.squeeze(2) *
                       (self._image_model.pixel_res - 1)).long()
            o_quant = o_quant.reshape(-1, *_o_pred.shape[-2:])
            logits = log_dsd_p.reshape(
                -1, *log_dsd_p.shape[-3:])  # (B, num_cats, img_h, img_w)

            # average loss over batches. dkl already batch-averaged.
            nll_loss_func = nn.NLLLoss(reduction="sum")
            loss_nll = nll_loss_func(logits, o_quant) / logits.shape[0]

            # # reporting the evaluation
            print(
                f"Prediction Loss (filt_pts={filt_points}, pred_pts={pred_points}) \t"
                f"L2 Loss: {loss_ade.item():.5f} \t"
                f"NLL: {loss_nll.item():.3f}")

        # stripped datasets
        else:
            # filtering and prediction time/data
            t_filt = viz.t[:filt_points]
            y_filt = viz.y[:filt_points]
            u_filt = viz.u[:filt_points]
            t_pred = viz.t[(filt_points - 1):(filt_points + pred_points - 1)]
            y_pred = viz.y[(filt_points - 1):(filt_points + pred_points - 1)]
            u_pred = viz.u[(filt_points - 1):(filt_points + pred_points - 1)]

            # filtering
            B = viz.y.shape[1]
            z0_f = self.get_initial_hidden_state(B)
            z_mu_f, z_cov_f = self(t_filt,
                                   y_filt,
                                   u_filt,
                                   z0_f,
                                   return_hidden=True)

            # prediction
            z0_mu_p = z_mu_f[-1]
            z0_cov_p = z_cov_f[-1]

            y_samples = []
            for i in range(100):
                if i % 10 == 0:
                    print(f"Rollout {i} / 100")
                y_mu_p, y_cov_p = self.predict(z0_mu_p, z0_cov_p, t_pred,
                                               u_pred)
                y_sample = reparameterize_gauss(y_mu_p, y_cov_p)
                y_samples.append(y_sample)

            y_samples_torch = torch.stack(y_samples)
            mean = y_samples_torch.mean(dim=0)
            var = y_samples_torch.var(dim=0)

            # computing losses (NLL and L2)
            loss_nll = -torch.mean(
                gaussian_log_prob(mean, torch.diag_embed(var), y_pred))
            loss_ade = torch.mean(
                torch.sqrt((y_samples_torch - y_pred.unsqueeze(0))**2))

            # reporting the evaluation
            print(
                f"Prediction Loss (filt_pts={filt_points}, pred_pts={pred_points}) \t"
                f"NLL Loss: {loss_nll.item():.3f} \t ADE Loss: {loss_ade.item():.5f}"
            )
        return loss_nll.item(), loss_ade.item()
Exemplo n.º 6
0
    def loss(
        self,
        batch_t: torch.Tensor,
        batch_y: torch.Tensor,
        batch_u: torch.Tensor,
        iteration: int,
        avg: bool = True,
    ) -> torch.Tensor:
        """See parent class."""
        T, B = batch_y.shape[:2]
        z0_p = self.get_initial_hidden_state(B)
        loss_cache: LossFeatures
        z_samples, z_cov, loss_cache = self(batch_t,
                                            batch_y,
                                            batch_u,
                                            z0_p,
                                            return_hidden=True,
                                            return_loss_cache=True)

        if self._is_image:
            log_dsd = self.latent_to_observation(z_samples, z_cov)
            reconstruction_loss = self._image_model.get_reconstruction_loss(
                batch_y, log_dsd, avg=False).reshape(T, B)

        else:
            y_mean, y_cov = self.latent_to_observation(z_samples, z_cov)
            reconstruction_loss = -gaussian_log_prob(y_mean, y_cov, batch_y)

        z_mean = torch.stack(loss_cache.q_mu_posterior_list)
        z_log_var = torch.stack(loss_cache.q_log_var_posterior_list)
        z_mean_prior = torch.stack(loss_cache.p_mu_prior_list)
        z_log_var_prior = torch.stack(loss_cache.p_log_var_prior_list)

        z_posterior = torch.distributions.normal.Normal(
            z_mean,
            torch.exp(z_log_var)**0.5 + 1e-6)
        z_prior = torch.distributions.normal.Normal(
            z_mean_prior,
            torch.exp(z_log_var_prior)**0.5 + 1e-6)
        kl = torch.distributions.kl.kl_divergence(z_posterior,
                                                  z_prior).mean(-1)
        # Reference for kl divergence computation
        # https://github.com/google-research/planet/blob/cbe77fc011299becf6c3805d6007c5bf58012f87/planet/models/rssm.py#L87-L94
        overshoot_loss = 0
        if self.config.overshoot[0] != OverShoot.NONE:
            K = min(self.config.overshoot[1], T - 1)
            for t, (_,
                    z_log_var_t) in enumerate(zip(z_mean[:-K],
                                                  z_log_var[:-K])):
                z_sample_t_t_k, z_mu_t_t_k, z_log_var_t_t_k = self.predict(
                    z_samples[t],
                    z_log_var_t,
                    batch_t[t:t + K],
                    batch_u[t:t + K, ...],
                    return_hidden=True,
                    with_dist=True,
                )

                if self.config.overshoot[0] == OverShoot.LATENT:
                    z_posterior = torch.distributions.normal.Normal(
                        z_mean[t:t + K, ...],
                        torch.exp(z_log_var[t:t + K, ...])**0.5 + 1e-6,
                    )
                    z_prior = torch.distributions.normal.Normal(
                        z_mu_t_t_k,
                        torch.exp(z_log_var_t_t_k)**0.5 + 1e-6)
                    kl_t_K = torch.distributions.kl.kl_divergence(
                        z_posterior, z_prior).mean(-1)
                    overshoot_loss += torch.sum(kl_t_K) / (K * B)
                elif self.config.overshoot[0] == OverShoot.OBSERVATION:
                    if self._is_image:
                        log_dsd = self.latent_to_observation(
                            z_sample_t_t_k, None)
                        overshoot_loss += (
                            self._image_model.get_reconstruction_loss(
                                batch_y[t:t + K], log_dsd,
                                avg=False).reshape(-1, B).mean())

                    else:
                        y_mean, y_cov = self.latent_to_observation(
                            z_sample_t_t_k, None)
                        overshoot_loss += -gaussian_log_prob(
                            y_mean, y_cov, batch_y[t:t + K]).mean()
            overshoot_loss = overshoot_loss / (t + 1)

        if avg:
            return torch.sum(reconstruction_loss + kl) / (T *
                                                          B) + overshoot_loss
        else:
            return reconstruction_loss + kl + overshoot_loss