Exemplo n.º 1
0
    def test_multiple_gpus(self):
        num_gpus = torch.cuda.device_count()

        for gpu_idx in range(num_gpus):
            device = torch.device("cuda:{}".format(gpu_idx))

            torch.manual_seed(0)
            a = torch.randn(N, H, W).to(device)
            b = a.clone()
            a.requires_grad = True
            b.requires_grad = True

            U, S, V = svd(a)
            loss = U.sum() + S.sum() + V.sum()
            loss.backward()

            u, s, v = torch.svd(b[0], some=True, compute_uv=True)
            loss0 = u.sum() + s.sum() + v.sum()
            loss0.backward()

            # eigenvectors are only precise up to sign
            testing.assert_allclose(U[0].abs(), u.abs())
            testing.assert_allclose(S[0].abs(), s.abs())
            testing.assert_allclose(V[0].abs(), v.abs())

            a_ = U @ torch.diag_embed(S) @ V.transpose(-2, -1)
            testing.assert_allclose(a, a_)
Exemplo n.º 2
0
def test_double():
    torch.manual_seed(0)
    a = torch.randn(10, 9, 3).cuda().double()
    b = a.clone()
    a.requires_grad = True
    b.requires_grad = True

    U, S, V = svd(a)
    loss = U.sum() + S.sum() + V.sum()
    loss.backward()

    u, s, v = torch.svd(b[0], some=True, compute_uv=True)
    loss0 = u.sum() + s.sum() + v.sum()
    loss0.backward()

    assert U.dtype == torch.double
    assert S.dtype == torch.double
    assert V.dtype == torch.double
    assert a.grad.dtype == torch.double
    testing.assert_allclose(
        U[0].abs(), u.abs())  # eigenvectors are only precise up to sign
    testing.assert_allclose(S[0].abs(), s.abs())
    testing.assert_allclose(V[0].abs(), v.abs())
    testing.assert_allclose(
        a,
        torch.matmul(torch.matmul(U, torch.diag_embed(S)), V.transpose(-2,
                                                                       -1)))
Exemplo n.º 3
0
    def test_half(self):
        torch.manual_seed(0)
        a = torch.randn(N, H, W).cuda().half()
        b = a.clone()
        a.requires_grad = True
        b.requires_grad = True

        U, S, V = svd(a)
        loss = U.sum() + S.sum() + V.sum()
        loss.backward()

        assert U.dtype == torch.half
        assert S.dtype == torch.half
        assert V.dtype == torch.half
        assert a.grad.dtype == torch.half

        a_ = U @ torch.diag_embed(S) @ V.transpose(-2, -1)
        testing.assert_allclose(a, a_, atol=0.01, rtol=0.01)
Exemplo n.º 4
0
def test_half():
    torch.manual_seed(0)
    a = torch.randn(10, 9, 3).cuda().half()
    b = a.clone()
    a.requires_grad = True
    b.requires_grad = True

    U, S, V = svd(a)
    loss = U.sum() + S.sum() + V.sum()
    loss.backward()

    assert U.dtype == torch.half
    assert S.dtype == torch.half
    assert V.dtype == torch.half
    assert a.grad.dtype == torch.half
    testing.assert_allclose(
        a,
        torch.matmul(torch.matmul(U, torch.diag_embed(S)), V.transpose(-2,
                                                                       -1)))
Exemplo n.º 5
0
def test_float():
    torch.manual_seed(0)
    a = torch.randn(N, H, W).cuda()
    b = a.clone()
    a.requires_grad = True
    b.requires_grad = True

    U, S, V = svd(a)
    loss = U.sum() + S.sum() + V.sum()
    loss.backward()

    u, s, v = torch.svd(b[0], some=True, compute_uv=True)
    loss0 = u.sum() + s.sum() + v.sum()
    loss0.backward()

    # eigenvectors are only precise up to sign
    testing.assert_allclose(U[0].abs(), u.abs())
    testing.assert_allclose(S[0].abs(), s.abs())
    testing.assert_allclose(V[0].abs(), v.abs())
    testing.assert_allclose(a, U @ torch.diag_embed(S) @ V.transpose(-2, -1))
Exemplo n.º 6
0
def batch_least_square(A, B, w):

    assert A.shape == B.shape
    num = A.shape[0]
    centroid_A = torch.mean(A, dim=1)
    centroid_B = torch.mean(B, dim=1)
    AA = A - centroid_A.unsqueeze(1)
    BB = B - centroid_B.unsqueeze(1)

    H = torch.bmm(torch.transpose(AA, 2, 1), BB)
    U, S, Vt = svd(H)

    R = torch.bmm(Vt, U.permute(0, 2, 1))
    i = torch.det(R) < 0
    tmp = torch.ones([num, 3, 3], dtype=torch.float32).cuda()
    tmp[i, :, 2] = -1
    Vt = Vt * tmp

    R = torch.bmm(Vt, U.permute(0, 2, 1))
    t = centroid_B - torch.bmm(R, centroid_A.unsqueeze(2)).squeeze()
    return R, t
Exemplo n.º 7
0
def bench_speed(N, H, W):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    torch.manual_seed(0)
    a = torch.randn(N, H, W).cuda()
    b = a.clone().cuda()
    torch.cuda.synchronize()

    start.record()
    for i in range(100):
        U, S, V = svd(a)
    end.record()
    torch.cuda.synchronize()
    t = start.elapsed_time(end) / 100
    print("Perform batched SVD on a {}x{}x{} matrix: {} ms".format(N, H, W, t))

    start.record()
    U, S, V = torch.svd(b, some=True, compute_uv=True)
    end.record()
    torch.cuda.synchronize()
    t = start.elapsed_time(end)
    print("Perform torch.svd on a {}x{}x{} matrix: {} ms".format(N, H, W, t))
Exemplo n.º 8
0
        buf['cd'].append(cd_loss.cpu().item())
        buf['sym'].append(sym_loss.cpu().item())
        buf['nor'].append(nor_loss.cpu().item())
        buf['lap'].append(lap_loss.cpu().item())
        '''disentanglement loss'''
        dot = torch.bmm(basis.abs(), basis.transpose(1, 2).abs())
        dot[:, range(opt.num_basis), range(opt.num_basis)] = 0
        ortho_loss = dot.norm(p=2, dim=(1, 2)).mean()

        sp_loss = basis.view(opt.batch_size, opt.num_basis, 150).norm(p=1, dim=2).mean() \
            + coef.view(opt.batch_size, opt.num_basis).norm(p=1, dim=-1).mean()

        basis = basis.reshape(opt.batch_size * opt.num_basis, 50, 3)
        tmp = torch.bmm(basis.transpose(1, 2), basis)
        _, s, _ = svd(tmp)
        svd_loss = s[:, 2].mean()

        coef = coef.view(opt.batch_size, opt.num_basis) - \
            coef.view(opt.batch_size, opt.num_basis).mean(dim=0)
        cov = torch.bmm(coef.view(opt.batch_size, opt.num_basis, 1),
                        coef.view(opt.batch_size, 1, opt.num_basis))
        cov = cov.sum(dim=0) / (opt.batch_size - 1)
        cov_loss = cov.norm(p=1, dim=(0, 1))

        buf['sp'].append(sp_loss.cpu().item())
        buf['svd'].append(svd_loss.cpu().item())
        buf['ortho'].append(ortho_loss.cpu().item())
        buf['cov'].append(cov_loss.cpu().item())

        loss = cd_loss + sym_loss * 1 + lap_loss * 3 + nor_loss * 0.1 + g_loss * 0.006 \