def load_and_deserialize_plots(path):
    """Returns deserialized plots loaded from given path."""
    result = []
    for record in tf.python_io.tf_record_iterator(path):
        plots_for_slice = metrics_for_slice_pb2.PlotsForSlice.FromString(
            record)
        if plots_for_slice.HasField('plot_data'):
            result.append((
                slicer.deserialize_slice_key(plots_for_slice.slice_key),  # pytype: disable=wrong-arg-types
                plots_for_slice.plot_data))
        if plots_for_slice.plots:
            result.append((
                slicer.deserialize_slice_key(plots_for_slice.slice_key),  # pytype: disable=wrong-arg-types
                plots_for_slice.plots))
    return result
def load_and_deserialize_plots(
        path: Text) -> List[Tuple[slicer.SliceKeyType, Any]]:
    """Returns deserialized plots loaded from given path."""
    result = []
    for record in tf.compat.v1.python_io.tf_record_iterator(path):
        plots_for_slice = metrics_for_slice_pb2.PlotsForSlice.FromString(
            record)
        plots_map = {}
        if plots_for_slice.plots:
            plot_dict = _convert_proto_map_to_dict(plots_for_slice.plots)
            keys = list(plot_dict.keys())
            # If there is only one label, choose it automatically.
            plot_data = plot_dict[keys[0]] if len(keys) == 1 else plot_dict
            plots_map[''] = {'': plot_data}
        elif plots_for_slice.HasField('plot_data'):
            plots_map[''] = {
                '': json_format.MessageToDict(plots_for_slice.plot_data)
            }

        if plots_for_slice.plot_keys_and_values:
            for kv in plots_for_slice.plot_keys_and_values:
                output_name = kv.key.output_name
                if output_name not in plots_map:
                    plots_map[output_name] = {}
                sub_key_id = _get_sub_key_id(
                    kv.key.sub_key) if kv.key.HasField('sub_key') else ''
                plots_map[output_name][sub_key_id] = json_format.MessageToDict(
                    kv.value)

        result.append((
            slicer.deserialize_slice_key(plots_for_slice.slice_key),  # pytype: disable=wrong-arg-types
            plots_map))

    return result
Beispiel #3
0
def load_and_deserialize_metrics(
    path: Text,
    model_name: Optional[Text] = None
) -> List[Tuple[slicer.SliceKeyType, Any]]:
    """Loads metrics from the given location and builds a metric map for it."""
    result = []
    for record in tf.compat.v1.python_io.tf_record_iterator(path):
        metrics_for_slice = metrics_for_slice_pb2.MetricsForSlice.FromString(
            record)

        model_metrics_map = {}
        if metrics_for_slice.metrics:
            model_metrics_map[''] = {
                '': {
                    '': _convert_proto_map_to_dict(metrics_for_slice.metrics)
                }
            }

        if metrics_for_slice.metric_keys:
            for key, value in zip(metrics_for_slice.metric_keys,
                                  metrics_for_slice.metric_values):
                current_model_name = key.model_name

                if current_model_name not in model_metrics_map:
                    model_metrics_map[current_model_name] = {}
                output_name = key.output_name
                if output_name not in model_metrics_map[current_model_name]:
                    model_metrics_map[current_model_name][output_name] = {}

                multi_class_metrics_map = model_metrics_map[
                    current_model_name][output_name]
                multi_class_key_id = _get_multi_class_key_id(
                    key.multi_class_key) if key.HasField(
                        'multi_class_key') else ''
                if multi_class_key_id not in multi_class_metrics_map:
                    multi_class_metrics_map[multi_class_key_id] = {}
                metric_name = key.name
                multi_class_metrics_map[multi_class_key_id][
                    metric_name] = json_format.MessageToDict(value)

        metrics_map = None
        keys = list(model_metrics_map.keys())
        if model_name in model_metrics_map:
            # Use the provided model name if there is a match.
            metrics_map = model_metrics_map[model_name]
        elif model_name is None and len(keys) == 1:
            # Show result of the only model if no model name is specified.
            metrics_map = model_metrics_map[keys[0]]
        else:
            # No match found.
            raise ValueError('Fail to find metrics for model name: %s . '
                             'Available model names are [%s]' %
                             (model_name, ', '.join(keys)))

        result.append((
            slicer.deserialize_slice_key(metrics_for_slice.slice_key),  # pytype: disable=wrong-arg-types
            metrics_map))
    return result
Beispiel #4
0
def load_and_deserialize_metrics(
    path):
  result = []
  for record in tf.python_io.tf_record_iterator(path):
    metrics_for_slice = metrics_for_slice_pb2.MetricsForSlice.FromString(record)
    result.append((
        slicer.deserialize_slice_key(metrics_for_slice.slice_key),  # pytype: disable=wrong-arg-types
        metrics_for_slice.metrics))
  return result
Beispiel #5
0
def load_and_deserialize_metrics(path: Text
                                ) -> List[Tuple[slicer.SliceKeyType, Any]]:
  result = []
  for record in tf.compat.v1.python_io.tf_record_iterator(path):
    metrics_for_slice = metrics_for_slice_pb2.MetricsForSlice.FromString(record)
    result.append((
        slicer.deserialize_slice_key(metrics_for_slice.slice_key),  # pytype: disable=wrong-arg-types
        metrics_for_slice.metrics))
  return result
Beispiel #6
0
def load_and_deserialize_plots(
    path: Text) -> List[Tuple[slicer.SliceKeyType, Any]]:
  """Returns deserialized plots loaded from given path."""
  result = []
  for record in tf.python_io.tf_record_iterator(path):
    plots_for_slice = metrics_for_slice_pb2.PlotsForSlice.FromString(record)
    if plots_for_slice.HasField('plot_data'):
      if plots_for_slice.plots:
        raise RuntimeError('Both plots and plot_data are set.')

      # For backward compatibility, plots data geneated with old code are added
      # to the plots map with default key empty string.
      plots_for_slice.plots[''].CopyFrom(plots_for_slice.plot_data)

    result.append((
        slicer.deserialize_slice_key(plots_for_slice.slice_key),  # pytype: disable=wrong-arg-types
        plots_for_slice.plots))
  return result
  def testDeserializeSliceKey(self):
    slice_metrics = text_format.Parse(
        """
          single_slice_keys {
            column: 'age'
            int64_value: 5
          }
          single_slice_keys {
            column: 'language'
            bytes_value: 'english'
          }
          single_slice_keys {
            column: 'price'
            float_value: 1.0
          }
        """, metrics_for_slice_pb2.SliceKey())

    got_slice_key = slicer.deserialize_slice_key(slice_metrics)
    self.assertItemsEqual([('age', 5), ('language', b'english'),
                           ('price', 1.0)], got_slice_key)