def _init_enumerable(self, dims, tensor_init_params: TensorInitParams): # Validate the sharding spec is compatible with the tensor. check_tensor(self._sharding_spec.shards, dims) # type: ignore[attr-defined] current_rank = dist.get_rank(self._process_group) shards_metadata = [] for shard_metadata in self._sharding_spec.shards: # type: ignore[attr-defined] rank, local_device = _parse_and_validate_remote_device( self._process_group, shard_metadata.placement) shards_metadata.append(shard_metadata) if current_rank == rank: # Initialize the local shard. local_shard = _create_tensor_from_params( *shard_metadata.shard_sizes, local_device=local_device, tensor_init_params=tensor_init_params) self._local_shards.append(Shard(local_shard, shard_metadata)) # Build overall metadata self._metadata = ShardedTensorMetadata( shards_metadata, dims, tensor_init_params.tensor_properties, )
def build_global_metadata( gathered_metadatas: Sequence[Optional[ShardedTensorMetadata]]): global_sharded_tensor_metadata = None global_metadata_rank = 0 for rank, rank_metadata in enumerate(gathered_metadatas): if rank_metadata is None: continue if global_sharded_tensor_metadata is None: global_sharded_tensor_metadata = rank_metadata global_metadata_rank = rank else: _raise_if_mismatch(global_sharded_tensor_metadata.size, rank_metadata.size, "global_size", [global_metadata_rank, rank], is_local=False) # don't need to check layout and memory format as we already checked in local shards validation stage _raise_if_mismatch( global_sharded_tensor_metadata.tensor_properties.dtype, rank_metadata.tensor_properties.dtype, "dtype", [global_metadata_rank, rank], is_local=False) _raise_if_mismatch( global_sharded_tensor_metadata.tensor_properties.requires_grad, rank_metadata.tensor_properties.requires_grad, "requires_grad", [global_metadata_rank, rank], is_local=False) _raise_if_mismatch( global_sharded_tensor_metadata.tensor_properties.pin_memory, rank_metadata.tensor_properties.pin_memory, "pin_memory", [global_metadata_rank, rank], is_local=False) # pass all validations, extend shards metadata global_sharded_tensor_metadata.shards_metadata.extend( rank_metadata.shards_metadata) if global_sharded_tensor_metadata is not None: # check if shards_metadata have overlap shards validate_non_overlapping_shards_metadata( global_sharded_tensor_metadata.shards_metadata) # check if the shards_metadata is compatible with global size of the sharded tensor. check_tensor(global_sharded_tensor_metadata.shards_metadata, global_sharded_tensor_metadata.size) else: raise ValueError("ShardedTensor have no local shards on all ranks!") return global_sharded_tensor_metadata
def _init_enumerable( self, dims, dtype, layout, requires_grad, pin_memory, memory_format, ): # Validate the sharding spec is compatible with the tensor. check_tensor(self._sharding_spec.shards, dims) # type: ignore[attr-defined] current_rank = dist.get_rank(self._process_group) shards_metadata = [] for shard_metadata in self._sharding_spec.shards: # type: ignore[attr-defined] rank, local_device = self._parse_and_validate_remote_device( shard_metadata.placement) shards_metadata.append(shard_metadata) if current_rank == rank: # Initialize the local shard. local_shard = torch.empty( *shard_metadata.shard_lengths, dtype=dtype, layout=layout, device=local_device, requires_grad=requires_grad, memory_format=memory_format, pin_memory=pin_memory, ) self._local_shards.append(Shard(local_shard, shard_metadata)) # Build overall metadata self._metadata = ShardedTensorMetadata( shards_metadata, dims, dtype, layout, requires_grad, memory_format, pin_memory, )
def _init_from_local_shards( cls, local_shards: List[Shard], sharded_tensor_metadata: ShardedTensorMetadata, process_group=None, init_rrefs=False, ): shards_metadata = sharded_tensor_metadata.shards_metadata tensor_properties = sharded_tensor_metadata.tensor_properties if len(shards_metadata) == 0: raise ValueError("shards_metadata must not be empty!") if tensor_properties.layout != torch.strided: raise ValueError('Only torch.strided layout is currently supported') sharded_tensor = cls.__new__(cls) # prepare initialization sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) sharded_tensor._metadata = sharded_tensor_metadata # no sharding spec for sharded tensors that initialized # from this API. sharded_tensor._sharding_spec = None current_rank = dist.get_rank(sharded_tensor._process_group) local_shard_metadatas = [] # collect local shard metadatas from the global sharded_tensor_metadata for shard_metadata in shards_metadata: # type: ignore[attr-defined] rank, local_device = sharded_tensor._parse_and_validate_remote_device(shard_metadata.placement) if current_rank == rank: local_shard_metadatas.append(shard_metadata) if len(local_shards) != len(local_shard_metadatas): raise RuntimeError( f'Number of local shards ({len(local_shards)}) does not match number of local ' f'shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) ' f'on rank ({current_rank}) ' ) for shard in local_shards: shard_meta = shard.metadata local_shard_tensor = shard.tensor rank, local_device = sharded_tensor._parse_and_validate_remote_device(shard_meta.placement) # validate if shard_meta in the metadatas collected from sharded_tensor_metadata assert shard_meta in local_shard_metadatas, \ "local shard metadata not in sharded_tensor_metadata!" if local_shard_tensor.layout != tensor_properties.layout: raise ValueError( f'Local shard tensor layout does not match with tensor_properties! ' f'local shard tensor layout: {local_shard_tensor.dtype}, ' f'tensor_properties layout: {tensor_properties.layout}' ) if not local_shard_tensor.is_contiguous(): raise ValueError('Only torch.contiguous_format memory_format is currently supported') if shard_meta.shard_lengths != list(local_shard_tensor.size()): raise ValueError( f'Local shard tensor is incompatible with local ShardMetadata! ' f'local shard tensor size: {local_shard_tensor.size()}, ' f'local ShardMetadata shard lengths: {shard_meta.shard_lengths}' ) if local_shard_tensor.is_pinned() != tensor_properties.pin_memory: raise ValueError( f'Local shard tensor pin_memory does not match with tensor_properties! ' f'local shard tensor pin_memory: {local_shard_tensor.is_pinned()}, ' f'tensor_properties pin_memory: {tensor_properties.pin_memory}' ) if local_shard_tensor.device != local_device: raise ValueError( f'Local shard tensor device does not match with local Shard placement! ' f'local shard tensor device: {local_shard_tensor.device}, ' f'local shard metadata placement device: {local_device}' ) if local_shard_tensor.dtype != tensor_properties.dtype: raise ValueError( f'Local shard tensor dtype does not match with tensor_properties! ' f'local shard tensor dtype: {local_shard_tensor.dtype}, ' f'tensor_properties dtype: {tensor_properties.dtype}' ) if local_shard_tensor.requires_grad != tensor_properties.requires_grad: raise ValueError( f'Local shard tensor requires_grad does not match with tensor_properties! ' f'local shard tensor requires_grad: {local_shard_tensor.requires_grad}, ' f'tensor_properties requires_grad: {tensor_properties.requires_grad}' ) # check if shards_metadata have overlap shards validate_non_overlapping_shards_metadata(shards_metadata) # check if the shards_metadata is compatible with overall size of the sharded tensor. check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) # done validation, add local_shards sharded_tensor._local_shards = local_shards # run post initialization, i.e. map registration, rpc initialization sharded_tensor._post_init() return sharded_tensor
def test_enumerable_sharding_spec(self): # test valid specs # test row-wise sharding spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="cuda:1", ) ]) check_tensor(spec.shards, torch.rand(10, 5).size()) # test row and column sharding spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[3, 3], placement="cuda:0", ), ShardMetadata( shard_offsets=[0, 3], shard_lengths=[3, 3], placement="cuda:1", ), ShardMetadata( shard_offsets=[3, 0], shard_lengths=[3, 3], placement="cuda:2", ), ShardMetadata( shard_offsets=[3, 3], shard_lengths=[3, 3], placement="cuda:3", ), ]) check_tensor(spec.shards, torch.rand(6, 6).size()) # test uneven shard sizes. spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[2, 4], placement="cuda:0", ), ShardMetadata( shard_offsets=[0, 4], shard_lengths=[4, 2], placement="cuda:1", ), ShardMetadata( shard_offsets=[2, 0], shard_lengths=[4, 4], placement="cuda:2", ), ShardMetadata( shard_offsets=[4, 4], shard_lengths=[2, 2], placement="cuda:3", ), ]) check_tensor(spec.shards, torch.rand(6, 6).size()) # test invalid sharding with self.assertRaisesRegex(ValueError, 'Could not parse remote_device'): ShardMetadata(shard_offsets=[0], shard_lengths=[1], placement="cuda:foo") with self.assertRaisesRegex(ValueError, 'same number of elements'): ShardMetadata(shard_offsets=[0, 0], shard_lengths=[1], placement="cuda:0") with self.assertRaisesRegex(ValueError, 'shard_offsets should be >=0'): ShardMetadata(shard_offsets=[-1, 0], shard_lengths=[1, 1], placement="cuda:0") with self.assertRaisesRegex(ValueError, 'shard_lengths should be > 0'): ShardMetadata(shard_offsets=[0, 0], shard_lengths=[0, 1], placement="cuda:0") with self.assertRaisesRegex(ValueError, 'Empty shard list provided'): EnumerableShardingSpec([]) with self.assertRaisesRegex(ValueError, 'Found inconsistent ranks for shards'): EnumerableShardingSpec([ ShardMetadata(shard_offsets=[0, 0], shard_lengths=[1, 1], placement="cpu"), ShardMetadata(shard_offsets=[0, 0, 0], shard_lengths=[1, 1, 1], placement="cpu"), ]) with self.assertRaisesRegex(ValueError, 'Shards.*overlap'): EnumerableShardingSpec([ ShardMetadata(shard_offsets=[0, 0], shard_lengths=[3, 3], placement="cpu"), ShardMetadata(shard_offsets=[2, 0], shard_lengths=[3, 3], placement="cpu"), ]) spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="cuda:1", ) ]) with self.assertRaisesRegex(ValueError, 'Rank of tensor is.*but shards rank'): check_tensor(spec.shards, torch.rand(10, 10, 10).size()) spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="cuda:1", ) ]) with self.assertRaisesRegex(ValueError, 'exceeds tensor dim'): check_tensor(spec.shards, torch.rand(10, 3).size()) spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 5], shard_lengths=[5, 5], placement="cuda:1", ) ]) with self.assertRaisesRegex(ValueError, 'does not match tensor volume'): check_tensor(spec.shards, torch.rand(10, 10).size())
def _init_from_local_shards_and_global_metadata( cls, local_shards: List[Shard], sharded_tensor_metadata: ShardedTensorMetadata, process_group=None, init_rrefs=False, ): """ Initialize a ShardedTensor with local shards and a global ShardedTensorMetadata built on each rank. Warning: This API is experimental and subject to change. It does not do cross rank validations, and fully rely on the user for the correctness of sharded_tensor_metadata on each rank """ process_group = (process_group if process_group is not None else distributed_c10d._get_default_group()) current_rank = dist.get_rank(process_group) shards_metadata = sharded_tensor_metadata.shards_metadata tensor_properties = sharded_tensor_metadata.tensor_properties if len(shards_metadata) == 0: raise ValueError("shards_metadata must not be empty!") if tensor_properties.layout != torch.strided: raise ValueError( 'Only torch.strided layout is currently supported') sharded_tensor = cls.__new__(cls) sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) sharded_tensor._metadata = sharded_tensor_metadata local_shard_metadatas = [] def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False): tensor_property_or_metadata = "tensor property" if is_property else "local ShardMetadata" if expected != actual: raise ValueError( f"Local shards' tensor {prop_name} property is incompatible with " f"{tensor_property_or_metadata} on rank {rank}: " f"{tensor_property_or_metadata} {prop_name}={expected}, " f"local shard tensor {prop_name}={actual}.") # collect local shard metadatas from the global sharded_tensor_metadata for shard_metadata in shards_metadata: # type: ignore[attr-defined] rank, local_device = _parse_and_validate_remote_device( sharded_tensor._process_group, shard_metadata.placement) if current_rank == rank: local_shard_metadatas.append(shard_metadata) if len(local_shards) != len(local_shard_metadatas): raise RuntimeError( f'Number of local shards ({len(local_shards)}) does not match number of local ' f'shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) ' f'on rank ({current_rank}) ') for shard in local_shards: shard_meta = shard.metadata local_shard_tensor = shard.tensor rank, local_device = _parse_and_validate_remote_device( sharded_tensor._process_group, shard_meta.placement) # validate if shard_meta in the metadatas collected from sharded_tensor_metadata assert shard_meta in local_shard_metadatas, \ "local shard metadata not in sharded_tensor_metadata!" _raise_if_mismatch(tensor_properties.layout, local_shard_tensor.layout, "layout", current_rank, True) if not local_shard_tensor.is_contiguous(): raise ValueError( 'Only torch.contiguous_format memory_format is currently supported' ) _raise_if_mismatch(shard_meta.shard_sizes, list(local_shard_tensor.size()), "size", current_rank) _raise_if_mismatch(tensor_properties.pin_memory, local_shard_tensor.is_pinned(), "pin_memory", current_rank, True) _raise_if_mismatch(local_device, local_shard_tensor.device, "device", current_rank) _raise_if_mismatch(tensor_properties.dtype, local_shard_tensor.dtype, "dtype", current_rank, True) _raise_if_mismatch(tensor_properties.requires_grad, local_shard_tensor.requires_grad, "requires_grad", current_rank, True) # check if shards_metadata have overlap shards validate_non_overlapping_shards_metadata(shards_metadata) # check if the shards_metadata is compatible with overall size of the sharded tensor. check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) # done validation, add local_shards sharded_tensor._local_shards = local_shards # make a EnumerableShardingSpec for sharded tensors that initialized from this API. # TODO: make sharding spec a ChunkShardingSpec by inferring from the metadata list. # see issue https://github.com/pytorch/pytorch/issues/67244 sharded_tensor._sharding_spec = EnumerableShardingSpec(shards_metadata) # run post initialization, i.e. map registration, rpc initialization sharded_tensor._post_init() return sharded_tensor