def test_pdf2(self): torch.manual_seed(22) dist1 = pyro.distributions.Normal(torch.ones(7), 0.01 * torch.ones(7, )).to_event(1) dist2 = pyro.distributions.Normal(-1 * torch.ones(3), 0.01 * torch.ones(3, )).to_event(1) catdist1 = CatDist([dist1, dist2], reduce_event_dim=False) catdist2 = CatDist([dist1, dist2], reduce_event_dim=True) x1 = torch.normal(1.0, 0.01, size=(150, 10)) x2 = torch.normal(-1.0, 0.01, size=(150, 10)) x3 = torch.cat([x1[:, :7], x2[:, 7:]], dim=-1) logp1 = torch.sum(catdist1.log_prob(x3), dim=0).numpy() logp2 = catdist2.log_prob(x3).numpy() assert np.array_equal(logp1, logp2)
def __init__(self, size_in, prior_factor=1.0, weight_prior_std=1.0, bias_prior_std=3.0, **kwargs): self._params = OrderedDict() self._param_dists = OrderedDict() self.prior_factor = prior_factor self.gp = VectorizedGP(size_in, **kwargs) for name, shape in self.gp.parameter_shapes().items(): if name == 'constant_mean': mean_p_loc = torch.zeros(1).to(device) mean_p_scale = torch.ones(1).to(device) self._param_dist(name, Normal(mean_p_loc, mean_p_scale).to_event(1)) if name == 'lengthscale_raw': lengthscale_p_loc = torch.zeros(shape[-1]).to(device) lengthscale_p_scale = torch.ones(shape[-1]).to(device) self._param_dist( name, Normal(lengthscale_p_loc, lengthscale_p_scale).to_event(1)) if name == 'noise_raw': noise_p_loc = -1. * torch.ones(1).to(device) noise_p_scale = torch.ones(1).to(device) self._param_dist( name, Normal(noise_p_loc, noise_p_scale).to_event(1)) if 'mean_nn' in name or 'kernel_nn' in name: mean = torch.zeros(shape).to(device) if "weight" in name: std = weight_prior_std * torch.ones(shape).to(device) elif "bias" in name: std = bias_prior_std * torch.ones(shape).to(device) else: raise NotImplementedError self._param_dist(name, Normal(mean, std).to_event(1)) # check that parameters in prior and gp modules are aligned for param_name_gp, param_name_prior in zip( self.gp.named_parameters().keys(), self._param_dists.keys()): assert param_name_gp == param_name_prior self.hyper_prior = CatDist(self._param_dists.values())
def test_sampling2(self): torch.manual_seed(22) dist1 = pyro.distributions.Normal(torch.ones(5), 0.01 * torch.ones(5, )).to_event(1) dist2 = pyro.distributions.Normal(-1 * torch.ones(3), 0.01 * torch.ones(3, )).to_event(1) catdist = CatDist([dist1, dist2]) sample = catdist.rsample((100, )) assert sample.shape == (100, 5 + 3) sample1_mean = sample[:, :5].mean().item() sample2_mean = sample[:, 5:].mean().item() assert np.abs(sample1_mean - 1) < 0.2 assert np.abs(sample2_mean + 1) < 0.2
def test_pdf(self): torch.manual_seed(22) dist1 = pyro.distributions.Normal(torch.ones(7), 0.01 * torch.ones(7, )).to_event(1) dist2 = pyro.distributions.Normal(-1 * torch.ones(3), 0.01 * torch.ones(3, )).to_event(1) catdist = CatDist([dist1, dist2]) x1 = torch.normal(1.0, 0.01, size=(150, 10)) x2 = torch.normal(-1.0, 0.01, size=(150, 10)) x3 = torch.cat([x1[:, :7], x2[:, 7:]], dim=-1) logp1 = catdist.log_prob(x1) logp3 = catdist.log_prob(x3) assert torch.mean(logp1).item() < -1000 assert torch.mean(logp3).item() > 10