def __init__(self, initial_dist, transition_dist, observation_dist, validate_args=None): assert isinstance(initial_dist, torch.distributions.MultivariateNormal) assert isinstance(transition_dist, torch.distributions.MultivariateNormal) assert isinstance(observation_dist, torch.distributions.MultivariateNormal) hidden_dim = initial_dist.event_shape[0] assert transition_dist.event_shape[0] == hidden_dim + hidden_dim obs_dim = observation_dist.event_shape[0] - hidden_dim shape = broadcast_shape(initial_dist.batch_shape + (1, ), transition_dist.batch_shape, observation_dist.batch_shape) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim, ) # Convert distributions to funsors. init = dist_to_funsor(initial_dist)(value="state") trans = mvn_to_funsor( transition_dist, ("time", ), OrderedDict([("state", Reals[hidden_dim]), ("state(time=1)", Reals[hidden_dim])])) obs = mvn_to_funsor( observation_dist, ("time", ), OrderedDict([("state(time=1)", Reals[hidden_dim]), ("value", Reals[obs_dim])])) # Construct the joint funsor. # Compare with pyro.distributions.hmm.GaussianMRF.log_prob(). with interpretation(lazy): time = Variable("time", Bint[time_shape[0]]) value = Variable("value", Reals[time_shape[0], obs_dim]) logp_oh = trans + obs(value=value["time"]) logp_oh = MarkovProduct(ops.logaddexp, ops.add, logp_oh, time, {"state": "state(time=1)"}) logp_oh += init logp_oh = logp_oh.reduce(ops.logaddexp, frozenset({"state", "state(time=1)"})) logp_h = trans + obs.reduce(ops.logaddexp, "value") logp_h = MarkovProduct(ops.logaddexp, ops.add, logp_h, time, {"state": "state(time=1)"}) logp_h += init logp_h = logp_h.reduce(ops.logaddexp, frozenset({"state", "state(time=1)"})) funsor_dist = logp_oh - logp_h dtype = "real" super(GaussianMRF, self).__init__(funsor_dist, batch_shape, event_shape, dtype, validate_args) self.hidden_dim = hidden_dim self.obs_dim = obs_dim
def test_mvn_to_funsor(batch_shape, event_shape, event_sizes): event_size = sum(event_sizes) mvn = random_mvn(batch_shape + event_shape, event_size) int_inputs = OrderedDict( (k, bint(size)) for k, size in zip("abc", event_shape)) real_inputs = OrderedDict( (k, reals(size)) for k, size in zip("xyz", event_sizes)) f = mvn_to_funsor(mvn, tuple(int_inputs), real_inputs) assert isinstance(f, Funsor) for k, d in int_inputs.items(): if d.num_elements == 1: assert d not in f.inputs else: assert k in f.inputs assert f.inputs[k] == d for k, d in real_inputs.items(): assert k in f.inputs assert f.inputs[k] == d value = mvn.sample() subs = {} beg = 0 for k, d in real_inputs.items(): end = beg + d.num_elements subs[k] = tensor_to_funsor(value[..., beg:end], tuple(int_inputs), 1) beg = end actual_log_prob = f(**subs) expected_log_prob = tensor_to_funsor(mvn.log_prob(value), tuple(int_inputs)) assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5)
def test_matrix_and_mvn_to_funsor(batch_shape, event_shape, x_size, y_size): matrix = torch.randn(batch_shape + event_shape + (x_size, y_size)) y_mvn = random_mvn(batch_shape + event_shape, y_size) xy_mvn = random_mvn(batch_shape + event_shape, x_size + y_size) int_inputs = OrderedDict( (k, bint(size)) for k, size in zip("abc", event_shape)) real_inputs = OrderedDict([("x", reals(x_size)), ("y", reals(y_size))]) f = (matrix_and_mvn_to_funsor(matrix, y_mvn, tuple(int_inputs), "x", "y") + mvn_to_funsor(xy_mvn, tuple(int_inputs), real_inputs)) assert isinstance(f, Funsor) for k, d in int_inputs.items(): if d.num_elements == 1: assert d not in f.inputs else: assert k in f.inputs assert f.inputs[k] == d assert f.inputs["x"] == reals(x_size) assert f.inputs["y"] == reals(y_size) xy = torch.randn(x_size + y_size) x, y = xy[:x_size], xy[x_size:] y_pred = x.unsqueeze(-2).matmul(matrix).squeeze(-2) actual_log_prob = f(x=x, y=y) expected_log_prob = tensor_to_funsor( xy_mvn.log_prob(xy) + y_mvn.log_prob(y - y_pred), tuple(int_inputs)) assert_close(actual_log_prob, expected_log_prob, atol=1e-4, rtol=1e-4)
def __init__(self, num_components, # the number of switching states K hidden_dim, # the dimension of the continuous latent space obs_dim, # the dimension of the continuous outputs fine_transition_matrix=True, # controls whether the transition matrix depends on s_t fine_transition_noise=False, # controls whether the transition noise depends on s_t fine_observation_matrix=False, # controls whether the observation matrix depends on s_t fine_observation_noise=False, # controls whether the observation noise depends on s_t moment_matching_lag=1): # controls the expense of the moment matching approximation self.num_components = num_components self.hidden_dim = hidden_dim self.obs_dim = obs_dim self.moment_matching_lag = moment_matching_lag self.fine_transition_noise = fine_transition_noise self.fine_observation_matrix = fine_observation_matrix self.fine_observation_noise = fine_observation_noise self.fine_transition_matrix = fine_transition_matrix assert moment_matching_lag > 0 assert fine_transition_noise or fine_observation_matrix or fine_observation_noise or fine_transition_matrix, \ "The continuous dynamics need to be coupled to the discrete dynamics in at least one way [use at " + \ "least one of the arguments --ftn --ftm --fon --fom]" super(SLDS, self).__init__() # initialize the various parameters of the model self.transition_logits = nn.Parameter(0.1 * torch.randn(num_components, num_components)) if fine_transition_matrix: transition_matrix = torch.eye(hidden_dim) + 0.05 * torch.randn(num_components, hidden_dim, hidden_dim) else: transition_matrix = torch.eye(hidden_dim) + 0.05 * torch.randn(hidden_dim, hidden_dim) self.transition_matrix = nn.Parameter(transition_matrix) if fine_transition_noise: self.log_transition_noise = nn.Parameter(0.1 * torch.randn(num_components, hidden_dim)) else: self.log_transition_noise = nn.Parameter(0.1 * torch.randn(hidden_dim)) if fine_observation_matrix: self.observation_matrix = nn.Parameter(0.3 * torch.randn(num_components, hidden_dim, obs_dim)) else: self.observation_matrix = nn.Parameter(0.3 * torch.randn(hidden_dim, obs_dim)) if fine_observation_noise: self.log_obs_noise = nn.Parameter(0.1 * torch.randn(num_components, obs_dim)) else: self.log_obs_noise = nn.Parameter(0.1 * torch.randn(obs_dim)) # define the prior distribution p(x_0) over the continuous latent at the initial time step t=0 x_init_mvn = pyro.distributions.MultivariateNormal(torch.zeros(self.hidden_dim), torch.eye(self.hidden_dim)) self.x_init_mvn = mvn_to_funsor(x_init_mvn, real_inputs=OrderedDict([('x_0', funsor.Reals[self.hidden_dim])]))