def _watch(self, primals, tangents):
        """Ensures that `primals` are being traced by this accumulator.

    Mathematically, `tangents` is a vector right-multiplying the Jacobian matrix
    (a Jacobian-vector product) for the function computed while this accumulator
    is active. Since JVPs are computed in forward mode as the computation
    happens, this vector must be supplied in advance.

    Watching a single tensor multiple times sums each of its `tangents`. Any
    un-watched tensor has zeros for its tangent vector.

    Args:
      primals: A Tensor or list of Tensors.
      tangents: A Tensor or list of Tensors matching `primals`.
    """
        nest.assert_same_structure(primals, tangents)
        for t, g in zip(nest.flatten(primals), nest.flatten(tangents)):
            if not t.dtype.is_floating:
                logging.log_first_n(
                    logging.WARN, "The dtype of the watched primal must be "
                    "floating (e.g. tf.float32), got %r", 5, t.dtype)
            g = ops.convert_to_tensor(g, dtype=t.dtype)
            if hasattr(t, "handle"):
                # Run convert_to_tensor to get the captured handle from whichever
                # function we're running if necessary.
                t = ops.convert_to_tensor(t.handle)
            pywrap_tfe.TFE_Py_ForwardAccumulatorWatch(self._accumulator, t, g)
Esempio n. 2
0
 def _watch(primal, tangent):
     if not primal.dtype.is_floating:
         logging.log_first_n(
             logging.WARN, "The dtype of the watched primal must be "
             "floating (e.g. tf.float32), got %r", 5, primal.dtype)
     tangent = ops.convert_to_tensor(tangent, dtype=primal.dtype)
     if hasattr(primal, "handle"):
         # Run convert_to_tensor to get the captured handle from whichever
         # function we're running if necessary.
         primal = ops.convert_to_tensor(primal.handle)
     pywrap_tfe.TFE_Py_ForwardAccumulatorWatch(self._accumulator,
                                               primal, tangent)