コード例 #1
0
    def test_concat_value_infer(self, value1, value2, output_value, axis):
        graph = build_graph(
            nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'),
                               ('concat', 'node_3'), ('node_3', 'op_output')],
            {
                'node_3': {
                    'shape': output_value.shape,
                    'value': output_value
                },
                'node_1': {
                    'shape': value1.shape,
                    'value': value1
                },
                'node_2': {
                    'shape': value2.shape,
                    'value': value2
                },
                'concat': {
                    'axis': axis
                }
            })

        concat_node = Node(graph, 'concat')
        concat_infer(concat_node)
        res_value = graph.node['node_3']['value']
        self.assertTrue(strict_compare_tensors(output_value, res_value))
コード例 #2
0
    def test_concat_infer_no_shape(self):
        graph = build_graph(
            nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'),
                               ('concat', 'node_3'), ('node_3', 'op_output')],
            {
                'node_3': {
                    'shape': None
                },
                'node_1': {
                    'shape': np.array([1, 3, 227, 227])
                },
                'node_2': {
                    'shape': None
                },
                'concat': {
                    'axis': 2
                }
            })

        concat_node = Node(graph, 'concat')
        with self.assertRaisesRegex(
                Error, "One of the input shapes is not defined for node *"):
            concat_infer(concat_node)
コード例 #3
0
    def test_concat_infer(self, shape1, shape2, output_shape, axis):
        graph = build_graph(
            nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'),
                               ('concat', 'node_3'), ('node_3', 'op_output')],
            {
                'node_3': {
                    'shape': None,
                    'value': None
                },
                'node_1': {
                    'shape': shape_array(shape1)
                },
                'node_2': {
                    'shape': shape_array(shape2)
                },
                'concat': {
                    'axis': axis
                }
            })

        concat_node = Node(graph, 'concat')
        concat_infer(concat_node)
        res_shape = graph.node['node_3']['shape']
        self.assertTrue(strict_compare_tensors(output_shape, res_shape))
コード例 #4
0
    def test_concat_infer_not_match(self):
        graph = build_graph(
            nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'),
                               ('concat', 'node_3'), ('node_3', 'op_output')],
            {
                'node_3': {
                    'shape': None,
                    'value': None
                },
                'node_1': {
                    'shape': np.array([1, 3, 227, 227])
                },
                'node_2': {
                    'shape': np.array([1, 2, 227, 227])
                },
                'concat': {
                    'axis': 2
                }
            })

        concat_node = Node(graph, 'concat')
        with self.assertRaisesRegex(
                Error, "Concat input shapes do not match for node*"):
            concat_infer(concat_node)