Exemple #1
0
    def load(self, path=ELIT_DEP_BIAFFINE_EN_MIXED, model_root=None, **kwargs):
        """Load from disk

        Parameters
        ----------
        path : str
            path to the directory which typically contains a config.pkl file and a model.bin file

        Returns
        -------
        DepParser
            parser itself
            :param **kwargs:
        """
        path = fetch_resource(path, model_root=model_root)
        config = _Config.load_json(os.path.join(path, 'config.json'))
        config = _Config(**config)
        config.save_dir = path  # redirect root path to what user specified
        vocab = ParserVocabulary.load_json(config.save_vocab_path)
        vocab = ParserVocabulary(vocab)
        self._vocab = vocab
        with mx.Context(mxnet_prefer_gpu()):
            self._parser = BiaffineParser(vocab, config.word_dims, config.tag_dims, config.dropout_emb,
                                          config.lstm_layers,
                                          config.lstm_hiddens, config.dropout_lstm_input, config.dropout_lstm_hidden,
                                          config.mlp_arc_size,
                                          config.mlp_rel_size, config.dropout_mlp, True)
            self._parser.load(config.save_model_path)
        return self
Exemple #2
0
    def parse(self, sentence: Sequence[Tuple]) -> ConllSentence:
        """Parse raw sentence into ConllSentence

        Parameters
        ----------
        sentence : list
            a list of (word, tag) tuples

        Returns
        -------
        elit.structure.ConllSentence
            ConllSentence object
        """
        words = np.zeros((len(sentence) + 1, 1), np.int32)
        tags = np.zeros((len(sentence) + 1, 1), np.int32)
        words[0, 0] = ParserVocabulary.ROOT
        tags[0, 0] = ParserVocabulary.ROOT
        vocab = self._vocab

        for i, (word, tag) in enumerate(sentence):
            words[i + 1, 0], tags[i + 1, 0] = vocab.word2id(word.lower()), vocab.tag2id(tag)

        with mx.Context(mxnet_prefer_gpu()):
            outputs = self._parser.forward(words, tags)
        words = []
        for arc, rel, (word, tag) in zip(outputs[0][0], outputs[0][1], sentence):
            words.append(ConllWord(id=len(words) + 1, form=word, pos=tag, head=arc, relation=vocab.id2rel(rel)))
        return ConllSentence(words)
Exemple #3
0
 def fill(self, path):
     super().fill(path)
     for i, second_decoder in enumerate(self.arc_biaffines):
         sd_path = os.path.join(path, 'second_decoder{}.bin'.format(i))
         if os.path.isfile(sd_path):
             second_decoder.load_parameters(sd_path, ctx=mxnet_prefer_gpu())
             freeze(second_decoder)
Exemple #4
0
    def evaluate(self, test_file, save_dir=None, logger=None, num_buckets_test=10, test_batch_size=5000):
        """Run evaluation on test set

        Parameters
        ----------
        test_file : str or Sequence
            path to test set
        save_dir : str
            where to store intermediate results and log
        logger : logging.logger
            logger for printing results
        num_buckets_test : int
            number of clusters for sentences from test set
        test_batch_size : int
            batch size of test set

        Returns
        -------
        tuple
            UAS, LAS
        """
        parser = self._parser
        vocab = self._vocab
        if not save_dir:
            save_dir = tempfile.mkdtemp()
        with mx.Context(mxnet_prefer_gpu()):
            UAS, LAS, speed = evaluate_official_script(parser, vocab, num_buckets_test, test_batch_size,
                                                       test_file, os.path.join(save_dir, 'valid_tmp'))
        if logger is None:
            logger = init_logger(save_dir, 'test.log')
        logger.info('Test: UAS %.2f%% LAS %.2f%% %d sents/s' % (UAS, LAS, speed))

        return UAS, LAS, speed
Exemple #5
0
    def load_from_file(cls, model_folder, context: mx.Context = None, model_root=None, **kwargs):
        model_folder = fetch_resource(model_folder, model_root=model_root)
        if context is None:
            context = mxnet_prefer_gpu()
        config_path = os.path.join(model_folder, 'config.json')
        config = load_json(config_path)
        # convert embedding str to type
        embeddings = []
        for classpath, param in config['embeddings']:
            embeddings.append((str_to_type(classpath), param))
        config['embeddings'] = embeddings
        config['tag_dictionary'] = Dictionary.from_dict(config['tag_dictionary'])
        with context:
            embeddings = StackedEmbeddings.from_list(config['embeddings'])
            model = SequenceTagger(
                hidden_size=config['hidden_size'],
                embeddings=embeddings,
                tag_dictionary=config['tag_dictionary'],
                tag_type=config['tag_type'],
                use_crf=config['use_crf'],
                use_rnn=config['use_rnn'],
                rnn_layers=config['rnn_layers'])
            # print(config)
            model.load_parameters(os.path.join(model_folder, 'model.bin'), ctx=context)
            if not model.use_crf:
                model.transitions = pickle_load(os.path.join(model_folder, 'transitions.pkl'))  # type:nd.NDArray
                model.transitions = model.transitions.as_in_context(context)

        return model
Exemple #6
0
    def __init__(self, model, detach: bool = True, context: mx.Context = None):
        super().__init__()
        """
            Contextual string embeddings of words, as proposed in Akbik et al., 2018.

            Parameters
            ----------
            arg1 : model
                model string, one of 'news-forward', 'news-backward', 'mix-forward', 'mix-backward', 'german-forward',
                'german-backward' depending on which character language model is desired
            arg2 : detach
                if set to false, the gradient will propagate into the language model. this dramatically slows down
                training and often leads to worse results, so not recommended.
        """
        self.model = model
        self.static_embeddings = detach
        self.context = context if context else mxnet_prefer_gpu()
        self.lm = ContextualStringModel.load_language_model(
            model, context=self.context)
        self.detach = detach
        if detach:
            self.lm.freeze()

        self.is_forward_lm = self.lm.is_forward_lm

        with self.context:
            dummy_sentence = Sentence()
            dummy_sentence.add_token(Token('hello'))
            embedded_dummy = self.embed(dummy_sentence)
            self.__embedding_length = len(
                embedded_dummy[0].get_token(1).get_embedding())
Exemple #7
0
 def __init__(self, context: mx.Context = None) -> None:
     """
     Create a tagger
     :param context: the context under which this component will run
     """
     super().__init__()
     self.tagger = None  # type: SequenceTagger
     self.context = context if context else mxnet_prefer_gpu()
Exemple #8
0
 def __init__(self, context=mxnet_prefer_gpu()) -> None:
     """
     Create a parser
     :param context: the context under which this component will run
     """
     super().__init__()
     self.context = context
     self._parser = None  # type: BiaffineSDPParser
Exemple #9
0
 def __init__(self, context: mx.Context = None) -> None:
     """
     Create a parser
     :param context: the context under which this component will run
     """
     super().__init__()
     self._parser = None  # type: BiaffineParser
     self._vocab = None  # type: ParserVocabulary
     self.context = context if context else mxnet_prefer_gpu()
Exemple #10
0
    def fill(self, path):
        rnn_path = os.path.join(path, 'rnn.bin')
        if os.path.isfile(rnn_path):
            # print('load rnn')
            self.rnn.load_parameters(rnn_path, ctx=mxnet_prefer_gpu())
            freeze(self.rnn)

        for i, (mlp, decoder) in enumerate(zip(self.mlps, self.decoders)):
            mlp_path = os.path.join(path, 'mlp{}.bin'.format(i))
            if os.path.isfile(mlp_path):
                # print('load mlp')
                mlp.load_parameters(mlp_path, ctx=mxnet_prefer_gpu())
                freeze(mlp)

            decoder_path = os.path.join(path, 'decoder{}.bin'.format(i))
            if os.path.isfile(decoder_path):
                # print('load decoder')
                decoder.load_parameters(decoder_path, ctx=mxnet_prefer_gpu())
                freeze(decoder)
Exemple #11
0
    def load(self, load_path, ctx=None):
        """Load model

        Parameters
        ----------
        load_path : str
            path to model file
        """
        if not ctx:
            ctx = mxnet_prefer_gpu()
        self.load_parameters(load_path, allow_missing=True, ctx=ctx)
Exemple #12
0
        train_file='en-pos.trn',
        test_file='en-pos.tst',
        dev_file='en-pos.dev',
        # train_file='train.tsv',
        # test_file='dev.tsv',
        # dev_file='dev.tsv',
    )

    # 2. what tag do we want to predict?
    tag_type = 'pos'

    # 3. make the tag dictionary from the corpus
    tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
    print(tag_dictionary.idx2item)
    model_path = 'data/model/pos/jumbo'
    with mx.Context(mxnet_prefer_gpu()):
        train = False
        print(train)
        use_crf = True
        if train:
            embedding_types = [
                WordEmbeddings(('fasttext', 'crawl-300d-2M-subword')),
                CharLMEmbeddings(EN_LM_FLAIR_FW_WMT11),
                CharLMEmbeddings(EN_LM_FLAIR_BW_WMT11),
            ]

            embeddings = StackedEmbeddings(embeddings=embedding_types)
            # 5. initialize sequence tagger
            tagger = SequenceTagger(hidden_size=256,
                                    embeddings=embeddings,
                                    tag_dictionary=tag_dictionary,
Exemple #13
0
    def train(self, trn_docs: Sequence[Document], dev_docs: Sequence[Document], save_dir, pretrained_embeddings=None,
              min_occur_count=2, lstm_layers=3, word_dims=100, tag_dims=100, dropout_emb=0.33, lstm_hiddens=400,
              dropout_lstm_input=0.33, dropout_lstm_hidden=0.33, mlp_arc_size=500, mlp_rel_size=100,
              dropout_mlp=0.33, learning_rate=2e-3, decay=.75, decay_steps=5000, beta_1=.9, beta_2=.9, epsilon=1e-12,
              num_buckets_train=40, num_buckets_valid=10, num_buckets_test=10, train_iters=50000, train_batch_size=5000,
              test_batch_size=5000, validate_every=100, save_after=5000, debug=False, **kwargs) -> float:
        """
        Train a DEP parser
        :param trn_docs: training set
        :param dev_docs: dev set
        :param save_dir: folder for saving model
        :param pretrained_embeddings: ptretrained embeddings
        :param min_occur_count: filter out features with frequency less than this threshold
        :param lstm_layers: lstm layers
        :param word_dims: dim for word embeddings
        :param tag_dims: dim for tag embeddings
        :param dropout_emb: dropout on word/tag embeddings
        :param lstm_hiddens: dim for lstm hidden states
        :param dropout_lstm_input: dropout on lstm input
        :param dropout_lstm_hidden: variational dropout
        :param mlp_arc_size: arc representation size
        :param mlp_rel_size: rel representation size
        :param dropout_mlp: dropout on output of the mlp
        :param learning_rate: learning rate
        :param decay: see ExponentialScheduler
        :param decay_steps: see ExponentialScheduler
        :param beta_1:see ExponentialScheduler
        :param beta_2: see ExponentialScheduler
        :param epsilon: see ExponentialScheduler
        :param num_buckets_train: cluster training set into this number of groups
        :param num_buckets_valid: cluster dev set into this number of groups
        :param num_buckets_test: cluster test set into this number of groups
        :param train_iters: training iteration
        :param train_batch_size: training batch size
        :param test_batch_size: test batch size
        :param validate_every: validate model on dev set every this number of steps
        :param save_after: save after this number of steps
        :param debug: debug mode
        :param kwargs: not used
        :return: best UAS during training
        """
        logger = init_logger(save_dir)
        config = _Config(trn_docs, dev_docs, '', save_dir, pretrained_embeddings, min_occur_count,
                         lstm_layers, word_dims, tag_dims, dropout_emb, lstm_hiddens, dropout_lstm_input,
                         dropout_lstm_hidden, mlp_arc_size, mlp_rel_size, dropout_mlp, learning_rate, decay,
                         decay_steps,
                         beta_1, beta_2, epsilon, num_buckets_train, num_buckets_valid, num_buckets_test, train_iters,
                         train_batch_size, debug)
        config.save_json()
        self._vocab = vocab = ParserVocabulary(trn_docs,
                                               pretrained_embeddings,
                                               min_occur_count)
        vocab.save_json(config.save_vocab_path)
        vocab.log_info(logger)

        with mx.Context(mxnet_prefer_gpu()):

            self._parser = parser = BiaffineParser(vocab, word_dims, tag_dims,
                                                   dropout_emb,
                                                   lstm_layers,
                                                   lstm_hiddens, dropout_lstm_input,
                                                   dropout_lstm_hidden,
                                                   mlp_arc_size,
                                                   mlp_rel_size, dropout_mlp, debug)
            parser.initialize()
            scheduler = ExponentialScheduler(learning_rate, decay, decay_steps)
            optimizer = mx.optimizer.Adam(learning_rate, beta_1, beta_2, epsilon,
                                          lr_scheduler=scheduler)
            trainer = gluon.Trainer(parser.collect_params(), optimizer=optimizer)
            data_loader = DataLoader(trn_docs, num_buckets_train, vocab)
            global_step = 0
            best_UAS = 0.
            batch_id = 0
            epoch = 1
            total_epoch = math.ceil(train_iters / validate_every)
            logger.info("Epoch {} out of {}".format(epoch, total_epoch))
            bar = Progbar(target=min(validate_every, data_loader.samples))
            while global_step < train_iters:
                for words, tags, arcs, rels in data_loader.get_batches(batch_size=train_batch_size,
                                                                       shuffle=True):
                    with autograd.record():
                        arc_accuracy, rel_accuracy, overall_accuracy, loss = parser.forward(words, tags, arcs,
                                                                                            rels)
                        loss_value = loss.asscalar()
                    loss.backward()
                    trainer.step(train_batch_size)
                    batch_id += 1
                    try:
                        bar.update(batch_id,
                                   exact=[("UAS", arc_accuracy, 2),
                                          # ("LAS", rel_accuracy, 2),
                                          # ("ALL", overall_accuracy, 2),
                                          ("loss", loss_value)])
                    except OverflowError:
                        pass  # sometimes loss can be 0 or infinity, crashes the bar

                    global_step += 1
                    if global_step % validate_every == 0:
                        bar = Progbar(target=min(validate_every, train_iters - global_step))
                        batch_id = 0
                        UAS, LAS, speed = evaluate_official_script(parser, vocab, num_buckets_valid,
                                                                   num_buckets_valid,
                                                                   dev_docs,
                                                                   os.path.join(save_dir, 'valid_tmp'))
                        logger.info('Dev: UAS %.2f%% LAS %.2f%% %d sents/s' % (UAS, LAS, speed))
                        epoch += 1
                        if global_step < train_iters:
                            logger.info("Epoch {} out of {}".format(epoch, total_epoch))
                        if global_step > save_after and UAS > best_UAS:
                            logger.info('- new best score!')
                            best_UAS = UAS
                            parser.save(config.save_model_path)

        # When validate_every is too big
        if not os.path.isfile(config.save_model_path) or best_UAS != UAS:
            parser.save(config.save_model_path)

        return best_UAS
Exemple #14
0
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-10-03 16:27
from types import SimpleNamespace

from elit.component.embedding.fasttext import FastText
from elit.component.tagger.corpus import conll_to_documents, label_map_from_conll
from elit.component.token_tagger.cnn import CNNTokenTagger
from elit.util.mx import mxnet_prefer_gpu

label_map = label_map_from_conll('data/ptb/pos/dev.tsv')
print(label_map)
tagger = CNNTokenTagger(ctx=mxnet_prefer_gpu(), key='pos',
                        embs=[FastText('https://elit-models.s3-us-west-2.amazonaws.com/cc.en.300.bin.zip')],
                        input_config=SimpleNamespace(row=100, col=5, dropout=0.5),
                        output_config=SimpleNamespace(num_class=len(label_map), flatten=True),
                        label_map=label_map
                        )
# 94.38
save_path = 'data/model/cnntagger'
tagger.train(conll_to_documents('data/ptb/pos/train.tsv', headers={0: 'text', 1: 'pos'}, gold=True),
             conll_to_documents('data/ptb/pos/dev.tsv', headers={0: 'text', 1: 'pos'}, gold=True),
             save_path)
tagger.load(save_path)
# Parameter 'dense1.weight' is missing in file 'data/model/cnntagger.params', which contains parameters: 'dense0.weight', 'dense0.bias'. Set allow_missing=True to ignore missing parameters.
tagger.evaluate(conll_to_documents('data/ptb/pos/dev.tsv', headers={0: 'text', 1: 'pos'}))
Exemple #15
0
    def train(self,
              base_path: str,
              learning_rate: float = 0.1,
              mini_batch_size: int = 32,
              max_epochs: int = 100,
              anneal_factor: float = 0.5,
              patience: int = 2,
              save_model: bool = True,
              embeddings_in_gpu: bool = True,
              train_with_dev: bool = False,
              context: mx.Context = None) -> float:
        """

        :param base_path: a folder to store model, log etc.
        :param learning_rate:
        :param mini_batch_size:
        :param max_epochs:
        :param anneal_factor:
        :param patience:
        :param save_model:
        :param embeddings_in_gpu:
        :param train_with_dev:
        :return: best dev f1
        """
        evaluation_method = 'F1'
        if self.model.tag_type in ['ner', 'np', 'srl']:
            evaluation_method = 'span-F1'
        if self.model.tag_type in ['pos', 'upos']:
            evaluation_method = 'accuracy'
        print(evaluation_method)

        os.makedirs(base_path, exist_ok=True)

        loss_txt = os.path.join(base_path, "loss.txt")
        open(loss_txt, "w", encoding='utf-8').close()

        anneal_mode = 'min' if train_with_dev else 'max'
        train_data = self.corpus.train

        # if training also uses dev data, include in training set
        if train_with_dev:
            train_data.extend(self.corpus.dev)

        # At any point you can hit Ctrl + C to break out of training early.
        try:
            if not context:
                context = mxnet_prefer_gpu()
            with mx.Context(context):
                self.model.initialize()
                if not self.model.use_crf:
                    self.model.count_transition_matrix(train_data)
                scheduler = ReduceLROnPlateau(lr=learning_rate,
                                              verbose=True,
                                              factor=anneal_factor,
                                              patience=patience,
                                              mode=anneal_mode)
                optimizer = mx.optimizer.SGD(learning_rate=learning_rate,
                                             lr_scheduler=scheduler,
                                             clip_gradient=5.0)
                trainer = gluon.Trainer(self.model.collect_params(),
                                        optimizer=optimizer)
                for epoch in range(0, max_epochs):
                    current_loss = 0
                    if not self.test_mode:
                        random.shuffle(train_data)

                    batches = [
                        train_data[x:x + mini_batch_size]
                        for x in range(0, len(train_data), mini_batch_size)
                    ]

                    batch_no = 0

                    for batch in batches:
                        batch = batch
                        batch_no += 1

                        if batch_no % 100 == 0:
                            print("%d of %d (%f)" %
                                  (batch_no, len(batches),
                                   float(batch_no / len(batches))))

                        # Step 4. Compute the loss, gradients, and update the parameters by calling optimizer.step()
                        with autograd.record():
                            loss = self.model.neg_log_likelihood(
                                batch, None if embeddings_in_gpu else context)

                        current_loss += loss.sum().asscalar()

                        loss.backward()

                        # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)

                        # optimizer.step()
                        trainer.step(len(batch))

                        sys.stdout.write('.')
                        sys.stdout.flush()

                    current_loss /= len(train_data)

                    if not train_with_dev:
                        print('.. evaluating... dev... ')
                        dev_score, dev_fp, dev_result = self.evaluate(
                            self.corpus.dev,
                            base_path,
                            evaluation_method=evaluation_method,
                            embeddings_in_gpu=embeddings_in_gpu)
                    else:
                        dev_fp = 0
                        dev_result = '_'

                    # anneal against train loss if training with dev, otherwise anneal against dev score
                    scheduler.step(
                        current_loss) if train_with_dev else scheduler.step(
                            dev_score)

                    # save if model is current best and we use dev data for model selection
                    if save_model and not train_with_dev and dev_score == scheduler.best:
                        self.model.save(base_path)
                    summary = '%d' % epoch + '\t({:%H:%M:%S})'.format(datetime.datetime.now()) \
                              + '\t%f\t%d\t%f\tDEV   %d\t' % (
                                  current_loss, scheduler.num_bad_epochs, learning_rate, dev_fp) + dev_result
                    summary = summary.replace('\n', '')
                    # if self.corpus.test and len(self.corpus.test):
                    #     print('test... ')
                    #     test_score, test_fp, test_result = self.evaluate(self.corpus.test, base_path,
                    #                                                      evaluation_method=evaluation_method,
                    #                                                      embeddings_in_memory=embeddings_in_gpu)
                    #     summary += '\tTEST   \t%d\t' % test_fp + test_result
                    with open(loss_txt, "a") as loss_file:
                        loss_file.write('%s\n' % summary)
                        loss_file.close()
                    print(summary)

            # if we do not use dev data for model selection, save final model
            if save_model and train_with_dev:
                self.model.save(base_path)

            return scheduler.best  # return maximum dev f1

        except KeyboardInterrupt:
            print('-' * 89)
            print('Exiting from training early')
            print('saving model')
            self.model.save(base_path + "/final-model")
            print('done')
    def train(self,
              base_path: str,
              sequence_length: int,
              learning_rate: float = 20,
              mini_batch_size: int = 100,
              anneal_factor: float = 0.25,
              patience: int = 10,
              clip=0.25,
              max_epochs: int = 10000):

        number_of_splits = len(self.corpus.train_files)
        val_data = self._batchify(self.corpus.valid, mini_batch_size)

        os.makedirs(base_path, exist_ok=True)
        loss_txt = os.path.join(base_path, 'loss.txt')
        savefile = os.path.join(base_path, 'best-lm.pt')

        try:
            with mx.Context(mxnet_prefer_gpu()):
                self.model.initialize()
                best_val_loss = 100000000
                scheduler = ReduceLROnPlateau(lr=learning_rate,
                                              verbose=True,
                                              factor=anneal_factor,
                                              patience=patience)
                optimizer = mx.optimizer.SGD(learning_rate=learning_rate,
                                             lr_scheduler=scheduler)
                trainer = gluon.Trainer(self.model.collect_params(),
                                        optimizer=optimizer)

                for epoch in range(1, max_epochs + 1):

                    print('Split %d' % epoch +
                          '\t - ({:%H:%M:%S})'.format(datetime.datetime.now()))

                    # for group in optimizer.param_groups:
                    #     learning_rate = group['lr']

                    train_slice = self.corpus.get_next_train_slice()

                    train_data = self._batchify(train_slice, mini_batch_size)
                    print('\t({:%H:%M:%S})'.format(datetime.datetime.now()))

                    # go into train mode
                    # self.model.train()

                    # reset variables
                    epoch_start_time = time.time()
                    total_loss = 0
                    start_time = time.time()

                    hidden = self.model.init_hidden(mini_batch_size)
                    cell = hidden.copy()

                    # not really sure what this does
                    ntokens = len(self.corpus.dictionary)

                    # do batches
                    for batch, i in enumerate(
                            range(0,
                                  len(train_data) - 1, sequence_length)):

                        data, targets = self._get_batch(
                            train_data, i, sequence_length)

                        # Starting each batch, we detach the hidden state from how it was previously produced.
                        # If we didn't, the model would try backpropagating all the way to start of the dataset.
                        hidden = self._repackage_hidden(hidden)
                        cell = self._repackage_hidden(cell)

                        # self.model.zero_grad()
                        # optimizer.zero_grad()

                        # do the forward pass in the model
                        with autograd.record():
                            output, rnn_output, hidden, cell = self.model.forward(
                                data, hidden, cell)
                            # try to predict the targets
                            loss = self.loss_function(
                                output.reshape(-1, ntokens), targets).mean()
                            loss.backward()

                        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
                        # torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip)

                        trainer.step(mini_batch_size)

                        total_loss += loss.asscalar()

                        if batch % self.log_interval == 0 and batch > 0:
                            cur_loss = total_loss.item() / self.log_interval
                            elapsed = time.time() - start_time
                            print(
                                '| split {:3d} /{:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                                'loss {:5.2f} | ppl {:8.2f}'.format(
                                    epoch, number_of_splits, batch,
                                    len(train_data) // sequence_length,
                                    elapsed * 1000 / self.log_interval,
                                    cur_loss, self._safe_exp(cur_loss)))
                            total_loss = 0
                            start_time = time.time()

                    print('epoch {} done! \t({:%H:%M:%S})'.format(
                        epoch, datetime.datetime.now()))
                    scheduler.step(cur_loss)

                    ###############################################################################
                    # TEST
                    ###############################################################################
                    # skip evaluation
                    # val_loss = self.evaluate(val_data, mini_batch_size, sequence_length)
                    # scheduler.step(val_loss)
                    #
                    # # Save the model if the validation loss is the best we've seen so far.
                    # if val_loss < best_val_loss:
                    #     self.model.save(savefile)
                    #     best_val_loss = val_loss
                    #     print('best loss so far {:5.2f}'.format(best_val_loss))
                    val_loss = cur_loss
                    if (self.corpus.current_train_file_index +
                            1) % 100 == 0 or self.corpus.is_last_slice:
                        self.model.save(savefile)

                    ###############################################################################
                    # print info
                    ###############################################################################
                    print('-' * 89)

                    local_split_number = epoch % number_of_splits
                    if local_split_number == 0:
                        local_split_number = number_of_splits

                    summary = '| end of split {:3d} /{:3d} | epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' \
                              'valid ppl {:8.2f} | learning rate {:3.2f}'.format(local_split_number,
                                                                                 number_of_splits,
                                                                                 epoch,
                                                                                 (time.time() - epoch_start_time),
                                                                                 val_loss,
                                                                                 self._safe_exp(val_loss),
                                                                                 learning_rate)

                    with open(loss_txt, "a") as myfile:
                        myfile.write('%s\n' % summary)

                    print(summary)
                    print('-' * 89)

        except KeyboardInterrupt:
            print('-' * 89)
            print('Exiting from training early')