def test_basic_math_ops(self): ops = [ "torch.add", "torch.sub", "torch.mul", "torch.div", "+", "-", "*", "/" ] spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) sharded_lhs = sharded_tensor.rand(spec, (12, 3)) sharded_rhs = sharded_tensor.rand(spec, (12, 3)) current_rank = dist.get_rank() global_lhs = torch.empty((12, 3)) if current_rank == 0 else None global_rhs = torch.empty((12, 3)) if current_rank == 0 else None sharded_lhs.gather(dst=0, out=global_lhs) sharded_rhs.gather(dst=0, out=global_rhs) for op in ops: binary_op = gen_binary_op_func(op) binary_op_ = gen_binary_op_func(op, inplace=True) # test basic math ops between ShardedTensors sharded_output = binary_op(sharded_lhs, sharded_rhs) output = torch.empty((12, 3)) if current_rank == 0 else None sharded_output.gather(dst=0, out=output) if current_rank == 0: global_output = binary_op(global_lhs, global_rhs) self.assertEqual(output, global_output) # test basic math ops between ShardedTensor and scalar scalars = [3, 1.8] for scalar in scalars: sharded_output_lhs = binary_op(sharded_lhs, scalar) sharded_output_lhs_ = binary_op_(sharded_lhs, scalar) self.assertTrue( torch.allclose(sharded_output_lhs, sharded_output_lhs_)) output_lhs = torch.empty( (12, 3)) if current_rank == 0 else None sharded_output_lhs.gather(dst=0, out=output_lhs) sharded_output_rhs = binary_op(scalar, sharded_lhs) output_rhs = torch.empty( (12, 3)) if current_rank == 0 else None sharded_output_rhs.gather(dst=0, out=output_rhs) if current_rank == 0: global_output_lhs = binary_op(global_lhs, scalar) global_output_rhs = binary_op(scalar, global_lhs) self.assertEqual(output_lhs, global_output_lhs) self.assertEqual(output_rhs, global_output_rhs)
def test_sharded_chunk_error(self): chunk_spec = generate_chunk_sharding_specs_for_test(-1) with self.assertRaisesRegex(NotImplementedError, "Chunk by sharding dim is not supported."): st = sharded_tensor.rand(chunk_spec[0], [17, 24]) torch.chunk(st, 5, dim=-1) enumerable_spec = generate_enumerable_sharding_specs_for_test() with self.assertRaisesRegex( NotImplementedError, "Only ChunkShardingSpec is supported for chunk."): st = sharded_tensor.rand(enumerable_spec[0], [10, 10]) torch.chunk(st, 5, dim=-1)
def test_sharded_tensor_type_as(self): specs = _chunk_sharding_specs_list_for_test([0], seed=7) for spec in specs: st = sharded_tensor.rand( spec, 16, 30, 5, init_rrefs=True, dtype=torch.double ) st_2 = sharded_tensor.rand( spec, 16, 30, 5, init_rrefs=True, dtype=torch.float ) st_3 = st.type_as(st_2) self.assertEqual(torch.float, st_3.dtype) self.assertEqual(torch.float, st_3.local_tensor().dtype) st_3 = st.type_as(torch.zeros(10).type(torch.BoolTensor).cuda()) self.assertEqual(torch.bool, st_3.dtype) self.assertEqual(torch.bool, st_3.local_tensor().dtype)
def test_sharded_tensor_reshard_errors(self): specs = _chunk_sharding_specs_list_for_test([0, 1], seed=6) spec, reshard_spec = specs[0], specs[1] enumerable_sharding_spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="rank:1/cuda:1", ), ]) st = sharded_tensor.rand(spec, 24, 12) with self.assertRaisesRegex( NotImplementedError, "Only ChunkShardingSpec supported for reshard."): st.reshard(enumerable_sharding_spec) st._local_shards = [st.local_shards()[0], st.local_shards()[0]] with self.assertRaisesRegex( NotImplementedError, "Only single local shard supported for reshard."): st.reshard(reshard_spec)
def get_random_tensors(self, spec1, spec2, *sizes, pg1=None, pg2=None, seed_offset=0): pg1 = _get_default_group() if pg1 is None else pg1 pg2 = _get_default_group() if pg2 is None else pg2 torch.manual_seed(TestShardedTensorBinaryOps.seed) st1 = sharded_tensor.rand(spec1, sizes, process_group=pg1) torch.manual_seed(TestShardedTensorBinaryOps.seed + seed_offset) st2 = sharded_tensor.rand(spec2, sizes, process_group=pg2) TestShardedTensorBinaryOps.seed += 1 return st1, st2
def __init__( self, spec: ShardingSpec, ) -> None: super(MyShardedModel3, self).__init__() self.sharded_tensor: ShardedTensor = sharded_tensor.rand( spec, 10, 20, init_rrefs=False)
def test_math_ops_errors(self): spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) sharded_lhs = sharded_tensor.rand(spec, (20, 3)) sharded_rhs = sharded_tensor.rand(spec, (12, 3)) with self.assertRaisesRegex( RuntimeError, "Implicit broadcasting not supported" ): torch.add(sharded_lhs, sharded_rhs) spec = EnumerableShardingSpec( [ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="rank:2/cuda:2", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], placement="rank:3/cuda:3", ), ] ) st = sharded_tensor.rand(spec, 10, 10) with self.assertRaisesRegex(RuntimeError, "not supported"): torch.add(st, sharded_rhs)
def test_dummy_writer_works(self) -> None: state_dict = { 'sharded': sharded_tensor.rand(self.get_spec(), 20, 20), 'replicated': torch.rand(10, 10), 'bytes': [1, 2, 3, 4] } save_state_dict(state_dict, FaultyStorageWriter({}))
def test_sharded_tensor_contiguous(self): specs = _chunk_sharding_specs_list_for_test([0], seed=7) for spec in specs: st = sharded_tensor.rand(spec, 10, 22, 5, init_rrefs=True) st = st.transpose(1, 0) st = st.contiguous() self.assertTrue(st.is_contiguous()) self.assertTrue(st.local_tensor().is_contiguous())
def __init__(self, spec=None, group=None): super(MyShardedModel, self).__init__() # Use same seed. torch.manual_seed(0) self.param = torch.nn.Parameter(torch.rand(5, 10)) if spec is not None: self.sharded_param = sharded_tensor.rand(spec, 20, 10, requires_grad=True, process_group=group) else: self.sharded_param = torch.nn.Parameter(torch.rand(5, 10))
def test_sharded_tensor_transpose_error(self): enumerable_spec = generate_enumerable_sharding_specs_for_test()[0] st = sharded_tensor.rand( enumerable_spec, 10, 10, init_rrefs=True, dtype=torch.double ) with self.assertRaisesRegex( NotImplementedError, "Only ChunkShardingSpec supported for 'transpose'", ): st.transpose(1, 0)
def _run_sharded_elementwise_ops(self, spec, input_size, op): torch.manual_seed(self.rank) st = sharded_tensor.rand(spec, *input_size) new_st = op(st) local_shard = st.local_tensor() new_st_local_shard = new_st.local_tensor() self.assertEqual( op(local_shard), new_st_local_shard, )
def __init__(self, spec=None, group=None, init_rrefs=True) -> None: super(MyShardedModel2, self).__init__() if spec is not None: self.sharded_tensor2 = sharded_tensor.rand(spec, 10, 20, process_group=group, init_rrefs=init_rrefs) else: self.sharded_tensor2 = None self.random_tensor2 = torch.nn.Parameter(torch.rand(2, 2))
def test_storage_key_mapping(self) -> None: device = f"cuda:{dist.get_rank()}" spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", ], ) state_dict = { 'sharded': sharded_tensor.rand(spec, ( 10, 10, )), 'replicated': torch.rand(4, device=device), 'bytes': [1, 2, 3, 4], } metadata, bytes_reqs, tensor_reqs = _prepare( state_dict, write_replicated_data=self.rank == 0) if self.rank == 0: self.assertEqual(1, len(bytes_reqs)) self.assertEqual(2, len(tensor_reqs)) self.assertTrue('bytes' in metadata.state_dict_metadata) self.assertEqual(bytes_reqs[0].storage_key, metadata.state_dict_metadata['bytes'].storage_key) # tensor ordering is unspecified if len(tensor_reqs[0].tensor.size()) == 1: replicated = tensor_reqs[0] shard = tensor_reqs[1] else: replicated = tensor_reqs[1] shard = tensor_reqs[0] self.assertTrue('replicated' in metadata.state_dict_metadata) self.assertEqual( replicated.storage_key, metadata.state_dict_metadata['replicated'].storage_key) else: self.assertEqual(0, len(bytes_reqs)) self.assertEqual(1, len(tensor_reqs)) shard = tensor_reqs[0] self.assertTrue('sharded' in metadata.state_dict_metadata) shard_keys = [ sm.storage_key for sm in metadata.state_dict_metadata['sharded'].storage_metadata ] self.assertTrue(shard.storage_key in shard_keys)
def test_sharded_tensor_transpose_error(self): enumerable_spec = generate_enumerable_sharding_specs_for_test()[0] st = sharded_tensor.rand(enumerable_spec, 10, 10, init_rrefs=False, dtype=torch.double) with self.assertRaisesRegex( RuntimeError, "not supported", ): st.transpose(1, 0)
def test_sharded_bmm_errors(self): specs = generate_chunk_sharding_specs_for_test(0) st_lhs = sharded_tensor.rand(specs[0], (15, 5, 6)) st_rhs = sharded_tensor.rand(specs[1], (15, 5, 6)) with self.assertRaisesRegex( NotImplementedError, 'Both st and st2 need to have same placements for bmm', ): torch.bmm(st_lhs, st_rhs) for spec in specs: st_lhs = sharded_tensor.rand(spec, (20, 3)) st_rhs = sharded_tensor.rand(spec, (20, 3)) with self.assertRaisesRegex( TypeError, 'both st and st2 need to be a 3D ShardedTensor', ): torch.bmm(st_lhs, st_rhs) rhs = torch.rand(15, 5, 6).cuda(self.rank) with self.assertRaisesRegex( TypeError, 'st2 needs to be a ShardedTensor for torch.bmm', ): torch.bmm(st_lhs, rhs) spec.dim = 1 st_lhs = sharded_tensor.rand(spec, (15, 5, 6)) st_rhs = sharded_tensor.rand(spec, (15, 5, 6)) with self.assertRaisesRegex( NotImplementedError, 'Only support performing bmm on tensors sharded on dim 0 now', ): torch.bmm(st_lhs, st_rhs)
def test_sharded_tensor_softmax_error(self): specs = _chunk_sharding_specs_list_for_test([0, 2], seed=17) for spec in specs: st = sharded_tensor.rand( spec, 16, 30, 5, init_rrefs=True, dtype=torch.double ) with self.assertRaisesRegex( NotImplementedError, "Only support performing softmax on non-sharding dim now.", ): torch.nn.functional.softmax( st, dim=st.sharding_spec().dim, dtype=torch.float32 )
def test_replicated_tensor_inter_op_sharded_tensor_errors(self): local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4 replica_tensor = ReplicatedTensor(local_tensor) torch.manual_seed(self.rank) spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) st1 = sharded_tensor.rand(spec, (20, 3, 3)) st2 = sharded_tensor.rand(spec, (30, 3, 3)) with self.assertRaisesRegex(RuntimeError, 'Implicit broadcasting'): st1 + st2 with self.assertRaisesRegex(RuntimeError, 'not supported for ShardedTensor'): st1 % replica_tensor
def _run_sharded_elementwise_ops( self, spec, input_size, op, reset_seed=None, **kwargs ): torch.manual_seed(self.rank) st = sharded_tensor.rand(spec, *input_size) reset_seed() if reset_seed else None new_st = op(st, **kwargs) local_shard = st.local_tensor() new_st_local_shard = new_st.local_tensor() reset_seed() if reset_seed else None self.assertEqual( op(local_shard, **kwargs), new_st_local_shard, )
def test_inplace_copy(self): spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) st = sharded_tensor.rand(spec, (12, 5)) ones_st = sharded_tensor.ones(spec, (12, 5)) self.assertFalse(torch.equal(ones_st, st)) st.copy_(ones_st) self.assertTrue(torch.equal(st, ones_st)) # no grad inplace_copy should work between two with different requires_grad st_with_grad = sharded_tensor.rand(spec, (12, 5), requires_grad=True) self.assertTrue(st_with_grad.requires_grad) self.assertFalse(ones_st.requires_grad) with torch.no_grad(): st_with_grad.copy_(ones_st) self.assertEqual(st_with_grad.local_tensor(), ones_st.local_tensor())
def test_clone(self): spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) st = sharded_tensor.rand(spec, (12, 5)) copied_st = st.clone() self.assertTrue(type(copied_st) is type(st)) self.assertEqual(copied_st.local_tensor(), st.local_tensor()) self.assertFalse(copied_st is st)
def test_storage_key_mapping(self) -> None: device = f"cuda:{dist.get_rank()}" spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", ], ) state_dict = { 'sharded': sharded_tensor.rand(spec, (10, 10, )), 'replicated': torch.rand(4, device=device), 'bytes': [1, 2, 3, 4], } metadata, bytes_reqs, tensor_reqs = _prepare(state_dict, write_replicated_data=self.rank == 0) if self.rank == 0: self.assertEqual(1, len(bytes_reqs)) self.assertEqual(2, len(tensor_reqs)) self.assertTrue('bytes' in metadata.state_dict_metadata) self.assertTrue(MetadataIndex('bytes') in metadata.storage_data) # tensor ordering is unspecified if len(tensor_reqs[0].tensor.size()) == 1: replicated = tensor_reqs[0] shard = tensor_reqs[1] else: replicated = tensor_reqs[1] shard = tensor_reqs[0] self.assertTrue('replicated' in metadata.state_dict_metadata) storage_key = MetadataIndex('replicated', torch.Size([0])) self.assertTrue(storage_key in metadata.storage_data) self.assertTrue(metadata.storage_data[storage_key], replicated.storage_key) else: self.assertEqual(0, len(bytes_reqs)) self.assertEqual(1, len(tensor_reqs)) shard = tensor_reqs[0] local_shard = state_dict["sharded"].local_shards()[0] self.assertTrue('sharded' in metadata.state_dict_metadata) storage_key = MetadataIndex('sharded', torch.Size(local_shard.metadata.shard_offsets)) self.assertTrue(storage_key in metadata.storage_data) self.assertTrue(metadata.storage_data[storage_key], shard.storage_key)
def test_load_error_handling(self) -> None: state_dict = { 'sharded': sharded_tensor.rand(self.get_spec(), 20, 20), 'replicated': torch.rand(10, 10), 'bytes': [1, 2, 3, 4] } self._test_load(state_dict) self._test_load(state_dict, fail_read_metadata=[0]) self._test_load(state_dict, fail_read_bytes=[1]) self._test_load(state_dict, fail_read_bytes_async=[2]) self._test_load(state_dict, fail_read_tensors=[3]) self._test_load(state_dict, fail_read_tensors_async=[1]) # We don't want to depend on the actual exception raised by pickle self._test_load(state_dict, fail_deser_bytes=[2], ignore_exception_type=True) self._test_load(state_dict, coordinator=1, fail_read_metadata=[3]) self._test_load(state_dict, coordinator=2, fail_read_bytes=[0]) self._test_load(state_dict, coordinator=3, fail_read_tensors_async=[2])
def test_save_error_handling(self) -> None: state_dict = { 'sharded': sharded_tensor.rand(self.get_spec(), 20, 20), 'replicated': torch.rand(10, 10), 'bytes': [1, 2, 3, 4] } self._test_save(state_dict, fail_prepare=[0]) self._test_save(state_dict, fail_finish=[0]) self._test_save(state_dict, fail_prepare_storage=[0]) self._test_save(state_dict, fail_write_tensors_on_ranks=[1]) self._test_save(state_dict, fail_write_tensors_on_ranks_async=[2]) self._test_save(state_dict, fail_write_bytes_on_ranks=[3]) self._test_save(state_dict, fail_write_bytes_on_ranks_async=[1]) self._test_save(state_dict, fail_write_tensors_on_ranks_async=[1, 3]) self._test_save(state_dict, coordinator=1, fail_prepare=[1]) self._test_save(state_dict, coordinator=1, fail_finish=[1])
def test_sharded_tensor_view_error(self): for spec in _chunk_sharding_specs_list_for_test([2], seed=7): st = sharded_tensor.rand( spec, 35, 17, 26, init_rrefs=True, dtype=torch.double ) with self.assertRaisesRegex( NotImplementedError, "Shape having dim 2 is not supported " "for sharded tensor sharded on dim 2.", ): st.view(35 * 17, 26) with self.assertRaisesRegex( ValueError, r"Shape '\[5, 7, 35, 17, 26\]' is invalid for sharded tensor size 15470.", ): st.view(5, 7, 35, 17, 26) with self.assertRaisesRegex( ValueError, "Only one dimension can be inferred for sharded view op.", ): st.view(5, 7, -1, -1)
def test_detach(self): spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) st = sharded_tensor.rand(spec, (12, 5), requires_grad=True) local_shards = st.local_shards() # before set requires_grad, all local shards should not require grads for local_shard in local_shards: self.assertTrue(local_shard.tensor.requires_grad) detached_st = st.detach() self.assertFalse(detached_st.requires_grad) for local_shard in detached_st.local_shards(): self.assertFalse(local_shard.tensor.requires_grad)
def test_sharded_tensor_masked_fill_error(self): specs = _chunk_sharding_specs_list_for_test([1, 2], seed=7) for spec in specs: st = sharded_tensor.rand( spec, 35, 17, 26, init_rrefs=True, dtype=torch.double ) mask = ( torch.randint(0, 2, (2, 35, 17, 26)) .type(torch.BoolTensor) .cuda(self.rank) ) with self.assertRaisesRegex( ValueError, "mask dim must not greater than the dim of the sharded tensor.", ): st.masked_fill(mask, 25.0) mask = torch.randint(0, 2, (16, 26)).type(torch.BoolTensor).cuda(self.rank) with self.assertRaisesRegex( ValueError, "The size of mask 0 must match the size of sharded tensor 1 " "at non-singleton dimension 0", ): st.masked_fill(mask, 25.0)
def test_switch_between_sharded_tensor_to_tensor(self) -> None: path = self.get_file_path() tensor_size = 32 specs = [ ChunkShardingSpec( dim=0, placements=[ "rank:0", "rank:1", ], ), ChunkShardingSpec( dim=0, placements=[ "rank:0", "rank:1", "rank:1", "rank:0", ], ), EnumerableShardingSpec(shards=[ ShardMetadata( shard_offsets=[0], shard_sizes=[8], placement="rank:1", ), ShardMetadata( shard_offsets=[8], shard_sizes=[tensor_size - 8], placement="rank:0", ), ]), EnumerableShardingSpec(shards=[ ShardMetadata( shard_offsets=[0], shard_sizes=[10], placement="rank:0", ), ShardMetadata( shard_offsets=[10], shard_sizes=[tensor_size - 10], placement="rank:1", ), ]), ] for save_spec in specs: for load_spec in specs: save_dict = { 'sharded': sharded_tensor.rand(save_spec, tensor_size), 'replicated': torch.rand(tensor_size, device=f"cpu:{self.rank}") } fs_writer = FileSystemWriter(path=path) save_state_dict(state_dict=save_dict, storage_writer=fs_writer) # Freaky Friday the tensors load_dict = { 'sharded': torch.zeros(tensor_size, device=f"cpu:{self.rank}"), 'replicated': sharded_tensor.zeros(load_spec, tensor_size) } fs_reader = FileSystemReader(path=path) load_state_dict(state_dict=load_dict, storage_reader=fs_reader) save_dict_sharded = self.load_tensor(save_dict['sharded']) load_dict_replicated = self.load_tensor( load_dict['replicated']) if dist.get_rank() == 0: self.assertTrue( torch.allclose(save_dict_sharded, load_dict['sharded']), f"save-spec {save_spec} load-spec {load_spec}") self.assertTrue( torch.allclose(save_dict['replicated'], load_dict_replicated), f"save-spec {save_spec} load-spec {load_spec}")