def test_new_group(self): spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="rank:0/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="rank:2/cuda:3", ), ]) pg = dist.new_group(ranks=[1, 2, 3]) if self.rank >= 1: sharded_tensor = _sharded_tensor.empty(spec, 10, 5, process_group=pg) self.assertEqual((10, 5), sharded_tensor.size()) if self.rank == 1 or self.rank == 3: # Verify local shard. local_shard = sharded_tensor.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. self.assertEqual((self.rank // 2 * 5, 0), local_shard.metadata.shard_offsets) self.assertEqual((5, 5), local_shard.metadata.shard_lengths) self.assertEqual(f'rank:{self.rank - 1}/cuda:{self.rank}', local_shard.metadata.placement) # Verify global metadata. sharding_metadata = sharded_tensor.sharding_metadata() self.assertEqual(2, len(sharding_metadata)) for rank, shard_metadata in enumerate(sharding_metadata): self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets) self.assertEqual((5, 5), shard_metadata.shard_lengths) self.assertEqual(f'rank:{rank * 2}/cuda:{rank * 2 + 1}', shard_metadata.placement) # Validate remote shards. remote_shards = sharded_tensor.remote_shards if self.rank == 1 or self.rank == 3: self.assertEqual(1, len(remote_shards)) else: self.assertEqual(2, len(remote_shards)) owners = {} for rpc_rank, shards in remote_shards.items(): self.assertEqual(1, len(shards)) for remote_shard in shards: self.assertEqual(rpc_rank, remote_shard.owner().id) shard = remote_shard.to_here() self.assertEqual((5, 5), shard.tensor.size())
def _init_from_local_shards( cls, local_shards: List[Shard], *global_size, process_group=None, init_rrefs=False, ): # STEP 1: Validate the Shardmetadatas locally process_group = (process_group if process_group is not None else distributed_c10d._get_default_group()) current_rank = dist.get_rank(process_group) world_size = dist.get_world_size(process_group) local_sharded_tensor_metadata: Optional[ShardedTensorMetadata] = None local_shards_device = torch.device("cpu") global_tensor_size = _flatten_tensor_size(global_size) if len(local_shards) > 0: local_sharded_tensor_metadata, local_shards_device = \ build_metadata_from_local_shards(local_shards, global_tensor_size, current_rank, process_group) # STEP 2. Validate metadata across ranks, and build a global sharded tensor # metadata by gathering local ShardedTensorMetadata gathered_metadatas = [None for _ in range(world_size)] if local_shards_device.type == "cuda": # with GPU/NCCL, we need to set a device for all_gather_object # to use as we need to know which device we should put the # serialized tensor on before the NCCL collective. with torch.cuda.device(local_shards_device): dist.all_gather_object(gathered_metadatas, local_sharded_tensor_metadata, group=process_group) else: dist.all_gather_object(gathered_metadatas, local_sharded_tensor_metadata, group=process_group) global_sharded_tensor_metadata = build_global_metadata( gathered_metadatas) # STEP 3: Validation done, create the actual ShardedTensor and populate fields # prepare initialization sharded_tensor = cls.__new__(cls) sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) # add to metadata and local_shards sharded_tensor._metadata = global_sharded_tensor_metadata sharded_tensor._local_shards = local_shards # make a EnumerableShardingSpec for sharded tensors that initialized from this API. # TODO: make sharding spec a ChunkShardingSpec by inferring from the metadata list. # see issue https://github.com/pytorch/pytorch/issues/67244 sharded_tensor._sharding_spec = EnumerableShardingSpec( global_sharded_tensor_metadata.shards_metadata) # run post initialization, i.e. map registration, rpc initialization sharded_tensor._post_init() return sharded_tensor
def test_partial_world_size(self): spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="rank:1/cuda:1", ), ]) sharded_tensor = _sharded_tensor.empty(spec, 10, 5) self.assertEqual((10, 5), sharded_tensor.size()) if self.rank <= 1: self.assertEqual(1, len(sharded_tensor.local_shards())) else: self.assertEqual(0, len(sharded_tensor.local_shards())) if self.rank <= 1: # Verify local shard. local_shard = sharded_tensor.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. self.assertEqual((self.rank * 5, 0), local_shard.metadata.shard_offsets) self.assertEqual((5, 5), local_shard.metadata.shard_lengths) self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}', local_shard.metadata.placement) # Verify global metadata. sharded_tensor_metadata = sharded_tensor.metadata() shards_metadata = sharded_tensor_metadata.shards_metadata self.assertEqual(2, len(shards_metadata)) for rank, shard_metadata in enumerate(shards_metadata): self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets) self.assertEqual((5, 5), shard_metadata.shard_lengths) self.assertEqual(f'rank:{rank}/cuda:{rank}', shard_metadata.placement) # Validate remote shards. remote_shards = sharded_tensor.remote_shards if self.rank <= 1: self.assertEqual(1, len(remote_shards)) else: self.assertEqual(2, len(remote_shards)) for rpc_rank, shards in remote_shards.items(): self.assertEqual(1, len(shards)) for remote_shard in shards: self.assertEqual(rpc_rank, remote_shard.owner().id) shard = remote_shard.to_here() self.assertEqual((5, 5), shard.tensor.size())
def test_with_rpc_names(self): spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="worker0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_lengths=[5, 5], placement="worker1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="worker2/cuda:2", ), ShardMetadata( shard_offsets=[5, 5], shard_lengths=[5, 5], placement="worker3/cuda:3", ) ]) sharded_tensor = _sharded_tensor.empty(spec, 10, 10) self.assertEqual((10, 10), sharded_tensor.size()) self.assertEqual(1, len(sharded_tensor.local_shards())) # Verify local shard. local_shard = sharded_tensor.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. self.assertEqual((self.rank // 2 * 5, (self.rank % 2) * 5), local_shard.metadata.shard_offsets) self.assertEqual((5, 5), local_shard.metadata.shard_lengths) self.assertEqual(f'worker{self.rank}/cuda:{self.rank}', local_shard.metadata.placement) # Verify global metadata. sharded_tensor_metadata = sharded_tensor.metadata() shards_metadata = sharded_tensor_metadata.shards_metadata self.assertEqual(4, len(shards_metadata)) for rank, shard_metadata in enumerate(shards_metadata): self.assertEqual((rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets) self.assertEqual((5, 5), shard_metadata.shard_lengths) self.assertEqual(f'worker{rank}/cuda:{rank}', shard_metadata.placement) # Validate remote shards. remote_shards = sharded_tensor.remote_shards self.assertEqual(3, len(remote_shards)) for rpc_rank, shards in remote_shards.items(): self.assertEqual(1, len(shards)) for remote_shard in shards: self.assertEqual(rpc_rank, remote_shard.owner().id) shard = remote_shard.to_here() self.assertEqual((5, 5), shard.tensor.size())
def test_multiple_local_shards(self): self.init_pg() spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_lengths=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[5, 5], shard_lengths=[5, 5], placement="rank:1/cuda:1", ) ]) sharded_tensor = _sharded_tensor.empty(spec, 10, 10) self.assertEqual((10, 10), sharded_tensor.size()) if self.rank <= 1: self.assertEqual(2, len(sharded_tensor.local_shards())) # Verify local shards. for idx, local_shard in enumerate(sharded_tensor.local_shards()): self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. self.assertEqual((idx * 5, self.rank * 5), local_shard.metadata.shard_offsets) self.assertEqual((5, 5), local_shard.metadata.shard_lengths) self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}', local_shard.metadata.placement) else: self.assertEqual(0, len(sharded_tensor.local_shards())) # Verify global metadata. sharding_metadata = sharded_tensor.sharding_metadata() self.assertEqual(4, len(sharding_metadata)) for shard_rank, shard_metadata in enumerate(sharding_metadata): self.assertEqual((shard_rank // 2 * 5, (shard_rank % 2) * 5), shard_metadata.shard_offsets) self.assertEqual((5, 5), shard_metadata.shard_lengths) self.assertEqual(f'rank:{shard_rank % 2}/cuda:{shard_rank % 2}', shard_metadata.placement)
def test_grid_sharding(self): self.init_pg() spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_lengths=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="rank:2/cuda:2", ), ShardMetadata( shard_offsets=[5, 5], shard_lengths=[5, 5], placement="rank:3/cuda:3", ) ]) sharded_tensor = _sharded_tensor.empty(spec, 10, 10) self.assertEqual((10, 10), sharded_tensor.size()) self.assertEqual(1, len(sharded_tensor.local_shards())) # Verify local shard. local_shard = sharded_tensor.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. self.assertEqual((self.rank // 2 * 5, (self.rank % 2) * 5), local_shard.metadata.shard_offsets) self.assertEqual((5, 5), local_shard.metadata.shard_lengths) self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}', local_shard.metadata.placement) # Verify global metadata. sharding_metadata = sharded_tensor.sharding_metadata() self.assertEqual(4, len(sharding_metadata)) for rank, shard_metadata in enumerate(sharding_metadata): self.assertEqual((rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets) self.assertEqual((5, 5), shard_metadata.shard_lengths) self.assertEqual(f'rank:{rank}/cuda:{rank}', shard_metadata.placement)
def test_enumerable_sharding_spec(self): # test valid specs # test row-wise sharding spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="cuda:1", ) ]) spec.check_tensor(torch.rand(10, 5).size()) # test row and column sharding spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[3, 3], placement="cuda:0", ), ShardMetadata( shard_offsets=[0, 3], shard_lengths=[3, 3], placement="cuda:1", ), ShardMetadata( shard_offsets=[3, 0], shard_lengths=[3, 3], placement="cuda:2", ), ShardMetadata( shard_offsets=[3, 3], shard_lengths=[3, 3], placement="cuda:3", ), ]) spec.check_tensor(torch.rand(6, 6).size()) # test uneven shard sizes. spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[2, 4], placement="cuda:0", ), ShardMetadata( shard_offsets=[0, 4], shard_lengths=[4, 2], placement="cuda:1", ), ShardMetadata( shard_offsets=[2, 0], shard_lengths=[4, 4], placement="cuda:2", ), ShardMetadata( shard_offsets=[4, 4], shard_lengths=[2, 2], placement="cuda:3", ), ]) spec.check_tensor(torch.rand(6, 6).size()) # test invalid sharding with self.assertRaisesRegex(ValueError, 'not a valid device'): ShardMetadata(shard_offsets=[0], shard_lengths=[1], placement="cuda:foo") with self.assertRaisesRegex(ValueError, 'same number of elements'): ShardMetadata(shard_offsets=[0, 0], shard_lengths=[1], placement="cuda:0") with self.assertRaisesRegex(ValueError, 'shard_offsets should be >=0'): ShardMetadata(shard_offsets=[-1, 0], shard_lengths=[1, 1], placement="cuda:0") with self.assertRaisesRegex(ValueError, 'shard_lengths should be > 0'): ShardMetadata(shard_offsets=[0, 0], shard_lengths=[0, 1], placement="cuda:0") with self.assertRaisesRegex(ValueError, 'Empty shard list provided'): EnumerableShardingSpec([]) with self.assertRaisesRegex(ValueError, 'Found inconsistent ranks for shards'): EnumerableShardingSpec([ ShardMetadata(shard_offsets=[0, 0], shard_lengths=[1, 1], placement="cpu"), ShardMetadata(shard_offsets=[0, 0, 0], shard_lengths=[1, 1, 1], placement="cpu"), ]) with self.assertRaisesRegex(ValueError, 'Shards.*overlap'): EnumerableShardingSpec([ ShardMetadata(shard_offsets=[0, 0], shard_lengths=[3, 3], placement="cpu"), ShardMetadata(shard_offsets=[2, 0], shard_lengths=[3, 3], placement="cpu"), ]) spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="cuda:1", ) ]) with self.assertRaisesRegex(ValueError, 'Rank of tensor is.*but shards rank'): spec.check_tensor(torch.rand(10, 10, 10).size()) spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="cuda:1", ) ]) with self.assertRaisesRegex(ValueError, 'exceeds tensor dim'): spec.check_tensor(torch.rand(10, 3).size()) spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 5], shard_lengths=[5, 5], placement="cuda:1", ) ]) with self.assertRaisesRegex(ValueError, 'does not match tensor volume'): spec.check_tensor(torch.rand(10, 10).size())
def test_sharded_linear_errors(self): for spec in generate_chunk_sharding_specs_for_test(0): fc1 = torch.nn.Linear(10, 10).cuda(self.rank) shard_parameter(fc1, "bias", spec) with self.assertRaisesRegex( TypeError, 'input and bias need to be torch.Tensor'): fc1(torch.rand(10, 10).cuda(self.rank)) fc2 = torch.nn.Linear(10, 10).cuda(self.rank) shard_parameter(fc2, "weight", spec) with self.assertRaisesRegex(ValueError, 'Input needs to have at least 1 dim'): fc2(torch.tensor(1).cuda(self.rank)) fc3 = torch.nn.Linear(10, 10).cuda(self.rank) fc3.weight = torch.nn.Parameter( torch.rand(10, 10, 10).cuda(self.rank)) shard_parameter(fc3, "weight", spec) with self.assertRaisesRegex(ValueError, 'Weight needs to have exactly 2 dims'): fc3(torch.rand(10, 10).cuda(self.rank)) fc4 = torch.nn.Linear(10, 10).cuda(self.rank) fc4.bias = torch.nn.Parameter(torch.rand(10, 10).cuda(self.rank)) shard_parameter(fc4, "weight", spec) with self.assertRaisesRegex(ValueError, 'Bias needs to have exactly 1 dim'): fc4(torch.rand(10, 10).cuda(self.rank)) fc5 = torch.nn.Linear(7, 10).cuda(self.rank) shard_parameter(fc5, "weight", spec) with self.assertRaisesRegex( ValueError, 'Input dim: 13 does not match appropriate weight dim: 7'): fc5(torch.rand(20, 10, 13).cuda(self.rank)) fc6 = torch.nn.Linear(10, 10).cuda(self.rank) del fc6.weight enumerable_spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="rank:2/cuda:2", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], placement="rank:3/cuda:3", ) ]) fc6.weight = empty(enumerable_spec, 10, 10) with self.assertRaisesRegex( ValueError, 'Only ChunkShardingSpec supported for ShardedTensor ops!'): fc6(torch.rand(10, 10).cuda(self.rank)) fc7 = torch.nn.Linear(10, 80).cuda(self.rank) multiple_local_shard_spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:0/cuda:0", "rank:1/cuda:1", "rank:1/cuda:1", "rank:2/cuda:2", "rank:2/cuda:2", "rank:3/cuda:3", "rank:3/cuda:3", ], ) del fc7.weight fc7.weight = empty(multiple_local_shard_spec, 80, 10) with self.assertRaisesRegex(ValueError, 'Only one local shard supported!'): fc7(torch.rand(10, 10).cuda(self.rank))
def _init_from_local_shards_and_global_metadata( cls, local_shards: List[Shard], sharded_tensor_metadata: ShardedTensorMetadata, process_group=None, init_rrefs=False, ): """ Initialize a ShardedTensor with local shards and a global ShardedTensorMetadata built on each rank. Warning: This API is experimental and subject to change. It does not do cross rank validations, and fully rely on the user for the correctness of sharded_tensor_metadata on each rank """ process_group = (process_group if process_group is not None else distributed_c10d._get_default_group()) current_rank = dist.get_rank(process_group) shards_metadata = sharded_tensor_metadata.shards_metadata tensor_properties = sharded_tensor_metadata.tensor_properties if len(shards_metadata) == 0: raise ValueError("shards_metadata must not be empty!") if tensor_properties.layout != torch.strided: raise ValueError( 'Only torch.strided layout is currently supported') sharded_tensor = cls.__new__(cls) sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) sharded_tensor._metadata = sharded_tensor_metadata local_shard_metadatas = [] def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False): tensor_property_or_metadata = "tensor property" if is_property else "local ShardMetadata" if expected != actual: raise ValueError( f"Local shards' tensor {prop_name} property is incompatible with " f"{tensor_property_or_metadata} on rank {rank}: " f"{tensor_property_or_metadata} {prop_name}={expected}, " f"local shard tensor {prop_name}={actual}.") # collect local shard metadatas from the global sharded_tensor_metadata for shard_metadata in shards_metadata: # type: ignore[attr-defined] rank, local_device = _parse_and_validate_remote_device( sharded_tensor._process_group, shard_metadata.placement) if current_rank == rank: local_shard_metadatas.append(shard_metadata) if len(local_shards) != len(local_shard_metadatas): raise RuntimeError( f'Number of local shards ({len(local_shards)}) does not match number of local ' f'shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) ' f'on rank ({current_rank}) ') for shard in local_shards: shard_meta = shard.metadata local_shard_tensor = shard.tensor rank, local_device = _parse_and_validate_remote_device( sharded_tensor._process_group, shard_meta.placement) # validate if shard_meta in the metadatas collected from sharded_tensor_metadata assert shard_meta in local_shard_metadatas, \ "local shard metadata not in sharded_tensor_metadata!" _raise_if_mismatch(tensor_properties.layout, local_shard_tensor.layout, "layout", current_rank, True) if not local_shard_tensor.is_contiguous(): raise ValueError( 'Only torch.contiguous_format memory_format is currently supported' ) _raise_if_mismatch(shard_meta.shard_sizes, list(local_shard_tensor.size()), "size", current_rank) _raise_if_mismatch(tensor_properties.pin_memory, local_shard_tensor.is_pinned(), "pin_memory", current_rank, True) _raise_if_mismatch(local_device, local_shard_tensor.device, "device", current_rank) _raise_if_mismatch(tensor_properties.dtype, local_shard_tensor.dtype, "dtype", current_rank, True) _raise_if_mismatch(tensor_properties.requires_grad, local_shard_tensor.requires_grad, "requires_grad", current_rank, True) # check if shards_metadata have overlap shards validate_non_overlapping_shards_metadata(shards_metadata) # check if the shards_metadata is compatible with overall size of the sharded tensor. check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) # done validation, add local_shards sharded_tensor._local_shards = local_shards # make a EnumerableShardingSpec for sharded tensors that initialized from this API. # TODO: make sharding spec a ChunkShardingSpec by inferring from the metadata list. # see issue https://github.com/pytorch/pytorch/issues/67244 sharded_tensor._sharding_spec = EnumerableShardingSpec(shards_metadata) # run post initialization, i.e. map registration, rpc initialization sharded_tensor._post_init() return sharded_tensor
def test_uneven_shards(self): self.init_pg() spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[2, 4], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 4], shard_lengths=[4, 2], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[2, 0], shard_lengths=[4, 4], placement="rank:2/cuda:2", ), ShardMetadata( shard_offsets=[4, 4], shard_lengths=[2, 2], placement="rank:3/cuda:3", ), ]) sharded_tensor = _sharded_tensor.empty(spec, 6, 6) self.assertEqual((6, 6), sharded_tensor.size()) self.assertEqual(1, len(sharded_tensor.local_shards())) def verify_size(rank, tensor_dims): if rank == 0: self.assertEqual((2, 4), tensor_dims) elif rank == 1: self.assertEqual((4, 2), tensor_dims) elif rank == 2: self.assertEqual((4, 4), tensor_dims) elif rank == 3: self.assertEqual((2, 2), tensor_dims) def verify_offsets(rank, offsets): if rank == 0: self.assertEqual((0, 0), offsets) elif rank == 1: self.assertEqual((0, 4), offsets) elif rank == 2: self.assertEqual((2, 0), offsets) elif rank == 3: self.assertEqual((4, 4), offsets) # Verify local shard. local_shard = sharded_tensor.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) verify_size(self.rank, local_shard.tensor.size()) # Verify local shard metadata. verify_offsets(self.rank, local_shard.metadata.shard_offsets) verify_size(self.rank, local_shard.metadata.shard_lengths) self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}', local_shard.metadata.placement) # Verify global metadata. sharded_tensor_metadata = sharded_tensor.metadata() shards_metadata = sharded_tensor_metadata.shards_metadata self.assertEqual(4, len(shards_metadata)) for rank, shard_metadata in enumerate(shards_metadata): verify_offsets(rank, shard_metadata.shard_offsets) verify_size(rank, shard_metadata.shard_lengths) self.assertEqual(f'rank:{rank}/cuda:{rank}', shard_metadata.placement)
def test_sharded_tensor_metadata(self): spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_lengths=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="rank:2/cuda:2", ), ShardMetadata( shard_offsets=[5, 5], shard_lengths=[5, 5], placement="rank:3/cuda:3", ) ]) sharded_tensor = _sharded_tensor.empty(spec, 10, 10) sharded_tensor_metadata = sharded_tensor.metadata() self.assertEqual(torch.Size([10, 10]), sharded_tensor_metadata.size) self.assertEqual(torch.float, sharded_tensor_metadata.dtype) self.assertEqual(torch.strided, sharded_tensor_metadata.layout) self.assertEqual(False, sharded_tensor_metadata.requires_grad) self.assertEqual(torch.contiguous_format, sharded_tensor_metadata.memory_format) self.assertEqual(False, sharded_tensor_metadata.pin_memory) sharded_tensor = _sharded_tensor.empty(spec, 10, 10, requires_grad=True) sharded_tensor_metadata = sharded_tensor.metadata() self.assertEqual(True, sharded_tensor_metadata.requires_grad) sharded_tensor = _sharded_tensor.empty(spec, 10, 10, dtype=torch.double) sharded_tensor_metadata = sharded_tensor.metadata() self.assertEqual(torch.double, sharded_tensor_metadata.dtype) # Need CPU for pin_memory spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="rank:0/cpu", ), ShardMetadata( shard_offsets=[0, 5], shard_lengths=[5, 5], placement="rank:1/cpu", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="rank:2/cpu", ), ShardMetadata( shard_offsets=[5, 5], shard_lengths=[5, 5], placement="rank:3/cpu", ) ]) sharded_tensor = _sharded_tensor.empty(spec, 10, 10, pin_memory=True) sharded_tensor_metadata = sharded_tensor.metadata() self.assertEqual(True, sharded_tensor_metadata.pin_memory)