def test_get_metric_params(self):
     metric_instances = [
         metrics_online.StddevAcrossRuns(eval_points=[3, 2, 1],
                                         baseline=-2),
         metrics_online.LowerCVaROnDiffs(alpha=0.77)
     ]
     metric_params = eval_metrics.get_metric_params(metric_instances)
     self.assertCountEqual(metric_params.keys(),
                           ['StddevAcrossRuns', 'LowerCVaROnDiffs'])
     self.assertEqual(
         metric_params['StddevAcrossRuns'], {
             'eval_points': [3, 2, 1],
             'baseline': -2,
             'lowpass_thresh': None,
             'window_size': None,
         })
     self.assertEqual(
         metric_params['LowerCVaROnDiffs'], {
             'target': 'diffs',
             'tail': 'lower',
             'alpha': 0.77,
             'baseline': None,
             'eval_points': None,
             'window_size': None,
             'lowpass_thresh': None,
         })
 def test_compute_metrics(self):
     curves = [
         np.array([[-1, 0, 1], [1., 1., 1.]]),
         np.array([[-1, 0, 1, 2], [2., 3., 4., 5.]])
     ]
     evaluator = eval_metrics.Evaluator(
         [metrics_online.StddevAcrossRuns(eval_points=[0, 1], baseline=1)])
     results = evaluator.compute_metrics(curves)
     np.testing.assert_allclose(results['StddevAcrossRuns'],
                                [1.41421356237, 2.12132034356])
예제 #3
0
 def testCorrectStddevAcrossRuns(self, timepoints, lowpass_thresh, baseline,
                                 expected):
     curves = [
         np.array([[-1, 0, 1], [1., 1., 1.]]),
         np.array([[-1, 0, 1, 2], [2., 3., 4., 5.]])
     ]
     metric = metrics_online.StddevAcrossRuns(lowpass_thresh=lowpass_thresh,
                                              eval_points=timepoints,
                                              baseline=baseline)
     result = metric(curves)
     np.testing.assert_allclose(result, expected)
    def test_write_results(self):
        # Generate some results.
        curves = [
            np.array([[-1, 0, 1], [1., 1., 1.]]),
            np.array([[-1, 0, 1, 2], [2., 3., 4., 5.]])
        ]
        metric = metrics_online.StddevAcrossRuns(eval_points=[0, 1],
                                                 baseline=1)
        evaluator = eval_metrics.Evaluator([metric])
        results = evaluator.compute_metrics(curves)

        outfile_prefix = os.path.join(flags.FLAGS.test_tmpdir, 'results_')
        params_path = evaluator.write_metric_params(outfile_prefix)
        results_path = evaluator.write_results(results, outfile_prefix)

        # Test write_results.
        with open(results_path, 'r') as outfile:
            results_loaded = outfile.readline()
        results_dict = json.loads(results_loaded)
        expected = {'StddevAcrossRuns': [1.41421356237, 2.12132034356]}
        self.assertEqual(results_dict.keys(), expected.keys())
        np.testing.assert_allclose(expected['StddevAcrossRuns'],
                                   results_dict['StddevAcrossRuns'])

        # Test write_metric_params.
        with open(params_path, 'r') as outfile:
            params_loaded = outfile.readline()
        expected = json.dumps({
            'StddevAcrossRuns': {
                'eval_points': [0, 1],
                'lowpass_thresh': None,
                'baseline': 1,
                'window_size': None,
            }
        })
        self.assertJsonEqual(expected, params_loaded)
 def test_window_empty(self):
     curves = [np.array([[0, 2], [2, 3]])]
     evaluator = eval_metrics.Evaluator([metrics_online.StddevAcrossRuns()])
     self.assertRaises(ValueError, evaluator.compute_metrics, curves)
 def test_window_out_of_range(self):
     curves = [np.array([[0, 1], [1, 1]])]
     evaluator = eval_metrics.Evaluator([metrics_online.StddevAcrossRuns()])
     self.assertRaises(ValueError, evaluator.compute_metrics, curves)
예제 #7
0
 def testErrorOnEvalPointsOutOfBounds(self, curves, eval_points):
     metric = metrics_online.StddevAcrossRuns(lowpass_thresh=None,
                                              eval_points=eval_points)
     self.assertRaises(ValueError, metric, curves)