Example #1
0
    def _run_sharded_linear(self, spec, input_size, linear_size, sharded_dim):
        # Use same seed.
        torch.manual_seed(0)
        local_linear = torch.nn.Linear(*linear_size).cuda(self.rank)

        sharded_linear = torch.nn.Linear(*linear_size)

        # Copy the weights and bias from local linear
        sharded_linear.weight = local_linear.weight
        sharded_linear.bias = local_linear.bias

        # Shard the parameter.
        shard_parameter(sharded_linear, "weight", spec)

        # Run sharded computation
        torch.manual_seed(self.rank)  # inputs different on each rank
        inp = torch.rand(*input_size).cuda(self.rank)
        sharded_output = sharded_linear(inp)

        # Run local computation
        local_output = local_linear(inp)

        # Verify
        self.assertEqual(local_output, sharded_output)

        # Validate for torch.nn.functional.linear version.
        local_output = torch.nn.functional.linear(
            inp, local_linear.weight, local_linear.bias
        )
        sharded_output = torch.nn.functional.linear(
            inp, sharded_linear.weight, sharded_linear.bias
        )
        self.assertEqual(local_output, sharded_output)
Example #2
0
    def _run_sharded_embedding(self, spec, input_size, num_embeddings,
                               embedding_dim):
        # Use same seed.
        torch.manual_seed(0)
        local_embedding = torch.nn.Embedding(num_embeddings,
                                             embedding_dim).cuda(self.rank)

        sharded_embedding = torch.nn.Embedding(num_embeddings, embedding_dim)

        # Copy the weights from local embedding
        sharded_embedding.weight = local_embedding.weight

        # Shard the parameter.
        shard_parameter(sharded_embedding, "weight", spec)

        # Run sharded computation
        torch.manual_seed(self.rank)  # inputs different on each rank
        inp = torch.randint(num_embeddings, tuple(input_size)).cuda(self.rank)
        sharded_output = sharded_embedding(inp)

        # Run local computation
        local_output = local_embedding(inp)

        # Verify
        self.assertEqual(local_output, sharded_output)

        # Validate for torch.nn.functional.embedding version.
        local_output = torch.nn.functional.embedding(inp,
                                                     local_embedding.weight)
        sharded_output = torch.nn.functional.embedding(
            inp, sharded_embedding.weight)

        self.assertEqual(local_output, sharded_output)
Example #3
0
    def _run_sharded_linear(self, spec, input_size, linear_size, sharded_dim):
        # Use same seed.
        torch.manual_seed(0)
        local_linear = torch.nn.Linear(*linear_size).cuda(self.rank)

        sharded_linear = torch.nn.Linear(*linear_size)

        # Copy the weights and bias from local linear
        sharded_linear.weight = torch.nn.Parameter(
            local_linear.weight.detach().clone())
        sharded_linear.bias = torch.nn.Parameter(
            local_linear.bias.detach().clone())

        # Shard the parameter.
        shard_parameter(sharded_linear, "weight", spec)

        # Run sharded computation
        torch.manual_seed(self.rank)  # inputs different on each rank
        inp = torch.rand(*input_size).cuda(self.rank)
        sharded_output = sharded_linear(inp)

        # Run local computation
        local_output = local_linear(inp)

        # Verify
        self.assertEqual(local_output, sharded_output)

        # Validate for torch.nn.functional.linear version.
        local_output = torch.nn.functional.linear(inp, local_linear.weight,
                                                  local_linear.bias)
        sharded_output = torch.nn.functional.linear(inp, sharded_linear.weight,
                                                    sharded_linear.bias)
        self.assertEqual(local_output, sharded_output)

        # Compute loss and run backward pass.
        local_output.sum().backward()
        sharded_output.sum().backward()
        local_grad = local_linear.weight.grad

        # Verify that both weight and bias in the sharded linear has non-None grad.
        sharded_weight = sharded_linear.weight.local_shards()[0].tensor
        self.assertNotEqual(sharded_linear.bias.grad, None)
        self.assertNotEqual(sharded_weight.grad, None)

        # Shard the local linear's weight grad so that we can compare.
        dist.all_reduce(local_grad)
        (start_pos,
         chunk_size) = generate_local_weight_sharding_params_for_test(
             local_linear.weight, sharded_dim, TEST_GPU_NUM, spec, self.rank)
        local_grad_narrowed = local_grad.narrow(sharded_dim, start_pos,
                                                chunk_size)
        local_weight_narrowed = local_linear.weight.narrow(
            sharded_dim, start_pos, chunk_size)

        # Test backward gradient calculation.
        self.assertEqual(sharded_linear.bias.grad, local_linear.bias.grad)
        self.assertEqual(sharded_weight.grad, local_grad_narrowed)
Example #4
0
    def shard_parameter(self):
        rowwise_sharding_spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        colwise_sharding_spec = ChunkShardingSpec(
            dim=1,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        sharded_tensor.shard_parameter(self.linear1, "weight", rowwise_sharding_spec)
        sharded_tensor.shard_parameter(self.linear2, "weight", colwise_sharding_spec)
Example #5
0
    def test_sharded_linear_errors(self):
        for spec in generate_chunk_sharding_specs_for_test(0):
            fc1 = torch.nn.Linear(10, 10).cuda(self.rank)
            shard_parameter(fc1, "bias", spec)
            with self.assertRaisesRegex(
                    TypeError, 'input and bias need to be torch.Tensor'):
                fc1(torch.rand(10, 10).cuda(self.rank))

            fc2 = torch.nn.Linear(10, 10).cuda(self.rank)
            shard_parameter(fc2, "weight", spec)
            with self.assertRaisesRegex(ValueError,
                                        'Input needs to have at least 1 dim'):
                fc2(torch.tensor(1).cuda(self.rank))

            fc3 = torch.nn.Linear(10, 10).cuda(self.rank)
            fc3.weight = torch.nn.Parameter(
                torch.rand(10, 10, 10).cuda(self.rank))
            shard_parameter(fc3, "weight", spec)
            with self.assertRaisesRegex(ValueError,
                                        'Weight needs to have exactly 2 dims'):
                fc3(torch.rand(10, 10).cuda(self.rank))

            fc4 = torch.nn.Linear(10, 10).cuda(self.rank)
            fc4.bias = torch.nn.Parameter(torch.rand(10, 10).cuda(self.rank))
            shard_parameter(fc4, "weight", spec)
            with self.assertRaisesRegex(ValueError,
                                        'Bias needs to have exactly 1 dim'):
                fc4(torch.rand(10, 10).cuda(self.rank))

            fc5 = torch.nn.Linear(7, 10).cuda(self.rank)
            shard_parameter(fc5, "weight", spec)
            with self.assertRaisesRegex(
                    ValueError,
                    'Input dim: 13 does not match appropriate weight dim: 7'):
                fc5(torch.rand(20, 10, 13).cuda(self.rank))

            fc6 = torch.nn.Linear(10, 10).cuda(self.rank)
            del fc6.weight
            enumerable_spec = EnumerableShardingSpec([
                ShardMetadata(
                    shard_offsets=[0, 0],
                    shard_sizes=[5, 5],
                    placement="rank:0/cuda:0",
                ),
                ShardMetadata(
                    shard_offsets=[0, 5],
                    shard_sizes=[5, 5],
                    placement="rank:1/cuda:1",
                ),
                ShardMetadata(
                    shard_offsets=[5, 0],
                    shard_sizes=[5, 5],
                    placement="rank:2/cuda:2",
                ),
                ShardMetadata(
                    shard_offsets=[5, 5],
                    shard_sizes=[5, 5],
                    placement="rank:3/cuda:3",
                )
            ])

            fc6.weight = empty(enumerable_spec, 10, 10)
            with self.assertRaisesRegex(
                    ValueError,
                    'Only ChunkShardingSpec supported for ShardedTensor ops!'):
                fc6(torch.rand(10, 10).cuda(self.rank))

            fc7 = torch.nn.Linear(10, 80).cuda(self.rank)
            multiple_local_shard_spec = ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0/cuda:0",
                    "rank:0/cuda:0",
                    "rank:1/cuda:1",
                    "rank:1/cuda:1",
                    "rank:2/cuda:2",
                    "rank:2/cuda:2",
                    "rank:3/cuda:3",
                    "rank:3/cuda:3",
                ],
            )
            del fc7.weight
            fc7.weight = empty(multiple_local_shard_spec, 80, 10)
            with self.assertRaisesRegex(ValueError,
                                        'Only one local shard supported!'):
                fc7(torch.rand(10, 10).cuda(self.rank))
Example #6
0
    def _run_sharded_embedding(
        self,
        spec,
        input_size,
        num_embeddings,
        embedding_dim,
        sharded_dim=None,
        max_norm=None,
        norm_type=2.0,
        padding_idx=None,
    ):
        # Use same seed.
        torch.manual_seed(0)
        local_embedding = torch.nn.Embedding(
            num_embeddings,
            embedding_dim,
            max_norm=max_norm,
            norm_type=norm_type,
            padding_idx=padding_idx,
        ).cuda(self.rank)

        sharded_embedding = torch.nn.Embedding(
            num_embeddings,
            embedding_dim,
            max_norm=max_norm,
            norm_type=norm_type,
            padding_idx=padding_idx,
        )

        # Copy the weights from local embedding
        sharded_embedding.weight = torch.nn.Parameter(
            local_embedding.weight.detach().clone())

        # Shard the parameter.
        shard_parameter(sharded_embedding, "weight", spec)

        # Run sharded computation
        torch.manual_seed(self.rank)  # inputs different on each rank
        inp = torch.randint(0, num_embeddings,
                            tuple(input_size)).cuda(self.rank)
        sharded_output = sharded_embedding(inp)

        # If max_norm is set, we need to ensure that the renorm has been applied across
        # inputs from all ranks.
        if max_norm is not None:
            gathered_inputs = [
                torch.zeros_like(inp) for _ in range(TEST_GPU_NUM)
            ]
            dist.all_gather(gathered_inputs, inp)
            unique_inp = torch.unique(torch.cat(gathered_inputs))
            local_embedding(unique_inp)

        # Run local computation
        local_output = local_embedding(inp)

        # Compare local weight and shared one to ensure the renorm
        # as expected.
        if max_norm is not None:
            sharded_weight = sharded_embedding.weight.local_shards()[0].tensor
            (start_pos,
             chunk_size) = generate_local_weight_sharding_params_for_test(
                 local_embedding.weight, sharded_dim, TEST_GPU_NUM, spec,
                 self.rank)
            local_weight_narrowed = local_embedding.weight.narrow(
                sharded_dim, start_pos, chunk_size)
            self.assertEqual(local_weight_narrowed, sharded_weight)

        # Verify
        self.assertEqual(local_output, sharded_output)

        # Validate for torch.nn.functional.embedding version.
        local_output = torch.nn.functional.embedding(
            inp,
            local_embedding.weight,
            max_norm=max_norm,
            norm_type=norm_type,
            padding_idx=padding_idx,
        )
        sharded_output = torch.nn.functional.embedding(
            inp,
            sharded_embedding.weight,
            max_norm=max_norm,
            norm_type=norm_type,
            padding_idx=padding_idx,
        )

        self.assertEqual(local_output, sharded_output)
Example #7
0
    def _run_sharded_embedding_bag(self,
                                   spec,
                                   input_size,
                                   num_embeddings,
                                   embedding_dim,
                                   mode,
                                   offset_size=None):
        # Use same seed.
        torch.manual_seed(0)
        local_embedding_bag = torch.nn.EmbeddingBag(num_embeddings,
                                                    embedding_dim,
                                                    mode=mode).cuda(self.rank)

        sharded_embedding_bag = torch.nn.EmbeddingBag(num_embeddings,
                                                      embedding_dim,
                                                      mode=mode)

        # Copy the weights from local embedding bag.
        sharded_embedding_bag.weight = local_embedding_bag.weight

        # Shard the parameter.
        shard_parameter(sharded_embedding_bag, "weight", spec)

        # Run sharded computation
        torch.manual_seed(self.rank)  # inputs different on each rank
        inp = torch.randint(0, num_embeddings,
                            tuple(input_size)).cuda(self.rank)
        per_sample_weights = None
        if mode == "sum":
            per_sample_weights = torch.rand(*input_size).cuda(self.rank)

        offsets = None
        if len(input_size) == 1:
            # We need to generate certain length offset for each rank.
            # The current implementation and dist API does not support
            # the case when the offset has different lengths.
            # input_size[0] >> offset_size, so the while loop will not
            # for too long.
            while offsets is None or (offsets.size(0) != offset_size):
                offsets = torch.randint(input_size[0], (offset_size, ))
                offsets[0] = 0
                offsets = (torch.unique(
                    offsets, sorted=True).contiguous().cuda(self.rank))

        sharded_output = sharded_embedding_bag(
            inp, offsets=offsets, per_sample_weights=per_sample_weights)

        # Run local computation
        local_output = local_embedding_bag(
            inp, offsets=offsets, per_sample_weights=per_sample_weights)

        # Verify
        self.assertEqual(local_output, sharded_output)

        # Validate for torch.nn.functional.embedding_bag version.
        local_output = torch.nn.functional.embedding_bag(
            inp,
            local_embedding_bag.weight,
            offsets=offsets,
            mode=mode,
            per_sample_weights=per_sample_weights,
        )
        sharded_output = torch.nn.functional.embedding_bag(
            inp,
            sharded_embedding_bag.weight,
            offsets=offsets,
            mode=mode,
            per_sample_weights=per_sample_weights,
        )

        self.assertEqual(local_output, sharded_output)