コード例 #1
0
def _copy_weights_to_distributed_model(model):
  """Copies weights from original model to distributed models."""
  if model._distribution_strategy:
    # Copy the weights from the original model to each of the replicated models.
    orig_model_weights = model.get_weights()
    distributed_model = model._distribution_strategy.unwrap(
        model._grouped_model)[0]
    distributed_training_utils.set_weights(
        model._distribution_strategy, distributed_model, orig_model_weights)
コード例 #2
0
def _copy_weights_to_distributed_model(model):
  """Copies weights from original model to distributed models."""
  if model._distribution_strategy:
    # Copy the weights from the original model to each of the replicated models.
    orig_model_weights = model.get_weights()
    distributed_model = model._distribution_strategy.unwrap(
        model._grouped_model)[0]
    distributed_training_utils.set_weights(
        model._distribution_strategy, distributed_model, orig_model_weights)
コード例 #3
0
def predict_loop(model, iterator, verbose=0, steps=None):
  """Predict loop for predicting with DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator for input data.
      verbose: Integer, Verbosity mode 0 or 1.
      steps: Total number of steps (batches of samples)
          before declaring `_predict_loop` finished.
          Ignored with the default value of `None`.

  Returns:
      Array of predictions (if the model has a single output)
      or list of arrays of predictions
      (if the model has multiple outputs).
  """
  current_strategy = model._distribution_strategy

  # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
  if current_strategy.__class__.__name__ == 'TPUStrategy':
    return _experimental_predict_loop(model, iterator, verbose, steps)

  if not model._grouped_model:
    clone_model_on_replicas(model, current_strategy)

  def _per_device_predict_function(model):
    model._make_predict_function()
    return (model.predict_function.inputs,
            model.predict_function.outputs,
            model.predict_function.updates_op,
            model.predict_function.session_kwargs)

  inputs, _, _ = _get_input_from_iterator(iterator, model)
  with current_strategy.scope():
    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.call_for_each_replica(
         _per_device_predict_function, args=(model._grouped_model,))

    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args)

    dataset_inputs = distributed_training_utils.flatten_perdevice_values(
        current_strategy, inputs)

    distributed_predict_function = K.function(
        all_inputs, all_outputs,
        updates=all_updates,
        name='distributed_predict_function',
        **all_session_args)

    if not isinstance(K.learning_phase(), int):
      ins = dataset_inputs + [0]
    else:
      ins = dataset_inputs

    if verbose == 1:
      progbar = Progbar(target=steps)

    # Copy the weights from the original model to each of the replicated models.
    orig_model_weights = model.get_weights()
    distributed_model = current_strategy.unwrap(model._grouped_model)[0]
    distributed_training_utils.set_weights(
        current_strategy, distributed_model, orig_model_weights)

    num_replicas = current_strategy.num_replicas_in_sync
    # Since we do not know how many samples we will see, we cannot
    # pre-allocate the returned Numpy arrays. Instead, we store one array per
    # batch seen and concatenate them upon returning.
    unconcatenated_outs = []
    assert steps is not None
    for step in range(steps):
      batch_outs = distributed_predict_function(ins)
      if not isinstance(batch_outs, list):
        batch_outs = [batch_outs]
      if step == 0:
        # batch_outs gives you the number of model outputs. In the distributed
        # case this will be number of model_outputs * num_replicas.
        for _ in range(len(model.outputs)):
          unconcatenated_outs.append([])
      for i in range(len(model.outputs)):
        nested_outs = batch_outs[i * num_replicas:
                                 i * num_replicas + num_replicas]
        outs = nest.flatten(nested_outs)
        unconcatenated_outs[i].extend(outs)
      if verbose >= 1:
        progbar.update(step + 1)
    if len(unconcatenated_outs) == 1:
      return np.concatenate(unconcatenated_outs[0], axis=0)
    return [
        np.concatenate(unconcatenated_outs[i], axis=0)
        for i in range(len(unconcatenated_outs))
    ]
コード例 #4
0
def _experimental_test_loop(model, iterator, verbose=0, steps=None,
                            initialize_finalize_strategy=True):
  """Test loop for evaluating with TPU DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator for input data.
      verbose: Integer, Verbosity mode 0 or 1.
      steps: Total number of steps (batches of samples)
          before declaring predictions finished.
          Ignored with the default value of `None`.
      initialize_finalize_strategy: Should the strategy initialize and finalize
          functions be called.

  Returns:
      Scalar loss (if the model has a single output and no metrics)
      or list of scalars (if the model has multiple outputs
      and/or metrics). The attribute `model.metrics_names` will give you
      the display labels for the outputs.
  """
  current_strategy = model._distribution_strategy
  if initialize_finalize_strategy:
    K.get_session().run(current_strategy.initialize())

  def _per_device_eval_function(model):
    model._make_eval_function()
    return (model._eval_function.inputs, model._eval_function.outputs,
            model._eval_function.updates_op,
            model._eval_function.session_kwargs)

  # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
  K.set_learning_phase(0)

  def step_fn(ctx, inputs, targets):
    """Clones the model and calls make_eval_function."""
    # TODO(priyag, sourabhbajaj): The model gets cloned every time
    # fit/test/predict is called. We should look into caching this keyed on
    # input shapes.
    clone_model_on_replicas(
        model,
        current_strategy,
        make_callback_model=False,
        inputs=inputs,
        targets=targets,
        mode=_Mode.TEST)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.call_for_each_replica(
         _per_device_eval_function, args=(model._grouped_model_test,))

    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args)

    combined_fn = K.function(
        all_inputs, all_outputs,
        updates=all_updates,
        name='distributed_test_function',
        **all_session_args)

    for label, output in zip(model.metrics_names, combined_fn.outputs):
      if label == 'loss':
        aggregation = distribute_lib.get_loss_reduction()
      else:
        # We aggregate all other metrics using mean for now. This is temporary
        # workaround until new metrics are in place.
        aggregation = variable_scope.VariableAggregation.MEAN
      ctx.set_last_step_output(label, output, aggregation)

    return combined_fn.updates_op

  # Add initial dummy values for loss and other metric tensors.
  initial_loop_values = {}
  initial_loop_values['loss'] = constant_op.constant(1e7)
  for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
    initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)

  with current_strategy.scope():
    # TODO(priyag): Use steps_per_run when we use new metrics as they will
    # allow handling metric computation at each step using variables.
    ctx = current_strategy.run_steps_on_dataset(
        step_fn, iterator, iterations=1,
        initial_loop_values=initial_loop_values)

  test_op = ctx.run_op
  output_tensors = ctx.last_step_outputs

  if verbose == 1:
    progbar = Progbar(target=steps)

  # Copy the weights from the original model to each of the replicated models.
  orig_model_weights = model.get_weights()
  with current_strategy.scope():
    distributed_model = current_strategy.unwrap(model._grouped_model_test)[0]
    distributed_training_utils.set_weights(
        current_strategy, distributed_model, orig_model_weights)

  assert steps is not None
  outs = [0.] * len(model.metrics_names)
  for step in range(steps):
    _, batch_outs = K.get_session().run([test_op, output_tensors])
    for i, label in enumerate(model.metrics_names):
      outs[i] += batch_outs[label]
    if verbose >= 1:
      progbar.update(step + 1)
  for i in range(len(outs)):
    outs[i] /= (steps)

  if initialize_finalize_strategy:
    K.get_session().run(current_strategy.finalize())

  if len(outs) == 1:
    return outs[0]
  return outs
コード例 #5
0
def fit_loop(
    model,
    iterator,
    epochs=100,
    verbose=1,
    callbacks=None,
    val_iterator=None,
    initial_epoch=0,
    steps_per_epoch=None,
    validation_steps=None):
  """Fit loop for training with DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator for input data.
      epochs: Number of times to iterate over the data
      verbose: Integer, Verbosity mode, 0, 1 or 2
      callbacks: List of callbacks to be called during training
      val_iterator: Iterator for validation data.
      initial_epoch: Epoch at which to start training
          (useful for resuming a previous training run)
      steps_per_epoch: Total number of steps (batches of samples)
          before declaring one epoch finished and starting the
          next epoch. Ignored with the default value of `None`.
      validation_steps: Number of steps to run validation for
          (only if doing validation from data tensors).
          Ignored with the default value of `None`.

  Returns:
      `History` object.

  Raises:
      ValueError: in case of invalid arguments.
  """
  current_strategy = model._distribution_strategy

  # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
  if current_strategy.__class__.__name__ == 'TPUStrategy':
    return _experimental_fit_loop(
        model, iterator, epochs, verbose, callbacks, initial_epoch,
        steps_per_epoch, val_iterator, validation_steps)

  if not model._grouped_model:
    clone_model_on_replicas(model, current_strategy, make_callback_model=True)

  def _per_device_fit_function(model):
    model._make_fit_function()
    return (model._fit_function.inputs, model._fit_function.outputs,
            model._fit_function.updates_op, model._fit_function.session_kwargs)

  inputs, targets, sample_weights = _get_input_from_iterator(iterator, model)
  with current_strategy.scope():
    # Create train ops on each of the devices when we call
    # `_per_device_fit_function`.
    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.call_for_each_replica(
         _per_device_fit_function, args=(model._grouped_model,))
    # Unwrap all the per device values returned from `call_for_each_replica`.
    # Unwrapping per device values gives you a list of values that can be
    # used to construct a new train function that is composed of update ops on
    # all the devices over which the model is distributed.
    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs,
         grouped_updates, grouped_session_args, with_loss_tensor=True)

    # Dataset inputs and targets are also per devices values that need to be
    # unwrapped.
    dataset_inputs = distributed_training_utils.flatten_perdevice_values(
        current_strategy, inputs)
    dataset_targets = distributed_training_utils.flatten_perdevice_values(
        current_strategy, targets)

    # Create a train function that is composed of all the parameters above.
    distributed_fit_function = K.function(
        all_inputs,
        all_outputs,
        updates=all_updates,
        name='distributed_fit_function',
        **all_session_args)

    # We need to set sample_weights to None since there are sample weight
    # placeholders that are created with default values.
    sample_weights = [None for _ in range(
        len(model.outputs) * current_strategy.num_replicas_in_sync)]
    if not isinstance(K.learning_phase(), int):
      ins = dataset_inputs + dataset_targets + sample_weights + [1]
    else:
      ins = dataset_inputs + dataset_targets

    do_validation = False
    if validation_steps:
      do_validation = True

    # Copy the weights from the original model to each of the replicated models.
    orig_model_weights = model.get_weights()
    distributed_model = current_strategy.unwrap(model._grouped_model)[0]
    distributed_training_utils.set_weights(
        current_strategy, distributed_model, orig_model_weights)

    callbacks = cbks.configure_callbacks(
        callbacks,
        model,
        do_validation=do_validation,
        val_inputs=None,
        val_targets=None,
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        verbose=verbose)
    out_labels = model.metrics_names or []
    callbacks.on_train_begin()

    assert steps_per_epoch is not None

    for epoch in range(initial_epoch, epochs):
      # Reset stateful metrics
      for m in model.stateful_metric_functions:
        m.reset_states()
      callbacks.on_epoch_begin(epoch)
      epoch_logs = {}
      for step_index in range(steps_per_epoch):
        batch_logs = {'batch': step_index, 'size': 1}
        callbacks.on_batch_begin(step_index, batch_logs)
        try:
          outs = distributed_fit_function(ins)
        except errors.OutOfRangeError:
          logging.warning('Your dataset iterator ran out of data; '
                          'interrupting training. Make sure that your dataset '
                          'can generate at least `steps_per_epoch * epochs` '
                          'batches (in this case, %d batches).' %
                          steps_per_epoch * epochs)
          break

        if not isinstance(outs, list):
          outs = [outs]
        for l, o in zip(out_labels, outs):
          batch_logs[l] = o
        callbacks.on_batch_end(step_index, batch_logs)
        if callbacks.model.stop_training:
          break
      if do_validation:
        val_outs = test_loop(
            model,
            val_iterator,
            steps=validation_steps,
            verbose=0)
        if not isinstance(val_outs, list):
          val_outs = [val_outs]
        # Same labels assumed.
        for l, o in zip(out_labels, val_outs):
          epoch_logs['val_' + l] = o

      callbacks.on_epoch_end(epoch, epoch_logs)
      if callbacks.model.stop_training:
        break
    callbacks.on_train_end()

    # Copy the weights back from the replicated model to the original model.
    updated_weights = current_strategy.unwrap(
        model._grouped_model)[0].get_weights()
    model.set_weights(updated_weights)
    return model.history
コード例 #6
0
def _experimental_fit_loop(
    model,
    iterator,
    epochs=100,
    initial_epoch=0,
    steps_per_epoch=None):
  """fit function when using TPU DistributionStrategy for training.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator that returns inputs and targets
      epochs: Number of times to iterate over the data
      initial_epoch: Epoch at which to start training
          (useful for resuming a previous training run)
      steps_per_epoch: Total number of steps (batches of samples)
          before declaring one epoch finished and starting the
          next epoch. Ignored with the default value of `None`.

  Returns:
      Returns `None`.

  Raises:
      ValueError: in case of invalid arguments.
  """
  current_strategy = model._distribution_strategy

  # TODO(priyag): Add validation that shapes are fully defined for TPU case.

  # TODO(priyag, sourabhbajaj): This should be moved into a callback instead.
  K.get_session().run(current_strategy.initialize())

  def _per_device_train_function(model):
    model._make_train_function()
    return (model.train_function.inputs,
            model.train_function.outputs,
            model.train_function.updates_op,
            model.train_function.session_kwargs)

  # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
  K.set_learning_phase(1)

  def step_fn(ctx, inputs, targets):
    """Clones the model and calls make_train_function."""
    # TODO(priyag, sourabhbajaj): Should cache this keyed on input shapes.
    clone_model_on_towers(
        model,
        current_strategy,
        make_callback_model=True,
        inputs=inputs,
        targets=targets)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.call_for_each_tower(
         _per_device_train_function, model._grouped_model)
    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs,
         grouped_updates, grouped_session_args, with_loss_tensor=True)
    combined_fn = K.Function(
        all_inputs, all_outputs,
        updates=all_updates,
        name='distributed_train_function',
        **all_session_args)

    # TODO(priyag, sourabhbajaj): Perhaps the aggregation type needs to be
    # something else for different outputs.
    out_labels = model.metrics_names or []
    for label, output in zip(out_labels, combined_fn.outputs):
      ctx.set_last_step_output(label, output,
                               aggregation=distribute_lib.get_loss_reduction())

    # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
    # feed_dict, session kwargs, run options, run_metadata for now. These should
    # be handled appropriately
    return combined_fn.updates_op

  # Add initial dummy values for loss and other metric tensors.
  initial_loop_values = {}
  initial_loop_values['loss'] = constant_op.constant(1e7)
  for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
    initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)

  with current_strategy.scope():
    # TODO(priyag, sourabhbajaj): Adjust steps_per_run appropriately based on
    # steps_per_epoch and number of epochs.
    ctx = current_strategy.run_steps_on_dataset(
        step_fn, iterator, iterations=current_strategy.steps_per_run,
        initial_loop_values=initial_loop_values)

  train_op = ctx.run_op
  output_tensors = ctx.last_step_outputs

  # Copy the weights from the original model to each of the replicated models.
  orig_model_weights = model.get_weights()
  with current_strategy.scope():
    distributed_model = current_strategy.unwrap(model._grouped_model)[0]
    distributed_training_utils.set_weights(
        current_strategy, distributed_model, orig_model_weights)

  assert steps_per_epoch is not None

  # TODO(priyag, sourabhbajaj): Add callbacks support.
  # TODO(priyag, sourabhbajaj): Add validation.
  for epoch in range(initial_epoch, epochs):
    for step_index in range(
        0, steps_per_epoch, current_strategy.steps_per_run):
      try:
        _, outs = K.get_session().run([train_op, output_tensors])
        # TODO(priyag, sourabhbajaj): Remove this logging in favor of proper
        # summaries through callbacks.
        print('Epoch: {}, step_index: {}, loss: {}'.format(
            epoch, step_index, outs['loss']))
        for label, out in outs.items():
          print(label, ': ', out)
      except errors.OutOfRangeError:
        logging.warning('Your dataset iterator ran out of data; '
                        'interrupting training. Make sure that your dataset '
                        'can generate at least `steps_per_epoch * epochs` '
                        'batches (in this case, %d batches).' %
                        steps_per_epoch * epochs)
        break

  # Copy the weights back from the replicated model to the original model.
  with current_strategy.scope():
    updated_weights = current_strategy.unwrap(
        model._grouped_model)[0].get_weights()
    model.set_weights(updated_weights)

  K.get_session().run(current_strategy.finalize())
コード例 #7
0
def predict_loop(model, iterator, verbose=0, steps=None):
    """Predict loop for predicting with DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator for input data.
      verbose: Integer, Verbosity mode 0 or 1.
      steps: Total number of steps (batches of samples)
          before declaring `_predict_loop` finished.
          Ignored with the default value of `None`.

  Returns:
      Array of predictions (if the model has a single output)
      or list of arrays of predictions
      (if the model has multiple outputs).
  """
    current_strategy = model._distribution_strategy

    # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
    if current_strategy.__class__.__name__ == 'TPUStrategy':
        return _experimental_predict_loop(model, iterator, verbose, steps)

    if not model._grouped_model:
        clone_model_on_replicas(model, current_strategy)

    def _per_device_predict_function(model):
        model._make_predict_function()
        return (model.predict_function.inputs, model.predict_function.outputs,
                model.predict_function.updates_op,
                model.predict_function.session_kwargs)

    inputs, _, _ = _get_input_from_iterator(iterator, model)
    with current_strategy.scope():
        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_replica(
             _per_device_predict_function, args=(model._grouped_model, ))

        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy, grouped_inputs, grouped_outputs,
             grouped_updates, grouped_session_args)

        dataset_inputs = distributed_training_utils.flatten_perdevice_values(
            current_strategy, inputs)

        distributed_predict_function = K.function(
            all_inputs,
            all_outputs,
            updates=all_updates,
            name='distributed_predict_function',
            **all_session_args)

        if verbose == 1:
            progbar = Progbar(target=steps)

        # Copy the weights from the original model to each of the replicated models.
        orig_model_weights = model.get_weights()
        distributed_model = current_strategy.unwrap(model._grouped_model)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)

        num_replicas = current_strategy.num_replicas_in_sync
        # Since we do not know how many samples we will see, we cannot
        # pre-allocate the returned Numpy arrays. Instead, we store one array per
        # batch seen and concatenate them upon returning.
        unconcatenated_outs = []
        assert steps is not None
        for step in range(steps):
            batch_outs = distributed_predict_function(dataset_inputs)
            if not isinstance(batch_outs, list):
                batch_outs = [batch_outs]
            if step == 0:
                # batch_outs gives you the number of model outputs. In the distributed
                # case this will be number of model_outputs * num_replicas.
                for _ in range(len(model.outputs)):
                    unconcatenated_outs.append([])
            for i in range(len(model.outputs)):
                nested_outs = batch_outs[i * num_replicas:i * num_replicas +
                                         num_replicas]
                outs = nest.flatten(nested_outs)
                unconcatenated_outs[i].extend(outs)
            if verbose >= 1:
                progbar.update(step + 1)
        if len(unconcatenated_outs) == 1:
            return np.concatenate(unconcatenated_outs[0], axis=0)
        return [
            np.concatenate(unconcatenated_outs[i], axis=0)
            for i in range(len(unconcatenated_outs))
        ]
コード例 #8
0
def fit_loop(model,
             iterator,
             epochs=100,
             verbose=1,
             callbacks=None,
             val_iterator=None,
             initial_epoch=0,
             steps_per_epoch=None,
             validation_steps=None):
    """Fit loop for training with DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator for input data.
      epochs: Number of times to iterate over the data
      verbose: Integer, Verbosity mode, 0, 1 or 2
      callbacks: List of callbacks to be called during training
      val_iterator: Iterator for validation data.
      initial_epoch: Epoch at which to start training
          (useful for resuming a previous training run)
      steps_per_epoch: Total number of steps (batches of samples)
          before declaring one epoch finished and starting the
          next epoch. Ignored with the default value of `None`.
      validation_steps: Number of steps to run validation for
          (only if doing validation from data tensors).
          Ignored with the default value of `None`.

  Returns:
      `History` object.

  Raises:
      ValueError: in case of invalid arguments.
  """
    current_strategy = model._distribution_strategy

    # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
    if current_strategy.__class__.__name__ == 'TPUStrategy':
        return _experimental_fit_loop(model, iterator, epochs, verbose,
                                      callbacks, initial_epoch,
                                      steps_per_epoch, val_iterator,
                                      validation_steps)

    if not model._grouped_model:
        clone_model_on_replicas(model,
                                current_strategy,
                                make_callback_model=True)

    def _per_device_fit_function(model):
        model._make_fit_function()
        return (model._fit_function.inputs, model._fit_function.outputs,
                model._fit_function.updates_op,
                model._fit_function.session_kwargs)

    inputs, targets, sample_weights = _get_input_from_iterator(iterator, model)
    with current_strategy.scope():
        # Create train ops on each of the devices when we call
        # `_per_device_fit_function`.
        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_replica(
             _per_device_fit_function, args=(model._grouped_model, ))

        # Initialize the variables in the replicated model. This is necessary for
        # multi-worker training because on some workers, initialization is not
        # needed. This method does initialization or waiting for initialization
        # according to the context object of distribute coordinator.
        distributed_training_utils.init_restore_or_wait_for_variables()

        # Unwrap all the per device values returned from `call_for_each_replica`.
        # Unwrapping per device values gives you a list of values that can be
        # used to construct a new train function that is composed of update ops on
        # all the devices over which the model is distributed.
        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy,
             grouped_inputs,
             grouped_outputs,
             grouped_updates,
             grouped_session_args,
             with_loss_tensor=True)

        # Dataset inputs and targets are also per devices values that need to be
        # unwrapped.
        dataset_inputs = distributed_training_utils.flatten_perdevice_values(
            current_strategy, inputs)
        dataset_targets = distributed_training_utils.flatten_perdevice_values(
            current_strategy, targets)

        # Create a train function that is composed of all the parameters above.
        distributed_fit_function = K.function(all_inputs,
                                              all_outputs,
                                              updates=all_updates,
                                              name='distributed_fit_function',
                                              **all_session_args)

        # We need to set sample_weights to None since there are sample weight
        # placeholders that are created with default values.
        sample_weights = [
            None for _ in range(
                len(model.outputs) * current_strategy.num_replicas_in_sync)
        ]
        if not isinstance(K.learning_phase(), int):
            ins = dataset_inputs + dataset_targets + sample_weights + [1]
        else:
            ins = dataset_inputs + dataset_targets

        do_validation = False
        if validation_steps:
            do_validation = True

        # Copy the weights from the original model to each of the replicated models.
        orig_model_weights = model.get_weights()
        distributed_model = current_strategy.unwrap(model._grouped_model)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)

        callbacks = cbks.configure_callbacks(callbacks,
                                             model,
                                             do_validation=do_validation,
                                             val_inputs=None,
                                             val_targets=None,
                                             epochs=epochs,
                                             steps_per_epoch=steps_per_epoch,
                                             verbose=verbose)
        out_labels = model.metrics_names or []
        callbacks.on_train_begin()

        assert steps_per_epoch is not None

        for epoch in range(initial_epoch, epochs):
            # Reset stateful metrics
            for m in model.stateful_metric_functions:
                m.reset_states()
            callbacks.on_epoch_begin(epoch)
            epoch_logs = {}
            for step_index in range(steps_per_epoch):
                batch_logs = {'batch': step_index, 'size': 1}
                callbacks.on_batch_begin(step_index, batch_logs)
                try:
                    outs = distributed_fit_function(ins)
                except errors.OutOfRangeError:
                    logging.warning(
                        'Your dataset iterator ran out of data; '
                        'interrupting training. Make sure that your dataset '
                        'can generate at least `steps_per_epoch * epochs` '
                        'batches (in this case, %d batches).' %
                        steps_per_epoch * epochs)
                    break

                if not isinstance(outs, list):
                    outs = [outs]
                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o
                callbacks.on_batch_end(step_index, batch_logs)
                if callbacks.model.stop_training:
                    break
            if do_validation:
                val_outs = test_loop(model,
                                     val_iterator,
                                     steps=validation_steps,
                                     verbose=0)
                if not isinstance(val_outs, list):
                    val_outs = [val_outs]
                # Same labels assumed.
                for l, o in zip(out_labels, val_outs):
                    epoch_logs['val_' + l] = o

            callbacks.on_epoch_end(epoch, epoch_logs)
            if callbacks.model.stop_training:
                break
        callbacks.on_train_end()

        # Copy the weights back from the replicated model to the original model.
        updated_weights = current_strategy.unwrap(
            model._grouped_model)[0].get_weights()
        model.set_weights(updated_weights)
        return model.history
コード例 #9
0
def _experimental_fit_loop(model,
                           iterator,
                           epochs=100,
                           verbose=1,
                           callbacks=None,
                           initial_epoch=0,
                           steps_per_epoch=None,
                           val_iterator=None,
                           validation_steps=None):
    """Fit loop for training with TPU DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator that returns inputs and targets
      epochs: Number of times to iterate over the data
      verbose: Integer, Verbosity mode, 0, 1 or 2
      callbacks: List of callbacks to be called during training
      initial_epoch: Epoch at which to start training
          (useful for resuming a previous training run)
      steps_per_epoch: Total number of steps (batches of samples)
          before declaring one epoch finished and starting the
          next epoch. Ignored with the default value of `None`.
      val_iterator: Iterator for validation data.
      validation_steps: Number of steps to run validation for
          (only if doing validation from data tensors).
          Ignored with the default value of `None`.

  Returns:
      Returns `None`.

  Raises:
      ValueError: in case of invalid arguments.
  """
    current_strategy = model._distribution_strategy

    K.get_session().run(current_strategy.initialize())

    def _per_device_fit_function(model):
        model._make_fit_function()
        return (model._fit_function.inputs, model._fit_function.outputs,
                model._fit_function.updates_op,
                model._fit_function.session_kwargs)

    # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
    K.set_learning_phase(1)
    out_labels = model.metrics_names or []

    def step_fn(ctx, inputs):
        """Clones the model and calls make_fit_function."""
        # TODO(priyag, sourabhbajaj): The model gets cloned every time
        # fit/test/predict is called. We should look into caching this keyed on
        # input shapes.
        inputs, targets = inputs
        clone_model_on_replicas(model,
                                current_strategy,
                                make_callback_model=True,
                                inputs=inputs,
                                targets=targets,
                                mode=_Mode.TRAIN)

        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_replica(
             _per_device_fit_function, args=(model._grouped_model_train, ))
        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy, grouped_inputs, grouped_outputs,
             grouped_updates, grouped_session_args)
        combined_fn = K.function(all_inputs,
                                 all_outputs,
                                 updates=all_updates,
                                 name='distributed_fit_function',
                                 **all_session_args)

        for label, output in zip(out_labels, combined_fn.outputs):
            if label == 'loss':
                reduce_op = distribute_lib.get_loss_reduction()
            else:
                # We reduce all other metrics using mean for now. This is temporary
                # workaround until new metrics are in place.
                reduce_op = ds_reduce_util.ReduceOp.MEAN
            ctx.set_last_step_output(label, output, reduce_op)

        # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
        # feed_dict, session kwargs, run options, run_metadata for now. These should
        # be handled appropriately
        return combined_fn.updates_op

    # Add initial dummy values for loss and other metric tensors.
    initial_loop_values = {}
    initial_loop_values['loss'] = constant_op.constant(1e7)
    for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
        initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)

    if steps_per_epoch is None:
        raise ValueError('`steps_per_epoch` should be specified when calling '
                         '`fit` on the model.')
    steps_per_run = K.variable(value=min(
        steps_per_epoch, current_strategy.extended.steps_per_run),
                               dtype='int32',
                               name='steps_per_run')

    with current_strategy.scope():
        ctx = current_strategy.run_steps_on_dataset(
            step_fn,
            iterator,
            iterations=steps_per_run,
            initial_loop_values=initial_loop_values)

    train_op = ctx.run_op
    output_tensors = ctx.last_step_outputs

    do_validation = bool(validation_steps)

    # Copy the weights from the original model to each of the replicated models.
    orig_model_weights = model.get_weights()
    with current_strategy.scope():
        distributed_model = current_strategy.unwrap(
            model._grouped_model_train)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)
    callbacks = cbks.configure_callbacks(callbacks,
                                         model,
                                         do_validation=do_validation,
                                         val_inputs=None,
                                         val_targets=None,
                                         epochs=epochs,
                                         steps_per_epoch=steps_per_epoch,
                                         verbose=verbose)

    # Calculate the steps each time on the device.
    steps_to_run = [current_strategy.extended.steps_per_run] * (
        steps_per_epoch // current_strategy.extended.steps_per_run)
    if steps_per_epoch % current_strategy.extended.steps_per_run:
        steps_to_run.append(steps_per_epoch %
                            current_strategy.extended.steps_per_run)

    callbacks.on_train_begin()
    for epoch in range(initial_epoch, epochs):
        callbacks.on_epoch_begin(epoch)
        epoch_logs = {}
        step_index = 0
        prev_step_count = None
        for step_count in steps_to_run:
            batch_logs = {
                'batch': step_index,
                'size': 1,
                'num_steps': step_count
            }
            callbacks.on_batch_begin(step_index, batch_logs)
            if prev_step_count is None or step_count != prev_step_count:
                steps_per_run.load(step_count, K.get_session())
                prev_step_count = step_count
            try:
                _, outputs = K.get_session().run([train_op, output_tensors])
            except errors.OutOfRangeError:
                logging.warning(
                    'Your dataset iterator ran out of data; '
                    'interrupting training. Make sure that your dataset '
                    'can generate at least `steps_per_epoch * epochs` '
                    'batches (in this case, %d batches).' % steps_per_epoch *
                    epochs)
                break

            batch_logs.update(outputs)
            callbacks.on_batch_end(step_index, batch_logs)
            step_index = step_index + step_count
            if callbacks.model.stop_training:
                break

        if do_validation:
            logging.info('Running validation at fit epoch: %s', epoch)

            # Since we create a new clone from the original model we need to copy
            # the weights back to the original model before we can run validation.
            with current_strategy.scope():
                updated_weights = current_strategy.unwrap(
                    model._grouped_model_train)[0].get_weights()
                model.set_weights(updated_weights)

            val_outs = _experimental_test_loop(
                model,
                val_iterator,
                steps=validation_steps,
                verbose=verbose,
                initialize_finalize_strategy=False)
            if not isinstance(val_outs, list):
                val_outs = [val_outs]
            # Same labels assumed.
            for label, val_out in zip(out_labels, val_outs):
                epoch_logs['val_' + label] = val_out

        callbacks.on_epoch_end(epoch, epoch_logs)
        if callbacks.model.stop_training:
            break
    callbacks.on_train_end()

    # Copy the weights back from the replicated model to the original model.
    with current_strategy.scope():
        updated_weights = current_strategy.unwrap(
            model._grouped_model_train)[0].get_weights()
        model.set_weights(updated_weights)

    K.get_session().run(current_strategy.finalize())
    return model.history
コード例 #10
0
def _experimental_fit_loop(model,
                           iterator,
                           epochs=100,
                           initial_epoch=0,
                           steps_per_epoch=None):
    """fit function when using TPU DistributionStrategy for training.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator that returns inputs and targets
      epochs: Number of times to iterate over the data
      initial_epoch: Epoch at which to start training
          (useful for resuming a previous training run)
      steps_per_epoch: Total number of steps (batches of samples)
          before declaring one epoch finished and starting the
          next epoch. Ignored with the default value of `None`.

  Returns:
      Returns `None`.

  Raises:
      ValueError: in case of invalid arguments.
  """
    current_strategy = model._distribution_strategy

    # TODO(priyag): Add validation that shapes are fully defined for TPU case.

    # TODO(priyag, sourabhbajaj): This should be moved into a callback instead.
    K.get_session().run(current_strategy.initialize())

    def _per_device_train_function(model):
        model._make_train_function()
        return (model.train_function.inputs, model.train_function.outputs,
                model.train_function.updates_op,
                model.train_function.session_kwargs)

    # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
    K.set_learning_phase(1)

    def step_fn(ctx, inputs, targets):
        """Clones the model and calls make_train_function."""
        # TODO(priyag, sourabhbajaj): Should cache this keyed on input shapes.
        clone_model_on_towers(model,
                              current_strategy,
                              make_callback_model=True,
                              inputs=inputs,
                              targets=targets)

        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_tower(
             _per_device_train_function, model._grouped_model)
        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy,
             grouped_inputs,
             grouped_outputs,
             grouped_updates,
             grouped_session_args,
             with_loss_tensor=True)
        combined_fn = K.Function(all_inputs,
                                 all_outputs,
                                 updates=all_updates,
                                 name='distributed_train_function',
                                 **all_session_args)

        # TODO(priyag, sourabhbajaj): Perhaps the aggregation type needs to be
        # something else for different outputs.
        out_labels = model.metrics_names or []
        for label, output in zip(out_labels, combined_fn.outputs):
            ctx.set_last_step_output(
                label, output, aggregation=distribute_lib.get_loss_reduction())

        # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
        # feed_dict, session kwargs, run options, run_metadata for now. These should
        # be handled appropriately
        return combined_fn.updates_op

    # Add initial dummy values for loss and other metric tensors.
    initial_loop_values = {}
    initial_loop_values['loss'] = constant_op.constant(1e7)
    for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
        initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)

    with current_strategy.scope():
        # TODO(priyag, sourabhbajaj): Adjust steps_per_run appropriately based on
        # steps_per_epoch and number of epochs.
        ctx = current_strategy.run_steps_on_dataset(
            step_fn,
            iterator,
            iterations=current_strategy.steps_per_run,
            initial_loop_values=initial_loop_values)

    train_op = ctx.run_op
    output_tensors = ctx.last_step_outputs

    # Copy the weights from the original model to each of the replicated models.
    orig_model_weights = model.get_weights()
    with current_strategy.scope():
        distributed_model = current_strategy.unwrap(model._grouped_model)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)

    assert steps_per_epoch is not None

    # TODO(priyag, sourabhbajaj): Add callbacks support.
    # TODO(priyag, sourabhbajaj): Add validation.
    for epoch in range(initial_epoch, epochs):
        for step_index in range(0, steps_per_epoch,
                                current_strategy.steps_per_run):
            try:
                _, outs = K.get_session().run([train_op, output_tensors])
                # TODO(priyag, sourabhbajaj): Remove this logging in favor of proper
                # summaries through callbacks.
                print('Epoch: {}, step_index: {}, loss: {}'.format(
                    epoch, step_index, outs['loss']))
                for label, out in outs.items():
                    print(label, ': ', out)
            except errors.OutOfRangeError:
                logging.warning(
                    'Your dataset iterator ran out of data; '
                    'interrupting training. Make sure that your dataset '
                    'can generate at least `steps_per_epoch * epochs` '
                    'batches (in this case, %d batches).' % steps_per_epoch *
                    epochs)
                break

    # Copy the weights back from the replicated model to the original model.
    with current_strategy.scope():
        updated_weights = current_strategy.unwrap(
            model._grouped_model)[0].get_weights()
        model.set_weights(updated_weights)

    K.get_session().run(current_strategy.finalize())
コード例 #11
0
def fit_loop(model,
             iterator,
             epochs=100,
             verbose=1,
             callbacks=None,
             val_iterator=None,
             initial_epoch=0,
             steps_per_epoch=None,
             validation_steps=None):
    """fit function when using DistributionStrategy for training.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator for input data.
      epochs: Number of times to iterate over the data
      verbose: Verbosity mode, 0, 1 or 2
      callbacks: List of callbacks to be called during training
      val_iterator: Iterator for validation data.
      initial_epoch: Epoch at which to start training
          (useful for resuming a previous training run)
      steps_per_epoch: Total number of steps (batches of samples)
          before declaring one epoch finished and starting the
          next epoch. Ignored with the default value of `None`.
      validation_steps: Number of steps to run validation for
          (only if doing validation from data tensors).
          Ignored with the default value of `None`.

  Returns:
      `History` object.

  Raises:
      ValueError: in case of invalid arguments.
  """
    current_strategy = model._distribution_strategy

    def _per_device_train_function(model):
        model._make_train_function()
        return (model.train_function.inputs, model.train_function.outputs,
                model.train_function.updates_op,
                model.train_function.session_kwargs)

    inputs, targets = _get_input_from_iterator(iterator, model)
    with current_strategy.scope():
        # Create train ops on each of the devices when we call
        # `_per_device_train_function`.
        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_tower(
             _per_device_train_function, model._grouped_model)
        # Unwrap all the per device values returned from `call_for_each_tower`.
        # Unwrapping per device values gives you a list of values that can be
        # used to construct a new train function that is composed of update ops on
        # all the devices over which the model is distributed.
        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy,
             grouped_inputs,
             grouped_outputs,
             grouped_updates,
             grouped_session_args,
             with_loss_tensor=True)

        # Dataset inputs and targets are also per devices values that need to be
        # unwrapped.
        dataset_inputs = distributed_training_utils.flatten_perdevice_values(
            current_strategy, inputs)
        dataset_targets = distributed_training_utils.flatten_perdevice_values(
            current_strategy, targets)

    # Create a train function that is composed of all the parameters above.
    distributed_train_function = K.Function(all_inputs,
                                            all_outputs,
                                            updates=all_updates,
                                            name='distributed_train_function',
                                            **all_session_args)

    # We need to set sample_weights to None since there are sample weight
    # placeholders that are created with default values.
    sample_weights = [
        None for _ in range(len(model.outputs) * current_strategy.num_towers)
    ]
    if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
        ins = dataset_inputs + dataset_targets + sample_weights + [1]
    else:
        ins = dataset_inputs + dataset_targets

    do_validation = False
    if validation_steps:
        do_validation = True
        if steps_per_epoch is None:
            raise ValueError('Can only use `validation_steps` '
                             'when doing step-wise '
                             'training, i.e. `steps_per_epoch` '
                             'must be set.')

    # Copy the weights from the original model to each of the replicated models.
    orig_model_weights = model.get_weights()
    with current_strategy.scope():
        distributed_model = current_strategy.unwrap(model._grouped_model)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)

    callbacks = cbks.configure_callbacks(callbacks,
                                         model,
                                         do_validation=do_validation,
                                         val_inputs=None,
                                         val_targets=None,
                                         epochs=epochs,
                                         steps_per_epoch=steps_per_epoch,
                                         verbose=verbose)
    out_labels = model.metrics_names or []
    callbacks.on_train_begin()
    for epoch in range(initial_epoch, epochs):
        callbacks.on_epoch_begin(epoch)
        if steps_per_epoch is not None:
            epoch_logs = {}
            for step_index in range(steps_per_epoch):
                batch_logs = {'batch': step_index, 'size': 1}
                callbacks.on_batch_begin(step_index, batch_logs)
                try:
                    outs = distributed_train_function(ins)
                except errors.OutOfRangeError:
                    logging.warning(
                        'Your dataset iterator ran out of data; '
                        'interrupting training. Make sure that your dataset '
                        'can generate at least `steps_per_epoch * epochs` '
                        'batches (in this case, %d batches).' %
                        steps_per_epoch * epochs)
                    break

                if not isinstance(outs, list):
                    outs = [outs]

                outs = _aggregate_metrics_across_towers(
                    current_strategy.num_towers, out_labels, outs)
                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o
                callbacks.on_batch_end(step_index, batch_logs)
                if callbacks.model.stop_training:
                    break
            if do_validation:
                val_outs = test_loop(model,
                                     val_iterator,
                                     steps=validation_steps,
                                     verbose=0)
                if not isinstance(val_outs, list):
                    val_outs = [val_outs]
                # Same labels assumed.
                for l, o in zip(out_labels, val_outs):
                    epoch_logs['val_' + l] = o

        callbacks.on_epoch_end(epoch, epoch_logs)
        if callbacks.model.stop_training:
            break
    callbacks.on_train_end()

    # Copy the weights back from the replicated model to the original model.
    with current_strategy.scope():
        updated_weights = current_strategy.unwrap(
            model._grouped_model)[0].get_weights()
        model.set_weights(updated_weights)
    return model.history
コード例 #12
0
def fit_loop(
    model,
    inputs,
    targets,
    epochs=100,
    verbose=1,
    callbacks=None,
    val_inputs=None,
    val_targets=None,
    initial_epoch=0,
    steps_per_epoch=None,
    validation_steps=None):
  """fit function when using DistributionStrategy for training.

  Arguments:
      model: Keras Model instance.
      inputs: List of input arrays.
      targets: List of target arrays.
      epochs: Number of times to iterate over the data
      verbose: Verbosity mode, 0, 1 or 2
      callbacks: List of callbacks to be called during training
      val_inputs: List of input arrays.
      val_targets: List of target arrays.
      initial_epoch: Epoch at which to start training
          (useful for resuming a previous training run)
      steps_per_epoch: Total number of steps (batches of samples)
          before declaring one epoch finished and starting the
          next epoch. Ignored with the default value of `None`.
      validation_steps: Number of steps to run validation for
          (only if doing validation from data tensors).
          Ignored with the default value of `None`.

  Returns:
      `History` object.

  Raises:
      ValueError: in case of invalid arguments.
  """
  current_strategy = model._distribution_strategy
  def _per_device_train_function(model):
    model._make_train_function()
    return (model.train_function.inputs,
            model.train_function.outputs,
            model.train_function.updates_op,
            model.train_function.session_kwargs)

  with current_strategy.scope():
    # Create train ops on each of the devices when we call
    # `_per_device_train_function`.
    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.call_for_each_tower(
         _per_device_train_function, model._grouped_model)
    # Unwrap all the per device values returned from `call_for_each_tower`.
    # Unwrapping per device values gives you a list of values that can be
    # used to construct a new train function that is composed of update ops on
    # all the devices over which the model is distributed.
    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs,
         grouped_updates, grouped_session_args, with_loss_tensor=True)

    # Dataset inputs and targets are also per devices values that need to be
    # unwrapped.
    dataset_inputs = distributed_training_utils.flatten_perdevice_values(
        current_strategy, inputs)
    dataset_targets = distributed_training_utils.flatten_perdevice_values(
        current_strategy, targets)

  # Create a train function that is composed of all the parameters above.
  distributed_train_function = K.Function(
      all_inputs, all_outputs,
      updates=all_updates,
      name='distributed_train_function',
      **all_session_args)

  # We need to set sample_weights to None since there are sample weight
  # placeholders that are created with default values.
  sample_weights = [None for _ in range(len(model.outputs) *
                                        current_strategy.num_towers)]
  if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
    ins = dataset_inputs + dataset_targets + sample_weights + [1]
  else:
    ins = dataset_inputs + dataset_targets

  do_validation = False
  if validation_steps:
    do_validation = True
    if steps_per_epoch is None:
      raise ValueError('Can only use `validation_steps` '
                       'when doing step-wise '
                       'training, i.e. `steps_per_epoch` '
                       'must be set.')

  # Copy the weights from the original model to each of the replicated models.
  orig_model_weights = model.get_weights()
  with current_strategy.scope():
    distributed_model = current_strategy.unwrap(model._grouped_model)[0]
    distributed_training_utils.set_weights(
        current_strategy, distributed_model, orig_model_weights)

  callbacks = cbks.configure_callbacks(
      callbacks,
      model,
      do_validation=do_validation,
      val_inputs=None,
      val_targets=None,
      epochs=epochs,
      steps_per_epoch=steps_per_epoch,
      verbose=verbose)
  out_labels = model.metrics_names or []
  callbacks.on_train_begin()
  for epoch in range(initial_epoch, epochs):
    callbacks.on_epoch_begin(epoch)
    if steps_per_epoch is not None:
      epoch_logs = {}
      for step_index in range(steps_per_epoch):
        batch_logs = {'batch': step_index, 'size': 1}
        callbacks.on_batch_begin(step_index, batch_logs)
        try:
          outs = distributed_train_function(ins)
        except errors.OutOfRangeError:
          logging.warning('Your dataset iterator ran out of data; '
                          'interrupting training. Make sure that your dataset '
                          'can generate at least `steps_per_epoch * epochs` '
                          'batches (in this case, %d batches).' %
                          steps_per_epoch * epochs)
          break

        if not isinstance(outs, list):
          outs = [outs]

        outs = _aggregate_metrics_across_towers(
            len(current_strategy._devices), out_labels, outs)
        for l, o in zip(out_labels, outs):
          batch_logs[l] = o
        callbacks.on_batch_end(step_index, batch_logs)
        if callbacks.model.stop_training:
          break
      if do_validation:
        val_outs = test_loop(
            model,
            val_inputs,
            val_targets,
            steps=validation_steps,
            verbose=0)
        if not isinstance(val_outs, list):
          val_outs = [val_outs]
        # Same labels assumed.
        for l, o in zip(out_labels, val_outs):
          epoch_logs['val_' + l] = o

    callbacks.on_epoch_end(epoch, epoch_logs)
    if callbacks.model.stop_training:
      break
  callbacks.on_train_end()

  # Copy the weights back from the replicated model to the original model.
  with current_strategy.scope():
    updated_weights = current_strategy.unwrap(
        model._grouped_model)[0].get_weights()
    model.set_weights(updated_weights)
  return model.history
コード例 #13
0
def predict_loop(model, inputs, verbose=0, steps=None):
  """Abstract method to loop over some data in batches.

  Arguments:
      model: Keras Model instance.
      inputs: list of tensors to be fed to `f`.
      verbose: verbosity mode.
      steps: Total number of steps (batches of samples)
          before declaring `_predict_loop` finished.
          Ignored with the default value of `None`.

  Returns:
      Array of predictions (if the model has a single output)
      or list of arrays of predictions
      (if the model has multiple outputs).
  """
  current_strategy = model._distribution_strategy
  def _per_device_predict_function(model):
    model._make_predict_function()
    return (model.predict_function.inputs,
            model.predict_function.outputs,
            model.predict_function.updates_op,
            model.predict_function.session_kwargs)

  with current_strategy.scope():
    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.call_for_each_tower(
         _per_device_predict_function, model._grouped_model)

    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args)

    dataset_inputs = distributed_training_utils.flatten_perdevice_values(
        current_strategy, inputs)

  distributed_predict_function = K.Function(
      all_inputs, all_outputs,
      updates=all_updates,
      name='distributed_predict_function',
      **all_session_args)

  if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
    ins = dataset_inputs + [0]
  else:
    ins = dataset_inputs

  if verbose == 1:
    progbar = Progbar(target=steps)

  # Copy the weights from the original model to each of the replicated models.
  orig_model_weights = model.get_weights()
  with current_strategy.scope():
    distributed_model = current_strategy.unwrap(model._grouped_model)[0]
    distributed_training_utils.set_weights(
        current_strategy, distributed_model, orig_model_weights)

  if steps is not None:
    # Since we do not know how many samples we will see, we cannot pre-allocate
    # the returned Numpy arrays. Instead, we store one array per batch seen
    # and concatenate them upon returning.
    unconcatenated_outs = []
    for step in range(steps):
      batch_outs = distributed_predict_function(ins)
      if not isinstance(batch_outs, list):
        batch_outs = [batch_outs]
      if step == 0:
        for _ in batch_outs:
          unconcatenated_outs.append([])
      for i, batch_out in enumerate(batch_outs):
        unconcatenated_outs[i].append(batch_out)
      if verbose == 1:
        progbar.update(step + 1)
    if len(unconcatenated_outs) == 1:
      return np.concatenate(unconcatenated_outs[0], axis=0)
    return [
        np.concatenate(unconcatenated_outs[i], axis=0)
        for i in range(len(unconcatenated_outs))
    ]
コード例 #14
0
def test_loop(model, inputs, targets, verbose=0, steps=None):
  """evaluate method to validate a model that uses DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      inputs: List of input arrays.
      targets: List of target arrays.
      verbose: verbosity mode.
      steps: Total number of steps (batches of samples)
          before declaring predictions finished.
          Ignored with the default value of `None`.

  Returns:
      Scalar loss (if the model has a single output and no metrics)
      or list of scalars (if the model has multiple outputs
      and/or metrics). The attribute `model.metrics_names` will give you
      the display labels for the scalar outputs.
  """
  current_strategy = model._distribution_strategy
  def _per_device_test_function(model):
    model._make_test_function()
    return (model.test_function.inputs,
            model.test_function.outputs,
            model.test_function.updates_op,
            model.test_function.session_kwargs)

  with current_strategy.scope():
    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.call_for_each_tower(
         _per_device_test_function, model._grouped_model)

    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args, with_loss_tensor=True)

    dataset_inputs = distributed_training_utils.flatten_perdevice_values(
        current_strategy, inputs)
    dataset_targets = distributed_training_utils.flatten_perdevice_values(
        current_strategy, targets)

  distributed_test_function = K.Function(
      all_inputs, all_outputs,
      updates=all_updates,
      name='distributed_test_function',
      **all_session_args)

  # We need to set sample_weights to None since there are sample weight
  # placeholders that are created with default values.
  sample_weights = [None for _ in range(len(model.outputs) *
                                        current_strategy.num_towers)]
  if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
    ins = dataset_inputs + dataset_targets + sample_weights + [0]
  else:
    ins = dataset_inputs + dataset_targets

  outs = []
  if verbose == 1:
    progbar = Progbar(target=steps)

  # Copy the weights from the original model to each of the replicated models.
  orig_model_weights = model.get_weights()
  with current_strategy.scope():
    distributed_model = current_strategy.unwrap(model._grouped_model)[0]
    distributed_training_utils.set_weights(
        current_strategy, distributed_model, orig_model_weights)

  if steps is not None:
    for step in range(steps):
      batch_outs = distributed_test_function(ins)
      batch_outs = _aggregate_metrics_across_towers(
          len(current_strategy._devices), model.metrics_names, batch_outs)
      if isinstance(batch_outs, list):
        if step == 0:
          for _ in enumerate(batch_outs):
            outs.append(0.)
        for i, batch_out in enumerate(batch_outs):
          outs[i] += batch_out
      else:
        if step == 0:
          outs.append(0.)
        outs[0] += batch_outs
      if verbose == 1:
        progbar.update(step + 1)
    for i in range(len(outs)):
      outs[i] /= steps

  if len(outs) == 1:
    return outs[0]
  return outs
コード例 #15
0
def _experimental_predict_loop(model, iterator, verbose=0, steps=None):
  """Predict loop for predicting with TPU DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator for input data.
      verbose: Integer, Verbosity mode 0 or 1.
      steps: Total number of steps (batches of samples)
          before declaring `_predict_loop` finished.
          Ignored with the default value of `None`.

  Returns:
      Array of predictions (if the model has a single output)
      or list of arrays of predictions
      (if the model has multiple outputs).
  """
  current_strategy = model._distribution_strategy
  K.get_session().run(current_strategy.initialize())

  # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
  K.set_learning_phase(0)

  def _per_device_predict_function(model):
    model._make_predict_function()
    return (model.predict_function.inputs,
            model.predict_function.outputs,
            model.predict_function.updates_op,
            model.predict_function.session_kwargs)

  def step_fn(ctx, *inputs):
    """Clones the model and calls make_predict_function."""

    # TODO(priyag, sourabhbajaj): The model gets cloned every time
    # fit/test/predict is called. We should look into caching this keyed on
    # input shapes.
    clone_model_on_replicas(
        model,
        current_strategy,
        make_callback_model=False,
        inputs=inputs,
        mode=_Mode.PREDICT)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.call_for_each_replica(
         _per_device_predict_function, args=(model._grouped_model_predict,))

    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args)

    combined_fn = K.function(
        all_inputs, all_outputs,
        updates=all_updates,
        name='distributed_predict_function',
        **all_session_args)

    for label, output in zip(model.output_names, combined_fn.outputs):
      ctx.set_last_step_output(label, output)

    return combined_fn.updates_op

  # Add initial dummy values for outputs.
  initial_loop_values = {}
  batch_dimension = distributed_training_utils.get_batch_dimension(iterator)
  for name, tensor in zip(model.output_names, model.outputs):
    # TODO(priyag): This is a workaround as we do not know the batch dimension
    # of the model's output at this point.
    shape = tensor_shape.TensorShape(tensor.shape.dims)
    shape.dims = [batch_dimension] + shape.dims[1:]
    initial_loop_values[name] = array_ops.zeros(shape, tensor.dtype)

  with current_strategy.scope():
    # TODO(priyag, sourabhbajaj): Support steps_per_run if/when we add outfeed.
    ctx = current_strategy.run_steps_on_dataset(
        step_fn, iterator, iterations=1,
        initial_loop_values=initial_loop_values)

  predict_op = ctx.run_op
  output_tensors = ctx.last_step_outputs

  if verbose == 1:
    progbar = Progbar(target=steps)

  # Copy the weights from the original model to each of the replicated models.
  orig_model_weights = model.get_weights()
  with current_strategy.scope():
    distributed_model = current_strategy.unwrap(model._grouped_model_predict)[0]
    distributed_training_utils.set_weights(
        current_strategy, distributed_model, orig_model_weights)

  assert steps is not None
  # Since we do not know how many samples we will see, we cannot pre-allocate
  # the returned Numpy arrays. Instead, we store one array per batch seen
  # and concatenate them upon returning.
  unconcatenated_outs = [[] for _ in model.outputs]
  for step in range(steps):
    _, batch_outs = K.get_session().run([predict_op, output_tensors])
    # TODO(priyag): maybe need to unwrap the outputs first for MirroredStrategy.
    for i, label in enumerate(model.output_names):
      unconcatenated_outs[i].extend(batch_outs[label])
    if verbose >= 1:
      progbar.update(step + 1)

  K.get_session().run(current_strategy.finalize())

  if len(unconcatenated_outs) == 1:
    return np.concatenate(unconcatenated_outs[0], axis=0)
  return [
      np.concatenate(unconcatenated_outs[i], axis=0)
      for i in range(len(unconcatenated_outs))
  ]
コード例 #16
0
def test_loop(model, iterator, verbose=0, steps=None):
    """evaluate method to validate a model that uses DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator for input data.
      verbose: verbosity mode.
      steps: Total number of steps (batches of samples)
          before declaring predictions finished.
          Ignored with the default value of `None`.

  Returns:
      Scalar loss (if the model has a single output and no metrics)
      or list of scalars (if the model has multiple outputs
      and/or metrics). The attribute `model.metrics_names` will give you
      the display labels for the scalar outputs.
  """
    current_strategy = model._distribution_strategy

    clone_model_on_towers(model, current_strategy)

    def _per_device_test_function(model):
        model._make_test_function()
        return (model.test_function.inputs, model.test_function.outputs,
                model.test_function.updates_op,
                model.test_function.session_kwargs)

    inputs, targets = _get_input_from_iterator(iterator, model)
    with current_strategy.scope():
        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_tower(
             _per_device_test_function, model._grouped_model)

        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy,
             grouped_inputs,
             grouped_outputs,
             grouped_updates,
             grouped_session_args,
             with_loss_tensor=True)

        dataset_inputs = distributed_training_utils.flatten_perdevice_values(
            current_strategy, inputs)
        dataset_targets = distributed_training_utils.flatten_perdevice_values(
            current_strategy, targets)

    distributed_test_function = K.Function(all_inputs,
                                           all_outputs,
                                           updates=all_updates,
                                           name='distributed_test_function',
                                           **all_session_args)

    # We need to set sample_weights to None since there are sample weight
    # placeholders that are created with default values.
    sample_weights = [
        None for _ in range(len(model.outputs) * current_strategy.num_towers)
    ]
    if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
        ins = dataset_inputs + dataset_targets + sample_weights + [0]
    else:
        ins = dataset_inputs + dataset_targets

    outs = []
    if verbose == 1:
        progbar = Progbar(target=steps)

    # Copy the weights from the original model to each of the replicated models.
    orig_model_weights = model.get_weights()
    with current_strategy.scope():
        distributed_model = current_strategy.unwrap(model._grouped_model)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)

    if steps is not None:
        for step in range(steps):
            batch_outs = distributed_test_function(ins)
            batch_outs = _aggregate_metrics_across_towers(
                current_strategy.num_towers, model.metrics_names, batch_outs)
            if isinstance(batch_outs, list):
                if step == 0:
                    for _ in enumerate(batch_outs):
                        outs.append(0.)
                for i, batch_out in enumerate(batch_outs):
                    outs[i] += batch_out
            else:
                if step == 0:
                    outs.append(0.)
                outs[0] += batch_outs
            if verbose == 1:
                progbar.update(step + 1)
        for i in range(len(outs)):
            outs[i] /= steps

    if len(outs) == 1:
        return outs[0]
    return outs
コード例 #17
0
def _experimental_fit_loop(model,
                           iterator,
                           epochs=100,
                           verbose=1,
                           callbacks=None,
                           initial_epoch=0,
                           steps_per_epoch=None):
    """Fit loop for training with TPU DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator that returns inputs and targets
      epochs: Number of times to iterate over the data
      verbose: Integer, Verbosity mode, 0, 1 or 2
      callbacks: List of callbacks to be called during training
      initial_epoch: Epoch at which to start training
          (useful for resuming a previous training run)
      steps_per_epoch: Total number of steps (batches of samples)
          before declaring one epoch finished and starting the
          next epoch. Ignored with the default value of `None`.

  Returns:
      Returns `None`.

  Raises:
      ValueError: in case of invalid arguments.
  """
    current_strategy = model._distribution_strategy

    # TODO(priyag): Add validation that shapes are fully defined for TPU case.

    K.get_session().run(current_strategy.initialize())

    def _per_device_train_function(model):
        model._make_train_function()
        return (model.train_function.inputs, model.train_function.outputs,
                model.train_function.updates_op,
                model.train_function.session_kwargs)

    # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
    K.set_learning_phase(1)

    def step_fn(ctx, inputs, targets):
        """Clones the model and calls make_train_function."""
        # TODO(priyag, sourabhbajaj): The model gets cloned every time
        # fit/test/predict is called. We should look into caching this keyed on
        # input shapes.
        clone_model_on_towers(model,
                              current_strategy,
                              make_callback_model=True,
                              inputs=inputs,
                              targets=targets)

        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_tower(
             _per_device_train_function, model._grouped_model)
        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy, grouped_inputs, grouped_outputs,
             grouped_updates, grouped_session_args)
        combined_fn = K.Function(all_inputs,
                                 all_outputs,
                                 updates=all_updates,
                                 name='distributed_train_function',
                                 **all_session_args)

        out_labels = model.metrics_names or []
        for label, output in zip(out_labels, combined_fn.outputs):
            if label == 'loss':
                aggregation = distribute_lib.get_loss_reduction()
            else:
                # We aggregate all other metrics using mean for now. This is temporary
                # workaround until new metrics are in place.
                aggregation = variable_scope.VariableAggregation.MEAN
            ctx.set_last_step_output(label, output, aggregation)

        # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
        # feed_dict, session kwargs, run options, run_metadata for now. These should
        # be handled appropriately
        return combined_fn.updates_op

    # Add initial dummy values for loss and other metric tensors.
    initial_loop_values = {}
    initial_loop_values['loss'] = constant_op.constant(1e7)
    for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
        initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)

    with current_strategy.scope():
        # TODO(priyag, sourabhbajaj): Adjust steps_per_run appropriately based on
        # steps_per_epoch and number of epochs.
        ctx = current_strategy.run_steps_on_dataset(
            step_fn,
            iterator,
            iterations=current_strategy.steps_per_run,
            initial_loop_values=initial_loop_values)

    train_op = ctx.run_op
    output_tensors = ctx.last_step_outputs

    # Copy the weights from the original model to each of the replicated models.
    orig_model_weights = model.get_weights()
    with current_strategy.scope():
        distributed_model = current_strategy.unwrap(model._grouped_model)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)

    assert steps_per_epoch is not None

    # TODO(sourabhbajaj): Convert this into a proper validation function
    if callbacks:
        raise NotImplementedError(
            'Callbacks are not supported with TPUStrategy right now.')

    callbacks = cbks.configure_callbacks(callbacks,
                                         model,
                                         do_validation=False,
                                         val_inputs=None,
                                         val_targets=None,
                                         epochs=epochs,
                                         steps_per_epoch=steps_per_epoch,
                                         verbose=verbose)
    # TODO(priyag, sourabhbajaj): Add callbacks support for per step callback
    # TODO(priyag, sourabhbajaj): Fix the number of steps run with steps_per_run
    # TODO(priyag, sourabhbajaj): Add validation.
    callbacks.on_train_begin()
    for epoch in range(initial_epoch, epochs):
        callbacks.on_epoch_begin(epoch)
        epoch_logs = {}
        for step_index in range(0, steps_per_epoch,
                                current_strategy.steps_per_run):
            # TODO(sourabhbajaj): Replace size with a combination of steps_per_run
            # and batch_size
            batch_logs = {'batch': step_index, 'size': 1}
            callbacks.on_batch_begin(step_index, batch_logs)
            try:
                _, outputs = K.get_session().run([train_op, output_tensors])
            except errors.OutOfRangeError:
                logging.warning(
                    'Your dataset iterator ran out of data; '
                    'interrupting training. Make sure that your dataset '
                    'can generate at least `steps_per_epoch * epochs` '
                    'batches (in this case, %d batches).' % steps_per_epoch *
                    epochs)
                break

            batch_logs.update(outputs)
            callbacks.on_batch_end(step_index, batch_logs)
            if callbacks.model.stop_training:
                break

        callbacks.on_epoch_end(epoch, epoch_logs)
        if callbacks.model.stop_training:
            break
    callbacks.on_train_end()

    # Copy the weights back from the replicated model to the original model.
    with current_strategy.scope():
        updated_weights = current_strategy.unwrap(
            model._grouped_model)[0].get_weights()
        model.set_weights(updated_weights)

    K.get_session().run(current_strategy.finalize())
    return model.history
コード例 #18
0
def predict_loop(model, iterator, verbose=0, steps=None):
    """Abstract method to loop over some data in batches.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator for input data.
      verbose: verbosity mode.
      steps: Total number of steps (batches of samples)
          before declaring `_predict_loop` finished.
          Ignored with the default value of `None`.

  Returns:
      Array of predictions (if the model has a single output)
      or list of arrays of predictions
      (if the model has multiple outputs).
  """
    current_strategy = model._distribution_strategy

    clone_model_on_towers(model, current_strategy)

    def _per_device_predict_function(model):
        model._make_predict_function()
        return (model.predict_function.inputs, model.predict_function.outputs,
                model.predict_function.updates_op,
                model.predict_function.session_kwargs)

    inputs, _ = _get_input_from_iterator(iterator, model)
    with current_strategy.scope():
        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_tower(
             _per_device_predict_function, model._grouped_model)

        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy, grouped_inputs, grouped_outputs,
             grouped_updates, grouped_session_args)

        dataset_inputs = distributed_training_utils.flatten_perdevice_values(
            current_strategy, inputs)

    distributed_predict_function = K.Function(
        all_inputs,
        all_outputs,
        updates=all_updates,
        name='distributed_predict_function',
        **all_session_args)

    if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
        ins = dataset_inputs + [0]
    else:
        ins = dataset_inputs

    if verbose == 1:
        progbar = Progbar(target=steps)

    # Copy the weights from the original model to each of the replicated models.
    orig_model_weights = model.get_weights()
    with current_strategy.scope():
        distributed_model = current_strategy.unwrap(model._grouped_model)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)

    if steps is not None:
        # Since we do not know how many samples we will see, we cannot pre-allocate
        # the returned Numpy arrays. Instead, we store one array per batch seen
        # and concatenate them upon returning.
        unconcatenated_outs = []
        for step in range(steps):
            batch_outs = distributed_predict_function(ins)
            if not isinstance(batch_outs, list):
                batch_outs = [batch_outs]
            if step == 0:
                for _ in batch_outs:
                    unconcatenated_outs.append([])
            for i, batch_out in enumerate(batch_outs):
                unconcatenated_outs[i].append(batch_out)
            if verbose == 1:
                progbar.update(step + 1)
        if len(unconcatenated_outs) == 1:
            return np.concatenate(unconcatenated_outs[0], axis=0)
        return [
            np.concatenate(unconcatenated_outs[i], axis=0)
            for i in range(len(unconcatenated_outs))
        ]
コード例 #19
0
def test_loop(model, iterator, verbose=0, steps=None):
    """Test loop for evaluating with DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator for input data.
      verbose: Integer, Verbosity mode 0 or 1.
      steps: Total number of steps (batches of samples)
          before declaring predictions finished.
          Ignored with the default value of `None`.

  Returns:
      Scalar loss (if the model has a single output and no metrics)
      or list of scalars (if the model has multiple outputs
      and/or metrics). The attribute `model.metrics_names` will give you
      the display labels for the outputs.
  """
    current_strategy = model._distribution_strategy

    # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
    if current_strategy.__class__.__name__ == 'TPUStrategy':
        return _experimental_test_loop(model, iterator, verbose, steps)

    if not model._grouped_model:
        clone_model_on_replicas(model, current_strategy)

    def _per_device_eval_function(model):
        model._make_eval_function()
        return (model._eval_function.inputs, model._eval_function.outputs,
                model._eval_function.updates_op,
                model._eval_function.session_kwargs)

    inputs, targets, sample_weights = _get_input_from_iterator(iterator, model)
    with current_strategy.scope():
        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_replica(
             _per_device_eval_function, args=(model._grouped_model, ))

        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy,
             grouped_inputs,
             grouped_outputs,
             grouped_updates,
             grouped_session_args,
             with_loss_tensor=True)

        dataset_inputs = distributed_training_utils.flatten_perdevice_values(
            current_strategy, inputs)
        dataset_targets = distributed_training_utils.flatten_perdevice_values(
            current_strategy, targets)

        distributed_test_function = K.function(
            all_inputs,
            all_outputs,
            updates=all_updates,
            name='distributed_test_function',
            **all_session_args)

        # We need to set sample_weights to None since there are sample weight
        # placeholders that are created with default values.
        sample_weights = [
            None for _ in range(
                len(model.outputs) * current_strategy.num_replicas_in_sync)
        ]
        ins = dataset_inputs + dataset_targets + sample_weights

        for m in model.stateful_metric_functions:
            m.reset_states()

        outs = []
        if verbose == 1:
            progbar = Progbar(target=steps)

        # Copy the weights from the original model to each of the replicated models.
        orig_model_weights = model.get_weights()
        distributed_model = current_strategy.unwrap(model._grouped_model)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)

        assert steps is not None
        for step in range(steps):
            batch_outs = distributed_test_function(ins)
            if isinstance(batch_outs, list):
                if step == 0:
                    outs = [0.] * len(batch_outs)
                outs[0] += batch_outs[0]  # index 0 = 'loss'
                outs[1:] = batch_outs[1:]
            else:
                if step == 0:
                    outs.append(0.)
                outs[0] += batch_outs  # index 0 = 'loss'
            if verbose >= 1:
                progbar.update(step + 1)
        outs[0] /= steps  # index 0 = 'loss'

        if len(outs) == 1:
            return outs[0]
        return outs
コード例 #20
0
def _experimental_fit_loop(
    model,
    iterator,
    epochs=100,
    verbose=1,
    callbacks=None,
    initial_epoch=0,
    steps_per_epoch=None,
    val_iterator=None,
    validation_steps=None):
  """Fit loop for training with TPU DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator that returns inputs and targets
      epochs: Number of times to iterate over the data
      verbose: Integer, Verbosity mode, 0, 1 or 2
      callbacks: List of callbacks to be called during training
      initial_epoch: Epoch at which to start training
          (useful for resuming a previous training run)
      steps_per_epoch: Total number of steps (batches of samples)
          before declaring one epoch finished and starting the
          next epoch. Ignored with the default value of `None`.
      val_iterator: Iterator for validation data.
      validation_steps: Number of steps to run validation for
          (only if doing validation from data tensors).
          Ignored with the default value of `None`.

  Returns:
      Returns `None`.

  Raises:
      ValueError: in case of invalid arguments.
  """
  current_strategy = model._distribution_strategy

  K.get_session().run(current_strategy.initialize())

  def _per_device_fit_function(model):
    model._make_fit_function()
    return (model._fit_function.inputs, model._fit_function.outputs,
            model._fit_function.updates_op, model._fit_function.session_kwargs)

  # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
  K.set_learning_phase(1)
  out_labels = model.metrics_names or []

  def step_fn(ctx, inputs, targets):
    """Clones the model and calls make_fit_function."""
    # TODO(priyag, sourabhbajaj): The model gets cloned every time
    # fit/test/predict is called. We should look into caching this keyed on
    # input shapes.
    clone_model_on_replicas(
        model,
        current_strategy,
        make_callback_model=True,
        inputs=inputs,
        targets=targets,
        mode=_Mode.TRAIN)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.call_for_each_replica(
         _per_device_fit_function, args=(model._grouped_model_train,))
    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs,
         grouped_updates, grouped_session_args)
    combined_fn = K.function(
        all_inputs,
        all_outputs,
        updates=all_updates,
        name='distributed_fit_function',
        **all_session_args)

    for label, output in zip(out_labels, combined_fn.outputs):
      if label == 'loss':
        aggregation = distribute_lib.get_loss_reduction()
      else:
        # We aggregate all other metrics using mean for now. This is temporary
        # workaround until new metrics are in place.
        aggregation = variable_scope.VariableAggregation.MEAN
      ctx.set_last_step_output(label, output, aggregation)

    # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
    # feed_dict, session kwargs, run options, run_metadata for now. These should
    # be handled appropriately
    return combined_fn.updates_op

  # Add initial dummy values for loss and other metric tensors.
  initial_loop_values = {}
  initial_loop_values['loss'] = constant_op.constant(1e7)
  for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
    initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)

  if steps_per_epoch is None:
    raise ValueError('`steps_per_epoch` should be specified when calling '
                     '`fit` on the model.')
  steps_per_run = K.variable(
      value=min(steps_per_epoch, current_strategy.steps_per_run),
      dtype='int32',
      name='steps_per_run')

  with current_strategy.scope():
    ctx = current_strategy.run_steps_on_dataset(
        step_fn, iterator, iterations=steps_per_run,
        initial_loop_values=initial_loop_values)

  train_op = ctx.run_op
  output_tensors = ctx.last_step_outputs

  do_validation = bool(validation_steps)

  # Copy the weights from the original model to each of the replicated models.
  orig_model_weights = model.get_weights()
  with current_strategy.scope():
    distributed_model = current_strategy.unwrap(model._grouped_model_train)[0]
    distributed_training_utils.set_weights(
        current_strategy, distributed_model, orig_model_weights)
  callbacks = cbks.configure_callbacks(
      callbacks,
      model,
      do_validation=do_validation,
      val_inputs=None,
      val_targets=None,
      epochs=epochs,
      steps_per_epoch=steps_per_epoch,
      verbose=verbose)

  # Calculate the steps each time on the device.
  steps_to_run = [current_strategy.steps_per_run] * (
      steps_per_epoch // current_strategy.steps_per_run)
  if steps_per_epoch % current_strategy.steps_per_run:
    steps_to_run.append(steps_per_epoch % current_strategy.steps_per_run)

  callbacks.on_train_begin()
  for epoch in range(initial_epoch, epochs):
    callbacks.on_epoch_begin(epoch)
    epoch_logs = {}
    step_index = 0
    prev_step_count = None
    for step_count in steps_to_run:
      batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count}
      callbacks.on_batch_begin(step_index, batch_logs)
      if prev_step_count is None or step_count != prev_step_count:
        steps_per_run.load(step_count, K.get_session())
        prev_step_count = step_count
      try:
        _, outputs = K.get_session().run([train_op, output_tensors])
      except errors.OutOfRangeError:
        logging.warning('Your dataset iterator ran out of data; '
                        'interrupting training. Make sure that your dataset '
                        'can generate at least `steps_per_epoch * epochs` '
                        'batches (in this case, %d batches).' %
                        steps_per_epoch * epochs)
        break

      batch_logs.update(outputs)
      callbacks.on_batch_end(step_index, batch_logs)
      step_index = step_index + step_count
      if callbacks.model.stop_training:
        break

    if do_validation:
      logging.info('Running validation at fit epoch: %s', epoch)

      # Since we create a new clone from the original model we need to copy
      # the weights back to the original model before we can run validation.
      with current_strategy.scope():
        updated_weights = current_strategy.unwrap(
            model._grouped_model_train)[0].get_weights()
        model.set_weights(updated_weights)

      val_outs = _experimental_test_loop(
          model,
          val_iterator,
          steps=validation_steps,
          verbose=verbose,
          initialize_finalize_strategy=False)
      if not isinstance(val_outs, list):
        val_outs = [val_outs]
      # Same labels assumed.
      for label, val_out in zip(out_labels, val_outs):
        epoch_logs['val_' + label] = val_out

    callbacks.on_epoch_end(epoch, epoch_logs)
    if callbacks.model.stop_training:
      break
  callbacks.on_train_end()

  # Copy the weights back from the replicated model to the original model.
  with current_strategy.scope():
    updated_weights = current_strategy.unwrap(
        model._grouped_model_train)[0].get_weights()
    model.set_weights(updated_weights)

  K.get_session().run(current_strategy.finalize())
  return model.history
コード例 #21
0
def _experimental_test_loop(model,
                            iterator,
                            verbose=0,
                            steps=None,
                            initialize_finalize_strategy=True):
    """Test loop for evaluating with TPU DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator for input data.
      verbose: Integer, Verbosity mode 0 or 1.
      steps: Total number of steps (batches of samples)
          before declaring predictions finished.
          Ignored with the default value of `None`.
      initialize_finalize_strategy: Should the strategy initialize and finalize
          functions be called.

  Returns:
      Scalar loss (if the model has a single output and no metrics)
      or list of scalars (if the model has multiple outputs
      and/or metrics). The attribute `model.metrics_names` will give you
      the display labels for the outputs.
  """
    current_strategy = model._distribution_strategy
    if initialize_finalize_strategy:
        K.get_session().run(current_strategy.initialize())

    def _per_device_eval_function(model):
        model._make_eval_function()
        return (model._eval_function.inputs, model._eval_function.outputs,
                model._eval_function.updates_op,
                model._eval_function.session_kwargs)

    # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
    K.set_learning_phase(0)

    def step_fn(ctx, inputs):
        """Clones the model and calls make_eval_function."""
        # TODO(priyag, sourabhbajaj): The model gets cloned every time
        # fit/test/predict is called. We should look into caching this keyed on
        # input shapes.
        inputs, targets = inputs
        clone_model_on_replicas(model,
                                current_strategy,
                                make_callback_model=False,
                                inputs=inputs,
                                targets=targets,
                                mode=_Mode.TEST)

        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_replica(
             _per_device_eval_function, args=(model._grouped_model_test, ))

        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy, grouped_inputs, grouped_outputs,
             grouped_updates, grouped_session_args)

        combined_fn = K.function(all_inputs,
                                 all_outputs,
                                 updates=all_updates,
                                 name='distributed_test_function',
                                 **all_session_args)

        for label, output in zip(model.metrics_names, combined_fn.outputs):
            if label == 'loss':
                reduce_op = distribute_lib.get_loss_reduction()
            else:
                # We reduce all other metrics using mean for now. This is temporary
                # workaround until new metrics are in place.
                reduce_op = ds_reduce_util.ReduceOp.MEAN
            ctx.set_last_step_output(label, output, reduce_op)

        return combined_fn.updates_op

    # Add initial dummy values for loss and other metric tensors.
    initial_loop_values = {}
    initial_loop_values['loss'] = constant_op.constant(1e7)
    for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
        initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)

    with current_strategy.scope():
        # TODO(priyag): Use steps_per_run when we use new metrics as they will
        # allow handling metric computation at each step using variables.
        ctx = current_strategy.run_steps_on_dataset(
            step_fn,
            iterator,
            iterations=1,
            initial_loop_values=initial_loop_values)

    test_op = ctx.run_op
    output_tensors = ctx.last_step_outputs

    if verbose == 1:
        progbar = Progbar(target=steps)

    # Copy the weights from the original model to each of the replicated models.
    orig_model_weights = model.get_weights()
    with current_strategy.scope():
        distributed_model = current_strategy.unwrap(
            model._grouped_model_test)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)

    assert steps is not None
    outs = [0.] * len(model.metrics_names)
    for step in range(steps):
        _, batch_outs = K.get_session().run([test_op, output_tensors])
        for i, label in enumerate(model.metrics_names):
            outs[i] += batch_outs[label]
        if verbose >= 1:
            progbar.update(step + 1)
    for i in range(len(outs)):
        outs[i] /= (steps)

    if initialize_finalize_strategy:
        K.get_session().run(current_strategy.finalize())

    if len(outs) == 1:
        return outs[0]
    return outs
コード例 #22
0
def test_loop(model, iterator, verbose=0, steps=None):
  """Test loop for evaluating with DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator for input data.
      verbose: Integer, Verbosity mode 0 or 1.
      steps: Total number of steps (batches of samples)
          before declaring predictions finished.
          Ignored with the default value of `None`.

  Returns:
      Scalar loss (if the model has a single output and no metrics)
      or list of scalars (if the model has multiple outputs
      and/or metrics). The attribute `model.metrics_names` will give you
      the display labels for the outputs.
  """
  current_strategy = model._distribution_strategy

  # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
  if current_strategy.__class__.__name__ == 'TPUStrategy':
    return _experimental_test_loop(model, iterator, verbose, steps)

  if not model._grouped_model:
    clone_model_on_replicas(model, current_strategy)

  def _per_device_eval_function(model):
    model._make_eval_function()
    return (model._eval_function.inputs, model._eval_function.outputs,
            model._eval_function.updates_op,
            model._eval_function.session_kwargs)

  inputs, targets, sample_weights = _get_input_from_iterator(iterator, model)
  with current_strategy.scope():
    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.call_for_each_replica(
         _per_device_eval_function, args=(model._grouped_model,))

    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args, with_loss_tensor=True)

    dataset_inputs = distributed_training_utils.flatten_perdevice_values(
        current_strategy, inputs)
    dataset_targets = distributed_training_utils.flatten_perdevice_values(
        current_strategy, targets)

    distributed_test_function = K.function(
        all_inputs, all_outputs,
        updates=all_updates,
        name='distributed_test_function',
        **all_session_args)

    # We need to set sample_weights to None since there are sample weight
    # placeholders that are created with default values.
    sample_weights = [None for _ in range(
        len(model.outputs) * current_strategy.num_replicas_in_sync)]
    if not isinstance(K.learning_phase(), int):
      ins = dataset_inputs + dataset_targets + sample_weights + [0]
    else:
      ins = dataset_inputs + dataset_targets

    for m in model.stateful_metric_functions:
      m.reset_states()

    outs = []
    if verbose == 1:
      progbar = Progbar(target=steps)

    # Copy the weights from the original model to each of the replicated models.
    orig_model_weights = model.get_weights()
    distributed_model = current_strategy.unwrap(model._grouped_model)[0]
    distributed_training_utils.set_weights(
        current_strategy, distributed_model, orig_model_weights)

    assert steps is not None
    for step in range(steps):
      batch_outs = distributed_test_function(ins)
      if isinstance(batch_outs, list):
        if step == 0:
          outs = [0.] * len(batch_outs)
        outs[0] += batch_outs[0]  # index 0 = 'loss'
        outs[1:] = batch_outs[1:]
      else:
        if step == 0:
          outs.append(0.)
        outs[0] += batch_outs  # index 0 = 'loss'
      if verbose >= 1:
        progbar.update(step + 1)
    outs[0] /= steps  # index 0 = 'loss'

    if len(outs) == 1:
      return outs[0]
    return outs
コード例 #23
0
def _experimental_predict_loop(model, iterator, verbose=0, steps=None):
    """Predict loop for predicting with TPU DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator for input data.
      verbose: Integer, Verbosity mode 0 or 1.
      steps: Total number of steps (batches of samples)
          before declaring `_predict_loop` finished.
          Ignored with the default value of `None`.

  Returns:
      Array of predictions (if the model has a single output)
      or list of arrays of predictions
      (if the model has multiple outputs).
  """
    current_strategy = model._distribution_strategy
    K.get_session().run(current_strategy.initialize())

    # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
    K.set_learning_phase(0)

    def _per_device_predict_function(model):
        model._make_predict_function()
        return (model.predict_function.inputs, model.predict_function.outputs,
                model.predict_function.updates_op,
                model.predict_function.session_kwargs)

    def step_fn(ctx, inputs):
        """Clones the model and calls make_predict_function."""

        # TODO(priyag, sourabhbajaj): The model gets cloned every time
        # fit/test/predict is called. We should look into caching this keyed on
        # input shapes.
        clone_model_on_replicas(model,
                                current_strategy,
                                make_callback_model=False,
                                inputs=inputs,
                                mode=_Mode.PREDICT)

        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_replica(
             _per_device_predict_function,
             args=(model._grouped_model_predict, ))

        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy, grouped_inputs, grouped_outputs,
             grouped_updates, grouped_session_args)

        combined_fn = K.function(all_inputs,
                                 all_outputs,
                                 updates=all_updates,
                                 name='distributed_predict_function',
                                 **all_session_args)

        for label, output in zip(model.output_names, combined_fn.outputs):
            ctx.set_last_step_output(label, output)

        return combined_fn.updates_op

    # Add initial dummy values for outputs.
    initial_loop_values = {}
    batch_dimension = distributed_training_utils.get_batch_dimension(iterator)
    for name, tensor in zip(model.output_names, model.outputs):
        # TODO(priyag): This is a workaround as we do not know the batch dimension
        # of the model's output at this point.
        shape = tensor_shape.TensorShape(tensor.shape.dims)
        shape.dims = [batch_dimension] + shape.dims[1:]
        initial_loop_values[name] = array_ops.zeros(shape, tensor.dtype)

    with current_strategy.scope():
        # TODO(priyag, sourabhbajaj): Support steps_per_run if/when we add outfeed.
        ctx = current_strategy.run_steps_on_dataset(
            step_fn,
            iterator,
            iterations=1,
            initial_loop_values=initial_loop_values)

    predict_op = ctx.run_op
    output_tensors = ctx.last_step_outputs

    if verbose == 1:
        progbar = Progbar(target=steps)

    # Copy the weights from the original model to each of the replicated models.
    orig_model_weights = model.get_weights()
    with current_strategy.scope():
        distributed_model = current_strategy.unwrap(
            model._grouped_model_predict)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)

    assert steps is not None
    # Since we do not know how many samples we will see, we cannot pre-allocate
    # the returned Numpy arrays. Instead, we store one array per batch seen
    # and concatenate them upon returning.
    unconcatenated_outs = [[] for _ in model.outputs]
    for step in range(steps):
        _, batch_outs = K.get_session().run([predict_op, output_tensors])
        # TODO(priyag): maybe need to unwrap the outputs first for MirroredStrategy.
        for i, label in enumerate(model.output_names):
            unconcatenated_outs[i].extend(batch_outs[label])
        if verbose >= 1:
            progbar.update(step + 1)

    K.get_session().run(current_strategy.finalize())

    if len(unconcatenated_outs) == 1:
        return np.concatenate(unconcatenated_outs[0], axis=0)
    return [
        np.concatenate(unconcatenated_outs[i], axis=0)
        for i in range(len(unconcatenated_outs))
    ]
コード例 #24
0
def fit_loop(model,
             inputs,
             targets,
             epochs=100,
             verbose=1,
             callbacks=None,
             val_inputs=None,
             val_targets=None,
             callback_metrics=None,
             initial_epoch=0,
             steps_per_epoch=None,
             validation_steps=None):
    """fit function when using DistributionStrategy for training.

  Arguments:
      model: Keras Model instance.
      inputs: List of input arrays.
      targets: List of target arrays.
      epochs: Number of times to iterate over the data
      verbose: Verbosity mode, 0, 1 or 2
      callbacks: List of callbacks to be called during training
      val_inputs: List of input arrays.
      val_targets: List of target arrays.
      callback_metrics: List of strings, the display names of the metrics
          passed to the callbacks. They should be the
          concatenation of list the display names of the outputs of
           `f` and the list of display names of the outputs of `f_val`.
      initial_epoch: Epoch at which to start training
          (useful for resuming a previous training run)
      steps_per_epoch: Total number of steps (batches of samples)
          before declaring one epoch finished and starting the
          next epoch. Ignored with the default value of `None`.
      validation_steps: Number of steps to run validation for
          (only if doing validation from data tensors).
          Ignored with the default value of `None`.

  Returns:
      `History` object.

  Raises:
      ValueError: in case of invalid arguments.
  """
    current_strategy = model._distribution_strategy

    def _per_device_train_function(model):
        model._make_train_function()
        return (model.train_function.inputs, model.train_function.outputs,
                model.train_function.updates_op,
                model.train_function.session_kwargs)

    with current_strategy.scope():
        # Create train ops on each of the devices when we call
        # `_per_device_train_function`.
        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_tower(
             _per_device_train_function, model._grouped_model)
        # Unwrap all the per device values returned from `call_for_each_tower`.
        # Unwrapping per device values gives you a list of values that can be
        # used to construct a new train function that is composed of update ops on
        # all the devices over which the model is distributed.
        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy,
             grouped_inputs,
             grouped_outputs,
             grouped_updates,
             grouped_session_args,
             with_loss_tensor=True)

        # Dataset inputs and targets are also per devices values that need to be
        # unwrapped.
        dataset_inputs = distributed_training_utils.flatten_perdevice_values(
            current_strategy, inputs)
        dataset_targets = distributed_training_utils.flatten_perdevice_values(
            current_strategy, targets)

    # Create a train function that is composed of all the parameters above.
    distributed_train_function = K.Function(all_inputs,
                                            all_outputs,
                                            updates=all_updates,
                                            name='distributed_train_function',
                                            **all_session_args)

    # We need to set sample_weights to None since there are sample weight
    # placeholders that are created with default values.
    sample_weights = [
        None for _ in range(len(model.outputs) * current_strategy.num_towers)
    ]
    if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
        ins = dataset_inputs + dataset_targets + sample_weights + [1]
    else:
        ins = dataset_inputs + dataset_targets

    do_validation = False
    if validation_steps:
        do_validation = True
        if steps_per_epoch is None:
            raise ValueError('Can only use `validation_steps` '
                             'when doing step-wise '
                             'training, i.e. `steps_per_epoch` '
                             'must be set.')
    out_labels = model.metrics_names
    if do_validation:
        callback_metrics = copy.copy(out_labels) + [
            'val_' + n for n in out_labels
        ]
    else:
        callback_metrics = copy.copy(out_labels)

    model.history = cbks.History()
    all_callbacks = [
        cbks.BaseLogger(stateful_metrics=model.stateful_metric_names)
    ]
    if verbose:
        # We assume that `steps_per_epoch` is always set since we have to use
        # Datasets.
        count_mode = 'steps'

        all_callbacks.append(
            cbks.ProgbarLogger(count_mode,
                               stateful_metrics=model.stateful_metric_names))
    all_callbacks += (callbacks or []) + [model.history]
    callbacks = cbks.CallbackList(all_callbacks)
    out_labels = out_labels or []

    # We set the callback model to an instance of the `DistributedModel` that we
    # create in the  `compile` call. The `DistributedModel` is initialized with
    # the first replicated model. We need to set the callback model to a
    # DistributedModel to allow us to override saving and loading weights when
    # we checkpoint the model during training.
    callback_model = model._replicated_model

    callbacks.set_model(callback_model)

    callbacks.set_params({
        'epochs': epochs,
        'steps': steps_per_epoch,
        'samples': None,
        'verbose': verbose,
        'do_validation': do_validation,
        'metrics': callback_metrics or [],
    })
    callbacks.on_train_begin()
    callback_model.stop_training = False

    out_labels = out_labels or []

    # Copy the weights from the original model to each of the replicated models.
    orig_model_weights = model.get_weights()
    with current_strategy.scope():
        distributed_model = current_strategy.unwrap(model._grouped_model)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)

    for epoch in range(initial_epoch, epochs):
        callbacks.on_epoch_begin(epoch)
        if steps_per_epoch is not None:
            epoch_logs = {}
            for step_index in range(steps_per_epoch):
                batch_logs = {}
                batch_logs['batch'] = step_index
                batch_logs['size'] = 1
                callbacks.on_batch_begin(step_index, batch_logs)
                try:
                    outs = distributed_train_function(ins)
                except errors.OutOfRangeError:
                    logging.warning(
                        'Your dataset iterator ran out of data; '
                        'interrupting training. Make sure that your dataset '
                        'can generate at least `steps_per_epoch * epochs` '
                        'batches (in this case, %d batches).' %
                        steps_per_epoch * epochs)
                    break

                if not isinstance(outs, list):
                    outs = [outs]

                # TODO(anjalisridhar): Temporary workaround for aggregating metrics
                # across towers. Replace with the new metrics module eventually.
                merged_output = []
                # The first output is the total loss.
                merged_output.append(outs[0])
                current_index = 1
                num_devices = len(current_strategy._devices)
                # Each label in `out_labels` corresponds to one set of metrics. The
                # number of metric values corresponds to the number of devices. We
                # currently take the mean of the values.
                for _ in out_labels[1:]:
                    m = np.mean(outs[current_index:current_index +
                                     num_devices])
                    merged_output.append(m)
                    current_index += num_devices

                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o
                callbacks.on_batch_end(step_index, batch_logs)
                if callback_model.stop_training:
                    break
            if do_validation:
                val_outs = test_loop(model,
                                     val_inputs,
                                     val_targets,
                                     steps=validation_steps,
                                     verbose=0)
                if not isinstance(val_outs, list):
                    val_outs = [val_outs]
                # Same labels assumed.
                for l, o in zip(out_labels, val_outs):
                    epoch_logs['val_' + l] = o

        callbacks.on_epoch_end(epoch, epoch_logs)
        if callback_model.stop_training:
            break
    callbacks.on_train_end()

    # Copy the weights back from the replicated model to the original model.
    with current_strategy.scope():
        updated_weights = current_strategy.unwrap(
            model._grouped_model)[0].get_weights()
        model.set_weights(updated_weights)
    return model.history