def consolidate_state_dict(self, recipient_rank: int = 0) -> None: """Update the consolidated state_dict list, one per rank. Arguments: recipient_rank (int): on which rank to materialize the full state dict. -1 is a special value, which means that all ranks should have the state .. warning: This needs to be called on all replicas""" # Sync lr and other attributes in case its been updated OSS._sync_param_groups(self.param_groups, self.optim.param_groups) # Pull the sharded state from all the other replicas # Store all the states in order, rank by rank logging.debug("Pulling the sharded optimizer state from all replicas") self._all_states = [] should_collect_state = self.rank == recipient_rank or recipient_rank == -1 should_send_state = self.rank != recipient_rank # NCCL requires CUDA tensors for all communication primitives dist_device = torch.device("cuda") if self.backend == dist.Backend.NCCL else self._default_device for rank in range(self.world_size): if rank == self.rank: if should_collect_state: logging.debug("Saving self state") self._all_states.append( recursive_copy_to_device(self.optim.state_dict(), non_blocking=True, device=torch.device("cpu")) ) # Sync with other replicas state_to_share = ( self.optim.state_dict() if should_send_state else torch.tensor([0], dtype=torch.uint8, device=dist_device) ) broadcast_object( state_to_share, src_rank=self.global_rank, group=self.group, dist_device=dist_device, ) else: # Fetch the optim state from the other replicas replica_state = broadcast_object( torch.tensor([0], dtype=torch.uint8, device=dist_device), src_rank=self._local_to_global_rank[rank], group=self.group, dist_device=dist_device, ) if should_collect_state: self._all_states.append( recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu")) ) logging.debug("State from rank %s received", rank)
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """Restore the global parameter groups as well as the shard. Arguments: state_dict (dict): optimizer state. Should be an object returned from a call to :meth:`state_dict` """ # Update the state, trusting the ordering in param_groups # Apart from the removal of states not owned by this rank, the pytorch logic is kept # (See torch.optim.optimizer) id_map = { old_id: p for old_id, p in zip( chain.from_iterable((g["params"] for g in state_dict["param_groups"])), chain.from_iterable((g["params"] for g in self.param_groups)), ) } for key, value in state_dict["state"].items(): param = id_map[key] # Populate the sharded optimizer state on the fly, # remove the params that this rank does not own if self._param_to_rank[param] != self.rank: state_dict["state"][key] = {} else: self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device) super().load_state_dict(state_dict) # Sync with the optimizer param groups OSS._sync_param_groups(state_dict["param_groups"], self.param_groups) OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim.SGD, transformer=False): """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()""" # Establish reference behavior. if transformer: unwrapped_model = TransformerWithSharedParams( group, wrapper_config=config).cuda() fsdp = self.get_wrapped_model(group, config=config).cuda() else: unwrapped_model = MixtureOfExperts(group, wrapper_config=None).cuda() fsdp = FullyShardedDataParallel( MixtureOfExperts(group, wrapper_config=config)).cuda() try: fsdp_optim = optim_fn( fsdp.parameters(), lr=0.01, ) optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01) except TypeError: # Adadelta fsdp_optim = optim_fn(fsdp.parameters()) optim_unwrapped = optim_fn(unwrapped_model.parameters()) fsdp_optim.zero_grad() optim_unwrapped.zero_grad() with torch.cuda.amp.autocast(enabled=True): x = fsdp.module.get_input(torch.device("cuda")) output = fsdp(*x) loss = fsdp.module.get_loss(x, output).to("cuda") fsdp.module.run_backward(loss) fsdp_optim.step() output = unwrapped_model(*x) loss = unwrapped_model.get_loss(x, output) unwrapped_model.run_backward(loss) optim_unwrapped.step() unwrapped_sd = optim_unwrapped.state_dict() if not transformer: no_broadcast_children = [ x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state ] assert len(no_broadcast_children) == 1 assert fsdp._fsdp_instances[-1].no_broadcast_optim_state torch.cuda.empty_cache() cuda_gb_before = torch.cuda.memory_stats( fsdp.rank)["allocated_bytes.all.current"] / 1024**3 tstart = time() sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0) duration = time() - tstart assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate" cuda_gb_after = torch.cuda.memory_stats( fsdp.rank)["allocated_bytes.all.current"] / 1024**3 mem_usg_gb = cuda_gb_after - cuda_gb_before assert mem_usg_gb == 0, f"gather_full_optim_state_dict used {mem_usg_gb:.2f} CUDA GB, max allowed is 0" assert cuda_gb_after > 0, "got 0 memory usage, logging is broken" if fsdp.rank > 0: assert sd is None return # assert whole state dict on CPU for k, v in sd["state"].items(): for buffer_name, t in v.items(): if torch.is_tensor(t): msg = f"got device {t.device} for {k}: {buffer_name}. expected CPU" assert t.device == torch.device("cpu"), msg unflat_state = sd["state"] assert "uncollected_local_ids" in sd shard_sd = fsdp.get_shard_from_optim_state_dict(sd) shard_sd = recursive_copy_to_device(shard_sd, non_blocking=False, device="cpu") state_after_get_shard = sd["state"] assert objects_are_equal(unflat_state, state_after_get_shard) # no side effects. assert_equal(len(sd["state"]), len(unwrapped_sd["state"])) assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"])) assert_equal( sum([first_tensor_numel(v) for k, v in sd["state"].items()]), sum([ first_tensor_numel(v) for k, v in unwrapped_sd["state"].items() ]), ) original_shard_sd = fsdp_optim.state_dict() assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"])) assert_equal(shard_sd.keys(), original_shard_sd.keys()) original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu") # Before asserting that the dicts are equal, we check keys individually to allow nice tracebacks. assert_equal( [first_tensor_numel(v) for k, v in shard_sd["state"].items()], [ first_tensor_numel(v) for k, v in original_shard_sd["state"].items() ], ) assert_equal( [v for k, v in shard_sd["param_groups"][0].items()], [v for k, v in original_shard_sd["param_groups"][0].items()], ) assert objects_are_equal(shard_sd["state"], original_shard_sd["state"]) assert objects_are_equal({k: shard_sd[k] for k in original_shard_sd}, original_shard_sd)