Пример #1
0
 def add_input(
     self, accumulator: tf_metric_accumulators.TFCompilableMetricsAccumulator,
     element: metric_types.StandardMetricInputs
 ) -> tf_metric_accumulators.TFCompilableMetricsAccumulator:
   for i, output_name in enumerate(self._output_names):
     # When micro averaging is being used, flatten should be set to True so
     # that each class is treated as though it was an independent example.
     micro_average = (
         self._aggregation_type and self._aggregation_type.micro_average)
     for label, prediction, example_weight in (
         metric_util.to_label_prediction_example_weight(
             element,
             eval_config=self._eval_config,
             model_name=self._model_name,
             output_name=output_name,
             # Skip sub_key processing if part of the keras config
             sub_key=self._sub_key if not self._sub_key_in_config else None,
             aggregation_type=self._aggregation_type,
             class_weights=self._class_weights,
             flatten=micro_average)):
       # Keras requires non-sparse keys for its calcuations.
       if self._sub_key_in_config and label.shape != prediction.shape:
         label = metric_util.one_hot(label, prediction)
       accumulator.add_input(i, label, prediction, example_weight)
   if accumulator.should_flush():
     self._process_batch(accumulator)
   return accumulator
Пример #2
0
 def _update_state(
         self,
         accumulator: tf_metric_accumulators.TFCompilableMetricsAccumulator
 ):
     if len(self._output_names) == 1:
         # Single-output models don't use dicts.
         l, p, w = accumulator.get_inputs(0)
         labels = tf.convert_to_tensor(l)
         predictions = tf.convert_to_tensor(p)
         example_weights = tf.convert_to_tensor(w)
     else:
         labels = {}
         predictions = {}
         example_weights = {}
         for i, output_name in enumerate(self._output_names):
             if not output_name:
                 # The empty output_name for multi-output models is not used for inputs
                 continue
             l, p, w = accumulator.get_inputs(i)
             labels[output_name] = tf.convert_to_tensor(l)
             predictions[output_name] = tf.convert_to_tensor(p)
             example_weights[output_name] = tf.convert_to_tensor(w)
     self._model.compiled_metrics.update_state(
         labels, predictions, sample_weight=example_weights)
     self._model.compiled_loss(labels,
                               predictions,
                               sample_weight=example_weights)
Пример #3
0
 def _add_input(
     self,
     accumulator: tf_metric_accumulators.TFCompilableMetricsAccumulator,
     element: metric_types.StandardMetricInputs
 ) -> tf_metric_accumulators.TFCompilableMetricsAccumulator:
     for i, output_name in enumerate(self._output_names):
         if not output_name and len(self._output_names) > 1:
             # The first output_name for multi-output models is '' and is used to
             # store combined metric weights for all outputs, but is not for inputs.
             labels, predictions, example_weights = None, None, None
         else:
             labels, predictions, example_weights = next(
                 metric_util.to_label_prediction_example_weight(
                     element,
                     self._eval_config,
                     self._model_name,
                     output_name,
                     flatten=False))
         accumulator.add_input(i, labels, predictions, example_weights)
     return accumulator
Пример #4
0
 def _process_batch(
     self, accumulator: tf_metric_accumulators.TFCompilableMetricsAccumulator):
   if accumulator.len_inputs() == 0:
     return
   self._batch_size_beam_metric.update(accumulator.len_inputs())
   self._total_input_byte_size_beam_metric.update(
       accumulator.get_size_estimate())
   for output_index, output_name in enumerate(self._output_names):
     inputs = accumulator.get_inputs(output_index)
     for metric_index, metric in enumerate(self._metrics[output_name]):
       metric.reset_states()
       metric.update_state(*inputs)
       accumulator.add_weights(output_index, metric_index,
                               metric.get_weights())
   accumulator.clear_inputs()
Пример #5
0
 def extract_output(
     self, accumulator: tf_metric_accumulators.TFCompilableMetricsAccumulator
 ) -> Dict[metric_types.MetricKey, Any]:
   self._process_batch(accumulator)
   result = {}
   for output_index, output_name in enumerate(self._output_names):
     for metric_index, metric in enumerate(self._metrics[output_name]):
       key = metric_types.MetricKey(
           name=metric.name,
           model_name=self._model_name,
           output_name=output_name,
           sub_key=self._sub_key)
       weights = accumulator.get_weights(output_index, metric_index)
       if weights is not None:
         metric.set_weights(weights)
       else:
         metric.reset_states()
       result[key] = metric.result().numpy()
   return result