コード例 #1
0
def test_transformed_hmm_shape(batch_shape, duration, hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim))
    trans_dist = random_mvn(batch_shape + (duration, ), hidden_dim)
    obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim))
    obs_dist = dist.LogNormal(
        torch.randn(batch_shape + (duration, obs_dim)),
        torch.rand(batch_shape + (duration, obs_dim)).exp()).to_event(1)
    hmm = dist.LinearHMM(init_dist,
                         trans_mat,
                         trans_dist,
                         obs_mat,
                         obs_dist,
                         duration=duration)

    def model(data=None):
        with pyro.plate_stack("plates", batch_shape):
            return pyro.sample("x", hmm, obs=data)

    data = model()
    with poutine.trace() as tr:
        with poutine.reparam(config={"x": LinearHMMReparam()}):
            model(data)
    fn = tr.trace.nodes["x"]["fn"]
    assert isinstance(fn, dist.TransformedDistribution)
    assert isinstance(fn.base_dist, dist.GaussianHMM)
    tr.trace.compute_log_prob()  # smoke test only
コード例 #2
0
def test_gaussian_hmm_shape(diag, init_shape, trans_mat_shape, trans_mvn_shape,
                            obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim):
    init_dist = random_mvn(init_shape, hidden_dim)
    trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_mvn(trans_mvn_shape, hidden_dim)
    obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim))
    obs_dist = random_mvn(obs_mvn_shape, obs_dim)
    if diag:
        scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1)
        obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1)
    d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)

    shape = broadcast_shape(init_shape + (1,),
                            trans_mat_shape,
                            trans_mvn_shape,
                            obs_mat_shape,
                            obs_mvn_shape)
    expected_batch_shape, time_shape = shape[:-1], shape[-1:]
    expected_event_shape = time_shape + (obs_dim,)
    assert d.batch_shape == expected_batch_shape
    assert d.event_shape == expected_event_shape

    data = obs_dist.expand(shape).sample()
    assert data.shape == d.shape()
    actual = d.log_prob(data)
    assert actual.shape == expected_batch_shape
    check_expand(d, data)

    final = d.filter(data)
    assert isinstance(final, dist.MultivariateNormal)
    assert final.batch_shape == d.batch_shape
    assert final.event_shape == (hidden_dim,)
コード例 #3
0
ファイル: test_hmm.py プロジェクト: yufengwa/pyro
def test_gamma_gaussian_hmm_shape(scale_shape, init_shape, trans_mat_shape,
                                  trans_mvn_shape, obs_mat_shape,
                                  obs_mvn_shape, hidden_dim, obs_dim):
    init_dist = random_mvn(init_shape, hidden_dim)
    trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_mvn(trans_mvn_shape, hidden_dim)
    obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim))
    obs_dist = random_mvn(obs_mvn_shape, obs_dim)
    scale_dist = random_gamma(scale_shape)
    d = dist.GammaGaussianHMM(scale_dist, init_dist, trans_mat, trans_dist,
                              obs_mat, obs_dist)

    shape = broadcast_shape(scale_shape + (1, ), init_shape + (1, ),
                            trans_mat_shape, trans_mvn_shape, obs_mat_shape,
                            obs_mvn_shape)
    expected_batch_shape, time_shape = shape[:-1], shape[-1:]
    expected_event_shape = time_shape + (obs_dim, )
    assert d.batch_shape == expected_batch_shape
    assert d.event_shape == expected_event_shape
    assert d.support.event_dim == d.event_dim

    data = obs_dist.expand(shape).sample()
    assert data.shape == d.shape()
    actual = d.log_prob(data)
    assert actual.shape == expected_batch_shape
    check_expand(d, data)

    mixing, final = d.filter(data)
    assert isinstance(mixing, dist.Gamma)
    assert mixing.batch_shape == d.batch_shape
    assert mixing.event_shape == ()
    assert isinstance(final, dist.MultivariateNormal)
    assert final.batch_shape == d.batch_shape
    assert final.event_shape == (hidden_dim, )
コード例 #4
0
ファイル: test_hmm.py プロジェクト: pyro-ppl/pyro
def test_gamma_gaussian_hmm_log_prob(sample_shape, batch_shape, num_steps,
                                     hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim))
    trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim)
    obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim))
    obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim)
    scale_dist = random_gamma(batch_shape)
    d = dist.GammaGaussianHMM(scale_dist, init_dist, trans_mat, trans_dist,
                              obs_mat, obs_dist)
    obs_mvn = obs_dist
    data = obs_dist.sample(sample_shape)
    assert data.shape == sample_shape + d.shape()
    actual_log_prob = d.log_prob(data)

    # Compare against hand-computed density.
    # We will construct enormous unrolled joint gaussian-gammas with shapes:
    #       t | 0 1 2 3 1 2 3      T = 3 in this example
    #   ------+-----------------------------------------
    #    init | H
    #   trans | H H H H            H = hidden
    #     obs |   H H H O O O      O = observed
    # and then combine these using gamma_gaussian_tensordot().
    T = num_steps
    init = gamma_and_mvn_to_gamma_gaussian(scale_dist, init_dist)
    trans = matrix_and_mvn_to_gamma_gaussian(trans_mat, trans_dist)
    obs = matrix_and_mvn_to_gamma_gaussian(obs_mat, obs_mvn)

    unrolled_trans = reduce(
        operator.add,
        [
            trans[..., t].event_pad(left=t * hidden_dim,
                                    right=(T - t - 1) * hidden_dim)
            for t in range(T)
        ],
    )
    unrolled_obs = reduce(
        operator.add,
        [
            obs[..., t].event_pad(left=t * obs.dim(),
                                  right=(T - t - 1) * obs.dim())
            for t in range(T)
        ],
    )
    # Permute obs from HOHOHO to HHHOOO.
    perm = torch.cat(
        [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] +
        [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)])
    unrolled_obs = unrolled_obs.event_permute(perm)
    unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, ))

    assert init.dim() == hidden_dim
    assert unrolled_trans.dim() == (1 + T) * hidden_dim
    assert unrolled_obs.dim() == T * (hidden_dim + obs_dim)
    logp = gamma_gaussian_tensordot(init, unrolled_trans, hidden_dim)
    logp = gamma_gaussian_tensordot(logp, unrolled_obs, T * hidden_dim)
    # compute log_prob of the joint student-t distribution
    expected_log_prob = logp.compound().log_prob(unrolled_data)
    assert_close(actual_log_prob, expected_log_prob)
コード例 #5
0
ファイル: test_hmm.py プロジェクト: pyro-ppl/pyro
def test_gaussian_mrf_log_prob(sample_shape, batch_shape, num_steps,
                               hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_dist = random_mvn(batch_shape + (num_steps, ),
                            hidden_dim + hidden_dim)
    obs_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim + obs_dim)
    d = dist.GaussianMRF(init_dist, trans_dist, obs_dist)
    data = obs_dist.sample(sample_shape)[..., hidden_dim:]
    assert data.shape == sample_shape + d.shape()
    actual_log_prob = d.log_prob(data)

    # Compare against hand-computed density.
    # We will construct enormous unrolled joint gaussians with shapes:
    #       t | 0 1 2 3 1 2 3      T = 3 in this example
    #   ------+-----------------------------------------
    #    init | H
    #   trans | H H H H            H = hidden
    #     obs |   H H H O O O      O = observed
    # and then combine these using gaussian_tensordot().
    T = num_steps
    init = mvn_to_gaussian(init_dist)
    trans = mvn_to_gaussian(trans_dist)
    obs = mvn_to_gaussian(obs_dist)

    unrolled_trans = reduce(
        operator.add,
        [
            trans[..., t].event_pad(left=t * hidden_dim,
                                    right=(T - t - 1) * hidden_dim)
            for t in range(T)
        ],
    )
    unrolled_obs = reduce(
        operator.add,
        [
            obs[..., t].event_pad(left=t * obs.dim(),
                                  right=(T - t - 1) * obs.dim())
            for t in range(T)
        ],
    )
    # Permute obs from HOHOHO to HHHOOO.
    perm = torch.cat(
        [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] +
        [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)])
    unrolled_obs = unrolled_obs.event_permute(perm)
    unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, ))

    assert init.dim() == hidden_dim
    assert unrolled_trans.dim() == (1 + T) * hidden_dim
    assert unrolled_obs.dim() == T * (hidden_dim + obs_dim)
    logp_h = gaussian_tensordot(init, unrolled_trans, hidden_dim)
    logp_oh = gaussian_tensordot(logp_h, unrolled_obs, T * hidden_dim)
    logp_h += unrolled_obs.marginalize(right=T * obs_dim)
    expected_log_prob = logp_oh.log_density(
        unrolled_data) - logp_h.event_logsumexp()
    assert_close(actual_log_prob, expected_log_prob)
コード例 #6
0
def test_gaussian_hmm_log_prob(diag, sample_shape, batch_shape, num_steps, hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim))
    trans_dist = random_mvn(batch_shape + (num_steps,), hidden_dim)
    obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim))
    obs_dist = random_mvn(batch_shape + (num_steps,), obs_dim)
    if diag:
        scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1)
        obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1)
    d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)
    if diag:
        obs_mvn = dist.MultivariateNormal(obs_dist.base_dist.loc,
                                          scale_tril=obs_dist.base_dist.scale.diag_embed())
    else:
        obs_mvn = obs_dist
    data = obs_dist.sample(sample_shape)
    assert data.shape == sample_shape + d.shape()
    actual_log_prob = d.log_prob(data)

    # Compare against hand-computed density.
    # We will construct enormous unrolled joint gaussians with shapes:
    #       t | 0 1 2 3 1 2 3      T = 3 in this example
    #   ------+-----------------------------------------
    #    init | H
    #   trans | H H H H            H = hidden
    #     obs |   H H H O O O      O = observed
    # and then combine these using gaussian_tensordot().
    T = num_steps
    init = mvn_to_gaussian(init_dist)
    trans = matrix_and_mvn_to_gaussian(trans_mat, trans_dist)
    obs = matrix_and_mvn_to_gaussian(obs_mat, obs_mvn)

    unrolled_trans = reduce(operator.add, [
        trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim)
        for t in range(T)
    ])
    unrolled_obs = reduce(operator.add, [
        obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim())
        for t in range(T)
    ])
    # Permute obs from HOHOHO to HHHOOO.
    perm = torch.cat([torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] +
                     [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)])
    unrolled_obs = unrolled_obs.event_permute(perm)
    unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim,))

    assert init.dim() == hidden_dim
    assert unrolled_trans.dim() == (1 + T) * hidden_dim
    assert unrolled_obs.dim() == T * (hidden_dim + obs_dim)
    logp = gaussian_tensordot(init, unrolled_trans, hidden_dim)
    logp = gaussian_tensordot(logp, unrolled_obs, T * hidden_dim)
    expected_log_prob = logp.log_density(unrolled_data)
    assert_close(actual_log_prob, expected_log_prob)
コード例 #7
0
ファイル: test_gaussian.py プロジェクト: youisbaby/pyro
def test_matrix_and_mvn_to_gaussian(sample_shape, batch_shape, x_dim, y_dim):
    matrix = torch.randn(batch_shape + (x_dim, y_dim))
    y_mvn = random_mvn(batch_shape, y_dim)
    xy_mvn = random_mvn(batch_shape, x_dim + y_dim)
    gaussian = matrix_and_mvn_to_gaussian(matrix,
                                          y_mvn) + mvn_to_gaussian(xy_mvn)
    xy = torch.randn(sample_shape + (1, ) * len(batch_shape) +
                     (x_dim + y_dim, ))
    x, y = xy[..., :x_dim], xy[..., x_dim:]
    y_pred = x.unsqueeze(-2).matmul(matrix).squeeze(-2)
    actual_log_prob = gaussian.log_density(xy)
    expected_log_prob = xy_mvn.log_prob(xy) + y_mvn.log_prob(y - y_pred)
    assert_close(actual_log_prob, expected_log_prob)
コード例 #8
0
ファイル: test_gaussian.py プロジェクト: youisbaby/pyro
def test_matrix_and_mvn_to_gaussian_2(sample_shape, batch_shape, x_dim, y_dim):
    matrix = torch.randn(batch_shape + (x_dim, y_dim))
    y_mvn = random_mvn(batch_shape, y_dim)
    x_mvn = random_mvn(batch_shape, x_dim)
    Mx_cov = matrix.transpose(-2, -1).matmul(
        x_mvn.covariance_matrix).matmul(matrix)
    Mx_loc = matrix.transpose(-2,
                              -1).matmul(x_mvn.loc.unsqueeze(-1)).squeeze(-1)
    mvn = dist.MultivariateNormal(Mx_loc + y_mvn.loc,
                                  Mx_cov + y_mvn.covariance_matrix)
    expected = mvn_to_gaussian(mvn)

    actual = gaussian_tensordot(mvn_to_gaussian(x_mvn),
                                matrix_and_mvn_to_gaussian(matrix, y_mvn),
                                dims=x_dim)
    assert_close_gaussian(expected, actual)
コード例 #9
0
ファイル: test_gaussian.py プロジェクト: pyro-ppl/pyro
def test_rsample_shape(sample_shape, batch_shape, dim):
    mvn = random_mvn(batch_shape, dim)
    g = mvn_to_gaussian(mvn)
    expected = mvn.rsample(sample_shape)
    actual = g.rsample(sample_shape)
    assert actual.dtype == expected.dtype
    assert actual.shape == expected.shape
コード例 #10
0
ファイル: test_gaussian.py プロジェクト: youisbaby/pyro
def test_mvn_to_gaussian(sample_shape, batch_shape, dim):
    mvn = random_mvn(batch_shape, dim)
    gaussian = mvn_to_gaussian(mvn)
    value = mvn.sample(sample_shape)
    actual_log_prob = gaussian.log_density(value)
    expected_log_prob = mvn.log_prob(value)
    assert_close(actual_log_prob, expected_log_prob)
コード例 #11
0
def test_gaussian_mrf_shape(init_shape, trans_shape, obs_shape, hidden_dim, obs_dim):
    init_dist = random_mvn(init_shape, hidden_dim)
    trans_dist = random_mvn(trans_shape, hidden_dim + hidden_dim)
    obs_dist = random_mvn(obs_shape, hidden_dim + obs_dim)
    d = dist.GaussianMRF(init_dist, trans_dist, obs_dist)

    shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape)
    expected_batch_shape, time_shape = shape[:-1], shape[-1:]
    expected_event_shape = time_shape + (obs_dim,)
    assert d.batch_shape == expected_batch_shape
    assert d.event_shape == expected_event_shape

    data = obs_dist.expand(shape).sample()[..., hidden_dim:]
    assert data.shape == d.shape()
    actual = d.log_prob(data)
    assert actual.shape == expected_batch_shape
    check_expand(d, data)
コード例 #12
0
ファイル: test_hmm.py プロジェクト: yufengwa/pyro
def test_independent_hmm_shape(init_shape, trans_mat_shape, trans_mvn_shape,
                               obs_mat_shape, obs_mvn_shape, hidden_dim,
                               obs_dim):
    base_init_shape = init_shape + (obs_dim, )
    base_trans_mat_shape = trans_mat_shape[:-1] + (obs_dim, trans_mat_shape[-1]
                                                   if trans_mat_shape else 6)
    base_trans_mvn_shape = trans_mvn_shape[:-1] + (obs_dim, trans_mvn_shape[-1]
                                                   if trans_mvn_shape else 6)
    base_obs_mat_shape = obs_mat_shape[:-1] + (obs_dim, obs_mat_shape[-1]
                                               if obs_mat_shape else 6)
    base_obs_mvn_shape = obs_mvn_shape[:-1] + (obs_dim, obs_mvn_shape[-1]
                                               if obs_mvn_shape else 6)

    init_dist = random_mvn(base_init_shape, hidden_dim)
    trans_mat = torch.randn(base_trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_mvn(base_trans_mvn_shape, hidden_dim)
    obs_mat = torch.randn(base_obs_mat_shape + (hidden_dim, 1))
    obs_dist = random_mvn(base_obs_mvn_shape, 1)
    d = dist.GaussianHMM(init_dist,
                         trans_mat,
                         trans_dist,
                         obs_mat,
                         obs_dist,
                         duration=6)
    d = dist.IndependentHMM(d)

    shape = broadcast_shape(init_shape + (6, ), trans_mat_shape,
                            trans_mvn_shape, obs_mat_shape, obs_mvn_shape)
    expected_batch_shape, time_shape = shape[:-1], shape[-1:]
    expected_event_shape = time_shape + (obs_dim, )
    assert d.batch_shape == expected_batch_shape
    assert d.event_shape == expected_event_shape
    assert d.support.event_dim == d.event_dim

    data = torch.randn(shape + (obs_dim, ))
    assert data.shape == d.shape()
    actual = d.log_prob(data)
    assert actual.shape == expected_batch_shape
    check_expand(d, data)

    x = d.rsample()
    assert x.shape == d.shape()
    x = d.rsample((6, ))
    assert x.shape == (6, ) + d.shape()
    x = d.expand((6, 5)).rsample()
    assert x.shape == (6, 5) + d.event_shape
コード例 #13
0
def test_gaussian_mrf_log_prob_block_diag(sample_shape, batch_shape, num_steps, hidden_dim, obs_dim):
    # Construct a block-diagonal obs dist, so observations are independent of hidden state.
    obs_dist = random_mvn(batch_shape + (num_steps,), hidden_dim + obs_dim)
    precision = obs_dist.precision_matrix
    precision[..., :hidden_dim, hidden_dim:] = 0
    precision[..., hidden_dim:, :hidden_dim] = 0
    obs_dist = dist.MultivariateNormal(obs_dist.loc, precision_matrix=precision)
    marginal_obs_dist = dist.MultivariateNormal(
        obs_dist.loc[..., hidden_dim:],
        precision_matrix=precision[..., hidden_dim:, hidden_dim:])

    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_dist = random_mvn(batch_shape + (num_steps,), hidden_dim + hidden_dim)
    d = dist.GaussianMRF(init_dist, trans_dist, obs_dist)
    data = obs_dist.sample(sample_shape)[..., hidden_dim:]
    assert data.shape == sample_shape + d.shape()
    actual_log_prob = d.log_prob(data)
    expected_log_prob = marginal_obs_dist.log_prob(data).sum(-1)
    assert_close(actual_log_prob, expected_log_prob)
コード例 #14
0
ファイル: test_util.py プロジェクト: pyro-ppl/pyro
def random_dist(Dist, shape, transform=None):
    if Dist is dist.FoldedDistribution:
        return Dist(random_dist(dist.Normal, shape))
    elif Dist is dist.MaskedDistribution:
        base_dist = random_dist(dist.Normal, shape)
        mask = torch.empty(shape, dtype=torch.bool).bernoulli_(0.5)
        return base_dist.mask(mask)
    elif Dist is dist.TransformedDistribution:
        base_dist = random_dist(dist.Normal, shape)
        transforms = [
            dist.transforms.ExpTransform(),
            dist.transforms.ComposeTransform([
                dist.transforms.AffineTransform(1, 1),
                dist.transforms.ExpTransform().inv,
            ]),
        ]
        return dist.TransformedDistribution(base_dist, transforms)
    elif Dist in (dist.GaussianHMM, dist.LinearHMM):
        batch_shape, duration, obs_dim = shape[:-2], shape[-2], shape[-1]
        hidden_dim = obs_dim + 1
        init_dist = random_dist(dist.Normal,
                                batch_shape + (hidden_dim, )).to_event(1)
        trans_mat = torch.randn(batch_shape +
                                (duration, hidden_dim, hidden_dim))
        trans_dist = random_dist(dist.Normal, batch_shape +
                                 (duration, hidden_dim)).to_event(1)
        obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim))
        obs_dist = random_dist(dist.Normal,
                               batch_shape + (duration, obs_dim)).to_event(1)
        if Dist is dist.LinearHMM and transform is not None:
            obs_dist = dist.TransformedDistribution(obs_dist, transform)
        return Dist(init_dist,
                    trans_mat,
                    trans_dist,
                    obs_mat,
                    obs_dist,
                    duration=duration)
    elif Dist is dist.IndependentHMM:
        batch_shape, duration, obs_dim = shape[:-2], shape[-2], shape[-1]
        base_shape = batch_shape + (obs_dim, duration, 1)
        base_dist = random_dist(dist.GaussianHMM, base_shape)
        return Dist(base_dist)
    elif Dist is dist.MultivariateNormal:
        return random_mvn(shape[:-1], shape[-1])
    elif Dist is dist.Uniform:
        low = torch.randn(shape)
        high = low + torch.randn(shape).exp()
        return Dist(low, high)
    else:
        params = {
            name:
            transform_to(Dist.arg_constraints[name])(torch.rand(shape) - 0.5)
            for name in UNIVARIATE_DISTS[Dist]
        }
        return Dist(**params)
コード例 #15
0
ファイル: test_conjugate.py プロジェクト: pyro-ppl/pyro
def test_gaussian_hmm_elbo(batch_shape, num_steps, hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim),
                            requires_grad=True)
    trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim)
    obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim),
                          requires_grad=True)
    obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim)

    data = obs_dist.sample()
    assert data.shape == batch_shape + (num_steps, obs_dim)
    prior = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat,
                             obs_dist)
    likelihood = dist.Normal(data, 1).to_event(2)
    posterior, log_normalizer = prior.conjugate_update(likelihood)

    def model(data):
        with pyro.plate_stack("plates", batch_shape):
            z = pyro.sample("z", prior)
            pyro.sample("x", dist.Normal(z, 1).to_event(2), obs=data)

    def guide(data):
        with pyro.plate_stack("plates", batch_shape):
            pyro.sample("z", posterior)

    reparam_model = poutine.reparam(model, {"z": ConjugateReparam(likelihood)})

    def reparam_guide(data):
        pass

    elbo = Trace_ELBO(num_particles=1000, vectorize_particles=True)
    expected_loss = elbo.differentiable_loss(model, guide, data)
    actual_loss = elbo.differentiable_loss(reparam_model, reparam_guide, data)
    assert_close(actual_loss, expected_loss, atol=0.01)

    params = [trans_mat, obs_mat]
    expected_grads = torch.autograd.grad(expected_loss,
                                         params,
                                         retain_graph=True)
    actual_grads = torch.autograd.grad(actual_loss, params, retain_graph=True)
    for a, e in zip(actual_grads, expected_grads):
        assert_close(a, e, rtol=0.01)
コード例 #16
0
ファイル: test_hmm.py プロジェクト: yufengwa/pyro
def test_gaussian_hmm_high_obs_dim():
    hidden_dim = 1
    obs_dim = 1000
    duration = 10
    sample_shape = (100, )
    init_dist = random_mvn((), hidden_dim)
    trans_mat = torch.randn((duration, ) + (hidden_dim, hidden_dim))
    trans_dist = random_mvn((duration, ), hidden_dim)
    obs_mat = torch.randn((duration, ) + (hidden_dim, obs_dim))
    loc = torch.randn((duration, obs_dim))
    scale = torch.randn((duration, obs_dim)).exp()
    obs_dist = dist.Normal(loc, scale).to_event(1)
    d = dist.GaussianHMM(init_dist,
                         trans_mat,
                         trans_dist,
                         obs_mat,
                         obs_dist,
                         duration=duration)
    x = d.rsample(sample_shape)
    assert x.shape == sample_shape + (duration, obs_dim)
コード例 #17
0
def test_gamma_and_mvn_to_gamma_gaussian(sample_shape, batch_shape, dim):
    gamma = random_gamma(batch_shape)
    mvn = random_mvn(batch_shape, dim)
    g = gamma_and_mvn_to_gamma_gaussian(gamma, mvn)
    value = mvn.sample(sample_shape)
    s = gamma.sample(sample_shape)
    actual_log_prob = g.log_density(value, s)

    s_log_prob = gamma.log_prob(s)
    scaled_prec = mvn.precision_matrix * s.unsqueeze(-1).unsqueeze(-1)
    mvn_log_prob = dist.MultivariateNormal(
        mvn.loc, precision_matrix=scaled_prec).log_prob(value)
    expected_log_prob = s_log_prob + mvn_log_prob
    assert_close(actual_log_prob, expected_log_prob)
コード例 #18
0
def test_matrix_and_mvn_to_gamma_gaussian(sample_shape, batch_shape, x_dim,
                                          y_dim):
    matrix = torch.randn(batch_shape + (x_dim, y_dim))
    y_mvn = random_mvn(batch_shape, y_dim)
    g = matrix_and_mvn_to_gamma_gaussian(matrix, y_mvn)
    xy = torch.randn(sample_shape + batch_shape + (x_dim + y_dim, ))
    s = torch.rand(sample_shape + batch_shape)
    actual_log_prob = g.log_density(xy, s)

    x, y = xy[..., :x_dim], xy[..., x_dim:]
    y_pred = x.unsqueeze(-2).matmul(matrix).squeeze(-2)
    loc = y_pred + y_mvn.loc
    scaled_prec = y_mvn.precision_matrix * s.unsqueeze(-1).unsqueeze(-1)
    expected_log_prob = dist.MultivariateNormal(
        loc, precision_matrix=scaled_prec).log_prob(y)
    assert_close(actual_log_prob, expected_log_prob)
コード例 #19
0
ファイル: test_gaussian.py プロジェクト: pyro-ppl/pyro
def test_rsample_distribution(batch_shape, dim):
    num_samples = 20000
    mvn = random_mvn(batch_shape, dim)
    g = mvn_to_gaussian(mvn)
    expected = mvn.rsample((num_samples, ))
    actual = g.rsample((num_samples, ))

    def get_moments(x):
        mean = x.mean(0)
        x = x - mean
        cov = (x.unsqueeze(-1) * x.unsqueeze(-2)).mean(0)
        std = cov.diagonal(dim1=-1, dim2=-2).sqrt()
        corr = cov / (std.unsqueeze(-1) * std.unsqueeze(-2))
        return mean, std, corr

    expected_mean, expected_std, expected_corr = get_moments(expected)
    actual_mean, actual_std, actual_corr = get_moments(actual)
    assert_close(actual_mean, expected_mean, atol=0.1, rtol=0.02)
    assert_close(actual_std, expected_std, atol=0.1, rtol=0.02)
    assert_close(actual_corr, expected_corr, atol=0.05)
コード例 #20
0
def random_dist(Dist, shape, transform=None):
    if Dist is dist.FoldedDistribution:
        return Dist(random_dist(dist.Normal, shape))
    elif Dist in (dist.GaussianHMM, dist.LinearHMM):
        batch_shape, duration, obs_dim = shape[:-2], shape[-2], shape[-1]
        hidden_dim = obs_dim + 1
        init_dist = random_dist(dist.Normal,
                                batch_shape + (hidden_dim, )).to_event(1)
        trans_mat = torch.randn(batch_shape +
                                (duration, hidden_dim, hidden_dim))
        trans_dist = random_dist(dist.Normal, batch_shape +
                                 (duration, hidden_dim)).to_event(1)
        obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim))
        obs_dist = random_dist(dist.Normal,
                               batch_shape + (duration, obs_dim)).to_event(1)
        if Dist is dist.LinearHMM and transform is not None:
            obs_dist = dist.TransformedDistribution(obs_dist, transform)
        return Dist(init_dist,
                    trans_mat,
                    trans_dist,
                    obs_mat,
                    obs_dist,
                    duration=duration)
    elif Dist is dist.IndependentHMM:
        batch_shape, duration, obs_dim = shape[:-2], shape[-2], shape[-1]
        base_shape = batch_shape + (obs_dim, duration, 1)
        base_dist = random_dist(dist.GaussianHMM, base_shape)
        return Dist(base_dist)
    elif Dist is dist.MultivariateNormal:
        return random_mvn(shape[:-1], shape[-1])
    else:
        params = {
            name:
            transform_to(Dist.arg_constraints[name])(torch.rand(shape) - 0.5)
            for name in UNIVARIATE_DISTS[Dist]
        }
        return Dist(**params)
コード例 #21
0
ファイル: test_hmm.py プロジェクト: yufengwa/pyro
def test_gaussian_hmm_distribution(diag, sample_shape, batch_shape, num_steps,
                                   hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim))
    trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim)
    obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim))
    obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim)
    if diag:
        scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1)
        obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1)
    d = dist.GaussianHMM(init_dist,
                         trans_mat,
                         trans_dist,
                         obs_mat,
                         obs_dist,
                         duration=num_steps)
    if diag:
        obs_mvn = dist.MultivariateNormal(
            obs_dist.base_dist.loc,
            scale_tril=obs_dist.base_dist.scale.diag_embed())
    else:
        obs_mvn = obs_dist
    data = obs_dist.sample(sample_shape)
    assert data.shape == sample_shape + d.shape()
    actual_log_prob = d.log_prob(data)

    # Compare against hand-computed density.
    # We will construct enormous unrolled joint gaussians with shapes:
    #       t | 0 1 2 3 1 2 3      T = 3 in this example
    #   ------+-----------------------------------------
    #    init | H
    #   trans | H H H H            H = hidden
    #     obs |   H H H O O O      O = observed
    #    like |         O O O
    # and then combine these using gaussian_tensordot().
    T = num_steps
    init = mvn_to_gaussian(init_dist)
    trans = matrix_and_mvn_to_gaussian(trans_mat, trans_dist)
    obs = matrix_and_mvn_to_gaussian(obs_mat, obs_mvn)
    like_dist = dist.Normal(torch.randn(data.shape), 1).to_event(2)
    like = mvn_to_gaussian(like_dist)

    unrolled_trans = reduce(operator.add, [
        trans[..., t].event_pad(left=t * hidden_dim,
                                right=(T - t - 1) * hidden_dim)
        for t in range(T)
    ])
    unrolled_obs = reduce(operator.add, [
        obs[..., t].event_pad(left=t * obs.dim(),
                              right=(T - t - 1) * obs.dim()) for t in range(T)
    ])
    unrolled_like = reduce(operator.add, [
        like[..., t].event_pad(left=t * obs_dim, right=(T - t - 1) * obs_dim)
        for t in range(T)
    ])
    # Permute obs from HOHOHO to HHHOOO.
    perm = torch.cat(
        [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] +
        [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)])
    unrolled_obs = unrolled_obs.event_permute(perm)
    unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, ))

    assert init.dim() == hidden_dim
    assert unrolled_trans.dim() == (1 + T) * hidden_dim
    assert unrolled_obs.dim() == T * (hidden_dim + obs_dim)
    logp = gaussian_tensordot(init, unrolled_trans, hidden_dim)
    logp = gaussian_tensordot(logp, unrolled_obs, T * hidden_dim)
    expected_log_prob = logp.log_density(unrolled_data)
    assert_close(actual_log_prob, expected_log_prob)

    d_posterior, log_normalizer = d.conjugate_update(like_dist)
    assert_close(
        d.log_prob(data) + like_dist.log_prob(data),
        d_posterior.log_prob(data) + log_normalizer)

    if batch_shape or sample_shape:
        return

    # Test mean and covariance.
    prior = "prior", d, logp
    posterior = "posterior", d_posterior, logp + unrolled_like
    for name, d, g in [prior, posterior]:
        logging.info("testing {} moments".format(name))
        with torch.no_grad():
            num_samples = 100000
            samples = d.sample([num_samples]).reshape(num_samples, T * obs_dim)
            actual_mean = samples.mean(0)
            delta = samples - actual_mean
            actual_cov = (delta.unsqueeze(-1) * delta.unsqueeze(-2)).mean(0)
            actual_std = actual_cov.diagonal(dim1=-2, dim2=-1).sqrt()
            actual_corr = actual_cov / (actual_std.unsqueeze(-1) *
                                        actual_std.unsqueeze(-2))

            expected_cov = g.precision.cholesky().cholesky_inverse()
            expected_mean = expected_cov.matmul(
                g.info_vec.unsqueeze(-1)).squeeze(-1)
            expected_std = expected_cov.diagonal(dim1=-2, dim2=-1).sqrt()
            expected_corr = expected_cov / (expected_std.unsqueeze(-1) *
                                            expected_std.unsqueeze(-2))

            assert_close(actual_mean, expected_mean, atol=0.05, rtol=0.02)
            assert_close(actual_std, expected_std, atol=0.05, rtol=0.02)
            assert_close(actual_corr, expected_corr, atol=0.02)
コード例 #22
0
ファイル: test_hmm.py プロジェクト: yufengwa/pyro
def test_gaussian_hmm_shape(diag, init_shape, trans_mat_shape, trans_mvn_shape,
                            obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim):
    init_dist = random_mvn(init_shape, hidden_dim)
    trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_mvn(trans_mvn_shape, hidden_dim)
    obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim))
    obs_dist = random_mvn(obs_mvn_shape, obs_dim)
    if diag:
        scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1)
        obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1)
    d = dist.GaussianHMM(init_dist,
                         trans_mat,
                         trans_dist,
                         obs_mat,
                         obs_dist,
                         duration=6)

    shape = broadcast_shape(init_shape + (6, ), trans_mat_shape,
                            trans_mvn_shape, obs_mat_shape, obs_mvn_shape)
    expected_batch_shape, time_shape = shape[:-1], shape[-1:]
    expected_event_shape = time_shape + (obs_dim, )
    assert d.batch_shape == expected_batch_shape
    assert d.event_shape == expected_event_shape
    assert d.support.event_dim == d.event_dim

    data = obs_dist.expand(shape).sample()
    assert data.shape == d.shape()
    actual = d.log_prob(data)
    assert actual.shape == expected_batch_shape
    check_expand(d, data)

    x = d.rsample()
    assert x.shape == d.shape()
    x = d.rsample((6, ))
    assert x.shape == (6, ) + d.shape()
    x = d.expand((6, 5)).rsample()
    assert x.shape == (6, 5) + d.event_shape

    likelihood = dist.Normal(data, 1).to_event(2)
    p, log_normalizer = d.conjugate_update(likelihood)
    assert p.batch_shape == d.batch_shape
    assert p.event_shape == d.event_shape
    x = p.rsample()
    assert x.shape == d.shape()
    x = p.rsample((6, ))
    assert x.shape == (6, ) + d.shape()
    x = p.expand((6, 5)).rsample()
    assert x.shape == (6, 5) + d.event_shape

    final = d.filter(data)
    assert isinstance(final, dist.MultivariateNormal)
    assert final.batch_shape == d.batch_shape
    assert final.event_shape == (hidden_dim, )

    z = d.rsample_posterior(data)
    assert z.shape == expected_batch_shape + time_shape + (hidden_dim, )

    for t in range(1, d.duration - 1):
        f = d.duration - t
        d2 = d.prefix_condition(data[..., :t, :])
        assert d2.batch_shape == d.batch_shape
        assert d2.event_shape == (f, obs_dim)