Ejemplo n.º 1
0
    def test_create_nullary(self):
        fndef = text_format.Parse(
            """
            signature {
               name: 'NullaryFunction'
               output_arg { name: 'o' type: DT_INT32 }
             }
             node_def {
               name: 'retval'
               op: 'Const'
               attr {
                 key: 'dtype'
                 value { type: DT_INT32 }
               }
               attr {
                 key: 'value'
                 value {
                   tensor {
                     dtype: DT_INT32
                     tensor_shape {}
                     int_val: 1
                   }
                 }
               }
             }
             ret { key: 'o' value: 'retval:output' }
         """,
            function_pb2.FunctionDef(),
        )

        ctx = runtime_client.GlobalEagerContext()
        rt = runtime_client.Runtime(ctx)
        rt.CreateFunction(fndef)
Ejemplo n.º 2
0
    def test_get_function_proto_from_py_runtime_function(self):
        if not tf2.enabled():
            self.skipTest("TF2 test")

        @def_function.function
        def f():
            return 1

        cf = f.get_concrete_function()

        ctx = runtime_client.GlobalPythonEagerContext()
        rt = runtime_client.Runtime(ctx)
        fndef = rt.GetFunctionProto(cf.function_def.signature.name)

        self.assertEqual(fndef.signature.name, cf.function_def.signature.name)
Ejemplo n.º 3
0
    def test_concrete_function_editing_proto(self):
        if not tf2.enabled():
            self.skipTest("TF2 test")

        @def_function.function
        def f():
            return 1

        cf = f.get_concrete_function()

        ctx = runtime_client.GlobalPythonEagerContext()
        rt = runtime_client.Runtime(ctx)
        fndef = rt.GetFunctionProto(cf.function_def.signature.name)

        fndef.node_def[0].attr["value"].tensor.int_val[0] = 2

        rt.CreateFunction(fndef)

        self.assertAllEqual(self.evaluate(f()), 2)
Ejemplo n.º 4
0
    def test_create_function_called_by_py_runtime(self):
        if not tf2.enabled():
            self.skipTest("TF2 test")

        fndef = text_format.Parse(
            """
            signature {
               name: 'NullaryFunction'
               output_arg { name: 'o' type: DT_INT32 }
             }
             node_def {
               name: 'retval'
               op: 'Const'
               attr {
                 key: 'dtype'
                 value { type: DT_INT32 }
               }
               attr {
                 key: 'value'
                 value {
                   tensor {
                     dtype: DT_INT32
                     tensor_shape {}
                     int_val: 1
                   }
                 }
               }
             }
             ret { key: 'o' value: 'retval:output' }
         """,
            function_pb2.FunctionDef(),
        )

        ctx = runtime_client.GlobalPythonEagerContext()
        rt = runtime_client.Runtime(ctx)
        rt.CreateFunction(fndef)

        ret, = execute.execute("NullaryFunction", 1, [], (), context.context())
        self.assertAllEqual(ret, 1)