コード例 #1
0
ファイル: test_sharded_optim.py プロジェクト: yinghai/pytorch
    def test_named_params_with_sharded_tensor(self):
        rowwise_spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        sharded_model = MyShardedModel(spec=rowwise_spec).cuda(self.rank)
        sharded_model_params = dict(named_params_with_sharded_tensor(sharded_model))
        param_keys = list(sharded_model_params.keys())
        self.assertEqual(len(param_keys), 2)
        self.assertTrue("param" in param_keys)
        self.assertTrue("sharded_param" in param_keys)

        sharded_linear = MyShardedLinear(rank=self.rank).cuda(self.rank)
        sharded_linear.shard_parameter()
        sharded_linear_params = dict(named_params_with_sharded_tensor(sharded_linear))
        param_keys = list(sharded_linear_params.keys())
        self.assertEqual(len(param_keys), 4)
        self.assertTrue("linear1.bias" in param_keys)
        self.assertTrue("linear2.bias" in param_keys)
        self.assertTrue("linear1.weight" in param_keys)
        self.assertTrue("linear2.weight" in param_keys)
        self.assertFalse("bias" in param_keys)
コード例 #2
0
def _get_toy_module_optim(module, lr):
    """
    Creata a optimizer for sharded tensor by using ShardedOptimizer.
    """
    return ShardedOptimizer(
        dict(named_params_with_sharded_tensor(module)),
        torch.optim.SGD,  # SGD is only demo purpose, one can use other optims.
        lr=lr,
    )
コード例 #3
0
ファイル: test_sharded_optim.py プロジェクト: yinghai/pytorch
    def test_sharded_optim(self):
        rowwise_spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        local_model = MyShardedModel().cuda(self.rank)
        sharded_model = MyShardedModel(spec=rowwise_spec).cuda(self.rank)

        # copy the parameteres from local model
        sharded_model.sharded_param.local_shards()[0].tensor = \
            local_model.sharded_param.detach().clone().requires_grad_()

        local_optim = optim.SGD(local_model.parameters(), lr=0.1)
        sharded_model_params = dict(named_params_with_sharded_tensor(sharded_model))
        sharded_optim = ShardedOptimizer(sharded_model_params, optim.SGD, lr=0.1)

        local_optim.zero_grad()
        sharded_optim.zero_grad()

        before_update = deepcopy(sharded_optim.named_params)

        inp = torch.rand([5, 10]).cuda(self.rank).requires_grad_()

        # run forward
        local_output = local_model(inp)
        sharded_output = sharded_model(inp)
        # backward
        local_output.sum().backward()
        sharded_output.sum().backward()

        # optimizer update
        local_optim.step()
        sharded_optim.step()

        # make sure the parameters (including sharded param)
        # get updated by the optimizer, and the updated
        # local params are the same as the sharded params
        for key, val in before_update.items():
            new_val = sharded_optim.named_params[key]
            if isinstance(val, sharded_tensor.ShardedTensor):
                self.assertNotEqual(
                    val.local_shards()[0].tensor,
                    new_val.local_shards()[0].tensor
                )
                self.assertEqual(
                    new_val.local_shards()[0].tensor,
                    local_model.sharded_param
                )
            else:
                self.assertNotEqual(val, new_val)
                self.assertEqual(new_val, local_model.param)
コード例 #4
0
ファイル: test_linear.py プロジェクト: yanboliang/pytorch
    def _run_sharded_linear(self, spec, input_size, linear_size, sharded_dim,
                            dtype):
        # Use same seed.
        torch.manual_seed(0)
        local_linear = torch.nn.Linear(*linear_size,
                                       dtype=dtype).cuda(self.rank)
        sharded_linear = torch.nn.Linear(*linear_size, dtype=dtype)

        # Copy the weights and bias from local linear
        sharded_linear.weight = clone_module_parameter(local_linear, "weight")
        sharded_linear.bias = clone_module_parameter(local_linear, "bias")

        # Shard the parameter.
        shard_parameter(sharded_linear, "weight", spec)

        # Run sharded computation
        torch.manual_seed(self.rank)  # inputs different on each rank
        inp = torch.rand(*input_size, dtype=dtype).cuda(self.rank)
        reshard_spec = copy.deepcopy(spec)
        reshard_spec.dim = 0
        reshard_spec.placements.sort(key=lambda placement: placement.rank())
        sharded_linear = _collect_local_shard(
            _reshard_output(sharded_linear, reshard_spec))
        sharded_output = sharded_linear(inp)

        # Run local computation
        local_output = local_linear(inp)

        # Verify
        self.assertEqual(local_output, sharded_output, atol=1e-3, rtol=1e-3)

        # Validate for torch.nn.functional.linear version.
        local_output = torch.nn.functional.linear(inp, local_linear.weight,
                                                  local_linear.bias)
        sharded_output = torch.nn.functional.linear(inp, sharded_linear.weight,
                                                    sharded_linear.bias)
        sharded_output = sharded_output.reshard(reshard_spec).local_tensor()
        # When local tensor only has one dimension, we increase one more dimension
        # for reshard. We need to squeeze the # of dimensions manually.
        if inp.dim() == 1:
            sharded_output = sharded_output.squeeze(reshard_spec.dim)
        self.assertEqual(local_output, sharded_output, atol=1e-3, rtol=1e-3)

        # Compute loss and run backward pass.
        local_output.sum().backward()
        sharded_output.sum().backward()
        local_grad = local_linear.weight.grad

        # Verify that both weight and bias in the sharded linear has non-None grad.
        sharded_weight = sharded_linear.weight.local_tensor()
        self.assertNotEqual(sharded_linear.bias.grad, None)
        self.assertNotEqual(sharded_weight.grad, None)

        # Shard the local linear's weight grad so that we can compare.
        dist.all_reduce(local_grad)
        (start_pos,
         chunk_size) = generate_local_weight_sharding_params_for_test(
             local_linear.weight, sharded_dim, TEST_GPU_NUM, spec, self.rank)
        local_grad_narrowed = local_grad.narrow(sharded_dim, start_pos,
                                                chunk_size)
        local_bias_grad = local_linear.bias.grad
        dist.all_reduce(local_bias_grad)

        # Test backward gradient calculation.
        self.assertEqual(sharded_linear.bias.grad,
                         local_bias_grad,
                         atol=1e-3,
                         rtol=1e-3)
        self.assertEqual(sharded_weight.grad,
                         local_grad_narrowed,
                         atol=1e-3,
                         rtol=1e-3)

        # Test optimizer.
        previous = local_linear.weight.clone().detach()
        optim = torch.optim.SGD(local_linear.parameters(), lr=0.1)
        optim.step()
        self.assertNotEqual(previous, local_linear.weight)
        previous_sharded_weight = sharded_weight.clone()
        previous_sharded_bias = sharded_linear.bias.clone()
        sharded_optim = ShardedOptimizer(
            dict(named_params_with_sharded_tensor(sharded_linear)),
            torch.optim.SGD,
            lr=0.1,
        )
        sharded_optim.step()
        sharded_weight = sharded_linear.weight.local_tensor()
        local_weight_narrowed = local_linear.weight.narrow(
            sharded_dim, start_pos, chunk_size)
        self.assertEqual(sharded_weight.size(), local_weight_narrowed.size())
        self.assertNotEqual(previous_sharded_weight, sharded_weight)
        self.assertEqual(sharded_weight,
                         local_weight_narrowed,
                         atol=1e-3,
                         rtol=1e-3)
        self.assertNotEqual(previous_sharded_bias, sharded_linear.bias)
        self.assertEqual(sharded_linear.bias,
                         local_linear.bias,
                         atol=1e-3,
                         rtol=1e-3)
コード例 #5
0
    def test_sharding_plan_simple_megatron(self):
        colwise_sharding_spec = generate_chunk_sharding_specs_for_test(0)
        rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1)
        for spec in zip(colwise_sharding_spec, rowwise_sharding_spec):
            # test each sharding spec pair and see if we can apply sharding
            reshard_spec = copy.deepcopy(spec[1])
            reshard_spec.placements.sort(key=lambda placement: placement.rank())
            reshard_spec.dim = 0

            sharding_plan = ShardingPlan(
                plan={
                    "fc1.weight": spec[0],
                    "fc2.weight": spec[1]
                },
                output_plan={
                    "": reshard_spec
                },
                return_local_tensor=[""])

            # Use same seed.
            torch.manual_seed(0)
            local_megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]]).cuda(self.rank)
            megatron_lm = copy.deepcopy(local_megatron_lm)

            # shard the module with the provided sharding plan
            shard_module(megatron_lm, sharding_plan)

            # check to make sure the module already been sharded
            self.assertTrue(isinstance(megatron_lm.fc1.weight, ShardedTensor))
            self.assertTrue(isinstance(megatron_lm.fc2.weight, ShardedTensor))
            self.assertEqual(megatron_lm.fc1.weight.sharding_spec(), spec[0])
            self.assertEqual(megatron_lm.fc2.weight.sharding_spec(), spec[1])

            # make sure we can run sharded computation
            input = torch.rand(22, 17).cuda(self.rank)
            sharded_output = megatron_lm(input)
            local_output = local_megatron_lm(input)

            # verify and make sure local and sharded output matches
            self.assertEqual(local_output, sharded_output)

            # Compute loss and run backward pass.
            local_output.sum().backward()
            sharded_output.sum().backward()
            (
                local_weight_grad_fc1,
                local_weight_grad_fc2,
            ) = local_megatron_lm.get_weight_grads()
            local_bias_grad_fc1, local_bias_grad_fc2 = local_megatron_lm.get_bias_grads()

            # Verify that weights in both layers and biases in the sharded linear has non-None grad.
            (
                sharded_weight_fc1,
                sharded_weight_fc2,
            ) = megatron_lm.get_weights()
            bias_grad_fc1, bias_grad_fc2 = megatron_lm.get_bias_grads()
            self.assertNotEqual(sharded_weight_fc1.grad, None)
            self.assertNotEqual(sharded_weight_fc2.grad, None)
            self.assertNotEqual(bias_grad_fc1, None)
            self.assertNotEqual(bias_grad_fc2, None)

            # Shard the local linear's weight grad so that we can compare.
            dist.all_reduce(local_weight_grad_fc1)
            dist.all_reduce(local_weight_grad_fc2)
            dist.all_reduce(local_bias_grad_fc1)
            dist.all_reduce(local_bias_grad_fc2)
            local_weight_fc1, local_weight_fc2 = local_megatron_lm.get_weights()
            (
                start_pos_fc1,
                chunk_size_fc1,
            ) = generate_local_weight_sharding_params_for_test(
                local_weight_fc1, 0, TEST_GPU_NUM, spec[0], self.rank
            )
            local_grad_narrowed_fc1 = local_weight_grad_fc1.narrow(
                0, start_pos_fc1, chunk_size_fc1
            )
            (
                start_pos_fc2,
                chunk_size_fc2,
            ) = generate_local_weight_sharding_params_for_test(
                local_weight_fc2, 1, TEST_GPU_NUM, spec[1], self.rank
            )
            local_grad_narrowed_fc2 = local_weight_grad_fc2.narrow(
                1, start_pos_fc2, chunk_size_fc2
            )

            # Test backward gradient calculation.
            self.assertEqual(sharded_weight_fc1.grad, local_grad_narrowed_fc1)
            self.assertEqual(sharded_weight_fc2.grad, local_grad_narrowed_fc2)
            self.assertEqual(bias_grad_fc1, local_bias_grad_fc1)
            self.assertEqual(bias_grad_fc2, local_bias_grad_fc2)

            # Test optimizer.
            bias_fc1, bias_fc2 = megatron_lm.get_biases()
            local_bias_fc1, local_bias_fc2 = local_megatron_lm.get_biases()
            self.assertEqual(bias_fc1, local_bias_fc1)
            self.assertEqual(bias_fc2, local_bias_fc2)
            self.assertEqual(bias_fc1.grad, local_bias_fc1.grad)
            self.assertEqual(bias_fc2.grad, local_bias_fc2.grad)
            previous_sharded_weight_fc1 = sharded_weight_fc1.clone()
            previous_sharded_weight_fc2 = sharded_weight_fc2.clone()
            previous_bias_fc1 = bias_fc1.clone()
            previous_bias_fc2 = bias_fc2.clone()
            optim = torch.optim.SGD(local_megatron_lm.parameters(), lr=0.1)
            optim.step()
            sharded_optim = ShardedOptimizer(
                dict(named_params_with_sharded_tensor(megatron_lm)),
                torch.optim.SGD,
                lr=0.1,
            )
            sharded_optim.step()
            local_weight_fc1_narrowed = local_weight_fc1.narrow(
                0, start_pos_fc1, chunk_size_fc1
            )
            local_weight_fc2_narrowed = local_weight_fc2.narrow(
                1, start_pos_fc2, chunk_size_fc2
            )

            # Test weight value after optimizer.
            self.assertEqual(sharded_weight_fc1.size(), local_weight_fc1_narrowed.size())
            self.assertEqual(sharded_weight_fc2.size(), local_weight_fc2_narrowed.size())
            self.assertNotEqual(previous_sharded_weight_fc1, sharded_weight_fc1)
            self.assertNotEqual(previous_sharded_weight_fc2, sharded_weight_fc2)
            self.assertEqual(sharded_weight_fc1, local_weight_fc1_narrowed)
            self.assertEqual(sharded_weight_fc2, local_weight_fc2_narrowed)

            # Test bias value after optimizer.
            local_bias_fc1, local_bias_fc2 = local_megatron_lm.get_biases()
            self.assertNotEqual(previous_bias_fc1, bias_fc1)
            self.assertEqual(bias_fc1, local_bias_fc1)
            self.assertNotEqual(previous_bias_fc2, bias_fc2)
            self.assertEqual(bias_fc2, local_bias_fc2)
コード例 #6
0
    def _run_megatron_linear(self, spec, input_size, linear_size, dtype):
        def _weight_override(module_dst, module_src):
            module_dst.fc1.weight = clone_module_parameter(module_src.fc1, "weight")
            module_dst.fc1.bias = clone_module_parameter(module_src.fc1, "bias")
            module_dst.fc2.weight = clone_module_parameter(module_src.fc2, "weight")
            module_dst.fc2.bias = clone_module_parameter(module_src.fc2, "bias")

        def _shard_parameter(module, spec):
            shard_parameter(module.fc1, "weight", spec[0])
            shard_parameter(module.fc2, "weight", spec[1])

        # Use same seed.
        torch.manual_seed(0)
        local_megatron_lm = SimpleMegatronLM(linear_size, rank=self.rank, dtype=dtype)
        sharded_megatron_lm = SimpleMegatronLM(linear_size, dtype=dtype)
        _weight_override(sharded_megatron_lm, local_megatron_lm)

        # Shard the parameter. First col-wise sharding and then row-wise
        _shard_parameter(sharded_megatron_lm, spec)

        # Setup resharding of output.
        reshard_spec = copy.deepcopy(spec[1])
        reshard_spec.placements.sort(key=lambda placement: placement.rank())
        reshard_spec.dim = 0

        sharded_megatron_lm = _collect_local_shard(
            _reshard_output(sharded_megatron_lm, reshard_spec)
        )


        torch.manual_seed(self.rank)  # inputs different on each rank
        inp = torch.rand(*input_size, requires_grad=True, device=self.rank, dtype=dtype)

        # Run local computation
        local_output = local_megatron_lm(inp)

        # Compute loss and run backward pass.
        local_output.sum().backward()

        # Save and reset input grads.
        local_input_grad = inp.grad
        self.assertIsNotNone(inp.grad)
        inp.grad = None

        # Run sharded computation
        sharded_output = sharded_megatron_lm(inp)

        # Verify local and sharded results
        self.assertEqual(local_output, sharded_output, atol=1e-3, rtol=1e-6)

        sharded_output.sum().backward()
        sharded_input_grad = inp.grad
        self.assertIsNotNone(inp.grad)

        # Verify sharded and local grads.
        self.assertEqual(local_input_grad, sharded_input_grad, atol=1e-3, rtol=1e-6)

        (
            local_weight_grad_fc1,
            local_weight_grad_fc2,
        ) = local_megatron_lm.get_weight_grads()
        local_bias_grad_fc1, local_bias_grad_fc2 = local_megatron_lm.get_bias_grads()

        # Verify that weights in both layers and biases in the sharded linear has non-None grad.
        (
            sharded_weight_fc1,
            sharded_weight_fc2,
        ) = sharded_megatron_lm.get_weights()
        bias_grad_fc1, bias_grad_fc2 = sharded_megatron_lm.get_bias_grads()
        self.assertNotEqual(sharded_weight_fc1.grad, None)
        self.assertNotEqual(sharded_weight_fc2.grad, None)
        self.assertNotEqual(bias_grad_fc1, None)
        self.assertNotEqual(bias_grad_fc2, None)

        # Shard the local linear's weight grad so that we can compare.
        dist.all_reduce(local_weight_grad_fc1)
        dist.all_reduce(local_weight_grad_fc2)
        dist.all_reduce(local_bias_grad_fc1)
        dist.all_reduce(local_bias_grad_fc2)
        local_weight_fc1, local_weight_fc2 = local_megatron_lm.get_weights()
        (
            start_pos_fc1,
            chunk_size_fc1,
        ) = generate_local_weight_sharding_params_for_test(
            local_weight_fc1, 0, TEST_GPU_NUM, spec[0], self.rank
        )
        local_grad_narrowed_fc1 = local_weight_grad_fc1.narrow(
            0, start_pos_fc1, chunk_size_fc1
        )
        (
            start_pos_fc2,
            chunk_size_fc2,
        ) = generate_local_weight_sharding_params_for_test(
            local_weight_fc2, 1, TEST_GPU_NUM, spec[1], self.rank
        )
        local_grad_narrowed_fc2 = local_weight_grad_fc2.narrow(
            1, start_pos_fc2, chunk_size_fc2
        )

        # Test backward gradient calculation.
        self.assertEdistNorm(sharded_weight_fc1.grad, local_grad_narrowed_fc1)
        self.assertEdistNorm(sharded_weight_fc2.grad, local_grad_narrowed_fc2)
        self.assertEdistNorm(bias_grad_fc1, local_bias_grad_fc1)
        self.assertEdistNorm(bias_grad_fc2, local_bias_grad_fc2)

        # Test optimizer.
        bias_fc1, bias_fc2 = sharded_megatron_lm.get_biases()
        local_bias_fc1, local_bias_fc2 = local_megatron_lm.get_biases()
        self.assertEdistNorm(bias_fc1, local_bias_fc1)
        self.assertEdistNorm(bias_fc2, local_bias_fc2)
        self.assertEdistNorm(bias_fc1.grad, local_bias_fc1.grad)
        self.assertEdistNorm(bias_fc2.grad, local_bias_fc2.grad)
        previous_sharded_weight_fc1 = sharded_weight_fc1.clone()
        previous_sharded_weight_fc2 = sharded_weight_fc2.clone()
        previous_bias_fc1 = bias_fc1.clone()
        previous_bias_fc2 = bias_fc2.clone()
        optim = torch.optim.SGD(local_megatron_lm.parameters(), lr=0.1)
        optim.step()
        sharded_optim = ShardedOptimizer(
            dict(named_params_with_sharded_tensor(sharded_megatron_lm)),
            torch.optim.SGD,
            lr=0.1,
        )
        sharded_optim.step()
        local_weight_fc1_narrowed = local_weight_fc1.narrow(
            0, start_pos_fc1, chunk_size_fc1
        )
        local_weight_fc2_narrowed = local_weight_fc2.narrow(
            1, start_pos_fc2, chunk_size_fc2
        )

        # Test weight value after optimizer.
        self.assertEqual(sharded_weight_fc1.size(), local_weight_fc1_narrowed.size())
        self.assertEqual(sharded_weight_fc2.size(), local_weight_fc2_narrowed.size())
        self.assertNotEqual(previous_sharded_weight_fc1, sharded_weight_fc1)
        self.assertNotEqual(previous_sharded_weight_fc2, sharded_weight_fc2)
        self.assertEdistNorm(sharded_weight_fc1, local_weight_fc1_narrowed)
        self.assertEdistNorm(sharded_weight_fc2, local_weight_fc2_narrowed)

        # Test bias value after optimizer.
        local_bias_fc1, local_bias_fc2 = local_megatron_lm.get_biases()
        self.assertNotEqual(previous_bias_fc1, bias_fc1)
        self.assertEdistNorm(bias_fc1, local_bias_fc1)
        self.assertNotEqual(previous_bias_fc2, bias_fc2)
        self.assertEdistNorm(bias_fc2, local_bias_fc2)
コード例 #7
0
    def _run_sharded_linear(self, spec, input_size, linear_size, sharded_dim):
        # Use same seed.
        torch.manual_seed(0)
        local_linear = torch.nn.Linear(*linear_size).cuda(self.rank)

        sharded_linear = torch.nn.Linear(*linear_size)

        # Copy the weights and bias from local linear
        sharded_linear.weight = torch.nn.Parameter(
            local_linear.weight.detach().clone())
        sharded_linear.bias = torch.nn.Parameter(
            local_linear.bias.detach().clone())

        # Shard the parameter.
        shard_parameter(sharded_linear, "weight", spec)

        # Run sharded computation
        torch.manual_seed(self.rank)  # inputs different on each rank
        inp = torch.rand(*input_size).cuda(self.rank)
        sharded_output = sharded_linear(inp)

        # Run local computation
        local_output = local_linear(inp)

        # Verify
        self.assertEqual(local_output, sharded_output)

        # Validate for torch.nn.functional.linear version.
        local_output = torch.nn.functional.linear(inp, local_linear.weight,
                                                  local_linear.bias)
        sharded_output = torch.nn.functional.linear(inp, sharded_linear.weight,
                                                    sharded_linear.bias)
        self.assertEqual(local_output, sharded_output)

        # Compute loss and run backward pass.
        local_output.sum().backward()
        sharded_output.sum().backward()
        local_grad = local_linear.weight.grad

        # Verify that both weight and bias in the sharded linear has non-None grad.
        sharded_weight = sharded_linear.weight.local_shards()[0].tensor
        self.assertNotEqual(sharded_linear.bias.grad, None)
        self.assertNotEqual(sharded_weight.grad, None)

        # Shard the local linear's weight grad so that we can compare.
        dist.all_reduce(local_grad)
        (start_pos,
         chunk_size) = generate_local_weight_sharding_params_for_test(
             local_linear.weight, sharded_dim, TEST_GPU_NUM, spec, self.rank)
        local_grad_narrowed = local_grad.narrow(sharded_dim, start_pos,
                                                chunk_size)

        # Test backward gradient calculation.
        self.assertEqual(sharded_linear.bias.grad, local_linear.bias.grad)
        self.assertEqual(sharded_weight.grad, local_grad_narrowed)

        # Test optimizer.
        previous = local_linear.weight.clone().detach()
        optim = torch.optim.SGD(local_linear.parameters(), lr=0.1)
        optim.step()
        self.assertNotEqual(previous, local_linear.weight)
        previous_sharded_weight = sharded_weight.clone()
        previous_sharded_bias = sharded_linear.bias.clone()
        sharded_optim = ShardedOptimizer(dict(
            named_params_with_sharded_tensor(sharded_linear)),
                                         torch.optim.SGD,
                                         lr=0.1)
        sharded_optim.step()
        sharded_weight = sharded_linear.weight.local_shards()[0].tensor
        local_weight_narrowed = local_linear.weight.narrow(
            sharded_dim, start_pos, chunk_size)
        self.assertEqual(sharded_weight.size(), local_weight_narrowed.size())
        self.assertNotEqual(previous_sharded_weight, sharded_weight)
        self.assertEqual(sharded_weight, local_weight_narrowed)
        self.assertNotEqual(previous_sharded_bias, sharded_linear.bias)
        self.assertEqual(sharded_linear.bias, local_linear.bias)