Ejemplo n.º 1
0
    def _build_controller(self,
                          trainer,
                          evaluator,
                          save_summary: bool = True,
                          train_actions: Optional[List[orbit.Action]] = None,
                          eval_actions: Optional[List[orbit.Action]] = None,
                          controller_cls=orbit.Controller) -> orbit.Controller:
        """Builds a Orbit controler."""
        train_actions = [] if not train_actions else train_actions
        if trainer:
            train_actions += actions.get_train_actions(
                self.params,
                trainer,
                self.model_dir,
                checkpoint_manager=self.checkpoint_manager)

        eval_actions = [] if not eval_actions else eval_actions
        if evaluator:
            eval_actions += actions.get_eval_actions(self.params, evaluator,
                                                     self.model_dir)

        controller = controller_cls(
            strategy=self.strategy,
            trainer=trainer,
            evaluator=evaluator,
            global_step=self.trainer.global_step,
            steps_per_loop=self.params.trainer.steps_per_loop,
            checkpoint_manager=self.checkpoint_manager,
            summary_dir=os.path.join(self.model_dir, 'train') if
            (save_summary) else None,
            eval_summary_dir=os.path.join(
                self.model_dir, self.params.trainer.validation_summary_subdir)
            if (save_summary) else None,
            summary_interval=self.params.trainer.summary_interval if
            (save_summary) else None,
            train_actions=train_actions,
            eval_actions=eval_actions)
        return controller
Ejemplo n.º 2
0
def run_experiment(
    distribution_strategy: tf.distribute.Strategy,
    task: base_task.Task,
    mode: str,
    params: config_definitions.ExperimentConfig,
    model_dir: str,
    run_post_eval: bool = False,
    save_summary: bool = True,
    trainer: Optional[base_trainer.Trainer] = None,
    controller_cls=orbit.Controller
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
    """Runs train/eval configured by the experiment params.

  Args:
    distribution_strategy: A distribution distribution_strategy.
    task: A Task instance.
    mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
      or 'continuous_eval'.
    params: ExperimentConfig instance.
    model_dir: A 'str', a path to store model checkpoints and summaries.
    run_post_eval: Whether to run post eval once after training, metrics logs
      are returned.
    save_summary: Whether to save train and validation summary.
    trainer: the base_trainer.Trainer instance. It should be created within the
      strategy.scope().
    controller_cls: The controller class to manage the train and eval process.
      Must be a orbit.Controller subclass.

  Returns:
    A 2-tuple of (model, eval_logs).
      model: `tf.keras.Model` instance.
      eval_logs: returns eval metrics logs when run_post_eval is set to True,
        otherwise, returns {}.
  """

    with distribution_strategy.scope():
        if not trainer:
            trainer = train_utils.create_trainer(
                params,
                task,
                train='train' in mode,
                evaluate=('eval' in mode) or run_post_eval,
                checkpoint_exporter=maybe_create_best_ckpt_exporter(
                    params, model_dir))

    if trainer.checkpoint:
        if model_dir is None:
            raise ValueError('model_dir must be specified, but got None')
        checkpoint_manager = tf.train.CheckpointManager(
            trainer.checkpoint,
            directory=model_dir,
            max_to_keep=params.trainer.max_to_keep,
            step_counter=trainer.global_step,
            checkpoint_interval=params.trainer.checkpoint_interval,
            init_fn=trainer.initialize)
    else:
        checkpoint_manager = None

    controller = controller_cls(
        strategy=distribution_strategy,
        trainer=trainer if 'train' in mode else None,
        evaluator=trainer,
        global_step=trainer.global_step,
        steps_per_loop=params.trainer.steps_per_loop,
        checkpoint_manager=checkpoint_manager,
        summary_dir=os.path.join(model_dir, 'train') if
        (save_summary) else None,
        eval_summary_dir=os.path.join(
            model_dir, params.trainer.validation_summary_subdir) if
        (save_summary) else None,
        summary_interval=params.trainer.summary_interval if
        (save_summary) else None,
        train_actions=actions.get_train_actions(
            params, trainer, model_dir, checkpoint_manager=checkpoint_manager),
        eval_actions=actions.get_eval_actions(params, trainer, model_dir))

    logging.info('Starts to execute mode: %s', mode)
    with distribution_strategy.scope():
        if mode == 'train':
            controller.train(steps=params.trainer.train_steps)
        elif mode == 'train_and_eval':
            controller.train_and_evaluate(
                train_steps=params.trainer.train_steps,
                eval_steps=params.trainer.validation_steps,
                eval_interval=params.trainer.validation_interval)
        elif mode == 'eval':
            controller.evaluate(steps=params.trainer.validation_steps)
        elif mode == 'continuous_eval':

            def timeout_fn():
                if trainer.global_step.numpy() >= params.trainer.train_steps:
                    return True
                return False

            controller.evaluate_continuously(
                steps=params.trainer.validation_steps,
                timeout=params.trainer.continuous_eval_timeout,
                timeout_fn=timeout_fn)
        else:
            raise NotImplementedError('The mode is not implemented: %s' % mode)

    num_params = train_utils.try_count_params(trainer.model)
    if num_params is not None:
        logging.info('Number of trainable params in model: %f Millions.',
                     num_params / 10.**6)

    flops = train_utils.try_count_flops(trainer.model)
    if flops is not None:
        logging.info('FLOPs (multi-adds) in model: %f Billions.',
                     flops / 10.**9 / 2)

    if run_post_eval:
        with distribution_strategy.scope():
            return trainer.model, trainer.evaluate(
                tf.convert_to_tensor(params.trainer.validation_steps))
    else:
        return trainer.model, {}