def test_distributions(state_dim, obs_dim): data = Tensor(torch.randn(2, obs_dim))["time"] bias = Variable("bias", reals(obs_dim)) bias_dist = dist_to_funsor(random_mvn((), obs_dim))(value=bias) prev = Variable("prev", reals(state_dim)) curr = Variable("curr", reals(state_dim)) trans_mat = Tensor( torch.eye(state_dim) + 0.1 * torch.randn(state_dim, state_dim)) trans_mvn = random_mvn((), state_dim) trans_dist = dist.MultivariateNormal(loc=trans_mvn.loc, scale_tril=trans_mvn.scale_tril, value=curr - prev @ trans_mat) state = Variable("state", reals(state_dim)) obs = Variable("obs", reals(obs_dim)) obs_mat = Tensor(torch.randn(state_dim, obs_dim)) obs_mvn = random_mvn((), obs_dim) obs_dist = dist.MultivariateNormal(loc=obs_mvn.loc, scale_tril=obs_mvn.scale_tril, value=state @ obs_mat + bias - obs) log_prob = 0 log_prob += bias_dist state_0 = Variable("state_0", reals(state_dim)) log_prob += obs_dist(state=state_0, obs=data(time=0)) state_1 = Variable("state_1", reals(state_dim)) log_prob += trans_dist(prev=state_0, curr=state_1) log_prob += obs_dist(state=state_1, obs=data(time=1)) log_prob = log_prob.reduce(ops.logaddexp) assert isinstance(log_prob, Tensor), log_prob.pretty()
def test_mvn_affine_one_var(): x = Variable('x', reals(2)) data = dict(x=Tensor(torch.randn(2))) with interpretation(lazy): d = dist_to_funsor(random_mvn((), 2)) d = d(value=2 * x + 1) _check_mvn_affine(d, data)
def test_pyro_convert(): data = Tensor(torch.randn(2, 2), OrderedDict([("time", bint(2))])) bias_dist = dist_to_funsor(random_mvn((), 2)) trans_mat = torch.randn(3, 3) trans_mvn = random_mvn((), 3) trans = matrix_and_mvn_to_funsor(trans_mat, trans_mvn, (), "prev", "curr") obs_mat = torch.randn(3, 2) obs_mvn = random_mvn((), 2) obs = matrix_and_mvn_to_funsor(obs_mat, obs_mvn, (), "state", "obs") log_prob = 0 bias = Variable("bias", reals(2)) log_prob += bias_dist(value=bias) state_0 = Variable("state_0", reals(3)) log_prob += obs(state=state_0, obs=bias + data(time=0)) state_1 = Variable("state_1", reals(3)) log_prob += trans(prev=state_0, curr=state_1) log_prob += obs(state=state_1, obs=bias + data(time=1)) log_prob = log_prob.reduce(ops.logaddexp) assert isinstance(log_prob, Tensor), log_prob.pretty()
def test_mvn_affine_getitem(): x = Variable('x', reals(2, 2)) data = dict(x=Tensor(torch.randn(2, 2))) with interpretation(lazy): d = dist_to_funsor(random_mvn((), 2)) d = d(value=x[0] - x[1]) _check_mvn_affine(d, data)
def test_mvn_affine_two_vars(): x = Variable('x', reals(2)) y = Variable('y', reals(2)) data = dict(x=Tensor(randn(2)), y=Tensor(randn(2))) with interpretation(lazy): d = dist_to_funsor(random_mvn((), 2)) d = d(value=x - y) _check_mvn_affine(d, data)
def test_mvn_affine_reshape(): x = Variable('x', reals(2, 2)) y = Variable('y', reals(4)) data = dict(x=Tensor(torch.randn(2, 2)), y=Tensor(torch.randn(4))) with interpretation(lazy): d = dist_to_funsor(random_mvn((), 4)) d = d(value=x.reshape((4, )) - y) _check_mvn_affine(d, data)
def test_dist_to_funsor_categorical(batch_shape, cardinality): logits = torch.randn(batch_shape + (cardinality, )) logits -= logits.logsumexp(dim=-1, keepdim=True) d = dist.Categorical(logits=logits) f = dist_to_funsor(d) assert isinstance(f, Tensor) expected = tensor_to_funsor(logits, ("value", )) assert_close(f, expected)
def test_mvn_affine_einsum(): c = Tensor(torch.randn(3, 2, 2)) x = Variable('x', reals(2, 2)) y = Variable('y', reals()) data = dict(x=Tensor(torch.randn(2, 2)), y=Tensor(torch.randn(()))) with interpretation(lazy): d = dist_to_funsor(random_mvn((), 3)) d = d(value=Einsum("abc,bc->a", c, x) + y) _check_mvn_affine(d, data)
def test_mvn_affine_matmul(): x = Variable('x', reals(2)) y = Variable('y', reals(3)) m = Tensor(torch.randn(2, 3)) data = dict(x=Tensor(torch.randn(2)), y=Tensor(torch.randn(3))) with interpretation(lazy): d = dist_to_funsor(random_mvn((), 3)) d = d(value=x @ m - y) _check_mvn_affine(d, data)
def test_dist_to_funsor_bernoulli(batch_shape): logits = torch.randn(batch_shape) d = dist.Bernoulli(logits=logits) f = dist_to_funsor(d) assert isinstance(f, Funsor) value = d.sample() actual_log_prob = f(value=tensor_to_funsor(value)) expected_log_prob = tensor_to_funsor(d.log_prob(value)) assert_close(actual_log_prob, expected_log_prob)
def test_dist_to_funsor_normal(batch_shape): loc = torch.randn(batch_shape) scale = torch.randn(batch_shape).exp() d = dist.Normal(loc, scale) f = dist_to_funsor(d) assert isinstance(f, Funsor) value = d.sample() actual_log_prob = f(value=tensor_to_funsor(value)) expected_log_prob = tensor_to_funsor(d.log_prob(value)) assert_close(actual_log_prob, expected_log_prob, rtol=1e-5)
def __init__(self, initial_dist, transition_dist, observation_dist, validate_args=None): assert isinstance(initial_dist, torch.distributions.MultivariateNormal) assert isinstance(transition_dist, torch.distributions.MultivariateNormal) assert isinstance(observation_dist, torch.distributions.MultivariateNormal) hidden_dim = initial_dist.event_shape[0] assert transition_dist.event_shape[0] == hidden_dim + hidden_dim obs_dim = observation_dist.event_shape[0] - hidden_dim shape = broadcast_shape(initial_dist.batch_shape + (1, ), transition_dist.batch_shape, observation_dist.batch_shape) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim, ) # Convert distributions to funsors. init = dist_to_funsor(initial_dist)(value="state") trans = mvn_to_funsor( transition_dist, ("time", ), OrderedDict([("state", Reals[hidden_dim]), ("state(time=1)", Reals[hidden_dim])])) obs = mvn_to_funsor( observation_dist, ("time", ), OrderedDict([("state(time=1)", Reals[hidden_dim]), ("value", Reals[obs_dim])])) # Construct the joint funsor. # Compare with pyro.distributions.hmm.GaussianMRF.log_prob(). with interpretation(lazy): time = Variable("time", Bint[time_shape[0]]) value = Variable("value", Reals[time_shape[0], obs_dim]) logp_oh = trans + obs(value=value["time"]) logp_oh = MarkovProduct(ops.logaddexp, ops.add, logp_oh, time, {"state": "state(time=1)"}) logp_oh += init logp_oh = logp_oh.reduce(ops.logaddexp, frozenset({"state", "state(time=1)"})) logp_h = trans + obs.reduce(ops.logaddexp, "value") logp_h = MarkovProduct(ops.logaddexp, ops.add, logp_h, time, {"state": "state(time=1)"}) logp_h += init logp_h = logp_h.reduce(ops.logaddexp, frozenset({"state", "state(time=1)"})) funsor_dist = logp_oh - logp_h dtype = "real" super(GaussianMRF, self).__init__(funsor_dist, batch_shape, event_shape, dtype, validate_args) self.hidden_dim = hidden_dim self.obs_dim = obs_dim
def test_dist_to_funsor_independent(batch_shape, event_shape): loc = torch.randn(batch_shape + event_shape) scale = torch.randn(batch_shape + event_shape).exp() d = dist.Normal(loc, scale).to_event(len(event_shape)) f = dist_to_funsor(d) assert isinstance(f, Funsor) value = d.sample() funsor_value = tensor_to_funsor(value, event_output=len(event_shape)) actual_log_prob = f(value=funsor_value) expected_log_prob = tensor_to_funsor(d.log_prob(value)) assert_close(actual_log_prob, expected_log_prob, rtol=1e-5)
def test_dist_to_funsor_masked(batch_shape): loc = torch.randn(batch_shape) scale = torch.randn(batch_shape).exp() mask = torch.bernoulli(torch.full(batch_shape, 0.5)).byte() d = dist.Normal(loc, scale).mask(mask) assert isinstance(d, MaskedDistribution) f = dist_to_funsor(d) assert isinstance(f, Funsor) value = d.sample() actual_log_prob = f(value=tensor_to_funsor(value)) expected_log_prob = tensor_to_funsor(d.log_prob(value)) assert_close(actual_log_prob, expected_log_prob)
def test_dist_to_funsor_mvn(batch_shape, event_size): loc = torch.randn(batch_shape + (event_size, )) cov = torch.randn(batch_shape + (event_size, 2 * event_size)) cov = cov.matmul(cov.transpose(-1, -2)) scale_tril = torch.cholesky(cov) d = dist.MultivariateNormal(loc, scale_tril=scale_tril) f = dist_to_funsor(d) assert isinstance(f, Funsor) value = d.sample() actual_log_prob = f(value=tensor_to_funsor(value, event_output=1)) expected_log_prob = tensor_to_funsor(d.log_prob(value)) assert_close(actual_log_prob, expected_log_prob)
def _forward_funsor(self, features, trip_counts): total_hours = len(features) observed_hours, num_origins, num_destins = trip_counts.shape assert observed_hours == total_hours assert num_origins == self.num_stations assert num_destins == self.num_stations n = self.num_stations gate_rate = funsor.Variable("gate_rate_t", reals(observed_hours, 2 * n * n))["time"] @funsor.torch.function(reals(2 * n * n), (reals(n, n, 2), reals(n, n))) def unpack_gate_rate(gate_rate): batch_shape = gate_rate.shape[:-1] gate, rate = gate_rate.reshape(batch_shape + (2, n, n)).unbind(-3) gate = gate.sigmoid().clamp(min=0.01, max=0.99) rate = bounded_exp(rate, bound=1e4) gate = torch.stack((1 - gate, gate), dim=-1) return gate, rate # Create a Gaussian latent dynamical system. init_dist, trans_matrix, trans_dist, obs_matrix, obs_dist = \ self._dynamics(features[:observed_hours]) init = dist_to_funsor(init_dist)(value="state") trans = matrix_and_mvn_to_funsor(trans_matrix, trans_dist, ("time", ), "state", "state(time=1)") obs = matrix_and_mvn_to_funsor(obs_matrix, obs_dist, ("time", ), "state(time=1)", "gate_rate") # Compute dynamic prior over gate_rate. prior = trans + obs(gate_rate=gate_rate) prior = MarkovProduct(ops.logaddexp, ops.add, prior, "time", {"state": "state(time=1)"}) prior += init prior = prior.reduce(ops.logaddexp, {"state", "state(time=1)"}) # Compute zero-inflated Poisson likelihood. gate, rate = unpack_gate_rate(gate_rate) likelihood = fdist.Categorical(gate["origin", "destin"], value="gated") trip_counts = tensor_to_funsor(trip_counts, ("time", "origin", "destin")) likelihood += funsor.Stack( "gated", (fdist.Poisson(rate["origin", "destin"], value=trip_counts), fdist.Delta(0, value=trip_counts))) likelihood = likelihood.reduce(ops.logaddexp, "gated") likelihood = likelihood.reduce(ops.add, {"time", "origin", "destin"}) assert set(prior.inputs) == {"gate_rate_t"}, prior.inputs assert set(likelihood.inputs) == {"gate_rate_t"}, likelihood.inputs return prior, likelihood
def __init__(self, initial_dist, transition_matrix, transition_dist, observation_matrix, observation_dist, validate_args=None): assert isinstance(initial_dist, torch.distributions.MultivariateNormal) assert isinstance(transition_matrix, torch.Tensor) assert isinstance(transition_dist, torch.distributions.MultivariateNormal) assert isinstance(observation_matrix, torch.Tensor) assert isinstance(observation_dist, torch.distributions.MultivariateNormal) hidden_dim, obs_dim = observation_matrix.shape[-2:] assert obs_dim >= hidden_dim // 2, "obs_dim must be at least half of hidden_dim" assert initial_dist.event_shape == (hidden_dim, ) assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim) assert transition_dist.event_shape == (hidden_dim, ) assert observation_dist.event_shape == (obs_dim, ) shape = broadcast_shape(initial_dist.batch_shape + (1, ), transition_matrix.shape[:-2], transition_dist.batch_shape, observation_matrix.shape[:-2], observation_dist.batch_shape) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim, ) # Convert distributions to funsors. init = dist_to_funsor(initial_dist)(value="state") trans = matrix_and_mvn_to_funsor(transition_matrix, transition_dist, ("time", ), "state", "state(time=1)") obs = matrix_and_mvn_to_funsor(observation_matrix, observation_dist, ("time", ), "state(time=1)", "value") dtype = "real" # Construct the joint funsor. with interpretation(lazy): value = Variable("value", Reals[time_shape[0], obs_dim]) result = trans + obs(value=value["time"]) result = MarkovProduct(ops.logaddexp, ops.add, result, "time", {"state": "state(time=1)"}) result = init + result.reduce(ops.logaddexp, "state(time=1)") funsor_dist = result.reduce(ops.logaddexp, "state") super(GaussianHMM, self).__init__(funsor_dist, batch_shape, event_shape, dtype, validate_args) self.hidden_dim = hidden_dim self.obs_dim = obs_dim
def test_funsor_to_mvn(batch_shape, event_shape, real_size): expected = random_mvn(batch_shape + event_shape, real_size) event_dims = tuple("abc"[:len(event_shape)]) ndims = len(expected.batch_shape) funsor_ = dist_to_funsor(expected, event_dims)(value="value") assert isinstance(funsor_, Funsor) actual = funsor_to_mvn(funsor_, ndims, event_dims) assert isinstance(actual, dist.MultivariateNormal) assert actual.batch_shape == expected.batch_shape assert_close(actual.loc, expected.loc, atol=1e-3, rtol=None) assert_close(actual.precision_matrix, expected.precision_matrix, atol=1e-3, rtol=None)
def __init__(self, initial_logits, transition_logits, observation_dist, validate_args=None): assert isinstance(initial_logits, torch.Tensor) assert isinstance(transition_logits, torch.Tensor) assert isinstance(observation_dist, torch.distributions.Distribution) assert initial_logits.dim() >= 1 assert transition_logits.dim() >= 2 assert len(observation_dist.batch_shape) >= 1 shape = broadcast_shape(initial_logits.shape[:-1] + (1, ), transition_logits.shape[:-2], observation_dist.batch_shape[:-1]) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + observation_dist.event_shape self._has_rsample = observation_dist.has_rsample # Normalize. initial_logits = initial_logits - initial_logits.logsumexp(-1, True) transition_logits = transition_logits - transition_logits.logsumexp( -1, True) # Convert tensors and distributions to funsors. init = tensor_to_funsor(initial_logits, ("state", )) trans = tensor_to_funsor(transition_logits, ("time", "state", "state(time=1)")) obs = dist_to_funsor(observation_dist, ("time", "state(time=1)")) dtype = obs.inputs["value"].dtype # Construct the joint funsor. with interpretation(lazy): # TODO perform math here once sequential_sum_product has been # implemented as a first-class funsor. funsor_dist = Variable("value", obs.inputs["value"]) # a bogus value # Until funsor_dist is defined, we save factors for hand-computation in .log_prob(). self._init = init self._trans = trans self._obs = obs super(DiscreteHMM, self).__init__(funsor_dist, batch_shape, event_shape, dtype, validate_args)
def test_funsor_to_cat_and_mvn(batch_shape, event_shape, int_size, real_size): logits = torch.randn(batch_shape + event_shape + (int_size, )) expected_cat = dist.Categorical(logits=logits) expected_mvn = random_mvn(batch_shape + event_shape + (int_size, ), real_size) event_dims = tuple("abc"[:len(event_shape)]) + ("component", ) ndims = len(expected_cat.batch_shape) funsor_ = (tensor_to_funsor(logits, event_dims) + dist_to_funsor(expected_mvn, event_dims)(value="value")) assert isinstance(funsor_, Funsor) actual_cat, actual_mvn = funsor_to_cat_and_mvn(funsor_, ndims, event_dims) assert isinstance(actual_cat, dist.Categorical) assert isinstance(actual_mvn, dist.MultivariateNormal) assert actual_cat.batch_shape == expected_cat.batch_shape assert actual_mvn.batch_shape == expected_mvn.batch_shape assert_close(actual_cat.logits, expected_cat.logits, atol=1e-4, rtol=None) assert_close(actual_mvn.loc, expected_mvn.loc, atol=1e-4, rtol=None) assert_close(actual_mvn.precision_matrix, expected_mvn.precision_matrix, atol=1e-4, rtol=None)
def forward(self, features, trip_counts): pyro.module("guide", self) assert features.dim() == 2 assert trip_counts.dim() == 3 observed_hours = len(trip_counts) log_counts = trip_counts.reshape(observed_hours, -1).log1p() loc_scale = ((self.diag_part * log_counts.unsqueeze(-2)).reshape( observed_hours, -1) + self.lowrank( torch.cat([features[:observed_hours], log_counts], dim=-1))) loc, scale = loc_scale.reshape(observed_hours, 2, -1).unbind(1) scale = bounded_exp(scale, bound=10.) if self.args.funsor: diag_normal = dist.Normal(loc, scale).to_event(2) return dist_to_funsor(diag_normal)(value="gate_rate_t") pyro.sample("gate_rate", dist.Normal(loc, scale).to_event(2)) if self.args.mean_field: time = torch.arange(observed_hours, dtype=features.dtype, device=features.device) temp = torch.cat([ time.unsqueeze(-1), (observed_hours - 1 - time).unsqueeze(-1), features[:observed_hours], log_counts, ], dim=-1) temp = self.mf_layer_0(temp).sigmoid() temp = self.mf_layer_1(temp).sigmoid() temp = (self.mf_highpass * temp + self.mf_lowpass * temp.mean(0, keepdim=True)) temp = torch.cat([temp[:1], temp], dim=0) # copy initial state. loc = temp[:, :self.args.state_dim] scale = bounded_exp(temp[:, self.args.state_dim:], bound=10.) pyro.sample("state", dist.Normal(loc, scale).to_event(2))
def forward(self, observations, add_bias=True): obs_dim = 2 * self.num_sensors bias_scale = self.log_bias_scale.exp() obs_noise = self.log_obs_noise.exp() trans_noise = self.log_trans_noise.exp() # bias distribution bias = Variable('bias', reals(obs_dim)) assert not torch.isnan(bias_scale), "bias scales was nan" bias_dist = dist_to_funsor( dist.MultivariateNormal( torch.zeros(obs_dim), scale_tril=bias_scale * torch.eye(2 * self.num_sensors)))(value=bias) init_dist = torch.distributions.MultivariateNormal(torch.zeros(4), scale_tril=100. * torch.eye(4)) self.init = dist_to_funsor(init_dist)(value="state") # hidden states prev = Variable("prev", reals(4)) curr = Variable("curr", reals(4)) self.trans_dist = f_dist.MultivariateNormal( loc=prev @ NCV_TRANSITION_MATRIX, scale_tril=trans_noise * NCV_PROCESS_NOISE.cholesky(), value=curr) state = Variable('state', reals(4)) obs = Variable("obs", reals(obs_dim)) observation_matrix = Tensor( torch.eye(4, 2).unsqueeze(-1).expand(-1, -1, self.num_sensors).reshape(4, -1)) assert observation_matrix.output.shape == ( 4, obs_dim), observation_matrix.output.shape obs_loc = state @ observation_matrix if add_bias: obs_loc += bias self.observation_dist = f_dist.MultivariateNormal( loc=obs_loc, scale_tril=obs_noise * torch.eye(obs_dim), value=obs) logp = bias_dist curr = "state_init" logp += self.init(state=curr) for t, x in enumerate(observations): prev, curr = curr, f"state_{t}" logp += self.trans_dist(prev=prev, curr=curr) logp += self.observation_dist(state=curr, obs=x) # marginalize out previous state logp = logp.reduce(ops.logaddexp, prev) # marginalize out bias variable logp = logp.reduce(ops.logaddexp, "bias") # save posterior over the final state assert set(logp.inputs) == {f'state_{len(observations) - 1}'} posterior = funsor_to_mvn(logp, ndims=0) # marginalize out remaining variables logp = logp.reduce(ops.logaddexp) assert isinstance(logp, Tensor) and logp.shape == (), logp.pretty() return logp.data, posterior
def __init__(self, initial_logits, initial_mvn, transition_logits, transition_matrix, transition_mvn, observation_matrix, observation_mvn, exact=False, validate_args=None): assert isinstance(initial_logits, torch.Tensor) assert isinstance(initial_mvn, torch.distributions.MultivariateNormal) assert isinstance(transition_logits, torch.Tensor) assert isinstance(transition_matrix, torch.Tensor) assert isinstance(transition_mvn, torch.distributions.MultivariateNormal) assert isinstance(observation_matrix, torch.Tensor) assert isinstance(observation_mvn, torch.distributions.MultivariateNormal) hidden_cardinality = initial_logits.size(-1) hidden_dim, obs_dim = observation_matrix.shape[-2:] assert obs_dim >= hidden_dim // 2, "obs_dim must be at least half of hidden_dim" assert initial_mvn.event_shape[0] == hidden_dim assert transition_logits.size(-1) == hidden_cardinality assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim) assert transition_mvn.event_shape[0] == hidden_dim assert observation_mvn.event_shape[0] == obs_dim init_shape = broadcast_shape(initial_logits.shape, initial_mvn.batch_shape) shape = broadcast_shape(init_shape[:-1] + (1, init_shape[-1]), transition_logits.shape[:-1], transition_matrix.shape[:-2], transition_mvn.batch_shape, observation_matrix.shape[:-2], observation_mvn.batch_shape) assert shape[-1] == hidden_cardinality batch_shape, time_shape = shape[:-2], shape[-2:-1] event_shape = time_shape + (obs_dim, ) # Normalize. initial_logits = initial_logits - initial_logits.logsumexp(-1, True) transition_logits = transition_logits - transition_logits.logsumexp( -1, True) # Convert tensors and distributions to funsors. init = (tensor_to_funsor(initial_logits, ("class", )) + dist_to_funsor(initial_mvn, ("class", ))(value="state")) trans = (tensor_to_funsor(transition_logits, ("time", "class", "class(time=1)")) + matrix_and_mvn_to_funsor(transition_matrix, transition_mvn, ("time", "class(time=1)"), "state", "state(time=1)")) obs = matrix_and_mvn_to_funsor(observation_matrix, observation_mvn, ("time", "class(time=1)"), "state(time=1)", "value") if "class(time=1)" not in set(trans.inputs).union(obs.inputs): raise ValueError( "neither transition nor observation depend on discrete state") dtype = "real" # Construct the joint funsor. with interpretation(lazy): # TODO perform math here once sequential_sum_product has been # implemented as a first-class funsor. funsor_dist = Variable("value", obs.inputs["value"]) # a bogus value # Until funsor_dist is defined, we save factors for hand-computation in .log_prob(). self._init = init self._trans = trans self._obs = obs super(SwitchingLinearHMM, self).__init__(funsor_dist, batch_shape, event_shape, dtype, validate_args) self.exact = exact