Ejemplo n.º 1
0
    def test_data(self):
        path_dir_name = os.path.dirname(os.path.realpath(__file__))
        data_path = os.path.join(path_dir_name, "sample.txt")

        with tempfile.TemporaryDirectory() as tmpdirname:
            processor = TextProcessor()
            processor.train_tokenizer([data_path],
                                      vocab_size=1000,
                                      to_save_dir=tmpdirname,
                                      languages={
                                          "<mzn>": 0,
                                          "<glk": 1
                                      })
            create_batches.write(text_processor=processor,
                                 cache_dir=tmpdirname,
                                 seq_len=512,
                                 txt_file=data_path,
                                 sen_block_size=10)
            dataset = TextDataset(save_cache_dir=tmpdirname, max_cache_size=3)
            assert dataset.line_num == 70

            dataset.__getitem__(3)
            assert len(dataset.current_cache) == 3

            dataset.__getitem__(9)
            assert len(dataset.current_cache) == 3

            dataset.__getitem__(69)
            assert len(dataset.current_cache) == 2
Ejemplo n.º 2
0
def get_tokenizer(tokenizer_path: Optional[str] = None,
                  train_path: Optional[str] = None,
                  model_path: Optional[str] = None,
                  vocab_size: Optional[int] = None) -> TextProcessor:
    if tokenizer_path is None:
        if not os.path.exists(model_path):
            os.makedirs(model_path)

        print("Training Tokenizer...")
        text_processor = TextProcessor()
        print("Writing raw text...")
        languages = set()
        with open(train_path + ".tmp", "w") as wf:
            with open(train_path, "r") as rf:
                for i, line in enumerate(rf):
                    spl = [
                        sen.strip() for sen in line.split("</s>")
                        if len(sen.strip()) > 0
                    ]
                    if len(spl) == 0: continue
                    if spl[0].startswith("<"):
                        sen_split = spl[0].strip().split(" ")
                        spl[0] = " ".join(sen_split[1:])
                        languages.add(sen_split[0])
                    wf.write("\n".join(spl))
                    wf.write("\n")
                    if ((i + 1) % 1000 == 0):
                        print(i + 1, "\r", end="")
        print("Writing raw text done!")

        text_processor.train_tokenizer(
            paths=[train_path + ".tmp"],
            vocab_size=vocab_size,
            to_save_dir=model_path,
            languages={l: i
                       for i, l in enumerate(sorted(languages))})
        print("Removing temporary file!")
        os.system("rm " + train_path + ".tmp &")
        print("done!")
    else:
        text_processor = TextProcessor(tokenizer_path)
    return text_processor
Ejemplo n.º 3
0
    def load(out_dir: str, tok_dir: str):
        text_processor = TextProcessor(tok_model_path=tok_dir)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        with open(os.path.join(out_dir, "mt_config"), "rb") as fp:
            enc_layer, embed_dim, intermediate_dim = pickle.load(fp)

            mt_model = Caption2Image(text_processor=text_processor, enc_layer=enc_layer, embed_dim=embed_dim,
                                     intermediate_dim=intermediate_dim)
            mt_model.load_state_dict(torch.load(os.path.join(out_dir, "mt_model.state_dict"), map_location=device),
                                     strict=False)
            return mt_model
Ejemplo n.º 4
0
 def load(out_dir: str):
     text_processor = TextProcessor(tok_model_path=out_dir)
     with open(os.path.join(out_dir, "config"), "rb") as fp:
         config = pickle.load(fp)
         if isinstance(config, dict):
             # For older configs
             config = BertConfig(**config)
         lm = LM(text_processor=text_processor, config=config)
         lm.load_state_dict(
             torch.load(os.path.join(out_dir, "model.state_dict")))
         return lm
Ejemplo n.º 5
0
    def test_train_tokenizer(self):
        path_dir_name = os.path.dirname(os.path.realpath(__file__))
        data_path = os.path.join(path_dir_name, "sample.txt")

        with tempfile.TemporaryDirectory() as tmpdirname:
            processor = TextProcessor()
            processor.train_tokenizer([data_path],
                                      vocab_size=1000,
                                      to_save_dir=tmpdirname,
                                      languages={"<en>": 0})
            assert processor.tokenizer.get_vocab_size() == 1000
            sen1 = "Obama signed many landmark bills into law during his first two years in office."
            assert processor._tokenize(sen1) is not None

            many_sens = "\n".join([sen1] * 10)
            assert len(processor.tokenize(many_sens)) == 10

            new_prcoessor = TextProcessor(tok_model_path=tmpdirname)
            assert new_prcoessor.tokenizer.get_vocab_size() == 1000
            sen2 = "Obama signed many landmark bills into law during his first two years in office."
            assert processor._tokenize(sen2) is not None
Ejemplo n.º 6
0
    def load(cls, out_dir: str, tok_dir: str, use_obj: bool = False):
        text_processor = TextProcessor(tok_model_path=tok_dir)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        with open(os.path.join(out_dir, "mt_config"), "rb") as fp:
            lang_dec, use_proposals, enc_layer, dec_layer, embed_dim, intermediate_dim, tie_embed, resnet_depth, freeze_image = pickle.load(
                fp)

            mt_model = cls(text_processor=text_processor, lang_dec=lang_dec, use_proposals=use_proposals,
                               tie_embed=tie_embed, enc_layer=enc_layer, dec_layer=dec_layer, embed_dim=embed_dim,
                               intermediate_dim=intermediate_dim, freeze_image=freeze_image, resnet_depth=resnet_depth,
                               use_obj=use_obj)

            mt_model.load_state_dict(torch.load(os.path.join(out_dir, "mt_model.state_dict"), map_location=device),
                                     strict=False)
            return mt_model
Ejemplo n.º 7
0
 def createSecondDataset(self, datasetFileName="df2.tsv"):
     # the text processor object is initialized
     self.tp = TextProcessor()
     # the term encoding is built
     self.tp.buildEncoding("encoding.pickle")
     # the term frequency dictionary is built
     self.tp.buildTf(fileName='tf.pickle')
     # the function which build the tfidf, still not normalized, is invoked
     self._createtfidf()
     # the function which computes the euclidean norm of the description in the house ad collection
     self._computeEuclideanNorm()
     # this function creates the tfidf dictionary but normalized
     self._createNormalizedtfidf()
     # the dataframe is created
     pd.DataFrame.from_dict(data=self.tfidf_normalized,
                            columns=range(1, self.tp.VOCABULARY_SIZE + 1),
                            orient='index').to_csv(datasetFileName,
                                                   sep='\t')
Ejemplo n.º 8
0
    def test_albert_init(self):
        path_dir_name = os.path.dirname(os.path.realpath(__file__))
        data_path = os.path.join(path_dir_name, "sample.txt")

        with tempfile.TemporaryDirectory() as tmpdirname:
            processor = TextProcessor()
            processor.train_tokenizer([data_path],
                                      vocab_size=1000,
                                      to_save_dir=tmpdirname,
                                      languages={"<en>": 0})
            lm = LM(text_processor=processor)
            assert lm.encoder.base_model.embeddings.word_embeddings.num_embeddings == 1000

            lm.save(tmpdirname)

            new_lm = LM.load(tmpdirname)

            assert new_lm.config == lm.config
Ejemplo n.º 9
0
    def test_albert_seq2seq_init(self):
        path_dir_name = os.path.dirname(os.path.realpath(__file__))
        data_path = os.path.join(path_dir_name, "sample.txt")

        with tempfile.TemporaryDirectory() as tmpdirname:
            processor = TextProcessor()
            processor.train_tokenizer([data_path],
                                      vocab_size=1000,
                                      to_save_dir=tmpdirname,
                                      languages={
                                          "<en>": 0,
                                          "<fa>": 1
                                      })
            seq2seq = Seq2Seq(text_processor=processor)
            src_inputs = torch.tensor([[
                1, 2, 3, 4, 5,
                processor.pad_token_id(),
                processor.pad_token_id()
            ], [1, 2, 3, 4, 5, 6, processor.pad_token_id()]])
            tgt_inputs = torch.tensor(
                [[6, 8, 7,
                  processor.pad_token_id(),
                  processor.pad_token_id()],
                 [6, 8, 7, 8, processor.pad_token_id()]])
            src_mask = (src_inputs != processor.pad_token_id())
            tgt_mask = (tgt_inputs != processor.pad_token_id())
            src_langs = torch.tensor([[0], [0]]).squeeze()
            tgt_langs = torch.tensor([[1], [1]]).squeeze()
            seq_output = seq2seq(src_inputs,
                                 tgt_inputs,
                                 src_mask,
                                 tgt_mask,
                                 src_langs,
                                 tgt_langs,
                                 log_softmax=True)
            assert list(seq_output.size()) == [5, processor.vocab_size()]

            seq_output = seq2seq(src_inputs, tgt_inputs, src_mask, tgt_mask,
                                 src_langs, tgt_langs)
            assert list(seq_output.size()) == [5, processor.vocab_size()]
Ejemplo n.º 10
0
    def load_raw_data(self):
        textprocessor = TextProcessor(self.in_dir, self.dictionary_file,
                                      self.hashtag_file)
        textprocessor.load_dictioanry()
        textprocessor.load_hashtag()
        dat = pd.read_csv(self.in_dir + '/' + self.input_file, header=None)

        dat.columns = ['tweet', 'hashtag']
        # n = len(dat)
        # nlist = range(0,n)
        dat['id'] = None
        dat = dat[['id', 'tweet', 'hashtag']]

        total = ['id', 'tweet', 'hashtag']
        total = total + list(textprocessor.hashtag)
        dat = dat.reindex(columns=list(total), fill_value=0)
        dat['tweet'] = dat['tweet'].apply(textprocessor.cleanup)
        dat['tweet'] = dat['tweet'].apply(textprocessor.informal_norm)

        dat['hashtag'] = dat['hashtag'].apply(textprocessor.del_hashtag)
        dat = dat.drop(
            dat[dat['hashtag'].map(len) < 1].index).reset_index(drop=True)

        dat['tweet'] = dat['tweet'].apply(textprocessor.drop_tweet)
        dat = dat.drop(
            dat[dat['tweet'].map(len) < 1].index).reset_index(drop=True)
        n = len(dat)
        nlist = range(0, n)
        dat['id'] = nlist

        # assign label
        for i in range(len(dat['hashtag'])):
            tmp_list = dat['hashtag'][i].split(",")
            for j in range(len(tmp_list)):
                tmp_list[j] = tmp_list[j].replace(' ', '')
                dat[tmp_list[j]][i] = 1
        return dat.drop(columns=['hashtag'])
Ejemplo n.º 11
0
            sorted_examples = list(map(lambda len_item: examples[len_item[0]], sorted_lens))
            with open(output_file + "." + str(part_num), "wb") as fw:
                marshal.dump(sorted_examples, fw)


def get_options():
    global options
    parser = OptionParser()
    parser.add_option("--src", dest="src_data_path", help="Path to the source txt file", metavar="FILE", default=None)
    parser.add_option("--dst", dest="dst_data_path", help="Path to the target txt file", metavar="FILE", default=None)
    parser.add_option("--output", dest="output_path", help="Output marshal file ", metavar="FILE", default=None)
    parser.add_option("--tok", dest="tokenizer_path", help="Path to the tokenizer folder", metavar="FILE", default=None)
    parser.add_option("--max_seq_len", dest="max_seq_len", help="Max sequence length", type="int", default=175)
    parser.add_option("--min_seq_len", dest="min_seq_len", help="Max sequence length", type="int", default=1)
    parser.add_option("--src-lang", dest="src_lang", type="str", default=None)
    parser.add_option("--dst-lang", dest="dst_lang", type="str", default=None)
    (options, args) = parser.parse_args()
    return options


if __name__ == "__main__":
    options = get_options()
    tokenizer = TextProcessor(options.tokenizer_path)

    print(datetime.datetime.now(), "Writing batches")
    src_lang = tokenizer.token_id("<" + options.src_lang + ">")
    dst_lang = tokenizer.token_id("<" + options.dst_lang + ">") if options.dst_lang is not None else None
    write(text_processor=tokenizer, output_file=options.output_path, src_txt_file=options.src_data_path,
          dst_txt_file=options.dst_data_path, src_lang=src_lang, dst_lang=dst_lang)
    print(datetime.datetime.now(), "Finished")
Ejemplo n.º 12
0
    def train(options):
        lex_dict = None
        if options.dict_path is not None:
            lex_dict = get_lex_dict(options.dict_path)
        if not os.path.exists(options.model_path):
            os.makedirs(options.model_path)

        text_processor = TextProcessor(options.tokenizer_path)
        assert text_processor.pad_token_id() == 0

        image_captioner = Seq2Seq.load(ImageCaptioning,
                                       options.pretrained_path,
                                       tok_dir=options.tokenizer_path)
        txt2ImageModel = Caption2Image(
            text_processor=text_processor,
            enc_layer=options.encoder_layer,
            embed_dim=options.embed_dim,
            intermediate_dim=options.intermediate_layer_dim)

        print("Model initialization done!")

        # We assume that the collator function returns a list with the size of number of gpus (in case of cpus,
        collator = dataset.ImageTextCollator()
        num_batches = max(1, torch.cuda.device_count())

        optimizer = build_optimizer(txt2ImageModel,
                                    options.learning_rate,
                                    warump_steps=options.warmup)

        trainer = Caption2ImageTrainer(
            model=txt2ImageModel,
            caption_model=image_captioner,
            mask_prob=options.mask_prob,
            optimizer=optimizer,
            clip=options.clip,
            beam_width=options.beam_width,
            max_len_a=options.max_len_a,
            max_len_b=options.max_len_b,
            len_penalty_ratio=options.len_penalty_ratio,
            fp16=options.fp16,
            mm_mode=options.mm_mode)

        pin_memory = torch.cuda.is_available()
        img_train_loader = ImageMTTrainer.get_img_loader(
            collator,
            dataset.ImageCaptionDatasetwNegSamples,
            options.train_path,
            txt2ImageModel,
            num_batches,
            options,
            pin_memory,
            lex_dict=lex_dict)

        img_dev_loader = ImageMTTrainer.get_img_loader(
            collator,
            dataset.ImageCaptionDatasetwNegSamples,
            options.dev_path,
            txt2ImageModel,
            num_batches,
            options,
            pin_memory,
            lex_dict=lex_dict,
            shuffle=False,
            denom=2)

        step, train_epoch = 0, 1
        while options.step > 0 and step < options.step:
            print("train epoch", train_epoch)
            step = trainer.train_epoch(img_data_iter=img_train_loader,
                                       img_dev_data_iter=img_dev_loader,
                                       max_step=options.step,
                                       lex_dict=lex_dict,
                                       saving_path=options.model_path,
                                       step=step)
            train_epoch += 1
Ejemplo n.º 13
0
    def train(options):
        if not os.path.exists(options.model_path):
            os.makedirs(options.model_path)

        text_processor = TextProcessor(options.tokenizer_path)

        lm_class = ReformerLM if options.reformer else LM
        if options.pretrained_path is None:
            lm = lm_class(text_processor=text_processor,
                          size=options.model_size)
        else:
            lm = lm_class.load(options.pretrained_path)

        if options.reformer:
            lm.config.hidden_dropout_prob = options.dropout
            lm.config.local_attention_probs_dropout_prob = options.dropout
            lm.config.lsh_attention_probs_dropout_prob = options.dropout
        else:
            LMTrainer.config_dropout(lm, options.dropout)

        train_data = dataset.TextDataset(save_cache_dir=options.train_path,
                                         max_cache_size=options.cache_size)
        dev_data = dataset.TextDataset(save_cache_dir=options.dev_path,
                                       max_cache_size=options.cache_size,
                                       load_all=True)

        if options.continue_train:
            with open(os.path.join(options.pretrained_path, "optim"),
                      "rb") as fp:
                optimizer = pickle.load(fp)
        else:
            optimizer = build_optimizer(lm, options.learning_rate,
                                        options.warmup)

        trainer = LMTrainer(model=lm,
                            mask_prob=options.mask_prob,
                            optimizer=optimizer,
                            clip=options.clip)

        collator = dataset.TextCollator(pad_idx=text_processor.pad_token_id())
        train_sampler, dev_sampler = None, None

        pin_memory = torch.cuda.is_available()
        loader = data_utils.DataLoader(train_data,
                                       batch_size=options.batch,
                                       shuffle=False,
                                       pin_memory=pin_memory,
                                       collate_fn=collator,
                                       sampler=train_sampler)
        dev_loader = data_utils.DataLoader(dev_data,
                                           batch_size=options.batch,
                                           shuffle=False,
                                           pin_memory=pin_memory,
                                           collate_fn=collator,
                                           sampler=dev_sampler)

        step, train_epoch = 0, 1
        while step <= options.step:
            print("train epoch", train_epoch)
            step = trainer.train_epoch(data_iter=loader,
                                       dev_data_iter=dev_loader,
                                       saving_path=options.model_path,
                                       step=step)
Ejemplo n.º 14
0
import time;
from textprocessor import TextProcessor
from bayes_classifier import BayesClassifier

form_file = 'formy'
polish_texts = ['dramat', 'popul', 'proza', 'publ', 'wp']
path_to_files = './data/'
show_simmilar = True;
num_of_simmilar = 4

if __name__ == '__main__':
    textprocessor = TextProcessor()
    textprocessor.create_dictionary(path_to_file = path_to_files, form_file=form_file)
    textprocessor.improve_dictionary(path_to_files = path_to_files, polish_texts=polish_texts)
    dict_of_words = textprocessor.dict_of_words;

    input_word = input('Napisz pojedyncze slowo:\n')
    start = time.time();
    input_word = textprocessor.map_chars(input_word)
    if not input_word in dict_of_words:
        bayes_classifier = BayesClassifier()
        simmilar_words = bayes_classifier.calculate(f'{input_word}', dict_of_words)
        unmapped_words = []
        for word in simmilar_words:
            unmapped_words.append(textprocessor.unmap_words(word))
        print(f'Slowo nie wystepuje w polskim jezyku.')
        print(f'Moze chodzilo o \'{unmapped_words[0]}\'?')
        show_hints = input('Chcesz zobaczyc inne mozliwosci? t/n\n')
        if(show_hints == 't'):
            print(f'Inne możliwosci {unmapped_words[1:num_of_simmilar]}\n')
    else:
Ejemplo n.º 15
0
    parser.add_option("--tok",
                      dest="tok",
                      help="Path to the tokenizer folder",
                      metavar="FILE",
                      default=None)
    (options, args) = parser.parse_args()
    return options


options = get_options()

src_file = options.src
dst_file = options.dst
align_file = options.align
output_file = options.output
tokenizer: TextProcessor = TextProcessor(options.tok)

word_translation = defaultdict(dict)
word_counter = defaultdict(int)

with open(src_file, "r") as sr, open(dst_file,
                                     "r") as dr, open(align_file, "r") as ar:
    for i, (src_line, dst_line, align_line) in enumerate(zip(sr, dr, ar)):
        src_words = src_line.strip().split(" ")
        dst_words = dst_line.strip().split(" ")
        alignments = filter(
            lambda x: x != None,
            map(
                lambda a: (int(a.split("-")[0]), int(a.split("-")[1]))
                if "-" in a else None,
                align_line.strip().split(" ")))
Ejemplo n.º 16
0
    def train(options):
        lex_dict = None
        if options.dict_path is not None:
            lex_dict = get_lex_dict(options.dict_path)
        if not os.path.exists(options.model_path):
            os.makedirs(options.model_path)

        text_processor = TextProcessor(options.tokenizer_path)
        assert text_processor.pad_token_id() == 0

        if options.pretrained_path is not None:
            caption_model = Seq2Seq.load(ImageCaptioning,
                                         options.pretrained_path,
                                         tok_dir=options.tokenizer_path)
        else:
            caption_model = ImageCaptioning(
                use_proposals=lex_dict is not None,
                tie_embed=options.tie_embed,
                text_processor=text_processor,
                resnet_depth=options.resnet_depth,
                lang_dec=options.lang_decoder,
                enc_layer=options.encoder_layer,
                dec_layer=options.decoder_layer,
                embed_dim=options.embed_dim,
                intermediate_dim=options.intermediate_layer_dim,
                use_obj=not options.no_obj)

        if options.lm_path is not None:  # In our case, this is an MT model.
            mt_pret_model = Seq2Seq.load(ImageMassSeq2Seq,
                                         options.lm_path,
                                         tok_dir=options.tokenizer_path)
            assert len(caption_model.encoder.encoder.layer) == len(
                mt_pret_model.encoder.encoder.layer)
            assert len(caption_model.decoder.decoder.layer) == len(
                mt_pret_model.decoder.decoder.layer)
            caption_model.encoder = mt_pret_model.encoder
            caption_model.decoder = mt_pret_model.decoder
            caption_model.output_layer = mt_pret_model.output_layer

        print("Model initialization done!")

        # We assume that the collator function returns a list with the size of number of gpus (in case of cpus,
        collator = dataset.ImageTextCollator()
        num_batches = max(1, torch.cuda.device_count())

        if options.continue_train:
            with open(os.path.join(options.pretrained_path, "optim"),
                      "rb") as fp:
                optimizer = pickle.load(fp)
        else:
            optimizer = build_optimizer(caption_model,
                                        options.learning_rate,
                                        warump_steps=options.warmup)
        trainer = ImageCaptionTrainer(
            model=caption_model,
            mask_prob=options.mask_prob,
            optimizer=optimizer,
            clip=options.clip,
            beam_width=options.beam_width,
            max_len_a=options.max_len_a,
            max_len_b=options.max_len_b,
            len_penalty_ratio=options.len_penalty_ratio,
            fp16=options.fp16,
            mm_mode=options.mm_mode)

        pin_memory = torch.cuda.is_available()
        img_train_loader = ImageMTTrainer.get_img_loader(
            collator,
            dataset.ImageCaptionDataset,
            options.train_path,
            caption_model,
            num_batches,
            options,
            pin_memory,
            lex_dict=lex_dict,
            shuffle=(options.local_rank < 0))
        num_processors = max(torch.cuda.device_count(),
                             1) if options.local_rank < 0 else 1
        mt_train_loader = None
        if options.mt_train_path is not None:
            mt_train_loader = ImageMTTrainer.get_mt_train_data(
                caption_model,
                num_processors,
                options,
                pin_memory,
                lex_dict=lex_dict)

        img_dev_loader = ImageMTTrainer.get_img_loader(
            collator,
            dataset.ImageCaptionTestDataset,
            options.dev_path,
            caption_model,
            num_batches,
            options,
            pin_memory,
            lex_dict=lex_dict,
            shuffle=False,
            denom=2)

        trainer.caption_reference = None
        if img_dev_loader is not None:
            trainer.caption_reference = defaultdict(list)
            generator = (trainer.generator.module if hasattr(
                trainer.generator, "module") else trainer.generator)
            for data in img_dev_loader:
                for batch in data:
                    for b in batch:
                        captions = b["captions"]
                        for id in captions:
                            for caption in captions[id]:
                                refs = get_outputs_until_eos(
                                    text_processor.sep_token_id(),
                                    caption,
                                    remove_first_token=True)
                                ref = [
                                    generator.seq2seq_model.text_processor.
                                    tokenizer.decode(ref.numpy())
                                    for ref in refs
                                ]
                                trainer.caption_reference[id] += ref
            print("Number of dev image/captions",
                  len(trainer.caption_reference))

        mt_dev_loader = None
        if options.mt_dev_path is not None:
            mt_dev_loader = ImageMTTrainer.get_mt_dev_data(caption_model,
                                                           options,
                                                           pin_memory,
                                                           text_processor,
                                                           trainer,
                                                           lex_dict=lex_dict)
            print("Number of dev sentences", len(trainer.reference))

        step, train_epoch = 0, 1
        while options.step > 0 and step < options.step:
            print("train epoch", train_epoch)
            step = trainer.train_epoch(img_data_iter=img_train_loader,
                                       img_dev_data_iter=img_dev_loader,
                                       max_step=options.step,
                                       lex_dict=lex_dict,
                                       mt_train_iter=mt_train_loader,
                                       saving_path=options.model_path,
                                       step=step,
                                       accum=options.accum,
                                       mt_dev_iter=mt_dev_loader,
                                       mtl_weight=options.mtl_weight)
            train_epoch += 1
Ejemplo n.º 17
0
 def __init__(self):
     self.textprocessor = TextProcessor()
     self.low_weight = 0.001
     self.missclick_weight = 0.1
     self.insert_weight = 0.1
     self.delete_weight = 0.1
Ejemplo n.º 18
0
    def train(options):
        if not os.path.exists(options.model_path):
            os.makedirs(options.model_path)

        text_processor = TextProcessor(options.tokenizer_path)
        assert text_processor.pad_token_id() == 0
        num_processors = max(torch.cuda.device_count(), 1)

        mt_model = SenSim(text_processor=text_processor,
                          enc_layer=options.encoder_layer,
                          embed_dim=options.embed_dim,
                          intermediate_dim=options.intermediate_layer_dim)

        if options.pretrained_path is not None:
            pret = Seq2Seq.load(Seq2Seq,
                                options.pretrained_path,
                                tok_dir=options.tokenizer_path)
            mt_model.init_from_lm(pret)

        print("Model initialization done!")

        optimizer = build_optimizer(mt_model,
                                    options.learning_rate,
                                    warump_steps=options.warmup)
        trainer = SenSimTrainer(model=mt_model,
                                mask_prob=options.mask_prob,
                                optimizer=optimizer,
                                clip=options.clip,
                                fp16=options.fp16)

        pin_memory = torch.cuda.is_available()

        mt_train_loader = SenSimTrainer.get_mt_train_data(
            mt_model, num_processors, options, pin_memory)
        src_neg_data = dataset.MassDataset(
            batch_pickle_dir=options.src_neg,
            max_batch_capacity=num_processors * options.total_capacity * 5,
            max_batch=num_processors * options.batch * 5,
            pad_idx=mt_model.text_processor.pad_token_id(),
            keep_pad_idx=False,
            max_seq_len=options.max_seq_len,
            keep_examples=False)
        dst_neg_data = dataset.MassDataset(
            batch_pickle_dir=options.dst_neg,
            max_batch_capacity=num_processors * options.total_capacity * 5,
            max_batch=num_processors * options.batch * 5,
            pad_idx=mt_model.text_processor.pad_token_id(),
            keep_pad_idx=False,
            max_seq_len=options.max_seq_len,
            keep_examples=False)

        src_neg_loader = data_utils.DataLoader(src_neg_data,
                                               batch_size=1,
                                               shuffle=True,
                                               pin_memory=pin_memory)
        dst_neg_loader = data_utils.DataLoader(dst_neg_data,
                                               batch_size=1,
                                               shuffle=True,
                                               pin_memory=pin_memory)

        mt_dev_loader = None
        if options.mt_dev_path is not None:
            mt_dev_loader = SenSimTrainer.get_mt_dev_data(
                mt_model,
                options,
                pin_memory,
                text_processor,
                trainer,
            )

        step, train_epoch = 0, 1
        trainer.best_loss = 1000000
        while options.step > 0 and step < options.step:
            print("train epoch", train_epoch)
            step = trainer.train_epoch(mt_train_iter=mt_train_loader,
                                       max_step=options.step,
                                       mt_dev_iter=mt_dev_loader,
                                       saving_path=options.model_path,
                                       step=step,
                                       src_neg_iter=src_neg_loader,
                                       dst_neg_iter=dst_neg_loader)
            train_epoch += 1