def test_remove_input(self): model_proto = self.sample_net() nodes = model_proto.node g = Graph(nodes, output_shapes={}, dtypes={}) n4 = g.get_node_by_name("n4") g.remove_input(n4, n4.input[1]) result = onnx_to_graphviz(g) expected = 'digraph { n1 [op_type=Abs] n2 [op_type=Abs] n3 [op_type=Abs] n4 [op_type=Add] ' \ 'n5 [op_type=Abs] n6 [op_type=Identity] input -> n1 n1:0 -> n2 n1:0 -> n3 n2:0 -> n4 ' \ 'n4:0 -> n5 n5:0 -> n6 }' self.assertEqual(expected, result)
def test_insert_node1(self): model_proto = self.sample_net() nodes = model_proto.node g = Graph(nodes, output_shapes={}, dtypes={}) n2 = g.get_node_by_name("n2") n7 = g.insert_new_node_on_input(n2, "Abs", "n1:0", name="n7") ops = g.get_nodes() ops.append(n7) g.topological_sort(ops) result = onnx_to_graphviz(g) expected = 'digraph { n1 [op_type=Abs] n7 [op_type=Abs] n2 [op_type=Abs] n3 [op_type=Abs] ' \ 'n4 [op_type=Add] n5 [op_type=Abs] n6 [op_type=Identity] ' \ 'input -> n1 n1:0 -> n7 n7:0 -> n2 n1:0 -> n3 n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 n5:0 -> n6 }' self.assertEqual(expected, result)