def backward(self, *inputs: Tensor) -> List[MultivariateNormal]: """Implement backwards pass.""" output_sequence, input_sequence = inputs _, _, dim_inputs = input_sequence.shape batch_size, sequence_length, dim_outputs = output_sequence.shape dim_states = self.dim_states num_particles = self.num_particles dim_delta = dim_states - dim_outputs shape = (batch_size, dim_delta, num_particles) ################################################################################ # Final Pseudo Measurement # ################################################################################ y = output_sequence[:, -1].expand(num_particles, -1, -1).permute(1, 2, 0) x_tilde_obs = self.emissions(y) loc = torch.cat((x_tilde_obs.loc, torch.zeros(*shape)), dim=1) cov = torch.cat((x_tilde_obs.covariance_matrix, torch.diag_embed(torch.ones(*shape))), dim=1) x_tilde = MultivariateNormal(loc, cov) outputs = [x_tilde] for t in reversed(range(sequence_length - 1)): ############################################################################ # PREDICT Previous pseudo-measurement # ############################################################################ y = output_sequence[:, t].expand(num_particles, -1, -1).permute(1, 2, 0) x_tilde_obs = self.emissions(y) u = input_sequence[:, t].expand(num_particles, batch_size, dim_inputs) u = u.permute(1, 2, 0) # Here change the order of y_tilde for identity dynamics (those that # return the first dim_output states). The reason for this is that we # already append the y_ from emissions in the first components. # We can check this by comparing before computing next_x_tilde # loc[0, :, 0], x.loc[0, :, 0], x_tilde[0, :, 0] delta_idx = torch.arange(dim_outputs, dim_states) idx = torch.cat((delta_idx, torch.arange(dim_outputs))) x_tilde_samples = x_tilde.rsample()[:, idx] # exchange indexes x_tilde_u = torch.cat((x_tilde_samples, u), dim=1) next_x_tilde = self.backward_model(x_tilde_u) next_x_tilde.loc += x_tilde_samples[:, :dim_delta] loc = torch.cat((x_tilde_obs.loc, next_x_tilde.loc), dim=1) cov = torch.cat((x_tilde_obs.covariance_matrix, next_x_tilde.covariance_matrix), dim=1) ############################################################################ # PREDICT Outputs # ############################################################################ x_tilde = MultivariateNormal(loc, cov) outputs.append(x_tilde) assert len(outputs) == sequence_length return outputs[::-1]
def _draw_gp_function(self, X, lengthscale=10.0, kernel_str="RBF"): if kernel_str == "RBF": kernel = RBFKernel() elif kernel_str == "Mat": kernel = MaternKernel(nu=0.5) else: raise Exception("Invalid kernel string: {}".format(kernel_str)) kernel.lengthscale = lengthscale with torch.no_grad(): lazy_cov = kernel(X) mean = torch.zeros(lazy_cov.size(0)) mvn = MultivariateNormal(mean, lazy_cov) Y = mvn.rsample()[:, None] return Y
def test_base_sample_shape(self): a = torch.randn(5, 10) lazy_square_a = RootLazyTensor(lazify(a)) dist = MultivariateNormal(torch.zeros(5), lazy_square_a) # check that providing the base samples is okay samples = dist.rsample(torch.Size((16, )), base_samples=torch.randn(16, 10)) self.assertEqual(samples.shape, torch.Size((16, 5))) # check that an event shape of base samples fails self.assertRaises(RuntimeError, dist.rsample, torch.Size((16, )), base_samples=torch.randn(16, 5)) # check that the proper event shape of base samples is okay for # a non root lt nonlazy_square_a = lazify(lazy_square_a.evaluate()) dist = MultivariateNormal(torch.zeros(5), nonlazy_square_a) samples = dist.rsample(torch.Size((16, )), base_samples=torch.randn(16, 5)) self.assertEqual(samples.shape, torch.Size((16, 5)))
def test_gauss_hermite_quadrature_1D_mvn_batch(self, cuda=False): func = lambda x: torch.sin(x) means = torch.randn(3, 10) variances = torch.randn(3, 10).abs() quadrature = GaussHermiteQuadrature1D() if cuda: means = means.cuda() variances = variances.cuda() quadrature = quadrature.cuda() dist = MultivariateNormal(means, DiagLazyTensor(variances.sqrt())) # Use quadrature results = quadrature(func, dist) # Use Monte-Carlo samples = dist.rsample(torch.Size([20000])) actual = func(samples).mean(0) self.assertLess(torch.mean(torch.abs(actual - results)), 0.1)