def check(optimizer, model): # Just run a couple of epochs, check that the model is properly updated for _ in range(epochs): target = torch.rand((batch, target_width), device=device) inputs = torch.rand((batch, input_width), device=device) def closure(): optimizer.zero_grad() output = model(inputs) loss = loss_fn(output, target) loss.backward() return loss _ = optimizer.step(closure=closure) # Check that all the params are the same on all ranks check_same_models_across_ranks( model, process_group, params_should_be_equal=True, check_broadcast_buffers=True )
def run_one_step( rank, world_size, backend, device, temp_file_name, broadcast_buffers, grad_accumulation, reduce_buffer_size, ): dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) if device == torch.device("cuda"): torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) # Any model works. Add one different buffer per rank model = _get_mlp() model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) next(model.parameters() ).requires_grad = False # Test non-trainable parameters optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer, broadcast_buffers=broadcast_buffers, reduce_buffer_size=reduce_buffer_size) # The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that check_same_models_across_ranks(ddp_model, dist.group.WORLD, params_should_be_equal=True, check_broadcast_buffers=broadcast_buffers) # Optim loop def closure(): optimizer.zero_grad() with ddp_model.no_sync() if grad_accumulation else suppress(): input_tensor = torch.rand((64, 2)).to(device) loss = ddp_model(input_tensor).abs().sum() loss.backward() return loss # The models should stay the same in between the ranks for i in range(5): _ = optimizer.step(closure=closure) # when running on cpu/gloo the "nodes" are not really different same_params = device == torch.device("cpu") or grad_accumulation check_same_models_across_ranks( ddp_model, dist.group.WORLD, params_should_be_equal=same_params, check_broadcast_buffers=broadcast_buffers) dist.destroy_process_group()
def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name): dist_init(rank, world_size, tempfile_name) device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE torch.cuda.set_device(rank) # Run a dummy step so that the optimizer state dict exists batch, input_width, hidden, target_width = 3, 3, 3, 5 target = torch.rand((batch, target_width), device=device) inputs = torch.rand((batch, input_width), device=device) model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)) model.to(device) loss_fn = torch.nn.L1Loss() loss_fn.to(device) # With SGD, Momentum is required to get a state to shard optimizer = optim.OSS(model.parameters(), lr=0.1, momentum=0.99) def closure(): optimizer.zero_grad() output = model(inputs) loss = loss_fn(output, target) loss.backward() return loss _ = optimizer.step(closure=closure) # Update the optimizer state on the reference rank optimizer.consolidate_state_dict(recipient_rank=reference_rank) # Fetch the state on the reference rank # - check that it has the correct size # - load it again if rank == reference_rank: optimizer_state_dict = optimizer.state_dict() assert len(optimizer_state_dict["state"]) == len( list(model.parameters())) else: optimizer_state_dict = {} # distribute to the other ranks optimizer_state_dict = sync_object_ranks(optimizer_state_dict, reference_rank, device) # Load the optimizer state dict optimizer.load_state_dict(optimizer_state_dict) # Check that the states are not None, but {} for state in optimizer.state.values(): for _, _ in state.items(): pass # Test the state dict materialization on all ranks _ = optimizer.step(closure=closure) optimizer_state_dict = optimizer.state_dict(all_ranks=True) # one per rank optimizer.load_state_dict(optimizer_state_dict) _ = optimizer.step(closure=closure) check_same_models_across_ranks(model, dist.group.WORLD, params_should_be_equal=True, check_broadcast_buffers=False) dist.destroy_process_group()