def check_result(got):
                    try:
                        self.assertEqual(3, len(got), 'got: %s' % got)
                        slices = {}
                        for slice_key, value in got:
                            slices[slice_key] = value
                        overall_slice = ()
                        first_slice = (('slice_key', b'first_slice'), )
                        second_slice = (('slice_key', b'second_slice'), )
                        self.assertItemsEqual(
                            list(slices.keys()),
                            [overall_slice, first_slice, second_slice])
                        self.assertDictElementsWithIntervalsAlmostEqual(
                            slices[overall_slice], {
                                'accuracy': 0.4,
                                'label/mean': 0.6,
                                'my_mean_age': 4.0,
                                'my_mean_age_times_label': 2.6,
                                'added_example_count': 5.0
                            })
                        self.assertDictElementsWithIntervalsAlmostEqual(
                            slices[first_slice], {
                                'accuracy': 1.0,
                                'label/mean': 0.5,
                                'my_mean_age': 3.0,
                                'my_mean_age_times_label': 1.5,
                                'added_example_count': 2.0
                            })
                        self.assertDictElementsWithIntervalsAlmostEqual(
                            slices[second_slice], {
                                'accuracy': 0.0,
                                'label/mean': 2.0 / 3.0,
                                'my_mean_age': 14.0 / 3.0,
                                'my_mean_age_times_label': 10.0 / 3.0,
                                'added_example_count': 3.0
                            })
                        # Ensure that serialization of the key at the end of
                        # ComputeMetricsAndPlots works.
                        for slice_key, value in got:
                            metrics_and_plots_evaluator._serialize_metrics(
                                (slice_key, value), [])

                    except AssertionError as err:
                        # This function is redefined every iteration, so it will have the
                        # right value of batch_size.
                        raise util.BeamAssertException(
                            'batch_size = %d, error: %s' % (batch_size, err))  # pylint: disable=cell-var-from-loop
 def testUncertaintyValuedMetrics(self):
     slice_key = _make_slice_key()
     slice_metrics = {
         'one_dim':
         types.ValueWithConfidenceInterval(2.0, 1.0, 3.0),
         'nans':
         types.ValueWithConfidenceInterval(float('nan'), float('nan'),
                                           float('nan')),
     }
     expected_metrics_for_slice = text_format.Parse(
         """
     slice_key {}
     metrics {
       key: "one_dim"
       value {
         bounded_value {
           value {
             value: 2.0
           }
           lower_bound {
             value: 1.0
           }
           upper_bound {
             value: 3.0
           }
           methodology: POISSON_BOOTSTRAP
         }
       }
     }
     metrics {
       key: "nans"
       value {
         bounded_value {
           value {
             value: nan
           }
           lower_bound {
             value: nan
           }
           upper_bound {
             value: nan
           }
           methodology: POISSON_BOOTSTRAP
         }
       }
     }
     """, metrics_for_slice_pb2.MetricsForSlice())
     got = metrics_and_plots_evaluator._serialize_metrics(
         (slice_key, slice_metrics), [])
     self.assertProtoEquals(
         expected_metrics_for_slice,
         metrics_for_slice_pb2.MetricsForSlice.FromString(got))
 def testTensorValuedMetrics(self):
     slice_key = _make_slice_key()
     slice_metrics = {
         'one_dim':
         np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32),
         'two_dims':
         np.array([['two', 'dims', 'test'], ['TWO', 'DIMS', 'TEST']]),
         'three_dims':
         np.array([[[100, 200, 300]], [[500, 600, 700]]], dtype=np.int64),
     }
     expected_metrics_for_slice = text_format.Parse(
         """
     slice_key {}
     metrics {
       key: "one_dim"
       value {
         array_value {
           data_type: FLOAT32
           shape: 4
           float32_values: [1.0, 2.0, 3.0, 4.0]
         }
       }
     }
     metrics {
       key: "two_dims"
       value {
         array_value {
           data_type: BYTES
           shape: [2, 3]
           bytes_values: ["two", "dims", "test", "TWO", "DIMS", "TEST"]
         }
       }
     }
     metrics {
       key: "three_dims"
       value {
         array_value {
           data_type: INT64
           shape: [2, 1, 3]
           int64_values: [100, 200, 300, 500, 600, 700]
         }
       }
     }
     """, metrics_for_slice_pb2.MetricsForSlice())
     got = metrics_and_plots_evaluator._serialize_metrics(
         (slice_key, slice_metrics), [])
     self.assertProtoEquals(
         expected_metrics_for_slice,
         metrics_for_slice_pb2.MetricsForSlice.FromString(got))
    def testStringMetrics(self):
        slice_key = _make_slice_key()
        slice_metrics = {
            'valid_ascii': b'test string',
            'valid_unicode': b'\xF0\x9F\x90\x84',  # U+1F404, Cow
            'invalid_unicode': b'\xE2\x28\xA1',
        }
        expected_metrics_for_slice = metrics_for_slice_pb2.MetricsForSlice()
        expected_metrics_for_slice.slice_key.SetInParent()
        expected_metrics_for_slice.metrics[
            'valid_ascii'].bytes_value = slice_metrics['valid_ascii']
        expected_metrics_for_slice.metrics[
            'valid_unicode'].bytes_value = slice_metrics['valid_unicode']
        expected_metrics_for_slice.metrics[
            'invalid_unicode'].bytes_value = slice_metrics['invalid_unicode']

        got = metrics_and_plots_evaluator._serialize_metrics(
            (slice_key, slice_metrics), [])
        self.assertProtoEquals(
            expected_metrics_for_slice,
            metrics_for_slice_pb2.MetricsForSlice.FromString(got))
    def testSerializeMetrics(self):
        slice_key = _make_slice_key('age', 5, 'language', 'english', 'price',
                                    0.3)
        slice_metrics = {
            'accuracy': 0.8,
            _full_key(metric_keys.AUPRC): 0.1,
            _full_key(metric_keys.lower_bound(metric_keys.AUPRC)): 0.05,
            _full_key(metric_keys.upper_bound(metric_keys.AUPRC)): 0.17,
            _full_key(metric_keys.AUC): 0.2,
            _full_key(metric_keys.lower_bound(metric_keys.AUC)): 0.1,
            _full_key(metric_keys.upper_bound(metric_keys.AUC)): 0.3
        }
        expected_metrics_for_slice = text_format.Parse(
            string.Template("""
        slice_key {
          single_slice_keys {
            column: 'age'
            int64_value: 5
          }
          single_slice_keys {
            column: 'language'
            bytes_value: 'english'
          }
          single_slice_keys {
            column: 'price'
            float_value: 0.3
          }
        }
        metrics {
          key: "accuracy"
          value {
            double_value {
              value: 0.8
            }
          }
        }
        metrics {
          key: "$auc"
          value {
            bounded_value {
              lower_bound {
                value: 0.1
              }
              upper_bound {
                value: 0.3
              }
              value {
                value: 0.2
              }
              methodology: RIEMANN_SUM
            }
          }
        }
        metrics {
          key: "$auprc"
          value {
            bounded_value {
              lower_bound {
                value: 0.05
              }
              upper_bound {
                value: 0.17
              }
              value {
                value: 0.1
              }
              methodology: RIEMANN_SUM
            }
          }
        }""").substitute(auc=_full_key(metric_keys.AUC),
                         auprc=_full_key(metric_keys.AUPRC)),
            metrics_for_slice_pb2.MetricsForSlice())

        got = metrics_and_plots_evaluator._serialize_metrics(
            (slice_key, slice_metrics),
            [post_export_metrics.auc(),
             post_export_metrics.auc(curve='PR')])
        self.assertProtoEquals(
            expected_metrics_for_slice,
            metrics_for_slice_pb2.MetricsForSlice.FromString(got))
    def testSerializeConfusionMatrices(self):
        slice_key = _make_slice_key()

        thresholds = [0.25, 0.75, 1.00]
        matrices = [[0.0, 1.0, 0.0, 2.0, 1.0, 1.0],
                    [1.0, 1.0, 0.0, 1.0, 1.0, 0.5],
                    [2.0, 1.0, 0.0, 0.0, float('nan'), 0.0]]

        slice_metrics = {
            _full_key(metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES):
            matrices,
            _full_key(metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS):
            thresholds,
        }
        expected_metrics_for_slice = text_format.Parse(
            """
        slice_key {}
        metrics {
          key: "post_export_metrics/confusion_matrix_at_thresholds"
          value {
            confusion_matrix_at_thresholds {
              matrices {
                threshold: 0.25
                false_negatives: 0.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 2.0
                precision: 1.0
                recall: 1.0
                bounded_false_negatives {
                  value {
                    value: 0.0
                  }
                }
                bounded_true_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_positives {
                  value {
                    value: 2.0
                  }
                }
                bounded_false_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_precision {
                  value {
                    value: 1.0
                  }
                }
                bounded_recall {
                  value {
                    value: 1.0
                  }
                }
              }
              matrices {
                threshold: 0.75
                false_negatives: 1.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 1.0
                precision: 1.0
                recall: 0.5
                bounded_false_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_positives {
                  value {
                    value: 1.0
                  }
                }
                bounded_false_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_precision {
                  value {
                    value: 1.0
                  }
                }
                bounded_recall {
                  value {
                    value: 0.5
                  }
                }
              }
              matrices {
                threshold: 1.00
                false_negatives: 2.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 0.0
                precision: nan
                recall: 0.0
                bounded_false_negatives {
                  value {
                    value: 2.0
                  }
                }
                bounded_true_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_false_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_precision {
                  value {
                    value: nan
                  }
                }
                bounded_recall {
                  value {
                    value: 0.0
                  }
                }
              }
            }
          }
        }
        """, metrics_for_slice_pb2.MetricsForSlice())

        got = metrics_and_plots_evaluator._serialize_metrics(
            (slice_key, slice_metrics),
            [post_export_metrics.confusion_matrix_at_thresholds(thresholds)])
        self.assertProtoEquals(
            expected_metrics_for_slice,
            metrics_for_slice_pb2.MetricsForSlice.FromString(got))