def test_init(self): dll = conv_infer_c.init(so_lib_path=self.so_lib_path, input_info=self.input_info, output_info=self.output_info) self.assertEqual(type(dll), ctypes.CDLL) self.assertTrue(hasattr(dll, "qumico"))
def test_infer_c_no_output(self): dll = conv_infer_c.init(so_lib_path=self.so_lib_path, input_info=self.input_info, output_info=self.output_info) self.assertRaises( ctypes.ArgumentError, lambda: conv_infer_c.infer_c( dll=dll, input=self.dll_input, output=None))
def test_infer_c(self): dll = conv_infer_c.init(self.so_lib_path, self.input_info, self.output_info) res = [] for i in prepare_infer_dataset(): output = numpy.zeros(dtype=numpy.float32, shape=(1, 10)) conv_infer_c.infer_c(dll, numpy.expand_dims(i, 0).astype(numpy.float32), output) classification = common_tool.softmax(output) y = common_tool.onehot_decoding(classification) res.append(y[0]) self.assertCountEqual(res, [7, 5, 1, 0, 4, 1, 4, 9, 2, 9])
def test_init_no_so_lib_path_info(self): self.assertRaises( AttributeError, lambda: conv_infer_c.init(so_lib_path=None, input_info=self.input_info, output_info=self.output_info))