Example #1
0
    def model(self, zero_data, covariates):
        duration,data_dim = zero_data.shape
        feature_dim = covariates.size(-1)
        drift_stability = pyro.sample("drift_stability", dist.Uniform(1, 2))
        drift_scale = pyro.sample("drift_scale", dist.LogNormal(-20, 5))
        with pyro.plate("item", data_dim, dim=-2):
            bias = pyro.sample("bias", dist.Normal(5.0, 10.0))
            with pyro.plate('features',feature_dim,dim=-1):
                weight = pyro.sample("weight", dist.Normal(0.0, 5))

            # We'll sample a time-global scale parameter outside the time plate,
            # then time-local iid noise inside the time plate.
            with self.time_plate:
                # We combine two different reparameterizers: the inner SymmetricStableReparam
                # is needed for the Stable site, and the outer LocScaleReparam is optional but
                # appears to improve inference.
                with poutine.reparam(config={"drift": LocScaleReparam()}):
                    with poutine.reparam(config={"drift": SymmetricStableReparam()}):
                        drift = pyro.sample("drift",
                                            dist.Stable(drift_stability, 0, drift_scale))

        motion = drift.cumsum(-1)  # A Brownian motion.

        # The prediction now includes three terms.
        regression = torch.matmul(covariates, weight[...,None])
        prediction = motion + bias + regression.sum(axis=-1)
        prediction = prediction.unsqueeze(-1).transpose(-1, -3)
        # Finally we can construct a noise distribution.
        # We will share parameters across all time series.
        obs_scale = pyro.sample("obs_scale", dist.LogNormal(-5, 5))
        noise_dist = dist.Normal(loc=0.0, scale=obs_scale.unsqueeze(-1))
        self.predict(noise_dist, prediction)
Example #2
0
    def __init__(self, model, data, covariates=None, *,
                 num_warmup=1000, num_samples=1000, num_chains=1, time_reparam=None,
                 dense_mass=False, jit_compile=False, max_tree_depth=10):
        assert data.size(-2) == covariates.size(-2)
        super().__init__()
        if time_reparam == "haar":
            model = poutine.reparam(model, time_reparam_haar)
        elif time_reparam == "dct":
            model = poutine.reparam(model, time_reparam_dct)
        elif time_reparam is not None:
            raise ValueError("unknown time_reparam: {}".format(time_reparam))
        self.model = model
        max_plate_nesting = _guess_max_plate_nesting(model, (data, covariates), {})
        self.max_plate_nesting = max(max_plate_nesting, 1)  # force a time plate

        kernel = NUTS(model, full_mass=dense_mass, jit_compile=jit_compile, ignore_jit_warnings=True,
                      max_tree_depth=max_tree_depth, max_plate_nesting=max_plate_nesting)
        mcmc = MCMC(kernel, warmup_steps=num_warmup, num_samples=num_samples, num_chains=num_chains)
        mcmc.run(data, covariates)
        # conditions to compute rhat
        if (num_chains == 1 and num_samples >= 4) or (num_chains > 1 and num_samples >= 2):
            mcmc.summary()

        # inspect the model with particles plate = 1, so that we can reshape samples to
        # add any missing plate dim in front.
        with poutine.trace() as tr:
            with pyro.plate("particles", 1, dim=-self.max_plate_nesting - 1):
                model(data, covariates)

        self._trace = tr.trace
        self._samples = mcmc.get_samples()
        self._num_samples = num_samples * num_chains
        for name, node in list(self._trace.nodes.items()):
            if name not in self._samples:
                del self._trace.nodes[name]
Example #3
0
def test_stable(Reparam, shape):
    stability = torch.empty(shape).uniform_(1.5, 2.).requires_grad_()
    skew = torch.empty(shape).uniform_(-0.5, 0.5).requires_grad_()
    scale = torch.empty(shape).uniform_(0.5, 1.0).requires_grad_()
    loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_()
    params = [stability, skew, scale, loc]

    def model():
        with pyro.plate_stack("plates", shape):
            with pyro.plate("particles", 100000):
                return pyro.sample("x", dist.Stable(stability, skew, scale,
                                                    loc))

    value = model()
    expected_moments = get_moments(value)

    reparam_model = poutine.reparam(model, {"x": Reparam()})
    trace = poutine.trace(reparam_model).get_trace()
    if Reparam is LatentStableReparam:
        assert isinstance(trace.nodes["x"]["fn"], MaskedDistribution)
        assert isinstance(trace.nodes["x"]["fn"].base_dist, dist.Delta)
    else:
        assert isinstance(trace.nodes["x"]["fn"], dist.Normal)
    trace.compute_log_prob()  # smoke test only
    value = trace.nodes["x"]["value"]
    actual_moments = get_moments(value)
    assert_close(actual_moments, expected_moments, atol=0.05)

    for actual_m, expected_m in zip(actual_moments, expected_moments):
        expected_grads = grad(expected_m.sum(), params, retain_graph=True)
        actual_grads = grad(actual_m.sum(), params, retain_graph=True)
        assert_close(actual_grads[0], expected_grads[0], atol=0.2)
        assert_close(actual_grads[1], expected_grads[1], atol=0.1)
        assert_close(actual_grads[2], expected_grads[2], atol=0.1)
        assert_close(actual_grads[3], expected_grads[3], atol=0.1)
Example #4
0
def test_studentt_hmm_shape(batch_shape, duration, hidden_dim, obs_dim):
    init_dist = random_studentt(batch_shape + (hidden_dim, )).to_event(1)
    trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim))
    trans_dist = random_studentt(batch_shape +
                                 (duration, hidden_dim)).to_event(1)
    obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim))
    obs_dist = random_studentt(batch_shape + (duration, obs_dim)).to_event(1)
    hmm = dist.LinearHMM(init_dist,
                         trans_mat,
                         trans_dist,
                         obs_mat,
                         obs_dist,
                         duration=duration)

    def model(data=None):
        with pyro.plate_stack("plates", batch_shape):
            return pyro.sample("x", hmm, obs=data)

    data = model()
    rep = StudentTReparam()
    with poutine.trace() as tr:
        with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}):
            model(data)

    assert isinstance(tr.trace.nodes["x"]["fn"], dist.GaussianHMM)
    assert tr.trace.nodes["x_init_gamma"]["fn"].event_shape == (hidden_dim, )
    assert tr.trace.nodes["x_trans_gamma"]["fn"].event_shape == (duration,
                                                                 hidden_dim)
    assert tr.trace.nodes["x_obs_gamma"]["fn"].event_shape == (duration,
                                                               obs_dim)
    tr.trace.compute_log_prob()  # smoke test only
Example #5
0
    def model(self, zero_data, covariates):
        with pyro.plate("batch", len(zero_data), dim=-2):
            zero_data = pyro.subsample(zero_data, event_dim=1)
            covariates = pyro.subsample(covariates, event_dim=1)

            loc = zero_data[..., :1, :]
            scale = pyro.sample("scale", dist.LogNormal(loc, 1).to_event(1))

            with self.time_plate:
                jumps = pyro.sample("jumps", dist.Normal(0, scale).to_event(1))
            prediction = jumps.cumsum(-2)

            duration, obs_dim = zero_data.shape[-2:]
            noise_dist = dist.LinearHMM(
                dist.Stable(1.9, 0).expand([obs_dim]).to_event(1),
                torch.eye(obs_dim),
                dist.Stable(1.9, 0).expand([obs_dim]).to_event(1),
                torch.eye(obs_dim),
                dist.Stable(1.9, 0).expand([obs_dim]).to_event(1),
                duration=duration,
            )
            rep = StableReparam()
            with poutine.reparam(
                    config={"residual": LinearHMMReparam(rep, rep, rep)}):
                self.predict(noise_dist, prediction)
Example #6
0
def test_beta_binomial_dependent_sample():
    total = 10
    counts = dist.Binomial(total, 0.3).sample()
    concentration1 = torch.tensor(0.5)
    concentration0 = torch.tensor(1.5)

    prior = dist.Beta(concentration1, concentration0)
    posterior = dist.Beta(concentration1 + counts,
                          concentration0 + total - counts)

    def model(counts):
        prob = pyro.sample("prob", prior)
        pyro.sample("counts", dist.Binomial(total, prob), obs=counts)

    reparam_model = poutine.reparam(
        model,
        {
            "prob":
            ConjugateReparam(
                lambda counts: dist.Beta(1 + counts, 1 + total - counts)),
        },
    )

    with poutine.trace() as tr, pyro.plate("particles", 10000):
        reparam_model(counts)
    samples = tr.trace.nodes["prob"]["value"]

    assert_close(samples.mean(), posterior.mean, atol=0.01)
    assert_close(samples.std(), posterior.variance.sqrt(), atol=0.01)
Example #7
0
def test_normal(shape, dim, flip):
    loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_()
    scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_()

    def model():
        with pyro.plate_stack("plates", shape[:dim]):
            with pyro.plate("particles", 10000):
                pyro.sample(
                    "x",
                    dist.Normal(loc, scale).expand(shape).to_event(-dim))

    value = poutine.trace(model).get_trace().nodes["x"]["value"]
    expected_probe = get_moments(value)

    rep = HaarReparam(dim=dim, flip=flip)
    reparam_model = poutine.reparam(model, {"x": rep})
    trace = poutine.trace(reparam_model).get_trace()
    assert isinstance(trace.nodes["x_haar"]["fn"],
                      dist.TransformedDistribution)
    assert isinstance(trace.nodes["x"]["fn"], dist.Delta)
    value = trace.nodes["x"]["value"]
    actual_probe = get_moments(value)
    assert_close(actual_probe, expected_probe, atol=0.1)

    for actual_m, expected_m in zip(actual_probe[:10], expected_probe[:10]):
        expected_grads = grad(expected_m.sum(), [loc, scale],
                              retain_graph=True)
        actual_grads = grad(actual_m.sum(), [loc, scale], retain_graph=True)
        assert_close(actual_grads[0], expected_grads[0], atol=0.05)
        assert_close(actual_grads[1], expected_grads[1], atol=0.05)
Example #8
0
def test_beta_binomial_hmc():
    num_samples = 1000
    total = 10
    counts = dist.Binomial(total, 0.3).sample()
    concentration1 = torch.tensor(0.5)
    concentration0 = torch.tensor(1.5)

    prior = dist.Beta(concentration1, concentration0)
    likelihood = dist.Beta(1 + counts, 1 + total - counts)
    posterior = dist.Beta(concentration1 + counts,
                          concentration0 + total - counts)

    def model():
        prob = pyro.sample("prob", prior)
        pyro.sample("counts", dist.Binomial(total, prob), obs=counts)

    reparam_model = poutine.reparam(model,
                                    {"prob": ConjugateReparam(likelihood)})

    kernel = HMC(reparam_model)
    samples = MCMC(kernel, num_samples, warmup_steps=0).run()
    pred = Predictive(reparam_model, samples, num_samples=num_samples)
    trace = pred.get_vectorized_trace()
    samples = trace.nodes["prob"]["value"]

    assert_close(samples.mean(), posterior.mean, atol=0.01)
    assert_close(samples.std(), posterior.variance.sqrt(), atol=0.01)
Example #9
0
def test_normal(dist_type, centered, shape):
    loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_()
    scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_()
    if isinstance(centered, torch.Tensor):
        centered = centered.expand(shape)

    def model():
        with pyro.plate_stack("plates", shape):
            with pyro.plate("particles", 200000):
                if "dist_type" == "Normal":
                    pyro.sample("x", dist.Normal(loc, scale))
                else:
                    pyro.sample("x", dist.StudentT(10.0, loc, scale))

    value = poutine.trace(model).get_trace().nodes["x"]["value"]
    expected_probe = get_moments(value)

    if "dist_type" == "Normal":
        reparam = LocScaleReparam()
    else:
        reparam = LocScaleReparam(shape_params=["df"])
    reparam_model = poutine.reparam(model, {"x": reparam})
    value = poutine.trace(reparam_model).get_trace().nodes["x"]["value"]
    actual_probe = get_moments(value)
    assert_close(actual_probe, expected_probe, atol=0.1)

    for actual_m, expected_m in zip(actual_probe, expected_probe):
        expected_grads = grad(expected_m.sum(), [loc, scale],
                              retain_graph=True)
        actual_grads = grad(actual_m.sum(), [loc, scale], retain_graph=True)
        assert_close(actual_grads[0], expected_grads[0], atol=0.05)
        assert_close(actual_grads[1], expected_grads[1], atol=0.05)
Example #10
0
def test_log_normal(shape):
    loc = torch.empty(shape).uniform_(-1, 1)
    scale = torch.empty(shape).uniform_(0.5, 1.5)

    def model():
        with pyro.plate_stack("plates", shape):
            with pyro.plate("particles", 200000):
                return pyro.sample(
                    "x",
                    dist.TransformedDistribution(
                        dist.Normal(torch.zeros_like(loc),
                                    torch.ones_like(scale)),
                        [AffineTransform(loc, scale),
                         ExpTransform()]))

    with poutine.trace() as tr:
        value = model()
    assert isinstance(tr.trace.nodes["x"]["fn"], dist.TransformedDistribution)
    expected_moments = get_moments(value)

    with poutine.reparam(config={"x": TransformReparam()}):
        with poutine.trace() as tr:
            value = model()
    assert isinstance(tr.trace.nodes["x"]["fn"], dist.Delta)
    actual_moments = get_moments(value)
    assert_close(actual_moments, expected_moments, atol=0.05)
Example #11
0
def test_pyrocov_reparam(model, Guide, backend):
    T, P, S, F = 2, 3, 4, 5
    dataset = {
        "features": torch.randn(S, F),
        "local_time": torch.randn(T, P),
        "weekly_strains": torch.randn(T, P, S).exp().round(),
    }

    # Reparametrize the model.
    config = {
        "coef": LocScaleReparam(),
        "rate_loc": None if model is pyrocov_model else LocScaleReparam(),
        "rate": LocScaleReparam(),
        "init_loc": LocScaleReparam(),
        "init": LocScaleReparam(),
    }
    model = poutine.reparam(model, config)
    guide = Guide(model, backend=backend)
    svi = SVI(model, guide, ClippedAdam({"lr": 1e-8}), Trace_ELBO())
    for step in range(2):
        with xfail_if_not_implemented():
            svi.step(dataset)
    guide(dataset)
    predictive = Predictive(model, guide=guide, num_samples=2)
    predictive(dataset)
Example #12
0
def test_stable_hmm_shape(skew, batch_shape, duration, hidden_dim, obs_dim):
    stability = dist.Uniform(0.5, 2).sample(batch_shape)

    init_dist = random_stable(batch_shape + (hidden_dim, ),
                              stability.unsqueeze(-1),
                              skew=skew).to_event(1)
    trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim))
    trans_dist = random_stable(batch_shape + (duration, hidden_dim),
                               stability.unsqueeze(-1).unsqueeze(-1),
                               skew=skew).to_event(1)
    obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim))
    obs_dist = random_stable(batch_shape + (duration, obs_dim),
                             stability.unsqueeze(-1).unsqueeze(-1),
                             skew=skew).to_event(1)
    hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)

    def model(data=None):
        with pyro.plate_stack("plates", batch_shape):
            return pyro.sample("x", hmm, obs=data)

    data = model()
    rep = SymmetricStableReparam() if skew == 0 else StableReparam()
    with poutine.trace() as tr:
        with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}):
            model(data)
    assert isinstance(tr.trace.nodes["x"]["fn"], dist.GaussianHMM)
    tr.trace.compute_log_prob()  # smoke test only
Example #13
0
    def model(self, zero_data, covariates):
        data_dim = zero_data.size(-1)
        feature_dim = covariates.size(-1)
        bias = pyro.sample("bias", dist.Normal(5.0, 10.0).expand((data_dim,)).to_event(1))
        weight = pyro.sample("weight", dist.Normal(0.0, 5).expand((feature_dim,)).to_event(1))

        # We'll sample a time-global scale parameter outside the time plate,
        # then time-local iid noise inside the time plate.
        drift_scale = pyro.sample("drift_scale",
                                  dist.LogNormal(-20, 5.0).expand((1,)).to_event(1))
        with self.time_plate:
            # We'll use a reparameterizer to improve variational fit. The model would still be
            # correct if you removed this context manager, but the fit appears to be worse.
            with poutine.reparam(config={"drift": LocScaleReparam()}):
                drift = pyro.sample("drift", dist.Normal(zero_data.double(), drift_scale.double()).to_event(1))

        # After we sample the iid "drift" noise we can combine it in any time-dependent way.
        # It is important to keep everything inside the plate independent and apply dependent
        # transforms outside the plate.
        motion = drift.cumsum(-2)  # A Brownian motion.

        # The prediction now includes three terms.
        prediction = motion + bias + (weight * covariates).sum(-1, keepdim=True)
        assert prediction.shape[-2:] == zero_data.shape

        # Construct the noise distribution and predict.
        noise_scale = pyro.sample("noise_scale", dist.LogNormal(0.0, 1.0).expand((1,)).to_event(1))
        noise_dist = dist.Normal(0, noise_scale)
        self.predict(noise_dist, prediction)
Example #14
0
def test_stable_hmm_distribution(stability, skew, duration, hidden_dim,
                                 obs_dim):
    init_dist = random_stable((hidden_dim, ), stability, skew=skew).to_event(1)
    trans_mat = torch.randn(duration, hidden_dim, hidden_dim)
    trans_dist = random_stable((duration, hidden_dim), stability,
                               skew=skew).to_event(1)
    obs_mat = torch.randn(duration, hidden_dim, obs_dim)
    obs_dist = random_stable((duration, obs_dim), stability,
                             skew=skew).to_event(1)
    hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)

    num_samples = 200000
    expected_samples = hmm.sample([num_samples
                                   ]).reshape(num_samples, duration * obs_dim)
    expected_loc, expected_scale, expected_corr = get_hmm_moments(
        expected_samples)

    rep = SymmetricStableReparam() if skew == 0 else StableReparam()
    with pyro.plate("samples", num_samples):
        with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}):
            actual_samples = pyro.sample("x",
                                         hmm).reshape(num_samples,
                                                      duration * obs_dim)
    actual_loc, actual_scale, actual_corr = get_hmm_moments(actual_samples)

    assert_close(actual_loc, expected_loc, atol=0.05, rtol=0.05)
    assert_close(actual_scale, expected_scale, atol=0.05, rtol=0.05)
    assert_close(actual_corr, expected_corr, atol=0.01)
Example #15
0
def test_transformed_hmm_shape(batch_shape, duration, hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim))
    trans_dist = random_mvn(batch_shape + (duration, ), hidden_dim)
    obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim))
    obs_dist = dist.LogNormal(
        torch.randn(batch_shape + (duration, obs_dim)),
        torch.rand(batch_shape + (duration, obs_dim)).exp()).to_event(1)
    hmm = dist.LinearHMM(init_dist,
                         trans_mat,
                         trans_dist,
                         obs_mat,
                         obs_dist,
                         duration=duration)

    def model(data=None):
        with pyro.plate_stack("plates", batch_shape):
            return pyro.sample("x", hmm, obs=data)

    data = model()
    with poutine.trace() as tr:
        with poutine.reparam(config={"x": LinearHMMReparam()}):
            model(data)
    fn = tr.trace.nodes["x"]["fn"]
    assert isinstance(fn, dist.TransformedDistribution)
    assert isinstance(fn.base_dist, dist.GaussianHMM)
    tr.trace.compute_log_prob()  # smoke test only
Example #16
0
def test_log_normal(batch_shape, event_shape):
    shape = batch_shape + event_shape
    loc = torch.empty(shape).uniform_(-1, 1)
    scale = torch.empty(shape).uniform_(0.5, 1.5)

    def model():
        fn = dist.TransformedDistribution(
            dist.Normal(torch.zeros_like(loc), torch.ones_like(scale)),
            [AffineTransform(loc, scale),
             ExpTransform()])
        if event_shape:
            fn = fn.to_event(len(event_shape))
        with pyro.plate_stack("plates", batch_shape):
            with pyro.plate("particles", 200000):
                return pyro.sample("x", fn)

    with poutine.trace() as tr:
        value = model()
    assert isinstance(tr.trace.nodes["x"]["fn"],
                      (dist.TransformedDistribution, dist.Independent))
    expected_moments = get_moments(value)

    with poutine.reparam(config={"x": TransformReparam()}):
        with poutine.trace() as tr:
            value = model()
    assert isinstance(tr.trace.nodes["x"]["fn"],
                      (dist.Delta, dist.MaskedDistribution))
    actual_moments = get_moments(value)
    assert_close(actual_moments, expected_moments, atol=0.05)
Example #17
0
def test_symmetric_stable(shape):
    stability = torch.empty(shape).uniform_(1.6, 1.9).requires_grad_()
    scale = torch.empty(shape).uniform_(0.5, 1.0).requires_grad_()
    loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_()
    params = [stability, scale, loc]

    def model():
        with pyro.plate_stack("plates", shape):
            with pyro.plate("particles", 200000):
                return pyro.sample("x", dist.Stable(stability, 0, scale, loc))

    value = model()
    expected_moments = get_moments(value)

    reparam_model = poutine.reparam(model, {"x": SymmetricStableReparam()})
    trace = poutine.trace(reparam_model).get_trace()
    assert isinstance(trace.nodes["x"]["fn"], dist.Normal)
    trace.compute_log_prob()  # smoke test only
    value = trace.nodes["x"]["value"]
    actual_moments = get_moments(value)
    assert_close(actual_moments, expected_moments, atol=0.05)

    for actual_m, expected_m in zip(actual_moments, expected_moments):
        expected_grads = grad(expected_m.sum(), params, retain_graph=True)
        actual_grads = grad(actual_m.sum(), params, retain_graph=True)
        assert_close(actual_grads[0], expected_grads[0], atol=0.2)
        assert_close(actual_grads[1], expected_grads[1], atol=0.1)
        assert_close(actual_grads[2], expected_grads[2], atol=0.1)
Example #18
0
def test_stable_hmm_smoke(batch_shape, num_steps, hidden_dim, obs_dim):
    init_dist = random_stable(batch_shape + (hidden_dim, )).to_event(1)
    trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim),
                            requires_grad=True)
    trans_dist = random_stable(batch_shape +
                               (num_steps, hidden_dim)).to_event(1)
    obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim),
                          requires_grad=True)
    obs_dist = random_stable(batch_shape + (num_steps, obs_dim)).to_event(1)
    data = obs_dist.sample()
    assert data.shape == batch_shape + (num_steps, obs_dim)

    def model(data):
        hmm = dist.LinearHMM(init_dist,
                             trans_mat,
                             trans_dist,
                             obs_mat,
                             obs_dist,
                             duration=num_steps)
        with pyro.plate_stack("plates", batch_shape):
            z = pyro.sample("z", hmm)
            pyro.sample("x", dist.Normal(z, 1).to_event(2), obs=data)

    # Test that we can combine these two reparameterizers.
    reparam_model = poutine.reparam(
        model,
        {
            "z":
            LinearHMMReparam(StableReparam(), StableReparam(),
                             StableReparam()),
        },
    )
    reparam_model = poutine.reparam(
        reparam_model,
        {
            "z": ConjugateReparam(dist.Normal(data, 1).to_event(2)),
        },
    )
    reparam_guide = AutoDiagonalNormal(
        reparam_model)  # Models auxiliary variables.

    # Smoke test only.
    elbo = Trace_ELBO(num_particles=5, vectorize_particles=True)
    loss = elbo.differentiable_loss(reparam_model, reparam_guide, data)
    params = [trans_mat, obs_mat]
    torch.autograd.grad(loss, params, retain_graph=True)
Example #19
0
def test_distribution(df, loc, scale):
    def model():
        with pyro.plate("particles", 20000):
            return pyro.sample("x", dist.StudentT(df, loc, scale))

    expected = model()
    with poutine.reparam(config={"x": StudentTReparam()}):
        actual = model()
    assert ks_2samp(expected, actual).pvalue > 0.05
Example #20
0
def test_sphere_reparam_ok(auto_class, init_loc_fn):
    def model():
        x = pyro.sample("x", dist.Normal(0., 1.).expand([3]).to_event(1))
        y = pyro.sample("y", dist.ProjectedNormal(x))
        pyro.sample("obs", dist.Normal(y, 1), obs=torch.tensor([1., 0.]))

    model = poutine.reparam(model, {"y": ProjectedNormalReparam()})
    guide = auto_class(model)
    poutine.trace(guide).get_trace().compute_log_prob()
Example #21
0
    def model(self, zero_data, covariates):
        duration, data_dim = zero_data.shape

        # Let's model each time series as a Levy stable process, and share process parameters
        # across time series. To do that in Pyro, we'll declare the shared random variables
        # outside of the "origin" plate:
        drift_stability = pyro.sample("drift_stability", dist.Uniform(1, 2))
        drift_scale = pyro.sample("drift_scale", dist.LogNormal(-20, 5))
        with pyro.plate("origin", data_dim, dim=-2):
            # Now inside of the origin plate we sample drift and seasonal components.
            # All the time series inside the "origin" plate are independent,
            # given the drift parameters above.
            with self.time_plate:
                # We combine two different reparameterizers: the inner SymmetricStableReparam
                # is needed for the Stable site, and the outer LocScaleReparam is optional but
                # appears to improve inference.
                with poutine.reparam(config={"drift": LocScaleReparam()}):
                    with poutine.reparam(config={"drift": SymmetricStableReparam()}):
                        drift = pyro.sample("drift",
                                            dist.Stable(drift_stability, 0, drift_scale))

            with pyro.plate("hour_of_week", 24 * 7, dim=-1):
                seasonal = pyro.sample("seasonal", dist.Normal(0, 5))

        # Now outside of the time plate we can perform time-dependent operations like
        # integrating over time. This allows us to create a motion with slow drift.
        seasonal = periodic_repeat(seasonal, duration, dim=-1)
        motion = drift.cumsum(dim=-1)  # A Levy stable motion to model shocks.
        prediction = motion + seasonal

        # Next we do some reshaping. Pyro's forecasting framework assumes all data is
        # multivariate of shape (duration, data_dim), but the above code uses an "origins"
        # plate that is left of the time_plate. Our prediction starts off with shape
        assert prediction.shape[-2:] == (data_dim, duration)
        # We need to swap those dimensions but keep the -2 dimension intact, in case Pyro
        # adds sample dimensions to the left of that.
        prediction = prediction.unsqueeze(-1).transpose(-1, -3)
        assert prediction.shape[-3:] == (1, duration, data_dim), prediction.shape

        # Finally we can construct a noise distribution.
        # We will share parameters across all time series.
        obs_scale = pyro.sample("obs_scale", dist.LogNormal(-5, 5))
        noise_dist = dist.Normal(0, obs_scale.unsqueeze(-1))
        self.predict(noise_dist, prediction)
Example #22
0
    def reparam(self, model):
        """
        Wrap a model with ``poutine.reparam``.
        """
        # Transform to Haar coordinates.
        config = {}
        for name, dim in self.dims.items():
            config[name] = HaarReparam(dim=dim, flip=True)
        model = poutine.reparam(model, config)

        if self.split:
            # Split into low- and high-frequency parts.
            splits = [self.split, self.duration - self.split]
            config = {}
            for name, dim in self.dims.items():
                config[name + "_haar"] = SplitReparam(splits, dim=dim)
            model = poutine.reparam(model, config)

        return model
Example #23
0
def test_normal_auto(centered):
    strategy = AutoReparam(centered=centered)
    model = strategy(normal_model)
    actual = trace_name_is_observed(model)
    if centered == 1.0:  # i.e. no decentering
        expected = [
            ("a", False),
            ("b_base", False),
            ("b", True),
            ("c", False),
            ("d_base", False),
            ("d", True),
            ("e", False),
            ("f_base", False),
            ("f", True),
            ("g", True),
            ("h", True),
            ("i", False),
            ("j_base", False),
            ("j", True),
        ]
    else:
        expected = [
            ("a_decentered", False),
            ("a", True),
            ("b_base_decentered", False),
            ("b_base", True),
            ("b", True),
            ("c_decentered", False),
            ("c", True),
            ("d_base_decentered", False),
            ("d_base", True),
            ("d", True),
            ("e_decentered", False),
            ("e", True),
            ("f_base_decentered", False),
            ("f_base", True),
            ("f", True),
            ("g", True),
            ("h", True),
            ("i_decentered", False),
            ("i", True),
            ("j_base_decentered", False),
            ("j_base", True),
            ("j", True),
        ]
    assert actual == expected

    # Also check that the config dict has been constructed.
    config = strategy.config
    assert isinstance(config, dict)
    model = poutine.reparam(normal_model, config)
    actual = trace_name_is_observed(model)
    assert actual == expected
Example #24
0
def test_distribution(stability, skew, Reparam):
    if Reparam is SymmetricStableReparam and (skew != 0 or stability == 2):
        pytest.skip()
    if stability == 2 and skew in (-1, 1):
        pytest.skip()

    def model():
        with pyro.plate("particles", 20000):
            return pyro.sample("x", dist.Stable(stability, skew))

    expected = model()
    with poutine.reparam(config={"x": Reparam()}):
        actual = model()
    assert ks_2samp(expected, actual).pvalue > 0.05
Example #25
0
    def model(self, zero_data, covariates):
        data_dim = zero_data.size(-1)
        feature_dim = covariates.size(-1)
        bias = pyro.sample("bias", dist.Normal(5.0, 10.0).expand((data_dim,)).to_event(1))
        weight = pyro.sample("weight", dist.Normal(0.0, 5).expand((feature_dim,)).to_event(1))

        # We'll sample a time-global scale parameter outside the time plate,
        # then time-local iid noise inside the time plate.
        drift_scale = pyro.sample("drift_scale",
                                  dist.LogNormal(0, 5.0).expand((1,)).to_event(1))
        with self.time_plate:
            # We'll use a reparameterizer to improve variational fit. The model would still be
            # correct if you removed this context manager, but the fit appears to be worse.
            with poutine.reparam(config={"drift": LocScaleReparam()}):
                drift = pyro.sample("drift", dist.Normal(zero_data.double(), drift_scale.double()).to_event(1))

        # After we sample the iid "drift" noise we can combine it in any time-dependent way.
        # It is important to keep everything inside the plate independent and apply dependent
        # transforms outside the plate.
        motion = drift.cumsum(-2)  # A Brownian motion.

        # The prediction now includes three terms.
        prediction = motion + bias + (weight * covariates).sum(-1, keepdim=True)
        assert prediction.shape[-2:] == zero_data.shape

        # The next part of the model creates a likelihood or noise distribution.
        # Again we'll be Bayesian and write this as a probabilistic program with
        # priors over parameters.
        stability = pyro.sample("noise_stability", dist.Uniform(1, 2).expand((1,)).to_event(1))
        skew = pyro.sample("noise_skew", dist.Uniform(-1, 1).expand((1,)).to_event(1))
        scale = pyro.sample("noise_scale", dist.LogNormal(-5, 5).expand((1,)).to_event(1))
        noise_dist = dist.Stable(stability, skew, scale)

        # We need to use a reparameterizer to handle the Stable distribution.
        # Note "residual" is the name of Pyro's internal sample site in self.predict().
        with poutine.reparam(config={"residual": StableReparam()}):
            self.predict(noise_dist, prediction)
Example #26
0
    def __call__(self, msg_or_fn: Union[dict, Callable]):
        """
        Strategies can be used as decorators to reparametrize a model.

        :param msg_or_fn: Public use: a model to be decorated. (Internal use: a
            site to be configured for reparametrization).
        """
        if isinstance(msg_or_fn, dict):  # Internal use during configuration.
            msg = msg_or_fn
            name = msg["name"]
            if name in self.config:
                return self.config[name]
            result = self.configure(msg)
            self.config[name] = result
            return result
        else:  # Public use as a decorator or handler.
            fn = msg_or_fn
            return poutine.reparam(fn, self)
Example #27
0
def test_uniform(shape, dim, smooth):
    def model():
        with pyro.plate_stack("plates", shape[:dim]):
            with pyro.plate("particles", 10000):
                pyro.sample("x",
                            dist.Uniform(0, 1).expand(shape).to_event(-dim))

    value = poutine.trace(model).get_trace().nodes["x"]["value"]
    expected_probe = get_moments(value)

    reparam_model = poutine.reparam(
        model, {"x": DiscreteCosineReparam(dim=dim, smooth=smooth)})
    trace = poutine.trace(reparam_model).get_trace()
    assert isinstance(trace.nodes["x_dct"]["fn"], dist.TransformedDistribution)
    assert isinstance(trace.nodes["x"]["fn"], dist.Delta)
    value = trace.nodes["x"]["value"]
    actual_probe = get_moments(value)
    assert_close(actual_probe, expected_probe, atol=0.1)
Example #28
0
def test_gaussian_hmm_elbo(batch_shape, num_steps, hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim),
                            requires_grad=True)
    trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim)
    obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim),
                          requires_grad=True)
    obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim)

    data = obs_dist.sample()
    assert data.shape == batch_shape + (num_steps, obs_dim)
    prior = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat,
                             obs_dist)
    likelihood = dist.Normal(data, 1).to_event(2)
    posterior, log_normalizer = prior.conjugate_update(likelihood)

    def model(data):
        with pyro.plate_stack("plates", batch_shape):
            z = pyro.sample("z", prior)
            pyro.sample("x", dist.Normal(z, 1).to_event(2), obs=data)

    def guide(data):
        with pyro.plate_stack("plates", batch_shape):
            pyro.sample("z", posterior)

    reparam_model = poutine.reparam(model, {"z": ConjugateReparam(likelihood)})

    def reparam_guide(data):
        pass

    elbo = Trace_ELBO(num_particles=1000, vectorize_particles=True)
    expected_loss = elbo.differentiable_loss(model, guide, data)
    actual_loss = elbo.differentiable_loss(reparam_model, reparam_guide, data)
    assert_close(actual_loss, expected_loss, atol=0.01)

    params = [trans_mat, obs_mat]
    expected_grads = torch.autograd.grad(expected_loss,
                                         params,
                                         retain_graph=True)
    actual_grads = torch.autograd.grad(actual_loss, params, retain_graph=True)
    for a, e in zip(actual_grads, expected_grads):
        assert_close(a, e, rtol=0.01)
Example #29
0
def check_init_reparam(model, reparam):
    assert isinstance(reparam, Reparam)
    with poutine.block():
        init_value = model()
    with InitMessenger(init_to_value(values={"x": init_value})):
        # Sanity check without reparametrizing.
        actual = model()
        assert_close(actual, init_value)

        # Check with reparametrizing.
        with poutine.reparam(config={"x": reparam}):
            with warnings.catch_warnings(record=True) as ws:
                warnings.simplefilter("always", category=RuntimeWarning)
                actual = model()
            for w in ws:
                if w.category == RuntimeWarning and "falling back to default" in str(
                        w):
                    pytest.skip("overwriting initial value")
                else:
                    warnings.warn(str(w.message), category=w.category)

            assert_close(actual, init_value)
Example #30
0
def test_projected_normal(shape, dim):
    concentration = torch.randn(shape + (dim,)).requires_grad_()

    def model():
        with pyro.plate_stack("plates", shape):
            with pyro.plate("particles", 10000):
                pyro.sample("x", dist.ProjectedNormal(concentration))

    value = poutine.trace(model).get_trace().nodes["x"]["value"]
    assert dist.ProjectedNormal.support.check(value).all()
    expected_probe = get_moments(value)

    reparam_model = poutine.reparam(model, {"x": ProjectedNormalReparam()})
    value = poutine.trace(reparam_model).get_trace().nodes["x"]["value"]
    assert dist.ProjectedNormal.support.check(value).all()
    actual_probe = get_moments(value)
    assert_close(actual_probe, expected_probe, atol=0.05)

    for actual_m, expected_m in zip(actual_probe, expected_probe):
        expected_grad = grad(expected_m, [concentration], retain_graph=True)
        actual_grad = grad(actual_m, [concentration], retain_graph=True)
        assert_close(actual_grad, expected_grad, atol=0.1)