Ejemplo n.º 1
0
    def test_mark_unfused_nodes_2(self):
        # Placeholder->ScaleShift->Mul->Add
        graph = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_1_data', 'mul_1'),
                               ('mul_1_w', 'mul_1'), ('mul_1', 'mul_1_data'),
                               ('mul_1_data', 'add_1'), ('add_1_w', 'add_1'),
                               ('add_1', 'add_1_data'),
                               ('add_1_data', 'mul_2'), ('mul_2_w', 'mul_2'),
                               ('mul_2', 'mul_2_data'),
                               ('mul_2_data', 'concat_1'),
                               ('concat_1', 'concat_1_data'),
                               ('placeholder_1_data', 'concat_1'),
                               ('concat_1_data', 'op_output')], {
                                   'placeholder_1_data': {
                                       'shape': np.array([1, 227, 227, 3])
                                   },
                                   'mul_1_data': {
                                       'shape': np.array([1, 227, 227, 3])
                                   },
                                   'add_1_data': {
                                       'shape': np.array([1, 227, 227, 3])
                                   },
                                   'mul_2_data': {
                                       'shape': np.array([1, 227, 227, 3])
                                   },
                                   'mul_1_w': {
                                       'shape': np.array([1]),
                                       'value': 6
                                   },
                                   'add_1_w': {
                                       'shape': np.array([1]),
                                       'value': 6
                                   },
                                   'mul_2_w': {
                                       'shape': np.array([1]),
                                       'value': 6
                                   },
                               })
        graph.graph['layout'] = 'NHWC'

        mark_unfused_nodes(graph, '.*')

        self.assertFalse(graph.node['mul_1']['can_be_fused'],
                         "can_be_fused should be False")
        self.assertFalse(graph.node['mul_2']['can_be_fused'],
                         "can_be_fused should be False")
        self.assertFalse(graph.node['add_1']['can_be_fused'],
                         "can_be_fused should be False")
        self.assertFalse(graph.node['placeholder_1']['can_be_fused'],
                         "can_be_fused should be False")
        self.assertFalse(graph.node['concat_1']['can_be_fused'],
                         "can_be_fused should be False")
Ejemplo n.º 2
0
    def find_and_replace_pattern(self, graph: Graph):
        fw = graph.graph['fw']
        argv = graph.graph['cmd_params']
        layout = graph.graph['layout']

        for_graph_and_each_sub_graph_recursively(graph, fuse_pad)
        for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())

        # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
        for_graph_and_each_sub_graph_recursively(
            graph,
            lambda graph: mark_unfused_nodes(graph, argv.finegrain_fusing))

        # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
        # IE doesn't support batchNormInference with 4 inputs, so we have to split it to two ScaleShift
        for_graph_and_each_sub_graph_recursively(graph, convert_batch_norm)

        if fw == 'caffe':
            # Converting ScaleShift layer to Mul->Add
            for_graph_and_each_sub_graph_recursively(
                graph, convert_scale_shift_to_mul_add)

        for_graph_and_each_sub_graph_recursively(
            graph,
            Div().find_and_replace_pattern)
        for_graph_and_each_sub_graph_recursively(
            graph,
            Sub().find_and_replace_pattern)
        for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())

        if not argv.disable_fusing:
            if fw != 'caffe':
                # Converting ScaleShift layer to Mul->Add
                for_graph_and_each_sub_graph_recursively(
                    graph, convert_scale_shift_to_mul_add)
                for_graph_and_each_sub_graph_recursively(
                    graph, lambda G: G.clean_up())

            # Fusing the sequences of Mul/Add operations
            for_graph_and_each_sub_graph_recursively(graph,
                                                     fuse_mul_add_sequence)
            for_graph_and_each_sub_graph_recursively(graph,
                                                     lambda G: G.clean_up())

            normalize_eltwise_inputs(graph)
            for_graph_and_each_sub_graph_recursively(graph,
                                                     lambda G: G.clean_up())

            # Fusing linear operation to Convolution
            for_graph_and_each_sub_graph_recursively(graph, fuse_linear_ops)
            for_graph_and_each_sub_graph_recursively(graph,
                                                     lambda G: G.clean_up())

        if not argv.disable_gfusing:
            for_graph_and_each_sub_graph_recursively(
                graph, grouped_convolutions_fusing)
            for_graph_and_each_sub_graph_recursively(graph,
                                                     lambda G: G.clean_up())
            if not argv.disable_fusing:
                for_graph_and_each_sub_graph_recursively(
                    graph, fuse_linear_ops)
                for_graph_and_each_sub_graph_recursively(
                    graph, lambda G: G.clean_up())

        for_graph_and_each_sub_graph_recursively(graph,
                                                 normalize_eltwise_inputs)
        for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())

        if not argv.disable_fusing:
            MarkNodesToFuseUpToFakeQuantize().find_and_replace_pattern(graph)
            FakeQuantizeFuse().find_and_replace_pattern(graph)
            AddFakeQuantizeFuse().find_and_replace_pattern(graph)
            MulFakeQuantizeFuse().find_and_replace_pattern(graph)
            for_graph_and_each_sub_graph_recursively(graph,
                                                     lambda G: G.clean_up())

        for_graph_and_each_sub_graph_recursively(graph, fuse_pad)
        for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())

        if layout != 'NHWC' and not argv.disable_resnet_optimization:
            stride_optimization(graph)