def sim(options):
        mt_model, text_processor = SenSim.load(options.model_path,
                                               tok_dir=options.tokenizer_path)

        print("Model initialization done!")
        optimizer = build_optimizer(mt_model,
                                    options.learning_rate,
                                    warump_steps=options.warmup)

        trainer = SenSimEval(model=mt_model,
                             mask_prob=options.mask_prob,
                             optimizer=optimizer,
                             clip=options.clip,
                             fp16=options.fp16)

        pin_memory = torch.cuda.is_available()
        mt_dev_data = dataset.MTDataset(
            batch_pickle_dir=options.mt_dev_path,
            max_batch_capacity=options.total_capacity,
            max_batch=int(options.batch / (options.beam_width * 2)),
            pad_idx=mt_model.text_processor.pad_token_id(),
            keep_pad_idx=False)
        dl = data_utils.DataLoader(mt_dev_data,
                                   batch_size=1,
                                   shuffle=False,
                                   pin_memory=pin_memory)

        trainer.eval(mt_dev_iter=dl, saving_path=options.output)
def build_data_loader(options, text_processor):
    print(datetime.datetime.now(), "Binarizing test data")
    assert options.src_lang is not None
    assert options.target_lang is not None
    src_lang = "<" + options.src_lang + ">"
    src_lang_id = text_processor.languages[src_lang]
    dst_lang = "<" + options.target_lang + ">"
    target_lang = text_processor.languages[dst_lang]
    fixed_output = [text_processor.token_id(dst_lang)]
    examples = []
    with open(options.input_path, "r") as s_fp:
        for src_line in s_fp:
            if len(src_line.strip()) == 0: continue
            src_line = " ".join([src_lang, src_line, "</s>"])
            src_tok_line = text_processor.tokenize_one_sentence(
                src_line.strip().replace(" </s> ", " "))
            examples.append(
                (src_tok_line, fixed_output, src_lang_id, target_lang))
    print(datetime.datetime.now(), "Loaded %f examples", (len(examples)))
    test_data = dataset.MTDataset(examples=examples,
                                  max_batch_capacity=options.total_capacity,
                                  max_batch=options.batch,
                                  pad_idx=text_processor.pad_token_id(),
                                  max_seq_len=10000)
    pin_memory = torch.cuda.is_available()
    return data_utils.DataLoader(test_data,
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=pin_memory)
示例#3
0
 def get_mt_train_data(mt_model,
                       num_processors,
                       options,
                       pin_memory,
                       lex_dict=None):
     mt_train_loader = []
     train_paths = options.mt_train_path.split(",")
     for train_path in train_paths:
         mt_train_data = dataset.MTDataset(
             batch_pickle_dir=train_path,
             max_batch_capacity=int(num_processors *
                                    options.total_capacity / 2),
             max_batch=int(num_processors * options.batch / 2),
             pad_idx=mt_model.text_processor.pad_token_id(),
             lex_dict=lex_dict,
             keep_pad_idx=False)
         mtl = data_utils.DataLoader(mt_train_data,
                                     batch_size=1,
                                     shuffle=True,
                                     pin_memory=pin_memory)
         mt_train_loader.append(mtl)
     return mt_train_loader
示例#4
0
    def get_mt_dev_data(mt_model,
                        options,
                        pin_memory,
                        text_processor,
                        trainer,
                        lex_dict=None):
        mt_dev_loader = []
        dev_paths = options.mt_dev_path.split(",")
        trainer.reference = []
        for dev_path in dev_paths:
            mt_dev_data = dataset.MTDataset(
                batch_pickle_dir=dev_path,
                max_batch_capacity=options.total_capacity,
                keep_pad_idx=True,
                max_batch=int(options.batch / (options.beam_width * 2)),
                pad_idx=mt_model.text_processor.pad_token_id(),
                lex_dict=lex_dict)
            dl = data_utils.DataLoader(mt_dev_data,
                                       batch_size=1,
                                       shuffle=False,
                                       pin_memory=pin_memory)
            mt_dev_loader.append(dl)

            print("creating reference")

            generator = (trainer.generator.module if hasattr(
                trainer.generator, "module") else trainer.generator)

            for batch in dl:
                tgt_inputs = batch["dst_texts"].squeeze()
                refs = get_outputs_until_eos(text_processor.sep_token_id(),
                                             tgt_inputs,
                                             remove_first_token=True)
                ref = [
                    generator.seq2seq_model.text_processor.tokenizer.decode(
                        ref.numpy()) for ref in refs
                ]
                trainer.reference += ref
        return mt_dev_loader