def __init__(self, opt, rank):
        super().__init__(opt, rank)

        # specify the training losses you want to print out.
        # The training/test scripts will call <BaseModel.get_current_losses>
        losses_G = ['sem']
        if opt.out_mask:
            losses_G += ['out_mask']

        losses_f_s = ['f_s']

        self.loss_names_G += losses_G
        self.loss_names_f_s = losses_f_s

        self.loss_names = self.loss_names_G + self.loss_names_D + self.loss_names_f_s

        # define networks (both generator and discriminator)
        if self.isTrain:
            self.netf_s = networks.define_f(opt.input_nc,
                                            nclasses=opt.semantic_nclasses,
                                            init_type=opt.init_type,
                                            init_gain=opt.init_gain,
                                            gpu_ids=self.gpu_ids,
                                            fs_light=opt.fs_light)

            self.model_names += ['f_s']

            # define loss functions
            self.criterionf_s = torch.nn.modules.CrossEntropyLoss()

            if opt.out_mask:
                if opt.loss_out_mask == 'L1':
                    self.criterionMask = torch.nn.L1Loss()
                elif opt.loss_out_mask == 'MSE':
                    self.criterionMask = torch.nn.MSELoss()
                elif opt.loss_out_mask == 'Charbonnier':
                    self.criterionMask = L1_Charbonnier_loss(
                        opt.charbonnier_eps)

            self.optimizer_f_s = torch.optim.Adam(self.netf_s.parameters(),
                                                  lr=opt.lr_f_s,
                                                  betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_f_s)

            if self.opt.iter_size > 1:
                self.iter_calculator = IterCalculator(self.loss_names)
                for i, cur_loss in enumerate(self.loss_names):
                    self.loss_names[i] = cur_loss + '_avg'
                    setattr(self, "loss_" + self.loss_names[i], 0)

            ###Making groups
            self.group_f_s = NetworkGroup(
                networks_to_optimize=["f_s"],
                forward_functions=None,
                backward_functions=["compute_f_s_loss"],
                loss_names_list=["loss_names_f_s"],
                optimizer=["optimizer_f_s"],
                loss_backward=["loss_f_s"])
            self.networks_groups.append(self.group_f_s)
Example #2
0
    def __init__(self, opt):
        super().__init__(opt)

        # specify the training losses you want to print out.
        # The training/test scripts will call <BaseModel.get_current_losses>

        losses_G = ['sem']

        losses_CLS = ['CLS']

        self.loss_names_G += losses_G
        self.loss_names_CLS = losses_CLS
        self.loss_names = self.loss_names_G + self.loss_names_CLS + self.loss_names_D

        # define networks (both generator and discriminator)
        if self.isTrain:
            self.netCLS = networks.define_C(opt.output_nc,
                                            opt.ndf,
                                            opt.crop_size,
                                            init_type=opt.init_type,
                                            init_gain=opt.init_gain,
                                            gpu_ids=self.gpu_ids,
                                            nclasses=opt.semantic_nclasses)

            # define loss functions
            self.criterionCLS = torch.nn.modules.CrossEntropyLoss()

            self.optimizer_CLS = torch.optim.Adam(self.netCLS.parameters(),
                                                  lr=opt.lr_f_s,
                                                  betas=(opt.beta1, 0.999))

            if opt.regression:
                if opt.l1_regression:
                    self.criterionCLS = torch.nn.L1Loss()
                else:
                    self.criterionCLS = torch.nn.modules.MSELoss()
            else:
                self.criterionCLS = torch.nn.modules.CrossEntropyLoss()

            self.optimizers.append(self.optimizer_CLS)

            if self.opt.iter_size > 1:
                self.iter_calculator = IterCalculator(self.loss_names)
                for i, cur_loss in enumerate(self.loss_names):
                    self.loss_names[i] = cur_loss + '_avg'
                    setattr(self, "loss_" + self.loss_names[i], 0)

            self.niter = 0

            ###Making groups
            self.group_CLS = NetworkGroup(
                networks_to_optimize=["CLS"],
                forward_functions=None,
                backward_functions=["compute_CLS_loss"],
                loss_names_list=["loss_names_CLS"],
                optimizer=["optimizer_CLS"],
                loss_backward=["loss_CLS"])
            self.networks_groups.append(self.group_CLS)
Example #3
0
    def __init__(self, opt):
        super().__init__(opt)

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        losses_G = ['sem_AB', 'sem_BA']
        losses_CLS = ['CLS']            
        
        self.loss_names_G += losses_G
        self.loss_names_CLS = losses_CLS

        self.loss_names = self.loss_names_G + self.loss_names_D + self.loss_names_CLS
        
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names += ['CLS']
            
        if self.isTrain:
            self.netCLS = networks.define_C(opt.output_nc, opt.ndf,opt.crop_size,
                                            init_type=opt.init_type, init_gain=opt.init_gain,
                                            gpu_ids=self.gpu_ids, nclasses=opt.semantic_nclasses,
                                            template=opt.cls_template, pretrained=opt.cls_pretrained)
 
        if self.isTrain:
            if opt.regression:
                if opt.l1_regression:
                    self.criterionCLS = torch.nn.L1Loss()
                else:
                    self.criterionCLS = torch.nn.modules.MSELoss()
            else:
                self.criterionCLS = torch.nn.modules.CrossEntropyLoss()
                
            # initialize optimizers
            self.optimizer_CLS = torch.optim.Adam(self.netCLS.parameters(), lr=opt.lr_f_s, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_CLS)

            self.rec_noise = opt.rec_noise

            if self.opt.iter_size > 1 :
                self.iter_calculator = IterCalculator(self.loss_names)
                for i,cur_loss in enumerate(self.loss_names):
                    self.loss_names[i] = cur_loss + '_avg'
                    setattr(self, "loss_" + self.loss_names[i], 0)

            ###Making groups
            self.group_CLS = NetworkGroup(networks_to_optimize=["CLS"],forward_functions=None,backward_functions=["compute_CLS_loss"],loss_names_list=["loss_names_CLS"],optimizer=["optimizer_CLS"],loss_backward=["loss_CLS"])
            self.networks_groups.append(self.group_CLS)
Example #4
0
    def __init__(self, opt):
        super().__init__(opt)

        if self.opt.adversarial_loss_p:
            self.loss_names_G += ["proj_fake_B_adversarial"]
        self.loss_names_G += ["recut"]
        self.loss_names_P = ["proj_real_B"]
        if self.opt.adversarial_loss_p:
            self.loss_names_P += [
                "proj_real_A_adversarial", "proj_real_B_adversarial"
            ]

        self.loss_names = self.loss_names_G + self.loss_names_f_s + self.loss_names_D + self.loss_names_P

        if self.opt.iter_size > 1:
            self.iter_calculator = IterCalculator(self.loss_names)
            for i, cur_loss in enumerate(self.loss_names):
                self.loss_names[i] = cur_loss + '_avg'
                setattr(self, "loss_" + self.loss_names[i], 0)

        self.visual_names += [["real_A_last", "proj_fake_B"],
                              ["real_B_last", "proj_real_B"]]

        self.netP_B = networks.define_G(
            (self.opt.nuplet_size - 1) * opt.input_nc,
            opt.output_nc,
            opt.ngf,
            opt.netP,
            opt.norm,
            not opt.no_dropout,
            opt.G_spectral,
            opt.init_type,
            opt.init_gain,
            self.gpu_ids,
            padding_type=opt.G_padding_type,
            opt=self.opt)
        self.model_names += ["P_B"]

        self.optimizer_P = torch.optim.Adam(itertools.chain(
            self.netP_B.parameters()),
                                            lr=opt.P_lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizers.append(self.optimizer_P)

        if self.opt.no_train_P_fake_images:
            self.group_P = NetworkGroup(networks_to_optimize=["P_B"],
                                        forward_functions=["forward_P"],
                                        backward_functions=["compute_P_loss"],
                                        loss_names_list=["loss_names_P"],
                                        optimizer=["optimizer_P"],
                                        loss_backward=["loss_P"])
            self.networks_groups.insert(1, self.group_P)
        else:  # P and G networks will be trained in the same time
            self.group_G = NetworkGroup(
                networks_to_optimize=["G", "P_B"],
                forward_functions=["forward", "forward_P"],
                backward_functions=["compute_G_loss", "compute_P_loss"],
                loss_names_list=["loss_names_G", "loss_names_P"],
                optimizer=["optimizer_G", "optimizer_P"],
                loss_backward=["loss_G", "loss_P"])
            self.networks_groups[0] = self.group_G

        self.criterionCycle = torch.nn.L1Loss()
Example #5
0
    def __init__(self, opt, rank):
        BaseModel.__init__(self, opt, rank)

        # specify the training losses you want to print out.
        # The training/test scripts will call <BaseModel.get_current_losses>
        losses_G = ['G_GAN', 'G', 'NCE']
        losses_D = ['D_tot', 'D']
        if opt.nce_idt and self.isTrain:
            losses_G += ['NCE_Y']
        if opt.netD_global != "none":
            losses_D += ['D_global']
            losses_G += ['G_GAN_global']

        self.loss_names_G = losses_G
        self.loss_names_D = losses_D

        self.loss_names = self.loss_names_G + self.loss_names_D

        visual_names_A = ['real_A', 'fake_B']
        visual_names_B = ['real_B']

        self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]

        if opt.nce_idt and self.isTrain:
            visual_names_B += ['idt_B']
        self.visual_names += [visual_names_A, visual_names_B]

        if self.opt.diff_aug_policy != '':
            self.visual_names.append(['fake_B_aug'])
            self.visual_names.append(['real_B_aug'])

        if self.isTrain:
            self.model_names = ['G', 'F', 'D']
            if opt.netD_global != "none":
                self.model_names += ['D_global']

        else:  # during test time, only load G
            self.model_names = ['G']

        # define networks (both generator and discriminator)
        self.netG = networks.define_G(opt.input_nc,
                                      opt.output_nc,
                                      opt.ngf,
                                      opt.netG,
                                      opt.norm,
                                      not opt.no_dropout,
                                      opt.G_spectral,
                                      opt.init_type,
                                      opt.init_gain,
                                      self.gpu_ids,
                                      opt=self.opt)
        self.netF = networks.define_F(opt.input_nc, opt.netF, opt.normG,
                                      not opt.no_dropout, opt.init_type,
                                      opt.init_gain, opt.no_antialias,
                                      self.gpu_ids, opt)
        self.netF.set_device(self.device)
        if self.isTrain:
            self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                          opt.n_layers_D, opt.norm,
                                          opt.D_dropout, opt.D_spectral,
                                          opt.init_type, opt.init_gain,
                                          opt.no_antialias, self.gpu_ids, opt)
            if opt.netD_global != "none":
                self.netD_global = networks.define_D(
                    opt.output_nc, opt.ndf, opt.netD_global, opt.n_layers_D,
                    opt.norm, opt.D_dropout, opt.D_spectral, opt.init_type,
                    opt.init_gain, opt.no_antialias, self.gpu_ids, opt)

            # define loss functions
            self.criterionGAN = loss.GANLoss(opt.gan_mode).to(self.device)
            self.criterionNCE = []

            for nce_layer in self.nce_layers:
                self.criterionNCE.append(PatchNCELoss(opt).to(self.device))

            self.criterionIdt = torch.nn.L1Loss().to(self.device)
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, opt.beta2))
            if opt.netD_global == "none":
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1,
                                                           opt.beta2))
            else:
                self.optimizer_D = torch.optim.Adam(
                    itertools.chain(self.netD.parameters(),
                                    self.netD_global.parameters()),
                    lr=opt.lr,
                    betas=(opt.beta1, opt.beta2))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

            if self.opt.iter_size > 1:
                self.iter_calculator = IterCalculator(self.loss_names)
                for i, cur_loss in enumerate(self.loss_names):
                    self.loss_names[i] = cur_loss + '_avg'
                    setattr(self, "loss_" + self.loss_names[i], 0)

            if opt.netD_global == "none":
                self.loss_D_global = 0
                self.loss_G_GAN_global = 0

            ###Making groups
            discriminators = ["netD"]
            if opt.netD_global != "none":
                discriminators += ["netD_global"]
                self.D_global_loss = loss.DiscriminatorGANLoss(
                    opt, self.netD_global, self.device, gan_mode="lsgan")
            self.networks_groups = []

            self.group_G = NetworkGroup(networks_to_optimize=["G", "F"],
                                        forward_functions=["forward"],
                                        backward_functions=["compute_G_loss"],
                                        loss_names_list=["loss_names_G"],
                                        optimizer=["optimizer_G"],
                                        loss_backward=["loss_G"],
                                        networks_to_ema=["G"])
            self.networks_groups.append(self.group_G)

            D_to_optimize = ["D"]
            if opt.netD_global != "none":
                D_to_optimize.append("D_global")
            self.group_D = NetworkGroup(networks_to_optimize=D_to_optimize,
                                        forward_functions=None,
                                        backward_functions=["compute_D_loss"],
                                        loss_names_list=["loss_names_D"],
                                        optimizer=["optimizer_D"],
                                        loss_backward=["loss_D_tot"])
            self.networks_groups.append(self.group_D)

        if self.opt.use_contrastive_loss_D:
            self.D_loss = loss.DiscriminatorContrastiveLoss(
                opt, self.netD, self.device)
        else:
            self.D_loss = loss.DiscriminatorGANLoss(opt, self.netD,
                                                    self.device)

        self.objects_to_update.append(self.D_loss)
Example #6
0
    def __init__(self, opt,rank):
        """Initialize the CycleGAN class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        super().__init__(opt,rank)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        
        losses_G = ['G_A','G_B']
        
        losses_G += ['cycle_A', 'idt_A', 
                     'cycle_B', 'idt_B']            
            
        losses_D = ['D_A', 'D_B']

        if opt.netD_global != "none":
            losses_D += ['D_A_global', 'D_B_global']
            losses_G += ['G_A_global','G_B_global']

        self.loss_names_G = losses_G
        self.loss_names_D = losses_D
        
        self.loss_names = self.loss_names_G + self.loss_names_D

        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.isTrain and self.opt.lambda_identity > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names += [visual_names_A , visual_names_B]  # combine visualizations for A and B
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.

        if self.opt.diff_aug_policy != '':
            self.visual_names.append(['real_A_aug','fake_B_aug'])
            self.visual_names.append(['real_B_aug','fake_A_aug'])
        
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
            if opt.netD_global != "none":
                self.model_names += ['D_A_global', 'D_B_global']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.G_spectral, opt.init_type, opt.init_gain, self.gpu_ids,padding_type=opt.G_padding_type,opt=self.opt)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.G_spectral, opt.init_type, opt.init_gain, self.gpu_ids,padding_type=opt.G_padding_type,opt=self.opt)

        if self.isTrain:  # define discriminators
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.D_dropout, opt.D_spectral, opt.init_type, opt.init_gain,opt.no_antialias, self.gpu_ids,self.opt)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.D_dropout, opt.D_spectral, opt.init_type, opt.init_gain,opt.no_antialias, self.gpu_ids,self.opt)

            if opt.netD_global != "none":
                self.netD_A_global = networks.define_D(opt.output_nc, opt.ndf, opt.netD_global,
                                            opt.n_layers_D, opt.norm, opt.D_dropout, opt.D_spectral, opt.init_type, opt.init_gain,opt.no_antialias, self.gpu_ids,self.opt)
                self.netD_B_global = networks.define_D(opt.input_nc, opt.ndf, opt.netD_global,
                                            opt.n_layers_D, opt.norm, opt.D_dropout, opt.D_spectral, opt.init_type, opt.init_gain,opt.no_antialias, self.gpu_ids,self.opt)

            if self.opt.lambda_identity == 0.0:
                self.loss_idt_A = 0
                self.loss_idt_B = 0
            if opt.netD_global == "none":
                self.loss_D_A_global=0
                self.loss_D_B_global=0
                self.loss_G_A_global=0
                self.loss_G_B_global=0
 
        if self.isTrain:
            if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
                assert(opt.input_nc == opt.output_nc)
            # define loss functions
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()

            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            if opt.netD_global== "none":
                self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            else:
                self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters(),self.netD_A_global.parameters(), self.netD_B_global.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            
            if self.opt.iter_size > 1 :
                self.iter_calculator = IterCalculator(self.loss_names)
                for i,cur_loss in enumerate(self.loss_names):
                    self.loss_names[i] = cur_loss + '_avg'
                    setattr(self, "loss_" + self.loss_names[i], 0)

            self.rec_noise = opt.rec_noise
                    
            if self.opt.use_contrastive_loss_D:
                self.D_loss="compute_D_contrastive_loss_basic"
                self.D_loss=loss.DiscriminatorContrastiveLoss(opt,self.netD_A,self.device)
            else:
                self.D_loss="compute_D_loss_basic"
                self.D_loss=loss.DiscriminatorGANLoss(opt,self.netD_A,self.device)

            self.objects_to_update.append(self.D_loss)
                
            ###Making groups
            self.networks_groups = []
            discriminators=["netD_A","netD_B"]
            if opt.netD_global != "none":
                discriminators += ["netD_A_global","netD_B_global"]
                self.D_global_loss=loss.DiscriminatorGANLoss(opt,self.netD_A_global,self.device)

            self.group_G = NetworkGroup(networks_to_optimize=["G_A","G_B"],forward_functions=["forward"],backward_functions=["compute_G_loss"],loss_names_list=["loss_names_G"],optimizer=["optimizer_G"],loss_backward=["loss_G"],networks_to_ema=["G_A","G_B"])
            self.networks_groups.append(self.group_G)

            self.group_D = NetworkGroup(networks_to_optimize=["D_A","D_B"],forward_functions=None,backward_functions=["compute_D_loss"],loss_names_list=["loss_names_D"],optimizer=["optimizer_D"],loss_backward=["loss_D"])
            self.networks_groups.append(self.group_D)
Example #7
0
    def __init__(self, opt, rank):
        super().__init__(opt, rank)
        if not hasattr(opt, 'disc_in_mask'):
            opt.disc_in_mask = False
        if not hasattr(opt, 'out_mask'):
            opt.out_mask = False
        if not hasattr(opt, 'fs_light'):
            opt.fs_light = False

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        losses_G = ['sem_AB', 'sem_BA']

        if opt.out_mask:
            losses_G += ['out_mask_AB', 'out_mask_BA']

        losses_f_s = ['f_s']

        losses_D = []
        if opt.disc_in_mask:
            losses_D = ['D_A_mask', 'D_B_mask']

        self.loss_names_G += losses_G
        self.loss_names_f_s = losses_f_s
        self.loss_names_D += losses_D

        self.loss_names = self.loss_names_G + self.loss_names_f_s + self.loss_names_D

        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names += ['f_s']
            if opt.disc_in_mask:
                self.model_names += ['D_A_mask', 'D_B_mask']

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        if self.isTrain:
            if opt.disc_in_mask:
                self.netD_A_mask = networks.define_D(
                    opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm,
                    opt.D_dropout, opt.D_spectral, opt.init_type,
                    opt.init_gain, self.gpu_ids)
                self.netD_B_mask = networks.define_D(
                    opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm,
                    opt.D_dropout, opt.D_spectral, opt.init_type,
                    opt.init_gain, self.gpu_ids)

        self.netf_s = networks.define_f(opt.input_nc,
                                        nclasses=opt.semantic_nclasses,
                                        init_type=opt.init_type,
                                        init_gain=opt.init_gain,
                                        gpu_ids=self.gpu_ids,
                                        fs_light=opt.fs_light)

        if self.isTrain:
            self.fake_A_pool_mask = ImagePool(opt.pool_size)
            self.fake_B_pool_mask = ImagePool(opt.pool_size)

            # define loss functions
            self.criterionf_s = torch.nn.modules.CrossEntropyLoss()

            if opt.out_mask or opt.disc_in_mask:
                if opt.loss_out_mask == 'L1':
                    self.criterionMask = torch.nn.L1Loss()
                elif opt.loss_out_mask == 'MSE':
                    self.criterionMask = torch.nn.MSELoss()
                elif opt.loss_out_mask == 'Charbonnier':
                    self.criterionMask = L1_Charbonnier_loss(
                        opt.charbonnier_eps)

            # initialize optimizers
            if not opt.madgrad:
                self.optimizer_G = torch.optim.Adam(itertools.chain(
                    self.netG_A.parameters(), self.netG_B.parameters()),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
                if opt.disc_in_mask:
                    self.optimizer_D = torch.optim.Adam(
                        itertools.chain(self.netD_A.parameters(),
                                        self.netD_B.parameters(),
                                        self.netD_A_mask.parameters(),
                                        self.netD_B_mask.parameters()),
                        lr=opt.D_lr,
                        betas=(opt.beta1, 0.999))
                else:
                    self.optimizer_D = torch.optim.Adam(
                        itertools.chain(self.netD_A.parameters(),
                                        self.netD_B.parameters()),
                        lr=opt.D_lr,
                        betas=(opt.beta1, 0.999))
                self.optimizer_f_s = torch.optim.Adam(self.netf_s.parameters(),
                                                      lr=opt.lr_f_s,
                                                      betas=(opt.beta1, 0.999))
            else:
                self.optimizer_G = MADGRAD(itertools.chain(
                    self.netG_A.parameters(), self.netG_B.parameters()),
                                           lr=opt.lr)
                if opt.disc_in_mask:
                    self.optimizer_D = MADGRAD(itertools.chain(
                        self.netD_A.parameters(), self.netD_B.parameters(),
                        self.netD_A_mask.parameters(),
                        self.netD_B_mask.parameters()),
                                               lr=opt.D_lr)
                else:
                    self.optimizer_D = MADGRAD(itertools.chain(
                        self.netD_A.parameters(), self.netD_B.parameters()),
                                               lr=opt.D_lr)
                self.optimizer_f_s = MADGRAD(self.netf_s.parameters(),
                                             lr=opt.lr_f_s)

            self.optimizers.append(self.optimizer_f_s)

            if self.opt.iter_size > 1:
                self.iter_calculator = IterCalculator(self.loss_names)
                for i, cur_loss in enumerate(self.loss_names):
                    self.loss_names[i] = cur_loss + '_avg'
                    setattr(self, "loss_" + self.loss_names[i], 0)

            ###Making groups
            discriminators = ["D_A", "D_B"]
            if opt.disc_in_mask:
                discriminators += ["D_A_mask", "D_B_mask"]

            self.group_f_s = NetworkGroup(
                networks_to_optimize=["f_s"],
                forward_functions=None,
                backward_functions=["compute_f_s_loss"],
                loss_names_list=["loss_names_f_s"],
                optimizer=["optimizer_f_s"],
                loss_backward=["loss_f_s"])
            self.networks_groups.append(self.group_f_s)