def _insert_op(fn, op, name, attr=None): """ Create a node with given attributes, then insert to the target graph in given function. Parameters ---------- fn: SSAFunction Function that contains graph to operate on. op: str Type of the operation for the new node. name: str Name of the new node. attr: dict or None (optional) Attributes of the new node. Returns ------- node: ParsedTFNode New node object. """ node = ParsedTFNode() node.op = op node.name = name if attr is not None: node.attr = attr fn.graph[node.name] = node return node
def test_init(self): parsed_node = ParsedTFNode(_mock_tf_node()) parsed_node.parse_from_attr() self.assertEqual("aNode", parsed_node.name) self.assertEqual("Placeholder", parsed_node.op) self.assertEqual(["anInput"], parsed_node.inputs) self.assertEqual(["aControlInput"], parsed_node.control_inputs)
def _dict_from_graph_def(graph, fn_name="main", sg_input_shapes=None): """ Loads a tf.Graph and transform it into dictionary of ParsedTFNodes. Potentially contains multiple functions, in such case, recursively resolve functions (sub-graphs). Parameters ---------- graph: tf.Graph TensorFlow graph. fn_name: str, optional, defaults to 'main' Function name of the graph. sg_input_shapes: dict(str: list) Dictionary of name and input shapes for functions / sub-graphs. Returns ------- dict(str: dict(str: ParsedTFNode)) Dictionary of function name and dictionary of node name and ParsedTFNode object. """ graph_dict = {fn_name: {}} graph_inputs = {fn_name: []} graph_outputs = {fn_name: []} graph_ret = {fn_name: {}} for op in graph.get_operations(): graph_dict[fn_name].update({op.name: ParsedTFNode(op.node_def)}) for name, sg in graph._functions.items(): sg_def = sg.definition input_shapes = sg_input_shapes[name] input_shapes = input_shapes[-len(sg_def.signature.input_arg):] fn_graph = _function_def_to_graph(sg_def, input_shapes=input_shapes) graph_dict.update( TF2Loader._dict_from_graph_def(fn_graph, name, sg_input_shapes)[0]) graph_inputs.update( {name: [t.name.split(":")[0] for t in fn_graph.inputs]}) graph_outputs.update( {name: [t.name.split(":")[0] for t in fn_graph.outputs]}) # ret is a mapping from the output arg names from `signature` to the # outputs from `node_def` that should be returned by the function. sg_def_ret = sg_def.ret sg_def_ret["identity_0"] = sg_def_ret.pop("identity") graph_ret.update({name: sg_def_ret}) return graph_dict, graph_inputs, graph_outputs, graph_ret
def test_copy(self): parsed_node = ParsedTFNode(_mock_tf_node()) parsed_node.parse_from_attr() copy = parsed_node.copy() self.assertTrue(isinstance(copy, type(parsed_node))) props = [ "name", "op", "datatype", "value", "inputs", "control_inputs", "outputs", "control_outputs", "attr", "original_node", ] for prop in props: self.assertEqual( getattr(parsed_node, prop), getattr(copy, prop), "Mismatch in property {}".format(prop), )