def test_independent(base_dist, sample_shape, batch_shape, reinterpreted_batch_ndims): if batch_shape: base_dist = base_dist.expand_by(batch_shape) if reinterpreted_batch_ndims > len(base_dist.batch_shape): with pytest.raises(ValueError): d = dist.Independent(base_dist, reinterpreted_batch_ndims) else: d = dist.Independent(base_dist, reinterpreted_batch_ndims) assert (d.batch_shape == batch_shape[:len(batch_shape) - reinterpreted_batch_ndims]) assert (d.event_shape == batch_shape[len(batch_shape) - reinterpreted_batch_ndims:] + base_dist.event_shape) assert d.sample().shape == batch_shape + base_dist.event_shape assert d.mean.shape == batch_shape + base_dist.event_shape assert d.variance.shape == batch_shape + base_dist.event_shape x = d.sample(sample_shape) assert x.shape == sample_shape + d.batch_shape + d.event_shape log_prob = d.log_prob(x) assert (log_prob.shape == sample_shape + batch_shape[:len(batch_shape) - reinterpreted_batch_ndims]) assert not torch_isnan(log_prob) log_prob_0 = base_dist.log_prob(x) assert_equal(log_prob, _sum_rightmost(log_prob_0, reinterpreted_batch_ndims))
def test_kl_independent_normal(batch_shape, event_shape): shape = batch_shape + event_shape p = dist.Normal(torch.randn(shape), torch.randn(shape).exp()) q = dist.Normal(torch.randn(shape), torch.randn(shape).exp()) actual = kl_divergence(dist.Independent(p, len(event_shape)), dist.Independent(q, len(event_shape))) expected = sum_rightmost(kl_divergence(p, q), len(event_shape)) assert_close(actual, expected)
def prs_model(beta_hat, obs_error): z = pyro.sample( 'z', dist.Independent(dist.Bernoulli(torch.tensor([p_causal]*N)), 1) ) beta = pyro.sample( 'beta_latent', dist.Independent(dist.Normal(GENETIC_MEAN, GENETIC_SD), 1) ) beta_hat = pyro.sample( 'beta_hat', dist.MultivariateNormal(torch.mv(obs_error, beta*z), covariance_matrix=obs_error*sigma_sq_e), obs=beta_hat ) return beta_hat
def test_kl_independent_delta_mvn_shape(batch_shape, size): v = torch.randn(batch_shape + (size, )) p = dist.Independent(dist.Delta(v), 1) loc = torch.randn(batch_shape + (size, )) cov = torch.randn(batch_shape + (size, size)) cov = cov @ cov.transpose(-1, -2) + 0.01 * torch.eye(size) q = dist.MultivariateNormal(loc, covariance_matrix=cov) assert kl_divergence(p, q).shape == batch_shape
def prs_guide(index): psi_causal = pyro.param( 'var_psi_causal_{}'.format(index), torch.tensor(np.ones(N)*p_causal), constraint=constraints.unit_interval ) z = pyro.sample( 'z', dist.Independent(dist.Bernoulli(psi_causal), 1) ) means = pyro.param( 'var_mean_{}'.format(index), torch.tensor(np.zeros(N)) ) scales = pyro.param( 'var_scale_{}'.format(index), torch.tensor(np.ones(N)), constraint=constraints.positive ) beta_latent = pyro.sample( 'beta_latent', dist.Independent(dist.Normal(means, scales), 1) ) return z, beta_latent
def _generate_noise_dist_parameters(self): import numpy as np noise_dim = 92 n_noise_comps = 20 rng = np.random rng.seed(42) loc = torch.from_numpy( np.array([ 15 * rng.normal(size=noise_dim) for i in range(n_noise_comps) ])) cholesky_factors = [ np.tril(rng.normal(size=(noise_dim, noise_dim))) + np.diag(np.exp(rng.normal(size=noise_dim))) for i in range(n_noise_comps) ] scale_tril = torch.from_numpy(3 * np.array(cholesky_factors)) mix = pdist.Categorical(torch.ones(n_noise_comps, )) comp = pdist.Independent( pdist.MultivariateStudentT(df=2, loc=loc, scale_tril=scale_tril), 0, ) gmm = pdist.MixtureSameFamily(mix, comp) torch.save(gmm, "files/gmm.torch") permutation_idx = torch.from_numpy(rng.permutation(noise_dim + 8)) torch.save(permutation_idx, "files/permutation_idx.torch") torch.manual_seed(42) for i in range(self.num_observations): num_observation = i + 1 observation = self.get_observation(num_observation) noise = gmm.sample().reshape((1, -1)).type(observation.dtype) observation_and_noise = torch.cat([observation, noise], dim=1) path = (self.path / "files" / f"num_observation_{num_observation}" / "observation_distractors.csv") self.dim_data = noise_dim + 8 self.save_data(path, observation_and_noise[:, permutation_idx])