Example #1
0
def gather(data, dst=0, group=None):
    """
    Run gather on arbitrary picklable data (not necessarily tensors).

    Args:
        data: any picklable object
        dst (int): destination rank
        group: a torch process group. By default, will use a group which
            contains all ranks on gloo backend.

    Returns:
        list[data]: on dst, a list of data gathered from each rank. Otherwise,
            an empty list.
    """
    if get_world_size() == 1:
        return [data]
    if group is None:
        group = _get_global_gloo_group()
    world_size = dist.get_world_size(group=group)
    if world_size == 1:
        return [data]
    rank = dist.get_rank(group=group)

    if rank == dst:
        output = [None for _ in range(world_size)]
        dist.gather_object(data, output, dst=dst, group=group)
        return output
    else:
        dist.gather_object(data, None, dst=dst, group=group)
        return []
Example #2
0
    def gather(
        self,
        dst: int = 0,
        out: Optional[torch.Tensor] = None,
    ) -> None:
        """
        Creates a full :class:`Tensor` on rank ``dst`` by gathering all shards of the
        sharded tensor.

        The API needs to be called on all ranks in SPMD fashion. All ranks should have
        the same ``dst``. ``out`` should be a tensor of the same size as the overall
        size of the sharded tensor on ``dst`` and ``None`` on all other ranks.

        Args:
            dst(int): The rank where full tensor is constructed.
                Default: 0
            out (:class `torch.Tensor`, optional): The output full tensor.
                Must to be provided ONLY on ``dst`` rank.
                Default: ``None``
        """
        rank = dist.get_rank(self._process_group)
        full_size = self.metadata().size
        _validate_output_tensor_for_gather(rank, dst, full_size, out)

        local_shards = self.local_shards()

        world_size = dist.get_world_size(self._process_group)

        gathered_shards: List[Optional[List[Shard]]] = [None] * world_size if rank == dst else []
        # TODO: see how we could use dist.gather() instead of dist.gather_object
        # as the latter one involves pickling on CPU, see more context
        # https://github.com/pytorch/pytorch/issues/73935
        dist.gather_object(
            obj=local_shards,
            object_gather_list=gathered_shards,
            dst=dst,
            group=self._process_group,
        )
        if rank == dst:
            if out is None:
                raise ValueError("`out` Tensor must be provided on dst rank!")
            dims = len(full_size)
            for shards in gathered_shards:
                if shards is None:
                    raise RuntimeError(
                        'Gathered shards cannot be None on dst rank {dst}'
                    )
                for shard in shards:
                    metadata = shard.metadata
                    tensor = shard.tensor

                    out_narrow_view = out
                    for dim in range(dims):
                        out_narrow_view = out_narrow_view.narrow(
                            dim,
                            metadata.shard_offsets[dim],
                            metadata.shard_sizes[dim],
                        )

                    out_narrow_view.copy_(tensor)
Example #3
0
    def test_gather_object(self):
        output = [None] * dist.get_world_size() if self.rank == 0 else None
        dist.gather_object(obj=self.rank, object_gather_list=output)

        if self.rank == 0:
            for i, v in enumerate(output):
                self.assertEqual(i, v, f"rank: {self.rank}")
Example #4
0
    def gather_object(self, object: T) -> Optional[List[T]]:
        """
        Same as c10d::gather_object but works without distributed enabled.
        """
        if self.use_dist:
            gather_objs = cast(List[T], [None] * dist.get_world_size(
                self.group)) if self.is_coordinator else None

            dist.gather_object(obj=object,
                               object_gather_list=gather_objs
                               if self.is_coordinator else None,
                               dst=self.coordinator_rank,
                               group=self.group)
            result = gather_objs
        else:
            result = [object]
        return result
    def synchronize_between_processes(self):
        if dist.is_initialized():
            # Bypass NCCL (which forces CUDA-only sync)
            if dist.get_backend() == "nccl":
                group = dist.new_group(backend="gloo")
            else:
                group = dist.group.WORLD

            my_rank = dist.get_rank()
            output = [None for _ in range(dist.get_world_size())]
            dist.gather_object(self.predictions,
                               output if my_rank == 0 else None,
                               dst=0,
                               group=group)

            return list(itertools.chain.from_iterable(output)), my_rank == 0
        else:
            return self.predictions, True
Example #6
0
def save_state_dict(
    state_dict: Dict[str, Any],
    storage_writer: StorageWriter,
    process_group: Optional[dist.ProcessGroup] = None,
    coordinator_rank: int = 0,
    no_dist: bool = False
) -> None:
    """
    Save a distributed model in SPMD style.

    This function is different from ``torch.save()`` as it handles
    ``ShardedTensor`` by having each rank only save their local shards.

    To produce a state_dict with ShardedTensor instances you must call
    ``_register_state_dict_hook`` on the top module with value
    `torch.distributed._shard.sharded_tensor.state_dict_hook` prior to
    calling `state_dict()` on the top module.

    There is no guarantees of Backwards Compatibility across PyTorch versions
    for saved state_dicts.

    If using the `process_group` argument, make sure that only its ranks
    call `save_state_dict` and that all data in state_dict belong to it.

    This function can be used to save a state_dict with an intialized process
    group by passing ``no_dist=True``. This can be used to produce a checkpoint
    that can consumed by load_state_dict is a SPMD fashion.

    Args:
        state_dict (Dict[str, Any]) : A state_dict
        storage_writer (StorageWriter): Instance of StorageWrite use to perform writes.
        process_group (ProcessGroup): ProcessGroup to be used for cross-rank synchronization
        coordinator_rank (int): Rank to use to coordinate the checkpoint, rank0 is used by default
        no_dist (bool): Don't attempt to save in SPMD style. Default to False

    Example:
        >>> my_model = MyModule()
        >>> # We must call this function prior to state_dict()
        >>> my_model._register_state_dict_hook(state_dict_hook)

        >>> model_state_dict = my_model.state_dict()

        >>> fs_storage_writer = torch.distributed._shard.checkpoint.FileSystemWriter("/checkpoint/1")
        >>> torch.distributed._shard.checkpoint.save_state_dict(
        >>>     state_dict=model_state_dict,
        >>>     storage_writer=fs_stroage_writer,
        >>> )

    .. note:: save_state_dict uses collectives to coordinate writes across ranks.
        For NCCL-based process groups, internal tensor representations of objects
        must be moved to the GPU device before communication takes place. In this
        case, the device used is given by ``torch.cuda.current_device()`` and it
        is the user's responsibility to ensure that this is set so that each rank
        has an individual GPU, via ``torch.cuda.set_device()``
    """
    is_coordinator = no_dist or dist.get_rank(process_group) == coordinator_rank

    exceptions: List[Optional[BaseException]] = [None]
    if is_coordinator:
        try:
            storage_writer.prepare()
        except BaseException as e:
            exceptions = [e]

    # Writing can only start once prepare has finished
    if not no_dist:
        dist.broadcast_object_list(exceptions, group=process_group, src=coordinator_rank)

    if exceptions[0] is not None:
        raise CheckpointException("failed to prepare storage", {coordinator_rank : exceptions[0]})

    rank_write_error: Optional[BaseException]
    try:
        (
            metadata,
            bytes_write_requests,
            tensor_write_requests,
        ) = _prepare(state_dict, is_coordinator, process_group)

        combined_writes: List[Union[TensorWriteRequest, BytesWriteRequest]] = []
        combined_writes.extend(tensor_write_requests)
        combined_writes.extend(bytes_write_requests)

        storage_writer.prepare_storage(combined_writes)
        bytes_futures = storage_writer.write_bytes(bytes_write_requests)
        tensor_futures = storage_writer.write_tensors(tensor_write_requests)
        torch.futures.wait_all([bytes_futures, tensor_futures])
        rank_write_error = None
    except BaseException as e:
        rank_write_error = e

    all_errors: List[Optional[BaseException]]
    # collect all write errors
    if not no_dist:
        all_errors = [None] * dist.get_world_size(process_group)
        dist.gather_object(
            obj=rank_write_error,
            object_gather_list=all_errors if is_coordinator else None,
            dst=coordinator_rank
        )
    else:
        all_errors = [rank_write_error]

    result: List[Optional[CheckpointException]] = [None]
    if is_coordinator:
        message: Optional[str] = None
        # gather produces an array of arrays, flatten it
        if any(all_errors):
            message = "Failed to write data"
        else:
            try:
                storage_writer.finish(metadata=metadata)
            except BaseException as e:
                all_errors[coordinator_rank] = e
                message = "Failed to finish checkpoint"

        if message is not None:
            node_failures = {i: err for i, err in enumerate(all_errors) if err is not None}
            result[0] = CheckpointException(message, node_failures)

    if not no_dist:
        dist.broadcast_object_list(
            result,
            group=process_group,
            src=coordinator_rank)

    if result[0] is not None:
        raise result[0]
Example #7
0
import torch.distributed as dist

if dist.get_rank() == 0:
    objects = ["f", 1]
else:
    objects = [None, None]

# ruleid: pickles-in-torch-distributed
dist.broadcast_object_list(objects, src=0)

# ruleid: pickles-in-torch-distributed
dist.all_gather_object(output, gather_objects[dist.get_rank()])

# ruleid: pickles-in-torch-distributed
dist.gather_object(gather_objects[dist.get_rank()],
                   output if dist.get_rank() == 0 else None,
                   dst=0)

# ruleid: pickles-in-torch-distributed
dist.scatter_object_list(output_list, objects, src=0)