コード例 #1
0
    def test_dequantize_no_zerop(self):
        graph = build_graph(nodes1_attributes, [
            ('input', 'dequantize'),
            ('scale_param_dq', 'dequantize'),
            ('dequantize', 'out'),
        ], {
            'scale_param_dq': {
                'shape': np.array([]),
                'value': np.float32(1.0 / 255)
            },
        },
                            nodes_with_edges_only=True)

        graph_ref = build_graph(nodes_ref_attributes, [
            ('input', 'cast'),
            ('cast', 'mul'),
            ('scale_param_dq', 'mul'),
            ('mul', 'out'),
        ], {
            'scale_param_dq': {
                'shape': np.array([]),
                'value': np.float32(1.0 / 255)
            }
        },
                                nodes_with_edges_only=True)

        graph.stage = 'front'
        DequantizeLinearResolver().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'out',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
コード例 #2
0
    def test_dequantize(self):
        graph = build_graph(nodes1_attributes, [
            ('input', 'dequantize'),
            ('scale_param_dq', 'dequantize'),
            ('zerop_param_dq', 'dequantize'),
            ('dequantize', 'out'),
        ], {
            'scale_param_dq': {
                'shape': np.array([]),
                'value': np.float32(1.0 / 255)
            },
            'zerop_param_dq': {
                'shape': np.array([]),
                'value': np.uint8(0)
            },
        },
                            nodes_with_edges_only=True)
        graph.graph['cmd_params'] = Namespace(keep_shape_ops=True,
                                              data_type='FP32')

        graph_ref = build_graph(nodes_ref_attributes, [
            ('input', 'cast'),
            ('cast', 'sub'),
            ('zerop_param_dq', 'sub'),
            ('sub', 'mul'),
            ('scale_param_dq', 'mul'),
            ('mul', 'out'),
        ], {
            'scale_param_dq': {
                'shape': np.array([]),
                'value': np.float32(1.0 / 255)
            },
            'zerop_param_dq': {
                'shape': np.array([]),
                'value': np.uint8(0)
            }
        },
                                nodes_with_edges_only=True)

        graph.stage = 'front'
        DequantizeLinearResolver().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'out',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)