Exemplo n.º 1
0
    def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
        """Update the consolidated state_dict list, one per rank.

        Arguments:
            recipient_rank (int): on which rank to materialize the full state dict.
            -1 is a special value, which means that all ranks should have the state

        .. warning: This needs to be called on all replicas"""

        # Sync lr and other attributes in case its been updated
        OSS._sync_param_groups(self.param_groups, self.optim.param_groups)

        # Pull the sharded state from all the other replicas
        # Store all the states in order, rank by rank
        logging.debug("Pulling the sharded optimizer state from all replicas")

        self._all_states = []
        should_collect_state = self.rank == recipient_rank or recipient_rank == -1
        should_send_state = self.rank != recipient_rank

        # NCCL requires CUDA tensors for all communication primitives
        dist_device = torch.device("cuda") if self.backend == dist.Backend.NCCL else self._default_device

        for rank in range(self.world_size):
            if rank == self.rank:
                if should_collect_state:
                    logging.debug("Saving self state")
                    self._all_states.append(
                        recursive_copy_to_device(self.optim.state_dict(), non_blocking=True, device=torch.device("cpu"))
                    )

                # Sync with other replicas
                state_to_share = (
                    self.optim.state_dict()
                    if should_send_state
                    else torch.tensor([0], dtype=torch.uint8, device=dist_device)
                )
                broadcast_object(
                    state_to_share, src_rank=self.global_rank, group=self.group, dist_device=dist_device,
                )
            else:
                # Fetch the optim state from the other replicas
                replica_state = broadcast_object(
                    torch.tensor([0], dtype=torch.uint8, device=dist_device),
                    src_rank=self._local_to_global_rank[rank],
                    group=self.group,
                    dist_device=dist_device,
                )

                if should_collect_state:
                    self._all_states.append(
                        recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu"))
                    )

                logging.debug("State from rank %s received", rank)
Exemplo n.º 2
0
    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """Restore the global parameter groups as well as the shard.

        Arguments:
            state_dict (dict): optimizer state. Should be an object returned
                from a call to :meth:`state_dict`
        """

        # Update the state, trusting the ordering in param_groups
        # Apart from the removal of states not owned by this rank, the pytorch logic is kept
        # (See torch.optim.optimizer)
        id_map = {
            old_id: p
            for old_id, p in zip(
                chain.from_iterable((g["params"] for g in state_dict["param_groups"])),
                chain.from_iterable((g["params"] for g in self.param_groups)),
            )
        }

        for key, value in state_dict["state"].items():
            param = id_map[key]

            # Populate the sharded optimizer state on the fly,
            # remove the params that this rank does not own
            if self._param_to_rank[param] != self.rank:
                state_dict["state"][key] = {}
            else:
                self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device)

        super().load_state_dict(state_dict)

        # Sync with the optimizer param groups
        OSS._sync_param_groups(state_dict["param_groups"], self.param_groups)
        OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
Exemplo n.º 3
0
    def _test_consolidated_optimizer(self,
                                     config,
                                     rank,
                                     group,
                                     optim_fn=torch.optim.SGD,
                                     transformer=False):
        """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()"""
        # Establish reference behavior.

        if transformer:
            unwrapped_model = TransformerWithSharedParams(
                group, wrapper_config=config).cuda()
            fsdp = self.get_wrapped_model(group, config=config).cuda()
        else:
            unwrapped_model = MixtureOfExperts(group,
                                               wrapper_config=None).cuda()
            fsdp = FullyShardedDataParallel(
                MixtureOfExperts(group, wrapper_config=config)).cuda()

        try:
            fsdp_optim = optim_fn(
                fsdp.parameters(),
                lr=0.01,
            )
            optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01)
        except TypeError:  # Adadelta
            fsdp_optim = optim_fn(fsdp.parameters())
            optim_unwrapped = optim_fn(unwrapped_model.parameters())

        fsdp_optim.zero_grad()
        optim_unwrapped.zero_grad()
        with torch.cuda.amp.autocast(enabled=True):
            x = fsdp.module.get_input(torch.device("cuda"))
            output = fsdp(*x)
            loss = fsdp.module.get_loss(x, output).to("cuda")
            fsdp.module.run_backward(loss)
            fsdp_optim.step()

            output = unwrapped_model(*x)
            loss = unwrapped_model.get_loss(x, output)
            unwrapped_model.run_backward(loss)
            optim_unwrapped.step()
        unwrapped_sd = optim_unwrapped.state_dict()

        if not transformer:
            no_broadcast_children = [
                x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state
            ]
            assert len(no_broadcast_children) == 1
            assert fsdp._fsdp_instances[-1].no_broadcast_optim_state
        torch.cuda.empty_cache()
        cuda_gb_before = torch.cuda.memory_stats(
            fsdp.rank)["allocated_bytes.all.current"] / 1024**3
        tstart = time()
        sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
        duration = time() - tstart
        assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"

        cuda_gb_after = torch.cuda.memory_stats(
            fsdp.rank)["allocated_bytes.all.current"] / 1024**3
        mem_usg_gb = cuda_gb_after - cuda_gb_before
        assert mem_usg_gb == 0, f"gather_full_optim_state_dict used {mem_usg_gb:.2f} CUDA GB, max allowed is 0"
        assert cuda_gb_after > 0, "got 0 memory usage, logging is broken"

        if fsdp.rank > 0:
            assert sd is None
            return

        # assert whole state dict on CPU
        for k, v in sd["state"].items():
            for buffer_name, t in v.items():
                if torch.is_tensor(t):
                    msg = f"got device {t.device} for {k}: {buffer_name}. expected CPU"
                    assert t.device == torch.device("cpu"), msg

        unflat_state = sd["state"]
        assert "uncollected_local_ids" in sd
        shard_sd = fsdp.get_shard_from_optim_state_dict(sd)
        shard_sd = recursive_copy_to_device(shard_sd,
                                            non_blocking=False,
                                            device="cpu")
        state_after_get_shard = sd["state"]
        assert objects_are_equal(unflat_state,
                                 state_after_get_shard)  # no side effects.

        assert_equal(len(sd["state"]), len(unwrapped_sd["state"]))
        assert_equal(len(sd["param_groups"][0]["params"]),
                     len(unwrapped_sd["param_groups"][0]["params"]))
        assert_equal(
            sum([first_tensor_numel(v) for k, v in sd["state"].items()]),
            sum([
                first_tensor_numel(v)
                for k, v in unwrapped_sd["state"].items()
            ]),
        )

        original_shard_sd = fsdp_optim.state_dict()
        assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"]))
        assert_equal(shard_sd.keys(), original_shard_sd.keys())
        original_shard_sd = recursive_copy_to_device(original_shard_sd,
                                                     non_blocking=False,
                                                     device="cpu")
        # Before asserting that the dicts are equal, we check keys individually to allow nice tracebacks.
        assert_equal(
            [first_tensor_numel(v) for k, v in shard_sd["state"].items()],
            [
                first_tensor_numel(v)
                for k, v in original_shard_sd["state"].items()
            ],
        )
        assert_equal(
            [v for k, v in shard_sd["param_groups"][0].items()],
            [v for k, v in original_shard_sd["param_groups"][0].items()],
        )
        assert objects_are_equal(shard_sd["state"], original_shard_sd["state"])
        assert objects_are_equal({k: shard_sd[k]
                                  for k in original_shard_sd},
                                 original_shard_sd)