def _makeEvalConfig(self): eval_config = model_eval_lib.EvalConfig( example_weight_metric_key='testing_key', slice_spec=None, data_location='', model_location='') return eval_config
def _makeEvalResults(self): result_a = model_eval_lib.EvalResult( slicing_metrics=self._makeTestData(), plots=None, config=model_eval_lib.EvalConfig( example_weight_metric_key=None, slice_spec=None, data_location=self.data_location_1, model_location=self.model_location_1)) result_b = model_eval_lib.EvalResult( slicing_metrics=[self.result_c2], plots=None, config=model_eval_lib.EvalConfig( example_weight_metric_key=None, slice_spec=None, data_location=self.full_data_location_2, model_location=self.full_model_location_2)) return model_eval_lib.EvalResults([result_a, result_b], constants.MODEL_CENTRIC_MODE)
def testSerializeDeserializeEvalConfig(self): eval_config = model_eval_lib.EvalConfig( model_location='/path/to/model', data_location='/path/to/data', slice_spec=[ slicer.SingleSliceSpec( features=[('age', 5), ('gender', 'f')], columns=['country']), slicer.SingleSliceSpec( features=[('age', 6), ('gender', 'm')], columns=['interest']) ], example_weight_metric_key='key') serialized = model_eval_lib._serialize_eval_config(eval_config) deserialized = pickle.loads(serialized) got_eval_config = deserialized[model_eval_lib._EVAL_CONFIG_KEY] self.assertEqual(eval_config, got_eval_config)
def testSerializeDeserializeToFile(self): metrics_slice_key = _make_slice_key(b'fruit', b'pear', b'animal', b'duck') metrics_for_slice = text_format.Parse( """ slice_key { single_slice_keys { column: "fruit" bytes_value: "pear" } single_slice_keys { column: "animal" bytes_value: "duck" } } metrics { key: "accuracy" value { double_value { value: 0.8 } } } metrics { key: "example_weight" value { double_value { value: 10.0 } } } metrics { key: "auc" value { bounded_value { lower_bound { value: 0.1 } upper_bound { value: 0.3 } value { value: 0.2 } } } } metrics { key: "auprc" value { bounded_value { lower_bound { value: 0.05 } upper_bound { value: 0.17 } value { value: 0.1 } } } }""", metrics_for_slice_pb2.MetricsForSlice()) plots_for_slice = text_format.Parse( """ slice_key { single_slice_keys { column: "fruit" bytes_value: "peach" } single_slice_keys { column: "animal" bytes_value: "cow" } } plots { key: '' value { calibration_histogram_buckets { buckets { lower_threshold_inclusive: -inf upper_threshold_exclusive: 0.0 num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } buckets { lower_threshold_inclusive: 0.0 upper_threshold_exclusive: 0.5 num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction { value: 0.3 } } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 1.0 num_weighted_examples { value: 1.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.7 } } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } } } }""", metrics_for_slice_pb2.PlotsForSlice()) plots_slice_key = _make_slice_key(b'fruit', b'peach', b'animal', b'cow') eval_config = model_eval_lib.EvalConfig( model_location='/path/to/model', data_location='/path/to/data', slice_spec=[ slicer.SingleSliceSpec(features=[('age', 5), ('gender', 'f')], columns=['country']), slicer.SingleSliceSpec(features=[('age', 6), ('gender', 'm')], columns=['interest']) ], example_weight_metric_key='key') output_path = self._getTempDir() with beam.Pipeline() as pipeline: metrics = (pipeline | 'CreateMetrics' >> beam.Create( [metrics_for_slice.SerializeToString()])) plots = (pipeline | 'CreatePlots' >> beam.Create( [plots_for_slice.SerializeToString()])) evaluation = { constants.METRICS_KEY: metrics, constants.PLOTS_KEY: plots } _ = (evaluation | 'WriteResults' >> model_eval_lib.WriteResults( writers=model_eval_lib.default_writers( output_path=output_path))) _ = pipeline | model_eval_lib.WriteEvalConfig( eval_config, output_path) metrics = metrics_and_plots_evaluator.load_and_deserialize_metrics( path=os.path.join(output_path, model_eval_lib._METRICS_OUTPUT_FILE)) plots = metrics_and_plots_evaluator.load_and_deserialize_plots( path=os.path.join(output_path, model_eval_lib._PLOTS_OUTPUT_FILE)) self.assertSliceMetricsListEqual( [(metrics_slice_key, metrics_for_slice.metrics)], metrics) self.assertSlicePlotsListEqual( [(plots_slice_key, plots_for_slice.plots)], plots) got_eval_config = model_eval_lib.load_eval_config(output_path) self.assertEqual(eval_config, got_eval_config)