コード例 #1
0
    def run(self) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
        """Run experiments by mode.

    Returns:
      A 2-tuple of (model, eval_logs).
        model: `tf.keras.Model` instance.
        eval_logs: returns eval metrics logs when run_post_eval is set to True,
          otherwise, returns {}.
    """
        mode = self._mode
        params = self.params
        logging.info('Starts to execute mode: %s', mode)
        with self.strategy.scope():
            if mode == 'train' or mode == 'train_and_post_eval':
                self.controller.train(steps=params.trainer.train_steps)
            elif mode == 'train_and_eval':
                self.controller.train_and_evaluate(
                    train_steps=params.trainer.train_steps,
                    eval_steps=params.trainer.validation_steps,
                    eval_interval=params.trainer.validation_interval)
            elif mode == 'eval':
                self.controller.evaluate(steps=params.trainer.validation_steps)
            elif mode == 'continuous_eval':

                def timeout_fn():
                    if self.trainer.global_step.numpy(
                    ) >= params.trainer.train_steps:
                        return True
                    return False

                self.controller.evaluate_continuously(
                    steps=params.trainer.validation_steps,
                    timeout=params.trainer.continuous_eval_timeout,
                    timeout_fn=timeout_fn)
            else:
                raise NotImplementedError('The mode is not implemented: %s' %
                                          mode)

        num_params = train_utils.try_count_params(self.trainer.model)
        if num_params is not None:
            logging.info('Number of trainable params in model: %f Millions.',
                         num_params / 10.**6)

        flops = train_utils.try_count_flops(self.trainer.model)
        if flops is not None:
            logging.info('FLOPs (multi-adds) in model: %f Billions.',
                         flops / 10.**9 / 2)

        if self._run_post_eval or mode == 'train_and_post_eval':
            with self.strategy.scope():
                return self.trainer.model, self.controller.evaluate(
                    steps=params.trainer.validation_steps)
        else:
            return self.trainer.model, {}
コード例 #2
0
def run_experiment(
    distribution_strategy: tf.distribute.Strategy,
    task: base_task.Task,
    mode: str,
    params: config_definitions.ExperimentConfig,
    model_dir: str,
    run_post_eval: bool = False,
    save_summary: bool = True,
    trainer: Optional[base_trainer.Trainer] = None,
    controller_cls=orbit.Controller
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
    """Runs train/eval configured by the experiment params.

  Args:
    distribution_strategy: A distribution distribution_strategy.
    task: A Task instance.
    mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
      or 'continuous_eval'.
    params: ExperimentConfig instance.
    model_dir: A 'str', a path to store model checkpoints and summaries.
    run_post_eval: Whether to run post eval once after training, metrics logs
      are returned.
    save_summary: Whether to save train and validation summary.
    trainer: the base_trainer.Trainer instance. It should be created within the
      strategy.scope().
    controller_cls: The controller class to manage the train and eval process.
      Must be a orbit.Controller subclass.

  Returns:
    A 2-tuple of (model, eval_logs).
      model: `tf.keras.Model` instance.
      eval_logs: returns eval metrics logs when run_post_eval is set to True,
        otherwise, returns {}.
  """

    with distribution_strategy.scope():
        if not trainer:
            trainer = train_utils.create_trainer(
                params,
                task,
                train='train' in mode,
                evaluate=('eval' in mode) or run_post_eval,
                checkpoint_exporter=maybe_create_best_ckpt_exporter(
                    params, model_dir))

    if trainer.checkpoint:
        if model_dir is None:
            raise ValueError('model_dir must be specified, but got None')
        checkpoint_manager = tf.train.CheckpointManager(
            trainer.checkpoint,
            directory=model_dir,
            max_to_keep=params.trainer.max_to_keep,
            step_counter=trainer.global_step,
            checkpoint_interval=params.trainer.checkpoint_interval,
            init_fn=trainer.initialize)
    else:
        checkpoint_manager = None

    controller = controller_cls(
        strategy=distribution_strategy,
        trainer=trainer if 'train' in mode else None,
        evaluator=trainer,
        global_step=trainer.global_step,
        steps_per_loop=params.trainer.steps_per_loop,
        checkpoint_manager=checkpoint_manager,
        summary_dir=os.path.join(model_dir, 'train') if
        (save_summary) else None,
        eval_summary_dir=os.path.join(
            model_dir, params.trainer.validation_summary_subdir) if
        (save_summary) else None,
        summary_interval=params.trainer.summary_interval if
        (save_summary) else None,
        train_actions=actions.get_train_actions(
            params, trainer, model_dir, checkpoint_manager=checkpoint_manager),
        eval_actions=actions.get_eval_actions(params, trainer, model_dir))

    logging.info('Starts to execute mode: %s', mode)
    with distribution_strategy.scope():
        if mode == 'train':
            controller.train(steps=params.trainer.train_steps)
        elif mode == 'train_and_eval':
            controller.train_and_evaluate(
                train_steps=params.trainer.train_steps,
                eval_steps=params.trainer.validation_steps,
                eval_interval=params.trainer.validation_interval)
        elif mode == 'eval':
            controller.evaluate(steps=params.trainer.validation_steps)
        elif mode == 'continuous_eval':

            def timeout_fn():
                if trainer.global_step.numpy() >= params.trainer.train_steps:
                    return True
                return False

            controller.evaluate_continuously(
                steps=params.trainer.validation_steps,
                timeout=params.trainer.continuous_eval_timeout,
                timeout_fn=timeout_fn)
        else:
            raise NotImplementedError('The mode is not implemented: %s' % mode)

    num_params = train_utils.try_count_params(trainer.model)
    if num_params is not None:
        logging.info('Number of trainable params in model: %f Millions.',
                     num_params / 10.**6)

    flops = train_utils.try_count_flops(trainer.model)
    if flops is not None:
        logging.info('FLOPs (multi-adds) in model: %f Billions.',
                     flops / 10.**9 / 2)

    if run_post_eval:
        with distribution_strategy.scope():
            return trainer.model, trainer.evaluate(
                tf.convert_to_tensor(params.trainer.validation_steps))
    else:
        return trainer.model, {}
コード例 #3
0
def export_inference_graph(
        input_type: str,
        batch_size: Optional[int],
        input_image_size: List[int],
        params: cfg.ExperimentConfig,
        checkpoint_path: str,
        export_dir: str,
        num_channels: Optional[int] = 3,
        export_module: Optional[export_base.ExportModule] = None,
        export_checkpoint_subdir: Optional[str] = None,
        export_saved_model_subdir: Optional[str] = None,
        save_options: Optional[tf.saved_model.SaveOptions] = None,
        log_model_flops_and_params: bool = False,
        checkpoint: Optional[tf.train.Checkpoint] = None,
        input_name: Optional[str] = None):
    """Exports inference graph for the model specified in the exp config.

  Saved model is stored at export_dir/saved_model, checkpoint is saved
  at export_dir/checkpoint, and params is saved at export_dir/params.yaml.

  Args:
    input_type: One of `image_tensor`, `image_bytes`, `tf_example` or `tflite`.
    batch_size: 'int', or None.
    input_image_size: List or Tuple of height and width.
    params: Experiment params.
    checkpoint_path: Trained checkpoint path or directory.
    export_dir: Export directory path.
    num_channels: The number of input image channels.
    export_module: Optional export module to be used instead of using params
      to create one. If None, the params will be used to create an export
      module.
    export_checkpoint_subdir: Optional subdirectory under export_dir
      to store checkpoint.
    export_saved_model_subdir: Optional subdirectory under export_dir
      to store saved model.
    save_options: `SaveOptions` for `tf.saved_model.save`.
    log_model_flops_and_params: If True, writes model FLOPs to model_flops.txt
      and model parameters to model_params.txt.
    checkpoint: An optional tf.train.Checkpoint. If provided, the export module
      will use it to read the weights.
    input_name: The input tensor name, default at `None` which produces input
      tensor name `inputs`.
  """

    if export_checkpoint_subdir:
        output_checkpoint_directory = os.path.join(export_dir,
                                                   export_checkpoint_subdir)
    else:
        output_checkpoint_directory = None

    if export_saved_model_subdir:
        output_saved_model_directory = os.path.join(export_dir,
                                                    export_saved_model_subdir)
    else:
        output_saved_model_directory = export_dir

    # TODO(arashwan): Offers a direct path to use ExportModule with Task objects.
    if not export_module:
        if isinstance(params.task,
                      configs.image_classification.ImageClassificationTask):
            export_module = image_classification.ClassificationModule(
                params=params,
                batch_size=batch_size,
                input_image_size=input_image_size,
                input_type=input_type,
                num_channels=num_channels,
                input_name=input_name)
        elif isinstance(params.task,
                        configs.retinanet.RetinaNetTask) or isinstance(
                            params.task, configs.maskrcnn.MaskRCNNTask):
            export_module = detection.DetectionModule(
                params=params,
                batch_size=batch_size,
                input_image_size=input_image_size,
                input_type=input_type,
                num_channels=num_channels,
                input_name=input_name)
        elif isinstance(
                params.task,
                configs.semantic_segmentation.SemanticSegmentationTask):
            export_module = semantic_segmentation.SegmentationModule(
                params=params,
                batch_size=batch_size,
                input_image_size=input_image_size,
                input_type=input_type,
                num_channels=num_channels,
                input_name=input_name)
        elif isinstance(params.task,
                        configs.video_classification.VideoClassificationTask):
            export_module = video_classification.VideoClassificationModule(
                params=params,
                batch_size=batch_size,
                input_image_size=input_image_size,
                input_type=input_type,
                num_channels=num_channels,
                input_name=input_name)
        else:
            raise ValueError(
                'Export module not implemented for {} task.'.format(
                    type(params.task)))

    export_base.export(export_module,
                       function_keys=[input_type],
                       export_savedmodel_dir=output_saved_model_directory,
                       checkpoint=checkpoint,
                       checkpoint_path=checkpoint_path,
                       timestamped=False,
                       save_options=save_options)

    if output_checkpoint_directory:
        ckpt = tf.train.Checkpoint(model=export_module.model)
        ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt'))
    train_utils.serialize_config(params, export_dir)

    if log_model_flops_and_params:
        inputs_kwargs = None
        if isinstance(
                params.task,
            (configs.retinanet.RetinaNetTask, configs.maskrcnn.MaskRCNNTask)):
            # We need to create inputs_kwargs argument to specify the input shapes for
            # subclass model that overrides model.call to take multiple inputs,
            # e.g., RetinaNet model.
            inputs_kwargs = {
                'images':
                tf.TensorSpec([1] + input_image_size + [num_channels],
                              tf.float32),
                'image_shape':
                tf.TensorSpec([1, 2], tf.float32)
            }
            dummy_inputs = {
                k: tf.ones(v.shape.as_list(), tf.float32)
                for k, v in inputs_kwargs.items()
            }
            # Must do forward pass to build the model.
            export_module.model(**dummy_inputs)
        else:
            logging.info(
                'Logging model flops and params not implemented for %s task.',
                type(params.task))
            return
        train_utils.try_count_flops(
            export_module.model, inputs_kwargs,
            os.path.join(export_dir, 'model_flops.txt'))
        train_utils.write_model_params(
            export_module.model, os.path.join(export_dir, 'model_params.txt'))