def gather(data, dst=0, group=None): """ Run gather on arbitrary picklable data (not necessarily tensors). Args: data: any picklable object dst (int): destination rank group: a torch process group. By default, will use a group which contains all ranks on gloo backend. Returns: list[data]: on dst, a list of data gathered from each rank. Otherwise, an empty list. """ if get_world_size() == 1: return [data] if group is None: group = _get_global_gloo_group() world_size = dist.get_world_size(group=group) if world_size == 1: return [data] rank = dist.get_rank(group=group) if rank == dst: output = [None for _ in range(world_size)] dist.gather_object(data, output, dst=dst, group=group) return output else: dist.gather_object(data, None, dst=dst, group=group) return []
def gather( self, dst: int = 0, out: Optional[torch.Tensor] = None, ) -> None: """ Creates a full :class:`Tensor` on rank ``dst`` by gathering all shards of the sharded tensor. The API needs to be called on all ranks in SPMD fashion. All ranks should have the same ``dst``. ``out`` should be a tensor of the same size as the overall size of the sharded tensor on ``dst`` and ``None`` on all other ranks. Args: dst(int): The rank where full tensor is constructed. Default: 0 out (:class `torch.Tensor`, optional): The output full tensor. Must to be provided ONLY on ``dst`` rank. Default: ``None`` """ rank = dist.get_rank(self._process_group) full_size = self.metadata().size _validate_output_tensor_for_gather(rank, dst, full_size, out) local_shards = self.local_shards() world_size = dist.get_world_size(self._process_group) gathered_shards: List[Optional[List[Shard]]] = [None] * world_size if rank == dst else [] # TODO: see how we could use dist.gather() instead of dist.gather_object # as the latter one involves pickling on CPU, see more context # https://github.com/pytorch/pytorch/issues/73935 dist.gather_object( obj=local_shards, object_gather_list=gathered_shards, dst=dst, group=self._process_group, ) if rank == dst: if out is None: raise ValueError("`out` Tensor must be provided on dst rank!") dims = len(full_size) for shards in gathered_shards: if shards is None: raise RuntimeError( 'Gathered shards cannot be None on dst rank {dst}' ) for shard in shards: metadata = shard.metadata tensor = shard.tensor out_narrow_view = out for dim in range(dims): out_narrow_view = out_narrow_view.narrow( dim, metadata.shard_offsets[dim], metadata.shard_sizes[dim], ) out_narrow_view.copy_(tensor)
def test_gather_object(self): output = [None] * dist.get_world_size() if self.rank == 0 else None dist.gather_object(obj=self.rank, object_gather_list=output) if self.rank == 0: for i, v in enumerate(output): self.assertEqual(i, v, f"rank: {self.rank}")
def gather_object(self, object: T) -> Optional[List[T]]: """ Same as c10d::gather_object but works without distributed enabled. """ if self.use_dist: gather_objs = cast(List[T], [None] * dist.get_world_size( self.group)) if self.is_coordinator else None dist.gather_object(obj=object, object_gather_list=gather_objs if self.is_coordinator else None, dst=self.coordinator_rank, group=self.group) result = gather_objs else: result = [object] return result
def synchronize_between_processes(self): if dist.is_initialized(): # Bypass NCCL (which forces CUDA-only sync) if dist.get_backend() == "nccl": group = dist.new_group(backend="gloo") else: group = dist.group.WORLD my_rank = dist.get_rank() output = [None for _ in range(dist.get_world_size())] dist.gather_object(self.predictions, output if my_rank == 0 else None, dst=0, group=group) return list(itertools.chain.from_iterable(output)), my_rank == 0 else: return self.predictions, True
def save_state_dict( state_dict: Dict[str, Any], storage_writer: StorageWriter, process_group: Optional[dist.ProcessGroup] = None, coordinator_rank: int = 0, no_dist: bool = False ) -> None: """ Save a distributed model in SPMD style. This function is different from ``torch.save()`` as it handles ``ShardedTensor`` by having each rank only save their local shards. To produce a state_dict with ShardedTensor instances you must call ``_register_state_dict_hook`` on the top module with value `torch.distributed._shard.sharded_tensor.state_dict_hook` prior to calling `state_dict()` on the top module. There is no guarantees of Backwards Compatibility across PyTorch versions for saved state_dicts. If using the `process_group` argument, make sure that only its ranks call `save_state_dict` and that all data in state_dict belong to it. This function can be used to save a state_dict with an intialized process group by passing ``no_dist=True``. This can be used to produce a checkpoint that can consumed by load_state_dict is a SPMD fashion. Args: state_dict (Dict[str, Any]) : A state_dict storage_writer (StorageWriter): Instance of StorageWrite use to perform writes. process_group (ProcessGroup): ProcessGroup to be used for cross-rank synchronization coordinator_rank (int): Rank to use to coordinate the checkpoint, rank0 is used by default no_dist (bool): Don't attempt to save in SPMD style. Default to False Example: >>> my_model = MyModule() >>> # We must call this function prior to state_dict() >>> my_model._register_state_dict_hook(state_dict_hook) >>> model_state_dict = my_model.state_dict() >>> fs_storage_writer = torch.distributed._shard.checkpoint.FileSystemWriter("/checkpoint/1") >>> torch.distributed._shard.checkpoint.save_state_dict( >>> state_dict=model_state_dict, >>> storage_writer=fs_stroage_writer, >>> ) .. note:: save_state_dict uses collectives to coordinate writes across ranks. For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by ``torch.cuda.current_device()`` and it is the user's responsibility to ensure that this is set so that each rank has an individual GPU, via ``torch.cuda.set_device()`` """ is_coordinator = no_dist or dist.get_rank(process_group) == coordinator_rank exceptions: List[Optional[BaseException]] = [None] if is_coordinator: try: storage_writer.prepare() except BaseException as e: exceptions = [e] # Writing can only start once prepare has finished if not no_dist: dist.broadcast_object_list(exceptions, group=process_group, src=coordinator_rank) if exceptions[0] is not None: raise CheckpointException("failed to prepare storage", {coordinator_rank : exceptions[0]}) rank_write_error: Optional[BaseException] try: ( metadata, bytes_write_requests, tensor_write_requests, ) = _prepare(state_dict, is_coordinator, process_group) combined_writes: List[Union[TensorWriteRequest, BytesWriteRequest]] = [] combined_writes.extend(tensor_write_requests) combined_writes.extend(bytes_write_requests) storage_writer.prepare_storage(combined_writes) bytes_futures = storage_writer.write_bytes(bytes_write_requests) tensor_futures = storage_writer.write_tensors(tensor_write_requests) torch.futures.wait_all([bytes_futures, tensor_futures]) rank_write_error = None except BaseException as e: rank_write_error = e all_errors: List[Optional[BaseException]] # collect all write errors if not no_dist: all_errors = [None] * dist.get_world_size(process_group) dist.gather_object( obj=rank_write_error, object_gather_list=all_errors if is_coordinator else None, dst=coordinator_rank ) else: all_errors = [rank_write_error] result: List[Optional[CheckpointException]] = [None] if is_coordinator: message: Optional[str] = None # gather produces an array of arrays, flatten it if any(all_errors): message = "Failed to write data" else: try: storage_writer.finish(metadata=metadata) except BaseException as e: all_errors[coordinator_rank] = e message = "Failed to finish checkpoint" if message is not None: node_failures = {i: err for i, err in enumerate(all_errors) if err is not None} result[0] = CheckpointException(message, node_failures) if not no_dist: dist.broadcast_object_list( result, group=process_group, src=coordinator_rank) if result[0] is not None: raise result[0]
import torch.distributed as dist if dist.get_rank() == 0: objects = ["f", 1] else: objects = [None, None] # ruleid: pickles-in-torch-distributed dist.broadcast_object_list(objects, src=0) # ruleid: pickles-in-torch-distributed dist.all_gather_object(output, gather_objects[dist.get_rank()]) # ruleid: pickles-in-torch-distributed dist.gather_object(gather_objects[dist.get_rank()], output if dist.get_rank() == 0 else None, dst=0) # ruleid: pickles-in-torch-distributed dist.scatter_object_list(output_list, objects, src=0)