def __init__(self): self.image_pub = rospy.Publisher("lanedetframe", Image, queue_size=1) self.maskimg_pub = rospy.Publisher("lanedetmask", Image, queue_size=1) self.binimg_pub = rospy.Publisher("lanedetbin", Image, queue_size=1) self.morphoimg_pub = rospy.Publisher("lanedetmorph", Image, queue_size=1) # self.bridge = CvBridge() self.yolo_sub = rospy.Subscriber("YOLO_detect_result_boxes", BoundingBoxes, self.callbackyolo) self.tlight_sub = rospy.Subscriber("tlight_detect_result_boxes", BoundingBoxes, self.callbackTrafficLight) # self.image_sub = rospy.Subscriber("YOLO_detect_result", Image, self.callbackRos) # self.image_sub = rospy.Subscriber("/camera/image_color", Image, self.callbackRos) self.image_sub = rospy.Subscriber("/wideangle/image_color", Image, self.callbackRos) # self.image_sub = message_filters.Subscriber("/camera/rgb/image_raw", Image,queue_size=1, buff_size=110592*6) self.weights_file = rospy.get_param("lanenet_weight") self.CUDA = torch.cuda.is_available() self.postprocessor = LaneNetPostProcessor() self.warningModule = Detection() self.band_width = 1.5 self.image_X = CFG.IMAGE_WIDTH self.image_Y = CFG.IMAGE_HEIGHT self.car_X = self.image_X / 2 self.car_Y = self.image_Y self.model = LaneNet(pretrained=False, embed_dim=4, delta_v=.5, delta_d=3.) self.save_dict = torch.load(self.weights_file, map_location='cuda:0') self.model.load_state_dict(self.save_dict['net']) # self.model.load_state_dict(torch.load(self.weights_file, map_location='cuda:0')) if self.CUDA: self.model.cuda() self.model.set_test() self.lastlane = np.ndarray(4, ) self.bridge = CvBridge() self.leftlane = Lane('left') self.rightlane = Lane('right') self.tracker = LaneTracker() # self.out = cv2.VideoWriter(str(time.time())+'testwrite.avi',cv2.VideoWriter_fourcc(*'MJPG'), 10.0, (CFG.IMAGE_WIDTH, CFG.IMAGE_HEIGHT),True) self.img = np.zeros([CFG.IMAGE_WIDTH, CFG.IMAGE_HEIGHT, 3], np.uint8) self.yoloBoxes = BoundingBoxes() self.trafficLightBoxes = BoundingBoxes() self.warning = 0
def _load_model(self): model=LaneNet() if self.mode=='parallel': model=DataParallel(model) model.load_state_dict(torch.load(self.model_path)) model=model.cuda() return model
def __init__(self): self.image_pub = rospy.Publisher("lanedetframe", Image,queue_size = 1) # self.bridge = CvBridge() # self.image_sub = rospy.Subscriber("YOLO_detect_result", Image, self.callbackRos) self.image_sub = rospy.Subscriber("/camera/rgb/image_raw", Image, self.callbackRos) # self.yolobbox_sub = rospy.Subscriber("publishers/bounding_boxes/topic", BoundingBoxes, self.callbackRos) self.weights_file = '/home/iairiv/code/lane_waring_final/experiments/exp1/exp1_best.pth' self.CUDA = torch.cuda.is_available() self.postprocessor = LaneNetPostProcessor() self.warning = Detection() self.band_width = 1.5 self.image_X = 1920 self.image_Y = 1200 self.car_X = self.image_X/2 self.car_Y = self.image_Y self.model = LaneNet(pretrained=False, embed_dim=4, delta_v=.5, delta_d=3.) self.save_dict = torch.load(self.weights_file, map_location='cuda:0') self.model.load_state_dict(self.save_dict['net']) # self.model.load_state_dict(torch.load(self.weights_file, map_location='cuda:0')) if self.CUDA: self.model.cuda() self.model.set_test() self.lastlane = np.ndarray(4,) self.bridge = CvBridge()
def main(): args = parse_args() img_path = args.img_path weight_path = args.weight_path _set = "IMAGENET" mean = IMG_MEAN[_set] std = IMG_STD[_set] transform_img = Resize((800, 288)) transform_x = Compose(ToTensor(), Normalize(mean=mean, std=std)) transform = Compose(transform_img, transform_x) img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB for net model input img = transform_img({'img': img})['img'] x = transform_x({'img': img})['img'] x.unsqueeze_(0) net = LaneNet(pretrained=False, embed_dim=4, delta_v=.5, delta_d=3.) save_dict = torch.load(weight_path, map_location='cpu') net.load_state_dict(save_dict['net']) net.eval() output = net(x) embedding = output['embedding'] embedding = embedding.detach().cpu().numpy() embedding = np.transpose(embedding[0], (1, 2, 0)) binary_seg = output['binary_seg'] bin_seg_prob = binary_seg.detach().cpu().numpy() bin_seg_pred = np.argmax(bin_seg_prob, axis=1)[0] img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) seg_img = np.zeros_like(img) lane_seg_img = embedding_post_process(embedding, bin_seg_pred, args.band_width, 4) color = np.array([[255, 125, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255]], dtype='uint8') for i, lane_idx in enumerate(np.unique(lane_seg_img)): if lane_idx == 0: continue seg_img[lane_seg_img == lane_idx] = color[i - 1] img = cv2.addWeighted(src1=seg_img, alpha=0.8, src2=img, beta=1., gamma=0.) cv2.imwrite("demo/demo_result.jpg", img) if args.visualize: cv2.imshow("", img) cv2.waitKey(0) cv2.destroyAllWindows()
folders.insert(0, path) break return folders # ------------ data and model ------------ # Imagenet mean, std mean=(0.485, 0.456, 0.406) std=(0.229, 0.224, 0.225) transform = Compose(Resize(resize_shape), ToTensor(), Normalize(mean=mean, std=std)) dataset_name = exp_cfg['dataset'].pop('dataset_name') Dataset_Type = getattr(dataset, dataset_name) test_dataset = Dataset_Type(Dataset_Path['Tusimple'], "test", transform) test_loader = DataLoader(test_dataset, batch_size=32, collate_fn=test_dataset.collate, num_workers=4) net = LaneNet(pretrained=True, **exp_cfg['model']) save_name = os.path.join(exp_dir, exp_dir.split('/')[-1] + '_best.pth') save_dict = torch.load(save_name, map_location='cpu') print("\nloading", save_name, "...... From Epoch: ", save_dict['epoch']) net.load_state_dict(save_dict['net']) net = torch.nn.DataParallel(net.to(device)) net.eval() # ------------ test ------------ out_path = os.path.join(exp_dir, "coord_output") evaluation_path = os.path.join(exp_dir, "evaluate") if not os.path.exists(out_path): os.mkdir(out_path) if not os.path.exists(evaluation_path): os.mkdir(evaluation_path) dump_to_json = []
class Lane_warning: def __init__(self): self.image_pub = rospy.Publisher("lanedetframe", Image, queue_size=1) self.maskimg_pub = rospy.Publisher("lanedetmask", Image, queue_size=1) self.binimg_pub = rospy.Publisher("lanedetbin", Image, queue_size=1) self.morphoimg_pub = rospy.Publisher("lanedetmorph", Image, queue_size=1) # self.bridge = CvBridge() self.yolo_sub = rospy.Subscriber("YOLO_detect_result_boxes", BoundingBoxes, self.callbackyolo) self.tlight_sub = rospy.Subscriber("tlight_detect_result_boxes", BoundingBoxes, self.callbackTrafficLight) # self.image_sub = rospy.Subscriber("YOLO_detect_result", Image, self.callbackRos) # self.image_sub = rospy.Subscriber("/camera/image_color", Image, self.callbackRos) self.image_sub = rospy.Subscriber("/wideangle/image_color", Image, self.callbackRos) # self.image_sub = message_filters.Subscriber("/camera/rgb/image_raw", Image,queue_size=1, buff_size=110592*6) self.weights_file = rospy.get_param("lanenet_weight") self.CUDA = torch.cuda.is_available() self.postprocessor = LaneNetPostProcessor() self.warningModule = Detection() self.band_width = 1.5 self.image_X = CFG.IMAGE_WIDTH self.image_Y = CFG.IMAGE_HEIGHT self.car_X = self.image_X / 2 self.car_Y = self.image_Y self.model = LaneNet(pretrained=False, embed_dim=4, delta_v=.5, delta_d=3.) self.save_dict = torch.load(self.weights_file, map_location='cuda:0') self.model.load_state_dict(self.save_dict['net']) # self.model.load_state_dict(torch.load(self.weights_file, map_location='cuda:0')) if self.CUDA: self.model.cuda() self.model.set_test() self.lastlane = np.ndarray(4, ) self.bridge = CvBridge() self.leftlane = Lane('left') self.rightlane = Lane('right') self.tracker = LaneTracker() # self.out = cv2.VideoWriter(str(time.time())+'testwrite.avi',cv2.VideoWriter_fourcc(*'MJPG'), 10.0, (CFG.IMAGE_WIDTH, CFG.IMAGE_HEIGHT),True) self.img = np.zeros([CFG.IMAGE_WIDTH, CFG.IMAGE_HEIGHT, 3], np.uint8) self.yoloBoxes = BoundingBoxes() self.trafficLightBoxes = BoundingBoxes() self.warning = 0 def callbackyolo(self, boxmsg): print('callbackyolo, boxes len:', boxmsg.objNum) self.yoloBoxes.objNum = boxmsg.objNum self.yoloBoxes.bounding_boxes = [] for i in range(boxmsg.objNum): print('box id:', boxmsg.bounding_boxes[i].id) box = boxmsg.bounding_boxes[i] self.yoloBoxes.bounding_boxes.append(box) # cv2.rectangle(self.img, (int(box.xmin), int(box.ymin)), (int(box.xmax), int(box.ymax)), (0, 255, 0), 4) def callbackTrafficLight(self, boxmsg): print('callbackTrafficLight, boxes len:', boxmsg.objNum) self.trafficLightBoxes.objNum = boxmsg.objNum self.trafficLightBoxes.bounding_boxes = [] for i in range(boxmsg.objNum): print('box id:', boxmsg.bounding_boxes[i].id) box = boxmsg.bounding_boxes[i] self.trafficLightBoxes.bounding_boxes.append(box) def transform_input(self, img): _set = "IMAGENET" mean = IMG_MEAN[_set] std = IMG_STD[_set] # transform_img = Resize((800, 288)) transform_img = Resize((512, 256)) transform_x = Compose(ToTensor(), Normalize(mean=mean, std=std)) #img_org = img[255:945, :, :] img = transform_img({'img': img})['img'] x = transform_x({'img': img})['img'] # print(x) x.unsqueeze_(0) x = x.to('cuda') return x def detection(self, input): #startt = time.time() if self.CUDA: input = input.cuda() with torch.no_grad(): output = self.model(input, None) # print('detection use:', time.time()-startt) return self.cluster(output) def cluster(self, output): #startt = time.time() global g_frameCnt embedding = output['embedding'] embedding = embedding.detach().cpu().numpy() embedding = np.transpose(embedding[0], (1, 2, 0)) binary_seg = output['binary_seg'] bin_seg_prob = binary_seg.detach().cpu().numpy() bin_seg_pred = np.argmax(bin_seg_prob, axis=1)[0] # seg = bin_seg_pred * 255 postprocess_result = self.postprocessor.postprocess( binary_seg_result=bin_seg_pred, instance_seg_result=embedding) # cv2.imwrite(str(g_frameCnt)+'_mask.png', postprocess_result['mask_image']) # cv2.imwrite(str(g_frameCnt)+'_binary.png', postprocess_result['binary_img']) return postprocess_result def process(self, frame): startt = time.time() cropImg = cropRoi(frame) input_image = self.transform_input(cropImg) # startt = time.time() postProcResult = self.detection(input_image) self.img = frame.copy() # cv2.imwrite(str(g_frameCnt)+'.png', frame) self.tracker.process(postProcResult['detectedLanes']) llane = self.tracker.leftlane rlane = self.tracker.rightlane lanePoints = {'lanes': [llane.points, rlane.points]} # self.warning = self.warningModule.detect(lanePoints) # if self.warning == 1: # soundplayTh = threading.Thread(target=playWarningSound) # soundplayTh.start() color = (0, 0, 255) if self.warning == 1 else (0, 255, 0) # if signal == 1: # playsound.playsound('/space/warn.mp3') for idx in range(11): cv2.line( self.img, (int(llane.points[idx][0]), int(llane.points[idx][1])), (int(llane.points[idx + 1][0]), int(llane.points[idx + 1][1])), color, 10) cv2.line( self.img, (int(rlane.points[idx][0]), int(rlane.points[idx][1])), (int(rlane.points[idx + 1][0]), int(rlane.points[idx + 1][1])), color, 10) #debug debugImg = frame.copy() for lane in postProcResult['detectedLanes']: if lane[0][0] == llane.detectedLane[0][0]: color = (255, 0, 0) elif lane[0][0] == rlane.detectedLane[0][0]: color = (255, 255, 0) else: color = (0, 255, 0) for j in range(11): cv2.line(debugImg, (lane[j][0], lane[j][1]), (lane[j + 1][0], lane[j + 1][1]), color, 10) cv2.line(debugImg, (0, lane[j][1]), (int(self.image_X), lane[j][1]), (0, 0, 255), 3) # cv2.putText(debugImg, '{}'.format(abs(int(fit_x[5])-960)), (int(fit_x[0]), int(plot_y[0])), cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, thickness=2) cv2.imwrite(str(1) + 'debug.png', debugImg) if postProcResult['mask_image'] is None: print('cant find any lane!!!') else: self.maskimg_pub.publish( self.bridge.cv2_to_imgmsg(postProcResult['mask_image'], "bgr8")) self.binimg_pub.publish( self.bridge.cv2_to_imgmsg(postProcResult['binary_img'], "mono8")) self.morphoimg_pub.publish( self.bridge.cv2_to_imgmsg(postProcResult['morpho_img'], "mono8")) self.image_pub.publish(self.bridge.cv2_to_imgmsg(self.img, "bgr8")) #debug end print('lanedet all use:', time.time() - startt) def drawYoloResult(self, data): for i in range(data.objNum): box = data.bounding_boxes[i] cv2.rectangle(self.img, (int(box.xmin), int(box.ymin)), (int(box.xmax), int(box.ymax)), (0, 255, 0), 3) t_size = cv2.getTextSize(box.id, cv2.FONT_HERSHEY_PLAIN, 1, 1)[0] c2 = int(box.xmin) + t_size[0] + 3, int(box.ymin) + t_size[1] + 4 cv2.rectangle(self.img, (int(box.xmin), int(box.ymin)), c2, (0, 255, 0), -1) cv2.putText(self.img, box.id, (int(box.xmin), int(box.ymin) + 10), cv2.FONT_HERSHEY_PLAIN, 1, (225, 255, 255), 1) def callbackRos(self, data): print('callbackros') cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8") self.img = cv_image.copy() # self.out.write(cv_image) self.process(cv_image) #cv2.imwrite('cvimage.png', cv_image) # cv2.imwrite('result.png', cv_image) self.drawYoloResult(self.yoloBoxes) self.drawYoloResult(self.trafficLightBoxes) self.image_pub.publish(self.bridge.cv2_to_imgmsg(self.img, "bgr8"))
class Lane_warning: def __init__(self): self.image_pub = rospy.Publisher("lanedetframe", Image,queue_size = 1) # self.bridge = CvBridge() # self.image_sub = rospy.Subscriber("YOLO_detect_result", Image, self.callbackRos) self.image_sub = rospy.Subscriber("/camera/rgb/image_raw", Image, self.callbackRos) # self.yolobbox_sub = rospy.Subscriber("publishers/bounding_boxes/topic", BoundingBoxes, self.callbackRos) self.weights_file = '/home/iairiv/code/lane_waring_final/experiments/exp1/exp1_best.pth' self.CUDA = torch.cuda.is_available() self.postprocessor = LaneNetPostProcessor() self.warning = Detection() self.band_width = 1.5 self.image_X = 1920 self.image_Y = 1200 self.car_X = self.image_X/2 self.car_Y = self.image_Y self.model = LaneNet(pretrained=False, embed_dim=4, delta_v=.5, delta_d=3.) self.save_dict = torch.load(self.weights_file, map_location='cuda:0') self.model.load_state_dict(self.save_dict['net']) # self.model.load_state_dict(torch.load(self.weights_file, map_location='cuda:0')) if self.CUDA: self.model.cuda() self.model.set_test() self.lastlane = np.ndarray(4,) self.bridge = CvBridge() def transform_input(self, img): return prep_image(img) def detection(self, input,raw): if self.CUDA: input = input.cuda() with torch.no_grad(): output = self.model(input, None) return self.cluster(output,raw) def cluster(self,output,raw): global i embedding = output['embedding'] embedding = embedding.detach().cpu().numpy() embedding = np.transpose(embedding[0], (1, 2, 0)) binary_seg = output['binary_seg'] bin_seg_prob = binary_seg.detach().cpu().numpy() bin_seg_pred = np.argmax(bin_seg_prob, axis=1)[0] # plt.savefig('a.png') # i = time.time() cv2.imwrite(str(i)+'a.png',bin_seg_pred) i=i+1 # cv2.waitKey(0) #plt.show() seg = bin_seg_pred * 255 # print('postprocess_result') postprocess_result = self.postprocessor.postprocess( binary_seg_result=bin_seg_pred, instance_seg_result=embedding ) prediction = postprocess_result prediction = np.array(prediction) return prediction """"没加预警""" # def write(self, output, img): # # output[:,:,1] = output[:,:,1]+255 # for i in range(len(output)): # line = np.array(output[i]) # line[:,1] = line[:,1]+255 # output[i] = line.tolist() # # for j in range(len(output[i])): # # output[i][j][1] = output[i][j][1] + 255 # # print(arr[i][j]) # # cv.circle(image, (int(arr[i][j][0]),int(arr[i][j][1])), 5, (0, 0, 213), -1) #画成小圆点 # cv.line(img, (int(output[i][0][0]), int(output[i][0][1])), (int(output[i][-1][0]), int(output[i][-1][1])), # (0,0,255), 3) # # if signal == 1: # # cv2.putText(img, "WARNING", (1300, 150), cv2.FONT_HERSHEY_SIMPLEX, 3.0, color, thickness=10) # #plt.imshow(img) # #plt.show() # return img """""加了预警""""" # def write(self, output, img,signal,color): # # output[:,:,1] = output[:,:,1]+255 # for i in range(len(output)): # line = np.array(output[i]) # # line[:,1] = line[:,1]+255 # output[i] = line.tolist() # #for j in range(len(output[i])): # #output[i][j][1] = output[i][j][1] + 255 # #print(output[i][j]) # #cv2.circle(img, (int(output[i][j][0]),int(output[i][j][1])), 5, color, -1) #画成小圆点 # cv2.line(img, (int(output[i][0][0]), int(output[i][0][1])), (int(output[i][-1][0]), int(output[i][-1][1])),color, 3) # if signal == 1: # cv2.putText(img, "WARNING", (1300, 150), cv2.FONT_HERSHEY_SIMPLEX, 3.0, color, thickness=10) # #plt.imshow(img) # #plt.show() # return img def write_nowarning(self, output, img): # output[:,:,1] = output[:,:,1]+255 for i in range(len(output)): line = np.array(output[i]) # line[:,1] = line[:,1]+255 output[i] = line.tolist() #for j in range(len(output[i])): #output[i][j][1] = output[i][j][1] + 255 #print(output[i][j]) #cv2.circle(img, (int(output[i][j][0]),int(output[i][j][1])), 5, color, -1) #画成小圆点 cv2.line(img, (int(output[i][0][0]), int(output[i][0][1])), (int(output[i][-1][0]), int(output[i][-1][1])),(0,0,255), 3) #plt.imshow(img) #plt.show() return img def color(self, signal): if signal == 0: color = (0, 255, 0) else: color = (0, 0, 255) return color #ros下的代码,还没测试过。无ros用另一个测。 def callbackRos(self, data): # print('callbackros') try: cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8") # cv_image = cv2.resize(cv_image, (0, 0), fx=0.5, fy=0.5, interpolation=cv2.INTER_NEAREST) input_image = self.transform_input(cv_image) prediction = self.detection(input_image, cv_image) if len(prediction) == 0: result = cv_image else: print(prediction) # signal = self.warning.detect(prediction) # color = self.color(signal) # result = self.write(prediction, cv_image, signal, color) # result = self.write_nowarning(prediction, cv_image) except CvBridgeError as e: print(e) # cv2.imshow("image windows", result) # cv2.waitKey(3) # try: # self.image_pub.publish(self.bridge.cv2_to_imgmsg(result, "bgr8")) # except CvBridgeError as e: # print(e) def callback(self, data): # try: # cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8") # cv_image = cv2.resize(cv_image, (0, 0), fx=0.5, fy=0.5, interpolation=cv2.INTER_NEAREST) time3 = time.time() input_image = self.transform_input(data) time4 = time.time()-time3 print('数据预处理时间:',time4) #lane_detect time5 = time.time() prediction = self.detection(input_image,data) # if len(prediction) == 0: # prediction = self.lastlane # else: # self.lastlane = prediction #print(prediction) time6 = time.time()-time5 print('检测时间:', time6) #warning time7 = time.time() signal = self.warning.detect(prediction) color = self.color(signal) time8 = time.time()-time7 print('预警时间:',time8) #draw_line time1 = time.time() # img = self.write(prediction, data) img = self.write(prediction, data, signal, color) time2 = time.time()-time1 print('画图时间:',time2) cv2.imshow("final",img) cv2.waitKey(0)
# ------------ data and model ------------ # Imagenet mean, std mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) transform = Compose(Resize(resize_shape), ToTensor(), Normalize(mean=mean, std=std)) dataset_name = exp_cfg['dataset'].pop('dataset_name') Dataset_Type = getattr(dataset, dataset_name) test_dataset = Dataset_Type(Dataset_Path['Tusimple'], "test", transform) test_loader = DataLoader(test_dataset, batch_size=16, collate_fn=test_dataset.collate, num_workers=4) net = LaneNet(pretrained=False, **exp_cfg['model']) save_name = os.path.join(exp_dir, exp_dir.split('/')[-1] + '_best.pth') save_dict = torch.load(save_name, map_location='cuda:0') print("\nloading", save_name, "...... From Epoch: ", save_dict['epoch']) net.load_state_dict(save_dict['net']) # net = torch.nn.DataParallel(net.to(device)) net = net.to(device) net.eval() # ------------ test ------------ out_path = os.path.join(exp_dir, "coord_output") evaluation_path = os.path.join(exp_dir, "evaluate") if not os.path.exists(out_path): os.mkdir(out_path) if not os.path.exists(evaluation_path): os.mkdir(evaluation_path)
ap = argparse.ArgumentParser() ap.add_argument('-e', '--epoch', default=50) #Epoch ap.add_argument('-b', '--batch', default=2) #Batch_size ap.add_argument('-dv', '--delta_v', default=.5) #delta_v ap.add_argument('-dd', '--delta_d', default=3) #delta_d ap.add_argument('-l', '--learning_rate', default=5e-4) #learning_rate ap.add_argument('-o', '--optimizer', default='Adam') #optimizer ap.add_argument('-d', '--device', default='GPU') #training device ap.add_argument('-t', '--test_ratio', default=.1) ap.add_argument('-s', '--stage', default='new') #ap.add_argument('-cl','--class_weight',default=.5) #ap.add_argument() #ap.add_argument() args = vars(ap.parse_args()) train_indices, test_indices = split_dataset(args['test_ratio']) data = build_sampler(TusimpleData('./data', transform=Rescale((256, 512))), args['batch'], 1, train_indices, test_indices) if args['stage'] == 'new': model = LaneNet() else: model_file = 'model_1548830895.pkl' model = LaneNet() weight_dict = torch.load(os.path.join('./logs/models/', model_file)) model.load_state_dict(weight_dict.state_dict()) train(model, data, args['epoch'], args['batch'], args['delta_v'], args['delta_d'], args['learning_rate'], args['optimizer'])
batch_size=exp_cfg['dataset']['batch_size'], shuffle=True, collate_fn=train_dataset.collate, num_workers=8) # ------------ val data ------------ transform_val = Compose(Resize(resize_shape), ToTensor(), Normalize(mean=mean, std=std)) val_dataset = Dataset_Type(Dataset_Path[dataset_name], "val", transform_val) val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=val_dataset.collate, num_workers=4) # ------------ preparation ------------ net = LaneNet(pretrained=True, **exp_cfg['model']) net = net.to(device) net = torch.nn.DataParallel(net) optimizer = optim.SGD(net.parameters(), **exp_cfg['optim']) lr_scheduler = PolyLR(optimizer, 0.9, exp_cfg['MAX_ITER']) best_val_loss = 1e6 def train(epoch): print("Train Epoch: {}".format(epoch)) net.train() train_loss = 0 train_loss_bin_seg = 0 train_loss_var = 0 train_loss_dist = 0
torch.save( model, os.path.join('./logs/models', 'model_{}.pkl'.format(start_time))) log.close() if __name__ == '__main__': ap = argparse.ArgumentParser() ap.add_argument('-e', '--epoch', default=30) #Epoch ap.add_argument('-b', '--batch', default=16) #Batch_size ap.add_argument('-dv', '--delta_v', default=.5) #delta_v ap.add_argument('-dd', '--delta_d', default=3) #delta_d ap.add_argument('-l', '--learning_rate', default=5e-4) #learning_rate ap.add_argument('-o', '--optimizer', default='Adam') #optimizer ap.add_argument('-d', '--device', default='GPU') #training device ap.add_argument('-t', '--test_ratio', default=.1) #ap.add_argument('-cl','--class_weight',default=.5) #ap.add_argument() #ap.add_argument() args = vars(ap.parse_args()) train_indices, test_indices = split_dataset(args['test_ratio']) data = build_sampler(TusimpleData('./data', transform=Rescale((256, 512))), args['batch'], 1, train_indices, test_indices) model = LaneNet() train(model, data, args['epoch'], args['batch'], args['delta_v'], args['delta_d'], args['learning_rate'], args['optimizer'])
import argparse import cv2 import torch from model import LaneNet from utils.transforms import * from utils.postprocess import embedding_post_process net = LaneNet(pretrained=False, embed_dim=7, delta_v=.5, delta_d=3.) transform = Compose(Resize((800, 288)), ToTensor(), Normalize(mean=(0.3598, 0.3653, 0.3662), std=(0.2573, 0.2663, 0.2756))) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--img_path", '-i', type=str, default="demo/demo.jpg", help="Path to demo img") parser.add_argument("--weight_path", '-w', type=str, help="Path to model weights") parser.add_argument("--delta_v", '-dv', type=float, default=0.5, help="Value of delta_v") parser.add_argument("--visualize", '-v', action="store_true", default=False, help="Visualize the result") args = parser.parse_args() return args def main(): args = parse_args() img_path = args.img_path weight_path = args.weight_path img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB for net model input x = transform(img)[0]