Ejemplo n.º 1
0
            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    for item in got:
                        # We can't verify the actual predictions, but we can verify the keys
                        self.assertIn(constants.BATCHED_PREDICTIONS_KEY, item)
                        for pred in item[constants.BATCHED_PREDICTIONS_KEY]:
                            for model_name in ('model1', 'model2'):
                                self.assertIn(model_name, pred)
                                for output_name in ('chinese_head',
                                                    'english_head',
                                                    'other_head'):
                                    for pred_key in ('logistic',
                                                     'probabilities',
                                                     'all_classes'):
                                        self.assertIn(
                                            output_name + '/' + pred_key,
                                            pred[model_name])

                except AssertionError as err:
                    raise util.BeamAssertException(err)
Ejemplo n.º 2
0
 def check_result(got):  # pylint: disable=invalid-name
     try:
         self.assertEqual(1, len(got), 'got: %s' % got)
         (slice_key, value) = got[0]
         self.assertEqual((), slice_key)
         self.assertIn(metric_keys.CALIBRATION_PLOT_MATRICES, value)
         buckets = value[metric_keys.CALIBRATION_PLOT_MATRICES]
         self.assertSequenceAlmostEqual(buckets[0], [-19.0, -17.0, 2.0])
         self.assertSequenceAlmostEqual(buckets[1], [0.0, 1.0, 1.0])
         self.assertSequenceAlmostEqual(buckets[11],
                                        [0.00303, 3.00303, 3.0])
         self.assertSequenceAlmostEqual(buckets[10000],
                                        [1.99997, 3.99997, 2.0])
         self.assertSequenceAlmostEqual(buckets[10001],
                                        [28.0, 32.0, 4.0])
         self.assertIn(metric_keys.CALIBRATION_PLOT_BOUNDARIES, value)
         boundaries = value[metric_keys.CALIBRATION_PLOT_BOUNDARIES]
         self.assertAlmostEqual(0.0, boundaries[0])
         self.assertAlmostEqual(0.001, boundaries[10])
         self.assertAlmostEqual(0.005, boundaries[50])
         self.assertAlmostEqual(0.010, boundaries[100])
         self.assertAlmostEqual(0.100, boundaries[1000])
         self.assertAlmostEqual(0.800, boundaries[8000])
         self.assertAlmostEqual(1.000, boundaries[10000])
         plot_data = metrics_for_slice_pb2.PlotData()
         calibration_plot.populate_plots_and_pop(value, plot_data)
         self.assertProtoEquals(
             """lower_threshold_inclusive:1.0
     upper_threshold_exclusive: inf
     num_weighted_examples {
       value: 4.0
     }
     total_weighted_label {
       value: 32.0
     }
     total_weighted_refined_prediction {
       value: 28.0
     }""", plot_data.calibration_histogram_buckets.buckets[10001])
     except AssertionError as err:
         raise util.BeamAssertException(err)
      def check_result(got):
        try:
          self.assertLen(got, 1)
          got_slice_key, got_metrics = got[0]
          self.assertEqual(got_slice_key, ())
          self.assertLen(got_metrics, 1)
          key = metric_types.MetricKey(name='confusion_matrix_at_thresholds')
          self.assertIn(key, got_metrics)
          got_metric = got_metrics[key]
          self.assertProtoEquals(
              """
              matrices {
                threshold: 0.3
                false_negatives: 1.0
                true_negatives: 1.0
                false_positives: 1.0
                true_positives: 1.0
                precision: 0.5
                recall: 0.5
              }
              matrices {
                threshold: 0.5
                false_negatives: 1.0
                true_negatives: 2.0
                true_positives: 1.0
                precision: 1.0
                recall: 0.5
              }
              matrices {
                threshold: 0.8
                false_negatives: 1.0
                true_negatives: 2.0
                true_positives: 1.0
                precision: 1.0
                recall: 0.5
              }
          """, got_metric)

        except AssertionError as err:
          raise util.BeamAssertException(err)
 def check_result(got):  # pylint: disable=invalid-name
     try:
         self.assertEqual(1, len(got), 'got: %s' % got)
         (slice_key, value) = got[0]
         self.assertEqual((), slice_key)
         expected_values_dict = {
             metric_keys.base_key('[email protected]'): 2.0 / 3.0,
             metric_keys.base_key('[email protected]'):
             1.0 / 2.0,
             metric_keys.base_key('[email protected]'): 3.0 / 5.0,
             metric_keys.base_key('[email protected]'): 1.0 / 2.0,
             metric_keys.base_key('[email protected]'):
             1.0 / 3.0,
             metric_keys.base_key('[email protected]'): 2.0 / 5.0,
             metric_keys.base_key('[email protected]'): 2.0 / 3.0,
             metric_keys.base_key('[email protected]'): 2.0 / 3.0,
             metric_keys.base_key('[email protected]'): 2.0 / 3.0,
             metric_keys.base_key('[email protected]'): 2.0 / 3.0,
             metric_keys.base_key('[email protected]'): 2.0 / 3.0,
             metric_keys.base_key('[email protected]'): 1.0 / 3.0,
             metric_keys.base_key('[email protected]'):
             1.0 / 2.0,
             metric_keys.base_key('[email protected]'): 2.0 / 5.0,
             metric_keys.base_key('[email protected]'): 1.0 / 2.0,
             metric_keys.base_key('[email protected]'):
             2.0 / 3.0,
             metric_keys.base_key('[email protected]'): 3.0 / 5.0,
             metric_keys.base_key('[email protected]'): 1.0 / 5.0,
             metric_keys.base_key('[email protected]'): 1.0 / 3.0,
             metric_keys.base_key('[email protected]'):
             0.0 / 2.0,
             metric_keys.base_key('[email protected]'): 1.0 / 5.0,
             metric_keys.base_key('[email protected]'): 2.0 / 2.0,
             metric_keys.base_key('[email protected]'):
             2.0 / 3.0,
             metric_keys.base_key('[email protected]'): 4.0 / 5.0,
         }
         self.assertDictElementsAlmostEqual(value, expected_values_dict)
     except AssertionError as err:
         raise util.BeamAssertException(err)
Ejemplo n.º 5
0
      def check_result(got):
        try:
          self.assertLen(got, 2)
          # We can't verify the actual predictions, but we can verify the keys.
          for item in got:
            self.assertIn(constants.PREDICTIONS_KEY, item)

            if multi_model:
              self.assertIn('model1', item[constants.PREDICTIONS_KEY])
              self.assertIn('model2', item[constants.PREDICTIONS_KEY])
              if multi_output:
                self.assertIn('Identity',
                              item[constants.PREDICTIONS_KEY]['model1'])
                self.assertIn('Identity_1',
                              item[constants.PREDICTIONS_KEY]['model1'])

            elif multi_output:
              self.assertIn('Identity', item[constants.PREDICTIONS_KEY])
              self.assertIn('Identity_1', item[constants.PREDICTIONS_KEY])

        except AssertionError as err:
          raise util.BeamAssertException(err)
            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    self.assertLen(got_metrics, 1)
                    key = metric_types.MetricKey(
                        name='_binary_confusion_matrices_[-inf]',
                        sub_key=metric_types.SubKey(top_k=3))
                    self.assertIn(key, got_metrics)
                    got_matrices = got_metrics[key]
                    self.assertEqual(
                        got_matrices,
                        binary_confusion_matrices.Matrices(
                            thresholds=[float('-inf')],
                            tp=[2.0],
                            fp=[10.0],
                            tn=[6.0],
                            fn=[2.0]))

                except AssertionError as err:
                    raise util.BeamAssertException(err)
Ejemplo n.º 7
0
            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    expected = {}
                    for output_name, per_output_values in expected_values.items(
                    ):
                        for name, value in per_output_values.items():
                            sub_key = None
                            if '@' in name:
                                sub_key = metric_types.SubKey(
                                    top_k=int(name.split('@')[1]))
                            key = metric_types.MetricKey(
                                name=name,
                                output_name=output_name,
                                sub_key=sub_key)
                            expected[key] = value
                    self.assertDictElementsAlmostEqual(got_metrics, expected)

                except AssertionError as err:
                    raise util.BeamAssertException(err)
Ejemplo n.º 8
0
 def check_result(got):  # pylint: disable=invalid-name
     try:
         self.assertEqual(1, len(got), 'got: %s' % got)
         (slice_key, value) = got[0]
         self.assertEqual((), slice_key)
         self.assertIn(
             metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES, value)
         matrices = value[
             metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES]
         #            |      |       --------- Threshold -----------
         # true label | pred | wt   | -1e-6 | 0.0 | 0.7 | 0.8 | 1.0
         #     -      | 0.0  | 1.0  | FP    | TN  | TN  | TN  | TN
         #     +      | 0.0  | 1.0  | TP    | FN  | FN  | FN  | FN
         #     +      | 0.7  | 3.0  | TP    | TP  | FN  | FN  | FN
         #     -      | 0.8  | 2.0  | FP    | FP  | FP  | TN  | TN
         #     +      | 1.0  | 3.0  | TP    | TP  | TP  | TP  | FN
         self.assertSequenceAlmostEqual(
             matrices[0], [0.0, 0.0, 3.0, 7.0, 7.0 / 10.0, 1.0])
         self.assertSequenceAlmostEqual(
             matrices[1], [1.0, 1.0, 2.0, 6.0, 6.0 / 8.0, 6.0 / 7.0])
         self.assertSequenceAlmostEqual(
             matrices[2], [4.0, 1.0, 2.0, 3.0, 3.0 / 5.0, 3.0 / 7.0])
         self.assertSequenceAlmostEqual(
             matrices[3], [4.0, 3.0, 0.0, 3.0, 1.0, 3.0 / 7.0])
         self.assertSequenceAlmostEqual(
             matrices[4],
             [7.0, 3.0, 0.0, 0.0, float('nan'), 0.0])
         self.assertIn(
             metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS,
             value)
         thresholds = value[
             metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS]
         self.assertAlmostEqual(-1e-6, thresholds[0])
         self.assertAlmostEqual(0.0, thresholds[1])
         self.assertAlmostEqual(0.7, thresholds[2])
         self.assertAlmostEqual(0.8, thresholds[3])
         self.assertAlmostEqual(1.0, thresholds[4])
     except AssertionError as err:
         raise util.BeamAssertException(err)
Ejemplo n.º 9
0
      def check_result(got):
        try:
          self.assertLen(got, 1)
          (got_key, got_trigger_labels_pairs_list) = got[0]
          self.assertEqual(b"Patient/14", got_key)
          self.assertLen(got_trigger_labels_pairs_list, 2)
          # Sort got_trigger_labels_pairs_list by trigger.event_time, so that
          # the ordering is always consistent in ordering.
          sorted_list = sorted(
              got_trigger_labels_pairs_list,
              key=lambda x: x[0].event_time.value_us)
          (got_trigger1, got_label_list1) = sorted_list[0]
          self.assertProtoEqual(got_trigger1, trigger1)
          self.assertLen(got_label_list1, 1)
          self.assertProtoEqual(got_label_list1[0], label1)
          (got_trigger2, got_label_list2) = sorted_list[1]
          self.assertProtoEqual(got_trigger2, trigger2)
          self.assertLen(got_label_list2, 1)
          self.assertProtoEqual(got_label_list2[0], label2)

        except AssertionError as err:
          raise util.BeamAssertException(err)
Ejemplo n.º 10
0
 def check_result(got):
     try:
         self.assertLen(got, 1)
         got_slice_key, got_metrics = got[0]
         self.assertEqual(got_slice_key, ())
         self.assertLen(got_metrics, 1)
         key = metric_types.MetricKey(
             name='multi_class_confusion_matrix_at_thresholds')
         got_matrix = got_metrics[key]
         self.assertEqual(
             multi_class_confusion_matrix_metrics.Matrices({
                 0.5: {
                     multi_class_confusion_matrix_metrics.MatrixEntryKey(actual_class_id=0,
                                                                         predicted_class_id=0):
                     0.5,
                     multi_class_confusion_matrix_metrics.MatrixEntryKey(actual_class_id=2,
                                                                         predicted_class_id=-1):
                     1.0
                 }
             }), got_matrix)
     except AssertionError as err:
         raise util.BeamAssertException(err)
  def _matcher(actual):
    """Matcher function for comparing the example dicts."""
    try:
      # Check number of examples.
      test.assertLen(actual, len(expected))
      for i in range(len(actual)):
        for key in actual[i]:
          # Check each feature value.
          if isinstance(expected[i][key], np.ndarray):
            test.assertEqual(
                expected[i][key].dtype, actual[i][key].dtype,
                'Expected dtype {}, found {} in actual[{}][{}]: {}'.format(
                    expected[i][key].dtype, actual[i][key].dtype, i, key,
                    actual[i][key]))
            np.testing.assert_equal(actual[i][key], expected[i][key])
          else:
            test.assertEqual(
                expected[i][key], actual[i][key],
                'Unexpected value of actual[{}][{}]'.format(i, key))

    except AssertionError:
      raise util.BeamAssertException(traceback.format_exc())
Ejemplo n.º 12
0
            def check_result(got):
                try:
                    self.assertLen(got, 2)
                    for item in got:
                        self.assertIn(constants.FEATURES_KEY, item)
                        for feature in ('language', 'age'):
                            for features_dict in item[constants.FEATURES_KEY]:
                                self.assertIn(feature, features_dict)
                        self.assertIn(constants.LABELS_KEY, item)
                        self.assertIn(constants.PREDICTIONS_KEY, item)
                        for model in ('model1', 'model2'):
                            for predictions_dict in item[
                                    constants.PREDICTIONS_KEY]:
                                self.assertIn(model, predictions_dict)
                        self.assertIn(constants.EXAMPLE_WEIGHTS_KEY, item)
                        for i in range(len(item[constants.FEATURES_KEY])):
                            self.assertAlmostEqual(
                                item[constants.FEATURES_KEY][i]['age'],
                                item[constants.EXAMPLE_WEIGHTS_KEY][i])

                except AssertionError as err:
                    raise util.BeamAssertException(err)
      def check_result(got):
        try:
          self.assertLen(got, 1)
          self.assertDictElementsAlmostEqual(got[0][constants.FEATURES_KEY][0],
                                             {
                                                 'fixed_int': np.array([1]),
                                                 'fixed_float': np.array([1.0]),
                                             })
          self.assertEqual(got[0][constants.FEATURES_KEY][0]['fixed_string'],
                           np.array([b'fixed_string1']))
          self.assertAlmostEqual(got[0][constants.LABELS_KEY][0],
                                 np.array([1.0]))
          self.assertAlmostEqual(got[0][constants.EXAMPLE_WEIGHTS_KEY][0],
                                 np.array([0.5]))
          self.assertDictElementsAlmostEqual(got[0][constants.FEATURES_KEY][1],
                                             {
                                                 'fixed_int': np.array([1]),
                                                 'fixed_float': np.array([1.0]),
                                             })
          self.assertEqual(got[0][constants.FEATURES_KEY][1]['fixed_string'],
                           np.array([b'fixed_string2']))
          self.assertAlmostEqual(got[0][constants.LABELS_KEY][1],
                                 np.array([0.0]))
          self.assertAlmostEqual(got[0][constants.EXAMPLE_WEIGHTS_KEY][1],
                                 np.array([0.0]))
          self.assertDictElementsAlmostEqual(got[0][constants.FEATURES_KEY][2],
                                             {
                                                 'fixed_int': np.array([2]),
                                                 'fixed_float': np.array([0.0]),
                                             })
          self.assertEqual(got[0][constants.FEATURES_KEY][2]['fixed_string'],
                           np.array([b'fixed_string3']))
          self.assertAlmostEqual(got[0][constants.LABELS_KEY][2],
                                 np.array([0.0]))
          self.assertAlmostEqual(got[0][constants.EXAMPLE_WEIGHTS_KEY][2],
                                 np.array([1.0]))

        except AssertionError as err:
          raise util.BeamAssertException(err)
Ejemplo n.º 14
0
            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_plots = got[0]
                    self.assertEqual(got_slice_key, ())
                    self.assertLen(got_plots, 1)
                    key = metric_types.PlotKey(
                        name='multi_class_confusion_matrix_plot')
                    got_matrix = got_plots[key]
                    self.assertProtoEquals(
                        """
              matrices {
                threshold: 0.0
                entries {
                  actual_class_id: 0
                  predicted_class_id: 2
                  num_weighted_examples: 1.0
                }
                entries {
                  actual_class_id: 1
                  predicted_class_id: 1
                  num_weighted_examples: 2.0
                }
                entries {
                  actual_class_id: 1
                  predicted_class_id: 2
                  num_weighted_examples: 0.25
                }
                entries {
                  actual_class_id: 2
                  predicted_class_id: 2
                  num_weighted_examples: 1.5
                }
              }
          """, got_matrix)

                except AssertionError as err:
                    raise util.BeamAssertException(err)
            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    self.assertLen(got_metrics, 6)  # 1 threshold * 6 metrics
                    self.assertTrue(
                        math.isnan(got_metrics[metric_types.MetricKey(
                            name=
                            'fairness_indicators_metrics/[email protected]',
                            model_name='',
                            output_name='',
                            sub_key=None)]))
                    self.assertTrue(
                        math.isnan(got_metrics[metric_types.MetricKey(
                            name=
                            'fairness_indicators_metrics/[email protected]',
                            model_name='',
                            output_name='',
                            sub_key=None)]))

                except AssertionError as err:
                    raise util.BeamAssertException(err)
            def check_metrics(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    example_count_key = metric_types.MetricKey(
                        name='example_count')
                    weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count')
                    label_key = metric_types.MetricKey(name='mean_label')
                    pred_key = metric_types.MetricKey(name='mean_prediction')
                    self.assertEqual(got_slice_key, ())
                    self.assertDictElementsAlmostEqual(
                        got_metrics, {
                            example_count_key: 3,
                            weighted_example_count_key: 4.0,
                            label_key:
                            (1.0 + 0.0 + 2 * 0.0) / (1.0 + 1.0 + 2.0),
                            pred_key:
                            (0.2 + 0.8 + 2 * 0.5) / (1.0 + 1.0 + 2.0),
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)
Ejemplo n.º 17
0
            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    total_queries_key = metric_types.MetricKey(
                        name='total_queries')
                    total_documents_key = metric_types.MetricKey(
                        name='total_documents')
                    min_documents_key = metric_types.MetricKey(
                        name='min_documents')
                    max_documents_key = metric_types.MetricKey(
                        name='max_documents')
                    self.assertDictElementsAlmostEqual(got_metrics, {
                        total_queries_key: 3,
                        total_documents_key: 6,
                        min_documents_key: 1,
                        max_documents_key: 3
                    },
                                                       places=5)

                except AssertionError as err:
                    raise util.BeamAssertException(err)
Ejemplo n.º 18
0
  def _matcher(actual):
    """Matcher function for comparing DatasetFeatureStatisticsList proto."""
    try:
      test.assertEqual(len(actual), 1)
      # Get the dataset stats from DatasetFeatureStatisticsList proto.
      actual_stats = actual[0].datasets[0]
      expected_stats = expected_result.datasets[0]

      test.assertEqual(actual_stats.num_examples, expected_stats.num_examples)
      test.assertEqual(len(actual_stats.features), len(expected_stats.features))

      expected_features = {}
      for feature in expected_stats.features:
        expected_features[feature.name] = feature

      for feature in actual_stats.features:
        compare.assertProtoEqual(
            test,
            feature,
            expected_features[feature.name],
            normalize_numbers=True)
    except AssertionError, e:
      raise util.BeamAssertException('Failed assert: ' + str(e))
 def check_result(got):  # pylint: disable=invalid-name
     try:
         self.assertEqual(1, len(got), 'got: %s' % got)
         (slice_key, value) = got[0]
         self.assertEqual((), slice_key)
         self.assertIn(
             metric_keys.base_key('[email protected]'), value)
         self.assertIn(
             metric_keys.base_key('[email protected]'), value)
         self.assertIn(
             metric_keys.base_key('[email protected]'), value)
         self.assertIn(
             metric_keys.base_key('[email protected]'), value)
         self.assertIn(
             metric_keys.base_key('[email protected]'), value)
         self.assertIn(
             metric_keys.base_key('[email protected]'), value)
         self.assertIn(
             metric_keys.base_key('[email protected]'), value)
         self.assertIn(
             metric_keys.base_key('[email protected]'), value)
     except AssertionError as err:
         raise util.BeamAssertException(err)
            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    self.assertDictElementsAlmostEqual(
                        got_metrics, {
                            metric_types.MetricKey(name='accuracy',
                                                   example_weighted=None):
                            1.0,
                            metric_types.MetricKey(name='label/mean',
                                                   example_weighted=None):
                            0.5,
                            metric_types.MetricKey(name='my_mean_age',
                                                   example_weighted=None):
                            3.75,
                            metric_types.MetricKey(name='my_mean_age_times_label',
                                                   example_weighted=None):
                            1.75
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)
Ejemplo n.º 21
0
            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got = got[0]
                    self.assertIn(constants.PREDICTIONS_KEY, got)
                    self.assertLen(got[constants.PREDICTIONS_KEY], 2)

                    for item in got[constants.PREDICTIONS_KEY]:
                        if multi_model:
                            self.assertIn('model1', item)
                            self.assertIn('model2', item)
                            if multi_output:
                                self.assertIn('Identity', item['model1'])
                                self.assertIn('Identity_1', item['model1'])

                        elif multi_output:
                            self.assertIn('Identity', item)
                            self.assertIn('Identity_1', item)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

                util.assert_that(result, check_result, label='result')
            def check_metrics(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    example_count_key = metric_types.MetricKey(
                        name='example_count')
                    weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count')
                    label_key_class_0 = metric_types.MetricKey(
                        name='mean_label',
                        sub_key=metric_types.SubKey(class_id=0))
                    label_key_class_1 = metric_types.MetricKey(
                        name='mean_label',
                        sub_key=metric_types.SubKey(class_id=1))
                    label_key_class_2 = metric_types.MetricKey(
                        name='mean_label',
                        sub_key=metric_types.SubKey(class_id=2))
                    self.assertEqual(got_slice_key, ())
                    self.assertDictElementsAlmostEqual(
                        got_metrics, {
                            example_count_key:
                            4,
                            weighted_example_count_key:
                            (1.0 + 2.0 + 3.0 + 4.0),
                            label_key_class_0:
                            (1 * 1.0 + 0 * 2.0 + 0 * 3.0 + 0 * 4.0) /
                            (1.0 + 2.0 + 3.0 + 4.0),
                            label_key_class_1:
                            (0 * 1.0 + 1 * 2.0 + 0 * 3.0 + 1 * 4.0) /
                            (1.0 + 2.0 + 3.0 + 4.0),
                            label_key_class_2:
                            (0 * 1.0 + 0 * 2.0 + 1 * 3.0 + 0 * 4.0) /
                            (1.0 + 2.0 + 3.0 + 4.0)
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)
Ejemplo n.º 23
0
 def check_result(got):  # pylint: disable=invalid-name
     try:
         self.assertEqual(1, len(got), 'got: %s' % got)
         (slice_key, value) = got[0]
         self.assertEqual((), slice_key)
         self.assertIn(metric_keys.AUC_PLOTS_MATRICES, value)
         matrices = value[metric_keys.AUC_PLOTS_MATRICES]
         #            |      | --------- Threshold -----------
         # true label | pred | -1e-6 | 0.0 | 0.7 | 0.8 | 1.0
         #     -      | 0.0  | FP    | TN  | TN  | TN  | TN
         #     +      | 0.0  | TP    | FN  | FN  | FN  | FN
         #     +      | 0.7  | TP    | TP  | FN  | FN  | FN
         #     -      | 0.8  | FP    | FP  | FP  | TN  | TN
         #     +      | 1.0  | TP    | TP  | TP  | TP  | FN
         self.assertSequenceAlmostEqual(matrices[0],
                                        [0, 0, 2, 3, 3.0 / 5.0, 1.0])
         self.assertSequenceAlmostEqual(
             matrices[1], [1, 1, 1, 2, 2.0 / 3.0, 2.0 / 3.0])
         self.assertSequenceAlmostEqual(
             matrices[7001], [2, 1, 1, 1, 1.0 / 2.0, 1.0 / 3.0])
         self.assertSequenceAlmostEqual(
             matrices[8001], [2, 2, 0, 1, 1.0 / 1.0, 1.0 / 3.0])
         self.assertSequenceAlmostEqual(
             matrices[10001],
             [3, 2, 0, 0, float('nan'), 0.0])
         self.assertIn(metric_keys.AUC_PLOTS_THRESHOLDS, value)
         thresholds = value[metric_keys.AUC_PLOTS_THRESHOLDS]
         self.assertAlmostEqual(0.0, thresholds[1])
         self.assertAlmostEqual(0.001, thresholds[11])
         self.assertAlmostEqual(0.005, thresholds[51])
         self.assertAlmostEqual(0.010, thresholds[101])
         self.assertAlmostEqual(0.100, thresholds[1001])
         self.assertAlmostEqual(0.800, thresholds[8001])
         self.assertAlmostEqual(1.000, thresholds[10001])
     except AssertionError as err:
         raise util.BeamAssertException(err)
      def check_result(got):
        try:
          self.assertLen(got, 1)
          got_slice_key, got_metrics = got[0]
          self.assertEqual(got_slice_key, ())
          key = metric.keys[0]
          # 1: prediction = 1, label = 2
          # 2: prediction = 2, label = 1
          # 3: prediction = 3, label = 2
          # 4: prediction = 4, label = 3
          #
          # pred_x_labels = 2 + 2 + 6 + 12 = 22
          # labels = 2 + 1 + 2 + 3 =  8
          # preds = 1 + 2 + 3 + 4 = 10
          # sq_labels = 4 + 1 + 4 + 9 = 18
          # sq_preds = 1 + 4 + 9 + 16 = 30
          # examples = 4
          #
          # r^2 = (22 - 8 * 10 / 4)^2 / (30 - 10^2 / 4) * (18 - 8^2 / 4)
          # r^2 = 4 / (5 * 2) = 0.4
          self.assertDictElementsAlmostEqual(got_metrics, {key: 0.4}, places=5)

        except AssertionError as err:
          raise util.BeamAssertException(err)
Ejemplo n.º 25
0
 def check_result(got):
   try:
     self.assertLen(got, 1)
     got_slice_key, got_metrics = got[0]
     self.assertEqual(got_slice_key, ())
     self.assertLen(got_metrics, 6)
     self.assertDictElementsAlmostEqual(
         got_metrics, {
             metric_types.MetricKey(
                 name='flip_count/[email protected]',
                 example_weighted=True):
                 5.0,
             metric_types.MetricKey(
                 name='flip_count/[email protected]',
                 example_weighted=True):
                 7.0,
             metric_types.MetricKey(
                 name='flip_count/[email protected]',
                 example_weighted=True):
                 6.0,
             metric_types.MetricKey(
                 name='flip_count/[email protected]',
                 example_weighted=True):
                 7.0,
         })
     self.assertAllEqual(
         got_metrics[metric_types.MetricKey(
             name='flip_count/[email protected]',
             example_weighted=True)], np.array([['id_2'], ['id_3']]))
     self.assertAllEqual(
         got_metrics[metric_types.MetricKey(
             name='flip_count/[email protected]',
             example_weighted=True)],
         np.array([['id_2'], ['id_3'], ['id_4']]))
   except AssertionError as err:
     raise util.BeamAssertException(err)
Ejemplo n.º 26
0
 def check_result(got):
     try:
         self.assertEqual(got, [{'a': 1, 'c': 3}])
     except AssertionError as err:
         raise util.BeamAssertException(err)
Ejemplo n.º 27
0
            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_plots = got[0]
                    self.assertEqual(got_slice_key, ())
                    self.assertLen(got_plots, 1)
                    key = metric_types.PlotKey(
                        name='multi_label_confusion_matrix_plot')
                    got_matrix = got_plots[key]
                    self.assertProtoEquals(
                        """
              matrices {
                threshold: 0.5
                entries {
                  actual_class_id: 0
                  predicted_class_id: 0
                  false_negatives: 0.0
                  true_negatives: 0.0
                  false_positives: 0.0
                  true_positives: 2.0
                }
                entries {
                  actual_class_id: 0
                  predicted_class_id: 1
                  false_negatives: 1.0
                  true_negatives: 1.0
                  false_positives: 0.0
                  true_positives: 0.0
                }
                entries {
                  actual_class_id: 0
                  predicted_class_id: 2
                  false_negatives: 0.0
                  true_negatives: 2.0
                  false_positives: 0.0
                  true_positives: 0.0
                }
                entries {
                  actual_class_id: 1
                  predicted_class_id: 0
                  false_negatives: 0.0
                  true_negatives: 1.0
                  false_positives: 0.0
                  true_positives: 1.0
                }
                entries {
                  actual_class_id: 1
                  predicted_class_id: 1
                  false_negatives: 1.0
                  true_negatives: 0.0
                  false_positives: 0.0
                  true_positives: 1.0
                }
                entries {
                  actual_class_id: 1
                  predicted_class_id: 2
                  false_negatives: 0.0
                  false_positives: 0.0
                  true_negatives: 2.0
                  true_positives: 0.0
                }
              }
          """, got_matrix)

                except AssertionError as err:
                    raise util.BeamAssertException(err)
Ejemplo n.º 28
0
            def check_result(got):
                try:
                    self.assertLen(got, 2)
                    self.assertDictElementsAlmostEqual(
                        got[0][constants.FEATURES_KEY], {
                            'fixed_int': np.array([1]),
                        })
                    self.assertEqual(
                        got[0][constants.FEATURES_KEY]['fixed_string'],
                        np.array([b'fixed_string1']))
                    for model_name in ('model1', 'model2'):
                        self.assertIn(model_name, got[0][constants.LABELS_KEY])
                        self.assertIn(model_name,
                                      got[0][constants.EXAMPLE_WEIGHTS_KEY])
                        self.assertIn(model_name,
                                      got[0][constants.PREDICTIONS_KEY])
                    self.assertAlmostEqual(
                        got[0][constants.LABELS_KEY]['model1'],
                        np.array([1.0]))
                    self.assertDictElementsAlmostEqual(
                        got[0][constants.LABELS_KEY]['model2'], {
                            'output1': np.array([1.0]),
                            'output2': np.array([0.0])
                        })
                    self.assertAlmostEqual(
                        got[0][constants.EXAMPLE_WEIGHTS_KEY]['model1'],
                        np.array([0.5]))
                    self.assertDictElementsAlmostEqual(
                        got[0][constants.EXAMPLE_WEIGHTS_KEY]['model2'], {
                            'output1': np.array([0.5]),
                            'output2': np.array([0.5])
                        })
                    self.assertAlmostEqual(
                        got[0][constants.PREDICTIONS_KEY]['model1'],
                        np.array([1.0]))
                    self.assertDictElementsAlmostEqual(
                        got[0][constants.PREDICTIONS_KEY]['model2'], {
                            'output1': np.array([1.0]),
                            'output2': np.array([1.0])
                        })

                    self.assertDictElementsAlmostEqual(
                        got[1][constants.FEATURES_KEY], {
                            'fixed_int': np.array([1]),
                        })
                    self.assertEqual(
                        got[1][constants.FEATURES_KEY]['fixed_string'],
                        np.array([b'fixed_string2']))
                    for model_name in ('model1', 'model2'):
                        self.assertIn(model_name, got[1][constants.LABELS_KEY])
                        self.assertIn(model_name,
                                      got[1][constants.EXAMPLE_WEIGHTS_KEY])
                        self.assertIn(model_name,
                                      got[1][constants.PREDICTIONS_KEY])
                    self.assertAlmostEqual(
                        got[1][constants.LABELS_KEY]['model1'],
                        np.array([1.0]))
                    self.assertDictElementsAlmostEqual(
                        got[1][constants.LABELS_KEY]['model2'], {
                            'output1': np.array([1.0]),
                            'output2': np.array([1.0])
                        })
                    self.assertAlmostEqual(
                        got[1][constants.EXAMPLE_WEIGHTS_KEY]['model1'],
                        np.array([0.0]))
                    self.assertDictElementsAlmostEqual(
                        got[1][constants.EXAMPLE_WEIGHTS_KEY]['model2'], {
                            'output1': np.array([0.0]),
                            'output2': np.array([1.0])
                        })
                    self.assertAlmostEqual(
                        got[1][constants.PREDICTIONS_KEY]['model1'],
                        np.array([2.0]))
                    self.assertDictElementsAlmostEqual(
                        got[1][constants.PREDICTIONS_KEY]['model2'], {
                            'output1': np.array([2.0]),
                            'output2': np.array([2.0])
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)
            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_plots = got[0]
                    self.assertEqual(got_slice_key, ())
                    self.assertLen(got_plots, 1)
                    key = metric_types.PlotKey(name='calibration_plot')
                    self.assertIn(key, got_plots)
                    got_plot = got_plots[key]
                    self.assertProtoEquals(
                        """
              buckets {
                lower_threshold_inclusive: -inf
                upper_threshold_exclusive: 0.0
                total_weighted_label {
                  value: 4.0
                }
                total_weighted_refined_prediction {
                  value: -0.4
                }
                num_weighted_examples {
                  value: 4.0
                }
              }
              buckets {
                lower_threshold_inclusive: 0.0
                upper_threshold_exclusive: 0.1
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 0.1
                upper_threshold_exclusive: 0.2
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 0.2
                upper_threshold_exclusive: 0.3
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                  value: 1.6
                }
                num_weighted_examples {
                  value: 8.0
                }
              }
              buckets {
                lower_threshold_inclusive: 0.3
                upper_threshold_exclusive: 0.4
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 0.4
                upper_threshold_exclusive: 0.5
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 0.5
                upper_threshold_exclusive: 0.6
                total_weighted_label {
                  value: 5.0
                }
                total_weighted_refined_prediction {
                  value: 4.0
                }
                num_weighted_examples {
                  value: 8.0
                }
              }
              buckets {
                lower_threshold_inclusive: 0.6
                upper_threshold_exclusive: 0.7
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 0.7
                upper_threshold_exclusive: 0.8
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 0.8
                upper_threshold_exclusive: 0.9
                total_weighted_label {
                  value: 8.0
                }
                total_weighted_refined_prediction {
                  value: 6.4
                }
                num_weighted_examples {
                  value: 8.0
                }
              }
              buckets {
                lower_threshold_inclusive: 0.9
                upper_threshold_exclusive: 1.0
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 1.0
                upper_threshold_exclusive: inf
                total_weighted_label {
                  value: 8.0
                }
                total_weighted_refined_prediction {
                  value: 8.8
                }
                num_weighted_examples {
                  value: 8.0
                }
              }
          """, got_plot)

                except AssertionError as err:
                    raise util.BeamAssertException(err)
 def check_result(got):
   try:
     self.assertEqual(got, [{constants.INPUT_KEY: 'input', 'other': 2}])
   except AssertionError as err:
     raise util.BeamAssertException(err)