def test_multiple_gpus(self): num_gpus = torch.cuda.device_count() for gpu_idx in range(num_gpus): device = torch.device("cuda:{}".format(gpu_idx)) torch.manual_seed(0) a = torch.randn(N, H, W).to(device) b = a.clone() a.requires_grad = True b.requires_grad = True U, S, V = svd(a) loss = U.sum() + S.sum() + V.sum() loss.backward() u, s, v = torch.svd(b[0], some=True, compute_uv=True) loss0 = u.sum() + s.sum() + v.sum() loss0.backward() # eigenvectors are only precise up to sign testing.assert_allclose(U[0].abs(), u.abs()) testing.assert_allclose(S[0].abs(), s.abs()) testing.assert_allclose(V[0].abs(), v.abs()) a_ = U @ torch.diag_embed(S) @ V.transpose(-2, -1) testing.assert_allclose(a, a_)
def test_double(): torch.manual_seed(0) a = torch.randn(10, 9, 3).cuda().double() b = a.clone() a.requires_grad = True b.requires_grad = True U, S, V = svd(a) loss = U.sum() + S.sum() + V.sum() loss.backward() u, s, v = torch.svd(b[0], some=True, compute_uv=True) loss0 = u.sum() + s.sum() + v.sum() loss0.backward() assert U.dtype == torch.double assert S.dtype == torch.double assert V.dtype == torch.double assert a.grad.dtype == torch.double testing.assert_allclose( U[0].abs(), u.abs()) # eigenvectors are only precise up to sign testing.assert_allclose(S[0].abs(), s.abs()) testing.assert_allclose(V[0].abs(), v.abs()) testing.assert_allclose( a, torch.matmul(torch.matmul(U, torch.diag_embed(S)), V.transpose(-2, -1)))
def test_half(self): torch.manual_seed(0) a = torch.randn(N, H, W).cuda().half() b = a.clone() a.requires_grad = True b.requires_grad = True U, S, V = svd(a) loss = U.sum() + S.sum() + V.sum() loss.backward() assert U.dtype == torch.half assert S.dtype == torch.half assert V.dtype == torch.half assert a.grad.dtype == torch.half a_ = U @ torch.diag_embed(S) @ V.transpose(-2, -1) testing.assert_allclose(a, a_, atol=0.01, rtol=0.01)
def test_half(): torch.manual_seed(0) a = torch.randn(10, 9, 3).cuda().half() b = a.clone() a.requires_grad = True b.requires_grad = True U, S, V = svd(a) loss = U.sum() + S.sum() + V.sum() loss.backward() assert U.dtype == torch.half assert S.dtype == torch.half assert V.dtype == torch.half assert a.grad.dtype == torch.half testing.assert_allclose( a, torch.matmul(torch.matmul(U, torch.diag_embed(S)), V.transpose(-2, -1)))
def test_float(): torch.manual_seed(0) a = torch.randn(N, H, W).cuda() b = a.clone() a.requires_grad = True b.requires_grad = True U, S, V = svd(a) loss = U.sum() + S.sum() + V.sum() loss.backward() u, s, v = torch.svd(b[0], some=True, compute_uv=True) loss0 = u.sum() + s.sum() + v.sum() loss0.backward() # eigenvectors are only precise up to sign testing.assert_allclose(U[0].abs(), u.abs()) testing.assert_allclose(S[0].abs(), s.abs()) testing.assert_allclose(V[0].abs(), v.abs()) testing.assert_allclose(a, U @ torch.diag_embed(S) @ V.transpose(-2, -1))
def batch_least_square(A, B, w): assert A.shape == B.shape num = A.shape[0] centroid_A = torch.mean(A, dim=1) centroid_B = torch.mean(B, dim=1) AA = A - centroid_A.unsqueeze(1) BB = B - centroid_B.unsqueeze(1) H = torch.bmm(torch.transpose(AA, 2, 1), BB) U, S, Vt = svd(H) R = torch.bmm(Vt, U.permute(0, 2, 1)) i = torch.det(R) < 0 tmp = torch.ones([num, 3, 3], dtype=torch.float32).cuda() tmp[i, :, 2] = -1 Vt = Vt * tmp R = torch.bmm(Vt, U.permute(0, 2, 1)) t = centroid_B - torch.bmm(R, centroid_A.unsqueeze(2)).squeeze() return R, t
def bench_speed(N, H, W): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) torch.manual_seed(0) a = torch.randn(N, H, W).cuda() b = a.clone().cuda() torch.cuda.synchronize() start.record() for i in range(100): U, S, V = svd(a) end.record() torch.cuda.synchronize() t = start.elapsed_time(end) / 100 print("Perform batched SVD on a {}x{}x{} matrix: {} ms".format(N, H, W, t)) start.record() U, S, V = torch.svd(b, some=True, compute_uv=True) end.record() torch.cuda.synchronize() t = start.elapsed_time(end) print("Perform torch.svd on a {}x{}x{} matrix: {} ms".format(N, H, W, t))
buf['cd'].append(cd_loss.cpu().item()) buf['sym'].append(sym_loss.cpu().item()) buf['nor'].append(nor_loss.cpu().item()) buf['lap'].append(lap_loss.cpu().item()) '''disentanglement loss''' dot = torch.bmm(basis.abs(), basis.transpose(1, 2).abs()) dot[:, range(opt.num_basis), range(opt.num_basis)] = 0 ortho_loss = dot.norm(p=2, dim=(1, 2)).mean() sp_loss = basis.view(opt.batch_size, opt.num_basis, 150).norm(p=1, dim=2).mean() \ + coef.view(opt.batch_size, opt.num_basis).norm(p=1, dim=-1).mean() basis = basis.reshape(opt.batch_size * opt.num_basis, 50, 3) tmp = torch.bmm(basis.transpose(1, 2), basis) _, s, _ = svd(tmp) svd_loss = s[:, 2].mean() coef = coef.view(opt.batch_size, opt.num_basis) - \ coef.view(opt.batch_size, opt.num_basis).mean(dim=0) cov = torch.bmm(coef.view(opt.batch_size, opt.num_basis, 1), coef.view(opt.batch_size, 1, opt.num_basis)) cov = cov.sum(dim=0) / (opt.batch_size - 1) cov_loss = cov.norm(p=1, dim=(0, 1)) buf['sp'].append(sp_loss.cpu().item()) buf['svd'].append(svd_loss.cpu().item()) buf['ortho'].append(ortho_loss.cpu().item()) buf['cov'].append(cov_loss.cpu().item()) loss = cd_loss + sym_loss * 1 + lap_loss * 3 + nor_loss * 0.1 + g_loss * 0.006 \