예제 #1
0
    def train(self,
              train_input_fn,
              run_eval_after_train=False,
              eval_input_fn=None):
        """Run distributed training on Mask RCNN model."""

        self._save_config()
        train_run_config = self.build_strategy_configuration('train')
        train_params = self.build_model_parameters('train')
        train_estimator = self.build_mask_rcnn_estimator(
            train_params, train_run_config, 'train')

        with dump_callback():
            train_estimator.train(
                input_fn=train_input_fn,
                max_steps=self._runtime_config.total_steps,
                hooks=get_training_hooks(
                    mode="train",
                    model_dir=self._runtime_config.model_dir,
                    checkpoint_path=self._runtime_config.checkpoint,
                    skip_checkpoint_variables=self._runtime_config.
                    skip_checkpoint_variables,
                    batch_size=train_params['batch_size'],
                    save_summary_steps=self._runtime_config.save_summary_steps,
                ))

        if not run_eval_after_train:
            return None

        if eval_input_fn is None:
            raise ValueError(
                'Eval input_fn must be passed to conduct evaluation after training.'
            )

        eval_run_config = self.build_strategy_configuration('eval')
        eval_params = self.build_model_parameters('eval')
        eval_estimator = self.build_mask_rcnn_estimator(
            eval_params, eval_run_config, 'eval')

        last_ckpt = self.get_last_checkpoint_path()
        logging.info("Restoring parameters from %s\n" % last_ckpt)

        eval_results, predictions = evaluation.evaluate(
            eval_estimator,
            eval_input_fn,
            self._runtime_config.eval_samples,
            self._runtime_config.eval_batch_size,
            self._runtime_config.include_mask,
            self._runtime_config.val_json_file,
            report_frequency=self._runtime_config.report_frequency,
            checkpoint_path=last_ckpt)

        output_dir = os.path.join(self._runtime_config.model_dir, 'eval')
        tf.io.gfile.makedirs(output_dir)

        # Summary writer writes out eval metrics.
        self._write_summary(output_dir, eval_results, predictions,
                            self._runtime_config.total_steps)

        return eval_results
예제 #2
0
    def eval(self, eval_input_fn):
        """Run distributed eval on Mask RCNN model."""

        output_dir = os.path.join(self._runtime_config.model_dir, 'eval')
        tf.io.gfile.makedirs(output_dir)

        # Summary writer writes out eval metrics.
        run_config = self.build_strategy_configuration('eval')
        eval_params = self.build_model_parameters('eval')
        eval_estimator = self.build_mask_rcnn_estimator(
            eval_params, run_config, 'eval')

        logging.info('Starting to evaluate.')

        last_ckpt = self.get_last_checkpoint_path()

        if last_ckpt is not None:
            logging.info("Restoring parameters from %s\n" % last_ckpt)
            current_step = int(os.path.basename(last_ckpt).split('-')[1])

        else:
            logging.warning(
                "Could not find trained model in model_dir: `%s`, running initialization to predict\n"
                % self._runtime_config.model_dir)
            current_step = 0

        eval_results, predictions = evaluation.evaluate(
            eval_estimator,
            eval_input_fn,
            self._runtime_config.eval_samples,
            self._runtime_config.eval_batch_size,
            self._runtime_config.include_mask,
            self._runtime_config.val_json_file,
            checkpoint_path=last_ckpt)

        self._write_summary(output_dir, eval_results, predictions,
                            current_step)

        if current_step >= self._runtime_config.total_steps:
            logging.info('Evaluation finished after training step %d' %
                         current_step)

        return eval_results
예제 #3
0
  def train_and_eval(self, train_input_fn, eval_input_fn):
    """Run distributed train and eval on Mask RCNN model."""

    self._save_config()
    output_dir = os.path.join(self._runtime_config.model_dir, 'eval')
    tf.io.gfile.makedirs(output_dir)

    train_run_config = self.build_strategy_configuration('train')
    train_params = self.build_model_parameters('train')
    train_estimator = self.build_mask_rcnn_estimator(train_params, train_run_config, 'train')

    eval_estimator = None
    eval_results = None

    num_cycles = math.ceil(self._runtime_config.total_steps / self._runtime_config.num_steps_per_eval)

    training_hooks = get_training_hooks(
        mode="train",
        model_dir=self._runtime_config.model_dir,
        checkpoint_path=self._runtime_config.checkpoint,
        skip_checkpoint_variables=self._runtime_config.skip_checkpoint_variables
    )

    for cycle in range(1, num_cycles + 1):

      if not MPI_is_distributed() or MPI_rank() == 0:

        print()  # Visual Spacing
        logging.info("=================================")
        logging.info('     Start training cycle %02d' % cycle)
        logging.info("=================================\n")

      max_cycle_step = min(int(cycle * self._runtime_config.num_steps_per_eval), self._runtime_config.total_steps)

      PROFILER_ENABLED = False

      if (not MPI_is_distributed() or MPI_rank() == 0) and PROFILER_ENABLED:
          profiler_context_manager = tf.contrib.tfprof.ProfileContext

      else:
          from contextlib import suppress
          profiler_context_manager = lambda *args, **kwargs: suppress()  # No-Op context manager

      with profiler_context_manager(
              '/workspace/profiling/',
              trace_steps=range(100, 200, 3),
              dump_steps=[200]
      ) as pctx:

          if (not MPI_is_distributed() or MPI_rank() == 0) and PROFILER_ENABLED:
            opts = tf.compat.v1.profiler.ProfileOptionBuilder.time_and_memory()
            pctx.add_auto_profiling('op', opts, [150, 200])

          train_estimator.train(
              input_fn=train_input_fn,
              max_steps=max_cycle_step,
              hooks=training_hooks,
          )

      if not MPI_is_distributed() or MPI_rank() == 0:

          print()  # Visual Spacing
          logging.info("=================================")
          logging.info('    Start evaluation cycle %02d' % cycle)
          logging.info("=================================\n")

          if eval_estimator is None:
              eval_run_config = self.build_strategy_configuration('eval')
              eval_params = self.build_model_parameters('eval')
              eval_estimator = self.build_mask_rcnn_estimator(eval_params, eval_run_config, 'eval')

          last_ckpt = tf.train.latest_checkpoint(self._runtime_config.model_dir, latest_filename=None)
          logging.info("Restoring parameters from %s\n" % last_ckpt)

          eval_results, predictions = evaluation.evaluate(
              eval_estimator,
              eval_input_fn,
              self._runtime_config.eval_samples,
              self._runtime_config.eval_batch_size,
              self._runtime_config.include_mask,
              self._runtime_config.val_json_file,
              report_frequency=self._runtime_config.report_frequency
          )

          self._write_summary(output_dir, eval_results, predictions, max_cycle_step)

      if MPI_is_distributed():
          from mpi4py import MPI
          comm = hvd.get_worker_comm()
          comm.Barrier()  # Waiting for all MPI processes to sync

    return eval_results