def main(args):

    savedir = '/home/shyam.nandan/NewExp/final_code/save/' + args.savedir  #change path here

    if not os.path.exists(savedir):
        os.makedirs(savedir)

    with open(savedir + '/opts.txt', "w") as myfile:
        myfile.write(str(args))

    rmodel = UNet()
    rmodel = torch.nn.DataParallel(rmodel).cuda()
    pretrainedEnc = torch.nn.DataParallel(ERFNet_imagenet(1000))
    pretrainedEnc.load_state_dict(
        torch.load(args.pretrainedEncoder)['state_dict'])
    pretrainedEnc = next(pretrainedEnc.children()).features.encoder
    model = Net(NUM_CLASSES)
    model = fill_weights(model, pretrainedEnc)
    model = torch.nn.DataParallel(model).cuda()
    #model = train(args, rmodel, model, False)

    PATH = '/home/shyam.nandan/NewExp/final_code/results/CB_iFL/rmodel_best.pth'
    rmodel.load_state_dict(torch.load(PATH))

    PATH = '/home/shyam.nandan/NewExp/final_code/results/CB_iFL/model_best.pth'

    model.load_state_dict(torch.load(PATH))

    model = train(args, rmodel, model, False)
Beispiel #2
0
def main(args):
    savedir = args.savedir

    if not os.path.exists(savedir):
        os.makedirs(savedir)

    with open(savedir + '/opts.txt', "w") as myfile:
        myfile.write(str(args))

    #Load Model
    assert os.path.exists(args.model + ".py"), "Error: model definition not found"
    model_file = importlib.import_module(args.model)
    model = model_file.Net(NUM_CLASSES)
    copyfile(args.model + ".py", savedir + '/' + args.model + ".py")
    
    if args.cuda:
        model = torch.nn.DataParallel(model).cuda()
    
    if args.state:
        #if args.state is provided then load this state for training
        #Note: this only loads initialized weights. If you want to resume a training use "--resume" option!!
        def load_my_state_dict(model, state_dict):  #custom function to load model when not all dict keys are there
            own_state = model.state_dict()
            for name, param in state_dict.items():
                if name not in own_state:
                     continue
                own_state[name].copy_(param)
            return model

        #print(torch.load(args.state))
        model = load_my_state_dict(model, torch.load(args.state))

    #train(args, model)
    if (not args.decoder):
        print("========== ENCODER TRAINING ===========")
        model = train(args, model, True) #Train encoder
        if (args.trainInOneGo):
            args.resume = False
    #CAREFUL: for some reason, after training encoder alone, the decoder gets weights=0. 
    #We must reinit decoder weights or reload network passing only encoder in order to train decoder
    print("========== DECODER TRAINING ===========")
    if (not args.state):
        if args.pretrainedEncoder:
            print("Loading encoder pretrained in imagenet")
            from erfnet_imagenet import ERFNet as ERFNet_imagenet
            pretrainedEnc = torch.nn.DataParallel(ERFNet_imagenet(1000))
            pretrainedEnc.load_state_dict(torch.load(args.pretrainedEncoder)['state_dict'])
            pretrainedEnc = next(pretrainedEnc.children()).features.encoder
            if (not args.cuda):
                pretrainedEnc = pretrainedEnc.cpu()     #because loaded encoder is probably saved in cuda
        else:
            pretrainedEnc = next(model.children()).encoder
        model = model_file.Net(NUM_CLASSES, encoder=pretrainedEnc)  #Add decoder to encoder
        if args.cuda:
            model = torch.nn.DataParallel(model).cuda()
        #When loading encoder reinitialize weights for decoder because they are set to 0 when training dec
    model = train(args, model, False)   #Train decoder
    print("========== TRAINING FINISHED ===========")
Beispiel #3
0
	def __init__(self, classes, embed_dim, resnet, pretrained_model=None,
				 pretrained=True, use_torch_up=False):
		super().__init__()
		assert(isinstance(classes , dict)), f"num_labels should be dict, got {type(classes)}"
		self.datasets = list(classes.keys())
		self.embed_dim = embed_dim
        
		pretrainedEnc = torch.nn.DataParallel(ERFNet_imagenet(1000))
		pretrainedEnc.load_state_dict(torch.load("erfnet_encoder_pretrained.pth.tar")['state_dict'])
		pretrainedEnc = next(pretrainedEnc.children()).features.encoder

		model = ECANet(num_classes=1000, encoder=pretrainedEnc)  #Add decoder to encoder
		pmodel = nn.DataParallel(model)
		#pmodel = pmodel.cuda()

		self.base = nn.Sequential(*list(model.children())[:-2]) ## Encoder. 
		#self.base = model[:-2]
		self.seg = nn.ModuleList() ## Decoder 1d conv
		self.up = nn.ModuleList() ## Decoder upsample (non-trainable)
	
		for n_labels in classes.values():
			m = nn.Conv2d(model.out_dim, n_labels, kernel_size=1, bias=True)
			n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
			m.weight.data.normal_(0, math.sqrt(2. / n))
			m.bias.data.zero_()
			self.seg.append(m)

			if use_torch_up:
				self.up.append(nn.UpsamplingBilinear2d(scale_factor=8))
			else:
				up = nn.ConvTranspose2d(n_labels, n_labels, 16, stride=8, padding=4,
										output_padding=0, groups=n_labels,
										bias=False)
				fill_up_weights(up)
				up.weight.requires_grad = False
				self.up.append(up)

		## Encoder output module
		m = nn.Conv2d(model.out_dim , self.embed_dim , kernel_size=1, bias=True)
		n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
		m.weight.data.normal_(0, math.sqrt(2. / n))
		m.bias.data.zero_()
		self.en_map = m
		self.en_up = nn.ConvTranspose2d(self.embed_dim , self.embed_dim , 16, stride=8, padding=4
													,output_padding=0,groups=self.embed_dim, bias=False)
		
		fill_up_weights(self.en_up)
		self.en_up.weight.requires_grad = False
Beispiel #4
0
def main():
    savedir = ARGS_SAVE_DIR
    if not os.path.exists(savedir):
        os.makedirs(savedir)
    #os.makedirs(savedir + '/plotdata')

    model_file = importlib.import_module(ARGS_MODEL)
    model = model_file.Net(NUM_CLASSES)
    copyfile(ARGS_MODEL + ".py", savedir + '/' + ARGS_MODEL + ".py")
    if ARGS_CUDA:
        model = torch.nn.DataParallel(model).cuda()
    if ARGS_STATE:

        def load_my_state_dict(model, state_dict):
            own_state = model.state_dict()
            for name, param in state_dict.items():
                if name not in own_state:
                    continue
                own_state[name].copy_(param)
            return model

        model = load_my_state_dict(model, torch.load(ARGS_STATE))

    if (not ARGS_DECODER):
        print("#################### ENCODER TRAINING ####################")
        model = train(model, True)
    print("#################### DECODER TRAINING ####################")
    if (not ARGS_STATE):
        if ARGS_PRETRAINED_ENCODER:
            print("Loading encoder pretrained in imagenet")
            from erfnet_imagenet import ERFNet as ERFNet_imagenet
            pretrainedEnc = torch.nn.DataParallel(ERFNet_imagenet(1000))
            pretrainedEnc.load_state_dict(
                torch.load(ARGS_PRETRAINED_ENCODER)['state_dict'])
            pretrainedEnc = next(pretrainedEnc.children()).features.encoder
            if (not ARGS_CUDA):
                pretrainedEnc = pretrainedEnc.cpu()
        else:
            pretrainedEnc = next(model.children()).encoder
        model = model_file.Net(NUM_CLASSES, encoder=pretrainedEnc)
        if ARGS_CUDA:
            model = torch.nn.DataParallel(model).cuda()
    model = train(model, False)
    print("#################### TRAINING FINISHED ####################")
Beispiel #5
0
def main():
    class Args():
        cuda = True  # NOTE: cpu-only has not been tested so you might have to change code if you deactivate this flag
        model = "erfnet"
        state = False
        port = 8097
        datadir = "/esat/toyota/trace/deeplearning/datasets_public/cityscapes/leftImg8bit_trainvaltest"
        height = 512
        num_epochs = 5
        num_workers = 4
        batch_size = 2
        steps_loss = 50
        steps_plot = 50
        epochs_save = 0  # You can use this value to save model every X epochs
        savedir = "~/Document/thesis_kontras/"
        decoder = False
        pretrainedEncoder = False  # , default="../trained_models/erfnet_encoder_pretrained.pth.tar")
        visualize = False
        iouTrain = False  # recommended: False (takes more time to train otherwise)
        iouVal = True
        resume = False

    args = Args()
    savedir = f'../save/{args.savedir}'

    if not os.path.exists(savedir):
        os.makedirs(savedir)

    with open(savedir + '/opts.txt', "w") as myfile:
        myfile.write(str(args))

    #Load Model
    assert os.path.exists(args.model + ".py"), "Error: model definition not found"
    model_file = importlib.import_module(args.model)
    model = model_file.Net(NUM_CLASSES)
    copyfile(args.model + ".py", savedir + '/' + args.model + ".py")
    
    if args.cuda:
        model = torch.nn.DataParallel(model).cuda()
    
    if args.state:
        #if args.state is provided then load this state for training
        #Note: this only loads initialized weights. If you want to resume a training use "--resume" option!!
        """
        try:
            model.load_state_dict(torch.load(args.state))
        except AssertionError:
            model.load_state_dict(torch.load(args.state,
                map_location=lambda storage, loc: storage))
        #When model is saved as DataParallel it adds a model. to each key. To remove:
        #state_dict = {k.partition('model.')[2]: v for k,v in state_dict}
        #https://discuss.pytorch.org/t/prefix-parameter-names-in-saved-model-if-trained-by-multi-gpu/494
        """
        def load_my_state_dict(model, state_dict):  #custom function to load model when not all dict keys are there
            own_state = model.state_dict()
            for name, param in state_dict.items():
                if name not in own_state:
                     continue
                own_state[name].copy_(param)
            return model

        #print(torch.load(args.state))
        model = load_my_state_dict(model, torch.load(args.state))

    """
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            #m.weight.data.normal_(0.0, 0.02)
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0, math.sqrt(2. / n))
        elif classname.find('BatchNorm') != -1:
            #m.weight.data.normal_(1.0, 0.02)
            m.weight.data.fill_(1)
            m.bias.data.fill_(0)

    #TO ACCESS MODEL IN DataParallel: next(model.children())
    #next(model.children()).decoder.apply(weights_init)
    #Reinitialize weights for decoder
    
    next(model.children()).decoder.layers.apply(weights_init)
    next(model.children()).decoder.output_conv.apply(weights_init)

    #print(model.state_dict())
    f = open('weights5.txt', 'w')
    f.write(str(model.state_dict()))
    f.close()
    """

    #train(args, model)
    if (not args.decoder):
        print("========== ENCODER TRAINING ===========")
        model = train(args, model, True) #Train encoder
    #CAREFUL: for some reason, after training encoder alone, the decoder gets weights=0. 
    #We must reinit decoder weights or reload network passing only encoder in order to train decoder
    print("========== DECODER TRAINING ===========")
    if (not args.state):
        if args.pretrainedEncoder:
            print("Loading encoder pretrained in imagenet")
            from erfnet_imagenet import ERFNet as ERFNet_imagenet
            pretrainedEnc = torch.nn.DataParallel(ERFNet_imagenet(1000))
            pretrainedEnc.load_state_dict(torch.load(args.pretrainedEncoder)['state_dict'])
            pretrainedEnc = next(pretrainedEnc.children()).features.encoder
            if (not args.cuda):
                pretrainedEnc = pretrainedEnc.cpu()     #because loaded encoder is probably saved in cuda
        else:
            pretrainedEnc = next(model.children()).encoder
        model = model_file.Net(NUM_CLASSES, encoder=pretrainedEnc)  #Add decoder to encoder
        if args.cuda:
            model = torch.nn.DataParallel(model).cuda()
        #When loading encoder reinitialize weights for decoder because they are set to 0 when training dec
    model = train(args, model, False)   #Train decoder
    print("========== TRAINING FINISHED ===========")
Beispiel #6
0
def main(args):
    savedir = f'../save/{args.savedir}'

    if not os.path.exists(savedir):
        os.makedirs(savedir)

    with open(savedir + '/opts.txt', "w") as myfile:
        myfile.write(str(args))

    #Load Model
    assert os.path.exists(args.model +
                          ".py"), "Error: model definition not found"
    model_file = importlib.import_module(args.model)
    model = model_file.Net(NUM_CLASSES)
    copyfile(args.model + ".py", savedir + '/' + args.model + ".py")

    if args.cuda:
        model = torch.nn.DataParallel(model).cuda()

    if args.state:
        #if args.state is provided then load this state for training
        #Note: this only loads initialized weights. If you want to resume a training use "--resume" option!!
        """
        try:
            model.load_state_dict(torch.load(args.state))
        except AssertionError:
            model.load_state_dict(torch.load(args.state,
                map_location=lambda storage, loc: storage))
        #When model is saved as DataParallel it adds a model. to each key. To remove:
        #state_dict = {k.partition('model.')[2]: v for k,v in state_dict}
        #https://discuss.pytorch.org/t/prefix-parameter-names-in-saved-model-if-trained-by-multi-gpu/494
        """
        def load_my_state_dict(
            model, state_dict
        ):  #custom function to load model when not all dict keys are there
            own_state = model.state_dict()
            for name, param in state_dict.items():
                if name not in own_state:
                    continue
                own_state[name].copy_(param)
            return model

        #print(torch.load(args.state))
        model = load_my_state_dict(model, torch.load(args.state))
    """
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            #m.weight.data.normal_(0.0, 0.02)
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0, math.sqrt(2. / n))
        elif classname.find('BatchNorm') != -1:
            #m.weight.data.normal_(1.0, 0.02)
            m.weight.data.fill_(1)
            m.bias.data.fill_(0)

    #TO ACCESS MODEL IN DataParallel: next(model.children())
    #next(model.children()).decoder.apply(weights_init)
    #Reinitialize weights for decoder
    
    next(model.children()).decoder.layers.apply(weights_init)
    next(model.children()).decoder.output_conv.apply(weights_init)

    #print(model.state_dict())
    f = open('weights5.txt', 'w')
    f.write(str(model.state_dict()))
    f.close()
    """

    #train(args, model)
    if (not args.decoder):
        print("========== ENCODER TRAINING ===========")
        model = train(args, model, True)  #Train encoder
    #CAREFUL: for some reason, after training encoder alone, the decoder gets weights=0.
    #We must reinit decoder weights or reload network passing only encoder in order to train decoder
    print("========== DECODER TRAINING ===========")
    if (not args.state):
        if args.pretrainedEncoder:
            print("Loading encoder pretrained in imagenet")
            from erfnet_imagenet import ERFNet as ERFNet_imagenet
            pretrainedEnc = torch.nn.DataParallel(ERFNet_imagenet(1000))
            pretrainedEnc.load_state_dict(
                torch.load(args.pretrainedEncoder)['state_dict'], False)
            pretrainedEnc = next(pretrainedEnc.children()).features.encoder
            if (not args.cuda):
                pretrainedEnc = pretrainedEnc.cpu(
                )  #because loaded encoder is probably saved in cuda
        else:
            pretrainedEnc = next(model.children()).encoder
        model = model_file.Net(NUM_CLASSES,
                               encoder=pretrainedEnc)  #Add decoder to encoder
        if args.cuda:
            model = torch.nn.DataParallel(model).cuda()
        #When loading encoder reinitialize weights for decoder because they are set to 0 when training dec
    model = train(args, model, False)  #Train decoder
    print("========== TRAINING FINISHED ===========")
                      args['dataset']['kwargs']['transform'])

dataset_it = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    drop_last=False,
    num_workers=4,
    pin_memory=True if args['cuda'] else False)

pretrainedEnc = True
if args['pretrain_encoder']['apply']:
    print("Loading encoder pretrained in imagenet")
    from erfnet_imagenet import ERFNet as ERFNet_imagenet

    pretrainedEnc = torch.nn.DataParallel(ERFNet_imagenet(1000))
    pretrainedEnc.load_state_dict(
        torch.load(args['pretrain_encoder']['path'])['state_dict'])
    pretrainedEnc = next(pretrainedEnc.children()).features.encoder

model = get_model(args['model']['name'],
                  args['model']['kwargs']['num_classes'], pretrainedEnc)
model = torch.nn.DataParallel(model).to(device)

if os.path.exists(args['checkpoint_path']):
    print("model_loaded")
    state = torch.load(args['checkpoint_path'])
    model.load_state_dict(state['model_state_dict'], strict=True)
else:
    assert (False, 'checkpoint_path {} does not exist!'.format(
        args['checkpoint_path']))
Beispiel #8
0
def main(args):
    savedir = '../save/'+args.savedir

    if not os.path.exists(savedir):
        os.makedirs(savedir)

    with open(savedir + '/opts.txt', "w") as myfile:
        myfile.write(str(args))

    #Load Model
    assert os.path.exists(args.model + ".py"), "Error: model definition not found"
    model_file = importlib.import_module(args.model)
    model = model_file.Net(20)
    copyfile(args.model + ".py", savedir + '/' + args.model + ".py")
    
    if args.cuda:
        model = torch.nn.DataParallel(model).cuda()
        #model = ModelDataParallel(model).cuda()
    
    if args.state:
        #if args.state is provided then load this state for training
        #Note: this only loads initialized weights. If you want to resume a training use "--resume" option!!
 
        def load_my_state_dict(model, state_dict):  #custom function to load model when not all dict keys are there
            own_state = model.state_dict()
            for name, param in state_dict.items():
                if name not in own_state:
                     continue
                own_state[name].copy_(param)
            return model

        #print(torch.load(args.state))
        model = load_my_state_dict(model, torch.load(args.state))

    #train(args, model)
    print("========== TRAINING ===========")
    if (not args.state):
        if args.pretrainedEncoder:
            print("Loading encoder pretrained in imagenet")
            from erfnet_imagenet import ERFNet as ERFNet_imagenet
            pretrainedEnc = torch.nn.DataParallel(ERFNet_imagenet(1000))
            pretrainedEnc.load_state_dict(torch.load(args.pretrainedEncoder)['state_dict'])
            pretrainedEnc = next(pretrainedEnc.children()).features.encoder
            if (not args.cuda):
                pretrainedEnc = pretrainedEnc.cpu()     #because loaded encoder is probably saved in cuda
        else:
            pretrainedEnc = next(model.children()).encoder
        model = model_file.Net(20, encoder=pretrainedEnc)  #Add decoder to encoder
        model = model.cuda()
        if args.cuda:
            model = torch.nn.DataParallel(model).cuda()
        
        def load_my_state_dict(model, state_dict):  #custom function to load model when not all dict keys are there
            own_state = model.state_dict()
            for name, param in state_dict.items():
                if name not in own_state:
                     continue
                own_state[name].copy_(param)
            return model
        
        defSegNet = defectSegmentationNet(20,model)
    
        defSegNet = torch.nn.DataParallel(defSegNet).cuda()
    
        road_segNet_file = importlib.import_module('erfnet2')
        roadSeg_model = road_segNet_file.Net(20)
        roadSeg_model = torch.nn.DataParallel(roadSeg_model).cuda()

        cascadeNet_model = cascadeNet(roadSeg_model,defSegNet)
        cascadeNet_model = torch.nn.DataParallel(cascadeNet_model).cuda()
        cascadeNet_model = load_my_state_dict(cascadeNet_model,torch.load('../save/release_version_test/model_best.pth')) 

        #When loading encoder reinitialize weights for decoder because they are set to 0 when training dec
    model = train(args, cascadeNet_model, False)   #Train decoder
    print("========== TRAINING FINISHED ===========")
Beispiel #9
0
def main(args):
    savedir = f'../save/{args.savedir}'

    if not os.path.exists(savedir):
        os.makedirs(savedir)

    with open(savedir + '/opts.txt', "w") as myfile:
        myfile.write(str(args))

    #Load Model
    assert os.path.exists(args.model +
                          ".py"), "Error: model definition not found"
    model_file = importlib.import_module(args.model)
    model = model_file.ERFNet(NUM_CLASSES)
    copyfile(args.model + ".py", savedir + '/' + args.model + ".py")

    if args.cuda:
        model = torch.nn.DataParallel(model).cuda()

    if args.state:

        def load_my_state_dict(
            model, state_dict
        ):  #custom function to load model when not all dict keys are there
            own_state = model.state_dict()
            for name, param in state_dict.items():
                if name not in own_state:
                    continue
                own_state[name].copy_(param)
            return model

        #print(torch.load(args.state))
        model = load_my_state_dict(model, torch.load(args.state))

    #train(args, model)
    if (not args.decoder):
        print("========== ENCODER TRAINING ===========")
        model = train(args, model, True)  #Train encoder
    #CAREFUL: for some reason, after training encoder alone, the decoder gets weights=0.
    #We must reinit decoder weights or reload network passing only encoder in order to train decoder
    print("========== DECODER TRAINING ===========")
    if (not args.state):
        if args.pretrainedEncoder:
            print("Loading encoder pretrained in imagenet")
            from erfnet_imagenet import ERFNet as ERFNet_imagenet
            pretrainedEnc = torch.nn.DataParallel(ERFNet_imagenet(1000))
            pretrainedEnc.load_state_dict(
                torch.load(args.pretrainedEncoder)['state_dict'])
            pretrainedEnc = next(pretrainedEnc.children()).features.encoder
            if (not args.cuda):
                pretrainedEnc = pretrainedEnc.cpu(
                )  #because loaded encoder is probably saved in cuda
        else:
            pretrainedEnc = next(model.children()).encoder
        model = model_file.ERFNet(
            NUM_CLASSES, encoder=pretrainedEnc)  #Add decoder to encoder
        if args.cuda:
            model = torch.nn.DataParallel(model).cuda()
        #When loading encoder reinitialize weights for decoder because they are set to 0 when training dec
    # TODO 加载cityscape的训练权重,然后再做迁移学习
    # weightspath = "../save/cityscape_6classes_2/model_best.pth"
    # # def load_my_state_dict(model, state_dict):  #custom function to load model when not all dict elements
    # #     own_state = model.state_dict()
    # #     for name, param in state_dict.items():
    # #         if name not in own_state:
    # #              continue
    # #         own_state[name].copy_(param)
    # #     return model
    # # model = load_my_state_dict(model, torch.load(weightspath))
    # weights_cityscape = torch.load(weightspath)
    # # 删除掉不匹配的权重层
    # del weights_cityscape['module.decoder.output_conv.weight']
    # del weights_cityscape['module.decoder.output_conv.bias']
    # model.load_state_dict(weights_cityscape, strict=False)

    model = train(args, model, False)  #Train decoder
    save_path = f'../save/{args.savedir}/weight_final.pth'
    torch.save(model.state_dict(), save_path)
    print("========== TRAINING FINISHED ===========")
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    #Import ERFNet model from the folder
    #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet")
    if args.pretrainedEncoder:
        pretrainedEnc = torch.nn.DataParallel(ERFNet_imagenet(1000))
        #pretrainedEnc.load_state_dict(torch.load(args.pretrainedEncoder)['state_dcit'])
        pretrainedEnc = next(pretrainedEnc.children()).features.encoder
        if (not args.cuda):
            pretrainedEnc = pretrainedEnc.cpu()
        model = ERFNet(NUM_CLASSES, encoder=pretrainedEnc)
    else:
        model = ERFNet(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if args.cuda:
        model = model.cuda()

    #model.load_state_dict(torch.load(args.state))
    #model.load_state_dict(torch.load(weightspath)) #not working if missing key

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    model.eval()

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")

    dataset_test = avlane(args.datadir, input_transform_lot,
                          label_transform_lot, 'test')
    loader = DataLoader(dataset_test,
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=False)

    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    if (args.visualize):
        vis = visdom.Visdom()

    fig = plt.figure()
    ax = fig.gca()
    h = ax.imshow(Image.new('RGB', (640 * 2, 480), 0))

    print(len(loader.dataset))

    iouEvalTest = iouEval_binary(NUM_CLASSES)

    with torch.no_grad():
        for step, (images, labels, filename) in enumerate(loader):

            #print(images.shape)
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()

            inputs = Variable(images)
            targets = labels
            outputs = model(inputs)

            preds = torch.where(outputs > 0.5,
                                torch.ones([1], dtype=torch.long).cuda(),
                                torch.zeros([1], dtype=torch.long).cuda())
            #preds = torch.where(outputs > 0.5, torch.ones([1], dtype=torch.uint8).cuda(), torch.zeros([1], dtype=torch.uint8).cuda()) # b x 1 x h x w

            #label = outputs[0].max(0)[1].byte().cpu().data
            #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
            #label_color = Colorize()(label.unsqueeze(0))

            # iou
            iouEvalTest.addBatch(preds[:, 0],
                                 targets[:, 0])  # no_grad handles it already
            iouTest = iouEvalTest.getIoU()
            iouStr = "test IOU: " + '{:0.2f}'.format(
                iouTest.item() * 100) + "%"
            print(iouStr)

            # save the output
            filenameSave = os.path.join(args.loadDir, 'test_results',
                                        filename[0].split("test/")[1])
            os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
            #image_transform(label.byte()).save(filenameSave)
            #label_save = ToPILImage()(label)
            pred = preds.to(torch.uint8).squeeze(0).cpu()  # 1xhxw
            pred_save = pred_transform_lot(pred)
            #pred_save.save(filenameSave)

            # concatenate data & result
            im1 = Image.open(
                os.path.join(args.datadir, 'data/test',
                             filename[0].split("test/")[1])).convert('RGB')
            im2 = pred_save
            dst = Image.new('RGB', (im1.width + im2.width, im1.height))
            dst.paste(im1, (0, 0))
            dst.paste(im2, (im1.width, 0))
            filenameSaveConcat = os.path.join(args.loadDir,
                                              'test_results_concat',
                                              filename[0].split("test/")[1])
            os.makedirs(os.path.dirname(filenameSaveConcat), exist_ok=True)
            #dst.save(filenameSaveConcat)

            # wrtie iou on dst
            font = ImageFont.truetype(
                '/usr/share/fonts/truetype/freefont/FreeMonoBold.ttf', 36)
            d = ImageDraw.Draw(dst)
            d.text((900, 0), iouStr, font=font, fill=(255, 255, 0))

            # show video
            h.set_data(dst)
            plt.draw()
            plt.axis('off')
            plt.pause(1e-2)

            if (args.visualize):
                vis.image(label_save.numpy())

            print(step, filenameSave)