예제 #1
0
파일: test_hmm.py 프로젝트: yufengwa/pyro
def test_stable_hmm_shape(init_shape, trans_mat_shape, trans_dist_shape,
                          obs_mat_shape, obs_dist_shape, hidden_dim, obs_dim):
    stability = dist.Uniform(0, 2).sample()
    init_dist = random_stable(stability,
                              init_shape + (hidden_dim, )).to_event(1)
    trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_stable(stability,
                               trans_dist_shape + (hidden_dim, )).to_event(1)
    obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim))
    obs_dist = random_stable(stability,
                             obs_dist_shape + (obs_dim, )).to_event(1)
    d = dist.LinearHMM(init_dist,
                       trans_mat,
                       trans_dist,
                       obs_mat,
                       obs_dist,
                       duration=4)

    shape = broadcast_shape(init_shape + (4, ), trans_mat_shape,
                            trans_dist_shape, obs_mat_shape, obs_dist_shape)
    expected_batch_shape, time_shape = shape[:-1], shape[-1:]
    expected_event_shape = time_shape + (obs_dim, )
    assert d.batch_shape == expected_batch_shape
    assert d.event_shape == expected_event_shape
    assert d.support.event_dim == d.event_dim

    x = d.rsample()
    assert x.shape == d.shape()
    x = d.rsample((6, ))
    assert x.shape == (6, ) + d.shape()
    x = d.expand((6, 5)).rsample()
    assert x.shape == (6, 5) + d.event_shape
예제 #2
0
파일: test_hmm.py 프로젝트: pyro-ppl/pyro
def test_studentt_hmm_shape(
    init_shape,
    trans_mat_shape,
    trans_dist_shape,
    obs_mat_shape,
    obs_dist_shape,
    hidden_dim,
    obs_dim,
):
    init_dist = random_studentt(init_shape + (hidden_dim, )).to_event(1)
    trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_studentt(trans_dist_shape + (hidden_dim, )).to_event(1)
    obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim))
    obs_dist = random_studentt(obs_dist_shape + (obs_dim, )).to_event(1)
    d = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)

    shape = broadcast_shape(
        init_shape + (1, ),
        trans_mat_shape,
        trans_dist_shape,
        obs_mat_shape,
        obs_dist_shape,
    )
    expected_batch_shape, time_shape = shape[:-1], shape[-1:]
    expected_event_shape = time_shape + (obs_dim, )
    assert d.batch_shape == expected_batch_shape
    assert d.event_shape == expected_event_shape
    assert d.support.event_dim == d.event_dim

    x = d.rsample()
    assert x.shape == d.shape()
    x = d.rsample((6, ))
    assert x.shape == (6, ) + d.shape()
    x = d.expand((6, 5)).rsample()
    assert x.shape == (6, 5) + d.event_shape
예제 #3
0
    def model(self, zero_data, covariates):
        with pyro.plate("batch", len(zero_data), dim=-2):
            zero_data = pyro.subsample(zero_data, event_dim=1)
            covariates = pyro.subsample(covariates, event_dim=1)

            loc = zero_data[..., :1, :]
            scale = pyro.sample("scale", dist.LogNormal(loc, 1).to_event(1))

            with self.time_plate:
                jumps = pyro.sample("jumps", dist.Normal(0, scale).to_event(1))
            prediction = jumps.cumsum(-2)

            duration, obs_dim = zero_data.shape[-2:]
            noise_dist = dist.LinearHMM(
                dist.Stable(1.9, 0).expand([obs_dim]).to_event(1),
                torch.eye(obs_dim),
                dist.Stable(1.9, 0).expand([obs_dim]).to_event(1),
                torch.eye(obs_dim),
                dist.Stable(1.9, 0).expand([obs_dim]).to_event(1),
                duration=duration,
            )
            rep = StableReparam()
            with poutine.reparam(
                    config={"residual": LinearHMMReparam(rep, rep, rep)}):
                self.predict(noise_dist, prediction)
예제 #4
0
파일: test_hmm.py 프로젝트: youisbaby/pyro
def test_stable_hmm_distribution(stability, skew, duration, hidden_dim,
                                 obs_dim):
    init_dist = random_stable((hidden_dim, ), stability, skew=skew).to_event(1)
    trans_mat = torch.randn(duration, hidden_dim, hidden_dim)
    trans_dist = random_stable((duration, hidden_dim), stability,
                               skew=skew).to_event(1)
    obs_mat = torch.randn(duration, hidden_dim, obs_dim)
    obs_dist = random_stable((duration, obs_dim), stability,
                             skew=skew).to_event(1)
    hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)

    num_samples = 200000
    expected_samples = hmm.sample([num_samples
                                   ]).reshape(num_samples, duration * obs_dim)
    expected_loc, expected_scale, expected_corr = get_hmm_moments(
        expected_samples)

    rep = SymmetricStableReparam() if skew == 0 else StableReparam()
    with pyro.plate("samples", num_samples):
        with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}):
            actual_samples = pyro.sample("x",
                                         hmm).reshape(num_samples,
                                                      duration * obs_dim)
    actual_loc, actual_scale, actual_corr = get_hmm_moments(actual_samples)

    assert_close(actual_loc, expected_loc, atol=0.05, rtol=0.05)
    assert_close(actual_scale, expected_scale, atol=0.05, rtol=0.05)
    assert_close(actual_corr, expected_corr, atol=0.01)
예제 #5
0
파일: test_hmm.py 프로젝트: youisbaby/pyro
def test_stable_hmm_shape(skew, batch_shape, duration, hidden_dim, obs_dim):
    stability = dist.Uniform(0.5, 2).sample(batch_shape)

    init_dist = random_stable(batch_shape + (hidden_dim, ),
                              stability.unsqueeze(-1),
                              skew=skew).to_event(1)
    trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim))
    trans_dist = random_stable(batch_shape + (duration, hidden_dim),
                               stability.unsqueeze(-1).unsqueeze(-1),
                               skew=skew).to_event(1)
    obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim))
    obs_dist = random_stable(batch_shape + (duration, obs_dim),
                             stability.unsqueeze(-1).unsqueeze(-1),
                             skew=skew).to_event(1)
    hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)

    def model(data=None):
        with pyro.plate_stack("plates", batch_shape):
            return pyro.sample("x", hmm, obs=data)

    data = model()
    rep = SymmetricStableReparam() if skew == 0 else StableReparam()
    with poutine.trace() as tr:
        with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}):
            model(data)
    assert isinstance(tr.trace.nodes["x"]["fn"], dist.GaussianHMM)
    tr.trace.compute_log_prob()  # smoke test only
예제 #6
0
def test_studentt_hmm_shape(batch_shape, duration, hidden_dim, obs_dim):
    init_dist = random_studentt(batch_shape + (hidden_dim, )).to_event(1)
    trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim))
    trans_dist = random_studentt(batch_shape +
                                 (duration, hidden_dim)).to_event(1)
    obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim))
    obs_dist = random_studentt(batch_shape + (duration, obs_dim)).to_event(1)
    hmm = dist.LinearHMM(init_dist,
                         trans_mat,
                         trans_dist,
                         obs_mat,
                         obs_dist,
                         duration=duration)

    def model(data=None):
        with pyro.plate_stack("plates", batch_shape):
            return pyro.sample("x", hmm, obs=data)

    data = model()
    rep = StudentTReparam()
    with poutine.trace() as tr:
        with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}):
            model(data)

    assert isinstance(tr.trace.nodes["x"]["fn"], dist.GaussianHMM)
    assert tr.trace.nodes["x_init_gamma"]["fn"].event_shape == (hidden_dim, )
    assert tr.trace.nodes["x_trans_gamma"]["fn"].event_shape == (duration,
                                                                 hidden_dim)
    assert tr.trace.nodes["x_obs_gamma"]["fn"].event_shape == (duration,
                                                               obs_dim)
    tr.trace.compute_log_prob()  # smoke test only
예제 #7
0
def test_transformed_hmm_shape(batch_shape, duration, hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim))
    trans_dist = random_mvn(batch_shape + (duration, ), hidden_dim)
    obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim))
    obs_dist = dist.LogNormal(
        torch.randn(batch_shape + (duration, obs_dim)),
        torch.rand(batch_shape + (duration, obs_dim)).exp()).to_event(1)
    hmm = dist.LinearHMM(init_dist,
                         trans_mat,
                         trans_dist,
                         obs_mat,
                         obs_dist,
                         duration=duration)

    def model(data=None):
        with pyro.plate_stack("plates", batch_shape):
            return pyro.sample("x", hmm, obs=data)

    data = model()
    with poutine.trace() as tr:
        with poutine.reparam(config={"x": LinearHMMReparam()}):
            model(data)
    fn = tr.trace.nodes["x"]["fn"]
    assert isinstance(fn, dist.TransformedDistribution)
    assert isinstance(fn.base_dist, dist.GaussianHMM)
    tr.trace.compute_log_prob()  # smoke test only
예제 #8
0
파일: test_hmm.py 프로젝트: pyro-ppl/pyro
def test_init_shape(skew, batch_shape, duration, hidden_dim, obs_dim):
    stability = dist.Uniform(0.5, 2).sample(batch_shape)
    init_dist = random_stable(batch_shape + (hidden_dim, ),
                              stability.unsqueeze(-1),
                              skew=skew).to_event(1)
    trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim))
    trans_dist = random_stable(
        batch_shape + (duration, hidden_dim),
        stability.unsqueeze(-1).unsqueeze(-1),
        skew=skew,
    ).to_event(1)
    obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim))
    obs_dist = random_stable(
        batch_shape + (duration, obs_dim),
        stability.unsqueeze(-1).unsqueeze(-1),
        skew=skew,
    ).to_event(1)
    hmm = dist.LinearHMM(init_dist,
                         trans_mat,
                         trans_dist,
                         obs_mat,
                         obs_dist,
                         duration=duration)
    assert hmm.batch_shape == batch_shape
    assert hmm.event_shape == (duration, obs_dim)

    def model():
        with pyro.plate_stack("plates", batch_shape):
            return pyro.sample("x", hmm)

    rep = SymmetricStableReparam() if skew == 0 else StableReparam()
    check_init_reparam(model, LinearHMMReparam(rep, rep, rep))
예제 #9
0
 def model(data):
     hmm = dist.LinearHMM(init_dist,
                          trans_mat,
                          trans_dist,
                          obs_mat,
                          obs_dist,
                          duration=num_steps)
     with pyro.plate_stack("plates", batch_shape):
         z = pyro.sample("z", hmm)
         pyro.sample("x", dist.Normal(z, 1).to_event(2), obs=data)
예제 #10
0
def test_stable_hmm_shape_error(batch_shape, duration, hidden_dim, obs_dim):
    stability = dist.Uniform(0.5, 2).sample(batch_shape)
    init_dist = random_stable(batch_shape + (hidden_dim, ),
                              stability.unsqueeze(-1)).to_event(1)
    trans_mat = torch.randn(batch_shape + (1, hidden_dim, hidden_dim))
    trans_dist = random_stable(
        batch_shape + (1, hidden_dim),
        stability.unsqueeze(-1).unsqueeze(-1)).to_event(1)
    obs_mat = torch.randn(batch_shape + (1, hidden_dim, obs_dim))
    obs_dist = random_stable(batch_shape + (1, obs_dim),
                             stability.unsqueeze(-1).unsqueeze(-1)).to_event(1)
    hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)
    assert hmm.batch_shape == batch_shape
    assert hmm.event_shape == (1, obs_dim)

    def model(data=None):
        with pyro.plate_stack("plates", batch_shape):
            return pyro.sample("x", hmm, obs=data)

    data = torch.randn(duration, obs_dim)
    rep = StableReparam()
    with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}):
        with pytest.raises(ValueError):
            model(data)