def testUpdateConfigWithDefaultsNoBaselineModelNonRubberstamp(self): eval_config_pbtxt = """ model_specs { name: "" } metrics_specs { metrics { class_name: "MeanLabel" per_slice_thresholds { slicing_specs: {} threshold { value_threshold { lower_bound { value: 0.9 } } change_threshold { direction: HIGHER_IS_BETTER absolute { value: -1e-10 } } } } } } """ eval_config = text_format.Parse(eval_config_pbtxt, config_pb2.EvalConfig()) with self.assertRaises(RuntimeError): config_util.update_eval_config_with_defaults(eval_config, has_baseline=False, rubber_stamp=False)
def testUpdateConfigWithDefaultsDoesNotAutomaticallyAddBaselineModel(self): eval_config_pbtxt = """ model_specs { name: "model1" } model_specs { name: "model2" is_baseline: true } metrics_specs { metrics { class_name: "WeightedExampleCount" } } """ eval_config = text_format.Parse(eval_config_pbtxt, config_pb2.EvalConfig()) expected_eval_config_pbtxt = """ model_specs { name: "model1" } model_specs { name: "model2" is_baseline: true } metrics_specs { metrics { class_name: "WeightedExampleCount" } model_names: ["model1", "model2"] } """ expected_eval_config = text_format.Parse(expected_eval_config_pbtxt, config_pb2.EvalConfig()) got_eval_config = config_util.update_eval_config_with_defaults( eval_config, has_baseline=True) self.assertProtoEquals(got_eval_config, expected_eval_config)
def testUpdateConfigWithoutBaselineModelWhenModelNameProvided(self): eval_config_pbtxt = """ model_specs { name: "candidate" } model_specs { name: "baseline" is_baseline: true } metrics_specs { metrics { class_name: "WeightedExampleCount" } model_names: "candidate" } """ eval_config = text_format.Parse(eval_config_pbtxt, config_pb2.EvalConfig()) expected_eval_config_pbtxt = """ model_specs { name: "candidate" } model_specs { name: "baseline" is_baseline: true } metrics_specs { metrics { class_name: "WeightedExampleCount" } model_names: ["candidate"] } """ expected_eval_config = text_format.Parse(expected_eval_config_pbtxt, config_pb2.EvalConfig()) got_eval_config = config_util.update_eval_config_with_defaults( eval_config, has_baseline=True) self.assertProtoEquals(got_eval_config, expected_eval_config)
def testUpdateConfigWithDefaultsMultiModel(self): eval_config_pbtxt = """ model_specs { name: "model1" } model_specs { name: "model2" } metrics_specs { metrics { class_name: "WeightedExampleCount" } } metrics_specs { metrics { class_name: "MeanLabel" } model_names: ["model1"] } """ eval_config = text_format.Parse(eval_config_pbtxt, config_pb2.EvalConfig()) expected_eval_config_pbtxt = """ model_specs { name: "model1" } model_specs { name: "model2" } metrics_specs { metrics { class_name: "WeightedExampleCount" } model_names: ["model1", "model2"] } metrics_specs { metrics { class_name: "MeanLabel" } model_names: ["model1"] } """ expected_eval_config = text_format.Parse(expected_eval_config_pbtxt, config_pb2.EvalConfig()) got_eval_config = config_util.update_eval_config_with_defaults( eval_config) self.assertProtoEquals(got_eval_config, expected_eval_config)
def testUpdateConfigWithDefaultsEmtpyModelName(self): eval_config_pbtxt = """ model_specs { name: "" } metrics_specs { metrics { class_name: "ExampleCount" } } """ eval_config = text_format.Parse(eval_config_pbtxt, config_pb2.EvalConfig()) expected_eval_config_pbtxt = """ model_specs { name: "" } metrics_specs { metrics { class_name: "ExampleCount" } model_names: [""] } """ expected_eval_config = text_format.Parse(expected_eval_config_pbtxt, config_pb2.EvalConfig()) got_eval_config = config_util.update_eval_config_with_defaults( eval_config) self.assertProtoEquals(got_eval_config, expected_eval_config)
def testUpdateConfigWithDefaultsAutomaticallyAddsBaselineModel(self): eval_config_pbtxt = """ model_specs { label_key: "my_label" } metrics_specs { metrics { class_name: "ExampleCount" } } """ eval_config = text_format.Parse(eval_config_pbtxt, config_pb2.EvalConfig()) expected_eval_config_pbtxt = """ model_specs { name: "candidate" label_key: "my_label" } model_specs { name: "baseline" label_key: "my_label" is_baseline: true } metrics_specs { metrics { class_name: "ExampleCount" } model_names: ["candidate", "baseline"] } """ expected_eval_config = text_format.Parse(expected_eval_config_pbtxt, config_pb2.EvalConfig()) got_eval_config = config_util.update_eval_config_with_defaults( eval_config, has_baseline=True) self.assertProtoEquals(got_eval_config, expected_eval_config)
def testUpdateConfigWithDefaultsRemoveBaselineModel(self): eval_config_pbtxt = """ model_specs { name: "candidate" } model_specs { name: "baseline" is_baseline: true } metrics_specs { metrics { class_name: "MeanLabel" threshold { value_threshold { lower_bound { value: 0.9 } } change_threshold { direction: HIGHER_IS_BETTER absolute { value: -1e-10 } } } per_slice_thresholds { threshold { value_threshold { lower_bound { value: 0.9 } } change_threshold { direction: HIGHER_IS_BETTER absolute { value: -1e-10 } } } } cross_slice_thresholds { threshold { value_threshold { lower_bound { value: 0.9 } } change_threshold { direction: HIGHER_IS_BETTER absolute { value: -1e-10 } } } } } thresholds { key: "my_metric" value { value_threshold { lower_bound { value: 0.9 } } change_threshold { direction: HIGHER_IS_BETTER absolute { value: -1e-10 } } } } per_slice_thresholds { key: "my_metric" value { thresholds { threshold { value_threshold { lower_bound { value: 0.9 } } change_threshold { direction: HIGHER_IS_BETTER absolute { value: -1e-10 } } } } } } cross_slice_thresholds { key: "my_metric" value { thresholds { threshold { value_threshold { lower_bound { value: 0.9 } } change_threshold { direction: HIGHER_IS_BETTER absolute { value: -1e-10 } } } } } } } """ eval_config = text_format.Parse(eval_config_pbtxt, config_pb2.EvalConfig()) expected_eval_config_pbtxt = """ model_specs {} metrics_specs { metrics { class_name: "MeanLabel" threshold { value_threshold { lower_bound { value: 0.9 } } } per_slice_thresholds { threshold { value_threshold { lower_bound { value: 0.9 } } } } cross_slice_thresholds { threshold { value_threshold { lower_bound { value: 0.9 } } } } } thresholds { key: "my_metric" value { value_threshold { lower_bound { value: 0.9 } } } } per_slice_thresholds { key: "my_metric" value { thresholds { threshold { value_threshold { lower_bound { value: 0.9 } } } } } } cross_slice_thresholds { key: "my_metric" value { thresholds { threshold { value_threshold { lower_bound { value: 0.9 } } } } } } model_names: [""] } """ expected_eval_config = text_format.Parse(expected_eval_config_pbtxt, config_pb2.EvalConfig()) # Only valid when rubber stamping. got_eval_config = config_util.update_eval_config_with_defaults( eval_config, has_baseline=False, rubber_stamp=True) self.assertProtoEquals(got_eval_config, expected_eval_config)