예제 #1
0
def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition,
                         hidden_size_per_att_head, batch_size, sequence_length):

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()

    seed = 12345
    set_random_seed(seed)

    num_att_heads = num_att_heads_per_partition * \
        torch.distributed.get_world_size()
    hidden_size = hidden_size_per_att_head * num_att_heads
    intermediate_size = 4 * hidden_size

    # Network
    identity_layer = IdentityLayer3D(batch_size, sequence_length,
                                     hidden_size).cuda()
    transformer_layer = parallel_state.BertParallelTransformerLayer(
        hidden_size, intermediate_size, num_att_heads, 0.0, 0.0,
        torch.nn.functional.relu, 1.0e-5).cuda()

    loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
    attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
    # Forward
    input_ = identity_layer()
    output = transformer_layer(input_, attention_mask)
    loss = torch.mul(output, loss_weight).sum()
    # Backward
    loss.backward()

    rank = parallel_state.get_tensor_model_parallel_rank()
    parallel_state.destroy_model_parallel()
    return rank, hidden_size, tensor_model_parallel_size, loss, \
        transformer_layer, identity_layer
예제 #2
0
def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing model parallel cuda manual seed with size {} ...'.
              format(tensor_model_parallel_size))

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    tensor_parallel.random.model_parallel_cuda_manual_seed(12345)
    assert torch.cuda.initial_seed() == 12345
    with tensor_parallel.random.get_cuda_rng_tracker().fork():
        assert (torch.cuda.initial_seed() == 12345 + 2718 +
                parallel_state.get_tensor_model_parallel_rank())

    # Reset the tracker
    tensor_parallel.random.get_cuda_rng_tracker().reset()

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)
예제 #3
0
def parallel_self_attention(tensor_model_parallel_size,
                            num_att_heads_per_partition,
                            hidden_size_per_att_head, dropout_prob, batch_size,
                            sequence_length):
    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    seed = 12345
    set_random_seed(seed)

    num_att_heads = num_att_heads_per_partition * \
        torch.distributed.get_world_size()
    hidden_size = hidden_size_per_att_head * num_att_heads

    # Network
    identity_layer = IdentityLayer3D(batch_size, sequence_length,
                                     hidden_size).cuda()
    attention_layer = parallel_state.BertParallelSelfAttention(
        hidden_size, num_att_heads, dropout_prob).cuda()
    loss_weight = torch.randn([batch_size, sequence_length,
                               hidden_size]).cuda()
    attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
    # Forward
    input_ = identity_layer()
    output = attention_layer(input_, attention_mask)
    loss = torch.mul(output, loss_weight).sum()
    # Backward
    loss.backward()

    rank = parallel_state.get_tensor_model_parallel_rank()
    parallel_state.destroy_model_parallel()
    return rank, hidden_size, tensor_model_parallel_size, loss, \
        attention_layer, identity_layer
예제 #4
0
    def forward(ctx, vocab_parallel_logits, target):

        # Maximum value along vocab dimension across all GPUs.
        logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
        torch.distributed.all_reduce(logits_max,
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=get_tensor_model_parallel_group())
        # Subtract the maximum value.
        vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(
            dim=-1)

        # Get the partition's vocab indecies
        get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
        partition_vocab_size = vocab_parallel_logits.size()[-1]
        rank = get_tensor_model_parallel_rank()
        world_size = get_tensor_model_parallel_world_size()
        vocab_start_index, vocab_end_index = get_vocab_range(
            partition_vocab_size, rank, world_size)

        # Create a mask of valid vocab ids (1 means it needs to be masked).
        target_mask = (target < vocab_start_index) | (target >=
                                                      vocab_end_index)
        masked_target = target.clone() - vocab_start_index
        masked_target[target_mask] = 0

        # Get predicted-logits = logits[target].
        # For Simplicity, we convert logits to a 2-D tensor with size
        # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
        logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
        masked_target_1d = masked_target.view(-1)
        arange_1d = torch.arange(start=0,
                                 end=logits_2d.size()[0],
                                 device=logits_2d.device)
        predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
        predicted_logits_1d = predicted_logits_1d.clone().contiguous()
        predicted_logits = predicted_logits_1d.view_as(target)
        predicted_logits[target_mask] = 0.0
        # All reduce is needed to get the chunks from other GPUs.
        torch.distributed.all_reduce(predicted_logits,
                                     op=torch.distributed.ReduceOp.SUM,
                                     group=get_tensor_model_parallel_group())

        # Sum of exponential of logits along vocab dimension across all GPUs.
        exp_logits = vocab_parallel_logits
        torch.exp(vocab_parallel_logits, out=exp_logits)
        sum_exp_logits = exp_logits.sum(dim=-1)
        torch.distributed.all_reduce(sum_exp_logits,
                                     op=torch.distributed.ReduceOp.SUM,
                                     group=get_tensor_model_parallel_group())

        # Loss = log(sum(exp(logits))) - predicted-logit.
        loss = torch.log(sum_exp_logits) - predicted_logits

        # Store softmax, target-mask and masked-target for backward pass.
        exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
        ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)

        return loss
예제 #5
0
파일: utils.py 프로젝트: kexinyu/apex
def split_tensor_into_1d_equal_chunks(tensor):
    """Break a tensor into equal 1D chunks."""
    data = tensor.view(-1)
    partition_size = torch.numel(
        data) // parallel_state.get_tensor_model_parallel_world_size()
    start_index = partition_size * parallel_state.get_tensor_model_parallel_rank(
    )
    end_index = start_index + partition_size
    return data[start_index:end_index]
예제 #6
0
파일: layers.py 프로젝트: NVIDIA/apex
    def __init__(
        self,
        num_embeddings,
        embedding_dim,
        init_method=init.xavier_normal_,
        *,
        params_dtype=torch.float32,
        use_cpu_initialization=False,
    ):
        super(VocabParallelEmbedding, self).__init__()
        # Keep the input dimensions.
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        # Set the detauls for compatibility.
        self.padding_idx = None
        self.max_norm = None
        self.norm_type = 2.0
        self.scale_grad_by_freq = False
        self.sparse = False
        self._weight = None
        self.tensor_model_parallel_size = get_tensor_model_parallel_world_size(
        )
        # Divide the weight matrix along the vocaburaly dimension.
        self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
            self.num_embeddings, get_tensor_model_parallel_rank(),
            self.tensor_model_parallel_size)
        self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index

        # Allocate weights and initialize.
        if use_cpu_initialization:
            self.weight = Parameter(
                torch.empty(self.num_embeddings_per_partition,
                            self.embedding_dim,
                            dtype=params_dtype))
            _initialize_affine_weight_cpu(
                self.weight,
                self.num_embeddings,
                self.embedding_dim,
                self.num_embeddings_per_partition,
                0,
                init_method,
                params_dtype=params_dtype,
            )
        else:
            self.weight = Parameter(
                torch.empty(
                    self.num_embeddings_per_partition,
                    self.embedding_dim,
                    device=torch.cuda.current_device(),
                    dtype=params_dtype,
                ))
            _initialize_affine_weight_gpu(self.weight,
                                          init_method,
                                          partition_dim=0,
                                          stride=1)
예제 #7
0
def test__reduce(args, tensor_model_parallel_size):
    print("Testing reduction size =", tensor_model_parallel_size)
    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )
    assert torch.equal(
        mappings._reduce(torch.full((10, 10, 10, 10), (50))),
        torch.full((10, 10, 10, 10), 50 * tensor_model_parallel_size),
    )
    parallel_state.destroy_model_parallel()
    print("Passed!")
예제 #8
0
def test_broadcast_data(tensor_model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print(
            '> testing broadcast_data with model parallel size {} ...'.format(
                tensor_model_parallel_size))

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    torch.manual_seed(1234 + parallel_state.get_data_parallel_rank())
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    key_size_t = {
        'key1': [7, 11],
        'key2': [8, 2, 1],
        'key3': [13],
        'key4': [5, 1, 2],
        'key5': [5, 12],
    }
    keys = list(key_size_t.keys())

    data = {}
    data_t = {}
    for key in key_size_t:
        data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000)
        data_t[key] = data[key].clone()
    data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
    data_t['keyX'] = data['keyX'].clone()
    if parallel_state.get_tensor_model_parallel_rank() != 0:
        data = None

    data_utils._check_data_types(keys, data_t, torch.int64)
    key_size, key_numel, \
        total_numel = data_utils._build_key_size_numel_dictionaries(keys, data)
    for key in keys:
        assert key_size[key] == key_size_t[key]
    total_numel_t = 0
    for key in keys:
        target_size = functools.reduce(operator.mul, key_size_t[key], 1)
        assert key_numel[key] == target_size
        total_numel_t += target_size
    assert total_numel == total_numel_t

    data_b = data_utils.broadcast_data(keys, data, torch.int64)
    for key in keys:
        tensor = data_t[key].cuda()
        assert data_b[key].sub(tensor).abs().max() == 0

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)
예제 #9
0
def _reduce(input_):
    """All-reduce the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if get_tensor_model_parallel_world_size() == 1:
        return input_

    # All-reduce.
    torch.distributed.all_reduce(input_,
                                 group=get_tensor_model_parallel_group())

    return input_
예제 #10
0
def test__gather(args, tensor_model_parallel_size):

    print("Testing gathering size =", tensor_model_parallel_size)
    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )
    assert torch.equal(
        mappings._gather(
            torch.tensor([parallel_state.get_tensor_model_parallel_rank()])),
        torch.tensor(list(range(tensor_model_parallel_size))),
    )
    parallel_state.destroy_model_parallel()
    print("Passed!")
예제 #11
0
파일: utils.py 프로젝트: kexinyu/apex
def gather_split_1d_tensor(tensor):
    """Opposite of above function, gather values from model parallel ranks."""
    world_size = parallel_state.get_tensor_model_parallel_world_size()
    numel = torch.numel(tensor)
    numel_gathered = world_size * numel
    gathered = torch.empty(numel_gathered,
                           dtype=tensor.dtype,
                           device=torch.cuda.current_device(),
                           requires_grad=False)
    chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)]
    torch.distributed.all_gather(
        chunks, tensor, group=parallel_state.get_tensor_model_parallel_group())
    return gathered
예제 #12
0
def test__split(args, tensor_model_parallel_size):
    print("Testing splitting size =", tensor_model_parallel_size)
    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )
    listy = []
    for i in range(tensor_model_parallel_size):
        listy.append(torch.randn(10, 1))
    x = torch.cat(tuple(listy), 1)
    out = mappings._split(x)
    assert torch.equal(out,
                       listy[parallel_state.get_tensor_model_parallel_rank()])
    parallel_state.destroy_model_parallel()
    print("Passed!")
예제 #13
0
def _split(input_):
    """Split the tensor along its last dimension and keep the
    corresponding slice."""

    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_

    # Split along last dimension.
    input_list = split_tensor_along_last_dim(input_, world_size)

    # Note: torch.split does not create contiguous tensors by default.
    rank = get_tensor_model_parallel_rank()
    output = input_list[rank].contiguous()

    return output
예제 #14
0
def test_column_parallel_linear_with_async_allreduce_custom_amp(
        tensor_model_parallel_size):
    dtypes = (torch.half,
              torch.bfloat16) if torch.cuda.is_bf16_supported() else (
                  torch.half, )

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    seed = 12345
    set_random_seed(seed)
    input_size_coeff = 13
    input_size = input_size_coeff * tensor_model_parallel_size
    output_size_coeff = 17
    output_size = output_size_coeff * tensor_model_parallel_size
    batch_size = 7

    for dtype in dtypes:
        # Network
        identity_layer = IdentityLayer3D(batch_size, batch_size,
                                         input_size).to(device="cuda",
                                                        dtype=dtype)
        linear_layer = layers.ColumnParallelLinear(
            input_size,
            output_size,
            keep_master_weight_for_test=True,
            params_dtype=global_vars.get_args().params_dtype,
            use_cpu_initialization=global_vars.get_args().
            use_cpu_initialization,
        ).to(device="cuda", dtype=dtype)
        # Forward
        loss_weight = torch.randn([batch_size, output_size]).cuda()
        output, _ = linear_layer(identity_layer())
        loss = torch.mul(output, loss_weight).sum()
        loss.backward()
        torch.distributed.barrier()

        assert output.dtype == dtype

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')
예제 #15
0
파일: layers.py 프로젝트: NVIDIA/apex
def _initialize_affine_weight_cpu(
    weight,
    output_size,
    input_size,
    per_partition_size,
    partition_dim,
    init_method,
    stride=1,
    return_master_weight=False,
    *,
    params_dtype=torch.float32,
):
    """Initialize affine weight for model parallel.

    Build the master weight on all processes and scatter
    the relevant chunk."""

    set_tensor_model_parallel_attributes(tensor=weight,
                                         is_parallel=True,
                                         dim=partition_dim,
                                         stride=stride)

    # Initialize master weight
    master_weight = torch.empty(output_size,
                                input_size,
                                dtype=torch.float,
                                requires_grad=False)
    init_method(master_weight)
    master_weight = master_weight.to(dtype=params_dtype)

    # Split and copy
    per_partition_per_stride_size = divide(per_partition_size, stride)
    weight_list = torch.split(master_weight,
                              per_partition_per_stride_size,
                              dim=partition_dim)
    rank = get_tensor_model_parallel_rank()
    world_size = get_tensor_model_parallel_world_size()
    my_weight_list = weight_list[rank::world_size]

    with torch.no_grad():
        torch.cat(my_weight_list, dim=partition_dim, out=weight)
    if return_master_weight:
        return master_weight
    return None
예제 #16
0
    def test_initialize_model_parallel(self) -> None:

        self.assertFalse(parallel_state.model_parallel_is_initialized())

        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            with self.subTest(tensor_model_parallel_world_size=
                              tensor_model_parallel_world_size):
                if self.world_size % tensor_model_parallel_world_size:
                    continue

                pipeline_model_parallel_world_size = (
                    self.world_size // tensor_model_parallel_world_size)

                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=
                    tensor_model_parallel_world_size,
                    pipeline_model_parallel_size_=
                    pipeline_model_parallel_world_size,
                )
                self.assertEqual(
                    tensor_model_parallel_world_size,
                    parallel_state.get_tensor_model_parallel_world_size(),
                )
                expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank(
                    self.rank, tensor_model_parallel_world_size)
                self.assertEqual(
                    expected_tensor_model_parallel_rank,
                    parallel_state.get_tensor_model_parallel_rank(),
                )

                expected_tensor_model_parallel_src_rank = (
                    self.rank // tensor_model_parallel_world_size
                ) * tensor_model_parallel_world_size
                self.assertEqual(
                    expected_tensor_model_parallel_src_rank,
                    parallel_state.get_tensor_model_parallel_src_rank(),
                )

                parallel_state.destroy_model_parallel()
                self.assertFalse(
                    parallel_state.model_parallel_is_initialized())
예제 #17
0
def test_cross_entropy(tensor_model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing cross entropy with model parallel size {} ...'.format(
            tensor_model_parallel_size))

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    batch_size = 13
    seq_length = 17
    vocab_size_per_partition = 11
    logits_scale = 1000.0
    vocab_size = vocab_size_per_partition * tensor_model_parallel_size
    seed = 1234

    loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length,
                                                 vocab_size, logits_scale,
                                                 seed)
    loss_mpu, grad_mpu = tensor_sharded_cross_entropy(batch_size, seq_length,
                                                      vocab_size, logits_scale,
                                                      seed)

    error = loss_torch.sub_(loss_mpu).abs().max()
    print('   max error in loss on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    error = grad_torch.sub_(grad_mpu).abs().max()
    print('   max error in grad on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)
예제 #18
0
def _gather(input_):
    """Gather tensors and concatinate along the last dimension."""

    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_

    # Size and dimension.
    last_dim = input_.dim() - 1
    rank = get_tensor_model_parallel_rank()

    tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
    tensor_list[rank] = input_
    torch.distributed.all_gather(tensor_list,
                                 input_,
                                 group=get_tensor_model_parallel_group())

    # Note: torch.cat already creates a contiguous tensor.
    output = torch.cat(tensor_list, dim=last_dim).contiguous()

    return output
예제 #19
0
def test_initialize_model_parallel(tensor_model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing initialize_model_parallel with size {} ...'.format(
            tensor_model_parallel_size))
    tensor_model_parallel_size_ = min(
        tensor_model_parallel_size,
        torch.distributed.get_world_size(),
    )
    assert not parallel_state.model_parallel_is_initialized()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_)
    assert parallel_state.model_parallel_is_initialized()

    # Checks.
    def check(group, world_size, rank):
        assert world_size == torch.distributed.get_world_size(group=group)
        assert rank == torch.distributed.get_rank(group=group)

    # Model parallel.
    world_size = tensor_model_parallel_size_
    rank = torch.distributed.get_rank() % tensor_model_parallel_size_
    assert world_size == parallel_state.get_tensor_model_parallel_world_size()
    assert rank == parallel_state.get_tensor_model_parallel_rank()
    check(parallel_state.get_tensor_model_parallel_group(), world_size, rank)

    # Data parallel.
    world_size = torch.distributed.get_world_size(
    ) // tensor_model_parallel_size_
    rank = torch.distributed.get_rank() // tensor_model_parallel_size
    assert world_size == parallel_state.get_data_parallel_world_size()
    assert rank == parallel_state.get_data_parallel_rank()
    check(parallel_state.get_data_parallel_group(), world_size, rank)

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)
예제 #20
0
def test_cuda_rng_tracker(tensor_model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing cuda rng tracker with size {} ...'.format(
            tensor_model_parallel_size))

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    seed_1 = 1234
    seed_2 = 4321
    size = [12, 21]
    tensor = torch.cuda.FloatTensor(size)

    # Set to seed_1 and generate two tensors.
    torch.cuda.manual_seed(seed_1)
    torch.randn(size, out=tensor)
    target_11 = tensor.clone()
    torch.randn(size, out=tensor)
    target_12 = tensor.clone()

    # Set to seed_2 and generate two tensors.
    torch.cuda.manual_seed(seed_2)
    torch.randn(size, out=tensor)
    target_21 = tensor.clone()
    torch.randn(size, out=tensor)
    target_22 = tensor.clone()

    # Now if we interleave seed_1 and seed_2,
    # we should still get the same tensors
    torch.cuda.manual_seed(seed_1)
    tensor_parallel.random.get_cuda_rng_tracker().add('test', seed_2)

    torch.randn(size, out=tensor)
    result_11 = tensor.clone()

    with tensor_parallel.random.get_cuda_rng_tracker().fork('test'):
        torch.randn(size, out=tensor)
        result_21 = tensor.clone()

    torch.randn(size, out=tensor)
    result_12 = tensor.clone()

    with tensor_parallel.random.get_cuda_rng_tracker().fork('test'):
        torch.randn(size, out=tensor)
        result_22 = tensor.clone()

    diff = result_11.sub(result_21).abs().max()
    diff = min(diff, result_12.sub(result_22).abs().max())
    print('   max diff in generated tensors (should be non-zero) on '
          'global rank {}: {}'.format(torch.distributed.get_rank(), diff))
    assert diff > 1.0e-6
    error = max(
        result_11.sub(target_11).abs().max(),
        result_12.sub(target_12).abs().max())
    error = max(error, result_21.sub(target_21).abs().max())
    error = max(error, result_22.sub(target_22).abs().max())
    print('   max error in generated tensors (should be zero) on '
          'global rank {}: {}'.format(torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Reset the tracker
    tensor_parallel.random.get_cuda_rng_tracker().reset()

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)
예제 #21
0
def test_set_cuda_rng_state(tensor_model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing set_rng_state with size {} ...'.format(
            tensor_model_parallel_size))

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    size = 123
    seed = 1234
    torch.cuda.manual_seed(seed)
    tensor = torch.cuda.FloatTensor(size)

    # Get the state
    rng_state = torch.cuda.get_rng_state()
    rng_state_copy = rng_state.clone()

    # Do some stuff.
    for _ in range(5):
        torch.randn(size, out=tensor)
    result_1 = tensor.clone()

    assert rng_state.sub(rng_state_copy).max() == 0
    assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0

    # State should be different.
    new_rng_state = torch.cuda.get_rng_state()
    max_diff = new_rng_state.sub(rng_state).max()
    print(
        '   max diff in rng state (should be non-zero) on global rank {}: {}'.
        format(torch.distributed.get_rank(), max_diff))
    assert max_diff > 0

    # Reset the rng state and do the same stuff.
    tensor_parallel.random._set_cuda_rng_state(rng_state)
    for _ in range(5):
        torch.randn(size, out=tensor)
    tensor_parallel.random._set_cuda_rng_state(rng_state)
    for _ in range(5):
        torch.randn(size, out=tensor)
    result_2 = tensor.clone()

    # Results should be the same
    error = result_2.sub(result_1).abs().max()
    print('   max error in generated tensors (should be zero) on '
          'global rank {}: {}'.format(torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Input state should have remained intact.
    error = rng_state.sub(rng_state_copy).max()
    print('   max error in rng state (should be zero) on global rank {}: {}'.
          format(torch.distributed.get_rank(), error))
    assert error == 0

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)
예제 #22
0
def _communicate(
    tensor_send_next: Optional[torch.Tensor],
    tensor_send_prev: Optional[torch.Tensor],
    recv_prev: bool,
    recv_next: bool,
    tensor_shape: Optional[Shape] = None,
    override_scatter_gather_tensors_in_pipeline: bool = False,
    dtype_: torch.dtype = torch.float,
    *,
    scatter_gather_tensors_in_pipeline: bool = True,
    params_dtype: Optional[torch.dtype] = None,
    fp32_residual_connection: bool = False,
) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None]]:
    """Base function for communication of tensors between stages.

    Args:
        tensor_send_next: tensor to send to next rank (no tensor sent if set to None).
        tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None).
        recv_prev: boolean for whether tensor should be received from previous rank.
        recv_next: boolean for whether tensor should be received from next rank.
        tensor_shape: optional, use when the input sequence contains less tokens than the default sequence length
        override_scatter_gather_tensors_in_pipeline:
            optional, this is used when tensor_shape is provided to override scatter gather tensors
        dtype_: This is used when tensor_shape is provided and what is the type of tensor_shape

    Keyword args:
        scatter_gather_tensors_in_pipeline: Optional. If :obj:`True`, use scatter/gather to optimize communication of tensors.
        params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on
            your model deliberately, pass this argument.
        fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32.

    Returns:
        tuple containing

        - tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise.
        - tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise.
    """
    # Create placeholder tensors for receive in forward and backward directions if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
    if tensor_shape is None:
        # In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)`
        raise RuntimeError(
            "`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`")
    if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
        tensor_chunk_shape = (reduce(operator.mul, tensor_shape, 1) // parallel_state.get_tensor_model_parallel_world_size(),)
    else:
        tensor_chunk_shape = tensor_shape

    # NOTE(mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32,
    # FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general.
    # It might be possible if we restrict model architecture.
    # dtype = params_dtype or torch.float
    # if fp32_residual_connection:
    #     dtype = torch.float
    # if dtype_ is not None:
    #     dtype = dtype_
    #     requires_grad = False
    if dtype_ != torch.float32 or params_dtype is not None:
        if torch.distributed.get_rank() == 0:
            warnings.warn("Tensor P2P communications are executed in FP32")
    dtype = torch.float32
    requires_grad = True

    if recv_prev:
        tensor_recv_prev = torch.empty(
            tensor_chunk_shape,
            requires_grad=requires_grad,
            device=torch.cuda.current_device(),
            dtype=dtype,
        )
    if recv_next:
        tensor_recv_next = torch.empty(
            tensor_chunk_shape,
            requires_grad=requires_grad,
            device=torch.cuda.current_device(),
            dtype=dtype,
        )

    # Split tensor into smaller chunks if using scatter-gather optimization.
    if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
        if tensor_send_next is not None:
            tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next)

        if tensor_send_prev is not None:
            tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev)

    # Send tensors in both the forward and backward directions as appropriate.
    _run_p2pops(tensor_send_prev, tensor_send_next, tensor_recv_prev, tensor_recv_next)
    # To protect against race condition when using batch_isend_irecv().
    torch.cuda.synchronize()

    # If using scatter-gather optimization, gather smaller chunks.
    if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
        if recv_prev:
            tensor_recv_prev = (
                gather_split_1d_tensor(tensor_recv_prev)
                .view(tensor_shape)
                .requires_grad_()
            )

        if recv_next:
            tensor_recv_next = (
                gather_split_1d_tensor(tensor_recv_next)
                .view(tensor_shape)
                .requires_grad_()
            )

    return tensor_recv_prev, tensor_recv_next
예제 #23
0
def _communicate(
    tensor_send_next: Optional[torch.Tensor],
    tensor_send_prev: Optional[torch.Tensor],
    recv_prev: bool,
    recv_next: bool,
    tensor_shape: Optional[Shape] = None,
    override_scatter_gather_tensors_in_pipeline: bool = False,
    dtype_: Optional[torch.dtype] = None,
    *,
    scatter_gather_tensors_in_pipeline: bool = True,
    params_dtype: Optional[torch.dtype] = None,
    fp32_residual_connection: bool = False,
    async_comm: bool = False,
) -> Tuple[Union[torch.Tensor, FutureTensor, None], Union[torch.Tensor,
                                                          FutureTensor, None]]:
    """Base function for communication of tensors between stages.

    dtype logic: If none of ``dtype_``, ``params_dtype``, ``fp32_residual_connection`` is specified,
    torch.float32 is used.

    See https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/arguments.py#L145-L159
    for the details of arguments of ``dtype_``, ``params_dtype``, ``fp32_residual_connection``.

    Args:
        tensor_send_next: tensor to send to next rank (no tensor sent if set to None).
        tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None).
        recv_prev: boolean for whether tensor should be received from previous rank.
        recv_next: boolean for whether tensor should be received from next rank.
        tensor_shape: optional, use when the input sequence contains less tokens than the default sequence length
        override_scatter_gather_tensors_in_pipeline:
            optional, this is used when tensor_shape is provided to override scatter gather tensors
        dtype_: This is used when tensor_shape is provided and what is the type of tensor_shape

    Keyword args:
        scatter_gather_tensors_in_pipeline: Optional. If :obj:`True`, use scatter/gather to optimize communication of tensors.
        params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on
            your model deliberately, pass this argument.
        fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32.

    Returns:
        tuple containing

        - tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise.
        - tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise.
    """
    # Create placeholder tensors for receive in forward and backward directions if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
    if tensor_shape is None:
        # In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)`
        raise RuntimeError(
            "`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`"
        )
    if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
        tensor_chunk_shape = (
            reduce(operator.mul, tensor_shape, 1) //
            parallel_state.get_tensor_model_parallel_world_size(), )
    else:
        tensor_chunk_shape = tensor_shape

    # The dtype logic below is copied from NVIDIA/Megatron-LM repo:
    # https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/p2p_communication.py#L74-L81
    # NOTE (mkozuki): Currently NeMo is implementing APEX AMP O2 style using PyTorch. In O2 style, forcing p2p comm to
    # use FP32 will be a perf killer so that I decided to reanimate `dtype_` argument with the default value of `None`.
    # NOTE (mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32,
    # FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general.
    # It might be possible if we restrict model architecture.
    dtype = params_dtype or torch.float
    if fp32_residual_connection:
        dtype = torch.float
    requires_grad = True
    if dtype_ is not None:
        dtype = dtype_
        requires_grad = False

    if recv_prev:
        tensor_recv_prev = torch.empty(
            tensor_chunk_shape,
            requires_grad=requires_grad,
            device=torch.cuda.current_device(),
            dtype=dtype,
        )
    if recv_next:
        tensor_recv_next = torch.empty(
            tensor_chunk_shape,
            requires_grad=requires_grad,
            device=torch.cuda.current_device(),
            dtype=dtype,
        )

    # Split tensor into smaller chunks if using scatter-gather optimization.
    if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
        if tensor_send_next is not None:
            tensor_send_next = split_tensor_into_1d_equal_chunks(
                tensor_send_next)

        if tensor_send_prev is not None:
            tensor_send_prev = split_tensor_into_1d_equal_chunks(
                tensor_send_prev)

    # Send tensors in both the forward and backward directions as appropriate.
    tensor_send_prev_req, tensor_recv_prev_req, tensor_send_next_req, tensor_recv_next_req = _run_p2pops(
        tensor_send_prev,
        tensor_send_next,
        tensor_recv_prev,
        tensor_recv_next,
        async_comm=async_comm)

    if async_comm:
        tensor_recv_prev_waitfunc = None
        tensor_recv_next_waitfunc = None
        # TODO: investigate whether this is necessary for correctness (ref: https://github.com/pytorch/pytorch/issues/38642)
        # see also: sync added for async_comm callbacks below in gather_recv_prev_wait and gather_recv_next_wait
        if tensor_recv_prev_req is not None:

            def tensor_recv_prev_wait():
                tensor_recv_prev_req.wait()
                torch.cuda.synchronize()

            tensor_recv_prev_waitfunc = tensor_recv_prev_wait
        if tensor_recv_next_req is not None:

            def tensor_recv_next_wait():
                tensor_recv_next_req.wait()
                torch.cuda.synchronize()

            tensor_recv_next_waitfunc = tensor_recv_next_wait
    else:
        # To protect against race condition when using batch_isend_irecv().
        torch.cuda.synchronize()

    # If using scatter-gather optimization, gather smaller chunks.
    if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
        if not async_comm:
            if recv_prev:
                tensor_recv_prev = (gather_split_1d_tensor(
                    tensor_recv_prev).view(tensor_shape).requires_grad_())

            if recv_next:
                tensor_recv_next = (gather_split_1d_tensor(
                    tensor_recv_next).view(tensor_shape).requires_grad_())
        else:

            def gather_recv_prev_wait():
                tensor_recv_prev_req.wait()
                # From @Deepak's PR https://github.com/NVIDIA/Megatron-LM/commit/27fc468964064eeb33b703c9a0b2af938d80dd14
                # A sync seems to be needed before gather otherwise losses jump around e.g., in run_gpt_minimal_test
                torch.cuda.synchronize()
                return (gather_split_1d_tensor(tensor_recv_prev).view(
                    tensor_shape).requires_grad_())

            def gather_recv_next_wait():
                tensor_recv_next_req.wait()
                torch.cuda.synchronize()
                return (gather_split_1d_tensor(tensor_recv_next).view(
                    tensor_shape).requires_grad_())

            tensor_recv_prev_waitfunc = gather_recv_prev_wait
            tensor_recv_next_waitfunc = gather_recv_next_wait
    if async_comm:
        future_tensor_recv_prev = None
        future_tensor_recv_next = None
        if tensor_recv_prev is not None:
            future_tensor_recv_prev = FutureTensor(tensor_recv_prev,
                                                   tensor_recv_prev_waitfunc)
        if tensor_recv_next is not None:
            future_tensor_recv_next = FutureTensor(tensor_recv_next,
                                                   tensor_recv_next_waitfunc)
        return future_tensor_recv_prev, future_tensor_recv_next

    return tensor_recv_prev, tensor_recv_next
예제 #24
0
def test_row_parallel_linear(tensor_model_parallel_size):

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    if torch.distributed.get_rank() == 0:
        print('> testing RowParallelLinear with model parallel '
              'size: {}'.format(tensor_model_parallel_size))
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    seed = 12345
    set_random_seed(seed)
    input_size_coeff = 13
    input_size = input_size_coeff * tensor_model_parallel_size
    output_size_coeff = 17
    output_size = output_size_coeff * tensor_model_parallel_size
    batch_size = 7

    # Network
    identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
    linear_layer = layers.RowParallelLinear(
        input_size,
        output_size,
        keep_master_weight_for_test=True,
        params_dtype=global_vars.get_args().params_dtype,
        use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
    ).cuda()
    loss_weight = torch.randn([batch_size, output_size]).cuda()
    # Forward
    input_ = identity_layer()
    output, _ = linear_layer(input_)
    loss = torch.mul(output, loss_weight).sum()
    # Backward
    loss.backward()

    # Values.
    dLdY = loss_weight
    X = identity_layer.weight
    A = linear_layer.master_weight.cuda()
    dLdA = torch.matmul(dLdY.t(), X)
    dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
    dLdX = torch.matmul(dLdY, A)

    rank = parallel_state.get_tensor_model_parallel_rank()
    my_dLdA = torch.split(dLdA, input_size_coeff,
                          dim=1)[rank].contiguous().clone()
    error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
    torch.distributed.barrier()
    print('   error in dLdA on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    error = dLdb.sub(linear_layer.bias.grad).abs().max()
    torch.distributed.barrier()
    print('   error in dLdb on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    error = dLdX.sub(identity_layer.weight.grad).abs().max()
    torch.distributed.barrier()
    print('   error in dLdX on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')
예제 #25
0
def test_parallel_embedding(tensor_model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing parallel embedding with model parallel size {} ...'.
              format(tensor_model_parallel_size))

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    batch_size = 17
    seq_length = 23
    vocab_size = 48
    hidden_size = 16
    seed = 1236

    set_random_seed(123)
    input_data = torch.LongTensor(size=(batch_size, seq_length)).random_(
        0, vocab_size).cuda()
    loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()

    set_random_seed(seed)
    embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda()

    output = embedding_original(input_data)
    loss_original = torch.mul(output, loss_weight).sum()
    loss_original.backward()

    set_random_seed(seed)
    embedding_parallel = layers.ParallelEmbedding(
        vocab_size, hidden_size, init_method=init.normal_).cuda()
    output = embedding_parallel(input_data)
    loss_parallel = torch.mul(output, loss_weight).sum()
    loss_parallel.backward()

    set_random_seed(seed)
    embedding_vocab_parallel = layers.VocabParallelEmbedding(
        vocab_size, hidden_size, init_method=init.normal_).cuda()
    output = embedding_vocab_parallel(input_data)
    loss_vocab_parallel = torch.mul(output, loss_weight).sum()
    loss_vocab_parallel.backward()

    torch.distributed.barrier()
    error = loss_parallel.sub(loss_original).abs()
    print('   error in loss (parallel) on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-12, 'error: {}'.format(error)

    torch.distributed.barrier()
    error = loss_vocab_parallel.sub(loss_original).abs()
    print('   error in loss (vocab parallel) on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-12, 'error: {}'.format(error)

    weight_grad_orig = torch.split(
        embedding_original.weight.grad,
        hidden_size // tensor_model_parallel_size,
        1)[parallel_state.get_tensor_model_parallel_rank()]
    error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
    print('   error in grad (parallel) on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-12, 'error: {}'.format(error)

    weight_grad_orig = torch.split(
        embedding_original.weight.grad,
        vocab_size // tensor_model_parallel_size,
        0)[parallel_state.get_tensor_model_parallel_rank()]
    error = embedding_vocab_parallel.weight.grad.sub(
        weight_grad_orig).abs().max()
    print('   error in grad (vocab parallel) on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-12, 'error: {}'.format(error)

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
예제 #26
0
def test_column_parallel_linear(tensor_model_parallel_size):

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    if torch.distributed.get_rank() == 0:
        print('> testing ColumnParallelLinear with model parallel '
              'size: {}'.format(tensor_model_parallel_size))
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    seed = 12345
    set_random_seed(seed)
    input_size_coeff = 13
    input_size = input_size_coeff * tensor_model_parallel_size
    output_size_coeff = 17
    output_size = output_size_coeff * tensor_model_parallel_size
    batch_size = 7
    hidden_size = 9

    # Network
    gradient_accumulation_fusion = True
    identity_layer = IdentityLayer3D(batch_size, hidden_size,
                                     input_size).cuda()
    linear_layer = layers.ColumnParallelLinear(
        input_size,
        output_size,
        keep_master_weight_for_test=True,
        params_dtype=global_vars.get_args().params_dtype,
        use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
        gradient_accumulation_fusion=gradient_accumulation_fusion,
    ).cuda()
    with torch.no_grad():
        linear_layer.weight.main_grad = torch.randn_like(linear_layer.weight)

    loss_weight = torch.randn([batch_size, hidden_size, output_size]).cuda()
    # Forward
    input_ = identity_layer()
    output, _ = linear_layer(input_)
    assert list(output.shape) == [batch_size, hidden_size, output_size]
    loss = torch.mul(output, loss_weight).sum()
    # Backward
    loss.backward()

    # TODO (mkozuki): Fix the following commented out lines
    # as `gradient_accumulation_fusion` only takes 3D tensors.
    # Values.
    # dLdY = loss_weight  # (7, 9, 17)
    # X = identity_layer.weight  # (7, 9, 13)
    # A = linear_layer.master_weight.cuda()  # (17, 13)
    # print(f"dLdY.shape, X.shape, A.shape = {dLdY.shape, X.shape, A.shape}")
    # dLdA = torch.matmul(dLdY.view(-1, 17).t(), X.view(-1, 13))
    # print(f"dLdA.shape = {dLdA.shape}")
    # ones = torch.ones(batch_size, hidden_size, 1).cuda()
    # print(f"dLdY.shape, ones.shape = {dLdY.shape, ones.shape}")
    # dLdb = torch.matmul(ones, dLdY).view(-1)
    # dLdX = torch.matmul(dLdY, A)

    # rank = parallel_state.get_tensor_model_parallel_rank()
    # my_dLdA = torch.split(dLdA, output_size_coeff,
    #                       dim=0)[rank].contiguous().clone()
    # error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
    # torch.distributed.barrier()
    # print('   error in dLdA on global rank {}: {}'.format(
    #     torch.distributed.get_rank(), error))
    # assert error < 1.0e-6

    # my_dLdb = torch.split(dLdb, output_size_coeff,
    #                       dim=0)[rank].contiguous().clone()
    # error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
    # torch.distributed.barrier()
    # print('   error in dLdb on global rank {}: {}'.format(
    #     torch.distributed.get_rank(), error))
    # assert error < 1.0e-6

    # error = dLdX.sub(identity_layer.weight.grad).abs().max()
    # torch.distributed.barrier()
    # print('   error in dLdX on global rank {}: {}'.format(
    #     torch.distributed.get_rank(), error))
    # assert error < 1.0e-6

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')
예제 #27
0
def test_initialize_affine_weight(tensor_model_parallel_size, device):

    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    if torch.distributed.get_rank() == 0:
        print('> testing initialize_affine_weight with model parallel '
              'size: {}'.format(tensor_model_parallel_size))
    tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size(
    )

    seed = 12345
    input_size_coeff = 13
    input_size = input_size_coeff * tensor_model_parallel_size
    output_size_coeff = 17
    output_size = output_size_coeff * tensor_model_parallel_size

    # ---------------
    # Column parallel
    # ---------------
    weight = torch.empty(output_size_coeff, input_size)
    set_random_seed(seed)
    if device == 'cpu':
        layers._initialize_affine_weight_cpu(
            weight,
            output_size,
            input_size,
            output_size_coeff,
            0,
            torch.nn.init.normal_,
            params_dtype=global_vars.get_args().params_dtype,
        )
    else:
        layers._initialize_affine_weight_gpu(weight, torch.nn.init.normal_, 0)

    # Target.
    set_random_seed(seed)
    master_weight = torch.empty(output_size, input_size)
    torch.nn.init.normal_(master_weight)
    rank = parallel_state.get_tensor_model_parallel_rank()
    my_weight = torch.split(master_weight, output_size_coeff,
                            dim=0)[rank].contiguous().clone()

    # Compare.
    error = weight.sub(my_weight).abs().max()
    torch.distributed.barrier()
    print('   column parallel max error (should be zero) on global rank '
          '{}: {}'.format(torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # ------------
    # Row parallel
    # ------------
    weight = torch.empty(output_size, input_size_coeff)
    set_random_seed(seed)
    if device == 'cpu':
        layers._initialize_affine_weight_cpu(
            weight,
            output_size,
            input_size,
            input_size_coeff,
            1,
            torch.nn.init.normal_,
            params_dtype=global_vars.get_args().params_dtype)

    else:
        layers._initialize_affine_weight_gpu(weight, torch.nn.init.normal_, 1)

    # Target.
    set_random_seed(seed)
    master_weight = torch.empty(output_size, input_size)
    torch.nn.init.normal_(master_weight)
    rank = parallel_state.get_tensor_model_parallel_rank()
    my_weight = torch.split(master_weight, input_size_coeff,
                            dim=1)[rank].contiguous().clone()

    # Compare.
    error = weight.sub(my_weight).abs().max()
    torch.distributed.barrier()
    print('   row parallel max error (should be zero) on global rank '
          '{}: {}'.format(torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')
예제 #28
0
    def __init__(
        self,
        init_method,
        output_layer_init_method,
        layer_number,
        num_attention_heads,
        hidden_size,
        attention_type=AttnType.self_attn,
        attn_mask_type=AttnMaskType.padding,
        precision=16,
        apply_query_key_layer_scaling=True,
        kv_channels=None,
        use_cpu_initialization=False,
        masked_softmax_fusion=True,
        attention_dropout=0.1,
    ):
        super(ParallelAttention, self).__init__()

        self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = False
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type

        if kv_channels is None:
            assert (
                hidden_size % num_attention_heads == 0
            ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
            kv_channels = hidden_size // num_attention_heads
        projection_size = kv_channels * num_attention_heads

        # Per attention head and per partition values.
        world_size = parallel_state.get_tensor_model_parallel_world_size()
        self.hidden_size_per_partition = safe_divide(projection_size,
                                                     world_size)
        self.hidden_size_per_attention_head = safe_divide(
            projection_size, num_attention_heads)
        self.num_attention_heads_per_partition = safe_divide(
            num_attention_heads, world_size)

        # Strided linear layer.
        if attention_type == AttnType.self_attn:
            self.query_key_value = tensor_parallel.ColumnParallelLinear(
                hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method,
                use_cpu_initialization=use_cpu_initialization,
            )
        else:
            assert attention_type == AttnType.cross_attn
            self.query = tensor_parallel.ColumnParallelLinear(
                hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = tensor_parallel.ColumnParallelLinear(
                hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        fused_fp16 = precision == 16
        fused_bf16 = precision == 'bf16'
        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            fused_fp16,
            fused_bf16,
            self.attn_mask_type,
            masked_softmax_fusion,
            attention_mask_func,
            self.attention_softmax_in_fp32,
            coeff,
        )

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

        # Output.
        self.dense = tensor_parallel.RowParallelLinear(
            projection_size,
            hidden_size,
            input_is_parallel=True,
            init_method=output_layer_init_method,
            skip_bias_add=True,
            use_cpu_initialization=use_cpu_initialization,
        )
예제 #29
0
    def __init__(
        self,
        input_size,
        output_size,
        bias=True,
        gather_output=True,
        init_method=init.xavier_normal_,
        stride=1,
        keep_master_weight_for_test=False,
        skip_bias_add=False,
        *,
        no_async_tensor_model_parallel_allreduce=False,
        params_dtype=torch.float32,
        use_cpu_initialization=False,
        gradient_accumulation_fusion=False,
        accumulation_in_fp16: bool = False,
    ):
        super(ColumnParallelLinear, self).__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.gather_output = gather_output
        # Divide the weight matrix along the last dimension.
        world_size = get_tensor_model_parallel_world_size()
        self.output_size_per_partition = divide(output_size, world_size)
        self.skip_bias_add = skip_bias_add

        # Parameters.
        # Note: torch.nn.functional.linear performs XA^T + b and as a result
        # we allocate the transpose.
        # Initialize weight.
        if use_cpu_initialization:
            self.weight = Parameter(
                torch.empty(
                    self.output_size_per_partition, self.input_size, dtype=params_dtype
                )
            )
            self.master_weight = _initialize_affine_weight_cpu(
                self.weight,
                self.output_size,
                self.input_size,
                self.output_size_per_partition,
                0,
                init_method,
                stride=stride,
                return_master_weight=keep_master_weight_for_test,
                params_dtype=params_dtype,
            )
        else:
            self.weight = Parameter(
                torch.empty(
                    self.output_size_per_partition,
                    self.input_size,
                    device=torch.cuda.current_device(),
                    dtype=params_dtype,
                )
            )
            _initialize_affine_weight_gpu(
                self.weight, init_method, partition_dim=0, stride=stride
            )

        if bias:
            if use_cpu_initialization:
                self.bias = Parameter(
                    torch.empty(self.output_size_per_partition, dtype=params_dtype)
                )
            else:
                self.bias = Parameter(
                    torch.empty(
                        self.output_size_per_partition,
                        device=torch.cuda.current_device(),
                        dtype=params_dtype,
                    )
                )
            set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
            # Always initialize bias to zero.
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter("bias", None)

        self.async_tensor_model_parallel_allreduce = (
            not no_async_tensor_model_parallel_allreduce and world_size > 1
        )
        if gradient_accumulation_fusion:
            if not _grad_accum_fusion_available:
                # Basically, apex.transformer module users are expected to install APEX's
                # `--cpp_ext` and `--cuda_ext`. The example installation command is as follows:
                # `pip install --global-option="--cpp_ext" --global-option="--cuda_ext ."
                # at the root of APEX repository.
                import warnings

                warnings.warn(
                    "`gradient_accumulation_fusion` is set to `True` but "
                    "the custom CUDA extension of `fused_weight_gradient_mlp_cuda` module not "
                    "found. Thus `gradient_accumulation_fusion` set to `False`. "
                    "Note that the extension requires CUDA>=11."
                )
                gradient_accumulation_fusion = False
        self.gradient_accumulation_fusion = gradient_accumulation_fusion

        self._forward_impl = (
            linear_with_grad_accumulation_and_async_allreduce_in16bit
            if accumulation_in_fp16
            else linear_with_grad_accumulation_and_async_allreduce
        )
예제 #30
0
    def __init__(
        self,
        input_size,
        output_size,
        bias=True,
        input_is_parallel=False,
        init_method=init.xavier_normal_,
        stride=1,
        keep_master_weight_for_test=False,
        skip_bias_add=False,
        *,
        params_dtype=torch.float32,
        use_cpu_initialization=False,
    ):
        super(RowParallelLinear, self).__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.input_is_parallel = input_is_parallel
        # Divide the weight matrix along the last dimension.
        world_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, world_size)
        self.skip_bias_add = skip_bias_add

        # as an argument to this function?
        # Parameters.
        # Note: torch.nn.functional.linear performs XA^T + b and as a result
        # we allocate the transpose.
        # Initialize weight.
        if use_cpu_initialization:
            self.weight = Parameter(
                torch.empty(
                    self.output_size, self.input_size_per_partition, dtype=params_dtype
                )
            )
            self.master_weight = _initialize_affine_weight_cpu(
                self.weight,
                self.output_size,
                self.input_size,
                self.input_size_per_partition,
                1,
                init_method,
                stride=stride,
                return_master_weight=keep_master_weight_for_test,
                params_dtype=params_dtype,
            )
        else:
            self.weight = Parameter(
                torch.empty(
                    self.output_size,
                    self.input_size_per_partition,
                    device=torch.cuda.current_device(),
                    dtype=params_dtype,
                )
            )
            _initialize_affine_weight_gpu(
                self.weight, init_method, partition_dim=1, stride=stride
            )
        if bias:
            if use_cpu_initialization:
                self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            else:
                self.bias = Parameter(
                    torch.empty(
                        self.output_size,
                        device=torch.cuda.current_device(),
                        dtype=params_dtype,
                    )
                )
            # Always initialize bias to zero.
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter("bias", None)