def test_transformed_hmm_shape(batch_shape, duration, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) trans_dist = random_mvn(batch_shape + (duration, ), hidden_dim) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) obs_dist = dist.LogNormal( torch.randn(batch_shape + (duration, obs_dim)), torch.rand(batch_shape + (duration, obs_dim)).exp()).to_event(1) hmm = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) def model(data=None): with pyro.plate_stack("plates", batch_shape): return pyro.sample("x", hmm, obs=data) data = model() with poutine.trace() as tr: with poutine.reparam(config={"x": LinearHMMReparam()}): model(data) fn = tr.trace.nodes["x"]["fn"] assert isinstance(fn, dist.TransformedDistribution) assert isinstance(fn.base_dist, dist.GaussianHMM) tr.trace.compute_log_prob() # smoke test only
def test_gaussian_hmm_shape(diag, init_shape, trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim): init_dist = random_mvn(init_shape, hidden_dim) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(trans_mvn_shape, hidden_dim) obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim)) obs_dist = random_mvn(obs_mvn_shape, obs_dim) if diag: scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1) obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) shape = broadcast_shape(init_shape + (1,), trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim,) assert d.batch_shape == expected_batch_shape assert d.event_shape == expected_event_shape data = obs_dist.expand(shape).sample() assert data.shape == d.shape() actual = d.log_prob(data) assert actual.shape == expected_batch_shape check_expand(d, data) final = d.filter(data) assert isinstance(final, dist.MultivariateNormal) assert final.batch_shape == d.batch_shape assert final.event_shape == (hidden_dim,)
def test_gamma_gaussian_hmm_shape(scale_shape, init_shape, trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim): init_dist = random_mvn(init_shape, hidden_dim) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(trans_mvn_shape, hidden_dim) obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim)) obs_dist = random_mvn(obs_mvn_shape, obs_dim) scale_dist = random_gamma(scale_shape) d = dist.GammaGaussianHMM(scale_dist, init_dist, trans_mat, trans_dist, obs_mat, obs_dist) shape = broadcast_shape(scale_shape + (1, ), init_shape + (1, ), trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim, ) assert d.batch_shape == expected_batch_shape assert d.event_shape == expected_event_shape assert d.support.event_dim == d.event_dim data = obs_dist.expand(shape).sample() assert data.shape == d.shape() actual = d.log_prob(data) assert actual.shape == expected_batch_shape check_expand(d, data) mixing, final = d.filter(data) assert isinstance(mixing, dist.Gamma) assert mixing.batch_shape == d.batch_shape assert mixing.event_shape == () assert isinstance(final, dist.MultivariateNormal) assert final.batch_shape == d.batch_shape assert final.event_shape == (hidden_dim, )
def test_gamma_gaussian_hmm_log_prob(sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim)) trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim) obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim)) obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim) scale_dist = random_gamma(batch_shape) d = dist.GammaGaussianHMM(scale_dist, init_dist, trans_mat, trans_dist, obs_mat, obs_dist) obs_mvn = obs_dist data = obs_dist.sample(sample_shape) assert data.shape == sample_shape + d.shape() actual_log_prob = d.log_prob(data) # Compare against hand-computed density. # We will construct enormous unrolled joint gaussian-gammas with shapes: # t | 0 1 2 3 1 2 3 T = 3 in this example # ------+----------------------------------------- # init | H # trans | H H H H H = hidden # obs | H H H O O O O = observed # and then combine these using gamma_gaussian_tensordot(). T = num_steps init = gamma_and_mvn_to_gamma_gaussian(scale_dist, init_dist) trans = matrix_and_mvn_to_gamma_gaussian(trans_mat, trans_dist) obs = matrix_and_mvn_to_gamma_gaussian(obs_mat, obs_mvn) unrolled_trans = reduce( operator.add, [ trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) for t in range(T) ], ) unrolled_obs = reduce( operator.add, [ obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) for t in range(T) ], ) # Permute obs from HOHOHO to HHHOOO. perm = torch.cat( [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)]) unrolled_obs = unrolled_obs.event_permute(perm) unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, )) assert init.dim() == hidden_dim assert unrolled_trans.dim() == (1 + T) * hidden_dim assert unrolled_obs.dim() == T * (hidden_dim + obs_dim) logp = gamma_gaussian_tensordot(init, unrolled_trans, hidden_dim) logp = gamma_gaussian_tensordot(logp, unrolled_obs, T * hidden_dim) # compute log_prob of the joint student-t distribution expected_log_prob = logp.compound().log_prob(unrolled_data) assert_close(actual_log_prob, expected_log_prob)
def test_gaussian_mrf_log_prob(sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim + hidden_dim) obs_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim + obs_dim) d = dist.GaussianMRF(init_dist, trans_dist, obs_dist) data = obs_dist.sample(sample_shape)[..., hidden_dim:] assert data.shape == sample_shape + d.shape() actual_log_prob = d.log_prob(data) # Compare against hand-computed density. # We will construct enormous unrolled joint gaussians with shapes: # t | 0 1 2 3 1 2 3 T = 3 in this example # ------+----------------------------------------- # init | H # trans | H H H H H = hidden # obs | H H H O O O O = observed # and then combine these using gaussian_tensordot(). T = num_steps init = mvn_to_gaussian(init_dist) trans = mvn_to_gaussian(trans_dist) obs = mvn_to_gaussian(obs_dist) unrolled_trans = reduce( operator.add, [ trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) for t in range(T) ], ) unrolled_obs = reduce( operator.add, [ obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) for t in range(T) ], ) # Permute obs from HOHOHO to HHHOOO. perm = torch.cat( [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)]) unrolled_obs = unrolled_obs.event_permute(perm) unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, )) assert init.dim() == hidden_dim assert unrolled_trans.dim() == (1 + T) * hidden_dim assert unrolled_obs.dim() == T * (hidden_dim + obs_dim) logp_h = gaussian_tensordot(init, unrolled_trans, hidden_dim) logp_oh = gaussian_tensordot(logp_h, unrolled_obs, T * hidden_dim) logp_h += unrolled_obs.marginalize(right=T * obs_dim) expected_log_prob = logp_oh.log_density( unrolled_data) - logp_h.event_logsumexp() assert_close(actual_log_prob, expected_log_prob)
def test_gaussian_hmm_log_prob(diag, sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim)) trans_dist = random_mvn(batch_shape + (num_steps,), hidden_dim) obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim)) obs_dist = random_mvn(batch_shape + (num_steps,), obs_dim) if diag: scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1) obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) if diag: obs_mvn = dist.MultivariateNormal(obs_dist.base_dist.loc, scale_tril=obs_dist.base_dist.scale.diag_embed()) else: obs_mvn = obs_dist data = obs_dist.sample(sample_shape) assert data.shape == sample_shape + d.shape() actual_log_prob = d.log_prob(data) # Compare against hand-computed density. # We will construct enormous unrolled joint gaussians with shapes: # t | 0 1 2 3 1 2 3 T = 3 in this example # ------+----------------------------------------- # init | H # trans | H H H H H = hidden # obs | H H H O O O O = observed # and then combine these using gaussian_tensordot(). T = num_steps init = mvn_to_gaussian(init_dist) trans = matrix_and_mvn_to_gaussian(trans_mat, trans_dist) obs = matrix_and_mvn_to_gaussian(obs_mat, obs_mvn) unrolled_trans = reduce(operator.add, [ trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) for t in range(T) ]) unrolled_obs = reduce(operator.add, [ obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) for t in range(T) ]) # Permute obs from HOHOHO to HHHOOO. perm = torch.cat([torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)]) unrolled_obs = unrolled_obs.event_permute(perm) unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim,)) assert init.dim() == hidden_dim assert unrolled_trans.dim() == (1 + T) * hidden_dim assert unrolled_obs.dim() == T * (hidden_dim + obs_dim) logp = gaussian_tensordot(init, unrolled_trans, hidden_dim) logp = gaussian_tensordot(logp, unrolled_obs, T * hidden_dim) expected_log_prob = logp.log_density(unrolled_data) assert_close(actual_log_prob, expected_log_prob)
def test_matrix_and_mvn_to_gaussian(sample_shape, batch_shape, x_dim, y_dim): matrix = torch.randn(batch_shape + (x_dim, y_dim)) y_mvn = random_mvn(batch_shape, y_dim) xy_mvn = random_mvn(batch_shape, x_dim + y_dim) gaussian = matrix_and_mvn_to_gaussian(matrix, y_mvn) + mvn_to_gaussian(xy_mvn) xy = torch.randn(sample_shape + (1, ) * len(batch_shape) + (x_dim + y_dim, )) x, y = xy[..., :x_dim], xy[..., x_dim:] y_pred = x.unsqueeze(-2).matmul(matrix).squeeze(-2) actual_log_prob = gaussian.log_density(xy) expected_log_prob = xy_mvn.log_prob(xy) + y_mvn.log_prob(y - y_pred) assert_close(actual_log_prob, expected_log_prob)
def test_matrix_and_mvn_to_gaussian_2(sample_shape, batch_shape, x_dim, y_dim): matrix = torch.randn(batch_shape + (x_dim, y_dim)) y_mvn = random_mvn(batch_shape, y_dim) x_mvn = random_mvn(batch_shape, x_dim) Mx_cov = matrix.transpose(-2, -1).matmul( x_mvn.covariance_matrix).matmul(matrix) Mx_loc = matrix.transpose(-2, -1).matmul(x_mvn.loc.unsqueeze(-1)).squeeze(-1) mvn = dist.MultivariateNormal(Mx_loc + y_mvn.loc, Mx_cov + y_mvn.covariance_matrix) expected = mvn_to_gaussian(mvn) actual = gaussian_tensordot(mvn_to_gaussian(x_mvn), matrix_and_mvn_to_gaussian(matrix, y_mvn), dims=x_dim) assert_close_gaussian(expected, actual)
def test_rsample_shape(sample_shape, batch_shape, dim): mvn = random_mvn(batch_shape, dim) g = mvn_to_gaussian(mvn) expected = mvn.rsample(sample_shape) actual = g.rsample(sample_shape) assert actual.dtype == expected.dtype assert actual.shape == expected.shape
def test_mvn_to_gaussian(sample_shape, batch_shape, dim): mvn = random_mvn(batch_shape, dim) gaussian = mvn_to_gaussian(mvn) value = mvn.sample(sample_shape) actual_log_prob = gaussian.log_density(value) expected_log_prob = mvn.log_prob(value) assert_close(actual_log_prob, expected_log_prob)
def test_gaussian_mrf_shape(init_shape, trans_shape, obs_shape, hidden_dim, obs_dim): init_dist = random_mvn(init_shape, hidden_dim) trans_dist = random_mvn(trans_shape, hidden_dim + hidden_dim) obs_dist = random_mvn(obs_shape, hidden_dim + obs_dim) d = dist.GaussianMRF(init_dist, trans_dist, obs_dist) shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim,) assert d.batch_shape == expected_batch_shape assert d.event_shape == expected_event_shape data = obs_dist.expand(shape).sample()[..., hidden_dim:] assert data.shape == d.shape() actual = d.log_prob(data) assert actual.shape == expected_batch_shape check_expand(d, data)
def test_independent_hmm_shape(init_shape, trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim): base_init_shape = init_shape + (obs_dim, ) base_trans_mat_shape = trans_mat_shape[:-1] + (obs_dim, trans_mat_shape[-1] if trans_mat_shape else 6) base_trans_mvn_shape = trans_mvn_shape[:-1] + (obs_dim, trans_mvn_shape[-1] if trans_mvn_shape else 6) base_obs_mat_shape = obs_mat_shape[:-1] + (obs_dim, obs_mat_shape[-1] if obs_mat_shape else 6) base_obs_mvn_shape = obs_mvn_shape[:-1] + (obs_dim, obs_mvn_shape[-1] if obs_mvn_shape else 6) init_dist = random_mvn(base_init_shape, hidden_dim) trans_mat = torch.randn(base_trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(base_trans_mvn_shape, hidden_dim) obs_mat = torch.randn(base_obs_mat_shape + (hidden_dim, 1)) obs_dist = random_mvn(base_obs_mvn_shape, 1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=6) d = dist.IndependentHMM(d) shape = broadcast_shape(init_shape + (6, ), trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim, ) assert d.batch_shape == expected_batch_shape assert d.event_shape == expected_event_shape assert d.support.event_dim == d.event_dim data = torch.randn(shape + (obs_dim, )) assert data.shape == d.shape() actual = d.log_prob(data) assert actual.shape == expected_batch_shape check_expand(d, data) x = d.rsample() assert x.shape == d.shape() x = d.rsample((6, )) assert x.shape == (6, ) + d.shape() x = d.expand((6, 5)).rsample() assert x.shape == (6, 5) + d.event_shape
def test_gaussian_mrf_log_prob_block_diag(sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): # Construct a block-diagonal obs dist, so observations are independent of hidden state. obs_dist = random_mvn(batch_shape + (num_steps,), hidden_dim + obs_dim) precision = obs_dist.precision_matrix precision[..., :hidden_dim, hidden_dim:] = 0 precision[..., hidden_dim:, :hidden_dim] = 0 obs_dist = dist.MultivariateNormal(obs_dist.loc, precision_matrix=precision) marginal_obs_dist = dist.MultivariateNormal( obs_dist.loc[..., hidden_dim:], precision_matrix=precision[..., hidden_dim:, hidden_dim:]) init_dist = random_mvn(batch_shape, hidden_dim) trans_dist = random_mvn(batch_shape + (num_steps,), hidden_dim + hidden_dim) d = dist.GaussianMRF(init_dist, trans_dist, obs_dist) data = obs_dist.sample(sample_shape)[..., hidden_dim:] assert data.shape == sample_shape + d.shape() actual_log_prob = d.log_prob(data) expected_log_prob = marginal_obs_dist.log_prob(data).sum(-1) assert_close(actual_log_prob, expected_log_prob)
def random_dist(Dist, shape, transform=None): if Dist is dist.FoldedDistribution: return Dist(random_dist(dist.Normal, shape)) elif Dist is dist.MaskedDistribution: base_dist = random_dist(dist.Normal, shape) mask = torch.empty(shape, dtype=torch.bool).bernoulli_(0.5) return base_dist.mask(mask) elif Dist is dist.TransformedDistribution: base_dist = random_dist(dist.Normal, shape) transforms = [ dist.transforms.ExpTransform(), dist.transforms.ComposeTransform([ dist.transforms.AffineTransform(1, 1), dist.transforms.ExpTransform().inv, ]), ] return dist.TransformedDistribution(base_dist, transforms) elif Dist in (dist.GaussianHMM, dist.LinearHMM): batch_shape, duration, obs_dim = shape[:-2], shape[-2], shape[-1] hidden_dim = obs_dim + 1 init_dist = random_dist(dist.Normal, batch_shape + (hidden_dim, )).to_event(1) trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) trans_dist = random_dist(dist.Normal, batch_shape + (duration, hidden_dim)).to_event(1) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) obs_dist = random_dist(dist.Normal, batch_shape + (duration, obs_dim)).to_event(1) if Dist is dist.LinearHMM and transform is not None: obs_dist = dist.TransformedDistribution(obs_dist, transform) return Dist(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) elif Dist is dist.IndependentHMM: batch_shape, duration, obs_dim = shape[:-2], shape[-2], shape[-1] base_shape = batch_shape + (obs_dim, duration, 1) base_dist = random_dist(dist.GaussianHMM, base_shape) return Dist(base_dist) elif Dist is dist.MultivariateNormal: return random_mvn(shape[:-1], shape[-1]) elif Dist is dist.Uniform: low = torch.randn(shape) high = low + torch.randn(shape).exp() return Dist(low, high) else: params = { name: transform_to(Dist.arg_constraints[name])(torch.rand(shape) - 0.5) for name in UNIVARIATE_DISTS[Dist] } return Dist(**params)
def test_gaussian_hmm_elbo(batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim), requires_grad=True) trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim) obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim), requires_grad=True) obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim) data = obs_dist.sample() assert data.shape == batch_shape + (num_steps, obs_dim) prior = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) likelihood = dist.Normal(data, 1).to_event(2) posterior, log_normalizer = prior.conjugate_update(likelihood) def model(data): with pyro.plate_stack("plates", batch_shape): z = pyro.sample("z", prior) pyro.sample("x", dist.Normal(z, 1).to_event(2), obs=data) def guide(data): with pyro.plate_stack("plates", batch_shape): pyro.sample("z", posterior) reparam_model = poutine.reparam(model, {"z": ConjugateReparam(likelihood)}) def reparam_guide(data): pass elbo = Trace_ELBO(num_particles=1000, vectorize_particles=True) expected_loss = elbo.differentiable_loss(model, guide, data) actual_loss = elbo.differentiable_loss(reparam_model, reparam_guide, data) assert_close(actual_loss, expected_loss, atol=0.01) params = [trans_mat, obs_mat] expected_grads = torch.autograd.grad(expected_loss, params, retain_graph=True) actual_grads = torch.autograd.grad(actual_loss, params, retain_graph=True) for a, e in zip(actual_grads, expected_grads): assert_close(a, e, rtol=0.01)
def test_gaussian_hmm_high_obs_dim(): hidden_dim = 1 obs_dim = 1000 duration = 10 sample_shape = (100, ) init_dist = random_mvn((), hidden_dim) trans_mat = torch.randn((duration, ) + (hidden_dim, hidden_dim)) trans_dist = random_mvn((duration, ), hidden_dim) obs_mat = torch.randn((duration, ) + (hidden_dim, obs_dim)) loc = torch.randn((duration, obs_dim)) scale = torch.randn((duration, obs_dim)).exp() obs_dist = dist.Normal(loc, scale).to_event(1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) x = d.rsample(sample_shape) assert x.shape == sample_shape + (duration, obs_dim)
def test_gamma_and_mvn_to_gamma_gaussian(sample_shape, batch_shape, dim): gamma = random_gamma(batch_shape) mvn = random_mvn(batch_shape, dim) g = gamma_and_mvn_to_gamma_gaussian(gamma, mvn) value = mvn.sample(sample_shape) s = gamma.sample(sample_shape) actual_log_prob = g.log_density(value, s) s_log_prob = gamma.log_prob(s) scaled_prec = mvn.precision_matrix * s.unsqueeze(-1).unsqueeze(-1) mvn_log_prob = dist.MultivariateNormal( mvn.loc, precision_matrix=scaled_prec).log_prob(value) expected_log_prob = s_log_prob + mvn_log_prob assert_close(actual_log_prob, expected_log_prob)
def test_matrix_and_mvn_to_gamma_gaussian(sample_shape, batch_shape, x_dim, y_dim): matrix = torch.randn(batch_shape + (x_dim, y_dim)) y_mvn = random_mvn(batch_shape, y_dim) g = matrix_and_mvn_to_gamma_gaussian(matrix, y_mvn) xy = torch.randn(sample_shape + batch_shape + (x_dim + y_dim, )) s = torch.rand(sample_shape + batch_shape) actual_log_prob = g.log_density(xy, s) x, y = xy[..., :x_dim], xy[..., x_dim:] y_pred = x.unsqueeze(-2).matmul(matrix).squeeze(-2) loc = y_pred + y_mvn.loc scaled_prec = y_mvn.precision_matrix * s.unsqueeze(-1).unsqueeze(-1) expected_log_prob = dist.MultivariateNormal( loc, precision_matrix=scaled_prec).log_prob(y) assert_close(actual_log_prob, expected_log_prob)
def test_rsample_distribution(batch_shape, dim): num_samples = 20000 mvn = random_mvn(batch_shape, dim) g = mvn_to_gaussian(mvn) expected = mvn.rsample((num_samples, )) actual = g.rsample((num_samples, )) def get_moments(x): mean = x.mean(0) x = x - mean cov = (x.unsqueeze(-1) * x.unsqueeze(-2)).mean(0) std = cov.diagonal(dim1=-1, dim2=-2).sqrt() corr = cov / (std.unsqueeze(-1) * std.unsqueeze(-2)) return mean, std, corr expected_mean, expected_std, expected_corr = get_moments(expected) actual_mean, actual_std, actual_corr = get_moments(actual) assert_close(actual_mean, expected_mean, atol=0.1, rtol=0.02) assert_close(actual_std, expected_std, atol=0.1, rtol=0.02) assert_close(actual_corr, expected_corr, atol=0.05)
def random_dist(Dist, shape, transform=None): if Dist is dist.FoldedDistribution: return Dist(random_dist(dist.Normal, shape)) elif Dist in (dist.GaussianHMM, dist.LinearHMM): batch_shape, duration, obs_dim = shape[:-2], shape[-2], shape[-1] hidden_dim = obs_dim + 1 init_dist = random_dist(dist.Normal, batch_shape + (hidden_dim, )).to_event(1) trans_mat = torch.randn(batch_shape + (duration, hidden_dim, hidden_dim)) trans_dist = random_dist(dist.Normal, batch_shape + (duration, hidden_dim)).to_event(1) obs_mat = torch.randn(batch_shape + (duration, hidden_dim, obs_dim)) obs_dist = random_dist(dist.Normal, batch_shape + (duration, obs_dim)).to_event(1) if Dist is dist.LinearHMM and transform is not None: obs_dist = dist.TransformedDistribution(obs_dist, transform) return Dist(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) elif Dist is dist.IndependentHMM: batch_shape, duration, obs_dim = shape[:-2], shape[-2], shape[-1] base_shape = batch_shape + (obs_dim, duration, 1) base_dist = random_dist(dist.GaussianHMM, base_shape) return Dist(base_dist) elif Dist is dist.MultivariateNormal: return random_mvn(shape[:-1], shape[-1]) else: params = { name: transform_to(Dist.arg_constraints[name])(torch.rand(shape) - 0.5) for name in UNIVARIATE_DISTS[Dist] } return Dist(**params)
def test_gaussian_hmm_distribution(diag, sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim)) trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim) obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim)) obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim) if diag: scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1) obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=num_steps) if diag: obs_mvn = dist.MultivariateNormal( obs_dist.base_dist.loc, scale_tril=obs_dist.base_dist.scale.diag_embed()) else: obs_mvn = obs_dist data = obs_dist.sample(sample_shape) assert data.shape == sample_shape + d.shape() actual_log_prob = d.log_prob(data) # Compare against hand-computed density. # We will construct enormous unrolled joint gaussians with shapes: # t | 0 1 2 3 1 2 3 T = 3 in this example # ------+----------------------------------------- # init | H # trans | H H H H H = hidden # obs | H H H O O O O = observed # like | O O O # and then combine these using gaussian_tensordot(). T = num_steps init = mvn_to_gaussian(init_dist) trans = matrix_and_mvn_to_gaussian(trans_mat, trans_dist) obs = matrix_and_mvn_to_gaussian(obs_mat, obs_mvn) like_dist = dist.Normal(torch.randn(data.shape), 1).to_event(2) like = mvn_to_gaussian(like_dist) unrolled_trans = reduce(operator.add, [ trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) for t in range(T) ]) unrolled_obs = reduce(operator.add, [ obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) for t in range(T) ]) unrolled_like = reduce(operator.add, [ like[..., t].event_pad(left=t * obs_dim, right=(T - t - 1) * obs_dim) for t in range(T) ]) # Permute obs from HOHOHO to HHHOOO. perm = torch.cat( [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)]) unrolled_obs = unrolled_obs.event_permute(perm) unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, )) assert init.dim() == hidden_dim assert unrolled_trans.dim() == (1 + T) * hidden_dim assert unrolled_obs.dim() == T * (hidden_dim + obs_dim) logp = gaussian_tensordot(init, unrolled_trans, hidden_dim) logp = gaussian_tensordot(logp, unrolled_obs, T * hidden_dim) expected_log_prob = logp.log_density(unrolled_data) assert_close(actual_log_prob, expected_log_prob) d_posterior, log_normalizer = d.conjugate_update(like_dist) assert_close( d.log_prob(data) + like_dist.log_prob(data), d_posterior.log_prob(data) + log_normalizer) if batch_shape or sample_shape: return # Test mean and covariance. prior = "prior", d, logp posterior = "posterior", d_posterior, logp + unrolled_like for name, d, g in [prior, posterior]: logging.info("testing {} moments".format(name)) with torch.no_grad(): num_samples = 100000 samples = d.sample([num_samples]).reshape(num_samples, T * obs_dim) actual_mean = samples.mean(0) delta = samples - actual_mean actual_cov = (delta.unsqueeze(-1) * delta.unsqueeze(-2)).mean(0) actual_std = actual_cov.diagonal(dim1=-2, dim2=-1).sqrt() actual_corr = actual_cov / (actual_std.unsqueeze(-1) * actual_std.unsqueeze(-2)) expected_cov = g.precision.cholesky().cholesky_inverse() expected_mean = expected_cov.matmul( g.info_vec.unsqueeze(-1)).squeeze(-1) expected_std = expected_cov.diagonal(dim1=-2, dim2=-1).sqrt() expected_corr = expected_cov / (expected_std.unsqueeze(-1) * expected_std.unsqueeze(-2)) assert_close(actual_mean, expected_mean, atol=0.05, rtol=0.02) assert_close(actual_std, expected_std, atol=0.05, rtol=0.02) assert_close(actual_corr, expected_corr, atol=0.02)
def test_gaussian_hmm_shape(diag, init_shape, trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim): init_dist = random_mvn(init_shape, hidden_dim) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(trans_mvn_shape, hidden_dim) obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim)) obs_dist = random_mvn(obs_mvn_shape, obs_dim) if diag: scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1) obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=6) shape = broadcast_shape(init_shape + (6, ), trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim, ) assert d.batch_shape == expected_batch_shape assert d.event_shape == expected_event_shape assert d.support.event_dim == d.event_dim data = obs_dist.expand(shape).sample() assert data.shape == d.shape() actual = d.log_prob(data) assert actual.shape == expected_batch_shape check_expand(d, data) x = d.rsample() assert x.shape == d.shape() x = d.rsample((6, )) assert x.shape == (6, ) + d.shape() x = d.expand((6, 5)).rsample() assert x.shape == (6, 5) + d.event_shape likelihood = dist.Normal(data, 1).to_event(2) p, log_normalizer = d.conjugate_update(likelihood) assert p.batch_shape == d.batch_shape assert p.event_shape == d.event_shape x = p.rsample() assert x.shape == d.shape() x = p.rsample((6, )) assert x.shape == (6, ) + d.shape() x = p.expand((6, 5)).rsample() assert x.shape == (6, 5) + d.event_shape final = d.filter(data) assert isinstance(final, dist.MultivariateNormal) assert final.batch_shape == d.batch_shape assert final.event_shape == (hidden_dim, ) z = d.rsample_posterior(data) assert z.shape == expected_batch_shape + time_shape + (hidden_dim, ) for t in range(1, d.duration - 1): f = d.duration - t d2 = d.prefix_condition(data[..., :t, :]) assert d2.batch_shape == d.batch_shape assert d2.event_shape == (f, obs_dim)