def run_one_step(rank, world_size, backend, device, temp_file_name): url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) if device == torch.device("cuda"): torch.cuda.set_device(rank) model = Sequential(Linear(2, 3), Linear(3, 4)).to(device) ddp = ShardedDataParallel( module=model, optimizer=torch.optim.SGD, optimizer_params={"lr": 0.1, "momentum": 0.99}, world_size=world_size ) optimizer = ddp.optimizer input_tensor = torch.rand((64, 2)).to(device) output = ddp(input_tensor).abs().sum() / input_tensor.numel() output.backward() ddp.reduce() # Check that all the grads have been populated, for the shard if device == torch.device("cuda"): torch.cuda.synchronize() # flush any remaining cuda op, just in case for pg in optimizer.optim.param_groups: for param in pg["params"]: if param.requires_grad: assert param.grad.abs().sum().item() > 0.0, "The reduce step should have populated all the gradients" # Check that the optimization process makes sense (ie. loss goes down for the same data) optimizer.step() new_eval = ddp(input_tensor).abs().sum() / input_tensor.numel()
def init_distributed_data_parallel_model(self): """ Initialize ShardedDataParallel, needed for sharded distributed training. This is where a model should be wrapped by DDP. """ # Init the base class, everything but the distributed model wrap is to be reused super().init_distributed_data_parallel_model() broadcast_buffers = ( self.broadcast_buffers_mode == BroadcastBuffersMode.FORWARD_PASS) # Replace the original DDP wrap by the shard-aware ShardedDDP # we use the fairscale reduce_buffer_size by default however, if user sets it to # some different value, we use the different value. reduce_buffer_size = 2**23 if self.config.MODEL.SHARDED_DDP_SETUP.reduce_buffer_size >= 0: reduce_buffer_size = self.config.MODEL.SHARDED_DDP_SETUP.reduce_buffer_size logging.info(f"Setting reduce_buffer_size: {reduce_buffer_size}") if isinstance(self.optimizer, ZeRO): logging.info("Using ShardedDDP") self.distributed_model = ShardedDataParallel( module=self.base_model, sharded_optimizer=self.optimizer.optimizer, broadcast_buffers=broadcast_buffers, reduce_buffer_size=reduce_buffer_size, ) else: raise NotImplementedError( "This DataParallel engine should only be used in conjunction with ZeRO" )
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name): dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) if device == torch.device("cuda"): torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) model = _DoubleInput().to(device) parameters = list(model.parameters()) optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=1e-3, momentum=0.99) optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2]) # Optim loop def closure(): input_tensor = torch.rand((64, 2)).to(device) loss = ddp_model(input_tensor, input_tensor).abs().sum() loss.backward() return loss for i in range(5): optimizer_1.zero_grad() optimizer_2.zero_grad() _ = optimizer_1.step(closure=closure) _ = optimizer_2.step(closure=closure) dist.destroy_process_group()
def init_distributed_data_parallel_model(self): """ Initialize `torch.nn.parallel.distributed.DistributedDataParallel <https://pytorch.org/ docs/stable/nn.html#distributeddataparallel>`_. Needed for distributed training. This is where a model should be wrapped by DDP. """ if not is_distributed_training_run(): return assert ( self.distributed_model is None ), "init_ddp_non_elastic must only be called once" broadcast_buffers = ( self.broadcast_buffers_mode == BroadcastBuffersMode.FORWARD_PASS ) if self.use_sharded_ddp: if not isinstance(self.optimizer, ZeRO): raise ValueError( "ShardedDataParallel engine should only be used in conjunction with ZeRO optimizer" ) from fairscale.nn.data_parallel import ShardedDataParallel # Replace the original DDP wrap by the shard-aware ShardedDDP self.distributed_model = ShardedDataParallel( module=self.base_model, sharded_optimizer=self.optimizer.optimizer, broadcast_buffers=broadcast_buffers, ) else: self.distributed_model = init_distributed_data_parallel_model( self.base_model, broadcast_buffers=broadcast_buffers, find_unused_parameters=self.find_unused_parameters, bucket_cap_mb=self.ddp_bucket_cap_mb, ) if self.fp16_grad_compress: from torch.distributed.algorithms import ddp_comm_hooks # FP16 hook is stateless and only takes a process group as the state. # We use the default process group so we set the state to None. process_group = None self.distributed_model.register_comm_hook( process_group, ddp_comm_hooks.default_hooks.fp16_compress_hook, ) if ( isinstance(self.base_loss, ClassyLoss) and self.base_loss.has_learned_parameters() ): logging.info("Initializing distributed loss") self.distributed_loss = init_distributed_data_parallel_model( self.base_loss, broadcast_buffers=broadcast_buffers, find_unused_parameters=self.find_unused_parameters, bucket_cap_mb=self.ddp_bucket_cap_mb, )
def test_train_eval_change(): # Check that ShardedDDP handles the switch from training to eval properly dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1) model = _get_mlp() model.train() optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) model = ShardedDataParallel(model, optimizer) input_tensor = torch.rand((2, 2)) loss = model(input_tensor).sum() loss.backward() # make sure that the gradients are reduced # Wipe the gradients and switch to eval mode model.zero_grad() model.eval() _ = model(input_tensor) assert next(model.parameters()).grad is None or torch.norm( next(model.parameters()).grad) < 1e-6 # Get back to training model = model.train() model(input_tensor).sum().backward() assert torch.norm(next(model.parameters()).grad) > 0.0 dist.destroy_process_group()
def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduce_buffer_size): dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) if device == "cuda": torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) model = _DoubleInput().to(device) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=reduce_buffer_size) # Optim loop def closure(): optimizer.zero_grad() input_tensor = torch.rand((64, 2)).to(device) loss = ddp_model(input_tensor, input_tensor).abs().sum() loss.backward() return loss for i in range(5): _ = optimizer.step(closure=closure) dist.destroy_process_group()
def init_components( self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None, ): """Inits the runs components.""" model = model_fn() model = self.sync_device(model) if self._sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) criterion = criterion_fn() criterion = self.sync_device(criterion) optimizer = optimizer_fn(model) optimizer = self.sync_device(optimizer) optimizer = OSS(model.parameters(), optim=optimizer.__class__, **optimizer.defaults) model = ShardedDataParallel(model, optimizer, **self.ddp_kwargs) scheduler = scheduler_fn(optimizer) scheduler = self.sync_device(scheduler) return model, criterion, optimizer, scheduler
def run_eval_mode(_unused): """ Testing eval mode make sure this is no asserts. """ dist.init_process_group(init_method=f"file://{tempfile.mkstemp()[1]}", backend=dist.Backend.GLOO, rank=0, world_size=1) model = Sequential(Linear(2, 3), Linear(3, 4)) optimizer_params = {"lr": 0.1, "momentum": 0.99} ddp = ShardedDataParallel(model, torch.optim.SGD, optimizer_params, 1, broadcast_buffers=False) optimizer = ddp.optimizer ddp.eval() for _ in range(5): input_tensor = torch.rand((64, 2)) output = ddp(input_tensor) ddp.train() try: for _ in range(5): input_tensor = torch.rand((64, 2)) output = ddp(input_tensor) except RuntimeError: pass else: assert False, "Multiple forward passes on training mode should not pass" dist.destroy_process_group()
def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size): # Check that the wrapped module can change devices dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) torch.cuda.set_device(rank) model = Sequential(Linear(2, 3), Linear( 3, 3)).cpu() # not device on purpose, test changing it after the fact optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer, sync_models_at_startup=False, reduce_buffer_size=reduce_buffer_size) try: ddp_model.to(device) assert False, "Changing devices should be caught and not supported" except AssertionError: pass dist.destroy_process_group()
def test_catch_grad_grad(): with temp_files_ctx(num=1) as temp_files: # Check that ShardedDDP exposes the original module's attributes dist.init_process_group(init_method="file://" + temp_files[0], backend="gloo", rank=0, world_size=1) model = Sequential(Linear(2, 3), Linear(3, 3)) model.train() chained_grad = torch.zeros_like(next(model.parameters())) chained_grad.requires_grad = True next(model.parameters()).grad = chained_grad optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer) inputs = torch.rand(100, 2) with pytest.raises(RuntimeError): _ = ddp_model(inputs) dist.destroy_process_group()
def parallelize_model(self) -> None: registry.register("data_parallel", False) registry.register("distributed", False) if ("cuda" in str(self.device) and torch.cuda.device_count() > 1 and not self.distributed): registry.register("data_parallel", True) self.model = torch.nn.DataParallel(self.model) if "cuda" in str(self.device) and self.distributed: registry.register("distributed", True) set_torch_ddp = True try: from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.optim.oss import OSS if isinstance(self.optimizer, OSS): self.model = ShardedDataParallel(self.model, self.optimizer) set_torch_ddp = False logger.info("Using FairScale ShardedDataParallel") except ImportError: logger.info("Using PyTorch DistributedDataParallel") warnings.warn( "You can enable ZeRO and Sharded DDP, by installing fairscale " + "and setting optimizer.enable_state_sharding=True.") if set_torch_ddp: self.model = torch.nn.parallel.DistributedDataParallel( self.model, device_ids=[self.local_rank], output_device=self.local_rank, find_unused_parameters=self.config.training. find_unused_parameters, )
def run_test_gpt2(rank, world_size, backend, device, temp_file_name): INPUT_DIM = 32 BACH_SIZE = 10 STEPS = 10 url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) if device == torch.device("cuda"): torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) model = GPT2( embed_dim=512, num_heads=2, num_layers=24, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2 ).to(device) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer) # Optim loop def closure(): optimizer.zero_grad() # Force int inputs to prevent the first grad from firing input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device) loss = ddp_model(input_tensor).abs().sum() loss.backward() return loss # Check for bucketing overflows for i in range(STEPS): _ = optimizer.step(closure=closure) dist.destroy_process_group()
def run_test_training_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size): group = dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) torch.cuda.set_device(rank) model = Sequential(Linear(2, 3), Linear(3, 3)).to(device) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer, process_group=group, reduce_buffer_size=reduce_buffer_size) inputs = torch.rand((10, 2), device=device) outputs = ddp_model( inputs) # assert if the module has not been changed properly _ = outputs.norm().backward() ddp_model.eval() ddp_model( inputs ) # This will assert if eval() is not properly taken into account ddp_model(inputs) dist.destroy_process_group()
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name): url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) device = torch.device("cuda") torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) # Any model works. Add one different buffer per rank model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) n_half_params = len(list(model.parameters())) // 2 sharded_optimizer = OSS( params=list(model.parameters())[:n_half_params], optim=torch.optim.SGD, lr=1e-3, momentum=0.99 ) sharded_optimizer_2 = OSS( params=list(model.parameters())[n_half_params:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99 ) sharded_ddp_model = ShardedDataParallel(module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True) ddp_model_single = copy.deepcopy(model) ddp_optimizer = torch.optim.SGD(list(ddp_model_single.parameters())[:n_half_params], lr=1e-3, momentum=0.99) ddp_optimizer_2 = torch.optim.SGD(list(ddp_model_single.parameters())[n_half_params:], lr=1e-3, momentum=0.99) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True) def check_same_model_params(): for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups): for p, ddp_p in zip(pg["params"], ddp_pg["params"]): assert torch.allclose( p, ddp_p, atol=1e-3 ), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}" for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()): assert torch.allclose(b, ddp_b, atol=1e-3), "Model buffers differ in between DDP and ShardedDDP" check_same_model_params() # The models should stay the same in between the ranks for i in range(20): input_tensor = torch.rand((64, 2)).to(device) # Run DDP ddp_optimizer.zero_grad() ddp_optimizer_2.zero_grad() ddp_loss = ddp_model(input_tensor).abs().sum() ddp_loss.backward() ddp_optimizer.step() ddp_optimizer_2.step() # Run Sharded sharded_optimizer.zero_grad() sharded_optimizer_2.zero_grad() sharded_loss = sharded_ddp_model(input_tensor).abs().sum() sharded_loss.backward() sharded_optimizer.step() sharded_optimizer_2.step() check_same_model_params() dist.destroy_process_group()
def run_test_gpt2(rank, world_size, backend, device, temp_file_name, reduce_buffer_size): INPUT_DIM = 16 BACH_SIZE = 10 STEPS = 10 url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) model = GPT2(embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=reduce_buffer_size) # Move the model to another device post-construction model = model.to(device) # Optim loop set_to_none = True def closure(): nonlocal set_to_none ddp_model.zero_grad(set_to_none=set_to_none) set_to_none = not set_to_none # Force int inputs to prevent the first grad from firing input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device) loss = ddp_model(input_tensor).abs().sum() loss.backward() return loss # Check for bucketing overflows for i in range(STEPS): _ = optimizer.step(closure=closure) # Stress test the .to() method ddp_model.to(device=device, dtype=torch.float16) ddp_model.to(device=device, dtype=torch.float32) dist.destroy_process_group()
def run_one_step(rank, world_size, backend, device, temp_file_name): url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) if device == torch.device("cuda"): torch.cuda.set_device(rank) # Any model works. Add one different buffer per rank model = Sequential(Linear(2, 3)).to(device) model.register_buffer("test_buffer", torch.ones((1)) * rank) def weights_init(m): if isinstance(m, Linear): torch.nn.init.constant_(m.weight.data, 1.0) torch.nn.init.constant_(m.bias.data, 1.0) model.apply(weights_init) model.to(device) ddp = ShardedDataParallel( module=model, optimizer=torch.optim.SGD, optimizer_params={ "lr": 0.01, "momentum": 0.99 }, world_size=world_size, broadcast_buffers=True, ) optimizer = ddp.optimizer model = ddp.module # Different input per rank, allows for checking that the gradients have been properly reduced input_tensor = (torch.ones((64, 2)) * rank).to(device) output = ddp(input_tensor).abs().sum() output.backward() ddp.reduce() # Check that all the grads have been populated, for the shard for pg in optimizer.optim.param_groups: for param in pg["params"]: if param.shape == torch.Size([3, 2]): assert param.grad[0, 0].cpu() == torch.tensor([32.0]) if param.shape == torch.Size([3]): assert param.grad[0].cpu() == torch.tensor([64.0]) # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0) for b in model.buffers(): assert b.cpu().item() == 0.0 dist.destroy_process_group()
def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_buffer_size): # Only work with the even ranks, to check that the global_rank indexing is properly used dist.init_process_group(init_method="file://" + tempfile_name, backend=backend, rank=rank, world_size=world_size) sub_group_ranks = [0, 2] process_group = torch.distributed.new_group(ranks=sub_group_ranks, backend=backend) # Make sure that all the ranks get different training data # So that the sync check in between their models is meaningful torch.manual_seed(rank) np.random.seed(rank) # Standard deep learning setup device = "cuda" torch.cuda.set_device(rank) epochs, batch, input_width, hidden, target_width = 5, 3, 20, 10, 5 loss_fn = torch.nn.L1Loss().to(device) def check(optimizer, model): # Just run a couple of epochs, check that the model is properly updated for _ in range(epochs): target = torch.rand((batch, target_width), device=device) inputs = torch.rand((batch, input_width), device=device) def closure(): optimizer.zero_grad() output = model(inputs) loss = loss_fn(output, target) loss.backward() return loss _ = optimizer.step(closure=closure) # Check that all the params are the same on all ranks check_same_models_across_ranks( model, process_group, params_should_be_equal=True, check_broadcast_buffers=True ) if rank in sub_group_ranks: # Model not-fitting in the broadcast bucket model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)).to( device ) # With SGD, Momentum is required to get a state to shard optimizer = OSS(model.parameters(), group=process_group, lr=1e-3, momentum=0.99) model = ShardedDataParallel( model, optimizer, process_group=process_group, reduce_buffer_size=reduce_buffer_size ) check(optimizer, model) dist.destroy_process_group(process_group)
def test_ddp_attributes(): # Check that ShardedDDP exposes the same attributes as Pytorch's DDP # - is multi_device_module # - device_type dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1) model = Sequential(Linear(2, 3), Linear(3, 3)) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer) assert hasattr(ddp_model, "is_multi_device_module") assert hasattr(ddp_model, "device_type") dist.destroy_process_group()
def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name): url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) model = Sequential(Linear(2, 3), torch.nn.BatchNorm1d(3), Linear(3, 3)).to(device) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer) assert isinstance(model[1], torch.nn.SyncBatchNorm) # Ensures sync batch norm handles have been added ddp_model(torch.randn(2, 2).to(device)) dist.destroy_process_group()
def train(rank, args, model, device, train_loader, num_epochs): ############## # SETUP dist_init(rank, WORLD_SIZE, BACKEND) ddp = ShardedDataParallel( module=model, optimizer=torch.optim.Adadelta, optimizer_params={"lr": 1e-4}, world_size=WORLD_SIZE, broadcast_buffers=True, ) ddp.train() optimizer = ddp.optimizer # Reset the memory use counter torch.cuda.reset_peak_memory_stats(rank) # Training loop torch.cuda.synchronize(rank) training_start = time.monotonic() loss_fn = nn.CrossEntropyLoss() ############## model.train() measurements = [] for epoch in range(num_epochs): epoch_start = time.monotonic() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) def closure(): model.zero_grad() outputs = model(data) loss = loss_fn(outputs, target) loss /= WORLD_SIZE loss.backward() # if dist.get_rank() == 0: # print(f"Loss: {loss.item()}") ddp.reduce() # Send the gradients to the appropriate shards return loss optimizer.step(closure) epoch_end = time.monotonic() torch.cuda.synchronize(rank) training_stop = time.monotonic() print("Total Time:", training_stop - training_start)
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name): url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) if device == torch.device("cuda"): torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) class _DoubleInput(torch.nn.Module): def __init__(self): super().__init__() self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) def forward(self, x, y): x1 = self.mlp(x) x2 = self.mlp(y) return torch.cat((x1, x2), dim=1) model = _DoubleInput().to(device) parameters = list(model.parameters()) optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=0.01, momentum=0.99) optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=0.01, momentum=0.99) ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2]) # Optim loop def closure(): input_tensor = torch.rand((64, 2)).to(device) loss = ddp_model(input_tensor, input_tensor).abs().sum() loss.backward() return loss for i in range(5): optimizer_1.zero_grad() optimizer_2.zero_grad() _ = optimizer_1.step(closure=closure) _ = optimizer_2.step(closure=closure) dist.destroy_process_group()
def test_random_attributes(): # Check that ShardedDDP exposes the original module's attributes dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1) model = Sequential(Linear(2, 3), Linear(3, 3)) model.banana = "sweet" optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer) assert hasattr(ddp_model, "banana") assert not hasattr(ddp_model, "orange") dist.destroy_process_group()
def run_test_device_change(rank, world_size, backend, device, temp_file_name): # Check that the wrapped module can change devices url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) model = Sequential(Linear(2, 3), Linear(3, 3)).cpu() optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer) ddp_model.to(device) inputs = torch.rand((10, 2), device=device) outputs = ddp_model(inputs) # assert if the module has not been changed properly loss = outputs.norm().backward() dist.destroy_process_group()
def run_one_step(rank, world_size, backend, device, temp_file_name): url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) if device == torch.device("cuda"): torch.cuda.set_device(rank) # Any model works. Add one different buffer per rank model = Sequential(Linear(2, 3), Linear(3, 4)).to(device) model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) ddp = ShardedDataParallel( module=model, optimizer=torch.optim.SGD, optimizer_params={ "lr": 0.01, "momentum": 0.99 }, world_size=world_size, broadcast_buffers=True, ) optimizer = ddp.optimizer model = ddp.module input_tensor = torch.rand((64, 2)).to(device) output = ddp(input_tensor).abs().sum() / input_tensor.numel() output.backward() ddp.reduce() # Check that all the grads have been populated, for the shard if device == torch.device("cuda"): torch.cuda.synchronize() # flush any remaining cuda op, just in case for pg in optimizer.optim.param_groups: for param in pg["params"]: if param.requires_grad: assert param.grad.abs().sum().item( ) > 0.0, "The reduce step should have populated all the gradients" # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0) for b in model.buffers(): assert b.cpu().item() == 0.0
def test_mixed_types(): # Check that ShardedDDP exposes the original module's attributes dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1) model = _get_mlp(tripwire=True) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) model = ShardedDataParallel(model, optimizer) input_tensor = torch.rand((2, 2)) _ = model(input_tensor) dist.destroy_process_group()
def train( rank: int, world_size: int, num_epochs: int = 10, batch_size: int = 32, data_size: int = 200, backend: str = "gloo", use_oss: bool = True, use_sdp: bool = False, check_regression: bool = True, reference_speed: float = -1.0, reference_memory: float = -1.0, reference_loss: float = -1.0, ): assert not use_sdp or (use_sdp and use_oss), "ShardedDataParallel requires OSS" # DDP dist_init(rank=rank, world_size=world_size, backend=backend) # Setup torch.cuda.set_device(rank) torch.cuda.manual_seed(0) torch.manual_seed(0) # also sets the cuda seed np.random.seed(0) if backend == "nccl": torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False model, dataloader, loss_fn = get_problem(rank, data_size, batch_size) # Shard the optimizer optimizer: Optional[torch.optim.Optimizer] = None if use_sdp: ddp = ShardedDataParallel( module=model, optimizer=OPTIM, optimizer_params={ "lr": 1e-4, "momentum": 0.9 }, world_size=world_size, broadcast_buffers=False, ) ddp.train() optimizer = ddp.optimizer model = ddp else: optimizer = ( OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9) if use_oss else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)) # Reset the memory use counter torch.cuda.reset_peak_memory_stats(rank) # Dummy training loop torch.cuda.synchronize(rank) training_start = time.monotonic() model.train() measurements = [] final_loss: Optional[float] = -1.0 for epoch in range(num_epochs): epoch_start = time.monotonic() for batch in dataloader: def closure(): model.zero_grad() outputs = model(batch["inputs"]) loss = loss_fn(outputs, batch["label"]) loss /= world_size loss.backward() dist.all_reduce(loss, op=dist.ReduceOp.SUM) if use_sdp: ddp.reduce( ) # Send the gradients to the appropriate shards return loss final_loss = optimizer.step(closure) epoch_end = time.monotonic() if use_oss: # Check the checkpointing in the case of the OSS optimizer # Memory usage could spill over from there optimizer = cast(OSS, optimizer) optimizer.consolidate_state_dict() if dist.get_rank() == 0: _ = optimizer.state_dict() print("... State dict collected") measurements.append(data_size / (epoch_end - epoch_start)) if dist.get_rank() == 0: print( f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}" ) torch.cuda.synchronize(rank) training_stop = time.monotonic() img_per_sec = data_size / (training_stop - training_start) * num_epochs max_memory = torch.cuda.max_memory_allocated(rank) / 2**20 print( f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall" ) print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB") # Compute the mean and average img per second mean = sum(measurements) / len(measurements) diff = map(lambda x: pow(x - mean, 2.0), measurements) std = math.sqrt(sum(diff) / (len(measurements) - 1)) print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}") if use_oss and check_regression and dist.get_rank() == 0: assert (mean + 3.0 * std) > reference_speed, "Speed regression detected" assert max_memory < 1.05 * reference_memory, "Memory use regression detected" assert abs(cast(float, final_loss) - reference_loss) < 1e-3, "Loss regression detected" print("[Regression Test] VALID")
def run_one_step( rank, world_size, backend, device, temp_file_name, broadcast_buffers, grad_accumulation, reduce_buffer_size, ): dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) if device == torch.device("cuda"): torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) # Any model works. Add one different buffer per rank model = _get_mlp() model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) next(model.parameters() ).requires_grad = False # Test non-trainable parameters optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer, broadcast_buffers=broadcast_buffers, reduce_buffer_size=reduce_buffer_size) # The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that check_same_models_across_ranks(ddp_model, dist.group.WORLD, params_should_be_equal=True, check_broadcast_buffers=broadcast_buffers) # Optim loop def closure(): optimizer.zero_grad() with ddp_model.no_sync() if grad_accumulation else suppress(): input_tensor = torch.rand((64, 2)).to(device) loss = ddp_model(input_tensor).abs().sum() loss.backward() return loss # The models should stay the same in between the ranks for i in range(5): _ = optimizer.step(closure=closure) # when running on cpu/gloo the "nodes" are not really different same_params = device == torch.device("cpu") or grad_accumulation check_same_models_across_ranks( ddp_model, dist.group.WORLD, params_should_be_equal=same_params, check_broadcast_buffers=broadcast_buffers) dist.destroy_process_group()
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_buffer_size): dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) device = torch.device("cuda") torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) # Any model works. Add one different buffer per rank BATCHS = 20 model = _get_mlp_emb() model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) n_half_params = len(list(model.parameters())) // 2 optim_settings = {"lr": 1e-3, "momentum": 0.99} sharded_optimizer = OSS(params=list(model.parameters())[:n_half_params], optim=torch.optim.SGD, **optim_settings) sharded_optimizer_2 = OSS(params=list(model.parameters())[n_half_params:], optim=torch.optim.SGD, **optim_settings) sharded_ddp_model = ShardedDataParallel( module=model, sharded_optimizer=[sharded_optimizer, sharded_optimizer_2], broadcast_buffers=True, reduce_buffer_size=reduce_buffer_size, ) ddp_model_single = copy.deepcopy(model) ddp_optimizer = torch.optim.SGD( list(ddp_model_single.parameters())[:n_half_params], **optim_settings) ddp_optimizer_2 = torch.optim.SGD( list(ddp_model_single.parameters())[n_half_params:], **optim_settings) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True) check_same_model_params( sharded_ddp_model, ddp_model, f"DDP parity two optim test failing. differing at startup, Buffers {reduce_buffer_size}", ) for i in range(BATCHS): input_tensor = _get_random_inputs(device) # Run DDP ddp_optimizer.zero_grad() ddp_optimizer_2.zero_grad() ddp_loss = ddp_model(input_tensor).abs().sum() ddp_loss.backward() ddp_optimizer.step() ddp_optimizer_2.step() torch.cuda.synchronize(device) # Run Sharded sharded_optimizer.zero_grad() sharded_optimizer_2.zero_grad() sharded_loss = sharded_ddp_model(input_tensor).abs().sum() sharded_loss.backward() sharded_optimizer.step() sharded_optimizer_2.step() torch.cuda.synchronize(device) check_same_model_params( sharded_ddp_model, ddp_model, f"DDP parity two optim test failing, step {i}, buffers {reduce_buffer_size}", ) dist.destroy_process_group()
def check(broadcast_buffers: bool, grad_accumulation: bool = False) -> None: # Any model works. Add one different buffer per rank model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) next(model.parameters() ).requires_grad = False # Test non-trainable parameters optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer, broadcast_buffers=broadcast_buffers) def check_same_model_params(same_params: bool): # Check that all the params are the same on all ranks # This should be true with and without broadcast_buffers, we don't have any real buffer here receptacle: List[torch.Tensor] = [] if dist.get_backend() != "nccl": for pg in optimizer.param_groups: for p in pg["params"]: # Check the params receptacle = [p.clone() for _ in range(world_size) ] if rank == 0 else [] dist.gather(p, receptacle, dst=0) if rank == 0: for sync_p in receptacle[1:]: if same_params: assert torch.all( torch.eq(receptacle[0], sync_p) ), "Models differ in between ranks" else: assert not torch.all( torch.eq(receptacle[0], sync_p) ), "Gradients should not have been synced" # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0) if broadcast_buffers: for b in ddp_model.buffers(): receptacle = [b.clone() for _ in range(world_size) ] if rank == 0 else [] dist.gather(b, receptacle, dst=0) if rank == 0: for sync_b in receptacle[1:]: if same_params: assert torch.all( torch.eq(receptacle[0], sync_b) ), "Models differ in between ranks" else: assert not torch.all( torch.eq(receptacle[0], sync_b) ), "Gradients should not have been synced" assert b.cpu().item() == 0.0 # The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that check_same_model_params(same_params=True) # Optim loop def closure(): optimizer.zero_grad() with ddp_model.no_sync() if grad_accumulation else suppress(): input_tensor = torch.rand((64, 2)).to(device) loss = ddp_model(input_tensor).abs().sum() loss.backward() return loss # The models should stay the same in between the ranks for i in range(5): _ = optimizer.step(closure=closure) # when running on cpu/gloo the "nodes" are not really different same_params = device == torch.device("cpu") or grad_accumulation check_same_model_params(same_params=same_params)
def run_ddp_parity( rank, world_size, backend, temp_file_name, reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction, clip_grad_norm, amp, manual_reduction, multiple_fw, ): dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) device = torch.device("cuda") torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) NUMBER_BATCHS = 5 # Test all combinations: AMP, Accumulate, Change train graph, reduce buckets print( f"{rank}: Checking configuration: accumulate {grad_accumulation}" + f" - change train graph {change_train_graph}" + f" - amp {amp}" + f" - manual reduction {manual_reduction}" + f" - buffers {reduce_buffer_size}" + f" - multiple FW {multiple_fw}", flush=True, ) # The API should be the exact same in between the sharded and non-sharded variants, generic closure def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False): accumulate_steps = 3 if should_accumulate else 1 model.zero_grad() def step(): if scaler is not None: with torch.cuda.amp.autocast(): loss = model(input_tensor).abs().sum() scaler.scale(loss).backward() else: loss = model(input_tensor).abs().sum() loss.backward() with model.no_sync() if should_accumulate else suppress(): for _ in range(accumulate_steps - 1): step() if not _manual_reduction: step() else: with model.no_sync(): step() model.reduce() # Any model works. Add one different buffer per rank model = _get_mlp_emb(multiple_fw) model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) # Make sure that the model starts with non-trainable, so that we check for the buckets to be # properly reassigned when/if this changes next(model.parameters()).requires_grad = False sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.99) sharded_ddp_model = ShardedDataParallel( module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True, reduce_buffer_size=reduce_buffer_size, reduce_fp16=fp16_reduction, ) ddp_model_single = copy.deepcopy(model) ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-4, momentum=0.99) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True) if fp16_reduction: from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook ddp_model.register_comm_hook(state=None, hook=fp16_compress_hook) # type: ignore ddp_scaler = TorchGradScaler() if amp else None sharded_scaler = ShardedGradScaler() if amp else None # The model should be synchronized in between the ranks at construction time, check that check_same_model_params(sharded_ddp_model, ddp_model) # Typical training loop, check that we get the exact same results as DDP for i in range(NUMBER_BATCHS): input_tensor = _get_random_inputs(device) def ddp_closure(input_tensor=input_tensor): return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation) def sharded_closure(input_tensor=input_tensor): return closure( sharded_ddp_model, sharded_scaler, input_tensor, grad_accumulation, _manual_reduction=manual_reduction, ) # Step/scale both for _scaler, _closure, _optimizer in ( (ddp_scaler, ddp_closure, ddp_optimizer), (sharded_scaler, sharded_closure, sharded_optimizer), ): if _scaler is not None: _ = _closure(input_tensor) _scaler.step(_optimizer) _scaler.update() else: _optimizer.step(_closure()) check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke") # Check that the two grad norm are equivalent # NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case # This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also # be valid for ShardedDDP # NOTE: DDP does not handle parameters trainability being changed after the fact, see # https://github.com/pytorch/pytorch/blob/5781aec74ef00284e0262817a649278c2e8072bf/torch/nn/parallel/distributed.py#L471 if clip_grad_norm and not change_train_graph: if torch_version() >= (1, 9, 0): total_norm = torch.nn.utils.clip_grad_norm_( ddp_model.parameters(), 0.3, norm_type=2.0, error_if_nonfinite=False) # type: ignore else: total_norm = torch.nn.utils.clip_grad_norm_( ddp_model.parameters(), 0.3, norm_type=2.0) # type: ignore if not torch.isnan(total_norm): oss_total_norm = sharded_optimizer.clip_grad_norm( 0.3, norm_type=2.0) allclose = torch.allclose(oss_total_norm, total_norm, atol=1e-2 if amp else 1e-8) if not allclose: # Debug helper if this unit test does not pass, compare the gradients in between DDP and ShardedDDP for idx, (p_ddp, p_sdp) in enumerate( zip(ddp_model.parameters(), sharded_ddp_model.parameters())): if p_ddp.grad is not None: if p_sdp.grad is not None: print(rank, idx, torch.norm(p_ddp.grad), torch.norm(p_sdp.grad), flush=True) else: print(rank, idx, torch.norm(p_ddp.grad), "not owned", flush=True) assert ( allclose ), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}" else: print(rank, "NaN grad norm in DDP", flush=True) # Flip the trainability of the first parameter back and forth if i == 0 and change_train_graph: next(sharded_ddp_model.parameters()).requires_grad = not next( sharded_ddp_model.parameters()).requires_grad next(ddp_model.parameters()).requires_grad = not next( ddp_model.parameters()).requires_grad check_same_model_params( sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke") dist.destroy_process_group()