Beispiel #1
0
    def test_tf_concat_infer_negative_axis(self):
        graph = build_graph(
            nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'),
                               ('concat', 'node_3'), ('node_3', 'op_output')],
            {
                'node_3': {
                    'shape': None
                },
                'node_1': {
                    'shape': np.array([1, 3, 227, 227])
                },
                'node_2': {
                    'shape': np.array([1, 3, 227, 227])
                },
                'concat': {
                    'axis': -1
                }
            })

        concat_node = Node(graph, 'concat')
        concat_infer(concat_node)
        exp_shape = np.array([1, 3, 227, 454])
        res_shape = graph.node['node_3']['shape']
        for i in range(0, len(exp_shape)):
            self.assertEqual(exp_shape[i], res_shape[i])
Beispiel #2
0
    def test_concat_value_infer(self, value1, value2, output_value, axis):
        graph = build_graph(
            nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'),
                               ('concat', 'node_3'), ('node_3', 'op_output')],
            {
                'node_3': {
                    'shape': output_value.shape,
                    'value': output_value
                },
                'node_1': {
                    'shape': value1.shape,
                    'value': value1
                },
                'node_2': {
                    'shape': value2.shape,
                    'value': value2
                },
                'concat': {
                    'axis': axis
                }
            })

        concat_node = Node(graph, 'concat')
        concat_infer(concat_node)
        res_value = graph.node['node_3']['value']
        self.assertTrue(strict_compare_tensors(output_value, res_value))
Beispiel #3
0
    def test_tf_concat_infer_no_shape(self):
        graph = build_graph(nodes_attributes,
                            [('node_1', 'concat'),
                             ('node_2', 'concat'),
                             ('concat', 'node_3'),
                             ('node_3', 'op_output')
                             ],
                            {'node_3': {'shape': None},
                             'node_1': {'shape': np.array([1, 3, 227, 227])},
                             'node_2': {'shape': None},
                             'concat': {'axis': 2}
                             })

        concat_node = Node(graph, 'concat')
        concat_infer(concat_node)
        res_shape = graph.node['node_3']['shape']
        self.assertIsNone(res_shape)
Beispiel #4
0
    def test_concat_infer_no_shape(self):
        graph = build_graph(
            nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'),
                               ('concat', 'node_3'), ('node_3', 'op_output')],
            {
                'node_3': {
                    'shape': None
                },
                'node_1': {
                    'shape': np.array([1, 3, 227, 227])
                },
                'node_2': {
                    'shape': None
                },
                'concat': {
                    'axis': 2
                }
            })

        concat_node = Node(graph, 'concat')
        with self.assertRaisesRegex(
                Error, "One of the input shapes is not defined for node *"):
            concat_infer(concat_node)
Beispiel #5
0
    def test_concat_infer(self, shape1, shape2, output_shape, axis):
        graph = build_graph(
            nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'),
                               ('concat', 'node_3'), ('node_3', 'op_output')],
            {
                'node_3': {
                    'shape': None,
                    'value': None
                },
                'node_1': {
                    'shape': shape_array(shape1)
                },
                'node_2': {
                    'shape': shape_array(shape2)
                },
                'concat': {
                    'axis': axis
                }
            })

        concat_node = Node(graph, 'concat')
        concat_infer(concat_node)
        res_shape = graph.node['node_3']['shape']
        self.assertTrue(strict_compare_tensors(output_shape, res_shape))
Beispiel #6
0
    def test_concat_infer_not_match(self):
        graph = build_graph(
            nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'),
                               ('concat', 'node_3'), ('node_3', 'op_output')],
            {
                'node_3': {
                    'shape': None,
                    'value': None
                },
                'node_1': {
                    'shape': np.array([1, 3, 227, 227])
                },
                'node_2': {
                    'shape': np.array([1, 2, 227, 227])
                },
                'concat': {
                    'axis': 2
                }
            })

        concat_node = Node(graph, 'concat')
        with self.assertRaisesRegex(
                Error, "Concat input shapes do not match for node*"):
            concat_infer(concat_node)