def test_mirrored_strategy(self):
     ds = distribution_utils.get_distribution_strategy(5)
     self.assertFalse(ds.is_single_tower)
     self.assertEquals(ds.num_towers, 5)
     self.assertEquals(len(ds.worker_devices), 5)
     for device in ds.worker_devices:
         self.assertIn('GPU', device)
Пример #2
0
 def __init__(self, strategy_type=None, strategy_config=None):
     _ = distribution_utils.configure_cluster(strategy_config.worker_hosts,
                                              strategy_config.task_index)
     self._strategy = distribution_utils.get_distribution_strategy(
         distribution_strategy=strategy_type,
         num_gpus=strategy_config.num_gpus,
         all_reduce_alg=strategy_config.all_reduce_alg,
         num_packs=strategy_config.num_packs,
         tpu_address=strategy_config.tpu)
Пример #3
0
def construct_estimator(flags_obj, params, schedule_manager):
    """Construct an estimator from either Estimator or TPUEstimator.

  Args:
    flags_obj: The FLAGS object parsed from command line.
    params: A dict of run specific parameters.
    schedule_manager: A schedule.Manager object containing the run schedule.

  Returns:
    An estimator object to be used for training and eval.
  """

    print("============== all_reduce_alg ==============")
    print(flags_obj.all_reduce_alg)
    print("============== all_reduce_alg ==============")

    if not params["use_tpu"]:
        distribution_strategy = distribution_utils.get_distribution_strategy(
            flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)
        return tf.estimator.Estimator(
            model_fn=model_fn,
            model_dir=flags_obj.model_dir,
            params=params,
            config=tf.estimator.RunConfig(
                train_distribute=distribution_strategy))

    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        tpu=flags_obj.tpu,
        zone=flags_obj.tpu_zone,
        project=flags_obj.tpu_gcp_project)

    tpu_config = tf.contrib.tpu.TPUConfig(
        iterations_per_loop=schedule_manager.single_iteration_train_steps,
        num_shards=flags_obj.num_tpu_shards)

    run_config = tf.contrib.tpu.RunConfig(cluster=tpu_cluster_resolver,
                                          model_dir=flags_obj.model_dir,
                                          session_config=tf.ConfigProto(
                                              allow_soft_placement=True,
                                              log_device_placement=True),
                                          tpu_config=tpu_config)

    return tf.contrib.tpu.TPUEstimator(
        model_fn=model_fn,
        use_tpu=params["use_tpu"] and flags_obj.tpu != tpu_util.LOCAL,
        train_batch_size=schedule_manager.batch_size,
        eval_batch_size=schedule_manager.batch_size,
        params={
            # TPUEstimator needs to populate batch_size itself due to sharding.
            key: value
            for key, value in params.items() if key != "batch_size"
        },
        config=run_config)
Пример #4
0
 def test_one_device_strategy_gpu(self):
     ds = distribution_utils.get_distribution_strategy(num_gpus=1)
     self.assertEquals(ds.num_replicas_in_sync, 1)
     self.assertEquals(len(ds.extended.worker_devices), 1)
     self.assertIn('GPU', ds.extended.worker_devices[0])
 def test_one_device_strategy_gpu(self):
     ds = distribution_utils.get_distribution_strategy(1)
     self.assertTrue(ds.is_single_tower)
     self.assertEquals(ds.num_towers, 1)
     self.assertEquals(len(ds.worker_devices), 1)
     self.assertIn('GPU', ds.worker_devices[0])
Пример #6
0
def run_mnist(flags_obj):
    """Run MNIST training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.
  """
    model_helpers.apply_clean(flags_obj)
    model_function = model_fn

    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_core.get_num_gpus(flags_obj),
        all_reduce_alg=flags_obj.all_reduce_alg)

    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config)

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    mnist_classifier = tf.estimator.Estimator(model_fn=model_function,
                                              model_dir=flags_obj.model_dir,
                                              config=run_config,
                                              params={
                                                  'data_format': data_format,
                                              })

    # Set up training and evaluation input functions.
    def train_input_fn():
        """Prepare data for training."""

        # When choosing shuffle buffer sizes, larger sizes result in better
        # randomness, while smaller sizes use less memory. MNIST is a small
        # enough dataset that we can easily shuffle the full epoch.
        ds = dataset.train(flags_obj.data_dir)
        ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)

        # Iterate through the dataset a set number (`epochs_between_evals`) of times
        # during each training session.
        ds = ds.repeat(flags_obj.epochs_between_evals)
        return ds

    def eval_input_fn():
        return dataset.test(flags_obj.data_dir).batch(
            flags_obj.batch_size).make_one_shot_iterator().get_next()

    # Set up hook that outputs training logs every 100 steps.
    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    # Train and evaluate model.
    for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
        mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
        eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
        print('\nEvaluation results:\n\t%s\n' % eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['accuracy']):
            break

    # Export the model
    if flags_obj.export_dir is not None:
        image = tf.placeholder(tf.float32, [None, 28, 28])
        input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
            'image':
            image,
        })
        mnist_classifier.export_savedmodel(flags_obj.export_dir,
                                           input_fn,
                                           strip_default_attrs=True)
Пример #7
0
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                shape=None):

    model_helpers.apply_clean(flags.FLAGS)

    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config)

    if flags_obj.pretrained_model_checkpoint_path is not None:
        warm_start_settings = tf.estimator.WarmStartSettings(
            flags_obj.pretrained_model_checkpoint_path,
            vars_to_warm_start='^(?!.*dense)')
    else:
        warm_start_settings = None

    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        warm_start_from=warm_start_settings,
        params={
            'resnet_size': int(flags_obj.resnet_size),
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'resnet_version': int(flags_obj.resnet_version),
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'dtype': flags_core.get_tf_dtype(flags_obj),
            'fine_tune': flags_obj.fine_tune
        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
    }
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('resnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train(num_epochs):
        return input_function(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=num_epochs,
            num_gpus=flags_core.get_num_gpus(flags_obj),
            dtype=flags_core.get_tf_dtype(flags_obj))

    def input_fn_eval():
        return input_function(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=1,
            dtype=flags_core.get_tf_dtype(flags_obj))

    if flags_obj.eval_only or not flags_obj.train_epochs:
        schedule, n_loops = [0], 1
    else:

        n_loops = math.ceil(flags_obj.train_epochs /
                            flags_obj.epochs_between_evals)
        schedule = [
            flags_obj.epochs_between_evals for _ in range(int(n_loops))
        ]
        schedule[-1] = flags_obj.train_epochs - sum(
            schedule[:-1])  # over counting.

    for cycle_index, num_train_epochs in enumerate(schedule):
        tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

        if num_train_epochs:
            classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
                             hooks=train_hooks,
                             max_steps=flags_obj.max_train_steps)

        tf.logging.info('Starting to evaluate.')
        eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                           steps=flags_obj.max_train_steps)

        benchmark_logger.log_evaluation_result(eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['accuracy']):
            break

    if flags_obj.export_dir is not None:
        dtype = flags_core.get_tf_dtype(flags_obj)
        input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
            shape, batch_size=flags_obj.batch_size, dtype=dtype)
        classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
Пример #8
0
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

    model_helpers.apply_clean(flags.FLAGS)

    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    # Create session config based on values of inter_op_parallelism_threads and
    # intra_op_parallelism_threads. Note that we default to having
    # allow_soft_placement = True, which is required for multi-GPU and not
    # harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config)

    # initialize our model with all but the dense layer from pretrained resnet
    if flags_obj.pretrained_model_checkpoint_path is not None:
        warm_start_settings = tf.estimator.WarmStartSettings(
            flags_obj.pretrained_model_checkpoint_path,
            vars_to_warm_start='^(?!.*dense)')
    else:
        warm_start_settings = None

    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        warm_start_from=warm_start_settings,
        params={
            'resnet_size': int(flags_obj.resnet_size),
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'resnet_version': int(flags_obj.resnet_version),
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'dtype': flags_core.get_tf_dtype(flags_obj),
            'fine_tune': flags_obj.fine_tune
        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
    }
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('resnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train(num_epochs):
        return input_function(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=num_epochs,
            num_gpus=flags_core.get_num_gpus(flags_obj),
            dtype=flags_core.get_tf_dtype(flags_obj))

    def input_fn_eval():
        return input_function(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=1,
            dtype=flags_core.get_tf_dtype(flags_obj))

    if flags_obj.eval_only or not flags_obj.train_epochs:
        # If --eval_only is set, perform a single loop with zero train epochs.
        schedule, n_loops = [0], 1
    else:
        # Compute the number of times to loop while training. All but the last
        # pass will train for `epochs_between_evals` epochs, while the last will
        # train for the number needed to reach `training_epochs`. For instance if
        #   train_epochs = 25 and epochs_between_evals = 10
        # schedule will be set to [10, 10, 5]. That is to say, the loop will:
        #   Train for 10 epochs and then evaluate.
        #   Train for another 10 epochs and then evaluate.
        #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
        n_loops = math.ceil(flags_obj.train_epochs /
                            flags_obj.epochs_between_evals)
        schedule = [
            flags_obj.epochs_between_evals for _ in range(int(n_loops))
        ]
        schedule[-1] = flags_obj.train_epochs - sum(
            schedule[:-1])  # over counting.

    for cycle_index, num_train_epochs in enumerate(schedule):
        tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

        if num_train_epochs:
            classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
                             hooks=train_hooks,
                             max_steps=flags_obj.max_train_steps)

        tf.logging.info('Starting to evaluate.')

        # flags_obj.max_train_steps is generally associated with testing and
        # profiling. As a result it is frequently called with synthetic data, which
        # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
        # eval (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                           steps=flags_obj.max_train_steps)

        benchmark_logger.log_evaluation_result(eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['accuracy']):
            break

    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        dtype = flags_core.get_tf_dtype(flags_obj)
        input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
            shape, batch_size=flags_obj.batch_size, dtype=dtype)
        classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
Пример #9
0
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

    model_helpers.apply_clean(flags.FLAGS)

    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    # Create session config based on values of inter_op_parallelism_threads and
    # intra_op_parallelism_threads. Note that we default to having
    # allow_soft_placement = True, which is required for multi-GPU and not
    # harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config)

    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        params={
            'resnet_size': int(flags_obj.resnet_size),
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'resnet_version': int(flags_obj.resnet_version),
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'dtype': flags_core.get_tf_dtype(flags_obj)
        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
    }
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('resnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train():
        return input_function(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=flags_obj.epochs_between_evals,
            num_gpus=flags_core.get_num_gpus(flags_obj))

    def input_fn_eval():
        return input_function(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=1)

    total_training_cycle = (flags_obj.train_epochs //
                            flags_obj.epochs_between_evals)
    for cycle_index in range(total_training_cycle):
        tf.logging.info('Starting a training cycle: %d/%d', cycle_index,
                        total_training_cycle)

        classifier.train(input_fn=input_fn_train,
                         hooks=train_hooks,
                         max_steps=flags_obj.max_train_steps)

        tf.logging.info('Starting to evaluate.')

        # flags_obj.max_train_steps is generally associated with testing and
        # profiling. As a result it is frequently called with synthetic data, which
        # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
        # eval (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                           steps=flags_obj.max_train_steps)

        benchmark_logger.log_evaluation_result(eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['accuracy']):
            break

    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
            shape, batch_size=flags_obj.batch_size)
        classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
Пример #10
0
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                shape=None):
    """Shared main loop for ResNet Models.

	Args:
		flags_obj: An object containing parsed flags. See define_resnet_flags()
			for details.
		model_function: the function that instantiates the Model and builds the
			ops for train/eval. This will be passed directly into the estimator.
		input_function: the function that processes the dataset and returns a
			dataset that the estimator can train on. This will be wrapped with
			all the relevant flags for running and passed to estimator.
		dataset_name: the name of the dataset for training and evaluation. This is
			used for logging purpose.
		shape: list of ints representing the shape of the images used for training.
			This is only used if flags_obj.export_dir is passed.
	"""

    model_helpers.apply_clean(flags.FLAGS)

    # Ensures flag override logic is only executed if explicitly triggered.
    if flags_obj.tf_gpu_thread_mode:
        override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)

    # Creates session config. allow_soft_placement = True, is required for
    # multi-GPU and is not harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

    # Creates a `RunConfig` that checkpoints every 24 hours which essentially
    # results in checkpoints determined only by `epochs_between_evals`.
    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config,
                                        save_checkpoints_secs=60 * 60 * 24)

    # Initializes model with all but the dense layer from pretrained ResNet.
    if flags_obj.pretrained_model_checkpoint_path is not None:
        if flags_obj.fine_tune:
            if string.lower(flags_obj.optimizer) == 'adam':
                if flags_obj.no_dense_init:
                    warm_start_settings = tf.estimator.WarmStartSettings(
                        flags_obj.pretrained_model_checkpoint_path,
                        vars_to_warm_start=[
                            '^(?!.*(resnet_model/dense|beta1_power|beta2_power|Adam|global_step))'
                        ])
                    # vars_to_warm_start=['^(?!.*(resnet_model/dense|global_step))'])
                else:
                    warm_start_settings = tf.estimator.WarmStartSettings(
                        flags_obj.pretrained_model_checkpoint_path,
                        vars_to_warm_start=[
                            '^(?!.*(resnet_model/dense/kernel/Momentum|resnet_model/dense/bias/Momentum|beta1_power|beta2_power|Adam|global_step))'
                        ])
                    # vars_to_warm_start=['^(?!.*(resnet_model/dense|global_step))'])
            else:
                if flags_obj.no_dense_init:
                    warm_start_settings = tf.estimator.WarmStartSettings(
                        flags_obj.pretrained_model_checkpoint_path,
                        vars_to_warm_start=[
                            '^(?!.*(resnet_model/dense|Momentum|global_step))'
                        ])
                else:
                    warm_start_settings = tf.estimator.WarmStartSettings(
                        flags_obj.pretrained_model_checkpoint_path,
                        vars_to_warm_start=[
                            '^(?!.*(resnet_model/dense/kernel/Momentum|resnet_model/dense/bias/Momentum|global_step))'
                        ])
                    # vars_to_warm_start=['^(?!.*(resnet_model/dense|global_step))'])
        else:
            if string.lower(flags_obj.optimizer) == 'adam':
                warm_start_settings = tf.estimator.WarmStartSettings(
                    flags_obj.pretrained_model_checkpoint_path,
                    vars_to_warm_start=[
                        '^(?!.*(endecoder|Momentum|beta1_power|beta2_power|global_step))'
                    ])
                # vars_to_warm_start='^(?!.*dense)')
            else:
                warm_start_settings = tf.estimator.WarmStartSettings(
                    flags_obj.pretrained_model_checkpoint_path,
                    vars_to_warm_start=['^(?!.*(endecoder|global_step))'])
                # vars_to_warm_start='^(?!.*dense)')
    else:
        warm_start_settings = None

    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        warm_start_from=warm_start_settings,
        params={
            'resnet_size': int(flags_obj.resnet_size),
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'resnet_version': int(flags_obj.resnet_version),
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'dtype': flags_core.get_tf_dtype(flags_obj),
            'fine_tune': flags_obj.fine_tune,
            'reconst_loss_scale': flags_obj.reconst_loss_scale,
            'use_ce': flags_obj.use_ce,
            'optimizer': string.lower(flags_obj.optimizer),
            'clip_grad': flags_obj.clip_grad,
            'spectral_norm': flags_obj.spectral_norm,
            'ce_scale': flags_obj.ce_scale,
            'sep_grad_nrom': flags_obj.sep_grad_nrom,
            'norm_teach_feature': flags_obj.norm_teach_feature,
            'no_dense_init': flags_obj.no_dense_init,
            'compress_ratio': flags_obj.compress_ratio
        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
        'fine_tune': flags_obj.fine_tune,
        'reconst_loss_scale': flags_obj.reconst_loss_scale,
        'use_ce': flags_obj.use_ce,
        'optimizer': string.lower(flags_obj.optimizer),
        'clip_grad': flags_obj.clip_grad,
        'spectral_norm': flags_obj.spectral_norm,
        'ce_scale': flags_obj.ce_scale,
        'sep_grad_nrom': flags_obj.sep_grad_nrom,
        'norm_teach_feature': flags_obj.norm_teach_feature,
        'no_dense_init': flags_obj.no_dense_init,
        'compress_ratio': flags_obj.compress_ratio,
    }
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('resnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train(num_epochs):
        return input_function(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=num_epochs,
            dtype=flags_core.get_tf_dtype(flags_obj),
            datasets_num_private_threads=flags_obj.
            datasets_num_private_threads,
            num_parallel_batches=flags_obj.datasets_num_parallel_batches)

    def input_fn_eval():
        return input_function(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=1,
            dtype=flags_core.get_tf_dtype(flags_obj))

    if flags_obj.eval_only or not flags_obj.train_epochs:
        # If --eval_only is set, perform a single loop with zero train epochs.
        schedule, n_loops = [0], 1
    else:
        # Compute the number of times to loop while training. All but the last
        # pass will train for `epochs_between_evals` epochs, while the last will
        # train for the number needed to reach `training_epochs`. For instance if
        #   train_epochs = 25 and epochs_between_evals = 10
        # schedule will be set to [10, 10, 5]. That is to say, the loop will:
        #   Train for 10 epochs and then evaluate.
        #   Train for another 10 epochs and then evaluate.
        #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
        n_loops = math.ceil(flags_obj.train_epochs /
                            flags_obj.epochs_between_evals)
        schedule = [
            flags_obj.epochs_between_evals for _ in range(int(n_loops))
        ]
        schedule[-1] = flags_obj.train_epochs - sum(
            schedule[:-1])  # over counting.

    print('schedule: ', schedule, flags_obj.epochs_between_evals,
          flags_obj.max_train_steps)
    for cycle_index, num_train_epochs in enumerate(schedule):
        tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

        if num_train_epochs:
            classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
                             hooks=train_hooks,
                             max_steps=flags_obj.max_train_steps)

        tf.logging.info('Starting to evaluate.')

        # flags_obj.max_train_steps is generally associated with testing and
        # profiling. As a result it is frequently called with synthetic data, which
        # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
        # eval (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                           steps=flags_obj.max_train_steps)

        benchmark_logger.log_evaluation_result(eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['accuracy']):
            break

    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        export_dtype = flags_core.get_tf_dtype(flags_obj)
        if flags_obj.image_bytes_as_serving_input:
            input_receiver_fn = functools.partial(image_bytes_serving_input_fn,
                                                  shape,
                                                  dtype=export_dtype)
        else:
            input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
                shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
        classifier.export_savedmodel(flags_obj.export_dir,
                                     input_receiver_fn,
                                     strip_default_attrs=True)
def run_keras_model_benchmark(_):
    """Run the benchmark on keras model."""
    # Ensure a valid model name was supplied via command line argument
    if FLAGS.model not in MODELS.keys():
        raise AssertionError("The --model command line argument should "
                             "be a key in the `MODELS` dictionary.")

    # Check if eager execution is enabled
    if FLAGS.eager:
        tf.logging.info("Eager execution is enabled...")
        tf.enable_eager_execution()

    # Load the model
    tf.logging.info("Benchmark on {} model...".format(FLAGS.model))
    keras_model = MODELS[FLAGS.model]
    model = keras_model(weights=None)

    # Get dataset
    dataset_name = "ImageNet"
    if FLAGS.use_synthetic_data:
        tf.logging.info("Using synthetic dataset...")
        dataset_name += "_Synthetic"
        train_dataset = dataset.generate_synthetic_input_dataset(
            FLAGS.model, FLAGS.batch_size)
        val_dataset = dataset.generate_synthetic_input_dataset(
            FLAGS.model, FLAGS.batch_size)
    else:
        raise ValueError("Only synthetic dataset is supported!")

    num_gpus = flags_core.get_num_gpus(FLAGS)

    distribution = None
    # Use distribution strategy
    if FLAGS.dist_strat:
        distribution = distribution_utils.get_distribution_strategy(
            num_gpus=num_gpus)
    elif num_gpus > 1:
        # Run with multi_gpu_model
        # If eager execution is enabled, only one GPU is utilized even if multiple
        # GPUs are provided.
        if FLAGS.eager:
            tf.logging.warning(
                "{} GPUs are provided, but only one GPU is utilized as "
                "eager execution is enabled.".format(num_gpus))
        model = tf.keras.utils.multi_gpu_model(model, gpus=num_gpus)

    # Adam optimizer and some other optimizers doesn't work well with
    # distribution strategy (b/113076709)
    # Use GradientDescentOptimizer here
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
    model.compile(loss="categorical_crossentropy",
                  optimizer=optimizer,
                  metrics=["accuracy"],
                  distribute=distribution)

    # Create benchmark logger for benchmark logging
    run_params = {
        "batch_size": FLAGS.batch_size,
        "synthetic_data": FLAGS.use_synthetic_data,
        "train_epochs": FLAGS.train_epochs,
        "num_train_images": FLAGS.num_train_images,
        "num_eval_images": FLAGS.num_eval_images,
    }

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info(model_name=FLAGS.model,
                                  dataset_name=dataset_name,
                                  run_params=run_params,
                                  test_id=FLAGS.benchmark_test_id)

    # Create callbacks that log metric values about the training and evaluation
    callbacks = model_callbacks.get_model_callbacks(
        FLAGS.callbacks,
        batch_size=FLAGS.batch_size,
        metric_logger=benchmark_logger)
    # Train and evaluate the model
    history = model.fit(train_dataset,
                        epochs=FLAGS.train_epochs,
                        callbacks=callbacks,
                        validation_data=val_dataset,
                        steps_per_epoch=int(
                            np.ceil(FLAGS.num_train_images /
                                    FLAGS.batch_size)),
                        validation_steps=int(
                            np.ceil(FLAGS.num_eval_images / FLAGS.batch_size)))

    tf.logging.info("Logging the evaluation results...")
    for epoch in range(FLAGS.train_epochs):
        eval_results = {
            "accuracy":
            history.history["val_acc"][epoch],
            "loss":
            history.history["val_loss"][epoch],
            tf.GraphKeys.GLOBAL_STEP:
            (epoch + 1) * np.ceil(FLAGS.num_eval_images / FLAGS.batch_size)
        }
        benchmark_logger.log_evaluation_result(eval_results)

    # Clear the session explicitly to avoid session delete error
    tf.keras.backend.clear_session()
Пример #12
0
def platform_main(train_data, val_data, class_num):
    # Parameter Get
    train_num = len(train_data[0])
    val_num = len(val_data[0])

    batch_size = int(FLAGS.param_batch_size)

    # Set up a RunConfig to only save checkpoints once per training cycle.
    session_config = tf.ConfigProto(allow_soft_placement=True)
    session_config.gpu_options.per_process_gpu_memory_fraction = 0.8
    # Windows not support nccl, refer to https://github.com/tensorflow/tensorflow/issues/21470
    # Only can use hierachical_copy
    distribution_strategy = distribution_utils.get_distribution_strategy(
        _GPU_NUM, 'hierachical_copy')

    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config)
    # run_config = tf.estimator.RunConfig(session_config=session_config).replace(save_checkpoints_secs=1e9)
    model = tf.estimator.Estimator(
        model_fn=model_spec.model_fn,
        model_dir=FLAGS.Log_dir,
        config=run_config,
        params={
            'model_name':
            FLAGS.model_name,
            'Log_dir':
            FLAGS.Log_dir,
            'base_architecture':
            FLAGS.base_architecture,
            'num_classes':
            class_num,
            'crop_width':
            FLAGS.crop_size[0],
            'crop_height':
            FLAGS.crop_size[1],
            'batch_size':
            distribution_utils.per_device_batch_size(batch_size, _GPU_NUM),
            'tensorboard_images_max_outputs':
            FLAGS.tensorboard_images_max_outputs,
            'weight_decay':
            FLAGS.weight_decay,
            'learning_rate_policy':
            FLAGS.learning_rate_policy,
            'num_train':
            train_num,
            'initial_learning_rate':
            FLAGS.initial_learning_rate,
            'max_iter':
            FLAGS.max_iter,
            'end_learning_rate':
            FLAGS.end_learning_rate,
            'power':
            FLAGS.decay_power,
            'freeze_batch_norm':
            FLAGS.freeze_batch_norm,
            'initial_global_step':
            FLAGS.initial_global_step,
            'class_weights':
            FLAGS.class_weights
        })

    for _ in range(FLAGS.train_epochs // FLAGS.epochs_per_eval):
        tensors_to_log = {
            # 'global_step': 'global_step',
            'learning_rate': 'learning_rate',
            'cross_entropy': 'cross_entropy',
            'train_px_accuracy': 'train_px_accuracy',
            'train_mean_iou': 'train_mean_iou',
        }

        logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log,
                                                  every_n_iter=10)
        train_hooks = [logging_hook]

        eval_hooks = None

        # if FLAGS.debug:
        #     debug_hook = tf_debug.LocalCLIDebugHook()
        #     train_hooks.append(debug_hook)
        #     eval_hooks = [debug_hook]

        tf.logging.info("Start training.")

        model.train(
            input_fn=lambda: input_fn(
                True, train_data,
                distribution_utils.per_device_batch_size(batch_size, _GPU_NUM),
                train_num, FLAGS.epochs_per_eval),
            hooks=train_hooks,
            # steps=1  # For debug
        )

        tf.logging.info("Start evaluation.")
        # Evaluate the model and print results
        eval_results = model.evaluate(
            # Batch size must be 1 for testing because the images' size differs
            input_fn=lambda: input_fn(
                False, val_data, batch_size=1, buffer_size=val_num),
            hooks=eval_hooks,
            # steps=1  # For debug
        )
        tf.logging.info(eval_results)
Пример #13
0
def run(flags_obj):
    """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

    Args:
      flags_obj: An object containing parsed flag values.

    Raises:
      ValueError: If fp16 is passed as it is not currently supported.

    Returns:
      Dictionary of training and eval stats.
    """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    resnet_common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'fp16':
        raise ValueError(
            'dtype fp16 is not supported in Keras. Use the default '
            'value(fp32).')

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = resnet_common.get_synth_input_fn(
            height=cifar10_preprocessing.HEIGHT,
            width=cifar10_preprocessing.WIDTH,
            num_channels=cifar10_preprocessing.NUM_CHANNELS,
            num_classes=cifar10_preprocessing.NUM_CLASSES,
            dtype=flags_core.get_tf_dtype(flags_obj),
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = cifar10_preprocessing.input_fn

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        parse_record_fn=cifar10_preprocessing.parse_record,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        # Setting drop_remainder to avoid the partial batch logic in normalization
        # layer, which triggers tf.where and leads to extra memory copy of input
        # sizes between host and GPU.
        drop_remainder=(not flags_obj.enable_get_next_as_optional))

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            parse_record_fn=cifar10_preprocessing.parse_record)

    steps_per_epoch = (cifar10_preprocessing.NUM_IMAGES['train'] //
                       flags_obj.batch_size)
    lr_schedule = 0.1
    if flags_obj.use_tensor_lr:
        initial_learning_rate = resnet_common.BASE_LEARNING_RATE * flags_obj.batch_size / 128
        lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE),
            values=[initial_learning_rate] + list(p[0] * initial_learning_rate
                                                  for p in LR_SCHEDULE))

    with strategy_scope:
        optimizer = resnet_common.get_optimizer(lr_schedule)
        model = resnet_cifar_model.resnet56(
            classes=cifar10_preprocessing.NUM_CLASSES)
        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=(['sparse_categorical_accuracy']
                               if flags_obj.report_accuracy_metrics else None),
                      run_eagerly=flags_obj.run_eagerly)

    train_epochs = flags_obj.train_epochs

    callbacks = resnet_common.get_callbacks(steps_per_epoch)

    if not flags_obj.use_tensor_lr:
        lr_callback = LearningRateBatchScheduler(
            schedule=learning_rate_schedule,
            batch_size=flags_obj.batch_size,
            steps_per_epoch=steps_per_epoch)
        callbacks.append(lr_callback)

    # if mutliple epochs, ignore the train_steps flag.
    if train_epochs <= 1 and flags_obj.train_steps:
        steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
        train_epochs = 1

    num_eval_steps = (cifar10_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=1)
    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = resnet_common.build_stats(history, eval_output, callbacks)
    return stats
Пример #14
0
 def test_mirrored_strategy(self):
     ds = distribution_utils.get_distribution_strategy(num_gpus=5)
     self.assertEquals(ds.num_replicas_in_sync, 5)
     self.assertEquals(len(ds.extended.worker_devices), 5)
     for device in ds.extended.worker_devices:
         self.assertIn('GPU', device)
Пример #15
0
def train_and_eval(
        params: base_configs.ExperimentConfig,
        strategy_override: tf.distribute.Strategy) -> Mapping[str, Any]:
    """Runs the train and eval path using compile/fit."""
    logging.info('Running train and eval.')

    # Note: for TPUs, strategy and scope should be created before the dataset
    strategy = strategy_override or distribution_utils.get_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        all_reduce_alg=params.runtime.all_reduce_alg,
        num_gpus=params.runtime.num_gpus,
        tpu_address=params.runtime.tpu)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    logging.info('Detected %d devices.',
                 strategy.num_replicas_in_sync if strategy else 1)

    label_smoothing = params.model.loss.label_smoothing
    one_hot = label_smoothing and label_smoothing > 0

    builders = _get_dataset_builders(params, strategy, one_hot)
    datasets = [builder.build() if builder else None for builder in builders]

    # Unpack datasets and builders based on train/val/test splits
    train_builder, validation_builder = builders  # pylint: disable=unbalanced-tuple-unpacking
    train_dataset, validation_dataset = datasets

    train_epochs = params.train.epochs
    train_steps = params.train.steps or train_builder.num_steps
    validation_steps = params.evaluation.steps or validation_builder.num_steps

    initialize(params, train_builder)

    logging.info('Global batch size: %d', train_builder.global_batch_size)

    with strategy_scope:
        model_params = params.model.model_params.as_dict()
        model = get_models()[params.model.name](**model_params)
        learning_rate = optimizer_factory.build_learning_rate(
            params=params.model.learning_rate,
            batch_size=train_builder.global_batch_size,
            train_steps=train_steps)
        optimizer = optimizer_factory.build_optimizer(
            optimizer_name=params.model.optimizer.name,
            base_learning_rate=learning_rate,
            params=params.model.optimizer.as_dict())
        optimizer = performance.configure_optimizer(
            optimizer,
            use_float16=train_builder.dtype == 'float16',
            loss_scale=get_loss_scale(params))

        metrics_map = _get_metrics(one_hot)
        metrics = [metrics_map[metric] for metric in params.train.metrics]

        if one_hot:
            loss_obj = tf.keras.losses.CategoricalCrossentropy(
                label_smoothing=params.model.loss.label_smoothing)
        else:
            loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
        model.compile(optimizer=optimizer, loss=loss_obj, metrics=metrics)

        initial_epoch = 0
        if params.train.resume_checkpoint:
            initial_epoch = resume_from_checkpoint(model=model,
                                                   model_dir=params.model_dir,
                                                   train_steps=train_steps)

    serialize_config(params=params, model_dir=params.model_dir)
    # TODO(dankondratyuk): callbacks significantly slow down training
    callbacks = custom_callbacks.get_callbacks(
        model_checkpoint=params.train.callbacks.enable_checkpoint_and_export,
        include_tensorboard=params.train.callbacks.enable_tensorboard,
        time_history=params.train.callbacks.enable_time_history,
        track_lr=params.train.tensorboard.track_lr,
        write_model_weights=params.train.tensorboard.write_model_weights,
        initial_step=initial_epoch * train_steps,
        batch_size=train_builder.global_batch_size,
        log_steps=params.train.time_history.log_steps,
        model_dir=params.model_dir)

    if params.evaluation.skip_eval:
        validation_kwargs = {}
    else:
        validation_kwargs = {
            'validation_data': validation_dataset,
            'validation_steps': validation_steps,
            'validation_freq': params.evaluation.epochs_between_evals,
        }

    model.summary()

    history = model.fit(train_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        initial_epoch=initial_epoch,
                        callbacks=callbacks,
                        **validation_kwargs)

    validation_output = None
    if not params.evaluation.skip_eval:
        validation_output = model.evaluate(validation_dataset,
                                           steps=validation_steps,
                                           verbose=2)

    # TODO(dankondratyuk): eval and save final test accuracy
    stats = resnet_common.build_stats(history, validation_output, callbacks)
    return stats
Пример #16
0
def run_mnist(flags_obj):
    """Run MNIST training and eval loop.
    Args:
      flags_obj: An object containing parsed flag values.
    """
    model_helpers.apply_clean(flags_obj)
    model_function = model_fn

    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

    run_config = tf.estimator.RunConfig(
        train_distribute=distribution_strategy,
        session_config=session_config,
        save_checkpoints_steps=flags_obj.ckpt_steps,
        keep_checkpoint_max=flags_obj.max_ckpts,
        save_summary_steps=flags_obj.save_summary_steps,
        log_step_count_steps=flags_obj.log_step_count_steps
    )

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    mnist_classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        params={
            'data_format': data_format,
        })

    # Set up training and evaluation input functions.
    def train_input_fn():
        """Prepare data for training."""

        # When choosing shuffle buffer sizes, larger sizes result in better
        # randomness, while smaller sizes use less memory. MNIST is a small
        # enough dataset that we can easily shuffle the full epoch.
        ds = dataset.train(flags_obj.data_dir)
        ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)

        # Iterate through the dataset a set number (`epochs_between_evals`) of times
        # during each training session.
        ds = ds.repeat(flags_obj.epochs_between_evals)
        return ds

    def eval_input_fn():
        return dataset.test(flags_obj.data_dir).batch(
            flags_obj.batch_size).make_one_shot_iterator().get_next()

    # Set up hook that outputs training logs every 100 steps.
    train_hooks = hooks_helper.get_train_hooks(
        flags_obj.hooks, model_dir=flags_obj.model_dir,
        batch_size=flags_obj.batch_size)

    train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, hooks=train_hooks, max_steps=flags_obj.max_steps)
    eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=None,
                                      start_delay_secs=10,
                                      throttle_secs=flags_obj.eval_secs)

    tf.estimator.train_and_evaluate(mnist_classifier, train_spec, eval_spec)

    # Export the model if node is master and export_dir is set and if experiment is multinode - check if its master
    if os.environ.get('PS_CONFIG') and os.environ.get('TYPE') != 'master':
        tf.logging.debug('No model was exported')
        return

    if flags_obj.export_dir:
        tf.logging.debug('Starting to Export model to {}'.format(str(flags_obj.export_dir)))
        image = tf.placeholder(tf.float32, [None, 28, 28])
        input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
            'image': image,
        })
        mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn,
                                           strip_default_attrs=True)
        tf.logging.debug('Model Exported')
Пример #17
0
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.

  Returns:
     Dict of results of the run.  Contains the keys `eval_results` and
    `train_hooks`. `eval_results` contains accuracy (top_1) and accuracy_top_5.
    `train_hooks` is a list the instances of hooks used during training.
  """

    model_helpers.apply_clean(flags.FLAGS)

    # Ensures flag override logic is only executed if explicitly triggered.
    if flags_obj.tf_gpu_thread_mode:
        override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)

    # Configures cluster spec for distribution strategy.
    num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                                       flags_obj.task_index)

    # Creates session config. allow_soft_placement = True, is required for
    # multi-GPU and is not harmful for other modes.
    session_config = tf.compat.v1.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_core.get_num_gpus(flags_obj),
        num_workers=num_workers,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    # Creates a `RunConfig` that checkpoints every 24 hours which essentially
    # results in checkpoints determined only by `epochs_between_evals`.
    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config,
                                        save_checkpoints_secs=None,
                                        save_checkpoints_steps=2000)

    # Initializes model with all but the dense layer from pretrained ResNet.
    if flags_obj.pretrained_model_checkpoint_path is not None:
        warm_start_settings = tf.estimator.WarmStartSettings(
            flags_obj.pretrained_model_checkpoint_path,
            vars_to_warm_start='^(?!.*dense)')
    else:
        warm_start_settings = None

    classifier = tf.estimator.Estimator(model_fn=model_function,
                                        model_dir=flags_obj.model_dir,
                                        config=run_config,
                                        warm_start_from=warm_start_settings,
                                        params={
                                            'resnet_size':
                                            int(flags_obj.resnet_size),
                                            'data_format':
                                            flags_obj.data_format,
                                            'batch_size':
                                            flags_obj.batch_size,
                                            'resnet_version':
                                            int(flags_obj.resnet_version),
                                            'loss_scale':
                                            flags_core.get_loss_scale(
                                                flags_obj,
                                                default_for_fp16=128),
                                            'dtype':
                                            flags_core.get_tf_dtype(flags_obj),
                                            'fine_tune':
                                            flags_obj.fine_tune,
                                            'num_workers':
                                            num_workers,
                                        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
        'num_workers': num_workers,
    }
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('resnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train(num_epochs, input_context=None):
        return input_function(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_replica_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=num_epochs,
            dtype=flags_core.get_tf_dtype(flags_obj),
            datasets_num_private_threads=flags_obj.
            datasets_num_private_threads,
            input_context=input_context)

    def input_fn_eval():
        return input_function(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_replica_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=1,
            dtype=flags_core.get_tf_dtype(flags_obj))

    train_epochs = (0 if flags_obj.eval_only or not flags_obj.train_epochs else
                    flags_obj.train_epochs)

    use_train_and_evaluate = flags_obj.use_train_and_evaluate or num_workers > 1
    if use_train_and_evaluate:
        train_spec = tf.estimator.TrainSpec(
            input_fn=lambda input_context=None: input_fn_train(
                train_epochs, input_context=input_context),
            hooks=train_hooks,
            max_steps=flags_obj.max_train_steps)
        eval_spec = tf.estimator.EvalSpec(input_fn=input_fn_eval)
        tf.compat.v1.logging.info('Starting to train and evaluate.')
        tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)
        # tf.estimator.train_and_evalute doesn't return anything in multi-worker
        # case.
        eval_results = {}
    else:
        if train_epochs == 0:
            # If --eval_only is set, perform a single loop with zero train epochs.
            schedule, n_loops = [0], 1
        else:
            # Compute the number of times to loop while training. All but the last
            # pass will train for `epochs_between_evals` epochs, while the last will
            # train for the number needed to reach `training_epochs`. For instance if
            #   train_epochs = 25 and epochs_between_evals = 10
            # schedule will be set to [10, 10, 5]. That is to say, the loop will:
            #   Train for 10 epochs and then evaluate.
            #   Train for another 10 epochs and then evaluate.
            #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
            n_loops = math.ceil(train_epochs / flags_obj.epochs_between_evals)
            schedule = [
                flags_obj.epochs_between_evals for _ in range(int(n_loops))
            ]
            schedule[-1] = train_epochs - sum(schedule[:-1])  # over counting.

        for cycle_index, num_train_epochs in enumerate(schedule):
            tf.compat.v1.logging.info('Starting cycle: %d/%d', cycle_index,
                                      int(n_loops))

            if num_train_epochs:
                # Since we are calling classifier.train immediately in each loop, the
                # value of num_train_epochs in the lambda function will not be changed
                # before it is used. So it is safe to ignore the pylint error here
                # pylint: disable=cell-var-from-loop
                classifier.train(
                    input_fn=lambda input_context=None: input_fn_train(
                        num_train_epochs, input_context=input_context),
                    hooks=train_hooks,
                    max_steps=flags_obj.max_train_steps)

            # flags_obj.max_train_steps is generally associated with testing and
            # profiling. As a result it is frequently called with synthetic data,
            # which will iterate forever. Passing steps=flags_obj.max_train_steps
            # allows the eval (which is generally unimportant in those circumstances)
            # to terminate.  Note that eval will run for max_train_steps each loop,
            # regardless of the global_step count.
            tf.compat.v1.logging.info('Starting to evaluate.')
            eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                               steps=flags_obj.max_train_steps)

            benchmark_logger.log_evaluation_result(eval_results)

            if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                                 eval_results['accuracy']):
                break

    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        export_dtype = flags_core.get_tf_dtype(flags_obj)
        if flags_obj.image_bytes_as_serving_input:
            input_receiver_fn = functools.partial(image_bytes_serving_input_fn,
                                                  shape,
                                                  dtype=export_dtype)
        else:
            input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
                shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
        classifier.export_savedmodel(flags_obj.export_dir,
                                     input_receiver_fn,
                                     strip_default_attrs=True)

    stats = {}
    stats['eval_results'] = eval_results
    stats['train_hooks'] = train_hooks

    return stats