def inverse_mass_matrix(self): # NB: this computation is O(N^2 x head_size) # however, HMC/NUTS kernel does not require us computing inverse_mass_matrix; # so all linear algebra cost in HMC/NUTS is still O(N x head_size^2); # we still expose this property for testing and for backward compatibility inverse_mass_matrix = {} for site_names, sqrt_inverse in self._mass_matrix_sqrt_inverse.items(): inverse_mass_matrix[site_names] = triu_gram(sqrt_inverse) return inverse_mass_matrix
def test_utilities(head_size): size = 5 cov = torch.randn(size, size) cov = torch.mm(cov, cov.t()) mask = torch.ones(size, size) mask[head_size:, head_size:] = 0.0 mask.view(-1)[::size + 1][head_size:] = 1.0 arrowhead_full = mask * cov expected = torch.flip( torch.linalg.cholesky(torch.flip(arrowhead_full, (-2, -1))), (-2, -1)) # test if those flip ops give expected upper triangular values assert_close(expected.triu(), expected) assert_close(expected.matmul(expected.t()), arrowhead_full) # test sqrt arrowhead = SymmArrowhead(cov[:head_size], cov.diag()[head_size:]) actual = sqrt(arrowhead) assert_close(actual.top, expected[:head_size]) assert_close(actual.bottom_diag, expected.diag()[head_size:]) # test triu_inverse expected = expected.inverse() actual = triu_inverse(actual) assert_close(actual.top, expected[:head_size]) assert_close(actual.bottom_diag, expected.diag()[head_size:]) # test triu_matvecmul v = torch.randn(size) assert_close(triu_matvecmul(actual, v), expected.matmul(v)) assert_close(triu_matvecmul(actual, v, transpose=True), expected.t().matmul(v)) # test triu_gram actual = triu_gram(actual) expected = (arrowhead_full.inverse() if head_size > 0 else arrowhead_full.diag().reciprocal()) assert_close(actual, expected)