def testConcatWithNonFullyDefinedElementShape(self): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=[None, 2]) l = list_ops.tensor_list_push_back(l, [[0., 1.]]) l = list_ops.tensor_list_push_back(l, [[2., 3.], [4., 5.]]) t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [[0., 1.], [2., 3.], [4., 5.]])
def testStack(self): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=scalar_shape()) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(t, [1.0, 2.0])
def testUnknownShape(self): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=-1) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant([1.0, 2.0])) _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(e, [1.0, 2.0])
def testPushInFullListFails(self): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=[], max_num_elements=1) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Tried to push item into a full list"): l = list_ops.tensor_list_push_back(l, 2.) self.evaluate(l)
def _testStack(self, max_num_elements): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=scalar_shape(), max_num_elements=max_num_elements) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [1.0, 2.0])
def testPushInEmptyListWithUnknownElementShape(self): with self.cached_session(), self.test_scope(): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=None, max_num_elements=2) l = list_ops.tensor_list_push_back(l, [3.0, 4.0]) # Pushing an element with a different shape should raise an error. with self.assertRaisesRegexp(errors.InvalidArgumentError, "Shape"): l = list_ops.tensor_list_push_back(l, 5.) self.evaluate( list_ops.tensor_list_stack(l, element_dtype=dtypes.float32))
def testSetDoesNotUpdatePushIndex(self): with self.cached_session(), self.test_scope(): l = list_ops.empty_tensor_list( element_shape=[], element_dtype=dtypes.float32, max_num_elements=2) # SetItem should not change the push index. l = list_ops.tensor_list_set_item(l, 1, 3.) l = list_ops.tensor_list_push_back(l, 5.) l = list_ops.tensor_list_push_back(l, 7.) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(t, [5., 7.])
def testConcatWithMismatchingTensorShapesFails(self): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=None) l = list_ops.tensor_list_push_back(l, [[0., 1.]]) l = list_ops.tensor_list_push_back(l, [[2.], [4.]]) with self.assertRaisesRegexp( errors.InvalidArgumentError, r"Tried to concat tensors with unequal shapes: " r"\[2\] vs \[1\]"): t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32) self.evaluate(t)
def _testStack(self, max_num_elements): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=[], max_num_elements=max_num_elements) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) if not context.executing_eagerly(): self.assertAllEqual(t.shape.as_list(), [None]) self.assertAllEqual(self.evaluate(t), [1.0, 2.0])
def testGatherGrad(self): with backprop.GradientTape() as tape: l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=scalar_shape()) c0 = constant_op.constant(1.0) tape.watch(c0) l = list_ops.tensor_list_push_back(l, c0) l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) t = list_ops.tensor_list_gather(l, [1, 0], element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [2.0, 1.0]) s = (t[0] + t[1]) * (t[0] + t[1]) dt = tape.gradient(s, c0) self.assertAllEqual(self.evaluate(dt), 6.0)
def testStack(self): with self.cached_session(), self.test_scope(): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=[], max_num_elements=2) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) self.assertAllEqual(e, 1.0) l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(t.shape.as_list(), [None]) self.assertAllEqual(t, [1.0, 2.0])
def testPushPop(self): with self.cached_session() as sess, self.test_scope(): num = array_ops.placeholder(dtypes.int32) l = list_ops.tensor_list_reserve( element_shape=(7, 15), num_elements=num, element_dtype=dtypes.float32) l = list_ops.tensor_list_push_back( l, constant_op.constant(1.0, shape=(7, 15))) l = list_ops.tensor_list_push_back( l, constant_op.constant(2.0, shape=(7, 15))) l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(sess.run(e2, {num: 10}), 2.0 * np.ones((7, 15))) self.assertAllEqual(sess.run(e1, {num: 10}), 1.0 * np.ones((7, 15)))
def testPushPop(self): with self.cached_session() as sess, self.test_scope(): l = list_ops.empty_tensor_list( element_shape=(7, 15), element_dtype=dtypes.float32, max_num_elements=10) l = list_ops.tensor_list_push_back( l, constant_op.constant(1.0, shape=(7, 15))) l = list_ops.tensor_list_push_back( l, constant_op.constant(2.0, shape=(7, 15))) l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(sess.run(e2), 2.0 * np.ones((7, 15))) self.assertAllEqual(sess.run(e1), 1.0 * np.ones((7, 15)))
def testStackWithPartiallyDefinedElementShape(self): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=[-1]) l = list_ops.tensor_list_push_back(l, constant_op.constant([1.0])) l = list_ops.tensor_list_push_back(l, constant_op.constant([2.0])) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [[1.0], [2.0]]) # Should raise an error when the element tensors do not all have the same # shape. with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"): l = list_ops.tensor_list_push_back(l, constant_op.constant([2.0, 3.0])) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.evaluate(t)
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" body_graph = _get_body_graph(op) # Replace None gradients with zeros. This is needed because `grads` could have # None incoming gradients for the TensorLists. If we pass None's through, the # custom gradient of TensorListPopBack will create an EmptyTensorList inside # the FuncGraph which is undesirable. # TODO(b/80444525): There might be an issue with treating no gradient as zero # gradient in certain cases. Consider replacing None gradients with Zeros # for accumulators only. grads = [ g if g is not None else array_ops.zeros_like(output) for g, output in zip(grads, op.outputs) ] body_grad_graph, args = _create_grad_func( body_graph, grads, util.unique_grad_fn_name(body_graph.name), op) intermediate_tensors = _get_intermediates(body_grad_graph) for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=_get_tensor_convertible_shape(intermediate_tensor.shape)) with body_grad_graph.as_default(): tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True) # Push the intermediate tensor to the tensor list. appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_grad_graph.outputs.append(appended_tensor_list) def grad_cond(counter, max_iters, *unused_args): return counter < max_iters loop_vars = args + body_grad_graph.external_captures grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) assert len(loop_vars) == len(body_grad_graph.inputs) assert len(loop_vars) == len(body_grad_graph.outputs) assert len(loop_vars) == len(cond_grad_graph.inputs) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], name="%s_grad" % op.name) _copy_handle_data(body_grad_graph.outputs, outputs) _maybe_set_lowering_attr(outputs[0].op) # outputs[0] is the loop counter. # outputs[1] is the total number of loop iterations. return outputs[2:2 + len(op.inputs)]
def testDefaultGradYs(self): with ops.Graph().as_default(): tl = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) a = constant(1.0) tl = list_ops.tensor_list_push_back(tl, a) grad_tl = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0)) grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0] with self.cached_session() as sess: self.assertEquals(self.evaluate(grad), 5.)
def testConcatListWithScalarElementsFails(self): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=None) l1 = list_ops.tensor_list_push_back(l, 1.) with self.assertRaisesRegexp( errors.InvalidArgumentError, "Concat saw a scalar shape at index 0" " but requires at least vectors"): t = list_ops.tensor_list_concat(l1, element_dtype=dtypes.float32) self.evaluate(t) l1 = list_ops.tensor_list_push_back(l, [1.]) l1 = list_ops.tensor_list_push_back(l1, 2.) with self.assertRaisesRegexp( errors.InvalidArgumentError, "Concat saw a scalar shape at index 1" " but requires at least vectors"): t = list_ops.tensor_list_concat(l1, element_dtype=dtypes.float32) self.evaluate(t)
def testPushPopSeparateLists(self): with self.cached_session() as sess, self.test_scope(): l = list_ops.empty_tensor_list( element_shape=[], element_dtype=dtypes.float32, max_num_elements=20) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l2 = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) l3 = list_ops.tensor_list_push_back(l, constant_op.constant(3.0)) _, e11 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) l2, e21 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32) l2, e22 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32) l3, e31 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) l3, e32 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) result = sess.run([e11, [e21, e22], [e31, e32]]) self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]])
def _testPushPop(self, max_num_elements): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=[], max_num_elements=max_num_elements) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(e), 1.0)
def testStackWithUnknownElementShape(self, max_num_elements): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=None, max_num_elements=max_num_elements) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [1.0, 2.0]) # Should raise an error when the element tensors do not all have the same # shape. with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"): l = list_ops.tensor_list_push_back(l, constant_op.constant([3.0, 4.0])) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.evaluate(t)
def body(i, m, t1): t1 = control_flow_ops.cond( math_ops.equal(list_ops.tensor_list_length(t1), 0), lambda: list_ops.empty_tensor_list(m.shape, m.dtype), lambda: t1) t1 = list_ops.tensor_list_push_back(t1, m * i) i += 1.0 return i, m, t1
def testPushPopSeparateLists(self): with self.cached_session() as sess, self.test_scope(): num = array_ops.placeholder(dtypes.int32) l = list_ops.tensor_list_reserve( element_shape=scalar_shape(), num_elements=num, element_dtype=dtypes.float32) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l2 = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) l3 = list_ops.tensor_list_push_back(l, constant_op.constant(3.0)) _, e11 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) l2, e21 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32) l2, e22 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32) l3, e31 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) l3, e32 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) result = sess.run([e11, [e21, e22], [e31, e32]], {num: 20}) self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]])
def tf_tensor_list_new(elements, element_dtype=None, element_shape=None): """Overload of new_list that stages a Tensor list creation.""" if tensor_util.is_tensor(elements): if element_shape is not None: raise ValueError( 'element shape may not be specified when creating list from tensor') element_shape = array_ops.shape(elements)[1:] l = list_ops.tensor_list_from_tensor(elements, element_shape=element_shape) return l elements = tuple(ops.convert_to_tensor(el) for el in elements) all_dtypes = set(el.dtype for el in elements) if len(all_dtypes) == 1: inferred_dtype = tuple(all_dtypes)[0] if element_dtype is not None and element_dtype != inferred_dtype: raise ValueError( 'incompatible dtype; specified: {}, inferred from {}: {}'.format( element_dtype, elements, inferred_dtype)) elif all_dtypes: # Heterogeneous lists are ok. if element_dtype is not None: raise ValueError( 'specified dtype {} is inconsistent with that of elements {}'.format( element_dtype, elements)) inferred_dtype = dtypes.variant else: inferred_dtype = dtypes.variant all_shapes = set(tuple(el.shape.as_list()) for el in elements) if len(all_shapes) == 1: inferred_shape = array_ops.shape(elements[0]) if element_shape is not None and element_shape != inferred_shape: raise ValueError( 'incompatible shape; specified: {}, inferred from {}: {}'.format( element_shape, elements, inferred_shape)) elif all_shapes: # Heterogeneous lists are ok. if element_shape is not None: raise ValueError( 'specified shape {} is inconsistent with that of elements {}'.format( element_shape, elements)) inferred_shape = constant_op.constant(-1) # unknown shape, by convention else: inferred_shape = constant_op.constant(-1) # unknown shape, by convention if element_dtype is None: element_dtype = inferred_dtype if element_shape is None: element_shape = inferred_shape element_shape = ops.convert_to_tensor(element_shape, dtype=dtypes.int32) l = list_ops.empty_tensor_list( element_shape=element_shape, element_dtype=element_dtype) for el in elements: l = list_ops.tensor_list_push_back(l, el) return l
def testDoNotConstantFoldVariants(self): with self.cached_session() as sess, self.test_scope(): val = array_ops.placeholder(dtype=dtypes.float32) l = list_ops.empty_tensor_list( element_shape=(7, 15), element_dtype=dtypes.float32, max_num_elements=10) # Note: Pushing a Placeholder will force the constant folding code # to build a Const node with a DT_VARIANT output. This tests that XLA # passes a cf_consider_fn which prevent folding such nodes. l = list_ops.tensor_list_push_back( l, array_ops.fill(value=val, dims=(7, 15))) l = list_ops.tensor_list_push_back( l, constant_op.constant(2.0, shape=(7, 15))) l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(sess.run(e2, {val: 1.0}), 2.0 * np.ones((7, 15))) self.assertAllEqual(sess.run(e1, {val: 1.0}), 1.0 * np.ones((7, 15)))
def testGatherWithUnknownElementShape(self): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=-1) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant([3.0, 4.0])) t = list_ops.tensor_list_gather(l, [1, 0], element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [2.0, 1.0]) t = list_ops.tensor_list_gather(l, [2], element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [[3.0, 4.0]]) # Should raise an error when the requested tensors do not all have the same # shape. with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"): t = list_ops.tensor_list_gather(l, [0, 2], element_dtype=dtypes.float32) self.evaluate(t)
def testGraphStack(self): with context.graph_mode(), self.test_session(): tl = list_ops.empty_tensor_list( element_shape=constant_op.constant([1], dtype=dtypes.int32), element_dtype=dtypes.int32) tl = list_ops.tensor_list_push_back(tl, [1]) self.assertAllEqual( list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32).eval(), [[1]])
def testEmptyTensorListMax(self): with self.cached_session() as sess, self.test_scope(): l = list_ops.empty_tensor_list( element_shape=(10, 15), element_dtype=dtypes.float32, max_num_elements=2) l = list_ops.tensor_list_push_back( l, array_ops.fill(value=3.0, dims=(10, 15))) _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(sess.run(e), 3.0 * np.ones((10, 15)))
def testEmptyTensorListNoMax(self): with self.cached_session() as sess, self.test_scope(): l = list_ops.empty_tensor_list( element_shape=(7, 15), element_dtype=dtypes.float32) l = list_ops.tensor_list_push_back( l, constant_op.constant(1.0, shape=(7, 15))) _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Set the max number of elements"): self.assertAllEqual(sess.run(e), 1.0 * np.ones((7, 15)))
def testPushPopGradients(self): with backprop.GradientTape() as tape: l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=[]) c = constant_op.constant(1.0) tape.watch(c) l = list_ops.tensor_list_push_back(l, c) l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) e = 2 * e self.assertAllEqual(self.evaluate(tape.gradient(e, [c])[0]), 2.0)
def testGraphStack(self): with self.cached_session(): tl = list_ops.empty_tensor_list( element_shape=constant_op.constant([1], dtype=dtypes.int32), element_dtype=dtypes.int32) tl = list_ops.tensor_list_push_back(tl, [1]) self.assertAllEqual( self.evaluate( list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32)), [[1]])
def testDoNotConstantFoldVariants(self): with self.session() as sess, self.test_scope(): val = array_ops.placeholder(dtype=dtypes.float32) l = list_ops.empty_tensor_list(element_shape=(7, 15), element_dtype=dtypes.float32, max_num_elements=10) # Note: Pushing a Placeholder will force the constant folding code # to build a Const node with a DT_VARIANT output. This tests that XLA # passes a cf_consider_fn which prevent folding such nodes. l = list_ops.tensor_list_push_back( l, array_ops.fill(value=val, dims=(7, 15))) l = list_ops.tensor_list_push_back( l, constant_op.constant(2.0, shape=(7, 15))) l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(sess.run(e2, {val: 1.0}), 2.0 * np.ones( (7, 15))) self.assertAllEqual(sess.run(e1, {val: 1.0}), 1.0 * np.ones( (7, 15)))
def testSerializeListWithMaxNumElements(self): if context.num_gpus(): # TODO(b/119151861): Enable on GPU. return worker = test_util.create_local_cluster(num_workers=1, num_ps=1)[0][0] with ops.Graph().as_default(), session.Session(target=worker.target): with ops.device("/job:worker"): l = list_ops.empty_tensor_list(element_shape=None, element_dtype=dtypes.float32, max_num_elements=2) l = list_ops.tensor_list_push_back(l, 1.) with ops.device("/job:ps"): l_ps = array_ops.identity(l) l_ps = list_ops.tensor_list_push_back(l_ps, 2.) with self.assertRaisesRegexp( errors.InvalidArgumentError, "Tried to push item into a full list"): with ops.device("/job:worker"): l_worker = array_ops.identity(l_ps) l_worker = list_ops.tensor_list_push_back(l_worker, 3.0) self.evaluate(l_worker)
def testEmptyTensorList(self): dim = 7 with self.cached_session() as sess, self.test_scope(): p = array_ops.placeholder(dtypes.int32) l = list_ops.empty_tensor_list( element_shape=(p, 15), element_dtype=dtypes.float32) l = list_ops.tensor_list_push_back( l, constant_op.constant(1.0, shape=(dim, 15))) _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Use TensorListReserve instead"): self.assertEqual(sess.run(e, {p: dim}), 1.0 * np.ones((dim, 15)))
def testGatherWithUnknownElementShape(self): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=-1) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant([3.0, 4.0])) t = list_ops.tensor_list_gather(l, [1, 0], element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [2.0, 1.0]) t = list_ops.tensor_list_gather(l, [2], element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [[3.0, 4.0]]) # Should raise an error when the requested tensors do not all have the same # shape. with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"): t = list_ops.tensor_list_gather(l, [0, 2], element_dtype=dtypes.float32) self.evaluate(t)
def testZerosLikeVariant(self): for dtype in (dtypes.uint8, dtypes.uint16, dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128, dtypes.bool): l = list_ops.empty_tensor_list(element_dtype=dtypes.variant, element_shape=scalar_shape()) sub_l = list_ops.empty_tensor_list(element_dtype=dtype, element_shape=scalar_shape()) l = list_ops.tensor_list_push_back(l, sub_l) sub_l = list_ops.tensor_list_push_back( sub_l, math_ops.cast(1, dtype=dtype)) l = list_ops.tensor_list_push_back(l, sub_l) sub_l = list_ops.tensor_list_push_back( sub_l, math_ops.cast(2, dtype=dtype)) l = list_ops.tensor_list_push_back(l, sub_l) # l : [[], # [1], # [1, 2]] # # l_zeros : [[], # [0], # [0, 0]] l_zeros = array_ops.zeros_like(l) outputs = [] for _ in range(3): l_zeros, out = list_ops.tensor_list_pop_back( l_zeros, element_dtype=dtypes.variant) outputs.append( list_ops.tensor_list_stack(out, element_dtype=dtype)) # Note: `outputs` contains popped values so the order is reversed. self.assertAllEqual(self.evaluate(outputs[2]), []) self.assertAllEqual(self.evaluate(outputs[1]), np.zeros((1, ), dtype=dtype.as_numpy_dtype)) self.assertAllEqual(self.evaluate(outputs[0]), np.zeros((2, ), dtype=dtype.as_numpy_dtype))
def testGatherWithPartiallyDefinedElementShape(self, max_num_elements): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=[-1], max_num_elements=max_num_elements) l = list_ops.tensor_list_push_back(l, constant_op.constant([1.0])) l = list_ops.tensor_list_push_back(l, constant_op.constant([2.0, 3.0])) l = list_ops.tensor_list_push_back(l, constant_op.constant([4.0, 5.0])) t = list_ops.tensor_list_gather(l, [0], element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [[1.0]]) t = list_ops.tensor_list_gather(l, [1, 2], element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [[2.0, 3.0], [4.0, 5.0]]) # Should raise an error when the requested tensors do not all have the same # shape. with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"): t = list_ops.tensor_list_gather(l, [0, 2], element_dtype=dtypes.float32) self.evaluate(t)
def testAddNNestedList(self): l1 = list_ops.tensor_list_from_tensor([1.0, 2.0], element_shape=[]) l2 = list_ops.tensor_list_from_tensor([3.0, 4.0], element_shape=[]) l3 = list_ops.tensor_list_from_tensor([5.0, 6.0], element_shape=[]) l4 = list_ops.tensor_list_from_tensor([7.0, 8.0], element_shape=[]) a = list_ops.empty_tensor_list( element_dtype=dtypes.variant, element_shape=[]) a = list_ops.tensor_list_push_back(a, l1) a = list_ops.tensor_list_push_back(a, l2) b = list_ops.empty_tensor_list( element_dtype=dtypes.variant, element_shape=[]) b = list_ops.tensor_list_push_back(b, l3) b = list_ops.tensor_list_push_back(b, l4) result = math_ops.add_n((a, b)) result_0 = list_ops.tensor_list_stack( list_ops.tensor_list_get_item(result, 0, element_dtype=dtypes.variant), element_dtype=dtypes.float32) result_1 = list_ops.tensor_list_stack( list_ops.tensor_list_get_item(result, 1, element_dtype=dtypes.variant), element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(result_0), [6., 8.]) self.assertAllEqual(self.evaluate(result_1), [10., 12.])
def testZerosLike(self): for dtype in (dtypes.uint8, dtypes.uint16, dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128, dtypes.bool): l_empty = list_ops.empty_tensor_list(element_dtype=dtype, element_shape=scalar_shape()) l_empty_zeros = array_ops.zeros_like(l_empty) t_empty_zeros = list_ops.tensor_list_stack(l_empty_zeros, element_dtype=dtype) l_full = list_ops.tensor_list_push_back( l_empty, math_ops.cast(0, dtype=dtype)) l_full = list_ops.tensor_list_push_back( l_full, math_ops.cast(1, dtype=dtype)) l_full_zeros = array_ops.zeros_like(l_full) t_full_zeros = list_ops.tensor_list_stack(l_full_zeros, element_dtype=dtype) self.assertAllEqual(self.evaluate(t_empty_zeros), []) self.assertAllEqual(self.evaluate(t_full_zeros), np.zeros((2, ), dtype=dtype.as_numpy_dtype))
def _tf_tensor_list_append(list_, x): """Overload of list_append that stages a Tensor list write.""" def empty_list_of_elements_like_x(): tensor_x = ops.convert_to_tensor(x) return list_ops.empty_tensor_list( element_shape=array_ops.shape(tensor_x), element_dtype=tensor_x.dtype) list_ = control_flow_ops.cond( list_ops.tensor_list_length(list_) > 0, lambda: list_, empty_list_of_elements_like_x, ) return list_ops.tensor_list_push_back(list_, x)
def dynamic_list_append(target, element): """Converts a list append call inline.""" if isinstance(target, tensor_array_ops.TensorArray): return target.write(target.size(), element) # TODO(mdan): What's the right way to check this? # TODO(mdan): We may not need this branch. # It may be possible to use TensorList alone if the loop body will not # require wrapping it, although we'd have to think about an autoboxing # mechanism for lists received as parameter. if isinstance(target, ops.Tensor): return list_ops.tensor_list_push_back(target, element) # Python targets (including TensorList): fallback to their original append. target.append(element) return target
def tf_tensor_list_new(elements, element_dtype=None, element_shape=None): """Overload of new_list that stages a Tensor list creation.""" elements = tuple(ops.convert_to_tensor(el) for el in elements) all_dtypes = set(el.dtype for el in elements) if len(all_dtypes) == 1: inferred_dtype = tuple(all_dtypes)[0] if element_dtype is not None and element_dtype != inferred_dtype: raise ValueError( 'incompatible dtype; specified: {}, inferred from {}: {}'. format(element_dtype, elements, inferred_dtype)) else: # Heterogeneous lists are ok. if element_dtype is not None: raise ValueError( 'specified dtype {} is inconsistent with that of elements {}'. format(element_dtype, elements)) inferred_dtype = dtypes.variant all_shapes = set(tuple(el.shape.as_list()) for el in elements) if len(all_shapes) == 1: inferred_shape = array_ops.shape(elements[0]) if element_shape is not None and element_shape != inferred_shape: raise ValueError( 'incompatible shape; specified: {}, inferred from {}: {}'. format(element_shape, elements, inferred_shape)) else: # Heterogeneous lists are ok. if element_shape is not None: raise ValueError( 'specified shape {} is inconsistent with that of elements {}'. format(element_shape, elements)) inferred_shape = constant_op.constant( -1) # unknown shape, by convention if element_dtype is None: element_dtype = inferred_dtype if element_shape is None: element_shape = inferred_shape l = list_ops.empty_tensor_list(element_shape=element_shape, element_dtype=element_dtype) for el in elements: l = list_ops.tensor_list_push_back(l, el) return l
def _tf_tensor_list_new(elements): """Overload of new_list that stages a Tensor list creation.""" elements = tuple(ops.convert_to_tensor(el) for el in elements) all_dtypes = set(el.dtype for el in elements) if len(all_dtypes) == 1: element_dtype = tuple(all_dtypes)[0] else: # Heterogeneous lists are ok. element_dtype = dtypes.variant # TODO(mdan): This may fail for elements of variable shapes. all_shapes = set(tuple(el.shape.as_list()) for el in elements) if len(all_shapes) == 1: element_shape = array_ops.shape(elements[0]) else: # Heterogeneous lists are ok. element_shape = constant_op.constant(-1) # unknown shape, by convention l = list_ops.empty_tensor_list( element_shape=element_shape, element_dtype=element_dtype) for el in elements: l = list_ops.tensor_list_push_back(l, el) return l
def testCPUGPUCopyNested(self): if not context.num_gpus(): return t = constant_op.constant([1.0, 2.0]) child_l = list_ops.tensor_list_from_tensor(t, element_shape=[]) l = list_ops.empty_tensor_list( element_shape=constant_op.constant([], dtype=dtypes.int32), element_dtype=dtypes.variant) l = list_ops.tensor_list_push_back(l, child_l) with context.device("gpu:0"): l_gpu = array_ops.identity(l) _, child_l_gpu = list_ops.tensor_list_pop_back( l_gpu, element_dtype=dtypes.variant) self.assertAllEqual( self.evaluate( list_ops.tensor_list_pop_back( child_l_gpu, element_dtype=dtypes.float32)[1]), 2.0) l_cpu = array_ops.identity(l_gpu) _, child_l_cpu = list_ops.tensor_list_pop_back( l_cpu, element_dtype=dtypes.variant) self.assertAllEqual( self.evaluate( list_ops.tensor_list_pop_back( child_l_cpu, element_dtype=dtypes.float32)[1]), 2.0)
def Body(x, tl): return x + 1, list_ops.tensor_list_push_back(tl, x)
def InnerBody(inner_x, outer_x, tl): return inner_x + 1, outer_x + 1, list_ops.tensor_list_push_back( tl, x)
def body(list_, m): list_ = control_flow_ops.cond( math_ops.equal(list_ops.tensor_list_length(list_), 0), lambda: list_ops.empty_tensor_list(m.shape, m.dtype), lambda: list_) list_ = list_ops.tensor_list_push_back(list_, m) return list_, m
def Body(x, tl): # There is an accumulator in the loop already so we should not add # another. tl = list_ops.tensor_list_push_back(tl, x) return x**2., tl
def testPushPop(self): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=scalar_shape()) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(e), 1.0)
def _capture_helper(self, tensor, name): if tensor.graph is not self._forward_graph: return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name) while tensor.op.type == "Identity": # We do not accumulate the output of identity nodes so we try to capture # the input of the Identity node instead. tensor = tensor.op.inputs[0] captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor)) if captured_tensor is not None: return captured_tensor # Do not accumulate loop invariants. if (any(tensor is t for t in self._forward_graph.inputs) and any(tensor is t for t in self._forward_graph.outputs)): captured_tensor = super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name) # Add to `popped_tensor_lists` so that this gets added to the list of # outputs. # TODO(srbs): Rename popped_tensor_lists. self.popped_tensor_lists[ops.tensor_id(captured_tensor)] = captured_tensor self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor return captured_tensor # Do not accumulate Const nodes. Instead copy them directly in the backward # graph. # TODO(srbs): This just checks for `Const` nodes. Consider checking for # graph compile time consts in general. # TODO(srbs): Consider making this a loop input. if constant_op.is_constant(tensor): real_value = constant_op.constant( tensor_util.constant_value(tensor), dtype=tensor.dtype) self._indirect_captures[ops.tensor_id(tensor)] = real_value return real_value # Resource tensors are not accumulated and handled specially. if tensor.dtype == dtypes.resource: return self._resource_capture_helper(tensor) # No need to accumulate loop invariants. Capture them directly. # The captured tensor gets resolved to the corresponding while output in # `_resolve_grad_captures`. if _is_loop_invariant(tensor, self._forward_graph_inputs, self._forward_graph_outputs): captured_tensor = super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name) return captured_tensor # Create or find an existing accumulator output for `tensor` in the forward # graph, and fetch from this accumulator in the gradient graph to get the # raw intermediate value. accumulator = _get_accumulator(tensor) if accumulator is None: # Create the initial empty tensor list. # # Note: We clear the control dependencies to avoid a cycle in case a # control tensor has an input path to an output of the forward While. # # E.g.: # x = tf.while_loop(...) # y = f(x) # with tf.control_dependencies([y]): # tf.gradients(y, x) # # Since the EmptyTensorList is fed back into the forward While, not # removing the control edge would cause a cycle. with self._forward_graph.outer_graph.as_default(): with util.clear_control_inputs(): tensor_list = list_ops.empty_tensor_list( element_dtype=tensor.dtype, element_shape=tensor.shape, max_num_elements=self._maximum_iterations, name=_build_accumulator_name(tensor)) self.empty_tensor_lists.append(tensor_list) # Push the intermediate tensor to the tensor list. This captures # `tensor_list`. with self._forward_graph.as_default(): accumulator = list_ops.tensor_list_push_back(tensor_list, tensor) # Add the modified tensor list to the list of outputs. This output will be # all the accumulated values. self._forward_graph.outputs.append(accumulator) # Capture in the cond graph as well so the forward cond and body inputs # match. with self._forward_cond_graph.as_default(): self._forward_cond_graph.capture(tensor_list) # Capture the accumulator tensor list in the gradient graph directly from # the forward graph -- we'll later modify this to capture the final list # output by the forward While op instead. captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper( accumulator, name) # Pop the intermediate value from the tensor list in the gradient graph. new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back( captured_accumulator, element_dtype=tensor.dtype) self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor self.popped_tensor_lists[ops.tensor_id( captured_accumulator)] = new_tensor_list return captured_tensor
def append(self, value): self.list_ = list_ops.tensor_list_push_back(self.list_, value)
def while_loop(cond, body, loop_vars, shape_invariants=None, maximum_iterations=None, name=None): """Like tf.while_loop, except emits a single While op.""" maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations) # Keep the original loop_vars around to know which args were TensorArrays. orig_loop_vars = loop_vars # Cache its length since we use it at multiple places below. len_orig_loop_vars = len(orig_loop_vars) # Convert TensorArrays to their flow variables. These get converted back to # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and # `wrapped_body` below. loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) loop_vars = nest.map_structure( ops.internal_convert_to_tensor_or_indexed_slices, loop_vars) if shape_invariants is not None: nest.assert_same_structure(orig_loop_vars, shape_invariants) else: shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars) if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") loop_counter = constant_op.constant( 0, dtype=maximum_iterations.dtype if maximum_iterations is not None else None, name="loop_counter") # Add loop counter needed for computing gradients. loop_vars = [loop_counter] + loop_vars shape_invariants = type(shape_invariants)([tensor_shape.scalar() ]) + shape_invariants # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = util.in_defun() # Build a `cond` wrapper that can handle the extra counter loop_var. def wrapped_cond(loop_counter, *args): # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. if maximum_iterations is None: return cond(*_pack_sequence_as(orig_loop_vars, args)) else: return math_ops.logical_and( loop_counter < maximum_iterations, cond(*_pack_sequence_as(orig_loop_vars, args))) cond_graph = func_graph_module.func_graph_from_py_func( cond_name, wrapped_cond, loop_vars, {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileCondFuncGraph(cond_name), add_control_dependencies=add_control_dependencies) # Add external_captures of cond to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + cond_graph.external_captures shape_invariants = shape_invariants + type(shape_invariants)( [t.shape for t in cond_graph.external_captures]) def wrapped_body(loop_counter, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. *args: List of args args[:len_orig_loop_vars] - Args for the original loop body. args[len_orig_loop_vars:] - External captures of cond. These get passed through as is. Returns: A list of tensors the same length as args. """ # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body( *_pack_sequence_as(orig_loop_vars, args[:len_orig_loop_vars])) if not nest.is_sequence(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars)) outputs = _tensor_array_to_flow(outputs) # Return the external_captures of cond_graph as is, i.e., treat them as # loop invariants. # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1] + list(outputs) + list( args[len_orig_loop_vars:]) body_graph = func_graph_module.func_graph_from_py_func( body_name, wrapped_body, loop_vars, {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileBodyFuncGraph(body_name), add_control_dependencies=add_control_dependencies) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture `external_captures` of `body_graph` in `cond_graph` so that it # expects to receive those as arguments. # TODO(b/118457764): Dedup tensors that are captured in both the cond and # body. This logic already exists in cond_v2. with cond_graph.as_default(): for external_capture in body_graph.external_captures: assert external_capture not in cond_graph.captures, ( "Looks like both cond and body are capturing the same tensor %s. " "This is not supported yet. For now consider passing," " this as a loop variable." % str(external_capture)) cond_graph.capture(external_capture) # Export all tensors in the loop body that may be needed for gradient # computation. We do this by accumulating the intermediate values in # TensorLists. intermediate_tensors = _get_intermediates(body_graph) for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=intermediate_tensor.shape, max_num_elements=maximum_iterations) loop_vars.append(tensor_list) with cond_graph.as_default(): # Add a placeholder to cond_graph's inputs corresponding to the # tensor_list. cond_graph.capture(tensor_list) with body_graph.as_default(): # Push the intermediate tensor to the tensor list. This captures the # `tensor_list` as well. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_graph.outputs.append(appended_tensor_list) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. num_flattened_outputs = len(nest.flatten(orig_loop_vars)) _check_shapes_compat( body_graph.outputs[1:1 + num_flattened_outputs], nest.flatten(shape_invariants[1:1 + len_orig_loop_vars]), nest.flatten(loop_vars[1:1 + len_orig_loop_vars])) flattened_loop_vars = nest.flatten(loop_vars) _check_num_inputs_outputs(cond_graph, body_graph, len(flattened_loop_vars)) outputs = gen_functional_ops._while( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=[t.shape for t in body_graph.outputs], name=scope) _copy_handle_data(body_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations) # Return identities for each output of the While op, rather than the output # of the While op directly. This makes pruning work if the output of # while_loop() is fetched: the lowering pass converts the While outputs into # IdentityN outputs, which if fetched will cause all ops in the body to be # run (since it takes all exit ops as input). After lowering, each output # identity op will end up with only the appropriate exit op as input. outputs = tuple(array_ops.identity(t) for t in outputs) # First var is loop counter. outputs = _pack_sequence_as(orig_loop_vars, outputs[1:1 + num_flattened_outputs]) flattened_outputs = nest.flatten(outputs) if len(flattened_outputs) == 1: return flattened_outputs[0] else: return outputs
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" body_graph = _get_body_graph(op) # Set the incoming gradient of TensorArray handles to None. The gradient # implementation currently assumes all resource tensors correspond to float32 # ResourceVariables, which can lead to runtime shape errors when used with a # TensorArray. This is a workaround until TensorArrays are reimplemented with # TensorLists instead of resources. # Also set the incoming gradient of non-trainable inputs to None. It is # possible that we receive non-None gradients for non-trainable types in # nested while loops because we accumulate outputs of the inner while as # variant tensors which are trainable and hence receive zeros_like tensors in # the gradient pass. The non-trainable tensors then receive the popped zeros # tensor from this zeros variant. The gradient for the loop vars corresponding # to these tensors is None or zeros (this happens only if the loop var is # accumulated as well) in _grad_fn so we reset these. # TODO(b/118712257): Remove the IsTrainable filter once we can handle None # output grads in _grad_fn. grads = [ None if _is_tensor_array_handle(output) or not gradients_impl.IsTrainable(output) else grad for grad, output in zip(grads, op.outputs) ] # Ensure that all non-resource trainable outputs have incoming gradients. assert all(g is not None or o.dtype == dtypes.resource or not gradients_impl.IsTrainable(o) for o, g in zip(op.outputs, grads) ), "All trainable loop vars must receive incoming gradients." # We compute the gradient for the sub-graph between trainable ys and xs # with non-None incoming gradients. We later pad the None's to the list of # outputs. ys, xs, non_none_grads = zip( *[(y, x, grad) for (y, x, grad) in zip(body_graph.outputs, body_graph.inputs, grads) if grad is not None]) body_grad_graph, args = _create_grad_func( ys, xs, non_none_grads, body_graph, util.unique_grad_fn_name(body_graph.name), op) intermediate_tensors = _get_intermediates(body_grad_graph) maximum_iterations = op.get_attr( "_maximum_iterations") if _is_in_xla_context() else None assert not _is_in_xla_context() or maximum_iterations is not None for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=intermediate_tensor.shape, max_num_elements=maximum_iterations) with body_grad_graph.as_default(): tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True) # Push the intermediate tensor to the tensor list. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list_ph, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_grad_graph.outputs.append(appended_tensor_list) def grad_cond(counter, max_iters, *unused_args): return counter < max_iters loop_vars = args + body_grad_graph.external_captures grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], name="%s_grad" % op.name) _copy_handle_data(body_grad_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations) # See comment in while_loop. outputs = [array_ops.identity(t) for t in outputs] # Set None as the output gradient for tensors with None input gradient # e.g. TensorArray handles. # outputs[0] is the loop counter. # outputs[1] is the total number of loop iterations. index = 2 none_padded_outputs = [] for g in grads: if g is None: none_padded_outputs.append(None) else: none_padded_outputs.append(outputs[index]) index += 1 return none_padded_outputs
def _capture_helper(self, tensor, name): if tensor.graph is not self._forward_graph: return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name) while tensor.op.type == "Identity": # We do not accumulate the output of identity nodes so we try to capture # the input of the Identity node instead. tensor = tensor.op.inputs[0] captured_tensor = self._indirect_captures.get(tensor) if captured_tensor is not None: return captured_tensor if tensor.dtype == dtypes.resource: # Resource-type tensors are not accumulated. # If a resource tensor exists in the loop body it must either be a loop # input or an output of a nested While op inside the loop body which # had captured the external resource. if tensor in self._forward_graph.inputs: index = self._forward_graph.inputs.index(tensor) elif tensor.op.type == "While": # Captured resources occur at the same index in the lists of inputs and # outputs of a while op. So we lookup the input of `tensor.op` at the # same index as the index of `tensor` in the `tensor.op.outputs`. index = self._forward_graph.inputs.index( tensor.op.inputs[tensor.value_index]) else: raise ValueError( "Taking gradient of a while loop which creates" " a resource in its body is not supported: %s" % str(tensor)) # This must be a loop invariant. assert self._forward_graph.inputs[ index] == self._forward_graph.outputs[ index], "Resource tensors must be loop invariants %s." % str( self._forward_graph._while.inputs[index]) tensor_in_outer_graph = self._forward_graph._while.inputs[index] self._indirect_captures[tensor] = self.capture( tensor_in_outer_graph, whitelisted=True) return self._indirect_captures[tensor] # Create or find an existing accumulator output for `tensor` in the forward # graph, and fetch from this accumulator in the gradient graph to get the # raw intermediate value. accumulator = _get_accumulator(tensor) if accumulator is None: # Create the initial empty tensor list. with self._forward_graph.outer_graph.as_default(): tensor_list = list_ops.empty_tensor_list( element_dtype=tensor.dtype, element_shape=tensor.shape, max_num_elements=self._maximum_iterations) self.empty_tensor_lists.append(tensor_list) # Push the intermediate tensor to the tensor list. This captures # `tensor_list`. with self._forward_graph.as_default(): accumulator = list_ops.tensor_list_push_back( tensor_list, tensor) # Add the modified tensor list to the list of outputs. This output will be # all the accumulated values. self._forward_graph.outputs.append(accumulator) # Capture in the cond graph as well so the forward cond and body inputs # match. with self._forward_cond_graph.as_default(): self._forward_cond_graph.capture(tensor_list) # Capture the accumulator tensor list in the gradient graph directly from # the forward graph -- we'll later modify this to capture the final list # output by the forward While op instead. captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper(accumulator, name) # Pop the intermediate value from the tensor list in the gradient graph. new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back( captured_accumulator, element_dtype=tensor.dtype) self._indirect_captures[tensor] = captured_tensor self.popped_tensor_lists[captured_accumulator] = new_tensor_list return captured_tensor
def while_loop(cond, body, loop_vars, shape_invariants=None, name=None): """Like tf.while_loop, except emits a single While op.""" flattened_loop_vars = nest.flatten(loop_vars) if shape_invariants is not None: nest.assert_same_structure(loop_vars, shape_invariants) flattened_shapes = nest.flatten(shape_invariants) else: flattened_shapes = [t.shape for t in flattened_loop_vars] del shape_invariants if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") num_outputs = len(flattened_loop_vars) # Add loop counter needed for computing gradients. flattened_loop_vars = [constant_op.constant(0., name="loop_counter") ] + flattened_loop_vars flattened_shapes = [tensor_shape.scalar()] + flattened_shapes # Build a `cond` wrapper that can handle the extra counter loop_var. def wrapped_cond(unused_loop_counter, *loop_vars): return cond(*loop_vars) signature = [ tensor_spec.TensorSpec(shape, t.dtype) for shape, t in zip(flattened_shapes, flattened_loop_vars) ] cond_graph = function.func_graph_from_py_func( cond_name, wrapped_cond, flattened_loop_vars, {}, signature=signature, func_graph=util.WhileCondFuncGraph(cond_name)) # Add external_captures of cond to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures flattened_shapes = flattened_shapes + [ t.shape for t in cond_graph.external_captures ] def wrapped_body(loop_counter, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. *args: List of args args[:num_outputs] - Args for the original loop body. args[num_outputs:] - External captures of cond. These get passed through as is. Returns: A list of tensors the same length as args. """ outputs = body(*args[:num_outputs]) if not isinstance(outputs, collections.Sequence): outputs = [outputs] # Return the external_captures of cond_graph as is, i.e., treat them as # loop invariants. # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1] + list(outputs) + list( args[num_outputs:]) signature = [ tensor_spec.TensorSpec(shape, t.dtype) for shape, t in zip(flattened_shapes, flattened_loop_vars) ] body_graph = function.func_graph_from_py_func( body_name, wrapped_body, flattened_loop_vars, {}, signature=signature, func_graph=util.WhileBodyFuncGraph(body_name)) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. flattened_loop_vars = flattened_loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture `external_captures` of `body_graph` in `cond_graph` so that it # expects to receive those as arguments. # TODO(srbs): Dedup tensors that are captured in both the cond and body. # This logic already exists in cond_v2. with cond_graph.as_default(): for external_capture in body_graph.external_captures: cond_graph.capture(external_capture) # Export all tensors in the loop body that may be needed for gradient # computation. We do this by accumulating the intermediate values in # TensorLists. intermediate_tensors = _get_intermediates(body_graph) for intermediate_tensor in intermediate_tensors: # TODO(srbs): Cache and re-use empty tensor lists. tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=_get_tensor_convertible_shape( intermediate_tensor.shape)) flattened_loop_vars.append(tensor_list) with cond_graph.as_default(): # Add a placeholder to cond_graph's inputs corresponding to the # tensor_list. cond_graph.capture(tensor_list) with body_graph.as_default(): # Push the intermediate tensor to the tensor list. This captures the # `tensor_list` as well. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_graph.outputs.append(appended_tensor_list) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. _check_shapes_compat(body_graph.outputs[1:1 + num_outputs], flattened_shapes[1:1 + num_outputs], flattened_loop_vars[1:1 + num_outputs]) outputs = gen_functional_ops._while( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=[t.shape for t in body_graph.outputs], name=scope) _copy_handle_data(body_graph.outputs, outputs) _maybe_set_lowering_attr(outputs[0].op) # First var is loop counter. if num_outputs == 1: return outputs[1] else: return nest.pack_sequence_as(loop_vars, outputs[1:1 + num_outputs])
def while_loop(cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, maximum_iterations=None, name=None, return_same_structure=True, back_prop=True): """Like tf.while_loop, except emits a single While op.""" # Keep the original loop_vars around to know which args were TensorArrays. orig_loop_vars = loop_vars # Cache its length since we use it at multiple places below. len_orig_loop_vars = len(orig_loop_vars) # Convert TensorArrays to their flow variables. These get converted back to # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and # `wrapped_body` below. loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) loop_vars = nest.map_structure( ops.internal_convert_to_tensor_or_indexed_slices, loop_vars, expand_composites=True) if shape_invariants is not None: nest.assert_same_structure(orig_loop_vars, shape_invariants, expand_composites=False) signature = nest.map_structure( control_flow_ops._shape_invariant_to_type_spec, loop_vars, list(shape_invariants), expand_composites=False) shape_invariants = nest.map_structure( control_flow_ops._get_shape_invariant, loop_vars, list(shape_invariants), expand_composites=False) else: signature = nest.map_structure( type_spec.type_spec_from_value, loop_vars, expand_composites=False) shape_invariants = nest.map_structure( control_flow_ops._get_shape_invariant, loop_vars, expand_composites=False) if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") maximum_iterations_loop_var = _build_maximum_iterations_loop_var( maximum_iterations) loop_counter = constant_op.constant( 0, dtype=maximum_iterations_loop_var.dtype if maximum_iterations is not None else None, name="loop_counter") # Add loop counter needed for computing gradients. loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars shape_invariants = [tensor_shape.TensorShape([])] * 2 + shape_invariants signature = ( [tensor_spec.TensorSpec.from_tensor(loop_counter), tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] + signature) # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = ops.get_default_graph()._add_control_dependencies def wrapped_cond(loop_counter, maximum_iterations_arg, *args): """Extra `cond` wrapper that can handle the extra counter loop_var.""" # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. pred = cond(*_pack_sequence_as(orig_loop_vars, args)) if (tensor_util.is_tensor(pred) and (pred.shape.dims is None or pred.shape.dims)): pred = array_ops.squeeze_v2(pred) if maximum_iterations is None: return pred else: return math_ops.logical_and( loop_counter < maximum_iterations_arg, pred) # NOTE(skyewm): we set collections to the outer graph's collections for # compatibility with TPUEstimator. cond_graph = func_graph_module.func_graph_from_py_func( cond_name, wrapped_cond, [], # We provide signature instead of args. {}, signature=signature, func_graph=util.WhileCondFuncGraph( cond_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) def wrapped_body(loop_counter, maximum_iterations_arg, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. maximum_iterations_arg: Maximum iterations of the loop. *args: List of args Returns: A list of tensors the same length as args. """ # Capture the tensors already captured in cond_graph so that they appear # in the same order in body_graph.external_captures. for t in cond_graph.external_captures: ops.get_default_graph().capture(t) # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body(*_pack_sequence_as(orig_loop_vars, args)) if not nest.is_sequence_or_composite(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars), expand_composites=True) outputs = _tensor_array_to_flow(outputs) # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1, maximum_iterations_arg] + list(outputs) body_graph = func_graph_module.func_graph_from_py_func( body_name, wrapped_body, [], # We provide signature instead of args. {}, signature=signature, func_graph=util.WhileBodyFuncGraph( body_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture the extra `external_captures` of `body_graph` in `cond_graph` so # that it expects to receive those as arguments. with cond_graph.as_default(): num_cond_captures = len(cond_graph.external_captures) assert (cond_graph.external_captures == body_graph.external_captures[:num_cond_captures]) cond_graph_captures = object_identity.ObjectIdentitySet( cond_graph.external_captures) for body_capture in body_graph.external_captures[num_cond_captures:]: assert body_capture not in cond_graph_captures cond_graph.capture(body_capture) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. num_flattened_outputs = len(nest.flatten(orig_loop_vars, expand_composites=True)) # First var is loop counter and second var is maximum_iterations. first_loop_var_index = 2 _check_shapes_compat( body_graph.outputs[first_loop_var_index:first_loop_var_index + num_flattened_outputs], nest.flatten( shape_invariants[first_loop_var_index:first_loop_var_index + len_orig_loop_vars], expand_composites=True), nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index + len_orig_loop_vars], expand_composites=True)) num_original_outputs = len(body_graph.outputs) if back_prop and util.output_all_intermediates(): # Export all tensors in the loop body that may be needed for gradient # computation. We do this by accumulating the intermediate values in # TensorLists. intermediate_tensors = _get_intermediates(body_graph) for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=intermediate_tensor.shape, max_num_elements=maximum_iterations) loop_vars.append(tensor_list) with cond_graph.as_default(): # Add a placeholder to cond_graph's inputs corresponding to the # tensor_list. cond_graph.capture(tensor_list) with body_graph.as_default(): # Push the intermediate tensor to the tensor list. This captures the # `tensor_list` as well. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_graph.outputs.append(appended_tensor_list) flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True) _check_num_inputs_outputs(cond_graph, body_graph, len(flattened_loop_vars)) _check_inputs_outputs_types_match(body_graph, flattened_loop_vars) with ops.control_dependencies( list(cond_graph.control_captures) + list(body_graph.control_captures)): output_shapes = [t.shape for t in body_graph.outputs] orig_loop_vars_range = slice(first_loop_var_index, first_loop_var_index + num_flattened_outputs) output_shapes[orig_loop_vars_range] = nest.flatten( shape_invariants, expand_composites=True)[orig_loop_vars_range] cond_stateful_ops = [ op for op in cond_graph.get_operations() if op._is_stateful ] body_stateful_ops = [ op for op in body_graph.get_operations() if op._is_stateful ] if (cond_stateful_ops or body_stateful_ops): op_fn = gen_functional_ops._while else: op_fn = gen_functional_ops.stateless_while outputs = op_fn( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=output_shapes, parallel_iterations=parallel_iterations, name=scope) # This is needed so we do not compute derivative wrt these extra outputs. outputs[0].op._set_attr("_num_original_outputs", attr_value_pb2.AttrValue(i=num_original_outputs)) _copy_handle_data(body_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op) # Return identities for each output of the While op, rather than the output # of the While op directly. This makes pruning work if the output of # while_loop() is fetched: the lowering pass converts the While outputs into # IdentityN outputs, which if fetched will cause all ops in the body to be # run (since it takes all exit ops as input). After lowering, each output # identity op will end up with only the appropriate exit op as input. outputs = tuple(array_ops.identity(t) for t in outputs) outputs = _pack_sequence_as( orig_loop_vars, outputs[first_loop_var_index:first_loop_var_index + num_flattened_outputs]) if return_same_structure: return outputs flattened_outputs = nest.flatten(outputs, expand_composites=True) if len(flattened_outputs) == 1: return flattened_outputs[0] else: return outputs
def Body(x, tl): tl = list_ops.tensor_list_push_back(tl, x) tl = list_ops.tensor_list_push_back(tl, constant_op.constant(100.)) return x**2., tl
def body(i, t1): t1 = list_ops.tensor_list_push_back(t1, i) i += 1 return i, t1
def _capture_helper(self, tensor, name): if tensor.graph is not self._forward_graph: return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name) while tensor.op.type == "Identity": # We do not accumulate the output of identity nodes so we try to capture # the input of the Identity node instead. tensor = tensor.op.inputs[0] captured_tensor = self._indirect_captures.get(tensor) if captured_tensor is not None: return captured_tensor # Resource tensors are not accumulated and handled specially. if tensor.dtype == dtypes.resource: return self._resource_capture_helper(tensor) # Create or find an existing accumulator output for `tensor` in the forward # graph, and fetch from this accumulator in the gradient graph to get the # raw intermediate value. accumulator = _get_accumulator(tensor) if accumulator is None: # Create the initial empty tensor list. # # Note: We clear the control dependencies to avoid a cycle in case a # control tensor has an input path to an output of the forward While. # # E.g.: # x = tf.while_loop(...) # y = f(x) # with tf.control_dependencies([y]): # tf.gradients(y, x) # # Since the EmptyTensorList is fed back into the forward While, not # removing the control edge would cause a cycle. with self._forward_graph.outer_graph.as_default(): with util.clear_control_inputs(): tensor_list = list_ops.empty_tensor_list( element_dtype=tensor.dtype, element_shape=tensor.shape, max_num_elements=self._maximum_iterations, name=_build_accumulator_name(tensor)) self.empty_tensor_lists.append(tensor_list) # Push the intermediate tensor to the tensor list. This captures # `tensor_list`. with self._forward_graph.as_default(): accumulator = list_ops.tensor_list_push_back( tensor_list, tensor) # Add the modified tensor list to the list of outputs. This output will be # all the accumulated values. self._forward_graph.outputs.append(accumulator) # Capture in the cond graph as well so the forward cond and body inputs # match. with self._forward_cond_graph.as_default(): self._forward_cond_graph.capture(tensor_list) # Capture the accumulator tensor list in the gradient graph directly from # the forward graph -- we'll later modify this to capture the final list # output by the forward While op instead. captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper(accumulator, name) # Pop the intermediate value from the tensor list in the gradient graph. new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back( captured_accumulator, element_dtype=tensor.dtype) self._indirect_captures[tensor] = captured_tensor self.popped_tensor_lists[captured_accumulator] = new_tensor_list return captured_tensor
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" body_graph = _get_body_graph(op) # Set the incoming gradient of TensorArray handle to None. # TODO(b/118164915): We need a way of distinguising b/w TensorArray resource # handles and ResourceVariables and set the default gradient of only the # TensorArray handle to None. grads = [ None if output.dtype == dtypes.resource else g for g, output in zip(grads, op.outputs) ] # Ensure that all non-resource trainable outputs have incoming gradients. assert all(g is not None or o.dtype == dtypes.resource or not gradients_impl.IsTrainable(o) for o, g in zip(op.outputs, grads) ), "All trainable loop vars must receive incoming gradients." # We compute the gradient for the sub-graph between trainable ys and xs # with non-None incoming gradients. We later pad the None's to the list of # outputs. ys, xs, non_none_grads = zip( *[(y, x, grad) for (y, x, grad) in zip(body_graph.outputs, body_graph.inputs, grads) if grad is not None]) body_grad_graph, args = _create_grad_func( ys, xs, non_none_grads, body_graph, util.unique_grad_fn_name(body_graph.name), op) intermediate_tensors = _get_intermediates(body_grad_graph) for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=_get_tensor_convertible_shape( intermediate_tensor.shape)) with body_grad_graph.as_default(): tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True) # Push the intermediate tensor to the tensor list. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list_ph, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_grad_graph.outputs.append(appended_tensor_list) def grad_cond(counter, max_iters, *unused_args): return counter < max_iters loop_vars = args + body_grad_graph.external_captures grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], name="%s_grad" % op.name) _copy_handle_data(body_grad_graph.outputs, outputs) _maybe_set_lowering_attr(outputs[0].op) # Set None as the output gradient for tensors with None input gradient # e.g. TensorArray handles. # outputs[0] is the loop counter. # outputs[1] is the total number of loop iterations. index = 2 none_padded_outputs = [] for g in grads: if g is None: none_padded_outputs.append(None) else: none_padded_outputs.append(outputs[index]) index += 1 return none_padded_outputs
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" body_graph = _get_body_graph(op) # Replace None gradients with zeros. This is needed because `grads` could have # None incoming gradients for the TensorLists. If we pass None's through, the # custom gradient of TensorListPopBack will create an EmptyTensorList inside # the FuncGraph which is undesirable. # TODO(b/80444525): There might be an issue with treating no gradient as zero # gradient in certain cases. Consider replacing None gradients with Zeros # for accumulators only. grads = [ g if g is not None else array_ops.zeros_like(output) for g, output in zip(grads, op.outputs) ] body_grad_graph, args = _create_grad_func( body_graph, grads, util.unique_grad_fn_name(body_graph.name), op) intermediate_tensors = _get_intermediates(body_grad_graph) for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=_get_tensor_convertible_shape( intermediate_tensor.shape)) with body_grad_graph.as_default(): tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True) # Push the intermediate tensor to the tensor list. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list_ph, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_grad_graph.outputs.append(appended_tensor_list) def grad_cond(counter, max_iters, *unused_args): return counter < max_iters loop_vars = args + body_grad_graph.external_captures grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = function.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) assert len(loop_vars) == len(body_grad_graph.inputs) assert len(loop_vars) == len(body_grad_graph.outputs) assert len(loop_vars) == len(cond_grad_graph.inputs) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], name="%s_grad" % op.name) _copy_handle_data(body_grad_graph.outputs, outputs) _maybe_set_lowering_attr(outputs[0].op) # outputs[0] is the loop counter. # outputs[1] is the total number of loop iterations. return outputs[2:2 + len(op.inputs)]