Beispiel #1
0
    def test_const_extra_outputs_i1_case(self):
        graph = build_graph(
            nodes,
            const_extra, {
                'const_1_d': {
                    'value': np.full([1, 2, 3, 4], -1, dtype=np.float32)
                },
                'quantize': {
                    'levels': 2
                },
                'quantize_1': {
                    'levels': 2
                }
            },
            nodes_with_edges_only=True)
        graph.graph['layout'] = 'NHWC'
        graph.stage = 'middle'

        graph_ref = build_graph(nodes,
                                ref_const_extra,
                                nodes_with_edges_only=True)

        ReluFakeQuantizeMark().find_and_replace_pattern(graph)
        ReluQuantizeFuse().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'relu',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
        np.array_equal(np.full([1, 2, 3, 4], float('-inf'), dtype=np.float32),
                       graph_ref.node['const_1_d']['value'])
Beispiel #2
0
    def test_classic_i1_positive_case(self):
        graph = build_graph(
            nodes,
            i1_edges, {
                'const_1_d': {
                    'value': np.zeros([1, 2, 3, 4], dtype=np.float32)
                },
                'quantize': {
                    'levels': 2
                }
            },
            nodes_with_edges_only=True)
        graph.graph['layout'] = 'NHWC'
        graph.stage = 'middle'

        graph_ref = build_graph(nodes,
                                ref_i1_edges,
                                nodes_with_edges_only=True)

        ReluFakeQuantizeMark().find_and_replace_pattern(graph)
        ReluQuantizeFuse().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)