Exemplo n.º 1
0
    def step_fn(ctx, inputs, targets):
        """Clones the model and calls make_predict_function."""

        # TODO(anjalisridhar): Support predict input correctly as it will not
        # contain targets, only inputs.
        del targets

        # TODO(priyag, sourabhbajaj): The model gets cloned every time
        # fit/test/predict is called. We should look into caching this keyed on
        # input shapes.
        clone_model_on_towers(model,
                              current_strategy,
                              make_callback_model=False,
                              inputs=inputs)

        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_tower(
             _per_device_predict_function, model._grouped_model)

        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy, grouped_inputs, grouped_outputs,
             grouped_updates, grouped_session_args)

        combined_fn = K.Function(all_inputs,
                                 all_outputs,
                                 updates=all_updates,
                                 name='distributed_predict_function',
                                 **all_session_args)

        for label, output in zip(model.output_names, combined_fn.outputs):
            ctx.set_last_step_output(label, output)

        return combined_fn.updates_op
Exemplo n.º 2
0
    def step_fn(ctx, *inputs):
        """Clones the model and calls make_predict_function."""

        # TODO(priyag, sourabhbajaj): The model gets cloned every time
        # fit/test/predict is called. We should look into caching this keyed on
        # input shapes.
        clone_model_on_replicas(model,
                                current_strategy,
                                make_callback_model=False,
                                inputs=inputs,
                                mode=_Mode.PREDICT)

        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_replica(
             _per_device_predict_function, model._grouped_model_predict)

        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy, grouped_inputs, grouped_outputs,
             grouped_updates, grouped_session_args)

        combined_fn = K.Function(all_inputs,
                                 all_outputs,
                                 updates=all_updates,
                                 name='distributed_predict_function',
                                 **all_session_args)

        for label, output in zip(model.output_names, combined_fn.outputs):
            ctx.set_last_step_output(label, output)

        return combined_fn.updates_op
  def step_fn(ctx, inputs, targets):
    """Clones the model and calls make_train_function."""
    # TODO(priyag, sourabhbajaj): The model gets cloned every time
    # fit/test/predict is called. We should look into caching this keyed on
    # input shapes.
    clone_model_on_towers(
        model,
        current_strategy,
        make_callback_model=True,
        inputs=inputs,
        targets=targets)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.call_for_each_tower(
         _per_device_train_function, model._grouped_model)
    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs,
         grouped_updates, grouped_session_args)
    combined_fn = K.Function(
        all_inputs, all_outputs,
        updates=all_updates,
        name='distributed_train_function',
        **all_session_args)

    out_labels = model.metrics_names or []
    for label, output in zip(out_labels, combined_fn.outputs):
      if label == 'loss':
        aggregation = distribute_lib.get_loss_reduction()
      else:
        # We aggregate all other metrics using mean for now. This is temporary
        # workaround until new metrics are in place.
        aggregation = variable_scope.VariableAggregation.MEAN
      ctx.set_last_step_output(label, output, aggregation)

    # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
    # feed_dict, session kwargs, run options, run_metadata for now. These should
    # be handled appropriately
    return combined_fn.updates_op
Exemplo n.º 4
0
    def step_fn(ctx, inputs, targets):
        """Clones the model and calls make_test_function."""
        # TODO(priyag, sourabhbajaj): The model gets cloned every time
        # fit/test/predict is called. We should look into caching this keyed on
        # input shapes.
        clone_model_on_replicas(model,
                                current_strategy,
                                make_callback_model=False,
                                inputs=inputs,
                                targets=targets,
                                mode=_Mode.TEST)

        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_replica(
             _per_device_test_function, model._grouped_model_test)

        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy, grouped_inputs, grouped_outputs,
             grouped_updates, grouped_session_args)

        combined_fn = K.Function(all_inputs,
                                 all_outputs,
                                 updates=all_updates,
                                 name='distributed_test_function',
                                 **all_session_args)

        for label, output in zip(model.metrics_names, combined_fn.outputs):
            if label == 'loss':
                aggregation = distribute_lib.get_loss_reduction()
            else:
                # We aggregate all other metrics using mean for now. This is temporary
                # workaround until new metrics are in place.
                aggregation = variable_scope.VariableAggregation.MEAN
            ctx.set_last_step_output(label, output, aggregation)

        return combined_fn.updates_op
Exemplo n.º 5
0
    def step_fn(ctx, inputs, targets):
        """Clones the model and calls make_train_function."""
        # TODO(priyag, sourabhbajaj): Should cache this keyed on input shapes.
        clone_model_on_towers(model,
                              current_strategy,
                              make_callback_model=True,
                              inputs=inputs,
                              targets=targets)

        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_tower(
             _per_device_train_function, model._grouped_model)
        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy,
             grouped_inputs,
             grouped_outputs,
             grouped_updates,
             grouped_session_args,
             with_loss_tensor=True)
        combined_fn = K.Function(all_inputs,
                                 all_outputs,
                                 updates=all_updates,
                                 name='distributed_train_function',
                                 **all_session_args)

        # TODO(priyag, sourabhbajaj): Perhaps the aggregation type needs to be
        # something else for different outputs.
        out_labels = model.metrics_names or []
        for label, output in zip(out_labels, combined_fn.outputs):
            ctx.set_last_step_output(
                label, output, aggregation=distribute_lib.get_loss_reduction())

        # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
        # feed_dict, session kwargs, run options, run_metadata for now. These should
        # be handled appropriately
        return combined_fn.updates_op
Exemplo n.º 6
0
def predict_loop(model, iterator, verbose=0, steps=None):
    """Predict loop for predicting with DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator for input data.
      verbose: Integer, Verbosity mode 0 or 1.
      steps: Total number of steps (batches of samples)
          before declaring `_predict_loop` finished.
          Ignored with the default value of `None`.

  Returns:
      Array of predictions (if the model has a single output)
      or list of arrays of predictions
      (if the model has multiple outputs).
  """
    current_strategy = model._distribution_strategy

    # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
    if current_strategy.__class__.__name__ == 'TPUStrategy':
        return _experimental_predict_loop(model, iterator, verbose, steps)

    if not model._grouped_model:
        clone_model_on_replicas(model, current_strategy)

    def _per_device_predict_function(model):
        model._make_predict_function()
        return (model.predict_function.inputs, model.predict_function.outputs,
                model.predict_function.updates_op,
                model.predict_function.session_kwargs)

    inputs, _, _ = _get_input_from_iterator(iterator, model)
    with current_strategy.scope():
        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_replica(
             _per_device_predict_function, model._grouped_model)

        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy, grouped_inputs, grouped_outputs,
             grouped_updates, grouped_session_args)

        dataset_inputs = distributed_training_utils.flatten_perdevice_values(
            current_strategy, inputs)

        distributed_predict_function = K.Function(
            all_inputs,
            all_outputs,
            updates=all_updates,
            name='distributed_predict_function',
            **all_session_args)

        if model.uses_learning_phase and not isinstance(
                K.learning_phase(), int):
            ins = dataset_inputs + [0]
        else:
            ins = dataset_inputs

        if verbose == 1:
            progbar = Progbar(target=steps)

        # Copy the weights from the original model to each of the replicated models.
        orig_model_weights = model.get_weights()
        distributed_model = current_strategy.unwrap(model._grouped_model)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)

        num_towers = current_strategy.num_towers
        # Since we do not know how many samples we will see, we cannot
        # pre-allocate the returned Numpy arrays. Instead, we store one array per
        # batch seen and concatenate them upon returning.
        unconcatenated_outs = []
        assert steps is not None
        for step in range(steps):
            batch_outs = distributed_predict_function(ins)
            if not isinstance(batch_outs, list):
                batch_outs = [batch_outs]
            if step == 0:
                # batch_outs gives you the number of model outputs. In the distributed
                # case this will be number of model_outputs * num_towers.
                for _ in range(len(model.outputs)):
                    unconcatenated_outs.append([])
            for i in range(len(model.outputs)):
                nested_outs = batch_outs[i * num_towers:i * num_towers +
                                         num_towers]
                outs = nest.flatten(nested_outs)
                unconcatenated_outs[i].extend(outs)
            if verbose >= 1:
                progbar.update(step + 1)
        if len(unconcatenated_outs) == 1:
            return np.concatenate(unconcatenated_outs[0], axis=0)
        return [
            np.concatenate(unconcatenated_outs[i], axis=0)
            for i in range(len(unconcatenated_outs))
        ]
Exemplo n.º 7
0
def fit_loop(model,
             iterator,
             epochs=100,
             verbose=1,
             callbacks=None,
             val_iterator=None,
             initial_epoch=0,
             steps_per_epoch=None,
             validation_steps=None):
    """Fit loop for training with DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator for input data.
      epochs: Number of times to iterate over the data
      verbose: Integer, Verbosity mode, 0, 1 or 2
      callbacks: List of callbacks to be called during training
      val_iterator: Iterator for validation data.
      initial_epoch: Epoch at which to start training
          (useful for resuming a previous training run)
      steps_per_epoch: Total number of steps (batches of samples)
          before declaring one epoch finished and starting the
          next epoch. Ignored with the default value of `None`.
      validation_steps: Number of steps to run validation for
          (only if doing validation from data tensors).
          Ignored with the default value of `None`.

  Returns:
      `History` object.

  Raises:
      ValueError: in case of invalid arguments.
  """
    current_strategy = model._distribution_strategy

    # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
    if current_strategy.__class__.__name__ == 'TPUStrategy':
        return _experimental_fit_loop(model, iterator, epochs, verbose,
                                      callbacks, initial_epoch,
                                      steps_per_epoch, val_iterator,
                                      validation_steps)

    if not model._grouped_model:
        clone_model_on_replicas(model,
                                current_strategy,
                                make_callback_model=True)

    def _per_device_train_function(model):
        model._make_train_function()
        return (model.train_function.inputs, model.train_function.outputs,
                model.train_function.updates_op,
                model.train_function.session_kwargs)

    inputs, targets, sample_weights = _get_input_from_iterator(iterator, model)
    with current_strategy.scope():
        # Create train ops on each of the devices when we call
        # `_per_device_train_function`.
        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_replica(
             _per_device_train_function, model._grouped_model)
        # Unwrap all the per device values returned from `call_for_each_replica`.
        # Unwrapping per device values gives you a list of values that can be
        # used to construct a new train function that is composed of update ops on
        # all the devices over which the model is distributed.
        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy,
             grouped_inputs,
             grouped_outputs,
             grouped_updates,
             grouped_session_args,
             with_loss_tensor=True)

        # Dataset inputs and targets are also per devices values that need to be
        # unwrapped.
        dataset_inputs = distributed_training_utils.flatten_perdevice_values(
            current_strategy, inputs)
        dataset_targets = distributed_training_utils.flatten_perdevice_values(
            current_strategy, targets)

        # Create a train function that is composed of all the parameters above.
        distributed_train_function = K.Function(
            all_inputs,
            all_outputs,
            updates=all_updates,
            name='distributed_train_function',
            **all_session_args)

        # We need to set sample_weights to None since there are sample weight
        # placeholders that are created with default values.
        sample_weights = [
            None
            for _ in range(len(model.outputs) * current_strategy.num_replicas)
        ]
        if model.uses_learning_phase and not isinstance(
                K.learning_phase(), int):
            ins = dataset_inputs + dataset_targets + sample_weights + [1]
        else:
            ins = dataset_inputs + dataset_targets

        do_validation = False
        if validation_steps:
            do_validation = True

        # Copy the weights from the original model to each of the replicated models.
        orig_model_weights = model.get_weights()
        distributed_model = current_strategy.unwrap(model._grouped_model)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)

        callbacks = cbks.configure_callbacks(callbacks,
                                             model,
                                             do_validation=do_validation,
                                             val_inputs=None,
                                             val_targets=None,
                                             epochs=epochs,
                                             steps_per_epoch=steps_per_epoch,
                                             verbose=verbose)
        out_labels = model.metrics_names or []
        callbacks.on_train_begin()

        assert steps_per_epoch is not None

        for epoch in range(initial_epoch, epochs):
            # Reset stateful metrics
            for m in model.stateful_metric_functions:
                m.reset_states()
            callbacks.on_epoch_begin(epoch)
            epoch_logs = {}
            for step_index in range(steps_per_epoch):
                batch_logs = {'batch': step_index, 'size': 1}
                callbacks.on_batch_begin(step_index, batch_logs)
                try:
                    outs = distributed_train_function(ins)
                except errors.OutOfRangeError:
                    logging.warning(
                        'Your dataset iterator ran out of data; '
                        'interrupting training. Make sure that your dataset '
                        'can generate at least `steps_per_epoch * epochs` '
                        'batches (in this case, %d batches).' %
                        steps_per_epoch * epochs)
                    break

                if not isinstance(outs, list):
                    outs = [outs]

                outs = _aggregate_metrics_across_replicas(
                    current_strategy.num_replicas, out_labels,
                    model.stateful_metric_names, outs)
                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o
                callbacks.on_batch_end(step_index, batch_logs)
                if callbacks.model.stop_training:
                    break
            if do_validation:
                val_outs = test_loop(model,
                                     val_iterator,
                                     steps=validation_steps,
                                     verbose=0)
                if not isinstance(val_outs, list):
                    val_outs = [val_outs]
                # Same labels assumed.
                for l, o in zip(out_labels, val_outs):
                    epoch_logs['val_' + l] = o

            callbacks.on_epoch_end(epoch, epoch_logs)
            if callbacks.model.stop_training:
                break
        callbacks.on_train_end()

        # Copy the weights back from the replicated model to the original model.
        updated_weights = current_strategy.unwrap(
            model._grouped_model)[0].get_weights()
        model.set_weights(updated_weights)
        return model.history
Exemplo n.º 8
0
def test_loop(model, iterator, verbose=0, steps=None):
    """Test loop for evaluating with DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator for input data.
      verbose: Integer, Verbosity mode 0 or 1.
      steps: Total number of steps (batches of samples)
          before declaring predictions finished.
          Ignored with the default value of `None`.

  Returns:
      Scalar loss (if the model has a single output and no metrics)
      or list of scalars (if the model has multiple outputs
      and/or metrics). The attribute `model.metrics_names` will give you
      the display labels for the outputs.
  """
    current_strategy = model._distribution_strategy

    # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
    if current_strategy.__class__.__name__ == 'TPUStrategy':
        return _experimental_test_loop(model, iterator, verbose, steps)

    if not model._grouped_model:
        clone_model_on_replicas(model, current_strategy)

    def _per_device_test_function(model):
        model._make_test_function()
        return (model.test_function.inputs, model.test_function.outputs,
                model.test_function.updates_op,
                model.test_function.session_kwargs)

    inputs, targets, sample_weights = _get_input_from_iterator(iterator, model)
    with current_strategy.scope():
        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_replica(
             _per_device_test_function, model._grouped_model)

        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy,
             grouped_inputs,
             grouped_outputs,
             grouped_updates,
             grouped_session_args,
             with_loss_tensor=True)

        dataset_inputs = distributed_training_utils.flatten_perdevice_values(
            current_strategy, inputs)
        dataset_targets = distributed_training_utils.flatten_perdevice_values(
            current_strategy, targets)

        distributed_test_function = K.Function(
            all_inputs,
            all_outputs,
            updates=all_updates,
            name='distributed_test_function',
            **all_session_args)

        # We need to set sample_weights to None since there are sample weight
        # placeholders that are created with default values.
        sample_weights = [
            None
            for _ in range(len(model.outputs) * current_strategy.num_replicas)
        ]
        if model.uses_learning_phase and not isinstance(
                K.learning_phase(), int):
            ins = dataset_inputs + dataset_targets + sample_weights + [0]
        else:
            ins = dataset_inputs + dataset_targets

        for m in model.stateful_metric_functions:
            m.reset_states()
        stateful_metric_indices = [
            i for i, name in enumerate(model.metrics_names)
            if str(name) in model.stateful_metric_names
        ]

        outs = []
        if verbose == 1:
            progbar = Progbar(target=steps)

        # Copy the weights from the original model to each of the replicated models.
        orig_model_weights = model.get_weights()
        distributed_model = current_strategy.unwrap(model._grouped_model)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)

        assert steps is not None
        for step in range(steps):
            batch_outs = distributed_test_function(ins)
            batch_outs = _aggregate_metrics_across_replicas(
                current_strategy.num_replicas, model.metrics_names,
                model.stateful_metric_names, batch_outs)
            if isinstance(batch_outs, list):
                if step == 0:
                    outs = [0.] * len(batch_outs)
                for i, batch_out in enumerate(batch_outs):
                    if i in stateful_metric_indices:
                        outs[i] = batch_out
                    else:
                        outs[i] += batch_out
            else:
                if step == 0:
                    outs.append(0.)
                outs[0] += batch_outs
            if verbose >= 1:
                progbar.update(step + 1)
        for i in range(len(outs)):
            if i not in stateful_metric_indices:
                outs[i] /= steps

        if len(outs) == 1:
            return outs[0]
        return outs
Exemplo n.º 9
0
def predict_loop(model, iterator, verbose=0, steps=None):
    """Abstract method to loop over some data in batches.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator for input data.
      verbose: verbosity mode.
      steps: Total number of steps (batches of samples)
          before declaring `_predict_loop` finished.
          Ignored with the default value of `None`.

  Returns:
      Array of predictions (if the model has a single output)
      or list of arrays of predictions
      (if the model has multiple outputs).
  """
    current_strategy = model._distribution_strategy

    clone_model_on_towers(model, current_strategy)

    def _per_device_predict_function(model):
        model._make_predict_function()
        return (model.predict_function.inputs, model.predict_function.outputs,
                model.predict_function.updates_op,
                model.predict_function.session_kwargs)

    inputs, _ = _get_input_from_iterator(iterator, model)
    with current_strategy.scope():
        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_tower(
             _per_device_predict_function, model._grouped_model)

        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy, grouped_inputs, grouped_outputs,
             grouped_updates, grouped_session_args)

        dataset_inputs = distributed_training_utils.flatten_perdevice_values(
            current_strategy, inputs)

    distributed_predict_function = K.Function(
        all_inputs,
        all_outputs,
        updates=all_updates,
        name='distributed_predict_function',
        **all_session_args)

    if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
        ins = dataset_inputs + [0]
    else:
        ins = dataset_inputs

    if verbose == 1:
        progbar = Progbar(target=steps)

    # Copy the weights from the original model to each of the replicated models.
    orig_model_weights = model.get_weights()
    with current_strategy.scope():
        distributed_model = current_strategy.unwrap(model._grouped_model)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)

    if steps is not None:
        # Since we do not know how many samples we will see, we cannot pre-allocate
        # the returned Numpy arrays. Instead, we store one array per batch seen
        # and concatenate them upon returning.
        unconcatenated_outs = []
        for step in range(steps):
            batch_outs = distributed_predict_function(ins)
            if not isinstance(batch_outs, list):
                batch_outs = [batch_outs]
            if step == 0:
                for _ in batch_outs:
                    unconcatenated_outs.append([])
            for i, batch_out in enumerate(batch_outs):
                unconcatenated_outs[i].append(batch_out)
            if verbose == 1:
                progbar.update(step + 1)
        if len(unconcatenated_outs) == 1:
            return np.concatenate(unconcatenated_outs[0], axis=0)
        return [
            np.concatenate(unconcatenated_outs[i], axis=0)
            for i in range(len(unconcatenated_outs))
        ]
Exemplo n.º 10
0
def test_loop(model, iterator, verbose=0, steps=None):
    """evaluate method to validate a model that uses DistributionStrategy.

  Arguments:
      model: Keras Model instance.
      iterator: Iterator for input data.
      verbose: verbosity mode.
      steps: Total number of steps (batches of samples)
          before declaring predictions finished.
          Ignored with the default value of `None`.

  Returns:
      Scalar loss (if the model has a single output and no metrics)
      or list of scalars (if the model has multiple outputs
      and/or metrics). The attribute `model.metrics_names` will give you
      the display labels for the scalar outputs.
  """
    current_strategy = model._distribution_strategy

    clone_model_on_towers(model, current_strategy)

    def _per_device_test_function(model):
        model._make_test_function()
        return (model.test_function.inputs, model.test_function.outputs,
                model.test_function.updates_op,
                model.test_function.session_kwargs)

    inputs, targets = _get_input_from_iterator(iterator, model)
    with current_strategy.scope():
        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_tower(
             _per_device_test_function, model._grouped_model)

        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy,
             grouped_inputs,
             grouped_outputs,
             grouped_updates,
             grouped_session_args,
             with_loss_tensor=True)

        dataset_inputs = distributed_training_utils.flatten_perdevice_values(
            current_strategy, inputs)
        dataset_targets = distributed_training_utils.flatten_perdevice_values(
            current_strategy, targets)

    distributed_test_function = K.Function(all_inputs,
                                           all_outputs,
                                           updates=all_updates,
                                           name='distributed_test_function',
                                           **all_session_args)

    # We need to set sample_weights to None since there are sample weight
    # placeholders that are created with default values.
    sample_weights = [
        None for _ in range(len(model.outputs) * current_strategy.num_towers)
    ]
    if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
        ins = dataset_inputs + dataset_targets + sample_weights + [0]
    else:
        ins = dataset_inputs + dataset_targets

    outs = []
    if verbose == 1:
        progbar = Progbar(target=steps)

    # Copy the weights from the original model to each of the replicated models.
    orig_model_weights = model.get_weights()
    with current_strategy.scope():
        distributed_model = current_strategy.unwrap(model._grouped_model)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)

    if steps is not None:
        for step in range(steps):
            batch_outs = distributed_test_function(ins)
            batch_outs = _aggregate_metrics_across_towers(
                current_strategy.num_towers, model.metrics_names, batch_outs)
            if isinstance(batch_outs, list):
                if step == 0:
                    for _ in enumerate(batch_outs):
                        outs.append(0.)
                for i, batch_out in enumerate(batch_outs):
                    outs[i] += batch_out
            else:
                if step == 0:
                    outs.append(0.)
                outs[0] += batch_outs
            if verbose == 1:
                progbar.update(step + 1)
        for i in range(len(outs)):
            outs[i] /= steps

    if len(outs) == 1:
        return outs[0]
    return outs
Exemplo n.º 11
0
def fit_loop(model,
             inputs,
             targets,
             epochs=100,
             verbose=1,
             callbacks=None,
             val_inputs=None,
             val_targets=None,
             callback_metrics=None,
             initial_epoch=0,
             steps_per_epoch=None,
             validation_steps=None):
    """fit function when using DistributionStrategy for training.

  Arguments:
      model: Keras Model instance.
      inputs: List of input arrays.
      targets: List of target arrays.
      epochs: Number of times to iterate over the data
      verbose: Verbosity mode, 0, 1 or 2
      callbacks: List of callbacks to be called during training
      val_inputs: List of input arrays.
      val_targets: List of target arrays.
      callback_metrics: List of strings, the display names of the metrics
          passed to the callbacks. They should be the
          concatenation of list the display names of the outputs of
           `f` and the list of display names of the outputs of `f_val`.
      initial_epoch: Epoch at which to start training
          (useful for resuming a previous training run)
      steps_per_epoch: Total number of steps (batches of samples)
          before declaring one epoch finished and starting the
          next epoch. Ignored with the default value of `None`.
      validation_steps: Number of steps to run validation for
          (only if doing validation from data tensors).
          Ignored with the default value of `None`.

  Returns:
      `History` object.

  Raises:
      ValueError: in case of invalid arguments.
  """
    current_strategy = model._distribution_strategy

    def _per_device_train_function(model):
        model._make_train_function()
        return (model.train_function.inputs, model.train_function.outputs,
                model.train_function.updates_op,
                model.train_function.session_kwargs)

    with current_strategy.scope():
        # Create train ops on each of the devices when we call
        # `_per_device_train_function`.
        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_tower(
             _per_device_train_function, model._grouped_model)
        # Unwrap all the per device values returned from `call_for_each_tower`.
        # Unwrapping per device values gives you a list of values that can be
        # used to construct a new train function that is composed of update ops on
        # all the devices over which the model is distributed.
        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy,
             grouped_inputs,
             grouped_outputs,
             grouped_updates,
             grouped_session_args,
             with_loss_tensor=True)

        # Dataset inputs and targets are also per devices values that need to be
        # unwrapped.
        dataset_inputs = distributed_training_utils.flatten_perdevice_values(
            current_strategy, inputs)
        dataset_targets = distributed_training_utils.flatten_perdevice_values(
            current_strategy, targets)

    # Create a train function that is composed of all the parameters above.
    distributed_train_function = K.Function(all_inputs,
                                            all_outputs,
                                            updates=all_updates,
                                            name='distributed_train_function',
                                            **all_session_args)

    # We need to set sample_weights to None since there are sample weight
    # placeholders that are created with default values.
    sample_weights = [
        None for _ in range(len(model.outputs) * current_strategy.num_towers)
    ]
    if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
        ins = dataset_inputs + dataset_targets + sample_weights + [1]
    else:
        ins = dataset_inputs + dataset_targets

    do_validation = False
    if validation_steps:
        do_validation = True
        if steps_per_epoch is None:
            raise ValueError('Can only use `validation_steps` '
                             'when doing step-wise '
                             'training, i.e. `steps_per_epoch` '
                             'must be set.')
    out_labels = model.metrics_names
    if do_validation:
        callback_metrics = copy.copy(out_labels) + [
            'val_' + n for n in out_labels
        ]
    else:
        callback_metrics = copy.copy(out_labels)

    model.history = cbks.History()
    all_callbacks = [
        cbks.BaseLogger(stateful_metrics=model.stateful_metric_names)
    ]
    if verbose:
        # We assume that `steps_per_epoch` is always set since we have to use
        # Datasets.
        count_mode = 'steps'

        all_callbacks.append(
            cbks.ProgbarLogger(count_mode,
                               stateful_metrics=model.stateful_metric_names))
    all_callbacks += (callbacks or []) + [model.history]
    callbacks = cbks.CallbackList(all_callbacks)
    out_labels = out_labels or []

    # We set the callback model to an instance of the `DistributedModel` that we
    # create in the  `compile` call. The `DistributedModel` is initialized with
    # the first replicated model. We need to set the callback model to a
    # DistributedModel to allow us to override saving and loading weights when
    # we checkpoint the model during training.
    callback_model = model._replicated_model

    callbacks.set_model(callback_model)

    callbacks.set_params({
        'epochs': epochs,
        'steps': steps_per_epoch,
        'samples': None,
        'verbose': verbose,
        'do_validation': do_validation,
        'metrics': callback_metrics or [],
    })
    callbacks.on_train_begin()
    callback_model.stop_training = False

    out_labels = out_labels or []

    # Copy the weights from the original model to each of the replicated models.
    orig_model_weights = model.get_weights()
    with current_strategy.scope():
        distributed_model = current_strategy.unwrap(model._grouped_model)[0]
        distributed_training_utils.set_weights(current_strategy,
                                               distributed_model,
                                               orig_model_weights)

    for epoch in range(initial_epoch, epochs):
        callbacks.on_epoch_begin(epoch)
        if steps_per_epoch is not None:
            epoch_logs = {}
            for step_index in range(steps_per_epoch):
                batch_logs = {}
                batch_logs['batch'] = step_index
                batch_logs['size'] = 1
                callbacks.on_batch_begin(step_index, batch_logs)
                try:
                    outs = distributed_train_function(ins)
                except errors.OutOfRangeError:
                    logging.warning(
                        'Your dataset iterator ran out of data; '
                        'interrupting training. Make sure that your dataset '
                        'can generate at least `steps_per_epoch * epochs` '
                        'batches (in this case, %d batches).' %
                        steps_per_epoch * epochs)
                    break

                if not isinstance(outs, list):
                    outs = [outs]

                # TODO(anjalisridhar): Temporary workaround for aggregating metrics
                # across towers. Replace with the new metrics module eventually.
                merged_output = []
                # The first output is the total loss.
                merged_output.append(outs[0])
                current_index = 1
                num_devices = len(current_strategy._devices)
                # Each label in `out_labels` corresponds to one set of metrics. The
                # number of metric values corresponds to the number of devices. We
                # currently take the mean of the values.
                for _ in out_labels[1:]:
                    m = np.mean(outs[current_index:current_index +
                                     num_devices])
                    merged_output.append(m)
                    current_index += num_devices

                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o
                callbacks.on_batch_end(step_index, batch_logs)
                if callback_model.stop_training:
                    break
            if do_validation:
                val_outs = test_loop(model,
                                     val_inputs,
                                     val_targets,
                                     steps=validation_steps,
                                     verbose=0)
                if not isinstance(val_outs, list):
                    val_outs = [val_outs]
                # Same labels assumed.
                for l, o in zip(out_labels, val_outs):
                    epoch_logs['val_' + l] = o

        callbacks.on_epoch_end(epoch, epoch_logs)
        if callback_model.stop_training:
            break
    callbacks.on_train_end()

    # Copy the weights back from the replicated model to the original model.
    with current_strategy.scope():
        updated_weights = current_strategy.unwrap(
            model._grouped_model)[0].get_weights()
        model.set_weights(updated_weights)
    return model.history