def test_state_dict(self): """Check that the ZeroRedundancyOptimizer exposes the expected state dict interface, irrespective of the sharding. """ self.dist_init(self.rank) x = torch.tensor([1.0], device=DEVICE, requires_grad=True) o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.1, momentum=0.9) x.backward() o.step() self.assertEqual(x, torch.tensor([0.9], device=DEVICE)) self.assertEqual(o.optim.state[x]["momentum_buffer"], torch.tensor([1.0], device=DEVICE)) o.zero_grad() o.consolidate_state_dict( ) # Sync state dict in between replicas - even if there are none state_dict = o.state_dict() # Check that the state dict is pytorch-compliant key wise self.assertIn("param_groups", state_dict.keys()) self.assertIn("state", state_dict.keys()) # Check that the pulled state is what we expect, and that we have all the expected keys self.assertEqual(state_dict["param_groups"][0]["lr"], 0.1) self.assertEqual(state_dict["param_groups"][0]["momentum"], 0.9) self.assertFalse(state_dict["param_groups"][0]["nesterov"]) self.assertEqual(state_dict["param_groups"][0]["weight_decay"], 0.0) self.assertEqual(state_dict["param_groups"][0]["dampening"], 0.0) # Check that the pulled state and the .param_groups attribute are in sync for k in state_dict["param_groups"][0].keys(): if k != "params": self.assertEqual(state_dict["param_groups"][0][k], o.param_groups[0][k]) # Check that it's correctly loaded o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01) o.load_state_dict(state_dict) # Check that state is correct and on proper device self.assertEqual(o.optim.state[x]["momentum_buffer"], torch.tensor([1.0], device=DEVICE)) # We should now be using a lr of 0.1, both within the optimizer # and as exposed by the .param_groups attribute assert o.param_groups[0]["lr"] == 0.1 x.backward() o.step() self.assertEqual(x, torch.tensor([0.71], device=DEVICE)) self.assertEqual(o.optim.state[x]["momentum_buffer"], torch.tensor([1.9], device=DEVICE)) # Check that the exposed param_groups are on the proper device self.assertEqual(o.param_groups[0]["params"][0].device, x.device)
def test_collect_shards(self): """ Check the state consolidation mechanism, and the state dict exposed by ZeroRedundancyOptimizer""" self.dist_init(self.rank) RECIPIENT_RANK = 0 # Run a dummy step so that the optimizer state dict exists batch, input_width, hidden, target_width = 3, 20, 10, 5 target = torch.rand((batch, target_width), device=self.device) inputs = torch.rand((batch, input_width), device=self.device) model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)) model.to(self.device) loss_fn = torch.nn.L1Loss() loss_fn.to(self.device) # With SGD, Momentum is required to get a state to shard optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=SGD, lr=0.1, momentum=0.99) def closure(): optimizer.zero_grad() output = model(inputs) loss = loss_fn(output, target) loss.backward() return loss _ = optimizer.step(closure=closure) # Update the optimizer state on the reference rank optimizer.consolidate_state_dict(to=RECIPIENT_RANK) # Fetch the state on the reference rank # - check that it has the correct size # - load it again if self.rank == RECIPIENT_RANK: optimizer_state_dict = optimizer.state_dict() self.assertEqual(len(optimizer_state_dict["state"]), len(list(model.parameters()))) else: optimizer_state_dict = {} optimizer_state_dict = _broadcast_object( optimizer_state_dict, src_rank=RECIPIENT_RANK, group=dist.group.WORLD, device=self.device, ) # Load the optimizer state dict, check that no exception is raised optimizer.load_state_dict(optimizer_state_dict)
def check_optimizer_equivalence( optimizer: Type[torch.optim.Optimizer]): # Any model works. Add one different buffer per rank model = torch.nn.Sequential( torch.nn.Linear(2, 3), torch.nn.Linear(3, 3), torch.nn.Linear(3, 3), ) model.register_buffer("test_buffer", torch.ones((1)) * self.rank) model.to(self.device) sharded_optimizer = ZeroRedundancyOptimizer( params=model.parameters(), optimizer_class=optimizer, lr=1e-3) sharded_ddp_model = DDP(module=model, device_ids=[self.rank], broadcast_buffers=True, find_unused_parameters=True) ddp_model_single = copy.deepcopy(model) ddp_model_single.to(self.device) ddp_optimizer = optimizer(ddp_model_single.parameters(), lr=1e-3) ddp_model = DDP(ddp_model_single, device_ids=[self.rank], broadcast_buffers=True, find_unused_parameters=True) # The model should be synchronized in between the ranks at construction time, check that check_same_model_params(sharded_ddp_model, ddp_model, "Models differ from the start") def check_step(): input_tensor = torch.rand((64, 2)) def closure_ddp(input_tensor=input_tensor): ddp_optimizer.zero_grad() ddp_loss = ddp_model(input_tensor).abs().sum() ddp_loss.backward() return ddp_loss def closure_sharded(input_tensor=input_tensor): sharded_optimizer.zero_grad() sharded_loss = sharded_ddp_model( input_tensor).abs().sum() sharded_loss.backward() return sharded_loss loss_ddp = cast(torch.Tensor, ddp_optimizer.step(closure=closure_ddp)) loss_sharded_optim = cast( torch.Tensor, sharded_optimizer.step(closure=closure_sharded)) assert torch.allclose( loss_ddp, loss_sharded_optim ), "Losses differ in between Pytorch optim and ZeroRedundancyOptimizer" check_same_model_params(sharded_ddp_model, ddp_model, "Models differ after a step") # The models should stay the same in between the ranks for i in range(BATCHS): check_step() # Change the models trainability, check that parity is maintained # only check after a couple of constant batchs to go through both regimes if i > BATCHS // 2: next(ddp_model.parameters()).requires_grad = bool(i % 2) next(sharded_ddp_model.parameters() ).requires_grad = bool(i % 2) # Check that the checkpoints are compatible reference_rank = 0 # - get states ddp_state_dict = ddp_optimizer.state_dict() sharded_optimizer.consolidate_state_dict(to=reference_rank) sharded_optim_state_dict = [ sharded_optimizer.state_dict() if self.rank == reference_rank else {} ] dist.broadcast_object_list(sharded_optim_state_dict, src=reference_rank, group=dist.group.WORLD) sharded_optim_state_dict = sharded_optim_state_dict[0] # - cross load the states # run one step and check that the models are still the same ddp_state_dict_ref = copy.deepcopy( ddp_state_dict) # OSS will remove some states ddp_optimizer.load_state_dict( sharded_optim_state_dict) # mixup on purpose ! sharded_optimizer.load_state_dict(ddp_state_dict) check_step() # - self load, rewind, check no problem # run one step and check that the models are still the same ddp_optimizer.load_state_dict(ddp_state_dict_ref) sharded_optimizer.load_state_dict(sharded_optim_state_dict) check_step()
def check_optimizer_equivalence( optimizer: Type[torch.optim.Optimizer]): # Any model works. Add one different buffer per rank model = torch.nn.Sequential( torch.nn.Linear(2, 3), torch.nn.Linear(3, 3), torch.nn.Linear(3, 3), ) model.register_buffer("test_buffer", torch.ones((1)) * self.rank) model.to(self.device) sharded_optimizer = ZeroRedundancyOptimizer( params=model.parameters(), optim=optimizer, lr=1e-3) sharded_ddp_model = DDP(module=model, device_ids=[self.rank], broadcast_buffers=True) ddp_model_single = copy.deepcopy(model) ddp_model_single.to(self.device) ddp_optimizer = optimizer(ddp_model_single.parameters(), lr=1e-3) ddp_model = DDP(ddp_model_single, device_ids=[self.rank], broadcast_buffers=True) def check_same_model_params(): for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups): for p, ddp_p in zip(pg["params"], ddp_pg["params"]): assert torch.allclose( p, ddp_p, atol=1e-3 ), f"Model parameters differ in between Pytorch optim and ZeroRedundancyOptimizer \n{p} {ddp_p}" for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()): assert torch.allclose( b, ddp_b ), "Model buffers differ in between Pytorch optim and ZeroRedundancyOptimizer" # The model should be synchronized in between the ranks at construction time, check that check_same_model_params() # The models should stay the same across multiple steps, losses should stay the same def check_step(): input_tensor = torch.rand((64, 2)) def closure_ddp(input_tensor=input_tensor): ddp_optimizer.zero_grad() ddp_loss = ddp_model(input_tensor).abs().sum() ddp_loss.backward() return ddp_loss def closure_sharded(input_tensor=input_tensor): sharded_optimizer.zero_grad() sharded_loss = sharded_ddp_model( input_tensor).abs().sum() sharded_loss.backward() return sharded_loss loss_ddp = cast(torch.Tensor, ddp_optimizer.step(closure=closure_ddp)) loss_sharded_optim = cast( torch.Tensor, sharded_optimizer.step(closure=closure_sharded)) assert torch.allclose( loss_ddp, loss_sharded_optim ), "Losses differ in between Pytorch optim and ZeroRedundancyOptimizer" check_same_model_params() for i in range(20): check_step() # Test state dict save/load/equivalence with pytorch # - save state for both sharded_optimizer.consolidate_state_dict() sharded_optimizer_state_dict = (sharded_optimizer.state_dict() if self.rank == RECIPIENT_RANK else torch.zeros(1)) ddp_state_dict = ddp_optimizer.state_dict() # - sync the saved state with all the ranks exchange_list = [sharded_optimizer_state_dict] dist.broadcast_object_list( exchange_list, src=RECIPIENT_RANK, group=dist.group.WORLD, ) sharded_optimizer_state_dict = exchange_list[0] # - cross load the states ddp_optimizer.load_state_dict(sharded_optimizer_state_dict) sharded_optimizer.load_state_dict(ddp_state_dict) # - run one step, and check that the models are still the same check_step()
def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer]): # Any model works. Add one different buffer per rank model = torch.nn.Sequential( torch.nn.Linear(2, 3), torch.nn.Linear(3, 3), torch.nn.Linear(3, 3), ) model.register_buffer("test_buffer", torch.ones((1)) * self.rank) model.to(self.device) sharded_optimizer = ZeroRedundancyOptimizer(params=model.parameters(), optim=optimizer, lr=1e-3) sharded_ddp_model = DDP(module=model, device_ids=[self.rank], broadcast_buffers=True) ddp_model_single = copy.deepcopy(model) ddp_model_single.to(self.device) ddp_optimizer = optimizer(ddp_model_single.parameters(), lr=1e-3) ddp_model = DDP(ddp_model_single, device_ids=[self.rank], broadcast_buffers=True) def check_same_model_params(): for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups): for p, ddp_p in zip(pg["params"], ddp_pg["params"]): assert torch.allclose( p, ddp_p, atol=1e-3 ), f"Model parameters differ in between Pytorch optim and ZeroRedundancyOptimizer \n{p} {ddp_p}" for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()): assert torch.allclose( b, ddp_b ), "Model buffers differ in between Pytorch optim and ZeroRedundancyOptimizer" # The model should be synchronized in between the ranks at construction time, check that check_same_model_params() def check_step(): input_tensor = torch.rand((64, 2)) def closure_ddp(input_tensor=input_tensor): ddp_optimizer.zero_grad() ddp_loss = ddp_model(input_tensor).abs().sum() ddp_loss.backward() return ddp_loss def closure_sharded(input_tensor=input_tensor): sharded_optimizer.zero_grad() sharded_loss = sharded_ddp_model(input_tensor).abs().sum() sharded_loss.backward() return sharded_loss loss_ddp = cast(torch.Tensor, ddp_optimizer.step(closure=closure_ddp)) loss_sharded_optim = cast(torch.Tensor, sharded_optimizer.step(closure=closure_sharded)) assert torch.allclose( loss_ddp, loss_sharded_optim ), "Losses differ in between Pytorch optim and ZeroRedundancyOptimizer" check_same_model_params() # The models should stay the same in between the ranks for i in range(20): check_step() # Check that the checkpoints are compatible reference_rank = 0 # - get states ddp_state_dict = ddp_optimizer.state_dict() sharded_optimizer.consolidate_state_dict(recipient_rank=reference_rank) sharded_optim_state_dict = [sharded_optimizer.state_dict() if self.rank == reference_rank else {}] dist.broadcast_object_list(sharded_optim_state_dict, src=reference_rank, group=dist.group.WORLD) sharded_optim_state_dict = sharded_optim_state_dict[0] # - cross load the states # run one step and check that the models are still the same ddp_state_dict_ref = copy.deepcopy(ddp_state_dict) # OSS will remove some states ddp_optimizer.load_state_dict(sharded_optim_state_dict) # mixup on purpose ! sharded_optimizer.load_state_dict(ddp_state_dict) check_step() # - self load, rewind, check no problem # run one step and check that the models are still the same ddp_optimizer.load_state_dict(ddp_state_dict_ref) sharded_optimizer.load_state_dict(sharded_optim_state_dict) check_step()