コード例 #1
0
ファイル: api.py プロジェクト: ydcjeff/pytorch
    def _init_rpc(self):
        self._rpc_initialized = True
        self._remote_shards = {}

        # Gather all the sharded tensor ids.
        world_size = dist.get_world_size(self._process_group)
        worker_infos = rpc._get_current_rpc_agent().get_worker_infos()
        rank_to_name = {}
        name_to_rank = {}

        for worker_info in worker_infos:
            rank_to_name[worker_info.id] = worker_info.name
            name_to_rank[worker_info.name] = worker_info.id

        rpc_workers = set()
        for rank in range(world_size):
            if self._process_group == distributed_c10d._get_default_group():
                global_rank = rank
            else:
                global_rank = distributed_c10d._get_global_rank(
                    self._process_group, rank)
            rpc_workers.add(rank_to_name[global_rank])

        all_tensor_ids = rpc.api._all_gather(self._sharded_tensor_id,
                                             rpc_workers)

        # Share the local shards to the entire world.
        futs = []
        rpc_rank = rpc.get_worker_info().id
        for rank in range(world_size):
            # Skip self.
            if rank == dist.get_rank(self._process_group):
                continue

            if self._process_group == distributed_c10d._get_default_group():
                global_rank = rank
            else:
                global_rank = distributed_c10d._get_global_rank(
                    self._process_group, rank)

            if len(self.local_shards()) != 0:
                rrefs: List[rpc.RRef[Shard]] = [
                    rpc.RRef(shard) for shard in self.local_shards()
                ]
                fut = rpc.rpc_async(
                    global_rank,
                    _register_remote_shards,
                    args=(all_tensor_ids[rank_to_name[global_rank]], rrefs,
                          rpc_rank))
                futs.append(fut)

        torch.futures.wait_all(futs)

        # Barrier for all RPCs to finish on all ranks.
        rpc.api._barrier(rpc_workers)
コード例 #2
0
ファイル: world.py プロジェクト: mrshenli/machin
    def irecv(self, tensor, src=None, tag=0):
        # pylint: disable=protected-access

        # Original irecv doesn't support recv from any
        # but original recv does. They are essentially
        # the same except recv have a wait() call
        dist_c10d._check_single_tensor(tensor, "tensor")
        if dist_c10d._rank_not_in_group(self.group):
            return -1

        if self.group == dist_c10d.GroupMember.WORLD:
            dist_c10d._check_default_pg()
            pg = dist_c10d._default_pg
        else:
            pg = self.group

        if src is None:
            work = pg.recv_anysource([tensor], tag)
            src_rank = work.source_rank()
            if self.group == dist_c10d.GroupMember.WORLD:
                return src_rank
            else:
                return dist_c10d._get_global_rank(pg, src_rank)
        else:
            if self.group == dist_c10d.GroupMember.WORLD:
                pg.recv([tensor], src, tag).wait()
            else:
                group_src_rank = dist_c10d._get_group_rank(pg, src)
                pg.recv([tensor], group_src_rank, tag).wait()
            return src
コード例 #3
0
 def wait(self):
     nonlocal work, pg
     work.wait()
     if _torch_version_less_than(1, 7):
         src_rank = work.source_rank()
     else:
         src_rank = work._source_rank()
     return dist_c10d._get_global_rank(pg, src_rank)
コード例 #4
0
ファイル: rpc.py プロジェクト: hulaba/fairscale
    def forward(self,
                tensor: TensorOrTensors) -> TensorOrTensors:  # type: ignore
        shape = get_shapes(tensor)
        dtype = get_dtype(tensor)

        if isinstance(tensor, torch.Tensor):
            num_tensors = 1
        else:
            num_tensors = len(tensor)

        futures = [
            rpc.rpc_async(self._get_rpc_name(rank),
                          self._model_forward,
                          args=(self.model.training, shape, dtype))
            for rank in range(1, self.group.size())
        ]

        if self.model.final_stage:
            return self.model(tensor)
        else:
            event = Event()
            t = Thread(target=self._model_forward_first_stage,
                       args=(tensor, event))
            t.start()

            shape, dtype = futures.pop().wait()
            dest_rank = self.group.size() - 1
            dest = self._get_rpc_name(dest_rank)
            dest_global_rank = _get_global_rank(self.group, dest_rank)
            src_global_rank = torch.distributed.get_rank()
            queue = EVENT_LOOP_QUEUE

            activations = PipeMessage(dest_global_rank,
                                      src_global_rank,
                                      queue_name=queue,
                                      tensor_count=num_tensors)
            grads = PipeMessage(src_global_rank,
                                dest_global_rank,
                                queue_name=queue,
                                tensor_count=num_tensors)

            back_fut = rpc.rpc_async(dest,
                                     self._send_result_and_do_backwards,
                                     args=(self.model.training, activations,
                                           grads))
            futures.append(back_fut)

            result = self._recv_result(self.model, shape, dtype, activations)
            if isinstance(result, torch.Tensor):
                result.requires_grad_()
            else:
                for r in result:
                    r.requires_grad_()

            assert self.model.pipeline
            return PipeBackRedirect.apply(result, dest_global_rank, event,
                                          grads, self.model.pipeline.transport,
                                          futures)
コード例 #5
0
ファイル: deepspeed_light.py プロジェクト: tgs266/DeepSpeed
 def _configure_distributed_model(self, model):
     self.module = model
     if self.fp16_enabled():
         self.module.half()
     self.module.to(self.device)
     if self.mpu is None:
         self.data_parallel_group = _initialize_parameter_parallel_groups()
         self.dp_world_size = dist.get_world_size()
         src_rank = 0
     else:
         self.data_parallel_group = self.mpu.get_data_parallel_group()
         self.dp_world_size = self.mpu.get_data_parallel_world_size()
         src_rank = _get_global_rank(self.mpu.get_data_parallel_group(), 0)
         logger.info(f"global src_rank={src_rank}")
     for p in self.module.parameters():
         if torch.is_tensor(p):
             dist.broadcast(p, src_rank, group=self.data_parallel_group)
コード例 #6
0
    def __init__(self,
                 params,
                 modifier_rank=None,
                 fwd_module=None,
                 enabled=True):
        """A context that collects parameters that were partitioned via a
        :class:`deepspeed.zero.Init` context. The parameters are partitioned
        again upon exit.

        Args:
            params (``torch.nn.Parameter``): A single parameter or a list of parameters to collect.
                It's assumed that all parameters are zero params.
            modifier_rank (int, optional): If specified, this rank's parameter will be
                broadcasted on exit from the context. This argument is required if ``params`` are
                modified, so that all processes have a consistent view of the data. Defaults
                to ``None``.
            fwd_module (``torch.nn.Module``, optional): If specified, ``params`` will be
                registered as external parameters of ``fwd_module``. See :meth:`deepspeed.zero.register_external_parameter`.
            enabled (bool, optional): If ``False``, this context is a no-op. Defaults to ``True``.

        Examples
        ========

        #. Allocate a partitioned module, initialize its weight on rank 0, and update all
           processes.

            .. code-block:: python

                with deepspeed.zero.Init():
                    linear = torch.nn.Linear(1000,1000)

                with deepspeed.zero.GatheredParameters(linear.weight,
                                                       modifier_rank=0):
                    if torch.distributed.get_rank() == 0:
                        linear.weight.zero_()

                with deepspeed.zero.GatheredParameters(linear.weight,
                                                       modifier_rank=0):
                    if torch.distributed.get_rank() == 0:
                        linear.weight.zero_()

        #. Collect a partitioned weight to pass to another module during
           training. The parameter will be registered as an external parameter
           and made available during the backward pass.

            .. code-block:: python
                :emphasize-lines: 6

                def forward(self, input):
                    x = self.layer1(input)

                    # self.layer1.weight is required by self.layer2.forward
                    with deepspeed.zero.GatheredParameters(self.layer1.weight,
                                                           fwd_module=self):
                        y = self.layer2(x, self.layer1.weight)
                    return y


        #. Pretrained model loading

            .. code-block:: python

                with deepspeed.zero.Init():
                    model = MyModel()

                state_dict = torch.load(model_path, map_location="cpu")

                def load(module: nn.Module, prefix=""):
                    # because zero3 puts placeholders in model params, this context
                    # manager gathers (unpartitions) the params of the current layer, then loads from
                    # the state dict and then re-partitions them again
                    with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
                        if torch.distributed.get_rank() == 0:
                            module._load_from_state_dict(state_dict, prefix)

                    for name, child in module._modules.items():
                        if child is not None:
                            load(child, prefix + name + ".")

                load(model, prefix="")

        If this approach is not used, then the full model will first get copied to each GPU. For models
        bigger than the memory of a single gpu this method is required.
        """

        self.enabled = enabled
        if not enabled:
            return

        if not isinstance(params, list):
            params = [params]

        # enable if at least one is zero-param, otherwise a noop
        if not any(is_zero_param(p) for p in params):
            self.enabled = False
            return

        self.params = [p for p in params if hasattr(p, "ds_id")]
        self.src_rank = None
        if modifier_rank is not None:
            if self.params[
                    0].ds_process_group == torch.distributed.group.WORLD:
                self.src_rank = modifier_rank
            else:
                # A group was specified; convert DP rank to global rank
                self.src_rank = _get_global_rank(
                    self.params[0].ds_process_group, modifier_rank)
        self.fwd_module = fwd_module
        if self.fwd_module is not None:
            # is a no-op if already registered
            for p in self.params:
                register_external_parameter(self.fwd_module, p)
コード例 #7
0
    def step(self, closure=None):
        """
        Not supporting closure.
        """
        # First compute norm for all group so we know if there is overflow

        self.overflow = self.overflow_checker.check()

        prev_scale = self.loss_scale
        self._update_scale(self.overflow)
        if self.overflow:
            self.zero_grad()
            if self.verbose:
                print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
                      "scale: {}, reducing to {}".format(
                          prev_scale, self.loss_scale))
            return self.overflow

        norm_groups = []
        single_partition_grad_groups = []

        partition_id = dist.get_rank(group=self.dp_process_group)
        for i, group in enumerate(self.fp16_groups):

            norm_groups.append(get_grad_norm(group, mpu=self.mpu))

            #free gradients for all the parameters that are not updated by this process
            self.free_grad_in_param_list(self.params_not_in_partition[i])

            #create a flat gradients for parameters updated by this process
            single_grad_partition = self.get_flat_partition(
                self.params_in_partition[i],
                self.first_offset[i],
                self.partition_size[i],
                dtype=self.single_partition_of_fp32_groups[i].dtype)

            self.single_partition_of_fp32_groups[
                i].grad = single_grad_partition

            #release all the gradient since we have already created a necessary copy in dp_grad_partition
            self.free_grad_in_param_list(self.params_in_partition[i])

            single_partition_grad_groups.append(single_grad_partition)

        self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups)

        self.optimizer.step()

        #get rid of the fp32 gradients. Not needed anymore
        for group in self.single_partition_of_fp32_groups:
            group.grad = None

        for fp16_partitions, fp32_partition in zip(
                self.parallel_partitioned_fp16_groups,
                self.single_partition_of_fp32_groups):
            fp16_partitions[partition_id].data.copy_(fp32_partition.data)

        dp_world_size = dist.get_world_size(group=self.dp_process_group)
        #gather the updated weights from everyone
        for _, partitioned_params in enumerate(
                self.parallel_partitioned_fp16_groups):
            if self.all_gather_partitions:
                # controllable memory-time tradeoff
                num_shards = max(
                    1, partitioned_params[partition_id].numel() *
                    dp_world_size // self.allgather_size)
                shard_size = partitioned_params[partition_id].numel(
                ) // num_shards
                num_elements = shard_size
                for shard_id in range(num_shards + 1):
                    if shard_id == num_shards:
                        if shard_size * num_shards >= partitioned_params[
                                partition_id].numel():
                            break
                        else:
                            num_elements = partitioned_params[
                                partition_id].numel() - shard_id * shard_size
                    shard_list = []
                    for dp_id in range(dp_world_size):
                        curr_shard = partitioned_params[dp_id].narrow(
                            0, shard_id * shard_size, num_elements)
                        shard_list.append(curr_shard)
                    dist.all_gather(shard_list,
                                    shard_list[partition_id],
                                    group=self.dp_process_group)
            else:
                #this should require less memory but should be faster
                for src, partitioned_param in enumerate(partitioned_params):
                    global_src = _get_global_rank(self.dp_process_group, src)
                    dist.broadcast(partitioned_param,
                                   global_src,
                                   group=self.dp_process_group)

        # TODO: we probably don't need this? just to be safe
        for i in range(len(norm_groups)):
            updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
                                                      self.fp16_groups[i])
            for p, q in zip(self.fp16_groups[i], updated_params):
                p.data = q.data

        return self.overflow
コード例 #8
0
ファイル: rpc.py プロジェクト: hulaba/fairscale
def get_global_ranks_from_group(group: ProcessGroup) -> List[int]:
    return [_get_global_rank(group, r) for r in range(group.size())]
コード例 #9
0
ファイル: rpc.py プロジェクト: hulaba/fairscale
 def _get_rpc_name(self, rank: int) -> str:
     return self.worker_map[_get_global_rank(self.group, rank)]
コード例 #10
0
    def __init__(self,
                 param,
                 modifier_rank=None,
                 fwd_module=None,
                 enabled=True):
        """A context that collects a parameter that was partitioned via a
        :class:`deepspeed.zero.Init` context. The parameter is partitioned
        again upon exit.

        Args:
            param (``torch.nn.Parameter``): The parameter to collect.
            modifier_rank (int, optional): If specified, this rank's parameter will be
                broadcasted after the context. This argument is required if ``param`` is
                modified all processes should have a consistent view of the data. Defaults
                to ``None``.
            fwd_module (``torch.nn.Module``, optional): If specified, ``param`` will be
                registered as an external parameter of ``fwd_module``. See :meth:`deepspeed.zero.register_external_parameter`.
            enabled (bool, optional): If ``False``, this context is a no-op. Defaults to ``True``.

        Examples
        ========

        #. Allocate a partitioned module, initialize its weight on rank 0, and update all
           processes.

            .. code-block:: python

                with deepspeed.zero.Init():
                    linear = torch.nn.Linear(1000,1000)

                with deepspeed.zero.GatheredParameters(linear.weight,
                                                       modifier_rank=0):
                    if torch.distributed.get_rank() == 0:
                        linear.weight.zero_()


        #. Collect a partitioned weight to pass to another module during
           training. The parameter will be registered as an external parameter
           and made available during the backward pass.

            .. code-block:: python
                :emphasize-lines: 6

                def forward(self, input):
                    x = self.layer1(input)

                    # self.layer1.weight is required by self.layer2.forward
                    with deepspeed.zero.GatheredParameters(self.layer1.weight,
                                                           fwd_module=self):
                        y = self.layer2(x, self.layer1.weight)
                    return y
        """

        self.enabled = enabled
        if not enabled:
            return

        # This is a no-op, just return.
        if not is_zero_param(param):
            self.enabled = False
            return

        self.param = param
        self.src_rank = None
        if modifier_rank is not None:
            if self.param.ds_process_group == torch.distributed.group.WORLD:
                self.src_rank = modifier_rank
            else:
                # A group was specified; convert DP rank to global rank
                self.src_rank = _get_global_rank(self.param.ds_process_group,
                                                 modifier_rank)
        self.fwd_module = fwd_module
        if self.fwd_module is not None:
            # is a no-op if already registered
            register_external_parameter(self.fwd_module, self.param)
コード例 #11
0
 def wait(self):
     nonlocal work, pg
     work.wait()
     src_rank = work.source_rank()
     return dist_c10d._get_global_rank(pg, src_rank)
コード例 #12
0
ファイル: __init__.py プロジェクト: microsoft/DeepSpeed
 def get_global_rank(group, group_rank):
     from torch.distributed.distributed_c10d import _get_global_rank
     return _get_global_rank(group, group_rank)
コード例 #13
0
ファイル: groups.py プロジェクト: jeffra/DeepSpeed
def _get_expert_broadcast_src_rank(group_name):
    return _get_global_rank(_get_expert_data_parallel_group(group_name), 0)
コード例 #14
0
ファイル: groups.py プロジェクト: jeffra/DeepSpeed
def _get_broadcast_src_rank():
    return _get_global_rank(_get_data_parallel_group(), 0)
コード例 #15
0
    def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor":
        # relative imports to avoid circular dependency
        from torch.distributed._shard.sharded_tensor import (
            ShardedTensor
        )
        tensor_properties = sharded_tensor_meta.TensorProperties(
            dtype=tensor.dtype,
            layout=tensor.layout,
            requires_grad=tensor.requires_grad,
            memory_format=torch.contiguous_format,
            pin_memory=tensor.is_pinned()
        )
        current_rank = dist.get_rank(process_group)
        tensor_meta = self.build_metadata(tensor.size(), tensor_properties)
        local_shards = []
        local_tensor = None
        local_metadata = None
        tensors_to_scatter = [None] * dist.get_world_size(process_group)

        sharding_dim_size = tensor.size()[self.dim]  # type: ignore[index]
        chunks = len(self.placements)
        split_size = get_split_size(sharding_dim_size, chunks)
        scatter_shape = list(tensor.size())
        scatter_shape[self.dim] = split_size  # type: ignore[index]

        for shard_meta in tensor_meta.shards_metadata:
            rank, device = _parse_and_validate_remote_device(process_group, shard_meta.placement)
            if current_rank == src_rank:
                # Reshape to get shard for this rank and we don't want autograd
                # recording here for the narrow op and 'local_shard' should be a
                # leaf variable in the autograd graph.
                narrowed_tensor = narrow_tensor(tensor, shard_meta)
                if shard_meta.shard_sizes[self.dim] < split_size:  # type: ignore[index]
                    # for the last shard that might be smaller to other shards
                    # resize the narrowed tensor to the same size and use it for
                    # the scatter collective as dist.scatter requires same size
                    # inputs on every rank
                    tensor_to_scatter = narrowed_tensor.detach().clone().resize_(scatter_shape)
                else:
                    tensor_to_scatter = narrowed_tensor.detach().clone().contiguous()

                tensors_to_scatter[rank] = tensor_to_scatter

            if current_rank == rank:
                local_tensor = torch.empty(
                    scatter_shape, dtype=tensor.dtype, layout=tensor.layout, device=device)
                local_metadata = shard_meta

        # each rank should have local_tensor and local_metadata initialized if we build
        # the metadata list in a correct way.
        assert local_tensor is not None
        assert local_metadata is not None

        # Scatter the shards to all ranks in the pg
        # scatter takes the global rank as ``src``
        src_for_scatter = src_rank
        if process_group is not None and process_group is not distributed_c10d._get_default_group():
            src_for_scatter = distributed_c10d._get_global_rank(process_group, src_for_scatter)

        dist.scatter(
            local_tensor,
            scatter_list=tensors_to_scatter if current_rank == src_rank else None,
            src=src_for_scatter,
            group=process_group
        )

        if list(local_tensor.size()) != local_metadata.shard_sizes:
            # detach again after receiving to ensure local shards remain a leaf node
            local_tensor = local_tensor.resize_(local_metadata.shard_sizes).detach()

        # Sync requires_grad to local_shard.
        local_tensor.requires_grad = tensor.requires_grad

        local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata))

        st = ShardedTensor._init_from_local_shards_and_global_metadata(
            local_shards,
            tensor_meta,
            process_group=process_group)

        # Manually set sharding_spec
        st._sharding_spec = self

        return st