Esempio n. 1
0
  def testGetNodeInGraph(self):
    g = tf.Graph()
    with g.as_default():
      apple = tf.constant(1.0)

    meta_graph_def = meta_graph_pb2.MetaGraphDef()
    meta_graph_def.collection_def['fruit_node'].any_list.value.extend(
        [encoding.encode_tensor_node(apple)])

    self.assertEqual(
        apple, graph_ref.get_node_in_graph(meta_graph_def, 'fruit_node', g))
Esempio n. 2
0
    def _load_and_parse_graph(self):
        """Actually load and parse the graph.

    This is factored out from __init__ in case we want to support delayed-loads
    in the future.
    """
        meta_graph_def = tf.saved_model.loader.load(
            self._session, [tf.saved_model.tag_constants.SERVING], self._path)

        with self._graph.as_default():
            # Get references to "named" nodes.
            self._input_example_node = graph_ref.get_node_in_graph(
                meta_graph_def, encoding.INPUT_EXAMPLE_COLLECTION, self._graph)
            self._labels_map = graph_ref.get_node_map_in_graph(
                meta_graph_def, encoding.LABELS_COLLECTION,
                [encoding.NODE_SUFFIX], self._graph)
            self._predictions_map = graph_ref.get_node_map_in_graph(
                meta_graph_def, encoding.PREDICTIONS_COLLECTION,
                [encoding.NODE_SUFFIX], self._graph)
            self._features_map = graph_ref.get_node_map_in_graph(
                meta_graph_def, encoding.FEATURES_COLLECTION,
                [encoding.NODE_SUFFIX], self._graph)

            metrics = graph_ref.get_node_map_in_graph(
                meta_graph_def, encoding.METRICS_COLLECTION,
                [encoding.VALUE_OP_SUFFIX, encoding.UPDATE_OP_SUFFIX],
                self._graph)
            metric_ops = {}
            for metric_name, ops in metrics.items():
                metric_ops[metric_name] = (ops[encoding.VALUE_OP_SUFFIX],
                                           ops[encoding.UPDATE_OP_SUFFIX])

            self._metric_names = []
            self._metric_value_ops = []
            self._metric_update_ops = []
            self._metric_variable_nodes = []
            self._metric_variable_placeholders = []
            self._metric_variable_assign_ops = []
            self.register_additional_metric_ops(metric_ops)
Esempio n. 3
0
    def _load_and_parse_graph(self):
        """Actually load and parse the graph.

    This is factored out from __init__ in case we want to support delayed-loads
    in the future.

    Raises:
      ValueError: Could not find signature keyed with EVAL_TAG; or
        signature_def did not have exactly one input; or there was a signature
        output with the metric prefix but an unrecognised suffix.
    """
        meta_graph_def = tf.saved_model.loader.load(self._session,
                                                    [constants.EVAL_TAG],
                                                    self._path)

        self._check_version(meta_graph_def)
        with self._graph.as_default():
            signature_def = meta_graph_def.signature_def.get(
                constants.EVAL_TAG)
            if signature_def is None:
                raise ValueError(
                    'could not find signature with name %s. signature_def '
                    'was %s' % (constants.EVAL_TAG, signature_def))

            # Note that there are two different encoding schemes in use here:
            #
            # 1. The scheme used by TFMA for the TFMA-specific extra collections
            #    for the features and labels.
            # 2. The scheme used by TensorFlow Estimators in the SignatureDefs for the
            #    input example node, predictions, metrics and so on.

            # Features and labels are in TFMA-specific extra collections.
            #
            # We use OrderedDict because the ordering of the keys matters:
            # we need to fix a canonical ordering for passing feed_list arguments
            # into make_callable.
            self._features_map = collections.OrderedDict(
                graph_ref.get_node_map_in_graph(meta_graph_def,
                                                encoding.FEATURES_COLLECTION,
                                                [encoding.NODE_SUFFIX],
                                                self._graph))
            self._labels_map = collections.OrderedDict(
                graph_ref.get_node_map_in_graph(meta_graph_def,
                                                encoding.LABELS_COLLECTION,
                                                [encoding.NODE_SUFFIX],
                                                self._graph))

            if len(signature_def.inputs) != 1:
                raise ValueError(
                    'there should be exactly one input. signature_def '
                    'was: %s' % signature_def)

            # The input node, predictions and metrics are in the signature.
            input_node = list(signature_def.inputs.values())[0]
            self._input_example_node = (
                tf.saved_model.utils.get_tensor_from_tensor_info(
                    input_node, self._graph))

            # The example reference node. If not defined in the graph, use the
            # input examples as example references.
            try:
                self._example_ref_tensor = graph_ref.get_node_in_graph(
                    meta_graph_def, encoding.EXAMPLE_REF_COLLECTION,
                    self._graph)
            except KeyError:
                # If we can't find the ExampleRef collection, then this is probably a
                # model created before we introduced the ExampleRef parameter to
                # EvalInputReceiver. In that case, we default to a tensor of range(0,
                # len(input_example)).
                self._example_ref_tensor = tf.range(
                    tf.size(self._input_example_node))

            # We use OrderedDict because the ordering of the keys matters:
            # we need to fix a canonical ordering for passing feed_dict arguments
            # into make_callable.
            #
            # The canonical ordering we use here is simply the ordering we get
            # from the predictions collection.
            predictions = graph_ref.extract_signature_outputs_with_prefix(
                constants.PREDICTIONS_NAME, signature_def.outputs)
            predictions_map = collections.OrderedDict()
            for k, v in predictions.items():
                # Extract to dictionary with a single key for consistency with
                # how features and labels are extracted.
                predictions_map[k] = {
                    encoding.NODE_SUFFIX:
                    tf.saved_model.utils.get_tensor_from_tensor_info(
                        v, self._graph)
                }
            self._predictions_map = predictions_map

            metrics = graph_ref.extract_signature_outputs_with_prefix(
                constants.METRICS_NAME, signature_def.outputs)
            metrics_map = collections.defaultdict(dict)
            for k, v in metrics.items():
                node = tf.saved_model.utils.get_tensor_from_tensor_info(
                    v, self._graph)

                if k.endswith('/' + constants.METRIC_VALUE_SUFFIX):
                    key = k[:-len(constants.METRIC_VALUE_SUFFIX) - 1]
                    metrics_map[key][encoding.VALUE_OP_SUFFIX] = node
                elif k.endswith('/' + constants.METRIC_UPDATE_SUFFIX):
                    key = k[:-len(constants.METRIC_UPDATE_SUFFIX) - 1]
                    metrics_map[key][encoding.UPDATE_OP_SUFFIX] = node
                else:
                    raise ValueError(
                        'unrecognised suffix for metric. key was: %s' % k)

            metric_ops = {}
            for metric_name, ops in metrics_map.items():
                metric_ops[metric_name] = (ops[encoding.VALUE_OP_SUFFIX],
                                           ops[encoding.UPDATE_OP_SUFFIX])

            # Create feed_list for metrics_reset_update_get_fn
            #
            # We need to save this because we need to update the
            # metrics_reset_update_get_fn when additional metric ops are registered
            # (the feed_list will stay the same though).
            feed_list = []
            feed_list_keys = []
            for which_map, key, map_dict in (
                    self._iterate_fpl_maps_in_canonical_order()):
                feed_list.append(map_dict[encoding.NODE_SUFFIX])
                feed_list_keys.append((which_map, key))
            self._metrics_reset_update_get_fn_feed_list = feed_list
            # We also keep the associated keys for better error messages.
            self._metrics_reset_update_get_fn_feed_list_keys = feed_list_keys

            self._metric_names = []
            self._metric_value_ops = []
            self._metric_update_ops = []
            self._metric_variable_nodes = []
            self._metric_variable_placeholders = []
            self._metric_variable_assign_ops = []
            self.register_additional_metric_ops(metric_ops)

            # Make callable for predict_list. The callable for
            # metrics_reset_update_get is updated in register_additional_metric_ops.
            # Repeated calls to a callable made using make_callable are faster than
            # doing repeated calls to session.run.
            self._predict_list_fn = self._session.make_callable(
                fetches=(self._features_map, self._predictions_map,
                         self._labels_map, self._example_ref_tensor),
                feed_list=[self._input_example_node])