Пример #1
0
def load_eval_result(output_path):
    """Creates an EvalResult object for use with the visualization functions."""
    metrics_proto_list = metrics_and_plots_evaluator.load_and_deserialize_metrics(
        path=os.path.join(output_path, _METRICS_OUTPUT_FILE))
    plots_proto_list = metrics_and_plots_evaluator.load_and_deserialize_plots(
        path=os.path.join(output_path, _PLOTS_OUTPUT_FILE))

    slicing_metrics = [(key, _convert_metric_map_to_dict(metrics_data))
                       for key, metrics_data in metrics_proto_list]
    plots = [(key, json_format.MessageToDict(plot_data))
             for key, plot_data in plots_proto_list]

    eval_config = load_eval_config(output_path)
    return EvalResult(slicing_metrics=slicing_metrics,
                      plots=plots,
                      config=eval_config)
    def testSerializeDeserializeToFile(self):
        metrics_slice_key = _make_slice_key(b'fruit', b'pear', b'animal',
                                            b'duck')
        metrics_for_slice = text_format.Parse(
            """
        slice_key {
          single_slice_keys {
            column: "fruit"
            bytes_value: "pear"
          }
          single_slice_keys {
            column: "animal"
            bytes_value: "duck"
          }
        }
        metrics {
          key: "accuracy"
          value {
            double_value {
              value: 0.8
            }
          }
        }
        metrics {
          key: "example_weight"
          value {
            double_value {
              value: 10.0
            }
          }
        }
        metrics {
          key: "auc"
          value {
            bounded_value {
              lower_bound {
                value: 0.1
              }
              upper_bound {
                value: 0.3
              }
              value {
                value: 0.2
              }
            }
          }
        }
        metrics {
          key: "auprc"
          value {
            bounded_value {
              lower_bound {
                value: 0.05
              }
              upper_bound {
                value: 0.17
              }
              value {
                value: 0.1
              }
            }
          }
        }""", metrics_for_slice_pb2.MetricsForSlice())
        plots_for_slice = text_format.Parse(
            """
        slice_key {
          single_slice_keys {
            column: "fruit"
            bytes_value: "peach"
          }
          single_slice_keys {
            column: "animal"
            bytes_value: "cow"
          }
        }
        plots {
          key: ''
          value {
            calibration_histogram_buckets {
              buckets {
                lower_threshold_inclusive: -inf
                upper_threshold_exclusive: 0.0
                num_weighted_examples { value: 0.0 }
                total_weighted_label { value: 0.0 }
                total_weighted_refined_prediction { value: 0.0 }
              }
              buckets {
                lower_threshold_inclusive: 0.0
                upper_threshold_exclusive: 0.5
                num_weighted_examples { value: 1.0 }
                total_weighted_label { value: 1.0 }
                total_weighted_refined_prediction { value: 0.3 }
              }
              buckets {
                lower_threshold_inclusive: 0.5
                upper_threshold_exclusive: 1.0
                num_weighted_examples { value: 1.0 }
                total_weighted_label { value: 0.0 }
                total_weighted_refined_prediction { value: 0.7 }
              }
              buckets {
                lower_threshold_inclusive: 1.0
                upper_threshold_exclusive: inf
                num_weighted_examples { value: 0.0 }
                total_weighted_label { value: 0.0 }
                total_weighted_refined_prediction { value: 0.0 }
              }
            }
          }
        }""", metrics_for_slice_pb2.PlotsForSlice())
        plots_slice_key = _make_slice_key(b'fruit', b'peach', b'animal',
                                          b'cow')
        eval_config = model_eval_lib.EvalConfig(
            model_location='/path/to/model',
            data_location='/path/to/data',
            slice_spec=[
                slicer.SingleSliceSpec(features=[('age', 5), ('gender', 'f')],
                                       columns=['country']),
                slicer.SingleSliceSpec(features=[('age', 6), ('gender', 'm')],
                                       columns=['interest'])
            ],
            example_weight_metric_key='key')

        output_path = self._getTempDir()
        with beam.Pipeline() as pipeline:
            metrics = (pipeline
                       | 'CreateMetrics' >> beam.Create(
                           [metrics_for_slice.SerializeToString()]))
            plots = (pipeline
                     | 'CreatePlots' >> beam.Create(
                         [plots_for_slice.SerializeToString()]))
            evaluation = {
                constants.METRICS_KEY: metrics,
                constants.PLOTS_KEY: plots
            }
            _ = (evaluation
                 | 'WriteResults' >> model_eval_lib.WriteResults(
                     writers=model_eval_lib.default_writers(
                         output_path=output_path)))
            _ = pipeline | model_eval_lib.WriteEvalConfig(
                eval_config, output_path)

        metrics = metrics_and_plots_evaluator.load_and_deserialize_metrics(
            path=os.path.join(output_path,
                              model_eval_lib._METRICS_OUTPUT_FILE))
        plots = metrics_and_plots_evaluator.load_and_deserialize_plots(
            path=os.path.join(output_path, model_eval_lib._PLOTS_OUTPUT_FILE))
        self.assertSliceMetricsListEqual(
            [(metrics_slice_key, metrics_for_slice.metrics)], metrics)
        self.assertSlicePlotsListEqual(
            [(plots_slice_key, plots_for_slice.plots)], plots)
        got_eval_config = model_eval_lib.load_eval_config(output_path)
        self.assertEqual(eval_config, got_eval_config)