def test_mvn_to_funsor(batch_shape, event_shape, event_sizes): event_size = sum(event_sizes) mvn = random_mvn(batch_shape + event_shape, event_size) int_inputs = OrderedDict( (k, bint(size)) for k, size in zip("abc", event_shape)) real_inputs = OrderedDict( (k, reals(size)) for k, size in zip("xyz", event_sizes)) f = mvn_to_funsor(mvn, tuple(int_inputs), real_inputs) assert isinstance(f, Funsor) for k, d in int_inputs.items(): if d.num_elements == 1: assert d not in f.inputs else: assert k in f.inputs assert f.inputs[k] == d for k, d in real_inputs.items(): assert k in f.inputs assert f.inputs[k] == d value = mvn.sample() subs = {} beg = 0 for k, d in real_inputs.items(): end = beg + d.num_elements subs[k] = tensor_to_funsor(value[..., beg:end], tuple(int_inputs), 1) beg = end actual_log_prob = f(**subs) expected_log_prob = tensor_to_funsor(mvn.log_prob(value), tuple(int_inputs)) assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5)
def test_dist_to_funsor_bernoulli(batch_shape): logits = torch.randn(batch_shape) d = dist.Bernoulli(logits=logits) f = dist_to_funsor(d) assert isinstance(f, Funsor) value = d.sample() actual_log_prob = f(value=tensor_to_funsor(value)) expected_log_prob = tensor_to_funsor(d.log_prob(value)) assert_close(actual_log_prob, expected_log_prob)
def test_dist_to_funsor_normal(batch_shape): loc = torch.randn(batch_shape) scale = torch.randn(batch_shape).exp() d = dist.Normal(loc, scale) f = dist_to_funsor(d) assert isinstance(f, Funsor) value = d.sample() actual_log_prob = f(value=tensor_to_funsor(value)) expected_log_prob = tensor_to_funsor(d.log_prob(value)) assert_close(actual_log_prob, expected_log_prob, rtol=1e-5)
def test_dist_to_funsor_independent(batch_shape, event_shape): loc = torch.randn(batch_shape + event_shape) scale = torch.randn(batch_shape + event_shape).exp() d = dist.Normal(loc, scale).to_event(len(event_shape)) f = dist_to_funsor(d) assert isinstance(f, Funsor) value = d.sample() funsor_value = tensor_to_funsor(value, event_output=len(event_shape)) actual_log_prob = f(value=funsor_value) expected_log_prob = tensor_to_funsor(d.log_prob(value)) assert_close(actual_log_prob, expected_log_prob, rtol=1e-5)
def test_dist_to_funsor_mvn(batch_shape, event_size): loc = torch.randn(batch_shape + (event_size, )) cov = torch.randn(batch_shape + (event_size, 2 * event_size)) cov = cov.matmul(cov.transpose(-1, -2)) scale_tril = torch.cholesky(cov) d = dist.MultivariateNormal(loc, scale_tril=scale_tril) f = dist_to_funsor(d) assert isinstance(f, Funsor) value = d.sample() actual_log_prob = f(value=tensor_to_funsor(value, event_output=1)) expected_log_prob = tensor_to_funsor(d.log_prob(value)) assert_close(actual_log_prob, expected_log_prob)
def test_dist_to_funsor_masked(batch_shape): loc = torch.randn(batch_shape) scale = torch.randn(batch_shape).exp() mask = torch.bernoulli(torch.full(batch_shape, 0.5)).byte() d = dist.Normal(loc, scale).mask(mask) assert isinstance(d, MaskedDistribution) f = dist_to_funsor(d) assert isinstance(f, Funsor) value = d.sample() actual_log_prob = f(value=tensor_to_funsor(value)) expected_log_prob = tensor_to_funsor(d.log_prob(value)) assert_close(actual_log_prob, expected_log_prob)
def __init__(self, logits, validate_args=None): batch_shape = logits.shape[:-1] event_shape = torch.Size() funsor_dist = tensor_to_funsor(logits, ("value", )) dtype = int(logits.size(-1)) super(Categorical, self).__init__(funsor_dist, batch_shape, event_shape, dtype, validate_args)
def filter(self, value): """ Compute posterior over final state given a sequence of observations. :param ~torch.Tensor value: A sequence of observations. :return: A posterior distribution over latent states at the final time step, represented as a pair ``(cat, mvn)``, where :class:`~pyro.distributions.Categorical` distribution over mixture components and ``mvn`` is a :class:`~pyro.distributions.MultivariateNormal` with rightmost batch dimension ranging over mixture components. This can then be used to initialize a sequential Pyro model for prediction. :rtype: tuple """ ndims = max(len(self.batch_shape), value.dim() - 2) time = Variable("time", Bint[self.event_shape[0]]) value = tensor_to_funsor(value, ("time", ), 1) seq_sum_prod = naive_sequential_sum_product if self.exact else sequential_sum_product with interpretation(eager if self.exact else moment_matching): logp = self._trans + self._obs(value=value) logp = seq_sum_prod(ops.logaddexp, ops.add, logp, time, { "class": "class(time=1)", "state": "state(time=1)" }) logp += self._init logp = logp.reduce(ops.logaddexp, frozenset(["class", "state"])) cat, mvn = funsor_to_cat_and_mvn(logp, ndims, ("class(time=1)", )) cat = cat.expand(self.batch_shape) mvn = mvn.expand(self.batch_shape + cat.logits.shape[-1:]) return cat, mvn
def test_matrix_and_mvn_to_funsor(batch_shape, event_shape, x_size, y_size): matrix = torch.randn(batch_shape + event_shape + (x_size, y_size)) y_mvn = random_mvn(batch_shape + event_shape, y_size) xy_mvn = random_mvn(batch_shape + event_shape, x_size + y_size) int_inputs = OrderedDict( (k, bint(size)) for k, size in zip("abc", event_shape)) real_inputs = OrderedDict([("x", reals(x_size)), ("y", reals(y_size))]) f = (matrix_and_mvn_to_funsor(matrix, y_mvn, tuple(int_inputs), "x", "y") + mvn_to_funsor(xy_mvn, tuple(int_inputs), real_inputs)) assert isinstance(f, Funsor) for k, d in int_inputs.items(): if d.num_elements == 1: assert d not in f.inputs else: assert k in f.inputs assert f.inputs[k] == d assert f.inputs["x"] == reals(x_size) assert f.inputs["y"] == reals(y_size) xy = torch.randn(x_size + y_size) x, y = xy[:x_size], xy[x_size:] y_pred = x.unsqueeze(-2).matmul(matrix).squeeze(-2) actual_log_prob = f(x=x, y=y) expected_log_prob = tensor_to_funsor( xy_mvn.log_prob(xy) + y_mvn.log_prob(y - y_pred), tuple(int_inputs)) assert_close(actual_log_prob, expected_log_prob, atol=1e-4, rtol=1e-4)
def test_dist_to_funsor_categorical(batch_shape, cardinality): logits = torch.randn(batch_shape + (cardinality, )) logits -= logits.logsumexp(dim=-1, keepdim=True) d = dist.Categorical(logits=logits) f = dist_to_funsor(d) assert isinstance(f, Tensor) expected = tensor_to_funsor(logits, ("value", )) assert_close(f, expected)
def log_prob(self, value): if self._validate_args: self._validate_sample(value) ndims = max(len(self.batch_shape), value.dim() - self.event_dim) value = tensor_to_funsor(value, event_output=self.event_dim, dtype=self.dtype) log_prob = self.funsor_dist(value=value) log_prob = funsor_to_tensor(log_prob, ndims=ndims) return log_prob
def __init__(self, initial_logits, transition_logits, observation_dist, validate_args=None): assert isinstance(initial_logits, torch.Tensor) assert isinstance(transition_logits, torch.Tensor) assert isinstance(observation_dist, torch.distributions.Distribution) assert initial_logits.dim() >= 1 assert transition_logits.dim() >= 2 assert len(observation_dist.batch_shape) >= 1 shape = broadcast_shape(initial_logits.shape[:-1] + (1, ), transition_logits.shape[:-2], observation_dist.batch_shape[:-1]) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + observation_dist.event_shape self._has_rsample = observation_dist.has_rsample # Normalize. initial_logits = initial_logits - initial_logits.logsumexp(-1, True) transition_logits = transition_logits - transition_logits.logsumexp( -1, True) # Convert tensors and distributions to funsors. init = tensor_to_funsor(initial_logits, ("state", )) trans = tensor_to_funsor(transition_logits, ("time", "state", "state(time=1)")) obs = dist_to_funsor(observation_dist, ("time", "state(time=1)")) dtype = obs.inputs["value"].dtype # Construct the joint funsor. with interpretation(lazy): # TODO perform math here once sequential_sum_product has been # implemented as a first-class funsor. funsor_dist = Variable("value", obs.inputs["value"]) # a bogus value # Until funsor_dist is defined, we save factors for hand-computation in .log_prob(). self._init = init self._trans = trans self._obs = obs super(DiscreteHMM, self).__init__(funsor_dist, batch_shape, event_shape, dtype, validate_args)
def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(type(self), _instance) batch_shape = torch.Size(batch_shape) funsor_dist = self.funsor_dist + tensor_to_funsor( torch.zeros(batch_shape)) super(type(self), new).__init__(funsor_dist, batch_shape, self.event_shape, self.dtype, validate_args=False) new.validate_args = self.__dict__.get('_validate_args') return new
def test_matrix_and_mvn_to_funsor_diag(batch_shape, x_size, y_size): matrix = torch.randn(batch_shape + (x_size, y_size)) loc = torch.randn(batch_shape + (y_size, )) scale = torch.randn(batch_shape + (y_size, )).exp() normal = dist.Normal(loc, scale).to_event(1) actual = matrix_and_mvn_to_funsor(matrix, normal) assert isinstance(actual, AffineNormal) mvn = dist.MultivariateNormal(loc, scale_tril=scale.diag_embed()) expected = matrix_and_mvn_to_funsor(matrix, mvn) y = tensor_to_funsor(torch.randn(batch_shape + (y_size, )), (), 1) actual_like = actual(value_y=y) expected_like = expected(value_y=y) assert_close(actual_like, expected_like, atol=1e-4, rtol=1e-4) x = tensor_to_funsor(torch.randn(batch_shape + (x_size, )), (), 1) actual_norm = actual_like(value_x=x) expected_norm = expected_like(value_x=x) assert_close(actual_norm, expected_norm, atol=1e-4, rtol=1e-4)
def _forward_funsor(self, features, trip_counts): total_hours = len(features) observed_hours, num_origins, num_destins = trip_counts.shape assert observed_hours == total_hours assert num_origins == self.num_stations assert num_destins == self.num_stations n = self.num_stations gate_rate = funsor.Variable("gate_rate_t", reals(observed_hours, 2 * n * n))["time"] @funsor.torch.function(reals(2 * n * n), (reals(n, n, 2), reals(n, n))) def unpack_gate_rate(gate_rate): batch_shape = gate_rate.shape[:-1] gate, rate = gate_rate.reshape(batch_shape + (2, n, n)).unbind(-3) gate = gate.sigmoid().clamp(min=0.01, max=0.99) rate = bounded_exp(rate, bound=1e4) gate = torch.stack((1 - gate, gate), dim=-1) return gate, rate # Create a Gaussian latent dynamical system. init_dist, trans_matrix, trans_dist, obs_matrix, obs_dist = \ self._dynamics(features[:observed_hours]) init = dist_to_funsor(init_dist)(value="state") trans = matrix_and_mvn_to_funsor(trans_matrix, trans_dist, ("time", ), "state", "state(time=1)") obs = matrix_and_mvn_to_funsor(obs_matrix, obs_dist, ("time", ), "state(time=1)", "gate_rate") # Compute dynamic prior over gate_rate. prior = trans + obs(gate_rate=gate_rate) prior = MarkovProduct(ops.logaddexp, ops.add, prior, "time", {"state": "state(time=1)"}) prior += init prior = prior.reduce(ops.logaddexp, {"state", "state(time=1)"}) # Compute zero-inflated Poisson likelihood. gate, rate = unpack_gate_rate(gate_rate) likelihood = fdist.Categorical(gate["origin", "destin"], value="gated") trip_counts = tensor_to_funsor(trip_counts, ("time", "origin", "destin")) likelihood += funsor.Stack( "gated", (fdist.Poisson(rate["origin", "destin"], value=trip_counts), fdist.Delta(0, value=trip_counts))) likelihood = likelihood.reduce(ops.logaddexp, "gated") likelihood = likelihood.reduce(ops.add, {"time", "origin", "destin"}) assert set(prior.inputs) == {"gate_rate_t"}, prior.inputs assert set(likelihood.inputs) == {"gate_rate_t"}, likelihood.inputs return prior, likelihood
def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(SwitchingLinearHMM, _instance) batch_shape = torch.Size(batch_shape) new._init = self._init + tensor_to_funsor(torch.zeros(batch_shape)) new._trans = self._trans new._obs = self._obs new.exact = self.exact super(SwitchingLinearHMM, new).__init__(self.funsor_dist, batch_shape, self.event_shape, self.dtype, validate_args=False) new.validate_args = self.__dict__.get('_validate_args') return new
def log_prob(self, value): if self._validate_args: self._validate_sample(value) ndims = max(len(self.batch_shape), value.dim() - self.event_dim) time = Variable("time", Bint[self.event_shape[0]]) value = tensor_to_funsor(value, ("time", ), event_output=self.event_dim - 1, dtype=self.dtype) # Compare with pyro.distributions.hmm.DiscreteHMM.log_prob(). obs = self._obs(value=value) result = self._trans + obs result = sequential_sum_product(ops.logaddexp, ops.add, result, time, {"state": "state(time=1)"}) result = self._init + result.reduce(ops.logaddexp, "state(time=1)") result = result.reduce(ops.logaddexp, "state") result = funsor_to_tensor(result, ndims=ndims) return result
def log_prob(self, value): ndims = max(len(self.batch_shape), value.dim() - 2) time = Variable("time", Bint[self.event_shape[0]]) value = tensor_to_funsor(value, ("time", ), 1) seq_sum_prod = naive_sequential_sum_product if self.exact else sequential_sum_product with interpretation(eager if self.exact else moment_matching): result = self._trans + self._obs(value=value) result = seq_sum_prod(ops.logaddexp, ops.add, result, time, { "class": "class(time=1)", "state": "state(time=1)" }) result += self._init result = result.reduce( ops.logaddexp, frozenset(["class", "state", "class(time=1)", "state(time=1)"])) result = funsor_to_tensor(result, ndims=ndims) return result
def test_funsor_to_cat_and_mvn(batch_shape, event_shape, int_size, real_size): logits = torch.randn(batch_shape + event_shape + (int_size, )) expected_cat = dist.Categorical(logits=logits) expected_mvn = random_mvn(batch_shape + event_shape + (int_size, ), real_size) event_dims = tuple("abc"[:len(event_shape)]) + ("component", ) ndims = len(expected_cat.batch_shape) funsor_ = (tensor_to_funsor(logits, event_dims) + dist_to_funsor(expected_mvn, event_dims)(value="value")) assert isinstance(funsor_, Funsor) actual_cat, actual_mvn = funsor_to_cat_and_mvn(funsor_, ndims, event_dims) assert isinstance(actual_cat, dist.Categorical) assert isinstance(actual_mvn, dist.MultivariateNormal) assert actual_cat.batch_shape == expected_cat.batch_shape assert actual_mvn.batch_shape == expected_mvn.batch_shape assert_close(actual_cat.logits, expected_cat.logits, atol=1e-4, rtol=None) assert_close(actual_mvn.loc, expected_mvn.loc, atol=1e-4, rtol=None) assert_close(actual_mvn.precision_matrix, expected_mvn.precision_matrix, atol=1e-4, rtol=None)
def test_tensor_funsor_tensor(batch_shape, event_shape, event_output): event_inputs = ("foo", "bar", "baz")[:len(event_shape) - event_output] t = torch.randn(batch_shape + event_shape) f = tensor_to_funsor(t, event_inputs, event_output) t2 = funsor_to_tensor(f, t.dim(), event_inputs) assert_close(t2, t)
def __init__(self, initial_logits, initial_mvn, transition_logits, transition_matrix, transition_mvn, observation_matrix, observation_mvn, exact=False, validate_args=None): assert isinstance(initial_logits, torch.Tensor) assert isinstance(initial_mvn, torch.distributions.MultivariateNormal) assert isinstance(transition_logits, torch.Tensor) assert isinstance(transition_matrix, torch.Tensor) assert isinstance(transition_mvn, torch.distributions.MultivariateNormal) assert isinstance(observation_matrix, torch.Tensor) assert isinstance(observation_mvn, torch.distributions.MultivariateNormal) hidden_cardinality = initial_logits.size(-1) hidden_dim, obs_dim = observation_matrix.shape[-2:] assert obs_dim >= hidden_dim // 2, "obs_dim must be at least half of hidden_dim" assert initial_mvn.event_shape[0] == hidden_dim assert transition_logits.size(-1) == hidden_cardinality assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim) assert transition_mvn.event_shape[0] == hidden_dim assert observation_mvn.event_shape[0] == obs_dim init_shape = broadcast_shape(initial_logits.shape, initial_mvn.batch_shape) shape = broadcast_shape(init_shape[:-1] + (1, init_shape[-1]), transition_logits.shape[:-1], transition_matrix.shape[:-2], transition_mvn.batch_shape, observation_matrix.shape[:-2], observation_mvn.batch_shape) assert shape[-1] == hidden_cardinality batch_shape, time_shape = shape[:-2], shape[-2:-1] event_shape = time_shape + (obs_dim, ) # Normalize. initial_logits = initial_logits - initial_logits.logsumexp(-1, True) transition_logits = transition_logits - transition_logits.logsumexp( -1, True) # Convert tensors and distributions to funsors. init = (tensor_to_funsor(initial_logits, ("class", )) + dist_to_funsor(initial_mvn, ("class", ))(value="state")) trans = (tensor_to_funsor(transition_logits, ("time", "class", "class(time=1)")) + matrix_and_mvn_to_funsor(transition_matrix, transition_mvn, ("time", "class(time=1)"), "state", "state(time=1)")) obs = matrix_and_mvn_to_funsor(observation_matrix, observation_mvn, ("time", "class(time=1)"), "state(time=1)", "value") if "class(time=1)" not in set(trans.inputs).union(obs.inputs): raise ValueError( "neither transition nor observation depend on discrete state") dtype = "real" # Construct the joint funsor. with interpretation(lazy): # TODO perform math here once sequential_sum_product has been # implemented as a first-class funsor. funsor_dist = Variable("value", obs.inputs["value"]) # a bogus value # Until funsor_dist is defined, we save factors for hand-computation in .log_prob(). self._init = init self._trans = trans self._obs = obs super(SwitchingLinearHMM, self).__init__(funsor_dist, batch_shape, event_shape, dtype, validate_args) self.exact = exact