def test_sharded_tensor_layer_norm(self): specs = _chunk_sharding_specs_list_for_test([1, 2], seed=10) flags = [True, False] for spec, flag in itertools.product(specs, flags): tensor = torch.rand(16, 35, 26).cuda(self.rank) layer_norm = torch.nn.LayerNorm( (35, 26), elementwise_affine=flag).cuda(self.rank) st = layer_norm(_shard_tensor(tensor, spec)) with torch.no_grad(): tensor_normed = layer_norm(tensor) st_expected = _shard_tensor(tensor_normed, spec) self.assertEqual( st.local_tensor(), st_expected.local_tensor(), ) self.assertTrue(torch.allclose( st, st_expected, atol=1e-6, )) st_expected = torch.nn.functional.layer_norm( _shard_tensor(tensor, spec), (35, 26), weight=layer_norm.weight, bias=layer_norm.bias, ) self.assertTrue(torch.allclose( st, st_expected, atol=1e-6, ))
def test_sharded_tensor_view(self): specs = _chunk_sharding_specs_list_for_test([0, 0], seed=10) for spec in specs: tensor = torch.rand(16, 35, 26).cuda(self.rank) tensor_v = tensor.view(16, 35, 26).view(4, 4, 35, 26) st_expected = _shard_tensor(tensor_v, spec) st_expected._metadata.shards_metadata.sort( key=lambda x: x.shard_offsets[0], ) self.assertTrue( torch.allclose( _shard_tensor(tensor, spec).view(4, 4, 35, 26), st_expected, ) ) st_expected = _shard_tensor(tensor, spec) st_expected._metadata.shards_metadata.sort( key=lambda x: x.shard_offsets[0], ) self.assertTrue( torch.allclose( _shard_tensor(tensor_v, spec).view(16, 35, 26), st_expected, ) )
def test_sharded_tensor_transpose(self): specs = _chunk_sharding_specs_list_for_test([0, 1, 2], seed=7) for spec in specs: tensor = torch.rand(15, 27, 16).cuda(self.rank) tensor_t = tensor.transpose(0, 1).contiguous() spec_n = copy.deepcopy(spec) if spec_n.dim in (0, 1): spec_n.dim = 1 - spec_n.dim st_expected = _shard_tensor(tensor_t, spec_n) st_expected._metadata.shards_metadata.sort( key=lambda x: x.shard_offsets[0], ) self.assertTrue( torch.allclose( torch.transpose(_shard_tensor(tensor, spec), 0, 1), st_expected ) ) tensor_t = torch.transpose(tensor, 1, 2).contiguous() spec_n = copy.deepcopy(spec) if spec_n.dim in (1, 2): spec_n.dim = 3 - spec_n.dim st_expected = _shard_tensor(tensor_t, spec_n) st_expected._metadata.shards_metadata.sort( key=lambda x: x.shard_offsets[0], ) self.assertTrue( torch.allclose(_shard_tensor(tensor, spec).transpose(1, 2), st_expected) )
def test_sharded_bmm(self): for spec in generate_chunk_sharding_specs_for_test(0): lhs = torch.rand(15, 4, 5).cuda(self.rank) rhs = torch.rand(15, 5, 6).cuda(self.rank) tensor = lhs.bmm(rhs) st_lhs = _shard_tensor(lhs, spec) st_rhs = _shard_tensor(rhs, spec) st_expected = _shard_tensor(tensor, spec) self.assertTrue(torch.allclose(torch.bmm(st_lhs, st_rhs), st_expected)) self.assertTrue(torch.allclose(st_lhs.bmm(st_rhs), st_expected))
def test_sharded_tensor_nan_to_num(self): specs = _chunk_sharding_specs_list_for_test([0, 1], seed=10) for spec in specs: tensor = torch.rand(16, 12).cuda(self.rank) tensor[:, :2] = float('nan') tensor[:, 4:5] = float('inf') tensor[:, 10:] = -float('inf') st = _shard_tensor(tensor, spec) st_expected = _shard_tensor(torch.nan_to_num(tensor), spec) st = torch.nan_to_num(st) self.assertTrue(torch.allclose(st, st_expected))
def test_sharded_bmm(self): for spec in generate_chunk_sharding_specs_for_test(0): lhs = torch.rand(15, 4, 5).cuda(self.rank) rhs = torch.rand(15, 5, 6).cuda(self.rank) tensor = lhs.bmm(rhs) st_lhs = _shard_tensor(lhs, spec) st_rhs = _shard_tensor(rhs, spec) st_expected = _shard_tensor(tensor, spec) st_expected._metadata.shards_metadata.sort( key=lambda x: x.shard_offsets[0], ) self.assertTrue( torch.allclose(torch.bmm(st_lhs, st_rhs), st_expected)) self.assertTrue(torch.allclose(st_lhs.bmm(st_rhs), st_expected))
def _run_sharded_tensor_reshard(self, sharding_spec, reshard_spec, input_size): torch.manual_seed(0) local_tensor = torch.rand(*input_size).cuda(self.rank) st = _shard_tensor(local_tensor, sharding_spec) st_compare = _shard_tensor(local_tensor, reshard_spec) st.reshard(reshard_spec) self.assertEqual(1, len(st.local_shards())) self.assertEqual(1, len(st_compare.local_shards())) self.assertEqual(st._metadata, st_compare._metadata) self.assertEqual(st.local_tensor(), st_compare.local_tensor()) self.assertEqual(st.local_shards()[0].metadata, st_compare.local_shards()[0].metadata)
def _test_masked_fill_with_sizes(self, mask_size, broadcast_style=False): specs = _chunk_sharding_specs_list_for_test([0, 1, 2], seed=7) for spec in specs: tensor = torch.rand(35, 17, 26).cuda(self.rank) mask = torch.randint(0, 2, mask_size).type(torch.BoolTensor).cuda( self.rank) if broadcast_style: mask = mask.unsqueeze(1) tensor_m = tensor.masked_fill(mask, 25.0) st_expected = _shard_tensor(tensor_m, spec) self.assertTrue( torch.allclose( _shard_tensor(tensor, spec).masked_fill(mask, 25.0), st_expected, ))
def test_sharded_tensor_softmax(self): specs = _chunk_sharding_specs_list_for_test([0, 2], seed=17) for spec in specs: tensor = torch.rand(15, 27, 16).cuda(self.rank) tensor_n = torch.nn.functional.softmax(tensor, dim=1, dtype=torch.float32) st_expected = _shard_tensor(tensor_n, spec) self.assertTrue( torch.allclose( torch.nn.functional.softmax(_shard_tensor(tensor, spec), dim=1, dtype=torch.float32), st_expected, ))
def test_custom_sharding_spec_shard_tensor(self): """ Test custom spec can be invoked from the _shard_tensor callsite. """ ranks = [ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ] grid_spec = GridShardingSpec(grid_size=2, placements=ranks) with self.assertRaisesRegex(NotImplementedError, 'not implemented'): _shard_tensor(torch.randn(8, 8), grid_spec)
def _run_sharded_chunk_test(self, local_tensor_size, shard_spec, chunk_num): torch.manual_seed(0) local_tensor = torch.rand(*local_tensor_size).cuda(self.rank) st_tensor = _shard_tensor(local_tensor.clone().detach(), shard_spec) local_tensor_chunked = torch.chunk(local_tensor, chunk_num, dim=-1) chunked_st = torch.chunk(st_tensor, chunk_num, dim=-1) self._compare_chunk_result(local_tensor_chunked, chunked_st) chunked_st = st_tensor.chunk(chunk_num, dim=-1) self._compare_chunk_result(local_tensor_chunked, chunked_st)
def _compare_chunk_result(self, chunked_list, chunked_st_list): self.assertEqual(len(chunked_list), len(chunked_st_list)) for idx, chunked_st in enumerate(chunked_st_list): tensor = chunked_list[idx] st = _shard_tensor(tensor.contiguous(), chunked_st.sharding_spec()) # _shard_tensor generate sharded tensor with metadata ranked by # of rank. st._metadata.shards_metadata.sort( key=lambda x: x.shard_offsets[chunked_st.sharding_spec().dim], ) self.assertTrue(torch.allclose(chunked_st, st))
def test_sharded_tensor_layer_norm_error(self): specs = _chunk_sharding_specs_list_for_test([2], seed=10) for spec in specs: tensor = torch.rand(16, 35, 26).cuda(self.rank) with self.assertRaisesRegex( ValueError, "normalized_shape dim must not be greater " "than the dim of the sharded tensor.", ): layer_norm = torch.nn.LayerNorm( (14, 55, 35, 26)).cuda(self.rank) layer_norm(_shard_tensor(tensor, spec)) with self.assertRaisesRegex( ValueError, r"Given normalized_shape=\[35\], expected input with shape " r"\[\*, 35\], but got input of size \[16, 35, 26\].", ): layer_norm = torch.nn.LayerNorm((35)).cuda(self.rank) layer_norm(_shard_tensor(tensor, spec))
def test_sharded_tensor_view(self): specs = _chunk_sharding_specs_list_for_test([0, 0, -3], seed=10) for spec in specs: tensor = torch.rand(16, 35, 26).cuda(self.rank) tensor_v = tensor.view(16, 35, 26).view(4, 4, 35, 26) new_spec = copy.deepcopy(spec) if new_spec.dim < 0: new_spec.dim -= 1 st_expected = _shard_tensor(tensor_v, new_spec) self.assertTrue( torch.allclose( _shard_tensor(tensor, spec).view(4, 4, 35, 26), st_expected, )) st_expected = _shard_tensor(tensor, spec) self.assertTrue( torch.allclose( _shard_tensor(tensor_v, new_spec).view(16, 35, 26), st_expected, ))
def _test_sharded_softmax(self, softmax_dim, sharding_dim): torch.manual_seed(0) local_tensor = torch.rand(10, 10, device=self.rank) local_softmax = torch.nn.functional.softmax(local_tensor, softmax_dim) spec = ChunkShardingSpec(dim=sharding_dim, placements=[ f'rank:{idx}/cuda:{idx}' for idx in range(self.world_size) ]) st = _shard_tensor(local_tensor, spec) sharded_softmax = torch.nn.functional.softmax(st, softmax_dim) self.assertEqual( local_softmax.chunk(self.world_size, dim=sharding_dim)[self.rank], sharded_softmax.local_tensor())
def test_replicated_tensor_inter_op_sharded_tensor(self): torch.manual_seed(self.rank) local_tensor1 = torch.rand(12, 3, device=f"cuda:{self.rank}") * 4 local_tensor2 = torch.ones(12, 3, device=f"cuda:{self.rank}") * 4 spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) st = _shard_tensor(local_tensor1, spec, src_rank=0) replica_tensor = ReplicatedTensor(local_tensor2) ops = [ "torch.add", "torch.sub", "torch.mul", "torch.div", "+", "-", "*", "/" ] for op in ops: binary_op = gen_binary_op_func(op) res = binary_op(st, replica_tensor) self.assertIsInstance(res, sharded_tensor.ShardedTensor) self.assertNotIsInstance(res, ReplicatedTensor) output = torch.empty((12, 3)) if self.rank == 0 else None res.gather(dst=0, out=output) if self.rank == 0: local_output = binary_op(local_tensor1, local_tensor2) self.assertEqual(output, local_output) # reflective reflect_res = binary_op(replica_tensor, st) self.assertIsInstance(reflect_res, sharded_tensor.ShardedTensor) self.assertNotIsInstance(reflect_res, ReplicatedTensor) reflect_output = torch.empty((12, 3)) if self.rank == 0 else None reflect_res.gather(dst=0, out=reflect_output) if self.rank == 0: reflect_local_output = binary_op(local_tensor2, local_tensor1) self.assertEqual(reflect_output, reflect_local_output)
def test_replicated_tensor_implicit_broadcasting(self): # use same seed torch.manual_seed(self.rank) # test implicit broadcasting local_tensor1 = torch.rand(12, 3, device=f"cuda:{self.rank}") * 4 # we use size (3) to trigger the implicit broadcasting logic # and it will fail if implicit broadcasting not happen. local_tensor2 = torch.ones(3, device=f"cuda:{self.rank}") spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) st = _shard_tensor(local_tensor1, spec, src_rank=0) replica_tensor = ReplicatedTensor(local_tensor2) ops = [ "torch.add", "torch.sub", "torch.mul", "torch.div", "+", "-", "*", "/" ] for op in ops: binary_op = gen_binary_op_func(op) # replicated tensor should automatically broadcasted res = binary_op(st, replica_tensor) self.assertIsInstance(res, sharded_tensor.ShardedTensor) output = torch.empty((12, 3)) if self.rank == 0 else None res.gather(dst=0, out=output) if self.rank == 0: local_output = binary_op(local_tensor1, local_tensor2) self.assertEqual(output, local_output)