def testGetFunctionDef(self): @def_function.function def f(): return constant_op.constant(1.) concrete = f.get_concrete_function() function_def = context.get_function_def(concrete.name) self.assertIsNot(function_def, None) found_const_node = False for node_def in function_def.node_def: if node_def.op == 'Const': found_const_node = True break self.assertTrue(found_const_node) with self.assertRaises(errors.NotFoundError): _ = context.get_function_def('this_should_not_be_found')
def _populate_sub_graph_input_shapes(self, graph, graph_fns): """ Populate function (sub-graph) input shapes from control flow op's inputs Note that the functions (sub-graphs) are not nested but the control flow ops are nested. The input shapes are used to extract sub-graphs from the parent graph (as the input of function_def_to_graph). Parameter --------- graph: tf.Graph TensorFlow graph. graph_fns: list of graph functions. List of TensorFlow graph functions. Returns ------- sg_input_shapes: dict(str: list) Dictionary of function (sub-graph) name and input shape pairs. """ sg_input_shapes = {} sub_graphs = [] for op in graph.get_operations(): if op.type not in {"StatelessIf", "If", "StatelessWhile", "While"}: continue sg1, sg2 = None, None if op.type in {"StatelessIf", "If"}: sg1 = op.get_attr("then_branch").name sg2 = op.get_attr("else_branch").name if op.type in {"StatelessWhile", "While"}: sg1 = op.get_attr("cond").name sg2 = op.get_attr("body").name # memorize input shapes for sub-graph conversions op_input_shapes = [i.get_shape() for i in op.inputs] sg_input_shapes.update({ sg1: op_input_shapes, sg2: op_input_shapes }) sub_graphs += [sg1, sg2] for name in sub_graphs: sg = graph_fns.get(name) fn_def = context.get_function_def(name) op_input_shapes = sg_input_shapes[name] op_input_shapes = op_input_shapes[-len(fn_def.signature.input_arg ):] fn_graph = _function_def_to_graph(fn_def, input_shapes=op_input_shapes) sg_input_shapes.update( self._populate_sub_graph_input_shapes(fn_graph, graph_fns)) return sg_input_shapes
def _dict_from_graph_def(graph, fn_name="main", sg_input_shapes=None): """ Loads a tf.Graph and transform it into dictionary of ParsedTFNodes. Potentially contains multiple functions, in such case, recursively resolve functions (sub-graphs). Parameters ---------- graph: tf.Graph TensorFlow graph. fn_name: str, optional, defaults to 'main' Function name of the graph. sg_input_shapes: dict(str: list) Dictionary of name and input shapes for functions / sub-graphs. Returns ------- dict(str: dict(str: ParsedTFNode)) Dictionary of function name and dictionary of node name and ParsedTFNode object. """ graph_dict = {fn_name: {}} graph_inputs = {fn_name: []} graph_outputs = {fn_name: []} graph_ret = {fn_name: {}} for op in graph.get_operations(): graph_dict[fn_name].update({op.name: ParsedTFNode(op.node_def)}) for name, sg in graph._functions.items(): sg_def = context.get_function_def(name) if name in sg_input_shapes: input_shapes = sg_input_shapes[name] input_shapes = input_shapes[-len(sg_def.signature.input_arg):] fn_graph = _function_def_to_graph(sg_def, input_shapes=input_shapes) graph_dict.update( TF2Loader._dict_from_graph_def(fn_graph, name, sg_input_shapes)[0]) graph_inputs.update( {name: [t.name.split(":")[0] for t in fn_graph.inputs]}) graph_outputs.update( {name: [t.name.split(":")[0] for t in fn_graph.outputs]}) # ret is a mapping from the output arg names from `signature` to the # outputs from `node_def` that should be returned by the function. graph_ret.update({name: sg_def.ret}) return graph_dict, graph_inputs, graph_outputs, graph_ret