def _serialize_plots(
        plots: Tuple[slicer.SliceKeyType, Dict[Text, Any]],
        post_export_metrics: List[types.AddMetricsCallbackType]) -> bytes:
    """Converts the given slice plots into serialized proto PlotsForSlice..

  Args:
    plots: The slice plots.
    post_export_metrics: A list of metric callbacks. This should be the same
      list as the one passed to tfma.Evaluate().

  Returns:
    The serialized proto PlotsForSlice.
  """
    result = metrics_for_slice_pb2.PlotsForSlice()
    slice_key, slice_plots = plots

    if metric_keys.ERROR_METRIC in slice_plots:
        tf.compat.v1.logging.warning(
            'Error for slice: %s with error message: %s ', slice_key,
            slice_plots[metric_keys.ERROR_METRIC])
        metrics = metrics_for_slice_pb2.PlotsForSlice()
        metrics.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))
        metrics.plots[metric_keys.ERROR_METRIC].debug_message = slice_plots[
            metric_keys.ERROR_METRIC]
        return metrics.SerializeToString()

    # Convert the slice key.
    result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))

    # Convert the slice plots.
    _convert_slice_plots(slice_plots, post_export_metrics, result.plots)  # pytype: disable=wrong-arg-types

    return result.SerializeToString()
Example #2
0
  def testConvertSlicePlotsToProtoEmptyPlot(self):
    slice_key = _make_slice_key('fruit', 'apple')
    tfma_plots = {metric_keys.ERROR_METRIC: 'error_message'}

    actual_plot = metrics_plots_and_validations_writer.convert_slice_plots_to_proto(
        (slice_key, tfma_plots), [])
    expected_plot = metrics_for_slice_pb2.PlotsForSlice()
    expected_plot.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))
    expected_plot.plots[
        metric_keys.ERROR_METRIC].debug_message = 'error_message'
    self.assertProtoEquals(expected_plot, actual_plot)
Example #3
0
  def testSerializePlots_emptyPlot(self):
    slice_key = _make_slice_key('fruit', 'apple')
    tfma_plots = {metric_keys.ERROR_METRIC: 'error_message'}

    actual_plot = metrics_and_plots_serialization._serialize_plots(
        (slice_key, tfma_plots), [])
    expected_plot = metrics_for_slice_pb2.PlotsForSlice()
    expected_plot.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))
    expected_plot.plots[
        metric_keys.ERROR_METRIC].debug_message = 'error_message'
    self.assertProtoEquals(
        expected_plot,
        metrics_for_slice_pb2.PlotsForSlice.FromString(actual_plot))
Example #4
0
def _serialize_plots(plots, post_export_metrics):
    """Converts the given slice plots into serialized proto PlotsForSlice..

  Args:
    plots: The slice plots.
    post_export_metrics: A list of metric callbacks. This should be the same
      list as the one passed to tfma.Evaluate().

  Returns:
    The serialized proto PlotsForSlice.
  """
    result = metrics_for_slice_pb2.PlotsForSlice()
    slice_key, slice_plots = plots

    # Convert the slice key.
    result.slice_key.CopyFrom(_convert_slice_key(slice_key))

    # Convert the slice plots.
    _convert_slice_plots(slice_plots, post_export_metrics, result.plot_data)  # pytype: disable=wrong-arg-types

    return result.SerializeToString()
Example #5
0
def _serialize_plots(
    plots: Tuple[slicer.SliceKeyType, Dict[Text, Any]],
    post_export_metrics: List[types.AddMetricsCallbackType]) -> bytes:
  """Converts the given slice plots into serialized proto PlotsForSlice..

  Args:
    plots: The slice plots.
    post_export_metrics: A list of metric callbacks. This should be the same
      list as the one passed to tfma.Evaluate().

  Returns:
    The serialized proto PlotsForSlice.
  """
  result = metrics_for_slice_pb2.PlotsForSlice()
  slice_key, slice_plots = plots

  # Convert the slice key.
  result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))

  # Convert the slice plots.
  _convert_slice_plots(slice_plots, post_export_metrics, result.plots)  # pytype: disable=wrong-arg-types

  return result.SerializeToString()
def convert_slice_plots_to_proto(
    plots: Tuple[slicer.SliceKeyType, Dict[Any, Any]],
    add_metrics_callbacks: List[types.AddMetricsCallbackType]
) -> metrics_for_slice_pb2.PlotsForSlice:
    """Converts the given slice plots into PlotsForSlice proto.

  Args:
    plots: The slice plots.
    add_metrics_callbacks: A list of metric callbacks. This should be the same
      list as the one passed to tfma.Evaluate().

  Returns:
    The PlotsForSlice proto.
  """
    result = metrics_for_slice_pb2.PlotsForSlice()
    slice_key, slice_plots = plots

    result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))

    slice_plots = slice_plots.copy()

    if metric_keys.ERROR_METRIC in slice_plots:
        logging.warning('Error for slice: %s with error message: %s ',
                        slice_key, slice_plots[metric_keys.ERROR_METRIC])
        error_metric = slice_plots.pop(metric_keys.ERROR_METRIC)
        result.plots[metric_keys.ERROR_METRIC].debug_message = error_metric
        return result

    if add_metrics_callbacks and (not any(
            isinstance(k, metric_types.MetricKey)
            for k in slice_plots.keys())):
        for add_metrics_callback in add_metrics_callbacks:
            if hasattr(add_metrics_callback, 'populate_plots_and_pop'):
                add_metrics_callback.populate_plots_and_pop(
                    slice_plots, result.plots)
    plots_by_key = {}
    for key in sorted(slice_plots.keys()):
        value = slice_plots[key]
        # Remove plot name from key (multiple plots are combined into a single
        # proto).
        if isinstance(key, metric_types.MetricKey):
            parent_key = key._replace(name=None)
        else:
            continue
        if parent_key not in plots_by_key:
            key_and_value = result.plot_keys_and_values.add()
            key_and_value.key.CopyFrom(parent_key.to_proto())
            plots_by_key[parent_key] = key_and_value.value

        if isinstance(value,
                      metrics_for_slice_pb2.CalibrationHistogramBuckets):
            plots_by_key[parent_key].calibration_histogram_buckets.CopyFrom(
                value)
            slice_plots.pop(key)
        elif isinstance(value,
                        metrics_for_slice_pb2.ConfusionMatrixAtThresholds):
            plots_by_key[parent_key].confusion_matrix_at_thresholds.CopyFrom(
                value)
            slice_plots.pop(key)
        elif isinstance(
                value,
                metrics_for_slice_pb2.MultiClassConfusionMatrixAtThresholds):
            plots_by_key[
                parent_key].multi_class_confusion_matrix_at_thresholds.CopyFrom(
                    value)
            slice_plots.pop(key)
        elif isinstance(
                value,
                metrics_for_slice_pb2.MultiLabelConfusionMatrixAtThresholds):
            plots_by_key[
                parent_key].multi_label_confusion_matrix_at_thresholds.CopyFrom(
                    value)
            slice_plots.pop(key)

    if slice_plots:
        if add_metrics_callbacks is None:
            add_metrics_callbacks = []
        raise NotImplementedError(
            'some plots were not converted or popped. keys: %s. '
            'add_metrics_callbacks were: %s' % (
                slice_plots.keys(),
                [
                    x.name for x in add_metrics_callbacks  # pytype: disable=attribute-error
                ]))

    return result
  def testWriteMetricsAndPlots(self):
    metrics_file = os.path.join(self._getTempDir(), 'metrics')
    plots_file = os.path.join(self._getTempDir(), 'plots')
    temp_eval_export_dir = os.path.join(self._getTempDir(), 'eval_export_dir')

    _, eval_export_dir = (
        fixed_prediction_estimator.simple_fixed_prediction_estimator(
            None, temp_eval_export_dir))
    eval_config = config.EvalConfig(
        model_specs=[config.ModelSpec()],
        options=config.Options(
            disabled_outputs={'values': ['eval_config.json']}))
    eval_shared_model = self.createTestEvalSharedModel(
        eval_saved_model_path=eval_export_dir,
        add_metrics_callbacks=[
            post_export_metrics.example_count(),
            post_export_metrics.calibration_plot_and_prediction_histogram(
                num_buckets=2)
        ])
    extractors = [
        predict_extractor.PredictExtractor(eval_shared_model),
        slice_key_extractor.SliceKeyExtractor()
    ]
    evaluators = [
        metrics_and_plots_evaluator.MetricsAndPlotsEvaluator(eval_shared_model)
    ]
    output_paths = {
        constants.METRICS_KEY: metrics_file,
        constants.PLOTS_KEY: plots_file
    }
    writers = [
        metrics_plots_and_validations_writer.MetricsPlotsAndValidationsWriter(
            output_paths, eval_shared_model.add_metrics_callbacks)
    ]

    with beam.Pipeline() as pipeline:
      example1 = self._makeExample(prediction=0.0, label=1.0)
      example2 = self._makeExample(prediction=1.0, label=1.0)

      # pylint: disable=no-value-for-parameter
      _ = (
          pipeline
          | 'Create' >> beam.Create([
              example1.SerializeToString(),
              example2.SerializeToString(),
          ])
          | 'ExtractEvaluateAndWriteResults' >>
          model_eval_lib.ExtractEvaluateAndWriteResults(
              eval_config=eval_config,
              eval_shared_model=eval_shared_model,
              extractors=extractors,
              evaluators=evaluators,
              writers=writers))
      # pylint: enable=no-value-for-parameter

    expected_metrics_for_slice = text_format.Parse(
        """
        slice_key {}
        metrics {
          key: "average_loss"
          value {
            double_value {
              value: 0.5
            }
          }
        }
        metrics {
          key: "post_export_metrics/example_count"
          value {
            double_value {
              value: 2.0
            }
          }
        }
        """, metrics_for_slice_pb2.MetricsForSlice())

    metric_records = []
    for record in tf.compat.v1.python_io.tf_record_iterator(metrics_file):
      metric_records.append(
          metrics_for_slice_pb2.MetricsForSlice.FromString(record))
    self.assertEqual(1, len(metric_records), 'metrics: %s' % metric_records)
    self.assertProtoEquals(expected_metrics_for_slice, metric_records[0])

    expected_plots_for_slice = text_format.Parse(
        """
      slice_key {}
      plots {
        key: "post_export_metrics"
        value {
          calibration_histogram_buckets {
            buckets {
              lower_threshold_inclusive: -inf
              num_weighted_examples {}
              total_weighted_label {}
              total_weighted_refined_prediction {}
            }
            buckets {
              upper_threshold_exclusive: 0.5
              num_weighted_examples {
                value: 1.0
              }
              total_weighted_label {
                value: 1.0
              }
              total_weighted_refined_prediction {}
            }
            buckets {
              lower_threshold_inclusive: 0.5
              upper_threshold_exclusive: 1.0
              num_weighted_examples {
              }
              total_weighted_label {}
              total_weighted_refined_prediction {}
            }
            buckets {
              lower_threshold_inclusive: 1.0
              upper_threshold_exclusive: inf
              num_weighted_examples {
                value: 1.0
              }
              total_weighted_label {
                value: 1.0
              }
              total_weighted_refined_prediction {
                value: 1.0
              }
            }
         }
        }
      }
    """, metrics_for_slice_pb2.PlotsForSlice())

    plot_records = []
    for record in tf.compat.v1.python_io.tf_record_iterator(plots_file):
      plot_records.append(
          metrics_for_slice_pb2.PlotsForSlice.FromString(record))
    self.assertEqual(1, len(plot_records), 'plots: %s' % plot_records)
    self.assertProtoEquals(expected_plots_for_slice, plot_records[0])
    def testSerializeDeserializeToFile(self):
        metrics_slice_key = _make_slice_key(b'fruit', b'pear', b'animal',
                                            b'duck')
        metrics_for_slice = text_format.Parse(
            """
        slice_key {
          single_slice_keys {
            column: "fruit"
            bytes_value: "pear"
          }
          single_slice_keys {
            column: "animal"
            bytes_value: "duck"
          }
        }
        metrics {
          key: "accuracy"
          value {
            double_value {
              value: 0.8
            }
          }
        }
        metrics {
          key: "example_weight"
          value {
            double_value {
              value: 10.0
            }
          }
        }
        metrics {
          key: "auc"
          value {
            bounded_value {
              lower_bound {
                value: 0.1
              }
              upper_bound {
                value: 0.3
              }
              value {
                value: 0.2
              }
            }
          }
        }
        metrics {
          key: "auprc"
          value {
            bounded_value {
              lower_bound {
                value: 0.05
              }
              upper_bound {
                value: 0.17
              }
              value {
                value: 0.1
              }
            }
          }
        }""", metrics_for_slice_pb2.MetricsForSlice())
        plots_for_slice = text_format.Parse(
            """
        slice_key {
          single_slice_keys {
            column: "fruit"
            bytes_value: "peach"
          }
          single_slice_keys {
            column: "animal"
            bytes_value: "cow"
          }
        }
        plots {
          key: ''
          value {
            calibration_histogram_buckets {
              buckets {
                lower_threshold_inclusive: -inf
                upper_threshold_exclusive: 0.0
                num_weighted_examples { value: 0.0 }
                total_weighted_label { value: 0.0 }
                total_weighted_refined_prediction { value: 0.0 }
              }
              buckets {
                lower_threshold_inclusive: 0.0
                upper_threshold_exclusive: 0.5
                num_weighted_examples { value: 1.0 }
                total_weighted_label { value: 1.0 }
                total_weighted_refined_prediction { value: 0.3 }
              }
              buckets {
                lower_threshold_inclusive: 0.5
                upper_threshold_exclusive: 1.0
                num_weighted_examples { value: 1.0 }
                total_weighted_label { value: 0.0 }
                total_weighted_refined_prediction { value: 0.7 }
              }
              buckets {
                lower_threshold_inclusive: 1.0
                upper_threshold_exclusive: inf
                num_weighted_examples { value: 0.0 }
                total_weighted_label { value: 0.0 }
                total_weighted_refined_prediction { value: 0.0 }
              }
            }
          }
        }""", metrics_for_slice_pb2.PlotsForSlice())
        plots_slice_key = _make_slice_key(b'fruit', b'peach', b'animal',
                                          b'cow')
        eval_config = model_eval_lib.EvalConfig(
            model_location='/path/to/model',
            data_location='/path/to/data',
            slice_spec=[
                slicer.SingleSliceSpec(features=[('age', 5), ('gender', 'f')],
                                       columns=['country']),
                slicer.SingleSliceSpec(features=[('age', 6), ('gender', 'm')],
                                       columns=['interest'])
            ],
            example_weight_metric_key='key')

        output_path = self._getTempDir()
        with beam.Pipeline() as pipeline:
            metrics = (pipeline
                       | 'CreateMetrics' >> beam.Create(
                           [metrics_for_slice.SerializeToString()]))
            plots = (pipeline
                     | 'CreatePlots' >> beam.Create(
                         [plots_for_slice.SerializeToString()]))
            evaluation = {
                constants.METRICS_KEY: metrics,
                constants.PLOTS_KEY: plots
            }
            _ = (evaluation
                 | 'WriteResults' >> model_eval_lib.WriteResults(
                     writers=model_eval_lib.default_writers(
                         output_path=output_path)))
            _ = pipeline | model_eval_lib.WriteEvalConfig(
                eval_config, output_path)

        metrics = metrics_and_plots_evaluator.load_and_deserialize_metrics(
            path=os.path.join(output_path,
                              model_eval_lib._METRICS_OUTPUT_FILE))
        plots = metrics_and_plots_evaluator.load_and_deserialize_plots(
            path=os.path.join(output_path, model_eval_lib._PLOTS_OUTPUT_FILE))
        self.assertSliceMetricsListEqual(
            [(metrics_slice_key, metrics_for_slice.metrics)], metrics)
        self.assertSlicePlotsListEqual(
            [(plots_slice_key, plots_for_slice.plots)], plots)
        got_eval_config = model_eval_lib.load_eval_config(output_path)
        self.assertEqual(eval_config, got_eval_config)
Example #9
0
    def testSerializePlots(self):
        slice_key = _make_slice_key('fruit', 'apple')
        plot_key = metric_types.PlotKey(name='calibration_plot',
                                        output_name='output_name')
        calibration_plot = text_format.Parse(
            """
        buckets {
          lower_threshold_inclusive: -inf
          upper_threshold_exclusive: 0.0
          num_weighted_examples { value: 0.0 }
          total_weighted_label { value: 0.0 }
          total_weighted_refined_prediction { value: 0.0 }
        }
        buckets {
          lower_threshold_inclusive: 0.0
          upper_threshold_exclusive: 0.5
          num_weighted_examples { value: 1.0 }
          total_weighted_label { value: 1.0 }
          total_weighted_refined_prediction { value: 0.3 }
        }
        buckets {
          lower_threshold_inclusive: 0.5
          upper_threshold_exclusive: 1.0
          num_weighted_examples { value: 1.0 }
          total_weighted_label { value: 0.0 }
          total_weighted_refined_prediction { value: 0.7 }
        }
        buckets {
          lower_threshold_inclusive: 1.0
          upper_threshold_exclusive: inf
          num_weighted_examples { value: 0.0 }
          total_weighted_label { value: 0.0 }
          total_weighted_refined_prediction { value: 0.0 }
        }
     """, metrics_for_slice_pb2.CalibrationHistogramBuckets())

        expected_plots_for_slice = text_format.Parse(
            """
      slice_key {
        single_slice_keys {
          column: 'fruit'
          bytes_value: 'apple'
        }
      }
      plot_keys_and_values {
        key {
          output_name: "output_name"
        }
        value {
          calibration_histogram_buckets {
            buckets {
              lower_threshold_inclusive: -inf
              upper_threshold_exclusive: 0.0
              num_weighted_examples { value: 0.0 }
              total_weighted_label { value: 0.0 }
              total_weighted_refined_prediction { value: 0.0 }
            }
            buckets {
              lower_threshold_inclusive: 0.0
              upper_threshold_exclusive: 0.5
              num_weighted_examples { value: 1.0 }
              total_weighted_label { value: 1.0 }
              total_weighted_refined_prediction { value: 0.3 }
            }
            buckets {
              lower_threshold_inclusive: 0.5
              upper_threshold_exclusive: 1.0
              num_weighted_examples { value: 1.0 }
              total_weighted_label { value: 0.0 }
              total_weighted_refined_prediction { value: 0.7 }
            }
            buckets {
              lower_threshold_inclusive: 1.0
              upper_threshold_exclusive: inf
              num_weighted_examples { value: 0.0 }
              total_weighted_label { value: 0.0 }
              total_weighted_refined_prediction { value: 0.0 }
            }
          }
        }
      }
    """, metrics_for_slice_pb2.PlotsForSlice())

        got = metrics_and_plots_serialization._serialize_plots(
            (slice_key, {
                plot_key: calibration_plot
            }), None)
        self.assertProtoEquals(
            expected_plots_for_slice,
            metrics_for_slice_pb2.PlotsForSlice.FromString(got))
Example #10
0
    def testSerializeDeserializeToFile(self):
        metrics_slice_key = _make_slice_key('fruit', 'pear', 'animal', 'duck')
        metrics_for_slice = text_format.Parse(
            """
        slice_key {
          single_slice_keys {
            column: "fruit"
            bytes_value: "pear"
          }
          single_slice_keys {
            column: "animal"
            bytes_value: "duck"
          }
        }
        metrics {
          key: "accuracy"
          value {
            double_value {
              value: 0.8
            }
          }
        }
        metrics {
          key: "example_weight"
          value {
            double_value {
              value: 10.0
            }
          }
        }
        metrics {
          key: "auc"
          value {
            bounded_value {
              lower_bound {
                value: 0.1
              }
              upper_bound {
                value: 0.3
              }
              value {
                value: 0.2
              }
            }
          }
        }
        metrics {
          key: "auprc"
          value {
            bounded_value {
              lower_bound {
                value: 0.05
              }
              upper_bound {
                value: 0.17
              }
              value {
                value: 0.1
              }
            }
          }
        }""", metrics_for_slice_pb2.MetricsForSlice())
        plots_for_slice = text_format.Parse(
            """
        slice_key {
          single_slice_keys {
            column: "fruit"
            bytes_value: "peach"
          }
          single_slice_keys {
            column: "animal"
            bytes_value: "cow"
          }
        }
        plot_data {
          calibration_histogram_buckets {
            buckets {
              lower_threshold_inclusive: -inf
              upper_threshold_exclusive: 0.0
              num_weighted_examples { value: 0.0 }
              total_weighted_label { value: 0.0 }
              total_weighted_refined_prediction { value: 0.0 }
            }
            buckets {
              lower_threshold_inclusive: 0.0
              upper_threshold_exclusive: 0.5
              num_weighted_examples { value: 1.0 }
              total_weighted_label { value: 1.0 }
              total_weighted_refined_prediction { value: 0.3 }
            }
            buckets {
              lower_threshold_inclusive: 0.5
              upper_threshold_exclusive: 1.0
              num_weighted_examples { value: 1.0 }
              total_weighted_label { value: 0.0 }
              total_weighted_refined_prediction { value: 0.7 }
            }
            buckets {
              lower_threshold_inclusive: 1.0
              upper_threshold_exclusive: inf
              num_weighted_examples { value: 0.0 }
              total_weighted_label { value: 0.0 }
              total_weighted_refined_prediction { value: 0.0 }
            }
          }
        }""", metrics_for_slice_pb2.PlotsForSlice())
        plots_slice_key = _make_slice_key('fruit', 'peach', 'animal', 'cow')
        eval_config = api_types.EvalConfig(
            model_location='/path/to/model',
            data_location='/path/to/data',
            slice_spec=[
                slicer.SingleSliceSpec(features=[('age', 5), ('gender', 'f')],
                                       columns=['country']),
                slicer.SingleSliceSpec(features=[('age', 6), ('gender', 'm')],
                                       columns=['interest'])
            ],
            example_weight_metric_key='key')

        output_path = self._getTempDir()
        with beam.Pipeline() as pipeline:
            metrics = (pipeline
                       | 'CreateMetrics' >> beam.Create(
                           [metrics_for_slice.SerializeToString()]))
            plots = (pipeline
                     | 'CreatePlots' >> beam.Create(
                         [plots_for_slice.SerializeToString()]))

            _ = ((metrics, plots)
                 | 'WriteMetricsPlotsAndConfig' >>
                 serialization.WriteMetricsPlotsAndConfig(
                     output_path=output_path, eval_config=eval_config))

        metrics, plots = serialization.load_plots_and_metrics(output_path)
        self.assertSliceMetricsListEqual(
            [(metrics_slice_key, metrics_for_slice.metrics)], metrics)
        self.assertSlicePlotsListEqual(
            [(plots_slice_key, plots_for_slice.plot_data)], plots)
        got_eval_config = serialization.load_eval_config(output_path)
        self.assertEqual(eval_config, got_eval_config)