Ejemplo n.º 1
0
def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
    INPUT_DIM = 32
    BACH_SIZE = 10
    STEPS = 10

    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = GPT2(
        embed_dim=512, num_heads=2, num_layers=24, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
    ).to(device)
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer)

    # Optim loop
    def closure():
        optimizer.zero_grad()
        # Force int inputs to prevent the first grad from firing
        input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device)
        loss = ddp_model(input_tensor).abs().sum()
        loss.backward()
        return loss

    # Check for bucketing overflows
    for i in range(STEPS):
        _ = optimizer.step(closure=closure)

    dist.destroy_process_group()
def run_test_two_inputs(rank, world_size, backend, device, temp_file_name,
                        reduce_buffer_size):
    dist.init_process_group(init_method="file://" + temp_file_name,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    if device == "cuda":
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)

    model = _DoubleInput().to(device)
    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    ddp_model = ShardedDataParallel(model,
                                    optimizer,
                                    reduce_buffer_size=reduce_buffer_size)

    # Optim loop
    def closure():
        optimizer.zero_grad()
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
        _ = optimizer.step(closure=closure)

    dist.destroy_process_group()
Ejemplo n.º 3
0
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = _DoubleInput().to(device)

    parameters = list(model.parameters())
    optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])

    # Optim loop
    def closure():
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
        optimizer_1.zero_grad()
        optimizer_2.zero_grad()

        _ = optimizer_1.step(closure=closure)
        _ = optimizer_2.step(closure=closure)

    dist.destroy_process_group()
Ejemplo n.º 4
0
    def __init__(
        self,
        module: nn.Module,
        sharded_optimizer: Union[OSS, List[OSS]],
        process_group: Any = None,
        broadcast_buffers: bool = True,
        sync_models_at_startup: bool = True,
    ):
        super().__init__()

        self.module = module
        self.sharded_optimizers = [sharded_optimizer] if isinstance(
            sharded_optimizer, OSS) else sharded_optimizer
        self.enable_broadcast_buffers = broadcast_buffers

        # Handle a no_sync() context which prevents the gradient synchronization,
        # accumulate in place
        self.should_accumulate_grads = False

        # Communication related attributes
        self.process_group = process_group if process_group is not None else dist.group.WORLD
        self.world_size = dist.get_world_size(self.process_group)
        self.reference_global_rank = OSS.get_global_rank(
            self.process_group, 0)  # picking rank 0 as the reference
        self.rank = dist.get_rank(self.process_group)
        self.global_rank = OSS.get_global_rank(self.process_group, self.rank)

        # Expose some of the PytorchDDP attributes, some frameworks rely on them.
        # See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
        # device_id related logic is not present, this is not handled
        devices = {p.device for p in self.module.parameters()}
        self.is_multi_device_module = len(devices) > 1
        self.device = list(devices)[0]

        distinct_device_types = {
            p.device.type
            for p in self.module.parameters()
        }
        assert len(distinct_device_types) == 1, (
            "ShardedDataParallel's input module must be on "
            "the same type of devices, but input module parameters are located on {} different device types."
        ).format(distinct_device_types)
        self.device_type = list(distinct_device_types)[0]

        # Scafolding to be able to reduce the grads during the BW pass
        # several optimizers can be present each working on seperate parameter sets,
        # we build an iterator which goes through all the parameters involved globally
        self._param_iterator = chain(*[
            optim.should_bucket_param.keys()
            for optim in self.sharded_optimizers
        ])
        self._grad_to_be_reduced = [True for _ in self._param_iterator]
        self._grad_accs: List[Callable] = []
        self._setup_backward_hooks()

        # Make sure that all ranks start with the same model
        if sync_models_at_startup:
            self._sync_params_and_buffers()
Ejemplo n.º 5
0
def run_ddp_parity(rank, world_size, backend, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)

    device = torch.device("cuda")
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)
    np.random.seed(rank)

    # Any model works. Add one different buffer per rank
    model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)

    sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    sharded_ddp_model = ShardedDataParallel(module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True)

    ddp_model_single = copy.deepcopy(model)
    ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-3, momentum=0.99)
    ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)

    def check_same_model_params():
        for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
            for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
                assert torch.allclose(
                    p, ddp_p, atol=1e-3
                ), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"

        for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
            assert torch.allclose(b, ddp_b, atol=1e-3), "Model buffers differ in between DDP and ShardedDDP"

    # The model should be synchronized in between the ranks at construction time, check that
    check_same_model_params()

    # The models should stay the same in between the ranks
    for i in range(20):
        input_tensor = torch.rand((64, 2)).to(device)

        def closure_ddp(input_tensor=input_tensor):
            ddp_optimizer.zero_grad()
            ddp_loss = ddp_model(input_tensor).abs().sum()
            ddp_loss.backward()
            return ddp_loss

        def closure_sharded(input_tensor=input_tensor):
            sharded_optimizer.zero_grad()
            sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
            sharded_loss.backward()
            return sharded_loss

        _ = ddp_optimizer.step(closure=closure_ddp)
        _ = sharded_optimizer.step(closure=closure_sharded)

        check_same_model_params()

    dist.destroy_process_group()
Ejemplo n.º 6
0
def run_test_gpt2(rank, world_size, backend, device, temp_file_name,
                  reduce_buffer_size):
    INPUT_DIM = 16
    BACH_SIZE = 10
    STEPS = 10

    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = GPT2(embed_dim=256,
                 num_heads=2,
                 num_layers=12,
                 num_positions=INPUT_DIM * INPUT_DIM,
                 num_vocab=512,
                 num_classes=2)
    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    ddp_model = ShardedDataParallel(model,
                                    optimizer,
                                    reduce_buffer_size=reduce_buffer_size)

    # Move the model to another device post-construction
    model = model.to(device)

    # Optim loop
    set_to_none = True

    def closure():
        nonlocal set_to_none
        ddp_model.zero_grad(set_to_none=set_to_none)
        set_to_none = not set_to_none

        # Force int inputs to prevent the first grad from firing
        input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device)
        loss = ddp_model(input_tensor).abs().sum()
        loss.backward()
        return loss

    # Check for bucketing overflows
    for i in range(STEPS):
        _ = optimizer.step(closure=closure)

        # Stress test the .to() method
        ddp_model.to(device=device, dtype=torch.float16)
        ddp_model.to(device=device, dtype=torch.float32)

    dist.destroy_process_group()
Ejemplo n.º 7
0
def _test_basic_func(rank,
                     world_size,
                     tempfile_name,
                     test_case,
                     oss,
                     model=None):
    _dist_init(rank, world_size, tempfile_name, backend="nccl")

    if model is None:
        model = Linear(2, 2)
        model.bias.data.fill_(0.0)

    model.to("cuda")
    model = DDP(model, device_ids=[rank])

    assert oss in ["none", "ada-oss", "wrapper-oss", "oss-wrapper"]
    if oss == "ada-oss":
        optim = AdaScale(OSS(model.parameters(), SGD, lr=0.1))
    elif oss == "wrapper-oss":
        optim = AdaScaleWrapper(model.parameters(),
                                optim_cls=OSS,
                                optim=SGD,
                                lr=0.1)
    elif oss == "oss-wrapper":
        optim = OSS(model.parameters(), AdaScaleWrapper, optim_cls=SGD, lr=0.1)
    else:
        assert oss == "none"
        optim = AdaScale(SGD(model.parameters(), lr=0.1))

    if "input" in test_case:
        inputs = [test_case["input"]]
    else:
        inputs = test_case["inputs"]

    for in_data in inputs:
        in_data = Tensor(in_data[rank]).cuda()
        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()

    if "expected_gain" in test_case:
        assert np.allclose(optim.gain(),
                           test_case["expected_gain"]), "{} vs {}".format(
                               optim.gain(), test_case["expected_gain"])

    if "expected_mean_weight" in test_case:
        mean_weight = mean(
            [model.module[i].weight.data.mean().item() for i in range(4)])
        assert np.allclose(mean_weight,
                           test_case["expected_mean_weight"]), mean_weight

    dist.destroy_process_group()
Ejemplo n.º 8
0
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)

    class _DoubleInput(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3),
                                  Linear(3, 3), Linear(3, 3), Linear(3, 3))

        def forward(self, x, y):
            x1 = self.mlp(x)
            x2 = self.mlp(y)
            return torch.cat((x1, x2), dim=1)

    model = _DoubleInput().to(device)

    parameters = list(model.parameters())
    optimizer_1 = OSS(params=parameters[:-10],
                      optim=torch.optim.SGD,
                      lr=0.01,
                      momentum=0.99)
    optimizer_2 = OSS(params=parameters[-10:],
                      optim=torch.optim.SGD,
                      lr=0.01,
                      momentum=0.99)
    ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])

    # Optim loop
    def closure():
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
        optimizer_1.zero_grad()
        optimizer_2.zero_grad()

        _ = optimizer_1.step(closure=closure)
        _ = optimizer_2.step(closure=closure)

    dist.destroy_process_group()
Ejemplo n.º 9
0
def run_one_step(rank, world_size, backend, device, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    model = Sequential(Linear(2, 3), Linear(3, 4)).to(device)
    optimizer = OSS(model.parameters(), lr=0.1, momentum=0.99)
    ddp = OssDdp(model, optimizer, world_size)
    input_tensor = torch.rand((64, 2)).to(device)
    output = ddp(input_tensor).sum()
    output.backward()
    ddp.reduce()
    optimizer.step()
def test_train_eval_change():
    # Check that ShardedDDP handles the switch from training to eval properly
    dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1],
                            backend="gloo",
                            rank=0,
                            world_size=1)

    model = _get_mlp()
    model.train()
    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    model = ShardedDataParallel(model, optimizer)
    input_tensor = torch.rand((2, 2))
    loss = model(input_tensor).sum()
    loss.backward()  # make sure that the gradients are reduced

    # Wipe the gradients and switch to eval mode
    model.zero_grad()
    model.eval()
    _ = model(input_tensor)
    assert next(model.parameters()).grad is None or torch.norm(
        next(model.parameters()).grad) < 1e-6

    # Get back to training
    model = model.train()
    model(input_tensor).sum().backward()
    assert torch.norm(next(model.parameters()).grad) > 0.0

    dist.destroy_process_group()
Ejemplo n.º 11
0
 def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]:
     for x, optimizer in enumerate(optimizers):
         if isinstance(optimizer, LightningOptimizer):
             optimizer = optimizer._optimizer
         if not isinstance(optimizer, OSS):
             optim_class = type(optimizer)
             zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
             if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
                 is_fp16 = self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF)
                 # For multi-node training, compressing the model shards in fp16 before broadcasting
                 # improves performance. When using PyTorch AMP, it will not degrade
                 # the model performance.
                 zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1
             optimizers[x] = zero_optimizer
             del optimizer
     return optimizers
Ejemplo n.º 12
0
    def init_components(
        self,
        model_fn=None,
        criterion_fn=None,
        optimizer_fn=None,
        scheduler_fn=None,
    ):
        """Inits the runs components."""
        model = model_fn()
        model = self.sync_device(model)
        if self._sync_bn:
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

        criterion = criterion_fn()
        criterion = self.sync_device(criterion)

        optimizer = optimizer_fn(model)
        optimizer = self.sync_device(optimizer)

        optimizer = OSS(model.parameters(),
                        optim=optimizer.__class__,
                        **optimizer.defaults)
        model = ShardedDataParallel(model, optimizer, **self.ddp_kwargs)

        scheduler = scheduler_fn(optimizer)
        scheduler = self.sync_device(scheduler)
        return model, criterion, optimizer, scheduler
def run_test_device_change(rank, world_size, backend, device, temp_file_name,
                           reduce_buffer_size):
    # Check that the wrapped module can change devices
    dist.init_process_group(init_method="file://" + temp_file_name,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    torch.cuda.set_device(rank)

    model = Sequential(Linear(2, 3), Linear(
        3, 3)).cpu()  # not device on purpose, test changing it after the fact
    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    ddp_model = ShardedDataParallel(model,
                                    optimizer,
                                    sync_models_at_startup=False,
                                    reduce_buffer_size=reduce_buffer_size)
    try:
        ddp_model.to(device)
        assert False, "Changing devices should be caught and not supported"
    except AssertionError:
        pass

    dist.destroy_process_group()
Ejemplo n.º 14
0
def test_catch_grad_grad():
    with temp_files_ctx(num=1) as temp_files:
        # Check that ShardedDDP exposes the original module's attributes
        dist.init_process_group(init_method="file://" + temp_files[0],
                                backend="gloo",
                                rank=0,
                                world_size=1)

        model = Sequential(Linear(2, 3), Linear(3, 3))
        model.train()
        chained_grad = torch.zeros_like(next(model.parameters()))
        chained_grad.requires_grad = True
        next(model.parameters()).grad = chained_grad

        optimizer = OSS(params=model.parameters(),
                        optim=torch.optim.SGD,
                        lr=1e-3,
                        momentum=0.99)
        ddp_model = ShardedDataParallel(model, optimizer)

        inputs = torch.rand(100, 2)
        with pytest.raises(RuntimeError):
            _ = ddp_model(inputs)

        dist.destroy_process_group()
def run_test_training_change(rank, world_size, backend, device, temp_file_name,
                             reduce_buffer_size):
    group = dist.init_process_group(init_method="file://" + temp_file_name,
                                    backend=backend,
                                    rank=rank,
                                    world_size=world_size)
    torch.cuda.set_device(rank)

    model = Sequential(Linear(2, 3), Linear(3, 3)).to(device)
    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    ddp_model = ShardedDataParallel(model,
                                    optimizer,
                                    process_group=group,
                                    reduce_buffer_size=reduce_buffer_size)

    inputs = torch.rand((10, 2), device=device)
    outputs = ddp_model(
        inputs)  # assert if the module has not been changed properly
    _ = outputs.norm().backward()

    ddp_model.eval()
    ddp_model(
        inputs
    )  # This will assert if eval() is not properly taken into account
    ddp_model(inputs)

    dist.destroy_process_group()
Ejemplo n.º 16
0
 def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]:
     for x, optimizer in enumerate(optimizers):
         if not isinstance(optimizer, OSS):
             optim_class = type(optimizer)
             zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
             optimizers[x] = zero_optimizer
             del optimizer
     return optimizers
Ejemplo n.º 17
0
    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Edited to use fixed Adafactor.

        Setup the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
        """
        if self.optimizer is None:
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [
                        p for n, p in self.model.named_parameters()
                        if not any(nd in n for nd in no_decay)
                    ],
                    "weight_decay":
                    self.args.weight_decay,
                },
                {
                    "params": [
                        p for n, p in self.model.named_parameters()
                        if any(nd in n for nd in no_decay)
                    ],
                    "weight_decay":
                    0.0,
                },
            ]
            optimizer_cls = FixedAdafactor if self.args.adafactor else AdamW
            if self.args.adafactor:
                optimizer_kwargs = {
                    "scale_parameter": False,
                    "relative_step": False
                }
            else:
                optimizer_kwargs = {
                    "betas": (self.args.adam_beta1, self.args.adam_beta2),
                    "eps": self.args.adam_epsilon,
                }
            optimizer_kwargs["lr"] = self.args.learning_rate
            if self.sharded_dpp:
                self.optimizer = OSS(
                    params=optimizer_grouped_parameters,
                    optim=optimizer_cls,
                    **optimizer_kwargs,
                )
            else:
                self.optimizer = optimizer_cls(optimizer_grouped_parameters,
                                               **optimizer_kwargs)

        if self.lr_scheduler is None:
            self.lr_scheduler = get_scheduler(
                self.args.lr_scheduler_type,
                self.optimizer,
                num_warmup_steps=self.args.warmup_steps,
                num_training_steps=num_training_steps,
            )
Ejemplo n.º 18
0
 def _reinit_optimizers_with_oss(self):
     optimizers = self.lightning_module.trainer.optimizers
     for x, optimizer in enumerate(optimizers):
         if not isinstance(optimizer, OSS):
             optim_class = type(optimizer)
             zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
             optimizers[x] = zero_optimizer
             del optimizer
     trainer = self.lightning_module.trainer
     trainer.optimizers = optimizers
Ejemplo n.º 19
0
 def _reinit_optimizers_with_oss(self):
     optimizers = self.lightning_module.trainer.optimizers
     for x, optimizer in enumerate(optimizers):
         if isinstance(optimizer, LightningOptimizer):
             optimizer = optimizer._optimizer
         if not isinstance(optimizer, OSS):
             optim_class = type(optimizer)
             zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
             if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
                 precision = self.lightning_module.trainer.precision
                 is_fp16 = precision in ("mixed", 16)
                 # For multi-node training, compressing the model shards in fp16 before broadcasting
                 # improves performance. When using PyTorch AMP, it will not degrade
                 # the model performance.
                 zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1
             optimizers[x] = zero_optimizer
             del optimizer
     trainer = self.lightning_module.trainer
     trainer.optimizers = optimizers
     trainer.convert_to_lightning_optimizers()
Ejemplo n.º 20
0
def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_buffer_size):
    # Only work with the even ranks, to check that the global_rank indexing is properly used
    dist.init_process_group(init_method="file://" + tempfile_name, backend=backend, rank=rank, world_size=world_size)

    sub_group_ranks = [0, 2]
    process_group = torch.distributed.new_group(ranks=sub_group_ranks, backend=backend)

    # Make sure that all the ranks get different training data
    # So that the sync check in between their models is meaningful
    torch.manual_seed(rank)
    np.random.seed(rank)

    # Standard deep learning setup
    device = "cuda"
    torch.cuda.set_device(rank)

    epochs, batch, input_width, hidden, target_width = 5, 3, 20, 10, 5
    loss_fn = torch.nn.L1Loss().to(device)

    def check(optimizer, model):
        # Just run a couple of epochs, check that the model is properly updated
        for _ in range(epochs):
            target = torch.rand((batch, target_width), device=device)
            inputs = torch.rand((batch, input_width), device=device)

            def closure():
                optimizer.zero_grad()
                output = model(inputs)
                loss = loss_fn(output, target)
                loss.backward()
                return loss

            _ = optimizer.step(closure=closure)

            # Check that all the params are the same on all ranks
            check_same_models_across_ranks(
                model, process_group, params_should_be_equal=True, check_broadcast_buffers=True
            )

    if rank in sub_group_ranks:
        # Model not-fitting in the broadcast bucket
        model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)).to(
            device
        )

        # With SGD, Momentum is required to get a state to shard
        optimizer = OSS(model.parameters(), group=process_group, lr=1e-3, momentum=0.99)
        model = ShardedDataParallel(
            model, optimizer, process_group=process_group, reduce_buffer_size=reduce_buffer_size
        )
        check(optimizer, model)

    dist.destroy_process_group(process_group)
Ejemplo n.º 21
0
 def _reinit_with_fairscale_oss(self, trainer):
     optimizers = trainer.optimizers
     for x, optimizer in enumerate(optimizers):
         if not isinstance(optimizer, OSS):
             optim_class = type(optimizer)
             zero_optimizer = OSS(
                 params=optimizer.param_groups,
                 optim=optim_class,
                 **optimizer.defaults
             )
             optimizers[x] = zero_optimizer
             del optimizer
Ejemplo n.º 22
0
    def __init__(self,
                 module: nn.Module,
                 oss: OSS,
                 world_size: int,
                 process_group: Any = None,
                 buffer_size: int = 2**28):
        super().__init__()

        self.module = module
        self.world_size = world_size
        self.process_group = process_group if process_group is not None else dist.group.WORLD
        self.rank = dist.get_rank(self.process_group)

        # Never use a bigger buffer than the number of model params
        self.buffer_size = min(
            buffer_size, sum(p.numel() for p in self.module.parameters()))
        self.buffer: Optional[Tensor] = None

        # Flag used to make sure we only reduce gradients one time in the execution engine
        self.need_reduction = False

        # We can also forcibly accumulate grads locally and only do the
        # gradients-reduce at some later time
        self.accumulate_grads = False

        # TODO (Min): The algorithm here can be improved. We are sorting params by device
        #     and by rank. Then in reduction_fn below, we pack smaller ones into
        #     a buffer for reduction.
        #     We can pre-sort them here and simplify the reduction_fn logic below
        #     since their size shouldn't change.

        # make per-device lists of parameters
        paramlists: OrderedDict = OrderedDict()
        for param in self.module.parameters():
            device = param.device
            if paramlists.get(device) is None:
                paramlists[device] = []
            paramlists[device] += [param]
        self.per_device_params = list(paramlists.values())

        # query oss and build a param-to-rank table
        self.param_rank = {}
        for rank, param_groups in enumerate(oss.partition_parameters()):
            for param_group in param_groups:
                for param in param_group["params"]:
                    self.param_rank[param] = rank

        # sanity checks
        assert len(self.param_rank) == len(list(
            self.module.parameters())), "number of params do not match"
        for param in self.module.parameters():
            assert param in self.param_rank, f"{param} not in the optimizer"
Ejemplo n.º 23
0
    def __init__(
        self,
        module: nn.Module,
        optimizer: Type[torch.optim.Optimizer],
        optimizer_params: Dict[str, Any],
        world_size: int,
        broadcast_buffers: bool,
        process_group: Any = None,
        buffer_size: int = 2**19,
    ):
        super().__init__()

        self.module = module
        self.world_size = world_size
        self.process_group = process_group if process_group is not None else dist.group.WORLD
        self.rank = dist.get_rank(self.process_group)
        self.broadcast_buffers = broadcast_buffers
        self.authoritative_rank = 0

        # Flag used to make sure we only reduce gradients one time in the execution engine
        self.need_reduction = False

        # We can also forcibly accumulate grads locally and only do the
        # gradients-reduce at some later time
        self.accumulate_grads = False

        # Build the sharded optimizer
        self.sharded_optimizer = OSS(self.module.parameters(),
                                     optim=optimizer,
                                     group=process_group,
                                     **optimizer_params)

        # Allocate reduce buffers
        # - Never use a bigger buffer than the number of model params
        buffer_size = min(buffer_size,
                          sum(p.numel() for p in self.module.parameters()))
        self._reduce_buffers: Dict[torch.device, List[torch.Tensor]] = {}

        # - One buffer per rank per device
        for device, per_device in self.sharded_optimizer.per_device_params.items(
        ):
            buffer_dtype = per_device[0][0].dtype
            self._reduce_buffers[device] = [
                torch.zeros(buffer_size, dtype=buffer_dtype, device=device)
                for _ in range(len(per_device))
            ]

        # Sanity checks
        assert len(self.sharded_optimizer.param_to_rank) == len(
            list(self.module.parameters())), "number of params do not match"
        for param in self.module.parameters():
            assert param in self.sharded_optimizer.param_to_rank, f"{param} not in the optimizer"
Ejemplo n.º 24
0
def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)

    model = Sequential(Linear(2, 3), torch.nn.BatchNorm1d(3), Linear(3, 3)).to(device)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer)

    assert isinstance(model[1], torch.nn.SyncBatchNorm)
    # Ensures sync batch norm handles have been added
    ddp_model(torch.randn(2, 2).to(device))
    dist.destroy_process_group()
Ejemplo n.º 25
0
def test_ddp_attributes():
    # Check that ShardedDDP exposes the same attributes as Pytorch's DDP
    # - is multi_device_module
    # - device_type
    dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)

    model = Sequential(Linear(2, 3), Linear(3, 3))
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer)

    assert hasattr(ddp_model, "is_multi_device_module")
    assert hasattr(ddp_model, "device_type")
    dist.destroy_process_group()
Ejemplo n.º 26
0
def test_random_attributes():
    # Check that ShardedDDP exposes the original module's attributes
    dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)

    model = Sequential(Linear(2, 3), Linear(3, 3))
    model.banana = "sweet"

    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer)

    assert hasattr(ddp_model, "banana")
    assert not hasattr(ddp_model, "orange")

    dist.destroy_process_group()
Ejemplo n.º 27
0
            def reduce(*_: Any) -> None:
                # Skip gradient reduction, do not alter status flags
                if not self.should_accumulate_grads and self._grad_to_be_reduced[
                        index]:
                    assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"

                    if not self._bucket_flush_callback_set:
                        Variable._execution_engine.queue_callback(
                            self._flush_buckets)
                        self._bucket_flush_callback_set = True

                    # Make sure that this is not fired twice
                    self._grad_to_be_reduced[index] = False
                    param.grad.mul_(self.world_size_scaling)

                    if self.reduce_fp16:
                        param.grad.data = param.grad.data.half()

                    # Future work includes clearing up the buffer if possible
                    def cleanup() -> None:
                        if dst_rank != self.global_rank:
                            param.grad = None
                        else:
                            assert param.grad is not None
                            param.grad.data = param.grad.data.to(
                                dtype=param.dtype)

                    # Async reduce for this buffer, log the future
                    dst_global_rank = OSS.get_global_rank(
                        self.process_group, dst_rank)

                    self._work_handles.append(
                        Workhandle(
                            handle=dist.reduce(tensor=param.grad.data,
                                               dst=dst_global_rank,
                                               group=self.process_group,
                                               async_op=True),
                            callback=cleanup,
                        ))
                    self._reduced_grads += 1

                    # Opportunistically try to empty the queue
                    self._try_consume_work_handle()

                    # If all the reduce operations have been called,
                    # make sure that all the asynchronous calls have concluded before moving on
                    # and execute the delayed actions (release gradients, unroll the buckets)
                    if self._reduced_grads == self._reduced_grads_max:
                        self._consume_work_handles()
Ejemplo n.º 28
0
def run_test_device_change(rank, world_size, backend, device, temp_file_name):
    # Check that the wrapped module can change devices

    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)

    model = Sequential(Linear(2, 3), Linear(3, 3)).cpu()
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer)
    ddp_model.to(device)

    inputs = torch.rand((10, 2), device=device)
    outputs = ddp_model(inputs)  # assert if the module has not been changed properly
    loss = outputs.norm().backward()

    dist.destroy_process_group()
def test_mixed_types():
    # Check that ShardedDDP exposes the original module's attributes
    dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1],
                            backend="gloo",
                            rank=0,
                            world_size=1)

    model = _get_mlp(tripwire=True)

    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    model = ShardedDataParallel(model, optimizer)
    input_tensor = torch.rand((2, 2))
    _ = model(input_tensor)

    dist.destroy_process_group()
Ejemplo n.º 30
0
def _test_basic_func(rank, ddp_cls, world_size, tempfile_name, test_case):
    _dist_init(rank, world_size, tempfile_name, backend="nccl")  # Covers nccl

    model = Linear(2, 2)
    model.to("cuda")
    if ddp_cls is DDP:
        model = ddp_cls(model, device_ids=[rank])
        optim = AdaScale(SGD(model.parameters(), lr=0.1))
    elif ddp_cls is SDP:
        optim = AdaScale(OSS(model.parameters(), SGD, lr=0.1))
        model = ddp_cls(model, sharded_optimizer=optim)
    else:
        assert ddp_cls is FSDP, ddp_cls
        # Two cases:
        #    flatten=True : AdaScale wrapper must be after FSDP and it receives
        #                   a single grad tensor. It won't receive grad if
        #                   wrapped before.
        #    flatten=False: AdaScale can be both before or after FSDP.
        # So, it is better to do AdaScale after FSDP.
        model = ddp_cls(model, flatten_parameters=False)
        optim = AdaScale(SGD(model.parameters(), lr=0.1))
    if "input" in test_case:
        # single iter
        in_data = Tensor(test_case["input"][rank])
        in_data = in_data.cuda()
        out = model(in_data)
        out.sum().backward()
        if ddp_cls is DDP:
            assert np.allclose(optim.gain(), test_case["expected_gain"]), optim.gain()
        optim.step()
        optim.zero_grad()
    else:
        # multiple iters
        for in_data in test_case["inputs"]:
            in_data = Tensor(in_data[rank]).cuda()
            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()
        if ddp_cls is DDP:
            assert np.allclose(optim.gain(), test_case["expected_gain"]), optim.gain()

    dist.destroy_process_group()