Ejemplo n.º 1
0
def create_optimizer(
    task: base_task.Task, params: config_definitions.ExperimentConfig
) -> tf.keras.optimizers.Optimizer:
    """A create optimizer util to be backward compatability with new args."""
    if 'dp_config' in inspect.signature(task.create_optimizer).parameters:
        dp_config = None
        if hasattr(params.task, 'differential_privacy_config'):
            dp_config = params.task.differential_privacy_config
        optimizer = task.create_optimizer(params.trainer.optimizer_config,
                                          params.runtime,
                                          dp_config=dp_config)
    else:
        if hasattr(params.task, 'differential_privacy_config'
                   ) and params.task.differential_privacy_config is not None:
            raise ValueError('Differential privacy config is specified but '
                             'task.create_optimizer api does not accept it.')
        optimizer = task.create_optimizer(params.trainer.optimizer_config,
                                          params.runtime)
    return optimizer
Ejemplo n.º 2
0
def create_trainer(params: config_definitions.ExperimentConfig,
                   task: base_task.Task,
                   train: bool,
                   evaluate: bool,
                   checkpoint_exporter: Optional[BestCheckpointExporter] = None,
                   trainer_cls=base_trainer.Trainer) -> base_trainer.Trainer:
  """Create trainer."""
  logging.info('Running default trainer.')
  model = task.build_model()
  optimizer = task.create_optimizer(params.trainer, params.runtime)
  return trainer_cls(
      params,
      task,
      model=model,
      optimizer=optimizer,
      train=train,
      evaluate=evaluate,
      checkpoint_exporter=checkpoint_exporter)
Ejemplo n.º 3
0
def run_experiment_with_multitask_eval(
    *,
    distribution_strategy: tf.distribute.Strategy,
    train_task: base_task.Task,
    eval_tasks: List[base_task.Task],
    mode: str,
    params: configs.MultiEvalExperimentConfig,
    model_dir: str,
    run_post_eval: bool = False,
    save_summary: bool = True,
    trainer: Optional[core_lib.Trainer] = None) -> tf.keras.Model:
  """Runs train/eval configured by the experiment params.

  Args:
    distribution_strategy: A distribution distribution_strategy.
    train_task: A base_task.Task instance.
    eval_tasks: A list of evaluation tasks.
    mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
      or 'continuous_eval'.
    params: MultiEvalExperimentConfig instance.
    model_dir: A 'str', a path to store model checkpoints and summaries.
    run_post_eval: Whether to run post eval once after training, metrics logs
      are returned.
    save_summary: Whether to save train and validation summary.
    trainer: the core_lib.Trainer instance. It should be created within the
      strategy.scope(). If not provided, an instance will be created by default
      if `mode` contains 'train'.

  Returns:
      model: `tf.keras.Model` instance.
  """

  is_training = 'train' in mode
  is_eval = 'eval' in mode
  with distribution_strategy.scope():
    if is_training:
      trainer = trainer or core_lib.Trainer(
          config=params,
          task=train_task,
          model=train_task.build_model(),
          optimizer=train_task.create_optimizer(params.trainer.optimizer_config,
                                                params.runtime),
          train=True,
          evaluate=False)
    else:
      trainer = None
    model = trainer.model if trainer else train_task.build_model()

    if is_eval:
      eval_steps = dict([(task_routine.task_config.name,
                          task_routine.eval_steps)
                         for task_routine in params.eval_tasks])
      evaluator = evaluator_lib.MultiTaskEvaluator(
          eval_tasks=eval_tasks,
          model=model,
          global_step=trainer.global_step if is_training else None,
          eval_steps=eval_steps,
          checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
              params, model_dir))
    else:
      evaluator = None

  if trainer:
    checkpoint = trainer.checkpoint
    global_step = trainer.global_step
  else:
    checkpoint = evaluator.checkpoint
    global_step = evaluator.global_step

  checkpoint_manager = tf.train.CheckpointManager(
      checkpoint,
      directory=model_dir,
      max_to_keep=params.trainer.max_to_keep,
      step_counter=global_step,
      checkpoint_interval=params.trainer.checkpoint_interval,
      init_fn=trainer.initialize if trainer else None)

  controller = orbit.Controller(
      strategy=distribution_strategy,
      trainer=trainer,
      evaluator=evaluator,
      global_step=global_step,
      steps_per_loop=params.trainer.steps_per_loop,
      checkpoint_manager=checkpoint_manager,
      summary_dir=os.path.join(model_dir, 'train') if save_summary else None,
      eval_summary_dir=os.path.join(model_dir, 'validation') if
      (save_summary) else None,
      summary_interval=params.trainer.summary_interval if
      (save_summary) else None)

  logging.info('Starts to execute mode: %s', mode)
  with distribution_strategy.scope():
    if mode == 'train':
      controller.train(steps=params.trainer.train_steps)
    elif mode == 'train_and_eval':
      controller.train_and_evaluate(
          train_steps=params.trainer.train_steps,
          eval_steps=params.trainer.validation_steps,
          eval_interval=params.trainer.validation_interval)
    elif mode == 'eval':
      controller.evaluate(steps=params.trainer.validation_steps)
    elif mode == 'continuous_eval':

      def timeout_fn():
        if evaluator.global_step.numpy() >= params.trainer.train_steps:
          return True
        return False

      controller.evaluate_continuously(
          steps=params.trainer.validation_steps,
          timeout=params.trainer.continuous_eval_timeout,
          timeout_fn=timeout_fn)
    else:
      raise NotImplementedError('The mode is not implemented: %s' % mode)

    if run_post_eval:
      return model, evaluator.evaluate(
          tf.convert_to_tensor(params.trainer.validation_steps))
    else:
      return model, {}
Ejemplo n.º 4
0
def run_experiment_wtih_multitask_eval(
        *, distribution_strategy: tf.distribute.Strategy,
        train_task: base_task.Task, eval_tasks: multitask.MultiTask, mode: str,
        params: configs.MultiEvalExperimentConfig,
        model_dir: str) -> tf.keras.Model:
    """Runs train/eval configured by the experiment params.

  Args:
    distribution_strategy: A distribution distribution_strategy.
    train_task: A base_task.Task instance.
    eval_tasks: A multitask.MultiTask with evaluation tasks.
    mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
      or 'continuous_eval'.
    params: MultiEvalExperimentConfig instance.
    model_dir: A 'str', a path to store model checkpoints and summaries.

  Returns:
      model: `tf.keras.Model` instance.
  """

    is_training = 'train' in mode
    is_eval = 'eval' in mode
    with distribution_strategy.scope():
        optimizer = train_task.create_optimizer(params.trainer, params.runtime)
        model = train_task.build_model()
        if is_training:
            trainer = core_lib.Trainer(config=params,
                                       task=train_task,
                                       model=model,
                                       optimizer=optimizer,
                                       train=True,
                                       evaluate=False)
        else:
            trainer = None
        if is_eval:
            evaluator = evaluator_lib.MultiTaskEvaluator(
                task=eval_tasks,
                model=model,
                global_step=trainer.global_step if is_training else None)
        else:
            evaluator = None

    if trainer:
        checkpoint = trainer.checkpoint
        global_step = trainer.global_step
    else:
        checkpoint = evaluator.checkpoint
        global_step = evaluator.global_step

    checkpoint_manager = tf.train.CheckpointManager(
        checkpoint,
        directory=model_dir,
        max_to_keep=params.trainer.max_to_keep,
        step_counter=global_step,
        checkpoint_interval=params.trainer.checkpoint_interval,
        init_fn=trainer.initialize if trainer else None)

    controller = orbit.Controller(
        strategy=distribution_strategy,
        trainer=trainer,
        evaluator=evaluator,
        global_step=global_step,
        steps_per_loop=params.trainer.steps_per_loop,
        checkpoint_manager=checkpoint_manager,
        summary_dir=os.path.join(model_dir, 'train'),
        eval_summary_dir=os.path.join(model_dir, 'validation'),
        summary_interval=params.trainer.summary_interval)

    logging.info('Starts to execute mode: %s', mode)
    with distribution_strategy.scope():
        if mode == 'train':
            controller.train(steps=params.trainer.train_steps)
        elif mode == 'train_and_eval':
            controller.train_and_evaluate(
                train_steps=params.trainer.train_steps,
                eval_steps=params.trainer.validation_steps,
                eval_interval=params.trainer.validation_interval)
        elif mode == 'eval':
            controller.evaluate(steps=params.trainer.validation_steps)
        elif mode == 'continuous_eval':

            def timeout_fn():
                if evaluator.global_step.numpy() >= params.trainer.train_steps:
                    return True
                return False

            controller.evaluate_continuously(
                steps=params.trainer.validation_steps,
                timeout=params.trainer.continuous_eval_timeout,
                timeout_fn=timeout_fn)
        else:
            raise NotImplementedError('The mode is not implemented: %s' % mode)

        return model