def testDuplicateAccumulator(self): x = constant_op.constant(2.) tensor_list = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=ScalarShape()) def Cond(x, tl): del tl # Unused for Cond. return x < 5. def Body(x, tl): # There is an accumulator in the loop already so we should not add # another. tl = list_ops.tensor_list_push_back(tl, x) return x**2., tl ret = while_loop_v2(Cond, Body, [x, tensor_list]) for op in ops.get_default_graph().get_operations(): if op.type == "While": while_op = op body_graph = while_v2._get_body_graph(while_op) # body_graph.inputs: [counter_arg, x_arg, tl_arg, *accumulators] x_input_t = body_graph.inputs[1] accumulator_count = len([ c for c in x_input_t.consumers() if c.type == "TensorListPushBack" ]) self.assertEqual(accumulator_count, 1) grad = gradients_impl.gradients(ret[0], x) with self.cached_session() as sess: self.assertEqual(sess.run(ret[0]), 16.) self.assertSequenceEqual(self.evaluate(grad), [32.])
def GetAccumulatorForInputAtIndex(while_op, idx): body_graph = while_v2._get_body_graph(while_op) y_input_t = body_graph.inputs[idx] push_back_node = [c for c in y_input_t.consumers() if c.type == "TensorListPushBack"][0] output_idx = body_graph.outputs.index(push_back_node.outputs[0]) return while_op.outputs[output_idx]
def testDuplicateAccumulator(self): x = constant_op.constant(2.) tensor_list = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=ScalarShape()) def Cond(x, tl): del tl # Unused for Cond. return x < 5. def Body(x, tl): # There is an accumulator in the loop already so we should not add # another. tl = list_ops.tensor_list_push_back(tl, x) return x**2., tl ret = while_loop_v2(Cond, Body, [x, tensor_list]) for op in ops.get_default_graph().get_operations(): if op.type == "While": while_op = op body_graph = while_v2._get_body_graph(while_op) # body_graph.inputs: [counter_arg, x_arg, tl_arg, *accumulators] x_input_t = body_graph.inputs[1] accumulator_count = len( [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"]) self.assertEqual(accumulator_count, 1) grad = gradients_impl.gradients(ret[0], x) with self.cached_session() as sess: self.assertEqual(sess.run(ret[0]), 16.) self.assertSequenceEqual(sess.run(grad), [32.])
def testTensorListOutputElementShape(self, shape): self.skipTest("b/115982901") x = constant_op.constant(2.) y = array_ops.placeholder(dtype=dtypes.float32, shape=shape) ret = while_loop_v2(lambda v, u: v < 8., lambda v, u: (v * v, u), [x, y]) # Get the TensorList output of While op containing the accumulated values # of y. while_op = ret[0].op body_graph = while_v2._get_body_graph(while_op) # body_graph.inputs: [counter_arg, x_arg, y_arg, *accumulators] y_input_t = body_graph.inputs[2] push_back_node = [ c for c in y_input_t.consumers() if c.type == "TensorListPushBack" ][0] output_idx = body_graph.outputs.index(push_back_node.outputs[0]) output = while_op.outputs[output_idx] _, val = list_ops.tensor_list_pop_back(output, element_dtype=dtypes.float32) self.assertEqual(val.shape, tensor_shape.TensorShape(shape))