def testShapes(self): fdef = self._build_function_def() g = function_def_to_graph.function_def_to_graph(fdef) self.assertIsNone(g.inputs[0].shape.dims) # Unknown dims. self.assertIsNone(g.inputs[1].shape.dims) # Unknown dims. self.assertIsNone(g.outputs[0].shape.dims) # Unknown dims. self.assertIsNone(g.outputs[1].shape.dims) # Unknown dims. g = function_def_to_graph.function_def_to_graph( fdef, input_shapes=[tensor_shape.vector(5), tensor_shape.vector(5)]) self.assertSequenceEqual(g.inputs[0].shape.dims, [5]) self.assertSequenceEqual(g.inputs[1].shape.dims, [5]) self.assertSequenceEqual(g.outputs[0].shape.dims, [5]) self.assertSequenceEqual(g.outputs[1].shape.dims, [5]) g = function_def_to_graph.function_def_to_graph( fdef, input_shapes=[None, tensor_shape.matrix(5, 7)]) self.assertIsNone(g.inputs[0].shape.dims) self.assertSequenceEqual(g.inputs[1].shape.dims, [5, 7]) self.assertSequenceEqual(g.outputs[0].shape.dims, [5, 7]) self.assertSequenceEqual(g.outputs[1].shape.dims, [5, 7]) # Should raise a ValueError if the length of input_shapes does not match # the number of input args in FunctionDef.signature.input_arg. with self.assertRaises(ValueError): g = function_def_to_graph.function_def_to_graph( fdef, input_shapes=[tensor_shape.matrix(5, 7)])
def resolve_functions(tf_graph): def toposort(data): while True: ordered = set(item for item, dep in data.items() if not dep) if not ordered: break yield ordered data = { item: (dep - ordered) for item, dep in data.items() if item not in ordered } _, _, _, _, _, functions = tflist_to_onnx(tf_graph, {}) data = {} for k, fdef in tf_graph._functions.items(): # pylint: disable=protected-access input_shapes = functions.get(k) fdef = fdef.definition if input_shapes and len(fdef.signature.input_arg) < len(input_shapes): input_shapes = input_shapes[:len(fdef.signature.input_arg)] try: func = function_def_to_graph(fdef, input_shapes=input_shapes) except: # pylint: disable=bare-except # if there is a missmatch between caller and function use the functions shape logger.warning("shape missmatch between caller and function: %s", k) func = function_def_to_graph(fdef) _FUNCTIONS[k] = func _, _, _, _, _, tfunctions = tflist_to_onnx(func, {}) functions.update(tfunctions) data[k] = set(tfunctions.keys()) result = [] for d in toposort(data): result.extend(list(d)) return [_FUNCTIONS[k] for k in result]
def testShapes(self): fdef = self._build_function_def() g = function_def_to_graph.function_def_to_graph(fdef) self.assertIsNone(g.inputs[0].shape.dims) # Unknown dims. self.assertIsNone(g.inputs[1].shape.dims) # Unknown dims. self.assertIsNone(g.outputs[0].shape.dims) # Unknown dims. self.assertIsNone(g.outputs[1].shape.dims) # Unknown dims. g = function_def_to_graph.function_def_to_graph( fdef, input_shapes=[tensor_shape.vector(5), tensor_shape.vector(5)]) self.assertSequenceEqual(g.inputs[0].shape.dims, [5]) self.assertSequenceEqual(g.inputs[1].shape.dims, [5]) self.assertSequenceEqual(g.outputs[0].shape.dims, [5]) self.assertSequenceEqual(g.outputs[1].shape.dims, [5]) g = function_def_to_graph.function_def_to_graph( fdef, input_shapes=[None, tensor_shape.matrix(5, 7)]) self.assertIsNone(g.inputs[0].shape.dims) self.assertSequenceEqual(g.inputs[1].shape.dims, [5, 7]) self.assertSequenceEqual(g.outputs[0].shape.dims, [5, 7]) self.assertSequenceEqual(g.outputs[1].shape.dims, [5, 7]) # Should raise a ValueError if the length of input_shapes does not match # the number of input args in FunctionDef.signature.input_arg. with self.assertRaises(ValueError): g = function_def_to_graph.function_def_to_graph( fdef, input_shapes=[tensor_shape.matrix(5, 7)])
def load_function_def_library(library): """Load a set of functions as concrete functions without captured inputs. Functions names are manipulated during load such that they do not overlap with previously created ones. Args: library: FunctionDefLibrary proto message. Returns: Map of original function names in the library to instances of `ConcreteFunction` without captured inputs. Raises: ValueError: if functions dependencies have a cycle. """ functions = {} for fdef in _sort_function_defs(library): copy = _fix_fdef(fdef, functions) func_graph = function_def_lib.function_def_to_graph(copy) for dep in _list_function_deps(fdef): functions[dep].add_to_graph(func_graph) func = function_lib.ConcreteFunction(func_graph) func.add_to_graph() functions[fdef.signature.name] = func # Also register the gradients in the current root context. with ops.init_scope(): func._register_gradient() # pylint: disable=protected-access return functions
def resolve_functions(tf_graph): def toposort(data): while True: ordered = set(item for item, dep in data.items() if not dep) if not ordered: break yield ordered data = {item: (dep - ordered) for item, dep in data.items() if item not in ordered} _, _, _, _, _, functions = tflist_to_onnx(tf_graph, {}) data = {} for k, fdef in tf_graph._functions.items(): # pylint: disable=protected-access input_shapes = functions.get(k) fdef = fdef.definition if input_shapes and len(fdef.signature.input_arg) < len(input_shapes): input_shapes = input_shapes[:len(fdef.signature.input_arg)] func = function_def_to_graph.function_def_to_graph(fdef, input_shapes=input_shapes) _FUNCTIONS[k] = func _, _, _, _, _, tfunctions = tflist_to_onnx(func, {}) functions.update(tfunctions) data[k] = set(tfunctions.keys()) result = [] for d in toposort(data): result.extend(list(d)) return [_FUNCTIONS[k] for k in result]
def testResourceHandleInputShapes(self): # Test that shape inference and validation with resource handles works as # expected. # Create a graph to generate the input and handle shape attributes in the # FunctionDef. with ops.Graph().as_default() as g: v = variables.Variable(array_ops.ones((2, 3), dtype=dtypes.float32)) @def_function.function(input_signature=[ tensor_spec.TensorSpec((None, 2, 2), dtypes.int32) ]) def lookup(inp): return { # gather_nd expects a nonscalar shape for `v`, otherwise raises # error when doing shape inference. "shape inference": array_ops.gather_nd(v, inp), # Triggers output shape validation. Expected shape must be []. "handle": v.handle } lookup.get_concrete_function().add_to_graph() fdef = g.as_graph_def(add_shapes=True).library.function[0] fg = function_def_to_graph.function_def_to_graph(fdef) self.assertSequenceEqual(fg.inputs[0].shape.as_list(), [None, 2, 2]) self.assertSequenceEqual(fg.inputs[1].shape.as_list(), [])
def _get_graph(while_op, func_attr_name): """Returns `FuncGraph` for the given function attribute. Args: while_op: The While Operation. func_attr_name: string Returns: `FuncGraph` """ # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes. input_shapes = [ tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes") ] func_name = while_op.get_attr(func_attr_name).name fdef = while_op.graph._get_function(func_name).definition # `while_op.graph` may not be the same as `ops.get_default_graph()` e.g. # if the `while_op` is in the body of another if/while/defun. We build the # `func_graph` with `while_op.graph` as its `outer_graph`. This resembles how # the `FuncGraph` was built in the forward pass. We need this so that we can # appropriately capture references to outer tensors in the nested grad graphs. with while_op.graph.as_default(): func_graph = function_def_to_graph.function_def_to_graph( fdef, input_shapes) func_graph._while = while_op return func_graph
def get_func_graph_for_name(self, graph, func_name): """Returns the FuncGraph associated to the given func_name if possible.""" outer_graph = graph while graph is not None: # pylint: disable=protected-access func = graph._get_function(str(func_name)) if func is not None: if hasattr(func, "graph"): return func.graph # `outer_graph` may not be the same as `ops.get_default_graph()` e.g. # in the case of nested if ops or when the gradient is being computed # from inside a Defun. We build the `func_graph` with `outer_graph` # as its outer graph. with outer_graph.as_default(): # This is a _DefinedFunction. func_graph = (function_def_to_graph.function_def_to_graph( func.definition)) if func_graph is not None: return func_graph if hasattr(graph, "outer_graph"): graph = graph.outer_graph else: raise ValueError( "Function {} does not exist in the graph.".format( func_name))
def _get_graph(while_op, func_attr_name): """Returns `FuncGraph` for the given function attribute. Args: while_op: The While Operation. func_attr_name: string Returns: `FuncGraph` """ # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes. input_shapes = [ tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes") ] func_name = while_op.get_attr(func_attr_name).name fdef = while_op.graph._get_function(func_name).definition # `while_op.graph` may not be the same as `ops.get_default_graph()` e.g. # if the `while_op` is in the body of another if/while/defun. We build the # `func_graph` with `while_op.graph` as its `outer_graph`. This resembles how # the `FuncGraph` was built in the forward pass. We need this so that we can # appropriately capture references to outer tensors in the nested grad graphs. with while_op.graph.as_default(): func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes) func_graph._while = while_op return func_graph
def testFunctionCallsFromFunction(self): x = constant_op.constant(5.0) y = constant_op.constant(10.0) @function.Defun() def fn(): @function.Defun() def inner_fn(): return x + y return inner_fn() # Instantiate the function in this graph so that # `function_def_to_graph` can find it. fn() def fn2(): return 2 * fn() fdef = function._DefinedFunction(fn2, [], []).definition func_graph = function_def_to_graph.function_def_to_graph(fdef) with func_graph.as_default(): x_ph, y_ph = func_graph.inputs with self.test_session(graph=func_graph) as sess: self.assertEqual( sess.run(func_graph.outputs[0], feed_dict={ x_ph: 5.0, y_ph: 10.0 }), 30.0)
def load_function_def_library(library): """Load a set of functions as concrete functions without captured inputs. Functions names are manipulated during load such that they do not overlap with previously created ones. Args: library: FunctionDefLibrary proto message. Returns: Map of original function names in the library to instances of `Function` without captured inputs. Raises: ValueError: if functions dependencies have a cycle. """ # TODO(andresp): Look into restoring gradient function information. functions = {} name_mapping = {} # Note: Use a new graph to allow function_def_to_graph to help validating # that the functions are loaded correctly. This is not possible to do # just in eager mode as there is no python API to find if a function has # been registered in eager. Note also that despite this the created # func_graphs can still be used in eager or in other graphs. with ops.Graph().as_default() as import_graph: for fdef in _sort_function_defs(library): copy = _fix_fdef(fdef, name_mapping) func_graph = function_def_lib.function_def_to_graph(copy) func = function_lib.Function(func_graph) func.add_to_graph(import_graph) name_mapping[fdef.signature.name] = func.name functions[fdef.signature.name] = func return functions
def testFunctionCallsFromFunction(self): x = constant_op.constant(5.0) y = constant_op.constant(10.0) @function.Defun() def fn(): @function.Defun() def inner_fn(): return x + y return inner_fn() # Instantiate the function in this graph so that # `function_def_to_graph` can find it. fn() def fn2(): return 2 * fn() fdef = function._DefinedFunction(fn2, [], []).definition func_graph = function_def_to_graph.function_def_to_graph(fdef) with func_graph.as_default(): x_ph, y_ph = func_graph.inputs with self.test_session(graph=func_graph) as sess: self.assertEqual( sess.run(func_graph.outputs[0], feed_dict={ x_ph: 5.0, y_ph: 10.0 }), 30.0)
def load_function_def_library(library): """Load a set of functions as concrete functions without captured inputs. Functions names are manipulated during load such that they do not overlap with previously created ones. Args: library: FunctionDefLibrary proto message. Returns: Map of original function names in the library to instances of `ConcreteFunction` without captured inputs. Raises: ValueError: if functions dependencies have a cycle. """ # TODO(andresp): Look into restoring gradient function information. functions = {} name_mapping = {} # Note: Use a new graph to allow function_def_to_graph to help validating # that the functions are loaded correctly. This is not possible to do # just in eager mode as there is no python API to find if a function has # been registered in eager. Note also that despite this the created # func_graphs can still be used in eager or in other graphs. with ops.Graph().as_default() as import_graph: for fdef in _sort_function_defs(library): copy = _fix_fdef(fdef, name_mapping) func_graph = function_def_lib.function_def_to_graph(copy) func = function_lib.ConcreteFunction(func_graph) func.add_to_graph(import_graph) name_mapping[fdef.signature.name] = func.name functions[fdef.signature.name] = func return functions
def get_func_graph(op, input_shapes, func_name): """Generates and returns a FuncGraph for the given op and input_shapes.""" fdef = None graph = op.graph # Recursively search the func in graphs. while graph is not None: func = graph._get_function(func_name) # pylint: disable=protected-access if func is not None: fdef = func.definition break if hasattr(graph, "outer_graph"): graph = graph.outer_graph else: break if fdef is None: raise KeyError("%s cannot be found in the graph" % func_name) # `op.graph` may not be the same as `ops.get_default_graph()` e.g. # in the case of nested if ops or when the gradient is being computed # from inside a Defun. We build the `func_graph` with `op.graph` as its # `outer_graph`. This resembles how the `FuncGraph` was built in the # forward pass. We need this so that we can resolve references to tensors # in `func_graph` from its gradient graph in `_resolve_grad_inputs`. with op.graph.as_default(): func_graph = function_def_to_graph.function_def_to_graph( fdef, input_shapes) return func_graph
def testFunctionCallsFromFunction(self): x = constant_op.constant(5.0) y = constant_op.constant(10.0) @function.defun def fn(): @function.defun def inner_fn(): return x + y return inner_fn() @function.defun def fn2(): return 2 * fn() fn2_defun = fn2.get_concrete_function() # Call `fn2` to make sure `fn` is correctly instantiated so # `function_def_to_graph` can find it. fn2_defun() fdef = fn2_defun._inference_function.definition func_graph = function_def_to_graph.function_def_to_graph(fdef) with func_graph.as_default(): x_ph, y_ph = func_graph.inputs with self.session(graph=func_graph) as sess: self.assertEqual( sess.run(func_graph.outputs[0], feed_dict={ x_ph: 5.0, y_ph: 10.0 }), 30.0)
def testFunctionCallsFromFunction(self): x = constant_op.constant(5.0) y = constant_op.constant(10.0) @function.defun def fn(): @function.defun def inner_fn(): return x + y return inner_fn() @function.defun def fn2(): return 2 * fn() fn2_defun = fn2.get_concrete_function() # Call `fn2` to make sure `fn` is correctly instantiated so # `function_def_to_graph` can find it. fn2_defun() fdef = fn2_defun._inference_function.definition func_graph = function_def_to_graph.function_def_to_graph(fdef) with func_graph.as_default(): x_ph, y_ph = func_graph.inputs with self.session(graph=func_graph) as sess: self.assertEqual( sess.run(func_graph.outputs[0], feed_dict={ x_ph: 5.0, y_ph: 10.0 }), 30.0)
def _load_func_graphs(self, function_library): # TODO(allenl): Do we need to do name mapping here? Not quite sure what # happens when loaded names collide with existing names. # TODO(andresp): Look into restoring nested and gradient functions in the # right order. self._functions = {} for fdef in function_library.function: graph = function_def_lib.function_def_to_graph(fdef) self._functions[fdef.signature.name] = function.Function(graph)
def testInputsAndOutputs(self): fdef = self._build_function_def() g = function_def_to_graph.function_def_to_graph(fdef) self.assertEqual(g.name, "_whats_in_a_name") with self.session(graph=g) as sess: inputs = sess.run(g.inputs, feed_dict={"x:0": 2, "y:0": 3}) self.assertSequenceEqual(inputs, [2.0, 3.0]) outputs = sess.run(g.outputs, feed_dict={"x:0": 2, "y:0": 3}) self.assertSequenceEqual(outputs, [13.0, 35.0])
def testInputsAndOutputs(self): fdef = self._build_function_def() g = function_def_to_graph.function_def_to_graph(fdef) self.assertEqual(g.name, "_whats_in_a_name") with self.session(graph=g) as sess: inputs = sess.run(g.inputs, feed_dict={"x:0": 2, "y:0": 3}) self.assertSequenceEqual(inputs, [2.0, 3.0]) outputs = sess.run(g.outputs, feed_dict={"x:0": 2, "y:0": 3}) self.assertSequenceEqual(outputs, [13.0, 35.0])
def count_ops(self, graph): '''return num of all operators''' from tensorflow.python.framework.function_def_to_graph import function_def_to_graph num = len(graph.get_operations()) for key, fdef in graph._functions.items(): sub_tf_graph = function_def_to_graph(fdef.definition) self.functions[key] = sub_tf_graph num += self.count_ops(sub_tf_graph) return num
def _get_func_graph_for_branch(branch_name): extra_inputs = if_op.inputs[1:] # First input is pred. input_shapes = [t.shape for t in extra_inputs] func_name = if_op.get_attr(branch_name).name fdef = if_op.graph._get_function(func_name).definition func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes) func_graph.extra_inputs = extra_inputs func_graph.extra_args = func_graph.inputs func_graph._captured = dict(zip(extra_inputs, func_graph.inputs)) return func_graph
def _get_func_graph_for_name(graph, func_name): """Returns the FuncGraph associated to the given func_name if possible.""" func = graph._get_function(str(func_name)) # pylint: disable=protected-access if not func: raise ValueError( 'Function {} does not exist in the graph.'.format(func_name)) if not hasattr(func, 'graph'): # This is a _DefinedFunction. return function_def_to_graph.function_def_to_graph(func.definition) else: return func.graph
def get_func_graph(op, input_shapes, func_name): """Generates and returns a FuncGraph for the given op and input_shapes.""" fdef = op.graph._get_function(func_name).definition # pylint: disable=protected-access # `op.graph` may not be the same as `ops.get_default_graph()` e.g. # in the case of nested if ops or when the gradient is being computed # from inside a Defun. We build the `func_graph` with `op.graph` as its # `outer_graph`. This resembles how the `FuncGraph` was built in the # forward pass. We need this so that we can resolve references to tensors # in `func_graph` from its gradient graph in `_resolve_grad_inputs`. with op.graph.as_default(): func_graph = function_def_to_graph.function_def_to_graph( fdef, input_shapes) return func_graph
def to_tf_function_graph(self): # type: () -> tf.Graph """ Converts this graph into a new TensorFlow `Graph`. Also takes care of variables. Note that function_def_to_graph.function_def_to_graph won't work if function calls into other functions. Returns a fresh `tf.Graph` containing all the nodes and variables that this object represents. """ return function_def_to_graph.function_def_to_graph( self.to_function_graph_def())
def load_function_def_library(library, load_shared_name_suffix=None): """Load a set of functions as concrete functions without captured inputs. Functions names are manipulated during load such that they do not overlap with previously created ones. Args: library: FunctionDefLibrary proto message. load_shared_name_suffix: If specified, used to uniquify shared names. Otherwise a unique name is generated. Returns: Map of original function names in the library to instances of `ConcreteFunction` without captured inputs. Raises: ValueError: if functions dependencies have a cycle. """ library_function_names = set(fdef.signature.name for fdef in library.function) functions = {} if load_shared_name_suffix is None: load_shared_name_suffix = "_load_{}".format(ops.uid()) for fdef in _sort_function_defs(library, library_function_names): copy = _fix_fdef(fdef, functions, load_shared_name_suffix) # There is no need to copy all functions into the function def graph. It # leads to a O(n^2) increase of memory when importing functions and the # extra function definitions are a no-op since they already imported as a # function before and passed in explicitly (due to the topologic sort # import). func_graph = function_def_lib.function_def_to_graph( copy, copy_functions=False) for dep in _list_function_deps(fdef, library_function_names): functions[dep].add_to_graph(func_graph) func = function_lib.ConcreteFunction(func_graph) func.add_to_graph() if context.executing_eagerly(): func.add_to_graph(ops.get_default_graph()) functions[fdef.signature.name] = func # Also register the gradients in the current root context. with ops.init_scope(): func._register_gradient() # pylint: disable=protected-access return functions
def _get_body_graph(while_op): """Returns `FuncGraph` for the while body. Args: while_op: The While Operation. Returns: `FuncGraph` for the while body. """ extra_inputs = list(while_op.inputs) input_shapes = [t.shape for t in extra_inputs] func_name = while_op.get_attr("body").name fdef = while_op.graph._get_function(func_name).definition func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes) func_graph._while = while_op return func_graph
def get_func_graph_for_name(graph, func_name): """Returns the FuncGraph associated to the given func_name if possible.""" while graph is not None: func = graph._get_function(str(func_name)) # pylint: disable=protected-access if func is not None: if hasattr(func, 'graph'): return func.graph func_graph = function_def_to_graph.function_def_to_graph( func.definition) if func_graph is not None: return func_graph if hasattr(graph, 'outer_graph'): graph = graph.outer_graph else: raise ValueError( 'Function {} does not exist in the graph.'.format(func_name))
def _get_body_graph(while_op): """Returns `FuncGraph` for the while body. Args: while_op: The While Operation. Returns: `FuncGraph` for the while body. """ extra_inputs = list(while_op.inputs) input_shapes = [t.shape for t in extra_inputs] func_name = while_op.get_attr("body").name fdef = while_op.graph._get_function(func_name).definition func_graph = function_def_to_graph.function_def_to_graph( fdef, input_shapes) func_graph._while = while_op return func_graph
def _get_body_graph(while_op): """Returns `FuncGraph` for the while body. Args: while_op: The While Operation. Returns: `FuncGraph` for the while body. """ # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes. input_shapes = [ tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes") ] func_name = while_op.get_attr("body").name fdef = while_op.graph._get_function(func_name).definition func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes) func_graph._while = while_op return func_graph
def _get_body_graph(while_op): """Returns `FuncGraph` for the while body. Args: while_op: The While Operation. Returns: `FuncGraph` for the while body. """ # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes. input_shapes = [ tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes") ] func_name = while_op.get_attr("body").name fdef = while_op.graph._get_function(func_name).definition func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes) func_graph._while = while_op return func_graph
def testControlDependencies(self): def fn(inp): x = constant_op.constant(2.0, name="x") # TODO(b/79881896): Test external control dependency once that's # supported. with ops.control_dependencies([x, inp]): constant_op.constant(3.0, name="y") return 4.0 fdef = function._DefinedFunction( fn, ["inp"], [dtypes.float32]).definition func_graph = function_def_to_graph.function_def_to_graph(fdef) op = func_graph.get_operation_by_name("y") self.assertEqual(len(op.control_inputs), 2) self.assertEqual(op.control_inputs[0].name, "x") self.assertEqual(op.control_inputs[1].name, "inp")
def testControlDependencies(self): @function.defun def fn(inp): x = constant_op.constant(2.0, name="x") # TODO(b/79881896): Test external control dependency once that's # supported. with ops.control_dependencies([x, inp]): constant_op.constant(3.0, name="y") return 4.0 inp = constant_op.constant(1.0) fdef = fn.get_concrete_function(inp).function_def func_graph = function_def_to_graph.function_def_to_graph(fdef) op = func_graph.get_operation_by_name("y") self.assertEqual(len(op.control_inputs), 2) self.assertEqual(op.control_inputs[0].name, "x") self.assertEqual(op.control_inputs[1].name, "placeholder")
def _get_func_graph_for_branch(name_attr_list): """Generates and returns a FuncGraph for the given branch.""" inputs = op.inputs[1:] # First input is pred. input_shapes = [t.shape for t in inputs] fdef = op.graph._get_function(name_attr_list.name).definition # `op.graph` may not be the same as `ops.get_default_graph()` e.g. # in the case of nested if ops or when the gradient is being computed # from inside a Defun. We build the `func_graph` with `op.graph` as its # `outer_graph`. This resembles how the `FuncGraph` was built in the # forward pass. We need this so that we can resolve references to tensors # in `func_graph` from its gradient graph in `_resolve_grad_inputs`. with op.graph.as_default(): func_graph = function_def_to_graph.function_def_to_graph( fdef, input_shapes) func_graph.captures = collections.OrderedDict(zip(inputs, func_graph.inputs)) # Link the op so that the gradient code can use it. func_graph._forward_cond = op return func_graph
def testControlDependencies(self): @function.defun def fn(inp): x = constant_op.constant(2.0, name="x") # TODO(b/79881896): Test external control dependency once that's # supported. with ops.control_dependencies([x, inp]): constant_op.constant(3.0, name="y") return 4.0 inp = constant_op.constant(1.0) fdef = fn.get_concrete_function(inp).function_def func_graph = function_def_to_graph.function_def_to_graph(fdef) op = func_graph.get_operation_by_name("y") self.assertEqual(len(op.control_inputs), 2) self.assertEqual(op.control_inputs[0].name, "x") self.assertEqual(op.control_inputs[1].name, "inp")
def _get_func_graph_for_branch(branch_name): """Generates and returns a FuncGraph for the given branch.""" inputs = if_op.inputs[1:] # First input is pred. input_shapes = [t.shape for t in inputs] func_name = if_op.get_attr(branch_name).name fdef = if_op.graph._get_function(func_name).definition # `if_op.graph` may not be the same as `ops.get_default_graph()` e.g. # in the case of nested if ops or when the gradient is being computed # from inside a Defun. We build the `func_graph` with `if_op.graph` as its # `outer_graph`. This resembles how the `FuncGraph` was built in the # forward pass. We need this so that we can resolve references to tensors # in `func_graph` from its gradient graph in `_resolve_grad_inputs`. with if_op.graph.as_default(): func_graph = function_def_to_graph.function_def_to_graph( fdef, input_shapes) func_graph.captures = collections.OrderedDict(zip(inputs, func_graph.inputs)) # Set the if op so that the gradient code can use it. func_graph._if = if_op return func_graph
def load_function_def_library(library): """Load a set of functions as concrete functions without captured inputs. Functions names are manipulated during load such that they do not overlap with previously created ones. Args: library: FunctionDefLibrary proto message. Returns: Map of original function names in the library to instances of `ConcreteFunction` without captured inputs. Raises: ValueError: if functions dependencies have a cycle. """ functions = {} load_shared_name_suffix = "_load_{}".format(ops.uid()) for fdef in _sort_function_defs(library): copy = _fix_fdef(fdef, functions, load_shared_name_suffix) # There is no need to copy functions into the function def graph. # It leads to a O(n^2) increase of memory when importing functions # and the extra function definitions are a no-op since they already # imported as a function before (due to the topologic sort import). func_graph = function_def_lib.function_def_to_graph( copy, copy_functions=False) for dep in _list_function_deps(fdef): functions[dep].add_to_graph(func_graph) func = function_lib.ConcreteFunction(func_graph) func.add_to_graph() functions[fdef.signature.name] = func # Also register the gradients in the current root context. with ops.init_scope(): func._register_gradient() # pylint: disable=protected-access return functions
def convert(inp_format, inp_loc, out_format, out_loc, output_nodes, ng_backend, device_id, backend_optional_params, shape_hints, do_aot, save_ng_clusters): """Functional api for converting TF models by inserting ngraph nodes. Sample usage: from tf2ngraph import convert convert('savedmodel', 'test_graph' , 'pbtxt', 'test_graph_ngraph.pbtxt', ['out_node']) convert('pbtxt', 'test_graph.pbtxt' , 'pbtxt', 'test_graph_ngraph.pbtxt', ['out_node']) Parameters: inp_format (string): 'savedmodel', 'pbtxt', 'pb' inp_loc (string): Location of input file or folder (in case of savedmodel) out_format (string): 'savedmodel', 'pbtxt', 'pb' out_loc (string): Location of output file or folder (in case of savedmodel) output_nodes (iterable of strings): names of output nodes Returns: void """ exit_on_error( inp_format in allowed_formats['input'], 'Unsupported input format ' + inp_format + ". Supported formats: " + str(allowed_formats['input'])) exit_on_error( out_format in allowed_formats['output'], 'Unsupported output format ' + out_format + ". Supported formats: " + str(allowed_formats['output'])) exit_on_error( ngraph_bridge.is_grappler_enabled(), "ngraph-bridge is not built with grappler enabled, hence tf2ngraph is not supported." ) input_gdef = get_gdef(inp_format, inp_loc) attach_device(input_gdef) output_gdef = run_ngraph_grappler_optimizer( input_gdef, output_nodes, ng_backend, device_id, backend_optional_params, shape_hints, do_aot) if save_ng_clusters: for fn in output_gdef.library.function: tf.io.write_graph( function_def_to_graph(fn).as_graph_def(), '.', fn.signature.name + '.pbtxt', as_text=True) save_model(output_gdef, out_format, out_loc)
def load_function_def_library(library): """Load a set of functions as concrete functions without captured inputs. Functions names are manipulated during load such that they do not overlap with previously created ones. Args: library: FunctionDefLibrary proto message. Returns: Map of original function names in the library to instances of `ConcreteFunction` without captured inputs. Raises: ValueError: if functions dependencies have a cycle. """ functions = {} # Note: Use a new graph to allow function_def_to_graph to help validating # that the functions are loaded correctly. This is not possible to do # just in eager mode as there is no python API to find if a function has # been registered in eager. Note also that despite this the created # func_graphs can still be used in eager or in other graphs. import_graph = ops.Graph() for fdef in _sort_function_defs(library): with import_graph.as_default(): copy = _fix_fdef(fdef, functions) func_graph = function_def_lib.function_def_to_graph(copy) func = function_lib.ConcreteFunction(func_graph) func.add_to_graph(import_graph) functions[fdef.signature.name] = func # Also register the gradients in the current root context. with ops.init_scope(): func._register_gradient() # pylint: disable=protected-access return functions
def testControlDependencies(self): v = variables.Variable(1) @function.defun def fn(inp): assign = v.assign(3, name="assign", read_value=False) x = constant_op.constant(2.0, name="x") # TODO(b/79881896): Test external control dependency once that's # supported. with ops.control_dependencies([x, inp, assign]): constant_op.constant(3.0, name="y") return 4.0 inp = constant_op.constant(1.0) fdef = fn.get_concrete_function(inp).function_def func_graph = function_def_to_graph.function_def_to_graph(fdef) op = func_graph.get_operation_by_name("y") self.assertEqual(len(op.control_inputs), 3) self.assertEqual(op.control_inputs[0].name, "assign") self.assertEqual(op.control_inputs[1].name, "inp") self.assertEqual(op.control_inputs[2].name, "x")
def testControlDependencies(self): v = variables.Variable(1) @function.defun def fn(inp): assign = v.assign(3, name="assign", read_value=False) x = constant_op.constant(2.0, name="x") # TODO(b/79881896): Test external control dependency once that's # supported. with ops.control_dependencies([x, inp, assign]): constant_op.constant(3.0, name="y") return 4.0 inp = constant_op.constant(1.0) fdef = fn.get_concrete_function(inp).function_def func_graph = function_def_to_graph.function_def_to_graph(fdef) op = func_graph.get_operation_by_name("y") self.assertEqual(len(op.control_inputs), 3) self.assertEqual(op.control_inputs[0].name, "assign") self.assertEqual(op.control_inputs[1].name, "inp") self.assertEqual(op.control_inputs[2].name, "x")
def load_function_def_library(library, load_shared_name_suffix=None): """Load a set of functions as concrete functions without captured inputs. Functions names are manipulated during load such that they do not overlap with previously created ones. Args: library: FunctionDefLibrary proto message. load_shared_name_suffix: If specified, used to uniquify shared names. Otherwise, a unique name is generated. Returns: Map of original function names in the library to instances of `ConcreteFunction` without captured inputs. Raises: ValueError: if functions dependencies have a cycle. """ library_function_names = set(fdef.signature.name for fdef in library.function) functions = {} renamed_functions = {} # Our graph building code currently requires functions to be registered with # some tf.Graph in order to import functions using the # op-name-is-function-name calling convention. To avoid leaking memory into # the global default graph when executing eagerly, we create a temporary # Graph. # # TODO(allenl): Make this Graph creation unnecessary when executing eagerly by # fixing function_def_to_graph_def. if ops.executing_eagerly_outside_functions(): graph = ops.Graph() else: graph = ops.get_default_graph() if load_shared_name_suffix is None: load_shared_name_suffix = "_load_{}".format(ops.uid()) for fdef in _sort_function_defs(library, library_function_names): copy = _fix_fdef(fdef, functions, load_shared_name_suffix) # There is no need to copy all functions into the function def graph. It # leads to a O(n^2) increase of memory when importing functions and the # extra function definitions are a no-op since they already imported as a # function before and passed in explicitly (due to the topologic sort # import). with graph.as_default(): func_graph = function_def_lib.function_def_to_graph(copy) _restore_gradient_functions(func_graph, renamed_functions) for dep in _list_function_deps(fdef, library_function_names): functions[dep].add_to_graph(func_graph) # We do not initialize the new ConcreteFunction's function_spec and/or # arg_keywords here (which are used to parse the structured and flat # signatures, respectively). ConcreteFunction that are part of a saved # function is set up later by recreate_function(); and bare ConcreteFunction # is set up by by setup_bare_concrete_function(). func = function_lib.ConcreteFunction(func_graph) func.add_to_graph(graph) functions[fdef.signature.name] = func renamed_functions[func.name] = func if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()): # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration # is fixed. Currently it's leaking memory to maintain bug compatibility # with previous behavior. func.add_to_graph(ops.get_default_graph()) return functions