def train(): train_dataset = Cityscapes(root='.', subset='train', transform=MyTransform()) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True) erfnet = ERFNet(num_classes) optimizer = Adam(erfnet.parameters(), 5e-4, (0.9, 0.999), eps=1e-08, weight_decay=1e-4) lambda1 = lambda epoch: pow((1 - ((epoch - 1) / num_epochs)), 0.9) scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) criterion = nn.NLLLoss2d(weight) iou_meter = meter.AverageValueMeter() erfnet.cuda() erfnet.train() for epoch in range(num_epochs): print("Epoch {} : ".format(epoch + 1)) total_loss = [] iou_meter.reset() scheduler.step(epoch) for i, (images, labels) in enumerate(train_loader): images = Variable(images.cuda(), requires_grad=True).cuda() labels = Variable(labels.cuda()).cuda() outputs = erfnet(images) optimizer.zero_grad() loss = criterion(torch.nn.functional.log_softmax(outputs), labels) loss.backward() optimizer.step() iou = IoU(outputs.max(1)[1].data, labels.data) iou_meter.add(iou) total_loss.append(loss.data[0] * batch_size) iou_avg = iou_meter.value()[0] loss_avg = sum(total_loss) / len(total_loss) scheduler.step(loss_avg) print("IoU : {:.5f}".format(iou_avg)) print("Loss : {:.5f}".format(loss_avg)) torch.save(erfnet, './model/erfnet' + str(epoch) + '.pth') eval(epoch)
class Model: def __init__(self, args): self.args = args self.device = args.device if args.model == 'stackhourglass': self.model = ERFNet(2) self.model.cuda() #= self.model.to(self.device) if args.use_multiple_gpu: self.model = torch.nn.DataParallel(self.model) self.model.cuda() if args.mode == 'train': self.loss_function = MonodepthLoss( n=3, SSIM_w=0.8, disp_gradient_w=0.1, lr_w=1).to(self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=args.learning_rate) self.val_n_img, self.val_loader = prepare_dataloader(args.val_data_dir, 'test', args.augment_parameters, False, 1, (args.input_height, args.input_width), args.num_workers) # self.model.load_state_dict(torch.load(args.loadmodel))#jiazai yuxunlian moxing jiezhu ciju else: self.model.load_state_dict(torch.load(args.model_path)) args.augment_parameters = None args.do_augmentation = False args.batch_size = 1 self.n_img, self.loader = prepare_dataloader(args.data_dir, args.mode, args.augment_parameters, args.do_augmentation, args.batch_size, (args.input_height, args.input_width), args.num_workers) if 'cuda' in self.device: torch.cuda.synchronize() def disparity_loader(self,path): return Image.open(path) def train(self): # losses = [] #val_losses = [] best_loss = float('Inf') best_val_loss = 1000000.0 for epoch in range(self.args.epochs): # losses = [] if self.args.adjust_lr: adjust_learning_rate(self.optimizer, epoch, self.args.learning_rate) c_time = time.time() running_loss = 0.0 self.model.train() #?start? for data in self.loader: #(test-for-train) # Load data self.optimizer.zero_grad() # loss=0 data = to_device(data, self.device) left = data['left_image'] bg_image = data['bg_image'] # plt.figure() # plt.imshow(left.squeeze().cpu().detach().numpy()) # plt.show() # plt.imshow(bg_image.squeeze().cpu().detach().numpy()) # plt.show() disps = self.model(left) # print('gggggggggggggggggggggggggggggggggggggggggg',left.shape,disps.shape) # plt.imshow(disps.squeeze().cpu().detach().numpy()) # plt.show() loss = self.loss_function(disps,bg_image) ssim_loss = ssim(disps,bg_image) ssim=1-ssim_loss loss1=loss*1+ssim*1 loss1.backward() self.optimizer.step() # losses.append(loss.item()) running_loss += loss1.item() # print(' time = %.2f' %(time.time() - c_time)) running_loss /=( self.n_img / self.args.batch_size) # print('running_loss:', running_loss) # running_val_loss /= (self.val_n_img / self.args.batch_size) print('Epoch:',epoch + 1,'train_loss:',running_loss,'time:',round(time.time() - c_time, 3),'s') TrueLoss=math.sqrt( running_loss )*255 print ('TrueLoss:',TrueLoss) if epoch%5==0: self.model.eval() i=0 running_val_loss = 0.0 for data in self.val_loader: data = to_device(data, self.device) left = data['left_image'] bg_image = data['bg_image'] with torch.no_grad(): # newinput=torch.cat([left,bg],1) # disps = self.model(newinput) disps = self.model(left) loss = self.loss_function(disps,bg_image) ssim_loss = ssim(disps,bg_image) ssim=1-ssim_loss loss1=loss*1+ssim*1 # loss1 = self.loss_function((disps+left1.float()),target,mask) running_val_loss+=loss1 running_val_loss/=200 print( 'running_val_loss = %.12f' %(running_val_loss)) if running_val_loss < best_val_loss: self.save(self.args.model_path[:-4] + '_cpt.pth') best_val_loss = running_val_loss print('Model_saved') # self.save(self.args.model_path[:-4] + '_'+str(epoch+1)+'.pth') # if epoch==100: # self.save(self.args.model_path[:-4] + '_100.pth') # if epoch==150: # self.save(self.args.model_path[:-4] + '_150.pth') # if epoch==200: # self.save(self.args.model_path[:-4] + '_200.pth') # if epoch==250: # self.save(self.args.model_path[:-4] + '_250.pth') # self.save(self.args.model_path[:-4] + '_last.pth')#上一回 # if running_loss < best_val_loss: # #print(running_val_loss) # #print(best_val_loss) # self.save(self.args.model_path[:-4] + '_cpt.pth') # best_val_loss = running_loss # print('Model_saved') # print ('Finished Training. Best loss:', best_loss) #self.save(self.args.model_path) def save(self, path): torch.save(self.model.state_dict(), path) def load(self, path): self.model.load_state_dict(torch.load(path))
class Model: def __init__(self, args): self.args = args self.device = args.device if args.model == 'stackhourglass': self.model = ERFNet(1) self.model2 = ldyNet(1) self.model.cuda() #= self.model.to(self.device) self.model2.cuda() if args.use_multiple_gpu: self.model = torch.nn.DataParallel(self.model) self.model.cuda() self.model2 = torch.nn.DataParallel(self.model2) self.model2.cuda() if args.mode == 'train': self.loss_function = MonodepthLoss(n=3, SSIM_w=0.8, disp_gradient_w=0.1, lr_w=1).to(self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=args.learning_rate) self.optimizer2 = optim.Adam(self.model2.parameters(), lr=args.learning_rate) self.val_n_img, self.val_loader = prepare_dataloader( args.val_data_dir, 'test', args.augment_parameters, False, 1, (args.input_height, args.input_width), args.num_workers) # self.model.load_state_dict(torch.load(args.loadmodel))#jiazai yuxunlian moxing jiezhu ciju else: self.model.load_state_dict(torch.load(args.model_path)) self.model2.load_state_dict(torch.load(args.model2_path)) args.augment_parameters = None args.do_augmentation = False args.batch_size = 1 self.n_img, self.loader = prepare_dataloader( args.data_dir, args.mode, args.augment_parameters, args.do_augmentation, args.batch_size, (args.input_height, args.input_width), args.num_workers) if 'cuda' in self.device: torch.cuda.synchronize() def disparity_loader(self, path): return Image.open(path) def train(self): # losses = [] #val_losses = [] best_loss = float('Inf') best_val_loss = 1000000.0 for epoch in range(self.args.epochs): # losses = [] if self.args.adjust_lr: adjust_learning_rate(self.optimizer, epoch, self.args.learning_rate) if self.args.adjust_lr: adjust_learning_rate2(self.optimizer2, epoch, self.args.learning_rate) c_time = time.time() running_loss = 0.0 self.model.train() #?start? self.model2.train() for data in self.loader: #(test-for-train) # Load data data = to_device(data, self.device) left = data['left_image'] bg_image = data['bg_image'] #template=data['template'] #randomMatrix=np.random.randint(0,1,(256,256)) # Net1Iput=torch.from_numpy(randomMatrix) disps = self.model(left) #j=0 # while j<(self.args.batch_size-1): # disps1=disps i = 0 left1 = left while i < 15: left1 = torch.cat((left1, left), 1) i = i + 1 Net2Iput = disps.mul(left1) #点乘 Net2Iput11, Net2Iput12 = Net2Iput.split(128, 2) Net2Iput11, Net2Iput21 = Net2Iput11.split(128, 3) Net2Iput12, Net2Iput22 = Net2Iput12.split(128, 3) # Net2Iput1=Net2Iput[:127] # Net2Iput2=Net2Iput[128:] # Net2Iput11=Net2Iput1[:,0:127] # Net2Iput12=Net2Iput1[:,128:255] # Net2Iput21=Net2Iput2[:,0:127] # Net2Iput22=Net2Iput2[:,128:255] Net2Iput11 = torch.sum(Net2Iput11, 2, keepdim=True, out=None) Net2Iput11 = torch.sum(Net2Iput11, 3, keepdim=True, out=None) Net2Iput12 = torch.sum(Net2Iput12, 2, keepdim=True, out=None) Net2Iput12 = torch.sum(Net2Iput12, 3, keepdim=True, out=None) Net2Iput21 = torch.sum(Net2Iput21, 2, keepdim=True, out=None) Net2Iput21 = torch.sum(Net2Iput21, 3, keepdim=True, out=None) Net2Iput22 = torch.sum(Net2Iput22, 2, keepdim=True, out=None) Net2Iput22 = torch.sum(Net2Iput22, 3, keepdim=True, out=None) Net2Iput1 = torch.cat((Net2Iput11, Net2Iput12), 2) Net2Iput2 = torch.cat((Net2Iput21, Net2Iput22), 2) Net2Iput = torch.cat((Net2Iput1, Net2Iput2), 3) # Net2Iput=torch.cat((Net2Iput,Net2Iput22),1) ps = nn.PixelShuffle(4) Net2Iput = ps(Net2Iput) Net2Out = self.model2(Net2Iput) self.optimizer2.zero_grad() lossnet2 = self.loss_function(Net2Out, bg_image) lossnet2.backward(retain_graph=True) self.optimizer2.step() self.optimizer.zero_grad() # loss=0 # plt.figure() # plt.imshow(left.squeeze().cpu().detach().numpy()) # plt.show() # plt.imshow(bg_image.squeeze().cpu().detach().numpy()) # plt.show() #disps = self.model(left) # print('gggggggggggggggggggggggggggggggggggggggggg',left.shape,disps.shape) # plt.imshow(disps.squeeze().cpu().detach().numpy()) # plt.show() loss1 = self.loss_function(Net2Out, bg_image) loss1.backward(retain_graph=True) self.optimizer.step() # losses.append(loss.item()) running_loss += lossnet2.item() # print(' time = %.2f' %(time.time() - c_time)) running_loss /= (self.n_img / self.args.batch_size) # print('running_loss:', running_loss) # running_val_loss /= (self.val_n_img / self.args.batch_size) print('Epoch:', epoch + 1, 'train_loss:', running_loss, 'time:', round(time.time() - c_time, 3), 's') TrueLoss = math.sqrt(running_loss) * 255 print('TrueLoss:', TrueLoss) if epoch % 2 == 0: #self.model.eval() self.model.eval() i = 0 running_val_loss = 0.0 for data in self.val_loader: data = to_device(data, self.device) left = data['left_image'] bg_image = data['bg_image'] #with torch.no_grad(): # newinput=torch.cat([left,bg],1) # disps = self.model(newinput) #disps = self.model(left) #Net2Out=self.model2(disps) #lossnet2 = self.loss_function(Net2Out,bg_image) # loss1 = self.loss_function((disps+left1.float()),target,mask) running_val_loss += lossnet2 #running_val_loss/=100 print('running_val_loss = %.12f' % (running_val_loss)) if running_val_loss < best_val_loss: self.save(self.args.model_path[:-4] + '_cpt.pth') self.save2(self.args.model2_path[:-4] + '_cpt.pth') best_val_loss = running_val_loss print('Model_saved') # self.save(self.args.model_path[:-4] + '_'+str(epoch+1)+'.pth') # if epoch==100: # self.save(self.args.model_path[:-4] + '_100.pth') # if epoch==150: # self.save(self.args.model_path[:-4] + '_150.pth') # if epoch==200: # self.save(self.args.model_path[:-4] + '_200.pth') # if epoch==250: # self.save(self.args.model_path[:-4] + '_250.pth') # self.save(self.args.model_path[:-4] + '_last.pth')#上一回 # if running_loss < best_val_loss: # #print(running_val_loss) # #print(best_val_loss) # self.save(self.args.model_path[:-4] + '_cpt.pth') # best_val_loss = running_loss # print('Model_saved') # print ('Finished Training. Best loss:', best_loss) #self.save(self.args.model_path) def save(self, path): torch.save(self.model.state_dict(), path) def save2(self, path): torch.save(self.model2.state_dict(), path) def load(self, path): self.model.load_state_dict(torch.load(path))