Ejemplo n.º 1
0
def test_independent(base_dist, sample_shape, batch_shape,
                     reinterpreted_batch_ndims):
    if batch_shape:
        base_dist = base_dist.expand_by(batch_shape)
    if reinterpreted_batch_ndims > len(base_dist.batch_shape):
        with pytest.raises(ValueError):
            d = dist.Independent(base_dist, reinterpreted_batch_ndims)
    else:
        d = dist.Independent(base_dist, reinterpreted_batch_ndims)
        assert (d.batch_shape == batch_shape[:len(batch_shape) -
                                             reinterpreted_batch_ndims])
        assert (d.event_shape == batch_shape[len(batch_shape) -
                                             reinterpreted_batch_ndims:] +
                base_dist.event_shape)

        assert d.sample().shape == batch_shape + base_dist.event_shape
        assert d.mean.shape == batch_shape + base_dist.event_shape
        assert d.variance.shape == batch_shape + base_dist.event_shape
        x = d.sample(sample_shape)
        assert x.shape == sample_shape + d.batch_shape + d.event_shape

        log_prob = d.log_prob(x)
        assert (log_prob.shape == sample_shape +
                batch_shape[:len(batch_shape) - reinterpreted_batch_ndims])
        assert not torch_isnan(log_prob)
        log_prob_0 = base_dist.log_prob(x)
        assert_equal(log_prob,
                     _sum_rightmost(log_prob_0, reinterpreted_batch_ndims))
Ejemplo n.º 2
0
def test_kl_independent_normal(batch_shape, event_shape):
    shape = batch_shape + event_shape
    p = dist.Normal(torch.randn(shape), torch.randn(shape).exp())
    q = dist.Normal(torch.randn(shape), torch.randn(shape).exp())
    actual = kl_divergence(dist.Independent(p, len(event_shape)),
                           dist.Independent(q, len(event_shape)))
    expected = sum_rightmost(kl_divergence(p, q), len(event_shape))
    assert_close(actual, expected)
Ejemplo n.º 3
0
def prs_model(beta_hat, obs_error):
    z = pyro.sample(
        'z',
        dist.Independent(dist.Bernoulli(torch.tensor([p_causal]*N)), 1)
    )
    beta = pyro.sample(
        'beta_latent',
        dist.Independent(dist.Normal(GENETIC_MEAN,
                                     GENETIC_SD), 1)
    )
    beta_hat = pyro.sample(
        'beta_hat',
        dist.MultivariateNormal(torch.mv(obs_error, beta*z),
                                covariance_matrix=obs_error*sigma_sq_e),
        obs=beta_hat
    )
    return beta_hat
Ejemplo n.º 4
0
def test_kl_independent_delta_mvn_shape(batch_shape, size):
    v = torch.randn(batch_shape + (size, ))
    p = dist.Independent(dist.Delta(v), 1)

    loc = torch.randn(batch_shape + (size, ))
    cov = torch.randn(batch_shape + (size, size))
    cov = cov @ cov.transpose(-1, -2) + 0.01 * torch.eye(size)
    q = dist.MultivariateNormal(loc, covariance_matrix=cov)
    assert kl_divergence(p, q).shape == batch_shape
Ejemplo n.º 5
0
def prs_guide(index):
    psi_causal = pyro.param(
        'var_psi_causal_{}'.format(index),
        torch.tensor(np.ones(N)*p_causal),
        constraint=constraints.unit_interval
    )
    z = pyro.sample(
        'z',
        dist.Independent(dist.Bernoulli(psi_causal), 1)
    )
    means = pyro.param(
        'var_mean_{}'.format(index),
        torch.tensor(np.zeros(N))
    )
    scales = pyro.param(
        'var_scale_{}'.format(index),
        torch.tensor(np.ones(N)),
        constraint=constraints.positive
    )
    beta_latent = pyro.sample(
        'beta_latent',
        dist.Independent(dist.Normal(means, scales), 1)
    )
    return z, beta_latent
Ejemplo n.º 6
0
    def _generate_noise_dist_parameters(self):
        import numpy as np

        noise_dim = 92
        n_noise_comps = 20

        rng = np.random
        rng.seed(42)

        loc = torch.from_numpy(
            np.array([
                15 * rng.normal(size=noise_dim) for i in range(n_noise_comps)
            ]))

        cholesky_factors = [
            np.tril(rng.normal(size=(noise_dim, noise_dim))) +
            np.diag(np.exp(rng.normal(size=noise_dim)))
            for i in range(n_noise_comps)
        ]
        scale_tril = torch.from_numpy(3 * np.array(cholesky_factors))

        mix = pdist.Categorical(torch.ones(n_noise_comps, ))
        comp = pdist.Independent(
            pdist.MultivariateStudentT(df=2, loc=loc, scale_tril=scale_tril),
            0,
        )
        gmm = pdist.MixtureSameFamily(mix, comp)
        torch.save(gmm, "files/gmm.torch")

        permutation_idx = torch.from_numpy(rng.permutation(noise_dim + 8))
        torch.save(permutation_idx, "files/permutation_idx.torch")

        torch.manual_seed(42)

        for i in range(self.num_observations):
            num_observation = i + 1

            observation = self.get_observation(num_observation)
            noise = gmm.sample().reshape((1, -1)).type(observation.dtype)

            observation_and_noise = torch.cat([observation, noise], dim=1)

            path = (self.path / "files" /
                    f"num_observation_{num_observation}" /
                    "observation_distractors.csv")
            self.dim_data = noise_dim + 8
            self.save_data(path, observation_and_noise[:, permutation_idx])