def GetAccumulatorForInputAtIndex(while_op, idx): body_graph = while_v2._get_graph(while_op, "body") y_input_t = body_graph.inputs[idx] push_back_node = [c for c in y_input_t.consumers() if c.type == "TensorListPushBack"][0] output_idx = body_graph.outputs.index(push_back_node.outputs[0]) return while_op.outputs[output_idx]
def testDuplicateAccumulator(self): x = constant_op.constant(2.) tensor_list = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=ScalarShape()) def Cond(x, tl): del tl # Unused for Cond. return x < 5. def Body(x, tl): # There is an accumulator in the loop already so we should not add # another. tl = list_ops.tensor_list_push_back(tl, x) return x**2., tl ret = while_loop_v2( Cond, Body, [x, tensor_list], return_same_structure=False) for op in ops.get_default_graph().get_operations(): if op.type == "While" or op.type == "StatelessWhile": while_op = op body_graph = while_v2._get_graph(while_op, "body") x_input_index = [i for i, inp in enumerate(while_op.inputs) if inp == x][0] x_input_t = body_graph.inputs[x_input_index] accumulator_count = len( [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"]) self.assertEqual(accumulator_count, 1) grad = gradients_impl.gradients(ret[0], x) with self.cached_session() as sess: self.assertEqual(sess.run(ret[0]), 16.) self.assertSequenceEqual(self.evaluate(grad), [32.])
def GetAccumulatorForInputAtIndex(while_op, idx): body_graph = while_v2._get_graph(while_op, "body") y_input_t = body_graph.inputs[idx] push_back_node = [c for c in y_input_t.consumers() if c.type == "TensorListPushBack"][0] output_idx = body_graph.outputs.index(push_back_node.outputs[0]) return while_op.outputs[output_idx]
def testDuplicateAccumulator(self): x = constant_op.constant(2.) tensor_list = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=ScalarShape()) def Cond(x, tl): del tl # Unused for Cond. return x < 5. def Body(x, tl): # There is an accumulator in the loop already so we should not add # another. tl = list_ops.tensor_list_push_back(tl, x) return x**2., tl ret = while_loop_v2( Cond, Body, [x, tensor_list], return_same_structure=False) for op in ops.get_default_graph().get_operations(): if op.type == "While": while_op = op body_graph = while_v2._get_graph(while_op, "body") x_input_index = [i for i, inp in enumerate(while_op.inputs) if inp == x][0] x_input_t = body_graph.inputs[x_input_index] accumulator_count = len( [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"]) self.assertEqual(accumulator_count, 1) grad = gradients_impl.gradients(ret[0], x) with self.cached_session() as sess: self.assertEqual(sess.run(ret[0]), 16.) self.assertSequenceEqual(self.evaluate(grad), [32.])
def testDoNotAccumulateConstNodes(self): def Body(v): return v * 2.0 v0 = constant_op.constant(2.) ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0] # Gradients computation has the side-effect of updating the forward op # which is what we want to test. unused_grad = gradients_impl.gradients(ret, [v0])[0] # ret is separated from the `While` op by an `Identity` so we skip over # that. forward_while_op = ret.op.inputs[0].op body_graph = while_v2._get_graph(forward_while_op, "body") push_back_nodes = [ o for o in body_graph.get_operations() if o.type == "TensorListPushBack" ] # Gradient of `Mul` requires accumulating both its inputs. But since one # of those is a Const (2.0), we should have just one accumulator. self.assertLen(push_back_nodes, 1)
def _assertNotAccumulated(self, while_op, index): """Asserts that `while_op` input at `index` is not accumulated.""" body_graph = while_v2._get_graph(while_op, "body", "_body_graph") placeholder = body_graph.inputs[index] self.assertNotIn("TensorListPushBack", [op.type for op in placeholder.consumers()])