示例#1
0
def _create_new_tf_function(func_graph):
    """Converts func_graph to a TF_Function and adds it to the current graph.

  Args:
    func_graph: function._FuncGraph

  Returns:
    The name of the new TF_Function.
  """
    c_func = c_api.TF_GraphToFunction_wrapper(
        func_graph._c_graph,
        compat.as_str(func_graph.name),
        False,  # append_hash_to_fn_name
        None,  # opers
        [t._as_tf_output() for t in func_graph.inputs],
        [t._as_tf_output() for t in func_graph.outputs],
        [],
        None,  # opts
        None)  # description
    _ = c_api_util.ScopedTFFunction(c_func)

    # TODO(b/109833212): this sucks, we're serializing the TF_Function*,
    # deserializing it into a Python FunctionDef, then reserializing it to create
    # a new TF_Function that we add to the graph.
    fdef = function.function_def_from_tf_function(c_func)
    defined_func = function._from_definition(fdef)
    defined_func.add_to_graph(ops.get_default_graph())

    return func_graph.name
示例#2
0
def _create_new_tf_function(func_graph):
  """Converts func_graph to a TF_Function and adds it to the current graph.

  Args:
    func_graph: function._FuncGraph

  Returns:
    The name of the new TF_Function.
  """
  c_func = c_api.TF_GraphToFunction_wrapper(
      func_graph._c_graph,
      compat.as_str(func_graph.name),
      False,  # append_hash_to_fn_name
      None,  # opers
      [t._as_tf_output() for t in func_graph.inputs],
      [t._as_tf_output() for t in func_graph.outputs],
      [],
      None,  # opts
      None)  # description
  _ = c_api_util.ScopedTFFunction(c_func)

  # TODO(b/109833212): this sucks, we're serializing the TF_Function*,
  # deserializing it into a Python FunctionDef, then reserializing it to create
  # a new TF_Function that we add to the graph.
  fdef = function.function_def_from_tf_function(c_func)
  defined_func = function._from_definition(fdef)
  defined_func.add_to_graph(ops.get_default_graph())

  return func_graph.name
示例#3
0
 def expectFunctionsEqual(self, func, grad_func=None, new_func=None):
   if new_func is None:
     # Make a copy of func.definition to avoid any bugs masked by using the
     # same object
     serialized_fdef = func.definition.SerializeToString()
     # Serialize and then deserialize `func` to create `new_func`
     fdef = function_pb2.FunctionDef.FromString(serialized_fdef)
     new_func = function._from_definition(fdef, grad_func=grad_func)
   self.assertEqual(func.name, new_func.name)
   self.assertEqual(func.definition, new_func.definition)
   self.assertEqual(func.grad_func_name, new_func.grad_func_name)
   self.assertEqual(func.declared_input_types, new_func.declared_input_types)
   self.assertEqual(func.captured_inputs, new_func.captured_inputs)
示例#4
0
 def expectFunctionsEqual(self, func, grad_func=None, new_func=None):
   if new_func is None:
     # Make a copy of func.definition to avoid any bugs masked by using the
     # same object
     serialized_fdef = func.definition.SerializeToString()
     # Serialize and then deserialize `func` to create `new_func`
     fdef = function_pb2.FunctionDef.FromString(serialized_fdef)
     new_func = function._from_definition(fdef, grad_func=grad_func)
   self.assertEqual(func.name, new_func.name)
   self.assertEqual(func.definition, new_func.definition)
   self.assertEqual(func.grad_func_name, new_func.grad_func_name)
   self.assertEqual(func.declared_input_types, new_func.declared_input_types)
   self.assertEqual(func.captured_inputs, new_func.captured_inputs)
示例#5
0
  def testCapturedInputs(self):
    c = constant_op.constant(10, dtypes.int64)
    @function.Defun(dtypes.int64)
    def Foo(x):
      return x + c

    new_func = function._from_definition(Foo.definition)

    self.assertEqual(Foo.name, new_func.name)
    self.assertEqual(Foo.definition, new_func.definition)
    self.assertEqual(Foo.grad_func_name, new_func.grad_func_name)

    # Captured inputs are added as regular inputs to the function definition
    self.assertEqual(new_func.declared_input_types,
                     Foo.declared_input_types + (dtypes.int64,))
    self.assertEqual(len(new_func.captured_inputs), 0)