def test_concat_value_infer(self, value1, value2, output_value, axis): graph = build_graph( nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'), ('concat', 'node_3'), ('node_3', 'op_output')], { 'node_3': { 'shape': output_value.shape, 'value': output_value }, 'node_1': { 'shape': value1.shape, 'value': value1 }, 'node_2': { 'shape': value2.shape, 'value': value2 }, 'concat': { 'axis': axis } }) concat_node = Node(graph, 'concat') concat_infer(concat_node) res_value = graph.node['node_3']['value'] self.assertTrue(strict_compare_tensors(output_value, res_value))
def test_concat_infer_no_shape(self): graph = build_graph( nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'), ('concat', 'node_3'), ('node_3', 'op_output')], { 'node_3': { 'shape': None }, 'node_1': { 'shape': np.array([1, 3, 227, 227]) }, 'node_2': { 'shape': None }, 'concat': { 'axis': 2 } }) concat_node = Node(graph, 'concat') with self.assertRaisesRegex( Error, "One of the input shapes is not defined for node *"): concat_infer(concat_node)
def test_concat_infer(self, shape1, shape2, output_shape, axis): graph = build_graph( nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'), ('concat', 'node_3'), ('node_3', 'op_output')], { 'node_3': { 'shape': None, 'value': None }, 'node_1': { 'shape': shape_array(shape1) }, 'node_2': { 'shape': shape_array(shape2) }, 'concat': { 'axis': axis } }) concat_node = Node(graph, 'concat') concat_infer(concat_node) res_shape = graph.node['node_3']['shape'] self.assertTrue(strict_compare_tensors(output_shape, res_shape))
def test_concat_infer_not_match(self): graph = build_graph( nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'), ('concat', 'node_3'), ('node_3', 'op_output')], { 'node_3': { 'shape': None, 'value': None }, 'node_1': { 'shape': np.array([1, 3, 227, 227]) }, 'node_2': { 'shape': np.array([1, 2, 227, 227]) }, 'concat': { 'axis': 2 } }) concat_node = Node(graph, 'concat') with self.assertRaisesRegex( Error, "Concat input shapes do not match for node*"): concat_infer(concat_node)