def train(): train_dataset = Cityscapes(root='.', subset='train', transform=MyTransform()) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True) erfnet = ERFNet(num_classes) optimizer = Adam(erfnet.parameters(), 5e-4, (0.9, 0.999), eps=1e-08, weight_decay=1e-4) lambda1 = lambda epoch: pow((1 - ((epoch - 1) / num_epochs)), 0.9) scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) criterion = nn.NLLLoss2d(weight) iou_meter = meter.AverageValueMeter() erfnet.cuda() erfnet.train() for epoch in range(num_epochs): print("Epoch {} : ".format(epoch + 1)) total_loss = [] iou_meter.reset() scheduler.step(epoch) for i, (images, labels) in enumerate(train_loader): images = Variable(images.cuda(), requires_grad=True).cuda() labels = Variable(labels.cuda()).cuda() outputs = erfnet(images) optimizer.zero_grad() loss = criterion(torch.nn.functional.log_softmax(outputs), labels) loss.backward() optimizer.step() iou = IoU(outputs.max(1)[1].data, labels.data) iou_meter.add(iou) total_loss.append(loss.data[0] * batch_size) iou_avg = iou_meter.value()[0] loss_avg = sum(total_loss) / len(total_loss) scheduler.step(loss_avg) print("IoU : {:.5f}".format(iou_avg)) print("Loss : {:.5f}".format(loss_avg)) torch.save(erfnet, './model/erfnet' + str(epoch) + '.pth') eval(epoch)
def main(args): modelpath = args.loadDir + args.loadModel weightspath = args.loadDir + args.loadWeights print ("Loading model: " + modelpath) print ("Loading weights: " + weightspath) #Import ERFNet model from the folder #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet") model = ERFNet(NUM_CLASSES) model = torch.nn.DataParallel(model) if (not args.cpu): model = model.cuda() #model.load_state_dict(torch.load(args.state)) #model.load_state_dict(torch.load(weightspath)) #not working if missing key def load_my_state_dict(model, state_dict): #custom function to load model when not all dict elements own_state = model.state_dict() for name, param in state_dict.items(): if name not in own_state: continue own_state[name].copy_(param) return model model = load_my_state_dict(model, torch.load(weightspath)) print ("Model and weights LOADED successfully") model.eval() if(not os.path.exists(args.datadir)): print ("Error: datadir could not be loaded") loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset), num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) for step, (images, labels, filename, filenameGt) in enumerate(loader): if (not args.cpu): images = images.cuda() #labels = labels.cuda() inputs = Variable(images, volatile=True) #targets = Variable(labels, volatile=True) outputs = model(inputs) label = outputs[0].max(0)[1].byte().cpu().data label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0)) #print (numpy.unique(label.numpy())) #debug filenameSave = "./save_results/" + filename[0].split("leftImg8bit/")[1] os.makedirs(os.path.dirname(filenameSave), exist_ok=True) #image_transform(label.byte()).save(filenameSave) label_cityscapes.save(filenameSave) print (step, filenameSave)
def midfunc(): print 'load Model....\n' t1 = time.time() Net = ERFNet(2) model = './erf_save/model/main-erfnet-step-5998-epoch-20.pth' #print list(model.children()) Net.load_state_dict(torch.load(model)) Net = Net.cuda() print 'Done.\n' print 'time:', time.time() - t1, '\n' print 'compute output....\n' infer(Net) return 0
def main(args): modelpath = args.loadDir + args.loadModel weightspath = args.loadDir + args.loadWeights print("Loading model: " + modelpath) print("Loading weights: " + weightspath) #Import ERFNet model from the folder #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet") model = ERFNet(NUM_CLASSES) model = torch.nn.DataParallel(model) if (not args.cpu): model = model.cuda() #model.load_state_dict(torch.load(args.state)) #model.load_state_dict(torch.load(weightspath)) #not working if missing key def load_my_state_dict( model, state_dict ): #custom function to load model when not all dict elements own_state = model.state_dict() for name, param in state_dict.items(): if name not in own_state: continue own_state[name].copy_(param) return model model = load_my_state_dict(model, torch.load(weightspath)) print("Model and weights LOADED successfully") model.eval() if (not os.path.exists(args.datadir)): print("Error: datadir could not be loaded") loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset), num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) # For visualizer: # must launch in other window "python3.6 -m visdom.server -port 8097" # and access localhost:8097 to see it if (args.visualize): vis = visdom.Visdom() for step, (images, labels, filename, filenameGt) in enumerate(loader): if (not args.cpu): images = images.cuda() #labels = labels.cuda() inputs = Variable(images) #targets = Variable(labels) with torch.no_grad(): outputs = model(inputs) label = outputs[0].max(0)[1].byte().cpu().data #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0)) label_color = Colorize()(label.unsqueeze(0)) filenameSave = "./save_color/" + filename[0].split("leftImg8bit/")[1] os.makedirs(os.path.dirname(filenameSave), exist_ok=True) #image_transform(label.byte()).save(filenameSave) label_save = ToPILImage()(label_color) label_save.save(filenameSave) if (args.visualize): vis.image(label_color.numpy()) print(step, filenameSave)
def main(args): modelpath = args.loadDir + args.loadModel weightspath = args.loadDir + args.loadWeights print("Loading model: " + modelpath) print("Loading weights: " + weightspath) model = ERFNet(NUM_CLASSES) model = torch.nn.DataParallel(model) if (not args.cpu): model = model.cuda() def load_my_state_dict( model, state_dict ): #custom function to load model when not all dict elements own_state = model.state_dict() for name, param in state_dict.items(): if name not in own_state: print(name, " not loaded") continue own_state[name].copy_(param) return model model = load_my_state_dict(model, torch.load(weightspath)) print("Model and weights LOADED successfully") model.eval() if (not os.path.exists(args.datadir)): print("Error: datadir could not be loaded") loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset), num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) iouEvalVal = iouEval(NUM_CLASSES) start = time.time() for step, (images, labels, filename, filenameGt) in enumerate(loader): if (not args.cpu): images = images.cuda() labels = labels.cuda() inputs = Variable(images, volatile=True) outputs = model(inputs) iouEvalVal.addBatch(outputs.max(1)[1].unsqueeze(1).data, labels) filenameSave = filename[0].split("leftImg8bit/")[1] print(step, filenameSave) iouVal, iou_classes = iouEvalVal.getIoU() iou_classes_str = [] for i in range(iou_classes.size(0)): iouStr = getColorEntry(iou_classes[i]) + '{:0.2f}'.format( iou_classes[i] * 100) + '\033[0m' iou_classes_str.append(iouStr) print("---------------------------------------") print("Took ", time.time() - start, "seconds") print("=======================================") #print("TOTAL IOU: ", iou * 100, "%") print("Per-Class IoU:") print(iou_classes_str[0], "Road") print(iou_classes_str[1], "sidewalk") print(iou_classes_str[2], "building") print(iou_classes_str[3], "wall") print(iou_classes_str[4], "fence") print(iou_classes_str[5], "pole") print(iou_classes_str[6], "traffic light") print(iou_classes_str[7], "traffic sign") print(iou_classes_str[8], "vegetation") print(iou_classes_str[9], "terrain") print(iou_classes_str[10], "sky") print(iou_classes_str[11], "person") print(iou_classes_str[12], "rider") print(iou_classes_str[13], "car") print(iou_classes_str[14], "truck") print(iou_classes_str[15], "bus") print(iou_classes_str[16], "train") print(iou_classes_str[17], "motorcycle") print(iou_classes_str[18], "bicycle") print("=======================================") iouStr = getColorEntry(iouVal) + '{:0.2f}'.format(iouVal * 100) + '\033[0m' print("MEAN IoU: ", iouStr, "%")
class Model: def __init__(self, args): self.args = args self.device = args.device if args.model == 'stackhourglass': self.model = ERFNet(2) self.model.cuda() #= self.model.to(self.device) if args.use_multiple_gpu: self.model = torch.nn.DataParallel(self.model) self.model.cuda() if args.mode == 'train': self.loss_function = MonodepthLoss( n=3, SSIM_w=0.8, disp_gradient_w=0.1, lr_w=1).to(self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=args.learning_rate) self.val_n_img, self.val_loader = prepare_dataloader(args.val_data_dir, 'test', args.augment_parameters, False, 1, (args.input_height, args.input_width), args.num_workers) # self.model.load_state_dict(torch.load(args.loadmodel))#jiazai yuxunlian moxing jiezhu ciju else: self.model.load_state_dict(torch.load(args.model_path)) args.augment_parameters = None args.do_augmentation = False args.batch_size = 1 self.n_img, self.loader = prepare_dataloader(args.data_dir, args.mode, args.augment_parameters, args.do_augmentation, args.batch_size, (args.input_height, args.input_width), args.num_workers) if 'cuda' in self.device: torch.cuda.synchronize() def disparity_loader(self,path): return Image.open(path) def train(self): # losses = [] #val_losses = [] best_loss = float('Inf') best_val_loss = 1000000.0 for epoch in range(self.args.epochs): # losses = [] if self.args.adjust_lr: adjust_learning_rate(self.optimizer, epoch, self.args.learning_rate) c_time = time.time() running_loss = 0.0 self.model.train() #?start? for data in self.loader: #(test-for-train) # Load data self.optimizer.zero_grad() # loss=0 data = to_device(data, self.device) left = data['left_image'] bg_image = data['bg_image'] # plt.figure() # plt.imshow(left.squeeze().cpu().detach().numpy()) # plt.show() # plt.imshow(bg_image.squeeze().cpu().detach().numpy()) # plt.show() disps = self.model(left) # print('gggggggggggggggggggggggggggggggggggggggggg',left.shape,disps.shape) # plt.imshow(disps.squeeze().cpu().detach().numpy()) # plt.show() loss = self.loss_function(disps,bg_image) ssim_loss = ssim(disps,bg_image) ssim=1-ssim_loss loss1=loss*1+ssim*1 loss1.backward() self.optimizer.step() # losses.append(loss.item()) running_loss += loss1.item() # print(' time = %.2f' %(time.time() - c_time)) running_loss /=( self.n_img / self.args.batch_size) # print('running_loss:', running_loss) # running_val_loss /= (self.val_n_img / self.args.batch_size) print('Epoch:',epoch + 1,'train_loss:',running_loss,'time:',round(time.time() - c_time, 3),'s') TrueLoss=math.sqrt( running_loss )*255 print ('TrueLoss:',TrueLoss) if epoch%5==0: self.model.eval() i=0 running_val_loss = 0.0 for data in self.val_loader: data = to_device(data, self.device) left = data['left_image'] bg_image = data['bg_image'] with torch.no_grad(): # newinput=torch.cat([left,bg],1) # disps = self.model(newinput) disps = self.model(left) loss = self.loss_function(disps,bg_image) ssim_loss = ssim(disps,bg_image) ssim=1-ssim_loss loss1=loss*1+ssim*1 # loss1 = self.loss_function((disps+left1.float()),target,mask) running_val_loss+=loss1 running_val_loss/=200 print( 'running_val_loss = %.12f' %(running_val_loss)) if running_val_loss < best_val_loss: self.save(self.args.model_path[:-4] + '_cpt.pth') best_val_loss = running_val_loss print('Model_saved') # self.save(self.args.model_path[:-4] + '_'+str(epoch+1)+'.pth') # if epoch==100: # self.save(self.args.model_path[:-4] + '_100.pth') # if epoch==150: # self.save(self.args.model_path[:-4] + '_150.pth') # if epoch==200: # self.save(self.args.model_path[:-4] + '_200.pth') # if epoch==250: # self.save(self.args.model_path[:-4] + '_250.pth') # self.save(self.args.model_path[:-4] + '_last.pth')#上一回 # if running_loss < best_val_loss: # #print(running_val_loss) # #print(best_val_loss) # self.save(self.args.model_path[:-4] + '_cpt.pth') # best_val_loss = running_loss # print('Model_saved') # print ('Finished Training. Best loss:', best_loss) #self.save(self.args.model_path) def save(self, path): torch.save(self.model.state_dict(), path) def load(self, path): self.model.load_state_dict(torch.load(path))
def main(args): modelpath = args.loadDir + args.loadModel weightspath = args.loadDir + args.loadWeights print("Loading model: " + modelpath) print("Loading weights: " + weightspath) #Import ERFNet model from the folder #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet") model = ERFNet(NUM_CLASSES) model = torch.nn.DataParallel(model) if (not args.cpu): model = model.cuda() #model.load_state_dict(torch.load(args.state)) #model.load_state_dict(torch.load(weightspath)) #not working if missing key fourcc = cv2.VideoWriter_fourcc(*'MP4V') # Save as video out = cv2.VideoWriter('output.mp4', fourcc, 20.0, (640, 352)) def load_my_state_dict( model, state_dict ): #custom function to load model when not all dict elements own_state = model.state_dict() for name, param in state_dict.items(): if name not in own_state: continue own_state[name].copy_(param) return model model = load_my_state_dict( model, torch.load(weightspath, map_location=torch.device('cpu'))) print("Model and weights LOADED successfully") model.eval() # loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset), # num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) # For visualizer: # must launch in other window "python3.6 -m visdom.server -port 8097" # and access localhost:8097 to see it # if (args.visualize): # vis = visdom.Visdom() # for step, (images, labels, filename, filenameGt) in enumerate(loader): # cap = cv2.VideoCapture(0) cap = cv2.VideoCapture('project_video_trimmed.mp4') while (True): # Capture frame-by-frame ret, images = cap.read() # print(images.shape) images = trans(images) images = images.float() images = images.view((1, 3, 352, 640)) # vidoe # Our operations on the frame come here if (not args.cpu): images = images.cuda() inputs = Variable(images) with torch.no_grad(): outputs = model(inputs) label = outputs[0].max(0)[1].byte().cpu().data #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0)) label_color = Colorize()(label.unsqueeze(0)) frame = label_color.numpy().transpose(1, 2, 0) # label_save = ToPILImage()(label_color) # label_save.save("result_1.png") # Display the resulting frame cv2.imshow('Segmented Image', frame) out.write(frame) # To save video file if cv2.waitKey(1) & 0xFF == ord('q'): break cap.release() out.release() cv2.destroyAllWindows() if (args.visualize): vis.image(label_color.numpy()) print(step, filenameSave)
def main(args): modelpath = args.loadDir + args.loadModel weightspath = args.loadDir + args.loadWeights if (not os.path.exists(args.datadir)): print("Error: datadir could not be loaded") else: print("Loading model: " + modelpath) print("Loading weights: " + weightspath) # Import ERFNET model = ERFNet(NUM_CLASSES) model = torch.nn.DataParallel(model) if (not args.cpu): model = model.cuda() def load_my_state_dict( model, state_dict ): #custom function to load model when not all dict elements own_state = model.state_dict() for name, param in state_dict.items(): if name not in own_state: continue own_state[name].copy_(param) return model model = load_my_state_dict(model, torch.load(weightspath)) print("Model and weights LOADED successfully") # Set model to Evaluation mode model.eval() # Setup the dataset loader ### RELLIS-3D Dataloader enc = False loader_test = custom_datasets.setup_loaders(args, enc) # For visualizer: # must launch in other window "python3.6 -m visdom.server -port 8097" # and access localhost:8097 to see it if (args.visualize): vis = visdom.Visdom() for step, (images, labels, img_name, _) in enumerate(loader_test): start_time = time.time() if (not args.cpu): images = images.cuda() #labels = labels.cuda() inputs = Variable(images) with torch.no_grad(): outputs = model(inputs) label = outputs[0].max(0)[1].byte().cpu().data label_color = Colorize()(label.unsqueeze(0)) eval_save_path = "./save_colour_rellis/" if not os.path.exists(eval_save_path): os.makedirs(eval_save_path) _, file_name = os.path.split(img_name[0]) file_name = file_name + ".png" #image_transform(label.byte()).save(filenameSave) label_save = ToPILImage()(label_color) label_save.save(os.path.join(eval_save_path, file_name)) if (args.visualize): vis.image(label_color.numpy()) if step != 0: #first run always takes some time for setup fwt = time.time() - start_time time_train.append(fwt) print("Forward time per img (b=%d): %.3f (Mean: %.3f)" % (args.batch_size, fwt / args.batch_size, sum(time_train) / len(time_train) / args.batch_size)) print(step, os.path.join(eval_save_path, file_name))
def main(args): modelpath = args.loadDir + args.loadModel #weightspath = args.loadDir + args.loadWeights #TODO weightspath = "/home/pan/repository/erfnet_pytorch/save/geoMat_2/model_best.pth" # print ("Loading model: " + modelpath) print("Loading weights: " + weightspath) #Import ERFNet model from the folder #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet") model = ERFNet(NUM_CLASSES) model = torch.nn.DataParallel(model) if (not args.cpu): model = model.cuda() #model.load_state_dict(torch.load(args.state)) #model.load_state_dict(torch.load(weightspath)) #not working if missing key def load_my_state_dict( model, state_dict ): #custom function to load model when not all dict elements own_state = model.state_dict() for name, param in state_dict.items(): if name not in own_state: continue own_state[name].copy_(param) return model model = load_my_state_dict(model, torch.load(weightspath)) print("Model and weights LOADED successfully") model.eval() if (not os.path.exists(args.datadir)): print("Error: datadir could not be loaded") loader = DataLoader(geoMat(args.datadir, input_transform_geoMat, target_transform_geoMat, subset=args.subset), num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) # For visualizer: # must launch in other window "python3.6 -m visdom.server -port 8097" # and access localhost:8097 to see it if (args.visualize): vis = visdom.Visdom() for step, (images, labels, filename, filenameGt) in enumerate(loader): if (not args.cpu): images = images.cuda() #labels = labels.cuda() inputs = Variable(images) #targets = Variable(labels) with torch.no_grad(): outputs = model(inputs) label = outputs[0].max(0)[1].byte().cpu().data # print(label.shape) #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0)) label_color = (label.unsqueeze(0)) # print(label_color.shape) filenameSave = "./getMat_2/" + filename[0].split( "material_dataset_v2")[1] os.makedirs(os.path.dirname(filenameSave), exist_ok=True) #image_transform(label.byte()).save(filenameSave) # label_save = ToPILImage()(label_color) label_save = label_color.numpy() label_save = label_save.transpose(1, 2, 0) # label_save.save(filenameSave) images = images.cpu().numpy().squeeze(axis=0) images = images.transpose(1, 2, 0) # print(images.shape) # print(label_save.shape) plt.figure(figsize=(10.4, 10.4), dpi=10) # plt.imshow(images) plt.imshow(label_save, alpha=0.6, cmap='gray') plt.axis('off') # plt.show() plt.savefig(filenameSave, dpi=10) plt.close() if (args.visualize): vis.image(label_color.numpy()) print(step, filenameSave)
class Model: def __init__(self, args): self.args = args self.device = args.device if args.model == 'stackhourglass': self.model = ERFNet(1) self.model2 = ldyNet(1) self.model.cuda() #= self.model.to(self.device) self.model2.cuda() if args.use_multiple_gpu: self.model = torch.nn.DataParallel(self.model) self.model.cuda() self.model2 = torch.nn.DataParallel(self.model2) self.model2.cuda() if args.mode == 'train': self.loss_function = MonodepthLoss(n=3, SSIM_w=0.8, disp_gradient_w=0.1, lr_w=1).to(self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=args.learning_rate) self.optimizer2 = optim.Adam(self.model2.parameters(), lr=args.learning_rate) self.val_n_img, self.val_loader = prepare_dataloader( args.val_data_dir, 'test', args.augment_parameters, False, 1, (args.input_height, args.input_width), args.num_workers) # self.model.load_state_dict(torch.load(args.loadmodel))#jiazai yuxunlian moxing jiezhu ciju else: self.model.load_state_dict(torch.load(args.model_path)) self.model2.load_state_dict(torch.load(args.model2_path)) args.augment_parameters = None args.do_augmentation = False args.batch_size = 1 self.n_img, self.loader = prepare_dataloader( args.data_dir, args.mode, args.augment_parameters, args.do_augmentation, args.batch_size, (args.input_height, args.input_width), args.num_workers) if 'cuda' in self.device: torch.cuda.synchronize() def disparity_loader(self, path): return Image.open(path) def train(self): # losses = [] #val_losses = [] best_loss = float('Inf') best_val_loss = 1000000.0 for epoch in range(self.args.epochs): # losses = [] if self.args.adjust_lr: adjust_learning_rate(self.optimizer, epoch, self.args.learning_rate) if self.args.adjust_lr: adjust_learning_rate2(self.optimizer2, epoch, self.args.learning_rate) c_time = time.time() running_loss = 0.0 self.model.train() #?start? self.model2.train() for data in self.loader: #(test-for-train) # Load data data = to_device(data, self.device) left = data['left_image'] bg_image = data['bg_image'] #template=data['template'] #randomMatrix=np.random.randint(0,1,(256,256)) # Net1Iput=torch.from_numpy(randomMatrix) disps = self.model(left) #j=0 # while j<(self.args.batch_size-1): # disps1=disps i = 0 left1 = left while i < 15: left1 = torch.cat((left1, left), 1) i = i + 1 Net2Iput = disps.mul(left1) #点乘 Net2Iput11, Net2Iput12 = Net2Iput.split(128, 2) Net2Iput11, Net2Iput21 = Net2Iput11.split(128, 3) Net2Iput12, Net2Iput22 = Net2Iput12.split(128, 3) # Net2Iput1=Net2Iput[:127] # Net2Iput2=Net2Iput[128:] # Net2Iput11=Net2Iput1[:,0:127] # Net2Iput12=Net2Iput1[:,128:255] # Net2Iput21=Net2Iput2[:,0:127] # Net2Iput22=Net2Iput2[:,128:255] Net2Iput11 = torch.sum(Net2Iput11, 2, keepdim=True, out=None) Net2Iput11 = torch.sum(Net2Iput11, 3, keepdim=True, out=None) Net2Iput12 = torch.sum(Net2Iput12, 2, keepdim=True, out=None) Net2Iput12 = torch.sum(Net2Iput12, 3, keepdim=True, out=None) Net2Iput21 = torch.sum(Net2Iput21, 2, keepdim=True, out=None) Net2Iput21 = torch.sum(Net2Iput21, 3, keepdim=True, out=None) Net2Iput22 = torch.sum(Net2Iput22, 2, keepdim=True, out=None) Net2Iput22 = torch.sum(Net2Iput22, 3, keepdim=True, out=None) Net2Iput1 = torch.cat((Net2Iput11, Net2Iput12), 2) Net2Iput2 = torch.cat((Net2Iput21, Net2Iput22), 2) Net2Iput = torch.cat((Net2Iput1, Net2Iput2), 3) # Net2Iput=torch.cat((Net2Iput,Net2Iput22),1) ps = nn.PixelShuffle(4) Net2Iput = ps(Net2Iput) Net2Out = self.model2(Net2Iput) self.optimizer2.zero_grad() lossnet2 = self.loss_function(Net2Out, bg_image) lossnet2.backward(retain_graph=True) self.optimizer2.step() self.optimizer.zero_grad() # loss=0 # plt.figure() # plt.imshow(left.squeeze().cpu().detach().numpy()) # plt.show() # plt.imshow(bg_image.squeeze().cpu().detach().numpy()) # plt.show() #disps = self.model(left) # print('gggggggggggggggggggggggggggggggggggggggggg',left.shape,disps.shape) # plt.imshow(disps.squeeze().cpu().detach().numpy()) # plt.show() loss1 = self.loss_function(Net2Out, bg_image) loss1.backward(retain_graph=True) self.optimizer.step() # losses.append(loss.item()) running_loss += lossnet2.item() # print(' time = %.2f' %(time.time() - c_time)) running_loss /= (self.n_img / self.args.batch_size) # print('running_loss:', running_loss) # running_val_loss /= (self.val_n_img / self.args.batch_size) print('Epoch:', epoch + 1, 'train_loss:', running_loss, 'time:', round(time.time() - c_time, 3), 's') TrueLoss = math.sqrt(running_loss) * 255 print('TrueLoss:', TrueLoss) if epoch % 2 == 0: #self.model.eval() self.model.eval() i = 0 running_val_loss = 0.0 for data in self.val_loader: data = to_device(data, self.device) left = data['left_image'] bg_image = data['bg_image'] #with torch.no_grad(): # newinput=torch.cat([left,bg],1) # disps = self.model(newinput) #disps = self.model(left) #Net2Out=self.model2(disps) #lossnet2 = self.loss_function(Net2Out,bg_image) # loss1 = self.loss_function((disps+left1.float()),target,mask) running_val_loss += lossnet2 #running_val_loss/=100 print('running_val_loss = %.12f' % (running_val_loss)) if running_val_loss < best_val_loss: self.save(self.args.model_path[:-4] + '_cpt.pth') self.save2(self.args.model2_path[:-4] + '_cpt.pth') best_val_loss = running_val_loss print('Model_saved') # self.save(self.args.model_path[:-4] + '_'+str(epoch+1)+'.pth') # if epoch==100: # self.save(self.args.model_path[:-4] + '_100.pth') # if epoch==150: # self.save(self.args.model_path[:-4] + '_150.pth') # if epoch==200: # self.save(self.args.model_path[:-4] + '_200.pth') # if epoch==250: # self.save(self.args.model_path[:-4] + '_250.pth') # self.save(self.args.model_path[:-4] + '_last.pth')#上一回 # if running_loss < best_val_loss: # #print(running_val_loss) # #print(best_val_loss) # self.save(self.args.model_path[:-4] + '_cpt.pth') # best_val_loss = running_loss # print('Model_saved') # print ('Finished Training. Best loss:', best_loss) #self.save(self.args.model_path) def save(self, path): torch.save(self.model.state_dict(), path) def save2(self, path): torch.save(self.model2.state_dict(), path) def load(self, path): self.model.load_state_dict(torch.load(path))
def main(): Net = ERFNet(NUM_CLASSES) Net = Net.cuda() train(Net)
class Model: def __init__(self, args): self.args = args self.device = args.device if args.model == 'stackhourglass': self.model = ERFNet(2) self.model.cuda() self.model.load_state_dict(torch.load(args.loadmodel)) # print(self.model.state_dict()) # print('kkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk') self.n_img, self.loader = prepare_dataloader(args.data_dir, args.mode, args.augment_parameters, args.do_augmentation, args.batch_size, (args.input_height, args.input_width), args.num_workers) if 'cuda' in self.device: torch.cuda.synchronize() def test(self): i = 0 sum_loss = 0.0 sum_ssim = 0.0 average_ssim = 0.0 average_loss = 0.0 for epoch in range(self.args.epochs): self.model.eval() #?start? for data in self.loader: #(test-for-train) i = i + 1 data = to_device(data, self.device) left = data['left_image'] bg_image = data['bg_image'] disps = self.model(left) print(disps.shape) l_loss = l1loss(disps,bg_image) ssim_loss = ssim(disps,bg_image) psnr222 = psnr1(disps,bg_image) sum_loss = sum_loss + l_loss.item() sum_ssim += ssim_loss.item() average_ssim = sum_ssim / i average_loss = sum_loss / i # print average_loss disp_show = disps.squeeze() bg_show = bg_image.squeeze() print(bg_show.shape) plt.figure() plt.subplot(1,2,1) plt.imshow(disp_show.data.cpu().numpy()) plt.subplot(1,2,2) plt.imshow(bg_show.data.cpu().numpy()) plt.show() print('average loss',average_loss,average_ssim)
def main(args): modelpath = args.loadDir + args.loadModel weightspath = args.loadDir + args.loadWeights print("Loading model: " + modelpath) print("Loading weights: " + weightspath) #Import ERFNet model from the folder #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet") if args.pretrainedEncoder: pretrainedEnc = torch.nn.DataParallel(ERFNet_imagenet(1000)) #pretrainedEnc.load_state_dict(torch.load(args.pretrainedEncoder)['state_dcit']) pretrainedEnc = next(pretrainedEnc.children()).features.encoder if (not args.cuda): pretrainedEnc = pretrainedEnc.cpu() model = ERFNet(NUM_CLASSES, encoder=pretrainedEnc) else: model = ERFNet(NUM_CLASSES) model = torch.nn.DataParallel(model) if args.cuda: model = model.cuda() #model.load_state_dict(torch.load(args.state)) #model.load_state_dict(torch.load(weightspath)) #not working if missing key def load_my_state_dict( model, state_dict ): #custom function to load model when not all dict elements own_state = model.state_dict() for name, param in state_dict.items(): if name not in own_state: continue own_state[name].copy_(param) return model model = load_my_state_dict(model, torch.load(weightspath)) print("Model and weights LOADED successfully") model.eval() if (not os.path.exists(args.datadir)): print("Error: datadir could not be loaded") dataset_test = avlane(args.datadir, input_transform_lot, label_transform_lot, 'test') loader = DataLoader(dataset_test, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) # For visualizer: # must launch in other window "python3.6 -m visdom.server -port 8097" # and access localhost:8097 to see it if (args.visualize): vis = visdom.Visdom() fig = plt.figure() ax = fig.gca() h = ax.imshow(Image.new('RGB', (640 * 2, 480), 0)) print(len(loader.dataset)) iouEvalTest = iouEval_binary(NUM_CLASSES) with torch.no_grad(): for step, (images, labels, filename) in enumerate(loader): #print(images.shape) if args.cuda: images = images.cuda() labels = labels.cuda() inputs = Variable(images) targets = labels outputs = model(inputs) preds = torch.where(outputs > 0.5, torch.ones([1], dtype=torch.long).cuda(), torch.zeros([1], dtype=torch.long).cuda()) #preds = torch.where(outputs > 0.5, torch.ones([1], dtype=torch.uint8).cuda(), torch.zeros([1], dtype=torch.uint8).cuda()) # b x 1 x h x w #label = outputs[0].max(0)[1].byte().cpu().data #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0)) #label_color = Colorize()(label.unsqueeze(0)) # iou iouEvalTest.addBatch(preds[:, 0], targets[:, 0]) # no_grad handles it already iouTest = iouEvalTest.getIoU() iouStr = "test IOU: " + '{:0.2f}'.format( iouTest.item() * 100) + "%" print(iouStr) # save the output filenameSave = os.path.join(args.loadDir, 'test_results', filename[0].split("test/")[1]) os.makedirs(os.path.dirname(filenameSave), exist_ok=True) #image_transform(label.byte()).save(filenameSave) #label_save = ToPILImage()(label) pred = preds.to(torch.uint8).squeeze(0).cpu() # 1xhxw pred_save = pred_transform_lot(pred) #pred_save.save(filenameSave) # concatenate data & result im1 = Image.open( os.path.join(args.datadir, 'data/test', filename[0].split("test/")[1])).convert('RGB') im2 = pred_save dst = Image.new('RGB', (im1.width + im2.width, im1.height)) dst.paste(im1, (0, 0)) dst.paste(im2, (im1.width, 0)) filenameSaveConcat = os.path.join(args.loadDir, 'test_results_concat', filename[0].split("test/")[1]) os.makedirs(os.path.dirname(filenameSaveConcat), exist_ok=True) #dst.save(filenameSaveConcat) # wrtie iou on dst font = ImageFont.truetype( '/usr/share/fonts/truetype/freefont/FreeMonoBold.ttf', 36) d = ImageDraw.Draw(dst) d.text((900, 0), iouStr, font=font, fill=(255, 255, 0)) # show video h.set_data(dst) plt.draw() plt.axis('off') plt.pause(1e-2) if (args.visualize): vis.image(label_save.numpy()) print(step, filenameSave)