def get_metric_ops(self, features_dict, predictions_dict, labels_dict):
        ref_tensor = _get_prediction_tensor(predictions_dict)
        if ref_tensor is None:
            # Note that if predictions_dict is a Tensor and not a dict,
            # get_predictions_tensor will return predictions_dict, so if we get
            # here, if means that predictions_dict is a dict without any of the
            # standard keys.
            #
            # If we can't get any of standard keys, then pick the first key
            # in alphabetical order if the predictions dict is non-empty.
            # If the predictions dict is empty, try the labels dict.
            # If that is empty too, default to the empty Tensor.
            tf.logging.info(
                'ExampleCount post export metric: could not find any of '
                'the standard keys in predictions_dict (keys were: %s)',
                predictions_dict.keys())
            if predictions_dict is not None and predictions_dict.keys():
                first_key = sorted(predictions_dict.keys())[0]
                ref_tensor = predictions_dict[first_key]
                tf.logging.info(
                    'Using the first key from predictions_dict: %s', first_key)
            elif labels_dict is not None:
                if types.is_tensor(labels_dict):
                    ref_tensor = labels_dict
                    tf.logging.info('Using the labels Tensor')
                elif labels_dict.keys():
                    first_key = sorted(labels_dict.keys())[0]
                    ref_tensor = labels_dict[first_key]
                    tf.logging.info('Using the first key from labels_dict: %s',
                                    first_key)

            if ref_tensor is None:
                tf.logging.info(
                    'Could not find a reference Tensor for example count. '
                    'Defaulting to the empty Tensor.')
                ref_tensor = tf.constant([])

        return {
            metric_keys.EXAMPLE_COUNT: metrics.total(tf.shape(ref_tensor)[0])
        }
Example #2
0
 def get_metric_ops(self, features_dict, predictions_dict, labels_dict):
     value = features_dict[self._example_weight_key]
     return {metric_keys.EXAMPLE_WEIGHT: metrics.total(value)}
Example #3
0
    def testVariablePredictionLengths(self):
        # Check that we can handle cases where the model produces predictions of
        # different lengths for different examples.
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = (
            fixed_prediction_classifier.simple_fixed_prediction_classifier(
                None, temp_eval_export_dir))

        eval_saved_model = load.EvalSavedModel(eval_export_dir)
        _, prediction_dict, _ = (
            eval_saved_model.get_features_predictions_labels_dicts())
        with eval_saved_model.graph_as_default():
            eval_saved_model.register_additional_metric_ops({
                'total_non_trivial_classes':
                metrics.total(
                    tf.reduce_sum(
                        tf.cast(
                            tf.logical_and(
                                tf.not_equal(prediction_dict['classes'], '?'),
                                tf.not_equal(prediction_dict['classes'], '')),
                            tf.int32))),
                'example_count':
                metrics.total(tf.shape(prediction_dict['classes'])[0]),
                'total_score':
                metrics.total(prediction_dict['probabilities']),
            })

        example1 = self._makeExample(classes=['apple'], scores=[100.0])
        example2 = self._makeExample()
        example3 = self._makeExample(
            classes=['durian', 'elderberry', 'fig', 'grape'],
            scores=[300.0, 301.0, 302.0, 303.0])
        example4 = self._makeExample(classes=['banana', 'cherry'],
                                     scores=[400.0, 401.0])

        fpl_list1 = eval_saved_model.predict_list([
            example1.SerializeToString(),
            example2.SerializeToString(),
        ])
        fpl_list2 = eval_saved_model.predict_list([
            example3.SerializeToString(),
            example4.SerializeToString(),
        ])

        # Note that the '?' and 0 default values come from the model.
        self.assertAllEqual(
            np.array([['apple']]),
            fpl_list1[0].predictions['classes'][encoding.NODE_SUFFIX])
        self.assertAllEqual(
            np.array([[100]]),
            fpl_list1[0].predictions['probabilities'][encoding.NODE_SUFFIX])
        self.assertAllEqual(
            np.array([['?']]),
            fpl_list1[1].predictions['classes'][encoding.NODE_SUFFIX])
        self.assertAllEqual(
            np.array([[0]]),
            fpl_list1[1].predictions['probabilities'][encoding.NODE_SUFFIX])

        self.assertAllEqual(
            np.array([['durian', 'elderberry', 'fig', 'grape']]),
            fpl_list2[0].predictions['classes'][encoding.NODE_SUFFIX])
        self.assertAllEqual(
            np.array([[300, 301, 302, 303]]),
            fpl_list2[0].predictions['probabilities'][encoding.NODE_SUFFIX])
        self.assertAllEqual(
            np.array([['banana', 'cherry', '?', '?']]),
            fpl_list2[1].predictions['classes'][encoding.NODE_SUFFIX])
        self.assertAllEqual(
            np.array([[400, 401, 0, 0]]),
            fpl_list2[1].predictions['probabilities'][encoding.NODE_SUFFIX])

        eval_saved_model.metrics_reset_update_get_list(fpl_list1 + fpl_list2)
        metric_values = eval_saved_model.get_metric_values()

        self.assertDictElementsAlmostEqual(
            metric_values, {
                'total_non_trivial_classes': 7.0,
                'example_count': 4.0,
                'total_score': 2107.0,
            })