Example #1
0
    def save_bottleneck_features(self,
                                 inputs,
                                 file_path,
                                 batch_size=None,
                                 data_type=np.float32):
        inputs_shape = inputs.shape
        img_mode = 'L' if inputs_shape[3] == 1 else 'RGB'

        with open(file_path, 'wb') as f:

            if batch_size:

                batch_generator = utils.get_batches(inputs,
                                                    batch_size=batch_size,
                                                    keep_last=True)
                if len(inputs) % batch_size == 0:
                    n_batch = len(inputs) // batch_size
                else:
                    n_batch = len(inputs) // batch_size + 1

                for _ in tqdm(range(n_batch),
                              total=n_batch,
                              ncols=100,
                              unit='batch'):
                    inputs_batch = next(batch_generator)
                    inputs_batch = utils.imgs_scale_to_255(
                        inputs_batch).astype(data_type)

                    if inputs_shape[1:3] != (224, 224):
                        inputs_batch = utils.img_resize(
                            inputs_batch,
                            (224, 224),
                            img_mode=img_mode,
                            resize_filter=Image.ANTIALIAS,
                            verbose=False,
                        ).astype(data_type)
                    if inputs_shape[3] == 1:
                        inputs_batch = np.concatenate(
                            [inputs_batch, inputs_batch, inputs_batch],
                            axis=-1)

                    assert inputs_batch.shape[1:] == (224, 224, 3)
                    bf_batch = self._extract_features(inputs_batch)
                    f.write(pickle.dumps(bf_batch))

                    # Release memory
                    K.clear_session()
                    tf.reset_default_graph()

            else:
                inputs = utils.imgs_scale_to_255(inputs).astype(data_type)

                if inputs_shape[3] == 1:
                    inputs = np.concatenate([inputs, inputs, inputs], axis=-1)

                assert inputs.shape[1:] == (224, 224, 3)
                bottleneck_features = self._extract_features(inputs)
                f.write(pickle.dumps(bottleneck_features))
Example #2
0
    def _eval_on_batches(self, sess, inputs, labels, loss, accuracy, clf_loss,
                         rec_loss, rec_images, x, y, n_batch):
        """
    Calculate losses and accuracies of full train set.
    """
        loss_all = []
        acc_all = []
        clf_loss_all = []
        rec_loss_all = []
        step = 0
        _batch_generator = utils.get_batches(x, y, self.cfg.TEST_BATCH_SIZE)

        if self.cfg.TEST_WITH_RECONSTRUCTION:
            for _ in tqdm(range(n_batch),
                          total=n_batch,
                          ncols=100,
                          unit=' batches'):
                step += 1
                x_batch, y_batch = next(_batch_generator)
                loss_i, clf_loss_i, rec_loss_i, acc_i = \
                    sess.run([loss, clf_loss, rec_loss, accuracy],
                             feed_dict={inputs: x_batch, labels: y_batch})
                loss_all.append(loss_i)
                clf_loss_all.append(clf_loss_i)
                rec_loss_all.append(rec_loss_i)
                acc_all.append(acc_i)

                # Save reconstruct images
                if self.cfg.TEST_SAVE_IMAGE_STEP is not None:
                    if step % self.cfg.TEST_SAVE_IMAGE_STEP == 0:
                        self._save_images(sess, rec_images, inputs, labels,
                                          x_batch, y_batch, step)

            clf_loss = sum(clf_loss_all) / len(clf_loss_all)
            rec_loss = sum(rec_loss_all) / len(rec_loss_all)

        else:
            for _ in tqdm(range(n_batch),
                          total=n_batch,
                          ncols=100,
                          unit=' batches'):
                x_batch, y_batch = next(_batch_generator)
                loss_i, acc_i = \
                    sess.run([loss, accuracy],
                             feed_dict={inputs: x_batch, labels: y_batch})
                loss_all.append(loss_i)
                acc_all.append(acc_i)
            clf_loss, rec_loss = None, None

        loss = sum(loss_all) / len(loss_all)
        accuracy = sum(acc_all) / len(acc_all)

        return loss, clf_loss, rec_loss, accuracy
  def get_bottleneck_features(self,
                              inputs,
                              batch_size=None,
                              data_type=np.float32,
                              pooling='avg'):
    # Check image size for transfer learning models
    inputs_shape = inputs.shape
    assert inputs_shape[1:3] == (224, 224)

    # Scale to 0-255 and extract features
    if batch_size:
      batch_generator = utils.get_batches(
          inputs, batch_size=batch_size, keep_last=True)
      n_batch = len(inputs) // batch_size + 1
      bottleneck_features = []
      for _ in tqdm(range(n_batch), total=n_batch, ncols=100, unit='batch'):

        inputs_batch = next(batch_generator)
        inputs_batch = utils.imgs_scale_to_255(inputs_batch).astype(data_type)

        if inputs_shape[3] == 1:
          inputs_batch = np.concatenate(
              [inputs_batch, inputs_batch, inputs_batch], axis=-1)

        assert inputs_batch.shape[1:] == (224, 224, 3)
        bf_batch = self._extract_features(inputs_batch, pooling=pooling)
        bottleneck_features.append(bf_batch)

        # Release memory
        K.clear_session()
        tf.reset_default_graph()

      bottleneck_features = np.concatenate(bottleneck_features, axis=0)
    else:
      inputs = utils.imgs_scale_to_255(inputs).astype(data_type)
      if inputs_shape[3] == 1:
        inputs = np.concatenate([inputs, inputs, inputs], axis=-1)
      bottleneck_features = self._extract_features(inputs, pooling=pooling)

    # Check data shape
    assert len(bottleneck_features) == len(inputs)
    assert bottleneck_features.shape[1:] == \
        self._get_bottleneck_feature_shape(pooling=pooling)

    return bottleneck_features
Example #4
0
    def _get_preds_vector(self, sess, inputs, preds, is_training):
        """Get prediction vectors of full train set."""
        utils.thin_line()
        print('Getting prediction vectors...')
        pred_all = []
        _batch_generator = utils.get_batches(
            self.x_test, batch_size=self.cfg.TEST_BATCH_SIZE, keep_last=True)

        if len(self.x_test) % self.cfg.TEST_BATCH_SIZE == 0:
            n_batch = (len(self.x_test) // self.cfg.TEST_BATCH_SIZE)
        else:
            n_batch = (len(self.x_test) // self.cfg.TEST_BATCH_SIZE) + 1

        for _ in tqdm(range(n_batch), total=n_batch, ncols=100, unit=' batch'):
            x_batch = next(_batch_generator)

            # The last batch which has less examples
            len_batch = len(x_batch)
            if len_batch != self.cfg.TEST_BATCH_SIZE:
                for i in range(self.cfg.TEST_BATCH_SIZE - len_batch):
                    x_batch = np.append(x_batch,
                                        np.expand_dims(np.zeros_like(
                                            x_batch[0]),
                                                       axis=0),
                                        axis=0)
                assert len(x_batch) == self.cfg.TEST_BATCH_SIZE

            pred_i = sess.run(preds,
                              feed_dict={
                                  inputs: x_batch,
                                  is_training: False
                              })
            if len_batch != self.cfg.TEST_BATCH_SIZE:
                pred_i = pred_i[:len_batch]
            pred_all.extend(list(pred_i))

        assert len(pred_all) == len(self.x_test), (len(pred_all),
                                                   len(self.x_test))
        return np.array(pred_all)
Example #5
0
    def _trainer(self, sess):

        utils.thick_line()
        print('Training...')

        # Merge all the summaries and create writers
        train_summary_path = join(self.summary_path, 'train')
        valid_summary_path = join(self.summary_path, 'valid')
        utils.check_dir([train_summary_path, valid_summary_path])
        train_writer = tf.summary.FileWriter(train_summary_path, sess.graph)
        valid_writer = tf.summary.FileWriter(valid_summary_path)

        sess.run(tf.global_variables_initializer())
        step = 0

        for epoch_i in range(self.cfg.EPOCHS):

            epoch_start_time = time.time()
            utils.thick_line()
            print('Training on epoch: {}/{}'.format(epoch_i + 1,
                                                    self.cfg.EPOCHS))

            if self.cfg.DISPLAY_STEP is not None:

                for x_batch, y_batch in utils.get_batches(
                        self.x_train, self.y_train, self.cfg.BATCH_SIZE):
                    step += 1

                    # Training optimizer
                    sess.run(self.optimizer,
                             feed_dict={
                                 self.inputs: x_batch,
                                 self.labels: y_batch,
                                 self.step: step - 1,
                                 self.is_training: True
                             })

                    # Display training information
                    if step % self.cfg.DISPLAY_STEP == 0:
                        self._display_status(sess, x_batch, y_batch, epoch_i,
                                             step)

                    # Save training logs
                    if self.cfg.SAVE_LOG_STEP is not None:
                        if step % self.cfg.SAVE_LOG_STEP == 0:
                            self._save_logs(sess, train_writer, valid_writer,
                                            x_batch, y_batch, epoch_i, step)

                    # Save reconstruction images
                    if self.cfg.SAVE_IMAGE_STEP is not None:
                        if self.cfg.WITH_RECONSTRUCTION:
                            if step % self.cfg.SAVE_IMAGE_STEP == 0:
                                self._save_images(sess,
                                                  self.train_image_path,
                                                  x_batch,
                                                  y_batch,
                                                  step,
                                                  epoch_i=epoch_i)

                    # Save models
                    if self.cfg.SAVE_MODEL_MODE == 'per_batch':
                        if step % self.cfg.SAVE_MODEL_STEP == 0:
                            self._save_model(sess, self.saver, step)

                    # Evaluate on full set
                    if self.cfg.FULL_SET_EVAL_MODE == 'per_batch':
                        if step % self.cfg.FULL_SET_EVAL_STEP == 0:
                            self._eval_on_full_set(sess, epoch_i, step)
                            utils.thick_line()
            else:
                utils.thin_line()
                train_batch_generator = utils.get_batches(
                    self.x_train, self.y_train, self.cfg.BATCH_SIZE)
                for _ in tqdm(range(self.n_batch_train),
                              total=self.n_batch_train,
                              ncols=100,
                              unit=' batches'):

                    step += 1
                    x_batch, y_batch = next(train_batch_generator)

                    # Training optimizer
                    sess.run(self.optimizer,
                             feed_dict={
                                 self.inputs: x_batch,
                                 self.labels: y_batch,
                                 self.step: step - 1,
                                 self.is_training: True
                             })

                    # Save training logs
                    if self.cfg.SAVE_LOG_STEP is not None:
                        if step % self.cfg.SAVE_LOG_STEP == 0:
                            self._save_logs(sess, train_writer, valid_writer,
                                            x_batch, y_batch, epoch_i, step)

                    # Save reconstruction images
                    if self.cfg.SAVE_IMAGE_STEP is not None:
                        if self.cfg.WITH_RECONSTRUCTION:
                            if step % self.cfg.SAVE_IMAGE_STEP == 0:
                                self._save_images(sess,
                                                  self.train_image_path,
                                                  x_batch,
                                                  y_batch,
                                                  step,
                                                  silent=True,
                                                  epoch_i=epoch_i)

                    # Save models
                    if self.cfg.SAVE_MODEL_MODE == 'per_batch':
                        if step % self.cfg.SAVE_MODEL_STEP == 0:
                            self._save_model(sess,
                                             self.saver,
                                             step,
                                             silent=True)

                    # Evaluate on full set
                    if self.cfg.FULL_SET_EVAL_MODE == 'per_batch':
                        if step % self.cfg.FULL_SET_EVAL_STEP == 0:
                            self._eval_on_full_set(sess,
                                                   epoch_i,
                                                   step,
                                                   silent=True)

            if self.cfg.SAVE_MODEL_MODE == 'per_epoch':
                if (epoch_i + 1) % self.cfg.SAVE_MODEL_STEP == 0:
                    self._save_model(sess, self.saver, epoch_i)
            if self.cfg.FULL_SET_EVAL_MODE == 'per_epoch':
                if (epoch_i + 1) % self.cfg.FULL_SET_EVAL_STEP == 0:
                    self._eval_on_full_set(sess, epoch_i, step)

            utils.thin_line()
            print('Epoch done! Using time: {:.2f}'.format(time.time() -
                                                          epoch_start_time))

        utils.thick_line()
        print('Training finished! Using time: {:.2f}'.format(time.time() -
                                                             self.start_time))
        utils.thick_line()

        # Evaluate on test set after training
        if self.cfg.TEST_AFTER_TRAINING:
            self._test_after_training(sess)

        utils.thick_line()
        print('All task finished! Total time: {:.2f}'.format(time.time() -
                                                             self.start_time))
        utils.thick_line()
Example #6
0
    def _test_after_training(self, sess):
        """
    Evaluate on the test set after training.
    """
        test_start_time = time.time()

        utils.thick_line()
        print('Testing...')

        # Check directory of paths
        utils.check_dir([self.test_log_path])
        if self.cfg.WITH_RECONSTRUCTION:
            if self.cfg.TEST_SAVE_IMAGE_STEP is not None:
                utils.check_dir([self.test_image_path])

        # Load data
        utils.thin_line()
        print('Loading test set...')
        utils.thin_line()
        x_test = utils.load_data_from_pkl(
            join(self.preprocessed_path, 'x_test.p'))
        y_test = utils.load_data_from_pkl(
            join(self.preprocessed_path, 'y_test.p'))
        n_batch_test = len(y_test) // self.cfg.BATCH_SIZE

        utils.thin_line()
        print('Calculating loss and accuracy on test set...')
        loss_test_all = []
        acc_test_all = []
        clf_loss_test_all = []
        rec_loss_test_all = []
        step = 0
        _test_batch_generator = utils.get_batches(x_test, y_test,
                                                  self.cfg.BATCH_SIZE)

        if self.cfg.WITH_RECONSTRUCTION:
            for _ in tqdm(range(n_batch_test),
                          total=n_batch_test,
                          ncols=100,
                          unit=' batches'):
                step += 1
                test_batch_x, test_batch_y = next(_test_batch_generator)
                loss_test_i, clf_loss_i, rec_loss_i, acc_test_i = sess.run(
                    [self.loss, self.clf_loss, self.rec_loss, self.accuracy],
                    feed_dict={
                        self.inputs: test_batch_x,
                        self.labels: test_batch_y,
                        self.is_training: False
                    })
                loss_test_all.append(loss_test_i)
                acc_test_all.append(acc_test_i)
                clf_loss_test_all.append(clf_loss_i)
                rec_loss_test_all.append(rec_loss_i)

                # Save reconstruct images
                if self.cfg.TEST_SAVE_IMAGE_STEP is not None:
                    if step % self.cfg.TEST_SAVE_IMAGE_STEP == 0:
                        self._save_images(sess,
                                          self.test_image_path,
                                          test_batch_x,
                                          test_batch_y,
                                          step,
                                          silent=False)

            clf_loss_test = sum(clf_loss_test_all) / len(clf_loss_test_all)
            rec_loss_test = sum(rec_loss_test_all) / len(rec_loss_test_all)

        else:
            for _ in tqdm(range(n_batch_test),
                          total=n_batch_test,
                          ncols=100,
                          unit=' batches'):
                test_batch_x, test_batch_y = next(_test_batch_generator)
                loss_test_i, acc_test_i = sess.run(
                    [self.loss, self.accuracy],
                    feed_dict={
                        self.inputs: test_batch_x,
                        self.labels: test_batch_y,
                        self.is_training: False
                    })
                loss_test_all.append(loss_test_i)
                acc_test_all.append(acc_test_i)
            clf_loss_test, rec_loss_test = None, None

        loss_test = sum(loss_test_all) / len(loss_test_all)
        acc_test = sum(acc_test_all) / len(acc_test_all)

        # Print losses and accuracy
        utils.thin_line()
        print('Test_Loss: {:.4f}\n'.format(loss_test),
              'Test_Accuracy: {:.2f}%'.format(acc_test * 100))
        if self.cfg.WITH_RECONSTRUCTION:
            utils.thin_line()
            print('Test_Train_Loss: {:.4f}\n'.format(clf_loss_test),
                  'Test_Reconstruction_Loss: {:.4f}'.format(rec_loss_test))

        # Save test log
        utils.save_test_log(self.test_log_path, loss_test, acc_test,
                            clf_loss_test, rec_loss_test,
                            self.cfg.WITH_RECONSTRUCTION)

        utils.thin_line()
        print('Testing finished! Using time: {:.2f}'.format(time.time() -
                                                            test_start_time))
Example #7
0
    def _eval_on_batches(self, mode, sess, x, y, n_batch, silent=False):
        """
    Calculate losses and accuracies of full train set.
    """
        loss_all = []
        acc_all = []
        clf_loss_all = []
        rec_loss_all = []

        if not silent:
            utils.thin_line()
            print(
                'Calculating loss and accuracy of full {} set...'.format(mode))
            _batch_generator = utils.get_batches(x, y, self.cfg.BATCH_SIZE)

            if self.cfg.WITH_RECONSTRUCTION:
                for _ in tqdm(range(n_batch),
                              total=n_batch,
                              ncols=100,
                              unit=' batches'):
                    x_batch, y_batch = next(_batch_generator)
                    loss_i, clf_loss_i, rec_loss_i, acc_i = sess.run(
                        [
                            self.loss, self.clf_loss, self.rec_loss,
                            self.accuracy
                        ],
                        feed_dict={
                            self.inputs: x_batch,
                            self.labels: y_batch,
                            self.is_training: False
                        })
                    loss_all.append(loss_i)
                    clf_loss_all.append(clf_loss_i)
                    rec_loss_all.append(rec_loss_i)
                    acc_all.append(acc_i)
                clf_loss = sum(clf_loss_all) / len(clf_loss_all)
                rec_loss = sum(rec_loss_all) / len(rec_loss_all)
            else:
                for _ in tqdm(range(n_batch),
                              total=n_batch,
                              ncols=100,
                              unit=' batches'):
                    x_batch, y_batch = next(_batch_generator)
                    loss_i, acc_i = sess.run(
                        [self.loss, self.accuracy],
                        feed_dict={
                            self.inputs: x_batch,
                            self.labels: y_batch,
                            self.is_training: False
                        })
                    loss_all.append(loss_i)
                    acc_all.append(acc_i)
                clf_loss, rec_loss = None, None

        else:
            if self.cfg.WITH_RECONSTRUCTION:
                for x_batch, y_batch in utils.get_batches(
                        x, y, self.cfg.BATCH_SIZE):
                    loss_i, clf_loss_i, rec_loss_i, acc_i = sess.run(
                        [
                            self.loss, self.clf_loss, self.rec_loss,
                            self.accuracy
                        ],
                        feed_dict={
                            self.inputs: x_batch,
                            self.labels: y_batch,
                            self.is_training: False
                        })
                    loss_all.append(loss_i)
                    clf_loss_all.append(clf_loss_i)
                    rec_loss_all.append(rec_loss_i)
                    acc_all.append(acc_i)
                clf_loss = sum(clf_loss_all) / len(clf_loss_all)
                rec_loss = sum(rec_loss_all) / len(rec_loss_all)
            else:
                for x_batch, y_batch in utils.get_batches(
                        x, y, self.cfg.BATCH_SIZE):
                    loss_i, acc_i = sess.run(
                        [self.loss, self.accuracy],
                        feed_dict={
                            self.inputs: x_batch,
                            self.labels: y_batch,
                            self.is_training: False
                        })
                    loss_all.append(loss_i)
                    acc_all.append(acc_i)
                clf_loss, rec_loss = None, None

        loss = sum(loss_all) / len(loss_all)
        accuracy = sum(acc_all) / len(acc_all)

        return loss, clf_loss, rec_loss, accuracy
Example #8
0
    def _eval_on_batches(self, sess, inputs, labels, input_imgs, is_training,
                         preds, loss, acc, clf_loss, rec_loss, rec_imgs):
        """Calculate losses and accuracies of full train set."""
        pred_all = []
        loss_all = []
        acc_all = []
        clf_loss_all = []
        rec_loss_all = []
        step = 0
        batch_generator = utils.get_batches(
            x=self.x_test,
            y=self.y_test,
            imgs=self.imgs_test,
            batch_size=self.cfg.TEST_BATCH_SIZE,
            keep_last=True)

        if len(self.x_test) % self.cfg.TEST_BATCH_SIZE == 0:
            n_batch = (len(self.x_test) // self.cfg.TEST_BATCH_SIZE)
        else:
            n_batch = (len(self.x_test) // self.cfg.TEST_BATCH_SIZE) + 1

        if self.cfg.TEST_WITH_REC:
            for _ in tqdm(range(n_batch),
                          total=n_batch,
                          ncols=100,
                          unit=' batch'):
                step += 1
                x_batch, y_batch, imgs_batch = next(batch_generator)
                len_batch = len(x_batch)

                if len_batch == self.cfg.TEST_BATCH_SIZE:
                    pred_i, loss_i, clf_loss_i, rec_loss_i, acc_i = \
                        sess.run([preds, loss, clf_loss, rec_loss, acc],
                                 feed_dict={inputs: x_batch,
                                            labels: y_batch,
                                            input_imgs: imgs_batch,
                                            is_training: False})
                    loss_all.append(loss_i)
                    clf_loss_all.append(clf_loss_i)
                    rec_loss_all.append(rec_loss_i)
                    acc_all.append(acc_i)

                    # Save reconstruct images
                    if self.cfg.TEST_SAVE_IMAGE_STEP:
                        if step % self.cfg.TEST_SAVE_IMAGE_STEP == 0:
                            self._save_images(sess,
                                              rec_imgs,
                                              inputs,
                                              labels,
                                              is_training,
                                              x_batch,
                                              y_batch,
                                              imgs_batch,
                                              step=step)
                else:
                    # The last batch which has less examples
                    for i in range(self.cfg.TEST_BATCH_SIZE - len_batch):
                        x_batch = np.append(x_batch,
                                            np.expand_dims(np.zeros_like(
                                                x_batch[0]),
                                                           axis=0),
                                            axis=0)
                    assert len(x_batch) == self.cfg.TEST_BATCH_SIZE
                    pred_i = sess.run(preds,
                                      feed_dict={
                                          inputs: x_batch,
                                          is_training: False
                                      })
                    pred_i = pred_i[:len_batch]

                pred_all.extend(list(pred_i))

            clf_loss_ = sum(clf_loss_all) / len(clf_loss_all)
            rec_loss_ = sum(rec_loss_all) / len(rec_loss_all)

        else:
            for _ in tqdm(range(n_batch),
                          total=n_batch,
                          ncols=100,
                          unit=' batches'):
                x_batch, y_batch, imgs_batch = next(batch_generator)
                len_batch = len(x_batch)

                if len_batch == self.cfg.TEST_BATCH_SIZE:
                    pred_i, loss_i, acc_i = \
                        sess.run([preds, loss, acc],
                                 feed_dict={inputs: x_batch,
                                            labels: y_batch,
                                            input_imgs: imgs_batch,
                                            is_training: False})
                    loss_all.append(loss_i)
                    acc_all.append(acc_i)
                else:
                    # The last batch which has less examples
                    for i in range(self.cfg.TEST_BATCH_SIZE - len_batch):
                        x_batch = np.append(x_batch,
                                            np.expand_dims(np.zeros_like(
                                                x_batch[0]),
                                                           axis=0),
                                            axis=0)
                    assert len(x_batch) == self.cfg.TEST_BATCH_SIZE
                    pred_i = sess.run(preds,
                                      feed_dict={
                                          inputs: x_batch,
                                          is_training: False
                                      })
                    pred_i = pred_i[:len_batch]

                pred_all.extend(list(pred_i))

            clf_loss_, rec_loss_ = None, None

        loss_ = sum(loss_all) / len(loss_all)
        acc_ = sum(acc_all) / len(acc_all)

        assert len(pred_all) == len(self.x_test), (len(pred_all),
                                                   len(self.x_test))
        preds_vec = np.array(pred_all)

        return preds_vec, loss_, clf_loss_, rec_loss_, acc_
Example #9
0
    def _trainer(self, sess):

        utils.thick_line()
        print('Training...')

        # Merge all the summaries and create writers
        train_summary_path = join(self.summary_path, 'train')
        valid_summary_path = join(self.summary_path, 'valid')
        utils.check_dir([train_summary_path, valid_summary_path])

        utils.thin_line()
        print('Generating TensorFLow summary writer...')
        train_writer = tf.summary.FileWriter(train_summary_path, sess.graph)
        valid_writer = tf.summary.FileWriter(valid_summary_path)

        sess.run(tf.global_variables_initializer())
        step = 0

        for epoch_i in range(self.cfg.EPOCHS):

            epoch_start_time = time.time()
            utils.thick_line()
            print('Training on epoch: {}/{}'.format(epoch_i + 1,
                                                    self.cfg.EPOCHS))

            utils.thin_line()
            train_batch_generator = utils.get_batches(
                x=self.x_train,
                y=self.y_train,
                imgs=self.imgs_train,
                batch_size=self.cfg.BATCH_SIZE)

            if self.cfg.DISPLAY_STEP:
                iterator = range(self.n_batch_train)
                silent = False
            else:
                iterator = tqdm(range(self.n_batch_train),
                                total=self.n_batch_train,
                                ncols=100,
                                unit=' batch')
                silent = True

            for _ in iterator:

                step += 1
                x_batch, y_batch, imgs_batch = next(train_batch_generator)

                # Training optimizer
                sess.run(self.optimizer,
                         feed_dict={
                             self.inputs: x_batch,
                             self.labels: y_batch,
                             self.input_imgs: imgs_batch,
                             self.step: step - 1,
                             self.is_training: True
                         })

                # Display training information
                if self.cfg.DISPLAY_STEP:
                    if step % self.cfg.DISPLAY_STEP == 0:
                        self._display_status(sess, x_batch, y_batch,
                                             imgs_batch, epoch_i, step - 1)

                # Save training logs
                if self.cfg.SAVE_LOG_STEP:
                    if step % self.cfg.SAVE_LOG_STEP == 0:
                        self._save_logs(sess, train_writer, valid_writer,
                                        x_batch, y_batch, imgs_batch, epoch_i,
                                        step - 1)

                # Save reconstruction images
                if self.cfg.SAVE_IMAGE_STEP:
                    if self.cfg.WITH_REC:
                        if step % self.cfg.SAVE_IMAGE_STEP == 0:
                            self._save_images(sess,
                                              self.train_image_path,
                                              x_batch,
                                              y_batch,
                                              imgs_batch,
                                              step - 1,
                                              epoch_i=epoch_i,
                                              silent=silent)

                # Save models
                if self.cfg.SAVE_MODEL_MODE == 'per_batch':
                    if step % self.cfg.SAVE_MODEL_STEP == 0:
                        self._save_model(sess,
                                         self.saver,
                                         step - 1,
                                         silent=silent)

                # Evaluate on full set
                if self.cfg.FULL_SET_EVAL_MODE == 'per_batch':
                    if step % self.cfg.FULL_SET_EVAL_STEP == 0:
                        self._eval_on_full_set(sess,
                                               epoch_i,
                                               step - 1,
                                               silent=silent)

            # Save model per epoch
            if self.cfg.SAVE_MODEL_MODE == 'per_epoch':
                if (epoch_i + 1) % self.cfg.SAVE_MODEL_STEP == 0:
                    self._save_model(sess, self.saver, epoch_i)

            # Evaluate on valid set per epoch
            if self.cfg.FULL_SET_EVAL_MODE == 'per_epoch':
                if (epoch_i + 1) % self.cfg.FULL_SET_EVAL_STEP == 0:
                    self._eval_on_full_set(sess, epoch_i, step - 1)

            # Evaluate on test set per epoch
            if self.cfg.TEST_SO_MODE == 'per_epoch':
                self._test(sess,
                           during_training=True,
                           epoch=epoch_i,
                           step=step,
                           mode='single')

            # Evaluate on multi-objects test set per epoch
            if self.cfg.TEST_MO_MODE == 'per_epoch':
                self._test(sess,
                           during_training=True,
                           epoch=epoch_i,
                           step=step,
                           mode='multi_obj')

            utils.thin_line()
            print('Epoch {}/{} done! Using time: {:.2f}'.format(
                epoch_i + 1, self.cfg.EPOCHS,
                time.time() - epoch_start_time))

        utils.thick_line()
        print('Training finished! Using time: {:.2f}'.format(time.time() -
                                                             self.start_time))
        utils.thick_line()

        # Evaluate on test set after training
        if self.cfg.TEST_SO_MODE == 'after_training':
            self._test(sess, during_training=True, epoch='end', mode='single')

        # Evaluate on multi-objects test set after training
        if self.cfg.TEST_MO_MODE == 'after_training':
            self._test(sess,
                       during_training=True,
                       epoch='end',
                       mode='multi_obj')

        utils.thick_line()
        print('All task finished! Total time: {:.2f}'.format(time.time() -
                                                             self.start_time))
        utils.thick_line()