Пример #1
0
def default_writers(
    output_path: Optional[Text],
    eval_shared_model: Optional[Union[types.EvalSharedModel,
                                      Dict[Text,
                                           types.EvalSharedModel]]] = None
) -> List[writer.Writer]:  # pylint: disable=invalid-name
    """Returns the default writers for use in WriteResults.

  Args:
    output_path: Output path.
    eval_shared_model: Optional shared model (single-model evaluation) or dict
      of shared models keyed by model name (multi-model evaluation). Only
      required if legacy add_metrics_callbacks are used.
  """
    add_metric_callbacks = []
    if eval_shared_model:
        eval_shared_models = eval_shared_model
        if not isinstance(eval_shared_model, dict):
            eval_shared_models = {'': eval_shared_model}
        for v in eval_shared_models.values():
            add_metric_callbacks.extend(v.add_metrics_callbacks)

    output_paths = {
        constants.METRICS_KEY: os.path.join(output_path,
                                            constants.METRICS_KEY),
        constants.PLOTS_KEY: os.path.join(output_path, constants.PLOTS_KEY)
    }
    return [
        metrics_and_plots_writer.MetricsAndPlotsWriter(
            output_paths=output_paths,
            add_metrics_callbacks=add_metric_callbacks)
    ]
def default_writers(
    eval_shared_model: Optional[types.EvalSharedModel] = None,
    eval_shared_models: Optional[List[types.EvalSharedModel]] = None,
    output_path: Optional[Text] = None,
    eval_config: config.EvalConfig = None,
) -> List[writer.Writer]:  # pylint: disable=invalid-name
    """Returns the default writers for use in WriteResults.

  Args:
    eval_shared_model: Shared model (single-model evaluation).
    eval_shared_models: Shared models (multi-model evaluation).
    output_path: Deprecated (use EvalConfig).
    eval_config: Eval config.
  """
    # TODO(b/141016373): Add support for multiple models.
    if eval_config is not None:
        output_spec = eval_config.output_data_specs[0]
    elif output_path is not None:
        output_spec = config.OutputDataSpec(default_location=output_path)
    if eval_shared_model is not None:
        eval_shared_models = [eval_shared_model]
    output_paths = {
        constants.METRICS_KEY: output_filename(output_spec,
                                               constants.METRICS_KEY),
        constants.PLOTS_KEY: output_filename(output_spec, constants.PLOTS_KEY)
    }
    return [
        metrics_and_plots_writer.MetricsAndPlotsWriter(
            eval_shared_model=eval_shared_models[0], output_paths=output_paths)
    ]
def default_writers(eval_shared_model: types.EvalSharedModel,
                    output_path: Text) -> List[writer.Writer]:  # pylint: disable=invalid-name
  """Returns the default writers for use in WriteResults.

  Args:
    eval_shared_model: Shared model parameters for EvalSavedModel.
    output_path: Path to store results files under.
  """
  output_paths = {
      constants.METRICS_KEY: os.path.join(output_path, _METRICS_OUTPUT_FILE),
      constants.PLOTS_KEY: os.path.join(output_path, _PLOTS_OUTPUT_FILE)
  }
  return [
      metrics_and_plots_writer.MetricsAndPlotsWriter(
          eval_shared_model=eval_shared_model, output_paths=output_paths)
  ]
    def testWriteMetricsAndPlots(self):
        metrics_file = os.path.join(self._getTempDir(), 'metrics')
        plots_file = os.path.join(self._getTempDir(), 'plots')
        temp_eval_export_dir = os.path.join(self._getTempDir(),
                                            'eval_export_dir')

        _, eval_export_dir = (
            fixed_prediction_estimator.simple_fixed_prediction_estimator(
                None, temp_eval_export_dir))
        eval_config = config.EvalConfig(
            model_specs=[config.ModelSpec()],
            options=config.Options(disabled_outputs=['eval_config.json']))
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=eval_export_dir,
            add_metrics_callbacks=[
                post_export_metrics.example_count(),
                post_export_metrics.calibration_plot_and_prediction_histogram(
                    num_buckets=2)
            ])
        extractors = [
            predict_extractor.PredictExtractor(eval_shared_model),
            slice_key_extractor.SliceKeyExtractor()
        ]
        evaluators = [
            metrics_and_plots_evaluator.MetricsAndPlotsEvaluator(
                eval_shared_model)
        ]
        output_paths = {
            constants.METRICS_KEY: metrics_file,
            constants.PLOTS_KEY: plots_file
        }
        writers = [
            metrics_and_plots_writer.MetricsAndPlotsWriter(
                output_paths, eval_shared_model.add_metrics_callbacks)
        ]

        with beam.Pipeline() as pipeline:
            example1 = self._makeExample(prediction=0.0, label=1.0)
            example2 = self._makeExample(prediction=1.0, label=1.0)

            # pylint: disable=no-value-for-parameter
            _ = (pipeline
                 | 'Create' >> beam.Create([
                     example1.SerializeToString(),
                     example2.SerializeToString(),
                 ])
                 | 'ExtractEvaluateAndWriteResults' >>
                 model_eval_lib.ExtractEvaluateAndWriteResults(
                     eval_config=eval_config,
                     eval_shared_model=eval_shared_model,
                     extractors=extractors,
                     evaluators=evaluators,
                     writers=writers))
            # pylint: enable=no-value-for-parameter

        expected_metrics_for_slice = text_format.Parse(
            """
        slice_key {}
        metrics {
          key: "average_loss"
          value {
            double_value {
              value: 0.5
            }
          }
        }
        metrics {
          key: "post_export_metrics/example_count"
          value {
            double_value {
              value: 2.0
            }
          }
        }
        """, metrics_for_slice_pb2.MetricsForSlice())

        metric_records = []
        for record in tf.compat.v1.python_io.tf_record_iterator(metrics_file):
            metric_records.append(
                metrics_for_slice_pb2.MetricsForSlice.FromString(record))
        self.assertEqual(1, len(metric_records),
                         'metrics: %s' % metric_records)
        self.assertProtoEquals(expected_metrics_for_slice, metric_records[0])

        expected_plots_for_slice = text_format.Parse(
            """
      slice_key {}
      plots {
        key: "post_export_metrics"
        value {
          calibration_histogram_buckets {
            buckets {
              lower_threshold_inclusive: -inf
              num_weighted_examples {}
              total_weighted_label {}
              total_weighted_refined_prediction {}
            }
            buckets {
              upper_threshold_exclusive: 0.5
              num_weighted_examples {
                value: 1.0
              }
              total_weighted_label {
                value: 1.0
              }
              total_weighted_refined_prediction {}
            }
            buckets {
              lower_threshold_inclusive: 0.5
              upper_threshold_exclusive: 1.0
              num_weighted_examples {
              }
              total_weighted_label {}
              total_weighted_refined_prediction {}
            }
            buckets {
              lower_threshold_inclusive: 1.0
              upper_threshold_exclusive: inf
              num_weighted_examples {
                value: 1.0
              }
              total_weighted_label {
                value: 1.0
              }
              total_weighted_refined_prediction {
                value: 1.0
              }
            }
         }
        }
      }
    """, metrics_for_slice_pb2.PlotsForSlice())

        plot_records = []
        for record in tf.compat.v1.python_io.tf_record_iterator(plots_file):
            plot_records.append(
                metrics_for_slice_pb2.PlotsForSlice.FromString(record))
        self.assertEqual(1, len(plot_records), 'plots: %s' % plot_records)
        self.assertProtoEquals(expected_plots_for_slice, plot_records[0])