Пример #1
0
def run_one_step(rank, world_size, backend, device, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    model = Sequential(Linear(2, 3), Linear(3, 4)).to(device)

    ddp = ShardedDataParallel(
        module=model, optimizer=torch.optim.SGD, optimizer_params={"lr": 0.1, "momentum": 0.99}, world_size=world_size
    )
    optimizer = ddp.optimizer

    input_tensor = torch.rand((64, 2)).to(device)
    output = ddp(input_tensor).abs().sum() / input_tensor.numel()
    output.backward()
    ddp.reduce()

    # Check that all the grads have been populated, for the shard
    if device == torch.device("cuda"):
        torch.cuda.synchronize()  # flush any remaining cuda op, just in case

    for pg in optimizer.optim.param_groups:
        for param in pg["params"]:
            if param.requires_grad:
                assert param.grad.abs().sum().item() > 0.0, "The reduce step should have populated all the gradients"

    # Check that the optimization process makes sense (ie. loss goes down for the same data)
    optimizer.step()
    new_eval = ddp(input_tensor).abs().sum() / input_tensor.numel()
Пример #2
0
    def init_distributed_data_parallel_model(self):
        """
        Initialize ShardedDataParallel,
        needed for sharded distributed training.
        This is where a model should be wrapped by DDP.
        """

        # Init the base class, everything but the distributed model wrap is to be reused
        super().init_distributed_data_parallel_model()

        broadcast_buffers = (
            self.broadcast_buffers_mode == BroadcastBuffersMode.FORWARD_PASS)

        # Replace the original DDP wrap by the shard-aware ShardedDDP
        # we use the fairscale reduce_buffer_size by default however, if user sets it to
        # some different value, we use the different value.
        reduce_buffer_size = 2**23
        if self.config.MODEL.SHARDED_DDP_SETUP.reduce_buffer_size >= 0:
            reduce_buffer_size = self.config.MODEL.SHARDED_DDP_SETUP.reduce_buffer_size
        logging.info(f"Setting reduce_buffer_size: {reduce_buffer_size}")
        if isinstance(self.optimizer, ZeRO):
            logging.info("Using ShardedDDP")
            self.distributed_model = ShardedDataParallel(
                module=self.base_model,
                sharded_optimizer=self.optimizer.optimizer,
                broadcast_buffers=broadcast_buffers,
                reduce_buffer_size=reduce_buffer_size,
            )
        else:
            raise NotImplementedError(
                "This DataParallel engine should only be used in conjunction with ZeRO"
            )
Пример #3
0
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = _DoubleInput().to(device)

    parameters = list(model.parameters())
    optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])

    # Optim loop
    def closure():
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
        optimizer_1.zero_grad()
        optimizer_2.zero_grad()

        _ = optimizer_1.step(closure=closure)
        _ = optimizer_2.step(closure=closure)

    dist.destroy_process_group()
Пример #4
0
    def init_distributed_data_parallel_model(self):
        """
        Initialize
        `torch.nn.parallel.distributed.DistributedDataParallel <https://pytorch.org/
        docs/stable/nn.html#distributeddataparallel>`_.

        Needed for distributed training. This is where a model should be wrapped by DDP.
        """
        if not is_distributed_training_run():
            return
        assert (
            self.distributed_model is None
        ), "init_ddp_non_elastic must only be called once"

        broadcast_buffers = (
            self.broadcast_buffers_mode == BroadcastBuffersMode.FORWARD_PASS
        )

        if self.use_sharded_ddp:
            if not isinstance(self.optimizer, ZeRO):
                raise ValueError(
                    "ShardedDataParallel engine should only be used in conjunction with ZeRO optimizer"
                )
            from fairscale.nn.data_parallel import ShardedDataParallel

            # Replace the original DDP wrap by the shard-aware ShardedDDP
            self.distributed_model = ShardedDataParallel(
                module=self.base_model,
                sharded_optimizer=self.optimizer.optimizer,
                broadcast_buffers=broadcast_buffers,
            )
        else:
            self.distributed_model = init_distributed_data_parallel_model(
                self.base_model,
                broadcast_buffers=broadcast_buffers,
                find_unused_parameters=self.find_unused_parameters,
                bucket_cap_mb=self.ddp_bucket_cap_mb,
            )
            if self.fp16_grad_compress:

                from torch.distributed.algorithms import ddp_comm_hooks

                # FP16 hook is stateless and only takes a process group as the state.
                # We use the default process group so we set the state to None.
                process_group = None
                self.distributed_model.register_comm_hook(
                    process_group,
                    ddp_comm_hooks.default_hooks.fp16_compress_hook,
                )
        if (
            isinstance(self.base_loss, ClassyLoss)
            and self.base_loss.has_learned_parameters()
        ):
            logging.info("Initializing distributed loss")
            self.distributed_loss = init_distributed_data_parallel_model(
                self.base_loss,
                broadcast_buffers=broadcast_buffers,
                find_unused_parameters=self.find_unused_parameters,
                bucket_cap_mb=self.ddp_bucket_cap_mb,
            )
def test_train_eval_change():
    # Check that ShardedDDP handles the switch from training to eval properly
    dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1],
                            backend="gloo",
                            rank=0,
                            world_size=1)

    model = _get_mlp()
    model.train()
    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    model = ShardedDataParallel(model, optimizer)
    input_tensor = torch.rand((2, 2))
    loss = model(input_tensor).sum()
    loss.backward()  # make sure that the gradients are reduced

    # Wipe the gradients and switch to eval mode
    model.zero_grad()
    model.eval()
    _ = model(input_tensor)
    assert next(model.parameters()).grad is None or torch.norm(
        next(model.parameters()).grad) < 1e-6

    # Get back to training
    model = model.train()
    model(input_tensor).sum().backward()
    assert torch.norm(next(model.parameters()).grad) > 0.0

    dist.destroy_process_group()
def run_test_two_inputs(rank, world_size, backend, device, temp_file_name,
                        reduce_buffer_size):
    dist.init_process_group(init_method="file://" + temp_file_name,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    if device == "cuda":
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)

    model = _DoubleInput().to(device)
    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    ddp_model = ShardedDataParallel(model,
                                    optimizer,
                                    reduce_buffer_size=reduce_buffer_size)

    # Optim loop
    def closure():
        optimizer.zero_grad()
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
        _ = optimizer.step(closure=closure)

    dist.destroy_process_group()
Пример #7
0
    def init_components(
        self,
        model_fn=None,
        criterion_fn=None,
        optimizer_fn=None,
        scheduler_fn=None,
    ):
        """Inits the runs components."""
        model = model_fn()
        model = self.sync_device(model)
        if self._sync_bn:
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

        criterion = criterion_fn()
        criterion = self.sync_device(criterion)

        optimizer = optimizer_fn(model)
        optimizer = self.sync_device(optimizer)

        optimizer = OSS(model.parameters(),
                        optim=optimizer.__class__,
                        **optimizer.defaults)
        model = ShardedDataParallel(model, optimizer, **self.ddp_kwargs)

        scheduler = scheduler_fn(optimizer)
        scheduler = self.sync_device(scheduler)
        return model, criterion, optimizer, scheduler
Пример #8
0
def run_eval_mode(_unused):
    """ Testing eval mode make sure this is no asserts. """
    dist.init_process_group(init_method=f"file://{tempfile.mkstemp()[1]}",
                            backend=dist.Backend.GLOO,
                            rank=0,
                            world_size=1)
    model = Sequential(Linear(2, 3), Linear(3, 4))
    optimizer_params = {"lr": 0.1, "momentum": 0.99}
    ddp = ShardedDataParallel(model,
                              torch.optim.SGD,
                              optimizer_params,
                              1,
                              broadcast_buffers=False)
    optimizer = ddp.optimizer

    ddp.eval()
    for _ in range(5):
        input_tensor = torch.rand((64, 2))
        output = ddp(input_tensor)

    ddp.train()
    try:
        for _ in range(5):
            input_tensor = torch.rand((64, 2))
            output = ddp(input_tensor)
    except RuntimeError:
        pass
    else:
        assert False, "Multiple forward passes on training mode should not pass"

    dist.destroy_process_group()
def run_test_device_change(rank, world_size, backend, device, temp_file_name,
                           reduce_buffer_size):
    # Check that the wrapped module can change devices
    dist.init_process_group(init_method="file://" + temp_file_name,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    torch.cuda.set_device(rank)

    model = Sequential(Linear(2, 3), Linear(
        3, 3)).cpu()  # not device on purpose, test changing it after the fact
    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    ddp_model = ShardedDataParallel(model,
                                    optimizer,
                                    sync_models_at_startup=False,
                                    reduce_buffer_size=reduce_buffer_size)
    try:
        ddp_model.to(device)
        assert False, "Changing devices should be caught and not supported"
    except AssertionError:
        pass

    dist.destroy_process_group()
Пример #10
0
def test_catch_grad_grad():
    with temp_files_ctx(num=1) as temp_files:
        # Check that ShardedDDP exposes the original module's attributes
        dist.init_process_group(init_method="file://" + temp_files[0],
                                backend="gloo",
                                rank=0,
                                world_size=1)

        model = Sequential(Linear(2, 3), Linear(3, 3))
        model.train()
        chained_grad = torch.zeros_like(next(model.parameters()))
        chained_grad.requires_grad = True
        next(model.parameters()).grad = chained_grad

        optimizer = OSS(params=model.parameters(),
                        optim=torch.optim.SGD,
                        lr=1e-3,
                        momentum=0.99)
        ddp_model = ShardedDataParallel(model, optimizer)

        inputs = torch.rand(100, 2)
        with pytest.raises(RuntimeError):
            _ = ddp_model(inputs)

        dist.destroy_process_group()
Пример #11
0
    def parallelize_model(self) -> None:
        registry.register("data_parallel", False)
        registry.register("distributed", False)
        if ("cuda" in str(self.device) and torch.cuda.device_count() > 1
                and not self.distributed):
            registry.register("data_parallel", True)
            self.model = torch.nn.DataParallel(self.model)

        if "cuda" in str(self.device) and self.distributed:
            registry.register("distributed", True)
            set_torch_ddp = True
            try:
                from fairscale.nn.data_parallel import ShardedDataParallel
                from fairscale.optim.oss import OSS

                if isinstance(self.optimizer, OSS):
                    self.model = ShardedDataParallel(self.model,
                                                     self.optimizer)
                    set_torch_ddp = False
                    logger.info("Using FairScale ShardedDataParallel")
            except ImportError:
                logger.info("Using PyTorch DistributedDataParallel")
                warnings.warn(
                    "You can enable ZeRO and Sharded DDP, by installing fairscale "
                    + "and setting optimizer.enable_state_sharding=True.")

            if set_torch_ddp:
                self.model = torch.nn.parallel.DistributedDataParallel(
                    self.model,
                    device_ids=[self.local_rank],
                    output_device=self.local_rank,
                    find_unused_parameters=self.config.training.
                    find_unused_parameters,
                )
Пример #12
0
def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
    INPUT_DIM = 32
    BACH_SIZE = 10
    STEPS = 10

    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = GPT2(
        embed_dim=512, num_heads=2, num_layers=24, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
    ).to(device)
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer)

    # Optim loop
    def closure():
        optimizer.zero_grad()
        # Force int inputs to prevent the first grad from firing
        input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device)
        loss = ddp_model(input_tensor).abs().sum()
        loss.backward()
        return loss

    # Check for bucketing overflows
    for i in range(STEPS):
        _ = optimizer.step(closure=closure)

    dist.destroy_process_group()
def run_test_training_change(rank, world_size, backend, device, temp_file_name,
                             reduce_buffer_size):
    group = dist.init_process_group(init_method="file://" + temp_file_name,
                                    backend=backend,
                                    rank=rank,
                                    world_size=world_size)
    torch.cuda.set_device(rank)

    model = Sequential(Linear(2, 3), Linear(3, 3)).to(device)
    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    ddp_model = ShardedDataParallel(model,
                                    optimizer,
                                    process_group=group,
                                    reduce_buffer_size=reduce_buffer_size)

    inputs = torch.rand((10, 2), device=device)
    outputs = ddp_model(
        inputs)  # assert if the module has not been changed properly
    _ = outputs.norm().backward()

    ddp_model.eval()
    ddp_model(
        inputs
    )  # This will assert if eval() is not properly taken into account
    ddp_model(inputs)

    dist.destroy_process_group()
Пример #14
0
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    device = torch.device("cuda")
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)
    np.random.seed(rank)  # Any model works. Add one different buffer per rank

    model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)
    n_half_params = len(list(model.parameters())) // 2

    sharded_optimizer = OSS(
        params=list(model.parameters())[:n_half_params], optim=torch.optim.SGD, lr=1e-3, momentum=0.99
    )
    sharded_optimizer_2 = OSS(
        params=list(model.parameters())[n_half_params:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99
    )

    sharded_ddp_model = ShardedDataParallel(module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True)

    ddp_model_single = copy.deepcopy(model)
    ddp_optimizer = torch.optim.SGD(list(ddp_model_single.parameters())[:n_half_params], lr=1e-3, momentum=0.99)
    ddp_optimizer_2 = torch.optim.SGD(list(ddp_model_single.parameters())[n_half_params:], lr=1e-3, momentum=0.99)
    ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)

    def check_same_model_params():
        for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
            for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
                assert torch.allclose(
                    p, ddp_p, atol=1e-3
                ), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"
        for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
            assert torch.allclose(b, ddp_b, atol=1e-3), "Model buffers differ in between DDP and ShardedDDP"

    check_same_model_params()  # The models should stay the same in between the ranks

    for i in range(20):
        input_tensor = torch.rand((64, 2)).to(device)

        # Run DDP
        ddp_optimizer.zero_grad()
        ddp_optimizer_2.zero_grad()
        ddp_loss = ddp_model(input_tensor).abs().sum()
        ddp_loss.backward()
        ddp_optimizer.step()
        ddp_optimizer_2.step()

        # Run Sharded
        sharded_optimizer.zero_grad()
        sharded_optimizer_2.zero_grad()
        sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
        sharded_loss.backward()
        sharded_optimizer.step()
        sharded_optimizer_2.step()
        check_same_model_params()

    dist.destroy_process_group()
Пример #15
0
def run_test_gpt2(rank, world_size, backend, device, temp_file_name,
                  reduce_buffer_size):
    INPUT_DIM = 16
    BACH_SIZE = 10
    STEPS = 10

    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = GPT2(embed_dim=256,
                 num_heads=2,
                 num_layers=12,
                 num_positions=INPUT_DIM * INPUT_DIM,
                 num_vocab=512,
                 num_classes=2)
    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    ddp_model = ShardedDataParallel(model,
                                    optimizer,
                                    reduce_buffer_size=reduce_buffer_size)

    # Move the model to another device post-construction
    model = model.to(device)

    # Optim loop
    set_to_none = True

    def closure():
        nonlocal set_to_none
        ddp_model.zero_grad(set_to_none=set_to_none)
        set_to_none = not set_to_none

        # Force int inputs to prevent the first grad from firing
        input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device)
        loss = ddp_model(input_tensor).abs().sum()
        loss.backward()
        return loss

    # Check for bucketing overflows
    for i in range(STEPS):
        _ = optimizer.step(closure=closure)

        # Stress test the .to() method
        ddp_model.to(device=device, dtype=torch.float16)
        ddp_model.to(device=device, dtype=torch.float32)

    dist.destroy_process_group()
Пример #16
0
def run_one_step(rank, world_size, backend, device, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    # Any model works. Add one different buffer per rank
    model = Sequential(Linear(2, 3)).to(device)
    model.register_buffer("test_buffer", torch.ones((1)) * rank)

    def weights_init(m):
        if isinstance(m, Linear):
            torch.nn.init.constant_(m.weight.data, 1.0)
            torch.nn.init.constant_(m.bias.data, 1.0)

    model.apply(weights_init)
    model.to(device)

    ddp = ShardedDataParallel(
        module=model,
        optimizer=torch.optim.SGD,
        optimizer_params={
            "lr": 0.01,
            "momentum": 0.99
        },
        world_size=world_size,
        broadcast_buffers=True,
    )
    optimizer = ddp.optimizer
    model = ddp.module

    # Different input per rank, allows for checking that the gradients have been properly reduced
    input_tensor = (torch.ones((64, 2)) * rank).to(device)
    output = ddp(input_tensor).abs().sum()
    output.backward()
    ddp.reduce()

    # Check that all the grads have been populated, for the shard
    for pg in optimizer.optim.param_groups:
        for param in pg["params"]:
            if param.shape == torch.Size([3, 2]):
                assert param.grad[0, 0].cpu() == torch.tensor([32.0])
            if param.shape == torch.Size([3]):
                assert param.grad[0].cpu() == torch.tensor([64.0])

    # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
    for b in model.buffers():
        assert b.cpu().item() == 0.0

    dist.destroy_process_group()
Пример #17
0
def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_buffer_size):
    # Only work with the even ranks, to check that the global_rank indexing is properly used
    dist.init_process_group(init_method="file://" + tempfile_name, backend=backend, rank=rank, world_size=world_size)

    sub_group_ranks = [0, 2]
    process_group = torch.distributed.new_group(ranks=sub_group_ranks, backend=backend)

    # Make sure that all the ranks get different training data
    # So that the sync check in between their models is meaningful
    torch.manual_seed(rank)
    np.random.seed(rank)

    # Standard deep learning setup
    device = "cuda"
    torch.cuda.set_device(rank)

    epochs, batch, input_width, hidden, target_width = 5, 3, 20, 10, 5
    loss_fn = torch.nn.L1Loss().to(device)

    def check(optimizer, model):
        # Just run a couple of epochs, check that the model is properly updated
        for _ in range(epochs):
            target = torch.rand((batch, target_width), device=device)
            inputs = torch.rand((batch, input_width), device=device)

            def closure():
                optimizer.zero_grad()
                output = model(inputs)
                loss = loss_fn(output, target)
                loss.backward()
                return loss

            _ = optimizer.step(closure=closure)

            # Check that all the params are the same on all ranks
            check_same_models_across_ranks(
                model, process_group, params_should_be_equal=True, check_broadcast_buffers=True
            )

    if rank in sub_group_ranks:
        # Model not-fitting in the broadcast bucket
        model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)).to(
            device
        )

        # With SGD, Momentum is required to get a state to shard
        optimizer = OSS(model.parameters(), group=process_group, lr=1e-3, momentum=0.99)
        model = ShardedDataParallel(
            model, optimizer, process_group=process_group, reduce_buffer_size=reduce_buffer_size
        )
        check(optimizer, model)

    dist.destroy_process_group(process_group)
Пример #18
0
def test_ddp_attributes():
    # Check that ShardedDDP exposes the same attributes as Pytorch's DDP
    # - is multi_device_module
    # - device_type
    dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)

    model = Sequential(Linear(2, 3), Linear(3, 3))
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer)

    assert hasattr(ddp_model, "is_multi_device_module")
    assert hasattr(ddp_model, "device_type")
    dist.destroy_process_group()
Пример #19
0
def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)

    model = Sequential(Linear(2, 3), torch.nn.BatchNorm1d(3), Linear(3, 3)).to(device)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer)

    assert isinstance(model[1], torch.nn.SyncBatchNorm)
    # Ensures sync batch norm handles have been added
    ddp_model(torch.randn(2, 2).to(device))
    dist.destroy_process_group()
Пример #20
0
def train(rank, args, model, device, train_loader, num_epochs):
    ##############
    # SETUP
    dist_init(rank, WORLD_SIZE, BACKEND)
    ddp = ShardedDataParallel(
        module=model,
        optimizer=torch.optim.Adadelta,
        optimizer_params={"lr": 1e-4},
        world_size=WORLD_SIZE,
        broadcast_buffers=True,
    )

    ddp.train()
    optimizer = ddp.optimizer
    # Reset the memory use counter
    torch.cuda.reset_peak_memory_stats(rank)

    # Training loop
    torch.cuda.synchronize(rank)
    training_start = time.monotonic()

    loss_fn = nn.CrossEntropyLoss()
    ##############

    model.train()
    measurements = []
    for epoch in range(num_epochs):
        epoch_start = time.monotonic()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            def closure():
                model.zero_grad()
                outputs = model(data)
                loss = loss_fn(outputs, target)
                loss /= WORLD_SIZE
                loss.backward()

                # if dist.get_rank() == 0:
                #     print(f"Loss: {loss.item()}")

                ddp.reduce()  # Send the gradients to the appropriate shards
                return loss

            optimizer.step(closure)

        epoch_end = time.monotonic()

    torch.cuda.synchronize(rank)
    training_stop = time.monotonic()
    print("Total Time:", training_stop - training_start)
Пример #21
0
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)

    class _DoubleInput(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3),
                                  Linear(3, 3), Linear(3, 3), Linear(3, 3))

        def forward(self, x, y):
            x1 = self.mlp(x)
            x2 = self.mlp(y)
            return torch.cat((x1, x2), dim=1)

    model = _DoubleInput().to(device)

    parameters = list(model.parameters())
    optimizer_1 = OSS(params=parameters[:-10],
                      optim=torch.optim.SGD,
                      lr=0.01,
                      momentum=0.99)
    optimizer_2 = OSS(params=parameters[-10:],
                      optim=torch.optim.SGD,
                      lr=0.01,
                      momentum=0.99)
    ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])

    # Optim loop
    def closure():
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
        optimizer_1.zero_grad()
        optimizer_2.zero_grad()

        _ = optimizer_1.step(closure=closure)
        _ = optimizer_2.step(closure=closure)

    dist.destroy_process_group()
Пример #22
0
def test_random_attributes():
    # Check that ShardedDDP exposes the original module's attributes
    dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)

    model = Sequential(Linear(2, 3), Linear(3, 3))
    model.banana = "sweet"

    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer)

    assert hasattr(ddp_model, "banana")
    assert not hasattr(ddp_model, "orange")

    dist.destroy_process_group()
Пример #23
0
def run_test_device_change(rank, world_size, backend, device, temp_file_name):
    # Check that the wrapped module can change devices

    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)

    model = Sequential(Linear(2, 3), Linear(3, 3)).cpu()
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer)
    ddp_model.to(device)

    inputs = torch.rand((10, 2), device=device)
    outputs = ddp_model(inputs)  # assert if the module has not been changed properly
    loss = outputs.norm().backward()

    dist.destroy_process_group()
Пример #24
0
def run_one_step(rank, world_size, backend, device, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    # Any model works. Add one different buffer per rank
    model = Sequential(Linear(2, 3), Linear(3, 4)).to(device)
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)

    ddp = ShardedDataParallel(
        module=model,
        optimizer=torch.optim.SGD,
        optimizer_params={
            "lr": 0.01,
            "momentum": 0.99
        },
        world_size=world_size,
        broadcast_buffers=True,
    )
    optimizer = ddp.optimizer
    model = ddp.module

    input_tensor = torch.rand((64, 2)).to(device)
    output = ddp(input_tensor).abs().sum() / input_tensor.numel()
    output.backward()
    ddp.reduce()

    # Check that all the grads have been populated, for the shard
    if device == torch.device("cuda"):
        torch.cuda.synchronize()  # flush any remaining cuda op, just in case

    for pg in optimizer.optim.param_groups:
        for param in pg["params"]:
            if param.requires_grad:
                assert param.grad.abs().sum().item(
                ) > 0.0, "The reduce step should have populated all the gradients"

    # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
    for b in model.buffers():
        assert b.cpu().item() == 0.0
def test_mixed_types():
    # Check that ShardedDDP exposes the original module's attributes
    dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1],
                            backend="gloo",
                            rank=0,
                            world_size=1)

    model = _get_mlp(tripwire=True)

    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    model = ShardedDataParallel(model, optimizer)
    input_tensor = torch.rand((2, 2))
    _ = model(input_tensor)

    dist.destroy_process_group()
Пример #26
0
def train(
    rank: int,
    world_size: int,
    num_epochs: int = 10,
    batch_size: int = 32,
    data_size: int = 200,
    backend: str = "gloo",
    use_oss: bool = True,
    use_sdp: bool = False,
    check_regression: bool = True,
    reference_speed: float = -1.0,
    reference_memory: float = -1.0,
    reference_loss: float = -1.0,
):
    assert not use_sdp or (use_sdp
                           and use_oss), "ShardedDataParallel requires OSS"
    # DDP
    dist_init(rank=rank, world_size=world_size, backend=backend)

    # Setup
    torch.cuda.set_device(rank)
    torch.cuda.manual_seed(0)
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)

    # Shard the optimizer
    optimizer: Optional[torch.optim.Optimizer] = None

    if use_sdp:
        ddp = ShardedDataParallel(
            module=model,
            optimizer=OPTIM,
            optimizer_params={
                "lr": 1e-4,
                "momentum": 0.9
            },
            world_size=world_size,
            broadcast_buffers=False,
        )
        ddp.train()
        optimizer = ddp.optimizer
        model = ddp
    else:
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
            if use_oss else OPTIM(model.parameters(), lr=1e-4, momentum=0.9))

    # Reset the memory use counter
    torch.cuda.reset_peak_memory_stats(rank)

    # Dummy training loop
    torch.cuda.synchronize(rank)
    training_start = time.monotonic()
    model.train()

    measurements = []
    final_loss: Optional[float] = -1.0

    for epoch in range(num_epochs):
        epoch_start = time.monotonic()

        for batch in dataloader:

            def closure():
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                loss /= world_size
                loss.backward()

                dist.all_reduce(loss, op=dist.ReduceOp.SUM)

                if use_sdp:
                    ddp.reduce(
                    )  # Send the gradients to the appropriate shards

                return loss

            final_loss = optimizer.step(closure)

        epoch_end = time.monotonic()

        if use_oss:
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
            optimizer.consolidate_state_dict()
            if dist.get_rank() == 0:
                _ = optimizer.state_dict()
                print("... State dict collected")

        measurements.append(data_size / (epoch_end - epoch_start))
        if dist.get_rank() == 0:
            print(
                f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}"
            )

    torch.cuda.synchronize(rank)
    training_stop = time.monotonic()
    img_per_sec = data_size / (training_stop - training_start) * num_epochs
    max_memory = torch.cuda.max_memory_allocated(rank) / 2**20

    print(
        f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall"
    )
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1))
    print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

    if use_oss and check_regression and dist.get_rank() == 0:
        assert (mean +
                3.0 * std) > reference_speed, "Speed regression detected"
        assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
        assert abs(cast(float, final_loss) -
                   reference_loss) < 1e-3, "Loss regression detected"

        print("[Regression Test] VALID")
def run_one_step(
    rank,
    world_size,
    backend,
    device,
    temp_file_name,
    broadcast_buffers,
    grad_accumulation,
    reduce_buffer_size,
):
    dist.init_process_group(init_method="file://" + temp_file_name,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)

    # Any model works. Add one different buffer per rank
    model = _get_mlp()
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)

    next(model.parameters()
         ).requires_grad = False  # Test non-trainable parameters

    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    ddp_model = ShardedDataParallel(model,
                                    optimizer,
                                    broadcast_buffers=broadcast_buffers,
                                    reduce_buffer_size=reduce_buffer_size)

    # The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
    check_same_models_across_ranks(ddp_model,
                                   dist.group.WORLD,
                                   params_should_be_equal=True,
                                   check_broadcast_buffers=broadcast_buffers)

    # Optim loop
    def closure():
        optimizer.zero_grad()

        with ddp_model.no_sync() if grad_accumulation else suppress():
            input_tensor = torch.rand((64, 2)).to(device)
            loss = ddp_model(input_tensor).abs().sum()
            loss.backward()
        return loss

    # The models should stay the same in between the ranks
    for i in range(5):
        _ = optimizer.step(closure=closure)
        # when running on cpu/gloo the "nodes" are not really different
        same_params = device == torch.device("cpu") or grad_accumulation
        check_same_models_across_ranks(
            ddp_model,
            dist.group.WORLD,
            params_should_be_equal=same_params,
            check_broadcast_buffers=broadcast_buffers)

    dist.destroy_process_group()
Пример #28
0
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name,
                             reduce_buffer_size):
    dist.init_process_group(init_method="file://" + temp_file_name,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    device = torch.device("cuda")
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)
    np.random.seed(rank)  # Any model works. Add one different buffer per rank

    BATCHS = 20

    model = _get_mlp_emb()
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)
    n_half_params = len(list(model.parameters())) // 2
    optim_settings = {"lr": 1e-3, "momentum": 0.99}

    sharded_optimizer = OSS(params=list(model.parameters())[:n_half_params],
                            optim=torch.optim.SGD,
                            **optim_settings)
    sharded_optimizer_2 = OSS(params=list(model.parameters())[n_half_params:],
                              optim=torch.optim.SGD,
                              **optim_settings)

    sharded_ddp_model = ShardedDataParallel(
        module=model,
        sharded_optimizer=[sharded_optimizer, sharded_optimizer_2],
        broadcast_buffers=True,
        reduce_buffer_size=reduce_buffer_size,
    )

    ddp_model_single = copy.deepcopy(model)
    ddp_optimizer = torch.optim.SGD(
        list(ddp_model_single.parameters())[:n_half_params], **optim_settings)
    ddp_optimizer_2 = torch.optim.SGD(
        list(ddp_model_single.parameters())[n_half_params:], **optim_settings)
    ddp_model = DDP(ddp_model_single,
                    device_ids=[rank],
                    broadcast_buffers=True)

    check_same_model_params(
        sharded_ddp_model,
        ddp_model,
        f"DDP parity two optim test failing. differing at startup, Buffers {reduce_buffer_size}",
    )

    for i in range(BATCHS):
        input_tensor = _get_random_inputs(device)

        # Run DDP
        ddp_optimizer.zero_grad()
        ddp_optimizer_2.zero_grad()
        ddp_loss = ddp_model(input_tensor).abs().sum()
        ddp_loss.backward()
        ddp_optimizer.step()
        ddp_optimizer_2.step()
        torch.cuda.synchronize(device)

        # Run Sharded
        sharded_optimizer.zero_grad()
        sharded_optimizer_2.zero_grad()
        sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
        sharded_loss.backward()
        sharded_optimizer.step()
        sharded_optimizer_2.step()
        torch.cuda.synchronize(device)

        check_same_model_params(
            sharded_ddp_model,
            ddp_model,
            f"DDP parity two optim test failing, step {i}, buffers {reduce_buffer_size}",
        )

    dist.destroy_process_group()
Пример #29
0
    def check(broadcast_buffers: bool,
              grad_accumulation: bool = False) -> None:
        # Any model works. Add one different buffer per rank
        model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3),
                           Linear(3, 3), Linear(3, 3), Linear(3, 3))
        model.register_buffer("test_buffer", torch.ones((1)) * rank)
        model.to(device)

        next(model.parameters()
             ).requires_grad = False  # Test non-trainable parameters

        optimizer = OSS(params=model.parameters(),
                        optim=torch.optim.SGD,
                        lr=0.01,
                        momentum=0.99)
        ddp_model = ShardedDataParallel(model,
                                        optimizer,
                                        broadcast_buffers=broadcast_buffers)

        def check_same_model_params(same_params: bool):
            # Check that all the params are the same on all ranks
            # This should be true with and without broadcast_buffers, we don't have any real buffer here
            receptacle: List[torch.Tensor] = []

            if dist.get_backend() != "nccl":
                for pg in optimizer.param_groups:
                    for p in pg["params"]:
                        # Check the params
                        receptacle = [p.clone() for _ in range(world_size)
                                      ] if rank == 0 else []
                        dist.gather(p, receptacle, dst=0)
                        if rank == 0:
                            for sync_p in receptacle[1:]:
                                if same_params:
                                    assert torch.all(
                                        torch.eq(receptacle[0], sync_p)
                                    ), "Models differ in between ranks"
                                else:
                                    assert not torch.all(
                                        torch.eq(receptacle[0], sync_p)
                                    ), "Gradients should not have been synced"

                # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
                if broadcast_buffers:
                    for b in ddp_model.buffers():
                        receptacle = [b.clone() for _ in range(world_size)
                                      ] if rank == 0 else []
                        dist.gather(b, receptacle, dst=0)
                        if rank == 0:
                            for sync_b in receptacle[1:]:
                                if same_params:
                                    assert torch.all(
                                        torch.eq(receptacle[0], sync_b)
                                    ), "Models differ in between ranks"
                                else:
                                    assert not torch.all(
                                        torch.eq(receptacle[0], sync_b)
                                    ), "Gradients should not have been synced"

                        assert b.cpu().item() == 0.0

        # The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
        check_same_model_params(same_params=True)

        # Optim loop
        def closure():
            optimizer.zero_grad()

            with ddp_model.no_sync() if grad_accumulation else suppress():
                input_tensor = torch.rand((64, 2)).to(device)
                loss = ddp_model(input_tensor).abs().sum()
                loss.backward()
            return loss

        # The models should stay the same in between the ranks
        for i in range(5):
            _ = optimizer.step(closure=closure)
            # when running on cpu/gloo the "nodes" are not really different
            same_params = device == torch.device("cpu") or grad_accumulation
            check_same_model_params(same_params=same_params)
Пример #30
0
def run_ddp_parity(
    rank,
    world_size,
    backend,
    temp_file_name,
    reduce_buffer_size,
    grad_accumulation,
    change_train_graph,
    fp16_reduction,
    clip_grad_norm,
    amp,
    manual_reduction,
    multiple_fw,
):
    dist.init_process_group(init_method="file://" + temp_file_name,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)

    device = torch.device("cuda")
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)
    np.random.seed(rank)
    NUMBER_BATCHS = 5

    # Test all combinations: AMP, Accumulate, Change train graph, reduce buckets
    print(
        f"{rank}: Checking configuration: accumulate {grad_accumulation}" +
        f" - change train graph {change_train_graph}" + f" - amp {amp}" +
        f" - manual reduction {manual_reduction}" +
        f" - buffers {reduce_buffer_size}" + f" - multiple FW {multiple_fw}",
        flush=True,
    )

    # The API should be the exact same in between the sharded and non-sharded variants, generic closure
    def closure(model,
                scaler,
                input_tensor,
                should_accumulate,
                _manual_reduction=False):
        accumulate_steps = 3 if should_accumulate else 1

        model.zero_grad()

        def step():
            if scaler is not None:
                with torch.cuda.amp.autocast():
                    loss = model(input_tensor).abs().sum()
                    scaler.scale(loss).backward()
            else:
                loss = model(input_tensor).abs().sum()
                loss.backward()

        with model.no_sync() if should_accumulate else suppress():
            for _ in range(accumulate_steps - 1):
                step()

        if not _manual_reduction:
            step()
        else:
            with model.no_sync():
                step()

            model.reduce()

    # Any model works. Add one different buffer per rank
    model = _get_mlp_emb(multiple_fw)
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)

    # Make sure that the model starts with non-trainable, so that we check for the buckets to be
    # properly reassigned when/if this changes
    next(model.parameters()).requires_grad = False

    sharded_optimizer = OSS(params=model.parameters(),
                            optim=torch.optim.SGD,
                            lr=1e-4,
                            momentum=0.99)
    sharded_ddp_model = ShardedDataParallel(
        module=model,
        sharded_optimizer=sharded_optimizer,
        broadcast_buffers=True,
        reduce_buffer_size=reduce_buffer_size,
        reduce_fp16=fp16_reduction,
    )

    ddp_model_single = copy.deepcopy(model)
    ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(),
                                    lr=1e-4,
                                    momentum=0.99)
    ddp_model = DDP(ddp_model_single,
                    device_ids=[rank],
                    broadcast_buffers=True,
                    find_unused_parameters=True)

    if fp16_reduction:
        from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook

        ddp_model.register_comm_hook(state=None,
                                     hook=fp16_compress_hook)  # type: ignore

    ddp_scaler = TorchGradScaler() if amp else None
    sharded_scaler = ShardedGradScaler() if amp else None

    # The model should be synchronized in between the ranks at construction time, check that
    check_same_model_params(sharded_ddp_model, ddp_model)

    # Typical training loop, check that we get the exact same results as DDP
    for i in range(NUMBER_BATCHS):
        input_tensor = _get_random_inputs(device)

        def ddp_closure(input_tensor=input_tensor):
            return closure(ddp_model, ddp_scaler, input_tensor,
                           grad_accumulation)

        def sharded_closure(input_tensor=input_tensor):
            return closure(
                sharded_ddp_model,
                sharded_scaler,
                input_tensor,
                grad_accumulation,
                _manual_reduction=manual_reduction,
            )

        # Step/scale both
        for _scaler, _closure, _optimizer in (
            (ddp_scaler, ddp_closure, ddp_optimizer),
            (sharded_scaler, sharded_closure, sharded_optimizer),
        ):
            if _scaler is not None:
                _ = _closure(input_tensor)
                _scaler.step(_optimizer)
                _scaler.update()
            else:
                _optimizer.step(_closure())

        check_same_model_params(sharded_ddp_model, ddp_model,
                                f"Rank: {rank} - Step {i} broke")

        # Check that the two grad norm are equivalent
        # NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case
        # This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also
        # be valid for ShardedDDP
        # NOTE: DDP does not handle parameters trainability being changed after the fact, see
        # https://github.com/pytorch/pytorch/blob/5781aec74ef00284e0262817a649278c2e8072bf/torch/nn/parallel/distributed.py#L471
        if clip_grad_norm and not change_train_graph:
            if torch_version() >= (1, 9, 0):
                total_norm = torch.nn.utils.clip_grad_norm_(
                    ddp_model.parameters(),
                    0.3,
                    norm_type=2.0,
                    error_if_nonfinite=False)  # type: ignore
            else:
                total_norm = torch.nn.utils.clip_grad_norm_(
                    ddp_model.parameters(), 0.3, norm_type=2.0)  # type: ignore
            if not torch.isnan(total_norm):
                oss_total_norm = sharded_optimizer.clip_grad_norm(
                    0.3, norm_type=2.0)
                allclose = torch.allclose(oss_total_norm,
                                          total_norm,
                                          atol=1e-2 if amp else 1e-8)

                if not allclose:
                    # Debug helper if this unit test does not pass, compare the gradients in between DDP and ShardedDDP
                    for idx, (p_ddp, p_sdp) in enumerate(
                            zip(ddp_model.parameters(),
                                sharded_ddp_model.parameters())):
                        if p_ddp.grad is not None:
                            if p_sdp.grad is not None:
                                print(rank,
                                      idx,
                                      torch.norm(p_ddp.grad),
                                      torch.norm(p_sdp.grad),
                                      flush=True)
                            else:
                                print(rank,
                                      idx,
                                      torch.norm(p_ddp.grad),
                                      "not owned",
                                      flush=True)

                assert (
                    allclose
                ), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}"
            else:
                print(rank, "NaN grad norm in DDP", flush=True)

        # Flip the trainability of the first parameter back and forth
        if i == 0 and change_train_graph:
            next(sharded_ddp_model.parameters()).requires_grad = not next(
                sharded_ddp_model.parameters()).requires_grad
            next(ddp_model.parameters()).requires_grad = not next(
                ddp_model.parameters()).requires_grad
            check_same_model_params(
                sharded_ddp_model, ddp_model,
                f"Rank: {rank} - Trainability refresh {i} broke")

    dist.destroy_process_group()