Example #1
0
    def get_features_predictions_labels_dicts(
        self
    ) -> Tuple[types.TensorTypeMaybeDict, types.TensorTypeMaybeDict,
               types.TensorTypeMaybeDict]:
        """Returns features, predictions, labels dictionaries (or values).

    The dictionaries contain references to the nodes, so they can be used
    to construct new metrics similarly to how metrics can be constructed in
    the Trainer.

    Returns:
      Tuple of features, predictions, labels dictionaries (or values).
    """
        features = {}
        for key, value in self._features_map.items():
            features[key] = value
        # Unnest if it wasn't a dictionary to begin with.
        features = util.extract_tensor_maybe_dict(constants.FEATURES_NAME,
                                                  features)

        predictions = {}
        for key, value in self._predictions_map.items():
            predictions[key] = value
        # Unnest if it wasn't a dictionary to begin with.
        predictions = util.extract_tensor_maybe_dict(
            constants.PREDICTIONS_NAME, predictions)

        labels = {}
        for key, value in self._labels_map.items():
            labels[key] = value
        # Unnest if it wasn't a dictionary to begin with.
        labels = util.extract_tensor_maybe_dict(constants.LABELS_NAME, labels)

        return (features, predictions, labels)
Example #2
0
    def predict_list(
            self, inputs: MultipleInputFeedType) -> List[FetchedTensorValues]:
        """Like predict, but takes a list of inputs.

    Args:
      inputs: A list of input data (or a dict of keys to lists of input data).
        See predict for more details.

    Returns:
       A list of FetchedTensorValues. See predict for more details.

    Raises:
      ValueError: If the original input_refs tensor passed to the
        EvalInputReceiver does not align with the features, predictions and
        labels returned after feeding the inputs.
    """
        if isinstance(inputs, dict):
            input_args = []
            # Only add values for keys that are in the input map (in order).
            for key in self._input_map:
                if key in inputs:
                    input_args.append(inputs[key])
        else:
            input_args = [inputs]

        (features, predictions, labels, input_refs,
         additional_fetches) = self._predict_list_fn(*input_args)

        all_fetches = additional_fetches
        all_fetches[constants.FEATURES_NAME] = features
        all_fetches[constants.LABELS_NAME] = labels
        all_fetches[constants.PREDICTIONS_NAME] = predictions

        # TODO(cyfoo): Optimise this.
        split_fetches = {}
        for group, tensors in all_fetches.items():
            split_tensors = {}
            for key in tensors:
                split_tensors[key] = util.split_tensor_value(tensors[key])
            split_fetches[group] = split_tensors

        result = []

        if (not isinstance(input_refs, np.ndarray) or input_refs.ndim != 1
                or not np.issubdtype(input_refs.dtype, np.integer)):
            raise ValueError(
                'input_refs should be an 1-D array of integers. input_refs was {}.'
                .format(input_refs))

        for group, tensors in split_fetches.items():
            for result_key, split_values in tensors.items():
                if len(split_values) != input_refs.shape[0]:
                    raise ValueError(
                        'input_refs should be batch-aligned with fetched values; {} key '
                        '{} had {} slices but input_refs had batch size of {}'.
                        format(group, result_key, len(split_values),
                               input_refs.shape[0]))

        for i, input_ref in enumerate(input_refs):
            if input_ref < 0 or input_ref >= len(inputs):
                raise ValueError(
                    'An index in input_refs is out of range: {} vs {}; '
                    'inputs: {}'.format(input_ref, len(inputs), inputs))
            values = {}
            for group, split_tensors in split_fetches.items():
                tensor_values = {}
                for key, split_value in split_tensors.items():
                    tensor_values[key] = split_value[i]
                values[group] = util.extract_tensor_maybe_dict(
                    group, tensor_values)

            result.append(
                FetchedTensorValues(input_ref=input_ref, values=values))

        return result