def test_init(self): dll = vgg16_infer_c.init(so_lib_path=self.so_lib_path, input_info=self.input_info, output_info=self.output_info) self.assertEqual(type(dll), ctypes.CDLL) self.assertTrue(hasattr(dll, "qumico"))
def test_vgg16_infer_c(self): count_correct = 0 for x in self.classes_test: file_name = x + ".jpg" img_file = os.path.join(os.path.dirname(__file__), "input", file_name) img = load_img(img_file, grayscale=False, color_mode='rgb', target_size=(224, 224)) img = img_to_array(img) dll = vgg16_infer_c.init(self.so_lib_path, self.input_info, self.output_info) input = numpy.expand_dims(img, axis=0) output = numpy.zeros(dtype=numpy.float32, shape=(1, 10)) dll.qumico(input, output) result_index = numpy.argmax(output, axis=-1) for i in result_index: if x in self.classes_test[i]: count_correct += 1 accuracy = count_correct / len(self.classes_test) self.assertGreaterEqual(accuracy, 0.7)
def test_init_no_so_lib_path_info(self): self.assertRaises( AttributeError, lambda: vgg16_infer_c.init(so_lib_path=None, input_info=self.input_info, output_info=self.output_info))