def testFuncCondFunc(self): @def_function.function def fn_with_cond(): cond_v2.cond_v2( constant_op.constant(True), lambda: constant_op.constant(1.), lambda: constant_op.constant(2.), name="cond_1") @def_function.function def true_branch(): return constant_op.constant(3.) return cond_v2.cond_v2( constant_op.constant(True), true_branch, lambda: constant_op.constant(4.), name="cond_2") concrete_fn = fn_with_cond.get_concrete_function() cond_1 = concrete_fn.graph.get_operation_by_name("cond_1") cond_2 = concrete_fn.graph.get_operation_by_name("cond_2") # Verify that all functional ops are stateless and cond_2 does not have # any control inputs. self.assertEqual(cond_1.type, "StatelessIf") self.assertEqual(cond_2.type, "StatelessIf") self.assertLen(cond_2.control_inputs, 0) cond_2_true_graph, _ = cond_v2.get_func_graphs(cond_2) cond_2_true_graph_operations = cond_2_true_graph.get_operations() self.assertEmpty([ op for op in cond_2_true_graph_operations if op.type == "StatefulPartitionedCall" ]) self.assertLen([ op for op in cond_2_true_graph_operations if op.type == "PartitionedCall" ], 1) fn_output = concrete_fn() self.assertEqual(fn_output.op.type, "PartitionedCall") self.assertAllEqual(fn_output, 3.0)
def testFuncCondFuncWithVariable(self): v1 = variables.Variable(2.) v2 = variables.Variable(4.) self.evaluate(variables.global_variables_initializer()) @def_function.function def fn_with_cond(): def update_v1(): v1.assign(v1) return v1 def update_v2(): v2.assign(v2) return v2 cond_v2.cond_v2( constant_op.constant(True), update_v1, lambda: constant_op.constant(0.), name="cond_1") cond_2 = cond_v2.cond_v2( constant_op.constant(False), lambda: constant_op.constant(0.), update_v1, name="cond_2") cond_v2.cond_v2( constant_op.constant(True), update_v2, lambda: constant_op.constant(0.), name="cond_3") @def_function.function def cond_4_false_branch(): v2.assign(v2) return v2 cond_4 = cond_v2.cond_v2( constant_op.constant(False), lambda: constant_op.constant(0.), cond_4_false_branch, name="cond_4") return cond_2, cond_4 concrete_fn = fn_with_cond.get_concrete_function() cond_1 = concrete_fn.graph.get_operation_by_name("cond_1") cond_2 = concrete_fn.graph.get_operation_by_name("cond_2") cond_3 = concrete_fn.graph.get_operation_by_name("cond_3") cond_4 = concrete_fn.graph.get_operation_by_name("cond_4") self.assertEqual(cond_1.type, "If") self.assertEqual(cond_2.type, "If") self.assertEqual(cond_3.type, "If") self.assertEqual(cond_4.type, "If") self.assertEmpty(cond_1.control_inputs) self.assertLen(cond_2.control_inputs, 1) self.assertIs(cond_2.control_inputs[0], cond_1) self.assertEmpty(cond_3.control_inputs) self.assertLen(cond_4.control_inputs, 1) self.assertIs(cond_4.control_inputs[0], cond_3) _, cond_4_false_graph = cond_v2.get_func_graphs(cond_4) cond_4_false_graph_operations = cond_4_false_graph.get_operations() self.assertEmpty([ op for op in cond_4_false_graph_operations if op.type == "PartitionedCall" ]) self.assertLen([ op for op in cond_4_false_graph_operations if op.type == "StatefulPartitionedCall" ], 1) fn_output = concrete_fn() self.assertEqual(fn_output[0].op.type, "StatefulPartitionedCall") self.assertAllEqual(self.evaluate(fn_output), [2.0, 4.0])