Example #1
0
  def __new__(cls, mode, predictions=None, loss=None, training_op=None,
              default_metrics=None, signature_fn=None):
    # Assert all ops are from the same graph.
    get_graph_from_inputs((predictions, loss, training_op))

    # Validate training_op.
    if training_op is None:
      if mode == ModeKeys.TRAIN:
        raise ValueError('Missing training_op.')
    elif not isinstance(training_op, ops.Operation):
      # TODO(ptucker): Should this be allowed? Consider raising error.
      training_op = ops.convert_to_tensor(training_op).op

    # Validate loss.
    if loss is None:
      if mode in (ModeKeys.TRAIN, ModeKeys.EVAL):
        raise ValueError('Missing loss.')
    else:
      loss = ops.convert_to_tensor(loss)
      loss_shape = loss.get_shape()
      if loss_shape.num_elements() not in (None, 1):
        raise ValueError('Loss must be scalar: %s.' % loss)
      if not loss_shape.is_compatible_with(tensor_shape.scalar()):
        loss = array_ops.reshape(loss, [])

    # Validate predictions.
    if predictions is None:
      if mode == ModeKeys.INFER or mode == ModeKeys.EVAL:
        raise ValueError('Missing predictions.')
    else:
      if isinstance(predictions, dict):
        predictions = {
            k: contrib_framework.convert_to_tensor_or_sparse_tensor(v)
            for k, v in six.iteritems(predictions)
        }
      else:
        predictions = contrib_framework.convert_to_tensor_or_sparse_tensor(
            predictions)

    # Validate default_metrics
    if default_metrics is None:
      default_metrics = {}
    else:
      if not isinstance(default_metrics, dict):
        raise ValueError('default_metrics must be a dict.')
      for k, v in default_metrics.items():
        if not isinstance(v, metric_spec.MetricSpec):
          raise ValueError('Metric with key=%s is not MetricSpec' % k)

    # validate signature_fn
    if signature_fn:
      if not callable(signature_fn):
        raise ValueError('signature_fn is not callable.')

    return super(ModelFnOps, cls).__new__(cls, predictions, loss, training_op,
                                          default_metrics, signature_fn)
Example #2
0
  def __new__(cls, mode, predictions=None, loss=None, training_op=None,
              default_metrics=None, signature_fn=None):
    # Assert all ops are from the same graph.
    get_graph_from_inputs((predictions, loss, training_op))

    # Validate training_op.
    if training_op is None:
      if mode == ModeKeys.TRAIN:
        raise ValueError('Missing training_op.')
    elif not isinstance(training_op, ops.Operation):
      # TODO(ptucker): Should this be allowed? Consider raising error.
      training_op = ops.convert_to_tensor(training_op).op

    # Validate loss.
    if loss is None:
      if mode in (ModeKeys.TRAIN, ModeKeys.EVAL):
        raise ValueError('Missing loss.')
    else:
      loss = ops.convert_to_tensor(loss)
      loss_shape = loss.get_shape()
      if loss_shape.num_elements() not in (None, 1):
        raise ValueError('Loss must be scalar: %s.' % loss)
      if not loss_shape.is_compatible_with(tensor_shape.scalar()):
        loss = array_ops.reshape(loss, [])

    # Validate predictions.
    if predictions is None:
      if mode == ModeKeys.INFER or mode == ModeKeys.EVAL:
        raise ValueError('Missing predictions.')
    else:
      if isinstance(predictions, dict):
        predictions = {
            k: contrib_framework.convert_to_tensor_or_sparse_tensor(v)
            for k, v in six.iteritems(predictions)
        }
      else:
        predictions = contrib_framework.convert_to_tensor_or_sparse_tensor(
            predictions)

    # Validate default_metrics
    if default_metrics is None:
      default_metrics = {}
    else:
      if not isinstance(default_metrics, dict):
        raise ValueError('default_metrics must be a dict.')
      for k, v in default_metrics.items():
        if not isinstance(v, metric_spec.MetricSpec):
          raise ValueError('Metric with key=%s is not MetricSpec' % k)

    # validate signature_fn
    if signature_fn:
      if not callable(signature_fn):
        raise ValueError('signature_fn is not callable.')

    return super(ModelFnOps, cls).__new__(cls, predictions, loss, training_op,
                                          default_metrics, signature_fn)
Example #3
0
def _set_operation(a, b, set_operation, validate_indices=True):
    """Compute set operation of elements in last dimension of `a` and `b`.

  All but the last dimension of `a` and `b` must match.

  Args:
    a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
        must be sorted in row-major order.
    b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
        `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
        sorted in row-major order.
    set_operation: String indicating set operaiton. See
        SetOperationOp::SetOperationFromContext for valid values.
    validate_indices: Whether to validate the order and range of sparse indices
       in `a` and `b`.

  Returns:
    A `SparseTensor` with the same rank as `a` and `b`, and all but the last
    dimension the same. Elements along the last dimension contain the results
    of the set operation.

  Raises:
    TypeError: If inputs are invalid types.
    ValueError: If `a` is sparse and `b` is dense.
  """
    a = framework.convert_to_tensor_or_sparse_tensor(a, name="a")
    if a.dtype.base_dtype not in _VALID_DTYPES:
        raise TypeError("'a' invalid dtype %s." % a.dtype)
    b = framework.convert_to_tensor_or_sparse_tensor(b, name="b")
    if b.dtype.base_dtype != a.dtype.base_dtype:
        raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype))
    # pylint: disable=protected-access
    if isinstance(a, ops.SparseTensor):
        if isinstance(b, ops.SparseTensor):
            indices, values, shape = _set_ops.sparse_to_sparse_set_operation(
                a.indices, a.values, a.shape, b.indices, b.values, b.shape,
                set_operation, validate_indices)
        else:
            raise ValueError(
                "Sparse,Dense is not supported, but Dense,Sparse is. "
                "Please flip the order of your inputs.")
    elif isinstance(b, ops.SparseTensor):
        indices, values, shape = _set_ops.dense_to_sparse_set_operation(
            a, b.indices, b.values, b.shape, set_operation, validate_indices)
    else:
        indices, values, shape = _set_ops.dense_to_dense_set_operation(
            a, b, set_operation, validate_indices)
    # pylint: enable=protected-access
    return ops.SparseTensor(indices, values, shape)
Example #4
0
    def _call_model_fn(self, features, targets, mode):
        """Calls model function with support of 2, 3 or 4 arguments."""
        features, targets = self._feature_engineering_fn(features, targets)
        model_fn_args = _get_arguments(self._model_fn)
        if 'mode' in model_fn_args:
            if 'params' in model_fn_args:
                predictions, loss, train_op = self._model_fn(
                    features, targets, mode=mode, params=self.params)
            else:
                predictions, loss, train_op = self._model_fn(features,
                                                             targets,
                                                             mode=mode)
        else:
            predictions, loss, train_op = self._model_fn(features, targets)

        # Validate train_op.
        if train_op is None:
            if mode == ModeKeys.TRAIN:
                raise ValueError('Missing train_op.')
        elif not isinstance(train_op, ops.Operation):
            train_op = ops.convert_to_tensor(train_op).op

        # Validate loss.
        if loss is None:
            if mode in (ModeKeys.TRAIN, ModeKeys.EVAL):
                raise ValueError('Missing loss.')
        else:
            loss = ops.convert_to_tensor(loss)
            loss_shape = loss.get_shape()
            if loss_shape.num_elements() not in (None, 1):
                raise ValueError('Loss must be scalar: %s.' % loss)
            if not loss_shape.is_compatible_with(tensor_shape.scalar()):
                loss = array_ops.reshape(loss, [])

        # Validate predictions.
        if predictions is None:
            if mode == ModeKeys.INFER:
                raise ValueError('Missing predictions.')
        else:
            if isinstance(predictions, dict):
                predictions = {
                    k: contrib_framework.convert_to_tensor_or_sparse_tensor(v)
                    for k, v in six.iteritems(predictions)
                }
            else:
                predictions = contrib_framework.convert_to_tensor_or_sparse_tensor(
                    predictions)

        return predictions, loss, train_op
Example #5
0
def _set_operation(a, b, set_operation, validate_indices=True):
  """Compute set operation of elements in last dimension of `a` and `b`.

  All but the last dimension of `a` and `b` must match.

  Args:
    a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
        must be sorted in row-major order.
    b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
        `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
        sorted in row-major order.
    set_operation: String indicating set operaiton. See
        SetOperationOp::SetOperationFromContext for valid values.
    validate_indices: Whether to validate the order and range of sparse indices
       in `a` and `b`.

  Returns:
    A `SparseTensor` with the same rank as `a` and `b`, and all but the last
    dimension the same. Elements along the last dimension contain the results
    of the set operation.

  Raises:
    TypeError: If inputs are invalid types.
    ValueError: If `a` is sparse and `b` is dense.
  """
  a = framework.convert_to_tensor_or_sparse_tensor(a, name="a")
  if a.dtype.base_dtype not in _VALID_DTYPES:
    raise TypeError("'a' invalid dtype %s." % a.dtype)
  b = framework.convert_to_tensor_or_sparse_tensor(b, name="b")
  if b.dtype.base_dtype != a.dtype.base_dtype:
    raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype))
  # pylint: disable=protected-access
  if isinstance(a, ops.SparseTensor):
    if isinstance(b, ops.SparseTensor):
      indices, values, shape = _set_ops.sparse_to_sparse_set_operation(
          a.indices, a.values, a.shape, b.indices, b.values, b.shape,
          set_operation, validate_indices)
    else:
      raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. "
                       "Please flip the order of your inputs.")
  elif isinstance(b, ops.SparseTensor):
    indices, values, shape = _set_ops.dense_to_sparse_set_operation(
        a, b.indices, b.values, b.shape, set_operation, validate_indices)
  else:
    indices, values, shape = _set_ops.dense_to_dense_set_operation(
        a, b, set_operation, validate_indices)
  # pylint: enable=protected-access
  return ops.SparseTensor(indices, values, shape)
Example #6
0
  def _call_model_fn(self, features, targets, mode):
    """Calls model function with support of 2, 3 or 4 arguments."""
    features, targets = self._feature_engineering_fn(features, targets)
    model_fn_args = _get_arguments(self._model_fn)
    if 'mode' in model_fn_args:
      if 'params' in model_fn_args:
        predictions, loss, train_op = self._model_fn(
            features, targets, mode=mode, params=self.params)
      else:
        predictions, loss, train_op = self._model_fn(
            features, targets, mode=mode)
    else:
      predictions, loss, train_op = self._model_fn(features, targets)

    # Validate train_op.
    if train_op is None:
      if mode == ModeKeys.TRAIN:
        raise ValueError('Missing train_op.')
    elif not isinstance(train_op, ops.Operation):
      train_op = ops.convert_to_tensor(train_op).op

    # Validate loss.
    if loss is None:
      if mode in (ModeKeys.TRAIN, ModeKeys.EVAL):
        raise ValueError('Missing loss.')
    else:
      loss = ops.convert_to_tensor(loss)
      loss_shape = loss.get_shape()
      if loss_shape.num_elements() not in (None, 1):
        raise ValueError('Loss must be scalar: %s.' % loss)
      if not loss_shape.is_compatible_with(tensor_shape.scalar()):
        loss = array_ops.reshape(loss, [])

    # Validate predictions.
    if predictions is None:
      if mode == ModeKeys.INFER:
        raise ValueError('Missing predictions.')
    else:
      if isinstance(predictions, dict):
        predictions = {
            k: contrib_framework.convert_to_tensor_or_sparse_tensor(v)
            for k, v in six.iteritems(predictions)
        }
      else:
        predictions = contrib_framework.convert_to_tensor_or_sparse_tensor(
            predictions)

    return predictions, loss, train_op
Example #7
0
def set_size(a, validate_indices=True):
  """Compute number of unique elements along last dimension of `a`.

  Args:
    a: `SparseTensor`, with indices sorted in row-major order.
    validate_indices: Whether to validate the order and range of sparse indices
       in `a`.

  Returns:
    For `a` ranked `n`, this is a `Tensor` with rank `n-1`, and the same 1st
    `n-1` dimensions as `a`. Each value is the number of unique elements in
    the corresponding `[0...n-1]` dimension of `a`.

  Raises:
    TypeError: If `a` is an invalid types.
  """
  a = framework.convert_to_tensor_or_sparse_tensor(a, name="a")
  if not isinstance(a, ops.SparseTensor):
    raise TypeError("Expected `SparseTensor`, got %s." % a)
  if a.values.dtype.base_dtype not in _VALID_DTYPES:
    raise TypeError("Invalid dtype %s." % a.values.dtype)
  # pylint: disable=protected-access
  return _set_ops.set_size(a.indices, a.values, a.shape, validate_indices)
Example #8
0
def set_size(a, validate_indices=True):
    """Compute number of unique elements along last dimension of `a`.

  Args:
    a: `SparseTensor`, with indices sorted in row-major order.
    validate_indices: Whether to validate the order and range of sparse indices
       in `a`.

  Returns:
    For `a` ranked `n`, this is a `Tensor` with rank `n-1`, and the same 1st
    `n-1` dimensions as `a`. Each value is the number of unique elements in
    the corresponding `[0...n-1]` dimension of `a`.

  Raises:
    TypeError: If `a` is an invalid types.
  """
    a = framework.convert_to_tensor_or_sparse_tensor(a, name="a")
    if not isinstance(a, ops.SparseTensor):
        raise TypeError("Expected `SparseTensor`, got %s." % a)
    if a.values.dtype.base_dtype not in _VALID_DTYPES:
        raise TypeError("Invalid dtype %s." % a.values.dtype)
    # pylint: disable=protected-access
    return _set_ops.set_size(a.indices, a.values, a.shape, validate_indices)
Example #9
0
    def __new__(cls,
                mode,
                predictions=None,
                loss=None,
                train_op=None,
                eval_metric_ops=None,
                signature_fn=None,
                output_alternatives=None):
        """Creates a validated `ModelFnOps` instance.

    For a multi-headed model, the predictions dict here will contain the outputs
    of all of the heads.  However: at serving time, requests will be made
    specifically for one or more heads, and the RPCs used for these requests may
    differ by problem type (i.e., regression, classification, other).  The
    purpose of the output_alternatives dict is to aid in exporting a SavedModel
    from which such head-specific queries can be served.  These
    output_alternatives will be combined with input_alternatives (see
    `saved_model_export_utils`) to produce a set of `SignatureDef`s specifying
    the valid requests that can be served from this model.

    For a single-headed model, it is still adviseable to provide
    output_alternatives with a single entry, because this is how the problem
    type is communicated for export and serving.  If output_alternatives is not
    given, the resulting SavedModel will support only one head of unspecified
    type.

    Args:
      mode: One of `ModeKeys`. Specifies if this training, evaluation or
        prediction.
      predictions: Predictions `Tensor` or dict of `Tensor`.
      loss: Training loss `Tensor`.
      train_op: Op for the training step.
      eval_metric_ops: Dict of metric results keyed by name. The values of the
        dict are the results of calling a metric function, such as `Tensor`.
      signature_fn: The signature_fn used for exporting.
      output_alternatives: a dict of
        `{submodel_name: (problem_type, {tensor_name: Tensor})}`, where
        `submodel_name` is a submodel identifier that should be consistent
        across the pipeline (here likely taken from the name of each `Head`,
        for models that use them), `problem_type` is a `ProblemType`,
        `tensor_name` is a symbolic name for an output Tensor possibly but not
        necessarily taken from `PredictionKey`, and `Tensor` is the
        corresponding output Tensor itself.

    Returns:
      A validated `ModelFnOps` object.

    Raises:
      ValueError: If validation fails.
    """
        # Assert all ops are from the same graph.
        get_graph_from_inputs((predictions, loss, train_op))

        # Validate train_op.
        if train_op is None:
            if mode == ModeKeys.TRAIN:
                raise ValueError('Missing training_op.')
        elif not isinstance(train_op, ops.Operation):
            # TODO(ptucker): Should this be allowed? Consider raising error.
            train_op = ops.convert_to_tensor(train_op).op

        # Validate loss.
        if loss is None:
            if mode in (ModeKeys.TRAIN, ModeKeys.EVAL):
                raise ValueError('Missing loss.')
        else:
            loss = ops.convert_to_tensor(loss)
            loss_shape = loss.get_shape()
            if loss_shape.num_elements() not in (None, 1):
                raise ValueError('Loss must be scalar: %s.' % loss)
            if not loss_shape.is_compatible_with(tensor_shape.scalar()):
                loss = array_ops.reshape(loss, [])

        # Validate predictions.
        if predictions is None:
            if mode == ModeKeys.INFER or mode == ModeKeys.EVAL:
                raise ValueError('Missing predictions.')
        else:
            if isinstance(predictions, dict):
                predictions = {
                    k: contrib_framework.convert_to_tensor_or_sparse_tensor(v)
                    for k, v in six.iteritems(predictions)
                }
            else:
                predictions = contrib_framework.convert_to_tensor_or_sparse_tensor(
                    predictions)

        # Validate eval_metric_ops
        if eval_metric_ops is None:
            eval_metric_ops = {}
        else:
            if not isinstance(eval_metric_ops, dict):
                raise ValueError('eval_metric_ops must be a dict.')

        # validate signature_fn
        if signature_fn:
            if not callable(signature_fn):
                raise ValueError('signature_fn is not callable.')

        return super(ModelFnOps, cls).__new__(cls, predictions, loss, train_op,
                                              eval_metric_ops, signature_fn,
                                              output_alternatives)
Example #10
0
    def __new__(cls,
                mode,
                predictions=None,
                loss=None,
                train_op=None,
                eval_metric_ops=None,
                signature_fn=None):
        """Creates a validated `ModelFnOps` instance.

    Args:
      mode: One of `ModeKeys`. Specifies if this training, evaluation or
        prediction.
      predictions: Predictions `Tensor` or dict of `Tensor`.
      loss: Training loss `Tensor`.
      train_op: Op for the training step.
      eval_metric_ops: Dict of metric results keyed by name. The values of the
        dict are the results of calling a metric function, such as `Tensor`.
      signature_fn: The signature_fn used for exporting.

    Returns:
      A validated `ModelFnOps` object.

    Raises:
      ValueError: If validation fails.
    """
        # Assert all ops are from the same graph.
        get_graph_from_inputs((predictions, loss, train_op))

        # Validate train_op.
        if train_op is None:
            if mode == ModeKeys.TRAIN:
                raise ValueError('Missing training_op.')
        elif not isinstance(train_op, ops.Operation):
            # TODO(ptucker): Should this be allowed? Consider raising error.
            train_op = ops.convert_to_tensor(train_op).op

        # Validate loss.
        if loss is None:
            if mode in (ModeKeys.TRAIN, ModeKeys.EVAL):
                raise ValueError('Missing loss.')
        else:
            loss = ops.convert_to_tensor(loss)
            loss_shape = loss.get_shape()
            if loss_shape.num_elements() not in (None, 1):
                raise ValueError('Loss must be scalar: %s.' % loss)
            if not loss_shape.is_compatible_with(tensor_shape.scalar()):
                loss = array_ops.reshape(loss, [])

        # Validate predictions.
        if predictions is None:
            if mode == ModeKeys.INFER or mode == ModeKeys.EVAL:
                raise ValueError('Missing predictions.')
        else:
            if isinstance(predictions, dict):
                predictions = {
                    k: contrib_framework.convert_to_tensor_or_sparse_tensor(v)
                    for k, v in six.iteritems(predictions)
                }
            else:
                predictions = contrib_framework.convert_to_tensor_or_sparse_tensor(
                    predictions)

        # Validate eval_metric_ops
        if eval_metric_ops is None:
            eval_metric_ops = {}
        else:
            if not isinstance(eval_metric_ops, dict):
                raise ValueError('eval_metric_ops must be a dict.')

        # validate signature_fn
        if signature_fn:
            if not callable(signature_fn):
                raise ValueError('signature_fn is not callable.')

        return super(ModelFnOps, cls).__new__(cls, predictions, loss, train_op,
                                              eval_metric_ops, signature_fn)
Example #11
0
  def __new__(cls,
              mode,
              predictions=None,
              loss=None,
              train_op=None,
              eval_metric_ops=None,
              output_alternatives=None,
              training_chief_hooks=None,
              training_hooks=None,
              scaffold=None):
    """Creates a validated `ModelFnOps` instance.

    For a multi-headed model, the predictions dict here will contain the outputs
    of all of the heads.  However: at serving time, requests will be made
    specifically for one or more heads, and the RPCs used for these requests may
    differ by problem type (i.e., regression, classification, other).  The
    purpose of the output_alternatives dict is to aid in exporting a SavedModel
    from which such head-specific queries can be served.  These
    output_alternatives will be combined with input_alternatives (see
    `saved_model_export_utils`) to produce a set of `SignatureDef`s specifying
    the valid requests that can be served from this model.

    For a single-headed model, it is still adviseable to provide
    output_alternatives with a single entry, because this is how the problem
    type is communicated for export and serving.  If output_alternatives is not
    given, the resulting SavedModel will support only one head of unspecified
    type.

    Args:
      mode: One of `ModeKeys`. Specifies if this training, evaluation or
        prediction.
      predictions: Predictions `Tensor` or dict of `Tensor`.
      loss: Training loss `Tensor`.
      train_op: Op for the training step.
      eval_metric_ops: Dict of metric results keyed by name. The values of the
        dict are the results of calling a metric function, such as `Tensor`.
      output_alternatives: a dict of
        `{submodel_name: (problem_type, {tensor_name: Tensor})}`, where
        `submodel_name` is a submodel identifier that should be consistent
        across the pipeline (here likely taken from the name of each `Head`,
        for models that use them), `problem_type` is a `ProblemType`,
        `tensor_name` is a symbolic name for an output Tensor possibly but not
        necessarily taken from `PredictionKey`, and `Tensor` is the
        corresponding output Tensor itself.
      training_chief_hooks: A list of `SessionRunHook` objects that will be
        run on the chief worker during training.
      training_hooks: A list of `SessionRunHook` objects that will be run on
        all workers during training.
      scaffold: A `tf.train.Scaffold` object that can be used to set
        initialization, saver, and more to be used in training.

    Returns:
      A validated `ModelFnOps` object.

    Raises:
      ValueError: If validation fails.
    """
    # Assert all ops are from the same graph.
    get_graph_from_inputs((predictions, loss, train_op))

    # Validate train_op.
    if train_op is None:
      if mode == ModeKeys.TRAIN:
        raise ValueError('Missing training_op.')
    elif not isinstance(train_op, ops.Operation):
      # TODO(ptucker): Should this be allowed? Consider raising error.
      train_op = ops.convert_to_tensor(train_op).op

    # Validate loss.
    if loss is None:
      if mode in (ModeKeys.TRAIN, ModeKeys.EVAL):
        raise ValueError('Missing loss.')
    else:
      loss = ops.convert_to_tensor(loss)
      loss_shape = loss.get_shape()
      if loss_shape.num_elements() not in (None, 1):
        raise ValueError('Loss must be scalar: %s.' % loss)
      if not loss_shape.is_compatible_with(tensor_shape.scalar()):
        loss = array_ops.reshape(loss, [])

    # Validate predictions.
    if predictions is None:
      if mode == ModeKeys.INFER or mode == ModeKeys.EVAL:
        raise ValueError('Missing predictions.')
    else:
      if isinstance(predictions, dict):
        predictions = {
            k: contrib_framework.convert_to_tensor_or_sparse_tensor(v)
            for k, v in six.iteritems(predictions)
        }
      else:
        predictions = contrib_framework.convert_to_tensor_or_sparse_tensor(
            predictions)

    # Validate eval_metric_ops
    if eval_metric_ops is None:
      eval_metric_ops = {}
    else:
      if not isinstance(eval_metric_ops, dict):
        raise ValueError('eval_metric_ops must be a dict.')

    # Validate hooks
    if training_chief_hooks is None:
      training_chief_hooks = []
    if training_hooks is None:
      training_hooks = []
    for hook in training_hooks + training_chief_hooks:
      if not isinstance(hook, session_run_hook.SessionRunHook):
        raise TypeError('All hooks returned from model_fn must be '
                        'SessionRunHook instances, got instance of %s: %s' %
                        (type(hook), hook))

    return super(ModelFnOps, cls).__new__(
        cls,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        output_alternatives=output_alternatives,
        training_chief_hooks=training_chief_hooks,
        training_hooks=training_hooks,
        scaffold=scaffold)
Example #12
0
  def __new__(cls, mode, predictions=None, loss=None, train_op=None,
              eval_metric_ops=None, signature_fn=None):
    """Creates a validated `ModelFnOps` instance.

    Args:
      mode: One of `ModeKeys`. Specifies if this training, evaluation or
        prediction.
      predictions: Predictions `Tensor` or dict of `Tensor`.
      loss: Training loss `Tensor`.
      train_op: Op for the training step.
      eval_metric_ops: Dict of metric results keyed by name. The values of the
        dict are the results of calling a metric function, such as `Tensor`.
      signature_fn: The signature_fn used for exporting.

    Returns:
      A validated `ModelFnOps` object.

    Raises:
      ValueError: If validation fails.
    """
    # Assert all ops are from the same graph.
    get_graph_from_inputs((predictions, loss, train_op))

    # Validate train_op.
    if train_op is None:
      if mode == ModeKeys.TRAIN:
        raise ValueError('Missing training_op.')
    elif not isinstance(train_op, ops.Operation):
      # TODO(ptucker): Should this be allowed? Consider raising error.
      train_op = ops.convert_to_tensor(train_op).op

    # Validate loss.
    if loss is None:
      if mode in (ModeKeys.TRAIN, ModeKeys.EVAL):
        raise ValueError('Missing loss.')
    else:
      loss = ops.convert_to_tensor(loss)
      loss_shape = loss.get_shape()
      if loss_shape.num_elements() not in (None, 1):
        raise ValueError('Loss must be scalar: %s.' % loss)
      if not loss_shape.is_compatible_with(tensor_shape.scalar()):
        loss = array_ops.reshape(loss, [])

    # Validate predictions.
    if predictions is None:
      if mode == ModeKeys.INFER or mode == ModeKeys.EVAL:
        raise ValueError('Missing predictions.')
    else:
      if isinstance(predictions, dict):
        predictions = {
            k: contrib_framework.convert_to_tensor_or_sparse_tensor(v)
            for k, v in six.iteritems(predictions)
        }
      else:
        predictions = contrib_framework.convert_to_tensor_or_sparse_tensor(
            predictions)

    # Validate eval_metric_ops
    if eval_metric_ops is None:
      eval_metric_ops = {}
    else:
      if not isinstance(eval_metric_ops, dict):
        raise ValueError('eval_metric_ops must be a dict.')

    # validate signature_fn
    if signature_fn:
      if not callable(signature_fn):
        raise ValueError('signature_fn is not callable.')

    return super(ModelFnOps, cls).__new__(cls, predictions, loss, train_op,
                                          eval_metric_ops, signature_fn)