def __call__(self, name, fn, obs): fn, event_dim = self._unwrap(fn) assert isinstance(fn, (dist.LinearHMM, dist.IndependentHMM)) if fn.duration is None: raise ValueError( "LinearHMMReparam requires duration to be specified " "on targeted LinearHMM distributions") # Unwrap IndependentHMM. if isinstance(fn, dist.IndependentHMM): if obs is not None: obs = obs.transpose(-1, -2).unsqueeze(-1) hmm, obs = self(name, fn.base_dist.to_event(1), obs) hmm = dist.IndependentHMM(hmm.to_event(-1)) if obs is not None: obs = obs.squeeze(-1).transpose(-1, -2) return hmm, obs # Reparameterize the initial distribution as conditionally Gaussian. init_dist = fn.initial_dist if self.init is not None: init_dist, _ = self.init("{}_init".format(name), self._wrap(init_dist, event_dim - 1), None) init_dist = init_dist.to_event(1 - init_dist.event_dim) # Reparameterize the transition distribution as conditionally Gaussian. trans_dist = fn.transition_dist if self.trans is not None: if trans_dist.batch_shape[-1] != fn.duration: trans_dist = trans_dist.expand(trans_dist.batch_shape[:-1] + (fn.duration, )) trans_dist, _ = self.trans("{}_trans".format(name), self._wrap(trans_dist, event_dim), None) trans_dist = trans_dist.to_event(1 - trans_dist.event_dim) # Reparameterize the observation distribution as conditionally Gaussian. obs_dist = fn.observation_dist if self.obs is not None: if obs_dist.batch_shape[-1] != fn.duration: obs_dist = obs_dist.expand(obs_dist.batch_shape[:-1] + (fn.duration, )) obs_dist, obs = self.obs("{}_obs".format(name), self._wrap(obs_dist, event_dim), obs) obs_dist = obs_dist.to_event(1 - obs_dist.event_dim) # Reparameterize the entire HMM as conditionally Gaussian. hmm = dist.GaussianHMM(init_dist, fn.transition_matrix, trans_dist, fn.observation_matrix, obs_dist, duration=fn.duration) hmm = self._wrap(hmm, event_dim) # Apply any observation transforms. if fn.transforms: hmm = dist.TransformedDistribution(hmm, fn.transforms) return hmm, obs
def test_independent_hmm_shape(init_shape, trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim): base_init_shape = init_shape + (obs_dim, ) base_trans_mat_shape = trans_mat_shape[:-1] + (obs_dim, trans_mat_shape[-1] if trans_mat_shape else 6) base_trans_mvn_shape = trans_mvn_shape[:-1] + (obs_dim, trans_mvn_shape[-1] if trans_mvn_shape else 6) base_obs_mat_shape = obs_mat_shape[:-1] + (obs_dim, obs_mat_shape[-1] if obs_mat_shape else 6) base_obs_mvn_shape = obs_mvn_shape[:-1] + (obs_dim, obs_mvn_shape[-1] if obs_mvn_shape else 6) init_dist = random_mvn(base_init_shape, hidden_dim) trans_mat = torch.randn(base_trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(base_trans_mvn_shape, hidden_dim) obs_mat = torch.randn(base_obs_mat_shape + (hidden_dim, 1)) obs_dist = random_mvn(base_obs_mvn_shape, 1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=6) d = dist.IndependentHMM(d) shape = broadcast_shape(init_shape + (6, ), trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim, ) assert d.batch_shape == expected_batch_shape assert d.event_shape == expected_event_shape assert d.support.event_dim == d.event_dim data = torch.randn(shape + (obs_dim, )) assert data.shape == d.shape() actual = d.log_prob(data) assert actual.shape == expected_batch_shape check_expand(d, data) x = d.rsample() assert x.shape == d.shape() x = d.rsample((6, )) assert x.shape == (6, ) + d.shape() x = d.expand((6, 5)).rsample() assert x.shape == (6, 5) + d.event_shape
def test_independent_hmm_shape(skew, batch_shape, duration, hidden_dim, obs_dim): base_batch_shape = batch_shape + (obs_dim, ) stability = dist.Uniform(0.5, 2).sample(base_batch_shape) init_dist = random_stable(base_batch_shape + (hidden_dim, ), stability.unsqueeze(-1), skew=skew).to_event(1) trans_mat = torch.randn(base_batch_shape + (duration, hidden_dim, hidden_dim)) trans_dist = random_stable( base_batch_shape + (duration, hidden_dim), stability.unsqueeze(-1).unsqueeze(-1), skew=skew, ).to_event(1) obs_mat = torch.randn(base_batch_shape + (duration, hidden_dim, 1)) obs_dist = random_stable( base_batch_shape + (duration, 1), stability.unsqueeze(-1).unsqueeze(-1), skew=skew, ).to_event(1) hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) assert hmm.batch_shape == base_batch_shape assert hmm.event_shape == (duration, 1) hmm = dist.IndependentHMM(hmm) assert hmm.batch_shape == batch_shape assert hmm.event_shape == (duration, obs_dim) def model(data=None): with pyro.plate_stack("plates", batch_shape): return pyro.sample("x", hmm, obs=data) data = torch.randn(duration, obs_dim) rep = SymmetricStableReparam() if skew == 0 else StableReparam() with poutine.trace() as tr: with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}): model(data) assert isinstance(tr.trace.nodes["x"]["fn"], dist.IndependentHMM) tr.trace.compute_log_prob() # smoke test only
def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] fn, event_dim = self._unwrap(fn) assert isinstance(fn, (dist.LinearHMM, dist.IndependentHMM)) if fn.duration is None: raise ValueError( "LinearHMMReparam requires duration to be specified " "on targeted LinearHMM distributions") # Unwrap IndependentHMM. if isinstance(fn, dist.IndependentHMM): indep_value = None if value is not None: indep_value = value.transpose(-1, -2).unsqueeze(-1) msg = self.apply({ "name": name, "fn": fn.base_dist.to_event(1), "value": indep_value, "is_observed": is_observed, }) hmm = msg["fn"] hmm = dist.IndependentHMM(hmm.to_event(-1)) if msg["value"] is not indep_value: value = msg["value"].squeeze(-1).transpose(-1, -2) return {"fn": hmm, "value": value, "is_observed": is_observed} # Reparameterize the initial distribution as conditionally Gaussian. init_dist = fn.initial_dist if self.init is not None: msg = self.init.apply({ "name": f"{name}_init", "fn": self._wrap(init_dist, event_dim - 1), "value": None, "is_observed": False, }) init_dist = msg["fn"] init_dist = init_dist.to_event(1 - init_dist.event_dim) # Reparameterize the transition distribution as conditionally Gaussian. trans_dist = fn.transition_dist if self.trans is not None: if trans_dist.batch_shape[-1] != fn.duration: trans_dist = trans_dist.expand(trans_dist.batch_shape[:-1] + (fn.duration, )) msg = self.trans.apply({ "name": f"{name}_trans", "fn": self._wrap(trans_dist, event_dim), "value": None, "is_observed": False, }) trans_dist = msg["fn"] trans_dist = trans_dist.to_event(1 - trans_dist.event_dim) # Reparameterize the observation distribution as conditionally Gaussian. obs_dist = fn.observation_dist if self.obs is not None: if obs_dist.batch_shape[-1] != fn.duration: obs_dist = obs_dist.expand(obs_dist.batch_shape[:-1] + (fn.duration, )) msg = self.obs.apply({ "name": f"{name}_obs", "fn": self._wrap(obs_dist, event_dim), "value": value, "is_observed": is_observed, }) obs_dist = msg["fn"] obs_dist = obs_dist.to_event(1 - obs_dist.event_dim) value = msg["value"] is_observed = msg["is_observed"] # Reparameterize the entire HMM as conditionally Gaussian. hmm = dist.GaussianHMM( init_dist, fn.transition_matrix, trans_dist, fn.observation_matrix, obs_dist, duration=fn.duration, ) hmm = self._wrap(hmm, event_dim) # Apply any observation transforms. if fn.transforms: hmm = dist.TransformedDistribution(hmm, fn.transforms) return {"fn": hmm, "value": value, "is_observed": is_observed}
def _(d, batch_shape): base_shape = batch_shape + d.event_shape[-1:] base_dist = reshape_batch(d.base_dist, base_shape) return dist.IndependentHMM(base_dist)
def _(d, data): base_data = data.transpose(-1, -2).unsqueeze(-1) base_dist = prefix_condition(d.base_dist, base_data) return dist.IndependentHMM(base_dist)