Exemplo n.º 1
0
    def __init__(self, opt):
        super(SRRaGANModel, self).__init__(opt)
        train_opt = opt['train']

        # set if data should be normalized (-1,1) or not (0,1)
        if self.is_train:
            z_norm = opt['datasets']['train'].get('znorm', False)

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netG.train()
            if train_opt['gan_weight']:
                self.netD = networks.define_D(opt).to(self.device)  # D
                self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            """
            Setup network cap
            """
            # define if the generator will have a final capping mechanism in the output
            self.outm = train_opt.get('finalcap', None)
            """
            Setup batch augmentations
            """
            self.mixup = train_opt.get('mixup', None)
            if self.mixup:
                #TODO: cutblur and cutout need model to be modified so LR and HR have the same dimensions (1x)
                self.mixopts = train_opt.get(
                    'mixopts', ["blend", "rgb", "mixup", "cutmix", "cutmixup"
                                ])  #, "cutout", "cutblur"]
                self.mixprob = train_opt.get(
                    'mixprob', [1.0, 1.0, 1.0, 1.0, 1.0])  #, 1.0, 1.0]
                self.mixalpha = train_opt.get(
                    'mixalpha', [0.6, 1.0, 1.2, 0.7, 0.7])  #, 0.001, 0.7]
                self.aux_mixprob = train_opt.get('aux_mixprob', 1.0)
                self.aux_mixalpha = train_opt.get('aux_mixalpha', 1.2)
                self.mix_p = train_opt.get('mix_p', None)
            """
            Setup frequency separation
            """
            self.fs = train_opt.get('fs', None)
            self.f_low = None
            self.f_high = None
            if self.fs:
                lpf_type = train_opt.get('lpf_type', "average")
                hpf_type = train_opt.get('hpf_type', "average")
                self.f_low = FilterLow(filter_type=lpf_type).to(self.device)
                self.f_high = FilterHigh(filter_type=hpf_type).to(self.device)
            """
            Initialize losses
            """
            #Initialize the losses with the opt parameters
            # Generator losses:
            self.generatorlosses = losses.GeneratorLoss(opt, self.device)
            # TODO: show the configured losses names in logger
            # print(self.generatorlosses.loss_list)

            # Discriminator loss:
            if train_opt['gan_type'] and train_opt['gan_weight']:
                self.cri_gan = True
                diffaug = train_opt.get('diffaug', None)
                dapolicy = None
                if diffaug:  #TODO: this if should not be necessary
                    dapolicy = train_opt.get(
                        'dapolicy', 'color,translation,cutout')  #original
                self.adversarial = losses.Adversarial(train_opt=train_opt,
                                                      device=self.device,
                                                      diffaug=diffaug,
                                                      dapolicy=dapolicy)
                # D_update_ratio and D_init_iters are for WGAN
                self.D_update_ratio = train_opt.get('D_update_ratio', 1)
                self.D_init_iters = train_opt.get('D_init_iters', 0)
            else:
                self.cri_gan = False
            """
            Prepare optimizers
            """
            self.optGstep = False
            self.optDstep = False
            if self.cri_gan:
                self.optimizers, self.optimizer_G, self.optimizer_D = optimizers.get_optimizers(
                    self.cri_gan, self.netD, self.netG, train_opt, logger,
                    self.optimizers)
            else:
                self.optimizers, self.optimizer_G = optimizers.get_optimizers(
                    None, None, self.netG, train_opt, logger, self.optimizers)
                self.optDstep = True
            """
            Prepare schedulers
            """
            self.schedulers = schedulers.get_schedulers(
                optimizers=self.optimizers,
                schedulers=self.schedulers,
                train_opt=train_opt)

            #Keep log in loss class instead?
            self.log_dict = OrderedDict()
            """
            Configure SWA
            """
            #https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/
            self.swa = opt.get('use_swa', False)
            if self.swa:
                self.swa_start_iter = train_opt.get('swa_start_iter', 0)
                # self.swa_start_epoch = train_opt.get('swa_start_epoch', None)
                swa_lr = train_opt.get('swa_lr', 0.0001)
                swa_anneal_epochs = train_opt.get('swa_anneal_epochs', 10)
                swa_anneal_strategy = train_opt.get('swa_anneal_strategy',
                                                    'cos')
                #TODO: Note: This could be done in resume_training() instead, to prevent creating
                # the swa scheduler and model before they are needed
                self.swa_scheduler, self.swa_model = swa.get_swa(
                    self.optimizer_G, self.netG, swa_lr, swa_anneal_epochs,
                    swa_anneal_strategy)
                self.load_swa()  #load swa from resume state
                logger.info('SWA enabled. Starting on iter: {}, lr: {}'.format(
                    self.swa_start_iter, swa_lr))
            """
            If using virtual batch
            """
            batch_size = opt["datasets"]["train"]["batch_size"]
            virtual_batch = opt["datasets"]["train"].get(
                'virtual_batch_size', None)
            self.virtual_batch = virtual_batch if virtual_batch \
                >= batch_size else batch_size
            self.accumulations = self.virtual_batch // batch_size
            self.optimizer_G.zero_grad()
            if self.cri_gan:
                self.optimizer_D.zero_grad()
            """
            Configure AMP
            """
            self.amp = load_amp and opt.get('use_amp', False)
            if self.amp:
                self.cast = autocast
                self.amp_scaler = GradScaler()
                logger.info('AMP enabled')
            else:
                self.cast = nullcast
            """
            Configure FreezeD
            """
            if self.cri_gan:
                loc = train_opt.get('freeze_loc', False)
                disc = opt["network_D"].get('which_model_D', False)
                if "discriminator_vgg" in disc and "fea" not in disc:
                    loc = (loc * 3) - 2
                elif "patchgan" in disc:
                    loc = (loc * 3) - 1
                #TODO: TMP, for now only tested with the vgg-like or patchgan discriminators
                if "discriminator_vgg" in disc or "patchgan" in disc:
                    self.feature_loc = loc
                    logger.info('FreezeD enabled')
                else:
                    self.feature_loc = None

        # print network
        """ 
Exemplo n.º 2
0
    def __init__(self, opt, step=0):
        super(SRRaGANModel, self).__init__(opt)
        train_opt = opt['train']

        # set if data should be normalized (-1,1) or not (0,1)
        if self.is_train:
            z_norm = opt['datasets']['train'].get('znorm', False)

        # specify the models you want to load/save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
        # for training and testing, a generator 'G' is needed
        self.model_names = ['G']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt, step=step).to(self.device)  # G
        if self.is_train:
            self.netG.train()
            if train_opt['gan_weight']:
                self.model_names.append(
                    'D')  # add discriminator to the network list
                self.netD = networks.define_D(opt).to(self.device)  # D
                self.netD.train()
        self.load()  # load G and D if needed

        self.outm = None

        # define losses, optimizer and scheduler
        if self.is_train:
            """
            Setup network cap
            """
            # define if the generator will have a final capping mechanism in the output
            self.outm = train_opt.get('finalcap', None)
            """
            Setup batch augmentations
            """
            self.mixup = train_opt.get('mixup', None)
            if self.mixup:
                #TODO: cutblur and cutout need model to be modified so LR and HR have the same dimensions (1x)
                self.mixopts = train_opt.get(
                    'mixopts', ["blend", "rgb", "mixup", "cutmix", "cutmixup"
                                ])  #, "cutout", "cutblur"]
                self.mixprob = train_opt.get(
                    'mixprob', [1.0, 1.0, 1.0, 1.0, 1.0])  #, 1.0, 1.0]
                self.mixalpha = train_opt.get(
                    'mixalpha', [0.6, 1.0, 1.2, 0.7, 0.7])  #, 0.001, 0.7]
                self.aux_mixprob = train_opt.get('aux_mixprob', 1.0)
                self.aux_mixalpha = train_opt.get('aux_mixalpha', 1.2)
                self.mix_p = train_opt.get('mix_p', None)
            """
            Setup frequency separation
            """
            self.fs = train_opt.get('fs', None)
            self.f_low = None
            self.f_high = None
            if self.fs:
                lpf_type = train_opt.get('lpf_type', "average")
                hpf_type = train_opt.get('hpf_type', "average")
                self.f_low = FilterLow(filter_type=lpf_type).to(self.device)
                self.f_high = FilterHigh(filter_type=hpf_type).to(self.device)
            """
            Initialize losses
            """
            #Initialize the losses with the opt parameters
            # Generator losses:
            # for the losses that don't require high precision (can use half precision)
            self.generatorlosses = losses.GeneratorLoss(opt, self.device)
            # for losses that need high precision (use out of the AMP context)
            self.precisegeneratorlosses = losses.PreciseGeneratorLoss(
                opt, self.device)
            # TODO: show the configured losses names in logger
            # print(self.generatorlosses.loss_list)

            # Discriminator loss:
            if train_opt['gan_type'] and train_opt['gan_weight']:
                self.cri_gan = True
                diffaug = train_opt.get('diffaug', None)
                dapolicy = None
                if diffaug:  #TODO: this if should not be necessary
                    dapolicy = train_opt.get(
                        'dapolicy', 'color,translation,cutout')  #original
                self.adversarial = losses.Adversarial(train_opt=train_opt,
                                                      device=self.device,
                                                      diffaug=diffaug,
                                                      dapolicy=dapolicy)
                # D_update_ratio and D_init_iters are for WGAN
                self.D_update_ratio = train_opt.get('D_update_ratio', 1)
                self.D_init_iters = train_opt.get('D_init_iters', 0)
            else:
                self.cri_gan = False
            """
            Prepare optimizers
            """
            self.optGstep = False
            self.optDstep = False
            if self.cri_gan:
                self.optimizers, self.optimizer_G, self.optimizer_D = optimizers.get_optimizers(
                    self.cri_gan, self.netD, self.netG, train_opt, logger,
                    self.optimizers)
            else:
                self.optimizers, self.optimizer_G = optimizers.get_optimizers(
                    None, None, self.netG, train_opt, logger, self.optimizers)
                self.optDstep = True
            """
            Prepare schedulers
            """
            self.schedulers = schedulers.get_schedulers(
                optimizers=self.optimizers,
                schedulers=self.schedulers,
                train_opt=train_opt)

            #Keep log in loss class instead?
            self.log_dict = OrderedDict()
            """
            Configure SWA
            """
            #https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/
            self.swa = opt.get('use_swa', False)
            if self.swa:
                self.swa_start_iter = train_opt.get('swa_start_iter', 0)
                # self.swa_start_epoch = train_opt.get('swa_start_epoch', None)
                swa_lr = train_opt.get('swa_lr', 0.0001)
                swa_anneal_epochs = train_opt.get('swa_anneal_epochs', 10)
                swa_anneal_strategy = train_opt.get('swa_anneal_strategy',
                                                    'cos')
                #TODO: Note: This could be done in resume_training() instead, to prevent creating
                # the swa scheduler and model before they are needed
                self.swa_scheduler, self.swa_model = swa.get_swa(
                    self.optimizer_G, self.netG, swa_lr, swa_anneal_epochs,
                    swa_anneal_strategy)
                self.load_swa()  #load swa from resume state
                logger.info('SWA enabled. Starting on iter: {}, lr: {}'.format(
                    self.swa_start_iter, swa_lr))
            """
            If using virtual batch
            """
            batch_size = opt["datasets"]["train"]["batch_size"]
            virtual_batch = opt["datasets"]["train"].get(
                'virtual_batch_size', None)
            self.virtual_batch = virtual_batch if virtual_batch \
                >= batch_size else batch_size
            self.accumulations = self.virtual_batch // batch_size
            self.optimizer_G.zero_grad()
            if self.cri_gan:
                self.optimizer_D.zero_grad()
            """
            Configure AMP
            """
            self.amp = load_amp and opt.get('use_amp', False)
            if self.amp:
                self.cast = autocast
                self.amp_scaler = GradScaler()
                logger.info('AMP enabled')
            else:
                self.cast = nullcast
            """
            Configure FreezeD
            """
            if self.cri_gan:
                self.feature_loc = None
                loc = train_opt.get('freeze_loc', False)
                if loc:
                    disc = opt["network_D"].get('which_model_D', False)
                    if "discriminator_vgg" in disc and "fea" not in disc:
                        loc = (loc * 3) - 2
                    elif "patchgan" in disc:
                        loc = (loc * 3) - 1
                    #TODO: TMP, for now only tested with the vgg-like or patchgan discriminators
                    if "discriminator_vgg" in disc or "patchgan" in disc:
                        self.feature_loc = loc
                        logger.info('FreezeD enabled')
            """
            Initialize CEM and wrap training generator 
            """
            self.CEM = opt.get('use_cem', None)
            if self.CEM:
                CEM_conf = CEMnet.Get_CEM_Conf(opt['scale'])
                CEM_conf.sigmoid_range_limit = bool(opt['network_G'].get(
                    'sigmoid_range_limit', 0))
                if CEM_conf.sigmoid_range_limit:
                    CEM_conf.input_range = [-1, 1] if z_norm else [0, 1]
                kernel = None  # note: could pass a kernel here, but None will use default cubic kernel
                self.CEM_net = CEMnet.CEMnet(CEM_conf, upscale_kernel=kernel)
                self.CEM_net.WrapArchitecture(only_padders=True)
                self.netG = self.CEM_net.WrapArchitecture(
                    self.netG,
                    training_patch_size=opt['datasets']['train']['HR_size'])
                logger.info('CEM enabled')

        # print network
        """ 
        TODO:
        Network summary? Make optional with parameter
            could be an selector between traditional print_network() and summary()
        """
        self.print_network(
            verbose=False)  #TODO: pass verbose flag from config file
Exemplo n.º 3
0
    def __init__(self, opt):
        super(VSRModel, self).__init__(opt)
        train_opt = opt['train']
        self.scale = opt.get('scale', 4)

        # set if data should be normalized (-1,1) or not (0,1)
        if self.is_train:
            z_norm = opt['datasets']['train'].get('znorm', False)

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netG.train()
            if train_opt['gan_weight']:
                self.netD = networks.define_D(opt).to(self.device)  # D
                self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            """
            Setup network cap
            """
            # define if the generator will have a final capping mechanism in the output
            self.outm = train_opt.get('finalcap', None)
            """
            Setup batch augmentations
            """
            #TODO: will need testing. Also consider batch augmentations may be ill defined with the temporal data
            '''
            self.mixup = train_opt.get('mixup', None)
            if self.mixup: 
                #TODO: cutblur and cutout need model to be modified so LR and HR have the same dimensions (1x)
                self.mixopts = train_opt.get('mixopts', ["blend", "rgb", "mixup", "cutmix", "cutmixup"]) #, "cutout", "cutblur"]
                self.mixprob = train_opt.get('mixprob', [1.0, 1.0, 1.0, 1.0, 1.0]) #, 1.0, 1.0]
                self.mixalpha = train_opt.get('mixalpha', [0.6, 1.0, 1.2, 0.7, 0.7]) #, 0.001, 0.7]
                self.aux_mixprob = train_opt.get('aux_mixprob', 1.0)
                self.aux_mixalpha = train_opt.get('aux_mixalpha', 1.2)
                self.mix_p = train_opt.get('mix_p', None)
            '''
            """
            Setup frequency separation
            """
            self.fs = train_opt.get('fs', None)
            self.f_low = None
            self.f_high = None
            if self.fs:
                lpf_type = train_opt.get('lpf_type', "average")
                hpf_type = train_opt.get('hpf_type', "average")
                self.f_low = FilterLow(filter_type=lpf_type).to(self.device)
                self.f_high = FilterHigh(filter_type=hpf_type).to(self.device)
            """
            Initialize losses
            """
            #Initialize the losses with the opt parameters
            # Generator losses:
            self.generatorlosses = losses.GeneratorLoss(opt, self.device)
            # TODO: show the configured losses names in logger
            # print(self.generatorlosses.loss_list)

            # Discriminator loss:
            if train_opt['gan_type'] and train_opt['gan_weight']:
                self.cri_gan = True
                diffaug = train_opt.get('diffaug', None)
                dapolicy = None
                if diffaug:  #TODO: this if should not be necessary
                    dapolicy = train_opt.get(
                        'dapolicy', 'color,translation,cutout')  #original
                self.adversarial = losses.Adversarial(train_opt=train_opt,
                                                      device=self.device,
                                                      diffaug=diffaug,
                                                      dapolicy=dapolicy)
                # D_update_ratio and D_init_iters are for WGAN
                self.D_update_ratio = train_opt.get('D_update_ratio', 1)
                self.D_init_iters = train_opt.get('D_init_iters', 0)
            else:
                self.cri_gan = False

            # Optical Flow Reconstruction loss:
            ofr_type = train_opt.get('ofr_type', None)
            ofr_weight = train_opt.get('ofr_weight', [0.1, 0.2, 0.1, 0.01])
            if ofr_type and ofr_weight:
                self.ofr_weight = ofr_weight[3]  #lambda 4
                self.ofr_wl1 = ofr_weight[0]  #lambda 1
                self.ofr_wl2 = ofr_weight[1]  #lambda 2
                ofr_wl3 = ofr_weight[2]  #lambda 3
                if ofr_type == 'ofr':
                    from models.modules.loss import OFR_loss
                    #TODO: make the regularization weight an option. lambda3 = 0.1
                    self.cri_ofr = OFR_loss(reg_weight=ofr_wl3).to(self.device)
            else:
                self.cri_ofr = False
            """
            Prepare optimizers
            """
            self.optGstep = False
            self.optDstep = False
            if self.cri_gan:
                self.optimizers, self.optimizer_G, self.optimizer_D = optimizers.get_optimizers(
                    self.cri_gan, self.netD, self.netG, train_opt, logger,
                    self.optimizers)
            else:
                self.optimizers, self.optimizer_G = optimizers.get_optimizers(
                    None, None, self.netG, train_opt, logger, self.optimizers)
                self.optDstep = True
            """
            Prepare schedulers
            """
            self.schedulers = schedulers.get_schedulers(
                optimizers=self.optimizers,
                schedulers=self.schedulers,
                train_opt=train_opt)

            #Keep log in loss class instead?
            self.log_dict = OrderedDict()
            """
            If using virtual batch
            """
            batch_size = opt["datasets"]["train"]["batch_size"]
            virtual_batch = opt["datasets"]["train"].get(
                'virtual_batch_size', None)
            self.virtual_batch = virtual_batch if virtual_batch \
                >= batch_size else batch_size
            self.accumulations = self.virtual_batch // batch_size
            self.optimizer_G.zero_grad()
            if self.cri_gan:
                self.optimizer_D.zero_grad()
            """
            Configure AMP
            """
            self.amp = load_amp and opt.get('use_amp', False)
            if self.amp:
                self.cast = autocast
                self.amp_scaler = GradScaler()
                logger.info('AMP enabled')
            else:
                self.cast = nullcast

        # print network
        """ 
Exemplo n.º 4
0
    def __init__(self, opt):
        """Initialize the WBC model class.
        Parameters:
            opt (Option dictionary): stores all the experiment flags
        """
        super(WBCModel, self).__init__(opt)
        train_opt = opt['train']

        # fetch lambda_idt if provided for identity loss
        self.lambda_idt = train_opt['lambda_identity']

        # specify the images you want to save/display. The training/test
        # scripts will call <BaseModel.get_current_visuals>
        self.visual_names = ['real_A', 'fake_B', 'real_B']

        if self.is_train and self.lambda_idt and self.lambda_idt > 0.0:
            # if identity loss is used, we also visualize idt_B=G(B)
            self.visual_names.append('idt_B')

        # specify the models you want to load/save to the disk.
        # The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
        # for training and testing, a generator 'G' is needed
        self.model_names = ['G']

        # define networks (both generator and discriminator) and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G

        if self.is_train:
            self.netG.train()
            if train_opt['gan_weight']:
                # add discriminators to the network list
                self.model_names.append('D_S')  # surface
                self.model_names.append('D_T')  # texture
                self.netD_S = networks.define_D(opt).to(self.device)
                t_opt = opt.copy()  # TODO: tmp to reuse same config.
                t_opt['network_D']['input_nc'] = 1
                self.netD_T = networks.define_D(t_opt).to(self.device)
                self.netD_T.train()
                self.netD_S.train()
        self.load()  # load 'G', 'D_T' and 'D_S' if needed

        # additional WBC component, initial guided filter
        #TODO: parameters for GFs can be in options file
        self.guided_filter = GuidedFilter(r=1, eps=1e-2)

        if self.is_train:
            if self.lambda_idt and self.lambda_idt > 0.0:
                # only works when input and output images have the same
                # number of channels
                assert opt['input_nc'] == opt['output_nc']

            # create image buffers to store previously generated images
            self.fake_S_pool = ImagePool(opt['pool_size'])
            self.fake_T_pool = ImagePool(opt['pool_size'])

            # Setup batch augmentations
            #TODO: test
            self.mixup = train_opt.get('mixup', None)
            if self.mixup:
                self.mixopts = train_opt.get(
                    'mixopts', ["blend", "rgb", "mixup", "cutmix", "cutmixup"
                                ])  # , "cutout", "cutblur"]
                self.mixprob = train_opt.get(
                    'mixprob', [1.0, 1.0, 1.0, 1.0, 1.0])  # , 1.0, 1.0]
                self.mixalpha = train_opt.get(
                    'mixalpha', [0.6, 1.0, 1.2, 0.7, 0.7])  # , 0.001, 0.7]
                self.aux_mixprob = train_opt.get('aux_mixprob', 1.0)
                self.aux_mixalpha = train_opt.get('aux_mixalpha', 1.2)
                self.mix_p = train_opt.get('mix_p', None)

            # Setup frequency separation
            self.fs = train_opt.get('fs', None)
            self.f_low = None
            self.f_high = None
            if self.fs:
                lpf_type = train_opt.get('lpf_type', "average")
                hpf_type = train_opt.get('hpf_type', "average")
                self.f_low = FilterLow(filter_type=lpf_type).to(self.device)
                self.f_high = FilterHigh(filter_type=hpf_type).to(self.device)

            # Initialize the losses with the opt parameters
            # Generator losses:
            # for the losses that don't require high precision (can use half precision)
            self.generatorlosses = losses.GeneratorLoss(opt, self.device)
            # for losses that need high precision (use out of the AMP context)
            self.precisegeneratorlosses = losses.PreciseGeneratorLoss(
                opt, self.device)
            # TODO: show the configured losses names in logger
            # print(self.generatorlosses.loss_list)

            # set filters losses for each representation
            self.surf_losses = opt['train'].get('surf_losses', [])
            self.text_losses = opt['train'].get('text_losses', [])
            self.struct_losses = opt['train'].get('struct_losses', ['fea'])
            self.cont_losses = opt['train'].get('cont_losses', ['fea'])
            self.reg_losses = opt['train'].get('reg_losses', ['tv'])

            # add identity loss if configured
            self.idt_losses = []
            if self.is_train and self.lambda_idt and self.lambda_idt > 0.0:
                self.idt_losses = opt['train'].get('idt_losses', ['pix'])

            # custom representations scales
            self.stru_w = opt['train'].get('struct_scale', 1)
            self.cont_w = opt['train'].get('content_scale', 1)
            self.text_w = opt['train'].get('texture_scale', 1)
            self.surf_w = opt['train'].get('surface_scale', 0.1)
            self.reg_w = opt['train'].get('reg_scale', 1)

            # additional WBC components
            self.colorshift = ColorShift()
            self.guided_filter_surf = GuidedFilter(r=5, eps=2e-1)
            self.sp_transform = get_sp_transform(
                train_opt, opt['datasets']['train']['znorm'])

            # Discriminator loss:
            if train_opt['gan_type'] and train_opt['gan_weight']:
                # TODO:
                # self.criterionGAN = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
                self.cri_gan = True
                diffaug = train_opt.get('diffaug', None)
                dapolicy = None
                if diffaug:  # TODO: this if should not be necessary
                    dapolicy = train_opt.get(
                        'dapolicy', 'color,translation,cutout')  # original
                self.adversarial = losses.Adversarial(train_opt=train_opt,
                                                      device=self.device,
                                                      diffaug=diffaug,
                                                      dapolicy=dapolicy,
                                                      conditional=False)
                # TODO:
                # D_update_ratio and D_init_iters are for WGAN
                # self.D_update_ratio = train_opt.get('D_update_ratio', 1)
                # self.D_init_iters = train_opt.get('D_init_iters', 0)
            else:
                self.cri_gan = False

            # Initialize optimizers
            self.optGstep = False
            self.optDstep = False
            if self.cri_gan:
                # self.optimizers, self.optimizer_G, self.optimizer_D = optimizers.get_optimizers(
                #     self.cri_gan, [self.netD_T, self.netD_S], self.netG,
                #     train_opt, logger, self.optimizers)
                self.optimizers, self.optimizer_G, self.optimizer_D = optimizers.get_optimizers(
                    cri_gan=self.cri_gan,
                    netG=self.netG,
                    optim_paramsD=itertools.chain(self.netD_T.parameters(),
                                                  self.netD_S.parameters()),
                    train_opt=train_opt,
                    logger=logger,
                    optimizers=self.optimizers)
            else:
                self.optimizers, self.optimizer_G = optimizers.get_optimizers(
                    None, None, self.netG, train_opt, logger, self.optimizers)
                self.optDstep = True

            # Prepare schedulers
            self.schedulers = schedulers.get_schedulers(
                optimizers=self.optimizers,
                schedulers=self.schedulers,
                train_opt=train_opt)

            # Configure SWA
            self.swa = opt.get('use_swa', False)
            if self.swa:
                self.swa_start_iter = train_opt.get('swa_start_iter', 0)
                # self.swa_start_epoch = train_opt.get('swa_start_epoch', None)
                swa_lr = train_opt.get('swa_lr', 0.0001)
                swa_anneal_epochs = train_opt.get('swa_anneal_epochs', 10)
                swa_anneal_strategy = train_opt.get('swa_anneal_strategy',
                                                    'cos')
                # TODO: Note: This could be done in resume_training() instead, to prevent creating
                # the swa scheduler and model before they are needed
                self.swa_scheduler, self.swa_model = swa.get_swa(
                    self.optimizer_G, self.netG, swa_lr, swa_anneal_epochs,
                    swa_anneal_strategy)
                self.load_swa()  # load swa from resume state
                logger.info('SWA enabled. Starting on iter: {}, lr: {}'.format(
                    self.swa_start_iter, swa_lr))

            # Configure virtual batch
            batch_size = opt["datasets"]["train"]["batch_size"]
            virtual_batch = opt["datasets"]["train"].get(
                'virtual_batch_size', None)
            self.virtual_batch = virtual_batch if virtual_batch \
                >= batch_size else batch_size
            self.accumulations = self.virtual_batch // batch_size
            self.optimizer_G.zero_grad()
            if self.cri_gan:
                self.optimizer_D.zero_grad()

            # Configure AMP
            self.amp = load_amp and opt.get('use_amp', False)
            if self.amp:
                self.cast = autocast
                self.amp_scaler = GradScaler()
                logger.info('AMP enabled')
            else:
                self.cast = nullcast

            # Configure FreezeD
            if self.cri_gan:
                self.feature_loc = None
                loc = train_opt.get('freeze_loc', False)
                if loc:
                    disc = opt["network_D"].get('which_model_D', False)
                    if "discriminator_vgg" in disc and "fea" not in disc:
                        loc = (loc * 3) - 2
                    elif "patchgan" in disc:
                        loc = (loc * 3) - 1
                    # TODO: TMP, for now only tested with the vgg-like or patchgan discriminators
                    if "discriminator_vgg" in disc or "patchgan" in disc:
                        self.feature_loc = loc
                        logger.info('FreezeD enabled')

            # create logs dictionaries
            self.log_dict = OrderedDict()
            self.log_dict_T = OrderedDict()
            self.log_dict_S = OrderedDict()

        self.print_network(
            verbose=False)  # TODO: pass verbose flag from config file
Exemplo n.º 5
0
    def __init__(self, opt):
        """Initialize the pix2pix class.
        Parameters:
            opt (Option dictionary): stores all the experiment flags
        """
        super(Pix2PixModel, self).__init__(opt)
        train_opt = opt['train']

        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        self.visual_names = ['real_A', 'fake_B', 'real_B']

        # specify the models you want to load/save to the disk.
        # The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
        # For training and testing, a generator 'G' is needed
        self.model_names = ['G']

        # define networks (both generator and discriminator) and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G

        if self.is_train:
            self.netG.train()
            if train_opt['gan_weight']:
                self.model_names.append(
                    'D')  # add discriminator to the network list
                # define a discriminator; conditional GANs need to take both input and output images;
                # Therefore, input channels for D must be input_nc + output_nc
                self.netD = networks.define_D(opt).to(self.device)  # D
                self.netD.train()
        self.load()  # load G and D if needed

        if self.is_train:
            """
            Setup batch augmentations
            #TODO: test
            """
            self.mixup = train_opt.get('mixup', None)
            if self.mixup:
                self.mixopts = train_opt.get(
                    'mixopts', ["blend", "rgb", "mixup", "cutmix", "cutmixup"
                                ])  # , "cutout", "cutblur"]
                self.mixprob = train_opt.get(
                    'mixprob', [1.0, 1.0, 1.0, 1.0, 1.0])  # , 1.0, 1.0]
                self.mixalpha = train_opt.get(
                    'mixalpha', [0.6, 1.0, 1.2, 0.7, 0.7])  # , 0.001, 0.7]
                self.aux_mixprob = train_opt.get('aux_mixprob', 1.0)
                self.aux_mixalpha = train_opt.get('aux_mixalpha', 1.2)
                self.mix_p = train_opt.get('mix_p', None)
            """
            Setup frequency separation
            """
            self.fs = train_opt.get('fs', None)
            self.f_low = None
            self.f_high = None
            if self.fs:
                lpf_type = train_opt.get('lpf_type', "average")
                hpf_type = train_opt.get('hpf_type', "average")
                self.f_low = FilterLow(filter_type=lpf_type).to(self.device)
                self.f_high = FilterHigh(filter_type=hpf_type).to(self.device)
            """
            Initialize losses
            """
            # Initialize the losses with the opt parameters
            # Generator losses:
            # for the losses that don't require high precision (can use half precision)
            self.generatorlosses = losses.GeneratorLoss(opt, self.device)
            # for losses that need high precision (use out of the AMP context)
            self.precisegeneratorlosses = losses.PreciseGeneratorLoss(
                opt, self.device)
            # TODO: show the configured losses names in logger
            # print(self.generatorlosses.loss_list)

            # Discriminator loss:
            if train_opt['gan_type'] and train_opt['gan_weight']:
                self.cri_gan = True
                diffaug = train_opt.get('diffaug', None)
                dapolicy = None
                if diffaug:  # TODO: this if should not be necessary
                    dapolicy = train_opt.get(
                        'dapolicy', 'color,translation,cutout')  # original
                self.adversarial = losses.Adversarial(train_opt=train_opt,
                                                      device=self.device,
                                                      diffaug=diffaug,
                                                      dapolicy=dapolicy,
                                                      conditional=True)
                # TODO:
                # D_update_ratio and D_init_iters are for WGAN
                # self.D_update_ratio = train_opt.get('D_update_ratio', 1)
                # self.D_init_iters = train_opt.get('D_init_iters', 0)
            else:
                self.cri_gan = False
            """
            Initialize optimizers
            """
            self.optGstep = False
            self.optDstep = False
            if self.cri_gan:
                self.optimizers, self.optimizer_G, self.optimizer_D = optimizers.get_optimizers(
                    self.cri_gan, self.netD, self.netG, train_opt, logger,
                    self.optimizers)
            else:
                self.optimizers, self.optimizer_G = optimizers.get_optimizers(
                    None, None, self.netG, train_opt, logger, self.optimizers)
                self.optDstep = True
            """
            Prepare schedulers
            """
            self.schedulers = schedulers.get_schedulers(
                optimizers=self.optimizers,
                schedulers=self.schedulers,
                train_opt=train_opt)
            """
            Configure SWA
            #TODO: test
            """
            self.swa = opt.get('use_swa', False)
            if self.swa:
                self.swa_start_iter = train_opt.get('swa_start_iter', 0)
                # self.swa_start_epoch = train_opt.get('swa_start_epoch', None)
                swa_lr = train_opt.get('swa_lr', 0.0001)
                swa_anneal_epochs = train_opt.get('swa_anneal_epochs', 10)
                swa_anneal_strategy = train_opt.get('swa_anneal_strategy',
                                                    'cos')
                # TODO: Note: This could be done in resume_training() instead, to prevent creating
                # the swa scheduler and model before they are needed
                self.swa_scheduler, self.swa_model = swa.get_swa(
                    self.optimizer_G, self.netG, swa_lr, swa_anneal_epochs,
                    swa_anneal_strategy)
                self.load_swa()  # load swa from resume state
                logger.info('SWA enabled. Starting on iter: {}, lr: {}'.format(
                    self.swa_start_iter, swa_lr))
            """
            If using virtual batch
            """
            batch_size = opt["datasets"]["train"]["batch_size"]
            virtual_batch = opt["datasets"]["train"].get(
                'virtual_batch_size', None)
            self.virtual_batch = virtual_batch if virtual_batch \
                >= batch_size else batch_size
            self.accumulations = self.virtual_batch // batch_size
            self.optimizer_G.zero_grad()
            if self.cri_gan:
                self.optimizer_D.zero_grad()
            """
            Configure AMP
            """
            self.amp = load_amp and opt.get('use_amp', False)
            if self.amp:
                self.cast = autocast
                self.amp_scaler = GradScaler()
                logger.info('AMP enabled')
            else:
                self.cast = nullcast
            """
            Configure FreezeD
            """
            if self.cri_gan:
                self.feature_loc = None
                loc = train_opt.get('freeze_loc', False)
                if loc:
                    disc = opt["network_D"].get('which_model_D', False)
                    if "discriminator_vgg" in disc and "fea" not in disc:
                        loc = (loc * 3) - 2
                    elif "patchgan" in disc:
                        loc = (loc * 3) - 1
                    # TODO: TMP, for now only tested with the vgg-like or patchgan discriminators
                    if "discriminator_vgg" in disc or "patchgan" in disc:
                        self.feature_loc = loc
                        logger.info('FreezeD enabled')

            self.log_dict = OrderedDict()

        self.print_network(
            verbose=False)  # TODO: pass verbose flag from config file
Exemplo n.º 6
0
    def __init__(self, opt):
        super(PBRModel, self).__init__(opt)
        train_opt = opt['train']

        # set if data should be normalized (-1,1) or not (0,1)
        if self.is_train:
            z_norm = opt['datasets']['train'].get('znorm', False)

        # specify the models you want to load/save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
        # for training and testing, a generator 'G' is needed
        self.model_names = ['G']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netG.train()
            if train_opt['gan_weight']:
                self.model_names.append(
                    'D')  # add discriminator to the network list
                self.netD = networks.define_D(opt).to(self.device)  # D
                self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            """
            Setup network cap
            """
            # define if the generator will have a final capping mechanism in the output
            self.outm = train_opt.get('finalcap', None)
            """
            Setup batch augmentations
            """
            self.mixup = train_opt.get('mixup', None)
            if self.mixup:
                #TODO: cutblur and cutout need model to be modified so LR and HR have the same dimensions (1x)
                self.mixopts = train_opt.get(
                    'mixopts', ["blend", "rgb", "mixup", "cutmix", "cutmixup"
                                ])  #, "cutout", "cutblur"]
                self.mixprob = train_opt.get(
                    'mixprob', [1.0, 1.0, 1.0, 1.0, 1.0])  #, 1.0, 1.0]
                self.mixalpha = train_opt.get(
                    'mixalpha', [0.6, 1.0, 1.2, 0.7, 0.7])  #, 0.001, 0.7]
                self.aux_mixprob = train_opt.get('aux_mixprob', 1.0)
                self.aux_mixalpha = train_opt.get('aux_mixalpha', 1.2)
                self.mix_p = train_opt.get('mix_p', None)
            """
            Setup frequency separation
            """
            self.fs = train_opt.get('fs', None)
            self.f_low = None
            self.f_high = None
            if self.fs:
                lpf_type = train_opt.get('lpf_type', "average")
                hpf_type = train_opt.get('hpf_type', "average")
                self.f_low = FilterLow(filter_type=lpf_type).to(self.device)
                self.f_high = FilterHigh(filter_type=hpf_type).to(self.device)
            """
            Initialize losses
            """
            #Initialize the losses with the opt parameters
            # Generator losses for 3 channel maps: diffuse, albedo and normal:
            # for the losses that don't require high precision (can use half precision)
            self.generatorlosses = losses.GeneratorLoss(opt, self.device)
            # for losses that need high precision (use out of the AMP context)
            self.precisegeneratorlosses = losses.PreciseGeneratorLoss(
                opt, self.device)
            # TODO: show the configured losses names in logger
            # print(self.generatorlosses.loss_list)

            # Generator losses for 1 channel maps (does not support feature networks like VGG):
            # using new option in the loss builder: allow_featnets = False
            # TODO: does it make sense to make fake 3ch images with the 1ch maps?
            # for the losses that don't require high precision (can use half precision)
            self.generatorlosses1ch = losses.GeneratorLoss(
                opt, self.device, False)
            # for losses that need high precision (use out of the AMP context)
            self.precisegeneratorlosses1ch = losses.PreciseGeneratorLoss(
                opt, self.device, False)

            # Discriminator loss:
            if train_opt['gan_type'] and train_opt['gan_weight']:
                self.cri_gan = True
                diffaug = train_opt.get('diffaug', None)
                dapolicy = None
                if diffaug:  #TODO: this if should not be necessary
                    dapolicy = train_opt.get(
                        'dapolicy', 'color,translation,cutout')  #original
                self.adversarial = losses.Adversarial(train_opt=train_opt,
                                                      device=self.device,
                                                      diffaug=diffaug,
                                                      dapolicy=dapolicy)
                # D_update_ratio and D_init_iters are for WGAN
                self.D_update_ratio = train_opt.get('D_update_ratio', 1)
                self.D_init_iters = train_opt.get('D_init_iters', 0)
            else:
                self.cri_gan = False
            """
            Prepare optimizers
            """
            self.optGstep = False
            self.optDstep = False
            if self.cri_gan:
                self.optimizers, self.optimizer_G, self.optimizer_D = optimizers.get_optimizers(
                    self.cri_gan, self.netD, self.netG, train_opt, logger,
                    self.optimizers)
            else:
                self.optimizers, self.optimizer_G = optimizers.get_optimizers(
                    None, None, self.netG, train_opt, logger, self.optimizers)
                self.optDstep = True
            """
            Prepare schedulers
            """
            self.schedulers = schedulers.get_schedulers(
                optimizers=self.optimizers,
                schedulers=self.schedulers,
                train_opt=train_opt)

            #Keep log in loss class instead?
            self.log_dict = OrderedDict()
            """
            If using virtual batch
            """
            batch_size = opt["datasets"]["train"]["batch_size"]
            virtual_batch = opt["datasets"]["train"].get(
                'virtual_batch_size', None)
            self.virtual_batch = virtual_batch if virtual_batch \
                >= batch_size else batch_size
            self.accumulations = self.virtual_batch // batch_size
            self.optimizer_G.zero_grad()
            if self.cri_gan:
                self.optimizer_D.zero_grad()
            """
            Configure AMP
            """
            self.amp = load_amp and opt.get('use_amp', False)
            if self.amp:
                self.cast = autocast
                self.amp_scaler = GradScaler()
                logger.info('AMP enabled')
            else:
                self.cast = nullcast

        # print network
        """ 
        TODO:
        Network summary? Make optional with parameter
            could be an selector between traditional print_network() and summary()
        """
        self.print_network(
            verbose=False)  #TODO: pass verbose flag from config file
Exemplo n.º 7
0
    def __init__(self, opt):
        """Initialize the CycleGAN class.
        Parameters:
            opt (Option dictionary): stores all the experiment flags
        """
        super(CycleGANModel, self).__init__(opt)
        train_opt = opt['train']

        # fetch lambda_idt if provided for identity loss
        self.lambda_idt = train_opt['lambda_identity']

        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.is_train and self.lambda_idt and self.lambda_idt > 0.0:
            # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        # combine visualizations for A and B
        self.visual_names = visual_names_A + visual_names_B

        # specify the models you want to load/save to the disk.
        # The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
        # for training and testing, a generator 'G' is needed
        self.model_names = ['G_A']

        # define networks (both generator and discriminator) and load pretrained models
        # *The naming is different from those used in the paper.
        #   Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt).to(self.device)  # G_A

        if self.is_train:
            # for training 2 generators are needed, add to list and define
            self.model_names.append('G_B')
            self.netG_B = networks.define_G(opt).to(self.device)  # G_B

            self.netG_A.train()
            self.netG_B.train()
            if train_opt['gan_weight']:
                # add discriminators to the network list
                self.model_names.append('D_A')
                self.model_names.append('D_B')
                self.netD_A = networks.define_D(opt).to(self.device)  # D_A
                self.netD_B = networks.define_D(opt).to(self.device)  # D_B
                self.netD_A.train()
                self.netD_B.train()
        self.load()  # load 'G_A', 'G_B', 'D_A' and 'D_B' if needed

        if self.is_train:
            if self.lambda_idt and self.lambda_idt > 0.0:
                # only works when input and output images have the same number of channels
                assert opt['input_nc'] == opt['output_nc']

            # create image buffers to store previously generated images
            self.fake_A_pool = ImagePool(opt['pool_size'])
            self.fake_B_pool = ImagePool(opt['pool_size'])
            """
            Setup batch augmentations
            #TODO: test
            """
            self.mixup = train_opt.get('mixup', None)
            if self.mixup:
                self.mixopts = train_opt.get(
                    'mixopts', ["blend", "rgb", "mixup", "cutmix", "cutmixup"
                                ])  # , "cutout", "cutblur"]
                self.mixprob = train_opt.get(
                    'mixprob', [1.0, 1.0, 1.0, 1.0, 1.0])  # , 1.0, 1.0]
                self.mixalpha = train_opt.get(
                    'mixalpha', [0.6, 1.0, 1.2, 0.7, 0.7])  # , 0.001, 0.7]
                self.aux_mixprob = train_opt.get('aux_mixprob', 1.0)
                self.aux_mixalpha = train_opt.get('aux_mixalpha', 1.2)
                self.mix_p = train_opt.get('mix_p', None)
            """
            Setup frequency separation
            """
            self.fs = train_opt.get('fs', None)
            self.f_low = None
            self.f_high = None
            if self.fs:
                lpf_type = train_opt.get('lpf_type', "average")
                hpf_type = train_opt.get('hpf_type', "average")
                self.f_low = FilterLow(filter_type=lpf_type).to(self.device)
                self.f_high = FilterHigh(filter_type=hpf_type).to(self.device)
            """
            Initialize losses
            """
            # Initialize the losses with the opt parameters
            # Generator losses:
            # for the losses that don't require high precision (can use half precision)
            self.cyclelosses = losses.GeneratorLoss(opt, self.device)
            # for losses that need high precision (use out of the AMP context)
            self.precisecyclelosses = losses.PreciseGeneratorLoss(
                opt, self.device)
            # TODO: show the configured losses names in logger
            # print(self.cyclelosses.loss_list)

            # add identity loss if configured
            if self.is_train and self.lambda_idt and self.lambda_idt > 0.0:
                # TODO: using the same losses as cycle/generator, could be different
                # self.idtlosses = losses.GeneratorLoss(opt, self.device)
                self.idtlosses = self.cyclelosses
                # self.preciseidtlosses = losses.PreciseGeneratorLoss(opt, self.device)
                self.preciseidtlosses = self.precisecyclelosses

            # Discriminator loss:
            if train_opt['gan_type'] and train_opt['gan_weight']:
                # TODO:
                # self.criterionGAN = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
                self.cri_gan = True
                diffaug = train_opt.get('diffaug', None)
                dapolicy = None
                if diffaug:  # TODO: this if should not be necessary
                    dapolicy = train_opt.get(
                        'dapolicy', 'color,translation,cutout')  # original
                self.adversarial = losses.Adversarial(train_opt=train_opt,
                                                      device=self.device,
                                                      diffaug=diffaug,
                                                      dapolicy=dapolicy,
                                                      conditional=False)
                # TODO:
                # D_update_ratio and D_init_iters are for WGAN
                # self.D_update_ratio = train_opt.get('D_update_ratio', 1)
                # self.D_init_iters = train_opt.get('D_init_iters', 0)
            else:
                self.cri_gan = False
            """
            Initialize optimizers
            """
            self.optGstep = False
            self.optDstep = False
            if self.cri_gan:
                self.optimizers, self.optimizer_G, self.optimizer_D = optimizers.get_optimizers(
                    cri_gan=self.cri_gan,
                    optim_paramsD=itertools.chain(self.netD_A.parameters(),
                                                  self.netD_B.parameters()),
                    optim_paramsG=itertools.chain(self.netG_A.parameters(),
                                                  self.netG_B.parameters()),
                    train_opt=train_opt,
                    logger=logger,
                    optimizers=self.optimizers)
            else:
                self.optimizers, self.optimizer_G = optimizers.get_optimizers(
                    None,
                    None,
                    optim_paramsG=itertools.chain(self.netG_A.parameters(),
                                                  self.netG_B.parameters()),
                    train_opt=train_opt,
                    logger=logger,
                    optimizers=self.optimizers)
                self.optDstep = True
            """
            Prepare schedulers
            """
            self.schedulers = schedulers.get_schedulers(
                optimizers=self.optimizers,
                schedulers=self.schedulers,
                train_opt=train_opt)
            """
            Configure SWA
            """
            # TODO: configure SWA for two Generators
            self.swa = False
            # self.swa = opt.get('use_swa', False)
            # if self.swa:
            #     self.swa_start_iter = train_opt.get('swa_start_iter', 0)
            #     # self.swa_start_epoch = train_opt.get('swa_start_epoch', None)
            #     swa_lr = train_opt.get('swa_lr', 0.0001)
            #     swa_anneal_epochs = train_opt.get('swa_anneal_epochs', 10)
            #     swa_anneal_strategy = train_opt.get('swa_anneal_strategy', 'cos')
            #     #TODO: Note: This could be done in resume_training() instead, to prevent creating
            #     # the swa scheduler and model before they are needed
            #     self.swa_scheduler, self.swa_model_A = swa.get_swa(
            #             self.optimizer_G, itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
            #             swa_lr, swa_anneal_epochs, swa_anneal_strategy)
            #     self.load_swa() #load swa from resume state
            #     logger.info('SWA enabled. Starting on iter: {}, lr: {}'.format(self.swa_start_iter, swa_lr))
            """
            If using virtual batch
            """
            batch_size = opt["datasets"]["train"]["batch_size"]
            virtual_batch = opt["datasets"]["train"].get(
                'virtual_batch_size', None)
            self.virtual_batch = virtual_batch if virtual_batch \
                >= batch_size else batch_size
            self.accumulations = self.virtual_batch // batch_size
            self.optimizer_G.zero_grad()
            if self.cri_gan:
                self.optimizer_D.zero_grad()
            """
            Configure AMP
            """
            self.amp = load_amp and opt.get('use_amp', False)
            if self.amp:
                self.cast = autocast
                self.amp_scaler = GradScaler()
                logger.info('AMP enabled')
            else:
                self.cast = nullcast
            """
            Configure FreezeD
            """
            if self.cri_gan:
                self.feature_loc = None
                loc = train_opt.get('freeze_loc', False)
                if loc:
                    disc = opt["network_D"].get('which_model_D', False)
                    if "discriminator_vgg" in disc and "fea" not in disc:
                        loc = (loc * 3) - 2
                    elif "patchgan" in disc:
                        loc = (loc * 3) - 1
                    # TODO: TMP, for now only tested with the vgg-like or patchgan discriminators
                    if "discriminator_vgg" in disc or "patchgan" in disc:
                        self.feature_loc = loc
                        logger.info('FreezeD enabled')

            self.log_dict = OrderedDict()
            self.log_dict_A = OrderedDict()
            self.log_dict_B = OrderedDict()

        self.print_network(
            verbose=False)  # TODO: pass verbose flag from config file