def testEvaluateExistingMetricsCSVInputBasic(self):
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = (
            csv_linear_classifier.simple_csv_linear_classifier(
                None, temp_eval_export_dir))

        eval_saved_model = load.EvalSavedModel(eval_export_dir)
        eval_saved_model.metrics_reset_update_get_list(
            ['3.0,english,1.0', '3.0,chinese,0.0'])

        metric_values = eval_saved_model.get_metric_values()
        self.assertDictElementsAlmostEqual(metric_values, {
            'accuracy': 1.0,
            'auc': 1.0
        })
示例#2
0
    def testEvaluateExistingMetricsCSVInputBasic(self):
        _, temp_eval_export_dir = self._getExportDirs()
        _, eval_export_dir = csv_linear_classifier.simple_csv_linear_classifier(
            None, temp_eval_export_dir)

        eval_saved_model = load.EvalSavedModel(eval_export_dir)
        features_predictions_labels = eval_saved_model.predict(
            '3.0,english,1.0')
        eval_saved_model.perform_metrics_update(features_predictions_labels)

        features_predictions_labels = eval_saved_model.predict(
            '3.0,chinese,0.0')
        eval_saved_model.perform_metrics_update(features_predictions_labels)

        metric_values = eval_saved_model.get_metric_values()
        self.assertDictElementsAlmostEqual(metric_values, {
            'accuracy': 1.0,
            'auc': 1.0
        })