def __init__(self, initial_dist, transition_dist, observation_dist, validate_args=None): assert isinstance(initial_dist, torch.distributions.MultivariateNormal) assert isinstance(transition_dist, torch.distributions.MultivariateNormal) assert isinstance(observation_dist, torch.distributions.MultivariateNormal) hidden_dim = initial_dist.event_shape[0] assert transition_dist.event_shape[0] == hidden_dim + hidden_dim obs_dim = observation_dist.event_shape[0] - hidden_dim shape = broadcast_shape(initial_dist.batch_shape + (1, ), transition_dist.batch_shape, observation_dist.batch_shape) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim, ) super(GaussianMRF, self).__init__(batch_shape, event_shape, validate_args=validate_args) self.hidden_dim = hidden_dim self.obs_dim = obs_dim self._init = mvn_to_gaussian(initial_dist) self._trans = mvn_to_gaussian(transition_dist) self._obs = mvn_to_gaussian(observation_dist)
def test_gaussian_mrf_log_prob(sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim + hidden_dim) obs_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim + obs_dim) d = dist.GaussianMRF(init_dist, trans_dist, obs_dist) data = obs_dist.sample(sample_shape)[..., hidden_dim:] assert data.shape == sample_shape + d.shape() actual_log_prob = d.log_prob(data) # Compare against hand-computed density. # We will construct enormous unrolled joint gaussians with shapes: # t | 0 1 2 3 1 2 3 T = 3 in this example # ------+----------------------------------------- # init | H # trans | H H H H H = hidden # obs | H H H O O O O = observed # and then combine these using gaussian_tensordot(). T = num_steps init = mvn_to_gaussian(init_dist) trans = mvn_to_gaussian(trans_dist) obs = mvn_to_gaussian(obs_dist) unrolled_trans = reduce( operator.add, [ trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) for t in range(T) ], ) unrolled_obs = reduce( operator.add, [ obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) for t in range(T) ], ) # Permute obs from HOHOHO to HHHOOO. perm = torch.cat( [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)]) unrolled_obs = unrolled_obs.event_permute(perm) unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, )) assert init.dim() == hidden_dim assert unrolled_trans.dim() == (1 + T) * hidden_dim assert unrolled_obs.dim() == T * (hidden_dim + obs_dim) logp_h = gaussian_tensordot(init, unrolled_trans, hidden_dim) logp_oh = gaussian_tensordot(logp_h, unrolled_obs, T * hidden_dim) logp_h += unrolled_obs.marginalize(right=T * obs_dim) expected_log_prob = logp_oh.log_density( unrolled_data) - logp_h.event_logsumexp() assert_close(actual_log_prob, expected_log_prob)
def test_matrix_and_mvn_to_gaussian_2(sample_shape, batch_shape, x_dim, y_dim): matrix = torch.randn(batch_shape + (x_dim, y_dim)) y_mvn = random_mvn(batch_shape, y_dim) x_mvn = random_mvn(batch_shape, x_dim) Mx_cov = matrix.transpose(-2, -1).matmul( x_mvn.covariance_matrix).matmul(matrix) Mx_loc = matrix.transpose(-2, -1).matmul(x_mvn.loc.unsqueeze(-1)).squeeze(-1) mvn = dist.MultivariateNormal(Mx_loc + y_mvn.loc, Mx_cov + y_mvn.covariance_matrix) expected = mvn_to_gaussian(mvn) actual = gaussian_tensordot(mvn_to_gaussian(x_mvn), matrix_and_mvn_to_gaussian(matrix, y_mvn), dims=x_dim) assert_close_gaussian(expected, actual)
def test_mvn_to_gaussian(sample_shape, batch_shape, dim): mvn = random_mvn(batch_shape, dim) gaussian = mvn_to_gaussian(mvn) value = mvn.sample(sample_shape) actual_log_prob = gaussian.log_density(value) expected_log_prob = mvn.log_prob(value) assert_close(actual_log_prob, expected_log_prob)
def test_rsample_shape(sample_shape, batch_shape, dim): mvn = random_mvn(batch_shape, dim) g = mvn_to_gaussian(mvn) expected = mvn.rsample(sample_shape) actual = g.rsample(sample_shape) assert actual.dtype == expected.dtype assert actual.shape == expected.shape
def test_gaussian_hmm_log_prob(diag, sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim)) trans_dist = random_mvn(batch_shape + (num_steps,), hidden_dim) obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim)) obs_dist = random_mvn(batch_shape + (num_steps,), obs_dim) if diag: scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1) obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) if diag: obs_mvn = dist.MultivariateNormal(obs_dist.base_dist.loc, scale_tril=obs_dist.base_dist.scale.diag_embed()) else: obs_mvn = obs_dist data = obs_dist.sample(sample_shape) assert data.shape == sample_shape + d.shape() actual_log_prob = d.log_prob(data) # Compare against hand-computed density. # We will construct enormous unrolled joint gaussians with shapes: # t | 0 1 2 3 1 2 3 T = 3 in this example # ------+----------------------------------------- # init | H # trans | H H H H H = hidden # obs | H H H O O O O = observed # and then combine these using gaussian_tensordot(). T = num_steps init = mvn_to_gaussian(init_dist) trans = matrix_and_mvn_to_gaussian(trans_mat, trans_dist) obs = matrix_and_mvn_to_gaussian(obs_mat, obs_mvn) unrolled_trans = reduce(operator.add, [ trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) for t in range(T) ]) unrolled_obs = reduce(operator.add, [ obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) for t in range(T) ]) # Permute obs from HOHOHO to HHHOOO. perm = torch.cat([torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)]) unrolled_obs = unrolled_obs.event_permute(perm) unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim,)) assert init.dim() == hidden_dim assert unrolled_trans.dim() == (1 + T) * hidden_dim assert unrolled_obs.dim() == T * (hidden_dim + obs_dim) logp = gaussian_tensordot(init, unrolled_trans, hidden_dim) logp = gaussian_tensordot(logp, unrolled_obs, T * hidden_dim) expected_log_prob = logp.log_density(unrolled_data) assert_close(actual_log_prob, expected_log_prob)
def test_matrix_and_mvn_to_gaussian(sample_shape, batch_shape, x_dim, y_dim): matrix = torch.randn(batch_shape + (x_dim, y_dim)) y_mvn = random_mvn(batch_shape, y_dim) xy_mvn = random_mvn(batch_shape, x_dim + y_dim) gaussian = matrix_and_mvn_to_gaussian(matrix, y_mvn) + mvn_to_gaussian(xy_mvn) xy = torch.randn(sample_shape + (1, ) * len(batch_shape) + (x_dim + y_dim, )) x, y = xy[..., :x_dim], xy[..., x_dim:] y_pred = x.unsqueeze(-2).matmul(matrix).squeeze(-2) actual_log_prob = gaussian.log_density(xy) expected_log_prob = xy_mvn.log_prob(xy) + y_mvn.log_prob(y - y_pred) assert_close(actual_log_prob, expected_log_prob)
def test_rsample_distribution(batch_shape, dim): num_samples = 20000 mvn = random_mvn(batch_shape, dim) g = mvn_to_gaussian(mvn) expected = mvn.rsample((num_samples, )) actual = g.rsample((num_samples, )) def get_moments(x): mean = x.mean(0) x = x - mean cov = (x.unsqueeze(-1) * x.unsqueeze(-2)).mean(0) std = cov.diagonal(dim1=-1, dim2=-2).sqrt() corr = cov / (std.unsqueeze(-1) * std.unsqueeze(-2)) return mean, std, corr expected_mean, expected_std, expected_corr = get_moments(expected) actual_mean, actual_std, actual_corr = get_moments(actual) assert_close(actual_mean, expected_mean, atol=0.1, rtol=0.02) assert_close(actual_std, expected_std, atol=0.1, rtol=0.02) assert_close(actual_corr, expected_corr, atol=0.05)
def __init__(self, initial_dist, transition_matrix, transition_dist, observation_matrix, observation_dist, validate_args=None): assert isinstance(initial_dist, torch.distributions.MultivariateNormal) assert isinstance(transition_matrix, torch.Tensor) assert isinstance(transition_dist, torch.distributions.MultivariateNormal) assert isinstance(observation_matrix, torch.Tensor) assert (isinstance(observation_dist, torch.distributions.MultivariateNormal) or (isinstance(observation_dist, torch.distributions.Independent) and isinstance(observation_dist.base_dist, torch.distributions.Normal))) hidden_dim, obs_dim = observation_matrix.shape[-2:] assert initial_dist.event_shape == (hidden_dim, ) assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim) assert transition_dist.event_shape == (hidden_dim, ) assert observation_dist.event_shape == (obs_dim, ) shape = broadcast_shape(initial_dist.batch_shape + (1, ), transition_matrix.shape[:-2], transition_dist.batch_shape, observation_matrix.shape[:-2], observation_dist.batch_shape) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim, ) super(GaussianHMM, self).__init__(batch_shape, event_shape, validate_args=validate_args) self.hidden_dim = hidden_dim self.obs_dim = obs_dim self._init = mvn_to_gaussian(initial_dist) self._trans = matrix_and_mvn_to_gaussian(transition_matrix, transition_dist) self._obs = matrix_and_mvn_to_gaussian(observation_matrix, observation_dist)
def test_gaussian_hmm_distribution(diag, sample_shape, batch_shape, num_steps, hidden_dim, obs_dim): init_dist = random_mvn(batch_shape, hidden_dim) trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim)) trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim) obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim)) obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim) if diag: scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1) obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=num_steps) if diag: obs_mvn = dist.MultivariateNormal( obs_dist.base_dist.loc, scale_tril=obs_dist.base_dist.scale.diag_embed()) else: obs_mvn = obs_dist data = obs_dist.sample(sample_shape) assert data.shape == sample_shape + d.shape() actual_log_prob = d.log_prob(data) # Compare against hand-computed density. # We will construct enormous unrolled joint gaussians with shapes: # t | 0 1 2 3 1 2 3 T = 3 in this example # ------+----------------------------------------- # init | H # trans | H H H H H = hidden # obs | H H H O O O O = observed # like | O O O # and then combine these using gaussian_tensordot(). T = num_steps init = mvn_to_gaussian(init_dist) trans = matrix_and_mvn_to_gaussian(trans_mat, trans_dist) obs = matrix_and_mvn_to_gaussian(obs_mat, obs_mvn) like_dist = dist.Normal(torch.randn(data.shape), 1).to_event(2) like = mvn_to_gaussian(like_dist) unrolled_trans = reduce(operator.add, [ trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim) for t in range(T) ]) unrolled_obs = reduce(operator.add, [ obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim()) for t in range(T) ]) unrolled_like = reduce(operator.add, [ like[..., t].event_pad(left=t * obs_dim, right=(T - t - 1) * obs_dim) for t in range(T) ]) # Permute obs from HOHOHO to HHHOOO. perm = torch.cat( [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] + [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)]) unrolled_obs = unrolled_obs.event_permute(perm) unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, )) assert init.dim() == hidden_dim assert unrolled_trans.dim() == (1 + T) * hidden_dim assert unrolled_obs.dim() == T * (hidden_dim + obs_dim) logp = gaussian_tensordot(init, unrolled_trans, hidden_dim) logp = gaussian_tensordot(logp, unrolled_obs, T * hidden_dim) expected_log_prob = logp.log_density(unrolled_data) assert_close(actual_log_prob, expected_log_prob) d_posterior, log_normalizer = d.conjugate_update(like_dist) assert_close( d.log_prob(data) + like_dist.log_prob(data), d_posterior.log_prob(data) + log_normalizer) if batch_shape or sample_shape: return # Test mean and covariance. prior = "prior", d, logp posterior = "posterior", d_posterior, logp + unrolled_like for name, d, g in [prior, posterior]: logging.info("testing {} moments".format(name)) with torch.no_grad(): num_samples = 100000 samples = d.sample([num_samples]).reshape(num_samples, T * obs_dim) actual_mean = samples.mean(0) delta = samples - actual_mean actual_cov = (delta.unsqueeze(-1) * delta.unsqueeze(-2)).mean(0) actual_std = actual_cov.diagonal(dim1=-2, dim2=-1).sqrt() actual_corr = actual_cov / (actual_std.unsqueeze(-1) * actual_std.unsqueeze(-2)) expected_cov = g.precision.cholesky().cholesky_inverse() expected_mean = expected_cov.matmul( g.info_vec.unsqueeze(-1)).squeeze(-1) expected_std = expected_cov.diagonal(dim1=-2, dim2=-1).sqrt() expected_corr = expected_cov / (expected_std.unsqueeze(-1) * expected_std.unsqueeze(-2)) assert_close(actual_mean, expected_mean, atol=0.05, rtol=0.02) assert_close(actual_std, expected_std, atol=0.05, rtol=0.02) assert_close(actual_corr, expected_corr, atol=0.02)
def _init(self): # To save computation in _sequential_gaussian_tensordot(), we expand # only _init, which is applied only after # _sequential_gaussian_tensordot(). return mvn_to_gaussian(self.initial_dist).expand(self.batch_shape)