def init_predictor(args): config = Config() if args.model_dir == "": config.set_model(args.model_file, args.params_file) else: config.set_model(args.model_dir) #config.disable_glog_info() config.enable_use_gpu(1000, 3) predictor = create_predictor(config) return predictor
#开始预测 predictor.run() #开始看结果 results =[] output_names = predictor.get_output_names() for i, name in enumerate(output_names): output_tensor = predictor.get_output_handle(name) output_data = output_tensor.copy_to_cpu() results.append(output_data) return results if __name__ == '__main__': #读入的摄像头信息根据大家自己ROS结点自己读入咯,这里我就直接用摄像头读取替代了,大家到时候这里自己更换进行 cap = cv2.VideoCapture(1) config = Config() config.set_model("inference_model/yolov3_mobilenet_v1_voc/__model__","inference_model/yolov3_mobilenet_v1_voc/__params__") config.disable_gpu() config.enable_mkldnn() predictor = create_predictor(config) im_size = 416 im_shape = np.array([416, 416]).reshape((1, 2)).astype(np.int32) while(1): success, img = cap.read() if (success == False): break img = cv2.resize(img, (im_size,im_size),0, 0) data = preprocess(img, im_size) results = trash_detect(trash_detector, [data, im_shape]) for res in results[0]: img = cv2.rectangle(img, (int(res[2]), int(res[3])), (int(res[4]), int(res[5])), (255, 0, 0), 2) cv2.imshow("img", img)