def testTensorListFromTensor(self): t = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_from_tensor(t, element_shape=[]) l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(e), 2.0) l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(e), 1.0) self.assertAllEqual(self.evaluate(list_ops.tensor_list_length(l)), 0)
def testUnknownShape(self): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=None) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant([1.0, 2.0])) l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(e), [1.0, 2.0]) l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(e), 1.0)
def testListFromTensor(self): with self.cached_session(), self.test_scope(): t = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_from_tensor(t, element_shape=[]) e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) self.assertAllEqual(e, 1.0) l, e0 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(e0, 2.0) l, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(e1, 1.0) self.assertAllEqual(list_ops.tensor_list_length(l), 0)
def testAccumulatorElementShape(self, shape): def MatchShape(actual_tensor_shape): # Compare the shapes, treating None dimensions as equal. We do not # directly check actual_tensor_shape and tf.TensorShape(shape) for # equality because tf.Dimension.__eq__ returns None if either dimension is # None. if shape is None: self.assertIsNone(actual_tensor_shape.dims) else: self.assertListEqual(actual_tensor_shape.as_list(), shape) def GetAccumulatorForInputAtIndex(while_op, idx): body_graph = while_v2._get_graph(while_op, "body") y_input_t = body_graph.inputs[idx] push_back_node = [c for c in y_input_t.consumers() if c.type == "TensorListPushBack"][0] output_idx = body_graph.outputs.index(push_back_node.outputs[0]) return while_op.outputs[output_idx] x = array_ops.placeholder(dtype=dtypes.float32, shape=shape) y = array_ops.placeholder(dtype=dtypes.float32, shape=shape) # Forward pass. ret = while_loop_v2(lambda v, u: v < 8., lambda v, u: (math_ops.pow(v, u), u), [x, y], return_same_structure=True) while_op = ret[0].op.inputs[0].op # Gradient pass. grad = gradients_impl.gradients(ret[0], x) # Note: There is an Identity b/w grad[0] and the While op. grad_while_op = grad[0].op.inputs[0].op # Get the TensorList output of While op containing the accumulated values # of y. x_input_index = [i for i, inp in enumerate(while_op.inputs) if x == inp][0] output = GetAccumulatorForInputAtIndex(while_op, x_input_index) _, val = list_ops.tensor_list_pop_back(output, element_dtype=dtypes.float32) MatchShape(val.shape) # Take second derivative to generate intermediate grad_while_op outputs gradients_impl.gradients(grad, x) # Get the TensorList output of gradient While op containing the accumulated # values of grad_x (note that grad_x is needed by the second derivative). # grad_while_op.inputs: grad_output_index = grad_while_op.outputs.index(grad[0].op.inputs[0]) grad_output = GetAccumulatorForInputAtIndex(grad_while_op, grad_output_index) _, val = list_ops.tensor_list_pop_back(grad_output, element_dtype=dtypes.float32) MatchShape(val.shape)
def testPushPop(self): with self.cached_session() as sess, self.test_scope(): num = array_ops.placeholder(dtypes.int32) l = list_ops.tensor_list_reserve( element_shape=(7, 15), num_elements=num, element_dtype=dtypes.float32) l = list_ops.tensor_list_push_back( l, constant_op.constant(1.0, shape=(7, 15))) l = list_ops.tensor_list_push_back( l, constant_op.constant(2.0, shape=(7, 15))) l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(sess.run(e2, {num: 10}), 2.0 * np.ones((7, 15))) self.assertAllEqual(sess.run(e1, {num: 10}), 1.0 * np.ones((7, 15)))
def testPushPop(self): with self.cached_session() as sess, self.test_scope(): l = list_ops.empty_tensor_list( element_shape=(7, 15), element_dtype=dtypes.float32, max_num_elements=10) l = list_ops.tensor_list_push_back( l, constant_op.constant(1.0, shape=(7, 15))) l = list_ops.tensor_list_push_back( l, constant_op.constant(2.0, shape=(7, 15))) l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(sess.run(e2), 2.0 * np.ones((7, 15))) self.assertAllEqual(sess.run(e1), 1.0 * np.ones((7, 15)))
def testAccumulatorElementShape(self, shape): def MatchShape(actual_tensor_shape): # Compare the shapes, treating None dimensions as equal. We do not # directly check actual_tensor_shape and tf.TensorShape(shape) for # equality because tf.Dimension.__eq__ returns None if either dimension is # None. if shape is None: self.assertIsNone(actual_tensor_shape.dims) else: self.assertListEqual(actual_tensor_shape.as_list(), shape) def GetAccumulatorForInputAtIndex(while_op, idx): body_graph = while_v2._get_body_graph(while_op) y_input_t = body_graph.inputs[idx] push_back_node = [c for c in y_input_t.consumers() if c.type == "TensorListPushBack"][0] output_idx = body_graph.outputs.index(push_back_node.outputs[0]) return while_op.outputs[output_idx] x = constant_op.constant(2.) y = array_ops.placeholder(dtype=dtypes.float32, shape=shape) # Forward pass. ret = while_loop_v2( lambda v, u: v < 8., lambda v, u: (v * v, u), [x, y], return_same_structure=False) while_op = ret[0].op.inputs[0].op # Get the TensorList output of While op containing the accumulated values # of y. # while_op.inputs: [counter_arg, x_arg, y_arg, *accumulators] output = GetAccumulatorForInputAtIndex(while_op, 2) _, val = list_ops.tensor_list_pop_back(output, element_dtype=dtypes.float32) MatchShape(val.shape) # Gradient pass. grad = gradients_impl.gradients(ret[1], y) grad_while_op = grad[0].op.inputs[0].op # Get the TensorList output of gradient While op containing the accumulated # values of grad_y. # grad_while_op.inputs: # [counter_arg, total_iters_arg, grad_x_arg, grad_y_arg, *other_args] grad_output = GetAccumulatorForInputAtIndex(grad_while_op, 3) _, val = list_ops.tensor_list_pop_back(grad_output, element_dtype=dtypes.float32) MatchShape(val.shape)
def testCPUGPUCopy(self): if not context.num_gpus(): return t = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_from_tensor(t, element_shape=[]) with context.device("gpu:0"): l_gpu = array_ops.identity(l) self.assertAllEqual( self.evaluate( list_ops.tensor_list_pop_back( l_gpu, element_dtype=dtypes.float32)[1]), 2.0) l_cpu = array_ops.identity(l_gpu) self.assertAllEqual( self.evaluate( list_ops.tensor_list_pop_back( l_cpu, element_dtype=dtypes.float32)[1]), 2.0)
def testPushPopSeparateLists(self): with self.cached_session() as sess, self.test_scope(): l = list_ops.empty_tensor_list( element_shape=[], element_dtype=dtypes.float32, max_num_elements=20) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l2 = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) l3 = list_ops.tensor_list_push_back(l, constant_op.constant(3.0)) _, e11 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) l2, e21 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32) l2, e22 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32) l3, e31 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) l3, e32 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) result = sess.run([e11, [e21, e22], [e31, e32]]) self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]])
def testSerialize(self): # pylint: disable=g-import-not-at-top try: import portpicker except ImportError: return with context.graph_mode(): worker_port = portpicker.pick_unused_port() ps_port = portpicker.pick_unused_port() cluster_dict = { "worker": ["localhost:%s" % worker_port], "ps": ["localhost:%s" % ps_port] } cs = server_lib.ClusterSpec(cluster_dict) worker = server_lib.Server( cs, job_name="worker", protocol="grpc", task_index=0, start=True) unused_ps = server_lib.Server( cs, job_name="ps", protocol="grpc", task_index=0, start=True) with ops.Graph().as_default(), session.Session(target=worker.target): with ops.device("/job:worker"): t = constant_op.constant([[1.0], [2.0]]) l = list_ops.tensor_list_from_tensor(t, element_shape=[1]) with ops.device("/job:ps"): l_ps = array_ops.identity(l) l_ps, e = list_ops.tensor_list_pop_back( l_ps, element_dtype=dtypes.float32) with ops.device("/job:worker"): worker_e = array_ops.identity(e) self.assertAllEqual(worker_e.eval(), [2.0])
def testPushPopSeparateLists(self): with self.cached_session() as sess, self.test_scope(): num = array_ops.placeholder(dtypes.int32) l = list_ops.tensor_list_reserve( element_shape=scalar_shape(), num_elements=num, element_dtype=dtypes.float32) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l2 = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) l3 = list_ops.tensor_list_push_back(l, constant_op.constant(3.0)) _, e11 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) l2, e21 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32) l2, e22 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32) l3, e31 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) l3, e32 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) result = sess.run([e11, [e21, e22], [e31, e32]], {num: 20}) self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]])
def _testPushPop(self, max_num_elements): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=[], max_num_elements=max_num_elements) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(e), 1.0)
def testDoNotConstantFoldVariants(self): with self.cached_session() as sess, self.test_scope(): val = array_ops.placeholder(dtype=dtypes.float32) l = list_ops.empty_tensor_list( element_shape=(7, 15), element_dtype=dtypes.float32, max_num_elements=10) # Note: Pushing a Placeholder will force the constant folding code # to build a Const node with a DT_VARIANT output. This tests that XLA # passes a cf_consider_fn which prevent folding such nodes. l = list_ops.tensor_list_push_back( l, array_ops.fill(value=val, dims=(7, 15))) l = list_ops.tensor_list_push_back( l, constant_op.constant(2.0, shape=(7, 15))) l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(sess.run(e2, {val: 1.0}), 2.0 * np.ones((7, 15))) self.assertAllEqual(sess.run(e1, {val: 1.0}), 1.0 * np.ones((7, 15)))
def testPopFromEmptyTensorListFails(self, max_num_elements): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=[], max_num_elements=max_num_elements) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Trying to pop from an empty list"): l = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.evaluate(l)
def testEmptyTensorListMax(self): with self.cached_session() as sess, self.test_scope(): l = list_ops.empty_tensor_list( element_shape=(10, 15), element_dtype=dtypes.float32, max_num_elements=2) l = list_ops.tensor_list_push_back( l, array_ops.fill(value=3.0, dims=(10, 15))) _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(sess.run(e), 3.0 * np.ones((10, 15)))
def testEmptyTensorListNoMax(self): with self.cached_session() as sess, self.test_scope(): l = list_ops.empty_tensor_list( element_shape=(7, 15), element_dtype=dtypes.float32) l = list_ops.tensor_list_push_back( l, constant_op.constant(1.0, shape=(7, 15))) _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Set the max number of elements"): self.assertAllEqual(sess.run(e), 1.0 * np.ones((7, 15)))
def testPushPopGradients(self): with backprop.GradientTape() as tape: l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=[]) c = constant_op.constant(1.0) tape.watch(c) l = list_ops.tensor_list_push_back(l, c) l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) e = 2 * e self.assertAllEqual(self.evaluate(tape.gradient(e, [c])[0]), 2.0)
def _capture_helper(self, tensor, name): if tensor.graph is not self._forward_graph: return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name) while tensor.op.type == "Identity": # We do not accumulate the output of identity nodes so we try to capture # the input of the Identity node instead. tensor = tensor.op.inputs[0] captured_tensor = self._indirect_captures.get(tensor) if captured_tensor is not None: return captured_tensor # Resource tensors are not accumulated and handled specially. if tensor.dtype == dtypes.resource: return self._resource_capture_helper(tensor) # Create or find an existing accumulator output for `tensor` in the forward # graph, and fetch from this accumulator in the gradient graph to get the # raw intermediate value. accumulator = _get_accumulator(tensor) if accumulator is None: # Create the initial empty tensor list. with self._forward_graph.outer_graph.as_default(): tensor_list = list_ops.empty_tensor_list( element_dtype=tensor.dtype, element_shape=tensor.shape, max_num_elements=self._maximum_iterations) self.empty_tensor_lists.append(tensor_list) # Push the intermediate tensor to the tensor list. This captures # `tensor_list`. with self._forward_graph.as_default(): accumulator = list_ops.tensor_list_push_back(tensor_list, tensor) # Add the modified tensor list to the list of outputs. This output will be # all the accumulated values. self._forward_graph.outputs.append(accumulator) # Capture in the cond graph as well so the forward cond and body inputs # match. with self._forward_cond_graph.as_default(): self._forward_cond_graph.capture(tensor_list) # Capture the accumulator tensor list in the gradient graph directly from # the forward graph -- we'll later modify this to capture the final list # output by the forward While op instead. captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper( accumulator, name) # Pop the intermediate value from the tensor list in the gradient graph. new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back( captured_accumulator, element_dtype=tensor.dtype) self._indirect_captures[tensor] = captured_tensor self.popped_tensor_lists[captured_accumulator] = new_tensor_list return captured_tensor
def testEmptyTensorList(self): dim = 7 with self.cached_session() as sess, self.test_scope(): p = array_ops.placeholder(dtypes.int32) l = list_ops.empty_tensor_list( element_shape=(p, 15), element_dtype=dtypes.float32) l = list_ops.tensor_list_push_back( l, constant_op.constant(1.0, shape=(dim, 15))) _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Use TensorListReserve instead"): self.assertEqual(sess.run(e, {p: dim}), 1.0 * np.ones((dim, 15)))
def testSerialize(self): worker = test_util.create_local_cluster(num_workers=1, num_ps=1)[0][0] with ops.Graph().as_default(), session.Session(target=worker.target): with ops.device("/job:worker"): t = constant_op.constant([[1.0], [2.0]]) l = list_ops.tensor_list_from_tensor(t, element_shape=[1]) with ops.device("/job:ps"): l_ps = array_ops.identity(l) l_ps, e = list_ops.tensor_list_pop_back( l_ps, element_dtype=dtypes.float32) with ops.device("/job:worker"): worker_e = array_ops.identity(e) self.assertAllEqual(self.evaluate(worker_e), [2.0])
def _tf_tensor_list_pop(list_, i, opts): """Overload of list_pop that stages a Tensor list pop.""" if i is not None: raise NotImplementedError('tensor lists only support removing from the end') if opts.element_dtype is None: raise ValueError('cannot pop from a list without knowing its element ' 'type; use set_element_type to annotate it') if opts.element_shape is None: raise ValueError('cannot pop from a list without knowing its element ' 'shape; use set_element_type to annotate it') list_out, x = list_ops.tensor_list_pop_back( list_, element_dtype=opts.element_dtype) x.set_shape(opts.element_shape) return list_out, x
def testCPUGPUCopyNested(self): if not context.num_gpus(): return t = constant_op.constant([1.0, 2.0]) child_l = list_ops.tensor_list_from_tensor(t, element_shape=[]) l = list_ops.empty_tensor_list( element_shape=constant_op.constant([], dtype=dtypes.int32), element_dtype=dtypes.variant) l = list_ops.tensor_list_push_back(l, child_l) with context.device("gpu:0"): l_gpu = array_ops.identity(l) _, child_l_gpu = list_ops.tensor_list_pop_back( l_gpu, element_dtype=dtypes.variant) self.assertAllEqual( self.evaluate( list_ops.tensor_list_pop_back( child_l_gpu, element_dtype=dtypes.float32)[1]), 2.0) l_cpu = array_ops.identity(l_gpu) _, child_l_cpu = list_ops.tensor_list_pop_back( l_cpu, element_dtype=dtypes.variant) self.assertAllEqual( self.evaluate( list_ops.tensor_list_pop_back( child_l_cpu, element_dtype=dtypes.float32)[1]), 2.0)
def testZerosLikeVariant(self): for dtype in (dtypes.uint8, dtypes.uint16, dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128, dtypes.bool): l = list_ops.empty_tensor_list( element_dtype=dtypes.variant, element_shape=scalar_shape()) sub_l = list_ops.empty_tensor_list( element_dtype=dtype, element_shape=scalar_shape()) l = list_ops.tensor_list_push_back(l, sub_l) sub_l = list_ops.tensor_list_push_back(sub_l, math_ops.cast( 1, dtype=dtype)) l = list_ops.tensor_list_push_back(l, sub_l) sub_l = list_ops.tensor_list_push_back(sub_l, math_ops.cast( 2, dtype=dtype)) l = list_ops.tensor_list_push_back(l, sub_l) # l : [[], # [1], # [1, 2]] # # l_zeros : [[], # [0], # [0, 0]] l_zeros = array_ops.zeros_like(l) outputs = [] for _ in range(3): l_zeros, out = list_ops.tensor_list_pop_back( l_zeros, element_dtype=dtypes.variant) outputs.append(list_ops.tensor_list_stack(out, element_dtype=dtype)) # Note: `outputs` contains popped values so the order is reversed. self.assertAllEqual(self.evaluate(outputs[2]), []) self.assertAllEqual( self.evaluate(outputs[1]), np.zeros((1,), dtype=dtype.as_numpy_dtype)) self.assertAllEqual( self.evaluate(outputs[0]), np.zeros((2,), dtype=dtype.as_numpy_dtype))
def _capture_helper(self, tensor, name): if tensor.graph is not self._forward_graph: return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name) captured_tensor = self._indirect_captures.get(tensor) if captured_tensor is not None: # For GradientTape housekeeping. assert self._tensor_to_accumulator[tensor] in self.captures super(_WhileBodyGradFuncGraph, self)._capture_helper( self._tensor_to_accumulator[tensor], name) return captured_tensor assert tensor not in self._tensor_to_accumulator accumulator = None # Find the TensorList that was used to accumulate the tensors of this # intermediate tensor. accumulator = _get_accumulator(tensor) if accumulator is None: raise ValueError("Reference to un-accumulated intermediate tensor: ", tensor.name) assert accumulator.graph == self._forward_graph # Get the While op output corresponding to the accumulator. accumulator = self._forward_graph._while.outputs[self._forward_graph.outputs .index(accumulator)] assert accumulator.graph == self._forward_graph.outer_graph self._tensor_to_accumulator[tensor] = accumulator # Capture the `accumulator`. accumulator_ph = super(_WhileBodyGradFuncGraph, self)._capture_helper( accumulator, name) new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back( accumulator_ph, element_dtype=tensor.dtype) self._indirect_captures[tensor] = captured_tensor self.popped_tensor_lists[accumulator_ph] = new_tensor_list return captured_tensor
def testSerialize(self): # pylint: disable=g-import-not-at-top try: import portpicker except ImportError: return with context.graph_mode(): worker_port = portpicker.pick_unused_port() ps_port = portpicker.pick_unused_port() cluster_dict = { "worker": ["localhost:%s" % worker_port], "ps": ["localhost:%s" % ps_port] } cs = server_lib.ClusterSpec(cluster_dict) worker = server_lib.Server(cs, job_name="worker", protocol="grpc", task_index=0, start=True) unused_ps = server_lib.Server(cs, job_name="ps", protocol="grpc", task_index=0, start=True) with ops.Graph().as_default(), session.Session( target=worker.target): with ops.device("/job:worker"): t = constant_op.constant([[1.0], [2.0]]) l = list_ops.tensor_list_from_tensor(t, element_shape=[1]) with ops.device("/job:ps"): l_ps = array_ops.identity(l) l_ps, e = list_ops.tensor_list_pop_back( l_ps, element_dtype=dtypes.float32) with ops.device("/job:worker"): worker_e = array_ops.identity(e) self.assertAllEqual(worker_e.eval(), [2.0])
def _capture_helper(self, tensor, name): if tensor.graph is not self._forward_graph: return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name) while tensor.op.type == "Identity": # We do not accumulate the output of identity nodes so we try to capture # the input of the Identity node instead. tensor = tensor.op.inputs[0] captured_tensor = self._indirect_captures.get(tensor) if captured_tensor is not None: return captured_tensor if tensor.dtype == dtypes.resource: # Resource-type tensors are not accumulated. # If a resource tensor exists in the loop body it must either be a loop # input or an output of a nested While op inside the loop body which # had captured the external resource. if tensor in self._forward_graph.inputs: index = self._forward_graph.inputs.index(tensor) elif tensor.op.type == "While": # Captured resources occur at the same index in the lists of inputs and # outputs of a while op. So we lookup the input of `tensor.op` at the # same index as the index of `tensor` in the `tensor.op.outputs`. index = self._forward_graph.inputs.index( tensor.op.inputs[tensor.value_index]) else: raise ValueError( "Taking gradient of a while loop which creates" " a resource in its body is not supported: %s" % str(tensor)) # This must be a loop invariant. assert self._forward_graph.inputs[ index] == self._forward_graph.outputs[ index], "Resource tensors must be loop invariants %s." % str( self._forward_graph._while.inputs[index]) tensor_in_outer_graph = self._forward_graph._while.inputs[index] self._indirect_captures[tensor] = self.capture( tensor_in_outer_graph, whitelisted=True) return self._indirect_captures[tensor] # Create or find an existing accumulator output for `tensor` in the forward # graph, and fetch from this accumulator in the gradient graph to get the # raw intermediate value. accumulator = _get_accumulator(tensor) if accumulator is None: # Create the initial empty tensor list. with self._forward_graph.outer_graph.as_default(): tensor_list = list_ops.empty_tensor_list( element_dtype=tensor.dtype, element_shape=tensor.shape, max_num_elements=self._maximum_iterations) self.empty_tensor_lists.append(tensor_list) # Push the intermediate tensor to the tensor list. This captures # `tensor_list`. with self._forward_graph.as_default(): accumulator = list_ops.tensor_list_push_back( tensor_list, tensor) # Add the modified tensor list to the list of outputs. This output will be # all the accumulated values. self._forward_graph.outputs.append(accumulator) # Capture in the cond graph as well so the forward cond and body inputs # match. with self._forward_cond_graph.as_default(): self._forward_cond_graph.capture(tensor_list) # Capture the accumulator tensor list in the gradient graph directly from # the forward graph -- we'll later modify this to capture the final list # output by the forward While op instead. captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper(accumulator, name) # Pop the intermediate value from the tensor list in the gradient graph. new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back( captured_accumulator, element_dtype=tensor.dtype) self._indirect_captures[tensor] = captured_tensor self.popped_tensor_lists[captured_accumulator] = new_tensor_list return captured_tensor
def pop(self): self.list_, value = list_ops.tensor_list_pop_back(self.list_, self.dtype) return value
def _capture_helper(self, tensor, name): if tensor.graph is not self._forward_graph: return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name) while tensor.op.type == "Identity": # We do not accumulate the output of identity nodes so we try to capture # the input of the Identity node instead. tensor = tensor.op.inputs[0] captured_tensor = self._indirect_captures.get(tensor) if captured_tensor is not None: # For GradientTape housekeeping. assert self._inner_to_outer_tensor[tensor] in self.captures super(_WhileBodyGradFuncGraph, self)._capture_helper( self._inner_to_outer_tensor[tensor], name) return captured_tensor if tensor.dtype == dtypes.resource: # Resource-type tensors are not accumulated. # If a resource tensor exists in the loop body it must either be a loop # input or an output of a nested While op inside the loop body which # had captured the external resource. if tensor in self._forward_graph.inputs: index = self._forward_graph.inputs.index(tensor) elif tensor.op.type == "While": # Captured resources occur at the same index in the lists of inputs and # outputs of a while op. So we lookup the input of `tensor.op` at the # same index as the index of `tensor` in the `tensor.op.outputs`. index = self._forward_graph.inputs.index( tensor.op.inputs[tensor.value_index]) else: raise ValueError( "Taking gradient of a while loop which creates" " a resource in its body is not supported: %s" % str(tensor)) # This must be a loop invariant. assert self._forward_graph.inputs[index] == self._forward_graph.outputs[ index], "Resource tensors must be loop invariants %s." % str( self._forward_graph._while.inputs[index]) tensor_in_outer_graph = self._forward_graph._while.inputs[index] self._inner_to_outer_tensor[tensor] = tensor_in_outer_graph self._indirect_captures[tensor] = self.capture( tensor_in_outer_graph, whitelisted=True) return self._indirect_captures[tensor] assert tensor not in self._inner_to_outer_tensor accumulator = None # Find the TensorList that was used to accumulate the tensors of this # intermediate tensor. accumulator = _get_accumulator(tensor) if accumulator is None: raise ValueError("Reference to un-accumulated intermediate tensor: ", tensor.name) assert accumulator.graph == self._forward_graph # Get the While op output corresponding to the accumulator. accumulator = self._forward_graph._while.outputs[self._forward_graph.outputs .index(accumulator)] assert accumulator.graph == self._forward_graph.outer_graph self._inner_to_outer_tensor[tensor] = accumulator # Capture the `accumulator`. accumulator_ph = super(_WhileBodyGradFuncGraph, self)._capture_helper( accumulator, name) new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back( accumulator_ph, element_dtype=tensor.dtype) self._indirect_captures[tensor] = captured_tensor self.popped_tensor_lists[accumulator_ph] = new_tensor_list return captured_tensor
def testPushPop(self): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=scalar_shape()) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(e, 1.0)
def _capture_helper(self, tensor, name): if tensor.graph is not self._forward_graph: return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name) while tensor.op.type == "Identity": # We do not accumulate the output of identity nodes so we try to capture # the input of the Identity node instead. tensor = tensor.op.inputs[0] captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor)) if captured_tensor is not None: return captured_tensor # Do not accumulate loop invariants. if (any(tensor is t for t in self._forward_graph.inputs) and any(tensor is t for t in self._forward_graph.outputs)): captured_tensor = super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name) # Add to `popped_tensor_lists` so that this gets added to the list of # outputs. # TODO(srbs): Rename popped_tensor_lists. self.popped_tensor_lists[ops.tensor_id(captured_tensor)] = captured_tensor self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor return captured_tensor # Do not accumulate Const nodes. Instead copy them directly in the backward # graph. # TODO(srbs): This just checks for `Const` nodes. Consider checking for # graph compile time consts in general. # TODO(srbs): Consider making this a loop input. if constant_op.is_constant(tensor): real_value = constant_op.constant( tensor_util.constant_value(tensor), dtype=tensor.dtype) self._indirect_captures[ops.tensor_id(tensor)] = real_value return real_value # Resource tensors are not accumulated and handled specially. if tensor.dtype == dtypes.resource: return self._resource_capture_helper(tensor) # No need to accumulate loop invariants. Capture them directly. # The captured tensor gets resolved to the corresponding while output in # `_resolve_grad_captures`. if _is_loop_invariant(tensor, self._forward_graph_inputs, self._forward_graph_outputs): captured_tensor = super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name) return captured_tensor # Create or find an existing accumulator output for `tensor` in the forward # graph, and fetch from this accumulator in the gradient graph to get the # raw intermediate value. accumulator = _get_accumulator(tensor) if accumulator is None: # Create the initial empty tensor list. # # Note: We clear the control dependencies to avoid a cycle in case a # control tensor has an input path to an output of the forward While. # # E.g.: # x = tf.while_loop(...) # y = f(x) # with tf.control_dependencies([y]): # tf.gradients(y, x) # # Since the EmptyTensorList is fed back into the forward While, not # removing the control edge would cause a cycle. with self._forward_graph.outer_graph.as_default(): with util.clear_control_inputs(): tensor_list = list_ops.empty_tensor_list( element_dtype=tensor.dtype, element_shape=tensor.shape, max_num_elements=self._maximum_iterations, name=_build_accumulator_name(tensor)) self.empty_tensor_lists.append(tensor_list) # Push the intermediate tensor to the tensor list. This captures # `tensor_list`. with self._forward_graph.as_default(): accumulator = list_ops.tensor_list_push_back(tensor_list, tensor) # Add the modified tensor list to the list of outputs. This output will be # all the accumulated values. self._forward_graph.outputs.append(accumulator) # Capture in the cond graph as well so the forward cond and body inputs # match. with self._forward_cond_graph.as_default(): self._forward_cond_graph.capture(tensor_list) # Capture the accumulator tensor list in the gradient graph directly from # the forward graph -- we'll later modify this to capture the final list # output by the forward While op instead. captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper( accumulator, name) # Pop the intermediate value from the tensor list in the gradient graph. new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back( captured_accumulator, element_dtype=tensor.dtype) self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor self.popped_tensor_lists[ops.tensor_id( captured_accumulator)] = new_tensor_list return captured_tensor
def pop(self): self.list_, value = list_ops.tensor_list_pop_back(self.list_, self.dtype) return value
def _capture_helper(self, tensor, name): if tensor.graph is not self._forward_graph: return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name) while tensor.op.type == "Identity": # We do not accumulate the output of identity nodes so we try to capture # the input of the Identity node instead. tensor = tensor.op.inputs[0] captured_tensor = self._indirect_captures.get(tensor) if captured_tensor is not None: # For GradientTape housekeeping. assert self._inner_to_outer_tensor[tensor] in self.captures super(_WhileBodyGradFuncGraph, self)._capture_helper(self._inner_to_outer_tensor[tensor], name) return captured_tensor if tensor.dtype == dtypes.resource: # Resource-type tensors are not accumulated. # If a resource tensor exists in the loop body it must either be a loop # input or an output of a nested While op inside the loop body which # had captured the external resource. if tensor in self._forward_graph.inputs: index = self._forward_graph.inputs.index(tensor) elif tensor.op.type == "While": # Captured resources occur at the same index in the lists of inputs and # outputs of a while op. So we lookup the input of `tensor.op` at the # same index as the index of `tensor` in the `tensor.op.outputs`. index = self._forward_graph.inputs.index( tensor.op.inputs[tensor.value_index]) else: raise ValueError( "Taking gradient of a while loop which creates" " a resource in its body is not supported: %s" % str(tensor)) # This must be a loop invariant. assert self._forward_graph.inputs[ index] == self._forward_graph.outputs[ index], "Resource tensors must be loop invariants %s." % str( self._forward_graph._while.inputs[index]) tensor_in_outer_graph = self._forward_graph._while.inputs[index] self._inner_to_outer_tensor[tensor] = tensor_in_outer_graph self._indirect_captures[tensor] = self.capture( tensor_in_outer_graph, whitelisted=True) return self._indirect_captures[tensor] assert tensor not in self._inner_to_outer_tensor accumulator = None # Find the TensorList that was used to accumulate the tensors of this # intermediate tensor. accumulator = _get_accumulator(tensor) if accumulator is None: raise ValueError( "Reference to un-accumulated intermediate tensor: ", tensor.name) assert accumulator.graph == self._forward_graph # Get the While op output corresponding to the accumulator. accumulator = self._forward_graph._while.outputs[ self._forward_graph.outputs.index(accumulator)] assert accumulator.graph == self._forward_graph.outer_graph self._inner_to_outer_tensor[tensor] = accumulator # Capture the `accumulator`. accumulator_ph = super(_WhileBodyGradFuncGraph, self)._capture_helper(accumulator, name) new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back( accumulator_ph, element_dtype=tensor.dtype) self._indirect_captures[tensor] = captured_tensor self.popped_tensor_lists[accumulator_ph] = new_tensor_list return captured_tensor
def _capture_helper(self, tensor, name): if tensor.graph is not self._forward_graph: return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name) while tensor.op.type == "Identity": # We do not accumulate the output of identity nodes so we try to capture # the input of the Identity node instead. tensor = tensor.op.inputs[0] captured_tensor = self._indirect_captures.get(tensor) if captured_tensor is not None: return captured_tensor # Resource tensors are not accumulated and handled specially. if tensor.dtype == dtypes.resource: return self._resource_capture_helper(tensor) # Create or find an existing accumulator output for `tensor` in the forward # graph, and fetch from this accumulator in the gradient graph to get the # raw intermediate value. accumulator = _get_accumulator(tensor) if accumulator is None: # Create the initial empty tensor list. # # Note: We clear the control dependencies to avoid a cycle in case a # control tensor has an input path to an output of the forward While. # # E.g.: # x = tf.while_loop(...) # y = f(x) # with tf.control_dependencies([y]): # tf.gradients(y, x) # # Since the EmptyTensorList is fed back into the forward While, not # removing the control edge would cause a cycle. with self._forward_graph.outer_graph.as_default(): with util.clear_control_inputs(): tensor_list = list_ops.empty_tensor_list( element_dtype=tensor.dtype, element_shape=tensor.shape, max_num_elements=self._maximum_iterations, name=_build_accumulator_name(tensor)) self.empty_tensor_lists.append(tensor_list) # Push the intermediate tensor to the tensor list. This captures # `tensor_list`. with self._forward_graph.as_default(): accumulator = list_ops.tensor_list_push_back( tensor_list, tensor) # Add the modified tensor list to the list of outputs. This output will be # all the accumulated values. self._forward_graph.outputs.append(accumulator) # Capture in the cond graph as well so the forward cond and body inputs # match. with self._forward_cond_graph.as_default(): self._forward_cond_graph.capture(tensor_list) # Capture the accumulator tensor list in the gradient graph directly from # the forward graph -- we'll later modify this to capture the final list # output by the forward While op instead. captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper(accumulator, name) # Pop the intermediate value from the tensor list in the gradient graph. new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back( captured_accumulator, element_dtype=tensor.dtype) self._indirect_captures[tensor] = captured_tensor self.popped_tensor_lists[captured_accumulator] = new_tensor_list return captured_tensor
def testPushPop(self): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=scalar_shape()) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(e), 1.0)
def _capture_helper(self, tensor, name): if tensor.graph is not self._forward_graph: return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name) while tensor.op.type == "Identity": # We do not accumulate the output of identity nodes so we try to capture # the input of the Identity node instead. tensor = tensor.op.inputs[0] captured_tensor = self._indirect_captures.get(tensor) if captured_tensor is not None: return captured_tensor # Resource tensors are not accumulated and handled specially. if tensor.dtype == dtypes.resource: return self._resource_capture_helper(tensor) # Create or find an existing accumulator output for `tensor` in the forward # graph, and fetch from this accumulator in the gradient graph to get the # raw intermediate value. accumulator = _get_accumulator(tensor) if accumulator is None: # Create the initial empty tensor list. # # Note: We clear the control dependencies to avoid a cycle in case a # control tensor has an input path to an output of the forward While. # # E.g.: # x = tf.while_loop(...) # y = f(x) # with tf.control_dependencies([y]): # tf.gradients(y, x) # # Since the EmptyTensorList is fed back into the forward While, not # removing the control edge would cause a cycle. with self._forward_graph.outer_graph.as_default(): with util.clear_control_inputs(): tensor_list = list_ops.empty_tensor_list( element_dtype=tensor.dtype, element_shape=tensor.shape, max_num_elements=self._maximum_iterations, name=_build_accumulator_name(tensor)) self.empty_tensor_lists.append(tensor_list) # Push the intermediate tensor to the tensor list. This captures # `tensor_list`. with self._forward_graph.as_default(): accumulator = list_ops.tensor_list_push_back(tensor_list, tensor) # Add the modified tensor list to the list of outputs. This output will be # all the accumulated values. self._forward_graph.outputs.append(accumulator) # Capture in the cond graph as well so the forward cond and body inputs # match. with self._forward_cond_graph.as_default(): self._forward_cond_graph.capture(tensor_list) # Capture the accumulator tensor list in the gradient graph directly from # the forward graph -- we'll later modify this to capture the final list # output by the forward While op instead. captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper( accumulator, name) # Pop the intermediate value from the tensor list in the gradient graph. new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back( captured_accumulator, element_dtype=tensor.dtype) self._indirect_captures[tensor] = captured_tensor self.popped_tensor_lists[captured_accumulator] = new_tensor_list return captured_tensor