def __init__(self, initial_dist, transition_dist, observation_dist, validate_args=None): assert isinstance(initial_dist, torch.distributions.MultivariateNormal) assert isinstance(transition_dist, torch.distributions.MultivariateNormal) assert isinstance(observation_dist, torch.distributions.MultivariateNormal) hidden_dim = initial_dist.event_shape[0] assert transition_dist.event_shape[0] == hidden_dim + hidden_dim obs_dim = observation_dist.event_shape[0] - hidden_dim shape = broadcast_shape(initial_dist.batch_shape + (1, ), transition_dist.batch_shape, observation_dist.batch_shape) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim, ) # Convert distributions to funsors. init = dist_to_funsor(initial_dist)(value="state") trans = mvn_to_funsor( transition_dist, ("time", ), OrderedDict([("state", Reals[hidden_dim]), ("state(time=1)", Reals[hidden_dim])])) obs = mvn_to_funsor( observation_dist, ("time", ), OrderedDict([("state(time=1)", Reals[hidden_dim]), ("value", Reals[obs_dim])])) # Construct the joint funsor. # Compare with pyro.distributions.hmm.GaussianMRF.log_prob(). with interpretation(lazy): time = Variable("time", Bint[time_shape[0]]) value = Variable("value", Reals[time_shape[0], obs_dim]) logp_oh = trans + obs(value=value["time"]) logp_oh = MarkovProduct(ops.logaddexp, ops.add, logp_oh, time, {"state": "state(time=1)"}) logp_oh += init logp_oh = logp_oh.reduce(ops.logaddexp, frozenset({"state", "state(time=1)"})) logp_h = trans + obs.reduce(ops.logaddexp, "value") logp_h = MarkovProduct(ops.logaddexp, ops.add, logp_h, time, {"state": "state(time=1)"}) logp_h += init logp_h = logp_h.reduce(ops.logaddexp, frozenset({"state", "state(time=1)"})) funsor_dist = logp_oh - logp_h dtype = "real" super(GaussianMRF, self).__init__(funsor_dist, batch_shape, event_shape, dtype, validate_args) self.hidden_dim = hidden_dim self.obs_dim = obs_dim
def __init__(self, initial_dist, transition_matrix, transition_dist, observation_matrix, observation_dist, validate_args=None): assert isinstance(initial_dist, torch.distributions.MultivariateNormal) assert isinstance(transition_matrix, torch.Tensor) assert isinstance(transition_dist, torch.distributions.MultivariateNormal) assert isinstance(observation_matrix, torch.Tensor) assert isinstance(observation_dist, torch.distributions.MultivariateNormal) hidden_dim, obs_dim = observation_matrix.shape[-2:] assert obs_dim >= hidden_dim // 2, "obs_dim must be at least half of hidden_dim" assert initial_dist.event_shape == (hidden_dim, ) assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim) assert transition_dist.event_shape == (hidden_dim, ) assert observation_dist.event_shape == (obs_dim, ) shape = broadcast_shape(initial_dist.batch_shape + (1, ), transition_matrix.shape[:-2], transition_dist.batch_shape, observation_matrix.shape[:-2], observation_dist.batch_shape) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim, ) # Convert distributions to funsors. init = dist_to_funsor(initial_dist)(value="state") trans = matrix_and_mvn_to_funsor(transition_matrix, transition_dist, ("time", ), "state", "state(time=1)") obs = matrix_and_mvn_to_funsor(observation_matrix, observation_dist, ("time", ), "state(time=1)", "value") dtype = "real" # Construct the joint funsor. with interpretation(lazy): value = Variable("value", Reals[time_shape[0], obs_dim]) result = trans + obs(value=value["time"]) result = MarkovProduct(ops.logaddexp, ops.add, result, "time", {"state": "state(time=1)"}) result = init + result.reduce(ops.logaddexp, "state(time=1)") funsor_dist = result.reduce(ops.logaddexp, "state") super(GaussianHMM, self).__init__(funsor_dist, batch_shape, event_shape, dtype, validate_args) self.hidden_dim = hidden_dim self.obs_dim = obs_dim
def _forward_funsor(self, features, trip_counts): total_hours = len(features) observed_hours, num_origins, num_destins = trip_counts.shape assert observed_hours == total_hours assert num_origins == self.num_stations assert num_destins == self.num_stations n = self.num_stations gate_rate = funsor.Variable("gate_rate_t", reals(observed_hours, 2 * n * n))["time"] @funsor.torch.function(reals(2 * n * n), (reals(n, n, 2), reals(n, n))) def unpack_gate_rate(gate_rate): batch_shape = gate_rate.shape[:-1] gate, rate = gate_rate.reshape(batch_shape + (2, n, n)).unbind(-3) gate = gate.sigmoid().clamp(min=0.01, max=0.99) rate = bounded_exp(rate, bound=1e4) gate = torch.stack((1 - gate, gate), dim=-1) return gate, rate # Create a Gaussian latent dynamical system. init_dist, trans_matrix, trans_dist, obs_matrix, obs_dist = \ self._dynamics(features[:observed_hours]) init = dist_to_funsor(init_dist)(value="state") trans = matrix_and_mvn_to_funsor(trans_matrix, trans_dist, ("time", ), "state", "state(time=1)") obs = matrix_and_mvn_to_funsor(obs_matrix, obs_dist, ("time", ), "state(time=1)", "gate_rate") # Compute dynamic prior over gate_rate. prior = trans + obs(gate_rate=gate_rate) prior = MarkovProduct(ops.logaddexp, ops.add, prior, "time", {"state": "state(time=1)"}) prior += init prior = prior.reduce(ops.logaddexp, {"state", "state(time=1)"}) # Compute zero-inflated Poisson likelihood. gate, rate = unpack_gate_rate(gate_rate) likelihood = fdist.Categorical(gate["origin", "destin"], value="gated") trip_counts = tensor_to_funsor(trip_counts, ("time", "origin", "destin")) likelihood += funsor.Stack( "gated", (fdist.Poisson(rate["origin", "destin"], value=trip_counts), fdist.Delta(0, value=trip_counts))) likelihood = likelihood.reduce(ops.logaddexp, "gated") likelihood = likelihood.reduce(ops.add, {"time", "origin", "destin"}) assert set(prior.inputs) == {"gate_rate_t"}, prior.inputs assert set(likelihood.inputs) == {"gate_rate_t"}, likelihood.inputs return prior, likelihood