예제 #1
0
    def test_nms_infer_opset1(self):
        nms_node = Node(self.graph, 'nms')
        nms_node['version'] = 'opset1'
        NonMaxSuppression.infer(nms_node)
        NonMaxSuppression.type_infer(nms_node)

        self.assertTrue(np.array_equal(nms_node.out_port(0).data.get_shape(), [100, 3]))
        self.assertTrue(nms_node.out_port(0).get_data_type() == np.int64)
예제 #2
0
    def test_nms_infer_i64_opset5_1_out(self):
        nms_node = Node(self.graph, 'nms')
        nms_node['version'] = 'opset5'
        nms_node['output_type'] = np.int64
        NonMaxSuppression.infer(nms_node)
        NonMaxSuppression.type_infer(nms_node)

        self.assertTrue(np.array_equal(nms_node.out_port(0).data.get_shape(),
                                       shape_array([dynamic_dimension_value, 3])))
        self.assertTrue(nms_node.out_port(0).get_data_type() == np.int64)
    def test_nms_infer_v10_opset1(self):
        self.graph.graph['cmd_params'] = FakeAttr(generate_experimental_IR_V10=True, ir_version=10)

        nms_node = Node(self.graph, 'nms')
        nms_node['version'] = 'opset1'
        NonMaxSuppression.infer(nms_node)
        NonMaxSuppression.type_infer(nms_node)

        self.assertTrue(np.array_equal(nms_node.out_port(0).data.get_shape(), [100, 3]))
        self.assertTrue(nms_node.out_port(0).get_data_type() == np.int64)
예제 #4
0
    def test_nms_infer_v10_i32_opset3(self):
        self.graph.graph['cmd_params'] = FakeAttr(ir_version=10)

        nms_node = Node(self.graph, 'nms')
        nms_node['version'] = 'opset3'
        nms_node['output_type'] = np.int32
        NonMaxSuppression.infer(nms_node)
        NonMaxSuppression.type_infer(nms_node)

        self.assertTrue(np.array_equal(nms_node.out_port(0).data.get_shape(), [100, 3]))
        self.assertTrue(nms_node.out_port(0).get_data_type() == np.int32)