Exemplo n.º 1
0
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    #Import ERFNet model from the folder
    #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet")
    model = ERFNet(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    #model.load_state_dict(torch.load(args.state))
    #model.load_state_dict(torch.load(weightspath)) #not working if missing key

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    model.eval()

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")

    loader = DataLoader(MAV(args.datadir,
                            input_transform_MAV,
                            target_transform_MAV,
                            subset=args.subset),
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=False)

    for step, (images, labels, filename, filenameGt) in enumerate(loader):
        if (not args.cpu):
            images = images.cuda()
            #labels = labels.cuda()

        inputs = Variable(images)
        #targets = Variable(labels)
        with torch.no_grad():
            outputs = model(inputs)

        label = outputs[0].max(0)[1].byte().cpu().data
        label_MAV = MAV_trainIds2labelIds(label.unsqueeze(0))
        #print (numpy.unique(label.numpy()))  #debug

        filenameSave = "./save_results/" + filename[0].split("Images/")[1]
        os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
        #image_transform(label.byte()).save(filenameSave)
        label_MAV.save(filenameSave)

        print(step, filenameSave)
Exemplo n.º 2
0
def main(args):
    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")
    else:
        print("Loading model: " + modelpath)
        print("Loading weights: " + weightspath)

    # Import ERFNET
    model = ERFNet(NUM_CLASSES)
    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    # Set model to Evaluation mode
    model.eval()

    # Setup the dataset loader
    ### RELLIS-3D Dataloader
    enc = False
    loader_test = custom_datasets.setup_loaders(args, enc)
    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    if (args.visualize):
        vis = visdom.Visdom()

    for step, (images, labels, img_name, _) in enumerate(loader_test):
        start_time = time.time()
        if (not args.cpu):
            images = images.cuda()
            #labels = labels.cuda()

        inputs = Variable(images)
        with torch.no_grad():
            outputs = model(inputs)

        label = outputs[0].max(0)[1].byte().cpu().data
        label_color = Colorize()(label.unsqueeze(0))

        eval_save_path = "./save_colour_rellis/"
        if not os.path.exists(eval_save_path):
            os.makedirs(eval_save_path)

        _, file_name = os.path.split(img_name[0])
        file_name = file_name + ".png"

        #image_transform(label.byte()).save(filenameSave)
        label_save = ToPILImage()(label_color)
        label_save.save(os.path.join(eval_save_path, file_name))

        if (args.visualize):
            vis.image(label_color.numpy())
        if step != 0:  #first run always takes some time for setup
            fwt = time.time() - start_time
            time_train.append(fwt)
            print("Forward time per img (b=%d): %.3f (Mean: %.3f)" %
                  (args.batch_size, fwt / args.batch_size,
                   sum(time_train) / len(time_train) / args.batch_size))

        print(step, os.path.join(eval_save_path, file_name))
Exemplo n.º 3
0
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    #Import ERFNet model from the folder
    #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet")
    model = ERFNet(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    #model.load_state_dict(torch.load(args.state))
    #model.load_state_dict(torch.load(weightspath)) #not working if missing key

    fourcc = cv2.VideoWriter_fourcc(*'MP4V')  # Save as video
    out = cv2.VideoWriter('output.mp4', fourcc, 20.0, (640, 352))

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(
        model, torch.load(weightspath, map_location=torch.device('cpu')))
    print("Model and weights LOADED successfully")

    model.eval()

    # loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset),
    #     num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)

    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    # if (args.visualize):
    #     vis = visdom.Visdom()

    # for step, (images, labels, filename, filenameGt) in enumerate(loader):
    # cap = cv2.VideoCapture(0)
    cap = cv2.VideoCapture('project_video_trimmed.mp4')

    while (True):
        # Capture frame-by-frame
        ret, images = cap.read()
        # print(images.shape)

        images = trans(images)
        images = images.float()
        images = images.view((1, 3, 352, 640))  # vidoe

        # Our operations on the frame come here
        if (not args.cpu):
            images = images.cuda()

        inputs = Variable(images)
        with torch.no_grad():
            outputs = model(inputs)

        label = outputs[0].max(0)[1].byte().cpu().data
        #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
        label_color = Colorize()(label.unsqueeze(0))
        frame = label_color.numpy().transpose(1, 2, 0)

        # label_save = ToPILImage()(label_color)
        # label_save.save("result_1.png")

        # Display the resulting frame
        cv2.imshow('Segmented Image', frame)
        out.write(frame)  # To save video file

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    out.release()
    cv2.destroyAllWindows()

    if (args.visualize):
        vis.image(label_color.numpy())
    print(step, filenameSave)
Exemplo n.º 4
0
def main(args):
    savedir = f'../save/{args.savedir}'

    if not os.path.exists(savedir):
        os.makedirs(savedir)

    with open(savedir + '/opts.txt', "w") as myfile:
        myfile.write(str(args))

    #Load Model
    assert os.path.exists(args.model + ".py"), "Error: model definition not found"
    model = ERFNet(NUM_CLASSES)
    copyfile(args.model + ".py", savedir + '/' + args.model + ".py")
    
    if args.cuda:
        model = torch.nn.DataParallel(model).cuda()

    assert os.path.exists(args.datadir), "Error: datadir (dataset directory) could not be loaded"

    # co_transform = MyCoTransform(augment=True, rescale=True, width=640, height=480)#1024)
    # co_transform_val = MyCoTransform(augment=False, rescale=True, width=640, height=480)#1024)
    # dataset_train = cityscapes_cv(args.datadir, co_transform, 'train')
    # dataset_val = cityscapes_cv(args.datadir, co_transform_val, 'val')
    # loader_train = DataLoader(dataset_train, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True)
    # loader_val = DataLoader(dataset_val, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)

    # co_transform = MyCoTransform(augment=True, rescale=True, width=640, height=480)#1024)
    # dataset_train = gardenscapes(args.datadir, co_transform)
    # loader_train = DataLoader(dataset_train, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True)
    # loader_val = loader_train


    if args.state:
        #if args.state is provided then load this state for training
        #Note: this only loads initialized weights. If you want to resume a training use "--resume" option!!
        def load_my_state_dict(model, state_dict):  #custom function to load model when not all dict keys are there
            own_state = model.state_dict()
            for name, param in state_dict.items():
                if name not in own_state:
                     continue
                own_state[name].copy_(param)
            return model

        #print(torch.load(args.state))
        model = load_my_state_dict(model, torch.load(args.state))

    #model = train(model, loader_train, loader_val, args)   #Train decoder
    print("========== TRAINING FINISHED ===========")

    print("========== START TESTING ==============")
    model_dir = "/home/pandongwei/work_repository/erfnet_pytorch/save/"+args.savedir+'/model_best.pth'
    #model_dir = "/home/pan/repository/erfnet_pytorch/save/geoMat_regression_2/checkpoint.pth.tar"
    def load_my_state_dict(model, state_dict):
        # state_dict = state_dict["state_dict"]
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                 continue
            own_state[name].copy_(param)
        return model
    model = load_my_state_dict(model, torch.load(model_dir))
    # filenameSave = "../eval/" + args.savedir + "/"
    # os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
    # co_transform_test = MyCoTransform(augment=False, rescale=True, width=640, height=480)  # 1024)
    # dataset_test = cityscapes_cv(args.datadir, co_transform_test, 'test')
    # loader_test = DataLoader(dataset_test,num_workers=args.num_workers, batch_size=1, shuffle=False)
    #test(filenameSave, model, loader_test, args)
    inference(model,args)
Exemplo n.º 5
0
def main(args):

    modelpath = args.loadDir + args.loadModel
    #weightspath = args.loadDir + args.loadWeights #TODO
    weightspath = "/home/pan/repository/erfnet_pytorch/save/geoMat_2/model_best.pth"
    # print ("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    #Import ERFNet model from the folder
    #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet")
    model = ERFNet(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    #model.load_state_dict(torch.load(args.state))
    #model.load_state_dict(torch.load(weightspath)) #not working if missing key

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    model.eval()

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")

    loader = DataLoader(geoMat(args.datadir,
                               input_transform_geoMat,
                               target_transform_geoMat,
                               subset=args.subset),
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=False)

    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    if (args.visualize):
        vis = visdom.Visdom()

    for step, (images, labels, filename, filenameGt) in enumerate(loader):
        if (not args.cpu):
            images = images.cuda()
            #labels = labels.cuda()

        inputs = Variable(images)
        #targets = Variable(labels)
        with torch.no_grad():
            outputs = model(inputs)

        label = outputs[0].max(0)[1].byte().cpu().data
        # print(label.shape)
        #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
        label_color = (label.unsqueeze(0))
        # print(label_color.shape)
        filenameSave = "./getMat_2/" + filename[0].split(
            "material_dataset_v2")[1]
        os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
        #image_transform(label.byte()).save(filenameSave)

        # label_save = ToPILImage()(label_color)
        label_save = label_color.numpy()
        label_save = label_save.transpose(1, 2, 0)
        # label_save.save(filenameSave)

        images = images.cpu().numpy().squeeze(axis=0)
        images = images.transpose(1, 2, 0)

        # print(images.shape)
        # print(label_save.shape)
        plt.figure(figsize=(10.4, 10.4), dpi=10)
        # plt.imshow(images)
        plt.imshow(label_save, alpha=0.6, cmap='gray')
        plt.axis('off')
        # plt.show()
        plt.savefig(filenameSave, dpi=10)
        plt.close()

        if (args.visualize):
            vis.image(label_color.numpy())
        print(step, filenameSave)
Exemplo n.º 6
0
class Model:
    def __init__(self, args):
        self.args = args

        self.device = args.device
        if args.model == 'stackhourglass':

            self.model = ERFNet(1)
            self.model2 = ldyNet(1)

        self.model.cuda()  #= self.model.to(self.device)
        self.model2.cuda()

        if args.use_multiple_gpu:
            self.model = torch.nn.DataParallel(self.model)
            self.model.cuda()
            self.model2 = torch.nn.DataParallel(self.model2)
            self.model2.cuda()
        if args.mode == 'train':
            self.loss_function = MonodepthLoss(n=3,
                                               SSIM_w=0.8,
                                               disp_gradient_w=0.1,
                                               lr_w=1).to(self.device)
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=args.learning_rate)
            self.optimizer2 = optim.Adam(self.model2.parameters(),
                                         lr=args.learning_rate)
            self.val_n_img, self.val_loader = prepare_dataloader(
                args.val_data_dir, 'test', args.augment_parameters, False, 1,
                (args.input_height, args.input_width), args.num_workers)

            # self.model.load_state_dict(torch.load(args.loadmodel))#jiazai yuxunlian moxing jiezhu ciju

        else:
            self.model.load_state_dict(torch.load(args.model_path))
            self.model2.load_state_dict(torch.load(args.model2_path))
            args.augment_parameters = None
            args.do_augmentation = False
            args.batch_size = 1
        self.n_img, self.loader = prepare_dataloader(
            args.data_dir, args.mode, args.augment_parameters,
            args.do_augmentation, args.batch_size,
            (args.input_height, args.input_width), args.num_workers)

        if 'cuda' in self.device:
            torch.cuda.synchronize()

    def disparity_loader(self, path):
        return Image.open(path)

    def train(self):

        # losses = []
        #val_losses = []

        best_loss = float('Inf')
        best_val_loss = 1000000.0

        for epoch in range(self.args.epochs):

            # losses = []

            if self.args.adjust_lr:
                adjust_learning_rate(self.optimizer, epoch,
                                     self.args.learning_rate)
            if self.args.adjust_lr:
                adjust_learning_rate2(self.optimizer2, epoch,
                                      self.args.learning_rate)
            c_time = time.time()
            running_loss = 0.0
            self.model.train()  #?start?
            self.model2.train()

            for data in self.loader:  #(test-for-train)
                # Load data
                data = to_device(data, self.device)
                left = data['left_image']
                bg_image = data['bg_image']
                #template=data['template']
                #randomMatrix=np.random.randint(0,1,(256,256))
                # Net1Iput=torch.from_numpy(randomMatrix)
                disps = self.model(left)
                #j=0
                # while j<(self.args.batch_size-1):
                #     disps1=disps

                i = 0
                left1 = left
                while i < 15:

                    left1 = torch.cat((left1, left), 1)
                    i = i + 1
                Net2Iput = disps.mul(left1)  #点乘
                Net2Iput11, Net2Iput12 = Net2Iput.split(128, 2)
                Net2Iput11, Net2Iput21 = Net2Iput11.split(128, 3)
                Net2Iput12, Net2Iput22 = Net2Iput12.split(128, 3)

                # Net2Iput1=Net2Iput[:127]
                # Net2Iput2=Net2Iput[128:]
                # Net2Iput11=Net2Iput1[:,0:127]
                # Net2Iput12=Net2Iput1[:,128:255]
                # Net2Iput21=Net2Iput2[:,0:127]
                # Net2Iput22=Net2Iput2[:,128:255]

                Net2Iput11 = torch.sum(Net2Iput11, 2, keepdim=True, out=None)
                Net2Iput11 = torch.sum(Net2Iput11, 3, keepdim=True, out=None)
                Net2Iput12 = torch.sum(Net2Iput12, 2, keepdim=True, out=None)
                Net2Iput12 = torch.sum(Net2Iput12, 3, keepdim=True, out=None)
                Net2Iput21 = torch.sum(Net2Iput21, 2, keepdim=True, out=None)
                Net2Iput21 = torch.sum(Net2Iput21, 3, keepdim=True, out=None)
                Net2Iput22 = torch.sum(Net2Iput22, 2, keepdim=True, out=None)
                Net2Iput22 = torch.sum(Net2Iput22, 3, keepdim=True, out=None)
                Net2Iput1 = torch.cat((Net2Iput11, Net2Iput12), 2)
                Net2Iput2 = torch.cat((Net2Iput21, Net2Iput22), 2)
                Net2Iput = torch.cat((Net2Iput1, Net2Iput2), 3)
                # Net2Iput=torch.cat((Net2Iput,Net2Iput22),1)
                ps = nn.PixelShuffle(4)
                Net2Iput = ps(Net2Iput)
                Net2Out = self.model2(Net2Iput)

                self.optimizer2.zero_grad()
                lossnet2 = self.loss_function(Net2Out, bg_image)
                lossnet2.backward(retain_graph=True)
                self.optimizer2.step()

                self.optimizer.zero_grad()
                # loss=0
                # plt.figure()
                # plt.imshow(left.squeeze().cpu().detach().numpy())
                # plt.show()
                # plt.imshow(bg_image.squeeze().cpu().detach().numpy())
                # plt.show()
                #disps = self.model(left)
                # print('gggggggggggggggggggggggggggggggggggggggggg',left.shape,disps.shape)
                # plt.imshow(disps.squeeze().cpu().detach().numpy())
                # plt.show()

                loss1 = self.loss_function(Net2Out, bg_image)
                loss1.backward(retain_graph=True)

                self.optimizer.step()

                # losses.append(loss.item())
                running_loss += lossnet2.item()
            # print(' time = %.2f' %(time.time() - c_time))
            running_loss /= (self.n_img / self.args.batch_size)

            # print('running_loss:', running_loss)

            # running_val_loss /= (self.val_n_img / self.args.batch_size)
            print('Epoch:', epoch + 1, 'train_loss:', running_loss, 'time:',
                  round(time.time() - c_time, 3), 's')

            TrueLoss = math.sqrt(running_loss) * 255

            print('TrueLoss:', TrueLoss)
            if epoch % 2 == 0:
                #self.model.eval()
                self.model.eval()
                i = 0
                running_val_loss = 0.0
                for data in self.val_loader:
                    data = to_device(data, self.device)

                    left = data['left_image']
                    bg_image = data['bg_image']

                    #with torch.no_grad():
                    # newinput=torch.cat([left,bg],1)
                    # disps = self.model(newinput)

                    #disps = self.model(left)
                    #Net2Out=self.model2(disps)

                    #lossnet2 = self.loss_function(Net2Out,bg_image)
                    # loss1 = self.loss_function((disps+left1.float()),target,mask)
                    running_val_loss += lossnet2

                #running_val_loss/=100

                print('running_val_loss = %.12f' % (running_val_loss))

            if running_val_loss < best_val_loss:

                self.save(self.args.model_path[:-4] + '_cpt.pth')
                self.save2(self.args.model2_path[:-4] + '_cpt.pth')
                best_val_loss = running_val_loss
                print('Model_saved')
            # self.save(self.args.model_path[:-4] + '_'+str(epoch+1)+'.pth')
            # if  epoch==100:
            #     self.save(self.args.model_path[:-4] + '_100.pth')
            # if  epoch==150:
            #     self.save(self.args.model_path[:-4] + '_150.pth')
            # if  epoch==200:
            #         self.save(self.args.model_path[:-4] + '_200.pth')
            # if  epoch==250:
            #     self.save(self.args.model_path[:-4] + '_250.pth')
            # self.save(self.args.model_path[:-4] + '_last.pth')#上一回
            # if running_loss < best_val_loss:
            #     #print(running_val_loss)
            #     #print(best_val_loss)
            #     self.save(self.args.model_path[:-4] + '_cpt.pth')
            #     best_val_loss = running_loss
            #     print('Model_saved')

        # print ('Finished Training. Best loss:', best_loss)
        #self.save(self.args.model_path)

    def save(self, path):
        torch.save(self.model.state_dict(), path)

    def save2(self, path):
        torch.save(self.model2.state_dict(), path)

    def load(self, path):
        self.model.load_state_dict(torch.load(path))
Exemplo n.º 7
0
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print ("Loading model: " + modelpath)
    print ("Loading weights: " + weightspath)

    #Import ERFNet model from the folder
    #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet")
    model = ERFNet(NUM_CLASSES)
  
    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    #model.load_state_dict(torch.load(args.state))
    #model.load_state_dict(torch.load(weightspath)) #not working if missing key

    def load_my_state_dict(model, state_dict):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                 continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print ("Model and weights LOADED successfully")

    model.eval()

    if(not os.path.exists(args.datadir)):
        print ("Error: datadir could not be loaded")


    loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset),
        num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)

    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    if (args.visualize):
        vis = visdom.Visdom()

    for step, (images, labels, filename, filenameGt) in enumerate(loader):
        if (not args.cpu):
            images = images.cuda()
            #labels = labels.cuda()

        inputs = Variable(images)
        #targets = Variable(labels)
        with torch.no_grad():
            outputs = model(inputs)

        label = outputs[0].max(0)[1].byte().cpu().data
        #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
        label_color = Colorize()(label.unsqueeze(0))

        filenameSave = "./save_color/" + filename[0].split("leftImg8bit/")[1]
        os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
        #image_transform(label.byte()).save(filenameSave)      
        label_save = ToPILImage()(label_color)           
        label_save.save(filenameSave) 

        if (args.visualize):
            vis.image(label_color.numpy())
        print (step, filenameSave)
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print ("Loading model: " + modelpath)
    print ("Loading weights: " + weightspath)

    model = ERFNet(NUM_CLASSES)

    #model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = torch.nn.DataParallel(model).cuda()

    def load_my_state_dict(model, state_dict):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                if name.startswith("module."):
                    own_state[name.split("module.")[-1]].copy_(param)
                else:
                    print(name, " not loaded")
                    continue
            else:
                own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath, map_location=lambda storage, loc: storage))
    print ("Model and weights LOADED successfully")


    model.eval()

    if(not os.path.exists(args.datadir)):
        print ("Error: datadir could not be loaded")


    loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset), num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)


    iouEvalVal = iouEval(NUM_CLASSES)

    start = time.time()

    for step, (images, labels, filename, filenameGt) in enumerate(loader):
        if (not args.cpu):
            images = images.cuda()
            labels = labels.cuda()

        inputs = Variable(images)
        with torch.no_grad():
            outputs = model(inputs)

        iouEvalVal.addBatch(outputs.max(1)[1].unsqueeze(1).data, labels)

        filenameSave = filename[0].split("leftImg8bit/")[1] 

        print (step, filenameSave)


    iouVal, iou_classes = iouEvalVal.getIoU()

    iou_classes_str = []
    for i in range(iou_classes.size(0)):
        iouStr = getColorEntry(iou_classes[i])+'{:0.2f}'.format(iou_classes[i]*100) + '\033[0m'
        iou_classes_str.append(iouStr)

    print("---------------------------------------")
    print("Took ", time.time()-start, "seconds")
    print("=======================================")
    #print("TOTAL IOU: ", iou * 100, "%")
    print("Per-Class IoU:")
    print(iou_classes_str[0], "Road")
    print(iou_classes_str[1], "sidewalk")
    print(iou_classes_str[2], "building")
    print(iou_classes_str[3], "wall")
    print(iou_classes_str[4], "fence")
    print(iou_classes_str[5], "pole")
    print(iou_classes_str[6], "traffic light")
    print(iou_classes_str[7], "traffic sign")
    print(iou_classes_str[8], "vegetation")
    print(iou_classes_str[9], "terrain")
    print(iou_classes_str[10], "sky")
    print(iou_classes_str[11], "person")
    print(iou_classes_str[12], "rider")
    print(iou_classes_str[13], "car")
    print(iou_classes_str[14], "truck")
    print(iou_classes_str[15], "bus")
    print(iou_classes_str[16], "train")
    print(iou_classes_str[17], "motorcycle")
    print(iou_classes_str[18], "bicycle")
    print("=======================================")
    iouStr = getColorEntry(iouVal)+'{:0.2f}'.format(iouVal*100) + '\033[0m'
    print ("MEAN IoU: ", iouStr, "%")
Exemplo n.º 9
0
def main():
    Net = ERFNet(NUM_CLASSES)
    Net = Net.cuda()
    train(Net)
Exemplo n.º 10
0
class Model:

    def __init__(self, args):
        self.args = args

        self.device = args.device
        if args.model == 'stackhourglass':
            
                self.model = ERFNet(2)
       
       
        self.model.cuda()
        
        self.model.load_state_dict(torch.load(args.loadmodel))
            # print(self.model.state_dict())
            # print('kkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk')
        self.n_img, self.loader = prepare_dataloader(args.data_dir, args.mode, 
                                                     args.augment_parameters,
                                                     args.do_augmentation, args.batch_size,
                                                     (args.input_height, args.input_width),
                                                     args.num_workers)
 
        if 'cuda' in self.device:
            torch.cuda.synchronize()

    

    def test(self):
        
        i = 0
        sum_loss = 0.0
        sum_ssim = 0.0
        average_ssim = 0.0
        average_loss = 0.0
        for epoch in range(self.args.epochs):


            self.model.eval()   #?start?

            for data in self.loader:       #(test-for-train)
                i = i + 1


                data = to_device(data, self.device)

                left = data['left_image']
                bg_image = data['bg_image']

                disps = self.model(left)

                print(disps.shape)

                l_loss = l1loss(disps,bg_image)
                ssim_loss = ssim(disps,bg_image)
				psnr222 = psnr1(disps,bg_image)


                sum_loss = sum_loss + l_loss.item()
                sum_ssim += ssim_loss.item()

                average_ssim = sum_ssim / i
                average_loss = sum_loss / i

                # print average_loss
    
                disp_show = disps.squeeze()
                bg_show = bg_image.squeeze()
                print(bg_show.shape)
                plt.figure()
                plt.subplot(1,2,1)
                plt.imshow(disp_show.data.cpu().numpy())
                plt.subplot(1,2,2)
                plt.imshow(bg_show.data.cpu().numpy())
                plt.show() 
        print('average loss',average_loss,average_ssim)
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    #Import ERFNet model from the folder
    #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet")
    if args.pretrainedEncoder:
        pretrainedEnc = torch.nn.DataParallel(ERFNet_imagenet(1000))
        #pretrainedEnc.load_state_dict(torch.load(args.pretrainedEncoder)['state_dcit'])
        pretrainedEnc = next(pretrainedEnc.children()).features.encoder
        if (not args.cuda):
            pretrainedEnc = pretrainedEnc.cpu()
        model = ERFNet(NUM_CLASSES, encoder=pretrainedEnc)
    else:
        model = ERFNet(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if args.cuda:
        model = model.cuda()

    #model.load_state_dict(torch.load(args.state))
    #model.load_state_dict(torch.load(weightspath)) #not working if missing key

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    model.eval()

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")

    dataset_test = avlane(args.datadir, input_transform_lot,
                          label_transform_lot, 'test')
    loader = DataLoader(dataset_test,
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=False)

    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    if (args.visualize):
        vis = visdom.Visdom()

    fig = plt.figure()
    ax = fig.gca()
    h = ax.imshow(Image.new('RGB', (640 * 2, 480), 0))

    print(len(loader.dataset))

    iouEvalTest = iouEval_binary(NUM_CLASSES)

    with torch.no_grad():
        for step, (images, labels, filename) in enumerate(loader):

            #print(images.shape)
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()

            inputs = Variable(images)
            targets = labels
            outputs = model(inputs)

            preds = torch.where(outputs > 0.5,
                                torch.ones([1], dtype=torch.long).cuda(),
                                torch.zeros([1], dtype=torch.long).cuda())
            #preds = torch.where(outputs > 0.5, torch.ones([1], dtype=torch.uint8).cuda(), torch.zeros([1], dtype=torch.uint8).cuda()) # b x 1 x h x w

            #label = outputs[0].max(0)[1].byte().cpu().data
            #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
            #label_color = Colorize()(label.unsqueeze(0))

            # iou
            iouEvalTest.addBatch(preds[:, 0],
                                 targets[:, 0])  # no_grad handles it already
            iouTest = iouEvalTest.getIoU()
            iouStr = "test IOU: " + '{:0.2f}'.format(
                iouTest.item() * 100) + "%"
            print(iouStr)

            # save the output
            filenameSave = os.path.join(args.loadDir, 'test_results',
                                        filename[0].split("test/")[1])
            os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
            #image_transform(label.byte()).save(filenameSave)
            #label_save = ToPILImage()(label)
            pred = preds.to(torch.uint8).squeeze(0).cpu()  # 1xhxw
            pred_save = pred_transform_lot(pred)
            #pred_save.save(filenameSave)

            # concatenate data & result
            im1 = Image.open(
                os.path.join(args.datadir, 'data/test',
                             filename[0].split("test/")[1])).convert('RGB')
            im2 = pred_save
            dst = Image.new('RGB', (im1.width + im2.width, im1.height))
            dst.paste(im1, (0, 0))
            dst.paste(im2, (im1.width, 0))
            filenameSaveConcat = os.path.join(args.loadDir,
                                              'test_results_concat',
                                              filename[0].split("test/")[1])
            os.makedirs(os.path.dirname(filenameSaveConcat), exist_ok=True)
            #dst.save(filenameSaveConcat)

            # wrtie iou on dst
            font = ImageFont.truetype(
                '/usr/share/fonts/truetype/freefont/FreeMonoBold.ttf', 36)
            d = ImageDraw.Draw(dst)
            d.text((900, 0), iouStr, font=font, fill=(255, 255, 0))

            # show video
            h.set_data(dst)
            plt.draw()
            plt.axis('off')
            plt.pause(1e-2)

            if (args.visualize):
                vis.image(label_save.numpy())

            print(step, filenameSave)
Exemplo n.º 12
0
def main():
    parser = ArgumentParser()
    arg = parser.add_argument
    arg('--model',
        default='M2UNet',
        type=str,
        choices=['M2UNet', 'DRIU', 'ERFNet', 'UNet'])
    arg('--state_dict',
        default='M2UNetDRIVE.pth',
        type=str,
        help='pretrained model weights file, stored in models')
    arg('--dataset',
        default='DRIVE',
        choices=['DRIVE', 'CHASE_DB1', 'HRF'],
        type=str,
        help=
        'determines the dataset directory and the amount of cropping that is performed to ensure that the loaded images are multiples of 32.'
        )
    arg('--threshold',
        default=0.5215,
        type=float,
        help='threshold to convert probability vessel map to binary map')
    arg('--devicename',
        default='cpu',
        type=str,
        help='device type, default: "cpu"')
    arg('--batch_size',
        default=1,
        type=int,
        help='inference batch size, default: 1')
    arg('--save_prob',
        default=False,
        help='save probability vessel maps to disk')
    arg('--save_binary_mask', default=False, help='save binary mask to disk')

    # Paths
    model_path = Path('models')
    data_path = Path('data')
    log_path = Path('logs')

    # parse arguments
    args = parser.parse_args()
    state_dict_path = model_path.joinpath(args.state_dict)
    dataset_path = data_path.joinpath(args.dataset)
    image_file_path = dataset_path.joinpath('test/images')
    prediction_output_path = dataset_path.joinpath('predictions')
    threshold = args.threshold
    devicename = args.devicename
    batch_size = args.batch_size
    dataset = args.dataset
    save_prob = args.save_prob
    save_binary_mask = args.save_binary_mask

    # default device type is 'cuda:0'
    pin_memory = True
    cudnn.benchmark = True
    device = torch.device(devicename)

    if devicename == 'cpu':
        # if run on cpu, disable cudnn benchmark and do not pin memory
        cudnn.benchmark = False
        pin_memory = False
        # only run on one core
        torch.set_num_threads(1)

    if args.model == 'M2UNet':
        model = m2unet()
    if args.model == 'DRIU' and dataset == 'DRIVE':
        model = DRIU()
    if args.model == 'DRIU' and dataset == 'CHASE_DB1':
        model = DRIU()
    if args.model == 'ERFNet':
        model = ERFNet(1)
    if args.model == 'UNet':
        model = UNet(in_channels=3,
                     n_classes=1,
                     depth=5,
                     wf=6,
                     padding=True,
                     batch_norm=False,
                     up_mode='upconv')

    state_dict = torch.load(str(state_dict_path), map_location=devicename)
    model.load_state_dict(state_dict, strict=True)
    model.eval()
    # list of all files include path
    file_paths = get_file_lists(image_file_path)
    # list of file names only
    file_names = list(map(lambda x: x.stem, file_paths))
    # dataloader
    dataloader = DataLoader(dataset=RetinaDataset(file_paths, dataset),
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=1,
                            pin_memory=pin_memory)

    run(model, dataloader, batch_size, threshold, device, save_prob,
        save_binary_mask, prediction_output_path, file_names)