コード例 #1
0
    def testUpdateConfigWithDefaultsNoBaselineModelNonRubberstamp(self):
        eval_config_pbtxt = """
      model_specs { name: "" }
      metrics_specs {
        metrics {
          class_name: "MeanLabel"
          per_slice_thresholds {
            slicing_specs: {}
            threshold {
              value_threshold {
                lower_bound { value: 0.9 }
              }
              change_threshold {
                direction: HIGHER_IS_BETTER
                absolute { value: -1e-10 }
              }
            }
          }
        }
      }
    """

        eval_config = text_format.Parse(eval_config_pbtxt,
                                        config_pb2.EvalConfig())

        with self.assertRaises(RuntimeError):
            config_util.update_eval_config_with_defaults(eval_config,
                                                         has_baseline=False,
                                                         rubber_stamp=False)
コード例 #2
0
    def testUpdateConfigWithDefaultsDoesNotAutomaticallyAddBaselineModel(self):
        eval_config_pbtxt = """
      model_specs { name: "model1" }
      model_specs { name: "model2" is_baseline: true }
      metrics_specs {
        metrics { class_name: "WeightedExampleCount" }
      }
    """
        eval_config = text_format.Parse(eval_config_pbtxt,
                                        config_pb2.EvalConfig())

        expected_eval_config_pbtxt = """
      model_specs { name: "model1" }
      model_specs { name: "model2" is_baseline: true }
      metrics_specs {
        metrics { class_name: "WeightedExampleCount" }
        model_names: ["model1", "model2"]
      }
    """
        expected_eval_config = text_format.Parse(expected_eval_config_pbtxt,
                                                 config_pb2.EvalConfig())

        got_eval_config = config_util.update_eval_config_with_defaults(
            eval_config, has_baseline=True)
        self.assertProtoEquals(got_eval_config, expected_eval_config)
コード例 #3
0
    def testUpdateConfigWithoutBaselineModelWhenModelNameProvided(self):
        eval_config_pbtxt = """
      model_specs { name: "candidate" }
      model_specs { name: "baseline" is_baseline: true }
      metrics_specs {
        metrics { class_name: "WeightedExampleCount" }
        model_names: "candidate"
      }
    """
        eval_config = text_format.Parse(eval_config_pbtxt,
                                        config_pb2.EvalConfig())

        expected_eval_config_pbtxt = """
      model_specs { name: "candidate" }
      model_specs { name: "baseline" is_baseline: true }
      metrics_specs {
        metrics { class_name: "WeightedExampleCount" }
        model_names: ["candidate"]
      }
    """
        expected_eval_config = text_format.Parse(expected_eval_config_pbtxt,
                                                 config_pb2.EvalConfig())

        got_eval_config = config_util.update_eval_config_with_defaults(
            eval_config, has_baseline=True)
        self.assertProtoEquals(got_eval_config, expected_eval_config)
コード例 #4
0
    def testUpdateConfigWithDefaultsMultiModel(self):
        eval_config_pbtxt = """
      model_specs { name: "model1" }
      model_specs { name: "model2" }
      metrics_specs {
        metrics { class_name: "WeightedExampleCount" }
      }
      metrics_specs {
        metrics { class_name: "MeanLabel" }
        model_names: ["model1"]
      }
    """
        eval_config = text_format.Parse(eval_config_pbtxt,
                                        config_pb2.EvalConfig())

        expected_eval_config_pbtxt = """
      model_specs { name: "model1" }
      model_specs { name: "model2" }
      metrics_specs {
        metrics { class_name: "WeightedExampleCount" }
        model_names: ["model1", "model2"]
      }
      metrics_specs {
        metrics { class_name: "MeanLabel" }
        model_names: ["model1"]
      }
    """
        expected_eval_config = text_format.Parse(expected_eval_config_pbtxt,
                                                 config_pb2.EvalConfig())

        got_eval_config = config_util.update_eval_config_with_defaults(
            eval_config)
        self.assertProtoEquals(got_eval_config, expected_eval_config)
コード例 #5
0
    def testUpdateConfigWithDefaultsEmtpyModelName(self):
        eval_config_pbtxt = """
      model_specs { name: "" }
      metrics_specs {
        metrics { class_name: "ExampleCount" }
      }
    """
        eval_config = text_format.Parse(eval_config_pbtxt,
                                        config_pb2.EvalConfig())

        expected_eval_config_pbtxt = """
      model_specs { name: "" }
      metrics_specs {
        metrics { class_name: "ExampleCount" }
        model_names: [""]
      }
    """
        expected_eval_config = text_format.Parse(expected_eval_config_pbtxt,
                                                 config_pb2.EvalConfig())

        got_eval_config = config_util.update_eval_config_with_defaults(
            eval_config)
        self.assertProtoEquals(got_eval_config, expected_eval_config)
コード例 #6
0
    def testUpdateConfigWithDefaultsAutomaticallyAddsBaselineModel(self):
        eval_config_pbtxt = """
      model_specs { label_key: "my_label" }
      metrics_specs {
        metrics { class_name: "ExampleCount" }
      }
    """
        eval_config = text_format.Parse(eval_config_pbtxt,
                                        config_pb2.EvalConfig())

        expected_eval_config_pbtxt = """
      model_specs { name: "candidate" label_key: "my_label" }
      model_specs { name: "baseline" label_key: "my_label" is_baseline: true }
      metrics_specs {
        metrics { class_name: "ExampleCount" }
        model_names: ["candidate", "baseline"]
      }
    """
        expected_eval_config = text_format.Parse(expected_eval_config_pbtxt,
                                                 config_pb2.EvalConfig())

        got_eval_config = config_util.update_eval_config_with_defaults(
            eval_config, has_baseline=True)
        self.assertProtoEquals(got_eval_config, expected_eval_config)
コード例 #7
0
    def testUpdateConfigWithDefaultsRemoveBaselineModel(self):
        eval_config_pbtxt = """
      model_specs { name: "candidate" }
      model_specs { name: "baseline" is_baseline: true }
      metrics_specs {
        metrics {
          class_name: "MeanLabel"
          threshold {
            value_threshold {
              lower_bound { value: 0.9 }
            }
            change_threshold {
              direction: HIGHER_IS_BETTER
              absolute { value: -1e-10 }
            }
          }
          per_slice_thresholds {
            threshold {
              value_threshold {
                lower_bound { value: 0.9 }
              }
              change_threshold {
                direction: HIGHER_IS_BETTER
                absolute { value: -1e-10 }
              }
            }
          }
          cross_slice_thresholds {
            threshold {
              value_threshold {
                lower_bound { value: 0.9 }
              }
              change_threshold {
                direction: HIGHER_IS_BETTER
                absolute { value: -1e-10 }
              }
            }
          }
        }
        thresholds {
          key: "my_metric"
          value {
            value_threshold {
              lower_bound { value: 0.9 }
            }
            change_threshold {
              direction: HIGHER_IS_BETTER
              absolute { value: -1e-10 }
            }
          }
        }
        per_slice_thresholds {
          key: "my_metric"
          value {
            thresholds {
              threshold {
                value_threshold {
                  lower_bound { value: 0.9 }
                }
                change_threshold {
                  direction: HIGHER_IS_BETTER
                  absolute { value: -1e-10 }
                }
              }
            }
          }
        }
        cross_slice_thresholds {
          key: "my_metric"
          value {
            thresholds {
              threshold {
                value_threshold {
                  lower_bound { value: 0.9 }
                }
                change_threshold {
                  direction: HIGHER_IS_BETTER
                  absolute { value: -1e-10 }
                }
              }
            }
          }
        }
      }
    """
        eval_config = text_format.Parse(eval_config_pbtxt,
                                        config_pb2.EvalConfig())

        expected_eval_config_pbtxt = """
      model_specs {}
      metrics_specs {
        metrics {
          class_name: "MeanLabel"
          threshold {
            value_threshold {
              lower_bound { value: 0.9 }
            }
          }
          per_slice_thresholds {
            threshold {
              value_threshold {
                lower_bound { value: 0.9 }
              }
            }
          }
          cross_slice_thresholds {
            threshold {
              value_threshold {
                lower_bound { value: 0.9 }
              }
            }
          }
        }
        thresholds {
          key: "my_metric"
          value {
            value_threshold {
              lower_bound { value: 0.9 }
            }
          }
        }
        per_slice_thresholds {
          key: "my_metric"
          value {
            thresholds {
              threshold {
                value_threshold {
                  lower_bound { value: 0.9 }
                }
              }
            }
          }
        }
        cross_slice_thresholds {
          key: "my_metric"
          value {
            thresholds {
              threshold {
                value_threshold {
                  lower_bound { value: 0.9 }
                }
              }
            }
          }
        }
        model_names: [""]
      }
    """
        expected_eval_config = text_format.Parse(expected_eval_config_pbtxt,
                                                 config_pb2.EvalConfig())

        # Only valid when rubber stamping.
        got_eval_config = config_util.update_eval_config_with_defaults(
            eval_config, has_baseline=False, rubber_stamp=True)
        self.assertProtoEquals(got_eval_config, expected_eval_config)