예제 #1
0
 def test_uniform_sample_distribution(self):
   uniform_sampler = sampler.get_task_sampler(
       configs.TaskSamplingConfig(type='uniform'), self._task_weights)
   for step in range(5):
     cumulative_distribution = uniform_sampler.task_cumulative_distribution(
         tf.constant(step, dtype=tf.int64))
     self.assertAllClose([0.333333, 0.666666, 1.0],
                         cumulative_distribution.numpy())
예제 #2
0
 def test_proportional_sample_distribution(self):
   prop_sampler = sampler.get_task_sampler(
       configs.TaskSamplingConfig(
           type='proportional',
           proportional=configs.ProportionalSampleConfig(alpha=2.0)),
       self._task_weights)
   # CucmulativeOf(Normalize([1.0^2, 2.0^2, 3.0^2]))
   for step in range(5):
     cumulative_distribution = prop_sampler.task_cumulative_distribution(
         tf.constant(step, dtype=tf.int64))
     self.assertAllClose([0.07142857, 0.35714286, 1.0],
                         cumulative_distribution.numpy())
예제 #3
0
  def test_annealing_sample_distribution(self):
    num_epoch = 3
    step_per_epoch = 6
    annel_sampler = sampler.get_task_sampler(
        configs.TaskSamplingConfig(
            type='annealing',
            annealing=configs.AnnealingSampleConfig(
                steps_per_epoch=step_per_epoch,
                total_steps=step_per_epoch * num_epoch)), self._task_weights)

    global_step = tf.Variable(
        0, dtype=tf.int64, name='global_step', trainable=False)
    expected_cumulative_epochs = [[0.12056106, 0.4387236, 1.0],
                                  [0.16666667, 0.5, 1.0],
                                  [0.22477472, 0.5654695, 1.0]]
    for epoch in range(num_epoch):
      for _ in range(step_per_epoch):
        cumulative_distribution = annel_sampler.task_cumulative_distribution(
            tf.constant(global_step, dtype=tf.int64))
        global_step.assign_add(1)
        self.assertAllClose(expected_cumulative_epochs[epoch],
                            cumulative_distribution.numpy())
예제 #4
0
def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
                   task: multitask.MultiTask,
                   model: base_model.MultiTaskBaseModel, mode: str,
                   params: configs.MultiTaskExperimentConfig,
                   model_dir: str) -> base_model.MultiTaskBaseModel:
  """Runs train/eval configured by the experiment params.

  Args:
    distribution_strategy: A distribution distribution_strategy.
    task: A MultiTaskTask instance.
    model: A MultiTaskBaseModel instance.
    mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
      or 'continuous_eval'.
    params: ExperimentConfig instance.
    model_dir: A 'str', a path to store model checkpoints and summaries.

  Returns:
      model: `base_model.MultiTaskBaseModel` instance.
  """

  is_training = 'train' in mode
  is_eval = 'eval' in mode
  with distribution_strategy.scope():
    optimizer = task.create_optimizer(params.trainer.optimizer_config,
                                      params.runtime)
    kwargs = dict(multi_task=task, multi_task_model=model, optimizer=optimizer)
    if params.trainer.trainer_type == 'interleaving':
      sampler = task_sampler.get_task_sampler(params.trainer.task_sampler,
                                              task.task_weights)
      kwargs.update(dict(task_sampler=sampler))
    trainer = TRAINERS[params.trainer.trainer_type](
        **kwargs) if is_training else None
    if is_eval:
      eval_steps = task.task_eval_steps
      evaluator = evaluator_lib.MultiTaskEvaluator(
          eval_tasks=task.tasks.values(),
          model=model,
          eval_steps=eval_steps,
          global_step=trainer.global_step if is_training else None,
          checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
              params, model_dir))
    else:
      evaluator = None

  if trainer:
    checkpoint = trainer.checkpoint
    global_step = trainer.global_step
  else:
    checkpoint = evaluator.checkpoint
    global_step = evaluator.global_step

  # TODO(hongkuny,haozhangthu): Revisit initialization method.
  checkpoint_manager = tf.train.CheckpointManager(
      checkpoint,
      directory=model_dir,
      max_to_keep=params.trainer.max_to_keep,
      step_counter=global_step,
      checkpoint_interval=params.trainer.checkpoint_interval,
      init_fn=model.initialize)

  controller = orbit.Controller(
      strategy=distribution_strategy,
      trainer=trainer,
      evaluator=evaluator,
      global_step=global_step,
      steps_per_loop=params.trainer.steps_per_loop,
      checkpoint_manager=checkpoint_manager,
      summary_dir=os.path.join(model_dir, 'train'),
      eval_summary_dir=os.path.join(model_dir, 'validation'),
      summary_interval=params.trainer.summary_interval)

  logging.info('Starts to execute mode: %s', mode)
  with distribution_strategy.scope():
    if mode == 'train':
      controller.train(steps=params.trainer.train_steps)
    elif mode == 'train_and_eval':
      controller.train_and_evaluate(
          train_steps=params.trainer.train_steps,
          eval_steps=params.trainer.validation_steps,
          eval_interval=params.trainer.validation_interval)
    elif mode == 'eval':
      controller.evaluate(steps=params.trainer.validation_steps)
    elif mode == 'continuous_eval':

      def timeout_fn():
        if evaluator.global_step.numpy() >= params.trainer.train_steps:
          return True
        return False

      controller.evaluate_continuously(
          steps=params.trainer.validation_steps,
          timeout=params.trainer.continuous_eval_timeout,
          timeout_fn=timeout_fn)
    else:
      raise NotImplementedError('The mode is not implemented: %s' % mode)

    return model