class Lane_warning: def __init__(self): self.image_pub = rospy.Publisher("lanedetframe", Image, queue_size=1) self.maskimg_pub = rospy.Publisher("lanedetmask", Image, queue_size=1) self.binimg_pub = rospy.Publisher("lanedetbin", Image, queue_size=1) self.morphoimg_pub = rospy.Publisher("lanedetmorph", Image, queue_size=1) # self.bridge = CvBridge() self.yolo_sub = rospy.Subscriber("YOLO_detect_result_boxes", BoundingBoxes, self.callbackyolo) self.tlight_sub = rospy.Subscriber("tlight_detect_result_boxes", BoundingBoxes, self.callbackTrafficLight) # self.image_sub = rospy.Subscriber("YOLO_detect_result", Image, self.callbackRos) # self.image_sub = rospy.Subscriber("/camera/image_color", Image, self.callbackRos) self.image_sub = rospy.Subscriber("/wideangle/image_color", Image, self.callbackRos) # self.image_sub = message_filters.Subscriber("/camera/rgb/image_raw", Image,queue_size=1, buff_size=110592*6) self.weights_file = rospy.get_param("lanenet_weight") self.CUDA = torch.cuda.is_available() self.postprocessor = LaneNetPostProcessor() self.warningModule = Detection() self.band_width = 1.5 self.image_X = CFG.IMAGE_WIDTH self.image_Y = CFG.IMAGE_HEIGHT self.car_X = self.image_X / 2 self.car_Y = self.image_Y self.model = LaneNet(pretrained=False, embed_dim=4, delta_v=.5, delta_d=3.) self.save_dict = torch.load(self.weights_file, map_location='cuda:0') self.model.load_state_dict(self.save_dict['net']) # self.model.load_state_dict(torch.load(self.weights_file, map_location='cuda:0')) if self.CUDA: self.model.cuda() self.model.set_test() self.lastlane = np.ndarray(4, ) self.bridge = CvBridge() self.leftlane = Lane('left') self.rightlane = Lane('right') self.tracker = LaneTracker() # self.out = cv2.VideoWriter(str(time.time())+'testwrite.avi',cv2.VideoWriter_fourcc(*'MJPG'), 10.0, (CFG.IMAGE_WIDTH, CFG.IMAGE_HEIGHT),True) self.img = np.zeros([CFG.IMAGE_WIDTH, CFG.IMAGE_HEIGHT, 3], np.uint8) self.yoloBoxes = BoundingBoxes() self.trafficLightBoxes = BoundingBoxes() self.warning = 0 def callbackyolo(self, boxmsg): print('callbackyolo, boxes len:', boxmsg.objNum) self.yoloBoxes.objNum = boxmsg.objNum self.yoloBoxes.bounding_boxes = [] for i in range(boxmsg.objNum): print('box id:', boxmsg.bounding_boxes[i].id) box = boxmsg.bounding_boxes[i] self.yoloBoxes.bounding_boxes.append(box) # cv2.rectangle(self.img, (int(box.xmin), int(box.ymin)), (int(box.xmax), int(box.ymax)), (0, 255, 0), 4) def callbackTrafficLight(self, boxmsg): print('callbackTrafficLight, boxes len:', boxmsg.objNum) self.trafficLightBoxes.objNum = boxmsg.objNum self.trafficLightBoxes.bounding_boxes = [] for i in range(boxmsg.objNum): print('box id:', boxmsg.bounding_boxes[i].id) box = boxmsg.bounding_boxes[i] self.trafficLightBoxes.bounding_boxes.append(box) def transform_input(self, img): _set = "IMAGENET" mean = IMG_MEAN[_set] std = IMG_STD[_set] # transform_img = Resize((800, 288)) transform_img = Resize((512, 256)) transform_x = Compose(ToTensor(), Normalize(mean=mean, std=std)) #img_org = img[255:945, :, :] img = transform_img({'img': img})['img'] x = transform_x({'img': img})['img'] # print(x) x.unsqueeze_(0) x = x.to('cuda') return x def detection(self, input): #startt = time.time() if self.CUDA: input = input.cuda() with torch.no_grad(): output = self.model(input, None) # print('detection use:', time.time()-startt) return self.cluster(output) def cluster(self, output): #startt = time.time() global g_frameCnt embedding = output['embedding'] embedding = embedding.detach().cpu().numpy() embedding = np.transpose(embedding[0], (1, 2, 0)) binary_seg = output['binary_seg'] bin_seg_prob = binary_seg.detach().cpu().numpy() bin_seg_pred = np.argmax(bin_seg_prob, axis=1)[0] # seg = bin_seg_pred * 255 postprocess_result = self.postprocessor.postprocess( binary_seg_result=bin_seg_pred, instance_seg_result=embedding) # cv2.imwrite(str(g_frameCnt)+'_mask.png', postprocess_result['mask_image']) # cv2.imwrite(str(g_frameCnt)+'_binary.png', postprocess_result['binary_img']) return postprocess_result def process(self, frame): startt = time.time() cropImg = cropRoi(frame) input_image = self.transform_input(cropImg) # startt = time.time() postProcResult = self.detection(input_image) self.img = frame.copy() # cv2.imwrite(str(g_frameCnt)+'.png', frame) self.tracker.process(postProcResult['detectedLanes']) llane = self.tracker.leftlane rlane = self.tracker.rightlane lanePoints = {'lanes': [llane.points, rlane.points]} # self.warning = self.warningModule.detect(lanePoints) # if self.warning == 1: # soundplayTh = threading.Thread(target=playWarningSound) # soundplayTh.start() color = (0, 0, 255) if self.warning == 1 else (0, 255, 0) # if signal == 1: # playsound.playsound('/space/warn.mp3') for idx in range(11): cv2.line( self.img, (int(llane.points[idx][0]), int(llane.points[idx][1])), (int(llane.points[idx + 1][0]), int(llane.points[idx + 1][1])), color, 10) cv2.line( self.img, (int(rlane.points[idx][0]), int(rlane.points[idx][1])), (int(rlane.points[idx + 1][0]), int(rlane.points[idx + 1][1])), color, 10) #debug debugImg = frame.copy() for lane in postProcResult['detectedLanes']: if lane[0][0] == llane.detectedLane[0][0]: color = (255, 0, 0) elif lane[0][0] == rlane.detectedLane[0][0]: color = (255, 255, 0) else: color = (0, 255, 0) for j in range(11): cv2.line(debugImg, (lane[j][0], lane[j][1]), (lane[j + 1][0], lane[j + 1][1]), color, 10) cv2.line(debugImg, (0, lane[j][1]), (int(self.image_X), lane[j][1]), (0, 0, 255), 3) # cv2.putText(debugImg, '{}'.format(abs(int(fit_x[5])-960)), (int(fit_x[0]), int(plot_y[0])), cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, thickness=2) cv2.imwrite(str(1) + 'debug.png', debugImg) if postProcResult['mask_image'] is None: print('cant find any lane!!!') else: self.maskimg_pub.publish( self.bridge.cv2_to_imgmsg(postProcResult['mask_image'], "bgr8")) self.binimg_pub.publish( self.bridge.cv2_to_imgmsg(postProcResult['binary_img'], "mono8")) self.morphoimg_pub.publish( self.bridge.cv2_to_imgmsg(postProcResult['morpho_img'], "mono8")) self.image_pub.publish(self.bridge.cv2_to_imgmsg(self.img, "bgr8")) #debug end print('lanedet all use:', time.time() - startt) def drawYoloResult(self, data): for i in range(data.objNum): box = data.bounding_boxes[i] cv2.rectangle(self.img, (int(box.xmin), int(box.ymin)), (int(box.xmax), int(box.ymax)), (0, 255, 0), 3) t_size = cv2.getTextSize(box.id, cv2.FONT_HERSHEY_PLAIN, 1, 1)[0] c2 = int(box.xmin) + t_size[0] + 3, int(box.ymin) + t_size[1] + 4 cv2.rectangle(self.img, (int(box.xmin), int(box.ymin)), c2, (0, 255, 0), -1) cv2.putText(self.img, box.id, (int(box.xmin), int(box.ymin) + 10), cv2.FONT_HERSHEY_PLAIN, 1, (225, 255, 255), 1) def callbackRos(self, data): print('callbackros') cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8") self.img = cv_image.copy() # self.out.write(cv_image) self.process(cv_image) #cv2.imwrite('cvimage.png', cv_image) # cv2.imwrite('result.png', cv_image) self.drawYoloResult(self.yoloBoxes) self.drawYoloResult(self.trafficLightBoxes) self.image_pub.publish(self.bridge.cv2_to_imgmsg(self.img, "bgr8"))
class Lane_warning: def __init__(self): self.image_pub = rospy.Publisher("lanedetframe", Image,queue_size = 1) # self.bridge = CvBridge() # self.image_sub = rospy.Subscriber("YOLO_detect_result", Image, self.callbackRos) self.image_sub = rospy.Subscriber("/camera/rgb/image_raw", Image, self.callbackRos) # self.yolobbox_sub = rospy.Subscriber("publishers/bounding_boxes/topic", BoundingBoxes, self.callbackRos) self.weights_file = '/home/iairiv/code/lane_waring_final/experiments/exp1/exp1_best.pth' self.CUDA = torch.cuda.is_available() self.postprocessor = LaneNetPostProcessor() self.warning = Detection() self.band_width = 1.5 self.image_X = 1920 self.image_Y = 1200 self.car_X = self.image_X/2 self.car_Y = self.image_Y self.model = LaneNet(pretrained=False, embed_dim=4, delta_v=.5, delta_d=3.) self.save_dict = torch.load(self.weights_file, map_location='cuda:0') self.model.load_state_dict(self.save_dict['net']) # self.model.load_state_dict(torch.load(self.weights_file, map_location='cuda:0')) if self.CUDA: self.model.cuda() self.model.set_test() self.lastlane = np.ndarray(4,) self.bridge = CvBridge() def transform_input(self, img): return prep_image(img) def detection(self, input,raw): if self.CUDA: input = input.cuda() with torch.no_grad(): output = self.model(input, None) return self.cluster(output,raw) def cluster(self,output,raw): global i embedding = output['embedding'] embedding = embedding.detach().cpu().numpy() embedding = np.transpose(embedding[0], (1, 2, 0)) binary_seg = output['binary_seg'] bin_seg_prob = binary_seg.detach().cpu().numpy() bin_seg_pred = np.argmax(bin_seg_prob, axis=1)[0] # plt.savefig('a.png') # i = time.time() cv2.imwrite(str(i)+'a.png',bin_seg_pred) i=i+1 # cv2.waitKey(0) #plt.show() seg = bin_seg_pred * 255 # print('postprocess_result') postprocess_result = self.postprocessor.postprocess( binary_seg_result=bin_seg_pred, instance_seg_result=embedding ) prediction = postprocess_result prediction = np.array(prediction) return prediction """"没加预警""" # def write(self, output, img): # # output[:,:,1] = output[:,:,1]+255 # for i in range(len(output)): # line = np.array(output[i]) # line[:,1] = line[:,1]+255 # output[i] = line.tolist() # # for j in range(len(output[i])): # # output[i][j][1] = output[i][j][1] + 255 # # print(arr[i][j]) # # cv.circle(image, (int(arr[i][j][0]),int(arr[i][j][1])), 5, (0, 0, 213), -1) #画成小圆点 # cv.line(img, (int(output[i][0][0]), int(output[i][0][1])), (int(output[i][-1][0]), int(output[i][-1][1])), # (0,0,255), 3) # # if signal == 1: # # cv2.putText(img, "WARNING", (1300, 150), cv2.FONT_HERSHEY_SIMPLEX, 3.0, color, thickness=10) # #plt.imshow(img) # #plt.show() # return img """""加了预警""""" # def write(self, output, img,signal,color): # # output[:,:,1] = output[:,:,1]+255 # for i in range(len(output)): # line = np.array(output[i]) # # line[:,1] = line[:,1]+255 # output[i] = line.tolist() # #for j in range(len(output[i])): # #output[i][j][1] = output[i][j][1] + 255 # #print(output[i][j]) # #cv2.circle(img, (int(output[i][j][0]),int(output[i][j][1])), 5, color, -1) #画成小圆点 # cv2.line(img, (int(output[i][0][0]), int(output[i][0][1])), (int(output[i][-1][0]), int(output[i][-1][1])),color, 3) # if signal == 1: # cv2.putText(img, "WARNING", (1300, 150), cv2.FONT_HERSHEY_SIMPLEX, 3.0, color, thickness=10) # #plt.imshow(img) # #plt.show() # return img def write_nowarning(self, output, img): # output[:,:,1] = output[:,:,1]+255 for i in range(len(output)): line = np.array(output[i]) # line[:,1] = line[:,1]+255 output[i] = line.tolist() #for j in range(len(output[i])): #output[i][j][1] = output[i][j][1] + 255 #print(output[i][j]) #cv2.circle(img, (int(output[i][j][0]),int(output[i][j][1])), 5, color, -1) #画成小圆点 cv2.line(img, (int(output[i][0][0]), int(output[i][0][1])), (int(output[i][-1][0]), int(output[i][-1][1])),(0,0,255), 3) #plt.imshow(img) #plt.show() return img def color(self, signal): if signal == 0: color = (0, 255, 0) else: color = (0, 0, 255) return color #ros下的代码,还没测试过。无ros用另一个测。 def callbackRos(self, data): # print('callbackros') try: cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8") # cv_image = cv2.resize(cv_image, (0, 0), fx=0.5, fy=0.5, interpolation=cv2.INTER_NEAREST) input_image = self.transform_input(cv_image) prediction = self.detection(input_image, cv_image) if len(prediction) == 0: result = cv_image else: print(prediction) # signal = self.warning.detect(prediction) # color = self.color(signal) # result = self.write(prediction, cv_image, signal, color) # result = self.write_nowarning(prediction, cv_image) except CvBridgeError as e: print(e) # cv2.imshow("image windows", result) # cv2.waitKey(3) # try: # self.image_pub.publish(self.bridge.cv2_to_imgmsg(result, "bgr8")) # except CvBridgeError as e: # print(e) def callback(self, data): # try: # cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8") # cv_image = cv2.resize(cv_image, (0, 0), fx=0.5, fy=0.5, interpolation=cv2.INTER_NEAREST) time3 = time.time() input_image = self.transform_input(data) time4 = time.time()-time3 print('数据预处理时间:',time4) #lane_detect time5 = time.time() prediction = self.detection(input_image,data) # if len(prediction) == 0: # prediction = self.lastlane # else: # self.lastlane = prediction #print(prediction) time6 = time.time()-time5 print('检测时间:', time6) #warning time7 = time.time() signal = self.warning.detect(prediction) color = self.color(signal) time8 = time.time()-time7 print('预警时间:',time8) #draw_line time1 = time.time() # img = self.write(prediction, data) img = self.write(prediction, data, signal, color) time2 = time.time()-time1 print('画图时间:',time2) cv2.imshow("final",img) cv2.waitKey(0)