Beispiel #1
0
def build_dataloader(config, mode, device, logger, seed=None):
    config = copy.deepcopy(config)

    support_dict = ['SimpleDataSet', 'LMDBDataSet']
    module_name = config[mode]['dataset']['name']
    assert module_name in support_dict, Exception(
        'DataSet only support {}'.format(support_dict))
    assert mode in ['Train', 'Eval', 'Test'
                    ], "Mode should be Train, Eval or Test."

    dataset = eval(module_name)(config, mode, logger, seed)
    loader_config = config[mode]['loader']
    batch_size = loader_config['batch_size_per_card']
    drop_last = loader_config['drop_last']
    shuffle = loader_config['shuffle']
    num_workers = 1

    use_shared_memory = False

    batch_sampler = BatchSampler(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            drop_last=drop_last)

    data_loader = DataLoader(
        dataset=dataset,
        batch_sampler=batch_sampler,
        places=device,
        num_workers=num_workers,
        return_list=True,
        use_shared_memory=use_shared_memory)

    return data_loader
Beispiel #2
0
 def init_batch_sampler(self):
     dataset = RandomDataset(1000, 10)
     sampler = SequenceSampler(dataset)
     bs = BatchSampler(sampler=sampler,
                       batch_size=self.batch_size,
                       drop_last=self.drop_last)
     return bs
Beispiel #3
0
 def init_batch_sampler(self):
     dataset = RandomDataset(self.num_samples, self.num_classes)
     bs = BatchSampler(dataset=dataset,
                       batch_size=self.batch_size,
                       shuffle=self.shuffle,
                       drop_last=self.drop_last)
     return bs
Beispiel #4
0
def create_data_loader(dataset, tokenizer, args, mode):
    trans_func1 = partial(preprocess_examples, mode=mode)
    trans_func2 = partial(convert_example,
                          tokenizer=tokenizer,
                          max_seq_len=args.max_seq_len,
                          max_response_len=args.max_response_len,
                          max_knowledge_len=args.max_knowledge_len,
                          mode=mode)
    dataset = dataset.map(trans_func1, batched=True).map(trans_func2,
                                                         lazy=True)
    if mode == 'train':
        batch_sampler = DistributedBatchSampler(dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True)
    else:
        batch_sampler = BatchSampler(dataset,
                                     batch_size=args.batch_size,
                                     shuffle=False)
    collate_fn = partial(batchify_fn,
                         pad_val=tokenizer.pad_token_id,
                         mode=mode)
    data_loader = DataLoader(dataset,
                             batch_sampler=batch_sampler,
                             collate_fn=collate_fn,
                             return_list=True)
    return dataset, data_loader
Beispiel #5
0
def load_squad_dataset(args):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    features_fn = prepare_train_features if args.is_training else prepare_validation_features
    if args.is_training:
        raw_dataset = load_dataset('squad', split='train')
    else:
        raw_dataset = load_dataset('squad', split='validation')
    column_names = raw_dataset.column_names
    dataset = raw_dataset.map(partial(
        features_fn, tokenizer=tokenizer, args=args),
                              batched=True,
                              remove_columns=column_names,
                              num_proc=4)

    bs = args.micro_batch_size * args.grad_acc_factor * args.batches_per_step * args.num_replica
    args.batch_size = bs
    if args.is_training:
        train_batch_sampler = BatchSampler(
            dataset, batch_size=bs, shuffle=args.shuffle, drop_last=True)
    else:
        train_batch_sampler = BatchSampler(
            dataset, batch_size=bs, shuffle=args.shuffle, drop_last=False)

    if args.is_training:
        collate_fn = lambda samples, fn=Dict({
            "input_ids": Stack(),
            "token_type_ids": Stack(),
            "position_ids": Stack(),
            "input_mask": Stack(),
            "start_positions": Stack(),
            "end_positions": Stack()
        }): fn(samples)
    else:
        collate_fn = lambda samples, fn=Dict({
            "input_ids": Stack(),
            "token_type_ids": Stack(),
            "position_ids": Stack(),
            "input_mask": Stack()}): fn(samples)

    data_loader = DataLoader(
        dataset=dataset,
        batch_sampler=train_batch_sampler,
        collate_fn=collate_fn,
        return_list=True)
    return raw_dataset, data_loader
Beispiel #6
0
 def test_main(self):
     try:
         dataset = RandomDataset(self.num_samples, self.num_classes)
         bs = BatchSampler(dataset=dataset,
                           indices=list(range(self.num_samples)),
                           batch_size=self.batch_size,
                           drop_last=self.drop_last)
         self.assertTrue(False)
     except AssertionError:
         pass
Beispiel #7
0
    def test_main(self):
        place = fluid.cpu_places()[0]
        with fluid.dygraph.guard(place):
            dataset = RandomDataset(100)
            batch_sampler = BatchSampler(dataset=dataset, batch_size=4)

            # dataset is not instance of Dataset
            try:
                loader = DataLoader(dataset=batch_sampler, places=place)
                self.assertTrue(False)
            except AssertionError:
                pass

            # places is None
            try:
                loader = DataLoader(dataset=dataset, places=None)
                self.assertTrue(False)
            except AssertionError:
                pass

            # num_workers < 0
            try:
                loader = DataLoader(dataset=dataset,
                                    places=place,
                                    num_workers=-1)
                self.assertTrue(False)
            except AssertionError:
                pass

            # timeout < 0
            try:
                loader = DataLoader(dataset=dataset, places=place, timeout=-1)
                self.assertTrue(False)
            except AssertionError:
                pass

            # set batch_sampler and shuffle/batch_size/drop_last
            try:
                loader = DataLoader(dataset=dataset,
                                    places=place,
                                    batch_sampler=batch_sampler,
                                    shuffle=True,
                                    drop_last=True)
                self.assertTrue(False)
            except AssertionError:
                pass

            # set batch_sampler correctly
            try:
                loader = DataLoader(dataset=dataset,
                                    places=place,
                                    batch_sampler=batch_sampler)
                self.assertTrue(True)
            except AssertionError:
                self.assertTrue(False)
Beispiel #8
0
 def test_main(self):
     try:
         dataset = RandomDataset(self.num_samples, self.num_classes)
         sampler = RandomSampler(dataset)
         bs = BatchSampler(sampler=sampler,
                           shuffle=self.shuffle,
                           batch_size=self.batch_size,
                           drop_last=self.drop_last)
         self.assertTrue(False)
     except AssertionError:
         pass
Beispiel #9
0
def get_test_dataloader(args, language, batchify_fn, trans_func):
    test_ds = load_dataset("xnli", language, splits="test")
    test_ds = test_ds.map(trans_func, lazy=True)
    test_batch_sampler = BatchSampler(test_ds,
                                      batch_size=args.batch_size,
                                      shuffle=False)
    test_data_loader = DataLoader(dataset=test_ds,
                                  batch_sampler=test_batch_sampler,
                                  collate_fn=batchify_fn,
                                  num_workers=0,
                                  return_list=True)
    return test_data_loader
Beispiel #10
0
def create_data_loader(args, dataset_class, trans_func, batchify_fn, mode):
    dataset = dataset_class(args.data_dir, mode)
    dataset = MapDataset(dataset).map(trans_func, lazy=True)
    if mode == 'train':
        batch_sampler = DistributedBatchSampler(
            dataset, batch_size=args.batch_size, shuffle=True)
    else:
        batch_sampler = BatchSampler(
            dataset, batch_size=args.test_batch_size, shuffle=False)
    data_loader = DataLoader(
        dataset,
        batch_sampler=batch_sampler,
        collate_fn=batchify_fn,
        return_list=True)
    return data_loader
Beispiel #11
0
def build_dataloader(config, mode, device, logger, seed=None):
    config = copy.deepcopy(config)

    support_dict = [
        'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet'
    ]
    module_name = config[mode]['dataset']['name']
    assert module_name in support_dict, Exception(
        'DataSet only support {}'.format(support_dict))
    assert mode in ['Train', 'Eval',
                    'Test'], "Mode should be Train, Eval or Test."

    dataset = eval(module_name)(config, mode, logger, seed)
    loader_config = config[mode]['loader']
    batch_size = loader_config['batch_size_per_card']
    drop_last = loader_config['drop_last']
    shuffle = loader_config['shuffle']
    num_workers = loader_config['num_workers']
    if 'use_shared_memory' in loader_config.keys():
        use_shared_memory = loader_config['use_shared_memory']
    else:
        use_shared_memory = True
    if mode == "Train":
        # Distribute data to multiple cards
        batch_sampler = DistributedBatchSampler(dataset=dataset,
                                                batch_size=batch_size,
                                                shuffle=shuffle,
                                                drop_last=drop_last)
    else:
        # Distribute data to single card
        batch_sampler = BatchSampler(dataset=dataset,
                                     batch_size=batch_size,
                                     shuffle=shuffle,
                                     drop_last=drop_last)

    data_loader = DataLoader(dataset=dataset,
                             batch_sampler=batch_sampler,
                             places=device,
                             num_workers=num_workers,
                             return_list=True,
                             use_shared_memory=use_shared_memory)

    # support exit using ctrl+c
    signal.signal(signal.SIGINT, term_mp)
    signal.signal(signal.SIGTERM, term_mp)

    return data_loader
Beispiel #12
0
def get_eval_dataloader(model, tokenizer, eval_filename, record_schema, args):
    """ Get evaluation dataloader
    """

    logger.info(f'Load data from {eval_filename} ...')

    schema = RecordSchema.read_from_file(record_schema)

    dataset = load_dataset(read_func,
                           tokenizer=tokenizer,
                           data_file=eval_filename,
                           max_source_length=args.max_source_length,
                           is_train=False,
                           lazy=False)

    batch_sampler = BatchSampler(dataset=dataset,
                                 batch_size=args.per_device_eval_batch_size,
                                 shuffle=False)

    label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id

    collate_fn = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        label_pad_token_id=label_pad_token_id,
        max_source_length=args.max_source_length,
        max_prefix_length=args.max_prefix_length,
        max_target_length=args.max_target_length,
        ssi_generator=DynamicSSIGenerator(
            tokenizer=tokenizer,
            schema=schema,
            positive_rate=1,
            negative=-1,
            ordered_prompt=True,
        ),
        spot_asoc_nosier=None,
    )

    data_loader = DataLoader(dataset=dataset,
                             batch_sampler=batch_sampler,
                             collate_fn=collate_fn,
                             num_workers=args.dataloader_num_workers,
                             return_list=True)

    return data_loader
Beispiel #13
0
    def __init__(self,
                 args,
                 place,
                 phase="train",
                 shuffle=False,
                 num_workers=0,
                 drop_last=False):
        assert phase in [
            "train", "test", "predict"
        ], "phase should be in [train, test, predict], but get %s" % phase

        if phase == "train":
            file_name = args.train_file
        elif phase == "test":
            file_name = args.test_file
        elif phase == "predict":
            file_name = args.predict_file

        self.dataset = LacDataset(args)
        self.dataset.file_reader(file_name, phase=phase)

        if phase == "train":
            self.sampler = DistributedBatchSampler(dataset=self.dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=shuffle,
                                                   drop_last=drop_last)
        else:
            self.sampler = BatchSampler(dataset=self.dataset,
                                        batch_size=args.batch_size,
                                        shuffle=shuffle,
                                        drop_last=drop_last)

        self.dataloader = DataLoader(dataset=self.dataset,
                                     batch_sampler=self.sampler,
                                     places=place,
                                     collate_fn=partial(
                                         create_lexnet_data_generator,
                                         args,
                                         phase=phase),
                                     num_workers=num_workers,
                                     return_list=True)
Beispiel #14
0
def get_mnli_dev_dataloader(tokenizer, args, matched=True):
    if matched:
        split = "dev_matched"
    else:
        split = "dev_mismatched"
    filename = os.path.join("caches", args.task_name + f"_{split}" + ".pkl")
    if os.path.exists(filename):
        ds = load_pickle(filename)
    else:
        ds = load_dataset("glue", args.task_name, splits=split)
        ds.map(
            partial(trans_func, tokenizer=tokenizer, args=args),
            batched=False,
            lazy=False,
        )
        save_pickle(ds, filename)

    batch_sampler = BatchSampler(ds,
                                 batch_size=args.train_batch_size,
                                 shuffle=False)

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"
            ),  # input_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"
            ),  # attention_mask
        Pad(axis=0, pad_val=-100, dtype="int64"),  # lm_labels
        Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"
            ),  # decoder_attention_mask
    ): fn(samples)

    data_loader = DataLoader(
        dataset=ds,
        batch_sampler=batch_sampler,
        collate_fn=batchify_fn,
        num_workers=args.num_workers,
        return_list=True,
    )

    return data_loader
Beispiel #15
0
def get_train_dataloader(tokenizer, args):
    splits = "train"
    data_dir = args.data_dir
    filename = os.path.join(data_dir, "cmrc2018_" + splits + ".pkl")

    if os.path.exists(filename):
        ds = load_pickle(filename)
    else:
        ds = load_dataset("cmrc2018", splits=splits)
        ds.map(
            partial(prepare_train_features_paddlenlp,
                    tokenizer=tokenizer,
                    args=args),
            batched=True,
            lazy=False,
        )
        save_pickle(ds, filename)

    batch_sampler = BatchSampler(ds,
                                 batch_size=args.train_batch_size,
                                 shuffle=True)

    batchify_fn = lambda samples, fn=Dict(
        {
            "input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
            "token_type_ids": Pad(axis=0, pad_val=0),
            "pinyin_ids": Pad(axis=0, pad_val=0),
            "start_positions": Stack(dtype="int64"),
            "end_positions": Stack(dtype="int64"),
        }): fn(samples)

    data_loader = DataLoader(
        dataset=ds,
        batch_sampler=batch_sampler,
        collate_fn=batchify_fn,
        num_workers=args.num_workers,
        return_list=True,
    )

    return data_loader
Beispiel #16
0
def create_data_loader(dataset, tokenizer, args, mode):
    trans_func = partial(convert_example,
                         tokenizer=tokenizer,
                         max_seq_len=args.max_seq_len,
                         max_target_len=args.max_target_len,
                         max_title_len=args.max_title_len,
                         mode=mode)
    dataset = dataset.map(trans_func, lazy=True)
    if mode == 'train':
        batch_sampler = DistributedBatchSampler(dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True)
    else:
        batch_sampler = BatchSampler(dataset,
                                     batch_size=args.batch_size // 2,
                                     shuffle=False)
    collate_fn = partial(batchify_fn,
                         pad_val=tokenizer.pad_token_id,
                         mode=mode)
    data_loader = DataLoader(dataset,
                             batch_sampler=batch_sampler,
                             collate_fn=collate_fn,
                             return_list=True)
    return dataset, data_loader
Beispiel #17
0
    def __init__(self,
                 input_file,
                 tokenizer,
                 label_list,
                 max_seq_length,
                 batch_size,
                 shuffle=False,
                 drop_last=False,
                 mode="all_in_memory",
                 leveldb_file="./leveldb",
                 line_processor=None,
                 delimiter="\t",
                 quotechar=None,
                 device=fluid.CPUPlace(),
                 num_workers=0,
                 return_list=True,
                 phase="train"):

        assert phase in [
            "train", "predict", "test"
        ], "phase of BertDataLoader should be in [train, predict, test], but get %s" % phase

        self.dataset = SingleSentenceDataset(tokenizer, label_list,
                                             max_seq_length, mode)

        if mode == "all_in_memory":
            self.dataset.load_all_data_in_memory(input_file, label_list,
                                                 max_seq_length, tokenizer,
                                                 line_processor, delimiter,
                                                 quotechar)
        elif mode == "leveldb":
            self.dataset.prepare_leveldb(input_file, leveldb_file, label_list,
                                         max_seq_length, tokenizer,
                                         line_processor, delimiter, quotechar)
        else:
            raise ValueError("mode should be in [all_in_memory, leveldb]")

        if phase == "train":
            self.sampler = DistributedBatchSampler(self.dataset,
                                                   batch_size,
                                                   shuffle=shuffle,
                                                   drop_last=drop_last)
        elif phase == "test" or phase == "predict":
            self.sampler = BatchSampler(dataset=self.dataset,
                                        batch_size=batch_size,
                                        shuffle=shuffle,
                                        drop_last=drop_last)

        self.dataloader = DataLoader(dataset=self.dataset,
                                     batch_sampler=self.sampler,
                                     places=device,
                                     collate_fn=partial(
                                         _prepare_train_batch,
                                         vocab_size=-1,
                                         pad_id=tokenizer.vocab["[PAD]"],
                                         cls_id=tokenizer.vocab["[CLS]"],
                                         sep_id=tokenizer.vocab["[SEP]"],
                                         mask_id=-1,
                                         return_input_mask=True,
                                         return_max_len=False,
                                         return_num_token=False),
                                     num_workers=num_workers,
                                     return_list=return_list)
Beispiel #18
0
def generate(args):
    paddle.set_device(args.device)
    tokenizer = ProphetNetTokenizer(vocab_file=args.vocab_file)
    model = ProphetNetModel(vocab_size=30522)
    model = ProphetNetForConditionalGeneration(model)

    ckpt = paddle.load("./ckpt/" + args.dataset + "/model_best.pdparams")

    model.load_dict(ckpt['model'])

    test_data_src = 'data/' + args.dataset + '_data/uncased_tok_data/test.src'
    test_data_tgt = 'data/' + args.dataset + '_data/uncased_tok_data/test.tgt'

    test_dataset = load_dataset(read,
                                data_path=[test_data_src, test_data_tgt],
                                lazy=False)

    trunc = convert_example(is_test=True)

    test_dataset = test_dataset.map(trunc)
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_ids
        Pad(axis=0, pad_val=0),  # attn mask
        Pad(axis=0, pad_val=tokenizer.pad_token_id)  # labels
    ): fn(samples)

    batch_sampler = BatchSampler(test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False)
    test_data_loader = DataLoader(dataset=test_dataset,
                                  batch_sampler=batch_sampler,
                                  num_workers=0,
                                  collate_fn=batchify_fn,
                                  return_list=True)

    model.eval()
    total_time = 0.0
    start_time = time.time()
    all_preds = []
    all_labels = []
    for step, batch in tqdm(enumerate(test_data_loader),
                            total=len(test_data_loader)):
        input_ids, attention_mask, labels = batch
        preds, _ = model.generate(input_ids=input_ids,
                                  attention_mask=attention_mask,
                                  max_length=args.max_target_length,
                                  min_length=args.min_target_length,
                                  decode_strategy=args.decode_strategy,
                                  top_k=args.top_k,
                                  top_p=args.top_p,
                                  num_beams=args.num_beams,
                                  length_penalty=args.length_penalty,
                                  early_stopping=args.early_stopping,
                                  diversity_rate=args.diversity_rate,
                                  num_beam_groups=args.num_beam_groups,
                                  repetition_penalty=args.repetition_penalty)
        total_time += (time.time() - start_time)
        all_preds.extend(preds.numpy())
        all_labels.extend(labels.numpy())
        if step % args.logging_steps == 0:
            print('step %d - %.3fs/step' %
                  (step, total_time / args.logging_steps))
            total_time = 0.0
        start_time = time.time()
    decoded_preds, _ = compute_metrics(all_preds,
                                       all_labels,
                                       tokenizer,
                                       args.ignore_pad_token_for_loss,
                                       compute_rouge_=False)
    if not os.path.exists(
            os.path.abspath(
                os.path.dirname(args.output_path) + os.path.sep + ".")):
        os.makedirs(
            os.path.abspath(
                os.path.dirname(args.output_path) + os.path.sep + "."))
    with open(args.output_path, 'w', encoding='utf-8') as fout:
        for decoded_pred in decoded_preds:
            fout.write(decoded_pred + '\n')
    print('Save generated result into: %s' % args.output_path)
Beispiel #19
0
def generate(args):
    paddle.set_device(args.device)
    set_seed(args)
    tokenizer = BartTokenizer.from_pretrained(args.model_name_or_path)
    model = BartForConditionalGeneration.from_pretrained(
        args.model_name_or_path)
    dataset = load_dataset(args.dataset_name, splits=["dev"])
    trans_func = partial(
        convert_example,
        text_column=summarization_name_mapping[args.dataset_name][0],
        summary_column=summarization_name_mapping[args.dataset_name][1],
        tokenizer=tokenizer,
        decoder_start_token_id=model.bart.decoder_start_token_id,
        max_source_length=args.max_source_length,
        max_target_length=args.max_target_length,
        ignore_pad_token_for_loss=args.ignore_pad_token_for_loss,
        is_train=False)
    batchify_fn = lambda samples, fn=Tuple(
        Stack(dtype="int64"),  # input_ids
        Stack(dtype="int64"),  # attention mask
        Stack(dtype="int32"),  # mem_seq_lens
        Stack(dtype="int64"),  # decoder_input_ids
        Stack(dtype="int64"),  # labels
    ): fn(samples)

    dataset = dataset.map(trans_func, lazy=True)
    batch_sampler = BatchSampler(dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False)
    data_loader = DataLoader(dataset=dataset,
                             batch_sampler=batch_sampler,
                             num_workers=0,
                             collate_fn=batchify_fn,
                             return_list=True)
    data_loader.pin_memory = False

    model.eval()
    total_time = 0.0
    start_time = time.time()
    all_preds = []
    all_labels = []
    for step, batch in enumerate(data_loader):
        input_ids, _, mem_seq_lens, _, labels = batch
        preds, _ = model.generate(input_ids=input_ids,
                                  seq_lens=mem_seq_lens,
                                  max_length=args.max_target_length,
                                  min_length=args.min_target_length,
                                  decode_strategy=args.decode_strategy,
                                  top_k=args.top_k,
                                  top_p=args.top_p,
                                  num_beams=args.num_beams,
                                  length_penalty=args.length_penalty,
                                  early_stopping=args.early_stopping,
                                  diversity_rate=args.diversity_rate,
                                  use_faster=args.faster)
        total_time += (time.time() - start_time)
        if step % args.logging_steps == 0:
            print('step %d - %.3fs/step' %
                  (step, total_time / args.logging_steps))
            total_time = 0.0
        all_preds.extend(preds.numpy())
        all_labels.extend(labels.numpy())
        start_time = time.time()

    rouge_result, decoded_preds = compute_metrics(
        all_preds, all_labels, tokenizer, args.ignore_pad_token_for_loss)
    print("Rouge result: ", rouge_result)
    with open(args.output_path, 'w', encoding='utf-8') as fout:
        for decoded_pred in decoded_preds:
            fout.write(decoded_pred + '\n')
    print('Save generated result into: %s' % args.output_path)
def do_train(args):
    paddle.set_device(args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    set_seed(args)
    tokenizer = BartTokenizer.from_pretrained(args.model_name_or_path)
    model = BartForConditionalGeneration.from_pretrained(
        args.model_name_or_path)
    trans_func = partial(
        convert_example,
        text_column=summarization_name_mapping[args.dataset_name][0],
        summary_column=summarization_name_mapping[args.dataset_name][1],
        tokenizer=tokenizer,
        decoder_start_token_id=model.bart.decoder_start_token_id,
        max_source_length=args.max_source_length,
        max_target_length=args.max_target_length,
        ignore_pad_token_for_loss=args.ignore_pad_token_for_loss)
    logger.info("Loading train and dev dataset: %s" % args.dataset_name)
    train_set, dev_set = load_dataset(args.dataset_name,
                                      splits=["train", "dev"])
    logger.info("Loaded train and dev dataset: %s" % args.dataset_name)
    train_set = train_set.map(trans_func, lazy=True)
    train_batch_sampler = DistributedBatchSampler(
        train_set, batch_size=args.train_batch_size, shuffle=True)
    batchify_fn = lambda samples, fn=Tuple(
        Stack(dtype="int64"),  # input_ids
        Stack(dtype="int64"),  # attention mask
        Stack(dtype="int64"),  # decoder_input_ids
        Stack(dtype="int64"),  # labels
    ): fn(samples)
    train_data_loader = DataLoader(dataset=train_set,
                                   batch_sampler=train_batch_sampler,
                                   num_workers=0,
                                   collate_fn=batchify_fn,
                                   return_list=True)
    dev_set = dev_set.map(trans_func, lazy=True)
    dev_batch_sampler = BatchSampler(dev_set,
                                     batch_size=args.eval_batch_size,
                                     shuffle=False)
    dev_data_loader = DataLoader(dataset=dev_set,
                                 batch_sampler=dev_batch_sampler,
                                 num_workers=0,
                                 collate_fn=batchify_fn,
                                 return_list=True)

    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    num_training_steps = args.max_steps if args.max_steps > 0 else (
        len(train_data_loader) * args.num_train_epochs)
    warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate,
                                         num_training_steps, warmup)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        beta1=0.9,
        beta2=0.999,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params)

    loss_fct = nn.CrossEntropyLoss()
    if args.use_amp:
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
    global_step = 0
    tic_train = time.time()
    for epoch in tqdm(range(args.num_train_epochs), desc="Epoch"):
        for step, batch in tqdm(enumerate(train_data_loader),
                                desc="Train step",
                                total=len(train_data_loader)):
            global_step += 1
            input_ids, attention_mask, decoder_input_ids, labels = batch
            with paddle.amp.auto_cast(
                    args.use_amp,
                    custom_white_list=["layer_norm", "softmax", "gelu"]):
                logits = model(input_ids, attention_mask, decoder_input_ids)
                loss = loss_fct(logits, labels)
            if args.use_amp:
                scaled_loss = scaler.scale(loss)
                scaled_loss.backward()
                scaler.minimize(optimizer, scaled_loss)
            else:
                loss.backward()
                optimizer.step()
            lr_scheduler.step()
            optimizer.clear_grad()
            if global_step % args.logging_steps == 0:
                logger.info(
                    "global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s"
                    % (global_step, num_training_steps, epoch, step,
                       paddle.distributed.get_rank(), loss, optimizer.get_lr(),
                       args.logging_steps / (time.time() - tic_train)))
                tic_train = time.time()
            if global_step % args.save_steps == 0 or global_step == num_training_steps:
                tic_eval = time.time()
                evaluate(model, dev_data_loader, tokenizer,
                         args.ignore_pad_token_for_loss,
                         args.min_target_length, args.max_target_length)
                logger.info("eval done total : %s s" %
                            (time.time() - tic_eval))
                if paddle.distributed.get_rank() == 0:
                    output_dir = os.path.join(
                        args.output_dir,
                        "bart_model_%d.pdparams" % global_step)
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    # Need better way to get inner model of DataParallel
                    model_to_save = model._layers if isinstance(
                        model, paddle.DataParallel) else model
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)
            if global_step >= num_training_steps:
                return
    if paddle.distributed.get_rank() == 0:
        output_dir = os.path.join(args.output_dir,
                                  "bart_model_final_%d.pdparams" % global_step)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        # Need better way to get inner model of DataParallel
        model_to_save = model._layers if isinstance(
            model, paddle.DataParallel) else model
        model_to_save.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir)
Beispiel #21
0
 def init_batch_sampler(self):
     bs = BatchSampler(indices=list(range(self.num_samples)),
                       batch_size=self.batch_size,
                       drop_last=self.drop_last)
     return bs