def _make_execution_function(model, mode):
    """Makes function to run one step of distributed model execution."""
    if context.executing_eagerly():
        return _make_eager_execution_function(model, mode)

    strategy = model._distribution_strategy
    if not model._distributed_model:
        if model._compile_distribution:
            clone_model_on_replicas(model,
                                    strategy,
                                    make_callback_model=(mode == 'train'))
        else:
            _build_distributed_network(model, strategy)

    def _per_device_function(model):
        f = model._make_execution_function(mode)
        return (f.inputs, f.outputs, f.updates_op, f.session_kwargs)

    with strategy.scope():
        # Create train ops on each of the devices when we call
        # `_per_device_fit_function`.
        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = strategy.extended.call_for_each_replica(
             _per_device_function, args=(model._distributed_model, ))

        if mode == 'train':
            # Initialize the variables in the replicated model. This is necessary for
            # multi-worker training because on some workers, initialization is not
            # needed. This method does initialization or waiting for initialization
            # according to the context object of distribute coordinator.
            distributed_training_utils.init_restore_or_wait_for_variables()

        # Unwrap all the per device values returned from `call_for_each_replica`.
        # Unwrapping per device values gives you a list of values that can be
        # used to construct a new train function that is composed of update ops on
        # all the devices over which the model is distributed.
        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             strategy,
             grouped_inputs,
             grouped_outputs,
             grouped_updates,
             grouped_session_args,
             with_loss_tensor=(mode != 'predict'))

        return K.function(all_inputs,
                          all_outputs,
                          updates=all_updates,
                          name='distributed_{}_function'.format(mode),
                          **all_session_args)
def _make_execution_function(model, mode):
  """Makes function to run one step of distributed model execution."""
  if context.executing_eagerly():
    return _make_eager_execution_function(model, mode)

  strategy = model._distribution_strategy
  if not model._grouped_model:
    clone_model_on_replicas(
        model, strategy, make_callback_model=(mode == 'train'))

  def _per_device_function(model):
    f = model._make_execution_function(mode)
    return (f.inputs, f.outputs, f.updates_op, f.session_kwargs)

  with strategy.scope():
    # Create train ops on each of the devices when we call
    # `_per_device_fit_function`.
    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = strategy.extended.call_for_each_replica(
         _per_device_function, args=(model._grouped_model,))

    if mode == 'train':
      # Initialize the variables in the replicated model. This is necessary for
      # multi-worker training because on some workers, initialization is not
      # needed. This method does initialization or waiting for initialization
      # according to the context object of distribute coordinator.
      distributed_training_utils.init_restore_or_wait_for_variables()

    # Unwrap all the per device values returned from `call_for_each_replica`.
    # Unwrapping per device values gives you a list of values that can be
    # used to construct a new train function that is composed of update ops on
    # all the devices over which the model is distributed.
    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         strategy,
         grouped_inputs,
         grouped_outputs,
         grouped_updates,
         grouped_session_args,
         with_loss_tensor=(mode != 'predict'))

    return K.function(
        all_inputs,
        all_outputs,
        updates=all_updates,
        name='distributed_{}_function'.format(mode),
        **all_session_args)
Exemple #3
0
def fit_loop(model,
             iterator,
             epochs=100,
             verbose=1,
             callbacks=None,
             val_iterator=None,
             initial_epoch=0,
             steps_per_epoch=None,
             validation_steps=None):
    """Fit loop for training with DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator for input data.
      epochs: Number of times to iterate over the data
      verbose: Integer, Verbosity mode, 0, 1 or 2
      callbacks: List of callbacks to be called during training
      val_iterator: Iterator for validation data.
      initial_epoch: Epoch at which to start training
          (useful for resuming a previous training run)
      steps_per_epoch: Total number of steps (batches of samples)
          before declaring one epoch finished and starting the
          next epoch. Ignored with the default value of `None`.
      validation_steps: Number of steps to run validation for
          (only if doing validation from data tensors).
          Ignored with the default value of `None`.

  Returns:
      `History` object.

  Raises:
      ValueError: in case of invalid arguments.
  """
    current_strategy = model._distribution_strategy

    # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
    if current_strategy.__class__.__name__ == 'TPUStrategy':
        return _experimental_fit_loop(model, iterator, epochs, verbose,
                                      callbacks, initial_epoch,
                                      steps_per_epoch, val_iterator,
                                      validation_steps)

    if not model._grouped_model:
        clone_model_on_replicas(model,
                                current_strategy,
                                make_callback_model=True)

    def _per_device_fit_function(model):
        model._make_fit_function()
        return (model._fit_function.inputs, model._fit_function.outputs,
                model._fit_function.updates_op,
                model._fit_function.session_kwargs)

    inputs, targets, sample_weights = _get_input_from_iterator(iterator, model)
    with current_strategy.scope():
        # Create train ops on each of the devices when we call
        # `_per_device_fit_function`.
        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_replica(
             _per_device_fit_function, args=(model._grouped_model, ))

        # Initialize the variables in the replicated model. This is necessary for
        # multi-worker training because on some workers, initialization is not
        # needed. This method does initialization or waiting for initialization
        # according to the context object of distribute coordinator.
        distributed_training_utils.init_restore_or_wait_for_variables()

        # Unwrap all the per device values returned from `call_for_each_replica`.
        # Unwrapping per device values gives you a list of values that can be
        # used to construct a new train function that is composed of update ops on
        # all the devices over which the model is distributed.
        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy,
             grouped_inputs,
             grouped_outputs,
             grouped_updates,
             grouped_session_args,
             with_loss_tensor=True)

        # Dataset inputs and targets are also per devices values that need to be
        # unwrapped.
        dataset_inputs = distributed_training_utils.flatten_perdevice_values(
            current_strategy, inputs)
        dataset_targets = distributed_training_utils.flatten_perdevice_values(
            current_strategy, targets)

        # Create a train function that is composed of all the parameters above.
        distributed_fit_function = K.function(all_inputs,
                                              all_outputs,
                                              updates=all_updates,
                                              name='distributed_fit_function',
                                              **all_session_args)

        # We need to set sample_weights to None since there are sample weight
        # placeholders that are created with default values.
        sample_weights = [
            None for _ in range(
                len(model.outputs) * current_strategy.num_replicas_in_sync)
        ]
        if not isinstance(K.learning_phase(), int):
            ins = dataset_inputs + dataset_targets + sample_weights + [1]
        else:
            ins = dataset_inputs + dataset_targets

        do_validation = False
        if validation_steps:
            do_validation = True

        # Copy the weights from the original model to each of the replicated models.
        orig_model_weights = model.get_weights()
        distributed_model = current_strategy.unwrap(model._grouped_model)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)

        callbacks = cbks.configure_callbacks(callbacks,
                                             model,
                                             do_validation=do_validation,
                                             val_inputs=None,
                                             val_targets=None,
                                             epochs=epochs,
                                             steps_per_epoch=steps_per_epoch,
                                             verbose=verbose)
        out_labels = model.metrics_names or []
        callbacks.on_train_begin()

        assert steps_per_epoch is not None

        for epoch in range(initial_epoch, epochs):
            # Reset stateful metrics
            for m in model.stateful_metric_functions:
                m.reset_states()
            callbacks.on_epoch_begin(epoch)
            epoch_logs = {}
            for step_index in range(steps_per_epoch):
                batch_logs = {'batch': step_index, 'size': 1}
                callbacks.on_batch_begin(step_index, batch_logs)
                try:
                    outs = distributed_fit_function(ins)
                except errors.OutOfRangeError:
                    logging.warning(
                        'Your dataset iterator ran out of data; '
                        'interrupting training. Make sure that your dataset '
                        'can generate at least `steps_per_epoch * epochs` '
                        'batches (in this case, %d batches).' %
                        steps_per_epoch * epochs)
                    break

                if not isinstance(outs, list):
                    outs = [outs]
                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o
                callbacks.on_batch_end(step_index, batch_logs)
                if callbacks.model.stop_training:
                    break
            if do_validation:
                val_outs = test_loop(model,
                                     val_iterator,
                                     steps=validation_steps,
                                     verbose=0)
                if not isinstance(val_outs, list):
                    val_outs = [val_outs]
                # Same labels assumed.
                for l, o in zip(out_labels, val_outs):
                    epoch_logs['val_' + l] = o

            callbacks.on_epoch_end(epoch, epoch_logs)
            if callbacks.model.stop_training:
                break
        callbacks.on_train_end()

        # Copy the weights back from the replicated model to the original model.
        updated_weights = current_strategy.unwrap(
            model._grouped_model)[0].get_weights()
        model.set_weights(updated_weights)
        return model.history