def forward(ctx, X): b, c, d, h, w = X.size() S_log, U = vSymEig(X, eigenvectors=True, flatten_output=True) ctx.save_for_backward(S_log, torch.exp(S_log), U, X) return U.bmm(torch.diag_embed(S_log)).bmm(U.transpose(1, 2)).reshape( b, d * h * w, 3, 3).permute(0, 2, 3, 1).reshape(b, c, d, h, w)
def test_should_compute_eigen_values(self): b, c, d, h, w = 1, 9, 32, 32, 32 input = sym(torch.rand(b, c, d, h, w)) eig_vals_expected, eig_vecs = vSymEig(input, eigenvectors=True, flatten_output=True) eig_vals = EigVals()(input) assert_that(torch.allclose(eig_vals_expected, eig_vals, atol=0.000001), equal_to(True))
def test_should_not_fail_on_empty_matrix(self): b, c, d, h, w = 2, 9, 1, 1, 1 input = sym(torch.zeros(b, c, d, h, w)) eig_vals, eig_vecs = vSymEig(input, eigenvectors=True, flatten_output=True) assert_that( torch.allclose(eig_vals, torch.zeros((2, 3)), atol=0.000001), equal_to(True)) assert_that( torch.allclose(eig_vecs, torch.eye(3).unsqueeze(0).repeat(2, 1, 1), atol=0.000001), equal_to(True))
def test_should_decompose_symmetric_matrices(self): b, c, d, h, w = 1, 9, 32, 32, 32 input = sym(torch.rand(b, c, d, h, w)) eig_vals, eig_vecs = vSymEig(input, eigenvectors=True, flatten_output=True, descending_eigenvals=False) # UVU^T reconstructed_input = eig_vecs.bmm(torch.diag_embed(eig_vals)).bmm( eig_vecs.transpose(1, 2)) reconstructed_input = reconstructed_input.reshape( b, d * h * w, 3, 3).permute(0, 2, 3, 1).reshape(b, c, d, h, w) assert_that(torch.allclose(reconstructed_input, input, atol=0.000001), equal_to(True)) assert_that(torch.any(eig_vals[:, 0] > eig_vals[:, 1]), equal_to(False)) assert_that(torch.any(eig_vals[:, 1] > eig_vals[:, 2]), equal_to(False)) eig_vals, eig_vecs = vSymEig(input, eigenvectors=True, flatten_output=True, descending_eigenvals=True) # UVU^T reconstructed_input = eig_vecs.bmm(torch.diag_embed(eig_vals)).bmm( eig_vecs.transpose(1, 2)) reconstructed_input = reconstructed_input.reshape( b, d * h * w, 3, 3).permute(0, 2, 3, 1).reshape(b, c, d, h, w) assert_that(torch.allclose(reconstructed_input, input, atol=0.000001), equal_to(True)) assert_that(torch.any(eig_vals[:, 0] > eig_vals[:, 1]), equal_to(True)) assert_that(torch.any(eig_vals[:, 1] > eig_vals[:, 2]), equal_to(True))
def test_should_decompose_identity_matrix(self): b, c, d, h, w = 2, 9, 1, 1, 1 input = torch.zeros(b, c, d, h, w) input[:, 0, :, :, :] = 1.0 input[:, 4, :, :, :] = 1.0 input[:, 8, :, :, :] = 1.0 eig_vals, eig_vecs = vSymEig(input, eigenvectors=True, flatten_output=True) assert_that( torch.allclose(eig_vals, torch.ones((2, 3)), atol=0.000001), equal_to(True)) assert_that( torch.allclose(eig_vecs, torch.eye(3).unsqueeze(0).repeat(2, 1, 1), atol=0.000001), equal_to(True))
def test_should_decompose_matrix_with_same_diag(self): b, c, d, h, w = 2, 9, 1, 1, 1 input = sym(torch.rand(b, c, d, h, w)) input[0, 0, :, :, :] = 1.0 input[0, 4, :, :, :] = 1.0 input[0, 8, :, :, :] = 1.0 input[1, :, :, :, :] = 0.0 eig_vals, eig_vecs = vSymEig(input, eigenvectors=True, flatten_output=True) expected_eig_vals, expected_eig_vecs = input.unsqueeze(0).reshape( b, 3, 3).symeig(eigenvectors=True) assert_that(torch.allclose(eig_vals, expected_eig_vals, atol=0.000001), equal_to(True)) assert_that( torch.allclose(torch.abs(eig_vecs), torch.abs(expected_eig_vecs)), equal_to(True))
import torch from torchvectorized.utils import sym from torchvectorized.vlinalg import vSymEig import matplotlib.pyplot as plt OPTIMIZER_STEPS = 5000 if __name__ == "__main__": cos_sim_computer = torch.nn.CosineSimilarity() gt_vals, gt_vecs = vSymEig(torch.rand(16, 9, 8, 8, 8), eigenvectors=True, descending_eigenvals=True) input = torch.nn.Parameter(sym(torch.rand(16, 9, 8, 8, 8)), requires_grad=True) optimizer = torch.optim.Adam([input], lr=0.001, betas=[0.5, 0.999]) steps = [] eig_vals_loss = [] eig_vecs_loss = [] cos_sim_metrics_v1 = [] cos_sim_metrics_v2 = [] cos_sim_metrics_v3 = [] for optim_step in range(OPTIMIZER_STEPS): optimizer.zero_grad() eig_vals, eig_vecs = vSymEig(input, eigenvectors=True, descending_eigenvals=True) loss_eig_val = torch.nn.functional.l1_loss(eig_vals, gt_vals) loss_eig_vecs = torch.nn.functional.l1_loss(eig_vecs, gt_vecs) total_loss = loss_eig_vecs + loss_eig_val steps.append(optim_step)
def vectorized_func_gpu(b): c, d, h, w = 9, 10, 10, 10 return timeit.timeit(lambda: vSymEig(sym(torch.rand(b, c, d, h, w)).cuda(), eigenvectors=True), number=5) / 5
def forward(ctx, X): V, U = vSymEig(X, eigenvectors=True, flatten_output=True) ctx.save_for_backward(V, U, X) return V
def vectorized_func(b): c, d, h, w = 9, 32, 32, 32 return timeit.timeit( lambda: vSymEig(sym(torch.rand(b, c, d, h, w)), eigenvectors=True), number=5) / (5 * 1000000)