Exemple #1
0
 def test_partial_tensor_reshard_errors(self):
     enumerable_sharding_spec = EnumerableShardingSpec(
         [
             ShardMetadata(
                 shard_offsets=[0, 0],
                 shard_sizes=[5, 5],
                 placement="rank:0/cuda:0",
             ),
             ShardMetadata(
                 shard_offsets=[5, 0],
                 shard_sizes=[5, 5],
                 placement="rank:1/cuda:1",
             ),
         ]
     )
     with self.assertRaisesRegex(
         NotImplementedError, "Only ChunkShardingSpec supported for reshard."
     ):
         self._run_partial_tensor_n_reshard(
             enumerable_sharding_spec, [13, 21], 4, dist.ReduceOp.SUM
         )
         self._run_partial_tensor_n_reshard(
             enumerable_sharding_spec, [12, 22], 4, dist.ReduceOp.MAX
         )
     specs = _chunk_sharding_specs_list_for_test([0], seed=7)
     spec = specs[0]
     with self.assertRaisesRegex(
         NotImplementedError, "Only real partial tensor supported for reshard."
     ):
         self._run_partial_tensor_n_reshard(
             spec, [13, 21], 4, dist.ReduceOp.SUM, dtype=torch.cfloat
         )
         self._run_partial_tensor_n_reshard(
             spec, [12, 22], 4, dist.ReduceOp.MAX, dtype=torch.cfloat
         )
 def test_sharded_tensor_transpose(self):
     specs = _chunk_sharding_specs_list_for_test([0, 1, 2], seed=7)
     for spec in specs:
         tensor = torch.rand(15, 27, 16).cuda(self.rank)
         tensor_t = tensor.transpose(0, 1).contiguous()
         spec_n = copy.deepcopy(spec)
         if spec_n.dim in (0, 1):
             spec_n.dim = 1 - spec_n.dim
         st_expected = _shard_tensor(tensor_t, spec_n)
         st_expected._metadata.shards_metadata.sort(
             key=lambda x: x.shard_offsets[0],
         )
         self.assertTrue(
             torch.allclose(
                 torch.transpose(_shard_tensor(tensor, spec), 0, 1), st_expected
             )
         )
         tensor_t = torch.transpose(tensor, 1, 2).contiguous()
         spec_n = copy.deepcopy(spec)
         if spec_n.dim in (1, 2):
             spec_n.dim = 3 - spec_n.dim
         st_expected = _shard_tensor(tensor_t, spec_n)
         st_expected._metadata.shards_metadata.sort(
             key=lambda x: x.shard_offsets[0],
         )
         self.assertTrue(
             torch.allclose(_shard_tensor(tensor, spec).transpose(1, 2), st_expected)
         )
 def test_sharded_tensor_reshard_errors(self):
     specs = _chunk_sharding_specs_list_for_test([0, 1], seed=6)
     spec, reshard_spec = specs[0], specs[1]
     enumerable_sharding_spec = EnumerableShardingSpec([
         ShardMetadata(
             shard_offsets=[0, 0],
             shard_sizes=[5, 5],
             placement="rank:0/cuda:0",
         ),
         ShardMetadata(
             shard_offsets=[5, 0],
             shard_sizes=[5, 5],
             placement="rank:1/cuda:1",
         ),
     ])
     st = sharded_tensor.rand(spec, 24, 12)
     with self.assertRaisesRegex(
             NotImplementedError,
             "Only ChunkShardingSpec supported for reshard."):
         st.reshard(enumerable_sharding_spec)
     st._local_shards = [st.local_shards()[0], st.local_shards()[0]]
     with self.assertRaisesRegex(
             NotImplementedError,
             "Only single local shard supported for reshard."):
         st.reshard(reshard_spec)
 def test_sharded_tensor_view(self):
     specs = _chunk_sharding_specs_list_for_test([0, 0], seed=10)
     for spec in specs:
         tensor = torch.rand(16, 35, 26).cuda(self.rank)
         tensor_v = tensor.view(16, 35, 26).view(4, 4, 35, 26)
         st_expected = _shard_tensor(tensor_v, spec)
         st_expected._metadata.shards_metadata.sort(
             key=lambda x: x.shard_offsets[0],
         )
         self.assertTrue(
             torch.allclose(
                 _shard_tensor(tensor, spec).view(4, 4, 35, 26),
                 st_expected,
             )
         )
         st_expected = _shard_tensor(tensor, spec)
         st_expected._metadata.shards_metadata.sort(
             key=lambda x: x.shard_offsets[0],
         )
         self.assertTrue(
             torch.allclose(
                 _shard_tensor(tensor_v, spec).view(16, 35, 26),
                 st_expected,
             )
         )
Exemple #5
0
 def test_sharded_tensor_layer_norm(self):
     specs = _chunk_sharding_specs_list_for_test([1, 2], seed=10)
     flags = [True, False]
     for spec, flag in itertools.product(specs, flags):
         tensor = torch.rand(16, 35, 26).cuda(self.rank)
         layer_norm = torch.nn.LayerNorm(
             (35, 26), elementwise_affine=flag).cuda(self.rank)
         st = layer_norm(_shard_tensor(tensor, spec))
         with torch.no_grad():
             tensor_normed = layer_norm(tensor)
         st_expected = _shard_tensor(tensor_normed, spec)
         self.assertEqual(
             st.local_tensor(),
             st_expected.local_tensor(),
         )
         self.assertTrue(torch.allclose(
             st,
             st_expected,
             atol=1e-6,
         ))
         st_expected = torch.nn.functional.layer_norm(
             _shard_tensor(tensor, spec),
             (35, 26),
             weight=layer_norm.weight,
             bias=layer_norm.bias,
         )
         self.assertTrue(torch.allclose(
             st,
             st_expected,
             atol=1e-6,
         ))
Exemple #6
0
 def test_partial_tensor_reshard(self):
     specs = _chunk_sharding_specs_list_for_test([0], seed=7)
     spec = specs[0]
     self._run_partial_tensor_n_reshard(spec, [13, 21], 4, dist.ReduceOp.SUM)
     self._run_partial_tensor_n_reshard(spec, [12, 22], 4, dist.ReduceOp.MAX)
     self._run_partial_tensor_n_reshard(spec, [13, 21], 3, dist.ReduceOp.SUM)
     self._run_partial_tensor_n_reshard(spec, [17, 21], 2, dist.ReduceOp.MAX)
 def test_sharded_tensor_contiguous(self):
     specs = _chunk_sharding_specs_list_for_test([0], seed=7)
     for spec in specs:
         st = sharded_tensor.rand(spec, 10, 22, 5, init_rrefs=True)
         st = st.transpose(1, 0)
         st = st.contiguous()
         self.assertTrue(st.is_contiguous())
         self.assertTrue(st.local_tensor().is_contiguous())
Exemple #8
0
 def test_infer_sharding_spec_from_shards_metadata(self):
     self._infer_enum_sharding_spec_case()
     chunk_specs = _chunk_sharding_specs_list_for_test([0, 0, 1, 1], seed=31)
     for spec in chunk_specs:
         self._infer_chunk_sharding_spec_case(spec.placements, 0, [4, 16])
         self._infer_chunk_sharding_spec_case(spec.placements, 0, [5, 15, 16])
         self._infer_chunk_sharding_spec_case(spec.placements, 1, [12, 16])
         self._infer_chunk_sharding_spec_case(spec.placements, 2, [4, 18, 15])
         self._infer_chunk_sharding_spec_case(spec.placements, 3, [7, 12, 16, 37])
         self._infer_chunk_sharding_spec_case(spec.placements, 4, [50, 4, 18, 15, 77])
 def test_sharded_tensor_reshard(self):
     dims = [0, 1]
     for sharding_dim, reshard_dim in product(dims, dims):
         specs = _chunk_sharding_specs_list_for_test(
             [sharding_dim, reshard_dim], seed=5)
         spec, reshard_spec = specs[0], specs[1]
         self._run_sharded_tensor_reshard(spec, reshard_spec, [13, 21])
         self._run_sharded_tensor_reshard(spec, reshard_spec, [14, 23])
         self._run_sharded_tensor_reshard(spec, reshard_spec, [15, 26])
         self._run_sharded_tensor_reshard(spec, reshard_spec, [12, 24])
 def test_sharded_tensor_nan_to_num(self):
     specs = _chunk_sharding_specs_list_for_test([0, 1], seed=10)
     for spec in specs:
         tensor = torch.rand(16, 12).cuda(self.rank)
         tensor[:, :2] = float('nan')
         tensor[:, 4:5] = float('inf')
         tensor[:, 10:] = -float('inf')
         st = _shard_tensor(tensor, spec)
         st_expected = _shard_tensor(torch.nan_to_num(tensor), spec)
         st = torch.nan_to_num(st)
         self.assertTrue(torch.allclose(st, st_expected))
Exemple #11
0
 def test_sharded_tensor_softmax_error(self):
     specs = _chunk_sharding_specs_list_for_test([0, 2], seed=17)
     for spec in specs:
         st = sharded_tensor.rand(
             spec, 16, 30, 5, init_rrefs=True, dtype=torch.double
         )
         with self.assertRaisesRegex(
             NotImplementedError,
             "Only support performing softmax on non-sharding dim now.",
         ):
             torch.nn.functional.softmax(
                 st, dim=st.sharding_spec().dim, dtype=torch.float32
             )
Exemple #12
0
 def _test_masked_fill_with_sizes(self, mask_size, broadcast_style=False):
     specs = _chunk_sharding_specs_list_for_test([0, 1, 2], seed=7)
     for spec in specs:
         tensor = torch.rand(35, 17, 26).cuda(self.rank)
         mask = torch.randint(0, 2, mask_size).type(torch.BoolTensor).cuda(
             self.rank)
         if broadcast_style:
             mask = mask.unsqueeze(1)
         tensor_m = tensor.masked_fill(mask, 25.0)
         st_expected = _shard_tensor(tensor_m, spec)
         self.assertTrue(
             torch.allclose(
                 _shard_tensor(tensor, spec).masked_fill(mask, 25.0),
                 st_expected,
             ))
Exemple #13
0
 def test_sharded_tensor_softmax(self):
     specs = _chunk_sharding_specs_list_for_test([0, 2], seed=17)
     for spec in specs:
         tensor = torch.rand(15, 27, 16).cuda(self.rank)
         tensor_n = torch.nn.functional.softmax(tensor,
                                                dim=1,
                                                dtype=torch.float32)
         st_expected = _shard_tensor(tensor_n, spec)
         self.assertTrue(
             torch.allclose(
                 torch.nn.functional.softmax(_shard_tensor(tensor, spec),
                                             dim=1,
                                             dtype=torch.float32),
                 st_expected,
             ))
Exemple #14
0
 def test_sharded_tensor_type_as(self):
     specs = _chunk_sharding_specs_list_for_test([0], seed=7)
     for spec in specs:
         st = sharded_tensor.rand(
             spec, 16, 30, 5, init_rrefs=True, dtype=torch.double
         )
         st_2 = sharded_tensor.rand(
             spec, 16, 30, 5, init_rrefs=True, dtype=torch.float
         )
         st_3 = st.type_as(st_2)
         self.assertEqual(torch.float, st_3.dtype)
         self.assertEqual(torch.float, st_3.local_tensor().dtype)
         st_3 = st.type_as(torch.zeros(10).type(torch.BoolTensor).cuda())
         self.assertEqual(torch.bool, st_3.dtype)
         self.assertEqual(torch.bool, st_3.local_tensor().dtype)
Exemple #15
0
 def test_sharded_tensor_layer_norm_error(self):
     specs = _chunk_sharding_specs_list_for_test([2], seed=10)
     for spec in specs:
         tensor = torch.rand(16, 35, 26).cuda(self.rank)
         with self.assertRaisesRegex(
                 ValueError,
                 "normalized_shape dim must not be greater "
                 "than the dim of the sharded tensor.",
         ):
             layer_norm = torch.nn.LayerNorm(
                 (14, 55, 35, 26)).cuda(self.rank)
             layer_norm(_shard_tensor(tensor, spec))
         with self.assertRaisesRegex(
                 ValueError,
                 r"Given normalized_shape=\[35\], expected input with shape "
                 r"\[\*, 35\], but got input of size \[16, 35, 26\].",
         ):
             layer_norm = torch.nn.LayerNorm((35)).cuda(self.rank)
             layer_norm(_shard_tensor(tensor, spec))
Exemple #16
0
 def test_sharded_tensor_view(self):
     specs = _chunk_sharding_specs_list_for_test([0, 0, -3], seed=10)
     for spec in specs:
         tensor = torch.rand(16, 35, 26).cuda(self.rank)
         tensor_v = tensor.view(16, 35, 26).view(4, 4, 35, 26)
         new_spec = copy.deepcopy(spec)
         if new_spec.dim < 0:
             new_spec.dim -= 1
         st_expected = _shard_tensor(tensor_v, new_spec)
         self.assertTrue(
             torch.allclose(
                 _shard_tensor(tensor, spec).view(4, 4, 35, 26),
                 st_expected,
             ))
         st_expected = _shard_tensor(tensor, spec)
         self.assertTrue(
             torch.allclose(
                 _shard_tensor(tensor_v, new_spec).view(16, 35, 26),
                 st_expected,
             ))
Exemple #17
0
 def test_sharded_tensor_view_error(self):
     for spec in _chunk_sharding_specs_list_for_test([2], seed=7):
         st = sharded_tensor.rand(
             spec, 35, 17, 26, init_rrefs=True, dtype=torch.double
         )
         with self.assertRaisesRegex(
             NotImplementedError,
             "Shape having dim 2 is not supported "
             "for sharded tensor sharded on dim 2.",
         ):
             st.view(35 * 17, 26)
         with self.assertRaisesRegex(
             ValueError,
             r"Shape '\[5, 7, 35, 17, 26\]' is invalid for sharded tensor size 15470.",
         ):
             st.view(5, 7, 35, 17, 26)
         with self.assertRaisesRegex(
             ValueError,
             "Only one dimension can be inferred for sharded view op.",
         ):
             st.view(5, 7, -1, -1)
Exemple #18
0
 def test_partial_tensor_reshard(self):
     specs = _chunk_sharding_specs_list_for_test([0], seed=7)
     spec = specs[0]
     self._run_partial_tensor_n_reshard(spec, [13, 21], 4,
                                        dist.ReduceOp.SUM)
     self._run_partial_tensor_n_reshard(spec, [12, 22], 4,
                                        dist.ReduceOp.MAX)
     self._run_partial_tensor_n_reshard(spec, [13, 21], 3,
                                        dist.ReduceOp.SUM)
     self._run_partial_tensor_n_reshard(spec, [17, 21], 2,
                                        dist.ReduceOp.MAX)
     sub_pgs = [dist.new_group([0, 1]), dist.new_group([2, 3])]
     pg = sub_pgs[self.rank // 2]
     spec = self._reshard_spec_for_subgroup(self.rank)
     self._run_partial_tensor_n_reshard(spec, [12, 22],
                                        4,
                                        dist.ReduceOp.MAX,
                                        pg=pg)
     self._run_partial_tensor_n_reshard(spec, [13, 22],
                                        3,
                                        dist.ReduceOp.SUM,
                                        pg=pg)
Exemple #19
0
 def test_sharded_tensor_masked_fill_error(self):
     specs = _chunk_sharding_specs_list_for_test([1, 2], seed=7)
     for spec in specs:
         st = sharded_tensor.rand(
             spec, 35, 17, 26, init_rrefs=True, dtype=torch.double
         )
         mask = (
             torch.randint(0, 2, (2, 35, 17, 26))
             .type(torch.BoolTensor)
             .cuda(self.rank)
         )
         with self.assertRaisesRegex(
             ValueError,
             "mask dim must not greater than the dim of the sharded tensor.",
         ):
             st.masked_fill(mask, 25.0)
         mask = torch.randint(0, 2, (16, 26)).type(torch.BoolTensor).cuda(self.rank)
         with self.assertRaisesRegex(
             ValueError,
             "The size of mask 0 must match the size of sharded tensor 1 "
             "at non-singleton dimension 0",
         ):
             st.masked_fill(mask, 25.0)