示例#1
0
 def testSerializeDeserializeEvalConfig(self):
     output_path = self._getTempDir()
     options = config.Options()
     options.compute_confidence_intervals.value = False
     options.k_anonymization_count.value = 1
     eval_config = config.EvalConfig(
         input_data_specs=[config.InputDataSpec(location='/path/to/data')],
         model_specs=[config.ModelSpec(location='/path/to/model')],
         output_data_specs=[
             config.OutputDataSpec(default_location=output_path)
         ],
         slicing_specs=[
             config.SlicingSpec(feature_keys=['country'],
                                feature_values={
                                    'age': '5',
                                    'gender': 'f'
                                }),
             config.SlicingSpec(feature_keys=['interest'],
                                feature_values={
                                    'age': '6',
                                    'gender': 'm'
                                })
         ],
         options=options)
     with tf.io.gfile.GFile(os.path.join(output_path, 'eval_config.json'),
                            'w') as f:
         f.write(model_eval_lib._serialize_eval_config(eval_config))
     got_eval_config = model_eval_lib.load_eval_config(output_path)
     self.assertEqual(eval_config, got_eval_config)
示例#2
0
 def testSerializeDeserializeLegacyEvalConfig(self):
     output_path = self._getTempDir()
     old_config = LegacyConfig(
         model_location='/path/to/model',
         data_location='/path/to/data',
         slice_spec=[
             slicer.SingleSliceSpec(columns=['country'],
                                    features=[('age', 5), ('gender', 'f')]),
             slicer.SingleSliceSpec(columns=['interest'],
                                    features=[('age', 6), ('gender', 'm')])
         ],
         example_count_metric_key=None,
         example_weight_metric_key='key',
         compute_confidence_intervals=False,
         k_anonymization_count=1)
     final_dict = {}
     final_dict['tfma_version'] = tfma_version.VERSION_STRING
     final_dict['eval_config'] = old_config
     with tf.io.TFRecordWriter(os.path.join(output_path,
                                            'eval_config')) as w:
         w.write(pickle.dumps(final_dict))
     got_eval_config = model_eval_lib.load_eval_config(output_path)
     options = config.Options()
     options.compute_confidence_intervals.value = (
         old_config.compute_confidence_intervals)
     options.k_anonymization_count.value = old_config.k_anonymization_count
     eval_config = config.EvalConfig(
         input_data_specs=[
             config.InputDataSpec(location=old_config.data_location)
         ],
         model_specs=[config.ModelSpec(location=old_config.model_location)],
         output_data_specs=[
             config.OutputDataSpec(default_location=output_path)
         ],
         slicing_specs=[
             config.SlicingSpec(feature_keys=['country'],
                                feature_values={
                                    'age': '5',
                                    'gender': 'f'
                                }),
             config.SlicingSpec(feature_keys=['interest'],
                                feature_values={
                                    'age': '6',
                                    'gender': 'm'
                                })
         ],
         options=options)
     self.assertEqual(eval_config, got_eval_config)
    def testSerializeDeserializeToFile(self):
        metrics_slice_key = _make_slice_key(b'fruit', b'pear', b'animal',
                                            b'duck')
        metrics_for_slice = text_format.Parse(
            """
        slice_key {
          single_slice_keys {
            column: "fruit"
            bytes_value: "pear"
          }
          single_slice_keys {
            column: "animal"
            bytes_value: "duck"
          }
        }
        metrics {
          key: "accuracy"
          value {
            double_value {
              value: 0.8
            }
          }
        }
        metrics {
          key: "example_weight"
          value {
            double_value {
              value: 10.0
            }
          }
        }
        metrics {
          key: "auc"
          value {
            bounded_value {
              lower_bound {
                value: 0.1
              }
              upper_bound {
                value: 0.3
              }
              value {
                value: 0.2
              }
            }
          }
        }
        metrics {
          key: "auprc"
          value {
            bounded_value {
              lower_bound {
                value: 0.05
              }
              upper_bound {
                value: 0.17
              }
              value {
                value: 0.1
              }
            }
          }
        }""", metrics_for_slice_pb2.MetricsForSlice())
        plots_for_slice = text_format.Parse(
            """
        slice_key {
          single_slice_keys {
            column: "fruit"
            bytes_value: "peach"
          }
          single_slice_keys {
            column: "animal"
            bytes_value: "cow"
          }
        }
        plots {
          key: ''
          value {
            calibration_histogram_buckets {
              buckets {
                lower_threshold_inclusive: -inf
                upper_threshold_exclusive: 0.0
                num_weighted_examples { value: 0.0 }
                total_weighted_label { value: 0.0 }
                total_weighted_refined_prediction { value: 0.0 }
              }
              buckets {
                lower_threshold_inclusive: 0.0
                upper_threshold_exclusive: 0.5
                num_weighted_examples { value: 1.0 }
                total_weighted_label { value: 1.0 }
                total_weighted_refined_prediction { value: 0.3 }
              }
              buckets {
                lower_threshold_inclusive: 0.5
                upper_threshold_exclusive: 1.0
                num_weighted_examples { value: 1.0 }
                total_weighted_label { value: 0.0 }
                total_weighted_refined_prediction { value: 0.7 }
              }
              buckets {
                lower_threshold_inclusive: 1.0
                upper_threshold_exclusive: inf
                num_weighted_examples { value: 0.0 }
                total_weighted_label { value: 0.0 }
                total_weighted_refined_prediction { value: 0.0 }
              }
            }
          }
        }""", metrics_for_slice_pb2.PlotsForSlice())
        plots_slice_key = _make_slice_key(b'fruit', b'peach', b'animal',
                                          b'cow')
        eval_config = model_eval_lib.EvalConfig(
            model_location='/path/to/model',
            data_location='/path/to/data',
            slice_spec=[
                slicer.SingleSliceSpec(features=[('age', 5), ('gender', 'f')],
                                       columns=['country']),
                slicer.SingleSliceSpec(features=[('age', 6), ('gender', 'm')],
                                       columns=['interest'])
            ],
            example_weight_metric_key='key')

        output_path = self._getTempDir()
        with beam.Pipeline() as pipeline:
            metrics = (pipeline
                       | 'CreateMetrics' >> beam.Create(
                           [metrics_for_slice.SerializeToString()]))
            plots = (pipeline
                     | 'CreatePlots' >> beam.Create(
                         [plots_for_slice.SerializeToString()]))
            evaluation = {
                constants.METRICS_KEY: metrics,
                constants.PLOTS_KEY: plots
            }
            _ = (evaluation
                 | 'WriteResults' >> model_eval_lib.WriteResults(
                     writers=model_eval_lib.default_writers(
                         output_path=output_path)))
            _ = pipeline | model_eval_lib.WriteEvalConfig(
                eval_config, output_path)

        metrics = metrics_and_plots_evaluator.load_and_deserialize_metrics(
            path=os.path.join(output_path,
                              model_eval_lib._METRICS_OUTPUT_FILE))
        plots = metrics_and_plots_evaluator.load_and_deserialize_plots(
            path=os.path.join(output_path, model_eval_lib._PLOTS_OUTPUT_FILE))
        self.assertSliceMetricsListEqual(
            [(metrics_slice_key, metrics_for_slice.metrics)], metrics)
        self.assertSlicePlotsListEqual(
            [(plots_slice_key, plots_for_slice.plots)], plots)
        got_eval_config = model_eval_lib.load_eval_config(output_path)
        self.assertEqual(eval_config, got_eval_config)