def _infer_enum_sharding_spec_case(self): shards_metadata = [ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[10, 5], placement="cuda:1", ) ] spec = _infer_sharding_spec_from_shards_metadata(shards_metadata) self.assertTrue(isinstance(spec, EnumerableShardingSpec)) self.assertEqual(spec.shards, shards_metadata) shards_metadata = [ ShardMetadata( shard_offsets=[0], shard_sizes=[16], placement="cuda:0", ), ShardMetadata( shard_offsets=[16], shard_sizes=[9], placement="cuda:1", ) ] spec = _infer_sharding_spec_from_shards_metadata(shards_metadata) self.assertTrue(isinstance(spec, EnumerableShardingSpec)) self.assertEqual(spec.shards, shards_metadata) shards_metadata = [ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], placement="rank:2/cuda:2", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], placement="rank:3/cuda:3", ), ] spec = _infer_sharding_spec_from_shards_metadata(shards_metadata) self.assertTrue(isinstance(spec, EnumerableShardingSpec)) self.assertEqual(spec.shards, shards_metadata)
def _init_from_local_shards( cls, local_shards: List[Shard], *global_size, process_group=None, init_rrefs=False, ): # STEP 1: Validate the Shardmetadatas locally process_group = ( process_group if process_group is not None else distributed_c10d._get_default_group() ) current_rank = dist.get_rank(process_group) world_size = dist.get_world_size(process_group) local_sharded_tensor_metadata: Optional[ShardedTensorMetadata] = None global_tensor_size = _flatten_tensor_size(global_size) if len(local_shards) > 0: local_sharded_tensor_metadata = \ build_metadata_from_local_shards(local_shards, global_tensor_size, current_rank, process_group) # STEP 2. Validate metadata across ranks, and build a global sharded tensor # metadata by gathering local ShardedTensorMetadata gathered_metadatas: List[Optional[ShardedTensorMetadata]] = [] if world_size > 1: gathered_metadatas = [None for _ in range(world_size)] dist.all_gather_object( gathered_metadatas, local_sharded_tensor_metadata, group=process_group ) else: gathered_metadatas = [local_sharded_tensor_metadata] global_sharded_tensor_metadata = build_global_metadata(gathered_metadatas) tensor_properties = global_sharded_tensor_metadata.tensor_properties # STEP 3: Validation done, create the actual ShardedTensor and populate fields # prepare initialization spec = shard_spec._infer_sharding_spec_from_shards_metadata( global_sharded_tensor_metadata.shards_metadata ) sharded_tensor = cls.__new__(cls, spec, global_sharded_tensor_metadata.size, dtype=tensor_properties.dtype, layout=tensor_properties.layout, pin_memory=tensor_properties.pin_memory, requires_grad=tensor_properties.requires_grad) sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) # attach local_shards to the ShardedTensor created sharded_tensor._local_shards = local_shards # run post initialization, i.e. map registration, rpc initialization sharded_tensor._post_init() return sharded_tensor
def _infer_chunk_sharding_spec_case(self, placements, sharding_dim, st_size): world_size = len(placements) split_size = get_split_size(st_size[sharding_dim], world_size) shards_metadata = [None] * world_size for idx, placement in enumerate(placements): shard_size = copy.deepcopy(st_size) offsets = [0] * len(st_size) offsets[sharding_dim] = split_size * idx shard_size[sharding_dim] = get_chunked_dim_size(st_size[sharding_dim], split_size, idx) shards_metadata[placement.rank()] = ShardMetadata( shard_offsets=offsets, shard_sizes=shard_size, placement=placement, ) spec = _infer_sharding_spec_from_shards_metadata(shards_metadata) self.assertTrue(isinstance(spec, ChunkShardingSpec)) self.assertEqual(spec.dim, sharding_dim) self.assertEqual(spec.placements, placements)
def _init_from_local_shards_and_global_metadata( cls, local_shards: List[Shard], sharded_tensor_metadata: ShardedTensorMetadata, process_group=None, init_rrefs=False, ) -> "ShardedTensor": """ Initialize a ShardedTensor with local shards and a global ShardedTensorMetadata built on each rank. Warning: This API is experimental and subject to change. It does not do cross rank validations, and fully rely on the user for the correctness of sharded_tensor_metadata on each rank """ process_group = (process_group if process_group is not None else distributed_c10d._get_default_group()) current_rank = dist.get_rank(process_group) shards_metadata = sharded_tensor_metadata.shards_metadata tensor_properties = sharded_tensor_metadata.tensor_properties if len(shards_metadata) == 0: raise ValueError("shards_metadata must not be empty!") if tensor_properties.layout != torch.strided: raise ValueError( 'Only torch.strided layout is currently supported') sharded_tensor = cls.__new__(cls) sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) sharded_tensor._metadata = sharded_tensor_metadata local_shard_metadatas = [] def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False): tensor_property_or_metadata = "tensor property" if is_property else "local ShardMetadata" if expected != actual: raise ValueError( f"Local shards' tensor {prop_name} property is incompatible with " f"{tensor_property_or_metadata} on rank {rank}: " f"{tensor_property_or_metadata} {prop_name}={expected}, " f"local shard tensor {prop_name}={actual}.") # collect local shard metadatas from the global sharded_tensor_metadata for shard_metadata in shards_metadata: # type: ignore[attr-defined] rank, local_device = _parse_and_validate_remote_device( sharded_tensor._process_group, shard_metadata.placement) if current_rank == rank: local_shard_metadatas.append(shard_metadata) if len(local_shards) != len(local_shard_metadatas): raise RuntimeError( f'Number of local shards ({len(local_shards)}) does not match number of local ' f'shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) ' f'on rank ({current_rank}) ') for shard in local_shards: shard_meta = shard.metadata local_shard_tensor = shard.tensor rank, local_device = _parse_and_validate_remote_device( sharded_tensor._process_group, shard_meta.placement) # validate if shard_meta in the metadatas collected from sharded_tensor_metadata assert shard_meta in local_shard_metadatas, \ "local shard metadata not in sharded_tensor_metadata!" _raise_if_mismatch(tensor_properties.layout, local_shard_tensor.layout, "layout", current_rank, True) if not local_shard_tensor.is_contiguous(): raise ValueError( 'Only torch.contiguous_format memory_format is currently supported' ) _raise_if_mismatch(shard_meta.shard_sizes, list(local_shard_tensor.size()), "size", current_rank) _raise_if_mismatch(tensor_properties.pin_memory, local_shard_tensor.is_pinned(), "pin_memory", current_rank, True) _raise_if_mismatch(local_device, local_shard_tensor.device, "device", current_rank) _raise_if_mismatch(tensor_properties.dtype, local_shard_tensor.dtype, "dtype", current_rank, True) _raise_if_mismatch(tensor_properties.requires_grad, local_shard_tensor.requires_grad, "requires_grad", current_rank, True) # check if shards_metadata have overlap shards validate_non_overlapping_shards_metadata(shards_metadata) # check if the shards_metadata is compatible with overall size of the sharded tensor. check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) # done validation, add local_shards sharded_tensor._local_shards = local_shards sharded_tensor._sharding_spec = _infer_sharding_spec_from_shards_metadata( shards_metadata) # run post initialization, i.e. map registration, rpc initialization sharded_tensor._post_init() return sharded_tensor
def _init_from_local_shards_and_global_metadata( cls, local_shards: List[Shard], sharded_tensor_metadata: ShardedTensorMetadata, sharding_spec=None, ) -> "ShardedTensor": """ Initialize a ShardedTensorBase with local shards and a global ShardedTensorMetadata built on each rank. Warning: This API is experimental and subject to change. It does not do cross rank validations, and fully rely on the user for the correctness of sharded_tensor_metadata on each rank """ shards_metadata = sharded_tensor_metadata.shards_metadata tensor_properties = sharded_tensor_metadata.tensor_properties if len(shards_metadata) == 0: raise ValueError("shards_metadata must not be empty!") if tensor_properties.layout != torch.strided: raise ValueError( "Only torch.strided layout is currently supported") if sharding_spec is None: spec = shard_spec._infer_sharding_spec_from_shards_metadata( shards_metadata) else: spec = sharding_spec sharded_tensor_base = ShardedTensor.__new__( ShardedTensor, spec, sharded_tensor_metadata.size, dtype=tensor_properties.dtype, layout=tensor_properties.layout, pin_memory=tensor_properties.pin_memory, requires_grad=tensor_properties.requires_grad, ) def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False): tensor_property_or_metadata = ("tensor property" if is_property else "local ShardMetadata") if expected != actual: raise ValueError( f"Local shards' tensor {prop_name} property is incompatible with " f"{tensor_property_or_metadata} on rank {rank}: " f"{tensor_property_or_metadata} {prop_name}={expected}, " f"local shard tensor {prop_name}={actual}.") for shard in local_shards: shard_meta = shard.metadata local_shard_tensor = shard.tensor placement = shard_meta.placement assert placement is not None, "Must specify placement for `Shard`!" rank = placement.rank() local_device = placement.device() _raise_if_mismatch( tensor_properties.layout, local_shard_tensor.layout, "layout", rank, True, ) if not local_shard_tensor.is_contiguous(): raise ValueError( "Only torch.contiguous_format memory_format is currently supported" ) _raise_if_mismatch( shard_meta.shard_sizes, list(local_shard_tensor.size()), "size", rank, ) _raise_if_mismatch( tensor_properties.pin_memory, local_shard_tensor.is_pinned(), "pin_memory", rank, True, ) _raise_if_mismatch(local_device, local_shard_tensor.device, "device", rank) _raise_if_mismatch( tensor_properties.dtype, local_shard_tensor.dtype, "dtype", rank, True, ) _raise_if_mismatch( tensor_properties.requires_grad, local_shard_tensor.requires_grad, "requires_grad", rank, True, ) # check if shards_metadata have overlap shards validate_non_overlapping_shards_metadata(shards_metadata) # check if the shards_metadata is compatible with overall size of the sharded tensor. check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) # done validation, add local_shards sharded_tensor_base._local_shards = local_shards return sharded_tensor_base