def test_collect_shards(self): """ Check the state consolidation mechanism, and the state dict exposed by ZeroRedundancyOptimizer""" self.dist_init(self.rank) RECIPIENT_RANK = 0 # Run a dummy step so that the optimizer state dict exists batch, input_width, hidden, target_width = 3, 20, 10, 5 target = torch.rand((batch, target_width), device=self.device) inputs = torch.rand((batch, input_width), device=self.device) model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)) model.to(self.device) loss_fn = torch.nn.L1Loss() loss_fn.to(self.device) # With SGD, Momentum is required to get a state to shard optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=SGD, lr=0.1, momentum=0.99) def closure(): optimizer.zero_grad() output = model(inputs) loss = loss_fn(output, target) loss.backward() return loss _ = optimizer.step(closure=closure) # Update the optimizer state on the reference rank optimizer.consolidate_state_dict(to=RECIPIENT_RANK) # Fetch the state on the reference rank # - check that it has the correct size # - load it again if self.rank == RECIPIENT_RANK: optimizer_state_dict = optimizer.state_dict() self.assertEqual(len(optimizer_state_dict["state"]), len(list(model.parameters()))) else: optimizer_state_dict = {} optimizer_state_dict = _broadcast_object( optimizer_state_dict, src_rank=RECIPIENT_RANK, group=dist.group.WORLD, device=self.device, ) # Load the optimizer state dict, check that no exception is raised optimizer.load_state_dict(optimizer_state_dict)
def _test_zero_join(self, device): r""" Check that the ZeRO join hook allows training with uneven inputs when using the given device. Arguments: device (torch.device): device used to store parameters and perform collective communications. """ NUM_INPUTS = 3 NUM_EPOCHS = 2 torch.manual_seed(0) torch.cuda.manual_seed(0) rank = self.rank world_size = self.world_size is_gpu = device.type == "cuda" backend = dist.Backend.NCCL if is_gpu else dist.Backend.GLOO self.dist_init(rank, world_size, backend) if BACKEND == dist.Backend.NCCL and is_gpu: torch.cuda.set_device(self.device) model = torch.nn.Sequential( torch.nn.Linear(2, 3), torch.nn.Linear(3, 3), torch.nn.Linear(3, 3), ) model.to(device) # DDP ensures correct gradients in data parallel training, so DDP with # local optimizers on uneven inputs should be equivalent to ZeRO on # uneven inputs with gradients being manually set ddp_model = DDP(model, device_ids=[rank]) if is_gpu else DDP(model) local_optim = torch.optim.Adam(ddp_model.parameters(), lr=0.01) zero_model = copy.deepcopy(model) zero_model.to(device) zero_optim = ZeroRedundancyOptimizer(zero_model.parameters(), torch.optim.Adam, lr=0.01) loss_fn = torch.nn.MSELoss() # Use uneven inputs: rank i has i extra inputs inputs = [ torch.randn(20, 2).to(device) for _ in range(NUM_INPUTS + rank) ] labels = torch.randn(20, 3).to(device) # Save the gradients and parameters from DDP as the ground truth; do # so on the last-joining rank (in this case, the largest rank) grads_at_each_iter = [] params_at_each_iter = [] with ddp_model.join(): for _ in range(NUM_EPOCHS): for input in inputs: output = ddp_model(input) loss_fn(output, labels).backward() if rank == world_size - 1: grads = [] for p in ddp_model.parameters(): grads.append(p.grad.detach().clone().to(device)) local_optim.step() if rank == world_size - 1: params = [] for p in ddp_model.parameters(): params.append(p.detach().clone().to(device)) grads_at_each_iter.append(grads) params_at_each_iter.append(params) # Broadcast the saved gradients and parameters to all of the other # ranks (which joined early) grads_and_params = [grads_at_each_iter, params_at_each_iter] grads_and_params = _broadcast_object(grads_and_params, src_rank=world_size - 1, group=dist.group.WORLD, device=device) grads_at_each_iter = grads_and_params[0] params_at_each_iter = grads_and_params[1] # TODO: Replace this `_broadcast_object` with `broadcast_object_list` # once the latter supports loading to the destination device instead # of the source device # A process must still set the remaining gradients after joining, so we # define a join hook to do this before the ZeRO join hook class _JoinGradInfo(): def __init__(self, grads): self.grads = grads # remaining gradients to set (in order) self.index = 0 class _SetGradsJoinHook(JoinHook): def __init__(self, zero_optim, grads): zero_optim._join_grad_info = _JoinGradInfo(grads) self.zero = zero_optim super().__init__() def main_hook(self): grads = self.zero._join_grad_info.grads[ self.zero._join_grad_info.index] self.zero._join_grad_info.index += 1 for p, grad in zip(self.zero._all_params, grads): p.grad = grad.detach().clone().to(device) class _GradientSetter(Joinable): def __init__(self): super().__init__() def join_hook(self, **kwargs): assert "zero_optim" in kwargs assert "grads" in kwargs zero_optim = kwargs["zero_optim"] grads = kwargs["grads"] return _SetGradsJoinHook(zero_optim, grads) @property def join_device(self): return device @property def join_process_group(self): return dist.group.WORLD num_grads_after_joining = NUM_EPOCHS * (world_size - rank - 1) grads = grads_at_each_iter[-num_grads_after_joining:] gradient_setter = _GradientSetter() iter = 0 with Join([gradient_setter, zero_optim], zero_optim=zero_optim, grads=grads): for _ in range(NUM_EPOCHS): for input in inputs: # Notify join context that this process has not joined Join.notify_join_context(gradient_setter) # Set gradients manually for p, grad in zip(zero_model.parameters(), grads_at_each_iter[iter]): p.grad = grad.detach().clone().to(device) # Perform optimizer step and check parity zero_optim.step() for p, ddp_p in zip(zero_model.parameters(), params_at_each_iter[iter]): assert torch.allclose(p, ddp_p), \ "Parameters differ between using ZeRO and local optimizer" iter += 1