def testGetNodeInGraph(self): g = tf.Graph() with g.as_default(): apple = tf.constant(1.0) meta_graph_def = meta_graph_pb2.MetaGraphDef() meta_graph_def.collection_def['fruit_node'].any_list.value.extend( [encoding.encode_tensor_node(apple)]) self.assertEqual( apple, graph_ref.get_node_in_graph(meta_graph_def, 'fruit_node', g))
def _load_and_parse_graph(self): """Actually load and parse the graph. This is factored out from __init__ in case we want to support delayed-loads in the future. """ meta_graph_def = tf.saved_model.loader.load( self._session, [tf.saved_model.tag_constants.SERVING], self._path) with self._graph.as_default(): # Get references to "named" nodes. self._input_example_node = graph_ref.get_node_in_graph( meta_graph_def, encoding.INPUT_EXAMPLE_COLLECTION, self._graph) self._labels_map = graph_ref.get_node_map_in_graph( meta_graph_def, encoding.LABELS_COLLECTION, [encoding.NODE_SUFFIX], self._graph) self._predictions_map = graph_ref.get_node_map_in_graph( meta_graph_def, encoding.PREDICTIONS_COLLECTION, [encoding.NODE_SUFFIX], self._graph) self._features_map = graph_ref.get_node_map_in_graph( meta_graph_def, encoding.FEATURES_COLLECTION, [encoding.NODE_SUFFIX], self._graph) metrics = graph_ref.get_node_map_in_graph( meta_graph_def, encoding.METRICS_COLLECTION, [encoding.VALUE_OP_SUFFIX, encoding.UPDATE_OP_SUFFIX], self._graph) metric_ops = {} for metric_name, ops in metrics.items(): metric_ops[metric_name] = (ops[encoding.VALUE_OP_SUFFIX], ops[encoding.UPDATE_OP_SUFFIX]) self._metric_names = [] self._metric_value_ops = [] self._metric_update_ops = [] self._metric_variable_nodes = [] self._metric_variable_placeholders = [] self._metric_variable_assign_ops = [] self.register_additional_metric_ops(metric_ops)
def _load_and_parse_graph(self): """Actually load and parse the graph. This is factored out from __init__ in case we want to support delayed-loads in the future. Raises: ValueError: Could not find signature keyed with EVAL_TAG; or signature_def did not have exactly one input; or there was a signature output with the metric prefix but an unrecognised suffix. """ meta_graph_def = tf.saved_model.loader.load(self._session, [constants.EVAL_TAG], self._path) self._check_version(meta_graph_def) with self._graph.as_default(): signature_def = meta_graph_def.signature_def.get( constants.EVAL_TAG) if signature_def is None: raise ValueError( 'could not find signature with name %s. signature_def ' 'was %s' % (constants.EVAL_TAG, signature_def)) # Note that there are two different encoding schemes in use here: # # 1. The scheme used by TFMA for the TFMA-specific extra collections # for the features and labels. # 2. The scheme used by TensorFlow Estimators in the SignatureDefs for the # input example node, predictions, metrics and so on. # Features and labels are in TFMA-specific extra collections. # # We use OrderedDict because the ordering of the keys matters: # we need to fix a canonical ordering for passing feed_list arguments # into make_callable. self._features_map = collections.OrderedDict( graph_ref.get_node_map_in_graph(meta_graph_def, encoding.FEATURES_COLLECTION, [encoding.NODE_SUFFIX], self._graph)) self._labels_map = collections.OrderedDict( graph_ref.get_node_map_in_graph(meta_graph_def, encoding.LABELS_COLLECTION, [encoding.NODE_SUFFIX], self._graph)) if len(signature_def.inputs) != 1: raise ValueError( 'there should be exactly one input. signature_def ' 'was: %s' % signature_def) # The input node, predictions and metrics are in the signature. input_node = list(signature_def.inputs.values())[0] self._input_example_node = ( tf.saved_model.utils.get_tensor_from_tensor_info( input_node, self._graph)) # The example reference node. If not defined in the graph, use the # input examples as example references. try: self._example_ref_tensor = graph_ref.get_node_in_graph( meta_graph_def, encoding.EXAMPLE_REF_COLLECTION, self._graph) except KeyError: # If we can't find the ExampleRef collection, then this is probably a # model created before we introduced the ExampleRef parameter to # EvalInputReceiver. In that case, we default to a tensor of range(0, # len(input_example)). self._example_ref_tensor = tf.range( tf.size(self._input_example_node)) # We use OrderedDict because the ordering of the keys matters: # we need to fix a canonical ordering for passing feed_dict arguments # into make_callable. # # The canonical ordering we use here is simply the ordering we get # from the predictions collection. predictions = graph_ref.extract_signature_outputs_with_prefix( constants.PREDICTIONS_NAME, signature_def.outputs) predictions_map = collections.OrderedDict() for k, v in predictions.items(): # Extract to dictionary with a single key for consistency with # how features and labels are extracted. predictions_map[k] = { encoding.NODE_SUFFIX: tf.saved_model.utils.get_tensor_from_tensor_info( v, self._graph) } self._predictions_map = predictions_map metrics = graph_ref.extract_signature_outputs_with_prefix( constants.METRICS_NAME, signature_def.outputs) metrics_map = collections.defaultdict(dict) for k, v in metrics.items(): node = tf.saved_model.utils.get_tensor_from_tensor_info( v, self._graph) if k.endswith('/' + constants.METRIC_VALUE_SUFFIX): key = k[:-len(constants.METRIC_VALUE_SUFFIX) - 1] metrics_map[key][encoding.VALUE_OP_SUFFIX] = node elif k.endswith('/' + constants.METRIC_UPDATE_SUFFIX): key = k[:-len(constants.METRIC_UPDATE_SUFFIX) - 1] metrics_map[key][encoding.UPDATE_OP_SUFFIX] = node else: raise ValueError( 'unrecognised suffix for metric. key was: %s' % k) metric_ops = {} for metric_name, ops in metrics_map.items(): metric_ops[metric_name] = (ops[encoding.VALUE_OP_SUFFIX], ops[encoding.UPDATE_OP_SUFFIX]) # Create feed_list for metrics_reset_update_get_fn # # We need to save this because we need to update the # metrics_reset_update_get_fn when additional metric ops are registered # (the feed_list will stay the same though). feed_list = [] feed_list_keys = [] for which_map, key, map_dict in ( self._iterate_fpl_maps_in_canonical_order()): feed_list.append(map_dict[encoding.NODE_SUFFIX]) feed_list_keys.append((which_map, key)) self._metrics_reset_update_get_fn_feed_list = feed_list # We also keep the associated keys for better error messages. self._metrics_reset_update_get_fn_feed_list_keys = feed_list_keys self._metric_names = [] self._metric_value_ops = [] self._metric_update_ops = [] self._metric_variable_nodes = [] self._metric_variable_placeholders = [] self._metric_variable_assign_ops = [] self.register_additional_metric_ops(metric_ops) # Make callable for predict_list. The callable for # metrics_reset_update_get is updated in register_additional_metric_ops. # Repeated calls to a callable made using make_callable are faster than # doing repeated calls to session.run. self._predict_list_fn = self._session.make_callable( fetches=(self._features_map, self._predictions_map, self._labels_map, self._example_ref_tensor), feed_list=[self._input_example_node])