Ejemplo n.º 1
0
def imagenet_model_fn(features, labels, mode, params):
    """Our model_fn for ResNet to be used with our Estimator."""

    # Warmup and higher lr may not be valid for fine tuning with small batches
    # and smaller numbers of training images.
    if params['fine_tune'] or ('disable_warmup' in params
                               and params['disable_warmup']):
        warmup = False
        base_lr = .1
    else:
        warmup = True
        base_lr = .128

    # According to https://arxiv.org/abs/1706.02677 and our internal experiments,
    # the best accuracy results for more than 16 devices are achieved when base_lr == 0.1
    if horovod_enabled() and hvd.size() > 16:
        base_lr = .1

    # Used for ResNeXt101-32x4d
    if params['use_cosine_lr']:
        base_lr = .256

    if horovod_enabled():
        total_batch_size = params['batch_size'] * hvd.size()
    else:
        total_batch_size = params['batch_size'] * params.get('num_workers', 1)

    learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
        batch_size=total_batch_size,
        batch_denom=256,
        num_images=NUM_IMAGES['train'],
        boundary_epochs=[30, 60, 80, 90],
        train_epochs=params['train_epochs'],
        decay_rates=[1, 0.1, 0.01, 0.001, 1e-4],
        warmup=warmup,
        warmup_epochs=params['warmup_epochs'],
        base_lr=base_lr,
        use_cosine_lr=params['use_cosine_lr'])

    return resnet_run_loop.resnet_model_fn(
        features=features,
        labels=labels,
        mode=mode,
        model_class=ImagenetModel,
        resnet_size=params['resnet_size'],
        weight_decay=flags.FLAGS.weight_decay,
        learning_rate_fn=learning_rate_fn,
        momentum=flags.FLAGS.momentum,
        data_format=params['data_format'],
        resnet_version=params['resnet_version'],
        loss_scale=params['loss_scale'],
        loss_filter_fn=None,
        model_type=params['model_type'],
        dtype=params['dtype'],
        fine_tune=params['fine_tune'],
        label_smoothing=flags.FLAGS.label_smoothing)
Ejemplo n.º 2
0
    def input_fn(params):
        """The actual input function."""
        if use_tpu:
            batch_size = params["batch_size"]
        else:
            batch_size = bsz

        if FLAGS.deterministic_run:
            d = tf.data.TFRecordDataset(input_file)
            d = d.apply(
                tf.data.experimental.map_and_batch(
                    lambda record: _decode_record(record, name_to_features),
                    batch_size=batch_size,
                    num_parallel_calls=1,
                    drop_remainder=True))
            return d
        # For training, we want a lot of parallel reading and shuffling.
        # For eval, we want no shuffling and parallel reading doesn't matter.
        d = tf.data.TFRecordDataset(input_file)
        if is_training:
            if horovod_enabled():
                d = d.shard(hvd.size(), hvd.rank())
            d = d.repeat()
            d = d.shuffle(buffer_size=100)

        d = d.apply(
            tf.data.experimental.map_and_batch(
                lambda record: _decode_record(record, name_to_features),
                batch_size=batch_size,
                drop_remainder=drop_remainder))

        return d
Ejemplo n.º 3
0
    def after_run(self, run_context, run_values):
        self._run_end = time.time()
        self._duration += self._run_end - self._run_begin
        # not use step 0 to warmup
        if self._step > 0 and self._step % self._every_n_steps == 0:
            results = run_values.results
            global_step = results['global_step']

            images = get_key_or_none(results, self._images_name)
            labels = get_key_or_none(results, self._labels_name)
            filenames = get_key_or_none(results, self._filenames_name)
            raw_images = get_key_or_none(results, self._raw_images_name)
            heat_map_features = get_key_or_none(results,
                                                self._heat_map_features_name)
            probs = get_key_or_none(results, self._probs_name)

            self._total_batch_size = len(images) * hvd.size()

            self._log_and_record(self._step + global_step)
            show_images(filenames,
                        images,
                        raw_images,
                        heat_map_features,
                        labels,
                        probs,
                        self._step + global_step,
                        self._max_images,
                        self._summary_writer,
                        prefix='eval')

        self._step += 1
Ejemplo n.º 4
0
def hvd_info(msg):
    hvd_try_init()
    if horovod_enabled():
        head = 'hvd rank{}/{} in {}'.format(hvd.rank(), hvd.size(),
                                            socket.gethostname())
    else:
        head = '{}'.format(socket.gethostname())
    tf.logging.info('{}: {}'.format(head, msg))
Ejemplo n.º 5
0
    def _calculate_mean_and_var(self, x, axes, keep_dims):

        with ops.name_scope('moments', values=[x, axes]):
            # The dynamic range of fp16 is too limited to support the collection of
            # sufficient statistics. As a workaround we simply perform the operations
            # on 32-bit floats before converting the mean and variance back to fp16
            y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x

            if horovod_enabled():
                num_shards = hvd.size()
            else:
                num_shards = 1

            if num_shards > 1:
                local_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True)
                local_squared_sum = math_ops.reduce_sum(math_ops.square(y), axis=axes, keepdims=True)
                batch_size = math_ops.cast(array_ops.shape_v2(y)[0], dtypes.float32)
                # y_sum, y_squared_sum, global_batch_size = (
                #     replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, [
                #         local_sum, local_squared_sum, batch_size]))

                # hvd_info(f'local_sum {local_sum.shape}, local_squared_sum {local_squared_sum.shape}')

                y_sum = hvd.allreduce(local_sum, average=False)
                y_squared_sum = hvd.allreduce(local_squared_sum, average=False)

                global_batch_size = batch_size * num_shards
                axes_vals = [(array_ops.shape_v2(y))[i] for i in range(1, len(axes))]
                multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals), dtypes.float32)
                multiplier = multiplier * global_batch_size

                mean = y_sum / multiplier
                y_squared_mean = y_squared_sum / multiplier
                # var = E(x^2) - E(x)^2
                variance = y_squared_mean - math_ops.square(mean)
            else:
                # Compute true mean while keeping the dims for proper broadcasting.
                mean = math_ops.reduce_mean(y, axes, keepdims=True, name='mean')
                # sample variance, not unbiased variance
                # Note: stop_gradient does not change the gradient that gets
                #       backpropagated to the mean from the variance calculation,
                #       because that gradient is zero
                variance = math_ops.reduce_mean(
                    math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
                    axes,
                    keepdims=True,
                    name='variance')
            if not keep_dims:
                mean = array_ops.squeeze(mean, axes)
                variance = array_ops.squeeze(variance, axes)
            if x.dtype == dtypes.float16:
                return (math_ops.cast(mean, dtypes.float16),
                        math_ops.cast(variance, dtypes.float16))
            else:
                return (mean, variance)
def _configure_learning_rate(num_samples_per_epoch, global_step):
    """Configures the learning rate.

  Args:
    num_samples_per_epoch: The number of samples in each epoch of training.
    global_step: The global_step tensor.

  Returns:
    A `Tensor` representing the learning rate.

  Raises:
    ValueError: if
  """
    # Note: when num_clones is > 1, this will actually have each clone to go
    # over each epoch FLAGS.num_epochs_per_decay times. This is different
    # behavior from sync replicas and is expected to produce different results.
    steps_per_epoch = num_samples_per_epoch / FLAGS.batch_size / FLAGS.num_workers
    if FLAGS.sync_replicas:
        steps_per_epoch /= FLAGS.replicas_to_aggregate

    decay_steps = int(steps_per_epoch * FLAGS.num_epochs_per_decay)

    if FLAGS.learning_rate_decay_type == 'exponential':
        learning_rate = tf.train.exponential_decay(
            FLAGS.learning_rate,
            global_step,
            decay_steps,
            FLAGS.learning_rate_decay_factor,
            staircase=True,
            name='exponential_decay_learning_rate')
    elif FLAGS.learning_rate_decay_type == 'fixed':
        learning_rate = tf.constant(FLAGS.learning_rate,
                                    name='fixed_learning_rate')
    elif FLAGS.learning_rate_decay_type == 'polynomial':
        learning_rate = tf.train.polynomial_decay(
            FLAGS.learning_rate,
            global_step,
            decay_steps,
            FLAGS.end_learning_rate,
            power=1.0,
            cycle=False,
            name='polynomial_decay_learning_rate')
    else:
        raise ValueError('learning_rate_decay_type [%s] was not recognized' %
                         FLAGS.learning_rate_decay_type)

    if FLAGS.warmup_epochs:
        warmup_lr = (FLAGS.learning_rate * tf.cast(global_step, tf.float32) /
                     (steps_per_epoch * FLAGS.warmup_epochs))
        learning_rate = tf.minimum(warmup_lr, learning_rate)
    if horovod_enabled():
        learning_rate = learning_rate * hvd.size()
    return learning_rate
Ejemplo n.º 7
0
def hvd_info_rank0(msg, with_head=True):
    hvd_try_init()
    if is_rank0():
        if with_head:
            if horovod_enabled():
                head = 'hvd only rank{}/{} in {}'.format(
                    hvd.rank(), hvd.size(), socket.gethostname())
            else:
                head = '{}'.format(socket.gethostname())
            tf.logging.info('{}: {}'.format(head, msg))
        else:
            tf.logging.info(msg)
Ejemplo n.º 8
0
  def _moments(self, inputs, reduction_axes, keep_dims):
    """Compute the mean and variance: it overrides the original _moments."""
    shard_mean, shard_variance = super(SyncBatchNormalization, self)._moments(
      inputs, reduction_axes, keep_dims=keep_dims)

    num_shards = hvd.size() if horovod_enabled() else 1
    if num_shards > 1:
      # Compute variance using: Var[X]= E[X^2] - E[X]^2.
      shard_square_of_mean = tf.math.square(shard_mean)
      shard_mean_of_square = shard_variance + shard_square_of_mean
      group_mean = hvd.allreduce(shard_mean)
      group_mean_of_square = hvd.allreduce(shard_mean_of_square)
      group_variance = group_mean_of_square - tf.math.square(group_mean)
      return (group_mean, group_variance)
    else:
      return (shard_mean, shard_variance)
Ejemplo n.º 9
0
def input_fn(is_training,
             data_dir,
             batch_size,
             num_epochs=1,
             dtype=tf.float32,
             datasets_num_private_threads=None,
             parse_record_fn=parse_record,
             input_context=None,
             drop_remainder=False,
             tf_data_experimental_slack=False,
             experimental_preloading=False,
             dataset_fn=None):
    """Input function which provides batches for train or eval.

  Args:
    is_training: A boolean denoting whether the input is for training.
    data_dir: The directory containing the input data.
    batch_size: The number of samples per batch.
    num_epochs: The number of epochs to repeat the dataset.
    dtype: Data type to use for images/features
    datasets_num_private_threads: Number of private threads for tf.data.
    parse_record_fn: Function to use for parsing the records.
    input_context: A `tf.distribute.InputContext` object passed in by
      `tf.distribute.Strategy`.
    drop_remainder: A boolean indicates whether to drop the remainder of the
      batches. If True, the batch dimension will be static.
    tf_data_experimental_slack: Whether to enable tf.data's
      `experimental_slack` option.

  Returns:
    A dataset that can be used for iteration.
  """
    if dataset_fn is None:
        filenames = get_filenames(is_training, data_dir)
        dataset = tf.data.Dataset.from_tensor_slices(filenames)
    else:
        dataset = dataset_fn()

    if is_training and horovod_enabled():
        dataset = dataset.shard(hvd.size(), hvd.rank())

    if input_context:
        tf.compat.v1.logging.info(
            'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d'
            % (input_context.input_pipeline_id,
               input_context.num_input_pipelines))
        dataset = dataset.shard(input_context.num_input_pipelines,
                                input_context.input_pipeline_id)

    if is_training:
        # Shuffle the input files
        dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)

    # Convert to individual records.
    # cycle_length = 10 means that up to 10 files will be read and deserialized in
    # parallel. You may want to increase this number if you have a large number of
    # CPU cores.
    dataset = dataset.interleave(
        tf.data.TFRecordDataset,
        cycle_length=10,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    return resnet_run_loop.process_record_dataset(
        dataset=dataset,
        is_training=is_training,
        batch_size=batch_size,
        shuffle_buffer=_SHUFFLE_BUFFER,
        parse_record_fn=parse_record_fn,
        num_epochs=num_epochs,
        dtype=dtype,
        datasets_num_private_threads=datasets_num_private_threads,
        drop_remainder=drop_remainder,
        tf_data_experimental_slack=tf_data_experimental_slack,
        experimental_preloading=experimental_preloading)
Ejemplo n.º 10
0
def resnet_main(
    flags_obj, model_function, input_function, dataset_name, shape=None):
  """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.

  Returns:
     Dict of results of the run.  Contains the keys `eval_results` and
    `train_hooks`. `eval_results` contains accuracy (top_1) and accuracy_top_5.
    `train_hooks` is a list the instances of hooks used during training.
  """

  experimental_preloading = flags_obj.experimental_preloading

  model_helpers.apply_clean(flags.FLAGS)

  # Ensures flag override logic is only executed if explicitly triggered.
  if flags_obj.tf_gpu_thread_mode:
    override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)

  # Configures cluster spec for distribution strategy.
  num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                                     flags_obj.task_index)

  # Creates session config. allow_soft_placement = True, is required for
  # multi-GPU and is not harmful for other modes.
  session_config = tf.compat.v1.ConfigProto(
      inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
      intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
      allow_soft_placement=not experimental_preloading)

  if horovod_enabled():
    # The Scoped Allocator Optimization is enabled by default unless disabled by a flag.
    if not condition_env_var('TF_DISABLE_SCOPED_ALLOCATOR', default=False):
      from tensorflow.core.protobuf import rewriter_config_pb2  # pylint: disable=import-error
      session_config.graph_options.rewrite_options.scoped_allocator_optimization = rewriter_config_pb2.RewriterConfig.ON
      enable_op = session_config.graph_options.rewrite_options.scoped_allocator_opts.enable_op
      del enable_op[:]
      enable_op.append("HorovodAllreduce")

  distribution_strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_core.get_num_gpus(flags_obj),
      num_workers=num_workers,
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs)

  # Creates a `RunConfig` that checkpoints every 24 hours which essentially
  # results in checkpoints determined only by `epochs_between_evals`.
  run_config = tf.estimator.RunConfig(
      train_distribute=distribution_strategy,
      session_config=session_config,
      log_step_count_steps=flags_obj.display_steps,
      save_checkpoints_secs=None,
      save_checkpoints_steps=flags_obj.save_checkpoint_steps)

  # Initializes model with all but the dense layer from pretrained ResNet.
  # if flags_obj.pretrained_model_checkpoint_path is not None:
  #   warm_start_settings = tf.estimator.WarmStartSettings(
  #       flags_obj.pretrained_model_checkpoint_path,
  #       vars_to_warm_start='^(?!.*dense)')
  # else:
  #   warm_start_settings = None
  warm_start_settings = None

  model_dir=flags_obj.model_dir

  if horovod_enabled():
    model_dir="{}/rank_{}".format(flags_obj.model_dir, hvd.rank())

  if experimental_preloading:
    SelectedEstimator = HabanaEstimator
  else:
    SelectedEstimator = tf.estimator.Estimator

  if flags.FLAGS.is_mlperf_enabled:
    for eval_batch_size in range(flags_obj.batch_size, 1, -1):
      if imagenet_main.NUM_IMAGES['validation'] % eval_batch_size == 0:
        break
  else:
    eval_batch_size = flags_obj.batch_size

  classifier = SelectedEstimator(
      model_fn=model_function, model_dir=model_dir, config=run_config,
      warm_start_from=warm_start_settings, params={
          'resnet_size': int(flags_obj.resnet_size),
          'data_format': flags_obj.data_format,
          'batch_size': flags_obj.batch_size,
          'resnet_version': int(flags_obj.resnet_version),
          'model_type': flags_obj.model_type,
          'loss_scale': flags_core.get_loss_scale(flags_obj,
                                                  default_for_fp16=128),
          'dtype': flags_core.get_tf_dtype(flags_obj),
          'fine_tune': flags_obj.fine_tune,
          'num_workers': num_workers,
          'train_epochs': flags_obj.train_epochs,
          'warmup_epochs': flags_obj.warmup_epochs,
          'use_cosine_lr': flags_obj.use_cosine_lr,
      })

  run_params = {
      'batch_size': flags_obj.batch_size,
      'dtype': flags_core.get_tf_dtype(flags_obj),
      'resnet_size': flags_obj.resnet_size,
      'resnet_version': flags_obj.resnet_version,
      'model_type': flags_obj.model_type,
      'synthetic_data': flags_obj.use_synthetic_data,
      'train_epochs': flags_obj.train_epochs,
      'num_workers': num_workers,
  }
  if flags.FLAGS.is_mlperf_enabled:
    run_params['eval_batch_size'] = eval_batch_size

  if flags_obj.use_synthetic_data:
    dataset_name = dataset_name + '-synthetic'

  benchmark_logger = logger.get_benchmark_logger()
  benchmark_logger.log_run_info('resnet', dataset_name, run_params,
                                test_id=flags_obj.benchmark_test_id)

  train_hooks = hooks_helper.get_train_hooks(
      flags_obj.hooks,
      model_dir=model_dir,
      batch_size=flags_obj.batch_size)

  if flags.FLAGS.is_mlperf_enabled:
    _log_cache = []
    def formatter(x):
      """Abuse side effects to get tensors out of the model_fn."""
      if _log_cache:
        _log_cache.pop()
        _log_cache.append(x.copy())
        return str(x)

    compliance_hook = tf.estimator.LoggingTensorHook(
      tensors={_NUM_EXAMPLES_NAME: _NUM_EXAMPLES_NAME},
      every_n_iter=int(1e10),
      at_end=True,
      formatter=formatter)
  else:
    compliance_hook = None

  if horovod_enabled():

    if "tf_profiler_hook" not in flags_obj.hooks and os.environ.get("TF_RANGE_TRACE", False):
      from TensorFlow.common.utils import RangeTFProfilerHook
      begin = (imagenet_main.NUM_IMAGES["train"] // (flags_obj.batch_size * hvd.size()) + 100)
      train_hooks.append(RangeTFProfilerHook(begin,20, "./rank-{}".format(hvd.rank())))

    if "synapse_logger_hook" not in flags_obj.hooks and "range" == os.environ.get("HABANA_SYNAPSE_LOGGER", "False").lower():
      from TensorFlow.common.horovod_helpers import SynapseLoggerHook
      begin = (imagenet_main.NUM_IMAGES["train"] // (flags_obj.batch_size * hvd.size()) + 100)
      end = begin + 100
      print("Begin: {}".format(begin))
      print("End: {}".format(end))
      train_hooks.append(SynapseLoggerHook(list(range(begin, end)), False))
    train_hooks.append(hvd.BroadcastGlobalVariablesHook(0))


  def input_fn_train(num_epochs, input_context=None):
    return input_function(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_replica_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=num_epochs,
        dtype=flags_core.get_dl_type(flags_obj),
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        input_context=input_context, experimental_preloading=experimental_preloading)

  def input_fn_eval():
    return input_function(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_replica_batch_size(
            eval_batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=1,
        dtype=flags_core.get_dl_type(flags_obj), experimental_preloading=experimental_preloading)

  train_epochs = (0 if flags_obj.eval_only or not flags_obj.train_epochs else
                  flags_obj.train_epochs)

  max_train_steps = flags_obj.max_train_steps
  global_batch_size = flags_obj.batch_size * (hvd.size() if horovod_enabled() else 1)
  steps_per_epoch = (imagenet_main.NUM_IMAGES['train'] // global_batch_size)
  if max_train_steps is None:
    max_train_steps = steps_per_epoch * (train_epochs + flags_obj.train_offset)

  max_eval_steps = flags_obj.max_eval_steps
  if max_eval_steps is None:
    max_eval_steps = (imagenet_main.NUM_IMAGES['validation'] + eval_batch_size - 1) // eval_batch_size

  use_train_and_evaluate = flags_obj.use_train_and_evaluate or num_workers > 1
  if use_train_and_evaluate:
    train_spec = tf.estimator.TrainSpec(
        input_fn=lambda input_context=None: input_fn_train(
            train_epochs, input_context=input_context),
        hooks=train_hooks,
        max_steps=max_train_steps)
    eval_spec = tf.estimator.EvalSpec(input_fn=input_fn_eval)
    tf.compat.v1.logging.info('Starting to train and evaluate.')
    tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)
    # tf.estimator.train_and_evalute doesn't return anything in multi-worker
    # case.
    eval_results = {}
  else:
    if train_epochs == 0:
      # If --eval_only is set, perform a single loop with zero train epochs.
      schedule, n_loops = [0], 1
    else:
      # Compute the number of times to loop while training. All but the last
      # pass will train for `epochs_between_evals` epochs, while the last will
      # train for the number needed to reach `training_epochs`. For instance if
      #   train_epochs = 25 and epochs_between_evals = 10
      # schedule will be set to [10, 10, 5]. That is to say, the loop will:
      #   Train for 10 epochs and then evaluate.
      #   Train for another 10 epochs and then evaluate.
      #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
      n_loops = math.ceil(train_epochs / flags_obj.epochs_between_evals)
      schedule = [flags_obj.epochs_between_evals for _ in range(int(n_loops))]
      schedule[-1] = train_epochs - sum(schedule[:-1])  # over counting.

    if flags.FLAGS.is_mlperf_enabled:
      mllogger.event(key=mllog.constants.CACHE_CLEAR)
      mllogger.start(key=mllog.constants.RUN_START)
      mllogger.event(key=mllog.constants.GLOBAL_BATCH_SIZE,
                     value=global_batch_size)

    final_step = 0

    if flags.FLAGS.is_mlperf_enabled:
      success = False
      if flags_obj.train_offset > 0:
        final_step += flags_obj.train_offset * steps_per_epoch
        mllogger.event(key=mllog.constants.FIRST_EPOCH_NUM, value=1, metadata={'number of epochs before main loop: ': flags_obj.train_offset})
        for i in range(flags_obj.train_offset):
          mllogger.event(key=mllog.constants.EPOCH_NUM, value=i+1)
        classifier.train(
              input_fn=lambda input_context=None: input_fn_train(
              flags_obj.train_offset, input_context=input_context),
              hooks=train_hooks + [compliance_hook],
              max_steps=max_train_steps if max_train_steps < final_step else final_step)

    for cycle_index, num_train_epochs in enumerate(schedule):
      tf.compat.v1.logging.info('Starting cycle: %d/%d', cycle_index,
                                int(n_loops))
      if flags.FLAGS.is_mlperf_enabled:
        mllogger.start(key=mllog.constants.BLOCK_START, value=cycle_index+1)
        mllogger.event(key=mllog.constants.FIRST_EPOCH_NUM, value=cycle_index*flags_obj.epochs_between_evals + flags_obj.train_offset + 1)
        mllogger.event(key=mllog.constants.EPOCH_COUNT, value=flags_obj.epochs_between_evals)

        for j in range(flags_obj.epochs_between_evals):
          mllogger.event(key=mllog.constants.EPOCH_NUM,
                         value=cycle_index  * flags_obj.epochs_between_evals + j +  flags_obj.train_offset + 1)

      if num_train_epochs:
        # Since we are calling classifier.train immediately in each loop, the
        # value of num_train_epochs in the lambda function will not be changed
        # before it is used. So it is safe to ignore the pylint error here
        # pylint: disable=cell-var-from-loop
        final_step += num_train_epochs * steps_per_epoch
        classifier.train(
            input_fn=lambda input_context=None: input_fn_train(
                num_train_epochs, input_context=input_context),
            hooks=train_hooks + [compliance_hook] if compliance_hook is not None else train_hooks,
            max_steps=max_train_steps if max_train_steps < final_step else final_step)
        if flags.FLAGS.is_mlperf_enabled:
            mllogger.end(key=mllog.constants.BLOCK_STOP, value=cycle_index+1)

      if flags.FLAGS.is_mlperf_enabled:
        mllogger.start(key=mllog.constants.EVAL_START)
      # max_eval_steps is associated with testing and profiling.
      # As a result it is frequently called with synthetic data,
      # which will iterate forever. Passing steps=max_eval_steps
      # allows the eval (which is generally unimportant in those circumstances)
      # to terminate. Note that eval will run for max_eval_steps each loop,
      # regardless of the global_step count.
      if flags_obj.get_flag_value("return_before_eval", False):
        return {}
      if flags_obj.get_flag_value("disable_eval", False):
        eval_results = None
        continue
      tf.compat.v1.logging.info('Starting to evaluate.')
      eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                         steps=max_eval_steps)

      if flags.FLAGS.is_mlperf_enabled:
        mllogger.event(key=mllog.constants.EVAL_SAMPLES, value=int(eval_results[_NUM_EXAMPLES_NAME]))
        valdiation_epoch = (cycle_index + 1) * flags_obj.epochs_between_evals + flags_obj.train_offset
        mllogger.event(key=mllog.constants.EVAL_ACCURACY, value=float(eval_results['accuracy']), metadata={'epoch_num: ': valdiation_epoch})
        mllogger.end(key=mllog.constants.EVAL_STOP, metadata={'epoch_num: ' : valdiation_epoch})
        if flags_obj.stop_threshold:
          success = bool(eval_results['accuracy'] >= flags_obj.stop_threshold)

      benchmark_logger.log_evaluation_result(eval_results)

      if flags_obj.stop_threshold:
        if horovod_enabled():
          past_treshold = tf.cast(model_helpers.past_stop_threshold(
              flags_obj.stop_threshold, eval_results['accuracy']), tf.float32)
          global_past_treshold = tf.math.greater(
              hvd.allreduce(past_treshold, op=hvd.Sum), tf.zeros(1, tf.float32))
          if global_past_treshold.eval(session=tf.compat.v1.Session()):
            break
        else:
          if model_helpers.past_stop_threshold(
              flags_obj.stop_threshold, eval_results['accuracy']):
            break

  if flags_obj.export_dir is not None:
    # Exports a saved model for the given classifier.
    export_dtype = flags_core.get_tf_dtype(flags_obj)
    if flags_obj.image_bytes_as_serving_input:
      input_receiver_fn = functools.partial(
          image_bytes_serving_input_fn, shape, dtype=export_dtype)
    else:
      input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
          shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
    classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn,
                                 strip_default_attrs=True)

  stats = {}
  stats['eval_results'] = eval_results
  stats['train_hooks'] = train_hooks

  if flags.FLAGS.is_mlperf_enabled:
    mllogger.event(key=mllog.constants.RUN_STOP, value={"success": success})
    mllogger.end(key=mllog.constants.RUN_STOP)

  return stats
Ejemplo n.º 11
0
def main(_):

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "xnli": XnliProcessor,
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError(
            "At least one of `do_train`, `do_eval` or `do_predict' must be True."
        )

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.io.gfile.makedirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    model_dir = FLAGS.output_dir
    if horovod_enabled():
        model_dir = os.path.join(FLAGS.output_dir, "worker_" + str(hvd.rank()))

    # The Scoped Allocator Optimization is enabled by default unless disabled by a flag.
    if condition_env_var('TF_DISABLE_SCOPED_ALLOCATOR', default=False):
        session_config = None
    else:
        from tensorflow.core.protobuf import rewriter_config_pb2  # pylint: disable=import-error

        session_config = tf.compat.v1.ConfigProto()
        session_config.graph_options.rewrite_options.scoped_allocator_optimization = rewriter_config_pb2.RewriterConfig.ON

        enable_op = session_config.graph_options.rewrite_options.scoped_allocator_opts.enable_op
        del enable_op[:]
        enable_op.append("HorovodAllreduce")

    is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.compat.v1.estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=model_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf.compat.v1.estimator.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host),
        session_config=session_config)

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None

    train_batch_size = FLAGS.train_batch_size
    if horovod_enabled():
        train_batch_size = train_batch_size * hvd.size()

    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)
        num_train_steps = int(
            len(train_examples) / train_batch_size * FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    start_index = 0
    end_index = len(train_examples)
    per_worker_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record")]
    worker_id = 0

    if horovod_enabled():
        per_worker_filenames = [
            os.path.join(FLAGS.output_dir, "train.tf_record_{}".format(i))
            for i in range(hvd.size())
        ]
        num_examples_per_rank = len(train_examples) // hvd.size()
        remainder = len(train_examples) % hvd.size()
        worker_id = hvd.rank()
        if worker_id < remainder:
            start_index = worker_id * (num_examples_per_rank + 1)
            end_index = start_index + num_examples_per_rank + 1
        else:
            start_index = worker_id * num_examples_per_rank + remainder
            end_index = start_index + (num_examples_per_rank)

    learning_rate = FLAGS.learning_rate
    if horovod_enabled():
        learning_rate = learning_rate * hvd.size()

    model_fn = model_fn_builder(
        bert_config=bert_config,
        num_labels=len(label_list),
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_tpu,
        dropout_before_logits=FLAGS.dropout_before_logits)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.compat.v1.estimator.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.predict_batch_size)

    if FLAGS.do_train:
        train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
        file_based_convert_examples_to_features(
            train_examples[start_index:end_index], label_list,
            FLAGS.max_seq_length, tokenizer, per_worker_filenames[worker_id])
        tf.compat.v1.logging.info("***** Running training *****")
        tf.compat.v1.logging.info("  Num examples = %d", len(train_examples))
        tf.compat.v1.logging.info("  Per-worker batch size = %d",
                                  FLAGS.train_batch_size)
        tf.compat.v1.logging.info("  Total batch size = %d", train_batch_size)
        tf.compat.v1.logging.info("  Num steps = %d", num_train_steps)
        train_input_fn = file_based_input_fn_builder(
            input_file=per_worker_filenames,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True)

        train_hooks = [
            habana_hooks.PerfLoggingHook(batch_size=train_batch_size,
                                         mode="train")
        ]
        if horovod_enabled():
            train_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

        if "range" == os.environ.get("HABANA_SYNAPSE_LOGGER", "False").lower():
            from TensorFlow.common.horovod_helpers import SynapseLoggerHook
            begin = 30
            end = begin + 10
            print("Begin: {}".format(begin))
            print("End: {}".format(end))
            train_hooks.append(
                SynapseLoggerHook(list(range(begin, end)), False))

        estimator.train(input_fn=train_input_fn,
                        max_steps=num_train_steps,
                        hooks=train_hooks)

    if FLAGS.do_eval:
        eval_examples = processor.get_dev_examples(FLAGS.data_dir)
        num_actual_eval_examples = len(eval_examples)
        if FLAGS.use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on. These do NOT count towards the metric (all tf.metrics
            # support a per-instance weight, and these get a weight of 0.0).
            while len(eval_examples) % FLAGS.eval_batch_size != 0:
                eval_examples.append(PaddingInputExample())

        eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
        file_based_convert_examples_to_features(eval_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer, eval_file)

        tf.compat.v1.logging.info("***** Running evaluation *****")
        tf.compat.v1.logging.info(
            "  Num examples = %d (%d actual, %d padding)", len(eval_examples),
            num_actual_eval_examples,
            len(eval_examples) - num_actual_eval_examples)
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        # This tells the estimator to run through the entire set.
        eval_steps = None
        # However, if running eval on the TPU, you will need to specify the
        # number of steps.
        if FLAGS.use_tpu:
            assert len(eval_examples) % FLAGS.eval_batch_size == 0
            eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)

        eval_drop_remainder = True if FLAGS.use_tpu else False
        eval_input_fn = file_based_input_fn_builder(
            input_file=eval_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=eval_drop_remainder)

        eval_hooks = [
            habana_hooks.PerfLoggingHook(batch_size=FLAGS.eval_batch_size,
                                         mode="eval")
        ]
        result = estimator.evaluate(input_fn=eval_input_fn,
                                    steps=eval_steps,
                                    hooks=eval_hooks)

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.io.gfile.GFile(output_eval_file, "w") as writer:
            tf.compat.v1.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.compat.v1.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    if FLAGS.do_predict:
        predict_examples = processor.get_test_examples(FLAGS.data_dir)
        num_actual_predict_examples = len(predict_examples)
        if FLAGS.use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on.
            while len(predict_examples) % FLAGS.predict_batch_size != 0:
                predict_examples.append(PaddingInputExample())

        predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
        file_based_convert_examples_to_features(predict_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer, predict_file)

        tf.compat.v1.logging.info("***** Running prediction*****")
        tf.compat.v1.logging.info(
            "  Num examples = %d (%d actual, %d padding)",
            len(predict_examples), num_actual_predict_examples,
            len(predict_examples) - num_actual_predict_examples)
        tf.compat.v1.logging.info("  Batch size = %d",
                                  FLAGS.predict_batch_size)

        predict_drop_remainder = True if FLAGS.use_tpu else False
        predict_input_fn = file_based_input_fn_builder(
            input_file=predict_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=predict_drop_remainder)

        result = estimator.predict(input_fn=predict_input_fn)

        output_predict_file = os.path.join(FLAGS.output_dir,
                                           "test_results.tsv")
        with tf.io.gfile.GFile(output_predict_file, "w") as writer:
            num_written_lines = 0
            tf.compat.v1.logging.info("***** Predict results *****")
            for (i, prediction) in enumerate(result):
                probabilities = prediction["probabilities"]
                if i >= num_actual_predict_examples:
                    break
                output_line = "\t".join(
                    str(class_probability)
                    for class_probability in probabilities) + "\n"
                writer.write(output_line)
                num_written_lines += 1
        assert num_written_lines == num_actual_predict_examples
Ejemplo n.º 12
0
 def __init__(self, summary_dir, batch_size, every_n_steps=100):
     super(SpeedHook, self).__init__(every_n_steps=every_n_steps,
                                     output_dir=summary_dir)
     self._total_batch_size = batch_size * hvd.size()
Ejemplo n.º 13
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    if not FLAGS.do_train and not FLAGS.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if horovod_enabled():
        FLAGS.output_dir = FLAGS.output_dir if hvd_rank(
        ) == 0 else os.path.join(FLAGS.output_dir, str(hvd_rank()))

    albert_config = modeling.AlbertConfig.from_json_file(
        FLAGS.albert_config_file)
    if FLAGS.deterministic_run and (albert_config.attention_probs_dropout_prob
                                    or albert_config.hidden_dropout_prob):
        albert_config.attention_probs_dropout_prob = 0.0
        albert_config.hidden_dropout_prob = 0.0

    tf.gfile.MakeDirs(FLAGS.output_dir)

    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.gfile.Glob(input_pattern))

    if FLAGS.use_horovod and len(input_files) < hvd.size():
        input_files = [input_files[0] for i in range(hvd.size())]

    tf.logging.info("*** Input Files ***")
    for input_file in input_files:
        tf.logging.info("  %s" % input_file)

    eval_files = []
    for eval_pattern in FLAGS.eval_file.split(","):
        eval_files.extend(tf.gfile.Glob(eval_pattern))

    if FLAGS.use_horovod and len(eval_files) < hvd.size():
        eval_files = [eval_files[0] for i in range(hvd.size())]

    tf.logging.info("*** Eval Files ***")
    for eval_file in eval_files:
        tf.logging.info("  %s" % eval_file)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        save_summary_steps=FLAGS.save_summary_steps,
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host))

    num_train_steps = FLAGS.num_train_steps
    num_warmup_steps = FLAGS.num_warmup_steps
    if FLAGS.do_train and horovod_enabled():
        num_train_steps //= hvd_size()
        num_warmup_steps //= hvd_size()

    model_fn = model_fn_builder(
        albert_config=albert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate
        if not FLAGS.use_horovod else FLAGS.learning_rate * hvd_size(),
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_tpu,
        optimizer=FLAGS.optimizer,
        poly_power=FLAGS.poly_power,
        start_warmup_step=FLAGS.start_warmup_step,
        use_einsum=FLAGS.use_einsum)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.estimator.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    write_hparams_v1(
        FLAGS.output_dir, {
            'batch_size': FLAGS.train_batch_size,
            'batch_size_per_pu': FLAGS.train_batch_size,
            **{x: getattr(FLAGS, x)
               for x in FLAGS}
        })

    if FLAGS.do_train:
        training_hooks = []
        if horovod_enabled():
            training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

        tf.logging.info("***** Running training *****")
        tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        train_input_fn = input_fn_builder(
            input_files=input_files,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            is_training=True)
        with dump_callback():
            estimator.train(input_fn=train_input_fn,
                            hooks=training_hooks,
                            max_steps=FLAGS.num_train_steps)

    if FLAGS.do_eval and (not FLAGS.use_horovod or hvd_rank() == 0):
        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)
        global_step = -1
        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        writer = tf.gfile.GFile(output_eval_file, "w")
        eval_input_fn = input_fn_builder(
            input_files=eval_files,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            is_training=False)
        best_perf = 0
        key_name = "masked_lm_accuracy"
        while global_step < FLAGS.num_train_steps:
            if estimator.latest_checkpoint() is None:
                tf.logging.info("No checkpoint found yet. Sleeping.")
                time.sleep(1)
            else:
                result = estimator.evaluate(input_fn=eval_input_fn,
                                            steps=FLAGS.max_eval_steps)
                global_step = result["global_step"]
                tf.logging.info("***** Eval results *****")
                checkpoint_path = estimator.latest_checkpoint()
                for key in sorted(result.keys()):
                    tf.logging.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))
                    if result[key_name] > best_perf:
                        best_perf = result[key_name]
                        for ext in ["meta", "data-00000-of-00001", "index"]:
                            src_ckpt = checkpoint_path + ".{}".format(ext)
                            tgt_ckpt = checkpoint_path.rsplit(
                                "-", 1)[0] + "-best.{}".format(ext)
                            tf.logging.info("saving {} to {}".format(
                                src_ckpt, tgt_ckpt))
                            tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True)
                            writer.write("saved {} to {}\n".format(
                                src_ckpt, tgt_ckpt))
Ejemplo n.º 14
0
def main(argv):
  tf.disable_v2_behavior()
  tf.enable_resource_variables()

  if FLAGS.use_hpu and FLAGS.recipe_cache:
    prepare_recipe_cache()

  if FLAGS.use_horovod:
    if FLAGS.use_hpu:
      from TensorFlow.common.horovod_helpers import hvd_init, horovod_enabled, hvd
      hvd_init()
      assert horovod_enabled()
      if FLAGS.recipe_cache:
        # Other ranks should wait for recipe cache to be removed.
        # This operation can't be done before hvd_init.
        from mpi4py import MPI
        MPI.COMM_WORLD.Barrier()
    else:
      import horovod.tensorflow as hvd
      hvd.init()
      assert hvd.size() > 1
      os.environ['CUDA_VISIBLE_DEVICES'] = str(hvd.local_rank())

  if FLAGS.use_hpu:
    if FLAGS.use_bf16:
      os.environ['TF_BF16_CONVERSION'] = FLAGS.bf16_config_path

    dyn_shapes_flag = 'TF_ENABLE_DYNAMIC_SHAPES'
    if dyn_shapes_flag not in os.environ:
        os.environ[dyn_shapes_flag] = 'false'

    from habana_frameworks.tensorflow import load_habana_module  # noqa
    load_habana_module()

  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  # If we just have to print the registry, do that and exit early.
  maybe_log_registry_and_exit()

  # Create HParams.
  if argv:
    set_hparams_from_args(argv[1:])
  if FLAGS.schedule != "run_std_server":
    hparams = create_hparams()
  if FLAGS.gpu_automatic_mixed_precision:
    setattr(hparams, "gpu_automatic_mixed_precision", True)
  if FLAGS.deterministic_dataset:
    hparams.add_hparam("deterministic_dataset", True)

  hparams.add_hparam("use_horovod", FLAGS.use_horovod)
  hparams.add_hparam("use_hpu", FLAGS.use_hpu)
  if FLAGS.use_horovod:
    hparams.add_hparam("hvd_worker_id", hvd.rank())
    hparams.add_hparam("hvd_size", hvd.size())

  if FLAGS.schedule == "run_std_server":
    run_std_server()
  trainer_lib.set_random_seed(FLAGS.random_seed)

  if FLAGS.generate_data:
    generate_data()

  exp_fn = create_experiment_fn()
  exp = exp_fn(create_run_config(hparams), hparams)
  if is_chief():
    save_metadata(hparams)

  with dump_callback():
    execute_schedule(exp)
Ejemplo n.º 15
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file)

  if FLAGS.deterministic_run and (albert_config.attention_probs_dropout_prob or albert_config.hidden_dropout_prob):
        albert_config.attention_probs_dropout_prob = 0.0
        albert_config.hidden_dropout_prob = 0.0

  validate_flags_or_throw(albert_config)

  tf.gfile.MakeDirs(FLAGS.output_dir)
  model_dir = FLAGS.output_dir
  if horovod_enabled():
    model_dir = os.path.join(FLAGS.output_dir, "worker_" + str(hvd.rank()))

  tokenizer = fine_tuning_utils.create_vocab(
      vocab_file=FLAGS.vocab_file,
      do_lower_case=FLAGS.do_lower_case,
      spm_model_file=FLAGS.spm_model_file,
      hub_module=FLAGS.albert_hub_module_handle)

  tpu_cluster_resolver = None
  if FLAGS.use_tpu and FLAGS.tpu_name:
    tpu_cluster_resolver = tf.distribute_cluster.TPUClusterResolver(
        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

  is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2
  if FLAGS.do_train:
    iterations_per_loop = int(min(FLAGS.iterations_per_loop,
                                  FLAGS.save_checkpoints_steps))
  else:
    iterations_per_loop = FLAGS.iterations_per_loop

  # The Scoped Allocator Optimization is enabled by default unless disabled by a flag.
  if FLAGS.enable_scoped_allocator:
    from tensorflow.core.protobuf import rewriter_config_pb2  # pylint: disable=import-error

    session_config = tf.compat.v1.ConfigProto()
    session_config.graph_options.rewrite_options.scoped_allocator_optimization = rewriter_config_pb2.RewriterConfig.ON

    enable_op = session_config.graph_options.rewrite_options.scoped_allocator_opts.enable_op
    del enable_op[:]
    enable_op.append("HorovodAllreduce")
  else:
    session_config = None

  run_config = tf.estimator.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      master=FLAGS.master,
      model_dir=model_dir,
      keep_checkpoint_max=0,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      save_summary_steps=FLAGS.save_summary_steps,
      tpu_config=tf.estimator.tpu.TPUConfig(
          iterations_per_loop=iterations_per_loop,
          num_shards=FLAGS.num_tpu_cores,
          per_host_input_for_training=is_per_host),
      session_config=session_config)

  train_examples = None
  num_train_steps = None
  num_warmup_steps = None

  train_batch_size = FLAGS.train_batch_size
  if horovod_enabled():
    train_batch_size = train_batch_size * hvd.size()

  if FLAGS.do_train:
    train_examples = squad_utils.read_squad_examples(
        input_file=FLAGS.train_file, is_training=True)
    num_train_steps = int(
        len(train_examples) / train_batch_size * FLAGS.num_train_epochs)
    if FLAGS.train_steps > 0:
      num_train_steps = FLAGS.train_steps
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    # Pre-shuffle the input to avoid having to make a very large shuffle
    # buffer in in the `input_fn`.
    rng = random.Random(12345)
    rng.shuffle(train_examples)

  start_index = 0
  end_index = len(train_examples)
  per_worker_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record")]
  worker_id = 0

  if horovod_enabled():
    per_worker_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record_{}".format(i)) for i in range(hvd.local_size())]
    num_examples_per_rank = len(train_examples) // hvd.size()
    remainder = len(train_examples) % hvd.size()
    worker_id = hvd.rank()
    if worker_id < remainder:
      start_index = worker_id * (num_examples_per_rank + 1)
      end_index = start_index + num_examples_per_rank + 1
    else:
      start_index = worker_id * num_examples_per_rank + remainder
      end_index = start_index + (num_examples_per_rank)

  learning_rate = FLAGS.learning_rate

  model_fn = squad_utils.v1_model_fn_builder(
      albert_config=albert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_tpu,
      use_einsum=FLAGS.use_einsum,
      hub_module=FLAGS.albert_hub_module_handle)

  # If TPU is not available, this will fall back to normal Estimator on CPU
  # or GPU.
  estimator = tf.estimator.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=FLAGS.train_batch_size,
      predict_batch_size=FLAGS.predict_batch_size)

  write_hparams_v1(FLAGS.output_dir, {
    'batch_size': FLAGS.train_batch_size,
    **{x: getattr(FLAGS, x) for x in FLAGS}
  })

  if FLAGS.do_train:
    # We write to a temporary file to avoid storing very large constant tensors
    # in memory.

    tf.logging.info("***** Running training *****")
    tf.logging.info("  Num orig examples = %d", len(train_examples))
    tf.logging.info("  Num steps = %d", num_train_steps)
    tf.logging.info("  Per-worker batch size = %d", FLAGS.train_batch_size)
    tf.logging.info("  Total batch size = %d", train_batch_size)

    ## use pre-generated tf_record as input
    if FLAGS.input_file:
      if horovod_enabled():
        per_worker_filenames_temp = [os.path.join(FLAGS.input_file, "train.tf_record") for i in range(hvd.local_size())]
      else:
        per_worker_filenames_temp = [os.path.join(FLAGS.input_file, "train.tf_record")]

      if tf.gfile.Exists(per_worker_filenames_temp[hvd.local_rank() if horovod_enabled() else worker_id]):
        per_worker_filenames = per_worker_filenames_temp

    if not tf.gfile.Exists(per_worker_filenames[hvd.local_rank() if horovod_enabled() else worker_id]):
      train_writer = squad_utils.FeatureWriter(
          filename=per_worker_filenames[hvd.local_rank() if horovod_enabled() else worker_id], is_training=True)
      squad_utils.convert_examples_to_features(
          examples=train_examples[start_index:end_index],
          tokenizer=tokenizer,
          max_seq_length=FLAGS.max_seq_length,
          doc_stride=FLAGS.doc_stride,
          max_query_length=FLAGS.max_query_length,
          is_training=True,
          output_fn=train_writer.process_feature,
          do_lower_case=FLAGS.do_lower_case)
      tf.logging.info("  Num split examples = %d", train_writer.num_features)
      train_writer.close()

      del train_examples

    train_input_fn = squad_utils.input_fn_builder(
        input_file=per_worker_filenames,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True,
        use_tpu=FLAGS.use_tpu,
        bsz=FLAGS.train_batch_size,
        is_v2=False)

    train_hooks = [habana_hooks.PerfLoggingHook(batch_size=train_batch_size, mode="train")]
    if horovod_enabled():
      train_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

    if "range" == os.environ.get("HABANA_SYNAPSE_LOGGER", "False").lower():
      from habana_frameworks.tensorflow.synapse_logger_helpers import SynapseLoggerHook
      begin = 670
      end = begin + 10
      print("Begin: {}".format(begin))
      print("End: {}".format(end))
      train_hooks.append(SynapseLoggerHook(list(range(begin, end)), False))

    with dump_callback():
      estimator.train(input_fn=train_input_fn, max_steps=num_train_steps, hooks=train_hooks)

  if FLAGS.do_predict:
    with tf.gfile.Open(FLAGS.predict_file) as predict_file:
      prediction_json = json.load(predict_file)["data"]

    eval_examples = squad_utils.read_squad_examples(
        input_file=FLAGS.predict_file, is_training=False)

    eval_writer = squad_utils.FeatureWriter(
        filename=os.path.join(model_dir, "eval.tf_record"), is_training=False)
    eval_features = []

    def append_feature(feature):
      eval_features.append(feature)
      eval_writer.process_feature(feature)

    squad_utils.convert_examples_to_features(
        examples=eval_examples,
        tokenizer=tokenizer,
        max_seq_length=FLAGS.max_seq_length,
        doc_stride=FLAGS.doc_stride,
        max_query_length=FLAGS.max_query_length,
        is_training=False,
        output_fn=append_feature,
        do_lower_case=FLAGS.do_lower_case)
    eval_writer.close()

    with tf.gfile.Open(os.path.join(model_dir, "eval_left.tf_record"), "wb") as fout:
      pickle.dump(eval_features, fout)

    tf.logging.info("***** Running predictions *****")
    tf.logging.info("  Num orig examples = %d", len(eval_examples))
    tf.logging.info("  Num split examples = %d", len(eval_features))
    tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

    predict_input_fn = squad_utils.input_fn_builder(
        input_file=os.path.join(model_dir, "eval.tf_record"),
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=False,
        use_tpu=FLAGS.use_tpu,
        bsz=FLAGS.predict_batch_size,
        is_v2=False)

    eval_hooks = [habana_hooks.PerfLoggingHook(batch_size=FLAGS.predict_batch_size, mode="eval")]

    def get_result(checkpoint):
      """Evaluate the checkpoint on SQuAD 1.0."""
      # If running eval on the TPU, you will need to specify the number of
      # steps.
      reader = tf.train.NewCheckpointReader(checkpoint)
      global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)
      all_results = []
      for result in estimator.predict(
          predict_input_fn, yield_single_examples=True,
          checkpoint_path=checkpoint, hooks=eval_hooks):
        if len(all_results) % 1000 == 0:
          tf.logging.info("Processing example: %d" % (len(all_results)))
        unique_id = int(result["unique_ids"])
        start_log_prob = [float(x) for x in result["start_log_prob"].flat]
        end_log_prob = [float(x) for x in result["end_log_prob"].flat]
        all_results.append(
            squad_utils.RawResult(
                unique_id=unique_id,
                start_log_prob=start_log_prob,
                end_log_prob=end_log_prob))

      output_prediction_file = os.path.join(
          model_dir, "predictions.json")
      output_nbest_file = os.path.join(
          model_dir, "nbest_predictions.json")

      result_dict = {}
      squad_utils.accumulate_predictions_v1(
          result_dict, eval_examples, eval_features,
          all_results, FLAGS.n_best_size, FLAGS.max_answer_length)
      predictions = squad_utils.write_predictions_v1(
          result_dict, eval_examples, eval_features, all_results,
          FLAGS.n_best_size, FLAGS.max_answer_length,
          output_prediction_file, output_nbest_file)

      return squad_utils.evaluate_v1(
          prediction_json, predictions), int(global_step)

    def _find_valid_cands(curr_step):
      filenames = tf.gfile.ListDirectory(model_dir)
      candidates = []
      for filename in filenames:
        if filename.endswith(".index"):
          ckpt_name = filename[:-6]
          idx = ckpt_name.split("-")[-1]
          if idx != "best" and int(idx) > curr_step:
            candidates.append(filename)
      return candidates

    output_eval_file = os.path.join(model_dir, "eval_results.txt")
    checkpoint_path = os.path.join(model_dir, "model.ckpt-best")
    key_name = "f1"
    writer = tf.gfile.GFile(output_eval_file, "w")
    if tf.gfile.Exists(checkpoint_path + ".index"):
      result = get_result(checkpoint_path)
      exact_match = result[0]["exact_match"]
      f1 = result[0]["f1"]
      with TBSummary(os.path.join(model_dir, 'eval')) as summary_writer:
          summary_writer.add_scalar('f1', f1, 0)
          summary_writer.add_scalar('exact_match', exact_match, 0)
      best_perf = result[0][key_name]
      global_step = result[1]
    else:
      global_step = -1
      best_perf = -1
      checkpoint_path = None
    while global_step < num_train_steps:
      steps_and_files = {}
      filenames = tf.gfile.ListDirectory(model_dir)
      for filename in filenames:
        if filename.endswith(".index"):
          ckpt_name = filename[:-6]
          cur_filename = os.path.join(model_dir, ckpt_name)
          if cur_filename.split("-")[-1] == "best":
            continue
          gstep = int(cur_filename.split("-")[-1])
          if gstep not in steps_and_files:
            tf.logging.info("Add {} to eval list.".format(cur_filename))
            steps_and_files[gstep] = cur_filename
      tf.logging.info("found {} files.".format(len(steps_and_files)))
      if not steps_and_files:
        tf.logging.info("found 0 file, global step: {}. Sleeping."
                        .format(global_step))
        time.sleep(60)
      else:
        for ele in sorted(steps_and_files.items()):
          step, checkpoint_path = ele
          if global_step >= step:
            if len(_find_valid_cands(step)) > 1:
              for ext in ["meta", "data-00000-of-00001", "index"]:
                src_ckpt = checkpoint_path + ".{}".format(ext)
                tf.logging.info("removing {}".format(src_ckpt))
                tf.gfile.Remove(src_ckpt)
            continue
          result, global_step = get_result(checkpoint_path)
          exact_match = result["exact_match"]
          f1 = result["f1"]
          with TBSummary(os.path.join(model_dir, 'eval')) as summary_writer:
            summary_writer.add_scalar('f1', f1, 0)
            summary_writer.add_scalar('exact_match', exact_match, 0)
          tf.logging.info("***** Eval results *****")
          for key in sorted(result.keys()):
            tf.logging.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))
          if result[key_name] > best_perf:
            best_perf = result[key_name]
            for ext in ["meta", "data-00000-of-00001", "index"]:
              src_ckpt = checkpoint_path + ".{}".format(ext)
              tgt_ckpt = checkpoint_path.rsplit(
                  "-", 1)[0] + "-best.{}".format(ext)
              tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt))
              tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True)
              writer.write("saved {} to {}\n".format(src_ckpt, tgt_ckpt))
          writer.write("best {} = {}\n".format(key_name, best_perf))
          tf.logging.info("  best {} = {}\n".format(key_name, best_perf))

          if len(_find_valid_cands(global_step)) > 2:
            for ext in ["meta", "data-00000-of-00001", "index"]:
              src_ckpt = checkpoint_path + ".{}".format(ext)
              tf.logging.info("removing {}".format(src_ckpt))
              tf.gfile.Remove(src_ckpt)
          writer.write("=" * 50 + "\n")

    checkpoint_path = os.path.join(model_dir, "model.ckpt-best")
    result, global_step = get_result(checkpoint_path)
    tf.logging.info("***** Final Eval results *****")
    for key in sorted(result.keys()):
      tf.logging.info("  %s = %s", key, str(result[key]))
      writer.write("%s = %s\n" % (key, str(result[key])))
    writer.write("best perf happened at step: {}".format(global_step))

  if FLAGS.export_dir:
    tf.gfile.MakeDirs(FLAGS.export_dir)
    squad_serving_input_fn = (
        build_squad_serving_input_fn(FLAGS.max_seq_length))
    tf.logging.info("Starting to export model.")
    subfolder = estimator.export_saved_model(
        export_dir_base=os.path.join(FLAGS.export_dir, "saved_model"),
        serving_input_receiver_fn=squad_serving_input_fn)

    tf.logging.info("Starting to export TFLite.")
    converter = tf.lite.TFLiteConverter.from_saved_model(
        subfolder,
        input_arrays=["input_ids", "input_mask", "segment_ids"],
        output_arrays=["start_logits", "end_logits"])
    float_model = converter.convert()
    tflite_file = os.path.join(FLAGS.export_dir, "albert_model.tflite")
    with tf.gfile.GFile(tflite_file, "wb") as f:
      f.write(float_model)
Ejemplo n.º 16
0
load_habana_module()
hvd.init()

from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
import time

if hvd.local_rank() == 0:
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    hvd.broadcast(0, 0)
else:
    hvd.broadcast(0, 0)
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

num_pics_per_rank = x_train.shape[0] // hvd.size()
pic_begin = num_pics_per_rank * hvd.rank()
pic_end = pic_begin + num_pics_per_rank
x_train = x_train[pic_begin:pic_end, ]
y_train = y_train[pic_begin:pic_end, ]

x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")

train_ds = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train)).shuffle(10000).batch(32)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
Ejemplo n.º 17
0
        assert(hvd.is_initialized())
        os.environ['CUDA_VISIBLE_DEVICES'] = str(hvd.local_rank())

    if ARGS.recipe_cache and ARGS.use_horovod:
        MPI.COMM_WORLD.Barrier()

    tf.disable_eager_execution()
    tf.logging.set_verbosity(tf.logging.INFO)

    train_hooks = []
    global_batch_size = ARGS.batch_size
    distributed_optimizer = None
    model_dir = ARGS.model_dir

    if ARGS.use_horovod:
        global_batch_size = ARGS.batch_size * hvd.size()
        num_shards = hvd.size()
        shard_index = hvd.rank()
        train_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
        distributed_optimizer = hvd.DistributedOptimizer

        if hvd.rank() > 0:
            model_dir = os.path.join(
                ARGS.model_dir, 'worker_' + str(hvd.rank()))

    run_config, params = construct_run_config(
        global_batch_size, model_dir, distributed_optimizer, num_shards, shard_index)

    tf.logging.info('steps_per_epoch: %s' % params['steps_per_epoch'])

    if ARGS.mode == 'train':
Ejemplo n.º 18
0
def main(_):
    init_bert_flags()
    os.environ[
        "TF_XLA_FLAGS"] = "--tf_xla_enable_lazy_compilation=false"  #causes memory fragmentation for bert leading to OOM

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
    dllogging = utils.dllogger_class.dllogger_class(FLAGS.dllog_path)

    if not FLAGS.do_train and not FLAGS.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if horovod_enabled():
        FLAGS.output_dir = FLAGS.output_dir if hvd_rank(
        ) == 0 else os.path.join(FLAGS.output_dir, str(hvd_rank()))

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    tf.io.gfile.makedirs(FLAGS.output_dir)

    input_files = []
    for input_file_dir in FLAGS.input_files_dir.split(","):
        input_files.extend(tf.io.gfile.glob(os.path.join(input_file_dir, "*")))

    if FLAGS.horovod and len(input_files) < hvd.size():
        input_files = [input_files[0] for i in range(hvd.size())]

    if FLAGS.amp and FLAGS.manual_fp16:
        raise ValueError(
            "AMP and Manual Mixed Precision Training are both activated! Error"
        )

    is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.compat.v1.estimator.tpu.RunConfig(
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf.compat.v1.estimator.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            per_host_input_for_training=is_per_host))

    num_train_steps = FLAGS.num_train_steps
    num_warmup_steps = FLAGS.num_warmup_steps
    if FLAGS.do_train and horovod_enabled():
        num_train_steps //= hvd_size()
        num_warmup_steps //= hvd_size()

    model_fn = model_fn_builder(bert_config=bert_config,
                                init_checkpoint=FLAGS.init_checkpoint,
                                learning_rate=FLAGS.learning_rate
                                if not FLAGS.horovod else FLAGS.learning_rate *
                                hvd_size(),
                                num_train_steps=num_train_steps,
                                num_warmup_steps=num_warmup_steps,
                                use_one_hot_embeddings=False)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.compat.v1.estimator.tpu.TPUEstimator(
        use_tpu=False,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    if FLAGS.do_train:

        training_hooks = []
        if horovod_enabled():
            training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
        if (not FLAGS.horovod or hvd_rank() == 0):
            global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps if not FLAGS.horovod else FLAGS.train_batch_size * FLAGS.num_accumulation_steps * hvd.size(
            )
            training_hooks.append(
                _LogSessionRunHook(global_batch_size,
                                   FLAGS.num_accumulation_steps, dllogging,
                                   FLAGS.display_loss_steps,
                                   FLAGS.save_checkpoints_steps,
                                   FLAGS.report_loss))

        tf.compat.v1.logging.info("***** Running training *****")
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        train_input_fn = input_fn_builder(
            input_files=input_files,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            is_training=True)

        train_start_time = time.time()
        estimator.train(input_fn=train_input_fn,
                        hooks=training_hooks,
                        max_steps=num_train_steps)
        train_time_elapsed = time.time() - train_start_time

        if (not FLAGS.horovod or hvd_rank() == 0):
            train_time_wo_overhead = training_hooks[-1].total_time
            avg_sentences_per_second = num_train_steps * global_batch_size * 1.0 / train_time_elapsed
            ss_sentences_per_second = (
                num_train_steps - training_hooks[-1].skipped
            ) * global_batch_size * 1.0 / train_time_wo_overhead

            tf.compat.v1.logging.info("-----------------------------")
            tf.compat.v1.logging.info(
                "Total Training Time = %0.2f for Sentences = %d",
                train_time_elapsed, num_train_steps * global_batch_size)
            tf.compat.v1.logging.info(
                "Total Training Time W/O Overhead = %0.2f for Sentences = %d",
                train_time_wo_overhead,
                (num_train_steps - training_hooks[-1].skipped) *
                global_batch_size)
            tf.compat.v1.logging.info(
                "Throughput Average (sentences/sec) with overhead = %0.2f",
                avg_sentences_per_second)
            tf.compat.v1.logging.info(
                "Throughput Average (sentences/sec) = %0.2f",
                ss_sentences_per_second)
            dllogging.logger.log(
                step=(),
                data={"throughput_train": ss_sentences_per_second},
                verbosity=Verbosity.DEFAULT)
            tf.compat.v1.logging.info("-----------------------------")

    if FLAGS.do_eval and (not FLAGS.horovod or hvd_rank() == 0):
        tf.compat.v1.logging.info("***** Running evaluation *****")
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        eval_files = []
        for eval_file_dir in FLAGS.eval_files_dir.split(","):
            eval_files.extend(
                tf.io.gfile.glob(os.path.join(eval_file_dir, "*")))

        eval_input_fn = input_fn_builder(
            input_files=input_files,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            is_training=False)

        eval_hooks = [LogEvalRunHook(FLAGS.eval_batch_size)]
        eval_start_time = time.time()
        result = estimator.evaluate(input_fn=eval_input_fn,
                                    steps=FLAGS.max_eval_steps,
                                    hooks=eval_hooks)

        eval_time_elapsed = time.time() - eval_start_time
        time_list = eval_hooks[-1].time_list
        time_list.sort()
        # Removing outliers (init/warmup) in throughput computation.
        eval_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.99)])
        num_sentences = (int(len(time_list) * 0.99)) * FLAGS.eval_batch_size

        ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead

        tf.compat.v1.logging.info("-----------------------------")
        tf.compat.v1.logging.info(
            "Total Inference Time = %0.2f for Sentences = %d",
            eval_time_elapsed, eval_hooks[-1].count * FLAGS.eval_batch_size)
        tf.compat.v1.logging.info(
            "Total Inference Time W/O Overhead = %0.2f for Sentences = %d",
            eval_time_wo_overhead, num_sentences)
        tf.compat.v1.logging.info("Summary Inference Statistics on EVAL set")
        tf.compat.v1.logging.info("Batch size = %d", FLAGS.eval_batch_size)
        tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
        tf.compat.v1.logging.info("Precision = %s",
                                  "fp16" if FLAGS.amp else "fp32")
        tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f",
                                  ss_sentences_per_second)
        dllogging.logger.log(step=(),
                             data={"throughput_val": ss_sentences_per_second},
                             verbosity=Verbosity.DEFAULT)
        tf.compat.v1.logging.info("-----------------------------")

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.io.gfile.GFile(output_eval_file, "w") as writer:
            tf.compat.v1.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.compat.v1.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
Ejemplo n.º 19
0
def main(_):
    os.environ[
        "TF_XLA_FLAGS"] = "--tf_xla_enable_lazy_compilation=false"  #causes memory fragmentation for bert leading to OOM

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
    dllogging = utils.dllogger_class.dllogger_class(FLAGS.dllog_path)

    if not FLAGS.do_train and not FLAGS.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    # In multi-node scenario, on each of HLSes there must be a checkpoint directly in the output_dir (read by Phase 2).
    # There may be only one worker with comm_local_rank() == 0 on each machine and this worker will put its checkpoints there.
    # All other workers use sub-directories to keep checkpoints.
    if horovod_enabled() and comm_local_rank() != 0:
        FLAGS.output_dir = os.path.join(FLAGS.output_dir, str(hvd_rank()))

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    tf.io.gfile.makedirs(FLAGS.output_dir)

    input_files = []
    for input_file_dir in FLAGS.input_files_dir.split(","):
        input_files.extend(tf.io.gfile.glob(os.path.join(input_file_dir, "*")))

    if FLAGS.horovod and len(input_files) < hvd.size():
        tf.compat.v1.logging.warning(
            "Input files count lower then expected. Using single file for OVERFIT test."
        )
        input_files = [input_files[0] for i in range(hvd.size())]
    if FLAGS.amp and FLAGS.manual_fp16:
        raise ValueError(
            "AMP and Manual Mixed Precision Training are both activated! Error"
        )

    is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2

    # The Scoped Allocator Optimization is enabled by default unless disabled by a flag.
    if condition_env_var('TF_DISABLE_SCOPED_ALLOCATOR', default=False):
        session_config = tf.compat.v1.ConfigProto()
    else:
        from tensorflow.core.protobuf import rewriter_config_pb2  # pylint: disable=import-error

        session_config = tf.compat.v1.ConfigProto()
        session_config.graph_options.rewrite_options.scoped_allocator_optimization = rewriter_config_pb2.RewriterConfig.ON

        enable_op = session_config.graph_options.rewrite_options.scoped_allocator_opts.enable_op
        del enable_op[:]
        enable_op.append("HorovodAllreduce")

    if FLAGS.horovod:
        session_config.gpu_options.visible_device_list = str(hvd.local_rank())
        if hvd.rank() == 0:
            tf.compat.v1.logging.info("***** Configuaration *****")
            for key in FLAGS.__flags.keys():
                tf.compat.v1.logging.info('  {}: {}'.format(
                    key, getattr(FLAGS, key)))
            tf.compat.v1.logging.info("**************************")


#    config.gpu_options.per_process_gpu_memory_fraction = 0.7
    if FLAGS.use_xla:
        session_config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
        session_config.graph_options.rewrite_options.memory_optimization = rewriter_config_pb2.RewriterConfig.NO_MEM_OPT
        if FLAGS.amp:
            tf.compat.v1.enable_resource_variables()

    run_config = tf.estimator.RunConfig(
        model_dir=FLAGS.output_dir,
        session_config=session_config,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        save_summary_steps=FLAGS.save_checkpoints_steps
        if not FLAGS.horovod else None,
        log_step_count_steps=1)

    model_fn = model_fn_builder(bert_config=bert_config,
                                init_checkpoint=FLAGS.init_checkpoint,
                                learning_rate=FLAGS.learning_rate
                                if not FLAGS.horovod else FLAGS.learning_rate *
                                hvd_size(),
                                num_train_steps=FLAGS.num_train_steps,
                                num_warmup_steps=FLAGS.num_warmup_steps,
                                use_one_hot_embeddings=False)

    estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)

    if FLAGS.do_train:

        training_hooks = []
        if horovod_enabled():
            training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

            if os.environ.get("FORCE_WEIGHT_SYNC",
                              "False").lower() in ["true", "1"]:
                # Use this hook to allreduce trainable variables before the optimizer run
                training_hooks.append(
                    TrainableVarsAllreducingHookPreOpt(
                        FLAGS.num_accumulation_steps))

        global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps if not FLAGS.horovod else FLAGS.train_batch_size * FLAGS.num_accumulation_steps * hvd.size(
        )
        training_hooks.append(
            _LogSessionRunHook(global_batch_size, FLAGS.num_accumulation_steps,
                               dllogging, FLAGS.display_loss_steps,
                               FLAGS.save_checkpoints_steps,
                               FLAGS.report_loss))

        tf.compat.v1.logging.info("***** Running training *****")
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        train_input_fn = input_fn_builder(
            input_files=input_files,
            batch_size=FLAGS.train_batch_size,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            is_training=True)

        train_start_time = time.time()
        estimator.train(input_fn=train_input_fn,
                        hooks=training_hooks,
                        max_steps=FLAGS.num_train_steps)
        train_time_elapsed = time.time() - train_start_time

        if (not FLAGS.horovod or hvd_rank() == 0):
            train_time_wo_overhead = training_hooks[-1].total_time
            avg_sentences_per_second = FLAGS.num_train_steps * global_batch_size * 1.0 / train_time_elapsed
            try:
                ss_sentences_per_second = (
                    FLAGS.num_train_steps - training_hooks[-1].skipped
                ) * global_batch_size * 1.0 / train_time_wo_overhead
                throughput_avg_wo_overhead_msg = [
                    "Throughput Average (sentences/sec) = %0.2f",
                    ss_sentences_per_second
                ]
            except:
                ss_sentences_per_second = float('nan')
                throughput_avg_wo_overhead_msg = [
                    f"Throughput Average W/O Overhead is not logged when num_train_steps < {training_hooks[-1].skip_iters}"
                ]

            tf.compat.v1.logging.info("-----------------------------")
            tf.compat.v1.logging.info(
                "Total Training Time = %0.2f for Sentences = %d",
                train_time_elapsed, FLAGS.num_train_steps * global_batch_size)
            tf.compat.v1.logging.info(
                "Total Training Time W/O Overhead = %0.2f for Sentences = %d",
                train_time_wo_overhead,
                (FLAGS.num_train_steps - training_hooks[-1].skipped) *
                global_batch_size)
            tf.compat.v1.logging.info(
                "Throughput Average (sentences/sec) with overhead = %0.2f",
                avg_sentences_per_second)
            tf.compat.v1.logging.info(*throughput_avg_wo_overhead_msg)
            dllogging.logger.log(
                step=(),
                data={"throughput_train": ss_sentences_per_second},
                verbosity=Verbosity.DEFAULT)
            tf.compat.v1.logging.info("-----------------------------")

    if FLAGS.do_eval and (not FLAGS.horovod or hvd_rank() == 0):
        tf.compat.v1.logging.info("***** Running evaluation *****")
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        eval_files = []
        for eval_file_dir in FLAGS.eval_files_dir.split(","):
            eval_files.extend(
                tf.io.gfile.glob(os.path.join(eval_file_dir, "*")))

        eval_input_fn = input_fn_builder(
            input_files=eval_files,
            batch_size=FLAGS.eval_batch_size,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            is_training=False)

        eval_hooks = [LogEvalRunHook(FLAGS.eval_batch_size)]
        eval_start_time = time.time()
        result = estimator.evaluate(input_fn=eval_input_fn,
                                    steps=FLAGS.max_eval_steps,
                                    hooks=eval_hooks)

        eval_time_elapsed = time.time() - eval_start_time
        time_list = eval_hooks[-1].time_list
        time_list.sort()
        # Removing outliers (init/warmup) in throughput computation.
        eval_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.99)])
        num_sentences = (int(len(time_list) * 0.99)) * FLAGS.eval_batch_size

        ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead

        tf.compat.v1.logging.info("-----------------------------")
        tf.compat.v1.logging.info(
            "Total Inference Time = %0.2f for Sentences = %d",
            eval_time_elapsed, eval_hooks[-1].count * FLAGS.eval_batch_size)
        tf.compat.v1.logging.info(
            "Total Inference Time W/O Overhead = %0.2f for Sentences = %d",
            eval_time_wo_overhead, num_sentences)
        tf.compat.v1.logging.info("Summary Inference Statistics on EVAL set")
        tf.compat.v1.logging.info("Batch size = %d", FLAGS.eval_batch_size)
        tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
        tf.compat.v1.logging.info("Precision = %s",
                                  "fp16" if FLAGS.amp else "fp32")
        tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f",
                                  ss_sentences_per_second)
        dllogging.logger.log(step=(),
                             data={"throughput_val": ss_sentences_per_second},
                             verbosity=Verbosity.DEFAULT)
        tf.compat.v1.logging.info("-----------------------------")

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.io.gfile.GFile(output_eval_file, "w") as writer:
            tf.compat.v1.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.compat.v1.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
def input_fn(is_training,
             data_dir,
             batch_size,
             dtype=tf.float32,
             datasets_num_private_threads=None,
             parse_record_fn=parse_record,
             input_context=None,
             drop_remainder=False,
             tf_data_experimental_slack=False,
             training_dataset_cache=False,
             filenames=None,
             experimental_preloading=False,
             use_distributed_eval=False):
    """Input function which provides batches for train or eval.

  Args:
    is_training: A boolean denoting whether the input is for training.
    data_dir: The directory containing the input data.
    batch_size: The number of samples per batch.
    dtype: Data type to use for images/features
    datasets_num_private_threads: Number of private threads for tf.data.
    parse_record_fn: Function to use for parsing the records.
    input_context: A `tf.distribute.InputContext` object passed in by
      `tf.distribute.Strategy`.
    drop_remainder: A boolean indicates whether to drop the remainder of the
      batches. If True, the batch dimension will be static.
    tf_data_experimental_slack: Whether to enable tf.data's
      `experimental_slack` option.
    training_dataset_cache: Whether to cache the training dataset on workers.
       Typically used to improve training performance when training data is in
       remote storage and can fit into worker memory.
    filenames: Optional field for providing the file names of the TFRecords.

  Returns:
    A dataset that can be used for iteration.
  """
    if filenames is None:
        filenames = get_filenames(is_training, data_dir)
    dataset = tf.data.Dataset.from_tensor_slices(filenames)

    if horovod_enabled() and (is_training or use_distributed_eval):
        logging.info(
            'HVD sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d',
            hvd.rank(), hvd.size())
        dataset = dataset.shard(hvd.size(), hvd.rank())

    if input_context:
        logging.info(
            'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d',
            input_context.input_pipeline_id, input_context.num_input_pipelines)
        dataset = dataset.shard(input_context.num_input_pipelines,
                                input_context.input_pipeline_id)

    if is_training:
        # Shuffle the input files
        dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES,
                                  seed=flags.FLAGS.shuffle_seed)

    # Convert to individual records.
    # cycle_length = 10 means that up to 10 files will be read and deserialized in
    # parallel. You may want to increase this number if you have a large number of
    # CPU cores.
    dataset = dataset.interleave(
        tf.data.TFRecordDataset,
        cycle_length=flags.FLAGS.interleave_cycle_length,
        num_parallel_calls=flags.FLAGS.dataset_parallel_calls)

    if is_training and training_dataset_cache:
        # Improve training performance when training data is in remote storage and
        # can fit into worker memory.
        dataset = dataset.cache()

    return process_record_dataset(
        dataset=dataset,
        is_training=is_training,
        batch_size=batch_size,
        shuffle_buffer=_SHUFFLE_BUFFER,
        parse_record_fn=parse_record_fn,
        dtype=dtype,
        datasets_num_private_threads=datasets_num_private_threads,
        drop_remainder=drop_remainder,
        tf_data_experimental_slack=tf_data_experimental_slack,
        experimental_preloading=experimental_preloading)
Ejemplo n.º 21
0
  def __call__(self, params=None):
    if params is None:
      params = self._params
    input_anchors = anchors.Anchors(params['min_level'], params['max_level'],
                                    params['num_scales'],
                                    params['aspect_ratios'],
                                    params['anchor_scale'],
                                    params['image_size'])
    anchor_labeler = anchors.AnchorLabeler(input_anchors, params['num_classes'])
    example_decoder = tf_example_decoder.TfExampleDecoder()

    def _dataset_parser(value):
      """Parse data to a fixed dimension input image and learning targets.

      Args:
        value: A dictionary contains an image and groundtruth annotations.

      Returns:
        image: Image tensor that is preprocessed to have normalized value and
          fixed dimension [image_size, image_size, 3]
        cls_targets_dict: ordered dictionary with keys
          [min_level, min_level+1, ..., max_level]. The values are tensor with
          shape [height_l, width_l, num_anchors]. The height_l and width_l
          represent the dimension of class logits at l-th level.
        box_targets_dict: ordered dictionary with keys
          [min_level, min_level+1, ..., max_level]. The values are tensor with
          shape [height_l, width_l, num_anchors * 4]. The height_l and
          width_l represent the dimension of bounding box regression output at
          l-th level.
        num_positives: Number of positive anchors in the image.
        source_id: Source image id. Default value -1 if the source id is empty
          in the groundtruth annotation.
        image_scale: Scale of the processed image to the original image.
        boxes: Groundtruth bounding box annotations. The box is represented in
          [y1, x1, y2, x2] format. The tensor is padded with -1 to the fixed
          dimension [self._max_num_instances, 4].
        is_crowds: Groundtruth annotations to indicate if an annotation
          represents a group of instances by value {0, 1}. The tensor is
          padded with 0 to the fixed dimension [self._max_num_instances].
        areas: Groundtruth areas annotations. The tensor is padded with -1
          to the fixed dimension [self._max_num_instances].
        classes: Groundtruth classes annotations. The tensor is padded with -1
          to the fixed dimension [self._max_num_instances].
      """
      with tf.name_scope('parser'):
        data = example_decoder.decode(value)
        source_id = data['source_id']
        image = data['image']
        boxes = data['groundtruth_boxes']
        classes = data['groundtruth_classes']
        classes = tf.reshape(tf.cast(classes, dtype=tf.float32), [-1, 1])
        areas = data['groundtruth_area']
        is_crowds = data['groundtruth_is_crowd']
        classes = tf.reshape(tf.cast(classes, dtype=tf.float32), [-1, 1])

        if params['skip_crowd_during_training'] and self._is_training:
          indices = tf.where(tf.logical_not(data['groundtruth_is_crowd']))
          classes = tf.gather_nd(classes, indices)
          boxes = tf.gather_nd(boxes, indices)

        # NOTE: The autoaugment method works best when used alongside the
        # standard horizontal flipping of images along with size jittering
        # and normalization.
        if params.get('autoaugment_policy', None) and self._is_training:
          from aug import autoaugment  # pylint: disable=g-import-not-at-top
          image, boxes = autoaugment.distort_image_with_autoaugment(
              image, boxes, params['autoaugment_policy'])

        input_processor = DetectionInputProcessor(
            image, params['image_size'], boxes, classes)
        input_processor.normalize_image()
        if self._is_training and params['input_rand_hflip']:
          input_processor.random_horizontal_flip()
        if self._is_training:
          input_processor.set_training_random_scale_factors(
              params['train_scale_min'], params['train_scale_max'])
        else:
          input_processor.set_scale_factors_to_output_size()
        image = input_processor.resize_and_crop_image()
        boxes, classes = input_processor.resize_and_crop_boxes()

        # Assign anchors.
        (cls_targets, box_targets,
         num_positives) = anchor_labeler.label_anchors(boxes, classes)

        source_id = tf.where(tf.equal(source_id, tf.constant('')), '-1',
                             source_id)
        source_id = tf.string_to_number(source_id)

        # Pad groundtruth data for evaluation.
        image_scale = input_processor.image_scale_to_original
        boxes *= image_scale
        is_crowds = tf.cast(is_crowds, dtype=tf.float32)
        boxes = pad_to_fixed_size(boxes, -1, [self._max_num_instances, 4])
        is_crowds = pad_to_fixed_size(is_crowds, 0,
                                      [self._max_num_instances, 1])
        areas = pad_to_fixed_size(areas, -1, [self._max_num_instances, 1])
        classes = pad_to_fixed_size(classes, -1, [self._max_num_instances, 1])
        if params['use_bfloat16']:
          image = tf.cast(image, dtype=tf.bfloat16)
        return (image, cls_targets, box_targets, num_positives, source_id,
                image_scale, boxes, is_crowds, areas, classes)

    dataset = tf.data.Dataset.list_files(
        self._file_pattern, shuffle=self._is_training)

    if horovod_enabled() and self._is_training: #multi card eval is not supported yet
      # 根据 GPU 数量做 shard 均分
      dataset = dataset.shard(hvd.size(), hvd.rank())

    if self._is_training:
      dataset = dataset.repeat()

    # Prefetch data from files.
    def _prefetch_dataset(filename):
      dataset = tf.data.TFRecordDataset(filename).prefetch(1)
      return dataset

    cycle_length = 1 if self._is_deterministic else 32
    dataset = dataset.apply(
        tf.data.experimental.parallel_interleave(
            _prefetch_dataset, cycle_length=cycle_length, sloppy=self._is_training))
    if self._is_training:
      dataset = dataset.shuffle(64)

    # Parse the fetched records to input tensors for model function.
    num_parallel_calls = 1 if self._is_deterministic else 64
    dataset = dataset.map(_dataset_parser, num_parallel_calls=num_parallel_calls)
    batch_size = params['batch_size']
    dataset = dataset.prefetch(batch_size)
    dataset = dataset.batch(batch_size, drop_remainder=True)

    def _process_example(images, cls_targets, box_targets, num_positives,
                         source_ids, image_scales, boxes, is_crowds, areas,
                         classes):
      """Processes one batch of data."""
      labels = {}
      # Count num_positives in a batch.
      num_positives_batch = tf.reduce_mean(num_positives)
      labels['mean_num_positives'] = tf.reshape(
          tf.tile(tf.expand_dims(num_positives_batch, 0), [
              batch_size,
          ]), [batch_size, 1])

      for level in range(params['min_level'], params['max_level'] + 1):
        labels['cls_targets_%d' % level] = cls_targets[level]
        labels['box_targets_%d' % level] = box_targets[level]
      # Concatenate groundtruth annotations to a tensor.
      groundtruth_data = tf.concat([boxes, is_crowds, areas, classes], axis=2)
      labels['source_ids'] = source_ids
      labels['groundtruth_data'] = groundtruth_data
      labels['image_scales'] = image_scales
      return images, labels

    dataset = dataset.map(_process_example)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    if self._use_fake_data:
      # Turn this dataset into a semi-fake dataset which always loop at the
      # first batch. This reduces variance in performance and is useful in
      # testing.
      dataset = dataset.take(1).cache().repeat()
    return dataset
Ejemplo n.º 22
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    processors = {
        "cola": classifier_utils.ColaProcessor,
        "mnli": classifier_utils.MnliProcessor,
        "mismnli": classifier_utils.MisMnliProcessor,
        "mrpc": classifier_utils.MrpcProcessor,
        "rte": classifier_utils.RteProcessor,
        "sst-2": classifier_utils.Sst2Processor,
        "sts-b": classifier_utils.StsbProcessor,
        "qqp": classifier_utils.QqpProcessor,
        "qnli": classifier_utils.QnliProcessor,
        "wnli": classifier_utils.WnliProcessor,
    }

    if not (FLAGS.do_train or FLAGS.do_eval or FLAGS.do_predict
            or FLAGS.export_dir):
        raise ValueError(
            "At least one of `do_train`, `do_eval`, `do_predict' or `export_dir` "
            "must be True.")

    if not FLAGS.albert_config_file and not FLAGS.albert_hub_module_handle:
        raise ValueError("At least one of `--albert_config_file` and "
                         "`--albert_hub_module_handle` must be set")

    if FLAGS.albert_config_file:
        albert_config = modeling.AlbertConfig.from_json_file(
            FLAGS.albert_config_file)
        if FLAGS.max_seq_length > albert_config.max_position_embeddings:
            raise ValueError(
                "Cannot use sequence length %d because the ALBERT model "
                "was only trained up to sequence length %d" %
                (FLAGS.max_seq_length, albert_config.max_position_embeddings))
    else:
        albert_config = None  # Get the config from TF-Hub.

    if FLAGS.deterministic_run and (albert_config.attention_probs_dropout_prob
                                    or albert_config.hidden_dropout_prob):
        albert_config.attention_probs_dropout_prob = 0.0
        albert_config.hidden_dropout_prob = 0.0

    tf.gfile.MakeDirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name](
        use_spm=True if FLAGS.spm_model_file else False,
        do_lower_case=FLAGS.do_lower_case)

    label_list = processor.get_labels()

    tokenizer = fine_tuning_utils.create_vocab(
        vocab_file=FLAGS.vocab_file,
        do_lower_case=FLAGS.do_lower_case,
        spm_model_file=FLAGS.spm_model_file,
        hub_module=FLAGS.albert_hub_module_handle)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    model_dir = FLAGS.output_dir
    if horovod_enabled():
        model_dir = os.path.join(FLAGS.output_dir, "worker_" + str(hvd.rank()))

    # The Scoped Allocator Optimization is enabled by default unless disabled by a flag.
    if FLAGS.enable_scoped_allocator:
        from tensorflow.core.protobuf import rewriter_config_pb2  # pylint: disable=import-error

        session_config = tf.compat.v1.ConfigProto()
        session_config.graph_options.rewrite_options.scoped_allocator_optimization = rewriter_config_pb2.RewriterConfig.ON

        enable_op = session_config.graph_options.rewrite_options.scoped_allocator_opts.enable_op
        del enable_op[:]
        enable_op.append("HorovodAllreduce")
    else:
        session_config = None

    is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2
    if FLAGS.do_train:
        iterations_per_loop = int(
            min(FLAGS.iterations_per_loop, FLAGS.save_checkpoints_steps))
    else:
        iterations_per_loop = FLAGS.iterations_per_loop

    run_config = tf.estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=model_dir,
        save_checkpoints_steps=int(FLAGS.save_checkpoints_steps),
        keep_checkpoint_max=0,
        save_summary_steps=FLAGS.save_summary_steps,
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host),
        session_config=session_config)

    train_examples = None

    train_batch_size = FLAGS.train_batch_size
    if horovod_enabled():
        train_batch_size = train_batch_size * hvd.size()

    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)

        start_index = 0
        end_index = len(train_examples)
        worker_id = 0
        per_worker_filenames = [
            os.path.join(FLAGS.output_dir, "train.tf_record")
        ]

        if horovod_enabled():
            per_worker_filenames = [
                os.path.join(FLAGS.output_dir, "train.tf_record_{}".format(i))
                for i in range(hvd.size())
            ]
            num_examples_per_rank = len(train_examples) // hvd.size()
            remainder = len(train_examples) % hvd.size()
            worker_id = hvd.rank()
            if worker_id < remainder:
                start_index = worker_id * (num_examples_per_rank + 1)
                end_index = start_index + num_examples_per_rank + 1
            else:
                start_index = worker_id * num_examples_per_rank + remainder
                end_index = start_index + (num_examples_per_rank)

    learning_rate = FLAGS.learning_rate
    if horovod_enabled():
        learning_rate = learning_rate * hvd.size()

    model_fn = classifier_utils.model_fn_builder(
        albert_config=albert_config,
        num_labels=len(label_list),
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=learning_rate,
        num_train_steps=FLAGS.train_step,
        num_warmup_steps=FLAGS.warmup_step,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_tpu,
        task_name=task_name,
        hub_module=FLAGS.albert_hub_module_handle,
        optimizer=FLAGS.optimizer)

    if not math.isnan(FLAGS.threshold_to_export):
        model_fn = _add_threshold_to_model_fn(model_fn,
                                              FLAGS.threshold_to_export)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.estimator.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.predict_batch_size,
        export_to_tpu=False)  # http://yaqs/4707241341091840

    write_hparams_v1(
        FLAGS.output_dir, {
            'batch_size': FLAGS.train_batch_size,
            **{x: getattr(FLAGS, x)
               for x in FLAGS}
        })

    if FLAGS.do_train:
        if FLAGS.deterministic_run and not horovod_enabled(
        ) and FLAGS.input_file:
            per_worker_filenames = [
                os.path.join(FLAGS.input_file, "train.tf_record")
            ]
        if not tf.gfile.Exists(per_worker_filenames[worker_id]):
            classifier_utils.file_based_convert_examples_to_features(
                train_examples[start_index:end_index], label_list,
                FLAGS.max_seq_length, tokenizer,
                per_worker_filenames[worker_id], task_name)
        tf.logging.info("***** Running training *****")
        tf.logging.info("  Num examples = %d", len(train_examples))
        tf.logging.info("  Per-worker batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Total batch size = %d", train_batch_size)
        tf.logging.info("  Num steps = %d", FLAGS.train_step)
        train_input_fn = classifier_utils.file_based_input_fn_builder(
            input_file=per_worker_filenames,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True,
            task_name=task_name,
            use_tpu=FLAGS.use_tpu,
            bsz=FLAGS.train_batch_size)

        train_hooks = [
            habana_hooks.PerfLoggingHook(batch_size=train_batch_size,
                                         mode="train")
        ]
        if horovod_enabled():
            train_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

        if "range" == os.environ.get("HABANA_SYNAPSE_LOGGER", "False").lower():
            from habana_frameworks.tensorflow.synapse_logger_helpers import SynapseLoggerHook
            begin = 30
            end = begin + 10
            print("Begin: {}".format(begin))
            print("End: {}".format(end))
            train_hooks.append(
                SynapseLoggerHook(list(range(begin, end)), False))

        with dump_callback():
            estimator.train(input_fn=train_input_fn,
                            max_steps=FLAGS.train_step,
                            hooks=train_hooks)

    if FLAGS.do_eval:
        eval_examples = processor.get_dev_examples(FLAGS.data_dir)
        num_actual_eval_examples = len(eval_examples)
        if FLAGS.use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on. These do NOT count towards the metric (all tf.metrics
            # support a per-instance weight, and these get a weight of 0.0).
            while len(eval_examples) % FLAGS.eval_batch_size != 0:
                eval_examples.append(classifier_utils.PaddingInputExample())

        cached_dir = FLAGS.cached_dir
        if not cached_dir:
            cached_dir = FLAGS.output_dir
        eval_file = os.path.join(cached_dir, task_name + "_eval.tf_record")
        if not tf.gfile.Exists(eval_file):
            classifier_utils.file_based_convert_examples_to_features(
                eval_examples, label_list, FLAGS.max_seq_length, tokenizer,
                eval_file, task_name)

        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(eval_examples), num_actual_eval_examples,
                        len(eval_examples) - num_actual_eval_examples)
        tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        # This tells the estimator to run through the entire set.
        eval_steps = None
        # However, if running eval on the TPU, you will need to specify the
        # number of steps.
        if FLAGS.use_tpu:
            assert len(eval_examples) % FLAGS.eval_batch_size == 0
            eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)

        eval_drop_remainder = True if FLAGS.use_tpu else False
        eval_input_fn = classifier_utils.file_based_input_fn_builder(
            input_file=eval_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=eval_drop_remainder,
            task_name=task_name,
            use_tpu=FLAGS.use_tpu,
            bsz=FLAGS.eval_batch_size)

        eval_hooks = [
            habana_hooks.PerfLoggingHook(batch_size=FLAGS.eval_batch_size,
                                         mode="eval")
        ]
        best_trial_info_file = os.path.join(FLAGS.output_dir, "best_trial.txt")

        def _best_trial_info():
            """Returns information about which checkpoints have been evaled so far."""
            if tf.gfile.Exists(best_trial_info_file):
                with tf.gfile.GFile(best_trial_info_file, "r") as best_info:
                    global_step, best_metric_global_step, metric_value = (
                        best_info.read().split(":"))
                    global_step = int(global_step)
                    best_metric_global_step = int(best_metric_global_step)
                    metric_value = float(metric_value)
            else:
                metric_value = -1
                best_metric_global_step = -1
                global_step = -1
            tf.logging.info(
                "Best trial info: Step: %s, Best Value Step: %s, "
                "Best Value: %s", global_step, best_metric_global_step,
                metric_value)
            return global_step, best_metric_global_step, metric_value

        def _remove_checkpoint(checkpoint_path):
            for ext in ["meta", "data-00000-of-00001", "index"]:
                src_ckpt = checkpoint_path + ".{}".format(ext)
                tf.logging.info("removing {}".format(src_ckpt))
                tf.gfile.Remove(src_ckpt)

        def _find_valid_cands(curr_step):
            filenames = tf.gfile.ListDirectory(model_dir)
            candidates = []
            for filename in filenames:
                if filename.endswith(".index"):
                    ckpt_name = filename[:-6]
                    idx = ckpt_name.split("-")[-1]
                    if int(idx) > curr_step:
                        candidates.append(filename)
            return candidates

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")

        if task_name == "sts-b":
            key_name = "pearson"
        elif task_name == "cola":
            key_name = "matthew_corr"
        else:
            key_name = "eval_accuracy"

        global_step, best_perf_global_step, best_perf = _best_trial_info()
        writer = tf.gfile.GFile(output_eval_file, "w")
        while global_step < FLAGS.train_step:
            steps_and_files = {}
            filenames = tf.gfile.ListDirectory(model_dir)
            for filename in filenames:
                if filename.endswith(".index"):
                    ckpt_name = filename[:-6]
                    cur_filename = os.path.join(model_dir, ckpt_name)
                    if cur_filename.split("-")[-1] == "best":
                        continue
                    gstep = int(cur_filename.split("-")[-1])
                    if gstep not in steps_and_files:
                        tf.logging.info(
                            "Add {} to eval list.".format(cur_filename))
                        steps_and_files[gstep] = cur_filename
            tf.logging.info("found {} files.".format(len(steps_and_files)))
            if not steps_and_files:
                tf.logging.info(
                    "found 0 file, global step: {}. Sleeping.".format(
                        global_step))
                time.sleep(60)
            else:
                for checkpoint in sorted(steps_and_files.items()):
                    step, checkpoint_path = checkpoint
                    if global_step >= step:
                        if (best_perf_global_step != step
                                and len(_find_valid_cands(step)) > 1):
                            _remove_checkpoint(checkpoint_path)
                        continue
                    result = estimator.evaluate(
                        input_fn=eval_input_fn,
                        steps=eval_steps,
                        checkpoint_path=checkpoint_path,
                        hooks=eval_hooks)
                    global_step = result["global_step"]
                    tf.logging.info("***** Eval results *****")
                    for key in sorted(result.keys()):
                        tf.logging.info("  %s = %s", key, str(result[key]))
                        writer.write("%s = %s\n" % (key, str(result[key])))
                    writer.write("best = {}\n".format(best_perf))
                    if result[key_name] > best_perf:
                        best_perf = result[key_name]
                        best_perf_global_step = global_step
                    elif len(_find_valid_cands(global_step)) > 1:
                        _remove_checkpoint(checkpoint_path)
                    writer.write("=" * 50 + "\n")
                    writer.flush()
                    with tf.gfile.GFile(best_trial_info_file,
                                        "w") as best_info:
                        best_info.write("{}:{}:{}".format(
                            global_step, best_perf_global_step, best_perf))
        writer.close()

        for ext in ["meta", "data-00000-of-00001", "index"]:
            src_ckpt = "model.ckpt-{}.{}".format(best_perf_global_step, ext)
            tgt_ckpt = "model.ckpt-best.{}".format(ext)
            tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt))
            tf.io.gfile.rename(os.path.join(model_dir, src_ckpt),
                               os.path.join(model_dir, tgt_ckpt),
                               overwrite=True)

    if FLAGS.do_predict:
        predict_examples = processor.get_test_examples(FLAGS.data_dir)
        num_actual_predict_examples = len(predict_examples)
        if FLAGS.use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on.
            while len(predict_examples) % FLAGS.predict_batch_size != 0:
                predict_examples.append(classifier_utils.PaddingInputExample())

        predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
        classifier_utils.file_based_convert_examples_to_features(
            predict_examples, label_list, FLAGS.max_seq_length, tokenizer,
            predict_file, task_name)

        tf.logging.info("***** Running prediction*****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(predict_examples), num_actual_predict_examples,
                        len(predict_examples) - num_actual_predict_examples)
        tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

        predict_drop_remainder = True if FLAGS.use_tpu else False
        predict_input_fn = classifier_utils.file_based_input_fn_builder(
            input_file=predict_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=predict_drop_remainder,
            task_name=task_name,
            use_tpu=FLAGS.use_tpu,
            bsz=FLAGS.predict_batch_size)

        checkpoint_path = os.path.join(model_dir, "model.ckpt-best")
        result = estimator.predict(input_fn=predict_input_fn,
                                   checkpoint_path=checkpoint_path)

        output_predict_file = os.path.join(FLAGS.output_dir,
                                           "test_results.tsv")
        output_submit_file = os.path.join(FLAGS.output_dir,
                                          "submit_results.tsv")
        with tf.gfile.GFile(output_predict_file, "w") as pred_writer,\
            tf.gfile.GFile(output_submit_file, "w") as sub_writer:
            sub_writer.write("index" + "\t" + "prediction\n")
            num_written_lines = 0
            tf.logging.info("***** Predict results *****")
            for (i, (example, prediction)) in\
                enumerate(zip(predict_examples, result)):
                probabilities = prediction["probabilities"]
                if i >= num_actual_predict_examples:
                    break
                output_line = "\t".join(
                    str(class_probability)
                    for class_probability in probabilities) + "\n"
                pred_writer.write(output_line)

                if task_name != "sts-b":
                    actual_label = label_list[int(prediction["predictions"])]
                else:
                    actual_label = str(prediction["predictions"])
                sub_writer.write(example.guid + "\t" + actual_label + "\n")
                num_written_lines += 1
        assert num_written_lines == num_actual_predict_examples

    if FLAGS.export_dir:
        tf.gfile.MakeDirs(FLAGS.export_dir)
        checkpoint_path = os.path.join(model_dir, "model.ckpt-best")
        tf.logging.info("Starting to export model.")
        subfolder = estimator.export_saved_model(
            export_dir_base=FLAGS.export_dir,
            serving_input_receiver_fn=_serving_input_receiver_fn,
            checkpoint_path=checkpoint_path)
        tf.logging.info("Model exported to %s.", subfolder)