Exemple #1
0
def predict_distributed(model,
                        x=None,
                        batch_size=None,
                        verbose=0,
                        steps=None,
                        callbacks=None):
  """Predict loop for Distribution Strategies."""
  distributed_training_utils.validate_inputs(x, None)
  first_x_value = nest.flatten(x)[0]
  if isinstance(first_x_value, np.ndarray):
    steps, batch_size = distributed_training_utils.get_input_params(
        model._distribution_strategy, first_x_value, steps,
        batch_size, mode=ModeKeys.PREDICT)
  batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
  dataset = model._distribution_standardize_user_data(
      x,
      batch_size=batch_size,
      allow_partial_batch=True)
  if distributed_training_utils.is_tpu_strategy(model._distribution_strategy):
    return experimental_tpu_predict_loop(
        model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
  else:
    return training_arrays.predict_loop(
        model,
        dataset,
        batch_size=batch_size,
        verbose=verbose,
        steps=steps,
        callbacks=callbacks)
Exemple #2
0
def evaluate_distributed(model,
                         x=None,
                         y=None,
                         batch_size=None,
                         verbose=1,
                         sample_weight=None,
                         steps=None,
                         callbacks=None):
    """Evaluate loop for Distribution Strategies."""
    distributed_training_utils.validate_inputs(x, y,
                                               model._distribution_strategy)
    first_x_value = nest.flatten(x)[0]
    if isinstance(first_x_value, np.ndarray):
        steps, batch_size = distributed_training_utils.get_input_params(
            model._distribution_strategy, first_x_value, steps, batch_size)
    batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
    dataset = model._distribution_standardize_user_data(
        x, y, sample_weight=sample_weight, batch_size=batch_size)

    if distributed_training_utils.is_tpu_strategy(
            model._distribution_strategy):
        return experimental_tpu_test_loop(model,
                                          dataset,
                                          verbose=verbose,
                                          steps=steps,
                                          callbacks=callbacks)
    else:
        return training_arrays.test_loop(model,
                                         inputs=dataset,
                                         batch_size=batch_size,
                                         verbose=verbose,
                                         steps=steps,
                                         callbacks=callbacks)
    def evaluate(self,
                 model,
                 x=None,
                 y=None,
                 batch_size=None,
                 verbose=1,
                 sample_weight=None,
                 steps=None,
                 callbacks=None,
                 **kwargs):
        """Evaluate loop for Distribution Strategies."""
        dist_utils.validate_inputs(x, y)
        batch_size, steps = self._process_batch_and_step_size(
            model, x, batch_size, steps, ModeKeys.TEST)
        batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
        dataset = model._distribution_standardize_user_data(
            x,
            y,
            sample_weight=sample_weight,
            batch_size=batch_size,
            allow_partial_batch=True)

        if dist_utils.is_tpu_strategy(model._distribution_strategy):
            return experimental_tpu_test_loop(model,
                                              dataset,
                                              verbose=verbose,
                                              steps=steps,
                                              callbacks=callbacks)
        else:
            return training_arrays.test_loop(model,
                                             inputs=dataset,
                                             batch_size=batch_size,
                                             verbose=verbose,
                                             steps=steps,
                                             callbacks=callbacks)
Exemple #4
0
def predict_distributed(model,
                        x=None,
                        batch_size=None,
                        verbose=0,
                        steps=None,
                        callbacks=None):
  """Predict loop for Distribution Strategies."""
  distributed_training_utils.validate_inputs(
      x, None, model._distribution_strategy, allow_partial_batch=True)
  first_x_value = nest.flatten(x)[0]
  if isinstance(first_x_value, np.ndarray):
    steps, batch_size = distributed_training_utils.get_input_params(
        model._distribution_strategy, first_x_value, steps,
        batch_size, mode=ModeKeys.PREDICT)
  batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
  dataset = model._distribution_standardize_user_data(
      x,
      batch_size=batch_size,
      allow_partial_batch=True)
  if distributed_training_utils.is_tpu_strategy(model._distribution_strategy):
    return experimental_tpu_predict_loop(
        model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
  else:
    return training_arrays.predict_loop(
        model,
        dataset,
        batch_size=batch_size,
        verbose=verbose,
        steps=steps,
        callbacks=callbacks)
def unwrap_output_dict(strategy, grouped_outputs, mode):
  """Unwrap the list of outputs contained in the PerReplica parameters."""
  if mode == ModeKeys.PREDICT:
    return flatten_per_replica_values(strategy, grouped_outputs)

  # In the case of fit/eval, the grouped_outputs is a dict, whereas in predict,
  # the output is as same structure as model output. They need to be treated
  # differently
  total_loss = strategy.reduce(reduce_util.ReduceOp.SUM,
                               grouped_outputs['total_loss'][0], axis=None)
  output_losses = flatten_per_replica_values(strategy,
                                             grouped_outputs['output_losses'])
  metrics = flatten_per_replica_values(strategy,
                                       grouped_outputs['metrics'])
  batch_size = strategy.reduce(reduce_util.ReduceOp.SUM,
                               grouped_outputs['batch_size'], axis=None)
  if (dist_utils.is_tpu_strategy(strategy) and
      ops.executing_eagerly_outside_functions()):
    # Choose 1 value per replica in the TPU case since all replicas produce the
    # same output.
    # We only do this in eager mode for now since this function is used in
    # both graph and eager mode and in the graph case we currently don't use
    # experimental_run so would need to be removed when we converge the graph
    # code path as well.
    output_losses = output_losses[::strategy.num_replicas_in_sync]
    metrics = metrics[::strategy.num_replicas_in_sync]
  return {'total_loss': [total_loss],
          'output_losses': output_losses,
          'metrics': metrics,
          'batch_size': batch_size}
Exemple #6
0
def evaluate_distributed(model,
                         x=None,
                         y=None,
                         batch_size=None,
                         verbose=1,
                         sample_weight=None,
                         steps=None,
                         callbacks=None):
  """Evaluate loop for Distribution Strategies."""
  distributed_training_utils.validate_inputs(x, y, model._distribution_strategy)
  first_x_value = nest.flatten(x)[0]
  if isinstance(first_x_value, np.ndarray):
    steps, batch_size = distributed_training_utils.get_input_params(
        model._distribution_strategy, first_x_value, steps, batch_size)
  batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
  dataset = model._distribution_standardize_user_data(
      x, y,
      sample_weight=sample_weight,
      batch_size=batch_size)

  if distributed_training_utils.is_tpu_strategy(model._distribution_strategy):
    return experimental_tpu_test_loop(
        model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
  else:
    return training_arrays.test_loop(
        model,
        inputs=dataset,
        batch_size=batch_size,
        verbose=verbose,
        steps=steps,
        callbacks=callbacks)
 def predict(self,
             model,
             x,
             batch_size=None,
             verbose=0,
             steps=None,
             callbacks=None,
             **kwargs):
     """Predict loop for Distribution Strategies."""
     dist_utils.validate_inputs(x=x, y=None)
     batch_size, steps = self._process_batch_and_step_size(
         model, x, batch_size, steps, ModeKeys.PREDICT)
     batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
     dataset = model._distribution_standardize_user_data(
         x, batch_size=batch_size, allow_partial_batch=True)
     if dist_utils.is_tpu_strategy(model._distribution_strategy):
         return experimental_tpu_predict_loop(model,
                                              dataset,
                                              verbose=verbose,
                                              steps=steps,
                                              callbacks=callbacks)
     else:
         return training_arrays.predict_loop(model,
                                             dataset,
                                             batch_size=batch_size,
                                             verbose=verbose,
                                             steps=steps,
                                             callbacks=callbacks)
 def predict(self,
             model,
             x,
             batch_size=None,
             verbose=0,
             steps=None,
             callbacks=None,
             **kwargs):
     """Predict loop for Distribution Strategies."""
     dist_utils.validate_inputs(x=x, y=None)
     batch_size, steps = dist_utils.process_batch_and_step_size(
         model._distribution_strategy, x, batch_size, steps,
         ModeKeys.PREDICT)
     batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
     dataset = model._distribution_standardize_user_data(
         x, batch_size=batch_size, allow_partial_batch=True)
     if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
         steps = training_utils_v1.infer_steps_for_dataset(
             model, dataset, steps, steps_name='steps')
         if steps is None:
             raise ValueError(
                 'Number of steps could not be inferred from the data, '
                 'please pass the steps argument.')
         if not context.executing_eagerly():
             return experimental_tpu_predict_loop(model,
                                                  dataset,
                                                  verbose=verbose,
                                                  steps=steps,
                                                  callbacks=callbacks)
     return training_arrays_v1.predict_loop(model,
                                            dataset,
                                            batch_size=batch_size,
                                            verbose=verbose,
                                            steps=steps,
                                            callbacks=callbacks)
def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
    """Prepare feed values to the model execution function.

  Arguments:
    model: Model to prepare feed values for.
    inputs: List or dict of model inputs.
    targets: Optional list of model targets.
    sample_weights: Optional list of sample weight arrays.
    mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.

  Returns:
    Feed values for the model in the given mode.
  """
    strategy = model._distribution_strategy
    inputs, targets, sample_weights = _get_input_from_iterator(inputs, model)
    if dist_utils.is_tpu_strategy(strategy):
        if sample_weights is not None:
            raise ValueError('TPUStrategy does not support sample weights.')

    # When the inputs are dict, then we want to flatten it in the same order as
    # the input layers, such that the data are fed into the input layers in the
    # correct order.
    if isinstance(inputs, dict):
        inputs = [inputs[key] for key in model._feed_input_names]
    if is_distributing_by_cloning(model):
        inputs = flatten_per_replica_values(strategy, inputs)
        targets = flatten_per_replica_values(strategy, targets)
        # Expand 1-dimensional inputs.
        # TODO(b/124535720): Remove once this standarize data logic is shared with
        # main flow.
        inputs, targets = nest.map_structure(
            training_utils.standardize_single_array, (inputs, targets))
    else:
        inputs = training_utils.ModelInputs(inputs).as_list()

    if mode == ModeKeys.PREDICT:
        sample_weights = []
        targets = []
    elif sample_weights is not None and is_distributing_by_cloning(model):
        if context.executing_eagerly() and not model._compile_distribution:
            raise NotImplementedError(
                '`sample_weight` is not supported when using '
                'tf.distribute.Strategy in eager mode and '
                'cloning=True.')
        sample_weights = flatten_per_replica_values(strategy, sample_weights)

    ins = [inputs, targets, sample_weights]
    return tuple(ins)
Exemple #10
0
    def evaluate(self,
                 model,
                 x=None,
                 y=None,
                 batch_size=None,
                 verbose=1,
                 sample_weight=None,
                 steps=None,
                 callbacks=None,
                 **kwargs):
        """Evaluate loop for Distribution Strategies."""
        dist_utils.validate_inputs(x, y)
        batch_size, steps = dist_utils.process_batch_and_step_size(
            model._distribution_strategy, x, batch_size, steps, ModeKeys.TEST)
        batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
        dataset = model._distribution_standardize_user_data(
            x,
            y,
            sample_weight=sample_weight,
            batch_size=batch_size,
            allow_partial_batch=True)

        if dist_utils.is_tpu_strategy(model._distribution_strategy):
            steps = training_utils.infer_steps_for_dataset(model,
                                                           dataset,
                                                           steps,
                                                           steps_name='steps')
            if steps is None:
                raise ValueError(
                    'Number of steps could not be inferred from the data, '
                    'please pass the steps argument.')

            if not context.executing_eagerly():
                # Run TPU evaluation in a custom loop in graph mode.
                return experimental_tpu_test_loop(model,
                                                  dataset,
                                                  verbose=verbose,
                                                  steps=steps,
                                                  callbacks=callbacks)

        return training_arrays_v1.test_loop(model,
                                            inputs=dataset,
                                            batch_size=batch_size,
                                            verbose=verbose,
                                            steps=steps,
                                            callbacks=callbacks)
def is_distributing_by_cloning(model):
  """Decide whether this model is going to be distributed via cloning.

  We are going to distribute the model by cloning in graph mode.

  Args:
    model: Keras model to distribute.

  Returns:
    True if the `model` is going to be distributed using cloning and False
    otherwise.
  """
  if (dist_utils.is_tpu_strategy(model._distribution_strategy) and
      context.executing_eagerly):  # b/137580852
    return False
  elif ops.executing_eagerly_outside_functions():
    return bool(model._compile_distribution)
  return True
def unwrap_outputs(distribution_strategy,
                   grouped_outputs,
                   with_loss_tensor=False):
    """Unwrap the list of outputs contained in the PerReplica parameters.

  This function calls `flatten_per_replica_values` to parse each of the input
  parameters into a list of outputs on the different devices. If we set
  `with_loss_tensor` to be True, we also call `reduce` on the list of losses on
  the different devices to give us one loss tensor.

  Args:
    distribution_strategy: DistributionStrategy used to distribute training and
        validation.
    grouped_outputs: PerReplica outputs returned from the train or test function
        that we ran on each device.
    with_loss_tensor: Boolean that indicates if we need to add the reduced loss
        tensor as one of the outputs.

  Returns:
    Values of each of the PerReplica outputs.

  """
    if not with_loss_tensor:
        return flatten_per_replica_values(distribution_strategy,
                                          grouped_outputs)

    if not isinstance(grouped_outputs, list):
        grouped_outputs = [grouped_outputs]
    # reduce loss tensor before adding it to the list of fetches
    loss = distribution_strategy.reduce(reduce_util.ReduceOp.SUM,
                                        grouped_outputs[0],
                                        axis=None)
    all_outputs = flatten_per_replica_values(distribution_strategy,
                                             grouped_outputs[1:])
    if (dist_utils.is_tpu_strategy(distribution_strategy)
            and ops.executing_eagerly_outside_functions()):
        # Choose 1 value per replica in the TPU case since all replicas produce the
        # same output.
        # We only do this in eager mode for now since this function is used in
        # both graph and eager mode and in the graph case we currently don't use
        # experimental_run so would need to be removed when we converge the graph
        # code path as well.
        all_outputs = all_outputs[::distribution_strategy.num_replicas_in_sync]
    return [loss] + all_outputs
Exemple #13
0
def _should_add_batch_index_to_element(strategy, mode):
    """Whether or not the batch index should be added to the input dataset.

  See docstring of _add_batch_index_to_element() for more details. So far the
  batch index is only need when using TPUStrategy with a multi-worker setting.
  We will try to avoid adding batch index for other cases since it has the
  performance implication.

  Args:
    strategy: the current distribution strategy for the model.
    mode: the current mode (Training/Eval/Predict) for the model.
  Returns:
    Boolean, whether the batch index should be added for the input data to
      preserve the ordering.
  """
    # TODO(priyag, rxsang): Come up a better way to determine when the batch index
    # should be added.
    return (mode == ModeKeys.PREDICT and dist_utils.is_tpu_strategy(strategy)
            and strategy.extended.num_hosts > 1)
    def fit(self,
            model,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_freq=1,
            **kwargs):
        """Fit loop for Distribution Strategies."""
        dist_utils.validate_callbacks(input_callbacks=callbacks,
                                      optimizer=model.optimizer)
        dist_utils.validate_inputs(x, y)

        batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size(
            model._distribution_strategy,
            x,
            batch_size,
            steps_per_epoch,
            ModeKeys.TRAIN,
            validation_split=validation_split)
        batch_size = model._validate_or_infer_batch_size(
            batch_size, steps_per_epoch, x)
        dataset = model._distribution_standardize_user_data(
            x,
            y,
            sample_weight=sample_weight,
            class_weight=class_weight,
            batch_size=batch_size,
            validation_split=validation_split,
            shuffle=shuffle,
            epochs=epochs)
        if not dist_utils.is_distributing_by_cloning(model):
            with model._distribution_strategy.scope():
                (dataset, _, _) = model._standardize_user_data(
                    dataset,
                    sample_weight=sample_weight,
                    class_weight=class_weight,
                    batch_size=batch_size,
                    validation_split=validation_split,
                    shuffle=shuffle)

        val_dataset = None
        if validation_data:
            val_x, val_y, val_sample_weights = (
                training_utils_v1.unpack_validation_data(validation_data))
            dist_utils.validate_inputs(val_x, val_y)
            _, validation_steps = dist_utils.process_batch_and_step_size(
                model._distribution_strategy, val_x, batch_size,
                validation_steps, ModeKeys.TEST)

            val_dataset = model._distribution_standardize_user_data(
                val_x,
                val_y,
                sample_weight=val_sample_weights,
                class_weight=None,
                batch_size=batch_size,
                validation_split=validation_split,
                shuffle=shuffle,
                allow_partial_batch=True)
        elif validation_split:
            raise ValueError('validation_split argument is not supported with '
                             'distribution strategies.')

        if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
            steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
                model,
                dataset,
                steps_per_epoch,
                epochs,
                steps_name='steps_per_epoch')
            if steps_per_epoch is None:
                raise ValueError(
                    'Number of steps could not be inferred from the data, '
                    'please pass the steps_per_epoch argument.')

            if not context.executing_eagerly():
                # Run TPU training in a custom loop in graph mode.
                return experimental_tpu_fit_loop(
                    model,
                    dataset,
                    epochs=epochs,
                    verbose=verbose,
                    callbacks=callbacks,
                    val_dataset=val_dataset,
                    initial_epoch=initial_epoch,
                    steps_per_epoch=steps_per_epoch,
                    validation_steps=validation_steps,
                    validation_freq=validation_freq)

        return training_arrays_v1.fit_loop(model,
                                           dataset,
                                           batch_size=batch_size,
                                           epochs=epochs,
                                           verbose=verbose,
                                           callbacks=callbacks,
                                           val_inputs=val_dataset,
                                           shuffle=shuffle,
                                           initial_epoch=initial_epoch,
                                           steps_per_epoch=steps_per_epoch,
                                           validation_steps=validation_steps,
                                           validation_freq=validation_freq,
                                           steps_name='steps_per_epoch')
Exemple #15
0
def fit_distributed(model,
                    x=None,
                    y=None,
                    batch_size=None,
                    epochs=1,
                    verbose=1,
                    callbacks=None,
                    validation_split=0.,
                    validation_data=None,
                    shuffle=True,
                    class_weight=None,
                    sample_weight=None,
                    initial_epoch=0,
                    steps_per_epoch=None,
                    validation_steps=None,
                    validation_freq=1):
  """Fit loop for Distribution Strategies."""
  distributed_training_utils.validate_callbacks(callbacks, model.optimizer)
  distributed_training_utils.validate_inputs(
      x, y)

  first_x_value = nest.flatten(x)[0]
  if isinstance(first_x_value, np.ndarray):
    # Until support for partial batch is implemented across all
    # functions and distribution strategy, we pass `mode` to selectively
    # relax the costraint to consume all the training samples.
    steps_per_epoch, batch_size = (
        distributed_training_utils.get_input_params(
            model._distribution_strategy, first_x_value, steps_per_epoch,
            batch_size, mode=ModeKeys.TRAIN))
  batch_size = model._validate_or_infer_batch_size(
      batch_size, steps_per_epoch, x)
  dataset = model._distribution_standardize_user_data(
      x, y,
      sample_weight=sample_weight,
      class_weight=class_weight,
      batch_size=batch_size,
      validation_split=validation_split,
      shuffle=shuffle,
      epochs=epochs)
  if not distributed_training_utils.is_distributing_by_cloning(model):
    with model._distribution_strategy.scope():
      (dataset, _, _) = model._standardize_user_data(
          dataset,
          sample_weight=sample_weight,
          class_weight=class_weight,
          batch_size=batch_size,
          validation_split=validation_split,
          shuffle=shuffle)

  val_dataset = None
  if validation_data:
    val_x, val_y, val_sample_weights = model._unpack_validation_data(
        validation_data)
    distributed_training_utils.validate_inputs(val_x, val_y)
    first_valx_value = nest.flatten(val_x)[0]
    if isinstance(first_valx_value, np.ndarray):
      validation_steps, _ = distributed_training_utils.get_input_params(
          model._distribution_strategy, first_valx_value, validation_steps,
          batch_size, mode=ModeKeys.TEST)
    val_dataset = model._distribution_standardize_user_data(
        val_x, val_y,
        sample_weight=val_sample_weights,
        class_weight=None,
        batch_size=batch_size,
        validation_split=validation_split,
        shuffle=shuffle,
        allow_partial_batch=True)
  elif validation_split:
    raise ValueError('validation_split argument is not supported with '
                     'distribution strategies.')

  if distributed_training_utils.is_tpu_strategy(model._distribution_strategy):
    return experimental_tpu_fit_loop(
        model,
        dataset,
        epochs=epochs,
        verbose=verbose,
        callbacks=callbacks,
        val_dataset=val_dataset,
        initial_epoch=initial_epoch,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,
        validation_freq=validation_freq)
  else:
    return training_arrays.fit_loop(
        model,
        dataset,
        batch_size=batch_size,
        epochs=epochs,
        verbose=verbose,
        callbacks=callbacks,
        val_inputs=val_dataset,
        shuffle=shuffle,
        initial_epoch=initial_epoch,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,
        validation_freq=validation_freq,
        steps_name='steps_per_epoch')
Exemple #16
0
def fit_distributed(model,
                    x=None,
                    y=None,
                    batch_size=None,
                    epochs=1,
                    verbose=1,
                    callbacks=None,
                    validation_split=0.,
                    validation_data=None,
                    shuffle=True,
                    class_weight=None,
                    sample_weight=None,
                    initial_epoch=0,
                    steps_per_epoch=None,
                    validation_steps=None,
                    validation_freq=1):
  """Fit loop for Distribution Strategies."""
  distributed_training_utils.validate_callbacks(callbacks, model.optimizer)
  distributed_training_utils.validate_inputs(
      x, y, model._distribution_strategy)

  first_x_value = nest.flatten(x)[0]
  if isinstance(first_x_value, np.ndarray):
    # Until support for partial batch is implemented across all
    # functions and distribution strategy, we pass `mode` to selectively
    # relax the costraint to consume all the training samples.
    steps_per_epoch, batch_size = (
        distributed_training_utils.get_input_params(
            model._distribution_strategy, first_x_value, steps_per_epoch,
            batch_size, mode=ModeKeys.TRAIN))
  batch_size = model._validate_or_infer_batch_size(
      batch_size, steps_per_epoch, x)
  dataset = model._distribution_standardize_user_data(
      x, y,
      sample_weight=sample_weight,
      class_weight=class_weight,
      batch_size=batch_size,
      validation_split=validation_split,
      shuffle=shuffle,
      repeat=True)

  val_dataset = None
  if validation_data:
    val_x, val_y, val_sample_weights = model._unpack_validation_data(
        validation_data)
    distributed_training_utils.validate_inputs(
        val_x, val_y, model._distribution_strategy)
    first_valx_value = nest.flatten(val_x)[0]
    if isinstance(first_valx_value, np.ndarray):
      validation_steps, _ = distributed_training_utils.get_input_params(
          model._distribution_strategy, first_valx_value, validation_steps,
          batch_size)
    val_dataset = model._distribution_standardize_user_data(
        val_x, val_y,
        sample_weight=val_sample_weights,
        class_weight=None,
        batch_size=batch_size,
        validation_split=validation_split,
        shuffle=shuffle)
  elif validation_split:
    raise ValueError('validation_split argument is not supported with '
                     'distribution strategies.')

  if distributed_training_utils.is_tpu_strategy(model._distribution_strategy):
    return experimental_tpu_fit_loop(
        model,
        dataset,
        epochs=epochs,
        verbose=verbose,
        callbacks=callbacks,
        val_dataset=val_dataset,
        initial_epoch=initial_epoch,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,
        validation_freq=validation_freq)
  else:
    return training_arrays.fit_loop(
        model,
        dataset,
        batch_size=batch_size,
        epochs=epochs,
        verbose=verbose,
        callbacks=callbacks,
        val_inputs=val_dataset,
        shuffle=shuffle,
        initial_epoch=initial_epoch,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,
        validation_freq=validation_freq,
        steps_name='steps_per_epoch')
def get_input_params(distribution_strategy,
                     num_samples,
                     steps,
                     batch_size,
                     mode=None):
  """Calculate the number of batches and steps/steps_per_epoch.

  Args:
    distribution_strategy: The DistributionStrategy used to compile the model.
    num_samples: The number of samples from which we determine the batch size
      and steps.
    steps:  The specified number of steps.
    batch_size: The specified batch_size.
    mode: ModeKey representing whether input will be used for training,
      evaluation, or prediction. This is used to relax the constraints on
      consuming all the training samples to keep compatibility till we support
      partial batches. If none, then partial batches are not allowed.

  Returns:
    steps: The steps or steps_per_epoch argument depending on if a user is
        calling `fit`, `evaluate` or `predict`. If the is_training flag is set
        we don't require the number of samples to be used completely.
    batch_size: The batch size to be used in model iterations.

  Raises:
    ValueError: If the number of batches or steps evaluates to 0.

  """
  # TODO(b/118776054): Use global batch size for Keras/DS support.
  # Currently this is only supported in TPUStrategy and CoreMirroredStrategy.
  use_per_replica_batch = not dist_utils.global_batch_size_supported(
      distribution_strategy)

  # TODO(b/128995245): In eager mode, uneven batch sizes are allowed except for
  # `fit()` on TPUStrategy.
  # In graph mode, the zero batch case in batch norm is not handled due to
  # XLA-GPU regression. Uneven batch sizes are not allowed except
  # for `test()` and `predict()` on TPUStrategy.
  if context.executing_eagerly():
    allow_partial_batch = (
        mode != ModeKeys.TRAIN or
        not dist_utils.is_tpu_strategy(distribution_strategy))
  else:
    allow_partial_batch = (
        mode == ModeKeys.TRAIN or
        ((mode == ModeKeys.PREDICT or mode == ModeKeys.TEST) and
         dist_utils.is_tpu_strategy(distribution_strategy)))

  if steps is None:
    if batch_size is None:
      # If neither the batch size or number of steps are set. We choose the
      # global batch size as the minimum of number of samples and 32. 32 is
      # chosen to provide backward compatibility.
      global_batch_size = min(num_samples, 32)
    else:
      # If the user provided the batch size we need to handle the case
      # between different strategies that use the global/per-replica batch size
      global_batch_size = batch_size
      if use_per_replica_batch:
        global_batch_size *= distribution_strategy.num_replicas_in_sync
    if allow_partial_batch:
      steps = np.ceil(num_samples / global_batch_size).astype(int)
    else:
      if num_samples % global_batch_size:
        raise ValueError('The number of samples %s is not divisible by '
                         'batch size %s.' % (num_samples, global_batch_size))
      steps = num_samples // global_batch_size
  else:
    if batch_size is None:
      # We calculate the batch size based on the number of steps specified
      if num_samples % steps:
        raise ValueError('The number of samples %s is not divisible by '
                         'steps %s. Please change the number of steps to a '
                         'value that can consume all the samples' % (
                             num_samples, steps))
      global_batch_size = num_samples // steps
    else:
      # If the user provided the batch size we need to handle the case
      # between different strategies that use the global/per-replica batch size
      global_batch_size = batch_size
      if use_per_replica_batch:
        global_batch_size *= distribution_strategy.num_replicas_in_sync

      min_num_samples = global_batch_size * steps
      if allow_partial_batch:
        min_num_samples = global_batch_size * (steps-1) + 1 if steps > 1 else 0

      if num_samples < min_num_samples:
        raise ValueError('Number of samples %s is less than samples required '
                         'for specified batch_size %s and steps %s' % (
                             num_samples, global_batch_size, steps))

  # We need to return the per replica or global batch size based on the strategy
  if use_per_replica_batch:
    if global_batch_size % distribution_strategy.num_replicas_in_sync:
      raise ValueError(
          'The batch size (%s) could not be sharded evenly across the sync '
          'replicas (%s) in the distribution strategy.' % (
              global_batch_size, distribution_strategy.num_replicas_in_sync))
    batch_size = global_batch_size // distribution_strategy.num_replicas_in_sync
  else:
    batch_size = global_batch_size

  return steps, batch_size
    def fit(self,
            model,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_freq=1,
            **kwargs):
        """Fit loop for Distribution Strategies."""
        dist_utils.validate_callbacks(input_callbacks=callbacks,
                                      optimizer=model.optimizer)
        dist_utils.validate_inputs(x, y)

        batch_size, steps_per_epoch = self._process_batch_and_step_size(
            model, x, batch_size, steps_per_epoch, ModeKeys.TRAIN)
        batch_size = model._validate_or_infer_batch_size(
            batch_size, steps_per_epoch, x)
        dataset = model._distribution_standardize_user_data(
            x,
            y,
            sample_weight=sample_weight,
            class_weight=class_weight,
            batch_size=batch_size,
            validation_split=validation_split,
            shuffle=shuffle,
            epochs=epochs)
        if not dist_utils.is_distributing_by_cloning(model):
            with model._distribution_strategy.scope():
                (dataset, _, _) = model._standardize_user_data(
                    dataset,
                    sample_weight=sample_weight,
                    class_weight=class_weight,
                    batch_size=batch_size,
                    validation_split=validation_split,
                    shuffle=shuffle)

        val_dataset = None
        if validation_data:
            val_x, val_y, val_sample_weights = model._unpack_validation_data(
                validation_data)
            dist_utils.validate_inputs(val_x, val_y)
            _, validation_steps = self._process_batch_and_step_size(
                model, val_x, batch_size, validation_steps, ModeKeys.TEST)

            val_dataset = model._distribution_standardize_user_data(
                val_x,
                val_y,
                sample_weight=val_sample_weights,
                class_weight=None,
                batch_size=batch_size,
                validation_split=validation_split,
                shuffle=shuffle,
                allow_partial_batch=True)
        elif validation_split:
            raise ValueError('validation_split argument is not supported with '
                             'distribution strategies.')

        if dist_utils.is_tpu_strategy(model._distribution_strategy):
            return experimental_tpu_fit_loop(model,
                                             dataset,
                                             epochs=epochs,
                                             verbose=verbose,
                                             callbacks=callbacks,
                                             val_dataset=val_dataset,
                                             initial_epoch=initial_epoch,
                                             steps_per_epoch=steps_per_epoch,
                                             validation_steps=validation_steps,
                                             validation_freq=validation_freq)
        else:
            return training_arrays.fit_loop(model,
                                            dataset,
                                            batch_size=batch_size,
                                            epochs=epochs,
                                            verbose=verbose,
                                            callbacks=callbacks,
                                            val_inputs=val_dataset,
                                            shuffle=shuffle,
                                            initial_epoch=initial_epoch,
                                            steps_per_epoch=steps_per_epoch,
                                            validation_steps=validation_steps,
                                            validation_freq=validation_freq,
                                            steps_name='steps_per_epoch')