Beispiel #1
0
    def test_forward_bfs_simple(self):
        # Placeholder->ScaleShift->Mul->Add
        graph = build_graph(nodes_attributes,
                            [('placeholder_1', 'placeholder_1_data'),
                             ('placeholder_1_data', 'scaleshift_1'),
                             ('scaleshift_1_w', 'scaleshift_1'),
                             ('scaleshift_1', 'scaleshift_1_data'),
                             ('scaleshift_1_data', 'mul_1'),
                             ('mul_1', 'mul_1_data'),
                             ('mul_1_data', 'add_1'),
                             ('add_1', 'add_1_data'),
                             ('add_1_data', 'op_output')
                             ])

        res = forward_bfs(Node(graph, 'placeholder_1'), ['ScaleShift', 'Mul'], ['Add'])
        self.assertTrue(len(res) == 1 and res[0].id == 'add_1', 'Add operation was not found by bfs')

        res = forward_bfs(Node(graph, 'placeholder_1'), [], ['Add'], allowed_all=True)
        self.assertTrue(len(res) == 1 and res[0].id == 'add_1', 'Add operation was not found by bfs')

        res = forward_bfs(Node(graph, 'placeholder_1_data'), ['ScaleShift'], ['Add'])
        self.assertTrue(len(res) == 0, 'No one node should be found! But bfs found {} nodes'.format(len(res)))

        res = forward_bfs(Node(graph, 'placeholder_1_data'), ['ScaleShift'], ['Mul', 'Add'])
        self.assertTrue(len(res) == 1 and res[0].id == 'mul_1', 'BFS should find only one Mul operation')
Beispiel #2
0
    def test_forward_bfs_hard(self):
        # Placeholder->ScaleShift->Mul1->Add1---->Concat
        #             `----------->Add2->Mul2--'
        graph = build_graph(nodes_attributes,
                            [('placeholder_1', 'placeholder_1_data'),
                             ('placeholder_1_data', 'scaleshift_1'),
                             ('placeholder_1_data', 'add_2'),
                             ('scaleshift_1_w', 'scaleshift_1'),
                             ('scaleshift_1', 'scaleshift_1_data'),
                             ('scaleshift_1_data', 'mul_1'),
                             ('mul_1', 'mul_1_data'), ('mul_1_data', 'add_1'),
                             ('add_1', 'add_1_data'), ('add_2', 'add_2_data'),
                             ('add_2_data', 'mul_2'), ('mul_2', 'mul_2_data'),
                             ('add_1_data', 'concat_1'),
                             ('mul_2_data', 'concat_1'),
                             ('concat_1', 'concat_1_data'),
                             ('concat_1_data', 'op_output')])

        res = forward_bfs(Node(graph, 'placeholder_1'),
                          ['ScaleShift', 'Mul', 'Add'], ['Concat'])
        self.assertTrue(
            len(res) == 1 and res[0].id == 'concat_1',
            'Probably Concat operation was not found by bfs')

        res = forward_bfs(Node(graph, 'placeholder_1'), ['ScaleShift', 'Mul'],
                          ['Add'])
        self.assertTrue(
            len(res) == 2
            and all([res[x].id in ['add_1', 'add_2']
                     for x in range(len(res))]),
            'Add operations was not found by bfs')

        res = forward_bfs(Node(graph, 'placeholder_1'), ['ScaleShift'],
                          ['Add'])
        self.assertTrue(len(res) == 0, 'BFS shouldn\'t find any operations')

        res = forward_bfs(Node(graph, 'placeholder_1'), [], ['Add'],
                          allowed_all=True)
        self.assertTrue(
            len(res) == 2
            and all([res[x].id in ['add_1', 'add_2']
                     for x in range(len(res))]),
            'Add operations was not found by bfs')

        res = forward_bfs(Node(graph, 'placeholder_1_data'), ['ScaleShift'],
                          ['Concat'])
        self.assertTrue(
            len(res) == 0,
            'No one node should be found! But bfs found {} nodes'.format(
                len(res)))
def fuse_linear_ops(graph: Graph):
    """
    This function makes fusing of linear operations (Mul,Add) to Convolution/FC.
    """
    fuse_count = 0

    # Fusion in backward direction
    nodes = graph.pseudo_topological_sort()
    for node in nodes:
        is_fused = False

        # Fuse Mul to Convolution/FC
        if node.soft_get('op') == 'Mul' and get_value_in_port(
                node) is not None and node.has_and_set('can_be_fused'):
            fuse_nodes = backward_bfs(
                node, [], ['Convolution', 'Deconvolution', 'MatMul'])
            is_fused = _fuse_mul(graph, node, fuse_nodes)

        fuse_count += is_fused

    # Fusion in forward direction
    nodes = graph.pseudo_topological_sort(reverse=True)
    for node in nodes:
        is_fused = False

        # Fuse Mul to Convolution/FC
        if node.soft_get('op') == 'Mul' and get_value_in_port(
                node) is not None and node.has_and_set('can_be_fused'):
            fuse_nodes = forward_bfs(
                node, [], ['Convolution', 'Deconvolution', 'MatMul'])
            is_fused = _fuse_mul(graph, node, fuse_nodes, False)

        fuse_count += is_fused

    log.debug("Fused {} nodes".format(fuse_count))