def __init__(self, opt): """Initialize the pix2pix class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ BaseModel.__init__(self, opt) # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> self.loss_names = ['G', 'G_GAN', 'G_L1', 'D_real', 'D_fake'] # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> # self.visual_names = ['real_A', 'fake_B', 'real_B'] self.visual_names = ['cloth_decoded', 'fakes_scaled', 'textures_unnormalized'] # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks> if self.is_train: self.model_names = ['G', 'D'] else: # during test time, only load G self.model_names = ['G'] # define networks (both generator and discriminator) self.net_G = define_G(opt.cloth_channels + 36, opt.texture_channels, 64, "unet_128", opt.norm, True, opt.init_type, opt.init_gain).to(self.device) if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc self.net_D = define_D(opt.cloth_channels + 36 + opt.texture_channels, 64, opt.discriminator, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain).to(self.device) if self.is_train: # define loss functions use_smooth = True if opt.gan_label_mode == "smooth" else False self.criterionGAN = GANLoss(opt.gan_mode, smooth_labels=use_smooth).to(self.device) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>. self.optimizer_G = torch.optim.Adam(self.net_G.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.net_D.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
def __init__(self, opt): """ Sets the generator, discriminator, and optimizers. Sets self.net_generator to the return value of self.define_G() Args: opt: """ super().__init__(opt) self.net_generator = self.define_G().to(self.device) modules.init_weights(self.net_generator, opt.init_type, opt.init_gain) self.model_names = ["generator"] if self.is_train: # setup discriminator self.net_discriminator = discriminators.define_D( self.get_D_inchannels(), 64, opt.discriminator, opt.n_layers_D, opt.norm ).to(self.device) modules.init_weights(self.net_discriminator, opt.init_type, opt.init_gain) # load discriminator only at train time self.model_names.append("discriminator") # setup GAN loss use_smooth = True if opt.gan_label_mode == "smooth" else False self.criterion_GAN = modules.loss.GANLoss( opt.gan_mode, smooth_labels=use_smooth ).to(self.device) if opt.lambda_discriminator: self.loss_names = ["D", "D_real", "D_fake"] if any(gp_mode in opt.gan_mode for gp_mode in ["gp", "lp"]): self.loss_names += ["D_gp"] self.loss_names += ["G"] if opt.lambda_gan: self.loss_names += ["G_gan"] # Define optimizers self.optimizer_G = optimizers.define_optimizer( self.net_generator.parameters(), opt, "G" ) self.optimizer_D = optimizers.define_optimizer( self.net_discriminator.parameters(), opt, "D" ) self.optimizer_names = ("G", "D")
def __init__(self, opt): """ Initialize the WarpModel. Either in GAN mode or plain Cross Entropy mode. Args: opt: """ # 3 for RGB self.body_channels = (opt.body_channels if opt.body_representation == "labels" else 3) # 3 for RGB self.cloth_channels = (opt.cloth_channels if opt.cloth_representation == "labels" else 3) #BaseGAN.__init__(self, opt) # or ####super().__init__(opt) ###################################### self.opt = opt self.gpu_id = opt.gpu_id self.is_train = opt.is_train # get device name: CPU or GPU self.device = (torch.device(f"cuda:{self.gpu_id}") if self.gpu_id is not None else torch.device("cpu")) # save all the checkpoints to save_dir self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) if self.is_train: PromptOnce.makedirs(self.save_dir, not opt.no_confirm) self.loss_names = [] self.model_names = [] self.visual_names = [] self.optimizer_names = [] # self.optimizers = [] self.image_paths = [] self.metric = 0 # used for learning rate policy 'plateau' ###################################### self.net_generator = self.define_G().to(self.device) modules.init_weights(self.net_generator, opt.init_type, opt.init_gain) self.model_names = ["generator"] if self.is_train: # setup discriminator self.net_discriminator = discriminators.define_D( self.get_D_inchannels(), 64, opt.discriminator, opt.n_layers_D, opt.norm).to(self.device) modules.init_weights(self.net_discriminator, opt.init_type, opt.init_gain) # load discriminator only at train time self.model_names.append("discriminator") # setup GAN loss use_smooth = True if opt.gan_label_mode == "smooth" else False self.criterion_GAN = modules.loss.GANLoss( opt.gan_mode, smooth_labels=use_smooth).to(self.device) if opt.lambda_discriminator: self.loss_names = ["D", "D_real", "D_fake"] if any(gp_mode in opt.gan_mode for gp_mode in ["gp", "lp"]): self.loss_names += ["D_gp"] self.loss_names += ["G"] if opt.lambda_gan: self.loss_names += ["G_gan"] # Define optimizers self.optimizer_G = optimizers.define_optimizer( self.net_generator.parameters(), opt, "G") self.optimizer_D = optimizers.define_optimizer( self.net_discriminator.parameters(), opt, "D") self.optimizer_names = ("G", "D") ####################################################################### # TODO: decode visuals for cloth self.visual_names = [ "inputs_decoded", "bodys_unnormalized", "fakes_decoded" ] if self.is_train: self.visual_names.append( "targets_decoded") # only show targets during training # we use cross entropy loss in both self.criterion_CE = nn.CrossEntropyLoss() if opt.warp_mode != "gan": # remove discriminator related things if no GAN self.model_names = ["generator"] self.loss_names = "G" del self.net_discriminator del self.optimizer_D self.optimizer_names = ["G"] else: self.loss_names += ["G_ce"]
def __init__(self, opt): #super().__init__(opt) self.opt = opt self.gpu_id = opt.gpu_id self.is_train = opt.is_train # get device name: CPU or GPU self.device = (torch.device(f"cuda:{self.gpu_id}") if self.gpu_id is not None else torch.device("cpu")) # save all the checkpoints to save_dir self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) if self.is_train: PromptOnce.makedirs(self.save_dir, not opt.no_confirm) self.loss_names = [] self.model_names = [] self.visual_names = [] self.optimizer_names = [] # self.optimizers = [] self.image_paths = [] self.metric = 0 # used for learning rate policy 'plateau' self.net_generator = self.define_G().to(self.device) modules.init_weights(self.net_generator, opt.init_type, opt.init_gain) self.model_names = ["generator"] if self.is_train: # setup discriminator self.net_discriminator = discriminators.define_D( self.get_D_inchannels(), 64, opt.discriminator, opt.n_layers_D, opt.norm).to(self.device) modules.init_weights(self.net_discriminator, opt.init_type, opt.init_gain) # load discriminator only at train time self.model_names.append("discriminator") # setup GAN loss use_smooth = True if opt.gan_label_mode == "smooth" else False self.criterion_GAN = modules.loss.GANLoss( opt.gan_mode, smooth_labels=use_smooth).to(self.device) if opt.lambda_discriminator: self.loss_names = ["D", "D_real", "D_fake"] if any(gp_mode in opt.gan_mode for gp_mode in ["gp", "lp"]): self.loss_names += ["D_gp"] self.loss_names += ["G"] if opt.lambda_gan: self.loss_names += ["G_gan"] # Define optimizers self.optimizer_G = optimizers.define_optimizer( self.net_generator.parameters(), opt, "G") self.optimizer_D = optimizers.define_optimizer( self.net_discriminator.parameters(), opt, "D") self.optimizer_names = ("G", "D") # TODO: decode cloth visual self.visual_names = [ "textures_unnormalized", "cloths_decoded", "fakes", "fakes_scaled", ] if self.is_train: self.visual_names.append("targets_unnormalized") # Define additional loss for generator self.criterion_L1 = nn.L1Loss().to(self.device) self.criterion_perceptual = modules.losses.PerceptualLoss( use_style=opt.lambda_style != 0).to(self.device) for loss in ["l1", "content", "style"]: if getattr(opt, "lambda_" + loss) != 0: self.loss_names.append(f"G_{loss}")