def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor": # relative imports to avoid circular dependency from torch.distributed._shard.sharded_tensor import ( ShardedTensor ) tensor_properties = sharded_tensor_meta.TensorProperties( dtype=tensor.dtype, layout=tensor.layout, requires_grad=tensor.requires_grad, memory_format=torch.contiguous_format, pin_memory=tensor.is_pinned() ) tensor_meta = self.build_metadata(tensor.size(), tensor_properties) local_shards = [] current_rank = dist.get_rank(process_group) # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient). dist.broadcast(tensor, src=src_rank, group=process_group) for shard_meta in tensor_meta.shards_metadata: rank, device = _parse_and_validate_remote_device(process_group, shard_meta.placement) if rank == current_rank: # Reshape to get shard for this rank and we don't want autograd # recording here for the narrow op and 'local_shard' should be a # leaf variable in the autograd graph. local_tensor = narrow_tensor(tensor, shard_meta).clone().detach().contiguous() # Sync requires_grad to local_shard. local_tensor.requires_grad = tensor.requires_grad local_shards.append( Shard( tensor=local_tensor, metadata=shard_meta ) ) st = ShardedTensor._init_from_local_shards(local_shards, tensor.size(), process_group=process_group) # Manually set sharding_spec st._sharding_spec = self return st
def binary_math_op_impl(op, types, args=(), kwargs=None, pg=None): """ Handles ``__torch_function__`` dispatch for the binary math ops such as `torch.add`, `torch.mul`, `torch.div`, etc. This method computes on ShardedTensor, or ShardedTensor op ReplicatedTensor """ if len(args) != 2: raise ValueError( "Only support binary math op on ShardedTensor for now!") lhs = args[0] rhs = args[1] # Validate types if isinstance(lhs, ReplicatedTensor): assert isinstance(rhs, ShardedTensor) st_size = rhs.size() st_meta = rhs.local_shards()[0].metadata if st_size != lhs.size(): # try to broadcast replicated tensor lhs = lhs.expand(st_size) replica_part = narrow_tensor(lhs, st_meta) res = op(replica_part, rhs.local_tensor()) return ShardedTensor._init_from_local_tensor( res, rhs.sharding_spec(), rhs.size(), # type: ignore[arg-type] process_group=pg) elif isinstance(rhs, ReplicatedTensor): assert isinstance(lhs, ShardedTensor) st_size = lhs.size() st_meta = lhs.local_shards()[0].metadata if st_size != rhs.size(): # try to broadcast replicated tensor rhs = rhs.expand(st_size) replica_part = narrow_tensor(rhs, st_meta) res = op(lhs.local_tensor(), replica_part) return ShardedTensor._init_from_local_tensor( res, lhs.sharding_spec(), lhs.size(), # type: ignore[arg-type] process_group=pg) elif isinstance(lhs, (int, float)): assert isinstance(rhs, ShardedTensor) res = op(lhs, rhs.local_tensor()) return ShardedTensor._init_from_local_tensor( res, rhs.sharding_spec(), rhs.size(), # type: ignore[arg-type] process_group=pg) elif isinstance(rhs, (int, float)): assert isinstance(lhs, ShardedTensor) res = op(lhs.local_tensor(), rhs) return ShardedTensor._init_from_local_tensor( res, lhs.sharding_spec(), lhs.size(), # type: ignore[arg-type] process_group=pg) else: raise RuntimeError( f"torch function '{op.__name__}', with args: {args} and " f"kwargs: {kwargs} not supported yet for ShardedTensor!")
def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor": # relative imports to avoid circular dependency from torch.distributed._shard.sharded_tensor import (ShardedTensor) tensor_properties = sharded_tensor_meta.TensorProperties( dtype=tensor.dtype, layout=tensor.layout, requires_grad=tensor.requires_grad, memory_format=torch.contiguous_format, pin_memory=tensor.is_pinned()) current_rank = dist.get_rank(process_group) tensor_meta = self.build_metadata(tensor.size(), tensor_properties) local_shards = [] local_tensor = None local_metadata = None tensors_to_scatter = [None] * dist.get_world_size(process_group) sharding_dim_size = tensor.size()[self.dim] # type: ignore[index] chunks = len(self.placements) split_size = get_split_size(sharding_dim_size, chunks) scatter_shape = list(tensor.size()) scatter_shape[self.dim] = split_size # type: ignore[index] for shard_meta in tensor_meta.shards_metadata: rank, device = _parse_and_validate_remote_device( process_group, shard_meta.placement) if current_rank == src_rank: # Reshape to get shard for this rank and we don't want autograd # recording here for the narrow op and 'local_shard' should be a # leaf variable in the autograd graph. narrowed_tensor = narrow_tensor(tensor, shard_meta) if shard_meta.shard_sizes[ self.dim] < split_size: # type: ignore[index] # for the last shard that might be smaller to other shards # resize the narrowed tensor to the same size and use it for # the scatter collective as dist.scatter requires same size # inputs on every rank tensor_to_scatter = narrowed_tensor.detach().clone( ).resize_(scatter_shape) else: tensor_to_scatter = narrowed_tensor.detach().clone( ).contiguous() tensors_to_scatter[rank] = tensor_to_scatter if current_rank == rank: local_tensor = torch.empty(scatter_shape, dtype=tensor.dtype, layout=tensor.layout, device=device) local_metadata = shard_meta # each rank should have local_tensor and local_metadata initialized if we build # the metadata list in a correct way. assert local_tensor is not None assert local_metadata is not None # Scatter the shards to all ranks in the pg dist.scatter(local_tensor, scatter_list=tensors_to_scatter if current_rank == src_rank else None, src=src_rank, group=process_group) if list(local_tensor.size()) != local_metadata.shard_sizes: # detach again after receiving to ensure local shards remain a leaf node local_tensor = local_tensor.resize_( local_metadata.shard_sizes).detach() # Sync requires_grad to local_shard. local_tensor.requires_grad = tensor.requires_grad local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata)) st = ShardedTensor._init_from_local_shards_and_global_metadata( local_shards, tensor_meta, process_group=process_group) # Manually set sharding_spec st._sharding_spec = self return st