def test_expand(extra_shape, log_normalizer_shape, info_vec_shape, precision_shape, dim): rank = dim + dim log_normalizer = torch.randn(log_normalizer_shape) info_vec = torch.randn(info_vec_shape + (dim, )) precision = torch.randn(precision_shape + (dim, rank)) precision = precision.matmul(precision.transpose(-1, -2)) gaussian = Gaussian(log_normalizer, info_vec, precision) expected_shape = extra_shape + broadcast_shape( log_normalizer_shape, info_vec_shape, precision_shape) actual = gaussian.expand(expected_shape) assert actual.batch_shape == expected_shape
def random_gaussian(batch_shape, dim, rank=None): """ Generate a random Gaussian for testing. """ if rank is None: rank = dim + dim log_normalizer = torch.randn(batch_shape) info_vec = torch.randn(batch_shape + (dim, )) samples = torch.randn(batch_shape + (dim, rank)) precision = torch.matmul(samples, samples.transpose(-2, -1)) result = Gaussian(log_normalizer, info_vec, precision) assert result.dim() == dim assert result.batch_shape == batch_shape return result
def log_prob(self, value): # We compute a normalized distribution as p(obs,hidden) / p(hidden). logp_oh = self._trans logp_h = self._trans # Combine observation and transition factors. logp_oh += self._obs.condition(value).event_pad(left=self.hidden_dim) logp_h += self._obs.marginalize(right=self.obs_dim).event_pad( left=self.hidden_dim) # Concatenate p(obs,hidden) and p(hidden) into a single Gaussian. batch_dim = 1 + max( len(self._init.batch_shape) + 1, len(logp_oh.batch_shape)) batch_shape = (1, ) * (batch_dim - len(logp_oh.batch_shape)) + logp_oh.batch_shape logp = Gaussian.cat( [logp_oh.expand(batch_shape), logp_h.expand(batch_shape)]) # Eliminate time dimension. logp = _sequential_gaussian_tensordot(logp) # Combine initial factor. logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim) # Marginalize out final state. logp_oh, logp_h = logp.event_logsumexp() return logp_oh - logp_h # = log( p(obs,hidden) / p(hidden) )
def test_cat(shape, cat_dim, split, dim): assert sum(split) == shape[cat_dim] gaussian = random_gaussian(shape, dim) parts = [] end = 0 for size in split: beg, end = end, end + size if cat_dim == -1: part = gaussian[..., beg:end] elif cat_dim == -2: part = gaussian[..., beg:end, :] elif cat_dim == 1: part = gaussian[:, beg:end] else: raise ValueError parts.append(part) actual = Gaussian.cat(parts, cat_dim) assert_close_gaussian(actual, gaussian)
def _sequential_gaussian_tensordot(gaussian): """ Integrates a Gaussian ``x`` whose rightmost batch dimension is time, computes:: x[..., 0] @ x[..., 1] @ ... @ x[..., T-1] """ assert isinstance(gaussian, Gaussian) assert gaussian.dim() % 2 == 0, "dim is not even" batch_shape = gaussian.batch_shape[:-1] state_dim = gaussian.dim() // 2 while gaussian.batch_shape[-1] > 1: time = gaussian.batch_shape[-1] even_time = time // 2 * 2 even_part = gaussian[..., :even_time] x_y = even_part.reshape(batch_shape + (even_time // 2, 2)) x, y = x_y[..., 0], x_y[..., 1] contracted = gaussian_tensordot(x, y, state_dim) if time > even_time: contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1) gaussian = contracted return gaussian[..., 0]
def test_gaussian_funsor(batch_shape): # This tests sample distribution, rsample gradients, log_prob, and log_prob # gradients for both Pyro's and Funsor's Gaussian. import funsor funsor.set_backend("torch") num_samples = 100000 # Declare unconstrained parameters. loc = torch.randn(batch_shape + (3, )).requires_grad_() t = transform_to(constraints.positive_definite) m = torch.randn(batch_shape + (3, 3)) precision_unconstrained = t.inv(m @ m.transpose(-1, -2)).requires_grad_() # Transform to constrained space. log_normalizer = torch.zeros(batch_shape) precision = t(precision_unconstrained) info_vec = (precision @ loc[..., None])[..., 0] def check_equal(actual, expected, atol=0.01, rtol=0): assert_close(actual.data, expected.data, atol=atol, rtol=rtol) grads = torch.autograd.grad( (actual - expected).abs().sum(), [loc, precision_unconstrained], retain_graph=True, ) for grad in grads: assert grad.abs().max() < atol entropy = dist.MultivariateNormal(loc, precision_matrix=precision).entropy() # Monte carlo estimate entropy via pyro. p_gaussian = Gaussian(log_normalizer, info_vec, precision) p_log_Z = p_gaussian.event_logsumexp() p_rsamples = p_gaussian.rsample((num_samples, )) pp_entropy = (p_log_Z - p_gaussian.log_density(p_rsamples)).mean(0) check_equal(pp_entropy, entropy) # Monte carlo estimate entropy via funsor. inputs = OrderedDict([(k, funsor.Bint[v]) for k, v in zip("ij", batch_shape)]) inputs["x"] = funsor.Reals[3] f_gaussian = funsor.gaussian.Gaussian(mean=loc, precision=precision, inputs=inputs) f_log_Z = f_gaussian.reduce(funsor.ops.logaddexp, "x") sample_inputs = OrderedDict(particle=funsor.Bint[num_samples]) deltas = f_gaussian.sample("x", sample_inputs) f_rsamples = funsor.montecarlo.extract_samples(deltas)["x"] ff_entropy = (f_log_Z - f_gaussian(x=f_rsamples)).reduce( funsor.ops.mean, "particle") check_equal(ff_entropy.data, entropy) # Check Funsor's .rsample against Pyro's .log_prob. pf_entropy = (p_log_Z - p_gaussian.log_density(f_rsamples.data)).mean(0) check_equal(pf_entropy, entropy) # Check Pyro's .rsample against Funsor's .log_prob. fp_rsamples = funsor.Tensor(p_rsamples)["particle"] for i in "ij"[:len(batch_shape)]: fp_rsamples = fp_rsamples[i] fp_entropy = (f_log_Z - f_gaussian(x=fp_rsamples)).reduce( funsor.ops.mean, "particle") check_equal(fp_entropy.data, entropy)