def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor": # relative imports to avoid circular dependency from torch.distributed._shard.sharded_tensor import (ShardedTensor) tensor_properties = sharded_tensor_meta.TensorProperties( dtype=tensor.dtype, layout=tensor.layout, requires_grad=tensor.requires_grad, memory_format=torch.contiguous_format, pin_memory=tensor.is_pinned()) tensor_meta = self.build_metadata(tensor.size(), tensor_properties) local_shards = [] current_rank = dist.get_rank(process_group) # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient). dist.broadcast(tensor, src=src_rank, group=process_group) for shard_meta in tensor_meta.shards_metadata: rank, device = _parse_and_validate_remote_device( process_group, shard_meta.placement) if rank == current_rank: shard_offsets = shard_meta.shard_offsets shard_sizes = shard_meta.shard_sizes local_tensor = tensor for idx, (offset, size) in enumerate(zip(shard_offsets, shard_sizes)): if size < tensor.size(idx): # Reshape to get shard for this rank and we don't want autograd # recording here for the narrow op and 'local_shard' should be a # leaf variable in the autograd graph. local_tensor = local_tensor.narrow( idx, shard_offsets[idx], shard_sizes[idx]).clone().detach().contiguous() # Sync requires_grad to local_shard. local_tensor.requires_grad = tensor.requires_grad local_shards.append( Shard(tensor=local_tensor, metadata=shard_meta)) st = ShardedTensor._init_from_local_shards(local_shards, tensor.size(), process_group=process_group) # Manually set sharding_spec st._sharding_spec = self return st
def _create_shard_for(tensor: Tensor) -> Shard: return Shard( tensor=tensor, metadata=_create_shard_metadata(tensor.size()), )
def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor": # relative imports to avoid circular dependency from torch.distributed._shard.sharded_tensor import (ShardedTensor) tensor_properties = sharded_tensor_meta.TensorProperties( dtype=tensor.dtype, layout=tensor.layout, requires_grad=tensor.requires_grad, memory_format=torch.contiguous_format, pin_memory=tensor.is_pinned()) current_rank = dist.get_rank(process_group) tensor_meta = self.build_metadata(tensor.size(), tensor_properties) local_shards = [] local_tensor = None local_metadata = None tensors_to_scatter = [None] * dist.get_world_size(process_group) sharding_dim_size = tensor.size()[self.dim] # type: ignore[index] chunks = len(self.placements) split_size = get_split_size(sharding_dim_size, chunks) scatter_shape = list(tensor.size()) scatter_shape[self.dim] = split_size # type: ignore[index] for shard_meta in tensor_meta.shards_metadata: rank, device = _parse_and_validate_remote_device( process_group, shard_meta.placement) if current_rank == src_rank: # Reshape to get shard for this rank and we don't want autograd # recording here for the narrow op and 'local_shard' should be a # leaf variable in the autograd graph. narrowed_tensor = narrow_tensor(tensor, shard_meta) if shard_meta.shard_sizes[ self.dim] < split_size: # type: ignore[index] # for the last shard that might be smaller to other shards # resize the narrowed tensor to the same size and use it for # the scatter collective as dist.scatter requires same size # inputs on every rank tensor_to_scatter = narrowed_tensor.detach().clone( ).resize_(scatter_shape) else: tensor_to_scatter = narrowed_tensor.detach().clone( ).contiguous() tensors_to_scatter[rank] = tensor_to_scatter if current_rank == rank: local_tensor = torch.empty(scatter_shape, dtype=tensor.dtype, layout=tensor.layout, device=device) local_metadata = shard_meta # each rank should have local_tensor and local_metadata initialized if we build # the metadata list in a correct way. assert local_tensor is not None assert local_metadata is not None # Scatter the shards to all ranks in the pg dist.scatter(local_tensor, scatter_list=tensors_to_scatter if current_rank == src_rank else None, src=src_rank, group=process_group) if list(local_tensor.size()) != local_metadata.shard_sizes: # detach again after receiving to ensure local shards remain a leaf node local_tensor = local_tensor.resize_( local_metadata.shard_sizes).detach() # Sync requires_grad to local_shard. local_tensor.requires_grad = tensor.requires_grad local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata)) st = ShardedTensor._init_from_local_shards_and_global_metadata( local_shards, tensor_meta, process_group=process_group) # Manually set sharding_spec st._sharding_spec = self return st
def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor": # relative imports to avoid circular dependency from torch.distributed._shard.sharded_tensor import ( ShardedTensor ) tensor_properties = sharded_tensor_meta.TensorProperties( dtype=tensor.dtype, layout=tensor.layout, requires_grad=tensor.requires_grad, memory_format=torch.contiguous_format, pin_memory=tensor.is_pinned() ) current_rank = dist.get_rank(process_group) tensor_meta = self.build_metadata(tensor.size(), tensor_properties) local_shards = [] local_tensor = None local_metadata = None tensors_to_scatter = [] for shard_meta in tensor_meta.shards_metadata: rank, device = _parse_and_validate_remote_device(process_group, shard_meta.placement) shard_offsets = shard_meta.shard_offsets shard_sizes = shard_meta.shard_sizes if current_rank == src_rank: narrowed_tensor = tensor for idx, (offset, size) in enumerate(zip(shard_offsets, shard_sizes)): if size < tensor.size(idx): # Reshape to get shard for this rank and we don't want autograd # recording here for the narrow op and 'local_shard' should be a # leaf variable in the autograd graph. narrowed_tensor = narrowed_tensor.narrow( idx, shard_offsets[idx], shard_sizes[idx] ).clone().detach().contiguous() tensors_to_scatter.append(narrowed_tensor) if current_rank == rank: local_tensor = torch.empty( shard_sizes, dtype=tensor.dtype, layout=tensor.layout, device=device) local_metadata = shard_meta # Scatter the shards to all ranks in the pg dist.scatter( local_tensor, scatter_list=tensors_to_scatter if current_rank == src_rank else None, src=src_rank, group=process_group ) assert local_tensor is not None assert local_metadata is not None # Sync requires_grad to local_shard. local_tensor.requires_grad = tensor.requires_grad local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata)) st = ShardedTensor._init_from_local_shards(local_shards, tensor.size(), process_group=process_group) # Manually set sharding_spec st._sharding_spec = self return st