def example_count_specs( model_names: Optional[List[Text]] = None, output_names: Optional[List[Text]] = None, include_example_count: bool = True, include_weighted_example_count: bool = True ) -> List[config.MetricsSpec]: """Returns metric specs for example count and weighted example counts. Args: model_names: Optional list of model names (if multi-model evaluation). output_names: Optional list of output names (if multi-output model). include_example_count: True to add example_count metric. include_weighted_example_count: True to add weighted_example_count metric. A weighted example count will be added per output for multi-output models. """ specs = [] if include_example_count: metric_config = _serialize_tfma_metric(example_count.ExampleCount()) specs.append( config.MetricsSpec(metrics=[metric_config], model_names=model_names)) if include_weighted_example_count: metric_config = _serialize_tfma_metric( weighted_example_count.WeightedExampleCount()) specs.append( config.MetricsSpec(metrics=[metric_config], model_names=model_names, output_names=output_names)) return specs
def example_count_specs( model_names: Optional[List[Text]] = None, output_names: Optional[List[Text]] = None, output_weights: Optional[Dict[Text, float]] = None, include_example_count: bool = True, include_weighted_example_count: bool = True) -> List[config.MetricsSpec]: """Returns metric specs for example count and weighted example counts. Args: model_names: Optional list of model names (if multi-model evaluation). output_names: Optional list of output names (if multi-output model). output_weights: Optional output weights for creating overall metric aggregated across outputs (if multi-output model). If a weight is not provided for an output, it's weight defaults to 0.0 (i.e. output ignored). include_example_count: True to add example_count metric. include_weighted_example_count: True to add weighted_example_count metric. A weighted example count will be added per output for multi-output models. """ specs = [] if include_example_count: metric_config = _serialize_tfma_metric(example_count.ExampleCount()) specs.append( config.MetricsSpec(metrics=[metric_config], model_names=model_names)) if include_weighted_example_count: metric_config = _serialize_tfma_metric( weighted_example_count.WeightedExampleCount()) specs.append( config.MetricsSpec( metrics=[metric_config], model_names=model_names, output_names=output_names, output_weights=output_weights)) return specs
def testWeightedExampleCount(self, model_name, output_name): metric = weighted_example_count.WeightedExampleCount().computations( model_names=[model_name], output_names=[output_name])[0] example1 = {'labels': None, 'predictions': None, 'example_weights': [0.5]} example2 = {'labels': None, 'predictions': None, 'example_weights': [1.0]} example3 = {'labels': None, 'predictions': None, 'example_weights': [0.7]} if output_name: example1['example_weights'] = {output_name: example1['example_weights']} example2['example_weights'] = {output_name: example2['example_weights']} example3['example_weights'] = {output_name: example3['example_weights']} if model_name: example1['example_weights'] = {model_name: example1['example_weights']} example2['example_weights'] = {model_name: example2['example_weights']} example3['example_weights'] = {model_name: example3['example_weights']} with beam.Pipeline() as pipeline: # pylint: disable=no-value-for-parameter result = ( pipeline | 'Create' >> beam.Create([example1, example2, example3]) | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) | 'AddSlice' >> beam.Map(lambda x: ((), x)) | 'ComputeMetric' >> beam.CombinePerKey(metric.combiner)) # pylint: enable=no-value-for-parameter def check_result(got): try: self.assertLen(got, 1) got_slice_key, got_metrics = got[0] self.assertEqual(got_slice_key, ()) weighted_example_count_key = metric_types.MetricKey( name='weighted_example_count', model_name=model_name, output_name=output_name) self.assertDictElementsAlmostEqual( got_metrics, {weighted_example_count_key: (0.5 + 1.0 + 0.7)}) except AssertionError as err: raise util.BeamAssertException(err) util.assert_that(result, check_result, label='result')