def testForward(self): """Test Forward Pass""" B, H, W = 2, 500, 700 M = torch.randn((B, H, W), dtype=torch.float) r = torch.rand((B, H), dtype=M.dtype) c = torch.rand((B, W), dtype=M.dtype) P = sinkhorn(M, eps=1.0e-9) self.assertTrue( torch.allclose(torch.sum(P, 2), torch.full((B, H), 1.0 / H))) self.assertTrue( torch.allclose(torch.sum(P, 1), torch.full((B, W), 1.0 / W))) P = OptimalTransportFcn().apply(M, None, None, 1.0e-9) self.assertTrue( torch.allclose(torch.sum(P, 2), torch.full((B, H), 1.0 / H))) self.assertTrue( torch.allclose(torch.sum(P, 1), torch.full((B, W), 1.0 / W))) P = sinkhorn(M, normalize(r, p=1.0), normalize(c, p=1.0), eps=1.0e-9) self.assertTrue(torch.allclose(torch.sum(P, 2), normalize(r, p=1.0))) self.assertTrue(torch.allclose(torch.sum(P, 1), normalize(c, p=1.0))) P = OptimalTransportFcn().apply(M, r, c, 1.0e-9) self.assertTrue(torch.allclose(torch.sum(P, 2), normalize(r, p=1.0))) self.assertTrue(torch.allclose(torch.sum(P, 1), normalize(c, p=1.0)))
def toy_example(): # test sinkhorn, approximate gradient and implicit gradient torch.manual_seed(0) M_true = torch.randn((2, 50, 50), dtype=torch.float) #M_true = torch.log(torch.rand((2, 50, 50), dtype=torch.float)) r_true = normalize(torch.rand((1, 50), dtype=M_true.dtype), p=1.0) c_true = normalize(torch.rand((1, 50), dtype=M_true.dtype), p=1.0) fcns = [sinkhorn, OptimalTransportLayer(approx_grad=True), OptimalTransportLayer()] # calibrated (uniform) print("Learning calibrated (uniform) models...") P_true = sinkhorn(M_true) M_init = torch.log(torch.rand_like(M_true)) M_good, h_good, t_good = learnM(fcns, M_init, None, None, P_true) # calibrated (non-uniform) print("Learning calibrated (non-uniform) models...") P_true = sinkhorn(M_true, r_true, c_true) M_init = torch.log(torch.rand_like(M_true)) M_good2, h_good2, t_good2 = learnM(fcns, M_init, r_true, c_true, P_true) # mis-calibrated print("Learning mis-calibrated models...") P_true = sinkhorn(M_true, r_true, c_true) M_bad, h_bad, t_bad = learnM(fcns, M_init, None, None, P_true) # learning M, r and c fcns = [OptimalTransportLayer()] r_init = normalize(torch.rand_like(r_true), p=1.0) c_init = normalize(torch.rand_like(c_true), p=1.0) h_mrc = learnMRC(fcns, M_init, r_init, c_init, P_true) print("...done") # plot learning curves plt.figure() plt.semilogy(h_good[0]); plt.semilogy(h_good[1]); plt.semilogy(h_good[2]) plt.title('Calibrated Model (Uniform)'); plt.xlabel('iteration'); plt.ylabel('loss (log scale)') plt.legend(['autograd', 'approx', 'implicit']) plt.figure() plt.semilogy(h_good2[0]); plt.semilogy(h_good2[1]); plt.semilogy(h_good2[2]) plt.title('Calibrated Model (Non-uniform)'); plt.xlabel('iteration'); plt.ylabel('loss (log scale)') plt.legend(['autograd', 'approx', 'implicit']) plt.figure() plt.semilogy(h_bad[0]); plt.semilogy(h_bad[1]); plt.semilogy(h_bad[2]); plt.semilogy(h_mrc[0]) plt.title('Mis-calibrated Model'); plt.xlabel('iteration'); plt.ylabel('loss (log scale)') plt.legend(['autograd', 'approx', 'implicit', 'implicit w/ r and c'])
def speed_memory_test(device=None, batch_size=1, repeats=100): """Run speed and memory tests.""" torch.manual_seed(0) if device is None: device = torch.device('cpu') n = [10, 50, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] t = [[], [], [], []] m = [[], [], [], []] fcns = [sinkhorn, OptimalTransportLayer(approx_grad=True), OptimalTransportLayer(block_inverse=False), OptimalTransportLayer()] for ni in n: print("Profiling on {}-by-{} problem...".format(ni, ni)) M_true = torch.randn((batch_size, ni, ni), dtype=torch.float) #M_true = torch.log(torch.rand((batch_size, ni, ni), dtype=torch.float)) P_true = sinkhorn(M_true).to(device) M_init = torch.log(torch.rand_like(M_true)).to(device) # profile speed _, _, ti = learnM(fcns, M_init, None, None, P_true, repeats) for i in range(4): t[i].append(ti[i]) # profile memory for i, f in enumerate(fcns): with profiler.profile(profile_memory=True) as prof: _ = learnM([f], M_init, None, None, P_true, 1) m[i].append(prof.total_average().cpu_memory_usage) print("...done") plt.figure() plt.plot(n, t[0], n, t[1], n, t[2], n, t[3]) plt.xlabel('problem size'); plt.ylabel('running time') plt.legend(['autograd', 'approx', 'implicit (full inv)', 'implicit (blk inv)']) plt.title('Running time on {} with batch size {}'.format(device, batch_size)) plt.figure() plt.plot(n, m[0], n, m[1], n, m[2], n, m[3]) plt.xlabel('problem size'); plt.ylabel('memory usage') plt.legend(['autograd', 'approx', 'implicit (full inv)', 'implicit (blk inv)']) plt.title('Memory usage on {} with batch size {}'.format(device, batch_size))
def plot_memory(): M_init = torch.randn((1, 500, 500), dtype=torch.float) maxiters_range = list(range(1, 11)) probsize_range = [5, 10, 25, 50, 100, 200, 500, 800, 1000] memory_by_maxiters = [[], []] memory_by_probsize = [[], []] for maxiters in maxiters_range: # profile autograd M = M_init.clone() M.requires_grad = True with profiler.profile(profile_memory=True) as prof: P = sinkhorn(M, eps=0.0, maxiters=maxiters) torch.linalg.norm(P - torch.eye(M.shape[1])).backward() memory_by_maxiters[0].append(prof.total_average().cpu_memory_usage / (1024 * 1024)) # profile implicit M = M_init.clone() M.requires_grad = True f = OptimalTransportLayer(eps=0.0, maxiters=maxiters) with profiler.profile(profile_memory=True) as prof: P = f(M) torch.linalg.norm(P - torch.eye(M.shape[1])).backward() memory_by_maxiters[1].append(prof.total_average().cpu_memory_usage / (1024 * 1024)) for n in probsize_range: M_init = torch.randn((1, n, n), dtype=torch.float) # profile autograd M = M_init.clone() M.requires_grad = True with profiler.profile(profile_memory=True) as prof: P = sinkhorn(M, eps=0.0, maxiters=10) torch.linalg.norm(P - torch.eye(n)).backward() memory_by_probsize[0].append(prof.total_average().cpu_memory_usage / (1024 * 1024)) # profile implicit M = M_init.clone() M.requires_grad = True f = OptimalTransportLayer(eps=0.0, maxiters=10) with profiler.profile(profile_memory=True) as prof: P = f(M) torch.linalg.norm(P - torch.eye(n)).backward() memory_by_probsize[1].append(prof.total_average().cpu_memory_usage / (1024 * 1024)) plt.figure(figsize=(7, 7)) plt.plot(maxiters_range, memory_by_maxiters[0], linestyle='-') plt.plot(maxiters_range, memory_by_maxiters[1], linestyle='-.') plt.xlabel('iterations', fontsize=30) plt.ylabel('memory usage (MB)', fontsize=30) plt.xticks(fontsize=20) plt.yticks(fontsize=20, rotation=90) plt.legend(['autograd', 'implicit'], fontsize=30) # plt.title("Memory usage for problem of size 500-by-500", fontsize=30) plt.tight_layout() plt.figure(figsize=(7, 7)) plt.plot(probsize_range, memory_by_probsize[0], linestyle='-') plt.plot(probsize_range, memory_by_probsize[1], linestyle='-.') plt.xlabel('problem size', fontsize=30) plt.ylabel('memory usage (MB)', fontsize=30) plt.xticks(fontsize=20) plt.yticks(fontsize=20, rotation=90) # plt.legend(['autograd', 'implicit'], fontsize=30) # plt.title("Memory usage for 10 Sinkhorn iterations", fontsize=30) plt.tight_layout()
def plot_running_time(batch_size, device, enable_legend=False): """Plot running time for given device.""" torch.manual_seed(22) print("Running on {} with batch size of {}...".format(device, batch_size)) n = [5, 10, 25, 50, 100, 200, 300, 500] t1, t2, t3, t4 = [], [], [], [] for ni in n: print("Timing on {}-by-{} problem...".format(ni, ni)) M_true = torch.randn((batch_size, ni, ni), dtype=torch.float) P_true = sinkhorn(M_true).to(device) M_init = torch.log(torch.rand_like(M_true)).to(device) t1.append( timeit(wrapper(learnM, [sinkhorn], M_init, None, None, P_true, iters=500), number=1)) t3.append( timeit(wrapper(learnM, [OptimalTransportLayer(method='full')], M_init, None, None, P_true, iters=500), number=1)) t2.append( timeit(wrapper(learnM, [OptimalTransportLayer(method='approx')], M_init, None, None, P_true, iters=500), number=1)) t4.append( timeit(wrapper(learnM, [OptimalTransportLayer()], M_init, None, None, P_true, iters=500), number=1)) print("...done") plt.figure(figsize=(7, 7)) plt.plot(n, t1, marker='x', markersize=14) plt.plot(n, t2, marker='*', markersize=14) plt.plot(n, t3, marker='o', markersize=14) plt.plot(n, t4, marker='<', markersize=14) plt.xlabel('problem size', fontsize=30) plt.ylabel('running time (s)', fontsize=30) plt.xticks(fontsize=20) plt.yticks(fontsize=20, rotation=90) # plt.title('Running time on {} with batch size {}'.format(device, batch_size), fontsize=30) plt.tight_layout() if enable_legend: plt.legend([ 'autograd', 'approx', 'implicit (full inv)', 'implicit (blk inv)' ], fontsize=30)
def speed_memory_test(device=None, batch_size=1, repeats=100): """Run speed and memory tests.""" torch.manual_seed(0) if device is None: device = torch.device('cpu') n = [10, 50, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] t = [[], [], [], []] m = [[], [], [], []] fcns = [ sinkhorn, OptimalTransportLayer(method='approx'), OptimalTransportLayer(method='full'), OptimalTransportLayer() ] for ni in n: print("Profiling on {}-by-{} problem...".format(ni, ni)) M_true = torch.randn((batch_size, ni, ni), dtype=torch.float) #M_true = torch.log(torch.rand((batch_size, ni, ni), dtype=torch.float)) P_true = sinkhorn(M_true).detach().to(device) M_init = torch.log(torch.rand_like(M_true) + 1.0e-16).to(device) # profile speed for i in range(len(fcns)): try: _, _, ti = learnM((fcns[i], ), M_init, None, None, P_true, repeats) t[i].append(ti[0]) except: t[i].append(float('nan')) torch.cuda.empty_cache() # profile memory for i, f in enumerate(fcns): try: if device == torch.device("cpu"): with profiler.profile(profile_memory=True) as prof: _ = learnM([f], M_init, None, None, P_true, 1) m[i].append(prof.total_average().cpu_memory_usage) else: torch.cuda.reset_peak_memory_stats() _ = learnM([f], M_init, None, None, P_true, 1) m[i].append(torch.cuda.max_memory_allocated(None)) except: m[i].append(float('nan')) torch.cuda.empty_cache() print("...done") _mb = 1.0 / (1024.0 * 1024.0) print("-" * 80) print("Profiling results on {}".format(device)) print("-" * 80) print("{:<4} {:<18} {:<18} {:<18} {:<18}".format("", 'autograd', 'approx', 'implicit (full)', 'implicit (blk)')) for i in range(len(n)): print( "{:<4} {:6.1f}s {:6.1f}MB {:6.1f}s {:6.1f}MB {:6.1f}s {:6.1f}MB {:6.1f}s {:6.1f}MB" .format(n[i], t[0][i], m[0][i] * _mb, t[1][i], m[1][i] * _mb, t[2][i], m[2][i] * _mb, t[3][i], m[3][i] * _mb)) plt.figure() plt.plot(n, t[0], n, t[1], n, t[2], n, t[3]) plt.xlabel('problem size') plt.ylabel('running time') plt.legend( ['autograd', 'approx', 'implicit (full inv)', 'implicit (blk inv)']) plt.title('Running time on {} with batch size {}'.format( device, batch_size)) plt.figure() plt.plot(n, m[0], n, m[1], n, m[2], n, m[3]) plt.xlabel('problem size') plt.ylabel('memory usage') plt.legend( ['autograd', 'approx', 'implicit (full inv)', 'implicit (blk inv)']) plt.title('Memory usage on {} with batch size {}'.format( device, batch_size))