示例#1
0
 def add_input(
         self, accumulator: Dict[str, List[float]],
         extracts: metric_types.StandardMetricInputs
 ) -> Dict[str, List[float]]:
     if constants.ATTRIBUTIONS_KEY not in extracts:
         raise ValueError(
             '{} missing from extracts {}\n\n. An attribution extractor is '
             'required to use attribution metrics'.format(
                 constants.ATTRIBUTIONS_KEY, extracts))
     attributions = extracts[constants.ATTRIBUTIONS_KEY]
     if self._key.model_name:
         attributions = util.get_by_keys(attributions,
                                         [self._key.model_name])
     if self._key.output_name:
         attributions = util.get_by_keys(attributions,
                                         [self._key.output_name])
     _, _, example_weight = next(
         metric_util.to_label_prediction_example_weight(
             extracts,
             eval_config=self._eval_config,
             model_name=self._key.model_name,
             output_name=self._key.output_name,
             sub_key=self._key.sub_key,
             example_weighted=self._key.example_weighted,
             allow_none=True,
             flatten=False))
     example_weight = float(example_weight)
     for k, v in attributions.items():
         v = util.to_numpy(v)
         if self._key.sub_key is not None:
             if self._key.sub_key.class_id is not None:
                 v = _scores_by_class_id(self._key.sub_key.class_id, v)
             elif self._key.sub_key.k is not None:
                 v = _scores_by_top_k(self._key.sub_key.k, v)
                 v = np.array(v[self._key.sub_key.k - 1])
             elif self._key.sub_key.top_k is not None:
                 v = _scores_by_top_k(self._key.sub_key.top_k, v)
         if k not in accumulator:
             accumulator[k] = [0.0] * v.size
         self._sum(accumulator[k], v * example_weight)
     return accumulator
示例#2
0
 def add_input(self, accumulator: Dict[Text, List[float]],
               attributions: Dict[Text, Any]) -> Dict[Text, List[float]]:
     if self._key.model_name:
         attributions = util.get_by_keys(attributions,
                                         [self._key.model_name])
     if self._key.output_name:
         attributions = util.get_by_keys(attributions,
                                         [self._key.output_name])
     for k, v in attributions.items():
         v = util.to_numpy(v)
         if self._key.sub_key is not None:
             if self._key.sub_key.class_id is not None:
                 v = _scores_by_class_id(self._key.sub_key.class_id, v)
             elif self._key.sub_key.k is not None:
                 v = _scores_by_top_k(self._key.sub_key.k, v)
                 v = np.array(v[self._key.sub_key.k - 1])
             elif self._key.sub_key.top_k is not None:
                 v = _scores_by_top_k(self._key.sub_key.top_k, v)
         if k not in accumulator:
             accumulator[k] = [0.0] * v.size
         self._sum(accumulator[k], v)
     return accumulator
示例#3
0
def select_class_id(
    class_id: int,
    labels: Any,
    predictions: Any,
    sparse_labels: bool = None,
) -> Tuple[np.ndarray, np.ndarray]:
    """Selects values for given class ID from multi-class labels and predictions.

  Args:
    class_id: Class ID to filter the labels and predictions by.
    labels: Array or list of processed labels (1D, 2D, or 3D).
    predictions: Array or list of processed predictions (1D, 2D, or 3D).
    sparse_labels: True if sparse labels are being used. If None then the
      sparseness will be inferred from the shapes of the labels and predictions
      (i.e. if the shapes are different then the labels will be assumed to be
      sparse).

  Returns:
    A (labels, predictions) tuple with the predictions returned in the same form
    as the originals (except for the last dimension which will be 1).

  Raises:
    ValueError: If the labels or predictions cannot be formatted properly.
  """
    labels = util.to_numpy(labels)
    predictions = util.to_numpy(predictions)
    if labels.size == 0 or predictions.size == 0:
        return (labels, predictions)

    def lookup(arr, target):
        if class_id < 0 or class_id >= len(arr):
            raise ValueError(
                f'class_id "{class_id}" out of range of {target}: {arr}')
        return arr[class_id]

    # Convert scalars to arrays
    if not labels.shape:
        labels = labels.reshape((1, ))
    if not predictions.shape:
        predictions = predictions.reshape((1, ))

    sparse_labels = _verify_sparse_labels(labels,
                                          predictions,
                                          sparse_labels=sparse_labels)
    if sparse_labels and labels.shape[-1] != 1:
        # Convert to [[class_id1], ...]
        labels = labels.reshape((-1, 1))

    labels_out_shape = list(labels.shape)
    labels_out_shape[-1] = 1
    predictions_out_shape = list(predictions.shape)
    predictions_out_shape[-1] = 1

    # Convert labels and predictions into the form ([[...], [...]])
    if len(labels.shape) > 1:
        # Flatten all but the last dim (a, b, c) -> (a * b, c)
        labels = labels.reshape((-1, labels.shape[-1]))
    else:
        labels = labels.reshape((1, labels.shape[0]))
    if len(predictions.shape) > 1:
        predictions = predictions.reshape((-1, predictions.shape[-1]))
    else:
        predictions = predictions.reshape((1, predictions.shape[0]))

    if sparse_labels:
        # Labels are of the form [[class_id1], [class_id2], ...]
        labels = np.array([int(l[0] == class_id) for l in labels])
    else:
        # Labels are of the form [[0, 0, 1, ...], [0, 0, 0, ...], ...]
        labels = np.array([lookup(l, 'labels') for l in labels])
    predictions = np.array([lookup(p, 'predictions') for p in predictions])

    return (labels.reshape(labels_out_shape),
            predictions.reshape(predictions_out_shape))
示例#4
0
def prepare_labels_and_predictions(
    labels: Any,
    predictions: Any,
    prediction_key: Optional[Text] = None,
    label_vocabulary: Optional[Union[np.ndarray, List[Text]]] = None
) -> Tuple[np.ndarray, np.ndarray]:
    """Prepares labels and predictions for use in calculations.

  If the predictions are a dict (i.e. estimator based output) this function will
  apply the necessary lookup based on the prediction_key provided (or using a
  default set of common keys such as 'probabilities', etc). Note that the
  predictions passed as args must be AFTER the model_name and/or output_name
  lookups have been performed. This function also applies any label vocabulary
  transformations where possible.

  If successful, the final output of calling this function will be a pair of
  numpy arrays representing the labels and predictions.

  Args:
    labels: List, np.ndarray, or SparseTensorValue of values (1D, 2D, or 3D).
    predictions: List or np.ndarray of prediction values (1D, 2D, or 3D) or a
      dict of prediction values keyed by prediction_key or common estimator keys
      (logistic, probabilties, etc).
    prediction_key: Optional predictions key. Used when the predict output is a
      dict.
    label_vocabulary: Optional label vocabulary to convert label values to ints
      (if prediction is a dict containing an 'all_classes' key that will be used
      if label_vocabulary is None).

  Returns:
    A (labels, predictions) tuple suitable for metric calculations.

  Raises:
    ValueError: If the labels or predictions are in an invalid format.
  """
    if isinstance(predictions, Mapping):
        if label_vocabulary is None:
            if _ALL_CLASSES in predictions:
                # Check for use of estimator label vocab under ALL_CLASSES. This was
                # added in 06/2019 for eval signatures because the CLASSES only contains
                # the label for the chosen class.
                label_vocabulary = util.to_numpy(predictions[_ALL_CLASSES])
            elif (tf.saved_model.CLASSIFY_OUTPUT_SCORES in predictions
                  and tf.saved_model.CLASSIFY_OUTPUT_CLASSES in predictions):
                # For classification model using the default serving signature, the
                # CLASSES contains the full vocabulary. The check for scores is needed
                # here to avoid matching CLASSES in the eval case (scores are not used
                # in eval).
                label_vocabulary = util.to_numpy(
                    predictions[tf.saved_model.CLASSIFY_OUTPUT_CLASSES])
            if label_vocabulary is not None:
                while len(label_vocabulary.shape) > 1:
                    label_vocabulary = label_vocabulary[
                        0]  # Remove the bach dimensions
        if not prediction_key:
            # Estimator predictions use dicts of scores, probabilities, classes, etc.
            for k in (tf.saved_model.CLASSIFY_OUTPUT_SCORES,
                      tf.saved_model.REGRESS_OUTPUTS, _PREDICTIONS, _LOGISTIC,
                      _PROBABILITIES, _LOGITS):
                if k in predictions:
                    predictions = predictions[k]
                    prediction_key = k
                    break
        elif prediction_key in predictions:
            predictions = predictions[prediction_key]

    if isinstance(predictions, Mapping):
        raise ValueError(
            'unable to prepare prediction for metric computation because the '
            'prediction is a dict with unrecognized keys. If a multi-output model '
            'was used check that an output name was provided in all the relevant '
            'settings (MetricsSpec.output_names, etc). If the model returns a dict '
            'for its output and the output does not contain one of the common '
            'prediction keys (e.g. logistic, probabilities, etc), then '
            'ModelSpec.prediction_key can be used to specify which key to use for '
            f'the predicted value: prediction={predictions}, '
            f'prediction_key={prediction_key}')

    if predictions is not None:
        predictions = util.to_numpy(predictions)

    if labels is not None:
        if (isinstance(labels, types.SparseTensorValue)
                or isinstance(labels, tf.compat.v1.SparseTensorValue)):
            if predictions is None or predictions.size == 0:
                raise ValueError(
                    'predictions must also be used if labels are of type '
                    f'SparseTensorValue: labels={labels}')
            values = labels.values if labels.values is not None else np.array(
                [])
            indices = labels.indices if labels.indices is not None else np.array(
                [])
            if label_vocabulary is not None and values.dtype.kind in ('U', 'S',
                                                                      'O'):
                values = _string_labels_to_class_ids(label_vocabulary, values)
                # If vocab is used then the values will be the indices into the vocab
                # and we should use multi-hot encoding to store the output. We can
                # accomplish this by passing 1's for the values and using the values
                # converted from the vocab as the indices to insert the 1's at the
                # proper offsets in the resulting multi-hot vector.
                labels = _to_dense_tensor(np.ones(values.shape), values,
                                          predictions.shape)
            else:
                labels = _to_dense_tensor(values, indices, predictions.shape)
        else:
            labels = util.to_numpy(labels)
            if label_vocabulary is not None and labels.dtype.kind in ('U', 'S',
                                                                      'O'):
                labels = _string_labels_to_class_ids(label_vocabulary, labels)

    return (labels, predictions)
示例#5
0
def to_label_prediction_example_weight(
    inputs: metric_types.StandardMetricInputs,
    eval_config: Optional[config_pb2.EvalConfig] = None,
    model_name: Text = '',
    output_name: Text = '',
    sub_key: Optional[metric_types.SubKey] = None,
    aggregation_type: Optional[metric_types.AggregationType] = None,
    class_weights: Optional[Dict[int, float]] = None,
    fractional_labels: bool = False,
    flatten: bool = True,
    squeeze: bool = True,
    allow_none: bool = False,
    require_single_example_weight: bool = False
) -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
    """Yields label, prediction, and example weights for use in calculations.

  Where applicable this function will perform model and output name lookups as
  well as any required class ID, top K, etc conversions. It will also apply
  prediction keys and label vocabularies given the necessary information is
  provided as part of the EvalConfig (or standard estimator based naming is
  used). The sparseness of labels will be inferred from the shapes of the labels
  and predictions (i.e. if the shapes are different then the labels will be
  assumed to be sparse).

  If successful, the final output of calling this function will be a tuple of
  numpy arrays representing the label, prediction, and example weight
  respectively. Labels and predictions will be returned in the same shape
  provided (default behavior) unless (1) flatten is True in which case a series
  of values (one per class ID) will be returned with last dimension of size 1 or
  (2) a sub_key is used in which case the last dimension may be re-shaped to
  match the new number of outputs (1 for class_id or k, top_k for top k with
  aggregation).

  Note that for top_k without aggregation, the non-top_k prediction values will
  be set to float('-inf'), but for top_k with aggregation the values will be
  truncated to only return the top k values.

  Examples:

    # default behavior
    #
    # Binary classification
    Input  : labels=[1] predictions=[0.6]
    Output : (np.array([1]), np.array([0.6]), np.array([1.0]))
    # Multi-class classification w/ sparse labels
    Input : labels=[2] predictions=[0.3, 0.6, 0.1]
    Output: (np.array([2]), np.array([0.3, 0.6, 0.1]), np.array([1.0]))
    # Multi-class / multi-label classification w/ dense labels
    Input  : labels=[0, 1, 1] predictions=[0.3, 0.6, 0.1]
    Output : (np.array([0, 1, 1]), np.array([0.3, 0.6, 0.1]), np.array([1.0]))

    # flatten=True
    #
    # Multi-class classification w/ sparse labels
    Input  : labels=[2], predictions=[0.3, 0.6, 0.1]
    Output : (np.array([0]), np.array([0.3]), np.array([1.0])),
             (np.array([0]), np.array([0.6]), np.array([1.0])),
             (np.array([1]), np.array([0.1]), np.array([1.0]))
    # Multi-class/multi-label classification w/ dense labels
    Input  : labels=[0, 0, 1], predictions=[0.3, 0.6, 0.1]
    Output : (np.array([0]), np.array([0.3]), np.array([1.0])),
             (np.array([0]), np.array([0.6]), np.array([1.0])),
             (np.array([1]), np.array([0.1]), np.array([1.0]))

    # sub_key.class_id=[2]
    #
    # Multi-class classification w/ sparse labels
    Input  : labels=[2] predictions=[0.3, 0.6, 0.1]
    Output : (np.array([1]), np.array([0.1]), np.array([1.0]))
    # Multi-class classification w/ dense labels
    Input  : labels=[0, 0, 1] predictions=[0.3, 0.6, 0.1]
    Output : (np.array([1]), np.array([0.1]), np.array([1.0]))

    # sub_key.top_k=2 and aggregation_type is None (i.e. binarization of top 2).
    #
    # Multi-class classification w/ sparse labels
    Input  : labels=[2] predictions=[0.3, 0.6, 0.1]
    Output : (np.array([0, 0, 1]), np.array([0.3, 0.6, -inf]), np.array([1.0]))
    # Multi-class classification w/ dense labels
    Input  : labels=[0, 0, 1] predictions=[0.3, 0.1, 0.6]
    Output : (np.array([0, 0, 1]), np.array([0.3, -inf, 0.6]), np.array([1.0]))

    # sub_key.top_k=2 and aggregation_type is not None (i.e. aggregate top 2).
    #
    # Multi-class classification w/ sparse labels
    Input  : labels=[2] predictions=[0.3, 0.6, 0.1]
    Output : (np.array([0, 1]), np.array([0.3, 0.6]), np.array([1.0]))
    # Multi-class classification w/ dense labels
    Input  : labels=[0, 0, 1] predictions=[0.3, 0.1, 0.6]
    Output : (np.array([0, 0]), np.array([0.3, 0.6]), np.array([1.0]))

    # sub_key.k=2 (i.e. binarization by choosing 2nd largest predicted value).
    #
    # Multi-class classification w/ sparse labels
    Input  : labels=[0] predictions=[0.3, 0.6, 0.1]
    Output : (np.array([1]), np.array([0.3]), np.array([1.0]))
    # Multi-class classification w/ dense labels
    Input  : labels=[0] predictions=[0.3]
    Output : (np.array([0]), np.array([0.3]), np.array([1.0]))

  Args:
    inputs: Standard metric inputs.
    eval_config: Eval config
    model_name: Optional model name (if multi-model evaluation).
    output_name: Optional output name (if multi-output model type).
    sub_key: Optional sub key.
    aggregation_type: Optional aggregation type.
    class_weights: Optional class weights to apply to multi-class / multi-label
      labels and predictions. If used, flatten must also be True.
    fractional_labels: If true, each incoming tuple of (label, prediction, and
      example weight) will be split into two tuples as follows (where l, p, w
      represent the resulting label, prediction, and example weight values):
        (1) l = 0.0, p = prediction, and w = example_weight * (1.0 - label)
        (2) l = 1.0, p = prediction, and w = example_weight * label
      If enabled, an exception will be raised if labels are not within [0, 1].
      The implementation is such that tuples associated with a weight of zero
      are not yielded. This means it is safe to enable fractional_labels even
      when the labels only take on the values of 0.0 or 1.0.
    flatten: True to flatten the final label and prediction outputs so that the
      yielded values are always arrays of size 1. For example, multi-class /
      multi-label outputs would be converted into label and prediction pairs
      that could then be processed by a binary classification metric in order to
      compute a micro average over all classes. If the example weight is not a
      scalar, then they will be flattened as well, otherwise the same example
      weight value will be output for each pair of labels and predictions.
    squeeze: True to squeeze any outputs that have rank > 1. This transforms
      outputs such as np.array([[1]]) to np.array([1]).
    allow_none: True to allow labels or predictions with None values to be
      returned. When used, the values will be returned as empty np.ndarrays. The
      example weight will always be non-empty.
    require_single_example_weight: True to require that the example_weight be a
      single value.

  Yields:
    Tuple of (label, prediction, example_weight).
  """
    def fn_call_str():
        return (f'to_label_prediction_example_weight(inputs={inputs}, '
                f'eval_config={eval_config}, model_name={model_name}, '
                f'output_name={output_name}, sub_key={sub_key}, '
                f'aggregation_type={aggregation_type}, '
                f'class_weights={class_weights}, '
                f'fractional_labels={fractional_labels}, flatten={flatten}, '
                f'squeeze={squeeze}, allow_none={allow_none})')

    def optionally_get_by_keys(value: Any, keys: List[Text]) -> Any:
        if isinstance(value, Mapping):
            new_value = util.get_by_keys(value, keys, optional=True)
            if new_value is not None:
                return new_value
        return value

    try:
        prediction_key = ''
        label_key = ''
        if eval_config and eval_config.model_specs:
            for spec in eval_config.model_specs:
                # To maintain consistency between settings where single models are used,
                # always use '' as the model name regardless of whether a name is passed
                spec_name = spec.name if len(
                    eval_config.model_specs) > 1 else ''
                if spec_name == model_name:
                    prediction_key = spec.prediction_key
                    label_key = spec.label_key
                    break

        label = inputs.label
        if label_key:
            # This is to support a custom EvalSavedModel where the labels are a dict
            # but the keys are not output_names.
            label = optionally_get_by_keys(label, [label_key])
        prediction = inputs.prediction
        example_weight = inputs.example_weight
        if example_weight is None:
            example_weight = np.array(
                1.0, dtype=np.float32)  # tf-ranking needs float32
        if model_name:
            prediction = util.get_by_keys(prediction, [model_name])
            # Labels and weights can optionally be keyed by model name.
            label = optionally_get_by_keys(label, [model_name])
            example_weight = optionally_get_by_keys(example_weight,
                                                    [model_name])
        if output_name:
            prediction = util.get_by_keys(prediction, [output_name])
            # Labels and example weights can optionally be keyed by output name.
            label = optionally_get_by_keys(label, [output_name])
            example_weight = optionally_get_by_keys(example_weight,
                                                    [output_name])

        if isinstance(label, Mapping):
            raise ValueError(
                'unable to prepare label for metric computation because the label is '
                'a dict with unrecognized keys. If a multi-output model was used '
                f'check that an output name was provided in all the relevant '
                'settings (ModelSpec.label_keys, MetricsSpec.output_names, etc): '
                f'label={label}, output_name={output_name}')
        if isinstance(example_weight, Mapping):
            raise ValueError(
                'unable to prepare example_weight for metric computation because the '
                'example_weight is a dict with unrecognized keys. If a multi-output '
                'model was used check that an output name was provided in all the '
                'relevant settings (ModelSpec.example_weight_keys, '
                f'MetricsSpec.output_names, etc): example_weight={example_weight}, '
                f'output_name={output_name}')

        label, prediction = prepare_labels_and_predictions(
            label, prediction, prediction_key)

        if not allow_none:
            for txt, value in zip(('label', 'prediction'),
                                  (label, prediction)):
                if value is None:
                    raise ValueError(
                        f'no value provided for {txt}\n\n'
                        'This may be caused by a configuration error (i.e. label, '
                        'and/or prediction keys were not specified) or an '
                        'error in the pipeline.')

        example_weight = util.to_numpy(example_weight)
        if require_single_example_weight and example_weight.size > 1:
            example_weight = example_weight.flatten()
            if not np.all(example_weight == example_weight[0]):
                raise ValueError(
                    'if example_weight size > 0, the values must all be the same: '
                    f'example_weight={example_weight}\n\n'
                    'This is most likely a configuration error.')
            example_weight = np.array(example_weight[0])

        if sub_key is not None:
            if sub_key.class_id is not None:
                label, prediction = select_class_id(sub_key.class_id, label,
                                                    prediction)
            elif sub_key.k is not None:
                indices = top_k_indices(sub_key.k, prediction)
                if len(prediction.shape) == 1:
                    indices = indices[0]  # 1D
                else:
                    # 2D, take kth values
                    indices = (indices[0][0::sub_key.k],
                               indices[1][0::sub_key.k])
                if label.shape != prediction.shape:
                    label = one_hot(label, prediction)
                label = select_indices(label, indices)
                prediction = select_indices(prediction, indices)
            elif sub_key.top_k is not None:
                # Set all non-top-k predictions to -inf. Note that we do not sort.
                indices = top_k_indices(sub_key.top_k, prediction)
                if aggregation_type is None:
                    top_k_predictions = np.full(prediction.shape,
                                                float('-inf'))
                    top_k_predictions[indices] = prediction[indices]
                    prediction = top_k_predictions
                else:
                    if label.shape != prediction.shape:
                        label = one_hot(label, prediction)
                    label = select_indices(label, indices)
                    prediction = select_indices(prediction, indices)

        # For consistency, make sure all outputs are arrays (i.e. convert scalars)
        if label is not None and not label.shape:
            label = label.reshape((1, ))
        if prediction is not None and not prediction.shape:
            prediction = prediction.reshape((1, ))
        if not example_weight.shape:
            example_weight = example_weight.reshape((1, ))

        label = label if label is not None else np.array([])
        prediction = prediction if prediction is not None else np.array([])

        flatten_size = prediction.size or label.size
        if flatten:
            if example_weight.size == 1:
                example_weight = np.array(
                    [float(example_weight) for i in range(flatten_size)])
            elif example_weight.size != flatten_size:
                raise ValueError(
                    'example_weight size does not match the size of labels and '
                    'predictions: label={}, prediction={}, example_weight={}'.
                    format(label, prediction, example_weight))

        if class_weights:
            if not flatten:
                raise ValueError(
                    'class_weights can only be used when flatten is also used: '
                    f'class_weights={class_weights}, flatten={flatten}\n\n'
                    'This is likely caused by a configuration error (i.e. micro '
                    "averaging being applied to metrics that don't support micro "
                    'averaging')
            example_weight = np.array([
                example_weight[i] *
                class_weights[i] if i in class_weights else 0.0
                for i in range(flatten_size)
            ])

        def yield_results(label, prediction, example_weight):
            if (not flatten or (label.size == 0 and prediction.size == 0)
                    or (label.size == 1 and prediction.size == 1
                        and example_weight.size == 1)):
                if squeeze:
                    yield _squeeze(label), _squeeze(prediction), _squeeze(
                        example_weight)
                else:
                    yield label, prediction, example_weight
            elif label.size == 0:
                for p, w in zip(prediction.flatten(),
                                example_weight.flatten()):
                    yield label, np.array([p]), np.array([w])
            elif prediction.size == 0:
                for l, w in zip(label.flatten(), example_weight.flatten()):
                    yield np.array([l]), prediction, np.array([w])
            elif label.size == prediction.size and label.size == example_weight.size:
                for l, p, w in zip(label.flatten(), prediction.flatten(),
                                   example_weight.flatten()):
                    yield np.array([l]), np.array([p]), np.array([w])
            elif label.shape[
                    -1] == 1 and prediction.size == example_weight.size:
                label = one_hot(label, prediction)
                for l, p, w in zip(label.flatten(), prediction.flatten(),
                                   example_weight.flatten()):
                    yield np.array([l]), np.array([p]), np.array([w])
            else:
                raise ValueError(
                    'unable to pair labels, predictions, and example weights: '
                    f'label={label}, prediction={prediction}, '
                    f'example_weight={example_weight}\n\n'
                    'This is most likely a configuration error.')

        for result in yield_results(label, prediction, example_weight):
            if fractional_labels and label.size:
                for new_result in _yield_fractional_labels(*result):
                    yield new_result
            else:
                yield result
    except Exception as e:
        import sys  # pylint: disable=g-import-not-at-top
        raise type(e)(str(e) + f'\n\n{fn_call_str()}').with_traceback(
            sys.exc_info()[2])
示例#6
0
def top_k_indices(
        top_k: int,
        scores: Any,
        sort: bool = False
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
    """Returns top_k indices into a list of scores.

  Note that the indices are returned in a form that is useful for assigning
  values to the array. If using to select values from an array you may need to
  reshape the output. Examples:

     # Assigning values to scores based on indices
     indices = top_k_indices(1, scores)
     scores[indices] = 0.0

     # Selecting top_k
     indices = top_k_indices(scores)
     scores[indices].reshape(scores.shape[:-1] + (top_k,))

  Args:
    top_k: Number of top k values to return.
    scores: Array or list of scores for computing the top_k indices.
    sort: True if the indices should be sorted (in descending order).

  Returns:
    An array of indices into scores that can be used with either 1D or 2D
    arrays. If sort was True the indices will be returned in descending order of
    score (i.e. top score first).

  Raises:
    ValueError: If top_k doesn't match scores or input has more than 2 dims.
  """
    scores = util.to_numpy(scores)
    if scores.shape[-1] < top_k:
        raise ValueError(
            'not enough values were provided to perform the requested '
            f'calcuations for top k. The requested value for k is {top_k}, but the '
            f'values are {scores}\n\nThis may be caused by a metric configuration '
            'error or an error in the pipeline.')

    if len(scores.shape) == 1:
        # 1D data
        indices = np.argpartition(scores, -top_k)[-top_k:]
        if sort:
            indices = indices[np.argsort(-scores[indices])]
        return indices
    elif len(scores.shape) == 2:
        # 2D data
        indices = np.argpartition(scores, -top_k, axis=-1)[:, -top_k:]
        # The above creates an n x top_k matrix where each row in indices matches
        # the corresponding row in scores. For example:
        #   [
        #      [<row1_top_k_index_1>, <row_1_top_k_index_2>, ...],
        #      [<row2_top_k_index_1>, <row_2_top_k_index_2>, ...],
        #      ...
        #   ]
        # However numpy indexing wants the index to be be a 2-tuple of where the
        # first tuple value contains the row indices (repeated top k times for each
        # row) and the second tuple value contains the column values.
        #   (row1, row1, ..., row2, ...), (row1_top_k_index1, row1_top_index_2,...)
        if sort:
            for i in range(indices.shape[0]):
                indices[i] = indices[i][np.argsort(-scores[i][indices[i]])]
        return np.arange(indices.shape[0]).repeat(top_k), indices.flatten()
    else:
        raise NotImplementedError(
            'top_k not supported for shapes > 2: scores = {}'.format(scores))