예제 #1
0
def test_memory_tracking_nlp_model():
    """
    Check that we can collect memory traces of a realistic model
    outside of the context of distributed training (DDP or FSDP)
    """

    BACH_SIZE = 10
    INPUT_DIM = 16
    model = GPT2(embed_dim=256,
                 num_heads=2,
                 num_layers=6,
                 num_positions=INPUT_DIM * INPUT_DIM,
                 num_vocab=512,
                 num_classes=2).cuda()
    tracker = LayerwiseMemoryTracker()
    tracker.monitor(model)
    input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).cuda()
    output = model(input_tensor)
    output.sum().backward()

    assert len(tracker.memory_traces) > 0, "failed to collected memory traces"
    assert len(
        tracker.forward_traces) > 0, "failed to collect forward memory traces"
    assert len(tracker.backward_traces
               ) > 0, "failed to collect backward memory traces"
    assert tracker.summary.total_activation_allocations == 12462080
예제 #2
0
def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
    INPUT_DIM = 32
    BACH_SIZE = 10
    STEPS = 10

    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = GPT2(
        embed_dim=512, num_heads=2, num_layers=24, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
    ).to(device)
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer)

    # Optim loop
    def closure():
        optimizer.zero_grad()
        # Force int inputs to prevent the first grad from firing
        input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device)
        loss = ddp_model(input_tensor).abs().sum()
        loss.backward()
        return loss

    # Check for bucketing overflows
    for i in range(STEPS):
        _ = optimizer.step(closure=closure)

    dist.destroy_process_group()
예제 #3
0
def run_test_gpt2(rank, world_size, backend, device, temp_file_name,
                  reduce_buffer_size):
    INPUT_DIM = 16
    BACH_SIZE = 10
    STEPS = 10

    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = GPT2(embed_dim=256,
                 num_heads=2,
                 num_layers=12,
                 num_positions=INPUT_DIM * INPUT_DIM,
                 num_vocab=512,
                 num_classes=2)
    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    ddp_model = ShardedDataParallel(model,
                                    optimizer,
                                    reduce_buffer_size=reduce_buffer_size)

    # Move the model to another device post-construction
    model = model.to(device)

    # Optim loop
    set_to_none = True

    def closure():
        nonlocal set_to_none
        ddp_model.zero_grad(set_to_none=set_to_none)
        set_to_none = not set_to_none

        # Force int inputs to prevent the first grad from firing
        input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device)
        loss = ddp_model(input_tensor).abs().sum()
        loss.backward()
        return loss

    # Check for bucketing overflows
    for i in range(STEPS):
        _ = optimizer.step(closure=closure)

        # Stress test the .to() method
        ddp_model.to(device=device, dtype=torch.float16)
        ddp_model.to(device=device, dtype=torch.float32)

    dist.destroy_process_group()