def run_one_step(rank, world_size, backend, device, temp_file_name): url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) if device == torch.device("cuda"): torch.cuda.set_device(rank) model = Sequential(Linear(2, 3), Linear(3, 4)).to(device) ddp = ShardedDataParallel( module=model, optimizer=torch.optim.SGD, optimizer_params={"lr": 0.1, "momentum": 0.99}, world_size=world_size ) optimizer = ddp.optimizer input_tensor = torch.rand((64, 2)).to(device) output = ddp(input_tensor).abs().sum() / input_tensor.numel() output.backward() ddp.reduce() # Check that all the grads have been populated, for the shard if device == torch.device("cuda"): torch.cuda.synchronize() # flush any remaining cuda op, just in case for pg in optimizer.optim.param_groups: for param in pg["params"]: if param.requires_grad: assert param.grad.abs().sum().item() > 0.0, "The reduce step should have populated all the gradients" # Check that the optimization process makes sense (ie. loss goes down for the same data) optimizer.step() new_eval = ddp(input_tensor).abs().sum() / input_tensor.numel()
def run_one_step(rank, world_size, backend, device, temp_file_name): url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) if device == torch.device("cuda"): torch.cuda.set_device(rank) # Any model works. Add one different buffer per rank model = Sequential(Linear(2, 3)).to(device) model.register_buffer("test_buffer", torch.ones((1)) * rank) def weights_init(m): if isinstance(m, Linear): torch.nn.init.constant_(m.weight.data, 1.0) torch.nn.init.constant_(m.bias.data, 1.0) model.apply(weights_init) model.to(device) ddp = ShardedDataParallel( module=model, optimizer=torch.optim.SGD, optimizer_params={ "lr": 0.01, "momentum": 0.99 }, world_size=world_size, broadcast_buffers=True, ) optimizer = ddp.optimizer model = ddp.module # Different input per rank, allows for checking that the gradients have been properly reduced input_tensor = (torch.ones((64, 2)) * rank).to(device) output = ddp(input_tensor).abs().sum() output.backward() ddp.reduce() # Check that all the grads have been populated, for the shard for pg in optimizer.optim.param_groups: for param in pg["params"]: if param.shape == torch.Size([3, 2]): assert param.grad[0, 0].cpu() == torch.tensor([32.0]) if param.shape == torch.Size([3]): assert param.grad[0].cpu() == torch.tensor([64.0]) # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0) for b in model.buffers(): assert b.cpu().item() == 0.0 dist.destroy_process_group()
def run_one_step(rank, world_size, backend, device, temp_file_name): url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) if device == torch.device("cuda"): torch.cuda.set_device(rank) # Any model works. Add one different buffer per rank model = Sequential(Linear(2, 3), Linear(3, 4)).to(device) model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) ddp = ShardedDataParallel( module=model, optimizer=torch.optim.SGD, optimizer_params={ "lr": 0.01, "momentum": 0.99 }, world_size=world_size, broadcast_buffers=True, ) optimizer = ddp.optimizer model = ddp.module input_tensor = torch.rand((64, 2)).to(device) output = ddp(input_tensor).abs().sum() / input_tensor.numel() output.backward() ddp.reduce() # Check that all the grads have been populated, for the shard if device == torch.device("cuda"): torch.cuda.synchronize() # flush any remaining cuda op, just in case for pg in optimizer.optim.param_groups: for param in pg["params"]: if param.requires_grad: assert param.grad.abs().sum().item( ) > 0.0, "The reduce step should have populated all the gradients" # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0) for b in model.buffers(): assert b.cpu().item() == 0.0