Ejemplo n.º 1
0
    def __init__(self,
                 model,
                 mask_prob: float = 0.15,
                 clip: int = 1,
                 optimizer=None):
        self.model = model
        self.clip = clip
        self.optimizer = optimizer

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)

        self.mask_prob = mask_prob
        self.criterion = nn.NLLLoss(
            ignore_index=model.text_processor.pad_token_id())

        num_gpu = torch.cuda.device_count()
        if num_gpu > 1:
            print("Let's use", num_gpu, "GPUs!")
            self.model = DataParallelModel(self.model)
            self.criterion = DataParallelCriterion(self.criterion)

        self.best_dev_loss = float("inf")
        self.best_train_loss = float("inf")
        self.last_train_loss = float("inf")
Ejemplo n.º 2
0
    def __init__(self,
                 model,
                 mask_prob: float = 0.3,
                 clip: int = 1,
                 optimizer=None,
                 beam_width: int = 5,
                 max_len_a: float = 1.1,
                 max_len_b: int = 5,
                 len_penalty_ratio: float = 0.8,
                 nll_loss: bool = False,
                 fp16: bool = False,
                 mm_mode="mixed"):
        self.model = model

        self.clip = clip
        self.optimizer = optimizer

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
        self.num_gpu = torch.cuda.device_count()

        self.mask_prob = mask_prob
        if nll_loss:
            self.criterion = nn.NLLLoss(
                ignore_index=model.text_processor.pad_token_id())
        else:
            self.criterion = SmoothedNLLLoss(
                ignore_index=model.text_processor.pad_token_id())

        self.fp16 = False
        if self.num_gpu == 1 and fp16:
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level="O2")
            self.fp16 = True
        self.generator = BeamDecoder(self.model,
                                     beam_width=beam_width,
                                     max_len_a=max_len_a,
                                     max_len_b=max_len_b,
                                     len_penalty_ratio=len_penalty_ratio)
        if self.num_gpu > 1:
            print("Let's use", self.num_gpu, "GPUs!")
            self.model = DataParallelModel(self.model)
            self.criterion = DataParallelCriterion(self.criterion)
            self.generator = DataParallelModel(self.generator)

        self.reference = None
        self.best_bleu = -1.0
        self.mm_mode = mm_mode
def build_model(options):
    model = Caption2Image.load(options.model_path, options.tokenizer_path)
    caption_model = Seq2Seq.load(ImageCaptioning,
                                 options.caption_model_path,
                                 tok_dir=options.tokenizer_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    caption_model = caption_model.to(device)
    num_gpu = torch.cuda.device_count()
    generator = BeamDecoder(caption_model,
                            beam_width=options.beam_width,
                            max_len_a=options.max_len_a,
                            max_len_b=options.max_len_b,
                            len_penalty_ratio=options.len_penalty_ratio)
    if options.fp16:
        model = amp.initialize(model, opt_level="O2")
        generator = amp.initialize(generator, opt_level="O2")
    if num_gpu > 1:
        model = DataParallelModel(model)
        generator = DataParallelModel(generator)
    return model, generator, model.text_processor
    def __init__(self, model, caption_model, mask_prob: float = 0.3, clip: int = 1, optimizer=None,
                 beam_width: int = 5, max_len_a: float = 1.1, max_len_b: int = 5,
                 len_penalty_ratio: float = 0.8, nll_loss: bool = False, fp16: bool = False, mm_mode="mixed"):
        super().__init__(model, mask_prob, clip, optimizer, beam_width, max_len_a, max_len_b, len_penalty_ratio,
                         nll_loss, fp16, mm_mode)
        self.caption_model = caption_model
        self.caption_model.eval()
        self.caption_model = self.caption_model.to(self.device)

        if self.num_gpu == 1 and fp16:
            self.caption_model = amp.initialize(self.caption_model, opt_level="O2")

        if self.num_gpu > 1:
            print("Let's use", self.num_gpu, "GPUs!")
            self.caption_model = DataParallelModel(self.caption_model)
Ejemplo n.º 5
0
def build_model(options):
    model = Seq2Seq.load(Seq2Seq,
                         options.model_path,
                         tok_dir=options.tokenizer_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    num_gpu = torch.cuda.device_count()
    generator = BeamDecoder(model,
                            beam_width=options.beam_width,
                            max_len_a=options.max_len_a,
                            max_len_b=options.max_len_b,
                            len_penalty_ratio=options.len_penalty_ratio)
    if options.fp16 and torch.cuda.is_available():
        from apex import amp
        generator = amp.initialize(generator, opt_level="O2")
    if num_gpu > 1:
        generator = DataParallelModel(generator)
    return generator, model.text_processor
Ejemplo n.º 6
0
class ImageMTTrainer:
    def __init__(self,
                 model,
                 mask_prob: float = 0.3,
                 clip: int = 1,
                 optimizer=None,
                 beam_width: int = 5,
                 max_len_a: float = 1.1,
                 max_len_b: int = 5,
                 len_penalty_ratio: float = 0.8,
                 nll_loss: bool = False,
                 fp16: bool = False,
                 mm_mode="mixed"):
        self.model = model

        self.clip = clip
        self.optimizer = optimizer

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
        self.num_gpu = torch.cuda.device_count()

        self.mask_prob = mask_prob
        if nll_loss:
            self.criterion = nn.NLLLoss(
                ignore_index=model.text_processor.pad_token_id())
        else:
            self.criterion = SmoothedNLLLoss(
                ignore_index=model.text_processor.pad_token_id())

        self.fp16 = False
        if self.num_gpu == 1 and fp16:
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level="O2")
            self.fp16 = True
        self.generator = BeamDecoder(self.model,
                                     beam_width=beam_width,
                                     max_len_a=max_len_a,
                                     max_len_b=max_len_b,
                                     len_penalty_ratio=len_penalty_ratio)
        if self.num_gpu > 1:
            print("Let's use", self.num_gpu, "GPUs!")
            self.model = DataParallelModel(self.model)
            self.criterion = DataParallelCriterion(self.criterion)
            self.generator = DataParallelModel(self.generator)

        self.reference = None
        self.best_bleu = -1.0
        self.mm_mode = mm_mode

    def train_epoch(self,
                    img_data_iter: List[data_utils.DataLoader] = None,
                    step: int = 10,
                    saving_path: str = None,
                    mass_data_iter: List[data_utils.DataLoader] = None,
                    mt_dev_iter: List[data_utils.DataLoader] = None,
                    mt_train_iter: List[data_utils.DataLoader] = None,
                    max_step: int = 300000,
                    fine_tune: bool = False,
                    lang_directions: dict = False,
                    lex_dict=None,
                    save_opt: bool = False,
                    **kwargs):
        "Standard Training and Logging Function"
        start = time.time()
        total_tokens, total_loss, tokens, cur_loss = 0, 0, 0, 0
        cur_loss = 0
        batch_zip, shortest = self.get_batch_zip(img_data_iter, mass_data_iter,
                                                 mt_train_iter)

        model = (self.model.module
                 if hasattr(self.model, "module") else self.model)
        for i, batches in enumerate(batch_zip):
            for batch in batches:
                self.optimizer.zero_grad()
                try:
                    src_inputs = batch["src_texts"].squeeze(0)
                    src_mask = batch["src_pad_mask"].squeeze(0)
                    tgt_inputs = batch["dst_texts"].squeeze(0)
                    tgt_mask = batch["dst_pad_mask"].squeeze(0)
                    src_langs = batch["src_langs"].squeeze(0)
                    dst_langs = batch["dst_langs"].squeeze(0)
                    proposals = batch["proposal"].squeeze(
                        0) if lex_dict is not None else None
                    if src_inputs.size(0) < self.num_gpu:
                        continue
                    predictions = self.model(
                        src_inputs=src_inputs,
                        tgt_inputs=tgt_inputs,
                        src_pads=src_mask,
                        tgt_mask=tgt_mask,
                        src_langs=src_langs,
                        tgt_langs=dst_langs,
                        proposals=proposals,
                        pad_idx=model.text_processor.pad_token_id(),
                        log_softmax=True)
                    targets = tgt_inputs[:, 1:].contiguous().view(-1)
                    tgt_mask_flat = tgt_mask[:, 1:].contiguous().view(-1)
                    targets = targets[tgt_mask_flat]
                    ntokens = targets.size(0)

                    if self.num_gpu == 1:
                        targets = targets.to(predictions.device)

                    loss = self.criterion(predictions, targets).mean()
                    backward(loss, self.optimizer, self.fp16)

                    loss = float(loss.data) * ntokens
                    tokens += ntokens
                    total_tokens += ntokens
                    total_loss += loss
                    cur_loss += loss

                    # We accumulate the gradients for both tasks!
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.clip)
                    self.optimizer.step()
                    step += 1

                    if step % 50 == 0 and tokens > 0:
                        elapsed = time.time() - start
                        print(
                            datetime.datetime.now(),
                            "Epoch Step: %d Loss: %f Tokens per Sec: %f " %
                            (step, cur_loss / tokens, tokens / elapsed))

                        if step % 500 == 0:
                            if mt_dev_iter is not None and step % 5000 == 0:
                                bleu = self.eval_bleu(mt_dev_iter, saving_path)
                                print("BLEU:", bleu)

                            model.save(saving_path)
                            if save_opt:
                                with open(os.path.join(saving_path, "optim"),
                                          "wb") as fp:
                                    pickle.dump(self.optimizer, fp)

                        start, tokens, cur_loss = time.time(), 0, 0

                except RuntimeError as err:
                    print(repr(err))
                    torch.cuda.empty_cache()

            if i == shortest - 1:
                break
            if step >= max_step:
                break

        try:
            print("Total loss in this epoch: %f" % (total_loss / total_tokens))
            model.save(saving_path)

            if mt_dev_iter is not None:
                bleu = self.eval_bleu(mt_dev_iter, saving_path)
                print("BLEU:", bleu)
        except RuntimeError as err:
            print(repr(err))

        return step

    def get_batch_zip(self, img_data_iter, mass_data_iter, mt_train_iter):
        # if img_data_iter is not None and mt_train_iter is not None:
        #     img_data_iter *= 5
        # if mass_data_iter is not None and mt_train_iter is not None:
        #     mass_data_iter *= 5
        iters = list(
            chain(*filter(lambda x: x != None,
                          [img_data_iter, mass_data_iter, mt_train_iter])))
        shortest = min(len(l) for l in iters)
        return zip(*iters), shortest

    def eval_bleu(self, dev_data_iter, saving_path, save_opt: bool = False):
        mt_output = []
        src_text = []
        model = (self.model.module
                 if hasattr(self.model, "module") else self.model)
        model.eval()

        with torch.no_grad():
            for iter in dev_data_iter:
                for batch in iter:
                    src_inputs = batch["src_texts"].squeeze(0)
                    src_mask = batch["src_pad_mask"].squeeze(0)
                    tgt_inputs = batch["dst_texts"].squeeze(0)
                    src_langs = batch["src_langs"].squeeze(0)
                    dst_langs = batch["dst_langs"].squeeze(0)
                    src_pad_idx = batch["pad_idx"].squeeze(0)
                    proposal = batch["proposal"].squeeze(
                        0) if batch["proposal"] is not None else None

                    src_ids = get_outputs_until_eos(
                        model.text_processor.sep_token_id(),
                        src_inputs,
                        remove_first_token=True)
                    src_text += list(
                        map(
                            lambda src: model.text_processor.tokenizer.decode(
                                src.numpy()), src_ids))

                    outputs = self.generator(
                        src_inputs=src_inputs,
                        src_sizes=src_pad_idx,
                        first_tokens=tgt_inputs[:, 0],
                        src_mask=src_mask,
                        src_langs=src_langs,
                        tgt_langs=dst_langs,
                        pad_idx=model.text_processor.pad_token_id(),
                        proposals=proposal)
                    if self.num_gpu > 1:
                        new_outputs = []
                        for output in outputs:
                            new_outputs += output
                        outputs = new_outputs

                    mt_output += list(
                        map(
                            lambda x: model.text_processor.tokenizer.decode(x[
                                1:].numpy()), outputs))

            model.train()
        bleu = sacrebleu.corpus_bleu(mt_output,
                                     [self.reference[:len(mt_output)]],
                                     lowercase=True,
                                     tokenize="intl")

        with open(os.path.join(saving_path, "bleu.output"), "w") as writer:
            writer.write("\n".join([
                src + "\n" + ref + "\n" + o + "\n\n***************\n"
                for src, ref, o in zip(src_text, mt_output,
                                       self.reference[:len(mt_output)])
            ]))

        if bleu.score > self.best_bleu:
            self.best_bleu = bleu.score
            print("Saving best BLEU", self.best_bleu)
            with open(os.path.join(saving_path, "bleu.best.output"),
                      "w") as writer:
                writer.write("\n".join([
                    src + "\n" + ref + "\n" + o + "\n\n***************\n"
                    for src, ref, o in zip(src_text, mt_output,
                                           self.reference[:len(mt_output)])
                ]))

            model.save(saving_path)
            if save_opt:
                with open(os.path.join(saving_path, "optim"), "wb") as fp:
                    pickle.dump(self.optimizer, fp)

        return bleu.score

    @staticmethod
    def train(options):
        lex_dict = None
        if options.dict_path is not None:
            lex_dict = get_lex_dict(options.dict_path)
        if not os.path.exists(options.model_path):
            os.makedirs(options.model_path)

        text_processor = TextProcessor(options.tokenizer_path)
        assert text_processor.pad_token_id() == 0
        num_processors = max(torch.cuda.device_count(), 1)

        if options.pretrained_path is not None:
            print("Loading pretrained path", options.pretrained_path)
            mt_model = Seq2Seq.load(ImageMassSeq2Seq,
                                    options.pretrained_path,
                                    tok_dir=options.tokenizer_path)
        else:
            mt_model = ImageMassSeq2Seq(
                use_proposals=lex_dict is not None,
                tie_embed=options.tie_embed,
                text_processor=text_processor,
                resnet_depth=options.resnet_depth,
                lang_dec=options.lang_decoder,
                enc_layer=options.encoder_layer,
                dec_layer=options.decoder_layer,
                embed_dim=options.embed_dim,
                intermediate_dim=options.intermediate_layer_dim)

        if options.lm_path is not None:
            lm = LM(text_processor=text_processor,
                    enc_layer=options.encoder_layer,
                    embed_dim=options.embed_dim,
                    intermediate_dim=options.intermediate_layer_dim)
            mt_model.init_from_lm(lm)

        print("Model initialization done!")

        # We assume that the collator function returns a list with the size of number of gpus (in case of cpus,
        collator = dataset.ImageTextCollator()
        num_batches = max(1, torch.cuda.device_count())

        if options.continue_train:
            with open(os.path.join(options.pretrained_path, "optim"),
                      "rb") as fp:
                optimizer = pickle.load(fp)
        else:
            optimizer = build_optimizer(mt_model,
                                        options.learning_rate,
                                        warump_steps=options.warmup)
        trainer = ImageMTTrainer(model=mt_model,
                                 mask_prob=options.mask_prob,
                                 optimizer=optimizer,
                                 clip=options.clip,
                                 beam_width=options.beam_width,
                                 max_len_a=options.max_len_a,
                                 max_len_b=options.max_len_b,
                                 len_penalty_ratio=options.len_penalty_ratio,
                                 fp16=options.fp16,
                                 mm_mode=options.mm_mode)

        pin_memory = torch.cuda.is_available()

        mt_train_loader = None
        if options.mt_train_path is not None:
            mt_train_loader = ImageMTTrainer.get_mt_train_data(
                mt_model,
                num_processors,
                options,
                pin_memory,
                lex_dict=lex_dict)

        mt_dev_loader = None
        if options.mt_dev_path is not None:
            mt_dev_loader = ImageMTTrainer.get_mt_dev_data(mt_model,
                                                           options,
                                                           pin_memory,
                                                           text_processor,
                                                           trainer,
                                                           lex_dict=lex_dict)

        step, train_epoch = 0, 1
        while options.step > 0 and step < options.step and train_epoch <= 10:
            print("train epoch", train_epoch, "step:", step)
            step = trainer.train_epoch(mt_train_iter=mt_train_loader,
                                       max_step=options.step,
                                       lex_dict=lex_dict,
                                       mt_dev_iter=mt_dev_loader,
                                       saving_path=options.model_path,
                                       step=step,
                                       save_opt=False)
            train_epoch += 1

    @staticmethod
    def get_mt_dev_data(mt_model,
                        options,
                        pin_memory,
                        text_processor,
                        trainer,
                        lex_dict=None):
        mt_dev_loader = []
        dev_paths = options.mt_dev_path.split(",")
        trainer.reference = []
        for dev_path in dev_paths:
            mt_dev_data = dataset.MTDataset(
                batch_pickle_dir=dev_path,
                max_batch_capacity=options.total_capacity,
                keep_pad_idx=True,
                max_batch=int(options.batch / (options.beam_width * 2)),
                pad_idx=mt_model.text_processor.pad_token_id(),
                lex_dict=lex_dict)
            dl = data_utils.DataLoader(mt_dev_data,
                                       batch_size=1,
                                       shuffle=False,
                                       pin_memory=pin_memory)
            mt_dev_loader.append(dl)

            print("creating reference")

            generator = (trainer.generator.module if hasattr(
                trainer.generator, "module") else trainer.generator)

            for batch in dl:
                tgt_inputs = batch["dst_texts"].squeeze()
                refs = get_outputs_until_eos(text_processor.sep_token_id(),
                                             tgt_inputs,
                                             remove_first_token=True)
                ref = [
                    generator.seq2seq_model.text_processor.tokenizer.decode(
                        ref.numpy()) for ref in refs
                ]
                trainer.reference += ref
        return mt_dev_loader

    @staticmethod
    def get_mt_train_data(mt_model,
                          num_processors,
                          options,
                          pin_memory,
                          lex_dict=None):
        mt_train_loader = []
        train_paths = options.mt_train_path.split(",")
        for train_path in train_paths:
            mt_train_data = dataset.MTDataset(
                batch_pickle_dir=train_path,
                max_batch_capacity=int(num_processors *
                                       options.total_capacity / 2),
                max_batch=int(num_processors * options.batch / 2),
                pad_idx=mt_model.text_processor.pad_token_id(),
                lex_dict=lex_dict,
                keep_pad_idx=False)
            mtl = data_utils.DataLoader(mt_train_data,
                                        batch_size=1,
                                        shuffle=True,
                                        pin_memory=pin_memory)
            mt_train_loader.append(mtl)
        return mt_train_loader
Ejemplo n.º 7
0
class LMTrainer:
    def __init__(self,
                 model,
                 mask_prob: float = 0.15,
                 clip: int = 1,
                 optimizer=None):
        self.model = model
        self.clip = clip
        self.optimizer = optimizer

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)

        self.mask_prob = mask_prob
        self.criterion = nn.NLLLoss(
            ignore_index=model.text_processor.pad_token_id())

        num_gpu = torch.cuda.device_count()
        if num_gpu > 1:
            print("Let's use", num_gpu, "GPUs!")
            self.model = DataParallelModel(self.model)
            self.criterion = DataParallelCriterion(self.criterion)

        self.best_dev_loss = float("inf")
        self.best_train_loss = float("inf")
        self.last_train_loss = float("inf")

    def train_epoch(self, data_iter: data_utils.DataLoader,
                    dev_data_iter: data_utils.DataLoader, saving_path: str,
                    step: int):
        "Standard Training and Logging Function"
        start = time.time()
        total_tokens, total_loss, tokens, cur_loss = 0, 0, 0, 0
        cur_loss = 0
        model = self.model.module if hasattr(self.model,
                                             "module") else self.model

        for i, batch in enumerate(data_iter):
            if self.optimizer is not None:
                self.optimizer.zero_grad()
            mask, target, texts = mask_text(self.mask_prob, batch["pad_mask"],
                                            batch["texts"],
                                            model.text_processor)
            try:
                predictions = self.model(mask=mask,
                                         texts=texts,
                                         pads=batch["pad_mask"],
                                         langs=batch["langs"])
                ntokens = target.size(0)

                if ntokens == 0:  # Nothing to predict!
                    continue

                loss = self.criterion(predictions, target).mean()
                loss.backward()

                unmask_text(mask, target, texts)

                if self.optimizer is not None:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.clip)

                    self.optimizer.step()
                    step += 1

                loss = float(loss.data) * ntokens
                total_loss += loss
                cur_loss += loss
                total_tokens += ntokens
                tokens += ntokens

                if step % 50 == 0:
                    elapsed = time.time() - start
                    print(
                        datetime.datetime.now(),
                        "Epoch Step: %d Loss: %f Tokens per Sec: %f" %
                        (step, cur_loss / tokens, tokens / elapsed))

                    if step % 500 == 0:
                        self.validate_and_save(saving_path, dev_data_iter)

                    start, tokens, cur_loss = time.time(), 0, 0
            except RuntimeError as err:
                print("Problem with batch item", texts.size())
                torch.cuda.empty_cache()
                pass

        current_loss = total_loss / total_tokens
        print("Total loss in this epoch: %f" % current_loss)
        if current_loss < self.best_train_loss:
            self.best_train_loss = current_loss
            model_to_save = (self.model.module if hasattr(
                self.model, "module") else self.model)
            model_to_save.save(saving_path + ".latest")
            with open(os.path.join(saving_path + ".latest", "optim"),
                      "wb") as fp:
                pickle.dump(self.optimizer, fp)
        self.last_train_loss = current_loss

        self.validate_and_save(saving_path, dev_data_iter)
        return step

    def validate_and_save(self, saving_path, dev_data_iter):
        with torch.no_grad():
            model = self.model.module if hasattr(self.model,
                                                 "module") else self.model
            model.eval()
            total_dev_loss, total_dev_tokens = 0, 0
            for batch in dev_data_iter:
                mask, target, texts = mask_text(self.mask_prob,
                                                batch["pad_mask"],
                                                batch["texts"].clone(),
                                                model.text_processor)
                predictions = self.model(mask=mask,
                                         texts=texts,
                                         pads=batch["pad_mask"],
                                         langs=batch["langs"])
                ntokens = target.size(0)

                if ntokens == 0:  # Nothing to predict!
                    continue
                loss = self.criterion(predictions,
                                      target).mean().data * ntokens
                total_dev_loss += float(loss)
                total_dev_tokens += ntokens

            dev_loss = total_dev_loss / total_dev_tokens
            print("Current dev loss", dev_loss)
            if self.best_dev_loss > float(dev_loss):
                self.best_dev_loss = float(dev_loss)
                print("saving best dev loss", self.best_dev_loss)
                model_to_save = (self.model.module if hasattr(
                    self.model, "module") else self.model)
                model_to_save.save(saving_path)
                with open(os.path.join(saving_path, "optim"), "wb") as fp:
                    pickle.dump(self.optimizer, fp)
            model.train()

    @staticmethod
    def config_dropout(model, dropout):
        model.encoder.config.hidden_dropout_prob = dropout
        model.encoder.config.attention_probs_dropout_prob = dropout

    @staticmethod
    def train(options):
        if not os.path.exists(options.model_path):
            os.makedirs(options.model_path)

        text_processor = TextProcessor(options.tokenizer_path)

        lm_class = ReformerLM if options.reformer else LM
        if options.pretrained_path is None:
            lm = lm_class(text_processor=text_processor,
                          size=options.model_size)
        else:
            lm = lm_class.load(options.pretrained_path)

        if options.reformer:
            lm.config.hidden_dropout_prob = options.dropout
            lm.config.local_attention_probs_dropout_prob = options.dropout
            lm.config.lsh_attention_probs_dropout_prob = options.dropout
        else:
            LMTrainer.config_dropout(lm, options.dropout)

        train_data = dataset.TextDataset(save_cache_dir=options.train_path,
                                         max_cache_size=options.cache_size)
        dev_data = dataset.TextDataset(save_cache_dir=options.dev_path,
                                       max_cache_size=options.cache_size,
                                       load_all=True)

        if options.continue_train:
            with open(os.path.join(options.pretrained_path, "optim"),
                      "rb") as fp:
                optimizer = pickle.load(fp)
        else:
            optimizer = build_optimizer(lm, options.learning_rate,
                                        options.warmup)

        trainer = LMTrainer(model=lm,
                            mask_prob=options.mask_prob,
                            optimizer=optimizer,
                            clip=options.clip)

        collator = dataset.TextCollator(pad_idx=text_processor.pad_token_id())
        train_sampler, dev_sampler = None, None

        pin_memory = torch.cuda.is_available()
        loader = data_utils.DataLoader(train_data,
                                       batch_size=options.batch,
                                       shuffle=False,
                                       pin_memory=pin_memory,
                                       collate_fn=collator,
                                       sampler=train_sampler)
        dev_loader = data_utils.DataLoader(dev_data,
                                           batch_size=options.batch,
                                           shuffle=False,
                                           pin_memory=pin_memory,
                                           collate_fn=collator,
                                           sampler=dev_sampler)

        step, train_epoch = 0, 1
        while step <= options.step:
            print("train epoch", train_epoch)
            step = trainer.train_epoch(data_iter=loader,
                                       dev_data_iter=dev_loader,
                                       saving_path=options.model_path,
                                       step=step)