def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature): """Converts inputs to pass into a function with an explicit signature.""" def format_error_message(inputs, input_signature): return (" inputs: (\n" + " " + ",\n ".join(str(i) for i in inputs) + ")\n" + " input_signature: (\n" + " " + ",\n ".join(str(i) for i in input_signature) + ")") try: flatten_inputs = nest.flatten_up_to( input_signature, inputs[:len(input_signature)], expand_composites=True, check_types=False) # lists are convert to tuples for `tf.data`. except ValueError: raise ValueError("Structure of Python function inputs does not match " "input_signature:\n" f"{format_error_message(inputs, input_signature)}.") need_packing = False for index, (value, spec) in enumerate(zip(flatten_inputs, flat_input_signature)): if (isinstance(spec, tensor_spec.TensorSpec) and not _pywrap_utils.IsTensor(value)): try: flatten_inputs[index] = ops.convert_to_tensor( value, dtype_hint=spec.dtype) need_packing = True except ValueError: raise ValueError( "When input_signature is provided, all inputs to " "the Python function must be convertible to " "tensors:\n" f"{format_error_message(inputs, input_signature)}.") if any(not spec.is_compatible_with(other) for spec, other in zip(flat_input_signature, flatten_inputs)): raise ValueError("Python inputs incompatible with input_signature:\n" f"{format_error_message(inputs, input_signature)}.") if need_packing: inputs = nest.pack_sequence_as(structure=input_signature, flat_sequence=flatten_inputs, expand_composites=True) flat_inputs = nest.flatten(inputs, expand_composites=True) return (inputs, flat_inputs, [ t for t in flat_inputs if isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable)) ])
def watch(self, tensor): """Ensures that `tensor` is being traced by this tape. Args: tensor: a Tensor or list of Tensors. Raises: ValueError: if it encounters something that is not a tensor. """ for t in nest.flatten(tensor, expand_composites=True): if not (_pywrap_utils.IsTensor(t) or _pywrap_utils.IsVariable(t)): raise ValueError("Passed in object of type {}, not tf.Tensor".format( type(t))) if not backprop_util.IsTrainable(t): logging.log_first_n( logging.WARN, "The dtype of the watched tensor must be " "floating (e.g. tf.float32), got %r", 5, t.dtype) if hasattr(t, "handle"): # There are many variable-like objects, all of them currently have # `handle` attribute that points to a tensor. If this changes, internals # of watch_variable need to change as well. tape.watch_variable(self._tape, t) else: tape.watch(self._tape, t)