예제 #1
0
 def train(self, dataset):
     self.logger.info("Start training...")
     best_score, no_imprv_epoch, cur_step = -np.inf, 0, 0
     for epoch in range(1, self.cfg.epochs + 1):
         self.logger.info('Epoch {}/{}:'.format(epoch, self.cfg.epochs))
         prog = Progbar(target=dataset.num_batches)
         for i in range(dataset.num_batches):
             cur_step += 1
             data = dataset.next_batch()
             feed_dict = self._get_feed_dict(data, training=True)
             _, train_loss = self.sess.run([self.train_op, self.loss],
                                           feed_dict=feed_dict)
             prog.update(i + 1, [("Global Step", int(cur_step)),
                                 ("Train Loss", train_loss)])
         # evaluate
         score = self.evaluate_data(dataset.dev_batches(), name="dev")
         self.evaluate_data(dataset.test_batches(), name="test")
         if score > best_score:
             best_score, no_imprv_epoch = score, 0
             self.save_session(epoch)
             self.logger.info(
                 ' -- new BEST score on dev dataset: {:04.2f}'.format(
                     best_score))
         else:
             no_imprv_epoch += 1
             if self.cfg.no_imprv_tolerance is not None and no_imprv_epoch >= self.cfg.no_imprv_tolerance:
                 self.logger.info(
                     'early stop at {}th epoch without improvement'.format(
                         epoch))
                 self.logger.info(
                     'best score on dev set: {}'.format(best_score))
                 break
예제 #2
0
 def train(self, dataset):
     self.logger.info("Start training...")
     best_f1, no_imprv_epoch, init_lr, lr, cur_step = -np.inf, 0, self.cfg.lr, self.cfg.lr, 0
     for epoch in range(1, self.cfg.epochs + 1):
         self.logger.info("Epoch {}/{}:".format(epoch, self.cfg.epochs))
         prog = Progbar(target=dataset.get_num_batches())
         for i, data in enumerate(dataset.get_data_batches()):
             cur_step += 1
             feed_dict = self._get_feed_dict(data, is_train=True, lr=lr)
             _, train_loss = self.sess.run([self.train_op, self.loss], feed_dict=feed_dict)
             prog.update(i + 1, [("Global Step", int(cur_step)), ("Train Loss", train_loss)])
         # learning rate decay
         if self.cfg.use_lr_decay:
             if self.cfg.decay_step:
                 lr = max(init_lr / (1.0 + self.cfg.lr_decay * epoch / self.cfg.decay_step), self.cfg.minimal_lr)
         # evaluate
         score = self.evaluate(dataset.get_data_batches("dev"), name="dev")
         self.evaluate(dataset.get_data_batches("test"), name="test")
         if score["FB1"] > best_f1:
             best_f1, no_imprv_epoch = score["FB1"], 0
             self.save_session(epoch)
             self.logger.info(" -- new BEST score on dev dataset: {:04.2f}".format(best_f1))
         else:
             no_imprv_epoch += 1
             if self.cfg.no_imprv_tolerance is not None and no_imprv_epoch >= self.cfg.no_imprv_tolerance:
                 self.logger.info("early stop at {}th epoch without improvement".format(epoch))
                 self.logger.info("best score on dev set: {}".format(best_f1))
                 break
예제 #3
0
    def train(self, train_words, train_chars, train_labels, test_words,
              test_chars, test_labels):
        global_test_acc, global_step, lr = 0.0, 0, self.cfg.lr
        num_batches = math.ceil(float(len(train_words) / self.cfg.batch_size))

        self.logger.info("start training...")
        for epoch in range(1, self.cfg.epochs + 1):
            self.logger.info("Epoch {}/{}:".format(epoch, self.cfg.epochs))
            train_words, train_chars, train_labels = sklearn.utils.shuffle(
                train_words, train_chars, train_labels)
            prog = Progbar(target=num_batches)

            for i, (b_words, b_seq_len, b_chars, b_char_seq_len,
                    b_labels) in enumerate(
                        batch_iter(train_words, train_chars, train_labels,
                                   self.cfg.batch_size)):
                global_step += 1
                batch_labels = []
                for j in range(self.num_classifier):
                    ecoc_array = self.nary_ecoc[:, j]
                    b_lbs = remap_labels(b_labels.copy(), ecoc_array)
                    b_lbs = dense_to_one_hot(b_lbs, self.num_class)
                    batch_labels.append(b_lbs)
                feed_dict = self.get_feed_dict(b_words,
                                               b_seq_len,
                                               b_chars,
                                               b_char_seq_len,
                                               batch_labels,
                                               lr=lr,
                                               training=True)
                _, pred_labels, loss = self.sess.run(
                    [self.train_op, self.pred_labels, self.loss],
                    feed_dict=feed_dict)
                acc = compute_ensemble_accuracy(pred_labels, self.nary_ecoc,
                                                b_labels)
                prog.update(i + 1, [("Global Step", global_step),
                                    ("Train Loss", loss),
                                    ("Train Acc", acc * 100)])
            accuracy, _ = self.test(test_words,
                                    test_chars,
                                    test_labels,
                                    batch_size=200,
                                    print_info=True,
                                    restore=False)

            if accuracy > global_test_acc:
                global_test_acc = accuracy
                self.save_session(epoch)
            lr = self.cfg.lr / (1 + epoch * self.cfg.lr_decay)
    def train_epoch(self, train_set, valid_data, epoch):
        num_batches = len(train_set)
        prog = Progbar(target=num_batches)
        total_cost, total_samples = 0, 0

        for i, batch in enumerate(train_set):
            feed_dict = self._get_feed_dict(batch, is_train=True, keep_prob=self.cfg["keep_prob"], lr=self.cfg["lr"])
            _, train_loss, summary = self.sess.run([self.train_op, self.loss, self.summary], feed_dict=feed_dict)
            cur_step = (epoch - 1) * num_batches + (i + 1)
            total_cost += train_loss
            total_samples += np.array(batch["words"]).shape[0]
            prog.update(i + 1, [("Global Step", int(cur_step)), ("Train Loss", train_loss),
                                ("Perplexity", np.exp(total_cost / total_samples))])
            self.train_writer.add_summary(summary, cur_step)

            if i % 100 == 0:
                valid_feed_dict = self._get_feed_dict(valid_data)
                valid_summary = self.sess.run(self.summary, feed_dict=valid_feed_dict)
                self.test_writer.add_summary(valid_summary, cur_step)
예제 #5
0
 def train(self, train_dataset, test_dataset):
     global_test_acc = 0.0
     global_step = 0
     test_imgs, test_labels = test_dataset.images, test_dataset.labels
     self.logger.info("start training...")
     for epoch in range(1, self.epochs + 1):
         self.logger.info("Epoch {}/{}:".format(epoch, self.epochs))
         num_batches = train_dataset.num_examples // self.batch_size
         prog = Progbar(target=num_batches)
         prog.update(0, [("Global Step", 0), ("Train Loss", 0.0), ("Train Acc", 0.0), ("Test Loss", 0.0),
                         ("Test Acc", 0.0)])
         for i in range(num_batches):
             global_step += 1
             train_imgs, train_labels = train_dataset.next_batch(self.batch_size)
             feed_dict = {self.inputs: train_imgs, self.labels: train_labels, self.training: True}
             _, loss, acc = self.sess.run([self.train_op, self.cost, self.accuracy], feed_dict=feed_dict)
             if global_step % 100 == 0:
                 feed_dict = {self.inputs: test_imgs, self.labels: test_labels, self.training: False}
                 test_loss, test_acc = self.sess.run([self.cost, self.accuracy], feed_dict=feed_dict)
                 prog.update(i + 1, [("Global Step", int(global_step)), ("Train Loss", loss), ("Train Acc", acc),
                                     ("Test Loss", test_loss), ("Test Acc", test_acc)])
                 if test_acc > global_test_acc:
                     global_test_acc = test_acc
                     self.save_session(global_step)
             else:
                 prog.update(i + 1, [("Global Step", int(global_step)), ("Train Loss", loss), ("Train Acc", acc)])
         feed_dict = {self.inputs: test_imgs, self.labels: test_labels, self.training: False}
         test_loss, test_acc = self.sess.run([self.cost, self.accuracy], feed_dict=feed_dict)
         self.logger.info("Epoch: {}, Global Step: {}, Test Loss: {}, Test Accuracy: {}".format(
             epoch, global_step, test_loss, test_acc))
예제 #6
0
 def train(self, label_dataset, partial_dataset, unlabeled_dataset):
     best_f1, no_imprv_epoch, lr, cur_step = -np.inf, 0, self.cfg.lr, 0
     loss_weight = self._compute_loss_weights(label_dataset.get_dataset_size(), partial_dataset.get_dataset_size(),
                                              unlabeled_dataset.get_dataset_size(), method=self.cfg.loss_method)
     self.logger.info("Start training...")
     for epoch in range(1, self.cfg.epochs + 1):
         self.logger.info("Epoch {}/{}:".format(epoch, self.cfg.epochs))
         batches = self._create_batches(label_dataset.get_num_batches(), partial_dataset.get_num_batches(),
                                        unlabeled_dataset.get_num_batches())
         prog = Progbar(target=len(batches))
         prog.update(0, [("Global Step", int(cur_step)), ("Label Loss", 0.0), ("Partial Loss", 0.0),
                         ("AE Loss", 0.0), ("AE Acc", 0.0)])
         for i, batch_name in enumerate(batches):
             cur_step += 1
             if batch_name == "label":
                 data = label_dataset.get_next_train_batch()
                 feed_dict = self._get_feed_dict(data, loss_weight=loss_weight[0], is_train=True, lr=lr)
                 _, cost = self.sess.run([self.train_op, self.loss], feed_dict=feed_dict)
                 prog.update(i + 1, [("Global Step", int(cur_step)), ("Label Loss", cost)])
             if batch_name == "partial":
                 data = partial_dataset.get_next_train_batch()
                 feed_dict = self._get_feed_dict(data, loss_weight=loss_weight[1], is_train=True, lr=lr)
                 _, cost = self.sess.run([self.train_op, self.loss], feed_dict=feed_dict)
                 prog.update(i + 1, [("Global Step", int(cur_step)), ("Partial Loss", cost)])
             if batch_name == "unlabeled":
                 data = unlabeled_dataset.get_next_train_batch()
                 feed_dict = self._get_feed_dict(data, loss_weight=loss_weight[2], is_train=True, lr=lr)
                 _, cost, acc = self.sess.run([self.ae_train_op, self.ae_loss, self.ae_acc], feed_dict=feed_dict)
                 prog.update(i + 1, [("Global Step", int(cur_step)), ("AE Loss", cost), ("AE Acc", acc)])
         # learning rate decay
         if self.cfg.use_lr_decay:
             if self.cfg.decay_step:
                 lr = max(self.cfg.lr / (1.0 + self.cfg.lr_decay * epoch / self.cfg.decay_step), self.cfg.minimal_lr)
         self.evaluate(label_dataset.get_data_batches("dev"), name="dev")
         score = self.evaluate(label_dataset.get_data_batches("test"), name="test")
         if score["FB1"] > best_f1:
             best_f1, no_imprv_epoch = score["FB1"], 0
             self.save_session(epoch)
             self.logger.info(" -- new BEST score on test dataset: {:04.2f}".format(best_f1))
         else:
             no_imprv_epoch += 1
             if self.cfg.no_imprv_tolerance is not None and no_imprv_epoch >= self.cfg.no_imprv_tolerance:
                 self.logger.info("early stop at {}th epoch without improvement".format(epoch))
                 self.logger.info("best score on test set: {}".format(best_f1))
                 break
예제 #7
0
    def train(self, x_train, y_train, x_test, y_test, batch_size=200, epochs=10):
        x_train, x_test = self.normalize(x_train, x_test)
        y_train = keras.utils.to_categorical(y_train, self.num_class)
        y_test = keras.utils.to_categorical(y_test, self.num_class)

        self.logger.info("data augmentation...")
        datagen = ImageDataGenerator(featurewise_center=True, samplewise_center=False, horizontal_flip=True, cval=0.0,
                                     featurewise_std_normalization=False, preprocessing_function=None, rescale=None,
                                     samplewise_std_normalization=False, zca_whitening=False, zca_epsilon=1e-06,
                                     rotation_range=15, width_shift_range=0.1, height_shift_range=0.1, shear_range=0.0,
                                     zoom_range=0.0, channel_shift_range=0.0, fill_mode='nearest',  vertical_flip=False,
                                     data_format="channels_last")
        datagen.fit(x_train)
        x_aug, y_aug = x_train.copy(), y_train.copy()
        x_aug = datagen.flow(x_aug, np.zeros(x_train.shape[0]), batch_size=x_train.shape[0], shuffle=False).next()[0]
        x_train, y_train = np.concatenate((x_train, x_aug)), np.concatenate((y_train, y_aug))
        self.logger.info("start training...")
        global_step, lr, global_test_acc = 0, self.learning_rate, 0.0
        num_batches = x_train.shape[0] // batch_size
        for epoch in range(1, epochs + 1):
            self.logger.info("Epoch {}/{}:".format(epoch, epochs))
            x_train, y_train = utils.shuffle(x_train, y_train, random_state=0)  # shuffle training dataset
            prog = Progbar(target=num_batches)
            prog.update(0, [("Global Step", 0), ("Train Loss", 0.0), ("Train Acc", 0.0), ("Test Loss", 0.0),
                            ("Test Acc", 0.0)])
            for i, (batch_imgs, batch_labels) in enumerate(batch_dataset(x_train, y_train, batch_size)):
                global_step += 1
                feed_dict = {self.inputs: batch_imgs, self.labels: batch_labels, self.training: True, self.lr: lr}
                _, loss, acc = self.sess.run([self.train_op, self.cost, self.accuracy], feed_dict=feed_dict)
                if global_step % 200 == 0:
                    feed_dict = {self.inputs: x_test, self.labels: y_test, self.training: False}
                    test_loss, test_acc = self.sess.run([self.cost, self.accuracy], feed_dict=feed_dict)
                    prog.update(i + 1, [("Global Step", int(global_step)), ("Train Loss", loss), ("Train Acc", acc),
                                        ("Test Loss", test_loss), ("Test Acc", test_acc)])
                    if test_acc > global_test_acc:
                        global_test_acc = test_acc
                        self.save_session(global_step)
                else:
                    prog.update(i + 1, [("Global Step", int(global_step)), ("Train Loss", loss), ("Train Acc", acc)])
            if epoch > 10:
                lr = self.learning_rate / (1 + (epoch - 10) * self.lr_decay)
            feed_dict = {self.inputs: x_test, self.labels: y_test, self.training: False}
            test_loss, test_acc = self.sess.run([self.cost, self.accuracy], feed_dict=feed_dict)
            self.logger.info("Epoch: {}, Global Step: {}, Test Loss: {}, Test Accuracy: {}".format(
                epoch, global_step, test_loss, test_acc))
예제 #8
0
 def train(self, src_dataset, tgt_dataset):
     self.logger.info("Start training...")
     best_score, no_imprv_epoch, src_lr, tgt_lr, cur_step = -np.inf, 0, self.cfg.lr, self.cfg.lr, 0
     for epoch in range(1, self.cfg.epochs + 1):
         self.logger.info("Epoch {}/{}:".format(epoch, self.cfg.epochs))
         batches = self._arrange_batches(src_dataset.num_batches, tgt_dataset.num_batches, self.cfg.mix_rate)
         prog = Progbar(target=len(batches))
         prog.update(0, [("Global Step", int(cur_step)), ("Source Train Loss", 0.0), ("Target Train Loss", 0.0)])
         for i, batch_name in enumerate(batches):
             cur_step += 1
             if batch_name == "src":
                 data = src_dataset.next_batch()
                 domain_labels = [[1, 0]] * data["batch_size"]
                 feed_dict = self._get_feed_dict(src_data=data, tgt_data=None, domain_labels=domain_labels,
                                                 training=True)
                 _, src_cost = self.sess.run([self.src_train_op, self.src_loss], feed_dict=feed_dict)
                 prog.update(i + 1, [("Global Step", int(cur_step)), ("Source Train Loss", src_cost)])
             else:  # "tgt"
                 data = tgt_dataset.next_batch()
                 domain_labels = [[0, 1]] * data["batch_size"]
                 feed_dict = self._get_feed_dict(src_data=None, tgt_data=data, domain_labels=domain_labels,
                                                 training=True)
                 _, tgt_cost = self.sess.run([self.tgt_train_op, self.tgt_loss], feed_dict=feed_dict)
                 prog.update(i + 1, [("Global Step", int(cur_step)), ("Target Train Loss", tgt_cost)])
         score = self.evaluate_data(tgt_dataset.dev_batches(), "target_dev", resource="target")
         self.evaluate_data(tgt_dataset.test_batches(), "target_test", resource="target")
         if score > best_score:
             best_score, no_imprv_epoch = score, 0
             self.save_session(epoch)
             self.logger.info(' -- new BEST score on target dev dataset: {:04.2f}'.format(best_score))
         else:
             no_imprv_epoch += 1
             if self.cfg.no_imprv_tolerance is not None and no_imprv_epoch >= self.cfg.no_imprv_tolerance:
                 self.logger.info('early stop at {}th epoch without improvement'.format(epoch))
                 self.logger.info('best score on target dev set: {}'.format(best_score))
                 break
예제 #9
0
    def _train_epoch(
        self,
        dataloader_train,
        dataloader_dev,
        optimizer,
        criterion,
        eval_every,
        train_step,
        best_score,
        best_loss,
        use_prog_bar=False,
    ):

        prog = Progbar(len(dataloader_train))

        tr_loss = 0

        for batch_idx, batch in enumerate(dataloader_train):
            torch.cuda.empty_cache()
            self.scheduler.step()
            train_step += 1
            optimizer.zero_grad()

            (sents, target, doc_encoding) = batch
            if not self.binary_class:
                target = target.squeeze(1)
            if self.use_doc_encoding:  # Capsule based models
                (
                    preds,
                    word_attention_scores,
                    sent_attention_scores,
                    rec_loss,
                ) = self.model(sents, doc_encoding)
            else:  # Other models
                (
                    preds,
                    word_attention_scores,
                    sent_attention_scores,
                    rec_loss,
                ) = self.model(
                    sents
                )  # rec loss defaults to 0 for non-CapsNet models

            if torch.cuda.device_count() > 1:
                rec_loss = rec_loss.mean()
            loss = criterion(preds, target)
            loss += rec_loss
            tr_loss += loss.item()

            # if APEX_AVAILABLE:
            #     with amp.scale_loss(loss, optimizer) as scaled_loss:
            #         scaled_loss.backward()
            # else:
            loss.backward()

            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5)
            optimizer.step()
            if use_prog_bar:
                prog.update(
                    batch_idx + 1,
                    values=[("train loss", loss.item()), ("recon loss", rec_loss)],
                )
            torch.cuda.empty_cache()

            if train_step % eval_every == 0:
                best_score, best_loss = self._eval_model(
                    dataloader_train, dataloader_dev, best_score, best_loss, train_step
                )

        return best_score, best_loss, train_step
예제 #10
0
 def train(self, src_dataset, tgt_dataset):
     self.logger.info("Start training...")
     best_f1, no_imprv_epoch, src_lr, tgt_lr, cur_step = -np.inf, 0, self.cfg.lr, self.cfg.lr, 0
     for epoch in range(1, self.cfg.epochs + 1):
         self.logger.info("Epoch {}/{}:".format(epoch, self.cfg.epochs))
         batches = self._arrange_batches(src_dataset, tgt_dataset,
                                         self.cfg.mix_rate,
                                         self.cfg.train_ratio)
         prog = Progbar(target=len(batches))
         prog.update(0, [("Global Step", int(cur_step)),
                         ("Source Train Loss", 0.0),
                         ("Target Train Loss", 0.0)])
         for i, batch_name in enumerate(batches):
             cur_step += 1
             if batch_name == "src":
                 data = src_dataset.get_next_train_batch()
                 domain_labels = [[1, 0]] * data["batch_size"]
                 feed_dict = self._get_feed_dict(
                     src_data=data,
                     tgt_data=None,
                     domain_labels=domain_labels,
                     is_train=True,
                     src_lr=src_lr)
                 _, src_cost = self.sess.run(
                     [self.src_train_op, self.src_loss],
                     feed_dict=feed_dict)
                 prog.update(i + 1, [("Global Step", int(cur_step)),
                                     ("Source Train Loss", src_cost)])
             else:  # "tgt"
                 data = tgt_dataset.get_next_train_batch()
                 domain_labels = [[0, 1]] * data["batch_size"]
                 feed_dict = self._get_feed_dict(
                     src_data=None,
                     tgt_data=data,
                     domain_labels=domain_labels,
                     is_train=True,
                     tgt_lr=tgt_lr)
                 _, tgt_cost = self.sess.run(
                     [self.tgt_train_op, self.tgt_loss],
                     feed_dict=feed_dict)
                 prog.update(i + 1, [("Global Step", int(cur_step)),
                                     ("Target Train Loss", tgt_cost)])
         if self.cfg.use_lr_decay:  # learning rate decay
             src_lr = max(self.cfg.lr / (1.0 + self.cfg.lr_decay * epoch),
                          self.cfg.minimal_lr)
             if epoch % self.cfg.decay_step == 0:
                 tgt_lr = max(
                     self.cfg.lr /
                     (1.0 +
                      self.cfg.lr_decay * epoch / self.cfg.decay_step),
                     self.cfg.minimal_lr)
         if not self.cfg.dev_for_train:
             self.evaluate(tgt_dataset.get_data_batches("dev"),
                           "target_dev",
                           self._tgt_predict_op,
                           rev_word_dict=self.rev_tw_dict,
                           rev_label_dict=self.rev_tl_dict)
         score = self.evaluate(tgt_dataset.get_data_batches("test"),
                               "target_test",
                               self._tgt_predict_op,
                               rev_word_dict=self.rev_tw_dict,
                               rev_label_dict=self.rev_tl_dict)
         if score["FB1"] > best_f1:
             best_f1, no_imprv_epoch = score["FB1"], 0
             self.save_session(epoch)
             self.logger.info(
                 ' -- new BEST score on target test dataset: {:04.2f}'.
                 format(best_f1))
         else:
             no_imprv_epoch += 1
             if self.cfg.no_imprv_tolerance is not None and no_imprv_epoch >= self.cfg.no_imprv_tolerance:
                 self.logger.info(
                     'early stop at {}th epoch without improvement'.format(
                         epoch))
                 self.logger.info(
                     'best score on target test set: {}'.format(best_f1))
                 break
예제 #11
0
 def train(self, train_dataset, test_dataset):
     global_test_acc = 0.0
     global_step = 0
     test_imgs, test_labels = test_dataset.images, test_dataset.labels
     self.logger.info("start training...")
     for epoch in range(1, self.epochs + 1):
         self.logger.info("Epoch {}/{}:".format(epoch, self.epochs))
         num_batches = train_dataset.num_examples // self.batch_size
         prog = Progbar(target=num_batches)
         prog.update(0, [("Global Step", 0), ("Train Loss", 0.0),
                         ("Train Acc", 0.0), ("Test Loss", 0.0),
                         ("Test Acc", 0.0)])
         for i in range(num_batches):
             global_step += 1
             train_imgs, train_labels = train_dataset.next_batch(
                 self.batch_size)
             b_labels = []
             for j in range(self.num_classifier):
                 ecoc_array = self.nary_ecoc[:, j]
                 b_lbs = remap_labels(train_labels.copy(), ecoc_array)
                 b_lbs = dense_to_one_hot(b_lbs, self.num_class)
                 b_labels.append(b_lbs)
             feed_dict = self._get_feed_dict(train_imgs, b_labels, True)
             _, pred_labels, loss = self.sess.run(
                 [self.train_op, self.pred_labels, self.cost],
                 feed_dict=feed_dict)
             acc = compute_ensemble_accuracy(pred_labels, self.nary_ecoc,
                                             train_labels)
             if global_step % 100 == 0:
                 y_labels = []
                 for j in range(self.num_classifier):
                     ecoc_array = self.nary_ecoc[:, j]
                     b_lbs = remap_labels(test_labels.copy(), ecoc_array)
                     b_lbs = dense_to_one_hot(b_lbs, self.num_class)
                     y_labels.append(b_lbs)
                 feed_dict = self._get_feed_dict(test_imgs, y_labels)
                 test_pred_labels, test_loss = self.sess.run(
                     [self.pred_labels, self.cost], feed_dict=feed_dict)
                 test_acc = compute_ensemble_accuracy(
                     test_pred_labels, self.nary_ecoc, test_labels)
                 prog.update(i + 1, [("Global Step", int(global_step)),
                                     ("Train Loss", loss),
                                     ("Train Acc", acc),
                                     ("Test Loss", test_loss),
                                     ("Test Acc", test_acc)])
                 if test_acc > global_test_acc:
                     global_test_acc = test_acc
                     self.save_session(global_step)
             else:
                 prog.update(i + 1, [("Global Step", int(global_step)),
                                     ("Train Loss", loss),
                                     ("Train Acc", acc)])
         y_labels = []
         for j in range(self.num_classifier):
             ecoc_array = self.nary_ecoc[:, j]
             b_lbs = remap_labels(test_labels.copy(), ecoc_array)
             b_lbs = dense_to_one_hot(b_lbs, self.num_class)
             y_labels.append(b_lbs)
         feed_dict = self._get_feed_dict(test_imgs, y_labels)
         test_pred_labels, test_loss = self.sess.run(
             [self.pred_labels, self.cost], feed_dict=feed_dict)
         test_acc = compute_ensemble_accuracy(test_pred_labels,
                                              self.nary_ecoc, test_labels)
         self.logger.info(
             "Epoch: {}, Global Step: {}, Test Loss: {}, Test Accuracy: {}".
             format(epoch, global_step, test_loss, test_acc))