def test_jacobian(transform): x = generate_data(transform) try: y = transform(x) actual = transform.log_abs_det_jacobian(x, y) except NotImplementedError: pytest.skip('Not implemented.') # Test shape target_shape = x.shape[:x.dim() - transform.domain.event_dim] assert actual.shape == target_shape # Expand if required transform = reshape_transform(transform, x.shape) ndims = len(x.shape) event_dim = ndims - transform.domain.event_dim x_ = x.view((-1, ) + x.shape[event_dim:]) n = x_.shape[0] # Reshape to squash batch dims to a single batch dim transform = reshape_transform(transform, x_.shape) # 1. Transforms with unit jacobian if isinstance(transform, ReshapeTransform) or isinstance( transform.inv, ReshapeTransform): expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim]) expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim]) # 2. Transforms with 0 off-diagonal elements elif transform.domain.event_dim == 0: jac = jacobian(transform, x_) # assert off-diagonal elements are zero assert torch.allclose(jac, jac.diagonal().diag_embed()) expected = jac.diagonal().abs().log().reshape(x.shape) # 3. Transforms with non-0 off-diagonal elements else: if isinstance(transform, CorrCholeskyTransform): jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1), x_) elif isinstance(transform.inv, CorrCholeskyTransform): jac = jacobian(lambda x: transform(vec_to_tril_matrix(x, diag=-1)), tril_matrix_to_vec(x_, diag=-1)) elif isinstance(transform, StickBreakingTransform): jac = jacobian(lambda x: transform(x)[..., :-1], x_) else: jac = jacobian(transform, x_) # Note that jacobian will have shape (batch_dims, y_event_dims, batch_dims, x_event_dims) # However, batches are independent so this can be converted into a (batch_dims, event_dims, event_dims) # after reshaping the event dims (see above) to give a batched square matrix whose determinant # can be computed. gather_idx_shape = list(jac.shape) gather_idx_shape[-2] = 1 gather_idxs = torch.arange(n).reshape( (n, ) + (1, ) * (len(jac.shape) - 1)).expand(gather_idx_shape) jac = jac.gather(-2, gather_idxs).squeeze(-2) out_ndims = jac.shape[-2] jac = jac[ ..., : out_ndims] # Remove extra zero-valued dims (for inverse stick-breaking). expected = torch.slogdet(jac).logabsdet assert torch.allclose(actual, expected, atol=1e-5)
def _inverse(self, y): # inverse stick-breaking # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html y_cumsum = 1 - torch.cumsum(y * y, dim=-1) y_cumsum_shifted = pad(y_cumsum[..., :-1], [1, 0], value=1) y_vec = tril_matrix_to_vec(y, diag=-1) y_cumsum_vec = tril_matrix_to_vec(y_cumsum_shifted, diag=-1) t = y_vec / (y_cumsum_vec).sqrt() # inverse of tanh x = ((1 + t) / (1 - t)).log() / 2 return x
def test_tril_matrix_to_vec(shape): mat = torch.randn(shape) n = mat.shape[-1] for diag in range(-n, n): actual = mat.tril(diag) vec = tril_matrix_to_vec(actual, diag) tril_mat = vec_to_tril_matrix(vec, diag) assert torch.allclose(tril_mat, actual)
def log_abs_det_jacobian(self, x, y, intermediates=None): # Because domain and codomain are two spaces with different dimensions, determinant of # Jacobian is not well-defined. We return `log_abs_det_jacobian` of `x` and the # flattened lower triangular part of `y`. # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html y1m_cumsum = 1 - (y * y).cumsum(dim=-1) # by taking diagonal=-2, we don't need to shift z_cumprod to the right # also works for 2 x 2 matrix y1m_cumsum_tril = tril_matrix_to_vec(y1m_cumsum, diag=-2) stick_breaking_logdet = 0.5 * (y1m_cumsum_tril).log().sum(-1) tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.)).sum(dim=-1) return stick_breaking_logdet + tanh_logdet