def load_function_def_library(library): """Load a set of functions as concrete functions without captured inputs. Functions names are manipulated during load such that they do not overlap with previously created ones. Args: library: FunctionDefLibrary proto message. Returns: Map of original function names in the library to instances of `Function` without captured inputs. Raises: ValueError: if functions dependencies have a cycle. """ # TODO(andresp): Look into restoring gradient function information. functions = {} name_mapping = {} # Note: Use a new graph to allow function_def_to_graph to help validating # that the functions are loaded correctly. This is not possible to do # just in eager mode as there is no python API to find if a function has # been registered in eager. Note also that despite this the created # func_graphs can still be used in eager or in other graphs. with ops.Graph().as_default() as import_graph: for fdef in _sort_function_defs(library): copy = _fix_fdef(fdef, name_mapping) func_graph = function_def_lib.function_def_to_graph(copy) func = function_lib.Function(func_graph) func.add_to_graph(import_graph) name_mapping[fdef.signature.name] = func.name functions[fdef.signature.name] = func return functions
def _load_func_graphs(self, function_library): # TODO(allenl): Do we need to do name mapping here? Not quite sure what # happens when loaded names collide with existing names. # TODO(andresp): Look into restoring nested and gradient functions in the # right order. self._functions = {} for fdef in function_library.function: graph = function_def_lib.function_def_to_graph(fdef) self._functions[fdef.signature.name] = function.Function(graph)
def wrap_function(fn, signature, name=None): """Wraps the TF 1.x function fn into a graph function. The python function `fn` will be called once with symbolic arguments specified in the `signature`, traced, and turned into a graph function. Any variables created by `fn` will be owned by the object returned by `wrap_function`. The resulting graph function can be called with tensors which match the signature. ```python def f(x, do_add): v = tf.Variable(5.0) if do_add: op = v.assign_add(x) else: op = v.assign_sub(x) with tf.control_dependencies([op]): return v.read_value() f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True]) assert float(f_add(1.0)) == 6.0 assert float(f_add(1.0)) == 7.0 # Can call tf.compat.v1.wrap_function again to get a new trace, a new set # of variables, and possibly different non-template arguments. f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False]) assert float(f_sub(1.0)) == 4.0 assert float(f_sub(1.0)) == 3.0 ``` Args: fn: python function to be wrapped signature: the placeholder and python arguments to be passed to the wrapped function name: Optional. The name of the function. Returns: the wrapped graph function. """ holder = VariableHolder(fn) fn = function.Function(function.func_graph_from_py_func( name, holder, args=None, kwargs=None, signature=signature, add_control_dependencies=False), signature=signature) fn._variable_holder = holder return fn