예제 #1
0
  def _perform_metrics_update_list(self, examples_list: List[Any]) -> None:
    """Run a metrics update on a list of examples."""
    try:
      self._perform_metrics_update_fn(*[examples_list])

    except (RuntimeError, TypeError, ValueError,
            tf.errors.OpError) as exception:
      general_util.reraise_augmented(exception,
                                     'raw_input = %s' % (examples_list))
예제 #2
0
 def __init__(self, model_agnostic_config: ModelAgnosticConfig):
     self._graph = tf.Graph()
     self._session = tf.compat.v1.Session(graph=self._graph)
     self._config = model_agnostic_config
     try:
         self._create_graph()
     except (RuntimeError, ValueError) as exception:
         general_util.reraise_augmented(
             exception, 'Failed to initialize agnostic model')
예제 #3
0
 def __init__(self, path):
     self._path = path
     self._graph = tf.Graph()
     self._session = tf.Session(graph=self._graph)
     try:
         self._load_and_parse_graph()
     except (RuntimeError, ValueError) as exception:
         general_util.reraise_augmented(
             exception, 'for saved_model at path %s' % self._path)
예제 #4
0
 def perform_metrics_update(self, features_predictions_labels):
     """Run a single metrics update step on a single FPL."""
     feed_dict = self._create_feed_for_features_predictions_labels(
         features_predictions_labels)
     try:
         self._session.run(fetches=self._all_metric_update_ops,
                           feed_dict=feed_dict)
     except (RuntimeError, TypeError, ValueError) as exception:
         general_util.reraise_augmented(
             exception, 'features_predictions_labels = %s, feed_dict = %s' %
             (features_predictions_labels, feed_dict))
예제 #5
0
 def metrics_reset_update_get(self, features_predictions_labels):
     """Run the metrics reset, update, get operations on a single FPL."""
     self.reset_metric_variables()
     feed_dict = self._create_feed_for_features_predictions_labels(
         features_predictions_labels)
     try:
         [_, result] = self._session.run(fetches=[
             self._all_metric_update_ops, self._metric_variable_nodes
         ],
                                         feed_dict=feed_dict)
     except (RuntimeError, TypeError, ValueError) as exception:
         general_util.reraise_augmented(
             exception, 'features_predictions_labels = %s, feed_dict = %s' %
             (features_predictions_labels, feed_dict))
     return result
예제 #6
0
    def __init__(self):
        """Initializes this class and attempts to create the graph.

    This method attempts to create the graph through _construct_graph and
    also creates all class variables that need to be populated by the override
    function _construct_graph.
    """
        self._graph = tf.Graph()
        self._session = tf.Session(graph=self._graph)

        # Variables that need to be populated.

        # The names of the metric.
        self._metric_names = []

        # Ops associated with reading and writing the metric variables.
        self._metric_value_ops = []
        self._metric_update_ops = []
        self._metric_variable_assign_ops = []

        # Nodes associated with the metric variables.
        self._metric_variable_nodes = []

        # Placeholders and feed input for the metric variables.
        self._metric_variable_placeholders = []
        self._metrics_reset_update_get_fn_feed_list = []
        self._metrics_reset_update_get_fn_feed_list_keys = []

        # Dict that maps Features Predictions Label keys to their tensors.
        self._features_map = {}
        self._predictions_map = {}
        self._labels_map = {}

        # Ops to update/reset all metric variables.
        self._all_metric_update_ops = None
        self._reset_variables_op = None

        # Callables to perform the above ops.
        self._perform_metrics_update_fn = None
        self._metrics_reset_update_get_fn = None

        try:
            self._construct_graph()
        except (RuntimeError, TypeError, ValueError,
                tf.errors.OpError) as exception:
            general_util.reraise_augmented(exception,
                                           'Failed to create graph.')
예제 #7
0
 def _perform_metrics_update_list(self, features_predictions_labels_list):
     """Run a metrics update on a list of FPLs."""
     feed_list = self._create_feed_for_features_predictions_labels_list(
         features_predictions_labels_list)
     try:
         self._perform_metrics_update_fn(*feed_list)
     except (RuntimeError, TypeError, ValueError,
             tf.errors.OpError) as exception:
         feed_dict = dict(
             zip(self._perform_metrics_update_fn_feed_list_keys, feed_list))
         self._log_debug_message_for_tracing_feed_errors(
             fetches=[self._all_metric_update_ops] +
             self._metric_variable_nodes,
             feed_list=self._perform_metrics_update_fn_feed_list)
         general_util.reraise_augmented(
             exception,
             'features_predictions_labels_list = %s, feed_dict = %s' %
             (features_predictions_labels_list, feed_dict))
예제 #8
0
 def metrics_reset_update_get(self, features_predictions_labels):
     """Run the metrics reset, update, get operations on a single FPL."""
     self.reset_metric_variables()
     feed_list = self._create_feed_for_features_predictions_labels(
         features_predictions_labels)
     try:
         [_, result] = self._metrics_reset_update_get_fn(*feed_list)
     except (RuntimeError, TypeError, ValueError,
             tf.errors.OpError) as exception:
         feed_dict = dict(
             zip(self._metrics_reset_update_get_fn_feed_list_keys,
                 feed_list))
         self._log_debug_message_for_tracing_feed_errors(
             fetches=[self._all_metric_update_ops],
             feed_list=self._metrics_reset_update_get_fn_feed_list)
         general_util.reraise_augmented(
             exception, 'features_predictions_labels = %s, feed_dict = %s' %
             (features_predictions_labels, feed_dict))
     return result
예제 #9
0
  def __init__(self):
    """Initializes this class and attempts to create the graph.

    This method attempts to create the graph through _construct_graph and
    also creates all class variables that need to be populated by the override
    function _construct_graph.
    """
    self._graph = tf.Graph()
    self._session = tf.compat.v1.Session(graph=self._graph)

    # This lock is for  multi-threaded contexts where multiple threads
    # share the same EvalSavedModel.
    #
    # Locking is required in the case where there are multiple threads using
    # the same EvalMetricsGraph. Because the metrics variables are part of the
    # session, and all threads share the same session, without a lock, the
    # "reset-update-get" steps may not be atomic and there can be races.
    #
    # Having each thread have its own session would also work, but would
    # require a bigger refactor.
    # TODO(b/131727905): Investigate whether it's possible / better to have
    # each thread have its own session.
    self._lock = threading.Lock()

    # Variables that need to be populated.

    # The names of the metric.
    self._metric_names = []

    # Ops associated with reading and writing the metric variables.
    self._metric_value_ops = []
    self._metric_update_ops = []
    self._metric_variable_assign_ops = []

    # Nodes associated with the metric variables.
    self._metric_variable_nodes = []

    # Placeholders and feed input for the metric variables.
    self._metric_variable_placeholders = []
    self._perform_metrics_update_fn_feed_list = []
    self._perform_metrics_update_fn_feed_list_keys = []

    # OrderedDicts that map features, predictions, and labels keys to their
    # tensors.
    self._features_map = {}
    self._predictions_map = {}
    self._labels_map = {}

    # Ops to set/update/reset all metric variables.
    self._all_metric_variable_assign_ops = None
    self._all_metric_update_ops = None
    self._reset_variables_op = None

    # Callable to perform metric update.
    self._perform_metrics_update_fn = None

    # OrderedDict produced by graph_ref's load_(legacy_)inputs, mapping input
    # key to tensor value.
    self._input_map = None

    self._batch_size = (
        beam.metrics.Metrics.distribution(constants.METRICS_NAMESPACE,
                                          'batch_size'))
    self._batch_size_failed = (
        beam.metrics.Metrics.distribution(constants.METRICS_NAMESPACE,
                                          'batch_size_failed'))

    try:
      self._construct_graph()
    except (RuntimeError, TypeError, ValueError,
            tf.errors.OpError) as exception:
      general_util.reraise_augmented(exception, 'Failed to create graph.')