Esempio n. 1
0
def evaluate_ocnli(model, dev_dataloader, device, args):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm.tqdm(dev_dataloader):
            tokens_1, masks_1, tokens_2, masks_2, tokens_3, masks_3, labels = [x.to(device) for x in batch]

            tokens, attention_mask, position_ids = get_batch(tokens_1, args)
            output, _ = model(tokens, position_ids, attention_mask)

            losses = mpu.vocab_parallel_cross_entropy(output[:, :-1, :].contiguous().float(), tokens[:, 1:])

            output_1 = torch.sum(losses * masks_1, 1) / torch.sum(masks_1, -1)

            tensor_list = [torch.zeros_like(output_1) for _ in range(mpu.get_data_parallel_world_size())]
            torch.distributed.all_gather(tensor_list, output_1, mpu.get_data_parallel_group())
            output_1 = torch.stack(tensor_list, 0).view(-1).cpu().detach().numpy()

            # --------------
            tokens, attention_mask, position_ids = get_batch(tokens_2, args)
            output, _ = model(tokens, position_ids, attention_mask)
            losses = mpu.vocab_parallel_cross_entropy(output[:, :-1, :].contiguous().float(), tokens[:, 1:])

            output_2 = torch.sum(losses * masks_2, 1) / torch.sum(masks_2, -1)

            tensor_list = [torch.zeros_like(output_2) for _ in range(mpu.get_data_parallel_world_size())]
            torch.distributed.all_gather(tensor_list, output_2, mpu.get_data_parallel_group())
            output_2 = torch.stack(tensor_list, 0).view(-1).cpu().detach().numpy()

            # ---------------

            tokens, attention_mask, position_ids = get_batch(tokens_3, args)
            output, _ = model(tokens, position_ids, attention_mask)
            losses = mpu.vocab_parallel_cross_entropy(output[:, :-1, :].contiguous().float(), tokens[:, 1:])

            output_3 = torch.sum(losses * masks_3, 1) / torch.sum(masks_3, -1)

            tensor_list = [torch.zeros_like(output_3) for _ in range(mpu.get_data_parallel_world_size())]
            torch.distributed.all_gather(tensor_list, output_3, mpu.get_data_parallel_group())
            output_3 = torch.stack(tensor_list, 0).view(-1).cpu().detach().numpy()


            # --------------

            tensor_list_labels = [torch.zeros_like(labels) for _ in range(mpu.get_data_parallel_world_size())]
            torch.distributed.all_gather(tensor_list_labels, labels, mpu.get_data_parallel_group())

            if torch.distributed.get_rank() == 0:
                labels = torch.stack(tensor_list_labels, 0)
                labels = labels.view(-1).cpu().detach().numpy()
                res = [np.argmin(np.array(x)) for x in zip(output_1, output_2, output_3)]
                res = [x==y for x, y in zip(res, labels)]
                correct += sum(res)
                total += len(res)
    
    if torch.distributed.get_rank() == 0:
        print("EVAL", correct, total)
Esempio n. 2
0
def evaluate(model, dev_dataloader, all_labels, device, args):
    model.eval()

    if torch.distributed.get_rank() == 0:
        res = []

    with torch.no_grad():
        for batch in tqdm.tqdm(dev_dataloader):
            tokens, masks = [x.to(device) for x in batch]

            tokens, attention_mask, position_ids = get_batch(tokens, args)
            output, _ = model(tokens, position_ids, attention_mask)
            losses = mpu.vocab_parallel_cross_entropy(output[:, :-1, :].contiguous().float(), tokens[:, 1:])

            output = torch.sum(losses * masks, 1) / torch.sum(masks, -1)

            tensor_list = [torch.zeros_like(output) for _ in range(mpu.get_data_parallel_world_size())]
            torch.distributed.all_gather(tensor_list, output, mpu.get_data_parallel_group())
            output = torch.stack(tensor_list, 0).view(-1).cpu().detach().numpy()

            if torch.distributed.get_rank() == 0:
                for v in output:
                    res.append(v)

    if torch.distributed.get_rank() == 0:
        cnt = 0
        label_size = max(all_labels) + 1
        num_inst = len(res) // label_size
        for x in range(num_inst):
            label = all_labels[x]
            cur_res = res[x*label_size:(x+1)*label_size]
            pos = np.argmin(cur_res)
            if pos == label:
                cnt += 1
        print("EVAL", cnt, num_inst)
Esempio n. 3
0
def load_data(args, data_type, tokenizer, ratio=1):
    data_path = args.data_dir
    # Data parallel arguments.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    global_batch_size = args.batch_size * world_size
    num_workers = args.num_workers

    # Dataset
    filename = os.path.join(data_path, data_type + '.json')
    dataset = CHIDDataset(args, filename, data_type, tokenizer, ratio=ratio)

    # Use a random sampler with distributed batch sampler.
    if data_type == 'train':
        sampler = RandomSampler(dataset)
    else:
        sampler = torch.utils.data.SequentialSampler(dataset)
    batch_sampler = DistributedBatchSampler(sampler=sampler,
                                            batch_size=global_batch_size,
                                            drop_last=True,
                                            rank=rank,
                                            world_size=world_size)

    # Torch dataloader.
    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True,
                                       collate_fn=dataset.collate), dataset
Esempio n. 4
0
def build_data_loader(dataset,
                      batch_size,
                      num_workers,
                      drop_last,
                      shuffle=True,
                      only_rank0=False):
    """Data loader. Note that batch-size is the local (per GPU) batch-size."""

    # Sampler.
    if only_rank0:
        rank, world_size = 0, 1
    else:
        world_size = mpu.get_data_parallel_world_size()
        rank = mpu.get_data_parallel_rank()
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=world_size, rank=rank, shuffle=shuffle)

    # Data loader. Note that batch size is the per GPU batch size.
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batch_size,
                                              sampler=sampler,
                                              shuffle=False,
                                              num_workers=num_workers,
                                              drop_last=drop_last,
                                              pin_memory=True,
                                              collate_fn=my_collate)

    return data_loader
Esempio n. 5
0
def evaluate_tnews(args, model, dataloader, device, mode="dev"):
    model.eval()
    all_truth, all_preds = [], []
    with torch.no_grad():
        for batch, no_model_batch in tqdm(dataloader, desc="Evaluating {}".format(mode),
                                          disable=(torch.distributed.get_rank() != 0)):
            for k in batch:
                batch[k] = batch[k].to(device)
            for k in no_model_batch:
                no_model_batch[k] = no_model_batch[k].to(device)

            output = model(**batch)
            output = torch.sum(output * no_model_batch["loss_mask"].unsqueeze(-1), 1) / torch.sum(
                no_model_batch["loss_mask"], -1).unsqueeze(-1)

            # gather the output logits from other gpus
            tensor_list = [torch.zeros_like(output) for _ in range(mpu.get_data_parallel_world_size())]
            torch.distributed.all_gather(tensor_list, output, mpu.get_data_parallel_group())

            # gather the truth labels from other gpus
            tensor_list_truth = [torch.zeros_like(no_model_batch["truth"], dtype=torch.long) for _ in
                                 range(mpu.get_data_parallel_world_size())]
            torch.distributed.all_gather(tensor_list_truth, no_model_batch["truth"], mpu.get_data_parallel_group())

            if args.model_parallel_size == 1:
                scores = torch.stack(tensor_list, 0).view(-1, 30000)
            else:
                assert args.model_parallel_size == 2, "Now, we only support model parallel <= 2"
                # for convience implementation. Note that the truth labels only appears in the first 15000 part of the logits, e.g. on rank 0, 2, 4, ...
                scores = torch.stack(tensor_list, 0).view(-1, 15000)

            truth = torch.stack(tensor_list_truth, 0)
            truth = truth.view(-1)
            # scores = scores[:, cand_ids]

            preds = torch.argmax(scores, dim=-1)

            all_truth.extend(truth.detach().cpu().tolist())
            all_preds.extend(preds.detach().cpu().tolist())

    acc = sum([int(p == l) for p, l in zip(all_preds, all_truth)]) / len(all_truth)
    acc = torch.tensor(acc).to(device)

    acc_list = [torch.zeros_like(acc) for _ in range(mpu.get_model_parallel_world_size())]
    torch.distributed.all_gather(acc_list, acc, mpu.get_model_parallel_group())

    return acc_list[0].item(), all_truth, all_preds
def make_gpt2_dataloaders(args):

    # Input parameters.
    input_data_sizes_file = args.input_data_sizes_file
    seq_length = args.seq_length
    initial_seed = args.seed

    # Data parallel arguments.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    global_batch_size = args.batch_size * world_size
    num_workers = args.num_workers

    def make_data_loader_(data_path):
        # Build the dataset.
        dataset = GPT2Dataset(data_path, input_data_sizes_file, seq_length,
                              initial_seed)
        # Use a simple sampler with distributed batch sampler.
        sampler = torch.utils.data.SequentialSampler(dataset)
        batch_sampler = DistributedBatchSampler(sampler=sampler,
                                                batch_size=global_batch_size,
                                                drop_last=True,
                                                rank=rank,
                                                world_size=world_size)
        # Torch dataloader.
        return torch.utils.data.DataLoader(dataset,
                                           batch_sampler=batch_sampler,
                                           num_workers=num_workers,
                                           pin_memory=True)

    train = make_data_loader_(args.train_data_path)
    valid = make_data_loader_(args.val_data_path)
    test = make_data_loader_(args.test_data_path)

    args.do_train = False
    args.do_valid = False
    args.do_test = False

    if train is not None:
        args.do_train = True
    if valid is not None:
        args.do_valid = True
    if test is not None:
        args.do_test = True

    # Tokenizer.
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=args.cache_dir)
    eod_token = tokenizer.encoder['<|endoftext|>']
    num_tokens = eod_token + 1

    return (train, valid, test), num_tokens, eod_token
Esempio n. 7
0
 def __init__(self, args, tokenizer, max_seq_length, bert_prob=1.0, gap_sentence_prob=0.0, gpt_infill_prob=0.5,
              gpt_min_ratio=0.5, bert_ratio=0.15, gap_sentence_ratio=0.15, average_block_length=3,
              max_block_length=40, block_mask_prob=0.0, context_mask_ratio=0.0, context_mask_range=3,
              short_seq_prob=0.0, single_span_prob=0.0, block_position_encoding=True, encoder_decoder=False,
              shuffle_blocks=True, sentinel_token=False, task_mask=False, random_position=False, masked_lm=False):
     self.eod_token = args.eod_token
     self.tokenizer = tokenizer
     self.count = 0
     self.max_seq_length = max_seq_length
     self.rank = mpu.get_data_parallel_rank()
     self.world_size = mpu.get_data_parallel_world_size()
     # self.rank = 0
     # self.world_size = 1
     assert 0.0 <= bert_prob <= 1.0
     self.bert_prob = bert_prob
     self.gap_sentence_prob = gap_sentence_prob
     self.gpt_prob = 1 - bert_prob - gap_sentence_prob
     assert self.gpt_prob >= -1e-10
     self.infill_prob = gpt_infill_prob
     self.gpt_min_ratio = gpt_min_ratio
     self.bert_ratio = bert_ratio
     self.gap_sentence_ratio = gap_sentence_ratio
     self.block_length_distribution = [poisson.pmf(i, average_block_length) for i in range(1, max_block_length)]
     self.block_mask_prob = block_mask_prob
     self.context_mask_ratio = context_mask_ratio
     self.context_mask_range = context_mask_range
     self.short_seq_prob = short_seq_prob
     self.single_span_prob = single_span_prob
     self.block_position_encoding = block_position_encoding
     self.encoder_decoder = encoder_decoder
     self.shuffle_blocks = shuffle_blocks
     self.sentinel_token = sentinel_token
     self.generation_mask = 'gMASK' if task_mask else 'MASK'
     self.generation_mask = self.tokenizer.get_command(self.generation_mask).Id
     self.gap_sentence_mask = 'sMASK' if task_mask else 'MASK'
     self.gap_sentence_mask = self.tokenizer.get_command(self.gap_sentence_mask).Id
     self.random_position = random_position
     self.masked_lm = masked_lm
     print_rank_0(
         f"BERT prob {self.bert_prob}, gap sent prob {self.gap_sentence_prob}, GPT prob {self.gpt_prob}, infill prob {self.infill_prob}")
     print_rank_0(
         f"generation min ratio {self.gpt_min_ratio}, block ratio {self.bert_ratio}, gap sent ratio {self.gap_sentence_ratio}")
     print_rank_0(f"block length distribution {self.block_length_distribution}")
     print_rank_0(f"block mask prob {self.block_mask_prob}, context mask ratio {self.context_mask_ratio}")
Esempio n. 8
0
def test_initialize_model_parallel(model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing initialize_model_parallel with size {} ...'.format(
            model_parallel_size))
    model_parallel_size_ = min(model_parallel_size,
                               torch.distributed.get_world_size())
    assert not mpu.model_parallel_is_initialized()
    mpu.initialize_model_parallel(model_parallel_size_)
    assert mpu.model_parallel_is_initialized()

    # Checks.
    def check(group, world_size, rank):
        assert world_size == torch.distributed.get_world_size(group=group)
        assert rank == torch.distributed.get_rank(group=group)

    # Model parallel.
    world_size = model_parallel_size_
    rank = torch.distributed.get_rank() % model_parallel_size_
    assert world_size == mpu.get_model_parallel_world_size()
    assert rank == mpu.get_model_parallel_rank()
    check(mpu.get_model_parallel_group(), world_size, rank)


    # Data parallel.
    world_size = torch.distributed.get_world_size() // model_parallel_size_
    rank = torch.distributed.get_rank() // model_parallel_size
    assert world_size == mpu.get_data_parallel_world_size()
    assert rank == mpu.get_data_parallel_rank()
    check(mpu.get_data_parallel_group(), world_size, rank)

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
Esempio n. 9
0
def make_data_loader(dataset):
    """Buld dataloader given an input dataset."""
    if dataset is None:
        return None
    args = get_args()

    # Data parallel arguments.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    global_batch_size = args.batch_size * world_size
    num_workers = args.num_workers

    # Use a simple sampler with distributed batch sampler.
    sampler = torch.utils.data.SequentialSampler(dataset)
    batch_sampler = DistributedBatchSampler(sampler=sampler,
                                            batch_size=global_batch_size,
                                            drop_last=True,
                                            rank=rank,
                                            world_size=world_size)
    # Torch dataloader.
    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True)
Esempio n. 10
0
def main():
    """Main training program."""

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()

    # Pytorch distributed.
    initialize_distributed(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # get the tokenizer
    tokenizer = GPT2Tokenizer(
        os.path.join(args.tokenizer_path, 'vocab.json'),
        os.path.join(args.tokenizer_path, 'chinese_vocab.model'))

    # load data
    test_dataloader, test_dataset = load_data(args, 'test', tokenizer, 1)
    # Set an arbitrary positive integer since the optimizer and the scheduler will not be used when do eval.
    args.train_iters = 1

    # Model
    model, _, _ = setup_model_and_optimizer(args)

    device = torch.cuda.current_device()

    # give a time stemp to the model
    cur_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
    results_dir = os.path.join(args.results_dir,
                               "{}-{}".format(args.model_name, cur_time))

    if torch.distributed.get_rank() == 0:
        os.makedirs(results_dir, exist_ok=True)

    model.eval()
    all_sids = []
    all_cids = []
    all_losses = []
    with torch.no_grad():
        for batch, no_model_batch in tqdm(
                test_dataloader,
                desc="Evaluating",
                disable=(torch.distributed.get_rank() != 0)):
            for k in batch:
                batch[k] = batch[k].to(device)
            for k in no_model_batch:
                no_model_batch[k] = no_model_batch[k].to(device)

            output = model(**batch)
            losses = mpu.vocab_parallel_cross_entropy(
                output.contiguous().float(), no_model_batch["labels"])
            loss_mask = no_model_batch["loss_mask"]
            loss = torch.sum(losses * loss_mask,
                             dim=-1) / loss_mask.sum(dim=-1)

            loss_tensor_list = [
                torch.zeros_like(loss).to(device)
                for _ in range(mpu.get_data_parallel_world_size())
            ]
            torch.distributed.all_gather(loss_tensor_list,
                                         loss.data,
                                         group=mpu.get_data_parallel_group())
            all_losses.extend(loss_tensor_list)

            sids = no_model_batch["sids"]
            sid_tensor_list = [
                torch.zeros_like(sids)
                for _ in range(mpu.get_data_parallel_world_size())
            ]
            torch.distributed.all_gather(sid_tensor_list,
                                         sids.data,
                                         group=mpu.get_data_parallel_group())
            all_sids.extend(sid_tensor_list)

            cids = no_model_batch["cids"]
            cid_tensor_list = [
                torch.zeros_like(cids)
                for _ in range(mpu.get_data_parallel_world_size())
            ]
            torch.distributed.all_gather(cid_tensor_list,
                                         cids.data,
                                         group=mpu.get_data_parallel_group())
            all_cids.extend(cid_tensor_list)

    if torch.distributed.get_rank() == 0:
        all_losses = torch.stack(all_losses).view(-1).cpu().detach().numpy()
        all_sids = torch.stack(all_sids).view(-1).cpu().detach().numpy()
        all_cids = torch.stack(all_cids).view(-1).cpu().detach().numpy()

        truth_labels = test_dataset.truth_labels
        preds = [[] for _ in truth_labels]

        for sid, cid, loss in zip(all_sids, all_cids, all_losses):
            preds[sid].append((cid, loss))

        preds = [min(p, key=lambda x: x[1])[0] for p in preds if len(p) > 0]

        yprint("Acc: {}".format(
            sum([int(p == l)
                 for p, l in zip(preds, truth_labels)]) / len(truth_labels)))
        with open(os.path.join(results_dir, "zero-shot_result.txt"), "w") as f:
            f.write("Acc: {}\n".format(
                sum([int(p == l) for p, l in zip(preds, truth_labels)]) /
                len(truth_labels)))

    torch.distributed.barrier()
Esempio n. 11
0
def finetune(args,
             train_valid_datasets_provider,
             model_kwargs,
             forward_step=finetune_forward_step,
             end_of_epoch_callback_provider=None):
    """Main finetune function used across all tasks."""
    global tokenizer
    timers = Timers()
    tokenizer = prepare_tokenizer(args)
    pretrain_glm.tokenizer = tokenizer
    if args.save:
        args.save = os.path.join(args.save, args.experiment_name)
    # Train and validation data loaders.
    timers('train/valid/test dataset/dataloder').start()
    train_dataloader, valid_dataloader = None, None
    train_block_dataloader, valid_block_dataloader = None, None
    if train_valid_datasets_provider is not None and args.epochs > 0:
        if mpu.get_model_parallel_rank() == 0:
            train_dataset, valid_dataset = train_valid_datasets_provider(
                args, tokenizer)
            train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
                train_dataset, valid_dataset, args)
            if args.no_validation:
                valid_dataloader = None
            train_iters = torch.cuda.LongTensor([len(train_dataloader)])
        else:
            train_iters = torch.cuda.LongTensor([0])
        torch.distributed.broadcast(train_iters,
                                    mpu.get_model_parallel_src_rank(),
                                    group=mpu.get_model_parallel_group())
        if mpu.get_model_parallel_rank() != 0:
            args.train_iters_per_epoch = train_iters[0].item()
            args.train_iters = args.epochs * args.train_iters_per_epoch

            train_dataloader = FakeDataloader(args.train_iters_per_epoch)
            if args.no_validation:
                valid_dataloader = None
            else:
                valid_dataloader = FakeDataloader(None)
        if args.block_lm_ratio > 0.0:
            if mpu.get_model_parallel_rank() == 0:
                train_block_dataset, valid_block_dataset = train_valid_datasets_provider(
                    args, tokenizer, pattern_text=True)
                train_block_dataloader = make_data_loader(
                    train_block_dataset,
                    tokenizer,
                    args.batch_size * mpu.get_data_parallel_world_size(),
                    args.train_iters,
                    args,
                    shuffle=True,
                    block_collate=True)
                valid_block_dataloader = make_data_loader(
                    valid_block_dataset,
                    tokenizer,
                    args.batch_size * mpu.get_data_parallel_world_size(),
                    (args.train_iters // args.eval_interval + 1) *
                    args.eval_iters,
                    args,
                    shuffle=True,
                    block_collate=True)
            else:
                train_block_dataloader = FakeDataloader(args.train_iters)
                valid_block_dataloader = FakeDataloader(None)
            train_block_dataloader, valid_block_dataloader = iter(
                train_block_dataloader), iter(valid_block_dataloader)

    timers('train/valid/test dataset/dataloder').stop()
    # Build calback function.
    timers('callback function').start()
    end_of_epoch_callback, end_of_train_callback = None, None
    if end_of_epoch_callback_provider is not None:
        if train_valid_datasets_provider is not None and args.epochs > 0 and not args.no_validation:
            end_of_epoch_callback = end_of_epoch_callback_provider(
                args, tokenizer, is_test=False)
        end_of_train_callback = end_of_epoch_callback_provider(args,
                                                               tokenizer,
                                                               is_test=True)
    timers('callback function').stop()

    # Build model, optimizer and learning rate scheduler.
    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(
        args, **model_kwargs)
    timers('model and optimizer').stop()

    # If pretrained checkpoint is provided and we have not trained for
    # any iteration (i.e., iteration is zero), then load the pretrained
    # checkpoint.
    timers('pretrained checkpoint').start()
    if args.load_pretrained is not None and not args.pretrained_bert:
        task_tokens = None
        if args.continuous_prompt and args.prompt_init:
            if mpu.get_model_parallel_rank() == 0:
                dataset = train_dataloader.dataset
                processor, pvp = dataset.processor, dataset.pvp
                task_tokens = []
                for label in processor.get_labels():
                    verbalizer = pvp.verbalize(label)[0]
                    verbalizer_ids = tokenizer.EncodeAsIds(
                        verbalizer).tokenization
                    task_tokens += verbalizer_ids
                print_rank_0("Task tokens: " +
                             tokenizer.DecodeIds(task_tokens))
                num_task_tokens = len(task_tokens)
            else:
                num_task_tokens, task_tokens = 0, []
            num_task_tokens = torch.cuda.LongTensor([num_task_tokens])
            torch.distributed.broadcast(num_task_tokens,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            num_task_tokens = num_task_tokens.item()
            if num_task_tokens > 0:
                if mpu.get_model_parallel_rank() == 0:
                    task_tokens = torch.cuda.LongTensor(task_tokens)
                else:
                    task_tokens = torch.empty(
                        num_task_tokens,
                        device=torch.cuda.current_device(),
                        dtype=torch.long)
                torch.distributed.broadcast(
                    task_tokens,
                    mpu.get_model_parallel_src_rank(),
                    group=mpu.get_model_parallel_group())
                task_tokens = task_tokens.tolist()
        with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"),
                      timeout=-1):
            load_pretrained(model,
                            args.load_pretrained,
                            args,
                            task_tokens=task_tokens)
        # This is critical when only model is loaded. We should make sure
        # master parameters are also updated.
        if args.fp16 and optimizer is not None:
            if args.deepspeed:
                optimizer.refresh_fp32_params()
            else:
                optimizer._model_params_to_master_params()
    if args.load is not None:
        with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"),
                      timeout=-1):
            load_checkpoint(model,
                            optimizer,
                            lr_scheduler,
                            args,
                            no_deepspeed=args.no_deepspeed_load)
        # This is critical when only model is loaded. We should make sure
        # master parameters are also updated.
        if args.fp16 and optimizer is not None:
            if args.deepspeed:
                optimizer.refresh_fp32_params()
            else:
                optimizer._model_params_to_master_params()
    torch.distributed.barrier()
    timers('pretrained checkpoint').stop()
    args.iteration = 0
    summary_writer = None
    if torch.distributed.get_rank() == 0:
        args.log_dir = get_log_dir(base=args.summary_dir,
                                   name=args.experiment_name)
        if os.path.exists(os.path.join(args.log_dir, "test_results.json")
                          ) and args.load is None and not args.overwrite:
            raise ValueError(
                "Output directory ({}) already exists and is not empty.".
                format(args.log_dir))
        summary_writer = get_sample_writer(log_dir=args.log_dir,
                                           iteration=args.iteration)
        print_and_save_args(args, verbose=True, log_dir=args.log_dir)

    # Print setup timing.
    print_rank_0('done with setups ...')
    timers.log([
        'train/valid/test dataset/dataloder', 'callback function',
        'model and optimizer', 'pretrained checkpoint'
    ])
    print_rank_0('training ...')

    # Finetune the model.
    score_dict = None
    if train_dataloader is not None and args.epochs > 0:
        if args.block_lm_ratio > 0.0:
            forward_step = mix_forward_step
        best_iteration = _train(model,
                                optimizer,
                                lr_scheduler,
                                forward_step,
                                (train_dataloader, train_block_dataloader),
                                (valid_dataloader, valid_block_dataloader),
                                end_of_epoch_callback,
                                args,
                                timers,
                                summary_writer=summary_writer)
        if end_of_train_callback is not None and best_iteration is not None:
            with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"),
                          timeout=-1):
                args.load = os.path.join(args.save, "best")
                load_checkpoint(model,
                                optimizer,
                                lr_scheduler,
                                args,
                                no_load_optim=True,
                                no_deepspeed=True)
                args.load = None
        torch.distributed.barrier()
        if end_of_train_callback is not None:
            score_dict = end_of_train_callback(model,
                                               epoch=-1,
                                               output_predictions=True)
    # Or just evaluate.
    else:
        if end_of_train_callback is not None:
            print_rank_0('evaluation only mode, setting epoch to -1')
            score_dict = end_of_train_callback(model,
                                               epoch=-1,
                                               output_predictions=True)
    if score_dict is not None and torch.distributed.get_rank() == 0:
        score_dict.update({"type": "test"})
        with open(os.path.join(args.log_dir, "test_results.json"),
                  "w") as output:
            output.write(json.dumps(score_dict) + "\n")

    print_rank_0('done :-)')
Esempio n. 12
0
    def evaluate(self, model, dataloader, example_dict, args):
        model.eval()
        store = torch.distributed.TCPStore(args.master_ip,
                                           18931 + random.randint(0, 10000),
                                           mpu.get_data_parallel_world_size(),
                                           torch.distributed.get_rank() == 0,
                                           datetime.timedelta(seconds=30))
        print_rank_0("Distributed store created")

        with torch.no_grad():
            for idx, data in enumerate(dataloader):
                tokens, attention_mask, position_ids = process_batch(
                    data, args)
                src_tokens = tokens
                batch_size = tokens.size(0)
                mask_positions = []
                current_mask = []
                for text in tokens.tolist():
                    mask_positions.append([
                        i for i, x in enumerate(text) if x == self.mask_token
                    ])
                    current_mask.append(0)
                    # print(self.tokenizer.DecodeIds(text))
                    # print(mask_positions[-1])
                counter = 0
                done = [False] * batch_size
                while counter < args.tgt_seq_length:
                    if counter == 0:
                        # print(tokens)
                        # print(position_ids)
                        next_token_logits, *mems = model(tokens,
                                                         position_ids,
                                                         attention_mask,
                                                         return_memory=True)
                        next_token_logits = next_token_logits[:, -1]
                        position_ids = tokens.new_ones(batch_size, 2, 1)
                        for i, text in enumerate(tokens.tolist()):
                            mask_pos = mask_positions[i][current_mask[i]]
                            position_ids[i, 0] = mask_pos
                        tokens = tokens.new_zeros(batch_size, 0)
                        attention_mask = tokens.new_zeros(batch_size)
                    else:
                        position_ids[:, 1] = position_ids[:, 1] + 1
                        last_token = tokens[:, -1:]
                        next_token_logits, *mems = model(last_token,
                                                         position_ids,
                                                         attention_mask,
                                                         *mems,
                                                         return_memory=True)
                        next_token_logits = next_token_logits[:, -1]
                    next_token_scores = F.log_softmax(next_token_logits,
                                                      dim=-1)
                    next_token_scores = self.processors(
                        tokens, next_token_scores)
                    next_tokens = next_token_scores.max(dim=-1)[1]
                    # print(self.tokenizer.DecodeIds(next_tokens.tolist()))
                    for i, next_token in enumerate(next_tokens.tolist()):
                        if next_token == self.end_token:
                            if current_mask[i] + 1 < len(mask_positions[i]):
                                current_mask[i] += 1
                                next_tokens[i] = self.start_token
                                position_ids[i, 0] = mask_positions[i][
                                    current_mask[i]]
                                position_ids[i, 1] = 0
                            else:
                                done[i] = True
                        if done[i]:
                            next_tokens[i] = self.pad_token
                    if all(done):
                        break
                    tokens = torch.cat(
                        [tokens, next_tokens.unsqueeze(-1)], dim=-1)
                    counter += 1
                predictions = []
                for i, text in enumerate(tokens.tolist()):
                    text = [
                        token for token in text
                        if token not in [self.end_token, self.pad_token]
                    ]
                    blanks = [[]]
                    for token in text:
                        if token == self.start_token:
                            blanks.append([])
                        else:
                            blanks[-1].append(token)
                    output_tokens = []
                    current_blank = 0
                    for token in src_tokens[i].tolist():
                        if token == self.mask_token:
                            if current_blank < len(blanks):
                                output_tokens += blanks[current_blank]
                            current_blank += 1
                        else:
                            if token not in [self.pad_token]:
                                output_tokens.append(token)
                    text = self.tokenizer.DecodeIds(output_tokens[:-1])
                    text = blanklm_fix_tokenization(text)
                    predictions.append(text)
                    # print(text)
                uid_list = data['uid']
                if isinstance(uid_list, torch.Tensor):
                    uid_list = uid_list.cpu().numpy().tolist()
                for uid, prediction in zip(uid_list, predictions):
                    store.set(uid, prediction)
                if (idx + 1) % args.log_interval == 0:
                    print_rank_0(f"Iteration {idx + 1} / {len(dataloader)}")

        model.train()
        torch.distributed.barrier()
        print_rank_0("Evaluation completed")
        predictions, examples = [], []
        for uid, example in example_dict.items():
            predictions.append(store.get(uid).decode('utf-8'))
            examples.append(example)
        torch.distributed.barrier()
        return predictions, [], examples
Esempio n. 13
0
def forward_step(data_iterator, model, args, timers, mems):
    """Forward step."""

    # Get the batch.
    timers('batch generator').start()
    timers('data loader').start()
    rand = random.Random(args.iteration * mpu.get_data_parallel_world_size() +
                         mpu.get_data_parallel_rank())
    if data_iterator[1] and rand.random() < args.multi_task_ratio:
        data = next(data_iterator[1]) if data_iterator[1] else None
        data["mode"] = "multi-task"
    else:
        data = next(data_iterator[0]) if data_iterator[0] else None
    # print_rank_0("data iterator")
    timers('data loader').stop()
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
        data, args)
    timers('batch generator').stop()

    # print_rank_0("get batch")

    def print_masked_text(batch_id):
        block_position_ids = position_ids[:, 1]
        position_ids_ = position_ids[:, 0]
        sep = attention_mask.item() if torch.numel(
            attention_mask) == 1 else attention_mask[batch_id].item()
        text, last_segment = "", []
        for i, token_id in enumerate(tokens[batch_id, :sep].tolist()):
            token = tokenizer.IdToToken(token_id)
            if token.startswith('[MASK') or token.endswith('MASK]'):
                if last_segment:
                    text += tokenizer.DecodeIds(last_segment)
                    last_segment = []
                text += f" [{position_ids_[batch_id, i].item()}, {token}]"
            else:
                last_segment.append(token_id)
        if last_segment:
            text += tokenizer.DecodeIds(last_segment)
        print(text.encode('utf-8'))
        last_index = None
        for i in range(sep, tokens.size(1)):
            if tokenizer.IdToToken(
                    tokens[batch_id, i].item()).startswith("<|startofpiece"):
                if last_index is not None:
                    print(
                        tokenizer.DecodeIds(
                            tokens[batch_id,
                                   last_index:i].tolist()).encode('utf-8'),
                        "|",
                        tokenizer.DecodeIds(
                            labels[batch_id,
                                   last_index:i].tolist()).encode('utf-8'),
                        position_ids_[batch_id, last_index:i].tolist(),
                        block_position_ids[batch_id, last_index:i].tolist())
                last_index = i
        if last_index is not None:
            print(
                tokenizer.DecodeIds(
                    tokens[batch_id,
                           last_index:].tolist()).encode('utf-8'), "|",
                tokenizer.DecodeIds(
                    labels[batch_id, last_index:].tolist()).encode('utf-8'),
                position_ids_[batch_id, last_index:].tolist(),
                block_position_ids[batch_id, last_index:].tolist())

    if data is not None and "mode" in data:
        mode = data['mode']
    else:
        mode = 'bert'

    logits, *mems = model(tokens, position_ids, attention_mask, *mems)
    losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(),
                                              labels)
    loss_mask = loss_mask.view(-1)
    loss = torch.sum(losses.view(-1) * loss_mask)
    if loss_mask.sum().item() > 0:
        loss = loss / loss_mask.sum()

    return loss, mems, mode
Esempio n. 14
0
    def evaluate(self, model, dataloader, example_dict, args):
        """Calculate correct over total answers and return prediction if the
        `output_predictions` is true."""
        model.eval()
        store = torch.distributed.TCPStore(args.master_ip,
                                           18931 + random.randint(0, 10000),
                                           mpu.get_data_parallel_world_size(),
                                           torch.distributed.get_rank() == 0,
                                           datetime.timedelta(seconds=30))
        print_rank_0("Distributed store created")
        with torch.no_grad():
            # For all the batches in the dataset.
            for idx, data in enumerate(dataloader):
                tokens, attention_mask, position_ids = process_batch(
                    data, args)
                batch_size = tokens.size(0)
                beam_scorer = BeamSearchScorer(
                    batch_size=batch_size,
                    max_length=args.out_seq_length,
                    num_beams=args.num_beams,
                    device=tokens.device,
                    length_penalty=args.length_penalty,
                    do_early_stopping=False,
                )
                beam_scores = torch.zeros((batch_size, args.num_beams),
                                          dtype=torch.float,
                                          device=tokens.device)
                beam_scores[:, 1:] = -1e9
                beam_scores = beam_scores.view((batch_size * args.num_beams, ))
                # Run the model forward.
                counter = 0
                while counter < args.tgt_seq_length:
                    if counter == 0:
                        next_token_logits, *mems = model(tokens,
                                                         position_ids,
                                                         attention_mask,
                                                         return_memory=True)
                        seq_length = next_token_logits.size(1)
                        next_token_logits = next_token_logits[:, -1]
                        next_token_logits = next_token_logits.unsqueeze(
                            1).repeat(1, args.num_beams,
                                      1).view(batch_size * args.num_beams, -1)
                        mems = [
                            mem.unsqueeze(1).repeat(
                                1, args.num_beams, 1,
                                1).view(batch_size * args.num_beams,
                                        seq_length, -1) for mem in mems
                        ]
                        position_ids = tokens.new_ones(batch_size,
                                                       args.num_beams, 2, 1)
                        for i, text in enumerate(tokens.tolist()):
                            mask_pos = text.index(self.mask_token)
                            position_ids[i, :, 0] = mask_pos
                        position_ids = position_ids.reshape(
                            batch_size * args.num_beams, 2, 1)
                        tokens = tokens.new_zeros(batch_size * args.num_beams,
                                                  0)
                        attention_mask = tokens.new_zeros(
                            [batch_size * args.num_beams])
                    else:
                        if not args.no_block_position:
                            position_ids[:, 1] = counter + 1
                        last_token = tokens[:, -1:]
                        next_token_logits, *mems = model(last_token,
                                                         position_ids,
                                                         attention_mask,
                                                         *mems,
                                                         return_memory=True)
                        next_token_logits = next_token_logits[:, -1]
                    next_token_scores = F.log_softmax(next_token_logits,
                                                      dim=-1)
                    next_token_scores = self.processors(
                        tokens, next_token_scores)
                    next_token_scores = next_token_scores + beam_scores[:, None].expand_as(
                        next_token_scores)
                    vocab_size = next_token_scores.shape[-1]
                    next_token_scores = next_token_scores.view(
                        batch_size, args.num_beams * vocab_size)

                    probs = F.softmax(next_token_scores, dim=-1)
                    if args.select_topk:
                        _, next_tokens = torch.topk(probs,
                                                    k=2 * args.num_beams,
                                                    dim=-1,
                                                    largest=True)
                    else:
                        next_tokens = torch.multinomial(probs,
                                                        num_samples=2 *
                                                        args.num_beams)
                    next_token_scores = torch.gather(next_token_scores, -1,
                                                     next_tokens)
                    next_token_scores, _indices = torch.sort(next_token_scores,
                                                             descending=True,
                                                             dim=1)
                    next_tokens = torch.gather(next_tokens, -1, _indices)

                    next_indices = next_tokens // vocab_size
                    next_tokens = next_tokens % vocab_size
                    # stateless
                    beam_outputs = beam_scorer.process(
                        tokens,
                        next_token_scores,
                        next_tokens,
                        next_indices,
                        eos_token_id=self.end_token,
                        pad_token_id=self.pad_token)
                    beam_scores = beam_outputs["next_beam_scores"]
                    beam_next_tokens = beam_outputs["next_beam_tokens"]
                    beam_idx = beam_outputs["next_beam_indices"]
                    beam_next_tokens = beam_next_tokens.unsqueeze(-1)
                    tokens = torch.cat([tokens[beam_idx, :], beam_next_tokens],
                                       dim=-1)
                    mems = [mem[beam_idx] for mem in mems] if mems else []
                    if beam_scorer.is_done:
                        break
                    counter += 1
                tokens, _ = beam_scorer.finalize(tokens,
                                                 beam_scores,
                                                 next_tokens,
                                                 next_indices,
                                                 eos_token_id=self.end_token,
                                                 pad_token_id=self.pad_token)
                predictions = []
                for text in tokens.tolist():
                    text = [
                        token for token in text
                        if token not in [self.end_token, self.pad_token]
                    ]
                    text = self.tokenizer.DecodeIds(text)
                    predictions.append(text)
                uid_list = data['uid']
                if isinstance(uid_list, torch.Tensor):
                    uid_list = uid_list.cpu().numpy().tolist()
                for uid, prediction in zip(uid_list, predictions):
                    store.set(uid, prediction)
                if (idx + 1) % args.log_interval == 0:
                    print_rank_0(f"Iteration {idx + 1} / {len(dataloader)}")
        model.train()
        torch.distributed.barrier()
        print_rank_0("Evaluation completed")
        predictions, examples = [], []
        for uid, example in example_dict.items():
            predictions.append(store.get(uid).decode('utf-8'))
            examples.append(example)
        torch.distributed.barrier()
        return predictions, [], examples
Esempio n. 15
0
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider, args):
    """XXX"""

    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
        # Rank, size, and global batch size.
        data_parallel_size = mpu.get_data_parallel_world_size()
        global_batch_size = args.batch_size * data_parallel_size

        # Number of train/valid/test samples.
        train_iters = args.train_iters
        eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
        test_iters = args.eval_iters
        train_val_test_num_samples = [train_iters * global_batch_size,
                                      eval_iters * global_batch_size,
                                      test_iters * global_batch_size]
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
        train_dataloader = make_data_loader(train_ds)
        valid_dataloader = make_data_loader(valid_ds)
        test_dataloader = make_data_loader(test_ds)

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

    # Shift the start iterations.
    if train_dataloader is not None:
        train_dataloader.batch_sampler.start_iter = args.iteration % \
            len(train_dataloader)
        print_rank_0('setting training data start iteration to {}'.
                     format(train_dataloader.batch_sampler.start_iter))
    if valid_dataloader is not None:
        start_iter_val = (args.iteration // args.eval_interval) * \
            args.eval_iters
        valid_dataloader.batch_sampler.start_iter = start_iter_val % \
            len(valid_dataloader)
        print_rank_0('setting validation data start iteration to {}'.
                     format(valid_dataloader.batch_sampler.start_iter))

    # Build iterators.
    if train_dataloader is not None:
        train_data_iterator = iter(train_dataloader)
    else:
        train_data_iterator = None

    if valid_dataloader is not None:
        valid_data_iterator = iter(valid_dataloader)
    else:
        valid_data_iterator = None

    if test_dataloader is not None:
        test_data_iterator = iter(test_dataloader)
    else:
        test_data_iterator = None

    return train_data_iterator, valid_data_iterator, test_data_iterator
Esempio n. 16
0
def main():
    """Main training program."""

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()

    # Pytorch distributed.
    initialize_distributed(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # get the tokenizer
    tokenizer = GPT2Tokenizer(os.path.join(args.tokenizer_path, 'vocab.json'), os.path.join(args.tokenizer_path, 'chinese_vocab.model'))

    # load train data
    if args.do_train:
        train_dataloader, _ = load_data(args, 'train', tokenizer, 1)
        dev_dataloader, dev_dataset = load_data(args, 'dev', tokenizer, 1)

        with open(args.deepspeed_config, "r") as f:
            deepspeed_conf = json.load(f)

        epoch = args.epoch
        grad_acc = deepspeed_conf["gradient_accumulation_steps"]
        args.train_iters = len(train_dataloader) * epoch / grad_acc

        # Model, optimizer, and learning rate.
        # TODO: maybe need to reinitialize optimizer
    elif args.do_eval:
        # Set an arbitrary positive integer since the optimizer and the scheduler will not be used when do eval.
        args.train_iters = 1

    model, optimizer, lr_scheduler = setup_model_and_optimizer_C(args)
    device = torch.cuda.current_device()

    # give a time stemp to the model
    cur_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
    results_dir = os.path.join(args.results_dir, "{}-{}".format(args.model_name, cur_time))
    os.makedirs(results_dir, exist_ok=True)

    if args.do_train and torch.distributed.get_rank() == 0:

        with open(os.path.join(results_dir, "train_log.txt"), "w") as f:
            f.write("Train losses:\n")

        with open(os.path.join(results_dir, "dev_log.txt"), "w") as f:
            f.write("Dev accs:\n")

    torch.distributed.barrier()

    if args.do_train:
        # cand_ids = torch.tensor(dev_dataset.cand_ids).to(device)
        total_loss, logging_loss, best_acc = 0.0, 0.0, 0.0
        global_step, total_step, best_step = 0, 0, 0
        
        for e in range(epoch):
            model.train()
            for batch, no_model_batch in tqdm(train_dataloader, disable=(torch.distributed.get_rank() != 0)):
                for k in batch:
                    batch[k] = batch[k].to(device)
                for k in no_model_batch:
                    no_model_batch[k] = no_model_batch[k].to(device)

                output = model(**batch)
                # get the loss of the last token
                output = torch.sum(output * no_model_batch["loss_mask"].unsqueeze(-1), 1) / torch.sum(no_model_batch["loss_mask"], -1).unsqueeze(-1)
                # get the label of the last token
                # labels = no_model_batch["labels"].float()
                labels = no_model_batch["truth"].float()
                # labels = (torch.sum(labels * no_model_batch["loss_mask"], 1) / torch.sum(no_model_batch["loss_mask"], -1)).long()
                # cross_entropy loss
                # losses = mpu.vocab_parallel_cross_entropy(output.unsqueeze(1).contiguous().float(), labels.unsqueeze(1))
                losses = CrossEntropyLoss(output.unsqueeze(1).contiguous().float(), labels.unsqueeze(1))
                loss = torch.mean(losses)

                model.backward(loss)
                model.step()

                torch.distributed.all_reduce(loss.data, group=mpu.get_data_parallel_group())
                loss.data = loss.data / mpu.get_data_parallel_world_size()
                total_loss += loss.item() / grad_acc

                if total_step % grad_acc == 0:
                    global_step += 1
                    if global_step != 0 and global_step % args.log_interval == 0:
                        # logging
                        if torch.distributed.get_rank() == 0:
                            train_log = "Epoch {}, global step {}, total step {}, train lm loss: {}".format(e, global_step, epoch * len(train_dataloader), (total_loss - logging_loss) / args.log_interval)
                            yprint(train_log)
                            with open(os.path.join(results_dir, "train_log.txt"), "a") as f:
                                f.write(train_log + "\n")

                        logging_loss = total_loss
    
                    if global_step != 0 and global_step % args.eval_interval == 0:
                        # evaluate on the dev
                        acc, _, _ = evaluate_tnews(args, model, dev_dataloader, device, mode="dev")
                        dev_results_dir = os.path.join(results_dir, "dev_step-{}".format(global_step))

                        if acc > best_acc:
                            best_acc = acc
                            best_step = global_step

                        if torch.distributed.get_rank() == 0:
                            # we will only write the log file once
                            dev_log = "Epoch: {}, Global step: {}, Acc: {}".format(e, global_step, acc)
                            yprint(dev_log)
                            os.makedirs(dev_results_dir, exist_ok=True)
                            with open(os.path.join(dev_results_dir, "dev_result.txt"), "w") as f:
                                f.write(dev_log + "\n")
                            with open(os.path.join(results_dir, "dev_log.txt"), "a") as f:
                                f.write(dev_log + "\n")

                        torch.distributed.barrier()
                        
                        args.save = dev_results_dir
                        save_checkpoint(global_step, model, optimizer, lr_scheduler, args)

                total_step += 1

        with open(os.path.join(dev_results_dir, "dev_log.txt"), "a") as f:
            f.write("Best acc: {} Best step: {}\n".format(best_acc, best_step))

    if args.do_eval:
        # evaluate on the test
        test_dataloader, test_dataset = load_data(args, 'test', tokenizer, 1)
        cand_ids = torch.tensor(test_dataset.cand_ids).to(device)

        if args.do_train:
            # if do training, then evaluate the one with the max acc on dev set.
            eval_ckpt_path = os.path.join(results_dir, "dev_step-{}".format(best_step))
            args.load = eval_ckpt_path
        else:
            # if only do eval, then evaluate the one specified by the user.
            args.load = args.eval_ckpt_path            
        
        load_checkpoint(model=model, optimizer=None, lr_scheduler=None, args=args)
        acc, _, _ = evaluate(args, model, test_dataloader, cand_ids, device, mode="test")

        if torch.distributed.get_rank() == 0:
            eval_log = "Checkpoint from {}: Acc: {}".format(args.load, acc)
            yprint(eval_log)
            with open(os.path.join(results_dir, "eval_log"), "w") as f:
                f.write(eval_log + "\n")

        torch.distributed.barrier()
Esempio n. 17
0
def load_tnews_data(data_path, data_type, tokenizer, few_shot=False):
    args = get_args()

    filename = os.path.join(data_path, data_type+'.json')
    objs = []
    with open(filename) as fin:
        for line in fin:
            objs.append(json.loads(line.strip()))

    pad_id = tokenizer.encoder['<pad>']
    args.eod_token = tokenizer.encoder['<eod>']

    labels = []
    label_map = {}
    label_reverse = {}
    with open(os.path.join(data_path, 'labels.json')) as fin:
        for i, line in enumerate(fin):
            obj = json.loads(line.strip())
            labels.append(obj['label_desc'])
            label_map[obj['label_desc']] = i
            label_reverse[obj['label']] = obj['label_desc']

    all_tokens = []
    all_masks = []
    all_labels = []
    for _, obj in enumerate(objs):
        sentence = obj['sentence']
        tokenized_sentence = tokenizer.encode(sentence)[:args.seq_length-20]
        obj['label_desc'] = label_reverse[obj['label']]

        if few_shot:
            cur_labels = random.sample(labels, 3)
            while obj['label_desc'] in cur_labels:
                cur_labels = random.sample(labels, 3)
            cur_labels.append(obj['label_desc'])
            cur_label = cur_labels.index(obj['label_desc'])
            assert cur_label != -1
        else:
            cur_labels = labels
            cur_label = label_map[obj['label_desc']]

        all_labels.append(cur_label)

        for _, label in enumerate(cur_labels):
            prompt = "这是关于{}的文章:".format(label)
            prompt_tokens = tokenizer.encode(prompt)
            prompt_len = len(prompt_tokens)
            tokens = prompt_tokens + tokenized_sentence
            second_mask = [0] * (args.seq_length-1)
            for idx in range(prompt_len-1, len(tokens)-1):
                second_mask[idx] = 1
            all_masks.append(second_mask)
            token_length = len(tokens)
            assert token_length < args.seq_length
            tokens.extend([pad_id] * (args.seq_length - token_length))
            all_tokens.append(tokens)
    
    all_tokens = torch.tensor(all_tokens, dtype=torch.long)
    all_masks = torch.tensor(all_masks, dtype=torch.float)
    dataset = TensorDataset(all_tokens, all_masks)

    # Data parallel arguments.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    global_batch_size = args.batch_size * world_size
    num_workers = args.num_workers

    sampler = torch.utils.data.SequentialSampler(dataset)
    batch_sampler = DistributedBatchSampler(sampler=sampler,
                                            batch_size=global_batch_size,
                                            drop_last=True,
                                            rank=rank,
                                            world_size=world_size)
    
    # Torch dataloader.
    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True), all_labels
Esempio n. 18
0
def load_ocnli_data(data_path, data_type, tokenizer):
    args = get_args()

    filename = os.path.join(data_path, data_type+'.json')
    objs = []
    with open(filename) as fin:
        for line in fin:
            objs.append(json.loads(line.strip()))

    pad_id = tokenizer.encoder['<pad>']
    args.eod_token = tokenizer.encoder['<eod>']

    all_tokens_1 = []
    all_masks_1 = []
    all_tokens_2 = []
    all_masks_2 = []    
    all_tokens_3 = []
    all_masks_3 = [] 
    all_labels = []
    for obj in objs:

        if obj['label'] == '-':
            continue

        prompt = "{}?对,".format(obj['sentence1'])
        prompt_tokens = tokenizer.encode(prompt)
        prompt_len = len(prompt_tokens)
        tokens = prompt_tokens + tokenizer.encode(obj['sentence2'])
        second_mask = [0] * (args.seq_length-1)
        for idx in range(prompt_len-1, len(tokens)-1):
            second_mask[idx] = 1
        all_masks_1.append(second_mask)
        token_length = len(tokens)
        assert token_length < args.seq_length
        tokens.extend([pad_id] * (args.seq_length - token_length))
        all_tokens_1.append(tokens)

        prompt = "{}?错,".format(obj['sentence1'])
        prompt_tokens = tokenizer.encode(prompt)
        prompt_len = len(prompt_tokens)
        tokens = prompt_tokens + tokenizer.encode(obj['sentence2'])
        second_mask = [0] * (args.seq_length-1)
        for idx in range(prompt_len-1, len(tokens)-1):
            second_mask[idx] = 1
        all_masks_2.append(second_mask)
        token_length = len(tokens)
        assert token_length < args.seq_length
        tokens.extend([pad_id] * (args.seq_length - token_length))
        all_tokens_2.append(tokens)

        prompt = "{}?也许,".format(obj['sentence1'])
        prompt_tokens = tokenizer.encode(prompt)
        prompt_len = len(prompt_tokens)
        tokens = prompt_tokens + tokenizer.encode(obj['sentence2'])
        second_mask = [0] * (args.seq_length-1)
        for idx in range(prompt_len-1, len(tokens)-1):
            second_mask[idx] = 1
        all_masks_3.append(second_mask)
        token_length = len(tokens)
        assert token_length < args.seq_length
        tokens.extend([pad_id] * (args.seq_length - token_length))
        all_tokens_3.append(tokens)

        if obj['label'] == 'entailment':
            all_labels.append([0])
        elif obj['label'] == 'contradiction':
            all_labels.append([1])
        else:
            all_labels.append([2])

    all_tokens_1 = torch.tensor(all_tokens_1, dtype=torch.long)
    all_masks_1 = torch.tensor(all_masks_1, dtype=torch.float)
    all_tokens_2 = torch.tensor(all_tokens_2, dtype=torch.long)
    all_masks_2 = torch.tensor(all_masks_2, dtype=torch.float)
    all_tokens_3 = torch.tensor(all_tokens_3, dtype=torch.long)
    all_masks_3 = torch.tensor(all_masks_3, dtype=torch.float)
    all_labels = torch.tensor(all_labels, dtype=torch.long)
    dataset = TensorDataset(all_tokens_1, all_masks_1, all_tokens_2, all_masks_2, all_tokens_3, all_masks_3, all_labels)

    # Data parallel arguments.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    global_batch_size = args.batch_size * world_size
    num_workers = args.num_workers

    # Use a random sampler with distributed batch sampler.
    if data_type == 'train':
        sampler = RandomSampler(dataset)
    else:
        sampler = torch.utils.data.SequentialSampler(dataset)
    batch_sampler = DistributedBatchSampler(sampler=sampler,
                                            batch_size=global_batch_size,
                                            drop_last=True,
                                            rank=rank,
                                            world_size=world_size)
    
    # Torch dataloader.
    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True)