Пример #1
0
def export_estimator(estimator, export_dir, input_fn=_default_input_fn,
                     signature_fn=_generic_signature_fn, default_batch_size=1,
                     exports_to_keep=None):
  """Exports inference graph into given dir.

  Args:
    estimator: Estimator to export
    export_dir: A string containing a directory to write the exported graph
      and checkpoints.
    input_fn: Function that given `Tensor` of `Example` strings, parses it into
      features that are then passed to the model.
    signature_fn: Function that given `Tensor` of `Example` strings,
      `dict` of `Tensor`s for features and `dict` of `Tensor`s for predictions
      and returns default and named exporting signautres.
    default_batch_size: Default batch size of the `Example` placeholder.
    exports_to_keep: Number of exports to keep.
  """
  checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
  with ops.Graph().as_default() as g:
    contrib_variables.create_global_step(g)
    examples = array_ops.placeholder(dtype=dtypes.string,
                                     shape=[default_batch_size],
                                     name='input_example_tensor')
    features = input_fn(estimator, examples)
    predictions = estimator._get_predict_ops(features)
    default_signature, named_graph_signatures = signature_fn(
        examples, features, predictions)
    if exports_to_keep is not None:
      exports_to_keep = gc.largest_export_versions(exports_to_keep)
    _export_graph(g, _get_saver(), checkpoint_path, export_dir,
                  default_graph_signature=default_signature,
                  named_graph_signatures=named_graph_signatures,
                  exports_to_keep=exports_to_keep)
Пример #2
0
def _export_estimator(estimator,
                      export_dir,
                      signature_fn,
                      input_fn,
                      default_batch_size,
                      exports_to_keep):
  input_fn = input_fn or _default_input_fn
  checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
  with ops.Graph().as_default() as g:
    contrib_variables.create_global_step(g)
    examples = array_ops.placeholder(dtype=dtypes.string,
                                     shape=[default_batch_size],
                                     name='input_example_tensor')
    features = input_fn(estimator, examples)
    predictions = estimator._get_predict_ops(features)

    # Explicit signature_fn takes priority
    if signature_fn:
      default_signature, named_graph_signatures = signature_fn(examples,
                                                               features,
                                                               predictions)
    else:
      try:
        # Some estimators provide a target_column of known type
        target_column = estimator._get_target_column()
        problem_type = target_column.problem_type

        if problem_type == layers.ProblemType.CLASSIFICATION:
          signature_fn = classification_signature_fn
        elif problem_type == layers.ProblemType.LINEAR_REGRESSION:
          signature_fn = regression_signature_fn
        elif problem_type == layers.ProblemType.LOGISTIC_REGRESSION:
          signature_fn = logistic_regression_signature_fn
        else:
          raise ValueError(
              'signature_fn must be provided because the TargetColumn is a %s, '
              'which does not have a standard problem type and so cannot use a '
              'standard export signature.' % type(target_column).__name__)

        default_signature, named_graph_signatures = (
            signature_fn(examples, features, predictions))
      except AttributeError:
        logging.warn(
            'Change warning: `signature_fn` will be required after'
            '2016-08-01.\n'
            'Using generic signatures for now.  To maintain this behavior, '
            'pass:\n'
            '  signature_fn=export.generic_signature_fn\n'
            'Also consider passing a regression or classification signature; '
            'see cl/126430915 for an example.')
        default_signature, named_graph_signatures = generic_signature_fn(
            examples, features, predictions)
    if exports_to_keep is not None:
      exports_to_keep = gc.largest_export_versions(exports_to_keep)
    _export_graph(g, _get_saver(), checkpoint_path, export_dir,
                  default_graph_signature=default_signature,
                  named_graph_signatures=named_graph_signatures,
                  exports_to_keep=exports_to_keep)
Пример #3
0
 def testUnion(self):
   paths = []
   for i in xrange(10):
     paths.append(gc.Path("/foo", i))
   f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3))
   self.assertEquals(
       f(paths), [gc.Path("/foo", 0), gc.Path("/foo", 3),
                  gc.Path("/foo", 6), gc.Path("/foo", 7),
                  gc.Path("/foo", 8), gc.Path("/foo", 9)])
Пример #4
0
def _export_estimator(estimator, export_dir, signature_fn, input_fn,
                      default_batch_size, exports_to_keep):
    input_fn = input_fn or _default_input_fn
    checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
    with ops.Graph().as_default() as g:
        contrib_variables.create_global_step(g)
        examples = array_ops.placeholder(dtype=dtypes.string,
                                         shape=[default_batch_size],
                                         name='input_example_tensor')
        features = input_fn(estimator, examples)
        predictions = estimator._get_predict_ops(features)

        # Explicit signature_fn takes priority
        if signature_fn:
            default_signature, named_graph_signatures = signature_fn(
                examples, features, predictions)
        else:
            try:
                # Some estimators provide a target_column of known type
                target_column = estimator._get_target_column()
                problem_type = target_column.problem_type

                if problem_type == layers.ProblemType.CLASSIFICATION:
                    signature_fn = classification_signature_fn
                elif problem_type == layers.ProblemType.LINEAR_REGRESSION:
                    signature_fn = regression_signature_fn
                elif problem_type == layers.ProblemType.LOGISTIC_REGRESSION:
                    signature_fn = logistic_regression_signature_fn
                else:
                    raise ValueError(
                        'signature_fn must be provided because the TargetColumn is a %s, '
                        'which does not have a standard problem type and so cannot use a '
                        'standard export signature.' %
                        type(target_column).__name__)

                default_signature, named_graph_signatures = (signature_fn(
                    examples, features, predictions))
            except AttributeError:
                logging.warn(
                    'Change warning: `signature_fn` will be required after'
                    '2016-08-01.\n'
                    'Using generic signatures for now.  To maintain this behavior, '
                    'pass:\n'
                    '  signature_fn=export.generic_signature_fn\n'
                    'Also consider passing a regression or classification signature; '
                    'see cl/126430915 for an example.')
                default_signature, named_graph_signatures = generic_signature_fn(
                    examples, features, predictions)
        if exports_to_keep is not None:
            exports_to_keep = gc.largest_export_versions(exports_to_keep)
        _export_graph(g,
                      _get_saver(),
                      checkpoint_path,
                      export_dir,
                      default_graph_signature=default_signature,
                      named_graph_signatures=named_graph_signatures,
                      exports_to_keep=exports_to_keep)
Пример #5
0
 def testUnion(self):
   paths = []
   for i in xrange(10):
     paths.append(gc.Path("/foo", i))
   f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3))
   self.assertEquals(
       f(paths), [gc.Path("/foo", 0), gc.Path("/foo", 3),
                  gc.Path("/foo", 6), gc.Path("/foo", 7),
                  gc.Path("/foo", 8), gc.Path("/foo", 9)])
Пример #6
0
def export_estimator(estimator,
                     export_dir,
                     signature_fn=None,
                     input_fn=_default_input_fn,
                     default_batch_size=1,
                     exports_to_keep=None):
    """Exports inference graph into given dir.

  Args:
    estimator: Estimator to export
    export_dir: A string containing a directory to write the exported graph
      and checkpoints.
    signature_fn: Function that given `Tensor` of `Example` strings,
      `dict` of `Tensor`s for features and `dict` of `Tensor`s for predictions
    input_fn: Function that given `Tensor` of `Example` strings, parses it into
      features that are then passed to the model.
      and returns default and named exporting signatures.
    default_batch_size: Default batch size of the `Example` placeholder.
    exports_to_keep: Number of exports to keep.
  """
    checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
    with ops.Graph().as_default() as g:
        contrib_variables.create_global_step(g)
        examples = array_ops.placeholder(dtype=dtypes.string,
                                         shape=[default_batch_size],
                                         name='input_example_tensor')
        features = input_fn(estimator, examples)
        predictions = estimator._get_predict_ops(features)
        if signature_fn:
            default_signature, named_graph_signatures = signature_fn(
                examples, features, predictions)
        else:
            logging.warn(
                'Change warning: `signature_fn` will be required after 2016-08-01.\n'
                'Using generic signatures for now.  To maintain this behavior, '
                'pass:\n'
                '  signature_fn=export.generic_signature_fn\n'
                'Also consider passing a regression or classification signature; see '
                'cl/126430915 for an example.')
            default_signature, named_graph_signatures = generic_signature_fn(
                examples, features, predictions)
        if exports_to_keep is not None:
            exports_to_keep = gc.largest_export_versions(exports_to_keep)
        _export_graph(g,
                      _get_saver(),
                      checkpoint_path,
                      export_dir,
                      default_graph_signature=default_signature,
                      named_graph_signatures=named_graph_signatures,
                      exports_to_keep=exports_to_keep)
Пример #7
0
def export_estimator(estimator,
                     export_dir,
                     signature_fn=None,
                     input_fn=_default_input_fn,
                     default_batch_size=1,
                     exports_to_keep=None):
  """Exports inference graph into given dir.

  Args:
    estimator: Estimator to export
    export_dir: A string containing a directory to write the exported graph
      and checkpoints.
    signature_fn: Function that given `Tensor` of `Example` strings,
      `dict` of `Tensor`s for features and `dict` of `Tensor`s for predictions
    input_fn: Function that given `Tensor` of `Example` strings, parses it into
      features that are then passed to the model.
      and returns default and named exporting signatures.
    default_batch_size: Default batch size of the `Example` placeholder.
    exports_to_keep: Number of exports to keep.
  """
  checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
  with ops.Graph().as_default() as g:
    contrib_variables.create_global_step(g)
    examples = array_ops.placeholder(dtype=dtypes.string,
                                     shape=[default_batch_size],
                                     name='input_example_tensor')
    features = input_fn(estimator, examples)
    predictions = estimator._get_predict_ops(features)
    if signature_fn:
      default_signature, named_graph_signatures = signature_fn(examples,
                                                               features,
                                                               predictions)
    else:
      logging.warn(
          'Change warning: `signature_fn` will be required after 2016-08-01.\n'
          'Using generic signatures for now.  To maintain this behavior, '
          'pass:\n'
          '  signature_fn=export.generic_signature_fn\n'
          'Also consider passing a regression or classification signature; see '
          'cl/126430915 for an example.')
      default_signature, named_graph_signatures = generic_signature_fn(
          examples, features, predictions)
    if exports_to_keep is not None:
      exports_to_keep = gc.largest_export_versions(exports_to_keep)
    _export_graph(g, _get_saver(), checkpoint_path, export_dir,
                  default_graph_signature=default_signature,
                  named_graph_signatures=named_graph_signatures,
                  exports_to_keep=exports_to_keep)
Пример #8
0
 def testLargestExportVersions(self):
   paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)]
   newest = gc.largest_export_versions(2)
   n = newest(paths)
   self.assertEquals(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)])
Пример #9
0
 def testLargestExportVersionsDoesNotDeleteZeroFolder(self):
   paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)]
   newest = gc.largest_export_versions(2)
   n = newest(paths)
   self.assertEquals(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)])
Пример #10
0
def _export_estimator(estimator,
                      export_dir,
                      signature_fn,
                      input_fn,
                      default_batch_size,
                      exports_to_keep,
                      input_feature_key=None,
                      use_deprecated_input_fn=True,
                      prediction_key=None,
                      checkpoint_path=None):
  if use_deprecated_input_fn:
    input_fn = input_fn or _default_input_fn
  elif input_fn is None:
    raise ValueError('input_fn must be defined.')

  # If checkpoint_path is specified, use the specified checkpoint path.
  checkpoint_path = (checkpoint_path or
                     tf_saver.latest_checkpoint(estimator._model_dir))
  with ops.Graph().as_default() as g:
    contrib_variables.create_global_step(g)

    if use_deprecated_input_fn:
      examples = array_ops.placeholder(dtype=dtypes.string,
                                       shape=[default_batch_size],
                                       name='input_example_tensor')
      features = input_fn(estimator, examples)
    else:
      features, _ = input_fn()
      examples = None
      if input_feature_key is not None:
        examples = features.pop(input_feature_key)

    if (not features) and (examples is None):
      raise ValueError('Either features or examples must be defined.')

    predictions = estimator._get_predict_ops(features).predictions

    if prediction_key is not None:
      predictions = predictions[prediction_key]

    # Explicit signature_fn takes priority
    if signature_fn:
      default_signature, named_graph_signatures = signature_fn(examples,
                                                               features,
                                                               predictions)
    else:
      try:
        # Some estimators provide a signature function.
        # TODO(zakaria): check if the estimator has this function,
        #   raise helpful error if not
        signature_fn = estimator._create_signature_fn()

        default_signature, named_graph_signatures = (
            signature_fn(examples, features, predictions))
      except AttributeError:
        logging.warn(
            'Change warning: `signature_fn` will be required after'
            '2016-08-01.\n'
            'Using generic signatures for now.  To maintain this behavior, '
            'pass:\n'
            '  signature_fn=export.generic_signature_fn\n'
            'Also consider passing a regression or classification signature; '
            'see cl/126430915 for an example.')
        default_signature, named_graph_signatures = generic_signature_fn(
            examples, features, predictions)
    if exports_to_keep is not None:
      exports_to_keep = gc.largest_export_versions(exports_to_keep)
    return _export_graph(
        g,
        _get_saver(),
        checkpoint_path,
        export_dir,
        default_graph_signature=default_signature,
        named_graph_signatures=named_graph_signatures,
        exports_to_keep=exports_to_keep)
Пример #11
0
def _export_estimator(estimator,
                      export_dir,
                      signature_fn,
                      input_fn,
                      default_batch_size,
                      exports_to_keep,
                      input_feature_key=None,
                      use_deprecated_input_fn=True,
                      prediction_key=None,
                      checkpoint_path=None):
    if use_deprecated_input_fn:
        input_fn = input_fn or _default_input_fn
    elif input_fn is None:
        raise ValueError('input_fn must be defined.')

    # If checkpoint_path is specified, use the specified checkpoint path.
    checkpoint_path = (checkpoint_path
                       or tf_saver.latest_checkpoint(estimator._model_dir))
    with ops.Graph().as_default() as g:
        training_util.create_global_step(g)

        if use_deprecated_input_fn:
            examples = array_ops.placeholder(dtype=dtypes.string,
                                             shape=[default_batch_size],
                                             name='input_example_tensor')
            features = input_fn(estimator, examples)
        else:
            features, _ = input_fn()
            examples = None
            if input_feature_key is not None:
                examples = features.pop(input_feature_key)

        if (not features) and (examples is None):
            raise ValueError('Either features or examples must be defined.')

        predictions = estimator._get_predict_ops(features).predictions

        if prediction_key is not None:
            predictions = predictions[prediction_key]

        # Explicit signature_fn takes priority
        if signature_fn:
            default_signature, named_graph_signatures = signature_fn(
                examples, features, predictions)
        else:
            try:
                # Some estimators provide a signature function.
                # TODO (zakaria): check if the estimator has this function, id:1300 gh:1301
                #   raise helpful error if not
                signature_fn = estimator._create_signature_fn()

                default_signature, named_graph_signatures = (signature_fn(
                    examples, features, predictions))
            except AttributeError:
                logging.warn(
                    'Change warning: `signature_fn` will be required after'
                    '2016-08-01.\n'
                    'Using generic signatures for now.  To maintain this behavior, '
                    'pass:\n'
                    '  signature_fn=export.generic_signature_fn\n'
                    'Also consider passing a regression or classification signature; '
                    'see cl/126430915 for an example.')
                default_signature, named_graph_signatures = generic_signature_fn(
                    examples, features, predictions)
        if exports_to_keep is not None:
            exports_to_keep = gc.largest_export_versions(exports_to_keep)
        return _export_graph(g,
                             _get_saver(),
                             checkpoint_path,
                             export_dir,
                             default_graph_signature=default_signature,
                             named_graph_signatures=named_graph_signatures,
                             exports_to_keep=exports_to_keep)
 def testLargestExportVersionsDoesNotDeleteZeroFolder(self):
     paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)]
     newest = gc.largest_export_versions(2)
     n = newest(paths)
     self.assertEquals(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)])
Пример #13
0
def _export_estimator(estimator,
                      export_dir,
                      signature_fn,
                      input_fn,
                      default_batch_size,
                      exports_to_keep,
                      input_feature_key=None,
                      use_deprecated_input_fn=True,
                      prediction_key=None):
  if use_deprecated_input_fn:
    input_fn = input_fn or _default_input_fn
  elif input_fn is None:
    raise ValueError('input_fn must be defined.')

  checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
  with ops.Graph().as_default() as g:
    contrib_variables.create_global_step(g)

    if use_deprecated_input_fn:
      examples = array_ops.placeholder(dtype=dtypes.string,
                                       shape=[default_batch_size],
                                       name='input_example_tensor')
      features = input_fn(estimator, examples)
    else:
      features, _ = input_fn()
      examples = None
      if input_feature_key is not None:
        examples = features.pop(input_feature_key)

    if (not features) and (examples is None):
      raise ValueError('Either features or examples must be defined.')

    # The default return type of _get_predict_ops is ModelFnOps. But there are
    # some subclasses of tf.contrib.learn.Estimator which override this
    # method and use the legacy signature, namely _get_predict_ops returns a
    # `predictions` Tensor or dict or Tensors. The following else-statement
    # code covers these cases, but will soon be deleted after the subclasses
    # are updated.
    # TODO(b/32664904): Update subclasses and delete the else-statement.
    infer_ops = estimator._get_predict_ops(features)
    if isinstance(infer_ops, model_fn.ModelFnOps):  # Default signature
      predictions = infer_ops.predictions
    else:  # Legacy signature
      predictions = infer_ops

    if prediction_key is not None:
      predictions = predictions[prediction_key]

    # Explicit signature_fn takes priority
    if signature_fn:
      default_signature, named_graph_signatures = signature_fn(examples,
                                                               features,
                                                               predictions)
    else:
      try:
        # Some estimators provide a signature function.
        # TODO(zakaria): check if the estimator has this function,
        #   raise helpful error if not
        signature_fn = estimator._create_signature_fn()

        default_signature, named_graph_signatures = (
            signature_fn(examples, features, predictions))
      except AttributeError:
        logging.warn(
            'Change warning: `signature_fn` will be required after'
            '2016-08-01.\n'
            'Using generic signatures for now.  To maintain this behavior, '
            'pass:\n'
            '  signature_fn=export.generic_signature_fn\n'
            'Also consider passing a regression or classification signature; '
            'see cl/126430915 for an example.')
        default_signature, named_graph_signatures = generic_signature_fn(
            examples, features, predictions)
    if exports_to_keep is not None:
      exports_to_keep = gc.largest_export_versions(exports_to_keep)
    return _export_graph(
        g,
        _get_saver(),
        checkpoint_path,
        export_dir,
        default_graph_signature=default_signature,
        named_graph_signatures=named_graph_signatures,
        exports_to_keep=exports_to_keep)
Пример #14
0
def export_estimator(estimator,
                     export_dir,
                     signature_fn=None,
                     input_fn=_default_input_fn,
                     default_batch_size=1,
                     exports_to_keep=None):
    """Exports inference graph into given dir.

  Args:
    estimator: Estimator to export
    export_dir: A string containing a directory to write the exported graph
      and checkpoints.
    signature_fn: Function that returns a default signature and a named
      signature map, given `Tensor` of `Example` strings, `dict` of `Tensor`s
      for features and `Tensor` or `dict` of `Tensor`s for predictions.
    input_fn: Function that given `Tensor` of `Example` strings, parses it into
      features that are then passed to the model.
    default_batch_size: Default batch size of the `Example` placeholder.
    exports_to_keep: Number of exports to keep.
  """
    checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
    with ops.Graph().as_default() as g:
        contrib_variables.create_global_step(g)
        examples = array_ops.placeholder(dtype=dtypes.string,
                                         shape=[default_batch_size],
                                         name='input_example_tensor')
        features = input_fn(estimator, examples)
        predictions = estimator._get_predict_ops(features)

        # Explicit signature_fn takes priority
        if signature_fn:
            default_signature, named_graph_signatures = signature_fn(
                examples, features, predictions)
        else:
            try:
                # Some estimators provide a target_column of known type
                target_column = estimator._get_target_column()
                problem_type = target_column.problem_type

                if problem_type == layers.ProblemType.CLASSIFICATION:
                    signature_fn = classification_signature_fn
                elif problem_type == layers.ProblemType.LINEAR_REGRESSION:
                    signature_fn = regression_signature_fn
                elif problem_type == layers.ProblemType.LOGISTIC_REGRESSION:
                    signature_fn = logistic_regression_signature_fn
                else:
                    raise ValueError(
                        'signature_fn must be provided because the TargetColumn is a %s, '
                        'which does not have a standard problem type and so cannot use a '
                        'standard export signature.' %
                        type(target_column).__name__)

                default_signature, named_graph_signatures = (signature_fn(
                    examples, features, predictions))
            except AttributeError:
                logging.warn(
                    'Change warning: `signature_fn` will be required after'
                    '2016-08-01.\n'
                    'Using generic signatures for now.  To maintain this behavior, '
                    'pass:\n'
                    '  signature_fn=export.generic_signature_fn\n'
                    'Also consider passing a regression or classification signature; '
                    'see cl/126430915 for an example.')
                default_signature, named_graph_signatures = generic_signature_fn(
                    examples, features, predictions)
        if exports_to_keep is not None:
            exports_to_keep = gc.largest_export_versions(exports_to_keep)
        _export_graph(g,
                      _get_saver(),
                      checkpoint_path,
                      export_dir,
                      default_graph_signature=default_signature,
                      named_graph_signatures=named_graph_signatures,
                      exports_to_keep=exports_to_keep)
Пример #15
0
def _export_estimator(estimator,
                      export_dir,
                      signature_fn,
                      input_fn,
                      default_batch_size,
                      exports_to_keep,
                      input_feature_key=None,
                      use_deprecated_input_fn=True,
                      prediction_key=None):
    if use_deprecated_input_fn:
        input_fn = input_fn or _default_input_fn
    elif input_fn is None:
        raise ValueError('input_fn must be defined.')

    checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
    with ops.Graph().as_default() as g:
        contrib_variables.create_global_step(g)

        if use_deprecated_input_fn:
            examples = array_ops.placeholder(dtype=dtypes.string,
                                             shape=[default_batch_size],
                                             name='input_example_tensor')
            features = input_fn(estimator, examples)
        else:
            features, _ = input_fn()
            examples = None
            if input_feature_key is not None:
                examples = features.pop(input_feature_key)

        if (not features) and (examples is None):
            raise ValueError('Either features or examples must be defined.')

        # The default return type of _get_predict_ops is ModelFnOps. But there are
        # some subclasses of tf.contrib.learn.Estimator which override this
        # method and use the legacy signature, namely _get_predict_ops returns a
        # `predictions` Tensor or dict or Tensors. The following else-statement
        # code covers these cases, but will soon be deleted after the subclasses
        # are updated.
        # TODO(b/32664904): Update subclasses and delete the else-statement.
        infer_ops = estimator._get_predict_ops(features)
        if isinstance(infer_ops, model_fn.ModelFnOps):  # Default signature
            predictions = infer_ops.predictions
        else:  # Legacy signature
            predictions = infer_ops

        if prediction_key is not None:
            predictions = predictions[prediction_key]

        # Explicit signature_fn takes priority
        if signature_fn:
            default_signature, named_graph_signatures = signature_fn(
                examples, features, predictions)
        else:
            try:
                # Some estimators provide a signature function.
                # TODO(zakaria): check if the estimator has this function,
                #   raise helpful error if not
                signature_fn = estimator._create_signature_fn()

                default_signature, named_graph_signatures = (signature_fn(
                    examples, features, predictions))
            except AttributeError:
                logging.warn(
                    'Change warning: `signature_fn` will be required after'
                    '2016-08-01.\n'
                    'Using generic signatures for now.  To maintain this behavior, '
                    'pass:\n'
                    '  signature_fn=export.generic_signature_fn\n'
                    'Also consider passing a regression or classification signature; '
                    'see cl/126430915 for an example.')
                default_signature, named_graph_signatures = generic_signature_fn(
                    examples, features, predictions)
        if exports_to_keep is not None:
            exports_to_keep = gc.largest_export_versions(exports_to_keep)
        return _export_graph(g,
                             _get_saver(),
                             checkpoint_path,
                             export_dir,
                             default_graph_signature=default_signature,
                             named_graph_signatures=named_graph_signatures,
                             exports_to_keep=exports_to_keep)
Пример #16
0
    def doBasicsOneExportPath(self,
                              export_path,
                              clear_devices=False,
                              global_step=GLOBAL_STEP,
                              sharded=True):
        # Build a graph with 2 parameter nodes on different devices.
        tf.reset_default_graph()
        with tf.Session(target="",
                        config=config_pb2.ConfigProto(
                            device_count={"CPU": 2})) as sess:
            # v2 is an unsaved variable derived from v0 and v1.  It is used to
            # exercise the ability to run an init op when restoring a graph.
            with sess.graph.device("/cpu:0"):
                v0 = tf.Variable(10, name="v0")
            with sess.graph.device("/cpu:1"):
                v1 = tf.Variable(20, name="v1")
            v2 = tf.Variable(1, name="v2", trainable=False, collections=[])
            assign_v2 = tf.assign(v2, tf.add(v0, v1))
            init_op = tf.group(assign_v2, name="init_op")

            tf.add_to_collection("v", v0)
            tf.add_to_collection("v", v1)
            tf.add_to_collection("v", v2)

            global_step_tensor = tf.Variable(global_step, name="global_step")
            named_tensor_bindings = {
                "logical_input_A": v0,
                "logical_input_B": v1
            }
            signatures = {
                "foo":
                exporter.regression_signature(input_tensor=v0,
                                              output_tensor=v1),
                "generic":
                exporter.generic_signature(named_tensor_bindings)
            }

            asset_filepath_orig = os.path.join(tf.test.get_temp_dir(),
                                               "hello42.txt")
            asset_file = tf.constant(asset_filepath_orig, name="filename42")
            tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, asset_file)

            with gfile.FastGFile(asset_filepath_orig, "w") as f:
                f.write("your data here")
            assets_collection = tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS)

            ignored_asset = os.path.join(tf.test.get_temp_dir(), "ignored.txt")
            with gfile.FastGFile(ignored_asset, "w") as f:
                f.write("additional data here")

            tf.initialize_all_variables().run()

            # Run an export.
            save = tf.train.Saver({
                "v0": v0,
                "v1": v1
            },
                                  restore_sequentially=True,
                                  sharded=sharded)
            export = exporter.Exporter(save)
            export.init(
                sess.graph.as_graph_def(),
                init_op=init_op,
                clear_devices=clear_devices,
                default_graph_signature=exporter.classification_signature(
                    input_tensor=v0),
                named_graph_signatures=signatures,
                assets_collection=assets_collection)
            export.export(export_path,
                          global_step_tensor,
                          sess,
                          exports_to_keep=gc.largest_export_versions(2))

        # Restore graph.
        compare_def = tf.get_default_graph().as_graph_def()
        tf.reset_default_graph()
        with tf.Session(target="",
                        config=config_pb2.ConfigProto(
                            device_count={"CPU": 2})) as sess:
            save = tf.train.import_meta_graph(
                os.path.join(export_path,
                             constants.VERSION_FORMAT_SPECIFIER % global_step,
                             constants.META_GRAPH_DEF_FILENAME))
            self.assertIsNotNone(save)
            meta_graph_def = save.export_meta_graph()
            collection_def = meta_graph_def.collection_def

            # Validate custom graph_def.
            graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
            self.assertEquals(len(graph_def_any), 1)
            graph_def = tf.GraphDef()
            graph_def_any[0].Unpack(graph_def)
            if clear_devices:
                for node in compare_def.node:
                    node.device = ""
            self.assertProtoEquals(compare_def, graph_def)

            # Validate init_op.
            init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
            self.assertEquals(len(init_ops), 1)
            self.assertEquals(init_ops[0], "init_op")

            # Validate signatures.
            signatures_any = collection_def[
                constants.SIGNATURES_KEY].any_list.value
            self.assertEquals(len(signatures_any), 1)
            signatures = manifest_pb2.Signatures()
            signatures_any[0].Unpack(signatures)
            default_signature = signatures.default_signature
            self.assertEqual(
                default_signature.classification_signature.input.tensor_name,
                "v0:0")
            bindings = signatures.named_signatures[
                "generic"].generic_signature.map
            self.assertEquals(bindings["logical_input_A"].tensor_name, "v0:0")
            self.assertEquals(bindings["logical_input_B"].tensor_name, "v1:0")
            read_foo_signature = (
                signatures.named_signatures["foo"].regression_signature)
            self.assertEquals(read_foo_signature.input.tensor_name, "v0:0")
            self.assertEquals(read_foo_signature.output.tensor_name, "v1:0")

            # Validate the assets.
            assets_any = collection_def[constants.ASSETS_KEY].any_list.value
            self.assertEquals(len(assets_any), 1)
            asset = manifest_pb2.AssetFile()
            assets_any[0].Unpack(asset)
            assets_path = os.path.join(
                export_path, constants.VERSION_FORMAT_SPECIFIER % global_step,
                constants.ASSETS_DIRECTORY, "hello42.txt")
            asset_contents = gfile.GFile(assets_path).read()
            self.assertEqual(asset_contents, "your data here")
            self.assertEquals("hello42.txt", asset.filename)
            self.assertEquals("filename42:0", asset.tensor_binding.tensor_name)
            ignored_asset_path = os.path.join(
                export_path, constants.VERSION_FORMAT_SPECIFIER % global_step,
                constants.ASSETS_DIRECTORY, "ignored.txt")
            self.assertFalse(gfile.Exists(ignored_asset_path))

            # Validate graph restoration.
            if sharded:
                save.restore(
                    sess,
                    os.path.join(
                        export_path,
                        constants.VERSION_FORMAT_SPECIFIER % global_step,
                        constants.VARIABLES_FILENAME_PATTERN))
            else:
                save.restore(
                    sess,
                    os.path.join(
                        export_path,
                        constants.VERSION_FORMAT_SPECIFIER % global_step,
                        constants.VARIABLES_FILENAME))
            self.assertEqual(10, tf.get_collection("v")[0].eval())
            self.assertEqual(20, tf.get_collection("v")[1].eval())
            tf.get_collection(constants.INIT_OP_KEY)[0].run()
            self.assertEqual(30, tf.get_collection("v")[2].eval())
Пример #17
0
  def doBasicsOneExportPath(self,
                            export_path,
                            clear_devices=False,
                            global_step=GLOBAL_STEP,
                            sharded=True):
    # Build a graph with 2 parameter nodes on different devices.
    tf.reset_default_graph()
    with tf.Session(
        target="",
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      # v2 is an unsaved variable derived from v0 and v1.  It is used to
      # exercise the ability to run an init op when restoring a graph.
      with sess.graph.device("/cpu:0"):
        v0 = tf.Variable(10, name="v0")
      with sess.graph.device("/cpu:1"):
        v1 = tf.Variable(20, name="v1")
      v2 = tf.Variable(1, name="v2", trainable=False, collections=[])
      assign_v2 = tf.assign(v2, tf.add(v0, v1))
      init_op = tf.group(assign_v2, name="init_op")

      tf.add_to_collection("v", v0)
      tf.add_to_collection("v", v1)
      tf.add_to_collection("v", v2)

      global_step_tensor = tf.Variable(global_step, name="global_step")
      named_tensor_bindings = {"logical_input_A": v0, "logical_input_B": v1}
      signatures = {
          "foo": exporter.regression_signature(input_tensor=v0,
                                               output_tensor=v1),
          "generic": exporter.generic_signature(named_tensor_bindings)
      }

      asset_filepath_orig = os.path.join(tf.test.get_temp_dir(), "hello42.txt")
      asset_file = tf.constant(asset_filepath_orig, name="filename42")
      tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, asset_file)

      with gfile.FastGFile(asset_filepath_orig, "w") as f:
        f.write("your data here")
      assets_collection = tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS)

      ignored_asset = os.path.join(tf.test.get_temp_dir(), "ignored.txt")
      with gfile.FastGFile(ignored_asset, "w") as f:
        f.write("additional data here")

      tf.initialize_all_variables().run()

      # Run an export.
      save = tf.train.Saver({"v0": v0,
                             "v1": v1},
                            restore_sequentially=True,
                            sharded=sharded)
      export = exporter.Exporter(save)
      export.init(sess.graph.as_graph_def(),
                  init_op=init_op,
                  clear_devices=clear_devices,
                  default_graph_signature=exporter.classification_signature(
                      input_tensor=v0),
                  named_graph_signatures=signatures,
                  assets_collection=assets_collection)
      export.export(export_path,
                    global_step_tensor,
                    sess,
                    exports_to_keep=gc.largest_export_versions(2))

    # Restore graph.
    compare_def = tf.get_default_graph().as_graph_def()
    tf.reset_default_graph()
    with tf.Session(
        target="",
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      save = tf.train.import_meta_graph(
          os.path.join(export_path, constants.VERSION_FORMAT_SPECIFIER %
                       global_step, constants.META_GRAPH_DEF_FILENAME))
      self.assertIsNotNone(save)
      meta_graph_def = save.export_meta_graph()
      collection_def = meta_graph_def.collection_def

      # Validate custom graph_def.
      graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
      self.assertEquals(len(graph_def_any), 1)
      graph_def = tf.GraphDef()
      graph_def_any[0].Unpack(graph_def)
      if clear_devices:
        for node in compare_def.node:
          node.device = ""
      self.assertProtoEquals(compare_def, graph_def)

      # Validate init_op.
      init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
      self.assertEquals(len(init_ops), 1)
      self.assertEquals(init_ops[0], "init_op")

      # Validate signatures.
      signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value
      self.assertEquals(len(signatures_any), 1)
      signatures = manifest_pb2.Signatures()
      signatures_any[0].Unpack(signatures)
      default_signature = signatures.default_signature
      self.assertEqual(
          default_signature.classification_signature.input.tensor_name, "v0:0")
      bindings = signatures.named_signatures["generic"].generic_signature.map
      self.assertEquals(bindings["logical_input_A"].tensor_name, "v0:0")
      self.assertEquals(bindings["logical_input_B"].tensor_name, "v1:0")
      read_foo_signature = (
          signatures.named_signatures["foo"].regression_signature)
      self.assertEquals(read_foo_signature.input.tensor_name, "v0:0")
      self.assertEquals(read_foo_signature.output.tensor_name, "v1:0")

      # Validate the assets.
      assets_any = collection_def[constants.ASSETS_KEY].any_list.value
      self.assertEquals(len(assets_any), 1)
      asset = manifest_pb2.AssetFile()
      assets_any[0].Unpack(asset)
      assets_path = os.path.join(export_path,
                                 constants.VERSION_FORMAT_SPECIFIER %
                                 global_step, constants.ASSETS_DIRECTORY,
                                 "hello42.txt")
      asset_contents = gfile.GFile(assets_path).read()
      self.assertEqual(asset_contents, "your data here")
      self.assertEquals("hello42.txt", asset.filename)
      self.assertEquals("filename42:0", asset.tensor_binding.tensor_name)
      ignored_asset_path = os.path.join(export_path,
                                        constants.VERSION_FORMAT_SPECIFIER %
                                        global_step, constants.ASSETS_DIRECTORY,
                                        "ignored.txt")
      self.assertFalse(gfile.Exists(ignored_asset_path))

      # Validate graph restoration.
      if sharded:
        save.restore(sess,
                     os.path.join(
                        export_path, constants.VERSION_FORMAT_SPECIFIER %
                        global_step, constants.VARIABLES_FILENAME_PATTERN))
      else:
        save.restore(sess,
                     os.path.join(
                        export_path, constants.VERSION_FORMAT_SPECIFIER %
                        global_step, constants.VARIABLES_FILENAME))
      self.assertEqual(10, tf.get_collection("v")[0].eval())
      self.assertEqual(20, tf.get_collection("v")[1].eval())
      tf.get_collection(constants.INIT_OP_KEY)[0].run()
      self.assertEqual(30, tf.get_collection("v")[2].eval())
 def testLargestExportVersions(self):
     paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)]
     newest = gc.largest_export_versions(2)
     n = newest(paths)
     self.assertEquals(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)])