示例#1
0
    def model_fn(features, labels, mode):
        """model_fn for keras Estimator."""
        # Raise an error when users use DistributionStrategy with native Keras
        # optimizers. Currently we only support native TensorFlow optimizers.
        if distribution_strategy_context.has_distribution_strategy() and \
            not isinstance(keras_model.optimizer,
                           (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
            raise ValueError(
                'Only TensorFlow native optimizers are supported with '
                'DistributionStrategy.')

        model = _clone_and_build_model(mode, keras_model, custom_objects,
                                       features, labels)
        model_output_names = []
        # We need to make sure that the output names of the last layer in the model
        # is the same for each of the cloned models. This is required for mirrored
        # strategy when we call regroup.
        if distribution_strategy_context.has_distribution_strategy():
            for name in model.output_names:
                name = re.compile(r'_\d$').sub('', name)
                model_output_names.append(name)
        else:
            model_output_names = model.output_names

        # Get inputs to EstimatorSpec
        predictions = dict(zip(model_output_names, model.outputs))

        loss = None
        train_op = None
        eval_metric_ops = None

        # Set loss and metric only during train and evaluate.
        if mode is not model_fn_lib.ModeKeys.PREDICT:
            if mode is model_fn_lib.ModeKeys.TRAIN:
                model._make_train_function()  # pylint: disable=protected-access
            else:
                model._make_test_function()  # pylint: disable=protected-access
            loss = model.total_loss

            eval_metric_ops = _convert_keras_metrics_to_estimator(model)

        # Set train_op only during train.
        if mode is model_fn_lib.ModeKeys.TRAIN:
            train_op = model.train_function.updates_op

        if not model._is_graph_network:
            # Reset model state to original state,
            # to avoid `model_fn` being destructive for the initial model argument.
            models.in_place_subclassed_model_state_restoration(keras_model)
        return model_fn_lib.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            loss=loss,
            train_op=train_op,
            eval_metric_ops=eval_metric_ops,
            export_outputs={
                _DEFAULT_SERVING_KEY:
                export_lib.export_output.PredictOutput(predictions)
            })
示例#2
0
  def model_fn(features, labels, mode):
    """model_fn for keras Estimator."""
    # Raise an error when users use DistributionStrategy with native Keras
    # optimizers. Currently we only support native TensorFlow optimizers.
    if distribution_strategy_context.has_distribution_strategy() and \
        not isinstance(keras_model.optimizer,
                       (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
      raise ValueError('Only TensorFlow native optimizers are supported with '
                       'DistributionStrategy.')

    model = _clone_and_build_model(mode, keras_model, custom_objects, features,
                                   labels)
    model_output_names = []
    # We need to make sure that the output names of the last layer in the model
    # is the same for each of the cloned models. This is required for mirrored
    # strategy when we call regroup.
    if distribution_strategy_context.has_distribution_strategy():
      for name in model.output_names:
        name = re.compile(r'_\d$').sub('', name)
        model_output_names.append(name)
    else:
      model_output_names = model.output_names

    # Get inputs to EstimatorSpec
    predictions = dict(zip(model_output_names, model.outputs))

    loss = None
    train_op = None
    eval_metric_ops = None

    # Set loss and metric only during train and evaluate.
    if mode is not model_fn_lib.ModeKeys.PREDICT:
      if mode is model_fn_lib.ModeKeys.TRAIN:
        model._make_train_function()  # pylint: disable=protected-access
      else:
        model._make_test_function()  # pylint: disable=protected-access
      loss = model.total_loss

      eval_metric_ops = _convert_keras_metrics_to_estimator(model)

    # Set train_op only during train.
    if mode is model_fn_lib.ModeKeys.TRAIN:
      train_op = model.train_function.updates_op

    if not model._is_graph_network:
      # Reset model state to original state,
      # to avoid `model_fn` being destructive for the initial model argument.
      models.in_place_subclassed_model_state_restoration(keras_model)
    return model_fn_lib.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        export_outputs={
            _DEFAULT_SERVING_KEY:
            export_lib.export_output.PredictOutput(predictions)
        })
示例#3
0
    def model_fn(features, labels, mode):
        """model_fn for keras Estimator."""
        model = _clone_and_build_model(mode=mode,
                                       keras_model=keras_model,
                                       custom_objects=custom_objects,
                                       features=features,
                                       labels=labels,
                                       optimizer_config=optimizer_config)
        model_output_names = []
        # We need to make sure that the output names of the last layer in the model
        # is the same for each of the cloned models. This is required for mirrored
        # strategy when we call regroup.
        if tf.distribute.has_strategy():
            for name in model.output_names:
                name = re.compile(r'_\d$').sub('', name)
                model_output_names.append(name)
        else:
            model_output_names = model.output_names

        # Get inputs to EstimatorSpec
        predictions = dict(zip(model_output_names, model.outputs))

        loss = None
        train_op = None
        eval_metric_ops = None

        # Set loss and metric only during train and evaluate.
        if mode is not ModeKeys.PREDICT:
            if mode is ModeKeys.TRAIN:
                model._make_train_function()  # pylint: disable=protected-access
            else:
                model._make_test_function()  # pylint: disable=protected-access
            loss = model.total_loss

            eval_metric_ops = _convert_keras_metrics_to_estimator(model)

        # Set train_op only during train.
        if mode is ModeKeys.TRAIN:
            train_op = model.train_function.updates_op

        if (not model._is_graph_network
                and hasattr(keras_model, '_original_attributes_cache')
                and keras_model._original_attributes_cache is not None):
            # To avoid `model_fn` being destructive for the initial model argument.
            models.in_place_subclassed_model_state_restoration(keras_model)

        scaffold = None
        if save_object_ckpt:
            model._track_trackable(tf.compat.v1.train.get_global_step(),
                                   'estimator_global_step')
            # Create saver that maps variable names to object-checkpoint keys.
            object_graph = graph_view.ObjectGraphView(model)
            var_list = object_graph.frozen_saveable_objects()
            saver = tf.compat.v1.train.Saver(var_list=var_list, sharded=True)
            saver._object_restore_saver = trackable_util.frozen_saver(model)
            scaffold = tf.compat.v1.train.Scaffold(saver=saver)

        return model_fn_lib.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            loss=loss,
            train_op=train_op,
            eval_metric_ops=eval_metric_ops,
            export_outputs={
                _DEFAULT_SERVING_KEY: export_lib.PredictOutput(predictions)
            },
            scaffold=scaffold)