示例#1
0
def define_distribution(worker_hosts=True, task_index=True):
    """Register distributed execution flags.
  Args:
    worker_hosts: Create a flag for specifying comma-separated list of workers.
    task_index: Create a flag for specifying index of task.
  Returns:
    A list of flags for core.py to marks as key flags.
  """
    key_flags = []

    if worker_hosts:
        flags.DEFINE_string(
            name='worker_hosts',
            default=None,
            help=help_wrap(
                'Comma-separated list of worker ip:port pairs for running '
                'multi-worker models with DistributionStrategy.  The user would '
                'start the program on each host with identical value for this '
                'flag.'))

    if task_index:
        flags.DEFINE_integer(
            name='task_index',
            default=-1,
            help=help_wrap('If multi-worker training, the task_index of this '
                           'worker.'))

    return key_flags
示例#2
0
def define_device(tpu=True):
    """Register device specific flags.
  Args:
    tpu: Create flags to specify TPU operation.
  Returns:
    A list of flags for core.py to marks as key flags.
  """

    key_flags = []

    if tpu:
        flags.DEFINE_string(
            name="tpu",
            default=None,
            help=help_wrap(
                "The Cloud TPU to use for training. This should be either the name "
                "used when creating the Cloud TPU, or a "
                "grpc://ip.address.of.tpu:8470 url. Passing `local` will use the"
                "CPU of the local instance instead. (Good for debugging.)"))
        key_flags.append("tpu")

        flags.DEFINE_string(
            name="tpu_zone",
            default=None,
            help=help_wrap(
                "[Optional] GCE zone where the Cloud TPU is located in. If not "
                "specified, we will attempt to automatically detect the GCE "
                "project from metadata."))

        flags.DEFINE_string(
            name="tpu_gcp_project",
            default=None,
            help=help_wrap(
                "[Optional] Project name for the Cloud TPU-enabled project. If not "
                "specified, we will attempt to automatically detect the GCE "
                "project from metadata."))

        flags.DEFINE_integer(name="num_tpu_shards",
                             default=8,
                             help=help_wrap("Number of shards (TPU chips)."))

    return key_flags
示例#3
0
def define_image(data_format=True):
  """Register image specific flags.
  Args:
    data_format: Create a flag to specify image axis convention.
  Returns:
    A list of flags for core.py to marks as key flags.
  """

  key_flags = []

  if data_format:
    flags.DEFINE_enum(
        name="data_format", short_name="df", default=None,
        enum_values=["channels_first", "channels_last"],
        help=help_wrap(
            "A flag to override the data format used in the model. "
            "channels_first provides a performance boost on GPU but is not "
            "always compatible with CPU. If left unspecified, the data format "
            "will be chosen automatically based on whether TensorFlow was "
            "built for CPU or GPU."))
    key_flags.append("data_format")

  return key_flags
示例#4
0
def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
  """Register benchmarking flags.
  Args:
    benchmark_log_dir: Create a flag to specify location for benchmark logging.
    bigquery_uploader: Create flags for uploading results to BigQuery.
  Returns:
    A list of flags for core.py to marks as key flags.
  """

  key_flags = []

  flags.DEFINE_enum(
      name="benchmark_logger_type", default="BaseBenchmarkLogger",
      enum_values=["BaseBenchmarkLogger", "BenchmarkFileLogger",
                   "BenchmarkBigQueryLogger"],
      help=help_wrap("The type of benchmark logger to use. Defaults to using "
                     "BaseBenchmarkLogger which logs to STDOUT. Different "
                     "loggers will require other flags to be able to work."))
  flags.DEFINE_string(
      name="benchmark_test_id", short_name="bti", default=None,
      help=help_wrap("The unique test ID of the benchmark run. It could be the "
                     "combination of key parameters. It is hardware "
                     "independent and could be used compare the performance "
                     "between different test runs. This flag is designed for "
                     "human consumption, and does not have any impact within "
                     "the system."))

  define_log_steps()

  if benchmark_log_dir:
    flags.DEFINE_string(
        name="benchmark_log_dir", short_name="bld", default=None,
        help=help_wrap("The location of the benchmark logging.")
    )

  if bigquery_uploader:
    flags.DEFINE_string(
        name="gcp_project", short_name="gp", default=None,
        help=help_wrap(
            "The GCP project name where the benchmark will be uploaded."))

    flags.DEFINE_string(
        name="bigquery_data_set", short_name="bds", default="test_benchmark",
        help=help_wrap(
            "The Bigquery dataset name where the benchmark will be uploaded."))

    flags.DEFINE_string(
        name="bigquery_run_table", short_name="brt", default="benchmark_run",
        help=help_wrap("The Bigquery table name where the benchmark run "
                       "information will be uploaded."))

    flags.DEFINE_string(
        name="bigquery_run_status_table", short_name="brst",
        default="benchmark_run_status",
        help=help_wrap("The Bigquery table name where the benchmark run "
                       "status information will be uploaded."))

    flags.DEFINE_string(
        name="bigquery_metric_table", short_name="bmt",
        default="benchmark_metric",
        help=help_wrap("The Bigquery table name where the benchmark metric "
                       "information will be uploaded."))

  @flags.multi_flags_validator(
      ["benchmark_logger_type", "benchmark_log_dir"],
      message="--benchmark_logger_type=BenchmarkFileLogger will require "
              "--benchmark_log_dir being set")
  def _check_benchmark_log_dir(flags_dict):
    benchmark_logger_type = flags_dict["benchmark_logger_type"]
    if benchmark_logger_type == "BenchmarkFileLogger":
      return flags_dict["benchmark_log_dir"]
    return True

  return key_flags
示例#5
0
def define_base(data_dir=True,
                model_dir=True,
                clean=False,
                train_epochs=False,
                epochs_between_evals=False,
                stop_threshold=False,
                batch_size=True,
                num_gpu=False,
                hooks=False,
                export_dir=False,
                distribution_strategy=False,
                run_eagerly=False):
    """Register base flags.
  Args:
    data_dir: Create a flag for specifying the input data directory.
    model_dir: Create a flag for specifying the model file directory.
    clean: Create a flag for removing the model_dir.
    train_epochs: Create a flag to specify the number of training epochs.
    epochs_between_evals: Create a flag to specify the frequency of testing.
    stop_threshold: Create a flag to specify a threshold accuracy or other
      eval metric which should trigger the end of training.
    batch_size: Create a flag to specify the batch size.
    num_gpu: Create a flag to specify the number of GPUs used.
    hooks: Create a flag to specify hooks for logging.
    export_dir: Create a flag to specify where a SavedModel should be exported.
    distribution_strategy: Create a flag to specify which Distribution Strategy
      to use.
    run_eagerly: Create a flag to specify to run eagerly op by op.
  Returns:
    A list of flags for core.py to marks as key flags.
  """
    key_flags = []

    if data_dir:
        flags.DEFINE_string(name="data_dir",
                            short_name="dd",
                            default="/tmp",
                            help=help_wrap("The location of the input data."))
        key_flags.append("data_dir")

    if model_dir:
        flags.DEFINE_string(
            name="model_dir",
            short_name="md",
            default="/tmp",
            help=help_wrap("The location of the model checkpoint files."))
        key_flags.append("model_dir")

    if clean:
        flags.DEFINE_boolean(
            name="clean",
            default=False,
            help=help_wrap("If set, model_dir will be removed if it exists."))
        key_flags.append("clean")

    if train_epochs:
        flags.DEFINE_integer(
            name="train_epochs",
            short_name="te",
            default=1,
            help=help_wrap("The number of epochs used to train."))
        key_flags.append("train_epochs")

    if epochs_between_evals:
        flags.DEFINE_integer(
            name="epochs_between_evals",
            short_name="ebe",
            default=1,
            help=help_wrap("The number of training epochs to run between "
                           "evaluations."))
        key_flags.append("epochs_between_evals")

    if stop_threshold:
        flags.DEFINE_float(
            name="stop_threshold",
            short_name="st",
            default=None,
            help=help_wrap("If passed, training will stop at the earlier of "
                           "train_epochs and when the evaluation metric is  "
                           "greater than or equal to stop_threshold."))

    if batch_size:
        flags.DEFINE_integer(
            name="batch_size",
            short_name="bs",
            default=32,
            help=help_wrap(
                "Batch size for training and evaluation. When using "
                "multiple gpus, this is the global batch size for "
                "all devices. For example, if the batch size is 32 "
                "and there are 4 GPUs, each GPU will get 8 examples on "
                "each step."))
        key_flags.append("batch_size")

    if num_gpu:
        flags.DEFINE_integer(
            name="num_gpus",
            short_name="ng",
            default=1,
            help=help_wrap("How many GPUs to use at each worker with the "
                           "DistributionStrategies API. The default is 1."))

    if run_eagerly:
        flags.DEFINE_boolean(
            name="run_eagerly",
            default=False,
            help="Run the model op by op without building a model function.")

    if hooks:
        # Construct a pretty summary of hooks.
        hook_list_str = (u"\ufeff  Hook:\n" + u"\n".join(
            [u"\ufeff    {}".format(key) for key in hooks_helper.HOOKS]))
        flags.DEFINE_list(
            name="hooks",
            short_name="hk",
            default="LoggingTensorHook",
            help=help_wrap(
                u"A list of (case insensitive) strings to specify the names of "
                u"training hooks.\n{}\n\ufeff  Example: `--hooks ProfilerHook,"
                u"ExamplesPerSecondHook`\n See official.utils.logs.hooks_helper "
                u"for details.".format(hook_list_str)))
        key_flags.append("hooks")

    if export_dir:
        flags.DEFINE_string(
            name="export_dir",
            short_name="ed",
            default=None,
            help=help_wrap(
                "If set, a SavedModel serialization of the model will "
                "be exported to this directory at the end of training. "
                "See the README for more details and relevant links."))
        key_flags.append("export_dir")

    if distribution_strategy:
        flags.DEFINE_string(
            name="distribution_strategy",
            short_name="ds",
            default="mirrored",
            help=help_wrap("The Distribution Strategy to use for training. "
                           "Accepted values are 'off', 'one_device', "
                           "'mirrored', 'parameter_server', 'collective', "
                           "case insensitive. 'off' means not to use "
                           "Distribution Strategy; 'default' means to choose "
                           "from `MirroredStrategy` or `OneDeviceStrategy` "
                           "according to the number of GPUs."))

    return key_flags
示例#6
0
def define_performance(num_parallel_calls=False,
                       inter_op=False,
                       intra_op=False,
                       synthetic_data=False,
                       max_train_steps=False,
                       dtype=False,
                       all_reduce_alg=False,
                       num_packs=False,
                       tf_gpu_thread_mode=False,
                       datasets_num_private_threads=False,
                       datasets_num_parallel_batches=False,
                       dynamic_loss_scale=False,
                       fp16_implementation=False,
                       loss_scale=False,
                       tf_data_experimental_slack=False,
                       enable_xla=False,
                       training_dataset_cache=False):
    """Register flags for specifying performance tuning arguments.
  Args:
    num_parallel_calls: Create a flag to specify parallelism of data loading.
    inter_op: Create a flag to allow specification of inter op threads.
    intra_op: Create a flag to allow specification of intra op threads.
    synthetic_data: Create a flag to allow the use of synthetic data.
    max_train_steps: Create a flags to allow specification of maximum number
      of training steps
    dtype: Create flags for specifying dtype.
    all_reduce_alg: If set forces a specific algorithm for multi-gpu.
    num_packs: If set provides number of packs for MirroredStrategy's cross
      device ops.
    tf_gpu_thread_mode: gpu_private triggers us of private thread pool.
    datasets_num_private_threads: Number of private threads for datasets.
    datasets_num_parallel_batches: Determines how many batches to process in
    parallel when using map and batch from tf.data.
    dynamic_loss_scale: Allow the "loss_scale" flag to take on the value
      "dynamic". Only valid if `dtype` is True.
    fp16_implementation: Create fp16_implementation flag.
    loss_scale: Controls the loss scaling, normally for mixed-precision
      training. Can only be turned on if dtype is also True.
    tf_data_experimental_slack: Determines whether to enable tf.data's
      `experimental_slack` option.
    enable_xla: Determines if XLA (auto clustering) is turned on.
    training_dataset_cache: Whether to cache the training dataset on workers.
       Typically used to improve training performance when training data is in
       remote storage and can fit into worker memory.
  Returns:
    A list of flags for core.py to marks as key flags.
  """

    key_flags = []
    if num_parallel_calls:
        flags.DEFINE_integer(
            name="num_parallel_calls",
            short_name="npc",
            default=multiprocessing.cpu_count(),
            help=help_wrap(
                "The number of records that are  processed in parallel "
                "during input processing. This can be optimized per "
                "data set but for generally homogeneous data sets, "
                "should be approximately the number of available CPU "
                "cores. (default behavior)"))

    if inter_op:
        flags.DEFINE_integer(
            name="inter_op_parallelism_threads",
            short_name="inter",
            default=0,
            help=help_wrap(
                "Number of inter_op_parallelism_threads to use for CPU. "
                "See TensorFlow config.proto for details."))

    if intra_op:
        flags.DEFINE_integer(
            name="intra_op_parallelism_threads",
            short_name="intra",
            default=0,
            help=help_wrap(
                "Number of intra_op_parallelism_threads to use for CPU. "
                "See TensorFlow config.proto for details."))

    if synthetic_data:
        flags.DEFINE_bool(
            name="use_synthetic_data",
            short_name="synth",
            default=False,
            help=help_wrap(
                "If set, use fake data (zeroes) instead of a real dataset. "
                "This mode is useful for performance debugging, as it removes "
                "input processing steps, but will not learn anything."))

    if max_train_steps:
        flags.DEFINE_integer(
            name="max_train_steps",
            short_name="mts",
            default=None,
            help=help_wrap(
                "The model will stop training if the global_step reaches this "
                "value. If not set, training will run until the specified number "
                "of epochs have run as usual. It is generally recommended to set "
                "--train_epochs=1 when using this flag."))

    if dtype:
        flags.DEFINE_enum(
            name="dtype",
            short_name="dt",
            default="fp32",
            enum_values=DTYPE_MAP.keys(),
            help=help_wrap("The TensorFlow datatype used for calculations. "
                           "Variables may be cast to a higher precision on a "
                           "case-by-case basis for numerical stability."))

        loss_scale_help_text = (
            "The amount to scale the loss by when the model is run. {}. Before "
            "gradients are computed, the loss is multiplied by the loss scale, "
            "making all gradients loss_scale times larger. To adjust for this, "
            "gradients are divided by the loss scale before being applied to "
            "variables. This is mathematically equivalent to training without "
            "a loss scale, but the loss scale helps avoid some intermediate "
            "gradients from underflowing to zero. If not provided the default "
            "for fp16 is 128 and 1 for all other dtypes.{}")
        if dynamic_loss_scale:
            loss_scale_help_text = loss_scale_help_text.format(
                "This can be an int/float or the string 'dynamic'",
                " The string 'dynamic' can be used to dynamically determine the "
                "optimal loss scale during training, but currently this "
                "significantly slows down performance")
            loss_scale_validation_msg = (
                "loss_scale should be a positive int/float "
                "or the string 'dynamic'.")
        else:
            loss_scale_help_text = loss_scale_help_text.format(
                "This must be an int/float", "")
            loss_scale_validation_msg = "loss_scale should be a positive int/float."
        if loss_scale:
            flags.DEFINE_string(name="loss_scale",
                                short_name="ls",
                                default=None,
                                help=help_wrap(loss_scale_help_text))

            @flags.validator(flag_name="loss_scale",
                             message=loss_scale_validation_msg)
            def _check_loss_scale(loss_scale):  # pylint: disable=unused-variable
                """Validator to check the loss scale flag is valid."""
                if loss_scale is None:
                    return True  # null case is handled in get_loss_scale()

                if loss_scale == "dynamic" and dynamic_loss_scale:
                    return True

                try:
                    loss_scale = float(loss_scale)
                except ValueError:
                    return False

                return loss_scale > 0

        if fp16_implementation:
            flags.DEFINE_enum(
                name="fp16_implementation",
                default="keras",
                enum_values=("keras', 'graph_rewrite"),
                help=help_wrap(
                    "When --dtype=fp16, how fp16 should be implemented. This has no "
                    "impact on correctness. 'keras' uses the "
                    "tf.keras.mixed_precision API. 'graph_rewrite' uses the "
                    "tf.train.experimental.enable_mixed_precision_graph_rewrite "
                    "API."))

            @flags.multi_flags_validator(
                ["fp16_implementation", "dtype", "loss_scale"])
            def _check_fp16_implementation(flags_dict):
                """Validator to check fp16_implementation flag is valid."""
                if (flags_dict["fp16_implementation"] == "graph_rewrite"
                        and flags_dict["dtype"] != "fp16"):
                    raise flags.ValidationError(
                        "--fp16_implementation should not be "
                        "specified unless --dtype=fp16")
                return True

    if all_reduce_alg:
        flags.DEFINE_string(
            name="all_reduce_alg",
            short_name="ara",
            default=None,
            help=help_wrap(
                "Defines the algorithm to use for performing all-reduce."
                "When specified with MirroredStrategy for single "
                "worker, this controls "
                "tf.contrib.distribute.AllReduceCrossTowerOps.  When "
                "specified with MultiWorkerMirroredStrategy, this "
                "controls "
                "tf.distribute.experimental.CollectiveCommunication; "
                "valid options are `ring` and `nccl`."))

    if num_packs:
        flags.DEFINE_integer(
            name="num_packs",
            default=1,
            help=help_wrap("Sets `num_packs` in the cross device ops used in "
                           "MirroredStrategy.  For details, see "
                           "tf.distribute.NcclAllReduce."))

    if tf_gpu_thread_mode:
        flags.DEFINE_string(
            name="tf_gpu_thread_mode",
            short_name="gt_mode",
            default=None,
            help=help_wrap(
                "Whether and how the GPU device uses its own threadpool."))

        flags.DEFINE_integer(
            name="per_gpu_thread_count",
            short_name="pgtc",
            default=0,
            help=help_wrap(
                "The number of threads to use for GPU. Only valid when "
                "tf_gpu_thread_mode is not global."))

    if datasets_num_private_threads:
        flags.DEFINE_integer(
            name="datasets_num_private_threads",
            default=None,
            help=help_wrap(
                "Number of threads for a private threadpool created for all"
                "datasets computation.."))

    if datasets_num_parallel_batches:
        flags.DEFINE_integer(
            name="datasets_num_parallel_batches",
            default=None,
            help=help_wrap(
                "Determines how many batches to process in parallel when using "
                "map and batch from tf.data."))

    if training_dataset_cache:
        flags.DEFINE_boolean(
            name="training_dataset_cache",
            default=False,
            help=help_wrap(
                "Determines whether to cache the training dataset on workers. "
                "Typically used to improve training performance when training "
                "data is in remote storage and can fit into worker memory."))

    if tf_data_experimental_slack:
        flags.DEFINE_boolean(
            name="tf_data_experimental_slack",
            default=False,
            help=help_wrap(
                "Whether to enable tf.data's `experimental_slack` option."))

    if enable_xla:
        flags.DEFINE_boolean(name="enable_xla",
                             default=False,
                             help="Whether to enable XLA auto jit compilation")

    return key_flags