Esempio n. 1
0
    def test_backward_bfs_for_op_closest_op_detected(self):
        """
        input -> hsigmoid_1 -> hsigmoid_2 -> result
        The returned op should be first met HSigmoid which is hsigmoid_2
        """
        nodes = {
            **regular_op('input', {'op': 'Parameter'}),
            **regular_op('hsigmoid_1', {'op': 'HSigmoid'}),
            **regular_op('hsigmoid_2', {'op': 'HSigmoid'}),
            **result('result'),
        }
        edges = [
            ('input', 'hsigmoid_1', {
                'out': 0,
                'in': 0
            }),
            ('hsigmoid_1', 'hsigmoid_2', {
                'out': 0,
                'in': 0
            }),
            ('hsigmoid_2', 'result', {
                'out': 0,
                'in': 0
            }),
        ]

        graph = build_graph_with_edge_attrs(nodes, edges)
        graph.stage = 'front'

        found_nodes = backward_bfs_for_operation(Node(graph, 'result'),
                                                 ['HSigmoid'])
        self.assertEqual(len(found_nodes), 1)
        self.assertEqual(found_nodes[0].id, 'hsigmoid_2')
Esempio n. 2
0
    def test_backward_bfs_for_op_parallel_branch_stop_op(self):
        r"""
        input_1 -> hsigmoid_1 -> hsigmoid_2 ->
                                               \
                                                - Concat->result
                                               /
        input_2 -> hsigmoid_3 -> ShapeOf    ->
        The returned op should be first met HSigmoids which is hsigmoid_2, but not the hsigmoid_3 located after banned
        operation of type "ShapeOf"
        """
        nodes = {
            **regular_op('input_1', {'op': 'Parameter'}),
            **regular_op('hsigmoid_1', {'op': 'HSigmoid'}),
            **regular_op('hsigmoid_2', {'op': 'HSigmoid'}),
            **regular_op('input_2', {'op': 'Parameter'}),
            **regular_op('hsigmoid_3', {'op': 'HSigmoid'}),
            **regular_op('shapeof', {'op': 'ShapeOf'}),
            **regular_op('concat', {'op': 'Concat'}),
            **result('result'),
        }
        edges = [
            ('input_1', 'hsigmoid_1', {
                'out': 0,
                'in': 0
            }),
            ('hsigmoid_1', 'hsigmoid_2', {
                'out': 0,
                'in': 0
            }),
            ('hsigmoid_2', 'concat', {
                'out': 0,
                'in': 0
            }),
            ('input_2', 'hsigmoid_3', {
                'out': 0,
                'in': 0
            }),
            ('hsigmoid_3', 'shapeof', {
                'out': 0,
                'in': 0
            }),
            ('shapeof', 'concat', {
                'out': 0,
                'in': 1
            }),
            ('concat', 'result', {
                'out': 0,
                'in': 0
            }),
        ]

        graph = build_graph_with_edge_attrs(nodes, edges)
        graph.stage = 'front'

        found_nodes = backward_bfs_for_operation(Node(graph, 'result'),
                                                 ['HSigmoid'], ['ShapeOf'])
        self.assertEqual(len(found_nodes), 1)
        self.assertEqual(found_nodes[0].id, 'hsigmoid_2')
Esempio n. 3
0
    def test_backward_bfs_for_op_parallel_branch_op_detected(self):
        r"""
        input_1 -> hsigmoid_1 -> hsigmoid_2 ->
                                               \
                                                - Concat->result
                                               /
        input_2 -> hsigmoid_3 -> hsigmoid_4 ->
        The returned op should be first met HSigmoids which are hsigmoid_2 and hsigmoid_4
        """
        nodes = {
            **regular_op('input_1', {'op': 'Parameter'}),
            **regular_op('hsigmoid_1', {'op': 'HSigmoid'}),
            **regular_op('hsigmoid_2', {'op': 'HSigmoid'}),
            **regular_op('input_2', {'op': 'Parameter'}),
            **regular_op('hsigmoid_3', {'op': 'HSigmoid'}),
            **regular_op('hsigmoid_4', {'op': 'HSigmoid'}),
            **regular_op('concat', {'op': 'Concat'}),
            **result('result'),
        }
        edges = [
            ('input_1', 'hsigmoid_1', {
                'out': 0,
                'in': 0
            }),
            ('hsigmoid_1', 'hsigmoid_2', {
                'out': 0,
                'in': 0
            }),
            ('hsigmoid_2', 'concat', {
                'out': 0,
                'in': 0
            }),
            ('input_2', 'hsigmoid_3', {
                'out': 0,
                'in': 0
            }),
            ('hsigmoid_3', 'hsigmoid_4', {
                'out': 0,
                'in': 0
            }),
            ('hsigmoid_4', 'concat', {
                'out': 0,
                'in': 1
            }),
            ('concat', 'result', {
                'out': 0,
                'in': 0
            }),
        ]

        graph = build_graph_with_edge_attrs(nodes, edges)
        graph.stage = 'front'

        found_nodes = backward_bfs_for_operation(Node(graph, 'result'),
                                                 ['HSigmoid'])
        self.assertEqual(len(found_nodes), 2)
        self.assertSetEqual({found_nodes[0].id, found_nodes[1].id},
                            {'hsigmoid_2', 'hsigmoid_4'})
Esempio n. 4
0
    def test_backward_bfs_for_op_no_ops_detected(self):
        nodes = {
            **regular_op('input', {'op': 'Parameter'}),
            **regular_op('hsigmoid', {'op': 'HSigmoid'}),
            **result('result'),
        }
        edges = [
            ('input', 'hsigmoid', {
                'out': 0,
                'in': 0
            }),
            ('hsigmoid', 'result', {
                'out': 0,
                'in': 0
            }),
        ]

        graph = build_graph_with_edge_attrs(nodes, edges)
        graph.stage = 'front'

        found_nodes = backward_bfs_for_operation(Node(graph, 'result'),
                                                 ['NonExistingOp'])
        self.assertEqual(len(found_nodes), 0)