def test_periodic_cumsum(period, size, left_shape, right_shape): dim = -1 - len(right_shape) tensor = torch.randn(left_shape + (size, ) + right_shape) actual = periodic_cumsum(tensor, period, dim) assert actual.shape == tensor.shape dots = (slice(None), ) * len(left_shape) for t in range(period): assert_equal(actual[dots + (t, )], tensor[dots + (t, )]) for t in range(period, size): assert_close(actual[dots + (t, )], tensor[dots + (t, )] + actual[dots + (t - period, )])
def model(self, zero_data, covariates): period = 24 * 7 duration, dim = zero_data.shape[-2:] assert dim == 2 # Data is bivariate: (arrivals, departures). # Sample global parameters. noise_scale = pyro.sample( "noise_scale", dist.LogNormal(torch.full((dim, ), -3.), 1.).to_event(1)) assert noise_scale.shape[-1:] == (dim, ) trans_timescale = pyro.sample( "trans_timescale", dist.LogNormal(torch.zeros(dim), 1).to_event(1)) assert trans_timescale.shape[-1:] == (dim, ) trans_loc = pyro.sample("trans_loc", dist.Cauchy(0, 1 / period)) trans_loc = trans_loc.unsqueeze(-1).expand(trans_loc.shape + (dim, )) assert trans_loc.shape[-1:] == (dim, ) trans_scale = pyro.sample( "trans_scale", dist.LogNormal(torch.zeros(dim), 0.1).to_event(1)) trans_corr = pyro.sample("trans_corr", dist.LKJCorrCholesky(dim, torch.ones(()))) trans_scale_tril = trans_scale.unsqueeze(-1) * trans_corr assert trans_scale_tril.shape[-2:] == (dim, dim) obs_scale = pyro.sample( "obs_scale", dist.LogNormal(torch.zeros(dim), 0.1).to_event(1)) obs_corr = pyro.sample("obs_corr", dist.LKJCorrCholesky(dim, torch.ones(()))) obs_scale_tril = obs_scale.unsqueeze(-1) * obs_corr assert obs_scale_tril.shape[-2:] == (dim, dim) # Note the initial seasonality should be sampled in a plate with the # same dim as the time_plate, dim=-1. That way we can repeat the dim # below using periodic_repeat(). with pyro.plate("season_plate", period, dim=-1): season_init = pyro.sample( "season_init", dist.Normal(torch.zeros(dim), 1).to_event(1)) assert season_init.shape[-2:] == (period, dim) # Sample independent noise at each time step. with self.time_plate: season_noise = pyro.sample("season_noise", dist.Normal(0, noise_scale).to_event(1)) assert season_noise.shape[-2:] == (duration, dim) # Construct a prediction. This prediction has an exactly repeated # seasonal part plus slow seasonal drift. We use two deterministic, # linear functions to transform our diagonal Normal noise to nontrivial # samples from a Gaussian process. prediction = (periodic_repeat(season_init, duration, dim=-2) + periodic_cumsum(season_noise, period, dim=-2)) assert prediction.shape[-2:] == (duration, dim) # Construct a joint noise model. This model is a GaussianHMM, whose # .rsample() and .log_prob() methods are parallelized over time; this # this entire model is parallelized over time. init_dist = dist.Normal(torch.zeros(dim), 100).to_event(1) trans_mat = trans_timescale.neg().exp().diag_embed() trans_dist = dist.MultivariateNormal(trans_loc, scale_tril=trans_scale_tril) obs_mat = torch.eye(dim) obs_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=obs_scale_tril) noise_model = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) assert noise_model.event_shape == (duration, dim) # The final statement registers our noise model and prediction. self.predict(noise_model, prediction)