コード例 #1
0
ファイル: test.py プロジェクト: TimothyHTimothy/SFL-SFSR
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
    parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)
    opt = option.dict_to_nonedict(opt)
    dataset_opt = opt['datasets']['test']
    test_set = create_dataset(dataset_opt)
    test_loader = create_dataloader(test_set, dataset_opt, opt, None)
    
    model = Model(opt)
    
    if test_loader is not None:  
        calc_lpips = PerceptualLossLPIPS()
        if True:
            avg_ssim = avg_psnr = avg_lpips = 0
            print("Testing Starts!")
            idx = 0
            test_bar = tqdm(test_loader)
            for test_data in test_bar:
                idx += 1
                img_name = os.path.splitext(os.path.basename(test_data['LQ_path'][0]))[0]
                img_dir = '../test_results_' + opt['name']
                util.mkdir(img_dir)

                model.feed_data(test_data)
                model.test()

                visuals = model.get_current_visuals()
                sr_img = util.tensor2img(visuals['SR'])  # uint8
                gt_img = util.tensor2img(visuals['GT'])  # uint8
                lq_img = util.tensor2img(visuals['LQ'])  # uint8

                # Save SR images for reference
                save_sr_img_path = os.path.join(img_dir,
                                                 '{:s}_sr.png'.format(img_name))

                util.save_img(sr_img, save_sr_img_path)

                gt_img = gt_img / 255.
                sr_img = sr_img / 255.
                lq_img = lq_img / 255.
                avg_psnr += util.calculate_psnr(sr_img * 255, gt_img * 255)
                avg_ssim += util.calculate_ssim(sr_img * 255, gt_img * 255)
                avg_lpips += calc_lpips(visuals['SR'], visuals['GT'])
   
                
            avg_psnr = avg_psnr / idx
            avg_ssim = avg_ssim / idx
            avg_lpips = avg_lpips / idx
                
            print('Test_Result_{:s} psnr: {:.4e} ssim: {:.4e} lpips: {:.4e}'.format(
                    opt['name'], avg_psnr, avg_ssim, avg_lpips))
コード例 #2
0
ファイル: generate.py プロジェクト: Bin4writing/BasicSR
    def run(self):
        for ds in self.datasets:
            self.log('\nGenerating from [{:s}]...'.format(ds.cnf['lr_dir']))
            ds_dir = os.path.join(self.result_dir, ds.name)

            util.mkdir(ds_dir)
            for data in ds.loader:
                img_path = data['LR_path'][0]
                img_name = os.path.splitext(os.path.basename(img_path))[0]

                sr_img = util.tensor2img(self.model.generate(data))

                save_img_path = os.path.join(ds_dir, img_name + '.png')
                util.save_img(sr_img, save_img_path)
        return self
コード例 #3
0
model = VGG_Classifier().cuda()
model.load_model('../bird/prt.pth')
train_set = loader.TrainDataset(PATHS[dataset]['test'], is_test=True)
train_loader = DataLoader(dataset=train_set, num_workers=1, batch_size=1, shuffle=True)

val_set = loader.TrainDataset(PATHS[dataset]['valid'], is_train=False)
val_loader = DataLoader(dataset=val_set, num_workers=1, batch_size=1, shuffle=False)

for image, label, indices in train_loader:
    with torch.no_grad():
        image = image.to('cuda')
        cam = model.get_cam(image, 1/16).cpu()
        mm = torch.min(cam)
        MM = torch.max(cam)
    save_path = os.path.join(target['cam_test'], indices[0].split('/')[-1])
    print(save_path)
    cv2.imwrite(save_path, util.tensor2img(cam, min_max=(mm,MM)))

exit()


for image, label, indices in val_loader:
    with torch.no_grad():
        image = image.to('cuda')
        cam = model.get_cam(image, 1/16).cpu()
        mm = torch.min(cam)
        MM = torch.max(cam)
    save_path = os.path.join(target['cam_valid'], indices[0].split('/')[-1])
    print(save_path)
    cv2.imwrite(save_path, util.tensor2img(cam, min_max=(mm,MM)))   
    
コード例 #4
0
import util
with open('paths.yml', 'r') as f:
    PATHS = yaml.load(f)

dataset = 'bird'
target = PATHS['bird_lr_x16']

train_set = loader.TrainDataset(PATHS[dataset]['train'])
train_loader = DataLoader(dataset=train_set, num_workers=1, batch_size=1, shuffle=True)

val_set = loader.TrainDataset(PATHS[dataset]['valid'], is_train=False)
val_loader = DataLoader(dataset=val_set, num_workers=1, batch_size=1, shuffle=False)


for image, label, indices in train_loader:
    with torch.no_grad():
        image = image.to('cuda')
        lr = F.interpolate(image, scale_factor=1/16, mode='bilinear', align_corners=False).cpu()
    save_path = os.path.join(target['train'], indices[0].split('/')[-1])
    print(save_path)
    cv2.imwrite(save_path, util.tensor2img(lr))


for image, label, indices in val_loader:
    with torch.no_grad():
        image = image.to('cuda')
        lr = F.interpolate(image, scale_factor=1/16, mode='bilinear', align_corners=False).cpu()
    save_path = os.path.join(target['valid'], indices[0].split('/')[-1])
    print(save_path)
    cv2.imwrite(save_path, util.tensor2img(lr))    
    
コード例 #5
0
def main(config):
	device = torch.device(config['device'])

	##### Setup Dirs #####
	experiment_dir = config['path']['experiments'] + config['name']
	util.mkdir_and_rename(
                experiment_dir)  # rename experiment folder if exists
	util.mkdirs((experiment_dir+'/sr_images', experiment_dir+'/lr_images'))

	##### Setup Logger #####
	logger = util.Logger('test', experiment_dir, 'test_' + config['name'])

	##### print Experiment Config
	logger.log(util.dict2str(config))
	
	###### Load Dataset #####
	testing_data_loader = dataset.get_test_sets(config['dataset'], logger)

	trainer = create_model(config, logger)
	trainer.print_network_params(logger)

	total_avg_psnr = 0.0
	total_avg_ssim = 0.0

	for name, test_set in testing_data_loader.items():
		logger.log('Testing Dataset {:s}'.format(name))
		valid_start_time = time.time()
		avg_psnr = 0.0
		avg_ssim = 0.0
		idx = 0
		for i, batch in enumerate(test_set):
			idx += 1
			img_name = batch[2][0][batch[2][0].rindex('/')+1:]
			# print(img_name)
			img_name = img_name[:img_name.index('.')]
			img_dir_sr = experiment_dir+'/sr_images'
			img_dir_lr = experiment_dir+'/lr_images'
			util.mkdir(img_dir_sr)
			infer_time = trainer.test(batch)
			visuals = trainer.get_current_visuals()
			lr_img = util.tensor2img(visuals['LR'])
			sr_img = util.tensor2img(visuals['SR'])  # uint8
			gt_img = util.tensor2img(visuals['HR'])  # uint8
			save_sr_img_path = os.path.join(img_dir_sr, '{:s}.png'.format(img_name))
			save_lr_img_path = os.path.join(img_dir_lr, '{:s}.png'.format(img_name))
			util.save_img(lr_img, save_lr_img_path)
			util.save_img(sr_img, save_sr_img_path)
			crop_size = config['dataset']['scale']
			psnr, ssim = util.calc_metrics(sr_img, gt_img, crop_size)
			#logger.log('[ Image: {:s}  PSNR: {:.4f} SSIM: {:.4f} Inference Time: {:.8f}]'.format(img_name, psnr, ssim, infer_time))
			avg_psnr += psnr
			avg_ssim += ssim
		avg_psnr = avg_psnr / idx
		avg_ssim = avg_ssim / idx
		valid_t = time.time() - valid_start_time
		logger.log('[ Set: {:s} Time:{:.3f}] PSNR: {:.2f} SSIM {:.4f}'.format(name, valid_t, avg_psnr, avg_ssim))
		
		iter_start_time = time.time()

		total_avg_ssim += avg_ssim
		total_avg_psnr += avg_psnr

	total_avg_ssim /= len(testing_data_loader)
	total_avg_psnr /= len(testing_data_loader)
	
	logger.log('[ Total Average of Sets: PSNR: {:.2f} SSIM {:.4f}'.format(total_avg_psnr, total_avg_ssim))
コード例 #6
0
ファイル: test_demo.py プロジェクト: liuguoyou/PAN
def main():
    ## test dataset
    test_d = sorted(
        glob.glob('/mnt/hyzhao/Documents/datasets/DIV2K_test/*.png'))

    torch.cuda.current_device()
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = False
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    ## some functions
    def readimg(path):
        im = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        im = im.astype(np.float32) / 255.
        im = im[:, :, [2, 1, 0]]
        return im

    def img2tensor(img):
        imgt = torch.from_numpy(
            np.ascontiguousarray(np.transpose(img, (2, 0, 1)))).float()[None,
                                                                        ...]
        return imgt

    ## load model
    model = PANet_arch.PANet(in_nc=3, out_nc=3, nf=40, nb=16)
    model_weight = torch.load('./pretrained_model/PANetx4_DF2K.pth')
    model.load_state_dict(model_weight, strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)

    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))

    ## runnning
    print('-----------------Start Running-----------------')
    psnrs = []
    times = []

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    for i in range(len(test_d)):
        im = readimg(test_d[i])
        img_LR = img2tensor(im)
        img_LR = img_LR.to(device)

        start.record()
        img_SR = model(img_LR)
        end.record()

        torch.cuda.synchronize()
        times.append(start.elapsed_time(end))

        sr_img = util.tensor2img(img_SR.detach())

        ### save image
        save_img_path = './test_results/%s' % (test_d[i].split('/')[-1])
        util.save_img(sr_img, save_img_path)

        print('Image: %03d' % (i + 1))
    print('Paramters: %d, Mean Time: %.10f' %
          (number_parameters, np.mean(times) / 1000.))
コード例 #7
0
def sftgan(load_name="",
           save_name='fin_rlt.png',
           mode='rgb',
           override_input=False):
    path = load_name
    test_img_folder_name = "TMP1"
    # options
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    device = torch.device(
        'cuda')  # if you want to run on CPU, change 'cuda' -> 'cpu'
    # device = torch.device('cpu')

    # make dirs
    test_img_folder = 'SFTGAN/data/' + test_img_folder_name  # HR images
    save_prob_path = 'SFTGAN/data/' + test_img_folder_name + '_segprob'  # probability maps
    save_byteimg_path = 'SFTGAN/data/' + test_img_folder_name + '_byteimg'  # segmentation annotations
    save_colorimg_path = 'SFTGAN/data/' + test_img_folder_name + '_colorimg'  # segmentaion color results
    util.mkdirs([save_prob_path, save_byteimg_path, save_colorimg_path])

    test_prob_path = 'SFTGAN/data/' + test_img_folder_name + '_segprob'  # probability maps
    save_result_path = 'SFTGAN/data/' + test_img_folder_name + '_result'  # results
    util.mkdirs([save_result_path])

    # load model
    seg_model = arch.OutdoorSceneSeg()
    seg_model_path = 'SFTGAN/pretrained_models/segmentation_OST_bic.pth'
    seg_model.load_state_dict(torch.load(seg_model_path), strict=True)
    seg_model.eval()
    seg_model = seg_model.to(device)

    # look_up table, RGB, for coloring the segmentation results
    lookup_table = torch.from_numpy(
        np.array([
            [153, 153, 153],  # 0, background
            [0, 255, 255],  # 1, sky
            [109, 158, 235],  # 2, water
            [183, 225, 205],  # 3, grass
            [153, 0, 255],  # 4, mountain
            [17, 85, 204],  # 5, building
            [106, 168, 79],  # 6, plant
            [224, 102, 102],  # 7, animal
            [255, 255, 255],  # 8/255, void
        ])).float()
    lookup_table /= 255

    print('Testing segmentation probability maps ...')
    """
    for idx, path in enumerate(glob.glob(test_img_folder + '/*')):
        imgname = os.path.basename(path)
        basename = os.path.splitext(imgname)[0]
    
        if "txt" in path:
          continue
    """

    idx = 0
    if True:
        #print(idx + 1, basename, path)
        print(idx + 1)
        # read image
        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        img = util.modcrop(img, 8)

        print(
            "debug ",
            img.shape,
            img.ndim,
        )
        if img.ndim == 2:
            img = np.expand_dims(img, axis=2)

        if mode == 'bw':
            #print(img.shape) # w,h,3 <- 1
            stacked_img = np.stack((img, ) * 3, axis=2)  # bw -> rgb
            stacked_img = stacked_img[:, :, :, 0]
            #print(stacked_img.shape) # w,h,3 <- 1
            img = stacked_img
            #(424, 1024, 3)
        #print("debug img", img.shape, )

        if override_input:
            print("overriding input ", img.shape, "as", path)
            util.save_img(img, path)

        img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()

        # MATLAB imresize
        # You can use the MATLAB to generate LR images first for faster imresize operation
        img_LR = util.imresize(img / 255, 1 / 4, antialiasing=True)
        img = util.imresize(img_LR, 4, antialiasing=True) * 255

        img[0] -= 103.939
        img[1] -= 116.779
        img[2] -= 123.68
        img = img.unsqueeze(0)
        img = img.to(device)

        with torch.no_grad():
            output = seg_model(img).detach().float().cpu().squeeze()

        # save segmentation probability maps
        #torch.save(output, os.path.join(save_prob_path, basename + '_bic.pth'))  # 8xHxW
        SEG_OUT = output
        """
        # save segmentation byte images (annotations)
        _, argmax = torch.max(output, 0)
        argmax = argmax.squeeze().byte()
        cv2.imwrite('foo1.png', argmax.numpy())
    
        # save segmentation colorful results
        im_h, im_w = argmax.size()
        color = torch.FloatTensor(3, im_h, im_w).fill_(0)  # black
        for i in range(8):
            mask = torch.eq(argmax, i)
            color.select(0, 0).masked_fill_(mask, lookup_table[i][0])  # R
            color.select(0, 1).masked_fill_(mask, lookup_table[i][1])  # G
            color.select(0, 2).masked_fill_(mask, lookup_table[i][2])  # B
        # void
        mask = torch.eq(argmax, 255)
        color.select(0, 0).masked_fill_(mask, lookup_table[8][0])  # R
        color.select(0, 1).masked_fill_(mask, lookup_table[8][1])  # G
        color.select(0, 2).masked_fill_(mask, lookup_table[8][2])  # B
        torchvision.utils.save_image(
            color, 'foo2.png', padding=0, normalize=False)
        """

    del seg_model
    '''
    Codes for testing SFTGAN
    '''

    # options
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    sres_model_path = 'SFTGAN/pretrained_models/SFTGAN_torch.pth'  # torch version
    # sres_model_path = 'SFTGAN/pretrained_models/SFTGAN_noBN_OST_bg.pth'  # pytorch version

    device = torch.device(
        'cuda')  # if you want to run on CPU, change 'cuda' -> 'cpu'
    # device = torch.device('cpu')

    if 'torch' in sres_model_path:  # torch version
        model = arch.SFT_Net_torch()
    else:  # pytorch version
        model = arch.SFT_Net()
    model.load_state_dict(torch.load(sres_model_path), strict=True)
    model.eval()
    model = model.to(device)

    print('Testing SFTGAN ...')
    """
    for idx, path in enumerate(glob.glob(test_img_folder + '/*')):
        imgname = os.path.basename(path)
        basename = os.path.splitext(imgname)[0]
        
        if "txt" in path:
          continue
    """
    if True:
        path
        #print(idx + 1, basename)
        print(idx + 1)
        # read image
        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        img = util.modcrop(img, 8)
        img = img * 1.0 / 255
        if img.ndim == 2:
            img = np.expand_dims(img, axis=2)

        if mode == 'bw':
            #print(img.shape) # w,h,3 <- 1
            stacked_img = np.stack((img, ) * 3, axis=2)  # bw -> rgb
            stacked_img = stacked_img[:, :, :, 0]
            #print(stacked_img.shape) # w,h,3 <- 1
            img = stacked_img
            #(424, 1024, 3)
        #print("debug img", img.shape, )

        img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]],
                                            (2, 0, 1))).float()
        # MATLAB imresize
        # You can use the MATLAB to generate LR images first for faster imresize operation
        img_LR = util.imresize(img, 1 / 4, antialiasing=True)
        img_LR = img_LR.unsqueeze(0)
        img_LR = img_LR.to(device)

        # read segmentation probability maps
        #seg = torch.load(os.path.join(test_prob_path, basename + '_bic.pth'))
        seg = SEG_OUT
        seg = seg.unsqueeze(0)
        # change probability
        # seg.fill_(0)
        # seg[:,5].fill_(1)
        seg = seg.to(device)
        with torch.no_grad():
            output = model((img_LR, seg)).data.float().cpu().squeeze()
        output = util.tensor2img(output)
        util.save_img(output, save_name)
コード例 #8
0
for idx, path in enumerate(glob.glob(test_img_folder + '/*')):
    imgname = os.path.basename(path)
    basename = os.path.splitext(imgname)[0]
    print(idx + 1, basename)
    # read image
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    img = util.modcrop(img, 8)
    img = img * 1.0 / 255
    if img.ndim == 2:
        img = np.expand_dims(img, axis=2)
    img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]],
                                        (2, 0, 1))).float()
    # MATLAB imresize
    # You can use the MATLAB to generate LR images first for faster imresize operation
    img_LR = util.imresize(img, 1 / 4, antialiasing=True)
    img_LR = img_LR.unsqueeze(0)
    img_LR = img_LR.to(device)

    # read segmentation probability maps
    seg = torch.load(os.path.join(test_prob_path, basename + '_bic.pth'))
    seg = seg.unsqueeze(0)
    # change probability
    # seg.fill_(0)
    # seg[:,5].fill_(1)
    seg = seg.to(device)
    with torch.no_grad():
        output = model((img_LR, seg)).data.float().cpu().squeeze()
    output = util.tensor2img(output)
    util.save_img(output, os.path.join(save_result_path,
                                       basename + '_rlt.png'))
コード例 #9
0
ファイル: runner.py プロジェクト: Bin4writing/BasicSR
    def run(self):

        if self.start_epoch: self.model.resume_training(self.resume_state)
        self.log('Start training from epoch: {:d}, iter: {:d}'.format(
            self.start_epoch, self.current_step))
        for epoch in range(self.start_epoch, self.total_epochs):
            for ds in self.datasets:
                for data in ds.loader:
                    self.current_step += 1
                    if self.current_step > self.total_iters: break
                    self.model.update_learning_rate()

                    self.model.feed_data(data)
                    self.model.optimize_parameters(self.current_step)

                    if self.current_step % 500 == 0:
                        logs = self.model.get_current_log()
                        message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                            epoch, self.current_step,
                            self.model.get_current_learning_rate())
                        for k, v in logs.items():
                            message += '{:s}: {:.4e} '.format(k, v)
                            if self.config[
                                    'enable_tensorboard'] and 'debug' not in self.config[
                                        'name']:
                                self.tf_logger.add_scalar(
                                    k, v, self.current_step)
                        self.log(message)

                    if self.current_step % 5000 == 0:
                        avg_ssim = 0.0
                        idx = 0
                        for val_ds in self.val_datasets:
                            for val_data in val_ds.loader:
                                idx += 1
                                img_name = os.path.splitext(
                                    os.path.basename(val_ds['LR_path'][0]))[0]
                                img_dir = os.path.join(
                                    self.config['val_images'], img_name)
                                util.mkdir(img_dir)

                                self.model.feed_data(val_ds.loader)
                                self.model.test()

                                visuals = self.model.get_current_visuals()
                                sr_img = util.tensor2img(
                                    visuals['SR'])  # uint8
                                gt_img = util.tensor2img(
                                    visuals['HR'])  # uint8

                                crop_size = 4
                                gt_img = gt_img / 255.
                                sr_img = sr_img / 255.
                                cropped_sr_img = sr_img[
                                    crop_size:-crop_size,
                                    crop_size:-crop_size, :]
                                cropped_gt_img = gt_img[
                                    crop_size:-crop_size,
                                    crop_size:-crop_size, :]
                                avg_ssim += util.calculate_ssim(
                                    cropped_sr_img * 255, cropped_gt_img * 255)

                        avg_ssim = avg_ssim / idx

                        self.log(
                            '# Validation # PSNR: {:.4e}'.format(avg_ssim))
                        self.log(
                            '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(
                                epoch, self.current_step, avg_ssim))

                        if self.config[
                                'enable_tensorboard'] and 'debug' not in self.config[
                                    'name']:
                            self.tf_logger.add_scalar('psnr', avg_ssim,
                                                      self.current_step)

                    if self.current_step % 5000 == 0:
                        self.log('Saving models and training states.')
                        self.model.save(self.current_step)
                        self.model.save_training_state(epoch,
                                                       self.current_step)

        self.log('Saving the final self.model.')
        self.model.save('latest')
        self.log('End of training.')
コード例 #10
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
    parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(opt['path']['resume_state'],
                                  map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            print(opt['path'])
            util.mkdir_and_rename(
                opt['path']['experiments_root'])  # rename experiment folder if exists
            util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
                         and 'pretrain_model' not in key and 'resume' not in key and path is not None))

        # config loggers. Before it, the log will not work
        util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
                          screen=True, tofile=True)
        util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
                          screen=True, tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
                from tensorboardX import SummaryWriter
            trial = 0
            while os.path.isdir('../Loggers/' + opt['name'] + '/' + str(trial)):
                trial += 1
            tb_logger = SummaryWriter(log_dir='../Loggers/' + opt['name'] + '/' + str(trial))
    else:
        util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
        logger = logging.getLogger('base')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    # -------------------------------------------- ADDED --------------------------------------------
    l1_loss = torch.nn.L1Loss()
    mse_loss = torch.nn.MSELoss()
    calc_lpips = PerceptualLossLPIPS()
    if torch.cuda.is_available():
        l1_loss = l1_loss.cuda()
        mse_loss = mse_loss.cuda()
    # -----------------------------------------------------------------------------------------------

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
                total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
            if rank <= 0:
                logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                    len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    model = Model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt['dist']:
            train_sampler.set_epoch(epoch)
        train_bar = tqdm(train_loader, desc='[%d/%d]' % (epoch, total_epochs))
        for bus, train_data in enumerate(train_bar):

             # validation
            if epoch % opt['train']['val_freq'] == 0 and bus == 0 and rank <= 0:
                avg_ssim = avg_psnr = avg_lpips = val_pix_err_f = val_pix_err_nf = val_mean_color_err = 0.0
                print("into validation!")
                idx = 0
                val_bar = tqdm(val_loader, desc='[%d/%d]' % (epoch, total_epochs))
                for val_data in val_bar:
                    idx += 1
                    img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['GT'])  # uint8
                    lq_img = util.tensor2img(visuals['LQ'])  # uint8
                    #nr_img = util.tensor2img(visuals['NR'])  # uint8
                    #nf_img = util.tensor2img(visuals['NF'])  # uint8
                    #nh_img = util.tensor2img(visuals['NH'])  # uint8


                    #print("Great! images got into here.")

                    # Save SR images for reference
                    save_sr_img_path = os.path.join(img_dir,
                                                 '{:s}_{:d}_sr.png'.format(img_name, current_step))
                    save_nr_img_path = os.path.join(img_dir,
                                                 '{:s}_{:d}_lq.png'.format(img_name, current_step))
                    #save_nf_img_path = os.path.join(img_dir,
                                                # 'bs_{:s}_{:d}_nr.png'.format(img_name, current_step)) 
                    #save_nh_img_path = os.path.join(img_dir,
                                                # 'bs_{:s}_{:d}_nh.png'.format(img_name, current_step)) 
                    util.save_img(sr_img, save_sr_img_path)
                    util.save_img(lq_img, save_nr_img_path)
                    #util.save_img(nf_img, save_nf_img_path)
                    #util.save_img(nh_img, save_nh_img_path)


                    #print("Saved")
                    # calculate PSNR
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.
                    #nf_img = nf_img / 255.
                    lq_img = lq_img / 255.
                    #cropped_lq_img = lq_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    #cropped_nr_img = nr_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(sr_img * 255, gt_img * 255)
                    avg_ssim += util.calculate_ssim(sr_img * 255, gt_img * 255)
                    avg_lpips += calc_lpips(visuals['SR'], visuals['GT'])
                    #avg_psnr_n += util.calculate_psnr(cropped_lq_img * 255, cropped_nr_img * 255)

                    # ----------------------------------------- ADDED -----------------------------------------
                    val_pix_err_nf += l1_loss(visuals['SR'], visuals['GT'])
                    val_mean_color_err += mse_loss(visuals['SR'].mean(2).mean(1), visuals['GT'].mean(2).mean(1))
                    # -----------------------------------------------------------------------------------------
                
                
                avg_psnr = avg_psnr / idx
                avg_ssim = avg_ssim / idx
                avg_lpips = avg_lpips / idx
                val_pix_err_f /= idx
                val_pix_err_nf /= idx
                val_mean_color_err /= idx



                # log
                logger.info('# Validation # PSNR: {:.4e},'.format(avg_psnr))
                logger.info('# Validation # SSIM: {:.4e},'.format(avg_ssim))
                logger.info('# Validation # LPIPS: {:.4e},'.format(avg_lpips))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e} ssim: {:.4e} lpips: {:.4e}'.format(
                    epoch, current_step, avg_psnr, avg_ssim, avg_lpips))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
                    tb_logger.add_scalar('val_ssim', avg_ssim, current_step)
                    tb_logger.add_scalar('val_lpips', avg_lpips, current_step)
                    tb_logger.add_scalar('val_pix_err_nf', val_pix_err_nf, current_step)
                    tb_logger.add_scalar('val_mean_color_err', val_mean_color_err, current_step)

            current_step += 1
            if current_step > total_iters:
                break
            #### update learning rate
            model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])

            #### training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)
            model.clear_data()
            #### tb_logger
            if current_step % opt['logger']['tb_freq'] == 0:
                logs = model.get_current_log()
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    for k, v in logs.items():
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)

            
            #### logger
            if epoch % opt['logger']['print_freq'] == 0  and epoch != 0 and bus == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                if rank <= 0:
                    logger.info(message)

           
            #### save models and training states
            if epoch % opt['logger']['save_checkpoint_freq'] == 0 and epoch != 0 and bus == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')