def test_sharded_dropout(self):
        def _reset_random_seed():
            torch.manual_seed(self.rank + 4)

        specs = generate_chunk_sharding_specs_for_test(
            0
        ) + generate_chunk_sharding_specs_for_test(1)
        for spec in specs:
            self._run_sharded_elementwise_ops(
                spec,
                [12, 17],
                torch.nn.functional.dropout,
                p=0.4,
                reset_seed=_reset_random_seed,
            )
            self._run_sharded_elementwise_ops(
                spec,
                [18, 21],
                torch.nn.functional.dropout,
                p=0.5,
                reset_seed=_reset_random_seed,
            )
            _reset_random_seed()
            dropout = torch.nn.Dropout(p=0.8)
            self._run_sharded_elementwise_ops(
                spec, [17, 23], dropout, reset_seed=_reset_random_seed
            )
Пример #2
0
 def test_megatron_two_layer_prototype(self):
     colwise_sharding_spec = generate_chunk_sharding_specs_for_test(0)
     rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1)
     for spec in zip(colwise_sharding_spec, rowwise_sharding_spec):
         self._run_megatron_linear(spec, [22, 17], [[17, 12], [12, 29]])
         self._run_megatron_linear(spec, [28, 21], [[21, 11], [11, 29]])
         self._run_megatron_linear(spec, [37, 23], [[23, 13], [13, 24]])
         self._run_megatron_linear(spec, [24, 15], [[15, 14], [14, 20]])
 def test_sharded_relu(self):
     specs = generate_chunk_sharding_specs_for_test(
         0
     ) + generate_chunk_sharding_specs_for_test(1)
     for spec in specs:
         self._run_sharded_elementwise_ops(spec, [12, 17], torch.nn.functional.relu)
         self._run_sharded_elementwise_ops(spec, [18, 21], torch.nn.functional.relu)
         self._run_sharded_elementwise_ops(spec, [17, 23], torch.nn.functional.relu)
         self._run_sharded_elementwise_ops(spec, [14, 15], torch.nn.functional.relu)
Пример #4
0
    def test_reshard_to_ddp_sharding_plan(self):
        colwise_sharding_spec = generate_chunk_sharding_specs_for_test(0)[0]
        rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1)[0]

        # test each sharding spec pair and see if we can apply sharding
        output_spec = copy.deepcopy(rowwise_sharding_spec)
        output_spec.placements.sort(key=lambda placement: placement.rank())
        output_spec.dim = 0

        # new module with megatron as submodule
        class MyModule(nn.Module):
            def __init__(self, rank=None):
                super().__init__()
                self.megatron = SimpleMegatronLM([[17, 12], [12, 29]],
                                                 rank=rank)
                self.relu = nn.ReLU()

            def forward(self, input):
                return self.relu(self.megatron(input))

        sharding_plan = ShardingPlan(plan={
            "megatron.fc1.weight":
            colwise_sharding_spec,
            "megatron.fc2.weight":
            rowwise_sharding_spec,
        },
                                     output_plan={"megatron": output_spec},
                                     return_local_tensor=["megatron"])

        # Use same seed.
        torch.manual_seed(0)
        local_module = MyModule().cuda(self.rank)
        sharded_module = copy.deepcopy(local_module)

        # shard the module with the provided sharding plan
        shard_module(sharded_module, sharding_plan)

        # check to make sure the module already been sharded
        self.assertTrue(
            isinstance(sharded_module.megatron.fc1.weight, ShardedTensor))
        self.assertTrue(
            isinstance(sharded_module.megatron.fc2.weight, ShardedTensor))
        self.assertEqual(sharded_module.megatron.fc1.weight.sharding_spec(),
                         colwise_sharding_spec)
        self.assertEqual(sharded_module.megatron.fc2.weight.sharding_spec(),
                         rowwise_sharding_spec)

        # make sure we can run sharded computation
        input = torch.rand(22, 17).cuda(self.rank)
        sharded_output = sharded_module(input)
        local_output = local_module(input)

        # verify and make sure local and sharded output matches
        self.assertEqual(local_output, sharded_output)
Пример #5
0
 def test_sharded_embedding_colwise(self):
     for spec in generate_chunk_sharding_specs_for_test(1):
         self._run_sharded_embedding(spec, [5, 4], 17, 12)
         self._run_sharded_embedding(spec, [6, 7, 6], 21, 11)
         self._run_sharded_embedding(spec, [8, 6, 5, 4], 23, 13)
         self._run_sharded_embedding(spec, [8, 6, 5, 4, 7], 23, 16)
         self._run_sharded_embedding(spec, [4], 15, 14)
         self._run_sharded_embedding(spec, [34], 15, 14, padding_idx=10)
         self._run_sharded_embedding(spec, [8, 6, 5, 4],
                                     23,
                                     13,
                                     padding_idx=12)
         self._run_sharded_embedding(
             spec,
             [4, 5, 6],
             23,
             13,
             max_norm=2.5,
         )
         self._run_sharded_embedding(
             spec,
             [12, 7, 16],
             23,
             13,
             max_norm=2.5,
         )
         self._run_sharded_embedding(
             spec,
             [8, 16, 20],
             12,
             12,
             max_norm=1.25,
             norm_type=1.0,
         )
         self._run_sharded_embedding(spec, [30], 15, 14, max_norm=2.0)
Пример #6
0
 def test_sharded_bmm_errors(self):
     specs = generate_chunk_sharding_specs_for_test(0)
     st_lhs = sharded_tensor.rand(specs[0], (15, 5, 6))
     st_rhs = sharded_tensor.rand(specs[1], (15, 5, 6))
     with self.assertRaisesRegex(
             NotImplementedError,
             'Both st and st2 need to have same placements for bmm',
     ):
         torch.bmm(st_lhs, st_rhs)
     for spec in specs:
         st_lhs = sharded_tensor.rand(spec, (20, 3))
         st_rhs = sharded_tensor.rand(spec, (20, 3))
         with self.assertRaisesRegex(
                 TypeError,
                 'both st and st2 need to be a 3D ShardedTensor',
         ):
             torch.bmm(st_lhs, st_rhs)
         rhs = torch.rand(15, 5, 6).cuda(self.rank)
         with self.assertRaisesRegex(
                 TypeError,
                 'st2 needs to be a ShardedTensor for torch.bmm',
         ):
             torch.bmm(st_lhs, rhs)
         spec.dim = 1
         st_lhs = sharded_tensor.rand(spec, (15, 5, 6))
         st_rhs = sharded_tensor.rand(spec, (15, 5, 6))
         with self.assertRaisesRegex(
                 NotImplementedError,
                 'Only support performing bmm on tensors sharded on dim 0 now',
         ):
             torch.bmm(st_lhs, st_rhs)
Пример #7
0
 def test_sharded_bmm(self):
     for spec in generate_chunk_sharding_specs_for_test(0):
         lhs = torch.rand(15, 4, 5).cuda(self.rank)
         rhs = torch.rand(15, 5, 6).cuda(self.rank)
         tensor = lhs.bmm(rhs)
         st_lhs = _shard_tensor(lhs, spec)
         st_rhs = _shard_tensor(rhs, spec)
         st_expected = _shard_tensor(tensor, spec)
         self.assertTrue(torch.allclose(torch.bmm(st_lhs, st_rhs), st_expected))
         self.assertTrue(torch.allclose(st_lhs.bmm(st_rhs), st_expected))
Пример #8
0
    def test_sharding_plan_errors(self):
        rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1)[0]
        sharding_plan_wrong_plan = ShardingPlan(
            plan={
                "fc1.weight": torch.randn(3, 4),
            },
            output_plan={
                "": rowwise_sharding_spec
            },
        )

        megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]]).cuda(self.rank)

        with self.assertRaisesRegex(
            TypeError, "Only `ShardingSpec` is supported to shard"
        ):
            # shard the module with the provided sharding plan
            shard_module(megatron_lm, sharding_plan_wrong_plan)

        sharding_plan_wrong_output_plan = ShardingPlan(
            plan={
                "fc1.weight": rowwise_sharding_spec,
            },
            output_plan={
                "": torch.randn(3, 4)
            },
        )

        with self.assertRaisesRegex(
            TypeError, "Only `ShardingSpec` is supported as output_plan"
        ):
            # shard the module with the provided sharding plan
            shard_module(megatron_lm, sharding_plan_wrong_output_plan)

        sharding_plan_wrong_module_path = ShardingPlan(
            plan={
                "fc3.weight": rowwise_sharding_spec,
            },
        )
        with self.assertRaisesRegex(
            AttributeError, "has no attribute"
        ):
            # shard the module with the provided sharding plan
            shard_module(megatron_lm, sharding_plan_wrong_module_path)

        sharding_plan_wrong_param_path = ShardingPlan(
            plan={
                "fc1.biass": rowwise_sharding_spec,
            },
        )
        with self.assertRaisesRegex(
            AttributeError, "has no attribute"
        ):
            # shard the module with the provided sharding plan
            shard_module(megatron_lm, sharding_plan_wrong_param_path)
Пример #9
0
    def test_megatron_two_layer_prototype(self):
        colwise_sharding_spec = generate_chunk_sharding_specs_for_test(0)
        rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1)
        for spec in zip(colwise_sharding_spec, rowwise_sharding_spec):
            self._run_megatron_linear(spec, [22, 17], [[17, 12], [12, 29]], torch.float16)
            self._run_megatron_linear(spec, [28, 21], [[21, 11], [11, 29]], torch.float32)
            self._run_megatron_linear(spec, [37, 23], [[23, 13], [13, 24]], torch.float64)
            self._run_megatron_linear(spec, [24, 15], [[15, 14], [14, 20]], torch.float16)

            # Test multiple input dims
            self._run_megatron_linear(spec, [10, 22, 17], [[17, 12], [12, 29]], torch.float32)
            self._run_megatron_linear(spec, [13, 28, 21], [[21, 11], [11, 29]], torch.float16)
            self._run_megatron_linear(spec, [27, 37, 23], [[23, 13], [13, 24]], torch.float32)
            self._run_megatron_linear(spec, [100, 24, 15], [[15, 14], [14, 20]], torch.float64)

            # Test single input dim
            self._run_megatron_linear(spec, [17], [[17, 12], [12, 29]], torch.float16)
            self._run_megatron_linear(spec, [21], [[21, 11], [11, 29]], torch.float32)
            self._run_megatron_linear(spec, [23], [[23, 13], [13, 24]], torch.float64)
            self._run_megatron_linear(spec, [15], [[15, 14], [14, 20]], torch.float16)
Пример #10
0
 def test_sharded_chunk_error(self):
     chunk_spec = generate_chunk_sharding_specs_for_test(-1)
     with self.assertRaisesRegex(NotImplementedError,
                                 "Chunk by sharding dim is not supported."):
         st = sharded_tensor.rand(chunk_spec[0], [17, 24])
         torch.chunk(st, 5, dim=-1)
     enumerable_spec = generate_enumerable_sharding_specs_for_test()
     with self.assertRaisesRegex(
             NotImplementedError,
             "Only ChunkShardingSpec is supported for chunk."):
         st = sharded_tensor.rand(enumerable_spec[0], [10, 10])
         torch.chunk(st, 5, dim=-1)
Пример #11
0
 def test_sharded_chunk(self):
     sharding_dims = [0]
     specs = []
     for dim in sharding_dims:
         specs.extend(generate_chunk_sharding_specs_for_test(dim))
     for spec in specs:
         self._run_sharded_chunk_test([17, 14], spec, 3)
         self._run_sharded_chunk_test([17, 15, 20], spec, 5)
         self._run_sharded_chunk_test([17, 16], spec, 2)
         # Large matrix case.
         self._run_sharded_chunk_test([128, 512], spec, 8)
         self._run_sharded_chunk_test([1024, 2048], spec, 4)
Пример #12
0
 def test_sharded_bmm(self):
     for spec in generate_chunk_sharding_specs_for_test(0):
         lhs = torch.rand(15, 4, 5).cuda(self.rank)
         rhs = torch.rand(15, 5, 6).cuda(self.rank)
         tensor = lhs.bmm(rhs)
         st_lhs = _shard_tensor(lhs, spec)
         st_rhs = _shard_tensor(rhs, spec)
         st_expected = _shard_tensor(tensor, spec)
         st_expected._metadata.shards_metadata.sort(
             key=lambda x: x.shard_offsets[0], )
         self.assertTrue(
             torch.allclose(torch.bmm(st_lhs, st_rhs), st_expected))
         self.assertTrue(torch.allclose(st_lhs.bmm(st_rhs), st_expected))
Пример #13
0
    def test_sharded_linear_rowwise(self):
        for spec in generate_chunk_sharding_specs_for_test(1):
            # Test even split.
            self._run_sharded_linear(spec, [8, 16], [16, 11], 1)

            # Test uneven split.
            self._run_sharded_linear(spec, [5, 19], [19, 11], 1)
            self._run_sharded_linear(spec, [10, 21], [21, 11], 1)

            # Test multiple input dims
            self._run_sharded_linear(spec, [13, 8, 16], [16, 11], 1)
            self._run_sharded_linear(spec, [10, 5, 19], [19, 11], 1)
            self._run_sharded_linear(spec, [12, 15, 10, 21], [21, 11], 1)

            # Test single input dim
            self._run_sharded_linear(spec, [16], [16, 11], 1)
            self._run_sharded_linear(spec, [19], [19, 11], 1)
            self._run_sharded_linear(spec, [21], [21, 11], 1)
Пример #14
0
    def test_sharded_linear_colwise(self):
        for spec in generate_chunk_sharding_specs_for_test(0):
            self._run_sharded_linear(spec, [2, 17], [17, 12], 0)
            self._run_sharded_linear(spec, [8, 21], [21, 11], 0)
            self._run_sharded_linear(spec, [7, 23], [23, 13], 0)
            self._run_sharded_linear(spec, [4, 15], [15, 14], 0)

            # Test multiple input dims
            self._run_sharded_linear(spec, [10, 2, 17], [17, 12], 0)
            self._run_sharded_linear(spec, [13, 8, 21], [21, 11], 0)
            self._run_sharded_linear(spec, [27, 7, 23], [23, 13], 0)
            self._run_sharded_linear(spec, [100, 12, 4, 15], [15, 14], 0)

            # Test single input dim
            self._run_sharded_linear(spec, [17], [17, 12], 0)
            self._run_sharded_linear(spec, [21], [21, 11], 0)
            self._run_sharded_linear(spec, [23], [23, 13], 0)
            self._run_sharded_linear(spec, [15], [15, 14], 0)
Пример #15
0
    def test_sharded_embedding_rowwise(self):
        for spec in generate_chunk_sharding_specs_for_test(0):
            # Test even split.
            self._run_sharded_embedding(spec, [5, 12], 16, 22)
            self._run_sharded_embedding(spec, [5, 4], 32, 12)
            self._run_sharded_embedding(spec, [6, 7, 6], 64, 11)
            self._run_sharded_embedding(
                spec,
                [5, 12],
                16,
                22,
                max_norm=2.5,
            )
            self._run_sharded_embedding(spec, [6, 7, 6],
                                        64,
                                        11,
                                        padding_idx=30)
            self._run_sharded_embedding(
                spec,
                [6, 5, 3],
                26,
                11,
                max_norm=2.0,
            )

            # Test uneven split.
            self._run_sharded_embedding(spec, [8, 6, 5, 4], 19, 11)
            self._run_sharded_embedding(spec, [6, 7, 6], 21, 11)
            self._run_sharded_embedding(spec, [4], 21, 11)
            self._run_sharded_embedding(spec, [8, 6, 5, 4],
                                        21,
                                        11,
                                        padding_idx=10)
            self._run_sharded_embedding(
                spec,
                [6, 5, 8],
                28,
                5,
                max_norm=2.0,
            )
            self._run_sharded_embedding(spec, [4], 14, 11, max_norm=2.5)
Пример #16
0
    def test_sharding_plan_simple_megatron(self):
        colwise_sharding_spec = generate_chunk_sharding_specs_for_test(0)
        rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1)
        for spec in zip(colwise_sharding_spec, rowwise_sharding_spec):
            # test each sharding spec pair and see if we can apply sharding
            reshard_spec = copy.deepcopy(spec[1])
            reshard_spec.placements.sort(key=lambda placement: placement.rank())
            reshard_spec.dim = 0

            sharding_plan = ShardingPlan(
                plan={
                    "fc1.weight": spec[0],
                    "fc2.weight": spec[1]
                },
                output_plan={
                    "": reshard_spec
                },
                return_local_tensor=[""])

            # Use same seed.
            torch.manual_seed(0)
            local_megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]]).cuda(self.rank)
            megatron_lm = copy.deepcopy(local_megatron_lm)

            # shard the module with the provided sharding plan
            shard_module(megatron_lm, sharding_plan)

            # check to make sure the module already been sharded
            self.assertTrue(isinstance(megatron_lm.fc1.weight, ShardedTensor))
            self.assertTrue(isinstance(megatron_lm.fc2.weight, ShardedTensor))
            self.assertEqual(megatron_lm.fc1.weight.sharding_spec(), spec[0])
            self.assertEqual(megatron_lm.fc2.weight.sharding_spec(), spec[1])

            # make sure we can run sharded computation
            input = torch.rand(22, 17).cuda(self.rank)
            sharded_output = megatron_lm(input)
            local_output = local_megatron_lm(input)

            # verify and make sure local and sharded output matches
            self.assertEqual(local_output, sharded_output)

            # Compute loss and run backward pass.
            local_output.sum().backward()
            sharded_output.sum().backward()
            (
                local_weight_grad_fc1,
                local_weight_grad_fc2,
            ) = local_megatron_lm.get_weight_grads()
            local_bias_grad_fc1, local_bias_grad_fc2 = local_megatron_lm.get_bias_grads()

            # Verify that weights in both layers and biases in the sharded linear has non-None grad.
            (
                sharded_weight_fc1,
                sharded_weight_fc2,
            ) = megatron_lm.get_weights()
            bias_grad_fc1, bias_grad_fc2 = megatron_lm.get_bias_grads()
            self.assertNotEqual(sharded_weight_fc1.grad, None)
            self.assertNotEqual(sharded_weight_fc2.grad, None)
            self.assertNotEqual(bias_grad_fc1, None)
            self.assertNotEqual(bias_grad_fc2, None)

            # Shard the local linear's weight grad so that we can compare.
            dist.all_reduce(local_weight_grad_fc1)
            dist.all_reduce(local_weight_grad_fc2)
            dist.all_reduce(local_bias_grad_fc1)
            dist.all_reduce(local_bias_grad_fc2)
            local_weight_fc1, local_weight_fc2 = local_megatron_lm.get_weights()
            (
                start_pos_fc1,
                chunk_size_fc1,
            ) = generate_local_weight_sharding_params_for_test(
                local_weight_fc1, 0, TEST_GPU_NUM, spec[0], self.rank
            )
            local_grad_narrowed_fc1 = local_weight_grad_fc1.narrow(
                0, start_pos_fc1, chunk_size_fc1
            )
            (
                start_pos_fc2,
                chunk_size_fc2,
            ) = generate_local_weight_sharding_params_for_test(
                local_weight_fc2, 1, TEST_GPU_NUM, spec[1], self.rank
            )
            local_grad_narrowed_fc2 = local_weight_grad_fc2.narrow(
                1, start_pos_fc2, chunk_size_fc2
            )

            # Test backward gradient calculation.
            self.assertEqual(sharded_weight_fc1.grad, local_grad_narrowed_fc1)
            self.assertEqual(sharded_weight_fc2.grad, local_grad_narrowed_fc2)
            self.assertEqual(bias_grad_fc1, local_bias_grad_fc1)
            self.assertEqual(bias_grad_fc2, local_bias_grad_fc2)

            # Test optimizer.
            bias_fc1, bias_fc2 = megatron_lm.get_biases()
            local_bias_fc1, local_bias_fc2 = local_megatron_lm.get_biases()
            self.assertEqual(bias_fc1, local_bias_fc1)
            self.assertEqual(bias_fc2, local_bias_fc2)
            self.assertEqual(bias_fc1.grad, local_bias_fc1.grad)
            self.assertEqual(bias_fc2.grad, local_bias_fc2.grad)
            previous_sharded_weight_fc1 = sharded_weight_fc1.clone()
            previous_sharded_weight_fc2 = sharded_weight_fc2.clone()
            previous_bias_fc1 = bias_fc1.clone()
            previous_bias_fc2 = bias_fc2.clone()
            optim = torch.optim.SGD(local_megatron_lm.parameters(), lr=0.1)
            optim.step()
            sharded_optim = ShardedOptimizer(
                dict(named_params_with_sharded_tensor(megatron_lm)),
                torch.optim.SGD,
                lr=0.1,
            )
            sharded_optim.step()
            local_weight_fc1_narrowed = local_weight_fc1.narrow(
                0, start_pos_fc1, chunk_size_fc1
            )
            local_weight_fc2_narrowed = local_weight_fc2.narrow(
                1, start_pos_fc2, chunk_size_fc2
            )

            # Test weight value after optimizer.
            self.assertEqual(sharded_weight_fc1.size(), local_weight_fc1_narrowed.size())
            self.assertEqual(sharded_weight_fc2.size(), local_weight_fc2_narrowed.size())
            self.assertNotEqual(previous_sharded_weight_fc1, sharded_weight_fc1)
            self.assertNotEqual(previous_sharded_weight_fc2, sharded_weight_fc2)
            self.assertEqual(sharded_weight_fc1, local_weight_fc1_narrowed)
            self.assertEqual(sharded_weight_fc2, local_weight_fc2_narrowed)

            # Test bias value after optimizer.
            local_bias_fc1, local_bias_fc2 = local_megatron_lm.get_biases()
            self.assertNotEqual(previous_bias_fc1, bias_fc1)
            self.assertEqual(bias_fc1, local_bias_fc1)
            self.assertNotEqual(previous_bias_fc2, bias_fc2)
            self.assertEqual(bias_fc2, local_bias_fc2)
Пример #17
0
    def test_sharded_linear_errors(self):
        for spec in generate_chunk_sharding_specs_for_test(0):
            fc1 = torch.nn.Linear(10, 10).cuda(self.rank)
            shard_parameter(fc1, "bias", spec)
            with self.assertRaisesRegex(TypeError,
                                        'bias needs to be torch.Tensor'):
                fc1(torch.rand(10, 10).cuda(self.rank))

            fc2 = torch.nn.Linear(10, 10).cuda(self.rank)
            shard_parameter(fc2, "weight", spec)
            with self.assertRaisesRegex(ValueError,
                                        'Input needs to have at least 1 dim'):
                fc2(torch.tensor(1).cuda(self.rank))

            fc3 = torch.nn.Linear(10, 10).cuda(self.rank)
            fc3.weight = torch.nn.Parameter(
                torch.rand(10, 10, 10).cuda(self.rank))
            shard_parameter(fc3, "weight", spec)
            with self.assertRaisesRegex(ValueError,
                                        'Weight needs to have exactly 2 dims'):
                fc3(torch.rand(10, 10).cuda(self.rank))

            fc4 = torch.nn.Linear(10, 10).cuda(self.rank)
            fc4.bias = torch.nn.Parameter(torch.rand(10, 10).cuda(self.rank))
            shard_parameter(fc4, "weight", spec)
            with self.assertRaisesRegex(ValueError,
                                        'Bias needs to have exactly 1 dim'):
                fc4(torch.rand(10, 10).cuda(self.rank))

            fc5 = torch.nn.Linear(7, 10).cuda(self.rank)
            shard_parameter(fc5, "weight", spec)
            with self.assertRaisesRegex(
                    ValueError,
                    'Input dim: 13 does not match appropriate weight dim: 7'):
                fc5(torch.rand(20, 10, 13).cuda(self.rank))

            fc6 = torch.nn.Linear(10, 10).cuda(self.rank)
            del fc6.weight
            enumerable_spec = EnumerableShardingSpec([
                ShardMetadata(
                    shard_offsets=[0, 0],
                    shard_sizes=[5, 5],
                    placement="rank:0/cuda:0",
                ),
                ShardMetadata(
                    shard_offsets=[0, 5],
                    shard_sizes=[5, 5],
                    placement="rank:1/cuda:1",
                ),
                ShardMetadata(
                    shard_offsets=[5, 0],
                    shard_sizes=[5, 5],
                    placement="rank:2/cuda:2",
                ),
                ShardMetadata(
                    shard_offsets=[5, 5],
                    shard_sizes=[5, 5],
                    placement="rank:3/cuda:3",
                )
            ])

            fc6.weight = empty(enumerable_spec, 10, 10)
            # Sharded Tensor metadata has parenthesis imbalance issue when using re.compile
            error_msg = r"torch function 'linear', with args: (?s).* "
            r"and kwargs: None not supported for ShardedTensor!"
            with self.assertRaisesRegex(RuntimeError, error_msg):
                fc6(torch.rand(10, 10).cuda(self.rank))

            fc7 = torch.nn.Linear(10, 80).cuda(self.rank)
            multiple_local_shard_spec = ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0/cuda:0",
                    "rank:0/cuda:0",
                    "rank:1/cuda:1",
                    "rank:1/cuda:1",
                    "rank:2/cuda:2",
                    "rank:2/cuda:2",
                    "rank:3/cuda:3",
                    "rank:3/cuda:3",
                ],
            )
            del fc7.weight
            fc7.weight = empty(multiple_local_shard_spec, 80, 10)
            with self.assertRaisesRegex(ValueError,
                                        'Only one local shard supported!'):
                fc7(torch.rand(10, 10).cuda(self.rank))
Пример #18
0
 def test_sharded_embedding_bag_rowwise(self):
     for spec in generate_chunk_sharding_specs_for_test(0):
         self._test_sharded_embedding_bag_with_test_cases(spec, 0)