def test_sharded_tensor_reshard_errors(self): specs = _chunk_sharding_specs_list_for_test([0, 1], seed=6) spec, reshard_spec = specs[0], specs[1] enumerable_sharding_spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="rank:1/cuda:1", ), ]) st = sharded_tensor.rand(spec, 24, 12) with self.assertRaisesRegex( NotImplementedError, "Only ChunkShardingSpec supported for reshard."): st.reshard(enumerable_sharding_spec) st._local_shards = [st.local_shards()[0], st.local_shards()[0]] with self.assertRaisesRegex( NotImplementedError, "Only single local shard supported for reshard."): st.reshard(reshard_spec)
def test_partial_tensor_reshard_errors(self): enumerable_sharding_spec = EnumerableShardingSpec( [ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="rank:1/cuda:1", ), ] ) with self.assertRaisesRegex( NotImplementedError, "Only ChunkShardingSpec supported for reshard." ): self._run_partial_tensor_n_reshard( enumerable_sharding_spec, [13, 21], 4, dist.ReduceOp.SUM ) self._run_partial_tensor_n_reshard( enumerable_sharding_spec, [12, 22], 4, dist.ReduceOp.MAX ) specs = _chunk_sharding_specs_list_for_test([0], seed=7) spec = specs[0] with self.assertRaisesRegex( NotImplementedError, "Only real partial tensor supported for reshard." ): self._run_partial_tensor_n_reshard( spec, [13, 21], 4, dist.ReduceOp.SUM, dtype=torch.cfloat ) self._run_partial_tensor_n_reshard( spec, [12, 22], 4, dist.ReduceOp.MAX, dtype=torch.cfloat )
def generate_enumerable_sharding_specs_for_test(): return [ EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], placement="rank:2/cuda:2", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], placement="rank:3/cuda:3", ), ]) ]
def _create_enumerate_spec(self, tensor): # Since placement is not used, always set placement to rank0 to mimic # the actual usage. metadata = [ ShardMetadata([0], [101], placement="rank0/cuda:0"), ShardMetadata([101], [900], placement="rank0/cuda:0"), ] return EnumerableShardingSpec(metadata)
def _init_from_local_shards( cls, local_shards: List[Shard], *global_size, process_group=None, init_rrefs=False, ): # STEP 1: Validate the Shardmetadatas locally process_group = (process_group if process_group is not None else distributed_c10d._get_default_group()) current_rank = dist.get_rank(process_group) world_size = dist.get_world_size(process_group) local_sharded_tensor_metadata: Optional[ShardedTensorMetadata] = None global_tensor_size = _flatten_tensor_size(global_size) if len(local_shards) > 0: local_sharded_tensor_metadata = \ build_metadata_from_local_shards(local_shards, global_tensor_size, current_rank, process_group) # STEP 2. Validate metadata across ranks, and build a global sharded tensor # metadata by gathering local ShardedTensorMetadata gathered_metadatas: List[Optional[ShardedTensorMetadata]] = [] if world_size > 1: gathered_metadatas = [None for _ in range(world_size)] dist.all_gather_object(gathered_metadatas, local_sharded_tensor_metadata, group=process_group) else: gathered_metadatas = [local_sharded_tensor_metadata] global_sharded_tensor_metadata = build_global_metadata( gathered_metadatas) # STEP 3: Validation done, create the actual ShardedTensor and populate fields # prepare initialization sharded_tensor = cls.__new__(cls) sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) # add to metadata and local_shards sharded_tensor._metadata = global_sharded_tensor_metadata sharded_tensor._local_shards = local_shards # make a EnumerableShardingSpec for sharded tensors that initialized from this API. # TODO: make sharding spec a ChunkShardingSpec by inferring from the metadata list. # see issue https://github.com/pytorch/pytorch/issues/67244 sharded_tensor._sharding_spec = EnumerableShardingSpec( global_sharded_tensor_metadata.shards_metadata) # run post initialization, i.e. map registration, rpc initialization sharded_tensor._post_init() return sharded_tensor
def test_math_ops_errors(self): spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) sharded_lhs = sharded_tensor.rand(spec, (20, 3)) sharded_rhs = sharded_tensor.rand(spec, (12, 3)) with self.assertRaisesRegex( RuntimeError, "Implicit broadcasting not supported" ): torch.add(sharded_lhs, sharded_rhs) spec = EnumerableShardingSpec( [ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="rank:2/cuda:2", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], placement="rank:3/cuda:3", ), ] ) st = sharded_tensor.rand(spec, 10, 10) with self.assertRaisesRegex(RuntimeError, "not supported"): torch.add(st, sharded_rhs)
def test_switch_between_sharded_tensor_to_tensor(self) -> None: path = self.get_file_path() tensor_size = 32 specs = [ ChunkShardingSpec( dim=0, placements=[ "rank:0", "rank:1", ], ), ChunkShardingSpec( dim=0, placements=[ "rank:0", "rank:1", "rank:1", "rank:0", ], ), EnumerableShardingSpec(shards=[ ShardMetadata( shard_offsets=[0], shard_sizes=[8], placement="rank:1", ), ShardMetadata( shard_offsets=[8], shard_sizes=[tensor_size - 8], placement="rank:0", ), ]), EnumerableShardingSpec(shards=[ ShardMetadata( shard_offsets=[0], shard_sizes=[10], placement="rank:0", ), ShardMetadata( shard_offsets=[10], shard_sizes=[tensor_size - 10], placement="rank:1", ), ]), ] for save_spec in specs: for load_spec in specs: save_dict = { 'sharded': sharded_tensor.rand(save_spec, tensor_size), 'replicated': torch.rand(tensor_size, device=f"cpu:{self.rank}") } fs_writer = FileSystemWriter(path=path) save_state_dict(state_dict=save_dict, storage_writer=fs_writer) # Freaky Friday the tensors load_dict = { 'sharded': torch.zeros(tensor_size, device=f"cpu:{self.rank}"), 'replicated': sharded_tensor.zeros(load_spec, tensor_size) } fs_reader = FileSystemReader(path=path) load_state_dict(state_dict=load_dict, storage_reader=fs_reader) save_dict_sharded = self.load_tensor(save_dict['sharded']) load_dict_replicated = self.load_tensor( load_dict['replicated']) if dist.get_rank() == 0: self.assertTrue( torch.allclose(save_dict_sharded, load_dict['sharded']), f"save-spec {save_spec} load-spec {load_spec}") self.assertTrue( torch.allclose(save_dict['replicated'], load_dict_replicated), f"save-spec {save_spec} load-spec {load_spec}")
def test_load_with_different_shard_plan(self) -> None: path = self.get_file_path() # We hardcode the assumption of how many shards are around self.assertEqual(self.world_size, dist.get_world_size()) specs = [ # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`. ChunkShardingSpec( dim=0, placements=[ "rank:0", "rank:1", ], ), # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`. ChunkShardingSpec( dim=0, placements=[ "rank:0", "rank:1", "rank:1", "rank:0", ], ), # This requires the tensors to be [10, 20] EnumerableShardingSpec(shards=[ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[2, 20], placement="rank:0", ), ShardMetadata( shard_offsets=[2, 0], shard_sizes=[1, 20], placement="rank:1", ), ShardMetadata( shard_offsets=[3, 0], shard_sizes=[3, 20], placement="rank:0", ), ShardMetadata( shard_offsets=[6, 0], shard_sizes=[3, 20], placement="rank:1", ), ShardMetadata( shard_offsets=[9, 0], shard_sizes=[1, 20], placement="rank:0", ), ]), # This requires the tensors to be [10, 20] EnumerableShardingSpec(shards=[ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[8, 20], placement="rank:1", ), ShardMetadata( shard_offsets=[8, 0], shard_sizes=[2, 20], placement="rank:0", ), ]), ] for s0 in specs: for s1 in specs: if s0 == s1: continue if dist.get_rank() == 0: shutil.rmtree(path, ignore_errors=True) os.makedirs(path) dist.barrier() model_to_save = MyShardedModel3(s0) model_to_save._register_state_dict_hook(state_dict_hook) state_dict_to_save = model_to_save.state_dict() fs_writer = FileSystemWriter(path=path) save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer) dist.barrier() model_to_load = MyShardedModel3(s1) model_to_load._register_state_dict_hook(state_dict_hook) state_dict_to_load_to = model_to_load.state_dict() dist.barrier() fs_reader = FileSystemReader(path=path) load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader) dist.barrier() store_tensor = self.load_tensor(model_to_save.sharded_tensor) dist.barrier() load_tensor = self.load_tensor(model_to_load.sharded_tensor) if dist.get_rank() == 0: self.assertTrue(torch.allclose(store_tensor, load_tensor), msg=f"{s0} vs {s1}")
def test_sharded_linear_errors(self): for spec in generate_chunk_sharding_specs_for_test(0): fc1 = torch.nn.Linear(10, 10).cuda(self.rank) shard_parameter(fc1, "bias", spec) with self.assertRaisesRegex(TypeError, 'bias needs to be torch.Tensor'): fc1(torch.rand(10, 10).cuda(self.rank)) fc2 = torch.nn.Linear(10, 10).cuda(self.rank) shard_parameter(fc2, "weight", spec) with self.assertRaisesRegex(ValueError, 'Input needs to have at least 1 dim'): fc2(torch.tensor(1).cuda(self.rank)) fc3 = torch.nn.Linear(10, 10).cuda(self.rank) fc3.weight = torch.nn.Parameter( torch.rand(10, 10, 10).cuda(self.rank)) shard_parameter(fc3, "weight", spec) with self.assertRaisesRegex(ValueError, 'Weight needs to have exactly 2 dims'): fc3(torch.rand(10, 10).cuda(self.rank)) fc4 = torch.nn.Linear(10, 10).cuda(self.rank) fc4.bias = torch.nn.Parameter(torch.rand(10, 10).cuda(self.rank)) shard_parameter(fc4, "weight", spec) with self.assertRaisesRegex(ValueError, 'Bias needs to have exactly 1 dim'): fc4(torch.rand(10, 10).cuda(self.rank)) fc5 = torch.nn.Linear(7, 10).cuda(self.rank) shard_parameter(fc5, "weight", spec) with self.assertRaisesRegex( ValueError, 'Input dim: 13 does not match appropriate weight dim: 7'): fc5(torch.rand(20, 10, 13).cuda(self.rank)) fc6 = torch.nn.Linear(10, 10).cuda(self.rank) del fc6.weight enumerable_spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="rank:2/cuda:2", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], placement="rank:3/cuda:3", ) ]) fc6.weight = empty(enumerable_spec, 10, 10) # Sharded Tensor metadata has parenthesis imbalance issue when using re.compile error_msg = r"torch function 'linear', with args: (?s).* " r"and kwargs: None not supported for ShardedTensor!" with self.assertRaisesRegex(RuntimeError, error_msg): fc6(torch.rand(10, 10).cuda(self.rank)) fc7 = torch.nn.Linear(10, 80).cuda(self.rank) multiple_local_shard_spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:0/cuda:0", "rank:1/cuda:1", "rank:1/cuda:1", "rank:2/cuda:2", "rank:2/cuda:2", "rank:3/cuda:3", "rank:3/cuda:3", ], ) del fc7.weight fc7.weight = empty(multiple_local_shard_spec, 80, 10) with self.assertRaisesRegex(ValueError, 'Only one local shard supported!'): fc7(torch.rand(10, 10).cuda(self.rank))
def test_enumerable_sharding_spec(self): # test valid specs # test row-wise sharding spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="cuda:1", ) ]) check_tensor(spec.shards, torch.rand(10, 5).size()) # test row and column sharding spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[3, 3], placement="cuda:0", ), ShardMetadata( shard_offsets=[0, 3], shard_sizes=[3, 3], placement="cuda:1", ), ShardMetadata( shard_offsets=[3, 0], shard_sizes=[3, 3], placement="cuda:2", ), ShardMetadata( shard_offsets=[3, 3], shard_sizes=[3, 3], placement="cuda:3", ), ]) check_tensor(spec.shards, torch.rand(6, 6).size()) # test uneven shard sizes. spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[2, 4], placement="cuda:0", ), ShardMetadata( shard_offsets=[0, 4], shard_sizes=[4, 2], placement="cuda:1", ), ShardMetadata( shard_offsets=[2, 0], shard_sizes=[4, 4], placement="cuda:2", ), ShardMetadata( shard_offsets=[4, 4], shard_sizes=[2, 2], placement="cuda:3", ), ]) check_tensor(spec.shards, torch.rand(6, 6).size()) # test invalid sharding with self.assertRaisesRegex(ValueError, 'Could not parse remote_device'): ShardMetadata(shard_offsets=[0], shard_sizes=[1], placement="cuda:foo") with self.assertRaisesRegex(ValueError, 'same number of elements'): ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1], placement="cuda:0") with self.assertRaisesRegex(ValueError, 'shard_offsets should be >=0'): ShardMetadata(shard_offsets=[-1, 0], shard_sizes=[1, 1], placement="cuda:0") with self.assertRaisesRegex(ValueError, 'shard_sizes should be >= 0'): ShardMetadata(shard_offsets=[0, 0], shard_sizes=[-1, 1], placement="cuda:0") with self.assertRaisesRegex(ValueError, 'Empty shard list provided'): EnumerableShardingSpec([]) with self.assertRaisesRegex(ValueError, 'Found inconsistent ranks for shards'): EnumerableShardingSpec([ ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1, 1], placement="cpu"), ShardMetadata(shard_offsets=[0, 0, 0], shard_sizes=[1, 1, 1], placement="cpu"), ]) with self.assertRaisesRegex(ValueError, 'Shards.*overlap'): EnumerableShardingSpec([ ShardMetadata(shard_offsets=[0, 0], shard_sizes=[3, 3], placement="cpu"), ShardMetadata(shard_offsets=[2, 0], shard_sizes=[3, 3], placement="cpu"), ]) spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="cuda:1", ) ]) with self.assertRaisesRegex(ValueError, 'Rank of tensor is.*but shards rank'): check_tensor(spec.shards, torch.rand(10, 10, 10).size()) spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="cuda:1", ) ]) with self.assertRaisesRegex(ValueError, 'exceeds tensor dim'): check_tensor(spec.shards, torch.rand(10, 3).size()) spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], placement="cuda:1", ) ]) with self.assertRaisesRegex(ValueError, 'does not match tensor volume'): check_tensor(spec.shards, torch.rand(10, 10).size())
def _init_from_local_shards_and_global_metadata( cls, local_shards: List[Shard], sharded_tensor_metadata: ShardedTensorMetadata, process_group=None, init_rrefs=False, ) -> "ShardedTensor": """ Initialize a ShardedTensor with local shards and a global ShardedTensorMetadata built on each rank. Warning: This API is experimental and subject to change. It does not do cross rank validations, and fully rely on the user for the correctness of sharded_tensor_metadata on each rank """ process_group = (process_group if process_group is not None else distributed_c10d._get_default_group()) current_rank = dist.get_rank(process_group) shards_metadata = sharded_tensor_metadata.shards_metadata tensor_properties = sharded_tensor_metadata.tensor_properties if len(shards_metadata) == 0: raise ValueError("shards_metadata must not be empty!") if tensor_properties.layout != torch.strided: raise ValueError( 'Only torch.strided layout is currently supported') sharded_tensor = cls.__new__(cls) sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) sharded_tensor._metadata = sharded_tensor_metadata local_shard_metadatas = [] def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False): tensor_property_or_metadata = "tensor property" if is_property else "local ShardMetadata" if expected != actual: raise ValueError( f"Local shards' tensor {prop_name} property is incompatible with " f"{tensor_property_or_metadata} on rank {rank}: " f"{tensor_property_or_metadata} {prop_name}={expected}, " f"local shard tensor {prop_name}={actual}.") # collect local shard metadatas from the global sharded_tensor_metadata for shard_metadata in shards_metadata: # type: ignore[attr-defined] rank, local_device = _parse_and_validate_remote_device( sharded_tensor._process_group, shard_metadata.placement) if current_rank == rank: local_shard_metadatas.append(shard_metadata) if len(local_shards) != len(local_shard_metadatas): raise RuntimeError( f'Number of local shards ({len(local_shards)}) does not match number of local ' f'shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) ' f'on rank ({current_rank}) ') for shard in local_shards: shard_meta = shard.metadata local_shard_tensor = shard.tensor rank, local_device = _parse_and_validate_remote_device( sharded_tensor._process_group, shard_meta.placement) # validate if shard_meta in the metadatas collected from sharded_tensor_metadata assert shard_meta in local_shard_metadatas, \ "local shard metadata not in sharded_tensor_metadata!" _raise_if_mismatch(tensor_properties.layout, local_shard_tensor.layout, "layout", current_rank, True) if not local_shard_tensor.is_contiguous(): raise ValueError( 'Only torch.contiguous_format memory_format is currently supported' ) _raise_if_mismatch(shard_meta.shard_sizes, list(local_shard_tensor.size()), "size", current_rank) _raise_if_mismatch(tensor_properties.pin_memory, local_shard_tensor.is_pinned(), "pin_memory", current_rank, True) _raise_if_mismatch(local_device, local_shard_tensor.device, "device", current_rank) _raise_if_mismatch(tensor_properties.dtype, local_shard_tensor.dtype, "dtype", current_rank, True) _raise_if_mismatch(tensor_properties.requires_grad, local_shard_tensor.requires_grad, "requires_grad", current_rank, True) # check if shards_metadata have overlap shards validate_non_overlapping_shards_metadata(shards_metadata) # check if the shards_metadata is compatible with overall size of the sharded tensor. check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) # done validation, add local_shards sharded_tensor._local_shards = local_shards # make a EnumerableShardingSpec for sharded tensors that initialized from this API. # TODO: make sharding spec a ChunkShardingSpec by inferring from the metadata list. # see issue https://github.com/pytorch/pytorch/issues/67244 sharded_tensor._sharding_spec = EnumerableShardingSpec(shards_metadata) # run post initialization, i.e. map registration, rpc initialization sharded_tensor._post_init() return sharded_tensor