示例#1
0
def _insert_op(fn, op, name, attr=None):
    """
    Create a node with given attributes, then insert to the target graph in
    given function.

    Parameters
    ----------
    fn: SSAFunction
        Function that contains graph to operate on.
    op: str
        Type of the operation for the new node.
    name: str
        Name of the new node.
    attr: dict or None (optional)
        Attributes of the new node.

    Returns
    -------
    node: ParsedTFNode
        New node object.
    """
    node = ParsedTFNode()
    node.op = op
    node.name = name
    if attr is not None:
        node.attr = attr
    fn.graph[node.name] = node
    return node
示例#2
0
 def test_init(self):
     parsed_node = ParsedTFNode(_mock_tf_node())
     parsed_node.parse_from_attr()
     self.assertEqual("aNode", parsed_node.name)
     self.assertEqual("Placeholder", parsed_node.op)
     self.assertEqual(["anInput"], parsed_node.inputs)
     self.assertEqual(["aControlInput"], parsed_node.control_inputs)
示例#3
0
    def _dict_from_graph_def(graph, fn_name="main", sg_input_shapes=None):
        """
        Loads a tf.Graph and transform it into dictionary of ParsedTFNodes.
        Potentially contains multiple functions, in such case, recursively
        resolve functions (sub-graphs).

        Parameters
        ----------
        graph: tf.Graph
            TensorFlow graph.
        fn_name: str, optional, defaults to 'main'
            Function name of the graph.
        sg_input_shapes: dict(str: list)
            Dictionary of name and input shapes for functions / sub-graphs.

        Returns
        -------
        dict(str: dict(str: ParsedTFNode))
            Dictionary of function name and dictionary of node name and
            ParsedTFNode object.
        """
        graph_dict = {fn_name: {}}
        graph_inputs = {fn_name: []}
        graph_outputs = {fn_name: []}
        graph_ret = {fn_name: {}}

        for op in graph.get_operations():
            graph_dict[fn_name].update({op.name: ParsedTFNode(op.node_def)})

        for name, sg in graph._functions.items():
            sg_def = sg.definition
            input_shapes = sg_input_shapes[name]
            input_shapes = input_shapes[-len(sg_def.signature.input_arg):]
            fn_graph = _function_def_to_graph(sg_def,
                                              input_shapes=input_shapes)

            graph_dict.update(
                TF2Loader._dict_from_graph_def(fn_graph, name,
                                               sg_input_shapes)[0])
            graph_inputs.update(
                {name: [t.name.split(":")[0] for t in fn_graph.inputs]})
            graph_outputs.update(
                {name: [t.name.split(":")[0] for t in fn_graph.outputs]})

            # ret is a mapping from the output arg names from `signature` to the
            # outputs from `node_def` that should be returned by the function.
            sg_def_ret = sg_def.ret
            sg_def_ret["identity_0"] = sg_def_ret.pop("identity")
            graph_ret.update({name: sg_def_ret})

        return graph_dict, graph_inputs, graph_outputs, graph_ret
示例#4
0
 def test_copy(self):
     parsed_node = ParsedTFNode(_mock_tf_node())
     parsed_node.parse_from_attr()
     copy = parsed_node.copy()
     self.assertTrue(isinstance(copy, type(parsed_node)))
     props = [
         "name",
         "op",
         "datatype",
         "value",
         "inputs",
         "control_inputs",
         "outputs",
         "control_outputs",
         "attr",
         "original_node",
     ]
     for prop in props:
         self.assertEqual(
             getattr(parsed_node, prop),
             getattr(copy, prop),
             "Mismatch in property {}".format(prop),
         )