def _batch_accumulator(cls, primals, tangents): """Factory constructor to test accumulator on batches of tangents. Args: primals: A tensor or nested structure of tensors to watch. tangents: A tensor or nested structure of tensors, with the same nesting structure as `primals`, with each element being a vector with compatible shape `[None] + primal.shape` of the corresponding primal element. Returns: A batch accumulator object. """ acc = super(ForwardAccumulator, cls).__new__(cls, primals, tangents) acc._recording = False acc._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(True) primal_ids = set() for primal, tangent in zip(nest.flatten(primals), nest.flatten(tangents)): tangent.shape.assert_is_compatible_with( tensor_shape.TensorShape([None]) + primal.shape) if id(primal) in primal_ids: raise ValueError( "Tensor {} was specified as a primal multiple times. This may " "indicate an error. If it was intended, please sum the " "corresponding tangents.") primal_ids.add(id(primal)) acc._watch(primals, tangents) return acc
def __init__(self, primals, tangents): """Specify tensors to watch and their Jacobian-vector products. Mathematically, `tangents` is a vector right-multiplying the Jacobian matrix (a Jacobian-vector product) for the function computed while this accumulator is active. Since JVPs are computed in forward mode as the computation happens, this vector must be supplied in advance. Listing a single tensor multiple times in `primals` raises an exception. Excluding a tensor from `primals` is equivalent to watching it with a tangent tensor of zeros. Args: primals: A tensor or nested structure of tensors to watch. tangents: A tensor or nested structure of tensors, with the same nesting structure as `primals`, with each element being a vector with the same size as the corresponding primal element. Raises: ValueError: If the same tensor or variable is specified multiple times in `primals`. """ self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(False) self._recording = False primal_ids = set() for primal in nest.flatten(primals): if id(primal) in primal_ids: raise ValueError( "Tensor {} was specified as a primal multiple times. This may " "indicate an error. If it was intended, please sum the " "corresponding tangents.") primal_ids.add(id(primal)) self._watch(primals, tangents)