def check_step(): input_tensor = torch.rand((batch, in_channels)).to(device) def closure_ddp(input_tensor=input_tensor): ddp_optimizer.zero_grad() ddp_loss = ddp_model(input_tensor).abs().sum() ddp_loss.backward() return ddp_loss def closure_sharded(input_tensor=input_tensor): sharded_optimizer.zero_grad() sharded_loss = oss_ddp_model(input_tensor).abs().sum() sharded_loss.backward() return sharded_loss loss_ddp = cast(torch.Tensor, ddp_optimizer.step(closure=closure_ddp)) loss_sharded_optim = cast( torch.Tensor, sharded_optimizer.step(closure=closure_sharded)) assert torch.allclose( loss_ddp, loss_sharded_optim ), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}" check_same_model_params(oss_ddp_model, ddp_model)
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name): url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) device = torch.device("cuda") torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) # Any model works. Add one different buffer per rank model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) n_half_params = len(list(model.parameters())) // 2 sharded_optimizer = OSS( params=list(model.parameters())[:n_half_params], optim=torch.optim.SGD, lr=1e-3, momentum=0.99 ) sharded_optimizer_2 = OSS( params=list(model.parameters())[n_half_params:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99 ) sharded_ddp_model = ShardedDataParallel(module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True) ddp_model_single = copy.deepcopy(model) ddp_optimizer = torch.optim.SGD(list(ddp_model_single.parameters())[:n_half_params], lr=1e-3, momentum=0.99) ddp_optimizer_2 = torch.optim.SGD(list(ddp_model_single.parameters())[n_half_params:], lr=1e-3, momentum=0.99) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True) def check_same_model_params(): for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups): for p, ddp_p in zip(pg["params"], ddp_pg["params"]): assert torch.allclose( p, ddp_p, atol=1e-3 ), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}" for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()): assert torch.allclose(b, ddp_b, atol=1e-3), "Model buffers differ in between DDP and ShardedDDP" check_same_model_params() # The models should stay the same in between the ranks for i in range(20): input_tensor = torch.rand((64, 2)).to(device) # Run DDP ddp_optimizer.zero_grad() ddp_optimizer_2.zero_grad() ddp_loss = ddp_model(input_tensor).abs().sum() ddp_loss.backward() ddp_optimizer.step() ddp_optimizer_2.step() # Run Sharded sharded_optimizer.zero_grad() sharded_optimizer_2.zero_grad() sharded_loss = sharded_ddp_model(input_tensor).abs().sum() sharded_loss.backward() sharded_optimizer.step() sharded_optimizer_2.step() check_same_model_params() dist.destroy_process_group()
def run_ddp_parity( rank, world_size, backend, temp_file_name, reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction, clip_grad_norm, amp, manual_reduction, multiple_fw, ): dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) device = torch.device("cuda") torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) NUMBER_BATCHS = 5 # Test all combinations: AMP, Accumulate, Change train graph, reduce buckets print( f"{rank}: Checking configuration: accumulate {grad_accumulation}" + f" - change train graph {change_train_graph}" + f" - amp {amp}" + f" - manual reduction {manual_reduction}" + f" - buffers {reduce_buffer_size}" + f" - multiple FW {multiple_fw}", flush=True, ) # The API should be the exact same in between the sharded and non-sharded variants, generic closure def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False): accumulate_steps = 3 if should_accumulate else 1 model.zero_grad() def step(): if scaler is not None: with torch.cuda.amp.autocast(): loss = model(input_tensor).abs().sum() scaler.scale(loss).backward() else: loss = model(input_tensor).abs().sum() loss.backward() with model.no_sync() if should_accumulate else suppress(): for _ in range(accumulate_steps - 1): step() if not _manual_reduction: step() else: with model.no_sync(): step() model.reduce() # Any model works. Add one different buffer per rank model = _get_mlp_emb(multiple_fw) model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) # Make sure that the model starts with non-trainable, so that we check for the buckets to be # properly reassigned when/if this changes next(model.parameters()).requires_grad = False sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.99) sharded_ddp_model = ShardedDataParallel( module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True, reduce_buffer_size=reduce_buffer_size, reduce_fp16=fp16_reduction, ) ddp_model_single = copy.deepcopy(model) ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-4, momentum=0.99) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True) if fp16_reduction: from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook ddp_model.register_comm_hook(state=None, hook=fp16_compress_hook) # type: ignore ddp_scaler = TorchGradScaler() if amp else None sharded_scaler = ShardedGradScaler() if amp else None # The model should be synchronized in between the ranks at construction time, check that check_same_model_params(sharded_ddp_model, ddp_model) # Typical training loop, check that we get the exact same results as DDP for i in range(NUMBER_BATCHS): input_tensor = _get_random_inputs(device) def ddp_closure(input_tensor=input_tensor): return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation) def sharded_closure(input_tensor=input_tensor): return closure( sharded_ddp_model, sharded_scaler, input_tensor, grad_accumulation, _manual_reduction=manual_reduction, ) # Step/scale both for _scaler, _closure, _optimizer in ( (ddp_scaler, ddp_closure, ddp_optimizer), (sharded_scaler, sharded_closure, sharded_optimizer), ): if _scaler is not None: _ = _closure(input_tensor) _scaler.step(_optimizer) _scaler.update() else: _optimizer.step(_closure()) check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke") # Check that the two grad norm are equivalent # NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case # This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also # be valid for ShardedDDP # NOTE: DDP does not handle parameters trainability being changed after the fact, see # https://github.com/pytorch/pytorch/blob/5781aec74ef00284e0262817a649278c2e8072bf/torch/nn/parallel/distributed.py#L471 if clip_grad_norm and not change_train_graph: if torch_version() >= (1, 9, 0): total_norm = torch.nn.utils.clip_grad_norm_( ddp_model.parameters(), 0.3, norm_type=2.0, error_if_nonfinite=False) # type: ignore else: total_norm = torch.nn.utils.clip_grad_norm_( ddp_model.parameters(), 0.3, norm_type=2.0) # type: ignore if not torch.isnan(total_norm): oss_total_norm = sharded_optimizer.clip_grad_norm( 0.3, norm_type=2.0) allclose = torch.allclose(oss_total_norm, total_norm, atol=1e-2 if amp else 1e-8) if not allclose: # Debug helper if this unit test does not pass, compare the gradients in between DDP and ShardedDDP for idx, (p_ddp, p_sdp) in enumerate( zip(ddp_model.parameters(), sharded_ddp_model.parameters())): if p_ddp.grad is not None: if p_sdp.grad is not None: print(rank, idx, torch.norm(p_ddp.grad), torch.norm(p_sdp.grad), flush=True) else: print(rank, idx, torch.norm(p_ddp.grad), "not owned", flush=True) assert ( allclose ), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}" else: print(rank, "NaN grad norm in DDP", flush=True) # Flip the trainability of the first parameter back and forth if i == 0 and change_train_graph: next(sharded_ddp_model.parameters()).requires_grad = not next( sharded_ddp_model.parameters()).requires_grad next(ddp_model.parameters()).requires_grad = not next( ddp_model.parameters()).requires_grad check_same_model_params( sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke") dist.destroy_process_group()
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_buffer_size): dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) device = torch.device("cuda") torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) # Any model works. Add one different buffer per rank BATCHS = 20 model = _get_mlp_emb() model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) n_half_params = len(list(model.parameters())) // 2 optim_settings = {"lr": 1e-3, "momentum": 0.99} sharded_optimizer = OSS(params=list(model.parameters())[:n_half_params], optim=torch.optim.SGD, **optim_settings) sharded_optimizer_2 = OSS(params=list(model.parameters())[n_half_params:], optim=torch.optim.SGD, **optim_settings) sharded_ddp_model = ShardedDataParallel( module=model, sharded_optimizer=[sharded_optimizer, sharded_optimizer_2], broadcast_buffers=True, reduce_buffer_size=reduce_buffer_size, ) ddp_model_single = copy.deepcopy(model) ddp_optimizer = torch.optim.SGD( list(ddp_model_single.parameters())[:n_half_params], **optim_settings) ddp_optimizer_2 = torch.optim.SGD( list(ddp_model_single.parameters())[n_half_params:], **optim_settings) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True) check_same_model_params( sharded_ddp_model, ddp_model, f"DDP parity two optim test failing. differing at startup, Buffers {reduce_buffer_size}", ) for i in range(BATCHS): input_tensor = _get_random_inputs(device) # Run DDP ddp_optimizer.zero_grad() ddp_optimizer_2.zero_grad() ddp_loss = ddp_model(input_tensor).abs().sum() ddp_loss.backward() ddp_optimizer.step() ddp_optimizer_2.step() torch.cuda.synchronize(device) # Run Sharded sharded_optimizer.zero_grad() sharded_optimizer_2.zero_grad() sharded_loss = sharded_ddp_model(input_tensor).abs().sum() sharded_loss.backward() sharded_optimizer.step() sharded_optimizer_2.step() torch.cuda.synchronize(device) check_same_model_params( sharded_ddp_model, ddp_model, f"DDP parity two optim test failing, step {i}, buffers {reduce_buffer_size}", ) dist.destroy_process_group()
def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer], change_train_graph: bool = False): # Any model works. Add one different buffer per rank trunk = torch.nn.Sequential(torch.nn.Linear(in_channels, hidden), torch.nn.Linear(hidden, hidden)) trunk.register_buffer("test_buffer", torch.ones((1)) * rank) trunk.to(device) head = torch.nn.Linear(hidden, out_channels).to(device) # Define a model to be trained by OSS oss_module = torch.nn.Sequential(trunk, head) oss_trainable_params = [ { "params": trunk.parameters(), "lr": 1e-5 }, { "params": head.parameters(), "lr": 1e-4 }, ] optimizer_settings: Dict[Any, Any] = {} if isinstance(optimizer, torch.optim.SGD): optimizer_settings["momentum"] = 0.9 sharded_optimizer = optim.OSS( params=oss_trainable_params, optim=optimizer, group=None, broadcast_buffer_size=2**10, **optimizer_settings, ) oss_ddp_model = DDP(module=oss_module, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True) # Define a model to be trained by normal pytorch + DDP ddp_trunk = copy.deepcopy(trunk) ddp_head = copy.deepcopy(head) ddp_module = torch.nn.Sequential(ddp_trunk, ddp_head) ddp_trainable_params = [ { "params": ddp_trunk.parameters(), "lr": 1e-5 }, { "params": ddp_head.parameters(), "lr": 1e-4 }, ] ddp_optimizer = optimizer(ddp_trainable_params, **optimizer_settings) # type: ignore ddp_model = DDP(module=ddp_module, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True) def check_step(): input_tensor = torch.rand((batch, in_channels)).to(device) def closure_ddp(input_tensor=input_tensor): ddp_optimizer.zero_grad() ddp_loss = ddp_model(input_tensor).abs().sum() ddp_loss.backward() return ddp_loss def closure_sharded(input_tensor=input_tensor): sharded_optimizer.zero_grad() sharded_loss = oss_ddp_model(input_tensor).abs().sum() sharded_loss.backward() return sharded_loss loss_ddp = cast(torch.Tensor, ddp_optimizer.step(closure=closure_ddp)) loss_sharded_optim = cast( torch.Tensor, sharded_optimizer.step(closure=closure_sharded)) assert torch.allclose( loss_ddp, loss_sharded_optim ), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}" check_same_model_params(oss_ddp_model, ddp_model) # The model should be synchronized in between the ranks at construction time, check that check_same_model_params(oss_ddp_model, ddp_model) # The models should stay the same in between ddp and sharded optimizer for i in range(5): check_step() # Check that altering the trainable parameters does not cause DDP and OSS to diverge if change_train_graph: # Flip the first parameter from trainable to non-trainable and vice-versa next(ddp_module.parameters()).requires_grad = not next( ddp_module.parameters()).requires_grad next(oss_module.parameters()).requires_grad = not next( oss_module.parameters()).requires_grad # sharded_optimizer.refresh_trainable() # Check that the checkpoints are compatible # - get states ddp_state_dict = ddp_optimizer.state_dict() sharded_optimizer.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) sharded_optim_state_dict = sharded_optimizer.state_dict( ) if rank == RECIPIENT_RANK else {} sharded_optim_state_dict = sync_object_ranks(sharded_optim_state_dict, RECIPIENT_RANK, device) # - cross load the states ddp_optimizer.load_state_dict( sharded_optim_state_dict) # mixup on purpose ! sharded_optimizer.load_state_dict(ddp_state_dict) # - run one step and check that the models are still the same check_step()
def run_state_dict_distributed(rank, world_size, tempfile_name): dist_init(rank, world_size, tempfile_name, backend="gloo") device = torch.device(rank) torch.manual_seed( rank) # make sure that the different rank get different data # Setup two problems in parallel, we'll make sure that the second track (with save/load) follows the first one(untouched) # We split the model in two to test the multiple param groups support batch, input_width, hidden, target_width = 3, 20, 10, 5 target = torch.rand((batch, target_width), device=device) inputs = torch.rand((batch, input_width), device=device) model_oss1 = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, hidden)).to(device) head_oss1 = torch.nn.Linear(hidden, target_width).to(device) model_oss2 = copy.deepcopy(model_oss1) head_oss2 = copy.deepcopy(head_oss1) # For this test the gradients are (all) reduced in the same way in between the torch reference and fairscale. # Normally OSS would use ShardedDDP and only reduce to the proper rank, but this does not change the # gradient norm computation from OSS and adds a dependency. # to keep the comparison apples-to-apples DDP is used in both cases model_oss1 = DDP( module=model_oss1, device_ids=[rank], ) sharded_optimizer1 = optim.OSS(model_oss1.parameters(), lr=0.1, momentum=0.99) sharded_optimizer1.add_param_group({"params": head_oss1.parameters()}) model_oss2 = DDP( module=model_oss2, device_ids=[rank], ) sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99) sharded_optimizer2.add_param_group({"params": head_oss2.parameters()}) loss_fn = torch.nn.L1Loss().to(device) def run_grad_step(model, head, optimizer): model.zero_grad() outputs = head(model(inputs)) # pull the current state, broadcast it to all ranks sharded_optimizer2.consolidate_state_dict( recipient_rank=RECIPIENT_RANK) # all ranks state_dict2 = sharded_optimizer2.state_dict( ) if rank == RECIPIENT_RANK else {} state_dict2 = sync_object_ranks(state_dict2, RECIPIENT_RANK, device) # re-create a new optimizer from scratch with absurd values, load the previous state sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=1e6, momentum=0.0001) sharded_optimizer2.add_param_group({"params": head_oss2.parameters()}) sharded_optimizer2.load_state_dict(state_dict2) check_same_model_params( model_oss1, model_oss2, "parameters of the two identical models have diverged (before any steps)" ) # now take a step and check that parameters are equal run_grad_step(model_oss1, head_oss1, sharded_optimizer1) run_grad_step(model_oss2, head_oss2, sharded_optimizer2) check_same_model_params( model_oss1, model_oss2, "parameters of the two identical models have diverged (after stepping)" ) # save the state dict for one model only, then distribute to the other ranks sharded_optimizer2.consolidate_state_dict( recipient_rank=RECIPIENT_RANK) # all ranks state_dict2 = sharded_optimizer2.state_dict( ) if rank == RECIPIENT_RANK else {} state_dict2 = sync_object_ranks(state_dict2, RECIPIENT_RANK, device) # Check that the pulled state and the .param_groups attribute are in sync for replica in range(len(state_dict2["param_groups"])): for k in state_dict2["param_groups"][replica].keys(): if k != "params": assert state_dict2["param_groups"][replica][ k] == sharded_optimizer2.param_groups[0][k] # take a step run_grad_step(model_oss1, head_oss1, sharded_optimizer1) run_grad_step(model_oss2, head_oss2, sharded_optimizer2) check_same_model_params( model_oss1, model_oss2, "parameters of the two identical models have diverged (after consolidating)" ) # save again for one rank, then distribute to the others sharded_optimizer2.consolidate_state_dict( recipient_rank=RECIPIENT_RANK) # all ranks state_dict2 = sharded_optimizer2.state_dict( ) if rank == RECIPIENT_RANK else {} state_dict2 = sync_object_ranks(state_dict2, RECIPIENT_RANK, device) # reload the state_dict sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99) sharded_optimizer2.add_param_group({"params": head_oss2.parameters()}) sharded_optimizer2.load_state_dict(state_dict2) # take a step run_grad_step(model_oss1, head_oss1, sharded_optimizer1) run_grad_step(model_oss2, head_oss2, sharded_optimizer2) check_same_model_params( model_oss1, model_oss2, "parameters of the two identical models have diverged (after reloading)" ) dist.destroy_process_group()
def check_parity(manual_reduction: bool): # The API should be the exact same in between the sharded and non-sharded variants, generic closure def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False): accumulate_steps = 3 if should_accumulate else 1 model.zero_grad() def step(): if scaler is not None: with torch.cuda.amp.autocast(): loss = model(input_tensor).abs().sum() scaler.scale(loss).backward() else: loss = model(input_tensor).abs().sum() loss.backward() with model.no_sync() if should_accumulate else suppress(): for _ in range(accumulate_steps - 1): step() if not _manual_reduction: step() else: with model.no_sync(): step() model.reduce() # Any model works. Add one different buffer per rank model = _get_mlp() model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) # Make sure that the model starts with non-trainable, so that we check for the buckets to be # properly reassigned when/if this changes next(model.parameters()).requires_grad = False sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.99) sharded_ddp_model = ShardedDataParallel( module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True, reduce_buffer_size=reduce_buffer_size, reduce_fp16=fp16_reduction, ) ddp_model_single = copy.deepcopy(model) ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-4, momentum=0.99) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True) if fp16_reduction: from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook ddp_model.register_comm_hook( state=None, hook=fp16_compress_hook) # type: ignore ddp_scaler = TorchGradScaler() if amp else None sharded_scaler = ShardedGradScaler() if amp else None # The model should be synchronized in between the ranks at construction time, check that check_same_model_params(sharded_ddp_model, ddp_model) # Typical training loop, check that we get the exact same results as DDP for i in range(NUMBER_BATCHS): input_tensor = torch.rand((BATCH_SIZE, 2)).to(device) def ddp_closure(input_tensor=input_tensor): return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation) def sharded_closure(input_tensor=input_tensor): return closure( sharded_ddp_model, sharded_scaler, input_tensor, grad_accumulation, _manual_reduction=manual_reduction, ) # Step/scale both for _scaler, _closure, _optimizer in ( (ddp_scaler, ddp_closure, ddp_optimizer), (sharded_scaler, sharded_closure, sharded_optimizer), ): if _scaler is not None: _ = _closure(input_tensor) _scaler.step(_optimizer) _scaler.update() check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke") # Check that the two grad norm are equivalent # NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case # This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also # be valid for ShardedDDP if clip_grad_norm: total_norm = torch.nn.utils.clip_grad_norm_( ddp_model.parameters(), 0.3, norm_type=2.0) # type: ignore if not torch.isnan(total_norm): oss_total_norm = sharded_optimizer.clip_grad_norm( 0.3, norm_type=2.0) assert torch.allclose( oss_total_norm, total_norm, atol=1e-2 if amp else 1e-8 ), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}" else: print(rank, "NaN grad norm in DDP", flush=True) # Flip the trainability of the first parameter back and forth if i == 0 and change_train_graph: next(sharded_ddp_model.parameters()).requires_grad = not next( sharded_ddp_model.parameters()).requires_grad next(ddp_model.parameters()).requires_grad = not next( ddp_model.parameters()).requires_grad check_same_model_params( sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")
def check(broadcast_buffers: bool, grad_accumulation: bool = False) -> None: # Any model works. Add one different buffer per rank model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) next(model.parameters()).requires_grad = False # Test non-trainable parameters optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer, broadcast_buffers=broadcast_buffers) def check_same_model_params(same_params: bool): # Check that all the params are the same on all ranks # This should be true with and without broadcast_buffers, we don't have any real buffer here receptacle: List[torch.Tensor] = [] if dist.get_backend() != "nccl": for pg in optimizer.param_groups: for p in pg["params"]: # Check the params receptacle = [p.clone() for _ in range(world_size)] if rank == 0 else [] dist.gather(p, receptacle, dst=0) if rank == 0: for sync_p in receptacle[1:]: if same_params: assert torch.all(torch.eq(receptacle[0], sync_p)), "Models differ in between ranks" else: assert not torch.all( torch.eq(receptacle[0], sync_p) ), "Gradients should not have been synced" # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0) if broadcast_buffers: for b in ddp_model.buffers(): receptacle = [b.clone() for _ in range(world_size)] if rank == 0 else [] dist.gather(b, receptacle, dst=0) if rank == 0: for sync_b in receptacle[1:]: if same_params: assert torch.all(torch.eq(receptacle[0], sync_b)), "Models differ in between ranks" else: assert not torch.all( torch.eq(receptacle[0], sync_b) ), "Gradients should not have been synced" assert b.cpu().item() == 0.0 # The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that check_same_model_params(same_params=True) # Optim loop def closure(): optimizer.zero_grad() with ddp_model.no_sync() if grad_accumulation else suppress(): input_tensor = torch.rand((64, 2)).to(device) loss = ddp_model(input_tensor).abs().sum() loss.backward() return loss # The models should stay the same in between the ranks for i in range(5): _ = optimizer.step(closure=closure) # when running on cpu/gloo the "nodes" are not really different same_params = device == torch.device("cpu") or grad_accumulation check_same_model_params(same_params=same_params)
def check_parity(amp: bool, accumulate: bool, change_train_graph: bool): # The API should be the exact same in between the sharded and non-sharded variants, generic closure def closure(model, scaler, input_tensor, should_accumulate): accumulate_steps = 3 if should_accumulate else 1 model.zero_grad() def step(): if scaler is not None: with torch.cuda.amp.autocast(): loss = model(input_tensor).abs().sum() scaler.scale(loss).backward() else: loss = model(input_tensor).abs().sum() loss.backward() with model.no_sync() if should_accumulate else suppress(): for _ in range(accumulate_steps - 1): step() step() # Any model works. Add one different buffer per rank model = Sequential(Linear(INPUTS, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) # Make sure that the model starts with non-trainable, so that we check for the buckets to be # properly reassigned when/if this changes next(model.parameters()).requires_grad = False sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-5, momentum=0.99) sharded_ddp_model = ShardedDataParallel( module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True ) ddp_model_single = copy.deepcopy(model) ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-5, momentum=0.99) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True) ddp_scaler = TorchGradScaler() if amp else None sharded_ddp_scaler = ShardedGradScaler() if amp else None # The model should be synchronized in between the ranks at construction time, check that check_same_model_params(sharded_ddp_model, ddp_model) # Typical training loop, check that we get the exact same results as DDP for i in range(NUMBER_BATCHS): input_tensor = torch.rand((BATCH_SIZE, INPUTS)).to(device) def closure_ddp(input_tensor=input_tensor): return closure(ddp_model, ddp_scaler, input_tensor, accumulate) def closure_sharded(input_tensor=input_tensor): return closure(sharded_ddp_model, sharded_ddp_scaler, input_tensor, accumulate) # Step/scale both if ddp_scaler is not None: _ = closure_ddp(input_tensor) ddp_scaler.step(ddp_optimizer) ddp_scaler.update() else: ddp_optimizer.step(closure=closure_ddp) if sharded_ddp_scaler is not None: _ = closure_sharded(input_tensor) sharded_ddp_scaler.step(sharded_optimizer) sharded_ddp_scaler.update() else: sharded_optimizer.step(closure=closure_sharded) check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke") # Flip the trainability of the first parameter back and forth if i == 0 and change_train_graph: next(sharded_ddp_model.parameters()).requires_grad = not next( sharded_ddp_model.parameters() ).requires_grad next(ddp_model.parameters()).requires_grad = not next(ddp_model.parameters()).requires_grad check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")