def test_laplace_shape_scalar_params(self):
     laplace = Laplace(0, 1)
     self.assertEqual(laplace._batch_shape, torch.Size())
     self.assertEqual(laplace._event_shape, torch.Size())
     self.assertEqual(laplace.sample().size(), torch.Size((1, )))
     self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2)))
     self.assertRaises(ValueError, laplace.log_prob, self.scalar_sample)
     self.assertEqual(
         laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
     self.assertEqual(
         laplace.log_prob(self.tensor_sample_2).size(), torch.Size(
             (3, 2, 3)))
 def test_laplace_shape_tensor_params(self):
     laplace = Laplace(torch.Tensor([0, 0]), torch.Tensor([1, 1]))
     self.assertEqual(laplace._batch_shape, torch.Size((2,)))
     self.assertEqual(laplace._event_shape, torch.Size(()))
     self.assertEqual(laplace.sample().size(), torch.Size((2,)))
     self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2, 2)))
     self.assertEqual(laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
     self.assertRaises(ValueError, laplace.log_prob, self.tensor_sample_2)
Exemple #3
0
    def training_step(self, batch, batch_idx):
        # x, y = torch.split(batch, split_size_or_sections=1, dim=0)
        x = batch
        eps = torch.randn(batch.shape[0], 1)

        zs, log_ratio = self.model(eps=eps, s_span=self.s_ext_span)
        zs = zs[1:-1]

        likelihood = Laplace(loc=zs, scale=self.scale)

        # Bad Hack just in this case where every tensor in batch is identical
        logp = likelihood.log_prob(x.mean(dim=0).unsqueeze(1).to(self.device)).sum(dim=0).mean(dim=0)
        loss = -logp + log_ratio * self.kl_scheduler()

        # loss.backward()
        # self.optimizer.step()
        # self.scheduler.step()
        self.logp_metric.step(logp)
        self.log_ratio_metric.step(log_ratio)
        self.loss_metric.step(loss)

        logs = {'train_loss': loss}
        return {'loss': loss, 'log': logs}