Esempio n. 1
0
def slice_state_dict(config, loaded_state_dict):
    sliced_state_dict = OrderedDict()
    start_layer_id = (
        config.n_total_layers // mpu.get_pipeline_parallel_world_size() *
        mpu.get_pipeline_parallel_group_rank() +
        min(mpu.get_pipeline_parallel_group_rank(),
            config.n_total_layers % mpu.get_pipeline_parallel_world_size()))
    end_layer_id = start_layer_id + config.n_layers
    for key, value in loaded_state_dict.items():
        keys = key.split('.')
        global_layer_id = int(keys[2])
        if start_layer_id <= global_layer_id < end_layer_id:
            local_layer_id = global_layer_id - start_layer_id
            new_key = '.'.join(keys[:2] + [str(local_layer_id)] + keys[3:])
            if keys[3] == 'attn' and keys[4] == 'in_proj':
                in_size = mpu.divide(value.size(0),
                                     mpu.get_model_parallel_world_size())
                if keys[5] in ('weight', 'bias'):
                    new_value = value[mpu.get_model_parallel_rank() *
                                      in_size:(mpu.get_model_parallel_rank() +
                                               1) * in_size]
                else:
                    raise NotImplementedError(f"Unknown key {key}")
            elif keys[3] == 'attn' and keys[4] == 'out_proj':
                if keys[5] == 'weight':
                    out_size = mpu.divide(value.size(1),
                                          mpu.get_model_parallel_world_size())
                    new_value = value[:,
                                      mpu.get_model_parallel_rank() *
                                      out_size:(mpu.get_model_parallel_rank() +
                                                1) * out_size]
                elif keys[5] == 'bias':
                    new_value = value
                else:
                    raise NotImplementedError(f"Unknown key {key}")
            elif keys[3] == 'fc1':
                in_size = mpu.divide(value.size(0),
                                     mpu.get_model_parallel_world_size())
                if keys[4] in ('weight', 'bias'):
                    new_value = value[mpu.get_model_parallel_rank() *
                                      in_size:(mpu.get_model_parallel_rank() +
                                               1) * in_size]
                else:
                    raise NotImplementedError(f"Unknown key {key}")
            elif keys[3] == 'fc2':
                if keys[4] == 'weight':
                    out_size = mpu.divide(value.size(1),
                                          mpu.get_model_parallel_world_size())
                    new_value = value[:,
                                      mpu.get_model_parallel_rank() *
                                      out_size:(mpu.get_model_parallel_rank() +
                                                1) * out_size]
                elif keys[4] == 'bias':
                    new_value = value
                else:
                    raise NotImplementedError(f"Unknown key {key}")
            else:
                new_value = value
            sliced_state_dict[new_key] = new_value
    return sliced_state_dict
Esempio n. 2
0
def parallel_self_attention(model_parallel_size, num_att_heads_per_partition,
                            hidden_size_per_att_head, dropout_prob, batch_size,
                            sequence_length):
    mpu.initialize_model_parallel(model_parallel_size)
    model_parallel_size = mpu.get_model_parallel_world_size()

    seed = 12345
    set_random_seed(seed)

    num_att_heads = num_att_heads_per_partition * \
                    torch.distributed.get_world_size()
    hidden_size = hidden_size_per_att_head * num_att_heads

    # Network
    identity_layer = IdentityLayer3D(batch_size, sequence_length,
                                     hidden_size).cuda()
    attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads,
                                                    dropout_prob).cuda()
    loss_weight = torch.randn([batch_size, sequence_length,
                               hidden_size]).cuda()
    attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
    # Forward
    input_ = identity_layer()
    output = attention_layer(input_, attention_mask)
    loss = torch.mul(output, loss_weight).sum()
    # Backward
    loss.backward()

    rank = mpu.get_model_parallel_rank()
    mpu.destroy_model_parallel()
    return rank, hidden_size, model_parallel_size, loss, \
        attention_layer, identity_layer
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
        mpu.initialize_model_parallel(args.model_parallel_size)

    # create model
    conf_dict = EasyDict(yaml.load(open(args.cfg, "r"), Loader=yaml.Loader))
    conf_dict.world_size = mpu.get_model_parallel_world_size()
    conf_dict.gpu = args.gpu
    conf_dict.device = torch.device(
        "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
    solver = Solver(conf_dict)
    solver.train()
Esempio n. 4
0
def prepare_tokenizer(args):
    tokenizer_args = {
        'tokenizer_type': args.tokenizer_type,
        'corpus': None,
        'model_path': args.tokenizer_path,
        'vocab_size': args.vocab_size,
        'model_type': args.tokenizer_model_type,
        'cache_dir': args.cache_dir}
    tokenizer = make_tokenizer(**tokenizer_args)

    num_tokens = tokenizer.num_tokens
    before = num_tokens
    after = before
    multiple = args.make_vocab_size_divisible_by * \
               mpu.get_model_parallel_world_size()
    while (after % multiple) != 0:
        after += 1
    print_rank_0('> padded vocab (size: {}) with {} dummy '
                 'tokens (new size: {})'.format(
        before, after - before, after))

    args.tokenizer_num_tokens = after
    args.tokenizer_num_type_tokens = tokenizer.num_type_tokens
    args.eod_token = tokenizer.get_command('eos').Id

    # after = tokenizer.num_tokens
    # while after % mpu.get_model_parallel_world_size() != 0:
    #     after += 1

    args.vocab_size = after
    print("prepare tokenizer done", flush=True)

    return tokenizer
Esempio n. 5
0
def prepare_tokenizer(args):

    tokenizer_args = {
        'tokenizer_type': args.tokenizer_type,
        'corpus': None,
        'model_path': args.tokenizer_path,
        'vocab_size': args.vocab_size,
        'model_type': args.tokenizer_model_type,
        'cache_dir': args.cache_dir
    }
    tokenizer = make_tokenizer(**tokenizer_args)

    args.tokenizer_num_tokens = tokenizer.num_tokens
    args.tokenizer_num_type_tokens = tokenizer.num_type_tokens
    args.eod_token = tokenizer.get_command('eos').Id

    after = tokenizer.num_tokens
    multiple = args.make_vocab_size_divisible_by * \
                   mpu.get_model_parallel_world_size()
    if multiple != 0:
        while (after % multiple) != 0:
            after += 1

    args.vocab_size = after
    print("prepare tokenizer done", flush=True)

    return tokenizer
Esempio n. 6
0
def test_boradcast_data(model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print(
            '> testing boradcast_data with model parallel size {} ...'.format(
                model_parallel_size))

    mpu.initialize_model_parallel(model_parallel_size)
    torch.manual_seed(1234 + mpu.get_data_parallel_rank())
    model_parallel_size = mpu.get_model_parallel_world_size()

    key_size_t = {
        'key1': [7, 11],
        'key2': [8, 2, 1],
        'key3': [13],
        'key4': [5, 1, 2],
        'key5': [5, 12]
    }
    keys = list(key_size_t.keys())

    data = {}
    data_t = {}
    for key in key_size_t:
        data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000)
        data_t[key] = data[key].clone()
    data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
    data_t['keyX'] = data['keyX'].clone()
    if mpu.get_model_parallel_rank() != 0:
        data = None

    data_utils._check_data_types(keys, data_t, torch.int64)
    key_size, key_numel, \
        total_numel = data_utils._build_key_size_numel_dictionaries(keys, data)
    for key in keys:
        assert key_size[key] == key_size_t[key]
    total_numel_t = 0
    for key in keys:
        target_size = functools.reduce(operator.mul, key_size_t[key], 1)
        assert key_numel[key] == target_size
        total_numel_t += target_size
    assert total_numel == total_numel_t

    data_b = data_utils.broadcast_data(keys, data, torch.int64)
    for key in keys:
        tensor = data_t[key].cuda()
        assert data_b[key].sub(tensor).abs().max() == 0

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
Esempio n. 7
0
    def __init__(self, embedding_dim, ffn_embedding_dim, num_attention_heads, device='cpu',
                checkpoint_gradients=False):
        nn.Module.__init__(self)
        self.model_parallel_size = mpu.get_model_parallel_world_size()
        self.checkpoint_gradients = checkpoint_gradients
        assert ffn_embedding_dim % self.model_parallel_size == 0

        # TODO: write a custom inplace LayerNorm layer
        self.attn_ln = nn.LayerNorm(embedding_dim).to(device)
        self.attn = ModelParallelMultiheadLMAttentionWithCache(embedding_dim, num_attention_heads, device=device)
        self.fc_ln = nn.LayerNorm(embedding_dim).to(device)
        self.fc1 = mpu.ColumnParallelLinear(embedding_dim, ffn_embedding_dim, gather_output=False, device=device)
        self.fc2 = mpu.RowParallelLinear(ffn_embedding_dim, embedding_dim, input_is_parallel=True, device=device)
Esempio n. 8
0
def get_train_val_test_data(args):
    """Load the data on rank zero and boradcast number of tokens to all GPUS."""

    (train_data, val_data, test_data) = (None, None, None)

    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
        if args.use_npy_data_loader:
            (train_data, val_data, test_data), num_tokens, \
                eod_token = make_gpt2_dataloaders(args)
        else:
            data_config = configure_data()
            data_config.set_defaults(data_set_type='GPT2', transpose=False)
            (train_data, val_data,
             test_data), tokenizer = data_config.apply(args)
            num_tokens = tokenizer.num_tokens
            eod_token = tokenizer.get_command('eos').Id
            assert eod_token == tokenizer.get_command('pad').Id
        before = num_tokens
        after = before
        multiple = args.make_vocab_size_divisible_by * \
                   mpu.get_model_parallel_world_size()
        while (after % multiple) != 0:
            after += 1
        print_rank_0('> padded vocab (size: {}) with {} dummy '
                     'tokens (new size: {})'.format(before, after - before,
                                                    after))
        print_rank_0('> found end-of-document token: {}'.format(eod_token))
        token_counts = torch.cuda.LongTensor([
            after, eod_token,
            int(args.do_train),
            int(args.do_valid),
            int(args.do_test)
        ])
    else:
        token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(token_counts,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    num_tokens = token_counts[0].item()
    eod_token = token_counts[1].item()
    args.do_train = token_counts[2].item()
    args.do_valid = token_counts[3].item()
    args.do_test = token_counts[4].item()

    return train_data, val_data, test_data, num_tokens, eod_token
Esempio n. 9
0
def evaluate_tnews(args, model, dataloader, device, mode="dev"):
    model.eval()
    all_truth, all_preds = [], []
    with torch.no_grad():
        for batch, no_model_batch in tqdm(dataloader, desc="Evaluating {}".format(mode),
                                          disable=(torch.distributed.get_rank() != 0)):
            for k in batch:
                batch[k] = batch[k].to(device)
            for k in no_model_batch:
                no_model_batch[k] = no_model_batch[k].to(device)

            output = model(**batch)
            output = torch.sum(output * no_model_batch["loss_mask"].unsqueeze(-1), 1) / torch.sum(
                no_model_batch["loss_mask"], -1).unsqueeze(-1)

            # gather the output logits from other gpus
            tensor_list = [torch.zeros_like(output) for _ in range(mpu.get_data_parallel_world_size())]
            torch.distributed.all_gather(tensor_list, output, mpu.get_data_parallel_group())

            # gather the truth labels from other gpus
            tensor_list_truth = [torch.zeros_like(no_model_batch["truth"], dtype=torch.long) for _ in
                                 range(mpu.get_data_parallel_world_size())]
            torch.distributed.all_gather(tensor_list_truth, no_model_batch["truth"], mpu.get_data_parallel_group())

            if args.model_parallel_size == 1:
                scores = torch.stack(tensor_list, 0).view(-1, 30000)
            else:
                assert args.model_parallel_size == 2, "Now, we only support model parallel <= 2"
                # for convience implementation. Note that the truth labels only appears in the first 15000 part of the logits, e.g. on rank 0, 2, 4, ...
                scores = torch.stack(tensor_list, 0).view(-1, 15000)

            truth = torch.stack(tensor_list_truth, 0)
            truth = truth.view(-1)
            # scores = scores[:, cand_ids]

            preds = torch.argmax(scores, dim=-1)

            all_truth.extend(truth.detach().cpu().tolist())
            all_preds.extend(preds.detach().cpu().tolist())

    acc = sum([int(p == l) for p, l in zip(all_preds, all_truth)]) / len(all_truth)
    acc = torch.tensor(acc).to(device)

    acc_list = [torch.zeros_like(acc) for _ in range(mpu.get_model_parallel_world_size())]
    torch.distributed.all_gather(acc_list, acc, mpu.get_model_parallel_group())

    return acc_list[0].item(), all_truth, all_preds
Esempio n. 10
0
def get_train_val_test_data(args):
    """Load the data on rank zero and boradcast number of tokens to all GPUS."""

    (train_data, val_data, test_data) = (None, None, None)

    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
        data_config = configure_data()
        ds_type = 'BERT'
        data_config.set_defaults(data_set_type=ds_type, transpose=False)
        (train_data, val_data, test_data), tokenizer = data_config.apply(args)
        before = tokenizer.num_tokens
        after = before
        multiple = args.make_vocab_size_divisible_by * \
                   mpu.get_model_parallel_world_size()
        while (after % multiple) != 0:
            after += 1
        print_rank_0('> padded vocab (size: {}) with {} dummy '
                     'tokens (new size: {})'.format(before, after - before,
                                                    after))
        # Need to broadcast num_tokens and num_type_tokens.
        token_counts = torch.cuda.LongTensor([
            after, tokenizer.num_type_tokens,
            int(args.do_train),
            int(args.do_valid),
            int(args.do_test)
        ])
    else:
        token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(token_counts,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    num_tokens = token_counts[0].item()
    num_type_tokens = token_counts[1].item()
    args.do_train = token_counts[2].item()
    args.do_valid = token_counts[3].item()
    args.do_test = token_counts[4].item()

    return train_data, val_data, test_data, num_tokens, num_type_tokens
Esempio n. 11
0
def prepare_tokenizer(args):
    add_sentinel_token = 0
    if args.sentinel_token:
        add_sentinel_token = args.max_position_embeddings
    tokenizer = make_tokenizer(args.tokenizer_type,
                               None,
                               args.tokenizer_path,
                               args.vocab_size,
                               args.tokenizer_model_type,
                               add_block_symbols=args.block_lm,
                               cache_dir=args.cache_dir,
                               add_sentinel_token=add_sentinel_token,
                               add_task_mask=args.task_mask,
                               add_decoder_mask=args.block_mask_prob > 0.0
                               or args.context_mask_ratio > 0.0)
    if mpu.get_model_parallel_rank() == 0:
        num_tokens = tokenizer.num_tokens
        eod_token = tokenizer.get_command('eos').Id
        assert eod_token == tokenizer.get_command('pad').Id
        before = num_tokens
        after = before
        multiple = args.make_vocab_size_divisible_by * \
                   mpu.get_model_parallel_world_size()
        while (after % multiple) != 0:
            after += 1
        print_rank_0('> padded vocab (size: {}) with {} dummy '
                     'tokens (new size: {})'.format(before, after - before,
                                                    after))
        print_rank_0('> found end-of-document token: {}'.format(eod_token))
        token_counts = torch.cuda.LongTensor([after, eod_token])
    else:
        token_counts = torch.cuda.LongTensor([0, 0])
    # Broadcast num tokens.
    torch.distributed.broadcast(token_counts,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    num_tokens = token_counts[0].item()
    eod_token = token_counts[1].item()
    args.vocab_size, args.eod_token = num_tokens, eod_token
    return tokenizer
Esempio n. 12
0
def test_cross_entropy(model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing cross entropy with model parallel size {} ...'.format(
            model_parallel_size))

    mpu.initialize_model_parallel(model_parallel_size)
    model_parallel_size = mpu.get_model_parallel_world_size()

    batch_size = 13
    seq_length = 17
    vocab_size_per_partition = 11
    logits_scale = 1000.0
    vocab_size = vocab_size_per_partition * model_parallel_size
    seed = 1234

    loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length,
                                                 vocab_size, logits_scale,
                                                 seed)
    loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length, vocab_size,
                                           logits_scale, seed)

    error = loss_torch.sub_(loss_mpu).abs().max()
    print('   max error in loss on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    error = grad_torch.sub_(grad_mpu).abs().max()
    print('   max error in grad on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
Esempio n. 13
0
    def __init__(self, embed_dim, num_heads, bias=True, device='cpu'):
        nn.Module.__init__(self)

        self.embed_dim = embed_dim

        self.in_proj = mpu.ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias,
                                                gather_output=False, device=device)
        self.out_proj = mpu.RowParallelLinear(embed_dim, embed_dim, bias=bias,
                                              input_is_parallel=True, device=device)

        self.model_parallel_size = mpu.get_model_parallel_world_size()

        self.num_total_heads = num_heads
        self.num_heads = self.num_total_heads // self.model_parallel_size
        assert (
                self.num_heads * self.model_parallel_size == num_heads
        ), "Number of heads must be divisble by model parallel size"

        self.head_dim = embed_dim // num_heads
        assert (
            self.head_dim * num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim ** -0.5
Esempio n. 14
0
def test_initialize_model_parallel(model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing initialize_model_parallel with size {} ...'.format(
            model_parallel_size))
    model_parallel_size_ = min(model_parallel_size,
                               torch.distributed.get_world_size())
    assert not mpu.model_parallel_is_initialized()
    mpu.initialize_model_parallel(model_parallel_size_)
    assert mpu.model_parallel_is_initialized()

    # Checks.
    def check(group, world_size, rank):
        assert world_size == torch.distributed.get_world_size(group=group)
        assert rank == torch.distributed.get_rank(group=group)

    # Model parallel.
    world_size = model_parallel_size_
    rank = torch.distributed.get_rank() % model_parallel_size_
    assert world_size == mpu.get_model_parallel_world_size()
    assert rank == mpu.get_model_parallel_rank()
    check(mpu.get_model_parallel_group(), world_size, rank)


    # Data parallel.
    world_size = torch.distributed.get_world_size() // model_parallel_size_
    rank = torch.distributed.get_rank() // model_parallel_size
    assert world_size == mpu.get_data_parallel_world_size()
    assert rank == mpu.get_data_parallel_rank()
    check(mpu.get_data_parallel_group(), world_size, rank)

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
Esempio n. 15
0
def test_model_parallel_cuda_manual_seed(model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing model parallel cuda manual seed with size {} ...'.
              format(model_parallel_size))

    mpu.initialize_model_parallel(model_parallel_size)
    model_parallel_size = mpu.get_model_parallel_world_size()

    mpu.model_parallel_cuda_manual_seed(12345)
    assert torch.cuda.initial_seed() == 12345
    with mpu.get_cuda_rng_tracker().fork():
        assert torch.cuda.initial_seed() == (12345 + 2718 +
                                             mpu.get_model_parallel_rank())

    # Reset the tracker
    mpu.get_cuda_rng_tracker().reset()

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
Esempio n. 16
0
def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
                         hidden_size_per_att_head, batch_size,
                         sequence_length):

    mpu.initialize_model_parallel(model_parallel_size)
    model_parallel_size = mpu.get_model_parallel_world_size()

    seed = 12345
    set_random_seed(seed)

    num_att_heads = num_att_heads_per_partition * \
                    torch.distributed.get_world_size()
    hidden_size = hidden_size_per_att_head * num_att_heads
    intermediate_size = 4 * hidden_size

    # Network
    identity_layer = IdentityLayer3D(batch_size, sequence_length,
                                     hidden_size).cuda()
    transformer_layer = mpu.BertParallelTransformerLayer(
        hidden_size, intermediate_size, num_att_heads, 0.0, 0.0,
        torch.nn.functional.relu, 1.0e-5).cuda()

    loss_weight = torch.randn([batch_size, sequence_length,
                               hidden_size]).cuda()
    attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
    # Forward
    input_ = identity_layer()
    output = transformer_layer(input_, attention_mask)
    loss = torch.mul(output, loss_weight).sum()
    # Backward
    loss.backward()

    rank = mpu.get_model_parallel_rank()
    mpu.destroy_model_parallel()
    return rank, hidden_size, model_parallel_size, loss, \
        transformer_layer, identity_layer
Esempio n. 17
0
def test_initialize_affine_weight(model_parallel_size):

    mpu.initialize_model_parallel(model_parallel_size)
    if torch.distributed.get_rank() == 0:
        print('> testing initialize_affine_weight with model parallel '
              'size: {}'.format(model_parallel_size))
    model_parallel_size = mpu.get_model_parallel_world_size()

    seed = 12345
    input_size_coeff = 13
    input_size = input_size_coeff * model_parallel_size
    output_size_coeff = 17
    output_size = output_size_coeff * model_parallel_size

    # ---------------
    # Column parallel
    # ---------------
    weight = torch.empty(output_size_coeff, input_size)
    set_random_seed(seed)
    layers._initialize_affine_weight(weight, output_size, input_size,
                                     output_size_coeff, 0,
                                     torch.nn.init.normal_)
    # Target.
    set_random_seed(seed)
    master_weight = torch.empty(output_size, input_size)
    torch.nn.init.normal_(master_weight)
    rank = mpu.get_model_parallel_rank()
    my_weight = torch.split(master_weight, output_size_coeff,
                            dim=0)[rank].contiguous().clone()

    # Compare.
    error = weight.sub(my_weight).abs().max()
    torch.distributed.barrier()
    print('   column parallel max error (should be zero) on global rank '
          '{}: {}'.format(torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # ------------
    # Row parallel
    # ------------
    weight = torch.empty(output_size, input_size_coeff)
    set_random_seed(seed)
    mpu.layers._initialize_affine_weight(weight, output_size, input_size,
                                         input_size_coeff, 1,
                                         torch.nn.init.normal_)
    # Target.
    set_random_seed(seed)
    master_weight = torch.empty(output_size, input_size)
    torch.nn.init.normal_(master_weight)
    rank = mpu.get_model_parallel_rank()
    my_weight = torch.split(master_weight, input_size_coeff,
                            dim=1)[rank].contiguous().clone()

    # Compare.
    error = weight.sub(my_weight).abs().max()
    torch.distributed.barrier()
    print('   row parallel max error (should be zero) on global rank '
          '{}: {}'.format(torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')
Esempio n. 18
0
def test_row_parallel_linear(model_parallel_size):

    mpu.initialize_model_parallel(model_parallel_size)
    if torch.distributed.get_rank() == 0:
        print('> testing RowParallelLinear with model parallel '
              'size: {}'.format(model_parallel_size))
    model_parallel_size = mpu.get_model_parallel_world_size()

    seed = 12345
    set_random_seed(seed)
    input_size_coeff = 13
    input_size = input_size_coeff * model_parallel_size
    output_size_coeff = 17
    output_size = output_size_coeff * model_parallel_size
    batch_size = 7

    # Network
    identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
    linear_layer = mpu.RowParallelLinear(
        input_size, output_size, keep_master_weight_for_test=True).cuda()
    loss_weight = torch.randn([batch_size, output_size]).cuda()
    # Forward
    input_ = identity_layer()
    output = linear_layer(input_)
    loss = torch.mul(output, loss_weight).sum()
    # Backward
    loss.backward()

    # Values.
    dLdY = loss_weight
    X = identity_layer.weight
    A = linear_layer.master_weight.cuda()
    dLdA = torch.matmul(dLdY.t(), X)
    dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
    dLdX = torch.matmul(dLdY, A)

    rank = mpu.get_model_parallel_rank()
    my_dLdA = torch.split(dLdA, input_size_coeff,
                          dim=1)[rank].contiguous().clone()
    error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
    torch.distributed.barrier()
    print('   error in dLdA on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    error = dLdb.sub(linear_layer.bias.grad).abs().max()
    torch.distributed.barrier()
    print('   error in dLdb on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    error = dLdX.sub(identity_layer.weight.grad).abs().max()
    torch.distributed.barrier()
    print('   error in dLdX on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')
Esempio n. 19
0
def test_parallel_embedding(model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing parallel embedding with model parallel size {} ...'.
              format(model_parallel_size))

    mpu.initialize_model_parallel(model_parallel_size)
    model_parallel_size = mpu.get_model_parallel_world_size()

    batch_size = 17
    seq_length = 23
    vocab_size = 48
    hidden_size = 16
    seed = 1236

    set_random_seed(123)
    input_data = torch.LongTensor(size=(batch_size, seq_length)).random_(
        0, vocab_size).cuda()
    loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()

    set_random_seed(seed)
    embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda()

    output = embedding_original(input_data)
    loss_original = torch.mul(output, loss_weight).sum()
    loss_original.backward()

    set_random_seed(seed)
    embedding_parallel = layers.ParallelEmbedding(
        vocab_size, hidden_size, init_method=init.normal_).cuda()
    output = embedding_parallel(input_data)
    loss_parallel = torch.mul(output, loss_weight).sum()
    loss_parallel.backward()

    set_random_seed(seed)
    embedding_vocab_parallel = layers.VocabParallelEmbedding(
        vocab_size, hidden_size, init_method=init.normal_).cuda()
    output = embedding_vocab_parallel(input_data)
    loss_vocab_parallel = torch.mul(output, loss_weight).sum()
    loss_vocab_parallel.backward()

    torch.distributed.barrier()
    error = loss_parallel.sub(loss_original).abs()
    print('   error in loss (parallel) on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-12, 'error: {}'.format(error)

    torch.distributed.barrier()
    error = loss_vocab_parallel.sub(loss_original).abs()
    print('   error in loss (vocab parallel) on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-12, 'error: {}'.format(error)

    weight_grad_orig = torch.split(embedding_original.weight.grad,
                                   hidden_size // model_parallel_size,
                                   1)[mpu.get_model_parallel_rank()]
    error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
    print('   error in grad (parallel) on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-12, 'error: {}'.format(error)

    weight_grad_orig = torch.split(embedding_original.weight.grad,
                                   vocab_size // model_parallel_size,
                                   0)[mpu.get_model_parallel_rank()]
    error = embedding_vocab_parallel.weight.grad.sub(
        weight_grad_orig).abs().max()
    print('   error in grad (vocab parallel) on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-12, 'error: {}'.format(error)

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
Esempio n. 20
0
def test_cuda_rng_tracker(model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing cuda rng tracker with size {} ...'.format(
            model_parallel_size))

    mpu.initialize_model_parallel(model_parallel_size)
    model_parallel_size = mpu.get_model_parallel_world_size()

    seed_1 = 1234
    seed_2 = 4321
    size = [12, 21]
    tensor = torch.cuda.FloatTensor(size)

    # Set to seed_1 and generate two tensors.
    torch.cuda.manual_seed(seed_1)
    torch.randn(size, out=tensor)
    target_11 = tensor.clone()
    torch.randn(size, out=tensor)
    target_12 = tensor.clone()

    # Set to seed_2 and generate two tensors.
    torch.cuda.manual_seed(seed_2)
    torch.randn(size, out=tensor)
    target_21 = tensor.clone()
    torch.randn(size, out=tensor)
    target_22 = tensor.clone()

    # Now if we interleave seed_1 and seed_2,
    # we should still get the same tensors
    torch.cuda.manual_seed(seed_1)
    mpu.get_cuda_rng_tracker().add('test', seed_2)

    torch.randn(size, out=tensor)
    result_11 = tensor.clone()

    with mpu.get_cuda_rng_tracker().fork('test'):
        torch.randn(size, out=tensor)
        result_21 = tensor.clone()

    torch.randn(size, out=tensor)
    result_12 = tensor.clone()

    with mpu.get_cuda_rng_tracker().fork('test'):
        torch.randn(size, out=tensor)
        result_22 = tensor.clone()

    diff = result_11.sub(result_21).abs().max()
    diff = min(diff, result_12.sub(result_22).abs().max())
    print('   max diff in generated tensors (should be non-zero) on '
          'global rank {}: {}'.format(torch.distributed.get_rank(), diff))
    assert diff > 1.0e-6
    error = max(
        result_11.sub(target_11).abs().max(),
        result_12.sub(target_12).abs().max())
    error = max(error, result_21.sub(target_21).abs().max())
    error = max(error, result_22.sub(target_22).abs().max())
    print('   max error in generated tensors (should be zero) on '
          'global rank {}: {}'.format(torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Reset the tracker
    mpu.get_cuda_rng_tracker().reset()

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
Esempio n. 21
0
def test_set_cuda_rng_state(model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing set_rng_state with size {} ...'.format(
            model_parallel_size))

    mpu.initialize_model_parallel(model_parallel_size)
    model_parallel_size = mpu.get_model_parallel_world_size()

    size = 123
    seed = 1234
    torch.cuda.manual_seed(1234)
    tensor = torch.cuda.FloatTensor(size)

    # Get the state
    rng_state = torch.cuda.get_rng_state()
    rng_state_copy = rng_state.clone()

    # Do some stuff.
    for _ in range(5):
        torch.randn(size, out=tensor)
    result_1 = tensor.clone()

    assert rng_state.sub(rng_state_copy).max() == 0
    assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0

    # State should be different.
    new_rng_state = torch.cuda.get_rng_state()
    max_diff = new_rng_state.sub(rng_state).max()
    print(
        '   max diff in rng state (should be non-zero) on global rank {}: {}'.
        format(torch.distributed.get_rank(), max_diff))
    assert max_diff > 0

    # Reset the rng state and do the same stuff.
    mpu.random._set_cuda_rng_state(rng_state)
    for _ in range(5):
        torch.randn(size, out=tensor)
    mpu.random._set_cuda_rng_state(rng_state)
    for _ in range(5):
        torch.randn(size, out=tensor)
    result_2 = tensor.clone()

    # Results should be the same
    error = result_2.sub(result_1).abs().max()
    print('   max error in generated tensors (should be zero) on '
          'global rank {}: {}'.format(torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Input state should have remained intact.
    error = rng_state.sub(rng_state_copy).max()
    print('   max error in rng state (should be zero) on global rank {}: {}'.
          format(torch.distributed.get_rank(), error))
    assert error == 0

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
Esempio n. 22
0
def verify_one_step(args):
    if args.verify == "save":
        assert dist.get_world_size() == 1
        assert mpu.get_pipeline_parallel_world_size() == 1
        assert mpu.get_model_parallel_world_size() == 1
        assert args.n_input_slices == 1
        assert args.n_batch_slices == 1
        os.makedirs(args.verify_path, exist_ok=True)
        config, layers, pipelined_layers = initialize_model(args)
        if mpu.get_pipeline_parallel_group_rank() == 0:
            x = layers.create_inputs(config.batch_size,
                                     config.seq_len,
                                     random=True)
            torch.save(x, os.path.join(args.verify_path, 'input.pt'))
        else:
            x = None
        try:
            y = pipelined_layers(x)
            if mpu.get_pipeline_parallel_group_rank(
            ) == mpu.get_pipeline_parallel_world_size() - 1:
                loss = loss_func(y)
                loss.backward()
            else:
                y.backward()
        except:
            print(f"rank={args.rank}", traceback.format_exc())
            raise
        torch.save(pipelined_layers.state_dict(),
                   os.path.join(args.verify_path, 'model.ckpt'))
        grad_dic = OrderedDict(
            (x[0], x[1].grad) for x in pipelined_layers.named_parameters())
        torch.save(grad_dic, os.path.join(args.verify_path, 'model.grad.ckpt'))
    else:
        assert args.verify == "load"
        config, layers, pipelined_layers = initialize_model(args)
        with FileLock(os.path.join(args.verify_path, 'model.ckpt.lock')):
            loaded_state_dict = torch.load(os.path.join(
                args.verify_path, 'model.ckpt'),
                                           map_location=torch.device('cuda'))
        sliced_state_dict = slice_state_dict(config, loaded_state_dict)
        pipelined_layers.load_state_dict(sliced_state_dict)
        if mpu.get_pipeline_parallel_group_rank() == 0:
            with FileLock(os.path.join(args.verify_path, 'input.pt.lock')):
                x = torch.load(os.path.join(args.verify_path, 'input.pt'),
                               map_location=torch.device('cuda'))
        else:
            x = None
        try:
            y = pipelined_layers(x)
            if mpu.get_pipeline_parallel_group_rank(
            ) == mpu.get_pipeline_parallel_world_size() - 1:
                loss = loss_func(y)
                loss.backward()
            else:
                y.backward()
        except:
            print(f"rank={args.rank}", traceback.format_exc())
            raise
        grad_dic = OrderedDict(
            (x[0], x[1].grad) for x in pipelined_layers.named_parameters())
        with FileLock(os.path.join(args.verify_path, 'model.grad.ckpt.lock')):
            loaded_grad_dic = torch.load(os.path.join(args.verify_path,
                                                      'model.grad.ckpt'),
                                         map_location=torch.device('cuda'))
        sliced_grad_dic = slice_state_dict(config, loaded_grad_dic)
        assert grad_dic.keys() == sliced_grad_dic.keys()
        for k in grad_dic.keys():
            assert torch.allclose(grad_dic[k], sliced_grad_dic[k])
def get_eval_data(args):
    val_dataloader = None
    if mpu.get_model_parallel_rank() == 0:
        eval_batch_size = args.eval_batch_size
        eval_batch_size = args.batch_size if eval_batch_size is None else eval_batch_size
        seq_len = args.seq_length
        valid_data = args.valid_data
        valid_data = valid_data[0] if isinstance(valid_data, list) else valid_data

        tokenizer = get_tokenizer(args)

        if not args.cloze_eval:

            with open(valid_data, "rb") as reader:
                entire_data = reader.read().decode('utf-8')
            num_original_tokens = len(entire_data.strip().split(" "))
            entire_data = get_detokenizer(valid_data)(entire_data)
            tokenized_data = tokenizer.EncodeAsIds(entire_data).tokenization
            num_tokenized_tokens = len(tokenized_data)
            string = 'Original Tokens: %d, Detokenized tokens: %d' % (num_tokenized_tokens, num_original_tokens)
            print_rank_0(string)

            eod_token = tokenizer.get_command('pad').Id
            val_dataset = LM_Eval_Dataset(tokenized_data, seq_len, eod_token,
                                          args.overlapping_eval)
        else:
            val_dataset = Lambada_Eval_Dataset(valid_data, tokenizer, seq_len)
            num_tokenized_tokens = 0
            num_original_tokens = 0
        val_dataloader = torch.utils.data.DataLoader(
            val_dataset, batch_size=eval_batch_size, drop_last=False)

        before = tokenizer.num_tokens
        after = before
        while after % mpu.get_model_parallel_world_size() != 0:
            after += 1
        print_rank_0('> padded vocab (size: {}) with {} dummy tokens (new size: {})'.
              format(before, after - before, after))
        eod_token = tokenizer.get_command('pad').Id
        num_examples = len(val_dataset)
        token_counts = torch.cuda.LongTensor([after, eod_token, num_examples,
                                              num_original_tokens,
                                              num_tokenized_tokens])
    else:
        token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])
    torch.distributed.broadcast(token_counts,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    args.vocab_size = token_counts[0].item()
    args.eod_token = token_counts[1].item()
    args.num_examples = token_counts[2].item()
    args.num_original_tokens = token_counts[3].item()
    args.num_tokenized_tokens = token_counts[4].item()

    print('global rank: {} | vocab size: {} | eod token: {} | '
          'num_examples: {} | num_original_tokens: {} | '
          'num_tokenized_tokens: {}'.format(
              torch.distributed.get_rank(), args.vocab_size,
              args.eod_token, args.num_examples, args.num_original_tokens,
              args.num_tokenized_tokens ))
    return val_dataloader