Beispiel #1
0
def train():
    train_dataset = Cityscapes(root='.',
                               subset='train',
                               transform=MyTransform())
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=1,
                                               shuffle=True)

    erfnet = ERFNet(num_classes)

    optimizer = Adam(erfnet.parameters(),
                     5e-4, (0.9, 0.999),
                     eps=1e-08,
                     weight_decay=1e-4)
    lambda1 = lambda epoch: pow((1 - ((epoch - 1) / num_epochs)), 0.9)
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
    criterion = nn.NLLLoss2d(weight)
    iou_meter = meter.AverageValueMeter()

    erfnet.cuda()
    erfnet.train()

    for epoch in range(num_epochs):
        print("Epoch {} : ".format(epoch + 1))
        total_loss = []
        iou_meter.reset()
        scheduler.step(epoch)
        for i, (images, labels) in enumerate(train_loader):
            images = Variable(images.cuda(), requires_grad=True).cuda()
            labels = Variable(labels.cuda()).cuda()

            outputs = erfnet(images)

            optimizer.zero_grad()
            loss = criterion(torch.nn.functional.log_softmax(outputs), labels)
            loss.backward()
            optimizer.step()

            iou = IoU(outputs.max(1)[1].data, labels.data)
            iou_meter.add(iou)
            total_loss.append(loss.data[0] * batch_size)

        iou_avg = iou_meter.value()[0]
        loss_avg = sum(total_loss) / len(total_loss)
        scheduler.step(loss_avg)
        print("IoU : {:.5f}".format(iou_avg))
        print("Loss : {:.5f}".format(loss_avg))
        torch.save(erfnet, './model/erfnet' + str(epoch) + '.pth')

        eval(epoch)
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print ("Loading model: " + modelpath)
    print ("Loading weights: " + weightspath)

    #Import ERFNet model from the folder
    #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet")
    model = ERFNet(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    #model.load_state_dict(torch.load(args.state))
    #model.load_state_dict(torch.load(weightspath)) #not working if missing key

    def load_my_state_dict(model, state_dict):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                 continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print ("Model and weights LOADED successfully")

    model.eval()

    if(not os.path.exists(args.datadir)):
        print ("Error: datadir could not be loaded")


    loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset),
        num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)

    for step, (images, labels, filename, filenameGt) in enumerate(loader):
        if (not args.cpu):
            images = images.cuda()
            #labels = labels.cuda()

        inputs = Variable(images, volatile=True)
        #targets = Variable(labels, volatile=True)
        outputs = model(inputs)

        label = outputs[0].max(0)[1].byte().cpu().data
        label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
        #print (numpy.unique(label.numpy()))  #debug

        filenameSave = "./save_results/" + filename[0].split("leftImg8bit/")[1]
        os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
        #image_transform(label.byte()).save(filenameSave)
        label_cityscapes.save(filenameSave)

        print (step, filenameSave)
def midfunc():
    print 'load Model....\n'
    t1 = time.time()
    Net = ERFNet(2)
    model = './erf_save/model/main-erfnet-step-5998-epoch-20.pth'

    #print list(model.children())
    Net.load_state_dict(torch.load(model))

    Net = Net.cuda()
    print 'Done.\n'
    print 'time:', time.time() - t1, '\n'
    print 'compute output....\n'
    infer(Net)
    return 0
Beispiel #4
0
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    #Import ERFNet model from the folder
    #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet")
    model = ERFNet(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    #model.load_state_dict(torch.load(args.state))
    #model.load_state_dict(torch.load(weightspath)) #not working if missing key

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    model.eval()

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")

    loader = DataLoader(cityscapes(args.datadir,
                                   input_transform_cityscapes,
                                   target_transform_cityscapes,
                                   subset=args.subset),
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=False)

    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    if (args.visualize):
        vis = visdom.Visdom()

    for step, (images, labels, filename, filenameGt) in enumerate(loader):
        if (not args.cpu):
            images = images.cuda()
            #labels = labels.cuda()

        inputs = Variable(images)
        #targets = Variable(labels)
        with torch.no_grad():
            outputs = model(inputs)

        label = outputs[0].max(0)[1].byte().cpu().data
        #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
        label_color = Colorize()(label.unsqueeze(0))

        filenameSave = "./save_color/" + filename[0].split("leftImg8bit/")[1]
        os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
        #image_transform(label.byte()).save(filenameSave)
        label_save = ToPILImage()(label_color)
        label_save.save(filenameSave)

        if (args.visualize):
            vis.image(label_color.numpy())
        print(step, filenameSave)
Beispiel #5
0
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    model = ERFNet(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                print(name, " not loaded")
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    model.eval()

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")

    loader = DataLoader(cityscapes(args.datadir,
                                   input_transform_cityscapes,
                                   target_transform_cityscapes,
                                   subset=args.subset),
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=False)

    iouEvalVal = iouEval(NUM_CLASSES)

    start = time.time()

    for step, (images, labels, filename, filenameGt) in enumerate(loader):
        if (not args.cpu):
            images = images.cuda()
            labels = labels.cuda()

        inputs = Variable(images, volatile=True)
        outputs = model(inputs)

        iouEvalVal.addBatch(outputs.max(1)[1].unsqueeze(1).data, labels)

        filenameSave = filename[0].split("leftImg8bit/")[1]

        print(step, filenameSave)

    iouVal, iou_classes = iouEvalVal.getIoU()

    iou_classes_str = []
    for i in range(iou_classes.size(0)):
        iouStr = getColorEntry(iou_classes[i]) + '{:0.2f}'.format(
            iou_classes[i] * 100) + '\033[0m'
        iou_classes_str.append(iouStr)

    print("---------------------------------------")
    print("Took ", time.time() - start, "seconds")
    print("=======================================")
    #print("TOTAL IOU: ", iou * 100, "%")
    print("Per-Class IoU:")
    print(iou_classes_str[0], "Road")
    print(iou_classes_str[1], "sidewalk")
    print(iou_classes_str[2], "building")
    print(iou_classes_str[3], "wall")
    print(iou_classes_str[4], "fence")
    print(iou_classes_str[5], "pole")
    print(iou_classes_str[6], "traffic light")
    print(iou_classes_str[7], "traffic sign")
    print(iou_classes_str[8], "vegetation")
    print(iou_classes_str[9], "terrain")
    print(iou_classes_str[10], "sky")
    print(iou_classes_str[11], "person")
    print(iou_classes_str[12], "rider")
    print(iou_classes_str[13], "car")
    print(iou_classes_str[14], "truck")
    print(iou_classes_str[15], "bus")
    print(iou_classes_str[16], "train")
    print(iou_classes_str[17], "motorcycle")
    print(iou_classes_str[18], "bicycle")
    print("=======================================")
    iouStr = getColorEntry(iouVal) + '{:0.2f}'.format(iouVal * 100) + '\033[0m'
    print("MEAN IoU: ", iouStr, "%")
Beispiel #6
0
class Model:

    def __init__(self, args):
        self.args = args

        self.device = args.device
        if args.model == 'stackhourglass':
            
                self.model = ERFNet(2)
       
       
        self.model.cuda() #= self.model.to(self.device)

        if args.use_multiple_gpu:
            self.model = torch.nn.DataParallel(self.model)
            self.model.cuda()
        if args.mode == 'train':
            self.loss_function = MonodepthLoss(
                n=3,
                SSIM_w=0.8,
                disp_gradient_w=0.1, lr_w=1).to(self.device)
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=args.learning_rate)
            self.val_n_img, self.val_loader = prepare_dataloader(args.val_data_dir, 'test',
                                                                 args.augment_parameters,
                                                                 False, 1,
                                                                 (args.input_height, args.input_width),
                                                                 args.num_workers)


            # self.model.load_state_dict(torch.load(args.loadmodel))#jiazai yuxunlian moxing jiezhu ciju 
                                      
        else:
            self.model.load_state_dict(torch.load(args.model_path))
            args.augment_parameters = None
            args.do_augmentation = False
            args.batch_size = 1
        self.n_img, self.loader = prepare_dataloader(args.data_dir, args.mode, 
                                                     args.augment_parameters,
                                                     args.do_augmentation, args.batch_size,
                                                     (args.input_height, args.input_width),
                                                     args.num_workers)


     
  
        
        if 'cuda' in self.device:
            torch.cuda.synchronize()

    

        
    def disparity_loader(self,path):
        return Image.open(path)
    def train(self):
        
        # losses = []
        #val_losses = []

        best_loss = float('Inf')
        best_val_loss = 1000000.0

        for epoch in range(self.args.epochs):

           
            # losses = []
       

       
            if self.args.adjust_lr:
                adjust_learning_rate(self.optimizer, epoch,
                                     self.args.learning_rate)
            c_time = time.time()
            running_loss = 0.0
            self.model.train()   #?start?

            for data in self.loader:       #(test-for-train)
                # Load data
                
                
                self.optimizer.zero_grad()
                # loss=0
               
                data = to_device(data, self.device)

               

                left = data['left_image']
                bg_image = data['bg_image']

            
                # plt.figure()
                # plt.imshow(left.squeeze().cpu().detach().numpy())
                # plt.show()  
                # plt.imshow(bg_image.squeeze().cpu().detach().numpy())
                # plt.show()  
                
                disps = self.model(left)
                
                # print('gggggggggggggggggggggggggggggggggggggggggg',left.shape,disps.shape)
                # plt.imshow(disps.squeeze().cpu().detach().numpy())
                # plt.show() 

                loss = self.loss_function(disps,bg_image)
                ssim_loss = ssim(disps,bg_image)
                ssim=1-ssim_loss
                loss1=loss*1+ssim*1
                loss1.backward()
       
                
             
        
                
                self.optimizer.step()
              
                
                # losses.append(loss.item())
                running_loss += loss1.item()
            # print(' time = %.2f' %(time.time() - c_time))
            running_loss /=( self.n_img / self.args.batch_size)
           
   
            # print('running_loss:', running_loss)
      
            # running_val_loss /= (self.val_n_img / self.args.batch_size)
            print('Epoch:',epoch + 1,'train_loss:',running_loss,'time:',round(time.time() - c_time, 3),'s')
                

            
            TrueLoss=math.sqrt( running_loss )*255

            print ('TrueLoss:',TrueLoss)
                



            if epoch%5==0:
                self.model.eval()
                i=0
                running_val_loss = 0.0
                for data in self.val_loader:
                    data = to_device(data, self.device)

                    left = data['left_image']
                    bg_image = data['bg_image']

            
   
                    with torch.no_grad():
                        # newinput=torch.cat([left,bg],1)
                        # disps = self.model(newinput)

                        disps = self.model(left)
            
                    
                    loss = self.loss_function(disps,bg_image)
                    ssim_loss = ssim(disps,bg_image)
                    ssim=1-ssim_loss
                    loss1=loss*1+ssim*1        
                    # loss1 = self.loss_function((disps+left1.float()),target,mask)
                    running_val_loss+=loss1
                        
                
                running_val_loss/=200

                print( 'running_val_loss = %.12f' %(running_val_loss))

            if running_val_loss < best_val_loss:
            
                self.save(self.args.model_path[:-4] + '_cpt.pth')
                best_val_loss = running_val_loss
                print('Model_saved')
            # self.save(self.args.model_path[:-4] + '_'+str(epoch+1)+'.pth')
            # if  epoch==100:
            #     self.save(self.args.model_path[:-4] + '_100.pth')  
            # if  epoch==150:
            #     self.save(self.args.model_path[:-4] + '_150.pth')  
            # if  epoch==200:
            #         self.save(self.args.model_path[:-4] + '_200.pth')
            # if  epoch==250:
            #     self.save(self.args.model_path[:-4] + '_250.pth')        
            # self.save(self.args.model_path[:-4] + '_last.pth')#上一回
            # if running_loss < best_val_loss:
            #     #print(running_val_loss)
            #     #print(best_val_loss)
            #     self.save(self.args.model_path[:-4] + '_cpt.pth')
            #     best_val_loss = running_loss
            #     print('Model_saved')

        # print ('Finished Training. Best loss:', best_loss)
        #self.save(self.args.model_path)

    def save(self, path):
        torch.save(self.model.state_dict(), path)

    def load(self, path):
        self.model.load_state_dict(torch.load(path))
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    #Import ERFNet model from the folder
    #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet")
    model = ERFNet(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    #model.load_state_dict(torch.load(args.state))
    #model.load_state_dict(torch.load(weightspath)) #not working if missing key

    fourcc = cv2.VideoWriter_fourcc(*'MP4V')  # Save as video
    out = cv2.VideoWriter('output.mp4', fourcc, 20.0, (640, 352))

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(
        model, torch.load(weightspath, map_location=torch.device('cpu')))
    print("Model and weights LOADED successfully")

    model.eval()

    # loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset),
    #     num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)

    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    # if (args.visualize):
    #     vis = visdom.Visdom()

    # for step, (images, labels, filename, filenameGt) in enumerate(loader):
    # cap = cv2.VideoCapture(0)
    cap = cv2.VideoCapture('project_video_trimmed.mp4')

    while (True):
        # Capture frame-by-frame
        ret, images = cap.read()
        # print(images.shape)

        images = trans(images)
        images = images.float()
        images = images.view((1, 3, 352, 640))  # vidoe

        # Our operations on the frame come here
        if (not args.cpu):
            images = images.cuda()

        inputs = Variable(images)
        with torch.no_grad():
            outputs = model(inputs)

        label = outputs[0].max(0)[1].byte().cpu().data
        #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
        label_color = Colorize()(label.unsqueeze(0))
        frame = label_color.numpy().transpose(1, 2, 0)

        # label_save = ToPILImage()(label_color)
        # label_save.save("result_1.png")

        # Display the resulting frame
        cv2.imshow('Segmented Image', frame)
        out.write(frame)  # To save video file

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    out.release()
    cv2.destroyAllWindows()

    if (args.visualize):
        vis.image(label_color.numpy())
    print(step, filenameSave)
Beispiel #8
0
def main(args):
    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")
    else:
        print("Loading model: " + modelpath)
        print("Loading weights: " + weightspath)

    # Import ERFNET
    model = ERFNet(NUM_CLASSES)
    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    # Set model to Evaluation mode
    model.eval()

    # Setup the dataset loader
    ### RELLIS-3D Dataloader
    enc = False
    loader_test = custom_datasets.setup_loaders(args, enc)
    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    if (args.visualize):
        vis = visdom.Visdom()

    for step, (images, labels, img_name, _) in enumerate(loader_test):
        start_time = time.time()
        if (not args.cpu):
            images = images.cuda()
            #labels = labels.cuda()

        inputs = Variable(images)
        with torch.no_grad():
            outputs = model(inputs)

        label = outputs[0].max(0)[1].byte().cpu().data
        label_color = Colorize()(label.unsqueeze(0))

        eval_save_path = "./save_colour_rellis/"
        if not os.path.exists(eval_save_path):
            os.makedirs(eval_save_path)

        _, file_name = os.path.split(img_name[0])
        file_name = file_name + ".png"

        #image_transform(label.byte()).save(filenameSave)
        label_save = ToPILImage()(label_color)
        label_save.save(os.path.join(eval_save_path, file_name))

        if (args.visualize):
            vis.image(label_color.numpy())
        if step != 0:  #first run always takes some time for setup
            fwt = time.time() - start_time
            time_train.append(fwt)
            print("Forward time per img (b=%d): %.3f (Mean: %.3f)" %
                  (args.batch_size, fwt / args.batch_size,
                   sum(time_train) / len(time_train) / args.batch_size))

        print(step, os.path.join(eval_save_path, file_name))
Beispiel #9
0
def main(args):

    modelpath = args.loadDir + args.loadModel
    #weightspath = args.loadDir + args.loadWeights #TODO
    weightspath = "/home/pan/repository/erfnet_pytorch/save/geoMat_2/model_best.pth"
    # print ("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    #Import ERFNet model from the folder
    #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet")
    model = ERFNet(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    #model.load_state_dict(torch.load(args.state))
    #model.load_state_dict(torch.load(weightspath)) #not working if missing key

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    model.eval()

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")

    loader = DataLoader(geoMat(args.datadir,
                               input_transform_geoMat,
                               target_transform_geoMat,
                               subset=args.subset),
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=False)

    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    if (args.visualize):
        vis = visdom.Visdom()

    for step, (images, labels, filename, filenameGt) in enumerate(loader):
        if (not args.cpu):
            images = images.cuda()
            #labels = labels.cuda()

        inputs = Variable(images)
        #targets = Variable(labels)
        with torch.no_grad():
            outputs = model(inputs)

        label = outputs[0].max(0)[1].byte().cpu().data
        # print(label.shape)
        #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
        label_color = (label.unsqueeze(0))
        # print(label_color.shape)
        filenameSave = "./getMat_2/" + filename[0].split(
            "material_dataset_v2")[1]
        os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
        #image_transform(label.byte()).save(filenameSave)

        # label_save = ToPILImage()(label_color)
        label_save = label_color.numpy()
        label_save = label_save.transpose(1, 2, 0)
        # label_save.save(filenameSave)

        images = images.cpu().numpy().squeeze(axis=0)
        images = images.transpose(1, 2, 0)

        # print(images.shape)
        # print(label_save.shape)
        plt.figure(figsize=(10.4, 10.4), dpi=10)
        # plt.imshow(images)
        plt.imshow(label_save, alpha=0.6, cmap='gray')
        plt.axis('off')
        # plt.show()
        plt.savefig(filenameSave, dpi=10)
        plt.close()

        if (args.visualize):
            vis.image(label_color.numpy())
        print(step, filenameSave)
Beispiel #10
0
class Model:
    def __init__(self, args):
        self.args = args

        self.device = args.device
        if args.model == 'stackhourglass':

            self.model = ERFNet(1)
            self.model2 = ldyNet(1)

        self.model.cuda()  #= self.model.to(self.device)
        self.model2.cuda()

        if args.use_multiple_gpu:
            self.model = torch.nn.DataParallel(self.model)
            self.model.cuda()
            self.model2 = torch.nn.DataParallel(self.model2)
            self.model2.cuda()
        if args.mode == 'train':
            self.loss_function = MonodepthLoss(n=3,
                                               SSIM_w=0.8,
                                               disp_gradient_w=0.1,
                                               lr_w=1).to(self.device)
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=args.learning_rate)
            self.optimizer2 = optim.Adam(self.model2.parameters(),
                                         lr=args.learning_rate)
            self.val_n_img, self.val_loader = prepare_dataloader(
                args.val_data_dir, 'test', args.augment_parameters, False, 1,
                (args.input_height, args.input_width), args.num_workers)

            # self.model.load_state_dict(torch.load(args.loadmodel))#jiazai yuxunlian moxing jiezhu ciju

        else:
            self.model.load_state_dict(torch.load(args.model_path))
            self.model2.load_state_dict(torch.load(args.model2_path))
            args.augment_parameters = None
            args.do_augmentation = False
            args.batch_size = 1
        self.n_img, self.loader = prepare_dataloader(
            args.data_dir, args.mode, args.augment_parameters,
            args.do_augmentation, args.batch_size,
            (args.input_height, args.input_width), args.num_workers)

        if 'cuda' in self.device:
            torch.cuda.synchronize()

    def disparity_loader(self, path):
        return Image.open(path)

    def train(self):

        # losses = []
        #val_losses = []

        best_loss = float('Inf')
        best_val_loss = 1000000.0

        for epoch in range(self.args.epochs):

            # losses = []

            if self.args.adjust_lr:
                adjust_learning_rate(self.optimizer, epoch,
                                     self.args.learning_rate)
            if self.args.adjust_lr:
                adjust_learning_rate2(self.optimizer2, epoch,
                                      self.args.learning_rate)
            c_time = time.time()
            running_loss = 0.0
            self.model.train()  #?start?
            self.model2.train()

            for data in self.loader:  #(test-for-train)
                # Load data
                data = to_device(data, self.device)
                left = data['left_image']
                bg_image = data['bg_image']
                #template=data['template']
                #randomMatrix=np.random.randint(0,1,(256,256))
                # Net1Iput=torch.from_numpy(randomMatrix)
                disps = self.model(left)
                #j=0
                # while j<(self.args.batch_size-1):
                #     disps1=disps

                i = 0
                left1 = left
                while i < 15:

                    left1 = torch.cat((left1, left), 1)
                    i = i + 1
                Net2Iput = disps.mul(left1)  #点乘
                Net2Iput11, Net2Iput12 = Net2Iput.split(128, 2)
                Net2Iput11, Net2Iput21 = Net2Iput11.split(128, 3)
                Net2Iput12, Net2Iput22 = Net2Iput12.split(128, 3)

                # Net2Iput1=Net2Iput[:127]
                # Net2Iput2=Net2Iput[128:]
                # Net2Iput11=Net2Iput1[:,0:127]
                # Net2Iput12=Net2Iput1[:,128:255]
                # Net2Iput21=Net2Iput2[:,0:127]
                # Net2Iput22=Net2Iput2[:,128:255]

                Net2Iput11 = torch.sum(Net2Iput11, 2, keepdim=True, out=None)
                Net2Iput11 = torch.sum(Net2Iput11, 3, keepdim=True, out=None)
                Net2Iput12 = torch.sum(Net2Iput12, 2, keepdim=True, out=None)
                Net2Iput12 = torch.sum(Net2Iput12, 3, keepdim=True, out=None)
                Net2Iput21 = torch.sum(Net2Iput21, 2, keepdim=True, out=None)
                Net2Iput21 = torch.sum(Net2Iput21, 3, keepdim=True, out=None)
                Net2Iput22 = torch.sum(Net2Iput22, 2, keepdim=True, out=None)
                Net2Iput22 = torch.sum(Net2Iput22, 3, keepdim=True, out=None)
                Net2Iput1 = torch.cat((Net2Iput11, Net2Iput12), 2)
                Net2Iput2 = torch.cat((Net2Iput21, Net2Iput22), 2)
                Net2Iput = torch.cat((Net2Iput1, Net2Iput2), 3)
                # Net2Iput=torch.cat((Net2Iput,Net2Iput22),1)
                ps = nn.PixelShuffle(4)
                Net2Iput = ps(Net2Iput)
                Net2Out = self.model2(Net2Iput)

                self.optimizer2.zero_grad()
                lossnet2 = self.loss_function(Net2Out, bg_image)
                lossnet2.backward(retain_graph=True)
                self.optimizer2.step()

                self.optimizer.zero_grad()
                # loss=0
                # plt.figure()
                # plt.imshow(left.squeeze().cpu().detach().numpy())
                # plt.show()
                # plt.imshow(bg_image.squeeze().cpu().detach().numpy())
                # plt.show()
                #disps = self.model(left)
                # print('gggggggggggggggggggggggggggggggggggggggggg',left.shape,disps.shape)
                # plt.imshow(disps.squeeze().cpu().detach().numpy())
                # plt.show()

                loss1 = self.loss_function(Net2Out, bg_image)
                loss1.backward(retain_graph=True)

                self.optimizer.step()

                # losses.append(loss.item())
                running_loss += lossnet2.item()
            # print(' time = %.2f' %(time.time() - c_time))
            running_loss /= (self.n_img / self.args.batch_size)

            # print('running_loss:', running_loss)

            # running_val_loss /= (self.val_n_img / self.args.batch_size)
            print('Epoch:', epoch + 1, 'train_loss:', running_loss, 'time:',
                  round(time.time() - c_time, 3), 's')

            TrueLoss = math.sqrt(running_loss) * 255

            print('TrueLoss:', TrueLoss)
            if epoch % 2 == 0:
                #self.model.eval()
                self.model.eval()
                i = 0
                running_val_loss = 0.0
                for data in self.val_loader:
                    data = to_device(data, self.device)

                    left = data['left_image']
                    bg_image = data['bg_image']

                    #with torch.no_grad():
                    # newinput=torch.cat([left,bg],1)
                    # disps = self.model(newinput)

                    #disps = self.model(left)
                    #Net2Out=self.model2(disps)

                    #lossnet2 = self.loss_function(Net2Out,bg_image)
                    # loss1 = self.loss_function((disps+left1.float()),target,mask)
                    running_val_loss += lossnet2

                #running_val_loss/=100

                print('running_val_loss = %.12f' % (running_val_loss))

            if running_val_loss < best_val_loss:

                self.save(self.args.model_path[:-4] + '_cpt.pth')
                self.save2(self.args.model2_path[:-4] + '_cpt.pth')
                best_val_loss = running_val_loss
                print('Model_saved')
            # self.save(self.args.model_path[:-4] + '_'+str(epoch+1)+'.pth')
            # if  epoch==100:
            #     self.save(self.args.model_path[:-4] + '_100.pth')
            # if  epoch==150:
            #     self.save(self.args.model_path[:-4] + '_150.pth')
            # if  epoch==200:
            #         self.save(self.args.model_path[:-4] + '_200.pth')
            # if  epoch==250:
            #     self.save(self.args.model_path[:-4] + '_250.pth')
            # self.save(self.args.model_path[:-4] + '_last.pth')#上一回
            # if running_loss < best_val_loss:
            #     #print(running_val_loss)
            #     #print(best_val_loss)
            #     self.save(self.args.model_path[:-4] + '_cpt.pth')
            #     best_val_loss = running_loss
            #     print('Model_saved')

        # print ('Finished Training. Best loss:', best_loss)
        #self.save(self.args.model_path)

    def save(self, path):
        torch.save(self.model.state_dict(), path)

    def save2(self, path):
        torch.save(self.model2.state_dict(), path)

    def load(self, path):
        self.model.load_state_dict(torch.load(path))
Beispiel #11
0
def main():
    Net = ERFNet(NUM_CLASSES)
    Net = Net.cuda()
    train(Net)
Beispiel #12
0
class Model:

    def __init__(self, args):
        self.args = args

        self.device = args.device
        if args.model == 'stackhourglass':
            
                self.model = ERFNet(2)
       
       
        self.model.cuda()
        
        self.model.load_state_dict(torch.load(args.loadmodel))
            # print(self.model.state_dict())
            # print('kkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk')
        self.n_img, self.loader = prepare_dataloader(args.data_dir, args.mode, 
                                                     args.augment_parameters,
                                                     args.do_augmentation, args.batch_size,
                                                     (args.input_height, args.input_width),
                                                     args.num_workers)
 
        if 'cuda' in self.device:
            torch.cuda.synchronize()

    

    def test(self):
        
        i = 0
        sum_loss = 0.0
        sum_ssim = 0.0
        average_ssim = 0.0
        average_loss = 0.0
        for epoch in range(self.args.epochs):


            self.model.eval()   #?start?

            for data in self.loader:       #(test-for-train)
                i = i + 1


                data = to_device(data, self.device)

                left = data['left_image']
                bg_image = data['bg_image']

                disps = self.model(left)

                print(disps.shape)

                l_loss = l1loss(disps,bg_image)
                ssim_loss = ssim(disps,bg_image)
				psnr222 = psnr1(disps,bg_image)


                sum_loss = sum_loss + l_loss.item()
                sum_ssim += ssim_loss.item()

                average_ssim = sum_ssim / i
                average_loss = sum_loss / i

                # print average_loss
    
                disp_show = disps.squeeze()
                bg_show = bg_image.squeeze()
                print(bg_show.shape)
                plt.figure()
                plt.subplot(1,2,1)
                plt.imshow(disp_show.data.cpu().numpy())
                plt.subplot(1,2,2)
                plt.imshow(bg_show.data.cpu().numpy())
                plt.show() 
        print('average loss',average_loss,average_ssim)
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    #Import ERFNet model from the folder
    #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet")
    if args.pretrainedEncoder:
        pretrainedEnc = torch.nn.DataParallel(ERFNet_imagenet(1000))
        #pretrainedEnc.load_state_dict(torch.load(args.pretrainedEncoder)['state_dcit'])
        pretrainedEnc = next(pretrainedEnc.children()).features.encoder
        if (not args.cuda):
            pretrainedEnc = pretrainedEnc.cpu()
        model = ERFNet(NUM_CLASSES, encoder=pretrainedEnc)
    else:
        model = ERFNet(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if args.cuda:
        model = model.cuda()

    #model.load_state_dict(torch.load(args.state))
    #model.load_state_dict(torch.load(weightspath)) #not working if missing key

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    model.eval()

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")

    dataset_test = avlane(args.datadir, input_transform_lot,
                          label_transform_lot, 'test')
    loader = DataLoader(dataset_test,
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=False)

    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    if (args.visualize):
        vis = visdom.Visdom()

    fig = plt.figure()
    ax = fig.gca()
    h = ax.imshow(Image.new('RGB', (640 * 2, 480), 0))

    print(len(loader.dataset))

    iouEvalTest = iouEval_binary(NUM_CLASSES)

    with torch.no_grad():
        for step, (images, labels, filename) in enumerate(loader):

            #print(images.shape)
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()

            inputs = Variable(images)
            targets = labels
            outputs = model(inputs)

            preds = torch.where(outputs > 0.5,
                                torch.ones([1], dtype=torch.long).cuda(),
                                torch.zeros([1], dtype=torch.long).cuda())
            #preds = torch.where(outputs > 0.5, torch.ones([1], dtype=torch.uint8).cuda(), torch.zeros([1], dtype=torch.uint8).cuda()) # b x 1 x h x w

            #label = outputs[0].max(0)[1].byte().cpu().data
            #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
            #label_color = Colorize()(label.unsqueeze(0))

            # iou
            iouEvalTest.addBatch(preds[:, 0],
                                 targets[:, 0])  # no_grad handles it already
            iouTest = iouEvalTest.getIoU()
            iouStr = "test IOU: " + '{:0.2f}'.format(
                iouTest.item() * 100) + "%"
            print(iouStr)

            # save the output
            filenameSave = os.path.join(args.loadDir, 'test_results',
                                        filename[0].split("test/")[1])
            os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
            #image_transform(label.byte()).save(filenameSave)
            #label_save = ToPILImage()(label)
            pred = preds.to(torch.uint8).squeeze(0).cpu()  # 1xhxw
            pred_save = pred_transform_lot(pred)
            #pred_save.save(filenameSave)

            # concatenate data & result
            im1 = Image.open(
                os.path.join(args.datadir, 'data/test',
                             filename[0].split("test/")[1])).convert('RGB')
            im2 = pred_save
            dst = Image.new('RGB', (im1.width + im2.width, im1.height))
            dst.paste(im1, (0, 0))
            dst.paste(im2, (im1.width, 0))
            filenameSaveConcat = os.path.join(args.loadDir,
                                              'test_results_concat',
                                              filename[0].split("test/")[1])
            os.makedirs(os.path.dirname(filenameSaveConcat), exist_ok=True)
            #dst.save(filenameSaveConcat)

            # wrtie iou on dst
            font = ImageFont.truetype(
                '/usr/share/fonts/truetype/freefont/FreeMonoBold.ttf', 36)
            d = ImageDraw.Draw(dst)
            d.text((900, 0), iouStr, font=font, fill=(255, 255, 0))

            # show video
            h.set_data(dst)
            plt.draw()
            plt.axis('off')
            plt.pause(1e-2)

            if (args.visualize):
                vis.image(label_save.numpy())

            print(step, filenameSave)