def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present in older models base_architecture(args) if not hasattr(args, 'max_source_positions'): args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS if not hasattr(args, 'max_target_positions'): args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS src_dict, tgt_dict = task.source_dictionary, task.target_dictionary src_berttokenizer = BertTokenizer.from_pretrained(args.bert_model_name) tgt_berttokenizer = BertTokenizer.from_pretrained( args.decoder_bert_model_name) assert src_berttokenizer.pad() == tgt_berttokenizer.pad() bertdecoder = BertAdapterDecoderFull.from_pretrained( args.decoder_bert_model_name, args, from_scratch=args.train_from_scratch) enc_top_layer_adapter = getattr(args, 'enc_top_layer_adapter', -1) adapter_dimension = getattr(args, 'adapter_dimension', 2048) bertencoder = BertModelWithAdapter.from_pretrained( args.bert_model_name, adapter_dimension, enc_top_layer_adapter, from_scratch=args.train_from_scratch) return cls(bertencoder, bertdecoder, src_berttokenizer, tgt_berttokenizer, args)
def main(args): # filename = '/mnt/yardcephfs/mmyard/g_wxg_td_prc/mt/v_xyvhuang/bert-nmt/examples/copy_translation/headtest.bert.en' filename = args.file_name tokenizer = [] # bert_tokenizer = AutoTokenizer.from_pretrained(args.bert_tokenizer, do_lower_case=False) # bart_tokenizer = AutoTokenizer.from_pretrained(args.bart_tokenizer, do_lower_case=False) bert_tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer, do_lower_case=False) bert_tokenizer.name_or_path = args.bert_tokenizer bart_tokenizer = BartTokenizer.from_pretrained(args.bart_tokenizer, do_lower_case=False) electra_tokenizer = ElectraTokenizer.from_pretrained( args.electra_tokenizer) #xlnet_tokenizer = AutoTokenizer.from_pretrained('xlnet-base-cased',cache_dir='/mnt/yardcephfs/mmyard/g_wxg_td_prc/mt/v_xyvhuang/ft_local/bart-base') tokenizer.append(bert_tokenizer) tokenizer.append(bart_tokenizer) tokenizer.append(electra_tokenizer) #tokenizer.append(xlnet_tokenizer) encoder_inputs, sentence_splits, extra_outs, drop_list = add_line( filename, tokenizer, add_extra_outs=args.extra_outs, n_process=args.n_process) # save_path = '/mnt/yardcephfs/mmyard/g_wxg_td_prc/mt/v_xyvhuang/bert-nmt/examples/copy_translation/encoder.en' save_path = args.output_path drop_path = args.drop_path # dict_save_path = '/mnt/yardcephfs/mmyard/g_wxg_td_prc/mt/v_xyvhuang/data/bert-nmt/destdir-mult/encoder_dict' save_txt(save_path, sentence_splits) save_txt(drop_path, drop_list) if len(extra_outs) != []: save_txt(save_path + '.bert', extra_outs[0]) save_txt(save_path + '.bart', extra_outs[1]) save_txt(save_path + '.electra', extra_outs[1]) output_dict_save_path = args.output_path + '.data_dict' save_input_dict(output_dict_save_path, encoder_inputs)
def make_batches(lines, args, task, max_positions, encode_fn): oldlines = lines lines = oldlines[0::2] bertlines = oldlines[1::2] tokens = [ task.source_dictionary.encode_line(encode_fn(src_str), add_if_not_exist=False).long() for src_str in lines ] bertdict = BertTokenizer.from_pretrained(args.bert_model_name) def getbert(line): line = line.strip() line = '{} {} {}'.format('[CLS]', line, '[SEP]') tokenizedline = bertdict.tokenize(line) if len(tokenizedline) > bertdict.max_len: tokenizedline = tokenizedline[:bertdict.max_len - 1] tokenizedline.append('[SEP]') words = bertdict.convert_tokens_to_ids(tokenizedline) nwords = len(words) ids = torch.IntTensor(nwords) for i, word in enumerate(words): ids[i] = word return ids.long() berttokens = [getbert(x) for x in bertlines] lengths = torch.LongTensor([t.numel() for t in tokens]) bertlengths = torch.LongTensor([t.numel() for t in berttokens]) itr = task.get_batch_iterator( dataset=task.build_dataset_for_inference(tokens, lengths, berttokens, bertlengths, bertdict), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions, ).next_epoch_itr(shuffle=False) for batch in itr: yield Batch(ids=batch['id'], src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'], bert_input=batch['net_input']['bert_input'])
from dataloader import get_chABSA_DataLoaders_and_TEXT from bert import BertTokenizer train_dl, val_dl, TEXT, dataloaders_dict = get_chABSA_DataLoaders_and_TEXT( max_length=256, batch_size=32) #print(train_dl) # 動作確認 検証データのデータセットで確認 batch = next(iter(train_dl)) print("Textの形状=", batch.Text[0].shape) print("Labelの形状=", batch.Label.shape) print(batch.Text) print(batch.A_label) print(batch.Label) # ミニバッチの1文目を確認してみる tokenizer_bert = BertTokenizer(vocab_file="./vocab/vocab.txt", do_lower_case=False) text_minibatch_1 = (batch.Label).numpy() # IDを単語に戻す text = tokenizer_bert.convert_ids_to_tokens(text_minibatch_1) print(text)
def main(args): utils.import_user_module(args) os.makedirs(args.destdir, exist_ok=True) logger.addHandler( logging.FileHandler(filename=os.path.join(args.destdir, "preprocess.log"), )) logger.info(args) task = tasks.get_task(args.task) def train_path(lang): return "{}{}".format(args.trainpref, ("." + lang) if lang else "") def file_name(prefix, lang): fname = prefix if lang is not None: fname += ".{lang}".format(lang=lang) return fname def dest_path(prefix, lang): return os.path.join(args.destdir, file_name(prefix, lang)) def dict_path(lang): return dest_path("dict", lang) + ".txt" def build_dictionary(filenames, src=False, tgt=False): assert src ^ tgt return task.build_dictionary( filenames, workers=args.workers, threshold=args.thresholdsrc if src else args.thresholdtgt, nwords=args.nwordssrc if src else args.nwordstgt, padding_factor=args.padding_factor, ) target = not args.only_source if not args.srcdict and os.path.exists(dict_path(args.source_lang)): raise FileExistsError(dict_path(args.source_lang)) if target and not args.tgtdict and os.path.exists( dict_path(args.target_lang)): raise FileExistsError(dict_path(args.target_lang)) if args.joined_dictionary: assert ( not args.srcdict or not args.tgtdict ), "cannot use both --srcdict and --tgtdict with --joined-dictionary" if args.srcdict: src_dict = task.load_dictionary(args.srcdict) elif args.tgtdict: src_dict = task.load_dictionary(args.tgtdict) else: assert (args.trainpref ), "--trainpref must be set if --srcdict is not specified" src_dict = build_dictionary( { train_path(lang) for lang in [args.source_lang, args.target_lang] }, src=True, ) tgt_dict = src_dict else: if args.srcdict: src_dict = task.load_dictionary(args.srcdict) else: assert (args.trainpref ), "--trainpref must be set if --srcdict is not specified" src_dict = build_dictionary([train_path(args.source_lang)], src=True) if target: if args.tgtdict: tgt_dict = task.load_dictionary(args.tgtdict) else: assert ( args.trainpref ), "--trainpref must be set if --tgtdict is not specified" tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True) else: tgt_dict = None src_dict.save(dict_path(args.source_lang)) if target and tgt_dict is not None: tgt_dict.save(dict_path(args.target_lang)) def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers): logger.info("[{}] Dictionary: {} types".format(lang, len(vocab))) output_prefix += '.bert' if isinstance(vocab, BertTokenizer) else '' input_prefix += '.bert' if isinstance(vocab, BertTokenizer) else '' n_seq_tok = [0, 0] replaced = Counter() def merge_result(worker_result): replaced.update(worker_result["replaced"]) n_seq_tok[0] += worker_result["nseq"] n_seq_tok[1] += worker_result["ntok"] input_file = "{}{}".format(input_prefix, ("." + lang) if lang is not None else "") offsets = Binarizer.find_offsets(input_file, num_workers) pool = None if num_workers > 1: pool = Pool(processes=num_workers - 1) for worker_id in range(1, num_workers): prefix = "{}{}".format(output_prefix, worker_id) pool.apply_async( binarize, ( args, input_file, vocab, prefix, lang, offsets[worker_id], offsets[worker_id + 1], ), callback=merge_result, ) pool.close() ds = indexed_dataset.make_builder( dataset_dest_file(args, output_prefix, lang, "bin"), impl=args.dataset_impl, vocab_size=len(vocab), ) merge_result( Binarizer.binarize(input_file, vocab, lambda t: ds.add_item(t), offset=0, end=offsets[1])) if num_workers > 1: pool.join() for worker_id in range(1, num_workers): prefix = "{}{}".format(output_prefix, worker_id) temp_file_path = dataset_dest_prefix(args, prefix, lang) ds.merge_file_(temp_file_path) os.remove(indexed_dataset.data_file_path(temp_file_path)) os.remove(indexed_dataset.index_file_path(temp_file_path)) ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx")) logger.info( "[{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format( lang, input_file, n_seq_tok[0], n_seq_tok[1], 100 * sum(replaced.values()) / n_seq_tok[1], vocab.unk_word, )) def make_binary_alignment_dataset(input_prefix, output_prefix, num_workers): nseq = [0] def merge_result(worker_result): nseq[0] += worker_result["nseq"] input_file = input_prefix offsets = Binarizer.find_offsets(input_file, num_workers) pool = None if num_workers > 1: pool = Pool(processes=num_workers - 1) for worker_id in range(1, num_workers): prefix = "{}{}".format(output_prefix, worker_id) pool.apply_async( binarize_alignments, ( args, input_file, utils.parse_alignment, prefix, offsets[worker_id], offsets[worker_id + 1], ), callback=merge_result, ) pool.close() ds = indexed_dataset.make_builder(dataset_dest_file( args, output_prefix, None, "bin"), impl=args.dataset_impl) merge_result( Binarizer.binarize_alignments( input_file, utils.parse_alignment, lambda t: ds.add_item(t), offset=0, end=offsets[1], )) if num_workers > 1: pool.join() for worker_id in range(1, num_workers): prefix = "{}{}".format(output_prefix, worker_id) temp_file_path = dataset_dest_prefix(args, prefix, None) ds.merge_file_(temp_file_path) os.remove(indexed_dataset.data_file_path(temp_file_path)) os.remove(indexed_dataset.index_file_path(temp_file_path)) ds.finalize(dataset_dest_file(args, output_prefix, None, "idx")) logger.info("[alignments] {}: parsed {} alignments".format( input_file, nseq[0])) def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1): if args.dataset_impl == "raw": # Copy original text file to destination folder output_text_file = dest_path( output_prefix + ".{}-{}".format(args.source_lang, args.target_lang), lang, ) shutil.copyfile(file_name(input_prefix, lang), output_text_file) else: make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers) def make_all(lang, vocab): if args.trainpref: make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers) if args.validpref: for k, validpref in enumerate(args.validpref.split(",")): outprefix = "valid{}".format(k) if k > 0 else "valid" make_dataset(vocab, validpref, outprefix, lang, num_workers=args.workers) if args.testpref: for k, testpref in enumerate(args.testpref.split(",")): outprefix = "test{}".format(k) if k > 0 else "test" make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers) def make_all_alignments(): if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix): make_binary_alignment_dataset( args.trainpref + "." + args.align_suffix, "train.align", num_workers=args.workers, ) if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix): make_binary_alignment_dataset( args.validpref + "." + args.align_suffix, "valid.align", num_workers=args.workers, ) if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix): make_binary_alignment_dataset( args.testpref + "." + args.align_suffix, "test.align", num_workers=args.workers, ) make_all(args.source_lang, src_dict) if target: make_all(args.target_lang, tgt_dict) berttokenizer = BertTokenizer.from_pretrained(args.bert_model_name) make_all(args.source_lang, berttokenizer) if args.align_suffix: make_all_alignments() logger.info("Wrote preprocessed data to {}".format(args.destdir)) if args.alignfile: assert args.trainpref, "--trainpref must be set if --alignfile is specified" src_file_name = train_path(args.source_lang) tgt_file_name = train_path(args.target_lang) freq_map = {} with open(args.alignfile, "r", encoding="utf-8") as align_file: with open(src_file_name, "r", encoding="utf-8") as src_file: with open(tgt_file_name, "r", encoding="utf-8") as tgt_file: for a, s, t in zip_longest(align_file, src_file, tgt_file): si = src_dict.encode_line(s, add_if_not_exist=False) ti = tgt_dict.encode_line(t, add_if_not_exist=False) ai = list(map(lambda x: tuple(x.split("-")), a.split())) for sai, tai in ai: srcidx = si[int(sai)] tgtidx = ti[int(tai)] if srcidx != src_dict.unk( ) and tgtidx != tgt_dict.unk(): assert srcidx != src_dict.pad() assert srcidx != src_dict.eos() assert tgtidx != tgt_dict.pad() assert tgtidx != tgt_dict.eos() if srcidx not in freq_map: freq_map[srcidx] = {} if tgtidx not in freq_map[srcidx]: freq_map[srcidx][tgtidx] = 1 else: freq_map[srcidx][tgtidx] += 1 align_dict = {} for srcidx in freq_map.keys(): align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get) with open( os.path.join( args.destdir, "alignment.{}-{}.txt".format(args.source_lang, args.target_lang), ), "w", encoding="utf-8", ) as f: for k, v in align_dict.items(): print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
def get_chABSA_DataLoaders_and_TEXT(max_length=256, batch_size=32): """IMDbのDataLoaderとTEXTオブジェクトを取得する。 """ # 乱数のシードを設定 torch.manual_seed(1234) np.random.seed(1234) random.seed(1234) # 単語分割用のTokenizerを用意 tokenizer_bert = BertTokenizer(vocab_file=VOCAB_FILE, do_lower_case=False) def preprocessing_text(text): # 半角・全角の統一 text = mojimoji.han_to_zen(text) # 改行、半角スペース、全角スペースを削除 text = re.sub('\r', '', text) text = re.sub('\n', '', text) text = re.sub(' ', '', text) text = re.sub(' ', '', text) # 数字文字の一律「0」化 text = re.sub(r'[0-9 0-9]+', '0', text) # 数字 # カンマ、ピリオド以外の記号をスペースに置換 for p in string.punctuation: if (p == ".") or (p == ","): continue else: text = text.replace(p, " ") return text # 前処理と単語分割をまとめた関数を定義 # 単語分割の関数を渡すので、tokenizer_bertではなく、tokenizer_bert.tokenizeを渡す点に注意 def tokenizer_with_preprocessing(text, tokenizer=tokenizer_bert.tokenize): text = preprocessing_text(text) ret = tokenizer(text) # tokenizer_bert return ret # データを読み込んだときに、読み込んだ内容に対して行う処理を定義します # 読み込むデータのカラムを "torchtext.data.Field" を用いて定義する # "is_target=True" ラベルフィールドかどうかの設定 (デフォルト:False) max_length = 256 TEXT = torchtext.data.Field(sequential=True, tokenize=tokenizer_with_preprocessing, use_vocab=True, lower=False, include_lengths=True, batch_first=True, fix_length=max_length, init_token="[CLS]", eos_token="[SEP]", pad_token='[PAD]', unk_token='[UNK]') A_LABEL = torchtext.data.Field(sequential=True, batch_first=True) LABEL = torchtext.data.Field(sequential=False, is_target=True) # フォルダ「data」から各tsvファイルを読み込みます # BERT用で処理するので、10分弱時間がかかります # data データセットを定義 train_val_ds, test_ds = torchtext.data.TabularDataset.splits( path=DATA_PATH, train='train_slot2_v2.tsv', test='test_slot2_v2.tsv', format='tsv', fields=[('Text', TEXT), ('A_label', A_LABEL), ('Label', LABEL)]) vocab_bert, ids_to_tokens_bert = load_vocab(vocab_file=VOCAB_FILE) # ラベルを認識するために追加(LABEL.build) A_LABEL.build_vocab(train_val_ds, min_freq=1) LABEL.build_vocab(train_val_ds, min_freq=1) TEXT.build_vocab(train_val_ds, min_freq=1) TEXT.vocab.stoi = vocab_bert batch_size = 32 # BERTでは16、32あたりを使用する train_dl = torchtext.data.Iterator(train_val_ds, batch_size=batch_size, train=True) val_dl = torchtext.data.Iterator(test_ds, batch_size=batch_size, train=False, sort=False) # 辞書オブジェクトにまとめる dataloaders_dict = {"train": train_dl, "val": val_dl} return train_dl, val_dl, TEXT, dataloaders_dict
def main(): parser = argparse.ArgumentParser() ## Required parameters parser.add_argument("--type", default=None, type=str, required=True, help=".") parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " "bert-base-multilingual-cased, bert-base-chinese.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model checkpoints and predictions will be written." ) ## Other parameters parser.add_argument( "--train_file", default=None, type=str, help="CoQA json for training. E.g., coqa-train-v1.0.json") parser.add_argument( "--predict_file", default=None, type=str, help="CoQA json for predictions. E.g., coqa-dev-v1.0.json") parser.add_argument( "--max_seq_length", default=512, type=int, help= "The maximum total input sequence length after WordPiece tokenization. Sequences " "longer than this will be truncated, and sequences shorter than this will be padded." ) parser.add_argument( "--doc_stride", default=128, type=int, help= "When splitting up a long document into chunks, how much stride to take between chunks." ) parser.add_argument( "--max_query_length", default=64, type=int, help= "The maximum number of tokens for the question. Questions longer than this will " "be truncated to this length.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_predict", action='store_true', help="Whether to run eval on the dev set.") # parser.add_argument("--do_F1", # action='store_true', # help="Whether to calculating F1 score") # we don't talk anymore. please use official evaluation scripts parser.add_argument("--train_batch_size", default=48, type=int, help="Total batch size for training.") parser.add_argument("--predict_batch_size", default=48, type=int, help="Total batch size for predictions.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--num_train_epochs", default=2.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.06, type=float, help= "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% " "of training.") parser.add_argument( "--n_best_size", default=20, type=int, help= "The total number of n-best predictions to generate in the nbest_predictions.json " "output file.") parser.add_argument( "--max_answer_length", default=30, type=int, help= "The maximum length of an answer that can be generated. This is needed because the start " "and end predictions are not conditioned on one another.") parser.add_argument( "--verbose_logging", action='store_true', help= "If true, all of the warnings related to data processing will be printed. " "A number of warnings are expected for a normal CoQA evaluation.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( "--do_lower_case", action='store_true', help= "Whether to lower case the input text. True for uncased models, False for cased models." ) parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp16_opt_level', type=str, default='O1', help= "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html") parser.add_argument('--overwrite_output_dir', action='store_true', help="Overwrite the content of the output directory") parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--weight_decay', type=float, default=0, help="") parser.add_argument( '--null_score_diff_threshold', type=float, default=0.0, help= "If null_score - best_non_null is greater than the threshold predict null." ) parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") parser.add_argument('--logfile', type=str, default=None, help='Which file to keep log.') parser.add_argument('--logmode', type=str, default=None, help='logging mode, `w` or `a`') parser.add_argument('--tensorboard', action='store_true', help='no tensor board') parser.add_argument('--qa_tag', action='store_true', help='add qa tag or not') parser.add_argument('--history_len', type=int, default=2, help='length of history') parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument('--logging_steps', type=int, default=50, help="Log every X updates steps.") args = parser.parse_args() print(args) if args.server_ip and args.server_port: # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script import ptvsd print("Waiting for debugger attach") ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.wait_for_attach() if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, filename=args.logfile, filemode=args.logmode) logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_predict: raise ValueError( "At least one of `do_train` or `do_predict` must be True.") if args.do_train: if not args.train_file: raise ValueError( "If `do_train` is True, then `train_file` must be specified.") if args.do_predict: if not args.predict_file: raise ValueError( "If `do_predict` is True, then `predict_file` must be specified." ) if os.path.exists(args.output_dir) and os.listdir( args.output_dir ) and args.do_train and not args.overwrite_output_dir: raise ValueError( "Output directory () already exists and is not empty.") if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) if args.local_rank not in [-1, 0]: torch.distributed.barrier( ) # Make sure only the first process in distributed training will download model & vocab if args.do_train or args.do_predict: tokenizer = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case) model = BertForCoQA.from_pretrained(args.bert_model) if args.local_rank == 0: torch.distributed.barrier() model.to(device) if args.do_train: if args.local_rank in [-1, 0] and args.tensorboard: from tensorboardX import SummaryWriter tb_writer = SummaryWriter() # Prepare data loader cached_train_features_file = args.train_file + '_{0}_{1}_{2}_{3}_{4}_{5}_{6}'.format( args.type, str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length), str(args.max_answer_length), str(args.history_len), str(args.qa_tag)) cached_train_examples_file = args.train_file + '_examples_{0}_{1}.pk'.format( str(args.history_len), str(args.qa_tag)) # try train_examples try: with open(cached_train_examples_file, "rb") as reader: train_examples = pickle.load(reader) except: logger.info(" No cached file %s", cached_train_examples_file) train_examples = read_coqa_examples(input_file=args.train_file, history_len=args.history_len, add_QA_tag=args.qa_tag) if args.local_rank == -1 or torch.distributed.get_rank() == 0: logger.info(" Saving train examples into cached file %s", cached_train_examples_file) with open(cached_train_examples_file, "wb") as writer: pickle.dump(train_examples, writer) # print('DEBUG') # exit() # try train_features try: with open(cached_train_features_file, "rb") as reader: train_features = pickle.load(reader) except: logger.info(" No cached file %s", cached_train_features_file) train_features = convert_examples_to_features( examples=train_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, ) if args.local_rank == -1 or torch.distributed.get_rank() == 0: logger.info(" Saving train features into cached file %s", cached_train_features_file) with open(cached_train_features_file, "wb") as writer: pickle.dump(train_features, writer) # print('DEBUG') # exit() all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) all_start_positions = torch.tensor( [f.start_position for f in train_features], dtype=torch.long) all_end_positions = torch.tensor( [f.end_position for f in train_features], dtype=torch.long) all_rational_mask = torch.tensor( [f.rational_mask for f in train_features], dtype=torch.long) all_cls_idx = torch.tensor([f.cls_idx for f in train_features], dtype=torch.long) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions, all_rational_mask, all_cls_idx) if args.local_rank == -1: train_sampler = RandomSampler(train_data) else: train_sampler = DistributedSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) num_train_optimization_steps = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs # if args.local_rank != -1: # num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() # Prepare optimizer param_optimizer = list(model.named_parameters()) # hack to remove pooler, which is not used # thus it produce None grad that break apex # param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = WarmupLinearSchedule( optimizer, warmup_steps=int(args.warmup_proportion * num_train_optimization_steps), t_total=num_train_optimization_steps) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) if n_gpu > 1: model = torch.nn.DataParallel(model) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) global_step = 0 tr_loss, logging_loss = 0.0, 0.0 logger.info("***** Running training *****") logger.info(" Num orig examples = %d", len(train_examples)) logger.info(" Num split examples = %d", len(train_features)) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_optimization_steps) model.train() for epoch in trange(int(args.num_train_epochs), desc="Epoch"): for step, batch in enumerate( tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])): batch = tuple( t.to(device) for t in batch) # multi-gpu does scattering it-self input_ids, input_mask, segment_ids, start_positions, end_positions, rational_mask, cls_idx = batch loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions, rational_mask, cls_idx) # loss = gather(loss, 0) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: loss.backward() if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: if args.tensorboard: tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step) else: logger.info( 'Step: {}\tLearning rate: {}\tLoss: {}\t'. format(global_step, scheduler.get_lr()[0], (tr_loss - logging_loss) / args.logging_steps)) logging_loss = tr_loss if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): # Save a trained model, configuration and tokenizer model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self # If we save using the predefined names, we can load using `from_pretrained` output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) output_config_file = os.path.join(args.output_dir, CONFIG_NAME) torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(args.output_dir) # Load a trained model and vocabulary that you have fine-tuned model = BertForCoQA.from_pretrained(args.output_dir) tokenizer = BertTokenizer.from_pretrained( args.output_dir, do_lower_case=args.do_lower_case) # Good practice: save your training arguments together with the trained model output_args_file = os.path.join(args.output_dir, 'training_args.bin') torch.save(args, output_args_file) else: model = BertForCoQA.from_pretrained(args.bert_model) model.to(device) if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0): cached_eval_features_file = args.predict_file + '_{0}_{1}_{2}_{3}_{4}_{5}_{6}'.format( args.type, str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length), str(args.max_answer_length), str(args.history_len), str(args.qa_tag)) cached_eval_examples_file = args.predict_file + '_examples_{0}_{1}.pk'.format( str(args.history_len), str(args.qa_tag)) # try eval_examples try: with open(cached_eval_examples_file, 'rb') as reader: eval_examples = pickle.load(reader) except: logger.info("No cached file: %s", cached_eval_examples_file) eval_examples = read_coqa_examples(input_file=args.predict_file, history_len=args.history_len, add_QA_tag=args.qa_tag) logger.info(" Saving eval examples into cached file %s", cached_eval_examples_file) with open(cached_eval_examples_file, 'wb') as writer: pickle.dump(eval_examples, writer) # try eval_features try: with open(cached_eval_features_file, "rb") as reader: eval_features = pickle.load(reader) except: logger.info("No cached file: %s", cached_eval_features_file) eval_features = convert_examples_to_features( examples=eval_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, ) logger.info(" Saving eval features into cached file %s", cached_eval_features_file) with open(cached_eval_features_file, "wb") as writer: pickle.dump(eval_features, writer) # print('DEBUG') # exit() logger.info("***** Running predictions *****") logger.info(" Num orig examples = %d", len(eval_examples)) logger.info(" Num split examples = %d", len(eval_features)) logger.info(" Batch size = %d", args.predict_batch_size) all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) # Run prediction for full data eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size) model.eval() all_results = [] logger.info("Start evaluating") for input_ids, input_mask, segment_ids, example_indices in tqdm( eval_dataloader, desc="Evaluating", disable=args.local_rank not in [-1, 0]): # if len(all_results) % 1000 == 0: # logger.info("Processing example: %d" % (len(all_results))) input_ids = input_ids.to(device) input_mask = input_mask.to(device) segment_ids = segment_ids.to(device) with torch.no_grad(): batch_start_logits, batch_end_logits, batch_yes_logits, batch_no_logits, batch_unk_logits = model( input_ids, segment_ids, input_mask) for i, example_index in enumerate(example_indices): start_logits = batch_start_logits[i].detach().cpu().tolist() end_logits = batch_end_logits[i].detach().cpu().tolist() yes_logits = batch_yes_logits[i].detach().cpu().tolist() no_logits = batch_no_logits[i].detach().cpu().tolist() unk_logits = batch_unk_logits[i].detach().cpu().tolist() eval_feature = eval_features[example_index.item()] unique_id = int(eval_feature.unique_id) all_results.append( RawResult(unique_id=unique_id, start_logits=start_logits, end_logits=end_logits, yes_logits=yes_logits, no_logits=no_logits, unk_logits=unk_logits)) output_prediction_file = os.path.join(args.output_dir, "predictions.json") output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json") output_null_log_odds_file = os.path.join(args.output_dir, "null_odds.json") write_predictions(eval_examples, eval_features, all_results, args.n_best_size, args.max_answer_length, args.do_lower_case, output_prediction_file, output_nbest_file, output_null_log_odds_file, args.verbose_logging, args.null_score_diff_threshold)
def __init__( self, models, tgt_dict, beam_size=1, max_len_a=0, max_len_b=200, max_len=0, min_len=1, normalize_scores=True, len_penalty=1.0, unk_penalty=0.0, temperature=1.0, match_source_len=False, no_repeat_ngram_size=0, search_strategy=None, eos=None, symbols_to_strip_from_output=None, lm_model=None, lm_weight=1.0, args=None, ): """Generates translations of a given source sentence. Args: models (List[~fairseq.models.FairseqModel]): ensemble of models, currently support fairseq.models.TransformerModel for scripting beam_size (int, optional): beam width (default: 1) max_len_a/b (int, optional): generate sequences of maximum length ax + b, where x is the source length max_len (int, optional): the maximum length of the generated output (not including end-of-sentence) min_len (int, optional): the minimum length of the generated output (not including end-of-sentence) normalize_scores (bool, optional): normalize scores by the length of the output (default: True) len_penalty (float, optional): length penalty, where <1.0 favors shorter, >1.0 favors longer sentences (default: 1.0) unk_penalty (float, optional): unknown word penalty, where <0 produces more unks, >0 produces fewer (default: 0.0) temperature (float, optional): temperature, where values >1.0 produce more uniform samples and values <1.0 produce sharper samples (default: 1.0) match_source_len (bool, optional): outputs should match the source length (default: False) """ super().__init__() if isinstance(models, EnsembleModel): self.model = models else: self.model = EnsembleModel(models) self.tgt_dict = tgt_dict self.pad = tgt_dict.pad() self.unk = tgt_dict.unk() self.eos = tgt_dict.eos() if eos is None else eos self.symbols_to_strip_from_output = ( symbols_to_strip_from_output.union({self.eos}) if symbols_to_strip_from_output is not None else {self.eos}) self.vocab_size = len(tgt_dict) self.beam_size = beam_size # the max beam size is the dictionary size - 1, since we never select pad self.beam_size = min(beam_size, self.vocab_size - 1) self.max_len_a = max_len_a self.max_len_b = max_len_b self.min_len = min_len self.max_len = max_len or self.model.max_decoder_positions() self.normalize_scores = normalize_scores self.len_penalty = len_penalty self.unk_penalty = unk_penalty self.temperature = temperature self.match_source_len = match_source_len self.use_bertinput = args.use_bertinput self.berttokenizer = BertTokenizer.from_pretrained( args.bert_model_name, do_lower_case=False) self.use_bartinput = args.use_bartinput if self.use_bartinput: self.barttokenizer = BartTokenizer.from_pretrained( args.bart_model_name, do_lower_case=False) self.use_electrainput = args.use_electrainput if self.use_electrainput: self.electratokenizer = ElectraTokenizer.from_pretrained( args.electra_model_name) # not implemented yet. # self.use_bertinput = args.use_bertinput # self.mask_lm = args.mask_lm # self.bert_ner = args.bert_ner # self.bert_sst = args.bert_sst if no_repeat_ngram_size > 0: self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) else: self.repeat_ngram_blocker = None assert temperature > 0, "--temperature must be greater than 0" self.search = (search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy) # We only need to set src_lengths in LengthConstrainedBeamSearch. # As a module attribute, setting it would break in multithread # settings when the model is shared. self.should_set_src_lengths = (hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths) self.model.eval() self.lm_model = lm_model self.lm_weight = lm_weight if self.lm_model is not None: self.lm_model.eval()
def __init__(self, args): self.data_dir = args.data_dir self.nega_num = args.nega_num self.tokenizer = BertTokenizer.from_pretrained(args.bert_model_dir, do_lower_case=True)
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, ratio, pred_probs, bert_model_name, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] srcbert_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) bertprefix = os.path.join(data_path, '{}.bert.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) bertprefix = os.path.join(data_path, '{}.bert.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) src_datasets.append(indexed_dataset.make_dataset(prefix + src, impl=dataset_impl, fix_lua_indexing=True, dictionary=src_dict)) tgt_datasets.append(indexed_dataset.make_dataset(prefix + tgt, impl=dataset_impl, fix_lua_indexing=True, dictionary=tgt_dict)) srcbert_datasets.append(indexed_dataset.make_dataset(bertprefix + src, impl=dataset_impl, fix_lua_indexing=True, )) print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] srcbert_datasets = srcbert_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) berttokenizer = BertTokenizer.from_pretrained(bert_model_name) if split == 'test': return BertLanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, srcbert=srcbert_datasets, srcbert_sizes=srcbert_datasets.sizes if srcbert_datasets is not None else None, berttokenizer=berttokenizer, ) else: return BertXYNoisyLanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, shuffle=True, ratio=ratio, pred_probs=pred_probs, srcbert=srcbert_datasets, srcbert_sizes=srcbert_datasets.sizes if srcbert_datasets is not None else None, berttokenizer=berttokenizer, )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, num_buckets=0, shuffle=True, pad_to_multiple=1, prepend_bos_src=None, bert_model_name=None, bart_model_name=None, electra_model_name=None, electra_pretrain=False, denoising=False, masking=False, extra_data=False, input_mapping=False, mask_ratio=None, random_ratio=None, insert_ratio=None, rotate_ratio=None, permute_sentence_ratio=None, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name, do_lower_case=False) if denoising: bart_tokenizer = AutoTokenizer.from_pretrained(bart_model_name, do_lower_case=False) #bart_tokenizer = BartTokenizer.from_pretrained(bart_model_name, do_lower_case=False) if electra_pretrain: electra_tokenizer = ElectraTokenizer.from_pretrained( electra_model_name) srcbert_datasets = [] extra_datasets = [] extra_bert_datasets = [] extra_bert_mapping_datasets = [] extra_bart_datasets = [] extra_bart_mapping_datasets = [] if denoising: srcbart_datasets = [] if electra_pretrain: srcelectra_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) bertprefix = os.path.join( data_path, '{}.bert.{}-{}.'.format(split_k, src, tgt)) bert_mapping_prefix = os.path.join( data_path, '{}.bert.map.{}-{}.'.format(split_k, src, tgt)) if denoising: bartprefix = os.path.join( data_path, '{}.bart.{}-{}.'.format(split_k, src, tgt)) bart_mapping_prefix = os.path.join( data_path, '{}.bart.map.{}-{}.'.format(split_k, src, tgt)) if electra_pretrain: electraprefix = os.path.join( data_path, '{}.electra.{}-{}.'.format(split_k, src, tgt)) electra_mapping_prefix = os.path.join( data_path, '{}.electra.map.{}-{}.'.format(split_k, src, tgt)) if extra_data: extraprefix = os.path.join( data_path, '{}.extra.{}-{}.'.format(split_k, src, tgt)) extra_bert_prefix = os.path.join( data_path, '{}.extra.bert.{}-{}.'.format(split_k, src, tgt)) extra_bert_mapping_prefix = os.path.join( data_path, '{}.extra.bert.map.{}-{}.'.format(split_k, src, tgt)) extra_bart_prefix = os.path.join( data_path, '{}.extra.bart.{}-{}.'.format(split_k, src, tgt)) extra_bart_mapping_prefix = os.path.join( data_path, '{}.extra.bart.map.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) bertprefix = os.path.join( data_path, '{}.bert.{}-{}.'.format(split_k, tgt, src)) bert_mapping_prefix = os.path.join( data_path, '{}.bert.map.{}-{}.'.format(split_k, src, tgt)) if denoising: bartprefix = os.path.join( data_path, '{}.bart.{}-{}.'.format(split_k, tgt, src)) bart_mapping_prefix = os.path.join( data_path, '{}.bart.map.{}-{}.'.format(split_k, src, tgt)) if electra_pretrain: electraprefix = os.path.join( data_path, '{}.electra.{}-{}.'.format(split_k, src, tgt)) electra_mapping_prefix = os.path.join( data_path, '{}.electra.map.{}-{}.'.format(split_k, src, tgt)) if extra_data: extraprefix = os.path.join( data_path, '{}.extra.{}-{}.'.format(split_k, src, tgt)) extra_bert_prefix = os.path.join( data_path, '{}.extra.bert.{}-{}.'.format(split_k, src, tgt)) extra_bert_mapping_prefix = os.path.join( data_path, '{}.extra.bert.map.{}-{}.'.format(split_k, src, tgt)) extra_bart_prefix = os.path.join( data_path, '{}.extra.bart.{}-{}.'.format(split_k, src, tgt)) extra_bart_mapping_prefix = os.path.join( data_path, '{}.extra.bart.map.{}-{}.'.format(split_k, src, tgt)) else: if k > 0: break else: raise FileNotFoundError("Dataset not found: {} ({})".format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) # srcbert_datasets.append(indexed_dataset.make_dataset(bertprefix + src, impl=dataset_impl, # fix_lua_indexing=True, )) # if denoising: # srcbart_datasets.append(indexed_dataset.make_dataset(bartprefix + src, impl=dataset_impl, # fix_lua_indexing=True, )) # if extra_data: # extra_datasets.append(indexed_dataset.make_dataset(extraprefix + src, impl=dataset_impl, # fix_lua_indexing=True, )) srcbert_datasets.append( data_utils.load_indexed_dataset( bertprefix + src, dataset_impl=dataset_impl, )) if denoising: srcbart_datasets.append( data_utils.load_indexed_dataset( bartprefix + src, dataset_impl=dataset_impl, )) if electra_pretrain: srcelectra_datasets.append( data_utils.load_indexed_dataset( electraprefix + src, dataset_impl=dataset_impl, )) if extra_data and split == 'train': extra_datasets.append( data_utils.load_indexed_dataset( extraprefix + src, dataset_impl=dataset_impl, )) extra_bert_datasets.append( data_utils.load_indexed_dataset( extra_bert_prefix + src, dataset_impl=dataset_impl, )) extra_bert_mapping_datasets.append( data_utils.load_indexed_dataset( extra_bert_mapping_prefix + src, dataset_impl=dataset_impl, )) extra_bart_datasets.append( data_utils.load_indexed_dataset( extra_bart_prefix + src, dataset_impl=dataset_impl, )) extra_bart_mapping_datasets.append( data_utils.load_indexed_dataset( extra_bart_mapping_prefix + src, dataset_impl=dataset_impl, )) #import pdb; pdb.set_trace() assert extra_datasets != [] or extra_bert_datasets != [] or extra_bert_mapping_datasets != [] or extra_bart_datasets != [] or extra_bart_mapping_datasets != [] #extra_datasets = extra_datasets[0] #import pdb; pdb.set_trace() src_datasets[-1] = PrependTokenDataset(src_datasets[-1], token=src_dict.bos_index) if extra_data and split == 'train': extra_datasets[-1] = PrependTokenDataset(extra_datasets[-1], token=src_dict.bos_index) if denoising is True: if input_mapping is True and split == 'train': bart_mapping_dataset = data_utils.load_indexed_dataset( bart_mapping_prefix + src, dataset_impl=dataset_impl) else: bart_mapping_dataset = None src_datasets[-1] = DenoisingBartDataset( src_datasets[-1], src_datasets[-1].sizes, src_dict, srcbart_datasets[-1], srcbart_datasets[-1].sizes, bart_tokenizer, map_dataset=bart_mapping_dataset, mask_ratio=mask_ratio, random_ratio=random_ratio, insert_ratio=insert_ratio, rotate_ratio=rotate_ratio, permute_sentence_ratio=permute_sentence_ratio, ) if electra_pretrain is True: if input_mapping is True and split == 'train': electra_mapping_dataset = data_utils.load_indexed_dataset( electra_mapping_prefix + src, dataset_impl=dataset_impl) else: electra_mapping_dataset = None src_datasets[-1] = ElectrapretrainDataset( src_datasets[-1], src_datasets[-1].sizes, src_dict, srcelectra_datasets[-1], srcelectra_datasets[-1].sizes, electra_tokenizer, map_dataset=electra_mapping_dataset, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, ) if masking is True: if input_mapping is True and split == 'train': #bert_mapping_dataset = indexed_dataset.make_dataset(bert_mapping_prefix + src, impl=dataset_impl, fix_lua_indexing=True) bert_mapping_dataset = data_utils.load_indexed_dataset( bert_mapping_prefix + src, dataset_impl=dataset_impl) else: bert_mapping_dataset = None src_datasets[-1] = MaskingDataset( src_datasets[-1], src_datasets[-1].sizes, src_dict, srcbert_datasets[-1], srcbert_datasets[-1].sizes, bert_tokenizer, map_dataset=bert_mapping_dataset, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, ) if extra_data is True and split == 'train': assert input_mapping is True src_datasets[-1] = MaskingExtraDataset( src_datasets[-1], src_datasets[-1].sizes, src_dict, extra_datasets[-1], extra_datasets[-1].sizes, extra_bert_datasets[-1], extra_bert_datasets[-1].sizes, bert_tokenizer, map_dataset=extra_bert_mapping_datasets[-1], left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, ) src_datasets[-1] = DenoisingBartExtraDataset( src_datasets[-1], src_datasets[-1].sizes, src_dict, extra_datasets[-1], extra_datasets[-1].sizes, extra_bart_datasets[-1], extra_bart_datasets[-1].sizes, bart_tokenizer, map_dataset=extra_bart_mapping_datasets[-1], ) logger.info("{} {} {}-{} {} examples".format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None # srcbert_datasets = srcbert_datasets[0] # if denoising: # srcbart_datasets = srcbart_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) elif prepend_bos_src is not None: logger.info(f"prepending src bos: {prepend_bos_src}") src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src) eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index("[{}]".format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index("[{}]".format(tgt))) eos = tgt_dict.index("[{}]".format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None src_bart_dataset = None src_bert_dataset = None src_electra_dataset = None return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, masking, src_bert_dataset, denoising, src_bart_dataset, src_electra_dataset, #extra_datasets, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, shuffle=shuffle, pad_to_multiple=pad_to_multiple, )