Exemplo n.º 1
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)

    with tf.Session() as sess:
        fsrcnn = FSRCNN(sess, config=FLAGS)
        fsrcnn.run()
Exemplo n.º 2
0
    def __init__(self, weigths, scale, *, onnx: Path):
        self.onnx = onnx
        self.scale = scale

        if onnx:
            self.session = onnxruntime.InferenceSession("fsrcnn.onnx")
            print("Upscaling with ONNX")
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            self.model = FSRCNN(**model_settings).to(self.device)
            self.model.load_state_dict(torch.load(weigths))
            self.model.eval()
Exemplo n.º 3
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if FLAGS.fast:
        FLAGS.checkpoint_dir = PREFIX_PATH + "checkpoint_fast/"
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)

    with tf.Session() as sess:
        tf.set_random_seed(FLAGS.seed)
        fsrcnn = FSRCNN(sess, config=FLAGS)
        fsrcnn.run()
Exemplo n.º 4
0
def construct_model():
    from model import FSRCNN
    model = FSRCNN()
    ckpt = torch.load('result/model-50.pkl')
    model = nn.DataParallel(model.cuda(), [0])
    model.load_state_dict(ckpt)
    model.eval()
    return model
Exemplo n.º 5
0
class Upscaler:
    def __init__(self, weigths, scale, *, onnx: Path):
        self.onnx = onnx
        self.scale = scale

        if onnx:
            self.session = onnxruntime.InferenceSession("fsrcnn.onnx")
            print("Upscaling with ONNX")
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            self.model = FSRCNN(**model_settings).to(self.device)
            self.model.load_state_dict(torch.load(weigths))
            self.model.eval()

    def _upscale(self, data: np.ndarray):
        if self.onnx:
            input_name = self.session.get_inputs()[0].name
            return self.session.run(None, {input_name: data})[0]
        else:
            tensor = torch.from_numpy(data).to(self.device)
            result = self.model(tensor)
            return result.numpy()

    def upscaleImage(self, image: Image):
        with torch.no_grad():
            original = image
            hr_size = (original.width * self.scale,
                       original.height * self.scale)
            y, cb, cr = original.convert("YCbCr").split()
            #####
            array_y = np.array(y)[np.newaxis, np.newaxis].astype(
                np.float32) / 255.0
            result_y = self._upscale(array_y)
            result_y = (result_y[0, 0] * 255.0).astype(np.uint8)
            result_image_y = Image.fromarray(result_y)
            #####
            cb_hr = cb.resize(hr_size, resample=Image.BICUBIC)
            cr_hr = cr.resize(hr_size, resample=Image.BICUBIC)
            new_image = Image.merge("YCbCr", (result_image_y, cb_hr, cr_hr))
            return new_image
Exemplo n.º 6
0
def construct_model(model_path, device):
    from model import FSRCNN
    model = FSRCNN()
    ckpt = torch.load(model_path)
    new_ckpt  = {}
    for key in ckpt:
        if key.startswith('module'):
            new_key  = key[7:]
        else:
            new_key = key
        new_ckpt[new_key] = ckpt[key]
    model = model.to(device)
    model.load_state_dict(new_ckpt)
    model.eval()
    return model
Exemplo n.º 7
0
from config import model_settings, batch_size, learning_rate, epochs

device = "cuda" if torch.cuda.is_available() else "cpu"

with h5py.File("datasets/General-100.h5") as f:

    outdir = Path("out")
    outdir.mkdir(exist_ok=True)

    # Create data loaders.
    train_dataloader = DataLoader(
        TrainDataset(f["train"]), batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(TestDataset(f["test"]), batch_size=1)

    # Create the model
    model = FSRCNN(**model_settings).to(device)

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

    def train(dataloader, model, loss_fn, optimizer):
        size = len(dataloader.dataset)
        for batch, (X, y) in enumerate(tqdm(dataloader, total=size // batch_size)):
            X, y = X.to(device), y.to(device)

            # Compute prediction error
            pred = model(X)
            loss = loss_fn(pred, y)

            # Backpropagation
            optimizer.zero_grad()
Exemplo n.º 8
0
    parser.add_argument("--label_percents", type=float, default=0.3)
    parser.add_argument('--num_workers', type=int, default=0)

    opts = parser.parse_args()

    if not os.path.exists(opts.weights_dir):
        os.mkdir(opts.weights_dir)

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')
    torch.manual_seed(42)

    if opts.sr_module == "FSRCNN":
        sr_module = FSRCNN(scale=opts.scale).to(device)
    elif opts.sr_module == "ESPCN":
        sr_module = ESPCN(scale=opts.scale).to(device)
    elif opts.sr_module == "VDSR":
        sr_module = VDSR(scale=opts.scale).to(device)
    else:
        sr_module = FSRCNN(scale=opts.scale).to(device)

    if opts.lr_module == "FLRCNN":
        lr_module = FLRCNN(scale=opts.scale).to(device)
    elif opts.lr_module == "DESPCN":
        lr_module = DESPCN(scale=opts.scale).to(device)
    elif opts.lr_module == "DVDSR":
        lr_module = DVDSR(scale=opts.scale).to(device)
    else:
        lr_module = FLRCNN(scale=opts.scale).to(device)
Exemplo n.º 9
0
def main():

    parser = argparse.ArgumentParser( formatter_class= argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--epochs', type=int, default=55, help='Number of epoch [15000]')
    parser.add_argument('--batch_size', type=int, default=128, help='The size of batch images [128]')
    parser.add_argument('--image_size', type=int, default= 33, help='The size of image to use [33]')
    parser.add_argument('--label_size', type=int, default= 33, help='The size of label [33]')
    parser.add_argument('--learning_rate', type=int, default= 1e-4, help='The learning rate of gradient descent algorithm [1e-4]')
    parser.add_argument('--c_dim', type=int, default= 3, help='Dimension of image color [3]')
    parser.add_argument('--scale', type=int, default= 3, help='The size of scale factor for preprocessing input image [3]')
    parser.add_argument('--stride', type=int, default= 14, help='The size of stride to apply input image [14]')
    parser.add_argument('--checkpoint_dir', type=str, default= 'checkpoint', help='Name of checkpoint directory [checkpoint]')
    parser.add_argument('--sample_dir', type=str, default= 'sample', help='Name of sample directory [sample]')
    parser.add_argument('--is_train', type=bool, default= False, help='True for training, False for testing [True]')

    args = parser.parse_args()

    if not os.path.exists(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)

    if not os.path.exists(args.sample_dir):
        os.makedirs(args.sample_dir)

    srcnn = FSRCNN(image_size=args.image_size, label_size=args.label_size,
                    batch_size=args.batch_size, c_dim=args.c_dim)

    # Stochastic gradient descent optimizer.
    optimizer = tf.keras.optimizers.Adam(args.learning_rate)

    # Optimization process.
    def run_optimization(x, y):
        # Wrap computation inside a GradientTape for automatic differentiation.
        with tf.GradientTape() as g:
            # Forward pass.
            pred = srcnn(x, is_training=True)
            # Compute loss.
            loss = mse(pred, y)
            # Variables to update, i.e. trainable variables.
            trainable_variables = srcnn.trainable_variables
            # Compute gradients.
            gradients = g.gradient(loss, trainable_variables)
            # Update W and b following gradients.
            optimizer.apply_gradients(zip(gradients, trainable_variables))


    def train(args):

        if args.is_train:
            input_setup(args)
        else:
            nx, ny = input_setup(args)

        counter = 0
        start_time = time.time()

        if args.is_train:
            print("Training...")
            data_dir = os.path.join('./{}'.format(args.checkpoint_dir), "train.h5")
            train_data, train_label = read_data(data_dir)

            display_step = 5
            for step in range(args.epochs):
                batch_idxs = len(train_data) // args.batch_size

                for idx in range(0, batch_idxs):

                    batch_images = train_data[idx * args.batch_size : (idx + 1) * args.batch_size]
                    batch_labels = train_label[idx * args.batch_size : (idx + 1) * args.batch_size]
                    run_optimization(batch_images, batch_labels)

                    if step % display_step == 0:
                        pred = srcnn(batch_images)
                        loss = mse(pred, batch_labels)
                        #psnr_loss = psnr(batch_labels, pred)
                        #acc = accuracy(pred, batch_y)

                        #print("step: %i, loss: %f", "psnr_loss: %f" %(step, loss, psnr_loss))
                        #print("Step:'{0}', Loss:'{1}', PSNR: '{2}'".format(step, loss, psnr_loss))

                        print("step: %i, loss: %f" %(step, loss))

        else:
            print("Testing...")
            data_dir = os.path.join('./{}'.format(args.checkpoint_dir), "test.h5")
            test_data, test_label = read_data(data_dir)

            result = srcnn(test_data)
            result = merge(result, [nx, ny])
            result = result.squeeze()

            image_path = os.path.join(os.getcwd(), args.sample_dir)
            image_path = os.path.join(image_path, "test_image.png")
            print(result.shape)
            imsave(result, image_path)

    train(args)
Exemplo n.º 10
0
    upscale=2.0,
    is_scale_back=True)
demo_dataset_x4_scale = ImageDatasetFromFile(
    "./DIV2K800/train/DIV2K_train_HR/DIV2K_train_HR/",
    upscale=4.0,
    is_scale_back=True)

train_data_loader = data.DataLoader(dataset=demo_dataset_x4,
                                    batch_size=opt.batch_size,
                                    num_workers=8,
                                    drop_last=True,
                                    pin_memory=True)

if opt.model:
    if opt.model == "FSRCNN" and opt.upscale == 4:
        model = FSRCNN(num_channels=3, upscale_factor=4)

    if opt.model == "FSRCNN" and opt.upscale == 4:
        model = FSRCNN(num_channels=3, upscale_factor=4)

    if opt.model == "FALSR_A" or opt.model == "FALSR_B":
        if opt.upscale is not 2:
            raise ("ONLY SUPPORT 2X")
        else:
            if opt.model == "FALSR_A":
                model = FALSR_A()
            if opt.model == "FALSR_B":
                model = FALSR_B()

    if opt.model == "SRCNN" and opt.upscale == 4:
        model = SRCNN(num_channels=3, upscale_factor=4)
Exemplo n.º 11
0
#!/usr/bin/env python3

import onnx
import torch

from config import model_settings
from model import FSRCNN

model = FSRCNN(**model_settings)
model.load_state_dict(torch.load('result.pth'))
model.eval()

inputs = torch.ones(1, 1, 10, 10)

torch.onnx.export(
    model, inputs, "fsrcnn.onnx", verbose=True,
    input_names=["input_image"], dynamic_axes={"input_image": [2, 3]})


onnx_model = onnx.load("fsrcnn.onnx")
onnx.checker.check_model(onnx_model)
Exemplo n.º 12
0
def main() -> None:
    # Create a folder of super-resolution experiment results
    results_dir = os.path.join("results", "test", config.exp_name)
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    # Initialize the super-resolution model
    print("Build SR model...")
    model = FSRCNN(config.upscale_factor).to(config.device)
    print("Build SR model successfully.")

    # Load the super-resolution model weights
    print(f"Load SR model weights `{os.path.abspath(config.model_path)}`...")
    state_dict = torch.load(config.model_path, map_location=config.device)
    model.load_state_dict(state_dict)
    print(f"Load SR model weights `{os.path.abspath(config.model_path)}` successfully.")

    # Start the verification mode of the model.
    model.eval()
    # Turn on half-precision inference.
    model.half()

    # Initialize the image evaluation index.
    total_psnr = 0.0

    # Get a list of test image file names.
    file_names = natsorted(os.listdir(config.hr_dir))
    # Get the number of test image files.
    total_files = len(file_names)

    for index in range(total_files):
        lr_image_path = os.path.join(config.lr_dir, file_names[index])
        sr_image_path = os.path.join(config.sr_dir, file_names[index])
        hr_image_path = os.path.join(config.hr_dir, file_names[index])

        print(f"Processing `{os.path.abspath(hr_image_path)}`...")
        lr_image = Image.open(lr_image_path).convert("RGB")
        bic_image = lr_image.resize([int(lr_image.width * config.upscale_factor), int(lr_image.height * config.upscale_factor)], Image.BICUBIC)
        hr_image = Image.open(hr_image_path).convert("RGB")

        # Extract Y channel lr image data
        lr_image = np.array(lr_image).astype(np.float32)
        lr_ycbcr_image = imgproc.convert_rgb_to_ycbcr(lr_image)
        lr_y_tensor = imgproc.image2tensor(lr_ycbcr_image, range_norm=False, half=True).to(config.device).unsqueeze_(0)

        # Extract Y channel bic image data
        bic_image = np.array(bic_image).astype(np.float32)
        bic_ycbcr_image = imgproc.convert_rgb_to_ycbcr(bic_image)

        # Extract Y channel hr image data.
        hr_image = np.array(hr_image).astype(np.float32)
        hr_ycbcr_image = imgproc.convert_rgb_to_ycbcr(hr_image)
        hr_y_tensor = imgproc.image2tensor(hr_ycbcr_image, range_norm=False, half=True).to(config.device).unsqueeze_(0)

        # Only reconstruct the Y channel image data.
        with torch.no_grad():
            sr_y_tensor = model(lr_y_tensor)

        # Cal PSNR
        total_psnr += 10. * torch.log10(1. / torch.mean((sr_y_tensor - hr_y_tensor) ** 2))

        sr_y_image = imgproc.tensor2image(sr_y_tensor, range_norm=False, half=True)
        sr_image = np.array([sr_y_image, bic_ycbcr_image[..., 1], bic_ycbcr_image[..., 2]]).transpose([1, 2, 0])
        sr_image = np.clip(imgproc.convert_ycbcr_to_rgb(sr_image), 0.0, 255.0).astype(np.uint8)
        sr_image = Image.fromarray(sr_image)
        sr_image.save(sr_image_path)

    print(f"PSNR: {total_psnr / total_files:.2f}.\n")
Exemplo n.º 13
0
def build_model() -> nn.Module:
    model = FSRCNN(config.upscale_factor).to(config.device)

    return model
Exemplo n.º 14
0
import torch
import torch.backends.cudnn as cudnn

#training code
if args.phase == 'train':
    dataloaders = data.DataLoader(DataLoader(args),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.workers,
                                  pin_memory=True)

    device = torch.device("cuda:0" if (
        torch.cuda.is_available() and args.ngpu > 0) else "cpu")

    print("constructing model ....")
    model = FSRCNN()

    model = nn.DataParallel(model.to(device), gpuids)

    if args.resume:
        model.load_state_dict(torch.load(args.model_path))
    print("model constructed")

    summary_writer = SummaryWriter(args.log_dir)

    optimizer = torch.optim.Adam(
        [{
            'params': model.module.extract_features.parameters()
        }, {
            'params': model.module.shrink.parameters()
        }, {
Exemplo n.º 15
0
    parser.add_argument('--lr_weights', type=str, required=True)
    parser.add_argument('--lr_module',
                        type=str,
                        choices=['FLRCNN', 'DESPCN', 'DVDSR'],
                        required=True)
    parser.add_argument("--scale", type=int, default=2)

    opts = parser.parse_args()

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    if opts.sr_module == "FSRCNN":
        sr_module = FSRCNN(scale=opts.scale)
    elif opts.sr_module == "ESPCN":
        sr_module = ESPCN(scale=opts.scale)
    elif opts.sr_module == "VDSR":
        sr_module = VDSR(scale=opts.scale)
    else:
        sr_module = FSRCNN(scale=opts.scale)

    if opts.lr_module == "FLRCNN":
        lr_module = FLRCNN(scale=opts.scale)
    elif opts.lr_module == "DESPCN":
        lr_module = DESPCN(scale=opts.scale)
    elif opts.lr_module == "DVDSR":
        lr_module = DVDSR(scale=opts.scale)
    else:
        lr_module = FLRCNN(scale=opts.scale)
    parser.add_argument('--cuda',
                        action='store_true',
                        help='whether to use cuda')
    args = parser.parse_args()

    if args.cuda and not torch.cuda.is_available():
        raise Exception('No GPU found')
    device = torch.device('cuda' if args.cuda else 'cpu')
    print('Use device:', device)

    filenames = os.listdir(args.img_dir)
    image_filenames = [os.path.join(args.img_dir, x) for x in filenames \
                       if is_image_file(x)]
    image_filenames = sorted(image_filenames)

    model = FSRCNN(img_channels=args.img_channels,
                   upscale_factor=args.upscale_factor).to(device)
    if args.cuda:
        ckpt = torch.load(args.model)
    else:
        ckpt = torch.load(args.model, map_location='cpu')
    model.load_state_dict(ckpt['model'])

    res = {}

    for i, f in enumerate(image_filenames):
        # Read test image.
        img = Image.open(f).convert('RGB')
        width, height = img.size[0], img.size[1]

        # Crop test image so that it has size that can be downsampled by the upscale factor.
        pad_width = width % args.upscale_factor
Exemplo n.º 17
0
        upscale_factor=config['model']['upscale_factor'],
        img_channels=config['model']['img_channels'],
        crop_size=config['data']['lr_crop_size'] *
        config['model']['upscale_factor'])
    train_dataloader = DataLoader(dataset=train_set,
                                  batch_size=config['training']['batch_size'],
                                  shuffle=True)

    val_set = get_val_set(img_dir=config['data']['test_root'],
                          upscale_factor=config['model']['upscale_factor'],
                          img_channels=config['model']['img_channels'])
    val_dataloader = DataLoader(dataset=val_set, batch_size=1, shuffle=False)

    print('===> Building model')
    sys.stdout.flush()
    model = FSRCNN(img_channels=config['model']['img_channels'],
                   upscale_factor=config['model']['upscale_factor']).to(device)
    criterion = nn.MSELoss()
    optimizer = setup_optimizer(model, config)
    scheduler = setup_scheduler(optimizer, config)

    start_iter = 0
    best_val_psnr = -1

    if config['training']['resume'] != 'None':
        print('===> Reloading model')
        sys.stdout.flush()
        ckpt = torch.load(config['training']['resume'])
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optimizer'])
        scheduler.load_state_dict(ckpt['scheduler'])
        start_iter = ckpt['iter']