def mass_matrix(self, value): for site_names, mass_matrix in value.items(): # XXX: consider to add a try/except here: # if mass_matrix is not positive definite, we won't reset adapt_scheme self._adapt_scheme[site_names].reset() mass_matrix_sqrt = sqrt(mass_matrix) mass_matrix_sqrt_inverse = triu_inverse(mass_matrix_sqrt) self._mass_matrix[site_names] = mass_matrix self._mass_matrix_sqrt[site_names] = mass_matrix_sqrt self._mass_matrix_sqrt_inverse[ site_names] = mass_matrix_sqrt_inverse
def test_utilities(head_size): size = 5 cov = torch.randn(size, size) cov = torch.mm(cov, cov.t()) mask = torch.ones(size, size) mask[head_size:, head_size:] = 0.0 mask.view(-1)[::size + 1][head_size:] = 1.0 arrowhead_full = mask * cov expected = torch.flip( torch.linalg.cholesky(torch.flip(arrowhead_full, (-2, -1))), (-2, -1)) # test if those flip ops give expected upper triangular values assert_close(expected.triu(), expected) assert_close(expected.matmul(expected.t()), arrowhead_full) # test sqrt arrowhead = SymmArrowhead(cov[:head_size], cov.diag()[head_size:]) actual = sqrt(arrowhead) assert_close(actual.top, expected[:head_size]) assert_close(actual.bottom_diag, expected.diag()[head_size:]) # test triu_inverse expected = expected.inverse() actual = triu_inverse(actual) assert_close(actual.top, expected[:head_size]) assert_close(actual.bottom_diag, expected.diag()[head_size:]) # test triu_matvecmul v = torch.randn(size) assert_close(triu_matvecmul(actual, v), expected.matmul(v)) assert_close(triu_matvecmul(actual, v, transpose=True), expected.t().matmul(v)) # test triu_gram actual = triu_gram(actual) expected = (arrowhead_full.inverse() if head_size > 0 else arrowhead_full.diag().reciprocal()) assert_close(actual, expected)