Пример #1
0
    def test_negative_5(self):
        graph = build_graph(
            nodes,
            edges, {
                'mul': {
                    'can_be_fused': True
                },
                'mul_const_data': {
                    'shape': np.array([3, 1, 1]),
                    'value': np.array([[[0]], [[1]], [[2]]])
                },
                'quantize_data': {
                    'shape': np.array([2, 3, 4, 4])
                },
                'mi_o_data': {
                    'shape': np.array([1, 1, 1, 1]),
                    'value': np.broadcast_to(np.array([0]), (1, 1, 1, 1))
                },
                'ma_o_data': {
                    'shape': np.array([1, 1, 1, 1]),
                    'value': np.broadcast_to(np.array([1]), (1, 1, 1, 1))
                },
            },
            nodes_with_edges_only=True)
        graph.stage = 'middle'
        graph_ref = graph.copy()

        MulFakeQuantizeFuse().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)

        self.assertTrue(flag, resp)
Пример #2
0
    def test_2(self):
        graph = build_graph(nodes,
                            edges, {
                                'mul': {
                                    'can_be_fused': True
                                },
                                'mul_const_data': {
                                    'shape': np.array([1]),
                                    'value': np.array([2])
                                },
                                'quantize_data': {
                                    'shape': np.array([2, 3, 4, 4])
                                },
                                'mi_o_data': {
                                    'shape': np.array([1]),
                                    'value': np.array([0])
                                },
                                'ma_o_data': {
                                    'shape': np.array([1]),
                                    'value': np.array([1])
                                },
                            },
                            nodes_with_edges_only=True)
        graph.stage = 'middle'
        graph_ref = build_graph(nodes,
                                edges_ref, {
                                    'quantize_data': {
                                        'shape': np.array([2, 3, 4, 4])
                                    },
                                    'mul_const_data': {
                                        'shape': np.array([1]),
                                        'value': np.array([2])
                                    },
                                    'mi_o_data': {
                                        'shape': np.array([1]),
                                        'value': np.array([0])
                                    },
                                    'ma_o_data': {
                                        'shape': np.array([1]),
                                        'value': np.array([1])
                                    },
                                    'mi_i_data': {
                                        'shape': np.array([1]),
                                        'value': np.array([-5])
                                    },
                                    'ma_i_data': {
                                        'shape': np.array([1]),
                                        'value': np.array([5])
                                    },
                                },
                                nodes_with_edges_only=True)

        MulFakeQuantizeFuse().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)

        self.assertTrue(flag, resp)
Пример #3
0
    def test_negative_1(self):
        graph = build_graph(nodes, edges, nodes_with_edges_only=True)
        graph.stage = 'middle'
        graph_ref = build_graph(nodes, edges, nodes_with_edges_only=True)

        MulFakeQuantizeFuse().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)

        self.assertTrue(flag, resp)
Пример #4
0
    def find_and_replace_pattern(self, graph: Graph):
        fw = graph.graph['fw']
        argv = graph.graph['cmd_params']
        layout = graph.graph['layout']

        for_graph_and_each_sub_graph_recursively(graph, fuse_pad)
        for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())

        # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
        for_graph_and_each_sub_graph_recursively(
            graph,
            lambda graph: mark_unfused_nodes(graph, argv.finegrain_fusing))

        # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
        # IE doesn't support batchNormInference with 4 inputs, so we have to split it to two ScaleShift
        for_graph_and_each_sub_graph_recursively(graph, convert_batch_norm)

        if fw == 'caffe':
            # Converting ScaleShift layer to Mul->Add
            for_graph_and_each_sub_graph_recursively(
                graph, convert_scale_shift_to_mul_add)

        for_graph_and_each_sub_graph_recursively(
            graph,
            Div().find_and_replace_pattern)
        for_graph_and_each_sub_graph_recursively(
            graph,
            Sub().find_and_replace_pattern)
        for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())

        if not argv.disable_fusing:
            if fw != 'caffe':
                # Converting ScaleShift layer to Mul->Add
                for_graph_and_each_sub_graph_recursively(
                    graph, convert_scale_shift_to_mul_add)
                for_graph_and_each_sub_graph_recursively(
                    graph, lambda G: G.clean_up())

            # Fusing the sequences of Mul/Add operations
            for_graph_and_each_sub_graph_recursively(graph,
                                                     fuse_mul_add_sequence)
            for_graph_and_each_sub_graph_recursively(graph,
                                                     lambda G: G.clean_up())

            normalize_eltwise_inputs(graph)
            for_graph_and_each_sub_graph_recursively(graph,
                                                     lambda G: G.clean_up())

            # Fusing linear operation to Convolution
            for_graph_and_each_sub_graph_recursively(graph, fuse_linear_ops)
            for_graph_and_each_sub_graph_recursively(graph,
                                                     lambda G: G.clean_up())

        if not argv.disable_gfusing:
            for_graph_and_each_sub_graph_recursively(
                graph, grouped_convolutions_fusing)
            for_graph_and_each_sub_graph_recursively(graph,
                                                     lambda G: G.clean_up())
            if not argv.disable_fusing:
                for_graph_and_each_sub_graph_recursively(
                    graph, fuse_linear_ops)
                for_graph_and_each_sub_graph_recursively(
                    graph, lambda G: G.clean_up())

        for_graph_and_each_sub_graph_recursively(graph,
                                                 normalize_eltwise_inputs)
        for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())

        if not argv.disable_fusing:
            MarkNodesToFuseUpToFakeQuantize().find_and_replace_pattern(graph)
            FakeQuantizeFuse().find_and_replace_pattern(graph)
            AddFakeQuantizeFuse().find_and_replace_pattern(graph)
            MulFakeQuantizeFuse().find_and_replace_pattern(graph)
            for_graph_and_each_sub_graph_recursively(graph,
                                                     lambda G: G.clean_up())

        for_graph_and_each_sub_graph_recursively(graph, fuse_pad)
        for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())

        if layout != 'NHWC' and not argv.disable_resnet_optimization:
            stride_optimization(graph)