def run_test_gpt2(rank, world_size, backend, device, temp_file_name): INPUT_DIM = 32 BACH_SIZE = 10 STEPS = 10 url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) if device == torch.device("cuda"): torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) model = GPT2( embed_dim=512, num_heads=2, num_layers=24, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2 ).to(device) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer) # Optim loop def closure(): optimizer.zero_grad() # Force int inputs to prevent the first grad from firing input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device) loss = ddp_model(input_tensor).abs().sum() loss.backward() return loss # Check for bucketing overflows for i in range(STEPS): _ = optimizer.step(closure=closure) dist.destroy_process_group()
def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduce_buffer_size): dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) if device == "cuda": torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) model = _DoubleInput().to(device) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=reduce_buffer_size) # Optim loop def closure(): optimizer.zero_grad() input_tensor = torch.rand((64, 2)).to(device) loss = ddp_model(input_tensor, input_tensor).abs().sum() loss.backward() return loss for i in range(5): _ = optimizer.step(closure=closure) dist.destroy_process_group()
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name): dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) if device == torch.device("cuda"): torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) model = _DoubleInput().to(device) parameters = list(model.parameters()) optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=1e-3, momentum=0.99) optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2]) # Optim loop def closure(): input_tensor = torch.rand((64, 2)).to(device) loss = ddp_model(input_tensor, input_tensor).abs().sum() loss.backward() return loss for i in range(5): optimizer_1.zero_grad() optimizer_2.zero_grad() _ = optimizer_1.step(closure=closure) _ = optimizer_2.step(closure=closure) dist.destroy_process_group()
def __init__( self, module: nn.Module, sharded_optimizer: Union[OSS, List[OSS]], process_group: Any = None, broadcast_buffers: bool = True, sync_models_at_startup: bool = True, ): super().__init__() self.module = module self.sharded_optimizers = [sharded_optimizer] if isinstance( sharded_optimizer, OSS) else sharded_optimizer self.enable_broadcast_buffers = broadcast_buffers # Handle a no_sync() context which prevents the gradient synchronization, # accumulate in place self.should_accumulate_grads = False # Communication related attributes self.process_group = process_group if process_group is not None else dist.group.WORLD self.world_size = dist.get_world_size(self.process_group) self.reference_global_rank = OSS.get_global_rank( self.process_group, 0) # picking rank 0 as the reference self.rank = dist.get_rank(self.process_group) self.global_rank = OSS.get_global_rank(self.process_group, self.rank) # Expose some of the PytorchDDP attributes, some frameworks rely on them. # See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel # device_id related logic is not present, this is not handled devices = {p.device for p in self.module.parameters()} self.is_multi_device_module = len(devices) > 1 self.device = list(devices)[0] distinct_device_types = { p.device.type for p in self.module.parameters() } assert len(distinct_device_types) == 1, ( "ShardedDataParallel's input module must be on " "the same type of devices, but input module parameters are located on {} different device types." ).format(distinct_device_types) self.device_type = list(distinct_device_types)[0] # Scafolding to be able to reduce the grads during the BW pass # several optimizers can be present each working on seperate parameter sets, # we build an iterator which goes through all the parameters involved globally self._param_iterator = chain(*[ optim.should_bucket_param.keys() for optim in self.sharded_optimizers ]) self._grad_to_be_reduced = [True for _ in self._param_iterator] self._grad_accs: List[Callable] = [] self._setup_backward_hooks() # Make sure that all ranks start with the same model if sync_models_at_startup: self._sync_params_and_buffers()
def run_ddp_parity(rank, world_size, backend, temp_file_name): url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) device = torch.device("cuda") torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) # Any model works. Add one different buffer per rank model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) sharded_ddp_model = ShardedDataParallel(module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True) ddp_model_single = copy.deepcopy(model) ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-3, momentum=0.99) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True) def check_same_model_params(): for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups): for p, ddp_p in zip(pg["params"], ddp_pg["params"]): assert torch.allclose( p, ddp_p, atol=1e-3 ), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}" for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()): assert torch.allclose(b, ddp_b, atol=1e-3), "Model buffers differ in between DDP and ShardedDDP" # The model should be synchronized in between the ranks at construction time, check that check_same_model_params() # The models should stay the same in between the ranks for i in range(20): input_tensor = torch.rand((64, 2)).to(device) def closure_ddp(input_tensor=input_tensor): ddp_optimizer.zero_grad() ddp_loss = ddp_model(input_tensor).abs().sum() ddp_loss.backward() return ddp_loss def closure_sharded(input_tensor=input_tensor): sharded_optimizer.zero_grad() sharded_loss = sharded_ddp_model(input_tensor).abs().sum() sharded_loss.backward() return sharded_loss _ = ddp_optimizer.step(closure=closure_ddp) _ = sharded_optimizer.step(closure=closure_sharded) check_same_model_params() dist.destroy_process_group()
def run_test_gpt2(rank, world_size, backend, device, temp_file_name, reduce_buffer_size): INPUT_DIM = 16 BACH_SIZE = 10 STEPS = 10 url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) model = GPT2(embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=reduce_buffer_size) # Move the model to another device post-construction model = model.to(device) # Optim loop set_to_none = True def closure(): nonlocal set_to_none ddp_model.zero_grad(set_to_none=set_to_none) set_to_none = not set_to_none # Force int inputs to prevent the first grad from firing input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device) loss = ddp_model(input_tensor).abs().sum() loss.backward() return loss # Check for bucketing overflows for i in range(STEPS): _ = optimizer.step(closure=closure) # Stress test the .to() method ddp_model.to(device=device, dtype=torch.float16) ddp_model.to(device=device, dtype=torch.float32) dist.destroy_process_group()
def _test_basic_func(rank, world_size, tempfile_name, test_case, oss, model=None): _dist_init(rank, world_size, tempfile_name, backend="nccl") if model is None: model = Linear(2, 2) model.bias.data.fill_(0.0) model.to("cuda") model = DDP(model, device_ids=[rank]) assert oss in ["none", "ada-oss", "wrapper-oss", "oss-wrapper"] if oss == "ada-oss": optim = AdaScale(OSS(model.parameters(), SGD, lr=0.1)) elif oss == "wrapper-oss": optim = AdaScaleWrapper(model.parameters(), optim_cls=OSS, optim=SGD, lr=0.1) elif oss == "oss-wrapper": optim = OSS(model.parameters(), AdaScaleWrapper, optim_cls=SGD, lr=0.1) else: assert oss == "none" optim = AdaScale(SGD(model.parameters(), lr=0.1)) if "input" in test_case: inputs = [test_case["input"]] else: inputs = test_case["inputs"] for in_data in inputs: in_data = Tensor(in_data[rank]).cuda() out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if "expected_gain" in test_case: assert np.allclose(optim.gain(), test_case["expected_gain"]), "{} vs {}".format( optim.gain(), test_case["expected_gain"]) if "expected_mean_weight" in test_case: mean_weight = mean( [model.module[i].weight.data.mean().item() for i in range(4)]) assert np.allclose(mean_weight, test_case["expected_mean_weight"]), mean_weight dist.destroy_process_group()
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name): url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) if device == torch.device("cuda"): torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) class _DoubleInput(torch.nn.Module): def __init__(self): super().__init__() self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) def forward(self, x, y): x1 = self.mlp(x) x2 = self.mlp(y) return torch.cat((x1, x2), dim=1) model = _DoubleInput().to(device) parameters = list(model.parameters()) optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=0.01, momentum=0.99) optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=0.01, momentum=0.99) ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2]) # Optim loop def closure(): input_tensor = torch.rand((64, 2)).to(device) loss = ddp_model(input_tensor, input_tensor).abs().sum() loss.backward() return loss for i in range(5): optimizer_1.zero_grad() optimizer_2.zero_grad() _ = optimizer_1.step(closure=closure) _ = optimizer_2.step(closure=closure) dist.destroy_process_group()
def run_one_step(rank, world_size, backend, device, temp_file_name): url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) if device == torch.device("cuda"): torch.cuda.set_device(rank) model = Sequential(Linear(2, 3), Linear(3, 4)).to(device) optimizer = OSS(model.parameters(), lr=0.1, momentum=0.99) ddp = OssDdp(model, optimizer, world_size) input_tensor = torch.rand((64, 2)).to(device) output = ddp(input_tensor).sum() output.backward() ddp.reduce() optimizer.step()
def test_train_eval_change(): # Check that ShardedDDP handles the switch from training to eval properly dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1) model = _get_mlp() model.train() optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) model = ShardedDataParallel(model, optimizer) input_tensor = torch.rand((2, 2)) loss = model(input_tensor).sum() loss.backward() # make sure that the gradients are reduced # Wipe the gradients and switch to eval mode model.zero_grad() model.eval() _ = model(input_tensor) assert next(model.parameters()).grad is None or torch.norm( next(model.parameters()).grad) < 1e-6 # Get back to training model = model.train() model(input_tensor).sum().backward() assert torch.norm(next(model.parameters()).grad) > 0.0 dist.destroy_process_group()
def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]: for x, optimizer in enumerate(optimizers): if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer if not isinstance(optimizer, OSS): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: is_fp16 = self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF) # For multi-node training, compressing the model shards in fp16 before broadcasting # improves performance. When using PyTorch AMP, it will not degrade # the model performance. zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1 optimizers[x] = zero_optimizer del optimizer return optimizers
def init_components( self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None, ): """Inits the runs components.""" model = model_fn() model = self.sync_device(model) if self._sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) criterion = criterion_fn() criterion = self.sync_device(criterion) optimizer = optimizer_fn(model) optimizer = self.sync_device(optimizer) optimizer = OSS(model.parameters(), optim=optimizer.__class__, **optimizer.defaults) model = ShardedDataParallel(model, optimizer, **self.ddp_kwargs) scheduler = scheduler_fn(optimizer) scheduler = self.sync_device(scheduler) return model, criterion, optimizer, scheduler
def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size): # Check that the wrapped module can change devices dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) torch.cuda.set_device(rank) model = Sequential(Linear(2, 3), Linear( 3, 3)).cpu() # not device on purpose, test changing it after the fact optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer, sync_models_at_startup=False, reduce_buffer_size=reduce_buffer_size) try: ddp_model.to(device) assert False, "Changing devices should be caught and not supported" except AssertionError: pass dist.destroy_process_group()
def test_catch_grad_grad(): with temp_files_ctx(num=1) as temp_files: # Check that ShardedDDP exposes the original module's attributes dist.init_process_group(init_method="file://" + temp_files[0], backend="gloo", rank=0, world_size=1) model = Sequential(Linear(2, 3), Linear(3, 3)) model.train() chained_grad = torch.zeros_like(next(model.parameters())) chained_grad.requires_grad = True next(model.parameters()).grad = chained_grad optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer) inputs = torch.rand(100, 2) with pytest.raises(RuntimeError): _ = ddp_model(inputs) dist.destroy_process_group()
def run_test_training_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size): group = dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) torch.cuda.set_device(rank) model = Sequential(Linear(2, 3), Linear(3, 3)).to(device) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer, process_group=group, reduce_buffer_size=reduce_buffer_size) inputs = torch.rand((10, 2), device=device) outputs = ddp_model( inputs) # assert if the module has not been changed properly _ = outputs.norm().backward() ddp_model.eval() ddp_model( inputs ) # This will assert if eval() is not properly taken into account ddp_model(inputs) dist.destroy_process_group()
def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]: for x, optimizer in enumerate(optimizers): if not isinstance(optimizer, OSS): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) optimizers[x] = zero_optimizer del optimizer return optimizers
def create_optimizer_and_scheduler(self, num_training_steps: int): """ Edited to use fixed Adafactor. Setup the optimizer and the learning rate scheduler. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass. """ if self.optimizer is None: no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": self.args.weight_decay, }, { "params": [ p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ] optimizer_cls = FixedAdafactor if self.args.adafactor else AdamW if self.args.adafactor: optimizer_kwargs = { "scale_parameter": False, "relative_step": False } else: optimizer_kwargs = { "betas": (self.args.adam_beta1, self.args.adam_beta2), "eps": self.args.adam_epsilon, } optimizer_kwargs["lr"] = self.args.learning_rate if self.sharded_dpp: self.optimizer = OSS( params=optimizer_grouped_parameters, optim=optimizer_cls, **optimizer_kwargs, ) else: self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if self.lr_scheduler is None: self.lr_scheduler = get_scheduler( self.args.lr_scheduler_type, self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps, )
def _reinit_optimizers_with_oss(self): optimizers = self.lightning_module.trainer.optimizers for x, optimizer in enumerate(optimizers): if not isinstance(optimizer, OSS): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) optimizers[x] = zero_optimizer del optimizer trainer = self.lightning_module.trainer trainer.optimizers = optimizers
def _reinit_optimizers_with_oss(self): optimizers = self.lightning_module.trainer.optimizers for x, optimizer in enumerate(optimizers): if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer if not isinstance(optimizer, OSS): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: precision = self.lightning_module.trainer.precision is_fp16 = precision in ("mixed", 16) # For multi-node training, compressing the model shards in fp16 before broadcasting # improves performance. When using PyTorch AMP, it will not degrade # the model performance. zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1 optimizers[x] = zero_optimizer del optimizer trainer = self.lightning_module.trainer trainer.optimizers = optimizers trainer.convert_to_lightning_optimizers()
def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_buffer_size): # Only work with the even ranks, to check that the global_rank indexing is properly used dist.init_process_group(init_method="file://" + tempfile_name, backend=backend, rank=rank, world_size=world_size) sub_group_ranks = [0, 2] process_group = torch.distributed.new_group(ranks=sub_group_ranks, backend=backend) # Make sure that all the ranks get different training data # So that the sync check in between their models is meaningful torch.manual_seed(rank) np.random.seed(rank) # Standard deep learning setup device = "cuda" torch.cuda.set_device(rank) epochs, batch, input_width, hidden, target_width = 5, 3, 20, 10, 5 loss_fn = torch.nn.L1Loss().to(device) def check(optimizer, model): # Just run a couple of epochs, check that the model is properly updated for _ in range(epochs): target = torch.rand((batch, target_width), device=device) inputs = torch.rand((batch, input_width), device=device) def closure(): optimizer.zero_grad() output = model(inputs) loss = loss_fn(output, target) loss.backward() return loss _ = optimizer.step(closure=closure) # Check that all the params are the same on all ranks check_same_models_across_ranks( model, process_group, params_should_be_equal=True, check_broadcast_buffers=True ) if rank in sub_group_ranks: # Model not-fitting in the broadcast bucket model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)).to( device ) # With SGD, Momentum is required to get a state to shard optimizer = OSS(model.parameters(), group=process_group, lr=1e-3, momentum=0.99) model = ShardedDataParallel( model, optimizer, process_group=process_group, reduce_buffer_size=reduce_buffer_size ) check(optimizer, model) dist.destroy_process_group(process_group)
def _reinit_with_fairscale_oss(self, trainer): optimizers = trainer.optimizers for x, optimizer in enumerate(optimizers): if not isinstance(optimizer, OSS): optim_class = type(optimizer) zero_optimizer = OSS( params=optimizer.param_groups, optim=optim_class, **optimizer.defaults ) optimizers[x] = zero_optimizer del optimizer
def __init__(self, module: nn.Module, oss: OSS, world_size: int, process_group: Any = None, buffer_size: int = 2**28): super().__init__() self.module = module self.world_size = world_size self.process_group = process_group if process_group is not None else dist.group.WORLD self.rank = dist.get_rank(self.process_group) # Never use a bigger buffer than the number of model params self.buffer_size = min( buffer_size, sum(p.numel() for p in self.module.parameters())) self.buffer: Optional[Tensor] = None # Flag used to make sure we only reduce gradients one time in the execution engine self.need_reduction = False # We can also forcibly accumulate grads locally and only do the # gradients-reduce at some later time self.accumulate_grads = False # TODO (Min): The algorithm here can be improved. We are sorting params by device # and by rank. Then in reduction_fn below, we pack smaller ones into # a buffer for reduction. # We can pre-sort them here and simplify the reduction_fn logic below # since their size shouldn't change. # make per-device lists of parameters paramlists: OrderedDict = OrderedDict() for param in self.module.parameters(): device = param.device if paramlists.get(device) is None: paramlists[device] = [] paramlists[device] += [param] self.per_device_params = list(paramlists.values()) # query oss and build a param-to-rank table self.param_rank = {} for rank, param_groups in enumerate(oss.partition_parameters()): for param_group in param_groups: for param in param_group["params"]: self.param_rank[param] = rank # sanity checks assert len(self.param_rank) == len(list( self.module.parameters())), "number of params do not match" for param in self.module.parameters(): assert param in self.param_rank, f"{param} not in the optimizer"
def __init__( self, module: nn.Module, optimizer: Type[torch.optim.Optimizer], optimizer_params: Dict[str, Any], world_size: int, broadcast_buffers: bool, process_group: Any = None, buffer_size: int = 2**19, ): super().__init__() self.module = module self.world_size = world_size self.process_group = process_group if process_group is not None else dist.group.WORLD self.rank = dist.get_rank(self.process_group) self.broadcast_buffers = broadcast_buffers self.authoritative_rank = 0 # Flag used to make sure we only reduce gradients one time in the execution engine self.need_reduction = False # We can also forcibly accumulate grads locally and only do the # gradients-reduce at some later time self.accumulate_grads = False # Build the sharded optimizer self.sharded_optimizer = OSS(self.module.parameters(), optim=optimizer, group=process_group, **optimizer_params) # Allocate reduce buffers # - Never use a bigger buffer than the number of model params buffer_size = min(buffer_size, sum(p.numel() for p in self.module.parameters())) self._reduce_buffers: Dict[torch.device, List[torch.Tensor]] = {} # - One buffer per rank per device for device, per_device in self.sharded_optimizer.per_device_params.items( ): buffer_dtype = per_device[0][0].dtype self._reduce_buffers[device] = [ torch.zeros(buffer_size, dtype=buffer_dtype, device=device) for _ in range(len(per_device)) ] # Sanity checks assert len(self.sharded_optimizer.param_to_rank) == len( list(self.module.parameters())), "number of params do not match" for param in self.module.parameters(): assert param in self.sharded_optimizer.param_to_rank, f"{param} not in the optimizer"
def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name): url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) model = Sequential(Linear(2, 3), torch.nn.BatchNorm1d(3), Linear(3, 3)).to(device) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer) assert isinstance(model[1], torch.nn.SyncBatchNorm) # Ensures sync batch norm handles have been added ddp_model(torch.randn(2, 2).to(device)) dist.destroy_process_group()
def test_ddp_attributes(): # Check that ShardedDDP exposes the same attributes as Pytorch's DDP # - is multi_device_module # - device_type dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1) model = Sequential(Linear(2, 3), Linear(3, 3)) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer) assert hasattr(ddp_model, "is_multi_device_module") assert hasattr(ddp_model, "device_type") dist.destroy_process_group()
def test_random_attributes(): # Check that ShardedDDP exposes the original module's attributes dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1) model = Sequential(Linear(2, 3), Linear(3, 3)) model.banana = "sweet" optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer) assert hasattr(ddp_model, "banana") assert not hasattr(ddp_model, "orange") dist.destroy_process_group()
def reduce(*_: Any) -> None: # Skip gradient reduction, do not alter status flags if not self.should_accumulate_grads and self._grad_to_be_reduced[ index]: assert param.grad is not None, "Reducing gradients during backward pass, cannot be None" if not self._bucket_flush_callback_set: Variable._execution_engine.queue_callback( self._flush_buckets) self._bucket_flush_callback_set = True # Make sure that this is not fired twice self._grad_to_be_reduced[index] = False param.grad.mul_(self.world_size_scaling) if self.reduce_fp16: param.grad.data = param.grad.data.half() # Future work includes clearing up the buffer if possible def cleanup() -> None: if dst_rank != self.global_rank: param.grad = None else: assert param.grad is not None param.grad.data = param.grad.data.to( dtype=param.dtype) # Async reduce for this buffer, log the future dst_global_rank = OSS.get_global_rank( self.process_group, dst_rank) self._work_handles.append( Workhandle( handle=dist.reduce(tensor=param.grad.data, dst=dst_global_rank, group=self.process_group, async_op=True), callback=cleanup, )) self._reduced_grads += 1 # Opportunistically try to empty the queue self._try_consume_work_handle() # If all the reduce operations have been called, # make sure that all the asynchronous calls have concluded before moving on # and execute the delayed actions (release gradients, unroll the buckets) if self._reduced_grads == self._reduced_grads_max: self._consume_work_handles()
def run_test_device_change(rank, world_size, backend, device, temp_file_name): # Check that the wrapped module can change devices url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) model = Sequential(Linear(2, 3), Linear(3, 3)).cpu() optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer) ddp_model.to(device) inputs = torch.rand((10, 2), device=device) outputs = ddp_model(inputs) # assert if the module has not been changed properly loss = outputs.norm().backward() dist.destroy_process_group()
def test_mixed_types(): # Check that ShardedDDP exposes the original module's attributes dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1) model = _get_mlp(tripwire=True) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) model = ShardedDataParallel(model, optimizer) input_tensor = torch.rand((2, 2)) _ = model(input_tensor) dist.destroy_process_group()
def _test_basic_func(rank, ddp_cls, world_size, tempfile_name, test_case): _dist_init(rank, world_size, tempfile_name, backend="nccl") # Covers nccl model = Linear(2, 2) model.to("cuda") if ddp_cls is DDP: model = ddp_cls(model, device_ids=[rank]) optim = AdaScale(SGD(model.parameters(), lr=0.1)) elif ddp_cls is SDP: optim = AdaScale(OSS(model.parameters(), SGD, lr=0.1)) model = ddp_cls(model, sharded_optimizer=optim) else: assert ddp_cls is FSDP, ddp_cls # Two cases: # flatten=True : AdaScale wrapper must be after FSDP and it receives # a single grad tensor. It won't receive grad if # wrapped before. # flatten=False: AdaScale can be both before or after FSDP. # So, it is better to do AdaScale after FSDP. model = ddp_cls(model, flatten_parameters=False) optim = AdaScale(SGD(model.parameters(), lr=0.1)) if "input" in test_case: # single iter in_data = Tensor(test_case["input"][rank]) in_data = in_data.cuda() out = model(in_data) out.sum().backward() if ddp_cls is DDP: assert np.allclose(optim.gain(), test_case["expected_gain"]), optim.gain() optim.step() optim.zero_grad() else: # multiple iters for in_data in test_case["inputs"]: in_data = Tensor(in_data[rank]).cuda() out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if ddp_cls is DDP: assert np.allclose(optim.gain(), test_case["expected_gain"]), optim.gain() dist.destroy_process_group()