def test_partial_tensor_reshard_errors(self): enumerable_sharding_spec = EnumerableShardingSpec( [ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="rank:1/cuda:1", ), ] ) with self.assertRaisesRegex( NotImplementedError, "Only ChunkShardingSpec supported for reshard." ): self._run_partial_tensor_n_reshard( enumerable_sharding_spec, [13, 21], 4, dist.ReduceOp.SUM ) self._run_partial_tensor_n_reshard( enumerable_sharding_spec, [12, 22], 4, dist.ReduceOp.MAX ) specs = _chunk_sharding_specs_list_for_test([0], seed=7) spec = specs[0] with self.assertRaisesRegex( NotImplementedError, "Only real partial tensor supported for reshard." ): self._run_partial_tensor_n_reshard( spec, [13, 21], 4, dist.ReduceOp.SUM, dtype=torch.cfloat ) self._run_partial_tensor_n_reshard( spec, [12, 22], 4, dist.ReduceOp.MAX, dtype=torch.cfloat )
def test_sharded_tensor_transpose(self): specs = _chunk_sharding_specs_list_for_test([0, 1, 2], seed=7) for spec in specs: tensor = torch.rand(15, 27, 16).cuda(self.rank) tensor_t = tensor.transpose(0, 1).contiguous() spec_n = copy.deepcopy(spec) if spec_n.dim in (0, 1): spec_n.dim = 1 - spec_n.dim st_expected = _shard_tensor(tensor_t, spec_n) st_expected._metadata.shards_metadata.sort( key=lambda x: x.shard_offsets[0], ) self.assertTrue( torch.allclose( torch.transpose(_shard_tensor(tensor, spec), 0, 1), st_expected ) ) tensor_t = torch.transpose(tensor, 1, 2).contiguous() spec_n = copy.deepcopy(spec) if spec_n.dim in (1, 2): spec_n.dim = 3 - spec_n.dim st_expected = _shard_tensor(tensor_t, spec_n) st_expected._metadata.shards_metadata.sort( key=lambda x: x.shard_offsets[0], ) self.assertTrue( torch.allclose(_shard_tensor(tensor, spec).transpose(1, 2), st_expected) )
def test_sharded_tensor_reshard_errors(self): specs = _chunk_sharding_specs_list_for_test([0, 1], seed=6) spec, reshard_spec = specs[0], specs[1] enumerable_sharding_spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="rank:1/cuda:1", ), ]) st = sharded_tensor.rand(spec, 24, 12) with self.assertRaisesRegex( NotImplementedError, "Only ChunkShardingSpec supported for reshard."): st.reshard(enumerable_sharding_spec) st._local_shards = [st.local_shards()[0], st.local_shards()[0]] with self.assertRaisesRegex( NotImplementedError, "Only single local shard supported for reshard."): st.reshard(reshard_spec)
def test_sharded_tensor_view(self): specs = _chunk_sharding_specs_list_for_test([0, 0], seed=10) for spec in specs: tensor = torch.rand(16, 35, 26).cuda(self.rank) tensor_v = tensor.view(16, 35, 26).view(4, 4, 35, 26) st_expected = _shard_tensor(tensor_v, spec) st_expected._metadata.shards_metadata.sort( key=lambda x: x.shard_offsets[0], ) self.assertTrue( torch.allclose( _shard_tensor(tensor, spec).view(4, 4, 35, 26), st_expected, ) ) st_expected = _shard_tensor(tensor, spec) st_expected._metadata.shards_metadata.sort( key=lambda x: x.shard_offsets[0], ) self.assertTrue( torch.allclose( _shard_tensor(tensor_v, spec).view(16, 35, 26), st_expected, ) )
def test_sharded_tensor_layer_norm(self): specs = _chunk_sharding_specs_list_for_test([1, 2], seed=10) flags = [True, False] for spec, flag in itertools.product(specs, flags): tensor = torch.rand(16, 35, 26).cuda(self.rank) layer_norm = torch.nn.LayerNorm( (35, 26), elementwise_affine=flag).cuda(self.rank) st = layer_norm(_shard_tensor(tensor, spec)) with torch.no_grad(): tensor_normed = layer_norm(tensor) st_expected = _shard_tensor(tensor_normed, spec) self.assertEqual( st.local_tensor(), st_expected.local_tensor(), ) self.assertTrue(torch.allclose( st, st_expected, atol=1e-6, )) st_expected = torch.nn.functional.layer_norm( _shard_tensor(tensor, spec), (35, 26), weight=layer_norm.weight, bias=layer_norm.bias, ) self.assertTrue(torch.allclose( st, st_expected, atol=1e-6, ))
def test_partial_tensor_reshard(self): specs = _chunk_sharding_specs_list_for_test([0], seed=7) spec = specs[0] self._run_partial_tensor_n_reshard(spec, [13, 21], 4, dist.ReduceOp.SUM) self._run_partial_tensor_n_reshard(spec, [12, 22], 4, dist.ReduceOp.MAX) self._run_partial_tensor_n_reshard(spec, [13, 21], 3, dist.ReduceOp.SUM) self._run_partial_tensor_n_reshard(spec, [17, 21], 2, dist.ReduceOp.MAX)
def test_sharded_tensor_contiguous(self): specs = _chunk_sharding_specs_list_for_test([0], seed=7) for spec in specs: st = sharded_tensor.rand(spec, 10, 22, 5, init_rrefs=True) st = st.transpose(1, 0) st = st.contiguous() self.assertTrue(st.is_contiguous()) self.assertTrue(st.local_tensor().is_contiguous())
def test_infer_sharding_spec_from_shards_metadata(self): self._infer_enum_sharding_spec_case() chunk_specs = _chunk_sharding_specs_list_for_test([0, 0, 1, 1], seed=31) for spec in chunk_specs: self._infer_chunk_sharding_spec_case(spec.placements, 0, [4, 16]) self._infer_chunk_sharding_spec_case(spec.placements, 0, [5, 15, 16]) self._infer_chunk_sharding_spec_case(spec.placements, 1, [12, 16]) self._infer_chunk_sharding_spec_case(spec.placements, 2, [4, 18, 15]) self._infer_chunk_sharding_spec_case(spec.placements, 3, [7, 12, 16, 37]) self._infer_chunk_sharding_spec_case(spec.placements, 4, [50, 4, 18, 15, 77])
def test_sharded_tensor_reshard(self): dims = [0, 1] for sharding_dim, reshard_dim in product(dims, dims): specs = _chunk_sharding_specs_list_for_test( [sharding_dim, reshard_dim], seed=5) spec, reshard_spec = specs[0], specs[1] self._run_sharded_tensor_reshard(spec, reshard_spec, [13, 21]) self._run_sharded_tensor_reshard(spec, reshard_spec, [14, 23]) self._run_sharded_tensor_reshard(spec, reshard_spec, [15, 26]) self._run_sharded_tensor_reshard(spec, reshard_spec, [12, 24])
def test_sharded_tensor_nan_to_num(self): specs = _chunk_sharding_specs_list_for_test([0, 1], seed=10) for spec in specs: tensor = torch.rand(16, 12).cuda(self.rank) tensor[:, :2] = float('nan') tensor[:, 4:5] = float('inf') tensor[:, 10:] = -float('inf') st = _shard_tensor(tensor, spec) st_expected = _shard_tensor(torch.nan_to_num(tensor), spec) st = torch.nan_to_num(st) self.assertTrue(torch.allclose(st, st_expected))
def test_sharded_tensor_softmax_error(self): specs = _chunk_sharding_specs_list_for_test([0, 2], seed=17) for spec in specs: st = sharded_tensor.rand( spec, 16, 30, 5, init_rrefs=True, dtype=torch.double ) with self.assertRaisesRegex( NotImplementedError, "Only support performing softmax on non-sharding dim now.", ): torch.nn.functional.softmax( st, dim=st.sharding_spec().dim, dtype=torch.float32 )
def _test_masked_fill_with_sizes(self, mask_size, broadcast_style=False): specs = _chunk_sharding_specs_list_for_test([0, 1, 2], seed=7) for spec in specs: tensor = torch.rand(35, 17, 26).cuda(self.rank) mask = torch.randint(0, 2, mask_size).type(torch.BoolTensor).cuda( self.rank) if broadcast_style: mask = mask.unsqueeze(1) tensor_m = tensor.masked_fill(mask, 25.0) st_expected = _shard_tensor(tensor_m, spec) self.assertTrue( torch.allclose( _shard_tensor(tensor, spec).masked_fill(mask, 25.0), st_expected, ))
def test_sharded_tensor_softmax(self): specs = _chunk_sharding_specs_list_for_test([0, 2], seed=17) for spec in specs: tensor = torch.rand(15, 27, 16).cuda(self.rank) tensor_n = torch.nn.functional.softmax(tensor, dim=1, dtype=torch.float32) st_expected = _shard_tensor(tensor_n, spec) self.assertTrue( torch.allclose( torch.nn.functional.softmax(_shard_tensor(tensor, spec), dim=1, dtype=torch.float32), st_expected, ))
def test_sharded_tensor_type_as(self): specs = _chunk_sharding_specs_list_for_test([0], seed=7) for spec in specs: st = sharded_tensor.rand( spec, 16, 30, 5, init_rrefs=True, dtype=torch.double ) st_2 = sharded_tensor.rand( spec, 16, 30, 5, init_rrefs=True, dtype=torch.float ) st_3 = st.type_as(st_2) self.assertEqual(torch.float, st_3.dtype) self.assertEqual(torch.float, st_3.local_tensor().dtype) st_3 = st.type_as(torch.zeros(10).type(torch.BoolTensor).cuda()) self.assertEqual(torch.bool, st_3.dtype) self.assertEqual(torch.bool, st_3.local_tensor().dtype)
def test_sharded_tensor_layer_norm_error(self): specs = _chunk_sharding_specs_list_for_test([2], seed=10) for spec in specs: tensor = torch.rand(16, 35, 26).cuda(self.rank) with self.assertRaisesRegex( ValueError, "normalized_shape dim must not be greater " "than the dim of the sharded tensor.", ): layer_norm = torch.nn.LayerNorm( (14, 55, 35, 26)).cuda(self.rank) layer_norm(_shard_tensor(tensor, spec)) with self.assertRaisesRegex( ValueError, r"Given normalized_shape=\[35\], expected input with shape " r"\[\*, 35\], but got input of size \[16, 35, 26\].", ): layer_norm = torch.nn.LayerNorm((35)).cuda(self.rank) layer_norm(_shard_tensor(tensor, spec))
def test_sharded_tensor_view(self): specs = _chunk_sharding_specs_list_for_test([0, 0, -3], seed=10) for spec in specs: tensor = torch.rand(16, 35, 26).cuda(self.rank) tensor_v = tensor.view(16, 35, 26).view(4, 4, 35, 26) new_spec = copy.deepcopy(spec) if new_spec.dim < 0: new_spec.dim -= 1 st_expected = _shard_tensor(tensor_v, new_spec) self.assertTrue( torch.allclose( _shard_tensor(tensor, spec).view(4, 4, 35, 26), st_expected, )) st_expected = _shard_tensor(tensor, spec) self.assertTrue( torch.allclose( _shard_tensor(tensor_v, new_spec).view(16, 35, 26), st_expected, ))
def test_sharded_tensor_view_error(self): for spec in _chunk_sharding_specs_list_for_test([2], seed=7): st = sharded_tensor.rand( spec, 35, 17, 26, init_rrefs=True, dtype=torch.double ) with self.assertRaisesRegex( NotImplementedError, "Shape having dim 2 is not supported " "for sharded tensor sharded on dim 2.", ): st.view(35 * 17, 26) with self.assertRaisesRegex( ValueError, r"Shape '\[5, 7, 35, 17, 26\]' is invalid for sharded tensor size 15470.", ): st.view(5, 7, 35, 17, 26) with self.assertRaisesRegex( ValueError, "Only one dimension can be inferred for sharded view op.", ): st.view(5, 7, -1, -1)
def test_partial_tensor_reshard(self): specs = _chunk_sharding_specs_list_for_test([0], seed=7) spec = specs[0] self._run_partial_tensor_n_reshard(spec, [13, 21], 4, dist.ReduceOp.SUM) self._run_partial_tensor_n_reshard(spec, [12, 22], 4, dist.ReduceOp.MAX) self._run_partial_tensor_n_reshard(spec, [13, 21], 3, dist.ReduceOp.SUM) self._run_partial_tensor_n_reshard(spec, [17, 21], 2, dist.ReduceOp.MAX) sub_pgs = [dist.new_group([0, 1]), dist.new_group([2, 3])] pg = sub_pgs[self.rank // 2] spec = self._reshard_spec_for_subgroup(self.rank) self._run_partial_tensor_n_reshard(spec, [12, 22], 4, dist.ReduceOp.MAX, pg=pg) self._run_partial_tensor_n_reshard(spec, [13, 22], 3, dist.ReduceOp.SUM, pg=pg)
def test_sharded_tensor_masked_fill_error(self): specs = _chunk_sharding_specs_list_for_test([1, 2], seed=7) for spec in specs: st = sharded_tensor.rand( spec, 35, 17, 26, init_rrefs=True, dtype=torch.double ) mask = ( torch.randint(0, 2, (2, 35, 17, 26)) .type(torch.BoolTensor) .cuda(self.rank) ) with self.assertRaisesRegex( ValueError, "mask dim must not greater than the dim of the sharded tensor.", ): st.masked_fill(mask, 25.0) mask = torch.randint(0, 2, (16, 26)).type(torch.BoolTensor).cuda(self.rank) with self.assertRaisesRegex( ValueError, "The size of mask 0 must match the size of sharded tensor 1 " "at non-singleton dimension 0", ): st.masked_fill(mask, 25.0)