Esempio n. 1
0
 def test_sharded_tensor_layer_norm(self):
     specs = _chunk_sharding_specs_list_for_test([1, 2], seed=10)
     flags = [True, False]
     for spec, flag in itertools.product(specs, flags):
         tensor = torch.rand(16, 35, 26).cuda(self.rank)
         layer_norm = torch.nn.LayerNorm(
             (35, 26), elementwise_affine=flag).cuda(self.rank)
         st = layer_norm(_shard_tensor(tensor, spec))
         with torch.no_grad():
             tensor_normed = layer_norm(tensor)
         st_expected = _shard_tensor(tensor_normed, spec)
         self.assertEqual(
             st.local_tensor(),
             st_expected.local_tensor(),
         )
         self.assertTrue(torch.allclose(
             st,
             st_expected,
             atol=1e-6,
         ))
         st_expected = torch.nn.functional.layer_norm(
             _shard_tensor(tensor, spec),
             (35, 26),
             weight=layer_norm.weight,
             bias=layer_norm.bias,
         )
         self.assertTrue(torch.allclose(
             st,
             st_expected,
             atol=1e-6,
         ))
Esempio n. 2
0
 def test_sharded_tensor_view(self):
     specs = _chunk_sharding_specs_list_for_test([0, 0], seed=10)
     for spec in specs:
         tensor = torch.rand(16, 35, 26).cuda(self.rank)
         tensor_v = tensor.view(16, 35, 26).view(4, 4, 35, 26)
         st_expected = _shard_tensor(tensor_v, spec)
         st_expected._metadata.shards_metadata.sort(
             key=lambda x: x.shard_offsets[0],
         )
         self.assertTrue(
             torch.allclose(
                 _shard_tensor(tensor, spec).view(4, 4, 35, 26),
                 st_expected,
             )
         )
         st_expected = _shard_tensor(tensor, spec)
         st_expected._metadata.shards_metadata.sort(
             key=lambda x: x.shard_offsets[0],
         )
         self.assertTrue(
             torch.allclose(
                 _shard_tensor(tensor_v, spec).view(16, 35, 26),
                 st_expected,
             )
         )
Esempio n. 3
0
 def test_sharded_tensor_transpose(self):
     specs = _chunk_sharding_specs_list_for_test([0, 1, 2], seed=7)
     for spec in specs:
         tensor = torch.rand(15, 27, 16).cuda(self.rank)
         tensor_t = tensor.transpose(0, 1).contiguous()
         spec_n = copy.deepcopy(spec)
         if spec_n.dim in (0, 1):
             spec_n.dim = 1 - spec_n.dim
         st_expected = _shard_tensor(tensor_t, spec_n)
         st_expected._metadata.shards_metadata.sort(
             key=lambda x: x.shard_offsets[0],
         )
         self.assertTrue(
             torch.allclose(
                 torch.transpose(_shard_tensor(tensor, spec), 0, 1), st_expected
             )
         )
         tensor_t = torch.transpose(tensor, 1, 2).contiguous()
         spec_n = copy.deepcopy(spec)
         if spec_n.dim in (1, 2):
             spec_n.dim = 3 - spec_n.dim
         st_expected = _shard_tensor(tensor_t, spec_n)
         st_expected._metadata.shards_metadata.sort(
             key=lambda x: x.shard_offsets[0],
         )
         self.assertTrue(
             torch.allclose(_shard_tensor(tensor, spec).transpose(1, 2), st_expected)
         )
Esempio n. 4
0
 def test_sharded_bmm(self):
     for spec in generate_chunk_sharding_specs_for_test(0):
         lhs = torch.rand(15, 4, 5).cuda(self.rank)
         rhs = torch.rand(15, 5, 6).cuda(self.rank)
         tensor = lhs.bmm(rhs)
         st_lhs = _shard_tensor(lhs, spec)
         st_rhs = _shard_tensor(rhs, spec)
         st_expected = _shard_tensor(tensor, spec)
         self.assertTrue(torch.allclose(torch.bmm(st_lhs, st_rhs), st_expected))
         self.assertTrue(torch.allclose(st_lhs.bmm(st_rhs), st_expected))
Esempio n. 5
0
 def test_sharded_tensor_nan_to_num(self):
     specs = _chunk_sharding_specs_list_for_test([0, 1], seed=10)
     for spec in specs:
         tensor = torch.rand(16, 12).cuda(self.rank)
         tensor[:, :2] = float('nan')
         tensor[:, 4:5] = float('inf')
         tensor[:, 10:] = -float('inf')
         st = _shard_tensor(tensor, spec)
         st_expected = _shard_tensor(torch.nan_to_num(tensor), spec)
         st = torch.nan_to_num(st)
         self.assertTrue(torch.allclose(st, st_expected))
Esempio n. 6
0
 def test_sharded_bmm(self):
     for spec in generate_chunk_sharding_specs_for_test(0):
         lhs = torch.rand(15, 4, 5).cuda(self.rank)
         rhs = torch.rand(15, 5, 6).cuda(self.rank)
         tensor = lhs.bmm(rhs)
         st_lhs = _shard_tensor(lhs, spec)
         st_rhs = _shard_tensor(rhs, spec)
         st_expected = _shard_tensor(tensor, spec)
         st_expected._metadata.shards_metadata.sort(
             key=lambda x: x.shard_offsets[0], )
         self.assertTrue(
             torch.allclose(torch.bmm(st_lhs, st_rhs), st_expected))
         self.assertTrue(torch.allclose(st_lhs.bmm(st_rhs), st_expected))
 def _run_sharded_tensor_reshard(self, sharding_spec, reshard_spec,
                                 input_size):
     torch.manual_seed(0)
     local_tensor = torch.rand(*input_size).cuda(self.rank)
     st = _shard_tensor(local_tensor, sharding_spec)
     st_compare = _shard_tensor(local_tensor, reshard_spec)
     st.reshard(reshard_spec)
     self.assertEqual(1, len(st.local_shards()))
     self.assertEqual(1, len(st_compare.local_shards()))
     self.assertEqual(st._metadata, st_compare._metadata)
     self.assertEqual(st.local_tensor(), st_compare.local_tensor())
     self.assertEqual(st.local_shards()[0].metadata,
                      st_compare.local_shards()[0].metadata)
Esempio n. 8
0
 def _test_masked_fill_with_sizes(self, mask_size, broadcast_style=False):
     specs = _chunk_sharding_specs_list_for_test([0, 1, 2], seed=7)
     for spec in specs:
         tensor = torch.rand(35, 17, 26).cuda(self.rank)
         mask = torch.randint(0, 2, mask_size).type(torch.BoolTensor).cuda(
             self.rank)
         if broadcast_style:
             mask = mask.unsqueeze(1)
         tensor_m = tensor.masked_fill(mask, 25.0)
         st_expected = _shard_tensor(tensor_m, spec)
         self.assertTrue(
             torch.allclose(
                 _shard_tensor(tensor, spec).masked_fill(mask, 25.0),
                 st_expected,
             ))
Esempio n. 9
0
 def test_sharded_tensor_softmax(self):
     specs = _chunk_sharding_specs_list_for_test([0, 2], seed=17)
     for spec in specs:
         tensor = torch.rand(15, 27, 16).cuda(self.rank)
         tensor_n = torch.nn.functional.softmax(tensor,
                                                dim=1,
                                                dtype=torch.float32)
         st_expected = _shard_tensor(tensor_n, spec)
         self.assertTrue(
             torch.allclose(
                 torch.nn.functional.softmax(_shard_tensor(tensor, spec),
                                             dim=1,
                                             dtype=torch.float32),
                 st_expected,
             ))
Esempio n. 10
0
    def test_custom_sharding_spec_shard_tensor(self):
        """ Test custom spec can be invoked from the
            _shard_tensor callsite.
        """

        ranks = [
            "rank:0/cuda:0",
            "rank:1/cuda:1",
            "rank:2/cuda:2",
            "rank:3/cuda:3",
        ]

        grid_spec = GridShardingSpec(grid_size=2, placements=ranks)

        with self.assertRaisesRegex(NotImplementedError, 'not implemented'):
            _shard_tensor(torch.randn(8, 8), grid_spec)
Esempio n. 11
0
 def _run_sharded_chunk_test(self, local_tensor_size, shard_spec,
                             chunk_num):
     torch.manual_seed(0)
     local_tensor = torch.rand(*local_tensor_size).cuda(self.rank)
     st_tensor = _shard_tensor(local_tensor.clone().detach(), shard_spec)
     local_tensor_chunked = torch.chunk(local_tensor, chunk_num, dim=-1)
     chunked_st = torch.chunk(st_tensor, chunk_num, dim=-1)
     self._compare_chunk_result(local_tensor_chunked, chunked_st)
     chunked_st = st_tensor.chunk(chunk_num, dim=-1)
     self._compare_chunk_result(local_tensor_chunked, chunked_st)
Esempio n. 12
0
 def _compare_chunk_result(self, chunked_list, chunked_st_list):
     self.assertEqual(len(chunked_list), len(chunked_st_list))
     for idx, chunked_st in enumerate(chunked_st_list):
         tensor = chunked_list[idx]
         st = _shard_tensor(tensor.contiguous(), chunked_st.sharding_spec())
         # _shard_tensor generate sharded tensor with metadata ranked by # of rank.
         st._metadata.shards_metadata.sort(
             key=lambda x: x.shard_offsets[chunked_st.sharding_spec().dim],
         )
         self.assertTrue(torch.allclose(chunked_st, st))
Esempio n. 13
0
 def test_sharded_tensor_layer_norm_error(self):
     specs = _chunk_sharding_specs_list_for_test([2], seed=10)
     for spec in specs:
         tensor = torch.rand(16, 35, 26).cuda(self.rank)
         with self.assertRaisesRegex(
                 ValueError,
                 "normalized_shape dim must not be greater "
                 "than the dim of the sharded tensor.",
         ):
             layer_norm = torch.nn.LayerNorm(
                 (14, 55, 35, 26)).cuda(self.rank)
             layer_norm(_shard_tensor(tensor, spec))
         with self.assertRaisesRegex(
                 ValueError,
                 r"Given normalized_shape=\[35\], expected input with shape "
                 r"\[\*, 35\], but got input of size \[16, 35, 26\].",
         ):
             layer_norm = torch.nn.LayerNorm((35)).cuda(self.rank)
             layer_norm(_shard_tensor(tensor, spec))
Esempio n. 14
0
 def test_sharded_tensor_view(self):
     specs = _chunk_sharding_specs_list_for_test([0, 0, -3], seed=10)
     for spec in specs:
         tensor = torch.rand(16, 35, 26).cuda(self.rank)
         tensor_v = tensor.view(16, 35, 26).view(4, 4, 35, 26)
         new_spec = copy.deepcopy(spec)
         if new_spec.dim < 0:
             new_spec.dim -= 1
         st_expected = _shard_tensor(tensor_v, new_spec)
         self.assertTrue(
             torch.allclose(
                 _shard_tensor(tensor, spec).view(4, 4, 35, 26),
                 st_expected,
             ))
         st_expected = _shard_tensor(tensor, spec)
         self.assertTrue(
             torch.allclose(
                 _shard_tensor(tensor_v, new_spec).view(16, 35, 26),
                 st_expected,
             ))
Esempio n. 15
0
    def _test_sharded_softmax(self, softmax_dim, sharding_dim):
        torch.manual_seed(0)
        local_tensor = torch.rand(10, 10, device=self.rank)
        local_softmax = torch.nn.functional.softmax(local_tensor, softmax_dim)

        spec = ChunkShardingSpec(dim=sharding_dim,
                                 placements=[
                                     f'rank:{idx}/cuda:{idx}'
                                     for idx in range(self.world_size)
                                 ])
        st = _shard_tensor(local_tensor, spec)
        sharded_softmax = torch.nn.functional.softmax(st, softmax_dim)

        self.assertEqual(
            local_softmax.chunk(self.world_size, dim=sharding_dim)[self.rank],
            sharded_softmax.local_tensor())
Esempio n. 16
0
    def test_replicated_tensor_inter_op_sharded_tensor(self):
        torch.manual_seed(self.rank)

        local_tensor1 = torch.rand(12, 3, device=f"cuda:{self.rank}") * 4
        local_tensor2 = torch.ones(12, 3, device=f"cuda:{self.rank}") * 4

        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        st = _shard_tensor(local_tensor1, spec, src_rank=0)
        replica_tensor = ReplicatedTensor(local_tensor2)

        ops = [
            "torch.add", "torch.sub", "torch.mul", "torch.div", "+", "-", "*",
            "/"
        ]

        for op in ops:
            binary_op = gen_binary_op_func(op)
            res = binary_op(st, replica_tensor)
            self.assertIsInstance(res, sharded_tensor.ShardedTensor)
            self.assertNotIsInstance(res, ReplicatedTensor)
            output = torch.empty((12, 3)) if self.rank == 0 else None
            res.gather(dst=0, out=output)

            if self.rank == 0:
                local_output = binary_op(local_tensor1, local_tensor2)
                self.assertEqual(output, local_output)

            # reflective
            reflect_res = binary_op(replica_tensor, st)
            self.assertIsInstance(reflect_res, sharded_tensor.ShardedTensor)
            self.assertNotIsInstance(reflect_res, ReplicatedTensor)
            reflect_output = torch.empty((12, 3)) if self.rank == 0 else None
            reflect_res.gather(dst=0, out=reflect_output)

            if self.rank == 0:
                reflect_local_output = binary_op(local_tensor2, local_tensor1)
                self.assertEqual(reflect_output, reflect_local_output)
Esempio n. 17
0
    def test_replicated_tensor_implicit_broadcasting(self):
        #  use same seed
        torch.manual_seed(self.rank)

        # test implicit broadcasting
        local_tensor1 = torch.rand(12, 3, device=f"cuda:{self.rank}") * 4
        # we use size (3) to trigger the implicit broadcasting logic
        # and it will fail if implicit broadcasting not happen.
        local_tensor2 = torch.ones(3, device=f"cuda:{self.rank}")

        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        st = _shard_tensor(local_tensor1, spec, src_rank=0)
        replica_tensor = ReplicatedTensor(local_tensor2)

        ops = [
            "torch.add", "torch.sub", "torch.mul", "torch.div", "+", "-", "*",
            "/"
        ]

        for op in ops:
            binary_op = gen_binary_op_func(op)
            # replicated tensor should automatically broadcasted
            res = binary_op(st, replica_tensor)

            self.assertIsInstance(res, sharded_tensor.ShardedTensor)
            output = torch.empty((12, 3)) if self.rank == 0 else None
            res.gather(dst=0, out=output)

            if self.rank == 0:
                local_output = binary_op(local_tensor1, local_tensor2)
                self.assertEqual(output, local_output)