def get_metric_ops(self, features_dict, predictions_dict, labels_dict): ref_tensor = _get_prediction_tensor(predictions_dict) if ref_tensor is None: # Note that if predictions_dict is a Tensor and not a dict, # get_predictions_tensor will return predictions_dict, so if we get # here, if means that predictions_dict is a dict without any of the # standard keys. # # If we can't get any of standard keys, then pick the first key # in alphabetical order if the predictions dict is non-empty. # If the predictions dict is empty, try the labels dict. # If that is empty too, default to the empty Tensor. tf.logging.info( 'ExampleCount post export metric: could not find any of ' 'the standard keys in predictions_dict (keys were: %s)', predictions_dict.keys()) if predictions_dict is not None and predictions_dict.keys(): first_key = sorted(predictions_dict.keys())[0] ref_tensor = predictions_dict[first_key] tf.logging.info( 'Using the first key from predictions_dict: %s', first_key) elif labels_dict is not None: if types.is_tensor(labels_dict): ref_tensor = labels_dict tf.logging.info('Using the labels Tensor') elif labels_dict.keys(): first_key = sorted(labels_dict.keys())[0] ref_tensor = labels_dict[first_key] tf.logging.info('Using the first key from labels_dict: %s', first_key) if ref_tensor is None: tf.logging.info( 'Could not find a reference Tensor for example count. ' 'Defaulting to the empty Tensor.') ref_tensor = tf.constant([]) return { metric_keys.EXAMPLE_COUNT: metrics.total(tf.shape(ref_tensor)[0]) }
def get_metric_ops(self, features_dict, predictions_dict, labels_dict): value = features_dict[self._example_weight_key] return {metric_keys.EXAMPLE_WEIGHT: metrics.total(value)}
def testVariablePredictionLengths(self): # Check that we can handle cases where the model produces predictions of # different lengths for different examples. temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = ( fixed_prediction_classifier.simple_fixed_prediction_classifier( None, temp_eval_export_dir)) eval_saved_model = load.EvalSavedModel(eval_export_dir) _, prediction_dict, _ = ( eval_saved_model.get_features_predictions_labels_dicts()) with eval_saved_model.graph_as_default(): eval_saved_model.register_additional_metric_ops({ 'total_non_trivial_classes': metrics.total( tf.reduce_sum( tf.cast( tf.logical_and( tf.not_equal(prediction_dict['classes'], '?'), tf.not_equal(prediction_dict['classes'], '')), tf.int32))), 'example_count': metrics.total(tf.shape(prediction_dict['classes'])[0]), 'total_score': metrics.total(prediction_dict['probabilities']), }) example1 = self._makeExample(classes=['apple'], scores=[100.0]) example2 = self._makeExample() example3 = self._makeExample( classes=['durian', 'elderberry', 'fig', 'grape'], scores=[300.0, 301.0, 302.0, 303.0]) example4 = self._makeExample(classes=['banana', 'cherry'], scores=[400.0, 401.0]) fpl_list1 = eval_saved_model.predict_list([ example1.SerializeToString(), example2.SerializeToString(), ]) fpl_list2 = eval_saved_model.predict_list([ example3.SerializeToString(), example4.SerializeToString(), ]) # Note that the '?' and 0 default values come from the model. self.assertAllEqual( np.array([['apple']]), fpl_list1[0].predictions['classes'][encoding.NODE_SUFFIX]) self.assertAllEqual( np.array([[100]]), fpl_list1[0].predictions['probabilities'][encoding.NODE_SUFFIX]) self.assertAllEqual( np.array([['?']]), fpl_list1[1].predictions['classes'][encoding.NODE_SUFFIX]) self.assertAllEqual( np.array([[0]]), fpl_list1[1].predictions['probabilities'][encoding.NODE_SUFFIX]) self.assertAllEqual( np.array([['durian', 'elderberry', 'fig', 'grape']]), fpl_list2[0].predictions['classes'][encoding.NODE_SUFFIX]) self.assertAllEqual( np.array([[300, 301, 302, 303]]), fpl_list2[0].predictions['probabilities'][encoding.NODE_SUFFIX]) self.assertAllEqual( np.array([['banana', 'cherry', '?', '?']]), fpl_list2[1].predictions['classes'][encoding.NODE_SUFFIX]) self.assertAllEqual( np.array([[400, 401, 0, 0]]), fpl_list2[1].predictions['probabilities'][encoding.NODE_SUFFIX]) eval_saved_model.metrics_reset_update_get_list(fpl_list1 + fpl_list2) metric_values = eval_saved_model.get_metric_values() self.assertDictElementsAlmostEqual( metric_values, { 'total_non_trivial_classes': 7.0, 'example_count': 4.0, 'total_score': 2107.0, })