Exemplo n.º 1
0
def load_and_deserialize_plots(
        path: Text) -> List[Tuple[slicer.SliceKeyType, Any]]:
    """Returns deserialized plots loaded from given path."""
    result = []
    for record in tf.compat.v1.python_io.tf_record_iterator(path):
        plots_for_slice = metrics_for_slice_pb2.PlotsForSlice.FromString(
            record)
        plots_map = {}
        if plots_for_slice.plots:
            plot_dict = _convert_proto_map_to_dict(plots_for_slice.plots)
            keys = list(plot_dict.keys())
            # If there is only one label, choose it automatically.
            plot_data = plot_dict[keys[0]] if len(keys) == 1 else plot_dict
            plots_map[''] = {'': plot_data}
        elif plots_for_slice.HasField('plot_data'):
            plots_map[''] = {
                '': json_format.MessageToDict(plots_for_slice.plot_data)
            }

        if plots_for_slice.plot_keys_and_values:
            for kv in plots_for_slice.plot_keys_and_values:
                output_name = kv.key.output_name
                if output_name not in plots_map:
                    plots_map[output_name] = {}
                sub_key_id = _get_sub_key_id(
                    kv.key.sub_key) if kv.key.HasField('sub_key') else ''
                plots_map[output_name][sub_key_id] = json_format.MessageToDict(
                    kv.value)

        result.append((
            slicer.deserialize_slice_key(plots_for_slice.slice_key),  # pytype: disable=wrong-arg-types
            plots_map))

    return result
Exemplo n.º 2
0
def load_and_deserialize_attributions(
    output_path: Text,
    output_file_format: Text = '',
    slice_specs: Optional[Iterable[slicer.SingleSliceSpec]] = None
) -> Iterator[metrics_for_slice_pb2.AttributionsForSlice]:
    """Read and deserialize the AttributionsForSlice records.

  Args:
    output_path: Path or pattern to search for attribution files under. If a
      directory is passed, files matching 'attributions*' will be searched for.
    output_file_format: Optional file extension to filter files by.
    slice_specs: A set of SingleSliceSpecs to use for filtering returned
      attributions. The attributions for a given slice key will be returned if
      that slice key matches any of the slice_specs.

  Yields:
    AttributionsForSlice protos found in matching files.
  """
    if tf.io.gfile.isdir(output_path):
        output_path = os.path.join(output_path, constants.ATTRIBUTIONS_KEY)
    pattern = _match_all_files(output_path)
    if output_file_format:
        pattern = pattern + '.' + output_file_format
    paths = tf.io.gfile.glob(pattern)
    for value in _raw_value_iterator(paths, output_file_format):
        attributions = metrics_for_slice_pb2.AttributionsForSlice.FromString(
            value)
        if slice_specs and not slicer.slice_key_matches_slice_specs(
                slicer.deserialize_slice_key(attributions.slice_key),
                slice_specs):
            continue
        yield attributions
Exemplo n.º 3
0
def load_and_deserialize_metrics(
    path: Text,
    model_name: Optional[Text] = None
) -> List[Tuple[slicer.SliceKeyType, Any]]:
    """Loads metrics from the given location and builds a metric map for it."""
    result = []
    for record in tf.compat.v1.python_io.tf_record_iterator(path):
        metrics_for_slice = metrics_for_slice_pb2.MetricsForSlice.FromString(
            record)

        model_metrics_map = {}
        if metrics_for_slice.metrics:
            model_metrics_map[''] = {
                '': {
                    '': _convert_proto_map_to_dict(metrics_for_slice.metrics)
                }
            }

        if metrics_for_slice.metric_keys_and_values:
            for kv in metrics_for_slice.metric_keys_and_values:
                current_model_name = kv.key.model_name

                if current_model_name not in model_metrics_map:
                    model_metrics_map[current_model_name] = {}
                output_name = kv.key.output_name
                if output_name not in model_metrics_map[current_model_name]:
                    model_metrics_map[current_model_name][output_name] = {}

                sub_key_metrics_map = model_metrics_map[current_model_name][
                    output_name]
                sub_key_id = _get_sub_key_id(
                    kv.key.sub_key) if kv.key.HasField('sub_key') else ''
                if sub_key_id not in sub_key_metrics_map:
                    sub_key_metrics_map[sub_key_id] = {}
                metric_name = kv.key.name
                sub_key_metrics_map[sub_key_id][
                    metric_name] = json_format.MessageToDict(kv.value)

        metrics_map = None
        keys = list(model_metrics_map.keys())
        if model_name in model_metrics_map:
            # Use the provided model name if there is a match.
            metrics_map = model_metrics_map[model_name]
        elif not model_name and len(keys) == 1:
            # Show result of the only model if no model name is specified.
            metrics_map = model_metrics_map[keys[0]]
        else:
            # No match found.
            raise ValueError('Fail to find metrics for model name: %s . '
                             'Available model names are [%s]' %
                             (model_name, ', '.join(keys)))

        result.append((
            slicer.deserialize_slice_key(metrics_for_slice.slice_key),  # pytype: disable=wrong-arg-types
            metrics_map))
    return result
Exemplo n.º 4
0
def convert_attributions_proto_to_dict(
    attributions_for_slice: metrics_for_slice_pb2.AttributionsForSlice,
    model_name: Optional[Text] = None
) -> Tuple[slicer.SliceKeyType, Optional[view_types.AttributionsByOutputName]]:
    """Converts attributions proto to dict."""
    model_metrics_map = {}
    default_model_name = None
    if attributions_for_slice.attributions_keys_and_values:
        for kv in attributions_for_slice.attributions_keys_and_values:
            current_model_name = kv.key.model_name
            if current_model_name not in model_metrics_map:
                model_metrics_map[current_model_name] = {}
            output_name = kv.key.output_name
            if output_name not in model_metrics_map[current_model_name]:
                model_metrics_map[current_model_name][output_name] = {}
            sub_key_metrics_map = model_metrics_map[current_model_name][
                output_name]
            if kv.key.HasField('sub_key'):
                sub_key_id = str(metric_types.SubKey.from_proto(
                    kv.key.sub_key))
            else:
                sub_key_id = ''
            if sub_key_id not in sub_key_metrics_map:
                sub_key_metrics_map[sub_key_id] = {}
            if kv.key.is_diff:
                if default_model_name is None:
                    default_model_name = current_model_name
                elif default_model_name != current_model_name:
                    # Setting '' to possibly trigger no match found ValueError below.
                    default_model_name = ''
                metric_name = '{}_diff'.format(kv.key.name)
            else:
                metric_name = kv.key.name
            attributions = {}
            for k in kv.values:
                attributions[k] = json_format.MessageToDict(kv.values[k])
            sub_key_metrics_map[sub_key_id][metric_name] = attributions

    metrics_map = None
    keys = list(model_metrics_map.keys())
    tmp_model_name = model_name or default_model_name
    if tmp_model_name in model_metrics_map:
        # Use the provided model name if there is a match.
        metrics_map = model_metrics_map[tmp_model_name]
    elif (not tmp_model_name) and len(keys) == 1:
        # Show result of the only model if no model name is specified.
        metrics_map = model_metrics_map[keys[0]]
    elif keys:
        # No match found.
        raise ValueError(
            'Fail to find attribution metrics for model name: %s . '
            'Available model names are [%s]' % (model_name, ', '.join(keys)))

    return (slicer.deserialize_slice_key(attributions_for_slice.slice_key),
            metrics_map)
Exemplo n.º 5
0
def convert_plots_proto_to_dict(
    plots_for_slice: metrics_for_slice_pb2.PlotsForSlice,
    model_name: Optional[Text] = None
) -> Optional[Tuple[slicer.SliceKeyType,
                    Optional[view_types.PlotsByOutputName]]]:
    """Converts plots proto to dict."""
    model_plots_map = {}
    if plots_for_slice.plots:
        plot_dict = _convert_proto_map_to_dict(plots_for_slice.plots)
        keys = list(plot_dict.keys())
        # If there is only one label, choose it automatically.
        plot_data = plot_dict[keys[0]] if len(keys) == 1 else plot_dict
        model_plots_map[''] = {'': {'': plot_data}}
    elif plots_for_slice.HasField('plot_data'):
        model_plots_map[''] = {
            '': {
                '': {json_format.MessageToDict(plots_for_slice.plot_data)}
            }
        }

    if plots_for_slice.plot_keys_and_values:
        for kv in plots_for_slice.plot_keys_and_values:
            current_model_name = kv.key.model_name
            if current_model_name not in model_plots_map:
                model_plots_map[current_model_name] = {}
            output_name = kv.key.output_name
            if output_name not in model_plots_map[current_model_name]:
                model_plots_map[current_model_name][output_name] = {}

            sub_key_plots_map = model_plots_map[current_model_name][
                output_name]
            sub_key_id = str(metric_types.SubKey.from_proto(
                kv.key.sub_key)) if kv.key.HasField('sub_key') else ''
            sub_key_plots_map[sub_key_id] = json_format.MessageToDict(kv.value)

    plots_map = None
    keys = list(model_plots_map.keys())
    if model_name in model_plots_map:
        # Use the provided model name if there is a match.
        plots_map = model_plots_map[model_name]
    elif not model_name and len(keys) == 1:
        # Show result of the only model if no model name is specified.
        plots_map = model_plots_map[keys[0]]
    elif keys:
        # No match found.
        logging.warning(
            'Fail to find plots for model name: %s . '
            'Available model names are [%s]', model_name, ', '.join(keys))
        return None

    return (slicer.deserialize_slice_key(plots_for_slice.slice_key), plots_map)
Exemplo n.º 6
0
    def testDeserializeSliceKey(self):
        slice_metrics = text_format.Parse(
            """
          single_slice_keys {
            column: 'age'
            int64_value: 5
          }
          single_slice_keys {
            column: 'language'
            bytes_value: 'english'
          }
          single_slice_keys {
            column: 'price'
            float_value: 1.0
          }
        """, metrics_for_slice_pb2.SliceKey())

        got_slice_key = slicer.deserialize_slice_key(slice_metrics)
        self.assertCountEqual([('age', 5), ('language', 'english'),
                               ('price', 1.0)], got_slice_key)
def load_and_deserialize_attributions(
    output_path: str,
    output_file_format: str = _TFRECORD_FORMAT,
    slice_specs: Optional[Iterable[slicer.SingleSliceSpec]] = None
) -> Iterator[metrics_for_slice_pb2.AttributionsForSlice]:
  """Read and deserialize the AttributionsForSlice records.

  Args:
    output_path: Path or pattern to search for attribution files under. If a
      directory is passed, files matching 'attributions*' will be searched for.
    output_file_format: Optional file extension to filter files by and the
      format to use for parsing. The default value is tfrecord.
    slice_specs: A set of SingleSliceSpecs to use for filtering returned
      attributions. The attributions for a given slice key will be returned if
      that slice key matches any of the slice_specs.

  Yields:
    AttributionsForSlice protos found in matching files.
  """
  if tf.io.gfile.isdir(output_path):
    output_path = os.path.join(output_path, constants.ATTRIBUTIONS_KEY)
  pattern = _match_all_files(output_path)
  if output_file_format:
    pattern = pattern + '.' + output_file_format
  paths = tf.io.gfile.glob(pattern)
  if not paths:
    # For backwards compatibility, check for files without an explicit suffix,
    # but still use the output_file_format for parsing.
    no_suffix_pattern = _match_all_files(output_path)
    paths = tf.io.gfile.glob(no_suffix_pattern)
  for value in _raw_value_iterator(paths, output_file_format):
    attributions = metrics_for_slice_pb2.AttributionsForSlice.FromString(value)
    if slice_specs and not slicer.slice_key_matches_slice_specs(
        slicer.deserialize_slice_key(attributions.slice_key), slice_specs):
      continue
    yield attributions
Exemplo n.º 8
0
def load_and_deserialize_metrics(
    path: Text,
    model_name: Optional[Text] = None
) -> List[Tuple[slicer.SliceKeyType, Any]]:
    """Loads metrics from the given location and builds a metric map for it."""
    # TODO(b/150413770): support metrics from multiple candidates.
    result = []
    for record in tf.compat.v1.python_io.tf_record_iterator(path):
        metrics_for_slice = metrics_for_slice_pb2.MetricsForSlice.FromString(
            record)

        model_metrics_map = {}
        if metrics_for_slice.metrics:
            model_metrics_map[''] = {
                '': {
                    '': _convert_proto_map_to_dict(metrics_for_slice.metrics)
                }
            }

        default_model_name = None
        if metrics_for_slice.metric_keys_and_values:
            for kv in metrics_for_slice.metric_keys_and_values:
                current_model_name = kv.key.model_name

                if current_model_name not in model_metrics_map:
                    model_metrics_map[current_model_name] = {}
                output_name = kv.key.output_name
                if output_name not in model_metrics_map[current_model_name]:
                    model_metrics_map[current_model_name][output_name] = {}

                sub_key_metrics_map = model_metrics_map[current_model_name][
                    output_name]
                sub_key_id = _get_sub_key_id(
                    kv.key.sub_key) if kv.key.HasField('sub_key') else ''
                if sub_key_id not in sub_key_metrics_map:
                    sub_key_metrics_map[sub_key_id] = {}
                if kv.key.is_diff:
                    if default_model_name is None:
                        default_model_name = current_model_name
                    elif default_model_name != current_model_name:
                        # Setting '' to trigger no match found ValueError below.
                        default_model_name = ''
                    metric_name = '{}_diff'.format(kv.key.name)
                else:
                    metric_name = kv.key.name
                sub_key_metrics_map[sub_key_id][
                    metric_name] = json_format.MessageToDict(kv.value)

        metrics_map = None
        keys = list(model_metrics_map.keys())
        tmp_model_name = model_name or default_model_name
        if tmp_model_name in model_metrics_map:
            # Use the provided model name if there is a match.
            metrics_map = model_metrics_map[tmp_model_name]
            # Add model-independent (e.g. example_count) metrics to all models.
            if tmp_model_name and '' in model_metrics_map:
                for output_name, output_dict in model_metrics_map[''].items():
                    for sub_key_id, sub_key_dict in output_dict.items():
                        for name, value in sub_key_dict.items():
                            metrics_map.setdefault(output_name, {}).setdefault(
                                sub_key_id, {})[name] = value
        elif not tmp_model_name and len(keys) == 1:
            # Show result of the only model if no model name is specified.
            metrics_map = model_metrics_map[keys[0]]
        else:
            # No match found.
            raise ValueError('Fail to find metrics for model name: %s . '
                             'Available model names are [%s]' %
                             (model_name, ', '.join(keys)))

        result.append((
            slicer.deserialize_slice_key(metrics_for_slice.slice_key),  # pytype: disable=wrong-arg-types
            metrics_map))
    return result
Exemplo n.º 9
0
def partition_slices(
    metrics: List[metrics_for_slice_pb2.MetricsForSlice],
    metric_key: metric_types.MetricKey,
    comparison_type: Text = 'HIGHER',
    alpha: float = 0.01,
    min_num_examples: int = 1
) -> Tuple[List[SliceComparisonResult], List[SliceComparisonResult]]:
    """Partition slices into significant and non-significant slices.

  Args:
    metrics: List of slice metrics protos. We assume that the metrics have
      MetricValue.confidence_interval field populated. This will be populated
      when the metrics computed with confidence intervals enabled.
    metric_key: Name of the metric based on which significance testing is done.
    comparison_type: Type of comparison indicating if we are looking for slices
      whose metric is higher (`HIGHER`) or lower (`LOWER`) than the metric of
      the base slice (overall dataset).
    alpha: Significance-level for statistical significance testing.
    min_num_examples: Minimum number of examples that a slice should have. If it
      is set to zero, we don't do any filtering.

  Returns:
    Tuple containing list of statistically significant and non-significant
    slices.
  """
    assert comparison_type in ['HIGHER', 'LOWER']
    if min_num_examples == 0:
        min_num_examples = 1

    metrics_dict = {
        slicer_lib.deserialize_slice_key(slice_metrics.slice_key):
        slice_metrics
        for slice_metrics in metrics
    }
    overall_slice_metrics = metrics_dict[()]
    del metrics_dict[()]

    example_count_metric_key = metric_types.MetricKey(
        name=example_count.EXAMPLE_COUNT_NAME,
        model_name=metric_key.model_name,
        output_name=metric_key.output_name,
        sub_key=metric_key.sub_key,
        is_diff=metric_key.is_diff)
    overall_metrics_dict = _get_metrics_as_dict(overall_slice_metrics)
    significant_slices, non_significant_slices = [], []
    for slice_key, slice_metrics in metrics_dict.items():
        slice_metrics_dict = _get_metrics_as_dict(slice_metrics)
        num_examples = int(
            slice_metrics_dict[example_count_metric_key].unsampled_value)
        if num_examples < min_num_examples:
            continue
        # Prune non-interesting slices.
        if np.isnan(slice_metrics_dict[metric_key].unsampled_value):
            continue
        if slice_metrics_dict[metric_key].sample_standard_deviation == 0:
            logging.warning(
                'Ignoring slice: %s with standard deviation: %s ', slice_key,
                slice_metrics_dict[metric_key].sample_standard_deviation)
            continue
        # TODO(pachristopher): Should we use weighted example count?
        if slice_metrics_dict[example_count_metric_key].unsampled_value <= 1:
            logging.warning(
                'Ignoring slice: %s with example count: %s ', slice_key,
                slice_metrics_dict[example_count_metric_key].unsampled_value)
            continue
        # Only consider statistically significant slices.
        is_significant, p_value = _is_significant_slice(
            slice_metrics_dict[metric_key].unsampled_value,
            slice_metrics_dict[metric_key].sample_standard_deviation,
            slice_metrics_dict[example_count_metric_key].unsampled_value,
            overall_metrics_dict[metric_key].unsampled_value,
            overall_metrics_dict[metric_key].sample_standard_deviation,
            overall_metrics_dict[example_count_metric_key].unsampled_value,
            comparison_type, alpha)
        # Compute effect size for the slice.
        effect_size = _compute_effect_size(
            slice_metrics_dict[metric_key].unsampled_value,
            slice_metrics_dict[metric_key].sample_standard_deviation,
            overall_metrics_dict[metric_key].unsampled_value,
            overall_metrics_dict[metric_key].sample_standard_deviation)
        slice_info = SliceComparisonResult(
            slice_key, num_examples,
            slice_metrics_dict[metric_key].unsampled_value,
            overall_metrics_dict[metric_key].unsampled_value, p_value,
            effect_size, slice_metrics)
        if not is_significant:
            non_significant_slices.append(slice_info)
            continue
        significant_slices.append(slice_info)
    return significant_slices, non_significant_slices
def find_significant_slices(
        metrics: List[metrics_for_slice_pb2.MetricsForSlice],
        metric_key: Text,
        comparison_type: Text = 'HIGHER',
        alpha: float = 0.01,
        min_num_examples: int = 10) -> List[SliceComparisonResult]:
    """Finds statistically significant slices.

  Args:
    metrics: List of slice metrics protos. We assume that the metrics have
      MetricValue.confidence_interval field populated. This will be populated
      when the metrics computed with confidence intervals enabled.
    metric_key: Name of the metric based on which significance testing is done.
    comparison_type: Type of comparison indicating if we are looking for slices
      whose metric is higher (`HIGHER`) or lower (`LOWER`) than the metric of
      the base slice (overall dataset).
    alpha: Significance-level for statistical significance testing.
    min_num_examples: Minimum number of examples that a slice should have.

  Returns:
    List of statistically significant slices.
  """
    assert comparison_type in ['HIGHER', 'LOWER']
    assert min_num_examples > 0

    metrics_dict = {
        tuple(sorted(slicer_lib.deserialize_slice_key(
            slice_metrics.slice_key))): slice_metrics
        for slice_metrics in metrics
    }
    overall_slice_metrics = metrics_dict[()]
    del metrics_dict[()]

    overall_metrics_dict = _get_metrics_as_dict(overall_slice_metrics)
    result = []
    for slice_key, slice_metrics in metrics_dict.items():
        slice_metrics_dict = _get_metrics_as_dict(slice_metrics)
        num_examples = slice_metrics_dict['example_count'].unsampled_value
        if num_examples < min_num_examples:
            continue
        # Prune non-interesting slices.
        if np.isnan(slice_metrics_dict[metric_key].unsampled_value):
            continue
        if comparison_type == 'HIGHER':
            comparison_fn = operator.le
        else:
            comparison_fn = operator.ge
        if comparison_fn(slice_metrics_dict[metric_key].unsampled_value,
                         overall_metrics_dict[metric_key].unsampled_value):
            continue

        # Only consider statistically significant slices.
        is_significant, p_value = _is_significant_slice(
            slice_metrics_dict[metric_key].unsampled_value,
            slice_metrics_dict[metric_key].sample_standard_deviation,
            slice_metrics_dict['example_count'].unsampled_value,
            overall_metrics_dict[metric_key].unsampled_value,
            overall_metrics_dict[metric_key].sample_standard_deviation,
            overall_metrics_dict['example_count'].unsampled_value,
            comparison_type, alpha)
        if not is_significant:
            continue
        # Compute effect size for the slice.
        effect_size = _compute_effect_size(
            slice_metrics_dict[metric_key].unsampled_value,
            slice_metrics_dict[metric_key].sample_standard_deviation,
            overall_metrics_dict[metric_key].unsampled_value,
            overall_metrics_dict[metric_key].sample_standard_deviation)
        result.append(
            SliceComparisonResult(
                slice_key, num_examples,
                slice_metrics_dict[metric_key].unsampled_value,
                overall_metrics_dict[metric_key].unsampled_value, p_value,
                effect_size, slice_metrics))  # pytype: disable=wrong-arg-types
    return result
Exemplo n.º 11
0
def convert_metrics_proto_to_dict(
    metrics_for_slice: metrics_for_slice_pb2.MetricsForSlice,
    model_name: Optional[Text] = None
) -> Optional[Tuple[slicer.SliceKeyOrCrossSliceKeyType,
                    Optional[view_types.MetricsByOutputName]]]:
  """Converts metrics proto to dict."""
  model_metrics_map = {}
  if metrics_for_slice.metrics:
    model_metrics_map[''] = {
        '': {
            '': _convert_proto_map_to_dict(metrics_for_slice.metrics)
        }
    }

  default_model_name = None
  if metrics_for_slice.metric_keys_and_values:
    for kv in metrics_for_slice.metric_keys_and_values:
      current_model_name = kv.key.model_name
      if current_model_name not in model_metrics_map:
        model_metrics_map[current_model_name] = {}
      output_name = kv.key.output_name
      if output_name not in model_metrics_map[current_model_name]:
        model_metrics_map[current_model_name][output_name] = {}

      sub_key_metrics_map = model_metrics_map[current_model_name][output_name]
      sub_key_id = str(metric_types.SubKey.from_proto(
          kv.key.sub_key)) if kv.key.HasField('sub_key') else ''
      if sub_key_id not in sub_key_metrics_map:
        sub_key_metrics_map[sub_key_id] = {}
      if kv.key.is_diff:
        if default_model_name is None:
          default_model_name = current_model_name
        elif default_model_name != current_model_name:
          # Setting '' to trigger no match found ValueError below.
          default_model_name = ''
        metric_name = '{}_diff'.format(kv.key.name)
      elif kv.key.HasField('aggregation_type'):
        metric_name = '{}_{}'.format(kv.key.aggregation_type, kv.key.name)
      else:
        metric_name = kv.key.name
      sub_key_metrics_map[sub_key_id][metric_name] = json_format.MessageToDict(
          kv.value)

  metrics_map = None
  keys = list(model_metrics_map.keys())
  tmp_model_name = model_name or default_model_name
  if tmp_model_name in model_metrics_map:
    # Use the provided model name if there is a match.
    metrics_map = model_metrics_map[tmp_model_name]
    # Add model-independent (e.g. example_count) metrics to all models.
    if tmp_model_name and '' in model_metrics_map:
      for output_name, output_dict in model_metrics_map[''].items():
        for sub_key_id, sub_key_dict in output_dict.items():
          for name, value in sub_key_dict.items():
            metrics_map.setdefault(output_name, {}).setdefault(sub_key_id,
                                                               {})[name] = value
  elif not tmp_model_name and len(keys) == 1:
    # Show result of the only model if no model name is specified.
    metrics_map = model_metrics_map[keys[0]]
  elif keys:
    # No match found.
    logging.warning(
        'Fail to find metrics for model name: %s . '
        'Available model names are [%s]', model_name, ', '.join(keys))
    return None

  slice_key = None
  if metrics_for_slice.HasField('cross_slice_key'):
    slice_key = slicer.deserialize_cross_slice_key(
        metrics_for_slice.cross_slice_key)
  else:
    slice_key = slicer.deserialize_slice_key(metrics_for_slice.slice_key)
  return (slice_key, metrics_map)
Exemplo n.º 12
0
def find_top_slices(metrics: List[metrics_for_slice_pb2.MetricsForSlice],
                    metric_key: Text,
                    statistics: statistics_pb2.DatasetFeatureStatisticsList,
                    comparison_type: Text = 'HIGHER',
                    min_num_examples: int = 10,
                    num_top_slices: int = 10,
                    rank_by: Text = 'EFFECT_SIZE'):
    """Finds top-k slices.

  Args:
    metrics: List of slice metrics protos. We assume that the metrics have
    MetricValue.confidence_interval field populated. This will be populated when
      the metrics computed with confidence intervals enabled.
    metric_key: Name of the metric based on which significance testing is done.
    statistics: Data statistics used to configure AutoSliceKeyExtractor.
    comparison_type: Type of comparison indicating if we are looking for slices
      whose metric is higher (`HIGHER`) or lower (`LOWER`) than the metric
      of the base slice (overall dataset).
    min_num_examples: Minimum number of examples that a slice should have.
    num_top_slices: Number of top slices to return.
    rank_by: Indicates how the slices should be ordered in the result.

  Returns:
    List of ordered slices.
  """
    assert comparison_type in ['HIGHER', 'LOWER']
    assert min_num_examples > 0
    assert 0 < num_top_slices
    assert rank_by in ['EFFECT_SIZE', 'PVALUE']

    metrics_dict = {
        slicer_lib.deserialize_slice_key(slice_metrics.slice_key):
        slice_metrics
        for slice_metrics in metrics
    }
    overall_slice_metrics = metrics_dict[()]
    del metrics_dict[()]

    boundaries = auto_slice_key_extractor._get_bucket_boundaries(statistics)  # pylint: disable=protected-access
    overall_metrics_dict = _get_metrics_as_dict(overall_slice_metrics)
    to_be_sorted_slices = []
    for slice_key, slice_metrics in metrics_dict.items():
        slice_metrics_dict = _get_metrics_as_dict(slice_metrics)
        num_examples = slice_metrics_dict['example_count'].unsampled_value
        if num_examples < min_num_examples:
            continue
        # Prune non-interesting slices.
        if np.isnan(slice_metrics_dict[metric_key].unsampled_value):
            continue
        if comparison_type == 'HIGHER':
            comparison_fn = operator.le
        else:
            comparison_fn = operator.ge
        if comparison_fn(slice_metrics_dict[metric_key].unsampled_value,
                         overall_metrics_dict[metric_key].unsampled_value):
            continue

        # Only consider statistically significant slices.
        is_significant, pvalue = _is_significant_slice(
            slice_metrics_dict[metric_key].unsampled_value,
            slice_metrics_dict[metric_key].sample_standard_deviation,
            slice_metrics_dict['example_count'].unsampled_value,
            overall_metrics_dict[metric_key].unsampled_value,
            overall_metrics_dict[metric_key].sample_standard_deviation,
            overall_metrics_dict['example_count'].unsampled_value,
            comparison_type)
        if not is_significant:
            continue
        # Format the slice info (feature names, values) in the proto into a
        # slice key.
        transformed_slice_key = []
        for (feature, value) in slice_key:
            if feature.startswith(
                    auto_slice_key_extractor.TRANSFORMED_FEATURE_PREFIX):
                feature = feature[len(auto_slice_key_extractor.
                                      TRANSFORMED_FEATURE_PREFIX):]
                value = _bucket_to_range(value, boundaries[feature])
            transformed_slice_key.append((feature, value))
        slice_key = slicer_lib.stringify_slice_key(
            tuple(transformed_slice_key))
        # Compute effect size for the slice.
        effect_size = _compute_effect_size(
            slice_metrics_dict[metric_key].unsampled_value,
            slice_metrics_dict[metric_key].sample_standard_deviation,
            overall_metrics_dict[metric_key].unsampled_value,
            overall_metrics_dict[metric_key].sample_standard_deviation)
        to_be_sorted_slices.append(
            SliceComparisonResult(
                slice_key, num_examples,
                slice_metrics_dict[metric_key].unsampled_value,
                overall_metrics_dict[metric_key].unsampled_value, pvalue,
                effect_size))
    # Rank the slices.
    ranking_fn, reverse = operator.attrgetter('effect_size'), True
    if rank_by == 'PVALUE':
        ranking_fn, reverse = operator.attrgetter('pvalue'), False
    result = sorted(to_be_sorted_slices, key=ranking_fn,
                    reverse=reverse)[:num_top_slices]
    return result