def test_deletion_3(self): nodes = { **shaped_const_with_data('input_0', [5, 3]), **shaped_const_with_data('input_1', [5, 1]), **shaped_const_with_data('input_2', [5, 5]), **shaped_const_with_data('input_3', [5, 0]), **regular_op_with_shaped_data('concat', [5, 9], { 'type': 'Concat', 'axis': 1 }), **result(), } edges_before = [ *connect('input_0', '0:concat'), *connect('input_1', '1:concat'), *connect('input_2', '2:concat'), *connect('input_3', '3:concat'), *connect('concat', 'output'), ] edges_after = [ *connect('input_0', '0:concat'), *connect('input_1', '1:concat'), *connect('input_2', '2:concat'), *connect('concat', 'output'), ] graph = build_graph(nodes, edges_before, nodes_with_edges_only=True) ConcatOdInputEraserAndPortsReconnect().find_and_replace_pattern(graph) graph_ref = build_graph(nodes, edges_after, nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_deletion_trailing_unconnected_ports(self): nodes = { **shaped_const_with_data('input_0', [5, 3]), **regular_op_with_shaped_data('concat', [5, 3], { 'type': 'Concat', 'axis': 1 }), **result(), } edges_before = [ *connect('input_0', '0:concat'), *connect('concat', 'output'), ] edges_after = [ *connect('input_0', '0:concat'), *connect('concat', 'output'), ] graph = build_graph(nodes, edges_before, nodes_with_edges_only=True) Node(graph, 'concat').add_input_port(1) ConcatOdInputEraserAndPortsReconnect().find_and_replace_pattern(graph) graph_ref = build_graph(nodes, edges_after, nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp) self.assertTrue(1 not in Node(graph, 'concat').in_ports())
def test_negative(self): nodes = { **shaped_const_with_data('input_0', [1]), **shaped_const_with_data('input_1', [1]), **shaped_const_with_data('input_2', [1]), **shaped_const_with_data('input_3', [1]), **regular_op_with_shaped_data('concat', [4], {'type': 'Concat'}), **result(), } edges = [ *connect('input_0', '0:concat'), *connect('input_1', '1:concat'), *connect('input_2', '2:concat'), *connect('input_3', '3:concat'), *connect('concat', 'output'), ] graph = build_graph(nodes, edges, nodes_with_edges_only=True) ConcatOdInputEraserAndPortsReconnect().find_and_replace_pattern(graph) graph_ref = build_graph(nodes, edges, nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_assertion_error(self): nodes = { **shaped_const_with_data('input_0', [0]), **shaped_const_with_data('input_1', [0]), **shaped_const_with_data('input_2', [0]), **shaped_const_with_data('input_3', [0]), **regular_op_with_shaped_data('concat', [0], {'type': 'Concat'}), **result(), } edges = [ *connect('input_0', '0:concat'), *connect('input_1', '1:concat'), *connect('input_2', '2:concat'), *connect('input_3', '3:concat'), *connect('concat', 'output'), ] graph = build_graph(nodes, edges, nodes_with_edges_only=True) self.assertRaises(AssertionError, ConcatOdInputEraserAndPortsReconnect().find_and_replace_pattern, graph)