def _init_enumerable(self, dims, tensor_init_params: TensorInitParams): # Validate the sharding spec is compatible with the tensor. check_tensor(self._sharding_spec.shards, dims) # type: ignore[attr-defined] current_rank = dist.get_rank(self._process_group) shards_metadata = [] for shard_metadata in self._sharding_spec.shards: # type: ignore[attr-defined] rank, local_device = _parse_and_validate_remote_device( self._process_group, shard_metadata.placement) shards_metadata.append(shard_metadata) if current_rank == rank: # Initialize the local shard. local_shard = _create_tensor_from_params( *shard_metadata.shard_sizes, local_device=local_device, tensor_init_params=tensor_init_params) self._local_shards.append(Shard(local_shard, shard_metadata)) # Build overall metadata self._metadata = ShardedTensorMetadata( shards_metadata, dims, tensor_init_params.tensor_properties, )
def build_global_metadata( gathered_metadatas: Sequence[Optional[ShardedTensorMetadata]]): global_sharded_tensor_metadata = None global_metadata_rank = 0 for rank, rank_metadata in enumerate(gathered_metadatas): if rank_metadata is None: continue if global_sharded_tensor_metadata is None: global_sharded_tensor_metadata = copy.deepcopy(rank_metadata) global_metadata_rank = rank else: _raise_if_mismatch(global_sharded_tensor_metadata.size, rank_metadata.size, "global_size", [global_metadata_rank, rank], is_local=False) # don't need to check layout and memory format as we already checked in local shards validation stage _raise_if_mismatch( global_sharded_tensor_metadata.tensor_properties.dtype, rank_metadata.tensor_properties.dtype, "dtype", [global_metadata_rank, rank], is_local=False) _raise_if_mismatch( global_sharded_tensor_metadata.tensor_properties.requires_grad, rank_metadata.tensor_properties.requires_grad, "requires_grad", [global_metadata_rank, rank], is_local=False) _raise_if_mismatch( global_sharded_tensor_metadata.tensor_properties.pin_memory, rank_metadata.tensor_properties.pin_memory, "pin_memory", [global_metadata_rank, rank], is_local=False) # pass all validations, extend shards metadata global_sharded_tensor_metadata.shards_metadata.extend( rank_metadata.shards_metadata) if global_sharded_tensor_metadata is not None: # check if shards_metadata have overlap shards validate_non_overlapping_shards_metadata( global_sharded_tensor_metadata.shards_metadata) # check if the shards_metadata is compatible with global size of the sharded tensor. check_tensor(global_sharded_tensor_metadata.shards_metadata, global_sharded_tensor_metadata.size) else: raise ValueError("ShardedTensor have no local shards on all ranks!") return global_sharded_tensor_metadata
def test_custom_sharding_spec(self): ranks = [ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ] grid_spec = GridShardingSpec(grid_size=4, placements=ranks) tensor_properties = TensorProperties( dtype=torch.get_default_dtype(), layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False, ) meta = grid_spec.build_metadata(torch.Size((8, 8)), tensor_properties) check_tensor(meta.shards_metadata, torch.Size((8, 8)))
def _init_from_local_shards_and_global_metadata( cls, local_shards: List[Shard], sharded_tensor_metadata: ShardedTensorMetadata, process_group=None, init_rrefs=False, ) -> "ShardedTensor": """ Initialize a ShardedTensor with local shards and a global ShardedTensorMetadata built on each rank. Warning: This API is experimental and subject to change. It does not do cross rank validations, and fully rely on the user for the correctness of sharded_tensor_metadata on each rank """ process_group = (process_group if process_group is not None else distributed_c10d._get_default_group()) current_rank = dist.get_rank(process_group) shards_metadata = sharded_tensor_metadata.shards_metadata tensor_properties = sharded_tensor_metadata.tensor_properties if len(shards_metadata) == 0: raise ValueError("shards_metadata must not be empty!") if tensor_properties.layout != torch.strided: raise ValueError( 'Only torch.strided layout is currently supported') sharded_tensor = cls.__new__(cls) sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) sharded_tensor._metadata = sharded_tensor_metadata local_shard_metadatas = [] def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False): tensor_property_or_metadata = "tensor property" if is_property else "local ShardMetadata" if expected != actual: raise ValueError( f"Local shards' tensor {prop_name} property is incompatible with " f"{tensor_property_or_metadata} on rank {rank}: " f"{tensor_property_or_metadata} {prop_name}={expected}, " f"local shard tensor {prop_name}={actual}.") # collect local shard metadatas from the global sharded_tensor_metadata for shard_metadata in shards_metadata: # type: ignore[attr-defined] rank, local_device = _parse_and_validate_remote_device( sharded_tensor._process_group, shard_metadata.placement) if current_rank == rank: local_shard_metadatas.append(shard_metadata) if len(local_shards) != len(local_shard_metadatas): raise RuntimeError( f'Number of local shards ({len(local_shards)}) does not match number of local ' f'shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) ' f'on rank ({current_rank}) ') for shard in local_shards: shard_meta = shard.metadata local_shard_tensor = shard.tensor rank, local_device = _parse_and_validate_remote_device( sharded_tensor._process_group, shard_meta.placement) # validate if shard_meta in the metadatas collected from sharded_tensor_metadata assert shard_meta in local_shard_metadatas, \ "local shard metadata not in sharded_tensor_metadata!" _raise_if_mismatch(tensor_properties.layout, local_shard_tensor.layout, "layout", current_rank, True) if not local_shard_tensor.is_contiguous(): raise ValueError( 'Only torch.contiguous_format memory_format is currently supported' ) _raise_if_mismatch(shard_meta.shard_sizes, list(local_shard_tensor.size()), "size", current_rank) _raise_if_mismatch(tensor_properties.pin_memory, local_shard_tensor.is_pinned(), "pin_memory", current_rank, True) _raise_if_mismatch(local_device, local_shard_tensor.device, "device", current_rank) _raise_if_mismatch(tensor_properties.dtype, local_shard_tensor.dtype, "dtype", current_rank, True) _raise_if_mismatch(tensor_properties.requires_grad, local_shard_tensor.requires_grad, "requires_grad", current_rank, True) # check if shards_metadata have overlap shards validate_non_overlapping_shards_metadata(shards_metadata) # check if the shards_metadata is compatible with overall size of the sharded tensor. check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) # done validation, add local_shards sharded_tensor._local_shards = local_shards sharded_tensor._sharding_spec = _infer_sharding_spec_from_shards_metadata( shards_metadata) # run post initialization, i.e. map registration, rpc initialization sharded_tensor._post_init() return sharded_tensor
def test_enumerable_sharding_spec(self): # test valid specs # test row-wise sharding spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="cuda:1", ) ]) check_tensor(spec.shards, torch.rand(10, 5).size()) # test row and column sharding spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[3, 3], placement="cuda:0", ), ShardMetadata( shard_offsets=[0, 3], shard_sizes=[3, 3], placement="cuda:1", ), ShardMetadata( shard_offsets=[3, 0], shard_sizes=[3, 3], placement="cuda:2", ), ShardMetadata( shard_offsets=[3, 3], shard_sizes=[3, 3], placement="cuda:3", ), ]) check_tensor(spec.shards, torch.rand(6, 6).size()) # test uneven shard sizes. spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[2, 4], placement="cuda:0", ), ShardMetadata( shard_offsets=[0, 4], shard_sizes=[4, 2], placement="cuda:1", ), ShardMetadata( shard_offsets=[2, 0], shard_sizes=[4, 4], placement="cuda:2", ), ShardMetadata( shard_offsets=[4, 4], shard_sizes=[2, 2], placement="cuda:3", ), ]) check_tensor(spec.shards, torch.rand(6, 6).size()) # test invalid sharding with self.assertRaisesRegex(ValueError, 'Could not parse remote_device'): ShardMetadata(shard_offsets=[0], shard_sizes=[1], placement="cuda:foo") with self.assertRaisesRegex(ValueError, 'same number of elements'): ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1], placement="cuda:0") with self.assertRaisesRegex(ValueError, 'shard_offsets should be >=0'): ShardMetadata(shard_offsets=[-1, 0], shard_sizes=[1, 1], placement="cuda:0") with self.assertRaisesRegex(ValueError, 'shard_sizes should be >= 0'): ShardMetadata(shard_offsets=[0, 0], shard_sizes=[-1, 1], placement="cuda:0") with self.assertRaisesRegex(ValueError, 'Empty shard list provided'): EnumerableShardingSpec([]) with self.assertRaisesRegex(ValueError, 'Found inconsistent ranks for shards'): EnumerableShardingSpec([ ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1, 1], placement="cpu"), ShardMetadata(shard_offsets=[0, 0, 0], shard_sizes=[1, 1, 1], placement="cpu"), ]) with self.assertRaisesRegex(ValueError, 'Shards.*overlap'): EnumerableShardingSpec([ ShardMetadata(shard_offsets=[0, 0], shard_sizes=[3, 3], placement="cpu"), ShardMetadata(shard_offsets=[2, 0], shard_sizes=[3, 3], placement="cpu"), ]) spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="cuda:1", ) ]) with self.assertRaisesRegex(ValueError, 'Rank of tensor is.*but shards rank'): check_tensor(spec.shards, torch.rand(10, 10, 10).size()) spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="cuda:1", ) ]) with self.assertRaisesRegex(ValueError, 'exceeds tensor dim'): check_tensor(spec.shards, torch.rand(10, 3).size()) spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], placement="cuda:1", ) ]) with self.assertRaisesRegex(ValueError, 'does not match tensor volume'): check_tensor(spec.shards, torch.rand(10, 10).size())
def _init_from_local_shards_and_global_metadata( cls, local_shards: List[Shard], sharded_tensor_metadata: ShardedTensorMetadata, sharding_spec=None, ) -> "ShardedTensor": """ Initialize a ShardedTensorBase with local shards and a global ShardedTensorMetadata built on each rank. Warning: This API is experimental and subject to change. It does not do cross rank validations, and fully rely on the user for the correctness of sharded_tensor_metadata on each rank """ shards_metadata = sharded_tensor_metadata.shards_metadata tensor_properties = sharded_tensor_metadata.tensor_properties if len(shards_metadata) == 0: raise ValueError("shards_metadata must not be empty!") if tensor_properties.layout != torch.strided: raise ValueError( "Only torch.strided layout is currently supported") if sharding_spec is None: spec = shard_spec._infer_sharding_spec_from_shards_metadata( shards_metadata) else: spec = sharding_spec sharded_tensor_base = ShardedTensor.__new__( ShardedTensor, spec, sharded_tensor_metadata.size, dtype=tensor_properties.dtype, layout=tensor_properties.layout, pin_memory=tensor_properties.pin_memory, requires_grad=tensor_properties.requires_grad, ) def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False): tensor_property_or_metadata = ("tensor property" if is_property else "local ShardMetadata") if expected != actual: raise ValueError( f"Local shards' tensor {prop_name} property is incompatible with " f"{tensor_property_or_metadata} on rank {rank}: " f"{tensor_property_or_metadata} {prop_name}={expected}, " f"local shard tensor {prop_name}={actual}.") for shard in local_shards: shard_meta = shard.metadata local_shard_tensor = shard.tensor placement = shard_meta.placement assert placement is not None, "Must specify placement for `Shard`!" rank = placement.rank() local_device = placement.device() _raise_if_mismatch( tensor_properties.layout, local_shard_tensor.layout, "layout", rank, True, ) if not local_shard_tensor.is_contiguous(): raise ValueError( "Only torch.contiguous_format memory_format is currently supported" ) _raise_if_mismatch( shard_meta.shard_sizes, list(local_shard_tensor.size()), "size", rank, ) _raise_if_mismatch( tensor_properties.pin_memory, local_shard_tensor.is_pinned(), "pin_memory", rank, True, ) _raise_if_mismatch(local_device, local_shard_tensor.device, "device", rank) _raise_if_mismatch( tensor_properties.dtype, local_shard_tensor.dtype, "dtype", rank, True, ) _raise_if_mismatch( tensor_properties.requires_grad, local_shard_tensor.requires_grad, "requires_grad", rank, True, ) # check if shards_metadata have overlap shards validate_non_overlapping_shards_metadata(shards_metadata) # check if the shards_metadata is compatible with overall size of the sharded tensor. check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) # done validation, add local_shards sharded_tensor_base._local_shards = local_shards return sharded_tensor_base