示例#1
0
    def test_flatten_infer_no_shape(self):
        graph = build_graph(
            nodes_attributes, [('node_1', 'flatten_1'),
                               ('flatten_1', 'node_2')], {
                                   'node_2': {
                                       'is_output': True,
                                       'shape': None
                                   },
                                   'node_1': {
                                       'shape': None
                                   },
                                   'flatten_1': {
                                       'axis': 1
                                   }
                               })

        flatten_node = Node(graph, 'flatten_1')

        Flatten.infer(flatten_node)
        res_shape = graph.node['node_2']['shape']
        self.assertIsNone(res_shape)
示例#2
0
    def test_flatten_infer(self):
        graph = build_graph(
            nodes_attributes, [('node_1', 'flatten_1'),
                               ('flatten_1', 'node_2')], {
                                   'node_2': {
                                       'is_output': True,
                                       'shape': np.array([1, 3 * 256 * 256])
                                   },
                                   'node_1': {
                                       'shape': np.array([1, 3, 256, 256])
                                   },
                                   'flatten_1': {
                                       'axis': 1,
                                       'dim': []
                                   }
                               })

        flatten_node = Node(graph, 'flatten_1')

        Flatten.infer(flatten_node)
        exp_shape = np.array([1, 3 * 256 * 256])
        res_shape = graph.node['node_2']['shape']
        for i in range(0, len(exp_shape)):
            self.assertEqual(exp_shape[i], res_shape[i])