예제 #1
0
    def __init__(self, opt):
        """Initialize the pix2pix class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
        if self.isTrain:
            self.model_names = ['G', 'D', 'SC']
        else:  # during test time, only load G
            self.model_names = ['G', 'SC']
        # define networks (both generator and discriminator)
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.netG, opt.norm, not opt.no_dropout,
                                      opt.init_type, opt.init_gain,
                                      self.gpu_ids)

        self.netSC = networks.define_C(opt.norm, opt.init_type, opt.init_gain,
                                       self.gpu_ids)
        self.aux_data = aux_dataset.AuxAttnDataset(7000,
                                                   7000,
                                                   self.gpu_ids[0],
                                                   mask_size=32)
        self.zero_attn_holder = torch.zeros(
            (1, 1, opt.mask_size, opt.mask_size),
            dtype=torch.float32).to(self.device)
        self.ones_attn_holder = torch.ones(
            (1, 1, opt.mask_size, opt.mask_size),
            dtype=torch.float32).to(self.device)

        if self.isTrain:  # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
            self.netD = networks.define_D(opt.input_nc + opt.output_nc,
                                          opt.ndf, opt.netD, opt.n_layers_D,
                                          opt.norm, opt.init_type,
                                          opt.init_gain, self.gpu_ids)

        if self.isTrain:
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG.parameters(), self.netSC.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
예제 #2
0
    def __init__(self, opt):
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.isTrain and self.opt.lambda_identity > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        if opt.concat != 'alpha':
            self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                            not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
            self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
                                            not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
        else:
            self.netG_A = networks.define_G(4, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                            not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
            self.netG_B = networks.define_G(4, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                            not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

        self.aux_data = aux_dataset.AuxAttnDataset(3000, 3000, self.gpu_ids[0], mask_size=opt.mask_size)
        self.zero_attn_holder = torch.zeros((1, 1, opt.mask_size, opt.mask_size), dtype=torch.float32).to(self.device)
        self.ones_attn_holder = torch.ones((1, 1, opt.mask_size, opt.mask_size), dtype=torch.float32).to(self.device)
        self.concat = opt.concat

        if self.isTrain:  # define discriminators
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids,
                                            opt.mask_size, opt.s1, opt.s2)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids,
                                            opt.mask_size, opt.s1, opt.s2)

        if self.isTrain:
            if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
                # assert(opt.input_nc == opt.output_nc)
                pass
            self.fake_A_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
    def __init__(self, opt):
        # Initialize the network structure, parameter groups
        BaseModel.__init__(self, opt)
        self.loss_names = [
            'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'
        ]

        # Intermediate result that will be visualized -- change manually
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        # visual_names_A = ['real_A', 'fake_B', 'rec_A', 'vis_A2B']
        # visual_names_B = ['real_B', 'fake_A', 'rec_B', 'vis_B2A']

        if self.isTrain and self.opt.lambda_identity > 0.0:  # Disabled in our framework
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B

        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B', 'S_CA', 'S_CB']
        else:  # during test time, load required networks
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B', 'S_CA', 'S_CB']

        # define networks (both Generators and discriminators)
        # Initialize the small attention transformation network
        self.netS_CA = networks.define_C(opt.norm, opt.init_type,
                                         opt.init_gain, self.gpu_ids)
        self.netS_CB = networks.define_C(opt.norm, opt.init_type,
                                         opt.init_gain, self.gpu_ids)

        # Define the structure of the generator based on the concatenation type
        if opt.concat != 'alpha':
            self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
                                            opt.ngf, opt.netG, opt.norm,
                                            not opt.no_dropout, opt.init_type,
                                            opt.init_gain, self.gpu_ids)
            self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
                                            opt.ngf, opt.netG, opt.norm,
                                            not opt.no_dropout, opt.init_type,
                                            opt.init_gain, self.gpu_ids)
        else:
            # Extra Channel
            self.netG_A = networks.define_G(opt.input_nc + 1, opt.output_nc,
                                            opt.ngf, opt.netG, opt.norm,
                                            not opt.no_dropout, opt.init_type,
                                            opt.init_gain, self.gpu_ids)
            self.netG_B = networks.define_G(opt.input_nc + 1, opt.output_nc,
                                            opt.ngf, opt.netG, opt.norm,
                                            not opt.no_dropout, opt.init_type,
                                            opt.init_gain, self.gpu_ids)

        # Auxiliary attention holder
        self.aux_data = aux_dataset.AuxAttnDataset(7000,
                                                   7000,
                                                   self.gpu_ids[0],
                                                   mask_size=32)
        self.zero_attn_holder = torch.zeros(
            (1, 1, opt.mask_size, opt.mask_size),
            dtype=torch.float32).to(self.device)
        self.ones_attn_holder = torch.ones(
            (1, 1, opt.mask_size, opt.mask_size),
            dtype=torch.float32).to(self.device)

        self.concat = opt.concat

        # Visualization purpose only
        self.vis_A2B, self.vis_B2A = torch.zeros((1, 1, 256, 256), dtype=torch.float32).to(self.device), \
                                     torch.zeros((1, 1, 256, 256), dtype=torch.float32).to(self.device)

        self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                        opt.n_layers_D, opt.norm,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids, opt.mask_size, opt.s1,
                                        opt.s2)
        self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                        opt.n_layers_D, opt.norm,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids, opt.mask_size, opt.s1,
                                        opt.s2)

        if self.isTrain:
            # Initialize components that will only be used in training phase

            if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
                pass
            self.fake_A_pool = ImagePool(
                opt.pool_size
            )  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(
                opt.pool_size
            )  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(
                self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters(),
                self.netS_CA.parameters(), self.netS_CB.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                self.netD_A.parameters(), self.netD_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)