def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name): dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) if device == torch.device("cuda"): torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) model = _DoubleInput().to(device) parameters = list(model.parameters()) optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=1e-3, momentum=0.99) optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2]) # Optim loop def closure(): input_tensor = torch.rand((64, 2)).to(device) loss = ddp_model(input_tensor, input_tensor).abs().sum() loss.backward() return loss for i in range(5): optimizer_1.zero_grad() optimizer_2.zero_grad() _ = optimizer_1.step(closure=closure) _ = optimizer_2.step(closure=closure) dist.destroy_process_group()
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name): url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) device = torch.device("cuda") torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) # Any model works. Add one different buffer per rank model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) n_half_params = len(list(model.parameters())) // 2 sharded_optimizer = OSS( params=list(model.parameters())[:n_half_params], optim=torch.optim.SGD, lr=1e-3, momentum=0.99 ) sharded_optimizer_2 = OSS( params=list(model.parameters())[n_half_params:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99 ) sharded_ddp_model = ShardedDataParallel(module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True) ddp_model_single = copy.deepcopy(model) ddp_optimizer = torch.optim.SGD(list(ddp_model_single.parameters())[:n_half_params], lr=1e-3, momentum=0.99) ddp_optimizer_2 = torch.optim.SGD(list(ddp_model_single.parameters())[n_half_params:], lr=1e-3, momentum=0.99) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True) def check_same_model_params(): for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups): for p, ddp_p in zip(pg["params"], ddp_pg["params"]): assert torch.allclose( p, ddp_p, atol=1e-3 ), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}" for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()): assert torch.allclose(b, ddp_b, atol=1e-3), "Model buffers differ in between DDP and ShardedDDP" check_same_model_params() # The models should stay the same in between the ranks for i in range(20): input_tensor = torch.rand((64, 2)).to(device) # Run DDP ddp_optimizer.zero_grad() ddp_optimizer_2.zero_grad() ddp_loss = ddp_model(input_tensor).abs().sum() ddp_loss.backward() ddp_optimizer.step() ddp_optimizer_2.step() # Run Sharded sharded_optimizer.zero_grad() sharded_optimizer_2.zero_grad() sharded_loss = sharded_ddp_model(input_tensor).abs().sum() sharded_loss.backward() sharded_optimizer.step() sharded_optimizer_2.step() check_same_model_params() dist.destroy_process_group()
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name): url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) if device == torch.device("cuda"): torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) class _DoubleInput(torch.nn.Module): def __init__(self): super().__init__() self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) def forward(self, x, y): x1 = self.mlp(x) x2 = self.mlp(y) return torch.cat((x1, x2), dim=1) model = _DoubleInput().to(device) parameters = list(model.parameters()) optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=0.01, momentum=0.99) optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=0.01, momentum=0.99) ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2]) # Optim loop def closure(): input_tensor = torch.rand((64, 2)).to(device) loss = ddp_model(input_tensor, input_tensor).abs().sum() loss.backward() return loss for i in range(5): optimizer_1.zero_grad() optimizer_2.zero_grad() _ = optimizer_1.step(closure=closure) _ = optimizer_2.step(closure=closure) dist.destroy_process_group()
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_buffer_size): dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) device = torch.device("cuda") torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) # Any model works. Add one different buffer per rank BATCHS = 20 model = _get_mlp_emb() model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) n_half_params = len(list(model.parameters())) // 2 optim_settings = {"lr": 1e-3, "momentum": 0.99} sharded_optimizer = OSS(params=list(model.parameters())[:n_half_params], optim=torch.optim.SGD, **optim_settings) sharded_optimizer_2 = OSS(params=list(model.parameters())[n_half_params:], optim=torch.optim.SGD, **optim_settings) sharded_ddp_model = ShardedDataParallel( module=model, sharded_optimizer=[sharded_optimizer, sharded_optimizer_2], broadcast_buffers=True, reduce_buffer_size=reduce_buffer_size, ) ddp_model_single = copy.deepcopy(model) ddp_optimizer = torch.optim.SGD( list(ddp_model_single.parameters())[:n_half_params], **optim_settings) ddp_optimizer_2 = torch.optim.SGD( list(ddp_model_single.parameters())[n_half_params:], **optim_settings) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True) check_same_model_params( sharded_ddp_model, ddp_model, f"DDP parity two optim test failing. differing at startup, Buffers {reduce_buffer_size}", ) for i in range(BATCHS): input_tensor = _get_random_inputs(device) # Run DDP ddp_optimizer.zero_grad() ddp_optimizer_2.zero_grad() ddp_loss = ddp_model(input_tensor).abs().sum() ddp_loss.backward() ddp_optimizer.step() ddp_optimizer_2.step() torch.cuda.synchronize(device) # Run Sharded sharded_optimizer.zero_grad() sharded_optimizer_2.zero_grad() sharded_loss = sharded_ddp_model(input_tensor).abs().sum() sharded_loss.backward() sharded_optimizer.step() sharded_optimizer_2.step() torch.cuda.synchronize(device) check_same_model_params( sharded_ddp_model, ddp_model, f"DDP parity two optim test failing, step {i}, buffers {reduce_buffer_size}", ) dist.destroy_process_group()