def testShapes(self):
    fdef = self._build_function_def()

    g = function_def_to_graph.function_def_to_graph(fdef)
    self.assertIsNone(g.inputs[0].shape.dims)  # Unknown dims.
    self.assertIsNone(g.inputs[1].shape.dims)  # Unknown dims.
    self.assertIsNone(g.outputs[0].shape.dims)  # Unknown dims.
    self.assertIsNone(g.outputs[1].shape.dims)  # Unknown dims.

    g = function_def_to_graph.function_def_to_graph(
        fdef, input_shapes=[tensor_shape.vector(5),
                            tensor_shape.vector(5)])
    self.assertSequenceEqual(g.inputs[0].shape.dims, [5])
    self.assertSequenceEqual(g.inputs[1].shape.dims, [5])
    self.assertSequenceEqual(g.outputs[0].shape.dims, [5])
    self.assertSequenceEqual(g.outputs[1].shape.dims, [5])

    g = function_def_to_graph.function_def_to_graph(
        fdef, input_shapes=[None, tensor_shape.matrix(5, 7)])
    self.assertIsNone(g.inputs[0].shape.dims)
    self.assertSequenceEqual(g.inputs[1].shape.dims, [5, 7])
    self.assertSequenceEqual(g.outputs[0].shape.dims, [5, 7])
    self.assertSequenceEqual(g.outputs[1].shape.dims, [5, 7])

    # Should raise a ValueError if the length of input_shapes does not match
    # the number of input args in FunctionDef.signature.input_arg.
    with self.assertRaises(ValueError):
      g = function_def_to_graph.function_def_to_graph(
          fdef, input_shapes=[tensor_shape.matrix(5, 7)])
def resolve_functions(tf_graph):
    def toposort(data):
        while True:
            ordered = set(item for item, dep in data.items() if not dep)
            if not ordered:
                break
            yield ordered
            data = {
                item: (dep - ordered)
                for item, dep in data.items() if item not in ordered
            }

    _, _, _, _, _, functions = tflist_to_onnx(tf_graph, {})
    data = {}
    for k, fdef in tf_graph._functions.items():  # pylint: disable=protected-access
        input_shapes = functions.get(k)
        fdef = fdef.definition
        if input_shapes and len(fdef.signature.input_arg) < len(input_shapes):
            input_shapes = input_shapes[:len(fdef.signature.input_arg)]
        try:
            func = function_def_to_graph(fdef, input_shapes=input_shapes)
        except:  # pylint: disable=bare-except
            # if there is a missmatch between caller and function use the functions shape
            logger.warning("shape missmatch between caller and function: %s",
                           k)
            func = function_def_to_graph(fdef)
        _FUNCTIONS[k] = func
        _, _, _, _, _, tfunctions = tflist_to_onnx(func, {})
        functions.update(tfunctions)
        data[k] = set(tfunctions.keys())

    result = []
    for d in toposort(data):
        result.extend(list(d))
    return [_FUNCTIONS[k] for k in result]
Exemple #3
0
    def testShapes(self):
        fdef = self._build_function_def()

        g = function_def_to_graph.function_def_to_graph(fdef)
        self.assertIsNone(g.inputs[0].shape.dims)  # Unknown dims.
        self.assertIsNone(g.inputs[1].shape.dims)  # Unknown dims.
        self.assertIsNone(g.outputs[0].shape.dims)  # Unknown dims.
        self.assertIsNone(g.outputs[1].shape.dims)  # Unknown dims.

        g = function_def_to_graph.function_def_to_graph(
            fdef,
            input_shapes=[tensor_shape.vector(5),
                          tensor_shape.vector(5)])
        self.assertSequenceEqual(g.inputs[0].shape.dims, [5])
        self.assertSequenceEqual(g.inputs[1].shape.dims, [5])
        self.assertSequenceEqual(g.outputs[0].shape.dims, [5])
        self.assertSequenceEqual(g.outputs[1].shape.dims, [5])

        g = function_def_to_graph.function_def_to_graph(
            fdef, input_shapes=[None, tensor_shape.matrix(5, 7)])
        self.assertIsNone(g.inputs[0].shape.dims)
        self.assertSequenceEqual(g.inputs[1].shape.dims, [5, 7])
        self.assertSequenceEqual(g.outputs[0].shape.dims, [5, 7])
        self.assertSequenceEqual(g.outputs[1].shape.dims, [5, 7])

        # Should raise a ValueError if the length of input_shapes does not match
        # the number of input args in FunctionDef.signature.input_arg.
        with self.assertRaises(ValueError):
            g = function_def_to_graph.function_def_to_graph(
                fdef, input_shapes=[tensor_shape.matrix(5, 7)])
Exemple #4
0
def load_function_def_library(library):
  """Load a set of functions as concrete functions without captured inputs.

  Functions names are manipulated during load such that they do not overlap
  with previously created ones.

  Args:
    library: FunctionDefLibrary proto message.

  Returns:
    Map of original function names in the library to instances of
    `ConcreteFunction` without captured inputs.

  Raises:
    ValueError: if functions dependencies have a cycle.
  """
  functions = {}

  for fdef in _sort_function_defs(library):
    copy = _fix_fdef(fdef, functions)

    func_graph = function_def_lib.function_def_to_graph(copy)
    for dep in _list_function_deps(fdef):
      functions[dep].add_to_graph(func_graph)
    func = function_lib.ConcreteFunction(func_graph)
    func.add_to_graph()

    functions[fdef.signature.name] = func

    # Also register the gradients in the current root context.
    with ops.init_scope():
      func._register_gradient()  # pylint: disable=protected-access

  return functions
Exemple #5
0
def resolve_functions(tf_graph):
    def toposort(data):
        while True:
            ordered = set(item for item, dep in data.items() if not dep)
            if not ordered:
                break
            yield ordered
            data = {item: (dep - ordered) for item, dep in data.items() if item not in ordered}

    _, _, _, _, _, functions = tflist_to_onnx(tf_graph, {})
    data = {}
    for k, fdef in tf_graph._functions.items():  # pylint: disable=protected-access
        input_shapes = functions.get(k)
        fdef = fdef.definition
        if input_shapes and len(fdef.signature.input_arg) < len(input_shapes):
            input_shapes = input_shapes[:len(fdef.signature.input_arg)]
        func = function_def_to_graph.function_def_to_graph(fdef, input_shapes=input_shapes)
        _FUNCTIONS[k] = func
        _, _, _, _, _, tfunctions = tflist_to_onnx(func, {})
        functions.update(tfunctions)
        data[k] = set(tfunctions.keys())

    result = []
    for d in toposort(data):
        result.extend(list(d))
    return [_FUNCTIONS[k] for k in result]
Exemple #6
0
    def testResourceHandleInputShapes(self):
        # Test that shape inference and validation with resource handles works as
        # expected.

        # Create a graph to generate the input and handle shape attributes in the
        # FunctionDef.
        with ops.Graph().as_default() as g:
            v = variables.Variable(array_ops.ones((2, 3),
                                                  dtype=dtypes.float32))

            @def_function.function(input_signature=[
                tensor_spec.TensorSpec((None, 2, 2), dtypes.int32)
            ])
            def lookup(inp):
                return {
                    # gather_nd expects a nonscalar shape for `v`, otherwise raises
                    # error when doing shape inference.
                    "shape inference": array_ops.gather_nd(v, inp),
                    # Triggers output shape validation. Expected shape must be [].
                    "handle": v.handle
                }

            lookup.get_concrete_function().add_to_graph()
            fdef = g.as_graph_def(add_shapes=True).library.function[0]

        fg = function_def_to_graph.function_def_to_graph(fdef)
        self.assertSequenceEqual(fg.inputs[0].shape.as_list(), [None, 2, 2])
        self.assertSequenceEqual(fg.inputs[1].shape.as_list(), [])
Exemple #7
0
def _get_graph(while_op, func_attr_name):
    """Returns `FuncGraph` for the given function attribute.

  Args:
    while_op: The While Operation.
    func_attr_name: string

  Returns:
    `FuncGraph`
  """
    # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes.
    input_shapes = [
        tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes")
    ]
    func_name = while_op.get_attr(func_attr_name).name
    fdef = while_op.graph._get_function(func_name).definition
    # `while_op.graph` may not be the same as `ops.get_default_graph()` e.g.
    # if the `while_op` is in the body of another if/while/defun. We build the
    # `func_graph` with `while_op.graph` as its `outer_graph`. This resembles how
    # the `FuncGraph` was built in the forward pass. We need this so that we can
    # appropriately capture references to outer tensors in the nested grad graphs.
    with while_op.graph.as_default():
        func_graph = function_def_to_graph.function_def_to_graph(
            fdef, input_shapes)
    func_graph._while = while_op
    return func_graph
 def get_func_graph_for_name(self, graph, func_name):
     """Returns the FuncGraph associated to the given func_name if possible."""
     outer_graph = graph
     while graph is not None:
         # pylint: disable=protected-access
         func = graph._get_function(str(func_name))
         if func is not None:
             if hasattr(func, "graph"):
                 return func.graph
             # `outer_graph` may not be the same as `ops.get_default_graph()` e.g.
             # in the case of nested if ops or when the gradient is being computed
             # from inside a Defun. We build the `func_graph` with `outer_graph`
             # as its outer graph.
             with outer_graph.as_default():
                 # This is a _DefinedFunction.
                 func_graph = (function_def_to_graph.function_def_to_graph(
                     func.definition))
             if func_graph is not None:
                 return func_graph
         if hasattr(graph, "outer_graph"):
             graph = graph.outer_graph
         else:
             raise ValueError(
                 "Function {} does not exist in the graph.".format(
                     func_name))
Exemple #9
0
def _get_graph(while_op, func_attr_name):
  """Returns `FuncGraph` for the given function attribute.

  Args:
    while_op: The While Operation.
    func_attr_name: string

  Returns:
    `FuncGraph`
  """
  # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes.
  input_shapes = [
      tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes")
  ]
  func_name = while_op.get_attr(func_attr_name).name
  fdef = while_op.graph._get_function(func_name).definition
  # `while_op.graph` may not be the same as `ops.get_default_graph()` e.g.
  # if the `while_op` is in the body of another if/while/defun. We build the
  # `func_graph` with `while_op.graph` as its `outer_graph`. This resembles how
  # the `FuncGraph` was built in the forward pass. We need this so that we can
  # appropriately capture references to outer tensors in the nested grad graphs.
  with while_op.graph.as_default():
    func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes)
  func_graph._while = while_op
  return func_graph
  def testFunctionCallsFromFunction(self):
    x = constant_op.constant(5.0)
    y = constant_op.constant(10.0)

    @function.Defun()
    def fn():

      @function.Defun()
      def inner_fn():
        return x + y

      return inner_fn()

    # Instantiate the function in this graph so that
    # `function_def_to_graph` can find it.
    fn()

    def fn2():
      return 2 * fn()

    fdef = function._DefinedFunction(fn2, [], []).definition
    func_graph = function_def_to_graph.function_def_to_graph(fdef)
    with func_graph.as_default():
      x_ph, y_ph = func_graph.inputs
      with self.test_session(graph=func_graph) as sess:
        self.assertEqual(
            sess.run(func_graph.outputs[0], feed_dict={
                x_ph: 5.0,
                y_ph: 10.0
            }), 30.0)
Exemple #11
0
def load_function_def_library(library):
  """Load a set of functions as concrete functions without captured inputs.

  Functions names are manipulated during load such that they do not overlap
  with previously created ones.

  Args:
    library: FunctionDefLibrary proto message.

  Returns:
    Map of original function names in the library to instances of `Function`
    without captured inputs.

  Raises:
    ValueError: if functions dependencies have a cycle.
  """
  # TODO(andresp): Look into restoring gradient function information.
  functions = {}
  name_mapping = {}
  # Note: Use a new graph to allow function_def_to_graph to help validating
  # that the functions are loaded correctly. This is not possible to do
  # just in eager mode as there is no python API to find if a function has
  # been registered in eager. Note also that despite this the created
  # func_graphs can still be used in eager or in other graphs.
  with ops.Graph().as_default() as import_graph:
    for fdef in _sort_function_defs(library):
      copy = _fix_fdef(fdef, name_mapping)

      func_graph = function_def_lib.function_def_to_graph(copy)
      func = function_lib.Function(func_graph)
      func.add_to_graph(import_graph)

      name_mapping[fdef.signature.name] = func.name
      functions[fdef.signature.name] = func
  return functions
Exemple #12
0
    def testFunctionCallsFromFunction(self):
        x = constant_op.constant(5.0)
        y = constant_op.constant(10.0)

        @function.Defun()
        def fn():
            @function.Defun()
            def inner_fn():
                return x + y

            return inner_fn()

        # Instantiate the function in this graph so that
        # `function_def_to_graph` can find it.
        fn()

        def fn2():
            return 2 * fn()

        fdef = function._DefinedFunction(fn2, [], []).definition
        func_graph = function_def_to_graph.function_def_to_graph(fdef)
        with func_graph.as_default():
            x_ph, y_ph = func_graph.inputs
            with self.test_session(graph=func_graph) as sess:
                self.assertEqual(
                    sess.run(func_graph.outputs[0],
                             feed_dict={
                                 x_ph: 5.0,
                                 y_ph: 10.0
                             }), 30.0)
def load_function_def_library(library):
  """Load a set of functions as concrete functions without captured inputs.

  Functions names are manipulated during load such that they do not overlap
  with previously created ones.

  Args:
    library: FunctionDefLibrary proto message.

  Returns:
    Map of original function names in the library to instances of
    `ConcreteFunction` without captured inputs.

  Raises:
    ValueError: if functions dependencies have a cycle.
  """
  # TODO(andresp): Look into restoring gradient function information.
  functions = {}
  name_mapping = {}
  # Note: Use a new graph to allow function_def_to_graph to help validating
  # that the functions are loaded correctly. This is not possible to do
  # just in eager mode as there is no python API to find if a function has
  # been registered in eager. Note also that despite this the created
  # func_graphs can still be used in eager or in other graphs.
  with ops.Graph().as_default() as import_graph:
    for fdef in _sort_function_defs(library):
      copy = _fix_fdef(fdef, name_mapping)

      func_graph = function_def_lib.function_def_to_graph(copy)
      func = function_lib.ConcreteFunction(func_graph)
      func.add_to_graph(import_graph)

      name_mapping[fdef.signature.name] = func.name
      functions[fdef.signature.name] = func
  return functions
Exemple #14
0
def get_func_graph(op, input_shapes, func_name):
    """Generates and returns a FuncGraph for the given op and input_shapes."""
    fdef = None
    graph = op.graph
    # Recursively search the func in graphs.
    while graph is not None:
        func = graph._get_function(func_name)  # pylint: disable=protected-access
        if func is not None:
            fdef = func.definition
            break
        if hasattr(graph, "outer_graph"):
            graph = graph.outer_graph
        else:
            break

    if fdef is None:
        raise KeyError("%s cannot be found in the graph" % func_name)

    # `op.graph` may not be the same as `ops.get_default_graph()` e.g.
    # in the case of nested if ops or when the gradient is being computed
    # from inside a Defun. We build the `func_graph` with `op.graph` as its
    # `outer_graph`. This resembles how the `FuncGraph` was built in the
    # forward pass. We need this so that we can resolve references to tensors
    # in `func_graph` from its gradient graph in `_resolve_grad_inputs`.
    with op.graph.as_default():
        func_graph = function_def_to_graph.function_def_to_graph(
            fdef, input_shapes)
    return func_graph
  def testFunctionCallsFromFunction(self):
    x = constant_op.constant(5.0)
    y = constant_op.constant(10.0)

    @function.defun
    def fn():

      @function.defun
      def inner_fn():
        return x + y

      return inner_fn()

    @function.defun
    def fn2():
      return 2 * fn()

    fn2_defun = fn2.get_concrete_function()

    # Call `fn2` to make sure `fn` is correctly instantiated so
    # `function_def_to_graph` can find it.
    fn2_defun()

    fdef = fn2_defun._inference_function.definition
    func_graph = function_def_to_graph.function_def_to_graph(fdef)
    with func_graph.as_default():
      x_ph, y_ph = func_graph.inputs
      with self.session(graph=func_graph) as sess:
        self.assertEqual(
            sess.run(func_graph.outputs[0], feed_dict={
                x_ph: 5.0,
                y_ph: 10.0
            }), 30.0)
Exemple #16
0
    def testFunctionCallsFromFunction(self):
        x = constant_op.constant(5.0)
        y = constant_op.constant(10.0)

        @function.defun
        def fn():
            @function.defun
            def inner_fn():
                return x + y

            return inner_fn()

        @function.defun
        def fn2():
            return 2 * fn()

        fn2_defun = fn2.get_concrete_function()

        # Call `fn2` to make sure `fn` is correctly instantiated so
        # `function_def_to_graph` can find it.
        fn2_defun()

        fdef = fn2_defun._inference_function.definition
        func_graph = function_def_to_graph.function_def_to_graph(fdef)
        with func_graph.as_default():
            x_ph, y_ph = func_graph.inputs
            with self.session(graph=func_graph) as sess:
                self.assertEqual(
                    sess.run(func_graph.outputs[0],
                             feed_dict={
                                 x_ph: 5.0,
                                 y_ph: 10.0
                             }), 30.0)
Exemple #17
0
 def _load_func_graphs(self, function_library):
   # TODO(allenl): Do we need to do name mapping here? Not quite sure what
   # happens when loaded names collide with existing names.
   # TODO(andresp): Look into restoring nested and gradient functions in the
   # right order.
   self._functions = {}
   for fdef in function_library.function:
     graph = function_def_lib.function_def_to_graph(fdef)
     self._functions[fdef.signature.name] = function.Function(graph)
Exemple #18
0
 def testInputsAndOutputs(self):
     fdef = self._build_function_def()
     g = function_def_to_graph.function_def_to_graph(fdef)
     self.assertEqual(g.name, "_whats_in_a_name")
     with self.session(graph=g) as sess:
         inputs = sess.run(g.inputs, feed_dict={"x:0": 2, "y:0": 3})
         self.assertSequenceEqual(inputs, [2.0, 3.0])
         outputs = sess.run(g.outputs, feed_dict={"x:0": 2, "y:0": 3})
         self.assertSequenceEqual(outputs, [13.0, 35.0])
 def testInputsAndOutputs(self):
   fdef = self._build_function_def()
   g = function_def_to_graph.function_def_to_graph(fdef)
   self.assertEqual(g.name, "_whats_in_a_name")
   with self.session(graph=g) as sess:
     inputs = sess.run(g.inputs, feed_dict={"x:0": 2, "y:0": 3})
     self.assertSequenceEqual(inputs, [2.0, 3.0])
     outputs = sess.run(g.outputs, feed_dict={"x:0": 2, "y:0": 3})
     self.assertSequenceEqual(outputs, [13.0, 35.0])
Exemple #20
0
 def count_ops(self, graph):
     '''return num of all operators'''
     from tensorflow.python.framework.function_def_to_graph import function_def_to_graph
     num = len(graph.get_operations())
     for key, fdef in graph._functions.items():
         sub_tf_graph = function_def_to_graph(fdef.definition)
         self.functions[key] = sub_tf_graph
         num += self.count_ops(sub_tf_graph)
     return num
Exemple #21
0
 def _get_func_graph_for_branch(branch_name):
   extra_inputs = if_op.inputs[1:]  # First input is pred.
   input_shapes = [t.shape for t in extra_inputs]
   func_name = if_op.get_attr(branch_name).name
   fdef = if_op.graph._get_function(func_name).definition
   func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes)
   func_graph.extra_inputs = extra_inputs
   func_graph.extra_args = func_graph.inputs
   func_graph._captured = dict(zip(extra_inputs, func_graph.inputs))
   return func_graph
Exemple #22
0
def _get_func_graph_for_name(graph, func_name):
    """Returns the FuncGraph associated to the given func_name if possible."""
    func = graph._get_function(str(func_name))  # pylint: disable=protected-access
    if not func:
        raise ValueError(
            'Function {} does not exist in the graph.'.format(func_name))
    if not hasattr(func, 'graph'):
        # This is a _DefinedFunction.
        return function_def_to_graph.function_def_to_graph(func.definition)
    else:
        return func.graph
Exemple #23
0
def get_func_graph(op, input_shapes, func_name):
    """Generates and returns a FuncGraph for the given op and input_shapes."""
    fdef = op.graph._get_function(func_name).definition  # pylint: disable=protected-access
    # `op.graph` may not be the same as `ops.get_default_graph()` e.g.
    # in the case of nested if ops or when the gradient is being computed
    # from inside a Defun. We build the `func_graph` with `op.graph` as its
    # `outer_graph`. This resembles how the `FuncGraph` was built in the
    # forward pass. We need this so that we can resolve references to tensors
    # in `func_graph` from its gradient graph in `_resolve_grad_inputs`.
    with op.graph.as_default():
        func_graph = function_def_to_graph.function_def_to_graph(
            fdef, input_shapes)
    return func_graph
    def to_tf_function_graph(self):
        # type: () -> tf.Graph
        """
    Converts this graph into a new TensorFlow `Graph`. Also takes care of
    variables.
    Note that function_def_to_graph.function_def_to_graph won't work if
    function calls into other functions.

    Returns a fresh `tf.Graph` containing all the nodes and variables that
    this object represents.
    """
        return function_def_to_graph.function_def_to_graph(
            self.to_function_graph_def())
Exemple #25
0
def load_function_def_library(library, load_shared_name_suffix=None):
    """Load a set of functions as concrete functions without captured inputs.

  Functions names are manipulated during load such that they do not overlap
  with previously created ones.

  Args:
    library: FunctionDefLibrary proto message.
    load_shared_name_suffix: If specified, used to uniquify shared
      names. Otherwise a unique name is generated.

  Returns:
    Map of original function names in the library to instances of
    `ConcreteFunction` without captured inputs.

  Raises:
    ValueError: if functions dependencies have a cycle.
  """
    library_function_names = set(fdef.signature.name
                                 for fdef in library.function)
    functions = {}

    if load_shared_name_suffix is None:
        load_shared_name_suffix = "_load_{}".format(ops.uid())
    for fdef in _sort_function_defs(library, library_function_names):
        copy = _fix_fdef(fdef, functions, load_shared_name_suffix)

        # There is no need to copy all functions into the function def graph. It
        # leads to a O(n^2) increase of memory when importing functions and the
        # extra function definitions are a no-op since they already imported as a
        # function before and passed in explicitly (due to the topologic sort
        # import).
        func_graph = function_def_lib.function_def_to_graph(
            copy, copy_functions=False)

        for dep in _list_function_deps(fdef, library_function_names):
            functions[dep].add_to_graph(func_graph)
        func = function_lib.ConcreteFunction(func_graph)
        func.add_to_graph()
        if context.executing_eagerly():
            func.add_to_graph(ops.get_default_graph())

        functions[fdef.signature.name] = func

        # Also register the gradients in the current root context.
        with ops.init_scope():
            func._register_gradient()  # pylint: disable=protected-access

    return functions
Exemple #26
0
def _get_body_graph(while_op):
  """Returns `FuncGraph` for the while body.

  Args:
    while_op: The While Operation.

  Returns:
    `FuncGraph` for the while body.
  """
  extra_inputs = list(while_op.inputs)
  input_shapes = [t.shape for t in extra_inputs]
  func_name = while_op.get_attr("body").name
  fdef = while_op.graph._get_function(func_name).definition
  func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes)
  func_graph._while = while_op
  return func_graph
Exemple #27
0
def get_func_graph_for_name(graph, func_name):
    """Returns the FuncGraph associated to the given func_name if possible."""
    while graph is not None:
        func = graph._get_function(str(func_name))  # pylint: disable=protected-access
        if func is not None:
            if hasattr(func, 'graph'):
                return func.graph
            func_graph = function_def_to_graph.function_def_to_graph(
                func.definition)
            if func_graph is not None:
                return func_graph
        if hasattr(graph, 'outer_graph'):
            graph = graph.outer_graph
        else:
            raise ValueError(
                'Function {} does not exist in the graph.'.format(func_name))
Exemple #28
0
def _get_body_graph(while_op):
    """Returns `FuncGraph` for the while body.

  Args:
    while_op: The While Operation.

  Returns:
    `FuncGraph` for the while body.
  """
    extra_inputs = list(while_op.inputs)
    input_shapes = [t.shape for t in extra_inputs]
    func_name = while_op.get_attr("body").name
    fdef = while_op.graph._get_function(func_name).definition
    func_graph = function_def_to_graph.function_def_to_graph(
        fdef, input_shapes)
    func_graph._while = while_op
    return func_graph
Exemple #29
0
def _get_body_graph(while_op):
  """Returns `FuncGraph` for the while body.

  Args:
    while_op: The While Operation.

  Returns:
    `FuncGraph` for the while body.
  """
  # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes.
  input_shapes = [
      tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes")
  ]
  func_name = while_op.get_attr("body").name
  fdef = while_op.graph._get_function(func_name).definition
  func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes)
  func_graph._while = while_op
  return func_graph
Exemple #30
0
def _get_body_graph(while_op):
  """Returns `FuncGraph` for the while body.

  Args:
    while_op: The While Operation.

  Returns:
    `FuncGraph` for the while body.
  """
  # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes.
  input_shapes = [
      tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes")
  ]
  func_name = while_op.get_attr("body").name
  fdef = while_op.graph._get_function(func_name).definition
  func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes)
  func_graph._while = while_op
  return func_graph
Exemple #31
0
  def testControlDependencies(self):

    def fn(inp):
      x = constant_op.constant(2.0, name="x")
      # TODO(b/79881896): Test external control dependency once that's
      # supported.
      with ops.control_dependencies([x, inp]):
        constant_op.constant(3.0, name="y")
      return 4.0

    fdef = function._DefinedFunction(
        fn, ["inp"], [dtypes.float32]).definition
    func_graph = function_def_to_graph.function_def_to_graph(fdef)

    op = func_graph.get_operation_by_name("y")
    self.assertEqual(len(op.control_inputs), 2)
    self.assertEqual(op.control_inputs[0].name, "x")
    self.assertEqual(op.control_inputs[1].name, "inp")
Exemple #32
0
    def testControlDependencies(self):
        @function.defun
        def fn(inp):
            x = constant_op.constant(2.0, name="x")
            # TODO(b/79881896): Test external control dependency once that's
            # supported.
            with ops.control_dependencies([x, inp]):
                constant_op.constant(3.0, name="y")
            return 4.0

        inp = constant_op.constant(1.0)
        fdef = fn.get_concrete_function(inp).function_def
        func_graph = function_def_to_graph.function_def_to_graph(fdef)

        op = func_graph.get_operation_by_name("y")
        self.assertEqual(len(op.control_inputs), 2)
        self.assertEqual(op.control_inputs[0].name, "x")
        self.assertEqual(op.control_inputs[1].name, "placeholder")
Exemple #33
0
 def _get_func_graph_for_branch(name_attr_list):
   """Generates and returns a FuncGraph for the given branch."""
   inputs = op.inputs[1:]  # First input is pred.
   input_shapes = [t.shape for t in inputs]
   fdef = op.graph._get_function(name_attr_list.name).definition
   # `op.graph` may not be the same as `ops.get_default_graph()` e.g.
   # in the case of nested if ops or when the gradient is being computed
   # from inside a Defun. We build the `func_graph` with `op.graph` as its
   # `outer_graph`. This resembles how the `FuncGraph` was built in the
   # forward pass. We need this so that we can resolve references to tensors
   # in `func_graph` from its gradient graph in `_resolve_grad_inputs`.
   with op.graph.as_default():
     func_graph = function_def_to_graph.function_def_to_graph(
         fdef, input_shapes)
   func_graph.captures = collections.OrderedDict(zip(inputs,
                                                     func_graph.inputs))
   # Link the op so that the gradient code can use it.
   func_graph._forward_cond = op
   return func_graph
  def testControlDependencies(self):

    @function.defun
    def fn(inp):
      x = constant_op.constant(2.0, name="x")
      # TODO(b/79881896): Test external control dependency once that's
      # supported.
      with ops.control_dependencies([x, inp]):
        constant_op.constant(3.0, name="y")
      return 4.0

    inp = constant_op.constant(1.0)
    fdef = fn.get_concrete_function(inp).function_def
    func_graph = function_def_to_graph.function_def_to_graph(fdef)

    op = func_graph.get_operation_by_name("y")
    self.assertEqual(len(op.control_inputs), 2)
    self.assertEqual(op.control_inputs[0].name, "x")
    self.assertEqual(op.control_inputs[1].name, "inp")
Exemple #35
0
 def _get_func_graph_for_branch(branch_name):
   """Generates and returns a FuncGraph for the given branch."""
   inputs = if_op.inputs[1:]  # First input is pred.
   input_shapes = [t.shape for t in inputs]
   func_name = if_op.get_attr(branch_name).name
   fdef = if_op.graph._get_function(func_name).definition
   # `if_op.graph` may not be the same as `ops.get_default_graph()` e.g.
   # in the case of nested if ops or when the gradient is being computed
   # from inside a Defun. We build the `func_graph` with `if_op.graph` as its
   # `outer_graph`. This resembles how the `FuncGraph` was built in the
   # forward pass. We need this so that we can resolve references to tensors
   # in `func_graph` from its gradient graph in `_resolve_grad_inputs`.
   with if_op.graph.as_default():
     func_graph = function_def_to_graph.function_def_to_graph(
         fdef, input_shapes)
   func_graph.captures = collections.OrderedDict(zip(inputs,
                                                     func_graph.inputs))
   # Set the if op so that the gradient code can use it.
   func_graph._if = if_op
   return func_graph
Exemple #36
0
def load_function_def_library(library):
  """Load a set of functions as concrete functions without captured inputs.

  Functions names are manipulated during load such that they do not overlap
  with previously created ones.

  Args:
    library: FunctionDefLibrary proto message.

  Returns:
    Map of original function names in the library to instances of
    `ConcreteFunction` without captured inputs.

  Raises:
    ValueError: if functions dependencies have a cycle.
  """
  functions = {}

  load_shared_name_suffix = "_load_{}".format(ops.uid())
  for fdef in _sort_function_defs(library):
    copy = _fix_fdef(fdef, functions, load_shared_name_suffix)

    # There is no need to copy functions into the function def graph.
    # It leads to a O(n^2) increase of memory when importing functions
    # and the extra function definitions are a no-op since they already
    # imported as a function before (due to the topologic sort import).
    func_graph = function_def_lib.function_def_to_graph(
        copy, copy_functions=False)

    for dep in _list_function_deps(fdef):
      functions[dep].add_to_graph(func_graph)
    func = function_lib.ConcreteFunction(func_graph)
    func.add_to_graph()

    functions[fdef.signature.name] = func

    # Also register the gradients in the current root context.
    with ops.init_scope():
      func._register_gradient()  # pylint: disable=protected-access

  return functions
Exemple #37
0
def convert(inp_format, inp_loc, out_format, out_loc, output_nodes, ng_backend,
            device_id, backend_optional_params, shape_hints, do_aot,
            save_ng_clusters):
    """Functional api for converting TF models by inserting ngraph nodes.
    Sample usage:
    from tf2ngraph import convert
    convert('savedmodel', 'test_graph' , 'pbtxt', 'test_graph_ngraph.pbtxt', ['out_node'])
    convert('pbtxt', 'test_graph.pbtxt' , 'pbtxt', 'test_graph_ngraph.pbtxt', ['out_node'])

    Parameters:
    inp_format (string): 'savedmodel', 'pbtxt', 'pb'
    inp_loc (string): Location of input file or folder (in case of savedmodel)
    out_format (string): 'savedmodel', 'pbtxt', 'pb'
    out_loc (string): Location of output file or folder (in case of savedmodel)
    output_nodes (iterable of strings): names of output nodes

    Returns: void
   """
    exit_on_error(
        inp_format in allowed_formats['input'], 'Unsupported input format ' +
        inp_format + ". Supported formats: " + str(allowed_formats['input']))
    exit_on_error(
        out_format in allowed_formats['output'], 'Unsupported output format ' +
        out_format + ". Supported formats: " + str(allowed_formats['output']))
    exit_on_error(
        ngraph_bridge.is_grappler_enabled(),
        "ngraph-bridge is not built with grappler enabled, hence tf2ngraph is not supported."
    )
    input_gdef = get_gdef(inp_format, inp_loc)
    attach_device(input_gdef)
    output_gdef = run_ngraph_grappler_optimizer(
        input_gdef, output_nodes, ng_backend, device_id,
        backend_optional_params, shape_hints, do_aot)
    if save_ng_clusters:
        for fn in output_gdef.library.function:
            tf.io.write_graph(
                function_def_to_graph(fn).as_graph_def(),
                '.',
                fn.signature.name + '.pbtxt',
                as_text=True)
    save_model(output_gdef, out_format, out_loc)
Exemple #38
0
def load_function_def_library(library):
    """Load a set of functions as concrete functions without captured inputs.

  Functions names are manipulated during load such that they do not overlap
  with previously created ones.

  Args:
    library: FunctionDefLibrary proto message.

  Returns:
    Map of original function names in the library to instances of
    `ConcreteFunction` without captured inputs.

  Raises:
    ValueError: if functions dependencies have a cycle.
  """
    functions = {}

    # Note: Use a new graph to allow function_def_to_graph to help validating
    # that the functions are loaded correctly. This is not possible to do
    # just in eager mode as there is no python API to find if a function has
    # been registered in eager. Note also that despite this the created
    # func_graphs can still be used in eager or in other graphs.
    import_graph = ops.Graph()

    for fdef in _sort_function_defs(library):
        with import_graph.as_default():
            copy = _fix_fdef(fdef, functions)

            func_graph = function_def_lib.function_def_to_graph(copy)
            func = function_lib.ConcreteFunction(func_graph)
            func.add_to_graph(import_graph)

            functions[fdef.signature.name] = func

        # Also register the gradients in the current root context.
        with ops.init_scope():
            func._register_gradient()  # pylint: disable=protected-access

    return functions
  def testControlDependencies(self):

    v = variables.Variable(1)

    @function.defun
    def fn(inp):
      assign = v.assign(3, name="assign", read_value=False)
      x = constant_op.constant(2.0, name="x")
      # TODO(b/79881896): Test external control dependency once that's
      # supported.
      with ops.control_dependencies([x, inp, assign]):
        constant_op.constant(3.0, name="y")
      return 4.0

    inp = constant_op.constant(1.0)
    fdef = fn.get_concrete_function(inp).function_def
    func_graph = function_def_to_graph.function_def_to_graph(fdef)

    op = func_graph.get_operation_by_name("y")
    self.assertEqual(len(op.control_inputs), 3)
    self.assertEqual(op.control_inputs[0].name, "assign")
    self.assertEqual(op.control_inputs[1].name, "inp")
    self.assertEqual(op.control_inputs[2].name, "x")
Exemple #40
0
    def testControlDependencies(self):

        v = variables.Variable(1)

        @function.defun
        def fn(inp):
            assign = v.assign(3, name="assign", read_value=False)
            x = constant_op.constant(2.0, name="x")
            # TODO(b/79881896): Test external control dependency once that's
            # supported.
            with ops.control_dependencies([x, inp, assign]):
                constant_op.constant(3.0, name="y")
            return 4.0

        inp = constant_op.constant(1.0)
        fdef = fn.get_concrete_function(inp).function_def
        func_graph = function_def_to_graph.function_def_to_graph(fdef)

        op = func_graph.get_operation_by_name("y")
        self.assertEqual(len(op.control_inputs), 3)
        self.assertEqual(op.control_inputs[0].name, "assign")
        self.assertEqual(op.control_inputs[1].name, "inp")
        self.assertEqual(op.control_inputs[2].name, "x")
def load_function_def_library(library, load_shared_name_suffix=None):
    """Load a set of functions as concrete functions without captured inputs.

  Functions names are manipulated during load such that they do not overlap
  with previously created ones.

  Args:
    library: FunctionDefLibrary proto message.
    load_shared_name_suffix: If specified, used to uniquify shared
      names. Otherwise, a unique name is generated.

  Returns:
    Map of original function names in the library to instances of
    `ConcreteFunction` without captured inputs.

  Raises:
    ValueError: if functions dependencies have a cycle.
  """
    library_function_names = set(fdef.signature.name
                                 for fdef in library.function)
    functions = {}
    renamed_functions = {}

    # Our graph building code currently requires functions to be registered with
    # some tf.Graph in order to import functions using the
    # op-name-is-function-name calling convention. To avoid leaking memory into
    # the global default graph when executing eagerly, we create a temporary
    # Graph.
    #
    # TODO(allenl): Make this Graph creation unnecessary when executing eagerly by
    # fixing function_def_to_graph_def.
    if ops.executing_eagerly_outside_functions():
        graph = ops.Graph()
    else:
        graph = ops.get_default_graph()

    if load_shared_name_suffix is None:
        load_shared_name_suffix = "_load_{}".format(ops.uid())
    for fdef in _sort_function_defs(library, library_function_names):
        copy = _fix_fdef(fdef, functions, load_shared_name_suffix)

        # There is no need to copy all functions into the function def graph. It
        # leads to a O(n^2) increase of memory when importing functions and the
        # extra function definitions are a no-op since they already imported as a
        # function before and passed in explicitly (due to the topologic sort
        # import).
        with graph.as_default():
            func_graph = function_def_lib.function_def_to_graph(copy)
        _restore_gradient_functions(func_graph, renamed_functions)

        for dep in _list_function_deps(fdef, library_function_names):
            functions[dep].add_to_graph(func_graph)

        # We do not initialize the new ConcreteFunction's function_spec and/or
        # arg_keywords here (which are used to parse the structured and flat
        # signatures, respectively). ConcreteFunction that are part of a saved
        # function is set up later by recreate_function(); and bare ConcreteFunction
        # is set up by by setup_bare_concrete_function().
        func = function_lib.ConcreteFunction(func_graph)
        func.add_to_graph(graph)

        functions[fdef.signature.name] = func
        renamed_functions[func.name] = func
        if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()):
            # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration
            # is fixed. Currently it's leaking memory to maintain bug compatibility
            # with previous behavior.
            func.add_to_graph(ops.get_default_graph())

    return functions