Example #1
0
def test_column_parallel_linear_with_async_allreduce_custom_amp(
        tensor_model_parallel_size):
    dtypes = (torch.half,
              torch.bfloat16) if torch.cuda.is_bf16_supported() else (
                  torch.half, )

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    seed = 12345
    set_random_seed(seed)
    input_size_coeff = 13
    input_size = input_size_coeff * tensor_model_parallel_size
    output_size_coeff = 17
    output_size = output_size_coeff * tensor_model_parallel_size
    batch_size = 7

    for dtype in dtypes:
        # Network
        identity_layer = IdentityLayer3D(batch_size, batch_size,
                                         input_size).to(device="cuda",
                                                        dtype=dtype)
        linear_layer = layers.ColumnParallelLinear(
            input_size,
            output_size,
            keep_master_weight_for_test=True,
            params_dtype=global_vars.get_args().params_dtype,
            use_cpu_initialization=global_vars.get_args().
            use_cpu_initialization,
        ).to(device="cuda", dtype=dtype)
        # Forward
        loss_weight = torch.randn([batch_size, output_size]).cuda()
        output, _ = linear_layer(identity_layer())
        loss = torch.mul(output, loss_weight).sum()
        loss.backward()
        torch.distributed.barrier()

        assert output.dtype == dtype

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')
Example #2
0
    def _column_parallel_linear_test_impl(
        self,
        no_async_tensor_model_parallel_allreduce: bool,
        gradient_accumulation_fusion: bool,
    ):
        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            with self.subTest(tensor_model_parallel_world_size=
                              tensor_model_parallel_world_size):
                if self.world_size % tensor_model_parallel_world_size:
                    continue
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=
                    tensor_model_parallel_world_size, )

                feature_size_coeff = self.INPUT_SIZE_COEFF
                feature_size = feature_size_coeff * tensor_model_parallel_world_size
                hidden_size = feature_size

                set_random_seed(self.SEED)
                input_tensor = torch.randn(
                    self.BATCH_SIZE,
                    hidden_size,
                    feature_size,
                    device="cuda",
                    requires_grad=True,
                )
                input_tensor.retain_grad()
                loss_weight = torch.randn(
                    (
                        self.BATCH_SIZE,
                        hidden_size,
                        feature_size,
                    ),
                    device="cuda",
                )
                linear = layers.ColumnParallelLinear(
                    feature_size,
                    feature_size,
                    bias=False,
                    keep_master_weight_for_test=True,
                    params_dtype=torch.float32,
                    use_cpu_initialization=True,
                    no_async_tensor_model_parallel_allreduce=
                    no_async_tensor_model_parallel_allreduce,
                    gradient_accumulation_fusion=gradient_accumulation_fusion,
                ).cuda()
                if gradient_accumulation_fusion:
                    with torch.no_grad():
                        linear.weight.main_grad = torch.randn_like(
                            linear.weight)
                output, _ = linear(input_tensor)
                self.assertEqual(
                    output.shape,
                    (
                        self.BATCH_SIZE,
                        hidden_size,
                        feature_size,
                    ),
                )
                loss = torch.mul(output, loss_weight).sum()
                loss.backward()

                with torch.no_grad():
                    dldy = loss_weight.clone()
                    x = input_tensor.clone()
                    a = linear.master_weight.cuda().clone()
                dldx = torch.matmul(dldy, a)
                self.assertEqual(input_tensor.grad, dldx)
                # TODO(mkozuki): Cover the other cases.
                if (tensor_model_parallel_world_size == 1
                        and not gradient_accumulation_fusion):
                    dlda = torch.matmul(torch.transpose(dldy, 1, 2),
                                        x).sum(dim=0)
                    curr_dlda = torch.split(
                        dlda, feature_size_coeff,
                        dim=0)[parallel_state.get_tensor_model_parallel_rank()]
                    self.assertEqual(linear.weight.grad, curr_dlda)

                parallel_state.destroy_model_parallel()
Example #3
0
def test_column_parallel_linear(tensor_model_parallel_size):

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    if torch.distributed.get_rank() == 0:
        print('> testing ColumnParallelLinear with model parallel '
              'size: {}'.format(tensor_model_parallel_size))
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    seed = 12345
    set_random_seed(seed)
    input_size_coeff = 13
    input_size = input_size_coeff * tensor_model_parallel_size
    output_size_coeff = 17
    output_size = output_size_coeff * tensor_model_parallel_size
    batch_size = 7
    hidden_size = 9

    # Network
    gradient_accumulation_fusion = True
    identity_layer = IdentityLayer3D(batch_size, hidden_size,
                                     input_size).cuda()
    linear_layer = layers.ColumnParallelLinear(
        input_size,
        output_size,
        keep_master_weight_for_test=True,
        params_dtype=global_vars.get_args().params_dtype,
        use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
        gradient_accumulation_fusion=gradient_accumulation_fusion,
    ).cuda()
    with torch.no_grad():
        linear_layer.weight.main_grad = torch.randn_like(linear_layer.weight)

    loss_weight = torch.randn([batch_size, hidden_size, output_size]).cuda()
    # Forward
    input_ = identity_layer()
    output, _ = linear_layer(input_)
    assert list(output.shape) == [batch_size, hidden_size, output_size]
    loss = torch.mul(output, loss_weight).sum()
    # Backward
    loss.backward()

    # TODO (mkozuki): Fix the following commented out lines
    # as `gradient_accumulation_fusion` only takes 3D tensors.
    # Values.
    # dLdY = loss_weight  # (7, 9, 17)
    # X = identity_layer.weight  # (7, 9, 13)
    # A = linear_layer.master_weight.cuda()  # (17, 13)
    # print(f"dLdY.shape, X.shape, A.shape = {dLdY.shape, X.shape, A.shape}")
    # dLdA = torch.matmul(dLdY.view(-1, 17).t(), X.view(-1, 13))
    # print(f"dLdA.shape = {dLdA.shape}")
    # ones = torch.ones(batch_size, hidden_size, 1).cuda()
    # print(f"dLdY.shape, ones.shape = {dLdY.shape, ones.shape}")
    # dLdb = torch.matmul(ones, dLdY).view(-1)
    # dLdX = torch.matmul(dLdY, A)

    # rank = parallel_state.get_tensor_model_parallel_rank()
    # my_dLdA = torch.split(dLdA, output_size_coeff,
    #                       dim=0)[rank].contiguous().clone()
    # error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
    # torch.distributed.barrier()
    # print('   error in dLdA on global rank {}: {}'.format(
    #     torch.distributed.get_rank(), error))
    # assert error < 1.0e-6

    # my_dLdb = torch.split(dLdb, output_size_coeff,
    #                       dim=0)[rank].contiguous().clone()
    # error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
    # torch.distributed.barrier()
    # print('   error in dLdb on global rank {}: {}'.format(
    #     torch.distributed.get_rank(), error))
    # assert error < 1.0e-6

    # error = dLdX.sub(identity_layer.weight.grad).abs().max()
    # torch.distributed.barrier()
    # print('   error in dLdX on global rank {}: {}'.format(
    #     torch.distributed.get_rank(), error))
    # assert error < 1.0e-6

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')
Example #4
0
def test_column_parallel_linear(tensor_model_parallel_size):

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    if torch.distributed.get_rank() == 0:
        print('> testing ColumnParallelLinear with model parallel '
              'size: {}'.format(tensor_model_parallel_size))
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    seed = 12345
    set_random_seed(seed)
    input_size_coeff = 13
    input_size = input_size_coeff * tensor_model_parallel_size
    output_size_coeff = 17
    output_size = output_size_coeff * tensor_model_parallel_size
    batch_size = 7

    # Network
    identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
    linear_layer = layers.ColumnParallelLinear(
        input_size,
        output_size,
        keep_master_weight_for_test=True,
        params_dtype=global_vars.get_args().params_dtype,
        use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
    ).cuda()
    loss_weight = torch.randn([batch_size, output_size]).cuda()
    # Forward
    input_ = identity_layer()
    output, _ = linear_layer(input_)
    loss = torch.mul(output, loss_weight).sum()
    # Backward
    loss.backward()

    # Values.
    dLdY = loss_weight
    X = identity_layer.weight
    A = linear_layer.master_weight.cuda()
    dLdA = torch.matmul(dLdY.t(), X)
    dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
    dLdX = torch.matmul(dLdY, A)

    rank = parallel_state.get_tensor_model_parallel_rank()
    my_dLdA = torch.split(dLdA, output_size_coeff,
                          dim=0)[rank].contiguous().clone()
    error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
    torch.distributed.barrier()
    print('   error in dLdA on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    my_dLdb = torch.split(dLdb, output_size_coeff,
                          dim=0)[rank].contiguous().clone()
    error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
    torch.distributed.barrier()
    print('   error in dLdb on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    error = dLdX.sub(identity_layer.weight.grad).abs().max()
    torch.distributed.barrier()
    print('   error in dLdX on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')