def test_single_task_training(self):
        iris = tfds.load('iris')
        train_ds = iris['train'].batch(32).repeat()

        model = tf.keras.Sequential([
            tf.keras.Input(shape=(4, ), name='features'),
            tf.keras.layers.Dense(10, activation=tf.nn.relu),
            tf.keras.layers.Dense(10, activation=tf.nn.relu),
            tf.keras.layers.Dense(3),
            tf.keras.layers.Softmax(),
        ])

        trainer = single_task_trainer.SingleTaskTrainer(
            train_ds,
            label_key='label',
            model=model,
            loss_fn=tf.keras.losses.sparse_categorical_crossentropy,
            optimizer=tf.keras.optimizers.SGD(learning_rate=0.01))

        controller = orbit.Controller(trainer=trainer,
                                      steps_per_loop=100,
                                      global_step=trainer.optimizer.iterations)

        controller.train(1)
        start_loss = trainer.train_loss.result().numpy()
        controller.train(500)
        end_loss = trainer.train_loss.result().numpy()

        # Assert that the model has trained 'significantly' - that the loss
        # has dropped by over 50%.
        self.assertLess(end_loss, start_loss / 2)
    def test_single_task_evaluation(self):

        iris = tfds.load('iris')
        train_ds = iris['train'].batch(32)

        model = tf.keras.Sequential([
            tf.keras.Input(shape=(4, ), name='features'),
            tf.keras.layers.Dense(10, activation=tf.nn.relu),
            tf.keras.layers.Dense(10, activation=tf.nn.relu),
            tf.keras.layers.Dense(3)
        ])

        trainer = single_task_trainer.SingleTaskTrainer(
            train_ds,
            label_key='label',
            model=model,
            loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True),
            optimizer=tf.keras.optimizers.SGD(learning_rate=0.01))

        evaluator = single_task_evaluator.SingleTaskEvaluator(
            train_ds,
            label_key='label',
            model=model,
            metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

        controller = orbit.Controller(trainer=trainer,
                                      evaluator=evaluator,
                                      steps_per_loop=100,
                                      global_step=trainer.optimizer.iterations)

        controller.train(train_ds.cardinality().numpy())
        controller.evaluate()
        accuracy = evaluator.metrics[0].result().numpy()

        self.assertGreater(0.925, accuracy)
예제 #3
0
def run_experiment(distribution_strategy: tf.distribute.Strategy,
                   task: base_task.Task,
                   mode: str,
                   params: config_definitions.ExperimentConfig,
                   model_dir: str,
                   run_post_eval: bool = False,
                   save_summary: bool = True) \
-> Tuple[tf.keras.Model, Mapping[str, Any]]:
    """Runs train/eval configured by the experiment params.

  Args:
    distribution_strategy: A distribution distribution_strategy.
    task: A Task instance.
    mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
      or 'continuous_eval'.
    params: ExperimentConfig instance.
    model_dir: A 'str', a path to store model checkpoints and summaries.
    run_post_eval: Whether to run post eval once after training, metrics logs
      are returned.
    save_summary: Whether to save train and validation summary.

  Returns:
    A 2-tuple of (model, eval_logs).
      model: `tf.keras.Model` instance.
      eval_logs: returns eval metrics logs when run_post_eval is set to True,
        otherwise, returns {}.
  """

    with distribution_strategy.scope():
        trainer = train_utils.create_trainer(
            params,
            task,
            model_dir=model_dir,
            train='train' in mode,
            evaluate=('eval' in mode) or run_post_eval,
            checkpoint_exporter=maybe_create_best_ckpt_exporter(
                params, model_dir))

    if trainer.checkpoint:
        checkpoint_manager = tf.train.CheckpointManager(
            trainer.checkpoint,
            directory=model_dir,
            max_to_keep=params.trainer.max_to_keep,
            step_counter=trainer.global_step,
            checkpoint_interval=params.trainer.checkpoint_interval,
            init_fn=trainer.initialize)
    else:
        checkpoint_manager = None

    controller = orbit.Controller(
        distribution_strategy,
        trainer=trainer if 'train' in mode else None,
        evaluator=trainer,
        global_step=trainer.global_step,
        steps_per_loop=params.trainer.steps_per_loop,
        checkpoint_manager=checkpoint_manager,
        summary_dir=os.path.join(model_dir, 'train') if
        (save_summary) else None,
        eval_summary_dir=os.path.join(model_dir, 'validation') if
        (save_summary) else None,
        summary_interval=params.trainer.summary_interval if
        (save_summary) else None)

    logging.info('Starts to execute mode: %s', mode)
    with distribution_strategy.scope():
        if mode == 'train':
            controller.train(steps=params.trainer.train_steps)
        elif mode == 'train_and_eval':
            controller.train_and_evaluate(
                train_steps=params.trainer.train_steps,
                eval_steps=params.trainer.validation_steps,
                eval_interval=params.trainer.validation_interval)
        elif mode == 'eval':
            controller.evaluate(steps=params.trainer.validation_steps)
        elif mode == 'continuous_eval':

            def timeout_fn():
                if trainer.global_step.numpy() >= params.trainer.train_steps:
                    return True
                return False

            controller.evaluate_continuously(
                steps=params.trainer.validation_steps,
                timeout=params.trainer.continuous_eval_timeout,
                timeout_fn=timeout_fn)
        else:
            raise NotImplementedError('The mode is not implemented: %s' % mode)

    if run_post_eval:
        with distribution_strategy.scope():
            return trainer.model, trainer.evaluate(
                tf.convert_to_tensor(params.trainer.validation_steps))
    else:
        return trainer.model, {}
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config()
    performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))

    if tf.config.list_physical_devices('GPU'):
        if flags_obj.tf_gpu_thread_mode:
            keras_utils.set_gpu_thread_mode_and_count(
                per_gpu_thread_count=flags_obj.per_gpu_thread_count,
                gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
                num_gpus=flags_obj.num_gpus,
                datasets_num_private_threads=flags_obj.
                datasets_num_private_threads)
        common.set_cudnn_batchnorm_mode()

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.config.list_physical_devices('GPU') else
                       'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
        flags_obj)
    if flags_obj.steps_per_loop is None:
        steps_per_loop = per_epoch_steps
    elif flags_obj.steps_per_loop > per_epoch_steps:
        steps_per_loop = per_epoch_steps
        logging.warn('Setting steps_per_loop to %d to respect epoch boundary.',
                     steps_per_loop)
    else:
        steps_per_loop = flags_obj.steps_per_loop

    logging.info(
        'Training %d epochs, each epoch has %d steps, '
        'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
        train_epochs * per_epoch_steps, eval_steps)

    time_callback = keras_utils.TimeHistory(
        flags_obj.batch_size,
        flags_obj.log_steps,
        logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
    with distribute_utils.get_strategy_scope(strategy):
        runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
                                                  per_epoch_steps)

    eval_interval = flags_obj.epochs_between_evals * per_epoch_steps
    checkpoint_interval = (steps_per_loop * 5
                           if flags_obj.enable_checkpoint_and_export else None)
    summary_interval = steps_per_loop if flags_obj.enable_tensorboard else None

    checkpoint_manager = tf.train.CheckpointManager(
        runnable.checkpoint,
        directory=flags_obj.model_dir,
        max_to_keep=10,
        step_counter=runnable.global_step,
        checkpoint_interval=checkpoint_interval)

    resnet_controller = orbit.Controller(
        strategy=strategy,
        trainer=runnable,
        evaluator=runnable if not flags_obj.skip_eval else None,
        global_step=runnable.global_step,
        steps_per_loop=steps_per_loop,
        checkpoint_manager=checkpoint_manager,
        summary_interval=summary_interval,
        summary_dir=flags_obj.model_dir,
        eval_summary_dir=os.path.join(flags_obj.model_dir, 'eval'))

    time_callback.on_train_begin()
    if not flags_obj.skip_eval:
        resnet_controller.train_and_evaluate(train_steps=per_epoch_steps *
                                             train_epochs,
                                             eval_steps=eval_steps,
                                             eval_interval=eval_interval)
    else:
        resnet_controller.train(steps=per_epoch_steps * train_epochs)
    time_callback.on_train_end()

    stats = build_stats(runnable, time_callback)
    return stats
예제 #5
0
def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
                   task: multitask.MultiTask,
                   model: base_model.MultiTaskBaseModel, mode: str,
                   params: configs.MultiTaskExperimentConfig,
                   model_dir: str) -> base_model.MultiTaskBaseModel:
  """Runs train/eval configured by the experiment params.

  Args:
    distribution_strategy: A distribution distribution_strategy.
    task: A MultiTaskTask instance.
    model: A MultiTaskBaseModel instance.
    mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
      or 'continuous_eval'.
    params: ExperimentConfig instance.
    model_dir: A 'str', a path to store model checkpoints and summaries.

  Returns:
      model: `base_model.MultiTaskBaseModel` instance.
  """

  is_training = 'train' in mode
  is_eval = 'eval' in mode
  with distribution_strategy.scope():
    optimizer = task.create_optimizer(params.trainer.optimizer_config,
                                      params.runtime)
    kwargs = dict(multi_task=task, multi_task_model=model, optimizer=optimizer)
    if params.trainer.trainer_type == 'interleaving':
      sampler = task_sampler.get_task_sampler(params.trainer.task_sampler,
                                              task.task_weights)
      kwargs.update(dict(task_sampler=sampler))
    trainer = TRAINERS[params.trainer.trainer_type](
        **kwargs) if is_training else None
    if is_eval:
      eval_steps = task.task_eval_steps
      evaluator = evaluator_lib.MultiTaskEvaluator(
          eval_tasks=task.tasks.values(),
          model=model,
          eval_steps=eval_steps,
          global_step=trainer.global_step if is_training else None,
          checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
              params, model_dir))
    else:
      evaluator = None

  if trainer:
    checkpoint = trainer.checkpoint
    global_step = trainer.global_step
  else:
    checkpoint = evaluator.checkpoint
    global_step = evaluator.global_step

  # TODO(hongkuny,haozhangthu): Revisit initialization method.
  checkpoint_manager = tf.train.CheckpointManager(
      checkpoint,
      directory=model_dir,
      max_to_keep=params.trainer.max_to_keep,
      step_counter=global_step,
      checkpoint_interval=params.trainer.checkpoint_interval,
      init_fn=model.initialize)

  controller = orbit.Controller(
      strategy=distribution_strategy,
      trainer=trainer,
      evaluator=evaluator,
      global_step=global_step,
      steps_per_loop=params.trainer.steps_per_loop,
      checkpoint_manager=checkpoint_manager,
      summary_dir=os.path.join(model_dir, 'train'),
      eval_summary_dir=os.path.join(model_dir, 'validation'),
      summary_interval=params.trainer.summary_interval)

  logging.info('Starts to execute mode: %s', mode)
  with distribution_strategy.scope():
    if mode == 'train':
      controller.train(steps=params.trainer.train_steps)
    elif mode == 'train_and_eval':
      controller.train_and_evaluate(
          train_steps=params.trainer.train_steps,
          eval_steps=params.trainer.validation_steps,
          eval_interval=params.trainer.validation_interval)
    elif mode == 'eval':
      controller.evaluate(steps=params.trainer.validation_steps)
    elif mode == 'continuous_eval':

      def timeout_fn():
        if evaluator.global_step.numpy() >= params.trainer.train_steps:
          return True
        return False

      controller.evaluate_continuously(
          steps=params.trainer.validation_steps,
          timeout=params.trainer.continuous_eval_timeout,
          timeout_fn=timeout_fn)
    else:
      raise NotImplementedError('The mode is not implemented: %s' % mode)

    return model
예제 #6
0
def run_experiment_with_multitask_eval(
    *,
    distribution_strategy: tf.distribute.Strategy,
    train_task: base_task.Task,
    eval_tasks: List[base_task.Task],
    mode: str,
    params: configs.MultiEvalExperimentConfig,
    model_dir: str,
    run_post_eval: bool = False,
    save_summary: bool = True,
    trainer: Optional[core_lib.Trainer] = None) -> tf.keras.Model:
  """Runs train/eval configured by the experiment params.

  Args:
    distribution_strategy: A distribution distribution_strategy.
    train_task: A base_task.Task instance.
    eval_tasks: A list of evaluation tasks.
    mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
      or 'continuous_eval'.
    params: MultiEvalExperimentConfig instance.
    model_dir: A 'str', a path to store model checkpoints and summaries.
    run_post_eval: Whether to run post eval once after training, metrics logs
      are returned.
    save_summary: Whether to save train and validation summary.
    trainer: the core_lib.Trainer instance. It should be created within the
      strategy.scope(). If not provided, an instance will be created by default
      if `mode` contains 'train'.

  Returns:
      model: `tf.keras.Model` instance.
  """

  is_training = 'train' in mode
  is_eval = 'eval' in mode
  with distribution_strategy.scope():
    if is_training:
      trainer = trainer or core_lib.Trainer(
          config=params,
          task=train_task,
          model=train_task.build_model(),
          optimizer=train_task.create_optimizer(params.trainer.optimizer_config,
                                                params.runtime),
          train=True,
          evaluate=False)
    else:
      trainer = None
    model = trainer.model if trainer else train_task.build_model()

    if is_eval:
      eval_steps = dict([(task_routine.task_config.name,
                          task_routine.eval_steps)
                         for task_routine in params.eval_tasks])
      evaluator = evaluator_lib.MultiTaskEvaluator(
          eval_tasks=eval_tasks,
          model=model,
          global_step=trainer.global_step if is_training else None,
          eval_steps=eval_steps,
          checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
              params, model_dir))
    else:
      evaluator = None

  if trainer:
    checkpoint = trainer.checkpoint
    global_step = trainer.global_step
  else:
    checkpoint = evaluator.checkpoint
    global_step = evaluator.global_step

  checkpoint_manager = tf.train.CheckpointManager(
      checkpoint,
      directory=model_dir,
      max_to_keep=params.trainer.max_to_keep,
      step_counter=global_step,
      checkpoint_interval=params.trainer.checkpoint_interval,
      init_fn=trainer.initialize if trainer else None)

  controller = orbit.Controller(
      strategy=distribution_strategy,
      trainer=trainer,
      evaluator=evaluator,
      global_step=global_step,
      steps_per_loop=params.trainer.steps_per_loop,
      checkpoint_manager=checkpoint_manager,
      summary_dir=os.path.join(model_dir, 'train') if save_summary else None,
      eval_summary_dir=os.path.join(model_dir, 'validation') if
      (save_summary) else None,
      summary_interval=params.trainer.summary_interval if
      (save_summary) else None)

  logging.info('Starts to execute mode: %s', mode)
  with distribution_strategy.scope():
    if mode == 'train':
      controller.train(steps=params.trainer.train_steps)
    elif mode == 'train_and_eval':
      controller.train_and_evaluate(
          train_steps=params.trainer.train_steps,
          eval_steps=params.trainer.validation_steps,
          eval_interval=params.trainer.validation_interval)
    elif mode == 'eval':
      controller.evaluate(steps=params.trainer.validation_steps)
    elif mode == 'continuous_eval':

      def timeout_fn():
        if evaluator.global_step.numpy() >= params.trainer.train_steps:
          return True
        return False

      controller.evaluate_continuously(
          steps=params.trainer.validation_steps,
          timeout=params.trainer.continuous_eval_timeout,
          timeout_fn=timeout_fn)
    else:
      raise NotImplementedError('The mode is not implemented: %s' % mode)

    if run_post_eval:
      return model, evaluator.evaluate(
          tf.convert_to_tensor(params.trainer.validation_steps))
    else:
      return model, {}
예제 #7
0
def run_experiment(
    distribution_strategy: tf.distribute.Strategy,
    task: base_task.Task,
    mode: str,
    params: config_definitions.ExperimentConfig,
    model_dir: str,
    run_post_eval: bool = False,
    save_summary: bool = True,
    trainer: Optional[base_trainer.Trainer] = None
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
    """Runs train/eval configured by the experiment params.

  Args:
    distribution_strategy: A distribution distribution_strategy.
    task: A Task instance.
    mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
      or 'continuous_eval'.
    params: ExperimentConfig instance.
    model_dir: A 'str', a path to store model checkpoints and summaries.
    run_post_eval: Whether to run post eval once after training, metrics logs
      are returned.
    save_summary: Whether to save train and validation summary.
    trainer: the base_trainer.Trainer instance. It should be created within the
      strategy.scope().

  Returns:
    A 2-tuple of (model, eval_logs).
      model: `tf.keras.Model` instance.
      eval_logs: returns eval metrics logs when run_post_eval is set to True,
        otherwise, returns {}.
  """

    with distribution_strategy.scope():
        if not trainer:
            trainer = train_utils.create_trainer(
                params,
                task,
                train='train' in mode,
                evaluate=('eval' in mode) or run_post_eval,
                checkpoint_exporter=maybe_create_best_ckpt_exporter(
                    params, model_dir))

    if trainer.checkpoint:
        checkpoint_manager = tf.train.CheckpointManager(
            trainer.checkpoint,
            directory=model_dir,
            max_to_keep=params.trainer.max_to_keep,
            step_counter=trainer.global_step,
            checkpoint_interval=params.trainer.checkpoint_interval,
            init_fn=trainer.initialize)
        # Adds recovery handling.
        trainer.add_recovery(params.trainer,
                             checkpoint_manager=checkpoint_manager)
    else:
        checkpoint_manager = None

    #Create logs matching tensorboard log parser format
    #see tensorboard_for_parser.md
    hparams = {
        "batch_size": params.task.train_data.global_batch_size,
        "precision": params.runtime.mixed_precision_dtype
    }

    controller = orbit.Controller(
        strategy=distribution_strategy,
        trainer=trainer if 'train' in mode else None,
        evaluator=trainer,
        global_step=trainer.global_step,
        steps_per_loop=params.trainer.steps_per_loop,
        checkpoint_manager=checkpoint_manager,
        summary_dir=model_dir if (save_summary) else None,
        eval_summary_dir=os.path.join(
            model_dir, params.trainer.validation_summary_subdir) if
        (save_summary) else None,
        summary_interval=params.trainer.summary_interval if
        (save_summary) else None,
        hparams=hparams if (save_summary) else None,
        train_actions=None,
        eval_actions=actions.get_eval_actions(params, trainer, model_dir))

    logging.info('Starts to execute mode: %s', mode)
    with distribution_strategy.scope():
        if (params.runtime.dump_config):
            from TensorFlow.common.debug import dump_callback
        with dump_callback(
                params.runtime.dump_config
        ) if params.runtime.dump_config else contextlib.ExitStack():
            if mode == 'train':
                controller.train(steps=params.trainer.train_steps)
            elif mode == 'train_and_eval':
                controller.train_and_evaluate(
                    train_steps=params.trainer.train_steps,
                    eval_steps=params.trainer.validation_steps,
                    eval_interval=params.trainer.validation_interval)
            elif mode == 'eval':
                controller.evaluate(steps=params.trainer.validation_steps)
            elif mode == 'continuous_eval':

                def timeout_fn():
                    if trainer.global_step.numpy(
                    ) >= params.trainer.train_steps:
                        return True
                    return False

                controller.evaluate_continuously(
                    steps=params.trainer.validation_steps,
                    timeout=params.trainer.continuous_eval_timeout,
                    timeout_fn=timeout_fn)
            else:
                raise NotImplementedError('The mode is not implemented: %s' %
                                          mode)

    num_params = train_utils.try_count_params(trainer.model)
    if num_params is not None:
        logging.info('Number of trainable params in model: %f Millions.',
                     num_params / 10.**6)

    if run_post_eval:
        with distribution_strategy.scope():
            return trainer.model, trainer.evaluate(
                tf.convert_to_tensor(params.trainer.validation_steps))
    else:
        return trainer.model, {}
예제 #8
0
def train(model_optimizer_fn, train_steps, eval_steps, steps_between_evals,
          train_dataset, test_dataset, experiment_dir):
  """Perform training.

  Arguments:
    model_optimizer_fn: Function that returns a tuple containing the model and
      its optimizer.
    train_steps: Total number of steps to train for.
    eval_steps: Number of steps to evaluate for.
    steps_between_evals: Number of steps to train for between evaluations.
    train_dataset: Dataset to use for training.
    test_dataset: Size of test dataset.
    experiment_dir: Directory in which to save results.
  """

  test_dataset_orig = test_dataset
  if FLAGS.tpu is not None:
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.TPUStrategy(resolver)
    train_dataset = strategy.experimental_distribute_dataset(train_dataset)
    test_dataset = strategy.experimental_distribute_dataset(test_dataset)

  summary_and_checkpoint_dir = os.path.join(experiment_dir, 'checkpoints')
  best_checkpoint_path = os.path.join(experiment_dir, 'checkpoints', 'best')
  final_model_dir = os.path.join(experiment_dir, 'final_model')
  best_model_dir = os.path.join(experiment_dir, 'best_model')

  # Load previous best accuracy and previous eval step out of summaries.
  summaries, summary_steps = get_summaries_from_dir(summary_and_checkpoint_dir)
  previous_best_accuracy = 0.0
  previous_eval_steps = 0
  if summaries[EVAL_ACCURACY_KEY]:
    previous_best_accuracy = max(summaries[EVAL_ACCURACY_KEY])
    previous_eval_steps = max(summary_steps[EVAL_ACCURACY_KEY])

  try:
    with strategy.scope() if FLAGS.tpu else contextlib.suppress():
      model, optimizer = model_optimizer_fn()

      def _loss_fn(labels, logits):
        """Compute total loss."""
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels, logits))
        # Add weight decay losses to final loss.
        return loss + tf.reduce_sum(model.losses)

      checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
      manager = tf.train.CheckpointManager(
          checkpoint=checkpoint,
          directory=summary_and_checkpoint_dir,
          max_to_keep=(None
                       if FLAGS.mode == 'train' else FLAGS.checkpoints_to_keep),
          step_counter=optimizer.iterations,
          checkpoint_interval=steps_between_evals)

      trainer = single_task_trainer.SingleTaskTrainer(
          train_dataset=train_dataset,
          label_key='label',
          model=model,
          loss_fn=_loss_fn,
          optimizer=optimizer,
          metrics=[
              tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy/train'),
              tf.keras.metrics.SparseCategoricalCrossentropy(name='loss/train'),
          ])
      evaluator = single_task_evaluator.SingleTaskEvaluator(
          eval_dataset=test_dataset,
          label_key='label',
          model=model,
          metrics=[
              tf.keras.metrics.SparseCategoricalAccuracy(
                  name=(EVAL_ACCURACY_KEY)),
              tf.keras.metrics.SparseCategoricalCrossentropy(name='loss/eval'),
          ])
      controller = orbit.Controller(
          trainer=trainer,
          evaluator=evaluator,
          steps_per_loop=steps_between_evals,
          global_step=optimizer.iterations,
          checkpoint_manager=manager)

      controller.restore_checkpoint()
      while optimizer.iterations < train_steps:
        current_steps = optimizer.iterations.numpy()
        if ('evaluate' in FLAGS.mode and
            (current_steps // steps_between_evals >
             previous_eval_steps // steps_between_evals)):
          logging.info('Skipping training because eval is out-of-date.')
        else:
          current_train_steps = min(
              (current_steps // steps_between_evals + 1) * steps_between_evals,
              train_steps)
          if 'train' in FLAGS.mode:
            controller.train(current_train_steps)
          elif 'evaluate' in FLAGS.mode:
            next_checkpoint_path = os.path.join(
                summary_and_checkpoint_dir,
                'ckpt-{}'.format(current_train_steps))
            while not tf.io.gfile.exists(next_checkpoint_path + '.index'):
              logging.info('Checkpoint %s not yet ready.', next_checkpoint_path)
              time.sleep(15)
            checkpoint.restore(next_checkpoint_path)

        if 'evaluate' in FLAGS.mode:
          controller.evaluate(eval_steps)
          current_accuracy = evaluator.eval_end()[EVAL_ACCURACY_KEY]
          current_train_loss = trainer.train_loop_end()['loss/train']
          previous_eval_steps = optimizer.iterations.numpy()


          if current_accuracy > previous_best_accuracy:
            logging.info(
                'New accuracy %.4f beats best previous accuracy %.4f; saving '
                'new best checkpoint.', current_accuracy,
                previous_best_accuracy)
            previous_best_accuracy = current_accuracy
            checkpoint.write(best_checkpoint_path)

        if FLAGS.mode == 'evaluate':
          # Delete checkpoints if we have hit max. We do this in the eval job
          # to make sure that we aren't deleting checkpoints we haven't yet
          # evaluated.
          checkpoint_paths = tf.io.gfile.glob(
              os.path.join(summary_and_checkpoint_dir, 'ckpt-*.index'))
          checkpoint_paths_nums = [
              (int(re.search(r'/ckpt-([0-9]+).index$', x).group(1)), x)
              for x in checkpoint_paths
          ]
          checkpoint_paths_nums.sort()
          for num, path in checkpoint_paths_nums[:-FLAGS.checkpoints_to_keep]:
            if num <= optimizer.iterations.numpy():
              # Don't delete unevaluated checkpoints.
              logging.info('Removing old checkpoint %s.', path)
              tf.io.gfile.remove(path)

      if 'evaluate' in FLAGS.mode:
        # Save final model.
        tf.io.gfile.mkdir(final_model_dir)
        save_predictions(model, test_dataset_orig, final_model_dir)
        tf.keras.models.save_model(model, final_model_dir)

        # At end of training, load best checkpoint and write Keras saved model.
        # We do not do this during training because it is very slow.
        checkpoint.restore(best_checkpoint_path)
        tf.io.gfile.mkdir(best_model_dir)
        save_predictions(model, test_dataset_orig, best_model_dir)
        tf.keras.models.save_model(model, best_model_dir)
  except tf.errors.UnavailableError:
    logging.info('Lost contact with TPU; restaarting.')
    sys.exit(42)
예제 #9
0
def main(_):

    # Set up experiment params and load the configs from file/files.
    experiment_params = params.EdgeTPUBERTCustomParams()
    experiment_params = utils.config_override(experiment_params, FLAGS)
    model_dir = utils.get_model_dir(experiment_params, FLAGS)

    distribution_strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=experiment_params.runtime.distribution_strategy,
        all_reduce_alg=experiment_params.runtime.all_reduce_alg,
        num_gpus=experiment_params.runtime.num_gpus,
        tpu_address=experiment_params.runtime.tpu_address)

    with distribution_strategy.scope():
        teacher_model = model_builder.build_bert_pretrainer(
            pretrainer_cfg=experiment_params.teacher_model,
            quantization_friendly=False,
            name='teacher')
        student_model = model_builder.build_bert_pretrainer(
            pretrainer_cfg=experiment_params.student_model,
            quantization_friendly=True,
            name='student')

        # Load model weights.
        teacher_ckpt_dir_or_file = experiment_params.teacher_model_init_checkpoint
        if not teacher_ckpt_dir_or_file:
            raise ValueError(
                '`teacher_model_init_checkpoint` is not specified.')
        utils.load_checkpoint(teacher_model, teacher_ckpt_dir_or_file)

        student_ckpt_dir_or_file = experiment_params.student_model_init_checkpoint
        if not student_ckpt_dir_or_file:
            # Makes sure the pretrainer variables are created.
            _ = student_model(student_model.inputs)
            logging.warn(
                'No student checkpoint is provided, training might take '
                'much longer before converging.')
        else:
            utils.load_checkpoint(student_model, student_ckpt_dir_or_file)

        runner = mobilebert_edgetpu_trainer.MobileBERTEdgeTPUDistillationTrainer(
            teacher_model=teacher_model,
            student_model=student_model,
            strategy=distribution_strategy,
            experiment_params=experiment_params,
            export_ckpt_path=model_dir)

        # Save checkpoint for preemption handling.
        # Checkpoint for downstreaming tasks are saved separately inside the
        # runner's train_loop_end() function.
        checkpoint = tf.train.Checkpoint(
            teacher_model=runner.teacher_model,
            student_model=runner.student_model,
            layer_wise_optimizer=runner.layer_wise_optimizer,
            e2e_optimizer=runner.e2e_optimizer,
            current_step=runner.current_step)
        checkpoint_manager = tf.train.CheckpointManager(
            checkpoint,
            directory=model_dir,
            max_to_keep=5,
            step_counter=runner.current_step,
            checkpoint_interval=20000,
            init_fn=None)

    controller = orbit.Controller(
        trainer=runner,
        evaluator=runner,
        global_step=runner.current_step,
        strategy=distribution_strategy,
        steps_per_loop=experiment_params.orbit_config.steps_per_loop,
        summary_dir=os.path.join(model_dir, 'train'),
        eval_summary_dir=os.path.join(model_dir, 'eval'),
        checkpoint_manager=checkpoint_manager)

    if FLAGS.mode == 'train':
        controller.train(steps=experiment_params.orbit_config.total_steps)
    else:
        raise ValueError('Unsupported mode, only support `train`')
예제 #10
0
파일: train_lib.py 프로젝트: zss1980/models
def run_experiment_wtih_multitask_eval(
        *, distribution_strategy: tf.distribute.Strategy,
        train_task: base_task.Task, eval_tasks: multitask.MultiTask, mode: str,
        params: configs.MultiEvalExperimentConfig,
        model_dir: str) -> tf.keras.Model:
    """Runs train/eval configured by the experiment params.

  Args:
    distribution_strategy: A distribution distribution_strategy.
    train_task: A base_task.Task instance.
    eval_tasks: A multitask.MultiTask with evaluation tasks.
    mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
      or 'continuous_eval'.
    params: MultiEvalExperimentConfig instance.
    model_dir: A 'str', a path to store model checkpoints and summaries.

  Returns:
      model: `tf.keras.Model` instance.
  """

    is_training = 'train' in mode
    is_eval = 'eval' in mode
    with distribution_strategy.scope():
        optimizer = train_task.create_optimizer(params.trainer, params.runtime)
        model = train_task.build_model()
        if is_training:
            trainer = core_lib.Trainer(config=params,
                                       task=train_task,
                                       model=model,
                                       optimizer=optimizer,
                                       train=True,
                                       evaluate=False)
        else:
            trainer = None
        if is_eval:
            evaluator = evaluator_lib.MultiTaskEvaluator(
                task=eval_tasks,
                model=model,
                global_step=trainer.global_step if is_training else None)
        else:
            evaluator = None

    if trainer:
        checkpoint = trainer.checkpoint
        global_step = trainer.global_step
    else:
        checkpoint = evaluator.checkpoint
        global_step = evaluator.global_step

    checkpoint_manager = tf.train.CheckpointManager(
        checkpoint,
        directory=model_dir,
        max_to_keep=params.trainer.max_to_keep,
        step_counter=global_step,
        checkpoint_interval=params.trainer.checkpoint_interval,
        init_fn=trainer.initialize if trainer else None)

    controller = orbit.Controller(
        strategy=distribution_strategy,
        trainer=trainer,
        evaluator=evaluator,
        global_step=global_step,
        steps_per_loop=params.trainer.steps_per_loop,
        checkpoint_manager=checkpoint_manager,
        summary_dir=os.path.join(model_dir, 'train'),
        eval_summary_dir=os.path.join(model_dir, 'validation'),
        summary_interval=params.trainer.summary_interval)

    logging.info('Starts to execute mode: %s', mode)
    with distribution_strategy.scope():
        if mode == 'train':
            controller.train(steps=params.trainer.train_steps)
        elif mode == 'train_and_eval':
            controller.train_and_evaluate(
                train_steps=params.trainer.train_steps,
                eval_steps=params.trainer.validation_steps,
                eval_interval=params.trainer.validation_interval)
        elif mode == 'eval':
            controller.evaluate(steps=params.trainer.validation_steps)
        elif mode == 'continuous_eval':

            def timeout_fn():
                if evaluator.global_step.numpy() >= params.trainer.train_steps:
                    return True
                return False

            controller.evaluate_continuously(
                steps=params.trainer.validation_steps,
                timeout=params.trainer.continuous_eval_timeout,
                timeout_fn=timeout_fn)
        else:
            raise NotImplementedError('The mode is not implemented: %s' % mode)

        return model
예제 #11
0
    global_step = evaluator.global_step

  checkpoint_manager = tf.train.CheckpointManager(
      checkpoint,
      directory=model_dir,
      max_to_keep=params.trainer.max_to_keep,
      step_counter=global_step,
      checkpoint_interval=params.trainer.checkpoint_interval,
      init_fn=trainer.initialize if trainer else None)

  controller = orbit.Controller(
      strategy=distribution_strategy,
      trainer=trainer,
      evaluator=evaluator,
      global_step=global_step,
      steps_per_loop=params.trainer.steps_per_loop,
      checkpoint_manager=checkpoint_manager,
      summary_dir=os.path.join(model_dir, 'train') if save_summary else None,
      eval_summary_dir=os.path.join(model_dir, 'validation') if
      (save_summary) else None,
      summary_interval=params.trainer.summary_interval if
      (save_summary) else None)

  logging.info('Starts to execute mode: %s', mode)
  with distribution_strategy.scope():
    if mode == 'train':
      controller.train(steps=params.trainer.train_steps)
    elif mode == 'train_and_eval':
      controller.train_and_evaluate(
          train_steps=params.trainer.train_steps,
          eval_steps=params.trainer.validation_steps,
          eval_interval=params.trainer.validation_interval)