Beispiel #1
0
def define_device(tpu=True):
    """Register device specific flags.
    Args:
      tpu: Create flags to specify TPU operation.
    Returns:
      A list of flags for core.py to marks as key flags.
    """

    key_flags = []

    if tpu:
        flags.DEFINE_string(
            name="tpu", default=None, help=help_wrap(
                "The Cloud TPU to use for training. This should be either the name "
                "used when creating the Cloud TPU, or a "
                "grpc://ip.address.of.tpu:8470 url. Passing `local` will use the"
                "CPU of the local instance instead. (Good for debugging.)"))
        key_flags.append("tpu")

        flags.DEFINE_string(
            name="tpu_zone", default=None, help=help_wrap(
                "[Optional] GCE zone where the Cloud TPU is located in. If not "
                "specified, we will attempt to automatically detect the GCE "
                "project from metadata."))

        flags.DEFINE_string(
            name="tpu_gcp_project", default=None, help=help_wrap(
                "[Optional] Project name for the Cloud TPU-enabled project. If not "
                "specified, we will attempt to automatically detect the GCE "
                "project from metadata."))

        flags.DEFINE_integer(name="num_tpu_shards", default=8,
                             help=help_wrap("Number of shards (TPU chips)."))

    return key_flags
Beispiel #2
0
def define_distribution(worker_hosts=True, task_index=True):
    """Register distributed execution flags.

    Args:
      worker_hosts: Create a flag for specifying comma-separated list of workers.
      task_index: Create a flag for specifying index of task.

    Returns:
      A list of flags for core.py to marks as key flags.
    """
    key_flags = []

    if worker_hosts:
        flags.DEFINE_string(
            name='worker_hosts',
            default=None,
            help=help_wrap(
                'Comma-separated list of worker ip:port pairs for running '
                'multi-worker models with DistributionStrategy.  The user would '
                'start the program on each host with identical value for this '
                'flag.'))

    if task_index:
        flags.DEFINE_integer(
            name='task_index',
            default=-1,
            help=help_wrap('If multi-worker training, the task_index of this '
                           'worker.'))

    return key_flags
Beispiel #3
0
def define_image(data_format=True):
    """Register image specific flags.

  Args:
    data_format: Create a flag to specify image axis convention.

  Returns:
    A list of flags for core.py to marks as key flags.
  """

    key_flags = []

    if data_format:
        flags.DEFINE_enum(
            name="data_format",
            short_name="df",
            default=None,
            enum_values=["channels_first", "channels_last"],
            help=help_wrap(
                "A flag to override the data format used in the model. "
                "channels_first provides a performance boost on GPU but is not "
                "always compatible with CPU. If left unspecified, the data format "
                "will be chosen automatically based on whether TensorFlow was "
                "built for CPU or GPU."))
        key_flags.append("data_format")

    return key_flags
Beispiel #4
0
def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
                       synthetic_data=True, max_train_steps=True, dtype=True,
                       all_reduce_alg=True):
  """Register flags for specifying performance tuning arguments.

  Args:
    num_parallel_calls: Create a flag to specify parallelism of input loading.
    inter_op: Create a flag to allow specification of inter op threads.
    intra_op: Create a flag to allow specification of intra op threads.
    synthetic_data: Create a flag to allow the use of synthetic input.
    max_train_steps: Create a flags to allow specification of maximum number
      of training steps
    dtype: Create flags for specifying dtype.

  Returns:
    A list of flags for core.py to marks as key flags.
  """

  key_flags = []
  if num_parallel_calls:
    flags.DEFINE_integer(
        name="num_parallel_calls", short_name="npc",
        default=multiprocessing.cpu_count(),
        help=help_wrap("The number of records that are  processed in parallel "
                       "during input processing. This can be optimized per "
                       "input set but for generally homogeneous input sets, "
                       "should be approximately the number of available CPU "
                       "cores. (default behavior)"))

  if inter_op:
    flags.DEFINE_integer(
        name="inter_op_parallelism_threads", short_name="inter", default=0,
        help=help_wrap("Number of inter_op_parallelism_threads to use for CPU. "
                       "See TensorFlow config.proto for details.")
    )

  if intra_op:
    flags.DEFINE_integer(
        name="intra_op_parallelism_threads", short_name="intra", default=0,
        help=help_wrap("Number of intra_op_parallelism_threads to use for CPU. "
                       "See TensorFlow config.proto for details."))

  if synthetic_data:
    flags.DEFINE_bool(
        name="use_synthetic_data", short_name="synth", default=False,
        help=help_wrap(
            "If set, use fake input (zeroes) instead of a real dataset. "
            "This mode is useful for performance debugging, as it removes "
            "input processing steps, but will not learn anything."))

  if max_train_steps:
    flags.DEFINE_integer(
        name="max_train_steps", short_name="mts", default=None, help=help_wrap(
            "The model will stop training if the global_step reaches this "
            "value. If not set, training will run until the specified number "
            "of epochs have run as usual. It is generally recommended to set "
            "--train_epochs=1 when using this flag."
        ))

  if dtype:
    flags.DEFINE_enum(
        name="dtype", short_name="dt", default="fp32",
        enum_values=DTYPE_MAP.keys(),
        help=help_wrap("The TensorFlow datatype used for calculations. "
                       "Variables may be cast to a higher precision on a "
                       "case-by-case basis for numerical stability."))

    flags.DEFINE_integer(
        name="loss_scale", short_name="ls", default=None,
        help=help_wrap(
            "The amount to scale the loss by when the model is run. Before "
            "gradients are computed, the loss is multiplied by the loss scale, "
            "making all gradients loss_scale times larger. To adjust for this, "
            "gradients are divided by the loss scale before being applied to "
            "variables. This is mathematically equivalent to training without "
            "a loss scale, but the loss scale helps avoid some intermediate "
            "gradients from underflowing to zero. If not provided the default "
            "for fp16 is 128 and 1 for all other dtypes."))

    loss_scale_val_msg = "loss_scale should be a positive integer."
    @flags.validator(flag_name="loss_scale", message=loss_scale_val_msg)
    def _check_loss_scale(loss_scale):  # pylint: disable=unused-variable
      if loss_scale is None:
        return True  # null case is handled in get_loss_scale()

      return loss_scale > 0

  if all_reduce_alg:
    flags.DEFINE_string(
        name="all_reduce_alg", short_name="ara", default=None,
        help=help_wrap("Defines the algorithm to use for performing all-reduce."
                       "See tf.contrib.distribute.AllReduceCrossTowerOps for "
                       "more details and available options."))


  return key_flags
Beispiel #5
0
def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
    """Register benchmarking flags.

    Args:
      benchmark_log_dir: Create a flag to specify location for benchmark logging.
      bigquery_uploader: Create flags for uploading results to BigQuery.

    Returns:
      A list of flags for core.py to marks as key flags.
    """

    key_flags = []

    flags.DEFINE_enum(
        name="benchmark_logger_type",
        default="BaseBenchmarkLogger",
        enum_values=[
            "BaseBenchmarkLogger",
            "BenchmarkFileLogger",
            "BenchmarkBigQueryLogger"],
        help=help_wrap(
            "The type of benchmark logger to use. Defaults to using "
            "BaseBenchmarkLogger which logs to STDOUT. Different "
            "loggers will require other flags to be able to work."))
    flags.DEFINE_string(
        name="benchmark_test_id", short_name="bti", default=None, help=help_wrap(
            "The unique test ID of the benchmark run. It could be the "
            "combination of key parameters. It is hardware "
            "independent and could be used compare the performance "
            "between different test runs. This flag is designed for "
            "human consumption, and does not have any impact within "
            "the system."))

    flags.DEFINE_integer(
        name='log_steps', default=100,
        help='For every log_steps, we log the timing information such as '
        'examples per second. Besides, for every log_steps, we store the '
        'timestamp of a batch end.')

    if benchmark_log_dir:
        flags.DEFINE_string(
            name="benchmark_log_dir", short_name="bld", default=None,
            help=help_wrap("The location of the benchmark logging.")
        )

    if bigquery_uploader:
        flags.DEFINE_string(
            name="gcp_project", short_name="gp", default=None,
            help=help_wrap(
                "The GCP project name where the benchmark will be uploaded."))

        flags.DEFINE_string(
            name="bigquery_data_set", short_name="bds", default="test_benchmark",
            help=help_wrap(
                "The Bigquery dataset name where the benchmark will be uploaded."))

        flags.DEFINE_string(
            name="bigquery_run_table",
            short_name="brt",
            default="benchmark_run",
            help=help_wrap(
                "The Bigquery table name where the benchmark run "
                "information will be uploaded."))

        flags.DEFINE_string(
            name="bigquery_run_status_table", short_name="brst",
            default="benchmark_run_status",
            help=help_wrap("The Bigquery table name where the benchmark run "
                           "status information will be uploaded."))

        flags.DEFINE_string(
            name="bigquery_metric_table",
            short_name="bmt",
            default="benchmark_metric",
            help=help_wrap(
                "The Bigquery table name where the benchmark metric "
                "information will be uploaded."))

    @flags.multi_flags_validator(
        ["benchmark_logger_type", "benchmark_log_dir"],
        message="--benchmark_logger_type=BenchmarkFileLogger will require "
                "--benchmark_log_dir being set")
    def _check_benchmark_log_dir(flags_dict):
        benchmark_logger_type = flags_dict["benchmark_logger_type"]
        if benchmark_logger_type == "BenchmarkFileLogger":
            return flags_dict["benchmark_log_dir"]
        return True

    return key_flags
def define_base(data_dir=True,
                model_dir=True,
                clean=True,
                train_epochs=True,
                epochs_between_evals=True,
                stop_threshold=True,
                batch_size=True,
                num_gpu=True,
                hooks=True,
                export_dir=True):
    """Register base flags.

  Args:
    data_dir: Create a flag for specifying the input data directory.
    model_dir: Create a flag for specifying the model file directory.
    train_epochs: Create a flag to specify the number of training epochs.
    epochs_between_evals: Create a flag to specify the frequency of testing.
    stop_threshold: Create a flag to specify a threshold accuracy or other
      eval metric which should trigger the end of training.
    batch_size: Create a flag to specify the batch size.
    num_gpu: Create a flag to specify the number of GPUs used.
    hooks: Create a flag to specify hooks for logging.
    export_dir: Create a flag to specify where a SavedModel should be exported.

  Returns:
    A list of flags for core.py to marks as key flags.
  """
    key_flags = []

    if data_dir:
        flags.DEFINE_string(name="data_dir",
                            short_name="dd",
                            default="/tmp",
                            help=help_wrap("The location of the input data."))
        key_flags.append("data_dir")

    if model_dir:
        flags.DEFINE_string(
            name="model_dir",
            short_name="md",
            default="/tmp",
            help=help_wrap("The location of the model checkpoint files."))
        key_flags.append("model_dir")

    if clean:
        flags.DEFINE_boolean(
            name="clean",
            default=False,
            help=help_wrap("If set, model_dir will be removed if it exists."))
        key_flags.append("clean")

    if train_epochs:
        flags.DEFINE_integer(
            name="train_epochs",
            short_name="te",
            default=1,
            help=help_wrap("The number of epochs used to train."))
        key_flags.append("train_epochs")

    if epochs_between_evals:
        flags.DEFINE_integer(
            name="epochs_between_evals",
            short_name="ebe",
            default=1,
            help=help_wrap("The number of training epochs to run between "
                           "evaluations."))
        key_flags.append("epochs_between_evals")

    if stop_threshold:
        flags.DEFINE_float(
            name="stop_threshold",
            short_name="st",
            default=None,
            help=help_wrap("If passed, training will stop at the earlier of "
                           "train_epochs and when the evaluation metric is  "
                           "greater than or equal to stop_threshold."))

    if batch_size:
        flags.DEFINE_integer(
            name="batch_size",
            short_name="bs",
            default=32,
            help=help_wrap(
                "Batch size for training and evaluation. When using "
                "multiple gpus, this is the global batch size for "
                "all devices. For example, if the batch size is 32 "
                "and there are 4 GPUs, each GPU will get 8 examples on "
                "each step."))
        key_flags.append("batch_size")

    if num_gpu:
        flags.DEFINE_integer(
            name="num_gpus",
            short_name="ng",
            default=1 if tf.test.is_gpu_available() else 0,
            help=help_wrap(
                "How many GPUs to use with the DistributionStrategies API. The "
                "default is 1 if TensorFlow can detect a GPU, and 0 otherwise."
            ))

    if hooks:
        # Construct a pretty summary of hooks.
        hook_list_str = (u"\ufeff  Hook:\n" + u"\n".join(
            [u"\ufeff    {}".format(key) for key in hooks_helper.HOOKS]))
        flags.DEFINE_list(
            name="hooks",
            short_name="hk",
            default="LoggingTensorHook",
            help=help_wrap(
                u"A list of (case insensitive) strings to specify the names of "
                u"training hooks.\n{}\n\ufeff  Example: `--hooks ProfilerHook,"
                u"ExamplesPerSecondHook`\n See utils.logs.hooks_helper "
                u"for details.".format(hook_list_str)))
        key_flags.append("hooks")

    if export_dir:
        flags.DEFINE_string(
            name="export_dir",
            short_name="ed",
            default=None,
            help=help_wrap(
                "If set, a SavedModel serialization of the model will "
                "be exported to this directory at the end of training. "
                "See the README for more details and relevant links."))
        key_flags.append("export_dir")

    return key_flags
Beispiel #7
0
def define_base(data_dir=True,
                model_dir=True,
                clean=False,
                train_epochs=False,
                epochs_between_evals=False,
                stop_threshold=False,
                batch_size=True,
                num_gpu=False,
                hooks=False,
                export_dir=False,
                distribution_strategy=False,
                run_eagerly=False):
    """Register base flags.

    Args:
      data_dir: Create a flag for specifying the input data directory.
      model_dir: Create a flag for specifying the model file directory.
      clean: Create a flag for removing the model_dir.
      train_epochs: Create a flag to specify the number of training epochs.
      epochs_between_evals: Create a flag to specify the frequency of testing.
      stop_threshold: Create a flag to specify a threshold accuracy or other
        eval metric which should trigger the end of training.
      batch_size: Create a flag to specify the batch size.
      num_gpu: Create a flag to specify the number of GPUs used.
      hooks: Create a flag to specify hooks for logging.
      export_dir: Create a flag to specify where a SavedModel should be exported.
      distribution_strategy: Create a flag to specify which Distribution Strategy
        to use.
      run_eagerly: Create a flag to specify to run eagerly op by op.
    Returns:
      A list of flags for core.py to marks as key flags.
    """
    key_flags = []

    if data_dir:
        flags.DEFINE_string(name="data_dir",
                            short_name="dd",
                            default="/tmp",
                            help=help_wrap("The location of the input data."))
        key_flags.append("data_dir")

    if model_dir:
        flags.DEFINE_string(
            name="model_dir",
            short_name="md",
            default="/tmp",
            help=help_wrap("The location of the model checkpoint files."))
        key_flags.append("model_dir")

    if clean:
        flags.DEFINE_boolean(
            name="clean",
            default=False,
            help=help_wrap("If set, model_dir will be removed if it exists."))
        key_flags.append("clean")

    if train_epochs:
        flags.DEFINE_integer(
            name="train_epochs",
            short_name="te",
            default=1,
            help=help_wrap("The number of epochs used to train."))
        key_flags.append("train_epochs")

    if epochs_between_evals:
        flags.DEFINE_integer(
            name="epochs_between_evals",
            short_name="ebe",
            default=1,
            help=help_wrap("The number of training epochs to run between "
                           "evaluations."))
        key_flags.append("epochs_between_evals")

    if stop_threshold:
        flags.DEFINE_float(
            name="stop_threshold",
            short_name="st",
            default=None,
            help=help_wrap("If passed, training will stop at the earlier of "
                           "train_epochs and when the evaluation metric is  "
                           "greater than or equal to stop_threshold."))

    if batch_size:
        flags.DEFINE_integer(
            name="batch_size",
            short_name="bs",
            default=32,
            help=help_wrap(
                "Batch size for training and evaluation. When using "
                "multiple gpus, this is the global batch size for "
                "all devices. For example, if the batch size is 32 "
                "and there are 4 GPUs, each GPU will get 8 examples on "
                "each step."))
        key_flags.append("batch_size")

    if num_gpu:
        flags.DEFINE_integer(
            name="num_gpus",
            short_name="ng",
            default=1,
            help=help_wrap("How many GPUs to use at each worker with the "
                           "DistributionStrategies API. The default is 1."))

    if run_eagerly:
        flags.DEFINE_boolean(
            name="run_eagerly",
            default=False,
            help="Run the model op by op without building a model function.")

    if hooks:
        # Construct a pretty summary of hooks.
        hook_list_str = (u"\ufeff  Hook:\n" + u"\n".join(
            [u"\ufeff    {}".format(key) for key in hooks_helper.HOOKS]))
        flags.DEFINE_list(
            name="hooks",
            short_name="hk",
            default="LoggingTensorHook",
            help=help_wrap(
                u"A list of (case insensitive) strings to specify the names of "
                u"training hooks.\n{}\n\ufeff  Example: `--hooks ProfilerHook,"
                u"ExamplesPerSecondHook`\n See official.utils.logs.hooks_helper "
                u"for details.".format(hook_list_str)))
        key_flags.append("hooks")

    if export_dir:
        flags.DEFINE_string(
            name="export_dir",
            short_name="ed",
            default=None,
            help=help_wrap(
                "If set, a SavedModel serialization of the model will "
                "be exported to this directory at the end of training. "
                "See the README for more details and relevant links."))
        key_flags.append("export_dir")

    if distribution_strategy:
        flags.DEFINE_string(
            name="distribution_strategy",
            short_name="ds",
            default="mirrored",
            help=help_wrap("The Distribution Strategy to use for training. "
                           "Accepted values are 'off', 'one_device', "
                           "'mirrored', 'parameter_server', 'collective', "
                           "case insensitive. 'off' means not to use "
                           "Distribution Strategy; 'default' means to choose "
                           "from `MirroredStrategy` or `OneDeviceStrategy` "
                           "according to the number of GPUs."))

    return key_flags
Beispiel #8
0
def define_performance(num_parallel_calls=True,
                       inter_op=True,
                       intra_op=True,
                       synthetic_data=True,
                       max_train_steps=True,
                       dtype=True,
                       all_reduce_alg=True,
                       num_packs=True,
                       tf_gpu_thread_mode=False,
                       datasets_num_private_threads=False,
                       datasets_num_parallel_batches=False,
                       dynamic_loss_scale=False,
                       fp16_implementation=False,
                       loss_scale=False,
                       tf_data_experimental_slack=False,
                       enable_xla=False,
                       force_v2_in_keras_compile=False):
    """Register flags for specifying performance tuning arguments.

  Args:
    num_parallel_calls: Create a flag to specify parallelism of data loading.
    inter_op: Create a flag to allow specification of inter op threads.
    intra_op: Create a flag to allow specification of intra op threads.
    synthetic_data: Create a flag to allow the use of synthetic data.
    max_train_steps: Create a flags to allow specification of maximum number
      of training steps
    dtype: Create flags for specifying dtype.
    all_reduce_alg: If set forces a specific algorithm for multi-gpu.
    num_packs: If set provides number of packs for MirroredStrategy's cross
      device ops.
    tf_gpu_thread_mode: gpu_private triggers us of private thread pool.
    datasets_num_private_threads: Number of private threads for datasets.
    datasets_num_parallel_batches: Determines how many batches to process in
    parallel when using map and batch from tf.data.
    dynamic_loss_scale: Allow the "loss_scale" flag to take on the value
      "dynamic". Only valid if `dtype` is True.
    fp16_implementation: Create fp16_implementation flag.
    loss_scale: Controls the loss scaling, normally for mixed-precision
      training. Can only be turned on if dtype is also True.
    tf_data_experimental_slack: Determines whether to enable tf.data's
      `experimental_slack` option.
    enable_xla: Determines if XLA (auto clustering) is turned on.
    force_v2_in_keras_compile: Forces the use of run_distribued path even if not
      using a `strategy`. This is not the same as
      `tf.distribute.OneDeviceStrategy`

  Returns:
    A list of flags for core.py to marks as key flags.
  """

    key_flags = []
    if num_parallel_calls:
        flags.DEFINE_integer(
            name="num_parallel_calls",
            short_name="npc",
            default=multiprocessing.cpu_count(),
            help=help_wrap(
                "The number of records that are  processed in parallel "
                "during input processing. This can be optimized per "
                "data set but for generally homogeneous data sets, "
                "should be approximately the number of available CPU "
                "cores. (default behavior)"))

    if inter_op:
        flags.DEFINE_integer(
            name="inter_op_parallelism_threads",
            short_name="inter",
            default=0,
            help=help_wrap(
                "Number of inter_op_parallelism_threads to use for CPU. "
                "See TensorFlow config.proto for details."))

    if intra_op:
        flags.DEFINE_integer(
            name="intra_op_parallelism_threads",
            short_name="intra",
            default=0,
            help=help_wrap(
                "Number of intra_op_parallelism_threads to use for CPU. "
                "See TensorFlow config.proto for details."))

    if synthetic_data:
        flags.DEFINE_bool(
            name="use_synthetic_data",
            short_name="synth",
            default=False,
            help=help_wrap(
                "If set, use fake data (zeroes) instead of a real dataset. "
                "This mode is useful for performance debugging, as it removes "
                "input processing steps, but will not learn anything."))

    if max_train_steps:
        flags.DEFINE_integer(
            name="max_train_steps",
            short_name="mts",
            default=None,
            help=help_wrap(
                "The model will stop training if the global_step reaches this "
                "value. If not set, training will run until the specified number "
                "of epochs have run as usual. It is generally recommended to set "
                "--train_epochs=1 when using this flag."))

    if dtype:
        flags.DEFINE_enum(
            name="dtype",
            short_name="dt",
            default="fp32",
            enum_values=DTYPE_MAP.keys(),
            help=help_wrap("The TensorFlow datatype used for calculations. "
                           "Variables may be cast to a higher precision on a "
                           "case-by-case basis for numerical stability."))

        loss_scale_help_text = (
            "The amount to scale the loss by when the model is run. {}. Before "
            "gradients are computed, the loss is multiplied by the loss scale, "
            "making all gradients loss_scale times larger. To adjust for this, "
            "gradients are divided by the loss scale before being applied to "
            "variables. This is mathematically equivalent to training without "
            "a loss scale, but the loss scale helps avoid some intermediate "
            "gradients from underflowing to zero. If not provided the default "
            "for fp16 is 128 and 1 for all other dtypes.{}")
        if dynamic_loss_scale:
            loss_scale_help_text = loss_scale_help_text.format(
                "This can be an int/float or the string 'dynamic'",
                " The string 'dynamic' can be used to dynamically determine the "
                "optimal loss scale during training, but currently this "
                "significantly slows down performance")
            loss_scale_validation_msg = (
                "loss_scale should be a positive int/float "
                "or the string 'dynamic'.")
        else:
            loss_scale_help_text = loss_scale_help_text.format(
                "This must be an int/float", "")
            loss_scale_validation_msg = "loss_scale should be a positive int/float."
        if loss_scale:
            flags.DEFINE_string(name="loss_scale",
                                short_name="ls",
                                default=None,
                                help=help_wrap(loss_scale_help_text))

            @flags.validator(flag_name="loss_scale",
                             message=loss_scale_validation_msg)
            def _check_loss_scale(loss_scale):  # pylint: disable=unused-variable
                """Validator to check the loss scale flag is valid."""
                if loss_scale is None:
                    return True  # null case is handled in get_loss_scale()

                if loss_scale == "dynamic" and dynamic_loss_scale:
                    return True

                try:
                    loss_scale = float(loss_scale)
                except ValueError:
                    return False

                return loss_scale > 0

        if fp16_implementation:
            # Currently, this flag is only defined for the estimator resnet model.
            flags.DEFINE_enum(
                name="fp16_implementation",
                default="casting",
                enum_values=("casting', 'graph_rewrite"),
                help=help_wrap(
                    "When --dtype=fp16, how fp16 should be implemented. This has no "
                    "impact on correctness. 'casting' will cause manual tf.casts to "
                    "be inserted in the model. 'graph_rewrite' means "
                    "tf.train.experimental.enable_mixed_precision_graph_rewrite will "
                    "be used to automatically use fp16 without any manual casts."
                ))

            @flags.multi_flags_validator(
                ["fp16_implementation", "dtype", "loss_scale"])
            def _check_fp16_implementation(flags_dict):
                """Validator to check fp16_implementation flag is valid."""
                if (flags_dict["fp16_implementation"] == "graph_rewrite"
                        and flags_dict["dtype"] != "fp16"):
                    raise flags.ValidationError(
                        "--fp16_implementation should not be "
                        "specified unless --dtype=fp16")
                if (flags_dict["fp16_implementation"] != "graph_rewrite"
                        and flags_dict["loss_scale"] == "dynamic"):
                    raise flags.ValidationError(
                        "--loss_scale=dynamic is only supported "
                        "when "
                        "--fp16_implementation=graph_rewrite")
                return True

    if all_reduce_alg:
        flags.DEFINE_string(
            name="all_reduce_alg",
            short_name="ara",
            default=None,
            help=help_wrap(
                "Defines the algorithm to use for performing all-reduce."
                "When specified with MirroredStrategy for single "
                "worker, this controls "
                "tf.contrib.distribute.AllReduceCrossTowerOps.  When "
                "specified with MultiWorkerMirroredStrategy, this "
                "controls "
                "tf.distribute.experimental.CollectiveCommunication; "
                "valid options are `ring` and `nccl`."))

    if num_packs:
        flags.DEFINE_integer(
            name="num_packs",
            default=1,
            help=help_wrap("Sets `num_packs` in the cross device ops used in "
                           "MirroredStrategy.  For details, see "
                           "tf.distribute.NcclAllReduce."))

    if tf_gpu_thread_mode:
        flags.DEFINE_string(
            name="tf_gpu_thread_mode",
            short_name="gt_mode",
            default=None,
            help=help_wrap(
                "Whether and how the GPU device uses its own threadpool."))

        flags.DEFINE_integer(
            name="per_gpu_thread_count",
            short_name="pgtc",
            default=0,
            help=help_wrap(
                "The number of threads to use for GPU. Only valid when "
                "tf_gpu_thread_mode is not global."))

    if datasets_num_private_threads:
        flags.DEFINE_integer(
            name="datasets_num_private_threads",
            default=None,
            help=help_wrap(
                "Number of threads for a private threadpool created for all"
                "datasets computation.."))

    if datasets_num_parallel_batches:
        flags.DEFINE_integer(
            name="datasets_num_parallel_batches",
            default=None,
            help=help_wrap(
                "Determines how many batches to process in parallel when using "
                "map and batch from tf.data."))

    if tf_data_experimental_slack:
        flags.DEFINE_boolean(
            name="tf_data_experimental_slack",
            default=False,
            help=help_wrap(
                "Whether to enable tf.data's `experimental_slack` option."))

    if enable_xla:
        flags.DEFINE_boolean(name="enable_xla",
                             default=False,
                             help="Whether to enable XLA auto jit compilation")

    if force_v2_in_keras_compile:
        flags.DEFINE_boolean(
            name="force_v2_in_keras_compile",
            default=False,
            help="Forces the use of run_distribued path even if not"
            "using a `strategy`. This is not the same as"
            "`tf.distribute.OneDeviceStrategy`")

    return key_flags
def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
    """Register benchmarking flags.

  Args:
    benchmark_log_dir: Create a flag to specify location for benchmark logging.
    bigquery_uploader: Create flags for uploading results to BigQuery.

  Returns:
    A list of flags for core.py to marks as key flags.
  """

    key_flags = []

    flags.DEFINE_enum(
        name="benchmark_logger_type",
        default="BaseBenchmarkLogger",
        enum_values=[
            "BaseBenchmarkLogger", "BenchmarkFileLogger",
            "BenchmarkBigQueryLogger"
        ],
        help=help_wrap(
            "The type of benchmark logger to use. Defaults to using "
            "BaseBenchmarkLogger which logs to STDOUT. Different "
            "loggers will require other flags to be able to work."))

    if benchmark_log_dir:
        flags.DEFINE_string(
            name="benchmark_log_dir",
            short_name="bld",
            default=None,
            help=help_wrap("The location of the benchmark logging."))

    if bigquery_uploader:
        flags.DEFINE_string(
            name="gcp_project",
            short_name="gp",
            default=None,
            help=help_wrap(
                "The GCP project name where the benchmark will be uploaded."))

        flags.DEFINE_string(
            name="bigquery_data_set",
            short_name="bds",
            default="test_benchmark",
            help=help_wrap(
                "The Bigquery dataset name where the benchmark will be uploaded."
            ))

        flags.DEFINE_string(
            name="bigquery_run_table",
            short_name="brt",
            default="benchmark_run",
            help=help_wrap("The Bigquery table name where the benchmark run "
                           "information will be uploaded."))

        flags.DEFINE_string(
            name="bigquery_metric_table",
            short_name="bmt",
            default="benchmark_metric",
            help=help_wrap(
                "The Bigquery table name where the benchmark metric "
                "information will be uploaded."))

    @flags.multi_flags_validator(
        ["benchmark_logger_type", "benchmark_log_dir"],
        message="--benchmark_logger_type=BenchmarkFileLogger will require "
        "--benchmark_log_dir being set")
    def _check_benchmark_log_dir(flags_dict):
        benchmark_logger_type = flags_dict["benchmark_logger_type"]
        if benchmark_logger_type == "BenchmarkFileLogger":
            return flags_dict["benchmark_log_dir"]
        return True

    return key_flags