def test_simplernms_infer_no_shape(self): graph = build_graph(nodes_attributes, [('SimplerNMS_1', 'node_1')], {'node_1': {'is_output': True, 'shape': None}, 'SimplerNMS_1': {'feat_stride': 12, 'post_nms_topn': 150, 'scale': [1, 2, 3]} }) simplernms_node = Node(graph, 'SimplerNMS_1') SimplerNMSOp.simplernms_infer(simplernms_node) self.assertIsNone(graph.node['node_1']['shape'])
def test_simplernms_infer_ideal(self): graph = build_graph(nodes_attributes, [('SimplerNMS_1', 'node_1')], {'node_1': {'is_output': True, 'shape': None}, 'SimplerNMS_1': {'feat_stride': 16, 'post_nms_topn': 150, 'scale': [1, 2, 3]} }) simplernms_node = Node(graph, 'SimplerNMS_1') SimplerNMSOp.simplernms_infer(simplernms_node) exp_shape = np.array([150, 5]) res_shape = graph.node['node_1']['shape'] for i in range(0, len(exp_shape)): self.assertEqual(exp_shape[i], res_shape[i]) self.assertEqual(simplernms_node.scale, ['1', '2', '3'])