def _construct_graph(self): """Actually load and parse the graph. This is factored out from __init__ in case we want to support delayed-loads in the future. Raises: ValueError: Could not find signature keyed with DEFAULT_EVAL_SIGNATURE_DEF_KEY; or signature_def did not have exactly one input; or there was a signature output with the metric prefix but an unrecognised suffix. """ meta_graph_def = tf.saved_model.loader.load(self._session, [constants.EVAL_TAG], self._path) with self._graph.as_default(): signature_def = meta_graph_def.signature_def.get( constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY) if signature_def is None: raise ValueError( 'could not find signature with name %s. signature_def ' 'was %s' % (constants.EVAL_TAG, signature_def)) # If features and labels are not stored in the signature_def.inputs then # only a single input will be present. We will use this as our flag to # indicate whether the features and labels should be read using the legacy # collections or using new signature_def.inputs. if len(signature_def.inputs) == 1: self._legacy_check_version(meta_graph_def) self._input_map, self._input_refs_node = graph_ref.load_legacy_inputs( meta_graph_def, signature_def, self._graph) self._features_map, self._labels_map = ( graph_ref.load_legacy_features_and_labels( meta_graph_def, self._graph)) else: self._check_version( graph_ref.load_tfma_version(signature_def, self._graph)) self._input_map, self._input_refs_node = graph_ref.load_inputs( signature_def, self._graph) self._features_map, self._labels_map = ( graph_ref.load_features_and_labels(signature_def, self._graph)) self._predictions_map = graph_ref.load_predictions( signature_def, self._graph) # Create feed_list for metrics_reset_update_get_fn # # We need to save this because we need to update the # metrics_reset_update_get_fn when additional metric ops are registered # (the feed_list will stay the same though). feed_list = [] feed_list_keys = [] for which_map, key, map_dict in ( self._iterate_fpl_maps_in_canonical_order()): feed_list.append(map_dict[encoding.NODE_SUFFIX]) feed_list_keys.append((which_map, key)) self._perform_metrics_update_fn_feed_list = feed_list # We also keep the associated keys for better error messages. self._perform_metrics_update_fn_feed_list_keys = feed_list_keys self._metric_names = [] self._metric_value_ops = [] self._metric_update_ops = [] self._metric_variable_nodes = [] self._metric_variable_placeholders = [] self._metric_variable_assign_ops = [] if self._include_default_metrics: metrics_map = graph_ref.load_metrics(signature_def, self._graph) metric_ops = {} for metric_name, ops in metrics_map.items(): metric_ops[metric_name] = (ops[encoding.VALUE_OP_SUFFIX], ops[encoding.UPDATE_OP_SUFFIX]) self.register_additional_metric_ops(metric_ops) # Make callable for predict_list. The callable for # metrics_reset_update_get is updated in register_additional_metric_ops. # Repeated calls to a callable made using make_callable are faster than # doing repeated calls to session.run. self._predict_list_fn = self._session.make_callable( fetches=(self._features_map, self._predictions_map, self._labels_map, self._input_refs_node), feed_list=list(self._input_map.values()))
def _construct_graph(self): """Actually load and parse the graph. This is factored out from __init__ in case we want to support delayed-loads in the future. Raises: ValueError: Could not find signature keyed with DEFAULT_EVAL_SIGNATURE_DEF_KEY; or signature_def did not have exactly one input; or there was a signature output with the metric prefix but an unrecognised suffix. """ meta_graph_def = tf.compat.v1.saved_model.loader.load( self._session, self._tags, self._path) with self._graph.as_default(): signature_def = meta_graph_def.signature_def.get( constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY) if signature_def is None: raise ValueError('could not find signature with name %s. signature_def ' 'was %s' % (constants.EVAL_TAG, signature_def)) self._additional_fetches_map = {} iterator_initializer = None # If features and labels are not stored in the signature_def.inputs then # only a single input will be present. We will use this as our flag to # indicate whether the features and labels should be read using the legacy # collections or using new signature_def.inputs. # TODO(b/119308261): Remove once all exported EvalSavedModels are updated. if len(signature_def.inputs) == 1: self._legacy_check_version(meta_graph_def) self._input_map, self._input_refs_node = graph_ref.load_legacy_inputs( meta_graph_def, signature_def, self._graph) self._features_map, self._labels_map = ( graph_ref.load_legacy_features_and_labels(meta_graph_def, self._graph)) else: self._check_version( graph_ref.load_tfma_version(signature_def, self._graph)) self._input_map, self._input_refs_node = graph_ref.load_inputs( signature_def, self._graph) self._features_map = graph_ref.load_additional_inputs( constants.FEATURES_NAME, signature_def, self._graph) if self._blacklist_feature_fetches: for feature_name in self._blacklist_feature_fetches: self._features_map.pop(feature_name, None) self._labels_map = graph_ref.load_additional_inputs( constants.LABELS_NAME, signature_def, self._graph) if self._additional_fetches: for prefix in self._additional_fetches: self._additional_fetches_map[prefix] = ( graph_ref.load_additional_inputs(prefix, signature_def, self._graph)) iterator_initializer = self._get_op_from_tensor( graph_ref.load_iterator_initializer_name(signature_def, self._graph)) self._predictions_map = graph_ref.load_predictions( signature_def, self._graph) # Create feed_list for metrics_reset_update_get_fn # # We need to save this because we need to update the # metrics_reset_update_get_fn when additional metric ops are registered # (the feed_list will stay the same though). self._perform_metrics_update_fn_feed_list = list(self._input_map.values()) self._metric_names = [] self._metric_value_ops = [] self._metric_update_ops = [] self._metric_variable_nodes = [] self._metric_variable_placeholders = [] self._metric_variable_assign_ops = [] if self._include_default_metrics: metrics_map = graph_ref.load_metrics(signature_def, self._graph) metric_ops = {} for metric_name, ops in metrics_map.items(): metric_ops[metric_name] = (ops[encoding.VALUE_OP_SUFFIX], ops[encoding.UPDATE_OP_SUFFIX]) self.register_additional_metric_ops(metric_ops) # Make callable for predict_list. The callable for # metrics_reset_update_get is updated in register_additional_metric_ops. # Repeated calls to a callable made using make_callable are faster than # doing repeated calls to session.run. if iterator_initializer: # When iterator is used, the initializer is used to feed the inputs. The # values are then fetched by repeated calls to the predict_list_fn until # OutOfRange is thrown. self._iterator_initializer_fn = self._session.make_callable( fetches=(iterator_initializer), feed_list=list(self._input_map.values())) self._predict_list_fn = self._session.make_callable( fetches=(self._features_map, self._predictions_map, self._labels_map, self._input_refs_node, self._additional_fetches_map)) else: self._iterator_initializer_fn = None self._predict_list_fn = self._session.make_callable( fetches=(self._features_map, self._predictions_map, self._labels_map, self._input_refs_node, self._additional_fetches_map), feed_list=list(self._input_map.values()))