Example #1
0
def _create_feed_dict_from_input_data(
        input_data: repr_dataset.RepresentativeSample,
        signature_def: meta_graph_pb2.SignatureDef) -> Dict[str, np.ndarray]:
    """Constructs a feed_dict from input data.

  Note: This function should only be used in graph mode.

  This is a helper function that converts an 'input key -> input value' mapping
  to a feed dict. A feed dict is an 'input tensor name -> input value' mapping
  and can be directly passed to the `feed_dict` argument of `sess.run()`.

  Args:
    input_data: Input key -> input value mapping. The input keys should match
      the input keys of `signature_def`.
    signature_def: A SignatureDef representing the function that `input_data` is
      an input to.

  Returns:
    Feed dict, which is intended to be used as input for `sess.run`. It is
    essentially a mapping: input tensor name -> input value. Note that the input
    value in the feed dict is not a `Tensor`.
  """
    feed_dict = {}
    for input_key, input_value in input_data.items():
        input_tensor_name = signature_def.inputs[input_key].name

        value = input_value
        if isinstance(input_value, core.Tensor):
            # Take the data out of the tensor.
            value = input_value.eval()

        feed_dict[input_tensor_name] = value

    return feed_dict
Example #2
0
    def validator(
        sample: repr_dataset.RepresentativeSample
    ) -> repr_dataset.RepresentativeSample:
        """Validates a single instance of representative sample.

    This provides a simple check for `sample` that this is a mapping of
    {input_key: input_value}.

    Args:
      sample: A `RepresentativeSample` to validate.

    Returns:
      `sample` iff it is valid.

    Raises:
      ValueError: iff the sample isn't an instance of `Mapping`.
      KeyError: iff the sample does not have the set of input keys that match
        the input keys of the function.
    """
        if not isinstance(sample, collections.abc.Mapping):
            raise ValueError(
                'Invalid representative sample type. Provide a mapping '
                '(usually a dict) of {input_key: input_value}. '
                f'Got type: {type(sample)} instead.')

        if set(sample.keys()) != expected_input_keys:
            raise KeyError(
                'Invalid input keys for representative sample. The function expects '
                f'input keys of: {set(expected_input_keys)}. '
                f'Got: {set(sample.keys())}. Please provide correct input keys for '
                'representative samples.')

        return sample
Example #3
0
def _convert_values_to_tf_tensors(
        sample: repr_dataset.RepresentativeSample
) -> Mapping[str, core.Tensor]:
    """Converts TensorLike values of `sample` to Tensors.

  Creates a copy of `sample`, where each value is converted to Tensors
  unless it is already a Tensor.
  The values are not converted in-place (i.e. `sample` is not mutated).

  Args:
    sample: A representative sample, which is a map of {name -> tensorlike
      value}.

  Returns:
    Converted map of {name -> tensor}.
  """
    tensor_mapping = {}
    for name, tensorlike_value in sample.items():
        if isinstance(tensorlike_value, core.Tensor):
            tensor_value = tensorlike_value
        else:
            tensor_value = ops.convert_to_tensor_v2_with_dispatch(
                tensorlike_value)

        tensor_mapping[name] = tensor_value

    return tensor_mapping
Example #4
0
def _contains_tensor(sample: repr_dataset.RepresentativeSample) -> bool:
    """Determines whether `sample` contains any tf.Tensors.

  Args:
    sample: A `RepresentativeSample`.

  Returns:
    True iff `sample` contains at least tf.Tensors.
  """
    return any(
        map(lambda value: isinstance(value, core.Tensor), sample.values()))