예제 #1
0
  def testFuncCondFunc(self):

    @def_function.function
    def fn_with_cond():
      cond_v2.cond_v2(
          constant_op.constant(True),
          lambda: constant_op.constant(1.),
          lambda: constant_op.constant(2.),
          name="cond_1")

      @def_function.function
      def true_branch():
        return constant_op.constant(3.)

      return cond_v2.cond_v2(
          constant_op.constant(True),
          true_branch,
          lambda: constant_op.constant(4.),
          name="cond_2")

    concrete_fn = fn_with_cond.get_concrete_function()
    cond_1 = concrete_fn.graph.get_operation_by_name("cond_1")
    cond_2 = concrete_fn.graph.get_operation_by_name("cond_2")
    # Verify that all functional ops are stateless and cond_2 does not have
    # any control inputs.
    self.assertEqual(cond_1.type, "StatelessIf")
    self.assertEqual(cond_2.type, "StatelessIf")
    self.assertLen(cond_2.control_inputs, 0)
    cond_2_true_graph, _ = cond_v2.get_func_graphs(cond_2)
    cond_2_true_graph_operations = cond_2_true_graph.get_operations()
    self.assertEmpty([
        op for op in cond_2_true_graph_operations
        if op.type == "StatefulPartitionedCall"
    ])
    self.assertLen([
        op for op in cond_2_true_graph_operations
        if op.type == "PartitionedCall"
    ], 1)
    fn_output = concrete_fn()
    self.assertEqual(fn_output.op.type, "PartitionedCall")
    self.assertAllEqual(fn_output, 3.0)
예제 #2
0
  def testFuncCondFuncWithVariable(self):
    v1 = variables.Variable(2.)
    v2 = variables.Variable(4.)

    self.evaluate(variables.global_variables_initializer())

    @def_function.function
    def fn_with_cond():

      def update_v1():
        v1.assign(v1)
        return v1

      def update_v2():
        v2.assign(v2)
        return v2

      cond_v2.cond_v2(
          constant_op.constant(True),
          update_v1,
          lambda: constant_op.constant(0.),
          name="cond_1")
      cond_2 = cond_v2.cond_v2(
          constant_op.constant(False),
          lambda: constant_op.constant(0.),
          update_v1,
          name="cond_2")
      cond_v2.cond_v2(
          constant_op.constant(True),
          update_v2,
          lambda: constant_op.constant(0.),
          name="cond_3")

      @def_function.function
      def cond_4_false_branch():
        v2.assign(v2)
        return v2

      cond_4 = cond_v2.cond_v2(
          constant_op.constant(False),
          lambda: constant_op.constant(0.),
          cond_4_false_branch,
          name="cond_4")
      return cond_2, cond_4

    concrete_fn = fn_with_cond.get_concrete_function()
    cond_1 = concrete_fn.graph.get_operation_by_name("cond_1")
    cond_2 = concrete_fn.graph.get_operation_by_name("cond_2")
    cond_3 = concrete_fn.graph.get_operation_by_name("cond_3")
    cond_4 = concrete_fn.graph.get_operation_by_name("cond_4")
    self.assertEqual(cond_1.type, "If")
    self.assertEqual(cond_2.type, "If")
    self.assertEqual(cond_3.type, "If")
    self.assertEqual(cond_4.type, "If")
    self.assertEmpty(cond_1.control_inputs)
    self.assertLen(cond_2.control_inputs, 1)
    self.assertIs(cond_2.control_inputs[0], cond_1)
    self.assertEmpty(cond_3.control_inputs)
    self.assertLen(cond_4.control_inputs, 1)
    self.assertIs(cond_4.control_inputs[0], cond_3)
    _, cond_4_false_graph = cond_v2.get_func_graphs(cond_4)
    cond_4_false_graph_operations = cond_4_false_graph.get_operations()
    self.assertEmpty([
        op for op in cond_4_false_graph_operations
        if op.type == "PartitionedCall"
    ])
    self.assertLen([
        op for op in cond_4_false_graph_operations
        if op.type == "StatefulPartitionedCall"
    ], 1)
    fn_output = concrete_fn()
    self.assertEqual(fn_output[0].op.type, "StatefulPartitionedCall")
    self.assertAllEqual(self.evaluate(fn_output), [2.0, 4.0])