def test_jacobian(transform): x = generate_data(transform) try: y = transform(x) actual = transform.log_abs_det_jacobian(x, y) except NotImplementedError: pytest.skip('Not implemented.') # Test shape target_shape = x.shape[:x.dim() - transform.domain.event_dim] assert actual.shape == target_shape # Expand if required transform = reshape_transform(transform, x.shape) ndims = len(x.shape) event_dim = ndims - transform.domain.event_dim x_ = x.view((-1, ) + x.shape[event_dim:]) n = x_.shape[0] # Reshape to squash batch dims to a single batch dim transform = reshape_transform(transform, x_.shape) # 1. Transforms with unit jacobian if isinstance(transform, ReshapeTransform) or isinstance( transform.inv, ReshapeTransform): expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim]) expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim]) # 2. Transforms with 0 off-diagonal elements elif transform.domain.event_dim == 0: jac = jacobian(transform, x_) # assert off-diagonal elements are zero assert torch.allclose(jac, jac.diagonal().diag_embed()) expected = jac.diagonal().abs().log().reshape(x.shape) # 3. Transforms with non-0 off-diagonal elements else: if isinstance(transform, CorrCholeskyTransform): jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1), x_) elif isinstance(transform.inv, CorrCholeskyTransform): jac = jacobian(lambda x: transform(vec_to_tril_matrix(x, diag=-1)), tril_matrix_to_vec(x_, diag=-1)) elif isinstance(transform, StickBreakingTransform): jac = jacobian(lambda x: transform(x)[..., :-1], x_) else: jac = jacobian(transform, x_) # Note that jacobian will have shape (batch_dims, y_event_dims, batch_dims, x_event_dims) # However, batches are independent so this can be converted into a (batch_dims, event_dims, event_dims) # after reshaping the event dims (see above) to give a batched square matrix whose determinant # can be computed. gather_idx_shape = list(jac.shape) gather_idx_shape[-2] = 1 gather_idxs = torch.arange(n).reshape( (n, ) + (1, ) * (len(jac.shape) - 1)).expand(gather_idx_shape) jac = jac.gather(-2, gather_idxs).squeeze(-2) out_ndims = jac.shape[-2] jac = jac[ ..., : out_ndims] # Remove extra zero-valued dims (for inverse stick-breaking). expected = torch.slogdet(jac).logabsdet assert torch.allclose(actual, expected, atol=1e-5)
def test_tril_matrix_to_vec(shape): mat = torch.randn(shape) n = mat.shape[-1] for diag in range(-n, n): actual = mat.tril(diag) vec = tril_matrix_to_vec(actual, diag) tril_mat = vec_to_tril_matrix(vec, diag) assert torch.allclose(tril_mat, actual)
def _call(self, x): x = torch.tanh(x) eps = torch.finfo(x.dtype).eps x = x.clamp(min=-1 + eps, max=1 - eps) r = vec_to_tril_matrix(x, diag=-1) # apply stick-breaking on the squared values # Note that y = sign(r) * sqrt(z * z1m_cumprod) # = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod) z = r**2 z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1) # Diagonal elements must be 1. r = r + torch.eye(r.shape[-1], dtype=r.dtype, device=r.device) y = r * pad(z1m_cumprod_sqrt[..., :-1], [1, 0], value=1) return y