def watch(self, tensor): """Ensures that `tensor` is being traced by this tape. Args: tensor: a Tensor or list of Tensors. Raises: ValueError: if it encounters something that is not a tensor. """ for t in nest.flatten(tensor): if not (pywrap_tensorflow.IsTensor(t) or pywrap_tensorflow.IsVariable(t)): raise ValueError( "Passed in object of type {}, not tf.Tensor".format( type(t))) if not t.dtype.is_floating: logging.log_first_n( logging.WARN, "The dtype of the watched tensor must be " "floating (e.g. tf.float32), got %r", 5, t.dtype) if hasattr(t, "handle"): # There are many variable-like objects, all of them currently have # `handle` attribute that points to a tensor. If this changes, internals # of watch_variable need to change as well. tape.watch_variable(self._tape, t) else: tape.watch(self._tape, t)
def constant_value(tensor, partial=False): # pylint: disable=invalid-name """Returns the constant value of the given tensor, if efficiently calculable. This function attempts to partially evaluate the given tensor, and returns its value as a numpy ndarray if this succeeds. Compatibility(V1): If `constant_value(tensor)` returns a non-`None` result, it will no longer be possible to feed a different value for `tensor`. This allows the result of this function to influence the graph that is constructed, and permits static shape optimizations. Args: tensor: The Tensor to be evaluated. partial: If True, the returned numpy array is allowed to have partially evaluated values. Values that can't be evaluated will be None. Returns: A numpy ndarray containing the constant value of the given `tensor`, or None if it cannot be calculated. Raises: TypeError: if tensor is not an ops.Tensor. """ if isinstance(tensor, ops.EagerTensor): return tensor.numpy() if not pywrap_tensorflow.IsTensor(tensor): return tensor ret = _ConstantValue(tensor, partial) if ret is not None: # The caller may now depend on the constant value of `tensor`, so we # conservatively prevent it from being fed. tensor.graph.prevent_feeding(tensor) return ret