def sim(options): mt_model, text_processor = SenSim.load(options.model_path, tok_dir=options.tokenizer_path) print("Model initialization done!") optimizer = build_optimizer(mt_model, options.learning_rate, warump_steps=options.warmup) trainer = SenSimEval(model=mt_model, mask_prob=options.mask_prob, optimizer=optimizer, clip=options.clip, fp16=options.fp16) pin_memory = torch.cuda.is_available() mt_dev_data = dataset.MTDataset( batch_pickle_dir=options.mt_dev_path, max_batch_capacity=options.total_capacity, max_batch=int(options.batch / (options.beam_width * 2)), pad_idx=mt_model.text_processor.pad_token_id(), keep_pad_idx=False) dl = data_utils.DataLoader(mt_dev_data, batch_size=1, shuffle=False, pin_memory=pin_memory) trainer.eval(mt_dev_iter=dl, saving_path=options.output)
def build_data_loader(options, text_processor): print(datetime.datetime.now(), "Binarizing test data") assert options.src_lang is not None assert options.target_lang is not None src_lang = "<" + options.src_lang + ">" src_lang_id = text_processor.languages[src_lang] dst_lang = "<" + options.target_lang + ">" target_lang = text_processor.languages[dst_lang] fixed_output = [text_processor.token_id(dst_lang)] examples = [] with open(options.input_path, "r") as s_fp: for src_line in s_fp: if len(src_line.strip()) == 0: continue src_line = " ".join([src_lang, src_line, "</s>"]) src_tok_line = text_processor.tokenize_one_sentence( src_line.strip().replace(" </s> ", " ")) examples.append( (src_tok_line, fixed_output, src_lang_id, target_lang)) print(datetime.datetime.now(), "Loaded %f examples", (len(examples))) test_data = dataset.MTDataset(examples=examples, max_batch_capacity=options.total_capacity, max_batch=options.batch, pad_idx=text_processor.pad_token_id(), max_seq_len=10000) pin_memory = torch.cuda.is_available() return data_utils.DataLoader(test_data, batch_size=1, shuffle=False, pin_memory=pin_memory)
def get_mt_train_data(mt_model, num_processors, options, pin_memory, lex_dict=None): mt_train_loader = [] train_paths = options.mt_train_path.split(",") for train_path in train_paths: mt_train_data = dataset.MTDataset( batch_pickle_dir=train_path, max_batch_capacity=int(num_processors * options.total_capacity / 2), max_batch=int(num_processors * options.batch / 2), pad_idx=mt_model.text_processor.pad_token_id(), lex_dict=lex_dict, keep_pad_idx=False) mtl = data_utils.DataLoader(mt_train_data, batch_size=1, shuffle=True, pin_memory=pin_memory) mt_train_loader.append(mtl) return mt_train_loader
def get_mt_dev_data(mt_model, options, pin_memory, text_processor, trainer, lex_dict=None): mt_dev_loader = [] dev_paths = options.mt_dev_path.split(",") trainer.reference = [] for dev_path in dev_paths: mt_dev_data = dataset.MTDataset( batch_pickle_dir=dev_path, max_batch_capacity=options.total_capacity, keep_pad_idx=True, max_batch=int(options.batch / (options.beam_width * 2)), pad_idx=mt_model.text_processor.pad_token_id(), lex_dict=lex_dict) dl = data_utils.DataLoader(mt_dev_data, batch_size=1, shuffle=False, pin_memory=pin_memory) mt_dev_loader.append(dl) print("creating reference") generator = (trainer.generator.module if hasattr( trainer.generator, "module") else trainer.generator) for batch in dl: tgt_inputs = batch["dst_texts"].squeeze() refs = get_outputs_until_eos(text_processor.sep_token_id(), tgt_inputs, remove_first_token=True) ref = [ generator.seq2seq_model.text_processor.tokenizer.decode( ref.numpy()) for ref in refs ] trainer.reference += ref return mt_dev_loader