def _convert_inputs_to_signature(inputs, input_signature,
                                 flat_input_signature):
    """Converts inputs to pass into a function with an explicit signature."""
    def format_error_message(inputs, input_signature):
        return ("  inputs: (\n" + "    " +
                ",\n    ".join(str(i) for i in inputs) + ")\n" +
                "  input_signature: (\n" + "    " +
                ",\n    ".join(str(i) for i in input_signature) + ")")

    try:
        flatten_inputs = nest.flatten_up_to(
            input_signature,
            inputs[:len(input_signature)],
            expand_composites=True,
            check_types=False)  # lists are convert to tuples for `tf.data`.
    except ValueError:
        raise ValueError("Structure of Python function inputs does not match "
                         "input_signature:\n"
                         f"{format_error_message(inputs, input_signature)}.")

    need_packing = False
    for index, (value,
                spec) in enumerate(zip(flatten_inputs, flat_input_signature)):
        if (isinstance(spec, tensor_spec.TensorSpec)
                and not _pywrap_utils.IsTensor(value)):
            try:
                flatten_inputs[index] = ops.convert_to_tensor(
                    value, dtype_hint=spec.dtype)
                need_packing = True
            except ValueError:
                raise ValueError(
                    "When input_signature is provided, all inputs to "
                    "the Python function must be convertible to "
                    "tensors:\n"
                    f"{format_error_message(inputs, input_signature)}.")

    if any(not spec.is_compatible_with(other)
           for spec, other in zip(flat_input_signature, flatten_inputs)):
        raise ValueError("Python inputs incompatible with input_signature:\n"
                         f"{format_error_message(inputs, input_signature)}.")

    if need_packing:
        inputs = nest.pack_sequence_as(structure=input_signature,
                                       flat_sequence=flatten_inputs,
                                       expand_composites=True)

    flat_inputs = nest.flatten(inputs, expand_composites=True)

    return (inputs, flat_inputs, [
        t for t in flat_inputs
        if isinstance(t, (ops.Tensor,
                          resource_variable_ops.BaseResourceVariable))
    ])
Beispiel #2
0
  def watch(self, tensor):
    """Ensures that `tensor` is being traced by this tape.

    Args:
      tensor: a Tensor or list of Tensors.

    Raises:
      ValueError: if it encounters something that is not a tensor.
    """
    for t in nest.flatten(tensor, expand_composites=True):
      if not (_pywrap_utils.IsTensor(t) or _pywrap_utils.IsVariable(t)):
        raise ValueError("Passed in object of type {}, not tf.Tensor".format(
            type(t)))
      if not backprop_util.IsTrainable(t):
        logging.log_first_n(
            logging.WARN, "The dtype of the watched tensor must be "
            "floating (e.g. tf.float32), got %r", 5, t.dtype)
      if hasattr(t, "handle"):
        # There are many variable-like objects, all of them currently have
        # `handle` attribute that points to a tensor. If this changes, internals
        # of watch_variable need to change as well.
        tape.watch_variable(self._tape, t)
      else:
        tape.watch(self._tape, t)