Пример #1
0
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_architecture(args)

        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

        src_berttokenizer = BertTokenizer.from_pretrained(args.bert_model_name)
        tgt_berttokenizer = BertTokenizer.from_pretrained(
            args.decoder_bert_model_name)
        assert src_berttokenizer.pad() == tgt_berttokenizer.pad()

        bertdecoder = BertAdapterDecoderFull.from_pretrained(
            args.decoder_bert_model_name,
            args,
            from_scratch=args.train_from_scratch)
        enc_top_layer_adapter = getattr(args, 'enc_top_layer_adapter', -1)
        adapter_dimension = getattr(args, 'adapter_dimension', 2048)
        bertencoder = BertModelWithAdapter.from_pretrained(
            args.bert_model_name,
            adapter_dimension,
            enc_top_layer_adapter,
            from_scratch=args.train_from_scratch)
        return cls(bertencoder, bertdecoder, src_berttokenizer,
                   tgt_berttokenizer, args)
Пример #2
0
def main(args):
    # filename = '/mnt/yardcephfs/mmyard/g_wxg_td_prc/mt/v_xyvhuang/bert-nmt/examples/copy_translation/headtest.bert.en'
    filename = args.file_name
    tokenizer = []
    # bert_tokenizer = AutoTokenizer.from_pretrained(args.bert_tokenizer, do_lower_case=False)
    # bart_tokenizer = AutoTokenizer.from_pretrained(args.bart_tokenizer, do_lower_case=False)
    bert_tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer,
                                                   do_lower_case=False)
    bert_tokenizer.name_or_path = args.bert_tokenizer
    bart_tokenizer = BartTokenizer.from_pretrained(args.bart_tokenizer,
                                                   do_lower_case=False)
    electra_tokenizer = ElectraTokenizer.from_pretrained(
        args.electra_tokenizer)
    #xlnet_tokenizer = AutoTokenizer.from_pretrained('xlnet-base-cased',cache_dir='/mnt/yardcephfs/mmyard/g_wxg_td_prc/mt/v_xyvhuang/ft_local/bart-base')
    tokenizer.append(bert_tokenizer)
    tokenizer.append(bart_tokenizer)
    tokenizer.append(electra_tokenizer)
    #tokenizer.append(xlnet_tokenizer)
    encoder_inputs, sentence_splits, extra_outs, drop_list = add_line(
        filename,
        tokenizer,
        add_extra_outs=args.extra_outs,
        n_process=args.n_process)
    # save_path = '/mnt/yardcephfs/mmyard/g_wxg_td_prc/mt/v_xyvhuang/bert-nmt/examples/copy_translation/encoder.en'
    save_path = args.output_path
    drop_path = args.drop_path
    # dict_save_path = '/mnt/yardcephfs/mmyard/g_wxg_td_prc/mt/v_xyvhuang/data/bert-nmt/destdir-mult/encoder_dict'
    save_txt(save_path, sentence_splits)
    save_txt(drop_path, drop_list)
    if len(extra_outs) != []:
        save_txt(save_path + '.bert', extra_outs[0])
        save_txt(save_path + '.bart', extra_outs[1])
        save_txt(save_path + '.electra', extra_outs[1])
    output_dict_save_path = args.output_path + '.data_dict'
    save_input_dict(output_dict_save_path, encoder_inputs)
Пример #3
0
def make_batches(lines, args, task, max_positions, encode_fn):
    oldlines = lines
    lines = oldlines[0::2]
    bertlines = oldlines[1::2]
    tokens = [
        task.source_dictionary.encode_line(encode_fn(src_str),
                                           add_if_not_exist=False).long()
        for src_str in lines
    ]
    bertdict = BertTokenizer.from_pretrained(args.bert_model_name)

    def getbert(line):
        line = line.strip()
        line = '{} {} {}'.format('[CLS]', line, '[SEP]')
        tokenizedline = bertdict.tokenize(line)
        if len(tokenizedline) > bertdict.max_len:
            tokenizedline = tokenizedline[:bertdict.max_len - 1]
            tokenizedline.append('[SEP]')
        words = bertdict.convert_tokens_to_ids(tokenizedline)
        nwords = len(words)
        ids = torch.IntTensor(nwords)
        for i, word in enumerate(words):
            ids[i] = word
        return ids.long()

    berttokens = [getbert(x) for x in bertlines]
    lengths = torch.LongTensor([t.numel() for t in tokens])
    bertlengths = torch.LongTensor([t.numel() for t in berttokens])
    itr = task.get_batch_iterator(
        dataset=task.build_dataset_for_inference(tokens, lengths, berttokens,
                                                 bertlengths, bertdict),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
        yield Batch(ids=batch['id'],
                    src_tokens=batch['net_input']['src_tokens'],
                    src_lengths=batch['net_input']['src_lengths'],
                    bert_input=batch['net_input']['bert_input'])
Пример #4
0
from dataloader import get_chABSA_DataLoaders_and_TEXT
from bert import BertTokenizer
train_dl, val_dl, TEXT, dataloaders_dict = get_chABSA_DataLoaders_and_TEXT(
    max_length=256, batch_size=32)
#print(train_dl)

# 動作確認 検証データのデータセットで確認
batch = next(iter(train_dl))
print("Textの形状=", batch.Text[0].shape)
print("Labelの形状=", batch.Label.shape)
print(batch.Text)
print(batch.A_label)
print(batch.Label)

# ミニバッチの1文目を確認してみる
tokenizer_bert = BertTokenizer(vocab_file="./vocab/vocab.txt",
                               do_lower_case=False)
text_minibatch_1 = (batch.Label).numpy()

# IDを単語に戻す
text = tokenizer_bert.convert_ids_to_tokens(text_minibatch_1)

print(text)
Пример #5
0
def main(args):
    utils.import_user_module(args)

    os.makedirs(args.destdir, exist_ok=True)

    logger.addHandler(
        logging.FileHandler(filename=os.path.join(args.destdir,
                                                  "preprocess.log"), ))
    logger.info(args)

    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )

    target = not args.only_source

    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    if target and not args.tgtdict and os.path.exists(
            dict_path(args.target_lang)):
        raise FileExistsError(dict_path(args.target_lang))

    if args.joined_dictionary:
        assert (
            not args.srcdict or not args.tgtdict
        ), "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
            assert (args.trainpref
                    ), "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary(
                {
                    train_path(lang)
                    for lang in [args.source_lang, args.target_lang]
                },
                src=True,
            )
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        else:
            assert (args.trainpref
                    ), "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)],
                                        src=True)

        if target:
            if args.tgtdict:
                tgt_dict = task.load_dictionary(args.tgtdict)
            else:
                assert (
                    args.trainpref
                ), "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary([train_path(args.target_lang)],
                                            tgt=True)
        else:
            tgt_dict = None

    src_dict.save(dict_path(args.source_lang))
    if target and tgt_dict is not None:
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                            num_workers):
        logger.info("[{}] Dictionary: {} types".format(lang, len(vocab)))
        output_prefix += '.bert' if isinstance(vocab, BertTokenizer) else ''
        input_prefix += '.bert' if isinstance(vocab, BertTokenizer) else ''
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize,
                    (
                        args,
                        input_file,
                        vocab,
                        prefix,
                        lang,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                    ),
                    callback=merge_result,
                )
            pool.close()

        ds = indexed_dataset.make_builder(
            dataset_dest_file(args, output_prefix, lang, "bin"),
            impl=args.dataset_impl,
            vocab_size=len(vocab),
        )
        merge_result(
            Binarizer.binarize(input_file,
                               vocab,
                               lambda t: ds.add_item(t),
                               offset=0,
                               end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        logger.info(
            "[{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
                lang,
                input_file,
                n_seq_tok[0],
                n_seq_tok[1],
                100 * sum(replaced.values()) / n_seq_tok[1],
                vocab.unk_word,
            ))

    def make_binary_alignment_dataset(input_prefix, output_prefix,
                                      num_workers):
        nseq = [0]

        def merge_result(worker_result):
            nseq[0] += worker_result["nseq"]

        input_file = input_prefix
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize_alignments,
                    (
                        args,
                        input_file,
                        utils.parse_alignment,
                        prefix,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                    ),
                    callback=merge_result,
                )
            pool.close()

        ds = indexed_dataset.make_builder(dataset_dest_file(
            args, output_prefix, None, "bin"),
                                          impl=args.dataset_impl)

        merge_result(
            Binarizer.binarize_alignments(
                input_file,
                utils.parse_alignment,
                lambda t: ds.add_item(t),
                offset=0,
                end=offsets[1],
            ))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, None)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, None, "idx"))

        logger.info("[alignments] {}: parsed {} alignments".format(
            input_file, nseq[0]))

    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
        if args.dataset_impl == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)
        else:
            make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                                num_workers)

    def make_all(lang, vocab):
        if args.trainpref:
            make_dataset(vocab,
                         args.trainpref,
                         "train",
                         lang,
                         num_workers=args.workers)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset(vocab,
                             validpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dataset(vocab,
                             testpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)

    def make_all_alignments():
        if args.trainpref and os.path.exists(args.trainpref + "." +
                                             args.align_suffix):
            make_binary_alignment_dataset(
                args.trainpref + "." + args.align_suffix,
                "train.align",
                num_workers=args.workers,
            )
        if args.validpref and os.path.exists(args.validpref + "." +
                                             args.align_suffix):
            make_binary_alignment_dataset(
                args.validpref + "." + args.align_suffix,
                "valid.align",
                num_workers=args.workers,
            )
        if args.testpref and os.path.exists(args.testpref + "." +
                                            args.align_suffix):
            make_binary_alignment_dataset(
                args.testpref + "." + args.align_suffix,
                "test.align",
                num_workers=args.workers,
            )

    make_all(args.source_lang, src_dict)
    if target:
        make_all(args.target_lang, tgt_dict)
    berttokenizer = BertTokenizer.from_pretrained(args.bert_model_name)
    make_all(args.source_lang, berttokenizer)
    if args.align_suffix:
        make_all_alignments()

    logger.info("Wrote preprocessed data to {}".format(args.destdir))

    if args.alignfile:
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        freq_map = {}
        with open(args.alignfile, "r", encoding="utf-8") as align_file:
            with open(src_file_name, "r", encoding="utf-8") as src_file:
                with open(tgt_file_name, "r", encoding="utf-8") as tgt_file:
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = src_dict.encode_line(s, add_if_not_exist=False)
                        ti = tgt_dict.encode_line(t, add_if_not_exist=False)
                        ai = list(map(lambda x: tuple(x.split("-")),
                                      a.split()))
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk(
                            ) and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx],
                                     key=freq_map[srcidx].get)

        with open(
                os.path.join(
                    args.destdir,
                    "alignment.{}-{}.txt".format(args.source_lang,
                                                 args.target_lang),
                ),
                "w",
                encoding="utf-8",
        ) as f:
            for k, v in align_dict.items():
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
Пример #6
0
def get_chABSA_DataLoaders_and_TEXT(max_length=256, batch_size=32):
    """IMDbのDataLoaderとTEXTオブジェクトを取得する。 """
    # 乱数のシードを設定
    torch.manual_seed(1234)
    np.random.seed(1234)
    random.seed(1234)
    # 単語分割用のTokenizerを用意
    tokenizer_bert = BertTokenizer(vocab_file=VOCAB_FILE, do_lower_case=False)

    def preprocessing_text(text):
        # 半角・全角の統一
        text = mojimoji.han_to_zen(text)
        # 改行、半角スペース、全角スペースを削除
        text = re.sub('\r', '', text)
        text = re.sub('\n', '', text)
        text = re.sub(' ', '', text)
        text = re.sub(' ', '', text)
        # 数字文字の一律「0」化
        text = re.sub(r'[0-9 0-9]+', '0', text)  # 数字

        # カンマ、ピリオド以外の記号をスペースに置換
        for p in string.punctuation:
            if (p == ".") or (p == ","):
                continue
            else:
                text = text.replace(p, " ")
            return text

    # 前処理と単語分割をまとめた関数を定義
    # 単語分割の関数を渡すので、tokenizer_bertではなく、tokenizer_bert.tokenizeを渡す点に注意
    def tokenizer_with_preprocessing(text, tokenizer=tokenizer_bert.tokenize):
        text = preprocessing_text(text)
        ret = tokenizer(text)  # tokenizer_bert
        return ret

    # データを読み込んだときに、読み込んだ内容に対して行う処理を定義します
    # 読み込むデータのカラムを "torchtext.data.Field" を用いて定義する
    # "is_target=True" ラベルフィールドかどうかの設定 (デフォルト:False)
    max_length = 256
    TEXT = torchtext.data.Field(sequential=True,
                                tokenize=tokenizer_with_preprocessing,
                                use_vocab=True,
                                lower=False,
                                include_lengths=True,
                                batch_first=True,
                                fix_length=max_length,
                                init_token="[CLS]",
                                eos_token="[SEP]",
                                pad_token='[PAD]',
                                unk_token='[UNK]')
    A_LABEL = torchtext.data.Field(sequential=True, batch_first=True)
    LABEL = torchtext.data.Field(sequential=False, is_target=True)

    # フォルダ「data」から各tsvファイルを読み込みます
    # BERT用で処理するので、10分弱時間がかかります
    # data データセットを定義
    train_val_ds, test_ds = torchtext.data.TabularDataset.splits(
        path=DATA_PATH,
        train='train_slot2_v2.tsv',
        test='test_slot2_v2.tsv',
        format='tsv',
        fields=[('Text', TEXT), ('A_label', A_LABEL), ('Label', LABEL)])

    vocab_bert, ids_to_tokens_bert = load_vocab(vocab_file=VOCAB_FILE)
    # ラベルを認識するために追加(LABEL.build)
    A_LABEL.build_vocab(train_val_ds, min_freq=1)
    LABEL.build_vocab(train_val_ds, min_freq=1)
    TEXT.build_vocab(train_val_ds, min_freq=1)
    TEXT.vocab.stoi = vocab_bert

    batch_size = 32  # BERTでは16、32あたりを使用する
    train_dl = torchtext.data.Iterator(train_val_ds,
                                       batch_size=batch_size,
                                       train=True)
    val_dl = torchtext.data.Iterator(test_ds,
                                     batch_size=batch_size,
                                     train=False,
                                     sort=False)
    # 辞書オブジェクトにまとめる
    dataloaders_dict = {"train": train_dl, "val": val_dl}

    return train_dl, val_dl, TEXT, dataloaders_dict
Пример #7
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--type",
                        default=None,
                        type=str,
                        required=True,
                        help=".")
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--train_file",
        default=None,
        type=str,
        help="CoQA json for training. E.g., coqa-train-v1.0.json")
    parser.add_argument(
        "--predict_file",
        default=None,
        type=str,
        help="CoQA json for predictions. E.g., coqa-dev-v1.0.json")
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    # parser.add_argument("--do_F1",
    #                     action='store_true',
    #                     help="Whether to calculating F1 score") # we don't talk anymore. please use official evaluation scripts
    parser.add_argument("--train_batch_size",
                        default=48,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--predict_batch_size",
                        default=48,
                        type=int,
                        help="Total batch size for predictions.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=2.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.06,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
        "of training.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help=
        "The total number of n-best predictions to generate in the nbest_predictions.json "
        "output file.")
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.")
    parser.add_argument(
        "--verbose_logging",
        action='store_true',
        help=
        "If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal CoQA evaluation.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument('--overwrite_output_dir',
                        action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--weight_decay', type=float, default=0, help="")
    parser.add_argument(
        '--null_score_diff_threshold',
        type=float,
        default=0.0,
        help=
        "If null_score - best_non_null is greater than the threshold predict null."
    )
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--logfile',
                        type=str,
                        default=None,
                        help='Which file to keep log.')
    parser.add_argument('--logmode',
                        type=str,
                        default=None,
                        help='logging mode, `w` or `a`')
    parser.add_argument('--tensorboard',
                        action='store_true',
                        help='no tensor board')
    parser.add_argument('--qa_tag',
                        action='store_true',
                        help='add qa tag or not')
    parser.add_argument('--history_len',
                        type=int,
                        default=2,
                        help='length of history')
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument('--logging_steps',
                        type=int,
                        default=50,
                        help="Log every X updates steps.")
    args = parser.parse_args()
    print(args)

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
        filename=args.logfile,
        filemode=args.logmode)

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_predict:
        raise ValueError(
            "At least one of `do_train` or `do_predict` must be True.")

    if args.do_train:
        if not args.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if args.do_predict:
        if not args.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified."
            )
    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir
    ) and args.do_train and not args.overwrite_output_dir:
        raise ValueError(
            "Output directory () already exists and is not empty.")
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab

    if args.do_train or args.do_predict:
        tokenizer = BertTokenizer.from_pretrained(
            args.bert_model, do_lower_case=args.do_lower_case)
        model = BertForCoQA.from_pretrained(args.bert_model)
        if args.local_rank == 0:
            torch.distributed.barrier()

        model.to(device)

    if args.do_train:
        if args.local_rank in [-1, 0] and args.tensorboard:
            from tensorboardX import SummaryWriter
            tb_writer = SummaryWriter()
        # Prepare data loader
        cached_train_features_file = args.train_file + '_{0}_{1}_{2}_{3}_{4}_{5}_{6}'.format(
            args.type, str(args.max_seq_length), str(args.doc_stride),
            str(args.max_query_length), str(args.max_answer_length),
            str(args.history_len), str(args.qa_tag))
        cached_train_examples_file = args.train_file + '_examples_{0}_{1}.pk'.format(
            str(args.history_len), str(args.qa_tag))

        # try train_examples
        try:
            with open(cached_train_examples_file, "rb") as reader:
                train_examples = pickle.load(reader)
        except:
            logger.info("  No cached file %s", cached_train_examples_file)
            train_examples = read_coqa_examples(input_file=args.train_file,
                                                history_len=args.history_len,
                                                add_QA_tag=args.qa_tag)
            if args.local_rank == -1 or torch.distributed.get_rank() == 0:
                logger.info("  Saving train examples into cached file %s",
                            cached_train_examples_file)
                with open(cached_train_examples_file, "wb") as writer:
                    pickle.dump(train_examples, writer)

        # print('DEBUG')
        # exit()

        # try train_features
        try:
            with open(cached_train_features_file, "rb") as reader:
                train_features = pickle.load(reader)
        except:
            logger.info("  No cached file %s", cached_train_features_file)
            train_features = convert_examples_to_features(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length,
                doc_stride=args.doc_stride,
                max_query_length=args.max_query_length,
            )
            if args.local_rank == -1 or torch.distributed.get_rank() == 0:
                logger.info("  Saving train features into cached file %s",
                            cached_train_features_file)
                with open(cached_train_features_file, "wb") as writer:
                    pickle.dump(train_features, writer)

        # print('DEBUG')
        # exit()

        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_start_positions = torch.tensor(
            [f.start_position for f in train_features], dtype=torch.long)
        all_end_positions = torch.tensor(
            [f.end_position for f in train_features], dtype=torch.long)
        all_rational_mask = torch.tensor(
            [f.rational_mask for f in train_features], dtype=torch.long)
        all_cls_idx = torch.tensor([f.cls_idx for f in train_features],
                                   dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_start_positions,
                                   all_end_positions, all_rational_mask,
                                   all_cls_idx)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)

        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)
        num_train_optimization_steps = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs
        # if args.local_rank != -1:
        #     num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()

        # Prepare optimizer
        param_optimizer = list(model.named_parameters())

        # hack to remove pooler, which is not used
        # thus it produce None grad that break apex
        # param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
        scheduler = WarmupLinearSchedule(
            optimizer,
            warmup_steps=int(args.warmup_proportion *
                             num_train_optimization_steps),
            t_total=num_train_optimization_steps)

        if args.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.fp16_opt_level)

        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

        if args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)

        global_step = 0
        tr_loss, logging_loss = 0.0, 0.0

        logger.info("***** Running training *****")
        logger.info("  Num orig examples = %d", len(train_examples))
        logger.info("  Num split examples = %d", len(train_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        model.train()
        for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
            for step, batch in enumerate(
                    tqdm(train_dataloader,
                         desc="Iteration",
                         disable=args.local_rank not in [-1, 0])):
                batch = tuple(
                    t.to(device)
                    for t in batch)  # multi-gpu does scattering it-self
                input_ids, input_mask, segment_ids, start_positions, end_positions, rational_mask, cls_idx = batch
                loss = model(input_ids, segment_ids, input_mask,
                             start_positions, end_positions, rational_mask,
                             cls_idx)
                # loss = gather(loss, 0)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    if args.max_grad_norm > 0:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                else:
                    loss.backward()
                    if args.max_grad_norm > 0:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)

                tr_loss += loss.item()

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1
                    if args.local_rank in [
                            -1, 0
                    ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                        if args.tensorboard:
                            tb_writer.add_scalar('lr',
                                                 scheduler.get_lr()[0],
                                                 global_step)
                            tb_writer.add_scalar('loss',
                                                 (tr_loss - logging_loss) /
                                                 args.logging_steps,
                                                 global_step)
                        else:
                            logger.info(
                                'Step: {}\tLearning rate: {}\tLoss: {}\t'.
                                format(global_step,
                                       scheduler.get_lr()[0],
                                       (tr_loss - logging_loss) /
                                       args.logging_steps))
                        logging_loss = tr_loss

    if args.do_train and (args.local_rank == -1
                          or torch.distributed.get_rank() == 0):
        # Save a trained model, configuration and tokenizer
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)

        torch.save(model_to_save.state_dict(), output_model_file)
        model_to_save.config.to_json_file(output_config_file)
        tokenizer.save_vocabulary(args.output_dir)

        # Load a trained model and vocabulary that you have fine-tuned
        model = BertForCoQA.from_pretrained(args.output_dir)
        tokenizer = BertTokenizer.from_pretrained(
            args.output_dir, do_lower_case=args.do_lower_case)

        # Good practice: save your training arguments together with the trained model
        output_args_file = os.path.join(args.output_dir, 'training_args.bin')
        torch.save(args, output_args_file)
    else:
        model = BertForCoQA.from_pretrained(args.bert_model)

    model.to(device)

    if args.do_predict and (args.local_rank == -1
                            or torch.distributed.get_rank() == 0):
        cached_eval_features_file = args.predict_file + '_{0}_{1}_{2}_{3}_{4}_{5}_{6}'.format(
            args.type, str(args.max_seq_length), str(args.doc_stride),
            str(args.max_query_length), str(args.max_answer_length),
            str(args.history_len), str(args.qa_tag))
        cached_eval_examples_file = args.predict_file + '_examples_{0}_{1}.pk'.format(
            str(args.history_len), str(args.qa_tag))

        # try eval_examples
        try:
            with open(cached_eval_examples_file, 'rb') as reader:
                eval_examples = pickle.load(reader)
        except:
            logger.info("No cached file: %s", cached_eval_examples_file)
            eval_examples = read_coqa_examples(input_file=args.predict_file,
                                               history_len=args.history_len,
                                               add_QA_tag=args.qa_tag)
            logger.info("  Saving eval examples into cached file %s",
                        cached_eval_examples_file)
            with open(cached_eval_examples_file, 'wb') as writer:
                pickle.dump(eval_examples, writer)

        # try eval_features
        try:
            with open(cached_eval_features_file, "rb") as reader:
                eval_features = pickle.load(reader)
        except:
            logger.info("No cached file: %s", cached_eval_features_file)
            eval_features = convert_examples_to_features(
                examples=eval_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length,
                doc_stride=args.doc_stride,
                max_query_length=args.max_query_length,
            )
            logger.info("  Saving eval features into cached file %s",
                        cached_eval_features_file)
            with open(cached_eval_features_file, "wb") as writer:
                pickle.dump(eval_features, writer)

        # print('DEBUG')
        # exit()

        logger.info("***** Running predictions *****")
        logger.info("  Num orig examples = %d", len(eval_examples))
        logger.info("  Num split examples = %d", len(eval_features))
        logger.info("  Batch size = %d", args.predict_batch_size)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_example_index = torch.arange(all_input_ids.size(0),
                                         dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_example_index)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.predict_batch_size)

        model.eval()
        all_results = []
        logger.info("Start evaluating")
        for input_ids, input_mask, segment_ids, example_indices in tqdm(
                eval_dataloader,
                desc="Evaluating",
                disable=args.local_rank not in [-1, 0]):
            # if len(all_results) % 1000 == 0:
            #     logger.info("Processing example: %d" % (len(all_results)))
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            with torch.no_grad():
                batch_start_logits, batch_end_logits, batch_yes_logits, batch_no_logits, batch_unk_logits = model(
                    input_ids, segment_ids, input_mask)
            for i, example_index in enumerate(example_indices):
                start_logits = batch_start_logits[i].detach().cpu().tolist()
                end_logits = batch_end_logits[i].detach().cpu().tolist()
                yes_logits = batch_yes_logits[i].detach().cpu().tolist()
                no_logits = batch_no_logits[i].detach().cpu().tolist()
                unk_logits = batch_unk_logits[i].detach().cpu().tolist()
                eval_feature = eval_features[example_index.item()]
                unique_id = int(eval_feature.unique_id)
                all_results.append(
                    RawResult(unique_id=unique_id,
                              start_logits=start_logits,
                              end_logits=end_logits,
                              yes_logits=yes_logits,
                              no_logits=no_logits,
                              unk_logits=unk_logits))
        output_prediction_file = os.path.join(args.output_dir,
                                              "predictions.json")
        output_nbest_file = os.path.join(args.output_dir,
                                         "nbest_predictions.json")
        output_null_log_odds_file = os.path.join(args.output_dir,
                                                 "null_odds.json")
        write_predictions(eval_examples, eval_features, all_results,
                          args.n_best_size, args.max_answer_length,
                          args.do_lower_case, output_prediction_file,
                          output_nbest_file, output_null_log_odds_file,
                          args.verbose_logging, args.null_score_diff_threshold)
Пример #8
0
    def __init__(
        self,
        models,
        tgt_dict,
        beam_size=1,
        max_len_a=0,
        max_len_b=200,
        max_len=0,
        min_len=1,
        normalize_scores=True,
        len_penalty=1.0,
        unk_penalty=0.0,
        temperature=1.0,
        match_source_len=False,
        no_repeat_ngram_size=0,
        search_strategy=None,
        eos=None,
        symbols_to_strip_from_output=None,
        lm_model=None,
        lm_weight=1.0,
        args=None,
    ):
        """Generates translations of a given source sentence.

        Args:
            models (List[~fairseq.models.FairseqModel]): ensemble of models,
                currently support fairseq.models.TransformerModel for scripting
            beam_size (int, optional): beam width (default: 1)
            max_len_a/b (int, optional): generate sequences of maximum length
                ax + b, where x is the source length
            max_len (int, optional): the maximum length of the generated output
                (not including end-of-sentence)
            min_len (int, optional): the minimum length of the generated output
                (not including end-of-sentence)
            normalize_scores (bool, optional): normalize scores by the length
                of the output (default: True)
            len_penalty (float, optional): length penalty, where <1.0 favors
                shorter, >1.0 favors longer sentences (default: 1.0)
            unk_penalty (float, optional): unknown word penalty, where <0
                produces more unks, >0 produces fewer (default: 0.0)
            temperature (float, optional): temperature, where values
                >1.0 produce more uniform samples and values <1.0 produce
                sharper samples (default: 1.0)
            match_source_len (bool, optional): outputs should match the source
                length (default: False)
        """
        super().__init__()
        if isinstance(models, EnsembleModel):
            self.model = models
        else:
            self.model = EnsembleModel(models)
        self.tgt_dict = tgt_dict
        self.pad = tgt_dict.pad()
        self.unk = tgt_dict.unk()
        self.eos = tgt_dict.eos() if eos is None else eos
        self.symbols_to_strip_from_output = (
            symbols_to_strip_from_output.union({self.eos})
            if symbols_to_strip_from_output is not None else {self.eos})
        self.vocab_size = len(tgt_dict)
        self.beam_size = beam_size
        # the max beam size is the dictionary size - 1, since we never select pad
        self.beam_size = min(beam_size, self.vocab_size - 1)
        self.max_len_a = max_len_a
        self.max_len_b = max_len_b
        self.min_len = min_len
        self.max_len = max_len or self.model.max_decoder_positions()

        self.normalize_scores = normalize_scores
        self.len_penalty = len_penalty
        self.unk_penalty = unk_penalty
        self.temperature = temperature
        self.match_source_len = match_source_len

        self.use_bertinput = args.use_bertinput
        self.berttokenizer = BertTokenizer.from_pretrained(
            args.bert_model_name, do_lower_case=False)

        self.use_bartinput = args.use_bartinput
        if self.use_bartinput:
            self.barttokenizer = BartTokenizer.from_pretrained(
                args.bart_model_name, do_lower_case=False)
        self.use_electrainput = args.use_electrainput
        if self.use_electrainput:
            self.electratokenizer = ElectraTokenizer.from_pretrained(
                args.electra_model_name)

        # not implemented yet.
        # self.use_bertinput = args.use_bertinput
        # self.mask_lm = args.mask_lm
        # self.bert_ner = args.bert_ner
        # self.bert_sst = args.bert_sst

        if no_repeat_ngram_size > 0:
            self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size)
        else:
            self.repeat_ngram_blocker = None

        assert temperature > 0, "--temperature must be greater than 0"

        self.search = (search.BeamSearch(tgt_dict)
                       if search_strategy is None else search_strategy)
        # We only need to set src_lengths in LengthConstrainedBeamSearch.
        # As a module attribute, setting it would break in multithread
        # settings when the model is shared.
        self.should_set_src_lengths = (hasattr(self.search,
                                               "needs_src_lengths")
                                       and self.search.needs_src_lengths)

        self.model.eval()

        self.lm_model = lm_model
        self.lm_weight = lm_weight
        if self.lm_model is not None:
            self.lm_model.eval()
Пример #9
0
 def __init__(self, args):
     self.data_dir = args.data_dir
     self.nega_num = args.nega_num
     self.tokenizer = BertTokenizer.from_pretrained(args.bert_model_dir,
                                                    do_lower_case=True)
Пример #10
0
def load_langpair_dataset(
    data_path, split,
    src, src_dict,
    tgt, tgt_dict,
    combine, dataset_impl, upsample_primary,
    left_pad_source, left_pad_target, max_source_positions, max_target_positions,
    ratio, pred_probs, bert_model_name,
):
    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []
    srcbert_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')

        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
            bertprefix = os.path.join(data_path, '{}.bert.{}-{}.'.format(split_k, src, tgt))
        elif split_exists(split_k, tgt, src, src, data_path):
            prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
            bertprefix = os.path.join(data_path, '{}.bert.{}-{}.'.format(split_k, tgt, src))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
        src_datasets.append(indexed_dataset.make_dataset(prefix + src, impl=dataset_impl,
                                                         fix_lua_indexing=True, dictionary=src_dict))
        tgt_datasets.append(indexed_dataset.make_dataset(prefix + tgt, impl=dataset_impl,
                                                         fix_lua_indexing=True, dictionary=tgt_dict))
        srcbert_datasets.append(indexed_dataset.make_dataset(bertprefix + src, impl=dataset_impl,
                                                         fix_lua_indexing=True, ))

        print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets)

    if len(src_datasets) == 1:
        src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
        srcbert_datasets = srcbert_datasets[0]
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)

    berttokenizer = BertTokenizer.from_pretrained(bert_model_name)
    if split == 'test':
        return BertLanguagePairDataset(
            src_dataset, src_dataset.sizes, src_dict,
            tgt_dataset, tgt_dataset.sizes, tgt_dict,
            left_pad_source=left_pad_source,
            left_pad_target=left_pad_target,
            max_source_positions=max_source_positions,
            max_target_positions=max_target_positions,
            srcbert=srcbert_datasets,
            srcbert_sizes=srcbert_datasets.sizes if srcbert_datasets is not None else None,
            berttokenizer=berttokenizer,
        )
    else:
        return BertXYNoisyLanguagePairDataset(
            src_dataset, src_dataset.sizes, src_dict,
            tgt_dataset, tgt_dataset.sizes, tgt_dict,
            left_pad_source=left_pad_source,
            left_pad_target=left_pad_target,
            max_source_positions=max_source_positions,
            max_target_positions=max_target_positions,
            shuffle=True,
            ratio=ratio,
            pred_probs=pred_probs,
            srcbert=srcbert_datasets,
            srcbert_sizes=srcbert_datasets.sizes if srcbert_datasets is not None else None,
            berttokenizer=berttokenizer,
        )
Пример #11
0
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    num_buckets=0,
    shuffle=True,
    pad_to_multiple=1,
    prepend_bos_src=None,
    bert_model_name=None,
    bart_model_name=None,
    electra_model_name=None,
    electra_pretrain=False,
    denoising=False,
    masking=False,
    extra_data=False,
    input_mapping=False,
    mask_ratio=None,
    random_ratio=None,
    insert_ratio=None,
    rotate_ratio=None,
    permute_sentence_ratio=None,
):
    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path,
                                "{}.{}-{}.{}".format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []
    bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name,
                                                   do_lower_case=False)
    if denoising:
        bart_tokenizer = AutoTokenizer.from_pretrained(bart_model_name,
                                                       do_lower_case=False)
        #bart_tokenizer = BartTokenizer.from_pretrained(bart_model_name, do_lower_case=False)
    if electra_pretrain:
        electra_tokenizer = ElectraTokenizer.from_pretrained(
            electra_model_name)
    srcbert_datasets = []
    extra_datasets = []
    extra_bert_datasets = []
    extra_bert_mapping_datasets = []
    extra_bart_datasets = []
    extra_bart_mapping_datasets = []
    if denoising:
        srcbart_datasets = []
    if electra_pretrain:
        srcelectra_datasets = []
    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else "")

        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            prefix = os.path.join(data_path,
                                  "{}.{}-{}.".format(split_k, src, tgt))
            bertprefix = os.path.join(
                data_path, '{}.bert.{}-{}.'.format(split_k, src, tgt))
            bert_mapping_prefix = os.path.join(
                data_path, '{}.bert.map.{}-{}.'.format(split_k, src, tgt))

            if denoising:
                bartprefix = os.path.join(
                    data_path, '{}.bart.{}-{}.'.format(split_k, src, tgt))
                bart_mapping_prefix = os.path.join(
                    data_path, '{}.bart.map.{}-{}.'.format(split_k, src, tgt))

            if electra_pretrain:
                electraprefix = os.path.join(
                    data_path, '{}.electra.{}-{}.'.format(split_k, src, tgt))
                electra_mapping_prefix = os.path.join(
                    data_path,
                    '{}.electra.map.{}-{}.'.format(split_k, src, tgt))

            if extra_data:
                extraprefix = os.path.join(
                    data_path, '{}.extra.{}-{}.'.format(split_k, src, tgt))
                extra_bert_prefix = os.path.join(
                    data_path,
                    '{}.extra.bert.{}-{}.'.format(split_k, src, tgt))
                extra_bert_mapping_prefix = os.path.join(
                    data_path,
                    '{}.extra.bert.map.{}-{}.'.format(split_k, src, tgt))
                extra_bart_prefix = os.path.join(
                    data_path,
                    '{}.extra.bart.{}-{}.'.format(split_k, src, tgt))
                extra_bart_mapping_prefix = os.path.join(
                    data_path,
                    '{}.extra.bart.map.{}-{}.'.format(split_k, src, tgt))

        elif split_exists(split_k, tgt, src, src, data_path):
            prefix = os.path.join(data_path,
                                  "{}.{}-{}.".format(split_k, tgt, src))
            bertprefix = os.path.join(
                data_path, '{}.bert.{}-{}.'.format(split_k, tgt, src))
            bert_mapping_prefix = os.path.join(
                data_path, '{}.bert.map.{}-{}.'.format(split_k, src, tgt))

            if denoising:
                bartprefix = os.path.join(
                    data_path, '{}.bart.{}-{}.'.format(split_k, tgt, src))
                bart_mapping_prefix = os.path.join(
                    data_path, '{}.bart.map.{}-{}.'.format(split_k, src, tgt))

            if electra_pretrain:
                electraprefix = os.path.join(
                    data_path, '{}.electra.{}-{}.'.format(split_k, src, tgt))
                electra_mapping_prefix = os.path.join(
                    data_path,
                    '{}.electra.map.{}-{}.'.format(split_k, src, tgt))

            if extra_data:
                extraprefix = os.path.join(
                    data_path, '{}.extra.{}-{}.'.format(split_k, src, tgt))
                extra_bert_prefix = os.path.join(
                    data_path,
                    '{}.extra.bert.{}-{}.'.format(split_k, src, tgt))
                extra_bert_mapping_prefix = os.path.join(
                    data_path,
                    '{}.extra.bert.map.{}-{}.'.format(split_k, src, tgt))
                extra_bart_prefix = os.path.join(
                    data_path,
                    '{}.extra.bart.{}-{}.'.format(split_k, src, tgt))
                extra_bart_mapping_prefix = os.path.join(
                    data_path,
                    '{}.extra.bart.map.{}-{}.'.format(split_k, src, tgt))

        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError("Dataset not found: {} ({})".format(
                    split, data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict,
                                                      dataset_impl)
        if truncate_source:
            src_dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(src_dataset, src_dict.eos()),
                    max_source_positions - 1,
                ),
                src_dict.eos(),
            )
        src_datasets.append(src_dataset)

        tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict,
                                                      dataset_impl)
        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        # srcbert_datasets.append(indexed_dataset.make_dataset(bertprefix + src, impl=dataset_impl,
        #                                                      fix_lua_indexing=True, ))
        # if denoising:
        #     srcbart_datasets.append(indexed_dataset.make_dataset(bartprefix + src, impl=dataset_impl,
        #                                                          fix_lua_indexing=True, ))
        # if extra_data:
        #     extra_datasets.append(indexed_dataset.make_dataset(extraprefix + src, impl=dataset_impl,
        #                                                        fix_lua_indexing=True, ))
        srcbert_datasets.append(
            data_utils.load_indexed_dataset(
                bertprefix + src,
                dataset_impl=dataset_impl,
            ))
        if denoising:
            srcbart_datasets.append(
                data_utils.load_indexed_dataset(
                    bartprefix + src,
                    dataset_impl=dataset_impl,
                ))
        if electra_pretrain:
            srcelectra_datasets.append(
                data_utils.load_indexed_dataset(
                    electraprefix + src,
                    dataset_impl=dataset_impl,
                ))
        if extra_data and split == 'train':
            extra_datasets.append(
                data_utils.load_indexed_dataset(
                    extraprefix + src,
                    dataset_impl=dataset_impl,
                ))
            extra_bert_datasets.append(
                data_utils.load_indexed_dataset(
                    extra_bert_prefix + src,
                    dataset_impl=dataset_impl,
                ))
            extra_bert_mapping_datasets.append(
                data_utils.load_indexed_dataset(
                    extra_bert_mapping_prefix + src,
                    dataset_impl=dataset_impl,
                ))
            extra_bart_datasets.append(
                data_utils.load_indexed_dataset(
                    extra_bart_prefix + src,
                    dataset_impl=dataset_impl,
                ))
            extra_bart_mapping_datasets.append(
                data_utils.load_indexed_dataset(
                    extra_bart_mapping_prefix + src,
                    dataset_impl=dataset_impl,
                ))
            #import pdb; pdb.set_trace()
            assert extra_datasets != [] or extra_bert_datasets != [] or extra_bert_mapping_datasets != [] or extra_bart_datasets != [] or extra_bart_mapping_datasets != []

            #extra_datasets = extra_datasets[0]
        #import pdb; pdb.set_trace()
        src_datasets[-1] = PrependTokenDataset(src_datasets[-1],
                                               token=src_dict.bos_index)
        if extra_data and split == 'train':
            extra_datasets[-1] = PrependTokenDataset(extra_datasets[-1],
                                                     token=src_dict.bos_index)
        if denoising is True:
            if input_mapping is True and split == 'train':
                bart_mapping_dataset = data_utils.load_indexed_dataset(
                    bart_mapping_prefix + src, dataset_impl=dataset_impl)
            else:
                bart_mapping_dataset = None

            src_datasets[-1] = DenoisingBartDataset(
                src_datasets[-1],
                src_datasets[-1].sizes,
                src_dict,
                srcbart_datasets[-1],
                srcbart_datasets[-1].sizes,
                bart_tokenizer,
                map_dataset=bart_mapping_dataset,
                mask_ratio=mask_ratio,
                random_ratio=random_ratio,
                insert_ratio=insert_ratio,
                rotate_ratio=rotate_ratio,
                permute_sentence_ratio=permute_sentence_ratio,
            )

        if electra_pretrain is True:
            if input_mapping is True and split == 'train':
                electra_mapping_dataset = data_utils.load_indexed_dataset(
                    electra_mapping_prefix + src, dataset_impl=dataset_impl)
            else:
                electra_mapping_dataset = None

            src_datasets[-1] = ElectrapretrainDataset(
                src_datasets[-1],
                src_datasets[-1].sizes,
                src_dict,
                srcelectra_datasets[-1],
                srcelectra_datasets[-1].sizes,
                electra_tokenizer,
                map_dataset=electra_mapping_dataset,
                left_pad_source=left_pad_source,
                left_pad_target=left_pad_target,
                max_source_positions=max_source_positions,
                max_target_positions=max_target_positions,
            )

        if masking is True:
            if input_mapping is True and split == 'train':
                #bert_mapping_dataset = indexed_dataset.make_dataset(bert_mapping_prefix + src, impl=dataset_impl, fix_lua_indexing=True)
                bert_mapping_dataset = data_utils.load_indexed_dataset(
                    bert_mapping_prefix + src, dataset_impl=dataset_impl)
            else:
                bert_mapping_dataset = None
            src_datasets[-1] = MaskingDataset(
                src_datasets[-1],
                src_datasets[-1].sizes,
                src_dict,
                srcbert_datasets[-1],
                srcbert_datasets[-1].sizes,
                bert_tokenizer,
                map_dataset=bert_mapping_dataset,
                left_pad_source=left_pad_source,
                left_pad_target=left_pad_target,
                max_source_positions=max_source_positions,
                max_target_positions=max_target_positions,
            )

        if extra_data is True and split == 'train':

            assert input_mapping is True
            src_datasets[-1] = MaskingExtraDataset(
                src_datasets[-1],
                src_datasets[-1].sizes,
                src_dict,
                extra_datasets[-1],
                extra_datasets[-1].sizes,
                extra_bert_datasets[-1],
                extra_bert_datasets[-1].sizes,
                bert_tokenizer,
                map_dataset=extra_bert_mapping_datasets[-1],
                left_pad_source=left_pad_source,
                left_pad_target=left_pad_target,
                max_source_positions=max_source_positions,
                max_target_positions=max_target_positions,
            )

            src_datasets[-1] = DenoisingBartExtraDataset(
                src_datasets[-1],
                src_datasets[-1].sizes,
                src_dict,
                extra_datasets[-1],
                extra_datasets[-1].sizes,
                extra_bart_datasets[-1],
                extra_bart_datasets[-1].sizes,
                bart_tokenizer,
                map_dataset=extra_bart_mapping_datasets[-1],
            )

        logger.info("{} {} {}-{} {} examples".format(data_path, split_k,
                                                     src, tgt,
                                                     len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
        # srcbert_datasets = srcbert_datasets[0]
        # if denoising:
        #     srcbart_datasets = srcbart_datasets[0]

    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
    elif prepend_bos_src is not None:
        logger.info(f"prepending src bos: {prepend_bos_src}")
        src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src)

    eos = None
    if append_source_id:
        src_dataset = AppendTokenDataset(src_dataset,
                                         src_dict.index("[{}]".format(src)))
        if tgt_dataset is not None:
            tgt_dataset = AppendTokenDataset(
                tgt_dataset, tgt_dict.index("[{}]".format(tgt)))
        eos = tgt_dict.index("[{}]".format(tgt))

    align_dataset = None
    if load_alignments:
        align_path = os.path.join(data_path,
                                  "{}.align.{}-{}".format(split, src, tgt))
        if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
            align_dataset = data_utils.load_indexed_dataset(
                align_path, None, dataset_impl)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    src_bart_dataset = None
    src_bert_dataset = None
    src_electra_dataset = None

    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        masking,
        src_bert_dataset,
        denoising,
        src_bart_dataset,
        src_electra_dataset,
        #extra_datasets,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        align_dataset=align_dataset,
        eos=eos,
        num_buckets=num_buckets,
        shuffle=shuffle,
        pad_to_multiple=pad_to_multiple,
    )