def _get_dist(self): """ Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to a :class:`DependentMaternGP` """ trans_matrix, trans_dist, stat_covar = self._trans_matrix_distribution_stat_covar(self.dt) return dist.GaussianHMM(self._get_init_dist(stat_covar), trans_matrix, trans_dist, self._get_obs_matrix(), self._get_obs_dist())
def test_gaussian_hmm_log_prob(init_shape, trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim): init_dist = random_mvn(init_shape, hidden_dim) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(trans_mvn_shape, hidden_dim) obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim)) obs_dist = random_mvn(obs_mvn_shape, obs_dim) actual_dist = GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) expected_dist = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) assert actual_dist.batch_shape == expected_dist.batch_shape assert actual_dist.event_shape == expected_dist.event_shape shape = broadcast_shape(init_shape + (1, ), trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape) data = obs_dist.expand(shape).sample() assert data.shape == actual_dist.shape() actual_log_prob = actual_dist.log_prob(data) expected_log_prob = expected_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5) check_expand(actual_dist, data)
def __call__(self, name, fn, obs): # Reparameterize the initial distribution as conditionally Gaussian. init_dist = fn.initial_dist if self.init is not None: init_dist, _ = self.init("{}_init".format(name), init_dist, None) # Reparameterize the transition distribution as conditionally Gaussian. trans_dist = fn.transition_dist if self.trans is not None: trans_dist, _ = self.trans("{}_trans".format(name), trans_dist.to_event(1), None) trans_dist = trans_dist.to_event(-1) # Reparameterize the observation distribution as conditionally Gaussian. obs_dist = fn.observation_dist if self.obs is not None: obs_dist, obs = self.obs("{}_obs".format(name), obs_dist.to_event(1), obs) obs_dist = obs_dist.to_event(-1) # Reparameterize the entire HMM as conditionally Gaussian. hmm = dist.GaussianHMM(init_dist, fn.transition_matrix, trans_dist, fn.observation_matrix, obs_dist) # Apply any observation transforms. if fn.transforms: hmm = dist.TransformedDistribution(hmm, fn.transforms) return hmm, obs
def __call__(self, name, fn, obs): fn, event_dim = self._unwrap(fn) assert isinstance(fn, (dist.LinearHMM, dist.IndependentHMM)) if fn.duration is None: raise ValueError( "LinearHMMReparam requires duration to be specified " "on targeted LinearHMM distributions") # Unwrap IndependentHMM. if isinstance(fn, dist.IndependentHMM): if obs is not None: obs = obs.transpose(-1, -2).unsqueeze(-1) hmm, obs = self(name, fn.base_dist.to_event(1), obs) hmm = dist.IndependentHMM(hmm.to_event(-1)) if obs is not None: obs = obs.squeeze(-1).transpose(-1, -2) return hmm, obs # Reparameterize the initial distribution as conditionally Gaussian. init_dist = fn.initial_dist if self.init is not None: init_dist, _ = self.init("{}_init".format(name), self._wrap(init_dist, event_dim - 1), None) init_dist = init_dist.to_event(1 - init_dist.event_dim) # Reparameterize the transition distribution as conditionally Gaussian. trans_dist = fn.transition_dist if self.trans is not None: if trans_dist.batch_shape[-1] != fn.duration: trans_dist = trans_dist.expand(trans_dist.batch_shape[:-1] + (fn.duration, )) trans_dist, _ = self.trans("{}_trans".format(name), self._wrap(trans_dist, event_dim), None) trans_dist = trans_dist.to_event(1 - trans_dist.event_dim) # Reparameterize the observation distribution as conditionally Gaussian. obs_dist = fn.observation_dist if self.obs is not None: if obs_dist.batch_shape[-1] != fn.duration: obs_dist = obs_dist.expand(obs_dist.batch_shape[:-1] + (fn.duration, )) obs_dist, obs = self.obs("{}_obs".format(name), self._wrap(obs_dist, event_dim), obs) obs_dist = obs_dist.to_event(1 - obs_dist.event_dim) # Reparameterize the entire HMM as conditionally Gaussian. hmm = dist.GaussianHMM(init_dist, fn.transition_matrix, trans_dist, fn.observation_matrix, obs_dist, duration=fn.duration) hmm = self._wrap(hmm, event_dim) # Apply any observation transforms. if fn.transforms: hmm = dist.TransformedDistribution(hmm, fn.transforms) return hmm, obs
def test_gaussian_hmm_shape(diag, init_shape, trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim): init_dist = random_mvn(init_shape, hidden_dim) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(trans_mvn_shape, hidden_dim) obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim)) obs_dist = random_mvn(obs_mvn_shape, obs_dim) if diag: scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1) obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) shape = broadcast_shape(init_shape + (1,), trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim,) assert d.batch_shape == expected_batch_shape assert d.event_shape == expected_event_shape data = obs_dist.expand(shape).sample() assert data.shape == d.shape() actual = d.log_prob(data) assert actual.shape == expected_batch_shape check_expand(d, data) final = d.filter(data) assert isinstance(final, dist.MultivariateNormal) assert final.batch_shape == d.batch_shape assert final.event_shape == (hidden_dim,)
def _get_dist(self): """ Get the `GaussianHMM` distribution that corresponds to `GenericLGSSMWithGPNoiseModel`. """ gp_trans_matrix, gp_process_covar = self.kernel.transition_matrix_and_covariance( dt=self.dt) trans_covar = self.z_trans_matrix.new_zeros(self.full_state_dim, self.full_state_dim) trans_covar[:self.full_gp_state_dim, :self. full_gp_state_dim] = block_diag_embed(gp_process_covar) eye = torch.eye(self.state_dim, device=trans_covar.device, dtype=trans_covar.dtype) trans_covar[ self.full_gp_state_dim:, self. full_gp_state_dim:] = self.log_trans_noise_scale_sq.exp() * eye trans_dist = MultivariateNormal( trans_covar.new_zeros(self.full_state_dim), trans_covar) full_trans_mat = trans_covar.new_zeros(self.full_state_dim, self.full_state_dim) full_trans_mat[:self.full_gp_state_dim, :self. full_gp_state_dim] = block_diag_embed(gp_trans_matrix) full_trans_mat[self.full_gp_state_dim:, self.full_gp_state_dim:] = self.z_trans_matrix return dist.GaussianHMM(self._get_init_dist(), full_trans_mat, trans_dist, self._get_obs_matrix(), self._get_obs_dist())
def get_dist(self, duration=None): """ Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to :class:`GenericLGSSMWithGPNoiseModel`. :param int duration: Optional size of the time axis ``event_shape[0]``. This is required when sampling from homogeneous HMMs whose parameters are not expanded along the time axis. """ gp_trans_matrix, gp_process_covar = self.kernel.transition_matrix_and_covariance( dt=self.dt) trans_covar = self.z_trans_matrix.new_zeros(self.full_state_dim, self.full_state_dim) trans_covar[:self.full_gp_state_dim, :self. full_gp_state_dim] = block_diag_embed(gp_process_covar) trans_covar[ self.full_gp_state_dim:, self.full_gp_state_dim:] = self.trans_noise_scale_sq.diag_embed() trans_dist = MultivariateNormal( trans_covar.new_zeros(self.full_state_dim), trans_covar) full_trans_mat = trans_covar.new_zeros(self.full_state_dim, self.full_state_dim) full_trans_mat[:self.full_gp_state_dim, :self. full_gp_state_dim] = block_diag_embed(gp_trans_matrix) full_trans_mat[self.full_gp_state_dim:, self.full_gp_state_dim:] = self.z_trans_matrix return dist.GaussianHMM(self._get_init_dist(), full_trans_mat, trans_dist, self._get_obs_matrix(), self._get_obs_dist(), duration=duration)
def _get_dist(self): """ Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to :class:`GenericLGSSMWithGPNoiseModel`. """ gp_trans_matrix, gp_process_covar = self.kernel.transition_matrix_and_covariance( dt=self.dt) trans_covar = self.z_trans_matrix.new_zeros(self.full_state_dim, self.full_state_dim) trans_covar[:self.full_gp_state_dim, :self. full_gp_state_dim] = block_diag_embed(gp_process_covar) trans_covar[ self.full_gp_state_dim:, self.full_gp_state_dim:] = self.trans_noise_scale_sq.diag_embed() trans_dist = MultivariateNormal( trans_covar.new_zeros(self.full_state_dim), trans_covar) full_trans_mat = trans_covar.new_zeros(self.full_state_dim, self.full_state_dim) full_trans_mat[:self.full_gp_state_dim, :self. full_gp_state_dim] = block_diag_embed(gp_trans_matrix) full_trans_mat[self.full_gp_state_dim:, self.full_gp_state_dim:] = self.z_trans_matrix return dist.GaussianHMM(self._get_init_dist(), full_trans_mat, trans_dist, self._get_obs_matrix(), self._get_obs_dist())
def _get_dist(self): """ Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to :class:`GenericLGSSM`. """ return dist.GaussianHMM(self._get_init_dist(), self.trans_matrix, self._get_trans_dist(), self.obs_matrix, self._get_obs_dist())
def _forward_pyro(self, features, trip_counts): total_hours = len(features) observed_hours, num_origins, num_destins = trip_counts.shape assert observed_hours <= total_hours assert num_origins == self.num_stations assert num_destins == self.num_stations time_plate = pyro.plate("time", observed_hours, dim=-3) origins_plate = pyro.plate("origins", num_origins, dim=-2) destins_plate = pyro.plate("destins", num_destins, dim=-1) # The first half of the model performs exact inference over # the observed portion of the time series. hmm = dist.GaussianHMM(*self._dynamics(features[:observed_hours])) gate_rate = pyro.sample("gate_rate", hmm) gate, rate = self._unpack_gate_rate(gate_rate, event_dim=2) with time_plate, origins_plate, destins_plate: pyro.sample("trip_count", dist.ZeroInflatedPoisson(gate, rate), obs=trip_counts) # The second half of the model forecasts forward. if total_hours > observed_hours: state_dist = hmm.filter(gate_rate) return self._forward_pyro_forecast(features, trip_counts, origins_plate, destins_plate, state_dist=state_dist)
def get_dist(self, duration=None): """ Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to a :class:`DependentMaternGP` :param int duration: Optional size of the time axis ``event_shape[0]``. This is required when sampling from homogeneous HMMs whose parameters are not expanded along the time axis. """ trans_matrix, trans_dist, stat_covar = self._trans_matrix_distribution_stat_covar(self.dt) return dist.GaussianHMM(self._get_init_dist(stat_covar), trans_matrix, trans_dist, self._get_obs_matrix(), self._get_obs_dist(), duration=duration)
def _get_dist(self): """ Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to ``obs_dim``-many independent Matern GPs. """ trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=self.dt) trans_dist = MultivariateNormal(self.obs_matrix.new_zeros(self.obs_dim, 1, self.kernel.state_dim), process_covar.unsqueeze(-3)) trans_matrix = trans_matrix.unsqueeze(-3) return dist.GaussianHMM(self._get_init_dist(), trans_matrix, trans_dist, self.obs_matrix, self._get_obs_dist())
def test_gaussian_hmm_log_prob(diag, sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim)) trans_dist = random_mvn(batch_shape + (num_steps,), hidden_dim) obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim)) obs_dist = random_mvn(batch_shape + (num_steps,), obs_dim) if diag: scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1) obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) if diag: obs_mvn = dist.MultivariateNormal(obs_dist.base_dist.loc, scale_tril=obs_dist.base_dist.scale.diag_embed()) else: obs_mvn = obs_dist data = obs_dist.sample(sample_shape) assert data.shape == sample_shape + d.shape() actual_log_prob = d.log_prob(data) # Compare against hand-computed density. # We will construct enormous unrolled joint gaussians with shapes: # t | 0 1 2 3 1 2 3 T = 3 in this example # ------+----------------------------------------- # init | H # trans | H H H H H = hidden # obs | H H H O O O O = observed # and then combine these using gaussian_tensordot(). T = num_steps init = mvn_to_gaussian(init_dist) trans = matrix_and_mvn_to_gaussian(trans_mat, trans_dist) obs = matrix_and_mvn_to_gaussian(obs_mat, obs_mvn) unrolled_trans = reduce(operator.add, [ trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) for t in range(T) ]) unrolled_obs = reduce(operator.add, [ obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) for t in range(T) ]) # Permute obs from HOHOHO to HHHOOO. perm = torch.cat([torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)]) unrolled_obs = unrolled_obs.event_permute(perm) unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim,)) assert init.dim() == hidden_dim assert unrolled_trans.dim() == (1 + T) * hidden_dim assert unrolled_obs.dim() == T * (hidden_dim + obs_dim) logp = gaussian_tensordot(init, unrolled_trans, hidden_dim) logp = gaussian_tensordot(logp, unrolled_obs, T * hidden_dim) expected_log_prob = logp.log_density(unrolled_data) assert_close(actual_log_prob, expected_log_prob)
def _get_dist(self): """ Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to a :class:`LinearlyCoupledMaternGP`. """ trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=self.dt) trans_matrix = block_diag_embed(trans_matrix) process_covar = block_diag_embed(process_covar) loc = self.A.new_zeros(self.full_state_dim) trans_dist = MultivariateNormal(loc, process_covar) return dist.GaussianHMM(self._get_init_dist(), trans_matrix, trans_dist, self._get_obs_matrix(), self._get_obs_dist())
def get_dist(self, duration=None): """ Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to :class:`GenericLGSSM`. :param int duration: Optional size of the time axis ``event_shape[0]``. This is required when sampling from homogeneous HMMs whose parameters are not expanded along the time axis. """ return dist.GaussianHMM(self._get_init_dist(), self.trans_matrix, self._get_trans_dist(), self.obs_matrix, self._get_obs_dist(), duration=duration)
def forward(self, features, trip_counts): pyro.module("model", self) total_hours = len(features) observed_hours, num_origins, num_destins = trip_counts.shape assert observed_hours <= total_hours assert num_origins == self.num_stations assert num_destins == self.num_stations time_plate = pyro.plate("time", observed_hours, dim=-3) origins_plate = pyro.plate("origins", num_origins, dim=-2) destins_plate = pyro.plate("destins", num_destins, dim=-1) # The first half of the model performs exact inference over # the observed portion of the time series. hmm = dist.GaussianHMM(*self._dynamics(features[:observed_hours])) gate_rate = pyro.sample("gate_rate", hmm) gate, rate = self._unpack_gate_rate(gate_rate, event_dim=2) with time_plate, origins_plate, destins_plate: pyro.sample("trip_count", dist.ZeroInflatedPoisson(gate, rate), obs=trip_counts) # The second half of the model forecasts forward. forecast = [] forecast_hours = total_hours - observed_hours if forecast_hours > 0: _, trans_matrix, trans_dist, obs_matrix, obs_dist = \ self._dynamics(features[observed_hours:]) state = None for t in range(forecast_hours): if state is None: # on first step state_dist = hmm.filter(gate_rate) else: loc = vm(state, trans_matrix) + trans_dist.loc scale_tril = trans_dist.scale_tril state_dist = dist.MultivariateNormal(loc, scale_tril=scale_tril) state = pyro.sample("state_{}".format(t), state_dist) loc = vm(state, obs_matrix) + obs_dist.base_dist.loc[..., t, :] scale = obs_dist.base_dist.scale[..., t, :] gate_rate = pyro.sample("gate_rate_{}".format(t), dist.Normal(loc, scale).to_event(1)) gate, rate = self._unpack_gate_rate(gate_rate, event_dim=1) with origins_plate, destins_plate: forecast.append( pyro.sample("trip_count_{}".format(t), dist.ZeroInflatedPoisson(gate, rate))) return forecast
def get_dist(self, duration=None): """ Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to ``obs_dim``-many independent Matern GPs. :param int duration: Optional size of the time axis ``event_shape[0]``. This is required when sampling from homogeneous HMMs whose parameters are not expanded along the time axis. """ trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=self.dt) trans_dist = MultivariateNormal(self.obs_matrix.new_zeros(self.obs_dim, 1, self.kernel.state_dim), process_covar.unsqueeze(-3)) trans_matrix = trans_matrix.unsqueeze(-3) return dist.GaussianHMM(self._get_init_dist(), trans_matrix, trans_dist, self.obs_matrix, self._get_obs_dist(), duration=duration)
def get_dist(self, duration=None): """ Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to a :class:`LinearlyCoupledMaternGP`. :param int duration: Optional size of the time axis ``event_shape[0]``. This is required when sampling from homogeneous HMMs whose parameters are not expanded along the time axis. """ trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=self.dt) trans_matrix = block_diag_embed(trans_matrix) process_covar = block_diag_embed(process_covar) loc = self.A.new_zeros(self.full_state_dim) trans_dist = MultivariateNormal(loc, process_covar) return dist.GaussianHMM(self._get_init_dist(), trans_matrix, trans_dist, self._get_obs_matrix(), self._get_obs_dist(), duration=duration)
def model(self, zero_data, covariates): duration = zero_data.size(-2) init_dist = dist.Normal(0, 1).expand([1]).to_event(1) obs_dist = dist.Normal(0, 2).expand([1]).to_event(1) obs_matrix = torch.tensor([[1.]]) trans_dist = dist.Normal(0, 1).expand([1]).to_event(1) trans_matrix = torch.tensor([[1.]]) pre_dist = dist.GaussianHMM(init_dist, trans_matrix, trans_dist, obs_matrix, obs_dist, duration=duration) prediction = periodic_repeat(torch.zeros(1, 1), duration, dim=-1).unsqueeze(-1) self.predict(pre_dist, prediction)
def test_independent_hmm_shape(init_shape, trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim): base_init_shape = init_shape + (obs_dim, ) base_trans_mat_shape = trans_mat_shape[:-1] + (obs_dim, trans_mat_shape[-1] if trans_mat_shape else 6) base_trans_mvn_shape = trans_mvn_shape[:-1] + (obs_dim, trans_mvn_shape[-1] if trans_mvn_shape else 6) base_obs_mat_shape = obs_mat_shape[:-1] + (obs_dim, obs_mat_shape[-1] if obs_mat_shape else 6) base_obs_mvn_shape = obs_mvn_shape[:-1] + (obs_dim, obs_mvn_shape[-1] if obs_mvn_shape else 6) init_dist = random_mvn(base_init_shape, hidden_dim) trans_mat = torch.randn(base_trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(base_trans_mvn_shape, hidden_dim) obs_mat = torch.randn(base_obs_mat_shape + (hidden_dim, 1)) obs_dist = random_mvn(base_obs_mvn_shape, 1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=6) d = dist.IndependentHMM(d) shape = broadcast_shape(init_shape + (6, ), trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim, ) assert d.batch_shape == expected_batch_shape assert d.event_shape == expected_event_shape assert d.support.event_dim == d.event_dim data = torch.randn(shape + (obs_dim, )) assert data.shape == d.shape() actual = d.log_prob(data) assert actual.shape == expected_batch_shape check_expand(d, data) x = d.rsample() assert x.shape == d.shape() x = d.rsample((6, )) assert x.shape == (6, ) + d.shape() x = d.expand((6, 5)).rsample() assert x.shape == (6, 5) + d.event_shape
def model(self, zero_data, covariates): with pyro.plate_stack("batch", zero_data.shape[:-2], rightmost_dim=-2): loc = zero_data[..., :1, :] scale = pyro.sample("scale", dist.LogNormal(loc, 1).to_event(1)) with self.time_plate: jumps = pyro.sample("jumps", dist.Normal(0, scale).to_event(1)) prediction = jumps.cumsum(-2) duration, obs_dim = zero_data.shape[-2:] noise_dist = dist.GaussianHMM( dist.Normal(0, 1).expand([obs_dim]).to_event(1), torch.eye(obs_dim), dist.Normal(0, 1).expand([obs_dim]).to_event(1), torch.eye(obs_dim), dist.Normal(0, 1).expand([obs_dim]).to_event(1), duration=duration, ) self.predict(noise_dist, prediction)
def test_gaussian_hmm_elbo(batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim), requires_grad=True) trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim) obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim), requires_grad=True) obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim) data = obs_dist.sample() assert data.shape == batch_shape + (num_steps, obs_dim) prior = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) likelihood = dist.Normal(data, 1).to_event(2) posterior, log_normalizer = prior.conjugate_update(likelihood) def model(data): with pyro.plate_stack("plates", batch_shape): z = pyro.sample("z", prior) pyro.sample("x", dist.Normal(z, 1).to_event(2), obs=data) def guide(data): with pyro.plate_stack("plates", batch_shape): pyro.sample("z", posterior) reparam_model = poutine.reparam(model, {"z": ConjugateReparam(likelihood)}) def reparam_guide(data): pass elbo = Trace_ELBO(num_particles=1000, vectorize_particles=True) expected_loss = elbo.differentiable_loss(model, guide, data) actual_loss = elbo.differentiable_loss(reparam_model, reparam_guide, data) assert_close(actual_loss, expected_loss, atol=0.01) params = [trans_mat, obs_mat] expected_grads = torch.autograd.grad(expected_loss, params, retain_graph=True) actual_grads = torch.autograd.grad(actual_loss, params, retain_graph=True) for a, e in zip(actual_grads, expected_grads): assert_close(a, e, rtol=0.01)
def test_gaussian_hmm_high_obs_dim(): hidden_dim = 1 obs_dim = 1000 duration = 10 sample_shape = (100, ) init_dist = random_mvn((), hidden_dim) trans_mat = torch.randn((duration, ) + (hidden_dim, hidden_dim)) trans_dist = random_mvn((duration, ), hidden_dim) obs_mat = torch.randn((duration, ) + (hidden_dim, obs_dim)) loc = torch.randn((duration, obs_dim)) scale = torch.randn((duration, obs_dim)).exp() obs_dist = dist.Normal(loc, scale).to_event(1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) x = d.rsample(sample_shape) assert x.shape == sample_shape + (duration, obs_dim)
def test_gaussian_hmm_log_prob_null_dynamics(init_shape, trans_mat_shape, trans_mvn_shape, obs_mvn_shape, hidden_dim): obs_dim = hidden_dim init_dist = random_mvn(init_shape, hidden_dim) # impose null dynamics trans_mat = torch.zeros(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(trans_mvn_shape, hidden_dim, diag=True) # trivial observation matrix (hidden_dim = obs_dim) obs_mat = torch.eye(hidden_dim) obs_dist = random_mvn(obs_mvn_shape, obs_dim, diag=True) actual_dist = GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) expected_dist = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) assert actual_dist.batch_shape == expected_dist.batch_shape assert actual_dist.event_shape == expected_dist.event_shape shape = broadcast_shape(init_shape + (1, ), trans_mat_shape, trans_mvn_shape, obs_mvn_shape) data = obs_dist.expand(shape).sample() assert data.shape == actual_dist.shape() actual_log_prob = actual_dist.log_prob(data) expected_log_prob = expected_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5) check_expand(actual_dist, data) obs_cov = obs_dist.covariance_matrix.diagonal(dim1=-1, dim2=-2) trans_cov = trans_dist.covariance_matrix.diagonal(dim1=-1, dim2=-2) sum_scale = (obs_cov + trans_cov).sqrt() sum_loc = trans_dist.loc + obs_dist.loc analytic_log_prob = dist.Normal(sum_loc, sum_scale).log_prob(data).sum(-1).sum(-1) assert_close(analytic_log_prob, actual_log_prob, atol=1.0e-5)
def test_gaussian_filter(): dim = 4 init_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=torch.eye(dim) * 10) trans_mat = torch.eye(dim) trans_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=torch.eye(dim)) obs_mat = torch.eye(dim) obs_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=torch.eye(dim) * 2) hmm = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) class Model: def init(self, state): state["z"] = pyro.sample("z_init", init_dist) self.t = 0 def step(self, state, datum=None): state["z"] = pyro.sample( "z_{}".format(self.t), dist.MultivariateNormal(state["z"], scale_tril=trans_dist.scale_tril)) datum = pyro.sample( "obs_{}".format(self.t), dist.MultivariateNormal(state["z"], scale_tril=obs_dist.scale_tril), obs=datum) self.t += 1 return datum class Guide: def init(self, state): pyro.sample("z_init", init_dist) self.t = 0 def step(self, state, datum): pyro.sample( "z_{}".format(self.t), dist.MultivariateNormal(state["z"], scale_tril=trans_dist.scale_tril * 2)) self.t += 1 # Generate data. num_steps = 20 model = Model() state = {} model.init(state) data = torch.stack([model.step(state) for _ in range(num_steps)]) # Perform inference. model = Model() guide = Guide() smc = SMCFilter(model, guide, num_particles=1000, max_plate_nesting=0) smc.init() for t, datum in enumerate(data): smc.step(datum) expected = hmm.filter(data[:1 + t]) actual = smc.get_empirical()["z"] assert_close(actual.variance**0.5, expected.variance**0.5, atol=0.1, rtol=0.5) sigma = actual.variance.max().item()**0.5 assert_close(actual.mean, expected.mean, atol=3 * sigma)
def model(self, zero_data, covariates): period = 24 * 7 duration, dim = zero_data.shape[-2:] assert dim == 2 # Data is bivariate: (arrivals, departures). # Sample global parameters. noise_scale = pyro.sample( "noise_scale", dist.LogNormal(torch.full((dim, ), -3.), 1.).to_event(1)) assert noise_scale.shape[-1:] == (dim, ) trans_timescale = pyro.sample( "trans_timescale", dist.LogNormal(torch.zeros(dim), 1).to_event(1)) assert trans_timescale.shape[-1:] == (dim, ) trans_loc = pyro.sample("trans_loc", dist.Cauchy(0, 1 / period)) trans_loc = trans_loc.unsqueeze(-1).expand(trans_loc.shape + (dim, )) assert trans_loc.shape[-1:] == (dim, ) trans_scale = pyro.sample( "trans_scale", dist.LogNormal(torch.zeros(dim), 0.1).to_event(1)) trans_corr = pyro.sample("trans_corr", dist.LKJCorrCholesky(dim, torch.ones(()))) trans_scale_tril = trans_scale.unsqueeze(-1) * trans_corr assert trans_scale_tril.shape[-2:] == (dim, dim) obs_scale = pyro.sample( "obs_scale", dist.LogNormal(torch.zeros(dim), 0.1).to_event(1)) obs_corr = pyro.sample("obs_corr", dist.LKJCorrCholesky(dim, torch.ones(()))) obs_scale_tril = obs_scale.unsqueeze(-1) * obs_corr assert obs_scale_tril.shape[-2:] == (dim, dim) # Note the initial seasonality should be sampled in a plate with the # same dim as the time_plate, dim=-1. That way we can repeat the dim # below using periodic_repeat(). with pyro.plate("season_plate", period, dim=-1): season_init = pyro.sample( "season_init", dist.Normal(torch.zeros(dim), 1).to_event(1)) assert season_init.shape[-2:] == (period, dim) # Sample independent noise at each time step. with self.time_plate: season_noise = pyro.sample("season_noise", dist.Normal(0, noise_scale).to_event(1)) assert season_noise.shape[-2:] == (duration, dim) # Construct a prediction. This prediction has an exactly repeated # seasonal part plus slow seasonal drift. We use two deterministic, # linear functions to transform our diagonal Normal noise to nontrivial # samples from a Gaussian process. prediction = (periodic_repeat(season_init, duration, dim=-2) + periodic_cumsum(season_noise, period, dim=-2)) assert prediction.shape[-2:] == (duration, dim) # Construct a joint noise model. This model is a GaussianHMM, whose # .rsample() and .log_prob() methods are parallelized over time; this # this entire model is parallelized over time. init_dist = dist.Normal(torch.zeros(dim), 100).to_event(1) trans_mat = trans_timescale.neg().exp().diag_embed() trans_dist = dist.MultivariateNormal(trans_loc, scale_tril=trans_scale_tril) obs_mat = torch.eye(dim) obs_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=obs_scale_tril) noise_model = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) assert noise_model.event_shape == (duration, dim) # The final statement registers our noise model and prediction. self.predict(noise_model, prediction)
def test_gaussian_hmm_distribution(diag, sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim)) trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim) obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim)) obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim) if diag: scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1) obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=num_steps) if diag: obs_mvn = dist.MultivariateNormal( obs_dist.base_dist.loc, scale_tril=obs_dist.base_dist.scale.diag_embed()) else: obs_mvn = obs_dist data = obs_dist.sample(sample_shape) assert data.shape == sample_shape + d.shape() actual_log_prob = d.log_prob(data) # Compare against hand-computed density. # We will construct enormous unrolled joint gaussians with shapes: # t | 0 1 2 3 1 2 3 T = 3 in this example # ------+----------------------------------------- # init | H # trans | H H H H H = hidden # obs | H H H O O O O = observed # like | O O O # and then combine these using gaussian_tensordot(). T = num_steps init = mvn_to_gaussian(init_dist) trans = matrix_and_mvn_to_gaussian(trans_mat, trans_dist) obs = matrix_and_mvn_to_gaussian(obs_mat, obs_mvn) like_dist = dist.Normal(torch.randn(data.shape), 1).to_event(2) like = mvn_to_gaussian(like_dist) unrolled_trans = reduce(operator.add, [ trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) for t in range(T) ]) unrolled_obs = reduce(operator.add, [ obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) for t in range(T) ]) unrolled_like = reduce(operator.add, [ like[..., t].event_pad(left=t * obs_dim, right=(T - t - 1) * obs_dim) for t in range(T) ]) # Permute obs from HOHOHO to HHHOOO. perm = torch.cat( [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)]) unrolled_obs = unrolled_obs.event_permute(perm) unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, )) assert init.dim() == hidden_dim assert unrolled_trans.dim() == (1 + T) * hidden_dim assert unrolled_obs.dim() == T * (hidden_dim + obs_dim) logp = gaussian_tensordot(init, unrolled_trans, hidden_dim) logp = gaussian_tensordot(logp, unrolled_obs, T * hidden_dim) expected_log_prob = logp.log_density(unrolled_data) assert_close(actual_log_prob, expected_log_prob) d_posterior, log_normalizer = d.conjugate_update(like_dist) assert_close( d.log_prob(data) + like_dist.log_prob(data), d_posterior.log_prob(data) + log_normalizer) if batch_shape or sample_shape: return # Test mean and covariance. prior = "prior", d, logp posterior = "posterior", d_posterior, logp + unrolled_like for name, d, g in [prior, posterior]: logging.info("testing {} moments".format(name)) with torch.no_grad(): num_samples = 100000 samples = d.sample([num_samples]).reshape(num_samples, T * obs_dim) actual_mean = samples.mean(0) delta = samples - actual_mean actual_cov = (delta.unsqueeze(-1) * delta.unsqueeze(-2)).mean(0) actual_std = actual_cov.diagonal(dim1=-2, dim2=-1).sqrt() actual_corr = actual_cov / (actual_std.unsqueeze(-1) * actual_std.unsqueeze(-2)) expected_cov = g.precision.cholesky().cholesky_inverse() expected_mean = expected_cov.matmul( g.info_vec.unsqueeze(-1)).squeeze(-1) expected_std = expected_cov.diagonal(dim1=-2, dim2=-1).sqrt() expected_corr = expected_cov / (expected_std.unsqueeze(-1) * expected_std.unsqueeze(-2)) assert_close(actual_mean, expected_mean, atol=0.05, rtol=0.02) assert_close(actual_std, expected_std, atol=0.05, rtol=0.02) assert_close(actual_corr, expected_corr, atol=0.02)
def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] fn, event_dim = self._unwrap(fn) assert isinstance(fn, (dist.LinearHMM, dist.IndependentHMM)) if fn.duration is None: raise ValueError( "LinearHMMReparam requires duration to be specified " "on targeted LinearHMM distributions") # Unwrap IndependentHMM. if isinstance(fn, dist.IndependentHMM): indep_value = None if value is not None: indep_value = value.transpose(-1, -2).unsqueeze(-1) msg = self.apply({ "name": name, "fn": fn.base_dist.to_event(1), "value": indep_value, "is_observed": is_observed, }) hmm = msg["fn"] hmm = dist.IndependentHMM(hmm.to_event(-1)) if msg["value"] is not indep_value: value = msg["value"].squeeze(-1).transpose(-1, -2) return {"fn": hmm, "value": value, "is_observed": is_observed} # Reparameterize the initial distribution as conditionally Gaussian. init_dist = fn.initial_dist if self.init is not None: msg = self.init.apply({ "name": f"{name}_init", "fn": self._wrap(init_dist, event_dim - 1), "value": None, "is_observed": False, }) init_dist = msg["fn"] init_dist = init_dist.to_event(1 - init_dist.event_dim) # Reparameterize the transition distribution as conditionally Gaussian. trans_dist = fn.transition_dist if self.trans is not None: if trans_dist.batch_shape[-1] != fn.duration: trans_dist = trans_dist.expand(trans_dist.batch_shape[:-1] + (fn.duration, )) msg = self.trans.apply({ "name": f"{name}_trans", "fn": self._wrap(trans_dist, event_dim), "value": None, "is_observed": False, }) trans_dist = msg["fn"] trans_dist = trans_dist.to_event(1 - trans_dist.event_dim) # Reparameterize the observation distribution as conditionally Gaussian. obs_dist = fn.observation_dist if self.obs is not None: if obs_dist.batch_shape[-1] != fn.duration: obs_dist = obs_dist.expand(obs_dist.batch_shape[:-1] + (fn.duration, )) msg = self.obs.apply({ "name": f"{name}_obs", "fn": self._wrap(obs_dist, event_dim), "value": value, "is_observed": is_observed, }) obs_dist = msg["fn"] obs_dist = obs_dist.to_event(1 - obs_dist.event_dim) value = msg["value"] is_observed = msg["is_observed"] # Reparameterize the entire HMM as conditionally Gaussian. hmm = dist.GaussianHMM( init_dist, fn.transition_matrix, trans_dist, fn.observation_matrix, obs_dist, duration=fn.duration, ) hmm = self._wrap(hmm, event_dim) # Apply any observation transforms. if fn.transforms: hmm = dist.TransformedDistribution(hmm, fn.transforms) return {"fn": hmm, "value": value, "is_observed": is_observed}
def test_gaussian_hmm_shape(diag, init_shape, trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim): init_dist = random_mvn(init_shape, hidden_dim) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(trans_mvn_shape, hidden_dim) obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim)) obs_dist = random_mvn(obs_mvn_shape, obs_dim) if diag: scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1) obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=6) shape = broadcast_shape(init_shape + (6, ), trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim, ) assert d.batch_shape == expected_batch_shape assert d.event_shape == expected_event_shape assert d.support.event_dim == d.event_dim data = obs_dist.expand(shape).sample() assert data.shape == d.shape() actual = d.log_prob(data) assert actual.shape == expected_batch_shape check_expand(d, data) x = d.rsample() assert x.shape == d.shape() x = d.rsample((6, )) assert x.shape == (6, ) + d.shape() x = d.expand((6, 5)).rsample() assert x.shape == (6, 5) + d.event_shape likelihood = dist.Normal(data, 1).to_event(2) p, log_normalizer = d.conjugate_update(likelihood) assert p.batch_shape == d.batch_shape assert p.event_shape == d.event_shape x = p.rsample() assert x.shape == d.shape() x = p.rsample((6, )) assert x.shape == (6, ) + d.shape() x = p.expand((6, 5)).rsample() assert x.shape == (6, 5) + d.event_shape final = d.filter(data) assert isinstance(final, dist.MultivariateNormal) assert final.batch_shape == d.batch_shape assert final.event_shape == (hidden_dim, ) z = d.rsample_posterior(data) assert z.shape == expected_batch_shape + time_shape + (hidden_dim, ) for t in range(1, d.duration - 1): f = d.duration - t d2 = d.prefix_condition(data[..., :t, :]) assert d2.batch_shape == d.batch_shape assert d2.event_shape == (f, obs_dim)