Ejemplo n.º 1
0
class Lane_warning:
    def __init__(self):
        self.image_pub = rospy.Publisher("lanedetframe", Image, queue_size=1)
        self.maskimg_pub = rospy.Publisher("lanedetmask", Image, queue_size=1)
        self.binimg_pub = rospy.Publisher("lanedetbin", Image, queue_size=1)
        self.morphoimg_pub = rospy.Publisher("lanedetmorph",
                                             Image,
                                             queue_size=1)
        # self.bridge = CvBridge()
        self.yolo_sub = rospy.Subscriber("YOLO_detect_result_boxes",
                                         BoundingBoxes, self.callbackyolo)
        self.tlight_sub = rospy.Subscriber("tlight_detect_result_boxes",
                                           BoundingBoxes,
                                           self.callbackTrafficLight)

        # self.image_sub = rospy.Subscriber("YOLO_detect_result", Image, self.callbackRos)
        # self.image_sub = rospy.Subscriber("/camera/image_color", Image, self.callbackRos)
        self.image_sub = rospy.Subscriber("/wideangle/image_color", Image,
                                          self.callbackRos)
        # self.image_sub = message_filters.Subscriber("/camera/rgb/image_raw", Image,queue_size=1, buff_size=110592*6)
        self.weights_file = rospy.get_param("lanenet_weight")
        self.CUDA = torch.cuda.is_available()
        self.postprocessor = LaneNetPostProcessor()
        self.warningModule = Detection()
        self.band_width = 1.5
        self.image_X = CFG.IMAGE_WIDTH
        self.image_Y = CFG.IMAGE_HEIGHT
        self.car_X = self.image_X / 2
        self.car_Y = self.image_Y
        self.model = LaneNet(pretrained=False,
                             embed_dim=4,
                             delta_v=.5,
                             delta_d=3.)
        self.save_dict = torch.load(self.weights_file, map_location='cuda:0')
        self.model.load_state_dict(self.save_dict['net'])
        # self.model.load_state_dict(torch.load(self.weights_file, map_location='cuda:0'))
        if self.CUDA: self.model.cuda()
        self.model.set_test()
        self.lastlane = np.ndarray(4, )
        self.bridge = CvBridge()

        self.leftlane = Lane('left')
        self.rightlane = Lane('right')
        self.tracker = LaneTracker()

        # self.out = cv2.VideoWriter(str(time.time())+'testwrite.avi',cv2.VideoWriter_fourcc(*'MJPG'), 10.0, (CFG.IMAGE_WIDTH, CFG.IMAGE_HEIGHT),True)

        self.img = np.zeros([CFG.IMAGE_WIDTH, CFG.IMAGE_HEIGHT, 3], np.uint8)
        self.yoloBoxes = BoundingBoxes()
        self.trafficLightBoxes = BoundingBoxes()

        self.warning = 0

    def callbackyolo(self, boxmsg):
        print('callbackyolo, boxes len:', boxmsg.objNum)
        self.yoloBoxes.objNum = boxmsg.objNum
        self.yoloBoxes.bounding_boxes = []
        for i in range(boxmsg.objNum):
            print('box id:', boxmsg.bounding_boxes[i].id)
            box = boxmsg.bounding_boxes[i]
            self.yoloBoxes.bounding_boxes.append(box)

            # cv2.rectangle(self.img, (int(box.xmin), int(box.ymin)), (int(box.xmax), int(box.ymax)), (0, 255, 0), 4)

    def callbackTrafficLight(self, boxmsg):
        print('callbackTrafficLight, boxes len:', boxmsg.objNum)
        self.trafficLightBoxes.objNum = boxmsg.objNum
        self.trafficLightBoxes.bounding_boxes = []
        for i in range(boxmsg.objNum):
            print('box id:', boxmsg.bounding_boxes[i].id)
            box = boxmsg.bounding_boxes[i]
            self.trafficLightBoxes.bounding_boxes.append(box)

    def transform_input(self, img):
        _set = "IMAGENET"
        mean = IMG_MEAN[_set]
        std = IMG_STD[_set]
        # transform_img = Resize((800, 288))
        transform_img = Resize((512, 256))
        transform_x = Compose(ToTensor(), Normalize(mean=mean, std=std))
        #img_org = img[255:945, :, :]
        img = transform_img({'img': img})['img']
        x = transform_x({'img': img})['img']
        # print(x)
        x.unsqueeze_(0)
        x = x.to('cuda')
        return x

    def detection(self, input):

        #startt = time.time()
        if self.CUDA:
            input = input.cuda()
        with torch.no_grad():
            output = self.model(input, None)

    # print('detection use:', time.time()-startt)
        return self.cluster(output)

    def cluster(self, output):
        #startt = time.time()

        global g_frameCnt

        embedding = output['embedding']
        embedding = embedding.detach().cpu().numpy()
        embedding = np.transpose(embedding[0], (1, 2, 0))
        binary_seg = output['binary_seg']
        bin_seg_prob = binary_seg.detach().cpu().numpy()
        bin_seg_pred = np.argmax(bin_seg_prob, axis=1)[0]
        # seg = bin_seg_pred * 255

        postprocess_result = self.postprocessor.postprocess(
            binary_seg_result=bin_seg_pred, instance_seg_result=embedding)

        # cv2.imwrite(str(g_frameCnt)+'_mask.png', postprocess_result['mask_image'])
        # cv2.imwrite(str(g_frameCnt)+'_binary.png', postprocess_result['binary_img'])

        return postprocess_result

    def process(self, frame):
        startt = time.time()
        cropImg = cropRoi(frame)
        input_image = self.transform_input(cropImg)
        # startt = time.time()
        postProcResult = self.detection(input_image)

        self.img = frame.copy()
        # cv2.imwrite(str(g_frameCnt)+'.png', frame)

        self.tracker.process(postProcResult['detectedLanes'])

        llane = self.tracker.leftlane
        rlane = self.tracker.rightlane

        lanePoints = {'lanes': [llane.points, rlane.points]}
        # self.warning = self.warningModule.detect(lanePoints)
        # if self.warning == 1:
        #     soundplayTh = threading.Thread(target=playWarningSound)
        #     soundplayTh.start()
        color = (0, 0, 255) if self.warning == 1 else (0, 255, 0)
        # if signal == 1:
        #     playsound.playsound('/space/warn.mp3')
        for idx in range(11):
            cv2.line(
                self.img,
                (int(llane.points[idx][0]), int(llane.points[idx][1])),
                (int(llane.points[idx + 1][0]), int(llane.points[idx + 1][1])),
                color, 10)
            cv2.line(
                self.img,
                (int(rlane.points[idx][0]), int(rlane.points[idx][1])),
                (int(rlane.points[idx + 1][0]), int(rlane.points[idx + 1][1])),
                color, 10)

        #debug
        debugImg = frame.copy()
        for lane in postProcResult['detectedLanes']:
            if lane[0][0] == llane.detectedLane[0][0]:
                color = (255, 0, 0)
            elif lane[0][0] == rlane.detectedLane[0][0]:
                color = (255, 255, 0)
            else:
                color = (0, 255, 0)

            for j in range(11):
                cv2.line(debugImg, (lane[j][0], lane[j][1]),
                         (lane[j + 1][0], lane[j + 1][1]), color, 10)
                cv2.line(debugImg, (0, lane[j][1]),
                         (int(self.image_X), lane[j][1]), (0, 0, 255), 3)
                # cv2.putText(debugImg, '{}'.format(abs(int(fit_x[5])-960)), (int(fit_x[0]), int(plot_y[0])), cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, thickness=2)
        cv2.imwrite(str(1) + 'debug.png', debugImg)

        if postProcResult['mask_image'] is None:
            print('cant find any lane!!!')
        else:
            self.maskimg_pub.publish(
                self.bridge.cv2_to_imgmsg(postProcResult['mask_image'],
                                          "bgr8"))
            self.binimg_pub.publish(
                self.bridge.cv2_to_imgmsg(postProcResult['binary_img'],
                                          "mono8"))
            self.morphoimg_pub.publish(
                self.bridge.cv2_to_imgmsg(postProcResult['morpho_img'],
                                          "mono8"))
            self.image_pub.publish(self.bridge.cv2_to_imgmsg(self.img, "bgr8"))
        #debug end

        print('lanedet all use:', time.time() - startt)

    def drawYoloResult(self, data):
        for i in range(data.objNum):
            box = data.bounding_boxes[i]
            cv2.rectangle(self.img, (int(box.xmin), int(box.ymin)),
                          (int(box.xmax), int(box.ymax)), (0, 255, 0), 3)
            t_size = cv2.getTextSize(box.id, cv2.FONT_HERSHEY_PLAIN, 1, 1)[0]
            c2 = int(box.xmin) + t_size[0] + 3, int(box.ymin) + t_size[1] + 4
            cv2.rectangle(self.img, (int(box.xmin), int(box.ymin)), c2,
                          (0, 255, 0), -1)
            cv2.putText(self.img, box.id, (int(box.xmin), int(box.ymin) + 10),
                        cv2.FONT_HERSHEY_PLAIN, 1, (225, 255, 255), 1)

    def callbackRos(self, data):
        print('callbackros')

        cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8")
        self.img = cv_image.copy()
        # self.out.write(cv_image)

        self.process(cv_image)
        #cv2.imwrite('cvimage.png', cv_image)

        # cv2.imwrite('result.png', cv_image)

        self.drawYoloResult(self.yoloBoxes)
        self.drawYoloResult(self.trafficLightBoxes)

        self.image_pub.publish(self.bridge.cv2_to_imgmsg(self.img, "bgr8"))
Ejemplo n.º 2
0
class Lane_warning:
    def __init__(self):
        self.image_pub = rospy.Publisher("lanedetframe", Image,queue_size = 1)
        # self.bridge = CvBridge()
        # self.image_sub = rospy.Subscriber("YOLO_detect_result", Image, self.callbackRos)
        self.image_sub = rospy.Subscriber("/camera/rgb/image_raw", Image, self.callbackRos)
        # self.yolobbox_sub = rospy.Subscriber("publishers/bounding_boxes/topic", BoundingBoxes, self.callbackRos)
        self.weights_file = '/home/iairiv/code/lane_waring_final/experiments/exp1/exp1_best.pth'
        self.CUDA = torch.cuda.is_available()
        self.postprocessor = LaneNetPostProcessor()
        self.warning = Detection()
        self.band_width = 1.5
        self.image_X = 1920
        self.image_Y = 1200
        self.car_X = self.image_X/2
        self.car_Y = self.image_Y
        self.model = LaneNet(pretrained=False, embed_dim=4, delta_v=.5, delta_d=3.)
        self.save_dict = torch.load(self.weights_file, map_location='cuda:0')
        self.model.load_state_dict(self.save_dict['net'])
        # self.model.load_state_dict(torch.load(self.weights_file, map_location='cuda:0'))
        if self.CUDA: self.model.cuda()
        self.model.set_test()
        self.lastlane = np.ndarray(4,)
        self.bridge = CvBridge()

    def transform_input(self, img):
        return prep_image(img)

    def detection(self, input,raw):
        
        if self.CUDA:
            input = input.cuda()
        with torch.no_grad():
            output = self.model(input, None)

        
        return self.cluster(output,raw)

    def cluster(self,output,raw):
        global i
        embedding = output['embedding']
        embedding = embedding.detach().cpu().numpy()
        embedding = np.transpose(embedding[0], (1, 2, 0))
        binary_seg = output['binary_seg']
        bin_seg_prob = binary_seg.detach().cpu().numpy()
        bin_seg_pred = np.argmax(bin_seg_prob, axis=1)[0]

        
        # plt.savefig('a.png')
        # i = time.time()
        cv2.imwrite(str(i)+'a.png',bin_seg_pred)
        i=i+1
        # cv2.waitKey(0)
        #plt.show()
        seg = bin_seg_pred * 255
        # print('postprocess_result')
        postprocess_result = self.postprocessor.postprocess(
            binary_seg_result=bin_seg_pred,
            instance_seg_result=embedding
        )

        prediction = postprocess_result
        prediction = np.array(prediction)
        
        return prediction
    """"没加预警"""
    # def write(self, output, img):
    #     # output[:,:,1] = output[:,:,1]+255
    #     for i in range(len(output)):
    #         line = np.array(output[i])
    #         line[:,1] = line[:,1]+255
    #         output[i] = line.tolist()
    #         # for j in range(len(output[i])):
    #             # output[i][j][1] = output[i][j][1] + 255
    #             # print(arr[i][j])
    #             # cv.circle(image, (int(arr[i][j][0]),int(arr[i][j][1])), 5, (0, 0, 213), -1) #画成小圆点
    #         cv.line(img, (int(output[i][0][0]), int(output[i][0][1])), (int(output[i][-1][0]), int(output[i][-1][1])),
    #                 (0,0,255), 3)
    #     # if signal == 1:
    #     #     cv2.putText(img, "WARNING", (1300, 150), cv2.FONT_HERSHEY_SIMPLEX, 3.0, color, thickness=10)
    #     #plt.imshow(img)
    #     #plt.show()
    #     return img
    """""加了预警"""""
    # def write(self, output, img,signal,color):
    #     # output[:,:,1] = output[:,:,1]+255
    #     for i in range(len(output)):
    #         line = np.array(output[i])
    #         # line[:,1] = line[:,1]+255
    #         output[i] = line.tolist()
    #         #for j in range(len(output[i])):
    #             #output[i][j][1] = output[i][j][1] + 255
    #             #print(output[i][j])
    #             #cv2.circle(img, (int(output[i][j][0]),int(output[i][j][1])), 5, color, -1) #画成小圆点
    #         cv2.line(img, (int(output[i][0][0]), int(output[i][0][1])), (int(output[i][-1][0]), int(output[i][-1][1])),color, 3)
    #     if signal == 1:
    #         cv2.putText(img, "WARNING", (1300, 150), cv2.FONT_HERSHEY_SIMPLEX, 3.0, color, thickness=10)
    #     #plt.imshow(img)
    #     #plt.show()
    #     return img

    def write_nowarning(self, output, img):
        # output[:,:,1] = output[:,:,1]+255
        for i in range(len(output)):
            line = np.array(output[i])
            # line[:,1] = line[:,1]+255
            output[i] = line.tolist()
            #for j in range(len(output[i])):
                #output[i][j][1] = output[i][j][1] + 255
                #print(output[i][j])
                #cv2.circle(img, (int(output[i][j][0]),int(output[i][j][1])), 5, color, -1) #画成小圆点
            cv2.line(img, (int(output[i][0][0]), int(output[i][0][1])), (int(output[i][-1][0]), int(output[i][-1][1])),(0,0,255), 3)
        #plt.imshow(img)
        #plt.show()
        return img

    def color(self, signal):
        if signal == 0:
            color = (0, 255, 0)
        else:
            color = (0, 0, 255)
        return color
        
    #ros下的代码,还没测试过。无ros用另一个测。
    def callbackRos(self, data):
        # print('callbackros')
        
        try:
            cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8")
            # cv_image = cv2.resize(cv_image, (0, 0), fx=0.5, fy=0.5, interpolation=cv2.INTER_NEAREST)
            input_image = self.transform_input(cv_image)
            prediction = self.detection(input_image, cv_image)
            if len(prediction) == 0:
                result = cv_image
            else:
                print(prediction)
                # signal = self.warning.detect(prediction)
                # color = self.color(signal)
                # result = self.write(prediction, cv_image, signal, color)
                # result = self.write_nowarning(prediction, cv_image)
        except CvBridgeError as e:
            print(e)
    
        # cv2.imshow("image windows", result)
        # cv2.waitKey(3)
        # try:
        #     self.image_pub.publish(self.bridge.cv2_to_imgmsg(result, "bgr8"))
        # except CvBridgeError as e:
        #     print(e)

        
    def callback(self, data):
        # try:
            # cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8")
            # cv_image = cv2.resize(cv_image, (0, 0), fx=0.5, fy=0.5, interpolation=cv2.INTER_NEAREST)
        time3 = time.time()
        input_image = self.transform_input(data)
        time4 = time.time()-time3
        print('数据预处理时间:',time4)

        #lane_detect
        time5 = time.time()
        prediction = self.detection(input_image,data)
        # if len(prediction) == 0:
        #     prediction = self.lastlane
        # else:
        #     self.lastlane = prediction


        #print(prediction)
        time6 = time.time()-time5
        print('检测时间:', time6)

        #warning
        time7 = time.time()
        signal = self.warning.detect(prediction)
        color = self.color(signal)
        time8 = time.time()-time7
        print('预警时间:',time8)


        #draw_line
        time1 = time.time()
        # img = self.write(prediction, data)
        img = self.write(prediction, data, signal, color)
        time2 = time.time()-time1
        print('画图时间:',time2)

        cv2.imshow("final",img)
        cv2.waitKey(0)