Beispiel #1
0
def test_jacobian(transform):
    x = generate_data(transform)
    try:
        y = transform(x)
        actual = transform.log_abs_det_jacobian(x, y)
    except NotImplementedError:
        pytest.skip('Not implemented.')
    # Test shape
    target_shape = x.shape[:x.dim() - transform.domain.event_dim]
    assert actual.shape == target_shape

    # Expand if required
    transform = reshape_transform(transform, x.shape)
    ndims = len(x.shape)
    event_dim = ndims - transform.domain.event_dim
    x_ = x.view((-1, ) + x.shape[event_dim:])
    n = x_.shape[0]
    # Reshape to squash batch dims to a single batch dim
    transform = reshape_transform(transform, x_.shape)

    # 1. Transforms with unit jacobian
    if isinstance(transform, ReshapeTransform) or isinstance(
            transform.inv, ReshapeTransform):
        expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
        expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
    # 2. Transforms with 0 off-diagonal elements
    elif transform.domain.event_dim == 0:
        jac = jacobian(transform, x_)
        # assert off-diagonal elements are zero
        assert torch.allclose(jac, jac.diagonal().diag_embed())
        expected = jac.diagonal().abs().log().reshape(x.shape)
    # 3. Transforms with non-0 off-diagonal elements
    else:
        if isinstance(transform, CorrCholeskyTransform):
            jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1),
                           x_)
        elif isinstance(transform.inv, CorrCholeskyTransform):
            jac = jacobian(lambda x: transform(vec_to_tril_matrix(x, diag=-1)),
                           tril_matrix_to_vec(x_, diag=-1))
        elif isinstance(transform, StickBreakingTransform):
            jac = jacobian(lambda x: transform(x)[..., :-1], x_)
        else:
            jac = jacobian(transform, x_)

        # Note that jacobian will have shape (batch_dims, y_event_dims, batch_dims, x_event_dims)
        # However, batches are independent so this can be converted into a (batch_dims, event_dims, event_dims)
        # after reshaping the event dims (see above) to give a batched square matrix whose determinant
        # can be computed.
        gather_idx_shape = list(jac.shape)
        gather_idx_shape[-2] = 1
        gather_idxs = torch.arange(n).reshape(
            (n, ) + (1, ) * (len(jac.shape) - 1)).expand(gather_idx_shape)
        jac = jac.gather(-2, gather_idxs).squeeze(-2)
        out_ndims = jac.shape[-2]
        jac = jac[
            ..., :
            out_ndims]  # Remove extra zero-valued dims (for inverse stick-breaking).
        expected = torch.slogdet(jac).logabsdet

    assert torch.allclose(actual, expected, atol=1e-5)
Beispiel #2
0
 def _inverse(self, y):
     # inverse stick-breaking
     # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html
     y_cumsum = 1 - torch.cumsum(y * y, dim=-1)
     y_cumsum_shifted = pad(y_cumsum[..., :-1], [1, 0], value=1)
     y_vec = tril_matrix_to_vec(y, diag=-1)
     y_cumsum_vec = tril_matrix_to_vec(y_cumsum_shifted, diag=-1)
     t = y_vec / (y_cumsum_vec).sqrt()
     # inverse of tanh
     x = ((1 + t) / (1 - t)).log() / 2
     return x
Beispiel #3
0
def test_tril_matrix_to_vec(shape):
    mat = torch.randn(shape)
    n = mat.shape[-1]
    for diag in range(-n, n):
        actual = mat.tril(diag)
        vec = tril_matrix_to_vec(actual, diag)
        tril_mat = vec_to_tril_matrix(vec, diag)
        assert torch.allclose(tril_mat, actual)
Beispiel #4
0
    def log_abs_det_jacobian(self, x, y, intermediates=None):
        # Because domain and codomain are two spaces with different dimensions, determinant of
        # Jacobian is not well-defined. We return `log_abs_det_jacobian` of `x` and the
        # flattened lower triangular part of `y`.

        # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html
        y1m_cumsum = 1 - (y * y).cumsum(dim=-1)
        # by taking diagonal=-2, we don't need to shift z_cumprod to the right
        # also works for 2 x 2 matrix
        y1m_cumsum_tril = tril_matrix_to_vec(y1m_cumsum, diag=-2)
        stick_breaking_logdet = 0.5 * (y1m_cumsum_tril).log().sum(-1)
        tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.)).sum(dim=-1)
        return stick_breaking_logdet + tanh_logdet