Ejemplo n.º 1
0
def conv(in_f,
         out_f,
         kernel_size,
         stride=1,
         bias=True,
         pad='zero',
         downsample_mode='stride'):
    downsampler = None
    if stride != 1 and downsample_mode != 'stride':

        if downsample_mode == 'avg':
            downsampler = nn.AvgPool2d(stride, stride)
        elif downsample_mode == 'max':
            downsampler = nn.MaxPool2d(stride, stride)
        elif downsample_mode in ['lanczos2', 'lanczos3']:
            downsampler = Downsampler(n_planes=out_f,
                                      factor=stride,
                                      kernel_type=downsample_mode,
                                      phase=0.5,
                                      preserve_size=True)
        else:
            assert False

        stride = 1

    padder = None
    to_pad = int((kernel_size - 1) / 2)
    if pad == 'reflection':
        padder = nn.ReflectionPad2d(to_pad)
        to_pad = 0

    convolver = nn.Conv2d(in_f,
                          out_f,
                          kernel_size,
                          stride,
                          padding=to_pad,
                          bias=bias)

    layers = filter(lambda x: x is not None, [padder, convolver, downsampler])
    return nn.Sequential(*layers)
Ejemplo n.º 2
0
                  NET_TYPE,
                  pad,
                  skip_n33d=128,
                  skip_n33u=128,
                  skip_n11=4,
                  num_scales=5,
                  upsample_mode='bilinear').type(dtype)

    # Losses
    mse = torch.nn.MSELoss().type(dtype)

    img_LR_var = np_to_var(imgs['LR_np']).type(dtype)

    downsampler = Downsampler(n_planes=3,
                              factor=factor,
                              kernel_type=KERNEL_TYPE,
                              phase=0.5,
                              preserve_size=True).type(dtype)

    def closure():
        global i

        if reg_noise_std > 0:
            net_input.data = net_input_saved + \
                (noise.normal_() * reg_noise_std)

        out_HR = net(net_input)  # ([1, 3, H, W])
        out_LR = downsampler(out_HR)  # ([1, 3, H/factor, W/factor])

        total_loss = mse(out_LR, img_LR_var)
        # total_loss = CharbonnierLoss(out_LR, img_LR_var)
Ejemplo n.º 3
0
    def __init__(self, config):
        self.rank, self.world_size = 0, 1
        if config['dist']:
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()

        self.config = config
        self.mode = config['dgp_mode']
        self.custom_mask = config['custom_mask']
        self.update_G = config['update_G']
        self.update_embed = config['update_embed']
        self.iterations = config['iterations']
        self.ftr_num = config['ftr_num']
        self.ft_num = config['ft_num']
        self.lr_ratio = config['lr_ratio']
        self.G_lrs = config['G_lrs']
        self.z_lrs = config['z_lrs']
        self.use_in = config['use_in']
        self.select_num = config['select_num']
        self.factor = 4  # Downsample factor
        self.mask_path = config['mask_path']

        #Create selective masking
        if self.custom_mask:
            self.mask = torch.ones(1, 1, 256, 256).cuda()
            x = Image.open(self.mask_path)
            pil_to_tensor = transforms.ToTensor()(x).unsqueeze_(0)
            t = Variable(torch.Tensor([0.9]))  # threshold
            final_mask = F.interpolate(pil_to_tensor,
                                       size=(256, 256),
                                       mode='bilinear')
            self.mask = (final_mask > t).float() * 1
            self.mask = self.mask[0][0].cuda()
            self.regions = self.get_regions(self.mask)
        #########################

        # create model
        self.G = models.Generator(**config).cuda()
        self.D = models.Discriminator(
            **config).cuda() if config['ftr_type'] == 'Discriminator' else None
        self.G.optim = torch.optim.Adam(
            [{
                'params': self.G.get_params(i, self.update_embed)
            } for i in range(len(self.G.blocks) + 1)],
            lr=config['G_lr'],
            betas=(config['G_B1'], config['G_B2']),
            weight_decay=0,
            eps=1e-8)

        # load weights
        if config['random_G']:
            self.random_G()
        else:
            utils.load_weights(self.G if not (config['use_ema']) else None,
                               self.D,
                               config['weights_root'],
                               name_suffix=config['load_weights'],
                               G_ema=self.G if config['use_ema'] else None,
                               strict=False)

        self.G.eval()
        if self.D is not None:
            self.D.eval()
        self.G_weight = deepcopy(self.G.state_dict())

        # prepare latent variable and optimizer
        self._prepare_latent()
        # prepare learning rate scheduler
        self.G_scheduler = utils.LRScheduler(self.G.optim, config['warm_up'])
        self.z_scheduler = utils.LRScheduler(self.z_optim, config['warm_up'])

        # loss functions
        self.mse = torch.nn.MSELoss()
        if config['ftr_type'] == 'Discriminator':
            self.ftr_net = self.D
            self.criterion = utils.DiscriminatorLoss(
                ftr_num=config['ftr_num'][0])
        else:
            vgg = torchvision.models.vgg16(pretrained=True).cuda().eval()
            self.ftr_net = models.subsequence(vgg.features, last_layer='20')
            self.criterion = utils.PerceptLoss()

        # Downsampler for producing low-resolution image
        self.downsampler = Downsampler(n_planes=3,
                                       factor=self.factor,
                                       kernel_type='lanczos2',
                                       phase=0.5,
                                       preserve_size=True).type(
                                           torch.cuda.FloatTensor)
Ejemplo n.º 4
0
def build_closure(writer, dtype):
    # Read config file
    config = common.configparser.ConfigParser()
    config.sections()
    config.read('config.cfg')

    plot_steps_low = config.getint('DEFAULT', 'plot_steps_low')
    plot_steps_high = config.getint('DEFAULT', 'plot_steps_high')

    blur = (config['DEFAULT']['blur'] != 'none')

    mse = torch.nn.MSELoss().type(dtype)
    downsampler = Downsampler(n_planes=config.getint('DEFAULT', 'n_channels'),
                              factor=4,
                              kernel_type='lanczos2',
                              phase=0.5,
                              preserve_size=True).type(dtype)

    loss_network = None

    augmented_history = config.has_option('LOADING', 'augmented_history') and \
                        config.getboolean('LOADING', 'augmented_history')

    ignore_ground_truth = config.has_option('LOADING', 'ignore_ground_truth') and \
                          config.getboolean('LOADING', 'ignore_ground_truth')

    random_swap = config.has_option('DEFAULT', 'random_swap') and \
                  config.getboolean('DEFAULT', 'random_swap')
    if random_swap:
        random.seed(101)

    if config.has_option('DEFAULT', 'noise_level'):
        noise_level = float(config['DEFAULT']['noise_level'])
    else:
        noise_level = 0.03
    """
    if config.getboolean('DEFAULT', 'use_perceptual_loss'):
        vgg_model = vgg.vgg16(pretrained=True)
        if torch.cuda.is_available():
            vgg_model.cuda()
        loss_network = LossNetwork(vgg_model)
        loss_network.eval()

        # Pre-trained model expects normalized, RGB images. Here's a transform.
        grey_to_normal_rgb = transforms.Compose([
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                transforms.Lambda(lambda x: x.unsqueeze(0))
            ])

    noise = config['DEFAULT']['added_noise'] != 'none'
    if noise:
        noise_x = config.getint('DEFAULT', 'noise_x')
        noise_y = config.getint('DEFAULT', 'noise_y')
        background = sr_utils.get_background(noise_x, noise_y)"""
    def get_loss(out_LR, ground_truth_LR, blurred_LR):
        """Calculates loss from the low-resolution output of the network.
        """

        if blur:
            used_image = blurred_LR
        else:
            used_image = ground_truth_LR

        if loss_network:
            out_LR = out_LR.repeat(3, 1, 1)
            used_image = used_image.repeat(3, 1, 1)
            total_loss = mse(
                loss_network(grey_to_normal_rgb(out_LR)).relu3_3,
                loss_network(grey_to_normal_rgb(used_image)).relu3_3)
        else:
            total_loss = mse(out_LR, used_image)

        return total_loss

    def get_images(before_HR, after_HR, bicubic_HR, blurred_HR, before_LR,
                   after_LR, blurred_LR):
        HR_grid = torchvision.utils.make_grid([
            torch.clamp(torch.flip(blurred_HR, [1, 0]), 0, 1),
            torch.clamp(torch.flip(before_HR, [1, 0]), 0, 1),
            torch.clamp(torch.flip(after_HR, [1, 0]), 0, 1),
            torch.clamp(torch.flip(bicubic_HR, [1, 0]), 0, 1),
            torch.clamp(
                torch.flip(common.makeResidual(after_HR, before_HR), [1, 0]),
                0, 1)
        ], 5)
        LR_grid = torchvision.utils.make_grid([
            torch.clamp(torch.flip(blurred_LR, [1, 0]), 0, 1),
            torch.clamp(torch.flip(before_LR, [1, 0]), 0, 1),
            torch.clamp(torch.flip(after_LR, [1, 0]), 0, 1),
            torch.clamp(
                torch.flip(common.makeResidual(after_LR, before_LR), [1, 0]),
                0, 1)
        ], 4)

        return HR_grid, LR_grid

    def closure():
        # Train with a different input/output pair at each iteration.
        if random_swap:
            index = random.randint(0, len(state.imgs) - 1)
        else:
            index = state.i % len(state.imgs)
        net_input = state.imgs[index]['net_input']
        ground_truth_LR = state.imgs[index]['LR_torch']
        ground_truth_HR = state.imgs[index]['HR_torch']
        bicubic_HR = state.imgs[index]['HR_torch_bicubic']
        #blurred_HR = TF.to_tensor(state.imgs[index]['HR_pil_blurred']).type(state.dtype)
        #blurred_LR = TF.to_tensor(state.imgs[index]['LR_pil_blurred']).type(state.dtype)
        blurred_LR = ground_truth_LR
        blurred_HR = ground_truth_HR
        background = None
        noise = None

        # Feed through actual network
        net_input_noisy = net_input.detach().clone()
        noise_t = net_input.detach().clone()
        net_input_noisy = net_input_noisy + (noise_t.normal_() * noise_level)
        out_HR = state.net(net_input_noisy)
        if noise:
            with torch.no_grad():
                out_HR = out_HR + background * torch.randn(out_HR.size()).type(
                    state.dtype)
        out_LR = downsampler(out_HR)

        out_HR = out_HR.squeeze(0)
        out_LR = out_LR.squeeze(0)

        # Get loss and train
        total_loss = get_loss(out_LR, ground_truth_LR, blurred_LR)
        total_loss.backward()

        out_HR = out_HR.detach().cpu()
        out_LR = out_LR.detach().cpu()
        ground_truth_LR = ground_truth_LR.cpu()
        ground_truth_HR = ground_truth_HR.cpu()
        bicubic_HR = bicubic_HR.cpu()
        blurred_LR = blurred_LR.cpu()
        blurred_HR = blurred_HR.cpu()

        if (state.i % plot_steps_low < len(state.imgs)):
            psnr_LR = compare_psnr(common.torch_to_np(ground_truth_LR),
                                   common.torch_to_np(out_LR))
            if not ignore_ground_truth:
                psnr_HR = compare_psnr(common.torch_to_np(ground_truth_HR),
                                       common.torch_to_np(out_HR))
                target_loss = sr_utils.compare_HR(
                    common.torch_to_np(ground_truth_HR),
                    common.torch_to_np(out_HR))
                state.imgs[index]['history_low'].psnr_HR.append(psnr_HR)
                state.imgs[index]['history_low'].target_loss.append(
                    target_loss)
            else:
                psnr_HR = 0
                target_loss = 0
            state.imgs[index]['history_low'].iteration.append(state.i)
            state.imgs[index]['history_low'].psnr_LR.append(psnr_LR)
            state.imgs[index]['history_low'].training_loss.append(
                total_loss.item())

            if augmented_history:
                HR_torch_blurred = state.imgs[index][
                    'HR_torch_blurred'].detach().cpu()
                LR_torch_downsampled = state.imgs[index][
                    'LR_torch_downsampled'].detach().cpu()
                psnr_blurred = compare_psnr(
                    common.torch_to_np(HR_torch_blurred),
                    common.torch_to_np(out_HR))
                psnr_downsampled = compare_psnr(
                    common.torch_to_np(LR_torch_downsampled),
                    common.torch_to_np(out_LR))
                state.imgs[index]['history_low'].psnr_blurred.append(
                    psnr_blurred)
                state.imgs[index]['history_low'].psnr_downsampled.append(
                    psnr_downsampled)
                print("{} {} {} {} {} {}".format(state.i, index, psnr_LR,
                                                 psnr_HR, psnr_blurred,
                                                 psnr_downsampled))
            elif not ignore_ground_truth:
                print("{} {} {} {}".format(state.i, index, psnr_LR, psnr_HR))
            else:
                print("{} {} {} {}".format(state.i, index, psnr_LR,
                                           total_loss.item()))

            # TensorBoard History
            writer.add_scalar('PSNR LR', psnr_LR, state.i)
            writer.add_scalar('PSNR HR', psnr_HR, state.i)
            writer.add_scalar('Training Loss', total_loss.item(), state.i)
            writer.add_scalar('Target Loss', target_loss, state.i)

        if (state.i % plot_steps_high < len(state.imgs)):
            # Lower frequency capturing of large data
            psnr_LR = compare_psnr(common.torch_to_np(ground_truth_LR),
                                   common.torch_to_np(out_LR))
            if not ignore_ground_truth:
                psnr_HR = compare_psnr(common.torch_to_np(ground_truth_HR),
                                       common.torch_to_np(out_HR))
                target_loss = sr_utils.compare_HR(
                    common.torch_to_np(ground_truth_HR),
                    common.torch_to_np(out_HR))
                state.imgs[index]['history_high'].psnr_HR.append(psnr_HR)
                state.imgs[index]['history_high'].target_loss.append(
                    target_loss)
            else:
                psnr_HR = 0
                target_loss = 0
            state.imgs[index]['history_high'].iteration.append(state.i)
            state.imgs[index]['history_high'].psnr_LR.append(psnr_LR)
            state.imgs[index]['history_high'].training_loss.append(
                total_loss.item())

            if augmented_history:
                HR_torch_blurred = state.imgs[index][
                    'HR_torch_blurred'].detach().cpu()
                LR_torch_downsampled = state.imgs[index][
                    'LR_torch_downsampled'].detach().cpu()
                psnr_blurred = compare_psnr(
                    common.torch_to_np(HR_torch_blurred),
                    common.torch_to_np(out_HR))
                psnr_downsampled = compare_psnr(
                    common.torch_to_np(LR_torch_downsampled),
                    common.torch_to_np(out_LR))
                state.imgs[index]['history_high'].psnr_blurred.append(
                    psnr_blurred)
                state.imgs[index]['history_high'].psnr_downsampled.append(
                    psnr_downsampled)

            # Save parameters
            for name, param in state.net.named_parameters():
                writer.add_histogram('Parameter {}'.format(name),
                                     param.flatten(), state.i)

            # Add images
            HR_grid, LR_grid = get_images(ground_truth_HR, out_HR, bicubic_HR,
                                          blurred_HR, ground_truth_LR, out_LR,
                                          blurred_LR)
            writer.add_image('Network LR Output', LR_grid, state.i)
            writer.add_image('Network HR Output', HR_grid, state.i)

            # Save a checkpoint
            #torch.save(state.net, "./output/checkpoints/checkpoint_{}.pt".format(state.i))
            #common.saveFigure("./output/checkpoints/checkpoint_{}.png".format(state.i), common.torch_to_np(out_HR))
            sr_utils.make_progress_figure(
                ground_truth_HR, bicubic_HR, out_HR, ground_truth_LR, out_LR,
                'HR_update_frame_{}_{}'.format(index, state.i),
                'LR_update_frame_{}_{}'.format(index, state.i))
        state.i += 1

        return total_loss

    return closure
def train(args, imgs):
    # Load model
    net = skip(num_input_channels=32,
               num_output_channels=3,
               num_channels_down=[64, 64, 64, 128, 128],
               num_channels_up=[64, 64, 64, 128, 128],
               num_channels_skip=[4, 4, 4, 4, 4],
               upsample_mode='bilinear',
               need_sigmoid=True,
               need_bias=True,
               pad='reflection',
               act_fun='LeakyReLU')
    net = net.float()
    net = net.cuda() if torch.cuda.is_available() else net

    # Compute the number of parameters
    s = sum([np.prod(list(p.size())) for p in net.parameters()])
    print('Number of parameters: ', s)
    criterion = nn.MSELoss()

    # define input
    net_input = get_noise(32, 'noise', (np.shape(
        imgs['HR_np'])[1], np.shape(imgs['HR_np'])[2])).float().detach()
    img_LR_var = np_to_var(imgs['LR_np']).float()
    downsampler = Downsampler(n_planes=3,
                              factor=args.factor,
                              kernel_type='lanczos2',
                              phase=0.5,
                              preserve_size=True).float()
    downsampler = downsampler.cuda() if torch.cuda.is_available(
    ) else downsampler
    net_input_saved = net_input.data.clone()
    noise = net_input.data.clone()

    # Define closure
    reg_noise_std = 0.0
    tv_weight = 0.0
    net_input = net_input.cuda() if torch.cuda.is_available() else net_input
    img_LR_var = img_LR_var.cuda() if torch.cuda.is_available() else img_LR_var
    net_input_saved = net_input_saved.cuda() if torch.cuda.is_available(
    ) else net_input_saved

    def closure():
        global iteration, psnr_HR
        if reg_noise_std > 0:
            net_input.data = net_input_saved + (
                torch.cuda.FloatTensor(noise.size()).normal_() * reg_noise_std)
        out_HR = net(net_input)
        out_LR = downsampler(out_HR)
        total_loss = criterion(out_LR, img_LR_var)
        if tv_weight > 0:
            total_loss += tv_weight * tv_loss(out_HR)
        total_loss.backward()

        # Log
        psnr_LR = compare_psnr(imgs['LR_np'], out_LR.data.cpu().numpy()[0])
        psnr_HR = compare_psnr(imgs['HR_np'], out_HR.data.cpu().numpy()[0])
        print('Iteration %05d   Loss %3f    PSNR_LR %.3f   PSNR_HR %.3f' %
              (iteration, total_loss.data[0], psnr_LR, psnr_HR),
              '\r',
              end='')
        if iteration % 500 == 0 or iteration == args.epoch - 1:
            out_HR_np = var_to_np(out_HR)
            # plot_image_grid([imgs['HR_np'], imgs['bicubic_np'], np.clip(out_HR_np, 0, 1)], factor=13, nrow=3)
            io.imsave(
                'sr_' + str(iteration) + '.png',
                np.transpose(np.clip(out_HR_np, 0, 1), [1, 2, 0])[:, :, :3])
        iteration += 1
        return total_loss

    # Optimize (Adam with first 100 epoch and LBFGS with rest)
    p = get_params('net', net, net_input)
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    for j in range(100):
        optimizer.zero_grad()
        closure()
        optimizer.step()
    print('Starting optimization with LBFGS')
    if args.epoch > 100:
        optimizer = torch.optim.LBFGS(net.parameters(),
                                      max_iter=args.epoch - 100,
                                      lr=0.01,
                                      tolerance_grad=-1,
                                      tolerance_change=-1)

        def closure2():
            optimizer.zero_grad()
            return closure()

        optimizer.step(closure2)

    # Show final result
    out_HR_np = np.clip(var_to_np(net(net_input)), 0, 1)
    result_deep_prior = put_in_center(out_HR_np, imgs['HR_np'].shape[1:])
    plot_image_grid([imgs['HR_np'], imgs['bicubic_np'], out_HR_np],
                    factor=4,
                    nrow=1)
    print('\nFinal PSNR: ', psnr_HR)
    from models.model_sr import Model
    net = Model()

    net = net.type(dtype)

    net_input = get_noise(
        args.input_depth, args.noise_method,
        (imgs['HR_pil'].size[1], imgs['HR_pil'].size[0])).type(dtype).detach()

    mse = torch.nn.MSELoss().type(dtype)

    img_LR_var = np_to_torch(imgs['LR_np']).type(dtype)
    downsampler = Downsampler(n_planes=3,
                              factor=args.factor,
                              kernel_type='lanczos2',
                              phase=0.5,
                              preserve_size=True).type(dtype)

    psnr_gt_best = 0

    i = 0
    PSNR_list = []

    _t = {'im_detect': Timer(), 'misc': Timer()}

    def closure():

        global i, net_input, psnr_gt_best, PSNR_list

        _t['im_detect'].tic()
Ejemplo n.º 7
0
    def __init__(self, config):
        self.rank, self.world_size = 0, 1
        if config['dist']:
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()
        self.config = config
        self.mode = config['dgp_mode']
        self.update_G = config['update_G']
        self.update_embed = config['update_embed']
        self.iterations = config['iterations']
        self.ftr_num = config['ftr_num']
        self.ft_num = config['ft_num']
        self.lr_ratio = config['lr_ratio']
        self.G_lrs = config['G_lrs']
        self.z_lrs = config['z_lrs']
        self.use_in = config['use_in']
        self.select_num = config['select_num']
        self.factor = 2 if self.mode == 'hybrid' else 4  # Downsample factor

        # create model
        self.G = models.Generator(**config).cuda()
        self.D = models.Discriminator(
            **config).cuda() if config['ftr_type'] == 'Discriminator' else None
        self.G.optim = torch.optim.Adam(
            [{
                'params': self.G.get_params(i, self.update_embed)
            } for i in range(len(self.G.blocks) + 1)],
            lr=config['G_lr'],
            betas=(config['G_B1'], config['G_B2']),
            weight_decay=0,
            eps=1e-8)

        # load weights
        if config['random_G']:
            self.random_G()
        else:
            utils.load_weights(self.G if not (config['use_ema']) else None,
                               self.D,
                               config['weights_root'],
                               name_suffix=config['load_weights'],
                               G_ema=self.G if config['use_ema'] else None,
                               strict=False)

        self.G.eval()
        if self.D is not None:
            self.D.eval()
        self.G_weight = deepcopy(self.G.state_dict())

        # prepare latent variable and optimizer
        self._prepare_latent()
        # prepare learning rate scheduler
        self.G_scheduler = utils.LRScheduler(self.G.optim, config['warm_up'])
        self.z_scheduler = utils.LRScheduler(self.z_optim, config['warm_up'])

        # loss functions
        self.mse = torch.nn.MSELoss()
        if config['ftr_type'] == 'Discriminator':
            self.ftr_net = self.D
            self.criterion = utils.DiscriminatorLoss(
                ftr_num=config['ftr_num'][0])
        else:
            vgg = torchvision.models.vgg16(pretrained=True).cuda().eval()
            self.ftr_net = models.subsequence(vgg.features, last_layer='20')
            self.criterion = utils.PerceptLoss()

        # Downsampler for producing low-resolution image
        self.downsampler = Downsampler(n_planes=3,
                                       factor=self.factor,
                                       kernel_type='lanczos2',
                                       phase=0.5,
                                       preserve_size=True).type(
                                           torch.cuda.FloatTensor)
Ejemplo n.º 8
0
def get_h(n_ch, blur_type, use_fourier, dtype):
    assert blur_type in ['uniform_blur', 'gauss_blur'], "blur_type can be or 'uniform' or 'gauss'"
    if not use_fourier:
        return Downsampler(n_ch, 1, blur_type, preserve_size=True).type(dtype)
    return lambda im: torch_blur(im, blur_type, dtype)