コード例 #1
0
    def test_bn(self):
        bn_pb = FakeBNProtoLayer(FakeParam('eps', 0.0001))
        mean = [1, 2.5, 3]
        var = [0.5, 0.1, 1.2]
        scale = [2.3, 3.4, 4.5]
        shift = [0.8, 0.6, 0.4]
        bn_bin = FakeBNBinLayer([
            FakeParam('data', mean),
            FakeParam('data', var),
            FakeParam('data', scale),
            FakeParam('data', shift)
        ])
        nodes = [
            ('input', {
                'kind': 'op',
                'type': 'Identity',
                'op': 'Identity'
            }),
            ('bn', {
                'type': None,
                'kind': 'op',
                'op': 'BN',
                'pb': bn_pb,
                'model_pb': bn_bin
            }),
            ('output', {
                'kind': 'op',
                'type': 'Identity',
                'op': 'Identity'
            }),
        ]
        edges = [
            ('input', 'bn', {
                'in': 0,
                'out': 0
            }),
            ('bn', 'output', {
                'in': 0,
                'out': 0
            }),
        ]
        graph = build_graph_with_attrs(nodes, edges)
        node = Node(graph, 'bn')
        graph.stage = 'front'

        BNToScaleShift().find_and_replace_pattern(graph)

        ref_nodes = {
            'input': {
                'kind': 'op',
                'type': 'Identity',
                'op': 'Identity'
            },
            'scale': {
                'kind': 'op',
                'type': 'Const',
                'op': 'Const',
                'value': np.array([1.11796412, 3.2272172, 4.74282367])
            },
            'shift': {
                'kind': 'op',
                'type': 'Const',
                'op': 'Const',
                'value': np.array([-2.07131747, -10.87253847, -20.14270653])
            },
            'ss': {
                'type': 'ScaleShift',
                'kind': 'op',
                'op': 'ScaleShift'
            },
            'output': {
                'kind': 'op',
                'type': 'Identity',
                'op': 'Identity'
            },
        }
        ref_edges = [
            ('input', 'ss', {
                'in': 0,
                'out': 0
            }),
            ('scale', 'ss', {
                'in': 1,
                'out': 0
            }),
            ('shift', 'ss', {
                'in': 2,
                'out': 0
            }),
            ('ss', 'output', {
                'in': 0,
                'out': 0
            }),
        ]
        ref_graph = build_graph_with_edge_attrs(ref_nodes, ref_edges)
        (flag, resp) = compare_graphs(graph,
                                      ref_graph,
                                      'input',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
コード例 #2
0
 def test_get_list_from_container_list_match_empty(self):
     res = get_list_from_container(FakeParam('prop', []), 'prop', int)
     self.assertEqual(res, [])
コード例 #3
0
 def test_get_list_from_container_no_existing_param(self):
     res = get_list_from_container(FakeParam("p", "1"), 'prop', int)
     self.assertEqual(res, [])
コード例 #4
0
 def test_get_list_from_container_simple_type_match(self):
     res = get_list_from_container(FakeParam('prop', 10), 'prop', int)
     self.assertEqual(res, [10])
コード例 #5
0
 def ListFields(self):
     keys = []
     for k in self.dict_values.keys():
         keys.append([FakeParam('name', k)])
     return keys