def test_tf_concat_infer_negative_axis(self): graph = build_graph( nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'), ('concat', 'node_3'), ('node_3', 'op_output')], { 'node_3': { 'shape': None }, 'node_1': { 'shape': np.array([1, 3, 227, 227]) }, 'node_2': { 'shape': np.array([1, 3, 227, 227]) }, 'concat': { 'axis': -1 } }) concat_node = Node(graph, 'concat') concat_infer(concat_node) exp_shape = np.array([1, 3, 227, 454]) res_shape = graph.node['node_3']['shape'] for i in range(0, len(exp_shape)): self.assertEqual(exp_shape[i], res_shape[i])
def test_concat_value_infer(self, value1, value2, output_value, axis): graph = build_graph( nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'), ('concat', 'node_3'), ('node_3', 'op_output')], { 'node_3': { 'shape': output_value.shape, 'value': output_value }, 'node_1': { 'shape': value1.shape, 'value': value1 }, 'node_2': { 'shape': value2.shape, 'value': value2 }, 'concat': { 'axis': axis } }) concat_node = Node(graph, 'concat') concat_infer(concat_node) res_value = graph.node['node_3']['value'] self.assertTrue(strict_compare_tensors(output_value, res_value))
def test_tf_concat_infer_no_shape(self): graph = build_graph(nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'), ('concat', 'node_3'), ('node_3', 'op_output') ], {'node_3': {'shape': None}, 'node_1': {'shape': np.array([1, 3, 227, 227])}, 'node_2': {'shape': None}, 'concat': {'axis': 2} }) concat_node = Node(graph, 'concat') concat_infer(concat_node) res_shape = graph.node['node_3']['shape'] self.assertIsNone(res_shape)
def test_concat_infer_no_shape(self): graph = build_graph( nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'), ('concat', 'node_3'), ('node_3', 'op_output')], { 'node_3': { 'shape': None }, 'node_1': { 'shape': np.array([1, 3, 227, 227]) }, 'node_2': { 'shape': None }, 'concat': { 'axis': 2 } }) concat_node = Node(graph, 'concat') with self.assertRaisesRegex( Error, "One of the input shapes is not defined for node *"): concat_infer(concat_node)
def test_concat_infer(self, shape1, shape2, output_shape, axis): graph = build_graph( nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'), ('concat', 'node_3'), ('node_3', 'op_output')], { 'node_3': { 'shape': None, 'value': None }, 'node_1': { 'shape': shape_array(shape1) }, 'node_2': { 'shape': shape_array(shape2) }, 'concat': { 'axis': axis } }) concat_node = Node(graph, 'concat') concat_infer(concat_node) res_shape = graph.node['node_3']['shape'] self.assertTrue(strict_compare_tensors(output_shape, res_shape))
def test_concat_infer_not_match(self): graph = build_graph( nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'), ('concat', 'node_3'), ('node_3', 'op_output')], { 'node_3': { 'shape': None, 'value': None }, 'node_1': { 'shape': np.array([1, 3, 227, 227]) }, 'node_2': { 'shape': np.array([1, 2, 227, 227]) }, 'concat': { 'axis': 2 } }) concat_node = Node(graph, 'concat') with self.assertRaisesRegex( Error, "Concat input shapes do not match for node*"): concat_infer(concat_node)