示例#1
0
    def testGetFunctionDef(self):
        @def_function.function
        def f():
            return constant_op.constant(1.)

        concrete = f.get_concrete_function()
        function_def = context.get_function_def(concrete.name)

        self.assertIsNot(function_def, None)

        found_const_node = False
        for node_def in function_def.node_def:
            if node_def.op == 'Const':
                found_const_node = True
                break
        self.assertTrue(found_const_node)

        with self.assertRaises(errors.NotFoundError):
            _ = context.get_function_def('this_should_not_be_found')
示例#2
0
    def _populate_sub_graph_input_shapes(self, graph, graph_fns):
        """
        Populate function (sub-graph) input shapes from control flow op's inputs
        Note that the functions (sub-graphs) are not nested but the control flow
        ops are nested. The input shapes are used to extract sub-graphs from the
        parent graph (as the input of function_def_to_graph).

        Parameter
        ---------
        graph: tf.Graph
            TensorFlow graph.
        graph_fns: list of graph functions.
            List of TensorFlow graph functions.

        Returns
        -------
        sg_input_shapes: dict(str: list)
            Dictionary of function (sub-graph) name and input shape pairs.
        """
        sg_input_shapes = {}
        sub_graphs = []
        for op in graph.get_operations():
            if op.type not in {"StatelessIf", "If", "StatelessWhile", "While"}:
                continue

            sg1, sg2 = None, None
            if op.type in {"StatelessIf", "If"}:
                sg1 = op.get_attr("then_branch").name
                sg2 = op.get_attr("else_branch").name
            if op.type in {"StatelessWhile", "While"}:
                sg1 = op.get_attr("cond").name
                sg2 = op.get_attr("body").name

            # memorize input shapes for sub-graph conversions
            op_input_shapes = [i.get_shape() for i in op.inputs]
            sg_input_shapes.update({
                sg1: op_input_shapes,
                sg2: op_input_shapes
            })
            sub_graphs += [sg1, sg2]

        for name in sub_graphs:
            sg = graph_fns.get(name)
            fn_def = context.get_function_def(name)
            op_input_shapes = sg_input_shapes[name]
            op_input_shapes = op_input_shapes[-len(fn_def.signature.input_arg
                                                   ):]
            fn_graph = _function_def_to_graph(fn_def,
                                              input_shapes=op_input_shapes)
            sg_input_shapes.update(
                self._populate_sub_graph_input_shapes(fn_graph, graph_fns))

        return sg_input_shapes
示例#3
0
    def _dict_from_graph_def(graph, fn_name="main", sg_input_shapes=None):
        """
        Loads a tf.Graph and transform it into dictionary of ParsedTFNodes.
        Potentially contains multiple functions, in such case, recursively
        resolve functions (sub-graphs).

        Parameters
        ----------
        graph: tf.Graph
            TensorFlow graph.
        fn_name: str, optional, defaults to 'main'
            Function name of the graph.
        sg_input_shapes: dict(str: list)
            Dictionary of name and input shapes for functions / sub-graphs.

        Returns
        -------
        dict(str: dict(str: ParsedTFNode))
            Dictionary of function name and dictionary of node name and
            ParsedTFNode object.
        """
        graph_dict = {fn_name: {}}
        graph_inputs = {fn_name: []}
        graph_outputs = {fn_name: []}
        graph_ret = {fn_name: {}}

        for op in graph.get_operations():
            graph_dict[fn_name].update({op.name: ParsedTFNode(op.node_def)})

        for name, sg in graph._functions.items():
            sg_def = context.get_function_def(name)
            if name in sg_input_shapes:
                input_shapes = sg_input_shapes[name]
                input_shapes = input_shapes[-len(sg_def.signature.input_arg):]
                fn_graph = _function_def_to_graph(sg_def,
                                                  input_shapes=input_shapes)

                graph_dict.update(
                    TF2Loader._dict_from_graph_def(fn_graph, name,
                                                   sg_input_shapes)[0])
                graph_inputs.update(
                    {name: [t.name.split(":")[0] for t in fn_graph.inputs]})
                graph_outputs.update(
                    {name: [t.name.split(":")[0] for t in fn_graph.outputs]})

                # ret is a mapping from the output arg names from `signature` to the
                # outputs from `node_def` that should be returned by the function.
                graph_ret.update({name: sg_def.ret})

        return graph_dict, graph_inputs, graph_outputs, graph_ret