def test_memory_consumption_vs_hbplinear(): """Test if requires same amount of memory as HBPLinear.""" # HBPLinear _ = example_linear() mem_stat1 = memory_report() # HBPParallelLinear _ = example_linear_parallel() mem_stat2 = memory_report() assert mem_stat1 == mem_stat2
def test_memory_consumption_forward(): """Check memory consumption during splitting.""" # feed through HBPLinear layer = example_linear() _ = layer(random_input()) mem_stat1 = memory_report() del layer, _ # feed through HBPParallelLinear layer = example_linear_parallel() _ = layer(random_input()) mem_stat2 = memory_report() assert mem_stat1 == mem_stat2
def test_memory_consumption_before_backward_hessian(): """Check memory consumption during splitting.""" _ = example_sequence() print("No splitting") mem_stat1 = memory_report() _ = example_sequence_parallel() print("With splitting") mem_stat2 = memory_report() assert mem_stat1 == mem_stat2 _ = example_sequence_parallel(10) print("With splitting") mem_stat2 = memory_report() assert mem_stat1 == mem_stat2
def test_memory_consumption_during_hbp(): """Check for constant memory consumption during HBP.""" # HBPParallel layer parallel = example_linear_parallel() # will be initialized at first/10th HBP run mem_stat, mem_stat_after = None, None # HBP run for i in range(10): input = random_input() out, loss = forward(parallel, input) loss_hessian = 2 * eye(out.numel()) loss.backward() out_h = loss_hessian parallel.backward_hessian(out_h) if i == 0: mem_stat = memory_report() if i == 9: mem_stat_after = memory_report() assert mem_stat == mem_stat_after
def test_memory_consumption_during_hbp(): """Check memory consumption during Hessian backpropagation.""" parallel = example_sequence_parallel() memory_report() mem_stat = None mem_stat_after = None for i in range(10): input = random_input() out, loss = forward(parallel, input) loss_hessian = 2 * eye(out.numel()) loss.backward() out_h = loss_hessian parallel.backward_hessian(out_h) if i == 0: mem_stat = memory_report() if i == 9: mem_stat_after = memory_report() assert mem_stat == mem_stat_after