Exemplo n.º 1
0
def parallel_self_attention(tensor_model_parallel_size,
                            num_att_heads_per_partition,
                            hidden_size_per_att_head, dropout_prob, batch_size,
                            sequence_length):
    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    seed = 12345
    set_random_seed(seed)

    num_att_heads = num_att_heads_per_partition * \
        torch.distributed.get_world_size()
    hidden_size = hidden_size_per_att_head * num_att_heads

    # Network
    identity_layer = IdentityLayer3D(batch_size, sequence_length,
                                     hidden_size).cuda()
    attention_layer = parallel_state.BertParallelSelfAttention(
        hidden_size, num_att_heads, dropout_prob).cuda()
    loss_weight = torch.randn([batch_size, sequence_length,
                               hidden_size]).cuda()
    attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
    # Forward
    input_ = identity_layer()
    output = attention_layer(input_, attention_mask)
    loss = torch.mul(output, loss_weight).sum()
    # Backward
    loss.backward()

    rank = parallel_state.get_tensor_model_parallel_rank()
    parallel_state.destroy_model_parallel()
    return rank, hidden_size, tensor_model_parallel_size, loss, \
        attention_layer, identity_layer
Exemplo n.º 2
0
    def test_row_parallel_linear(self) -> None:
        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_parallel_world_size:
                continue
            with self.subTest(tensor_model_parallel_world_size=
                              tensor_model_parallel_world_size):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=tensor_model_parallel_world_size
                )

                input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
                output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size

                set_random_seed(self.SEED)
                linear_layer = layers.RowParallelLinear(
                    input_size,
                    output_size,
                    keep_master_weight_for_test=True,
                    params_dtype=torch.float32,
                    use_cpu_initialization=True,
                ).cuda()
                loss_weight = torch.randn(
                    (self.BATCH_SIZE, output_size)).cuda()

                # Forward and backward
                input_tensor = torch.randn(self.BATCH_SIZE,
                                           input_size,
                                           requires_grad=True).cuda()
                input_tensor.retain_grad()
                output, _ = linear_layer(input_tensor)
                loss = torch.mul(output, loss_weight).sum()
                loss.backward()
                self.assertIsNotNone(input_tensor.grad)

                with torch.no_grad():
                    dldy = loss_weight.clone()
                    x = input_tensor.clone()
                    a = linear_layer.master_weight.cuda()
                dlda = torch.matmul(dldy.t(), x)
                dldb = torch.matmul(
                    torch.ones(self.BATCH_SIZE, 1).cuda().t(), dldy).view(-1)
                dldx = torch.matmul(dldy, a)

                with torch.no_grad():
                    curr_dlda = torch.split(
                        dlda, self.INPUT_SIZE_COEFF, dim=1
                    )[parallel_state.get_tensor_model_parallel_rank()].clone()
                self.assertEqual(linear_layer.weight.grad, curr_dlda)
                self.assertEqual(input_tensor.grad, dldx)
                self.assertEqual(linear_layer.bias.grad, dldb)

                parallel_state.destroy_model_parallel()
Exemplo n.º 3
0
    def _affine_weight_init_test_impl(self, init_device: str,
                                      is_column_parallel: bool) -> None:
        dim = int(not is_column_parallel)
        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_parallel_world_size:
                continue
            with self.subTest(tensor_model_parallel_world_size=
                              tensor_model_parallel_world_size):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=tensor_model_parallel_world_size
                )
                input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
                output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size

                weight_shape = ((self.OUTPUT_SIZE_COEFF,
                                 input_size) if is_column_parallel else
                                (output_size, self.INPUT_SIZE_COEFF))
                weight = torch.empty(weight_shape)
                set_random_seed(self.SEED)

                sharding_dim_size = (self.OUTPUT_SIZE_COEFF
                                     if is_column_parallel else
                                     self.INPUT_SIZE_COEFF)

                if init_device == "cpu":
                    layers._initialize_affine_weight_cpu(
                        weight,
                        output_size,
                        input_size,
                        sharding_dim_size,
                        dim,
                        nn.init.normal_,
                        params_dtype=torch.float32,
                    )
                else:
                    layers._initialize_affine_weight_gpu(
                        weight, torch.nn.init.normal_, dim)
                # Target
                set_random_seed(self.SEED)
                if init_device == "cpu":
                    main_weight = torch.empty(output_size, input_size)
                    nn.init.normal_(main_weight)
                    curr_weight = torch.split(
                        main_weight, sharding_dim_size, dim=dim)[
                            parallel_state.get_tensor_model_parallel_rank()]
                else:
                    curr_weight = torch.empty(*weight_shape)
                    nn.init.normal_(curr_weight)
                self.assertEqual(curr_weight, weight)
                parallel_state.destroy_model_parallel()
Exemplo n.º 4
0
def torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale,
                        seed):
    set_random_seed(seed)
    identity = IdentityLayer((batch_size, seq_length, vocab_size),
                             scale=logits_scale).cuda()
    logits = identity()
    target = torch.cuda.LongTensor(size=(batch_size,
                                         seq_length)).random_(0, vocab_size)
    loss = F.cross_entropy(logits.view(-1,
                                       logits.size()[-1]),
                           target.view(-1),
                           reduction='none').view_as(target).mean()
    loss.backward()
    return loss, identity.weight.grad
Exemplo n.º 5
0
def tensor_sharded_cross_entropy(batch_size, seq_length, vocab_size,
                                 logits_scale, seed):
    set_random_seed(seed)
    identity = IdentityLayer((batch_size, seq_length, vocab_size),
                             scale=logits_scale).cuda()
    logits = identity()
    logits_parallel = tensor_parallel.scatter_to_tensor_model_parallel_region(
        logits)
    target = torch.cuda.LongTensor(size=(batch_size,
                                         seq_length)).random_(0, vocab_size)
    logits_parallel_ = logits_parallel.clone().detach()
    loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
    loss.backward()
    # check for mutation
    assert torch.equal(logits_parallel_, logits_parallel)
    return loss, identity.weight.grad
Exemplo n.º 6
0
def test_column_parallel_linear_with_async_allreduce_custom_amp(
        tensor_model_parallel_size):
    dtypes = (torch.half,
              torch.bfloat16) if torch.cuda.is_bf16_supported() else (
                  torch.half, )

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    seed = 12345
    set_random_seed(seed)
    input_size_coeff = 13
    input_size = input_size_coeff * tensor_model_parallel_size
    output_size_coeff = 17
    output_size = output_size_coeff * tensor_model_parallel_size
    batch_size = 7

    for dtype in dtypes:
        # Network
        identity_layer = IdentityLayer3D(batch_size, batch_size,
                                         input_size).to(device="cuda",
                                                        dtype=dtype)
        linear_layer = layers.ColumnParallelLinear(
            input_size,
            output_size,
            keep_master_weight_for_test=True,
            params_dtype=global_vars.get_args().params_dtype,
            use_cpu_initialization=global_vars.get_args().
            use_cpu_initialization,
        ).to(device="cuda", dtype=dtype)
        # Forward
        loss_weight = torch.randn([batch_size, output_size]).cuda()
        output, _ = linear_layer(identity_layer())
        loss = torch.mul(output, loss_weight).sum()
        loss.backward()
        torch.distributed.barrier()

        assert output.dtype == dtype

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')
Exemplo n.º 7
0
def parallel_transformer(tensor_model_parallel_size,
                         num_att_heads_per_partition, hidden_size_per_att_head,
                         batch_size, sequence_length):

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    seed = 12345
    set_random_seed(seed)

    num_att_heads = num_att_heads_per_partition * \
        torch.distributed.get_world_size()
    hidden_size = hidden_size_per_att_head * num_att_heads
    intermediate_size = 4 * hidden_size

    # Network
    identity_layer = IdentityLayer3D(batch_size, sequence_length,
                                     hidden_size).cuda()
    transformer_layer = parallel_state.BertParallelTransformerLayer(
        hidden_size, intermediate_size, num_att_heads, 0.0, 0.0,
        torch.nn.functional.relu, 1.0e-5).cuda()

    loss_weight = torch.randn([batch_size, sequence_length,
                               hidden_size]).cuda()
    attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
    # Forward
    input_ = identity_layer()
    output = transformer_layer(input_, attention_mask)
    loss = torch.mul(output, loss_weight).sum()
    # Backward
    loss.backward()

    rank = parallel_state.get_tensor_model_parallel_rank()
    parallel_state.destroy_model_parallel()
    return rank, hidden_size, tensor_model_parallel_size, loss, \
        transformer_layer, identity_layer
Exemplo n.º 8
0
def test_row_parallel_linear(tensor_model_parallel_size):

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    if torch.distributed.get_rank() == 0:
        print('> testing RowParallelLinear with model parallel '
              'size: {}'.format(tensor_model_parallel_size))
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    seed = 12345
    set_random_seed(seed)
    input_size_coeff = 13
    input_size = input_size_coeff * tensor_model_parallel_size
    output_size_coeff = 17
    output_size = output_size_coeff * tensor_model_parallel_size
    batch_size = 7

    # Network
    identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
    linear_layer = layers.RowParallelLinear(
        input_size,
        output_size,
        keep_master_weight_for_test=True,
        params_dtype=global_vars.get_args().params_dtype,
        use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
    ).cuda()
    loss_weight = torch.randn([batch_size, output_size]).cuda()
    # Forward
    input_ = identity_layer()
    output, _ = linear_layer(input_)
    loss = torch.mul(output, loss_weight).sum()
    # Backward
    loss.backward()

    # Values.
    dLdY = loss_weight
    X = identity_layer.weight
    A = linear_layer.master_weight.cuda()
    dLdA = torch.matmul(dLdY.t(), X)
    dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
    dLdX = torch.matmul(dLdY, A)

    rank = parallel_state.get_tensor_model_parallel_rank()
    my_dLdA = torch.split(dLdA, input_size_coeff,
                          dim=1)[rank].contiguous().clone()
    error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
    torch.distributed.barrier()
    print('   error in dLdA on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    error = dLdb.sub(linear_layer.bias.grad).abs().max()
    torch.distributed.barrier()
    print('   error in dLdb on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    error = dLdX.sub(identity_layer.weight.grad).abs().max()
    torch.distributed.barrier()
    print('   error in dLdX on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')
Exemplo n.º 9
0
def test_parallel_embedding(tensor_model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing parallel embedding with model parallel size {} ...'.
              format(tensor_model_parallel_size))

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    batch_size = 17
    seq_length = 23
    vocab_size = 48
    hidden_size = 16
    seed = 1236

    set_random_seed(123)
    input_data = torch.LongTensor(size=(batch_size, seq_length)).random_(
        0, vocab_size).cuda()
    loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()

    set_random_seed(seed)
    embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda()

    output = embedding_original(input_data)
    loss_original = torch.mul(output, loss_weight).sum()
    loss_original.backward()

    set_random_seed(seed)
    embedding_parallel = layers.ParallelEmbedding(
        vocab_size, hidden_size, init_method=init.normal_).cuda()
    output = embedding_parallel(input_data)
    loss_parallel = torch.mul(output, loss_weight).sum()
    loss_parallel.backward()

    set_random_seed(seed)
    embedding_vocab_parallel = layers.VocabParallelEmbedding(
        vocab_size, hidden_size, init_method=init.normal_).cuda()
    output = embedding_vocab_parallel(input_data)
    loss_vocab_parallel = torch.mul(output, loss_weight).sum()
    loss_vocab_parallel.backward()

    torch.distributed.barrier()
    error = loss_parallel.sub(loss_original).abs()
    print('   error in loss (parallel) on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-12, 'error: {}'.format(error)

    torch.distributed.barrier()
    error = loss_vocab_parallel.sub(loss_original).abs()
    print('   error in loss (vocab parallel) on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-12, 'error: {}'.format(error)

    weight_grad_orig = torch.split(
        embedding_original.weight.grad,
        hidden_size // tensor_model_parallel_size,
        1)[parallel_state.get_tensor_model_parallel_rank()]
    error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
    print('   error in grad (parallel) on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-12, 'error: {}'.format(error)

    weight_grad_orig = torch.split(
        embedding_original.weight.grad,
        vocab_size // tensor_model_parallel_size,
        0)[parallel_state.get_tensor_model_parallel_rank()]
    error = embedding_vocab_parallel.weight.grad.sub(
        weight_grad_orig).abs().max()
    print('   error in grad (vocab parallel) on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-12, 'error: {}'.format(error)

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
Exemplo n.º 10
0
def test_column_parallel_linear(tensor_model_parallel_size):

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    if torch.distributed.get_rank() == 0:
        print('> testing ColumnParallelLinear with model parallel '
              'size: {}'.format(tensor_model_parallel_size))
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    seed = 12345
    set_random_seed(seed)
    input_size_coeff = 13
    input_size = input_size_coeff * tensor_model_parallel_size
    output_size_coeff = 17
    output_size = output_size_coeff * tensor_model_parallel_size
    batch_size = 7
    hidden_size = 9

    # Network
    gradient_accumulation_fusion = True
    identity_layer = IdentityLayer3D(batch_size, hidden_size,
                                     input_size).cuda()
    linear_layer = layers.ColumnParallelLinear(
        input_size,
        output_size,
        keep_master_weight_for_test=True,
        params_dtype=global_vars.get_args().params_dtype,
        use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
        gradient_accumulation_fusion=gradient_accumulation_fusion,
    ).cuda()
    with torch.no_grad():
        linear_layer.weight.main_grad = torch.randn_like(linear_layer.weight)

    loss_weight = torch.randn([batch_size, hidden_size, output_size]).cuda()
    # Forward
    input_ = identity_layer()
    output, _ = linear_layer(input_)
    assert list(output.shape) == [batch_size, hidden_size, output_size]
    loss = torch.mul(output, loss_weight).sum()
    # Backward
    loss.backward()

    # TODO (mkozuki): Fix the following commented out lines
    # as `gradient_accumulation_fusion` only takes 3D tensors.
    # Values.
    # dLdY = loss_weight  # (7, 9, 17)
    # X = identity_layer.weight  # (7, 9, 13)
    # A = linear_layer.master_weight.cuda()  # (17, 13)
    # print(f"dLdY.shape, X.shape, A.shape = {dLdY.shape, X.shape, A.shape}")
    # dLdA = torch.matmul(dLdY.view(-1, 17).t(), X.view(-1, 13))
    # print(f"dLdA.shape = {dLdA.shape}")
    # ones = torch.ones(batch_size, hidden_size, 1).cuda()
    # print(f"dLdY.shape, ones.shape = {dLdY.shape, ones.shape}")
    # dLdb = torch.matmul(ones, dLdY).view(-1)
    # dLdX = torch.matmul(dLdY, A)

    # rank = parallel_state.get_tensor_model_parallel_rank()
    # my_dLdA = torch.split(dLdA, output_size_coeff,
    #                       dim=0)[rank].contiguous().clone()
    # error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
    # torch.distributed.barrier()
    # print('   error in dLdA on global rank {}: {}'.format(
    #     torch.distributed.get_rank(), error))
    # assert error < 1.0e-6

    # my_dLdb = torch.split(dLdb, output_size_coeff,
    #                       dim=0)[rank].contiguous().clone()
    # error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
    # torch.distributed.barrier()
    # print('   error in dLdb on global rank {}: {}'.format(
    #     torch.distributed.get_rank(), error))
    # assert error < 1.0e-6

    # error = dLdX.sub(identity_layer.weight.grad).abs().max()
    # torch.distributed.barrier()
    # print('   error in dLdX on global rank {}: {}'.format(
    #     torch.distributed.get_rank(), error))
    # assert error < 1.0e-6

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')
Exemplo n.º 11
0
def test_initialize_affine_weight(tensor_model_parallel_size, device):

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    if torch.distributed.get_rank() == 0:
        print('> testing initialize_affine_weight with model parallel '
              'size: {}'.format(tensor_model_parallel_size))
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    seed = 12345
    input_size_coeff = 13
    input_size = input_size_coeff * tensor_model_parallel_size
    output_size_coeff = 17
    output_size = output_size_coeff * tensor_model_parallel_size

    # ---------------
    # Column parallel
    # ---------------
    weight = torch.empty(output_size_coeff, input_size)
    set_random_seed(seed)
    if device == 'cpu':
        layers._initialize_affine_weight_cpu(
            weight,
            output_size,
            input_size,
            output_size_coeff,
            0,
            torch.nn.init.normal_,
            params_dtype=global_vars.get_args().params_dtype,
        )
    else:
        layers._initialize_affine_weight_gpu(weight, torch.nn.init.normal_, 0)

    # Target.
    set_random_seed(seed)
    master_weight = torch.empty(output_size, input_size)
    torch.nn.init.normal_(master_weight)
    rank = parallel_state.get_tensor_model_parallel_rank()
    my_weight = torch.split(master_weight, output_size_coeff,
                            dim=0)[rank].contiguous().clone()

    # Compare.
    error = weight.sub(my_weight).abs().max()
    torch.distributed.barrier()
    print('   column parallel max error (should be zero) on global rank '
          '{}: {}'.format(torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # ------------
    # Row parallel
    # ------------
    weight = torch.empty(output_size, input_size_coeff)
    set_random_seed(seed)
    if device == 'cpu':
        layers._initialize_affine_weight_cpu(
            weight,
            output_size,
            input_size,
            input_size_coeff,
            1,
            torch.nn.init.normal_,
            params_dtype=global_vars.get_args().params_dtype)

    else:
        layers._initialize_affine_weight_gpu(weight, torch.nn.init.normal_, 1)

    # Target.
    set_random_seed(seed)
    master_weight = torch.empty(output_size, input_size)
    torch.nn.init.normal_(master_weight)
    rank = parallel_state.get_tensor_model_parallel_rank()
    my_weight = torch.split(master_weight, input_size_coeff,
                            dim=1)[rank].contiguous().clone()

    # Compare.
    error = weight.sub(my_weight).abs().max()
    torch.distributed.barrier()
    print('   row parallel max error (should be zero) on global rank '
          '{}: {}'.format(torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')
Exemplo n.º 12
0
    def test_parallel_embedding(self) -> None:
        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_parallel_world_size:
                continue
            with self.subTest(tensor_model_parallel_world_size=
                              tensor_model_parallel_world_size):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=
                    tensor_model_parallel_world_size, )
                set_random_seed(self.SEED + 1)
                input_tensor = torch.randint(
                    0,
                    self.VOCAB_SIZE,
                    (
                        self.BATCH_SIZE,
                        self.SEQUENCE_LENGTH,
                    ),
                    device="cuda",
                )
                loss_weight = torch.randn(
                    (
                        self.BATCH_SIZE,
                        self.SEQUENCE_LENGTH,
                        self.HIDDEN_SIZE,
                    ),
                    device="cuda",
                )

                set_random_seed(self.SEED)
                embedding_torch = nn.Embedding(
                    self.VOCAB_SIZE,
                    self.HIDDEN_SIZE,
                ).cuda()
                output_torch = embedding_torch(input_tensor)
                loss_torch = torch.mul(output_torch, loss_weight).sum()
                loss_torch.backward()

                # N.B. (mkozuki): With affine weight initialization on GPU,
                # it's super difficult to keep the consistency with nn.Embedding.
                # Thus, turning on `use_cpu_initialization`.
                set_random_seed(self.SEED)
                embedding_vocab_parallel = layers.VocabParallelEmbedding(
                    self.VOCAB_SIZE,
                    self.HIDDEN_SIZE,
                    init_method=nn.init.normal_,
                    use_cpu_initialization=True,
                ).cuda()
                output_vocab_parallel = embedding_vocab_parallel(input_tensor)
                loss_vocab_parallel = torch.mul(output_vocab_parallel,
                                                loss_weight).sum()
                loss_vocab_parallel.backward()

                self.assertEqual(output_torch, output_vocab_parallel)
                self.assertEqual(loss_torch, loss_vocab_parallel)

                splitted_weight_torch = torch.split(
                    embedding_torch.weight.grad,
                    self.VOCAB_SIZE // tensor_model_parallel_world_size,
                    0,
                )[parallel_state.get_tensor_model_parallel_rank()]
                self.assertEqual(splitted_weight_torch,
                                 embedding_vocab_parallel.weight.grad)

                parallel_state.destroy_model_parallel()
Exemplo n.º 13
0
    def _column_parallel_linear_test_impl(
        self,
        no_async_tensor_model_parallel_allreduce: bool,
        gradient_accumulation_fusion: bool,
    ):
        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            with self.subTest(tensor_model_parallel_world_size=
                              tensor_model_parallel_world_size):
                if self.world_size % tensor_model_parallel_world_size:
                    continue
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=
                    tensor_model_parallel_world_size, )

                feature_size_coeff = self.INPUT_SIZE_COEFF
                feature_size = feature_size_coeff * tensor_model_parallel_world_size
                hidden_size = feature_size

                set_random_seed(self.SEED)
                input_tensor = torch.randn(
                    self.BATCH_SIZE,
                    hidden_size,
                    feature_size,
                    device="cuda",
                    requires_grad=True,
                )
                input_tensor.retain_grad()
                loss_weight = torch.randn(
                    (
                        self.BATCH_SIZE,
                        hidden_size,
                        feature_size,
                    ),
                    device="cuda",
                )
                linear = layers.ColumnParallelLinear(
                    feature_size,
                    feature_size,
                    bias=False,
                    keep_master_weight_for_test=True,
                    params_dtype=torch.float32,
                    use_cpu_initialization=True,
                    no_async_tensor_model_parallel_allreduce=
                    no_async_tensor_model_parallel_allreduce,
                    gradient_accumulation_fusion=gradient_accumulation_fusion,
                ).cuda()
                if gradient_accumulation_fusion:
                    with torch.no_grad():
                        linear.weight.main_grad = torch.randn_like(
                            linear.weight)
                output, _ = linear(input_tensor)
                self.assertEqual(
                    output.shape,
                    (
                        self.BATCH_SIZE,
                        hidden_size,
                        feature_size,
                    ),
                )
                loss = torch.mul(output, loss_weight).sum()
                loss.backward()

                with torch.no_grad():
                    dldy = loss_weight.clone()
                    x = input_tensor.clone()
                    a = linear.master_weight.cuda().clone()
                dldx = torch.matmul(dldy, a)
                self.assertEqual(input_tensor.grad, dldx)
                # TODO(mkozuki): Cover the other cases.
                if (tensor_model_parallel_world_size == 1
                        and not gradient_accumulation_fusion):
                    dlda = torch.matmul(torch.transpose(dldy, 1, 2),
                                        x).sum(dim=0)
                    curr_dlda = torch.split(
                        dlda, feature_size_coeff,
                        dim=0)[parallel_state.get_tensor_model_parallel_rank()]
                    self.assertEqual(linear.weight.grad, curr_dlda)

                parallel_state.destroy_model_parallel()