예제 #1
0
class SRFlowModel(SRRaGANModel):
    def __init__(self, opt, step):
        super(SRFlowModel, self).__init__(opt, step)
        train_opt = opt['train']

        self.heats = opt_get(opt, ['val', 'heats'], 0.0)
        self.n_sample = opt_get(opt, ['val', 'n_sample'], 1)
        hr_size = opt_get(opt, ['datasets', 'train', 'HR_size'], 160)
        self.lr_size = hr_size // opt['scale']
        self.nll = None

        if self.is_train:
            """
            Initialize losses
            """
            # nll loss
            self.fl_weight = opt_get(self.opt, ['train', 'fl_weight'], 1)
            """
            Prepare optimizer
            """
            self.optDstep = True  # no Discriminator being used
            self.optimizers, self.optimizer_G = optimizers.get_optimizers_filter(
                None,
                None,
                self.netG,
                train_opt,
                logger,
                self.optimizers,
                param_filter='RRDB')
            """
            Prepare schedulers
            """
            self.schedulers = schedulers.get_schedulers(
                optimizers=self.optimizers,
                schedulers=self.schedulers,
                train_opt=train_opt)
            """
            Set RRDB training state
            """
            train_RRDB_delay = opt_get(self.opt,
                                       ['network_G', 'train_RRDB_delay'])
            if train_RRDB_delay is not None and step < int(train_RRDB_delay * self.opt['train']['niter']) \
                and self.netG.module.RRDB_training:
                if self.netG.module.set_rrdb_training(False):
                    logger.info(
                        'RRDB module frozen, will unfreeze at iter: {}'.format(
                            int(train_RRDB_delay *
                                self.opt['train']['niter'])))

    # TODO: CEM is WIP
    # def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False,
    #             epses=None, reverse_with_grad=False, lr_enc=None, add_gt_noise=False,
    #             step=None, y_label=None, CEM_net=None)
    #     """
    #     Run forward pass G(LR); called by <optimize_parameters> and <test> functions.
    #     Can be used either with 'data' passed directly or loaded 'self.var_L'.
    #     CEM_net can be used during inference to pass different CEM wrappers.
    #     """
    #     if isinstance(lr, torch.Tensor):
    #         gt=gt, lr=lr
    #     else:
    #         gt=self.real_H, lr=self.var_L

    #     if CEM_net is not None:
    #         wrapped_netG = CEM_net.WrapArchitecture(self.netG)
    #         net_out = wrapped_netG(gt=gt, lr=lr, z=z, eps_std=eps_std, reverse=reverse,
    #                             epses=epses, reverse_with_grad=reverse_with_grad,
    #                             lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
    #                             y_label=y_label)
    #     else:
    #         net_out = self.netG(gt=gt, lr=lr, z=z, eps_std=eps_std, reverse=reverse,
    #                         epses=epses, reverse_with_grad=reverse_with_grad,
    #                         lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
    #                         y_label=y_label)

    #     if reverse:
    #         sr, logdet = net_out
    #         return sr, logdet
    #     else:
    #         z, nll, y_logits = net_out
    #         return z, nll, y_logits

    def add_optimizer_and_scheduler_RRDB(self, train_opt):
        #Note: this function from the original SRFLow code seems partially broken.
        #Since the RRDB optimizer is being created on init, this is not being used
        # optimizers
        assert len(self.optimizers) == 1, self.optimizers
        assert len(self.optimizer_G.param_groups[1]
                   ['params']) == 0, self.optimizer_G.param_groups[1]
        for k, v in self.netG.named_parameters(
        ):  # can optimize for a part of the model
            if v.requires_grad:
                if '.RRDB.' in k:
                    self.optimizer_G.param_groups[1]['params'].append(v)
        assert len(self.optimizer_G.param_groups[1]['params']) > 0

    def optimize_parameters(self, step):
        # unfreeze RRDB module if train_RRDB_delay is set
        train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
        if train_RRDB_delay is not None and \
                int(step/self.accumulations) > int(train_RRDB_delay * self.opt['train']['niter']) \
                and not self.netG.module.RRDB_training:
            if self.netG.module.set_rrdb_training(True):
                logger.info('Unfreezing RRDB module.')
                if len(self.optimizers) == 1:
                    # add the RRDB optimizer only if missing
                    self.add_optimizer_and_scheduler_RRDB(self.opt['train'])

        # self.print_rrdb_state()

        self.netG.train()
        """
        Calculate and log losses
        """
        l_g_total = 0
        if self.fl_weight > 0:
            # compute the negative log-likelihood of the output z assuming a unit-norm Gaussian prior
            # with self.cast():  # needs testing, reduced precision could affect results
            z, nll, y_logits = self.netG(gt=self.real_H,
                                         lr=self.var_L,
                                         reverse=False)
            nll_loss = torch.mean(nll)
            l_g_nll = self.fl_weight * nll_loss
            # # /with self.cast():
            self.log_dict['nll_loss'] = l_g_nll.item()
            l_g_total += l_g_nll / self.accumulations

        if self.generatorlosses.loss_list or self.precisegeneratorlosses.loss_list:
            # batch (mixup) augmentations
            aug = None
            if self.mixup:
                self.real_H, self.var_L, mask, aug = BatchAug(
                    self.real_H, self.var_L, self.mixopts, self.mixprob,
                    self.mixalpha, self.aux_mixprob, self.aux_mixalpha,
                    self.mix_p)

            with self.cast():
                z = self.get_z(heat=0,
                               seed=None,
                               batch_size=self.var_L.shape[0],
                               lr_shape=self.var_L.shape)
                self.fake_H, logdet = self.netG(lr=self.var_L,
                                                z=z,
                                                eps_std=0,
                                                reverse=True,
                                                reverse_with_grad=True)

            # batch (mixup) augmentations
            # cutout-ed pixels are discarded when calculating loss by masking removed pixels
            if aug == "cutout":
                self.fake_H, self.real_H = self.fake_H * mask, self.real_H * mask

            # TODO: CEM is WIP
            # unpad images if using CEM
            # if self.CEM:
            #     self.fake_H = self.CEM_net.HR_unpadder(self.fake_H)
            #     self.real_H = self.CEM_net.HR_unpadder(self.real_H)
            #     self.var_ref = self.CEM_net.HR_unpadder(self.var_ref)

            if self.generatorlosses.loss_list:
                with self.cast():
                    # regular losses
                    loss_results, self.log_dict = self.generatorlosses(
                        self.fake_H, self.real_H, self.log_dict, self.f_low)
                    l_g_total += sum(loss_results) / self.accumulations

            if self.precisegeneratorlosses.loss_list:
                # high precision generator losses (can be affected by AMP half precision)
                precise_loss_results, self.log_dict = self.precisegeneratorlosses(
                    self.fake_H, self.real_H, self.log_dict, self.f_low)
                l_g_total += sum(precise_loss_results) / self.accumulations

        if self.amp:
            self.amp_scaler.scale(l_g_total).backward()
        else:
            l_g_total.backward()

        # only step and clear gradient if virtual batch has completed
        if (step + 1) % self.accumulations == 0:
            if self.amp:
                self.amp_scaler.step(self.optimizer_G)
                self.amp_scaler.update()
            else:
                self.optimizer_G.step()
            self.optimizer_G.zero_grad()
            self.optGstep = True

    def print_rrdb_state(self):
        for name, param in self.netG.module.named_parameters():
            if "RRDB.conv_first.weight" in name:
                print(name, param.requires_grad, param.data.abs().sum())
        print('params',
              [len(p['params']) for p in self.optimizer_G.param_groups])

    def test(self, CEM_net=None):
        self.netG.eval()
        self.fake_H = {}
        for heat in self.heats:
            for i in range(self.n_sample):
                z = self.get_z(heat,
                               seed=None,
                               batch_size=self.var_L.shape[0],
                               lr_shape=self.var_L.shape)
                with torch.no_grad():
                    self.fake_H[(heat, i)], logdet = self.netG(lr=self.var_L,
                                                               z=z,
                                                               eps_std=heat,
                                                               reverse=True)
        with torch.no_grad():
            _, nll, _ = self.netG(gt=self.real_H, lr=self.var_L, reverse=False)
        self.netG.train()
        self.nll = nll.mean().item()

    # TODO
    def get_encode_nll(self, lq, gt):
        self.netG.eval()
        with torch.no_grad():
            _, nll, _ = self.netG(gt=gt, lr=lq, reverse=False)
        self.netG.train()
        return nll.mean().item()

    # TODO: only used for testing code
    def get_sr(self, lq, heat=None, seed=None, z=None, epses=None):
        return self.get_sr_with_z(lq, heat, seed, z, epses)[0]

    # TODO
    def get_encode_z(self, lq, gt, epses=None, add_gt_noise=True):
        self.netG.eval()
        with torch.no_grad():
            z, _, _ = self.netG(gt=gt,
                                lr=lq,
                                reverse=False,
                                epses=epses,
                                add_gt_noise=add_gt_noise)
        self.netG.train()
        return z

    # TODO
    def get_encode_z_and_nll(self, lq, gt, epses=None, add_gt_noise=True):
        self.netG.eval()
        with torch.no_grad():
            z, nll, _ = self.netG(gt=gt,
                                  lr=lq,
                                  reverse=False,
                                  epses=epses,
                                  add_gt_noise=add_gt_noise)
        self.netG.train()
        return z, nll

    # TODO: used by get_sr
    def get_sr_with_z(self, lq, heat=None, seed=None, z=None, epses=None):
        self.netG.eval()
        z = self.get_z(heat, seed, batch_size=lq.shape[0],
                       lr_shape=lq.shape) if z is None and epses is None else z

        with torch.no_grad():
            sr, logdet = self.netG(lr=lq,
                                   z=z,
                                   eps_std=heat,
                                   reverse=True,
                                   epses=epses)
        self.netG.train()
        return sr, z

    # TODO: used in optimize_parameters and test
    def get_z(self, heat, seed=None, batch_size=1, lr_shape=None):
        if seed: torch.manual_seed(seed)
        if opt_get(self.opt, ['network_G', 'flow', 'split', 'enable']):
            C = self.netG.module.flowUpsamplerNet.C
            H = int(self.opt['scale'] * lr_shape[2] //
                    self.netG.module.flowUpsamplerNet.scaleH)
            W = int(self.opt['scale'] * lr_shape[3] //
                    self.netG.module.flowUpsamplerNet.scaleW)
            size = (batch_size, C, H, W)
            z = torch.normal(mean=0, std=heat,
                             size=size) if heat > 0 else torch.zeros(size)
        else:
            L = opt_get(self.opt, ['network_G', 'flow', 'L']) or 3
            fac = 2**(L - 3)
            z_size = int(self.lr_size // (2**(L - 3)))
            z = torch.normal(mean=0,
                             std=heat,
                             size=(batch_size, 3 * 8 * 8 * fac * fac, z_size,
                                   z_size))
        return z

    def get_current_visuals(self, need_HR=True):
        out_dict = OrderedDict()
        out_dict['LR'] = self.var_L.detach()[0].float().cpu()
        out_dict['SR'] = True
        for heat in self.heats:
            for i in range(self.n_sample):
                out_dict[('SR', heat,
                          i)] = self.fake_H[(heat,
                                             i)].detach()[0].float().cpu()
        if need_HR:
            out_dict['HR'] = self.real_H.detach()[0].float().cpu()
        return out_dict
예제 #2
0
class SRRaGANModel(BaseModel):
    def __init__(self, opt):
        super(SRRaGANModel, self).__init__(opt)
        train_opt = opt['train']

        # set if data should be normalized (-1,1) or not (0,1)
        if self.is_train:
            z_norm = opt['datasets']['train'].get('znorm', False)

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netG.train()
            if train_opt['gan_weight']:
                self.netD = networks.define_D(opt).to(self.device)  # D
                self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            """
            Setup network cap
            """
            # define if the generator will have a final capping mechanism in the output
            self.outm = train_opt.get('finalcap', None)
            """
            Setup batch augmentations
            """
            self.mixup = train_opt.get('mixup', None)
            if self.mixup:
                #TODO: cutblur and cutout need model to be modified so LR and HR have the same dimensions (1x)
                self.mixopts = train_opt.get(
                    'mixopts', ["blend", "rgb", "mixup", "cutmix", "cutmixup"
                                ])  #, "cutout", "cutblur"]
                self.mixprob = train_opt.get(
                    'mixprob', [1.0, 1.0, 1.0, 1.0, 1.0])  #, 1.0, 1.0]
                self.mixalpha = train_opt.get(
                    'mixalpha', [0.6, 1.0, 1.2, 0.7, 0.7])  #, 0.001, 0.7]
                self.aux_mixprob = train_opt.get('aux_mixprob', 1.0)
                self.aux_mixalpha = train_opt.get('aux_mixalpha', 1.2)
                self.mix_p = train_opt.get('mix_p', None)
            """
            Setup frequency separation
            """
            self.fs = train_opt.get('fs', None)
            self.f_low = None
            self.f_high = None
            if self.fs:
                lpf_type = train_opt.get('lpf_type', "average")
                hpf_type = train_opt.get('hpf_type', "average")
                self.f_low = FilterLow(filter_type=lpf_type).to(self.device)
                self.f_high = FilterHigh(filter_type=hpf_type).to(self.device)
            """
            Initialize losses
            """
            #Initialize the losses with the opt parameters
            # Generator losses:
            self.generatorlosses = losses.GeneratorLoss(opt, self.device)
            # TODO: show the configured losses names in logger
            # print(self.generatorlosses.loss_list)

            # Discriminator loss:
            if train_opt['gan_type'] and train_opt['gan_weight']:
                self.cri_gan = True
                diffaug = train_opt.get('diffaug', None)
                dapolicy = None
                if diffaug:  #TODO: this if should not be necessary
                    dapolicy = train_opt.get(
                        'dapolicy', 'color,translation,cutout')  #original
                self.adversarial = losses.Adversarial(train_opt=train_opt,
                                                      device=self.device,
                                                      diffaug=diffaug,
                                                      dapolicy=dapolicy)
                # D_update_ratio and D_init_iters are for WGAN
                self.D_update_ratio = train_opt.get('D_update_ratio', 1)
                self.D_init_iters = train_opt.get('D_init_iters', 0)
            else:
                self.cri_gan = False
            """
            Prepare optimizers
            """
            self.optGstep = False
            self.optDstep = False
            if self.cri_gan:
                self.optimizers, self.optimizer_G, self.optimizer_D = optimizers.get_optimizers(
                    self.cri_gan, self.netD, self.netG, train_opt, logger,
                    self.optimizers)
            else:
                self.optimizers, self.optimizer_G = optimizers.get_optimizers(
                    None, None, self.netG, train_opt, logger, self.optimizers)
                self.optDstep = True
            """
            Prepare schedulers
            """
            self.schedulers = schedulers.get_schedulers(
                optimizers=self.optimizers,
                schedulers=self.schedulers,
                train_opt=train_opt)

            #Keep log in loss class instead?
            self.log_dict = OrderedDict()
            """
            Configure SWA
            """
            #https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/
            self.swa = opt.get('use_swa', False)
            if self.swa:
                self.swa_start_iter = train_opt.get('swa_start_iter', 0)
                # self.swa_start_epoch = train_opt.get('swa_start_epoch', None)
                swa_lr = train_opt.get('swa_lr', 0.0001)
                swa_anneal_epochs = train_opt.get('swa_anneal_epochs', 10)
                swa_anneal_strategy = train_opt.get('swa_anneal_strategy',
                                                    'cos')
                #TODO: Note: This could be done in resume_training() instead, to prevent creating
                # the swa scheduler and model before they are needed
                self.swa_scheduler, self.swa_model = swa.get_swa(
                    self.optimizer_G, self.netG, swa_lr, swa_anneal_epochs,
                    swa_anneal_strategy)
                self.load_swa()  #load swa from resume state
                logger.info('SWA enabled. Starting on iter: {}, lr: {}'.format(
                    self.swa_start_iter, swa_lr))
            """
            If using virtual batch
            """
            batch_size = opt["datasets"]["train"]["batch_size"]
            virtual_batch = opt["datasets"]["train"].get(
                'virtual_batch_size', None)
            self.virtual_batch = virtual_batch if virtual_batch \
                >= batch_size else batch_size
            self.accumulations = self.virtual_batch // batch_size
            self.optimizer_G.zero_grad()
            if self.cri_gan:
                self.optimizer_D.zero_grad()
            """
            Configure AMP
            """
            self.amp = load_amp and opt.get('use_amp', False)
            if self.amp:
                self.cast = autocast
                self.amp_scaler = GradScaler()
                logger.info('AMP enabled')
            else:
                self.cast = nullcast
            """
            Configure FreezeD
            """
            if self.cri_gan:
                loc = train_opt.get('freeze_loc', False)
                disc = opt["network_D"].get('which_model_D', False)
                if "discriminator_vgg" in disc and "fea" not in disc:
                    loc = (loc * 3) - 2
                elif "patchgan" in disc:
                    loc = (loc * 3) - 1
                #TODO: TMP, for now only tested with the vgg-like or patchgan discriminators
                if "discriminator_vgg" in disc or "patchgan" in disc:
                    self.feature_loc = loc
                    logger.info('FreezeD enabled')
                else:
                    self.feature_loc = None

        # print network
        """ 
        TODO:
        Network summary? Make optional with parameter
            could be an selector between traditional print_network() and summary()
        """
        #self.print_network() #TODO

    def feed_data(self, data, need_HR=True):
        # LR images
        self.var_L = data['LR'].to(self.device)
        if need_HR:  # train or val
            # HR images
            self.var_H = data['HR'].to(self.device)
            # discriminator references
            input_ref = data.get('ref', data['HR'])
            self.var_ref = input_ref.to(self.device)

    def feed_data_batch(self, data, need_HR=True):
        # LR
        self.var_L = data

    def optimize_parameters(self, step):
        # G
        # freeze discriminator while generator is trained to prevent BP
        if self.cri_gan:
            self.requires_grad(self.netD, flag=False, net_type='D')

        # batch (mixup) augmentations
        aug = None
        if self.mixup:
            self.var_H, self.var_L, mask, aug = BatchAug(
                self.var_H, self.var_L, self.mixopts, self.mixprob,
                self.mixalpha, self.aux_mixprob, self.aux_mixalpha, self.mix_p)

        ### Network forward, generate SR
        with self.cast():
            if self.outm:  #if the model has the final activation option
                self.fake_H = self.netG(self.var_L, outm=self.outm)
            else:  #regular models without the final activation option
                self.fake_H = self.netG(self.var_L)
        #/with self.cast():

        # batch (mixup) augmentations
        # cutout-ed pixels are discarded when calculating loss by masking removed pixels
        if aug == "cutout":
            self.fake_H, self.var_H = self.fake_H * mask, self.var_H * mask

        l_g_total = 0
        """
        Calculate and log losses
        """
        loss_results = []
        # training generator and discriminator
        # update generator (on its own if only training generator or alternatively if training GAN)
        if (self.cri_gan is not True) or (step % self.D_update_ratio == 0
                                          and step > self.D_init_iters):
            with self.cast(
            ):  # Casts operations to mixed precision if enabled, else nullcontext
                # regular losses
                loss_results, self.log_dict = self.generatorlosses(
                    self.fake_H, self.var_H, self.log_dict, self.f_low)
                l_g_total += sum(loss_results) / self.accumulations

                if self.cri_gan:
                    # adversarial loss
                    l_g_gan = self.adversarial(
                        self.fake_H,
                        self.var_ref,
                        netD=self.netD,
                        stage='generator',
                        fsfilter=self.f_high)  # (sr, hr)
                    self.log_dict['l_g_gan'] = l_g_gan.item()
                    l_g_total += l_g_gan / self.accumulations

            #/with self.cast():

            if self.amp:
                # call backward() on scaled loss to create scaled gradients.
                self.amp_scaler.scale(l_g_total).backward()
            else:
                l_g_total.backward()

            # only step and clear gradient if virtual batch has completed
            if (step + 1) % self.accumulations == 0:
                if self.amp:
                    # unscale gradients of the optimizer's params, call
                    # optimizer.step() if no infs/NaNs in gradients, else, skipped
                    self.amp_scaler.step(self.optimizer_G)
                    # Update GradScaler scale for next iteration.
                    self.amp_scaler.update()
                    #TODO: remove. for debugging AMP
                    #print("AMP Scaler state dict: ", self.amp_scaler.state_dict())
                else:
                    self.optimizer_G.step()
                self.optimizer_G.zero_grad()
                self.optGstep = True

        if self.cri_gan:
            # update discriminator
            if isinstance(self.feature_loc, int):
                # unfreeze all D
                self.requires_grad(self.netD, flag=True)
                # then freeze up to the selected layers
                for loc in range(self.feature_loc):
                    self.requires_grad(self.netD,
                                       False,
                                       target_layer=loc,
                                       net_type='D')
            else:
                # unfreeze discriminator
                self.requires_grad(self.netD, flag=True)

            l_d_total = 0

            with self.cast(
            ):  # Casts operations to mixed precision if enabled, else nullcontext
                l_d_total, gan_logs = self.adversarial(
                    self.fake_H,
                    self.var_ref,
                    netD=self.netD,
                    stage='discriminator',
                    fsfilter=self.f_high)  # (sr, hr)

                for g_log in gan_logs:
                    self.log_dict[g_log] = gan_logs[g_log]

                l_d_total /= self.accumulations
            #/with autocast():

            if self.amp:
                # call backward() on scaled loss to create scaled gradients.
                self.amp_scaler.scale(l_d_total).backward()
            else:
                l_d_total.backward()

            # only step and clear gradient if virtual batch has completed
            if (step + 1) % self.accumulations == 0:
                if self.amp:
                    # unscale gradients of the optimizer's params, call
                    # optimizer.step() if no infs/NaNs in gradients, else, skipped
                    self.amp_scaler.step(self.optimizer_D)
                    # Update GradScaler scale for next iteration.
                    self.amp_scaler.update()
                else:
                    self.optimizer_D.step()
                self.optimizer_D.zero_grad()
                self.optDstep = True

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            if self.is_train:
                self.fake_H = self.netG(self.var_L)
            else:
                #self.fake_H = self.netG(self.var_L, isTest=True)
                self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_HR=True):
        out_dict = OrderedDict()
        out_dict['LR'] = self.var_L.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        if need_HR:
            out_dict['HR'] = self.var_H.detach()[0].float().cpu()
        #TODO for PPON ?
        #if get stages 1 and 2
        #out_dict['SR_content'] = ...
        #out_dict['SR_structure'] = ...
        return out_dict

    def get_current_visuals_batch(self, need_HR=True):
        out_dict = OrderedDict()
        out_dict['LR'] = self.var_L.detach().float().cpu()
        out_dict['SR'] = self.fake_H.detach().float().cpu()
        if need_HR:
            out_dict['HR'] = self.var_H.detach().float().cpu()
        #TODO for PPON ?
        #if get stages 1 and 2
        #out_dict['SR_content'] = ...
        #out_dict['SR_structure'] = ...
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)

        logger.info('Network G structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)
        if self.is_train:
            # Discriminator
            if self.cri_gan:
                s, n = self.get_network_description(self.netD)
                if isinstance(self.netD, nn.DataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netD.__class__.__name__,
                        self.netD.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netD.__class__.__name__)

                logger.info(
                    'Network D structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                logger.info(s)

            #TODO: feature network is not being trained, is it necessary to visualize? Maybe just name?
            # maybe show the generatorlosses instead?
            '''
            if self.generatorlosses.cri_fea:  # F, Perceptual Network
                #s, n = self.get_network_description(self.netF)
                s, n = self.get_network_description(self.generatorlosses.netF) #TODO
                #s, n = self.get_network_description(self.generatorlosses.loss_list.netF) #TODO
                if isinstance(self.generatorlosses.netF, nn.DataParallel):
                    net_struc_str = '{} - {}'.format(self.generatorlosses.netF.__class__.__name__,
                                                    self.generatorlosses.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.generatorlosses.netF.__class__.__name__)

                logger.info('Network F structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
                logger.info(s)
            '''

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading pretrained model for G [{:s}] ...'.format(
                load_path_G))
            strict = self.opt['network_G'].get('strict', None)
            self.load_network(load_path_G, self.netG, strict, model_type='G')
        if self.opt['is_train'] and self.opt['train']['gan_weight']:
            load_path_D = self.opt['path']['pretrain_model_D']
            if self.opt['is_train'] and load_path_D is not None:
                logger.info('Loading pretrained model for D [{:s}] ...'.format(
                    load_path_D))
                strict = self.opt['network_D'].get('strict', None)
                self.load_network(load_path_D, self.netD, model_type='D')

    def load_swa(self):
        if self.opt['is_train'] and self.opt['use_swa']:
            load_path_swaG = self.opt['path']['pretrain_model_swaG']
            if self.opt['is_train'] and load_path_swaG is not None:
                logger.info(
                    'Loading pretrained model for SWA G [{:s}] ...'.format(
                        load_path_swaG))
                self.load_network(load_path_swaG, self.swa_model)

    def save(self, iter_step, latest=None, loader=None):
        self.save_network(self.netG, 'G', iter_step, latest)
        if self.cri_gan:
            self.save_network(self.netD, 'D', iter_step, latest)
        if self.swa:
            # when training with networks that use BN
            # # Update bn statistics for the swa_model only at the end of training
            # if not isinstance(iter_step, int): #TODO: not sure if it should be done only at the end
            self.swa_model = self.swa_model.cpu()
            torch.optim.swa_utils.update_bn(loader, self.swa_model)
            self.swa_model = self.swa_model.cuda()
            # Check swa BN statistics
            # for module in self.swa_model.modules():
            #     if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
            #         print(module.running_mean)
            #         print(module.running_var)
            #         print(module.momentum)
            #         break
            self.save_network(self.swa_model, 'swaG', iter_step, latest)
예제 #3
0
class SRRaGANModel(BaseModel):
    def __init__(self, opt, step=0):
        super(SRRaGANModel, self).__init__(opt)
        train_opt = opt['train']

        # set if data should be normalized (-1,1) or not (0,1)
        if self.is_train:
            z_norm = opt['datasets']['train'].get('znorm', False)

        # specify the models you want to load/save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
        # for training and testing, a generator 'G' is needed
        self.model_names = ['G']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt, step=step).to(self.device)  # G
        if self.is_train:
            self.netG.train()
            if train_opt['gan_weight']:
                self.model_names.append(
                    'D')  # add discriminator to the network list
                self.netD = networks.define_D(opt).to(self.device)  # D
                self.netD.train()
        self.load()  # load G and D if needed

        self.outm = None

        # define losses, optimizer and scheduler
        if self.is_train:
            """
            Setup network cap
            """
            # define if the generator will have a final capping mechanism in the output
            self.outm = train_opt.get('finalcap', None)
            """
            Setup batch augmentations
            """
            self.mixup = train_opt.get('mixup', None)
            if self.mixup:
                #TODO: cutblur and cutout need model to be modified so LR and HR have the same dimensions (1x)
                self.mixopts = train_opt.get(
                    'mixopts', ["blend", "rgb", "mixup", "cutmix", "cutmixup"
                                ])  #, "cutout", "cutblur"]
                self.mixprob = train_opt.get(
                    'mixprob', [1.0, 1.0, 1.0, 1.0, 1.0])  #, 1.0, 1.0]
                self.mixalpha = train_opt.get(
                    'mixalpha', [0.6, 1.0, 1.2, 0.7, 0.7])  #, 0.001, 0.7]
                self.aux_mixprob = train_opt.get('aux_mixprob', 1.0)
                self.aux_mixalpha = train_opt.get('aux_mixalpha', 1.2)
                self.mix_p = train_opt.get('mix_p', None)
            """
            Setup frequency separation
            """
            self.fs = train_opt.get('fs', None)
            self.f_low = None
            self.f_high = None
            if self.fs:
                lpf_type = train_opt.get('lpf_type', "average")
                hpf_type = train_opt.get('hpf_type', "average")
                self.f_low = FilterLow(filter_type=lpf_type).to(self.device)
                self.f_high = FilterHigh(filter_type=hpf_type).to(self.device)
            """
            Initialize losses
            """
            #Initialize the losses with the opt parameters
            # Generator losses:
            # for the losses that don't require high precision (can use half precision)
            self.generatorlosses = losses.GeneratorLoss(opt, self.device)
            # for losses that need high precision (use out of the AMP context)
            self.precisegeneratorlosses = losses.PreciseGeneratorLoss(
                opt, self.device)
            # TODO: show the configured losses names in logger
            # print(self.generatorlosses.loss_list)

            # Discriminator loss:
            if train_opt['gan_type'] and train_opt['gan_weight']:
                self.cri_gan = True
                diffaug = train_opt.get('diffaug', None)
                dapolicy = None
                if diffaug:  #TODO: this if should not be necessary
                    dapolicy = train_opt.get(
                        'dapolicy', 'color,translation,cutout')  #original
                self.adversarial = losses.Adversarial(train_opt=train_opt,
                                                      device=self.device,
                                                      diffaug=diffaug,
                                                      dapolicy=dapolicy)
                # D_update_ratio and D_init_iters are for WGAN
                self.D_update_ratio = train_opt.get('D_update_ratio', 1)
                self.D_init_iters = train_opt.get('D_init_iters', 0)
            else:
                self.cri_gan = False
            """
            Prepare optimizers
            """
            self.optGstep = False
            self.optDstep = False
            if self.cri_gan:
                self.optimizers, self.optimizer_G, self.optimizer_D = optimizers.get_optimizers(
                    self.cri_gan, self.netD, self.netG, train_opt, logger,
                    self.optimizers)
            else:
                self.optimizers, self.optimizer_G = optimizers.get_optimizers(
                    None, None, self.netG, train_opt, logger, self.optimizers)
                self.optDstep = True
            """
            Prepare schedulers
            """
            self.schedulers = schedulers.get_schedulers(
                optimizers=self.optimizers,
                schedulers=self.schedulers,
                train_opt=train_opt)

            #Keep log in loss class instead?
            self.log_dict = OrderedDict()
            """
            Configure SWA
            """
            #https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/
            self.swa = opt.get('use_swa', False)
            if self.swa:
                self.swa_start_iter = train_opt.get('swa_start_iter', 0)
                # self.swa_start_epoch = train_opt.get('swa_start_epoch', None)
                swa_lr = train_opt.get('swa_lr', 0.0001)
                swa_anneal_epochs = train_opt.get('swa_anneal_epochs', 10)
                swa_anneal_strategy = train_opt.get('swa_anneal_strategy',
                                                    'cos')
                #TODO: Note: This could be done in resume_training() instead, to prevent creating
                # the swa scheduler and model before they are needed
                self.swa_scheduler, self.swa_model = swa.get_swa(
                    self.optimizer_G, self.netG, swa_lr, swa_anneal_epochs,
                    swa_anneal_strategy)
                self.load_swa()  #load swa from resume state
                logger.info('SWA enabled. Starting on iter: {}, lr: {}'.format(
                    self.swa_start_iter, swa_lr))
            """
            If using virtual batch
            """
            batch_size = opt["datasets"]["train"]["batch_size"]
            virtual_batch = opt["datasets"]["train"].get(
                'virtual_batch_size', None)
            self.virtual_batch = virtual_batch if virtual_batch \
                >= batch_size else batch_size
            self.accumulations = self.virtual_batch // batch_size
            self.optimizer_G.zero_grad()
            if self.cri_gan:
                self.optimizer_D.zero_grad()
            """
            Configure AMP
            """
            self.amp = load_amp and opt.get('use_amp', False)
            if self.amp:
                self.cast = autocast
                self.amp_scaler = GradScaler()
                logger.info('AMP enabled')
            else:
                self.cast = nullcast
            """
            Configure FreezeD
            """
            if self.cri_gan:
                self.feature_loc = None
                loc = train_opt.get('freeze_loc', False)
                if loc:
                    disc = opt["network_D"].get('which_model_D', False)
                    if "discriminator_vgg" in disc and "fea" not in disc:
                        loc = (loc * 3) - 2
                    elif "patchgan" in disc:
                        loc = (loc * 3) - 1
                    #TODO: TMP, for now only tested with the vgg-like or patchgan discriminators
                    if "discriminator_vgg" in disc or "patchgan" in disc:
                        self.feature_loc = loc
                        logger.info('FreezeD enabled')
            """
            Initialize CEM and wrap training generator 
            """
            self.CEM = opt.get('use_cem', None)
            if self.CEM:
                CEM_conf = CEMnet.Get_CEM_Conf(opt['scale'])
                CEM_conf.sigmoid_range_limit = bool(opt['network_G'].get(
                    'sigmoid_range_limit', 0))
                if CEM_conf.sigmoid_range_limit:
                    CEM_conf.input_range = [-1, 1] if z_norm else [0, 1]
                kernel = None  # note: could pass a kernel here, but None will use default cubic kernel
                self.CEM_net = CEMnet.CEMnet(CEM_conf, upscale_kernel=kernel)
                self.CEM_net.WrapArchitecture(only_padders=True)
                self.netG = self.CEM_net.WrapArchitecture(
                    self.netG,
                    training_patch_size=opt['datasets']['train']['HR_size'])
                logger.info('CEM enabled')

        # print network
        """ 
        TODO:
        Network summary? Make optional with parameter
            could be an selector between traditional print_network() and summary()
        """
        self.print_network(
            verbose=False)  #TODO: pass verbose flag from config file

    def feed_data(self, data, need_HR=True):
        # LR images
        self.var_L = data['LR'].to(self.device)  # LQ
        if need_HR:  # train or val
            # HR images
            self.real_H = data['HR'].to(self.device)  # GT
            # discriminator references
            input_ref = data.get('ref', data['HR'])
            self.var_ref = input_ref.to(self.device)

    def feed_data_batch(self, data, need_HR=True):
        # LR
        self.var_L = data

    def forward(self, data=None, CEM_net=None):
        """
        Run forward pass; called by <optimize_parameters> and <test> functions.
        Can be used either with 'data' passed directly or loaded 'self.var_L'. 
        CEM_net can be used during inference to pass different CEM wrappers.
        """
        if isinstance(data, torch.Tensor):
            if CEM_net is not None:
                wrapped_netG = CEM_net.WrapArchitecture(self.netG)
                return wrapped_netG(data)
            else:
                return self.netG(data)

        if CEM_net is not None:
            wrapped_netG = CEM_net.WrapArchitecture(self.netG)
            self.fake_H = wrapped_netG(self.var_L)  # G(LR)
        else:
            if self.outm:  #if the model has the final activation option
                self.fake_H = self.netG(self.var_L, outm=self.outm)
            else:  #regular models without the final activation option
                self.fake_H = self.netG(self.var_L)  # G(LR)

    def optimize_parameters(self, step):
        # G
        # freeze discriminator while generator is trained to prevent BP
        if self.cri_gan:
            self.requires_grad(self.netD, flag=False, net_type='D')

        # batch (mixup) augmentations
        aug = None
        if self.mixup:
            self.real_H, self.var_L, mask, aug = BatchAug(
                self.real_H, self.var_L, self.mixopts, self.mixprob,
                self.mixalpha, self.aux_mixprob, self.aux_mixalpha, self.mix_p)

        ### Network forward, generate SR
        with self.cast():
            self.forward()
        #/with self.cast():

        # batch (mixup) augmentations
        # cutout-ed pixels are discarded when calculating loss by masking removed pixels
        if aug == "cutout":
            self.fake_H, self.real_H = self.fake_H * mask, self.real_H * mask

        # unpad images if using CEM
        if self.CEM:
            self.fake_H = self.CEM_net.HR_unpadder(self.fake_H)
            self.real_H = self.CEM_net.HR_unpadder(self.real_H)
            self.var_ref = self.CEM_net.HR_unpadder(self.var_ref)

        l_g_total = 0
        """
        Calculate and log losses
        """
        loss_results = []
        # training generator and discriminator
        # update generator (on its own if only training generator or alternatively if training GAN)
        if (self.cri_gan is not True) or (step % self.D_update_ratio == 0
                                          and step > self.D_init_iters):
            with self.cast(
            ):  # Casts operations to mixed precision if enabled, else nullcontext
                # regular losses
                loss_results, self.log_dict = self.generatorlosses(
                    self.fake_H, self.real_H, self.log_dict, self.f_low)
                l_g_total += sum(loss_results) / self.accumulations

                if self.cri_gan:
                    # adversarial loss
                    l_g_gan = self.adversarial(
                        self.fake_H,
                        self.var_ref,
                        netD=self.netD,
                        stage='generator',
                        fsfilter=self.f_high)  # (sr, hr)
                    self.log_dict['l_g_gan'] = l_g_gan.item()
                    l_g_total += l_g_gan / self.accumulations

            #/with self.cast():
            # high precision generator losses (can be affected by AMP half precision)
            if self.precisegeneratorlosses.loss_list:
                precise_loss_results, self.log_dict = self.precisegeneratorlosses(
                    self.fake_H, self.real_H, self.log_dict, self.f_low)
                l_g_total += sum(precise_loss_results) / self.accumulations

            if self.amp:
                # call backward() on scaled loss to create scaled gradients.
                self.amp_scaler.scale(l_g_total).backward()
            else:
                l_g_total.backward()

            # only step and clear gradient if virtual batch has completed
            if (step + 1) % self.accumulations == 0:
                if self.amp:
                    # unscale gradients of the optimizer's params, call
                    # optimizer.step() if no infs/NaNs in gradients, else, skipped
                    self.amp_scaler.step(self.optimizer_G)
                    # Update GradScaler scale for next iteration.
                    self.amp_scaler.update()
                    #TODO: remove. for debugging AMP
                    #print("AMP Scaler state dict: ", self.amp_scaler.state_dict())
                else:
                    self.optimizer_G.step()
                self.optimizer_G.zero_grad()
                self.optGstep = True

        if self.cri_gan:
            # update discriminator
            if isinstance(self.feature_loc, int):
                # unfreeze all D
                self.requires_grad(self.netD, flag=True)
                # then freeze up to the selected layers
                for loc in range(self.feature_loc):
                    self.requires_grad(self.netD,
                                       False,
                                       target_layer=loc,
                                       net_type='D')
            else:
                # unfreeze discriminator
                self.requires_grad(self.netD, flag=True)

            l_d_total = 0

            with self.cast(
            ):  # Casts operations to mixed precision if enabled, else nullcontext
                l_d_total, gan_logs = self.adversarial(
                    self.fake_H,
                    self.var_ref,
                    netD=self.netD,
                    stage='discriminator',
                    fsfilter=self.f_high)  # (sr, hr)

                for g_log in gan_logs:
                    self.log_dict[g_log] = gan_logs[g_log]

                l_d_total /= self.accumulations
            #/with autocast():

            if self.amp:
                # call backward() on scaled loss to create scaled gradients.
                self.amp_scaler.scale(l_d_total).backward()
            else:
                l_d_total.backward()

            # only step and clear gradient if virtual batch has completed
            if (step + 1) % self.accumulations == 0:
                if self.amp:
                    # unscale gradients of the optimizer's params, call
                    # optimizer.step() if no infs/NaNs in gradients, else, skipped
                    self.amp_scaler.step(self.optimizer_D)
                    # Update GradScaler scale for next iteration.
                    self.amp_scaler.update()
                else:
                    self.optimizer_D.step()
                self.optimizer_D.zero_grad()
                self.optDstep = True

    def test(self, CEM_net=None):
        """Forward function used in test time.
        This function wraps <forward> function in no_grad() so intermediate steps 
        for backprop are not saved.
        """
        self.netG.eval()
        with torch.no_grad():
            self.forward(CEM_net=CEM_net)
        self.netG.train()

    def test_x8(self, CEM_net=None):
        """Geometric self-ensemble forward function used in test time.
        Will upscale each image 8 times in different rotations/flips 
        and average the results into a single image.
        """
        # from https://github.com/thstkdgus35/EDSR-PyTorch
        self.netG.eval()

        def _transform(v, op):
            # if self.precision != 'single': v = v.float()
            v2np = v.data.cpu().numpy()
            if op == 'v':
                tfnp = v2np[:, :, :, ::-1].copy()
            elif op == 'h':
                tfnp = v2np[:, :, ::-1, :].copy()
            elif op == 't':
                tfnp = v2np.transpose((0, 1, 3, 2)).copy()

            ret = torch.Tensor(tfnp).to(self.device)
            # if self.precision == 'half': ret = ret.half()

            return ret

        lr_list = [self.var_L]
        for tf in 'v', 'h', 't':
            lr_list.extend([_transform(t, tf) for t in lr_list])
        with torch.no_grad():
            sr_list = [
                self.forward(data=aug, CEM_net=CEM_net) for aug in lr_list
            ]
        for i in range(len(sr_list)):
            if i > 3:
                sr_list[i] = _transform(sr_list[i], 't')
            if i % 4 > 1:
                sr_list[i] = _transform(sr_list[i], 'h')
            if (i % 4) % 2 == 1:
                sr_list[i] = _transform(sr_list[i], 'v')

        output_cat = torch.cat(sr_list, dim=0)
        self.fake_H = output_cat.mean(dim=0, keepdim=True)
        self.netG.train()

    def test_chop(self, patch_size=200, step=1.0, CEM_net=None):
        """Chop forward function used in test time.
        Converts large images into patches of size (patch_size, patch_size).
        Make sure the patch size is small enough that your GPU memory is sufficient.
        Examples: patch_size = 200 for BlindSR, 64 for ABPN
        """
        batch_size, channels, img_height, img_width = self.var_L.size()
        # if (patch_size * (1.0 - step)) % 1 < 0.5:
        #     patch_size += 1
        patch_size = min(img_height, img_width, patch_size)
        scale = self.opt['scale']

        img_patches = extract_patches_2d(img=self.var_L,
                                         patch_shape=(patch_size, patch_size),
                                         step=[step, step],
                                         batch_first=True).squeeze(0)

        n_patches = img_patches.size(0)
        highres_patches = []

        self.netG.eval()
        with torch.no_grad():
            for p in range(n_patches):
                lowres_input = img_patches[p:p + 1]
                prediction = self.forward(data=lowres_input, CEM_net=CEM_net)
                highres_patches.append(prediction)

        highres_patches = torch.cat(highres_patches, 0)

        self.fake_H = recompose_tensor(highres_patches,
                                       img_height,
                                       img_width,
                                       step=step,
                                       scale=scale)
        self.netG.train()

    def get_current_log(self):
        """Return traning losses / errors. train.py will print out these on the 
        console, and save them to a file"""
        return self.log_dict

    def get_current_visuals(self, need_HR=True):
        """Return visualization images."""
        out_dict = OrderedDict()
        out_dict['LR'] = self.var_L.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        if need_HR:
            out_dict['HR'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def get_current_visuals_batch(self, need_HR=True):
        out_dict = OrderedDict()
        out_dict['LR'] = self.var_L.detach().float().cpu()
        out_dict['SR'] = self.fake_H.detach().float().cpu()
        if need_HR:
            out_dict['HR'] = self.real_H.detach().float().cpu()
        return out_dict
예제 #4
0
class inpaintModel(BaseModel):
    def __init__(self, opt):
        super(inpaintModel, self).__init__(opt)

        self.counter = 0

        train_opt = opt['train']

        # set if data should be normalized (-1,1) or not (0,1)
        if self.is_train:
            z_norm = opt['datasets']['train'].get('znorm', False)

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netG.train()
            if train_opt['gan_weight']:
                self.netD = networks.define_D(opt).to(self.device)  # D
                self.netD.train()
        self.load()  # load G and D if needed
        self.which_model_G = opt['network_G']['which_model_G']

        # define losses, optimizer and scheduler
        if self.is_train:
            """
            Setup network cap
            """
            # define if the generator will have a final capping mechanism in the output
            self.outm = train_opt.get('finalcap', None)
            """
            Setup batch augmentations
            """
            self.mixup = train_opt.get('mixup', None)
            if self.mixup:
                #TODO: cutblur and cutout need model to be modified so LR and HR have the same dimensions (1x)
                self.mixopts = train_opt.get(
                    'mixopts', ["blend", "rgb", "mixup", "cutmix", "cutmixup"
                                ])  #, "cutout", "cutblur"]
                self.mixprob = train_opt.get(
                    'mixprob', [1.0, 1.0, 1.0, 1.0, 1.0])  #, 1.0, 1.0]
                self.mixalpha = train_opt.get(
                    'mixalpha', [0.6, 1.0, 1.2, 0.7, 0.7])  #, 0.001, 0.7]
                self.aux_mixprob = train_opt.get('aux_mixprob', 1.0)
                self.aux_mixalpha = train_opt.get('aux_mixalpha', 1.2)
                self.mix_p = train_opt.get('mix_p', None)
            """
            Setup frequency separation
            """
            self.fs = train_opt.get('fs', None)
            self.f_low = None
            self.f_high = None
            if self.fs:
                lpf_type = train_opt.get('lpf_type', "average")
                hpf_type = train_opt.get('hpf_type', "average")
                self.f_low = FilterLow(filter_type=lpf_type).to(self.device)
                self.f_high = FilterHigh(filter_type=hpf_type).to(self.device)
            """
            Initialize losses
            """
            #Initialize the losses with the opt parameters
            # Generator losses:
            self.generatorlosses = losses.GeneratorLoss(opt, self.device)
            # TODO: show the configured losses names in logger
            # print(self.generatorlosses.loss_list)

            # Discriminator loss:
            if train_opt['gan_type'] and train_opt['gan_weight']:
                self.cri_gan = True
                diffaug = train_opt.get('diffaug', None)
                dapolicy = None
                if diffaug:  #TODO: this if should not be necessary
                    dapolicy = train_opt.get(
                        'dapolicy', 'color,translation,cutout')  #original
                self.adversarial = losses.Adversarial(train_opt=train_opt,
                                                      device=self.device,
                                                      diffaug=diffaug,
                                                      dapolicy=dapolicy)
                # D_update_ratio and D_init_iters are for WGAN
                self.D_update_ratio = train_opt.get('D_update_ratio', 1)
                self.D_init_iters = train_opt.get('D_init_iters', 0)
            else:
                self.cri_gan = False
            """
            Prepare optimizers
            """
            self.optGstep = False
            self.optDstep = False
            if self.cri_gan:
                self.optimizers, self.optimizer_G, self.optimizer_D = optimizers.get_optimizers(
                    self.cri_gan, self.netD, self.netG, train_opt, logger,
                    self.optimizers)
            else:
                self.optimizers, self.optimizer_G = optimizers.get_optimizers(
                    None, None, self.netG, train_opt, logger, self.optimizers)
                self.optDstep = True
            """
            Prepare schedulers
            """
            self.schedulers = schedulers.get_schedulers(
                optimizers=self.optimizers,
                schedulers=self.schedulers,
                train_opt=train_opt)

            #Keep log in loss class instead?
            self.log_dict = OrderedDict()
            """
            Configure SWA
            """
            #https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/
            self.swa = opt.get('use_swa', False)
            if self.swa:
                self.swa_start_iter = train_opt.get('swa_start_iter', 0)
                # self.swa_start_epoch = train_opt.get('swa_start_epoch', None)
                swa_lr = train_opt.get('swa_lr', 0.0001)
                swa_anneal_epochs = train_opt.get('swa_anneal_epochs', 10)
                swa_anneal_strategy = train_opt.get('swa_anneal_strategy',
                                                    'cos')
                #TODO: Note: This could be done in resume_training() instead, to prevent creating
                # the swa scheduler and model before they are needed
                self.swa_scheduler, self.swa_model = swa.get_swa(
                    self.optimizer_G, self.netG, swa_lr, swa_anneal_epochs,
                    swa_anneal_strategy)
                self.load_swa()  #load swa from resume state
                logger.info('SWA enabled. Starting on iter: {}, lr: {}'.format(
                    self.swa_start_iter, swa_lr))
            """
            If using virtual batch
            """
            batch_size = opt["datasets"]["train"]["batch_size"]
            virtual_batch = opt["datasets"]["train"].get(
                'virtual_batch_size', None)
            self.virtual_batch = virtual_batch if virtual_batch \
                >= batch_size else batch_size
            self.accumulations = self.virtual_batch // batch_size
            self.optimizer_G.zero_grad()
            if self.cri_gan:
                self.optimizer_D.zero_grad()
            """
            Configure AMP
            """
            self.amp = load_amp and opt.get('use_amp', False)
            if self.amp:
                self.cast = autocast
                self.amp_scaler = GradScaler()
                logger.info('AMP enabled')
            else:
                self.cast = nullcast

        # print network
        """
        TODO:
        Network summary? Make optional with parameter
            could be an selector between traditional print_network() and summary()
        """
        #self.print_network() #TODO

    #https://github.com/Yukariin/DFNet/blob/master/data.py
    def random_mask(self,
                    height=256,
                    width=256,
                    min_stroke=1,
                    max_stroke=4,
                    min_vertex=1,
                    max_vertex=12,
                    min_brush_width_divisor=16,
                    max_brush_width_divisor=10):

        mask = np.ones((height, width))

        min_brush_width = height // min_brush_width_divisor
        max_brush_width = height // max_brush_width_divisor
        max_angle = 2 * np.pi
        num_stroke = np.random.randint(min_stroke, max_stroke + 1)
        average_length = np.sqrt(height * height + width * width) / 8

        for _ in range(num_stroke):
            num_vertex = np.random.randint(min_vertex, max_vertex + 1)
            start_x = np.random.randint(width)
            start_y = np.random.randint(height)

            for _ in range(num_vertex):
                angle = np.random.uniform(max_angle)
                length = np.clip(
                    np.random.normal(average_length, average_length // 2), 0,
                    2 * average_length)
                brush_width = np.random.randint(min_brush_width,
                                                max_brush_width + 1)
                end_x = (start_x + length * np.sin(angle)).astype(np.int32)
                end_y = (start_y + length * np.cos(angle)).astype(np.int32)

                cv2.line(mask, (start_y, start_x), (end_y, end_x), 0.,
                         brush_width)

                start_x, start_y = end_x, end_y
        if np.random.random() < 0.5:
            mask = np.fliplr(mask)
        if np.random.random() < 0.5:
            mask = np.flipud(mask)
        return torch.from_numpy(
            mask.reshape((1, ) + mask.shape).astype(np.float32)).unsqueeze(0)

    def masking_images(self):
        mask = self.random_mask(height=self.var_L.shape[2],
                                width=self.var_L.shape[3]).cuda()
        for i in range(self.var_L.shape[0] - 1):
            mask = torch.cat([
                mask,
                self.random_mask(height=self.var_L.shape[2],
                                 width=self.var_L.shape[3]).cuda()
            ],
                             dim=0)

        #self.var_L=self.var_L * mask
        return self.var_L * mask, mask

    def masking_images_with_invert(self):
        mask = self.random_mask(height=self.var_L.shape[2],
                                width=self.var_L.shape[3]).cuda()
        for i in range(self.var_L.shape[0] - 1):
            mask = torch.cat([
                mask,
                self.random_mask(height=self.var_L.shape[2],
                                 width=self.var_L.shape[3]).cuda()
            ],
                             dim=0)

        #self.var_L=self.var_L * mask
        return self.var_L * mask, self.var_L * (1 - mask), mask

    def feed_data(self, data, need_HR=True):
        # LR images
        if self.which_model_G == 'EdgeConnect' or self.which_model_G == 'PRVS':
            self.var_L = data['LR'].to(self.device)
            self.canny_data = data['img_HR_canny'].to(self.device)
            self.grayscale_data = data['img_HR_gray'].to(self.device)
            #self.mask = data['green_mask'].to(self.device)
        else:
            self.var_L = data['LR'].to(self.device)

        if need_HR:  # train or val
            # HR images
            self.var_H = data['HR'].to(self.device)
            # discriminator references
            input_ref = data.get('ref', data['HR'])
            self.var_ref = input_ref.to(self.device)

    def feed_data_batch(self, data, need_HR=True):
        # LR
        self.var_L = data

    def optimize_parameters(self, step):
        # G
        # freeze discriminator while generator is trained to prevent BP
        if self.cri_gan:
            for p in self.netD.parameters():
                p.requires_grad = False

        # batch (mixup) augmentations
        aug = None
        if self.mixup:
            self.var_H, self.var_L, mask, aug = BatchAug(
                self.var_H, self.var_L, self.mixopts, self.mixprob,
                self.mixalpha, self.aux_mixprob, self.aux_mixalpha, self.mix_p)

        if self.which_model_G == 'Pluralistic':
            # pluralistic needs the inpainted area as an image and not only the cut-out
            self.var_L, img_inverted, mask = self.masking_images_with_invert()
        else:
            self.var_L, mask = self.masking_images()

        ### Network forward, generate inpainted fake
        with self.cast():
            # normal
            if self.which_model_G == 'AdaFill' or self.which_model_G == 'MEDFE' or self.which_model_G == 'RFR' or self.which_model_G == 'LBAM' or self.which_model_G == 'DMFN' or self.which_model_G == 'partial' or self.which_model_G == 'Adaptive' or self.which_model_G == 'DFNet' or self.which_model_G == 'RN':
                self.fake_H = self.netG(self.var_L, mask)
            # 2 rgb images
            elif self.which_model_G == 'CRA' or self.which_model_G == 'pennet' or self.which_model_G == 'deepfillv1' or self.which_model_G == 'deepfillv2' or self.which_model_G == 'Global' or self.which_model_G == 'crfill' or self.which_model_G == 'DeepDFNet':
                self.fake_H, self.other_img = self.netG(self.var_L, mask)

            # special
            elif self.which_model_G == 'Pluralistic':
                self.fake_H, self.kl_rec, self.kl_g = self.netG(
                    self.var_L, img_inverted, mask)
                save_image(self.fake_H, "self.fake_H_pluralistic.png")

            elif self.which_model_G == 'EdgeConnect':
                self.fake_H, self.other_img = self.netG(
                    self.var_L, self.canny_data, self.grayscale_data, mask)

            elif self.which_model_G == 'FRRN':
                self.fake_H, mid_x, mid_mask = self.netG(self.var_L, mask)

            elif self.which_model_G == 'PRVS':
                self.fake_H, _, edge_small, edge_big = self.netG(
                    self.var_L, mask, self.canny_data)

            elif self.which_model_G == 'CSA':
                #out_c, out_r, csa, csa_d
                coarse_result, self.fake_H, csa, csa_d = self.netG(
                    self.var_L, mask)

            elif self.which_model_G == 'atrous':
                self.fake_H = self.netG(self.var_L)

            else:
                print("Selected model is not implemented.")

        # Merge inpainted data with original data in masked region
        self.fake_H = self.var_L * mask + self.fake_H * (1 - mask)
        #save_image(self.fake_H, 'self_fake_H.png')

        #/with self.cast():
        #self.fake_H = self.netG(self.var_L, mask)

        #self.counter += 1
        #save_image(mask, str(self.counter)+'mask_train.png')
        #save_image(self.fake_H, str(self.counter)+'fake_H_train.png')

        # batch (mixup) augmentations
        # cutout-ed pixels are discarded when calculating loss by masking removed pixels
        if aug == "cutout":
            self.fake_H, self.var_H = self.fake_H * mask, self.var_H * mask

        l_g_total = 0
        """
        Calculate and log losses
        """
        loss_results = []
        # training generator and discriminator
        # update generator (on its own if only training generator or alternatively if training GAN)
        if (self.cri_gan is not True) or (step % self.D_update_ratio == 0
                                          and step > self.D_init_iters):
            with self.cast(
            ):  # Casts operations to mixed precision if enabled, else nullcontext
                # regular losses
                loss_results, self.log_dict = self.generatorlosses(
                    self.fake_H, self.var_H, self.log_dict, self.f_low)

                # additional losses, in case a model does output more than a normal image
                ###############################
                # deepfillv2 / global / crfill / CRA
                if self.which_model_G == 'deepfillv2' or self.which_model_G == 'Global' or self.which_model_G == 'crfill' or self.which_model_G == 'CRA':
                    L1Loss = nn.L1Loss()
                    l1_stage1 = L1Loss(self.other_img, self.var_H)

                    self.log_dict.update(l1_stage1=l1_stage1)
                    loss_results.append(l1_stage1)

                # edge-connect
                if self.which_model_G == 'EdgeConnect':
                    L1Loss = nn.L1Loss()
                    l1_edge = L1Loss(self.other_img, self.canny_data)

                    self.log_dict.update(l1_edge=l1_edge)
                    loss_results.append(l1_edge)
                ###############################
                # csa
                if self.which_model_G == 'CSA':
                    #coarse_result, refine_result, csa, csa_d = g_model(masked, mask)
                    L1Loss = nn.L1Loss()
                    recon_loss = L1Loss(coarse_result, self.var_H) + L1Loss(
                        self.fake_H, self.var_H)

                    from models.modules.csa_loss import ConsistencyLoss
                    cons = ConsistencyLoss()

                    cons_loss = cons(csa, csa_d, self.var_H, mask)

                    self.log_dict.update(recon_loss=recon_loss)
                    loss_results.append(recon_loss)
                    self.log_dict.update(cons_loss=cons_loss)
                    loss_results.append(cons_loss)
                ###############################
                # pluralistic (encoder kl loss)
                if self.which_model_G == 'Pluralistic':
                    loss_kl_rec = self.kl_rec.mean()
                    loss_kl_g = self.kl_g.mean()

                    self.log_dict.update(loss_kl_rec=loss_kl_rec)
                    loss_results.append(loss_kl_rec)
                    self.log_dict.update(loss_kl_g=loss_kl_g)
                    loss_results.append(loss_kl_g)
                ###############################
                # deepfillv1
                if self.which_model_G == 'deepfillv1':
                    from models.modules.deepfillv1_loss import ReconLoss
                    ReconLoss_ = ReconLoss(1, 1, 1, 1)
                    reconstruction_loss = ReconLoss_(self.var_H,
                                                     self.other_img,
                                                     self.fake_H, mask)

                    self.log_dict.update(
                        reconstruction_loss=reconstruction_loss)
                    loss_results.append(reconstruction_loss)
                ###############################
                # pennet
                if self.which_model_G == 'pennet':
                    L1Loss = nn.L1Loss()
                    if self.other_img is not None:
                        pyramid_loss = 0
                        for _, f in enumerate(self.other_img):
                            pyramid_loss += L1Loss(
                                f,
                                torch.nn.functional.interpolate(
                                    self.var_H,
                                    size=f.size()[2:4],
                                    mode='bilinear',
                                    align_corners=True))

                    self.log_dict.update(pyramid_loss=pyramid_loss)
                    loss_results.append(pyramid_loss)
                ###############################
                # FRRN
                if self.which_model_G == 'FRRN':
                    L1Loss = nn.L1Loss()
                    # generator step loss
                    for idx in range(len(mid_x) - 1):
                        mid_l1_loss = L1Loss(mid_x[idx] * mid_mask[idx],
                                             self.var_H * mid_mask[idx])

                    self.log_dict.update(mid_l1_loss=mid_l1_loss)
                    loss_results.append(mid_l1_loss)
                ###############################
                # PRVS
                if self.which_model_G == 'PRVS':
                    L1Loss = nn.L1Loss()
                    #from models.modules.PRVS_loss import edge_loss
                    #[edge_small, edge_big]
                    #adv_loss_0 = self.edge_loss(fake_edge[1], real_edge)
                    #dv_loss_1 = self.edge_loss(fake_edge[0], F.interpolate(real_edge, scale_factor = 0.5))

                    #adv_loss_0 = edge_loss(self, edge_big, self.canny_data, self.grayscale_data)
                    #adv_loss_1 = edge_loss(self, edge_small, torch.nn.functional.interpolate(self.canny_data, scale_factor = 0.5))

                    # l1 instead of discriminator loss
                    edge_big_l1 = L1Loss(edge_big, self.canny_data)
                    edge_small_l1 = L1Loss(
                        edge_small,
                        torch.nn.functional.interpolate(self.canny_data,
                                                        scale_factor=0.5))

                    self.log_dict.update(edge_big_l1=edge_big_l1)
                    loss_results.append(edge_big_l1)
                    self.log_dict.update(edge_small_l1=edge_small_l1)
                    loss_results.append(edge_small_l1)
                ###############################

                #for key, value in self.log_dict.items():
                #    print(key, value)

                l_g_total += sum(loss_results) / self.accumulations

                if self.cri_gan:
                    # adversarial loss
                    l_g_gan = self.adversarial(
                        self.fake_H,
                        self.var_ref,
                        netD=self.netD,
                        stage='generator',
                        fsfilter=self.f_high)  # (sr, hr)
                    self.log_dict['l_g_gan'] = l_g_gan.item()
                    l_g_total += l_g_gan / self.accumulations

            #/with self.cast():

            if self.amp:
                # call backward() on scaled loss to create scaled gradients.
                self.amp_scaler.scale(l_g_total).backward()
            else:
                l_g_total.backward()

            # only step and clear gradient if virtual batch has completed
            if (step + 1) % self.accumulations == 0:
                if self.amp:
                    # unscale gradients of the optimizer's params, call
                    # optimizer.step() if no infs/NaNs in gradients, else, skipped
                    self.amp_scaler.step(self.optimizer_G)
                    # Update GradScaler scale for next iteration.
                    self.amp_scaler.update()
                    #TODO: remove. for debugging AMP
                    #print("AMP Scaler state dict: ", self.amp_scaler.state_dict())
                else:
                    self.optimizer_G.step()
                self.optimizer_G.zero_grad()
                self.optGstep = True

        if self.cri_gan:
            # update discriminator
            # unfreeze discriminator
            for p in self.netD.parameters():
                p.requires_grad = True
            l_d_total = 0

            with self.cast(
            ):  # Casts operations to mixed precision if enabled, else nullcontext
                l_d_total, gan_logs = self.adversarial(
                    self.fake_H,
                    self.var_ref,
                    netD=self.netD,
                    stage='discriminator',
                    fsfilter=self.f_high)  # (sr, hr)

                for g_log in gan_logs:
                    self.log_dict[g_log] = gan_logs[g_log]

                l_d_total /= self.accumulations
            #/with autocast():

            if self.amp:
                # call backward() on scaled loss to create scaled gradients.
                self.amp_scaler.scale(l_d_total).backward()
            else:
                l_d_total.backward()

            # only step and clear gradient if virtual batch has completed
            if (step + 1) % self.accumulations == 0:
                if self.amp:
                    # unscale gradients of the optimizer's params, call
                    # optimizer.step() if no infs/NaNs in gradients, else, skipped
                    self.amp_scaler.step(self.optimizer_D)
                    # Update GradScaler scale for next iteration.
                    self.amp_scaler.update()
                else:
                    self.optimizer_D.step()
                self.optimizer_D.zero_grad()
                self.optDstep = True

    def test(self, data):
        """
        # generating random mask for validation
        self.var_L, mask = self.masking_images()
        if self.which_model_G == 'Pluralistic':
          # pluralistic needs the inpainted area as an image and not only the cut-out
          self.var_L, img_inverted, mask = self.masking_images_with_invert()
        else:
          self.var_L, mask = self.masking_images()
        """

        self.mask = data['green_mask'].float().to(self.device).unsqueeze(0)

        if self.which_model_G == 'Pluralistic':
            img_inverted = self.var_L * (1 - self.mask)

        self.var_L = self.var_L * self.mask

        self.netG.eval()
        with torch.no_grad():
            if self.is_train:
                # normal
                if self.which_model_G == 'AdaFill' or self.which_model_G == 'MEDFE' or self.which_model_G == 'RFR' or self.which_model_G == 'LBAM' or self.which_model_G == 'DMFN' or self.which_model_G == 'partial' or self.which_model_G == 'Adaptive' or self.which_model_G == 'DFNet' or self.which_model_G == 'RN':
                    self.fake_H = self.netG(self.var_L, self.mask)
                # 2 rgb images
                elif self.which_model_G == 'CRA' or self.which_model_G == 'pennet' or self.which_model_G == 'deepfillv1' or self.which_model_G == 'deepfillv2' or self.which_model_G == 'Global' or self.which_model_G == 'crfill' or self.which_model_G == 'DeepDFNet':
                    self.fake_H, _ = self.netG(self.var_L, self.mask)

                # special
                elif self.which_model_G == 'Pluralistic':
                    self.fake_H, _, _ = self.netG(self.var_L, img_inverted,
                                                  self.mask)

                elif self.which_model_G == 'EdgeConnect':
                    self.fake_H, _ = self.netG(self.var_L, self.canny_data,
                                               self.grayscale_data, self.mask)

                elif self.which_model_G == 'FRRN':
                    self.fake_H, _, _ = self.netG(self.var_L, self.mask)

                elif self.which_model_G == 'PRVS':
                    self.fake_H, _, _, _ = self.netG(self.var_L, self.mask,
                                                     self.canny_data)

                elif self.which_model_G == 'CSA':
                    _, self.fake_H, _, _ = self.netG(self.var_L, self.mask)

                elif self.which_model_G == 'atrous':
                    self.fake_H = self.netG(self.var_L)
                else:
                    print("Selected model is not implemented.")

        # Merge inpainted data with original data in masked region
        self.fake_H = self.var_L * self.mask + self.fake_H * (1 - self.mask)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_HR=True):
        out_dict = OrderedDict()
        out_dict['LR'] = self.var_L.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        if need_HR:
            out_dict['HR'] = self.var_H.detach()[0].float().cpu()
        #TODO for PPON ?
        #if get stages 1 and 2
        #out_dict['SR_content'] = ...
        #out_dict['SR_structure'] = ...
        return out_dict

    def get_current_visuals_batch(self, need_HR=True):
        out_dict = OrderedDict()
        out_dict['LR'] = self.var_L.detach().float().cpu()
        out_dict['SR'] = self.fake_H.detach().float().cpu()
        if need_HR:
            out_dict['HR'] = self.var_H.detach().float().cpu()
        #TODO for PPON ?
        #if get stages 1 and 2
        #out_dict['SR_content'] = ...
        #out_dict['SR_structure'] = ...
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)

        logger.info('Network G structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)
        if self.is_train:
            # Discriminator
            if self.cri_gan:
                s, n = self.get_network_description(self.netD)
                if isinstance(self.netD, nn.DataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netD.__class__.__name__,
                        self.netD.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netD.__class__.__name__)

                logger.info(
                    'Network D structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                logger.info(s)

            #TODO: feature network is not being trained, is it necessary to visualize? Maybe just name?
            # maybe show the generatorlosses instead?
            '''
            if self.generatorlosses.cri_fea:  # F, Perceptual Network
                #s, n = self.get_network_description(self.netF)
                s, n = self.get_network_description(self.generatorlosses.netF) #TODO
                #s, n = self.get_network_description(self.generatorlosses.loss_list.netF) #TODO
                if isinstance(self.generatorlosses.netF, nn.DataParallel):
                    net_struc_str = '{} - {}'.format(self.generatorlosses.netF.__class__.__name__,
                                                    self.generatorlosses.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.generatorlosses.netF.__class__.__name__)

                logger.info('Network F structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
                logger.info(s)
            '''

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading pretrained model for G [{:s}] ...'.format(
                load_path_G))
            strict = self.opt['path'].get('strict', None)
            self.load_network(load_path_G, self.netG, strict)
        if self.opt['is_train'] and self.opt['train']['gan_weight']:
            load_path_D = self.opt['path']['pretrain_model_D']
            if self.opt['is_train'] and load_path_D is not None:
                logger.info('Loading pretrained model for D [{:s}] ...'.format(
                    load_path_D))
                self.load_network(load_path_D, self.netD)

    def load_swa(self):
        if self.opt['is_train'] and self.opt['use_swa']:
            load_path_swaG = self.opt['path']['pretrain_model_swaG']
            if self.opt['is_train'] and load_path_swaG is not None:
                logger.info(
                    'Loading pretrained model for SWA G [{:s}] ...'.format(
                        load_path_swaG))
                self.load_network(load_path_swaG, self.swa_model)

    def save(self, iter_step, latest=None, loader=None):
        self.save_network(self.netG, 'G', iter_step, latest)
        if self.cri_gan:
            self.save_network(self.netD, 'D', iter_step, latest)
        if self.swa:
            # when training with networks that use BN
            # # Update bn statistics for the swa_model only at the end of training
            # if not isinstance(iter_step, int): #TODO: not sure if it should be done only at the end
            self.swa_model = self.swa_model.cpu()
            torch.optim.swa_utils.update_bn(loader, self.swa_model)
            self.swa_model = self.swa_model.cuda()
            # Check swa BN statistics
            # for module in self.swa_model.modules():
            #     if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
            #         print(module.running_mean)
            #         print(module.running_var)
            #         print(module.momentum)
            #         break
            self.save_network(self.swa_model, 'swaG', iter_step, latest)