Exemple #1
0
def get_token_stream(model, context_tokens, tokenizer, args):
    pad_id = tokenizer.get_command('pad').Id
    # context_length = len(context_tokens)
    # if context_length < args.seq_length:
    #     context_tokens = context_tokens + [pad_id] * (args.seq_length - context_length)
    context_tokens, context_lengths = pad_batch(context_tokens, tokenizer,
                                                args)

    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)
    # context_length_tensor = torch.cuda.LongTensor([context_length])

    torch.distributed.broadcast(context_length_tensor,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    torch.distributed.broadcast(context_tokens_tensor,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())

    context_length = context_length_tensor.min().item()
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor,
                                                     args)

    counter = 0
    org_context_length = context_length

    layer_past = None

    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids,
                                                 tokenizer, args)
    for tokens, lengths in batch_token_iterator:
        context_length += 1
        yield tokens[:, :context_length], lengths
Exemple #2
0
def read_context(tokenizer, args, output):
    terminate_runs, skip_run = 0, 0
    if mpu.get_model_parallel_rank() == 0:
        while True:
            raw_text = input("\nContext prompt (stop to exit) >>> ")
            if not raw_text:
                print('Prompt should not be empty!')
                continue
            if raw_text == "stop":
                terminate_runs = 1
                break
            generation_mask = '[gMASK]' if args.task_mask else '[MASK]'
            if args.block_lm and 'MASK]' not in raw_text:
                raw_text += ' ' + generation_mask
            output.write(raw_text)
            context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
            if args.block_lm:
                context_tokens = [tokenizer.get_command('ENC').Id
                                  ] + context_tokens
                if not raw_text.endswith('MASK]'):
                    context_tokens = context_tokens + [
                        tokenizer.get_command('eos').Id
                    ]
            context_length = len(context_tokens)

            if context_length >= args.seq_length:
                print("\nContext length", context_length,
                      "\nPlease give smaller context than the window length!")
                continue
            break
    else:
        context_length = 0

    terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
    torch.distributed.broadcast(terminate_runs_tensor,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    terminate_runs = terminate_runs_tensor[0].item()

    if terminate_runs == 1:
        return terminate_runs, None, None, None

    context_length_tensor = torch.cuda.LongTensor([context_length])

    torch.distributed.broadcast(context_length_tensor,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    context_length = context_length_tensor[0].item()
    if mpu.get_model_parallel_rank() == 0:
        context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    else:
        context_tokens_tensor = torch.cuda.LongTensor([0] * context_length)
    torch.distributed.broadcast(context_tokens_tensor,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    if mpu.get_model_parallel_rank() != 0:
        raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist())
    return terminate_runs, raw_text, context_tokens_tensor, context_length
Exemple #3
0
def generate_samples(model, tokenizer, args, device):
    model.eval()
    output_path = "./samples"
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    output_path = os.path.join(output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt")
    with torch.no_grad(), open(output_path, "w") as output:
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())

            terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output)
            if terminate_runs == 1:
                return
            start_time = time.time()
            if args.block_lm:
                mems = []
                tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, device, args)
                mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
                mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
                end_tokens = [tokenizer.get_command('eop').Id, args.eod_token]
                mask_positions = []
                for token in mask_tokens:
                    mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist()
                mask_positions.sort()
                if args.no_block_position:
                    for mask_position in mask_positions:
                        position_ids[0, mask_position + 1:] += args.out_seq_length
                _, *mems = model(tokens, position_ids, attention_mask, *mems)
                for mask_position in mask_positions:
                    if args.no_block_position:
                        position = position_ids[0, mask_position].item()
                    else:
                        position = mask_position
                    tokens, mems = sample_sequence(model, tokenizer, tokens, position,
                                                   args, device, mems=mems, end_tokens=end_tokens)
            else:
                tokens, _ = sample_sequence(model, tokenizer, context_tokens_tensor, context_length, args, device)
            output_tokens_list = tokens.view(-1).contiguous()
            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
                print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
                print("\nContext:", raw_text, flush=True)
                decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
                trim_decode_tokens = decode_tokens
                print("\nGLM:", trim_decode_tokens, flush=True)
                output.write(trim_decode_tokens + "\n")

            torch.distributed.barrier(group=mpu.get_model_parallel_group())
Exemple #4
0
def prepare_tokenizer(args):
    add_sentinel_token = 0
    if args.sentinel_token:
        add_sentinel_token = args.max_position_embeddings
    tokenizer = make_tokenizer(args.tokenizer_type, None, args.tokenizer_path, args.vocab_size,
                               args.tokenizer_model_type, add_block_symbols=args.block_lm, cache_dir=args.cache_dir,
                               add_sentinel_token=add_sentinel_token, add_task_mask=args.task_mask,
                               add_decoder_mask=args.block_mask_prob > 0.0 or args.context_mask_ratio > 0.0,
                               fix_command_token=args.fix_command_token)
    if mpu.get_model_parallel_rank() == 0:
        num_tokens = tokenizer.num_tokens
        eod_token = tokenizer.get_command('eos').Id
        assert eod_token == tokenizer.get_command('pad').Id
        before = num_tokens
        after = before
        multiple = args.make_vocab_size_divisible_by
        while (after % multiple) != 0:
            after += 1
        print_rank_0('> padded vocab (size: {}) with {} dummy '
                     'tokens (new size: {})'.format(before, after - before, after))
        print_rank_0('> found end-of-document token: {}'.format(eod_token))
        token_counts = torch.cuda.LongTensor([after, eod_token])
    else:
        token_counts = torch.cuda.LongTensor([0, 0])
    # Broadcast num tokens.
    torch.distributed.broadcast(token_counts,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    num_tokens = token_counts[0].item()
    eod_token = token_counts[1].item()
    args.vocab_size, args.eod_token = num_tokens, eod_token
    return tokenizer
Exemple #5
0
def get_train_val_test_data(args, tokenizer):
    """Load the data on rank zero and boradcast number of tokens to all GPUS."""

    (train_data, val_data, test_data) = (None, None, None)
    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
        data_config = configure_data()
        if args.block_lm:
            data_set_type = "Block"
        elif args.transformer_xl:
            data_set_type = "GPT-XL"
        else:
            data_set_type = "GPT2"
        data_config.set_defaults(data_set_type=data_set_type, transpose=False)
        train_data, val_data, test_data = data_config.apply(args, tokenizer)

        data_counts = torch.cuda.LongTensor(
            [int(args.do_train),
             int(args.do_valid),
             int(args.do_test)])
    else:
        data_counts = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(data_counts,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    args.do_train = data_counts[0].item()
    args.do_valid = data_counts[1].item()
    args.do_test = data_counts[2].item()

    return train_data, val_data, test_data
Exemple #6
0
def generate_samples(model, tokenizer, args, device):
    model.eval()
    output_path = "./samples"
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    output_path = os.path.join(output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt")
    with torch.no_grad(), open(output_path, "w") as output:
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())

            terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output)
            if terminate_runs == 1:
                return
            start_time = time.time()
            output_tokens_list, _ = sample_sequence(model, tokenizer, context_tokens_tensor, context_length, args,
                                                    device)
            if args.hierarchical:
                eop_token = tokenizer.get_command('eop').Id
                if output_tokens_list[-1] == eop_token:
                    output_tokens_list = output_tokens_list[:-1]
                decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
                trim_decode_tokens = decode_tokens[9:]
                print("Summary:", trim_decode_tokens)
                keys = nltk.tokenize.sent_tokenize(trim_decode_tokens)
                context, mems = "", []
                for i, key in enumerate(keys):
                    if i > 0 and not context.endswith(" "):
                        key = " " + key
                    context_tokens = tokenizer.EncodeAsIds(key).tokenization
                    context_length = len(context_tokens)
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    output_tokens_list, mems = sample_sequence(model, tokenizer, context_tokens_tensor, context_length,
                                                               args, device, end_token=eop_token, mems=mems)
                    decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
                    context += decode_tokens
                print(context)
            else:
                if mpu.get_model_parallel_rank() == 0:
                    os.system('clear')
                    print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
                    print("\nContext:", raw_text, flush=True)
                    decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
                    trim_decode_tokens = decode_tokens[len(raw_text):]
                    print("\nGPT2:", trim_decode_tokens, flush=True)
                    output.write(trim_decode_tokens + "\n")

            torch.distributed.barrier(group=mpu.get_model_parallel_group())
Exemple #7
0
def read_context(tokenizer, args, output):
    terminate_runs, skip_run = 0, 0
    if mpu.get_model_parallel_rank() == 0:
        while True:
            raw_text = input("\nContext prompt (stop to exit) >>> ")
            if not raw_text:
                print('Prompt should not be empty!')
                continue
            if raw_text == "stop":
                terminate_runs = 1
                break
            if args.hierarchical:
                raw_text = "Summary: " + raw_text
            output.write(raw_text)
            context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
            context_length = len(context_tokens)

            if context_length >= args.seq_length:
                print("\nContext length", context_length,
                      "\nPlease give smaller context than the window length!")
                continue
            break
    else:
        context_length = 0

    terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
    torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    terminate_runs = terminate_runs_tensor[0].item()

    if terminate_runs == 1:
        return terminate_runs, raw_text, None, None

    context_length_tensor = torch.cuda.LongTensor([context_length])

    torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    context_length = context_length_tensor[0].item()
    if mpu.get_model_parallel_rank() == 0:
        context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    else:
        context_tokens_tensor = torch.cuda.LongTensor([0] * context_length)
    torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    return terminate_runs, raw_text, context_tokens_tensor, context_length
def evaluate(data_loader, model, args, timers,
             num_iterations=None):
    """Evaluation."""

    # Turn on evaluation mode which disables dropout.
    model.eval()

    total_lm_loss = 0
    if num_iterations is not None:
        max_iters = num_iterations
    else:
        if mpu.get_model_parallel_rank() == 0:
            max_iters_gpu = torch.cuda.LongTensor([len(data_loader)])
        else:
            max_iters_gpu = torch.cuda.LongTensor([0])
        torch.distributed.broadcast(max_iters_gpu,
                                    mpu.get_model_parallel_src_rank(),
                                    group=mpu.get_model_parallel_group())
        max_iters = max_iters_gpu[0].item()
        print_rank_0('global rank: {} | max iters: {}'.format(
            torch.distributed.get_rank(), max_iters))

    if data_loader is not None:
        data_iterator = iter(data_loader)
    else:
        data_iterator = None

    with torch.no_grad():
        iteration = 0
        while iteration < max_iters:
            if iteration % args.log_interval == 0:
                print_rank_0('global rank: {} | iteration: {}'.format(
                    torch.distributed.get_rank(), iteration))
            # Forward evaluation.
            lm_loss = forward_step(data_iterator, model, args, timers)
            if lm_loss is None:
                break
            # Reduce across processes.
            if isinstance(model, DDP):
                torch.distributed.all_reduce(lm_loss.data)
                if args.cloze_eval:
                    lm_loss.data = lm_loss.data / args.world_size
                else:
                    lm_loss.data = lm_loss.data / args.model_parallel_size

            if not args.cloze_eval:
                total_lm_loss += lm_loss.data.detach().float().item()/(args.num_tokenized_tokens-1)
            else:
                total_lm_loss += lm_loss.data.detach().float().item()

            iteration += 1

    # Move model back to the train mode.
    model.train()

    return total_lm_loss
 def has_overflow(self, params):
     overflow = self.has_overflow_serial(params)
     # Since each model parallel GPU carries only part of the model,
     # make sure overflow flag is synced across all the model parallel GPUs
     overflow_gpu = torch.cuda.ByteTensor([overflow])
     torch.distributed.all_reduce(overflow_gpu,
                                  op=torch.distributed.ReduceOp.MAX,
                                  group=mpu.get_model_parallel_group())
     overflow = overflow_gpu[0].item()
     return bool(overflow)
Exemple #10
0
 def backward(ctx, grad_x):
     rank = torch.distributed.get_rank()
     dst_rank = mpu.get_model_parallel_dst_rank()
     next_src_rank = mpu.get_model_parallel_next_src_rank()
     pipeline_group = mpu.get_pipeline_parallel_pred_group()
     model_parallel_group = mpu.get_model_parallel_group()
     if next_src_rank is not None:
         if rank == dst_rank:
             assert pipeline_group is not None
             dist.broadcast(grad_x, next_src_rank, group=pipeline_group)
         dist.broadcast(grad_x, dst_rank, group=model_parallel_group)
     return grad_x
Exemple #11
0
 def forward(ctx, x):
     rank = torch.distributed.get_rank()
     prev_dst_rank = mpu.get_model_parallel_prev_dst_rank()
     src_rank = mpu.get_model_parallel_src_rank()
     pipeline_group = mpu.get_pipeline_parallel_succ_group()
     model_parallel_group = mpu.get_model_parallel_group()
     if prev_dst_rank is not None:
         if rank == src_rank:
             assert pipeline_group is not None
             dist.broadcast(x, prev_dst_rank, group=pipeline_group)
         dist.broadcast(x, src_rank, group=model_parallel_group)
     return x
Exemple #12
0
def get_train_val_test_data(args):
    """Load the data on rank zero and boradcast number of tokens to all GPUS."""

    (train_data, val_data, test_data) = (None, None, None)

    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
        if args.use_npy_data_loader:
            (train_data, val_data, test_data), num_tokens, \
                eod_token = make_gpt2_dataloaders(args)
        else:
            data_config = configure_data()
            data_config.set_defaults(data_set_type='GPT2', transpose=False)
            (train_data, val_data,
             test_data), tokenizer = data_config.apply(args)
            num_tokens = tokenizer.num_tokens
            eod_token = tokenizer.get_command('eos').Id
            assert eod_token == tokenizer.get_command('pad').Id
        before = num_tokens
        after = before
        multiple = args.make_vocab_size_divisible_by * \
                   mpu.get_model_parallel_world_size()
        while (after % multiple) != 0:
            after += 1
        print_rank_0('> padded vocab (size: {}) with {} dummy '
                     'tokens (new size: {})'.format(before, after - before,
                                                    after))
        print_rank_0('> found end-of-document token: {}'.format(eod_token))
        token_counts = torch.cuda.LongTensor([
            after, eod_token,
            int(args.do_train),
            int(args.do_valid),
            int(args.do_test)
        ])
    else:
        token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(token_counts,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    num_tokens = token_counts[0].item()
    eod_token = token_counts[1].item()
    args.do_train = token_counts[2].item()
    args.do_valid = token_counts[3].item()
    args.do_test = token_counts[4].item()

    return train_data, val_data, test_data, num_tokens, eod_token
def evaluate_tnews(args, model, dataloader, device, mode="dev"):
    model.eval()
    all_truth, all_preds = [], []
    with torch.no_grad():
        for batch, no_model_batch in tqdm(dataloader, desc="Evaluating {}".format(mode),
                                          disable=(torch.distributed.get_rank() != 0)):
            for k in batch:
                batch[k] = batch[k].to(device)
            for k in no_model_batch:
                no_model_batch[k] = no_model_batch[k].to(device)

            output = model(**batch)
            output = torch.sum(output * no_model_batch["loss_mask"].unsqueeze(-1), 1) / torch.sum(
                no_model_batch["loss_mask"], -1).unsqueeze(-1)

            # gather the output logits from other gpus
            tensor_list = [torch.zeros_like(output) for _ in range(mpu.get_data_parallel_world_size())]
            torch.distributed.all_gather(tensor_list, output, mpu.get_data_parallel_group())

            # gather the truth labels from other gpus
            tensor_list_truth = [torch.zeros_like(no_model_batch["truth"], dtype=torch.long) for _ in
                                 range(mpu.get_data_parallel_world_size())]
            torch.distributed.all_gather(tensor_list_truth, no_model_batch["truth"], mpu.get_data_parallel_group())

            if args.model_parallel_size == 1:
                scores = torch.stack(tensor_list, 0).view(-1, 30000)
            else:
                assert args.model_parallel_size == 2, "Now, we only support model parallel <= 2"
                # for convience implementation. Note that the truth labels only appears in the first 15000 part of the logits, e.g. on rank 0, 2, 4, ...
                scores = torch.stack(tensor_list, 0).view(-1, 15000)

            truth = torch.stack(tensor_list_truth, 0)
            truth = truth.view(-1)
            # scores = scores[:, cand_ids]

            preds = torch.argmax(scores, dim=-1)

            all_truth.extend(truth.detach().cpu().tolist())
            all_preds.extend(preds.detach().cpu().tolist())

    acc = sum([int(p == l) for p, l in zip(all_preds, all_truth)]) / len(all_truth)
    acc = torch.tensor(acc).to(device)

    acc_list = [torch.zeros_like(acc) for _ in range(mpu.get_model_parallel_world_size())]
    torch.distributed.all_gather(acc_list, acc, mpu.get_model_parallel_group())

    return acc_list[0].item(), all_truth, all_preds
Exemple #14
0
def mix_forward_step(batch_and_dataloader, model, args, times, mems):
    use_blocklm = 0
    if args.block_lm_ratio > 0.0:
        if mpu.get_model_parallel_rank() == 0:
            if random.random() > 1 / (1 + args.block_lm_ratio):
                use_blocklm = 1
        use_blocklm = torch.cuda.LongTensor([use_blocklm])
        torch.distributed.broadcast(use_blocklm,
                                    mpu.get_model_parallel_src_rank(),
                                    group=mpu.get_model_parallel_group())
        use_blocklm = use_blocklm.item()
    if use_blocklm:
        return lm_forward_step((batch_and_dataloader[1], None), model, args,
                               times, mems)
    else:
        return finetune_forward_step(batch_and_dataloader[0], model, args,
                                     times, mems)
Exemple #15
0
def get_train_val_test_data(args):
    """Load the data on rank zero and boradcast number of tokens to all GPUS."""

    (train_data, val_data, test_data) = (None, None, None)

    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
        data_config = configure_data()
        ds_type = 'BERT'
        data_config.set_defaults(data_set_type=ds_type, transpose=False)
        (train_data, val_data, test_data), tokenizer = data_config.apply(args)
        before = tokenizer.num_tokens
        after = before
        multiple = args.make_vocab_size_divisible_by * \
                   mpu.get_model_parallel_world_size()
        while (after % multiple) != 0:
            after += 1
        print_rank_0('> padded vocab (size: {}) with {} dummy '
                     'tokens (new size: {})'.format(before, after - before,
                                                    after))
        # Need to broadcast num_tokens and num_type_tokens.
        token_counts = torch.cuda.LongTensor([
            after, tokenizer.num_type_tokens,
            int(args.do_train),
            int(args.do_valid),
            int(args.do_test)
        ])
    else:
        token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(token_counts,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    num_tokens = token_counts[0].item()
    num_type_tokens = token_counts[1].item()
    args.do_train = token_counts[2].item()
    args.do_valid = token_counts[3].item()
    args.do_test = token_counts[4].item()

    return train_data, val_data, test_data, num_tokens, num_type_tokens
Exemple #16
0
def test_initialize_model_parallel(model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing initialize_model_parallel with size {} ...'.format(
            model_parallel_size))
    model_parallel_size_ = min(model_parallel_size,
                               torch.distributed.get_world_size())
    assert not mpu.model_parallel_is_initialized()
    mpu.initialize_model_parallel(model_parallel_size_)
    assert mpu.model_parallel_is_initialized()

    # Checks.
    def check(group, world_size, rank):
        assert world_size == torch.distributed.get_world_size(group=group)
        assert rank == torch.distributed.get_rank(group=group)

    # Model parallel.
    world_size = model_parallel_size_
    rank = torch.distributed.get_rank() % model_parallel_size_
    assert world_size == mpu.get_model_parallel_world_size()
    assert rank == mpu.get_model_parallel_rank()
    check(mpu.get_model_parallel_group(), world_size, rank)


    # Data parallel.
    world_size = torch.distributed.get_world_size() // model_parallel_size_
    rank = torch.distributed.get_rank() // model_parallel_size
    assert world_size == mpu.get_data_parallel_world_size()
    assert rank == mpu.get_data_parallel_rank()
    check(mpu.get_data_parallel_group(), world_size, rank)

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
Exemple #17
0
def generate_samples_input_from_file(model, tokenizer, args):

    if args.sample_input_file == "":
        if mpu.get_model_parallel_rank() == 0:
            print("args.sample_input_file CAN NOT BE empty!\n")
        return

    if mpu.get_model_parallel_rank() == 0:
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
        if args.sample_output_file == "":
            print(
                "Argument: sample-output-file can't be empty, setting it to\n")
            print("\t args.sample_input_file.out")
            args.sample_output_file = args.sample_input_file + ".out"
        fname_out = open(args.sample_output_file, "w+")

    context_count = 0
    model.eval()
    with torch.no_grad():
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            terminate_runs = 0

            if mpu.get_model_parallel_rank() == 0:
                raw_text = all_raw_text[input_pos]
                input_pos += 1
                if input_pos == input_count:
                    raw_text = "stop"

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
                    context_tokens = tokenizer.EncodeAsIds(
                        raw_text).tokenization
                    context_length = len(context_tokens)

                    # if context_length >=args.seq_length//2:
                    #     print("\nContext length", context_length, \
                    #         "\nPlease give smaller context (half of the sequence length)!")
                    #     continue
            else:
                context_tokens = tokenizer.EncodeAsIds(
                    "EMPTY TEXT").tokenization
                context_length = len(context_tokens)

            terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
            torch.distributed.broadcast(terminate_runs_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            terminate_runs = terminate_runs_tensor[0].item()

            if terminate_runs == 1:
                return

            start_time = time.time()
            token_stream = get_token_stream(model, [context_tokens], tokenizer,
                                            args)
            for counter, decode_tokens in enumerate(token_stream):
                # token_end = decode_tokens.find("<|endoftext|>")
                # if token_end > 0:
                #     break
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()

            if mpu.get_model_parallel_rank() == 0 and decode_tokens:
                os.system('clear')
                #print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
                print("\nContext:", raw_text, flush=True)
                trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)
                #print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

                fname_out.write("\nContext:")
                fname_out.write(raw_text)
                fname_out.write("\n\nMegatron-LM:")
                fname_out.write(trim_decode_tokens)
                #fname_out.write(trim_decode_tokens.replace("\n", "\n\n"))
                fname_out.write("\n")

            raw_text = None

            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            context_count += 1
def generate_samples(model, tokenizer, args, device):

    context_count = 0
    model.eval()
    with torch.no_grad():
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            terminate_runs = 0

            if mpu.get_model_parallel_rank() == 0:
                if args.input_text:
                    raw_text = open(args.input_text).read().strip()
                else:
                    raw_text = input("\nContext prompt (stop to exit) >>> ")
                    while not raw_text:
                        print('Prompt should not be empty!')
                        raw_text = input(
                            "\nContext prompt (stop to exit) >>> ")

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
                    #context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
                    context_tokens = tokenizer.encode(raw_text)
                    context_length = len(context_tokens)

                    if context_length >= args.seq_length // 2:
                        print("\nContext length", context_length, \
                            "\nPlease give smaller context (half of the sequence length)!")
                        continue
            else:
                #context_tokens = tokenizer.EncodeAsIds("EMPTY TEXT").tokenization
                context_tokens = tokenizer.encode("空文本")
                context_length = len(context_tokens)

            terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
            torch.distributed.broadcast(terminate_runs_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            terminate_runs = terminate_runs_tensor[0].item()

            if terminate_runs == 1:
                return

            pad_id = tokenizer.encoder['<pad>']
            args.eod_token = tokenizer.encoder['<eod>']
            if context_length < args.seq_length:
                context_tokens.extend([pad_id] *
                                      (args.seq_length - context_length))

            context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
            context_length_tensor = torch.cuda.LongTensor([context_length])

            torch.distributed.broadcast(context_length_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            torch.distributed.broadcast(context_tokens_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())

            context_length = context_length_tensor[0].item()
            tokens, attention_mask, position_ids = get_batch(
                context_tokens_tensor, device, args)

            start_time = time.time()

            counter = 0
            org_context_length = context_length

            past_key_values = None
            while counter < (org_context_length + args.out_seq_length):
                if counter == 0:
                    logits, past_key_values = model(
                        tokens[:, :context_length],
                        position_ids[:, :context_length],
                        attention_mask[:, :, :context_length, :context_length],
                        past_key_values=past_key_values,
                        use_cache=True)
                    logits = logits[:, context_length - 1, :]
                else:
                    logits, past_key_values = model(
                        tokens[:, context_length - 1:context_length],
                        position_ids[:, context_length - 1:context_length],
                        attention_mask[:, :,
                                       context_length - 1, :context_length],
                        past_key_values=past_key_values,
                        use_cache=True)
                    logits = logits[:, 0, :]
                past_key_values = [x.half() for x in past_key_values]
                logits = top_k_logits(logits,
                                      top_k=args.top_k,
                                      top_p=args.top_p)
                log_probs = F.softmax(logits, dim=-1)
                prev = torch.multinomial(log_probs, num_samples=1)
                tokens[0, context_length] = prev[0]
                torch.distributed.broadcast(
                    tokens,
                    mpu.get_model_parallel_src_rank(),
                    group=mpu.get_model_parallel_group())
                context_length += 1
                counter += 1

                output_tokens_list = tokens.view(-1).contiguous()
                decode_tokens = tokenizer.decode(output_tokens_list.tolist())
                token_end = decode_tokens.find("<eod>")

                if mpu.get_model_parallel_rank() == 0 and (counter % 16 == 0
                                                           or token_end != -1):
                    os.system('clear')
                    print("\nTaken time {:.2f}\n".format(time.time() -
                                                         start_time),
                          flush=True)
                    print("\nContext:", raw_text, flush=True)
                    trim_decode_tokens = decode_tokens[
                        len(raw_text):decode_tokens.find("<eod>")]
                    print("\nCPM:", trim_decode_tokens, flush=True)
                if token_end != -1:
                    #print(token_end)
                    break

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
                print("\nTaken time {:.2f}\n".format(time.time() - start_time),
                      flush=True)
                print("\nContext:", raw_text, flush=True)
                output_tokens_list = tokens.view(-1).contiguous()
                decode_tokens = tokenizer.decode(output_tokens_list.tolist())
                trim_decode_tokens = decode_tokens[len(raw_text):decode_tokens.
                                                   find("<eod>")]
                print("\nCPM:", trim_decode_tokens, flush=True)
                #print(token_end)
            raw_text = None

            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            context_count += 1

            if args.input_text:
                break
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider, args):
    """XXX"""

    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
        # Rank, size, and global batch size.
        data_parallel_size = mpu.get_data_parallel_world_size()
        global_batch_size = args.batch_size * data_parallel_size

        # Number of train/valid/test samples.
        train_iters = args.train_iters
        eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
        test_iters = args.eval_iters
        train_val_test_num_samples = [train_iters * global_batch_size,
                                      eval_iters * global_batch_size,
                                      test_iters * global_batch_size]
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
        train_dataloader = make_data_loader(train_ds)
        valid_dataloader = make_data_loader(valid_ds)
        test_dataloader = make_data_loader(test_ds)

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

    # Shift the start iterations.
    if train_dataloader is not None:
        train_dataloader.batch_sampler.start_iter = args.iteration % \
            len(train_dataloader)
        print_rank_0('setting training data start iteration to {}'.
                     format(train_dataloader.batch_sampler.start_iter))
    if valid_dataloader is not None:
        start_iter_val = (args.iteration // args.eval_interval) * \
            args.eval_iters
        valid_dataloader.batch_sampler.start_iter = start_iter_val % \
            len(valid_dataloader)
        print_rank_0('setting validation data start iteration to {}'.
                     format(valid_dataloader.batch_sampler.start_iter))

    # Build iterators.
    if train_dataloader is not None:
        train_data_iterator = iter(train_dataloader)
    else:
        train_data_iterator = None

    if valid_dataloader is not None:
        valid_data_iterator = iter(valid_dataloader)
    else:
        valid_data_iterator = None

    if test_dataloader is not None:
        test_data_iterator = iter(test_dataloader)
    else:
        test_data_iterator = None

    return train_data_iterator, valid_data_iterator, test_data_iterator
def generate_samples(model, tokenizer, args, device):

    context_count = 0
    model.eval()
    with torch.no_grad():
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            terminate_runs = 0

            if mpu.get_model_parallel_rank() == 0:
                raw_text = 'what is freedom of speech?'
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = 'what is freedom of speech?'
                if "stop" in raw_text:
                    terminate_runs = 1
                else:
                    context_tokens = tokenizer.EncodeAsIds(
                        raw_text).tokenization
                    context_length = len(context_tokens)

                    if context_length >= args.seq_length // 2:
                        print("\nContext length", context_length, \
                            "\nPlease give smaller context (half of the sequence length)!")
                        continue
            else:
                context_tokens = tokenizer.EncodeAsIds(
                    "EMPTY TEXT").tokenization
                context_length = len(context_tokens)

            terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
            torch.distributed.broadcast(terminate_runs_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            terminate_runs = terminate_runs_tensor[0].item()

            if terminate_runs == 1:
                return

            pad_id = tokenizer.get_command('pad').Id
            if context_length < args.seq_length:
                context_tokens.extend([pad_id] *
                                      (args.seq_length - context_length))

            context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
            context_length_tensor = torch.cuda.LongTensor([context_length])

            torch.distributed.broadcast(context_length_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            torch.distributed.broadcast(context_tokens_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())

            context_length = context_length_tensor[0].item()
            tokens, attention_mask, position_ids = get_batch(
                context_tokens_tensor, device, args)

            start_time = time.time()

            counter = 0
            org_context_length = context_length

            while counter < (org_context_length + args.out_seq_length):
                logits = model(tokens, position_ids, attention_mask)
                logits = logits[:, context_length - 1, :] / args.temperature
                logits = top_k_logits(logits,
                                      top_k=args.top_k,
                                      top_p=args.top_p)
                log_probs = F.softmax(logits, dim=-1)
                prev = torch.multinomial(log_probs, num_samples=1)
                tokens[0, context_length] = prev[0]
                context_length += 1
                counter += 1

                output_tokens_list = tokens.view(-1).contiguous()
                decode_tokens = tokenizer.DecodeIds(
                    output_tokens_list.tolist())
                token_end = decode_tokens.find("<|endoftext|>")

                if mpu.get_model_parallel_rank() == 0 and (counter % 16 == 0
                                                           or token_end != -1):
                    os.system('clear')
                    print("\nTaken time {:.2f}\n".format(time.time() -
                                                         start_time),
                          flush=True)
                    print("\nContext:", raw_text, flush=True)
                    trim_decode_tokens = decode_tokens[
                        len(raw_text):decode_tokens.find("<|endoftext|>")]
                    print("\nGPT2:", trim_decode_tokens, flush=True)
                if token_end != -1:
                    break

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
                print("\nTaken time {:.2f}\n".format(time.time() - start_time),
                      flush=True)
                print("\nContext:", raw_text, flush=True)
                output_tokens_list = tokens.view(-1).contiguous()
                decode_tokens = tokenizer.DecodeIds(
                    output_tokens_list.tolist())
                trim_decode_tokens = decode_tokens[len(raw_text):decode_tokens.
                                                   find("<|endoftext|>")]
                print("\nGPT2:", trim_decode_tokens, flush=True)
            raw_text = None

            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            context_count += 1
Exemple #21
0
    def __call__(self, text=None, input_ids=None, labels=None, **kwargs):
        if input_ids is None:
            if text is None:
                text = ""
            input_ids = torch.cuda.LongTensor(
                [self.tokenizer(text)['input_ids']])
        if isinstance(input_ids, list):
            input_ids = torch.cuda.LongTensor(input_ids)
        if isinstance(labels, list):
            labels = torch.cuda.LongTensor(labels)
        res = []
        if labels is not None:
            lbls = labels
        else:
            lbls = [None] * len(input_ids)
        loss = None
        original_context_length = 0
        for tokens, lbl in zip(input_ids, lbls):
            context_tokens = tokens.tolist()
            context_length = len(context_tokens)
            original_context_length = len(context_tokens)
            if context_length < self.seq_len:
                context_tokens.extend([self.pad_token_id] *
                                      (self.seq_len - context_length))
                if labels is not None:
                    lbl = lbl.tolist()
                    lbl.extend([self.pad_token_id] *
                               (self.seq_len - context_length))
                    lbl = torch.cuda.LongTensor(lbl)
            context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
            context_length_tensor = torch.cuda.LongTensor([context_length])

            torch.distributed.broadcast(context_length_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            torch.distributed.broadcast(context_tokens_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())

            # context_length = context_length_tensor[0].item()

            tokens = context_tokens_tensor
            tokens = tokens.view(1, -1).contiguous()
            tokens = tokens.to(torch.cuda.current_device())
            attention_mask, loss_mask, position_ids = get_masks_and_position_ids(
                tokens, self.pad_token_id, False, False)
            lm_logits = self.model(tokens, position_ids, attention_mask)
            loss = None
            if labels is not None:
                # Shift so that tokens < n predict n
                shift_logits = lm_logits[..., :-1, :].contiguous()
                shift_labels = lbl[..., 1:].contiguous()
                # Flatten the tokens
                loss_fct = CrossEntropyLoss(ignore_index=self.pad_token_id)
                loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
                                shift_labels.view(-1))
            res.append((lm_logits, loss))
        logits = torch.cat([x[0] for x in res],
                           dim=0)[:, :original_context_length, :]
        if loss is not None:
            loss = [x[1] for x in res]
        return ModelOutput(logits, loss)
Exemple #22
0
def finetune(args,
             train_valid_datasets_provider,
             model_kwargs,
             forward_step=finetune_forward_step,
             end_of_epoch_callback_provider=None):
    """Main finetune function used across all tasks."""
    global tokenizer
    timers = Timers()
    tokenizer = prepare_tokenizer(args)
    pretrain_glm.tokenizer = tokenizer
    if args.save:
        args.save = os.path.join(args.save, args.experiment_name)
    # Train and validation data loaders.
    timers('train/valid/test dataset/dataloder').start()
    train_dataloader, valid_dataloader = None, None
    train_block_dataloader, valid_block_dataloader = None, None
    if train_valid_datasets_provider is not None and args.epochs > 0:
        if mpu.get_model_parallel_rank() == 0:
            train_dataset, valid_dataset = train_valid_datasets_provider(
                args, tokenizer)
            train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
                train_dataset, valid_dataset, args)
            if args.no_validation:
                valid_dataloader = None
            train_iters = torch.cuda.LongTensor([len(train_dataloader)])
        else:
            train_iters = torch.cuda.LongTensor([0])
        torch.distributed.broadcast(train_iters,
                                    mpu.get_model_parallel_src_rank(),
                                    group=mpu.get_model_parallel_group())
        if mpu.get_model_parallel_rank() != 0:
            args.train_iters_per_epoch = train_iters[0].item()
            args.train_iters = args.epochs * args.train_iters_per_epoch

            train_dataloader = FakeDataloader(args.train_iters_per_epoch)
            if args.no_validation:
                valid_dataloader = None
            else:
                valid_dataloader = FakeDataloader(None)
        if args.block_lm_ratio > 0.0:
            if mpu.get_model_parallel_rank() == 0:
                train_block_dataset, valid_block_dataset = train_valid_datasets_provider(
                    args, tokenizer, pattern_text=True)
                train_block_dataloader = make_data_loader(
                    train_block_dataset,
                    tokenizer,
                    args.batch_size * mpu.get_data_parallel_world_size(),
                    args.train_iters,
                    args,
                    shuffle=True,
                    block_collate=True)
                valid_block_dataloader = make_data_loader(
                    valid_block_dataset,
                    tokenizer,
                    args.batch_size * mpu.get_data_parallel_world_size(),
                    (args.train_iters // args.eval_interval + 1) *
                    args.eval_iters,
                    args,
                    shuffle=True,
                    block_collate=True)
            else:
                train_block_dataloader = FakeDataloader(args.train_iters)
                valid_block_dataloader = FakeDataloader(None)
            train_block_dataloader, valid_block_dataloader = iter(
                train_block_dataloader), iter(valid_block_dataloader)

    timers('train/valid/test dataset/dataloder').stop()
    # Build calback function.
    timers('callback function').start()
    end_of_epoch_callback, end_of_train_callback = None, None
    if end_of_epoch_callback_provider is not None:
        if train_valid_datasets_provider is not None and args.epochs > 0 and not args.no_validation:
            end_of_epoch_callback = end_of_epoch_callback_provider(
                args, tokenizer, is_test=False)
        end_of_train_callback = end_of_epoch_callback_provider(args,
                                                               tokenizer,
                                                               is_test=True)
    timers('callback function').stop()

    # Build model, optimizer and learning rate scheduler.
    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(
        args, **model_kwargs)
    timers('model and optimizer').stop()

    # If pretrained checkpoint is provided and we have not trained for
    # any iteration (i.e., iteration is zero), then load the pretrained
    # checkpoint.
    timers('pretrained checkpoint').start()
    if args.load_pretrained is not None and not args.pretrained_bert:
        task_tokens = None
        if args.continuous_prompt and args.prompt_init:
            if mpu.get_model_parallel_rank() == 0:
                dataset = train_dataloader.dataset
                processor, pvp = dataset.processor, dataset.pvp
                task_tokens = []
                for label in processor.get_labels():
                    verbalizer = pvp.verbalize(label)[0]
                    verbalizer_ids = tokenizer.EncodeAsIds(
                        verbalizer).tokenization
                    task_tokens += verbalizer_ids
                print_rank_0("Task tokens: " +
                             tokenizer.DecodeIds(task_tokens))
                num_task_tokens = len(task_tokens)
            else:
                num_task_tokens, task_tokens = 0, []
            num_task_tokens = torch.cuda.LongTensor([num_task_tokens])
            torch.distributed.broadcast(num_task_tokens,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            num_task_tokens = num_task_tokens.item()
            if num_task_tokens > 0:
                if mpu.get_model_parallel_rank() == 0:
                    task_tokens = torch.cuda.LongTensor(task_tokens)
                else:
                    task_tokens = torch.empty(
                        num_task_tokens,
                        device=torch.cuda.current_device(),
                        dtype=torch.long)
                torch.distributed.broadcast(
                    task_tokens,
                    mpu.get_model_parallel_src_rank(),
                    group=mpu.get_model_parallel_group())
                task_tokens = task_tokens.tolist()
        with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"),
                      timeout=-1):
            load_pretrained(model,
                            args.load_pretrained,
                            args,
                            task_tokens=task_tokens)
        # This is critical when only model is loaded. We should make sure
        # master parameters are also updated.
        if args.fp16 and optimizer is not None:
            if args.deepspeed:
                optimizer.refresh_fp32_params()
            else:
                optimizer._model_params_to_master_params()
    if args.load is not None:
        with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"),
                      timeout=-1):
            load_checkpoint(model,
                            optimizer,
                            lr_scheduler,
                            args,
                            no_deepspeed=args.no_deepspeed_load)
        # This is critical when only model is loaded. We should make sure
        # master parameters are also updated.
        if args.fp16 and optimizer is not None:
            if args.deepspeed:
                optimizer.refresh_fp32_params()
            else:
                optimizer._model_params_to_master_params()
    torch.distributed.barrier()
    timers('pretrained checkpoint').stop()
    args.iteration = 0
    summary_writer = None
    if torch.distributed.get_rank() == 0:
        args.log_dir = get_log_dir(base=args.summary_dir,
                                   name=args.experiment_name)
        if os.path.exists(os.path.join(args.log_dir, "test_results.json")
                          ) and args.load is None and not args.overwrite:
            raise ValueError(
                "Output directory ({}) already exists and is not empty.".
                format(args.log_dir))
        summary_writer = get_sample_writer(log_dir=args.log_dir,
                                           iteration=args.iteration)
        print_and_save_args(args, verbose=True, log_dir=args.log_dir)

    # Print setup timing.
    print_rank_0('done with setups ...')
    timers.log([
        'train/valid/test dataset/dataloder', 'callback function',
        'model and optimizer', 'pretrained checkpoint'
    ])
    print_rank_0('training ...')

    # Finetune the model.
    score_dict = None
    if train_dataloader is not None and args.epochs > 0:
        if args.block_lm_ratio > 0.0:
            forward_step = mix_forward_step
        best_iteration = _train(model,
                                optimizer,
                                lr_scheduler,
                                forward_step,
                                (train_dataloader, train_block_dataloader),
                                (valid_dataloader, valid_block_dataloader),
                                end_of_epoch_callback,
                                args,
                                timers,
                                summary_writer=summary_writer)
        if end_of_train_callback is not None and best_iteration is not None:
            with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"),
                          timeout=-1):
                args.load = os.path.join(args.save, "best")
                load_checkpoint(model,
                                optimizer,
                                lr_scheduler,
                                args,
                                no_load_optim=True,
                                no_deepspeed=True)
                args.load = None
        torch.distributed.barrier()
        if end_of_train_callback is not None:
            score_dict = end_of_train_callback(model,
                                               epoch=-1,
                                               output_predictions=True)
    # Or just evaluate.
    else:
        if end_of_train_callback is not None:
            print_rank_0('evaluation only mode, setting epoch to -1')
            score_dict = end_of_train_callback(model,
                                               epoch=-1,
                                               output_predictions=True)
    if score_dict is not None and torch.distributed.get_rank() == 0:
        score_dict.update({"type": "test"})
        with open(os.path.join(args.log_dir, "test_results.json"),
                  "w") as output:
            output.write(json.dumps(score_dict) + "\n")

    print_rank_0('done :-)')
def generate_samples_input_from_file(model, tokenizer, args):
    if mpu.get_model_parallel_rank() == 0:
        all_raw_text = args.text
        input_count = len(all_raw_text)
        input_pos = 0

    context_count = 0
    model.eval()
    with torch.no_grad():
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            terminate_runs = 0

            if mpu.get_model_parallel_rank() == 0:
                raw_text = all_raw_text[input_pos]
                input_pos += 1
                if input_pos == input_count:
                    raw_text = "stop"

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
                    context_tokens = tokenizer.EncodeAsIds(
                        raw_text).tokenization
                    context_length = len(context_tokens)

                    if context_length >= args.seq_length // 2:
                        print(
                            "\nContext length", context_length,
                            "\nPlease give smaller context (half of the sequence length)!"
                        )
                        continue
            else:
                context_tokens = tokenizer.EncodeAsIds(
                    "EMPTY TEXT").tokenization
                context_length = len(context_tokens)

            terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
            torch.distributed.broadcast(terminate_runs_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            terminate_runs = terminate_runs_tensor[0].item()

            if terminate_runs == 1:
                return

            start_time = time.time()
            token_stream = get_token_stream(model, [context_tokens], tokenizer,
                                            args)
            for counter, decode_tokens in enumerate(token_stream):
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
                print("\nTaken time {:.2f}\n".format(time.time() - start_time),
                      flush=True)
                trim_decode_tokens = tokenizer.DecodeIds(
                    decode_tokens)[len(raw_text):]
                # print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

                yield trim_decode_tokens

            raw_text = None

            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            context_count += 1
def generate_samples(model, tokenizer, args, device):

    context_count = 0
    keys = ['text', 'types', 'mask', 'mask_labels', 'pad_mask']
    datatype = torch.int64
    keys2 = ['clickscores', 'hrsscores']
    datatype2 = torch.float64
    model.eval()
    fout = open(args.output_path, 'w', encoding='utf-8')
    with torch.no_grad():
        data_ietrator = binglr_iterator_dataset([args.valid_data],
                                                run_once=True,
                                                max_seq_len=args.seq_length,
                                                mask_lm_prob=0.15,
                                                max_preds_per_seq=20,
                                                tokenizer=tokenizer,
                                                train=False)
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            terminate_runs = 0
            if mpu.get_model_parallel_rank() == 0:
                data = next(data_ietrator)

                if sample is None:
                    terminate_runs = 1
                else:
                    # Unpack.
                    tokens = data['text']
                    types = data['types']
                    loss_mask = data['mask']
                    lm_labels = data['mask_labels']
                    padding_mask = data['pad_mask']
                    clickscores = data['clickscores']
                    hrsscores = data['hrsscores']
                    sample_id = data['sample_id']
                    # Get the masks and postition ids.

            else:
                tokens = np.array([0] * seq_length)
                types = np.array([0] * seq_length)
                loss_mask = np.array([0] * seq_length)
                lm_labels = np.array([0] * seq_length)
                padding_mask = np.array([0] * seq_length)
                clickscores = np.array([0.0])
                hrsscores = np.array([0.0])
                sample_id = np.array([0])

            terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
            torch.distributed.broadcast(terminate_runs_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            terminate_runs = terminate_runs_tensor[0].item()

            if terminate_runs == 1:
                return

            tokens_tensor = torch.cuda.LongTensor(tokens).view(1, -1)
            types_tensor = torch.cuda.LongTensor(types).view(1, -1)
            loss_mask_tensor = torch.cuda.LongTensor(loss_mask).view(1, -1)
            lm_labels_tensor = torch.cuda.LongTensor(lm_labels).view(1, -1)
            padding_mask_tensor = torch.cuda.LongTensor(
                padding_mask_mask).view(1, -1)
            clickscores_tensor = torch.cuda.FloatTensor(clickscores)
            hrsscores_tensor = torch.cuda.FloatTensor(hrsscores)
            batch_size, seq_length = tokens.size()
            attention_mask_tensor = (
                torch.ones_like(padding_mask, device=padding_mask.device) -
                padding_mask).view(batch_size, 1, seq_length, 1) * (
                    torch.ones_like(padding_mask, device=padding_mask.device) -
                    padding_mask).view(batch_size, 1, 1, seq_length)
            position_ids = torch.arange(seq_length,
                                        dtype=torch.long,
                                        device=tokens.device)
            position_ids_tensor = position_ids.unsqueeze(0).expand_as(tokens)

            torch.distributed.broadcast(tokens_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            torch.distributed.broadcast(types_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            torch.distributed.broadcast(clickscores_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            torch.distributed.broadcast(hrsscores_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            torch.distributed.broadcast(attention_mask_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            torch.distributed.broadcast(position_ids_mask_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())

            start_time = time.time()

            counter = 0
            _, hrs_scores, _, _ = model(tokens.contiguous(),
                                        position_ids.contiguous(),
                                        attention_mask.contiguous(),
                                        types.contiguous())

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
                fout('\t'.join(
                    [sample_id,
                     hrs_scores.detach().clone().cpu().numpy()]) + '\n')

            torch.distributed.barrier(group=mpu.get_model_parallel_group())
def generate_samples(model, tokenizer, args, device):

    context_count = 0
    model.eval()
    output_path = "./samples"
    print("We're in.1")
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    print("We're in.2")
    output_path = os.path.join(
        output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt")
    with torch.no_grad(), open(output_path, "w") as output:
        while True:
            print("We're in.")
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            terminate_runs = 0

            if mpu.get_model_parallel_rank() == 0:
                raw_text = input("\nContext prompt (stop to exit) >>> ")
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = input("\nContext prompt (stop to exit) >>> ")

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
                    output.write(raw_text)
                    context_tokens = tokenizer.EncodeAsIds(
                        raw_text).tokenization
                    context_length = len(context_tokens)

                    if context_length >= args.seq_length:
                        print("\nContext length", context_length, \
                            "\nPlease give smaller context (half of the sequence length)!")
                        continue
            else:
                context_tokens = tokenizer.EncodeAsIds(
                    "EMPTY TEXT").tokenization
                context_length = len(context_tokens)

            terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
            torch.distributed.broadcast(terminate_runs_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            terminate_runs = terminate_runs_tensor[0].item()

            if terminate_runs == 1:
                return

            # pad_id = tokenizer.get_command('pad').Id
            # if context_length < args.out_seq_length:
            #     context_tokens.extend([pad_id] * (args.out_seq_length - context_length))

            context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
            context_length_tensor = torch.cuda.LongTensor([context_length])

            torch.distributed.broadcast(context_length_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            torch.distributed.broadcast(context_tokens_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())

            context_length = context_length_tensor[0].item()
            tokens, attention_mask, position_ids = get_batch(
                context_tokens_tensor, device, args)

            start_time = time.time()

            counter, mems = 0, []
            org_context_length = context_length
            while counter < (args.out_seq_length - org_context_length):
                if counter == 0:
                    logits, *mems = model(tokens, position_ids, attention_mask,
                                          *mems)
                else:
                    index = org_context_length + counter
                    logits, *mems = model(
                        tokens[:, index - 1:index],
                        tokens.new_ones((1, 1)) * (index - 1),
                        tokens.new_ones(1,
                                        1,
                                        1,
                                        args.mem_length + 1,
                                        device=tokens.device,
                                        dtype=torch.float), *mems)
                logits = logits[:, -1]
                logits /= args.temperature
                logits = top_k_logits(logits,
                                      top_k=args.top_k,
                                      top_p=args.top_p)
                log_probs = F.softmax(logits, dim=-1)
                prev = torch.multinomial(log_probs, num_samples=1)[0]
                tokens = torch.cat((tokens, prev.view(1, 1)), dim=1)
                context_length += 1
                counter += 1

                output_tokens_list = tokens.view(-1).contiguous()
                decode_tokens = tokenizer.DecodeIds(
                    output_tokens_list.tolist())

                is_end = prev == args.eod_token
                if mpu.get_model_parallel_rank() == 0 and (counter % 128 == 0
                                                           or is_end):
                    os.system('clear')
                    print("\nTaken time {:.2f}\n".format(time.time() -
                                                         start_time),
                          flush=True)
                    print("\nContext:", raw_text, flush=True)
                    trim_decode_tokens = decode_tokens[
                        len(raw_text):decode_tokens.find("<|endoftext|>")]
                    print("\nGPT2:", trim_decode_tokens, flush=True)
                if is_end:
                    break

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
                print("\nTaken time {:.2f}\n".format(time.time() - start_time),
                      flush=True)
                print("\nContext:", raw_text, flush=True)
                output_tokens_list = tokens.view(-1).contiguous()
                decode_tokens = tokenizer.DecodeIds(
                    output_tokens_list.tolist())
                trim_decode_tokens = decode_tokens[len(raw_text):decode_tokens.
                                                   find("<|endoftext|>")]
                print("\nGPT2:", trim_decode_tokens, flush=True)
                output.write(trim_decode_tokens + "\n")
            raw_text = None

            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            context_count += 1
def get_eval_data(args):
    val_dataloader = None
    if mpu.get_model_parallel_rank() == 0:
        eval_batch_size = args.eval_batch_size
        eval_batch_size = args.batch_size if eval_batch_size is None else eval_batch_size
        seq_len = args.seq_length
        valid_data = args.valid_data
        valid_data = valid_data[0] if isinstance(valid_data, list) else valid_data

        tokenizer = get_tokenizer(args)

        if not args.cloze_eval:

            with open(valid_data, "rb") as reader:
                entire_data = reader.read().decode('utf-8')
            num_original_tokens = len(entire_data.strip().split(" "))
            entire_data = get_detokenizer(valid_data)(entire_data)
            tokenized_data = tokenizer.EncodeAsIds(entire_data).tokenization
            num_tokenized_tokens = len(tokenized_data)
            string = 'Original Tokens: %d, Detokenized tokens: %d' % (num_tokenized_tokens, num_original_tokens)
            print_rank_0(string)

            eod_token = tokenizer.get_command('pad').Id
            val_dataset = LM_Eval_Dataset(tokenized_data, seq_len, eod_token,
                                          args.overlapping_eval)
        else:
            val_dataset = Lambada_Eval_Dataset(valid_data, tokenizer, seq_len)
            num_tokenized_tokens = 0
            num_original_tokens = 0
        val_dataloader = torch.utils.data.DataLoader(
            val_dataset, batch_size=eval_batch_size, drop_last=False)

        before = tokenizer.num_tokens
        after = before
        while after % mpu.get_model_parallel_world_size() != 0:
            after += 1
        print_rank_0('> padded vocab (size: {}) with {} dummy tokens (new size: {})'.
              format(before, after - before, after))
        eod_token = tokenizer.get_command('pad').Id
        num_examples = len(val_dataset)
        token_counts = torch.cuda.LongTensor([after, eod_token, num_examples,
                                              num_original_tokens,
                                              num_tokenized_tokens])
    else:
        token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])
    torch.distributed.broadcast(token_counts,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    args.vocab_size = token_counts[0].item()
    args.eod_token = token_counts[1].item()
    args.num_examples = token_counts[2].item()
    args.num_original_tokens = token_counts[3].item()
    args.num_tokenized_tokens = token_counts[4].item()

    print('global rank: {} | vocab size: {} | eod token: {} | '
          'num_examples: {} | num_original_tokens: {} | '
          'num_tokenized_tokens: {}'.format(
              torch.distributed.get_rank(), args.vocab_size,
              args.eod_token, args.num_examples, args.num_original_tokens,
              args.num_tokenized_tokens ))
    return val_dataloader
Exemple #27
0
def generate_samples_interactive(model, tokenizer, args):

    print_frequency = 24

    context_count = 0
    model.eval()
    with torch.no_grad():
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            terminate_runs = 0

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
                raw_text = input("\nContext prompt (stop to exit) >>> ")
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = input("\nContext prompt (stop to exit) >>> ")

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
                    context_tokens = tokenizer.EncodeAsIds(
                        raw_text).tokenization
                    context_length = len(context_tokens)

                    # if context_length >=args.seq_length//2:
                    #     print("\nContext length", context_length, \
                    #         "\nPlease give smaller context (half of the sequence length)!")
                    #     continue
            else:
                context_tokens = tokenizer.EncodeAsIds(
                    "EMPTY TEXT").tokenization
                context_length = len(context_tokens)

            terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
            torch.distributed.broadcast(terminate_runs_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            terminate_runs = terminate_runs_tensor[0].item()

            if terminate_runs == 1:
                return

            start_time = time.time()
            token_stream = get_token_stream(model, [context_tokens], tokenizer,
                                            args)
            for counter, decode_tokens in enumerate(token_stream):
                # token_end = decode_tokens.find("<|endoftext|>")
                # if token_end > 0:
                #     break
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()

                if mpu.get_model_parallel_rank(
                ) == 0 and counter % print_frequency == 0:
                    os.system('clear')
                    #print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
                    print("\nContext:", raw_text, flush=True)
                    trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)
                    #print("\nGPT2:", trim_decode_tokens, flush=True)
                    #print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
                    print("\nMegatron-LM:", trim_decode_tokens, flush=True)

            if mpu.get_model_parallel_rank() == 0 and decode_tokens:
                os.system('clear')
                #print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
                print("\nContext:", raw_text, flush=True)
                trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)
                #print("\nGPT2:", trim_decode_tokens, flush=True)
                #print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

            raw_text = None

            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            context_count += 1

            if mpu.get_model_parallel_rank() == 0:
                input("\nPress any key to continue >>>")
    def __init__(self,
                 config,
                 batch_slices,
                 seq_slices,
                 distributed_init_method,
                 world_size,
                 data_parallel_size,
                 model_parallel_size,
                 pipeline_parallel_size,
                 rank,
                 local_rank,
                 mixed_precision=False,
                 use_mpi=False,
                 init_process_group=False,
                 checkpoint_gradients=False):
        self.config = config
        self.batch_slices = batch_slices
        self.seq_slices = seq_slices
        torch.cuda.set_device(local_rank)
        if init_process_group:
            dist.init_process_group(
                backend='nccl',
                init_method=distributed_init_method,
                world_size=world_size,
                rank=rank,
            )
        dist.all_reduce(torch.zeros(1).cuda())
        mpu.initialize_model_parallel(model_parallel_size,
                                      pipeline_parallel_size)
        set_random_seed(0)
        mpu.model_parallel_cuda_manual_seed(0)
        self.rank = rank
        self.local_rank = local_rank
        self.world_size = world_size
        self.data_parallel_size = data_parallel_size
        self.model_parallel_size = model_parallel_size
        self.pipeline_parallel_size = pipeline_parallel_size
        self.pipeline_parallel_group_rank = mpu.get_pipeline_parallel_group_rank(
        )
        self.data_parallel_group = mpu.get_data_parallel_group()
        self.model_parallel_group = mpu.get_model_parallel_group()
        self.pipeline_parallel_pred_group = mpu.get_pipeline_parallel_pred_group(
        )
        self.pipeline_parallel_succ_group = mpu.get_pipeline_parallel_succ_group(
        )
        self.model_parallel_src_rank = mpu.get_model_parallel_src_rank()
        self.model_parallel_dst_rank = mpu.get_model_parallel_dst_rank()
        self.model_parallel_next_src_rank = (
            self.model_parallel_src_rank + self.model_parallel_size if
            self.pipeline_parallel_group_rank < self.pipeline_parallel_size - 1
            else None)
        self.model_parallel_prev_dst_rank = (
            self.model_parallel_dst_rank - self.model_parallel_size
            if self.pipeline_parallel_group_rank > 0 else None)

        self.n_layers = (config.n_layers // pipeline_parallel_size +
                         int(rank < config.n_layers % pipeline_parallel_size))
        self.config = config
        self.mixed_precision = mixed_precision
        self.checkpoint_gradients = checkpoint_gradients

        self.layers = []
        for _ in range(self.n_layers):
            l = ModelParallelTransformerLayer(
                self.config.embedding_dim,
                self.config.ffn_embedding_dim,
                self.config.num_attention_heads,
                device="cuda",
                checkpoint_gradients=self.checkpoint_gradients)
            self.layers.append(l.half() if self.mixed_precision else l)

        self.all_parameters = []
        for layer in self.layers:
            self.all_parameters.extend(layer.parameters())
        self.n_params = len(self.all_parameters)

        if self.mixed_precision:
            self.master_parameters = [
                p.clone().detach().float() for p in self.all_parameters
            ]
            for p in self.master_parameters:
                p.requires_grad_()
            self.optimizer = optimizers.FusedAdam(self.master_parameters,
                                                  lr=1e-10)
        else:
            self.optimizer = torch.optim.Adam(self.all_parameters, lr=1e-10)