예제 #1
0
    def test_should_error_out_for_not_recognized_args(self):
        estimator = linear.LinearClassifier([fc.numeric_column('x')])

        def metric_fn(features, not_recognized):
            _, _ = features, not_recognized
            return {}

        with self.assertRaisesRegexp(ValueError, 'not_recognized'):
            estimator = extenders.add_metrics(estimator, metric_fn)
예제 #2
0
    def test_all_args_are_optional(self):
        input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]])
        estimator = linear.LinearClassifier([fc.numeric_column('x')])

        def metric_fn():
            return {'two': metrics_lib.mean(constant_op.constant([2.]))}

        estimator = extenders.add_metrics(estimator, metric_fn)

        estimator.train(input_fn=input_fn)
        metrics = estimator.evaluate(input_fn=input_fn)
        self.assertEqual(2., metrics['two'])
예제 #3
0
    def test_overrides_existing_metrics(self):
        input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]])
        estimator = linear.LinearClassifier([fc.numeric_column('x')])
        estimator.train(input_fn=input_fn)
        metrics = estimator.evaluate(input_fn=input_fn)
        self.assertNotEqual(2., metrics['auc'])

        def metric_fn():
            return {'auc': metrics_lib.mean(constant_op.constant([2.]))}

        estimator = extenders.add_metrics(estimator, metric_fn)
        metrics = estimator.evaluate(input_fn=input_fn)
        self.assertEqual(2., metrics['auc'])
예제 #4
0
    def test_all_supported_args_in_different_order(self):
        input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]])
        estimator = linear.LinearClassifier([fc.numeric_column('x')])

        def metric_fn(labels, config, features, predictions):
            self.assertIn('x', features)
            self.assertIsNotNone(labels)
            self.assertIn('logistic', predictions)
            self.assertTrue(isinstance(config, estimator_lib.RunConfig))
            return {}

        estimator = extenders.add_metrics(estimator, metric_fn)

        estimator.train(input_fn=input_fn)
        estimator.evaluate(input_fn=input_fn)
예제 #5
0
    def test_should_add_metrics(self):
        input_fn = get_input_fn(x=np.arange(4)[:, None, None],
                                y=np.ones(4)[:, None])
        estimator = linear.LinearClassifier([fc.numeric_column('x')])

        def metric_fn(features):
            return {'mean_x': metrics_lib.mean(features['x'])}

        estimator = extenders.add_metrics(estimator, metric_fn)

        estimator.train(input_fn=input_fn)
        metrics = estimator.evaluate(input_fn=input_fn)
        self.assertIn('mean_x', metrics)
        self.assertEqual(1.5, metrics['mean_x'])
        # assert that it keeps original estimators metrics
        self.assertIn('auc', metrics)