def build_global_metadata( gathered_metadatas: Sequence[Optional[ShardedTensorMetadata]]): global_sharded_tensor_metadata = None global_metadata_rank = 0 for rank, rank_metadata in enumerate(gathered_metadatas): if rank_metadata is None: continue if global_sharded_tensor_metadata is None: global_sharded_tensor_metadata = copy.deepcopy(rank_metadata) global_metadata_rank = rank else: _raise_if_mismatch(global_sharded_tensor_metadata.size, rank_metadata.size, "global_size", [global_metadata_rank, rank], is_local=False) # don't need to check layout and memory format as we already checked in local shards validation stage _raise_if_mismatch( global_sharded_tensor_metadata.tensor_properties.dtype, rank_metadata.tensor_properties.dtype, "dtype", [global_metadata_rank, rank], is_local=False) _raise_if_mismatch( global_sharded_tensor_metadata.tensor_properties.requires_grad, rank_metadata.tensor_properties.requires_grad, "requires_grad", [global_metadata_rank, rank], is_local=False) _raise_if_mismatch( global_sharded_tensor_metadata.tensor_properties.pin_memory, rank_metadata.tensor_properties.pin_memory, "pin_memory", [global_metadata_rank, rank], is_local=False) # pass all validations, extend shards metadata global_sharded_tensor_metadata.shards_metadata.extend( rank_metadata.shards_metadata) if global_sharded_tensor_metadata is not None: # check if shards_metadata have overlap shards validate_non_overlapping_shards_metadata( global_sharded_tensor_metadata.shards_metadata) # check if the shards_metadata is compatible with global size of the sharded tensor. check_tensor(global_sharded_tensor_metadata.shards_metadata, global_sharded_tensor_metadata.size) else: raise ValueError("ShardedTensor have no local shards on all ranks!") return global_sharded_tensor_metadata
def _validate_sharded_tensor( tensor_md: ShardedTensorMetadata, checkpoint_md: ShardedTensorStorageMetadata) -> None: # We assume the incoming tensor has being validated during construction # To ensure a checkpoint can satisfy loading a ST, we compute the loading # plans for all shards and see if they are doable. validate_non_overlapping_shards_metadata( checkpoint_md.tensor_metadata.shards_metadata) for shard_md in tensor_md.shards_metadata: read_volume = 0 for storage_md in checkpoint_md.storage_metadata: shard_md_from_storage = storage_md.shard_metadata if not _check_shard_metadata_pair_overlap(shard_md, shard_md_from_storage): continue shard_volume = 1 for ( _, _, _, length, ) in _shards_get_overlap_region_wrt_saved_tensor( saved_shard=shard_md_from_storage, current_shard=shard_md): shard_volume *= length read_volume += shard_volume shard_volume = 1 for size in shard_md.shard_sizes: shard_volume *= size if read_volume != shard_volume: raise ValueError( f"Shard {shard_md} only has {read_volume} available" + "elements but needs {shard_volume}")
def _init_from_local_shards_and_global_metadata( cls, local_shards: List[Shard], sharded_tensor_metadata: ShardedTensorMetadata, process_group=None, init_rrefs=False, ) -> "ShardedTensor": """ Initialize a ShardedTensor with local shards and a global ShardedTensorMetadata built on each rank. Warning: This API is experimental and subject to change. It does not do cross rank validations, and fully rely on the user for the correctness of sharded_tensor_metadata on each rank """ process_group = (process_group if process_group is not None else distributed_c10d._get_default_group()) current_rank = dist.get_rank(process_group) shards_metadata = sharded_tensor_metadata.shards_metadata tensor_properties = sharded_tensor_metadata.tensor_properties if len(shards_metadata) == 0: raise ValueError("shards_metadata must not be empty!") if tensor_properties.layout != torch.strided: raise ValueError( 'Only torch.strided layout is currently supported') sharded_tensor = cls.__new__(cls) sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) sharded_tensor._metadata = sharded_tensor_metadata local_shard_metadatas = [] def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False): tensor_property_or_metadata = "tensor property" if is_property else "local ShardMetadata" if expected != actual: raise ValueError( f"Local shards' tensor {prop_name} property is incompatible with " f"{tensor_property_or_metadata} on rank {rank}: " f"{tensor_property_or_metadata} {prop_name}={expected}, " f"local shard tensor {prop_name}={actual}.") # collect local shard metadatas from the global sharded_tensor_metadata for shard_metadata in shards_metadata: # type: ignore[attr-defined] rank, local_device = _parse_and_validate_remote_device( sharded_tensor._process_group, shard_metadata.placement) if current_rank == rank: local_shard_metadatas.append(shard_metadata) if len(local_shards) != len(local_shard_metadatas): raise RuntimeError( f'Number of local shards ({len(local_shards)}) does not match number of local ' f'shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) ' f'on rank ({current_rank}) ') for shard in local_shards: shard_meta = shard.metadata local_shard_tensor = shard.tensor rank, local_device = _parse_and_validate_remote_device( sharded_tensor._process_group, shard_meta.placement) # validate if shard_meta in the metadatas collected from sharded_tensor_metadata assert shard_meta in local_shard_metadatas, \ "local shard metadata not in sharded_tensor_metadata!" _raise_if_mismatch(tensor_properties.layout, local_shard_tensor.layout, "layout", current_rank, True) if not local_shard_tensor.is_contiguous(): raise ValueError( 'Only torch.contiguous_format memory_format is currently supported' ) _raise_if_mismatch(shard_meta.shard_sizes, list(local_shard_tensor.size()), "size", current_rank) _raise_if_mismatch(tensor_properties.pin_memory, local_shard_tensor.is_pinned(), "pin_memory", current_rank, True) _raise_if_mismatch(local_device, local_shard_tensor.device, "device", current_rank) _raise_if_mismatch(tensor_properties.dtype, local_shard_tensor.dtype, "dtype", current_rank, True) _raise_if_mismatch(tensor_properties.requires_grad, local_shard_tensor.requires_grad, "requires_grad", current_rank, True) # check if shards_metadata have overlap shards validate_non_overlapping_shards_metadata(shards_metadata) # check if the shards_metadata is compatible with overall size of the sharded tensor. check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) # done validation, add local_shards sharded_tensor._local_shards = local_shards sharded_tensor._sharding_spec = _infer_sharding_spec_from_shards_metadata( shards_metadata) # run post initialization, i.e. map registration, rpc initialization sharded_tensor._post_init() return sharded_tensor
def _init_from_local_shards_and_global_metadata( cls, local_shards: List[Shard], sharded_tensor_metadata: ShardedTensorMetadata, sharding_spec=None, ) -> "ShardedTensor": """ Initialize a ShardedTensorBase with local shards and a global ShardedTensorMetadata built on each rank. Warning: This API is experimental and subject to change. It does not do cross rank validations, and fully rely on the user for the correctness of sharded_tensor_metadata on each rank """ shards_metadata = sharded_tensor_metadata.shards_metadata tensor_properties = sharded_tensor_metadata.tensor_properties if len(shards_metadata) == 0: raise ValueError("shards_metadata must not be empty!") if tensor_properties.layout != torch.strided: raise ValueError( "Only torch.strided layout is currently supported") if sharding_spec is None: spec = shard_spec._infer_sharding_spec_from_shards_metadata( shards_metadata) else: spec = sharding_spec sharded_tensor_base = ShardedTensor.__new__( ShardedTensor, spec, sharded_tensor_metadata.size, dtype=tensor_properties.dtype, layout=tensor_properties.layout, pin_memory=tensor_properties.pin_memory, requires_grad=tensor_properties.requires_grad, ) def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False): tensor_property_or_metadata = ("tensor property" if is_property else "local ShardMetadata") if expected != actual: raise ValueError( f"Local shards' tensor {prop_name} property is incompatible with " f"{tensor_property_or_metadata} on rank {rank}: " f"{tensor_property_or_metadata} {prop_name}={expected}, " f"local shard tensor {prop_name}={actual}.") for shard in local_shards: shard_meta = shard.metadata local_shard_tensor = shard.tensor placement = shard_meta.placement assert placement is not None, "Must specify placement for `Shard`!" rank = placement.rank() local_device = placement.device() _raise_if_mismatch( tensor_properties.layout, local_shard_tensor.layout, "layout", rank, True, ) if not local_shard_tensor.is_contiguous(): raise ValueError( "Only torch.contiguous_format memory_format is currently supported" ) _raise_if_mismatch( shard_meta.shard_sizes, list(local_shard_tensor.size()), "size", rank, ) _raise_if_mismatch( tensor_properties.pin_memory, local_shard_tensor.is_pinned(), "pin_memory", rank, True, ) _raise_if_mismatch(local_device, local_shard_tensor.device, "device", rank) _raise_if_mismatch( tensor_properties.dtype, local_shard_tensor.dtype, "dtype", rank, True, ) _raise_if_mismatch( tensor_properties.requires_grad, local_shard_tensor.requires_grad, "requires_grad", rank, True, ) # check if shards_metadata have overlap shards validate_non_overlapping_shards_metadata(shards_metadata) # check if the shards_metadata is compatible with overall size of the sharded tensor. check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) # done validation, add local_shards sharded_tensor_base._local_shards = local_shards return sharded_tensor_base