Exemplo n.º 1
0
    def train(self):

        self.allocate_placeholders()
        self.build_model()

        self.sess.run(tf.global_variables_initializer())

        fetchworker = Fetcher(self.opts)
        fetchworker.start()

        self.saver = tf.train.Saver(max_to_keep=None)
        self.writer = tf.summary.FileWriter(self.opts.log_dir, self.sess.graph)

        restore_epoch = 0
        if self.opts.restore:
            restore_epoch, checkpoint_path = model_utils.pre_load_checkpoint(
                self.opts.log_dir)
            self.saver.restore(self.sess, checkpoint_path)
            #self.saver.restore(self.sess, tf.train.latest_checkpoint(self.opts.log_dir))
            self.LOG_FOUT = open(
                os.path.join(self.opts.log_dir, 'log_train.txt'), 'a')
            tf.assign(self.global_step,
                      restore_epoch * fetchworker.num_batches).eval()
            restore_epoch += 1

        else:
            os.makedirs(os.path.join(self.opts.log_dir, 'plots'))
            self.LOG_FOUT = open(
                os.path.join(self.opts.log_dir, 'log_train.txt'), 'w')

        with open(os.path.join(self.opts.log_dir, 'args.txt'), 'w') as log:
            for arg in sorted(vars(self.opts)):
                log.write(arg + ': ' + str(getattr(self.opts, arg)) +
                          '\n')  # log of arguments

        step = self.sess.run(self.global_step)
        start = time()
        for epoch in range(restore_epoch, self.opts.training_epoch):
            logging.info('**** EPOCH %03d ****\t' % (epoch))
            for batch_idx in range(fetchworker.num_batches):

                batch_input_x, batch_input_y, batch_radius = fetchworker.fetch(
                )

                feed_dict = {
                    self.input_x: batch_input_x,
                    self.input_y: batch_input_y,
                    self.pc_radius: batch_radius,
                    self.is_training: True
                }

                # Update D network
                _, d_loss, d_summary = self.sess.run(
                    [self.D_optimizers, self.D_loss, self.d_summary_op],
                    feed_dict=feed_dict)
                self.writer.add_summary(d_summary, step)

                # Update G network
                for i in range(self.opts.gen_update):
                    # get previously generated images
                    _, g_total_loss, summary = self.sess.run(
                        [
                            self.G_optimizers, self.total_gen_loss,
                            self.g_summary_op
                        ],
                        feed_dict=feed_dict)
                    self.writer.add_summary(summary, step)

                if step % self.opts.steps_per_print == 0:
                    self.log_string(
                        '-----------EPOCH %d Step %d:-------------' %
                        (epoch, step))
                    self.log_string('  G_loss   : {}'.format(g_total_loss))
                    self.log_string('  D_loss   : {}'.format(d_loss))
                    self.log_string(' Time Cost : {}'.format(time() - start))
                    start = time()
                    feed_dict = {
                        self.input_x: batch_input_x,
                        self.is_training: False
                    }

                    fake_y_val = self.sess.run([self.G_y], feed_dict=feed_dict)

                    fake_y_val = np.squeeze(fake_y_val)
                    image_input_x = point_cloud_three_views(batch_input_x[0])
                    image_fake_y = point_cloud_three_views(fake_y_val[0])
                    image_input_y = point_cloud_three_views(batch_input_y[0, :,
                                                                          0:3])
                    image_x_merged = np.concatenate(
                        [image_input_x, image_fake_y, image_input_y], axis=1)
                    image_x_merged = np.expand_dims(image_x_merged, axis=0)
                    image_x_merged = np.expand_dims(image_x_merged, axis=-1)
                    image_x_summary = self.sess.run(
                        self.image_x_summary,
                        feed_dict={self.image_x_merged: image_x_merged})
                    self.writer.add_summary(image_x_summary, step)

                if self.opts.visulize and (step % self.opts.steps_per_visu
                                           == 0):
                    feed_dict = {
                        self.input_x: batch_input_x,
                        self.input_y: batch_input_y,
                        self.pc_radius: batch_radius,
                        self.is_training: False
                    }
                    pcds = self.sess.run([self.visualize_ops],
                                         feed_dict=feed_dict)
                    pcds = np.squeeze(
                        pcds
                    )  # np.asarray(pcds).reshape([3,self.opts.num_point,3])
                    plot_path = os.path.join(
                        self.opts.log_dir, 'plots',
                        'epoch_%d_step_%d.png' % (epoch, step))
                    plot_pcd_three_views(plot_path, pcds,
                                         self.visualize_titles)

                step += 1
            if (epoch % self.opts.epoch_per_save) == 0:
                self.saver.save(self.sess,
                                os.path.join(self.opts.log_dir, 'model'),
                                epoch)
                print(
                    colored('Model saved at %s' % self.opts.log_dir, 'white',
                            'on_blue'))

        fetchworker.shutdown()
Exemplo n.º 2
0
  def train_one_epoch(self,epoch=0):

      n_examples = int(len(self.train_dataset))

      epoch_g_loss = AverageMeter()
      epoch_d_loss = AverageMeter()
      epoch_coarse_loss = AverageMeter()
      epoch_coarse_hd_loss = AverageMeter()
      epoch_fine_loss = AverageMeter()
      epoch_fine_hd_loss = AverageMeter()

      n_batches = int(n_examples / self.opts.batch_size) - 1
      start_time = time()

      for _ in tqdm(range(n_batches)):

          batch_input_x, batch_input_y, batch_radius = self.train_dataset.next_batch()

          feed_dict = {self.input: batch_input_x,
                       self.gt: batch_input_y,
                       self.pc_radius: batch_radius,
                       self.is_training: True}

          if self.use_gan:
              # Update D network
              _, _, d_loss, d_summary = self.sess.run([self.D_optimizers, self.D_clip, self.D_loss, self.d_summary_op],
                                                      feed_dict=feed_dict)
              self.writer.add_summary(d_summary, self.step)
              epoch_d_loss.update(d_loss)
          # Update G network
          _, g_loss, coarse_loss, fine_loss, coarse_hd_loss,fine_hd_loss, summary = self.sess.run(
              [self.G_optimizers, self.total_gen_loss, self.dis_coarse_cd, self.dis_fine_cd,
               self.dis_coarse_hd, self.dis_fine_hd,
               self.g_summary_op], feed_dict=feed_dict)
          self.writer.add_summary(summary, self.step)

          epoch_g_loss.update(g_loss)
          epoch_coarse_loss.update(coarse_loss)
          epoch_fine_loss.update(fine_loss)
          epoch_fine_hd_loss.update(fine_hd_loss)
          epoch_coarse_hd_loss.update(coarse_hd_loss)

          if True:
              self.step += 1
              if True and self.step % self.opts.steps_per_print == 0:

                  feed_dict = {self.input: batch_input_x,
                               self.is_training: False}

                  coarse,fine = self.sess.run([self.coarse,self.fine], feed_dict=feed_dict)

                  image_sparse = point_cloud_three_views(batch_input_x[0])
                  image_coarse = point_cloud_three_views(coarse[0])
                  image_fine = point_cloud_three_views(fine[0])
                  image_gt = point_cloud_three_views(batch_input_y[0])
                  image_merged = np.concatenate([image_sparse, image_coarse, image_fine, image_gt], axis=1)
                  image_merged = np.transpose(image_merged, [1, 0])
                  image_merged = np.expand_dims(image_merged, axis=0)
                  image_merged = np.expand_dims(image_merged, axis=-1)
                  image_summary = self.sess.run(self.image_summary, feed_dict={self.image_merged: image_merged})
                  self.writer.add_summary(image_summary, self.step)

              if self.opts.visulize and (self.step % self.opts.steps_per_visu == 0):
                  feed_dict = {self.input: batch_input_x,
                               self.gt: batch_input_y,
                               self.pc_radius: batch_radius,
                               self.is_training: False}
                  pcds = self.sess.run([self.visualize_ops], feed_dict=feed_dict)
                  pcds = np.squeeze(pcds)  # np.asarray(pcds).reshape([3,self.opts.num_point,3])
                  plot_path = os.path.join(self.opts.log_dir, 'plots',
                                           'epoch_%d_step_%d.png' % (epoch, self.step))
                  plot_pcd_three_views(plot_path, pcds, self.visualize_titles)

      duration = time() - start_time

      return (
          epoch_d_loss.avg,
          epoch_g_loss.avg,
          epoch_coarse_loss.avg,
          epoch_coarse_hd_loss.avg,
          epoch_fine_loss.avg,
          epoch_fine_hd_loss.avg,
          duration,
      )