Exemple #1
0
def main(_):

    # Sets up config and enables GPU memory allocation growth
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    # Checks if train mode is 3, 5 or 6 and training mode is on
    if FLAGS.train_mode == 3 and FLAGS.is_train:
        print('Error: Bicubic Mode does not require training')
        return
    elif FLAGS.train_mode == 5 and FLAGS.is_train:
        print(
            'Error: Multi-Dir testing mode for Mode 2 does not require training'
        )
        return
    elif FLAGS.train_mode == 6 and FLAGS.is_train:
        print(
            'Error: Multi-Dir testing mode for Mode 1 does not require training'
        )
        return

    # Starts session; initializes ESPCN object; and runs training or testing operations
    with tf.Session(config=config) as sess:
        espcn = ESPCN(sess,
                      image_size=FLAGS.image_size,
                      is_train=FLAGS.is_train,
                      train_mode=FLAGS.train_mode,
                      scale=FLAGS.scale,
                      c_dim=FLAGS.c_dim,
                      batch_size=FLAGS.batch_size,
                      load_existing_data=FLAGS.load_existing_data,
                      config=config)
        espcn.train(FLAGS)
Exemple #2
0
def main(_):  #?
    with tf.Session() as sess:
        espcn = ESPCN(sess,
                      image_size=FLAGS.image_size,
                      is_train=FLAGS.is_train,
                      scale=FLAGS.scale,
                      c_dim=FLAGS.c_dim,
                      batch_size=FLAGS.batch_size,
                      test_img=FLAGS.test_img,
                      test_path=FLAGS.test_path)

        espcn.train(FLAGS)
Exemple #3
0
def main(_):  # ?
    with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(
            log_device_placement=True)) as sess:
        espcn = ESPCN(
            sess,
            image_size=FLAGS.image_size,
            is_train=FLAGS.is_train,
            scale=FLAGS.scale,
            c_dim=FLAGS.c_dim,
            batch_size=FLAGS.batch_size,
            test_img=FLAGS.test_img,
        )

        espcn.train(FLAGS)
Exemple #4
0
    parser.add_argument('--input',
                        type=str,
                        required=True,
                        help='input image to super resolve')
    parser.add_argument('--output',
                        type=str,
                        required=True,
                        help='where to save the output image')
    args = parser.parse_args()

    img = Image.open(args.input).convert('RGB')
    img = rgb2ycrcb(img)
    y, cb, cr = img.split()

    ckpt = torch.load(args.model, map_location='cpu')
    model = ESPCN(upscale_factor=args.upscale_factor)
    model.load_state_dict(ckpt['model'])

    input = ToTensor()(y).view(1, -1, y.size[1], y.size[0])

    out = model(input)
    out_img_y = out.detach().numpy().squeeze()
    out_img_y *= 255.0
    out_img_y = out_img_y.clip(0, 255)
    out_img_y = Image.fromarray(np.uint8(out_img_y), mode='L')

    out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)
    out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
    out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr])
    out_img = ycbcr2rgb(out_img)
Exemple #5
0
    opts = parser.parse_args()

    if not os.path.exists(opts.weights_dir):
        os.mkdir(opts.weights_dir)

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')
    torch.manual_seed(42)

    if opts.sr_module == "FSRCNN":
        sr_module = FSRCNN(scale=opts.scale).to(device)
    elif opts.sr_module == "ESPCN":
        sr_module = ESPCN(scale=opts.scale).to(device)
    elif opts.sr_module == "VDSR":
        sr_module = VDSR(scale=opts.scale).to(device)
    else:
        sr_module = FSRCNN(scale=opts.scale).to(device)

    if opts.lr_module == "FLRCNN":
        lr_module = FLRCNN(scale=opts.scale).to(device)
    elif opts.lr_module == "DESPCN":
        lr_module = DESPCN(scale=opts.scale).to(device)
    elif opts.lr_module == "DVDSR":
        lr_module = DVDSR(scale=opts.scale).to(device)
    else:
        lr_module = FLRCNN(scale=opts.scale).to(device)

    criterion = nn.MSELoss()
Exemple #6
0
import tensorflow as tf
from model import ESPCN
import config

if __name__ == '__main__':
    with tf.Session() as sess:
        espcn = ESPCN(
            sess,
            image_size=config.image_size,
            scale=config.scale,
            c_dim=config.c_dim,
            batch_size=config.batch_size,
        )
        espcn.train()
Exemple #7
0
            DataList = []

            if FLAGS.train_mode == 2:
                xx1, xx2, yy = prepare_data(FLAGS)
                DataList = [xx1, xx2, yy]
            else:
                xx1, yy = prepare_data(FLAGS)
                DataList = [xx1, yy]

            espcn = ESPCN(
                image_size=FLAGS.image_size,
                is_train=FLAGS.is_train,
                train_mode=FLAGS.train_mode,
                scale=FLAGS.scale,
                c_dim=FLAGS.c_dim,
                batch_size=FLAGS.batch_size,
                load_existing_data=FLAGS.load_existing_data,
                device=device,
                learn_rate=FLAGS.learning_rate,
                data_list=DataList)

            # 通过设置log_device_placement选项来记录operations 和 Tensor 被指派到哪个设备上运行
            config = tf.ConfigProto(
                allow_soft_placement=True,
                log_device_placement=False,
                device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index]
            )

            run_train_epochs(target, config)
Exemple #8
0
                      upscale_factor=args.upscale_factor)
    train_dataloader = DataLoader(dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)

    val_dataset = Dataset(path=args.path,
                          mode='val',
                          upscale_factor=args.upscale_factor)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.num_workers)

    print("==>Loading model")
    model = ESPCN(upscale_factor=args.upscale_factor).to(device)
    #data parallel
    if len(args.gpu_ids) > 1:
        model = nn.DataParallel(model, args.gpu_ids)

    print(model)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    criterion = nn.MSELoss().to(device)
    psnr_best = 0
    for epoch in range(args.epochs):
        train(train_dataloader, model, epoch, criterion, optimizer, args)
        psnr = evaluate(val_dataloader, model, epoch, criterion, args)
        if psnr > psnr_best:
            psnr_best = psnr
            save_path_name = os.path.join(
Exemple #9
0
def main() -> None:
    # Initialize the super-resolution model
    print("Build SR model...")
    model = ESPCN(config.upscale_factor).to(config.device)
    print("Build SR model successfully.")

    # Load the super-resolution model weights
    print(f"Load SR model weights `{os.path.abspath(config.model_path)}`...")
    state_dict = torch.load(config.model_path, map_location=config.device)
    model.load_state_dict(state_dict)
    print(f"Load SR model weights `{os.path.abspath(config.model_path)}` successfully.")

    # Create a folder of super-resolution experiment results
    results_dir = os.path.join("results", "test", config.exp_name)
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    # Start the verification mode of the model.
    model.eval()
    # Turn on half-precision inference.
    model.half()

    # Initialize the image evaluation index.
    total_psnr = 0.0

    # Get a list of test image file names.
    file_names = natsorted(os.listdir(config.hr_dir))
    # Get the number of test image files.
    total_files = len(file_names)

    for index in range(total_files):
        lr_image_path = os.path.join(config.lr_dir, file_names[index])
        sr_image_path = os.path.join(config.sr_dir, file_names[index])
        hr_image_path = os.path.join(config.hr_dir, file_names[index])

        print(f"Processing `{os.path.abspath(hr_image_path)}`...")
        lr_image = Image.open(lr_image_path).convert("RGB")
        bic_image = lr_image.resize([int(lr_image.width * config.upscale_factor), int(lr_image.height * config.upscale_factor)], Image.BICUBIC)
        hr_image = Image.open(hr_image_path).convert("RGB")

        # Extract Y channel lr image data
        lr_image = np.array(lr_image).astype(np.float32)
        lr_ycbcr_image = imgproc.convert_rgb_to_ycbcr(lr_image)
        lr_y_tensor = imgproc.image2tensor(lr_ycbcr_image, range_norm=False, half=True).to(config.device).unsqueeze_(0)

        # Extract Y channel bic image data
        bic_image = np.array(bic_image).astype(np.float32)
        bic_ycbcr_image = imgproc.convert_rgb_to_ycbcr(bic_image)

        # Extract Y channel hr image data.
        hr_image = np.array(hr_image).astype(np.float32)
        hr_ycbcr_image = imgproc.convert_rgb_to_ycbcr(hr_image)
        hr_y_tensor = imgproc.image2tensor(hr_ycbcr_image, range_norm=False, half=True).to(config.device).unsqueeze_(0)

        # Only reconstruct the Y channel image data.
        with torch.no_grad():
            sr_y_tensor = model(lr_y_tensor)

        # Cal PSNR
        total_psnr += 10. * torch.log10(1. / torch.mean((sr_y_tensor - hr_y_tensor) ** 2))

        sr_y_image = imgproc.tensor2image(sr_y_tensor, range_norm=False, half=True)
        sr_image = np.array([sr_y_image, bic_ycbcr_image[..., 1], bic_ycbcr_image[..., 2]]).transpose([1, 2, 0])
        sr_image = np.clip(imgproc.convert_ycbcr_to_rgb(sr_image), 0.0, 255.0).astype(np.uint8)
        sr_image = Image.fromarray(sr_image)
        sr_image.save(sr_image_path)

    print(f"PSNR: {total_psnr / total_files:.2f}.\n")
Exemple #10
0
        if opt.upscale is not 2:
            raise ("ONLY SUPPORT 2X")
        else:
            if opt.model == "FALSR_A":
                model = FALSR_A()
            if opt.model == "FALSR_B":
                model = FALSR_B()

    if opt.model == "SRCNN" and opt.upscale == 4:
        model = SRCNN(num_channels=3, upscale_factor=4)

    if opt.model == "VDSR" and opt.upscale == 4:
        model = VDSR(num_channels=3, base_channels=3, num_residual=20)

    if opt.model == "ESPCN" and opt.upscale == 4:
        model = ESPCN(num_channels=3, feature=64, upscale_factor=4)

if opt.criterion:
    if opt.criterion == "l1":
        criterion = nn.L1Loss()
    if opt.criterion == "l2":
        criterion = nn.MSELoss()
    if opt.criterion == "custom":
        pass

if torch.cuda.is_available():
    model = model.cuda()
    criterion = criterion.cuda()

optimizerG = optim.RMSprop(model.parameters(), lr=opt.lr)
Exemple #11
0
                                                             staircase=True)

# TPU objects
tpu_strategy = None

# model & optimizer
model = None
optimizer = None

if args.use_tpu:
    cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_host(cluster_resolver.master())
    tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
    tpu_strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    with tpu_strategy.scope():
        model = ESPCN(args.upscale_factor)
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
else:
    model = ESPCN(args.upscale_factor)
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

# Dataset
train_dataset = None
test_dataset = None
if args.use_tpu:
    with tpu_strategy.scope():
        #train_dataset = get_training_set(args.upscale_factor).shuffle(200).batch(args.batch_size)
        train_dataset = get_coco_training_set(
            args.upscale_factor).shuffle(200).batch(args.batch_size).prefetch(
                buffer_size=tf.data.experimental.AUTOTUNE)
        train_dataset = tpu_strategy.experimental_distribute_dataset(
        img_dir=config['data']['train_root'],
        upscale_factor=config['model']['upscale_factor'],
        img_channels=config['model']['img_channels'],
        crop_size=config['data']['lr_crop_size'] *
        config['model']['upscale_factor'])
    train_dataloader = DataLoader(dataset=train_set,
                                  batch_size=config['training']['batch_size'],
                                  shuffle=True)
    val_set = get_val_set(img_dir=config['data']['test_root'],
                          upscale_factor=config['model']['upscale_factor'],
                          img_channels=config['model']['img_channels'])
    val_dataloader = DataLoader(dataset=val_set, batch_size=1, shuffle=False)

    print('===> Building model')
    sys.stdout.flush()
    model = ESPCN(img_channels=config['model']['img_channels'],
                  upscale_factor=config['model']['upscale_factor']).to(device)
    criterion = nn.MSELoss()
    optimizer = setup_optimizer(model, config)
    scheduler = setup_scheduler(optimizer, config)

    start_iter = 0
    best_val_psnr = -1

    if config['training']['resume'] != 'None':
        print('===> Reloading model')
        sys.stdout.flush()
        ckpt = torch.load(config['training']['resume'])
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optimizer'])
        scheduler.load_state_dict(ckpt['scheduler'])
        start_iter = ckpt['iter']
Exemple #13
0
def build_model() -> nn.Module:
    model = ESPCN(config.upscale_factor).to(config.device)

    return model
Exemple #14
0
                        type=str,
                        choices=['FLRCNN', 'DESPCN', 'DVDSR'],
                        required=True)
    parser.add_argument("--scale", type=int, default=2)

    opts = parser.parse_args()

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    if opts.sr_module == "FSRCNN":
        sr_module = FSRCNN(scale=opts.scale)
    elif opts.sr_module == "ESPCN":
        sr_module = ESPCN(scale=opts.scale)
    elif opts.sr_module == "VDSR":
        sr_module = VDSR(scale=opts.scale)
    else:
        sr_module = FSRCNN(scale=opts.scale)

    if opts.lr_module == "FLRCNN":
        lr_module = FLRCNN(scale=opts.scale)
    elif opts.lr_module == "DESPCN":
        lr_module = DESPCN(scale=opts.scale)
    elif opts.lr_module == "DVDSR":
        lr_module = DVDSR(scale=opts.scale)
    else:
        lr_module = FLRCNN(scale=opts.scale)

    sr_module = sr_module.to(device)
Exemple #15
0
            for split, folder in folders.items()}
datasizes = {split: len(datasets[split])
              for split, folder in folders.items()}

batch_size = 4

dataloaders = {split: data.DataLoader(dataset, batch_size=1, shuffle=split == 'train', num_workers=0)
               for split, dataset in datasets.items()}

dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

nepochs = 50
log_step = 100

# Define network
net = ESPCN(8)
net.to(dev)
print(net)

# Define loss
criterion = nn.MSELoss()

# Define optim
optimizer = optim.Adam(net.parameters(), lr=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)

pretrained = torch.load('models/unet_bn_20190912_040318.pth', map_location='cuda:0')
net.load_state_dict(pretrained['model_state_dict'])
optimizer.load_state_dict(pretrained['optimizer_state_dict'])

net.eval()
Exemple #16
0
from model import ESPCN
from matplotlib import pyplot as plt
from utils import pre_process, psnr_calculate, convert_ycbcr_to_rgb, convert_rgb_to_ycbcr

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', type=str)
    parser.add_argument('--test_img', type=str)
    parser.add_argument('--scale', type=int, default=4)
    args = parser.parse_args()

    criterion = nn.MSELoss()
    cudnn.benchmark = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    net = ESPCN(num_channel=1, scale=args.scale)

    net.load_state_dict(torch.load(args.weights, map_location=device))

    with torch.no_grad():
        net.eval()

    img = Image.open(args.test_img, mode='r').convert('RGB')
    height, weight = (img.size[0] // args.scale) * args.scale, (
        img.size[1] // args.scale) * args.scale

    lr = img.resize((height // args.scale, weight // args.scale),
                    Image.BICUBIC)
    bicubic = lr.resize((height, weight), Image.BICUBIC)
    lr = pre_process(lr.convert('L')).to(device)
Exemple #17
0
args = parser.parse_args()
print(args)

torch.manual_seed(args.seed)
CUDA = torch.cuda.is_available()
dtype = set_dtype(CUDA)
device = torch.device("cuda" if CUDA else "cpu")

train_set = get_training_set(args.upscale_factor)
test_set = get_test_set(args.upscale_factor)
train_data_loader = DataLoader(dataset=train_set, num_workers=args.threads, \
                                batch_size=args.train_batch_size, shuffle=True)
test_data_loader = DataLoader(dataset=test_set, num_workers=args.threads, \
                                batch_size=args.test_batch_size, shuffle=False)

net = ESPCN(upscale_factor=args.upscale_factor).to(device)

# Uncomment below to load trained weights
# weights = torch.load('data/weights/weights_epoch_30.pth')
# net.load_state_dict(weights)

criterion = nn.MSELoss()
optim = torch.optim.Adam(net.parameters(), lr=args.lr)


def train(epoch):
    epoch_loss = 0
    for iteration, batch in enumerate(train_data_loader, 1):
        img_in, target = batch[0].to(device).type(dtype), \
                         batch[1].to(device).type(dtype)
    parser.add_argument('--cuda',
                        action='store_true',
                        help='whether to use cuda')
    args = parser.parse_args()

    if args.cuda and not torch.cuda.is_available():
        raise Exception('No GPU found')
    device = torch.device('cuda' if args.cuda else 'cpu')
    print('Use device:', device)

    filenames = os.listdir(args.img_dir)
    image_filenames = [os.path.join(args.img_dir, x) for x in filenames \
                       if is_image_file(x)]
    image_filenames = sorted(image_filenames)

    model = ESPCN(img_channels=args.img_channels,
                  upscale_factor=args.upscale_factor).to(device)
    if args.cuda:
        ckpt = torch.load(args.model)
    else:
        ckpt = torch.load(args.model, map_location='cpu')
    model.load_state_dict(ckpt['model'])

    res = {}

    for i, f in enumerate(image_filenames):
        # Read test image.
        img = Image.open(f).convert('RGB')
        width, height = img.size[0], img.size[1]

        # Crop test image so that it has size that can be downsampled by the upscale factor.
        pad_width = width % args.upscale_factor
Exemple #19
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--base',
                        '-B',
                        default=os.path.dirname(os.path.abspath(__file__)),
                        help='base directory path of program files')
    parser.add_argument('--config_path',
                        type=str,
                        default='configs/base.yml',
                        help='path to config file')
    parser.add_argument('--out',
                        '-o',
                        default='results/inference',
                        help='Directory to output the result')

    parser.add_argument('--model',
                        '-m',
                        default='',
                        help='Load model data(snapshot)')

    parser.add_argument('--root',
                        '-R',
                        default=os.path.dirname(os.path.abspath(__file__)),
                        help='Root directory path of input image')
    args = parser.parse_args()
    config = yaml_utils.Config(
        yaml.load(open(os.path.join(args.base, args.config_path))))
    print('GPU: {}'.format(args.gpu))
    print('')

    hr_patchside = config.patch['patchside']
    config.patch['patchside'] = int(config.patch['patchside'] /
                                    config.upsampling_rate)

    gen = ESPCN(r=config.upsampling_rate)
    chainer.serializers.load_npz(args.model, gen)
    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        gen.to_gpu()
    xp = gen.xp

    # Read test list
    path_pairs = []
    with open(os.path.join(args.base,
                           config.dataset['test_fn'])) as paths_file:
        for line in paths_file:
            line = line.split()
            if not line: continue
            path_pairs.append(line[:])

    case_list = []
    ssim_list = []
    psnr_list = []
    for i in path_pairs:
        print('   LR from: {}'.format(i[0]))
        print('   HR from: {}'.format(i[1]))
        sitkLR = sitk.ReadImage(os.path.join(args.root, i[0]))
        lr = sitk.GetArrayFromImage(sitkLR).astype("float32")

        # Calculate maximum of number of patch at each side
        ze, ye, xe = lr.shape
        xm = int(math.ceil((float(xe) / float(config.patch['patchside']))))
        ym = int(math.ceil((float(ye) / float(config.patch['patchside']))))
        zm = int(math.ceil((float(ze) / float(config.patch['patchside']))))

        margin = ((0, config.patch['patchside']),
                  (0, config.patch['patchside']), (0,
                                                   config.patch['patchside']))
        lr = np.pad(lr, margin, 'edge')
        lr = chainer.Variable(
            xp.array(lr[np.newaxis, np.newaxis, :], dtype=xp.float32))

        zh, yh, xh = ze * config.upsampling_rate, ye * config.upsampling_rate, xe * config.upsampling_rate
        hr_map = np.zeros(
            (zh + hr_patchside, yh + hr_patchside, xh + hr_patchside))

        # Patch loop
        for s in range(xm * ym * zm):
            xi = int(s % xm) * config.patch['patchside']
            yi = int((s % (ym * xm)) / xm) * config.patch['patchside']
            zi = int(s / (ym * xm)) * config.patch['patchside']

            # Extract patch from original image
            patch = lr[:, :, zi:zi + config.patch['patchside'],
                       yi:yi + config.patch['patchside'],
                       xi:xi + config.patch['patchside']]
            with chainer.using_config('train', False), chainer.using_config(
                    'enable_backprop', False):
                hr_patch = gen(patch)

            # Generate HR map
            hr_patch = hr_patch.data
            if args.gpu >= 0:
                hr_patch = chainer.cuda.to_cpu(hr_patch)
            zi, yi, xi = zi * config.upsampling_rate, yi * config.upsampling_rate, xi * config.upsampling_rate
            hr_map[zi:zi + hr_patchside, yi:yi + hr_patchside,
                   xi:xi + hr_patchside] = hr_patch[0, :, :, :]

        print('Save image')
        hr_map = hr_map[:zh, :yh, :xh]

        # Save HR map
        inferenceHrImage = sitk.GetImageFromArray(hr_map)
        lr_spacing = sitkLR.GetSpacing()
        new_spacing = [i / config.upsampling_rate for i in lr_spacing]
        inferenceHrImage.SetSpacing(new_spacing)
        inferenceHrImage.SetOrigin(sitkLR.GetOrigin())
        result_dir = os.path.join(args.base, args.out)
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)
        fn = os.path.splitext(os.path.basename(i[0]))[0]
        sitk.WriteImage(inferenceHrImage, '{}/{}.mhd'.format(result_dir, fn))

        # Calc metric
        case_list.append(os.path.basename(i[0]))
        # PSNR
        sitkHR = sitk.ReadImage(os.path.join(args.root, i[1]))
        hr_gt = sitk.GetArrayFromImage(sitkHR).astype("float")

        psnr_const = psnr(hr_gt,
                          hr_map,
                          dynamic_range=np.amax(hr_gt) - np.amin(hr_gt))
        print('PSNR: {}'.format(psnr_const))
        psnr_list.append(psnr_const)
        # SSIM
        ssim_const = ssim(hr_gt,
                          hr_map,
                          dynamic_range=np.amax(hr_gt) - np.amin(hr_gt),
                          gaussian_weights=True,
                          use_sample_covariance=False)
        print('SSIM: {}'.format(ssim_const))
        ssim_list.append(ssim_const)

    df = pd.DataFrame({
        'Case': case_list,
        'PSNR': psnr_list,
        'SSIM': ssim_list
    })
    df.to_csv('{}/result.csv'.format(result_dir),
              index=False,
              encoding="utf-8",
              mode='w')
Exemple #20
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--training_set', type=str)
    parser.add_argument('--val_set', type=str)
    parser.add_argument('--lr', type=int, default=1e-3)
    parser.add_argument('--scale', type=int, default=4)
    parser.add_argument('--patch_size', type=int, default=56)
    parser.add_argument('--loss_coeff', type=float, default=0)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--epoch', type=int, default=200)
    parser.add_argument('--num_workers', type=int, default=8)
    args = parser.parse_args()

    cudnn.benchmark = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = ESPCN(num_channel=1, scale=args.scale).to(device)

    k = args.loss_coeff
    if k != 0:
        feat_ext = Feature_Extractor().to(device)

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)

    train_set = Train(args.training_set, scale=args.scale, patch_size=args.patch_size)
    trainloader = DataLoader(train_set, batch_size=args.batch_size,
                              shuffle=True, num_workers=args.num_workers, pin_memory=True)

    val_set = Validation(args.val_set)
    valloader = DataLoader(val_set, batch_size=1,
                            shuffle=True, num_workers=args.num_workers, pin_memory=True)