Exemplo n.º 1
0
      def check_result(got):
        try:
          self.assertLen(got, 1)
          got_slice_key, got_plots = got[0]
          self.assertEqual(got_slice_key, ())
          self.assertLen(got_plots, 1)
          key = metric_types.PlotKey(
              name='_calibration_histogram_10000',
              sub_key=metric_types.SubKey(top_k=2),
              example_weighted=True)
          self.assertIn(key, got_plots)
          got_histogram = got_plots[key]
          self.assertLen(got_histogram, 5)
          self.assertEqual(
              got_histogram[0],
              calibration_histogram.Bucket(
                  bucket_id=0,
                  weighted_labels=3.0 + 4.0,
                  weighted_predictions=(2 * 1.0 * float('-inf') +
                                        2 * 2.0 * float('-inf') +
                                        2 * 3.0 * float('-inf') +
                                        2 * 4.0 * float('-inf') + -0.1 * 4.0),
                  weighted_examples=(1.0 * 2.0 + 2.0 * 2.0 + 3.0 * 2.0 +
                                     4.0 * 3.0)))
          self.assertEqual(
              got_histogram[1],
              calibration_histogram.Bucket(
                  bucket_id=2001,
                  weighted_labels=0.0 + 0.0,
                  weighted_predictions=0.2 + 3 * 0.2,
                  weighted_examples=1.0 + 3.0))
          self.assertEqual(
              got_histogram[2],
              calibration_histogram.Bucket(
                  bucket_id=5001,
                  weighted_labels=1.0 + 0.0 * 3.0,
                  weighted_predictions=0.5 * 1.0 + 0.5 * 3.0,
                  weighted_examples=1.0 + 3.0))
          self.assertEqual(
              got_histogram[3],
              calibration_histogram.Bucket(
                  bucket_id=8001,
                  weighted_labels=0.0 * 2.0 + 1.0 * 2.0,
                  weighted_predictions=0.8 * 2.0 + 0.8 * 2.0,
                  weighted_examples=2.0 + 2.0))
          self.assertEqual(
              got_histogram[4],
              calibration_histogram.Bucket(
                  bucket_id=10001,
                  weighted_labels=0.0 * 4.0,
                  weighted_predictions=1.1 * 4.0,
                  weighted_examples=4.0))

        except AssertionError as err:
          raise util.BeamAssertException(err)
Exemplo n.º 2
0
      def check_result(got):
        try:
          self.assertLen(got, 1)
          got_slice_key, got_plots = got[0]
          self.assertEqual(got_slice_key, ())
          self.assertLen(got_plots, 1)
          key = metric_types.PlotKey(
              name='_calibration_histogram_10000',
              sub_key=metric_types.SubKey(k=2),
              example_weighted=True)
          self.assertIn(key, got_plots)
          got_histogram = got_plots[key]
          self.assertLen(got_histogram, 5)
          self.assertEqual(
              got_histogram[0],
              calibration_histogram.Bucket(
                  bucket_id=0,
                  weighted_labels=0.0 * 4.0,
                  weighted_predictions=-0.2 * 4.0,
                  weighted_examples=4.0))
          self.assertEqual(
              got_histogram[1],
              calibration_histogram.Bucket(
                  bucket_id=1001,
                  weighted_labels=1.0 + 7 * 1.0,
                  weighted_predictions=0.1 + 7 * 0.1,
                  weighted_examples=1.0 + 7.0))
          self.assertEqual(
              got_histogram[2],
              calibration_histogram.Bucket(
                  bucket_id=4001,
                  weighted_labels=1.0 * 3.0 + 0.0 * 5.0,
                  weighted_predictions=0.4 * 3.0 + 0.4 * 5.0,
                  weighted_examples=3.0 + 5.0))
          self.assertEqual(
              got_histogram[3],
              calibration_histogram.Bucket(
                  bucket_id=7001,
                  weighted_labels=0.0 * 2.0 + 0.0 * 6.0,
                  weighted_predictions=0.7 * 2.0 + 0.7 * 6.0,
                  weighted_examples=2.0 + 6.0))
          self.assertEqual(
              got_histogram[4],
              calibration_histogram.Bucket(
                  bucket_id=10001,
                  weighted_labels=0.0 * 8.0,
                  weighted_predictions=1.05 * 8.0,
                  weighted_examples=8.0))

        except AssertionError as err:
          raise util.BeamAssertException(err)
            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_plots = got[0]
                    self.assertEqual(got_slice_key, ())
                    self.assertLen(got_plots, 1)
                    key = metric_types.PlotKey('_calibration_histogram_10000')
                    self.assertIn(key, got_plots)
                    got_histogram = got_plots[key]
                    self.assertLen(got_histogram, 5)
                    self.assertEqual(
                        got_histogram[0],
                        calibration_histogram.Bucket(
                            bucket_id=0,
                            weighted_labels=1.0 * 4.0,
                            weighted_predictions=-0.1 * 4.0,
                            weighted_examples=4.0))
                    self.assertEqual(
                        got_histogram[1],
                        calibration_histogram.Bucket(
                            bucket_id=2001,
                            weighted_labels=0.0 + 0.0,
                            weighted_predictions=0.2 + 7 * 0.2,
                            weighted_examples=1.0 + 7.0))
                    self.assertEqual(
                        got_histogram[2],
                        calibration_histogram.Bucket(
                            bucket_id=5001,
                            weighted_labels=1.0 * 5.0,
                            weighted_predictions=0.5 * 3.0 + 0.5 * 5.0,
                            weighted_examples=3.0 + 5.0))
                    self.assertEqual(
                        got_histogram[3],
                        calibration_histogram.Bucket(
                            bucket_id=8001,
                            weighted_labels=1.0 * 2.0 + 1.0 * 6.0,
                            weighted_predictions=0.8 * 2.0 + 0.8 * 6.0,
                            weighted_examples=2.0 + 6.0))
                    self.assertEqual(
                        got_histogram[4],
                        calibration_histogram.Bucket(bucket_id=10001,
                                                     weighted_labels=1.0 * 8.0,
                                                     weighted_predictions=1.1 *
                                                     8.0,
                                                     weighted_examples=8.0))

                except AssertionError as err:
                    raise util.BeamAssertException(err)
    def testRebinWithSparseData(self):
        histogram = [
            calibration_histogram.Bucket(4, 5.0, .25, 5.0),  # pred = .05
            calibration_histogram.Bucket(61, 60.0, 36.0, 60.0),  # pred = .6
            calibration_histogram.Bucket(70, 69.0, 47.61, 69.0),  # pred = .69
            calibration_histogram.Bucket(100, 99.0, 98.01, 99.0)  # pred = .99
        ]
        # [0, 0.1, ..., 0.9, 1.0]
        thresholds = [i * 1.0 / 10 for i in range(0, 11)]
        got = calibration_histogram.rebin(thresholds, histogram, 100)

        expected = [
            calibration_histogram.Bucket(0, 5.0, 0.25, 5.0),
            calibration_histogram.Bucket(1, 0.0, 0.0, 0.0),
            calibration_histogram.Bucket(2, 0.0, 0.0, 0.0),
            calibration_histogram.Bucket(3, 0.0, 0.0, 0.0),
            calibration_histogram.Bucket(4, 0.0, 0.0, 0.0),
            calibration_histogram.Bucket(5, 0.0, 0.0, 0.0),
            calibration_histogram.Bucket(6, 129.0, 83.61, 129.0),
            calibration_histogram.Bucket(7, 0.0, 0.0, 0.0),
            calibration_histogram.Bucket(8, 0.0, 0.0, 0.0),
            calibration_histogram.Bucket(9, 99.0, 98.01, 99.0),
            calibration_histogram.Bucket(10, 0.0, 0.0, 0.0),
        ]
        self.assertLen(got, len(expected))
        for i in range(len(got)):
            self.assertSequenceAlmostEqual(got[i], expected[i])
    def testRebin(self):
        # [Bucket(0, -1, -0.01), Bucket(1, 0, 0) ... Bucket(101, 101, 1.01)]
        histogram = [calibration_histogram.Bucket(0, -1, -.01, 1.0)]
        for i in range(100):
            histogram.append(
                calibration_histogram.Bucket(i + 1, i, i * .01, 1.0))
        histogram.append(calibration_histogram.Bucket(101, 101, 1.01, 1.0))
        # [-1e-7, 0.0, 0.1, ..., 0.9, 1.0, 1.0+1e-7]
        thresholds = [-1e-7] + [i * 1.0 / 10 for i in range(11)] + [1.0 + 1e-7]
        got = calibration_histogram.rebin(thresholds, histogram, 100)

        # labels = (10 * (i-1)) + (1 + 2 + 3 + ... + 9)
        expected = [
            calibration_histogram.Bucket(0, -1, -0.01, 1.0),
            calibration_histogram.Bucket(1, 45.0, 0.45, 10.0),
            calibration_histogram.Bucket(2, 145.0, 1.45, 10.0),
            calibration_histogram.Bucket(3, 245.0, 2.45, 10.0),
            calibration_histogram.Bucket(4, 345.0, 3.45, 10.0),
            calibration_histogram.Bucket(5, 445.0, 4.45, 10.0),
            calibration_histogram.Bucket(6, 545.0, 5.45, 10.0),
            calibration_histogram.Bucket(7, 645.0, 6.45, 10.0),
            calibration_histogram.Bucket(8, 745.0, 7.45, 10.0),
            calibration_histogram.Bucket(9, 845.0, 8.45, 10.0),
            calibration_histogram.Bucket(10, 945.0, 9.45, 10.0),
            calibration_histogram.Bucket(11, 0.0, 0.0, 0.0),
            calibration_histogram.Bucket(12, 101.0, 1.01, 1.0),
        ]
        self.assertLen(got, len(expected))
        for i in range(len(got)):
            self.assertSequenceAlmostEqual(got[i], expected[i])