Exemplo n.º 1
0
def init_predictor(args):
    config = Config()
    if args.model_dir == "":
        config.set_model(args.model_file, args.params_file)
    else:
        config.set_model(args.model_dir)
    #config.disable_glog_info()
    config.enable_use_gpu(1000, 3)
    predictor = create_predictor(config)
    return predictor
    #开始预测
    predictor.run()
    #开始看结果
    results =[]
    output_names = predictor.get_output_names()
    for i, name in enumerate(output_names):
        output_tensor = predictor.get_output_handle(name)
        output_data = output_tensor.copy_to_cpu()
        results.append(output_data)
    return results

if __name__ == '__main__':
    #读入的摄像头信息根据大家自己ROS结点自己读入咯,这里我就直接用摄像头读取替代了,大家到时候这里自己更换进行
    cap = cv2.VideoCapture(1)
    config = Config()
    config.set_model("inference_model/yolov3_mobilenet_v1_voc/__model__","inference_model/yolov3_mobilenet_v1_voc/__params__")
    config.disable_gpu()
    config.enable_mkldnn()
    predictor = create_predictor(config)
    im_size = 416
    im_shape = np.array([416, 416]).reshape((1, 2)).astype(np.int32)
    while(1):
        success, img = cap.read()
        if (success == False):
            break
        img = cv2.resize(img, (im_size,im_size),0, 0)
        data = preprocess(img, im_size)
        results = trash_detect(trash_detector, [data, im_shape])
        for res in results[0]:
            img = cv2.rectangle(img, (int(res[2]), int(res[3])), (int(res[4]), int(res[5])), (255, 0, 0), 2)
        cv2.imshow("img", img)