def test_train_eval_change():
    # Check that ShardedDDP handles the switch from training to eval properly
    dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1],
                            backend="gloo",
                            rank=0,
                            world_size=1)

    model = _get_mlp()
    model.train()
    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    model = ShardedDataParallel(model, optimizer)
    input_tensor = torch.rand((2, 2))
    loss = model(input_tensor).sum()
    loss.backward()  # make sure that the gradients are reduced

    # Wipe the gradients and switch to eval mode
    model.zero_grad()
    model.eval()
    _ = model(input_tensor)
    assert next(model.parameters()).grad is None or torch.norm(
        next(model.parameters()).grad) < 1e-6

    # Get back to training
    model = model.train()
    model(input_tensor).sum().backward()
    assert torch.norm(next(model.parameters()).grad) > 0.0

    dist.destroy_process_group()
def run_test_training_change(rank, world_size, backend, device, temp_file_name,
                             reduce_buffer_size):
    group = dist.init_process_group(init_method="file://" + temp_file_name,
                                    backend=backend,
                                    rank=rank,
                                    world_size=world_size)
    torch.cuda.set_device(rank)

    model = Sequential(Linear(2, 3), Linear(3, 3)).to(device)
    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    ddp_model = ShardedDataParallel(model,
                                    optimizer,
                                    process_group=group,
                                    reduce_buffer_size=reduce_buffer_size)

    inputs = torch.rand((10, 2), device=device)
    outputs = ddp_model(
        inputs)  # assert if the module has not been changed properly
    _ = outputs.norm().backward()

    ddp_model.eval()
    ddp_model(
        inputs
    )  # This will assert if eval() is not properly taken into account
    ddp_model(inputs)

    dist.destroy_process_group()
Beispiel #3
0
def run_eval_mode(_unused):
    """ Testing eval mode make sure this is no asserts. """
    dist.init_process_group(init_method=f"file://{tempfile.mkstemp()[1]}",
                            backend=dist.Backend.GLOO,
                            rank=0,
                            world_size=1)
    model = Sequential(Linear(2, 3), Linear(3, 4))
    optimizer_params = {"lr": 0.1, "momentum": 0.99}
    ddp = ShardedDataParallel(model,
                              torch.optim.SGD,
                              optimizer_params,
                              1,
                              broadcast_buffers=False)
    optimizer = ddp.optimizer

    ddp.eval()
    for _ in range(5):
        input_tensor = torch.rand((64, 2))
        output = ddp(input_tensor)

    ddp.train()
    try:
        for _ in range(5):
            input_tensor = torch.rand((64, 2))
            output = ddp(input_tensor)
    except RuntimeError:
        pass
    else:
        assert False, "Multiple forward passes on training mode should not pass"

    dist.destroy_process_group()