def test_evaluate_only(self, strategy):
        with strategy.scope():
            test_runnable = TestRunnable()

        checkpoint = tf.train.Checkpoint(model=test_runnable.model)
        checkpoint.save(os.path.join(self.model_dir, "ckpt"))

        checkpoint_manager = tf.train.CheckpointManager(
            checkpoint,
            self.model_dir,
            max_to_keep=None,
            step_counter=test_runnable.global_step)
        test_controller = controller.Controller(
            strategy=strategy,
            eval_fn=test_runnable.evaluate,
            global_step=test_runnable.global_step,
            checkpoint_manager=checkpoint_manager,
            summary_dir=os.path.join(self.model_dir, "summaries/train"),
            eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
            eval_steps=2,
            eval_interval=5)
        test_controller.evaluate()

        # Only eval summaries are written
        self.assertFalse(
            tf.io.gfile.exists(os.path.join(self.model_dir,
                                            "summaries/train")))
        self.assertNotEmpty(
            tf.io.gfile.listdir(os.path.join(self.model_dir,
                                             "summaries/eval")))
        self.assertTrue(
            check_eventfile_for_keyword(
                "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
 def test_no_checkpoint(self):
     test_runnable = TestRunnable()
     # No checkpoint manager and no strategy.
     test_controller = controller.Controller(
         train_fn=test_runnable.train,
         eval_fn=test_runnable.evaluate,
         global_step=test_runnable.global_step,
         train_steps=10,
         steps_per_loop=2,
         summary_dir=os.path.join(self.model_dir, "summaries/train"),
         summary_interval=2,
         eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
         eval_steps=2,
         eval_interval=5)
     test_controller.train(evaluate=True)
     self.assertEqual(test_runnable.global_step.numpy(), 10)
     # Loss and accuracy values should be written into summaries.
     self.assertNotEmpty(
         tf.io.gfile.listdir(os.path.join(self.model_dir,
                                          "summaries/train")))
     self.assertTrue(
         check_eventfile_for_keyword(
             "loss", os.path.join(self.model_dir, "summaries/train")))
     self.assertNotEmpty(
         tf.io.gfile.listdir(os.path.join(self.model_dir,
                                          "summaries/eval")))
     self.assertTrue(
         check_eventfile_for_keyword(
             "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
     # No checkpoint, so global step starts from 0.
     test_runnable.global_step.assign(0)
     test_controller.train(evaluate=True)
     self.assertEqual(test_runnable.global_step.numpy(), 10)
 def test_no_checkpoint_and_summaries(self):
     test_runnable = TestRunnable()
     # No checkpoint + summary directories.
     test_controller = controller.Controller(
         train_fn=test_runnable.train,
         eval_fn=test_runnable.evaluate,
         global_step=test_runnable.global_step,
         train_steps=10,
         steps_per_loop=2,
         eval_steps=2,
         eval_interval=5)
     test_controller.train(evaluate=True)
     self.assertEqual(test_runnable.global_step.numpy(), 10)
    def test_train_and_evaluate(self, strategy):
        with strategy.scope():
            test_runnable = TestRunnable()

        checkpoint = tf.train.Checkpoint(model=test_runnable.model,
                                         optimizer=test_runnable.optimizer)
        checkpoint_manager = tf.train.CheckpointManager(
            checkpoint,
            self.model_dir,
            max_to_keep=None,
            step_counter=test_runnable.global_step,
            checkpoint_interval=10)
        test_controller = controller.Controller(
            strategy=strategy,
            train_fn=test_runnable.train,
            eval_fn=test_runnable.evaluate,
            global_step=test_runnable.global_step,
            train_steps=10,
            steps_per_loop=2,
            summary_dir=os.path.join(self.model_dir, "summaries/train"),
            summary_interval=2,
            checkpoint_manager=checkpoint_manager,
            eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
            eval_steps=2,
            eval_interval=5)
        test_controller.train(evaluate=True)

        # Checkpoints are saved.
        self.assertNotEmpty(
            tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))

        # Loss and accuracy values should be written into summaries.
        self.assertNotEmpty(
            tf.io.gfile.listdir(os.path.join(self.model_dir,
                                             "summaries/train")))
        self.assertTrue(
            check_eventfile_for_keyword(
                "loss", os.path.join(self.model_dir, "summaries/train")))
        self.assertNotEmpty(
            tf.io.gfile.listdir(os.path.join(self.model_dir,
                                             "summaries/eval")))
        self.assertTrue(
            check_eventfile_for_keyword(
                "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
Esempio n. 5
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)
    performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))

    # This only affects GPU.
    common.set_cudnn_batchnorm_mode()

    # TODO(anj-s): Set data_format without using Keras.
    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.config.list_physical_devices('GPU') else
                       'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
        flags_obj)
    steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)

    logging.info(
        'Training %d epochs, each epoch has %d steps, '
        'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
        train_epochs * per_epoch_steps, eval_steps)

    time_callback = keras_utils.TimeHistory(
        flags_obj.batch_size,
        flags_obj.log_steps,
        logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
    with distribution_utils.get_strategy_scope(strategy):
        runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
                                                  per_epoch_steps)

    eval_interval = flags_obj.epochs_between_evals * per_epoch_steps
    checkpoint_interval = (per_epoch_steps
                           if flags_obj.enable_checkpoint_and_export else None)
    summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None

    checkpoint_manager = tf.train.CheckpointManager(
        runnable.checkpoint,
        directory=flags_obj.model_dir,
        max_to_keep=10,
        step_counter=runnable.global_step,
        checkpoint_interval=checkpoint_interval)

    resnet_controller = controller.Controller(
        strategy,
        runnable.train,
        runnable.evaluate,
        global_step=runnable.global_step,
        steps_per_loop=steps_per_loop,
        train_steps=per_epoch_steps * train_epochs,
        checkpoint_manager=checkpoint_manager,
        summary_interval=summary_interval,
        eval_steps=eval_steps,
        eval_interval=eval_interval)

    time_callback.on_train_begin()
    resnet_controller.train(evaluate=not flags_obj.skip_eval)
    time_callback.on_train_end()

    stats = build_stats(runnable, time_callback)
    return stats
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    #init horovod
    hvd.init()

    #pin GPU to be used to process local rank
    #If TF1
    #config = tf.ConfigProto()
    #config.gpu_options.visible_device_list = str(hvd.local_rank())
    #If TF2
    gpus = tf.config.experimental.list_physical_devices('GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    if gpus:
        tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()],
                                                   'GPU')

    # keras_utils.set_session_config(
    #     enable_eager=flags_obj.enable_eager,
    #     enable_xla=flags_obj.enable_xla)
    # performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))

    # if tf.config.list_physical_devices('GPU'):
    #   if flags_obj.tf_gpu_thread_mode:
    #     keras_utils.set_gpu_thread_mode_and_count(
    #         per_gpu_thread_count=flags_obj.per_gpu_thread_count,
    #         gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
    #         num_gpus=flags_obj.num_gpus,
    #         datasets_num_private_threads=flags_obj.datasets_num_private_threads)
    #   common.set_cudnn_batchnorm_mode()

    # TODO(anj-s): Set data_format without using Keras.
    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.config.list_physical_devices('GPU') else
                       'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        #num_gpus=flags_obj.num_gpus,
        num_gpus=1,  #set to 1 to force into non-distributed but GPU mode
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
        flags_obj)
    steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)

    logging.info(
        'Training %d epochs, each epoch has %d steps, '
        'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
        train_epochs * per_epoch_steps, eval_steps)

    time_callback = keras_utils.TimeHistory(
        flags_obj.batch_size,
        flags_obj.log_steps,
        logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
    with distribution_utils.get_strategy_scope(strategy):
        runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
                                                  per_epoch_steps)

    eval_interval = flags_obj.epochs_between_evals * per_epoch_steps
    checkpoint_interval = (per_epoch_steps
                           if flags_obj.enable_checkpoint_and_export else None)
    summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None

    checkpoint_manager = tf.train.CheckpointManager(
        runnable.checkpoint,
        directory=flags_obj.model_dir,
        max_to_keep=10,
        step_counter=runnable.global_step,
        checkpoint_interval=checkpoint_interval)

    resnet_controller = controller.Controller(
        strategy,
        runnable.train,
        runnable.evaluate if not flags_obj.skip_eval else None,
        global_step=runnable.global_step,
        steps_per_loop=steps_per_loop,
        train_steps=per_epoch_steps * train_epochs,
        checkpoint_manager=checkpoint_manager,
        summary_interval=summary_interval,
        eval_steps=eval_steps,
        eval_interval=eval_interval)

    time_callback.on_train_begin()
    resnet_controller.train(evaluate=not flags_obj.skip_eval)
    time_callback.on_train_end()

    stats = build_stats(runnable, time_callback)
    return stats