예제 #1
0
    def _eval_on_full_set(self, sess, epoch_i, step, silent=False):
        """
    Evaluate on the full data set and print information.
    """
        eval_start_time = time.time()

        if not silent:
            utils.thick_line()
            print('Calculating losses using full data set...')

        # Calculate losses and accuracies of full train set
        if self.cfg.EVAL_WITH_FULL_TRAIN_SET:
            loss_train, clf_loss_train, rec_loss_train, acc_train = \
                self._eval_on_batches('train', sess, self.x_train, self.y_train,
                                      self.n_batch_train, silent=silent)
        else:
            loss_train, clf_loss_train, rec_loss_train, acc_train = \
                None, None, None, None

        # Calculate losses and accuracies of full valid set
        loss_valid, clf_loss_valid, rec_loss_valid, acc_valid = \
            self._eval_on_batches('valid', sess, self.x_valid, self.y_valid,
                                  self.n_batch_valid, silent=silent)

        if not silent:
            utils.print_full_set_eval(epoch_i, self.cfg.EPOCHS, step,
                                      self.start_time, loss_train,
                                      clf_loss_train, rec_loss_train,
                                      acc_train, loss_valid, clf_loss_valid,
                                      rec_loss_valid, acc_valid,
                                      self.cfg.EVAL_WITH_FULL_TRAIN_SET,
                                      self.cfg.WITH_RECONSTRUCTION)

        file_path = join(self.train_log_path, 'full_set_eval_log.csv')
        if not silent:
            utils.thin_line()
            print('Saving {}...'.format(file_path))
        utils.save_log(file_path, epoch_i + 1, step,
                       time.time() - self.start_time, loss_train,
                       clf_loss_train, rec_loss_train, acc_train, loss_valid,
                       clf_loss_valid, rec_loss_valid, acc_valid,
                       self.cfg.WITH_RECONSTRUCTION)

        if not silent:
            utils.thin_line()
            print(
                'Evaluation done! Using time: {:.2f}'.format(time.time() -
                                                             eval_start_time))
예제 #2
0
    def _save_logs(self, sess, train_writer, valid_writer, x_batch, y_batch,
                   imgs_batch, epoch_i, step):
        """Save logs and ddd summaries to TensorBoard while training."""
        valid_batch_idx = np.random.choice(range(len(self.x_valid)),
                                           self.cfg.BATCH_SIZE).tolist()
        x_valid_batch = self.x_valid[valid_batch_idx]
        y_valid_batch = self.y_valid[valid_batch_idx]
        imgs_valid_batch = self.imgs_valid[valid_batch_idx]

        if self.cfg.WITH_REC:
            summary_train, loss_train, clf_loss_train, rec_loss_train, acc_train = \
                sess.run([self.summary, self.loss, self.clf_loss,
                          self.rec_loss, self.accuracy],
                         feed_dict={self.inputs: x_batch,
                                    self.labels: y_batch,
                                    self.input_imgs: imgs_batch,
                                    self.is_training: False})
            summary_valid, loss_valid, clf_loss_valid, rec_loss_valid, acc_valid = \
                sess.run([self.summary, self.loss, self.clf_loss,
                          self.rec_loss, self.accuracy],
                         feed_dict={self.inputs: x_valid_batch,
                                    self.labels: y_valid_batch,
                                    self.input_imgs: imgs_valid_batch,
                                    self.is_training: False})
        else:
            summary_train, loss_train, acc_train = \
                sess.run([self.summary, self.loss, self.accuracy],
                         feed_dict={self.inputs: x_batch,
                                    self.labels: y_batch,
                                    self.input_imgs: imgs_batch,
                                    self.is_training: False})
            summary_valid, loss_valid, acc_valid = \
                sess.run([self.summary, self.loss, self.accuracy],
                         feed_dict={self.inputs: x_valid_batch,
                                    self.labels: y_valid_batch,
                                    self.input_imgs: imgs_valid_batch,
                                    self.is_training: False})
            clf_loss_train, rec_loss_train, clf_loss_valid, rec_loss_valid = \
                None, None, None, None

        train_writer.add_summary(summary_train, step)
        valid_writer.add_summary(summary_valid, step)
        utils.save_log(join(self.train_log_path,
                            'train_log.csv'), epoch_i + 1, step,
                       time.time() - self.start_time, loss_train,
                       clf_loss_train, rec_loss_train, acc_train, loss_valid,
                       clf_loss_valid, rec_loss_valid, acc_valid,
                       self.cfg.WITH_REC)