示例#1
0
    def _construct_graph(self):
        """Actually load and parse the graph.

    This is factored out from __init__ in case we want to support delayed-loads
    in the future.

    Raises:
      ValueError: Could not find signature keyed with
        DEFAULT_EVAL_SIGNATURE_DEF_KEY; or signature_def did not have exactly
        one input; or there was a signature output with the metric prefix but an
        unrecognised suffix.
    """
        meta_graph_def = tf.saved_model.loader.load(self._session,
                                                    [constants.EVAL_TAG],
                                                    self._path)

        with self._graph.as_default():
            signature_def = meta_graph_def.signature_def.get(
                constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY)
            if signature_def is None:
                raise ValueError(
                    'could not find signature with name %s. signature_def '
                    'was %s' % (constants.EVAL_TAG, signature_def))

            # If features and labels are not stored in the signature_def.inputs then
            # only a single input will be present. We will use this as our flag to
            # indicate whether the features and labels should be read using the legacy
            # collections or using new signature_def.inputs.
            if len(signature_def.inputs) == 1:
                self._legacy_check_version(meta_graph_def)
                self._input_map, self._input_refs_node = graph_ref.load_legacy_inputs(
                    meta_graph_def, signature_def, self._graph)
                self._features_map, self._labels_map = (
                    graph_ref.load_legacy_features_and_labels(
                        meta_graph_def, self._graph))
            else:
                self._check_version(
                    graph_ref.load_tfma_version(signature_def, self._graph))
                self._input_map, self._input_refs_node = graph_ref.load_inputs(
                    signature_def, self._graph)
                self._features_map, self._labels_map = (
                    graph_ref.load_features_and_labels(signature_def,
                                                       self._graph))

            self._predictions_map = graph_ref.load_predictions(
                signature_def, self._graph)

            # Create feed_list for metrics_reset_update_get_fn
            #
            # We need to save this because we need to update the
            # metrics_reset_update_get_fn when additional metric ops are registered
            # (the feed_list will stay the same though).
            feed_list = []
            feed_list_keys = []
            for which_map, key, map_dict in (
                    self._iterate_fpl_maps_in_canonical_order()):
                feed_list.append(map_dict[encoding.NODE_SUFFIX])
                feed_list_keys.append((which_map, key))
            self._perform_metrics_update_fn_feed_list = feed_list
            # We also keep the associated keys for better error messages.
            self._perform_metrics_update_fn_feed_list_keys = feed_list_keys

            self._metric_names = []
            self._metric_value_ops = []
            self._metric_update_ops = []
            self._metric_variable_nodes = []
            self._metric_variable_placeholders = []
            self._metric_variable_assign_ops = []

            if self._include_default_metrics:
                metrics_map = graph_ref.load_metrics(signature_def,
                                                     self._graph)
                metric_ops = {}
                for metric_name, ops in metrics_map.items():
                    metric_ops[metric_name] = (ops[encoding.VALUE_OP_SUFFIX],
                                               ops[encoding.UPDATE_OP_SUFFIX])
                self.register_additional_metric_ops(metric_ops)

            # Make callable for predict_list. The callable for
            # metrics_reset_update_get is updated in register_additional_metric_ops.
            # Repeated calls to a callable made using make_callable are faster than
            # doing repeated calls to session.run.
            self._predict_list_fn = self._session.make_callable(
                fetches=(self._features_map, self._predictions_map,
                         self._labels_map, self._input_refs_node),
                feed_list=list(self._input_map.values()))
示例#2
0
  def _construct_graph(self):
    """Actually load and parse the graph.

    This is factored out from __init__ in case we want to support delayed-loads
    in the future.

    Raises:
      ValueError: Could not find signature keyed with
        DEFAULT_EVAL_SIGNATURE_DEF_KEY; or signature_def did not have exactly
        one input; or there was a signature output with the metric prefix but an
        unrecognised suffix.
    """
    meta_graph_def = tf.compat.v1.saved_model.loader.load(
        self._session, self._tags, self._path)

    with self._graph.as_default():
      signature_def = meta_graph_def.signature_def.get(
          constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY)
      if signature_def is None:
        raise ValueError('could not find signature with name %s. signature_def '
                         'was %s' % (constants.EVAL_TAG, signature_def))

      self._additional_fetches_map = {}
      iterator_initializer = None

      # If features and labels are not stored in the signature_def.inputs then
      # only a single input will be present. We will use this as our flag to
      # indicate whether the features and labels should be read using the legacy
      # collections or using new signature_def.inputs.
      # TODO(b/119308261): Remove once all exported EvalSavedModels are updated.
      if len(signature_def.inputs) == 1:
        self._legacy_check_version(meta_graph_def)
        self._input_map, self._input_refs_node = graph_ref.load_legacy_inputs(
            meta_graph_def, signature_def, self._graph)
        self._features_map, self._labels_map = (
            graph_ref.load_legacy_features_and_labels(meta_graph_def,
                                                      self._graph))
      else:
        self._check_version(
            graph_ref.load_tfma_version(signature_def, self._graph))
        self._input_map, self._input_refs_node = graph_ref.load_inputs(
            signature_def, self._graph)
        self._features_map = graph_ref.load_additional_inputs(
            constants.FEATURES_NAME, signature_def, self._graph)
        if self._blacklist_feature_fetches:
          for feature_name in self._blacklist_feature_fetches:
            self._features_map.pop(feature_name, None)
        self._labels_map = graph_ref.load_additional_inputs(
            constants.LABELS_NAME, signature_def, self._graph)
        if self._additional_fetches:
          for prefix in self._additional_fetches:
            self._additional_fetches_map[prefix] = (
                graph_ref.load_additional_inputs(prefix, signature_def,
                                                 self._graph))
        iterator_initializer = self._get_op_from_tensor(
            graph_ref.load_iterator_initializer_name(signature_def,
                                                     self._graph))

      self._predictions_map = graph_ref.load_predictions(
          signature_def, self._graph)

      # Create feed_list for metrics_reset_update_get_fn
      #
      # We need to save this because we need to update the
      # metrics_reset_update_get_fn when additional metric ops are registered
      # (the feed_list will stay the same though).
      self._perform_metrics_update_fn_feed_list = list(self._input_map.values())

      self._metric_names = []
      self._metric_value_ops = []
      self._metric_update_ops = []
      self._metric_variable_nodes = []
      self._metric_variable_placeholders = []
      self._metric_variable_assign_ops = []

      if self._include_default_metrics:
        metrics_map = graph_ref.load_metrics(signature_def, self._graph)
        metric_ops = {}
        for metric_name, ops in metrics_map.items():
          metric_ops[metric_name] = (ops[encoding.VALUE_OP_SUFFIX],
                                     ops[encoding.UPDATE_OP_SUFFIX])
        self.register_additional_metric_ops(metric_ops)

      # Make callable for predict_list. The callable for
      # metrics_reset_update_get is updated in register_additional_metric_ops.
      # Repeated calls to a callable made using make_callable are faster than
      # doing repeated calls to session.run.
      if iterator_initializer:
        # When iterator is used, the initializer is used to feed the inputs. The
        # values are then fetched by repeated calls to the predict_list_fn until
        # OutOfRange is thrown.
        self._iterator_initializer_fn = self._session.make_callable(
            fetches=(iterator_initializer),
            feed_list=list(self._input_map.values()))
        self._predict_list_fn = self._session.make_callable(
            fetches=(self._features_map, self._predictions_map,
                     self._labels_map, self._input_refs_node,
                     self._additional_fetches_map))
      else:
        self._iterator_initializer_fn = None
        self._predict_list_fn = self._session.make_callable(
            fetches=(self._features_map, self._predictions_map,
                     self._labels_map, self._input_refs_node,
                     self._additional_fetches_map),
            feed_list=list(self._input_map.values()))