예제 #1
0
def filter_metrics(
        eval_result: tfma.EvalResult,
        metrics_include: Optional[List[str]] = None,
        metrics_exclude: Optional[List[str]] = None) -> tfma.EvalResult:
    """Filters metrics in a TFMA EvalResult.

  Args:
    eval_result: The TFMA EvalResult object.
    metrics_include: The names of metrics to keep in the EvalResult. Mutually
      exclusive with metrics_exclude.
    metrics_exclude: The names of metrics to discard in the EvalResult. Mutually
      exclusive with metrics_include.

  Returns:
    The eval_result with unwanted metrics filtered.

  Raises:
    ValueError: if both metrics_include and metrics_exclude are provided.
  """
    if metrics_include and not metrics_exclude:
        include = lambda metric_name: metric_name in metrics_include
    elif metrics_exclude and not metrics_include:
        include = lambda metric_name: metric_name not in metrics_exclude
    else:
        raise ValueError(
            'filter_metrics() requires exactly one of metrics_include '
            'and metrics_exclude.')

    filtered_slicing_metrics = []
    for slc, mtrc in eval_result.slicing_metrics:
        filtered_mtrc = {}
        for output_name in mtrc:
            for subkey in mtrc[output_name]:
                for mtrc_name in mtrc[output_name][subkey]:
                    if include(mtrc_name):
                        filtered_mtrc[output_name] = filtered_mtrc.get(
                            output_name, {})
                        filtered_mtrc[output_name][subkey] = filtered_mtrc[
                            output_name].get(subkey, {})
                        filtered_mtrc[output_name][subkey][mtrc_name] = mtrc[
                            output_name][subkey][mtrc_name]
        filtered_slicing_metrics.append(
            tfma.view.SlicedMetrics(slice=slc, metrics=filtered_mtrc))

    return tfma.EvalResult(slicing_metrics=filtered_slicing_metrics,
                           plots=eval_result.plots,
                           attributions=eval_result.attributions,
                           config=eval_result.config,
                           data_location=eval_result.data_location,
                           file_format=eval_result.file_format,
                           model_location=eval_result.model_location)
예제 #2
0
    def test_annotate_eval_results_plots(self):
        slicing_metrics = [
            ((('weekday', 0), ), {
                '': {
                    '': {
                        'average_loss': {
                            'doubleValue': 0.07875693589448929
                        },
                        'prediction/mean': {
                            'boundedValue': {
                                'value': 0.5100112557411194,
                                'lowerBound': 0.4100112557411194,
                                'upperBound': 0.6100112557411194,
                            }
                        },
                        'average_loss_diff': {}
                    }
                }
            }),
            ((('weekday', 1), ), {
                '': {
                    '': {
                        'average_loss': {
                            'doubleValue': 4.4887189865112305
                        },
                        'prediction/mean': {
                            'boundedValue': {
                                'value': 0.4839990735054016,
                                'lowerBound': 0.3839990735054016,
                                'upperBound': 0.5839990735054016,
                            }
                        },
                        'average_loss_diff': {}
                    }
                }
            }),
            ((('weekday', 2), ), {
                '': {
                    '': {
                        'average_loss': {
                            'doubleValue': 2.092138290405273
                        },
                        'prediction/mean': {
                            'boundedValue': {
                                'value': 0.3767518997192383,
                                'lowerBound': 0.1767518997192383,
                                'upperBound': 0.5767518997192383,
                            }
                        },
                        'average_loss_diff': {}
                    }
                }
            }),
            ((('gender', 'male'), ('age', 10)), {
                '': {
                    '': {
                        'average_loss': {
                            'doubleValue': 2.092138290405273
                        },
                        'prediction/mean': {
                            'boundedValue': {
                                'value': 0.3767518997192383,
                                'lowerBound': 0.1767518997192383,
                                'upperBound': 0.5767518997192383,
                            }
                        },
                        'average_loss_diff': {}
                    }
                }
            }),
            (
                (('gender', 'female'), ('age', 20)),
                {
                    '': {
                        '': {
                            'average_loss': {
                                'doubleValue': 2.092138290405273
                            },
                            'prediction/mean': {
                                'doubleValue': 0.3767518997192383
                            },
                            'average_loss_diff': {},
                            '__ERROR__': {
                                # CI not computed because only 16 samples
                                # were non-empty. Expected 20.
                                'bytesValue':
                                'Q0kgbm90IGNvbXB1dGVkIGJlY2F1c2Ugb25seSAxNiBzYW1wbGVzIHdlcmUgbm9uLWVtcHR5LiBFeHBlY3RlZCAyMC4='
                            }
                        }
                    }
                }),
            ((), {
                '': {
                    '': {
                        'average_loss': {
                            'doubleValue': 1.092138290405273
                        },
                        'prediction/mean': {
                            'boundedValue': {
                                'value': 0.4767518997192383,
                                'lowerBound': 0.2767518997192383,
                                'upperBound': 0.6767518997192383,
                            }
                        },
                        'average_loss_diff': {}
                    }
                }
            })
        ]
        eval_result = tfma.EvalResult(slicing_metrics=slicing_metrics,
                                      plots=None,
                                      attributions=None,
                                      config=None,
                                      data_location=None,
                                      file_format=None,
                                      model_location=None)
        model_card = model_card_module.ModelCard()
        graphics.annotate_eval_result_plots(model_card, eval_result)

        expected_metrics_names = {
            'average_loss | weekday', 'prediction/mean | weekday',
            'average_loss | gender, age', 'prediction/mean | gender, age'
        }
        self.assertSameElements(expected_metrics_names, [
            g.name
            for g in model_card.quantitative_analysis.graphics.collection
        ])

        for graph in model_card.quantitative_analysis.graphics.collection:
            logging.info('%s: %s', graph.name, graph.image)
            self.assertNotEmpty(graph.image,
                                f'feature {graph.name} has empty plot')
예제 #3
0
 def test_filter_metrics(self):
   eval_result = tfma.EvalResult(
       slicing_metrics=_SLICING_METRICS,
       plots=None,
       attributions=None,
       config=None,
       data_location=None,
       file_format=None,
       model_location=None)
   metrics_include = ['average_loss']
   metrics_exclude = [
       'prediction/mean', 'int_array', 'float_array', 'invalid_array'
   ]
   expected_slicing_metrics = [
       ((('weekday', 0),), {
           '': {
               '': {
                   'average_loss': {
                       'doubleValue': 0.07875693589448929
                   }
               }
           }
       }),
       ((('weekday', 1),), {
           '': {
               '': {
                   'average_loss': {
                       'doubleValue': 4.4887189865112305
                   }
               }
           }
       }),
       ((('weekday', 2),), {
           '': {
               '': {
                   'average_loss': {
                       'doubleValue': 2.092138290405273
                   }
               }
           }
       }),
       ((('gender', 'male'), ('age', 10)), {
           '': {
               '': {
                   'average_loss': {
                       'doubleValue': 2.092138290405273
                   }
               }
           }
       }),
       ((('gender', 'female'), ('age', 20)), {
           '': {
               '': {
                   'average_loss': {
                       'doubleValue': 2.092138290405273
                   }
               }
           }
       }),
       ((), {
           '': {
               '': {
                   'average_loss': {
                       'doubleValue': 1.092138290405273
                   }
               }
           }
       })
   ]
   with self.subTest(name='metrics_include'):
     self.assertEqual(
         tfx_util.filter_metrics(
             eval_result, metrics_include=metrics_include).slicing_metrics,
         expected_slicing_metrics)
   with self.subTest(name='metrics_exclude'):
     self.assertEqual(
         tfx_util.filter_metrics(
             eval_result, metrics_exclude=metrics_exclude).slicing_metrics,
         expected_slicing_metrics)
   with self.subTest(
       name='both metrics_include and metrics_exclude (invalid)'):
     with self.assertRaises(ValueError):
       tfx_util.filter_metrics(
           eval_result,
           metrics_include=metrics_include,
           metrics_exclude=metrics_exclude)
   with self.subTest(
       name='neither metrics_include nor metrics_exclude (invalid)'):
     with self.assertRaises(ValueError):
       tfx_util.filter_metrics(eval_result)
예제 #4
0
  def test_annotate_eval_results_plots(self):
    slicing_metrics = [((('weekday', 0),), {
        '': {
            '': {
                'average_loss': {
                    'doubleValue': 0.07875693589448929
                },
                'prediction/mean': {
                    'boundedValue': {
                        'value': 0.5100112557411194,
                        'lowerBound': 0.4100112557411194,
                        'upperBound': 0.6100112557411194,
                    }
                },
                'average_loss_diff': {}
            }
        }
    }),
                       ((('weekday', 1),), {
                           '': {
                               '': {
                                   'average_loss': {
                                       'doubleValue': 4.4887189865112305
                                   },
                                   'prediction/mean': {
                                       'boundedValue': {
                                           'value': 0.4839990735054016,
                                           'lowerBound': 0.3839990735054016,
                                           'upperBound': 0.5839990735054016,
                                       }
                                   },
                                   'average_loss_diff': {}
                               }
                           }
                       }),
                       ((('weekday', 2),), {
                           '': {
                               '': {
                                   'average_loss': {
                                       'doubleValue': 2.092138290405273
                                   },
                                   'prediction/mean': {
                                       'boundedValue': {
                                           'value': 0.3767518997192383,
                                           'lowerBound': 0.1767518997192383,
                                           'upperBound': 0.5767518997192383,
                                       }
                                   },
                                   'average_loss_diff': {}
                               }
                           }
                       }),
                       ((('gender', 'male'), ('age', 10)), {
                           '': {
                               '': {
                                   'average_loss': {
                                       'doubleValue': 2.092138290405273
                                   },
                                   'prediction/mean': {
                                       'boundedValue': {
                                           'value': 0.3767518997192383,
                                           'lowerBound': 0.1767518997192383,
                                           'upperBound': 0.5767518997192383,
                                       }
                                   },
                                   'average_loss_diff': {}
                               }
                           }
                       }),
                       ((('gender', 'female'), ('age', 20)), {
                           '': {
                               '': {
                                   'average_loss': {
                                       'doubleValue': 2.092138290405273
                                   },
                                   'prediction/mean': {
                                       'boundedValue': {
                                           'value': 0.3767518997192383,
                                           'lowerBound': 0.1767518997192383,
                                           'upperBound': 0.5767518997192383,
                                       }
                                   },
                                   'average_loss_diff': {}
                               }
                           }
                       }),
                       ((), {
                           '': {
                               '': {
                                   'average_loss': {
                                       'doubleValue': 1.092138290405273
                                   },
                                   'prediction/mean': {
                                       'boundedValue': {
                                           'value': 0.4767518997192383,
                                           'lowerBound': 0.2767518997192383,
                                           'upperBound': 0.6767518997192383,
                                       }
                                   },
                                   'average_loss_diff': {}
                               }
                           }
                       })]
    eval_result = tfma.EvalResult(slicing_metrics, None, None, None, None, None)
    model_card = model_card_module.ModelCard()
    graphics.annotate_eval_result_plots(model_card, eval_result)

    expected_metrics_names = {
        'average_loss | weekday', 'prediction/mean | weekday',
        'average_loss | gender, age', 'prediction/mean | gender, age'
    }
    self.assertSameElements(
        expected_metrics_names,
        [g.name for g in model_card.quantitative_analysis.graphics.collection])

    for graph in model_card.quantitative_analysis.graphics.collection:
      logging.info('%s: %s', graph.name, graph.image)
      self.assertNotEmpty(graph.image, f'feature {graph.name} has empty plot')
예제 #5
0
  def test_annotate_eval_results_metrics(self):
    eval_result = tfma.EvalResult(
        slicing_metrics=_SLICING_METRICS,
        plots=None,
        attributions=None,
        config=None,
        data_location=None,
        file_format=None,
        model_location=None)
    model_card = ModelCard()
    tfx_util.annotate_eval_result_metrics(model_card, eval_result)

    expected_metrics = [
        PerformanceMetric(
            type='average_loss', value='0.07875693589448929',
            slice='weekday_0'),
        PerformanceMetric(
            type='prediction/mean',
            value='0.5100112557411194',
            slice='weekday_0'),
        PerformanceMetric(
            type='average_loss', value='4.4887189865112305', slice='weekday_1'),
        PerformanceMetric(
            type='prediction/mean',
            value='0.4839990735054016',
            slice='weekday_1'),
        PerformanceMetric(
            type='average_loss', value='2.092138290405273', slice='weekday_2'),
        PerformanceMetric(
            type='prediction/mean',
            value='0.3767518997192383',
            slice='weekday_2'),
        PerformanceMetric(
            type='average_loss',
            value='2.092138290405273',
            slice='gender_male_X_age_10'),
        PerformanceMetric(
            type='prediction/mean',
            value='0.3767518997192383',
            slice='gender_male_X_age_10'),
        PerformanceMetric(
            type='average_loss',
            value='2.092138290405273',
            slice='gender_female_X_age_20'),
        PerformanceMetric(
            type='prediction/mean',
            value='0.3767518997192383',
            slice='gender_female_X_age_20'),
        PerformanceMetric(
            type='average_loss', value='1.092138290405273', slice=''),
        PerformanceMetric(
            type='prediction/mean', value='0.4767518997192383', slice=''),
        PerformanceMetric(type='int_array', value='1, 2, 3', slice=''),
        PerformanceMetric(type='float_array', value='1.1, 2.2, 3.3', slice='')
    ]
    self.assertEqual(
        len(model_card.quantitative_analysis.performance_metrics),
        len(expected_metrics))
    for actual_metric, expected_metric in zip(
        model_card.quantitative_analysis.performance_metrics, expected_metrics):
      self.assertEqual(actual_metric.type, expected_metric.type)
      self.assertEqual(actual_metric.slice, expected_metric.slice)
      self.assertEqual(actual_metric.value, expected_metric.value)