def main(args): modelpath = args.loadDir + args.loadModel weightspath = args.loadDir + args.loadWeights print("Loading model: " + modelpath) print("Loading weights: " + weightspath) #Import ERFNet model from the folder #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet") model = ERFNet(NUM_CLASSES) model = torch.nn.DataParallel(model) if (not args.cpu): model = model.cuda() #model.load_state_dict(torch.load(args.state)) #model.load_state_dict(torch.load(weightspath)) #not working if missing key def load_my_state_dict( model, state_dict ): #custom function to load model when not all dict elements own_state = model.state_dict() for name, param in state_dict.items(): if name not in own_state: continue own_state[name].copy_(param) return model model = load_my_state_dict(model, torch.load(weightspath)) print("Model and weights LOADED successfully") model.eval() if (not os.path.exists(args.datadir)): print("Error: datadir could not be loaded") loader = DataLoader(MAV(args.datadir, input_transform_MAV, target_transform_MAV, subset=args.subset), num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) for step, (images, labels, filename, filenameGt) in enumerate(loader): if (not args.cpu): images = images.cuda() #labels = labels.cuda() inputs = Variable(images) #targets = Variable(labels) with torch.no_grad(): outputs = model(inputs) label = outputs[0].max(0)[1].byte().cpu().data label_MAV = MAV_trainIds2labelIds(label.unsqueeze(0)) #print (numpy.unique(label.numpy())) #debug filenameSave = "./save_results/" + filename[0].split("Images/")[1] os.makedirs(os.path.dirname(filenameSave), exist_ok=True) #image_transform(label.byte()).save(filenameSave) label_MAV.save(filenameSave) print(step, filenameSave)
def main(args): modelpath = args.loadDir + args.loadModel weightspath = args.loadDir + args.loadWeights if (not os.path.exists(args.datadir)): print("Error: datadir could not be loaded") else: print("Loading model: " + modelpath) print("Loading weights: " + weightspath) # Import ERFNET model = ERFNet(NUM_CLASSES) model = torch.nn.DataParallel(model) if (not args.cpu): model = model.cuda() def load_my_state_dict( model, state_dict ): #custom function to load model when not all dict elements own_state = model.state_dict() for name, param in state_dict.items(): if name not in own_state: continue own_state[name].copy_(param) return model model = load_my_state_dict(model, torch.load(weightspath)) print("Model and weights LOADED successfully") # Set model to Evaluation mode model.eval() # Setup the dataset loader ### RELLIS-3D Dataloader enc = False loader_test = custom_datasets.setup_loaders(args, enc) # For visualizer: # must launch in other window "python3.6 -m visdom.server -port 8097" # and access localhost:8097 to see it if (args.visualize): vis = visdom.Visdom() for step, (images, labels, img_name, _) in enumerate(loader_test): start_time = time.time() if (not args.cpu): images = images.cuda() #labels = labels.cuda() inputs = Variable(images) with torch.no_grad(): outputs = model(inputs) label = outputs[0].max(0)[1].byte().cpu().data label_color = Colorize()(label.unsqueeze(0)) eval_save_path = "./save_colour_rellis/" if not os.path.exists(eval_save_path): os.makedirs(eval_save_path) _, file_name = os.path.split(img_name[0]) file_name = file_name + ".png" #image_transform(label.byte()).save(filenameSave) label_save = ToPILImage()(label_color) label_save.save(os.path.join(eval_save_path, file_name)) if (args.visualize): vis.image(label_color.numpy()) if step != 0: #first run always takes some time for setup fwt = time.time() - start_time time_train.append(fwt) print("Forward time per img (b=%d): %.3f (Mean: %.3f)" % (args.batch_size, fwt / args.batch_size, sum(time_train) / len(time_train) / args.batch_size)) print(step, os.path.join(eval_save_path, file_name))
def main(args): modelpath = args.loadDir + args.loadModel weightspath = args.loadDir + args.loadWeights print("Loading model: " + modelpath) print("Loading weights: " + weightspath) #Import ERFNet model from the folder #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet") model = ERFNet(NUM_CLASSES) model = torch.nn.DataParallel(model) if (not args.cpu): model = model.cuda() #model.load_state_dict(torch.load(args.state)) #model.load_state_dict(torch.load(weightspath)) #not working if missing key fourcc = cv2.VideoWriter_fourcc(*'MP4V') # Save as video out = cv2.VideoWriter('output.mp4', fourcc, 20.0, (640, 352)) def load_my_state_dict( model, state_dict ): #custom function to load model when not all dict elements own_state = model.state_dict() for name, param in state_dict.items(): if name not in own_state: continue own_state[name].copy_(param) return model model = load_my_state_dict( model, torch.load(weightspath, map_location=torch.device('cpu'))) print("Model and weights LOADED successfully") model.eval() # loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset), # num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) # For visualizer: # must launch in other window "python3.6 -m visdom.server -port 8097" # and access localhost:8097 to see it # if (args.visualize): # vis = visdom.Visdom() # for step, (images, labels, filename, filenameGt) in enumerate(loader): # cap = cv2.VideoCapture(0) cap = cv2.VideoCapture('project_video_trimmed.mp4') while (True): # Capture frame-by-frame ret, images = cap.read() # print(images.shape) images = trans(images) images = images.float() images = images.view((1, 3, 352, 640)) # vidoe # Our operations on the frame come here if (not args.cpu): images = images.cuda() inputs = Variable(images) with torch.no_grad(): outputs = model(inputs) label = outputs[0].max(0)[1].byte().cpu().data #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0)) label_color = Colorize()(label.unsqueeze(0)) frame = label_color.numpy().transpose(1, 2, 0) # label_save = ToPILImage()(label_color) # label_save.save("result_1.png") # Display the resulting frame cv2.imshow('Segmented Image', frame) out.write(frame) # To save video file if cv2.waitKey(1) & 0xFF == ord('q'): break cap.release() out.release() cv2.destroyAllWindows() if (args.visualize): vis.image(label_color.numpy()) print(step, filenameSave)
def main(args): savedir = f'../save/{args.savedir}' if not os.path.exists(savedir): os.makedirs(savedir) with open(savedir + '/opts.txt', "w") as myfile: myfile.write(str(args)) #Load Model assert os.path.exists(args.model + ".py"), "Error: model definition not found" model = ERFNet(NUM_CLASSES) copyfile(args.model + ".py", savedir + '/' + args.model + ".py") if args.cuda: model = torch.nn.DataParallel(model).cuda() assert os.path.exists(args.datadir), "Error: datadir (dataset directory) could not be loaded" # co_transform = MyCoTransform(augment=True, rescale=True, width=640, height=480)#1024) # co_transform_val = MyCoTransform(augment=False, rescale=True, width=640, height=480)#1024) # dataset_train = cityscapes_cv(args.datadir, co_transform, 'train') # dataset_val = cityscapes_cv(args.datadir, co_transform_val, 'val') # loader_train = DataLoader(dataset_train, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True) # loader_val = DataLoader(dataset_val, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) # co_transform = MyCoTransform(augment=True, rescale=True, width=640, height=480)#1024) # dataset_train = gardenscapes(args.datadir, co_transform) # loader_train = DataLoader(dataset_train, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True) # loader_val = loader_train if args.state: #if args.state is provided then load this state for training #Note: this only loads initialized weights. If you want to resume a training use "--resume" option!! def load_my_state_dict(model, state_dict): #custom function to load model when not all dict keys are there own_state = model.state_dict() for name, param in state_dict.items(): if name not in own_state: continue own_state[name].copy_(param) return model #print(torch.load(args.state)) model = load_my_state_dict(model, torch.load(args.state)) #model = train(model, loader_train, loader_val, args) #Train decoder print("========== TRAINING FINISHED ===========") print("========== START TESTING ==============") model_dir = "/home/pandongwei/work_repository/erfnet_pytorch/save/"+args.savedir+'/model_best.pth' #model_dir = "/home/pan/repository/erfnet_pytorch/save/geoMat_regression_2/checkpoint.pth.tar" def load_my_state_dict(model, state_dict): # state_dict = state_dict["state_dict"] own_state = model.state_dict() for name, param in state_dict.items(): if name not in own_state: continue own_state[name].copy_(param) return model model = load_my_state_dict(model, torch.load(model_dir)) # filenameSave = "../eval/" + args.savedir + "/" # os.makedirs(os.path.dirname(filenameSave), exist_ok=True) # co_transform_test = MyCoTransform(augment=False, rescale=True, width=640, height=480) # 1024) # dataset_test = cityscapes_cv(args.datadir, co_transform_test, 'test') # loader_test = DataLoader(dataset_test,num_workers=args.num_workers, batch_size=1, shuffle=False) #test(filenameSave, model, loader_test, args) inference(model,args)
def main(args): modelpath = args.loadDir + args.loadModel #weightspath = args.loadDir + args.loadWeights #TODO weightspath = "/home/pan/repository/erfnet_pytorch/save/geoMat_2/model_best.pth" # print ("Loading model: " + modelpath) print("Loading weights: " + weightspath) #Import ERFNet model from the folder #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet") model = ERFNet(NUM_CLASSES) model = torch.nn.DataParallel(model) if (not args.cpu): model = model.cuda() #model.load_state_dict(torch.load(args.state)) #model.load_state_dict(torch.load(weightspath)) #not working if missing key def load_my_state_dict( model, state_dict ): #custom function to load model when not all dict elements own_state = model.state_dict() for name, param in state_dict.items(): if name not in own_state: continue own_state[name].copy_(param) return model model = load_my_state_dict(model, torch.load(weightspath)) print("Model and weights LOADED successfully") model.eval() if (not os.path.exists(args.datadir)): print("Error: datadir could not be loaded") loader = DataLoader(geoMat(args.datadir, input_transform_geoMat, target_transform_geoMat, subset=args.subset), num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) # For visualizer: # must launch in other window "python3.6 -m visdom.server -port 8097" # and access localhost:8097 to see it if (args.visualize): vis = visdom.Visdom() for step, (images, labels, filename, filenameGt) in enumerate(loader): if (not args.cpu): images = images.cuda() #labels = labels.cuda() inputs = Variable(images) #targets = Variable(labels) with torch.no_grad(): outputs = model(inputs) label = outputs[0].max(0)[1].byte().cpu().data # print(label.shape) #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0)) label_color = (label.unsqueeze(0)) # print(label_color.shape) filenameSave = "./getMat_2/" + filename[0].split( "material_dataset_v2")[1] os.makedirs(os.path.dirname(filenameSave), exist_ok=True) #image_transform(label.byte()).save(filenameSave) # label_save = ToPILImage()(label_color) label_save = label_color.numpy() label_save = label_save.transpose(1, 2, 0) # label_save.save(filenameSave) images = images.cpu().numpy().squeeze(axis=0) images = images.transpose(1, 2, 0) # print(images.shape) # print(label_save.shape) plt.figure(figsize=(10.4, 10.4), dpi=10) # plt.imshow(images) plt.imshow(label_save, alpha=0.6, cmap='gray') plt.axis('off') # plt.show() plt.savefig(filenameSave, dpi=10) plt.close() if (args.visualize): vis.image(label_color.numpy()) print(step, filenameSave)
class Model: def __init__(self, args): self.args = args self.device = args.device if args.model == 'stackhourglass': self.model = ERFNet(1) self.model2 = ldyNet(1) self.model.cuda() #= self.model.to(self.device) self.model2.cuda() if args.use_multiple_gpu: self.model = torch.nn.DataParallel(self.model) self.model.cuda() self.model2 = torch.nn.DataParallel(self.model2) self.model2.cuda() if args.mode == 'train': self.loss_function = MonodepthLoss(n=3, SSIM_w=0.8, disp_gradient_w=0.1, lr_w=1).to(self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=args.learning_rate) self.optimizer2 = optim.Adam(self.model2.parameters(), lr=args.learning_rate) self.val_n_img, self.val_loader = prepare_dataloader( args.val_data_dir, 'test', args.augment_parameters, False, 1, (args.input_height, args.input_width), args.num_workers) # self.model.load_state_dict(torch.load(args.loadmodel))#jiazai yuxunlian moxing jiezhu ciju else: self.model.load_state_dict(torch.load(args.model_path)) self.model2.load_state_dict(torch.load(args.model2_path)) args.augment_parameters = None args.do_augmentation = False args.batch_size = 1 self.n_img, self.loader = prepare_dataloader( args.data_dir, args.mode, args.augment_parameters, args.do_augmentation, args.batch_size, (args.input_height, args.input_width), args.num_workers) if 'cuda' in self.device: torch.cuda.synchronize() def disparity_loader(self, path): return Image.open(path) def train(self): # losses = [] #val_losses = [] best_loss = float('Inf') best_val_loss = 1000000.0 for epoch in range(self.args.epochs): # losses = [] if self.args.adjust_lr: adjust_learning_rate(self.optimizer, epoch, self.args.learning_rate) if self.args.adjust_lr: adjust_learning_rate2(self.optimizer2, epoch, self.args.learning_rate) c_time = time.time() running_loss = 0.0 self.model.train() #?start? self.model2.train() for data in self.loader: #(test-for-train) # Load data data = to_device(data, self.device) left = data['left_image'] bg_image = data['bg_image'] #template=data['template'] #randomMatrix=np.random.randint(0,1,(256,256)) # Net1Iput=torch.from_numpy(randomMatrix) disps = self.model(left) #j=0 # while j<(self.args.batch_size-1): # disps1=disps i = 0 left1 = left while i < 15: left1 = torch.cat((left1, left), 1) i = i + 1 Net2Iput = disps.mul(left1) #点乘 Net2Iput11, Net2Iput12 = Net2Iput.split(128, 2) Net2Iput11, Net2Iput21 = Net2Iput11.split(128, 3) Net2Iput12, Net2Iput22 = Net2Iput12.split(128, 3) # Net2Iput1=Net2Iput[:127] # Net2Iput2=Net2Iput[128:] # Net2Iput11=Net2Iput1[:,0:127] # Net2Iput12=Net2Iput1[:,128:255] # Net2Iput21=Net2Iput2[:,0:127] # Net2Iput22=Net2Iput2[:,128:255] Net2Iput11 = torch.sum(Net2Iput11, 2, keepdim=True, out=None) Net2Iput11 = torch.sum(Net2Iput11, 3, keepdim=True, out=None) Net2Iput12 = torch.sum(Net2Iput12, 2, keepdim=True, out=None) Net2Iput12 = torch.sum(Net2Iput12, 3, keepdim=True, out=None) Net2Iput21 = torch.sum(Net2Iput21, 2, keepdim=True, out=None) Net2Iput21 = torch.sum(Net2Iput21, 3, keepdim=True, out=None) Net2Iput22 = torch.sum(Net2Iput22, 2, keepdim=True, out=None) Net2Iput22 = torch.sum(Net2Iput22, 3, keepdim=True, out=None) Net2Iput1 = torch.cat((Net2Iput11, Net2Iput12), 2) Net2Iput2 = torch.cat((Net2Iput21, Net2Iput22), 2) Net2Iput = torch.cat((Net2Iput1, Net2Iput2), 3) # Net2Iput=torch.cat((Net2Iput,Net2Iput22),1) ps = nn.PixelShuffle(4) Net2Iput = ps(Net2Iput) Net2Out = self.model2(Net2Iput) self.optimizer2.zero_grad() lossnet2 = self.loss_function(Net2Out, bg_image) lossnet2.backward(retain_graph=True) self.optimizer2.step() self.optimizer.zero_grad() # loss=0 # plt.figure() # plt.imshow(left.squeeze().cpu().detach().numpy()) # plt.show() # plt.imshow(bg_image.squeeze().cpu().detach().numpy()) # plt.show() #disps = self.model(left) # print('gggggggggggggggggggggggggggggggggggggggggg',left.shape,disps.shape) # plt.imshow(disps.squeeze().cpu().detach().numpy()) # plt.show() loss1 = self.loss_function(Net2Out, bg_image) loss1.backward(retain_graph=True) self.optimizer.step() # losses.append(loss.item()) running_loss += lossnet2.item() # print(' time = %.2f' %(time.time() - c_time)) running_loss /= (self.n_img / self.args.batch_size) # print('running_loss:', running_loss) # running_val_loss /= (self.val_n_img / self.args.batch_size) print('Epoch:', epoch + 1, 'train_loss:', running_loss, 'time:', round(time.time() - c_time, 3), 's') TrueLoss = math.sqrt(running_loss) * 255 print('TrueLoss:', TrueLoss) if epoch % 2 == 0: #self.model.eval() self.model.eval() i = 0 running_val_loss = 0.0 for data in self.val_loader: data = to_device(data, self.device) left = data['left_image'] bg_image = data['bg_image'] #with torch.no_grad(): # newinput=torch.cat([left,bg],1) # disps = self.model(newinput) #disps = self.model(left) #Net2Out=self.model2(disps) #lossnet2 = self.loss_function(Net2Out,bg_image) # loss1 = self.loss_function((disps+left1.float()),target,mask) running_val_loss += lossnet2 #running_val_loss/=100 print('running_val_loss = %.12f' % (running_val_loss)) if running_val_loss < best_val_loss: self.save(self.args.model_path[:-4] + '_cpt.pth') self.save2(self.args.model2_path[:-4] + '_cpt.pth') best_val_loss = running_val_loss print('Model_saved') # self.save(self.args.model_path[:-4] + '_'+str(epoch+1)+'.pth') # if epoch==100: # self.save(self.args.model_path[:-4] + '_100.pth') # if epoch==150: # self.save(self.args.model_path[:-4] + '_150.pth') # if epoch==200: # self.save(self.args.model_path[:-4] + '_200.pth') # if epoch==250: # self.save(self.args.model_path[:-4] + '_250.pth') # self.save(self.args.model_path[:-4] + '_last.pth')#上一回 # if running_loss < best_val_loss: # #print(running_val_loss) # #print(best_val_loss) # self.save(self.args.model_path[:-4] + '_cpt.pth') # best_val_loss = running_loss # print('Model_saved') # print ('Finished Training. Best loss:', best_loss) #self.save(self.args.model_path) def save(self, path): torch.save(self.model.state_dict(), path) def save2(self, path): torch.save(self.model2.state_dict(), path) def load(self, path): self.model.load_state_dict(torch.load(path))
def main(args): modelpath = args.loadDir + args.loadModel weightspath = args.loadDir + args.loadWeights print ("Loading model: " + modelpath) print ("Loading weights: " + weightspath) #Import ERFNet model from the folder #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet") model = ERFNet(NUM_CLASSES) model = torch.nn.DataParallel(model) if (not args.cpu): model = model.cuda() #model.load_state_dict(torch.load(args.state)) #model.load_state_dict(torch.load(weightspath)) #not working if missing key def load_my_state_dict(model, state_dict): #custom function to load model when not all dict elements own_state = model.state_dict() for name, param in state_dict.items(): if name not in own_state: continue own_state[name].copy_(param) return model model = load_my_state_dict(model, torch.load(weightspath)) print ("Model and weights LOADED successfully") model.eval() if(not os.path.exists(args.datadir)): print ("Error: datadir could not be loaded") loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset), num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) # For visualizer: # must launch in other window "python3.6 -m visdom.server -port 8097" # and access localhost:8097 to see it if (args.visualize): vis = visdom.Visdom() for step, (images, labels, filename, filenameGt) in enumerate(loader): if (not args.cpu): images = images.cuda() #labels = labels.cuda() inputs = Variable(images) #targets = Variable(labels) with torch.no_grad(): outputs = model(inputs) label = outputs[0].max(0)[1].byte().cpu().data #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0)) label_color = Colorize()(label.unsqueeze(0)) filenameSave = "./save_color/" + filename[0].split("leftImg8bit/")[1] os.makedirs(os.path.dirname(filenameSave), exist_ok=True) #image_transform(label.byte()).save(filenameSave) label_save = ToPILImage()(label_color) label_save.save(filenameSave) if (args.visualize): vis.image(label_color.numpy()) print (step, filenameSave)
def main(args): modelpath = args.loadDir + args.loadModel weightspath = args.loadDir + args.loadWeights print ("Loading model: " + modelpath) print ("Loading weights: " + weightspath) model = ERFNet(NUM_CLASSES) #model = torch.nn.DataParallel(model) if (not args.cpu): model = torch.nn.DataParallel(model).cuda() def load_my_state_dict(model, state_dict): #custom function to load model when not all dict elements own_state = model.state_dict() for name, param in state_dict.items(): if name not in own_state: if name.startswith("module."): own_state[name.split("module.")[-1]].copy_(param) else: print(name, " not loaded") continue else: own_state[name].copy_(param) return model model = load_my_state_dict(model, torch.load(weightspath, map_location=lambda storage, loc: storage)) print ("Model and weights LOADED successfully") model.eval() if(not os.path.exists(args.datadir)): print ("Error: datadir could not be loaded") loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset), num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) iouEvalVal = iouEval(NUM_CLASSES) start = time.time() for step, (images, labels, filename, filenameGt) in enumerate(loader): if (not args.cpu): images = images.cuda() labels = labels.cuda() inputs = Variable(images) with torch.no_grad(): outputs = model(inputs) iouEvalVal.addBatch(outputs.max(1)[1].unsqueeze(1).data, labels) filenameSave = filename[0].split("leftImg8bit/")[1] print (step, filenameSave) iouVal, iou_classes = iouEvalVal.getIoU() iou_classes_str = [] for i in range(iou_classes.size(0)): iouStr = getColorEntry(iou_classes[i])+'{:0.2f}'.format(iou_classes[i]*100) + '\033[0m' iou_classes_str.append(iouStr) print("---------------------------------------") print("Took ", time.time()-start, "seconds") print("=======================================") #print("TOTAL IOU: ", iou * 100, "%") print("Per-Class IoU:") print(iou_classes_str[0], "Road") print(iou_classes_str[1], "sidewalk") print(iou_classes_str[2], "building") print(iou_classes_str[3], "wall") print(iou_classes_str[4], "fence") print(iou_classes_str[5], "pole") print(iou_classes_str[6], "traffic light") print(iou_classes_str[7], "traffic sign") print(iou_classes_str[8], "vegetation") print(iou_classes_str[9], "terrain") print(iou_classes_str[10], "sky") print(iou_classes_str[11], "person") print(iou_classes_str[12], "rider") print(iou_classes_str[13], "car") print(iou_classes_str[14], "truck") print(iou_classes_str[15], "bus") print(iou_classes_str[16], "train") print(iou_classes_str[17], "motorcycle") print(iou_classes_str[18], "bicycle") print("=======================================") iouStr = getColorEntry(iouVal)+'{:0.2f}'.format(iouVal*100) + '\033[0m' print ("MEAN IoU: ", iouStr, "%")
def main(): Net = ERFNet(NUM_CLASSES) Net = Net.cuda() train(Net)
class Model: def __init__(self, args): self.args = args self.device = args.device if args.model == 'stackhourglass': self.model = ERFNet(2) self.model.cuda() self.model.load_state_dict(torch.load(args.loadmodel)) # print(self.model.state_dict()) # print('kkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk') self.n_img, self.loader = prepare_dataloader(args.data_dir, args.mode, args.augment_parameters, args.do_augmentation, args.batch_size, (args.input_height, args.input_width), args.num_workers) if 'cuda' in self.device: torch.cuda.synchronize() def test(self): i = 0 sum_loss = 0.0 sum_ssim = 0.0 average_ssim = 0.0 average_loss = 0.0 for epoch in range(self.args.epochs): self.model.eval() #?start? for data in self.loader: #(test-for-train) i = i + 1 data = to_device(data, self.device) left = data['left_image'] bg_image = data['bg_image'] disps = self.model(left) print(disps.shape) l_loss = l1loss(disps,bg_image) ssim_loss = ssim(disps,bg_image) psnr222 = psnr1(disps,bg_image) sum_loss = sum_loss + l_loss.item() sum_ssim += ssim_loss.item() average_ssim = sum_ssim / i average_loss = sum_loss / i # print average_loss disp_show = disps.squeeze() bg_show = bg_image.squeeze() print(bg_show.shape) plt.figure() plt.subplot(1,2,1) plt.imshow(disp_show.data.cpu().numpy()) plt.subplot(1,2,2) plt.imshow(bg_show.data.cpu().numpy()) plt.show() print('average loss',average_loss,average_ssim)
def main(args): modelpath = args.loadDir + args.loadModel weightspath = args.loadDir + args.loadWeights print("Loading model: " + modelpath) print("Loading weights: " + weightspath) #Import ERFNet model from the folder #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet") if args.pretrainedEncoder: pretrainedEnc = torch.nn.DataParallel(ERFNet_imagenet(1000)) #pretrainedEnc.load_state_dict(torch.load(args.pretrainedEncoder)['state_dcit']) pretrainedEnc = next(pretrainedEnc.children()).features.encoder if (not args.cuda): pretrainedEnc = pretrainedEnc.cpu() model = ERFNet(NUM_CLASSES, encoder=pretrainedEnc) else: model = ERFNet(NUM_CLASSES) model = torch.nn.DataParallel(model) if args.cuda: model = model.cuda() #model.load_state_dict(torch.load(args.state)) #model.load_state_dict(torch.load(weightspath)) #not working if missing key def load_my_state_dict( model, state_dict ): #custom function to load model when not all dict elements own_state = model.state_dict() for name, param in state_dict.items(): if name not in own_state: continue own_state[name].copy_(param) return model model = load_my_state_dict(model, torch.load(weightspath)) print("Model and weights LOADED successfully") model.eval() if (not os.path.exists(args.datadir)): print("Error: datadir could not be loaded") dataset_test = avlane(args.datadir, input_transform_lot, label_transform_lot, 'test') loader = DataLoader(dataset_test, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) # For visualizer: # must launch in other window "python3.6 -m visdom.server -port 8097" # and access localhost:8097 to see it if (args.visualize): vis = visdom.Visdom() fig = plt.figure() ax = fig.gca() h = ax.imshow(Image.new('RGB', (640 * 2, 480), 0)) print(len(loader.dataset)) iouEvalTest = iouEval_binary(NUM_CLASSES) with torch.no_grad(): for step, (images, labels, filename) in enumerate(loader): #print(images.shape) if args.cuda: images = images.cuda() labels = labels.cuda() inputs = Variable(images) targets = labels outputs = model(inputs) preds = torch.where(outputs > 0.5, torch.ones([1], dtype=torch.long).cuda(), torch.zeros([1], dtype=torch.long).cuda()) #preds = torch.where(outputs > 0.5, torch.ones([1], dtype=torch.uint8).cuda(), torch.zeros([1], dtype=torch.uint8).cuda()) # b x 1 x h x w #label = outputs[0].max(0)[1].byte().cpu().data #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0)) #label_color = Colorize()(label.unsqueeze(0)) # iou iouEvalTest.addBatch(preds[:, 0], targets[:, 0]) # no_grad handles it already iouTest = iouEvalTest.getIoU() iouStr = "test IOU: " + '{:0.2f}'.format( iouTest.item() * 100) + "%" print(iouStr) # save the output filenameSave = os.path.join(args.loadDir, 'test_results', filename[0].split("test/")[1]) os.makedirs(os.path.dirname(filenameSave), exist_ok=True) #image_transform(label.byte()).save(filenameSave) #label_save = ToPILImage()(label) pred = preds.to(torch.uint8).squeeze(0).cpu() # 1xhxw pred_save = pred_transform_lot(pred) #pred_save.save(filenameSave) # concatenate data & result im1 = Image.open( os.path.join(args.datadir, 'data/test', filename[0].split("test/")[1])).convert('RGB') im2 = pred_save dst = Image.new('RGB', (im1.width + im2.width, im1.height)) dst.paste(im1, (0, 0)) dst.paste(im2, (im1.width, 0)) filenameSaveConcat = os.path.join(args.loadDir, 'test_results_concat', filename[0].split("test/")[1]) os.makedirs(os.path.dirname(filenameSaveConcat), exist_ok=True) #dst.save(filenameSaveConcat) # wrtie iou on dst font = ImageFont.truetype( '/usr/share/fonts/truetype/freefont/FreeMonoBold.ttf', 36) d = ImageDraw.Draw(dst) d.text((900, 0), iouStr, font=font, fill=(255, 255, 0)) # show video h.set_data(dst) plt.draw() plt.axis('off') plt.pause(1e-2) if (args.visualize): vis.image(label_save.numpy()) print(step, filenameSave)
def main(): parser = ArgumentParser() arg = parser.add_argument arg('--model', default='M2UNet', type=str, choices=['M2UNet', 'DRIU', 'ERFNet', 'UNet']) arg('--state_dict', default='M2UNetDRIVE.pth', type=str, help='pretrained model weights file, stored in models') arg('--dataset', default='DRIVE', choices=['DRIVE', 'CHASE_DB1', 'HRF'], type=str, help= 'determines the dataset directory and the amount of cropping that is performed to ensure that the loaded images are multiples of 32.' ) arg('--threshold', default=0.5215, type=float, help='threshold to convert probability vessel map to binary map') arg('--devicename', default='cpu', type=str, help='device type, default: "cpu"') arg('--batch_size', default=1, type=int, help='inference batch size, default: 1') arg('--save_prob', default=False, help='save probability vessel maps to disk') arg('--save_binary_mask', default=False, help='save binary mask to disk') # Paths model_path = Path('models') data_path = Path('data') log_path = Path('logs') # parse arguments args = parser.parse_args() state_dict_path = model_path.joinpath(args.state_dict) dataset_path = data_path.joinpath(args.dataset) image_file_path = dataset_path.joinpath('test/images') prediction_output_path = dataset_path.joinpath('predictions') threshold = args.threshold devicename = args.devicename batch_size = args.batch_size dataset = args.dataset save_prob = args.save_prob save_binary_mask = args.save_binary_mask # default device type is 'cuda:0' pin_memory = True cudnn.benchmark = True device = torch.device(devicename) if devicename == 'cpu': # if run on cpu, disable cudnn benchmark and do not pin memory cudnn.benchmark = False pin_memory = False # only run on one core torch.set_num_threads(1) if args.model == 'M2UNet': model = m2unet() if args.model == 'DRIU' and dataset == 'DRIVE': model = DRIU() if args.model == 'DRIU' and dataset == 'CHASE_DB1': model = DRIU() if args.model == 'ERFNet': model = ERFNet(1) if args.model == 'UNet': model = UNet(in_channels=3, n_classes=1, depth=5, wf=6, padding=True, batch_norm=False, up_mode='upconv') state_dict = torch.load(str(state_dict_path), map_location=devicename) model.load_state_dict(state_dict, strict=True) model.eval() # list of all files include path file_paths = get_file_lists(image_file_path) # list of file names only file_names = list(map(lambda x: x.stem, file_paths)) # dataloader dataloader = DataLoader(dataset=RetinaDataset(file_paths, dataset), batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=pin_memory) run(model, dataloader, batch_size, threshold, device, save_prob, save_binary_mask, prediction_output_path, file_names)