def test_backward_bfs_for_op_closest_op_detected(self): """ input -> hsigmoid_1 -> hsigmoid_2 -> result The returned op should be first met HSigmoid which is hsigmoid_2 """ nodes = { **regular_op('input', {'op': 'Parameter'}), **regular_op('hsigmoid_1', {'op': 'HSigmoid'}), **regular_op('hsigmoid_2', {'op': 'HSigmoid'}), **result('result'), } edges = [ ('input', 'hsigmoid_1', { 'out': 0, 'in': 0 }), ('hsigmoid_1', 'hsigmoid_2', { 'out': 0, 'in': 0 }), ('hsigmoid_2', 'result', { 'out': 0, 'in': 0 }), ] graph = build_graph_with_edge_attrs(nodes, edges) graph.stage = 'front' found_nodes = backward_bfs_for_operation(Node(graph, 'result'), ['HSigmoid']) self.assertEqual(len(found_nodes), 1) self.assertEqual(found_nodes[0].id, 'hsigmoid_2')
def test_backward_bfs_for_op_parallel_branch_stop_op(self): r""" input_1 -> hsigmoid_1 -> hsigmoid_2 -> \ - Concat->result / input_2 -> hsigmoid_3 -> ShapeOf -> The returned op should be first met HSigmoids which is hsigmoid_2, but not the hsigmoid_3 located after banned operation of type "ShapeOf" """ nodes = { **regular_op('input_1', {'op': 'Parameter'}), **regular_op('hsigmoid_1', {'op': 'HSigmoid'}), **regular_op('hsigmoid_2', {'op': 'HSigmoid'}), **regular_op('input_2', {'op': 'Parameter'}), **regular_op('hsigmoid_3', {'op': 'HSigmoid'}), **regular_op('shapeof', {'op': 'ShapeOf'}), **regular_op('concat', {'op': 'Concat'}), **result('result'), } edges = [ ('input_1', 'hsigmoid_1', { 'out': 0, 'in': 0 }), ('hsigmoid_1', 'hsigmoid_2', { 'out': 0, 'in': 0 }), ('hsigmoid_2', 'concat', { 'out': 0, 'in': 0 }), ('input_2', 'hsigmoid_3', { 'out': 0, 'in': 0 }), ('hsigmoid_3', 'shapeof', { 'out': 0, 'in': 0 }), ('shapeof', 'concat', { 'out': 0, 'in': 1 }), ('concat', 'result', { 'out': 0, 'in': 0 }), ] graph = build_graph_with_edge_attrs(nodes, edges) graph.stage = 'front' found_nodes = backward_bfs_for_operation(Node(graph, 'result'), ['HSigmoid'], ['ShapeOf']) self.assertEqual(len(found_nodes), 1) self.assertEqual(found_nodes[0].id, 'hsigmoid_2')
def test_backward_bfs_for_op_parallel_branch_op_detected(self): r""" input_1 -> hsigmoid_1 -> hsigmoid_2 -> \ - Concat->result / input_2 -> hsigmoid_3 -> hsigmoid_4 -> The returned op should be first met HSigmoids which are hsigmoid_2 and hsigmoid_4 """ nodes = { **regular_op('input_1', {'op': 'Parameter'}), **regular_op('hsigmoid_1', {'op': 'HSigmoid'}), **regular_op('hsigmoid_2', {'op': 'HSigmoid'}), **regular_op('input_2', {'op': 'Parameter'}), **regular_op('hsigmoid_3', {'op': 'HSigmoid'}), **regular_op('hsigmoid_4', {'op': 'HSigmoid'}), **regular_op('concat', {'op': 'Concat'}), **result('result'), } edges = [ ('input_1', 'hsigmoid_1', { 'out': 0, 'in': 0 }), ('hsigmoid_1', 'hsigmoid_2', { 'out': 0, 'in': 0 }), ('hsigmoid_2', 'concat', { 'out': 0, 'in': 0 }), ('input_2', 'hsigmoid_3', { 'out': 0, 'in': 0 }), ('hsigmoid_3', 'hsigmoid_4', { 'out': 0, 'in': 0 }), ('hsigmoid_4', 'concat', { 'out': 0, 'in': 1 }), ('concat', 'result', { 'out': 0, 'in': 0 }), ] graph = build_graph_with_edge_attrs(nodes, edges) graph.stage = 'front' found_nodes = backward_bfs_for_operation(Node(graph, 'result'), ['HSigmoid']) self.assertEqual(len(found_nodes), 2) self.assertSetEqual({found_nodes[0].id, found_nodes[1].id}, {'hsigmoid_2', 'hsigmoid_4'})
def test_backward_bfs_for_op_no_ops_detected(self): nodes = { **regular_op('input', {'op': 'Parameter'}), **regular_op('hsigmoid', {'op': 'HSigmoid'}), **result('result'), } edges = [ ('input', 'hsigmoid', { 'out': 0, 'in': 0 }), ('hsigmoid', 'result', { 'out': 0, 'in': 0 }), ] graph = build_graph_with_edge_attrs(nodes, edges) graph.stage = 'front' found_nodes = backward_bfs_for_operation(Node(graph, 'result'), ['NonExistingOp']) self.assertEqual(len(found_nodes), 0)