Example #1
0
 def __init__(self, scale=2, num_feats=64, kernel=3, padding=1, bias=True):
     super().__init__()
     self.model_image = EDSR(scale=scale, num_feats=32)
     self.model_high = EDSR(scale=scale, num_feats=24)
     self.model_low = EDSR(scale=scale, num_feats=8)
     layers = []
     if (scale & (scale - 1)) == 0:  # Is scale = 2^n?
         for _ in range(int(math.log(scale, 2))):
             layers += [
                 nn.Conv2d(in_channels=num_feats,
                           out_channels=num_feats * 4,
                           kernel_size=kernel,
                           padding=padding,
                           bias=bias)
             ]
             layers += [nn.PixelShuffle(2)]
     layers += [
         nn.Conv2d(in_channels=num_feats,
                   out_channels=3,
                   kernel_size=kernel,
                   padding=padding,
                   bias=bias)
     ]
     self.tail = nn.Sequential(*layers)
     self.add_mean = MeanShift(mode='add')
Example #2
0
 def define_model(self):
     self.sr_model = EDSR(self.LR, scale=FLAGS.scale)
Example #3
0
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))

from models.EDSR import EDSR
from utils.data_loader import get_loader
import torch

torch.manual_seed(0)


scale_factor = 4

model = EDSR(scale=scale_factor)

if scale_factor == 4:
    train_loader = get_loader(mode='train', batch_size=16, height=192, width=192, scale_factor=4, augment=True)
    test_loader = get_loader(mode='test', height=256, width=256, scale_factor=4)
elif scale_factor == 2:
    train_loader = get_loader(mode='train', batch_size=16, augment=True)
    test_loader = get_loader(mode='test')

"""
import trainer
trainer.train(model, train_loader, test_loader, mode=f'EDSR_x{scale_factor}_Baseline')

import trainer_v1_pool as trainer
trainer.train(model, train_loader, test_loader, mode='EDSR_v1_pool')

import trainer_v2_centered_init as trainer
trainer.train(model, train_loader, test_loader, mode='EDSR_v2_centered_kernel')
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))

from models.EDSR import EDSR
from utils.data_loader import get_loader
import torch

torch.manual_seed(0)

scale_factor = 2

model = EDSR(scale=scale_factor)

if scale_factor == 4:
    train_loader = get_loader(mode='train',
                              batch_size=16,
                              height=192,
                              width=192,
                              scale_factor=4,
                              augment=True)
    test_loader = get_loader(mode='test',
                             height=256,
                             width=256,
                             scale_factor=4)
elif scale_factor == 2:
    train_loader = get_loader(mode='train', batch_size=16, augment=True)
    test_loader = get_loader(mode='test')

from models.EDSR_x1 import EDSR
model = EDSR(scale=scale_factor)
Example #5
0
        if k == 'Freq_Fusion':
            model = EDSR_fusion(scale=scale_factor).to(device)
        elif k == 'EDSR_x2_v39_hm':
            model = EDSR_hm(scale=scale_factor).to(device)
        elif k == 'EDSR_x2_v40_hm_att':
            model = EDSR_hm_with_att(scale=scale_factor).to(device)
        elif k == 'EDSR_x2_v41_hm_att_v2':
            model = EDSR_hm_with_att_v2(scale=scale_factor).to(device)
        elif k == 'EDSR_x2_v43_base_att2':
            model = EDSR_with_att2(scale=scale_factor).to(device)
        elif k == 'EDSR_x2_v44_base_att2_std':
            model = EDSR_with_att2_std(scale=scale_factor).to(device)
        elif k == 'EDSR_x2_v45_base_att2_std2':
            model = EDSR_with_att2_std2(scale=scale_factor).to(device)
        else:
            model = EDSR(scale=scale_factor).to(device)
        model.load_state_dict(weights[k])
        model.eval()

        for image in tqdm(set5):
            lr, hr = get_tensor(image)

            if k == 'Freq_Fusion':
                lr_high, lr_low = pass_filter(lr)
                sr, out_img, out_hf, out_lf = model(lr, lr_high, lr_low)
            elif 'hm' in k:
                lr_hf = high_pass_filter_hard_kernel(lr)
                sr, _ = model(lr, lr_hf)
            else:
                sr, fea = model(lr)
Example #6
0
import torchvision.transforms as T

from models.EDSR import EDSR
from models.common import GMSD_quality

from utils import imshow
from utils.eval import ssim as get_ssim
from utils.eval import ms_ssim as get_msssim
from utils.eval import psnr as get_psnr

torch.manual_seed(0)
scale_factor = 2

device = 'cuda' if torch.cuda.is_available() else 'cpu'
quantize = lambda x: x.mul(255).clamp(0, 255).round().div(255)
model = EDSR(scale=scale_factor).to(device)

weights = {
    'Baseline':
    torch.load('./weights/2021.01.15/EDSR_x2_Baseline/epoch_1000.pth'),
    'MSHF': torch.load('./weights/2021.01.15/EDSR_x2_v10_MSHF/epoch_1000.pth'),
    'GMSD': torch.load('./weights/2021.01.15/EDSR_x2_v8_gmsd/epoch_1000.pth')
}


def get_tensor(lr):
    trans = T.Compose([T.ToTensor()])
    hr = lr.replace('LR_bicubic/X2', 'HR')
    hr = hr.replace('x2.png', '.png')
    lr = Image.open(lr)
    lr = trans(lr)