Пример #1
0
class Main(object):
    def __init__(self,
                 cfg,
                 model_architecture,
                 mode='normal',
                 fine_tune=False):
        """Load data and initialize models."""
        # Global start time
        self.start_time = time.time()

        # Config
        self.cfg = cfg
        self.fine_tune = fine_tune
        self.multi_gpu = True

        # Use encode transfer learning
        if self.cfg.TRANSFER_LEARNING == 'encode':
            self.tl_encode = True
        else:
            self.tl_encode = False

        # Get paths from configuration
        self.preprocessed_path, self.train_log_path, \
            self.summary_path, self.checkpoint_path, \
            self.train_image_path = self._get_paths()

        # Load data
        self.x_train, self.y_train, self.imgs_train, \
            self.x_valid, self.y_valid, self.imgs_valid = self._load_data()

        # Calculate number of batches
        self.n_batch_train = len(self.y_train) // cfg.BATCH_SIZE
        self.n_batch_valid = len(self.y_valid) // cfg.BATCH_SIZE

        # Restore a pre-trained model
        if self.fine_tune:
            restore_vars_dict = self._get_restore_vars_dict()
        else:
            restore_vars_dict = None

        # Get model
        if mode == 'multi-tasks':
            self.multi_gpu = True
            self.model = ModelMultiTasks(cfg, model_architecture,
                                         restore_vars_dict)
        elif mode == 'multi-gpu':
            self.multi_gpu = True
            self.model = ModelDistribute(cfg, model_architecture,
                                         restore_vars_dict)
        else:
            self.multi_gpu = False
            self.model = Model(cfg, model_architecture, restore_vars_dict)

        # Build graph
        utils.thick_line()
        print('Building graph...')
        tf.reset_default_graph()
        self.step, self.train_graph, self.inputs, self.labels, self.input_imgs,\
            self.is_training, self.optimizer, self.saver, self.summary, \
            self.loss, self.accuracy, self.clf_loss, self.clf_preds, \
            self.rec_loss, self.rec_imgs = self.model.build_graph(
                input_size=self.x_train.shape[1:],
                image_size=self.imgs_train.shape[1:],
                num_class=self.y_train.shape[1])

        # Save config
        utils.save_config_log(self.train_log_path, cfg,
                              self.model.model_arch_info)

    def _get_paths(self):
        """Get paths from configuration."""
        if self.cfg.DATABASE_MODE is not None:
            preprocessed_path = join(
                '../data/{}'.format(self.cfg.DATABASE_MODE),
                self.cfg.DATABASE_NAME)
        else:
            preprocessed_path = join(self.cfg.DPP_DATA_PATH,
                                     self.cfg.DATABASE_NAME)

        # Fine-tuning
        if self.fine_tune:
            preprocessed_path = join(self.cfg.DPP_DATA_PATH,
                                     self.cfg.FT_DATABASE_NAME)
            self.restore_checkpoint_path = join(self.cfg.CHECKPOINT_PATH,
                                                self.cfg.VERSION)
            self.cfg.VERSION += '_ft'

        train_log_path_ = join(self.cfg.TRAIN_LOG_PATH, self.cfg.VERSION)
        summary_path_ = join(self.cfg.SUMMARY_PATH, self.cfg.VERSION)
        checkpoint_path_ = join(self.cfg.CHECKPOINT_PATH, self.cfg.VERSION)

        # Get log paths, append information if the directory exist.
        train_log_path = train_log_path_
        i_append_info = 0
        while isdir(train_log_path):
            i_append_info += 1
            train_log_path = train_log_path_ + '({})'.format(i_append_info)

        if i_append_info > 0:
            summary_path = summary_path_ + '({})'.format(i_append_info)
            checkpoint_path = checkpoint_path_ + '({})'.format(i_append_info)
        else:
            summary_path = summary_path_
            checkpoint_path = checkpoint_path_

        # Images saving path
        train_image_path = join(train_log_path, 'images')

        # Check directory of paths
        utils.check_dir([train_log_path, checkpoint_path])
        if self.cfg.WITH_REC:
            if self.cfg.SAVE_IMAGE_STEP:
                utils.check_dir([train_image_path])

        return preprocessed_path, train_log_path, \
            summary_path, checkpoint_path, train_image_path

    def _load_data(self):
        """Load preprocessed data."""
        utils.thick_line()
        print('Loading data...')

        x_train = utils.load_pkls(self.preprocessed_path,
                                  'x_train',
                                  tl=self.tl_encode)
        x_valid = utils.load_pkls(self.preprocessed_path,
                                  'x_valid',
                                  tl=self.tl_encode,
                                  add_n_batch=1)

        imgs_train = utils.load_pkls(self.preprocessed_path, 'imgs_train')
        imgs_valid = utils.load_pkls(self.preprocessed_path, 'imgs_valid')

        if imgs_train.shape == x_train.shape:
            print('[W] imgs_train.shape == x_train.shape')
            del imgs_train
            del imgs_valid
            gc.collect()
            imgs_train = x_train
            imgs_valid = x_valid

        y_train = utils.load_pkls(self.preprocessed_path, 'y_train')
        y_valid = utils.load_pkls(self.preprocessed_path, 'y_valid')

        utils.thin_line()
        print('Data info:')
        utils.thin_line()
        print('x_train: {}\ny_train: {}\nx_valid: {}\ny_valid: {}'.format(
            x_train.shape, y_train.shape, x_valid.shape, y_valid.shape))

        print('imgs_train: {}\nimgs_valid: {}'.format(imgs_train.shape,
                                                      imgs_valid.shape))

        return x_train, y_train, imgs_train, x_valid, y_valid, imgs_valid

    def _get_restore_vars_dict(self):
        """Load pre-trained variables."""
        utils.thick_line()
        print('Loading pre-trained variables from:\n',
              self.restore_checkpoint_path)
        utils.thin_line()

        tf.reset_default_graph()
        loaded_graph = tf.Graph()

        with tf.Session(graph=loaded_graph) as sess:
            ckp_path = tf.train.latest_checkpoint(self.restore_checkpoint_path)
            loader = tf.train.import_meta_graph(ckp_path + '.meta')
            loader.restore(sess, ckp_path)

            restore_vars_dict = dict()
            restore_vars_dict['w_conv_0'] = sess.run(
                loaded_graph.get_tensor_by_name('classifier/conv_0/weights:0'))
            restore_vars_dict['b_conv_0'] = sess.run(
                loaded_graph.get_tensor_by_name('classifier/conv_0/biases:0'))
            restore_vars_dict['w_caps_0'] = sess.run(
                loaded_graph.get_tensor_by_name('classifier/caps_0/weights:0'))
            restore_vars_dict['b_caps_0'] = sess.run(
                loaded_graph.get_tensor_by_name('classifier/caps_0/biases:0'))
            restore_vars_dict['w_caps_1'] = sess.run(
                loaded_graph.get_tensor_by_name('classifier/caps_1/weights:0'))
            #       restore_vars_dict['b_caps_1'] = sess.run(
            #           loaded_graph.get_tensor_by_name('classifier/caps_1/biases:0'))

            return restore_vars_dict

    def _display_status(self, sess, x_batch, y_batch, imgs_batch, epoch_i,
                        step):
        """Display information during training."""
        valid_batch_idx = np.random.choice(range(len(self.x_valid)),
                                           self.cfg.BATCH_SIZE).tolist()
        x_valid_batch = self.x_valid[valid_batch_idx]
        y_valid_batch = self.y_valid[valid_batch_idx]
        imgs_valid_batch = self.imgs_valid[valid_batch_idx]

        if self.cfg.WITH_REC:
            loss_train, clf_loss_train, rec_loss_train, acc_train = \
                sess.run([self.loss, self.clf_loss,
                          self.rec_loss, self.accuracy],
                         feed_dict={self.inputs: x_batch,
                                    self.labels: y_batch,
                                    self.input_imgs: imgs_batch,
                                    self.is_training: False})
            loss_valid, clf_loss_valid, rec_loss_valid, acc_valid = \
                sess.run([self.loss, self.clf_loss,
                          self.rec_loss, self.accuracy],
                         feed_dict={self.inputs: x_valid_batch,
                                    self.labels: y_valid_batch,
                                    self.input_imgs: imgs_valid_batch,
                                    self.is_training: False})
        else:
            loss_train, acc_train = \
                sess.run([self.loss, self.accuracy],
                         feed_dict={self.inputs: x_batch,
                                    self.labels: y_batch,
                                    self.input_imgs: imgs_batch,
                                    self.is_training: False})
            loss_valid, acc_valid = \
                sess.run([self.loss, self.accuracy],
                         feed_dict={self.inputs: x_valid_batch,
                                    self.labels: y_valid_batch,
                                    self.input_imgs: imgs_valid_batch,
                                    self.is_training: False})
            clf_loss_train, rec_loss_train, clf_loss_valid, rec_loss_valid = \
                None, None, None, None

        utils.print_status(epoch_i, self.cfg.EPOCHS, step, self.start_time,
                           loss_train, clf_loss_train, rec_loss_train,
                           acc_train, loss_valid, clf_loss_valid,
                           rec_loss_valid, acc_valid, self.cfg.WITH_REC)

    def _save_logs(self, sess, train_writer, valid_writer, x_batch, y_batch,
                   imgs_batch, epoch_i, step):
        """Save logs and ddd summaries to TensorBoard while training."""
        valid_batch_idx = np.random.choice(range(len(self.x_valid)),
                                           self.cfg.BATCH_SIZE).tolist()
        x_valid_batch = self.x_valid[valid_batch_idx]
        y_valid_batch = self.y_valid[valid_batch_idx]
        imgs_valid_batch = self.imgs_valid[valid_batch_idx]

        if self.cfg.WITH_REC:
            summary_train, loss_train, clf_loss_train, rec_loss_train, acc_train = \
                sess.run([self.summary, self.loss, self.clf_loss,
                          self.rec_loss, self.accuracy],
                         feed_dict={self.inputs: x_batch,
                                    self.labels: y_batch,
                                    self.input_imgs: imgs_batch,
                                    self.is_training: False})
            summary_valid, loss_valid, clf_loss_valid, rec_loss_valid, acc_valid = \
                sess.run([self.summary, self.loss, self.clf_loss,
                          self.rec_loss, self.accuracy],
                         feed_dict={self.inputs: x_valid_batch,
                                    self.labels: y_valid_batch,
                                    self.input_imgs: imgs_valid_batch,
                                    self.is_training: False})
        else:
            summary_train, loss_train, acc_train = \
                sess.run([self.summary, self.loss, self.accuracy],
                         feed_dict={self.inputs: x_batch,
                                    self.labels: y_batch,
                                    self.input_imgs: imgs_batch,
                                    self.is_training: False})
            summary_valid, loss_valid, acc_valid = \
                sess.run([self.summary, self.loss, self.accuracy],
                         feed_dict={self.inputs: x_valid_batch,
                                    self.labels: y_valid_batch,
                                    self.input_imgs: imgs_valid_batch,
                                    self.is_training: False})
            clf_loss_train, rec_loss_train, clf_loss_valid, rec_loss_valid = \
                None, None, None, None

        train_writer.add_summary(summary_train, step)
        valid_writer.add_summary(summary_valid, step)
        utils.save_log(join(self.train_log_path,
                            'train_log.csv'), epoch_i + 1, step,
                       time.time() - self.start_time, loss_train,
                       clf_loss_train, rec_loss_train, acc_train, loss_valid,
                       clf_loss_valid, rec_loss_valid, acc_valid,
                       self.cfg.WITH_REC)

    def _eval_on_batches(self, mode, sess, x, y, imgs, n_batch, silent=False):
        """Calculate losses and accuracies of full train set."""
        loss_all = []
        acc_all = []
        clf_loss_all = []
        rec_loss_all = []

        batch_generator = utils.get_batches(x=x,
                                            y=y,
                                            imgs=imgs,
                                            batch_size=self.cfg.BATCH_SIZE)

        if not silent:
            utils.thin_line()
            print(
                'Calculating loss and accuracy of full {} set...'.format(mode))
            iterator = tqdm(range(n_batch),
                            total=n_batch,
                            ncols=100,
                            unit=' batches')
        else:
            iterator = range(n_batch)

        if self.cfg.WITH_REC:
            for _ in iterator:

                x_batch, y_batch, imgs_batch = next(batch_generator)

                loss_i, clf_loss_i, rec_loss_i, acc_i = sess.run(
                    [self.loss, self.clf_loss, self.rec_loss, self.accuracy],
                    feed_dict={
                        self.inputs: x_batch,
                        self.labels: y_batch,
                        self.input_imgs: imgs_batch,
                        self.is_training: False
                    })
                loss_all.append(loss_i)
                clf_loss_all.append(clf_loss_i)
                rec_loss_all.append(rec_loss_i)
                acc_all.append(acc_i)
            clf_loss = sum(clf_loss_all) / len(clf_loss_all)
            rec_loss = sum(rec_loss_all) / len(rec_loss_all)
        else:
            for _ in iterator:

                x_batch, y_batch, imgs_batch = next(batch_generator)

                loss_i, acc_i = sess.run(
                    [self.loss, self.accuracy],
                    feed_dict={
                        self.inputs: x_batch,
                        self.labels: y_batch,
                        self.input_imgs: imgs_batch,
                        self.is_training: False
                    })
                loss_all.append(loss_i)
                acc_all.append(acc_i)
            clf_loss, rec_loss = None, None

        loss = sum(loss_all) / len(loss_all)
        accuracy = sum(acc_all) / len(acc_all)

        return loss, clf_loss, rec_loss, accuracy

    def _eval_on_full_set(self, sess, epoch_i, step, silent=False):
        """Evaluate on the full data set and print information."""
        eval_start_time = time.time()

        if not silent:
            utils.thick_line()
            print('Calculating losses using full data set...')

        # Calculate losses and accuracies of full train set
        if self.cfg.EVAL_WITH_FULL_TRAIN_SET:
            loss_train, clf_loss_train, rec_loss_train, acc_train = \
                self._eval_on_batches(
                    'train', sess, self.x_train, self.y_train,
                    self.imgs_train, self.n_batch_train, silent=silent)
        else:
            loss_train, clf_loss_train, rec_loss_train, acc_train = \
                None, None, None, None

        # Calculate losses and accuracies of full valid set
        loss_valid, clf_loss_valid, rec_loss_valid, acc_valid = \
            self._eval_on_batches(
                'valid', sess, self.x_valid, self.y_valid,
                self.imgs_valid, self.n_batch_valid, silent=silent)

        if not silent:
            utils.print_full_set_eval(
                epoch_i, self.cfg.EPOCHS, step, self.start_time, loss_train,
                clf_loss_train, rec_loss_train, acc_train, loss_valid,
                clf_loss_valid, rec_loss_valid, acc_valid,
                self.cfg.EVAL_WITH_FULL_TRAIN_SET, self.cfg.WITH_REC)

        file_path = join(self.train_log_path, 'full_set_eval_log.csv')
        if not silent:
            utils.thin_line()
            print('Saving {}...'.format(file_path))
        utils.save_log(file_path, epoch_i + 1, step,
                       time.time() - self.start_time, loss_train,
                       clf_loss_train, rec_loss_train, acc_train, loss_valid,
                       clf_loss_valid, rec_loss_valid, acc_valid,
                       self.cfg.WITH_REC)

        if not silent:
            utils.thin_line()
            print(
                'Evaluation done! Using time: {:.2f}'.format(time.time() -
                                                             eval_start_time))

    def _save_images(self,
                     sess,
                     img_path,
                     x,
                     y,
                     imgs,
                     step,
                     silent=False,
                     epoch_i=None,
                     test_flag=False):
        """Save reconstructed images."""
        rec_images_ = sess.run(self.rec_imgs,
                               feed_dict={
                                   self.inputs: x,
                                   self.labels: y,
                                   self.is_training: False
                               })

        # rec_images_ shape: [128, 28, 28, 1] for mnist
        utils.save_imgs(real_imgs=imgs,
                        rec_imgs=rec_images_,
                        img_path=img_path,
                        database_name=self.cfg.DATABASE_NAME,
                        max_img_in_col=self.cfg.MAX_IMAGE_IN_COL,
                        step=step,
                        silent=silent,
                        epoch_i=epoch_i,
                        test_flag=test_flag)

    def _save_model(self, sess, saver, step, silent=False):
        """Save models."""
        save_path = join(self.checkpoint_path, 'models.ckpt')
        if not silent:
            utils.thin_line()
            print('Saving models to {}...'.format(save_path))
        saver.save(sess, save_path, global_step=step)

    def _test(self,
              sess,
              during_training=False,
              epoch=None,
              step=None,
              mode='single'):
        """Evaluate on the test set."""
        utils.thick_line()
        start_time_test = time.time()

        test_params = dict(
            cfg=self.cfg,
            multi_gpu=self.multi_gpu,
            version=self.cfg.VERSION,
            during_training=during_training,
            epoch_train=epoch,
            step_train=step,
            model_arch_info=self.model.model_arch_info,
        )

        if mode == 'single':
            print('Testing on Single-object test set...')
            tester_ = Test
        elif mode == 'multi_obj':
            print('Testing on Multi-object test set...')
            tester_ = TestMultiObjects
        else:
            raise ValueError('Wrong mode name')

        tester_(**test_params).tester(sess, self.inputs, self.labels,
                                      self.input_imgs, self.is_training,
                                      self.clf_preds, self.rec_imgs,
                                      start_time_test, self.loss,
                                      self.accuracy, self.clf_loss,
                                      self.rec_loss)

    def _trainer(self, sess):

        utils.thick_line()
        print('Training...')

        # Merge all the summaries and create writers
        train_summary_path = join(self.summary_path, 'train')
        valid_summary_path = join(self.summary_path, 'valid')
        utils.check_dir([train_summary_path, valid_summary_path])

        utils.thin_line()
        print('Generating TensorFLow summary writer...')
        train_writer = tf.summary.FileWriter(train_summary_path, sess.graph)
        valid_writer = tf.summary.FileWriter(valid_summary_path)

        sess.run(tf.global_variables_initializer())
        step = 0

        for epoch_i in range(self.cfg.EPOCHS):

            epoch_start_time = time.time()
            utils.thick_line()
            print('Training on epoch: {}/{}'.format(epoch_i + 1,
                                                    self.cfg.EPOCHS))

            utils.thin_line()
            train_batch_generator = utils.get_batches(
                x=self.x_train,
                y=self.y_train,
                imgs=self.imgs_train,
                batch_size=self.cfg.BATCH_SIZE)

            if self.cfg.DISPLAY_STEP:
                iterator = range(self.n_batch_train)
                silent = False
            else:
                iterator = tqdm(range(self.n_batch_train),
                                total=self.n_batch_train,
                                ncols=100,
                                unit=' batch')
                silent = True

            for _ in iterator:

                step += 1
                x_batch, y_batch, imgs_batch = next(train_batch_generator)

                # Training optimizer
                sess.run(self.optimizer,
                         feed_dict={
                             self.inputs: x_batch,
                             self.labels: y_batch,
                             self.input_imgs: imgs_batch,
                             self.step: step - 1,
                             self.is_training: True
                         })

                # Display training information
                if self.cfg.DISPLAY_STEP:
                    if step % self.cfg.DISPLAY_STEP == 0:
                        self._display_status(sess, x_batch, y_batch,
                                             imgs_batch, epoch_i, step - 1)

                # Save training logs
                if self.cfg.SAVE_LOG_STEP:
                    if step % self.cfg.SAVE_LOG_STEP == 0:
                        self._save_logs(sess, train_writer, valid_writer,
                                        x_batch, y_batch, imgs_batch, epoch_i,
                                        step - 1)

                # Save reconstruction images
                if self.cfg.SAVE_IMAGE_STEP:
                    if self.cfg.WITH_REC:
                        if step % self.cfg.SAVE_IMAGE_STEP == 0:
                            self._save_images(sess,
                                              self.train_image_path,
                                              x_batch,
                                              y_batch,
                                              imgs_batch,
                                              step - 1,
                                              epoch_i=epoch_i,
                                              silent=silent)

                # Save models
                if self.cfg.SAVE_MODEL_MODE == 'per_batch':
                    if step % self.cfg.SAVE_MODEL_STEP == 0:
                        self._save_model(sess,
                                         self.saver,
                                         step - 1,
                                         silent=silent)

                # Evaluate on full set
                if self.cfg.FULL_SET_EVAL_MODE == 'per_batch':
                    if step % self.cfg.FULL_SET_EVAL_STEP == 0:
                        self._eval_on_full_set(sess,
                                               epoch_i,
                                               step - 1,
                                               silent=silent)

            # Save model per epoch
            if self.cfg.SAVE_MODEL_MODE == 'per_epoch':
                if (epoch_i + 1) % self.cfg.SAVE_MODEL_STEP == 0:
                    self._save_model(sess, self.saver, epoch_i)

            # Evaluate on valid set per epoch
            if self.cfg.FULL_SET_EVAL_MODE == 'per_epoch':
                if (epoch_i + 1) % self.cfg.FULL_SET_EVAL_STEP == 0:
                    self._eval_on_full_set(sess, epoch_i, step - 1)

            # Evaluate on test set per epoch
            if self.cfg.TEST_SO_MODE == 'per_epoch':
                self._test(sess,
                           during_training=True,
                           epoch=epoch_i,
                           step=step,
                           mode='single')

            # Evaluate on multi-objects test set per epoch
            if self.cfg.TEST_MO_MODE == 'per_epoch':
                self._test(sess,
                           during_training=True,
                           epoch=epoch_i,
                           step=step,
                           mode='multi_obj')

            utils.thin_line()
            print('Epoch {}/{} done! Using time: {:.2f}'.format(
                epoch_i + 1, self.cfg.EPOCHS,
                time.time() - epoch_start_time))

        utils.thick_line()
        print('Training finished! Using time: {:.2f}'.format(time.time() -
                                                             self.start_time))
        utils.thick_line()

        # Evaluate on test set after training
        if self.cfg.TEST_SO_MODE == 'after_training':
            self._test(sess, during_training=True, epoch='end', mode='single')

        # Evaluate on multi-objects test set after training
        if self.cfg.TEST_MO_MODE == 'after_training':
            self._test(sess,
                       during_training=True,
                       epoch='end',
                       mode='multi_obj')

        utils.thick_line()
        print('All task finished! Total time: {:.2f}'.format(time.time() -
                                                             self.start_time))
        utils.thick_line()

    def train(self):
        """Training models."""
        session_cfg = tf.ConfigProto(allow_soft_placement=True)
        session_cfg.gpu_options.allow_growth = True

        if self.cfg.VAR_ON_CPU:
            with tf.Session(graph=self.train_graph,
                            config=session_cfg) as sess:
                with tf.device('/cpu:0'):
                    self._trainer(sess)
        else:
            with tf.Session(graph=self.train_graph,
                            config=session_cfg) as sess:
                self._trainer(sess)