コード例 #1
0
def test_periodic_cumsum(period, size, left_shape, right_shape):
    dim = -1 - len(right_shape)
    tensor = torch.randn(left_shape + (size, ) + right_shape)
    actual = periodic_cumsum(tensor, period, dim)
    assert actual.shape == tensor.shape
    dots = (slice(None), ) * len(left_shape)
    for t in range(period):
        assert_equal(actual[dots + (t, )], tensor[dots + (t, )])
    for t in range(period, size):
        assert_close(actual[dots + (t, )],
                     tensor[dots + (t, )] + actual[dots + (t - period, )])
コード例 #2
0
ファイル: bart.py プロジェクト: nwjnwj/pyro
    def model(self, zero_data, covariates):
        period = 24 * 7
        duration, dim = zero_data.shape[-2:]
        assert dim == 2  # Data is bivariate: (arrivals, departures).

        # Sample global parameters.
        noise_scale = pyro.sample(
            "noise_scale",
            dist.LogNormal(torch.full((dim, ), -3.), 1.).to_event(1))
        assert noise_scale.shape[-1:] == (dim, )
        trans_timescale = pyro.sample(
            "trans_timescale",
            dist.LogNormal(torch.zeros(dim), 1).to_event(1))
        assert trans_timescale.shape[-1:] == (dim, )

        trans_loc = pyro.sample("trans_loc", dist.Cauchy(0, 1 / period))
        trans_loc = trans_loc.unsqueeze(-1).expand(trans_loc.shape + (dim, ))
        assert trans_loc.shape[-1:] == (dim, )
        trans_scale = pyro.sample(
            "trans_scale",
            dist.LogNormal(torch.zeros(dim), 0.1).to_event(1))
        trans_corr = pyro.sample("trans_corr",
                                 dist.LKJCorrCholesky(dim, torch.ones(())))
        trans_scale_tril = trans_scale.unsqueeze(-1) * trans_corr
        assert trans_scale_tril.shape[-2:] == (dim, dim)

        obs_scale = pyro.sample(
            "obs_scale",
            dist.LogNormal(torch.zeros(dim), 0.1).to_event(1))
        obs_corr = pyro.sample("obs_corr",
                               dist.LKJCorrCholesky(dim, torch.ones(())))
        obs_scale_tril = obs_scale.unsqueeze(-1) * obs_corr
        assert obs_scale_tril.shape[-2:] == (dim, dim)

        # Note the initial seasonality should be sampled in a plate with the
        # same dim as the time_plate, dim=-1. That way we can repeat the dim
        # below using periodic_repeat().
        with pyro.plate("season_plate", period, dim=-1):
            season_init = pyro.sample(
                "season_init",
                dist.Normal(torch.zeros(dim), 1).to_event(1))
            assert season_init.shape[-2:] == (period, dim)

        # Sample independent noise at each time step.
        with self.time_plate:
            season_noise = pyro.sample("season_noise",
                                       dist.Normal(0, noise_scale).to_event(1))
            assert season_noise.shape[-2:] == (duration, dim)

        # Construct a prediction. This prediction has an exactly repeated
        # seasonal part plus slow seasonal drift. We use two deterministic,
        # linear functions to transform our diagonal Normal noise to nontrivial
        # samples from a Gaussian process.
        prediction = (periodic_repeat(season_init, duration, dim=-2) +
                      periodic_cumsum(season_noise, period, dim=-2))
        assert prediction.shape[-2:] == (duration, dim)

        # Construct a joint noise model. This model is a GaussianHMM, whose
        # .rsample() and .log_prob() methods are parallelized over time; this
        # this entire model is parallelized over time.
        init_dist = dist.Normal(torch.zeros(dim), 100).to_event(1)
        trans_mat = trans_timescale.neg().exp().diag_embed()
        trans_dist = dist.MultivariateNormal(trans_loc,
                                             scale_tril=trans_scale_tril)
        obs_mat = torch.eye(dim)
        obs_dist = dist.MultivariateNormal(torch.zeros(dim),
                                           scale_tril=obs_scale_tril)
        noise_model = dist.GaussianHMM(init_dist,
                                       trans_mat,
                                       trans_dist,
                                       obs_mat,
                                       obs_dist,
                                       duration=duration)
        assert noise_model.event_shape == (duration, dim)

        # The final statement registers our noise model and prediction.
        self.predict(noise_model, prediction)