예제 #1
0
    def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, K=None, opt=None, step=None):
        super(SRFlowNet, self).__init__()

        self.opt = opt
        self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
                            None else opt_get(opt, ['datasets', 'train', 'quant'])
        self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
        hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels'])
        hidden_channels = hidden_channels or 64
        self.RRDB_training = True  # Default is true

        # train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
        set_RRDB_to_train = False
        if set_RRDB_to_train:
            self.set_rrdb_training(True)

        # Note: using this will cause RRDB optimizer not to be created and has
        # to be added with add_optimizer_and_scheduler_RRDB
        # set_RRDB_to_train = opt_get(self.opt, ['network_G', 'train_RRDB'])  # False
        # set_RRDB_to_train = False if not set_RRDB_to_train else set_RRDB_to_train
        # self.set_rrdb_training(set_RRDB_to_train)

        self.flowUpsamplerNet = \
            FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
                             flow_coupling=opt['network_G']['flow']['coupling'], opt=opt)
        self.i = 0
    def __init__(self, in_channels, opt):
        super().__init__()
        self.need_features = True
        self.in_channels = in_channels
        self.in_channels_rrdb = 320
        self.kernel_hidden = 1  # from GLOW/RealNVP papers
        self.affine_eps = 0.0001
        self.n_hidden_layers = 1  # from GLOW/RealNVP papers
        hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels'])
        self.hidden_channels = 64 if hidden_channels is None else hidden_channels

        self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'],  0.0001)

        self.channels_for_nn = self.in_channels // 2
        self.channels_for_co = self.in_channels - self.channels_for_nn

        if self.channels_for_nn is None:
            self.channels_for_nn = self.in_channels // 2

        self.fAffine = self.F(in_channels=self.channels_for_nn + self.in_channels_rrdb,
                              out_channels=self.channels_for_co * 2,
                              hidden_channels=self.hidden_channels,
                              kernel_hidden=self.kernel_hidden,
                              n_hidden_layers=self.n_hidden_layers)

        self.fFeatures = self.F(in_channels=self.in_channels_rrdb,
                                out_channels=self.in_channels * 2,
                                hidden_channels=self.hidden_channels,
                                kernel_hidden=self.kernel_hidden,
                                n_hidden_layers=self.n_hidden_layers)
예제 #3
0
    def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None):
        if lr_enc is None:
            lr_enc = self.rrdbPreprocessing(lr)

        logdet = torch.zeros_like(gt[:, 0, 0, 0])
        pixels = thops.pixels(gt)

        z = gt

        if add_gt_noise:
            # Setup
            noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True)
            if noiseQuant:
                z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant)
            logdet = logdet + float(-np.log(self.quant) * pixels)

        # Encode
        epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=epses,
                                              y_onehot=y_onehot)

        objective = logdet.clone()

        if isinstance(epses, (list, tuple)):
            z = epses[-1]
        else:
            z = epses

        objective = objective + flow.GaussianDiag.logp(None, None, z)

        nll = (-objective) / float(np.log(2.) * pixels)

        if isinstance(epses, list):
            return epses, nll, logdet
        return z, nll, logdet
예제 #4
0
    def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None):
        z = epses.pop() if isinstance(epses, list) else z

        fl_fea = z
        # debug.imwrite("fl_fea", fl_fea)
        bypasses = {}
        level_conditionals = {}
        if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) is True:
            for level in range(self.L + 1):
                level_conditionals[level] = rrdbResults[self.levelToName[level]]

        for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)):
            size = shape[2]
            level = int(np.log(160 / size) / np.log(2))
            # size = fl_fea.shape[2]
            # level = int(np.log(160 / size) / np.log(2))

            if isinstance(layer, Split2d):
                fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer,
                                                              rrdbResults[self.levelToName[level]], logdet=logdet,
                                                              y_onehot=y_onehot)
            elif isinstance(layer, FlowStep):
                fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, rrdbResults=level_conditionals[level])
            else:
                fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True)

        sr = fl_fea

        assert sr.shape[1] == 3
        return sr, logdet
예제 #5
0
    def __init__(self, opt, step):
        super(SRFlowModel, self).__init__(opt, step)
        train_opt = opt['train']

        self.heats = opt_get(opt, ['val', 'heats'], 0.0)
        self.n_sample = opt_get(opt, ['val', 'n_sample'], 1)
        hr_size = opt_get(opt, ['datasets', 'train', 'HR_size'], 160)
        self.lr_size = hr_size // opt['scale']
        self.nll = None

        if self.is_train:
            """
            Initialize losses
            """
            # nll loss
            self.fl_weight = opt_get(self.opt, ['train', 'fl_weight'], 1)
            """
            Prepare optimizer
            """
            self.optDstep = True  # no Discriminator being used
            self.optimizers, self.optimizer_G = optimizers.get_optimizers_filter(
                None,
                None,
                self.netG,
                train_opt,
                logger,
                self.optimizers,
                param_filter='RRDB')
            """
            Prepare schedulers
            """
            self.schedulers = schedulers.get_schedulers(
                optimizers=self.optimizers,
                schedulers=self.schedulers,
                train_opt=train_opt)
            """
            Set RRDB training state
            """
            train_RRDB_delay = opt_get(self.opt,
                                       ['network_G', 'train_RRDB_delay'])
            if train_RRDB_delay is not None and step < int(train_RRDB_delay * self.opt['train']['niter']) \
                and self.netG.module.RRDB_training:
                if self.netG.module.set_rrdb_training(False):
                    logger.info(
                        'RRDB module frozen, will unfreeze at iter: {}'.format(
                            int(train_RRDB_delay *
                                self.opt['train']['niter'])))
예제 #6
0
    def arch_split(self, H, W, L, levels, opt, opt_get):
        correct_splits = opt_get(
            opt, ['network_G', 'flow', 'split', 'correct_splits'], False)
        correction = 0 if correct_splits else 1
        if opt_get(opt, ['network_G', 'flow', 'split', 'enable'
                         ]) and L < levels - correction:
            logs_eps = opt_get(opt,
                               ['network_G', 'flow', 'split', 'logs_eps']) or 0
            consume_ratio = opt_get(
                opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5
            position_name = get_position_name(H, self.opt['scale'])
            position = position_name if opt_get(
                opt, ['network_G', 'flow', 'split', 'conditional']) else None
            cond_channels = opt_get(
                opt, ['network_G', 'flow', 'split', 'cond_channels'])
            cond_channels = 0 if cond_channels is None else cond_channels

            t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d')

            if t == 'Split2d':
                split = Split2d(num_channels=self.C,
                                logs_eps=logs_eps,
                                position=position,
                                cond_channels=cond_channels,
                                consume_ratio=consume_ratio,
                                opt=opt)
            self.layers.append(split)
            self.output_shapes.append([-1, split.num_channels_pass, H, W])
            self.C = split.num_channels_pass
예제 #7
0
 def get_z(self, heat, seed=None, batch_size=1, lr_shape=None):
     if seed: torch.manual_seed(seed)
     if opt_get(self.opt, ['network_G', 'flow', 'split', 'enable']):
         C = self.netG.module.flowUpsamplerNet.C
         H = int(self.opt['scale'] * lr_shape[2] //
                 self.netG.module.flowUpsamplerNet.scaleH)
         W = int(self.opt['scale'] * lr_shape[3] //
                 self.netG.module.flowUpsamplerNet.scaleW)
         size = (batch_size, C, H, W)
         z = torch.normal(mean=0, std=heat,
                          size=size) if heat > 0 else torch.zeros(size)
     else:
         L = opt_get(self.opt, ['network_G', 'flow', 'L']) or 3
         fac = 2**(L - 3)
         z_size = int(self.lr_size // (2**(L - 3)))
         z = torch.normal(mean=0,
                          std=heat,
                          size=(batch_size, 3 * 8 * 8 * fac * fac, z_size,
                                z_size))
     return z
예제 #8
0
    def rrdbPreprocessing(self, lr):
        rrdbResults = self.RRDB(lr, get_steps=True)
        block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
        if len(block_idxs) > 0:
            concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)

            if opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'concat']) or False:
                keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4']
                if 'fea_up0' in rrdbResults.keys():
                    keys.append('fea_up0')
                if 'fea_up-1' in rrdbResults.keys():
                    keys.append('fea_up-1')
                if self.opt['scale'] >= 8:
                    keys.append('fea_up8')
                if self.opt['scale'] == 16:
                    keys.append('fea_up16')
                for k in keys:
                    h = rrdbResults[k].shape[2]
                    w = rrdbResults[k].shape[3]
                    rrdbResults[k] = torch.cat([rrdbResults[k], F.interpolate(concat, (h, w))], dim=1)
        return rrdbResults
예제 #9
0
    def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None):
        fl_fea = gt
        reverse = False
        level_conditionals = {}
        bypasses = {}

        L = opt_get(self.opt, ['network_G', 'flow', 'L'])

        for level in range(1, L + 1):
            bypasses[level] = torch.nn.functional.interpolate(
                gt,
                scale_factor=2**-level,
                mode='bilinear',
                align_corners=False,
                recompute_scale_factor=False)

        for layer, shape in zip(self.layers, self.output_shapes):
            size = shape[2]
            level = int(np.log(160 / size) / np.log(2))

            if level > 0 and level not in level_conditionals.keys():
                level_conditionals[level] = rrdbResults[
                    self.levelToName[level]]

            level_conditionals[level] = rrdbResults[self.levelToName[level]]

            if isinstance(layer, FlowStep):
                fl_fea, logdet = layer(fl_fea,
                                       logdet,
                                       reverse=reverse,
                                       rrdbResults=level_conditionals[level])
            elif isinstance(layer, Split2d):
                fl_fea, logdet = self.forward_split2d(
                    epses,
                    fl_fea,
                    layer,
                    logdet,
                    reverse,
                    level_conditionals[level],
                    y_onehot=y_onehot)
            else:
                fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse)

        z = fl_fea

        if not isinstance(epses, list):
            return z, logdet

        epses.append(z)
        return epses, logdet
예제 #10
0
def model_val(opt_net=None, state_dict=None, model_type=None):
    if model_type == 'G':
        model = opt_get(opt_net, ['network_G', 'type']).lower()
        if model in ('rrdb_net', 'esrgan'):  # tonormal
            return mod2normal(state_dict)
        elif model == 'mrrdb_net' or model == 'srflow_net':  # tomod
            return normal2mod(state_dict)
        return state_dict
    elif model_type == 'D':
        # no particular Discriminator validation at the moment
        # model = opt_get(opt_net, ['network_G', 'type']).lower()
        return state_dict
    # if model_type not provided, return unchanged
    # (can do other validations here)
    return state_dict
예제 #11
0
def model_val(opt_net=None, state_dict=None, model_type=None):
    if model_type == 'G':
        model = opt_get(opt_net, ['network_G', 'which_model_G'])
        if model == 'RRDB_net': # tonormal
            return mod2normal(state_dict)
        elif model == 'MRRDB_net' or model == 'SRFlow_net': # tomod
            return normal2mod(state_dict)
        else:
            return state_dict
    elif model_type == 'D':
        # no particular Discriminator validation at the moment
        # model = opt_get(opt_net, ['network_G', 'which_model_D'])
        return state_dict
    else:
        # if model_type not provided, return unchanged 
        # (can do other validations here)
        return state_dict
예제 #12
0
    def optimize_parameters(self, step):
        # unfreeze RRDB module if train_RRDB_delay is set
        train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
        if train_RRDB_delay is not None and \
                int(step/self.accumulations) > int(train_RRDB_delay * self.opt['train']['niter']) \
                and not self.netG.module.RRDB_training:
            if self.netG.module.set_rrdb_training(True):
                logger.info('Unfreezing RRDB module.')
                if len(self.optimizers) == 1:
                    # add the RRDB optimizer only if missing
                    self.add_optimizer_and_scheduler_RRDB(self.opt['train'])

        # self.print_rrdb_state()

        self.netG.train()
        """
        Calculate and log losses
        """
        l_g_total = 0
        if self.fl_weight > 0:
            # compute the negative log-likelihood of the output z assuming a unit-norm Gaussian prior
            # with self.cast():  # needs testing, reduced precision could affect results
            z, nll, y_logits = self.netG(gt=self.real_H,
                                         lr=self.var_L,
                                         reverse=False)
            nll_loss = torch.mean(nll)
            l_g_nll = self.fl_weight * nll_loss
            # # /with self.cast():
            self.log_dict['nll_loss'] = l_g_nll.item()
            l_g_total += l_g_nll / self.accumulations

        if self.generatorlosses.loss_list or self.precisegeneratorlosses.loss_list:
            # batch (mixup) augmentations
            aug = None
            if self.mixup:
                self.real_H, self.var_L, mask, aug = BatchAug(
                    self.real_H, self.var_L, self.mixopts, self.mixprob,
                    self.mixalpha, self.aux_mixprob, self.aux_mixalpha,
                    self.mix_p)

            with self.cast():
                z = self.get_z(heat=0,
                               seed=None,
                               batch_size=self.var_L.shape[0],
                               lr_shape=self.var_L.shape)
                self.fake_H, logdet = self.netG(lr=self.var_L,
                                                z=z,
                                                eps_std=0,
                                                reverse=True,
                                                reverse_with_grad=True)

            # batch (mixup) augmentations
            # cutout-ed pixels are discarded when calculating loss by masking removed pixels
            if aug == "cutout":
                self.fake_H, self.real_H = self.fake_H * mask, self.real_H * mask

            # TODO: CEM is WIP
            # unpad images if using CEM
            # if self.CEM:
            #     self.fake_H = self.CEM_net.HR_unpadder(self.fake_H)
            #     self.real_H = self.CEM_net.HR_unpadder(self.real_H)
            #     self.var_ref = self.CEM_net.HR_unpadder(self.var_ref)

            if self.generatorlosses.loss_list:
                with self.cast():
                    # regular losses
                    loss_results, self.log_dict = self.generatorlosses(
                        self.fake_H, self.real_H, self.log_dict, self.f_low)
                    l_g_total += sum(loss_results) / self.accumulations

            if self.precisegeneratorlosses.loss_list:
                # high precision generator losses (can be affected by AMP half precision)
                precise_loss_results, self.log_dict = self.precisegeneratorlosses(
                    self.fake_H, self.real_H, self.log_dict, self.f_low)
                l_g_total += sum(precise_loss_results) / self.accumulations

        if self.amp:
            self.amp_scaler.scale(l_g_total).backward()
        else:
            l_g_total.backward()

        # only step and clear gradient if virtual batch has completed
        if (step + 1) % self.accumulations == 0:
            if self.amp:
                self.amp_scaler.step(self.optimizer_G)
                self.amp_scaler.update()
            else:
                self.optimizer_G.step()
            self.optimizer_G.zero_grad()
            self.optGstep = True
예제 #13
0
 def get_affineInCh(self, opt_get):
     affineInCh = opt_get(
         self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
     affineInCh = (len(affineInCh) + 1) * 64
     return affineInCh
예제 #14
0
 def get_condAffSetting(self, opt, opt_get):
     condAff = opt_get(opt, ['network_G', 'flow', 'condAff']) or None
     condAff = opt_get(opt,
                       ['network_G', 'flow', 'condFtAffine']) or condAff
     return condAff
예제 #15
0
 def get_n_rrdb_channels(self, opt, opt_get):
     blocks = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'blocks'])
     n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
     return n_rrdb
예제 #16
0
    def forward(self, x, get_steps=False):
        fea = self.conv_first(x)

        block_idxs = opt_get(
            self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
        block_results = {}

        for idx, m in enumerate(self.RRDB_trunk.children()):
            fea = m(fea)
            for b in block_idxs:
                if b == idx:
                    block_results["block_{}".format(idx)] = fea

        trunk = self.trunk_conv(fea)

        last_lr_fea = fea + trunk

        fea_up2 = self.upconv1(
            F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
        fea = self.lrelu(fea_up2)

        fea_up4 = self.upconv2(
            F.interpolate(fea, scale_factor=2, mode='nearest'))
        fea = self.lrelu(fea_up4)

        fea_up8 = None
        fea_up16 = None
        fea_up32 = None

        if self.scale >= 8:
            fea_up8 = self.upconv3(
                F.interpolate(fea, scale_factor=2, mode='nearest'))
            fea = self.lrelu(fea_up8)
        if self.scale >= 16:
            fea_up16 = self.upconv4(
                F.interpolate(fea, scale_factor=2, mode='nearest'))
            fea = self.lrelu(fea_up16)
        if self.scale >= 32:
            fea_up32 = self.upconv5(
                F.interpolate(fea, scale_factor=2, mode='nearest'))
            fea = self.lrelu(fea_up32)

        out = self.conv_last(self.lrelu(self.HRconv(fea)))

        results = {
            'last_lr_fea': last_lr_fea,
            'fea_up1': last_lr_fea,
            'fea_up2': fea_up2,
            'fea_up4': fea_up4,
            'fea_up8': fea_up8,
            'fea_up16': fea_up16,
            'fea_up32': fea_up32,
            'out': out
        }

        fea_up0_en = opt_get(self.opt,
                             ['network_G', 'flow', 'fea_up0']) or False
        if fea_up0_en:
            results['fea_up0'] = F.interpolate(last_lr_fea,
                                               scale_factor=1 / 2,
                                               mode='bilinear',
                                               align_corners=False,
                                               recompute_scale_factor=True)
        fea_upn1_en = opt_get(self.opt,
                              ['network_G', 'flow', 'fea_up-1']) or False
        if fea_upn1_en:
            results['fea_up-1'] = F.interpolate(last_lr_fea,
                                                scale_factor=1 / 4,
                                                mode='bilinear',
                                                align_corners=False,
                                                recompute_scale_factor=True)

        if get_steps:
            for k, v in block_results.items():
                results[k] = v
            return results
        else:
            return out
예제 #17
0
    def __init__(self,
                 image_shape,
                 hidden_channels,
                 K,
                 L=None,
                 actnorm_scale=1.0,
                 flow_permutation=None,
                 flow_coupling="affine",
                 LU_decomposed=False,
                 opt=None):
        super().__init__()
        self.layers = nn.ModuleList()
        self.output_shapes = []
        self.L = opt_get(opt, ['network_G', 'flow', 'L'])
        self.K = opt_get(opt, ['network_G', 'flow', 'K'])
        if isinstance(self.K, int):
            self.K = [K for K in [
                K,
            ] * (self.L + 1)]

        self.opt = opt
        H, W, self.C = image_shape
        self.check_image_shape()

        if opt['scale'] == 16:
            self.levelToName = {
                0: 'fea_up16',
                1: 'fea_up8',
                2: 'fea_up4',
                3: 'fea_up2',
                4: 'fea_up1',
            }

        if opt['scale'] == 8:
            self.levelToName = {
                0: 'fea_up8',
                1: 'fea_up4',
                2: 'fea_up2',
                3: 'fea_up1',
                4: 'fea_up0'
            }

        elif opt['scale'] == 4:
            self.levelToName = {
                0: 'fea_up4',
                1: 'fea_up2',
                2: 'fea_up1',
                3: 'fea_up0',
                4: 'fea_up-1'
            }

        affineInCh = self.get_affineInCh(opt_get)
        flow_permutation = self.get_flow_permutation(flow_permutation, opt)

        normOpt = opt_get(opt, ['network_G', 'flow', 'norm'])

        conditional_channels = {}
        n_rrdb = self.get_n_rrdb_channels(opt, opt_get)
        n_bypass_channels = opt_get(
            opt, ['network_G', 'flow', 'levelConditional', 'n_channels'])
        conditional_channels[0] = n_rrdb
        for level in range(1, self.L + 1):
            # Level 1 gets conditionals from 2, 3, 4 => L - level
            # Level 2 gets conditionals from 3, 4
            # Level 3 gets conditionals from 4
            # Level 4 gets conditionals from None
            n_bypass = 0 if n_bypass_channels is None else (
                self.L - level) * n_bypass_channels
            conditional_channels[level] = n_rrdb + n_bypass

        # Upsampler
        for level in range(1, self.L + 1):
            # 1. Squeeze
            H, W = self.arch_squeeze(H, W)

            # 2. K FlowStep
            self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale,
                                           hidden_channels, opt)
            self.arch_FlowStep(
                H,
                self.K[level],
                LU_decomposed,
                W,
                actnorm_scale,
                affineInCh,
                flow_coupling,
                flow_permutation,
                hidden_channels,
                normOpt,
                opt,
                opt_get,
                n_conditinal_channels=conditional_channels[level])
            # Split
            self.arch_split(H, W, level, self.L, opt, opt_get)

        if opt_get(opt, ['network_G', 'flow', 'split', 'enable']):
            self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
        else:
            self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64)

        self.H = H
        self.W = W
        self.scaleH = 160 / H
        self.scaleW = 160 / W