Exemple #1
0
def test_memory_consumption_vs_hbplinear():
    """Test if requires same amount of memory as HBPLinear."""
    # HBPLinear
    _ = example_linear()
    mem_stat1 = memory_report()
    # HBPParallelLinear
    _ = example_linear_parallel()
    mem_stat2 = memory_report()
    assert mem_stat1 == mem_stat2
Exemple #2
0
def test_memory_consumption_forward():
    """Check memory consumption during splitting."""
    # feed through HBPLinear
    layer = example_linear()
    _ = layer(random_input())
    mem_stat1 = memory_report()
    del layer, _
    # feed through HBPParallelLinear
    layer = example_linear_parallel()
    _ = layer(random_input())
    mem_stat2 = memory_report()
    assert mem_stat1 == mem_stat2
Exemple #3
0
def test_memory_consumption_before_backward_hessian():
    """Check memory consumption during splitting."""
    _ = example_sequence()
    print("No splitting")
    mem_stat1 = memory_report()

    _ = example_sequence_parallel()
    print("With splitting")
    mem_stat2 = memory_report()
    assert mem_stat1 == mem_stat2

    _ = example_sequence_parallel(10)
    print("With splitting")
    mem_stat2 = memory_report()
    assert mem_stat1 == mem_stat2
Exemple #4
0
def test_memory_consumption_during_hbp():
    """Check for constant memory consumption during HBP."""
    # HBPParallel layer
    parallel = example_linear_parallel()
    # will be initialized at first/10th HBP run
    mem_stat, mem_stat_after = None, None
    # HBP run
    for i in range(10):
        input = random_input()
        out, loss = forward(parallel, input)
        loss_hessian = 2 * eye(out.numel())
        loss.backward()
        out_h = loss_hessian
        parallel.backward_hessian(out_h)
        if i == 0:
            mem_stat = memory_report()
        if i == 9:
            mem_stat_after = memory_report()
    assert mem_stat == mem_stat_after
Exemple #5
0
def test_memory_consumption_during_hbp():
    """Check memory consumption during Hessian backpropagation."""
    parallel = example_sequence_parallel()

    memory_report()

    mem_stat = None
    mem_stat_after = None

    for i in range(10):
        input = random_input()
        out, loss = forward(parallel, input)
        loss_hessian = 2 * eye(out.numel())
        loss.backward()
        out_h = loss_hessian
        parallel.backward_hessian(out_h)
        if i == 0:
            mem_stat = memory_report()
        if i == 9:
            mem_stat_after = memory_report()

    assert mem_stat == mem_stat_after