def _create_new_tf_function(func_graph): """Converts func_graph to a TF_Function and adds it to the current graph. Args: func_graph: function._FuncGraph Returns: The name of the new TF_Function. """ c_func = c_api.TF_GraphToFunction_wrapper( func_graph._c_graph, compat.as_str(func_graph.name), False, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in func_graph.inputs], [t._as_tf_output() for t in func_graph.outputs], [], None, # opts None) # description _ = c_api_util.ScopedTFFunction(c_func) # TODO(b/109833212): this sucks, we're serializing the TF_Function*, # deserializing it into a Python FunctionDef, then reserializing it to create # a new TF_Function that we add to the graph. fdef = function.function_def_from_tf_function(c_func) defined_func = function._from_definition(fdef) defined_func.add_to_graph(ops.get_default_graph()) return func_graph.name
def expectFunctionsEqual(self, func, grad_func=None, new_func=None): if new_func is None: # Make a copy of func.definition to avoid any bugs masked by using the # same object serialized_fdef = func.definition.SerializeToString() # Serialize and then deserialize `func` to create `new_func` fdef = function_pb2.FunctionDef.FromString(serialized_fdef) new_func = function._from_definition(fdef, grad_func=grad_func) self.assertEqual(func.name, new_func.name) self.assertEqual(func.definition, new_func.definition) self.assertEqual(func.grad_func_name, new_func.grad_func_name) self.assertEqual(func.declared_input_types, new_func.declared_input_types) self.assertEqual(func.captured_inputs, new_func.captured_inputs)
def testCapturedInputs(self): c = constant_op.constant(10, dtypes.int64) @function.Defun(dtypes.int64) def Foo(x): return x + c new_func = function._from_definition(Foo.definition) self.assertEqual(Foo.name, new_func.name) self.assertEqual(Foo.definition, new_func.definition) self.assertEqual(Foo.grad_func_name, new_func.grad_func_name) # Captured inputs are added as regular inputs to the function definition self.assertEqual(new_func.declared_input_types, Foo.declared_input_types + (dtypes.int64,)) self.assertEqual(len(new_func.captured_inputs), 0)