예제 #1
0
    def forward(ctx, X):
        b, c, d, h, w = X.size()
        S_log, U = vSymEig(X, eigenvectors=True, flatten_output=True)

        ctx.save_for_backward(S_log, torch.exp(S_log), U, X)

        return U.bmm(torch.diag_embed(S_log)).bmm(U.transpose(1, 2)).reshape(
            b, d * h * w, 3, 3).permute(0, 2, 3, 1).reshape(b, c, d, h, w)
예제 #2
0
    def test_should_compute_eigen_values(self):
        b, c, d, h, w = 1, 9, 32, 32, 32
        input = sym(torch.rand(b, c, d, h, w))
        eig_vals_expected, eig_vecs = vSymEig(input,
                                              eigenvectors=True,
                                              flatten_output=True)

        eig_vals = EigVals()(input)

        assert_that(torch.allclose(eig_vals_expected, eig_vals, atol=0.000001),
                    equal_to(True))
예제 #3
0
    def test_should_not_fail_on_empty_matrix(self):
        b, c, d, h, w = 2, 9, 1, 1, 1
        input = sym(torch.zeros(b, c, d, h, w))

        eig_vals, eig_vecs = vSymEig(input,
                                     eigenvectors=True,
                                     flatten_output=True)

        assert_that(
            torch.allclose(eig_vals, torch.zeros((2, 3)), atol=0.000001),
            equal_to(True))
        assert_that(
            torch.allclose(eig_vecs,
                           torch.eye(3).unsqueeze(0).repeat(2, 1, 1),
                           atol=0.000001), equal_to(True))
예제 #4
0
    def test_should_decompose_symmetric_matrices(self):
        b, c, d, h, w = 1, 9, 32, 32, 32
        input = sym(torch.rand(b, c, d, h, w))
        eig_vals, eig_vecs = vSymEig(input,
                                     eigenvectors=True,
                                     flatten_output=True,
                                     descending_eigenvals=False)

        # UVU^T
        reconstructed_input = eig_vecs.bmm(torch.diag_embed(eig_vals)).bmm(
            eig_vecs.transpose(1, 2))
        reconstructed_input = reconstructed_input.reshape(
            b, d * h * w, 3, 3).permute(0, 2, 3, 1).reshape(b, c, d, h, w)

        assert_that(torch.allclose(reconstructed_input, input, atol=0.000001),
                    equal_to(True))
        assert_that(torch.any(eig_vals[:, 0] > eig_vals[:, 1]),
                    equal_to(False))
        assert_that(torch.any(eig_vals[:, 1] > eig_vals[:, 2]),
                    equal_to(False))

        eig_vals, eig_vecs = vSymEig(input,
                                     eigenvectors=True,
                                     flatten_output=True,
                                     descending_eigenvals=True)

        # UVU^T
        reconstructed_input = eig_vecs.bmm(torch.diag_embed(eig_vals)).bmm(
            eig_vecs.transpose(1, 2))
        reconstructed_input = reconstructed_input.reshape(
            b, d * h * w, 3, 3).permute(0, 2, 3, 1).reshape(b, c, d, h, w)

        assert_that(torch.allclose(reconstructed_input, input, atol=0.000001),
                    equal_to(True))
        assert_that(torch.any(eig_vals[:, 0] > eig_vals[:, 1]), equal_to(True))
        assert_that(torch.any(eig_vals[:, 1] > eig_vals[:, 2]), equal_to(True))
예제 #5
0
    def test_should_decompose_identity_matrix(self):
        b, c, d, h, w = 2, 9, 1, 1, 1
        input = torch.zeros(b, c, d, h, w)
        input[:, 0, :, :, :] = 1.0
        input[:, 4, :, :, :] = 1.0
        input[:, 8, :, :, :] = 1.0

        eig_vals, eig_vecs = vSymEig(input,
                                     eigenvectors=True,
                                     flatten_output=True)

        assert_that(
            torch.allclose(eig_vals, torch.ones((2, 3)), atol=0.000001),
            equal_to(True))
        assert_that(
            torch.allclose(eig_vecs,
                           torch.eye(3).unsqueeze(0).repeat(2, 1, 1),
                           atol=0.000001), equal_to(True))
예제 #6
0
    def test_should_decompose_matrix_with_same_diag(self):
        b, c, d, h, w = 2, 9, 1, 1, 1
        input = sym(torch.rand(b, c, d, h, w))
        input[0, 0, :, :, :] = 1.0
        input[0, 4, :, :, :] = 1.0
        input[0, 8, :, :, :] = 1.0
        input[1, :, :, :, :] = 0.0

        eig_vals, eig_vecs = vSymEig(input,
                                     eigenvectors=True,
                                     flatten_output=True)
        expected_eig_vals, expected_eig_vecs = input.unsqueeze(0).reshape(
            b, 3, 3).symeig(eigenvectors=True)

        assert_that(torch.allclose(eig_vals, expected_eig_vals, atol=0.000001),
                    equal_to(True))
        assert_that(
            torch.allclose(torch.abs(eig_vecs), torch.abs(expected_eig_vecs)),
            equal_to(True))
import torch
from torchvectorized.utils import sym
from torchvectorized.vlinalg import vSymEig
import matplotlib.pyplot as plt

OPTIMIZER_STEPS = 5000

if __name__ == "__main__":
    cos_sim_computer = torch.nn.CosineSimilarity()

    gt_vals, gt_vecs = vSymEig(torch.rand(16, 9, 8, 8, 8), eigenvectors=True, descending_eigenvals=True)
    input = torch.nn.Parameter(sym(torch.rand(16, 9, 8, 8, 8)), requires_grad=True)

    optimizer = torch.optim.Adam([input], lr=0.001, betas=[0.5, 0.999])

    steps = []
    eig_vals_loss = []
    eig_vecs_loss = []
    cos_sim_metrics_v1 = []
    cos_sim_metrics_v2 = []
    cos_sim_metrics_v3 = []

    for optim_step in range(OPTIMIZER_STEPS):
        optimizer.zero_grad()
        eig_vals, eig_vecs = vSymEig(input, eigenvectors=True, descending_eigenvals=True)
        loss_eig_val = torch.nn.functional.l1_loss(eig_vals, gt_vals)
        loss_eig_vecs = torch.nn.functional.l1_loss(eig_vecs, gt_vecs)

        total_loss = loss_eig_vecs + loss_eig_val

        steps.append(optim_step)
예제 #8
0
def vectorized_func_gpu(b):
    c, d, h, w = 9, 10, 10, 10
    return timeit.timeit(lambda: vSymEig(sym(torch.rand(b, c, d, h, w)).cuda(),
                                         eigenvectors=True),
                         number=5) / 5
예제 #9
0
    def forward(ctx, X):
        V, U = vSymEig(X, eigenvectors=True, flatten_output=True)
        ctx.save_for_backward(V, U, X)

        return V
예제 #10
0
def vectorized_func(b):
    c, d, h, w = 9, 32, 32, 32
    return timeit.timeit(
        lambda: vSymEig(sym(torch.rand(b, c, d, h, w)), eigenvectors=True),
        number=5) / (5 * 1000000)