Exemple #1
0
def train():
    train_dataset = Cityscapes(root='.',
                               subset='train',
                               transform=MyTransform())
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=1,
                                               shuffle=True)

    erfnet = ERFNet(num_classes)

    optimizer = Adam(erfnet.parameters(),
                     5e-4, (0.9, 0.999),
                     eps=1e-08,
                     weight_decay=1e-4)
    lambda1 = lambda epoch: pow((1 - ((epoch - 1) / num_epochs)), 0.9)
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
    criterion = nn.NLLLoss2d(weight)
    iou_meter = meter.AverageValueMeter()

    erfnet.cuda()
    erfnet.train()

    for epoch in range(num_epochs):
        print("Epoch {} : ".format(epoch + 1))
        total_loss = []
        iou_meter.reset()
        scheduler.step(epoch)
        for i, (images, labels) in enumerate(train_loader):
            images = Variable(images.cuda(), requires_grad=True).cuda()
            labels = Variable(labels.cuda()).cuda()

            outputs = erfnet(images)

            optimizer.zero_grad()
            loss = criterion(torch.nn.functional.log_softmax(outputs), labels)
            loss.backward()
            optimizer.step()

            iou = IoU(outputs.max(1)[1].data, labels.data)
            iou_meter.add(iou)
            total_loss.append(loss.data[0] * batch_size)

        iou_avg = iou_meter.value()[0]
        loss_avg = sum(total_loss) / len(total_loss)
        scheduler.step(loss_avg)
        print("IoU : {:.5f}".format(iou_avg))
        print("Loss : {:.5f}".format(loss_avg))
        torch.save(erfnet, './model/erfnet' + str(epoch) + '.pth')

        eval(epoch)
Exemple #2
0
class Model:

    def __init__(self, args):
        self.args = args

        self.device = args.device
        if args.model == 'stackhourglass':
            
                self.model = ERFNet(2)
       
       
        self.model.cuda() #= self.model.to(self.device)

        if args.use_multiple_gpu:
            self.model = torch.nn.DataParallel(self.model)
            self.model.cuda()
        if args.mode == 'train':
            self.loss_function = MonodepthLoss(
                n=3,
                SSIM_w=0.8,
                disp_gradient_w=0.1, lr_w=1).to(self.device)
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=args.learning_rate)
            self.val_n_img, self.val_loader = prepare_dataloader(args.val_data_dir, 'test',
                                                                 args.augment_parameters,
                                                                 False, 1,
                                                                 (args.input_height, args.input_width),
                                                                 args.num_workers)


            # self.model.load_state_dict(torch.load(args.loadmodel))#jiazai yuxunlian moxing jiezhu ciju 
                                      
        else:
            self.model.load_state_dict(torch.load(args.model_path))
            args.augment_parameters = None
            args.do_augmentation = False
            args.batch_size = 1
        self.n_img, self.loader = prepare_dataloader(args.data_dir, args.mode, 
                                                     args.augment_parameters,
                                                     args.do_augmentation, args.batch_size,
                                                     (args.input_height, args.input_width),
                                                     args.num_workers)


     
  
        
        if 'cuda' in self.device:
            torch.cuda.synchronize()

    

        
    def disparity_loader(self,path):
        return Image.open(path)
    def train(self):
        
        # losses = []
        #val_losses = []

        best_loss = float('Inf')
        best_val_loss = 1000000.0

        for epoch in range(self.args.epochs):

           
            # losses = []
       

       
            if self.args.adjust_lr:
                adjust_learning_rate(self.optimizer, epoch,
                                     self.args.learning_rate)
            c_time = time.time()
            running_loss = 0.0
            self.model.train()   #?start?

            for data in self.loader:       #(test-for-train)
                # Load data
                
                
                self.optimizer.zero_grad()
                # loss=0
               
                data = to_device(data, self.device)

               

                left = data['left_image']
                bg_image = data['bg_image']

            
                # plt.figure()
                # plt.imshow(left.squeeze().cpu().detach().numpy())
                # plt.show()  
                # plt.imshow(bg_image.squeeze().cpu().detach().numpy())
                # plt.show()  
                
                disps = self.model(left)
                
                # print('gggggggggggggggggggggggggggggggggggggggggg',left.shape,disps.shape)
                # plt.imshow(disps.squeeze().cpu().detach().numpy())
                # plt.show() 

                loss = self.loss_function(disps,bg_image)
                ssim_loss = ssim(disps,bg_image)
                ssim=1-ssim_loss
                loss1=loss*1+ssim*1
                loss1.backward()
       
                
             
        
                
                self.optimizer.step()
              
                
                # losses.append(loss.item())
                running_loss += loss1.item()
            # print(' time = %.2f' %(time.time() - c_time))
            running_loss /=( self.n_img / self.args.batch_size)
           
   
            # print('running_loss:', running_loss)
      
            # running_val_loss /= (self.val_n_img / self.args.batch_size)
            print('Epoch:',epoch + 1,'train_loss:',running_loss,'time:',round(time.time() - c_time, 3),'s')
                

            
            TrueLoss=math.sqrt( running_loss )*255

            print ('TrueLoss:',TrueLoss)
                



            if epoch%5==0:
                self.model.eval()
                i=0
                running_val_loss = 0.0
                for data in self.val_loader:
                    data = to_device(data, self.device)

                    left = data['left_image']
                    bg_image = data['bg_image']

            
   
                    with torch.no_grad():
                        # newinput=torch.cat([left,bg],1)
                        # disps = self.model(newinput)

                        disps = self.model(left)
            
                    
                    loss = self.loss_function(disps,bg_image)
                    ssim_loss = ssim(disps,bg_image)
                    ssim=1-ssim_loss
                    loss1=loss*1+ssim*1        
                    # loss1 = self.loss_function((disps+left1.float()),target,mask)
                    running_val_loss+=loss1
                        
                
                running_val_loss/=200

                print( 'running_val_loss = %.12f' %(running_val_loss))

            if running_val_loss < best_val_loss:
            
                self.save(self.args.model_path[:-4] + '_cpt.pth')
                best_val_loss = running_val_loss
                print('Model_saved')
            # self.save(self.args.model_path[:-4] + '_'+str(epoch+1)+'.pth')
            # if  epoch==100:
            #     self.save(self.args.model_path[:-4] + '_100.pth')  
            # if  epoch==150:
            #     self.save(self.args.model_path[:-4] + '_150.pth')  
            # if  epoch==200:
            #         self.save(self.args.model_path[:-4] + '_200.pth')
            # if  epoch==250:
            #     self.save(self.args.model_path[:-4] + '_250.pth')        
            # self.save(self.args.model_path[:-4] + '_last.pth')#上一回
            # if running_loss < best_val_loss:
            #     #print(running_val_loss)
            #     #print(best_val_loss)
            #     self.save(self.args.model_path[:-4] + '_cpt.pth')
            #     best_val_loss = running_loss
            #     print('Model_saved')

        # print ('Finished Training. Best loss:', best_loss)
        #self.save(self.args.model_path)

    def save(self, path):
        torch.save(self.model.state_dict(), path)

    def load(self, path):
        self.model.load_state_dict(torch.load(path))
Exemple #3
0
class Model:
    def __init__(self, args):
        self.args = args

        self.device = args.device
        if args.model == 'stackhourglass':

            self.model = ERFNet(1)
            self.model2 = ldyNet(1)

        self.model.cuda()  #= self.model.to(self.device)
        self.model2.cuda()

        if args.use_multiple_gpu:
            self.model = torch.nn.DataParallel(self.model)
            self.model.cuda()
            self.model2 = torch.nn.DataParallel(self.model2)
            self.model2.cuda()
        if args.mode == 'train':
            self.loss_function = MonodepthLoss(n=3,
                                               SSIM_w=0.8,
                                               disp_gradient_w=0.1,
                                               lr_w=1).to(self.device)
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=args.learning_rate)
            self.optimizer2 = optim.Adam(self.model2.parameters(),
                                         lr=args.learning_rate)
            self.val_n_img, self.val_loader = prepare_dataloader(
                args.val_data_dir, 'test', args.augment_parameters, False, 1,
                (args.input_height, args.input_width), args.num_workers)

            # self.model.load_state_dict(torch.load(args.loadmodel))#jiazai yuxunlian moxing jiezhu ciju

        else:
            self.model.load_state_dict(torch.load(args.model_path))
            self.model2.load_state_dict(torch.load(args.model2_path))
            args.augment_parameters = None
            args.do_augmentation = False
            args.batch_size = 1
        self.n_img, self.loader = prepare_dataloader(
            args.data_dir, args.mode, args.augment_parameters,
            args.do_augmentation, args.batch_size,
            (args.input_height, args.input_width), args.num_workers)

        if 'cuda' in self.device:
            torch.cuda.synchronize()

    def disparity_loader(self, path):
        return Image.open(path)

    def train(self):

        # losses = []
        #val_losses = []

        best_loss = float('Inf')
        best_val_loss = 1000000.0

        for epoch in range(self.args.epochs):

            # losses = []

            if self.args.adjust_lr:
                adjust_learning_rate(self.optimizer, epoch,
                                     self.args.learning_rate)
            if self.args.adjust_lr:
                adjust_learning_rate2(self.optimizer2, epoch,
                                      self.args.learning_rate)
            c_time = time.time()
            running_loss = 0.0
            self.model.train()  #?start?
            self.model2.train()

            for data in self.loader:  #(test-for-train)
                # Load data
                data = to_device(data, self.device)
                left = data['left_image']
                bg_image = data['bg_image']
                #template=data['template']
                #randomMatrix=np.random.randint(0,1,(256,256))
                # Net1Iput=torch.from_numpy(randomMatrix)
                disps = self.model(left)
                #j=0
                # while j<(self.args.batch_size-1):
                #     disps1=disps

                i = 0
                left1 = left
                while i < 15:

                    left1 = torch.cat((left1, left), 1)
                    i = i + 1
                Net2Iput = disps.mul(left1)  #点乘
                Net2Iput11, Net2Iput12 = Net2Iput.split(128, 2)
                Net2Iput11, Net2Iput21 = Net2Iput11.split(128, 3)
                Net2Iput12, Net2Iput22 = Net2Iput12.split(128, 3)

                # Net2Iput1=Net2Iput[:127]
                # Net2Iput2=Net2Iput[128:]
                # Net2Iput11=Net2Iput1[:,0:127]
                # Net2Iput12=Net2Iput1[:,128:255]
                # Net2Iput21=Net2Iput2[:,0:127]
                # Net2Iput22=Net2Iput2[:,128:255]

                Net2Iput11 = torch.sum(Net2Iput11, 2, keepdim=True, out=None)
                Net2Iput11 = torch.sum(Net2Iput11, 3, keepdim=True, out=None)
                Net2Iput12 = torch.sum(Net2Iput12, 2, keepdim=True, out=None)
                Net2Iput12 = torch.sum(Net2Iput12, 3, keepdim=True, out=None)
                Net2Iput21 = torch.sum(Net2Iput21, 2, keepdim=True, out=None)
                Net2Iput21 = torch.sum(Net2Iput21, 3, keepdim=True, out=None)
                Net2Iput22 = torch.sum(Net2Iput22, 2, keepdim=True, out=None)
                Net2Iput22 = torch.sum(Net2Iput22, 3, keepdim=True, out=None)
                Net2Iput1 = torch.cat((Net2Iput11, Net2Iput12), 2)
                Net2Iput2 = torch.cat((Net2Iput21, Net2Iput22), 2)
                Net2Iput = torch.cat((Net2Iput1, Net2Iput2), 3)
                # Net2Iput=torch.cat((Net2Iput,Net2Iput22),1)
                ps = nn.PixelShuffle(4)
                Net2Iput = ps(Net2Iput)
                Net2Out = self.model2(Net2Iput)

                self.optimizer2.zero_grad()
                lossnet2 = self.loss_function(Net2Out, bg_image)
                lossnet2.backward(retain_graph=True)
                self.optimizer2.step()

                self.optimizer.zero_grad()
                # loss=0
                # plt.figure()
                # plt.imshow(left.squeeze().cpu().detach().numpy())
                # plt.show()
                # plt.imshow(bg_image.squeeze().cpu().detach().numpy())
                # plt.show()
                #disps = self.model(left)
                # print('gggggggggggggggggggggggggggggggggggggggggg',left.shape,disps.shape)
                # plt.imshow(disps.squeeze().cpu().detach().numpy())
                # plt.show()

                loss1 = self.loss_function(Net2Out, bg_image)
                loss1.backward(retain_graph=True)

                self.optimizer.step()

                # losses.append(loss.item())
                running_loss += lossnet2.item()
            # print(' time = %.2f' %(time.time() - c_time))
            running_loss /= (self.n_img / self.args.batch_size)

            # print('running_loss:', running_loss)

            # running_val_loss /= (self.val_n_img / self.args.batch_size)
            print('Epoch:', epoch + 1, 'train_loss:', running_loss, 'time:',
                  round(time.time() - c_time, 3), 's')

            TrueLoss = math.sqrt(running_loss) * 255

            print('TrueLoss:', TrueLoss)
            if epoch % 2 == 0:
                #self.model.eval()
                self.model.eval()
                i = 0
                running_val_loss = 0.0
                for data in self.val_loader:
                    data = to_device(data, self.device)

                    left = data['left_image']
                    bg_image = data['bg_image']

                    #with torch.no_grad():
                    # newinput=torch.cat([left,bg],1)
                    # disps = self.model(newinput)

                    #disps = self.model(left)
                    #Net2Out=self.model2(disps)

                    #lossnet2 = self.loss_function(Net2Out,bg_image)
                    # loss1 = self.loss_function((disps+left1.float()),target,mask)
                    running_val_loss += lossnet2

                #running_val_loss/=100

                print('running_val_loss = %.12f' % (running_val_loss))

            if running_val_loss < best_val_loss:

                self.save(self.args.model_path[:-4] + '_cpt.pth')
                self.save2(self.args.model2_path[:-4] + '_cpt.pth')
                best_val_loss = running_val_loss
                print('Model_saved')
            # self.save(self.args.model_path[:-4] + '_'+str(epoch+1)+'.pth')
            # if  epoch==100:
            #     self.save(self.args.model_path[:-4] + '_100.pth')
            # if  epoch==150:
            #     self.save(self.args.model_path[:-4] + '_150.pth')
            # if  epoch==200:
            #         self.save(self.args.model_path[:-4] + '_200.pth')
            # if  epoch==250:
            #     self.save(self.args.model_path[:-4] + '_250.pth')
            # self.save(self.args.model_path[:-4] + '_last.pth')#上一回
            # if running_loss < best_val_loss:
            #     #print(running_val_loss)
            #     #print(best_val_loss)
            #     self.save(self.args.model_path[:-4] + '_cpt.pth')
            #     best_val_loss = running_loss
            #     print('Model_saved')

        # print ('Finished Training. Best loss:', best_loss)
        #self.save(self.args.model_path)

    def save(self, path):
        torch.save(self.model.state_dict(), path)

    def save2(self, path):
        torch.save(self.model2.state_dict(), path)

    def load(self, path):
        self.model.load_state_dict(torch.load(path))