コード例 #1
0
def load_function_def_library(library):
  """Load a set of functions as concrete functions without captured inputs.

  Functions names are manipulated during load such that they do not overlap
  with previously created ones.

  Args:
    library: FunctionDefLibrary proto message.

  Returns:
    Map of original function names in the library to instances of `Function`
    without captured inputs.

  Raises:
    ValueError: if functions dependencies have a cycle.
  """
  # TODO(andresp): Look into restoring gradient function information.
  functions = {}
  name_mapping = {}
  # Note: Use a new graph to allow function_def_to_graph to help validating
  # that the functions are loaded correctly. This is not possible to do
  # just in eager mode as there is no python API to find if a function has
  # been registered in eager. Note also that despite this the created
  # func_graphs can still be used in eager or in other graphs.
  with ops.Graph().as_default() as import_graph:
    for fdef in _sort_function_defs(library):
      copy = _fix_fdef(fdef, name_mapping)

      func_graph = function_def_lib.function_def_to_graph(copy)
      func = function_lib.Function(func_graph)
      func.add_to_graph(import_graph)

      name_mapping[fdef.signature.name] = func.name
      functions[fdef.signature.name] = func
  return functions
コード例 #2
0
 def _load_func_graphs(self, function_library):
   # TODO(allenl): Do we need to do name mapping here? Not quite sure what
   # happens when loaded names collide with existing names.
   # TODO(andresp): Look into restoring nested and gradient functions in the
   # right order.
   self._functions = {}
   for fdef in function_library.function:
     graph = function_def_lib.function_def_to_graph(fdef)
     self._functions[fdef.signature.name] = function.Function(graph)
コード例 #3
0
def wrap_function(fn, signature, name=None):
    """Wraps the TF 1.x function fn into a graph function.

  The python function `fn` will be called once with symbolic arguments specified
  in the `signature`, traced, and turned into a graph function. Any variables
  created by `fn` will be owned by the object returned by `wrap_function`. The
  resulting graph function can be called with tensors which match the
  signature.

  ```python
  def f(x, do_add):
    v = tf.Variable(5.0)
    if do_add:
      op = v.assign_add(x)
    else:
      op = v.assign_sub(x)
    with tf.control_dependencies([op]):
      return v.read_value()

  f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True])

  assert float(f_add(1.0)) == 6.0
  assert float(f_add(1.0)) == 7.0

  # Can call tf.compat.v1.wrap_function again to get a new trace, a new set
  # of variables, and possibly different non-template arguments.
  f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False])

  assert float(f_sub(1.0)) == 4.0
  assert float(f_sub(1.0)) == 3.0
  ```

  Args:
    fn: python function to be wrapped
    signature: the placeholder and python arguments to be passed to the
      wrapped function
    name: Optional. The name of the function.

  Returns:
    the wrapped graph function.
  """
    holder = VariableHolder(fn)
    fn = function.Function(function.func_graph_from_py_func(
        name,
        holder,
        args=None,
        kwargs=None,
        signature=signature,
        add_control_dependencies=False),
                           signature=signature)
    fn._variable_holder = holder
    return fn