예제 #1
0
파일: model.py 프로젝트: ericxsun/BiDTree
    def run_epoch(self, sess, train, train_deps, dev, dev_deps, vocab_words,
                  vocab_tags, epoch):
        """
        Performs one complete pass over the train set and evaluate on dev
        """
        self.config.istrain = True  # set to train first, #batch normalization#
        nbatches = (len(train_deps) + self.config.batch_size -
                    1) / self.config.batch_size
        prog = Progbar(target=nbatches)
        for i, (words, poss, chunks, labels, btup_idx_list, btup_words_list,
                btup_depwords_list, btup_deprels_list,
                btup_depwords_length_list, upbt_idx_list, upbt_words_list,
                upbt_depwords_list, upbt_deprels_list,
                upbt_depwords_length_list, btup_formidx_list,
                upbt_formidx_list) in enumerate(
                    minibatches(train, train_deps, self.config.batch_size)):
            fd, sequence_lengths = self.get_feed_dict(
                words, poss, chunks, labels, btup_idx_list, btup_words_list,
                btup_depwords_list, btup_deprels_list,
                btup_depwords_length_list, upbt_idx_list, upbt_words_list,
                upbt_depwords_list, upbt_deprels_list,
                upbt_depwords_length_list, btup_formidx_list,
                upbt_formidx_list, self.config.lr, self.config.dropout)

            _, train_loss, logits = sess.run(
                [self.train_op, self.loss, self.logits], feed_dict=fd)
            prog.update(i + 1, [("train loss", train_loss)])

        acc, recall, f1, test_acc = self.run_evaluate(sess, dev, dev_deps,
                                                      vocab_words, vocab_tags)
        self.logger.info(
            "- dev acc {:04.2f} - dev recall {:04.2f} - f1 {:04.2f} - test acc {:04.2f}"
            .format(100 * acc, 100 * recall, 100 * f1, 100 * test_acc))
        return acc, recall, f1, train_loss
예제 #2
0
    def run_epoch(self, sess, train, dev, tags, epoch):
        """
        Performs one complete pass over the train set and evaluate on dev
        Args:
            sess: tensorflow session
            train: dataset that yields tuple of sentences, tags
            dev: dataset
            tags: {tag: index} dictionary
            epoch: (int) number of the epoch
        """
        nbatches = (len(train) + self.config.batch_size -
                    1) // self.config.batch_size
        prog = Progbar(target=nbatches)
        for i, (words,
                labels) in enumerate(minibatches(train,
                                                 self.config.batch_size)):
            fd, _ = self.get_feed_dict(words, labels, self.config.lr,
                                       self.config.dropout)

            _, train_loss, summary = sess.run(
                [self.train_op, self.loss, self.merged], feed_dict=fd)

            prog.update(i + 1, [("train loss", train_loss)])

            # tensorboard
            if i % 10 == 0:
                self.file_writer.add_summary(summary, epoch * nbatches + i)

        acc, f1 = self.run_evaluate(sess, dev, tags)
        self.logger.info("- dev acc {:04.2f} - f1 {:04.2f}".format(
            100 * acc, 100 * f1))
        return acc, f1
    def run_epoch(self, session, train):
        """
        Perform one complete pass over the training data and evaluate on dev
        """

        nbatches = (len(train) + self.config.batch_size -
                    1) / self.config.batch_size
        prog = Progbar(target=nbatches)

        total_loss = 0

        for i, (q_batch, c_batch, a_batch) in enumerate(
                minibatches(train, self.config.batch_size)):

            # at training time, dropout needs to be on.
            input_feed = self.get_feed_dict(q_batch, c_batch, a_batch,
                                            self.config.dropout_val)

            _, train_loss = session.run([self.train_op, self.loss],
                                        feed_dict=input_feed)

            #            tf.summary.scalar("train loss", train_loss)

            prog.update(i + 1, [("train loss", train_loss)])
            total_loss += train_loss

        self.losses.append(total_loss)
        print(">>>>>>>", str(total_loss))
예제 #4
0
    def run_epoch(self, train, dev, train_eval, epoch):
        """Performs one complete pass over the train set and evaluate on dev

        Args:
            train: dataset that yields tuple of sentences, tags
            dev: dataset
            epoch: (int) index of the current epoch

        Returns:
            f1: (python float), score to select model on, higher is better

        """
        # progbar stuff for logging
        batch_size = self.config.batch_size
        nbatches = (len(train) + batch_size - 1) // batch_size
        prog = Progbar(target=nbatches)

        # iterate over dataset
        for i, (words, labels) in enumerate(minibatches(train, batch_size)):
            fd, _ = self.get_feed_dict(True, words, labels, lr=self.config.lr)
            _, train_loss = self.sess.run([self.train_op, self.loss],
                                          feed_dict=fd)
            prog.update(i + 1, values=[("train loss", train_loss)])

        acc_train = self.evaluate(train_eval)
        acc_test = self.evaluate(dev)

        prog.update(i + 1,
                    epoch, [("train loss", train_loss)],
                    exact=[("dev acc", acc_test), ("train acc", acc_train),
                           ("lr", self.config.lr)])

        return acc_train, acc_test, train_loss
예제 #5
0
    def run_epoch(self, sess, train, dev, tags, epoch):
        """
        Performs one complete pass over the train set and evaluate on dev
        Args:
            sess: tensorflow session
            train: dataset that yields tuple of sentences, tags
            dev: dataset
            tags: {tag: index} dictionary
            epoch: (int) number of the epoch
        """
        nbatches = (len(train) + self.config.batch_size - 1) // self.config.batch_size
        prog = Progbar(target=nbatches)
        for i, (words, labels) in enumerate(minibatches(train, self.config.batch_size)):
            fd, _ = self.get_feed_dict(words, labels, self.config.lr, self.config.dropout)

            _, train_loss, summary = sess.run([self.train_op, self.loss, self.merged], feed_dict=fd)

            prog.update(i + 1, [("train loss", train_loss)])

            # tensorboard
            if i % 10 == 0:
                self.file_writer.add_summary(summary, epoch*nbatches + i)

        acc, f1 = self.run_evaluate(sess, dev, tags)
        self.logger.info("- dev acc {:04.2f} - f1 {:04.2f}".format(100*acc, 100*f1))
        return acc, f1
예제 #6
0
    def run_epoch(self, sess, train, dev, epoch):
        """
        Performs one complete pass over the train set and evaluate on dev
        Args:
            sess: tensorflow session
            train: dataset that yields tuple of sentences, tags
            dev: dataset
            epoch: (int) number of the epoch
        """
        nbatches = (len(train) + self.config.batch_size -
                    1) / self.config.batch_size
        prog = Progbar(target=nbatches, verbose=False)
        for i, (framelist,
                phones) in enumerate(minibatches(train,
                                                 self.config.batch_size)):

            fd, _ = self.get_feed_dict(framelist, phones, self.config.lr,
                                       self.config.keep_prob)

            _, train_loss, summary = sess.run(
                [self.train_op, self.loss, self.merged], feed_dict=fd)

            prog.update(i + 1, [("train loss", train_loss)])

            # tensorboard
            if i % 10 == 0:
                self.file_writer.add_summary(summary, epoch * nbatches + i)

        acc, per = self.run_evaluate(sess, dev)
        self.logger.info(" - dev accuracy {:04.2f} - PER {:04.2f}".format(
            100 * acc, 100 * per))
        return acc, per
예제 #7
0
 def run_evaluate(self, test):
     """
     在test数据集上进行验证,输出准确率和召回率
     :param test:由(sequence, tags)组成的list
     :return:
     """
     nbatches = (len(test) + batch_size - 1) // batch_size  # 训练batch的次数
     prog = Progbar(target=nbatches)
     accs = []
     i = 0
     cnt_eq = 0
     for words1, words2, labels in minibatches(test, batch_size):
         logits = self.predict_batch(words1, words2)
         for logit, label in zip(logits[0], labels):
             # print(logit, np.argmax(logit), label)
             accs += [np.argmax(logit) == label]
             if np.argmax(logit) == label:
                 cnt_eq += 1
             pass
         prog.update(i + 1, [("evaluate acc", 100 * np.mean(accs))])
         i += 1
         print('cnt_eq=', cnt_eq)
         pass
     # print('cnt_eq=', cnt_eq)
     acc = np.mean(accs)
     return {"acc": 100 * acc}
예제 #8
0
    def run_epoch(self, sess, config, dataset, train_writer, merged):
        prog = Progbar(target=1 + len(dataset.train_inputs[0]) / config.batch_size)
        for i, (train_x, train_y) in enumerate(get_minibatches([dataset.train_inputs, dataset.train_targets],
                                                               config.batch_size, is_multi_feature_input=True)):

            summary, loss = self.train_on_batch(sess, train_x, train_y, merged)
            prog.update(i + 1, [("train loss", loss)])
            # train_writer.add_summary(summary, global_step=i)
        return summary, loss  # Last batch
예제 #9
0
    def run_epoch(self, sess, parser, train_examples, dev_set):
        prog = Progbar(target=1 + len(train_examples) / self.config.batch_size)
        for i, (train_x, train_y) in enumerate(minibatches(train_examples, self.config.batch_size)):
            loss = self.train_on_batch(sess, train_x, train_y)
            prog.update(i + 1, [("train loss", loss)])

        print "Evaluating on dev set",
        dev_UAS, _ = parser.parse(dev_set)
        print "- dev UAS: {:.2f}".format(dev_UAS * 100.0)
        return dev_UAS
예제 #10
0
    def run_epoch(self, sess, train, dev, tags, epoch):
        """
        Performs one complete pass over the train set and evaluate on dev
        Args:
            sess: tensorflow session
            train: dataset that yields tuple of sentences, tags
            dev: dataset
            tags: {tag: index} dictionary
            epoch: (int) number of the epoch
        """

        #trie setting
        self.lis1 = []
        self.lis2 = []
        self.lis3 = []
        self.lis4 = []
        self.lis5 = []

        trie.gazette(self.lis1, "data/dic/gazette.txt")
        trie.gazette(self.lis2, "data/dic/thres3.txt")
        trie.gazette_DTTI(self.lis3, "data/dic/DT_analysis.txt")
        trie.gazette_DTTI(self.lis4, "data/dic/TI_analysis.txt")
        trie.gazette(self.lis5, "data/dic/wiki_PS.txt")

        nbatches = (len(train) + self.config.batch_size -
                    1) // self.config.batch_size
        prog = Progbar(target=nbatches)
        for i, (words, fw_words, bw_words, labels, postags, sentences,
                _) in enumerate(minibatches(train, self.config.batch_size)):

            dict_labels = self.dict_trie(sentences)

            fd, _ = self.get_feed_dict(words,
                                       fw_words,
                                       bw_words,
                                       dict_labels,
                                       labels,
                                       self.config.lr,
                                       self.config.dropout,
                                       test_flag=0)

            _, train_loss, summary = sess.run(
                [self.train_op, self.loss, self.merged], feed_dict=fd)

            prog.update(i + 1, [("train loss", train_loss)])

            # tensorboard
            if i % 10 == 0:
                self.file_writer.add_summary(summary, epoch * nbatches + i)

        acc, f1, p, r = self.run_evaluate(sess, dev, tags, test_flag=0)
        self.logger.info(
            "- dev acc {:04.2f} - f1 {:04.2f} - p {:04.2f} - r {:04.2f}".
            format(100 * acc, 100 * f1, 100 * p, 100 * r))
        return acc, f1
예제 #11
0
def run_epoch(session, config, model, data, eval_op, keep_prob, is_training):
    n_samples = len(data[0])
    print("Running %d samples:" % (n_samples))
    minibatches = get_minibatches_idx(n_samples,
                                      config.batch_size,
                                      shuffle=False)

    correct = 0.
    total = 0
    total_cost = 0
    prog = Progbar(target=len(minibatches))
    #dummynode_hidden_states_collector=np.array([[0]*config.hidden_size])

    to_print_total = np.array([[0] * 2])
    for i, inds in enumerate(minibatches):
        x = data[0][inds]
        if sys.argv[4] == 'cnn':
            x = pad_sequences(x,
                              maxlen=700,
                              dtype='int32',
                              padding='post',
                              truncating='post',
                              value=0.)
        else:
            x = pad_sequences(x,
                              maxlen=None,
                              dtype='int32',
                              padding='post',
                              truncating='post',
                              value=0.)
        y = data[1][inds]
        mask = data[2][inds]

        count, _, cost, to_print= \
        session.run([model.accuracy, eval_op,model.cost, model.to_print],\
            {model.input_data: x, model.labels: y, model.mask:mask, model.keep_prob:keep_prob})
        if not is_training:
            to_print_total = np.concatenate((to_print_total, to_print), axis=0)
        correct += count
        total += len(inds)
        total_cost += cost
        prog.update(i + 1, [("train loss", cost)])
    #if not is_training:
    #    print(to_print_total[:, 0].tolist())
    #    print(data[1].tolist())
    #    print(data[2].tolist())
    print("Total loss:")
    print(total_cost)
    accuracy = correct / total
    return accuracy
예제 #12
0
def load_model(data):
    embeddings = get_embedding_data()
    nbatches = (len(data) + batch_size - 1) // batch_size  # 训练batch的次数
    prog = Progbar(target=nbatches)
    accs = []
    i = 0
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(
            'save/SentenceRelationModel1/SentenceRelationModel1_nl_3_hsl_196__.cpkt.meta'
        )
        saver.restore(
            sess, tf.train.latest_checkpoint("save/SentenceRelationModel1/"))
        print(tf.get_collection('logit'))
        y = tf.get_collection('logit')
        graph = tf.get_default_graph()
        seq1_word_embeddings = graph.get_operation_by_name(
            'seq1_word_embeddings').outputs[0]
        seq2_word_embeddings = graph.get_operation_by_name(
            'seq2_word_embeddings').outputs[0]
        sequence1_lengths1 = graph.get_operation_by_name(
            'sequence1_lengths').outputs[0]
        sequence2_lengths1 = graph.get_operation_by_name(
            'sequence2_lengths').outputs[0]
        for words1, words2, labels in minibatches(data, batch_size):
            words1, sequence1_lengths = pad_sequences(words1)
            words2, sequence2_lengths = pad_sequences(words2)
            words1_embeddings = [[
                embeddings[w1] if w1 in embeddings.keys() else embeddings[","]
                for w1 in ws1
            ] if len(ws1) > 0 else [embeddings["。"]] for ws1 in words1]
            words2_embeddings = [[
                embeddings[w2] if w2 in embeddings.keys() else embeddings[","]
                for w2 in ws2
            ] if len(ws2) > 0 else [embeddings["。"]] for ws2 in words2]
            sess.run(y,
                     feed_dict={
                         seq1_word_embeddings: words1_embeddings,
                         seq2_word_embeddings: words2_embeddings,
                         sequence1_lengths: sequence1_lengths1,
                         sequence2_lengths: sequence2_lengths1
                     })
            for logit, label in zip(y[0], labels):
                accs += [np.argmax(logit) == label]
                pass
            prog.update(i + 1, [("evaluate acc", 100 * np.mean(accs))])
            i += 1
            pass
        acc = np.mean(accs)
        return {"acc": 100 * acc}
예제 #13
0
    def run_epoch(self, session, train):
        """
        Perform one complete pass over the training data and evaluate on dev
        """
        nbatches = (len(train) + self.config.batch_size - 1) / self.config.batch_size
        prog = Progbar(target=nbatches)

        for i, (q_batch, c_batch, a_batch) in enumerate(minibatches(train, self.config.batch_size)):
            # at training time, dropout needs to be on.
            input_feed = self.get_feed_dict(q_batch, c_batch, a_batch, self.config.dropout_val)

            _, train_loss = session.run([self.train_op, self.loss], feed_dict=input_feed)
            prog.update(i + 1, [("train loss", train_loss)])
            summary = session.run(self.merged, input_feed)
            self.writer.add_summary(summary, i)
예제 #14
0
    def run_epoch(self, session, train):
        """
        Perform one complete pass over the training data and evaluate on dev
        """

        nbatches = (len(train) + config.BATCH_SIZE - 1) / config.BATCH_SIZE
        prog = Progbar(target=nbatches)

        for i, q_batch in enumerate(minibatches(train, config.BATCH_SIZE)):

            # at training time, dropout needs to be on.
            input_feed = self.get_feed_dict(q_batch, config.DROPOUT_VAL)

            _, train_loss = session.run([self.train_op, self.loss],
                                        feed_dict=input_feed)
            prog.update(i + 1, [("train loss", train_loss)])
예제 #15
0
    def train(self, train, dev):
        best_score = 0
        nepoch_no_imprv = 0 # for early stopping
        self.add_summary() # tensorboard

        for epoch in range(self.config.nepochs):
            self.logger.info("Epoch {:} out of {:}".format(epoch + 1,
                self.config.nepochs))
            batch_size = self.config.batch_size
            nbatches = (len(train) + batch_size - 1) // batch_size
            prog = Progbar(target=nbatches)
            #self.config.lr *= self.config.lr_decay
            for i, (words, labels, intent, all_tags) in enumerate(minibatches(train, batch_size)):
                fd, _ = self.get_feed_dict(words, all_tags, labels, intent, self.config.lr,\
                        self.config.dropout)
                _, train_loss, summary, intent_loss, slot_loss= self.sess.run(
                        [self.train_op, self.loss, self.merged, self.intent_loss, self.slot_loss], feed_dict=fd)
                prog.update(i + 1, [("train loss", train_loss), \
                        ("intent_loss", intent_loss), ("slot_loss", slot_loss)])
                if i % 10 == 0:
                    self.file_writer.add_summary(summary, epoch*nbatches + i)
            metrics = self.run_evaluate(dev)
            msg = " - ".join(["{} {:04.2f}".format(k, v)
                for k, v in metrics.items()])
            self.logger.info(msg)
            score = metrics["f1"] + metrics["intent_acc"]
            self.config.lr *= self.config.lr_decay
            if score >= best_score:
                nepoch_no_imprv = 0
                self.save_session()
                best_score = score
                self.logger.info("- new best score!")
            else:
                nepoch_no_imprv += 1
                if nepoch_no_imprv >= self.config.nepoch_no_imprv:
                    if not self.embedding_trainable:
                        self.logger.info("fine tuning word embedding")
                        for i in range(10):
                            self.logger.info("######################")
                        self.set_word_embeddings_trainable()
                        self.config.lr = 0.001
                        nepoch_no_imprv = 0
                    else:
                        self.logger.info("- early stopping {} epochs without "\
                                "improvement".format(nepoch_no_imprv))
                        break
예제 #16
0
def run_epoch(session,
              config,
              model,
              data,
              eval_op,
              keep_prob,
              is_training,
              saver=None):
    n_samples = len(data[0])
    minibatches = get_minibatches_idx(n_samples,
                                      config.batch_size,
                                      shuffle=config.shuffle)
    correct = 0.
    total = 0
    total_cost = 0
    prog = Progbar(target=len(minibatches))
    for i, idx in enumerate(minibatches):
        x = data[0][idx]
        chars = data[3][idx]
        x, mask = pad_sequences(x, 0, max_seq_len=config.max_seq_len)
        if [] in chars:
            continue
        chars, char_mask, max_word_len = pad_sequences(
            chars,
            pad_tok=0,
            nlevels=2,
            max_seq_len=config.max_seq_len,
            max_word_len=config.max_word_len)

        y = data[1][idx]
        max_word_len = max_word_len if max_word_len < config.max_word_len else config.max_word_len

        global_step, count, _, cost = session.run([model.global_step, model.accuracy, eval_op, model.cost],\
            {model.input_data: x, model.labels: y, model.mask: mask, model.keep_prob: keep_prob, model.chars: chars, model.char_mask: char_mask, model.max_word_len :  max_word_len})
        correct += count
        total += len(idx)
        total_cost += cost
        if is_training:
            prog.update(i + 1, [("train loss", cost), ("step", global_step)])
        if global_step % config.save_step == 0 and saver:
            saver.save(session, config.save_path, global_step=global_step)
        del cost
    accuracy = correct / total
    return accuracy, total_cost,
예제 #17
0
    def run_epoch(self, sess, train, dev, epoch):
        nbatches = (len(train) + self.config.batch_size -
                    1) // self.config.batch_size
        prog = Progbar(target=nbatches)
        for i, (words, imp_labels) in enumerate(
                minibatches(train, self.config.batch_size)):

            if self.config.model == "lstm_crf":
                imp_labels = list(map(self.config.digitize_labels, imp_labels))

            fd, _ = self.get_feed_dict(words, imp_labels, self.config.lr,
                                       self.config.dropout)
            _, train_loss = sess.run([self.optimize_, self.loss], feed_dict=fd)
            prog.update(i + 1, [("train loss", train_loss)])

        result = self.run_evaluate(sess, dev)
        self.logger.info(
            "- dev acc {:04.4f} - f {:04.4f} - rms {:04.4f}".format(
                100 * result['accuracy'], 100 * result['f-score'],
                -1 * result['rms']))
        return result
예제 #18
0
 def run_epoch(self, train, dev, epoch):
     """
     在训练集上和测试集上完整地跑一回
     :param train:(sentences, tags),其中sentences尚未对齐
     :param dev:类似于train的测试数据集
     :param epoch:当前epoch的编号
     :return f1:准确率
     """
     nbatches = (len(train) + batch_size - 1) // batch_size  # 训练batch的次数
     prog = Progbar(target=nbatches)
     # 首先遍历整个训练集
     for i, (words1, words2,
             labels) in enumerate(minibatches(train, batch_size)):
         fd, _, _ = self.get_feed_dict(words1, words2, labels, lr, dropout)
         _, train_loss = self.sess.run([self.train_op, self.loss],
                                       feed_dict=fd)
         prog.update(i + 1, [("train loss", train_loss)])
         pass
     # 接下来是测试了
     metrics = self.run_evaluate(dev)
     # print('第{}次的准确率为:\t{}\t召回率为:\t{}'.format(epoch, metrics['acc'], metrics['rec']))
     return metrics["acc"]  # 返回准确率
예제 #19
0
파일: ner_model.py 프로젝트: utkrist/simple
    def run_epoch(self, train, dev, epoch):
        """Performs one complete pass over the train set and evaluate on dev

        Args:
            train: dataset that yields tuple of sentences, tags
            dev: dataset
            epoch: (int) index of the current epoch

        Returns:
            f1: (python float), score to select model on, higher is better

        """
        # progbar stuff for logging
        batch_size = self.config.batch_size
        nbatches = (len(train) + batch_size - 1) // batch_size
        prog = Progbar(target=nbatches)

        # iterate over dataset
        for i, (words, labels, pred_flags) in enumerate(minibatches(train, batch_size)):
            fd, _ = self.get_feed_dict(words, labels, pred_flags, self.config.lr,
                    self.config.dropout)

            _, train_loss, train_summary = self.sess.run(
                    [self.train_op, self.loss, self.train_merged], feed_dict=fd)

            prog.update(i + 1, [("{}_train loss".format(self.config.name), train_loss)])

            # tensorboard
            if i % 10 == 0:
                self.train_file_writer.add_summary(train_summary, epoch*nbatches + i)
            yield None

        metrics = self.run_evaluate(dev, log_step=epoch*nbatches+i)
        msg = " - ".join(["{} {:04.2f}".format(k, v)
                for k, v in metrics.items()])
        self.logger.info(msg)

        # Score for early stopping
        return self.config.early_stop_metric_sign * metrics[self.config.early_stopping_metric]
예제 #20
0
    def run_epoch(self, sess, train, dev, tags, epoch):
        nbatches = (len(train) + self.config.batch_size -
                    1) // self.config.batch_size
        prog = Progbar(target=nbatches)
        for i, (words,
                labels) in enumerate(minibatches(train,
                                                 self.config.batch_size)):
            fd, _ = self.get_feed_dict(words, labels, self.config.lr,
                                       self.config.dropout)

            _, train_loss, summary = sess.run(
                [self.train_op, self.loss, self.merged], feed_dict=fd)

            prog.update(i + 1, [("train loss", train_loss)])

            # tensorboard
            if i % 10 == 0:
                self.file_writer.add_summary(summary, epoch * nbatches + i)

        acc, f1 = self.evaluate(sess, dev, tags)
        self.logger.info("- dev acc {:04.2f} - f1 {:04.2f}".format(
            100 * acc, 100 * f1))
        return acc, f1
    def run_epoch(self, train, dev, epoch):
        nbatches = (len(train) + self.batch_size - 1) // self.batch_size

        prog = Progbar(target=nbatches)  #进度条

        batches = batch_yield(train, self.batch_size)
        for i, (sentence, label) in enumerate(batches):
            fd, _ = self.get_feed_dict(sentence, self.weight_dropout_list,
                                       label, self.lr, self.dropout)

            _, train_loss, summary = self.sess.run(
                [self.train_op, self.loss, self.merged], feed_dict=fd)

            prog.update(i + 1, [("train loss", train_loss)])

            if i % 10 == 0:
                self.file_writer.add_summary(summary, epoch * nbatches + i)

        acc, p, r, f1 = self.run_evaluate(dev)
        self.logger.info(
            "- dev acc {:04.2f} - p {:04.2f}- r {:04.2f}- f1 {:04.2f}".format(
                100 * acc, 100 * p, 100 * r, 100 * f1))
        return acc, p, r, f1
    def run_epoch(self, sess, train, dev, epoch):
        '''
        Performs one complete pass over the train set and evaluates on dev

        Args:
            sess: tensorflow session
            train: large BIODataSentence list (training set)
            dev: large BIODataSentence list (dev set)
            epoch: (int) number of the epoch
        '''
        nbatches = len(train) // self.config.batch_size
        if len(train) % self.config.batch_size != 0:
            nbatches += 1

        prog = Progbar(target=nbatches)
        for i, sent_batch in enumerate(
                minibatches(train, self.config.batch_size)):
            fd, _ = self.prepare_feed_dict_optimized( \
                        bio_data_sentence_batch=sent_batch, \
                        dropout_keep_prob=self.config.dropout_keep_prob, \
                        learning_rate=self.config.learning_rate)

            #_, train_loss, summary = sess.run([self.train_op, self.loss, self.merged], feed_dict=fd)
            _, train_loss = sess.run([self.train_op, self.loss], feed_dict=fd)

            prog.update(i + 1, [('train loss', train_loss)])

            # tensorboard
            #if i % 10 == 0:
            #self.file_writer.add_summary(summary, epoch*nbatches + i)

        acc, f1, mod_p = self.run_evaluate(sess, dev)
        self.logger.info(
            '- dev acc {:04.2f} - f1 {:04.2f} - mod prec {:04.2f}'.format(
                100 * acc, 100 * f1, 100 * mod_p))
        return acc, f1
예제 #23
0
    def run_epoch(self, sess, train_data, dev_data, test_data, epoch):
        """
        :param train_data: contains concatenated sentence(user and system list type) and ground_labels(O, T, X)
        :return: accuracy and f1 scroe
        """
        num_batches = (len(train_data) + self.config.batch_size -
                       1) // self.config.batch_size
        prog = Progbar(target=num_batches)

        for i, (concat_utter_list, ground_label) in enumerate(
                minibatches(train_data + dev_data + test_data[:200],
                            self.config.batch_size)):
            input_features = []
            for each_utter_list in concat_utter_list:
                user_sentence = each_utter_list[0]
                system_sentence = each_utter_list[1]
                user_embedding = self.utter_embed.embed_utterance(
                    user_sentence)
                system_embedding = self.utter_embed.embed_utterance(
                    system_sentence)
                input_feature = np.concatenate(
                    (user_embedding, system_embedding), axis=0)
                input_features.append(input_feature)

            input_features = np.array([input_features])

            ground_label_list = []
            for label in ground_label:
                # label.strip().encode('utf-8')
                ground_label_list.append(self.cate_mapping_dict[label.strip()])

            ground_label_list = np.array([ground_label_list])

            dropout_keep_prob = 0.8
            feed_dict = {
                self.input_features: input_features,
                self.ground_label: ground_label_list,
                self.dropout_keep_prob: dropout_keep_prob
            }

            # self.merged = tf.summary.merge_all()
            self.file_writer = tf.summary.FileWriter(self.config.output_path,
                                                     sess.graph)

            _, train_loss = sess.run([self.train_step, self.loss],
                                     feed_dict=feed_dict)

            prog.update(i + 1, [("train loss", train_loss)])

        js_divergence_value, accuracy, precision_X, recall_X, f1_score_X, precision_B_T, recall_B_T, f1_score_B_T = self.run_evaluate(
            sess, test_data[200:])

        self.logger.info("JS_divergence : {:f}".format(js_divergence_value))
        self.logger.info("accuracy : {:f}".format(accuracy))
        self.logger.info("precision_X : {:f}".format(precision_X))
        self.logger.info("recall_X : {:f}".format(recall_X))
        self.logger.info("f1_score_X : {:f}".format(f1_score_X))

        self.logger.info("precision X + T : {:f}".format(precision_B_T))
        self.logger.info("recall X + T : {:f}".format(recall_B_T))
        self.logger.info("f1_score X + T : {:f}".format(f1_score_B_T))

        return accuracy, f1_score_X
예제 #24
0
파일: train.py 프로젝트: johndpope/etagger
def do_train(model, config, train_data, dev_data, test_data):
    learning_rate_init=0.001  # initial
    learning_rate_final=0.0001 # final
    learning_rate=learning_rate_init
    intermid_epoch = 20       # after this epoch, change learning rate
    maximum = 0
    session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
    sess = tf.Session(config=session_conf)
    with sess.as_default():
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        if config.restore is not None:
            saver.restore(sess, config.restore)
            print('model restored')
        # summary setting
        loss_summary = tf.summary.scalar('loss', model.loss)
        acc_summary = tf.summary.scalar('accuracy', model.accuracy)
        train_summary_op = tf.summary.merge([loss_summary, acc_summary])
        train_summary_dir = os.path.join(config.summary_dir, 'summaries', 'train')
        train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)
        dev_summary_op = tf.summary.merge([loss_summary, acc_summary])
        dev_summary_dir = os.path.join(config.summary_dir, 'summaries', 'dev')
        dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)
        # training steps
        for e in range(config.epoch):
            # run epoch
            idx = 0
            nbatches = (len(train_data.sentence_word_ids) + config.batch_size - 1) // config.batch_size
            prog = Progbar(target=nbatches)
            for ptr in range(0, len(train_data.sentence_word_ids), config.batch_size):
                feed_dict={model.input_data_word_ids: train_data.sentence_word_ids[ptr:ptr + config.batch_size],
                           model.input_data_wordchr_ids: train_data.sentence_wordchr_ids[ptr:ptr + config.batch_size],
                           model.input_data_pos_ids: train_data.sentence_pos_ids[ptr:ptr + config.batch_size],
                           model.input_data_etc: train_data.sentence_etc[ptr:ptr + config.batch_size],
                           model.output_data: train_data.sentence_tag[ptr:ptr + config.batch_size],
                           model.learning_rate:learning_rate}
                step, train_summaries, _, train_loss, train_accuracy = \
                           sess.run([model.global_step, train_summary_op, model.train_op, model.loss, model.accuracy], feed_dict=feed_dict)
                prog.update(idx + 1, [('train loss', train_loss), ('train accuracy', train_accuracy)])
                train_summary_writer.add_summary(train_summaries, step)
                idx += 1
            # evaluate on dev data
            feed_dict={model.input_data_word_ids: dev_data.sentence_word_ids,
                       model.input_data_wordchr_ids: dev_data.sentence_wordchr_ids,
                       model.input_data_pos_ids: dev_data.sentence_pos_ids,
                       model.input_data_etc: dev_data.sentence_etc,
                       model.output_data: dev_data.sentence_tag}
            step, dev_summaries, logits, logits_indices, trans_params, output_data_indices, length, dev_loss, dev_accuracy = \
                       sess.run([model.global_step, dev_summary_op, model.logits, model.logits_indices, model.trans_params, model.output_data_indices, model.length, model.loss, model.accuracy], feed_dict=feed_dict)
            print('epoch: %d / %d, step: %d, dev loss: %s, dev accuracy: %s' % (e, config.epoch, step, dev_loss, dev_accuracy))
            dev_summary_writer.add_summary(dev_summaries, step)
            print('dev precision, recall, f1(token): ')
            token_f1 = TokenEval.compute_f1(config.class_size, logits, dev_data.sentence_tag, length)
            if config.use_crf:
                viterbi_sequences = viterbi_decode(logits, trans_params, length)
                tag_preds = dev_data.logits_indices_to_tags_seq(viterbi_sequences, length)
            else:
                tag_preds = dev_data.logits_indices_to_tags_seq(logits_indices, length)
            tag_corrects = dev_data.logits_indices_to_tags_seq(output_data_indices, length)
            dev_prec, dev_rec, dev_f1 = ChunkEval.compute_f1(tag_preds, tag_corrects)
            print('dev precision, recall, f1(chunk): ', dev_prec, dev_rec, dev_f1)
            chunk_f1 = dev_f1
            # save best model
            '''
            m = chunk_f1 # slightly lower than token-based f1 for test
            '''
            m = token_f1
            if m > maximum:
                print('new best f1 score!')
                maximum = m
                save_path = saver.save(sess, config.checkpoint_dir + '/' + 'model_max.ckpt')
                print('max model saved in file: %s' % save_path)
                feed_dict={model.input_data_word_ids: test_data.sentence_word_ids,
                           model.input_data_wordchr_ids: test_data.sentence_wordchr_ids,
                           model.input_data_pos_ids: test_data.sentence_pos_ids,
                           model.input_data_etc: test_data.sentence_etc,
                           model.output_data: test_data.sentence_tag}
                step, logits, logits_indices, trans_params, output_data_indices, length, test_loss, test_accuracy = \
                           sess.run([model.global_step, model.logits, model.logits_indices, model.trans_params, model.output_data_indices, model.length, model.loss, model.accuracy], feed_dict=feed_dict)
                print('epoch: %d / %d, step: %d, test loss: %s, test accuracy: %s' % (e, config.epoch, step, test_loss, test_accuracy))
                print('test precision, recall, f1(token): ')
                TokenEval.compute_f1(config.class_size, logits, test_data.sentence_tag, length)
                if config.use_crf:
                    viterbi_sequences = viterbi_decode(logits, trans_params, length)
                    tag_preds = test_data.logits_indices_to_tags_seq(viterbi_sequences, length)
                else:
                    tag_preds = test_data.logits_indices_to_tags_seq(logits_indices, length)
                tag_corrects = test_data.logits_indices_to_tags_seq(output_data_indices, length)
                test_prec, test_rec, test_f1 = ChunkEval.compute_f1(tag_preds, tag_corrects)
                print('test precision, recall, f1(chunk): ', test_prec, test_rec, test_f1)
            # learning rate change
            if e > intermid_epoch: learning_rate=learning_rate_final
예제 #25
0
def run_epoch(session, config, model, data, eval_op, keep_prob, is_training):
    n_samples = len(data[0])
    print("Running %d samples:" % (n_samples))
    minibatches = get_minibatches_idx(n_samples,
                                      config.batch_size,
                                      shuffle=False)

    correct = 0.
    target_correct = 0.0
    total = 0
    total_cost = 0
    prog = Progbar(target=len(minibatches))
    # dummynode_hidden_states_collector=np.array([[0]*config.hidden_size])

    # to_print_total = np.zeros([2, config.num_label])
    corr = data[5]
    for i, inds in enumerate(minibatches):
        # print("i:",i,"inds:",inds)
        x = [data[0][j] for j in inds]
        y = np.array([data[1][j] for j in inds])

        multi = [data[2][j] for j in inds]
        length = [data[3][j] for j in inds]
        rank = [data[4][j] for j in inds]
        target = [data[6][j] for j in inds]

        ranksignal, prosignal = signalMatrix(multi, rank, length,
                                             config.num_label)

        # 计算文本实际长度
        text_mask = [len(s) for s in x]

        # 对文本做padding
        text_padding = data_helper.padding_text(x, text_mask, padding_word=[0])

        x = []
        for j in range(len(inds)):
            x = x + text_padding[j]
        config.sentence_size = len(x) / len(inds)

        # 计算每个句子长度
        sentence_mask = [len(s) for s in x]
        # 对句子做padding
        sentence_padding = data_helper.padding_text(x,
                                                    sentence_mask,
                                                    padding_word=0)

        x = np.array([np.array(s) for s in sentence_padding])

        count, _, cost, to_print, prediction, text_alphas = session.run(
            [
                model.accuracy, eval_op, model.cost, model.to_print,
                model.prediction, model.text_alphas
            ], {
                model.input_data: x,
                model.labels: y,
                model.sentence_mask: sentence_mask,
                model.text_mask: text_mask,
                model.keep_prob: keep_prob,
                model.signal: prosignal,
                model.length: length,
                model.corr: corr
            })
        # to_print = session.run([model.to_print], {model.input_data: x, model.labels: y, model.sentence_mask:sentece_mask.astype(int), model.text_mask:text_mask.astype(int), model.keep_prob: keep_prob})

        # target predict
        target_predict = np.argmax(text_alphas)
        target_count = np.equal(target_predict, target)
        target_correct = target_correct + target_count

        # if not is_training:
        #     to_print_total = np.concatenate((to_print_total, to_print), axis=0)

        correct += count
        total += len(inds)
        total_cost += cost
        prog.update(i + 1, [("train loss", cost)])
        # print(correct / total)

    print("Total loss:")
    print(total_cost)
    accuracy = correct / total
    target_accuracy = target_correct / total
    return accuracy, target_accuracy
예제 #26
0
def train():
    print("Training")

    # tf Graph input
    x = tf.placeholder(dtype=tf.float32, shape=[None, config.input_window_size - 1, config.input_size], name="input_sequence")
    y = tf.placeholder(dtype=tf.float32, shape=[None, config.output_window_size, config.input_size], name="raw_labels")
    dec_in = tf.placeholder(dtype=tf.float32, shape=[None, config.output_window_size, config.input_size], name="decoder_input")

    labels = tf.transpose(y, [1, 0, 2])
    labels = tf.reshape(labels, [-1, config.input_size])
    labels = tf.split(labels, config.output_window_size, axis=0, name='labels')

    # Define model
    prediction = models.seq2seq(x, dec_in, config, True)

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.gpu_options.per_process_gpu_memory_fraction = 0.6
    sess_config.allow_soft_placement = True
    sess_config.log_device_placement = False
    sess = tf.Session(config=sess_config)

    # Define cost function
    loss = eval('loss_functions.' + config.loss + '_loss(prediction, labels, config)')

    # Add a summary for the loss
    train_loss = tf.summary.scalar('train_loss', loss)
    valid_loss = tf.summary.scalar('valid_loss', loss)

    # Defining training parameters
    optimizer = tf.train.AdamOptimizer(config.learning_rate)

    global_step = tf.Variable(0, name='global_step', trainable=False)

    # Gradient Clipping
    grads = tf.gradients(loss, tf.trainable_variables())
    grads, _ = tf.clip_by_global_norm(grads, config.max_grad_norm)
    optimizer.apply_gradients(zip(grads, tf.trainable_variables()))
    train_op = optimizer.minimize(loss, global_step=global_step)

    saver = tf.train.Saver(max_to_keep=10)
    train_writer = tf.summary.FileWriter("./log", sess.graph)

    sess.run(tf.local_variables_initializer())
    sess.run(tf.global_variables_initializer())

    # Obtain total training parameters
    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print('Total training parameters: ' + str(total_parameters))

    if not(os.path.exists(checkpoint_dir)):
        os.makedirs(checkpoint_dir)

    saved_epoch = 0
    train_size = config.training_size
    valid_size = config.validation_size
    best_val_loss = float('inf')

    if config.restore & os.path.exists(checkpoint_dir+'checkpoint'):
        with open(checkpoint_dir + 'checkpoint') as f:
            content = f.readlines()
        saved_epoch = int(re.search(r'\d+', content[0]).group())
        model_name = checkpoint_dir + "Epoch_" + str(saved_epoch)
        saver.restore(sess, model_name)

        v_loss_mean = 0.0
        for i in range(valid_size):
            batch_x, batch_dec_in, batch_y = data_utils.get_batch(config, test_set)
            v_loss, valid_summary = sess.run([loss, valid_loss], feed_dict={x: batch_x, y: batch_y, dec_in: batch_dec_in})
            v_loss_mean = v_loss_mean*i/(i+1) + v_loss/(i+1)
        best_val_loss = v_loss_mean

        print("Restored session from Epoch ", str(saved_epoch))
        print("Best Validation Loss: ", best_val_loss, "\n")

    print("________________________________________________________________")

    best_val_epoch = saved_epoch

    for j in range(saved_epoch, config.max_epoch):

        print("Epoch ", j+1)
        prog = Progbar(target=train_size)
        prog_valid = Progbar(target=valid_size)

        for i in range(train_size):
            batch_x, batch_dec_in, batch_y = data_utils.get_batch(config, train_set)
            current_cost, train_summary, _ = sess.run([loss, train_loss, train_op], feed_dict={x: batch_x, y: batch_y, dec_in: batch_dec_in})

            train_writer.add_summary(train_summary, j*train_size+i)
            prog.update(i+1, [("Training Loss", current_cost)])

        v_loss_mean = 0.0
        for i in range(valid_size):
            batch_x, batch_dec_in, batch_y = data_utils.get_batch(config, test_set)
            v_loss, valid_summary = sess.run([loss, valid_loss], feed_dict={x: batch_x, y: batch_y, dec_in: batch_dec_in})
            v_loss_mean = v_loss_mean*i/(i+1) + v_loss/(i+1)
            prog_valid.update(i + 1, [("Validation Loss", v_loss)])
            train_writer.add_summary(valid_summary, j*valid_size+i)

        if v_loss_mean < best_val_loss:
            model_name = checkpoint_dir + "Epoch_" + str(j+1)
            best_val_loss = v_loss_mean
            best_val_epoch = j+1
            saver.save(sess, model_name)

        print("Current Best Epoch: ", best_val_epoch, ", Best Validation Loss: ", best_val_loss, "\n")

        if j+1 - best_val_epoch > config.early_stop:
            break
예제 #27
0
파일: model.py 프로젝트: vutran0230/DOER
    def run_epoch(self, sess, train, dev, vocab_words, vocab_poss,
                  vocab_chunks, vocab_aspect_tags, vocab_polarity_tags,
                  vocab_joint_tags, epoch):
        """
        Performs one complete pass over the train set and evaluate on dev
        Args:
                sess: tensorflow session
                train: dataset that yields tuple of sentences, tags
                dev: dataset
                vocab_aspect_tags: {tag: index} dictionary
                epoch: (int) number of the epoch
        """
        self.config.istrain = True  # set to train first, #batch normalization#
        losses = []
        nbatches = (len(train) + self.config.batch_size -
                    1) / self.config.batch_size
        prog = Progbar(target=nbatches)
        for i, (words, poss, chunks, labels_aspect, labels_polarity,
                labels_joint) in enumerate(
                    minibatches_for_sequence(train, self.config.batch_size)):
            fd, sequence_lengths = self.get_feed_dict(
                words,
                poss,
                chunks,
                labels_aspect,
                labels_polarity,
                labels_joint,
                self.config.lr,
                self.config.dropout,
                vocab_aspect_tags=vocab_aspect_tags)

            _, lr, train_loss, summary = sess.run(
                [self.train_op, self.learning_rate, self.loss, self.merged],
                feed_dict=fd)
            losses.append(train_loss)

            if self.config.show_process_logs:
                print_mess = [("train loss", train_loss)]
                if self.config.use_labels_length:
                    label_length_loss = sess.run(self.label_length_loss,
                                                 feed_dict=fd)
                    print_mess.append(("label_length_loss", label_length_loss))
                if self.config.use_mpqa:
                    mpqa_loss = sess.run(self.mpqa_loss, feed_dict=fd)
                    print_mess.append(("mpqa_loss", mpqa_loss))
                print_mess.append(("lr", lr))

                prog.update(i + 1, print_mess)

            # tensorboard
            if i % 2 == 0:
                self.file_writer.add_summary(summary, epoch * nbatches + i)

        if self.config.data_sets.startswith("twitter"):
            aspect_p, aspect_r, aspect_f1, aspect_test_acc, \
            polarity_p, polarity_r, polarity_f1, polarity_test_acc, dev_loss = 0, 0, 0, 0, 0, 0, 0, 0, 0
            self.logger.info("Ignore validating without corresponding data~")
        else:
            aspect_p, aspect_r, aspect_f1, aspect_test_acc, \
            polarity_p, polarity_r, polarity_f1, polarity_test_acc, dev_loss = self.run_evaluate(
                sess, dev, vocab_aspect_tags, vocab_polarity_tags, vocab_joint_tags, vocab_words, is_dev=True)

            self.logger.info(
                "- aspect_dev precision {:04.2f} - aspect_dev recall {:04.2f} - "
                "aspect_dev f1 {:04.2f} - aspect_dev acc {:04.2f}".format(
                    100 * aspect_p, 100 * aspect_r, 100 * aspect_f1,
                    100 * aspect_test_acc))
            self.logger.info(
                "- polarity_dev precision {:04.2f} - polarity_dev recall {:04.2f} - "
                "polarity_dev f1 {:04.2f} - polarity_dev acc {:04.2f}".format(
                    100 * polarity_p, 100 * polarity_r, 100 * polarity_f1,
                    100 * polarity_test_acc))

        return aspect_p, aspect_r, aspect_f1, sum(losses) / len(
            losses), polarity_p, polarity_r, polarity_f1, dev_loss
예제 #28
0
    def train_stepwise(self, train, dev, train_eval):
        """Performs training with early stopping and lr exponential decay

        Args:
            train: dataset that yields tuple of (sentences, tags)
            dev: dataset

        """
        best_score = 0
        nepoch_no_imprv = 0
        updates = 0
        epoch_train_loss = 0

        prog = Progbar(target=self.config.updates_per_epoch)

        while 1:
            for words, labels in minibatches(train, self.config.batch_size):
                fd, _ = self.get_feed_dict(True,
                                           words,
                                           labels,
                                           lr=self.config.lr)
                _, train_loss = self.sess.run([self.train_op, self.loss],
                                              feed_dict=fd)
                prog.update((updates % self.config.updates_per_epoch) + 1,
                            values=[("train loss", train_loss)])
                epoch_train_loss += train_loss
                updates += 1

                if updates % self.config.updates_per_epoch == 0:
                    acc_train = self.evaluate(train_eval)
                    acc_test = self.evaluate(dev)

                    epoch = updates / self.config.updates_per_epoch
                    prog.update(self.config.updates_per_epoch,
                                epoch, [("train loss", train_loss)],
                                exact=[("dev acc", acc_test),
                                       ("train acc", acc_train),
                                       ("lr", self.config.lr)])
                    self.write_epoch_results(
                        epoch, acc_train, acc_test,
                        epoch_train_loss / self.config.updates_per_epoch)

                    epoch_train_loss = 0

                    # early stopping and saving best parameters
                    if acc_test >= best_score:
                        nepoch_no_imprv = 0
                        self.save_session()
                        best_score = acc_test
                    else:
                        nepoch_no_imprv += 1
                        if nepoch_no_imprv >= self.config.nepoch_no_imprv:
                            return best_score

                    # apply decay
                    if self.config.lr_decay_strategy == "on-no-improvement":
                        if acc_test < best_score:
                            self.config.lr *= self.config.lr_decay
                    elif self.config.lr_decay_strategy == "exponential":
                        self.config.lr *= self.config.lr_decay
                    elif self.config.lr_decay_strategy == "step":
                        self.config.lr = self.config.step_decay_init_lr * math.pow(
                            self.config.step_decay_drop,
                            math.floor(
                                epoch / self.config.step_decay_epochs_drop))
                    elif self.config.lr_decay_strategy is None:
                        pass
                    else:
                        raise ValueError("Invalid 'decay_strategy' setting: " +
                                         self.config.lr_decay_strategy)

                    if updates < self.config.max_updates:
                        prog = Progbar(target=self.config.updates_per_epoch)

                if updates >= self.config.max_updates:
                    return best_score
예제 #29
0
    def train_epochwise(self, train, dev, train_eval):
        """Performs training with early stopping and lr decay"""
        updates, epoch, best_score, nepoch_no_imprv = 0, 0, 0, 0
        batch_size = self.config.batch_size
        max_epochs = self.config.max_epochs
        nbatches = (len(train) + batch_size - 1) // batch_size

        while epoch < max_epochs:
            # Run one epoch

            epoch_time = time()
            train_time = time()

            epoch_train_loss = 0
            iter = 0
            prog = Progbar(target=nbatches)

            for feed_dict in self.iter_prebuilt_feed_dict_batches(
                    train, batch_size):
                fd, _ = self.get_final_feed_dict(True,
                                                 feed_dict,
                                                 lr=self.config.lr)
                _, train_loss = self.sess.run([self.train_op, self.loss],
                                              feed_dict=fd)
                epoch_train_loss += train_loss
                updates += 1

                if updates % self.config.lr_decay_step == 0:
                    # apply decay
                    if self.config.lr_decay_strategy == "on-no-improvement":
                        if acc_test < best_score:
                            self.config.lr *= self.config.lr_decay
                    elif self.config.lr_decay_strategy == "exponential":
                        self.config.lr *= self.config.lr_decay
                    elif self.config.lr_decay_strategy == "step":
                        self.config.lr = self.config.step_decay_init_lr * \
                                         math.pow(self.config.step_decay_drop, math.floor(
                                             (epoch) / self.config.step_decay_epochs_drop))
                    elif self.config.lr_decay_strategy is None:
                        pass
                    else:
                        raise ValueError("Invalid 'decay_strategy' setting: " +
                                         self.config.lr_decay_strategy)

                prog.update(iter + 1, values=[("train loss", train_loss)])
                iter += 1

            train_time = time() - train_time

            # evaluate epoch
            acc_train = self.evaluate(train_eval)

            eval_time = time()
            acc_test = self.evaluate(dev)
            eval_time = time() - eval_time

            epoch_time = time() - epoch_time

            # log epoch
            prog.update(iter + 1,
                        epoch, [("train loss", train_loss)],
                        exact=[("dev acc", acc_test), ("train acc", acc_train),
                               ("lr", self.config.lr)])
            self.write_epoch_results(epoch,
                                     acc_train,
                                     acc_test,
                                     epoch_train_loss / iter,
                                     nbatches,
                                     epoch_time=epoch_time,
                                     train_time=train_time,
                                     eval_time=eval_time)

            # early stopping and saving checkpoint
            if acc_test >= best_score:
                nepoch_no_imprv = 0
                self.save_session()
                best_score = acc_test
            else:
                nepoch_no_imprv += 1
                if nepoch_no_imprv >= self.config.nepoch_no_imprv:
                    self.logger.info(
                        "- early stopping {} epochs without improvement".
                        format(nepoch_no_imprv))
                    break
            epoch += 1
        return best_score