示例#1
0
    def __init__(self):
        self.image_pub = rospy.Publisher("lanedetframe", Image, queue_size=1)
        self.maskimg_pub = rospy.Publisher("lanedetmask", Image, queue_size=1)
        self.binimg_pub = rospy.Publisher("lanedetbin", Image, queue_size=1)
        self.morphoimg_pub = rospy.Publisher("lanedetmorph",
                                             Image,
                                             queue_size=1)
        # self.bridge = CvBridge()
        self.yolo_sub = rospy.Subscriber("YOLO_detect_result_boxes",
                                         BoundingBoxes, self.callbackyolo)
        self.tlight_sub = rospy.Subscriber("tlight_detect_result_boxes",
                                           BoundingBoxes,
                                           self.callbackTrafficLight)

        # self.image_sub = rospy.Subscriber("YOLO_detect_result", Image, self.callbackRos)
        # self.image_sub = rospy.Subscriber("/camera/image_color", Image, self.callbackRos)
        self.image_sub = rospy.Subscriber("/wideangle/image_color", Image,
                                          self.callbackRos)
        # self.image_sub = message_filters.Subscriber("/camera/rgb/image_raw", Image,queue_size=1, buff_size=110592*6)
        self.weights_file = rospy.get_param("lanenet_weight")
        self.CUDA = torch.cuda.is_available()
        self.postprocessor = LaneNetPostProcessor()
        self.warningModule = Detection()
        self.band_width = 1.5
        self.image_X = CFG.IMAGE_WIDTH
        self.image_Y = CFG.IMAGE_HEIGHT
        self.car_X = self.image_X / 2
        self.car_Y = self.image_Y
        self.model = LaneNet(pretrained=False,
                             embed_dim=4,
                             delta_v=.5,
                             delta_d=3.)
        self.save_dict = torch.load(self.weights_file, map_location='cuda:0')
        self.model.load_state_dict(self.save_dict['net'])
        # self.model.load_state_dict(torch.load(self.weights_file, map_location='cuda:0'))
        if self.CUDA: self.model.cuda()
        self.model.set_test()
        self.lastlane = np.ndarray(4, )
        self.bridge = CvBridge()

        self.leftlane = Lane('left')
        self.rightlane = Lane('right')
        self.tracker = LaneTracker()

        # self.out = cv2.VideoWriter(str(time.time())+'testwrite.avi',cv2.VideoWriter_fourcc(*'MJPG'), 10.0, (CFG.IMAGE_WIDTH, CFG.IMAGE_HEIGHT),True)

        self.img = np.zeros([CFG.IMAGE_WIDTH, CFG.IMAGE_HEIGHT, 3], np.uint8)
        self.yoloBoxes = BoundingBoxes()
        self.trafficLightBoxes = BoundingBoxes()

        self.warning = 0
示例#2
0
 def _load_model(self):
     model=LaneNet()
     if self.mode=='parallel':
         model=DataParallel(model)
     model.load_state_dict(torch.load(self.model_path))
     model=model.cuda()
     return model
示例#3
0
 def __init__(self):
     self.image_pub = rospy.Publisher("lanedetframe", Image,queue_size = 1)
     # self.bridge = CvBridge()
     # self.image_sub = rospy.Subscriber("YOLO_detect_result", Image, self.callbackRos)
     self.image_sub = rospy.Subscriber("/camera/rgb/image_raw", Image, self.callbackRos)
     # self.yolobbox_sub = rospy.Subscriber("publishers/bounding_boxes/topic", BoundingBoxes, self.callbackRos)
     self.weights_file = '/home/iairiv/code/lane_waring_final/experiments/exp1/exp1_best.pth'
     self.CUDA = torch.cuda.is_available()
     self.postprocessor = LaneNetPostProcessor()
     self.warning = Detection()
     self.band_width = 1.5
     self.image_X = 1920
     self.image_Y = 1200
     self.car_X = self.image_X/2
     self.car_Y = self.image_Y
     self.model = LaneNet(pretrained=False, embed_dim=4, delta_v=.5, delta_d=3.)
     self.save_dict = torch.load(self.weights_file, map_location='cuda:0')
     self.model.load_state_dict(self.save_dict['net'])
     # self.model.load_state_dict(torch.load(self.weights_file, map_location='cuda:0'))
     if self.CUDA: self.model.cuda()
     self.model.set_test()
     self.lastlane = np.ndarray(4,)
     self.bridge = CvBridge()
示例#4
0
def main():
    args = parse_args()
    img_path = args.img_path
    weight_path = args.weight_path

    _set = "IMAGENET"
    mean = IMG_MEAN[_set]
    std = IMG_STD[_set]
    transform_img = Resize((800, 288))
    transform_x = Compose(ToTensor(), Normalize(mean=mean, std=std))
    transform = Compose(transform_img, transform_x)

    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # RGB for net model input
    img = transform_img({'img': img})['img']
    x = transform_x({'img': img})['img']
    x.unsqueeze_(0)

    net = LaneNet(pretrained=False, embed_dim=4, delta_v=.5, delta_d=3.)
    save_dict = torch.load(weight_path, map_location='cpu')
    net.load_state_dict(save_dict['net'])
    net.eval()

    output = net(x)
    embedding = output['embedding']
    embedding = embedding.detach().cpu().numpy()
    embedding = np.transpose(embedding[0], (1, 2, 0))
    binary_seg = output['binary_seg']
    bin_seg_prob = binary_seg.detach().cpu().numpy()
    bin_seg_pred = np.argmax(bin_seg_prob, axis=1)[0]

    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    seg_img = np.zeros_like(img)
    lane_seg_img = embedding_post_process(embedding, bin_seg_pred,
                                          args.band_width, 4)
    color = np.array([[255, 125, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255]],
                     dtype='uint8')
    for i, lane_idx in enumerate(np.unique(lane_seg_img)):
        if lane_idx == 0:
            continue
        seg_img[lane_seg_img == lane_idx] = color[i - 1]
    img = cv2.addWeighted(src1=seg_img, alpha=0.8, src2=img, beta=1., gamma=0.)

    cv2.imwrite("demo/demo_result.jpg", img)

    if args.visualize:
        cv2.imshow("", img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
示例#5
0
                folders.insert(0, path)
            break
    return folders

# ------------ data and model ------------
# Imagenet mean, std
mean=(0.485, 0.456, 0.406)
std=(0.229, 0.224, 0.225)
transform = Compose(Resize(resize_shape), ToTensor(),
                    Normalize(mean=mean, std=std))
dataset_name = exp_cfg['dataset'].pop('dataset_name')
Dataset_Type = getattr(dataset, dataset_name)
test_dataset = Dataset_Type(Dataset_Path['Tusimple'], "test", transform)
test_loader = DataLoader(test_dataset, batch_size=32, collate_fn=test_dataset.collate, num_workers=4)

net = LaneNet(pretrained=True, **exp_cfg['model'])
save_name = os.path.join(exp_dir, exp_dir.split('/')[-1] + '_best.pth')
save_dict = torch.load(save_name, map_location='cpu')
print("\nloading", save_name, "...... From Epoch: ", save_dict['epoch'])
net.load_state_dict(save_dict['net'])
net = torch.nn.DataParallel(net.to(device))
net.eval()

# ------------ test ------------
out_path = os.path.join(exp_dir, "coord_output")
evaluation_path = os.path.join(exp_dir, "evaluate")
if not os.path.exists(out_path):
    os.mkdir(out_path)
if not os.path.exists(evaluation_path):
    os.mkdir(evaluation_path)
dump_to_json = []
示例#6
0
class Lane_warning:
    def __init__(self):
        self.image_pub = rospy.Publisher("lanedetframe", Image, queue_size=1)
        self.maskimg_pub = rospy.Publisher("lanedetmask", Image, queue_size=1)
        self.binimg_pub = rospy.Publisher("lanedetbin", Image, queue_size=1)
        self.morphoimg_pub = rospy.Publisher("lanedetmorph",
                                             Image,
                                             queue_size=1)
        # self.bridge = CvBridge()
        self.yolo_sub = rospy.Subscriber("YOLO_detect_result_boxes",
                                         BoundingBoxes, self.callbackyolo)
        self.tlight_sub = rospy.Subscriber("tlight_detect_result_boxes",
                                           BoundingBoxes,
                                           self.callbackTrafficLight)

        # self.image_sub = rospy.Subscriber("YOLO_detect_result", Image, self.callbackRos)
        # self.image_sub = rospy.Subscriber("/camera/image_color", Image, self.callbackRos)
        self.image_sub = rospy.Subscriber("/wideangle/image_color", Image,
                                          self.callbackRos)
        # self.image_sub = message_filters.Subscriber("/camera/rgb/image_raw", Image,queue_size=1, buff_size=110592*6)
        self.weights_file = rospy.get_param("lanenet_weight")
        self.CUDA = torch.cuda.is_available()
        self.postprocessor = LaneNetPostProcessor()
        self.warningModule = Detection()
        self.band_width = 1.5
        self.image_X = CFG.IMAGE_WIDTH
        self.image_Y = CFG.IMAGE_HEIGHT
        self.car_X = self.image_X / 2
        self.car_Y = self.image_Y
        self.model = LaneNet(pretrained=False,
                             embed_dim=4,
                             delta_v=.5,
                             delta_d=3.)
        self.save_dict = torch.load(self.weights_file, map_location='cuda:0')
        self.model.load_state_dict(self.save_dict['net'])
        # self.model.load_state_dict(torch.load(self.weights_file, map_location='cuda:0'))
        if self.CUDA: self.model.cuda()
        self.model.set_test()
        self.lastlane = np.ndarray(4, )
        self.bridge = CvBridge()

        self.leftlane = Lane('left')
        self.rightlane = Lane('right')
        self.tracker = LaneTracker()

        # self.out = cv2.VideoWriter(str(time.time())+'testwrite.avi',cv2.VideoWriter_fourcc(*'MJPG'), 10.0, (CFG.IMAGE_WIDTH, CFG.IMAGE_HEIGHT),True)

        self.img = np.zeros([CFG.IMAGE_WIDTH, CFG.IMAGE_HEIGHT, 3], np.uint8)
        self.yoloBoxes = BoundingBoxes()
        self.trafficLightBoxes = BoundingBoxes()

        self.warning = 0

    def callbackyolo(self, boxmsg):
        print('callbackyolo, boxes len:', boxmsg.objNum)
        self.yoloBoxes.objNum = boxmsg.objNum
        self.yoloBoxes.bounding_boxes = []
        for i in range(boxmsg.objNum):
            print('box id:', boxmsg.bounding_boxes[i].id)
            box = boxmsg.bounding_boxes[i]
            self.yoloBoxes.bounding_boxes.append(box)

            # cv2.rectangle(self.img, (int(box.xmin), int(box.ymin)), (int(box.xmax), int(box.ymax)), (0, 255, 0), 4)

    def callbackTrafficLight(self, boxmsg):
        print('callbackTrafficLight, boxes len:', boxmsg.objNum)
        self.trafficLightBoxes.objNum = boxmsg.objNum
        self.trafficLightBoxes.bounding_boxes = []
        for i in range(boxmsg.objNum):
            print('box id:', boxmsg.bounding_boxes[i].id)
            box = boxmsg.bounding_boxes[i]
            self.trafficLightBoxes.bounding_boxes.append(box)

    def transform_input(self, img):
        _set = "IMAGENET"
        mean = IMG_MEAN[_set]
        std = IMG_STD[_set]
        # transform_img = Resize((800, 288))
        transform_img = Resize((512, 256))
        transform_x = Compose(ToTensor(), Normalize(mean=mean, std=std))
        #img_org = img[255:945, :, :]
        img = transform_img({'img': img})['img']
        x = transform_x({'img': img})['img']
        # print(x)
        x.unsqueeze_(0)
        x = x.to('cuda')
        return x

    def detection(self, input):

        #startt = time.time()
        if self.CUDA:
            input = input.cuda()
        with torch.no_grad():
            output = self.model(input, None)

    # print('detection use:', time.time()-startt)
        return self.cluster(output)

    def cluster(self, output):
        #startt = time.time()

        global g_frameCnt

        embedding = output['embedding']
        embedding = embedding.detach().cpu().numpy()
        embedding = np.transpose(embedding[0], (1, 2, 0))
        binary_seg = output['binary_seg']
        bin_seg_prob = binary_seg.detach().cpu().numpy()
        bin_seg_pred = np.argmax(bin_seg_prob, axis=1)[0]
        # seg = bin_seg_pred * 255

        postprocess_result = self.postprocessor.postprocess(
            binary_seg_result=bin_seg_pred, instance_seg_result=embedding)

        # cv2.imwrite(str(g_frameCnt)+'_mask.png', postprocess_result['mask_image'])
        # cv2.imwrite(str(g_frameCnt)+'_binary.png', postprocess_result['binary_img'])

        return postprocess_result

    def process(self, frame):
        startt = time.time()
        cropImg = cropRoi(frame)
        input_image = self.transform_input(cropImg)
        # startt = time.time()
        postProcResult = self.detection(input_image)

        self.img = frame.copy()
        # cv2.imwrite(str(g_frameCnt)+'.png', frame)

        self.tracker.process(postProcResult['detectedLanes'])

        llane = self.tracker.leftlane
        rlane = self.tracker.rightlane

        lanePoints = {'lanes': [llane.points, rlane.points]}
        # self.warning = self.warningModule.detect(lanePoints)
        # if self.warning == 1:
        #     soundplayTh = threading.Thread(target=playWarningSound)
        #     soundplayTh.start()
        color = (0, 0, 255) if self.warning == 1 else (0, 255, 0)
        # if signal == 1:
        #     playsound.playsound('/space/warn.mp3')
        for idx in range(11):
            cv2.line(
                self.img,
                (int(llane.points[idx][0]), int(llane.points[idx][1])),
                (int(llane.points[idx + 1][0]), int(llane.points[idx + 1][1])),
                color, 10)
            cv2.line(
                self.img,
                (int(rlane.points[idx][0]), int(rlane.points[idx][1])),
                (int(rlane.points[idx + 1][0]), int(rlane.points[idx + 1][1])),
                color, 10)

        #debug
        debugImg = frame.copy()
        for lane in postProcResult['detectedLanes']:
            if lane[0][0] == llane.detectedLane[0][0]:
                color = (255, 0, 0)
            elif lane[0][0] == rlane.detectedLane[0][0]:
                color = (255, 255, 0)
            else:
                color = (0, 255, 0)

            for j in range(11):
                cv2.line(debugImg, (lane[j][0], lane[j][1]),
                         (lane[j + 1][0], lane[j + 1][1]), color, 10)
                cv2.line(debugImg, (0, lane[j][1]),
                         (int(self.image_X), lane[j][1]), (0, 0, 255), 3)
                # cv2.putText(debugImg, '{}'.format(abs(int(fit_x[5])-960)), (int(fit_x[0]), int(plot_y[0])), cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, thickness=2)
        cv2.imwrite(str(1) + 'debug.png', debugImg)

        if postProcResult['mask_image'] is None:
            print('cant find any lane!!!')
        else:
            self.maskimg_pub.publish(
                self.bridge.cv2_to_imgmsg(postProcResult['mask_image'],
                                          "bgr8"))
            self.binimg_pub.publish(
                self.bridge.cv2_to_imgmsg(postProcResult['binary_img'],
                                          "mono8"))
            self.morphoimg_pub.publish(
                self.bridge.cv2_to_imgmsg(postProcResult['morpho_img'],
                                          "mono8"))
            self.image_pub.publish(self.bridge.cv2_to_imgmsg(self.img, "bgr8"))
        #debug end

        print('lanedet all use:', time.time() - startt)

    def drawYoloResult(self, data):
        for i in range(data.objNum):
            box = data.bounding_boxes[i]
            cv2.rectangle(self.img, (int(box.xmin), int(box.ymin)),
                          (int(box.xmax), int(box.ymax)), (0, 255, 0), 3)
            t_size = cv2.getTextSize(box.id, cv2.FONT_HERSHEY_PLAIN, 1, 1)[0]
            c2 = int(box.xmin) + t_size[0] + 3, int(box.ymin) + t_size[1] + 4
            cv2.rectangle(self.img, (int(box.xmin), int(box.ymin)), c2,
                          (0, 255, 0), -1)
            cv2.putText(self.img, box.id, (int(box.xmin), int(box.ymin) + 10),
                        cv2.FONT_HERSHEY_PLAIN, 1, (225, 255, 255), 1)

    def callbackRos(self, data):
        print('callbackros')

        cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8")
        self.img = cv_image.copy()
        # self.out.write(cv_image)

        self.process(cv_image)
        #cv2.imwrite('cvimage.png', cv_image)

        # cv2.imwrite('result.png', cv_image)

        self.drawYoloResult(self.yoloBoxes)
        self.drawYoloResult(self.trafficLightBoxes)

        self.image_pub.publish(self.bridge.cv2_to_imgmsg(self.img, "bgr8"))
示例#7
0
class Lane_warning:
    def __init__(self):
        self.image_pub = rospy.Publisher("lanedetframe", Image,queue_size = 1)
        # self.bridge = CvBridge()
        # self.image_sub = rospy.Subscriber("YOLO_detect_result", Image, self.callbackRos)
        self.image_sub = rospy.Subscriber("/camera/rgb/image_raw", Image, self.callbackRos)
        # self.yolobbox_sub = rospy.Subscriber("publishers/bounding_boxes/topic", BoundingBoxes, self.callbackRos)
        self.weights_file = '/home/iairiv/code/lane_waring_final/experiments/exp1/exp1_best.pth'
        self.CUDA = torch.cuda.is_available()
        self.postprocessor = LaneNetPostProcessor()
        self.warning = Detection()
        self.band_width = 1.5
        self.image_X = 1920
        self.image_Y = 1200
        self.car_X = self.image_X/2
        self.car_Y = self.image_Y
        self.model = LaneNet(pretrained=False, embed_dim=4, delta_v=.5, delta_d=3.)
        self.save_dict = torch.load(self.weights_file, map_location='cuda:0')
        self.model.load_state_dict(self.save_dict['net'])
        # self.model.load_state_dict(torch.load(self.weights_file, map_location='cuda:0'))
        if self.CUDA: self.model.cuda()
        self.model.set_test()
        self.lastlane = np.ndarray(4,)
        self.bridge = CvBridge()

    def transform_input(self, img):
        return prep_image(img)

    def detection(self, input,raw):
        
        if self.CUDA:
            input = input.cuda()
        with torch.no_grad():
            output = self.model(input, None)

        
        return self.cluster(output,raw)

    def cluster(self,output,raw):
        global i
        embedding = output['embedding']
        embedding = embedding.detach().cpu().numpy()
        embedding = np.transpose(embedding[0], (1, 2, 0))
        binary_seg = output['binary_seg']
        bin_seg_prob = binary_seg.detach().cpu().numpy()
        bin_seg_pred = np.argmax(bin_seg_prob, axis=1)[0]

        
        # plt.savefig('a.png')
        # i = time.time()
        cv2.imwrite(str(i)+'a.png',bin_seg_pred)
        i=i+1
        # cv2.waitKey(0)
        #plt.show()
        seg = bin_seg_pred * 255
        # print('postprocess_result')
        postprocess_result = self.postprocessor.postprocess(
            binary_seg_result=bin_seg_pred,
            instance_seg_result=embedding
        )

        prediction = postprocess_result
        prediction = np.array(prediction)
        
        return prediction
    """"没加预警"""
    # def write(self, output, img):
    #     # output[:,:,1] = output[:,:,1]+255
    #     for i in range(len(output)):
    #         line = np.array(output[i])
    #         line[:,1] = line[:,1]+255
    #         output[i] = line.tolist()
    #         # for j in range(len(output[i])):
    #             # output[i][j][1] = output[i][j][1] + 255
    #             # print(arr[i][j])
    #             # cv.circle(image, (int(arr[i][j][0]),int(arr[i][j][1])), 5, (0, 0, 213), -1) #画成小圆点
    #         cv.line(img, (int(output[i][0][0]), int(output[i][0][1])), (int(output[i][-1][0]), int(output[i][-1][1])),
    #                 (0,0,255), 3)
    #     # if signal == 1:
    #     #     cv2.putText(img, "WARNING", (1300, 150), cv2.FONT_HERSHEY_SIMPLEX, 3.0, color, thickness=10)
    #     #plt.imshow(img)
    #     #plt.show()
    #     return img
    """""加了预警"""""
    # def write(self, output, img,signal,color):
    #     # output[:,:,1] = output[:,:,1]+255
    #     for i in range(len(output)):
    #         line = np.array(output[i])
    #         # line[:,1] = line[:,1]+255
    #         output[i] = line.tolist()
    #         #for j in range(len(output[i])):
    #             #output[i][j][1] = output[i][j][1] + 255
    #             #print(output[i][j])
    #             #cv2.circle(img, (int(output[i][j][0]),int(output[i][j][1])), 5, color, -1) #画成小圆点
    #         cv2.line(img, (int(output[i][0][0]), int(output[i][0][1])), (int(output[i][-1][0]), int(output[i][-1][1])),color, 3)
    #     if signal == 1:
    #         cv2.putText(img, "WARNING", (1300, 150), cv2.FONT_HERSHEY_SIMPLEX, 3.0, color, thickness=10)
    #     #plt.imshow(img)
    #     #plt.show()
    #     return img

    def write_nowarning(self, output, img):
        # output[:,:,1] = output[:,:,1]+255
        for i in range(len(output)):
            line = np.array(output[i])
            # line[:,1] = line[:,1]+255
            output[i] = line.tolist()
            #for j in range(len(output[i])):
                #output[i][j][1] = output[i][j][1] + 255
                #print(output[i][j])
                #cv2.circle(img, (int(output[i][j][0]),int(output[i][j][1])), 5, color, -1) #画成小圆点
            cv2.line(img, (int(output[i][0][0]), int(output[i][0][1])), (int(output[i][-1][0]), int(output[i][-1][1])),(0,0,255), 3)
        #plt.imshow(img)
        #plt.show()
        return img

    def color(self, signal):
        if signal == 0:
            color = (0, 255, 0)
        else:
            color = (0, 0, 255)
        return color
        
    #ros下的代码,还没测试过。无ros用另一个测。
    def callbackRos(self, data):
        # print('callbackros')
        
        try:
            cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8")
            # cv_image = cv2.resize(cv_image, (0, 0), fx=0.5, fy=0.5, interpolation=cv2.INTER_NEAREST)
            input_image = self.transform_input(cv_image)
            prediction = self.detection(input_image, cv_image)
            if len(prediction) == 0:
                result = cv_image
            else:
                print(prediction)
                # signal = self.warning.detect(prediction)
                # color = self.color(signal)
                # result = self.write(prediction, cv_image, signal, color)
                # result = self.write_nowarning(prediction, cv_image)
        except CvBridgeError as e:
            print(e)
    
        # cv2.imshow("image windows", result)
        # cv2.waitKey(3)
        # try:
        #     self.image_pub.publish(self.bridge.cv2_to_imgmsg(result, "bgr8"))
        # except CvBridgeError as e:
        #     print(e)

        
    def callback(self, data):
        # try:
            # cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8")
            # cv_image = cv2.resize(cv_image, (0, 0), fx=0.5, fy=0.5, interpolation=cv2.INTER_NEAREST)
        time3 = time.time()
        input_image = self.transform_input(data)
        time4 = time.time()-time3
        print('数据预处理时间:',time4)

        #lane_detect
        time5 = time.time()
        prediction = self.detection(input_image,data)
        # if len(prediction) == 0:
        #     prediction = self.lastlane
        # else:
        #     self.lastlane = prediction


        #print(prediction)
        time6 = time.time()-time5
        print('检测时间:', time6)

        #warning
        time7 = time.time()
        signal = self.warning.detect(prediction)
        color = self.color(signal)
        time8 = time.time()-time7
        print('预警时间:',time8)


        #draw_line
        time1 = time.time()
        # img = self.write(prediction, data)
        img = self.write(prediction, data, signal, color)
        time2 = time.time()-time1
        print('画图时间:',time2)

        cv2.imshow("final",img)
        cv2.waitKey(0)
示例#8
0
# ------------ data and model ------------
# Imagenet mean, std

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
transform = Compose(Resize(resize_shape), ToTensor(),
                    Normalize(mean=mean, std=std))
dataset_name = exp_cfg['dataset'].pop('dataset_name')
Dataset_Type = getattr(dataset, dataset_name)
test_dataset = Dataset_Type(Dataset_Path['Tusimple'], "test", transform)
test_loader = DataLoader(test_dataset,
                         batch_size=16,
                         collate_fn=test_dataset.collate,
                         num_workers=4)

net = LaneNet(pretrained=False, **exp_cfg['model'])
save_name = os.path.join(exp_dir, exp_dir.split('/')[-1] + '_best.pth')
save_dict = torch.load(save_name, map_location='cuda:0')
print("\nloading", save_name, "...... From Epoch: ", save_dict['epoch'])
net.load_state_dict(save_dict['net'])
# net = torch.nn.DataParallel(net.to(device))
net = net.to(device)
net.eval()

# ------------ test ------------
out_path = os.path.join(exp_dir, "coord_output")
evaluation_path = os.path.join(exp_dir, "evaluate")
if not os.path.exists(out_path):
    os.mkdir(out_path)
if not os.path.exists(evaluation_path):
    os.mkdir(evaluation_path)
示例#9
0
    ap = argparse.ArgumentParser()

    ap.add_argument('-e', '--epoch', default=50)  #Epoch
    ap.add_argument('-b', '--batch', default=2)  #Batch_size
    ap.add_argument('-dv', '--delta_v', default=.5)  #delta_v
    ap.add_argument('-dd', '--delta_d', default=3)  #delta_d
    ap.add_argument('-l', '--learning_rate', default=5e-4)  #learning_rate
    ap.add_argument('-o', '--optimizer', default='Adam')  #optimizer
    ap.add_argument('-d', '--device', default='GPU')  #training device
    ap.add_argument('-t', '--test_ratio', default=.1)
    ap.add_argument('-s', '--stage', default='new')
    #ap.add_argument('-cl','--class_weight',default=.5)
    #ap.add_argument()
    #ap.add_argument()

    args = vars(ap.parse_args())
    train_indices, test_indices = split_dataset(args['test_ratio'])
    data = build_sampler(TusimpleData('./data', transform=Rescale((256, 512))),
                         args['batch'], 1, train_indices, test_indices)

    if args['stage'] == 'new':
        model = LaneNet()
    else:
        model_file = 'model_1548830895.pkl'
        model = LaneNet()
        weight_dict = torch.load(os.path.join('./logs/models/', model_file))
        model.load_state_dict(weight_dict.state_dict())

    train(model, data, args['epoch'], args['batch'], args['delta_v'],
          args['delta_d'], args['learning_rate'], args['optimizer'])
示例#10
0
                          batch_size=exp_cfg['dataset']['batch_size'],
                          shuffle=True,
                          collate_fn=train_dataset.collate,
                          num_workers=8)

# ------------ val data ------------
transform_val = Compose(Resize(resize_shape), ToTensor(),
                        Normalize(mean=mean, std=std))
val_dataset = Dataset_Type(Dataset_Path[dataset_name], "val", transform_val)
val_loader = DataLoader(val_dataset,
                        batch_size=8,
                        collate_fn=val_dataset.collate,
                        num_workers=4)

# ------------ preparation ------------
net = LaneNet(pretrained=True, **exp_cfg['model'])
net = net.to(device)
net = torch.nn.DataParallel(net)

optimizer = optim.SGD(net.parameters(), **exp_cfg['optim'])
lr_scheduler = PolyLR(optimizer, 0.9, exp_cfg['MAX_ITER'])
best_val_loss = 1e6


def train(epoch):
    print("Train Epoch: {}".format(epoch))
    net.train()
    train_loss = 0
    train_loss_bin_seg = 0
    train_loss_var = 0
    train_loss_dist = 0
示例#11
0
    torch.save(
        model, os.path.join('./logs/models',
                            'model_{}.pkl'.format(start_time)))
    log.close()


if __name__ == '__main__':
    ap = argparse.ArgumentParser()

    ap.add_argument('-e', '--epoch', default=30)  #Epoch
    ap.add_argument('-b', '--batch', default=16)  #Batch_size
    ap.add_argument('-dv', '--delta_v', default=.5)  #delta_v
    ap.add_argument('-dd', '--delta_d', default=3)  #delta_d
    ap.add_argument('-l', '--learning_rate', default=5e-4)  #learning_rate
    ap.add_argument('-o', '--optimizer', default='Adam')  #optimizer
    ap.add_argument('-d', '--device', default='GPU')  #training device
    ap.add_argument('-t', '--test_ratio', default=.1)
    #ap.add_argument('-cl','--class_weight',default=.5)
    #ap.add_argument()
    #ap.add_argument()

    args = vars(ap.parse_args())

    train_indices, test_indices = split_dataset(args['test_ratio'])
    data = build_sampler(TusimpleData('./data', transform=Rescale((256, 512))),
                         args['batch'], 1, train_indices, test_indices)
    model = LaneNet()

    train(model, data, args['epoch'], args['batch'], args['delta_v'],
          args['delta_d'], args['learning_rate'], args['optimizer'])
示例#12
0
import argparse
import cv2
import torch

from model import LaneNet
from utils.transforms import *
from utils.postprocess import embedding_post_process

net = LaneNet(pretrained=False, embed_dim=7, delta_v=.5, delta_d=3.)
transform = Compose(Resize((800, 288)), ToTensor(),
                    Normalize(mean=(0.3598, 0.3653, 0.3662), std=(0.2573, 0.2663, 0.2756)))


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--img_path", '-i', type=str, default="demo/demo.jpg", help="Path to demo img")
    parser.add_argument("--weight_path", '-w', type=str, help="Path to model weights")
    parser.add_argument("--delta_v", '-dv', type=float, default=0.5, help="Value of delta_v")
    parser.add_argument("--visualize", '-v', action="store_true", default=False, help="Visualize the result")
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    img_path = args.img_path
    weight_path = args.weight_path

    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB for net model input
    x = transform(img)[0]