def _init_rpc(self): self._rpc_initialized = True self._remote_shards = {} # Gather all the sharded tensor ids. world_size = dist.get_world_size(self._process_group) worker_infos = rpc._get_current_rpc_agent().get_worker_infos() rank_to_name = {} name_to_rank = {} for worker_info in worker_infos: rank_to_name[worker_info.id] = worker_info.name name_to_rank[worker_info.name] = worker_info.id rpc_workers = set() for rank in range(world_size): if self._process_group == distributed_c10d._get_default_group(): global_rank = rank else: global_rank = distributed_c10d._get_global_rank( self._process_group, rank) rpc_workers.add(rank_to_name[global_rank]) all_tensor_ids = rpc.api._all_gather(self._sharded_tensor_id, rpc_workers) # Share the local shards to the entire world. futs = [] rpc_rank = rpc.get_worker_info().id for rank in range(world_size): # Skip self. if rank == dist.get_rank(self._process_group): continue if self._process_group == distributed_c10d._get_default_group(): global_rank = rank else: global_rank = distributed_c10d._get_global_rank( self._process_group, rank) if len(self.local_shards()) != 0: rrefs: List[rpc.RRef[Shard]] = [ rpc.RRef(shard) for shard in self.local_shards() ] fut = rpc.rpc_async( global_rank, _register_remote_shards, args=(all_tensor_ids[rank_to_name[global_rank]], rrefs, rpc_rank)) futs.append(fut) torch.futures.wait_all(futs) # Barrier for all RPCs to finish on all ranks. rpc.api._barrier(rpc_workers)
def irecv(self, tensor, src=None, tag=0): # pylint: disable=protected-access # Original irecv doesn't support recv from any # but original recv does. They are essentially # the same except recv have a wait() call dist_c10d._check_single_tensor(tensor, "tensor") if dist_c10d._rank_not_in_group(self.group): return -1 if self.group == dist_c10d.GroupMember.WORLD: dist_c10d._check_default_pg() pg = dist_c10d._default_pg else: pg = self.group if src is None: work = pg.recv_anysource([tensor], tag) src_rank = work.source_rank() if self.group == dist_c10d.GroupMember.WORLD: return src_rank else: return dist_c10d._get_global_rank(pg, src_rank) else: if self.group == dist_c10d.GroupMember.WORLD: pg.recv([tensor], src, tag).wait() else: group_src_rank = dist_c10d._get_group_rank(pg, src) pg.recv([tensor], group_src_rank, tag).wait() return src
def wait(self): nonlocal work, pg work.wait() if _torch_version_less_than(1, 7): src_rank = work.source_rank() else: src_rank = work._source_rank() return dist_c10d._get_global_rank(pg, src_rank)
def forward(self, tensor: TensorOrTensors) -> TensorOrTensors: # type: ignore shape = get_shapes(tensor) dtype = get_dtype(tensor) if isinstance(tensor, torch.Tensor): num_tensors = 1 else: num_tensors = len(tensor) futures = [ rpc.rpc_async(self._get_rpc_name(rank), self._model_forward, args=(self.model.training, shape, dtype)) for rank in range(1, self.group.size()) ] if self.model.final_stage: return self.model(tensor) else: event = Event() t = Thread(target=self._model_forward_first_stage, args=(tensor, event)) t.start() shape, dtype = futures.pop().wait() dest_rank = self.group.size() - 1 dest = self._get_rpc_name(dest_rank) dest_global_rank = _get_global_rank(self.group, dest_rank) src_global_rank = torch.distributed.get_rank() queue = EVENT_LOOP_QUEUE activations = PipeMessage(dest_global_rank, src_global_rank, queue_name=queue, tensor_count=num_tensors) grads = PipeMessage(src_global_rank, dest_global_rank, queue_name=queue, tensor_count=num_tensors) back_fut = rpc.rpc_async(dest, self._send_result_and_do_backwards, args=(self.model.training, activations, grads)) futures.append(back_fut) result = self._recv_result(self.model, shape, dtype, activations) if isinstance(result, torch.Tensor): result.requires_grad_() else: for r in result: r.requires_grad_() assert self.model.pipeline return PipeBackRedirect.apply(result, dest_global_rank, event, grads, self.model.pipeline.transport, futures)
def _configure_distributed_model(self, model): self.module = model if self.fp16_enabled(): self.module.half() self.module.to(self.device) if self.mpu is None: self.data_parallel_group = _initialize_parameter_parallel_groups() self.dp_world_size = dist.get_world_size() src_rank = 0 else: self.data_parallel_group = self.mpu.get_data_parallel_group() self.dp_world_size = self.mpu.get_data_parallel_world_size() src_rank = _get_global_rank(self.mpu.get_data_parallel_group(), 0) logger.info(f"global src_rank={src_rank}") for p in self.module.parameters(): if torch.is_tensor(p): dist.broadcast(p, src_rank, group=self.data_parallel_group)
def __init__(self, params, modifier_rank=None, fwd_module=None, enabled=True): """A context that collects parameters that were partitioned via a :class:`deepspeed.zero.Init` context. The parameters are partitioned again upon exit. Args: params (``torch.nn.Parameter``): A single parameter or a list of parameters to collect. It's assumed that all parameters are zero params. modifier_rank (int, optional): If specified, this rank's parameter will be broadcasted on exit from the context. This argument is required if ``params`` are modified, so that all processes have a consistent view of the data. Defaults to ``None``. fwd_module (``torch.nn.Module``, optional): If specified, ``params`` will be registered as external parameters of ``fwd_module``. See :meth:`deepspeed.zero.register_external_parameter`. enabled (bool, optional): If ``False``, this context is a no-op. Defaults to ``True``. Examples ======== #. Allocate a partitioned module, initialize its weight on rank 0, and update all processes. .. code-block:: python with deepspeed.zero.Init(): linear = torch.nn.Linear(1000,1000) with deepspeed.zero.GatheredParameters(linear.weight, modifier_rank=0): if torch.distributed.get_rank() == 0: linear.weight.zero_() with deepspeed.zero.GatheredParameters(linear.weight, modifier_rank=0): if torch.distributed.get_rank() == 0: linear.weight.zero_() #. Collect a partitioned weight to pass to another module during training. The parameter will be registered as an external parameter and made available during the backward pass. .. code-block:: python :emphasize-lines: 6 def forward(self, input): x = self.layer1(input) # self.layer1.weight is required by self.layer2.forward with deepspeed.zero.GatheredParameters(self.layer1.weight, fwd_module=self): y = self.layer2(x, self.layer1.weight) return y #. Pretrained model loading .. code-block:: python with deepspeed.zero.Init(): model = MyModel() state_dict = torch.load(model_path, map_location="cpu") def load(module: nn.Module, prefix=""): # because zero3 puts placeholders in model params, this context # manager gathers (unpartitions) the params of the current layer, then loads from # the state dict and then re-partitions them again with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): if torch.distributed.get_rank() == 0: module._load_from_state_dict(state_dict, prefix) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") load(model, prefix="") If this approach is not used, then the full model will first get copied to each GPU. For models bigger than the memory of a single gpu this method is required. """ self.enabled = enabled if not enabled: return if not isinstance(params, list): params = [params] # enable if at least one is zero-param, otherwise a noop if not any(is_zero_param(p) for p in params): self.enabled = False return self.params = [p for p in params if hasattr(p, "ds_id")] self.src_rank = None if modifier_rank is not None: if self.params[ 0].ds_process_group == torch.distributed.group.WORLD: self.src_rank = modifier_rank else: # A group was specified; convert DP rank to global rank self.src_rank = _get_global_rank( self.params[0].ds_process_group, modifier_rank) self.fwd_module = fwd_module if self.fwd_module is not None: # is a no-op if already registered for p in self.params: register_external_parameter(self.fwd_module, p)
def step(self, closure=None): """ Not supporting closure. """ # First compute norm for all group so we know if there is overflow self.overflow = self.overflow_checker.check() prev_scale = self.loss_scale self._update_scale(self.overflow) if self.overflow: self.zero_grad() if self.verbose: print("[deepspeed] OVERFLOW! Skipping step. Attempted loss " "scale: {}, reducing to {}".format( prev_scale, self.loss_scale)) return self.overflow norm_groups = [] single_partition_grad_groups = [] partition_id = dist.get_rank(group=self.dp_process_group) for i, group in enumerate(self.fp16_groups): norm_groups.append(get_grad_norm(group, mpu=self.mpu)) #free gradients for all the parameters that are not updated by this process self.free_grad_in_param_list(self.params_not_in_partition[i]) #create a flat gradients for parameters updated by this process single_grad_partition = self.get_flat_partition( self.params_in_partition[i], self.first_offset[i], self.partition_size[i], dtype=self.single_partition_of_fp32_groups[i].dtype) self.single_partition_of_fp32_groups[ i].grad = single_grad_partition #release all the gradient since we have already created a necessary copy in dp_grad_partition self.free_grad_in_param_list(self.params_in_partition[i]) single_partition_grad_groups.append(single_grad_partition) self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups) self.optimizer.step() #get rid of the fp32 gradients. Not needed anymore for group in self.single_partition_of_fp32_groups: group.grad = None for fp16_partitions, fp32_partition in zip( self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): fp16_partitions[partition_id].data.copy_(fp32_partition.data) dp_world_size = dist.get_world_size(group=self.dp_process_group) #gather the updated weights from everyone for _, partitioned_params in enumerate( self.parallel_partitioned_fp16_groups): if self.all_gather_partitions: # controllable memory-time tradeoff num_shards = max( 1, partitioned_params[partition_id].numel() * dp_world_size // self.allgather_size) shard_size = partitioned_params[partition_id].numel( ) // num_shards num_elements = shard_size for shard_id in range(num_shards + 1): if shard_id == num_shards: if shard_size * num_shards >= partitioned_params[ partition_id].numel(): break else: num_elements = partitioned_params[ partition_id].numel() - shard_id * shard_size shard_list = [] for dp_id in range(dp_world_size): curr_shard = partitioned_params[dp_id].narrow( 0, shard_id * shard_size, num_elements) shard_list.append(curr_shard) dist.all_gather(shard_list, shard_list[partition_id], group=self.dp_process_group) else: #this should require less memory but should be faster for src, partitioned_param in enumerate(partitioned_params): global_src = _get_global_rank(self.dp_process_group, src) dist.broadcast(partitioned_param, global_src, group=self.dp_process_group) # TODO: we probably don't need this? just to be safe for i in range(len(norm_groups)): updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data = q.data return self.overflow
def get_global_ranks_from_group(group: ProcessGroup) -> List[int]: return [_get_global_rank(group, r) for r in range(group.size())]
def _get_rpc_name(self, rank: int) -> str: return self.worker_map[_get_global_rank(self.group, rank)]
def __init__(self, param, modifier_rank=None, fwd_module=None, enabled=True): """A context that collects a parameter that was partitioned via a :class:`deepspeed.zero.Init` context. The parameter is partitioned again upon exit. Args: param (``torch.nn.Parameter``): The parameter to collect. modifier_rank (int, optional): If specified, this rank's parameter will be broadcasted after the context. This argument is required if ``param`` is modified all processes should have a consistent view of the data. Defaults to ``None``. fwd_module (``torch.nn.Module``, optional): If specified, ``param`` will be registered as an external parameter of ``fwd_module``. See :meth:`deepspeed.zero.register_external_parameter`. enabled (bool, optional): If ``False``, this context is a no-op. Defaults to ``True``. Examples ======== #. Allocate a partitioned module, initialize its weight on rank 0, and update all processes. .. code-block:: python with deepspeed.zero.Init(): linear = torch.nn.Linear(1000,1000) with deepspeed.zero.GatheredParameters(linear.weight, modifier_rank=0): if torch.distributed.get_rank() == 0: linear.weight.zero_() #. Collect a partitioned weight to pass to another module during training. The parameter will be registered as an external parameter and made available during the backward pass. .. code-block:: python :emphasize-lines: 6 def forward(self, input): x = self.layer1(input) # self.layer1.weight is required by self.layer2.forward with deepspeed.zero.GatheredParameters(self.layer1.weight, fwd_module=self): y = self.layer2(x, self.layer1.weight) return y """ self.enabled = enabled if not enabled: return # This is a no-op, just return. if not is_zero_param(param): self.enabled = False return self.param = param self.src_rank = None if modifier_rank is not None: if self.param.ds_process_group == torch.distributed.group.WORLD: self.src_rank = modifier_rank else: # A group was specified; convert DP rank to global rank self.src_rank = _get_global_rank(self.param.ds_process_group, modifier_rank) self.fwd_module = fwd_module if self.fwd_module is not None: # is a no-op if already registered register_external_parameter(self.fwd_module, self.param)
def wait(self): nonlocal work, pg work.wait() src_rank = work.source_rank() return dist_c10d._get_global_rank(pg, src_rank)
def get_global_rank(group, group_rank): from torch.distributed.distributed_c10d import _get_global_rank return _get_global_rank(group, group_rank)
def _get_expert_broadcast_src_rank(group_name): return _get_global_rank(_get_expert_data_parallel_group(group_name), 0)
def _get_broadcast_src_rank(): return _get_global_rank(_get_data_parallel_group(), 0)
def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor": # relative imports to avoid circular dependency from torch.distributed._shard.sharded_tensor import ( ShardedTensor ) tensor_properties = sharded_tensor_meta.TensorProperties( dtype=tensor.dtype, layout=tensor.layout, requires_grad=tensor.requires_grad, memory_format=torch.contiguous_format, pin_memory=tensor.is_pinned() ) current_rank = dist.get_rank(process_group) tensor_meta = self.build_metadata(tensor.size(), tensor_properties) local_shards = [] local_tensor = None local_metadata = None tensors_to_scatter = [None] * dist.get_world_size(process_group) sharding_dim_size = tensor.size()[self.dim] # type: ignore[index] chunks = len(self.placements) split_size = get_split_size(sharding_dim_size, chunks) scatter_shape = list(tensor.size()) scatter_shape[self.dim] = split_size # type: ignore[index] for shard_meta in tensor_meta.shards_metadata: rank, device = _parse_and_validate_remote_device(process_group, shard_meta.placement) if current_rank == src_rank: # Reshape to get shard for this rank and we don't want autograd # recording here for the narrow op and 'local_shard' should be a # leaf variable in the autograd graph. narrowed_tensor = narrow_tensor(tensor, shard_meta) if shard_meta.shard_sizes[self.dim] < split_size: # type: ignore[index] # for the last shard that might be smaller to other shards # resize the narrowed tensor to the same size and use it for # the scatter collective as dist.scatter requires same size # inputs on every rank tensor_to_scatter = narrowed_tensor.detach().clone().resize_(scatter_shape) else: tensor_to_scatter = narrowed_tensor.detach().clone().contiguous() tensors_to_scatter[rank] = tensor_to_scatter if current_rank == rank: local_tensor = torch.empty( scatter_shape, dtype=tensor.dtype, layout=tensor.layout, device=device) local_metadata = shard_meta # each rank should have local_tensor and local_metadata initialized if we build # the metadata list in a correct way. assert local_tensor is not None assert local_metadata is not None # Scatter the shards to all ranks in the pg # scatter takes the global rank as ``src`` src_for_scatter = src_rank if process_group is not None and process_group is not distributed_c10d._get_default_group(): src_for_scatter = distributed_c10d._get_global_rank(process_group, src_for_scatter) dist.scatter( local_tensor, scatter_list=tensors_to_scatter if current_rank == src_rank else None, src=src_for_scatter, group=process_group ) if list(local_tensor.size()) != local_metadata.shard_sizes: # detach again after receiving to ensure local shards remain a leaf node local_tensor = local_tensor.resize_(local_metadata.shard_sizes).detach() # Sync requires_grad to local_shard. local_tensor.requires_grad = tensor.requires_grad local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata)) st = ShardedTensor._init_from_local_shards_and_global_metadata( local_shards, tensor_meta, process_group=process_group) # Manually set sharding_spec st._sharding_spec = self return st