def test_laplace_shape_scalar_params(self): laplace = Laplace(0, 1) self.assertEqual(laplace._batch_shape, torch.Size()) self.assertEqual(laplace._event_shape, torch.Size()) self.assertEqual(laplace.sample().size(), torch.Size((1, ))) self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2))) self.assertRaises(ValueError, laplace.log_prob, self.scalar_sample) self.assertEqual( laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) self.assertEqual( laplace.log_prob(self.tensor_sample_2).size(), torch.Size( (3, 2, 3)))
def test_laplace_shape_tensor_params(self): laplace = Laplace(torch.Tensor([0, 0]), torch.Tensor([1, 1])) self.assertEqual(laplace._batch_shape, torch.Size((2,))) self.assertEqual(laplace._event_shape, torch.Size(())) self.assertEqual(laplace.sample().size(), torch.Size((2,))) self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2, 2))) self.assertEqual(laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) self.assertRaises(ValueError, laplace.log_prob, self.tensor_sample_2)
def training_step(self, batch, batch_idx): # x, y = torch.split(batch, split_size_or_sections=1, dim=0) x = batch eps = torch.randn(batch.shape[0], 1) zs, log_ratio = self.model(eps=eps, s_span=self.s_ext_span) zs = zs[1:-1] likelihood = Laplace(loc=zs, scale=self.scale) # Bad Hack just in this case where every tensor in batch is identical logp = likelihood.log_prob(x.mean(dim=0).unsqueeze(1).to(self.device)).sum(dim=0).mean(dim=0) loss = -logp + log_ratio * self.kl_scheduler() # loss.backward() # self.optimizer.step() # self.scheduler.step() self.logp_metric.step(logp) self.log_ratio_metric.step(log_ratio) self.loss_metric.step(loss) logs = {'train_loss': loss} return {'loss': loss, 'log': logs}